MemoryOS 1.0.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MemoryOS might be problematic. Click here for more details.

Files changed (94) hide show
  1. {memoryos-1.0.0.dist-info → memoryos-1.1.1.dist-info}/METADATA +8 -2
  2. {memoryos-1.0.0.dist-info → memoryos-1.1.1.dist-info}/RECORD +92 -69
  3. {memoryos-1.0.0.dist-info → memoryos-1.1.1.dist-info}/WHEEL +1 -1
  4. memos/__init__.py +1 -1
  5. memos/api/client.py +109 -0
  6. memos/api/config.py +35 -8
  7. memos/api/context/dependencies.py +15 -66
  8. memos/api/middleware/request_context.py +63 -0
  9. memos/api/product_api.py +5 -2
  10. memos/api/product_models.py +107 -16
  11. memos/api/routers/product_router.py +62 -19
  12. memos/api/start_api.py +13 -0
  13. memos/configs/graph_db.py +4 -0
  14. memos/configs/mem_scheduler.py +38 -3
  15. memos/configs/memory.py +13 -0
  16. memos/configs/reranker.py +18 -0
  17. memos/context/context.py +255 -0
  18. memos/embedders/factory.py +2 -0
  19. memos/graph_dbs/base.py +4 -2
  20. memos/graph_dbs/nebular.py +368 -223
  21. memos/graph_dbs/neo4j.py +49 -13
  22. memos/graph_dbs/neo4j_community.py +13 -3
  23. memos/llms/factory.py +2 -0
  24. memos/llms/openai.py +74 -2
  25. memos/llms/vllm.py +2 -0
  26. memos/log.py +128 -4
  27. memos/mem_cube/general.py +3 -1
  28. memos/mem_os/core.py +89 -23
  29. memos/mem_os/main.py +3 -6
  30. memos/mem_os/product.py +418 -154
  31. memos/mem_os/utils/reference_utils.py +20 -0
  32. memos/mem_reader/factory.py +2 -0
  33. memos/mem_reader/simple_struct.py +204 -82
  34. memos/mem_scheduler/analyzer/__init__.py +0 -0
  35. memos/mem_scheduler/analyzer/mos_for_test_scheduler.py +569 -0
  36. memos/mem_scheduler/analyzer/scheduler_for_eval.py +280 -0
  37. memos/mem_scheduler/base_scheduler.py +126 -56
  38. memos/mem_scheduler/general_modules/dispatcher.py +2 -2
  39. memos/mem_scheduler/general_modules/misc.py +99 -1
  40. memos/mem_scheduler/general_modules/scheduler_logger.py +17 -11
  41. memos/mem_scheduler/general_scheduler.py +40 -88
  42. memos/mem_scheduler/memory_manage_modules/__init__.py +5 -0
  43. memos/mem_scheduler/memory_manage_modules/memory_filter.py +308 -0
  44. memos/mem_scheduler/{general_modules → memory_manage_modules}/retriever.py +34 -7
  45. memos/mem_scheduler/monitors/dispatcher_monitor.py +9 -8
  46. memos/mem_scheduler/monitors/general_monitor.py +119 -39
  47. memos/mem_scheduler/optimized_scheduler.py +124 -0
  48. memos/mem_scheduler/orm_modules/__init__.py +0 -0
  49. memos/mem_scheduler/orm_modules/base_model.py +635 -0
  50. memos/mem_scheduler/orm_modules/monitor_models.py +261 -0
  51. memos/mem_scheduler/scheduler_factory.py +2 -0
  52. memos/mem_scheduler/schemas/monitor_schemas.py +96 -29
  53. memos/mem_scheduler/utils/config_utils.py +100 -0
  54. memos/mem_scheduler/utils/db_utils.py +33 -0
  55. memos/mem_scheduler/utils/filter_utils.py +1 -1
  56. memos/mem_scheduler/webservice_modules/__init__.py +0 -0
  57. memos/mem_user/mysql_user_manager.py +4 -2
  58. memos/memories/activation/kv.py +2 -1
  59. memos/memories/textual/item.py +96 -17
  60. memos/memories/textual/naive.py +1 -1
  61. memos/memories/textual/tree.py +57 -3
  62. memos/memories/textual/tree_text_memory/organize/handler.py +4 -2
  63. memos/memories/textual/tree_text_memory/organize/manager.py +28 -14
  64. memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +1 -2
  65. memos/memories/textual/tree_text_memory/organize/reorganizer.py +75 -23
  66. memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +10 -6
  67. memos/memories/textual/tree_text_memory/retrieve/internet_retriever.py +6 -2
  68. memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +2 -0
  69. memos/memories/textual/tree_text_memory/retrieve/recall.py +119 -21
  70. memos/memories/textual/tree_text_memory/retrieve/searcher.py +172 -44
  71. memos/memories/textual/tree_text_memory/retrieve/utils.py +6 -4
  72. memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py +5 -4
  73. memos/memos_tools/notification_utils.py +46 -0
  74. memos/memos_tools/singleton.py +174 -0
  75. memos/memos_tools/thread_safe_dict.py +22 -0
  76. memos/memos_tools/thread_safe_dict_segment.py +382 -0
  77. memos/parsers/factory.py +2 -0
  78. memos/reranker/__init__.py +4 -0
  79. memos/reranker/base.py +24 -0
  80. memos/reranker/concat.py +59 -0
  81. memos/reranker/cosine_local.py +96 -0
  82. memos/reranker/factory.py +48 -0
  83. memos/reranker/http_bge.py +312 -0
  84. memos/reranker/noop.py +16 -0
  85. memos/templates/mem_reader_prompts.py +289 -40
  86. memos/templates/mem_scheduler_prompts.py +242 -0
  87. memos/templates/mos_prompts.py +133 -60
  88. memos/types.py +4 -1
  89. memos/api/context/context.py +0 -147
  90. memos/mem_scheduler/mos_for_test_scheduler.py +0 -146
  91. {memoryos-1.0.0.dist-info → memoryos-1.1.1.dist-info}/entry_points.txt +0 -0
  92. {memoryos-1.0.0.dist-info → memoryos-1.1.1.dist-info/licenses}/LICENSE +0 -0
  93. /memos/mem_scheduler/{general_modules → webservice_modules}/rabbitmq_service.py +0 -0
  94. /memos/mem_scheduler/{general_modules → webservice_modules}/redis_service.py +0 -0
@@ -1,10 +1,10 @@
1
+ import json
1
2
  import traceback
2
3
 
3
4
  from contextlib import suppress
4
5
  from datetime import datetime
5
- from queue import Empty, Queue
6
6
  from threading import Lock
7
- from typing import Any, Literal
7
+ from typing import TYPE_CHECKING, Any, ClassVar, Literal
8
8
 
9
9
  import numpy as np
10
10
 
@@ -15,9 +15,28 @@ from memos.log import get_logger
15
15
  from memos.utils import timed
16
16
 
17
17
 
18
+ if TYPE_CHECKING:
19
+ from nebulagraph_python import (
20
+ NebulaClient,
21
+ )
22
+
23
+
18
24
  logger = get_logger(__name__)
19
25
 
20
26
 
27
+ _TRANSIENT_ERR_KEYS = (
28
+ "Session not found",
29
+ "Connection not established",
30
+ "timeout",
31
+ "deadline exceeded",
32
+ "Broken pipe",
33
+ "EOFError",
34
+ "socket closed",
35
+ "connection reset",
36
+ "connection refused",
37
+ )
38
+
39
+
21
40
  @timed
22
41
  def _normalize(vec: list[float]) -> list[float]:
23
42
  v = np.asarray(vec, dtype=np.float32)
@@ -35,7 +54,28 @@ def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]:
35
54
 
36
55
  @timed
37
56
  def _escape_str(value: str) -> str:
38
- return value.replace('"', '\\"')
57
+ out = []
58
+ for ch in value:
59
+ code = ord(ch)
60
+ if ch == "\\":
61
+ out.append("\\\\")
62
+ elif ch == '"':
63
+ out.append('\\"')
64
+ elif ch == "\n":
65
+ out.append("\\n")
66
+ elif ch == "\r":
67
+ out.append("\\r")
68
+ elif ch == "\t":
69
+ out.append("\\t")
70
+ elif ch == "\b":
71
+ out.append("\\b")
72
+ elif ch == "\f":
73
+ out.append("\\f")
74
+ elif code < 0x20 or code in (0x2028, 0x2029):
75
+ out.append(f"\\u{code:04x}")
76
+ else:
77
+ out.append(ch)
78
+ return "".join(out)
39
79
 
40
80
 
41
81
  @timed
@@ -61,145 +101,202 @@ def _normalize_datetime(val):
61
101
  return str(val)
62
102
 
63
103
 
64
- class SessionPoolError(Exception):
65
- pass
66
-
67
-
68
- class SessionPool:
69
- @require_python_package(
70
- import_name="nebulagraph_python",
71
- install_command="pip install ... @Tianxing",
72
- install_link=".....",
73
- )
74
- def __init__(
75
- self,
76
- hosts: list[str],
77
- user: str,
78
- password: str,
79
- minsize: int = 1,
80
- maxsize: int = 10000,
81
- ):
82
- self.hosts = hosts
83
- self.user = user
84
- self.password = password
85
- self.minsize = minsize
86
- self.maxsize = maxsize
87
- self.pool = Queue(maxsize)
88
- self.lock = Lock()
89
-
90
- self.clients = []
91
-
92
- for _ in range(minsize):
93
- self._create_and_add_client()
94
-
95
- @timed
96
- def _create_and_add_client(self):
97
- from nebulagraph_python import NebulaClient
98
-
99
- client = NebulaClient(self.hosts, self.user, self.password)
100
- self.pool.put(client)
101
- self.clients.append(client)
102
-
103
- @timed
104
- def get_client(self, timeout: float = 5.0):
105
- try:
106
- return self.pool.get(timeout=timeout)
107
- except Empty:
108
- with self.lock:
109
- if len(self.clients) < self.maxsize:
110
- from nebulagraph_python import NebulaClient
111
-
112
- client = NebulaClient(self.hosts, self.user, self.password)
113
- self.clients.append(client)
114
- return client
115
- raise RuntimeError("NebulaClientPool exhausted") from None
116
-
117
- @timed
118
- def return_client(self, client):
119
- try:
120
- client.execute("YIELD 1")
121
- self.pool.put(client)
122
- except Exception:
123
- logger.info("[Pool] Client dead, replacing...")
124
- self.replace_client(client)
125
-
126
- @timed
127
- def close(self):
128
- for client in self.clients:
129
- with suppress(Exception):
130
- client.close()
131
- self.clients.clear()
104
+ class NebulaGraphDB(BaseGraphDB):
105
+ """
106
+ NebulaGraph-based implementation of a graph memory store.
107
+ """
132
108
 
133
- @timed
134
- def get(self):
135
- """
136
- Context manager: with pool.get() as client:
137
- """
109
+ # ====== shared pool cache & refcount ======
110
+ # These are process-local; in a multi-process model each process will
111
+ # have its own cache.
112
+ _CLIENT_CACHE: ClassVar[dict[str, "NebulaClient"]] = {}
113
+ _CLIENT_REFCOUNT: ClassVar[dict[str, int]] = {}
114
+ _CLIENT_LOCK: ClassVar[Lock] = Lock()
115
+ _CLIENT_INIT_DONE: ClassVar[set[str]] = set()
116
+
117
+ @staticmethod
118
+ def _get_hosts_from_cfg(cfg: NebulaGraphDBConfig) -> list[str]:
119
+ hosts = getattr(cfg, "uri", None) or getattr(cfg, "hosts", None)
120
+ if isinstance(hosts, str):
121
+ return [hosts]
122
+ return list(hosts or [])
123
+
124
+ @staticmethod
125
+ def _make_client_key(cfg: NebulaGraphDBConfig) -> str:
126
+ hosts = NebulaGraphDB._get_hosts_from_cfg(cfg)
127
+ return "|".join(
128
+ [
129
+ "nebula-sync",
130
+ ",".join(hosts),
131
+ str(getattr(cfg, "user", "")),
132
+ str(getattr(cfg, "use_multi_db", False)),
133
+ str(getattr(cfg, "space", "")),
134
+ ]
135
+ )
138
136
 
139
- class _ClientContext:
140
- def __init__(self, outer):
141
- self.outer = outer
142
- self.client = None
137
+ @classmethod
138
+ def _bootstrap_admin(cls, cfg: NebulaGraphDBConfig, client: "NebulaClient") -> "NebulaGraphDB":
139
+ tmp = object.__new__(NebulaGraphDB)
140
+ tmp.config = cfg
141
+ tmp.db_name = cfg.space
142
+ tmp.user_name = getattr(cfg, "user_name", None)
143
+ tmp.embedding_dimension = getattr(cfg, "embedding_dimension", 3072)
144
+ tmp.default_memory_dimension = 3072
145
+ tmp.common_fields = {
146
+ "id",
147
+ "memory",
148
+ "user_name",
149
+ "user_id",
150
+ "session_id",
151
+ "status",
152
+ "key",
153
+ "confidence",
154
+ "tags",
155
+ "created_at",
156
+ "updated_at",
157
+ "memory_type",
158
+ "sources",
159
+ "source",
160
+ "node_type",
161
+ "visibility",
162
+ "usage",
163
+ "background",
164
+ }
165
+ tmp.base_fields = set(tmp.common_fields) - {"usage"}
166
+ tmp.heavy_fields = {"usage"}
167
+ tmp.dim_field = (
168
+ f"embedding_{tmp.embedding_dimension}"
169
+ if str(tmp.embedding_dimension) != str(tmp.default_memory_dimension)
170
+ else "embedding"
171
+ )
172
+ tmp.system_db_name = "system" if getattr(cfg, "use_multi_db", False) else cfg.space
173
+ tmp._client = client
174
+ tmp._owns_client = False
175
+ return tmp
176
+
177
+ @classmethod
178
+ def _get_or_create_shared_client(cls, cfg: NebulaGraphDBConfig) -> tuple[str, "NebulaClient"]:
179
+ from nebulagraph_python import (
180
+ ConnectionConfig,
181
+ NebulaClient,
182
+ SessionConfig,
183
+ SessionPoolConfig,
184
+ )
143
185
 
144
- def __enter__(self):
145
- self.client = self.outer.get_client()
146
- return self.client
186
+ key = cls._make_client_key(cfg)
187
+ with cls._CLIENT_LOCK:
188
+ client = cls._CLIENT_CACHE.get(key)
189
+ if client is None:
190
+ # Connection setting
191
+ conn_conf: ConnectionConfig | None = getattr(cfg, "conn_config", None)
192
+ if conn_conf is None:
193
+ conn_conf = ConnectionConfig.from_defults(
194
+ cls._get_hosts_from_cfg(cfg),
195
+ getattr(cfg, "ssl_param", None),
196
+ )
197
+
198
+ sess_conf = SessionConfig(graph=getattr(cfg, "space", None))
199
+ pool_conf = SessionPoolConfig(
200
+ size=int(getattr(cfg, "max_client", 1000)), wait_timeout=5000
201
+ )
147
202
 
148
- def __exit__(self, exc_type, exc_val, exc_tb):
149
- if self.client:
150
- self.outer.return_client(self.client)
203
+ client = NebulaClient(
204
+ hosts=conn_conf.hosts,
205
+ username=cfg.user,
206
+ password=cfg.password,
207
+ conn_config=conn_conf,
208
+ session_config=sess_conf,
209
+ session_pool_config=pool_conf,
210
+ )
211
+ cls._CLIENT_CACHE[key] = client
212
+ cls._CLIENT_REFCOUNT[key] = 0
213
+ logger.info(f"[NebulaGraphDBSync] Created shared NebulaClient key={key}")
151
214
 
152
- return _ClientContext(self)
215
+ cls._CLIENT_REFCOUNT[key] = cls._CLIENT_REFCOUNT.get(key, 0) + 1
153
216
 
154
- @timed
155
- def reset_pool(self):
156
- """⚠️ Emergency reset: Close all clients and clear the pool."""
157
- logger.warning("[Pool] Resetting all clients. Existing sessions will be lost.")
158
- with self.lock:
159
- for client in self.clients:
217
+ if getattr(cfg, "auto_create", False) and key not in cls._CLIENT_INIT_DONE:
160
218
  try:
161
- client.close()
162
- except Exception:
163
- logger.error("Fail to close!!!")
164
- self.clients.clear()
165
- while not self.pool.empty():
219
+ pass
220
+ finally:
221
+ pass
222
+
223
+ if getattr(cfg, "auto_create", False) and key not in cls._CLIENT_INIT_DONE:
224
+ with cls._CLIENT_LOCK:
225
+ if key not in cls._CLIENT_INIT_DONE:
226
+ admin = cls._bootstrap_admin(cfg, client)
227
+ try:
228
+ admin._ensure_database_exists()
229
+ admin._create_basic_property_indexes()
230
+ admin._create_vector_index(
231
+ dimensions=int(
232
+ admin.embedding_dimension or admin.default_memory_dimension
233
+ ),
234
+ )
235
+ cls._CLIENT_INIT_DONE.add(key)
236
+ logger.info("[NebulaGraphDBSync] One-time init done")
237
+ except Exception:
238
+ logger.exception("[NebulaGraphDBSync] One-time init failed")
239
+
240
+ return key, client
241
+
242
+ def _refresh_client(self):
243
+ """
244
+ refresh NebulaClient:
245
+ """
246
+ old_key = getattr(self, "_client_key", None)
247
+ if not old_key:
248
+ return
249
+
250
+ cls = self.__class__
251
+ with cls._CLIENT_LOCK:
252
+ try:
253
+ if old_key in cls._CLIENT_CACHE:
254
+ try:
255
+ cls._CLIENT_CACHE[old_key].close()
256
+ except Exception as e:
257
+ logger.warning(f"[refresh_client] close old client error: {e}")
258
+ finally:
259
+ cls._CLIENT_CACHE.pop(old_key, None)
260
+ finally:
261
+ cls._CLIENT_REFCOUNT[old_key] = 0
262
+
263
+ new_key, new_client = cls._get_or_create_shared_client(self.config)
264
+ self._client_key = new_key
265
+ self._client = new_client
266
+ logger.info(f"[NebulaGraphDBSync] client refreshed: {old_key} -> {new_key}")
267
+
268
+ @classmethod
269
+ def _release_shared_client(cls, key: str):
270
+ with cls._CLIENT_LOCK:
271
+ if key not in cls._CLIENT_CACHE:
272
+ return
273
+ cls._CLIENT_REFCOUNT[key] = max(0, cls._CLIENT_REFCOUNT.get(key, 0) - 1)
274
+ if cls._CLIENT_REFCOUNT[key] == 0:
166
275
  try:
167
- self.pool.get_nowait()
168
- except Empty:
169
- break
170
- for _ in range(self.minsize):
171
- self._create_and_add_client()
172
- logger.info("[Pool] Pool has been reset successfully.")
173
-
174
- @timed
175
- def replace_client(self, client):
176
- try:
177
- client.close()
178
- except Exception:
179
- logger.error("Fail to close client")
180
-
181
- if client in self.clients:
182
- self.clients.remove(client)
183
-
184
- from nebulagraph_python import NebulaClient
185
-
186
- new_client = NebulaClient(self.hosts, self.user, self.password)
187
- self.clients.append(new_client)
188
-
189
- self.pool.put(new_client)
190
-
191
- logger.info("[Pool] Replaced dead client with a new one.")
192
- return new_client
193
-
194
-
195
- class NebulaGraphDB(BaseGraphDB):
196
- """
197
- NebulaGraph-based implementation of a graph memory store.
198
- """
276
+ cls._CLIENT_CACHE[key].close()
277
+ except Exception as e:
278
+ logger.warning(f"[NebulaGraphDBSync] Error closing client: {e}")
279
+ finally:
280
+ cls._CLIENT_CACHE.pop(key, None)
281
+ cls._CLIENT_REFCOUNT.pop(key, None)
282
+ logger.info(f"[NebulaGraphDBSync] Closed & removed client key={key}")
283
+
284
+ @classmethod
285
+ def close_all_shared_clients(cls):
286
+ with cls._CLIENT_LOCK:
287
+ for key, client in list(cls._CLIENT_CACHE.items()):
288
+ try:
289
+ client.close()
290
+ except Exception as e:
291
+ logger.warning(f"[NebulaGraphDBSync] Error closing client {key}: {e}")
292
+ finally:
293
+ logger.info(f"[NebulaGraphDBSync] Closed client key={key}")
294
+ cls._CLIENT_CACHE.clear()
295
+ cls._CLIENT_REFCOUNT.clear()
199
296
 
200
297
  @require_python_package(
201
298
  import_name="nebulagraph_python",
202
- install_command="pip install ... @Tianxing",
299
+ install_command="pip install nebulagraph-python>=5.1.1",
203
300
  install_link=".....",
204
301
  )
205
302
  def __init__(self, config: NebulaGraphDBConfig):
@@ -246,48 +343,65 @@ class NebulaGraphDB(BaseGraphDB):
246
343
  "usage",
247
344
  "background",
248
345
  }
346
+ self.base_fields = set(self.common_fields) - {"usage"}
347
+ self.heavy_fields = {"usage"}
249
348
  self.dim_field = (
250
349
  f"embedding_{self.embedding_dimension}"
251
350
  if (str(self.embedding_dimension) != str(self.default_memory_dimension))
252
351
  else "embedding"
253
352
  )
254
353
  self.system_db_name = "system" if config.use_multi_db else config.space
255
- self.pool = SessionPool(
256
- hosts=config.get("uri"),
257
- user=config.get("user"),
258
- password=config.get("password"),
259
- minsize=1,
260
- maxsize=config.get("max_client", 1000),
261
- )
262
-
263
- if config.auto_create:
264
- self._ensure_database_exists()
265
-
266
- self.execute_query(f"SESSION SET GRAPH `{self.db_name}`")
267
354
 
268
- # Create only if not exists
269
- self.create_index(dimensions=config.embedding_dimension)
355
+ # ---- NEW: pool acquisition strategy
356
+ # Get or create a shared pool from the class-level cache
357
+ self._client_key, self._client = self._get_or_create_shared_client(config)
358
+ self._owns_client = True
270
359
 
271
360
  logger.info("Connected to NebulaGraph successfully.")
272
361
 
273
362
  @timed
274
- def execute_query(self, gql: str, timeout: float = 5.0, auto_set_db: bool = True):
275
- with self.pool.get() as client:
276
- try:
277
- if auto_set_db and self.db_name:
278
- client.execute(f"SESSION SET GRAPH `{self.db_name}`")
279
- return client.execute(gql, timeout=timeout)
363
+ def execute_query(self, gql: str, timeout: float = 60.0, auto_set_db: bool = True):
364
+ def _wrap_use_db(q: str) -> str:
365
+ if auto_set_db and self.db_name:
366
+ return f"USE `{self.db_name}`\n{q}"
367
+ return q
280
368
 
281
- except Exception as e:
282
- if "Session not found" in str(e) or "Connection not established" in str(e):
283
- logger.warning(f"[execute_query] {e!s}, replacing client...")
284
- self.pool.replace_client(client)
285
- return self.execute_query(gql, timeout, auto_set_db)
286
- raise
369
+ try:
370
+ return self._client.execute(_wrap_use_db(gql), timeout=timeout)
371
+
372
+ except Exception as e:
373
+ emsg = str(e)
374
+ if any(k.lower() in emsg.lower() for k in _TRANSIENT_ERR_KEYS):
375
+ logger.warning(f"[execute_query] {e!s} → refreshing session pool and retry once...")
376
+ try:
377
+ self._refresh_client()
378
+ return self._client.execute(_wrap_use_db(gql), timeout=timeout)
379
+ except Exception:
380
+ logger.exception("[execute_query] retry after refresh failed")
381
+ raise
382
+ raise
287
383
 
288
384
  @timed
289
385
  def close(self):
290
- self.pool.close()
386
+ """
387
+ Close the connection resource if this instance owns it.
388
+
389
+ - If pool was injected (`shared_pool`), do nothing.
390
+ - If pool was acquired via shared cache, decrement refcount and close
391
+ when the last owner releases it.
392
+ """
393
+ if not self._owns_client:
394
+ logger.debug("[NebulaGraphDBSync] close() skipped (injected client).")
395
+ return
396
+ if self._client_key:
397
+ self._release_shared_client(self._client_key)
398
+ self._client_key = None
399
+ self._client = None
400
+
401
+ # NOTE: __del__ is best-effort; do not rely on GC order.
402
+ def __del__(self):
403
+ with suppress(Exception):
404
+ self.close()
291
405
 
292
406
  @timed
293
407
  def create_index(
@@ -366,12 +480,10 @@ class NebulaGraphDB(BaseGraphDB):
366
480
  filter_clause = f'n.memory_type = "{scope}" AND n.user_name = "{self.config.user_name}"'
367
481
  else:
368
482
  filter_clause = f'n.memory_type = "{scope}"'
369
- return_fields = ", ".join(f"n.{field} AS {field}" for field in self.common_fields)
370
-
371
483
  query = f"""
372
484
  MATCH (n@Memory)
373
485
  WHERE {filter_clause}
374
- RETURN {return_fields}
486
+ RETURN n.id AS id
375
487
  LIMIT 1
376
488
  """
377
489
 
@@ -568,10 +680,7 @@ class NebulaGraphDB(BaseGraphDB):
568
680
  try:
569
681
  result = self.execute_query(gql)
570
682
  for row in result:
571
- if include_embedding:
572
- props = row.values()[0].as_node().get_properties()
573
- else:
574
- props = {k: v.value for k, v in row.items()}
683
+ props = {k: v.value for k, v in row.items()}
575
684
  node = self._parse_node(props)
576
685
  return node
577
686
 
@@ -582,7 +691,9 @@ class NebulaGraphDB(BaseGraphDB):
582
691
  return None
583
692
 
584
693
  @timed
585
- def get_nodes(self, ids: list[str], include_embedding: bool = False) -> list[dict[str, Any]]:
694
+ def get_nodes(
695
+ self, ids: list[str], include_embedding: bool = False, **kwargs
696
+ ) -> list[dict[str, Any]]:
586
697
  """
587
698
  Retrieve the metadata and memory of a list of nodes.
588
699
  Args:
@@ -600,7 +711,10 @@ class NebulaGraphDB(BaseGraphDB):
600
711
 
601
712
  where_user = ""
602
713
  if not self.config.use_multi_db and self.config.user_name:
603
- where_user = f" AND n.user_name = '{self.config.user_name}'"
714
+ if kwargs.get("cube_name"):
715
+ where_user = f" AND n.user_name = '{kwargs['cube_name']}'"
716
+ else:
717
+ where_user = f" AND n.user_name = '{self.config.user_name}'"
604
718
 
605
719
  # Safe formatting of the ID list
606
720
  id_list = ",".join(f'"{_id}"' for _id in ids)
@@ -615,10 +729,7 @@ class NebulaGraphDB(BaseGraphDB):
615
729
  try:
616
730
  results = self.execute_query(query)
617
731
  for row in results:
618
- if include_embedding:
619
- props = row.values()[0].as_node().get_properties()
620
- else:
621
- props = {k: v.value for k, v in row.items()}
732
+ props = {k: v.value for k, v in row.items()}
622
733
  nodes.append(self._parse_node(props))
623
734
  except Exception as e:
624
735
  logger.error(
@@ -687,6 +798,7 @@ class NebulaGraphDB(BaseGraphDB):
687
798
  exclude_ids: list[str],
688
799
  top_k: int = 5,
689
800
  min_overlap: int = 1,
801
+ include_embedding: bool = False,
690
802
  ) -> list[dict[str, Any]]:
691
803
  """
692
804
  Find top-K neighbor nodes with maximum tag overlap.
@@ -696,6 +808,7 @@ class NebulaGraphDB(BaseGraphDB):
696
808
  exclude_ids: Node IDs to exclude (e.g., local cluster).
697
809
  top_k: Max number of neighbors to return.
698
810
  min_overlap: Minimum number of overlapping tags required.
811
+ include_embedding: with/without embedding
699
812
 
700
813
  Returns:
701
814
  List of dicts with node details and overlap count.
@@ -717,12 +830,13 @@ class NebulaGraphDB(BaseGraphDB):
717
830
  where_clause = " AND ".join(where_clauses)
718
831
  tag_list_literal = "[" + ", ".join(f'"{_escape_str(t)}"' for t in tags) + "]"
719
832
 
833
+ return_fields = self._build_return_fields(include_embedding)
720
834
  query = f"""
721
835
  LET tag_list = {tag_list_literal}
722
836
 
723
837
  MATCH (n@Memory)
724
838
  WHERE {where_clause}
725
- RETURN n,
839
+ RETURN {return_fields},
726
840
  size( filter( n.tags, t -> t IN tag_list ) ) AS overlap_count
727
841
  ORDER BY overlap_count DESC
728
842
  LIMIT {top_k}
@@ -731,9 +845,8 @@ class NebulaGraphDB(BaseGraphDB):
731
845
  result = self.execute_query(query)
732
846
  neighbors: list[dict[str, Any]] = []
733
847
  for r in result:
734
- node_props = r["n"].as_node().get_properties()
735
- parsed = self._parse_node(node_props) # --> {id, memory, metadata}
736
-
848
+ props = {k: v.value for k, v in r.items() if k != "overlap_count"}
849
+ parsed = self._parse_node(props)
737
850
  parsed["overlap_count"] = r["overlap_count"].value
738
851
  neighbors.append(parsed)
739
852
 
@@ -840,6 +953,8 @@ class NebulaGraphDB(BaseGraphDB):
840
953
  scope: str | None = None,
841
954
  status: str | None = None,
842
955
  threshold: float | None = None,
956
+ search_filter: dict | None = None,
957
+ **kwargs,
843
958
  ) -> list[dict]:
844
959
  """
845
960
  Retrieve node IDs based on vector similarity.
@@ -851,6 +966,8 @@ class NebulaGraphDB(BaseGraphDB):
851
966
  status (str, optional): Node status filter (e.g., 'active', 'archived').
852
967
  If provided, restricts results to nodes with matching status.
853
968
  threshold (float, optional): Minimum similarity score threshold (0 ~ 1).
969
+ search_filter (dict, optional): Additional metadata filters for search results.
970
+ Keys should match node properties, values are the expected values.
854
971
 
855
972
  Returns:
856
973
  list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
@@ -860,6 +977,7 @@ class NebulaGraphDB(BaseGraphDB):
860
977
  - If scope is provided, it restricts results to nodes with matching memory_type.
861
978
  - If 'status' is provided, only nodes with the matching status will be returned.
862
979
  - If threshold is provided, only results with score >= threshold will be returned.
980
+ - If search_filter is provided, additional WHERE clauses will be added for metadata filtering.
863
981
  - Typical use case: restrict to 'status = activated' to avoid
864
982
  matching archived or merged nodes.
865
983
  """
@@ -874,12 +992,22 @@ class NebulaGraphDB(BaseGraphDB):
874
992
  if status:
875
993
  where_clauses.append(f'n.status = "{status}"')
876
994
  if not self.config.use_multi_db and self.config.user_name:
877
- where_clauses.append(f'n.user_name = "{self.config.user_name}"')
995
+ if kwargs.get("cube_name"):
996
+ where_clauses.append(f'n.user_name = "{kwargs["cube_name"]}"')
997
+ else:
998
+ where_clauses.append(f'n.user_name = "{self.config.user_name}"')
999
+
1000
+ # Add search_filter conditions
1001
+ if search_filter:
1002
+ for key, value in search_filter.items():
1003
+ if isinstance(value, str):
1004
+ where_clauses.append(f'n.{key} = "{value}"')
1005
+ else:
1006
+ where_clauses.append(f"n.{key} = {value}")
878
1007
 
879
1008
  where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
880
1009
 
881
1010
  gql = f"""
882
- USE `{self.db_name}`
883
1011
  MATCH (n@Memory)
884
1012
  {where_clause}
885
1013
  ORDER BY inner_product(n.{self.dim_field}, {gql_vector}) DESC
@@ -902,7 +1030,7 @@ class NebulaGraphDB(BaseGraphDB):
902
1030
  id_val = values[0].as_string()
903
1031
  score_val = values[1].as_double()
904
1032
  score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score
905
- if threshold is None or score_val <= threshold:
1033
+ if threshold is None or score_val >= threshold:
906
1034
  output.append({"id": id_val, "score": score_val})
907
1035
  return output
908
1036
  except Exception as e:
@@ -936,20 +1064,12 @@ class NebulaGraphDB(BaseGraphDB):
936
1064
  """
937
1065
  where_clauses = []
938
1066
 
939
- def _escape_value(value):
940
- if isinstance(value, str):
941
- return f'"{value}"'
942
- elif isinstance(value, list):
943
- return "[" + ", ".join(_escape_value(v) for v in value) + "]"
944
- else:
945
- return str(value)
946
-
947
1067
  for _i, f in enumerate(filters):
948
1068
  field = f["field"]
949
1069
  op = f.get("op", "=")
950
1070
  value = f["value"]
951
1071
 
952
- escaped_value = _escape_value(value)
1072
+ escaped_value = self._format_value(value)
953
1073
 
954
1074
  # Build WHERE clause
955
1075
  if op == "=":
@@ -1153,28 +1273,36 @@ class NebulaGraphDB(BaseGraphDB):
1153
1273
  data: A dictionary containing all nodes and edges to be loaded.
1154
1274
  """
1155
1275
  for node in data.get("nodes", []):
1156
- id, memory, metadata = _compose_node(node)
1276
+ try:
1277
+ id, memory, metadata = _compose_node(node)
1157
1278
 
1158
- if not self.config.use_multi_db and self.config.user_name:
1159
- metadata["user_name"] = self.config.user_name
1279
+ if not self.config.use_multi_db and self.config.user_name:
1280
+ metadata["user_name"] = self.config.user_name
1160
1281
 
1161
- metadata = self._prepare_node_metadata(metadata)
1162
- metadata.update({"id": id, "memory": memory})
1163
- properties = ", ".join(f"{k}: {self._format_value(v, k)}" for k, v in metadata.items())
1164
- node_gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
1165
- self.execute_query(node_gql)
1282
+ metadata = self._prepare_node_metadata(metadata)
1283
+ metadata.update({"id": id, "memory": memory})
1284
+ properties = ", ".join(
1285
+ f"{k}: {self._format_value(v, k)}" for k, v in metadata.items()
1286
+ )
1287
+ node_gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
1288
+ self.execute_query(node_gql)
1289
+ except Exception as e:
1290
+ logger.error(f"Fail to load node: {node}, error: {e}")
1166
1291
 
1167
1292
  for edge in data.get("edges", []):
1168
- source_id, target_id = edge["source"], edge["target"]
1169
- edge_type = edge["type"]
1170
- props = ""
1171
- if not self.config.use_multi_db and self.config.user_name:
1172
- props = f'{{user_name: "{self.config.user_name}"}}'
1173
- edge_gql = f'''
1174
- MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}})
1175
- INSERT OR IGNORE (a) -[e@{edge_type} {props}]-> (b)
1176
- '''
1177
- self.execute_query(edge_gql)
1293
+ try:
1294
+ source_id, target_id = edge["source"], edge["target"]
1295
+ edge_type = edge["type"]
1296
+ props = ""
1297
+ if not self.config.use_multi_db and self.config.user_name:
1298
+ props = f'{{user_name: "{self.config.user_name}"}}'
1299
+ edge_gql = f'''
1300
+ MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}})
1301
+ INSERT OR IGNORE (a) -[e@{edge_type} {props}]-> (b)
1302
+ '''
1303
+ self.execute_query(edge_gql)
1304
+ except Exception as e:
1305
+ logger.error(f"Fail to load edge: {edge}, error: {e}")
1178
1306
 
1179
1307
  @timed
1180
1308
  def get_all_memory_items(self, scope: str, include_embedding: bool = False) -> (list)[dict]:
@@ -1208,10 +1336,7 @@ class NebulaGraphDB(BaseGraphDB):
1208
1336
  try:
1209
1337
  results = self.execute_query(query)
1210
1338
  for row in results:
1211
- if include_embedding:
1212
- props = row.values()[0].as_node().get_properties()
1213
- else:
1214
- props = {k: v.value for k, v in row.items()}
1339
+ props = {k: v.value for k, v in row.items()}
1215
1340
  nodes.append(self._parse_node(props))
1216
1341
  except Exception as e:
1217
1342
  logger.error(f"Failed to get memories: {e}")
@@ -1235,9 +1360,9 @@ class NebulaGraphDB(BaseGraphDB):
1235
1360
  where_clause += f' AND n.user_name = "{self.config.user_name}"'
1236
1361
 
1237
1362
  return_fields = self._build_return_fields(include_embedding)
1363
+ return_fields += f", n.{self.dim_field} AS {self.dim_field}"
1238
1364
 
1239
1365
  query = f"""
1240
- USE `{self.db_name}`
1241
1366
  MATCH (n@Memory)
1242
1367
  WHERE {where_clause}
1243
1368
  OPTIONAL MATCH (n)-[@PARENT]->(c@Memory)
@@ -1247,14 +1372,16 @@ class NebulaGraphDB(BaseGraphDB):
1247
1372
  """
1248
1373
 
1249
1374
  candidates = []
1375
+ node_ids = set()
1250
1376
  try:
1251
1377
  results = self.execute_query(query)
1252
1378
  for row in results:
1253
- if include_embedding:
1254
- props = row.values()[0].as_node().get_properties()
1255
- else:
1256
- props = {k: v.value for k, v in row.items()}
1257
- candidates.append(self._parse_node(props))
1379
+ props = {k: v.value for k, v in row.items()}
1380
+ node = self._parse_node(props)
1381
+ node_id = node["id"]
1382
+ if node_id not in node_ids:
1383
+ candidates.append(node)
1384
+ node_ids.add(node_id)
1258
1385
  except Exception as e:
1259
1386
  logger.error(f"Failed : {e}, traceback: {traceback.format_exc()}")
1260
1387
  return candidates
@@ -1408,18 +1535,19 @@ class NebulaGraphDB(BaseGraphDB):
1408
1535
  logger.info(f"✅ Graph Type {graph_type_name} already include {self.dim_field}")
1409
1536
 
1410
1537
  create_graph = f"CREATE GRAPH IF NOT EXISTS `{self.db_name}` TYPED {graph_type_name}"
1411
- set_graph_working = f"SESSION SET GRAPH `{self.db_name}`"
1412
-
1413
1538
  try:
1414
1539
  self.execute_query(create_graph, auto_set_db=False)
1415
- self.execute_query(set_graph_working)
1416
1540
  logger.info(f"✅ Graph ``{self.db_name}`` is now the working graph.")
1417
1541
  except Exception as e:
1418
1542
  logger.error(f"❌ Failed to create tag: {e} trace: {traceback.format_exc()}")
1419
1543
 
1420
1544
  @timed
1421
1545
  def _create_vector_index(
1422
- self, label: str, vector_property: str, dimensions: int, index_name: str
1546
+ self,
1547
+ label: str = "Memory",
1548
+ vector_property: str = "embedding",
1549
+ dimensions: int = 3072,
1550
+ index_name: str = "memory_vector_index",
1423
1551
  ) -> None:
1424
1552
  """
1425
1553
  Create a vector index for the specified property in the label.
@@ -1555,6 +1683,7 @@ class NebulaGraphDB(BaseGraphDB):
1555
1683
  # Normalize embedding type
1556
1684
  embedding = metadata.get("embedding")
1557
1685
  if embedding and isinstance(embedding, list):
1686
+ metadata.pop("embedding")
1558
1687
  metadata[self.dim_field] = _normalize([float(x) for x in embedding])
1559
1688
 
1560
1689
  return metadata
@@ -1563,12 +1692,22 @@ class NebulaGraphDB(BaseGraphDB):
1563
1692
  def _format_value(self, val: Any, key: str = "") -> str:
1564
1693
  from nebulagraph_python.py_data_types import NVector
1565
1694
 
1695
+ # None
1696
+ if val is None:
1697
+ return "NULL"
1698
+ # bool
1699
+ if isinstance(val, bool):
1700
+ return "true" if val else "false"
1701
+ # str
1566
1702
  if isinstance(val, str):
1567
1703
  return f'"{_escape_str(val)}"'
1704
+ # num
1568
1705
  elif isinstance(val, (int | float)):
1569
1706
  return str(val)
1707
+ # time
1570
1708
  elif isinstance(val, datetime):
1571
1709
  return f'datetime("{val.isoformat()}")'
1710
+ # list
1572
1711
  elif isinstance(val, list):
1573
1712
  if key == self.dim_field:
1574
1713
  dim = len(val)
@@ -1576,13 +1715,18 @@ class NebulaGraphDB(BaseGraphDB):
1576
1715
  return f"VECTOR<{dim}, FLOAT>([{joined}])"
1577
1716
  else:
1578
1717
  return f"[{', '.join(self._format_value(v) for v in val)}]"
1718
+ # NVector
1579
1719
  elif isinstance(val, NVector):
1580
1720
  if key == self.dim_field:
1581
1721
  dim = len(val)
1582
1722
  joined = ",".join(str(float(x)) for x in val)
1583
1723
  return f"VECTOR<{dim}, FLOAT>([{joined}])"
1584
- elif val is None:
1585
- return "NULL"
1724
+ else:
1725
+ logger.warning("Invalid NVector")
1726
+ # dict
1727
+ if isinstance(val, dict):
1728
+ j = json.dumps(val, ensure_ascii=False, separators=(",", ":"))
1729
+ return f'"{_escape_str(j)}"'
1586
1730
  else:
1587
1731
  return f'"{_escape_str(str(val))}"'
1588
1732
 
@@ -1607,6 +1751,7 @@ class NebulaGraphDB(BaseGraphDB):
1607
1751
  return filtered_metadata
1608
1752
 
1609
1753
  def _build_return_fields(self, include_embedding: bool = False) -> str:
1754
+ fields = set(self.base_fields)
1610
1755
  if include_embedding:
1611
- return "n"
1612
- return ", ".join(f"n.{field} AS {field}" for field in self.common_fields)
1756
+ fields.add(self.dim_field)
1757
+ return ", ".join(f"n.{f} AS {f}" for f in fields)