MemoryOS 0.2.2__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MemoryOS might be problematic. Click here for more details.

Files changed (82) hide show
  1. {memoryos-0.2.2.dist-info → memoryos-1.0.1.dist-info}/METADATA +7 -1
  2. {memoryos-0.2.2.dist-info → memoryos-1.0.1.dist-info}/RECORD +81 -66
  3. memos/__init__.py +1 -1
  4. memos/api/config.py +31 -8
  5. memos/api/context/context.py +1 -1
  6. memos/api/context/context_thread.py +96 -0
  7. memos/api/middleware/request_context.py +94 -0
  8. memos/api/product_api.py +5 -1
  9. memos/api/product_models.py +16 -0
  10. memos/api/routers/product_router.py +39 -3
  11. memos/api/start_api.py +3 -0
  12. memos/configs/internet_retriever.py +13 -0
  13. memos/configs/mem_scheduler.py +38 -16
  14. memos/configs/memory.py +13 -0
  15. memos/configs/reranker.py +18 -0
  16. memos/graph_dbs/base.py +33 -4
  17. memos/graph_dbs/nebular.py +631 -236
  18. memos/graph_dbs/neo4j.py +18 -7
  19. memos/graph_dbs/neo4j_community.py +6 -3
  20. memos/llms/vllm.py +2 -0
  21. memos/log.py +125 -8
  22. memos/mem_os/core.py +49 -11
  23. memos/mem_os/main.py +1 -1
  24. memos/mem_os/product.py +392 -215
  25. memos/mem_os/utils/default_config.py +1 -1
  26. memos/mem_os/utils/format_utils.py +11 -47
  27. memos/mem_os/utils/reference_utils.py +153 -0
  28. memos/mem_reader/simple_struct.py +112 -43
  29. memos/mem_scheduler/base_scheduler.py +58 -55
  30. memos/mem_scheduler/{modules → general_modules}/base.py +1 -2
  31. memos/mem_scheduler/{modules → general_modules}/dispatcher.py +54 -15
  32. memos/mem_scheduler/{modules → general_modules}/rabbitmq_service.py +4 -4
  33. memos/mem_scheduler/{modules → general_modules}/redis_service.py +1 -1
  34. memos/mem_scheduler/{modules → general_modules}/retriever.py +19 -5
  35. memos/mem_scheduler/{modules → general_modules}/scheduler_logger.py +10 -4
  36. memos/mem_scheduler/general_scheduler.py +110 -67
  37. memos/mem_scheduler/monitors/__init__.py +0 -0
  38. memos/mem_scheduler/monitors/dispatcher_monitor.py +305 -0
  39. memos/mem_scheduler/{modules/monitor.py → monitors/general_monitor.py} +57 -19
  40. memos/mem_scheduler/mos_for_test_scheduler.py +7 -1
  41. memos/mem_scheduler/schemas/general_schemas.py +3 -2
  42. memos/mem_scheduler/schemas/message_schemas.py +2 -1
  43. memos/mem_scheduler/schemas/monitor_schemas.py +10 -2
  44. memos/mem_scheduler/utils/misc_utils.py +43 -2
  45. memos/mem_user/mysql_user_manager.py +4 -2
  46. memos/memories/activation/item.py +1 -1
  47. memos/memories/activation/kv.py +20 -8
  48. memos/memories/textual/base.py +1 -1
  49. memos/memories/textual/general.py +1 -1
  50. memos/memories/textual/item.py +1 -1
  51. memos/memories/textual/tree.py +31 -1
  52. memos/memories/textual/tree_text_memory/organize/{conflict.py → handler.py} +30 -48
  53. memos/memories/textual/tree_text_memory/organize/manager.py +8 -96
  54. memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +2 -0
  55. memos/memories/textual/tree_text_memory/organize/reorganizer.py +102 -140
  56. memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +231 -0
  57. memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +9 -0
  58. memos/memories/textual/tree_text_memory/retrieve/recall.py +67 -10
  59. memos/memories/textual/tree_text_memory/retrieve/reranker.py +1 -1
  60. memos/memories/textual/tree_text_memory/retrieve/searcher.py +246 -134
  61. memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +7 -2
  62. memos/memories/textual/tree_text_memory/retrieve/utils.py +7 -5
  63. memos/memos_tools/lockfree_dict.py +120 -0
  64. memos/memos_tools/notification_utils.py +46 -0
  65. memos/memos_tools/thread_safe_dict.py +288 -0
  66. memos/reranker/__init__.py +4 -0
  67. memos/reranker/base.py +24 -0
  68. memos/reranker/cosine_local.py +95 -0
  69. memos/reranker/factory.py +43 -0
  70. memos/reranker/http_bge.py +99 -0
  71. memos/reranker/noop.py +16 -0
  72. memos/templates/mem_reader_prompts.py +290 -39
  73. memos/templates/mem_scheduler_prompts.py +23 -10
  74. memos/templates/mos_prompts.py +133 -31
  75. memos/templates/tree_reorganize_prompts.py +24 -17
  76. memos/utils.py +19 -0
  77. memos/memories/textual/tree_text_memory/organize/redundancy.py +0 -193
  78. {memoryos-0.2.2.dist-info → memoryos-1.0.1.dist-info}/LICENSE +0 -0
  79. {memoryos-0.2.2.dist-info → memoryos-1.0.1.dist-info}/WHEEL +0 -0
  80. {memoryos-0.2.2.dist-info → memoryos-1.0.1.dist-info}/entry_points.txt +0 -0
  81. /memos/mem_scheduler/{modules → general_modules}/__init__.py +0 -0
  82. /memos/mem_scheduler/{modules → general_modules}/misc.py +0 -0
@@ -39,8 +39,8 @@ class MemoryManager:
39
39
  if not memory_size:
40
40
  self.memory_size = {
41
41
  "WorkingMemory": 20,
42
- "LongTermMemory": 10000,
43
- "UserMemory": 10000,
42
+ "LongTermMemory": 1500,
43
+ "UserMemory": 480,
44
44
  }
45
45
  self._threshold = threshold
46
46
  self.is_reorganize = is_reorganize
@@ -158,106 +158,18 @@ class MemoryManager:
158
158
  - topic_summary_prefix: summary node id prefix if applicable
159
159
  - enable_summary_link: whether to auto-link to a summary node
160
160
  """
161
- embedding = memory.metadata.embedding
162
-
163
- # Step 1: Find similar nodes for possible merging
164
- similar_nodes = self.graph_store.search_by_embedding(
165
- vector=embedding,
166
- top_k=3,
167
- scope=memory_type,
168
- threshold=self._threshold,
169
- status="activated",
170
- )
171
-
172
- if similar_nodes and similar_nodes[0]["score"] > self._merged_threshold:
173
- return self._merge(memory, similar_nodes)
174
- else:
175
- node_id = str(uuid.uuid4())
176
- # Step 2: Add new node to graph
177
- self.graph_store.add_node(
178
- node_id, memory.memory, memory.metadata.model_dump(exclude_none=True)
179
- )
180
- self.reorganizer.add_message(
181
- QueueMessage(
182
- op="add",
183
- after_node=[node_id],
184
- )
185
- )
186
- return node_id
187
-
188
- def _merge(self, source_node: TextualMemoryItem, similar_nodes: list[dict]) -> str:
189
- """
190
- TODO: Add node traceability support by optionally preserving source nodes and linking them with MERGED_FROM edges.
191
-
192
- Merge the source memory into the most similar existing node (only one),
193
- and establish a MERGED_FROM edge in the graph.
194
-
195
- Parameters:
196
- source_node: The new memory item (not yet in the graph)
197
- similar_nodes: A list of dicts returned by search_by_embedding(), ordered by similarity
198
- """
199
- original_node = similar_nodes[0]
200
- original_id = original_node["id"]
201
- original_data = self.graph_store.get_node(original_id)
202
-
203
- target_text = original_data.get("memory", "")
204
- merged_text = f"{target_text}\n⟵MERGED⟶\n{source_node.memory}"
205
-
206
- original_meta = TreeNodeTextualMemoryMetadata(**original_data["metadata"])
207
- source_meta = source_node.metadata
208
-
209
- merged_key = source_meta.key or original_meta.key
210
- merged_tags = list(set((original_meta.tags or []) + (source_meta.tags or [])))
211
- merged_sources = list(set((original_meta.sources or []) + (source_meta.sources or [])))
212
- merged_background = f"{original_meta.background}\n⟵MERGED⟶\n{source_meta.background}"
213
- merged_embedding = self.embedder.embed([merged_text])[0]
214
-
215
- original_conf = original_meta.confidence or 0.0
216
- source_conf = source_meta.confidence or 0.0
217
- merged_confidence = float((original_conf + source_conf) / 2)
218
- merged_usage = list(set((original_meta.usage or []) + (source_meta.usage or [])))
219
-
220
- # Create new merged node
221
- merged_id = str(uuid.uuid4())
222
- merged_metadata = source_meta.model_copy(
223
- update={
224
- "embedding": merged_embedding,
225
- "updated_at": datetime.now().isoformat(),
226
- "key": merged_key,
227
- "tags": merged_tags,
228
- "sources": merged_sources,
229
- "background": merged_background,
230
- "confidence": merged_confidence,
231
- "usage": merged_usage,
232
- }
233
- )
234
-
161
+ node_id = str(uuid.uuid4())
162
+ # Step 2: Add new node to graph
235
163
  self.graph_store.add_node(
236
- merged_id, merged_text, merged_metadata.model_dump(exclude_none=True)
164
+ node_id, memory.memory, memory.metadata.model_dump(exclude_none=True)
237
165
  )
238
-
239
- # Add traceability edges: both original and new point to merged node
240
- self.graph_store.add_edge(original_id, merged_id, type="MERGED_TO")
241
- self.graph_store.update_node(original_id, {"status": "archived"})
242
- source_id = str(uuid.uuid4())
243
- source_metadata = source_node.metadata.model_copy(update={"status": "archived"})
244
- self.graph_store.add_node(source_id, source_node.memory, source_metadata.model_dump())
245
- self.graph_store.add_edge(source_id, merged_id, type="MERGED_TO")
246
- # After creating merged node and tracing lineage
247
- self._inherit_edges(original_id, merged_id)
248
-
249
- # log to reorganizer before updating the graph
250
166
  self.reorganizer.add_message(
251
167
  QueueMessage(
252
- op="merge",
253
- before_node=[
254
- original_id,
255
- source_node.id,
256
- ],
257
- after_node=[merged_id],
168
+ op="add",
169
+ after_node=[node_id],
258
170
  )
259
171
  )
260
- return merged_id
172
+ return node_id
261
173
 
262
174
  def _inherit_edges(self, from_id: str, to_id: str) -> None:
263
175
  """
@@ -73,10 +73,12 @@ class RelationAndReasoningDetector:
73
73
  results["sequence_links"].extend(seq)
74
74
  """
75
75
 
76
+ """
76
77
  # 4) Aggregate
77
78
  agg = self._detect_aggregate_node_for_group(node, nearest, min_group_size=5)
78
79
  if agg:
79
80
  results["aggregate_nodes"].append(agg)
81
+ """
80
82
 
81
83
  except Exception as e:
82
84
  logger.error(
@@ -3,7 +3,7 @@ import threading
3
3
  import time
4
4
  import traceback
5
5
 
6
- from collections import Counter, defaultdict
6
+ from collections import defaultdict
7
7
  from concurrent.futures import ThreadPoolExecutor, as_completed
8
8
  from queue import PriorityQueue
9
9
  from typing import Literal
@@ -17,8 +17,7 @@ from memos.graph_dbs.neo4j import Neo4jGraphDB
17
17
  from memos.llms.base import BaseLLM
18
18
  from memos.log import get_logger
19
19
  from memos.memories.textual.item import TreeNodeTextualMemoryMetadata
20
- from memos.memories.textual.tree_text_memory.organize.conflict import ConflictHandler
21
- from memos.memories.textual.tree_text_memory.organize.redundancy import RedundancyHandler
20
+ from memos.memories.textual.tree_text_memory.organize.handler import NodeHandler
22
21
  from memos.memories.textual.tree_text_memory.organize.relation_reason_detector import (
23
22
  RelationAndReasoningDetector,
24
23
  )
@@ -63,10 +62,10 @@ class GraphStructureReorganizer:
63
62
  self.relation_detector = RelationAndReasoningDetector(
64
63
  self.graph_store, self.llm, self.embedder
65
64
  )
66
- self.conflict = ConflictHandler(graph_store=graph_store, llm=llm, embedder=embedder)
67
- self.redundancy = RedundancyHandler(graph_store=graph_store, llm=llm, embedder=embedder)
65
+ self.resolver = NodeHandler(graph_store=graph_store, llm=llm, embedder=embedder)
68
66
 
69
67
  self.is_reorganize = is_reorganize
68
+ self._reorganize_needed = True
70
69
  if self.is_reorganize:
71
70
  # ____ 1. For queue message driven thread ___________
72
71
  self.thread = threading.Thread(target=self._run_message_consumer_loop)
@@ -125,13 +124,17 @@ class GraphStructureReorganizer:
125
124
  """
126
125
  import schedule
127
126
 
128
- schedule.every(600).seconds.do(self.optimize_structure, scope="LongTermMemory")
129
- schedule.every(600).seconds.do(self.optimize_structure, scope="UserMemory")
127
+ schedule.every(100).seconds.do(self.optimize_structure, scope="LongTermMemory")
128
+ schedule.every(100).seconds.do(self.optimize_structure, scope="UserMemory")
130
129
 
131
130
  logger.info("Structure optimizer schedule started.")
132
131
  while not getattr(self, "_stop_scheduler", False):
133
- schedule.run_pending()
134
- time.sleep(1)
132
+ if self._reorganize_needed:
133
+ logger.info("[Reorganizer] Triggering optimize_structure due to new nodes.")
134
+ self.optimize_structure(scope="LongTermMemory")
135
+ self.optimize_structure(scope="UserMemory")
136
+ self._reorganize_needed = False
137
+ time.sleep(30)
135
138
 
136
139
  def stop(self):
137
140
  """
@@ -148,45 +151,31 @@ class GraphStructureReorganizer:
148
151
  logger.info("Structure optimizer stopped.")
149
152
 
150
153
  def handle_message(self, message: QueueMessage):
151
- handle_map = {
152
- "add": self.handle_add,
153
- "remove": self.handle_remove,
154
- "merge": self.handle_merge,
155
- }
154
+ handle_map = {"add": self.handle_add, "remove": self.handle_remove}
156
155
  handle_map[message.op](message)
157
156
  logger.debug(f"message queue size: {self.queue.qsize()}")
158
157
 
159
158
  def handle_add(self, message: QueueMessage):
160
159
  logger.debug(f"Handling add operation: {str(message)[:500]}")
161
- # ———————— 1. check for conflicts ————————
162
160
  added_node = message.after_node[0]
163
- conflicts = self.conflict.detect(added_node, scope=added_node.metadata.memory_type)
164
- if conflicts:
165
- for added_node, existing_node in conflicts:
166
- self.conflict.resolve(added_node, existing_node)
167
- logger.info(f"Resolved conflict between {added_node.id} and {existing_node.id}.")
168
-
169
- # ———————— 2. check for redundancy ————————
170
- redundancies = self.redundancy.detect(added_node, scope=added_node.metadata.memory_type)
171
- if redundancies:
172
- for added_node, existing_node in redundancies:
173
- self.redundancy.resolve_two_nodes(added_node, existing_node)
174
- logger.info(f"Resolved redundancy between {added_node.id} and {existing_node.id}.")
161
+ detected_relationships = self.resolver.detect(
162
+ added_node, scope=added_node.metadata.memory_type
163
+ )
164
+ if detected_relationships:
165
+ for added_node, existing_node, relation in detected_relationships:
166
+ self.resolver.resolve(added_node, existing_node, relation)
167
+
168
+ self._reorganize_needed = True
175
169
 
176
170
  def handle_remove(self, message: QueueMessage):
177
171
  logger.debug(f"Handling remove operation: {str(message)[:50]}")
178
172
 
179
- def handle_merge(self, message: QueueMessage):
180
- after_node = message.after_node[0]
181
- logger.debug(f"Handling merge operation: <{after_node.memory}>")
182
- self.redundancy.resolve_one_node(after_node)
183
-
184
173
  def optimize_structure(
185
174
  self,
186
175
  scope: str = "LongTermMemory",
187
176
  local_tree_threshold: int = 10,
188
- min_cluster_size: int = 3,
189
- min_group_size: int = 5,
177
+ min_cluster_size: int = 4,
178
+ min_group_size: int = 20,
190
179
  ):
191
180
  """
192
181
  Periodically reorganize the graph:
@@ -253,7 +242,7 @@ class GraphStructureReorganizer:
253
242
  except Exception as e:
254
243
  logger.warning(
255
244
  f"[Reorganize] Cluster processing "
256
- f"failed: {e}, trace: {traceback.format_exc()}"
245
+ f"failed: {e}, cluster_nodes: {cluster_nodes}, trace: {traceback.format_exc()}"
257
246
  )
258
247
  logger.info("[GraphStructure Reorganize] Structure optimization finished.")
259
248
 
@@ -271,29 +260,23 @@ class GraphStructureReorganizer:
271
260
  if len(cluster_nodes) <= min_cluster_size:
272
261
  return
273
262
 
274
- if len(cluster_nodes) <= local_tree_threshold:
275
- # Small cluster ➜ single parent
276
- parent_node = self._summarize_cluster(cluster_nodes, scope)
277
- self._create_parent_node(parent_node)
278
- self._link_cluster_nodes(parent_node, cluster_nodes)
279
- else:
280
- # Large cluster ➜ local sub-clustering
281
- sub_clusters = self._local_subcluster(cluster_nodes)
282
- sub_parents = []
283
-
284
- for sub_nodes in sub_clusters:
285
- if len(sub_nodes) < min_cluster_size:
286
- continue # Skip tiny noise
287
- sub_parent_node = self._summarize_cluster(sub_nodes, scope)
288
- self._create_parent_node(sub_parent_node)
289
- self._link_cluster_nodes(sub_parent_node, sub_nodes)
290
- sub_parents.append(sub_parent_node)
291
-
292
- if sub_parents:
293
- cluster_parent_node = self._summarize_cluster(cluster_nodes, scope)
294
- self._create_parent_node(cluster_parent_node)
295
- for sub_parent in sub_parents:
296
- self.graph_store.add_edge(cluster_parent_node.id, sub_parent.id, "PARENT")
263
+ # Large cluster ➜ local sub-clustering
264
+ sub_clusters = self._local_subcluster(cluster_nodes)
265
+ sub_parents = []
266
+
267
+ for sub_nodes in sub_clusters:
268
+ if len(sub_nodes) < min_cluster_size:
269
+ continue # Skip tiny noise
270
+ sub_parent_node = self._summarize_cluster(sub_nodes, scope)
271
+ self._create_parent_node(sub_parent_node)
272
+ self._link_cluster_nodes(sub_parent_node, sub_nodes)
273
+ sub_parents.append(sub_parent_node)
274
+
275
+ if sub_parents and len(sub_parents) >= min_cluster_size:
276
+ cluster_parent_node = self._summarize_cluster(cluster_nodes, scope)
277
+ self._create_parent_node(cluster_parent_node)
278
+ for sub_parent in sub_parents:
279
+ self.graph_store.add_edge(cluster_parent_node.id, sub_parent.id, "PARENT")
297
280
 
298
281
  logger.info("Adding relations/reasons")
299
282
  nodes_to_check = cluster_nodes
@@ -350,7 +333,9 @@ class GraphStructureReorganizer:
350
333
 
351
334
  logger.info("[Reorganizer] Cluster relation/reasoning done.")
352
335
 
353
- def _local_subcluster(self, cluster_nodes: list[GraphDBNode]) -> list[list[GraphDBNode]]:
336
+ def _local_subcluster(
337
+ self, cluster_nodes: list[GraphDBNode], max_length: int = 8000
338
+ ) -> (list)[list[GraphDBNode]]:
354
339
  """
355
340
  Use LLM to split a large cluster into semantically coherent sub-clusters.
356
341
  """
@@ -364,7 +349,9 @@ class GraphStructureReorganizer:
364
349
  scene_lines.append(line)
365
350
 
366
351
  joined_scene = "\n".join(scene_lines)
367
- prompt = LOCAL_SUBCLUSTER_PROMPT.replace("{joined_scene}", joined_scene)
352
+ if len(joined_scene) > max_length:
353
+ logger.warning(f"Sub-cluster too long: {joined_scene}")
354
+ prompt = LOCAL_SUBCLUSTER_PROMPT.replace("{joined_scene}", joined_scene[:max_length])
368
355
 
369
356
  messages = [{"role": "user", "content": prompt}]
370
357
  response_text = self.llm.generate(messages)
@@ -389,12 +376,12 @@ class GraphStructureReorganizer:
389
376
  install_command="pip install scikit-learn",
390
377
  install_link="https://scikit-learn.org/stable/install.html",
391
378
  )
392
- def _partition(self, nodes, min_cluster_size: int = 3, max_cluster_size: int = 20):
379
+ def _partition(self, nodes, min_cluster_size: int = 10, max_cluster_size: int = 20):
393
380
  """
394
381
  Partition nodes by:
395
- 1) Frequent tags (top N & above threshold)
396
- 2) Remaining nodes by embedding clustering (MiniBatchKMeans)
397
- 3) Small clusters merged or assigned to 'Other'
382
+ - If total nodes <= max_cluster_size -> return all nodes in one cluster.
383
+ - If total nodes > max_cluster_size -> cluster by embeddings, recursively split.
384
+ - Only keep clusters with size > min_cluster_size.
398
385
 
399
386
  Args:
400
387
  nodes: List of GraphDBNode
@@ -405,105 +392,80 @@ class GraphStructureReorganizer:
405
392
  """
406
393
  from sklearn.cluster import MiniBatchKMeans
407
394
 
408
- # 1) Count all tags
409
- tag_counter = Counter()
410
- for node in nodes:
411
- for tag in node.metadata.tags:
412
- tag_counter[tag] += 1
413
-
414
- # Select frequent tags
415
- top_n_tags = {tag for tag, count in tag_counter.most_common(50)}
416
- threshold_tags = {tag for tag, count in tag_counter.items() if count >= 50}
417
- frequent_tags = top_n_tags | threshold_tags
418
-
419
- # Group nodes by tags
420
- tag_groups = defaultdict(list)
421
-
422
- for node in nodes:
423
- for tag in node.metadata.tags:
424
- if tag in frequent_tags:
425
- tag_groups[tag].append(node)
426
- break
427
-
428
- filtered_tag_clusters = []
429
- assigned_ids = set()
430
- for tag, group in tag_groups.items():
431
- if len(group) >= min_cluster_size:
432
- # Split large groups into chunks of at most max_cluster_size
433
- for i in range(0, len(group), max_cluster_size):
434
- sub_group = group[i : i + max_cluster_size]
435
- filtered_tag_clusters.append(sub_group)
436
- assigned_ids.update(n.id for n in sub_group)
437
- else:
438
- logger.info(f"... dropped tag {tag} due to low size ...")
439
-
440
- logger.info(
441
- f"[MixedPartition] Created {len(filtered_tag_clusters)} clusters from tags. "
442
- f"Nodes grouped by tags: {len(assigned_ids)} / {len(nodes)}"
443
- )
444
-
445
- # Remaining nodes -> embedding clustering
446
- remaining_nodes = [n for n in nodes if n.id not in assigned_ids]
447
- logger.info(
448
- f"[MixedPartition] Remaining nodes for embedding clustering: {len(remaining_nodes)}"
449
- )
450
-
451
- embedding_clusters = []
395
+ if len(nodes) <= max_cluster_size:
396
+ logger.info(
397
+ f"[KMeansPartition] Node count {len(nodes)} <= {max_cluster_size}, skipping KMeans."
398
+ )
399
+ return [nodes]
452
400
 
453
- def recursive_clustering(nodes_list):
401
+ def recursive_clustering(nodes_list, depth=0):
454
402
  """Recursively split clusters until each is <= max_cluster_size."""
403
+ indent = " " * depth
404
+ logger.info(
405
+ f"{indent}[Recursive] Start clustering {len(nodes_list)} nodes at depth {depth}"
406
+ )
407
+
455
408
  if len(nodes_list) <= max_cluster_size:
409
+ logger.info(
410
+ f"{indent}[Recursive] Node count <= {max_cluster_size}, stop splitting."
411
+ )
456
412
  return [nodes_list]
457
-
458
413
  # Try kmeans with k = ceil(len(nodes) / max_cluster_size)
459
- x = np.array([n.metadata.embedding for n in nodes_list if n.metadata.embedding])
460
- if len(x) < 2:
414
+ x_nodes = [n for n in nodes_list if n.metadata.embedding]
415
+ x = np.array([n.metadata.embedding for n in x_nodes])
416
+
417
+ if len(x) < min_cluster_size:
418
+ logger.info(
419
+ f"{indent}[Recursive] Too few embeddings ({len(x)}), skipping clustering."
420
+ )
461
421
  return [nodes_list]
462
422
 
463
423
  k = min(len(x), (len(nodes_list) + max_cluster_size - 1) // max_cluster_size)
464
- k = max(1, min(k, len(x)))
424
+ k = max(1, k)
465
425
 
466
426
  try:
427
+ logger.info(f"{indent}[Recursive] Clustering with k={k} on {len(x)} points.")
467
428
  kmeans = MiniBatchKMeans(n_clusters=k, batch_size=256, random_state=42)
468
429
  labels = kmeans.fit_predict(x)
469
430
 
470
431
  label_groups = defaultdict(list)
471
- for node, label in zip(nodes_list, labels, strict=False):
432
+ for node, label in zip(x_nodes, labels, strict=False):
472
433
  label_groups[label].append(node)
473
434
 
435
+ # Map: label -> nodes with no embedding (fallback group)
436
+ no_embedding_nodes = [n for n in nodes_list if not n.metadata.embedding]
437
+ if no_embedding_nodes:
438
+ logger.warning(
439
+ f"{indent}[Recursive] {len(no_embedding_nodes)} nodes have no embedding. Added to largest cluster."
440
+ )
441
+ # Assign to largest cluster
442
+ largest_label = max(label_groups.items(), key=lambda kv: len(kv[1]))[0]
443
+ label_groups[largest_label].extend(no_embedding_nodes)
444
+
474
445
  result = []
475
- for sub_group in label_groups.values():
476
- result.extend(recursive_clustering(sub_group))
446
+ for label, sub_group in label_groups.items():
447
+ logger.info(f"{indent} Cluster-{label}: {len(sub_group)} nodes")
448
+ result.extend(recursive_clustering(sub_group, depth=depth + 1))
477
449
  return result
450
+
478
451
  except Exception as e:
479
- logger.warning(f"Clustering failed: {e}, falling back to single cluster.")
452
+ logger.warning(
453
+ f"{indent}[Recursive] Clustering failed: {e}, fallback to one cluster."
454
+ )
480
455
  return [nodes_list]
481
456
 
482
- if remaining_nodes:
483
- clusters = recursive_clustering(remaining_nodes)
484
- embedding_clusters.extend(clusters)
485
- logger.info(
486
- f"[MixedPartition] Created {len(embedding_clusters)} clusters from embeddings."
487
- )
488
-
489
- # Merge all clusters
490
- all_clusters = filtered_tag_clusters + embedding_clusters
457
+ raw_clusters = recursive_clustering(nodes)
458
+ filtered_clusters = [c for c in raw_clusters if len(c) > min_cluster_size]
491
459
 
492
- # Handle small clusters (< min_cluster_size)
493
- final_clusters = []
494
- small_nodes = []
495
- for group in all_clusters:
496
- if len(group) < min_cluster_size:
497
- small_nodes.extend(group)
498
- else:
499
- final_clusters.append(group)
460
+ logger.info(f"[KMeansPartition] Total clusters before filtering: {len(raw_clusters)}")
461
+ for i, cluster in enumerate(raw_clusters):
462
+ logger.info(f"[KMeansPartition] Cluster-{i}: {len(cluster)} nodes")
500
463
 
501
- if small_nodes:
502
- final_clusters.append(small_nodes)
503
- logger.info(f"[MixedPartition] {len(small_nodes)} nodes assigned to 'Other' cluster.")
464
+ logger.info(
465
+ f"[KMeansPartition] Clusters after filtering (>{min_cluster_size}): {len(filtered_clusters)}"
466
+ )
504
467
 
505
- logger.info(f"[MixedPartition] Total final clusters: {len(final_clusters)}")
506
- return final_clusters
468
+ return filtered_clusters
507
469
 
508
470
  def _summarize_cluster(self, cluster_nodes: list[GraphDBNode], scope: str) -> GraphDBNode:
509
471
  """
@@ -600,7 +562,7 @@ class GraphStructureReorganizer:
600
562
  for i, node in enumerate(message.after_node or []):
601
563
  if not isinstance(node, str):
602
564
  continue
603
- raw_node = self.graph_store.get_node(node)
565
+ raw_node = self.graph_store.get_node(node, include_embedding=True)
604
566
  if raw_node is None:
605
567
  logger.debug(f"Node with ID {node} not found in the graph store.")
606
568
  message.after_node[i] = None