MemoryOS 0.2.2__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MemoryOS might be problematic. Click here for more details.
- {memoryos-0.2.2.dist-info → memoryos-1.0.1.dist-info}/METADATA +7 -1
- {memoryos-0.2.2.dist-info → memoryos-1.0.1.dist-info}/RECORD +81 -66
- memos/__init__.py +1 -1
- memos/api/config.py +31 -8
- memos/api/context/context.py +1 -1
- memos/api/context/context_thread.py +96 -0
- memos/api/middleware/request_context.py +94 -0
- memos/api/product_api.py +5 -1
- memos/api/product_models.py +16 -0
- memos/api/routers/product_router.py +39 -3
- memos/api/start_api.py +3 -0
- memos/configs/internet_retriever.py +13 -0
- memos/configs/mem_scheduler.py +38 -16
- memos/configs/memory.py +13 -0
- memos/configs/reranker.py +18 -0
- memos/graph_dbs/base.py +33 -4
- memos/graph_dbs/nebular.py +631 -236
- memos/graph_dbs/neo4j.py +18 -7
- memos/graph_dbs/neo4j_community.py +6 -3
- memos/llms/vllm.py +2 -0
- memos/log.py +125 -8
- memos/mem_os/core.py +49 -11
- memos/mem_os/main.py +1 -1
- memos/mem_os/product.py +392 -215
- memos/mem_os/utils/default_config.py +1 -1
- memos/mem_os/utils/format_utils.py +11 -47
- memos/mem_os/utils/reference_utils.py +153 -0
- memos/mem_reader/simple_struct.py +112 -43
- memos/mem_scheduler/base_scheduler.py +58 -55
- memos/mem_scheduler/{modules → general_modules}/base.py +1 -2
- memos/mem_scheduler/{modules → general_modules}/dispatcher.py +54 -15
- memos/mem_scheduler/{modules → general_modules}/rabbitmq_service.py +4 -4
- memos/mem_scheduler/{modules → general_modules}/redis_service.py +1 -1
- memos/mem_scheduler/{modules → general_modules}/retriever.py +19 -5
- memos/mem_scheduler/{modules → general_modules}/scheduler_logger.py +10 -4
- memos/mem_scheduler/general_scheduler.py +110 -67
- memos/mem_scheduler/monitors/__init__.py +0 -0
- memos/mem_scheduler/monitors/dispatcher_monitor.py +305 -0
- memos/mem_scheduler/{modules/monitor.py → monitors/general_monitor.py} +57 -19
- memos/mem_scheduler/mos_for_test_scheduler.py +7 -1
- memos/mem_scheduler/schemas/general_schemas.py +3 -2
- memos/mem_scheduler/schemas/message_schemas.py +2 -1
- memos/mem_scheduler/schemas/monitor_schemas.py +10 -2
- memos/mem_scheduler/utils/misc_utils.py +43 -2
- memos/mem_user/mysql_user_manager.py +4 -2
- memos/memories/activation/item.py +1 -1
- memos/memories/activation/kv.py +20 -8
- memos/memories/textual/base.py +1 -1
- memos/memories/textual/general.py +1 -1
- memos/memories/textual/item.py +1 -1
- memos/memories/textual/tree.py +31 -1
- memos/memories/textual/tree_text_memory/organize/{conflict.py → handler.py} +30 -48
- memos/memories/textual/tree_text_memory/organize/manager.py +8 -96
- memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +2 -0
- memos/memories/textual/tree_text_memory/organize/reorganizer.py +102 -140
- memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +231 -0
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +9 -0
- memos/memories/textual/tree_text_memory/retrieve/recall.py +67 -10
- memos/memories/textual/tree_text_memory/retrieve/reranker.py +1 -1
- memos/memories/textual/tree_text_memory/retrieve/searcher.py +246 -134
- memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +7 -2
- memos/memories/textual/tree_text_memory/retrieve/utils.py +7 -5
- memos/memos_tools/lockfree_dict.py +120 -0
- memos/memos_tools/notification_utils.py +46 -0
- memos/memos_tools/thread_safe_dict.py +288 -0
- memos/reranker/__init__.py +4 -0
- memos/reranker/base.py +24 -0
- memos/reranker/cosine_local.py +95 -0
- memos/reranker/factory.py +43 -0
- memos/reranker/http_bge.py +99 -0
- memos/reranker/noop.py +16 -0
- memos/templates/mem_reader_prompts.py +290 -39
- memos/templates/mem_scheduler_prompts.py +23 -10
- memos/templates/mos_prompts.py +133 -31
- memos/templates/tree_reorganize_prompts.py +24 -17
- memos/utils.py +19 -0
- memos/memories/textual/tree_text_memory/organize/redundancy.py +0 -193
- {memoryos-0.2.2.dist-info → memoryos-1.0.1.dist-info}/LICENSE +0 -0
- {memoryos-0.2.2.dist-info → memoryos-1.0.1.dist-info}/WHEEL +0 -0
- {memoryos-0.2.2.dist-info → memoryos-1.0.1.dist-info}/entry_points.txt +0 -0
- /memos/mem_scheduler/{modules → general_modules}/__init__.py +0 -0
- /memos/mem_scheduler/{modules → general_modules}/misc.py +0 -0
|
@@ -39,8 +39,8 @@ class MemoryManager:
|
|
|
39
39
|
if not memory_size:
|
|
40
40
|
self.memory_size = {
|
|
41
41
|
"WorkingMemory": 20,
|
|
42
|
-
"LongTermMemory":
|
|
43
|
-
"UserMemory":
|
|
42
|
+
"LongTermMemory": 1500,
|
|
43
|
+
"UserMemory": 480,
|
|
44
44
|
}
|
|
45
45
|
self._threshold = threshold
|
|
46
46
|
self.is_reorganize = is_reorganize
|
|
@@ -158,106 +158,18 @@ class MemoryManager:
|
|
|
158
158
|
- topic_summary_prefix: summary node id prefix if applicable
|
|
159
159
|
- enable_summary_link: whether to auto-link to a summary node
|
|
160
160
|
"""
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
# Step 1: Find similar nodes for possible merging
|
|
164
|
-
similar_nodes = self.graph_store.search_by_embedding(
|
|
165
|
-
vector=embedding,
|
|
166
|
-
top_k=3,
|
|
167
|
-
scope=memory_type,
|
|
168
|
-
threshold=self._threshold,
|
|
169
|
-
status="activated",
|
|
170
|
-
)
|
|
171
|
-
|
|
172
|
-
if similar_nodes and similar_nodes[0]["score"] > self._merged_threshold:
|
|
173
|
-
return self._merge(memory, similar_nodes)
|
|
174
|
-
else:
|
|
175
|
-
node_id = str(uuid.uuid4())
|
|
176
|
-
# Step 2: Add new node to graph
|
|
177
|
-
self.graph_store.add_node(
|
|
178
|
-
node_id, memory.memory, memory.metadata.model_dump(exclude_none=True)
|
|
179
|
-
)
|
|
180
|
-
self.reorganizer.add_message(
|
|
181
|
-
QueueMessage(
|
|
182
|
-
op="add",
|
|
183
|
-
after_node=[node_id],
|
|
184
|
-
)
|
|
185
|
-
)
|
|
186
|
-
return node_id
|
|
187
|
-
|
|
188
|
-
def _merge(self, source_node: TextualMemoryItem, similar_nodes: list[dict]) -> str:
|
|
189
|
-
"""
|
|
190
|
-
TODO: Add node traceability support by optionally preserving source nodes and linking them with MERGED_FROM edges.
|
|
191
|
-
|
|
192
|
-
Merge the source memory into the most similar existing node (only one),
|
|
193
|
-
and establish a MERGED_FROM edge in the graph.
|
|
194
|
-
|
|
195
|
-
Parameters:
|
|
196
|
-
source_node: The new memory item (not yet in the graph)
|
|
197
|
-
similar_nodes: A list of dicts returned by search_by_embedding(), ordered by similarity
|
|
198
|
-
"""
|
|
199
|
-
original_node = similar_nodes[0]
|
|
200
|
-
original_id = original_node["id"]
|
|
201
|
-
original_data = self.graph_store.get_node(original_id)
|
|
202
|
-
|
|
203
|
-
target_text = original_data.get("memory", "")
|
|
204
|
-
merged_text = f"{target_text}\n⟵MERGED⟶\n{source_node.memory}"
|
|
205
|
-
|
|
206
|
-
original_meta = TreeNodeTextualMemoryMetadata(**original_data["metadata"])
|
|
207
|
-
source_meta = source_node.metadata
|
|
208
|
-
|
|
209
|
-
merged_key = source_meta.key or original_meta.key
|
|
210
|
-
merged_tags = list(set((original_meta.tags or []) + (source_meta.tags or [])))
|
|
211
|
-
merged_sources = list(set((original_meta.sources or []) + (source_meta.sources or [])))
|
|
212
|
-
merged_background = f"{original_meta.background}\n⟵MERGED⟶\n{source_meta.background}"
|
|
213
|
-
merged_embedding = self.embedder.embed([merged_text])[0]
|
|
214
|
-
|
|
215
|
-
original_conf = original_meta.confidence or 0.0
|
|
216
|
-
source_conf = source_meta.confidence or 0.0
|
|
217
|
-
merged_confidence = float((original_conf + source_conf) / 2)
|
|
218
|
-
merged_usage = list(set((original_meta.usage or []) + (source_meta.usage or [])))
|
|
219
|
-
|
|
220
|
-
# Create new merged node
|
|
221
|
-
merged_id = str(uuid.uuid4())
|
|
222
|
-
merged_metadata = source_meta.model_copy(
|
|
223
|
-
update={
|
|
224
|
-
"embedding": merged_embedding,
|
|
225
|
-
"updated_at": datetime.now().isoformat(),
|
|
226
|
-
"key": merged_key,
|
|
227
|
-
"tags": merged_tags,
|
|
228
|
-
"sources": merged_sources,
|
|
229
|
-
"background": merged_background,
|
|
230
|
-
"confidence": merged_confidence,
|
|
231
|
-
"usage": merged_usage,
|
|
232
|
-
}
|
|
233
|
-
)
|
|
234
|
-
|
|
161
|
+
node_id = str(uuid.uuid4())
|
|
162
|
+
# Step 2: Add new node to graph
|
|
235
163
|
self.graph_store.add_node(
|
|
236
|
-
|
|
164
|
+
node_id, memory.memory, memory.metadata.model_dump(exclude_none=True)
|
|
237
165
|
)
|
|
238
|
-
|
|
239
|
-
# Add traceability edges: both original and new point to merged node
|
|
240
|
-
self.graph_store.add_edge(original_id, merged_id, type="MERGED_TO")
|
|
241
|
-
self.graph_store.update_node(original_id, {"status": "archived"})
|
|
242
|
-
source_id = str(uuid.uuid4())
|
|
243
|
-
source_metadata = source_node.metadata.model_copy(update={"status": "archived"})
|
|
244
|
-
self.graph_store.add_node(source_id, source_node.memory, source_metadata.model_dump())
|
|
245
|
-
self.graph_store.add_edge(source_id, merged_id, type="MERGED_TO")
|
|
246
|
-
# After creating merged node and tracing lineage
|
|
247
|
-
self._inherit_edges(original_id, merged_id)
|
|
248
|
-
|
|
249
|
-
# log to reorganizer before updating the graph
|
|
250
166
|
self.reorganizer.add_message(
|
|
251
167
|
QueueMessage(
|
|
252
|
-
op="
|
|
253
|
-
|
|
254
|
-
original_id,
|
|
255
|
-
source_node.id,
|
|
256
|
-
],
|
|
257
|
-
after_node=[merged_id],
|
|
168
|
+
op="add",
|
|
169
|
+
after_node=[node_id],
|
|
258
170
|
)
|
|
259
171
|
)
|
|
260
|
-
return
|
|
172
|
+
return node_id
|
|
261
173
|
|
|
262
174
|
def _inherit_edges(self, from_id: str, to_id: str) -> None:
|
|
263
175
|
"""
|
|
@@ -73,10 +73,12 @@ class RelationAndReasoningDetector:
|
|
|
73
73
|
results["sequence_links"].extend(seq)
|
|
74
74
|
"""
|
|
75
75
|
|
|
76
|
+
"""
|
|
76
77
|
# 4) Aggregate
|
|
77
78
|
agg = self._detect_aggregate_node_for_group(node, nearest, min_group_size=5)
|
|
78
79
|
if agg:
|
|
79
80
|
results["aggregate_nodes"].append(agg)
|
|
81
|
+
"""
|
|
80
82
|
|
|
81
83
|
except Exception as e:
|
|
82
84
|
logger.error(
|
|
@@ -3,7 +3,7 @@ import threading
|
|
|
3
3
|
import time
|
|
4
4
|
import traceback
|
|
5
5
|
|
|
6
|
-
from collections import
|
|
6
|
+
from collections import defaultdict
|
|
7
7
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
8
8
|
from queue import PriorityQueue
|
|
9
9
|
from typing import Literal
|
|
@@ -17,8 +17,7 @@ from memos.graph_dbs.neo4j import Neo4jGraphDB
|
|
|
17
17
|
from memos.llms.base import BaseLLM
|
|
18
18
|
from memos.log import get_logger
|
|
19
19
|
from memos.memories.textual.item import TreeNodeTextualMemoryMetadata
|
|
20
|
-
from memos.memories.textual.tree_text_memory.organize.
|
|
21
|
-
from memos.memories.textual.tree_text_memory.organize.redundancy import RedundancyHandler
|
|
20
|
+
from memos.memories.textual.tree_text_memory.organize.handler import NodeHandler
|
|
22
21
|
from memos.memories.textual.tree_text_memory.organize.relation_reason_detector import (
|
|
23
22
|
RelationAndReasoningDetector,
|
|
24
23
|
)
|
|
@@ -63,10 +62,10 @@ class GraphStructureReorganizer:
|
|
|
63
62
|
self.relation_detector = RelationAndReasoningDetector(
|
|
64
63
|
self.graph_store, self.llm, self.embedder
|
|
65
64
|
)
|
|
66
|
-
self.
|
|
67
|
-
self.redundancy = RedundancyHandler(graph_store=graph_store, llm=llm, embedder=embedder)
|
|
65
|
+
self.resolver = NodeHandler(graph_store=graph_store, llm=llm, embedder=embedder)
|
|
68
66
|
|
|
69
67
|
self.is_reorganize = is_reorganize
|
|
68
|
+
self._reorganize_needed = True
|
|
70
69
|
if self.is_reorganize:
|
|
71
70
|
# ____ 1. For queue message driven thread ___________
|
|
72
71
|
self.thread = threading.Thread(target=self._run_message_consumer_loop)
|
|
@@ -125,13 +124,17 @@ class GraphStructureReorganizer:
|
|
|
125
124
|
"""
|
|
126
125
|
import schedule
|
|
127
126
|
|
|
128
|
-
schedule.every(
|
|
129
|
-
schedule.every(
|
|
127
|
+
schedule.every(100).seconds.do(self.optimize_structure, scope="LongTermMemory")
|
|
128
|
+
schedule.every(100).seconds.do(self.optimize_structure, scope="UserMemory")
|
|
130
129
|
|
|
131
130
|
logger.info("Structure optimizer schedule started.")
|
|
132
131
|
while not getattr(self, "_stop_scheduler", False):
|
|
133
|
-
|
|
134
|
-
|
|
132
|
+
if self._reorganize_needed:
|
|
133
|
+
logger.info("[Reorganizer] Triggering optimize_structure due to new nodes.")
|
|
134
|
+
self.optimize_structure(scope="LongTermMemory")
|
|
135
|
+
self.optimize_structure(scope="UserMemory")
|
|
136
|
+
self._reorganize_needed = False
|
|
137
|
+
time.sleep(30)
|
|
135
138
|
|
|
136
139
|
def stop(self):
|
|
137
140
|
"""
|
|
@@ -148,45 +151,31 @@ class GraphStructureReorganizer:
|
|
|
148
151
|
logger.info("Structure optimizer stopped.")
|
|
149
152
|
|
|
150
153
|
def handle_message(self, message: QueueMessage):
|
|
151
|
-
handle_map = {
|
|
152
|
-
"add": self.handle_add,
|
|
153
|
-
"remove": self.handle_remove,
|
|
154
|
-
"merge": self.handle_merge,
|
|
155
|
-
}
|
|
154
|
+
handle_map = {"add": self.handle_add, "remove": self.handle_remove}
|
|
156
155
|
handle_map[message.op](message)
|
|
157
156
|
logger.debug(f"message queue size: {self.queue.qsize()}")
|
|
158
157
|
|
|
159
158
|
def handle_add(self, message: QueueMessage):
|
|
160
159
|
logger.debug(f"Handling add operation: {str(message)[:500]}")
|
|
161
|
-
# ———————— 1. check for conflicts ————————
|
|
162
160
|
added_node = message.after_node[0]
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
if redundancies:
|
|
172
|
-
for added_node, existing_node in redundancies:
|
|
173
|
-
self.redundancy.resolve_two_nodes(added_node, existing_node)
|
|
174
|
-
logger.info(f"Resolved redundancy between {added_node.id} and {existing_node.id}.")
|
|
161
|
+
detected_relationships = self.resolver.detect(
|
|
162
|
+
added_node, scope=added_node.metadata.memory_type
|
|
163
|
+
)
|
|
164
|
+
if detected_relationships:
|
|
165
|
+
for added_node, existing_node, relation in detected_relationships:
|
|
166
|
+
self.resolver.resolve(added_node, existing_node, relation)
|
|
167
|
+
|
|
168
|
+
self._reorganize_needed = True
|
|
175
169
|
|
|
176
170
|
def handle_remove(self, message: QueueMessage):
|
|
177
171
|
logger.debug(f"Handling remove operation: {str(message)[:50]}")
|
|
178
172
|
|
|
179
|
-
def handle_merge(self, message: QueueMessage):
|
|
180
|
-
after_node = message.after_node[0]
|
|
181
|
-
logger.debug(f"Handling merge operation: <{after_node.memory}>")
|
|
182
|
-
self.redundancy.resolve_one_node(after_node)
|
|
183
|
-
|
|
184
173
|
def optimize_structure(
|
|
185
174
|
self,
|
|
186
175
|
scope: str = "LongTermMemory",
|
|
187
176
|
local_tree_threshold: int = 10,
|
|
188
|
-
min_cluster_size: int =
|
|
189
|
-
min_group_size: int =
|
|
177
|
+
min_cluster_size: int = 4,
|
|
178
|
+
min_group_size: int = 20,
|
|
190
179
|
):
|
|
191
180
|
"""
|
|
192
181
|
Periodically reorganize the graph:
|
|
@@ -253,7 +242,7 @@ class GraphStructureReorganizer:
|
|
|
253
242
|
except Exception as e:
|
|
254
243
|
logger.warning(
|
|
255
244
|
f"[Reorganize] Cluster processing "
|
|
256
|
-
f"failed: {e}, trace: {traceback.format_exc()}"
|
|
245
|
+
f"failed: {e}, cluster_nodes: {cluster_nodes}, trace: {traceback.format_exc()}"
|
|
257
246
|
)
|
|
258
247
|
logger.info("[GraphStructure Reorganize] Structure optimization finished.")
|
|
259
248
|
|
|
@@ -271,29 +260,23 @@ class GraphStructureReorganizer:
|
|
|
271
260
|
if len(cluster_nodes) <= min_cluster_size:
|
|
272
261
|
return
|
|
273
262
|
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
if sub_parents:
|
|
293
|
-
cluster_parent_node = self._summarize_cluster(cluster_nodes, scope)
|
|
294
|
-
self._create_parent_node(cluster_parent_node)
|
|
295
|
-
for sub_parent in sub_parents:
|
|
296
|
-
self.graph_store.add_edge(cluster_parent_node.id, sub_parent.id, "PARENT")
|
|
263
|
+
# Large cluster ➜ local sub-clustering
|
|
264
|
+
sub_clusters = self._local_subcluster(cluster_nodes)
|
|
265
|
+
sub_parents = []
|
|
266
|
+
|
|
267
|
+
for sub_nodes in sub_clusters:
|
|
268
|
+
if len(sub_nodes) < min_cluster_size:
|
|
269
|
+
continue # Skip tiny noise
|
|
270
|
+
sub_parent_node = self._summarize_cluster(sub_nodes, scope)
|
|
271
|
+
self._create_parent_node(sub_parent_node)
|
|
272
|
+
self._link_cluster_nodes(sub_parent_node, sub_nodes)
|
|
273
|
+
sub_parents.append(sub_parent_node)
|
|
274
|
+
|
|
275
|
+
if sub_parents and len(sub_parents) >= min_cluster_size:
|
|
276
|
+
cluster_parent_node = self._summarize_cluster(cluster_nodes, scope)
|
|
277
|
+
self._create_parent_node(cluster_parent_node)
|
|
278
|
+
for sub_parent in sub_parents:
|
|
279
|
+
self.graph_store.add_edge(cluster_parent_node.id, sub_parent.id, "PARENT")
|
|
297
280
|
|
|
298
281
|
logger.info("Adding relations/reasons")
|
|
299
282
|
nodes_to_check = cluster_nodes
|
|
@@ -350,7 +333,9 @@ class GraphStructureReorganizer:
|
|
|
350
333
|
|
|
351
334
|
logger.info("[Reorganizer] Cluster relation/reasoning done.")
|
|
352
335
|
|
|
353
|
-
def _local_subcluster(
|
|
336
|
+
def _local_subcluster(
|
|
337
|
+
self, cluster_nodes: list[GraphDBNode], max_length: int = 8000
|
|
338
|
+
) -> (list)[list[GraphDBNode]]:
|
|
354
339
|
"""
|
|
355
340
|
Use LLM to split a large cluster into semantically coherent sub-clusters.
|
|
356
341
|
"""
|
|
@@ -364,7 +349,9 @@ class GraphStructureReorganizer:
|
|
|
364
349
|
scene_lines.append(line)
|
|
365
350
|
|
|
366
351
|
joined_scene = "\n".join(scene_lines)
|
|
367
|
-
|
|
352
|
+
if len(joined_scene) > max_length:
|
|
353
|
+
logger.warning(f"Sub-cluster too long: {joined_scene}")
|
|
354
|
+
prompt = LOCAL_SUBCLUSTER_PROMPT.replace("{joined_scene}", joined_scene[:max_length])
|
|
368
355
|
|
|
369
356
|
messages = [{"role": "user", "content": prompt}]
|
|
370
357
|
response_text = self.llm.generate(messages)
|
|
@@ -389,12 +376,12 @@ class GraphStructureReorganizer:
|
|
|
389
376
|
install_command="pip install scikit-learn",
|
|
390
377
|
install_link="https://scikit-learn.org/stable/install.html",
|
|
391
378
|
)
|
|
392
|
-
def _partition(self, nodes, min_cluster_size: int =
|
|
379
|
+
def _partition(self, nodes, min_cluster_size: int = 10, max_cluster_size: int = 20):
|
|
393
380
|
"""
|
|
394
381
|
Partition nodes by:
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
382
|
+
- If total nodes <= max_cluster_size -> return all nodes in one cluster.
|
|
383
|
+
- If total nodes > max_cluster_size -> cluster by embeddings, recursively split.
|
|
384
|
+
- Only keep clusters with size > min_cluster_size.
|
|
398
385
|
|
|
399
386
|
Args:
|
|
400
387
|
nodes: List of GraphDBNode
|
|
@@ -405,105 +392,80 @@ class GraphStructureReorganizer:
|
|
|
405
392
|
"""
|
|
406
393
|
from sklearn.cluster import MiniBatchKMeans
|
|
407
394
|
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
# Select frequent tags
|
|
415
|
-
top_n_tags = {tag for tag, count in tag_counter.most_common(50)}
|
|
416
|
-
threshold_tags = {tag for tag, count in tag_counter.items() if count >= 50}
|
|
417
|
-
frequent_tags = top_n_tags | threshold_tags
|
|
418
|
-
|
|
419
|
-
# Group nodes by tags
|
|
420
|
-
tag_groups = defaultdict(list)
|
|
421
|
-
|
|
422
|
-
for node in nodes:
|
|
423
|
-
for tag in node.metadata.tags:
|
|
424
|
-
if tag in frequent_tags:
|
|
425
|
-
tag_groups[tag].append(node)
|
|
426
|
-
break
|
|
427
|
-
|
|
428
|
-
filtered_tag_clusters = []
|
|
429
|
-
assigned_ids = set()
|
|
430
|
-
for tag, group in tag_groups.items():
|
|
431
|
-
if len(group) >= min_cluster_size:
|
|
432
|
-
# Split large groups into chunks of at most max_cluster_size
|
|
433
|
-
for i in range(0, len(group), max_cluster_size):
|
|
434
|
-
sub_group = group[i : i + max_cluster_size]
|
|
435
|
-
filtered_tag_clusters.append(sub_group)
|
|
436
|
-
assigned_ids.update(n.id for n in sub_group)
|
|
437
|
-
else:
|
|
438
|
-
logger.info(f"... dropped tag {tag} due to low size ...")
|
|
439
|
-
|
|
440
|
-
logger.info(
|
|
441
|
-
f"[MixedPartition] Created {len(filtered_tag_clusters)} clusters from tags. "
|
|
442
|
-
f"Nodes grouped by tags: {len(assigned_ids)} / {len(nodes)}"
|
|
443
|
-
)
|
|
444
|
-
|
|
445
|
-
# Remaining nodes -> embedding clustering
|
|
446
|
-
remaining_nodes = [n for n in nodes if n.id not in assigned_ids]
|
|
447
|
-
logger.info(
|
|
448
|
-
f"[MixedPartition] Remaining nodes for embedding clustering: {len(remaining_nodes)}"
|
|
449
|
-
)
|
|
450
|
-
|
|
451
|
-
embedding_clusters = []
|
|
395
|
+
if len(nodes) <= max_cluster_size:
|
|
396
|
+
logger.info(
|
|
397
|
+
f"[KMeansPartition] Node count {len(nodes)} <= {max_cluster_size}, skipping KMeans."
|
|
398
|
+
)
|
|
399
|
+
return [nodes]
|
|
452
400
|
|
|
453
|
-
def recursive_clustering(nodes_list):
|
|
401
|
+
def recursive_clustering(nodes_list, depth=0):
|
|
454
402
|
"""Recursively split clusters until each is <= max_cluster_size."""
|
|
403
|
+
indent = " " * depth
|
|
404
|
+
logger.info(
|
|
405
|
+
f"{indent}[Recursive] Start clustering {len(nodes_list)} nodes at depth {depth}"
|
|
406
|
+
)
|
|
407
|
+
|
|
455
408
|
if len(nodes_list) <= max_cluster_size:
|
|
409
|
+
logger.info(
|
|
410
|
+
f"{indent}[Recursive] Node count <= {max_cluster_size}, stop splitting."
|
|
411
|
+
)
|
|
456
412
|
return [nodes_list]
|
|
457
|
-
|
|
458
413
|
# Try kmeans with k = ceil(len(nodes) / max_cluster_size)
|
|
459
|
-
|
|
460
|
-
|
|
414
|
+
x_nodes = [n for n in nodes_list if n.metadata.embedding]
|
|
415
|
+
x = np.array([n.metadata.embedding for n in x_nodes])
|
|
416
|
+
|
|
417
|
+
if len(x) < min_cluster_size:
|
|
418
|
+
logger.info(
|
|
419
|
+
f"{indent}[Recursive] Too few embeddings ({len(x)}), skipping clustering."
|
|
420
|
+
)
|
|
461
421
|
return [nodes_list]
|
|
462
422
|
|
|
463
423
|
k = min(len(x), (len(nodes_list) + max_cluster_size - 1) // max_cluster_size)
|
|
464
|
-
k = max(1,
|
|
424
|
+
k = max(1, k)
|
|
465
425
|
|
|
466
426
|
try:
|
|
427
|
+
logger.info(f"{indent}[Recursive] Clustering with k={k} on {len(x)} points.")
|
|
467
428
|
kmeans = MiniBatchKMeans(n_clusters=k, batch_size=256, random_state=42)
|
|
468
429
|
labels = kmeans.fit_predict(x)
|
|
469
430
|
|
|
470
431
|
label_groups = defaultdict(list)
|
|
471
|
-
for node, label in zip(
|
|
432
|
+
for node, label in zip(x_nodes, labels, strict=False):
|
|
472
433
|
label_groups[label].append(node)
|
|
473
434
|
|
|
435
|
+
# Map: label -> nodes with no embedding (fallback group)
|
|
436
|
+
no_embedding_nodes = [n for n in nodes_list if not n.metadata.embedding]
|
|
437
|
+
if no_embedding_nodes:
|
|
438
|
+
logger.warning(
|
|
439
|
+
f"{indent}[Recursive] {len(no_embedding_nodes)} nodes have no embedding. Added to largest cluster."
|
|
440
|
+
)
|
|
441
|
+
# Assign to largest cluster
|
|
442
|
+
largest_label = max(label_groups.items(), key=lambda kv: len(kv[1]))[0]
|
|
443
|
+
label_groups[largest_label].extend(no_embedding_nodes)
|
|
444
|
+
|
|
474
445
|
result = []
|
|
475
|
-
for sub_group in label_groups.
|
|
476
|
-
|
|
446
|
+
for label, sub_group in label_groups.items():
|
|
447
|
+
logger.info(f"{indent} Cluster-{label}: {len(sub_group)} nodes")
|
|
448
|
+
result.extend(recursive_clustering(sub_group, depth=depth + 1))
|
|
477
449
|
return result
|
|
450
|
+
|
|
478
451
|
except Exception as e:
|
|
479
|
-
logger.warning(
|
|
452
|
+
logger.warning(
|
|
453
|
+
f"{indent}[Recursive] Clustering failed: {e}, fallback to one cluster."
|
|
454
|
+
)
|
|
480
455
|
return [nodes_list]
|
|
481
456
|
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
embedding_clusters.extend(clusters)
|
|
485
|
-
logger.info(
|
|
486
|
-
f"[MixedPartition] Created {len(embedding_clusters)} clusters from embeddings."
|
|
487
|
-
)
|
|
488
|
-
|
|
489
|
-
# Merge all clusters
|
|
490
|
-
all_clusters = filtered_tag_clusters + embedding_clusters
|
|
457
|
+
raw_clusters = recursive_clustering(nodes)
|
|
458
|
+
filtered_clusters = [c for c in raw_clusters if len(c) > min_cluster_size]
|
|
491
459
|
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
for group in all_clusters:
|
|
496
|
-
if len(group) < min_cluster_size:
|
|
497
|
-
small_nodes.extend(group)
|
|
498
|
-
else:
|
|
499
|
-
final_clusters.append(group)
|
|
460
|
+
logger.info(f"[KMeansPartition] Total clusters before filtering: {len(raw_clusters)}")
|
|
461
|
+
for i, cluster in enumerate(raw_clusters):
|
|
462
|
+
logger.info(f"[KMeansPartition] Cluster-{i}: {len(cluster)} nodes")
|
|
500
463
|
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
464
|
+
logger.info(
|
|
465
|
+
f"[KMeansPartition] Clusters after filtering (>{min_cluster_size}): {len(filtered_clusters)}"
|
|
466
|
+
)
|
|
504
467
|
|
|
505
|
-
|
|
506
|
-
return final_clusters
|
|
468
|
+
return filtered_clusters
|
|
507
469
|
|
|
508
470
|
def _summarize_cluster(self, cluster_nodes: list[GraphDBNode], scope: str) -> GraphDBNode:
|
|
509
471
|
"""
|
|
@@ -600,7 +562,7 @@ class GraphStructureReorganizer:
|
|
|
600
562
|
for i, node in enumerate(message.after_node or []):
|
|
601
563
|
if not isinstance(node, str):
|
|
602
564
|
continue
|
|
603
|
-
raw_node = self.graph_store.get_node(node)
|
|
565
|
+
raw_node = self.graph_store.get_node(node, include_embedding=True)
|
|
604
566
|
if raw_node is None:
|
|
605
567
|
logger.debug(f"Node with ID {node} not found in the graph store.")
|
|
606
568
|
message.after_node[i] = None
|