MemoryOS 2.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- memoryos-2.0.3.dist-info/METADATA +418 -0
- memoryos-2.0.3.dist-info/RECORD +315 -0
- memoryos-2.0.3.dist-info/WHEEL +4 -0
- memoryos-2.0.3.dist-info/entry_points.txt +3 -0
- memoryos-2.0.3.dist-info/licenses/LICENSE +201 -0
- memos/__init__.py +20 -0
- memos/api/client.py +571 -0
- memos/api/config.py +1018 -0
- memos/api/context/dependencies.py +50 -0
- memos/api/exceptions.py +53 -0
- memos/api/handlers/__init__.py +62 -0
- memos/api/handlers/add_handler.py +158 -0
- memos/api/handlers/base_handler.py +194 -0
- memos/api/handlers/chat_handler.py +1401 -0
- memos/api/handlers/component_init.py +388 -0
- memos/api/handlers/config_builders.py +190 -0
- memos/api/handlers/feedback_handler.py +93 -0
- memos/api/handlers/formatters_handler.py +237 -0
- memos/api/handlers/memory_handler.py +316 -0
- memos/api/handlers/scheduler_handler.py +497 -0
- memos/api/handlers/search_handler.py +222 -0
- memos/api/handlers/suggestion_handler.py +117 -0
- memos/api/mcp_serve.py +614 -0
- memos/api/middleware/request_context.py +101 -0
- memos/api/product_api.py +38 -0
- memos/api/product_models.py +1206 -0
- memos/api/routers/__init__.py +1 -0
- memos/api/routers/product_router.py +477 -0
- memos/api/routers/server_router.py +394 -0
- memos/api/server_api.py +44 -0
- memos/api/start_api.py +433 -0
- memos/chunkers/__init__.py +4 -0
- memos/chunkers/base.py +24 -0
- memos/chunkers/charactertext_chunker.py +41 -0
- memos/chunkers/factory.py +24 -0
- memos/chunkers/markdown_chunker.py +62 -0
- memos/chunkers/sentence_chunker.py +54 -0
- memos/chunkers/simple_chunker.py +50 -0
- memos/cli.py +113 -0
- memos/configs/__init__.py +0 -0
- memos/configs/base.py +82 -0
- memos/configs/chunker.py +59 -0
- memos/configs/embedder.py +88 -0
- memos/configs/graph_db.py +236 -0
- memos/configs/internet_retriever.py +100 -0
- memos/configs/llm.py +151 -0
- memos/configs/mem_agent.py +54 -0
- memos/configs/mem_chat.py +81 -0
- memos/configs/mem_cube.py +105 -0
- memos/configs/mem_os.py +83 -0
- memos/configs/mem_reader.py +91 -0
- memos/configs/mem_scheduler.py +385 -0
- memos/configs/mem_user.py +70 -0
- memos/configs/memory.py +324 -0
- memos/configs/parser.py +38 -0
- memos/configs/reranker.py +18 -0
- memos/configs/utils.py +8 -0
- memos/configs/vec_db.py +80 -0
- memos/context/context.py +355 -0
- memos/dependency.py +52 -0
- memos/deprecation.py +262 -0
- memos/embedders/__init__.py +0 -0
- memos/embedders/ark.py +95 -0
- memos/embedders/base.py +106 -0
- memos/embedders/factory.py +29 -0
- memos/embedders/ollama.py +77 -0
- memos/embedders/sentence_transformer.py +49 -0
- memos/embedders/universal_api.py +51 -0
- memos/exceptions.py +30 -0
- memos/graph_dbs/__init__.py +0 -0
- memos/graph_dbs/base.py +274 -0
- memos/graph_dbs/factory.py +27 -0
- memos/graph_dbs/item.py +46 -0
- memos/graph_dbs/nebular.py +1794 -0
- memos/graph_dbs/neo4j.py +1942 -0
- memos/graph_dbs/neo4j_community.py +1058 -0
- memos/graph_dbs/polardb.py +5446 -0
- memos/hello_world.py +97 -0
- memos/llms/__init__.py +0 -0
- memos/llms/base.py +25 -0
- memos/llms/deepseek.py +13 -0
- memos/llms/factory.py +38 -0
- memos/llms/hf.py +443 -0
- memos/llms/hf_singleton.py +114 -0
- memos/llms/ollama.py +135 -0
- memos/llms/openai.py +222 -0
- memos/llms/openai_new.py +198 -0
- memos/llms/qwen.py +13 -0
- memos/llms/utils.py +14 -0
- memos/llms/vllm.py +218 -0
- memos/log.py +237 -0
- memos/mem_agent/base.py +19 -0
- memos/mem_agent/deepsearch_agent.py +391 -0
- memos/mem_agent/factory.py +36 -0
- memos/mem_chat/__init__.py +0 -0
- memos/mem_chat/base.py +30 -0
- memos/mem_chat/factory.py +21 -0
- memos/mem_chat/simple.py +200 -0
- memos/mem_cube/__init__.py +0 -0
- memos/mem_cube/base.py +30 -0
- memos/mem_cube/general.py +240 -0
- memos/mem_cube/navie.py +172 -0
- memos/mem_cube/utils.py +169 -0
- memos/mem_feedback/base.py +15 -0
- memos/mem_feedback/feedback.py +1192 -0
- memos/mem_feedback/simple_feedback.py +40 -0
- memos/mem_feedback/utils.py +230 -0
- memos/mem_os/client.py +5 -0
- memos/mem_os/core.py +1203 -0
- memos/mem_os/main.py +582 -0
- memos/mem_os/product.py +1608 -0
- memos/mem_os/product_server.py +455 -0
- memos/mem_os/utils/default_config.py +359 -0
- memos/mem_os/utils/format_utils.py +1403 -0
- memos/mem_os/utils/reference_utils.py +162 -0
- memos/mem_reader/__init__.py +0 -0
- memos/mem_reader/base.py +47 -0
- memos/mem_reader/factory.py +53 -0
- memos/mem_reader/memory.py +298 -0
- memos/mem_reader/multi_modal_struct.py +965 -0
- memos/mem_reader/read_multi_modal/__init__.py +43 -0
- memos/mem_reader/read_multi_modal/assistant_parser.py +311 -0
- memos/mem_reader/read_multi_modal/base.py +273 -0
- memos/mem_reader/read_multi_modal/file_content_parser.py +826 -0
- memos/mem_reader/read_multi_modal/image_parser.py +359 -0
- memos/mem_reader/read_multi_modal/multi_modal_parser.py +252 -0
- memos/mem_reader/read_multi_modal/string_parser.py +139 -0
- memos/mem_reader/read_multi_modal/system_parser.py +327 -0
- memos/mem_reader/read_multi_modal/text_content_parser.py +131 -0
- memos/mem_reader/read_multi_modal/tool_parser.py +210 -0
- memos/mem_reader/read_multi_modal/user_parser.py +218 -0
- memos/mem_reader/read_multi_modal/utils.py +358 -0
- memos/mem_reader/simple_struct.py +912 -0
- memos/mem_reader/strategy_struct.py +163 -0
- memos/mem_reader/utils.py +157 -0
- memos/mem_scheduler/__init__.py +0 -0
- memos/mem_scheduler/analyzer/__init__.py +0 -0
- memos/mem_scheduler/analyzer/api_analyzer.py +714 -0
- memos/mem_scheduler/analyzer/eval_analyzer.py +219 -0
- memos/mem_scheduler/analyzer/mos_for_test_scheduler.py +571 -0
- memos/mem_scheduler/analyzer/scheduler_for_eval.py +280 -0
- memos/mem_scheduler/base_scheduler.py +1319 -0
- memos/mem_scheduler/general_modules/__init__.py +0 -0
- memos/mem_scheduler/general_modules/api_misc.py +137 -0
- memos/mem_scheduler/general_modules/base.py +80 -0
- memos/mem_scheduler/general_modules/init_components_for_scheduler.py +425 -0
- memos/mem_scheduler/general_modules/misc.py +313 -0
- memos/mem_scheduler/general_modules/scheduler_logger.py +389 -0
- memos/mem_scheduler/general_modules/task_threads.py +315 -0
- memos/mem_scheduler/general_scheduler.py +1495 -0
- memos/mem_scheduler/memory_manage_modules/__init__.py +5 -0
- memos/mem_scheduler/memory_manage_modules/memory_filter.py +306 -0
- memos/mem_scheduler/memory_manage_modules/retriever.py +547 -0
- memos/mem_scheduler/monitors/__init__.py +0 -0
- memos/mem_scheduler/monitors/dispatcher_monitor.py +366 -0
- memos/mem_scheduler/monitors/general_monitor.py +394 -0
- memos/mem_scheduler/monitors/task_schedule_monitor.py +254 -0
- memos/mem_scheduler/optimized_scheduler.py +410 -0
- memos/mem_scheduler/orm_modules/__init__.py +0 -0
- memos/mem_scheduler/orm_modules/api_redis_model.py +518 -0
- memos/mem_scheduler/orm_modules/base_model.py +729 -0
- memos/mem_scheduler/orm_modules/monitor_models.py +261 -0
- memos/mem_scheduler/orm_modules/redis_model.py +699 -0
- memos/mem_scheduler/scheduler_factory.py +23 -0
- memos/mem_scheduler/schemas/__init__.py +0 -0
- memos/mem_scheduler/schemas/analyzer_schemas.py +52 -0
- memos/mem_scheduler/schemas/api_schemas.py +233 -0
- memos/mem_scheduler/schemas/general_schemas.py +55 -0
- memos/mem_scheduler/schemas/message_schemas.py +173 -0
- memos/mem_scheduler/schemas/monitor_schemas.py +406 -0
- memos/mem_scheduler/schemas/task_schemas.py +132 -0
- memos/mem_scheduler/task_schedule_modules/__init__.py +0 -0
- memos/mem_scheduler/task_schedule_modules/dispatcher.py +740 -0
- memos/mem_scheduler/task_schedule_modules/local_queue.py +247 -0
- memos/mem_scheduler/task_schedule_modules/orchestrator.py +74 -0
- memos/mem_scheduler/task_schedule_modules/redis_queue.py +1385 -0
- memos/mem_scheduler/task_schedule_modules/task_queue.py +162 -0
- memos/mem_scheduler/utils/__init__.py +0 -0
- memos/mem_scheduler/utils/api_utils.py +77 -0
- memos/mem_scheduler/utils/config_utils.py +100 -0
- memos/mem_scheduler/utils/db_utils.py +50 -0
- memos/mem_scheduler/utils/filter_utils.py +176 -0
- memos/mem_scheduler/utils/metrics.py +125 -0
- memos/mem_scheduler/utils/misc_utils.py +290 -0
- memos/mem_scheduler/utils/monitor_event_utils.py +67 -0
- memos/mem_scheduler/utils/status_tracker.py +229 -0
- memos/mem_scheduler/webservice_modules/__init__.py +0 -0
- memos/mem_scheduler/webservice_modules/rabbitmq_service.py +485 -0
- memos/mem_scheduler/webservice_modules/redis_service.py +380 -0
- memos/mem_user/factory.py +94 -0
- memos/mem_user/mysql_persistent_user_manager.py +271 -0
- memos/mem_user/mysql_user_manager.py +502 -0
- memos/mem_user/persistent_factory.py +98 -0
- memos/mem_user/persistent_user_manager.py +260 -0
- memos/mem_user/redis_persistent_user_manager.py +225 -0
- memos/mem_user/user_manager.py +488 -0
- memos/memories/__init__.py +0 -0
- memos/memories/activation/__init__.py +0 -0
- memos/memories/activation/base.py +42 -0
- memos/memories/activation/item.py +56 -0
- memos/memories/activation/kv.py +292 -0
- memos/memories/activation/vllmkv.py +219 -0
- memos/memories/base.py +19 -0
- memos/memories/factory.py +42 -0
- memos/memories/parametric/__init__.py +0 -0
- memos/memories/parametric/base.py +19 -0
- memos/memories/parametric/item.py +11 -0
- memos/memories/parametric/lora.py +41 -0
- memos/memories/textual/__init__.py +0 -0
- memos/memories/textual/base.py +92 -0
- memos/memories/textual/general.py +236 -0
- memos/memories/textual/item.py +304 -0
- memos/memories/textual/naive.py +187 -0
- memos/memories/textual/prefer_text_memory/__init__.py +0 -0
- memos/memories/textual/prefer_text_memory/adder.py +504 -0
- memos/memories/textual/prefer_text_memory/config.py +106 -0
- memos/memories/textual/prefer_text_memory/extractor.py +221 -0
- memos/memories/textual/prefer_text_memory/factory.py +85 -0
- memos/memories/textual/prefer_text_memory/retrievers.py +177 -0
- memos/memories/textual/prefer_text_memory/spliter.py +132 -0
- memos/memories/textual/prefer_text_memory/utils.py +93 -0
- memos/memories/textual/preference.py +344 -0
- memos/memories/textual/simple_preference.py +161 -0
- memos/memories/textual/simple_tree.py +69 -0
- memos/memories/textual/tree.py +459 -0
- memos/memories/textual/tree_text_memory/__init__.py +0 -0
- memos/memories/textual/tree_text_memory/organize/__init__.py +0 -0
- memos/memories/textual/tree_text_memory/organize/handler.py +184 -0
- memos/memories/textual/tree_text_memory/organize/manager.py +518 -0
- memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +238 -0
- memos/memories/textual/tree_text_memory/organize/reorganizer.py +622 -0
- memos/memories/textual/tree_text_memory/retrieve/__init__.py +0 -0
- memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py +364 -0
- memos/memories/textual/tree_text_memory/retrieve/bm25_util.py +186 -0
- memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +419 -0
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever.py +270 -0
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +102 -0
- memos/memories/textual/tree_text_memory/retrieve/reasoner.py +61 -0
- memos/memories/textual/tree_text_memory/retrieve/recall.py +497 -0
- memos/memories/textual/tree_text_memory/retrieve/reranker.py +111 -0
- memos/memories/textual/tree_text_memory/retrieve/retrieval_mid_structs.py +16 -0
- memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py +472 -0
- memos/memories/textual/tree_text_memory/retrieve/searcher.py +848 -0
- memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +135 -0
- memos/memories/textual/tree_text_memory/retrieve/utils.py +54 -0
- memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py +387 -0
- memos/memos_tools/dinding_report_bot.py +453 -0
- memos/memos_tools/lockfree_dict.py +120 -0
- memos/memos_tools/notification_service.py +44 -0
- memos/memos_tools/notification_utils.py +142 -0
- memos/memos_tools/singleton.py +174 -0
- memos/memos_tools/thread_safe_dict.py +310 -0
- memos/memos_tools/thread_safe_dict_segment.py +382 -0
- memos/multi_mem_cube/__init__.py +0 -0
- memos/multi_mem_cube/composite_cube.py +86 -0
- memos/multi_mem_cube/single_cube.py +874 -0
- memos/multi_mem_cube/views.py +54 -0
- memos/parsers/__init__.py +0 -0
- memos/parsers/base.py +15 -0
- memos/parsers/factory.py +21 -0
- memos/parsers/markitdown.py +28 -0
- memos/reranker/__init__.py +4 -0
- memos/reranker/base.py +25 -0
- memos/reranker/concat.py +103 -0
- memos/reranker/cosine_local.py +102 -0
- memos/reranker/factory.py +72 -0
- memos/reranker/http_bge.py +324 -0
- memos/reranker/http_bge_strategy.py +327 -0
- memos/reranker/noop.py +19 -0
- memos/reranker/strategies/__init__.py +4 -0
- memos/reranker/strategies/base.py +61 -0
- memos/reranker/strategies/concat_background.py +94 -0
- memos/reranker/strategies/concat_docsource.py +110 -0
- memos/reranker/strategies/dialogue_common.py +109 -0
- memos/reranker/strategies/factory.py +31 -0
- memos/reranker/strategies/single_turn.py +107 -0
- memos/reranker/strategies/singleturn_outmem.py +98 -0
- memos/settings.py +10 -0
- memos/templates/__init__.py +0 -0
- memos/templates/advanced_search_prompts.py +211 -0
- memos/templates/cloud_service_prompt.py +107 -0
- memos/templates/instruction_completion.py +66 -0
- memos/templates/mem_agent_prompts.py +85 -0
- memos/templates/mem_feedback_prompts.py +822 -0
- memos/templates/mem_reader_prompts.py +1096 -0
- memos/templates/mem_reader_strategy_prompts.py +238 -0
- memos/templates/mem_scheduler_prompts.py +626 -0
- memos/templates/mem_search_prompts.py +93 -0
- memos/templates/mos_prompts.py +403 -0
- memos/templates/prefer_complete_prompt.py +735 -0
- memos/templates/tool_mem_prompts.py +139 -0
- memos/templates/tree_reorganize_prompts.py +230 -0
- memos/types/__init__.py +34 -0
- memos/types/general_types.py +151 -0
- memos/types/openai_chat_completion_types/__init__.py +15 -0
- memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py +56 -0
- memos/types/openai_chat_completion_types/chat_completion_content_part_image_param.py +27 -0
- memos/types/openai_chat_completion_types/chat_completion_content_part_input_audio_param.py +23 -0
- memos/types/openai_chat_completion_types/chat_completion_content_part_param.py +43 -0
- memos/types/openai_chat_completion_types/chat_completion_content_part_refusal_param.py +16 -0
- memos/types/openai_chat_completion_types/chat_completion_content_part_text_param.py +16 -0
- memos/types/openai_chat_completion_types/chat_completion_message_custom_tool_call_param.py +27 -0
- memos/types/openai_chat_completion_types/chat_completion_message_function_tool_call_param.py +32 -0
- memos/types/openai_chat_completion_types/chat_completion_message_param.py +18 -0
- memos/types/openai_chat_completion_types/chat_completion_message_tool_call_union_param.py +15 -0
- memos/types/openai_chat_completion_types/chat_completion_system_message_param.py +36 -0
- memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py +30 -0
- memos/types/openai_chat_completion_types/chat_completion_user_message_param.py +34 -0
- memos/utils.py +123 -0
- memos/vec_dbs/__init__.py +0 -0
- memos/vec_dbs/base.py +117 -0
- memos/vec_dbs/factory.py +23 -0
- memos/vec_dbs/item.py +50 -0
- memos/vec_dbs/milvus.py +654 -0
- memos/vec_dbs/qdrant.py +355 -0
|
@@ -0,0 +1,1403 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import random
|
|
3
|
+
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from memos.log import get_logger
|
|
7
|
+
from memos.memories.activation.item import KVCacheItem
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
logger = get_logger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def extract_node_name(memory: str) -> str:
|
|
14
|
+
"""Extract the first two words from memory as node_name"""
|
|
15
|
+
if not memory:
|
|
16
|
+
return ""
|
|
17
|
+
|
|
18
|
+
words = [word.strip() for word in memory.split() if word.strip()]
|
|
19
|
+
|
|
20
|
+
if len(words) >= 2:
|
|
21
|
+
return " ".join(words[:2])
|
|
22
|
+
elif len(words) == 1:
|
|
23
|
+
return words[0]
|
|
24
|
+
else:
|
|
25
|
+
return ""
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def analyze_tree_structure_enhanced(nodes: list[dict], edges: list[dict]) -> dict:
|
|
29
|
+
"""Enhanced tree structure analysis, focusing on branching degree and leaf distribution"""
|
|
30
|
+
# Build adjacency list
|
|
31
|
+
adj_list = {}
|
|
32
|
+
reverse_adj = {}
|
|
33
|
+
for edge in edges:
|
|
34
|
+
source, target = edge["source"], edge["target"]
|
|
35
|
+
adj_list.setdefault(source, []).append(target)
|
|
36
|
+
reverse_adj.setdefault(target, []).append(source)
|
|
37
|
+
|
|
38
|
+
# Find all nodes and root nodes
|
|
39
|
+
all_nodes = {node["id"] for node in nodes}
|
|
40
|
+
target_nodes = {edge["target"] for edge in edges}
|
|
41
|
+
root_nodes = all_nodes - target_nodes
|
|
42
|
+
|
|
43
|
+
subtree_analysis = {}
|
|
44
|
+
|
|
45
|
+
def analyze_subtree_enhanced(root_id: str) -> dict:
|
|
46
|
+
"""Enhanced subtree analysis, focusing on branching degree and structure quality"""
|
|
47
|
+
visited = set()
|
|
48
|
+
max_depth = 0
|
|
49
|
+
leaf_count = 0
|
|
50
|
+
total_nodes = 0
|
|
51
|
+
branch_nodes = 0 # Number of branch nodes with multiple children
|
|
52
|
+
chain_length = 0 # Longest single chain length
|
|
53
|
+
width_per_level = {} # Width per level
|
|
54
|
+
|
|
55
|
+
def dfs(node_id: str, depth: int, chain_len: int):
|
|
56
|
+
nonlocal max_depth, leaf_count, total_nodes, branch_nodes, chain_length
|
|
57
|
+
|
|
58
|
+
if node_id in visited:
|
|
59
|
+
return
|
|
60
|
+
|
|
61
|
+
visited.add(node_id)
|
|
62
|
+
total_nodes += 1
|
|
63
|
+
max_depth = max(max_depth, depth)
|
|
64
|
+
chain_length = max(chain_length, chain_len)
|
|
65
|
+
|
|
66
|
+
# Record number of nodes per level
|
|
67
|
+
width_per_level[depth] = width_per_level.get(depth, 0) + 1
|
|
68
|
+
|
|
69
|
+
children = adj_list.get(node_id, [])
|
|
70
|
+
|
|
71
|
+
if not children: # Leaf node
|
|
72
|
+
leaf_count += 1
|
|
73
|
+
elif len(children) > 1: # Branch node
|
|
74
|
+
branch_nodes += 1
|
|
75
|
+
# Reset chain length because we encountered a branch
|
|
76
|
+
for child in children:
|
|
77
|
+
dfs(child, depth + 1, 0)
|
|
78
|
+
else: # Single child node (chain structure)
|
|
79
|
+
for child in children:
|
|
80
|
+
dfs(child, depth + 1, chain_len + 1)
|
|
81
|
+
|
|
82
|
+
dfs(root_id, 0, 0)
|
|
83
|
+
|
|
84
|
+
# Calculate structure quality metrics
|
|
85
|
+
avg_width = sum(width_per_level.values()) / len(width_per_level) if width_per_level else 0
|
|
86
|
+
max_width = max(width_per_level.values()) if width_per_level else 0
|
|
87
|
+
|
|
88
|
+
# Calculate branch density: ratio of branch nodes to total nodes
|
|
89
|
+
branch_density = branch_nodes / total_nodes if total_nodes > 0 else 0
|
|
90
|
+
|
|
91
|
+
# Calculate depth-width ratio: ideal tree should have moderate depth and good breadth
|
|
92
|
+
depth_width_ratio = max_depth / max_width if max_width > 0 else max_depth
|
|
93
|
+
|
|
94
|
+
quality_score = calculate_enhanced_quality(
|
|
95
|
+
max_depth,
|
|
96
|
+
leaf_count,
|
|
97
|
+
total_nodes,
|
|
98
|
+
branch_nodes,
|
|
99
|
+
chain_length,
|
|
100
|
+
branch_density,
|
|
101
|
+
depth_width_ratio,
|
|
102
|
+
max_width,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
return {
|
|
106
|
+
"root_id": root_id,
|
|
107
|
+
"max_depth": max_depth,
|
|
108
|
+
"leaf_count": leaf_count,
|
|
109
|
+
"total_nodes": total_nodes,
|
|
110
|
+
"branch_nodes": branch_nodes,
|
|
111
|
+
"max_chain_length": chain_length,
|
|
112
|
+
"branch_density": branch_density,
|
|
113
|
+
"max_width": max_width,
|
|
114
|
+
"avg_width": avg_width,
|
|
115
|
+
"depth_width_ratio": depth_width_ratio,
|
|
116
|
+
"nodes_in_subtree": list(visited),
|
|
117
|
+
"quality_score": quality_score,
|
|
118
|
+
"width_per_level": width_per_level,
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
for root_id in root_nodes:
|
|
122
|
+
subtree_analysis[root_id] = analyze_subtree_enhanced(root_id)
|
|
123
|
+
|
|
124
|
+
return subtree_analysis
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def calculate_enhanced_quality(
|
|
128
|
+
max_depth: int,
|
|
129
|
+
leaf_count: int,
|
|
130
|
+
total_nodes: int,
|
|
131
|
+
branch_nodes: int,
|
|
132
|
+
max_chain_length: int,
|
|
133
|
+
branch_density: float,
|
|
134
|
+
depth_width_ratio: float,
|
|
135
|
+
max_width: int,
|
|
136
|
+
) -> float:
|
|
137
|
+
"""Enhanced quality calculation, prioritizing branching degree and leaf distribution"""
|
|
138
|
+
|
|
139
|
+
if total_nodes <= 1:
|
|
140
|
+
return 0.1
|
|
141
|
+
|
|
142
|
+
# 1. Branch quality score (weight: 35%)
|
|
143
|
+
# Branch node count score
|
|
144
|
+
branch_count_score = min(branch_nodes * 3, 15) # 3 points per branch node, max 15 points
|
|
145
|
+
|
|
146
|
+
# Branch density score: ideal density between 20%-60%
|
|
147
|
+
if 0.2 <= branch_density <= 0.6:
|
|
148
|
+
branch_density_score = 10
|
|
149
|
+
elif branch_density > 0.6:
|
|
150
|
+
branch_density_score = max(5, 10 - (branch_density - 0.6) * 20)
|
|
151
|
+
else:
|
|
152
|
+
branch_density_score = branch_density * 25 # Linear growth for 0-20%
|
|
153
|
+
|
|
154
|
+
branch_score = (branch_count_score + branch_density_score) * 0.35
|
|
155
|
+
|
|
156
|
+
# 2. Leaf quality score (weight: 25%)
|
|
157
|
+
# Leaf count score
|
|
158
|
+
leaf_count_score = min(leaf_count * 2, 20)
|
|
159
|
+
|
|
160
|
+
# Leaf distribution score: ideal leaf ratio 30%-70% of total nodes
|
|
161
|
+
leaf_ratio = leaf_count / total_nodes
|
|
162
|
+
if 0.3 <= leaf_ratio <= 0.7:
|
|
163
|
+
leaf_ratio_score = 10
|
|
164
|
+
elif leaf_ratio > 0.7:
|
|
165
|
+
leaf_ratio_score = max(3, 10 - (leaf_ratio - 0.7) * 20)
|
|
166
|
+
else:
|
|
167
|
+
leaf_ratio_score = leaf_ratio * 20 # Linear growth for 0-30%
|
|
168
|
+
|
|
169
|
+
leaf_score = (leaf_count_score + leaf_ratio_score) * 0.25
|
|
170
|
+
|
|
171
|
+
# 3. Structure balance score (weight: 25%)
|
|
172
|
+
# Depth score: moderate depth is best (3-8 layers)
|
|
173
|
+
if 3 <= max_depth <= 8:
|
|
174
|
+
depth_score = 15
|
|
175
|
+
elif max_depth < 3:
|
|
176
|
+
depth_score = max_depth * 3 # Lower score for 1-2 layers
|
|
177
|
+
else:
|
|
178
|
+
depth_score = max(5, 15 - (max_depth - 8) * 1.5) # Gradually reduce score beyond 8 layers
|
|
179
|
+
|
|
180
|
+
# Width score: larger max width is better, but with upper limit
|
|
181
|
+
width_score = min(max_width * 1.5, 15)
|
|
182
|
+
|
|
183
|
+
# Depth-width ratio penalty: too large ratio means tree is too "thin"
|
|
184
|
+
if depth_width_ratio > 3:
|
|
185
|
+
ratio_penalty = (depth_width_ratio - 3) * 2
|
|
186
|
+
structure_score = max(0, (depth_score + width_score - ratio_penalty)) * 0.25
|
|
187
|
+
else:
|
|
188
|
+
structure_score = (depth_score + width_score) * 0.25
|
|
189
|
+
|
|
190
|
+
# 4. Chain structure penalty (weight: 15%)
|
|
191
|
+
# Longest single chain length penalty: overly long chains severely affect display
|
|
192
|
+
if max_chain_length <= 2:
|
|
193
|
+
chain_penalty_score = 10
|
|
194
|
+
elif max_chain_length <= 5:
|
|
195
|
+
chain_penalty_score = 8 - (max_chain_length - 2)
|
|
196
|
+
else:
|
|
197
|
+
chain_penalty_score = max(0, 3 - (max_chain_length - 5) * 0.5)
|
|
198
|
+
|
|
199
|
+
chain_score = chain_penalty_score * 0.15
|
|
200
|
+
|
|
201
|
+
# 5. Comprehensive calculation
|
|
202
|
+
total_score = branch_score + leaf_score + structure_score + chain_score
|
|
203
|
+
|
|
204
|
+
# Special case severe penalties
|
|
205
|
+
if max_chain_length > total_nodes * 0.8: # If more than 80% are single chains
|
|
206
|
+
total_score *= 0.3
|
|
207
|
+
elif branch_density < 0.1 and total_nodes > 5: # Large tree with almost no branches
|
|
208
|
+
total_score *= 0.5
|
|
209
|
+
|
|
210
|
+
return total_score
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def sample_nodes_with_type_balance(
|
|
214
|
+
nodes: list[dict],
|
|
215
|
+
edges: list[dict],
|
|
216
|
+
target_count: int = 150,
|
|
217
|
+
type_ratios: dict[str, float] | None = None,
|
|
218
|
+
) -> tuple[list[dict], list[dict]]:
|
|
219
|
+
"""
|
|
220
|
+
Balanced sampling based on type ratios and tree quality
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
nodes: List of nodes
|
|
224
|
+
edges: List of edges
|
|
225
|
+
target_count: Target number of nodes
|
|
226
|
+
type_ratios: Expected ratio for each type, e.g. {'WorkingMemory': 0.15, 'EpisodicMemory': 0.30, ...}
|
|
227
|
+
"""
|
|
228
|
+
if len(nodes) <= target_count:
|
|
229
|
+
return nodes, edges
|
|
230
|
+
|
|
231
|
+
# Default type ratio configuration
|
|
232
|
+
if type_ratios is None:
|
|
233
|
+
type_ratios = {
|
|
234
|
+
"WorkingMemory": 0.10, # 10%
|
|
235
|
+
"EpisodicMemory": 0.25, # 25%
|
|
236
|
+
"SemanticMemory": 0.25, # 25%
|
|
237
|
+
"ProceduralMemory": 0.20, # 20%
|
|
238
|
+
"EmotionalMemory": 0.15, # 15%
|
|
239
|
+
"MetaMemory": 0.05, # 5%
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
logger.info(
|
|
243
|
+
f"Starting type-balanced sampling, original nodes: {len(nodes)}, target nodes: {target_count}"
|
|
244
|
+
)
|
|
245
|
+
logger.info(f"Target type ratios: {type_ratios}")
|
|
246
|
+
|
|
247
|
+
# Analyze current node type distribution
|
|
248
|
+
current_type_counts = {}
|
|
249
|
+
nodes_by_type = {}
|
|
250
|
+
|
|
251
|
+
for node in nodes:
|
|
252
|
+
memory_type = node.get("metadata", {}).get("memory_type", "Unknown")
|
|
253
|
+
current_type_counts[memory_type] = current_type_counts.get(memory_type, 0) + 1
|
|
254
|
+
if memory_type not in nodes_by_type:
|
|
255
|
+
nodes_by_type[memory_type] = []
|
|
256
|
+
nodes_by_type[memory_type].append(node)
|
|
257
|
+
|
|
258
|
+
logger.info(f"Current type distribution: {current_type_counts}")
|
|
259
|
+
|
|
260
|
+
# Calculate target node count for each type
|
|
261
|
+
type_targets = {}
|
|
262
|
+
remaining_target = target_count
|
|
263
|
+
|
|
264
|
+
for memory_type, ratio in type_ratios.items():
|
|
265
|
+
if memory_type in nodes_by_type:
|
|
266
|
+
target_for_type = int(target_count * ratio)
|
|
267
|
+
# Ensure not exceeding the actual node count for this type
|
|
268
|
+
target_for_type = min(target_for_type, len(nodes_by_type[memory_type]))
|
|
269
|
+
type_targets[memory_type] = target_for_type
|
|
270
|
+
remaining_target -= target_for_type
|
|
271
|
+
|
|
272
|
+
# Handle types not in ratio configuration
|
|
273
|
+
other_types = set(nodes_by_type.keys()) - set(type_ratios.keys())
|
|
274
|
+
if other_types and remaining_target > 0:
|
|
275
|
+
per_other_type = max(1, remaining_target // len(other_types))
|
|
276
|
+
for memory_type in other_types:
|
|
277
|
+
allocation = min(per_other_type, len(nodes_by_type[memory_type]))
|
|
278
|
+
type_targets[memory_type] = allocation
|
|
279
|
+
remaining_target -= allocation
|
|
280
|
+
|
|
281
|
+
# If there's still remaining, distribute proportionally to main types
|
|
282
|
+
if remaining_target > 0:
|
|
283
|
+
main_types = [t for t in type_ratios if t in nodes_by_type]
|
|
284
|
+
if main_types:
|
|
285
|
+
extra_per_type = remaining_target // len(main_types)
|
|
286
|
+
for memory_type in main_types:
|
|
287
|
+
additional = min(
|
|
288
|
+
extra_per_type,
|
|
289
|
+
len(nodes_by_type[memory_type]) - type_targets.get(memory_type, 0),
|
|
290
|
+
)
|
|
291
|
+
type_targets[memory_type] = type_targets.get(memory_type, 0) + additional
|
|
292
|
+
|
|
293
|
+
logger.info(f"Target node count for each type: {type_targets}")
|
|
294
|
+
|
|
295
|
+
# Perform subtree quality sampling for each type
|
|
296
|
+
selected_nodes = []
|
|
297
|
+
|
|
298
|
+
for memory_type, target_for_type in type_targets.items():
|
|
299
|
+
if target_for_type <= 0 or memory_type not in nodes_by_type:
|
|
300
|
+
continue
|
|
301
|
+
|
|
302
|
+
type_nodes = nodes_by_type[memory_type]
|
|
303
|
+
logger.info(
|
|
304
|
+
f"\n--- Processing {memory_type} type: {len(type_nodes)} -> {target_for_type} ---"
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
if len(type_nodes) <= target_for_type:
|
|
308
|
+
selected_nodes.extend(type_nodes)
|
|
309
|
+
logger.info(f" Select all: {len(type_nodes)} nodes")
|
|
310
|
+
else:
|
|
311
|
+
# Use enhanced subtree quality sampling
|
|
312
|
+
type_selected = sample_by_enhanced_subtree_quality(type_nodes, edges, target_for_type)
|
|
313
|
+
selected_nodes.extend(type_selected)
|
|
314
|
+
logger.info(f" Sampled selection: {len(type_selected)} nodes")
|
|
315
|
+
|
|
316
|
+
# Filter edges
|
|
317
|
+
selected_node_ids = {node["id"] for node in selected_nodes}
|
|
318
|
+
filtered_edges = [
|
|
319
|
+
edge
|
|
320
|
+
for edge in edges
|
|
321
|
+
if edge["source"] in selected_node_ids and edge["target"] in selected_node_ids
|
|
322
|
+
]
|
|
323
|
+
|
|
324
|
+
logger.info(f"\nFinal selected nodes: {len(selected_nodes)}")
|
|
325
|
+
logger.info(f"Final edges: {len(filtered_edges)}")
|
|
326
|
+
|
|
327
|
+
# Verify final type distribution
|
|
328
|
+
final_type_counts = {}
|
|
329
|
+
for node in selected_nodes:
|
|
330
|
+
memory_type = node.get("metadata", {}).get("memory_type", "Unknown")
|
|
331
|
+
final_type_counts[memory_type] = final_type_counts.get(memory_type, 0) + 1
|
|
332
|
+
|
|
333
|
+
logger.info(f"Final type distribution: {final_type_counts}")
|
|
334
|
+
for memory_type, count in final_type_counts.items():
|
|
335
|
+
percentage = count / len(selected_nodes) * 100
|
|
336
|
+
target_percentage = type_ratios.get(memory_type, 0) * 100
|
|
337
|
+
logger.info(
|
|
338
|
+
f" {memory_type}: {count} nodes ({percentage:.1f}%, target: {target_percentage:.1f}%)"
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
return selected_nodes, filtered_edges
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def sample_by_enhanced_subtree_quality(
|
|
345
|
+
nodes: list[dict], edges: list[dict], target_count: int
|
|
346
|
+
) -> list[dict]:
|
|
347
|
+
"""Sample using enhanced subtree quality"""
|
|
348
|
+
if len(nodes) <= target_count:
|
|
349
|
+
return nodes
|
|
350
|
+
|
|
351
|
+
# Analyze subtree structure
|
|
352
|
+
subtree_analysis = analyze_tree_structure_enhanced(nodes, edges)
|
|
353
|
+
|
|
354
|
+
if not subtree_analysis:
|
|
355
|
+
# If no subtree structure, sample by node importance
|
|
356
|
+
return sample_nodes_by_importance(nodes, edges, target_count)
|
|
357
|
+
|
|
358
|
+
# Sort subtrees by quality score
|
|
359
|
+
sorted_subtrees = sorted(
|
|
360
|
+
subtree_analysis.items(), key=lambda x: x[1]["quality_score"], reverse=True
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
logger.info(" Subtree quality ranking:")
|
|
364
|
+
for i, (root_id, analysis) in enumerate(sorted_subtrees[:5]):
|
|
365
|
+
logger.info(
|
|
366
|
+
f" #{i + 1} Root node {root_id}: Quality={analysis['quality_score']:.2f}, "
|
|
367
|
+
f"Depth={analysis['max_depth']}, Branches={analysis['branch_nodes']}, "
|
|
368
|
+
f"Leaves={analysis['leaf_count']}, Max Width={analysis['max_width']}"
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
# Greedy selection of high-quality subtrees
|
|
372
|
+
selected_nodes = []
|
|
373
|
+
selected_node_ids = set()
|
|
374
|
+
|
|
375
|
+
for root_id, analysis in sorted_subtrees:
|
|
376
|
+
subtree_nodes = analysis["nodes_in_subtree"]
|
|
377
|
+
new_nodes = [node_id for node_id in subtree_nodes if node_id not in selected_node_ids]
|
|
378
|
+
|
|
379
|
+
if not new_nodes:
|
|
380
|
+
continue
|
|
381
|
+
|
|
382
|
+
remaining_quota = target_count - len(selected_nodes)
|
|
383
|
+
|
|
384
|
+
if len(new_nodes) <= remaining_quota:
|
|
385
|
+
# Entire subtree can be added
|
|
386
|
+
for node_id in new_nodes:
|
|
387
|
+
node = next((n for n in nodes if n["id"] == node_id), None)
|
|
388
|
+
if node:
|
|
389
|
+
selected_nodes.append(node)
|
|
390
|
+
selected_node_ids.add(node_id)
|
|
391
|
+
logger.info(f" Select entire subtree {root_id}: +{len(new_nodes)} nodes")
|
|
392
|
+
else:
|
|
393
|
+
# Subtree too large, need partial selection
|
|
394
|
+
if analysis["quality_score"] > 5: # Only partial selection for high-quality subtrees
|
|
395
|
+
subtree_node_objects = [n for n in nodes if n["id"] in new_nodes]
|
|
396
|
+
partial_selection = select_best_nodes_from_subtree(
|
|
397
|
+
subtree_node_objects, edges, remaining_quota, root_id
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
selected_nodes.extend(partial_selection)
|
|
401
|
+
for node in partial_selection:
|
|
402
|
+
selected_node_ids.add(node["id"])
|
|
403
|
+
logger.info(
|
|
404
|
+
f" Partial selection of subtree {root_id}: +{len(partial_selection)} nodes"
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
if len(selected_nodes) >= target_count:
|
|
408
|
+
break
|
|
409
|
+
|
|
410
|
+
# If target count not reached, supplement with remaining nodes
|
|
411
|
+
if len(selected_nodes) < target_count:
|
|
412
|
+
remaining_nodes = [n for n in nodes if n["id"] not in selected_node_ids]
|
|
413
|
+
remaining_count = target_count - len(selected_nodes)
|
|
414
|
+
additional = sample_nodes_by_importance(remaining_nodes, edges, remaining_count)
|
|
415
|
+
selected_nodes.extend(additional)
|
|
416
|
+
logger.info(f" Supplementary selection: +{len(additional)} nodes")
|
|
417
|
+
|
|
418
|
+
return selected_nodes
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
def select_best_nodes_from_subtree(
|
|
422
|
+
subtree_nodes: list[dict], edges: list[dict], max_count: int, root_id: str
|
|
423
|
+
) -> list[dict]:
|
|
424
|
+
"""Select the most important nodes from subtree, prioritizing branch structure"""
|
|
425
|
+
if len(subtree_nodes) <= max_count:
|
|
426
|
+
return subtree_nodes
|
|
427
|
+
|
|
428
|
+
# Build internal connection relationships of subtree
|
|
429
|
+
subtree_node_ids = {node["id"] for node in subtree_nodes}
|
|
430
|
+
subtree_edges = [
|
|
431
|
+
edge
|
|
432
|
+
for edge in edges
|
|
433
|
+
if edge["source"] in subtree_node_ids and edge["target"] in subtree_node_ids
|
|
434
|
+
]
|
|
435
|
+
|
|
436
|
+
# Calculate importance score for each node
|
|
437
|
+
node_scores = []
|
|
438
|
+
|
|
439
|
+
for node in subtree_nodes:
|
|
440
|
+
node_id = node["id"]
|
|
441
|
+
|
|
442
|
+
# Out-degree and in-degree
|
|
443
|
+
out_degree = sum(1 for edge in subtree_edges if edge["source"] == node_id)
|
|
444
|
+
in_degree = sum(1 for edge in subtree_edges if edge["target"] == node_id)
|
|
445
|
+
|
|
446
|
+
# Content length score
|
|
447
|
+
content_score = min(len(node.get("memory", "")), 300) / 15
|
|
448
|
+
|
|
449
|
+
# Branch node bonus
|
|
450
|
+
branch_bonus = out_degree * 8 if out_degree > 1 else 0
|
|
451
|
+
|
|
452
|
+
# Root node bonus
|
|
453
|
+
root_bonus = 15 if node_id == root_id else 0
|
|
454
|
+
|
|
455
|
+
# Connection importance
|
|
456
|
+
connection_score = (out_degree + in_degree) * 3
|
|
457
|
+
|
|
458
|
+
# Leaf node moderate bonus (ensure certain number of leaf nodes)
|
|
459
|
+
leaf_bonus = 5 if out_degree == 0 and in_degree > 0 else 0
|
|
460
|
+
|
|
461
|
+
total_score = content_score + connection_score + branch_bonus + root_bonus + leaf_bonus
|
|
462
|
+
node_scores.append((node, total_score))
|
|
463
|
+
|
|
464
|
+
# Sort by score and select
|
|
465
|
+
node_scores.sort(key=lambda x: x[1], reverse=True)
|
|
466
|
+
selected = [node for node, _ in node_scores[:max_count]]
|
|
467
|
+
|
|
468
|
+
return selected
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
def sample_nodes_by_importance(
|
|
472
|
+
nodes: list[dict], edges: list[dict], target_count: int
|
|
473
|
+
) -> list[dict]:
|
|
474
|
+
"""Sample by node importance (for cases without tree structure)"""
|
|
475
|
+
if len(nodes) <= target_count:
|
|
476
|
+
return nodes
|
|
477
|
+
|
|
478
|
+
node_scores = []
|
|
479
|
+
|
|
480
|
+
for node in nodes:
|
|
481
|
+
node_id = node["id"]
|
|
482
|
+
out_degree = sum(1 for edge in edges if edge["source"] == node_id)
|
|
483
|
+
in_degree = sum(1 for edge in edges if edge["target"] == node_id)
|
|
484
|
+
content_score = min(len(node.get("memory", "")), 200) / 10
|
|
485
|
+
connection_score = (out_degree + in_degree) * 5
|
|
486
|
+
random_score = random.random() * 10
|
|
487
|
+
|
|
488
|
+
total_score = content_score + connection_score + random_score
|
|
489
|
+
node_scores.append((node, total_score))
|
|
490
|
+
|
|
491
|
+
node_scores.sort(key=lambda x: x[1], reverse=True)
|
|
492
|
+
return [node for node, _ in node_scores[:target_count]]
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
# Modified main function to use new sampling strategy
|
|
496
|
+
def convert_graph_to_tree_forworkmem(
|
|
497
|
+
json_data: dict[str, Any],
|
|
498
|
+
target_node_count: int = 200,
|
|
499
|
+
type_ratios: dict[str, float] | None = None,
|
|
500
|
+
) -> dict[str, Any]:
|
|
501
|
+
"""
|
|
502
|
+
Enhanced graph-to-tree conversion function, prioritizing branching degree and type balance
|
|
503
|
+
"""
|
|
504
|
+
original_nodes = json_data.get("nodes", [])
|
|
505
|
+
original_edges = json_data.get("edges", [])
|
|
506
|
+
|
|
507
|
+
logger.info(f"Original node count: {len(original_nodes)}")
|
|
508
|
+
logger.info(f"Target node count: {target_node_count}")
|
|
509
|
+
filter_original_edges = []
|
|
510
|
+
for original_edge in original_edges:
|
|
511
|
+
if original_edge["type"] == "PARENT":
|
|
512
|
+
filter_original_edges.append(original_edge)
|
|
513
|
+
node_type_count = {}
|
|
514
|
+
for node in original_nodes:
|
|
515
|
+
node_type = node.get("metadata", {}).get("memory_type", "Unknown")
|
|
516
|
+
node_type_count[node_type] = node_type_count.get(node_type, 0) + 1
|
|
517
|
+
original_edges = filter_original_edges
|
|
518
|
+
# Use enhanced type-balanced sampling
|
|
519
|
+
if len(original_nodes) > target_node_count:
|
|
520
|
+
nodes, edges = sample_nodes_with_type_balance(
|
|
521
|
+
original_nodes, original_edges, target_node_count, type_ratios
|
|
522
|
+
)
|
|
523
|
+
else:
|
|
524
|
+
nodes, edges = original_nodes, original_edges
|
|
525
|
+
|
|
526
|
+
# The rest of tree structure building remains unchanged...
|
|
527
|
+
# [Original tree building code here]
|
|
528
|
+
|
|
529
|
+
# Create node mapping table
|
|
530
|
+
node_map = {}
|
|
531
|
+
for node in nodes:
|
|
532
|
+
memory = node.get("memory", "")
|
|
533
|
+
node_name = extract_node_name(memory)
|
|
534
|
+
memory_key = node.get("metadata", {}).get("key", node_name)
|
|
535
|
+
usage = node.get("metadata", {}).get("usage", [])
|
|
536
|
+
frequency = len(usage) if len(usage) < 100 else 100
|
|
537
|
+
node_map[node["id"]] = {
|
|
538
|
+
"id": node["id"],
|
|
539
|
+
"value": memory,
|
|
540
|
+
"frequency": frequency,
|
|
541
|
+
"node_name": memory_key,
|
|
542
|
+
"memory_type": node.get("metadata", {}).get("memory_type", "Unknown"),
|
|
543
|
+
"children": [],
|
|
544
|
+
}
|
|
545
|
+
|
|
546
|
+
# Build parent-child relationship mapping
|
|
547
|
+
children_map = {}
|
|
548
|
+
parent_map = {}
|
|
549
|
+
|
|
550
|
+
for edge in edges:
|
|
551
|
+
source = edge["source"]
|
|
552
|
+
target = edge["target"]
|
|
553
|
+
if source not in children_map:
|
|
554
|
+
children_map[source] = []
|
|
555
|
+
children_map[source].append(target)
|
|
556
|
+
parent_map[target] = source
|
|
557
|
+
|
|
558
|
+
# Find root nodes
|
|
559
|
+
all_node_ids = set(node_map.keys())
|
|
560
|
+
children_node_ids = set(parent_map.keys())
|
|
561
|
+
root_node_ids = all_node_ids - children_node_ids
|
|
562
|
+
|
|
563
|
+
# Separate WorkingMemory and other root nodes
|
|
564
|
+
working_memory_roots = []
|
|
565
|
+
other_roots = []
|
|
566
|
+
|
|
567
|
+
for root_id in root_node_ids:
|
|
568
|
+
if node_map[root_id]["memory_type"] == "WorkingMemory":
|
|
569
|
+
working_memory_roots.append(root_id)
|
|
570
|
+
else:
|
|
571
|
+
other_roots.append(root_id)
|
|
572
|
+
|
|
573
|
+
def build_tree(node_id: str, visited=None) -> dict[str, Any] | None:
|
|
574
|
+
"""Recursively build tree structure with cycle detection"""
|
|
575
|
+
if visited is None:
|
|
576
|
+
visited = set()
|
|
577
|
+
|
|
578
|
+
if node_id in visited:
|
|
579
|
+
logger.warning(f"[build_tree] Detected cycle at node {node_id}, skipping.")
|
|
580
|
+
return None
|
|
581
|
+
visited.add(node_id)
|
|
582
|
+
|
|
583
|
+
if node_id not in node_map:
|
|
584
|
+
return None
|
|
585
|
+
|
|
586
|
+
children_ids = children_map.get(node_id, [])
|
|
587
|
+
children = []
|
|
588
|
+
for child_id in children_ids:
|
|
589
|
+
child_tree = build_tree(child_id, visited)
|
|
590
|
+
if child_tree:
|
|
591
|
+
children.append(child_tree)
|
|
592
|
+
|
|
593
|
+
node = {
|
|
594
|
+
"id": node_id,
|
|
595
|
+
"node_name": node_map[node_id]["node_name"],
|
|
596
|
+
"value": node_map[node_id]["value"],
|
|
597
|
+
"memory_type": node_map[node_id]["memory_type"],
|
|
598
|
+
"frequency": node_map[node_id]["frequency"],
|
|
599
|
+
}
|
|
600
|
+
|
|
601
|
+
if children:
|
|
602
|
+
node["children"] = children
|
|
603
|
+
|
|
604
|
+
return node
|
|
605
|
+
|
|
606
|
+
# Build root tree list
|
|
607
|
+
root_trees = []
|
|
608
|
+
for root_id in other_roots:
|
|
609
|
+
tree = build_tree(root_id)
|
|
610
|
+
if tree:
|
|
611
|
+
root_trees.append(tree)
|
|
612
|
+
|
|
613
|
+
# Handle WorkingMemory
|
|
614
|
+
if working_memory_roots:
|
|
615
|
+
working_memory_children = []
|
|
616
|
+
for wm_root_id in working_memory_roots:
|
|
617
|
+
tree = build_tree(wm_root_id)
|
|
618
|
+
if tree:
|
|
619
|
+
working_memory_children.append(tree)
|
|
620
|
+
|
|
621
|
+
working_memory_node = {
|
|
622
|
+
"id": "WorkingMemory",
|
|
623
|
+
"node_name": "WorkingMemory",
|
|
624
|
+
"value": "WorkingMemory",
|
|
625
|
+
"memory_type": "WorkingMemory",
|
|
626
|
+
"children": working_memory_children,
|
|
627
|
+
"frequency": 0,
|
|
628
|
+
}
|
|
629
|
+
|
|
630
|
+
root_trees.append(working_memory_node)
|
|
631
|
+
|
|
632
|
+
# Create total root node
|
|
633
|
+
result = {
|
|
634
|
+
"id": "root",
|
|
635
|
+
"node_name": "root",
|
|
636
|
+
"value": "root",
|
|
637
|
+
"memory_type": "Root",
|
|
638
|
+
"children": root_trees,
|
|
639
|
+
"frequency": 0,
|
|
640
|
+
}
|
|
641
|
+
|
|
642
|
+
return result, node_type_count
|
|
643
|
+
|
|
644
|
+
|
|
645
|
+
def print_tree_structure(node: dict[str, Any], level: int = 0, max_level: int = 5):
|
|
646
|
+
"""logger.info the first few layers of tree structure for easy viewing"""
|
|
647
|
+
if level > max_level:
|
|
648
|
+
return
|
|
649
|
+
|
|
650
|
+
indent = " " * level
|
|
651
|
+
node_id = node.get("id", "unknown")
|
|
652
|
+
node_name = node.get("node_name", "")
|
|
653
|
+
node_value = node.get("value", "")
|
|
654
|
+
memory_type = node.get("memory_type", "Unknown")
|
|
655
|
+
|
|
656
|
+
# Determine display method based on whether there are children
|
|
657
|
+
children = node.get("children", [])
|
|
658
|
+
if children:
|
|
659
|
+
# Intermediate node, display name, type and child count
|
|
660
|
+
logger.info(f"{indent}- {node_name} [{memory_type}] ({len(children)} children)")
|
|
661
|
+
logger.info(f"{indent} ID: {node_id}")
|
|
662
|
+
display_value = node_value[:80] + "..." if len(node_value) > 80 else node_value
|
|
663
|
+
logger.info(f"{indent} Value: {display_value}")
|
|
664
|
+
|
|
665
|
+
if level < max_level:
|
|
666
|
+
for child in children:
|
|
667
|
+
print_tree_structure(child, level + 1, max_level)
|
|
668
|
+
elif level == max_level:
|
|
669
|
+
logger.info(f"{indent} ... (expansion limited)")
|
|
670
|
+
else:
|
|
671
|
+
# Leaf node, display name, type and value
|
|
672
|
+
display_value = node_value[:80] + "..." if len(node_value) > 80 else node_value
|
|
673
|
+
logger.info(f"{indent}- {node_name} [{memory_type}]: {display_value}")
|
|
674
|
+
logger.info(f"{indent} ID: {node_id}")
|
|
675
|
+
|
|
676
|
+
|
|
677
|
+
def analyze_final_tree_quality(tree_data: dict[str, Any]) -> dict:
|
|
678
|
+
"""Analyze final tree quality, including type diversity, branch structure, etc."""
|
|
679
|
+
stats = {
|
|
680
|
+
"total_nodes": 0,
|
|
681
|
+
"by_type": {},
|
|
682
|
+
"by_depth": {},
|
|
683
|
+
"max_depth": 0,
|
|
684
|
+
"total_leaves": 0,
|
|
685
|
+
"total_branches": 0, # Number of branch nodes with multiple children
|
|
686
|
+
"subtrees": [],
|
|
687
|
+
"type_diversity": {},
|
|
688
|
+
"structure_quality": {},
|
|
689
|
+
"chain_analysis": {}, # Single chain structure analysis
|
|
690
|
+
}
|
|
691
|
+
|
|
692
|
+
def analyze_subtree(node, depth=0, parent_path="", chain_length=0):
|
|
693
|
+
stats["total_nodes"] += 1
|
|
694
|
+
stats["max_depth"] = max(stats["max_depth"], depth)
|
|
695
|
+
|
|
696
|
+
# Count by type
|
|
697
|
+
memory_type = node.get("memory_type", "Unknown")
|
|
698
|
+
stats["by_type"][memory_type] = stats["by_type"].get(memory_type, 0) + 1
|
|
699
|
+
|
|
700
|
+
# Count by depth
|
|
701
|
+
stats["by_depth"][depth] = stats["by_depth"].get(depth, 0) + 1
|
|
702
|
+
|
|
703
|
+
children = node.get("children", [])
|
|
704
|
+
current_path = (
|
|
705
|
+
f"{parent_path}/{node.get('node_name', 'unknown')}"
|
|
706
|
+
if parent_path
|
|
707
|
+
else node.get("node_name", "root")
|
|
708
|
+
)
|
|
709
|
+
|
|
710
|
+
# Analyze node type
|
|
711
|
+
if not children: # Leaf node
|
|
712
|
+
stats["total_leaves"] += 1
|
|
713
|
+
# Record chain length
|
|
714
|
+
if "max_chain_length" not in stats["chain_analysis"]:
|
|
715
|
+
stats["chain_analysis"]["max_chain_length"] = 0
|
|
716
|
+
stats["chain_analysis"]["max_chain_length"] = max(
|
|
717
|
+
stats["chain_analysis"]["max_chain_length"], chain_length
|
|
718
|
+
)
|
|
719
|
+
elif len(children) == 1: # Single child node (chain)
|
|
720
|
+
# Continue calculating chain length
|
|
721
|
+
for child in children:
|
|
722
|
+
analyze_subtree(child, depth + 1, current_path, chain_length + 1)
|
|
723
|
+
return # Early return to avoid duplicate processing
|
|
724
|
+
else: # Branch node (multiple children)
|
|
725
|
+
stats["total_branches"] += 1
|
|
726
|
+
# Reset chain length
|
|
727
|
+
chain_length = 0
|
|
728
|
+
|
|
729
|
+
# If it's the root node of a major subtree, analyze its characteristics
|
|
730
|
+
if depth <= 2 and children: # Major subtree
|
|
731
|
+
subtree_depth = 0
|
|
732
|
+
subtree_leaves = 0
|
|
733
|
+
subtree_nodes = 0
|
|
734
|
+
subtree_branches = 0
|
|
735
|
+
subtree_types = {}
|
|
736
|
+
subtree_max_width = 0
|
|
737
|
+
width_per_level = {}
|
|
738
|
+
|
|
739
|
+
def count_subtree(subnode, subdepth):
|
|
740
|
+
nonlocal \
|
|
741
|
+
subtree_depth, \
|
|
742
|
+
subtree_leaves, \
|
|
743
|
+
subtree_nodes, \
|
|
744
|
+
subtree_branches, \
|
|
745
|
+
subtree_max_width
|
|
746
|
+
subtree_nodes += 1
|
|
747
|
+
subtree_depth = max(subtree_depth, subdepth)
|
|
748
|
+
|
|
749
|
+
# Count type distribution within subtree
|
|
750
|
+
sub_memory_type = subnode.get("memory_type", "Unknown")
|
|
751
|
+
subtree_types[sub_memory_type] = subtree_types.get(sub_memory_type, 0) + 1
|
|
752
|
+
|
|
753
|
+
# Count width per level
|
|
754
|
+
width_per_level[subdepth] = width_per_level.get(subdepth, 0) + 1
|
|
755
|
+
subtree_max_width = max(subtree_max_width, width_per_level[subdepth])
|
|
756
|
+
|
|
757
|
+
subchildren = subnode.get("children", [])
|
|
758
|
+
if not subchildren:
|
|
759
|
+
subtree_leaves += 1
|
|
760
|
+
elif len(subchildren) > 1:
|
|
761
|
+
subtree_branches += 1
|
|
762
|
+
|
|
763
|
+
for child in subchildren:
|
|
764
|
+
count_subtree(child, subdepth + 1)
|
|
765
|
+
|
|
766
|
+
count_subtree(node, 0)
|
|
767
|
+
|
|
768
|
+
# Calculate subtree quality metrics
|
|
769
|
+
branch_density = subtree_branches / subtree_nodes if subtree_nodes > 0 else 0
|
|
770
|
+
leaf_ratio = subtree_leaves / subtree_nodes if subtree_nodes > 0 else 0
|
|
771
|
+
depth_width_ratio = (
|
|
772
|
+
subtree_depth / subtree_max_width if subtree_max_width > 0 else subtree_depth
|
|
773
|
+
)
|
|
774
|
+
|
|
775
|
+
stats["subtrees"].append(
|
|
776
|
+
{
|
|
777
|
+
"root": node.get("node_name", "unknown"),
|
|
778
|
+
"type": memory_type,
|
|
779
|
+
"depth": subtree_depth,
|
|
780
|
+
"leaves": subtree_leaves,
|
|
781
|
+
"nodes": subtree_nodes,
|
|
782
|
+
"branches": subtree_branches,
|
|
783
|
+
"branch_density": branch_density,
|
|
784
|
+
"leaf_ratio": leaf_ratio,
|
|
785
|
+
"max_width": subtree_max_width,
|
|
786
|
+
"depth_width_ratio": depth_width_ratio,
|
|
787
|
+
"path": current_path,
|
|
788
|
+
"type_distribution": subtree_types,
|
|
789
|
+
"quality_score": calculate_enhanced_quality(
|
|
790
|
+
subtree_depth,
|
|
791
|
+
subtree_leaves,
|
|
792
|
+
subtree_nodes,
|
|
793
|
+
subtree_branches,
|
|
794
|
+
0,
|
|
795
|
+
branch_density,
|
|
796
|
+
depth_width_ratio,
|
|
797
|
+
subtree_max_width,
|
|
798
|
+
),
|
|
799
|
+
}
|
|
800
|
+
)
|
|
801
|
+
|
|
802
|
+
# Recursively analyze child nodes
|
|
803
|
+
for child in children:
|
|
804
|
+
analyze_subtree(child, depth + 1, current_path, 0) # Reset chain length
|
|
805
|
+
|
|
806
|
+
analyze_subtree(tree_data)
|
|
807
|
+
|
|
808
|
+
# Calculate overall structure quality
|
|
809
|
+
if stats["total_nodes"] > 1:
|
|
810
|
+
branch_density = stats["total_branches"] / stats["total_nodes"]
|
|
811
|
+
leaf_ratio = stats["total_leaves"] / stats["total_nodes"]
|
|
812
|
+
|
|
813
|
+
# Calculate average width per level
|
|
814
|
+
total_width = sum(stats["by_depth"].values())
|
|
815
|
+
avg_width = total_width / len(stats["by_depth"]) if stats["by_depth"] else 0
|
|
816
|
+
max_width = max(stats["by_depth"].values()) if stats["by_depth"] else 0
|
|
817
|
+
|
|
818
|
+
stats["structure_quality"] = {
|
|
819
|
+
"branch_density": branch_density,
|
|
820
|
+
"leaf_ratio": leaf_ratio,
|
|
821
|
+
"avg_width": avg_width,
|
|
822
|
+
"max_width": max_width,
|
|
823
|
+
"depth_width_ratio": stats["max_depth"] / max_width
|
|
824
|
+
if max_width > 0
|
|
825
|
+
else stats["max_depth"],
|
|
826
|
+
"is_well_balanced": 0.2 <= branch_density <= 0.6 and 0.3 <= leaf_ratio <= 0.7,
|
|
827
|
+
}
|
|
828
|
+
|
|
829
|
+
# Calculate type diversity metrics
|
|
830
|
+
total_types = len(stats["by_type"])
|
|
831
|
+
if total_types > 1:
|
|
832
|
+
# Calculate uniformity of type distribution (Shannon diversity index)
|
|
833
|
+
shannon_diversity = 0
|
|
834
|
+
for count in stats["by_type"].values():
|
|
835
|
+
if count > 0:
|
|
836
|
+
p = count / stats["total_nodes"]
|
|
837
|
+
shannon_diversity -= p * math.log2(p)
|
|
838
|
+
|
|
839
|
+
# Normalize diversity index (0-1 range)
|
|
840
|
+
max_diversity = math.log2(total_types) if total_types > 1 else 0
|
|
841
|
+
normalized_diversity = shannon_diversity / max_diversity if max_diversity > 0 else 0
|
|
842
|
+
|
|
843
|
+
stats["type_diversity"] = {
|
|
844
|
+
"total_types": total_types,
|
|
845
|
+
"shannon_diversity": shannon_diversity,
|
|
846
|
+
"normalized_diversity": normalized_diversity,
|
|
847
|
+
"distribution_balance": min(stats["by_type"].values()) / max(stats["by_type"].values())
|
|
848
|
+
if max(stats["by_type"].values()) > 0
|
|
849
|
+
else 0,
|
|
850
|
+
}
|
|
851
|
+
|
|
852
|
+
# Single chain analysis
|
|
853
|
+
total_single_child_nodes = sum(
|
|
854
|
+
1 for subtree in stats["subtrees"] if subtree.get("branch_density", 0) < 0.1
|
|
855
|
+
)
|
|
856
|
+
stats["chain_analysis"].update(
|
|
857
|
+
{
|
|
858
|
+
"single_chain_subtrees": total_single_child_nodes,
|
|
859
|
+
"chain_subtree_ratio": total_single_child_nodes / len(stats["subtrees"])
|
|
860
|
+
if stats["subtrees"]
|
|
861
|
+
else 0,
|
|
862
|
+
}
|
|
863
|
+
)
|
|
864
|
+
|
|
865
|
+
return stats
|
|
866
|
+
|
|
867
|
+
|
|
868
|
+
def print_tree_analysis(tree_data: dict[str, Any]):
|
|
869
|
+
"""logger.info enhanced tree analysis results"""
|
|
870
|
+
stats = analyze_final_tree_quality(tree_data)
|
|
871
|
+
|
|
872
|
+
logger.info("\n" + "=" * 60)
|
|
873
|
+
logger.info("🌳 Enhanced Tree Structure Quality Analysis Report")
|
|
874
|
+
logger.info("=" * 60)
|
|
875
|
+
|
|
876
|
+
# Basic statistics
|
|
877
|
+
logger.info("\n📊 Basic Statistics:")
|
|
878
|
+
logger.info(f" Total nodes: {stats['total_nodes']}")
|
|
879
|
+
logger.info(f" Max depth: {stats['max_depth']}")
|
|
880
|
+
logger.info(
|
|
881
|
+
f" Leaf nodes: {stats['total_leaves']} ({stats['total_leaves'] / stats['total_nodes'] * 100:.1f}%)"
|
|
882
|
+
)
|
|
883
|
+
logger.info(
|
|
884
|
+
f" Branch nodes: {stats['total_branches']} ({stats['total_branches'] / stats['total_nodes'] * 100:.1f}%)"
|
|
885
|
+
)
|
|
886
|
+
|
|
887
|
+
# Structure quality assessment
|
|
888
|
+
structure = stats.get("structure_quality", {})
|
|
889
|
+
if structure:
|
|
890
|
+
logger.info("\n🏗️ Structure Quality Assessment:")
|
|
891
|
+
logger.info(
|
|
892
|
+
f" Branch density: {structure['branch_density']:.3f} ({'✅ Good' if 0.2 <= structure['branch_density'] <= 0.6 else '⚠️ Needs improvement'})"
|
|
893
|
+
)
|
|
894
|
+
logger.info(
|
|
895
|
+
f" Leaf ratio: {structure['leaf_ratio']:.3f} ({'✅ Good' if 0.3 <= structure['leaf_ratio'] <= 0.7 else '⚠️ Needs improvement'})"
|
|
896
|
+
)
|
|
897
|
+
logger.info(f" Max width: {structure['max_width']}")
|
|
898
|
+
logger.info(
|
|
899
|
+
f" Depth-width ratio: {structure['depth_width_ratio']:.2f} ({'✅ Good' if structure['depth_width_ratio'] <= 3 else '⚠️ Too thin'})"
|
|
900
|
+
)
|
|
901
|
+
logger.info(
|
|
902
|
+
f" Overall balance: {'✅ Good' if structure['is_well_balanced'] else '⚠️ Needs improvement'}"
|
|
903
|
+
)
|
|
904
|
+
|
|
905
|
+
# Single chain analysis
|
|
906
|
+
chain_analysis = stats.get("chain_analysis", {})
|
|
907
|
+
if chain_analysis:
|
|
908
|
+
logger.info("\n🔗 Single Chain Structure Analysis:")
|
|
909
|
+
logger.info(f" Longest chain: {chain_analysis.get('max_chain_length', 0)} layers")
|
|
910
|
+
logger.info(f" Single chain subtrees: {chain_analysis.get('single_chain_subtrees', 0)}")
|
|
911
|
+
logger.info(
|
|
912
|
+
f" Single chain subtree ratio: {chain_analysis.get('chain_subtree_ratio', 0) * 100:.1f}%"
|
|
913
|
+
)
|
|
914
|
+
|
|
915
|
+
if chain_analysis.get("max_chain_length", 0) > 5:
|
|
916
|
+
logger.info(" ⚠️ Warning: Overly long single chain structure may affect display")
|
|
917
|
+
elif chain_analysis.get("chain_subtree_ratio", 0) > 0.3:
|
|
918
|
+
logger.info(
|
|
919
|
+
" ⚠️ Warning: Too many single chain subtrees, suggest increasing branch structure"
|
|
920
|
+
)
|
|
921
|
+
else:
|
|
922
|
+
logger.info(" ✅ Single chain structure well controlled")
|
|
923
|
+
|
|
924
|
+
# Type diversity
|
|
925
|
+
type_div = stats.get("type_diversity", {})
|
|
926
|
+
if type_div:
|
|
927
|
+
logger.info("\n🎨 Type Diversity Analysis:")
|
|
928
|
+
logger.info(f" Total types: {type_div['total_types']}")
|
|
929
|
+
logger.info(f" Diversity index: {type_div['shannon_diversity']:.3f}")
|
|
930
|
+
logger.info(f" Normalized diversity: {type_div['normalized_diversity']:.3f}")
|
|
931
|
+
logger.info(f" Distribution balance: {type_div['distribution_balance']:.3f}")
|
|
932
|
+
|
|
933
|
+
# Type distribution
|
|
934
|
+
logger.info("\n📋 Type Distribution Details:")
|
|
935
|
+
for mem_type, count in sorted(stats["by_type"].items(), key=lambda x: x[1], reverse=True):
|
|
936
|
+
percentage = count / stats["total_nodes"] * 100
|
|
937
|
+
logger.info(f" {mem_type}: {count} nodes ({percentage:.1f}%)")
|
|
938
|
+
|
|
939
|
+
# Depth distribution
|
|
940
|
+
logger.info("\n📏 Depth Distribution:")
|
|
941
|
+
for depth in sorted(stats["by_depth"].keys()):
|
|
942
|
+
count = stats["by_depth"][depth]
|
|
943
|
+
logger.info(f" Depth {depth}: {count} nodes")
|
|
944
|
+
|
|
945
|
+
# Major subtree analysis
|
|
946
|
+
if stats["subtrees"]:
|
|
947
|
+
logger.info("\n🌲 Major Subtree Analysis (sorted by quality):")
|
|
948
|
+
sorted_subtrees = sorted(
|
|
949
|
+
stats["subtrees"], key=lambda x: x.get("quality_score", 0), reverse=True
|
|
950
|
+
)
|
|
951
|
+
for i, subtree in enumerate(sorted_subtrees[:8]): # Show first 8
|
|
952
|
+
quality = subtree.get("quality_score", 0)
|
|
953
|
+
logger.info(f" #{i + 1} {subtree['root']} [{subtree['type']}]:")
|
|
954
|
+
logger.info(f" Quality score: {quality:.2f}")
|
|
955
|
+
logger.info(
|
|
956
|
+
f" Structure: Depth={subtree['depth']}, Branches={subtree['branches']}, Leaves={subtree['leaves']}"
|
|
957
|
+
)
|
|
958
|
+
logger.info(
|
|
959
|
+
f" Density: Branch density={subtree.get('branch_density', 0):.3f}, Leaf ratio={subtree.get('leaf_ratio', 0):.3f}"
|
|
960
|
+
)
|
|
961
|
+
|
|
962
|
+
if quality > 15:
|
|
963
|
+
logger.info(" ✅ High quality subtree")
|
|
964
|
+
elif quality > 8:
|
|
965
|
+
logger.info(" 🟡 Medium quality subtree")
|
|
966
|
+
else:
|
|
967
|
+
logger.info(" 🔴 Low quality subtree")
|
|
968
|
+
|
|
969
|
+
logger.info("\n" + "=" * 60)
|
|
970
|
+
|
|
971
|
+
|
|
972
|
+
def remove_embedding_recursive(memory_info: dict) -> Any:
|
|
973
|
+
"""remove the embedding from the memory info
|
|
974
|
+
Args:
|
|
975
|
+
memory_info: product memory info
|
|
976
|
+
|
|
977
|
+
Returns:
|
|
978
|
+
Any: product memory info without embedding
|
|
979
|
+
"""
|
|
980
|
+
if isinstance(memory_info, dict):
|
|
981
|
+
new_dict = {}
|
|
982
|
+
for key, value in memory_info.items():
|
|
983
|
+
if key != "embedding":
|
|
984
|
+
new_dict[key] = remove_embedding_recursive(value)
|
|
985
|
+
return new_dict
|
|
986
|
+
elif isinstance(memory_info, list):
|
|
987
|
+
return [remove_embedding_recursive(item) for item in memory_info]
|
|
988
|
+
else:
|
|
989
|
+
return memory_info
|
|
990
|
+
|
|
991
|
+
|
|
992
|
+
def remove_embedding_from_memory_items(memory_items: list[Any]) -> list[dict]:
|
|
993
|
+
"""Batch remove embedding fields from multiple TextualMemoryItem objects"""
|
|
994
|
+
clean_memories = []
|
|
995
|
+
|
|
996
|
+
for item in memory_items:
|
|
997
|
+
memory_dict = item.model_dump()
|
|
998
|
+
|
|
999
|
+
# Remove embedding from metadata
|
|
1000
|
+
if "metadata" in memory_dict and "embedding" in memory_dict["metadata"]:
|
|
1001
|
+
del memory_dict["metadata"]["embedding"]
|
|
1002
|
+
|
|
1003
|
+
clean_memories.append(memory_dict)
|
|
1004
|
+
|
|
1005
|
+
return clean_memories
|
|
1006
|
+
|
|
1007
|
+
|
|
1008
|
+
def sort_children_by_memory_type(children: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
1009
|
+
"""
|
|
1010
|
+
sort the children by the memory_type
|
|
1011
|
+
Args:
|
|
1012
|
+
children: the children of the node
|
|
1013
|
+
Returns:
|
|
1014
|
+
the sorted children
|
|
1015
|
+
"""
|
|
1016
|
+
if not children:
|
|
1017
|
+
return children
|
|
1018
|
+
|
|
1019
|
+
def get_sort_key(child):
|
|
1020
|
+
memory_type = child.get("memory_type", "Unknown")
|
|
1021
|
+
# Sort directly by memory_type string, same types will naturally cluster together
|
|
1022
|
+
return memory_type
|
|
1023
|
+
|
|
1024
|
+
# Sort by memory_type
|
|
1025
|
+
sorted_children = sorted(children, key=get_sort_key)
|
|
1026
|
+
|
|
1027
|
+
return sorted_children
|
|
1028
|
+
|
|
1029
|
+
|
|
1030
|
+
def extract_all_ids_from_tree(tree_node):
|
|
1031
|
+
"""
|
|
1032
|
+
Recursively traverse tree structure to extract all node IDs
|
|
1033
|
+
|
|
1034
|
+
Args:
|
|
1035
|
+
tree_node: Tree node (dictionary format)
|
|
1036
|
+
|
|
1037
|
+
Returns:
|
|
1038
|
+
set: Set containing all node IDs
|
|
1039
|
+
"""
|
|
1040
|
+
ids = set()
|
|
1041
|
+
|
|
1042
|
+
# Add current node ID (if exists)
|
|
1043
|
+
if "id" in tree_node:
|
|
1044
|
+
ids.add(tree_node["id"])
|
|
1045
|
+
|
|
1046
|
+
# Recursively process child nodes
|
|
1047
|
+
if tree_node.get("children"):
|
|
1048
|
+
for child in tree_node["children"]:
|
|
1049
|
+
ids.update(extract_all_ids_from_tree(child))
|
|
1050
|
+
|
|
1051
|
+
return ids
|
|
1052
|
+
|
|
1053
|
+
|
|
1054
|
+
def filter_nodes_by_tree_ids(tree_data, nodes_data):
|
|
1055
|
+
"""
|
|
1056
|
+
Filter nodes list based on IDs used in tree structure
|
|
1057
|
+
|
|
1058
|
+
Args:
|
|
1059
|
+
tree_data: Tree structure data (dictionary)
|
|
1060
|
+
nodes_data: Data containing nodes list (dictionary)
|
|
1061
|
+
|
|
1062
|
+
Returns:
|
|
1063
|
+
dict: Filtered nodes data, maintaining original structure
|
|
1064
|
+
"""
|
|
1065
|
+
# Extract all IDs used in the tree
|
|
1066
|
+
used_ids = extract_all_ids_from_tree(tree_data)
|
|
1067
|
+
|
|
1068
|
+
# Filter nodes list, keeping only nodes with IDs used in the tree
|
|
1069
|
+
filtered_nodes = [node for node in nodes_data["nodes"] if node["id"] in used_ids]
|
|
1070
|
+
|
|
1071
|
+
# Return result maintaining original structure
|
|
1072
|
+
return {"nodes": filtered_nodes}
|
|
1073
|
+
|
|
1074
|
+
|
|
1075
|
+
def convert_activation_memory_to_serializable(
|
|
1076
|
+
act_mem_items: list[KVCacheItem],
|
|
1077
|
+
) -> list[dict[str, Any]]:
|
|
1078
|
+
"""
|
|
1079
|
+
Convert activation memory items to a serializable format.
|
|
1080
|
+
|
|
1081
|
+
Args:
|
|
1082
|
+
act_mem_items: List of KVCacheItem objects
|
|
1083
|
+
|
|
1084
|
+
Returns:
|
|
1085
|
+
List of dictionaries with serializable data
|
|
1086
|
+
"""
|
|
1087
|
+
serializable_items = []
|
|
1088
|
+
|
|
1089
|
+
for item in act_mem_items:
|
|
1090
|
+
key_layers = 0
|
|
1091
|
+
val_layers = 0
|
|
1092
|
+
device = "unknown"
|
|
1093
|
+
dtype = "unknown"
|
|
1094
|
+
key_shapes = []
|
|
1095
|
+
value_shapes = []
|
|
1096
|
+
|
|
1097
|
+
if item.memory:
|
|
1098
|
+
if hasattr(item.memory, "layers"):
|
|
1099
|
+
key_layers = len(item.memory.layers)
|
|
1100
|
+
val_layers = len(item.memory.layers)
|
|
1101
|
+
if key_layers > 0:
|
|
1102
|
+
l0 = item.memory.layers[0]
|
|
1103
|
+
k0 = getattr(l0, "key_cache", getattr(l0, "keys", None))
|
|
1104
|
+
if k0 is not None:
|
|
1105
|
+
device = str(k0.device)
|
|
1106
|
+
dtype = str(k0.dtype)
|
|
1107
|
+
|
|
1108
|
+
for i, layer in enumerate(item.memory.layers):
|
|
1109
|
+
k = getattr(layer, "key_cache", getattr(layer, "keys", None))
|
|
1110
|
+
v = getattr(layer, "value_cache", getattr(layer, "values", None))
|
|
1111
|
+
if k is not None:
|
|
1112
|
+
key_shapes.append({"layer": i, "shape": list(k.shape)})
|
|
1113
|
+
if v is not None:
|
|
1114
|
+
value_shapes.append({"layer": i, "shape": list(v.shape)})
|
|
1115
|
+
|
|
1116
|
+
elif hasattr(item.memory, "key_cache"):
|
|
1117
|
+
key_layers = len(item.memory.key_cache)
|
|
1118
|
+
val_layers = len(item.memory.value_cache)
|
|
1119
|
+
if key_layers > 0 and item.memory.key_cache[0] is not None:
|
|
1120
|
+
device = str(item.memory.key_cache[0].device)
|
|
1121
|
+
dtype = str(item.memory.key_cache[0].dtype)
|
|
1122
|
+
|
|
1123
|
+
for i, key_tensor in enumerate(item.memory.key_cache):
|
|
1124
|
+
if key_tensor is not None:
|
|
1125
|
+
key_shapes.append({"layer": i, "shape": list(key_tensor.shape)})
|
|
1126
|
+
|
|
1127
|
+
for i, val_tensor in enumerate(item.memory.value_cache):
|
|
1128
|
+
if val_tensor is not None:
|
|
1129
|
+
value_shapes.append({"layer": i, "shape": list(val_tensor.shape)})
|
|
1130
|
+
|
|
1131
|
+
# Extract basic information that can be serialized
|
|
1132
|
+
serializable_item = {
|
|
1133
|
+
"id": item.id,
|
|
1134
|
+
"metadata": item.metadata,
|
|
1135
|
+
"memory_info": {
|
|
1136
|
+
"type": "DynamicCache",
|
|
1137
|
+
"key_cache_layers": key_layers,
|
|
1138
|
+
"value_cache_layers": val_layers,
|
|
1139
|
+
"device": device,
|
|
1140
|
+
"dtype": dtype,
|
|
1141
|
+
},
|
|
1142
|
+
}
|
|
1143
|
+
|
|
1144
|
+
# Add tensor shape information if available
|
|
1145
|
+
if key_shapes:
|
|
1146
|
+
serializable_item["memory_info"]["key_shapes"] = key_shapes
|
|
1147
|
+
if value_shapes:
|
|
1148
|
+
serializable_item["memory_info"]["value_shapes"] = value_shapes
|
|
1149
|
+
|
|
1150
|
+
serializable_items.append(serializable_item)
|
|
1151
|
+
|
|
1152
|
+
return serializable_items
|
|
1153
|
+
|
|
1154
|
+
|
|
1155
|
+
def convert_activation_memory_summary(act_mem_items: list[KVCacheItem]) -> dict[str, Any]:
|
|
1156
|
+
"""
|
|
1157
|
+
Create a summary of activation memory for API responses.
|
|
1158
|
+
|
|
1159
|
+
Args:
|
|
1160
|
+
act_mem_items: List of KVCacheItem objects
|
|
1161
|
+
|
|
1162
|
+
Returns:
|
|
1163
|
+
Dictionary with summary information
|
|
1164
|
+
"""
|
|
1165
|
+
if not act_mem_items:
|
|
1166
|
+
return {"total_items": 0, "summary": "No activation memory items found"}
|
|
1167
|
+
|
|
1168
|
+
total_items = len(act_mem_items)
|
|
1169
|
+
total_layers = 0
|
|
1170
|
+
total_parameters = 0
|
|
1171
|
+
|
|
1172
|
+
for item in act_mem_items:
|
|
1173
|
+
if not item.memory:
|
|
1174
|
+
continue
|
|
1175
|
+
|
|
1176
|
+
if hasattr(item.memory, "layers"):
|
|
1177
|
+
total_layers += len(item.memory.layers)
|
|
1178
|
+
for layer in item.memory.layers:
|
|
1179
|
+
k = getattr(layer, "key_cache", getattr(layer, "keys", None))
|
|
1180
|
+
v = getattr(layer, "value_cache", getattr(layer, "values", None))
|
|
1181
|
+
if k is not None:
|
|
1182
|
+
total_parameters += k.numel()
|
|
1183
|
+
if v is not None:
|
|
1184
|
+
total_parameters += v.numel()
|
|
1185
|
+
elif hasattr(item.memory, "key_cache"):
|
|
1186
|
+
total_layers += len(item.memory.key_cache)
|
|
1187
|
+
|
|
1188
|
+
# Calculate approximate parameter count
|
|
1189
|
+
for key_tensor in item.memory.key_cache:
|
|
1190
|
+
if key_tensor is not None:
|
|
1191
|
+
total_parameters += key_tensor.numel()
|
|
1192
|
+
|
|
1193
|
+
for value_tensor in item.memory.value_cache:
|
|
1194
|
+
if value_tensor is not None:
|
|
1195
|
+
total_parameters += value_tensor.numel()
|
|
1196
|
+
|
|
1197
|
+
return {
|
|
1198
|
+
"total_items": total_items,
|
|
1199
|
+
"total_layers": total_layers,
|
|
1200
|
+
"total_parameters": total_parameters,
|
|
1201
|
+
"summary": f"Activation memory contains {total_items} items with {total_layers} layers and approximately {total_parameters:,} parameters",
|
|
1202
|
+
}
|
|
1203
|
+
|
|
1204
|
+
|
|
1205
|
+
def detect_and_remove_duplicate_ids(tree_node: dict[str, Any]) -> dict[str, Any]:
|
|
1206
|
+
"""
|
|
1207
|
+
Detect and remove duplicate IDs in tree structure by skipping duplicate nodes.
|
|
1208
|
+
First occurrence of each ID is kept, subsequent duplicates are removed.
|
|
1209
|
+
|
|
1210
|
+
Args:
|
|
1211
|
+
tree_node: Tree node (dictionary format)
|
|
1212
|
+
|
|
1213
|
+
Returns:
|
|
1214
|
+
dict: Fixed tree node with duplicate nodes removed
|
|
1215
|
+
"""
|
|
1216
|
+
used_ids = set()
|
|
1217
|
+
removed_count = 0
|
|
1218
|
+
|
|
1219
|
+
def remove_duplicates_recursive(
|
|
1220
|
+
node: dict[str, Any], parent_path: str = ""
|
|
1221
|
+
) -> dict[str, Any] | None:
|
|
1222
|
+
"""Recursively remove duplicate IDs by skipping duplicate nodes"""
|
|
1223
|
+
nonlocal removed_count
|
|
1224
|
+
|
|
1225
|
+
if not isinstance(node, dict):
|
|
1226
|
+
return node
|
|
1227
|
+
|
|
1228
|
+
# Create node copy
|
|
1229
|
+
fixed_node = node.copy()
|
|
1230
|
+
|
|
1231
|
+
# Handle current node ID
|
|
1232
|
+
current_id = fixed_node.get("id", "")
|
|
1233
|
+
if current_id in used_ids and current_id not in ["root", "WorkingMemory"]:
|
|
1234
|
+
# Skip this duplicate node
|
|
1235
|
+
logger.info(f"Skipping duplicate node: {current_id} (path: {parent_path})")
|
|
1236
|
+
removed_count += 1
|
|
1237
|
+
return None # Return None to indicate this node should be removed
|
|
1238
|
+
else:
|
|
1239
|
+
used_ids.add(current_id)
|
|
1240
|
+
|
|
1241
|
+
# Recursively process child nodes
|
|
1242
|
+
if "children" in fixed_node and isinstance(fixed_node["children"], list):
|
|
1243
|
+
fixed_children = []
|
|
1244
|
+
for i, child in enumerate(fixed_node["children"]):
|
|
1245
|
+
child_path = f"{parent_path}/{fixed_node.get('node_name', 'unknown')}[{i}]"
|
|
1246
|
+
fixed_child = remove_duplicates_recursive(child, child_path)
|
|
1247
|
+
if fixed_child is not None: # Only add non-None children
|
|
1248
|
+
fixed_children.append(fixed_child)
|
|
1249
|
+
fixed_node["children"] = fixed_children
|
|
1250
|
+
|
|
1251
|
+
return fixed_node
|
|
1252
|
+
|
|
1253
|
+
result = remove_duplicates_recursive(tree_node)
|
|
1254
|
+
if result is not None:
|
|
1255
|
+
logger.info(f"Removed {removed_count} duplicate nodes")
|
|
1256
|
+
return result
|
|
1257
|
+
else:
|
|
1258
|
+
# If root node itself was removed (shouldn't happen), return empty root
|
|
1259
|
+
return {
|
|
1260
|
+
"id": "root",
|
|
1261
|
+
"node_name": "root",
|
|
1262
|
+
"value": "root",
|
|
1263
|
+
"memory_type": "Root",
|
|
1264
|
+
"children": [],
|
|
1265
|
+
}
|
|
1266
|
+
|
|
1267
|
+
|
|
1268
|
+
def validate_tree_structure(tree_node: dict[str, Any]) -> dict[str, Any]:
|
|
1269
|
+
"""
|
|
1270
|
+
Validate tree structure integrity, including ID uniqueness check
|
|
1271
|
+
|
|
1272
|
+
Args:
|
|
1273
|
+
tree_node: Tree node (dictionary format)
|
|
1274
|
+
|
|
1275
|
+
Returns:
|
|
1276
|
+
dict: Validation result containing error messages and fix suggestions
|
|
1277
|
+
"""
|
|
1278
|
+
validation_result = {
|
|
1279
|
+
"is_valid": True,
|
|
1280
|
+
"errors": [],
|
|
1281
|
+
"warnings": [],
|
|
1282
|
+
"total_nodes": 0,
|
|
1283
|
+
"unique_ids": set(),
|
|
1284
|
+
"duplicate_ids": set(),
|
|
1285
|
+
"missing_ids": set(),
|
|
1286
|
+
"invalid_structure": [],
|
|
1287
|
+
}
|
|
1288
|
+
|
|
1289
|
+
def validate_recursive(node: dict[str, Any], path: str = "", depth: int = 0):
|
|
1290
|
+
"""Recursively validate tree structure"""
|
|
1291
|
+
if not isinstance(node, dict):
|
|
1292
|
+
validation_result["errors"].append(f"Node is not a dictionary: {path}")
|
|
1293
|
+
validation_result["is_valid"] = False
|
|
1294
|
+
return
|
|
1295
|
+
|
|
1296
|
+
validation_result["total_nodes"] += 1
|
|
1297
|
+
|
|
1298
|
+
# Check required fields
|
|
1299
|
+
if "id" not in node:
|
|
1300
|
+
validation_result["errors"].append(f"Node missing ID field: {path}")
|
|
1301
|
+
validation_result["missing_ids"].add(path)
|
|
1302
|
+
validation_result["is_valid"] = False
|
|
1303
|
+
else:
|
|
1304
|
+
node_id = node["id"]
|
|
1305
|
+
if node_id in validation_result["unique_ids"]:
|
|
1306
|
+
validation_result["errors"].append(f"Duplicate node ID: {node_id} (path: {path})")
|
|
1307
|
+
validation_result["duplicate_ids"].add(node_id)
|
|
1308
|
+
validation_result["is_valid"] = False
|
|
1309
|
+
else:
|
|
1310
|
+
validation_result["unique_ids"].add(node_id)
|
|
1311
|
+
|
|
1312
|
+
# Check other required fields
|
|
1313
|
+
required_fields = ["node_name", "value", "memory_type"]
|
|
1314
|
+
for field in required_fields:
|
|
1315
|
+
if field not in node:
|
|
1316
|
+
validation_result["warnings"].append(f"Node missing field '{field}': {path}")
|
|
1317
|
+
|
|
1318
|
+
# Recursively validate child nodes
|
|
1319
|
+
if "children" in node:
|
|
1320
|
+
if not isinstance(node["children"], list):
|
|
1321
|
+
validation_result["errors"].append(f"Children field is not a list: {path}")
|
|
1322
|
+
validation_result["is_valid"] = False
|
|
1323
|
+
else:
|
|
1324
|
+
for i, child in enumerate(node["children"]):
|
|
1325
|
+
child_path = f"{path}/children[{i}]"
|
|
1326
|
+
validate_recursive(child, child_path, depth + 1)
|
|
1327
|
+
|
|
1328
|
+
# Check depth limit
|
|
1329
|
+
if depth > 20:
|
|
1330
|
+
validation_result["warnings"].append(f"Tree depth too deep ({depth}): {path}")
|
|
1331
|
+
|
|
1332
|
+
validate_recursive(tree_node)
|
|
1333
|
+
|
|
1334
|
+
# Generate fix suggestions
|
|
1335
|
+
if validation_result["duplicate_ids"]:
|
|
1336
|
+
validation_result["fix_suggestion"] = (
|
|
1337
|
+
"Use detect_and_fix_duplicate_ids() function to fix duplicate IDs"
|
|
1338
|
+
)
|
|
1339
|
+
|
|
1340
|
+
return validation_result
|
|
1341
|
+
|
|
1342
|
+
|
|
1343
|
+
def ensure_unique_tree_ids(tree_result: dict[str, Any]) -> dict[str, Any]:
|
|
1344
|
+
"""
|
|
1345
|
+
Ensure all node IDs in tree structure are unique by removing duplicate nodes,
|
|
1346
|
+
this is a post-processing function for convert_graph_to_tree_forworkmem
|
|
1347
|
+
|
|
1348
|
+
Args:
|
|
1349
|
+
tree_result: Tree structure returned by convert_graph_to_tree_forworkmem
|
|
1350
|
+
|
|
1351
|
+
Returns:
|
|
1352
|
+
dict: Fixed tree structure with duplicate nodes removed
|
|
1353
|
+
"""
|
|
1354
|
+
logger.info("🔍 Starting duplicate ID check in tree structure...")
|
|
1355
|
+
|
|
1356
|
+
# First validate tree structure
|
|
1357
|
+
validation = validate_tree_structure(tree_result)
|
|
1358
|
+
|
|
1359
|
+
if validation["is_valid"]:
|
|
1360
|
+
logger.info("Tree structure validation passed, no duplicate IDs found")
|
|
1361
|
+
return tree_result
|
|
1362
|
+
|
|
1363
|
+
# Report issues
|
|
1364
|
+
logger.info(f"Found {len(validation['errors'])} errors:")
|
|
1365
|
+
for error in validation["errors"][:5]: # Only show first 5 errors
|
|
1366
|
+
logger.info(f" - {error}")
|
|
1367
|
+
|
|
1368
|
+
if len(validation["errors"]) > 5:
|
|
1369
|
+
logger.info(f" ... and {len(validation['errors']) - 5} more errors")
|
|
1370
|
+
|
|
1371
|
+
logger.info("Statistics:")
|
|
1372
|
+
logger.info(f" - Total nodes: {validation['total_nodes']}")
|
|
1373
|
+
logger.info(f" - Unique IDs: {len(validation['unique_ids'])}")
|
|
1374
|
+
logger.info(f" - Duplicate IDs: {len(validation['duplicate_ids'])}")
|
|
1375
|
+
|
|
1376
|
+
# Remove duplicate nodes
|
|
1377
|
+
logger.info(" Starting duplicate node removal...")
|
|
1378
|
+
fixed_tree = detect_and_remove_duplicate_ids(tree_result)
|
|
1379
|
+
|
|
1380
|
+
# Validate again
|
|
1381
|
+
post_validation = validate_tree_structure(fixed_tree)
|
|
1382
|
+
if post_validation["is_valid"]:
|
|
1383
|
+
logger.info("Removal completed, tree structure is now valid")
|
|
1384
|
+
logger.info(f"Final node count: {post_validation['total_nodes']}")
|
|
1385
|
+
else:
|
|
1386
|
+
logger.info("Issues remain after removal, please check code logic")
|
|
1387
|
+
for error in post_validation["errors"][:3]:
|
|
1388
|
+
logger.info(f" - {error}")
|
|
1389
|
+
|
|
1390
|
+
return fixed_tree
|
|
1391
|
+
|
|
1392
|
+
|
|
1393
|
+
def clean_json_response(response: str) -> str:
|
|
1394
|
+
"""
|
|
1395
|
+
Remove markdown JSON code block formatting from LLM response.
|
|
1396
|
+
|
|
1397
|
+
Args:
|
|
1398
|
+
response: Raw response string that may contain ```json and ```
|
|
1399
|
+
|
|
1400
|
+
Returns:
|
|
1401
|
+
str: Clean JSON string without markdown formatting
|
|
1402
|
+
"""
|
|
1403
|
+
return response.replace("```json", "").replace("```", "").strip()
|