MemoryOS 1.0.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MemoryOS might be problematic. Click here for more details.

Files changed (94) hide show
  1. {memoryos-1.0.0.dist-info → memoryos-1.1.1.dist-info}/METADATA +8 -2
  2. {memoryos-1.0.0.dist-info → memoryos-1.1.1.dist-info}/RECORD +92 -69
  3. {memoryos-1.0.0.dist-info → memoryos-1.1.1.dist-info}/WHEEL +1 -1
  4. memos/__init__.py +1 -1
  5. memos/api/client.py +109 -0
  6. memos/api/config.py +35 -8
  7. memos/api/context/dependencies.py +15 -66
  8. memos/api/middleware/request_context.py +63 -0
  9. memos/api/product_api.py +5 -2
  10. memos/api/product_models.py +107 -16
  11. memos/api/routers/product_router.py +62 -19
  12. memos/api/start_api.py +13 -0
  13. memos/configs/graph_db.py +4 -0
  14. memos/configs/mem_scheduler.py +38 -3
  15. memos/configs/memory.py +13 -0
  16. memos/configs/reranker.py +18 -0
  17. memos/context/context.py +255 -0
  18. memos/embedders/factory.py +2 -0
  19. memos/graph_dbs/base.py +4 -2
  20. memos/graph_dbs/nebular.py +368 -223
  21. memos/graph_dbs/neo4j.py +49 -13
  22. memos/graph_dbs/neo4j_community.py +13 -3
  23. memos/llms/factory.py +2 -0
  24. memos/llms/openai.py +74 -2
  25. memos/llms/vllm.py +2 -0
  26. memos/log.py +128 -4
  27. memos/mem_cube/general.py +3 -1
  28. memos/mem_os/core.py +89 -23
  29. memos/mem_os/main.py +3 -6
  30. memos/mem_os/product.py +418 -154
  31. memos/mem_os/utils/reference_utils.py +20 -0
  32. memos/mem_reader/factory.py +2 -0
  33. memos/mem_reader/simple_struct.py +204 -82
  34. memos/mem_scheduler/analyzer/__init__.py +0 -0
  35. memos/mem_scheduler/analyzer/mos_for_test_scheduler.py +569 -0
  36. memos/mem_scheduler/analyzer/scheduler_for_eval.py +280 -0
  37. memos/mem_scheduler/base_scheduler.py +126 -56
  38. memos/mem_scheduler/general_modules/dispatcher.py +2 -2
  39. memos/mem_scheduler/general_modules/misc.py +99 -1
  40. memos/mem_scheduler/general_modules/scheduler_logger.py +17 -11
  41. memos/mem_scheduler/general_scheduler.py +40 -88
  42. memos/mem_scheduler/memory_manage_modules/__init__.py +5 -0
  43. memos/mem_scheduler/memory_manage_modules/memory_filter.py +308 -0
  44. memos/mem_scheduler/{general_modules → memory_manage_modules}/retriever.py +34 -7
  45. memos/mem_scheduler/monitors/dispatcher_monitor.py +9 -8
  46. memos/mem_scheduler/monitors/general_monitor.py +119 -39
  47. memos/mem_scheduler/optimized_scheduler.py +124 -0
  48. memos/mem_scheduler/orm_modules/__init__.py +0 -0
  49. memos/mem_scheduler/orm_modules/base_model.py +635 -0
  50. memos/mem_scheduler/orm_modules/monitor_models.py +261 -0
  51. memos/mem_scheduler/scheduler_factory.py +2 -0
  52. memos/mem_scheduler/schemas/monitor_schemas.py +96 -29
  53. memos/mem_scheduler/utils/config_utils.py +100 -0
  54. memos/mem_scheduler/utils/db_utils.py +33 -0
  55. memos/mem_scheduler/utils/filter_utils.py +1 -1
  56. memos/mem_scheduler/webservice_modules/__init__.py +0 -0
  57. memos/mem_user/mysql_user_manager.py +4 -2
  58. memos/memories/activation/kv.py +2 -1
  59. memos/memories/textual/item.py +96 -17
  60. memos/memories/textual/naive.py +1 -1
  61. memos/memories/textual/tree.py +57 -3
  62. memos/memories/textual/tree_text_memory/organize/handler.py +4 -2
  63. memos/memories/textual/tree_text_memory/organize/manager.py +28 -14
  64. memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +1 -2
  65. memos/memories/textual/tree_text_memory/organize/reorganizer.py +75 -23
  66. memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +10 -6
  67. memos/memories/textual/tree_text_memory/retrieve/internet_retriever.py +6 -2
  68. memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +2 -0
  69. memos/memories/textual/tree_text_memory/retrieve/recall.py +119 -21
  70. memos/memories/textual/tree_text_memory/retrieve/searcher.py +172 -44
  71. memos/memories/textual/tree_text_memory/retrieve/utils.py +6 -4
  72. memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py +5 -4
  73. memos/memos_tools/notification_utils.py +46 -0
  74. memos/memos_tools/singleton.py +174 -0
  75. memos/memos_tools/thread_safe_dict.py +22 -0
  76. memos/memos_tools/thread_safe_dict_segment.py +382 -0
  77. memos/parsers/factory.py +2 -0
  78. memos/reranker/__init__.py +4 -0
  79. memos/reranker/base.py +24 -0
  80. memos/reranker/concat.py +59 -0
  81. memos/reranker/cosine_local.py +96 -0
  82. memos/reranker/factory.py +48 -0
  83. memos/reranker/http_bge.py +312 -0
  84. memos/reranker/noop.py +16 -0
  85. memos/templates/mem_reader_prompts.py +289 -40
  86. memos/templates/mem_scheduler_prompts.py +242 -0
  87. memos/templates/mos_prompts.py +133 -60
  88. memos/types.py +4 -1
  89. memos/api/context/context.py +0 -147
  90. memos/mem_scheduler/mos_for_test_scheduler.py +0 -146
  91. {memoryos-1.0.0.dist-info → memoryos-1.1.1.dist-info}/entry_points.txt +0 -0
  92. {memoryos-1.0.0.dist-info → memoryos-1.1.1.dist-info/licenses}/LICENSE +0 -0
  93. /memos/mem_scheduler/{general_modules → webservice_modules}/rabbitmq_service.py +0 -0
  94. /memos/mem_scheduler/{general_modules → webservice_modules}/redis_service.py +0 -0
@@ -4,19 +4,20 @@ import time
4
4
  import traceback
5
5
 
6
6
  from collections import defaultdict
7
- from concurrent.futures import ThreadPoolExecutor, as_completed
7
+ from concurrent.futures import as_completed
8
8
  from queue import PriorityQueue
9
9
  from typing import Literal
10
10
 
11
11
  import numpy as np
12
12
 
13
+ from memos.context.context import ContextThreadPoolExecutor
13
14
  from memos.dependency import require_python_package
14
15
  from memos.embedders.factory import OllamaEmbedder
15
16
  from memos.graph_dbs.item import GraphDBEdge, GraphDBNode
16
17
  from memos.graph_dbs.neo4j import Neo4jGraphDB
17
18
  from memos.llms.base import BaseLLM
18
19
  from memos.log import get_logger
19
- from memos.memories.textual.item import TreeNodeTextualMemoryMetadata
20
+ from memos.memories.textual.item import SourceMessage, TreeNodeTextualMemoryMetadata
20
21
  from memos.memories.textual.tree_text_memory.organize.handler import NodeHandler
21
22
  from memos.memories.textual.tree_text_memory.organize.relation_reason_detector import (
22
23
  RelationAndReasoningDetector,
@@ -27,6 +28,22 @@ from memos.templates.tree_reorganize_prompts import LOCAL_SUBCLUSTER_PROMPT, REO
27
28
  logger = get_logger(__name__)
28
29
 
29
30
 
31
+ def build_summary_parent_node(cluster_nodes):
32
+ normalized_sources = []
33
+ for n in cluster_nodes:
34
+ sm = SourceMessage(
35
+ type="chat",
36
+ role=None,
37
+ chat_time=None,
38
+ message_id=None,
39
+ content=n.memory,
40
+ # extra
41
+ node_id=n.id,
42
+ )
43
+ normalized_sources.append(sm)
44
+ return normalized_sources
45
+
46
+
30
47
  class QueueMessage:
31
48
  def __init__(
32
49
  self,
@@ -51,6 +68,15 @@ class QueueMessage:
51
68
  return op_priority[self.op] < op_priority[other.op]
52
69
 
53
70
 
71
+ def extract_first_to_last_brace(text: str):
72
+ start = text.find("{")
73
+ end = text.rfind("}")
74
+ if start == -1 or end == -1 or end < start:
75
+ return "", None
76
+ json_str = text[start : end + 1]
77
+ return json_str, json.loads(json_str)
78
+
79
+
54
80
  class GraphStructureReorganizer:
55
81
  def __init__(
56
82
  self, graph_store: Neo4jGraphDB, llm: BaseLLM, embedder: OllamaEmbedder, is_reorganize: bool
@@ -87,6 +113,7 @@ class GraphStructureReorganizer:
87
113
  1) queue is empty
88
114
  2) any running structure optimization is done
89
115
  """
116
+ deadline = time.time() + 600
90
117
  if not self.is_reorganize:
91
118
  return
92
119
 
@@ -96,6 +123,9 @@ class GraphStructureReorganizer:
96
123
 
97
124
  while any(self._is_optimizing.values()):
98
125
  logger.debug(f"Waiting for structure optimizer to finish... {self._is_optimizing}")
126
+ if time.time() > deadline:
127
+ logger.error(f"Wait timed out; flags={self._is_optimizing}")
128
+ break
99
129
  time.sleep(1)
100
130
  logger.debug("Structure optimizer is now idle.")
101
131
 
@@ -129,6 +159,9 @@ class GraphStructureReorganizer:
129
159
 
130
160
  logger.info("Structure optimizer schedule started.")
131
161
  while not getattr(self, "_stop_scheduler", False):
162
+ if any(self._is_optimizing.values()):
163
+ time.sleep(1)
164
+ continue
132
165
  if self._reorganize_needed:
133
166
  logger.info("[Reorganizer] Triggering optimize_structure due to new nodes.")
134
167
  self.optimize_structure(scope="LongTermMemory")
@@ -176,6 +209,7 @@ class GraphStructureReorganizer:
176
209
  local_tree_threshold: int = 10,
177
210
  min_cluster_size: int = 4,
178
211
  min_group_size: int = 20,
212
+ max_duration_sec: int = 600,
179
213
  ):
180
214
  """
181
215
  Periodically reorganize the graph:
@@ -183,8 +217,20 @@ class GraphStructureReorganizer:
183
217
  2. Summarize each cluster.
184
218
  3. Create parent nodes and build local PARENT trees.
185
219
  """
220
+ # --- Total time watch dog: check functions ---
221
+ start_ts = time.time()
222
+
223
+ def _check_deadline(where: str):
224
+ if time.time() - start_ts > max_duration_sec:
225
+ logger.error(
226
+ f"[GraphStructureReorganize] {scope} surpass {max_duration_sec}s,time "
227
+ f"over at {where}"
228
+ )
229
+ return True
230
+ return False
231
+
186
232
  if self._is_optimizing[scope]:
187
- logger.info(f"Already optimizing for {scope}. Skipping.")
233
+ logger.info(f"[GraphStructureReorganize] Already optimizing for {scope}. Skipping.")
188
234
  return
189
235
 
190
236
  if self.graph_store.node_not_exist(scope):
@@ -198,32 +244,35 @@ class GraphStructureReorganizer:
198
244
  )
199
245
 
200
246
  logger.debug(
201
- f"Num of scope in self.graph_store is {self.graph_store.get_memory_count(scope)}"
247
+ f"[GraphStructureReorganize] Num of scope in self.graph_store is"
248
+ f" {self.graph_store.get_memory_count(scope)}"
202
249
  )
203
250
  # Load candidate nodes
251
+ if _check_deadline("[GraphStructureReorganize] Before loading candidates"):
252
+ return
204
253
  raw_nodes = self.graph_store.get_structure_optimization_candidates(scope)
205
254
  nodes = [GraphDBNode(**n) for n in raw_nodes]
206
255
 
207
256
  if not nodes:
208
257
  logger.info("[GraphStructureReorganize] No nodes to optimize. Skipping.")
209
258
  return
210
-
211
259
  if len(nodes) < min_group_size:
212
260
  logger.info(
213
261
  f"[GraphStructureReorganize] Only {len(nodes)} candidate nodes found. Not enough to reorganize. Skipping."
214
262
  )
215
263
  return
216
264
 
217
- logger.info(f"[GraphStructureReorganize] Loaded {len(nodes)} nodes.")
218
-
219
265
  # Step 2: Partition nodes
266
+ if _check_deadline("[GraphStructureReorganize] Before partition"):
267
+ return
220
268
  partitioned_groups = self._partition(nodes)
221
-
222
269
  logger.info(
223
270
  f"[GraphStructureReorganize] Partitioned into {len(partitioned_groups)} clusters."
224
271
  )
225
272
 
226
- with ThreadPoolExecutor(max_workers=4) as executor:
273
+ if _check_deadline("[GraphStructureReorganize] Before submit partition task"):
274
+ return
275
+ with ContextThreadPoolExecutor(max_workers=4) as executor:
227
276
  futures = []
228
277
  for cluster_nodes in partitioned_groups:
229
278
  futures.append(
@@ -237,14 +286,17 @@ class GraphStructureReorganizer:
237
286
  )
238
287
 
239
288
  for f in as_completed(futures):
289
+ if _check_deadline("[GraphStructureReorganize] Waiting clusters..."):
290
+ for x in futures:
291
+ x.cancel()
292
+ return
240
293
  try:
241
294
  f.result()
242
295
  except Exception as e:
243
296
  logger.warning(
244
- f"[Reorganize] Cluster processing "
245
- f"failed: {e}, cluster_nodes: {cluster_nodes}, trace: {traceback.format_exc()}"
297
+ f"[GraphStructureReorganize] Cluster processing failed: {e}, trace: {traceback.format_exc()}"
246
298
  )
247
- logger.info("[GraphStructure Reorganize] Structure optimization finished.")
299
+ logger.info("[GraphStructure Reorganize] Structure optimization finished.")
248
300
 
249
301
  finally:
250
302
  self._is_optimizing[scope] = False
@@ -282,7 +334,7 @@ class GraphStructureReorganizer:
282
334
  nodes_to_check = cluster_nodes
283
335
  exclude_ids = [n.id for n in nodes_to_check]
284
336
 
285
- with ThreadPoolExecutor(max_workers=4) as executor:
337
+ with ContextThreadPoolExecutor(max_workers=4) as executor:
286
338
  futures = []
287
339
  for node in nodes_to_check:
288
340
  futures.append(
@@ -294,7 +346,7 @@ class GraphStructureReorganizer:
294
346
  )
295
347
  )
296
348
 
297
- for f in as_completed(futures):
349
+ for f in as_completed(futures, timeout=300):
298
350
  results = f.result()
299
351
 
300
352
  # 1) Add pairwise relations
@@ -331,11 +383,11 @@ class GraphStructureReorganizer:
331
383
  for child_id in agg_node.metadata.sources:
332
384
  self.graph_store.add_edge(agg_node.id, child_id, "AGGREGATE_TO")
333
385
 
334
- logger.info("[Reorganizer] Cluster relation/reasoning done.")
386
+ logger.info("[Reorganizer] Cluster relation/reasoning done.")
335
387
 
336
388
  def _local_subcluster(
337
- self, cluster_nodes: list[GraphDBNode], max_length: int = 8000
338
- ) -> (list)[list[GraphDBNode]]:
389
+ self, cluster_nodes: list[GraphDBNode], max_length: int = 15000
390
+ ) -> list[list[GraphDBNode]]:
339
391
  """
340
392
  Use LLM to split a large cluster into semantically coherent sub-clusters.
341
393
  """
@@ -350,7 +402,7 @@ class GraphStructureReorganizer:
350
402
 
351
403
  joined_scene = "\n".join(scene_lines)
352
404
  if len(joined_scene) > max_length:
353
- logger.warning(f"Sub-cluster too long: {joined_scene}")
405
+ logger.warning("Sub-cluster too long")
354
406
  prompt = LOCAL_SUBCLUSTER_PROMPT.replace("{joined_scene}", joined_scene[:max_length])
355
407
 
356
408
  messages = [{"role": "user", "content": prompt}]
@@ -499,17 +551,17 @@ class GraphStructureReorganizer:
499
551
  parent_node = GraphDBNode(
500
552
  memory=parent_value,
501
553
  metadata=TreeNodeTextualMemoryMetadata(
502
- user_id="", # TODO: summarized node: no user_id
503
- session_id="", # TODO: summarized node: no session_id
554
+ user_id=None,
555
+ session_id=None,
504
556
  memory_type=scope,
505
557
  status="activated",
506
558
  key=parent_key,
507
559
  tags=parent_tags,
508
560
  embedding=embedding,
509
561
  usage=[],
510
- sources=[n.id for n in cluster_nodes],
562
+ sources=build_summary_parent_node(cluster_nodes),
511
563
  background=parent_background,
512
- confidence=0.99,
564
+ confidence=0.66,
513
565
  type="topic",
514
566
  ),
515
567
  )
@@ -518,7 +570,7 @@ class GraphStructureReorganizer:
518
570
  def _parse_json_result(self, response_text):
519
571
  try:
520
572
  response_text = response_text.replace("```", "").replace("json", "")
521
- response_json = json.loads(response_text)
573
+ response_json = extract_first_to_last_brace(response_text)[1]
522
574
  return response_json
523
575
  except json.JSONDecodeError as e:
524
576
  logger.warning(
@@ -2,15 +2,17 @@
2
2
 
3
3
  import json
4
4
 
5
- from concurrent.futures import ThreadPoolExecutor, as_completed
5
+ from concurrent.futures import as_completed
6
6
  from datetime import datetime
7
+ from typing import Any
7
8
 
8
9
  import requests
9
10
 
11
+ from memos.context.context import ContextThreadPoolExecutor
10
12
  from memos.embedders.factory import OllamaEmbedder
11
13
  from memos.log import get_logger
12
14
  from memos.mem_reader.base import BaseMemReader
13
- from memos.memories.textual.item import TextualMemoryItem
15
+ from memos.memories.textual.item import SourceMessage, TextualMemoryItem
14
16
 
15
17
 
16
18
  logger = get_logger(__name__)
@@ -177,7 +179,7 @@ class BochaAISearchRetriever:
177
179
  if not info:
178
180
  info = {"user_id": "", "session_id": ""}
179
181
 
180
- with ThreadPoolExecutor(max_workers=8) as executor:
182
+ with ContextThreadPoolExecutor(max_workers=8) as executor:
181
183
  futures = [
182
184
  executor.submit(self._process_result, r, query, parsed_goal, info)
183
185
  for r in search_results
@@ -193,7 +195,7 @@ class BochaAISearchRetriever:
193
195
  return list(unique_memory_items.values())
194
196
 
195
197
  def _process_result(
196
- self, result: dict, query: str, parsed_goal: str, info: None
198
+ self, result: dict, query: str, parsed_goal: str, info: dict[str, Any]
197
199
  ) -> list[TextualMemoryItem]:
198
200
  """Process one Bocha search result into TextualMemoryItem."""
199
201
  title = result.get("name", "")
@@ -218,12 +220,14 @@ class BochaAISearchRetriever:
218
220
  memory_items = []
219
221
  for read_item_i in read_items[0]:
220
222
  read_item_i.memory = (
221
- f"Title: {title}\nNewsTime: {publish_time}\nSummary: {summary}\n"
223
+ f"[Outer internet view] Title: {title}\nNewsTime:"
224
+ f" {publish_time}\nSummary:"
225
+ f" {summary}\n"
222
226
  f"Content: {read_item_i.memory}"
223
227
  )
224
228
  read_item_i.metadata.source = "web"
225
229
  read_item_i.metadata.memory_type = "OuterMemory"
226
- read_item_i.metadata.sources = [url] if url else []
230
+ read_item_i.metadata.sources = [SourceMessage(type="web", url=url)] if url else []
227
231
  read_item_i.metadata.visibility = "public"
228
232
  memory_items.append(read_item_i)
229
233
  return memory_items
@@ -7,7 +7,11 @@ from datetime import datetime
7
7
  import requests
8
8
 
9
9
  from memos.embedders.factory import OllamaEmbedder
10
- from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata
10
+ from memos.memories.textual.item import (
11
+ SourceMessage,
12
+ TextualMemoryItem,
13
+ TreeNodeTextualMemoryMetadata,
14
+ )
11
15
 
12
16
 
13
17
  class GoogleCustomSearchAPI:
@@ -172,7 +176,7 @@ class InternetGoogleRetriever:
172
176
  visibility="public",
173
177
  memory_type="LongTermMemory", # Internet search results as working memory
174
178
  key=title,
175
- sources=[link] if link else [],
179
+ sources=[SourceMessage(type="web", url=link)] if link else [],
176
180
  embedding=self.embedder.embed([memory_content])[0], # Can add embedding later
177
181
  created_at=datetime.now().isoformat(),
178
182
  usage=[],
@@ -10,6 +10,7 @@ from memos.memories.textual.tree_text_memory.retrieve.internet_retriever import
10
10
  InternetGoogleRetriever,
11
11
  )
12
12
  from memos.memories.textual.tree_text_memory.retrieve.xinyusearch import XinyuSearchRetriever
13
+ from memos.memos_tools.singleton import singleton_factory
13
14
 
14
15
 
15
16
  class InternetRetrieverFactory:
@@ -23,6 +24,7 @@ class InternetRetrieverFactory:
23
24
  }
24
25
 
25
26
  @classmethod
27
+ @singleton_factory()
26
28
  def from_config(
27
29
  cls, config_factory: InternetRetrieverConfigFactory, embedder: BaseEmbedder
28
30
  ) -> InternetGoogleRetriever | None:
@@ -1,11 +1,16 @@
1
1
  import concurrent.futures
2
2
 
3
+ from memos.context.context import ContextThreadPoolExecutor
3
4
  from memos.embedders.factory import OllamaEmbedder
4
5
  from memos.graph_dbs.neo4j import Neo4jGraphDB
6
+ from memos.log import get_logger
5
7
  from memos.memories.textual.item import TextualMemoryItem
6
8
  from memos.memories.textual.tree_text_memory.retrieve.retrieval_mid_structs import ParsedTaskGoal
7
9
 
8
10
 
11
+ logger = get_logger(__name__)
12
+
13
+
9
14
  class GraphMemoryRetriever:
10
15
  """
11
16
  Unified memory retriever that combines both graph-based and vector-based retrieval logic.
@@ -14,6 +19,8 @@ class GraphMemoryRetriever:
14
19
  def __init__(self, graph_store: Neo4jGraphDB, embedder: OllamaEmbedder):
15
20
  self.graph_store = graph_store
16
21
  self.embedder = embedder
22
+ self.max_workers = 10
23
+ self.filter_weight = 0.6
17
24
 
18
25
  def retrieve(
19
26
  self,
@@ -22,6 +29,7 @@ class GraphMemoryRetriever:
22
29
  top_k: int,
23
30
  memory_scope: str,
24
31
  query_embedding: list[list[float]] | None = None,
32
+ search_filter: dict | None = None,
25
33
  ) -> list[TextualMemoryItem]:
26
34
  """
27
35
  Perform hybrid memory retrieval:
@@ -35,7 +43,7 @@ class GraphMemoryRetriever:
35
43
  top_k (int): Number of candidates to return.
36
44
  memory_scope (str): One of ['working', 'long_term', 'user'].
37
45
  query_embedding(list of embedding): list of embedding of query
38
-
46
+ search_filter (dict, optional): Optional metadata filters for search results.
39
47
  Returns:
40
48
  list: Combined memory items.
41
49
  """
@@ -45,16 +53,20 @@ class GraphMemoryRetriever:
45
53
  if memory_scope == "WorkingMemory":
46
54
  # For working memory, retrieve all entries (no filtering)
47
55
  working_memories = self.graph_store.get_all_memory_items(
48
- scope="WorkingMemory", include_embedding=True
56
+ scope="WorkingMemory", include_embedding=False
49
57
  )
50
58
  return [TextualMemoryItem.from_dict(record) for record in working_memories]
51
59
 
52
- with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
60
+ with ContextThreadPoolExecutor(max_workers=2) as executor:
53
61
  # Structured graph-based retrieval
54
62
  future_graph = executor.submit(self._graph_recall, parsed_goal, memory_scope)
55
63
  # Vector similarity search
56
64
  future_vector = executor.submit(
57
- self._vector_recall, query_embedding, memory_scope, top_k
65
+ self._vector_recall,
66
+ query_embedding or [],
67
+ memory_scope,
68
+ top_k,
69
+ search_filter=search_filter,
58
70
  )
59
71
 
60
72
  graph_results = future_graph.result()
@@ -74,6 +86,51 @@ class GraphMemoryRetriever:
74
86
 
75
87
  return list(combined.values())
76
88
 
89
+ def retrieve_from_cube(
90
+ self,
91
+ top_k: int,
92
+ memory_scope: str,
93
+ query_embedding: list[list[float]] | None = None,
94
+ cube_name: str = "memos_cube01",
95
+ ) -> list[TextualMemoryItem]:
96
+ """
97
+ Perform hybrid memory retrieval:
98
+ - Run graph-based lookup from dispatch plan.
99
+ - Run vector similarity search from embedded query.
100
+ - Merge and return combined result set.
101
+
102
+ Args:
103
+ top_k (int): Number of candidates to return.
104
+ memory_scope (str): One of ['working', 'long_term', 'user'].
105
+ query_embedding(list of embedding): list of embedding of query
106
+ cube_name: specify cube_name
107
+
108
+ Returns:
109
+ list: Combined memory items.
110
+ """
111
+ if memory_scope not in ["WorkingMemory", "LongTermMemory", "UserMemory"]:
112
+ raise ValueError(f"Unsupported memory scope: {memory_scope}")
113
+
114
+ graph_results = self._vector_recall(
115
+ query_embedding, memory_scope, top_k, cube_name=cube_name
116
+ )
117
+
118
+ for result_i in graph_results:
119
+ result_i.metadata.memory_type = "OuterMemory"
120
+ # Merge and deduplicate by ID
121
+ combined = {item.id: item for item in graph_results}
122
+
123
+ graph_ids = {item.id for item in graph_results}
124
+ combined_ids = set(combined.keys())
125
+ lost_ids = graph_ids - combined_ids
126
+
127
+ if lost_ids:
128
+ print(
129
+ f"[DEBUG] The following nodes were in graph_results but missing in combined: {lost_ids}"
130
+ )
131
+
132
+ return list(combined.values())
133
+
77
134
  def _graph_recall(
78
135
  self, parsed_goal: ParsedTaskGoal, memory_scope: str
79
136
  ) -> list[TextualMemoryItem]:
@@ -108,7 +165,7 @@ class GraphMemoryRetriever:
108
165
  return []
109
166
 
110
167
  # Load nodes and post-filter
111
- node_dicts = self.graph_store.get_nodes(list(candidate_ids), include_embedding=True)
168
+ node_dicts = self.graph_store.get_nodes(list(candidate_ids), include_embedding=False)
112
169
 
113
170
  final_nodes = []
114
171
  for node in node_dicts:
@@ -134,31 +191,72 @@ class GraphMemoryRetriever:
134
191
  query_embedding: list[list[float]],
135
192
  memory_scope: str,
136
193
  top_k: int = 20,
137
- max_num: int = 5,
194
+ max_num: int = 3,
195
+ cube_name: str | None = None,
196
+ search_filter: dict | None = None,
138
197
  ) -> list[TextualMemoryItem]:
139
198
  """
140
- # TODO: tackle with post-filter and pre-filter(5.18+) better.
141
199
  Perform vector-based similarity retrieval using query embedding.
200
+ # TODO: tackle with post-filter and pre-filter(5.18+) better.
142
201
  """
143
- all_matches = []
202
+ if not query_embedding:
203
+ return []
144
204
 
145
- def search_single(vec):
205
+ def search_single(vec, filt=None):
146
206
  return (
147
- self.graph_store.search_by_embedding(vector=vec, top_k=top_k, scope=memory_scope)
207
+ self.graph_store.search_by_embedding(
208
+ vector=vec,
209
+ top_k=top_k,
210
+ scope=memory_scope,
211
+ cube_name=cube_name,
212
+ search_filter=filt,
213
+ )
148
214
  or []
149
215
  )
150
216
 
151
- with concurrent.futures.ThreadPoolExecutor() as executor:
152
- futures = [executor.submit(search_single, vec) for vec in query_embedding[:max_num]]
153
- for future in concurrent.futures.as_completed(futures):
154
- result = future.result()
155
- all_matches.extend(result)
217
+ def search_path_a():
218
+ """Path A: search without filter"""
219
+ path_a_hits = []
220
+ with ContextThreadPoolExecutor() as executor:
221
+ futures = [
222
+ executor.submit(search_single, vec, None) for vec in query_embedding[:max_num]
223
+ ]
224
+ for f in concurrent.futures.as_completed(futures):
225
+ path_a_hits.extend(f.result() or [])
226
+ return path_a_hits
156
227
 
157
- if not all_matches:
158
- return []
228
+ def search_path_b():
229
+ """Path B: search with filter"""
230
+ if not search_filter:
231
+ return []
232
+ path_b_hits = []
233
+ with ContextThreadPoolExecutor() as executor:
234
+ futures = [
235
+ executor.submit(search_single, vec, search_filter)
236
+ for vec in query_embedding[:max_num]
237
+ ]
238
+ for f in concurrent.futures.as_completed(futures):
239
+ path_b_hits.extend(f.result() or [])
240
+ return path_b_hits
241
+
242
+ # Execute both paths concurrently
243
+ all_hits = []
244
+ with ContextThreadPoolExecutor(max_workers=2) as executor:
245
+ path_a_future = executor.submit(search_path_a)
246
+ path_b_future = executor.submit(search_path_b)
159
247
 
160
- # Step 3: Extract matched IDs and retrieve full nodes
161
- unique_ids = set({r["id"] for r in all_matches})
162
- node_dicts = self.graph_store.get_nodes(list(unique_ids), include_embedding=True)
248
+ all_hits.extend(path_a_future.result())
249
+ all_hits.extend(path_b_future.result())
163
250
 
164
- return [TextualMemoryItem.from_dict(record) for record in node_dicts]
251
+ if not all_hits:
252
+ return []
253
+
254
+ # merge and deduplicate
255
+ unique_ids = {r["id"] for r in all_hits if r.get("id")}
256
+ node_dicts = (
257
+ self.graph_store.get_nodes(
258
+ list(unique_ids), include_embedding=False, cube_name=cube_name
259
+ )
260
+ or []
261
+ )
262
+ return [TextualMemoryItem.from_dict(n) for n in node_dicts]