MemoryOS 2.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- memoryos-2.0.3.dist-info/METADATA +418 -0
- memoryos-2.0.3.dist-info/RECORD +315 -0
- memoryos-2.0.3.dist-info/WHEEL +4 -0
- memoryos-2.0.3.dist-info/entry_points.txt +3 -0
- memoryos-2.0.3.dist-info/licenses/LICENSE +201 -0
- memos/__init__.py +20 -0
- memos/api/client.py +571 -0
- memos/api/config.py +1018 -0
- memos/api/context/dependencies.py +50 -0
- memos/api/exceptions.py +53 -0
- memos/api/handlers/__init__.py +62 -0
- memos/api/handlers/add_handler.py +158 -0
- memos/api/handlers/base_handler.py +194 -0
- memos/api/handlers/chat_handler.py +1401 -0
- memos/api/handlers/component_init.py +388 -0
- memos/api/handlers/config_builders.py +190 -0
- memos/api/handlers/feedback_handler.py +93 -0
- memos/api/handlers/formatters_handler.py +237 -0
- memos/api/handlers/memory_handler.py +316 -0
- memos/api/handlers/scheduler_handler.py +497 -0
- memos/api/handlers/search_handler.py +222 -0
- memos/api/handlers/suggestion_handler.py +117 -0
- memos/api/mcp_serve.py +614 -0
- memos/api/middleware/request_context.py +101 -0
- memos/api/product_api.py +38 -0
- memos/api/product_models.py +1206 -0
- memos/api/routers/__init__.py +1 -0
- memos/api/routers/product_router.py +477 -0
- memos/api/routers/server_router.py +394 -0
- memos/api/server_api.py +44 -0
- memos/api/start_api.py +433 -0
- memos/chunkers/__init__.py +4 -0
- memos/chunkers/base.py +24 -0
- memos/chunkers/charactertext_chunker.py +41 -0
- memos/chunkers/factory.py +24 -0
- memos/chunkers/markdown_chunker.py +62 -0
- memos/chunkers/sentence_chunker.py +54 -0
- memos/chunkers/simple_chunker.py +50 -0
- memos/cli.py +113 -0
- memos/configs/__init__.py +0 -0
- memos/configs/base.py +82 -0
- memos/configs/chunker.py +59 -0
- memos/configs/embedder.py +88 -0
- memos/configs/graph_db.py +236 -0
- memos/configs/internet_retriever.py +100 -0
- memos/configs/llm.py +151 -0
- memos/configs/mem_agent.py +54 -0
- memos/configs/mem_chat.py +81 -0
- memos/configs/mem_cube.py +105 -0
- memos/configs/mem_os.py +83 -0
- memos/configs/mem_reader.py +91 -0
- memos/configs/mem_scheduler.py +385 -0
- memos/configs/mem_user.py +70 -0
- memos/configs/memory.py +324 -0
- memos/configs/parser.py +38 -0
- memos/configs/reranker.py +18 -0
- memos/configs/utils.py +8 -0
- memos/configs/vec_db.py +80 -0
- memos/context/context.py +355 -0
- memos/dependency.py +52 -0
- memos/deprecation.py +262 -0
- memos/embedders/__init__.py +0 -0
- memos/embedders/ark.py +95 -0
- memos/embedders/base.py +106 -0
- memos/embedders/factory.py +29 -0
- memos/embedders/ollama.py +77 -0
- memos/embedders/sentence_transformer.py +49 -0
- memos/embedders/universal_api.py +51 -0
- memos/exceptions.py +30 -0
- memos/graph_dbs/__init__.py +0 -0
- memos/graph_dbs/base.py +274 -0
- memos/graph_dbs/factory.py +27 -0
- memos/graph_dbs/item.py +46 -0
- memos/graph_dbs/nebular.py +1794 -0
- memos/graph_dbs/neo4j.py +1942 -0
- memos/graph_dbs/neo4j_community.py +1058 -0
- memos/graph_dbs/polardb.py +5446 -0
- memos/hello_world.py +97 -0
- memos/llms/__init__.py +0 -0
- memos/llms/base.py +25 -0
- memos/llms/deepseek.py +13 -0
- memos/llms/factory.py +38 -0
- memos/llms/hf.py +443 -0
- memos/llms/hf_singleton.py +114 -0
- memos/llms/ollama.py +135 -0
- memos/llms/openai.py +222 -0
- memos/llms/openai_new.py +198 -0
- memos/llms/qwen.py +13 -0
- memos/llms/utils.py +14 -0
- memos/llms/vllm.py +218 -0
- memos/log.py +237 -0
- memos/mem_agent/base.py +19 -0
- memos/mem_agent/deepsearch_agent.py +391 -0
- memos/mem_agent/factory.py +36 -0
- memos/mem_chat/__init__.py +0 -0
- memos/mem_chat/base.py +30 -0
- memos/mem_chat/factory.py +21 -0
- memos/mem_chat/simple.py +200 -0
- memos/mem_cube/__init__.py +0 -0
- memos/mem_cube/base.py +30 -0
- memos/mem_cube/general.py +240 -0
- memos/mem_cube/navie.py +172 -0
- memos/mem_cube/utils.py +169 -0
- memos/mem_feedback/base.py +15 -0
- memos/mem_feedback/feedback.py +1192 -0
- memos/mem_feedback/simple_feedback.py +40 -0
- memos/mem_feedback/utils.py +230 -0
- memos/mem_os/client.py +5 -0
- memos/mem_os/core.py +1203 -0
- memos/mem_os/main.py +582 -0
- memos/mem_os/product.py +1608 -0
- memos/mem_os/product_server.py +455 -0
- memos/mem_os/utils/default_config.py +359 -0
- memos/mem_os/utils/format_utils.py +1403 -0
- memos/mem_os/utils/reference_utils.py +162 -0
- memos/mem_reader/__init__.py +0 -0
- memos/mem_reader/base.py +47 -0
- memos/mem_reader/factory.py +53 -0
- memos/mem_reader/memory.py +298 -0
- memos/mem_reader/multi_modal_struct.py +965 -0
- memos/mem_reader/read_multi_modal/__init__.py +43 -0
- memos/mem_reader/read_multi_modal/assistant_parser.py +311 -0
- memos/mem_reader/read_multi_modal/base.py +273 -0
- memos/mem_reader/read_multi_modal/file_content_parser.py +826 -0
- memos/mem_reader/read_multi_modal/image_parser.py +359 -0
- memos/mem_reader/read_multi_modal/multi_modal_parser.py +252 -0
- memos/mem_reader/read_multi_modal/string_parser.py +139 -0
- memos/mem_reader/read_multi_modal/system_parser.py +327 -0
- memos/mem_reader/read_multi_modal/text_content_parser.py +131 -0
- memos/mem_reader/read_multi_modal/tool_parser.py +210 -0
- memos/mem_reader/read_multi_modal/user_parser.py +218 -0
- memos/mem_reader/read_multi_modal/utils.py +358 -0
- memos/mem_reader/simple_struct.py +912 -0
- memos/mem_reader/strategy_struct.py +163 -0
- memos/mem_reader/utils.py +157 -0
- memos/mem_scheduler/__init__.py +0 -0
- memos/mem_scheduler/analyzer/__init__.py +0 -0
- memos/mem_scheduler/analyzer/api_analyzer.py +714 -0
- memos/mem_scheduler/analyzer/eval_analyzer.py +219 -0
- memos/mem_scheduler/analyzer/mos_for_test_scheduler.py +571 -0
- memos/mem_scheduler/analyzer/scheduler_for_eval.py +280 -0
- memos/mem_scheduler/base_scheduler.py +1319 -0
- memos/mem_scheduler/general_modules/__init__.py +0 -0
- memos/mem_scheduler/general_modules/api_misc.py +137 -0
- memos/mem_scheduler/general_modules/base.py +80 -0
- memos/mem_scheduler/general_modules/init_components_for_scheduler.py +425 -0
- memos/mem_scheduler/general_modules/misc.py +313 -0
- memos/mem_scheduler/general_modules/scheduler_logger.py +389 -0
- memos/mem_scheduler/general_modules/task_threads.py +315 -0
- memos/mem_scheduler/general_scheduler.py +1495 -0
- memos/mem_scheduler/memory_manage_modules/__init__.py +5 -0
- memos/mem_scheduler/memory_manage_modules/memory_filter.py +306 -0
- memos/mem_scheduler/memory_manage_modules/retriever.py +547 -0
- memos/mem_scheduler/monitors/__init__.py +0 -0
- memos/mem_scheduler/monitors/dispatcher_monitor.py +366 -0
- memos/mem_scheduler/monitors/general_monitor.py +394 -0
- memos/mem_scheduler/monitors/task_schedule_monitor.py +254 -0
- memos/mem_scheduler/optimized_scheduler.py +410 -0
- memos/mem_scheduler/orm_modules/__init__.py +0 -0
- memos/mem_scheduler/orm_modules/api_redis_model.py +518 -0
- memos/mem_scheduler/orm_modules/base_model.py +729 -0
- memos/mem_scheduler/orm_modules/monitor_models.py +261 -0
- memos/mem_scheduler/orm_modules/redis_model.py +699 -0
- memos/mem_scheduler/scheduler_factory.py +23 -0
- memos/mem_scheduler/schemas/__init__.py +0 -0
- memos/mem_scheduler/schemas/analyzer_schemas.py +52 -0
- memos/mem_scheduler/schemas/api_schemas.py +233 -0
- memos/mem_scheduler/schemas/general_schemas.py +55 -0
- memos/mem_scheduler/schemas/message_schemas.py +173 -0
- memos/mem_scheduler/schemas/monitor_schemas.py +406 -0
- memos/mem_scheduler/schemas/task_schemas.py +132 -0
- memos/mem_scheduler/task_schedule_modules/__init__.py +0 -0
- memos/mem_scheduler/task_schedule_modules/dispatcher.py +740 -0
- memos/mem_scheduler/task_schedule_modules/local_queue.py +247 -0
- memos/mem_scheduler/task_schedule_modules/orchestrator.py +74 -0
- memos/mem_scheduler/task_schedule_modules/redis_queue.py +1385 -0
- memos/mem_scheduler/task_schedule_modules/task_queue.py +162 -0
- memos/mem_scheduler/utils/__init__.py +0 -0
- memos/mem_scheduler/utils/api_utils.py +77 -0
- memos/mem_scheduler/utils/config_utils.py +100 -0
- memos/mem_scheduler/utils/db_utils.py +50 -0
- memos/mem_scheduler/utils/filter_utils.py +176 -0
- memos/mem_scheduler/utils/metrics.py +125 -0
- memos/mem_scheduler/utils/misc_utils.py +290 -0
- memos/mem_scheduler/utils/monitor_event_utils.py +67 -0
- memos/mem_scheduler/utils/status_tracker.py +229 -0
- memos/mem_scheduler/webservice_modules/__init__.py +0 -0
- memos/mem_scheduler/webservice_modules/rabbitmq_service.py +485 -0
- memos/mem_scheduler/webservice_modules/redis_service.py +380 -0
- memos/mem_user/factory.py +94 -0
- memos/mem_user/mysql_persistent_user_manager.py +271 -0
- memos/mem_user/mysql_user_manager.py +502 -0
- memos/mem_user/persistent_factory.py +98 -0
- memos/mem_user/persistent_user_manager.py +260 -0
- memos/mem_user/redis_persistent_user_manager.py +225 -0
- memos/mem_user/user_manager.py +488 -0
- memos/memories/__init__.py +0 -0
- memos/memories/activation/__init__.py +0 -0
- memos/memories/activation/base.py +42 -0
- memos/memories/activation/item.py +56 -0
- memos/memories/activation/kv.py +292 -0
- memos/memories/activation/vllmkv.py +219 -0
- memos/memories/base.py +19 -0
- memos/memories/factory.py +42 -0
- memos/memories/parametric/__init__.py +0 -0
- memos/memories/parametric/base.py +19 -0
- memos/memories/parametric/item.py +11 -0
- memos/memories/parametric/lora.py +41 -0
- memos/memories/textual/__init__.py +0 -0
- memos/memories/textual/base.py +92 -0
- memos/memories/textual/general.py +236 -0
- memos/memories/textual/item.py +304 -0
- memos/memories/textual/naive.py +187 -0
- memos/memories/textual/prefer_text_memory/__init__.py +0 -0
- memos/memories/textual/prefer_text_memory/adder.py +504 -0
- memos/memories/textual/prefer_text_memory/config.py +106 -0
- memos/memories/textual/prefer_text_memory/extractor.py +221 -0
- memos/memories/textual/prefer_text_memory/factory.py +85 -0
- memos/memories/textual/prefer_text_memory/retrievers.py +177 -0
- memos/memories/textual/prefer_text_memory/spliter.py +132 -0
- memos/memories/textual/prefer_text_memory/utils.py +93 -0
- memos/memories/textual/preference.py +344 -0
- memos/memories/textual/simple_preference.py +161 -0
- memos/memories/textual/simple_tree.py +69 -0
- memos/memories/textual/tree.py +459 -0
- memos/memories/textual/tree_text_memory/__init__.py +0 -0
- memos/memories/textual/tree_text_memory/organize/__init__.py +0 -0
- memos/memories/textual/tree_text_memory/organize/handler.py +184 -0
- memos/memories/textual/tree_text_memory/organize/manager.py +518 -0
- memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +238 -0
- memos/memories/textual/tree_text_memory/organize/reorganizer.py +622 -0
- memos/memories/textual/tree_text_memory/retrieve/__init__.py +0 -0
- memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py +364 -0
- memos/memories/textual/tree_text_memory/retrieve/bm25_util.py +186 -0
- memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +419 -0
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever.py +270 -0
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +102 -0
- memos/memories/textual/tree_text_memory/retrieve/reasoner.py +61 -0
- memos/memories/textual/tree_text_memory/retrieve/recall.py +497 -0
- memos/memories/textual/tree_text_memory/retrieve/reranker.py +111 -0
- memos/memories/textual/tree_text_memory/retrieve/retrieval_mid_structs.py +16 -0
- memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py +472 -0
- memos/memories/textual/tree_text_memory/retrieve/searcher.py +848 -0
- memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +135 -0
- memos/memories/textual/tree_text_memory/retrieve/utils.py +54 -0
- memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py +387 -0
- memos/memos_tools/dinding_report_bot.py +453 -0
- memos/memos_tools/lockfree_dict.py +120 -0
- memos/memos_tools/notification_service.py +44 -0
- memos/memos_tools/notification_utils.py +142 -0
- memos/memos_tools/singleton.py +174 -0
- memos/memos_tools/thread_safe_dict.py +310 -0
- memos/memos_tools/thread_safe_dict_segment.py +382 -0
- memos/multi_mem_cube/__init__.py +0 -0
- memos/multi_mem_cube/composite_cube.py +86 -0
- memos/multi_mem_cube/single_cube.py +874 -0
- memos/multi_mem_cube/views.py +54 -0
- memos/parsers/__init__.py +0 -0
- memos/parsers/base.py +15 -0
- memos/parsers/factory.py +21 -0
- memos/parsers/markitdown.py +28 -0
- memos/reranker/__init__.py +4 -0
- memos/reranker/base.py +25 -0
- memos/reranker/concat.py +103 -0
- memos/reranker/cosine_local.py +102 -0
- memos/reranker/factory.py +72 -0
- memos/reranker/http_bge.py +324 -0
- memos/reranker/http_bge_strategy.py +327 -0
- memos/reranker/noop.py +19 -0
- memos/reranker/strategies/__init__.py +4 -0
- memos/reranker/strategies/base.py +61 -0
- memos/reranker/strategies/concat_background.py +94 -0
- memos/reranker/strategies/concat_docsource.py +110 -0
- memos/reranker/strategies/dialogue_common.py +109 -0
- memos/reranker/strategies/factory.py +31 -0
- memos/reranker/strategies/single_turn.py +107 -0
- memos/reranker/strategies/singleturn_outmem.py +98 -0
- memos/settings.py +10 -0
- memos/templates/__init__.py +0 -0
- memos/templates/advanced_search_prompts.py +211 -0
- memos/templates/cloud_service_prompt.py +107 -0
- memos/templates/instruction_completion.py +66 -0
- memos/templates/mem_agent_prompts.py +85 -0
- memos/templates/mem_feedback_prompts.py +822 -0
- memos/templates/mem_reader_prompts.py +1096 -0
- memos/templates/mem_reader_strategy_prompts.py +238 -0
- memos/templates/mem_scheduler_prompts.py +626 -0
- memos/templates/mem_search_prompts.py +93 -0
- memos/templates/mos_prompts.py +403 -0
- memos/templates/prefer_complete_prompt.py +735 -0
- memos/templates/tool_mem_prompts.py +139 -0
- memos/templates/tree_reorganize_prompts.py +230 -0
- memos/types/__init__.py +34 -0
- memos/types/general_types.py +151 -0
- memos/types/openai_chat_completion_types/__init__.py +15 -0
- memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py +56 -0
- memos/types/openai_chat_completion_types/chat_completion_content_part_image_param.py +27 -0
- memos/types/openai_chat_completion_types/chat_completion_content_part_input_audio_param.py +23 -0
- memos/types/openai_chat_completion_types/chat_completion_content_part_param.py +43 -0
- memos/types/openai_chat_completion_types/chat_completion_content_part_refusal_param.py +16 -0
- memos/types/openai_chat_completion_types/chat_completion_content_part_text_param.py +16 -0
- memos/types/openai_chat_completion_types/chat_completion_message_custom_tool_call_param.py +27 -0
- memos/types/openai_chat_completion_types/chat_completion_message_function_tool_call_param.py +32 -0
- memos/types/openai_chat_completion_types/chat_completion_message_param.py +18 -0
- memos/types/openai_chat_completion_types/chat_completion_message_tool_call_union_param.py +15 -0
- memos/types/openai_chat_completion_types/chat_completion_system_message_param.py +36 -0
- memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py +30 -0
- memos/types/openai_chat_completion_types/chat_completion_user_message_param.py +34 -0
- memos/utils.py +123 -0
- memos/vec_dbs/__init__.py +0 -0
- memos/vec_dbs/base.py +117 -0
- memos/vec_dbs/factory.py +23 -0
- memos/vec_dbs/item.py +50 -0
- memos/vec_dbs/milvus.py +654 -0
- memos/vec_dbs/qdrant.py +355 -0
|
@@ -0,0 +1,1794 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import traceback
|
|
3
|
+
|
|
4
|
+
from contextlib import suppress
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from threading import Lock
|
|
7
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Literal
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from memos.configs.graph_db import NebulaGraphDBConfig
|
|
12
|
+
from memos.dependency import require_python_package
|
|
13
|
+
from memos.graph_dbs.base import BaseGraphDB
|
|
14
|
+
from memos.log import get_logger
|
|
15
|
+
from memos.utils import timed
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from nebulagraph_python import (
|
|
20
|
+
NebulaClient,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
logger = get_logger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
_TRANSIENT_ERR_KEYS = (
|
|
28
|
+
"Session not found",
|
|
29
|
+
"Connection not established",
|
|
30
|
+
"timeout",
|
|
31
|
+
"deadline exceeded",
|
|
32
|
+
"Broken pipe",
|
|
33
|
+
"EOFError",
|
|
34
|
+
"socket closed",
|
|
35
|
+
"connection reset",
|
|
36
|
+
"connection refused",
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@timed
|
|
41
|
+
def _normalize(vec: list[float]) -> list[float]:
|
|
42
|
+
v = np.asarray(vec, dtype=np.float32)
|
|
43
|
+
norm = np.linalg.norm(v)
|
|
44
|
+
return (v / (norm if norm else 1.0)).tolist()
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@timed
|
|
48
|
+
def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]:
|
|
49
|
+
node_id = item["id"]
|
|
50
|
+
memory = item["memory"]
|
|
51
|
+
metadata = item.get("metadata", {})
|
|
52
|
+
return node_id, memory, metadata
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@timed
|
|
56
|
+
def _escape_str(value: str) -> str:
|
|
57
|
+
out = []
|
|
58
|
+
for ch in value:
|
|
59
|
+
code = ord(ch)
|
|
60
|
+
if ch == "\\":
|
|
61
|
+
out.append("\\\\")
|
|
62
|
+
elif ch == '"':
|
|
63
|
+
out.append('\\"')
|
|
64
|
+
elif ch == "\n":
|
|
65
|
+
out.append("\\n")
|
|
66
|
+
elif ch == "\r":
|
|
67
|
+
out.append("\\r")
|
|
68
|
+
elif ch == "\t":
|
|
69
|
+
out.append("\\t")
|
|
70
|
+
elif ch == "\b":
|
|
71
|
+
out.append("\\b")
|
|
72
|
+
elif ch == "\f":
|
|
73
|
+
out.append("\\f")
|
|
74
|
+
elif code < 0x20 or code in (0x2028, 0x2029):
|
|
75
|
+
out.append(f"\\u{code:04x}")
|
|
76
|
+
else:
|
|
77
|
+
out.append(ch)
|
|
78
|
+
return "".join(out)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@timed
|
|
82
|
+
def _format_datetime(value: str | datetime) -> str:
|
|
83
|
+
"""Ensure datetime is in ISO 8601 format string."""
|
|
84
|
+
if isinstance(value, datetime):
|
|
85
|
+
return value.isoformat()
|
|
86
|
+
return str(value)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@timed
|
|
90
|
+
def _normalize_datetime(val):
|
|
91
|
+
"""
|
|
92
|
+
Normalize datetime to ISO 8601 UTC string with +00:00.
|
|
93
|
+
- If val is datetime object -> keep isoformat() (Neo4j)
|
|
94
|
+
- If val is string without timezone -> append +00:00 (Nebula)
|
|
95
|
+
- Otherwise just str()
|
|
96
|
+
"""
|
|
97
|
+
if hasattr(val, "isoformat"):
|
|
98
|
+
return val.isoformat()
|
|
99
|
+
if isinstance(val, str) and not val.endswith(("+00:00", "Z", "+08:00")):
|
|
100
|
+
return val + "+08:00"
|
|
101
|
+
return str(val)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class NebulaGraphDB(BaseGraphDB):
|
|
105
|
+
"""
|
|
106
|
+
NebulaGraph-based implementation of a graph memory store.
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
# ====== shared pool cache & refcount ======
|
|
110
|
+
# These are process-local; in a multi-process model each process will
|
|
111
|
+
# have its own cache.
|
|
112
|
+
_CLIENT_CACHE: ClassVar[dict[str, "NebulaClient"]] = {}
|
|
113
|
+
_CLIENT_REFCOUNT: ClassVar[dict[str, int]] = {}
|
|
114
|
+
_CLIENT_LOCK: ClassVar[Lock] = Lock()
|
|
115
|
+
_CLIENT_INIT_DONE: ClassVar[set[str]] = set()
|
|
116
|
+
|
|
117
|
+
@staticmethod
|
|
118
|
+
def _get_hosts_from_cfg(cfg: NebulaGraphDBConfig) -> list[str]:
|
|
119
|
+
hosts = getattr(cfg, "uri", None) or getattr(cfg, "hosts", None)
|
|
120
|
+
if isinstance(hosts, str):
|
|
121
|
+
return [hosts]
|
|
122
|
+
return list(hosts or [])
|
|
123
|
+
|
|
124
|
+
@staticmethod
|
|
125
|
+
def _make_client_key(cfg: NebulaGraphDBConfig) -> str:
|
|
126
|
+
hosts = NebulaGraphDB._get_hosts_from_cfg(cfg)
|
|
127
|
+
return "|".join(
|
|
128
|
+
[
|
|
129
|
+
"nebula-sync",
|
|
130
|
+
",".join(hosts),
|
|
131
|
+
str(getattr(cfg, "user", "")),
|
|
132
|
+
str(getattr(cfg, "space", "")),
|
|
133
|
+
]
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
@classmethod
|
|
137
|
+
def _bootstrap_admin(cls, cfg: NebulaGraphDBConfig, client: "NebulaClient") -> "NebulaGraphDB":
|
|
138
|
+
tmp = object.__new__(NebulaGraphDB)
|
|
139
|
+
tmp.config = cfg
|
|
140
|
+
tmp.db_name = cfg.space
|
|
141
|
+
tmp.user_name = None
|
|
142
|
+
tmp.embedding_dimension = getattr(cfg, "embedding_dimension", 3072)
|
|
143
|
+
tmp.default_memory_dimension = 3072
|
|
144
|
+
tmp.common_fields = {
|
|
145
|
+
"id",
|
|
146
|
+
"memory",
|
|
147
|
+
"user_name",
|
|
148
|
+
"user_id",
|
|
149
|
+
"session_id",
|
|
150
|
+
"status",
|
|
151
|
+
"key",
|
|
152
|
+
"confidence",
|
|
153
|
+
"tags",
|
|
154
|
+
"created_at",
|
|
155
|
+
"updated_at",
|
|
156
|
+
"memory_type",
|
|
157
|
+
"sources",
|
|
158
|
+
"source",
|
|
159
|
+
"node_type",
|
|
160
|
+
"visibility",
|
|
161
|
+
"usage",
|
|
162
|
+
"background",
|
|
163
|
+
}
|
|
164
|
+
tmp.base_fields = set(tmp.common_fields) - {"usage"}
|
|
165
|
+
tmp.heavy_fields = {"usage"}
|
|
166
|
+
tmp.dim_field = (
|
|
167
|
+
f"embedding_{tmp.embedding_dimension}"
|
|
168
|
+
if str(tmp.embedding_dimension) != str(tmp.default_memory_dimension)
|
|
169
|
+
else "embedding"
|
|
170
|
+
)
|
|
171
|
+
tmp.system_db_name = cfg.space
|
|
172
|
+
tmp._client = client
|
|
173
|
+
tmp._owns_client = False
|
|
174
|
+
return tmp
|
|
175
|
+
|
|
176
|
+
@classmethod
|
|
177
|
+
def _get_or_create_shared_client(cls, cfg: NebulaGraphDBConfig) -> tuple[str, "NebulaClient"]:
|
|
178
|
+
from nebulagraph_python import (
|
|
179
|
+
ConnectionConfig,
|
|
180
|
+
NebulaClient,
|
|
181
|
+
SessionConfig,
|
|
182
|
+
SessionPoolConfig,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
key = cls._make_client_key(cfg)
|
|
186
|
+
with cls._CLIENT_LOCK:
|
|
187
|
+
client = cls._CLIENT_CACHE.get(key)
|
|
188
|
+
if client is None:
|
|
189
|
+
# Connection setting
|
|
190
|
+
|
|
191
|
+
tmp_client = NebulaClient(
|
|
192
|
+
hosts=cfg.uri,
|
|
193
|
+
username=cfg.user,
|
|
194
|
+
password=cfg.password,
|
|
195
|
+
session_config=SessionConfig(graph=None),
|
|
196
|
+
session_pool_config=SessionPoolConfig(size=1, wait_timeout=3000),
|
|
197
|
+
)
|
|
198
|
+
try:
|
|
199
|
+
cls._ensure_space_exists(tmp_client, cfg)
|
|
200
|
+
finally:
|
|
201
|
+
tmp_client.close()
|
|
202
|
+
|
|
203
|
+
conn_conf: ConnectionConfig | None = getattr(cfg, "conn_config", None)
|
|
204
|
+
if conn_conf is None:
|
|
205
|
+
conn_conf = ConnectionConfig.from_defults(
|
|
206
|
+
cls._get_hosts_from_cfg(cfg),
|
|
207
|
+
getattr(cfg, "ssl_param", None),
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
sess_conf = SessionConfig(graph=getattr(cfg, "space", None))
|
|
211
|
+
pool_conf = SessionPoolConfig(
|
|
212
|
+
size=int(getattr(cfg, "max_client", 1000)), wait_timeout=5000
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
client = NebulaClient(
|
|
216
|
+
hosts=conn_conf.hosts,
|
|
217
|
+
username=cfg.user,
|
|
218
|
+
password=cfg.password,
|
|
219
|
+
conn_config=conn_conf,
|
|
220
|
+
session_config=sess_conf,
|
|
221
|
+
session_pool_config=pool_conf,
|
|
222
|
+
)
|
|
223
|
+
cls._CLIENT_CACHE[key] = client
|
|
224
|
+
cls._CLIENT_REFCOUNT[key] = 0
|
|
225
|
+
logger.info(f"[NebulaGraphDBSync] Created shared NebulaClient key={key}")
|
|
226
|
+
|
|
227
|
+
cls._CLIENT_REFCOUNT[key] = cls._CLIENT_REFCOUNT.get(key, 0) + 1
|
|
228
|
+
|
|
229
|
+
if getattr(cfg, "auto_create", False) and key not in cls._CLIENT_INIT_DONE:
|
|
230
|
+
try:
|
|
231
|
+
pass
|
|
232
|
+
finally:
|
|
233
|
+
pass
|
|
234
|
+
|
|
235
|
+
if getattr(cfg, "auto_create", False) and key not in cls._CLIENT_INIT_DONE:
|
|
236
|
+
with cls._CLIENT_LOCK:
|
|
237
|
+
if key not in cls._CLIENT_INIT_DONE:
|
|
238
|
+
admin = cls._bootstrap_admin(cfg, client)
|
|
239
|
+
try:
|
|
240
|
+
admin._ensure_database_exists()
|
|
241
|
+
admin._create_basic_property_indexes()
|
|
242
|
+
admin._create_vector_index(
|
|
243
|
+
dimensions=int(
|
|
244
|
+
admin.embedding_dimension or admin.default_memory_dimension
|
|
245
|
+
),
|
|
246
|
+
)
|
|
247
|
+
cls._CLIENT_INIT_DONE.add(key)
|
|
248
|
+
logger.info("[NebulaGraphDBSync] One-time init done")
|
|
249
|
+
except Exception:
|
|
250
|
+
logger.exception("[NebulaGraphDBSync] One-time init failed")
|
|
251
|
+
|
|
252
|
+
return key, client
|
|
253
|
+
|
|
254
|
+
def _refresh_client(self):
|
|
255
|
+
"""
|
|
256
|
+
refresh NebulaClient:
|
|
257
|
+
"""
|
|
258
|
+
old_key = getattr(self, "_client_key", None)
|
|
259
|
+
if not old_key:
|
|
260
|
+
return
|
|
261
|
+
|
|
262
|
+
cls = self.__class__
|
|
263
|
+
with cls._CLIENT_LOCK:
|
|
264
|
+
try:
|
|
265
|
+
if old_key in cls._CLIENT_CACHE:
|
|
266
|
+
try:
|
|
267
|
+
cls._CLIENT_CACHE[old_key].close()
|
|
268
|
+
except Exception as e:
|
|
269
|
+
logger.warning(f"[refresh_client] close old client error: {e}")
|
|
270
|
+
finally:
|
|
271
|
+
cls._CLIENT_CACHE.pop(old_key, None)
|
|
272
|
+
finally:
|
|
273
|
+
cls._CLIENT_REFCOUNT[old_key] = 0
|
|
274
|
+
|
|
275
|
+
new_key, new_client = cls._get_or_create_shared_client(self.config)
|
|
276
|
+
self._client_key = new_key
|
|
277
|
+
self._client = new_client
|
|
278
|
+
logger.info(f"[NebulaGraphDBSync] client refreshed: {old_key} -> {new_key}")
|
|
279
|
+
|
|
280
|
+
@classmethod
|
|
281
|
+
def _release_shared_client(cls, key: str):
|
|
282
|
+
with cls._CLIENT_LOCK:
|
|
283
|
+
if key not in cls._CLIENT_CACHE:
|
|
284
|
+
return
|
|
285
|
+
cls._CLIENT_REFCOUNT[key] = max(0, cls._CLIENT_REFCOUNT.get(key, 0) - 1)
|
|
286
|
+
if cls._CLIENT_REFCOUNT[key] == 0:
|
|
287
|
+
try:
|
|
288
|
+
cls._CLIENT_CACHE[key].close()
|
|
289
|
+
except Exception as e:
|
|
290
|
+
logger.warning(f"[NebulaGraphDBSync] Error closing client: {e}")
|
|
291
|
+
finally:
|
|
292
|
+
cls._CLIENT_CACHE.pop(key, None)
|
|
293
|
+
cls._CLIENT_REFCOUNT.pop(key, None)
|
|
294
|
+
logger.info(f"[NebulaGraphDBSync] Closed & removed client key={key}")
|
|
295
|
+
|
|
296
|
+
@classmethod
|
|
297
|
+
def close_all_shared_clients(cls):
|
|
298
|
+
with cls._CLIENT_LOCK:
|
|
299
|
+
for key, client in list(cls._CLIENT_CACHE.items()):
|
|
300
|
+
try:
|
|
301
|
+
client.close()
|
|
302
|
+
except Exception as e:
|
|
303
|
+
logger.warning(f"[NebulaGraphDBSync] Error closing client {key}: {e}")
|
|
304
|
+
finally:
|
|
305
|
+
logger.info(f"[NebulaGraphDBSync] Closed client key={key}")
|
|
306
|
+
cls._CLIENT_CACHE.clear()
|
|
307
|
+
cls._CLIENT_REFCOUNT.clear()
|
|
308
|
+
|
|
309
|
+
@require_python_package(
|
|
310
|
+
import_name="nebulagraph_python",
|
|
311
|
+
install_command="pip install nebulagraph-python>=5.1.1",
|
|
312
|
+
install_link=".....",
|
|
313
|
+
)
|
|
314
|
+
def __init__(self, config: NebulaGraphDBConfig):
|
|
315
|
+
"""
|
|
316
|
+
NebulaGraph DB client initialization.
|
|
317
|
+
|
|
318
|
+
Required config attributes:
|
|
319
|
+
- hosts: list[str] like ["host1:port", "host2:port"]
|
|
320
|
+
- user: str
|
|
321
|
+
- password: str
|
|
322
|
+
- db_name: str (optional for basic commands)
|
|
323
|
+
|
|
324
|
+
Example config:
|
|
325
|
+
{
|
|
326
|
+
"hosts": ["xxx.xx.xx.xxx:xxxx"],
|
|
327
|
+
"user": "root",
|
|
328
|
+
"password": "nebula",
|
|
329
|
+
"space": "test"
|
|
330
|
+
}
|
|
331
|
+
"""
|
|
332
|
+
|
|
333
|
+
assert config.use_multi_db is False, "Multi-DB MODE IS NOT SUPPORTED"
|
|
334
|
+
self.config = config
|
|
335
|
+
self.db_name = config.space
|
|
336
|
+
self.user_name = config.user_name
|
|
337
|
+
self.embedding_dimension = config.embedding_dimension
|
|
338
|
+
self.default_memory_dimension = 3072
|
|
339
|
+
self.common_fields = {
|
|
340
|
+
"id",
|
|
341
|
+
"memory",
|
|
342
|
+
"user_name",
|
|
343
|
+
"user_id",
|
|
344
|
+
"session_id",
|
|
345
|
+
"status",
|
|
346
|
+
"key",
|
|
347
|
+
"confidence",
|
|
348
|
+
"tags",
|
|
349
|
+
"created_at",
|
|
350
|
+
"updated_at",
|
|
351
|
+
"memory_type",
|
|
352
|
+
"sources",
|
|
353
|
+
"source",
|
|
354
|
+
"node_type",
|
|
355
|
+
"visibility",
|
|
356
|
+
"usage",
|
|
357
|
+
"background",
|
|
358
|
+
}
|
|
359
|
+
self.base_fields = set(self.common_fields) - {"usage"}
|
|
360
|
+
self.heavy_fields = {"usage"}
|
|
361
|
+
self.dim_field = (
|
|
362
|
+
f"embedding_{self.embedding_dimension}"
|
|
363
|
+
if (str(self.embedding_dimension) != str(self.default_memory_dimension))
|
|
364
|
+
else "embedding"
|
|
365
|
+
)
|
|
366
|
+
self.system_db_name = config.space
|
|
367
|
+
|
|
368
|
+
# ---- NEW: pool acquisition strategy
|
|
369
|
+
# Get or create a shared pool from the class-level cache
|
|
370
|
+
self._client_key, self._client = self._get_or_create_shared_client(config)
|
|
371
|
+
self._owns_client = True
|
|
372
|
+
|
|
373
|
+
logger.info("Connected to NebulaGraph successfully.")
|
|
374
|
+
|
|
375
|
+
@timed
|
|
376
|
+
def execute_query(self, gql: str, timeout: float = 60.0, auto_set_db: bool = True):
|
|
377
|
+
def _wrap_use_db(q: str) -> str:
|
|
378
|
+
if auto_set_db and self.db_name:
|
|
379
|
+
return f"USE `{self.db_name}`\n{q}"
|
|
380
|
+
return q
|
|
381
|
+
|
|
382
|
+
try:
|
|
383
|
+
return self._client.execute(_wrap_use_db(gql), timeout=timeout)
|
|
384
|
+
|
|
385
|
+
except Exception as e:
|
|
386
|
+
emsg = str(e)
|
|
387
|
+
if any(k.lower() in emsg.lower() for k in _TRANSIENT_ERR_KEYS):
|
|
388
|
+
logger.warning(f"[execute_query] {e!s} → refreshing session pool and retry once...")
|
|
389
|
+
try:
|
|
390
|
+
self._refresh_client()
|
|
391
|
+
return self._client.execute(_wrap_use_db(gql), timeout=timeout)
|
|
392
|
+
except Exception:
|
|
393
|
+
logger.exception("[execute_query] retry after refresh failed")
|
|
394
|
+
raise
|
|
395
|
+
raise
|
|
396
|
+
|
|
397
|
+
@timed
|
|
398
|
+
def close(self):
|
|
399
|
+
"""
|
|
400
|
+
Close the connection resource if this instance owns it.
|
|
401
|
+
|
|
402
|
+
- If pool was injected (`shared_pool`), do nothing.
|
|
403
|
+
- If pool was acquired via shared cache, decrement refcount and close
|
|
404
|
+
when the last owner releases it.
|
|
405
|
+
"""
|
|
406
|
+
if not self._owns_client:
|
|
407
|
+
logger.debug("[NebulaGraphDBSync] close() skipped (injected client).")
|
|
408
|
+
return
|
|
409
|
+
if self._client_key:
|
|
410
|
+
self._release_shared_client(self._client_key)
|
|
411
|
+
self._client_key = None
|
|
412
|
+
self._client = None
|
|
413
|
+
|
|
414
|
+
# NOTE: __del__ is best-effort; do not rely on GC order.
|
|
415
|
+
def __del__(self):
|
|
416
|
+
with suppress(Exception):
|
|
417
|
+
self.close()
|
|
418
|
+
|
|
419
|
+
@timed
|
|
420
|
+
def create_index(
|
|
421
|
+
self,
|
|
422
|
+
label: str = "Memory",
|
|
423
|
+
vector_property: str = "embedding",
|
|
424
|
+
dimensions: int = 3072,
|
|
425
|
+
index_name: str = "memory_vector_index",
|
|
426
|
+
) -> None:
|
|
427
|
+
# Create vector index
|
|
428
|
+
self._create_vector_index(label, vector_property, dimensions, index_name)
|
|
429
|
+
# Create indexes
|
|
430
|
+
self._create_basic_property_indexes()
|
|
431
|
+
|
|
432
|
+
@timed
|
|
433
|
+
def remove_oldest_memory(
|
|
434
|
+
self, memory_type: str, keep_latest: int, user_name: str | None = None
|
|
435
|
+
) -> None:
|
|
436
|
+
"""
|
|
437
|
+
Remove all WorkingMemory nodes except the latest `keep_latest` entries.
|
|
438
|
+
|
|
439
|
+
Args:
|
|
440
|
+
memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory').
|
|
441
|
+
keep_latest (int): Number of latest WorkingMemory entries to keep.
|
|
442
|
+
user_name(str): optional user_name.
|
|
443
|
+
"""
|
|
444
|
+
try:
|
|
445
|
+
user_name = user_name if user_name else self.config.user_name
|
|
446
|
+
optional_condition = f"AND n.user_name = '{user_name}'"
|
|
447
|
+
count = self.count_nodes(memory_type, user_name)
|
|
448
|
+
if count > keep_latest:
|
|
449
|
+
delete_query = f"""
|
|
450
|
+
MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
|
|
451
|
+
WHERE n.memory_type = '{memory_type}'
|
|
452
|
+
{optional_condition}
|
|
453
|
+
ORDER BY n.updated_at DESC
|
|
454
|
+
OFFSET {int(keep_latest)}
|
|
455
|
+
DETACH DELETE n
|
|
456
|
+
"""
|
|
457
|
+
self.execute_query(delete_query)
|
|
458
|
+
except Exception as e:
|
|
459
|
+
logger.warning(f"Delete old mem error: {e}")
|
|
460
|
+
|
|
461
|
+
@timed
|
|
462
|
+
def add_node(
|
|
463
|
+
self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None
|
|
464
|
+
) -> None:
|
|
465
|
+
"""
|
|
466
|
+
Insert or update a Memory node in NebulaGraph.
|
|
467
|
+
"""
|
|
468
|
+
metadata["user_name"] = user_name if user_name else self.config.user_name
|
|
469
|
+
now = datetime.utcnow()
|
|
470
|
+
metadata = metadata.copy()
|
|
471
|
+
metadata.setdefault("created_at", now)
|
|
472
|
+
metadata.setdefault("updated_at", now)
|
|
473
|
+
metadata["node_type"] = metadata.pop("type")
|
|
474
|
+
metadata["id"] = id
|
|
475
|
+
metadata["memory"] = memory
|
|
476
|
+
|
|
477
|
+
if "embedding" in metadata and isinstance(metadata["embedding"], list):
|
|
478
|
+
assert len(metadata["embedding"]) == self.embedding_dimension, (
|
|
479
|
+
f"input embedding dimension must equal to {self.embedding_dimension}"
|
|
480
|
+
)
|
|
481
|
+
embedding = metadata.pop("embedding")
|
|
482
|
+
metadata[self.dim_field] = _normalize(embedding)
|
|
483
|
+
|
|
484
|
+
metadata = self._metadata_filter(metadata)
|
|
485
|
+
properties = ", ".join(f"{k}: {self._format_value(v, k)}" for k, v in metadata.items())
|
|
486
|
+
gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
|
|
487
|
+
|
|
488
|
+
try:
|
|
489
|
+
self.execute_query(gql)
|
|
490
|
+
logger.info("insert success")
|
|
491
|
+
except Exception as e:
|
|
492
|
+
logger.error(
|
|
493
|
+
f"Failed to insert vertex {id}: gql: {gql}, {e}\ntrace: {traceback.format_exc()}"
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
@timed
|
|
497
|
+
def node_not_exist(self, scope: str, user_name: str | None = None) -> int:
|
|
498
|
+
user_name = user_name if user_name else self.config.user_name
|
|
499
|
+
filter_clause = f'n.memory_type = "{scope}" AND n.user_name = "{user_name}"'
|
|
500
|
+
query = f"""
|
|
501
|
+
MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
|
|
502
|
+
WHERE {filter_clause}
|
|
503
|
+
RETURN n.id AS id
|
|
504
|
+
LIMIT 1
|
|
505
|
+
"""
|
|
506
|
+
|
|
507
|
+
try:
|
|
508
|
+
result = self.execute_query(query)
|
|
509
|
+
return result.size == 0
|
|
510
|
+
except Exception as e:
|
|
511
|
+
logger.error(f"[node_not_exist] Query failed: {e}", exc_info=True)
|
|
512
|
+
raise
|
|
513
|
+
|
|
514
|
+
@timed
|
|
515
|
+
def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None:
|
|
516
|
+
"""
|
|
517
|
+
Update node fields in Nebular, auto-converting `created_at` and `updated_at` to datetime type if present.
|
|
518
|
+
"""
|
|
519
|
+
user_name = user_name if user_name else self.config.user_name
|
|
520
|
+
fields = fields.copy()
|
|
521
|
+
set_clauses = []
|
|
522
|
+
for k, v in fields.items():
|
|
523
|
+
set_clauses.append(f"n.{k} = {self._format_value(v, k)}")
|
|
524
|
+
|
|
525
|
+
set_clause_str = ",\n ".join(set_clauses)
|
|
526
|
+
|
|
527
|
+
query = f"""
|
|
528
|
+
MATCH (n@Memory {{id: "{id}"}})
|
|
529
|
+
"""
|
|
530
|
+
query += f'WHERE n.user_name = "{user_name}"'
|
|
531
|
+
|
|
532
|
+
query += f"\nSET {set_clause_str}"
|
|
533
|
+
self.execute_query(query)
|
|
534
|
+
|
|
535
|
+
@timed
|
|
536
|
+
def delete_node(self, id: str, user_name: str | None = None) -> None:
|
|
537
|
+
"""
|
|
538
|
+
Delete a node from the graph.
|
|
539
|
+
Args:
|
|
540
|
+
id: Node identifier to delete.
|
|
541
|
+
user_name (str, optional): User name for filtering in non-multi-db mode
|
|
542
|
+
"""
|
|
543
|
+
user_name = user_name if user_name else self.config.user_name
|
|
544
|
+
query = f"""
|
|
545
|
+
MATCH (n@Memory {{id: "{id}"}}) WHERE n.user_name = {self._format_value(user_name)}
|
|
546
|
+
DETACH DELETE n
|
|
547
|
+
"""
|
|
548
|
+
self.execute_query(query)
|
|
549
|
+
|
|
550
|
+
@timed
|
|
551
|
+
def add_edge(self, source_id: str, target_id: str, type: str, user_name: str | None = None):
|
|
552
|
+
"""
|
|
553
|
+
Create an edge from source node to target node.
|
|
554
|
+
Args:
|
|
555
|
+
source_id: ID of the source node.
|
|
556
|
+
target_id: ID of the target node.
|
|
557
|
+
type: Relationship type (e.g., 'RELATE_TO', 'PARENT').
|
|
558
|
+
user_name (str, optional): User name for filtering in non-multi-db mode
|
|
559
|
+
"""
|
|
560
|
+
if not source_id or not target_id:
|
|
561
|
+
raise ValueError("[add_edge] source_id and target_id must be provided")
|
|
562
|
+
user_name = user_name if user_name else self.config.user_name
|
|
563
|
+
props = ""
|
|
564
|
+
props = f'{{user_name: "{user_name}"}}'
|
|
565
|
+
insert_stmt = f'''
|
|
566
|
+
MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}})
|
|
567
|
+
INSERT (a) -[e@{type} {props}]-> (b)
|
|
568
|
+
'''
|
|
569
|
+
try:
|
|
570
|
+
self.execute_query(insert_stmt)
|
|
571
|
+
except Exception as e:
|
|
572
|
+
logger.error(f"Failed to insert edge: {e}", exc_info=True)
|
|
573
|
+
|
|
574
|
+
@timed
|
|
575
|
+
def delete_edge(
|
|
576
|
+
self, source_id: str, target_id: str, type: str, user_name: str | None = None
|
|
577
|
+
) -> None:
|
|
578
|
+
"""
|
|
579
|
+
Delete a specific edge between two nodes.
|
|
580
|
+
Args:
|
|
581
|
+
source_id: ID of the source node.
|
|
582
|
+
target_id: ID of the target node.
|
|
583
|
+
type: Relationship type to remove.
|
|
584
|
+
user_name (str, optional): User name for filtering in non-multi-db mode
|
|
585
|
+
"""
|
|
586
|
+
user_name = user_name if user_name else self.config.user_name
|
|
587
|
+
query = f"""
|
|
588
|
+
MATCH (a@Memory) -[r@{type}]-> (b@Memory)
|
|
589
|
+
WHERE a.id = {self._format_value(source_id)} AND b.id = {self._format_value(target_id)}
|
|
590
|
+
"""
|
|
591
|
+
|
|
592
|
+
query += f" AND a.user_name = {self._format_value(user_name)} AND b.user_name = {self._format_value(user_name)}"
|
|
593
|
+
query += "\nDELETE r"
|
|
594
|
+
self.execute_query(query)
|
|
595
|
+
|
|
596
|
+
@timed
|
|
597
|
+
def get_memory_count(self, memory_type: str, user_name: str | None = None) -> int:
|
|
598
|
+
user_name = user_name if user_name else self.config.user_name
|
|
599
|
+
query = f"""
|
|
600
|
+
MATCH (n@Memory)
|
|
601
|
+
WHERE n.memory_type = "{memory_type}"
|
|
602
|
+
"""
|
|
603
|
+
query += f"\nAND n.user_name = '{user_name}'"
|
|
604
|
+
query += "\nRETURN COUNT(n) AS count"
|
|
605
|
+
|
|
606
|
+
try:
|
|
607
|
+
result = self.execute_query(query)
|
|
608
|
+
return result.one_or_none()["count"].value
|
|
609
|
+
except Exception as e:
|
|
610
|
+
logger.error(f"[get_memory_count] Failed: {e}")
|
|
611
|
+
return -1
|
|
612
|
+
|
|
613
|
+
@timed
|
|
614
|
+
def count_nodes(self, scope: str, user_name: str | None = None) -> int:
|
|
615
|
+
user_name = user_name if user_name else self.config.user_name
|
|
616
|
+
query = f"""
|
|
617
|
+
MATCH (n@Memory)
|
|
618
|
+
WHERE n.memory_type = "{scope}"
|
|
619
|
+
"""
|
|
620
|
+
query += f"\nAND n.user_name = '{user_name}'"
|
|
621
|
+
query += "\nRETURN count(n) AS count"
|
|
622
|
+
|
|
623
|
+
result = self.execute_query(query)
|
|
624
|
+
return result.one_or_none()["count"].value
|
|
625
|
+
|
|
626
|
+
@timed
|
|
627
|
+
def edge_exists(
|
|
628
|
+
self,
|
|
629
|
+
source_id: str,
|
|
630
|
+
target_id: str,
|
|
631
|
+
type: str = "ANY",
|
|
632
|
+
direction: str = "OUTGOING",
|
|
633
|
+
user_name: str | None = None,
|
|
634
|
+
) -> bool:
|
|
635
|
+
"""
|
|
636
|
+
Check if an edge exists between two nodes.
|
|
637
|
+
Args:
|
|
638
|
+
source_id: ID of the source node.
|
|
639
|
+
target_id: ID of the target node.
|
|
640
|
+
type: Relationship type. Use "ANY" to match any relationship type.
|
|
641
|
+
direction: Direction of the edge.
|
|
642
|
+
Use "OUTGOING" (default), "INCOMING", or "ANY".
|
|
643
|
+
user_name (str, optional): User name for filtering in non-multi-db mode
|
|
644
|
+
Returns:
|
|
645
|
+
True if the edge exists, otherwise False.
|
|
646
|
+
"""
|
|
647
|
+
# Prepare the relationship pattern
|
|
648
|
+
user_name = user_name if user_name else self.config.user_name
|
|
649
|
+
rel = "r" if type == "ANY" else f"r@{type}"
|
|
650
|
+
|
|
651
|
+
# Prepare the match pattern with direction
|
|
652
|
+
if direction == "OUTGOING":
|
|
653
|
+
pattern = f"(a@Memory {{id: '{source_id}'}})-[{rel}]->(b@Memory {{id: '{target_id}'}})"
|
|
654
|
+
elif direction == "INCOMING":
|
|
655
|
+
pattern = f"(a@Memory {{id: '{source_id}'}})<-[{rel}]-(b@Memory {{id: '{target_id}'}})"
|
|
656
|
+
elif direction == "ANY":
|
|
657
|
+
pattern = f"(a@Memory {{id: '{source_id}'}})-[{rel}]-(b@Memory {{id: '{target_id}'}})"
|
|
658
|
+
else:
|
|
659
|
+
raise ValueError(
|
|
660
|
+
f"Invalid direction: {direction}. Must be 'OUTGOING', 'INCOMING', or 'ANY'."
|
|
661
|
+
)
|
|
662
|
+
query = f"MATCH {pattern}"
|
|
663
|
+
query += f"\nWHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}'"
|
|
664
|
+
query += "\nRETURN r"
|
|
665
|
+
|
|
666
|
+
# Run the Cypher query
|
|
667
|
+
result = self.execute_query(query)
|
|
668
|
+
record = result.one_or_none()
|
|
669
|
+
if record is None:
|
|
670
|
+
return False
|
|
671
|
+
return record.values() is not None
|
|
672
|
+
|
|
673
|
+
@timed
|
|
674
|
+
# Graph Query & Reasoning
|
|
675
|
+
def get_node(
|
|
676
|
+
self, id: str, include_embedding: bool = False, user_name: str | None = None
|
|
677
|
+
) -> dict[str, Any] | None:
|
|
678
|
+
"""
|
|
679
|
+
Retrieve a Memory node by its unique ID.
|
|
680
|
+
|
|
681
|
+
Args:
|
|
682
|
+
id (str): Node ID (Memory.id)
|
|
683
|
+
include_embedding: with/without embedding
|
|
684
|
+
user_name (str, optional): User name for filtering in non-multi-db mode
|
|
685
|
+
|
|
686
|
+
Returns:
|
|
687
|
+
dict: Node properties as key-value pairs, or None if not found.
|
|
688
|
+
"""
|
|
689
|
+
filter_clause = f'n.id = "{id}"'
|
|
690
|
+
return_fields = self._build_return_fields(include_embedding)
|
|
691
|
+
gql = f"""
|
|
692
|
+
MATCH (n@Memory)
|
|
693
|
+
WHERE {filter_clause}
|
|
694
|
+
RETURN {return_fields}
|
|
695
|
+
"""
|
|
696
|
+
|
|
697
|
+
try:
|
|
698
|
+
result = self.execute_query(gql)
|
|
699
|
+
for row in result:
|
|
700
|
+
props = {k: v.value for k, v in row.items()}
|
|
701
|
+
node = self._parse_node(props)
|
|
702
|
+
return node
|
|
703
|
+
|
|
704
|
+
except Exception as e:
|
|
705
|
+
logger.error(
|
|
706
|
+
f"[get_node] Failed to retrieve node '{id}': {e}, trace: {traceback.format_exc()}"
|
|
707
|
+
)
|
|
708
|
+
return None
|
|
709
|
+
|
|
710
|
+
@timed
|
|
711
|
+
def get_nodes(
|
|
712
|
+
self,
|
|
713
|
+
ids: list[str],
|
|
714
|
+
include_embedding: bool = False,
|
|
715
|
+
user_name: str | None = None,
|
|
716
|
+
**kwargs,
|
|
717
|
+
) -> list[dict[str, Any]]:
|
|
718
|
+
"""
|
|
719
|
+
Retrieve the metadata and memory of a list of nodes.
|
|
720
|
+
Args:
|
|
721
|
+
ids: List of Node identifier.
|
|
722
|
+
include_embedding: with/without embedding
|
|
723
|
+
user_name (str, optional): User name for filtering in non-multi-db mode
|
|
724
|
+
Returns:
|
|
725
|
+
list[dict]: Parsed node records containing 'id', 'memory', and 'metadata'.
|
|
726
|
+
|
|
727
|
+
Notes:
|
|
728
|
+
- Assumes all provided IDs are valid and exist.
|
|
729
|
+
- Returns empty list if input is empty.
|
|
730
|
+
"""
|
|
731
|
+
if not ids:
|
|
732
|
+
return []
|
|
733
|
+
# Safe formatting of the ID list
|
|
734
|
+
id_list = ",".join(f'"{_id}"' for _id in ids)
|
|
735
|
+
|
|
736
|
+
return_fields = self._build_return_fields(include_embedding)
|
|
737
|
+
query = f"""
|
|
738
|
+
MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
|
|
739
|
+
WHERE n.id IN [{id_list}]
|
|
740
|
+
RETURN {return_fields}
|
|
741
|
+
"""
|
|
742
|
+
nodes = []
|
|
743
|
+
try:
|
|
744
|
+
results = self.execute_query(query)
|
|
745
|
+
for row in results:
|
|
746
|
+
props = {k: v.value for k, v in row.items()}
|
|
747
|
+
nodes.append(self._parse_node(props))
|
|
748
|
+
except Exception as e:
|
|
749
|
+
logger.error(
|
|
750
|
+
f"[get_nodes] Failed to retrieve nodes {ids}: {e}, trace: {traceback.format_exc()}"
|
|
751
|
+
)
|
|
752
|
+
return nodes
|
|
753
|
+
|
|
754
|
+
@timed
|
|
755
|
+
def get_edges(
|
|
756
|
+
self, id: str, type: str = "ANY", direction: str = "ANY", user_name: str | None = None
|
|
757
|
+
) -> list[dict[str, str]]:
|
|
758
|
+
"""
|
|
759
|
+
Get edges connected to a node, with optional type and direction filter.
|
|
760
|
+
|
|
761
|
+
Args:
|
|
762
|
+
id: Node ID to retrieve edges for.
|
|
763
|
+
type: Relationship type to match, or 'ANY' to match all.
|
|
764
|
+
direction: 'OUTGOING', 'INCOMING', or 'ANY'.
|
|
765
|
+
user_name (str, optional): User name for filtering in non-multi-db mode
|
|
766
|
+
|
|
767
|
+
Returns:
|
|
768
|
+
List of edges:
|
|
769
|
+
[
|
|
770
|
+
{"from": "source_id", "to": "target_id", "type": "RELATE"},
|
|
771
|
+
...
|
|
772
|
+
]
|
|
773
|
+
"""
|
|
774
|
+
# Build relationship type filter
|
|
775
|
+
rel_type = "" if type == "ANY" else f"@{type}"
|
|
776
|
+
user_name = user_name if user_name else self.config.user_name
|
|
777
|
+
# Build Cypher pattern based on direction
|
|
778
|
+
if direction == "OUTGOING":
|
|
779
|
+
pattern = f"(a@Memory)-[r{rel_type}]->(b@Memory)"
|
|
780
|
+
where_clause = f"a.id = '{id}'"
|
|
781
|
+
elif direction == "INCOMING":
|
|
782
|
+
pattern = f"(a@Memory)<-[r{rel_type}]-(b@Memory)"
|
|
783
|
+
where_clause = f"a.id = '{id}'"
|
|
784
|
+
elif direction == "ANY":
|
|
785
|
+
pattern = f"(a@Memory)-[r{rel_type}]-(b@Memory)"
|
|
786
|
+
where_clause = f"a.id = '{id}' OR b.id = '{id}'"
|
|
787
|
+
else:
|
|
788
|
+
raise ValueError("Invalid direction. Must be 'OUTGOING', 'INCOMING', or 'ANY'.")
|
|
789
|
+
|
|
790
|
+
where_clause += f" AND a.user_name = '{user_name}' AND b.user_name = '{user_name}'"
|
|
791
|
+
|
|
792
|
+
query = f"""
|
|
793
|
+
MATCH {pattern}
|
|
794
|
+
WHERE {where_clause}
|
|
795
|
+
RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type
|
|
796
|
+
"""
|
|
797
|
+
|
|
798
|
+
result = self.execute_query(query)
|
|
799
|
+
edges = []
|
|
800
|
+
for record in result:
|
|
801
|
+
edges.append(
|
|
802
|
+
{
|
|
803
|
+
"from": record["from_id"].value,
|
|
804
|
+
"to": record["to_id"].value,
|
|
805
|
+
"type": record["edge_type"].value,
|
|
806
|
+
}
|
|
807
|
+
)
|
|
808
|
+
return edges
|
|
809
|
+
|
|
810
|
+
@timed
|
|
811
|
+
def get_neighbors_by_tag(
|
|
812
|
+
self,
|
|
813
|
+
tags: list[str],
|
|
814
|
+
exclude_ids: list[str],
|
|
815
|
+
top_k: int = 5,
|
|
816
|
+
min_overlap: int = 1,
|
|
817
|
+
include_embedding: bool = False,
|
|
818
|
+
user_name: str | None = None,
|
|
819
|
+
) -> list[dict[str, Any]]:
|
|
820
|
+
"""
|
|
821
|
+
Find top-K neighbor nodes with maximum tag overlap.
|
|
822
|
+
|
|
823
|
+
Args:
|
|
824
|
+
tags: The list of tags to match.
|
|
825
|
+
exclude_ids: Node IDs to exclude (e.g., local cluster).
|
|
826
|
+
top_k: Max number of neighbors to return.
|
|
827
|
+
min_overlap: Minimum number of overlapping tags required.
|
|
828
|
+
include_embedding: with/without embedding
|
|
829
|
+
user_name (str, optional): User name for filtering in non-multi-db mode
|
|
830
|
+
|
|
831
|
+
Returns:
|
|
832
|
+
List of dicts with node details and overlap count.
|
|
833
|
+
"""
|
|
834
|
+
if not tags:
|
|
835
|
+
return []
|
|
836
|
+
user_name = user_name if user_name else self.config.user_name
|
|
837
|
+
where_clauses = [
|
|
838
|
+
'n.status = "activated"',
|
|
839
|
+
'NOT (n.node_type = "reasoning")',
|
|
840
|
+
'NOT (n.memory_type = "WorkingMemory")',
|
|
841
|
+
]
|
|
842
|
+
if exclude_ids:
|
|
843
|
+
where_clauses.append(f"NOT (n.id IN {exclude_ids})")
|
|
844
|
+
|
|
845
|
+
where_clauses.append(f'n.user_name = "{user_name}"')
|
|
846
|
+
|
|
847
|
+
where_clause = " AND ".join(where_clauses)
|
|
848
|
+
tag_list_literal = "[" + ", ".join(f'"{_escape_str(t)}"' for t in tags) + "]"
|
|
849
|
+
|
|
850
|
+
return_fields = self._build_return_fields(include_embedding)
|
|
851
|
+
query = f"""
|
|
852
|
+
LET tag_list = {tag_list_literal}
|
|
853
|
+
|
|
854
|
+
MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
|
|
855
|
+
WHERE {where_clause}
|
|
856
|
+
RETURN {return_fields},
|
|
857
|
+
size( filter( n.tags, t -> t IN tag_list ) ) AS overlap_count
|
|
858
|
+
ORDER BY overlap_count DESC
|
|
859
|
+
LIMIT {top_k}
|
|
860
|
+
"""
|
|
861
|
+
|
|
862
|
+
result = self.execute_query(query)
|
|
863
|
+
neighbors: list[dict[str, Any]] = []
|
|
864
|
+
for r in result:
|
|
865
|
+
props = {k: v.value for k, v in r.items() if k != "overlap_count"}
|
|
866
|
+
parsed = self._parse_node(props)
|
|
867
|
+
parsed["overlap_count"] = r["overlap_count"].value
|
|
868
|
+
neighbors.append(parsed)
|
|
869
|
+
|
|
870
|
+
neighbors.sort(key=lambda x: x["overlap_count"], reverse=True)
|
|
871
|
+
neighbors = neighbors[:top_k]
|
|
872
|
+
result = []
|
|
873
|
+
for neighbor in neighbors[:top_k]:
|
|
874
|
+
neighbor.pop("overlap_count")
|
|
875
|
+
result.append(neighbor)
|
|
876
|
+
return result
|
|
877
|
+
|
|
878
|
+
@timed
|
|
879
|
+
def get_children_with_embeddings(
|
|
880
|
+
self, id: str, user_name: str | None = None
|
|
881
|
+
) -> list[dict[str, Any]]:
|
|
882
|
+
user_name = user_name if user_name else self.config.user_name
|
|
883
|
+
where_user = f"AND p.user_name = '{user_name}' AND c.user_name = '{user_name}'"
|
|
884
|
+
|
|
885
|
+
query = f"""
|
|
886
|
+
MATCH (p@Memory)-[@PARENT]->(c@Memory)
|
|
887
|
+
WHERE p.id = "{id}" {where_user}
|
|
888
|
+
RETURN c.id AS id, c.{self.dim_field} AS {self.dim_field}, c.memory AS memory
|
|
889
|
+
"""
|
|
890
|
+
result = self.execute_query(query)
|
|
891
|
+
children = []
|
|
892
|
+
for row in result:
|
|
893
|
+
eid = row["id"].value # STRING
|
|
894
|
+
emb_v = row[self.dim_field].value # NVector
|
|
895
|
+
emb = list(emb_v.values) if emb_v else []
|
|
896
|
+
mem = row["memory"].value # STRING
|
|
897
|
+
|
|
898
|
+
children.append({"id": eid, "embedding": emb, "memory": mem})
|
|
899
|
+
return children
|
|
900
|
+
|
|
901
|
+
@timed
|
|
902
|
+
def get_subgraph(
|
|
903
|
+
self,
|
|
904
|
+
center_id: str,
|
|
905
|
+
depth: int = 2,
|
|
906
|
+
center_status: str = "activated",
|
|
907
|
+
user_name: str | None = None,
|
|
908
|
+
) -> dict[str, Any]:
|
|
909
|
+
"""
|
|
910
|
+
Retrieve a local subgraph centered at a given node.
|
|
911
|
+
Args:
|
|
912
|
+
center_id: The ID of the center node.
|
|
913
|
+
depth: The hop distance for neighbors.
|
|
914
|
+
center_status: Required status for center node.
|
|
915
|
+
user_name (str, optional): User name for filtering in non-multi-db mode
|
|
916
|
+
Returns:
|
|
917
|
+
{
|
|
918
|
+
"core_node": {...},
|
|
919
|
+
"neighbors": [...],
|
|
920
|
+
"edges": [...]
|
|
921
|
+
}
|
|
922
|
+
"""
|
|
923
|
+
if not 1 <= depth <= 5:
|
|
924
|
+
raise ValueError("depth must be 1-5")
|
|
925
|
+
|
|
926
|
+
user_name = user_name if user_name else self.config.user_name
|
|
927
|
+
|
|
928
|
+
gql = f"""
|
|
929
|
+
MATCH (center@Memory /*+ INDEX(idx_memory_user_name) */)
|
|
930
|
+
WHERE center.id = '{center_id}'
|
|
931
|
+
AND center.status = '{center_status}'
|
|
932
|
+
AND center.user_name = '{user_name}'
|
|
933
|
+
OPTIONAL MATCH p = (center)-[e]->{{1,{depth}}}(neighbor@Memory)
|
|
934
|
+
WHERE neighbor.user_name = '{user_name}'
|
|
935
|
+
RETURN center,
|
|
936
|
+
collect(DISTINCT neighbor) AS neighbors,
|
|
937
|
+
collect(EDGES(p)) AS edge_chains
|
|
938
|
+
"""
|
|
939
|
+
|
|
940
|
+
result = self.execute_query(gql).one_or_none()
|
|
941
|
+
if not result or result.size == 0:
|
|
942
|
+
return {"core_node": None, "neighbors": [], "edges": []}
|
|
943
|
+
|
|
944
|
+
core_node_props = result["center"].as_node().get_properties()
|
|
945
|
+
core_node = self._parse_node(core_node_props)
|
|
946
|
+
neighbors = []
|
|
947
|
+
vid_to_id_map = {result["center"].as_node().node_id: core_node["id"]}
|
|
948
|
+
for n in result["neighbors"].value:
|
|
949
|
+
n_node = n.as_node()
|
|
950
|
+
n_props = n_node.get_properties()
|
|
951
|
+
node_parsed = self._parse_node(n_props)
|
|
952
|
+
neighbors.append(node_parsed)
|
|
953
|
+
vid_to_id_map[n_node.node_id] = node_parsed["id"]
|
|
954
|
+
|
|
955
|
+
edges = []
|
|
956
|
+
for chain_group in result["edge_chains"].value:
|
|
957
|
+
for edge_wr in chain_group.value:
|
|
958
|
+
edge = edge_wr.value
|
|
959
|
+
edges.append(
|
|
960
|
+
{
|
|
961
|
+
"type": edge.get_type(),
|
|
962
|
+
"source": vid_to_id_map.get(edge.get_src_id()),
|
|
963
|
+
"target": vid_to_id_map.get(edge.get_dst_id()),
|
|
964
|
+
}
|
|
965
|
+
)
|
|
966
|
+
|
|
967
|
+
return {"core_node": core_node, "neighbors": neighbors, "edges": edges}
|
|
968
|
+
|
|
969
|
+
@timed
|
|
970
|
+
# Search / recall operations
|
|
971
|
+
def search_by_embedding(
|
|
972
|
+
self,
|
|
973
|
+
vector: list[float],
|
|
974
|
+
top_k: int = 5,
|
|
975
|
+
scope: str | None = None,
|
|
976
|
+
status: str | None = None,
|
|
977
|
+
threshold: float | None = None,
|
|
978
|
+
search_filter: dict | None = None,
|
|
979
|
+
user_name: str | None = None,
|
|
980
|
+
**kwargs,
|
|
981
|
+
) -> list[dict]:
|
|
982
|
+
"""
|
|
983
|
+
Retrieve node IDs based on vector similarity.
|
|
984
|
+
|
|
985
|
+
Args:
|
|
986
|
+
vector (list[float]): The embedding vector representing query semantics.
|
|
987
|
+
top_k (int): Number of top similar nodes to retrieve.
|
|
988
|
+
scope (str, optional): Memory type filter (e.g., 'WorkingMemory', 'LongTermMemory').
|
|
989
|
+
status (str, optional): Node status filter (e.g., 'active', 'archived').
|
|
990
|
+
If provided, restricts results to nodes with matching status.
|
|
991
|
+
threshold (float, optional): Minimum similarity score threshold (0 ~ 1).
|
|
992
|
+
search_filter (dict, optional): Additional metadata filters for search results.
|
|
993
|
+
Keys should match node properties, values are the expected values.
|
|
994
|
+
user_name (str, optional): User name for filtering in non-multi-db mode
|
|
995
|
+
|
|
996
|
+
Returns:
|
|
997
|
+
list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
|
|
998
|
+
|
|
999
|
+
Notes:
|
|
1000
|
+
- This method uses Neo4j native vector indexing to search for similar nodes.
|
|
1001
|
+
- If scope is provided, it restricts results to nodes with matching memory_type.
|
|
1002
|
+
- If 'status' is provided, only nodes with the matching status will be returned.
|
|
1003
|
+
- If threshold is provided, only results with score >= threshold will be returned.
|
|
1004
|
+
- If search_filter is provided, additional WHERE clauses will be added for metadata filtering.
|
|
1005
|
+
- Typical use case: restrict to 'status = activated' to avoid
|
|
1006
|
+
matching archived or merged nodes.
|
|
1007
|
+
"""
|
|
1008
|
+
user_name = user_name if user_name else self.config.user_name
|
|
1009
|
+
vector = _normalize(vector)
|
|
1010
|
+
dim = len(vector)
|
|
1011
|
+
vector_str = ",".join(f"{float(x)}" for x in vector)
|
|
1012
|
+
gql_vector = f"VECTOR<{dim}, FLOAT>([{vector_str}])"
|
|
1013
|
+
where_clauses = [f"n.{self.dim_field} IS NOT NULL"]
|
|
1014
|
+
if scope:
|
|
1015
|
+
where_clauses.append(f'n.memory_type = "{scope}"')
|
|
1016
|
+
if status:
|
|
1017
|
+
where_clauses.append(f'n.status = "{status}"')
|
|
1018
|
+
where_clauses.append(f'n.user_name = "{user_name}"')
|
|
1019
|
+
|
|
1020
|
+
# Add search_filter conditions
|
|
1021
|
+
if search_filter:
|
|
1022
|
+
for key, value in search_filter.items():
|
|
1023
|
+
if isinstance(value, str):
|
|
1024
|
+
where_clauses.append(f'n.{key} = "{value}"')
|
|
1025
|
+
else:
|
|
1026
|
+
where_clauses.append(f"n.{key} = {value}")
|
|
1027
|
+
|
|
1028
|
+
where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
|
|
1029
|
+
|
|
1030
|
+
gql = f"""
|
|
1031
|
+
let a = {gql_vector}
|
|
1032
|
+
MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
|
|
1033
|
+
{where_clause}
|
|
1034
|
+
ORDER BY inner_product(n.{self.dim_field}, a) DESC
|
|
1035
|
+
LIMIT {top_k}
|
|
1036
|
+
RETURN n.id AS id, inner_product(n.{self.dim_field}, a) AS score"""
|
|
1037
|
+
try:
|
|
1038
|
+
result = self.execute_query(gql)
|
|
1039
|
+
except Exception as e:
|
|
1040
|
+
logger.error(f"[search_by_embedding] Query failed: {e}")
|
|
1041
|
+
return []
|
|
1042
|
+
|
|
1043
|
+
try:
|
|
1044
|
+
output = []
|
|
1045
|
+
for row in result:
|
|
1046
|
+
values = row.values()
|
|
1047
|
+
id_val = values[0].as_string()
|
|
1048
|
+
score_val = values[1].as_double()
|
|
1049
|
+
score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score
|
|
1050
|
+
if threshold is None or score_val >= threshold:
|
|
1051
|
+
output.append({"id": id_val, "score": score_val})
|
|
1052
|
+
return output
|
|
1053
|
+
except Exception as e:
|
|
1054
|
+
logger.error(f"[search_by_embedding] Result parse failed: {e}")
|
|
1055
|
+
return []
|
|
1056
|
+
|
|
1057
|
+
@timed
|
|
1058
|
+
def get_by_metadata(
|
|
1059
|
+
self, filters: list[dict[str, Any]], user_name: str | None = None
|
|
1060
|
+
) -> list[str]:
|
|
1061
|
+
"""
|
|
1062
|
+
1. ADD logic: "AND" vs "OR"(support logic combination);
|
|
1063
|
+
2. Support nested conditional expressions;
|
|
1064
|
+
|
|
1065
|
+
Retrieve node IDs that match given metadata filters.
|
|
1066
|
+
Supports exact match.
|
|
1067
|
+
|
|
1068
|
+
Args:
|
|
1069
|
+
filters: List of filter dicts like:
|
|
1070
|
+
[
|
|
1071
|
+
{"field": "key", "op": "in", "value": ["A", "B"]},
|
|
1072
|
+
{"field": "confidence", "op": ">=", "value": 80},
|
|
1073
|
+
{"field": "tags", "op": "contains", "value": "AI"},
|
|
1074
|
+
...
|
|
1075
|
+
]
|
|
1076
|
+
user_name (str, optional): User name for filtering in non-multi-db mode
|
|
1077
|
+
|
|
1078
|
+
Returns:
|
|
1079
|
+
list[str]: Node IDs whose metadata match the filter conditions. (AND logic).
|
|
1080
|
+
|
|
1081
|
+
Notes:
|
|
1082
|
+
- Supports structured querying such as tag/category/importance/time filtering.
|
|
1083
|
+
- Can be used for faceted recall or prefiltering before embedding rerank.
|
|
1084
|
+
"""
|
|
1085
|
+
where_clauses = []
|
|
1086
|
+
user_name = user_name if user_name else self.config.user_name
|
|
1087
|
+
for _i, f in enumerate(filters):
|
|
1088
|
+
field = f["field"]
|
|
1089
|
+
op = f.get("op", "=")
|
|
1090
|
+
value = f["value"]
|
|
1091
|
+
|
|
1092
|
+
escaped_value = self._format_value(value)
|
|
1093
|
+
|
|
1094
|
+
# Build WHERE clause
|
|
1095
|
+
if op == "=":
|
|
1096
|
+
where_clauses.append(f"n.{field} = {escaped_value}")
|
|
1097
|
+
elif op == "in":
|
|
1098
|
+
where_clauses.append(f"n.{field} IN {escaped_value}")
|
|
1099
|
+
elif op == "contains":
|
|
1100
|
+
where_clauses.append(f"size(filter(n.{field}, t -> t IN {escaped_value})) > 0")
|
|
1101
|
+
elif op == "starts_with":
|
|
1102
|
+
where_clauses.append(f"n.{field} STARTS WITH {escaped_value}")
|
|
1103
|
+
elif op == "ends_with":
|
|
1104
|
+
where_clauses.append(f"n.{field} ENDS WITH {escaped_value}")
|
|
1105
|
+
elif op in [">", ">=", "<", "<="]:
|
|
1106
|
+
where_clauses.append(f"n.{field} {op} {escaped_value}")
|
|
1107
|
+
else:
|
|
1108
|
+
raise ValueError(f"Unsupported operator: {op}")
|
|
1109
|
+
|
|
1110
|
+
where_clauses.append(f'n.user_name = "{user_name}"')
|
|
1111
|
+
|
|
1112
|
+
where_str = " AND ".join(where_clauses)
|
|
1113
|
+
gql = f"MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) WHERE {where_str} RETURN n.id AS id"
|
|
1114
|
+
ids = []
|
|
1115
|
+
try:
|
|
1116
|
+
result = self.execute_query(gql)
|
|
1117
|
+
ids = [record["id"].value for record in result]
|
|
1118
|
+
except Exception as e:
|
|
1119
|
+
logger.error(f"Failed to get metadata: {e}, gql is {gql}")
|
|
1120
|
+
return ids
|
|
1121
|
+
|
|
1122
|
+
@timed
|
|
1123
|
+
def get_grouped_counts(
|
|
1124
|
+
self,
|
|
1125
|
+
group_fields: list[str],
|
|
1126
|
+
where_clause: str = "",
|
|
1127
|
+
params: dict[str, Any] | None = None,
|
|
1128
|
+
user_name: str | None = None,
|
|
1129
|
+
) -> list[dict[str, Any]]:
|
|
1130
|
+
"""
|
|
1131
|
+
Count nodes grouped by any fields.
|
|
1132
|
+
|
|
1133
|
+
Args:
|
|
1134
|
+
group_fields (list[str]): Fields to group by, e.g., ["memory_type", "status"]
|
|
1135
|
+
where_clause (str, optional): Extra WHERE condition. E.g.,
|
|
1136
|
+
"WHERE n.status = 'activated'"
|
|
1137
|
+
params (dict, optional): Parameters for WHERE clause.
|
|
1138
|
+
user_name (str, optional): User name for filtering in non-multi-db mode
|
|
1139
|
+
|
|
1140
|
+
Returns:
|
|
1141
|
+
list[dict]: e.g., [{ 'memory_type': 'WorkingMemory', 'status': 'active', 'count': 10 }, ...]
|
|
1142
|
+
"""
|
|
1143
|
+
if not group_fields:
|
|
1144
|
+
raise ValueError("group_fields cannot be empty")
|
|
1145
|
+
user_name = user_name if user_name else self.config.user_name
|
|
1146
|
+
# GQL-specific modifications
|
|
1147
|
+
user_clause = f"n.user_name = '{user_name}'"
|
|
1148
|
+
if where_clause:
|
|
1149
|
+
where_clause = where_clause.strip()
|
|
1150
|
+
if where_clause.upper().startswith("WHERE"):
|
|
1151
|
+
where_clause += f" AND {user_clause}"
|
|
1152
|
+
else:
|
|
1153
|
+
where_clause = f"WHERE {where_clause} AND {user_clause}"
|
|
1154
|
+
else:
|
|
1155
|
+
where_clause = f"WHERE {user_clause}"
|
|
1156
|
+
|
|
1157
|
+
# Inline parameters if provided
|
|
1158
|
+
if params:
|
|
1159
|
+
for key, value in params.items():
|
|
1160
|
+
# Handle different value types appropriately
|
|
1161
|
+
if isinstance(value, str):
|
|
1162
|
+
value = f"'{value}'"
|
|
1163
|
+
where_clause = where_clause.replace(f"${key}", str(value))
|
|
1164
|
+
|
|
1165
|
+
return_fields = []
|
|
1166
|
+
group_by_fields = []
|
|
1167
|
+
|
|
1168
|
+
for field in group_fields:
|
|
1169
|
+
alias = field.replace(".", "_")
|
|
1170
|
+
return_fields.append(f"n.{field} AS {alias}")
|
|
1171
|
+
group_by_fields.append(alias)
|
|
1172
|
+
# Full GQL query construction
|
|
1173
|
+
gql = f"""
|
|
1174
|
+
MATCH (n /*+ INDEX(idx_memory_user_name) */)
|
|
1175
|
+
{where_clause}
|
|
1176
|
+
RETURN {", ".join(return_fields)}, COUNT(n) AS count
|
|
1177
|
+
"""
|
|
1178
|
+
result = self.execute_query(gql) # Pure GQL string execution
|
|
1179
|
+
|
|
1180
|
+
output = []
|
|
1181
|
+
for record in result:
|
|
1182
|
+
group_values = {}
|
|
1183
|
+
for i, field in enumerate(group_fields):
|
|
1184
|
+
value = record.values()[i].as_string()
|
|
1185
|
+
group_values[field] = value
|
|
1186
|
+
count_value = record["count"].value
|
|
1187
|
+
output.append({**group_values, "count": count_value})
|
|
1188
|
+
|
|
1189
|
+
return output
|
|
1190
|
+
|
|
1191
|
+
@timed
|
|
1192
|
+
def clear(self, user_name: str | None = None) -> None:
|
|
1193
|
+
"""
|
|
1194
|
+
Clear the entire graph if the target database exists.
|
|
1195
|
+
|
|
1196
|
+
Args:
|
|
1197
|
+
user_name (str, optional): User name for filtering in non-multi-db mode
|
|
1198
|
+
"""
|
|
1199
|
+
user_name = user_name if user_name else self.config.user_name
|
|
1200
|
+
try:
|
|
1201
|
+
query = f"MATCH (n@Memory) WHERE n.user_name = '{user_name}' DETACH DELETE n"
|
|
1202
|
+
self.execute_query(query)
|
|
1203
|
+
logger.info("Cleared all nodes from database.")
|
|
1204
|
+
|
|
1205
|
+
except Exception as e:
|
|
1206
|
+
logger.error(f"[ERROR] Failed to clear database: {e}")
|
|
1207
|
+
|
|
1208
|
+
@timed
|
|
1209
|
+
def export_graph(
|
|
1210
|
+
self, include_embedding: bool = False, user_name: str | None = None, **kwargs
|
|
1211
|
+
) -> dict[str, Any]:
|
|
1212
|
+
"""
|
|
1213
|
+
Export all graph nodes and edges in a structured form.
|
|
1214
|
+
Args:
|
|
1215
|
+
include_embedding (bool): Whether to include the large embedding field.
|
|
1216
|
+
user_name (str, optional): User name for filtering in non-multi-db mode
|
|
1217
|
+
|
|
1218
|
+
Returns:
|
|
1219
|
+
{
|
|
1220
|
+
"nodes": [ { "id": ..., "memory": ..., "metadata": {...} }, ... ],
|
|
1221
|
+
"edges": [ { "source": ..., "target": ..., "type": ... }, ... ]
|
|
1222
|
+
}
|
|
1223
|
+
"""
|
|
1224
|
+
user_name = user_name if user_name else self.config.user_name
|
|
1225
|
+
node_query = "MATCH (n@Memory)"
|
|
1226
|
+
edge_query = "MATCH (a@Memory)-[r]->(b@Memory)"
|
|
1227
|
+
node_query += f' WHERE n.user_name = "{user_name}"'
|
|
1228
|
+
edge_query += f' WHERE r.user_name = "{user_name}"'
|
|
1229
|
+
|
|
1230
|
+
try:
|
|
1231
|
+
if include_embedding:
|
|
1232
|
+
return_fields = "n"
|
|
1233
|
+
else:
|
|
1234
|
+
return_fields = ",".join(
|
|
1235
|
+
[
|
|
1236
|
+
"n.id AS id",
|
|
1237
|
+
"n.memory AS memory",
|
|
1238
|
+
"n.user_name AS user_name",
|
|
1239
|
+
"n.user_id AS user_id",
|
|
1240
|
+
"n.session_id AS session_id",
|
|
1241
|
+
"n.status AS status",
|
|
1242
|
+
"n.key AS key",
|
|
1243
|
+
"n.confidence AS confidence",
|
|
1244
|
+
"n.tags AS tags",
|
|
1245
|
+
"n.created_at AS created_at",
|
|
1246
|
+
"n.updated_at AS updated_at",
|
|
1247
|
+
"n.memory_type AS memory_type",
|
|
1248
|
+
"n.sources AS sources",
|
|
1249
|
+
"n.source AS source",
|
|
1250
|
+
"n.node_type AS node_type",
|
|
1251
|
+
"n.visibility AS visibility",
|
|
1252
|
+
"n.usage AS usage",
|
|
1253
|
+
"n.background AS background",
|
|
1254
|
+
]
|
|
1255
|
+
)
|
|
1256
|
+
|
|
1257
|
+
full_node_query = f"{node_query} RETURN {return_fields}"
|
|
1258
|
+
node_result = self.execute_query(full_node_query, timeout=20)
|
|
1259
|
+
nodes = []
|
|
1260
|
+
logger.debug(f"Debugging: {node_result}")
|
|
1261
|
+
for row in node_result:
|
|
1262
|
+
if include_embedding:
|
|
1263
|
+
props = row.values()[0].as_node().get_properties()
|
|
1264
|
+
else:
|
|
1265
|
+
props = {k: v.value for k, v in row.items()}
|
|
1266
|
+
node = self._parse_node(props)
|
|
1267
|
+
nodes.append(node)
|
|
1268
|
+
except Exception as e:
|
|
1269
|
+
raise RuntimeError(f"[EXPORT GRAPH - NODES] Exception: {e}") from e
|
|
1270
|
+
|
|
1271
|
+
try:
|
|
1272
|
+
full_edge_query = f"{edge_query} RETURN a.id AS source, b.id AS target, type(r) as edge"
|
|
1273
|
+
edge_result = self.execute_query(full_edge_query, timeout=20)
|
|
1274
|
+
edges = [
|
|
1275
|
+
{
|
|
1276
|
+
"source": row.values()[0].value,
|
|
1277
|
+
"target": row.values()[1].value,
|
|
1278
|
+
"type": row.values()[2].value,
|
|
1279
|
+
}
|
|
1280
|
+
for row in edge_result
|
|
1281
|
+
]
|
|
1282
|
+
except Exception as e:
|
|
1283
|
+
raise RuntimeError(f"[EXPORT GRAPH - EDGES] Exception: {e}") from e
|
|
1284
|
+
|
|
1285
|
+
return {"nodes": nodes, "edges": edges}
|
|
1286
|
+
|
|
1287
|
+
@timed
|
|
1288
|
+
def import_graph(self, data: dict[str, Any], user_name: str | None = None) -> None:
|
|
1289
|
+
"""
|
|
1290
|
+
Import the entire graph from a serialized dictionary.
|
|
1291
|
+
|
|
1292
|
+
Args:
|
|
1293
|
+
data: A dictionary containing all nodes and edges to be loaded.
|
|
1294
|
+
user_name (str, optional): User name for filtering in non-multi-db mode
|
|
1295
|
+
"""
|
|
1296
|
+
user_name = user_name if user_name else self.config.user_name
|
|
1297
|
+
for node in data.get("nodes", []):
|
|
1298
|
+
try:
|
|
1299
|
+
id, memory, metadata = _compose_node(node)
|
|
1300
|
+
metadata["user_name"] = user_name
|
|
1301
|
+
metadata = self._prepare_node_metadata(metadata)
|
|
1302
|
+
metadata.update({"id": id, "memory": memory})
|
|
1303
|
+
properties = ", ".join(
|
|
1304
|
+
f"{k}: {self._format_value(v, k)}" for k, v in metadata.items()
|
|
1305
|
+
)
|
|
1306
|
+
node_gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
|
|
1307
|
+
self.execute_query(node_gql)
|
|
1308
|
+
except Exception as e:
|
|
1309
|
+
logger.error(f"Fail to load node: {node}, error: {e}")
|
|
1310
|
+
|
|
1311
|
+
for edge in data.get("edges", []):
|
|
1312
|
+
try:
|
|
1313
|
+
source_id, target_id = edge["source"], edge["target"]
|
|
1314
|
+
edge_type = edge["type"]
|
|
1315
|
+
props = f'{{user_name: "{user_name}"}}'
|
|
1316
|
+
edge_gql = f'''
|
|
1317
|
+
MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}})
|
|
1318
|
+
INSERT OR IGNORE (a) -[e@{edge_type} {props}]-> (b)
|
|
1319
|
+
'''
|
|
1320
|
+
self.execute_query(edge_gql)
|
|
1321
|
+
except Exception as e:
|
|
1322
|
+
logger.error(f"Fail to load edge: {edge}, error: {e}")
|
|
1323
|
+
|
|
1324
|
+
@timed
|
|
1325
|
+
def get_all_memory_items(
|
|
1326
|
+
self, scope: str, include_embedding: bool = False, user_name: str | None = None
|
|
1327
|
+
) -> (list)[dict]:
|
|
1328
|
+
"""
|
|
1329
|
+
Retrieve all memory items of a specific memory_type.
|
|
1330
|
+
|
|
1331
|
+
Args:
|
|
1332
|
+
scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'.
|
|
1333
|
+
include_embedding: with/without embedding
|
|
1334
|
+
user_name (str, optional): User name for filtering in non-multi-db mode
|
|
1335
|
+
|
|
1336
|
+
Returns:
|
|
1337
|
+
list[dict]: Full list of memory items under this scope.
|
|
1338
|
+
"""
|
|
1339
|
+
user_name = user_name if user_name else self.config.user_name
|
|
1340
|
+
if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"}:
|
|
1341
|
+
raise ValueError(f"Unsupported memory type scope: {scope}")
|
|
1342
|
+
|
|
1343
|
+
where_clause = f"WHERE n.memory_type = '{scope}'"
|
|
1344
|
+
where_clause += f" AND n.user_name = '{user_name}'"
|
|
1345
|
+
|
|
1346
|
+
return_fields = self._build_return_fields(include_embedding)
|
|
1347
|
+
|
|
1348
|
+
query = f"""
|
|
1349
|
+
MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
|
|
1350
|
+
{where_clause}
|
|
1351
|
+
RETURN {return_fields}
|
|
1352
|
+
LIMIT 100
|
|
1353
|
+
"""
|
|
1354
|
+
nodes = []
|
|
1355
|
+
try:
|
|
1356
|
+
results = self.execute_query(query)
|
|
1357
|
+
for row in results:
|
|
1358
|
+
props = {k: v.value for k, v in row.items()}
|
|
1359
|
+
nodes.append(self._parse_node(props))
|
|
1360
|
+
except Exception as e:
|
|
1361
|
+
logger.error(f"Failed to get memories: {e}")
|
|
1362
|
+
return nodes
|
|
1363
|
+
|
|
1364
|
+
@timed
|
|
1365
|
+
def get_structure_optimization_candidates(
|
|
1366
|
+
self, scope: str, include_embedding: bool = False, user_name: str | None = None
|
|
1367
|
+
) -> list[dict]:
|
|
1368
|
+
"""
|
|
1369
|
+
Find nodes that are likely candidates for structure optimization:
|
|
1370
|
+
- Isolated nodes, nodes with empty background, or nodes with exactly one child.
|
|
1371
|
+
- Plus: the child of any parent node that has exactly one child.
|
|
1372
|
+
"""
|
|
1373
|
+
user_name = user_name if user_name else self.config.user_name
|
|
1374
|
+
where_clause = f'''
|
|
1375
|
+
n.memory_type = "{scope}"
|
|
1376
|
+
AND n.status = "activated"
|
|
1377
|
+
'''
|
|
1378
|
+
where_clause += f' AND n.user_name = "{user_name}"'
|
|
1379
|
+
|
|
1380
|
+
return_fields = self._build_return_fields(include_embedding)
|
|
1381
|
+
return_fields += f", n.{self.dim_field} AS {self.dim_field}"
|
|
1382
|
+
|
|
1383
|
+
query = f"""
|
|
1384
|
+
MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
|
|
1385
|
+
WHERE {where_clause}
|
|
1386
|
+
OPTIONAL MATCH (n)-[@PARENT]->(c@Memory)
|
|
1387
|
+
OPTIONAL MATCH (p@Memory)-[@PARENT]->(n)
|
|
1388
|
+
WHERE c IS NULL AND p IS NULL
|
|
1389
|
+
RETURN {return_fields}
|
|
1390
|
+
"""
|
|
1391
|
+
|
|
1392
|
+
candidates = []
|
|
1393
|
+
node_ids = set()
|
|
1394
|
+
try:
|
|
1395
|
+
results = self.execute_query(query)
|
|
1396
|
+
for row in results:
|
|
1397
|
+
props = {k: v.value for k, v in row.items()}
|
|
1398
|
+
node = self._parse_node(props)
|
|
1399
|
+
node_id = node["id"]
|
|
1400
|
+
if node_id not in node_ids:
|
|
1401
|
+
candidates.append(node)
|
|
1402
|
+
node_ids.add(node_id)
|
|
1403
|
+
except Exception as e:
|
|
1404
|
+
logger.error(f"Failed : {e}, traceback: {traceback.format_exc()}")
|
|
1405
|
+
return candidates
|
|
1406
|
+
|
|
1407
|
+
@timed
|
|
1408
|
+
def drop_database(self) -> None:
|
|
1409
|
+
"""
|
|
1410
|
+
Permanently delete the entire database this instance is using.
|
|
1411
|
+
WARNING: This operation is destructive and cannot be undone.
|
|
1412
|
+
"""
|
|
1413
|
+
raise ValueError(
|
|
1414
|
+
f"Refusing to drop protected database: `{self.db_name}` in "
|
|
1415
|
+
f"Shared Database Multi-Tenant mode"
|
|
1416
|
+
)
|
|
1417
|
+
|
|
1418
|
+
@timed
|
|
1419
|
+
def detect_conflicts(self) -> list[tuple[str, str]]:
|
|
1420
|
+
"""
|
|
1421
|
+
Detect conflicting nodes based on logical or semantic inconsistency.
|
|
1422
|
+
Returns:
|
|
1423
|
+
A list of (node_id1, node_id2) tuples that conflict.
|
|
1424
|
+
"""
|
|
1425
|
+
raise NotImplementedError
|
|
1426
|
+
|
|
1427
|
+
@timed
|
|
1428
|
+
# Structure Maintenance
|
|
1429
|
+
def deduplicate_nodes(self) -> None:
|
|
1430
|
+
"""
|
|
1431
|
+
Deduplicate redundant or semantically similar nodes.
|
|
1432
|
+
This typically involves identifying nodes with identical or near-identical memory.
|
|
1433
|
+
"""
|
|
1434
|
+
raise NotImplementedError
|
|
1435
|
+
|
|
1436
|
+
@timed
|
|
1437
|
+
def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]:
|
|
1438
|
+
"""
|
|
1439
|
+
Get the ordered context chain starting from a node, following a relationship type.
|
|
1440
|
+
Args:
|
|
1441
|
+
id: Starting node ID.
|
|
1442
|
+
type: Relationship type to follow (e.g., 'FOLLOWS').
|
|
1443
|
+
Returns:
|
|
1444
|
+
List of ordered node IDs in the chain.
|
|
1445
|
+
"""
|
|
1446
|
+
raise NotImplementedError
|
|
1447
|
+
|
|
1448
|
+
@timed
|
|
1449
|
+
def get_neighbors(
|
|
1450
|
+
self, id: str, type: str, direction: Literal["in", "out", "both"] = "out"
|
|
1451
|
+
) -> list[str]:
|
|
1452
|
+
"""
|
|
1453
|
+
Get connected node IDs in a specific direction and relationship type.
|
|
1454
|
+
Args:
|
|
1455
|
+
id: Source node ID.
|
|
1456
|
+
type: Relationship type.
|
|
1457
|
+
direction: Edge direction to follow ('out', 'in', or 'both').
|
|
1458
|
+
Returns:
|
|
1459
|
+
List of neighboring node IDs.
|
|
1460
|
+
"""
|
|
1461
|
+
raise NotImplementedError
|
|
1462
|
+
|
|
1463
|
+
@timed
|
|
1464
|
+
def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[str]:
|
|
1465
|
+
"""
|
|
1466
|
+
Get the path of nodes from source to target within a limited depth.
|
|
1467
|
+
Args:
|
|
1468
|
+
source_id: Starting node ID.
|
|
1469
|
+
target_id: Target node ID.
|
|
1470
|
+
max_depth: Maximum path length to traverse.
|
|
1471
|
+
Returns:
|
|
1472
|
+
Ordered list of node IDs along the path.
|
|
1473
|
+
"""
|
|
1474
|
+
raise NotImplementedError
|
|
1475
|
+
|
|
1476
|
+
@timed
|
|
1477
|
+
def merge_nodes(self, id1: str, id2: str) -> str:
|
|
1478
|
+
"""
|
|
1479
|
+
Merge two similar or duplicate nodes into one.
|
|
1480
|
+
Args:
|
|
1481
|
+
id1: First node ID.
|
|
1482
|
+
id2: Second node ID.
|
|
1483
|
+
Returns:
|
|
1484
|
+
ID of the resulting merged node.
|
|
1485
|
+
"""
|
|
1486
|
+
raise NotImplementedError
|
|
1487
|
+
|
|
1488
|
+
@classmethod
|
|
1489
|
+
def _ensure_space_exists(cls, tmp_client, cfg):
|
|
1490
|
+
"""Lightweight check to ensure target graph (space) exists."""
|
|
1491
|
+
db_name = getattr(cfg, "space", None)
|
|
1492
|
+
if not db_name:
|
|
1493
|
+
logger.warning("[NebulaGraphDBSync] No `space` specified in cfg.")
|
|
1494
|
+
return
|
|
1495
|
+
|
|
1496
|
+
try:
|
|
1497
|
+
res = tmp_client.execute("SHOW GRAPHS")
|
|
1498
|
+
existing = {row.values()[0].as_string() for row in res}
|
|
1499
|
+
if db_name not in existing:
|
|
1500
|
+
tmp_client.execute(f"CREATE GRAPH IF NOT EXISTS `{db_name}` TYPED MemOSBgeM3Type")
|
|
1501
|
+
logger.info(f"✅ Graph `{db_name}` created before session binding.")
|
|
1502
|
+
else:
|
|
1503
|
+
logger.debug(f"Graph `{db_name}` already exists.")
|
|
1504
|
+
except Exception:
|
|
1505
|
+
logger.exception("[NebulaGraphDBSync] Failed to ensure space exists")
|
|
1506
|
+
|
|
1507
|
+
@timed
|
|
1508
|
+
def _ensure_database_exists(self):
|
|
1509
|
+
graph_type_name = "MemOSBgeM3Type"
|
|
1510
|
+
|
|
1511
|
+
check_type_query = "SHOW GRAPH TYPES"
|
|
1512
|
+
result = self.execute_query(check_type_query, auto_set_db=False)
|
|
1513
|
+
|
|
1514
|
+
type_exists = any(row["graph_type"].as_string() == graph_type_name for row in result)
|
|
1515
|
+
|
|
1516
|
+
if not type_exists:
|
|
1517
|
+
create_tag = f"""
|
|
1518
|
+
CREATE GRAPH TYPE IF NOT EXISTS {graph_type_name} AS {{
|
|
1519
|
+
NODE Memory (:MemoryTag {{
|
|
1520
|
+
id STRING,
|
|
1521
|
+
memory STRING,
|
|
1522
|
+
user_name STRING,
|
|
1523
|
+
user_id STRING,
|
|
1524
|
+
session_id STRING,
|
|
1525
|
+
status STRING,
|
|
1526
|
+
key STRING,
|
|
1527
|
+
confidence FLOAT,
|
|
1528
|
+
tags LIST<STRING>,
|
|
1529
|
+
created_at STRING,
|
|
1530
|
+
updated_at STRING,
|
|
1531
|
+
memory_type STRING,
|
|
1532
|
+
sources LIST<STRING>,
|
|
1533
|
+
source STRING,
|
|
1534
|
+
node_type STRING,
|
|
1535
|
+
visibility STRING,
|
|
1536
|
+
usage LIST<STRING>,
|
|
1537
|
+
background STRING,
|
|
1538
|
+
{self.dim_field} VECTOR<{self.embedding_dimension}, FLOAT>,
|
|
1539
|
+
PRIMARY KEY(id)
|
|
1540
|
+
}}),
|
|
1541
|
+
EDGE RELATE_TO (Memory) -[{{user_name STRING}}]-> (Memory),
|
|
1542
|
+
EDGE PARENT (Memory) -[{{user_name STRING}}]-> (Memory),
|
|
1543
|
+
EDGE AGGREGATE_TO (Memory) -[{{user_name STRING}}]-> (Memory),
|
|
1544
|
+
EDGE MERGED_TO (Memory) -[{{user_name STRING}}]-> (Memory),
|
|
1545
|
+
EDGE INFERS (Memory) -[{{user_name STRING}}]-> (Memory),
|
|
1546
|
+
EDGE FOLLOWS (Memory) -[{{user_name STRING}}]-> (Memory)
|
|
1547
|
+
}}
|
|
1548
|
+
"""
|
|
1549
|
+
self.execute_query(create_tag, auto_set_db=False)
|
|
1550
|
+
else:
|
|
1551
|
+
describe_query = f"DESCRIBE NODE TYPE Memory OF {graph_type_name}"
|
|
1552
|
+
desc_result = self.execute_query(describe_query, auto_set_db=False)
|
|
1553
|
+
|
|
1554
|
+
memory_fields = []
|
|
1555
|
+
for row in desc_result:
|
|
1556
|
+
field_name = row.values()[0].as_string()
|
|
1557
|
+
memory_fields.append(field_name)
|
|
1558
|
+
|
|
1559
|
+
if self.dim_field not in memory_fields:
|
|
1560
|
+
alter_query = f"""
|
|
1561
|
+
ALTER GRAPH TYPE {graph_type_name} {{
|
|
1562
|
+
ALTER NODE TYPE Memory ADD PROPERTIES {{ {self.dim_field} VECTOR<{self.embedding_dimension}, FLOAT> }}
|
|
1563
|
+
}}
|
|
1564
|
+
"""
|
|
1565
|
+
self.execute_query(alter_query, auto_set_db=False)
|
|
1566
|
+
logger.info(f"✅ Add new vector search {self.dim_field} to {graph_type_name}")
|
|
1567
|
+
else:
|
|
1568
|
+
logger.info(f"✅ Graph Type {graph_type_name} already include {self.dim_field}")
|
|
1569
|
+
|
|
1570
|
+
create_graph = f"CREATE GRAPH IF NOT EXISTS `{self.db_name}` TYPED {graph_type_name}"
|
|
1571
|
+
try:
|
|
1572
|
+
self.execute_query(create_graph, auto_set_db=False)
|
|
1573
|
+
logger.info(f"✅ Graph ``{self.db_name}`` is now the working graph.")
|
|
1574
|
+
except Exception as e:
|
|
1575
|
+
logger.error(f"❌ Failed to create tag: {e} trace: {traceback.format_exc()}")
|
|
1576
|
+
|
|
1577
|
+
@timed
|
|
1578
|
+
def _create_vector_index(
|
|
1579
|
+
self,
|
|
1580
|
+
label: str = "Memory",
|
|
1581
|
+
vector_property: str = "embedding",
|
|
1582
|
+
dimensions: int = 3072,
|
|
1583
|
+
index_name: str = "memory_vector_index",
|
|
1584
|
+
) -> None:
|
|
1585
|
+
"""
|
|
1586
|
+
Create a vector index for the specified property in the label.
|
|
1587
|
+
"""
|
|
1588
|
+
if str(dimensions) == str(self.default_memory_dimension):
|
|
1589
|
+
index_name = f"idx_{vector_property}"
|
|
1590
|
+
vector_name = vector_property
|
|
1591
|
+
else:
|
|
1592
|
+
index_name = f"idx_{vector_property}_{dimensions}"
|
|
1593
|
+
vector_name = f"{vector_property}_{dimensions}"
|
|
1594
|
+
|
|
1595
|
+
create_vector_index = f"""
|
|
1596
|
+
CREATE VECTOR INDEX IF NOT EXISTS {index_name}
|
|
1597
|
+
ON NODE {label}::{vector_name}
|
|
1598
|
+
OPTIONS {{
|
|
1599
|
+
DIM: {dimensions},
|
|
1600
|
+
METRIC: IP,
|
|
1601
|
+
TYPE: IVF,
|
|
1602
|
+
NLIST: 100,
|
|
1603
|
+
TRAINSIZE: 1000
|
|
1604
|
+
}}
|
|
1605
|
+
FOR `{self.db_name}`
|
|
1606
|
+
"""
|
|
1607
|
+
self.execute_query(create_vector_index)
|
|
1608
|
+
logger.info(
|
|
1609
|
+
f"✅ Ensure {label}::{vector_property} vector index {index_name} "
|
|
1610
|
+
f"exists (DIM={dimensions})"
|
|
1611
|
+
)
|
|
1612
|
+
|
|
1613
|
+
@timed
|
|
1614
|
+
def _create_basic_property_indexes(self) -> None:
|
|
1615
|
+
"""
|
|
1616
|
+
Create standard B-tree indexes on status, memory_type, created_at
|
|
1617
|
+
and updated_at fields.
|
|
1618
|
+
Create standard B-tree indexes on user_name when use Shared Database
|
|
1619
|
+
Multi-Tenant Mode.
|
|
1620
|
+
"""
|
|
1621
|
+
fields = [
|
|
1622
|
+
"status",
|
|
1623
|
+
"memory_type",
|
|
1624
|
+
"created_at",
|
|
1625
|
+
"updated_at",
|
|
1626
|
+
"user_name",
|
|
1627
|
+
]
|
|
1628
|
+
|
|
1629
|
+
for field in fields:
|
|
1630
|
+
index_name = f"idx_memory_{field}"
|
|
1631
|
+
gql = f"""
|
|
1632
|
+
CREATE INDEX IF NOT EXISTS {index_name} ON NODE Memory({field})
|
|
1633
|
+
FOR `{self.db_name}`
|
|
1634
|
+
"""
|
|
1635
|
+
try:
|
|
1636
|
+
self.execute_query(gql)
|
|
1637
|
+
logger.info(f"✅ Created index: {index_name} on field {field}")
|
|
1638
|
+
except Exception as e:
|
|
1639
|
+
logger.error(
|
|
1640
|
+
f"❌ Failed to create index {index_name}: {e}, trace: {traceback.format_exc()}"
|
|
1641
|
+
)
|
|
1642
|
+
|
|
1643
|
+
@timed
|
|
1644
|
+
def _index_exists(self, index_name: str) -> bool:
|
|
1645
|
+
"""
|
|
1646
|
+
Check if an index with the given name exists.
|
|
1647
|
+
"""
|
|
1648
|
+
"""
|
|
1649
|
+
Check if a vector index with the given name exists in NebulaGraph.
|
|
1650
|
+
|
|
1651
|
+
Args:
|
|
1652
|
+
index_name (str): The name of the index to check.
|
|
1653
|
+
|
|
1654
|
+
Returns:
|
|
1655
|
+
bool: True if the index exists, False otherwise.
|
|
1656
|
+
"""
|
|
1657
|
+
query = "SHOW VECTOR INDEXES"
|
|
1658
|
+
try:
|
|
1659
|
+
result = self.execute_query(query)
|
|
1660
|
+
return any(row.values()[0].as_string() == index_name for row in result)
|
|
1661
|
+
except Exception as e:
|
|
1662
|
+
logger.error(f"[Nebula] Failed to check index existence: {e}")
|
|
1663
|
+
return False
|
|
1664
|
+
|
|
1665
|
+
@timed
|
|
1666
|
+
def _parse_value(self, value: Any) -> Any:
|
|
1667
|
+
"""turn Nebula ValueWrapper to Python type"""
|
|
1668
|
+
from nebulagraph_python.value_wrapper import ValueWrapper
|
|
1669
|
+
|
|
1670
|
+
if value is None or (hasattr(value, "is_null") and value.is_null()):
|
|
1671
|
+
return None
|
|
1672
|
+
try:
|
|
1673
|
+
prim = value.cast_primitive() if isinstance(value, ValueWrapper) else value
|
|
1674
|
+
except Exception as e:
|
|
1675
|
+
logger.warning(f"Error when decode Nebula ValueWrapper: {e}")
|
|
1676
|
+
prim = value.cast() if isinstance(value, ValueWrapper) else value
|
|
1677
|
+
|
|
1678
|
+
if isinstance(prim, ValueWrapper):
|
|
1679
|
+
return self._parse_value(prim)
|
|
1680
|
+
if isinstance(prim, list):
|
|
1681
|
+
return [self._parse_value(v) for v in prim]
|
|
1682
|
+
if type(prim).__name__ == "NVector":
|
|
1683
|
+
return list(prim.values)
|
|
1684
|
+
|
|
1685
|
+
return prim # already a Python primitive
|
|
1686
|
+
|
|
1687
|
+
def _parse_node(self, props: dict[str, Any]) -> dict[str, Any]:
|
|
1688
|
+
parsed = {k: self._parse_value(v) for k, v in props.items()}
|
|
1689
|
+
|
|
1690
|
+
for tf in ("created_at", "updated_at"):
|
|
1691
|
+
if tf in parsed and parsed[tf] is not None:
|
|
1692
|
+
parsed[tf] = _normalize_datetime(parsed[tf])
|
|
1693
|
+
|
|
1694
|
+
node_id = parsed.pop("id")
|
|
1695
|
+
memory = parsed.pop("memory", "")
|
|
1696
|
+
parsed.pop("user_name", None)
|
|
1697
|
+
metadata = parsed
|
|
1698
|
+
metadata["type"] = metadata.pop("node_type")
|
|
1699
|
+
|
|
1700
|
+
if self.dim_field in metadata:
|
|
1701
|
+
metadata["embedding"] = metadata.pop(self.dim_field)
|
|
1702
|
+
|
|
1703
|
+
return {"id": node_id, "memory": memory, "metadata": metadata}
|
|
1704
|
+
|
|
1705
|
+
@timed
|
|
1706
|
+
def _prepare_node_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]:
|
|
1707
|
+
"""
|
|
1708
|
+
Ensure metadata has proper datetime fields and normalized types.
|
|
1709
|
+
|
|
1710
|
+
- Fill `created_at` and `updated_at` if missing (in ISO 8601 format).
|
|
1711
|
+
- Convert embedding to list of float if present.
|
|
1712
|
+
"""
|
|
1713
|
+
now = datetime.utcnow().isoformat()
|
|
1714
|
+
metadata["node_type"] = metadata.pop("type")
|
|
1715
|
+
|
|
1716
|
+
# Fill timestamps if missing
|
|
1717
|
+
metadata.setdefault("created_at", now)
|
|
1718
|
+
metadata.setdefault("updated_at", now)
|
|
1719
|
+
|
|
1720
|
+
# Normalize embedding type
|
|
1721
|
+
embedding = metadata.get("embedding")
|
|
1722
|
+
if embedding and isinstance(embedding, list):
|
|
1723
|
+
metadata.pop("embedding")
|
|
1724
|
+
metadata[self.dim_field] = _normalize([float(x) for x in embedding])
|
|
1725
|
+
|
|
1726
|
+
return metadata
|
|
1727
|
+
|
|
1728
|
+
@timed
|
|
1729
|
+
def _format_value(self, val: Any, key: str = "") -> str:
|
|
1730
|
+
from nebulagraph_python.py_data_types import NVector
|
|
1731
|
+
|
|
1732
|
+
# None
|
|
1733
|
+
if val is None:
|
|
1734
|
+
return "NULL"
|
|
1735
|
+
# bool
|
|
1736
|
+
if isinstance(val, bool):
|
|
1737
|
+
return "true" if val else "false"
|
|
1738
|
+
# str
|
|
1739
|
+
if isinstance(val, str):
|
|
1740
|
+
return f'"{_escape_str(val)}"'
|
|
1741
|
+
# num
|
|
1742
|
+
elif isinstance(val, (int | float)):
|
|
1743
|
+
return str(val)
|
|
1744
|
+
# time
|
|
1745
|
+
elif isinstance(val, datetime):
|
|
1746
|
+
return f'datetime("{val.isoformat()}")'
|
|
1747
|
+
# list
|
|
1748
|
+
elif isinstance(val, list):
|
|
1749
|
+
if key == self.dim_field:
|
|
1750
|
+
dim = len(val)
|
|
1751
|
+
joined = ",".join(str(float(x)) for x in val)
|
|
1752
|
+
return f"VECTOR<{dim}, FLOAT>([{joined}])"
|
|
1753
|
+
else:
|
|
1754
|
+
return f"[{', '.join(self._format_value(v) for v in val)}]"
|
|
1755
|
+
# NVector
|
|
1756
|
+
elif isinstance(val, NVector):
|
|
1757
|
+
if key == self.dim_field:
|
|
1758
|
+
dim = len(val)
|
|
1759
|
+
joined = ",".join(str(float(x)) for x in val)
|
|
1760
|
+
return f"VECTOR<{dim}, FLOAT>([{joined}])"
|
|
1761
|
+
else:
|
|
1762
|
+
logger.warning("Invalid NVector")
|
|
1763
|
+
# dict
|
|
1764
|
+
if isinstance(val, dict):
|
|
1765
|
+
j = json.dumps(val, ensure_ascii=False, separators=(",", ":"))
|
|
1766
|
+
return f'"{_escape_str(j)}"'
|
|
1767
|
+
else:
|
|
1768
|
+
return f'"{_escape_str(str(val))}"'
|
|
1769
|
+
|
|
1770
|
+
@timed
|
|
1771
|
+
def _metadata_filter(self, metadata: dict[str, Any]) -> dict[str, Any]:
|
|
1772
|
+
"""
|
|
1773
|
+
Filter and validate metadata dictionary against the Memory node schema.
|
|
1774
|
+
- Removes keys not in schema.
|
|
1775
|
+
- Warns if required fields are missing.
|
|
1776
|
+
"""
|
|
1777
|
+
|
|
1778
|
+
dim_fields = {self.dim_field}
|
|
1779
|
+
|
|
1780
|
+
allowed_fields = self.common_fields | dim_fields
|
|
1781
|
+
|
|
1782
|
+
missing_fields = allowed_fields - metadata.keys()
|
|
1783
|
+
if missing_fields:
|
|
1784
|
+
logger.info(f"Metadata missing required fields: {sorted(missing_fields)}")
|
|
1785
|
+
|
|
1786
|
+
filtered_metadata = {k: v for k, v in metadata.items() if k in allowed_fields}
|
|
1787
|
+
|
|
1788
|
+
return filtered_metadata
|
|
1789
|
+
|
|
1790
|
+
def _build_return_fields(self, include_embedding: bool = False) -> str:
|
|
1791
|
+
fields = set(self.base_fields)
|
|
1792
|
+
if include_embedding:
|
|
1793
|
+
fields.add(self.dim_field)
|
|
1794
|
+
return ", ".join(f"n.{f} AS {f}" for f in fields)
|