MemoryOS 1.0.1__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MemoryOS might be problematic. Click here for more details.

Files changed (82) hide show
  1. {memoryos-1.0.1.dist-info → memoryos-1.1.1.dist-info}/METADATA +7 -2
  2. {memoryos-1.0.1.dist-info → memoryos-1.1.1.dist-info}/RECORD +79 -65
  3. {memoryos-1.0.1.dist-info → memoryos-1.1.1.dist-info}/WHEEL +1 -1
  4. memos/__init__.py +1 -1
  5. memos/api/client.py +109 -0
  6. memos/api/config.py +11 -9
  7. memos/api/context/dependencies.py +15 -55
  8. memos/api/middleware/request_context.py +9 -40
  9. memos/api/product_api.py +2 -3
  10. memos/api/product_models.py +91 -16
  11. memos/api/routers/product_router.py +23 -16
  12. memos/api/start_api.py +10 -0
  13. memos/configs/graph_db.py +4 -0
  14. memos/configs/mem_scheduler.py +38 -3
  15. memos/context/context.py +255 -0
  16. memos/embedders/factory.py +2 -0
  17. memos/graph_dbs/nebular.py +230 -232
  18. memos/graph_dbs/neo4j.py +35 -1
  19. memos/graph_dbs/neo4j_community.py +7 -0
  20. memos/llms/factory.py +2 -0
  21. memos/llms/openai.py +74 -2
  22. memos/log.py +27 -15
  23. memos/mem_cube/general.py +3 -1
  24. memos/mem_os/core.py +60 -22
  25. memos/mem_os/main.py +3 -6
  26. memos/mem_os/product.py +35 -11
  27. memos/mem_reader/factory.py +2 -0
  28. memos/mem_reader/simple_struct.py +127 -74
  29. memos/mem_scheduler/analyzer/__init__.py +0 -0
  30. memos/mem_scheduler/analyzer/mos_for_test_scheduler.py +569 -0
  31. memos/mem_scheduler/analyzer/scheduler_for_eval.py +280 -0
  32. memos/mem_scheduler/base_scheduler.py +126 -56
  33. memos/mem_scheduler/general_modules/dispatcher.py +2 -2
  34. memos/mem_scheduler/general_modules/misc.py +99 -1
  35. memos/mem_scheduler/general_modules/scheduler_logger.py +17 -11
  36. memos/mem_scheduler/general_scheduler.py +40 -88
  37. memos/mem_scheduler/memory_manage_modules/__init__.py +5 -0
  38. memos/mem_scheduler/memory_manage_modules/memory_filter.py +308 -0
  39. memos/mem_scheduler/{general_modules → memory_manage_modules}/retriever.py +34 -7
  40. memos/mem_scheduler/monitors/dispatcher_monitor.py +9 -8
  41. memos/mem_scheduler/monitors/general_monitor.py +119 -39
  42. memos/mem_scheduler/optimized_scheduler.py +124 -0
  43. memos/mem_scheduler/orm_modules/__init__.py +0 -0
  44. memos/mem_scheduler/orm_modules/base_model.py +635 -0
  45. memos/mem_scheduler/orm_modules/monitor_models.py +261 -0
  46. memos/mem_scheduler/scheduler_factory.py +2 -0
  47. memos/mem_scheduler/schemas/monitor_schemas.py +96 -29
  48. memos/mem_scheduler/utils/config_utils.py +100 -0
  49. memos/mem_scheduler/utils/db_utils.py +33 -0
  50. memos/mem_scheduler/utils/filter_utils.py +1 -1
  51. memos/mem_scheduler/webservice_modules/__init__.py +0 -0
  52. memos/memories/activation/kv.py +2 -1
  53. memos/memories/textual/item.py +95 -16
  54. memos/memories/textual/naive.py +1 -1
  55. memos/memories/textual/tree.py +27 -3
  56. memos/memories/textual/tree_text_memory/organize/handler.py +4 -2
  57. memos/memories/textual/tree_text_memory/organize/manager.py +28 -14
  58. memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +1 -2
  59. memos/memories/textual/tree_text_memory/organize/reorganizer.py +75 -23
  60. memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +7 -5
  61. memos/memories/textual/tree_text_memory/retrieve/internet_retriever.py +6 -2
  62. memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +2 -0
  63. memos/memories/textual/tree_text_memory/retrieve/recall.py +70 -22
  64. memos/memories/textual/tree_text_memory/retrieve/searcher.py +101 -33
  65. memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py +5 -4
  66. memos/memos_tools/singleton.py +174 -0
  67. memos/memos_tools/thread_safe_dict.py +22 -0
  68. memos/memos_tools/thread_safe_dict_segment.py +382 -0
  69. memos/parsers/factory.py +2 -0
  70. memos/reranker/concat.py +59 -0
  71. memos/reranker/cosine_local.py +1 -0
  72. memos/reranker/factory.py +5 -0
  73. memos/reranker/http_bge.py +225 -12
  74. memos/templates/mem_scheduler_prompts.py +242 -0
  75. memos/types.py +4 -1
  76. memos/api/context/context.py +0 -147
  77. memos/api/context/context_thread.py +0 -96
  78. memos/mem_scheduler/mos_for_test_scheduler.py +0 -146
  79. {memoryos-1.0.1.dist-info → memoryos-1.1.1.dist-info}/entry_points.txt +0 -0
  80. {memoryos-1.0.1.dist-info → memoryos-1.1.1.dist-info/licenses}/LICENSE +0 -0
  81. /memos/mem_scheduler/{general_modules → webservice_modules}/rabbitmq_service.py +0 -0
  82. /memos/mem_scheduler/{general_modules → webservice_modules}/redis_service.py +0 -0
@@ -1,13 +1,48 @@
1
1
  """Defines memory item types for textual memory."""
2
2
 
3
+ import json
3
4
  import uuid
4
5
 
5
6
  from datetime import datetime
6
- from typing import Literal
7
+ from typing import Any, Literal
7
8
 
8
9
  from pydantic import BaseModel, ConfigDict, Field, field_validator
9
10
 
10
11
 
12
+ ALLOWED_ROLES = {"user", "assistant", "system"}
13
+
14
+
15
+ class SourceMessage(BaseModel):
16
+ """
17
+ Purpose: **memory provenance / traceability**.
18
+
19
+ Capture the minimal, reproducible origin context of a memory item so it can be
20
+ audited, traced, rolled back, or de-duplicated later.
21
+
22
+ Fields & conventions:
23
+ - type: Source kind (e.g., "chat", "doc", "web", "file", "system", ...).
24
+ If not provided, upstream logic may infer it:
25
+ presence of `role` ⇒ "chat"; otherwise ⇒ "doc".
26
+ - role: Conversation role ("user" | "assistant" | "system") when the
27
+ source is a chat turn.
28
+ - content: Minimal reproducible snippet from the source. If omitted,
29
+ upstream may fall back to `doc_path` / `url` / `message_id`.
30
+ - chat_time / message_id / doc_path: Locators for precisely pointing back
31
+ to the original record (timestamp, message id, document path).
32
+ - Extra fields: Allowed (`model_config.extra="allow"`) to carry arbitrary
33
+ provenance attributes (e.g., url, page, offset, span, local_confidence).
34
+ """
35
+
36
+ type: str | None = "chat"
37
+ role: Literal["user", "assistant", "system"] | None = None
38
+ chat_time: str | None = None
39
+ message_id: str | None = None
40
+ content: str | None = None
41
+ doc_path: str | None = None
42
+
43
+ model_config = ConfigDict(extra="allow")
44
+
45
+
11
46
  class TextualMemoryMetadata(BaseModel):
12
47
  """Metadata for a memory item.
13
48
 
@@ -62,7 +97,7 @@ class TreeNodeTextualMemoryMetadata(TextualMemoryMetadata):
62
97
  memory_type: Literal["WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"] = Field(
63
98
  default="WorkingMemory", description="Memory lifecycle type."
64
99
  )
65
- sources: list[str] | None = Field(
100
+ sources: list[SourceMessage] | None = Field(
66
101
  default=None, description="Multiple origins of the memory (e.g., URLs, notes)."
67
102
  )
68
103
  embedding: list[float] | None = Field(
@@ -74,8 +109,8 @@ class TreeNodeTextualMemoryMetadata(TextualMemoryMetadata):
74
109
  description="The timestamp of the first creation to the memory. Useful "
75
110
  "for tracking memory initialization. Format: ISO 8601.",
76
111
  )
77
- usage: list[str] | None = Field(
78
- default=[],
112
+ usage: list[str] = Field(
113
+ default_factory=list,
79
114
  description="Usage history of this node",
80
115
  )
81
116
  background: str | None = Field(
@@ -83,12 +118,40 @@ class TreeNodeTextualMemoryMetadata(TextualMemoryMetadata):
83
118
  description="background of this node",
84
119
  )
85
120
 
86
- @field_validator("sources")
121
+ @field_validator("sources", mode="before")
87
122
  @classmethod
88
- def validate_sources(cls, v):
89
- if v is not None and not isinstance(v, list):
90
- raise ValueError("Sources must be a list of strings.")
91
- return v
123
+ def coerce_sources(cls, v):
124
+ if v is None:
125
+ return v
126
+ if not isinstance(v, list):
127
+ raise TypeError("sources must be a list")
128
+ out = []
129
+ for item in v:
130
+ if isinstance(item, SourceMessage):
131
+ out.append(item)
132
+
133
+ elif isinstance(item, dict):
134
+ d = dict(item)
135
+ if d.get("type") is None:
136
+ d["type"] = "chat" if d.get("role") in ALLOWED_ROLES else "doc"
137
+ out.append(SourceMessage(**d))
138
+
139
+ elif isinstance(item, str):
140
+ try:
141
+ parsed = json.loads(item)
142
+ except Exception:
143
+ parsed = None
144
+
145
+ if isinstance(parsed, dict):
146
+ if parsed.get("type") is None:
147
+ parsed["type"] = "chat" if parsed.get("role") in ALLOWED_ROLES else "doc"
148
+ out.append(SourceMessage(**parsed))
149
+ else:
150
+ out.append(SourceMessage(type="doc", content=item))
151
+
152
+ else:
153
+ out.append(SourceMessage(type="doc", content=str(item)))
154
+ return out
92
155
 
93
156
  def __str__(self) -> str:
94
157
  """Pretty string representation of the metadata."""
@@ -114,19 +177,17 @@ class TextualMemoryItem(BaseModel):
114
177
  id: str = Field(default_factory=lambda: str(uuid.uuid4()))
115
178
  memory: str
116
179
  metadata: (
117
- TextualMemoryMetadata
180
+ SearchedTreeNodeTextualMemoryMetadata
118
181
  | TreeNodeTextualMemoryMetadata
119
- | SearchedTreeNodeTextualMemoryMetadata
182
+ | TextualMemoryMetadata
120
183
  ) = Field(default_factory=TextualMemoryMetadata)
121
184
 
122
185
  model_config = ConfigDict(extra="forbid")
123
186
 
187
+ @field_validator("id")
124
188
  @classmethod
125
- def validate_id(cls, v):
126
- try:
127
- uuid.UUID(v)
128
- except ValueError as e:
129
- raise ValueError("Invalid UUID format") from e
189
+ def _validate_id(cls, v: str) -> str:
190
+ uuid.UUID(v)
130
191
  return v
131
192
 
132
193
  @classmethod
@@ -136,6 +197,24 @@ class TextualMemoryItem(BaseModel):
136
197
  def to_dict(self) -> dict:
137
198
  return self.model_dump(exclude_none=True)
138
199
 
200
+ @field_validator("metadata", mode="before")
201
+ @classmethod
202
+ def _coerce_metadata(cls, v: Any):
203
+ if isinstance(
204
+ v,
205
+ SearchedTreeNodeTextualMemoryMetadata
206
+ | TreeNodeTextualMemoryMetadata
207
+ | TextualMemoryMetadata,
208
+ ):
209
+ return v
210
+ if isinstance(v, dict):
211
+ if v.get("relativity") is not None:
212
+ return SearchedTreeNodeTextualMemoryMetadata(**v)
213
+ if any(k in v for k in ("sources", "memory_type", "embedding", "background", "usage")):
214
+ return TreeNodeTextualMemoryMetadata(**v)
215
+ return TextualMemoryMetadata(**v)
216
+ return v
217
+
139
218
  def __str__(self) -> str:
140
219
  """Pretty string representation of the memory item."""
141
220
  return f"<ID: {self.id} | Memory: {self.memory} | Metadata: {self.metadata!s}>"
@@ -115,7 +115,7 @@ class NaiveTextMemory(BaseTextMemory):
115
115
  self.memories[i] = memory_dict
116
116
  break
117
117
 
118
- def search(self, query: str, top_k: int) -> list[TextualMemoryItem]:
118
+ def search(self, query: str, top_k: int, **kwargs) -> list[TextualMemoryItem]:
119
119
  """Search for memories based on a query."""
120
120
  sims = [
121
121
  (memory, len(set(query.split()) & set(memory["memory"].split())))
@@ -2,6 +2,7 @@ import json
2
2
  import os
3
3
  import shutil
4
4
  import tempfile
5
+ import time
5
6
 
6
7
  from datetime import datetime
7
8
  from pathlib import Path
@@ -32,15 +33,28 @@ class TreeTextMemory(BaseTextMemory):
32
33
 
33
34
  def __init__(self, config: TreeTextMemoryConfig):
34
35
  """Initialize memory with the given configuration."""
36
+ time_start = time.time()
35
37
  self.config: TreeTextMemoryConfig = config
36
38
  self.extractor_llm: OpenAILLM | OllamaLLM | AzureLLM = LLMFactory.from_config(
37
39
  config.extractor_llm
38
40
  )
41
+ logger.info(f"time init: extractor_llm time is: {time.time() - time_start}")
42
+
43
+ time_start_ex = time.time()
39
44
  self.dispatcher_llm: OpenAILLM | OllamaLLM | AzureLLM = LLMFactory.from_config(
40
45
  config.dispatcher_llm
41
46
  )
47
+ logger.info(f"time init: dispatcher_llm time is: {time.time() - time_start_ex}")
48
+
49
+ time_start_em = time.time()
42
50
  self.embedder: OllamaEmbedder = EmbedderFactory.from_config(config.embedder)
51
+ logger.info(f"time init: embedder time is: {time.time() - time_start_em}")
52
+
53
+ time_start_gs = time.time()
43
54
  self.graph_store: Neo4jGraphDB = GraphStoreFactory.from_config(config.graph_db)
55
+ logger.info(f"time init: graph_store time is: {time.time() - time_start_gs}")
56
+
57
+ time_start_rr = time.time()
44
58
  if config.reranker is None:
45
59
  default_cfg = RerankerConfigFactory.model_validate(
46
60
  {
@@ -54,9 +68,10 @@ class TreeTextMemory(BaseTextMemory):
54
68
  self.reranker = RerankerFactory.from_config(default_cfg)
55
69
  else:
56
70
  self.reranker = RerankerFactory.from_config(config.reranker)
57
-
71
+ logger.info(f"time init: reranker time is: {time.time() - time_start_rr}")
58
72
  self.is_reorganize = config.reorganize
59
73
 
74
+ time_start_mm = time.time()
60
75
  self.memory_manager: MemoryManager = MemoryManager(
61
76
  self.graph_store,
62
77
  self.embedder,
@@ -69,7 +84,8 @@ class TreeTextMemory(BaseTextMemory):
69
84
  },
70
85
  is_reorganize=self.is_reorganize,
71
86
  )
72
-
87
+ logger.info(f"time init: memory_manager time is: {time.time() - time_start_mm}")
88
+ time_start_ir = time.time()
73
89
  # Create internet retriever if configured
74
90
  self.internet_retriever = None
75
91
  if config.internet_retriever is not None:
@@ -81,6 +97,7 @@ class TreeTextMemory(BaseTextMemory):
81
97
  )
82
98
  else:
83
99
  logger.info("No internet retriever configured")
100
+ logger.info(f"time init: internet_retriever time is: {time.time() - time_start_ir}")
84
101
 
85
102
  def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]:
86
103
  """Add memories.
@@ -122,6 +139,7 @@ class TreeTextMemory(BaseTextMemory):
122
139
  memory_type: str = "All",
123
140
  manual_close_internet: bool = False,
124
141
  moscube: bool = False,
142
+ search_filter: dict | None = None,
125
143
  ) -> list[TextualMemoryItem]:
126
144
  """Search for memories based on a query.
127
145
  User query -> TaskGoalParser -> MemoryPathResolver ->
@@ -136,6 +154,12 @@ class TreeTextMemory(BaseTextMemory):
136
154
  memory_type (str): Type restriction for search.
137
155
  ['All', 'WorkingMemory', 'LongTermMemory', 'UserMemory']
138
156
  manual_close_internet (bool): If True, the internet retriever will be closed by this search, it high priority than config.
157
+ moscube (bool): whether you use moscube to answer questions
158
+ search_filter (dict, optional): Optional metadata filters for search results.
159
+ - Keys correspond to memory metadata fields (e.g., "user_id", "session_id").
160
+ - Values are exact-match conditions.
161
+ Example: {"user_id": "123", "session_id": "abc"}
162
+ If None, no additional filtering is applied.
139
163
  Returns:
140
164
  list[TextualMemoryItem]: List of matching memories.
141
165
  """
@@ -160,7 +184,7 @@ class TreeTextMemory(BaseTextMemory):
160
184
  internet_retriever=self.internet_retriever,
161
185
  moscube=moscube,
162
186
  )
163
- return searcher.search(query, top_k, info, mode, memory_type)
187
+ return searcher.search(query, top_k, info, mode, memory_type, search_filter)
164
188
 
165
189
  def get_relevant_subgraph(
166
190
  self, query: str, top_k: int = 5, depth: int = 2, center_status: str = "activated"
@@ -1,5 +1,6 @@
1
1
  import json
2
2
  import re
3
+
3
4
  from datetime import datetime
4
5
 
5
6
  from dateutil import parser
@@ -14,6 +15,7 @@ from memos.templates.tree_reorganize_prompts import (
14
15
  MEMORY_RELATION_RESOLVER_PROMPT,
15
16
  )
16
17
 
18
+
17
19
  logger = get_logger(__name__)
18
20
 
19
21
 
@@ -50,12 +52,12 @@ class NodeHandler:
50
52
  ]
51
53
  result = self.llm.generate(prompt).strip()
52
54
  if result == "contradictory":
53
- logger.warning(
55
+ logger.info(
54
56
  f'detected "{memory.memory}" <==CONFLICT==> "{embedding_candidate.memory}"'
55
57
  )
56
58
  detected_relationships.append([memory, embedding_candidate, "contradictory"])
57
59
  elif result == "redundant":
58
- logger.warning(
60
+ logger.info(
59
61
  f'detected "{memory.memory}" <==REDUNDANT==> "{embedding_candidate.memory}"'
60
62
  )
61
63
  detected_relationships.append([memory, embedding_candidate, "redundant"])
@@ -1,8 +1,10 @@
1
+ import traceback
1
2
  import uuid
2
3
 
3
- from concurrent.futures import ThreadPoolExecutor, as_completed
4
+ from concurrent.futures import as_completed
4
5
  from datetime import datetime
5
6
 
7
+ from memos.context.context import ContextThreadPoolExecutor
6
8
  from memos.embedders.factory import OllamaEmbedder
7
9
  from memos.graph_dbs.neo4j import Neo4jGraphDB
8
10
  from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM
@@ -55,24 +57,35 @@ class MemoryManager:
55
57
  """
56
58
  added_ids: list[str] = []
57
59
 
58
- with ThreadPoolExecutor(max_workers=8) as executor:
60
+ with ContextThreadPoolExecutor(max_workers=8) as executor:
59
61
  futures = {executor.submit(self._process_memory, m): m for m in memories}
60
- for future in as_completed(futures):
62
+ for future in as_completed(futures, timeout=60):
61
63
  try:
62
64
  ids = future.result()
63
65
  added_ids.extend(ids)
64
66
  except Exception as e:
65
67
  logger.exception("Memory processing error: ", exc_info=e)
66
68
 
67
- self.graph_store.remove_oldest_memory(
68
- memory_type="WorkingMemory", keep_latest=self.memory_size["WorkingMemory"]
69
- )
70
- self.graph_store.remove_oldest_memory(
71
- memory_type="LongTermMemory", keep_latest=self.memory_size["LongTermMemory"]
72
- )
73
- self.graph_store.remove_oldest_memory(
74
- memory_type="UserMemory", keep_latest=self.memory_size["UserMemory"]
75
- )
69
+ try:
70
+ self.graph_store.remove_oldest_memory(
71
+ memory_type="WorkingMemory", keep_latest=self.memory_size["WorkingMemory"]
72
+ )
73
+ except Exception:
74
+ logger.warning(f"Remove WorkingMemory error: {traceback.format_exc()}")
75
+
76
+ try:
77
+ self.graph_store.remove_oldest_memory(
78
+ memory_type="LongTermMemory", keep_latest=self.memory_size["LongTermMemory"]
79
+ )
80
+ except Exception:
81
+ logger.warning(f"Remove LongTermMemory error: {traceback.format_exc()}")
82
+
83
+ try:
84
+ self.graph_store.remove_oldest_memory(
85
+ memory_type="UserMemory", keep_latest=self.memory_size["UserMemory"]
86
+ )
87
+ except Exception:
88
+ logger.warning(f"Remove UserMemory error: {traceback.format_exc()}")
76
89
 
77
90
  self._refresh_memory_size()
78
91
  return added_ids
@@ -82,12 +95,12 @@ class MemoryManager:
82
95
  Replace WorkingMemory
83
96
  """
84
97
  working_memory_top_k = memories[: self.memory_size["WorkingMemory"]]
85
- with ThreadPoolExecutor(max_workers=8) as executor:
98
+ with ContextThreadPoolExecutor(max_workers=8) as executor:
86
99
  futures = [
87
100
  executor.submit(self._add_memory_to_db, memory, "WorkingMemory")
88
101
  for memory in working_memory_top_k
89
102
  ]
90
- for future in as_completed(futures):
103
+ for future in as_completed(futures, timeout=60):
91
104
  try:
92
105
  future.result()
93
106
  except Exception as e:
@@ -102,6 +115,7 @@ class MemoryManager:
102
115
  """
103
116
  Return the cached memory type counts.
104
117
  """
118
+ self._refresh_memory_size()
105
119
  return self.current_memory_size
106
120
 
107
121
  def _refresh_memory_size(self) -> None:
@@ -46,7 +46,7 @@ class RelationAndReasoningDetector:
46
46
  "sequence_links": [],
47
47
  "aggregate_nodes": [],
48
48
  }
49
-
49
+ """
50
50
  nearest = self.graph_store.get_neighbors_by_tag(
51
51
  tags=node.metadata.tags,
52
52
  exclude_ids=exclude_ids,
@@ -55,7 +55,6 @@ class RelationAndReasoningDetector:
55
55
  )
56
56
  nearest = [GraphDBNode(**cand_data) for cand_data in nearest]
57
57
 
58
- """
59
58
  # 1) Pairwise relations (including CAUSE/CONDITION/CONFLICT)
60
59
  pairwise = self._detect_pairwise_causal_condition_relations(node, nearest)
61
60
  results["relations"].extend(pairwise["relations"])
@@ -4,19 +4,20 @@ import time
4
4
  import traceback
5
5
 
6
6
  from collections import defaultdict
7
- from concurrent.futures import ThreadPoolExecutor, as_completed
7
+ from concurrent.futures import as_completed
8
8
  from queue import PriorityQueue
9
9
  from typing import Literal
10
10
 
11
11
  import numpy as np
12
12
 
13
+ from memos.context.context import ContextThreadPoolExecutor
13
14
  from memos.dependency import require_python_package
14
15
  from memos.embedders.factory import OllamaEmbedder
15
16
  from memos.graph_dbs.item import GraphDBEdge, GraphDBNode
16
17
  from memos.graph_dbs.neo4j import Neo4jGraphDB
17
18
  from memos.llms.base import BaseLLM
18
19
  from memos.log import get_logger
19
- from memos.memories.textual.item import TreeNodeTextualMemoryMetadata
20
+ from memos.memories.textual.item import SourceMessage, TreeNodeTextualMemoryMetadata
20
21
  from memos.memories.textual.tree_text_memory.organize.handler import NodeHandler
21
22
  from memos.memories.textual.tree_text_memory.organize.relation_reason_detector import (
22
23
  RelationAndReasoningDetector,
@@ -27,6 +28,22 @@ from memos.templates.tree_reorganize_prompts import LOCAL_SUBCLUSTER_PROMPT, REO
27
28
  logger = get_logger(__name__)
28
29
 
29
30
 
31
+ def build_summary_parent_node(cluster_nodes):
32
+ normalized_sources = []
33
+ for n in cluster_nodes:
34
+ sm = SourceMessage(
35
+ type="chat",
36
+ role=None,
37
+ chat_time=None,
38
+ message_id=None,
39
+ content=n.memory,
40
+ # extra
41
+ node_id=n.id,
42
+ )
43
+ normalized_sources.append(sm)
44
+ return normalized_sources
45
+
46
+
30
47
  class QueueMessage:
31
48
  def __init__(
32
49
  self,
@@ -51,6 +68,15 @@ class QueueMessage:
51
68
  return op_priority[self.op] < op_priority[other.op]
52
69
 
53
70
 
71
+ def extract_first_to_last_brace(text: str):
72
+ start = text.find("{")
73
+ end = text.rfind("}")
74
+ if start == -1 or end == -1 or end < start:
75
+ return "", None
76
+ json_str = text[start : end + 1]
77
+ return json_str, json.loads(json_str)
78
+
79
+
54
80
  class GraphStructureReorganizer:
55
81
  def __init__(
56
82
  self, graph_store: Neo4jGraphDB, llm: BaseLLM, embedder: OllamaEmbedder, is_reorganize: bool
@@ -87,6 +113,7 @@ class GraphStructureReorganizer:
87
113
  1) queue is empty
88
114
  2) any running structure optimization is done
89
115
  """
116
+ deadline = time.time() + 600
90
117
  if not self.is_reorganize:
91
118
  return
92
119
 
@@ -96,6 +123,9 @@ class GraphStructureReorganizer:
96
123
 
97
124
  while any(self._is_optimizing.values()):
98
125
  logger.debug(f"Waiting for structure optimizer to finish... {self._is_optimizing}")
126
+ if time.time() > deadline:
127
+ logger.error(f"Wait timed out; flags={self._is_optimizing}")
128
+ break
99
129
  time.sleep(1)
100
130
  logger.debug("Structure optimizer is now idle.")
101
131
 
@@ -129,6 +159,9 @@ class GraphStructureReorganizer:
129
159
 
130
160
  logger.info("Structure optimizer schedule started.")
131
161
  while not getattr(self, "_stop_scheduler", False):
162
+ if any(self._is_optimizing.values()):
163
+ time.sleep(1)
164
+ continue
132
165
  if self._reorganize_needed:
133
166
  logger.info("[Reorganizer] Triggering optimize_structure due to new nodes.")
134
167
  self.optimize_structure(scope="LongTermMemory")
@@ -176,6 +209,7 @@ class GraphStructureReorganizer:
176
209
  local_tree_threshold: int = 10,
177
210
  min_cluster_size: int = 4,
178
211
  min_group_size: int = 20,
212
+ max_duration_sec: int = 600,
179
213
  ):
180
214
  """
181
215
  Periodically reorganize the graph:
@@ -183,8 +217,20 @@ class GraphStructureReorganizer:
183
217
  2. Summarize each cluster.
184
218
  3. Create parent nodes and build local PARENT trees.
185
219
  """
220
+ # --- Total time watch dog: check functions ---
221
+ start_ts = time.time()
222
+
223
+ def _check_deadline(where: str):
224
+ if time.time() - start_ts > max_duration_sec:
225
+ logger.error(
226
+ f"[GraphStructureReorganize] {scope} surpass {max_duration_sec}s,time "
227
+ f"over at {where}"
228
+ )
229
+ return True
230
+ return False
231
+
186
232
  if self._is_optimizing[scope]:
187
- logger.info(f"Already optimizing for {scope}. Skipping.")
233
+ logger.info(f"[GraphStructureReorganize] Already optimizing for {scope}. Skipping.")
188
234
  return
189
235
 
190
236
  if self.graph_store.node_not_exist(scope):
@@ -198,32 +244,35 @@ class GraphStructureReorganizer:
198
244
  )
199
245
 
200
246
  logger.debug(
201
- f"Num of scope in self.graph_store is {self.graph_store.get_memory_count(scope)}"
247
+ f"[GraphStructureReorganize] Num of scope in self.graph_store is"
248
+ f" {self.graph_store.get_memory_count(scope)}"
202
249
  )
203
250
  # Load candidate nodes
251
+ if _check_deadline("[GraphStructureReorganize] Before loading candidates"):
252
+ return
204
253
  raw_nodes = self.graph_store.get_structure_optimization_candidates(scope)
205
254
  nodes = [GraphDBNode(**n) for n in raw_nodes]
206
255
 
207
256
  if not nodes:
208
257
  logger.info("[GraphStructureReorganize] No nodes to optimize. Skipping.")
209
258
  return
210
-
211
259
  if len(nodes) < min_group_size:
212
260
  logger.info(
213
261
  f"[GraphStructureReorganize] Only {len(nodes)} candidate nodes found. Not enough to reorganize. Skipping."
214
262
  )
215
263
  return
216
264
 
217
- logger.info(f"[GraphStructureReorganize] Loaded {len(nodes)} nodes.")
218
-
219
265
  # Step 2: Partition nodes
266
+ if _check_deadline("[GraphStructureReorganize] Before partition"):
267
+ return
220
268
  partitioned_groups = self._partition(nodes)
221
-
222
269
  logger.info(
223
270
  f"[GraphStructureReorganize] Partitioned into {len(partitioned_groups)} clusters."
224
271
  )
225
272
 
226
- with ThreadPoolExecutor(max_workers=4) as executor:
273
+ if _check_deadline("[GraphStructureReorganize] Before submit partition task"):
274
+ return
275
+ with ContextThreadPoolExecutor(max_workers=4) as executor:
227
276
  futures = []
228
277
  for cluster_nodes in partitioned_groups:
229
278
  futures.append(
@@ -237,14 +286,17 @@ class GraphStructureReorganizer:
237
286
  )
238
287
 
239
288
  for f in as_completed(futures):
289
+ if _check_deadline("[GraphStructureReorganize] Waiting clusters..."):
290
+ for x in futures:
291
+ x.cancel()
292
+ return
240
293
  try:
241
294
  f.result()
242
295
  except Exception as e:
243
296
  logger.warning(
244
- f"[Reorganize] Cluster processing "
245
- f"failed: {e}, cluster_nodes: {cluster_nodes}, trace: {traceback.format_exc()}"
297
+ f"[GraphStructureReorganize] Cluster processing failed: {e}, trace: {traceback.format_exc()}"
246
298
  )
247
- logger.info("[GraphStructure Reorganize] Structure optimization finished.")
299
+ logger.info("[GraphStructure Reorganize] Structure optimization finished.")
248
300
 
249
301
  finally:
250
302
  self._is_optimizing[scope] = False
@@ -282,7 +334,7 @@ class GraphStructureReorganizer:
282
334
  nodes_to_check = cluster_nodes
283
335
  exclude_ids = [n.id for n in nodes_to_check]
284
336
 
285
- with ThreadPoolExecutor(max_workers=4) as executor:
337
+ with ContextThreadPoolExecutor(max_workers=4) as executor:
286
338
  futures = []
287
339
  for node in nodes_to_check:
288
340
  futures.append(
@@ -294,7 +346,7 @@ class GraphStructureReorganizer:
294
346
  )
295
347
  )
296
348
 
297
- for f in as_completed(futures):
349
+ for f in as_completed(futures, timeout=300):
298
350
  results = f.result()
299
351
 
300
352
  # 1) Add pairwise relations
@@ -331,11 +383,11 @@ class GraphStructureReorganizer:
331
383
  for child_id in agg_node.metadata.sources:
332
384
  self.graph_store.add_edge(agg_node.id, child_id, "AGGREGATE_TO")
333
385
 
334
- logger.info("[Reorganizer] Cluster relation/reasoning done.")
386
+ logger.info("[Reorganizer] Cluster relation/reasoning done.")
335
387
 
336
388
  def _local_subcluster(
337
- self, cluster_nodes: list[GraphDBNode], max_length: int = 8000
338
- ) -> (list)[list[GraphDBNode]]:
389
+ self, cluster_nodes: list[GraphDBNode], max_length: int = 15000
390
+ ) -> list[list[GraphDBNode]]:
339
391
  """
340
392
  Use LLM to split a large cluster into semantically coherent sub-clusters.
341
393
  """
@@ -350,7 +402,7 @@ class GraphStructureReorganizer:
350
402
 
351
403
  joined_scene = "\n".join(scene_lines)
352
404
  if len(joined_scene) > max_length:
353
- logger.warning(f"Sub-cluster too long: {joined_scene}")
405
+ logger.warning("Sub-cluster too long")
354
406
  prompt = LOCAL_SUBCLUSTER_PROMPT.replace("{joined_scene}", joined_scene[:max_length])
355
407
 
356
408
  messages = [{"role": "user", "content": prompt}]
@@ -499,17 +551,17 @@ class GraphStructureReorganizer:
499
551
  parent_node = GraphDBNode(
500
552
  memory=parent_value,
501
553
  metadata=TreeNodeTextualMemoryMetadata(
502
- user_id="", # TODO: summarized node: no user_id
503
- session_id="", # TODO: summarized node: no session_id
554
+ user_id=None,
555
+ session_id=None,
504
556
  memory_type=scope,
505
557
  status="activated",
506
558
  key=parent_key,
507
559
  tags=parent_tags,
508
560
  embedding=embedding,
509
561
  usage=[],
510
- sources=[n.id for n in cluster_nodes],
562
+ sources=build_summary_parent_node(cluster_nodes),
511
563
  background=parent_background,
512
- confidence=0.99,
564
+ confidence=0.66,
513
565
  type="topic",
514
566
  ),
515
567
  )
@@ -518,7 +570,7 @@ class GraphStructureReorganizer:
518
570
  def _parse_json_result(self, response_text):
519
571
  try:
520
572
  response_text = response_text.replace("```", "").replace("json", "")
521
- response_json = json.loads(response_text)
573
+ response_json = extract_first_to_last_brace(response_text)[1]
522
574
  return response_json
523
575
  except json.JSONDecodeError as e:
524
576
  logger.warning(