MemoryOS 2.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. memoryos-2.0.3.dist-info/METADATA +418 -0
  2. memoryos-2.0.3.dist-info/RECORD +315 -0
  3. memoryos-2.0.3.dist-info/WHEEL +4 -0
  4. memoryos-2.0.3.dist-info/entry_points.txt +3 -0
  5. memoryos-2.0.3.dist-info/licenses/LICENSE +201 -0
  6. memos/__init__.py +20 -0
  7. memos/api/client.py +571 -0
  8. memos/api/config.py +1018 -0
  9. memos/api/context/dependencies.py +50 -0
  10. memos/api/exceptions.py +53 -0
  11. memos/api/handlers/__init__.py +62 -0
  12. memos/api/handlers/add_handler.py +158 -0
  13. memos/api/handlers/base_handler.py +194 -0
  14. memos/api/handlers/chat_handler.py +1401 -0
  15. memos/api/handlers/component_init.py +388 -0
  16. memos/api/handlers/config_builders.py +190 -0
  17. memos/api/handlers/feedback_handler.py +93 -0
  18. memos/api/handlers/formatters_handler.py +237 -0
  19. memos/api/handlers/memory_handler.py +316 -0
  20. memos/api/handlers/scheduler_handler.py +497 -0
  21. memos/api/handlers/search_handler.py +222 -0
  22. memos/api/handlers/suggestion_handler.py +117 -0
  23. memos/api/mcp_serve.py +614 -0
  24. memos/api/middleware/request_context.py +101 -0
  25. memos/api/product_api.py +38 -0
  26. memos/api/product_models.py +1206 -0
  27. memos/api/routers/__init__.py +1 -0
  28. memos/api/routers/product_router.py +477 -0
  29. memos/api/routers/server_router.py +394 -0
  30. memos/api/server_api.py +44 -0
  31. memos/api/start_api.py +433 -0
  32. memos/chunkers/__init__.py +4 -0
  33. memos/chunkers/base.py +24 -0
  34. memos/chunkers/charactertext_chunker.py +41 -0
  35. memos/chunkers/factory.py +24 -0
  36. memos/chunkers/markdown_chunker.py +62 -0
  37. memos/chunkers/sentence_chunker.py +54 -0
  38. memos/chunkers/simple_chunker.py +50 -0
  39. memos/cli.py +113 -0
  40. memos/configs/__init__.py +0 -0
  41. memos/configs/base.py +82 -0
  42. memos/configs/chunker.py +59 -0
  43. memos/configs/embedder.py +88 -0
  44. memos/configs/graph_db.py +236 -0
  45. memos/configs/internet_retriever.py +100 -0
  46. memos/configs/llm.py +151 -0
  47. memos/configs/mem_agent.py +54 -0
  48. memos/configs/mem_chat.py +81 -0
  49. memos/configs/mem_cube.py +105 -0
  50. memos/configs/mem_os.py +83 -0
  51. memos/configs/mem_reader.py +91 -0
  52. memos/configs/mem_scheduler.py +385 -0
  53. memos/configs/mem_user.py +70 -0
  54. memos/configs/memory.py +324 -0
  55. memos/configs/parser.py +38 -0
  56. memos/configs/reranker.py +18 -0
  57. memos/configs/utils.py +8 -0
  58. memos/configs/vec_db.py +80 -0
  59. memos/context/context.py +355 -0
  60. memos/dependency.py +52 -0
  61. memos/deprecation.py +262 -0
  62. memos/embedders/__init__.py +0 -0
  63. memos/embedders/ark.py +95 -0
  64. memos/embedders/base.py +106 -0
  65. memos/embedders/factory.py +29 -0
  66. memos/embedders/ollama.py +77 -0
  67. memos/embedders/sentence_transformer.py +49 -0
  68. memos/embedders/universal_api.py +51 -0
  69. memos/exceptions.py +30 -0
  70. memos/graph_dbs/__init__.py +0 -0
  71. memos/graph_dbs/base.py +274 -0
  72. memos/graph_dbs/factory.py +27 -0
  73. memos/graph_dbs/item.py +46 -0
  74. memos/graph_dbs/nebular.py +1794 -0
  75. memos/graph_dbs/neo4j.py +1942 -0
  76. memos/graph_dbs/neo4j_community.py +1058 -0
  77. memos/graph_dbs/polardb.py +5446 -0
  78. memos/hello_world.py +97 -0
  79. memos/llms/__init__.py +0 -0
  80. memos/llms/base.py +25 -0
  81. memos/llms/deepseek.py +13 -0
  82. memos/llms/factory.py +38 -0
  83. memos/llms/hf.py +443 -0
  84. memos/llms/hf_singleton.py +114 -0
  85. memos/llms/ollama.py +135 -0
  86. memos/llms/openai.py +222 -0
  87. memos/llms/openai_new.py +198 -0
  88. memos/llms/qwen.py +13 -0
  89. memos/llms/utils.py +14 -0
  90. memos/llms/vllm.py +218 -0
  91. memos/log.py +237 -0
  92. memos/mem_agent/base.py +19 -0
  93. memos/mem_agent/deepsearch_agent.py +391 -0
  94. memos/mem_agent/factory.py +36 -0
  95. memos/mem_chat/__init__.py +0 -0
  96. memos/mem_chat/base.py +30 -0
  97. memos/mem_chat/factory.py +21 -0
  98. memos/mem_chat/simple.py +200 -0
  99. memos/mem_cube/__init__.py +0 -0
  100. memos/mem_cube/base.py +30 -0
  101. memos/mem_cube/general.py +240 -0
  102. memos/mem_cube/navie.py +172 -0
  103. memos/mem_cube/utils.py +169 -0
  104. memos/mem_feedback/base.py +15 -0
  105. memos/mem_feedback/feedback.py +1192 -0
  106. memos/mem_feedback/simple_feedback.py +40 -0
  107. memos/mem_feedback/utils.py +230 -0
  108. memos/mem_os/client.py +5 -0
  109. memos/mem_os/core.py +1203 -0
  110. memos/mem_os/main.py +582 -0
  111. memos/mem_os/product.py +1608 -0
  112. memos/mem_os/product_server.py +455 -0
  113. memos/mem_os/utils/default_config.py +359 -0
  114. memos/mem_os/utils/format_utils.py +1403 -0
  115. memos/mem_os/utils/reference_utils.py +162 -0
  116. memos/mem_reader/__init__.py +0 -0
  117. memos/mem_reader/base.py +47 -0
  118. memos/mem_reader/factory.py +53 -0
  119. memos/mem_reader/memory.py +298 -0
  120. memos/mem_reader/multi_modal_struct.py +965 -0
  121. memos/mem_reader/read_multi_modal/__init__.py +43 -0
  122. memos/mem_reader/read_multi_modal/assistant_parser.py +311 -0
  123. memos/mem_reader/read_multi_modal/base.py +273 -0
  124. memos/mem_reader/read_multi_modal/file_content_parser.py +826 -0
  125. memos/mem_reader/read_multi_modal/image_parser.py +359 -0
  126. memos/mem_reader/read_multi_modal/multi_modal_parser.py +252 -0
  127. memos/mem_reader/read_multi_modal/string_parser.py +139 -0
  128. memos/mem_reader/read_multi_modal/system_parser.py +327 -0
  129. memos/mem_reader/read_multi_modal/text_content_parser.py +131 -0
  130. memos/mem_reader/read_multi_modal/tool_parser.py +210 -0
  131. memos/mem_reader/read_multi_modal/user_parser.py +218 -0
  132. memos/mem_reader/read_multi_modal/utils.py +358 -0
  133. memos/mem_reader/simple_struct.py +912 -0
  134. memos/mem_reader/strategy_struct.py +163 -0
  135. memos/mem_reader/utils.py +157 -0
  136. memos/mem_scheduler/__init__.py +0 -0
  137. memos/mem_scheduler/analyzer/__init__.py +0 -0
  138. memos/mem_scheduler/analyzer/api_analyzer.py +714 -0
  139. memos/mem_scheduler/analyzer/eval_analyzer.py +219 -0
  140. memos/mem_scheduler/analyzer/mos_for_test_scheduler.py +571 -0
  141. memos/mem_scheduler/analyzer/scheduler_for_eval.py +280 -0
  142. memos/mem_scheduler/base_scheduler.py +1319 -0
  143. memos/mem_scheduler/general_modules/__init__.py +0 -0
  144. memos/mem_scheduler/general_modules/api_misc.py +137 -0
  145. memos/mem_scheduler/general_modules/base.py +80 -0
  146. memos/mem_scheduler/general_modules/init_components_for_scheduler.py +425 -0
  147. memos/mem_scheduler/general_modules/misc.py +313 -0
  148. memos/mem_scheduler/general_modules/scheduler_logger.py +389 -0
  149. memos/mem_scheduler/general_modules/task_threads.py +315 -0
  150. memos/mem_scheduler/general_scheduler.py +1495 -0
  151. memos/mem_scheduler/memory_manage_modules/__init__.py +5 -0
  152. memos/mem_scheduler/memory_manage_modules/memory_filter.py +306 -0
  153. memos/mem_scheduler/memory_manage_modules/retriever.py +547 -0
  154. memos/mem_scheduler/monitors/__init__.py +0 -0
  155. memos/mem_scheduler/monitors/dispatcher_monitor.py +366 -0
  156. memos/mem_scheduler/monitors/general_monitor.py +394 -0
  157. memos/mem_scheduler/monitors/task_schedule_monitor.py +254 -0
  158. memos/mem_scheduler/optimized_scheduler.py +410 -0
  159. memos/mem_scheduler/orm_modules/__init__.py +0 -0
  160. memos/mem_scheduler/orm_modules/api_redis_model.py +518 -0
  161. memos/mem_scheduler/orm_modules/base_model.py +729 -0
  162. memos/mem_scheduler/orm_modules/monitor_models.py +261 -0
  163. memos/mem_scheduler/orm_modules/redis_model.py +699 -0
  164. memos/mem_scheduler/scheduler_factory.py +23 -0
  165. memos/mem_scheduler/schemas/__init__.py +0 -0
  166. memos/mem_scheduler/schemas/analyzer_schemas.py +52 -0
  167. memos/mem_scheduler/schemas/api_schemas.py +233 -0
  168. memos/mem_scheduler/schemas/general_schemas.py +55 -0
  169. memos/mem_scheduler/schemas/message_schemas.py +173 -0
  170. memos/mem_scheduler/schemas/monitor_schemas.py +406 -0
  171. memos/mem_scheduler/schemas/task_schemas.py +132 -0
  172. memos/mem_scheduler/task_schedule_modules/__init__.py +0 -0
  173. memos/mem_scheduler/task_schedule_modules/dispatcher.py +740 -0
  174. memos/mem_scheduler/task_schedule_modules/local_queue.py +247 -0
  175. memos/mem_scheduler/task_schedule_modules/orchestrator.py +74 -0
  176. memos/mem_scheduler/task_schedule_modules/redis_queue.py +1385 -0
  177. memos/mem_scheduler/task_schedule_modules/task_queue.py +162 -0
  178. memos/mem_scheduler/utils/__init__.py +0 -0
  179. memos/mem_scheduler/utils/api_utils.py +77 -0
  180. memos/mem_scheduler/utils/config_utils.py +100 -0
  181. memos/mem_scheduler/utils/db_utils.py +50 -0
  182. memos/mem_scheduler/utils/filter_utils.py +176 -0
  183. memos/mem_scheduler/utils/metrics.py +125 -0
  184. memos/mem_scheduler/utils/misc_utils.py +290 -0
  185. memos/mem_scheduler/utils/monitor_event_utils.py +67 -0
  186. memos/mem_scheduler/utils/status_tracker.py +229 -0
  187. memos/mem_scheduler/webservice_modules/__init__.py +0 -0
  188. memos/mem_scheduler/webservice_modules/rabbitmq_service.py +485 -0
  189. memos/mem_scheduler/webservice_modules/redis_service.py +380 -0
  190. memos/mem_user/factory.py +94 -0
  191. memos/mem_user/mysql_persistent_user_manager.py +271 -0
  192. memos/mem_user/mysql_user_manager.py +502 -0
  193. memos/mem_user/persistent_factory.py +98 -0
  194. memos/mem_user/persistent_user_manager.py +260 -0
  195. memos/mem_user/redis_persistent_user_manager.py +225 -0
  196. memos/mem_user/user_manager.py +488 -0
  197. memos/memories/__init__.py +0 -0
  198. memos/memories/activation/__init__.py +0 -0
  199. memos/memories/activation/base.py +42 -0
  200. memos/memories/activation/item.py +56 -0
  201. memos/memories/activation/kv.py +292 -0
  202. memos/memories/activation/vllmkv.py +219 -0
  203. memos/memories/base.py +19 -0
  204. memos/memories/factory.py +42 -0
  205. memos/memories/parametric/__init__.py +0 -0
  206. memos/memories/parametric/base.py +19 -0
  207. memos/memories/parametric/item.py +11 -0
  208. memos/memories/parametric/lora.py +41 -0
  209. memos/memories/textual/__init__.py +0 -0
  210. memos/memories/textual/base.py +92 -0
  211. memos/memories/textual/general.py +236 -0
  212. memos/memories/textual/item.py +304 -0
  213. memos/memories/textual/naive.py +187 -0
  214. memos/memories/textual/prefer_text_memory/__init__.py +0 -0
  215. memos/memories/textual/prefer_text_memory/adder.py +504 -0
  216. memos/memories/textual/prefer_text_memory/config.py +106 -0
  217. memos/memories/textual/prefer_text_memory/extractor.py +221 -0
  218. memos/memories/textual/prefer_text_memory/factory.py +85 -0
  219. memos/memories/textual/prefer_text_memory/retrievers.py +177 -0
  220. memos/memories/textual/prefer_text_memory/spliter.py +132 -0
  221. memos/memories/textual/prefer_text_memory/utils.py +93 -0
  222. memos/memories/textual/preference.py +344 -0
  223. memos/memories/textual/simple_preference.py +161 -0
  224. memos/memories/textual/simple_tree.py +69 -0
  225. memos/memories/textual/tree.py +459 -0
  226. memos/memories/textual/tree_text_memory/__init__.py +0 -0
  227. memos/memories/textual/tree_text_memory/organize/__init__.py +0 -0
  228. memos/memories/textual/tree_text_memory/organize/handler.py +184 -0
  229. memos/memories/textual/tree_text_memory/organize/manager.py +518 -0
  230. memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +238 -0
  231. memos/memories/textual/tree_text_memory/organize/reorganizer.py +622 -0
  232. memos/memories/textual/tree_text_memory/retrieve/__init__.py +0 -0
  233. memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py +364 -0
  234. memos/memories/textual/tree_text_memory/retrieve/bm25_util.py +186 -0
  235. memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +419 -0
  236. memos/memories/textual/tree_text_memory/retrieve/internet_retriever.py +270 -0
  237. memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +102 -0
  238. memos/memories/textual/tree_text_memory/retrieve/reasoner.py +61 -0
  239. memos/memories/textual/tree_text_memory/retrieve/recall.py +497 -0
  240. memos/memories/textual/tree_text_memory/retrieve/reranker.py +111 -0
  241. memos/memories/textual/tree_text_memory/retrieve/retrieval_mid_structs.py +16 -0
  242. memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py +472 -0
  243. memos/memories/textual/tree_text_memory/retrieve/searcher.py +848 -0
  244. memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +135 -0
  245. memos/memories/textual/tree_text_memory/retrieve/utils.py +54 -0
  246. memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py +387 -0
  247. memos/memos_tools/dinding_report_bot.py +453 -0
  248. memos/memos_tools/lockfree_dict.py +120 -0
  249. memos/memos_tools/notification_service.py +44 -0
  250. memos/memos_tools/notification_utils.py +142 -0
  251. memos/memos_tools/singleton.py +174 -0
  252. memos/memos_tools/thread_safe_dict.py +310 -0
  253. memos/memos_tools/thread_safe_dict_segment.py +382 -0
  254. memos/multi_mem_cube/__init__.py +0 -0
  255. memos/multi_mem_cube/composite_cube.py +86 -0
  256. memos/multi_mem_cube/single_cube.py +874 -0
  257. memos/multi_mem_cube/views.py +54 -0
  258. memos/parsers/__init__.py +0 -0
  259. memos/parsers/base.py +15 -0
  260. memos/parsers/factory.py +21 -0
  261. memos/parsers/markitdown.py +28 -0
  262. memos/reranker/__init__.py +4 -0
  263. memos/reranker/base.py +25 -0
  264. memos/reranker/concat.py +103 -0
  265. memos/reranker/cosine_local.py +102 -0
  266. memos/reranker/factory.py +72 -0
  267. memos/reranker/http_bge.py +324 -0
  268. memos/reranker/http_bge_strategy.py +327 -0
  269. memos/reranker/noop.py +19 -0
  270. memos/reranker/strategies/__init__.py +4 -0
  271. memos/reranker/strategies/base.py +61 -0
  272. memos/reranker/strategies/concat_background.py +94 -0
  273. memos/reranker/strategies/concat_docsource.py +110 -0
  274. memos/reranker/strategies/dialogue_common.py +109 -0
  275. memos/reranker/strategies/factory.py +31 -0
  276. memos/reranker/strategies/single_turn.py +107 -0
  277. memos/reranker/strategies/singleturn_outmem.py +98 -0
  278. memos/settings.py +10 -0
  279. memos/templates/__init__.py +0 -0
  280. memos/templates/advanced_search_prompts.py +211 -0
  281. memos/templates/cloud_service_prompt.py +107 -0
  282. memos/templates/instruction_completion.py +66 -0
  283. memos/templates/mem_agent_prompts.py +85 -0
  284. memos/templates/mem_feedback_prompts.py +822 -0
  285. memos/templates/mem_reader_prompts.py +1096 -0
  286. memos/templates/mem_reader_strategy_prompts.py +238 -0
  287. memos/templates/mem_scheduler_prompts.py +626 -0
  288. memos/templates/mem_search_prompts.py +93 -0
  289. memos/templates/mos_prompts.py +403 -0
  290. memos/templates/prefer_complete_prompt.py +735 -0
  291. memos/templates/tool_mem_prompts.py +139 -0
  292. memos/templates/tree_reorganize_prompts.py +230 -0
  293. memos/types/__init__.py +34 -0
  294. memos/types/general_types.py +151 -0
  295. memos/types/openai_chat_completion_types/__init__.py +15 -0
  296. memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py +56 -0
  297. memos/types/openai_chat_completion_types/chat_completion_content_part_image_param.py +27 -0
  298. memos/types/openai_chat_completion_types/chat_completion_content_part_input_audio_param.py +23 -0
  299. memos/types/openai_chat_completion_types/chat_completion_content_part_param.py +43 -0
  300. memos/types/openai_chat_completion_types/chat_completion_content_part_refusal_param.py +16 -0
  301. memos/types/openai_chat_completion_types/chat_completion_content_part_text_param.py +16 -0
  302. memos/types/openai_chat_completion_types/chat_completion_message_custom_tool_call_param.py +27 -0
  303. memos/types/openai_chat_completion_types/chat_completion_message_function_tool_call_param.py +32 -0
  304. memos/types/openai_chat_completion_types/chat_completion_message_param.py +18 -0
  305. memos/types/openai_chat_completion_types/chat_completion_message_tool_call_union_param.py +15 -0
  306. memos/types/openai_chat_completion_types/chat_completion_system_message_param.py +36 -0
  307. memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py +30 -0
  308. memos/types/openai_chat_completion_types/chat_completion_user_message_param.py +34 -0
  309. memos/utils.py +123 -0
  310. memos/vec_dbs/__init__.py +0 -0
  311. memos/vec_dbs/base.py +117 -0
  312. memos/vec_dbs/factory.py +23 -0
  313. memos/vec_dbs/item.py +50 -0
  314. memos/vec_dbs/milvus.py +654 -0
  315. memos/vec_dbs/qdrant.py +355 -0
@@ -0,0 +1,1403 @@
1
+ import math
2
+ import random
3
+
4
+ from typing import Any
5
+
6
+ from memos.log import get_logger
7
+ from memos.memories.activation.item import KVCacheItem
8
+
9
+
10
+ logger = get_logger(__name__)
11
+
12
+
13
+ def extract_node_name(memory: str) -> str:
14
+ """Extract the first two words from memory as node_name"""
15
+ if not memory:
16
+ return ""
17
+
18
+ words = [word.strip() for word in memory.split() if word.strip()]
19
+
20
+ if len(words) >= 2:
21
+ return " ".join(words[:2])
22
+ elif len(words) == 1:
23
+ return words[0]
24
+ else:
25
+ return ""
26
+
27
+
28
+ def analyze_tree_structure_enhanced(nodes: list[dict], edges: list[dict]) -> dict:
29
+ """Enhanced tree structure analysis, focusing on branching degree and leaf distribution"""
30
+ # Build adjacency list
31
+ adj_list = {}
32
+ reverse_adj = {}
33
+ for edge in edges:
34
+ source, target = edge["source"], edge["target"]
35
+ adj_list.setdefault(source, []).append(target)
36
+ reverse_adj.setdefault(target, []).append(source)
37
+
38
+ # Find all nodes and root nodes
39
+ all_nodes = {node["id"] for node in nodes}
40
+ target_nodes = {edge["target"] for edge in edges}
41
+ root_nodes = all_nodes - target_nodes
42
+
43
+ subtree_analysis = {}
44
+
45
+ def analyze_subtree_enhanced(root_id: str) -> dict:
46
+ """Enhanced subtree analysis, focusing on branching degree and structure quality"""
47
+ visited = set()
48
+ max_depth = 0
49
+ leaf_count = 0
50
+ total_nodes = 0
51
+ branch_nodes = 0 # Number of branch nodes with multiple children
52
+ chain_length = 0 # Longest single chain length
53
+ width_per_level = {} # Width per level
54
+
55
+ def dfs(node_id: str, depth: int, chain_len: int):
56
+ nonlocal max_depth, leaf_count, total_nodes, branch_nodes, chain_length
57
+
58
+ if node_id in visited:
59
+ return
60
+
61
+ visited.add(node_id)
62
+ total_nodes += 1
63
+ max_depth = max(max_depth, depth)
64
+ chain_length = max(chain_length, chain_len)
65
+
66
+ # Record number of nodes per level
67
+ width_per_level[depth] = width_per_level.get(depth, 0) + 1
68
+
69
+ children = adj_list.get(node_id, [])
70
+
71
+ if not children: # Leaf node
72
+ leaf_count += 1
73
+ elif len(children) > 1: # Branch node
74
+ branch_nodes += 1
75
+ # Reset chain length because we encountered a branch
76
+ for child in children:
77
+ dfs(child, depth + 1, 0)
78
+ else: # Single child node (chain structure)
79
+ for child in children:
80
+ dfs(child, depth + 1, chain_len + 1)
81
+
82
+ dfs(root_id, 0, 0)
83
+
84
+ # Calculate structure quality metrics
85
+ avg_width = sum(width_per_level.values()) / len(width_per_level) if width_per_level else 0
86
+ max_width = max(width_per_level.values()) if width_per_level else 0
87
+
88
+ # Calculate branch density: ratio of branch nodes to total nodes
89
+ branch_density = branch_nodes / total_nodes if total_nodes > 0 else 0
90
+
91
+ # Calculate depth-width ratio: ideal tree should have moderate depth and good breadth
92
+ depth_width_ratio = max_depth / max_width if max_width > 0 else max_depth
93
+
94
+ quality_score = calculate_enhanced_quality(
95
+ max_depth,
96
+ leaf_count,
97
+ total_nodes,
98
+ branch_nodes,
99
+ chain_length,
100
+ branch_density,
101
+ depth_width_ratio,
102
+ max_width,
103
+ )
104
+
105
+ return {
106
+ "root_id": root_id,
107
+ "max_depth": max_depth,
108
+ "leaf_count": leaf_count,
109
+ "total_nodes": total_nodes,
110
+ "branch_nodes": branch_nodes,
111
+ "max_chain_length": chain_length,
112
+ "branch_density": branch_density,
113
+ "max_width": max_width,
114
+ "avg_width": avg_width,
115
+ "depth_width_ratio": depth_width_ratio,
116
+ "nodes_in_subtree": list(visited),
117
+ "quality_score": quality_score,
118
+ "width_per_level": width_per_level,
119
+ }
120
+
121
+ for root_id in root_nodes:
122
+ subtree_analysis[root_id] = analyze_subtree_enhanced(root_id)
123
+
124
+ return subtree_analysis
125
+
126
+
127
+ def calculate_enhanced_quality(
128
+ max_depth: int,
129
+ leaf_count: int,
130
+ total_nodes: int,
131
+ branch_nodes: int,
132
+ max_chain_length: int,
133
+ branch_density: float,
134
+ depth_width_ratio: float,
135
+ max_width: int,
136
+ ) -> float:
137
+ """Enhanced quality calculation, prioritizing branching degree and leaf distribution"""
138
+
139
+ if total_nodes <= 1:
140
+ return 0.1
141
+
142
+ # 1. Branch quality score (weight: 35%)
143
+ # Branch node count score
144
+ branch_count_score = min(branch_nodes * 3, 15) # 3 points per branch node, max 15 points
145
+
146
+ # Branch density score: ideal density between 20%-60%
147
+ if 0.2 <= branch_density <= 0.6:
148
+ branch_density_score = 10
149
+ elif branch_density > 0.6:
150
+ branch_density_score = max(5, 10 - (branch_density - 0.6) * 20)
151
+ else:
152
+ branch_density_score = branch_density * 25 # Linear growth for 0-20%
153
+
154
+ branch_score = (branch_count_score + branch_density_score) * 0.35
155
+
156
+ # 2. Leaf quality score (weight: 25%)
157
+ # Leaf count score
158
+ leaf_count_score = min(leaf_count * 2, 20)
159
+
160
+ # Leaf distribution score: ideal leaf ratio 30%-70% of total nodes
161
+ leaf_ratio = leaf_count / total_nodes
162
+ if 0.3 <= leaf_ratio <= 0.7:
163
+ leaf_ratio_score = 10
164
+ elif leaf_ratio > 0.7:
165
+ leaf_ratio_score = max(3, 10 - (leaf_ratio - 0.7) * 20)
166
+ else:
167
+ leaf_ratio_score = leaf_ratio * 20 # Linear growth for 0-30%
168
+
169
+ leaf_score = (leaf_count_score + leaf_ratio_score) * 0.25
170
+
171
+ # 3. Structure balance score (weight: 25%)
172
+ # Depth score: moderate depth is best (3-8 layers)
173
+ if 3 <= max_depth <= 8:
174
+ depth_score = 15
175
+ elif max_depth < 3:
176
+ depth_score = max_depth * 3 # Lower score for 1-2 layers
177
+ else:
178
+ depth_score = max(5, 15 - (max_depth - 8) * 1.5) # Gradually reduce score beyond 8 layers
179
+
180
+ # Width score: larger max width is better, but with upper limit
181
+ width_score = min(max_width * 1.5, 15)
182
+
183
+ # Depth-width ratio penalty: too large ratio means tree is too "thin"
184
+ if depth_width_ratio > 3:
185
+ ratio_penalty = (depth_width_ratio - 3) * 2
186
+ structure_score = max(0, (depth_score + width_score - ratio_penalty)) * 0.25
187
+ else:
188
+ structure_score = (depth_score + width_score) * 0.25
189
+
190
+ # 4. Chain structure penalty (weight: 15%)
191
+ # Longest single chain length penalty: overly long chains severely affect display
192
+ if max_chain_length <= 2:
193
+ chain_penalty_score = 10
194
+ elif max_chain_length <= 5:
195
+ chain_penalty_score = 8 - (max_chain_length - 2)
196
+ else:
197
+ chain_penalty_score = max(0, 3 - (max_chain_length - 5) * 0.5)
198
+
199
+ chain_score = chain_penalty_score * 0.15
200
+
201
+ # 5. Comprehensive calculation
202
+ total_score = branch_score + leaf_score + structure_score + chain_score
203
+
204
+ # Special case severe penalties
205
+ if max_chain_length > total_nodes * 0.8: # If more than 80% are single chains
206
+ total_score *= 0.3
207
+ elif branch_density < 0.1 and total_nodes > 5: # Large tree with almost no branches
208
+ total_score *= 0.5
209
+
210
+ return total_score
211
+
212
+
213
+ def sample_nodes_with_type_balance(
214
+ nodes: list[dict],
215
+ edges: list[dict],
216
+ target_count: int = 150,
217
+ type_ratios: dict[str, float] | None = None,
218
+ ) -> tuple[list[dict], list[dict]]:
219
+ """
220
+ Balanced sampling based on type ratios and tree quality
221
+
222
+ Args:
223
+ nodes: List of nodes
224
+ edges: List of edges
225
+ target_count: Target number of nodes
226
+ type_ratios: Expected ratio for each type, e.g. {'WorkingMemory': 0.15, 'EpisodicMemory': 0.30, ...}
227
+ """
228
+ if len(nodes) <= target_count:
229
+ return nodes, edges
230
+
231
+ # Default type ratio configuration
232
+ if type_ratios is None:
233
+ type_ratios = {
234
+ "WorkingMemory": 0.10, # 10%
235
+ "EpisodicMemory": 0.25, # 25%
236
+ "SemanticMemory": 0.25, # 25%
237
+ "ProceduralMemory": 0.20, # 20%
238
+ "EmotionalMemory": 0.15, # 15%
239
+ "MetaMemory": 0.05, # 5%
240
+ }
241
+
242
+ logger.info(
243
+ f"Starting type-balanced sampling, original nodes: {len(nodes)}, target nodes: {target_count}"
244
+ )
245
+ logger.info(f"Target type ratios: {type_ratios}")
246
+
247
+ # Analyze current node type distribution
248
+ current_type_counts = {}
249
+ nodes_by_type = {}
250
+
251
+ for node in nodes:
252
+ memory_type = node.get("metadata", {}).get("memory_type", "Unknown")
253
+ current_type_counts[memory_type] = current_type_counts.get(memory_type, 0) + 1
254
+ if memory_type not in nodes_by_type:
255
+ nodes_by_type[memory_type] = []
256
+ nodes_by_type[memory_type].append(node)
257
+
258
+ logger.info(f"Current type distribution: {current_type_counts}")
259
+
260
+ # Calculate target node count for each type
261
+ type_targets = {}
262
+ remaining_target = target_count
263
+
264
+ for memory_type, ratio in type_ratios.items():
265
+ if memory_type in nodes_by_type:
266
+ target_for_type = int(target_count * ratio)
267
+ # Ensure not exceeding the actual node count for this type
268
+ target_for_type = min(target_for_type, len(nodes_by_type[memory_type]))
269
+ type_targets[memory_type] = target_for_type
270
+ remaining_target -= target_for_type
271
+
272
+ # Handle types not in ratio configuration
273
+ other_types = set(nodes_by_type.keys()) - set(type_ratios.keys())
274
+ if other_types and remaining_target > 0:
275
+ per_other_type = max(1, remaining_target // len(other_types))
276
+ for memory_type in other_types:
277
+ allocation = min(per_other_type, len(nodes_by_type[memory_type]))
278
+ type_targets[memory_type] = allocation
279
+ remaining_target -= allocation
280
+
281
+ # If there's still remaining, distribute proportionally to main types
282
+ if remaining_target > 0:
283
+ main_types = [t for t in type_ratios if t in nodes_by_type]
284
+ if main_types:
285
+ extra_per_type = remaining_target // len(main_types)
286
+ for memory_type in main_types:
287
+ additional = min(
288
+ extra_per_type,
289
+ len(nodes_by_type[memory_type]) - type_targets.get(memory_type, 0),
290
+ )
291
+ type_targets[memory_type] = type_targets.get(memory_type, 0) + additional
292
+
293
+ logger.info(f"Target node count for each type: {type_targets}")
294
+
295
+ # Perform subtree quality sampling for each type
296
+ selected_nodes = []
297
+
298
+ for memory_type, target_for_type in type_targets.items():
299
+ if target_for_type <= 0 or memory_type not in nodes_by_type:
300
+ continue
301
+
302
+ type_nodes = nodes_by_type[memory_type]
303
+ logger.info(
304
+ f"\n--- Processing {memory_type} type: {len(type_nodes)} -> {target_for_type} ---"
305
+ )
306
+
307
+ if len(type_nodes) <= target_for_type:
308
+ selected_nodes.extend(type_nodes)
309
+ logger.info(f" Select all: {len(type_nodes)} nodes")
310
+ else:
311
+ # Use enhanced subtree quality sampling
312
+ type_selected = sample_by_enhanced_subtree_quality(type_nodes, edges, target_for_type)
313
+ selected_nodes.extend(type_selected)
314
+ logger.info(f" Sampled selection: {len(type_selected)} nodes")
315
+
316
+ # Filter edges
317
+ selected_node_ids = {node["id"] for node in selected_nodes}
318
+ filtered_edges = [
319
+ edge
320
+ for edge in edges
321
+ if edge["source"] in selected_node_ids and edge["target"] in selected_node_ids
322
+ ]
323
+
324
+ logger.info(f"\nFinal selected nodes: {len(selected_nodes)}")
325
+ logger.info(f"Final edges: {len(filtered_edges)}")
326
+
327
+ # Verify final type distribution
328
+ final_type_counts = {}
329
+ for node in selected_nodes:
330
+ memory_type = node.get("metadata", {}).get("memory_type", "Unknown")
331
+ final_type_counts[memory_type] = final_type_counts.get(memory_type, 0) + 1
332
+
333
+ logger.info(f"Final type distribution: {final_type_counts}")
334
+ for memory_type, count in final_type_counts.items():
335
+ percentage = count / len(selected_nodes) * 100
336
+ target_percentage = type_ratios.get(memory_type, 0) * 100
337
+ logger.info(
338
+ f" {memory_type}: {count} nodes ({percentage:.1f}%, target: {target_percentage:.1f}%)"
339
+ )
340
+
341
+ return selected_nodes, filtered_edges
342
+
343
+
344
+ def sample_by_enhanced_subtree_quality(
345
+ nodes: list[dict], edges: list[dict], target_count: int
346
+ ) -> list[dict]:
347
+ """Sample using enhanced subtree quality"""
348
+ if len(nodes) <= target_count:
349
+ return nodes
350
+
351
+ # Analyze subtree structure
352
+ subtree_analysis = analyze_tree_structure_enhanced(nodes, edges)
353
+
354
+ if not subtree_analysis:
355
+ # If no subtree structure, sample by node importance
356
+ return sample_nodes_by_importance(nodes, edges, target_count)
357
+
358
+ # Sort subtrees by quality score
359
+ sorted_subtrees = sorted(
360
+ subtree_analysis.items(), key=lambda x: x[1]["quality_score"], reverse=True
361
+ )
362
+
363
+ logger.info(" Subtree quality ranking:")
364
+ for i, (root_id, analysis) in enumerate(sorted_subtrees[:5]):
365
+ logger.info(
366
+ f" #{i + 1} Root node {root_id}: Quality={analysis['quality_score']:.2f}, "
367
+ f"Depth={analysis['max_depth']}, Branches={analysis['branch_nodes']}, "
368
+ f"Leaves={analysis['leaf_count']}, Max Width={analysis['max_width']}"
369
+ )
370
+
371
+ # Greedy selection of high-quality subtrees
372
+ selected_nodes = []
373
+ selected_node_ids = set()
374
+
375
+ for root_id, analysis in sorted_subtrees:
376
+ subtree_nodes = analysis["nodes_in_subtree"]
377
+ new_nodes = [node_id for node_id in subtree_nodes if node_id not in selected_node_ids]
378
+
379
+ if not new_nodes:
380
+ continue
381
+
382
+ remaining_quota = target_count - len(selected_nodes)
383
+
384
+ if len(new_nodes) <= remaining_quota:
385
+ # Entire subtree can be added
386
+ for node_id in new_nodes:
387
+ node = next((n for n in nodes if n["id"] == node_id), None)
388
+ if node:
389
+ selected_nodes.append(node)
390
+ selected_node_ids.add(node_id)
391
+ logger.info(f" Select entire subtree {root_id}: +{len(new_nodes)} nodes")
392
+ else:
393
+ # Subtree too large, need partial selection
394
+ if analysis["quality_score"] > 5: # Only partial selection for high-quality subtrees
395
+ subtree_node_objects = [n for n in nodes if n["id"] in new_nodes]
396
+ partial_selection = select_best_nodes_from_subtree(
397
+ subtree_node_objects, edges, remaining_quota, root_id
398
+ )
399
+
400
+ selected_nodes.extend(partial_selection)
401
+ for node in partial_selection:
402
+ selected_node_ids.add(node["id"])
403
+ logger.info(
404
+ f" Partial selection of subtree {root_id}: +{len(partial_selection)} nodes"
405
+ )
406
+
407
+ if len(selected_nodes) >= target_count:
408
+ break
409
+
410
+ # If target count not reached, supplement with remaining nodes
411
+ if len(selected_nodes) < target_count:
412
+ remaining_nodes = [n for n in nodes if n["id"] not in selected_node_ids]
413
+ remaining_count = target_count - len(selected_nodes)
414
+ additional = sample_nodes_by_importance(remaining_nodes, edges, remaining_count)
415
+ selected_nodes.extend(additional)
416
+ logger.info(f" Supplementary selection: +{len(additional)} nodes")
417
+
418
+ return selected_nodes
419
+
420
+
421
+ def select_best_nodes_from_subtree(
422
+ subtree_nodes: list[dict], edges: list[dict], max_count: int, root_id: str
423
+ ) -> list[dict]:
424
+ """Select the most important nodes from subtree, prioritizing branch structure"""
425
+ if len(subtree_nodes) <= max_count:
426
+ return subtree_nodes
427
+
428
+ # Build internal connection relationships of subtree
429
+ subtree_node_ids = {node["id"] for node in subtree_nodes}
430
+ subtree_edges = [
431
+ edge
432
+ for edge in edges
433
+ if edge["source"] in subtree_node_ids and edge["target"] in subtree_node_ids
434
+ ]
435
+
436
+ # Calculate importance score for each node
437
+ node_scores = []
438
+
439
+ for node in subtree_nodes:
440
+ node_id = node["id"]
441
+
442
+ # Out-degree and in-degree
443
+ out_degree = sum(1 for edge in subtree_edges if edge["source"] == node_id)
444
+ in_degree = sum(1 for edge in subtree_edges if edge["target"] == node_id)
445
+
446
+ # Content length score
447
+ content_score = min(len(node.get("memory", "")), 300) / 15
448
+
449
+ # Branch node bonus
450
+ branch_bonus = out_degree * 8 if out_degree > 1 else 0
451
+
452
+ # Root node bonus
453
+ root_bonus = 15 if node_id == root_id else 0
454
+
455
+ # Connection importance
456
+ connection_score = (out_degree + in_degree) * 3
457
+
458
+ # Leaf node moderate bonus (ensure certain number of leaf nodes)
459
+ leaf_bonus = 5 if out_degree == 0 and in_degree > 0 else 0
460
+
461
+ total_score = content_score + connection_score + branch_bonus + root_bonus + leaf_bonus
462
+ node_scores.append((node, total_score))
463
+
464
+ # Sort by score and select
465
+ node_scores.sort(key=lambda x: x[1], reverse=True)
466
+ selected = [node for node, _ in node_scores[:max_count]]
467
+
468
+ return selected
469
+
470
+
471
+ def sample_nodes_by_importance(
472
+ nodes: list[dict], edges: list[dict], target_count: int
473
+ ) -> list[dict]:
474
+ """Sample by node importance (for cases without tree structure)"""
475
+ if len(nodes) <= target_count:
476
+ return nodes
477
+
478
+ node_scores = []
479
+
480
+ for node in nodes:
481
+ node_id = node["id"]
482
+ out_degree = sum(1 for edge in edges if edge["source"] == node_id)
483
+ in_degree = sum(1 for edge in edges if edge["target"] == node_id)
484
+ content_score = min(len(node.get("memory", "")), 200) / 10
485
+ connection_score = (out_degree + in_degree) * 5
486
+ random_score = random.random() * 10
487
+
488
+ total_score = content_score + connection_score + random_score
489
+ node_scores.append((node, total_score))
490
+
491
+ node_scores.sort(key=lambda x: x[1], reverse=True)
492
+ return [node for node, _ in node_scores[:target_count]]
493
+
494
+
495
+ # Modified main function to use new sampling strategy
496
+ def convert_graph_to_tree_forworkmem(
497
+ json_data: dict[str, Any],
498
+ target_node_count: int = 200,
499
+ type_ratios: dict[str, float] | None = None,
500
+ ) -> dict[str, Any]:
501
+ """
502
+ Enhanced graph-to-tree conversion function, prioritizing branching degree and type balance
503
+ """
504
+ original_nodes = json_data.get("nodes", [])
505
+ original_edges = json_data.get("edges", [])
506
+
507
+ logger.info(f"Original node count: {len(original_nodes)}")
508
+ logger.info(f"Target node count: {target_node_count}")
509
+ filter_original_edges = []
510
+ for original_edge in original_edges:
511
+ if original_edge["type"] == "PARENT":
512
+ filter_original_edges.append(original_edge)
513
+ node_type_count = {}
514
+ for node in original_nodes:
515
+ node_type = node.get("metadata", {}).get("memory_type", "Unknown")
516
+ node_type_count[node_type] = node_type_count.get(node_type, 0) + 1
517
+ original_edges = filter_original_edges
518
+ # Use enhanced type-balanced sampling
519
+ if len(original_nodes) > target_node_count:
520
+ nodes, edges = sample_nodes_with_type_balance(
521
+ original_nodes, original_edges, target_node_count, type_ratios
522
+ )
523
+ else:
524
+ nodes, edges = original_nodes, original_edges
525
+
526
+ # The rest of tree structure building remains unchanged...
527
+ # [Original tree building code here]
528
+
529
+ # Create node mapping table
530
+ node_map = {}
531
+ for node in nodes:
532
+ memory = node.get("memory", "")
533
+ node_name = extract_node_name(memory)
534
+ memory_key = node.get("metadata", {}).get("key", node_name)
535
+ usage = node.get("metadata", {}).get("usage", [])
536
+ frequency = len(usage) if len(usage) < 100 else 100
537
+ node_map[node["id"]] = {
538
+ "id": node["id"],
539
+ "value": memory,
540
+ "frequency": frequency,
541
+ "node_name": memory_key,
542
+ "memory_type": node.get("metadata", {}).get("memory_type", "Unknown"),
543
+ "children": [],
544
+ }
545
+
546
+ # Build parent-child relationship mapping
547
+ children_map = {}
548
+ parent_map = {}
549
+
550
+ for edge in edges:
551
+ source = edge["source"]
552
+ target = edge["target"]
553
+ if source not in children_map:
554
+ children_map[source] = []
555
+ children_map[source].append(target)
556
+ parent_map[target] = source
557
+
558
+ # Find root nodes
559
+ all_node_ids = set(node_map.keys())
560
+ children_node_ids = set(parent_map.keys())
561
+ root_node_ids = all_node_ids - children_node_ids
562
+
563
+ # Separate WorkingMemory and other root nodes
564
+ working_memory_roots = []
565
+ other_roots = []
566
+
567
+ for root_id in root_node_ids:
568
+ if node_map[root_id]["memory_type"] == "WorkingMemory":
569
+ working_memory_roots.append(root_id)
570
+ else:
571
+ other_roots.append(root_id)
572
+
573
+ def build_tree(node_id: str, visited=None) -> dict[str, Any] | None:
574
+ """Recursively build tree structure with cycle detection"""
575
+ if visited is None:
576
+ visited = set()
577
+
578
+ if node_id in visited:
579
+ logger.warning(f"[build_tree] Detected cycle at node {node_id}, skipping.")
580
+ return None
581
+ visited.add(node_id)
582
+
583
+ if node_id not in node_map:
584
+ return None
585
+
586
+ children_ids = children_map.get(node_id, [])
587
+ children = []
588
+ for child_id in children_ids:
589
+ child_tree = build_tree(child_id, visited)
590
+ if child_tree:
591
+ children.append(child_tree)
592
+
593
+ node = {
594
+ "id": node_id,
595
+ "node_name": node_map[node_id]["node_name"],
596
+ "value": node_map[node_id]["value"],
597
+ "memory_type": node_map[node_id]["memory_type"],
598
+ "frequency": node_map[node_id]["frequency"],
599
+ }
600
+
601
+ if children:
602
+ node["children"] = children
603
+
604
+ return node
605
+
606
+ # Build root tree list
607
+ root_trees = []
608
+ for root_id in other_roots:
609
+ tree = build_tree(root_id)
610
+ if tree:
611
+ root_trees.append(tree)
612
+
613
+ # Handle WorkingMemory
614
+ if working_memory_roots:
615
+ working_memory_children = []
616
+ for wm_root_id in working_memory_roots:
617
+ tree = build_tree(wm_root_id)
618
+ if tree:
619
+ working_memory_children.append(tree)
620
+
621
+ working_memory_node = {
622
+ "id": "WorkingMemory",
623
+ "node_name": "WorkingMemory",
624
+ "value": "WorkingMemory",
625
+ "memory_type": "WorkingMemory",
626
+ "children": working_memory_children,
627
+ "frequency": 0,
628
+ }
629
+
630
+ root_trees.append(working_memory_node)
631
+
632
+ # Create total root node
633
+ result = {
634
+ "id": "root",
635
+ "node_name": "root",
636
+ "value": "root",
637
+ "memory_type": "Root",
638
+ "children": root_trees,
639
+ "frequency": 0,
640
+ }
641
+
642
+ return result, node_type_count
643
+
644
+
645
+ def print_tree_structure(node: dict[str, Any], level: int = 0, max_level: int = 5):
646
+ """logger.info the first few layers of tree structure for easy viewing"""
647
+ if level > max_level:
648
+ return
649
+
650
+ indent = " " * level
651
+ node_id = node.get("id", "unknown")
652
+ node_name = node.get("node_name", "")
653
+ node_value = node.get("value", "")
654
+ memory_type = node.get("memory_type", "Unknown")
655
+
656
+ # Determine display method based on whether there are children
657
+ children = node.get("children", [])
658
+ if children:
659
+ # Intermediate node, display name, type and child count
660
+ logger.info(f"{indent}- {node_name} [{memory_type}] ({len(children)} children)")
661
+ logger.info(f"{indent} ID: {node_id}")
662
+ display_value = node_value[:80] + "..." if len(node_value) > 80 else node_value
663
+ logger.info(f"{indent} Value: {display_value}")
664
+
665
+ if level < max_level:
666
+ for child in children:
667
+ print_tree_structure(child, level + 1, max_level)
668
+ elif level == max_level:
669
+ logger.info(f"{indent} ... (expansion limited)")
670
+ else:
671
+ # Leaf node, display name, type and value
672
+ display_value = node_value[:80] + "..." if len(node_value) > 80 else node_value
673
+ logger.info(f"{indent}- {node_name} [{memory_type}]: {display_value}")
674
+ logger.info(f"{indent} ID: {node_id}")
675
+
676
+
677
+ def analyze_final_tree_quality(tree_data: dict[str, Any]) -> dict:
678
+ """Analyze final tree quality, including type diversity, branch structure, etc."""
679
+ stats = {
680
+ "total_nodes": 0,
681
+ "by_type": {},
682
+ "by_depth": {},
683
+ "max_depth": 0,
684
+ "total_leaves": 0,
685
+ "total_branches": 0, # Number of branch nodes with multiple children
686
+ "subtrees": [],
687
+ "type_diversity": {},
688
+ "structure_quality": {},
689
+ "chain_analysis": {}, # Single chain structure analysis
690
+ }
691
+
692
+ def analyze_subtree(node, depth=0, parent_path="", chain_length=0):
693
+ stats["total_nodes"] += 1
694
+ stats["max_depth"] = max(stats["max_depth"], depth)
695
+
696
+ # Count by type
697
+ memory_type = node.get("memory_type", "Unknown")
698
+ stats["by_type"][memory_type] = stats["by_type"].get(memory_type, 0) + 1
699
+
700
+ # Count by depth
701
+ stats["by_depth"][depth] = stats["by_depth"].get(depth, 0) + 1
702
+
703
+ children = node.get("children", [])
704
+ current_path = (
705
+ f"{parent_path}/{node.get('node_name', 'unknown')}"
706
+ if parent_path
707
+ else node.get("node_name", "root")
708
+ )
709
+
710
+ # Analyze node type
711
+ if not children: # Leaf node
712
+ stats["total_leaves"] += 1
713
+ # Record chain length
714
+ if "max_chain_length" not in stats["chain_analysis"]:
715
+ stats["chain_analysis"]["max_chain_length"] = 0
716
+ stats["chain_analysis"]["max_chain_length"] = max(
717
+ stats["chain_analysis"]["max_chain_length"], chain_length
718
+ )
719
+ elif len(children) == 1: # Single child node (chain)
720
+ # Continue calculating chain length
721
+ for child in children:
722
+ analyze_subtree(child, depth + 1, current_path, chain_length + 1)
723
+ return # Early return to avoid duplicate processing
724
+ else: # Branch node (multiple children)
725
+ stats["total_branches"] += 1
726
+ # Reset chain length
727
+ chain_length = 0
728
+
729
+ # If it's the root node of a major subtree, analyze its characteristics
730
+ if depth <= 2 and children: # Major subtree
731
+ subtree_depth = 0
732
+ subtree_leaves = 0
733
+ subtree_nodes = 0
734
+ subtree_branches = 0
735
+ subtree_types = {}
736
+ subtree_max_width = 0
737
+ width_per_level = {}
738
+
739
+ def count_subtree(subnode, subdepth):
740
+ nonlocal \
741
+ subtree_depth, \
742
+ subtree_leaves, \
743
+ subtree_nodes, \
744
+ subtree_branches, \
745
+ subtree_max_width
746
+ subtree_nodes += 1
747
+ subtree_depth = max(subtree_depth, subdepth)
748
+
749
+ # Count type distribution within subtree
750
+ sub_memory_type = subnode.get("memory_type", "Unknown")
751
+ subtree_types[sub_memory_type] = subtree_types.get(sub_memory_type, 0) + 1
752
+
753
+ # Count width per level
754
+ width_per_level[subdepth] = width_per_level.get(subdepth, 0) + 1
755
+ subtree_max_width = max(subtree_max_width, width_per_level[subdepth])
756
+
757
+ subchildren = subnode.get("children", [])
758
+ if not subchildren:
759
+ subtree_leaves += 1
760
+ elif len(subchildren) > 1:
761
+ subtree_branches += 1
762
+
763
+ for child in subchildren:
764
+ count_subtree(child, subdepth + 1)
765
+
766
+ count_subtree(node, 0)
767
+
768
+ # Calculate subtree quality metrics
769
+ branch_density = subtree_branches / subtree_nodes if subtree_nodes > 0 else 0
770
+ leaf_ratio = subtree_leaves / subtree_nodes if subtree_nodes > 0 else 0
771
+ depth_width_ratio = (
772
+ subtree_depth / subtree_max_width if subtree_max_width > 0 else subtree_depth
773
+ )
774
+
775
+ stats["subtrees"].append(
776
+ {
777
+ "root": node.get("node_name", "unknown"),
778
+ "type": memory_type,
779
+ "depth": subtree_depth,
780
+ "leaves": subtree_leaves,
781
+ "nodes": subtree_nodes,
782
+ "branches": subtree_branches,
783
+ "branch_density": branch_density,
784
+ "leaf_ratio": leaf_ratio,
785
+ "max_width": subtree_max_width,
786
+ "depth_width_ratio": depth_width_ratio,
787
+ "path": current_path,
788
+ "type_distribution": subtree_types,
789
+ "quality_score": calculate_enhanced_quality(
790
+ subtree_depth,
791
+ subtree_leaves,
792
+ subtree_nodes,
793
+ subtree_branches,
794
+ 0,
795
+ branch_density,
796
+ depth_width_ratio,
797
+ subtree_max_width,
798
+ ),
799
+ }
800
+ )
801
+
802
+ # Recursively analyze child nodes
803
+ for child in children:
804
+ analyze_subtree(child, depth + 1, current_path, 0) # Reset chain length
805
+
806
+ analyze_subtree(tree_data)
807
+
808
+ # Calculate overall structure quality
809
+ if stats["total_nodes"] > 1:
810
+ branch_density = stats["total_branches"] / stats["total_nodes"]
811
+ leaf_ratio = stats["total_leaves"] / stats["total_nodes"]
812
+
813
+ # Calculate average width per level
814
+ total_width = sum(stats["by_depth"].values())
815
+ avg_width = total_width / len(stats["by_depth"]) if stats["by_depth"] else 0
816
+ max_width = max(stats["by_depth"].values()) if stats["by_depth"] else 0
817
+
818
+ stats["structure_quality"] = {
819
+ "branch_density": branch_density,
820
+ "leaf_ratio": leaf_ratio,
821
+ "avg_width": avg_width,
822
+ "max_width": max_width,
823
+ "depth_width_ratio": stats["max_depth"] / max_width
824
+ if max_width > 0
825
+ else stats["max_depth"],
826
+ "is_well_balanced": 0.2 <= branch_density <= 0.6 and 0.3 <= leaf_ratio <= 0.7,
827
+ }
828
+
829
+ # Calculate type diversity metrics
830
+ total_types = len(stats["by_type"])
831
+ if total_types > 1:
832
+ # Calculate uniformity of type distribution (Shannon diversity index)
833
+ shannon_diversity = 0
834
+ for count in stats["by_type"].values():
835
+ if count > 0:
836
+ p = count / stats["total_nodes"]
837
+ shannon_diversity -= p * math.log2(p)
838
+
839
+ # Normalize diversity index (0-1 range)
840
+ max_diversity = math.log2(total_types) if total_types > 1 else 0
841
+ normalized_diversity = shannon_diversity / max_diversity if max_diversity > 0 else 0
842
+
843
+ stats["type_diversity"] = {
844
+ "total_types": total_types,
845
+ "shannon_diversity": shannon_diversity,
846
+ "normalized_diversity": normalized_diversity,
847
+ "distribution_balance": min(stats["by_type"].values()) / max(stats["by_type"].values())
848
+ if max(stats["by_type"].values()) > 0
849
+ else 0,
850
+ }
851
+
852
+ # Single chain analysis
853
+ total_single_child_nodes = sum(
854
+ 1 for subtree in stats["subtrees"] if subtree.get("branch_density", 0) < 0.1
855
+ )
856
+ stats["chain_analysis"].update(
857
+ {
858
+ "single_chain_subtrees": total_single_child_nodes,
859
+ "chain_subtree_ratio": total_single_child_nodes / len(stats["subtrees"])
860
+ if stats["subtrees"]
861
+ else 0,
862
+ }
863
+ )
864
+
865
+ return stats
866
+
867
+
868
+ def print_tree_analysis(tree_data: dict[str, Any]):
869
+ """logger.info enhanced tree analysis results"""
870
+ stats = analyze_final_tree_quality(tree_data)
871
+
872
+ logger.info("\n" + "=" * 60)
873
+ logger.info("🌳 Enhanced Tree Structure Quality Analysis Report")
874
+ logger.info("=" * 60)
875
+
876
+ # Basic statistics
877
+ logger.info("\n📊 Basic Statistics:")
878
+ logger.info(f" Total nodes: {stats['total_nodes']}")
879
+ logger.info(f" Max depth: {stats['max_depth']}")
880
+ logger.info(
881
+ f" Leaf nodes: {stats['total_leaves']} ({stats['total_leaves'] / stats['total_nodes'] * 100:.1f}%)"
882
+ )
883
+ logger.info(
884
+ f" Branch nodes: {stats['total_branches']} ({stats['total_branches'] / stats['total_nodes'] * 100:.1f}%)"
885
+ )
886
+
887
+ # Structure quality assessment
888
+ structure = stats.get("structure_quality", {})
889
+ if structure:
890
+ logger.info("\n🏗️ Structure Quality Assessment:")
891
+ logger.info(
892
+ f" Branch density: {structure['branch_density']:.3f} ({'✅ Good' if 0.2 <= structure['branch_density'] <= 0.6 else '⚠️ Needs improvement'})"
893
+ )
894
+ logger.info(
895
+ f" Leaf ratio: {structure['leaf_ratio']:.3f} ({'✅ Good' if 0.3 <= structure['leaf_ratio'] <= 0.7 else '⚠️ Needs improvement'})"
896
+ )
897
+ logger.info(f" Max width: {structure['max_width']}")
898
+ logger.info(
899
+ f" Depth-width ratio: {structure['depth_width_ratio']:.2f} ({'✅ Good' if structure['depth_width_ratio'] <= 3 else '⚠️ Too thin'})"
900
+ )
901
+ logger.info(
902
+ f" Overall balance: {'✅ Good' if structure['is_well_balanced'] else '⚠️ Needs improvement'}"
903
+ )
904
+
905
+ # Single chain analysis
906
+ chain_analysis = stats.get("chain_analysis", {})
907
+ if chain_analysis:
908
+ logger.info("\n🔗 Single Chain Structure Analysis:")
909
+ logger.info(f" Longest chain: {chain_analysis.get('max_chain_length', 0)} layers")
910
+ logger.info(f" Single chain subtrees: {chain_analysis.get('single_chain_subtrees', 0)}")
911
+ logger.info(
912
+ f" Single chain subtree ratio: {chain_analysis.get('chain_subtree_ratio', 0) * 100:.1f}%"
913
+ )
914
+
915
+ if chain_analysis.get("max_chain_length", 0) > 5:
916
+ logger.info(" ⚠️ Warning: Overly long single chain structure may affect display")
917
+ elif chain_analysis.get("chain_subtree_ratio", 0) > 0.3:
918
+ logger.info(
919
+ " ⚠️ Warning: Too many single chain subtrees, suggest increasing branch structure"
920
+ )
921
+ else:
922
+ logger.info(" ✅ Single chain structure well controlled")
923
+
924
+ # Type diversity
925
+ type_div = stats.get("type_diversity", {})
926
+ if type_div:
927
+ logger.info("\n🎨 Type Diversity Analysis:")
928
+ logger.info(f" Total types: {type_div['total_types']}")
929
+ logger.info(f" Diversity index: {type_div['shannon_diversity']:.3f}")
930
+ logger.info(f" Normalized diversity: {type_div['normalized_diversity']:.3f}")
931
+ logger.info(f" Distribution balance: {type_div['distribution_balance']:.3f}")
932
+
933
+ # Type distribution
934
+ logger.info("\n📋 Type Distribution Details:")
935
+ for mem_type, count in sorted(stats["by_type"].items(), key=lambda x: x[1], reverse=True):
936
+ percentage = count / stats["total_nodes"] * 100
937
+ logger.info(f" {mem_type}: {count} nodes ({percentage:.1f}%)")
938
+
939
+ # Depth distribution
940
+ logger.info("\n📏 Depth Distribution:")
941
+ for depth in sorted(stats["by_depth"].keys()):
942
+ count = stats["by_depth"][depth]
943
+ logger.info(f" Depth {depth}: {count} nodes")
944
+
945
+ # Major subtree analysis
946
+ if stats["subtrees"]:
947
+ logger.info("\n🌲 Major Subtree Analysis (sorted by quality):")
948
+ sorted_subtrees = sorted(
949
+ stats["subtrees"], key=lambda x: x.get("quality_score", 0), reverse=True
950
+ )
951
+ for i, subtree in enumerate(sorted_subtrees[:8]): # Show first 8
952
+ quality = subtree.get("quality_score", 0)
953
+ logger.info(f" #{i + 1} {subtree['root']} [{subtree['type']}]:")
954
+ logger.info(f" Quality score: {quality:.2f}")
955
+ logger.info(
956
+ f" Structure: Depth={subtree['depth']}, Branches={subtree['branches']}, Leaves={subtree['leaves']}"
957
+ )
958
+ logger.info(
959
+ f" Density: Branch density={subtree.get('branch_density', 0):.3f}, Leaf ratio={subtree.get('leaf_ratio', 0):.3f}"
960
+ )
961
+
962
+ if quality > 15:
963
+ logger.info(" ✅ High quality subtree")
964
+ elif quality > 8:
965
+ logger.info(" 🟡 Medium quality subtree")
966
+ else:
967
+ logger.info(" 🔴 Low quality subtree")
968
+
969
+ logger.info("\n" + "=" * 60)
970
+
971
+
972
+ def remove_embedding_recursive(memory_info: dict) -> Any:
973
+ """remove the embedding from the memory info
974
+ Args:
975
+ memory_info: product memory info
976
+
977
+ Returns:
978
+ Any: product memory info without embedding
979
+ """
980
+ if isinstance(memory_info, dict):
981
+ new_dict = {}
982
+ for key, value in memory_info.items():
983
+ if key != "embedding":
984
+ new_dict[key] = remove_embedding_recursive(value)
985
+ return new_dict
986
+ elif isinstance(memory_info, list):
987
+ return [remove_embedding_recursive(item) for item in memory_info]
988
+ else:
989
+ return memory_info
990
+
991
+
992
+ def remove_embedding_from_memory_items(memory_items: list[Any]) -> list[dict]:
993
+ """Batch remove embedding fields from multiple TextualMemoryItem objects"""
994
+ clean_memories = []
995
+
996
+ for item in memory_items:
997
+ memory_dict = item.model_dump()
998
+
999
+ # Remove embedding from metadata
1000
+ if "metadata" in memory_dict and "embedding" in memory_dict["metadata"]:
1001
+ del memory_dict["metadata"]["embedding"]
1002
+
1003
+ clean_memories.append(memory_dict)
1004
+
1005
+ return clean_memories
1006
+
1007
+
1008
+ def sort_children_by_memory_type(children: list[dict[str, Any]]) -> list[dict[str, Any]]:
1009
+ """
1010
+ sort the children by the memory_type
1011
+ Args:
1012
+ children: the children of the node
1013
+ Returns:
1014
+ the sorted children
1015
+ """
1016
+ if not children:
1017
+ return children
1018
+
1019
+ def get_sort_key(child):
1020
+ memory_type = child.get("memory_type", "Unknown")
1021
+ # Sort directly by memory_type string, same types will naturally cluster together
1022
+ return memory_type
1023
+
1024
+ # Sort by memory_type
1025
+ sorted_children = sorted(children, key=get_sort_key)
1026
+
1027
+ return sorted_children
1028
+
1029
+
1030
+ def extract_all_ids_from_tree(tree_node):
1031
+ """
1032
+ Recursively traverse tree structure to extract all node IDs
1033
+
1034
+ Args:
1035
+ tree_node: Tree node (dictionary format)
1036
+
1037
+ Returns:
1038
+ set: Set containing all node IDs
1039
+ """
1040
+ ids = set()
1041
+
1042
+ # Add current node ID (if exists)
1043
+ if "id" in tree_node:
1044
+ ids.add(tree_node["id"])
1045
+
1046
+ # Recursively process child nodes
1047
+ if tree_node.get("children"):
1048
+ for child in tree_node["children"]:
1049
+ ids.update(extract_all_ids_from_tree(child))
1050
+
1051
+ return ids
1052
+
1053
+
1054
+ def filter_nodes_by_tree_ids(tree_data, nodes_data):
1055
+ """
1056
+ Filter nodes list based on IDs used in tree structure
1057
+
1058
+ Args:
1059
+ tree_data: Tree structure data (dictionary)
1060
+ nodes_data: Data containing nodes list (dictionary)
1061
+
1062
+ Returns:
1063
+ dict: Filtered nodes data, maintaining original structure
1064
+ """
1065
+ # Extract all IDs used in the tree
1066
+ used_ids = extract_all_ids_from_tree(tree_data)
1067
+
1068
+ # Filter nodes list, keeping only nodes with IDs used in the tree
1069
+ filtered_nodes = [node for node in nodes_data["nodes"] if node["id"] in used_ids]
1070
+
1071
+ # Return result maintaining original structure
1072
+ return {"nodes": filtered_nodes}
1073
+
1074
+
1075
+ def convert_activation_memory_to_serializable(
1076
+ act_mem_items: list[KVCacheItem],
1077
+ ) -> list[dict[str, Any]]:
1078
+ """
1079
+ Convert activation memory items to a serializable format.
1080
+
1081
+ Args:
1082
+ act_mem_items: List of KVCacheItem objects
1083
+
1084
+ Returns:
1085
+ List of dictionaries with serializable data
1086
+ """
1087
+ serializable_items = []
1088
+
1089
+ for item in act_mem_items:
1090
+ key_layers = 0
1091
+ val_layers = 0
1092
+ device = "unknown"
1093
+ dtype = "unknown"
1094
+ key_shapes = []
1095
+ value_shapes = []
1096
+
1097
+ if item.memory:
1098
+ if hasattr(item.memory, "layers"):
1099
+ key_layers = len(item.memory.layers)
1100
+ val_layers = len(item.memory.layers)
1101
+ if key_layers > 0:
1102
+ l0 = item.memory.layers[0]
1103
+ k0 = getattr(l0, "key_cache", getattr(l0, "keys", None))
1104
+ if k0 is not None:
1105
+ device = str(k0.device)
1106
+ dtype = str(k0.dtype)
1107
+
1108
+ for i, layer in enumerate(item.memory.layers):
1109
+ k = getattr(layer, "key_cache", getattr(layer, "keys", None))
1110
+ v = getattr(layer, "value_cache", getattr(layer, "values", None))
1111
+ if k is not None:
1112
+ key_shapes.append({"layer": i, "shape": list(k.shape)})
1113
+ if v is not None:
1114
+ value_shapes.append({"layer": i, "shape": list(v.shape)})
1115
+
1116
+ elif hasattr(item.memory, "key_cache"):
1117
+ key_layers = len(item.memory.key_cache)
1118
+ val_layers = len(item.memory.value_cache)
1119
+ if key_layers > 0 and item.memory.key_cache[0] is not None:
1120
+ device = str(item.memory.key_cache[0].device)
1121
+ dtype = str(item.memory.key_cache[0].dtype)
1122
+
1123
+ for i, key_tensor in enumerate(item.memory.key_cache):
1124
+ if key_tensor is not None:
1125
+ key_shapes.append({"layer": i, "shape": list(key_tensor.shape)})
1126
+
1127
+ for i, val_tensor in enumerate(item.memory.value_cache):
1128
+ if val_tensor is not None:
1129
+ value_shapes.append({"layer": i, "shape": list(val_tensor.shape)})
1130
+
1131
+ # Extract basic information that can be serialized
1132
+ serializable_item = {
1133
+ "id": item.id,
1134
+ "metadata": item.metadata,
1135
+ "memory_info": {
1136
+ "type": "DynamicCache",
1137
+ "key_cache_layers": key_layers,
1138
+ "value_cache_layers": val_layers,
1139
+ "device": device,
1140
+ "dtype": dtype,
1141
+ },
1142
+ }
1143
+
1144
+ # Add tensor shape information if available
1145
+ if key_shapes:
1146
+ serializable_item["memory_info"]["key_shapes"] = key_shapes
1147
+ if value_shapes:
1148
+ serializable_item["memory_info"]["value_shapes"] = value_shapes
1149
+
1150
+ serializable_items.append(serializable_item)
1151
+
1152
+ return serializable_items
1153
+
1154
+
1155
+ def convert_activation_memory_summary(act_mem_items: list[KVCacheItem]) -> dict[str, Any]:
1156
+ """
1157
+ Create a summary of activation memory for API responses.
1158
+
1159
+ Args:
1160
+ act_mem_items: List of KVCacheItem objects
1161
+
1162
+ Returns:
1163
+ Dictionary with summary information
1164
+ """
1165
+ if not act_mem_items:
1166
+ return {"total_items": 0, "summary": "No activation memory items found"}
1167
+
1168
+ total_items = len(act_mem_items)
1169
+ total_layers = 0
1170
+ total_parameters = 0
1171
+
1172
+ for item in act_mem_items:
1173
+ if not item.memory:
1174
+ continue
1175
+
1176
+ if hasattr(item.memory, "layers"):
1177
+ total_layers += len(item.memory.layers)
1178
+ for layer in item.memory.layers:
1179
+ k = getattr(layer, "key_cache", getattr(layer, "keys", None))
1180
+ v = getattr(layer, "value_cache", getattr(layer, "values", None))
1181
+ if k is not None:
1182
+ total_parameters += k.numel()
1183
+ if v is not None:
1184
+ total_parameters += v.numel()
1185
+ elif hasattr(item.memory, "key_cache"):
1186
+ total_layers += len(item.memory.key_cache)
1187
+
1188
+ # Calculate approximate parameter count
1189
+ for key_tensor in item.memory.key_cache:
1190
+ if key_tensor is not None:
1191
+ total_parameters += key_tensor.numel()
1192
+
1193
+ for value_tensor in item.memory.value_cache:
1194
+ if value_tensor is not None:
1195
+ total_parameters += value_tensor.numel()
1196
+
1197
+ return {
1198
+ "total_items": total_items,
1199
+ "total_layers": total_layers,
1200
+ "total_parameters": total_parameters,
1201
+ "summary": f"Activation memory contains {total_items} items with {total_layers} layers and approximately {total_parameters:,} parameters",
1202
+ }
1203
+
1204
+
1205
+ def detect_and_remove_duplicate_ids(tree_node: dict[str, Any]) -> dict[str, Any]:
1206
+ """
1207
+ Detect and remove duplicate IDs in tree structure by skipping duplicate nodes.
1208
+ First occurrence of each ID is kept, subsequent duplicates are removed.
1209
+
1210
+ Args:
1211
+ tree_node: Tree node (dictionary format)
1212
+
1213
+ Returns:
1214
+ dict: Fixed tree node with duplicate nodes removed
1215
+ """
1216
+ used_ids = set()
1217
+ removed_count = 0
1218
+
1219
+ def remove_duplicates_recursive(
1220
+ node: dict[str, Any], parent_path: str = ""
1221
+ ) -> dict[str, Any] | None:
1222
+ """Recursively remove duplicate IDs by skipping duplicate nodes"""
1223
+ nonlocal removed_count
1224
+
1225
+ if not isinstance(node, dict):
1226
+ return node
1227
+
1228
+ # Create node copy
1229
+ fixed_node = node.copy()
1230
+
1231
+ # Handle current node ID
1232
+ current_id = fixed_node.get("id", "")
1233
+ if current_id in used_ids and current_id not in ["root", "WorkingMemory"]:
1234
+ # Skip this duplicate node
1235
+ logger.info(f"Skipping duplicate node: {current_id} (path: {parent_path})")
1236
+ removed_count += 1
1237
+ return None # Return None to indicate this node should be removed
1238
+ else:
1239
+ used_ids.add(current_id)
1240
+
1241
+ # Recursively process child nodes
1242
+ if "children" in fixed_node and isinstance(fixed_node["children"], list):
1243
+ fixed_children = []
1244
+ for i, child in enumerate(fixed_node["children"]):
1245
+ child_path = f"{parent_path}/{fixed_node.get('node_name', 'unknown')}[{i}]"
1246
+ fixed_child = remove_duplicates_recursive(child, child_path)
1247
+ if fixed_child is not None: # Only add non-None children
1248
+ fixed_children.append(fixed_child)
1249
+ fixed_node["children"] = fixed_children
1250
+
1251
+ return fixed_node
1252
+
1253
+ result = remove_duplicates_recursive(tree_node)
1254
+ if result is not None:
1255
+ logger.info(f"Removed {removed_count} duplicate nodes")
1256
+ return result
1257
+ else:
1258
+ # If root node itself was removed (shouldn't happen), return empty root
1259
+ return {
1260
+ "id": "root",
1261
+ "node_name": "root",
1262
+ "value": "root",
1263
+ "memory_type": "Root",
1264
+ "children": [],
1265
+ }
1266
+
1267
+
1268
+ def validate_tree_structure(tree_node: dict[str, Any]) -> dict[str, Any]:
1269
+ """
1270
+ Validate tree structure integrity, including ID uniqueness check
1271
+
1272
+ Args:
1273
+ tree_node: Tree node (dictionary format)
1274
+
1275
+ Returns:
1276
+ dict: Validation result containing error messages and fix suggestions
1277
+ """
1278
+ validation_result = {
1279
+ "is_valid": True,
1280
+ "errors": [],
1281
+ "warnings": [],
1282
+ "total_nodes": 0,
1283
+ "unique_ids": set(),
1284
+ "duplicate_ids": set(),
1285
+ "missing_ids": set(),
1286
+ "invalid_structure": [],
1287
+ }
1288
+
1289
+ def validate_recursive(node: dict[str, Any], path: str = "", depth: int = 0):
1290
+ """Recursively validate tree structure"""
1291
+ if not isinstance(node, dict):
1292
+ validation_result["errors"].append(f"Node is not a dictionary: {path}")
1293
+ validation_result["is_valid"] = False
1294
+ return
1295
+
1296
+ validation_result["total_nodes"] += 1
1297
+
1298
+ # Check required fields
1299
+ if "id" not in node:
1300
+ validation_result["errors"].append(f"Node missing ID field: {path}")
1301
+ validation_result["missing_ids"].add(path)
1302
+ validation_result["is_valid"] = False
1303
+ else:
1304
+ node_id = node["id"]
1305
+ if node_id in validation_result["unique_ids"]:
1306
+ validation_result["errors"].append(f"Duplicate node ID: {node_id} (path: {path})")
1307
+ validation_result["duplicate_ids"].add(node_id)
1308
+ validation_result["is_valid"] = False
1309
+ else:
1310
+ validation_result["unique_ids"].add(node_id)
1311
+
1312
+ # Check other required fields
1313
+ required_fields = ["node_name", "value", "memory_type"]
1314
+ for field in required_fields:
1315
+ if field not in node:
1316
+ validation_result["warnings"].append(f"Node missing field '{field}': {path}")
1317
+
1318
+ # Recursively validate child nodes
1319
+ if "children" in node:
1320
+ if not isinstance(node["children"], list):
1321
+ validation_result["errors"].append(f"Children field is not a list: {path}")
1322
+ validation_result["is_valid"] = False
1323
+ else:
1324
+ for i, child in enumerate(node["children"]):
1325
+ child_path = f"{path}/children[{i}]"
1326
+ validate_recursive(child, child_path, depth + 1)
1327
+
1328
+ # Check depth limit
1329
+ if depth > 20:
1330
+ validation_result["warnings"].append(f"Tree depth too deep ({depth}): {path}")
1331
+
1332
+ validate_recursive(tree_node)
1333
+
1334
+ # Generate fix suggestions
1335
+ if validation_result["duplicate_ids"]:
1336
+ validation_result["fix_suggestion"] = (
1337
+ "Use detect_and_fix_duplicate_ids() function to fix duplicate IDs"
1338
+ )
1339
+
1340
+ return validation_result
1341
+
1342
+
1343
+ def ensure_unique_tree_ids(tree_result: dict[str, Any]) -> dict[str, Any]:
1344
+ """
1345
+ Ensure all node IDs in tree structure are unique by removing duplicate nodes,
1346
+ this is a post-processing function for convert_graph_to_tree_forworkmem
1347
+
1348
+ Args:
1349
+ tree_result: Tree structure returned by convert_graph_to_tree_forworkmem
1350
+
1351
+ Returns:
1352
+ dict: Fixed tree structure with duplicate nodes removed
1353
+ """
1354
+ logger.info("🔍 Starting duplicate ID check in tree structure...")
1355
+
1356
+ # First validate tree structure
1357
+ validation = validate_tree_structure(tree_result)
1358
+
1359
+ if validation["is_valid"]:
1360
+ logger.info("Tree structure validation passed, no duplicate IDs found")
1361
+ return tree_result
1362
+
1363
+ # Report issues
1364
+ logger.info(f"Found {len(validation['errors'])} errors:")
1365
+ for error in validation["errors"][:5]: # Only show first 5 errors
1366
+ logger.info(f" - {error}")
1367
+
1368
+ if len(validation["errors"]) > 5:
1369
+ logger.info(f" ... and {len(validation['errors']) - 5} more errors")
1370
+
1371
+ logger.info("Statistics:")
1372
+ logger.info(f" - Total nodes: {validation['total_nodes']}")
1373
+ logger.info(f" - Unique IDs: {len(validation['unique_ids'])}")
1374
+ logger.info(f" - Duplicate IDs: {len(validation['duplicate_ids'])}")
1375
+
1376
+ # Remove duplicate nodes
1377
+ logger.info(" Starting duplicate node removal...")
1378
+ fixed_tree = detect_and_remove_duplicate_ids(tree_result)
1379
+
1380
+ # Validate again
1381
+ post_validation = validate_tree_structure(fixed_tree)
1382
+ if post_validation["is_valid"]:
1383
+ logger.info("Removal completed, tree structure is now valid")
1384
+ logger.info(f"Final node count: {post_validation['total_nodes']}")
1385
+ else:
1386
+ logger.info("Issues remain after removal, please check code logic")
1387
+ for error in post_validation["errors"][:3]:
1388
+ logger.info(f" - {error}")
1389
+
1390
+ return fixed_tree
1391
+
1392
+
1393
+ def clean_json_response(response: str) -> str:
1394
+ """
1395
+ Remove markdown JSON code block formatting from LLM response.
1396
+
1397
+ Args:
1398
+ response: Raw response string that may contain ```json and ```
1399
+
1400
+ Returns:
1401
+ str: Clean JSON string without markdown formatting
1402
+ """
1403
+ return response.replace("```json", "").replace("```", "").strip()