MemoryOS 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MemoryOS might be problematic. Click here for more details.

Files changed (80) hide show
  1. {memoryos-0.2.0.dist-info → memoryos-0.2.1.dist-info}/METADATA +66 -26
  2. {memoryos-0.2.0.dist-info → memoryos-0.2.1.dist-info}/RECORD +80 -56
  3. memoryos-0.2.1.dist-info/entry_points.txt +3 -0
  4. memos/__init__.py +1 -1
  5. memos/api/config.py +471 -0
  6. memos/api/exceptions.py +28 -0
  7. memos/api/mcp_serve.py +502 -0
  8. memos/api/product_api.py +35 -0
  9. memos/api/product_models.py +159 -0
  10. memos/api/routers/__init__.py +1 -0
  11. memos/api/routers/product_router.py +358 -0
  12. memos/chunkers/sentence_chunker.py +8 -2
  13. memos/cli.py +113 -0
  14. memos/configs/embedder.py +27 -0
  15. memos/configs/graph_db.py +83 -2
  16. memos/configs/llm.py +47 -0
  17. memos/configs/mem_cube.py +1 -1
  18. memos/configs/mem_scheduler.py +91 -5
  19. memos/configs/memory.py +5 -4
  20. memos/dependency.py +52 -0
  21. memos/embedders/ark.py +92 -0
  22. memos/embedders/factory.py +4 -0
  23. memos/embedders/sentence_transformer.py +8 -2
  24. memos/embedders/universal_api.py +32 -0
  25. memos/graph_dbs/base.py +2 -2
  26. memos/graph_dbs/factory.py +2 -0
  27. memos/graph_dbs/neo4j.py +331 -122
  28. memos/graph_dbs/neo4j_community.py +300 -0
  29. memos/llms/base.py +9 -0
  30. memos/llms/deepseek.py +54 -0
  31. memos/llms/factory.py +10 -1
  32. memos/llms/hf.py +170 -13
  33. memos/llms/hf_singleton.py +114 -0
  34. memos/llms/ollama.py +4 -0
  35. memos/llms/openai.py +67 -1
  36. memos/llms/qwen.py +63 -0
  37. memos/llms/vllm.py +153 -0
  38. memos/mem_cube/general.py +77 -16
  39. memos/mem_cube/utils.py +102 -0
  40. memos/mem_os/core.py +131 -41
  41. memos/mem_os/main.py +93 -11
  42. memos/mem_os/product.py +1098 -35
  43. memos/mem_os/utils/default_config.py +352 -0
  44. memos/mem_os/utils/format_utils.py +1154 -0
  45. memos/mem_reader/simple_struct.py +5 -5
  46. memos/mem_scheduler/base_scheduler.py +467 -36
  47. memos/mem_scheduler/general_scheduler.py +125 -244
  48. memos/mem_scheduler/modules/base.py +9 -0
  49. memos/mem_scheduler/modules/dispatcher.py +68 -2
  50. memos/mem_scheduler/modules/misc.py +39 -0
  51. memos/mem_scheduler/modules/monitor.py +228 -49
  52. memos/mem_scheduler/modules/rabbitmq_service.py +317 -0
  53. memos/mem_scheduler/modules/redis_service.py +32 -22
  54. memos/mem_scheduler/modules/retriever.py +250 -23
  55. memos/mem_scheduler/modules/schemas.py +189 -7
  56. memos/mem_scheduler/mos_for_test_scheduler.py +143 -0
  57. memos/mem_scheduler/utils.py +51 -2
  58. memos/mem_user/persistent_user_manager.py +260 -0
  59. memos/memories/activation/item.py +25 -0
  60. memos/memories/activation/kv.py +10 -3
  61. memos/memories/activation/vllmkv.py +219 -0
  62. memos/memories/factory.py +2 -0
  63. memos/memories/textual/general.py +7 -5
  64. memos/memories/textual/tree.py +9 -5
  65. memos/memories/textual/tree_text_memory/organize/conflict.py +5 -3
  66. memos/memories/textual/tree_text_memory/organize/manager.py +26 -18
  67. memos/memories/textual/tree_text_memory/organize/redundancy.py +25 -44
  68. memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +11 -13
  69. memos/memories/textual/tree_text_memory/organize/reorganizer.py +73 -51
  70. memos/memories/textual/tree_text_memory/retrieve/recall.py +0 -1
  71. memos/memories/textual/tree_text_memory/retrieve/reranker.py +2 -2
  72. memos/memories/textual/tree_text_memory/retrieve/searcher.py +6 -5
  73. memos/parsers/markitdown.py +8 -2
  74. memos/templates/mem_reader_prompts.py +65 -23
  75. memos/templates/mem_scheduler_prompts.py +96 -47
  76. memos/templates/tree_reorganize_prompts.py +85 -30
  77. memos/vec_dbs/base.py +12 -0
  78. memos/vec_dbs/qdrant.py +46 -20
  79. {memoryos-0.2.0.dist-info → memoryos-0.2.1.dist-info}/LICENSE +0 -0
  80. {memoryos-0.2.0.dist-info → memoryos-0.2.1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,1154 @@
1
+ import math
2
+ import random
3
+
4
+ from typing import Any
5
+
6
+ from memos.log import get_logger
7
+ from memos.memories.activation.item import KVCacheItem
8
+
9
+
10
+ logger = get_logger(__name__)
11
+
12
+
13
+ def extract_node_name(memory: str) -> str:
14
+ """Extract the first two words from memory as node_name"""
15
+ if not memory:
16
+ return ""
17
+
18
+ words = [word.strip() for word in memory.split() if word.strip()]
19
+
20
+ if len(words) >= 2:
21
+ return " ".join(words[:2])
22
+ elif len(words) == 1:
23
+ return words[0]
24
+ else:
25
+ return ""
26
+
27
+
28
+ def analyze_tree_structure_enhanced(nodes: list[dict], edges: list[dict]) -> dict:
29
+ """Enhanced tree structure analysis, focusing on branching degree and leaf distribution"""
30
+ # Build adjacency list
31
+ adj_list = {}
32
+ reverse_adj = {}
33
+ for edge in edges:
34
+ source, target = edge["source"], edge["target"]
35
+ adj_list.setdefault(source, []).append(target)
36
+ reverse_adj.setdefault(target, []).append(source)
37
+
38
+ # Find all nodes and root nodes
39
+ all_nodes = {node["id"] for node in nodes}
40
+ target_nodes = {edge["target"] for edge in edges}
41
+ root_nodes = all_nodes - target_nodes
42
+
43
+ subtree_analysis = {}
44
+
45
+ def analyze_subtree_enhanced(root_id: str) -> dict:
46
+ """Enhanced subtree analysis, focusing on branching degree and structure quality"""
47
+ visited = set()
48
+ max_depth = 0
49
+ leaf_count = 0
50
+ total_nodes = 0
51
+ branch_nodes = 0 # Number of branch nodes with multiple children
52
+ chain_length = 0 # Longest single chain length
53
+ width_per_level = {} # Width per level
54
+
55
+ def dfs(node_id: str, depth: int, chain_len: int):
56
+ nonlocal max_depth, leaf_count, total_nodes, branch_nodes, chain_length
57
+
58
+ if node_id in visited:
59
+ return
60
+
61
+ visited.add(node_id)
62
+ total_nodes += 1
63
+ max_depth = max(max_depth, depth)
64
+ chain_length = max(chain_length, chain_len)
65
+
66
+ # Record number of nodes per level
67
+ width_per_level[depth] = width_per_level.get(depth, 0) + 1
68
+
69
+ children = adj_list.get(node_id, [])
70
+
71
+ if not children: # Leaf node
72
+ leaf_count += 1
73
+ elif len(children) > 1: # Branch node
74
+ branch_nodes += 1
75
+ # Reset chain length because we encountered a branch
76
+ for child in children:
77
+ dfs(child, depth + 1, 0)
78
+ else: # Single child node (chain structure)
79
+ for child in children:
80
+ dfs(child, depth + 1, chain_len + 1)
81
+
82
+ dfs(root_id, 0, 0)
83
+
84
+ # Calculate structure quality metrics
85
+ avg_width = sum(width_per_level.values()) / len(width_per_level) if width_per_level else 0
86
+ max_width = max(width_per_level.values()) if width_per_level else 0
87
+
88
+ # Calculate branch density: ratio of branch nodes to total nodes
89
+ branch_density = branch_nodes / total_nodes if total_nodes > 0 else 0
90
+
91
+ # Calculate depth-width ratio: ideal tree should have moderate depth and good breadth
92
+ depth_width_ratio = max_depth / max_width if max_width > 0 else max_depth
93
+
94
+ quality_score = calculate_enhanced_quality(
95
+ max_depth,
96
+ leaf_count,
97
+ total_nodes,
98
+ branch_nodes,
99
+ chain_length,
100
+ branch_density,
101
+ depth_width_ratio,
102
+ max_width,
103
+ )
104
+
105
+ return {
106
+ "root_id": root_id,
107
+ "max_depth": max_depth,
108
+ "leaf_count": leaf_count,
109
+ "total_nodes": total_nodes,
110
+ "branch_nodes": branch_nodes,
111
+ "max_chain_length": chain_length,
112
+ "branch_density": branch_density,
113
+ "max_width": max_width,
114
+ "avg_width": avg_width,
115
+ "depth_width_ratio": depth_width_ratio,
116
+ "nodes_in_subtree": list(visited),
117
+ "quality_score": quality_score,
118
+ "width_per_level": width_per_level,
119
+ }
120
+
121
+ for root_id in root_nodes:
122
+ subtree_analysis[root_id] = analyze_subtree_enhanced(root_id)
123
+
124
+ return subtree_analysis
125
+
126
+
127
+ def calculate_enhanced_quality(
128
+ max_depth: int,
129
+ leaf_count: int,
130
+ total_nodes: int,
131
+ branch_nodes: int,
132
+ max_chain_length: int,
133
+ branch_density: float,
134
+ depth_width_ratio: float,
135
+ max_width: int,
136
+ ) -> float:
137
+ """Enhanced quality calculation, prioritizing branching degree and leaf distribution"""
138
+
139
+ if total_nodes <= 1:
140
+ return 0.1
141
+
142
+ # 1. Branch quality score (weight: 35%)
143
+ # Branch node count score
144
+ branch_count_score = min(branch_nodes * 3, 15) # 3 points per branch node, max 15 points
145
+
146
+ # Branch density score: ideal density between 20%-60%
147
+ if 0.2 <= branch_density <= 0.6:
148
+ branch_density_score = 10
149
+ elif branch_density > 0.6:
150
+ branch_density_score = max(5, 10 - (branch_density - 0.6) * 20)
151
+ else:
152
+ branch_density_score = branch_density * 25 # Linear growth for 0-20%
153
+
154
+ branch_score = (branch_count_score + branch_density_score) * 0.35
155
+
156
+ # 2. Leaf quality score (weight: 25%)
157
+ # Leaf count score
158
+ leaf_count_score = min(leaf_count * 2, 20)
159
+
160
+ # Leaf distribution score: ideal leaf ratio 30%-70% of total nodes
161
+ leaf_ratio = leaf_count / total_nodes
162
+ if 0.3 <= leaf_ratio <= 0.7:
163
+ leaf_ratio_score = 10
164
+ elif leaf_ratio > 0.7:
165
+ leaf_ratio_score = max(3, 10 - (leaf_ratio - 0.7) * 20)
166
+ else:
167
+ leaf_ratio_score = leaf_ratio * 20 # Linear growth for 0-30%
168
+
169
+ leaf_score = (leaf_count_score + leaf_ratio_score) * 0.25
170
+
171
+ # 3. Structure balance score (weight: 25%)
172
+ # Depth score: moderate depth is best (3-8 layers)
173
+ if 3 <= max_depth <= 8:
174
+ depth_score = 15
175
+ elif max_depth < 3:
176
+ depth_score = max_depth * 3 # Lower score for 1-2 layers
177
+ else:
178
+ depth_score = max(5, 15 - (max_depth - 8) * 1.5) # Gradually reduce score beyond 8 layers
179
+
180
+ # Width score: larger max width is better, but with upper limit
181
+ width_score = min(max_width * 1.5, 15)
182
+
183
+ # Depth-width ratio penalty: too large ratio means tree is too "thin"
184
+ if depth_width_ratio > 3:
185
+ ratio_penalty = (depth_width_ratio - 3) * 2
186
+ structure_score = max(0, (depth_score + width_score - ratio_penalty)) * 0.25
187
+ else:
188
+ structure_score = (depth_score + width_score) * 0.25
189
+
190
+ # 4. Chain structure penalty (weight: 15%)
191
+ # Longest single chain length penalty: overly long chains severely affect display
192
+ if max_chain_length <= 2:
193
+ chain_penalty_score = 10
194
+ elif max_chain_length <= 5:
195
+ chain_penalty_score = 8 - (max_chain_length - 2)
196
+ else:
197
+ chain_penalty_score = max(0, 3 - (max_chain_length - 5) * 0.5)
198
+
199
+ chain_score = chain_penalty_score * 0.15
200
+
201
+ # 5. Comprehensive calculation
202
+ total_score = branch_score + leaf_score + structure_score + chain_score
203
+
204
+ # Special case severe penalties
205
+ if max_chain_length > total_nodes * 0.8: # If more than 80% are single chains
206
+ total_score *= 0.3
207
+ elif branch_density < 0.1 and total_nodes > 5: # Large tree with almost no branches
208
+ total_score *= 0.5
209
+
210
+ return total_score
211
+
212
+
213
+ def sample_nodes_with_type_balance(
214
+ nodes: list[dict],
215
+ edges: list[dict],
216
+ target_count: int = 150,
217
+ type_ratios: dict[str, float] | None = None,
218
+ ) -> tuple[list[dict], list[dict]]:
219
+ """
220
+ Balanced sampling based on type ratios and tree quality
221
+
222
+ Args:
223
+ nodes: List of nodes
224
+ edges: List of edges
225
+ target_count: Target number of nodes
226
+ type_ratios: Expected ratio for each type, e.g. {'WorkingMemory': 0.15, 'EpisodicMemory': 0.30, ...}
227
+ """
228
+ if len(nodes) <= target_count:
229
+ return nodes, edges
230
+
231
+ # Default type ratio configuration
232
+ if type_ratios is None:
233
+ type_ratios = {
234
+ "WorkingMemory": 0.10, # 10%
235
+ "EpisodicMemory": 0.25, # 25%
236
+ "SemanticMemory": 0.25, # 25%
237
+ "ProceduralMemory": 0.20, # 20%
238
+ "EmotionalMemory": 0.15, # 15%
239
+ "MetaMemory": 0.05, # 5%
240
+ }
241
+
242
+ print(
243
+ f"Starting type-balanced sampling, original nodes: {len(nodes)}, target nodes: {target_count}"
244
+ )
245
+ print(f"Target type ratios: {type_ratios}")
246
+
247
+ # Analyze current node type distribution
248
+ current_type_counts = {}
249
+ nodes_by_type = {}
250
+
251
+ for node in nodes:
252
+ memory_type = node.get("metadata", {}).get("memory_type", "Unknown")
253
+ current_type_counts[memory_type] = current_type_counts.get(memory_type, 0) + 1
254
+ if memory_type not in nodes_by_type:
255
+ nodes_by_type[memory_type] = []
256
+ nodes_by_type[memory_type].append(node)
257
+
258
+ print(f"Current type distribution: {current_type_counts}")
259
+
260
+ # Calculate target node count for each type
261
+ type_targets = {}
262
+ remaining_target = target_count
263
+
264
+ for memory_type, ratio in type_ratios.items():
265
+ if memory_type in nodes_by_type:
266
+ target_for_type = int(target_count * ratio)
267
+ # Ensure not exceeding the actual node count for this type
268
+ target_for_type = min(target_for_type, len(nodes_by_type[memory_type]))
269
+ type_targets[memory_type] = target_for_type
270
+ remaining_target -= target_for_type
271
+
272
+ # Handle types not in ratio configuration
273
+ other_types = set(nodes_by_type.keys()) - set(type_ratios.keys())
274
+ if other_types and remaining_target > 0:
275
+ per_other_type = max(1, remaining_target // len(other_types))
276
+ for memory_type in other_types:
277
+ allocation = min(per_other_type, len(nodes_by_type[memory_type]))
278
+ type_targets[memory_type] = allocation
279
+ remaining_target -= allocation
280
+
281
+ # If there's still remaining, distribute proportionally to main types
282
+ if remaining_target > 0:
283
+ main_types = [t for t in type_ratios if t in nodes_by_type]
284
+ if main_types:
285
+ extra_per_type = remaining_target // len(main_types)
286
+ for memory_type in main_types:
287
+ additional = min(
288
+ extra_per_type,
289
+ len(nodes_by_type[memory_type]) - type_targets.get(memory_type, 0),
290
+ )
291
+ type_targets[memory_type] = type_targets.get(memory_type, 0) + additional
292
+
293
+ print(f"Target node count for each type: {type_targets}")
294
+
295
+ # Perform subtree quality sampling for each type
296
+ selected_nodes = []
297
+
298
+ for memory_type, target_for_type in type_targets.items():
299
+ if target_for_type <= 0 or memory_type not in nodes_by_type:
300
+ continue
301
+
302
+ type_nodes = nodes_by_type[memory_type]
303
+ print(f"\n--- Processing {memory_type} type: {len(type_nodes)} -> {target_for_type} ---")
304
+
305
+ if len(type_nodes) <= target_for_type:
306
+ selected_nodes.extend(type_nodes)
307
+ print(f" Select all: {len(type_nodes)} nodes")
308
+ else:
309
+ # Use enhanced subtree quality sampling
310
+ type_selected = sample_by_enhanced_subtree_quality(type_nodes, edges, target_for_type)
311
+ selected_nodes.extend(type_selected)
312
+ print(f" Sampled selection: {len(type_selected)} nodes")
313
+
314
+ # Filter edges
315
+ selected_node_ids = {node["id"] for node in selected_nodes}
316
+ filtered_edges = [
317
+ edge
318
+ for edge in edges
319
+ if edge["source"] in selected_node_ids and edge["target"] in selected_node_ids
320
+ ]
321
+
322
+ print(f"\nFinal selected nodes: {len(selected_nodes)}")
323
+ print(f"Final edges: {len(filtered_edges)}")
324
+
325
+ # Verify final type distribution
326
+ final_type_counts = {}
327
+ for node in selected_nodes:
328
+ memory_type = node.get("metadata", {}).get("memory_type", "Unknown")
329
+ final_type_counts[memory_type] = final_type_counts.get(memory_type, 0) + 1
330
+
331
+ print(f"Final type distribution: {final_type_counts}")
332
+ for memory_type, count in final_type_counts.items():
333
+ percentage = count / len(selected_nodes) * 100
334
+ target_percentage = type_ratios.get(memory_type, 0) * 100
335
+ print(
336
+ f" {memory_type}: {count} nodes ({percentage:.1f}%, target: {target_percentage:.1f}%)"
337
+ )
338
+
339
+ return selected_nodes, filtered_edges
340
+
341
+
342
+ def sample_by_enhanced_subtree_quality(
343
+ nodes: list[dict], edges: list[dict], target_count: int
344
+ ) -> list[dict]:
345
+ """Sample using enhanced subtree quality"""
346
+ if len(nodes) <= target_count:
347
+ return nodes
348
+
349
+ # Analyze subtree structure
350
+ subtree_analysis = analyze_tree_structure_enhanced(nodes, edges)
351
+
352
+ if not subtree_analysis:
353
+ # If no subtree structure, sample by node importance
354
+ return sample_nodes_by_importance(nodes, edges, target_count)
355
+
356
+ # Sort subtrees by quality score
357
+ sorted_subtrees = sorted(
358
+ subtree_analysis.items(), key=lambda x: x[1]["quality_score"], reverse=True
359
+ )
360
+
361
+ print(" Subtree quality ranking:")
362
+ for i, (root_id, analysis) in enumerate(sorted_subtrees[:5]):
363
+ print(
364
+ f" #{i + 1} Root node {root_id}: Quality={analysis['quality_score']:.2f}, "
365
+ f"Depth={analysis['max_depth']}, Branches={analysis['branch_nodes']}, "
366
+ f"Leaves={analysis['leaf_count']}, Max Width={analysis['max_width']}"
367
+ )
368
+
369
+ # Greedy selection of high-quality subtrees
370
+ selected_nodes = []
371
+ selected_node_ids = set()
372
+
373
+ for root_id, analysis in sorted_subtrees:
374
+ subtree_nodes = analysis["nodes_in_subtree"]
375
+ new_nodes = [node_id for node_id in subtree_nodes if node_id not in selected_node_ids]
376
+
377
+ if not new_nodes:
378
+ continue
379
+
380
+ remaining_quota = target_count - len(selected_nodes)
381
+
382
+ if len(new_nodes) <= remaining_quota:
383
+ # Entire subtree can be added
384
+ for node_id in new_nodes:
385
+ node = next((n for n in nodes if n["id"] == node_id), None)
386
+ if node:
387
+ selected_nodes.append(node)
388
+ selected_node_ids.add(node_id)
389
+ print(f" Select entire subtree {root_id}: +{len(new_nodes)} nodes")
390
+ else:
391
+ # Subtree too large, need partial selection
392
+ if analysis["quality_score"] > 5: # Only partial selection for high-quality subtrees
393
+ subtree_node_objects = [n for n in nodes if n["id"] in new_nodes]
394
+ partial_selection = select_best_nodes_from_subtree(
395
+ subtree_node_objects, edges, remaining_quota, root_id
396
+ )
397
+
398
+ selected_nodes.extend(partial_selection)
399
+ for node in partial_selection:
400
+ selected_node_ids.add(node["id"])
401
+ print(
402
+ f" Partial selection of subtree {root_id}: +{len(partial_selection)} nodes"
403
+ )
404
+
405
+ if len(selected_nodes) >= target_count:
406
+ break
407
+
408
+ # If target count not reached, supplement with remaining nodes
409
+ if len(selected_nodes) < target_count:
410
+ remaining_nodes = [n for n in nodes if n["id"] not in selected_node_ids]
411
+ remaining_count = target_count - len(selected_nodes)
412
+ additional = sample_nodes_by_importance(remaining_nodes, edges, remaining_count)
413
+ selected_nodes.extend(additional)
414
+ print(f" Supplementary selection: +{len(additional)} nodes")
415
+
416
+ return selected_nodes
417
+
418
+
419
+ def select_best_nodes_from_subtree(
420
+ subtree_nodes: list[dict], edges: list[dict], max_count: int, root_id: str
421
+ ) -> list[dict]:
422
+ """Select the most important nodes from subtree, prioritizing branch structure"""
423
+ if len(subtree_nodes) <= max_count:
424
+ return subtree_nodes
425
+
426
+ # Build internal connection relationships of subtree
427
+ subtree_node_ids = {node["id"] for node in subtree_nodes}
428
+ subtree_edges = [
429
+ edge
430
+ for edge in edges
431
+ if edge["source"] in subtree_node_ids and edge["target"] in subtree_node_ids
432
+ ]
433
+
434
+ # Calculate importance score for each node
435
+ node_scores = []
436
+
437
+ for node in subtree_nodes:
438
+ node_id = node["id"]
439
+
440
+ # Out-degree and in-degree
441
+ out_degree = sum(1 for edge in subtree_edges if edge["source"] == node_id)
442
+ in_degree = sum(1 for edge in subtree_edges if edge["target"] == node_id)
443
+
444
+ # Content length score
445
+ content_score = min(len(node.get("memory", "")), 300) / 15
446
+
447
+ # Branch node bonus
448
+ branch_bonus = out_degree * 8 if out_degree > 1 else 0
449
+
450
+ # Root node bonus
451
+ root_bonus = 15 if node_id == root_id else 0
452
+
453
+ # Connection importance
454
+ connection_score = (out_degree + in_degree) * 3
455
+
456
+ # Leaf node moderate bonus (ensure certain number of leaf nodes)
457
+ leaf_bonus = 5 if out_degree == 0 and in_degree > 0 else 0
458
+
459
+ total_score = content_score + connection_score + branch_bonus + root_bonus + leaf_bonus
460
+ node_scores.append((node, total_score))
461
+
462
+ # Sort by score and select
463
+ node_scores.sort(key=lambda x: x[1], reverse=True)
464
+ selected = [node for node, _ in node_scores[:max_count]]
465
+
466
+ return selected
467
+
468
+
469
+ def sample_nodes_by_importance(
470
+ nodes: list[dict], edges: list[dict], target_count: int
471
+ ) -> list[dict]:
472
+ """Sample by node importance (for cases without tree structure)"""
473
+ if len(nodes) <= target_count:
474
+ return nodes
475
+
476
+ node_scores = []
477
+
478
+ for node in nodes:
479
+ node_id = node["id"]
480
+ out_degree = sum(1 for edge in edges if edge["source"] == node_id)
481
+ in_degree = sum(1 for edge in edges if edge["target"] == node_id)
482
+ content_score = min(len(node.get("memory", "")), 200) / 10
483
+ connection_score = (out_degree + in_degree) * 5
484
+ random_score = random.random() * 10
485
+
486
+ total_score = content_score + connection_score + random_score
487
+ node_scores.append((node, total_score))
488
+
489
+ node_scores.sort(key=lambda x: x[1], reverse=True)
490
+ return [node for node, _ in node_scores[:target_count]]
491
+
492
+
493
+ # Modified main function to use new sampling strategy
494
+ def convert_graph_to_tree_forworkmem(
495
+ json_data: dict[str, Any],
496
+ target_node_count: int = 150,
497
+ type_ratios: dict[str, float] | None = None,
498
+ ) -> dict[str, Any]:
499
+ """
500
+ Enhanced graph-to-tree conversion function, prioritizing branching degree and type balance
501
+ """
502
+ original_nodes = json_data.get("nodes", [])
503
+ original_edges = json_data.get("edges", [])
504
+
505
+ print(f"Original node count: {len(original_nodes)}")
506
+ print(f"Target node count: {target_node_count}")
507
+ filter_original_edges = []
508
+ for original_edge in original_edges:
509
+ if original_edge["type"] == "PARENT":
510
+ filter_original_edges.append(original_edge)
511
+ node_type_count = {}
512
+ for node in original_nodes:
513
+ node_type = node.get("metadata", {}).get("memory_type", "Unknown")
514
+ node_type_count[node_type] = node_type_count.get(node_type, 0) + 1
515
+ original_edges = filter_original_edges
516
+ # Use enhanced type-balanced sampling
517
+ if len(original_nodes) > target_node_count:
518
+ nodes, edges = sample_nodes_with_type_balance(
519
+ original_nodes, original_edges, target_node_count, type_ratios
520
+ )
521
+ else:
522
+ nodes, edges = original_nodes, original_edges
523
+
524
+ # The rest of tree structure building remains unchanged...
525
+ # [Original tree building code here]
526
+
527
+ # Create node mapping table
528
+ node_map = {}
529
+ for node in nodes:
530
+ memory = node.get("memory", "")
531
+ node_name = extract_node_name(memory)
532
+ memory_key = node.get("metadata", {}).get("key", node_name)
533
+ usage = node.get("metadata", {}).get("usage", [])
534
+ frequency = len(usage)
535
+ node_map[node["id"]] = {
536
+ "id": node["id"],
537
+ "value": memory,
538
+ "frequency": frequency,
539
+ "node_name": memory_key,
540
+ "memory_type": node.get("metadata", {}).get("memory_type", "Unknown"),
541
+ "children": [],
542
+ }
543
+
544
+ # Build parent-child relationship mapping
545
+ children_map = {}
546
+ parent_map = {}
547
+
548
+ for edge in edges:
549
+ source = edge["source"]
550
+ target = edge["target"]
551
+ if source not in children_map:
552
+ children_map[source] = []
553
+ children_map[source].append(target)
554
+ parent_map[target] = source
555
+
556
+ # Find root nodes
557
+ all_node_ids = set(node_map.keys())
558
+ children_node_ids = set(parent_map.keys())
559
+ root_node_ids = all_node_ids - children_node_ids
560
+
561
+ # Separate WorkingMemory and other root nodes
562
+ working_memory_roots = []
563
+ other_roots = []
564
+
565
+ for root_id in root_node_ids:
566
+ if node_map[root_id]["memory_type"] == "WorkingMemory":
567
+ working_memory_roots.append(root_id)
568
+ else:
569
+ other_roots.append(root_id)
570
+
571
+ def build_tree(node_id: str) -> dict[str, Any]:
572
+ """Recursively build tree structure"""
573
+ if node_id not in node_map:
574
+ return None
575
+
576
+ children_ids = children_map.get(node_id, [])
577
+ children = []
578
+ for child_id in children_ids:
579
+ child_tree = build_tree(child_id)
580
+ if child_tree:
581
+ children.append(child_tree)
582
+
583
+ node = {
584
+ "id": node_id,
585
+ "node_name": node_map[node_id]["node_name"],
586
+ "value": node_map[node_id]["value"],
587
+ "memory_type": node_map[node_id]["memory_type"],
588
+ "frequency": node_map[node_id]["frequency"],
589
+ }
590
+
591
+ if children:
592
+ node["children"] = children
593
+
594
+ return node
595
+
596
+ # Build root tree list
597
+ root_trees = []
598
+ for root_id in other_roots:
599
+ tree = build_tree(root_id)
600
+ if tree:
601
+ root_trees.append(tree)
602
+
603
+ # Handle WorkingMemory
604
+ if working_memory_roots:
605
+ working_memory_children = []
606
+ for wm_root_id in working_memory_roots:
607
+ tree = build_tree(wm_root_id)
608
+ if tree:
609
+ working_memory_children.append(tree)
610
+
611
+ working_memory_node = {
612
+ "id": "WorkingMemory",
613
+ "node_name": "WorkingMemory",
614
+ "value": "WorkingMemory",
615
+ "memory_type": "WorkingMemory",
616
+ "children": working_memory_children,
617
+ "frequency": 0,
618
+ }
619
+
620
+ root_trees.append(working_memory_node)
621
+
622
+ # Create total root node
623
+ result = {
624
+ "id": "root",
625
+ "node_name": "root",
626
+ "value": "root",
627
+ "memory_type": "Root",
628
+ "children": root_trees,
629
+ "frequency": 0,
630
+ }
631
+
632
+ return result, node_type_count
633
+
634
+
635
+ def print_tree_structure(node: dict[str, Any], level: int = 0, max_level: int = 5):
636
+ """Print the first few layers of tree structure for easy viewing"""
637
+ if level > max_level:
638
+ return
639
+
640
+ indent = " " * level
641
+ node_id = node.get("id", "unknown")
642
+ node_name = node.get("node_name", "")
643
+ node_value = node.get("value", "")
644
+ memory_type = node.get("memory_type", "Unknown")
645
+
646
+ # Determine display method based on whether there are children
647
+ children = node.get("children", [])
648
+ if children:
649
+ # Intermediate node, display name, type and child count
650
+ print(f"{indent}- {node_name} [{memory_type}] ({len(children)} children)")
651
+ print(f"{indent} ID: {node_id}")
652
+ display_value = node_value[:80] + "..." if len(node_value) > 80 else node_value
653
+ print(f"{indent} Value: {display_value}")
654
+
655
+ if level < max_level:
656
+ for child in children:
657
+ print_tree_structure(child, level + 1, max_level)
658
+ elif level == max_level:
659
+ print(f"{indent} ... (expansion limited)")
660
+ else:
661
+ # Leaf node, display name, type and value
662
+ display_value = node_value[:80] + "..." if len(node_value) > 80 else node_value
663
+ print(f"{indent}- {node_name} [{memory_type}]: {display_value}")
664
+ print(f"{indent} ID: {node_id}")
665
+
666
+
667
+ def analyze_final_tree_quality(tree_data: dict[str, Any]) -> dict:
668
+ """Analyze final tree quality, including type diversity, branch structure, etc."""
669
+ stats = {
670
+ "total_nodes": 0,
671
+ "by_type": {},
672
+ "by_depth": {},
673
+ "max_depth": 0,
674
+ "total_leaves": 0,
675
+ "total_branches": 0, # Number of branch nodes with multiple children
676
+ "subtrees": [],
677
+ "type_diversity": {},
678
+ "structure_quality": {},
679
+ "chain_analysis": {}, # Single chain structure analysis
680
+ }
681
+
682
+ def analyze_subtree(node, depth=0, parent_path="", chain_length=0):
683
+ stats["total_nodes"] += 1
684
+ stats["max_depth"] = max(stats["max_depth"], depth)
685
+
686
+ # Count by type
687
+ memory_type = node.get("memory_type", "Unknown")
688
+ stats["by_type"][memory_type] = stats["by_type"].get(memory_type, 0) + 1
689
+
690
+ # Count by depth
691
+ stats["by_depth"][depth] = stats["by_depth"].get(depth, 0) + 1
692
+
693
+ children = node.get("children", [])
694
+ current_path = (
695
+ f"{parent_path}/{node.get('node_name', 'unknown')}"
696
+ if parent_path
697
+ else node.get("node_name", "root")
698
+ )
699
+
700
+ # Analyze node type
701
+ if not children: # Leaf node
702
+ stats["total_leaves"] += 1
703
+ # Record chain length
704
+ if "max_chain_length" not in stats["chain_analysis"]:
705
+ stats["chain_analysis"]["max_chain_length"] = 0
706
+ stats["chain_analysis"]["max_chain_length"] = max(
707
+ stats["chain_analysis"]["max_chain_length"], chain_length
708
+ )
709
+ elif len(children) == 1: # Single child node (chain)
710
+ # Continue calculating chain length
711
+ for child in children:
712
+ analyze_subtree(child, depth + 1, current_path, chain_length + 1)
713
+ return # Early return to avoid duplicate processing
714
+ else: # Branch node (multiple children)
715
+ stats["total_branches"] += 1
716
+ # Reset chain length
717
+ chain_length = 0
718
+
719
+ # If it's the root node of a major subtree, analyze its characteristics
720
+ if depth <= 2 and children: # Major subtree
721
+ subtree_depth = 0
722
+ subtree_leaves = 0
723
+ subtree_nodes = 0
724
+ subtree_branches = 0
725
+ subtree_types = {}
726
+ subtree_max_width = 0
727
+ width_per_level = {}
728
+
729
+ def count_subtree(subnode, subdepth):
730
+ nonlocal \
731
+ subtree_depth, \
732
+ subtree_leaves, \
733
+ subtree_nodes, \
734
+ subtree_branches, \
735
+ subtree_max_width
736
+ subtree_nodes += 1
737
+ subtree_depth = max(subtree_depth, subdepth)
738
+
739
+ # Count type distribution within subtree
740
+ sub_memory_type = subnode.get("memory_type", "Unknown")
741
+ subtree_types[sub_memory_type] = subtree_types.get(sub_memory_type, 0) + 1
742
+
743
+ # Count width per level
744
+ width_per_level[subdepth] = width_per_level.get(subdepth, 0) + 1
745
+ subtree_max_width = max(subtree_max_width, width_per_level[subdepth])
746
+
747
+ subchildren = subnode.get("children", [])
748
+ if not subchildren:
749
+ subtree_leaves += 1
750
+ elif len(subchildren) > 1:
751
+ subtree_branches += 1
752
+
753
+ for child in subchildren:
754
+ count_subtree(child, subdepth + 1)
755
+
756
+ count_subtree(node, 0)
757
+
758
+ # Calculate subtree quality metrics
759
+ branch_density = subtree_branches / subtree_nodes if subtree_nodes > 0 else 0
760
+ leaf_ratio = subtree_leaves / subtree_nodes if subtree_nodes > 0 else 0
761
+ depth_width_ratio = (
762
+ subtree_depth / subtree_max_width if subtree_max_width > 0 else subtree_depth
763
+ )
764
+
765
+ stats["subtrees"].append(
766
+ {
767
+ "root": node.get("node_name", "unknown"),
768
+ "type": memory_type,
769
+ "depth": subtree_depth,
770
+ "leaves": subtree_leaves,
771
+ "nodes": subtree_nodes,
772
+ "branches": subtree_branches,
773
+ "branch_density": branch_density,
774
+ "leaf_ratio": leaf_ratio,
775
+ "max_width": subtree_max_width,
776
+ "depth_width_ratio": depth_width_ratio,
777
+ "path": current_path,
778
+ "type_distribution": subtree_types,
779
+ "quality_score": calculate_enhanced_quality(
780
+ subtree_depth,
781
+ subtree_leaves,
782
+ subtree_nodes,
783
+ subtree_branches,
784
+ 0,
785
+ branch_density,
786
+ depth_width_ratio,
787
+ subtree_max_width,
788
+ ),
789
+ }
790
+ )
791
+
792
+ # Recursively analyze child nodes
793
+ for child in children:
794
+ analyze_subtree(child, depth + 1, current_path, 0) # Reset chain length
795
+
796
+ analyze_subtree(tree_data)
797
+
798
+ # Calculate overall structure quality
799
+ if stats["total_nodes"] > 1:
800
+ branch_density = stats["total_branches"] / stats["total_nodes"]
801
+ leaf_ratio = stats["total_leaves"] / stats["total_nodes"]
802
+
803
+ # Calculate average width per level
804
+ total_width = sum(stats["by_depth"].values())
805
+ avg_width = total_width / len(stats["by_depth"]) if stats["by_depth"] else 0
806
+ max_width = max(stats["by_depth"].values()) if stats["by_depth"] else 0
807
+
808
+ stats["structure_quality"] = {
809
+ "branch_density": branch_density,
810
+ "leaf_ratio": leaf_ratio,
811
+ "avg_width": avg_width,
812
+ "max_width": max_width,
813
+ "depth_width_ratio": stats["max_depth"] / max_width
814
+ if max_width > 0
815
+ else stats["max_depth"],
816
+ "is_well_balanced": 0.2 <= branch_density <= 0.6 and 0.3 <= leaf_ratio <= 0.7,
817
+ }
818
+
819
+ # Calculate type diversity metrics
820
+ total_types = len(stats["by_type"])
821
+ if total_types > 1:
822
+ # Calculate uniformity of type distribution (Shannon diversity index)
823
+ shannon_diversity = 0
824
+ for count in stats["by_type"].values():
825
+ if count > 0:
826
+ p = count / stats["total_nodes"]
827
+ shannon_diversity -= p * math.log2(p)
828
+
829
+ # Normalize diversity index (0-1 range)
830
+ max_diversity = math.log2(total_types) if total_types > 1 else 0
831
+ normalized_diversity = shannon_diversity / max_diversity if max_diversity > 0 else 0
832
+
833
+ stats["type_diversity"] = {
834
+ "total_types": total_types,
835
+ "shannon_diversity": shannon_diversity,
836
+ "normalized_diversity": normalized_diversity,
837
+ "distribution_balance": min(stats["by_type"].values()) / max(stats["by_type"].values())
838
+ if max(stats["by_type"].values()) > 0
839
+ else 0,
840
+ }
841
+
842
+ # Single chain analysis
843
+ total_single_child_nodes = sum(
844
+ 1 for subtree in stats["subtrees"] if subtree.get("branch_density", 0) < 0.1
845
+ )
846
+ stats["chain_analysis"].update(
847
+ {
848
+ "single_chain_subtrees": total_single_child_nodes,
849
+ "chain_subtree_ratio": total_single_child_nodes / len(stats["subtrees"])
850
+ if stats["subtrees"]
851
+ else 0,
852
+ }
853
+ )
854
+
855
+ return stats
856
+
857
+
858
+ def print_tree_analysis(tree_data: dict[str, Any]):
859
+ """Print enhanced tree analysis results"""
860
+ stats = analyze_final_tree_quality(tree_data)
861
+
862
+ print("\n" + "=" * 60)
863
+ print("🌳 Enhanced Tree Structure Quality Analysis Report")
864
+ print("=" * 60)
865
+
866
+ # Basic statistics
867
+ print("\n📊 Basic Statistics:")
868
+ print(f" Total nodes: {stats['total_nodes']}")
869
+ print(f" Max depth: {stats['max_depth']}")
870
+ print(
871
+ f" Leaf nodes: {stats['total_leaves']} ({stats['total_leaves'] / stats['total_nodes'] * 100:.1f}%)"
872
+ )
873
+ print(
874
+ f" Branch nodes: {stats['total_branches']} ({stats['total_branches'] / stats['total_nodes'] * 100:.1f}%)"
875
+ )
876
+
877
+ # Structure quality assessment
878
+ structure = stats.get("structure_quality", {})
879
+ if structure:
880
+ print("\n🏗️ Structure Quality Assessment:")
881
+ print(
882
+ f" Branch density: {structure['branch_density']:.3f} ({'✅ Good' if 0.2 <= structure['branch_density'] <= 0.6 else '⚠️ Needs improvement'})"
883
+ )
884
+ print(
885
+ f" Leaf ratio: {structure['leaf_ratio']:.3f} ({'✅ Good' if 0.3 <= structure['leaf_ratio'] <= 0.7 else '⚠️ Needs improvement'})"
886
+ )
887
+ print(f" Max width: {structure['max_width']}")
888
+ print(
889
+ f" Depth-width ratio: {structure['depth_width_ratio']:.2f} ({'✅ Good' if structure['depth_width_ratio'] <= 3 else '⚠️ Too thin'})"
890
+ )
891
+ print(
892
+ f" Overall balance: {'✅ Good' if structure['is_well_balanced'] else '⚠️ Needs improvement'}"
893
+ )
894
+
895
+ # Single chain analysis
896
+ chain_analysis = stats.get("chain_analysis", {})
897
+ if chain_analysis:
898
+ print("\n🔗 Single Chain Structure Analysis:")
899
+ print(f" Longest chain: {chain_analysis.get('max_chain_length', 0)} layers")
900
+ print(f" Single chain subtrees: {chain_analysis.get('single_chain_subtrees', 0)}")
901
+ print(
902
+ f" Single chain subtree ratio: {chain_analysis.get('chain_subtree_ratio', 0) * 100:.1f}%"
903
+ )
904
+
905
+ if chain_analysis.get("max_chain_length", 0) > 5:
906
+ print(" ⚠️ Warning: Overly long single chain structure may affect display")
907
+ elif chain_analysis.get("chain_subtree_ratio", 0) > 0.3:
908
+ print(
909
+ " ⚠️ Warning: Too many single chain subtrees, suggest increasing branch structure"
910
+ )
911
+ else:
912
+ print(" ✅ Single chain structure well controlled")
913
+
914
+ # Type diversity
915
+ type_div = stats.get("type_diversity", {})
916
+ if type_div:
917
+ print("\n🎨 Type Diversity Analysis:")
918
+ print(f" Total types: {type_div['total_types']}")
919
+ print(f" Diversity index: {type_div['shannon_diversity']:.3f}")
920
+ print(f" Normalized diversity: {type_div['normalized_diversity']:.3f}")
921
+ print(f" Distribution balance: {type_div['distribution_balance']:.3f}")
922
+
923
+ # Type distribution
924
+ print("\n📋 Type Distribution Details:")
925
+ for mem_type, count in sorted(stats["by_type"].items(), key=lambda x: x[1], reverse=True):
926
+ percentage = count / stats["total_nodes"] * 100
927
+ print(f" {mem_type}: {count} nodes ({percentage:.1f}%)")
928
+
929
+ # Depth distribution
930
+ print("\n📏 Depth Distribution:")
931
+ for depth in sorted(stats["by_depth"].keys()):
932
+ count = stats["by_depth"][depth]
933
+ print(f" Depth {depth}: {count} nodes")
934
+
935
+ # Major subtree analysis
936
+ if stats["subtrees"]:
937
+ print("\n🌲 Major Subtree Analysis (sorted by quality):")
938
+ sorted_subtrees = sorted(
939
+ stats["subtrees"], key=lambda x: x.get("quality_score", 0), reverse=True
940
+ )
941
+ for i, subtree in enumerate(sorted_subtrees[:8]): # Show first 8
942
+ quality = subtree.get("quality_score", 0)
943
+ print(f" #{i + 1} {subtree['root']} [{subtree['type']}]:")
944
+ print(f" Quality score: {quality:.2f}")
945
+ print(
946
+ f" Structure: Depth={subtree['depth']}, Branches={subtree['branches']}, Leaves={subtree['leaves']}"
947
+ )
948
+ print(
949
+ f" Density: Branch density={subtree.get('branch_density', 0):.3f}, Leaf ratio={subtree.get('leaf_ratio', 0):.3f}"
950
+ )
951
+
952
+ if quality > 15:
953
+ print(" ✅ High quality subtree")
954
+ elif quality > 8:
955
+ print(" 🟡 Medium quality subtree")
956
+ else:
957
+ print(" 🔴 Low quality subtree")
958
+
959
+ print("\n" + "=" * 60)
960
+
961
+
962
+ def remove_embedding_recursive(memory_info: dict) -> Any:
963
+ """remove the embedding from the memory info
964
+ Args:
965
+ memory_info: product memory info
966
+
967
+ Returns:
968
+ Any: product memory info without embedding
969
+ """
970
+ if isinstance(memory_info, dict):
971
+ new_dict = {}
972
+ for key, value in memory_info.items():
973
+ if key != "embedding":
974
+ new_dict[key] = remove_embedding_recursive(value)
975
+ return new_dict
976
+ elif isinstance(memory_info, list):
977
+ return [remove_embedding_recursive(item) for item in memory_info]
978
+ else:
979
+ return memory_info
980
+
981
+
982
+ def remove_embedding_from_memory_items(memory_items: list[Any]) -> list[dict]:
983
+ """Batch remove embedding fields from multiple TextualMemoryItem objects"""
984
+ clean_memories = []
985
+
986
+ for item in memory_items:
987
+ memory_dict = item.model_dump()
988
+
989
+ # Remove embedding from metadata
990
+ if "metadata" in memory_dict and "embedding" in memory_dict["metadata"]:
991
+ del memory_dict["metadata"]["embedding"]
992
+
993
+ clean_memories.append(memory_dict)
994
+
995
+ return clean_memories
996
+
997
+
998
+ def sort_children_by_memory_type(children: list[dict[str, Any]]) -> list[dict[str, Any]]:
999
+ """
1000
+ sort the children by the memory_type
1001
+ Args:
1002
+ children: the children of the node
1003
+ Returns:
1004
+ the sorted children
1005
+ """
1006
+ if not children:
1007
+ return children
1008
+
1009
+ def get_sort_key(child):
1010
+ memory_type = child.get("memory_type", "Unknown")
1011
+ # Sort directly by memory_type string, same types will naturally cluster together
1012
+ return memory_type
1013
+
1014
+ # Sort by memory_type
1015
+ sorted_children = sorted(children, key=get_sort_key)
1016
+
1017
+ return sorted_children
1018
+
1019
+
1020
+ def extract_all_ids_from_tree(tree_node):
1021
+ """
1022
+ Recursively traverse tree structure to extract all node IDs
1023
+
1024
+ Args:
1025
+ tree_node: Tree node (dictionary format)
1026
+
1027
+ Returns:
1028
+ set: Set containing all node IDs
1029
+ """
1030
+ ids = set()
1031
+
1032
+ # Add current node ID (if exists)
1033
+ if "id" in tree_node:
1034
+ ids.add(tree_node["id"])
1035
+
1036
+ # Recursively process child nodes
1037
+ if tree_node.get("children"):
1038
+ for child in tree_node["children"]:
1039
+ ids.update(extract_all_ids_from_tree(child))
1040
+
1041
+ return ids
1042
+
1043
+
1044
+ def filter_nodes_by_tree_ids(tree_data, nodes_data):
1045
+ """
1046
+ Filter nodes list based on IDs used in tree structure
1047
+
1048
+ Args:
1049
+ tree_data: Tree structure data (dictionary)
1050
+ nodes_data: Data containing nodes list (dictionary)
1051
+
1052
+ Returns:
1053
+ dict: Filtered nodes data, maintaining original structure
1054
+ """
1055
+ # Extract all IDs used in the tree
1056
+ used_ids = extract_all_ids_from_tree(tree_data)
1057
+
1058
+ # Filter nodes list, keeping only nodes with IDs used in the tree
1059
+ filtered_nodes = [node for node in nodes_data["nodes"] if node["id"] in used_ids]
1060
+
1061
+ # Return result maintaining original structure
1062
+ return {"nodes": filtered_nodes}
1063
+
1064
+
1065
+ def convert_activation_memory_to_serializable(
1066
+ act_mem_items: list[KVCacheItem],
1067
+ ) -> list[dict[str, Any]]:
1068
+ """
1069
+ Convert activation memory items to a serializable format.
1070
+
1071
+ Args:
1072
+ act_mem_items: List of KVCacheItem objects
1073
+
1074
+ Returns:
1075
+ List of dictionaries with serializable data
1076
+ """
1077
+ serializable_items = []
1078
+
1079
+ for item in act_mem_items:
1080
+ # Extract basic information that can be serialized
1081
+ serializable_item = {
1082
+ "id": item.id,
1083
+ "metadata": item.metadata,
1084
+ "memory_info": {
1085
+ "type": "DynamicCache",
1086
+ "key_cache_layers": len(item.memory.key_cache) if item.memory else 0,
1087
+ "value_cache_layers": len(item.memory.value_cache) if item.memory else 0,
1088
+ "device": str(item.memory.key_cache[0].device)
1089
+ if item.memory and item.memory.key_cache
1090
+ else "unknown",
1091
+ "dtype": str(item.memory.key_cache[0].dtype)
1092
+ if item.memory and item.memory.key_cache
1093
+ else "unknown",
1094
+ },
1095
+ }
1096
+
1097
+ # Add tensor shape information if available
1098
+ if item.memory and item.memory.key_cache:
1099
+ key_shapes = []
1100
+ value_shapes = []
1101
+
1102
+ for i, key_tensor in enumerate(item.memory.key_cache):
1103
+ if key_tensor is not None:
1104
+ key_shapes.append({"layer": i, "shape": list(key_tensor.shape)})
1105
+
1106
+ if i < len(item.memory.value_cache) and item.memory.value_cache[i] is not None:
1107
+ value_shapes.append(
1108
+ {"layer": i, "shape": list(item.memory.value_cache[i].shape)}
1109
+ )
1110
+
1111
+ serializable_item["memory_info"]["key_shapes"] = key_shapes
1112
+ serializable_item["memory_info"]["value_shapes"] = value_shapes
1113
+
1114
+ serializable_items.append(serializable_item)
1115
+
1116
+ return serializable_items
1117
+
1118
+
1119
+ def convert_activation_memory_summary(act_mem_items: list[KVCacheItem]) -> dict[str, Any]:
1120
+ """
1121
+ Create a summary of activation memory for API responses.
1122
+
1123
+ Args:
1124
+ act_mem_items: List of KVCacheItem objects
1125
+
1126
+ Returns:
1127
+ Dictionary with summary information
1128
+ """
1129
+ if not act_mem_items:
1130
+ return {"total_items": 0, "summary": "No activation memory items found"}
1131
+
1132
+ total_items = len(act_mem_items)
1133
+ total_layers = 0
1134
+ total_parameters = 0
1135
+
1136
+ for item in act_mem_items:
1137
+ if item.memory and item.memory.key_cache:
1138
+ total_layers += len(item.memory.key_cache)
1139
+
1140
+ # Calculate approximate parameter count
1141
+ for key_tensor in item.memory.key_cache:
1142
+ if key_tensor is not None:
1143
+ total_parameters += key_tensor.numel()
1144
+
1145
+ for value_tensor in item.memory.value_cache:
1146
+ if value_tensor is not None:
1147
+ total_parameters += value_tensor.numel()
1148
+
1149
+ return {
1150
+ "total_items": total_items,
1151
+ "total_layers": total_layers,
1152
+ "total_parameters": total_parameters,
1153
+ "summary": f"Activation memory contains {total_items} items with {total_layers} layers and approximately {total_parameters:,} parameters",
1154
+ }