MemoryOS 0.2.2__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MemoryOS might be problematic. Click here for more details.

Files changed (82) hide show
  1. {memoryos-0.2.2.dist-info → memoryos-1.0.1.dist-info}/METADATA +7 -1
  2. {memoryos-0.2.2.dist-info → memoryos-1.0.1.dist-info}/RECORD +81 -66
  3. memos/__init__.py +1 -1
  4. memos/api/config.py +31 -8
  5. memos/api/context/context.py +1 -1
  6. memos/api/context/context_thread.py +96 -0
  7. memos/api/middleware/request_context.py +94 -0
  8. memos/api/product_api.py +5 -1
  9. memos/api/product_models.py +16 -0
  10. memos/api/routers/product_router.py +39 -3
  11. memos/api/start_api.py +3 -0
  12. memos/configs/internet_retriever.py +13 -0
  13. memos/configs/mem_scheduler.py +38 -16
  14. memos/configs/memory.py +13 -0
  15. memos/configs/reranker.py +18 -0
  16. memos/graph_dbs/base.py +33 -4
  17. memos/graph_dbs/nebular.py +631 -236
  18. memos/graph_dbs/neo4j.py +18 -7
  19. memos/graph_dbs/neo4j_community.py +6 -3
  20. memos/llms/vllm.py +2 -0
  21. memos/log.py +125 -8
  22. memos/mem_os/core.py +49 -11
  23. memos/mem_os/main.py +1 -1
  24. memos/mem_os/product.py +392 -215
  25. memos/mem_os/utils/default_config.py +1 -1
  26. memos/mem_os/utils/format_utils.py +11 -47
  27. memos/mem_os/utils/reference_utils.py +153 -0
  28. memos/mem_reader/simple_struct.py +112 -43
  29. memos/mem_scheduler/base_scheduler.py +58 -55
  30. memos/mem_scheduler/{modules → general_modules}/base.py +1 -2
  31. memos/mem_scheduler/{modules → general_modules}/dispatcher.py +54 -15
  32. memos/mem_scheduler/{modules → general_modules}/rabbitmq_service.py +4 -4
  33. memos/mem_scheduler/{modules → general_modules}/redis_service.py +1 -1
  34. memos/mem_scheduler/{modules → general_modules}/retriever.py +19 -5
  35. memos/mem_scheduler/{modules → general_modules}/scheduler_logger.py +10 -4
  36. memos/mem_scheduler/general_scheduler.py +110 -67
  37. memos/mem_scheduler/monitors/__init__.py +0 -0
  38. memos/mem_scheduler/monitors/dispatcher_monitor.py +305 -0
  39. memos/mem_scheduler/{modules/monitor.py → monitors/general_monitor.py} +57 -19
  40. memos/mem_scheduler/mos_for_test_scheduler.py +7 -1
  41. memos/mem_scheduler/schemas/general_schemas.py +3 -2
  42. memos/mem_scheduler/schemas/message_schemas.py +2 -1
  43. memos/mem_scheduler/schemas/monitor_schemas.py +10 -2
  44. memos/mem_scheduler/utils/misc_utils.py +43 -2
  45. memos/mem_user/mysql_user_manager.py +4 -2
  46. memos/memories/activation/item.py +1 -1
  47. memos/memories/activation/kv.py +20 -8
  48. memos/memories/textual/base.py +1 -1
  49. memos/memories/textual/general.py +1 -1
  50. memos/memories/textual/item.py +1 -1
  51. memos/memories/textual/tree.py +31 -1
  52. memos/memories/textual/tree_text_memory/organize/{conflict.py → handler.py} +30 -48
  53. memos/memories/textual/tree_text_memory/organize/manager.py +8 -96
  54. memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +2 -0
  55. memos/memories/textual/tree_text_memory/organize/reorganizer.py +102 -140
  56. memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +231 -0
  57. memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +9 -0
  58. memos/memories/textual/tree_text_memory/retrieve/recall.py +67 -10
  59. memos/memories/textual/tree_text_memory/retrieve/reranker.py +1 -1
  60. memos/memories/textual/tree_text_memory/retrieve/searcher.py +246 -134
  61. memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +7 -2
  62. memos/memories/textual/tree_text_memory/retrieve/utils.py +7 -5
  63. memos/memos_tools/lockfree_dict.py +120 -0
  64. memos/memos_tools/notification_utils.py +46 -0
  65. memos/memos_tools/thread_safe_dict.py +288 -0
  66. memos/reranker/__init__.py +4 -0
  67. memos/reranker/base.py +24 -0
  68. memos/reranker/cosine_local.py +95 -0
  69. memos/reranker/factory.py +43 -0
  70. memos/reranker/http_bge.py +99 -0
  71. memos/reranker/noop.py +16 -0
  72. memos/templates/mem_reader_prompts.py +290 -39
  73. memos/templates/mem_scheduler_prompts.py +23 -10
  74. memos/templates/mos_prompts.py +133 -31
  75. memos/templates/tree_reorganize_prompts.py +24 -17
  76. memos/utils.py +19 -0
  77. memos/memories/textual/tree_text_memory/organize/redundancy.py +0 -193
  78. {memoryos-0.2.2.dist-info → memoryos-1.0.1.dist-info}/LICENSE +0 -0
  79. {memoryos-0.2.2.dist-info → memoryos-1.0.1.dist-info}/WHEEL +0 -0
  80. {memoryos-0.2.2.dist-info → memoryos-1.0.1.dist-info}/entry_points.txt +0 -0
  81. /memos/mem_scheduler/{modules → general_modules}/__init__.py +0 -0
  82. /memos/mem_scheduler/{modules → general_modules}/misc.py +0 -0
@@ -1,10 +1,11 @@
1
+ import json
1
2
  import traceback
2
3
 
3
4
  from contextlib import suppress
4
5
  from datetime import datetime
5
6
  from queue import Empty, Queue
6
7
  from threading import Lock
7
- from typing import Any, Literal
8
+ from typing import TYPE_CHECKING, Any, ClassVar, Literal
8
9
 
9
10
  import numpy as np
10
11
 
@@ -12,17 +13,24 @@ from memos.configs.graph_db import NebulaGraphDBConfig
12
13
  from memos.dependency import require_python_package
13
14
  from memos.graph_dbs.base import BaseGraphDB
14
15
  from memos.log import get_logger
16
+ from memos.utils import timed
17
+
18
+
19
+ if TYPE_CHECKING:
20
+ from nebulagraph_python.client.pool import NebulaPool
15
21
 
16
22
 
17
23
  logger = get_logger(__name__)
18
24
 
19
25
 
26
+ @timed
20
27
  def _normalize(vec: list[float]) -> list[float]:
21
28
  v = np.asarray(vec, dtype=np.float32)
22
29
  norm = np.linalg.norm(v)
23
30
  return (v / (norm if norm else 1.0)).tolist()
24
31
 
25
32
 
33
+ @timed
26
34
  def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]:
27
35
  node_id = item["id"]
28
36
  memory = item["memory"]
@@ -30,97 +38,33 @@ def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]:
30
38
  return node_id, memory, metadata
31
39
 
32
40
 
33
- def _prepare_node_metadata(metadata: dict[str, Any]) -> dict[str, Any]:
34
- """
35
- Ensure metadata has proper datetime fields and normalized types.
36
-
37
- - Fill `created_at` and `updated_at` if missing (in ISO 8601 format).
38
- - Convert embedding to list of float if present.
39
- """
40
- now = datetime.utcnow().isoformat()
41
- metadata["node_type"] = metadata.pop("type")
42
-
43
- # Fill timestamps if missing
44
- metadata.setdefault("created_at", now)
45
- metadata.setdefault("updated_at", now)
46
-
47
- # Normalize embedding type
48
- embedding = metadata.get("embedding")
49
- if embedding and isinstance(embedding, list):
50
- metadata["embedding"] = _normalize([float(x) for x in embedding])
51
-
52
- return metadata
53
-
54
-
55
- def _metadata_filter(metadata: dict[str, Any]) -> dict[str, Any]:
56
- """
57
- Filter and validate metadata dictionary against the Memory node schema.
58
- - Removes keys not in schema.
59
- - Warns if required fields are missing.
60
- """
61
-
62
- allowed_fields = {
63
- "id",
64
- "memory",
65
- "user_name",
66
- "user_id",
67
- "session_id",
68
- "status",
69
- "key",
70
- "confidence",
71
- "tags",
72
- "created_at",
73
- "updated_at",
74
- "memory_type",
75
- "sources",
76
- "source",
77
- "node_type",
78
- "visibility",
79
- "usage",
80
- "background",
81
- "embedding",
82
- }
83
-
84
- missing_fields = allowed_fields - metadata.keys()
85
- if missing_fields:
86
- logger.warning(f"Metadata missing required fields: {sorted(missing_fields)}")
87
-
88
- filtered_metadata = {k: v for k, v in metadata.items() if k in allowed_fields}
89
-
90
- return filtered_metadata
91
-
92
-
41
+ @timed
93
42
  def _escape_str(value: str) -> str:
94
- return value.replace('"', '\\"')
95
-
96
-
97
- def _format_value(val: Any, key: str = "") -> str:
98
- from nebulagraph_python.py_data_types import NVector
99
-
100
- if isinstance(val, str):
101
- return f'"{_escape_str(val)}"'
102
- elif isinstance(val, (int | float)):
103
- return str(val)
104
- elif isinstance(val, datetime):
105
- return f'datetime("{val.isoformat()}")'
106
- elif isinstance(val, list):
107
- if key == "embedding":
108
- dim = len(val)
109
- joined = ",".join(str(float(x)) for x in val)
110
- return f"VECTOR<{dim}, FLOAT>([{joined}])"
43
+ out = []
44
+ for ch in value:
45
+ code = ord(ch)
46
+ if ch == "\\":
47
+ out.append("\\\\")
48
+ elif ch == '"':
49
+ out.append('\\"')
50
+ elif ch == "\n":
51
+ out.append("\\n")
52
+ elif ch == "\r":
53
+ out.append("\\r")
54
+ elif ch == "\t":
55
+ out.append("\\t")
56
+ elif ch == "\b":
57
+ out.append("\\b")
58
+ elif ch == "\f":
59
+ out.append("\\f")
60
+ elif code < 0x20 or code in (0x2028, 0x2029):
61
+ out.append(f"\\u{code:04x}")
111
62
  else:
112
- return f"[{', '.join(_format_value(v) for v in val)}]"
113
- elif isinstance(val, NVector):
114
- if key == "embedding":
115
- dim = len(val)
116
- joined = ",".join(str(float(x)) for x in val)
117
- return f"VECTOR<{dim}, FLOAT>([{joined}])"
118
- elif val is None:
119
- return "NULL"
120
- else:
121
- return f'"{_escape_str(str(val))}"'
63
+ out.append(ch)
64
+ return "".join(out)
122
65
 
123
66
 
67
+ @timed
124
68
  def _format_datetime(value: str | datetime) -> str:
125
69
  """Ensure datetime is in ISO 8601 format string."""
126
70
  if isinstance(value, datetime):
@@ -128,6 +72,21 @@ def _format_datetime(value: str | datetime) -> str:
128
72
  return str(value)
129
73
 
130
74
 
75
+ @timed
76
+ def _normalize_datetime(val):
77
+ """
78
+ Normalize datetime to ISO 8601 UTC string with +00:00.
79
+ - If val is datetime object -> keep isoformat() (Neo4j)
80
+ - If val is string without timezone -> append +00:00 (Nebula)
81
+ - Otherwise just str()
82
+ """
83
+ if hasattr(val, "isoformat"):
84
+ return val.isoformat()
85
+ if isinstance(val, str) and not val.endswith(("+00:00", "Z", "+08:00")):
86
+ return val + "+08:00"
87
+ return str(val)
88
+
89
+
131
90
  class SessionPoolError(Exception):
132
91
  pass
133
92
 
@@ -149,6 +108,7 @@ class SessionPool:
149
108
  self.hosts = hosts
150
109
  self.user = user
151
110
  self.password = password
111
+ self.minsize = minsize
152
112
  self.maxsize = maxsize
153
113
  self.pool = Queue(maxsize)
154
114
  self.lock = Lock()
@@ -158,6 +118,7 @@ class SessionPool:
158
118
  for _ in range(minsize):
159
119
  self._create_and_add_client()
160
120
 
121
+ @timed
161
122
  def _create_and_add_client(self):
162
123
  from nebulagraph_python import NebulaClient
163
124
 
@@ -165,28 +126,37 @@ class SessionPool:
165
126
  self.pool.put(client)
166
127
  self.clients.append(client)
167
128
 
129
+ @timed
168
130
  def get_client(self, timeout: float = 5.0):
169
- from nebulagraph_python import NebulaClient
170
-
171
131
  try:
172
132
  return self.pool.get(timeout=timeout)
173
133
  except Empty:
174
134
  with self.lock:
175
135
  if len(self.clients) < self.maxsize:
136
+ from nebulagraph_python import NebulaClient
137
+
176
138
  client = NebulaClient(self.hosts, self.user, self.password)
177
139
  self.clients.append(client)
178
140
  return client
179
141
  raise RuntimeError("NebulaClientPool exhausted") from None
180
142
 
143
+ @timed
181
144
  def return_client(self, client):
182
- self.pool.put(client)
145
+ try:
146
+ client.execute("YIELD 1")
147
+ self.pool.put(client)
148
+ except Exception:
149
+ logger.info("[Pool] Client dead, replacing...")
150
+ self.replace_client(client)
183
151
 
152
+ @timed
184
153
  def close(self):
185
154
  for client in self.clients:
186
155
  with suppress(Exception):
187
156
  client.close()
188
157
  self.clients.clear()
189
158
 
159
+ @timed
190
160
  def get(self):
191
161
  """
192
162
  Context manager: with pool.get() as client:
@@ -207,12 +177,140 @@ class SessionPool:
207
177
 
208
178
  return _ClientContext(self)
209
179
 
180
+ @timed
181
+ def reset_pool(self):
182
+ """⚠️ Emergency reset: Close all clients and clear the pool."""
183
+ logger.warning("[Pool] Resetting all clients. Existing sessions will be lost.")
184
+ with self.lock:
185
+ for client in self.clients:
186
+ try:
187
+ client.close()
188
+ except Exception:
189
+ logger.error("Fail to close!!!")
190
+ self.clients.clear()
191
+ while not self.pool.empty():
192
+ try:
193
+ self.pool.get_nowait()
194
+ except Empty:
195
+ break
196
+ for _ in range(self.minsize):
197
+ self._create_and_add_client()
198
+ logger.info("[Pool] Pool has been reset successfully.")
199
+
200
+ @timed
201
+ def replace_client(self, client):
202
+ try:
203
+ client.close()
204
+ except Exception:
205
+ logger.error("Fail to close client")
206
+
207
+ if client in self.clients:
208
+ self.clients.remove(client)
209
+
210
+ from nebulagraph_python import NebulaClient
211
+
212
+ new_client = NebulaClient(self.hosts, self.user, self.password)
213
+ self.clients.append(new_client)
214
+
215
+ self.pool.put(new_client)
216
+
217
+ logger.info("[Pool] Replaced dead client with a new one.")
218
+ return new_client
219
+
210
220
 
211
221
  class NebulaGraphDB(BaseGraphDB):
212
222
  """
213
223
  NebulaGraph-based implementation of a graph memory store.
214
224
  """
215
225
 
226
+ # ====== shared pool cache & refcount ======
227
+ # These are process-local; in a multi-process model each process will
228
+ # have its own cache.
229
+ _POOL_CACHE: ClassVar[dict[str, "NebulaPool"]] = {}
230
+ _POOL_REFCOUNT: ClassVar[dict[str, int]] = {}
231
+ _POOL_LOCK: ClassVar[Lock] = Lock()
232
+
233
+ @staticmethod
234
+ def _make_pool_key(cfg: NebulaGraphDBConfig) -> str:
235
+ """
236
+ Build a cache key that captures all connection-affecting options.
237
+ Keep this key stable and include fields that change the underlying pool behavior.
238
+ """
239
+ # NOTE: Do not include tenant-like or query-scope-only fields here.
240
+ # Only include things that affect the actual TCP/auth/session pool.
241
+ return "|".join(
242
+ [
243
+ "nebula",
244
+ str(getattr(cfg, "uri", "")),
245
+ str(getattr(cfg, "user", "")),
246
+ str(getattr(cfg, "password", "")),
247
+ # pool sizing / tls / timeouts if you have them in config:
248
+ str(getattr(cfg, "max_client", 1000)),
249
+ # multi-db mode can impact how we use sessions; keep it to be safe
250
+ str(getattr(cfg, "use_multi_db", False)),
251
+ ]
252
+ )
253
+
254
+ @classmethod
255
+ def _get_or_create_shared_pool(cls, cfg: NebulaGraphDBConfig):
256
+ """
257
+ Get a shared NebulaPool from cache or create one if missing.
258
+ Thread-safe with a lock; maintains a simple refcount.
259
+ """
260
+ key = cls._make_pool_key(cfg)
261
+
262
+ with cls._POOL_LOCK:
263
+ pool = cls._POOL_CACHE.get(key)
264
+ if pool is None:
265
+ # Create a new pool and put into cache
266
+ pool = SessionPool(
267
+ hosts=cfg.get("uri"),
268
+ user=cfg.get("user"),
269
+ password=cfg.get("password"),
270
+ minsize=1,
271
+ maxsize=cfg.get("max_client", 1000),
272
+ )
273
+ cls._POOL_CACHE[key] = pool
274
+ cls._POOL_REFCOUNT[key] = 0
275
+ logger.info(f"[NebulaGraphDB] Created new shared NebulaPool for key={key}")
276
+
277
+ # Increase refcount for the caller
278
+ cls._POOL_REFCOUNT[key] = cls._POOL_REFCOUNT.get(key, 0) + 1
279
+ return key, pool
280
+
281
+ @classmethod
282
+ def _release_shared_pool(cls, key: str):
283
+ """
284
+ Decrease refcount for the given pool key; only close when refcount hits zero.
285
+ """
286
+ with cls._POOL_LOCK:
287
+ if key not in cls._POOL_CACHE:
288
+ return
289
+ cls._POOL_REFCOUNT[key] = max(0, cls._POOL_REFCOUNT.get(key, 0) - 1)
290
+ if cls._POOL_REFCOUNT[key] == 0:
291
+ try:
292
+ cls._POOL_CACHE[key].close()
293
+ except Exception as e:
294
+ logger.warning(f"[NebulaGraphDB] Error closing shared pool: {e}")
295
+ finally:
296
+ cls._POOL_CACHE.pop(key, None)
297
+ cls._POOL_REFCOUNT.pop(key, None)
298
+ logger.info(f"[NebulaGraphDB] Closed and removed shared pool key={key}")
299
+
300
+ @classmethod
301
+ def close_all_shared_pools(cls):
302
+ """Force close all cached pools. Call this on graceful shutdown."""
303
+ with cls._POOL_LOCK:
304
+ for key, pool in list(cls._POOL_CACHE.items()):
305
+ try:
306
+ pool.close()
307
+ except Exception as e:
308
+ logger.warning(f"[NebulaGraphDB] Error closing pool key={key}: {e}")
309
+ finally:
310
+ logger.info(f"[NebulaGraphDB] Closed pool key={key}")
311
+ cls._POOL_CACHE.clear()
312
+ cls._POOL_REFCOUNT.clear()
313
+
216
314
  @require_python_package(
217
315
  import_name="nebulagraph_python",
218
316
  install_command="pip install ... @Tianxing",
@@ -240,15 +338,43 @@ class NebulaGraphDB(BaseGraphDB):
240
338
  self.config = config
241
339
  self.db_name = config.space
242
340
  self.user_name = config.user_name
243
- self.system_db_name = "system" if config.use_multi_db else config.space
244
- self.pool = SessionPool(
245
- hosts=config.get("uri"),
246
- user=config.get("user"),
247
- password=config.get("password"),
248
- minsize=1,
249
- maxsize=config.get("max_client", 1000),
341
+ self.embedding_dimension = config.embedding_dimension
342
+ self.default_memory_dimension = 3072
343
+ self.common_fields = {
344
+ "id",
345
+ "memory",
346
+ "user_name",
347
+ "user_id",
348
+ "session_id",
349
+ "status",
350
+ "key",
351
+ "confidence",
352
+ "tags",
353
+ "created_at",
354
+ "updated_at",
355
+ "memory_type",
356
+ "sources",
357
+ "source",
358
+ "node_type",
359
+ "visibility",
360
+ "usage",
361
+ "background",
362
+ }
363
+ self.base_fields = set(self.common_fields) - {"usage"}
364
+ self.heavy_fields = {"usage"}
365
+ self.dim_field = (
366
+ f"embedding_{self.embedding_dimension}"
367
+ if (str(self.embedding_dimension) != str(self.default_memory_dimension))
368
+ else "embedding"
250
369
  )
370
+ self.system_db_name = "system" if config.use_multi_db else config.space
251
371
 
372
+ # ---- NEW: pool acquisition strategy
373
+ # Get or create a shared pool from the class-level cache
374
+ self._pool_key, self.pool = self._get_or_create_shared_pool(config)
375
+ self._owns_pool = True # We manage refcount for this instance
376
+
377
+ # auto-create graph type / graph / index if needed
252
378
  if config.auto_create:
253
379
  self._ensure_database_exists()
254
380
 
@@ -259,15 +385,44 @@ class NebulaGraphDB(BaseGraphDB):
259
385
 
260
386
  logger.info("Connected to NebulaGraph successfully.")
261
387
 
262
- def execute_query(self, gql: str, timeout: float = 5.0, auto_set_db: bool = True):
388
+ @timed
389
+ def execute_query(self, gql: str, timeout: float = 10.0, auto_set_db: bool = True):
263
390
  with self.pool.get() as client:
264
- if auto_set_db and self.db_name:
265
- client.execute(f"SESSION SET GRAPH `{self.db_name}`")
266
- return client.execute(gql, timeout=timeout)
391
+ try:
392
+ if auto_set_db and self.db_name:
393
+ client.execute(f"SESSION SET GRAPH `{self.db_name}`")
394
+ return client.execute(gql, timeout=timeout)
267
395
 
396
+ except Exception as e:
397
+ if "Session not found" in str(e) or "Connection not established" in str(e):
398
+ logger.warning(f"[execute_query] {e!s}, replacing client...")
399
+ self.pool.replace_client(client)
400
+ return self.execute_query(gql, timeout, auto_set_db)
401
+ raise
402
+
403
+ @timed
268
404
  def close(self):
269
- self.pool.close()
405
+ """
406
+ Close the connection resource if this instance owns it.
270
407
 
408
+ - If pool was injected (`shared_pool`), do nothing.
409
+ - If pool was acquired via shared cache, decrement refcount and close
410
+ when the last owner releases it.
411
+ """
412
+ if not self._owns_pool:
413
+ logger.debug("[NebulaGraphDB] close() skipped (injected pool).")
414
+ return
415
+ if self._pool_key:
416
+ self._release_shared_pool(self._pool_key)
417
+ self._pool_key = None
418
+ self.pool = None
419
+
420
+ # NOTE: __del__ is best-effort; do not rely on GC order.
421
+ def __del__(self):
422
+ with suppress(Exception):
423
+ self.close()
424
+
425
+ @timed
271
426
  def create_index(
272
427
  self,
273
428
  label: str = "Memory",
@@ -280,6 +435,7 @@ class NebulaGraphDB(BaseGraphDB):
280
435
  # Create indexes
281
436
  self._create_basic_property_indexes()
282
437
 
438
+ @timed
283
439
  def remove_oldest_memory(self, memory_type: str, keep_latest: int) -> None:
284
440
  """
285
441
  Remove all WorkingMemory nodes except the latest `keep_latest` entries.
@@ -302,6 +458,7 @@ class NebulaGraphDB(BaseGraphDB):
302
458
  """
303
459
  self.execute_query(query)
304
460
 
461
+ @timed
305
462
  def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None:
306
463
  """
307
464
  Insert or update a Memory node in NebulaGraph.
@@ -318,10 +475,14 @@ class NebulaGraphDB(BaseGraphDB):
318
475
  metadata["memory"] = memory
319
476
 
320
477
  if "embedding" in metadata and isinstance(metadata["embedding"], list):
321
- metadata["embedding"] = _normalize(metadata["embedding"])
478
+ assert len(metadata["embedding"]) == self.embedding_dimension, (
479
+ f"input embedding dimension must equal to {self.embedding_dimension}"
480
+ )
481
+ embedding = metadata.pop("embedding")
482
+ metadata[self.dim_field] = _normalize(embedding)
322
483
 
323
- metadata = _metadata_filter(metadata)
324
- properties = ", ".join(f"{k}: {_format_value(v, k)}" for k, v in metadata.items())
484
+ metadata = self._metadata_filter(metadata)
485
+ properties = ", ".join(f"{k}: {self._format_value(v, k)}" for k, v in metadata.items())
325
486
  gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
326
487
 
327
488
  try:
@@ -332,16 +493,16 @@ class NebulaGraphDB(BaseGraphDB):
332
493
  f"Failed to insert vertex {id}: gql: {gql}, {e}\ntrace: {traceback.format_exc()}"
333
494
  )
334
495
 
496
+ @timed
335
497
  def node_not_exist(self, scope: str) -> int:
336
498
  if not self.config.use_multi_db and self.config.user_name:
337
499
  filter_clause = f'n.memory_type = "{scope}" AND n.user_name = "{self.config.user_name}"'
338
500
  else:
339
501
  filter_clause = f'n.memory_type = "{scope}"'
340
-
341
502
  query = f"""
342
503
  MATCH (n@Memory)
343
504
  WHERE {filter_clause}
344
- RETURN n
505
+ RETURN n.id AS id
345
506
  LIMIT 1
346
507
  """
347
508
 
@@ -352,6 +513,7 @@ class NebulaGraphDB(BaseGraphDB):
352
513
  logger.error(f"[node_not_exist] Query failed: {e}", exc_info=True)
353
514
  raise
354
515
 
516
+ @timed
355
517
  def update_node(self, id: str, fields: dict[str, Any]) -> None:
356
518
  """
357
519
  Update node fields in Nebular, auto-converting `created_at` and `updated_at` to datetime type if present.
@@ -359,7 +521,7 @@ class NebulaGraphDB(BaseGraphDB):
359
521
  fields = fields.copy()
360
522
  set_clauses = []
361
523
  for k, v in fields.items():
362
- set_clauses.append(f"n.{k} = {_format_value(v, k)}")
524
+ set_clauses.append(f"n.{k} = {self._format_value(v, k)}")
363
525
 
364
526
  set_clause_str = ",\n ".join(set_clauses)
365
527
 
@@ -373,6 +535,7 @@ class NebulaGraphDB(BaseGraphDB):
373
535
  query += f"\nSET {set_clause_str}"
374
536
  self.execute_query(query)
375
537
 
538
+ @timed
376
539
  def delete_node(self, id: str) -> None:
377
540
  """
378
541
  Delete a node from the graph.
@@ -384,10 +547,11 @@ class NebulaGraphDB(BaseGraphDB):
384
547
  """
385
548
  if not self.config.use_multi_db and self.config.user_name:
386
549
  user_name = self.config.user_name
387
- query += f" WHERE n.user_name = {_format_value(user_name)}"
550
+ query += f" WHERE n.user_name = {self._format_value(user_name)}"
388
551
  query += "\n DETACH DELETE n"
389
552
  self.execute_query(query)
390
553
 
554
+ @timed
391
555
  def add_edge(self, source_id: str, target_id: str, type: str):
392
556
  """
393
557
  Create an edge from source node to target node.
@@ -412,6 +576,7 @@ class NebulaGraphDB(BaseGraphDB):
412
576
  except Exception as e:
413
577
  logger.error(f"Failed to insert edge: {e}", exc_info=True)
414
578
 
579
+ @timed
415
580
  def delete_edge(self, source_id: str, target_id: str, type: str) -> None:
416
581
  """
417
582
  Delete a specific edge between two nodes.
@@ -422,16 +587,17 @@ class NebulaGraphDB(BaseGraphDB):
422
587
  """
423
588
  query = f"""
424
589
  MATCH (a@Memory) -[r@{type}]-> (b@Memory)
425
- WHERE a.id = {_format_value(source_id)} AND b.id = {_format_value(target_id)}
590
+ WHERE a.id = {self._format_value(source_id)} AND b.id = {self._format_value(target_id)}
426
591
  """
427
592
 
428
593
  if not self.config.use_multi_db and self.config.user_name:
429
594
  user_name = self.config.user_name
430
- query += f" AND a.user_name = {_format_value(user_name)} AND b.user_name = {_format_value(user_name)}"
595
+ query += f" AND a.user_name = {self._format_value(user_name)} AND b.user_name = {self._format_value(user_name)}"
431
596
 
432
597
  query += "\nDELETE r"
433
598
  self.execute_query(query)
434
599
 
600
+ @timed
435
601
  def get_memory_count(self, memory_type: str) -> int:
436
602
  query = f"""
437
603
  MATCH (n@Memory)
@@ -449,6 +615,7 @@ class NebulaGraphDB(BaseGraphDB):
449
615
  logger.error(f"[get_memory_count] Failed: {e}")
450
616
  return -1
451
617
 
618
+ @timed
452
619
  def count_nodes(self, scope: str) -> int:
453
620
  query = f"""
454
621
  MATCH (n@Memory)
@@ -462,6 +629,7 @@ class NebulaGraphDB(BaseGraphDB):
462
629
  result = self.execute_query(query)
463
630
  return result.one_or_none()["count"].value
464
631
 
632
+ @timed
465
633
  def edge_exists(
466
634
  self, source_id: str, target_id: str, type: str = "ANY", direction: str = "OUTGOING"
467
635
  ) -> bool:
@@ -503,43 +671,53 @@ class NebulaGraphDB(BaseGraphDB):
503
671
  return False
504
672
  return record.values() is not None
505
673
 
674
+ @timed
506
675
  # Graph Query & Reasoning
507
- def get_node(self, id: str) -> dict[str, Any] | None:
676
+ def get_node(self, id: str, include_embedding: bool = False) -> dict[str, Any] | None:
508
677
  """
509
678
  Retrieve a Memory node by its unique ID.
510
679
 
511
680
  Args:
512
681
  id (str): Node ID (Memory.id)
682
+ include_embedding: with/without embedding
513
683
 
514
684
  Returns:
515
685
  dict: Node properties as key-value pairs, or None if not found.
516
686
  """
687
+ if not self.config.use_multi_db and self.config.user_name:
688
+ filter_clause = f'n.user_name = "{self.config.user_name}" AND n.id = "{id}"'
689
+ else:
690
+ filter_clause = f'n.id = "{id}"'
691
+
692
+ return_fields = self._build_return_fields(include_embedding)
517
693
  gql = f"""
518
- USE `{self.db_name}`
519
- MATCH (v {{id: '{id}'}})
520
- RETURN v
521
- """
694
+ MATCH (n@Memory)
695
+ WHERE {filter_clause}
696
+ RETURN {return_fields}
697
+ """
522
698
 
523
699
  try:
524
700
  result = self.execute_query(gql)
525
- record = result.one_or_none()
526
- if record is None:
527
- return None
528
-
529
- node_wrapper = record["v"].as_node()
530
- props = node_wrapper.get_properties()
531
- node = self._parse_node(props)
532
- return node
701
+ for row in result:
702
+ props = {k: v.value for k, v in row.items()}
703
+ node = self._parse_node(props)
704
+ return node
533
705
 
534
706
  except Exception as e:
535
- logger.error(f"[get_node] Failed to retrieve node '{id}': {e}")
707
+ logger.error(
708
+ f"[get_node] Failed to retrieve node '{id}': {e}, trace: {traceback.format_exc()}"
709
+ )
536
710
  return None
537
711
 
538
- def get_nodes(self, ids: list[str]) -> list[dict[str, Any]]:
712
+ @timed
713
+ def get_nodes(
714
+ self, ids: list[str], include_embedding: bool = False, **kwargs
715
+ ) -> list[dict[str, Any]]:
539
716
  """
540
717
  Retrieve the metadata and memory of a list of nodes.
541
718
  Args:
542
719
  ids: List of Node identifier.
720
+ include_embedding: with/without embedding
543
721
  Returns:
544
722
  list[dict]: Parsed node records containing 'id', 'memory', and 'metadata'.
545
723
 
@@ -552,18 +730,33 @@ class NebulaGraphDB(BaseGraphDB):
552
730
 
553
731
  where_user = ""
554
732
  if not self.config.use_multi_db and self.config.user_name:
555
- where_user = f" AND n.user_name = '{self.config.user_name}'"
733
+ if kwargs.get("cube_name"):
734
+ where_user = f" AND n.user_name = '{kwargs['cube_name']}'"
735
+ else:
736
+ where_user = f" AND n.user_name = '{self.config.user_name}'"
556
737
 
557
- query = f"MATCH (n@Memory) WHERE n.id IN {ids} {where_user} RETURN n"
738
+ # Safe formatting of the ID list
739
+ id_list = ",".join(f'"{_id}"' for _id in ids)
558
740
 
559
- results = self.execute_query(query)
741
+ return_fields = self._build_return_fields(include_embedding)
742
+ query = f"""
743
+ MATCH (n@Memory)
744
+ WHERE n.id IN [{id_list}] {where_user}
745
+ RETURN {return_fields}
746
+ """
560
747
  nodes = []
561
- for rec in results:
562
- node_props = rec["n"].as_node().get_properties()
563
- nodes.append(self._parse_node(node_props))
564
-
748
+ try:
749
+ results = self.execute_query(query)
750
+ for row in results:
751
+ props = {k: v.value for k, v in row.items()}
752
+ nodes.append(self._parse_node(props))
753
+ except Exception as e:
754
+ logger.error(
755
+ f"[get_nodes] Failed to retrieve nodes {ids}: {e}, trace: {traceback.format_exc()}"
756
+ )
565
757
  return nodes
566
758
 
759
+ @timed
567
760
  def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[dict[str, str]]:
568
761
  """
569
762
  Get edges connected to a node, with optional type and direction filter.
@@ -617,12 +810,14 @@ class NebulaGraphDB(BaseGraphDB):
617
810
  )
618
811
  return edges
619
812
 
813
+ @timed
620
814
  def get_neighbors_by_tag(
621
815
  self,
622
816
  tags: list[str],
623
817
  exclude_ids: list[str],
624
818
  top_k: int = 5,
625
819
  min_overlap: int = 1,
820
+ include_embedding: bool = False,
626
821
  ) -> list[dict[str, Any]]:
627
822
  """
628
823
  Find top-K neighbor nodes with maximum tag overlap.
@@ -632,6 +827,7 @@ class NebulaGraphDB(BaseGraphDB):
632
827
  exclude_ids: Node IDs to exclude (e.g., local cluster).
633
828
  top_k: Max number of neighbors to return.
634
829
  min_overlap: Minimum number of overlapping tags required.
830
+ include_embedding: with/without embedding
635
831
 
636
832
  Returns:
637
833
  List of dicts with node details and overlap count.
@@ -653,12 +849,13 @@ class NebulaGraphDB(BaseGraphDB):
653
849
  where_clause = " AND ".join(where_clauses)
654
850
  tag_list_literal = "[" + ", ".join(f'"{_escape_str(t)}"' for t in tags) + "]"
655
851
 
852
+ return_fields = self._build_return_fields(include_embedding)
656
853
  query = f"""
657
854
  LET tag_list = {tag_list_literal}
658
855
 
659
856
  MATCH (n@Memory)
660
857
  WHERE {where_clause}
661
- RETURN n,
858
+ RETURN {return_fields},
662
859
  size( filter( n.tags, t -> t IN tag_list ) ) AS overlap_count
663
860
  ORDER BY overlap_count DESC
664
861
  LIMIT {top_k}
@@ -667,9 +864,8 @@ class NebulaGraphDB(BaseGraphDB):
667
864
  result = self.execute_query(query)
668
865
  neighbors: list[dict[str, Any]] = []
669
866
  for r in result:
670
- node_props = r["n"].as_node().get_properties()
671
- parsed = self._parse_node(node_props) # --> {id, memory, metadata}
672
-
867
+ props = {k: v.value for k, v in r.items() if k != "overlap_count"}
868
+ parsed = self._parse_node(props)
673
869
  parsed["overlap_count"] = r["overlap_count"].value
674
870
  neighbors.append(parsed)
675
871
 
@@ -681,6 +877,7 @@ class NebulaGraphDB(BaseGraphDB):
681
877
  result.append(neighbor)
682
878
  return result
683
879
 
880
+ @timed
684
881
  def get_children_with_embeddings(self, id: str) -> list[dict[str, Any]]:
685
882
  where_user = ""
686
883
 
@@ -691,19 +888,20 @@ class NebulaGraphDB(BaseGraphDB):
691
888
  query = f"""
692
889
  MATCH (p@Memory)-[@PARENT]->(c@Memory)
693
890
  WHERE p.id = "{id}" {where_user}
694
- RETURN c.id AS id, c.embedding AS embedding, c.memory AS memory
891
+ RETURN c.id AS id, c.{self.dim_field} AS {self.dim_field}, c.memory AS memory
695
892
  """
696
893
  result = self.execute_query(query)
697
894
  children = []
698
895
  for row in result:
699
896
  eid = row["id"].value # STRING
700
- emb_v = row["embedding"].value # NVector
897
+ emb_v = row[self.dim_field].value # NVector
701
898
  emb = list(emb_v.values) if emb_v else []
702
899
  mem = row["memory"].value # STRING
703
900
 
704
901
  children.append({"id": eid, "embedding": emb, "memory": mem})
705
902
  return children
706
903
 
904
+ @timed
707
905
  def get_subgraph(
708
906
  self, center_id: str, depth: int = 2, center_status: str = "activated"
709
907
  ) -> dict[str, Any]:
@@ -765,6 +963,7 @@ class NebulaGraphDB(BaseGraphDB):
765
963
 
766
964
  return {"core_node": core_node, "neighbors": neighbors, "edges": edges}
767
965
 
966
+ @timed
768
967
  # Search / recall operations
769
968
  def search_by_embedding(
770
969
  self,
@@ -773,6 +972,7 @@ class NebulaGraphDB(BaseGraphDB):
773
972
  scope: str | None = None,
774
973
  status: str | None = None,
775
974
  threshold: float | None = None,
975
+ **kwargs,
776
976
  ) -> list[dict]:
777
977
  """
778
978
  Retrieve node IDs based on vector similarity.
@@ -807,7 +1007,10 @@ class NebulaGraphDB(BaseGraphDB):
807
1007
  if status:
808
1008
  where_clauses.append(f'n.status = "{status}"')
809
1009
  if not self.config.use_multi_db and self.config.user_name:
810
- where_clauses.append(f'n.user_name = "{self.config.user_name}"')
1010
+ if kwargs.get("cube_name"):
1011
+ where_clauses.append(f'n.user_name = "{kwargs["cube_name"]}"')
1012
+ else:
1013
+ where_clauses.append(f'n.user_name = "{self.config.user_name}"')
811
1014
 
812
1015
  where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
813
1016
 
@@ -815,11 +1018,11 @@ class NebulaGraphDB(BaseGraphDB):
815
1018
  USE `{self.db_name}`
816
1019
  MATCH (n@Memory)
817
1020
  {where_clause}
818
- ORDER BY inner_product(n.embedding, {gql_vector}) DESC
1021
+ ORDER BY inner_product(n.{self.dim_field}, {gql_vector}) DESC
819
1022
  APPROXIMATE
820
1023
  LIMIT {top_k}
821
1024
  OPTIONS {{ METRIC: IP, TYPE: IVF, NPROBE: 8 }}
822
- RETURN n.id AS id, inner_product(n.embedding, {gql_vector}) AS score
1025
+ RETURN n.id AS id, inner_product(n.{self.dim_field}, {gql_vector}) AS score
823
1026
  """
824
1027
 
825
1028
  try:
@@ -842,6 +1045,7 @@ class NebulaGraphDB(BaseGraphDB):
842
1045
  logger.error(f"[search_by_embedding] Result parse failed: {e}")
843
1046
  return []
844
1047
 
1048
+ @timed
845
1049
  def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]:
846
1050
  """
847
1051
  1. ADD logic: "AND" vs "OR"(support logic combination);
@@ -868,20 +1072,12 @@ class NebulaGraphDB(BaseGraphDB):
868
1072
  """
869
1073
  where_clauses = []
870
1074
 
871
- def _escape_value(value):
872
- if isinstance(value, str):
873
- return f'"{value}"'
874
- elif isinstance(value, list):
875
- return "[" + ", ".join(_escape_value(v) for v in value) + "]"
876
- else:
877
- return str(value)
878
-
879
1075
  for _i, f in enumerate(filters):
880
1076
  field = f["field"]
881
1077
  op = f.get("op", "=")
882
1078
  value = f["value"]
883
1079
 
884
- escaped_value = _escape_value(value)
1080
+ escaped_value = self._format_value(value)
885
1081
 
886
1082
  # Build WHERE clause
887
1083
  if op == "=":
@@ -912,6 +1108,7 @@ class NebulaGraphDB(BaseGraphDB):
912
1108
  logger.error(f"Failed to get metadata: {e}, gql is {gql}")
913
1109
  return ids
914
1110
 
1111
+ @timed
915
1112
  def get_grouped_counts(
916
1113
  self,
917
1114
  group_fields: list[str],
@@ -980,6 +1177,7 @@ class NebulaGraphDB(BaseGraphDB):
980
1177
 
981
1178
  return output
982
1179
 
1180
+ @timed
983
1181
  def clear(self) -> None:
984
1182
  """
985
1183
  Clear the entire graph if the target database exists.
@@ -996,9 +1194,12 @@ class NebulaGraphDB(BaseGraphDB):
996
1194
  except Exception as e:
997
1195
  logger.error(f"[ERROR] Failed to clear database: {e}")
998
1196
 
999
- def export_graph(self) -> dict[str, Any]:
1197
+ @timed
1198
+ def export_graph(self, include_embedding: bool = False) -> dict[str, Any]:
1000
1199
  """
1001
1200
  Export all graph nodes and edges in a structured form.
1201
+ Args:
1202
+ include_embedding (bool): Whether to include the large embedding field.
1002
1203
 
1003
1204
  Returns:
1004
1205
  {
@@ -1015,13 +1216,41 @@ class NebulaGraphDB(BaseGraphDB):
1015
1216
  edge_query += f' WHERE r.user_name = "{username}"'
1016
1217
 
1017
1218
  try:
1018
- full_node_query = f"{node_query} RETURN n"
1019
- node_result = self.execute_query(full_node_query)
1219
+ if include_embedding:
1220
+ return_fields = "n"
1221
+ else:
1222
+ return_fields = ",".join(
1223
+ [
1224
+ "n.id AS id",
1225
+ "n.memory AS memory",
1226
+ "n.user_name AS user_name",
1227
+ "n.user_id AS user_id",
1228
+ "n.session_id AS session_id",
1229
+ "n.status AS status",
1230
+ "n.key AS key",
1231
+ "n.confidence AS confidence",
1232
+ "n.tags AS tags",
1233
+ "n.created_at AS created_at",
1234
+ "n.updated_at AS updated_at",
1235
+ "n.memory_type AS memory_type",
1236
+ "n.sources AS sources",
1237
+ "n.source AS source",
1238
+ "n.node_type AS node_type",
1239
+ "n.visibility AS visibility",
1240
+ "n.usage AS usage",
1241
+ "n.background AS background",
1242
+ ]
1243
+ )
1244
+
1245
+ full_node_query = f"{node_query} RETURN {return_fields}"
1246
+ node_result = self.execute_query(full_node_query, timeout=20)
1020
1247
  nodes = []
1248
+ logger.debug(f"Debugging: {node_result}")
1021
1249
  for row in node_result:
1022
- node_wrapper = row.values()[0].as_node()
1023
- props = node_wrapper.get_properties()
1024
-
1250
+ if include_embedding:
1251
+ props = row.values()[0].as_node().get_properties()
1252
+ else:
1253
+ props = {k: v.value for k, v in row.items()}
1025
1254
  node = self._parse_node(props)
1026
1255
  nodes.append(node)
1027
1256
  except Exception as e:
@@ -1029,7 +1258,7 @@ class NebulaGraphDB(BaseGraphDB):
1029
1258
 
1030
1259
  try:
1031
1260
  full_edge_query = f"{edge_query} RETURN a.id AS source, b.id AS target, type(r) as edge"
1032
- edge_result = self.execute_query(full_edge_query)
1261
+ edge_result = self.execute_query(full_edge_query, timeout=20)
1033
1262
  edges = [
1034
1263
  {
1035
1264
  "source": row.values()[0].value,
@@ -1043,6 +1272,7 @@ class NebulaGraphDB(BaseGraphDB):
1043
1272
 
1044
1273
  return {"nodes": nodes, "edges": edges}
1045
1274
 
1275
+ @timed
1046
1276
  def import_graph(self, data: dict[str, Any]) -> None:
1047
1277
  """
1048
1278
  Import the entire graph from a serialized dictionary.
@@ -1051,35 +1281,45 @@ class NebulaGraphDB(BaseGraphDB):
1051
1281
  data: A dictionary containing all nodes and edges to be loaded.
1052
1282
  """
1053
1283
  for node in data.get("nodes", []):
1054
- id, memory, metadata = _compose_node(node)
1284
+ try:
1285
+ id, memory, metadata = _compose_node(node)
1055
1286
 
1056
- if not self.config.use_multi_db and self.config.user_name:
1057
- metadata["user_name"] = self.config.user_name
1287
+ if not self.config.use_multi_db and self.config.user_name:
1288
+ metadata["user_name"] = self.config.user_name
1058
1289
 
1059
- metadata = _prepare_node_metadata(metadata)
1060
- metadata.update({"id": id, "memory": memory})
1061
- properties = ", ".join(f"{k}: {_format_value(v, k)}" for k, v in metadata.items())
1062
- node_gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
1063
- self.execute_query(node_gql)
1290
+ metadata = self._prepare_node_metadata(metadata)
1291
+ metadata.update({"id": id, "memory": memory})
1292
+ properties = ", ".join(
1293
+ f"{k}: {self._format_value(v, k)}" for k, v in metadata.items()
1294
+ )
1295
+ node_gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
1296
+ self.execute_query(node_gql)
1297
+ except Exception as e:
1298
+ logger.error(f"Fail to load node: {node}, error: {e}")
1064
1299
 
1065
1300
  for edge in data.get("edges", []):
1066
- source_id, target_id = edge["source"], edge["target"]
1067
- edge_type = edge["type"]
1068
- props = ""
1069
- if not self.config.use_multi_db and self.config.user_name:
1070
- props = f'{{user_name: "{self.config.user_name}"}}'
1071
- edge_gql = f'''
1072
- MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}})
1073
- INSERT OR IGNORE (a) -[e@{edge_type} {props}]-> (b)
1074
- '''
1075
- self.execute_query(edge_gql)
1301
+ try:
1302
+ source_id, target_id = edge["source"], edge["target"]
1303
+ edge_type = edge["type"]
1304
+ props = ""
1305
+ if not self.config.use_multi_db and self.config.user_name:
1306
+ props = f'{{user_name: "{self.config.user_name}"}}'
1307
+ edge_gql = f'''
1308
+ MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}})
1309
+ INSERT OR IGNORE (a) -[e@{edge_type} {props}]-> (b)
1310
+ '''
1311
+ self.execute_query(edge_gql)
1312
+ except Exception as e:
1313
+ logger.error(f"Fail to load edge: {edge}, error: {e}")
1076
1314
 
1077
- def get_all_memory_items(self, scope: str) -> list[dict]:
1315
+ @timed
1316
+ def get_all_memory_items(self, scope: str, include_embedding: bool = False) -> (list)[dict]:
1078
1317
  """
1079
1318
  Retrieve all memory items of a specific memory_type.
1080
1319
 
1081
1320
  Args:
1082
1321
  scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'.
1322
+ include_embedding: with/without embedding
1083
1323
 
1084
1324
  Returns:
1085
1325
  list[dict]: Full list of memory items under this scope.
@@ -1092,22 +1332,28 @@ class NebulaGraphDB(BaseGraphDB):
1092
1332
  if not self.config.use_multi_db and self.config.user_name:
1093
1333
  where_clause += f" AND n.user_name = '{self.config.user_name}'"
1094
1334
 
1335
+ return_fields = self._build_return_fields(include_embedding)
1336
+
1095
1337
  query = f"""
1096
1338
  MATCH (n@Memory)
1097
1339
  {where_clause}
1098
- RETURN n
1340
+ RETURN {return_fields}
1341
+ LIMIT 100
1099
1342
  """
1100
1343
  nodes = []
1101
1344
  try:
1102
1345
  results = self.execute_query(query)
1103
- for rec in results:
1104
- node_props = rec["n"].as_node().get_properties()
1105
- nodes.append(self._parse_node(node_props))
1346
+ for row in results:
1347
+ props = {k: v.value for k, v in row.items()}
1348
+ nodes.append(self._parse_node(props))
1106
1349
  except Exception as e:
1107
1350
  logger.error(f"Failed to get memories: {e}")
1108
1351
  return nodes
1109
1352
 
1110
- def get_structure_optimization_candidates(self, scope: str) -> list[dict]:
1353
+ @timed
1354
+ def get_structure_optimization_candidates(
1355
+ self, scope: str, include_embedding: bool = False
1356
+ ) -> list[dict]:
1111
1357
  """
1112
1358
  Find nodes that are likely candidates for structure optimization:
1113
1359
  - Isolated nodes, nodes with empty background, or nodes with exactly one child.
@@ -1121,6 +1367,8 @@ class NebulaGraphDB(BaseGraphDB):
1121
1367
  if not self.config.use_multi_db and self.config.user_name:
1122
1368
  where_clause += f' AND n.user_name = "{self.config.user_name}"'
1123
1369
 
1370
+ return_fields = self._build_return_fields(include_embedding)
1371
+
1124
1372
  query = f"""
1125
1373
  USE `{self.db_name}`
1126
1374
  MATCH (n@Memory)
@@ -1128,19 +1376,20 @@ class NebulaGraphDB(BaseGraphDB):
1128
1376
  OPTIONAL MATCH (n)-[@PARENT]->(c@Memory)
1129
1377
  OPTIONAL MATCH (p@Memory)-[@PARENT]->(n)
1130
1378
  WHERE c IS NULL AND p IS NULL
1131
- RETURN n
1379
+ RETURN {return_fields}
1132
1380
  """
1133
1381
 
1134
1382
  candidates = []
1135
1383
  try:
1136
1384
  results = self.execute_query(query)
1137
- for rec in results:
1138
- node_props = rec["n"].as_node().get_properties()
1139
- candidates.append(self._parse_node(node_props))
1385
+ for row in results:
1386
+ props = {k: v.value for k, v in row.items()}
1387
+ candidates.append(self._parse_node(props))
1140
1388
  except Exception as e:
1141
- logger.error(f"Failed : {e}")
1389
+ logger.error(f"Failed : {e}, traceback: {traceback.format_exc()}")
1142
1390
  return candidates
1143
1391
 
1392
+ @timed
1144
1393
  def drop_database(self) -> None:
1145
1394
  """
1146
1395
  Permanently delete the entire database this instance is using.
@@ -1155,6 +1404,7 @@ class NebulaGraphDB(BaseGraphDB):
1155
1404
  f"Shared Database Multi-Tenant mode"
1156
1405
  )
1157
1406
 
1407
+ @timed
1158
1408
  def detect_conflicts(self) -> list[tuple[str, str]]:
1159
1409
  """
1160
1410
  Detect conflicting nodes based on logical or semantic inconsistency.
@@ -1163,6 +1413,7 @@ class NebulaGraphDB(BaseGraphDB):
1163
1413
  """
1164
1414
  raise NotImplementedError
1165
1415
 
1416
+ @timed
1166
1417
  # Structure Maintenance
1167
1418
  def deduplicate_nodes(self) -> None:
1168
1419
  """
@@ -1171,6 +1422,7 @@ class NebulaGraphDB(BaseGraphDB):
1171
1422
  """
1172
1423
  raise NotImplementedError
1173
1424
 
1425
+ @timed
1174
1426
  def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]:
1175
1427
  """
1176
1428
  Get the ordered context chain starting from a node, following a relationship type.
@@ -1182,6 +1434,7 @@ class NebulaGraphDB(BaseGraphDB):
1182
1434
  """
1183
1435
  raise NotImplementedError
1184
1436
 
1437
+ @timed
1185
1438
  def get_neighbors(
1186
1439
  self, id: str, type: str, direction: Literal["in", "out", "both"] = "out"
1187
1440
  ) -> list[str]:
@@ -1196,6 +1449,7 @@ class NebulaGraphDB(BaseGraphDB):
1196
1449
  """
1197
1450
  raise NotImplementedError
1198
1451
 
1452
+ @timed
1199
1453
  def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[str]:
1200
1454
  """
1201
1455
  Get the path of nodes from source to target within a limited depth.
@@ -1208,6 +1462,7 @@ class NebulaGraphDB(BaseGraphDB):
1208
1462
  """
1209
1463
  raise NotImplementedError
1210
1464
 
1465
+ @timed
1211
1466
  def merge_nodes(self, id1: str, id2: str) -> str:
1212
1467
  """
1213
1468
  Merge two similar or duplicate nodes into one.
@@ -1219,70 +1474,112 @@ class NebulaGraphDB(BaseGraphDB):
1219
1474
  """
1220
1475
  raise NotImplementedError
1221
1476
 
1477
+ @timed
1222
1478
  def _ensure_database_exists(self):
1223
- create_tag = """
1224
- CREATE GRAPH TYPE IF NOT EXISTS MemOSType AS {
1225
- NODE Memory (:MemoryTag {
1226
- id STRING,
1227
- memory STRING,
1228
- user_name STRING,
1229
- user_id STRING,
1230
- session_id STRING,
1231
- status STRING,
1232
- key STRING,
1233
- confidence FLOAT,
1234
- tags LIST<STRING>,
1235
- created_at STRING,
1236
- updated_at STRING,
1237
- memory_type STRING,
1238
- sources LIST<STRING>,
1239
- source STRING,
1240
- node_type STRING,
1241
- visibility STRING,
1242
- usage LIST<STRING>,
1243
- background STRING,
1244
- embedding VECTOR<3072, FLOAT>,
1245
- PRIMARY KEY(id)
1246
- }),
1247
- EDGE RELATE_TO (Memory) -[{user_name STRING}]-> (Memory),
1248
- EDGE PARENT (Memory) -[{user_name STRING}]-> (Memory),
1249
- EDGE AGGREGATE_TO (Memory) -[{user_name STRING}]-> (Memory),
1250
- EDGE MERGED_TO (Memory) -[{user_name STRING}]-> (Memory),
1251
- EDGE INFERS (Memory) -[{user_name STRING}]-> (Memory),
1252
- EDGE FOLLOWS (Memory) -[{user_name STRING}]-> (Memory)
1253
- }
1254
- """
1255
- create_graph = f"CREATE GRAPH IF NOT EXISTS `{self.db_name}` TYPED MemOSType"
1479
+ graph_type_name = "MemOSBgeM3Type"
1480
+
1481
+ check_type_query = "SHOW GRAPH TYPES"
1482
+ result = self.execute_query(check_type_query, auto_set_db=False)
1483
+
1484
+ type_exists = any(row["graph_type"].as_string() == graph_type_name for row in result)
1485
+
1486
+ if not type_exists:
1487
+ create_tag = f"""
1488
+ CREATE GRAPH TYPE IF NOT EXISTS {graph_type_name} AS {{
1489
+ NODE Memory (:MemoryTag {{
1490
+ id STRING,
1491
+ memory STRING,
1492
+ user_name STRING,
1493
+ user_id STRING,
1494
+ session_id STRING,
1495
+ status STRING,
1496
+ key STRING,
1497
+ confidence FLOAT,
1498
+ tags LIST<STRING>,
1499
+ created_at STRING,
1500
+ updated_at STRING,
1501
+ memory_type STRING,
1502
+ sources LIST<STRING>,
1503
+ source STRING,
1504
+ node_type STRING,
1505
+ visibility STRING,
1506
+ usage LIST<STRING>,
1507
+ background STRING,
1508
+ {self.dim_field} VECTOR<{self.embedding_dimension}, FLOAT>,
1509
+ PRIMARY KEY(id)
1510
+ }}),
1511
+ EDGE RELATE_TO (Memory) -[{{user_name STRING}}]-> (Memory),
1512
+ EDGE PARENT (Memory) -[{{user_name STRING}}]-> (Memory),
1513
+ EDGE AGGREGATE_TO (Memory) -[{{user_name STRING}}]-> (Memory),
1514
+ EDGE MERGED_TO (Memory) -[{{user_name STRING}}]-> (Memory),
1515
+ EDGE INFERS (Memory) -[{{user_name STRING}}]-> (Memory),
1516
+ EDGE FOLLOWS (Memory) -[{{user_name STRING}}]-> (Memory)
1517
+ }}
1518
+ """
1519
+ self.execute_query(create_tag, auto_set_db=False)
1520
+ else:
1521
+ describe_query = f"DESCRIBE NODE TYPE Memory OF {graph_type_name};"
1522
+ desc_result = self.execute_query(describe_query, auto_set_db=False)
1523
+
1524
+ memory_fields = []
1525
+ for row in desc_result:
1526
+ field_name = row.values()[0].as_string()
1527
+ memory_fields.append(field_name)
1528
+
1529
+ if self.dim_field not in memory_fields:
1530
+ alter_query = f"""
1531
+ ALTER GRAPH TYPE {graph_type_name} {{
1532
+ ALTER NODE TYPE Memory ADD PROPERTIES {{ {self.dim_field} VECTOR<{self.embedding_dimension}, FLOAT> }}
1533
+ }}
1534
+ """
1535
+ self.execute_query(alter_query, auto_set_db=False)
1536
+ logger.info(f"✅ Add new vector search {self.dim_field} to {graph_type_name}")
1537
+ else:
1538
+ logger.info(f"✅ Graph Type {graph_type_name} already include {self.dim_field}")
1539
+
1540
+ create_graph = f"CREATE GRAPH IF NOT EXISTS `{self.db_name}` TYPED {graph_type_name}"
1256
1541
  set_graph_working = f"SESSION SET GRAPH `{self.db_name}`"
1257
1542
 
1258
1543
  try:
1259
- self.execute_query(create_tag, auto_set_db=False)
1260
1544
  self.execute_query(create_graph, auto_set_db=False)
1261
1545
  self.execute_query(set_graph_working)
1262
1546
  logger.info(f"✅ Graph ``{self.db_name}`` is now the working graph.")
1263
1547
  except Exception as e:
1264
1548
  logger.error(f"❌ Failed to create tag: {e} trace: {traceback.format_exc()}")
1265
1549
 
1550
+ @timed
1266
1551
  def _create_vector_index(
1267
1552
  self, label: str, vector_property: str, dimensions: int, index_name: str
1268
1553
  ) -> None:
1269
1554
  """
1270
1555
  Create a vector index for the specified property in the label.
1271
1556
  """
1557
+ if str(dimensions) == str(self.default_memory_dimension):
1558
+ index_name = f"idx_{vector_property}"
1559
+ vector_name = vector_property
1560
+ else:
1561
+ index_name = f"idx_{vector_property}_{dimensions}"
1562
+ vector_name = f"{vector_property}_{dimensions}"
1563
+
1272
1564
  create_vector_index = f"""
1273
- CREATE VECTOR INDEX IF NOT EXISTS {index_name}
1274
- ON NODE Memory::{vector_property}
1275
- OPTIONS {{
1276
- DIM: {dimensions},
1277
- METRIC: IP,
1278
- TYPE: IVF,
1279
- NLIST: 100,
1280
- TRAINSIZE: 1000
1281
- }}
1282
- FOR `{self.db_name}`
1283
- """
1565
+ CREATE VECTOR INDEX IF NOT EXISTS {index_name}
1566
+ ON NODE {label}::{vector_name}
1567
+ OPTIONS {{
1568
+ DIM: {dimensions},
1569
+ METRIC: IP,
1570
+ TYPE: IVF,
1571
+ NLIST: 100,
1572
+ TRAINSIZE: 1000
1573
+ }}
1574
+ FOR `{self.db_name}`
1575
+ """
1284
1576
  self.execute_query(create_vector_index)
1577
+ logger.info(
1578
+ f"✅ Ensure {label}::{vector_property} vector index {index_name} "
1579
+ f"exists (DIM={dimensions})"
1580
+ )
1285
1581
 
1582
+ @timed
1286
1583
  def _create_basic_property_indexes(self) -> None:
1287
1584
  """
1288
1585
  Create standard B-tree indexes on status, memory_type, created_at
@@ -1304,8 +1601,11 @@ class NebulaGraphDB(BaseGraphDB):
1304
1601
  self.execute_query(gql)
1305
1602
  logger.info(f"✅ Created index: {index_name} on field {field}")
1306
1603
  except Exception as e:
1307
- logger.error(f"❌ Failed to create index {index_name}: {e}")
1604
+ logger.error(
1605
+ f"❌ Failed to create index {index_name}: {e}, trace: {traceback.format_exc()}"
1606
+ )
1308
1607
 
1608
+ @timed
1309
1609
  def _index_exists(self, index_name: str) -> bool:
1310
1610
  """
1311
1611
  Check if an index with the given name exists.
@@ -1327,6 +1627,7 @@ class NebulaGraphDB(BaseGraphDB):
1327
1627
  logger.error(f"[Nebula] Failed to check index existence: {e}")
1328
1628
  return False
1329
1629
 
1630
+ @timed
1330
1631
  def _parse_value(self, value: Any) -> Any:
1331
1632
  """turn Nebula ValueWrapper to Python type"""
1332
1633
  from nebulagraph_python.value_wrapper import ValueWrapper
@@ -1352,8 +1653,8 @@ class NebulaGraphDB(BaseGraphDB):
1352
1653
  parsed = {k: self._parse_value(v) for k, v in props.items()}
1353
1654
 
1354
1655
  for tf in ("created_at", "updated_at"):
1355
- if tf in parsed and hasattr(parsed[tf], "isoformat"):
1356
- parsed[tf] = parsed[tf].isoformat()
1656
+ if tf in parsed and parsed[tf] is not None:
1657
+ parsed[tf] = _normalize_datetime(parsed[tf])
1357
1658
 
1358
1659
  node_id = parsed.pop("id")
1359
1660
  memory = parsed.pop("memory", "")
@@ -1361,4 +1662,98 @@ class NebulaGraphDB(BaseGraphDB):
1361
1662
  metadata = parsed
1362
1663
  metadata["type"] = metadata.pop("node_type")
1363
1664
 
1665
+ if self.dim_field in metadata:
1666
+ metadata["embedding"] = metadata.pop(self.dim_field)
1667
+
1364
1668
  return {"id": node_id, "memory": memory, "metadata": metadata}
1669
+
1670
+ @timed
1671
+ def _prepare_node_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]:
1672
+ """
1673
+ Ensure metadata has proper datetime fields and normalized types.
1674
+
1675
+ - Fill `created_at` and `updated_at` if missing (in ISO 8601 format).
1676
+ - Convert embedding to list of float if present.
1677
+ """
1678
+ now = datetime.utcnow().isoformat()
1679
+ metadata["node_type"] = metadata.pop("type")
1680
+
1681
+ # Fill timestamps if missing
1682
+ metadata.setdefault("created_at", now)
1683
+ metadata.setdefault("updated_at", now)
1684
+
1685
+ # Normalize embedding type
1686
+ embedding = metadata.get("embedding")
1687
+ if embedding and isinstance(embedding, list):
1688
+ metadata.pop("embedding")
1689
+ metadata[self.dim_field] = _normalize([float(x) for x in embedding])
1690
+
1691
+ return metadata
1692
+
1693
+ @timed
1694
+ def _format_value(self, val: Any, key: str = "") -> str:
1695
+ from nebulagraph_python.py_data_types import NVector
1696
+
1697
+ # None
1698
+ if val is None:
1699
+ return "NULL"
1700
+ # bool
1701
+ if isinstance(val, bool):
1702
+ return "true" if val else "false"
1703
+ # str
1704
+ if isinstance(val, str):
1705
+ return f'"{_escape_str(val)}"'
1706
+ # num
1707
+ elif isinstance(val, (int | float)):
1708
+ return str(val)
1709
+ # time
1710
+ elif isinstance(val, datetime):
1711
+ return f'datetime("{val.isoformat()}")'
1712
+ # list
1713
+ elif isinstance(val, list):
1714
+ if key == self.dim_field:
1715
+ dim = len(val)
1716
+ joined = ",".join(str(float(x)) for x in val)
1717
+ return f"VECTOR<{dim}, FLOAT>([{joined}])"
1718
+ else:
1719
+ return f"[{', '.join(self._format_value(v) for v in val)}]"
1720
+ # NVector
1721
+ elif isinstance(val, NVector):
1722
+ if key == self.dim_field:
1723
+ dim = len(val)
1724
+ joined = ",".join(str(float(x)) for x in val)
1725
+ return f"VECTOR<{dim}, FLOAT>([{joined}])"
1726
+ else:
1727
+ logger.warning("Invalid NVector")
1728
+ # dict
1729
+ if isinstance(val, dict):
1730
+ j = json.dumps(val, ensure_ascii=False, separators=(",", ":"))
1731
+ return f'"{_escape_str(j)}"'
1732
+ else:
1733
+ return f'"{_escape_str(str(val))}"'
1734
+
1735
+ @timed
1736
+ def _metadata_filter(self, metadata: dict[str, Any]) -> dict[str, Any]:
1737
+ """
1738
+ Filter and validate metadata dictionary against the Memory node schema.
1739
+ - Removes keys not in schema.
1740
+ - Warns if required fields are missing.
1741
+ """
1742
+
1743
+ dim_fields = {self.dim_field}
1744
+
1745
+ allowed_fields = self.common_fields | dim_fields
1746
+
1747
+ missing_fields = allowed_fields - metadata.keys()
1748
+ if missing_fields:
1749
+ logger.info(f"Metadata missing required fields: {sorted(missing_fields)}")
1750
+
1751
+ filtered_metadata = {k: v for k, v in metadata.items() if k in allowed_fields}
1752
+
1753
+ return filtered_metadata
1754
+
1755
+ def _build_return_fields(self, include_embedding: bool = False) -> str:
1756
+ fields = set(self.base_fields)
1757
+ if include_embedding:
1758
+ fields.add(self.dim_field)
1759
+ return ", ".join(f"n.{f} AS {f}" for f in fields)