MemoryOS 2.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. memoryos-2.0.3.dist-info/METADATA +418 -0
  2. memoryos-2.0.3.dist-info/RECORD +315 -0
  3. memoryos-2.0.3.dist-info/WHEEL +4 -0
  4. memoryos-2.0.3.dist-info/entry_points.txt +3 -0
  5. memoryos-2.0.3.dist-info/licenses/LICENSE +201 -0
  6. memos/__init__.py +20 -0
  7. memos/api/client.py +571 -0
  8. memos/api/config.py +1018 -0
  9. memos/api/context/dependencies.py +50 -0
  10. memos/api/exceptions.py +53 -0
  11. memos/api/handlers/__init__.py +62 -0
  12. memos/api/handlers/add_handler.py +158 -0
  13. memos/api/handlers/base_handler.py +194 -0
  14. memos/api/handlers/chat_handler.py +1401 -0
  15. memos/api/handlers/component_init.py +388 -0
  16. memos/api/handlers/config_builders.py +190 -0
  17. memos/api/handlers/feedback_handler.py +93 -0
  18. memos/api/handlers/formatters_handler.py +237 -0
  19. memos/api/handlers/memory_handler.py +316 -0
  20. memos/api/handlers/scheduler_handler.py +497 -0
  21. memos/api/handlers/search_handler.py +222 -0
  22. memos/api/handlers/suggestion_handler.py +117 -0
  23. memos/api/mcp_serve.py +614 -0
  24. memos/api/middleware/request_context.py +101 -0
  25. memos/api/product_api.py +38 -0
  26. memos/api/product_models.py +1206 -0
  27. memos/api/routers/__init__.py +1 -0
  28. memos/api/routers/product_router.py +477 -0
  29. memos/api/routers/server_router.py +394 -0
  30. memos/api/server_api.py +44 -0
  31. memos/api/start_api.py +433 -0
  32. memos/chunkers/__init__.py +4 -0
  33. memos/chunkers/base.py +24 -0
  34. memos/chunkers/charactertext_chunker.py +41 -0
  35. memos/chunkers/factory.py +24 -0
  36. memos/chunkers/markdown_chunker.py +62 -0
  37. memos/chunkers/sentence_chunker.py +54 -0
  38. memos/chunkers/simple_chunker.py +50 -0
  39. memos/cli.py +113 -0
  40. memos/configs/__init__.py +0 -0
  41. memos/configs/base.py +82 -0
  42. memos/configs/chunker.py +59 -0
  43. memos/configs/embedder.py +88 -0
  44. memos/configs/graph_db.py +236 -0
  45. memos/configs/internet_retriever.py +100 -0
  46. memos/configs/llm.py +151 -0
  47. memos/configs/mem_agent.py +54 -0
  48. memos/configs/mem_chat.py +81 -0
  49. memos/configs/mem_cube.py +105 -0
  50. memos/configs/mem_os.py +83 -0
  51. memos/configs/mem_reader.py +91 -0
  52. memos/configs/mem_scheduler.py +385 -0
  53. memos/configs/mem_user.py +70 -0
  54. memos/configs/memory.py +324 -0
  55. memos/configs/parser.py +38 -0
  56. memos/configs/reranker.py +18 -0
  57. memos/configs/utils.py +8 -0
  58. memos/configs/vec_db.py +80 -0
  59. memos/context/context.py +355 -0
  60. memos/dependency.py +52 -0
  61. memos/deprecation.py +262 -0
  62. memos/embedders/__init__.py +0 -0
  63. memos/embedders/ark.py +95 -0
  64. memos/embedders/base.py +106 -0
  65. memos/embedders/factory.py +29 -0
  66. memos/embedders/ollama.py +77 -0
  67. memos/embedders/sentence_transformer.py +49 -0
  68. memos/embedders/universal_api.py +51 -0
  69. memos/exceptions.py +30 -0
  70. memos/graph_dbs/__init__.py +0 -0
  71. memos/graph_dbs/base.py +274 -0
  72. memos/graph_dbs/factory.py +27 -0
  73. memos/graph_dbs/item.py +46 -0
  74. memos/graph_dbs/nebular.py +1794 -0
  75. memos/graph_dbs/neo4j.py +1942 -0
  76. memos/graph_dbs/neo4j_community.py +1058 -0
  77. memos/graph_dbs/polardb.py +5446 -0
  78. memos/hello_world.py +97 -0
  79. memos/llms/__init__.py +0 -0
  80. memos/llms/base.py +25 -0
  81. memos/llms/deepseek.py +13 -0
  82. memos/llms/factory.py +38 -0
  83. memos/llms/hf.py +443 -0
  84. memos/llms/hf_singleton.py +114 -0
  85. memos/llms/ollama.py +135 -0
  86. memos/llms/openai.py +222 -0
  87. memos/llms/openai_new.py +198 -0
  88. memos/llms/qwen.py +13 -0
  89. memos/llms/utils.py +14 -0
  90. memos/llms/vllm.py +218 -0
  91. memos/log.py +237 -0
  92. memos/mem_agent/base.py +19 -0
  93. memos/mem_agent/deepsearch_agent.py +391 -0
  94. memos/mem_agent/factory.py +36 -0
  95. memos/mem_chat/__init__.py +0 -0
  96. memos/mem_chat/base.py +30 -0
  97. memos/mem_chat/factory.py +21 -0
  98. memos/mem_chat/simple.py +200 -0
  99. memos/mem_cube/__init__.py +0 -0
  100. memos/mem_cube/base.py +30 -0
  101. memos/mem_cube/general.py +240 -0
  102. memos/mem_cube/navie.py +172 -0
  103. memos/mem_cube/utils.py +169 -0
  104. memos/mem_feedback/base.py +15 -0
  105. memos/mem_feedback/feedback.py +1192 -0
  106. memos/mem_feedback/simple_feedback.py +40 -0
  107. memos/mem_feedback/utils.py +230 -0
  108. memos/mem_os/client.py +5 -0
  109. memos/mem_os/core.py +1203 -0
  110. memos/mem_os/main.py +582 -0
  111. memos/mem_os/product.py +1608 -0
  112. memos/mem_os/product_server.py +455 -0
  113. memos/mem_os/utils/default_config.py +359 -0
  114. memos/mem_os/utils/format_utils.py +1403 -0
  115. memos/mem_os/utils/reference_utils.py +162 -0
  116. memos/mem_reader/__init__.py +0 -0
  117. memos/mem_reader/base.py +47 -0
  118. memos/mem_reader/factory.py +53 -0
  119. memos/mem_reader/memory.py +298 -0
  120. memos/mem_reader/multi_modal_struct.py +965 -0
  121. memos/mem_reader/read_multi_modal/__init__.py +43 -0
  122. memos/mem_reader/read_multi_modal/assistant_parser.py +311 -0
  123. memos/mem_reader/read_multi_modal/base.py +273 -0
  124. memos/mem_reader/read_multi_modal/file_content_parser.py +826 -0
  125. memos/mem_reader/read_multi_modal/image_parser.py +359 -0
  126. memos/mem_reader/read_multi_modal/multi_modal_parser.py +252 -0
  127. memos/mem_reader/read_multi_modal/string_parser.py +139 -0
  128. memos/mem_reader/read_multi_modal/system_parser.py +327 -0
  129. memos/mem_reader/read_multi_modal/text_content_parser.py +131 -0
  130. memos/mem_reader/read_multi_modal/tool_parser.py +210 -0
  131. memos/mem_reader/read_multi_modal/user_parser.py +218 -0
  132. memos/mem_reader/read_multi_modal/utils.py +358 -0
  133. memos/mem_reader/simple_struct.py +912 -0
  134. memos/mem_reader/strategy_struct.py +163 -0
  135. memos/mem_reader/utils.py +157 -0
  136. memos/mem_scheduler/__init__.py +0 -0
  137. memos/mem_scheduler/analyzer/__init__.py +0 -0
  138. memos/mem_scheduler/analyzer/api_analyzer.py +714 -0
  139. memos/mem_scheduler/analyzer/eval_analyzer.py +219 -0
  140. memos/mem_scheduler/analyzer/mos_for_test_scheduler.py +571 -0
  141. memos/mem_scheduler/analyzer/scheduler_for_eval.py +280 -0
  142. memos/mem_scheduler/base_scheduler.py +1319 -0
  143. memos/mem_scheduler/general_modules/__init__.py +0 -0
  144. memos/mem_scheduler/general_modules/api_misc.py +137 -0
  145. memos/mem_scheduler/general_modules/base.py +80 -0
  146. memos/mem_scheduler/general_modules/init_components_for_scheduler.py +425 -0
  147. memos/mem_scheduler/general_modules/misc.py +313 -0
  148. memos/mem_scheduler/general_modules/scheduler_logger.py +389 -0
  149. memos/mem_scheduler/general_modules/task_threads.py +315 -0
  150. memos/mem_scheduler/general_scheduler.py +1495 -0
  151. memos/mem_scheduler/memory_manage_modules/__init__.py +5 -0
  152. memos/mem_scheduler/memory_manage_modules/memory_filter.py +306 -0
  153. memos/mem_scheduler/memory_manage_modules/retriever.py +547 -0
  154. memos/mem_scheduler/monitors/__init__.py +0 -0
  155. memos/mem_scheduler/monitors/dispatcher_monitor.py +366 -0
  156. memos/mem_scheduler/monitors/general_monitor.py +394 -0
  157. memos/mem_scheduler/monitors/task_schedule_monitor.py +254 -0
  158. memos/mem_scheduler/optimized_scheduler.py +410 -0
  159. memos/mem_scheduler/orm_modules/__init__.py +0 -0
  160. memos/mem_scheduler/orm_modules/api_redis_model.py +518 -0
  161. memos/mem_scheduler/orm_modules/base_model.py +729 -0
  162. memos/mem_scheduler/orm_modules/monitor_models.py +261 -0
  163. memos/mem_scheduler/orm_modules/redis_model.py +699 -0
  164. memos/mem_scheduler/scheduler_factory.py +23 -0
  165. memos/mem_scheduler/schemas/__init__.py +0 -0
  166. memos/mem_scheduler/schemas/analyzer_schemas.py +52 -0
  167. memos/mem_scheduler/schemas/api_schemas.py +233 -0
  168. memos/mem_scheduler/schemas/general_schemas.py +55 -0
  169. memos/mem_scheduler/schemas/message_schemas.py +173 -0
  170. memos/mem_scheduler/schemas/monitor_schemas.py +406 -0
  171. memos/mem_scheduler/schemas/task_schemas.py +132 -0
  172. memos/mem_scheduler/task_schedule_modules/__init__.py +0 -0
  173. memos/mem_scheduler/task_schedule_modules/dispatcher.py +740 -0
  174. memos/mem_scheduler/task_schedule_modules/local_queue.py +247 -0
  175. memos/mem_scheduler/task_schedule_modules/orchestrator.py +74 -0
  176. memos/mem_scheduler/task_schedule_modules/redis_queue.py +1385 -0
  177. memos/mem_scheduler/task_schedule_modules/task_queue.py +162 -0
  178. memos/mem_scheduler/utils/__init__.py +0 -0
  179. memos/mem_scheduler/utils/api_utils.py +77 -0
  180. memos/mem_scheduler/utils/config_utils.py +100 -0
  181. memos/mem_scheduler/utils/db_utils.py +50 -0
  182. memos/mem_scheduler/utils/filter_utils.py +176 -0
  183. memos/mem_scheduler/utils/metrics.py +125 -0
  184. memos/mem_scheduler/utils/misc_utils.py +290 -0
  185. memos/mem_scheduler/utils/monitor_event_utils.py +67 -0
  186. memos/mem_scheduler/utils/status_tracker.py +229 -0
  187. memos/mem_scheduler/webservice_modules/__init__.py +0 -0
  188. memos/mem_scheduler/webservice_modules/rabbitmq_service.py +485 -0
  189. memos/mem_scheduler/webservice_modules/redis_service.py +380 -0
  190. memos/mem_user/factory.py +94 -0
  191. memos/mem_user/mysql_persistent_user_manager.py +271 -0
  192. memos/mem_user/mysql_user_manager.py +502 -0
  193. memos/mem_user/persistent_factory.py +98 -0
  194. memos/mem_user/persistent_user_manager.py +260 -0
  195. memos/mem_user/redis_persistent_user_manager.py +225 -0
  196. memos/mem_user/user_manager.py +488 -0
  197. memos/memories/__init__.py +0 -0
  198. memos/memories/activation/__init__.py +0 -0
  199. memos/memories/activation/base.py +42 -0
  200. memos/memories/activation/item.py +56 -0
  201. memos/memories/activation/kv.py +292 -0
  202. memos/memories/activation/vllmkv.py +219 -0
  203. memos/memories/base.py +19 -0
  204. memos/memories/factory.py +42 -0
  205. memos/memories/parametric/__init__.py +0 -0
  206. memos/memories/parametric/base.py +19 -0
  207. memos/memories/parametric/item.py +11 -0
  208. memos/memories/parametric/lora.py +41 -0
  209. memos/memories/textual/__init__.py +0 -0
  210. memos/memories/textual/base.py +92 -0
  211. memos/memories/textual/general.py +236 -0
  212. memos/memories/textual/item.py +304 -0
  213. memos/memories/textual/naive.py +187 -0
  214. memos/memories/textual/prefer_text_memory/__init__.py +0 -0
  215. memos/memories/textual/prefer_text_memory/adder.py +504 -0
  216. memos/memories/textual/prefer_text_memory/config.py +106 -0
  217. memos/memories/textual/prefer_text_memory/extractor.py +221 -0
  218. memos/memories/textual/prefer_text_memory/factory.py +85 -0
  219. memos/memories/textual/prefer_text_memory/retrievers.py +177 -0
  220. memos/memories/textual/prefer_text_memory/spliter.py +132 -0
  221. memos/memories/textual/prefer_text_memory/utils.py +93 -0
  222. memos/memories/textual/preference.py +344 -0
  223. memos/memories/textual/simple_preference.py +161 -0
  224. memos/memories/textual/simple_tree.py +69 -0
  225. memos/memories/textual/tree.py +459 -0
  226. memos/memories/textual/tree_text_memory/__init__.py +0 -0
  227. memos/memories/textual/tree_text_memory/organize/__init__.py +0 -0
  228. memos/memories/textual/tree_text_memory/organize/handler.py +184 -0
  229. memos/memories/textual/tree_text_memory/organize/manager.py +518 -0
  230. memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +238 -0
  231. memos/memories/textual/tree_text_memory/organize/reorganizer.py +622 -0
  232. memos/memories/textual/tree_text_memory/retrieve/__init__.py +0 -0
  233. memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py +364 -0
  234. memos/memories/textual/tree_text_memory/retrieve/bm25_util.py +186 -0
  235. memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +419 -0
  236. memos/memories/textual/tree_text_memory/retrieve/internet_retriever.py +270 -0
  237. memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +102 -0
  238. memos/memories/textual/tree_text_memory/retrieve/reasoner.py +61 -0
  239. memos/memories/textual/tree_text_memory/retrieve/recall.py +497 -0
  240. memos/memories/textual/tree_text_memory/retrieve/reranker.py +111 -0
  241. memos/memories/textual/tree_text_memory/retrieve/retrieval_mid_structs.py +16 -0
  242. memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py +472 -0
  243. memos/memories/textual/tree_text_memory/retrieve/searcher.py +848 -0
  244. memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +135 -0
  245. memos/memories/textual/tree_text_memory/retrieve/utils.py +54 -0
  246. memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py +387 -0
  247. memos/memos_tools/dinding_report_bot.py +453 -0
  248. memos/memos_tools/lockfree_dict.py +120 -0
  249. memos/memos_tools/notification_service.py +44 -0
  250. memos/memos_tools/notification_utils.py +142 -0
  251. memos/memos_tools/singleton.py +174 -0
  252. memos/memos_tools/thread_safe_dict.py +310 -0
  253. memos/memos_tools/thread_safe_dict_segment.py +382 -0
  254. memos/multi_mem_cube/__init__.py +0 -0
  255. memos/multi_mem_cube/composite_cube.py +86 -0
  256. memos/multi_mem_cube/single_cube.py +874 -0
  257. memos/multi_mem_cube/views.py +54 -0
  258. memos/parsers/__init__.py +0 -0
  259. memos/parsers/base.py +15 -0
  260. memos/parsers/factory.py +21 -0
  261. memos/parsers/markitdown.py +28 -0
  262. memos/reranker/__init__.py +4 -0
  263. memos/reranker/base.py +25 -0
  264. memos/reranker/concat.py +103 -0
  265. memos/reranker/cosine_local.py +102 -0
  266. memos/reranker/factory.py +72 -0
  267. memos/reranker/http_bge.py +324 -0
  268. memos/reranker/http_bge_strategy.py +327 -0
  269. memos/reranker/noop.py +19 -0
  270. memos/reranker/strategies/__init__.py +4 -0
  271. memos/reranker/strategies/base.py +61 -0
  272. memos/reranker/strategies/concat_background.py +94 -0
  273. memos/reranker/strategies/concat_docsource.py +110 -0
  274. memos/reranker/strategies/dialogue_common.py +109 -0
  275. memos/reranker/strategies/factory.py +31 -0
  276. memos/reranker/strategies/single_turn.py +107 -0
  277. memos/reranker/strategies/singleturn_outmem.py +98 -0
  278. memos/settings.py +10 -0
  279. memos/templates/__init__.py +0 -0
  280. memos/templates/advanced_search_prompts.py +211 -0
  281. memos/templates/cloud_service_prompt.py +107 -0
  282. memos/templates/instruction_completion.py +66 -0
  283. memos/templates/mem_agent_prompts.py +85 -0
  284. memos/templates/mem_feedback_prompts.py +822 -0
  285. memos/templates/mem_reader_prompts.py +1096 -0
  286. memos/templates/mem_reader_strategy_prompts.py +238 -0
  287. memos/templates/mem_scheduler_prompts.py +626 -0
  288. memos/templates/mem_search_prompts.py +93 -0
  289. memos/templates/mos_prompts.py +403 -0
  290. memos/templates/prefer_complete_prompt.py +735 -0
  291. memos/templates/tool_mem_prompts.py +139 -0
  292. memos/templates/tree_reorganize_prompts.py +230 -0
  293. memos/types/__init__.py +34 -0
  294. memos/types/general_types.py +151 -0
  295. memos/types/openai_chat_completion_types/__init__.py +15 -0
  296. memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py +56 -0
  297. memos/types/openai_chat_completion_types/chat_completion_content_part_image_param.py +27 -0
  298. memos/types/openai_chat_completion_types/chat_completion_content_part_input_audio_param.py +23 -0
  299. memos/types/openai_chat_completion_types/chat_completion_content_part_param.py +43 -0
  300. memos/types/openai_chat_completion_types/chat_completion_content_part_refusal_param.py +16 -0
  301. memos/types/openai_chat_completion_types/chat_completion_content_part_text_param.py +16 -0
  302. memos/types/openai_chat_completion_types/chat_completion_message_custom_tool_call_param.py +27 -0
  303. memos/types/openai_chat_completion_types/chat_completion_message_function_tool_call_param.py +32 -0
  304. memos/types/openai_chat_completion_types/chat_completion_message_param.py +18 -0
  305. memos/types/openai_chat_completion_types/chat_completion_message_tool_call_union_param.py +15 -0
  306. memos/types/openai_chat_completion_types/chat_completion_system_message_param.py +36 -0
  307. memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py +30 -0
  308. memos/types/openai_chat_completion_types/chat_completion_user_message_param.py +34 -0
  309. memos/utils.py +123 -0
  310. memos/vec_dbs/__init__.py +0 -0
  311. memos/vec_dbs/base.py +117 -0
  312. memos/vec_dbs/factory.py +23 -0
  313. memos/vec_dbs/item.py +50 -0
  314. memos/vec_dbs/milvus.py +654 -0
  315. memos/vec_dbs/qdrant.py +355 -0
@@ -0,0 +1,1794 @@
1
+ import json
2
+ import traceback
3
+
4
+ from contextlib import suppress
5
+ from datetime import datetime
6
+ from threading import Lock
7
+ from typing import TYPE_CHECKING, Any, ClassVar, Literal
8
+
9
+ import numpy as np
10
+
11
+ from memos.configs.graph_db import NebulaGraphDBConfig
12
+ from memos.dependency import require_python_package
13
+ from memos.graph_dbs.base import BaseGraphDB
14
+ from memos.log import get_logger
15
+ from memos.utils import timed
16
+
17
+
18
+ if TYPE_CHECKING:
19
+ from nebulagraph_python import (
20
+ NebulaClient,
21
+ )
22
+
23
+
24
+ logger = get_logger(__name__)
25
+
26
+
27
+ _TRANSIENT_ERR_KEYS = (
28
+ "Session not found",
29
+ "Connection not established",
30
+ "timeout",
31
+ "deadline exceeded",
32
+ "Broken pipe",
33
+ "EOFError",
34
+ "socket closed",
35
+ "connection reset",
36
+ "connection refused",
37
+ )
38
+
39
+
40
+ @timed
41
+ def _normalize(vec: list[float]) -> list[float]:
42
+ v = np.asarray(vec, dtype=np.float32)
43
+ norm = np.linalg.norm(v)
44
+ return (v / (norm if norm else 1.0)).tolist()
45
+
46
+
47
+ @timed
48
+ def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]:
49
+ node_id = item["id"]
50
+ memory = item["memory"]
51
+ metadata = item.get("metadata", {})
52
+ return node_id, memory, metadata
53
+
54
+
55
+ @timed
56
+ def _escape_str(value: str) -> str:
57
+ out = []
58
+ for ch in value:
59
+ code = ord(ch)
60
+ if ch == "\\":
61
+ out.append("\\\\")
62
+ elif ch == '"':
63
+ out.append('\\"')
64
+ elif ch == "\n":
65
+ out.append("\\n")
66
+ elif ch == "\r":
67
+ out.append("\\r")
68
+ elif ch == "\t":
69
+ out.append("\\t")
70
+ elif ch == "\b":
71
+ out.append("\\b")
72
+ elif ch == "\f":
73
+ out.append("\\f")
74
+ elif code < 0x20 or code in (0x2028, 0x2029):
75
+ out.append(f"\\u{code:04x}")
76
+ else:
77
+ out.append(ch)
78
+ return "".join(out)
79
+
80
+
81
+ @timed
82
+ def _format_datetime(value: str | datetime) -> str:
83
+ """Ensure datetime is in ISO 8601 format string."""
84
+ if isinstance(value, datetime):
85
+ return value.isoformat()
86
+ return str(value)
87
+
88
+
89
+ @timed
90
+ def _normalize_datetime(val):
91
+ """
92
+ Normalize datetime to ISO 8601 UTC string with +00:00.
93
+ - If val is datetime object -> keep isoformat() (Neo4j)
94
+ - If val is string without timezone -> append +00:00 (Nebula)
95
+ - Otherwise just str()
96
+ """
97
+ if hasattr(val, "isoformat"):
98
+ return val.isoformat()
99
+ if isinstance(val, str) and not val.endswith(("+00:00", "Z", "+08:00")):
100
+ return val + "+08:00"
101
+ return str(val)
102
+
103
+
104
+ class NebulaGraphDB(BaseGraphDB):
105
+ """
106
+ NebulaGraph-based implementation of a graph memory store.
107
+ """
108
+
109
+ # ====== shared pool cache & refcount ======
110
+ # These are process-local; in a multi-process model each process will
111
+ # have its own cache.
112
+ _CLIENT_CACHE: ClassVar[dict[str, "NebulaClient"]] = {}
113
+ _CLIENT_REFCOUNT: ClassVar[dict[str, int]] = {}
114
+ _CLIENT_LOCK: ClassVar[Lock] = Lock()
115
+ _CLIENT_INIT_DONE: ClassVar[set[str]] = set()
116
+
117
+ @staticmethod
118
+ def _get_hosts_from_cfg(cfg: NebulaGraphDBConfig) -> list[str]:
119
+ hosts = getattr(cfg, "uri", None) or getattr(cfg, "hosts", None)
120
+ if isinstance(hosts, str):
121
+ return [hosts]
122
+ return list(hosts or [])
123
+
124
+ @staticmethod
125
+ def _make_client_key(cfg: NebulaGraphDBConfig) -> str:
126
+ hosts = NebulaGraphDB._get_hosts_from_cfg(cfg)
127
+ return "|".join(
128
+ [
129
+ "nebula-sync",
130
+ ",".join(hosts),
131
+ str(getattr(cfg, "user", "")),
132
+ str(getattr(cfg, "space", "")),
133
+ ]
134
+ )
135
+
136
+ @classmethod
137
+ def _bootstrap_admin(cls, cfg: NebulaGraphDBConfig, client: "NebulaClient") -> "NebulaGraphDB":
138
+ tmp = object.__new__(NebulaGraphDB)
139
+ tmp.config = cfg
140
+ tmp.db_name = cfg.space
141
+ tmp.user_name = None
142
+ tmp.embedding_dimension = getattr(cfg, "embedding_dimension", 3072)
143
+ tmp.default_memory_dimension = 3072
144
+ tmp.common_fields = {
145
+ "id",
146
+ "memory",
147
+ "user_name",
148
+ "user_id",
149
+ "session_id",
150
+ "status",
151
+ "key",
152
+ "confidence",
153
+ "tags",
154
+ "created_at",
155
+ "updated_at",
156
+ "memory_type",
157
+ "sources",
158
+ "source",
159
+ "node_type",
160
+ "visibility",
161
+ "usage",
162
+ "background",
163
+ }
164
+ tmp.base_fields = set(tmp.common_fields) - {"usage"}
165
+ tmp.heavy_fields = {"usage"}
166
+ tmp.dim_field = (
167
+ f"embedding_{tmp.embedding_dimension}"
168
+ if str(tmp.embedding_dimension) != str(tmp.default_memory_dimension)
169
+ else "embedding"
170
+ )
171
+ tmp.system_db_name = cfg.space
172
+ tmp._client = client
173
+ tmp._owns_client = False
174
+ return tmp
175
+
176
+ @classmethod
177
+ def _get_or_create_shared_client(cls, cfg: NebulaGraphDBConfig) -> tuple[str, "NebulaClient"]:
178
+ from nebulagraph_python import (
179
+ ConnectionConfig,
180
+ NebulaClient,
181
+ SessionConfig,
182
+ SessionPoolConfig,
183
+ )
184
+
185
+ key = cls._make_client_key(cfg)
186
+ with cls._CLIENT_LOCK:
187
+ client = cls._CLIENT_CACHE.get(key)
188
+ if client is None:
189
+ # Connection setting
190
+
191
+ tmp_client = NebulaClient(
192
+ hosts=cfg.uri,
193
+ username=cfg.user,
194
+ password=cfg.password,
195
+ session_config=SessionConfig(graph=None),
196
+ session_pool_config=SessionPoolConfig(size=1, wait_timeout=3000),
197
+ )
198
+ try:
199
+ cls._ensure_space_exists(tmp_client, cfg)
200
+ finally:
201
+ tmp_client.close()
202
+
203
+ conn_conf: ConnectionConfig | None = getattr(cfg, "conn_config", None)
204
+ if conn_conf is None:
205
+ conn_conf = ConnectionConfig.from_defults(
206
+ cls._get_hosts_from_cfg(cfg),
207
+ getattr(cfg, "ssl_param", None),
208
+ )
209
+
210
+ sess_conf = SessionConfig(graph=getattr(cfg, "space", None))
211
+ pool_conf = SessionPoolConfig(
212
+ size=int(getattr(cfg, "max_client", 1000)), wait_timeout=5000
213
+ )
214
+
215
+ client = NebulaClient(
216
+ hosts=conn_conf.hosts,
217
+ username=cfg.user,
218
+ password=cfg.password,
219
+ conn_config=conn_conf,
220
+ session_config=sess_conf,
221
+ session_pool_config=pool_conf,
222
+ )
223
+ cls._CLIENT_CACHE[key] = client
224
+ cls._CLIENT_REFCOUNT[key] = 0
225
+ logger.info(f"[NebulaGraphDBSync] Created shared NebulaClient key={key}")
226
+
227
+ cls._CLIENT_REFCOUNT[key] = cls._CLIENT_REFCOUNT.get(key, 0) + 1
228
+
229
+ if getattr(cfg, "auto_create", False) and key not in cls._CLIENT_INIT_DONE:
230
+ try:
231
+ pass
232
+ finally:
233
+ pass
234
+
235
+ if getattr(cfg, "auto_create", False) and key not in cls._CLIENT_INIT_DONE:
236
+ with cls._CLIENT_LOCK:
237
+ if key not in cls._CLIENT_INIT_DONE:
238
+ admin = cls._bootstrap_admin(cfg, client)
239
+ try:
240
+ admin._ensure_database_exists()
241
+ admin._create_basic_property_indexes()
242
+ admin._create_vector_index(
243
+ dimensions=int(
244
+ admin.embedding_dimension or admin.default_memory_dimension
245
+ ),
246
+ )
247
+ cls._CLIENT_INIT_DONE.add(key)
248
+ logger.info("[NebulaGraphDBSync] One-time init done")
249
+ except Exception:
250
+ logger.exception("[NebulaGraphDBSync] One-time init failed")
251
+
252
+ return key, client
253
+
254
+ def _refresh_client(self):
255
+ """
256
+ refresh NebulaClient:
257
+ """
258
+ old_key = getattr(self, "_client_key", None)
259
+ if not old_key:
260
+ return
261
+
262
+ cls = self.__class__
263
+ with cls._CLIENT_LOCK:
264
+ try:
265
+ if old_key in cls._CLIENT_CACHE:
266
+ try:
267
+ cls._CLIENT_CACHE[old_key].close()
268
+ except Exception as e:
269
+ logger.warning(f"[refresh_client] close old client error: {e}")
270
+ finally:
271
+ cls._CLIENT_CACHE.pop(old_key, None)
272
+ finally:
273
+ cls._CLIENT_REFCOUNT[old_key] = 0
274
+
275
+ new_key, new_client = cls._get_or_create_shared_client(self.config)
276
+ self._client_key = new_key
277
+ self._client = new_client
278
+ logger.info(f"[NebulaGraphDBSync] client refreshed: {old_key} -> {new_key}")
279
+
280
+ @classmethod
281
+ def _release_shared_client(cls, key: str):
282
+ with cls._CLIENT_LOCK:
283
+ if key not in cls._CLIENT_CACHE:
284
+ return
285
+ cls._CLIENT_REFCOUNT[key] = max(0, cls._CLIENT_REFCOUNT.get(key, 0) - 1)
286
+ if cls._CLIENT_REFCOUNT[key] == 0:
287
+ try:
288
+ cls._CLIENT_CACHE[key].close()
289
+ except Exception as e:
290
+ logger.warning(f"[NebulaGraphDBSync] Error closing client: {e}")
291
+ finally:
292
+ cls._CLIENT_CACHE.pop(key, None)
293
+ cls._CLIENT_REFCOUNT.pop(key, None)
294
+ logger.info(f"[NebulaGraphDBSync] Closed & removed client key={key}")
295
+
296
+ @classmethod
297
+ def close_all_shared_clients(cls):
298
+ with cls._CLIENT_LOCK:
299
+ for key, client in list(cls._CLIENT_CACHE.items()):
300
+ try:
301
+ client.close()
302
+ except Exception as e:
303
+ logger.warning(f"[NebulaGraphDBSync] Error closing client {key}: {e}")
304
+ finally:
305
+ logger.info(f"[NebulaGraphDBSync] Closed client key={key}")
306
+ cls._CLIENT_CACHE.clear()
307
+ cls._CLIENT_REFCOUNT.clear()
308
+
309
+ @require_python_package(
310
+ import_name="nebulagraph_python",
311
+ install_command="pip install nebulagraph-python>=5.1.1",
312
+ install_link=".....",
313
+ )
314
+ def __init__(self, config: NebulaGraphDBConfig):
315
+ """
316
+ NebulaGraph DB client initialization.
317
+
318
+ Required config attributes:
319
+ - hosts: list[str] like ["host1:port", "host2:port"]
320
+ - user: str
321
+ - password: str
322
+ - db_name: str (optional for basic commands)
323
+
324
+ Example config:
325
+ {
326
+ "hosts": ["xxx.xx.xx.xxx:xxxx"],
327
+ "user": "root",
328
+ "password": "nebula",
329
+ "space": "test"
330
+ }
331
+ """
332
+
333
+ assert config.use_multi_db is False, "Multi-DB MODE IS NOT SUPPORTED"
334
+ self.config = config
335
+ self.db_name = config.space
336
+ self.user_name = config.user_name
337
+ self.embedding_dimension = config.embedding_dimension
338
+ self.default_memory_dimension = 3072
339
+ self.common_fields = {
340
+ "id",
341
+ "memory",
342
+ "user_name",
343
+ "user_id",
344
+ "session_id",
345
+ "status",
346
+ "key",
347
+ "confidence",
348
+ "tags",
349
+ "created_at",
350
+ "updated_at",
351
+ "memory_type",
352
+ "sources",
353
+ "source",
354
+ "node_type",
355
+ "visibility",
356
+ "usage",
357
+ "background",
358
+ }
359
+ self.base_fields = set(self.common_fields) - {"usage"}
360
+ self.heavy_fields = {"usage"}
361
+ self.dim_field = (
362
+ f"embedding_{self.embedding_dimension}"
363
+ if (str(self.embedding_dimension) != str(self.default_memory_dimension))
364
+ else "embedding"
365
+ )
366
+ self.system_db_name = config.space
367
+
368
+ # ---- NEW: pool acquisition strategy
369
+ # Get or create a shared pool from the class-level cache
370
+ self._client_key, self._client = self._get_or_create_shared_client(config)
371
+ self._owns_client = True
372
+
373
+ logger.info("Connected to NebulaGraph successfully.")
374
+
375
+ @timed
376
+ def execute_query(self, gql: str, timeout: float = 60.0, auto_set_db: bool = True):
377
+ def _wrap_use_db(q: str) -> str:
378
+ if auto_set_db and self.db_name:
379
+ return f"USE `{self.db_name}`\n{q}"
380
+ return q
381
+
382
+ try:
383
+ return self._client.execute(_wrap_use_db(gql), timeout=timeout)
384
+
385
+ except Exception as e:
386
+ emsg = str(e)
387
+ if any(k.lower() in emsg.lower() for k in _TRANSIENT_ERR_KEYS):
388
+ logger.warning(f"[execute_query] {e!s} → refreshing session pool and retry once...")
389
+ try:
390
+ self._refresh_client()
391
+ return self._client.execute(_wrap_use_db(gql), timeout=timeout)
392
+ except Exception:
393
+ logger.exception("[execute_query] retry after refresh failed")
394
+ raise
395
+ raise
396
+
397
+ @timed
398
+ def close(self):
399
+ """
400
+ Close the connection resource if this instance owns it.
401
+
402
+ - If pool was injected (`shared_pool`), do nothing.
403
+ - If pool was acquired via shared cache, decrement refcount and close
404
+ when the last owner releases it.
405
+ """
406
+ if not self._owns_client:
407
+ logger.debug("[NebulaGraphDBSync] close() skipped (injected client).")
408
+ return
409
+ if self._client_key:
410
+ self._release_shared_client(self._client_key)
411
+ self._client_key = None
412
+ self._client = None
413
+
414
+ # NOTE: __del__ is best-effort; do not rely on GC order.
415
+ def __del__(self):
416
+ with suppress(Exception):
417
+ self.close()
418
+
419
+ @timed
420
+ def create_index(
421
+ self,
422
+ label: str = "Memory",
423
+ vector_property: str = "embedding",
424
+ dimensions: int = 3072,
425
+ index_name: str = "memory_vector_index",
426
+ ) -> None:
427
+ # Create vector index
428
+ self._create_vector_index(label, vector_property, dimensions, index_name)
429
+ # Create indexes
430
+ self._create_basic_property_indexes()
431
+
432
+ @timed
433
+ def remove_oldest_memory(
434
+ self, memory_type: str, keep_latest: int, user_name: str | None = None
435
+ ) -> None:
436
+ """
437
+ Remove all WorkingMemory nodes except the latest `keep_latest` entries.
438
+
439
+ Args:
440
+ memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory').
441
+ keep_latest (int): Number of latest WorkingMemory entries to keep.
442
+ user_name(str): optional user_name.
443
+ """
444
+ try:
445
+ user_name = user_name if user_name else self.config.user_name
446
+ optional_condition = f"AND n.user_name = '{user_name}'"
447
+ count = self.count_nodes(memory_type, user_name)
448
+ if count > keep_latest:
449
+ delete_query = f"""
450
+ MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
451
+ WHERE n.memory_type = '{memory_type}'
452
+ {optional_condition}
453
+ ORDER BY n.updated_at DESC
454
+ OFFSET {int(keep_latest)}
455
+ DETACH DELETE n
456
+ """
457
+ self.execute_query(delete_query)
458
+ except Exception as e:
459
+ logger.warning(f"Delete old mem error: {e}")
460
+
461
+ @timed
462
+ def add_node(
463
+ self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None
464
+ ) -> None:
465
+ """
466
+ Insert or update a Memory node in NebulaGraph.
467
+ """
468
+ metadata["user_name"] = user_name if user_name else self.config.user_name
469
+ now = datetime.utcnow()
470
+ metadata = metadata.copy()
471
+ metadata.setdefault("created_at", now)
472
+ metadata.setdefault("updated_at", now)
473
+ metadata["node_type"] = metadata.pop("type")
474
+ metadata["id"] = id
475
+ metadata["memory"] = memory
476
+
477
+ if "embedding" in metadata and isinstance(metadata["embedding"], list):
478
+ assert len(metadata["embedding"]) == self.embedding_dimension, (
479
+ f"input embedding dimension must equal to {self.embedding_dimension}"
480
+ )
481
+ embedding = metadata.pop("embedding")
482
+ metadata[self.dim_field] = _normalize(embedding)
483
+
484
+ metadata = self._metadata_filter(metadata)
485
+ properties = ", ".join(f"{k}: {self._format_value(v, k)}" for k, v in metadata.items())
486
+ gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
487
+
488
+ try:
489
+ self.execute_query(gql)
490
+ logger.info("insert success")
491
+ except Exception as e:
492
+ logger.error(
493
+ f"Failed to insert vertex {id}: gql: {gql}, {e}\ntrace: {traceback.format_exc()}"
494
+ )
495
+
496
+ @timed
497
+ def node_not_exist(self, scope: str, user_name: str | None = None) -> int:
498
+ user_name = user_name if user_name else self.config.user_name
499
+ filter_clause = f'n.memory_type = "{scope}" AND n.user_name = "{user_name}"'
500
+ query = f"""
501
+ MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
502
+ WHERE {filter_clause}
503
+ RETURN n.id AS id
504
+ LIMIT 1
505
+ """
506
+
507
+ try:
508
+ result = self.execute_query(query)
509
+ return result.size == 0
510
+ except Exception as e:
511
+ logger.error(f"[node_not_exist] Query failed: {e}", exc_info=True)
512
+ raise
513
+
514
+ @timed
515
+ def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None:
516
+ """
517
+ Update node fields in Nebular, auto-converting `created_at` and `updated_at` to datetime type if present.
518
+ """
519
+ user_name = user_name if user_name else self.config.user_name
520
+ fields = fields.copy()
521
+ set_clauses = []
522
+ for k, v in fields.items():
523
+ set_clauses.append(f"n.{k} = {self._format_value(v, k)}")
524
+
525
+ set_clause_str = ",\n ".join(set_clauses)
526
+
527
+ query = f"""
528
+ MATCH (n@Memory {{id: "{id}"}})
529
+ """
530
+ query += f'WHERE n.user_name = "{user_name}"'
531
+
532
+ query += f"\nSET {set_clause_str}"
533
+ self.execute_query(query)
534
+
535
+ @timed
536
+ def delete_node(self, id: str, user_name: str | None = None) -> None:
537
+ """
538
+ Delete a node from the graph.
539
+ Args:
540
+ id: Node identifier to delete.
541
+ user_name (str, optional): User name for filtering in non-multi-db mode
542
+ """
543
+ user_name = user_name if user_name else self.config.user_name
544
+ query = f"""
545
+ MATCH (n@Memory {{id: "{id}"}}) WHERE n.user_name = {self._format_value(user_name)}
546
+ DETACH DELETE n
547
+ """
548
+ self.execute_query(query)
549
+
550
+ @timed
551
+ def add_edge(self, source_id: str, target_id: str, type: str, user_name: str | None = None):
552
+ """
553
+ Create an edge from source node to target node.
554
+ Args:
555
+ source_id: ID of the source node.
556
+ target_id: ID of the target node.
557
+ type: Relationship type (e.g., 'RELATE_TO', 'PARENT').
558
+ user_name (str, optional): User name for filtering in non-multi-db mode
559
+ """
560
+ if not source_id or not target_id:
561
+ raise ValueError("[add_edge] source_id and target_id must be provided")
562
+ user_name = user_name if user_name else self.config.user_name
563
+ props = ""
564
+ props = f'{{user_name: "{user_name}"}}'
565
+ insert_stmt = f'''
566
+ MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}})
567
+ INSERT (a) -[e@{type} {props}]-> (b)
568
+ '''
569
+ try:
570
+ self.execute_query(insert_stmt)
571
+ except Exception as e:
572
+ logger.error(f"Failed to insert edge: {e}", exc_info=True)
573
+
574
+ @timed
575
+ def delete_edge(
576
+ self, source_id: str, target_id: str, type: str, user_name: str | None = None
577
+ ) -> None:
578
+ """
579
+ Delete a specific edge between two nodes.
580
+ Args:
581
+ source_id: ID of the source node.
582
+ target_id: ID of the target node.
583
+ type: Relationship type to remove.
584
+ user_name (str, optional): User name for filtering in non-multi-db mode
585
+ """
586
+ user_name = user_name if user_name else self.config.user_name
587
+ query = f"""
588
+ MATCH (a@Memory) -[r@{type}]-> (b@Memory)
589
+ WHERE a.id = {self._format_value(source_id)} AND b.id = {self._format_value(target_id)}
590
+ """
591
+
592
+ query += f" AND a.user_name = {self._format_value(user_name)} AND b.user_name = {self._format_value(user_name)}"
593
+ query += "\nDELETE r"
594
+ self.execute_query(query)
595
+
596
+ @timed
597
+ def get_memory_count(self, memory_type: str, user_name: str | None = None) -> int:
598
+ user_name = user_name if user_name else self.config.user_name
599
+ query = f"""
600
+ MATCH (n@Memory)
601
+ WHERE n.memory_type = "{memory_type}"
602
+ """
603
+ query += f"\nAND n.user_name = '{user_name}'"
604
+ query += "\nRETURN COUNT(n) AS count"
605
+
606
+ try:
607
+ result = self.execute_query(query)
608
+ return result.one_or_none()["count"].value
609
+ except Exception as e:
610
+ logger.error(f"[get_memory_count] Failed: {e}")
611
+ return -1
612
+
613
+ @timed
614
+ def count_nodes(self, scope: str, user_name: str | None = None) -> int:
615
+ user_name = user_name if user_name else self.config.user_name
616
+ query = f"""
617
+ MATCH (n@Memory)
618
+ WHERE n.memory_type = "{scope}"
619
+ """
620
+ query += f"\nAND n.user_name = '{user_name}'"
621
+ query += "\nRETURN count(n) AS count"
622
+
623
+ result = self.execute_query(query)
624
+ return result.one_or_none()["count"].value
625
+
626
+ @timed
627
+ def edge_exists(
628
+ self,
629
+ source_id: str,
630
+ target_id: str,
631
+ type: str = "ANY",
632
+ direction: str = "OUTGOING",
633
+ user_name: str | None = None,
634
+ ) -> bool:
635
+ """
636
+ Check if an edge exists between two nodes.
637
+ Args:
638
+ source_id: ID of the source node.
639
+ target_id: ID of the target node.
640
+ type: Relationship type. Use "ANY" to match any relationship type.
641
+ direction: Direction of the edge.
642
+ Use "OUTGOING" (default), "INCOMING", or "ANY".
643
+ user_name (str, optional): User name for filtering in non-multi-db mode
644
+ Returns:
645
+ True if the edge exists, otherwise False.
646
+ """
647
+ # Prepare the relationship pattern
648
+ user_name = user_name if user_name else self.config.user_name
649
+ rel = "r" if type == "ANY" else f"r@{type}"
650
+
651
+ # Prepare the match pattern with direction
652
+ if direction == "OUTGOING":
653
+ pattern = f"(a@Memory {{id: '{source_id}'}})-[{rel}]->(b@Memory {{id: '{target_id}'}})"
654
+ elif direction == "INCOMING":
655
+ pattern = f"(a@Memory {{id: '{source_id}'}})<-[{rel}]-(b@Memory {{id: '{target_id}'}})"
656
+ elif direction == "ANY":
657
+ pattern = f"(a@Memory {{id: '{source_id}'}})-[{rel}]-(b@Memory {{id: '{target_id}'}})"
658
+ else:
659
+ raise ValueError(
660
+ f"Invalid direction: {direction}. Must be 'OUTGOING', 'INCOMING', or 'ANY'."
661
+ )
662
+ query = f"MATCH {pattern}"
663
+ query += f"\nWHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}'"
664
+ query += "\nRETURN r"
665
+
666
+ # Run the Cypher query
667
+ result = self.execute_query(query)
668
+ record = result.one_or_none()
669
+ if record is None:
670
+ return False
671
+ return record.values() is not None
672
+
673
+ @timed
674
+ # Graph Query & Reasoning
675
+ def get_node(
676
+ self, id: str, include_embedding: bool = False, user_name: str | None = None
677
+ ) -> dict[str, Any] | None:
678
+ """
679
+ Retrieve a Memory node by its unique ID.
680
+
681
+ Args:
682
+ id (str): Node ID (Memory.id)
683
+ include_embedding: with/without embedding
684
+ user_name (str, optional): User name for filtering in non-multi-db mode
685
+
686
+ Returns:
687
+ dict: Node properties as key-value pairs, or None if not found.
688
+ """
689
+ filter_clause = f'n.id = "{id}"'
690
+ return_fields = self._build_return_fields(include_embedding)
691
+ gql = f"""
692
+ MATCH (n@Memory)
693
+ WHERE {filter_clause}
694
+ RETURN {return_fields}
695
+ """
696
+
697
+ try:
698
+ result = self.execute_query(gql)
699
+ for row in result:
700
+ props = {k: v.value for k, v in row.items()}
701
+ node = self._parse_node(props)
702
+ return node
703
+
704
+ except Exception as e:
705
+ logger.error(
706
+ f"[get_node] Failed to retrieve node '{id}': {e}, trace: {traceback.format_exc()}"
707
+ )
708
+ return None
709
+
710
+ @timed
711
+ def get_nodes(
712
+ self,
713
+ ids: list[str],
714
+ include_embedding: bool = False,
715
+ user_name: str | None = None,
716
+ **kwargs,
717
+ ) -> list[dict[str, Any]]:
718
+ """
719
+ Retrieve the metadata and memory of a list of nodes.
720
+ Args:
721
+ ids: List of Node identifier.
722
+ include_embedding: with/without embedding
723
+ user_name (str, optional): User name for filtering in non-multi-db mode
724
+ Returns:
725
+ list[dict]: Parsed node records containing 'id', 'memory', and 'metadata'.
726
+
727
+ Notes:
728
+ - Assumes all provided IDs are valid and exist.
729
+ - Returns empty list if input is empty.
730
+ """
731
+ if not ids:
732
+ return []
733
+ # Safe formatting of the ID list
734
+ id_list = ",".join(f'"{_id}"' for _id in ids)
735
+
736
+ return_fields = self._build_return_fields(include_embedding)
737
+ query = f"""
738
+ MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
739
+ WHERE n.id IN [{id_list}]
740
+ RETURN {return_fields}
741
+ """
742
+ nodes = []
743
+ try:
744
+ results = self.execute_query(query)
745
+ for row in results:
746
+ props = {k: v.value for k, v in row.items()}
747
+ nodes.append(self._parse_node(props))
748
+ except Exception as e:
749
+ logger.error(
750
+ f"[get_nodes] Failed to retrieve nodes {ids}: {e}, trace: {traceback.format_exc()}"
751
+ )
752
+ return nodes
753
+
754
+ @timed
755
+ def get_edges(
756
+ self, id: str, type: str = "ANY", direction: str = "ANY", user_name: str | None = None
757
+ ) -> list[dict[str, str]]:
758
+ """
759
+ Get edges connected to a node, with optional type and direction filter.
760
+
761
+ Args:
762
+ id: Node ID to retrieve edges for.
763
+ type: Relationship type to match, or 'ANY' to match all.
764
+ direction: 'OUTGOING', 'INCOMING', or 'ANY'.
765
+ user_name (str, optional): User name for filtering in non-multi-db mode
766
+
767
+ Returns:
768
+ List of edges:
769
+ [
770
+ {"from": "source_id", "to": "target_id", "type": "RELATE"},
771
+ ...
772
+ ]
773
+ """
774
+ # Build relationship type filter
775
+ rel_type = "" if type == "ANY" else f"@{type}"
776
+ user_name = user_name if user_name else self.config.user_name
777
+ # Build Cypher pattern based on direction
778
+ if direction == "OUTGOING":
779
+ pattern = f"(a@Memory)-[r{rel_type}]->(b@Memory)"
780
+ where_clause = f"a.id = '{id}'"
781
+ elif direction == "INCOMING":
782
+ pattern = f"(a@Memory)<-[r{rel_type}]-(b@Memory)"
783
+ where_clause = f"a.id = '{id}'"
784
+ elif direction == "ANY":
785
+ pattern = f"(a@Memory)-[r{rel_type}]-(b@Memory)"
786
+ where_clause = f"a.id = '{id}' OR b.id = '{id}'"
787
+ else:
788
+ raise ValueError("Invalid direction. Must be 'OUTGOING', 'INCOMING', or 'ANY'.")
789
+
790
+ where_clause += f" AND a.user_name = '{user_name}' AND b.user_name = '{user_name}'"
791
+
792
+ query = f"""
793
+ MATCH {pattern}
794
+ WHERE {where_clause}
795
+ RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type
796
+ """
797
+
798
+ result = self.execute_query(query)
799
+ edges = []
800
+ for record in result:
801
+ edges.append(
802
+ {
803
+ "from": record["from_id"].value,
804
+ "to": record["to_id"].value,
805
+ "type": record["edge_type"].value,
806
+ }
807
+ )
808
+ return edges
809
+
810
+ @timed
811
+ def get_neighbors_by_tag(
812
+ self,
813
+ tags: list[str],
814
+ exclude_ids: list[str],
815
+ top_k: int = 5,
816
+ min_overlap: int = 1,
817
+ include_embedding: bool = False,
818
+ user_name: str | None = None,
819
+ ) -> list[dict[str, Any]]:
820
+ """
821
+ Find top-K neighbor nodes with maximum tag overlap.
822
+
823
+ Args:
824
+ tags: The list of tags to match.
825
+ exclude_ids: Node IDs to exclude (e.g., local cluster).
826
+ top_k: Max number of neighbors to return.
827
+ min_overlap: Minimum number of overlapping tags required.
828
+ include_embedding: with/without embedding
829
+ user_name (str, optional): User name for filtering in non-multi-db mode
830
+
831
+ Returns:
832
+ List of dicts with node details and overlap count.
833
+ """
834
+ if not tags:
835
+ return []
836
+ user_name = user_name if user_name else self.config.user_name
837
+ where_clauses = [
838
+ 'n.status = "activated"',
839
+ 'NOT (n.node_type = "reasoning")',
840
+ 'NOT (n.memory_type = "WorkingMemory")',
841
+ ]
842
+ if exclude_ids:
843
+ where_clauses.append(f"NOT (n.id IN {exclude_ids})")
844
+
845
+ where_clauses.append(f'n.user_name = "{user_name}"')
846
+
847
+ where_clause = " AND ".join(where_clauses)
848
+ tag_list_literal = "[" + ", ".join(f'"{_escape_str(t)}"' for t in tags) + "]"
849
+
850
+ return_fields = self._build_return_fields(include_embedding)
851
+ query = f"""
852
+ LET tag_list = {tag_list_literal}
853
+
854
+ MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
855
+ WHERE {where_clause}
856
+ RETURN {return_fields},
857
+ size( filter( n.tags, t -> t IN tag_list ) ) AS overlap_count
858
+ ORDER BY overlap_count DESC
859
+ LIMIT {top_k}
860
+ """
861
+
862
+ result = self.execute_query(query)
863
+ neighbors: list[dict[str, Any]] = []
864
+ for r in result:
865
+ props = {k: v.value for k, v in r.items() if k != "overlap_count"}
866
+ parsed = self._parse_node(props)
867
+ parsed["overlap_count"] = r["overlap_count"].value
868
+ neighbors.append(parsed)
869
+
870
+ neighbors.sort(key=lambda x: x["overlap_count"], reverse=True)
871
+ neighbors = neighbors[:top_k]
872
+ result = []
873
+ for neighbor in neighbors[:top_k]:
874
+ neighbor.pop("overlap_count")
875
+ result.append(neighbor)
876
+ return result
877
+
878
+ @timed
879
+ def get_children_with_embeddings(
880
+ self, id: str, user_name: str | None = None
881
+ ) -> list[dict[str, Any]]:
882
+ user_name = user_name if user_name else self.config.user_name
883
+ where_user = f"AND p.user_name = '{user_name}' AND c.user_name = '{user_name}'"
884
+
885
+ query = f"""
886
+ MATCH (p@Memory)-[@PARENT]->(c@Memory)
887
+ WHERE p.id = "{id}" {where_user}
888
+ RETURN c.id AS id, c.{self.dim_field} AS {self.dim_field}, c.memory AS memory
889
+ """
890
+ result = self.execute_query(query)
891
+ children = []
892
+ for row in result:
893
+ eid = row["id"].value # STRING
894
+ emb_v = row[self.dim_field].value # NVector
895
+ emb = list(emb_v.values) if emb_v else []
896
+ mem = row["memory"].value # STRING
897
+
898
+ children.append({"id": eid, "embedding": emb, "memory": mem})
899
+ return children
900
+
901
+ @timed
902
+ def get_subgraph(
903
+ self,
904
+ center_id: str,
905
+ depth: int = 2,
906
+ center_status: str = "activated",
907
+ user_name: str | None = None,
908
+ ) -> dict[str, Any]:
909
+ """
910
+ Retrieve a local subgraph centered at a given node.
911
+ Args:
912
+ center_id: The ID of the center node.
913
+ depth: The hop distance for neighbors.
914
+ center_status: Required status for center node.
915
+ user_name (str, optional): User name for filtering in non-multi-db mode
916
+ Returns:
917
+ {
918
+ "core_node": {...},
919
+ "neighbors": [...],
920
+ "edges": [...]
921
+ }
922
+ """
923
+ if not 1 <= depth <= 5:
924
+ raise ValueError("depth must be 1-5")
925
+
926
+ user_name = user_name if user_name else self.config.user_name
927
+
928
+ gql = f"""
929
+ MATCH (center@Memory /*+ INDEX(idx_memory_user_name) */)
930
+ WHERE center.id = '{center_id}'
931
+ AND center.status = '{center_status}'
932
+ AND center.user_name = '{user_name}'
933
+ OPTIONAL MATCH p = (center)-[e]->{{1,{depth}}}(neighbor@Memory)
934
+ WHERE neighbor.user_name = '{user_name}'
935
+ RETURN center,
936
+ collect(DISTINCT neighbor) AS neighbors,
937
+ collect(EDGES(p)) AS edge_chains
938
+ """
939
+
940
+ result = self.execute_query(gql).one_or_none()
941
+ if not result or result.size == 0:
942
+ return {"core_node": None, "neighbors": [], "edges": []}
943
+
944
+ core_node_props = result["center"].as_node().get_properties()
945
+ core_node = self._parse_node(core_node_props)
946
+ neighbors = []
947
+ vid_to_id_map = {result["center"].as_node().node_id: core_node["id"]}
948
+ for n in result["neighbors"].value:
949
+ n_node = n.as_node()
950
+ n_props = n_node.get_properties()
951
+ node_parsed = self._parse_node(n_props)
952
+ neighbors.append(node_parsed)
953
+ vid_to_id_map[n_node.node_id] = node_parsed["id"]
954
+
955
+ edges = []
956
+ for chain_group in result["edge_chains"].value:
957
+ for edge_wr in chain_group.value:
958
+ edge = edge_wr.value
959
+ edges.append(
960
+ {
961
+ "type": edge.get_type(),
962
+ "source": vid_to_id_map.get(edge.get_src_id()),
963
+ "target": vid_to_id_map.get(edge.get_dst_id()),
964
+ }
965
+ )
966
+
967
+ return {"core_node": core_node, "neighbors": neighbors, "edges": edges}
968
+
969
+ @timed
970
+ # Search / recall operations
971
+ def search_by_embedding(
972
+ self,
973
+ vector: list[float],
974
+ top_k: int = 5,
975
+ scope: str | None = None,
976
+ status: str | None = None,
977
+ threshold: float | None = None,
978
+ search_filter: dict | None = None,
979
+ user_name: str | None = None,
980
+ **kwargs,
981
+ ) -> list[dict]:
982
+ """
983
+ Retrieve node IDs based on vector similarity.
984
+
985
+ Args:
986
+ vector (list[float]): The embedding vector representing query semantics.
987
+ top_k (int): Number of top similar nodes to retrieve.
988
+ scope (str, optional): Memory type filter (e.g., 'WorkingMemory', 'LongTermMemory').
989
+ status (str, optional): Node status filter (e.g., 'active', 'archived').
990
+ If provided, restricts results to nodes with matching status.
991
+ threshold (float, optional): Minimum similarity score threshold (0 ~ 1).
992
+ search_filter (dict, optional): Additional metadata filters for search results.
993
+ Keys should match node properties, values are the expected values.
994
+ user_name (str, optional): User name for filtering in non-multi-db mode
995
+
996
+ Returns:
997
+ list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
998
+
999
+ Notes:
1000
+ - This method uses Neo4j native vector indexing to search for similar nodes.
1001
+ - If scope is provided, it restricts results to nodes with matching memory_type.
1002
+ - If 'status' is provided, only nodes with the matching status will be returned.
1003
+ - If threshold is provided, only results with score >= threshold will be returned.
1004
+ - If search_filter is provided, additional WHERE clauses will be added for metadata filtering.
1005
+ - Typical use case: restrict to 'status = activated' to avoid
1006
+ matching archived or merged nodes.
1007
+ """
1008
+ user_name = user_name if user_name else self.config.user_name
1009
+ vector = _normalize(vector)
1010
+ dim = len(vector)
1011
+ vector_str = ",".join(f"{float(x)}" for x in vector)
1012
+ gql_vector = f"VECTOR<{dim}, FLOAT>([{vector_str}])"
1013
+ where_clauses = [f"n.{self.dim_field} IS NOT NULL"]
1014
+ if scope:
1015
+ where_clauses.append(f'n.memory_type = "{scope}"')
1016
+ if status:
1017
+ where_clauses.append(f'n.status = "{status}"')
1018
+ where_clauses.append(f'n.user_name = "{user_name}"')
1019
+
1020
+ # Add search_filter conditions
1021
+ if search_filter:
1022
+ for key, value in search_filter.items():
1023
+ if isinstance(value, str):
1024
+ where_clauses.append(f'n.{key} = "{value}"')
1025
+ else:
1026
+ where_clauses.append(f"n.{key} = {value}")
1027
+
1028
+ where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
1029
+
1030
+ gql = f"""
1031
+ let a = {gql_vector}
1032
+ MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
1033
+ {where_clause}
1034
+ ORDER BY inner_product(n.{self.dim_field}, a) DESC
1035
+ LIMIT {top_k}
1036
+ RETURN n.id AS id, inner_product(n.{self.dim_field}, a) AS score"""
1037
+ try:
1038
+ result = self.execute_query(gql)
1039
+ except Exception as e:
1040
+ logger.error(f"[search_by_embedding] Query failed: {e}")
1041
+ return []
1042
+
1043
+ try:
1044
+ output = []
1045
+ for row in result:
1046
+ values = row.values()
1047
+ id_val = values[0].as_string()
1048
+ score_val = values[1].as_double()
1049
+ score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score
1050
+ if threshold is None or score_val >= threshold:
1051
+ output.append({"id": id_val, "score": score_val})
1052
+ return output
1053
+ except Exception as e:
1054
+ logger.error(f"[search_by_embedding] Result parse failed: {e}")
1055
+ return []
1056
+
1057
+ @timed
1058
+ def get_by_metadata(
1059
+ self, filters: list[dict[str, Any]], user_name: str | None = None
1060
+ ) -> list[str]:
1061
+ """
1062
+ 1. ADD logic: "AND" vs "OR"(support logic combination);
1063
+ 2. Support nested conditional expressions;
1064
+
1065
+ Retrieve node IDs that match given metadata filters.
1066
+ Supports exact match.
1067
+
1068
+ Args:
1069
+ filters: List of filter dicts like:
1070
+ [
1071
+ {"field": "key", "op": "in", "value": ["A", "B"]},
1072
+ {"field": "confidence", "op": ">=", "value": 80},
1073
+ {"field": "tags", "op": "contains", "value": "AI"},
1074
+ ...
1075
+ ]
1076
+ user_name (str, optional): User name for filtering in non-multi-db mode
1077
+
1078
+ Returns:
1079
+ list[str]: Node IDs whose metadata match the filter conditions. (AND logic).
1080
+
1081
+ Notes:
1082
+ - Supports structured querying such as tag/category/importance/time filtering.
1083
+ - Can be used for faceted recall or prefiltering before embedding rerank.
1084
+ """
1085
+ where_clauses = []
1086
+ user_name = user_name if user_name else self.config.user_name
1087
+ for _i, f in enumerate(filters):
1088
+ field = f["field"]
1089
+ op = f.get("op", "=")
1090
+ value = f["value"]
1091
+
1092
+ escaped_value = self._format_value(value)
1093
+
1094
+ # Build WHERE clause
1095
+ if op == "=":
1096
+ where_clauses.append(f"n.{field} = {escaped_value}")
1097
+ elif op == "in":
1098
+ where_clauses.append(f"n.{field} IN {escaped_value}")
1099
+ elif op == "contains":
1100
+ where_clauses.append(f"size(filter(n.{field}, t -> t IN {escaped_value})) > 0")
1101
+ elif op == "starts_with":
1102
+ where_clauses.append(f"n.{field} STARTS WITH {escaped_value}")
1103
+ elif op == "ends_with":
1104
+ where_clauses.append(f"n.{field} ENDS WITH {escaped_value}")
1105
+ elif op in [">", ">=", "<", "<="]:
1106
+ where_clauses.append(f"n.{field} {op} {escaped_value}")
1107
+ else:
1108
+ raise ValueError(f"Unsupported operator: {op}")
1109
+
1110
+ where_clauses.append(f'n.user_name = "{user_name}"')
1111
+
1112
+ where_str = " AND ".join(where_clauses)
1113
+ gql = f"MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) WHERE {where_str} RETURN n.id AS id"
1114
+ ids = []
1115
+ try:
1116
+ result = self.execute_query(gql)
1117
+ ids = [record["id"].value for record in result]
1118
+ except Exception as e:
1119
+ logger.error(f"Failed to get metadata: {e}, gql is {gql}")
1120
+ return ids
1121
+
1122
+ @timed
1123
+ def get_grouped_counts(
1124
+ self,
1125
+ group_fields: list[str],
1126
+ where_clause: str = "",
1127
+ params: dict[str, Any] | None = None,
1128
+ user_name: str | None = None,
1129
+ ) -> list[dict[str, Any]]:
1130
+ """
1131
+ Count nodes grouped by any fields.
1132
+
1133
+ Args:
1134
+ group_fields (list[str]): Fields to group by, e.g., ["memory_type", "status"]
1135
+ where_clause (str, optional): Extra WHERE condition. E.g.,
1136
+ "WHERE n.status = 'activated'"
1137
+ params (dict, optional): Parameters for WHERE clause.
1138
+ user_name (str, optional): User name for filtering in non-multi-db mode
1139
+
1140
+ Returns:
1141
+ list[dict]: e.g., [{ 'memory_type': 'WorkingMemory', 'status': 'active', 'count': 10 }, ...]
1142
+ """
1143
+ if not group_fields:
1144
+ raise ValueError("group_fields cannot be empty")
1145
+ user_name = user_name if user_name else self.config.user_name
1146
+ # GQL-specific modifications
1147
+ user_clause = f"n.user_name = '{user_name}'"
1148
+ if where_clause:
1149
+ where_clause = where_clause.strip()
1150
+ if where_clause.upper().startswith("WHERE"):
1151
+ where_clause += f" AND {user_clause}"
1152
+ else:
1153
+ where_clause = f"WHERE {where_clause} AND {user_clause}"
1154
+ else:
1155
+ where_clause = f"WHERE {user_clause}"
1156
+
1157
+ # Inline parameters if provided
1158
+ if params:
1159
+ for key, value in params.items():
1160
+ # Handle different value types appropriately
1161
+ if isinstance(value, str):
1162
+ value = f"'{value}'"
1163
+ where_clause = where_clause.replace(f"${key}", str(value))
1164
+
1165
+ return_fields = []
1166
+ group_by_fields = []
1167
+
1168
+ for field in group_fields:
1169
+ alias = field.replace(".", "_")
1170
+ return_fields.append(f"n.{field} AS {alias}")
1171
+ group_by_fields.append(alias)
1172
+ # Full GQL query construction
1173
+ gql = f"""
1174
+ MATCH (n /*+ INDEX(idx_memory_user_name) */)
1175
+ {where_clause}
1176
+ RETURN {", ".join(return_fields)}, COUNT(n) AS count
1177
+ """
1178
+ result = self.execute_query(gql) # Pure GQL string execution
1179
+
1180
+ output = []
1181
+ for record in result:
1182
+ group_values = {}
1183
+ for i, field in enumerate(group_fields):
1184
+ value = record.values()[i].as_string()
1185
+ group_values[field] = value
1186
+ count_value = record["count"].value
1187
+ output.append({**group_values, "count": count_value})
1188
+
1189
+ return output
1190
+
1191
+ @timed
1192
+ def clear(self, user_name: str | None = None) -> None:
1193
+ """
1194
+ Clear the entire graph if the target database exists.
1195
+
1196
+ Args:
1197
+ user_name (str, optional): User name for filtering in non-multi-db mode
1198
+ """
1199
+ user_name = user_name if user_name else self.config.user_name
1200
+ try:
1201
+ query = f"MATCH (n@Memory) WHERE n.user_name = '{user_name}' DETACH DELETE n"
1202
+ self.execute_query(query)
1203
+ logger.info("Cleared all nodes from database.")
1204
+
1205
+ except Exception as e:
1206
+ logger.error(f"[ERROR] Failed to clear database: {e}")
1207
+
1208
+ @timed
1209
+ def export_graph(
1210
+ self, include_embedding: bool = False, user_name: str | None = None, **kwargs
1211
+ ) -> dict[str, Any]:
1212
+ """
1213
+ Export all graph nodes and edges in a structured form.
1214
+ Args:
1215
+ include_embedding (bool): Whether to include the large embedding field.
1216
+ user_name (str, optional): User name for filtering in non-multi-db mode
1217
+
1218
+ Returns:
1219
+ {
1220
+ "nodes": [ { "id": ..., "memory": ..., "metadata": {...} }, ... ],
1221
+ "edges": [ { "source": ..., "target": ..., "type": ... }, ... ]
1222
+ }
1223
+ """
1224
+ user_name = user_name if user_name else self.config.user_name
1225
+ node_query = "MATCH (n@Memory)"
1226
+ edge_query = "MATCH (a@Memory)-[r]->(b@Memory)"
1227
+ node_query += f' WHERE n.user_name = "{user_name}"'
1228
+ edge_query += f' WHERE r.user_name = "{user_name}"'
1229
+
1230
+ try:
1231
+ if include_embedding:
1232
+ return_fields = "n"
1233
+ else:
1234
+ return_fields = ",".join(
1235
+ [
1236
+ "n.id AS id",
1237
+ "n.memory AS memory",
1238
+ "n.user_name AS user_name",
1239
+ "n.user_id AS user_id",
1240
+ "n.session_id AS session_id",
1241
+ "n.status AS status",
1242
+ "n.key AS key",
1243
+ "n.confidence AS confidence",
1244
+ "n.tags AS tags",
1245
+ "n.created_at AS created_at",
1246
+ "n.updated_at AS updated_at",
1247
+ "n.memory_type AS memory_type",
1248
+ "n.sources AS sources",
1249
+ "n.source AS source",
1250
+ "n.node_type AS node_type",
1251
+ "n.visibility AS visibility",
1252
+ "n.usage AS usage",
1253
+ "n.background AS background",
1254
+ ]
1255
+ )
1256
+
1257
+ full_node_query = f"{node_query} RETURN {return_fields}"
1258
+ node_result = self.execute_query(full_node_query, timeout=20)
1259
+ nodes = []
1260
+ logger.debug(f"Debugging: {node_result}")
1261
+ for row in node_result:
1262
+ if include_embedding:
1263
+ props = row.values()[0].as_node().get_properties()
1264
+ else:
1265
+ props = {k: v.value for k, v in row.items()}
1266
+ node = self._parse_node(props)
1267
+ nodes.append(node)
1268
+ except Exception as e:
1269
+ raise RuntimeError(f"[EXPORT GRAPH - NODES] Exception: {e}") from e
1270
+
1271
+ try:
1272
+ full_edge_query = f"{edge_query} RETURN a.id AS source, b.id AS target, type(r) as edge"
1273
+ edge_result = self.execute_query(full_edge_query, timeout=20)
1274
+ edges = [
1275
+ {
1276
+ "source": row.values()[0].value,
1277
+ "target": row.values()[1].value,
1278
+ "type": row.values()[2].value,
1279
+ }
1280
+ for row in edge_result
1281
+ ]
1282
+ except Exception as e:
1283
+ raise RuntimeError(f"[EXPORT GRAPH - EDGES] Exception: {e}") from e
1284
+
1285
+ return {"nodes": nodes, "edges": edges}
1286
+
1287
+ @timed
1288
+ def import_graph(self, data: dict[str, Any], user_name: str | None = None) -> None:
1289
+ """
1290
+ Import the entire graph from a serialized dictionary.
1291
+
1292
+ Args:
1293
+ data: A dictionary containing all nodes and edges to be loaded.
1294
+ user_name (str, optional): User name for filtering in non-multi-db mode
1295
+ """
1296
+ user_name = user_name if user_name else self.config.user_name
1297
+ for node in data.get("nodes", []):
1298
+ try:
1299
+ id, memory, metadata = _compose_node(node)
1300
+ metadata["user_name"] = user_name
1301
+ metadata = self._prepare_node_metadata(metadata)
1302
+ metadata.update({"id": id, "memory": memory})
1303
+ properties = ", ".join(
1304
+ f"{k}: {self._format_value(v, k)}" for k, v in metadata.items()
1305
+ )
1306
+ node_gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
1307
+ self.execute_query(node_gql)
1308
+ except Exception as e:
1309
+ logger.error(f"Fail to load node: {node}, error: {e}")
1310
+
1311
+ for edge in data.get("edges", []):
1312
+ try:
1313
+ source_id, target_id = edge["source"], edge["target"]
1314
+ edge_type = edge["type"]
1315
+ props = f'{{user_name: "{user_name}"}}'
1316
+ edge_gql = f'''
1317
+ MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}})
1318
+ INSERT OR IGNORE (a) -[e@{edge_type} {props}]-> (b)
1319
+ '''
1320
+ self.execute_query(edge_gql)
1321
+ except Exception as e:
1322
+ logger.error(f"Fail to load edge: {edge}, error: {e}")
1323
+
1324
+ @timed
1325
+ def get_all_memory_items(
1326
+ self, scope: str, include_embedding: bool = False, user_name: str | None = None
1327
+ ) -> (list)[dict]:
1328
+ """
1329
+ Retrieve all memory items of a specific memory_type.
1330
+
1331
+ Args:
1332
+ scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'.
1333
+ include_embedding: with/without embedding
1334
+ user_name (str, optional): User name for filtering in non-multi-db mode
1335
+
1336
+ Returns:
1337
+ list[dict]: Full list of memory items under this scope.
1338
+ """
1339
+ user_name = user_name if user_name else self.config.user_name
1340
+ if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"}:
1341
+ raise ValueError(f"Unsupported memory type scope: {scope}")
1342
+
1343
+ where_clause = f"WHERE n.memory_type = '{scope}'"
1344
+ where_clause += f" AND n.user_name = '{user_name}'"
1345
+
1346
+ return_fields = self._build_return_fields(include_embedding)
1347
+
1348
+ query = f"""
1349
+ MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
1350
+ {where_clause}
1351
+ RETURN {return_fields}
1352
+ LIMIT 100
1353
+ """
1354
+ nodes = []
1355
+ try:
1356
+ results = self.execute_query(query)
1357
+ for row in results:
1358
+ props = {k: v.value for k, v in row.items()}
1359
+ nodes.append(self._parse_node(props))
1360
+ except Exception as e:
1361
+ logger.error(f"Failed to get memories: {e}")
1362
+ return nodes
1363
+
1364
+ @timed
1365
+ def get_structure_optimization_candidates(
1366
+ self, scope: str, include_embedding: bool = False, user_name: str | None = None
1367
+ ) -> list[dict]:
1368
+ """
1369
+ Find nodes that are likely candidates for structure optimization:
1370
+ - Isolated nodes, nodes with empty background, or nodes with exactly one child.
1371
+ - Plus: the child of any parent node that has exactly one child.
1372
+ """
1373
+ user_name = user_name if user_name else self.config.user_name
1374
+ where_clause = f'''
1375
+ n.memory_type = "{scope}"
1376
+ AND n.status = "activated"
1377
+ '''
1378
+ where_clause += f' AND n.user_name = "{user_name}"'
1379
+
1380
+ return_fields = self._build_return_fields(include_embedding)
1381
+ return_fields += f", n.{self.dim_field} AS {self.dim_field}"
1382
+
1383
+ query = f"""
1384
+ MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
1385
+ WHERE {where_clause}
1386
+ OPTIONAL MATCH (n)-[@PARENT]->(c@Memory)
1387
+ OPTIONAL MATCH (p@Memory)-[@PARENT]->(n)
1388
+ WHERE c IS NULL AND p IS NULL
1389
+ RETURN {return_fields}
1390
+ """
1391
+
1392
+ candidates = []
1393
+ node_ids = set()
1394
+ try:
1395
+ results = self.execute_query(query)
1396
+ for row in results:
1397
+ props = {k: v.value for k, v in row.items()}
1398
+ node = self._parse_node(props)
1399
+ node_id = node["id"]
1400
+ if node_id not in node_ids:
1401
+ candidates.append(node)
1402
+ node_ids.add(node_id)
1403
+ except Exception as e:
1404
+ logger.error(f"Failed : {e}, traceback: {traceback.format_exc()}")
1405
+ return candidates
1406
+
1407
+ @timed
1408
+ def drop_database(self) -> None:
1409
+ """
1410
+ Permanently delete the entire database this instance is using.
1411
+ WARNING: This operation is destructive and cannot be undone.
1412
+ """
1413
+ raise ValueError(
1414
+ f"Refusing to drop protected database: `{self.db_name}` in "
1415
+ f"Shared Database Multi-Tenant mode"
1416
+ )
1417
+
1418
+ @timed
1419
+ def detect_conflicts(self) -> list[tuple[str, str]]:
1420
+ """
1421
+ Detect conflicting nodes based on logical or semantic inconsistency.
1422
+ Returns:
1423
+ A list of (node_id1, node_id2) tuples that conflict.
1424
+ """
1425
+ raise NotImplementedError
1426
+
1427
+ @timed
1428
+ # Structure Maintenance
1429
+ def deduplicate_nodes(self) -> None:
1430
+ """
1431
+ Deduplicate redundant or semantically similar nodes.
1432
+ This typically involves identifying nodes with identical or near-identical memory.
1433
+ """
1434
+ raise NotImplementedError
1435
+
1436
+ @timed
1437
+ def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]:
1438
+ """
1439
+ Get the ordered context chain starting from a node, following a relationship type.
1440
+ Args:
1441
+ id: Starting node ID.
1442
+ type: Relationship type to follow (e.g., 'FOLLOWS').
1443
+ Returns:
1444
+ List of ordered node IDs in the chain.
1445
+ """
1446
+ raise NotImplementedError
1447
+
1448
+ @timed
1449
+ def get_neighbors(
1450
+ self, id: str, type: str, direction: Literal["in", "out", "both"] = "out"
1451
+ ) -> list[str]:
1452
+ """
1453
+ Get connected node IDs in a specific direction and relationship type.
1454
+ Args:
1455
+ id: Source node ID.
1456
+ type: Relationship type.
1457
+ direction: Edge direction to follow ('out', 'in', or 'both').
1458
+ Returns:
1459
+ List of neighboring node IDs.
1460
+ """
1461
+ raise NotImplementedError
1462
+
1463
+ @timed
1464
+ def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[str]:
1465
+ """
1466
+ Get the path of nodes from source to target within a limited depth.
1467
+ Args:
1468
+ source_id: Starting node ID.
1469
+ target_id: Target node ID.
1470
+ max_depth: Maximum path length to traverse.
1471
+ Returns:
1472
+ Ordered list of node IDs along the path.
1473
+ """
1474
+ raise NotImplementedError
1475
+
1476
+ @timed
1477
+ def merge_nodes(self, id1: str, id2: str) -> str:
1478
+ """
1479
+ Merge two similar or duplicate nodes into one.
1480
+ Args:
1481
+ id1: First node ID.
1482
+ id2: Second node ID.
1483
+ Returns:
1484
+ ID of the resulting merged node.
1485
+ """
1486
+ raise NotImplementedError
1487
+
1488
+ @classmethod
1489
+ def _ensure_space_exists(cls, tmp_client, cfg):
1490
+ """Lightweight check to ensure target graph (space) exists."""
1491
+ db_name = getattr(cfg, "space", None)
1492
+ if not db_name:
1493
+ logger.warning("[NebulaGraphDBSync] No `space` specified in cfg.")
1494
+ return
1495
+
1496
+ try:
1497
+ res = tmp_client.execute("SHOW GRAPHS")
1498
+ existing = {row.values()[0].as_string() for row in res}
1499
+ if db_name not in existing:
1500
+ tmp_client.execute(f"CREATE GRAPH IF NOT EXISTS `{db_name}` TYPED MemOSBgeM3Type")
1501
+ logger.info(f"✅ Graph `{db_name}` created before session binding.")
1502
+ else:
1503
+ logger.debug(f"Graph `{db_name}` already exists.")
1504
+ except Exception:
1505
+ logger.exception("[NebulaGraphDBSync] Failed to ensure space exists")
1506
+
1507
+ @timed
1508
+ def _ensure_database_exists(self):
1509
+ graph_type_name = "MemOSBgeM3Type"
1510
+
1511
+ check_type_query = "SHOW GRAPH TYPES"
1512
+ result = self.execute_query(check_type_query, auto_set_db=False)
1513
+
1514
+ type_exists = any(row["graph_type"].as_string() == graph_type_name for row in result)
1515
+
1516
+ if not type_exists:
1517
+ create_tag = f"""
1518
+ CREATE GRAPH TYPE IF NOT EXISTS {graph_type_name} AS {{
1519
+ NODE Memory (:MemoryTag {{
1520
+ id STRING,
1521
+ memory STRING,
1522
+ user_name STRING,
1523
+ user_id STRING,
1524
+ session_id STRING,
1525
+ status STRING,
1526
+ key STRING,
1527
+ confidence FLOAT,
1528
+ tags LIST<STRING>,
1529
+ created_at STRING,
1530
+ updated_at STRING,
1531
+ memory_type STRING,
1532
+ sources LIST<STRING>,
1533
+ source STRING,
1534
+ node_type STRING,
1535
+ visibility STRING,
1536
+ usage LIST<STRING>,
1537
+ background STRING,
1538
+ {self.dim_field} VECTOR<{self.embedding_dimension}, FLOAT>,
1539
+ PRIMARY KEY(id)
1540
+ }}),
1541
+ EDGE RELATE_TO (Memory) -[{{user_name STRING}}]-> (Memory),
1542
+ EDGE PARENT (Memory) -[{{user_name STRING}}]-> (Memory),
1543
+ EDGE AGGREGATE_TO (Memory) -[{{user_name STRING}}]-> (Memory),
1544
+ EDGE MERGED_TO (Memory) -[{{user_name STRING}}]-> (Memory),
1545
+ EDGE INFERS (Memory) -[{{user_name STRING}}]-> (Memory),
1546
+ EDGE FOLLOWS (Memory) -[{{user_name STRING}}]-> (Memory)
1547
+ }}
1548
+ """
1549
+ self.execute_query(create_tag, auto_set_db=False)
1550
+ else:
1551
+ describe_query = f"DESCRIBE NODE TYPE Memory OF {graph_type_name}"
1552
+ desc_result = self.execute_query(describe_query, auto_set_db=False)
1553
+
1554
+ memory_fields = []
1555
+ for row in desc_result:
1556
+ field_name = row.values()[0].as_string()
1557
+ memory_fields.append(field_name)
1558
+
1559
+ if self.dim_field not in memory_fields:
1560
+ alter_query = f"""
1561
+ ALTER GRAPH TYPE {graph_type_name} {{
1562
+ ALTER NODE TYPE Memory ADD PROPERTIES {{ {self.dim_field} VECTOR<{self.embedding_dimension}, FLOAT> }}
1563
+ }}
1564
+ """
1565
+ self.execute_query(alter_query, auto_set_db=False)
1566
+ logger.info(f"✅ Add new vector search {self.dim_field} to {graph_type_name}")
1567
+ else:
1568
+ logger.info(f"✅ Graph Type {graph_type_name} already include {self.dim_field}")
1569
+
1570
+ create_graph = f"CREATE GRAPH IF NOT EXISTS `{self.db_name}` TYPED {graph_type_name}"
1571
+ try:
1572
+ self.execute_query(create_graph, auto_set_db=False)
1573
+ logger.info(f"✅ Graph ``{self.db_name}`` is now the working graph.")
1574
+ except Exception as e:
1575
+ logger.error(f"❌ Failed to create tag: {e} trace: {traceback.format_exc()}")
1576
+
1577
+ @timed
1578
+ def _create_vector_index(
1579
+ self,
1580
+ label: str = "Memory",
1581
+ vector_property: str = "embedding",
1582
+ dimensions: int = 3072,
1583
+ index_name: str = "memory_vector_index",
1584
+ ) -> None:
1585
+ """
1586
+ Create a vector index for the specified property in the label.
1587
+ """
1588
+ if str(dimensions) == str(self.default_memory_dimension):
1589
+ index_name = f"idx_{vector_property}"
1590
+ vector_name = vector_property
1591
+ else:
1592
+ index_name = f"idx_{vector_property}_{dimensions}"
1593
+ vector_name = f"{vector_property}_{dimensions}"
1594
+
1595
+ create_vector_index = f"""
1596
+ CREATE VECTOR INDEX IF NOT EXISTS {index_name}
1597
+ ON NODE {label}::{vector_name}
1598
+ OPTIONS {{
1599
+ DIM: {dimensions},
1600
+ METRIC: IP,
1601
+ TYPE: IVF,
1602
+ NLIST: 100,
1603
+ TRAINSIZE: 1000
1604
+ }}
1605
+ FOR `{self.db_name}`
1606
+ """
1607
+ self.execute_query(create_vector_index)
1608
+ logger.info(
1609
+ f"✅ Ensure {label}::{vector_property} vector index {index_name} "
1610
+ f"exists (DIM={dimensions})"
1611
+ )
1612
+
1613
+ @timed
1614
+ def _create_basic_property_indexes(self) -> None:
1615
+ """
1616
+ Create standard B-tree indexes on status, memory_type, created_at
1617
+ and updated_at fields.
1618
+ Create standard B-tree indexes on user_name when use Shared Database
1619
+ Multi-Tenant Mode.
1620
+ """
1621
+ fields = [
1622
+ "status",
1623
+ "memory_type",
1624
+ "created_at",
1625
+ "updated_at",
1626
+ "user_name",
1627
+ ]
1628
+
1629
+ for field in fields:
1630
+ index_name = f"idx_memory_{field}"
1631
+ gql = f"""
1632
+ CREATE INDEX IF NOT EXISTS {index_name} ON NODE Memory({field})
1633
+ FOR `{self.db_name}`
1634
+ """
1635
+ try:
1636
+ self.execute_query(gql)
1637
+ logger.info(f"✅ Created index: {index_name} on field {field}")
1638
+ except Exception as e:
1639
+ logger.error(
1640
+ f"❌ Failed to create index {index_name}: {e}, trace: {traceback.format_exc()}"
1641
+ )
1642
+
1643
+ @timed
1644
+ def _index_exists(self, index_name: str) -> bool:
1645
+ """
1646
+ Check if an index with the given name exists.
1647
+ """
1648
+ """
1649
+ Check if a vector index with the given name exists in NebulaGraph.
1650
+
1651
+ Args:
1652
+ index_name (str): The name of the index to check.
1653
+
1654
+ Returns:
1655
+ bool: True if the index exists, False otherwise.
1656
+ """
1657
+ query = "SHOW VECTOR INDEXES"
1658
+ try:
1659
+ result = self.execute_query(query)
1660
+ return any(row.values()[0].as_string() == index_name for row in result)
1661
+ except Exception as e:
1662
+ logger.error(f"[Nebula] Failed to check index existence: {e}")
1663
+ return False
1664
+
1665
+ @timed
1666
+ def _parse_value(self, value: Any) -> Any:
1667
+ """turn Nebula ValueWrapper to Python type"""
1668
+ from nebulagraph_python.value_wrapper import ValueWrapper
1669
+
1670
+ if value is None or (hasattr(value, "is_null") and value.is_null()):
1671
+ return None
1672
+ try:
1673
+ prim = value.cast_primitive() if isinstance(value, ValueWrapper) else value
1674
+ except Exception as e:
1675
+ logger.warning(f"Error when decode Nebula ValueWrapper: {e}")
1676
+ prim = value.cast() if isinstance(value, ValueWrapper) else value
1677
+
1678
+ if isinstance(prim, ValueWrapper):
1679
+ return self._parse_value(prim)
1680
+ if isinstance(prim, list):
1681
+ return [self._parse_value(v) for v in prim]
1682
+ if type(prim).__name__ == "NVector":
1683
+ return list(prim.values)
1684
+
1685
+ return prim # already a Python primitive
1686
+
1687
+ def _parse_node(self, props: dict[str, Any]) -> dict[str, Any]:
1688
+ parsed = {k: self._parse_value(v) for k, v in props.items()}
1689
+
1690
+ for tf in ("created_at", "updated_at"):
1691
+ if tf in parsed and parsed[tf] is not None:
1692
+ parsed[tf] = _normalize_datetime(parsed[tf])
1693
+
1694
+ node_id = parsed.pop("id")
1695
+ memory = parsed.pop("memory", "")
1696
+ parsed.pop("user_name", None)
1697
+ metadata = parsed
1698
+ metadata["type"] = metadata.pop("node_type")
1699
+
1700
+ if self.dim_field in metadata:
1701
+ metadata["embedding"] = metadata.pop(self.dim_field)
1702
+
1703
+ return {"id": node_id, "memory": memory, "metadata": metadata}
1704
+
1705
+ @timed
1706
+ def _prepare_node_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]:
1707
+ """
1708
+ Ensure metadata has proper datetime fields and normalized types.
1709
+
1710
+ - Fill `created_at` and `updated_at` if missing (in ISO 8601 format).
1711
+ - Convert embedding to list of float if present.
1712
+ """
1713
+ now = datetime.utcnow().isoformat()
1714
+ metadata["node_type"] = metadata.pop("type")
1715
+
1716
+ # Fill timestamps if missing
1717
+ metadata.setdefault("created_at", now)
1718
+ metadata.setdefault("updated_at", now)
1719
+
1720
+ # Normalize embedding type
1721
+ embedding = metadata.get("embedding")
1722
+ if embedding and isinstance(embedding, list):
1723
+ metadata.pop("embedding")
1724
+ metadata[self.dim_field] = _normalize([float(x) for x in embedding])
1725
+
1726
+ return metadata
1727
+
1728
+ @timed
1729
+ def _format_value(self, val: Any, key: str = "") -> str:
1730
+ from nebulagraph_python.py_data_types import NVector
1731
+
1732
+ # None
1733
+ if val is None:
1734
+ return "NULL"
1735
+ # bool
1736
+ if isinstance(val, bool):
1737
+ return "true" if val else "false"
1738
+ # str
1739
+ if isinstance(val, str):
1740
+ return f'"{_escape_str(val)}"'
1741
+ # num
1742
+ elif isinstance(val, (int | float)):
1743
+ return str(val)
1744
+ # time
1745
+ elif isinstance(val, datetime):
1746
+ return f'datetime("{val.isoformat()}")'
1747
+ # list
1748
+ elif isinstance(val, list):
1749
+ if key == self.dim_field:
1750
+ dim = len(val)
1751
+ joined = ",".join(str(float(x)) for x in val)
1752
+ return f"VECTOR<{dim}, FLOAT>([{joined}])"
1753
+ else:
1754
+ return f"[{', '.join(self._format_value(v) for v in val)}]"
1755
+ # NVector
1756
+ elif isinstance(val, NVector):
1757
+ if key == self.dim_field:
1758
+ dim = len(val)
1759
+ joined = ",".join(str(float(x)) for x in val)
1760
+ return f"VECTOR<{dim}, FLOAT>([{joined}])"
1761
+ else:
1762
+ logger.warning("Invalid NVector")
1763
+ # dict
1764
+ if isinstance(val, dict):
1765
+ j = json.dumps(val, ensure_ascii=False, separators=(",", ":"))
1766
+ return f'"{_escape_str(j)}"'
1767
+ else:
1768
+ return f'"{_escape_str(str(val))}"'
1769
+
1770
+ @timed
1771
+ def _metadata_filter(self, metadata: dict[str, Any]) -> dict[str, Any]:
1772
+ """
1773
+ Filter and validate metadata dictionary against the Memory node schema.
1774
+ - Removes keys not in schema.
1775
+ - Warns if required fields are missing.
1776
+ """
1777
+
1778
+ dim_fields = {self.dim_field}
1779
+
1780
+ allowed_fields = self.common_fields | dim_fields
1781
+
1782
+ missing_fields = allowed_fields - metadata.keys()
1783
+ if missing_fields:
1784
+ logger.info(f"Metadata missing required fields: {sorted(missing_fields)}")
1785
+
1786
+ filtered_metadata = {k: v for k, v in metadata.items() if k in allowed_fields}
1787
+
1788
+ return filtered_metadata
1789
+
1790
+ def _build_return_fields(self, include_embedding: bool = False) -> str:
1791
+ fields = set(self.base_fields)
1792
+ if include_embedding:
1793
+ fields.add(self.dim_field)
1794
+ return ", ".join(f"n.{f} AS {f}" for f in fields)