MemoryOS 0.1.12__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MemoryOS might be problematic. Click here for more details.
- {memoryos-0.1.12.dist-info → memoryos-0.2.0.dist-info}/METADATA +51 -31
- {memoryos-0.1.12.dist-info → memoryos-0.2.0.dist-info}/RECORD +32 -21
- memos/__init__.py +1 -1
- memos/configs/internet_retriever.py +81 -0
- memos/configs/llm.py +1 -0
- memos/configs/mem_os.py +4 -0
- memos/configs/mem_reader.py +4 -0
- memos/configs/memory.py +11 -1
- memos/graph_dbs/item.py +46 -0
- memos/graph_dbs/neo4j.py +72 -5
- memos/llms/openai.py +1 -0
- memos/mem_os/main.py +491 -0
- memos/mem_reader/simple_struct.py +11 -6
- memos/mem_user/user_manager.py +10 -0
- memos/memories/textual/item.py +3 -1
- memos/memories/textual/tree.py +39 -3
- memos/memories/textual/tree_text_memory/organize/conflict.py +196 -0
- memos/memories/textual/tree_text_memory/organize/manager.py +49 -8
- memos/memories/textual/tree_text_memory/organize/redundancy.py +212 -0
- memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +235 -0
- memos/memories/textual/tree_text_memory/organize/reorganizer.py +584 -0
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever.py +263 -0
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +89 -0
- memos/memories/textual/tree_text_memory/retrieve/reasoner.py +1 -4
- memos/memories/textual/tree_text_memory/retrieve/searcher.py +46 -4
- memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +3 -3
- memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py +335 -0
- memos/templates/mem_reader_prompts.py +42 -15
- memos/templates/mos_prompts.py +63 -0
- memos/templates/tree_reorganize_prompts.py +168 -0
- {memoryos-0.1.12.dist-info → memoryos-0.2.0.dist-info}/LICENSE +0 -0
- {memoryos-0.1.12.dist-info → memoryos-0.2.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,584 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import threading
|
|
3
|
+
import time
|
|
4
|
+
import traceback
|
|
5
|
+
|
|
6
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
7
|
+
from queue import PriorityQueue
|
|
8
|
+
from typing import Literal
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import schedule
|
|
12
|
+
|
|
13
|
+
from sklearn.cluster import MiniBatchKMeans
|
|
14
|
+
|
|
15
|
+
from memos.embedders.factory import OllamaEmbedder
|
|
16
|
+
from memos.graph_dbs.item import GraphDBEdge, GraphDBNode
|
|
17
|
+
from memos.graph_dbs.neo4j import Neo4jGraphDB
|
|
18
|
+
from memos.llms.base import BaseLLM
|
|
19
|
+
from memos.log import get_logger
|
|
20
|
+
from memos.memories.textual.item import TreeNodeTextualMemoryMetadata
|
|
21
|
+
from memos.memories.textual.tree_text_memory.organize.conflict import ConflictHandler
|
|
22
|
+
from memos.memories.textual.tree_text_memory.organize.redundancy import RedundancyHandler
|
|
23
|
+
from memos.memories.textual.tree_text_memory.organize.relation_reason_detector import (
|
|
24
|
+
RelationAndReasoningDetector,
|
|
25
|
+
)
|
|
26
|
+
from memos.templates.tree_reorganize_prompts import LOCAL_SUBCLUSTER_PROMPT, REORGANIZE_PROMPT
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
logger = get_logger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class QueueMessage:
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
op: Literal["add", "remove", "merge", "update"],
|
|
36
|
+
# `str` for node and edge IDs, `GraphDBNode` and `GraphDBEdge` for actual objects
|
|
37
|
+
before_node: list[str] | list[GraphDBNode] | None = None,
|
|
38
|
+
before_edge: list[str] | list[GraphDBEdge] | None = None,
|
|
39
|
+
after_node: list[str] | list[GraphDBNode] | None = None,
|
|
40
|
+
after_edge: list[str] | list[GraphDBEdge] | None = None,
|
|
41
|
+
):
|
|
42
|
+
self.op = op
|
|
43
|
+
self.before_node = before_node
|
|
44
|
+
self.before_edge = before_edge
|
|
45
|
+
self.after_node = after_node
|
|
46
|
+
self.after_edge = after_edge
|
|
47
|
+
|
|
48
|
+
def __str__(self) -> str:
|
|
49
|
+
return f"QueueMessage(op={self.op}, before_node={self.before_node if self.before_node is None else len(self.before_node)}, after_node={self.after_node if self.after_node is None else len(self.after_node)})"
|
|
50
|
+
|
|
51
|
+
def __lt__(self, other: "QueueMessage") -> bool:
|
|
52
|
+
op_priority = {"add": 2, "remove": 2, "merge": 1}
|
|
53
|
+
return op_priority[self.op] < op_priority[other.op]
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class GraphStructureReorganizer:
|
|
57
|
+
def __init__(
|
|
58
|
+
self, graph_store: Neo4jGraphDB, llm: BaseLLM, embedder: OllamaEmbedder, is_reorganize: bool
|
|
59
|
+
):
|
|
60
|
+
self.queue = PriorityQueue() # Min-heap
|
|
61
|
+
self.graph_store = graph_store
|
|
62
|
+
self.llm = llm
|
|
63
|
+
self.embedder = embedder
|
|
64
|
+
self.relation_detector = RelationAndReasoningDetector(
|
|
65
|
+
self.graph_store, self.llm, self.embedder
|
|
66
|
+
)
|
|
67
|
+
self.conflict = ConflictHandler(graph_store=graph_store, llm=llm, embedder=embedder)
|
|
68
|
+
self.redundancy = RedundancyHandler(graph_store=graph_store, llm=llm, embedder=embedder)
|
|
69
|
+
|
|
70
|
+
self.is_reorganize = is_reorganize
|
|
71
|
+
if self.is_reorganize:
|
|
72
|
+
# ____ 1. For queue message driven thread ___________
|
|
73
|
+
self.thread = threading.Thread(target=self._run_message_consumer_loop)
|
|
74
|
+
self.thread.start()
|
|
75
|
+
# ____ 2. For periodic structure optimization _______
|
|
76
|
+
self._stop_scheduler = False
|
|
77
|
+
self._is_optimizing = {"LongTermMemory": False, "UserMemory": False}
|
|
78
|
+
self.structure_optimizer_thread = threading.Thread(
|
|
79
|
+
target=self._run_structure_organizer_loop
|
|
80
|
+
)
|
|
81
|
+
self.structure_optimizer_thread.start()
|
|
82
|
+
|
|
83
|
+
def add_message(self, message: QueueMessage):
|
|
84
|
+
self.queue.put_nowait(message)
|
|
85
|
+
|
|
86
|
+
def wait_until_current_task_done(self):
|
|
87
|
+
"""
|
|
88
|
+
Wait until:
|
|
89
|
+
1) queue is empty
|
|
90
|
+
2) any running structure optimization is done
|
|
91
|
+
"""
|
|
92
|
+
if not self.is_reorganize:
|
|
93
|
+
return
|
|
94
|
+
|
|
95
|
+
if not self.queue.empty():
|
|
96
|
+
self.queue.join()
|
|
97
|
+
logger.debug("Queue is now empty.")
|
|
98
|
+
|
|
99
|
+
while any(self._is_optimizing.values()):
|
|
100
|
+
logger.debug(f"Waiting for structure optimizer to finish... {self._is_optimizing}")
|
|
101
|
+
time.sleep(1)
|
|
102
|
+
logger.debug("Structure optimizer is now idle.")
|
|
103
|
+
|
|
104
|
+
def _run_message_consumer_loop(self):
|
|
105
|
+
while True:
|
|
106
|
+
message = self.queue.get()
|
|
107
|
+
if message is None:
|
|
108
|
+
break
|
|
109
|
+
|
|
110
|
+
try:
|
|
111
|
+
if self._preprocess_message(message):
|
|
112
|
+
self.handle_message(message)
|
|
113
|
+
except Exception:
|
|
114
|
+
logger.error(traceback.format_exc())
|
|
115
|
+
self.queue.task_done()
|
|
116
|
+
|
|
117
|
+
def _run_structure_organizer_loop(self):
|
|
118
|
+
"""
|
|
119
|
+
Use schedule library to periodically trigger structure optimization.
|
|
120
|
+
This runs until the stop flag is set.
|
|
121
|
+
"""
|
|
122
|
+
schedule.every(20).seconds.do(self.optimize_structure, scope="LongTermMemory")
|
|
123
|
+
schedule.every(20).seconds.do(self.optimize_structure, scope="UserMemory")
|
|
124
|
+
|
|
125
|
+
logger.info("Structure optimizer schedule started.")
|
|
126
|
+
while not getattr(self, "_stop_scheduler", False):
|
|
127
|
+
schedule.run_pending()
|
|
128
|
+
time.sleep(1)
|
|
129
|
+
|
|
130
|
+
def stop(self):
|
|
131
|
+
"""
|
|
132
|
+
Stop the reorganizer thread.
|
|
133
|
+
"""
|
|
134
|
+
if not self.is_reorganize:
|
|
135
|
+
return
|
|
136
|
+
|
|
137
|
+
self.add_message(None)
|
|
138
|
+
self.thread.join()
|
|
139
|
+
logger.info("Reorganize thread stopped.")
|
|
140
|
+
self._stop_scheduler = True
|
|
141
|
+
self.structure_optimizer_thread.join()
|
|
142
|
+
logger.info("Structure optimizer stopped.")
|
|
143
|
+
|
|
144
|
+
def handle_message(self, message: QueueMessage):
|
|
145
|
+
handle_map = {
|
|
146
|
+
"add": self.handle_add,
|
|
147
|
+
"remove": self.handle_remove,
|
|
148
|
+
"merge": self.handle_merge,
|
|
149
|
+
}
|
|
150
|
+
handle_map[message.op](message)
|
|
151
|
+
logger.debug(f"message queue size: {self.queue.qsize()}")
|
|
152
|
+
|
|
153
|
+
def handle_add(self, message: QueueMessage):
|
|
154
|
+
logger.debug(f"Handling add operation: {str(message)[:500]}")
|
|
155
|
+
assert message.before_node is None and message.before_edge is None, (
|
|
156
|
+
"Before node and edge should be None for `add` operation."
|
|
157
|
+
)
|
|
158
|
+
# ———————— 1. check for conflicts ————————
|
|
159
|
+
added_node = message.after_node[0]
|
|
160
|
+
conflicts = self.conflict.detect(added_node, scope=added_node.metadata.memory_type)
|
|
161
|
+
if conflicts:
|
|
162
|
+
for added_node, existing_node in conflicts:
|
|
163
|
+
self.conflict.resolve(added_node, existing_node)
|
|
164
|
+
logger.info(f"Resolved conflict between {added_node.id} and {existing_node.id}.")
|
|
165
|
+
|
|
166
|
+
# ———————— 2. check for redundancy ————————
|
|
167
|
+
redundancy = self.redundancy.detect(added_node, scope=added_node.metadata.memory_type)
|
|
168
|
+
if redundancy:
|
|
169
|
+
for added_node, existing_node in redundancy:
|
|
170
|
+
self.redundancy.resolve_two_nodes(added_node, existing_node)
|
|
171
|
+
logger.info(f"Resolved redundancy between {added_node.id} and {existing_node.id}.")
|
|
172
|
+
|
|
173
|
+
def handle_remove(self, message: QueueMessage):
|
|
174
|
+
logger.debug(f"Handling remove operation: {str(message)[:50]}")
|
|
175
|
+
|
|
176
|
+
def handle_merge(self, message: QueueMessage):
|
|
177
|
+
after_node = message.after_node[0]
|
|
178
|
+
logger.debug(f"Handling merge operation: <{after_node.memory}>")
|
|
179
|
+
self.redundancy_resolver.resolve_one_node(after_node)
|
|
180
|
+
|
|
181
|
+
def optimize_structure(
|
|
182
|
+
self,
|
|
183
|
+
scope: str = "LongTermMemory",
|
|
184
|
+
local_tree_threshold: int = 10,
|
|
185
|
+
min_cluster_size: int = 3,
|
|
186
|
+
min_group_size: int = 10,
|
|
187
|
+
):
|
|
188
|
+
"""
|
|
189
|
+
Periodically reorganize the graph:
|
|
190
|
+
1. Weakly partition nodes into clusters.
|
|
191
|
+
2. Summarize each cluster.
|
|
192
|
+
3. Create parent nodes and build local PARENT trees.
|
|
193
|
+
"""
|
|
194
|
+
if self._is_optimizing[scope]:
|
|
195
|
+
logger.info(f"Already optimizing for {scope}. Skipping.")
|
|
196
|
+
return
|
|
197
|
+
|
|
198
|
+
if self.graph_store.count_nodes(scope) == 0:
|
|
199
|
+
logger.debug(f"[GraphStructureReorganize] No nodes for scope={scope}. Skip.")
|
|
200
|
+
return
|
|
201
|
+
|
|
202
|
+
self._is_optimizing[scope] = True
|
|
203
|
+
try:
|
|
204
|
+
logger.debug(
|
|
205
|
+
f"[GraphStructureReorganize] 🔍 Starting structure optimization for scope: {scope}"
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
logger.debug(
|
|
209
|
+
f"Num of scope in self.graph_store is {self.graph_store.get_memory_count(scope)}"
|
|
210
|
+
)
|
|
211
|
+
# Load candidate nodes
|
|
212
|
+
raw_nodes = self.graph_store.get_structure_optimization_candidates(scope)
|
|
213
|
+
nodes = [GraphDBNode(**n) for n in raw_nodes]
|
|
214
|
+
|
|
215
|
+
if not nodes:
|
|
216
|
+
logger.info("[GraphStructureReorganize] No nodes to optimize. Skipping.")
|
|
217
|
+
return
|
|
218
|
+
|
|
219
|
+
if len(nodes) < min_group_size:
|
|
220
|
+
logger.info(
|
|
221
|
+
f"[GraphStructureReorganize] Only {len(nodes)} candidate nodes found. Not enough to reorganize. Skipping."
|
|
222
|
+
)
|
|
223
|
+
return
|
|
224
|
+
|
|
225
|
+
logger.info(f"[GraphStructureReorganize] Loaded {len(nodes)} nodes.")
|
|
226
|
+
|
|
227
|
+
# Step 2: Partition nodes
|
|
228
|
+
partitioned_groups = self._partition(nodes)
|
|
229
|
+
|
|
230
|
+
logger.info(
|
|
231
|
+
f"[GraphStructureReorganize] Partitioned into {len(partitioned_groups)} clusters."
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
with ThreadPoolExecutor(max_workers=4) as executor:
|
|
235
|
+
futures = []
|
|
236
|
+
for cluster_nodes in partitioned_groups:
|
|
237
|
+
futures.append(
|
|
238
|
+
executor.submit(
|
|
239
|
+
self._process_cluster_and_write,
|
|
240
|
+
cluster_nodes,
|
|
241
|
+
scope,
|
|
242
|
+
local_tree_threshold,
|
|
243
|
+
min_cluster_size,
|
|
244
|
+
)
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
for f in as_completed(futures):
|
|
248
|
+
try:
|
|
249
|
+
f.result()
|
|
250
|
+
except Exception as e:
|
|
251
|
+
logger.warning(f"[Reorganize] Cluster processing failed: {e}")
|
|
252
|
+
logger.info("[GraphStructure Reorganize] Structure optimization finished.")
|
|
253
|
+
|
|
254
|
+
finally:
|
|
255
|
+
self._is_optimizing[scope] = False
|
|
256
|
+
logger.info("[GraphStructureReorganize] Structure optimization finished.")
|
|
257
|
+
|
|
258
|
+
def _process_cluster_and_write(
|
|
259
|
+
self,
|
|
260
|
+
cluster_nodes: list[GraphDBNode],
|
|
261
|
+
scope: str,
|
|
262
|
+
local_tree_threshold: int,
|
|
263
|
+
min_cluster_size: int,
|
|
264
|
+
):
|
|
265
|
+
if len(cluster_nodes) <= min_cluster_size:
|
|
266
|
+
return
|
|
267
|
+
|
|
268
|
+
if len(cluster_nodes) <= local_tree_threshold:
|
|
269
|
+
# Small cluster ➜ single parent
|
|
270
|
+
parent_node = self._summarize_cluster(cluster_nodes, scope)
|
|
271
|
+
self._create_parent_node(parent_node)
|
|
272
|
+
self._link_cluster_nodes(parent_node, cluster_nodes)
|
|
273
|
+
else:
|
|
274
|
+
# Large cluster ➜ local sub-clustering
|
|
275
|
+
sub_clusters = self._local_subcluster(cluster_nodes)
|
|
276
|
+
sub_parents = []
|
|
277
|
+
|
|
278
|
+
for sub_nodes in sub_clusters:
|
|
279
|
+
if len(sub_nodes) < min_cluster_size:
|
|
280
|
+
continue # Skip tiny noise
|
|
281
|
+
sub_parent_node = self._summarize_cluster(sub_nodes, scope)
|
|
282
|
+
self._create_parent_node(sub_parent_node)
|
|
283
|
+
self._link_cluster_nodes(sub_parent_node, sub_nodes)
|
|
284
|
+
sub_parents.append(sub_parent_node)
|
|
285
|
+
|
|
286
|
+
if sub_parents:
|
|
287
|
+
cluster_parent_node = self._summarize_cluster(cluster_nodes, scope)
|
|
288
|
+
self._create_parent_node(cluster_parent_node)
|
|
289
|
+
for sub_parent in sub_parents:
|
|
290
|
+
self.graph_store.add_edge(cluster_parent_node.id, sub_parent.id, "PARENT")
|
|
291
|
+
|
|
292
|
+
logger.info("Adding relations/reasons")
|
|
293
|
+
nodes_to_check = cluster_nodes
|
|
294
|
+
exclude_ids = [n.id for n in nodes_to_check]
|
|
295
|
+
|
|
296
|
+
with ThreadPoolExecutor(max_workers=4) as executor:
|
|
297
|
+
futures = []
|
|
298
|
+
for node in nodes_to_check:
|
|
299
|
+
futures.append(
|
|
300
|
+
executor.submit(
|
|
301
|
+
self.relation_detector.process_node,
|
|
302
|
+
node,
|
|
303
|
+
exclude_ids,
|
|
304
|
+
10, # top_k
|
|
305
|
+
)
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
for f in as_completed(futures):
|
|
309
|
+
results = f.result()
|
|
310
|
+
|
|
311
|
+
# 1) Add pairwise relations
|
|
312
|
+
for rel in results["relations"]:
|
|
313
|
+
if not self.graph_store.edge_exists(
|
|
314
|
+
rel["source_id"], rel["target_id"], rel["relation_type"]
|
|
315
|
+
):
|
|
316
|
+
self.graph_store.add_edge(
|
|
317
|
+
rel["source_id"], rel["target_id"], rel["relation_type"]
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
# 2) Add inferred nodes and link to sources
|
|
321
|
+
for inf_node in results["inferred_nodes"]:
|
|
322
|
+
self.graph_store.add_node(
|
|
323
|
+
inf_node.id,
|
|
324
|
+
inf_node.memory,
|
|
325
|
+
inf_node.metadata.model_dump(exclude_none=True),
|
|
326
|
+
)
|
|
327
|
+
for src_id in inf_node.metadata.sources:
|
|
328
|
+
self.graph_store.add_edge(src_id, inf_node.id, "INFERS")
|
|
329
|
+
|
|
330
|
+
# 3) Add sequence links
|
|
331
|
+
for seq in results["sequence_links"]:
|
|
332
|
+
if not self.graph_store.edge_exists(seq["from_id"], seq["to_id"], "FOLLOWS"):
|
|
333
|
+
self.graph_store.add_edge(seq["from_id"], seq["to_id"], "FOLLOWS")
|
|
334
|
+
|
|
335
|
+
# 4) Add aggregate concept nodes
|
|
336
|
+
for agg_node in results["aggregate_nodes"]:
|
|
337
|
+
self.graph_store.add_node(
|
|
338
|
+
agg_node.id,
|
|
339
|
+
agg_node.memory,
|
|
340
|
+
agg_node.metadata.model_dump(exclude_none=True),
|
|
341
|
+
)
|
|
342
|
+
for child_id in agg_node.metadata.sources:
|
|
343
|
+
self.graph_store.add_edge(agg_node.id, child_id, "AGGREGATES")
|
|
344
|
+
|
|
345
|
+
logger.info("[Reorganizer] Cluster relation/reasoning done.")
|
|
346
|
+
|
|
347
|
+
def _local_subcluster(self, cluster_nodes: list[GraphDBNode]) -> list[list[GraphDBNode]]:
|
|
348
|
+
"""
|
|
349
|
+
Use LLM to split a large cluster into semantically coherent sub-clusters.
|
|
350
|
+
"""
|
|
351
|
+
if not cluster_nodes:
|
|
352
|
+
return []
|
|
353
|
+
|
|
354
|
+
# Prepare conversation-like input: ID + key + value
|
|
355
|
+
scene_lines = []
|
|
356
|
+
for node in cluster_nodes:
|
|
357
|
+
line = f"- ID: {node.id} | Key: {node.metadata.key} | Value: {node.memory}"
|
|
358
|
+
scene_lines.append(line)
|
|
359
|
+
|
|
360
|
+
joined_scene = "\n".join(scene_lines)
|
|
361
|
+
prompt = LOCAL_SUBCLUSTER_PROMPT.format(joined_scene=joined_scene)
|
|
362
|
+
|
|
363
|
+
messages = [{"role": "user", "content": prompt}]
|
|
364
|
+
response_text = self.llm.generate(messages)
|
|
365
|
+
response_json = self._parse_json_result(response_text)
|
|
366
|
+
assigned_ids = set()
|
|
367
|
+
result_subclusters = []
|
|
368
|
+
|
|
369
|
+
for cluster in response_json.get("clusters", []):
|
|
370
|
+
ids = []
|
|
371
|
+
for nid in cluster.get("ids", []):
|
|
372
|
+
if nid not in assigned_ids:
|
|
373
|
+
ids.append(nid)
|
|
374
|
+
assigned_ids.add(nid)
|
|
375
|
+
sub_nodes = [node for node in cluster_nodes if node.id in ids]
|
|
376
|
+
if len(sub_nodes) >= 2:
|
|
377
|
+
result_subclusters.append(sub_nodes)
|
|
378
|
+
|
|
379
|
+
return result_subclusters
|
|
380
|
+
|
|
381
|
+
def _partition(
|
|
382
|
+
self, nodes: list[GraphDBNode], min_cluster_size: int = 3
|
|
383
|
+
) -> list[list[GraphDBNode]]:
|
|
384
|
+
"""
|
|
385
|
+
Partition nodes by:
|
|
386
|
+
1) Frequent tags (top N & above threshold)
|
|
387
|
+
2) Remaining nodes by embedding clustering (MiniBatchKMeans)
|
|
388
|
+
3) Small clusters merged or assigned to 'Other'
|
|
389
|
+
|
|
390
|
+
Args:
|
|
391
|
+
nodes: List of GraphDBNode
|
|
392
|
+
min_cluster_size: Min size to keep a cluster as-is
|
|
393
|
+
|
|
394
|
+
Returns:
|
|
395
|
+
List of clusters, each as a list of GraphDBNode
|
|
396
|
+
"""
|
|
397
|
+
from collections import Counter, defaultdict
|
|
398
|
+
|
|
399
|
+
# 1) Count all tags
|
|
400
|
+
tag_counter = Counter()
|
|
401
|
+
for node in nodes:
|
|
402
|
+
for tag in node.metadata.tags:
|
|
403
|
+
tag_counter[tag] += 1
|
|
404
|
+
|
|
405
|
+
# Select frequent tags
|
|
406
|
+
top_n_tags = {tag for tag, count in tag_counter.most_common(50)}
|
|
407
|
+
threshold_tags = {tag for tag, count in tag_counter.items() if count >= 50}
|
|
408
|
+
frequent_tags = top_n_tags | threshold_tags
|
|
409
|
+
|
|
410
|
+
# Group nodes by tags, ensure each group is unique internally
|
|
411
|
+
tag_groups = defaultdict(list)
|
|
412
|
+
|
|
413
|
+
for node in nodes:
|
|
414
|
+
for tag in node.metadata.tags:
|
|
415
|
+
if tag in frequent_tags:
|
|
416
|
+
tag_groups[tag].append(node)
|
|
417
|
+
break
|
|
418
|
+
|
|
419
|
+
filtered_tag_clusters = []
|
|
420
|
+
assigned_ids = set()
|
|
421
|
+
for tag, group in tag_groups.items():
|
|
422
|
+
if len(group) >= min_cluster_size:
|
|
423
|
+
filtered_tag_clusters.append(group)
|
|
424
|
+
assigned_ids.update(n.id for n in group)
|
|
425
|
+
else:
|
|
426
|
+
logger.info(f"... dropped {tag} ...")
|
|
427
|
+
|
|
428
|
+
logger.info(
|
|
429
|
+
f"[MixedPartition] Created {len(filtered_tag_clusters)} clusters from tags. "
|
|
430
|
+
f"Nodes grouped by tags: {len(assigned_ids)} / {len(nodes)}"
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
# 5) Remaining nodes -> embedding clustering
|
|
434
|
+
remaining_nodes = [n for n in nodes if n.id not in assigned_ids]
|
|
435
|
+
logger.info(
|
|
436
|
+
f"[MixedPartition] Remaining nodes for embedding clustering: {len(remaining_nodes)}"
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
embedding_clusters = []
|
|
440
|
+
if remaining_nodes:
|
|
441
|
+
x = np.array([n.metadata.embedding for n in remaining_nodes if n.metadata.embedding])
|
|
442
|
+
k = max(1, min(len(remaining_nodes) // min_cluster_size, 20))
|
|
443
|
+
if len(x) < k:
|
|
444
|
+
k = len(x)
|
|
445
|
+
|
|
446
|
+
if 1 < k <= len(x):
|
|
447
|
+
kmeans = MiniBatchKMeans(n_clusters=k, batch_size=256, random_state=42)
|
|
448
|
+
labels = kmeans.fit_predict(x)
|
|
449
|
+
|
|
450
|
+
label_groups = defaultdict(list)
|
|
451
|
+
for node, label in zip(remaining_nodes, labels, strict=False):
|
|
452
|
+
label_groups[label].append(node)
|
|
453
|
+
|
|
454
|
+
embedding_clusters = list(label_groups.values())
|
|
455
|
+
logger.info(
|
|
456
|
+
f"[MixedPartition] Created {len(embedding_clusters)} clusters from embedding."
|
|
457
|
+
)
|
|
458
|
+
else:
|
|
459
|
+
embedding_clusters = [remaining_nodes]
|
|
460
|
+
|
|
461
|
+
# Merge all & handle small clusters
|
|
462
|
+
all_clusters = filtered_tag_clusters + embedding_clusters
|
|
463
|
+
|
|
464
|
+
# Optional: merge tiny clusters
|
|
465
|
+
final_clusters = []
|
|
466
|
+
small_nodes = []
|
|
467
|
+
for group in all_clusters:
|
|
468
|
+
if len(group) < min_cluster_size:
|
|
469
|
+
small_nodes.extend(group)
|
|
470
|
+
else:
|
|
471
|
+
final_clusters.append(group)
|
|
472
|
+
|
|
473
|
+
if small_nodes:
|
|
474
|
+
final_clusters.append(small_nodes)
|
|
475
|
+
logger.info(f"[MixedPartition] {len(small_nodes)} nodes assigned to 'Other' cluster.")
|
|
476
|
+
|
|
477
|
+
logger.info(f"[MixedPartition] Total final clusters: {len(final_clusters)}")
|
|
478
|
+
return final_clusters
|
|
479
|
+
|
|
480
|
+
def _summarize_cluster(self, cluster_nodes: list[GraphDBNode], scope: str) -> GraphDBNode:
|
|
481
|
+
"""
|
|
482
|
+
Generate a cluster label using LLM, based on top keys in the cluster.
|
|
483
|
+
"""
|
|
484
|
+
if not cluster_nodes:
|
|
485
|
+
raise ValueError("Cluster nodes cannot be empty.")
|
|
486
|
+
|
|
487
|
+
joined_keys = "\n".join(f"- {n.metadata.key}" for n in cluster_nodes if n.metadata.key)
|
|
488
|
+
joined_values = "\n".join(f"- {n.memory}" for n in cluster_nodes)
|
|
489
|
+
joined_backgrounds = "\n".join(
|
|
490
|
+
f"- {n.metadata.background}" for n in cluster_nodes if n.metadata.background
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
# Build prompt
|
|
494
|
+
prompt = REORGANIZE_PROMPT.format(
|
|
495
|
+
joined_keys=joined_keys,
|
|
496
|
+
joined_values=joined_values,
|
|
497
|
+
joined_backgrounds=joined_backgrounds,
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
messages = [{"role": "user", "content": prompt}]
|
|
501
|
+
response_text = self.llm.generate(messages)
|
|
502
|
+
response_json = self._parse_json_result(response_text)
|
|
503
|
+
|
|
504
|
+
# Extract fields
|
|
505
|
+
parent_key = response_json.get("key", "").strip()
|
|
506
|
+
parent_value = response_json.get("value", "").strip()
|
|
507
|
+
parent_tags = response_json.get("tags", [])
|
|
508
|
+
parent_background = response_json.get("background", "").strip()
|
|
509
|
+
|
|
510
|
+
embedding = self.embedder.embed([parent_value])[0]
|
|
511
|
+
|
|
512
|
+
parent_node = GraphDBNode(
|
|
513
|
+
memory=parent_value,
|
|
514
|
+
metadata=TreeNodeTextualMemoryMetadata(
|
|
515
|
+
user_id="", # TODO: summarized node: no user_id
|
|
516
|
+
session_id="", # TODO: summarized node: no session_id
|
|
517
|
+
memory_type=scope,
|
|
518
|
+
status="activated",
|
|
519
|
+
key=parent_key,
|
|
520
|
+
tags=parent_tags,
|
|
521
|
+
embedding=embedding,
|
|
522
|
+
usage=[],
|
|
523
|
+
sources=[n.id for n in cluster_nodes],
|
|
524
|
+
background=parent_background,
|
|
525
|
+
confidence=0.99,
|
|
526
|
+
type="topic",
|
|
527
|
+
),
|
|
528
|
+
)
|
|
529
|
+
return parent_node
|
|
530
|
+
|
|
531
|
+
def _parse_json_result(self, response_text):
|
|
532
|
+
try:
|
|
533
|
+
response_text = response_text.replace("```", "").replace("json", "")
|
|
534
|
+
response_json = json.loads(response_text)
|
|
535
|
+
return response_json
|
|
536
|
+
except json.JSONDecodeError as e:
|
|
537
|
+
logger.warning(
|
|
538
|
+
f"Failed to parse LLM response as JSON: {e}\nRaw response:\n{response_text}"
|
|
539
|
+
)
|
|
540
|
+
return {}
|
|
541
|
+
|
|
542
|
+
def _create_parent_node(self, parent_node: GraphDBNode) -> None:
|
|
543
|
+
"""
|
|
544
|
+
Create a new parent node for the cluster.
|
|
545
|
+
"""
|
|
546
|
+
self.graph_store.add_node(
|
|
547
|
+
parent_node.id,
|
|
548
|
+
parent_node.memory,
|
|
549
|
+
parent_node.metadata.model_dump(exclude_none=True),
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
def _link_cluster_nodes(self, parent_node: GraphDBNode, child_nodes: list[GraphDBNode]):
|
|
553
|
+
"""
|
|
554
|
+
Add PARENT edges from the parent node to all nodes in the cluster.
|
|
555
|
+
"""
|
|
556
|
+
for child in child_nodes:
|
|
557
|
+
if not self.graph_store.edge_exists(
|
|
558
|
+
parent_node.id, child.id, "PARENT", direction="OUTGOING"
|
|
559
|
+
):
|
|
560
|
+
self.graph_store.add_edge(parent_node.id, child.id, "PARENT")
|
|
561
|
+
|
|
562
|
+
def _preprocess_message(self, message: QueueMessage) -> bool:
|
|
563
|
+
message = self._convert_id_to_node(message)
|
|
564
|
+
if None in message.after_node:
|
|
565
|
+
logger.debug(
|
|
566
|
+
f"Found non-existent node in after_node in message: {message}, skip this message."
|
|
567
|
+
)
|
|
568
|
+
return False
|
|
569
|
+
return True
|
|
570
|
+
|
|
571
|
+
def _convert_id_to_node(self, message: QueueMessage) -> QueueMessage:
|
|
572
|
+
"""
|
|
573
|
+
Convert IDs in the message.after_node to GraphDBNode objects.
|
|
574
|
+
"""
|
|
575
|
+
for i, node in enumerate(message.after_node or []):
|
|
576
|
+
if not isinstance(node, str):
|
|
577
|
+
continue
|
|
578
|
+
raw_node = self.graph_store.get_node(node)
|
|
579
|
+
if raw_node is None:
|
|
580
|
+
logger.debug(f"Node with ID {node} not found in the graph store.")
|
|
581
|
+
message.after_node[i] = None
|
|
582
|
+
else:
|
|
583
|
+
message.after_node[i] = GraphDBNode(**raw_node)
|
|
584
|
+
return message
|