MemoryOS 2.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- memoryos-2.0.3.dist-info/METADATA +418 -0
- memoryos-2.0.3.dist-info/RECORD +315 -0
- memoryos-2.0.3.dist-info/WHEEL +4 -0
- memoryos-2.0.3.dist-info/entry_points.txt +3 -0
- memoryos-2.0.3.dist-info/licenses/LICENSE +201 -0
- memos/__init__.py +20 -0
- memos/api/client.py +571 -0
- memos/api/config.py +1018 -0
- memos/api/context/dependencies.py +50 -0
- memos/api/exceptions.py +53 -0
- memos/api/handlers/__init__.py +62 -0
- memos/api/handlers/add_handler.py +158 -0
- memos/api/handlers/base_handler.py +194 -0
- memos/api/handlers/chat_handler.py +1401 -0
- memos/api/handlers/component_init.py +388 -0
- memos/api/handlers/config_builders.py +190 -0
- memos/api/handlers/feedback_handler.py +93 -0
- memos/api/handlers/formatters_handler.py +237 -0
- memos/api/handlers/memory_handler.py +316 -0
- memos/api/handlers/scheduler_handler.py +497 -0
- memos/api/handlers/search_handler.py +222 -0
- memos/api/handlers/suggestion_handler.py +117 -0
- memos/api/mcp_serve.py +614 -0
- memos/api/middleware/request_context.py +101 -0
- memos/api/product_api.py +38 -0
- memos/api/product_models.py +1206 -0
- memos/api/routers/__init__.py +1 -0
- memos/api/routers/product_router.py +477 -0
- memos/api/routers/server_router.py +394 -0
- memos/api/server_api.py +44 -0
- memos/api/start_api.py +433 -0
- memos/chunkers/__init__.py +4 -0
- memos/chunkers/base.py +24 -0
- memos/chunkers/charactertext_chunker.py +41 -0
- memos/chunkers/factory.py +24 -0
- memos/chunkers/markdown_chunker.py +62 -0
- memos/chunkers/sentence_chunker.py +54 -0
- memos/chunkers/simple_chunker.py +50 -0
- memos/cli.py +113 -0
- memos/configs/__init__.py +0 -0
- memos/configs/base.py +82 -0
- memos/configs/chunker.py +59 -0
- memos/configs/embedder.py +88 -0
- memos/configs/graph_db.py +236 -0
- memos/configs/internet_retriever.py +100 -0
- memos/configs/llm.py +151 -0
- memos/configs/mem_agent.py +54 -0
- memos/configs/mem_chat.py +81 -0
- memos/configs/mem_cube.py +105 -0
- memos/configs/mem_os.py +83 -0
- memos/configs/mem_reader.py +91 -0
- memos/configs/mem_scheduler.py +385 -0
- memos/configs/mem_user.py +70 -0
- memos/configs/memory.py +324 -0
- memos/configs/parser.py +38 -0
- memos/configs/reranker.py +18 -0
- memos/configs/utils.py +8 -0
- memos/configs/vec_db.py +80 -0
- memos/context/context.py +355 -0
- memos/dependency.py +52 -0
- memos/deprecation.py +262 -0
- memos/embedders/__init__.py +0 -0
- memos/embedders/ark.py +95 -0
- memos/embedders/base.py +106 -0
- memos/embedders/factory.py +29 -0
- memos/embedders/ollama.py +77 -0
- memos/embedders/sentence_transformer.py +49 -0
- memos/embedders/universal_api.py +51 -0
- memos/exceptions.py +30 -0
- memos/graph_dbs/__init__.py +0 -0
- memos/graph_dbs/base.py +274 -0
- memos/graph_dbs/factory.py +27 -0
- memos/graph_dbs/item.py +46 -0
- memos/graph_dbs/nebular.py +1794 -0
- memos/graph_dbs/neo4j.py +1942 -0
- memos/graph_dbs/neo4j_community.py +1058 -0
- memos/graph_dbs/polardb.py +5446 -0
- memos/hello_world.py +97 -0
- memos/llms/__init__.py +0 -0
- memos/llms/base.py +25 -0
- memos/llms/deepseek.py +13 -0
- memos/llms/factory.py +38 -0
- memos/llms/hf.py +443 -0
- memos/llms/hf_singleton.py +114 -0
- memos/llms/ollama.py +135 -0
- memos/llms/openai.py +222 -0
- memos/llms/openai_new.py +198 -0
- memos/llms/qwen.py +13 -0
- memos/llms/utils.py +14 -0
- memos/llms/vllm.py +218 -0
- memos/log.py +237 -0
- memos/mem_agent/base.py +19 -0
- memos/mem_agent/deepsearch_agent.py +391 -0
- memos/mem_agent/factory.py +36 -0
- memos/mem_chat/__init__.py +0 -0
- memos/mem_chat/base.py +30 -0
- memos/mem_chat/factory.py +21 -0
- memos/mem_chat/simple.py +200 -0
- memos/mem_cube/__init__.py +0 -0
- memos/mem_cube/base.py +30 -0
- memos/mem_cube/general.py +240 -0
- memos/mem_cube/navie.py +172 -0
- memos/mem_cube/utils.py +169 -0
- memos/mem_feedback/base.py +15 -0
- memos/mem_feedback/feedback.py +1192 -0
- memos/mem_feedback/simple_feedback.py +40 -0
- memos/mem_feedback/utils.py +230 -0
- memos/mem_os/client.py +5 -0
- memos/mem_os/core.py +1203 -0
- memos/mem_os/main.py +582 -0
- memos/mem_os/product.py +1608 -0
- memos/mem_os/product_server.py +455 -0
- memos/mem_os/utils/default_config.py +359 -0
- memos/mem_os/utils/format_utils.py +1403 -0
- memos/mem_os/utils/reference_utils.py +162 -0
- memos/mem_reader/__init__.py +0 -0
- memos/mem_reader/base.py +47 -0
- memos/mem_reader/factory.py +53 -0
- memos/mem_reader/memory.py +298 -0
- memos/mem_reader/multi_modal_struct.py +965 -0
- memos/mem_reader/read_multi_modal/__init__.py +43 -0
- memos/mem_reader/read_multi_modal/assistant_parser.py +311 -0
- memos/mem_reader/read_multi_modal/base.py +273 -0
- memos/mem_reader/read_multi_modal/file_content_parser.py +826 -0
- memos/mem_reader/read_multi_modal/image_parser.py +359 -0
- memos/mem_reader/read_multi_modal/multi_modal_parser.py +252 -0
- memos/mem_reader/read_multi_modal/string_parser.py +139 -0
- memos/mem_reader/read_multi_modal/system_parser.py +327 -0
- memos/mem_reader/read_multi_modal/text_content_parser.py +131 -0
- memos/mem_reader/read_multi_modal/tool_parser.py +210 -0
- memos/mem_reader/read_multi_modal/user_parser.py +218 -0
- memos/mem_reader/read_multi_modal/utils.py +358 -0
- memos/mem_reader/simple_struct.py +912 -0
- memos/mem_reader/strategy_struct.py +163 -0
- memos/mem_reader/utils.py +157 -0
- memos/mem_scheduler/__init__.py +0 -0
- memos/mem_scheduler/analyzer/__init__.py +0 -0
- memos/mem_scheduler/analyzer/api_analyzer.py +714 -0
- memos/mem_scheduler/analyzer/eval_analyzer.py +219 -0
- memos/mem_scheduler/analyzer/mos_for_test_scheduler.py +571 -0
- memos/mem_scheduler/analyzer/scheduler_for_eval.py +280 -0
- memos/mem_scheduler/base_scheduler.py +1319 -0
- memos/mem_scheduler/general_modules/__init__.py +0 -0
- memos/mem_scheduler/general_modules/api_misc.py +137 -0
- memos/mem_scheduler/general_modules/base.py +80 -0
- memos/mem_scheduler/general_modules/init_components_for_scheduler.py +425 -0
- memos/mem_scheduler/general_modules/misc.py +313 -0
- memos/mem_scheduler/general_modules/scheduler_logger.py +389 -0
- memos/mem_scheduler/general_modules/task_threads.py +315 -0
- memos/mem_scheduler/general_scheduler.py +1495 -0
- memos/mem_scheduler/memory_manage_modules/__init__.py +5 -0
- memos/mem_scheduler/memory_manage_modules/memory_filter.py +306 -0
- memos/mem_scheduler/memory_manage_modules/retriever.py +547 -0
- memos/mem_scheduler/monitors/__init__.py +0 -0
- memos/mem_scheduler/monitors/dispatcher_monitor.py +366 -0
- memos/mem_scheduler/monitors/general_monitor.py +394 -0
- memos/mem_scheduler/monitors/task_schedule_monitor.py +254 -0
- memos/mem_scheduler/optimized_scheduler.py +410 -0
- memos/mem_scheduler/orm_modules/__init__.py +0 -0
- memos/mem_scheduler/orm_modules/api_redis_model.py +518 -0
- memos/mem_scheduler/orm_modules/base_model.py +729 -0
- memos/mem_scheduler/orm_modules/monitor_models.py +261 -0
- memos/mem_scheduler/orm_modules/redis_model.py +699 -0
- memos/mem_scheduler/scheduler_factory.py +23 -0
- memos/mem_scheduler/schemas/__init__.py +0 -0
- memos/mem_scheduler/schemas/analyzer_schemas.py +52 -0
- memos/mem_scheduler/schemas/api_schemas.py +233 -0
- memos/mem_scheduler/schemas/general_schemas.py +55 -0
- memos/mem_scheduler/schemas/message_schemas.py +173 -0
- memos/mem_scheduler/schemas/monitor_schemas.py +406 -0
- memos/mem_scheduler/schemas/task_schemas.py +132 -0
- memos/mem_scheduler/task_schedule_modules/__init__.py +0 -0
- memos/mem_scheduler/task_schedule_modules/dispatcher.py +740 -0
- memos/mem_scheduler/task_schedule_modules/local_queue.py +247 -0
- memos/mem_scheduler/task_schedule_modules/orchestrator.py +74 -0
- memos/mem_scheduler/task_schedule_modules/redis_queue.py +1385 -0
- memos/mem_scheduler/task_schedule_modules/task_queue.py +162 -0
- memos/mem_scheduler/utils/__init__.py +0 -0
- memos/mem_scheduler/utils/api_utils.py +77 -0
- memos/mem_scheduler/utils/config_utils.py +100 -0
- memos/mem_scheduler/utils/db_utils.py +50 -0
- memos/mem_scheduler/utils/filter_utils.py +176 -0
- memos/mem_scheduler/utils/metrics.py +125 -0
- memos/mem_scheduler/utils/misc_utils.py +290 -0
- memos/mem_scheduler/utils/monitor_event_utils.py +67 -0
- memos/mem_scheduler/utils/status_tracker.py +229 -0
- memos/mem_scheduler/webservice_modules/__init__.py +0 -0
- memos/mem_scheduler/webservice_modules/rabbitmq_service.py +485 -0
- memos/mem_scheduler/webservice_modules/redis_service.py +380 -0
- memos/mem_user/factory.py +94 -0
- memos/mem_user/mysql_persistent_user_manager.py +271 -0
- memos/mem_user/mysql_user_manager.py +502 -0
- memos/mem_user/persistent_factory.py +98 -0
- memos/mem_user/persistent_user_manager.py +260 -0
- memos/mem_user/redis_persistent_user_manager.py +225 -0
- memos/mem_user/user_manager.py +488 -0
- memos/memories/__init__.py +0 -0
- memos/memories/activation/__init__.py +0 -0
- memos/memories/activation/base.py +42 -0
- memos/memories/activation/item.py +56 -0
- memos/memories/activation/kv.py +292 -0
- memos/memories/activation/vllmkv.py +219 -0
- memos/memories/base.py +19 -0
- memos/memories/factory.py +42 -0
- memos/memories/parametric/__init__.py +0 -0
- memos/memories/parametric/base.py +19 -0
- memos/memories/parametric/item.py +11 -0
- memos/memories/parametric/lora.py +41 -0
- memos/memories/textual/__init__.py +0 -0
- memos/memories/textual/base.py +92 -0
- memos/memories/textual/general.py +236 -0
- memos/memories/textual/item.py +304 -0
- memos/memories/textual/naive.py +187 -0
- memos/memories/textual/prefer_text_memory/__init__.py +0 -0
- memos/memories/textual/prefer_text_memory/adder.py +504 -0
- memos/memories/textual/prefer_text_memory/config.py +106 -0
- memos/memories/textual/prefer_text_memory/extractor.py +221 -0
- memos/memories/textual/prefer_text_memory/factory.py +85 -0
- memos/memories/textual/prefer_text_memory/retrievers.py +177 -0
- memos/memories/textual/prefer_text_memory/spliter.py +132 -0
- memos/memories/textual/prefer_text_memory/utils.py +93 -0
- memos/memories/textual/preference.py +344 -0
- memos/memories/textual/simple_preference.py +161 -0
- memos/memories/textual/simple_tree.py +69 -0
- memos/memories/textual/tree.py +459 -0
- memos/memories/textual/tree_text_memory/__init__.py +0 -0
- memos/memories/textual/tree_text_memory/organize/__init__.py +0 -0
- memos/memories/textual/tree_text_memory/organize/handler.py +184 -0
- memos/memories/textual/tree_text_memory/organize/manager.py +518 -0
- memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +238 -0
- memos/memories/textual/tree_text_memory/organize/reorganizer.py +622 -0
- memos/memories/textual/tree_text_memory/retrieve/__init__.py +0 -0
- memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py +364 -0
- memos/memories/textual/tree_text_memory/retrieve/bm25_util.py +186 -0
- memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +419 -0
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever.py +270 -0
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +102 -0
- memos/memories/textual/tree_text_memory/retrieve/reasoner.py +61 -0
- memos/memories/textual/tree_text_memory/retrieve/recall.py +497 -0
- memos/memories/textual/tree_text_memory/retrieve/reranker.py +111 -0
- memos/memories/textual/tree_text_memory/retrieve/retrieval_mid_structs.py +16 -0
- memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py +472 -0
- memos/memories/textual/tree_text_memory/retrieve/searcher.py +848 -0
- memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +135 -0
- memos/memories/textual/tree_text_memory/retrieve/utils.py +54 -0
- memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py +387 -0
- memos/memos_tools/dinding_report_bot.py +453 -0
- memos/memos_tools/lockfree_dict.py +120 -0
- memos/memos_tools/notification_service.py +44 -0
- memos/memos_tools/notification_utils.py +142 -0
- memos/memos_tools/singleton.py +174 -0
- memos/memos_tools/thread_safe_dict.py +310 -0
- memos/memos_tools/thread_safe_dict_segment.py +382 -0
- memos/multi_mem_cube/__init__.py +0 -0
- memos/multi_mem_cube/composite_cube.py +86 -0
- memos/multi_mem_cube/single_cube.py +874 -0
- memos/multi_mem_cube/views.py +54 -0
- memos/parsers/__init__.py +0 -0
- memos/parsers/base.py +15 -0
- memos/parsers/factory.py +21 -0
- memos/parsers/markitdown.py +28 -0
- memos/reranker/__init__.py +4 -0
- memos/reranker/base.py +25 -0
- memos/reranker/concat.py +103 -0
- memos/reranker/cosine_local.py +102 -0
- memos/reranker/factory.py +72 -0
- memos/reranker/http_bge.py +324 -0
- memos/reranker/http_bge_strategy.py +327 -0
- memos/reranker/noop.py +19 -0
- memos/reranker/strategies/__init__.py +4 -0
- memos/reranker/strategies/base.py +61 -0
- memos/reranker/strategies/concat_background.py +94 -0
- memos/reranker/strategies/concat_docsource.py +110 -0
- memos/reranker/strategies/dialogue_common.py +109 -0
- memos/reranker/strategies/factory.py +31 -0
- memos/reranker/strategies/single_turn.py +107 -0
- memos/reranker/strategies/singleturn_outmem.py +98 -0
- memos/settings.py +10 -0
- memos/templates/__init__.py +0 -0
- memos/templates/advanced_search_prompts.py +211 -0
- memos/templates/cloud_service_prompt.py +107 -0
- memos/templates/instruction_completion.py +66 -0
- memos/templates/mem_agent_prompts.py +85 -0
- memos/templates/mem_feedback_prompts.py +822 -0
- memos/templates/mem_reader_prompts.py +1096 -0
- memos/templates/mem_reader_strategy_prompts.py +238 -0
- memos/templates/mem_scheduler_prompts.py +626 -0
- memos/templates/mem_search_prompts.py +93 -0
- memos/templates/mos_prompts.py +403 -0
- memos/templates/prefer_complete_prompt.py +735 -0
- memos/templates/tool_mem_prompts.py +139 -0
- memos/templates/tree_reorganize_prompts.py +230 -0
- memos/types/__init__.py +34 -0
- memos/types/general_types.py +151 -0
- memos/types/openai_chat_completion_types/__init__.py +15 -0
- memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py +56 -0
- memos/types/openai_chat_completion_types/chat_completion_content_part_image_param.py +27 -0
- memos/types/openai_chat_completion_types/chat_completion_content_part_input_audio_param.py +23 -0
- memos/types/openai_chat_completion_types/chat_completion_content_part_param.py +43 -0
- memos/types/openai_chat_completion_types/chat_completion_content_part_refusal_param.py +16 -0
- memos/types/openai_chat_completion_types/chat_completion_content_part_text_param.py +16 -0
- memos/types/openai_chat_completion_types/chat_completion_message_custom_tool_call_param.py +27 -0
- memos/types/openai_chat_completion_types/chat_completion_message_function_tool_call_param.py +32 -0
- memos/types/openai_chat_completion_types/chat_completion_message_param.py +18 -0
- memos/types/openai_chat_completion_types/chat_completion_message_tool_call_union_param.py +15 -0
- memos/types/openai_chat_completion_types/chat_completion_system_message_param.py +36 -0
- memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py +30 -0
- memos/types/openai_chat_completion_types/chat_completion_user_message_param.py +34 -0
- memos/utils.py +123 -0
- memos/vec_dbs/__init__.py +0 -0
- memos/vec_dbs/base.py +117 -0
- memos/vec_dbs/factory.py +23 -0
- memos/vec_dbs/item.py +50 -0
- memos/vec_dbs/milvus.py +654 -0
- memos/vec_dbs/qdrant.py +355 -0
memos/graph_dbs/neo4j.py
ADDED
|
@@ -0,0 +1,1942 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import time
|
|
3
|
+
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import Any, Literal
|
|
6
|
+
|
|
7
|
+
from memos.configs.graph_db import Neo4jGraphDBConfig
|
|
8
|
+
from memos.dependency import require_python_package
|
|
9
|
+
from memos.graph_dbs.base import BaseGraphDB
|
|
10
|
+
from memos.log import get_logger
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
logger = get_logger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]:
|
|
17
|
+
node_id = item["id"]
|
|
18
|
+
memory = item["memory"]
|
|
19
|
+
metadata = item.get("metadata", {})
|
|
20
|
+
return node_id, memory, metadata
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _prepare_node_metadata(metadata: dict[str, Any]) -> dict[str, Any]:
|
|
24
|
+
"""
|
|
25
|
+
Ensure metadata has proper datetime fields and normalized types.
|
|
26
|
+
|
|
27
|
+
- Fill `created_at` and `updated_at` if missing (in ISO 8601 format).
|
|
28
|
+
- Convert embedding to list of float if present.
|
|
29
|
+
"""
|
|
30
|
+
now = datetime.utcnow().isoformat()
|
|
31
|
+
|
|
32
|
+
# Fill timestamps if missing
|
|
33
|
+
metadata.setdefault("created_at", now)
|
|
34
|
+
metadata.setdefault("updated_at", now)
|
|
35
|
+
|
|
36
|
+
# Normalize embedding type
|
|
37
|
+
embedding = metadata.get("embedding")
|
|
38
|
+
if embedding and isinstance(embedding, list):
|
|
39
|
+
metadata["embedding"] = [float(x) for x in embedding]
|
|
40
|
+
|
|
41
|
+
# serialization
|
|
42
|
+
if metadata["sources"]:
|
|
43
|
+
for idx in range(len(metadata["sources"])):
|
|
44
|
+
metadata["sources"][idx] = json.dumps(metadata["sources"][idx])
|
|
45
|
+
return metadata
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _flatten_info_fields(metadata: dict[str, Any]) -> dict[str, Any]:
|
|
49
|
+
"""
|
|
50
|
+
Flatten the 'info' field in metadata to the top level.
|
|
51
|
+
|
|
52
|
+
If metadata contains an 'info' field that is a dictionary, all its key-value pairs
|
|
53
|
+
will be moved to the top level of metadata, and the 'info' field will be removed.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
metadata: Dictionary that may contain an 'info' field
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Dictionary with 'info' fields flattened to top level
|
|
60
|
+
|
|
61
|
+
Example:
|
|
62
|
+
Input: {"user_id": "xxx", "info": {"A": "value1", "B": "value2"}}
|
|
63
|
+
Output: {"user_id": "xxx", "A": "value1", "B": "value2"}
|
|
64
|
+
"""
|
|
65
|
+
if "info" in metadata and isinstance(metadata["info"], dict):
|
|
66
|
+
# Copy info fields to top level
|
|
67
|
+
info_dict = metadata.pop("info")
|
|
68
|
+
for key, value in info_dict.items():
|
|
69
|
+
# Only add if key doesn't already exist at top level (to avoid overwriting)
|
|
70
|
+
if key not in metadata:
|
|
71
|
+
metadata[key] = value
|
|
72
|
+
return metadata
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class Neo4jGraphDB(BaseGraphDB):
|
|
76
|
+
"""Neo4j-based implementation of a graph memory store."""
|
|
77
|
+
|
|
78
|
+
@require_python_package(
|
|
79
|
+
import_name="neo4j",
|
|
80
|
+
install_command="pip install neo4j",
|
|
81
|
+
install_link="https://neo4j.com/docs/python-manual/current/install/",
|
|
82
|
+
)
|
|
83
|
+
def __init__(self, config: Neo4jGraphDBConfig):
|
|
84
|
+
"""Neo4j-based implementation of a graph memory store.
|
|
85
|
+
|
|
86
|
+
Tenant Modes:
|
|
87
|
+
- use_multi_db = True:
|
|
88
|
+
Dedicated Database Mode (Multi-Database Multi-Tenant).
|
|
89
|
+
Each tenant or logical scope uses a separate Neo4j database.
|
|
90
|
+
`db_name` is the specific tenant database.
|
|
91
|
+
`user_name` can be None (optional).
|
|
92
|
+
|
|
93
|
+
- use_multi_db = False:
|
|
94
|
+
Shared Database Multi-Tenant Mode.
|
|
95
|
+
All tenants share a single Neo4j database.
|
|
96
|
+
`db_name` is the shared database.
|
|
97
|
+
`user_name` is required to isolate each tenant's data at the node level.
|
|
98
|
+
All node queries will enforce `user_name` in WHERE conditions and store it in metadata,
|
|
99
|
+
but it will be removed automatically before returning to external consumers.
|
|
100
|
+
"""
|
|
101
|
+
from neo4j import GraphDatabase
|
|
102
|
+
|
|
103
|
+
self.config = config
|
|
104
|
+
self.driver = GraphDatabase.driver(config.uri, auth=(config.user, config.password))
|
|
105
|
+
self.db_name = config.db_name
|
|
106
|
+
self.user_name = config.user_name
|
|
107
|
+
|
|
108
|
+
self.system_db_name = "system" if config.use_multi_db else config.db_name
|
|
109
|
+
if config.auto_create:
|
|
110
|
+
self._ensure_database_exists()
|
|
111
|
+
|
|
112
|
+
# Create only if not exists
|
|
113
|
+
self.create_index(dimensions=config.embedding_dimension)
|
|
114
|
+
|
|
115
|
+
def create_index(
|
|
116
|
+
self,
|
|
117
|
+
label: str = "Memory",
|
|
118
|
+
vector_property: str = "embedding",
|
|
119
|
+
dimensions: int = 1536,
|
|
120
|
+
index_name: str = "memory_vector_index",
|
|
121
|
+
) -> None:
|
|
122
|
+
"""
|
|
123
|
+
Create the vector index for embedding and datetime indexes for created_at and updated_at fields.
|
|
124
|
+
"""
|
|
125
|
+
# Create vector index if it doesn't exist
|
|
126
|
+
if not self._vector_index_exists(index_name):
|
|
127
|
+
self._create_vector_index(label, vector_property, dimensions, index_name)
|
|
128
|
+
# Create indexes
|
|
129
|
+
self._create_basic_property_indexes()
|
|
130
|
+
|
|
131
|
+
def get_memory_count(self, memory_type: str, user_name: str | None = None) -> int:
|
|
132
|
+
user_name = user_name if user_name else self.config.user_name
|
|
133
|
+
query = """
|
|
134
|
+
MATCH (n:Memory)
|
|
135
|
+
WHERE n.memory_type = $memory_type
|
|
136
|
+
"""
|
|
137
|
+
if not self.config.use_multi_db and (self.config.user_name or user_name):
|
|
138
|
+
query += "\nAND n.user_name = $user_name"
|
|
139
|
+
query += "\nRETURN COUNT(n) AS count"
|
|
140
|
+
with self.driver.session(database=self.db_name) as session:
|
|
141
|
+
result = session.run(
|
|
142
|
+
query,
|
|
143
|
+
{
|
|
144
|
+
"memory_type": memory_type,
|
|
145
|
+
"user_name": user_name,
|
|
146
|
+
},
|
|
147
|
+
)
|
|
148
|
+
return result.single()["count"]
|
|
149
|
+
|
|
150
|
+
def node_not_exist(self, scope: str, user_name: str | None = None) -> int:
|
|
151
|
+
user_name = user_name if user_name else self.config.user_name
|
|
152
|
+
query = """
|
|
153
|
+
MATCH (n:Memory)
|
|
154
|
+
WHERE n.memory_type = $scope
|
|
155
|
+
"""
|
|
156
|
+
if not self.config.use_multi_db and (self.config.user_name or user_name):
|
|
157
|
+
query += "\nAND n.user_name = $user_name"
|
|
158
|
+
query += "\nRETURN n LIMIT 1"
|
|
159
|
+
|
|
160
|
+
with self.driver.session(database=self.db_name) as session:
|
|
161
|
+
result = session.run(
|
|
162
|
+
query,
|
|
163
|
+
{
|
|
164
|
+
"scope": scope,
|
|
165
|
+
"user_name": user_name,
|
|
166
|
+
},
|
|
167
|
+
)
|
|
168
|
+
return result.single() is None
|
|
169
|
+
|
|
170
|
+
def remove_oldest_memory(
|
|
171
|
+
self, memory_type: str, keep_latest: int, user_name: str | None = None
|
|
172
|
+
) -> None:
|
|
173
|
+
"""
|
|
174
|
+
Remove all WorkingMemory nodes except the latest `keep_latest` entries.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory').
|
|
178
|
+
keep_latest (int): Number of latest WorkingMemory entries to keep.
|
|
179
|
+
user_name(str): optional user_name.
|
|
180
|
+
"""
|
|
181
|
+
user_name = user_name if user_name else self.config.user_name
|
|
182
|
+
query = f"""
|
|
183
|
+
MATCH (n:Memory)
|
|
184
|
+
WHERE n.memory_type = '{memory_type}'
|
|
185
|
+
"""
|
|
186
|
+
if not self.config.use_multi_db and (self.config.user_name or user_name):
|
|
187
|
+
query += f"\nAND n.user_name = '{user_name}'"
|
|
188
|
+
keep_latest = int(keep_latest)
|
|
189
|
+
query += f"""
|
|
190
|
+
WITH n ORDER BY n.updated_at DESC
|
|
191
|
+
SKIP {keep_latest}
|
|
192
|
+
DETACH DELETE n
|
|
193
|
+
"""
|
|
194
|
+
with self.driver.session(database=self.db_name) as session:
|
|
195
|
+
session.run(query)
|
|
196
|
+
|
|
197
|
+
def add_node(
|
|
198
|
+
self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None
|
|
199
|
+
) -> None:
|
|
200
|
+
logger.info(f"[add_node] metadata: {metadata},info: {metadata.get('info')}")
|
|
201
|
+
|
|
202
|
+
user_name = user_name if user_name else self.config.user_name
|
|
203
|
+
if not self.config.use_multi_db and (self.config.user_name or user_name):
|
|
204
|
+
metadata["user_name"] = user_name
|
|
205
|
+
|
|
206
|
+
# Safely process metadata
|
|
207
|
+
metadata = _prepare_node_metadata(metadata)
|
|
208
|
+
|
|
209
|
+
# Flatten info fields to top level (for Neo4j flat structure)
|
|
210
|
+
metadata = _flatten_info_fields(metadata)
|
|
211
|
+
|
|
212
|
+
# Initialize delete_time and delete_record_id fields
|
|
213
|
+
metadata.setdefault("delete_time", "")
|
|
214
|
+
metadata.setdefault("delete_record_id", "")
|
|
215
|
+
|
|
216
|
+
# Merge node and set metadata
|
|
217
|
+
created_at = metadata.pop("created_at")
|
|
218
|
+
updated_at = metadata.pop("updated_at")
|
|
219
|
+
|
|
220
|
+
query = """
|
|
221
|
+
MERGE (n:Memory {id: $id})
|
|
222
|
+
SET n.memory = $memory,
|
|
223
|
+
n.created_at = datetime($created_at),
|
|
224
|
+
n.updated_at = datetime($updated_at),
|
|
225
|
+
n += $metadata
|
|
226
|
+
"""
|
|
227
|
+
|
|
228
|
+
# serialization
|
|
229
|
+
if metadata["sources"]:
|
|
230
|
+
for idx in range(len(metadata["sources"])):
|
|
231
|
+
metadata["sources"][idx] = json.dumps(metadata["sources"][idx])
|
|
232
|
+
|
|
233
|
+
with self.driver.session(database=self.db_name) as session:
|
|
234
|
+
session.run(
|
|
235
|
+
query,
|
|
236
|
+
id=id,
|
|
237
|
+
memory=memory,
|
|
238
|
+
created_at=created_at,
|
|
239
|
+
updated_at=updated_at,
|
|
240
|
+
metadata=metadata,
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
def add_nodes_batch(
|
|
244
|
+
self,
|
|
245
|
+
nodes: list[dict[str, Any]],
|
|
246
|
+
user_name: str | None = None,
|
|
247
|
+
) -> None:
|
|
248
|
+
"""
|
|
249
|
+
Batch add multiple memory nodes to the graph.
|
|
250
|
+
|
|
251
|
+
Args:
|
|
252
|
+
nodes: List of node dictionaries, each containing:
|
|
253
|
+
- id: str - Node ID
|
|
254
|
+
- memory: str - Memory content
|
|
255
|
+
- metadata: dict[str, Any] - Node metadata
|
|
256
|
+
user_name: Optional user name (will use config default if not provided)
|
|
257
|
+
"""
|
|
258
|
+
logger.info("neo4j [add_nodes_batch] staring")
|
|
259
|
+
if not nodes:
|
|
260
|
+
logger.warning("[add_nodes_batch] Empty nodes list, skipping")
|
|
261
|
+
return
|
|
262
|
+
|
|
263
|
+
logger.info(f"[add_nodes_batch] Adding {len(nodes)} nodes")
|
|
264
|
+
|
|
265
|
+
# user_name comes from parameter; fallback to config if missing
|
|
266
|
+
effective_user_name = user_name if user_name else self.config.user_name
|
|
267
|
+
|
|
268
|
+
# Prepare all nodes
|
|
269
|
+
prepared_nodes = []
|
|
270
|
+
for node_data in nodes:
|
|
271
|
+
try:
|
|
272
|
+
id = node_data["id"]
|
|
273
|
+
memory = node_data["memory"]
|
|
274
|
+
metadata = node_data.get("metadata", {})
|
|
275
|
+
|
|
276
|
+
logger.debug(f"[add_nodes_batch] Processing node id: {id}")
|
|
277
|
+
|
|
278
|
+
# Set user_name in metadata if needed
|
|
279
|
+
if not self.config.use_multi_db and (self.config.user_name or effective_user_name):
|
|
280
|
+
metadata["user_name"] = effective_user_name
|
|
281
|
+
|
|
282
|
+
# Safely process metadata
|
|
283
|
+
metadata = _prepare_node_metadata(metadata)
|
|
284
|
+
|
|
285
|
+
# Flatten info fields to top level (for Neo4j flat structure)
|
|
286
|
+
metadata = _flatten_info_fields(metadata)
|
|
287
|
+
|
|
288
|
+
# Initialize delete_time and delete_record_id fields
|
|
289
|
+
metadata.setdefault("delete_time", "")
|
|
290
|
+
metadata.setdefault("delete_record_id", "")
|
|
291
|
+
|
|
292
|
+
# Merge node and set metadata
|
|
293
|
+
created_at = metadata.pop("created_at")
|
|
294
|
+
updated_at = metadata.pop("updated_at")
|
|
295
|
+
|
|
296
|
+
# Serialization for sources
|
|
297
|
+
if metadata.get("sources"):
|
|
298
|
+
for idx in range(len(metadata["sources"])):
|
|
299
|
+
metadata["sources"][idx] = json.dumps(metadata["sources"][idx])
|
|
300
|
+
|
|
301
|
+
prepared_nodes.append(
|
|
302
|
+
{
|
|
303
|
+
"id": id,
|
|
304
|
+
"memory": memory,
|
|
305
|
+
"created_at": created_at,
|
|
306
|
+
"updated_at": updated_at,
|
|
307
|
+
"metadata": metadata,
|
|
308
|
+
}
|
|
309
|
+
)
|
|
310
|
+
except Exception as e:
|
|
311
|
+
logger.error(
|
|
312
|
+
f"[add_nodes_batch] Failed to prepare node {node_data.get('id', 'unknown')}: {e}",
|
|
313
|
+
exc_info=True,
|
|
314
|
+
)
|
|
315
|
+
# Continue with other nodes
|
|
316
|
+
continue
|
|
317
|
+
|
|
318
|
+
if not prepared_nodes:
|
|
319
|
+
logger.warning("[add_nodes_batch] No valid nodes to insert after preparation")
|
|
320
|
+
return
|
|
321
|
+
|
|
322
|
+
# Batch insert using Neo4j UNWIND for better performance
|
|
323
|
+
query = """
|
|
324
|
+
UNWIND $nodes AS node
|
|
325
|
+
MERGE (n:Memory {id: node.id})
|
|
326
|
+
SET n.memory = node.memory,
|
|
327
|
+
n.created_at = datetime(node.created_at),
|
|
328
|
+
n.updated_at = datetime(node.updated_at),
|
|
329
|
+
n += node.metadata
|
|
330
|
+
"""
|
|
331
|
+
|
|
332
|
+
# Prepare nodes data for UNWIND
|
|
333
|
+
nodes_data = [
|
|
334
|
+
{
|
|
335
|
+
"id": node["id"],
|
|
336
|
+
"memory": node["memory"],
|
|
337
|
+
"created_at": node["created_at"],
|
|
338
|
+
"updated_at": node["updated_at"],
|
|
339
|
+
"metadata": node["metadata"],
|
|
340
|
+
}
|
|
341
|
+
for node in prepared_nodes
|
|
342
|
+
]
|
|
343
|
+
|
|
344
|
+
try:
|
|
345
|
+
with self.driver.session(database=self.db_name) as session:
|
|
346
|
+
session.run(query, nodes=nodes_data)
|
|
347
|
+
logger.info(f"[add_nodes_batch] Successfully inserted {len(prepared_nodes)} nodes")
|
|
348
|
+
except Exception as e:
|
|
349
|
+
logger.error(f"[add_nodes_batch] Failed to add nodes: {e}", exc_info=True)
|
|
350
|
+
raise
|
|
351
|
+
|
|
352
|
+
def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None:
|
|
353
|
+
"""
|
|
354
|
+
Update node fields in Neo4j, auto-converting `created_at` and `updated_at` to datetime type if present.
|
|
355
|
+
"""
|
|
356
|
+
user_name = user_name if user_name else self.config.user_name
|
|
357
|
+
fields = fields.copy() # Avoid mutating external dict
|
|
358
|
+
set_clauses = []
|
|
359
|
+
params = {"id": id, "fields": fields}
|
|
360
|
+
|
|
361
|
+
for time_field in ("created_at", "updated_at"):
|
|
362
|
+
if time_field in fields:
|
|
363
|
+
# Set clause like: n.created_at = datetime($created_at)
|
|
364
|
+
set_clauses.append(f"n.{time_field} = datetime(${time_field})")
|
|
365
|
+
params[time_field] = fields.pop(time_field)
|
|
366
|
+
|
|
367
|
+
set_clauses.append("n += $fields") # Merge remaining fields
|
|
368
|
+
set_clause_str = ",\n ".join(set_clauses)
|
|
369
|
+
|
|
370
|
+
query = """
|
|
371
|
+
MATCH (n:Memory {id: $id})
|
|
372
|
+
"""
|
|
373
|
+
if not self.config.use_multi_db and (self.config.user_name or user_name):
|
|
374
|
+
query += "\nWHERE n.user_name = $user_name"
|
|
375
|
+
params["user_name"] = user_name
|
|
376
|
+
|
|
377
|
+
query += f"\nSET {set_clause_str}"
|
|
378
|
+
|
|
379
|
+
with self.driver.session(database=self.db_name) as session:
|
|
380
|
+
session.run(query, **params)
|
|
381
|
+
|
|
382
|
+
def delete_node(self, id: str, user_name: str | None = None) -> None:
|
|
383
|
+
"""
|
|
384
|
+
Delete a node from the graph.
|
|
385
|
+
Args:
|
|
386
|
+
id: Node identifier to delete.
|
|
387
|
+
"""
|
|
388
|
+
user_name = user_name if user_name else self.config.user_name
|
|
389
|
+
query = "MATCH (n:Memory {id: $id})"
|
|
390
|
+
|
|
391
|
+
params = {"id": id}
|
|
392
|
+
if not self.config.use_multi_db and (self.config.user_name or user_name):
|
|
393
|
+
query += " WHERE n.user_name = $user_name"
|
|
394
|
+
params["user_name"] = user_name
|
|
395
|
+
|
|
396
|
+
query += " DETACH DELETE n"
|
|
397
|
+
|
|
398
|
+
with self.driver.session(database=self.db_name) as session:
|
|
399
|
+
session.run(query, **params)
|
|
400
|
+
|
|
401
|
+
# Edge (Relationship) Management
|
|
402
|
+
def add_edge(
|
|
403
|
+
self, source_id: str, target_id: str, type: str, user_name: str | None = None
|
|
404
|
+
) -> None:
|
|
405
|
+
"""
|
|
406
|
+
Create an edge from source node to target node.
|
|
407
|
+
Args:
|
|
408
|
+
source_id: ID of the source node.
|
|
409
|
+
target_id: ID of the target node.
|
|
410
|
+
type: Relationship type (e.g., 'RELATE_TO', 'PARENT').
|
|
411
|
+
"""
|
|
412
|
+
user_name = user_name if user_name else self.config.user_name
|
|
413
|
+
query = """
|
|
414
|
+
MATCH (a:Memory {id: $source_id})
|
|
415
|
+
MATCH (b:Memory {id: $target_id})
|
|
416
|
+
"""
|
|
417
|
+
params = {"source_id": source_id, "target_id": target_id}
|
|
418
|
+
if not self.config.use_multi_db and (self.config.user_name or user_name):
|
|
419
|
+
query += """
|
|
420
|
+
WHERE a.user_name = $user_name AND b.user_name = $user_name
|
|
421
|
+
"""
|
|
422
|
+
params["user_name"] = user_name
|
|
423
|
+
|
|
424
|
+
query += f"\nMERGE (a)-[:{type}]->(b)"
|
|
425
|
+
|
|
426
|
+
with self.driver.session(database=self.db_name) as session:
|
|
427
|
+
session.run(query, params)
|
|
428
|
+
|
|
429
|
+
def delete_edge(
|
|
430
|
+
self, source_id: str, target_id: str, type: str, user_name: str | None = None
|
|
431
|
+
) -> None:
|
|
432
|
+
"""
|
|
433
|
+
Delete a specific edge between two nodes.
|
|
434
|
+
Args:
|
|
435
|
+
source_id: ID of the source node.
|
|
436
|
+
target_id: ID of the target node.
|
|
437
|
+
type: Relationship type to remove.
|
|
438
|
+
"""
|
|
439
|
+
user_name = user_name if user_name else self.config.user_name
|
|
440
|
+
query = f"""
|
|
441
|
+
MATCH (a:Memory {{id: $source}})
|
|
442
|
+
-[r:{type}]->
|
|
443
|
+
(b:Memory {{id: $target}})
|
|
444
|
+
"""
|
|
445
|
+
params = {"source": source_id, "target": target_id}
|
|
446
|
+
|
|
447
|
+
if not self.config.use_multi_db and (self.config.user_name or user_name):
|
|
448
|
+
query += "\nWHERE a.user_name = $user_name AND b.user_name = $user_name"
|
|
449
|
+
params["user_name"] = user_name
|
|
450
|
+
|
|
451
|
+
query += "\nDELETE r"
|
|
452
|
+
|
|
453
|
+
with self.driver.session(database=self.db_name) as session:
|
|
454
|
+
session.run(query, params)
|
|
455
|
+
|
|
456
|
+
def edge_exists(
|
|
457
|
+
self,
|
|
458
|
+
source_id: str,
|
|
459
|
+
target_id: str,
|
|
460
|
+
type: str = "ANY",
|
|
461
|
+
direction: str = "OUTGOING",
|
|
462
|
+
user_name: str | None = None,
|
|
463
|
+
) -> bool:
|
|
464
|
+
"""
|
|
465
|
+
Check if an edge exists between two nodes.
|
|
466
|
+
Args:
|
|
467
|
+
source_id: ID of the source node.
|
|
468
|
+
target_id: ID of the target node.
|
|
469
|
+
type: Relationship type. Use "ANY" to match any relationship type.
|
|
470
|
+
direction: Direction of the edge.
|
|
471
|
+
Use "OUTGOING" (default), "INCOMING", or "ANY".
|
|
472
|
+
Returns:
|
|
473
|
+
True if the edge exists, otherwise False.
|
|
474
|
+
"""
|
|
475
|
+
user_name = user_name if user_name else self.config.user_name
|
|
476
|
+
# Prepare the relationship pattern
|
|
477
|
+
rel = "r" if type == "ANY" else f"r:{type}"
|
|
478
|
+
|
|
479
|
+
# Prepare the match pattern with direction
|
|
480
|
+
if direction == "OUTGOING":
|
|
481
|
+
pattern = f"(a:Memory {{id: $source}})-[{rel}]->(b:Memory {{id: $target}})"
|
|
482
|
+
elif direction == "INCOMING":
|
|
483
|
+
pattern = f"(a:Memory {{id: $source}})<-[{rel}]-(b:Memory {{id: $target}})"
|
|
484
|
+
elif direction == "ANY":
|
|
485
|
+
pattern = f"(a:Memory {{id: $source}})-[{rel}]-(b:Memory {{id: $target}})"
|
|
486
|
+
else:
|
|
487
|
+
raise ValueError(
|
|
488
|
+
f"Invalid direction: {direction}. Must be 'OUTGOING', 'INCOMING', or 'ANY'."
|
|
489
|
+
)
|
|
490
|
+
query = f"MATCH {pattern}"
|
|
491
|
+
params = {"source": source_id, "target": target_id}
|
|
492
|
+
|
|
493
|
+
if not self.config.use_multi_db and (self.config.user_name or user_name):
|
|
494
|
+
query += "\nWHERE a.user_name = $user_name AND b.user_name = $user_name"
|
|
495
|
+
params["user_name"] = user_name
|
|
496
|
+
|
|
497
|
+
query += "\nRETURN r"
|
|
498
|
+
|
|
499
|
+
# Run the Cypher query
|
|
500
|
+
with self.driver.session(database=self.db_name) as session:
|
|
501
|
+
result = session.run(query, params)
|
|
502
|
+
return result.single() is not None
|
|
503
|
+
|
|
504
|
+
# Graph Query & Reasoning
|
|
505
|
+
def get_node(self, id: str, **kwargs) -> dict[str, Any] | None:
|
|
506
|
+
"""
|
|
507
|
+
Retrieve the metadata and memory of a node.
|
|
508
|
+
Args:
|
|
509
|
+
id: Node identifier.
|
|
510
|
+
Returns:
|
|
511
|
+
Dictionary of node fields, or None if not found.
|
|
512
|
+
"""
|
|
513
|
+
user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name
|
|
514
|
+
where_user = ""
|
|
515
|
+
params = {"id": id}
|
|
516
|
+
if not self.config.use_multi_db and (self.config.user_name or user_name):
|
|
517
|
+
where_user = " AND n.user_name = $user_name"
|
|
518
|
+
params["user_name"] = user_name
|
|
519
|
+
|
|
520
|
+
query = f"MATCH (n:Memory) WHERE n.id = $id {where_user} RETURN n"
|
|
521
|
+
|
|
522
|
+
with self.driver.session(database=self.db_name) as session:
|
|
523
|
+
record = session.run(query, params).single()
|
|
524
|
+
return self._parse_node(dict(record["n"])) if record else None
|
|
525
|
+
|
|
526
|
+
def get_nodes(self, ids: list[str], **kwargs) -> list[dict[str, Any]]:
|
|
527
|
+
"""
|
|
528
|
+
Retrieve the metadata and memory of a list of nodes.
|
|
529
|
+
Args:
|
|
530
|
+
ids: List of Node identifier.
|
|
531
|
+
Returns:
|
|
532
|
+
list[dict]: Parsed node records containing 'id', 'memory', and 'metadata'.
|
|
533
|
+
|
|
534
|
+
Notes:
|
|
535
|
+
- Assumes all provided IDs are valid and exist.
|
|
536
|
+
- Returns empty list if input is empty.
|
|
537
|
+
"""
|
|
538
|
+
|
|
539
|
+
if not ids:
|
|
540
|
+
return []
|
|
541
|
+
user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name
|
|
542
|
+
where_user = ""
|
|
543
|
+
params = {"ids": ids}
|
|
544
|
+
|
|
545
|
+
if not self.config.use_multi_db and (self.config.user_name or user_name):
|
|
546
|
+
where_user = " AND n.user_name = $user_name"
|
|
547
|
+
if kwargs.get("cube_name"):
|
|
548
|
+
params["user_name"] = kwargs["cube_name"]
|
|
549
|
+
else:
|
|
550
|
+
params["user_name"] = user_name
|
|
551
|
+
|
|
552
|
+
query = f"MATCH (n:Memory) WHERE n.id IN $ids{where_user} RETURN n"
|
|
553
|
+
|
|
554
|
+
with self.driver.session(database=self.db_name) as session:
|
|
555
|
+
results = session.run(query, params)
|
|
556
|
+
return [self._parse_node(dict(record["n"])) for record in results]
|
|
557
|
+
|
|
558
|
+
def get_edges(
|
|
559
|
+
self, id: str, type: str = "ANY", direction: str = "ANY", user_name: str | None = None
|
|
560
|
+
) -> list[dict[str, str]]:
|
|
561
|
+
"""
|
|
562
|
+
Get edges connected to a node, with optional type and direction filter.
|
|
563
|
+
|
|
564
|
+
Args:
|
|
565
|
+
id: Node ID to retrieve edges for.
|
|
566
|
+
type: Relationship type to match, or 'ANY' to match all.
|
|
567
|
+
direction: 'OUTGOING', 'INCOMING', or 'ANY'.
|
|
568
|
+
|
|
569
|
+
Returns:
|
|
570
|
+
List of edges:
|
|
571
|
+
[
|
|
572
|
+
{"from": "source_id", "to": "target_id", "type": "RELATE"},
|
|
573
|
+
...
|
|
574
|
+
]
|
|
575
|
+
"""
|
|
576
|
+
user_name = user_name if user_name else self.config.user_name
|
|
577
|
+
# Build relationship type filter
|
|
578
|
+
rel_type = "" if type == "ANY" else f":{type}"
|
|
579
|
+
|
|
580
|
+
# Build Cypher pattern based on direction
|
|
581
|
+
if direction == "OUTGOING":
|
|
582
|
+
pattern = f"(a:Memory)-[r{rel_type}]->(b:Memory)"
|
|
583
|
+
where_clause = "a.id = $id"
|
|
584
|
+
elif direction == "INCOMING":
|
|
585
|
+
pattern = f"(a:Memory)<-[r{rel_type}]-(b:Memory)"
|
|
586
|
+
where_clause = "a.id = $id"
|
|
587
|
+
elif direction == "ANY":
|
|
588
|
+
pattern = f"(a:Memory)-[r{rel_type}]-(b:Memory)"
|
|
589
|
+
where_clause = "a.id = $id OR b.id = $id"
|
|
590
|
+
else:
|
|
591
|
+
raise ValueError("Invalid direction. Must be 'OUTGOING', 'INCOMING', or 'ANY'.")
|
|
592
|
+
|
|
593
|
+
params = {"id": id}
|
|
594
|
+
|
|
595
|
+
if not self.config.use_multi_db and (self.config.user_name or user_name):
|
|
596
|
+
where_clause += " AND a.user_name = $user_name AND b.user_name = $user_name"
|
|
597
|
+
params["user_name"] = user_name
|
|
598
|
+
|
|
599
|
+
query = f"""
|
|
600
|
+
MATCH {pattern}
|
|
601
|
+
WHERE {where_clause}
|
|
602
|
+
RETURN a.id AS from_id, b.id AS to_id, type(r) AS type
|
|
603
|
+
"""
|
|
604
|
+
|
|
605
|
+
with self.driver.session(database=self.db_name) as session:
|
|
606
|
+
result = session.run(query, params)
|
|
607
|
+
edges = []
|
|
608
|
+
for record in result:
|
|
609
|
+
edges.append(
|
|
610
|
+
{"from": record["from_id"], "to": record["to_id"], "type": record["type"]}
|
|
611
|
+
)
|
|
612
|
+
return edges
|
|
613
|
+
|
|
614
|
+
def get_neighbors(
|
|
615
|
+
self,
|
|
616
|
+
id: str,
|
|
617
|
+
type: str,
|
|
618
|
+
direction: Literal["in", "out", "both"] = "out",
|
|
619
|
+
user_name: str | None = None,
|
|
620
|
+
) -> list[str]:
|
|
621
|
+
"""
|
|
622
|
+
Get connected node IDs in a specific direction and relationship type.
|
|
623
|
+
Args:
|
|
624
|
+
id: Source node ID.
|
|
625
|
+
type: Relationship type.
|
|
626
|
+
direction: Edge direction to follow ('out', 'in', or 'both').
|
|
627
|
+
Returns:
|
|
628
|
+
List of neighboring node IDs.
|
|
629
|
+
"""
|
|
630
|
+
raise NotImplementedError
|
|
631
|
+
|
|
632
|
+
def get_neighbors_by_tag(
|
|
633
|
+
self,
|
|
634
|
+
tags: list[str],
|
|
635
|
+
exclude_ids: list[str],
|
|
636
|
+
top_k: int = 5,
|
|
637
|
+
min_overlap: int = 1,
|
|
638
|
+
user_name: str | None = None,
|
|
639
|
+
) -> list[dict[str, Any]]:
|
|
640
|
+
"""
|
|
641
|
+
Find top-K neighbor nodes with maximum tag overlap.
|
|
642
|
+
|
|
643
|
+
Args:
|
|
644
|
+
tags: The list of tags to match.
|
|
645
|
+
exclude_ids: Node IDs to exclude (e.g., local cluster).
|
|
646
|
+
top_k: Max number of neighbors to return.
|
|
647
|
+
min_overlap: Minimum number of overlapping tags required.
|
|
648
|
+
|
|
649
|
+
Returns:
|
|
650
|
+
List of dicts with node details and overlap count.
|
|
651
|
+
"""
|
|
652
|
+
user_name = user_name if user_name else self.config.user_name
|
|
653
|
+
where_user = ""
|
|
654
|
+
params = {
|
|
655
|
+
"tags": tags,
|
|
656
|
+
"exclude_ids": exclude_ids,
|
|
657
|
+
"min_overlap": min_overlap,
|
|
658
|
+
"top_k": top_k,
|
|
659
|
+
}
|
|
660
|
+
|
|
661
|
+
if not self.config.use_multi_db and (self.config.user_name or user_name):
|
|
662
|
+
where_user = "AND n.user_name = $user_name"
|
|
663
|
+
params["user_name"] = user_name
|
|
664
|
+
|
|
665
|
+
query = f"""
|
|
666
|
+
MATCH (n:Memory)
|
|
667
|
+
WHERE NOT n.id IN $exclude_ids
|
|
668
|
+
AND n.status = 'activated'
|
|
669
|
+
AND n.type <> 'reasoning'
|
|
670
|
+
AND n.memory_type <> 'WorkingMemory'
|
|
671
|
+
{where_user}
|
|
672
|
+
WITH n, [tag IN n.tags WHERE tag IN $tags] AS overlap_tags
|
|
673
|
+
WHERE size(overlap_tags) >= $min_overlap
|
|
674
|
+
RETURN n, size(overlap_tags) AS overlap_count
|
|
675
|
+
ORDER BY overlap_count DESC
|
|
676
|
+
LIMIT $top_k
|
|
677
|
+
"""
|
|
678
|
+
|
|
679
|
+
with self.driver.session(database=self.db_name) as session:
|
|
680
|
+
result = session.run(query, params)
|
|
681
|
+
return [self._parse_node(dict(record["n"])) for record in result]
|
|
682
|
+
|
|
683
|
+
def get_children_with_embeddings(
|
|
684
|
+
self, id: str, user_name: str | None = None
|
|
685
|
+
) -> list[dict[str, Any]]:
|
|
686
|
+
user_name = user_name if user_name else self.config.user_name
|
|
687
|
+
where_user = ""
|
|
688
|
+
params = {"id": id}
|
|
689
|
+
|
|
690
|
+
if not self.config.use_multi_db and (self.config.user_name or user_name):
|
|
691
|
+
where_user = "AND p.user_name = $user_name AND c.user_name = $user_name"
|
|
692
|
+
params["user_name"] = user_name
|
|
693
|
+
|
|
694
|
+
query = f"""
|
|
695
|
+
MATCH (p:Memory)-[:PARENT]->(c:Memory)
|
|
696
|
+
WHERE p.id = $id {where_user}
|
|
697
|
+
RETURN c.id AS id, c.embedding AS embedding, c.memory AS memory
|
|
698
|
+
"""
|
|
699
|
+
|
|
700
|
+
with self.driver.session(database=self.db_name) as session:
|
|
701
|
+
result = session.run(query, params)
|
|
702
|
+
return [
|
|
703
|
+
{"id": r["id"], "embedding": r["embedding"], "memory": r["memory"]} for r in result
|
|
704
|
+
]
|
|
705
|
+
|
|
706
|
+
def get_path(
|
|
707
|
+
self, source_id: str, target_id: str, max_depth: int = 3, user_name: str | None = None
|
|
708
|
+
) -> list[str]:
|
|
709
|
+
"""
|
|
710
|
+
Get the path of nodes from source to target within a limited depth.
|
|
711
|
+
Args:
|
|
712
|
+
source_id: Starting node ID.
|
|
713
|
+
target_id: Target node ID.
|
|
714
|
+
max_depth: Maximum path length to traverse.
|
|
715
|
+
Returns:
|
|
716
|
+
Ordered list of node IDs along the path.
|
|
717
|
+
"""
|
|
718
|
+
raise NotImplementedError
|
|
719
|
+
|
|
720
|
+
def get_subgraph(
|
|
721
|
+
self,
|
|
722
|
+
center_id: str,
|
|
723
|
+
depth: int = 2,
|
|
724
|
+
center_status: str = "activated",
|
|
725
|
+
user_name: str | None = None,
|
|
726
|
+
) -> dict[str, Any]:
|
|
727
|
+
"""
|
|
728
|
+
Retrieve a local subgraph centered at a given node.
|
|
729
|
+
Args:
|
|
730
|
+
center_id: The ID of the center node.
|
|
731
|
+
depth: The hop distance for neighbors.
|
|
732
|
+
center_status: Required status for center node.
|
|
733
|
+
Returns:
|
|
734
|
+
{
|
|
735
|
+
"core_node": {...},
|
|
736
|
+
"neighbors": [...],
|
|
737
|
+
"edges": [...]
|
|
738
|
+
}
|
|
739
|
+
"""
|
|
740
|
+
user_name = user_name if user_name else self.config.user_name
|
|
741
|
+
with self.driver.session(database=self.db_name) as session:
|
|
742
|
+
params = {"center_id": center_id}
|
|
743
|
+
center_user_clause = ""
|
|
744
|
+
neighbor_user_clause = ""
|
|
745
|
+
|
|
746
|
+
if not self.config.use_multi_db and (self.config.user_name or user_name):
|
|
747
|
+
center_user_clause = " AND center.user_name = $user_name"
|
|
748
|
+
neighbor_user_clause = " WHERE neighbor.user_name = $user_name"
|
|
749
|
+
params["user_name"] = user_name
|
|
750
|
+
status_clause = f" AND center.status = '{center_status}'" if center_status else ""
|
|
751
|
+
|
|
752
|
+
query = f"""
|
|
753
|
+
MATCH (center:Memory)
|
|
754
|
+
WHERE center.id = $center_id{status_clause}{center_user_clause}
|
|
755
|
+
|
|
756
|
+
OPTIONAL MATCH (center)-[r*1..{depth}]-(neighbor:Memory)
|
|
757
|
+
{neighbor_user_clause}
|
|
758
|
+
|
|
759
|
+
WITH collect(DISTINCT center) AS centers,
|
|
760
|
+
collect(DISTINCT neighbor) AS neighbors,
|
|
761
|
+
collect(DISTINCT r) AS rels
|
|
762
|
+
RETURN centers, neighbors, rels
|
|
763
|
+
"""
|
|
764
|
+
record = session.run(query, params).single()
|
|
765
|
+
|
|
766
|
+
if not record:
|
|
767
|
+
return {"core_node": None, "neighbors": [], "edges": []}
|
|
768
|
+
|
|
769
|
+
centers = record["centers"]
|
|
770
|
+
if not centers or centers[0] is None:
|
|
771
|
+
return {"core_node": None, "neighbors": [], "edges": []}
|
|
772
|
+
|
|
773
|
+
core_node = self._parse_node(dict(centers[0]))
|
|
774
|
+
neighbors = [self._parse_node(dict(n)) for n in record["neighbors"] if n]
|
|
775
|
+
edges = []
|
|
776
|
+
for rel_chain in record["rels"]:
|
|
777
|
+
for rel in rel_chain:
|
|
778
|
+
edges.append(
|
|
779
|
+
{
|
|
780
|
+
"type": rel.type,
|
|
781
|
+
"source": rel.start_node["id"],
|
|
782
|
+
"target": rel.end_node["id"],
|
|
783
|
+
}
|
|
784
|
+
)
|
|
785
|
+
|
|
786
|
+
return {"core_node": core_node, "neighbors": neighbors, "edges": edges}
|
|
787
|
+
|
|
788
|
+
def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]:
|
|
789
|
+
"""
|
|
790
|
+
Get the ordered context chain starting from a node, following a relationship type.
|
|
791
|
+
Args:
|
|
792
|
+
id: Starting node ID.
|
|
793
|
+
type: Relationship type to follow (e.g., 'FOLLOWS').
|
|
794
|
+
Returns:
|
|
795
|
+
List of ordered node IDs in the chain.
|
|
796
|
+
"""
|
|
797
|
+
raise NotImplementedError
|
|
798
|
+
|
|
799
|
+
# Search / recall operations
|
|
800
|
+
def search_by_embedding(
|
|
801
|
+
self,
|
|
802
|
+
vector: list[float],
|
|
803
|
+
top_k: int = 5,
|
|
804
|
+
scope: str | None = None,
|
|
805
|
+
status: str | None = None,
|
|
806
|
+
threshold: float | None = None,
|
|
807
|
+
search_filter: dict | None = None,
|
|
808
|
+
user_name: str | None = None,
|
|
809
|
+
filter: dict | None = None,
|
|
810
|
+
knowledgebase_ids: list[str] | None = None,
|
|
811
|
+
**kwargs,
|
|
812
|
+
) -> list[dict]:
|
|
813
|
+
"""
|
|
814
|
+
Retrieve node IDs based on vector similarity.
|
|
815
|
+
|
|
816
|
+
Args:
|
|
817
|
+
vector (list[float]): The embedding vector representing query semantics.
|
|
818
|
+
top_k (int): Number of top similar nodes to retrieve.
|
|
819
|
+
scope (str, optional): Memory type filter (e.g., 'WorkingMemory', 'LongTermMemory').
|
|
820
|
+
status (str, optional): Node status filter (e.g., 'activated', 'archived').
|
|
821
|
+
If provided, restricts results to nodes with matching status.
|
|
822
|
+
threshold (float, optional): Minimum similarity score threshold (0 ~ 1).
|
|
823
|
+
search_filter (dict, optional): Additional metadata filters for search results.
|
|
824
|
+
Keys should match node properties, values are the expected values.
|
|
825
|
+
|
|
826
|
+
Returns:
|
|
827
|
+
list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
|
|
828
|
+
|
|
829
|
+
Notes:
|
|
830
|
+
- This method uses Neo4j native vector indexing to search for similar nodes.
|
|
831
|
+
- If scope is provided, it restricts results to nodes with matching memory_type.
|
|
832
|
+
- If 'status' is provided, only nodes with the matching status will be returned.
|
|
833
|
+
- If threshold is provided, only results with score >= threshold will be returned.
|
|
834
|
+
- If search_filter is provided, additional WHERE clauses will be added for metadata filtering.
|
|
835
|
+
- Typical use case: restrict to 'status = activated' to avoid
|
|
836
|
+
matching archived or merged nodes.
|
|
837
|
+
"""
|
|
838
|
+
user_name = user_name if user_name else self.config.user_name
|
|
839
|
+
# Build WHERE clause dynamically
|
|
840
|
+
where_clauses = []
|
|
841
|
+
if scope:
|
|
842
|
+
where_clauses.append("node.memory_type = $scope")
|
|
843
|
+
if status:
|
|
844
|
+
where_clauses.append("node.status = $status")
|
|
845
|
+
|
|
846
|
+
# Build user_name filter with knowledgebase_ids support (OR relationship) using common method
|
|
847
|
+
user_name_conditions, user_name_params = self._build_user_name_and_kb_ids_conditions_cypher(
|
|
848
|
+
user_name=user_name,
|
|
849
|
+
knowledgebase_ids=knowledgebase_ids,
|
|
850
|
+
default_user_name=self.config.user_name,
|
|
851
|
+
node_alias="node",
|
|
852
|
+
)
|
|
853
|
+
|
|
854
|
+
# Add user_name WHERE clause
|
|
855
|
+
if user_name_conditions:
|
|
856
|
+
if len(user_name_conditions) == 1:
|
|
857
|
+
where_clauses.append(user_name_conditions[0])
|
|
858
|
+
else:
|
|
859
|
+
where_clauses.append(f"({' OR '.join(user_name_conditions)})")
|
|
860
|
+
|
|
861
|
+
# Add search_filter conditions
|
|
862
|
+
if search_filter:
|
|
863
|
+
for key, _ in search_filter.items():
|
|
864
|
+
param_name = f"filter_{key}"
|
|
865
|
+
where_clauses.append(f"node.{key} = ${param_name}")
|
|
866
|
+
|
|
867
|
+
# Build filter conditions using common method
|
|
868
|
+
filter_conditions, filter_params = self._build_filter_conditions_cypher(
|
|
869
|
+
filter=filter,
|
|
870
|
+
param_counter_start=0,
|
|
871
|
+
node_alias="node",
|
|
872
|
+
)
|
|
873
|
+
where_clauses.extend(filter_conditions)
|
|
874
|
+
|
|
875
|
+
where_clause = ""
|
|
876
|
+
if where_clauses:
|
|
877
|
+
where_clause = "WHERE " + " AND ".join(where_clauses)
|
|
878
|
+
|
|
879
|
+
query = f"""
|
|
880
|
+
CALL db.index.vector.queryNodes('memory_vector_index', $k, $embedding)
|
|
881
|
+
YIELD node, score
|
|
882
|
+
{where_clause}
|
|
883
|
+
RETURN node.id AS id, score
|
|
884
|
+
"""
|
|
885
|
+
|
|
886
|
+
parameters = {"embedding": vector, "k": top_k}
|
|
887
|
+
|
|
888
|
+
if scope:
|
|
889
|
+
parameters["scope"] = scope
|
|
890
|
+
if status:
|
|
891
|
+
parameters["status"] = status
|
|
892
|
+
|
|
893
|
+
# Add user_name and knowledgebase_ids parameters using common method
|
|
894
|
+
parameters.update(user_name_params)
|
|
895
|
+
|
|
896
|
+
# Handle cube_name override for user_name
|
|
897
|
+
if kwargs.get("cube_name"):
|
|
898
|
+
parameters["user_name"] = kwargs["cube_name"]
|
|
899
|
+
|
|
900
|
+
if search_filter:
|
|
901
|
+
for key, value in search_filter.items():
|
|
902
|
+
param_name = f"filter_{key}"
|
|
903
|
+
parameters[param_name] = value
|
|
904
|
+
|
|
905
|
+
# Add filter parameters
|
|
906
|
+
if filter_params:
|
|
907
|
+
parameters.update(filter_params)
|
|
908
|
+
|
|
909
|
+
logger.info(f"[search_by_embedding] query: {query},parameters: {parameters}")
|
|
910
|
+
print(f"[search_by_embedding] query: {query},parameters: {parameters}")
|
|
911
|
+
with self.driver.session(database=self.db_name) as session:
|
|
912
|
+
result = session.run(query, parameters)
|
|
913
|
+
records = [{"id": record["id"], "score": record["score"]} for record in result]
|
|
914
|
+
|
|
915
|
+
# Threshold filtering after retrieval
|
|
916
|
+
if threshold is not None:
|
|
917
|
+
records = [r for r in records if r["score"] >= threshold]
|
|
918
|
+
|
|
919
|
+
return records
|
|
920
|
+
|
|
921
|
+
def get_by_metadata(
|
|
922
|
+
self,
|
|
923
|
+
filters: list[dict[str, Any]],
|
|
924
|
+
user_name: str | None = None,
|
|
925
|
+
filter: dict | None = None,
|
|
926
|
+
knowledgebase_ids: list[str] | None = None,
|
|
927
|
+
user_name_flag: bool = True,
|
|
928
|
+
status: str | None = None,
|
|
929
|
+
) -> list[str]:
|
|
930
|
+
"""
|
|
931
|
+
TODO:
|
|
932
|
+
1. ADD logic: "AND" vs "OR"(support logic combination);
|
|
933
|
+
2. Support nested conditional expressions;
|
|
934
|
+
|
|
935
|
+
Retrieve node IDs that match given metadata filters.
|
|
936
|
+
Supports exact match.
|
|
937
|
+
|
|
938
|
+
Args:
|
|
939
|
+
filters: List of filter dicts like:
|
|
940
|
+
[
|
|
941
|
+
{"field": "key", "op": "in", "value": ["A", "B"]},
|
|
942
|
+
{"field": "confidence", "op": ">=", "value": 80},
|
|
943
|
+
{"field": "tags", "op": "contains", "value": "AI"},
|
|
944
|
+
...
|
|
945
|
+
]
|
|
946
|
+
status (str, optional): Filter by status (e.g., 'activated', 'archived').
|
|
947
|
+
If None, no status filter is applied.
|
|
948
|
+
|
|
949
|
+
Returns:
|
|
950
|
+
list[str]: Node IDs whose metadata match the filter conditions. (AND logic).
|
|
951
|
+
|
|
952
|
+
Notes:
|
|
953
|
+
- Supports structured querying such as tag/category/importance/time filtering.
|
|
954
|
+
- Can be used for faceted recall or prefiltering before embedding rerank.
|
|
955
|
+
"""
|
|
956
|
+
logger.info(
|
|
957
|
+
f"[get_by_metadata] filters: {filters},user_name: {user_name},filter: {filter},knowledgebase_ids: {knowledgebase_ids},status: {status}"
|
|
958
|
+
)
|
|
959
|
+
print(
|
|
960
|
+
f"[get_by_metadata] filters: {filters},user_name: {user_name},filter: {filter},knowledgebase_ids: {knowledgebase_ids},status: {status}"
|
|
961
|
+
)
|
|
962
|
+
user_name = user_name if user_name else self.config.user_name
|
|
963
|
+
where_clauses = []
|
|
964
|
+
params = {}
|
|
965
|
+
|
|
966
|
+
# Add status filter if provided
|
|
967
|
+
if status:
|
|
968
|
+
where_clauses.append("n.status = $status")
|
|
969
|
+
params["status"] = status
|
|
970
|
+
|
|
971
|
+
for i, f in enumerate(filters):
|
|
972
|
+
field = f["field"]
|
|
973
|
+
op = f.get("op", "=")
|
|
974
|
+
value = f["value"]
|
|
975
|
+
param_key = f"val{i}"
|
|
976
|
+
|
|
977
|
+
# Build WHERE clause
|
|
978
|
+
if op == "=":
|
|
979
|
+
where_clauses.append(f"n.{field} = ${param_key}")
|
|
980
|
+
params[param_key] = value
|
|
981
|
+
elif op == "in":
|
|
982
|
+
where_clauses.append(f"n.{field} IN ${param_key}")
|
|
983
|
+
params[param_key] = value
|
|
984
|
+
elif op == "contains":
|
|
985
|
+
where_clauses.append(f"ANY(x IN ${param_key} WHERE x IN n.{field})")
|
|
986
|
+
params[param_key] = value
|
|
987
|
+
elif op == "starts_with":
|
|
988
|
+
where_clauses.append(f"n.{field} STARTS WITH ${param_key}")
|
|
989
|
+
params[param_key] = value
|
|
990
|
+
elif op == "ends_with":
|
|
991
|
+
where_clauses.append(f"n.{field} ENDS WITH ${param_key}")
|
|
992
|
+
params[param_key] = value
|
|
993
|
+
elif op in [">", ">=", "<", "<="]:
|
|
994
|
+
where_clauses.append(f"n.{field} {op} ${param_key}")
|
|
995
|
+
params[param_key] = value
|
|
996
|
+
else:
|
|
997
|
+
raise ValueError(f"Unsupported operator: {op}")
|
|
998
|
+
|
|
999
|
+
# Build user_name filter with knowledgebase_ids support (OR relationship) using common method
|
|
1000
|
+
user_name_conditions = []
|
|
1001
|
+
user_name_params = {}
|
|
1002
|
+
if user_name_flag:
|
|
1003
|
+
user_name_conditions, user_name_params = (
|
|
1004
|
+
self._build_user_name_and_kb_ids_conditions_cypher(
|
|
1005
|
+
user_name=user_name,
|
|
1006
|
+
knowledgebase_ids=knowledgebase_ids,
|
|
1007
|
+
default_user_name=self.config.user_name,
|
|
1008
|
+
node_alias="n",
|
|
1009
|
+
)
|
|
1010
|
+
)
|
|
1011
|
+
print(
|
|
1012
|
+
f"[get_by_metadata] user_name_conditions: {user_name_conditions},user_name_params: {user_name_params}"
|
|
1013
|
+
)
|
|
1014
|
+
|
|
1015
|
+
# Add user_name WHERE clause
|
|
1016
|
+
if user_name_conditions:
|
|
1017
|
+
if len(user_name_conditions) == 1:
|
|
1018
|
+
where_clauses.append(user_name_conditions[0])
|
|
1019
|
+
else:
|
|
1020
|
+
where_clauses.append(f"({' OR '.join(user_name_conditions)})")
|
|
1021
|
+
|
|
1022
|
+
# Build filter conditions using common method
|
|
1023
|
+
filter_conditions, filter_params = self._build_filter_conditions_cypher(
|
|
1024
|
+
filter=filter,
|
|
1025
|
+
param_counter_start=len(filters), # Start from len(filters) to avoid conflicts
|
|
1026
|
+
node_alias="n",
|
|
1027
|
+
)
|
|
1028
|
+
where_clauses.extend(filter_conditions)
|
|
1029
|
+
|
|
1030
|
+
where_str = " AND ".join(where_clauses) if where_clauses else ""
|
|
1031
|
+
if where_str:
|
|
1032
|
+
query = f"MATCH (n:Memory) WHERE {where_str} RETURN n.id AS id"
|
|
1033
|
+
else:
|
|
1034
|
+
query = "MATCH (n:Memory) RETURN n.id AS id"
|
|
1035
|
+
|
|
1036
|
+
# Add user_name and knowledgebase_ids parameters using common method
|
|
1037
|
+
params.update(user_name_params)
|
|
1038
|
+
|
|
1039
|
+
# Merge filter parameters
|
|
1040
|
+
if filter_params:
|
|
1041
|
+
params.update(filter_params)
|
|
1042
|
+
logger.info(f"[get_by_metadata] query: {query},params: {params}")
|
|
1043
|
+
print(f"[get_by_metadata] query: {query},params: {params}")
|
|
1044
|
+
|
|
1045
|
+
with self.driver.session(database=self.db_name) as session:
|
|
1046
|
+
result = session.run(query, params)
|
|
1047
|
+
return [record["id"] for record in result]
|
|
1048
|
+
|
|
1049
|
+
def get_grouped_counts(
|
|
1050
|
+
self,
|
|
1051
|
+
group_fields: list[str],
|
|
1052
|
+
where_clause: str = "",
|
|
1053
|
+
params: dict[str, Any] | None = None,
|
|
1054
|
+
user_name: str | None = None,
|
|
1055
|
+
) -> list[dict[str, Any]]:
|
|
1056
|
+
"""
|
|
1057
|
+
Count nodes grouped by any fields.
|
|
1058
|
+
|
|
1059
|
+
Args:
|
|
1060
|
+
group_fields (list[str]): Fields to group by, e.g., ["memory_type", "status"]
|
|
1061
|
+
where_clause (str, optional): Extra WHERE condition. E.g.,
|
|
1062
|
+
"WHERE n.status = 'activated'"
|
|
1063
|
+
params (dict, optional): Parameters for WHERE clause.
|
|
1064
|
+
|
|
1065
|
+
Returns:
|
|
1066
|
+
list[dict]: e.g., [{ 'memory_type': 'WorkingMemory', 'status': 'active', 'count': 10 }, ...]
|
|
1067
|
+
"""
|
|
1068
|
+
user_name = user_name if user_name else self.config.user_name
|
|
1069
|
+
if not group_fields:
|
|
1070
|
+
raise ValueError("group_fields cannot be empty")
|
|
1071
|
+
|
|
1072
|
+
final_params = params.copy() if params else {}
|
|
1073
|
+
|
|
1074
|
+
if not self.config.use_multi_db and (self.config.user_name or user_name):
|
|
1075
|
+
user_clause = "n.user_name = $user_name"
|
|
1076
|
+
final_params["user_name"] = user_name
|
|
1077
|
+
if where_clause:
|
|
1078
|
+
where_clause = where_clause.strip()
|
|
1079
|
+
if where_clause.upper().startswith("WHERE"):
|
|
1080
|
+
where_clause += f" AND {user_clause}"
|
|
1081
|
+
else:
|
|
1082
|
+
where_clause = f"WHERE {where_clause} AND {user_clause}"
|
|
1083
|
+
else:
|
|
1084
|
+
where_clause = f"WHERE {user_clause}"
|
|
1085
|
+
|
|
1086
|
+
# Force RETURN field AS field to guarantee key match
|
|
1087
|
+
group_fields_cypher = ", ".join([f"n.{field} AS {field}" for field in group_fields])
|
|
1088
|
+
|
|
1089
|
+
query = f"""
|
|
1090
|
+
MATCH (n:Memory)
|
|
1091
|
+
{where_clause}
|
|
1092
|
+
RETURN {group_fields_cypher}, COUNT(n) AS count
|
|
1093
|
+
"""
|
|
1094
|
+
|
|
1095
|
+
with self.driver.session(database=self.db_name) as session:
|
|
1096
|
+
result = session.run(query, final_params)
|
|
1097
|
+
return [
|
|
1098
|
+
{**{field: record[field] for field in group_fields}, "count": record["count"]}
|
|
1099
|
+
for record in result
|
|
1100
|
+
]
|
|
1101
|
+
|
|
1102
|
+
# Structure Maintenance
|
|
1103
|
+
def deduplicate_nodes(self) -> None:
|
|
1104
|
+
"""
|
|
1105
|
+
Deduplicate redundant or semantically similar nodes.
|
|
1106
|
+
This typically involves identifying nodes with identical or near-identical memory.
|
|
1107
|
+
"""
|
|
1108
|
+
raise NotImplementedError
|
|
1109
|
+
|
|
1110
|
+
def detect_conflicts(self) -> list[tuple[str, str]]:
|
|
1111
|
+
"""
|
|
1112
|
+
Detect conflicting nodes based on logical or semantic inconsistency.
|
|
1113
|
+
Returns:
|
|
1114
|
+
A list of (node_id1, node_id2) tuples that conflict.
|
|
1115
|
+
"""
|
|
1116
|
+
raise NotImplementedError
|
|
1117
|
+
|
|
1118
|
+
def merge_nodes(self, id1: str, id2: str) -> str:
|
|
1119
|
+
"""
|
|
1120
|
+
Merge two similar or duplicate nodes into one.
|
|
1121
|
+
Args:
|
|
1122
|
+
id1: First node ID.
|
|
1123
|
+
id2: Second node ID.
|
|
1124
|
+
Returns:
|
|
1125
|
+
ID of the resulting merged node.
|
|
1126
|
+
"""
|
|
1127
|
+
raise NotImplementedError
|
|
1128
|
+
|
|
1129
|
+
# Utilities
|
|
1130
|
+
def clear(self, user_name: str | None = None) -> None:
|
|
1131
|
+
"""
|
|
1132
|
+
Clear the entire graph if the target database exists.
|
|
1133
|
+
"""
|
|
1134
|
+
user_name = user_name if user_name else self.config.user_name
|
|
1135
|
+
try:
|
|
1136
|
+
if not self.config.use_multi_db and (self.config.user_name or user_name):
|
|
1137
|
+
query = "MATCH (n:Memory) WHERE n.user_name = $user_name DETACH DELETE n"
|
|
1138
|
+
params = {"user_name": user_name}
|
|
1139
|
+
else:
|
|
1140
|
+
query = "MATCH (n) DETACH DELETE n"
|
|
1141
|
+
params = {}
|
|
1142
|
+
|
|
1143
|
+
# Step 2: Clear the graph in that database
|
|
1144
|
+
with self.driver.session(database=self.db_name) as session:
|
|
1145
|
+
session.run(query, params)
|
|
1146
|
+
logger.info(f"Cleared all nodes from database '{self.db_name}'.")
|
|
1147
|
+
|
|
1148
|
+
except Exception as e:
|
|
1149
|
+
logger.error(f"[ERROR] Failed to clear database '{self.db_name}': {e}")
|
|
1150
|
+
raise
|
|
1151
|
+
|
|
1152
|
+
def export_graph(
|
|
1153
|
+
self,
|
|
1154
|
+
page: int | None = None,
|
|
1155
|
+
page_size: int | None = None,
|
|
1156
|
+
**kwargs,
|
|
1157
|
+
) -> dict[str, Any]:
|
|
1158
|
+
"""
|
|
1159
|
+
Export all graph nodes and edges in a structured form.
|
|
1160
|
+
|
|
1161
|
+
Args:
|
|
1162
|
+
page (int, optional): Page number (starts from 1). If None, exports all data without pagination.
|
|
1163
|
+
page_size (int, optional): Number of items per page. If None, exports all data without pagination.
|
|
1164
|
+
**kwargs: Additional keyword arguments, including:
|
|
1165
|
+
- user_name (str, optional): User name for filtering in non-multi-db mode
|
|
1166
|
+
|
|
1167
|
+
Returns:
|
|
1168
|
+
{
|
|
1169
|
+
"nodes": [ { "id": ..., "memory": ..., "metadata": {...} }, ... ],
|
|
1170
|
+
"edges": [ { "source": ..., "target": ..., "type": ... }, ... ],
|
|
1171
|
+
"total_nodes": int, # Total number of nodes matching the filter criteria
|
|
1172
|
+
"total_edges": int, # Total number of edges matching the filter criteria
|
|
1173
|
+
}
|
|
1174
|
+
"""
|
|
1175
|
+
user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name
|
|
1176
|
+
|
|
1177
|
+
# Initialize total counts
|
|
1178
|
+
total_nodes = 0
|
|
1179
|
+
total_edges = 0
|
|
1180
|
+
|
|
1181
|
+
# Determine if pagination is needed
|
|
1182
|
+
use_pagination = page is not None and page_size is not None
|
|
1183
|
+
|
|
1184
|
+
# Validate pagination parameters if pagination is enabled
|
|
1185
|
+
if use_pagination:
|
|
1186
|
+
if page < 1:
|
|
1187
|
+
page = 1
|
|
1188
|
+
if page_size < 1:
|
|
1189
|
+
page_size = 10
|
|
1190
|
+
skip = (page - 1) * page_size
|
|
1191
|
+
|
|
1192
|
+
with self.driver.session(database=self.db_name) as session:
|
|
1193
|
+
# Build base queries
|
|
1194
|
+
node_base_query = "MATCH (n:Memory)"
|
|
1195
|
+
edge_base_query = "MATCH (a:Memory)-[r]->(b:Memory)"
|
|
1196
|
+
params = {}
|
|
1197
|
+
|
|
1198
|
+
if not self.config.use_multi_db and (self.config.user_name or user_name):
|
|
1199
|
+
node_base_query += " WHERE n.user_name = $user_name"
|
|
1200
|
+
edge_base_query += " WHERE a.user_name = $user_name AND b.user_name = $user_name"
|
|
1201
|
+
params["user_name"] = user_name
|
|
1202
|
+
|
|
1203
|
+
# Get total count of nodes before pagination
|
|
1204
|
+
count_node_query = node_base_query + " RETURN COUNT(n) AS count"
|
|
1205
|
+
count_node_result = session.run(count_node_query, params)
|
|
1206
|
+
total_nodes = count_node_result.single()["count"]
|
|
1207
|
+
|
|
1208
|
+
# Export nodes with ORDER BY created_at DESC
|
|
1209
|
+
node_query = node_base_query + " RETURN n ORDER BY n.created_at DESC, n.id DESC"
|
|
1210
|
+
if use_pagination:
|
|
1211
|
+
node_query += f" SKIP {skip} LIMIT {page_size}"
|
|
1212
|
+
|
|
1213
|
+
node_result = session.run(node_query, params)
|
|
1214
|
+
nodes = [self._parse_node(dict(record["n"])) for record in node_result]
|
|
1215
|
+
|
|
1216
|
+
# Get total count of edges before pagination
|
|
1217
|
+
count_edge_query = edge_base_query + " RETURN COUNT(r) AS count"
|
|
1218
|
+
count_edge_result = session.run(count_edge_query, params)
|
|
1219
|
+
total_edges = count_edge_result.single()["count"]
|
|
1220
|
+
|
|
1221
|
+
# Export edges with ORDER BY created_at DESC
|
|
1222
|
+
edge_query = (
|
|
1223
|
+
edge_base_query
|
|
1224
|
+
+ " RETURN a.id AS source, b.id AS target, type(r) AS type ORDER BY a.created_at DESC, b.created_at DESC, a.id DESC, b.id DESC"
|
|
1225
|
+
)
|
|
1226
|
+
if use_pagination:
|
|
1227
|
+
edge_query += f" SKIP {skip} LIMIT {page_size}"
|
|
1228
|
+
|
|
1229
|
+
edge_result = session.run(edge_query, params)
|
|
1230
|
+
edges = [
|
|
1231
|
+
{"source": record["source"], "target": record["target"], "type": record["type"]}
|
|
1232
|
+
for record in edge_result
|
|
1233
|
+
]
|
|
1234
|
+
|
|
1235
|
+
return {
|
|
1236
|
+
"nodes": nodes,
|
|
1237
|
+
"edges": edges,
|
|
1238
|
+
"total_nodes": total_nodes,
|
|
1239
|
+
"total_edges": total_edges,
|
|
1240
|
+
}
|
|
1241
|
+
|
|
1242
|
+
def import_graph(self, data: dict[str, Any], user_name: str | None = None) -> None:
|
|
1243
|
+
"""
|
|
1244
|
+
Import the entire graph from a serialized dictionary.
|
|
1245
|
+
|
|
1246
|
+
Args:
|
|
1247
|
+
data: A dictionary containing all nodes and edges to be loaded.
|
|
1248
|
+
"""
|
|
1249
|
+
user_name = user_name if user_name else self.config.user_name
|
|
1250
|
+
with self.driver.session(database=self.db_name) as session:
|
|
1251
|
+
for node in data.get("nodes", []):
|
|
1252
|
+
id, memory, metadata = _compose_node(node)
|
|
1253
|
+
|
|
1254
|
+
if not self.config.use_multi_db and (self.config.user_name or user_name):
|
|
1255
|
+
metadata["user_name"] = user_name
|
|
1256
|
+
|
|
1257
|
+
metadata = _prepare_node_metadata(metadata)
|
|
1258
|
+
|
|
1259
|
+
# Merge node and set metadata
|
|
1260
|
+
created_at = metadata.pop("created_at")
|
|
1261
|
+
updated_at = metadata.pop("updated_at")
|
|
1262
|
+
|
|
1263
|
+
session.run(
|
|
1264
|
+
"""
|
|
1265
|
+
MERGE (n:Memory {id: $id})
|
|
1266
|
+
SET n.memory = $memory,
|
|
1267
|
+
n.created_at = datetime($created_at),
|
|
1268
|
+
n.updated_at = datetime($updated_at),
|
|
1269
|
+
n += $metadata
|
|
1270
|
+
""",
|
|
1271
|
+
id=id,
|
|
1272
|
+
memory=memory,
|
|
1273
|
+
created_at=created_at,
|
|
1274
|
+
updated_at=updated_at,
|
|
1275
|
+
metadata=metadata,
|
|
1276
|
+
)
|
|
1277
|
+
|
|
1278
|
+
for edge in data.get("edges", []):
|
|
1279
|
+
session.run(
|
|
1280
|
+
f"""
|
|
1281
|
+
MATCH (a:Memory {{id: $source_id}})
|
|
1282
|
+
MATCH (b:Memory {{id: $target_id}})
|
|
1283
|
+
MERGE (a)-[:{edge["type"]}]->(b)
|
|
1284
|
+
""",
|
|
1285
|
+
source_id=edge["source"],
|
|
1286
|
+
target_id=edge["target"],
|
|
1287
|
+
)
|
|
1288
|
+
|
|
1289
|
+
def get_all_memory_items(
|
|
1290
|
+
self,
|
|
1291
|
+
scope: str,
|
|
1292
|
+
include_embedding: bool = False,
|
|
1293
|
+
filter: dict | None = None,
|
|
1294
|
+
knowledgebase_ids: list[str] | None = None,
|
|
1295
|
+
status: str | None = None,
|
|
1296
|
+
**kwargs,
|
|
1297
|
+
) -> list[dict]:
|
|
1298
|
+
"""
|
|
1299
|
+
Retrieve all memory items of a specific memory_type.
|
|
1300
|
+
|
|
1301
|
+
Args:
|
|
1302
|
+
scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'.
|
|
1303
|
+
include_embedding (bool): Whether to include embedding in results.
|
|
1304
|
+
filter (dict, optional): Filter conditions with 'and' or 'or' logic for search results.
|
|
1305
|
+
Example: {"and": [{"id": "xxx"}, {"A": "yyy"}]} or {"or": [{"id": "xxx"}, {"A": "yyy"}]}
|
|
1306
|
+
knowledgebase_ids (list[str], optional): List of knowledgebase IDs to filter by.
|
|
1307
|
+
status (str, optional): Filter by status (e.g., 'activated', 'archived').
|
|
1308
|
+
If None, no status filter is applied.
|
|
1309
|
+
|
|
1310
|
+
Returns:
|
|
1311
|
+
list[dict]: Full list of memory items under this scope.
|
|
1312
|
+
"""
|
|
1313
|
+
logger.info(
|
|
1314
|
+
f"[get_all_memory_items] scope: {scope},filter: {filter},knowledgebase_ids: {knowledgebase_ids},status: {status}"
|
|
1315
|
+
)
|
|
1316
|
+
print(
|
|
1317
|
+
f"[get_all_memory_items] scope: {scope},filter: {filter},knowledgebase_ids: {knowledgebase_ids},status: {status}"
|
|
1318
|
+
)
|
|
1319
|
+
|
|
1320
|
+
user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name
|
|
1321
|
+
if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"}:
|
|
1322
|
+
raise ValueError(f"Unsupported memory type scope: {scope}")
|
|
1323
|
+
|
|
1324
|
+
where_clauses = ["n.memory_type = $scope"]
|
|
1325
|
+
params = {"scope": scope}
|
|
1326
|
+
|
|
1327
|
+
# Add status filter if provided
|
|
1328
|
+
if status:
|
|
1329
|
+
where_clauses.append("n.status = $status")
|
|
1330
|
+
params["status"] = status
|
|
1331
|
+
|
|
1332
|
+
# Build user_name filter with knowledgebase_ids support (OR relationship) using common method
|
|
1333
|
+
user_name_conditions, user_name_params = self._build_user_name_and_kb_ids_conditions_cypher(
|
|
1334
|
+
user_name=user_name,
|
|
1335
|
+
knowledgebase_ids=knowledgebase_ids,
|
|
1336
|
+
default_user_name=self.config.user_name,
|
|
1337
|
+
node_alias="n",
|
|
1338
|
+
)
|
|
1339
|
+
|
|
1340
|
+
# Add user_name WHERE clause
|
|
1341
|
+
if user_name_conditions:
|
|
1342
|
+
if len(user_name_conditions) == 1:
|
|
1343
|
+
where_clauses.append(user_name_conditions[0])
|
|
1344
|
+
else:
|
|
1345
|
+
where_clauses.append(f"({' OR '.join(user_name_conditions)})")
|
|
1346
|
+
|
|
1347
|
+
# Build filter conditions using common method
|
|
1348
|
+
filter_conditions, filter_params = self._build_filter_conditions_cypher(
|
|
1349
|
+
filter=filter,
|
|
1350
|
+
param_counter_start=0,
|
|
1351
|
+
node_alias="n",
|
|
1352
|
+
)
|
|
1353
|
+
where_clauses.extend(filter_conditions)
|
|
1354
|
+
|
|
1355
|
+
where_clause = "WHERE " + " AND ".join(where_clauses)
|
|
1356
|
+
|
|
1357
|
+
# Add user_name and knowledgebase_ids parameters using common method
|
|
1358
|
+
params.update(user_name_params)
|
|
1359
|
+
|
|
1360
|
+
# Add filter parameters
|
|
1361
|
+
if filter_params:
|
|
1362
|
+
params.update(filter_params)
|
|
1363
|
+
|
|
1364
|
+
query = f"""
|
|
1365
|
+
MATCH (n:Memory)
|
|
1366
|
+
{where_clause}
|
|
1367
|
+
RETURN n
|
|
1368
|
+
"""
|
|
1369
|
+
logger.info(f"[get_all_memory_items] query: {query},params: {params}")
|
|
1370
|
+
print(f"[get_all_memory_items] query: {query},params: {params}")
|
|
1371
|
+
|
|
1372
|
+
with self.driver.session(database=self.db_name) as session:
|
|
1373
|
+
results = session.run(query, params)
|
|
1374
|
+
return [self._parse_node(dict(record["n"])) for record in results]
|
|
1375
|
+
|
|
1376
|
+
def get_structure_optimization_candidates(self, scope: str, **kwargs) -> list[dict]:
|
|
1377
|
+
"""
|
|
1378
|
+
Find nodes that are likely candidates for structure optimization:
|
|
1379
|
+
- Isolated nodes, nodes with empty background, or nodes with exactly one child.
|
|
1380
|
+
- Plus: the child of any parent node that has exactly one child.
|
|
1381
|
+
"""
|
|
1382
|
+
user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name
|
|
1383
|
+
where_clause = """
|
|
1384
|
+
WHERE n.memory_type = $scope
|
|
1385
|
+
AND n.status = 'activated'
|
|
1386
|
+
AND NOT ( (n)-[:PARENT]->() OR ()-[:PARENT]->(n) )
|
|
1387
|
+
"""
|
|
1388
|
+
params = {"scope": scope}
|
|
1389
|
+
|
|
1390
|
+
if not self.config.use_multi_db and (self.config.user_name or user_name):
|
|
1391
|
+
where_clause += " AND n.user_name = $user_name"
|
|
1392
|
+
params["user_name"] = user_name
|
|
1393
|
+
|
|
1394
|
+
query = f"""
|
|
1395
|
+
MATCH (n:Memory)
|
|
1396
|
+
{where_clause}
|
|
1397
|
+
RETURN n.id AS id, n AS node
|
|
1398
|
+
"""
|
|
1399
|
+
|
|
1400
|
+
with self.driver.session(database=self.db_name) as session:
|
|
1401
|
+
results = session.run(query, params)
|
|
1402
|
+
return [
|
|
1403
|
+
self._parse_node({"id": record["id"], **dict(record["node"])}) for record in results
|
|
1404
|
+
]
|
|
1405
|
+
|
|
1406
|
+
def drop_database(self) -> None:
|
|
1407
|
+
"""
|
|
1408
|
+
Permanently delete the entire database this instance is using.
|
|
1409
|
+
WARNING: This operation is destructive and cannot be undone.
|
|
1410
|
+
"""
|
|
1411
|
+
if self.config.use_multi_db:
|
|
1412
|
+
if self.db_name in ("system", "neo4j"):
|
|
1413
|
+
raise ValueError(f"Refusing to drop protected database: {self.db_name}")
|
|
1414
|
+
|
|
1415
|
+
with self.driver.session(database=self.system_db_name) as session:
|
|
1416
|
+
session.run(f"DROP DATABASE {self.db_name} IF EXISTS")
|
|
1417
|
+
logger.info(f"Database '{self.db_name}' has been dropped.")
|
|
1418
|
+
else:
|
|
1419
|
+
raise ValueError(
|
|
1420
|
+
f"Refusing to drop protected database: {self.db_name} in "
|
|
1421
|
+
f"Shared Database Multi-Tenant mode"
|
|
1422
|
+
)
|
|
1423
|
+
|
|
1424
|
+
def _ensure_database_exists(self):
|
|
1425
|
+
from neo4j.exceptions import ClientError
|
|
1426
|
+
|
|
1427
|
+
try:
|
|
1428
|
+
with self.driver.session(database="system") as session:
|
|
1429
|
+
session.run(f"CREATE DATABASE `{self.db_name}` IF NOT EXISTS")
|
|
1430
|
+
except ClientError as e:
|
|
1431
|
+
if "Unsupported administration command" in str(
|
|
1432
|
+
e
|
|
1433
|
+
) or "Unsupported administration" in str(e):
|
|
1434
|
+
logger.warning(
|
|
1435
|
+
f"Could not create database '{self.db_name}' because this Neo4j instance "
|
|
1436
|
+
"(likely Community Edition) does not support administrative commands. "
|
|
1437
|
+
"Please ensure the database exists manually or use the default 'neo4j' database."
|
|
1438
|
+
)
|
|
1439
|
+
return
|
|
1440
|
+
if "ExistingDatabaseFound" in str(e):
|
|
1441
|
+
pass # Ignore, database already exists
|
|
1442
|
+
else:
|
|
1443
|
+
raise
|
|
1444
|
+
|
|
1445
|
+
# Wait until the database is available
|
|
1446
|
+
for _ in range(10):
|
|
1447
|
+
with self.driver.session(database=self.system_db_name) as session:
|
|
1448
|
+
result = session.run(
|
|
1449
|
+
"SHOW DATABASES YIELD name, currentStatus RETURN name, currentStatus"
|
|
1450
|
+
)
|
|
1451
|
+
status_map = {r["name"]: r["currentStatus"] for r in result}
|
|
1452
|
+
if self.db_name in status_map and status_map[self.db_name] == "online":
|
|
1453
|
+
return
|
|
1454
|
+
time.sleep(1)
|
|
1455
|
+
|
|
1456
|
+
raise RuntimeError(f"Database {self.db_name} not ready after waiting.")
|
|
1457
|
+
|
|
1458
|
+
def _vector_index_exists(self, index_name: str = "memory_vector_index") -> bool:
|
|
1459
|
+
query = "SHOW INDEXES YIELD name WHERE name = $name RETURN name"
|
|
1460
|
+
with self.driver.session(database=self.db_name) as session:
|
|
1461
|
+
result = session.run(query, name=index_name)
|
|
1462
|
+
return result.single() is not None
|
|
1463
|
+
|
|
1464
|
+
def _create_vector_index(
|
|
1465
|
+
self, label: str, vector_property: str, dimensions: int, index_name: str
|
|
1466
|
+
) -> None:
|
|
1467
|
+
"""
|
|
1468
|
+
Create a vector index for the specified property in the label.
|
|
1469
|
+
"""
|
|
1470
|
+
try:
|
|
1471
|
+
query = f"""
|
|
1472
|
+
CREATE VECTOR INDEX {index_name} IF NOT EXISTS
|
|
1473
|
+
FOR (n:{label}) ON (n.{vector_property})
|
|
1474
|
+
OPTIONS {{
|
|
1475
|
+
indexConfig: {{
|
|
1476
|
+
`vector.dimensions`: {dimensions},
|
|
1477
|
+
`vector.similarity_function`: 'cosine'
|
|
1478
|
+
}}
|
|
1479
|
+
}}
|
|
1480
|
+
"""
|
|
1481
|
+
with self.driver.session(database=self.db_name) as session:
|
|
1482
|
+
session.run(query)
|
|
1483
|
+
logger.debug(f"Vector index '{index_name}' ensured.")
|
|
1484
|
+
except Exception as e:
|
|
1485
|
+
logger.warning(f"Failed to create vector index '{index_name}': {e}")
|
|
1486
|
+
|
|
1487
|
+
def _create_basic_property_indexes(self) -> None:
|
|
1488
|
+
"""
|
|
1489
|
+
Create standard B-tree indexes on memory_type, created_at,
|
|
1490
|
+
and updated_at fields.
|
|
1491
|
+
Create standard B-tree indexes on user_name when use Shared Database
|
|
1492
|
+
Multi-Tenant Mode
|
|
1493
|
+
"""
|
|
1494
|
+
try:
|
|
1495
|
+
with self.driver.session(database=self.db_name) as session:
|
|
1496
|
+
session.run("""
|
|
1497
|
+
CREATE INDEX memory_type_index IF NOT EXISTS
|
|
1498
|
+
FOR (n:Memory) ON (n.memory_type)
|
|
1499
|
+
""")
|
|
1500
|
+
logger.debug("Index 'memory_type_index' ensured.")
|
|
1501
|
+
|
|
1502
|
+
session.run("""
|
|
1503
|
+
CREATE INDEX memory_created_at_index IF NOT EXISTS
|
|
1504
|
+
FOR (n:Memory) ON (n.created_at)
|
|
1505
|
+
""")
|
|
1506
|
+
logger.debug("Index 'memory_created_at_index' ensured.")
|
|
1507
|
+
|
|
1508
|
+
session.run("""
|
|
1509
|
+
CREATE INDEX memory_updated_at_index IF NOT EXISTS
|
|
1510
|
+
FOR (n:Memory) ON (n.updated_at)
|
|
1511
|
+
""")
|
|
1512
|
+
logger.debug("Index 'memory_updated_at_index' ensured.")
|
|
1513
|
+
|
|
1514
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
1515
|
+
session.run(
|
|
1516
|
+
"""
|
|
1517
|
+
CREATE INDEX memory_user_name_index IF NOT EXISTS
|
|
1518
|
+
FOR (n:Memory) ON (n.user_name)
|
|
1519
|
+
"""
|
|
1520
|
+
)
|
|
1521
|
+
logger.debug("Index 'memory_user_name_index' ensured.")
|
|
1522
|
+
except Exception as e:
|
|
1523
|
+
logger.warning(f"Failed to create basic property indexes: {e}")
|
|
1524
|
+
|
|
1525
|
+
def _index_exists(self, index_name: str) -> bool:
|
|
1526
|
+
"""
|
|
1527
|
+
Check if an index with the given name exists.
|
|
1528
|
+
"""
|
|
1529
|
+
query = "SHOW INDEXES"
|
|
1530
|
+
with self.driver.session(database=self.db_name) as session:
|
|
1531
|
+
result = session.run(query)
|
|
1532
|
+
for record in result:
|
|
1533
|
+
if record["name"] == index_name:
|
|
1534
|
+
return True
|
|
1535
|
+
return False
|
|
1536
|
+
|
|
1537
|
+
def _build_user_name_and_kb_ids_conditions_cypher(
|
|
1538
|
+
self,
|
|
1539
|
+
user_name: str | None,
|
|
1540
|
+
knowledgebase_ids: list[str] | None,
|
|
1541
|
+
default_user_name: str | None = None,
|
|
1542
|
+
node_alias: str = "node",
|
|
1543
|
+
) -> tuple[list[str], dict[str, Any]]:
|
|
1544
|
+
"""
|
|
1545
|
+
Build user_name and knowledgebase_ids conditions for Cypher queries.
|
|
1546
|
+
|
|
1547
|
+
Args:
|
|
1548
|
+
user_name: User name for filtering
|
|
1549
|
+
knowledgebase_ids: List of knowledgebase IDs
|
|
1550
|
+
default_user_name: Default user name from config if user_name is None
|
|
1551
|
+
node_alias: Node alias in Cypher query (default: "node" or "n")
|
|
1552
|
+
|
|
1553
|
+
Returns:
|
|
1554
|
+
Tuple of (condition_strings_list, parameters_dict)
|
|
1555
|
+
"""
|
|
1556
|
+
user_name_conditions = []
|
|
1557
|
+
params = {}
|
|
1558
|
+
effective_user_name = user_name if user_name else default_user_name
|
|
1559
|
+
|
|
1560
|
+
# Only add user_name condition if not using multi-db mode
|
|
1561
|
+
if not self.config.use_multi_db and (self.config.user_name or effective_user_name):
|
|
1562
|
+
user_name_conditions.append(f"{node_alias}.user_name = $user_name")
|
|
1563
|
+
params["user_name"] = effective_user_name
|
|
1564
|
+
|
|
1565
|
+
# Add knowledgebase_ids conditions (checking user_name field in the data)
|
|
1566
|
+
if knowledgebase_ids and isinstance(knowledgebase_ids, list) and len(knowledgebase_ids) > 0:
|
|
1567
|
+
for idx, kb_id in enumerate(knowledgebase_ids):
|
|
1568
|
+
if isinstance(kb_id, str):
|
|
1569
|
+
param_name = f"kb_id_{idx}"
|
|
1570
|
+
user_name_conditions.append(f"{node_alias}.user_name = ${param_name}")
|
|
1571
|
+
params[param_name] = kb_id
|
|
1572
|
+
|
|
1573
|
+
return user_name_conditions, params
|
|
1574
|
+
|
|
1575
|
+
def _build_filter_conditions_cypher(
|
|
1576
|
+
self,
|
|
1577
|
+
filter: dict | None,
|
|
1578
|
+
param_counter_start: int = 0,
|
|
1579
|
+
node_alias: str = "node",
|
|
1580
|
+
) -> tuple[list[str], dict[str, Any]]:
|
|
1581
|
+
"""
|
|
1582
|
+
Build filter conditions for Cypher queries.
|
|
1583
|
+
|
|
1584
|
+
Args:
|
|
1585
|
+
filter: Filter dictionary with "or" or "and" logic
|
|
1586
|
+
param_counter_start: Starting value for parameter counter (to avoid conflicts)
|
|
1587
|
+
node_alias: Node alias in Cypher query (default: "node" or "n")
|
|
1588
|
+
|
|
1589
|
+
Returns:
|
|
1590
|
+
Tuple of (condition_strings_list, parameters_dict)
|
|
1591
|
+
"""
|
|
1592
|
+
filter_conditions = []
|
|
1593
|
+
filter_params = {}
|
|
1594
|
+
|
|
1595
|
+
if not filter:
|
|
1596
|
+
return filter_conditions, filter_params
|
|
1597
|
+
|
|
1598
|
+
def build_filter_condition(condition_dict: dict, param_counter: list) -> tuple[str, dict]:
|
|
1599
|
+
"""Build a WHERE condition for a single filter item.
|
|
1600
|
+
|
|
1601
|
+
Args:
|
|
1602
|
+
condition_dict: A dict like {"id": "xxx"} or {"A": "xxx"} or {"created_at": {"gt": "2025-11-01"}}
|
|
1603
|
+
param_counter: List to track parameter counter for unique param names
|
|
1604
|
+
|
|
1605
|
+
Returns:
|
|
1606
|
+
Tuple of (condition_string, parameters_dict)
|
|
1607
|
+
"""
|
|
1608
|
+
condition_parts = []
|
|
1609
|
+
params = {}
|
|
1610
|
+
|
|
1611
|
+
for key, value in condition_dict.items():
|
|
1612
|
+
# Check if value is a dict with comparison operators (gt, lt, gte, lte, contains, in, like)
|
|
1613
|
+
if isinstance(value, dict):
|
|
1614
|
+
# Handle comparison operators: gt, lt, gte, lte, contains, in, like
|
|
1615
|
+
for op, op_value in value.items():
|
|
1616
|
+
if op in ("gt", "lt", "gte", "lte"):
|
|
1617
|
+
# Map operator to Cypher operator
|
|
1618
|
+
cypher_op_map = {"gt": ">", "lt": "<", "gte": ">=", "lte": "<="}
|
|
1619
|
+
cypher_op = cypher_op_map[op]
|
|
1620
|
+
|
|
1621
|
+
# All fields are stored as flat properties in Neo4j
|
|
1622
|
+
param_name = f"filter_{key}_{op}_{param_counter[0]}"
|
|
1623
|
+
param_counter[0] += 1
|
|
1624
|
+
params[param_name] = op_value
|
|
1625
|
+
|
|
1626
|
+
# Check if field is a date field (created_at, updated_at, etc.)
|
|
1627
|
+
# Use datetime() function for date comparisons
|
|
1628
|
+
if key in ("created_at", "updated_at") or key.endswith("_at"):
|
|
1629
|
+
condition_parts.append(
|
|
1630
|
+
f"datetime({node_alias}.{key}) {cypher_op} datetime(${param_name})"
|
|
1631
|
+
)
|
|
1632
|
+
else:
|
|
1633
|
+
condition_parts.append(
|
|
1634
|
+
f"{node_alias}.{key} {cypher_op} ${param_name}"
|
|
1635
|
+
)
|
|
1636
|
+
elif op == "contains":
|
|
1637
|
+
# Handle contains operator
|
|
1638
|
+
# For arrays: use IN to check if array contains value (value IN array_field)
|
|
1639
|
+
# For strings: also use IN syntax to check if string value is in array field
|
|
1640
|
+
# Note: In Neo4j, for array fields, we use "value IN field" syntax
|
|
1641
|
+
param_name = f"filter_{key}_{op}_{param_counter[0]}"
|
|
1642
|
+
param_counter[0] += 1
|
|
1643
|
+
params[param_name] = op_value
|
|
1644
|
+
# Use IN syntax: value IN array_field (works for both string and array values)
|
|
1645
|
+
condition_parts.append(f"${param_name} IN {node_alias}.{key}")
|
|
1646
|
+
elif op == "in":
|
|
1647
|
+
# Handle in operator (for checking if field value is in a list)
|
|
1648
|
+
# Supports array format: {"field": {"in": ["value1", "value2"]}}
|
|
1649
|
+
if not isinstance(op_value, list):
|
|
1650
|
+
raise ValueError(
|
|
1651
|
+
f"in operator only supports array format. "
|
|
1652
|
+
f"Use {{'{key}': {{'in': ['{op_value}']}}}} instead of {{'{key}': {{'in': '{op_value}'}}}}"
|
|
1653
|
+
)
|
|
1654
|
+
# Build IN clause
|
|
1655
|
+
param_name = f"filter_{key}_{op}_{param_counter[0]}"
|
|
1656
|
+
param_counter[0] += 1
|
|
1657
|
+
params[param_name] = op_value
|
|
1658
|
+
condition_parts.append(f"{node_alias}.{key} IN ${param_name}")
|
|
1659
|
+
elif op == "like":
|
|
1660
|
+
# Handle like operator (for fuzzy matching, similar to SQL LIKE '%value%')
|
|
1661
|
+
# Neo4j uses CONTAINS for string matching
|
|
1662
|
+
param_name = f"filter_{key}_{op}_{param_counter[0]}"
|
|
1663
|
+
param_counter[0] += 1
|
|
1664
|
+
params[param_name] = op_value
|
|
1665
|
+
condition_parts.append(f"{node_alias}.{key} CONTAINS ${param_name}")
|
|
1666
|
+
else:
|
|
1667
|
+
# All fields are stored as flat properties in Neo4j (simple equality)
|
|
1668
|
+
param_name = f"filter_{key}_{param_counter[0]}"
|
|
1669
|
+
param_counter[0] += 1
|
|
1670
|
+
params[param_name] = value
|
|
1671
|
+
condition_parts.append(f"{node_alias}.{key} = ${param_name}")
|
|
1672
|
+
|
|
1673
|
+
return " AND ".join(condition_parts), params
|
|
1674
|
+
|
|
1675
|
+
param_counter = [param_counter_start]
|
|
1676
|
+
|
|
1677
|
+
if isinstance(filter, dict):
|
|
1678
|
+
if "or" in filter:
|
|
1679
|
+
# OR logic: at least one condition must match
|
|
1680
|
+
or_conditions = []
|
|
1681
|
+
for condition in filter["or"]:
|
|
1682
|
+
if isinstance(condition, dict):
|
|
1683
|
+
condition_str, params = build_filter_condition(condition, param_counter)
|
|
1684
|
+
if condition_str:
|
|
1685
|
+
or_conditions.append(f"({condition_str})")
|
|
1686
|
+
filter_params.update(params)
|
|
1687
|
+
if or_conditions:
|
|
1688
|
+
filter_conditions.append(f"({' OR '.join(or_conditions)})")
|
|
1689
|
+
|
|
1690
|
+
elif "and" in filter:
|
|
1691
|
+
# AND logic: all conditions must match
|
|
1692
|
+
for condition in filter["and"]:
|
|
1693
|
+
if isinstance(condition, dict):
|
|
1694
|
+
condition_str, params = build_filter_condition(condition, param_counter)
|
|
1695
|
+
if condition_str:
|
|
1696
|
+
filter_conditions.append(f"({condition_str})")
|
|
1697
|
+
filter_params.update(params)
|
|
1698
|
+
else:
|
|
1699
|
+
# Handle simple dict without "and" or "or" (e.g., {"id": "xxx"})
|
|
1700
|
+
condition_str, params = build_filter_condition(filter, param_counter)
|
|
1701
|
+
if condition_str:
|
|
1702
|
+
filter_conditions.append(condition_str)
|
|
1703
|
+
filter_params.update(params)
|
|
1704
|
+
|
|
1705
|
+
return filter_conditions, filter_params
|
|
1706
|
+
|
|
1707
|
+
def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]:
|
|
1708
|
+
node = node_data.copy()
|
|
1709
|
+
|
|
1710
|
+
# Convert Neo4j datetime to string
|
|
1711
|
+
for time_field in ("created_at", "updated_at"):
|
|
1712
|
+
if time_field in node and hasattr(node[time_field], "isoformat"):
|
|
1713
|
+
node[time_field] = node[time_field].isoformat()
|
|
1714
|
+
node.pop("user_name", None)
|
|
1715
|
+
|
|
1716
|
+
# serialization
|
|
1717
|
+
if node.get("sources"):
|
|
1718
|
+
for idx in range(len(node["sources"])):
|
|
1719
|
+
if not (
|
|
1720
|
+
isinstance(node["sources"][idx], str)
|
|
1721
|
+
and node["sources"][idx][0] == "{"
|
|
1722
|
+
and node["sources"][idx][0] == "}"
|
|
1723
|
+
):
|
|
1724
|
+
break
|
|
1725
|
+
node["sources"][idx] = json.loads(node["sources"][idx])
|
|
1726
|
+
return {"id": node.pop("id"), "memory": node.pop("memory", ""), "metadata": node}
|
|
1727
|
+
|
|
1728
|
+
def delete_node_by_prams(
|
|
1729
|
+
self,
|
|
1730
|
+
writable_cube_ids: list[str] | None = None,
|
|
1731
|
+
memory_ids: list[str] | None = None,
|
|
1732
|
+
file_ids: list[str] | None = None,
|
|
1733
|
+
filter: dict | None = None,
|
|
1734
|
+
) -> int:
|
|
1735
|
+
"""
|
|
1736
|
+
Delete nodes by memory_ids, file_ids, or filter.
|
|
1737
|
+
Supports three scenarios:
|
|
1738
|
+
1. Delete by memory_ids (standalone)
|
|
1739
|
+
2. Delete by writable_cube_ids + file_ids (combined)
|
|
1740
|
+
3. Delete by filter (standalone, no writable_cube_ids needed)
|
|
1741
|
+
|
|
1742
|
+
Args:
|
|
1743
|
+
writable_cube_ids (list[str], optional): List of cube IDs (user_name) to filter nodes.
|
|
1744
|
+
Only used with file_ids scenario. If not provided, no user_name filter will be applied.
|
|
1745
|
+
memory_ids (list[str], optional): List of memory node IDs to delete.
|
|
1746
|
+
file_ids (list[str], optional): List of file node IDs to delete. Must be used with writable_cube_ids.
|
|
1747
|
+
filter (dict, optional): Filter dictionary for metadata filtering.
|
|
1748
|
+
Filter conditions are directly used in DELETE WHERE clause without pre-querying.
|
|
1749
|
+
Does not require writable_cube_ids.
|
|
1750
|
+
|
|
1751
|
+
Returns:
|
|
1752
|
+
int: Number of nodes deleted.
|
|
1753
|
+
"""
|
|
1754
|
+
batch_start_time = time.time()
|
|
1755
|
+
logger.info(
|
|
1756
|
+
f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}"
|
|
1757
|
+
)
|
|
1758
|
+
|
|
1759
|
+
# Build user_name condition from writable_cube_ids (OR relationship - match any cube_id)
|
|
1760
|
+
# Only add user_name filter if writable_cube_ids is provided (for file_ids scenario)
|
|
1761
|
+
user_name_conditions = []
|
|
1762
|
+
params = {}
|
|
1763
|
+
if writable_cube_ids and len(writable_cube_ids) > 0:
|
|
1764
|
+
for idx, cube_id in enumerate(writable_cube_ids):
|
|
1765
|
+
param_name = f"cube_id_{idx}"
|
|
1766
|
+
user_name_conditions.append(f"n.user_name = ${param_name}")
|
|
1767
|
+
params[param_name] = cube_id
|
|
1768
|
+
|
|
1769
|
+
# Build filter conditions using common method (no query, direct use in WHERE clause)
|
|
1770
|
+
filter_conditions = []
|
|
1771
|
+
filter_params = {}
|
|
1772
|
+
if filter:
|
|
1773
|
+
filter_conditions, filter_params = self._build_filter_conditions_cypher(
|
|
1774
|
+
filter, param_counter_start=0, node_alias="n"
|
|
1775
|
+
)
|
|
1776
|
+
logger.info(f"[delete_node_by_prams] filter_conditions: {filter_conditions}")
|
|
1777
|
+
params.update(filter_params)
|
|
1778
|
+
|
|
1779
|
+
# If no conditions to delete, return 0
|
|
1780
|
+
if not memory_ids and not file_ids and not filter_conditions:
|
|
1781
|
+
logger.warning(
|
|
1782
|
+
"[delete_node_by_prams] No nodes to delete (no memory_ids, file_ids, or filter provided)"
|
|
1783
|
+
)
|
|
1784
|
+
return 0
|
|
1785
|
+
|
|
1786
|
+
# Build WHERE conditions list
|
|
1787
|
+
where_clauses = []
|
|
1788
|
+
|
|
1789
|
+
# Scenario 1: memory_ids (standalone)
|
|
1790
|
+
if memory_ids:
|
|
1791
|
+
logger.info(f"[delete_node_by_prams] Processing {len(memory_ids)} memory_ids")
|
|
1792
|
+
where_clauses.append("n.id IN $memory_ids")
|
|
1793
|
+
params["memory_ids"] = memory_ids
|
|
1794
|
+
|
|
1795
|
+
# Scenario 2: file_ids + writable_cube_ids (combined)
|
|
1796
|
+
if file_ids:
|
|
1797
|
+
logger.info(f"[delete_node_by_prams] Processing {len(file_ids)} file_ids")
|
|
1798
|
+
file_id_conditions = []
|
|
1799
|
+
for idx, file_id in enumerate(file_ids):
|
|
1800
|
+
param_name = f"file_id_{idx}"
|
|
1801
|
+
params[param_name] = file_id
|
|
1802
|
+
# Check if this file_id is in the file_ids array field
|
|
1803
|
+
file_id_conditions.append(f"${param_name} IN n.file_ids")
|
|
1804
|
+
if file_id_conditions:
|
|
1805
|
+
where_clauses.append(f"({' OR '.join(file_id_conditions)})")
|
|
1806
|
+
|
|
1807
|
+
# Scenario 3: filter (standalone, no writable_cube_ids needed)
|
|
1808
|
+
if filter_conditions:
|
|
1809
|
+
logger.info("[delete_node_by_prams] Processing filter conditions")
|
|
1810
|
+
# Combine filter conditions with AND
|
|
1811
|
+
filter_where = " AND ".join(filter_conditions)
|
|
1812
|
+
where_clauses.append(f"({filter_where})")
|
|
1813
|
+
|
|
1814
|
+
# Build final WHERE clause
|
|
1815
|
+
if not where_clauses:
|
|
1816
|
+
logger.warning("[delete_node_by_prams] No WHERE conditions to delete")
|
|
1817
|
+
return 0
|
|
1818
|
+
|
|
1819
|
+
# Combine all conditions with AND
|
|
1820
|
+
data_conditions = " AND ".join([f"({clause})" for clause in where_clauses])
|
|
1821
|
+
|
|
1822
|
+
# Add user_name filter if provided (for file_ids scenario)
|
|
1823
|
+
if user_name_conditions:
|
|
1824
|
+
user_name_where = " OR ".join(user_name_conditions)
|
|
1825
|
+
final_where = f"({user_name_where}) AND ({data_conditions})"
|
|
1826
|
+
else:
|
|
1827
|
+
final_where = data_conditions
|
|
1828
|
+
|
|
1829
|
+
# Delete directly without pre-counting
|
|
1830
|
+
delete_query = f"MATCH (n:Memory) WHERE {final_where} DETACH DELETE n"
|
|
1831
|
+
logger.info(f"[delete_node_by_prams] delete_query: {delete_query}")
|
|
1832
|
+
|
|
1833
|
+
deleted_count = 0
|
|
1834
|
+
try:
|
|
1835
|
+
with self.driver.session(database=self.db_name) as session:
|
|
1836
|
+
# Execute delete query
|
|
1837
|
+
result = session.run(delete_query, **params)
|
|
1838
|
+
# Consume the result to ensure deletion completes and get the summary
|
|
1839
|
+
summary = result.consume()
|
|
1840
|
+
# Get the count from the result summary
|
|
1841
|
+
deleted_count = summary.counters.nodes_deleted if summary.counters else 0
|
|
1842
|
+
|
|
1843
|
+
elapsed_time = time.time() - batch_start_time
|
|
1844
|
+
logger.info(
|
|
1845
|
+
f"[delete_node_by_prams] Deletion completed successfully in {elapsed_time:.2f}s, total deleted {deleted_count} nodes"
|
|
1846
|
+
)
|
|
1847
|
+
except Exception as e:
|
|
1848
|
+
logger.error(f"[delete_node_by_prams] Failed to delete nodes: {e}", exc_info=True)
|
|
1849
|
+
raise
|
|
1850
|
+
|
|
1851
|
+
logger.info(f"[delete_node_by_prams] Successfully deleted {deleted_count} nodes")
|
|
1852
|
+
return deleted_count
|
|
1853
|
+
|
|
1854
|
+
def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, str | None]:
|
|
1855
|
+
"""Get user names by memory ids.
|
|
1856
|
+
|
|
1857
|
+
Args:
|
|
1858
|
+
memory_ids: List of memory node IDs to query.
|
|
1859
|
+
|
|
1860
|
+
Returns:
|
|
1861
|
+
dict[str, str | None]: Dictionary mapping memory_id to user_name.
|
|
1862
|
+
- Key: memory_id
|
|
1863
|
+
- Value: user_name if exists, None if memory_id does not exist
|
|
1864
|
+
Example: {"4918d700-6f01-4f4c-a076-75cc7b0e1a7c": "zhangsan", "2222222": None}
|
|
1865
|
+
"""
|
|
1866
|
+
if not memory_ids:
|
|
1867
|
+
return {}
|
|
1868
|
+
|
|
1869
|
+
logger.info(f"[get_user_names_by_memory_ids] Querying memory_ids {memory_ids}")
|
|
1870
|
+
|
|
1871
|
+
try:
|
|
1872
|
+
with self.driver.session(database=self.db_name) as session:
|
|
1873
|
+
# Query to get memory_id and user_name pairs
|
|
1874
|
+
query = """
|
|
1875
|
+
MATCH (n:Memory)
|
|
1876
|
+
WHERE n.id IN $memory_ids
|
|
1877
|
+
RETURN n.id AS memory_id, n.user_name AS user_name
|
|
1878
|
+
"""
|
|
1879
|
+
logger.info(f"[get_user_names_by_memory_ids] query: {query}")
|
|
1880
|
+
|
|
1881
|
+
result = session.run(query, memory_ids=memory_ids)
|
|
1882
|
+
result_dict = {}
|
|
1883
|
+
|
|
1884
|
+
# Build result dictionary from query results
|
|
1885
|
+
for record in result:
|
|
1886
|
+
memory_id = record["memory_id"]
|
|
1887
|
+
user_name = record["user_name"]
|
|
1888
|
+
result_dict[memory_id] = user_name if user_name else None
|
|
1889
|
+
|
|
1890
|
+
# Set None for memory_ids that were not found
|
|
1891
|
+
for mid in memory_ids:
|
|
1892
|
+
if mid not in result_dict:
|
|
1893
|
+
result_dict[mid] = None
|
|
1894
|
+
|
|
1895
|
+
logger.info(
|
|
1896
|
+
f"[get_user_names_by_memory_ids] Found {len([v for v in result_dict.values() if v is not None])} memory_ids with user_names, "
|
|
1897
|
+
f"{len([v for v in result_dict.values() if v is None])} memory_ids without user_names"
|
|
1898
|
+
)
|
|
1899
|
+
|
|
1900
|
+
return result_dict
|
|
1901
|
+
except Exception as e:
|
|
1902
|
+
logger.error(
|
|
1903
|
+
f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True
|
|
1904
|
+
)
|
|
1905
|
+
raise
|
|
1906
|
+
|
|
1907
|
+
def exist_user_name(self, user_name: str) -> dict[str, bool]:
|
|
1908
|
+
"""Check if user name exists in the graph.
|
|
1909
|
+
|
|
1910
|
+
Args:
|
|
1911
|
+
user_name: User name to check.
|
|
1912
|
+
|
|
1913
|
+
Returns:
|
|
1914
|
+
dict[str, bool]: Dictionary with user_name as key and bool as value indicating existence.
|
|
1915
|
+
"""
|
|
1916
|
+
logger.info(f"[exist_user_name] Querying user_name {user_name}")
|
|
1917
|
+
if not user_name:
|
|
1918
|
+
return {user_name: False}
|
|
1919
|
+
|
|
1920
|
+
try:
|
|
1921
|
+
with self.driver.session(database=self.db_name) as session:
|
|
1922
|
+
# Query to check if user_name exists
|
|
1923
|
+
query = """
|
|
1924
|
+
MATCH (n:Memory)
|
|
1925
|
+
WHERE n.user_name = $user_name
|
|
1926
|
+
RETURN COUNT(n) AS count
|
|
1927
|
+
"""
|
|
1928
|
+
logger.info(f"[exist_user_name] query: {query}")
|
|
1929
|
+
|
|
1930
|
+
result = session.run(query, user_name=user_name)
|
|
1931
|
+
count = result.single()["count"]
|
|
1932
|
+
result_dict = {user_name: count > 0}
|
|
1933
|
+
|
|
1934
|
+
logger.info(
|
|
1935
|
+
f"[exist_user_name] user_name {user_name} exists: {result_dict[user_name]}"
|
|
1936
|
+
)
|
|
1937
|
+
return result_dict
|
|
1938
|
+
except Exception as e:
|
|
1939
|
+
logger.error(
|
|
1940
|
+
f"[exist_user_name] Failed to check user_name existence: {e}", exc_info=True
|
|
1941
|
+
)
|
|
1942
|
+
raise
|