MemoryOS 1.0.1__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MemoryOS might be problematic. Click here for more details.

Files changed (82) hide show
  1. {memoryos-1.0.1.dist-info → memoryos-1.1.1.dist-info}/METADATA +7 -2
  2. {memoryos-1.0.1.dist-info → memoryos-1.1.1.dist-info}/RECORD +79 -65
  3. {memoryos-1.0.1.dist-info → memoryos-1.1.1.dist-info}/WHEEL +1 -1
  4. memos/__init__.py +1 -1
  5. memos/api/client.py +109 -0
  6. memos/api/config.py +11 -9
  7. memos/api/context/dependencies.py +15 -55
  8. memos/api/middleware/request_context.py +9 -40
  9. memos/api/product_api.py +2 -3
  10. memos/api/product_models.py +91 -16
  11. memos/api/routers/product_router.py +23 -16
  12. memos/api/start_api.py +10 -0
  13. memos/configs/graph_db.py +4 -0
  14. memos/configs/mem_scheduler.py +38 -3
  15. memos/context/context.py +255 -0
  16. memos/embedders/factory.py +2 -0
  17. memos/graph_dbs/nebular.py +230 -232
  18. memos/graph_dbs/neo4j.py +35 -1
  19. memos/graph_dbs/neo4j_community.py +7 -0
  20. memos/llms/factory.py +2 -0
  21. memos/llms/openai.py +74 -2
  22. memos/log.py +27 -15
  23. memos/mem_cube/general.py +3 -1
  24. memos/mem_os/core.py +60 -22
  25. memos/mem_os/main.py +3 -6
  26. memos/mem_os/product.py +35 -11
  27. memos/mem_reader/factory.py +2 -0
  28. memos/mem_reader/simple_struct.py +127 -74
  29. memos/mem_scheduler/analyzer/__init__.py +0 -0
  30. memos/mem_scheduler/analyzer/mos_for_test_scheduler.py +569 -0
  31. memos/mem_scheduler/analyzer/scheduler_for_eval.py +280 -0
  32. memos/mem_scheduler/base_scheduler.py +126 -56
  33. memos/mem_scheduler/general_modules/dispatcher.py +2 -2
  34. memos/mem_scheduler/general_modules/misc.py +99 -1
  35. memos/mem_scheduler/general_modules/scheduler_logger.py +17 -11
  36. memos/mem_scheduler/general_scheduler.py +40 -88
  37. memos/mem_scheduler/memory_manage_modules/__init__.py +5 -0
  38. memos/mem_scheduler/memory_manage_modules/memory_filter.py +308 -0
  39. memos/mem_scheduler/{general_modules → memory_manage_modules}/retriever.py +34 -7
  40. memos/mem_scheduler/monitors/dispatcher_monitor.py +9 -8
  41. memos/mem_scheduler/monitors/general_monitor.py +119 -39
  42. memos/mem_scheduler/optimized_scheduler.py +124 -0
  43. memos/mem_scheduler/orm_modules/__init__.py +0 -0
  44. memos/mem_scheduler/orm_modules/base_model.py +635 -0
  45. memos/mem_scheduler/orm_modules/monitor_models.py +261 -0
  46. memos/mem_scheduler/scheduler_factory.py +2 -0
  47. memos/mem_scheduler/schemas/monitor_schemas.py +96 -29
  48. memos/mem_scheduler/utils/config_utils.py +100 -0
  49. memos/mem_scheduler/utils/db_utils.py +33 -0
  50. memos/mem_scheduler/utils/filter_utils.py +1 -1
  51. memos/mem_scheduler/webservice_modules/__init__.py +0 -0
  52. memos/memories/activation/kv.py +2 -1
  53. memos/memories/textual/item.py +95 -16
  54. memos/memories/textual/naive.py +1 -1
  55. memos/memories/textual/tree.py +27 -3
  56. memos/memories/textual/tree_text_memory/organize/handler.py +4 -2
  57. memos/memories/textual/tree_text_memory/organize/manager.py +28 -14
  58. memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +1 -2
  59. memos/memories/textual/tree_text_memory/organize/reorganizer.py +75 -23
  60. memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +7 -5
  61. memos/memories/textual/tree_text_memory/retrieve/internet_retriever.py +6 -2
  62. memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +2 -0
  63. memos/memories/textual/tree_text_memory/retrieve/recall.py +70 -22
  64. memos/memories/textual/tree_text_memory/retrieve/searcher.py +101 -33
  65. memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py +5 -4
  66. memos/memos_tools/singleton.py +174 -0
  67. memos/memos_tools/thread_safe_dict.py +22 -0
  68. memos/memos_tools/thread_safe_dict_segment.py +382 -0
  69. memos/parsers/factory.py +2 -0
  70. memos/reranker/concat.py +59 -0
  71. memos/reranker/cosine_local.py +1 -0
  72. memos/reranker/factory.py +5 -0
  73. memos/reranker/http_bge.py +225 -12
  74. memos/templates/mem_scheduler_prompts.py +242 -0
  75. memos/types.py +4 -1
  76. memos/api/context/context.py +0 -147
  77. memos/api/context/context_thread.py +0 -96
  78. memos/mem_scheduler/mos_for_test_scheduler.py +0 -146
  79. {memoryos-1.0.1.dist-info → memoryos-1.1.1.dist-info}/entry_points.txt +0 -0
  80. {memoryos-1.0.1.dist-info → memoryos-1.1.1.dist-info/licenses}/LICENSE +0 -0
  81. /memos/mem_scheduler/{general_modules → webservice_modules}/rabbitmq_service.py +0 -0
  82. /memos/mem_scheduler/{general_modules → webservice_modules}/redis_service.py +0 -0
@@ -3,7 +3,6 @@ import traceback
3
3
 
4
4
  from contextlib import suppress
5
5
  from datetime import datetime
6
- from queue import Empty, Queue
7
6
  from threading import Lock
8
7
  from typing import TYPE_CHECKING, Any, ClassVar, Literal
9
8
 
@@ -17,12 +16,27 @@ from memos.utils import timed
17
16
 
18
17
 
19
18
  if TYPE_CHECKING:
20
- from nebulagraph_python.client.pool import NebulaPool
19
+ from nebulagraph_python import (
20
+ NebulaClient,
21
+ )
21
22
 
22
23
 
23
24
  logger = get_logger(__name__)
24
25
 
25
26
 
27
+ _TRANSIENT_ERR_KEYS = (
28
+ "Session not found",
29
+ "Connection not established",
30
+ "timeout",
31
+ "deadline exceeded",
32
+ "Broken pipe",
33
+ "EOFError",
34
+ "socket closed",
35
+ "connection reset",
36
+ "connection refused",
37
+ )
38
+
39
+
26
40
  @timed
27
41
  def _normalize(vec: list[float]) -> list[float]:
28
42
  v = np.asarray(vec, dtype=np.float32)
@@ -87,137 +101,6 @@ def _normalize_datetime(val):
87
101
  return str(val)
88
102
 
89
103
 
90
- class SessionPoolError(Exception):
91
- pass
92
-
93
-
94
- class SessionPool:
95
- @require_python_package(
96
- import_name="nebulagraph_python",
97
- install_command="pip install ... @Tianxing",
98
- install_link=".....",
99
- )
100
- def __init__(
101
- self,
102
- hosts: list[str],
103
- user: str,
104
- password: str,
105
- minsize: int = 1,
106
- maxsize: int = 10000,
107
- ):
108
- self.hosts = hosts
109
- self.user = user
110
- self.password = password
111
- self.minsize = minsize
112
- self.maxsize = maxsize
113
- self.pool = Queue(maxsize)
114
- self.lock = Lock()
115
-
116
- self.clients = []
117
-
118
- for _ in range(minsize):
119
- self._create_and_add_client()
120
-
121
- @timed
122
- def _create_and_add_client(self):
123
- from nebulagraph_python import NebulaClient
124
-
125
- client = NebulaClient(self.hosts, self.user, self.password)
126
- self.pool.put(client)
127
- self.clients.append(client)
128
-
129
- @timed
130
- def get_client(self, timeout: float = 5.0):
131
- try:
132
- return self.pool.get(timeout=timeout)
133
- except Empty:
134
- with self.lock:
135
- if len(self.clients) < self.maxsize:
136
- from nebulagraph_python import NebulaClient
137
-
138
- client = NebulaClient(self.hosts, self.user, self.password)
139
- self.clients.append(client)
140
- return client
141
- raise RuntimeError("NebulaClientPool exhausted") from None
142
-
143
- @timed
144
- def return_client(self, client):
145
- try:
146
- client.execute("YIELD 1")
147
- self.pool.put(client)
148
- except Exception:
149
- logger.info("[Pool] Client dead, replacing...")
150
- self.replace_client(client)
151
-
152
- @timed
153
- def close(self):
154
- for client in self.clients:
155
- with suppress(Exception):
156
- client.close()
157
- self.clients.clear()
158
-
159
- @timed
160
- def get(self):
161
- """
162
- Context manager: with pool.get() as client:
163
- """
164
-
165
- class _ClientContext:
166
- def __init__(self, outer):
167
- self.outer = outer
168
- self.client = None
169
-
170
- def __enter__(self):
171
- self.client = self.outer.get_client()
172
- return self.client
173
-
174
- def __exit__(self, exc_type, exc_val, exc_tb):
175
- if self.client:
176
- self.outer.return_client(self.client)
177
-
178
- return _ClientContext(self)
179
-
180
- @timed
181
- def reset_pool(self):
182
- """⚠️ Emergency reset: Close all clients and clear the pool."""
183
- logger.warning("[Pool] Resetting all clients. Existing sessions will be lost.")
184
- with self.lock:
185
- for client in self.clients:
186
- try:
187
- client.close()
188
- except Exception:
189
- logger.error("Fail to close!!!")
190
- self.clients.clear()
191
- while not self.pool.empty():
192
- try:
193
- self.pool.get_nowait()
194
- except Empty:
195
- break
196
- for _ in range(self.minsize):
197
- self._create_and_add_client()
198
- logger.info("[Pool] Pool has been reset successfully.")
199
-
200
- @timed
201
- def replace_client(self, client):
202
- try:
203
- client.close()
204
- except Exception:
205
- logger.error("Fail to close client")
206
-
207
- if client in self.clients:
208
- self.clients.remove(client)
209
-
210
- from nebulagraph_python import NebulaClient
211
-
212
- new_client = NebulaClient(self.hosts, self.user, self.password)
213
- self.clients.append(new_client)
214
-
215
- self.pool.put(new_client)
216
-
217
- logger.info("[Pool] Replaced dead client with a new one.")
218
- return new_client
219
-
220
-
221
104
  class NebulaGraphDB(BaseGraphDB):
222
105
  """
223
106
  NebulaGraph-based implementation of a graph memory store.
@@ -226,94 +109,194 @@ class NebulaGraphDB(BaseGraphDB):
226
109
  # ====== shared pool cache & refcount ======
227
110
  # These are process-local; in a multi-process model each process will
228
111
  # have its own cache.
229
- _POOL_CACHE: ClassVar[dict[str, "NebulaPool"]] = {}
230
- _POOL_REFCOUNT: ClassVar[dict[str, int]] = {}
231
- _POOL_LOCK: ClassVar[Lock] = Lock()
112
+ _CLIENT_CACHE: ClassVar[dict[str, "NebulaClient"]] = {}
113
+ _CLIENT_REFCOUNT: ClassVar[dict[str, int]] = {}
114
+ _CLIENT_LOCK: ClassVar[Lock] = Lock()
115
+ _CLIENT_INIT_DONE: ClassVar[set[str]] = set()
232
116
 
233
117
  @staticmethod
234
- def _make_pool_key(cfg: NebulaGraphDBConfig) -> str:
235
- """
236
- Build a cache key that captures all connection-affecting options.
237
- Keep this key stable and include fields that change the underlying pool behavior.
238
- """
239
- # NOTE: Do not include tenant-like or query-scope-only fields here.
240
- # Only include things that affect the actual TCP/auth/session pool.
118
+ def _get_hosts_from_cfg(cfg: NebulaGraphDBConfig) -> list[str]:
119
+ hosts = getattr(cfg, "uri", None) or getattr(cfg, "hosts", None)
120
+ if isinstance(hosts, str):
121
+ return [hosts]
122
+ return list(hosts or [])
123
+
124
+ @staticmethod
125
+ def _make_client_key(cfg: NebulaGraphDBConfig) -> str:
126
+ hosts = NebulaGraphDB._get_hosts_from_cfg(cfg)
241
127
  return "|".join(
242
128
  [
243
- "nebula",
244
- str(getattr(cfg, "uri", "")),
129
+ "nebula-sync",
130
+ ",".join(hosts),
245
131
  str(getattr(cfg, "user", "")),
246
- str(getattr(cfg, "password", "")),
247
- # pool sizing / tls / timeouts if you have them in config:
248
- str(getattr(cfg, "max_client", 1000)),
249
- # multi-db mode can impact how we use sessions; keep it to be safe
250
132
  str(getattr(cfg, "use_multi_db", False)),
133
+ str(getattr(cfg, "space", "")),
251
134
  ]
252
135
  )
253
136
 
254
137
  @classmethod
255
- def _get_or_create_shared_pool(cls, cfg: NebulaGraphDBConfig):
256
- """
257
- Get a shared NebulaPool from cache or create one if missing.
258
- Thread-safe with a lock; maintains a simple refcount.
259
- """
260
- key = cls._make_pool_key(cfg)
261
-
262
- with cls._POOL_LOCK:
263
- pool = cls._POOL_CACHE.get(key)
264
- if pool is None:
265
- # Create a new pool and put into cache
266
- pool = SessionPool(
267
- hosts=cfg.get("uri"),
268
- user=cfg.get("user"),
269
- password=cfg.get("password"),
270
- minsize=1,
271
- maxsize=cfg.get("max_client", 1000),
138
+ def _bootstrap_admin(cls, cfg: NebulaGraphDBConfig, client: "NebulaClient") -> "NebulaGraphDB":
139
+ tmp = object.__new__(NebulaGraphDB)
140
+ tmp.config = cfg
141
+ tmp.db_name = cfg.space
142
+ tmp.user_name = getattr(cfg, "user_name", None)
143
+ tmp.embedding_dimension = getattr(cfg, "embedding_dimension", 3072)
144
+ tmp.default_memory_dimension = 3072
145
+ tmp.common_fields = {
146
+ "id",
147
+ "memory",
148
+ "user_name",
149
+ "user_id",
150
+ "session_id",
151
+ "status",
152
+ "key",
153
+ "confidence",
154
+ "tags",
155
+ "created_at",
156
+ "updated_at",
157
+ "memory_type",
158
+ "sources",
159
+ "source",
160
+ "node_type",
161
+ "visibility",
162
+ "usage",
163
+ "background",
164
+ }
165
+ tmp.base_fields = set(tmp.common_fields) - {"usage"}
166
+ tmp.heavy_fields = {"usage"}
167
+ tmp.dim_field = (
168
+ f"embedding_{tmp.embedding_dimension}"
169
+ if str(tmp.embedding_dimension) != str(tmp.default_memory_dimension)
170
+ else "embedding"
171
+ )
172
+ tmp.system_db_name = "system" if getattr(cfg, "use_multi_db", False) else cfg.space
173
+ tmp._client = client
174
+ tmp._owns_client = False
175
+ return tmp
176
+
177
+ @classmethod
178
+ def _get_or_create_shared_client(cls, cfg: NebulaGraphDBConfig) -> tuple[str, "NebulaClient"]:
179
+ from nebulagraph_python import (
180
+ ConnectionConfig,
181
+ NebulaClient,
182
+ SessionConfig,
183
+ SessionPoolConfig,
184
+ )
185
+
186
+ key = cls._make_client_key(cfg)
187
+ with cls._CLIENT_LOCK:
188
+ client = cls._CLIENT_CACHE.get(key)
189
+ if client is None:
190
+ # Connection setting
191
+ conn_conf: ConnectionConfig | None = getattr(cfg, "conn_config", None)
192
+ if conn_conf is None:
193
+ conn_conf = ConnectionConfig.from_defults(
194
+ cls._get_hosts_from_cfg(cfg),
195
+ getattr(cfg, "ssl_param", None),
196
+ )
197
+
198
+ sess_conf = SessionConfig(graph=getattr(cfg, "space", None))
199
+ pool_conf = SessionPoolConfig(
200
+ size=int(getattr(cfg, "max_client", 1000)), wait_timeout=5000
201
+ )
202
+
203
+ client = NebulaClient(
204
+ hosts=conn_conf.hosts,
205
+ username=cfg.user,
206
+ password=cfg.password,
207
+ conn_config=conn_conf,
208
+ session_config=sess_conf,
209
+ session_pool_config=pool_conf,
272
210
  )
273
- cls._POOL_CACHE[key] = pool
274
- cls._POOL_REFCOUNT[key] = 0
275
- logger.info(f"[NebulaGraphDB] Created new shared NebulaPool for key={key}")
211
+ cls._CLIENT_CACHE[key] = client
212
+ cls._CLIENT_REFCOUNT[key] = 0
213
+ logger.info(f"[NebulaGraphDBSync] Created shared NebulaClient key={key}")
276
214
 
277
- # Increase refcount for the caller
278
- cls._POOL_REFCOUNT[key] = cls._POOL_REFCOUNT.get(key, 0) + 1
279
- return key, pool
215
+ cls._CLIENT_REFCOUNT[key] = cls._CLIENT_REFCOUNT.get(key, 0) + 1
216
+
217
+ if getattr(cfg, "auto_create", False) and key not in cls._CLIENT_INIT_DONE:
218
+ try:
219
+ pass
220
+ finally:
221
+ pass
222
+
223
+ if getattr(cfg, "auto_create", False) and key not in cls._CLIENT_INIT_DONE:
224
+ with cls._CLIENT_LOCK:
225
+ if key not in cls._CLIENT_INIT_DONE:
226
+ admin = cls._bootstrap_admin(cfg, client)
227
+ try:
228
+ admin._ensure_database_exists()
229
+ admin._create_basic_property_indexes()
230
+ admin._create_vector_index(
231
+ dimensions=int(
232
+ admin.embedding_dimension or admin.default_memory_dimension
233
+ ),
234
+ )
235
+ cls._CLIENT_INIT_DONE.add(key)
236
+ logger.info("[NebulaGraphDBSync] One-time init done")
237
+ except Exception:
238
+ logger.exception("[NebulaGraphDBSync] One-time init failed")
239
+
240
+ return key, client
241
+
242
+ def _refresh_client(self):
243
+ """
244
+ refresh NebulaClient:
245
+ """
246
+ old_key = getattr(self, "_client_key", None)
247
+ if not old_key:
248
+ return
249
+
250
+ cls = self.__class__
251
+ with cls._CLIENT_LOCK:
252
+ try:
253
+ if old_key in cls._CLIENT_CACHE:
254
+ try:
255
+ cls._CLIENT_CACHE[old_key].close()
256
+ except Exception as e:
257
+ logger.warning(f"[refresh_client] close old client error: {e}")
258
+ finally:
259
+ cls._CLIENT_CACHE.pop(old_key, None)
260
+ finally:
261
+ cls._CLIENT_REFCOUNT[old_key] = 0
262
+
263
+ new_key, new_client = cls._get_or_create_shared_client(self.config)
264
+ self._client_key = new_key
265
+ self._client = new_client
266
+ logger.info(f"[NebulaGraphDBSync] client refreshed: {old_key} -> {new_key}")
280
267
 
281
268
  @classmethod
282
- def _release_shared_pool(cls, key: str):
283
- """
284
- Decrease refcount for the given pool key; only close when refcount hits zero.
285
- """
286
- with cls._POOL_LOCK:
287
- if key not in cls._POOL_CACHE:
269
+ def _release_shared_client(cls, key: str):
270
+ with cls._CLIENT_LOCK:
271
+ if key not in cls._CLIENT_CACHE:
288
272
  return
289
- cls._POOL_REFCOUNT[key] = max(0, cls._POOL_REFCOUNT.get(key, 0) - 1)
290
- if cls._POOL_REFCOUNT[key] == 0:
273
+ cls._CLIENT_REFCOUNT[key] = max(0, cls._CLIENT_REFCOUNT.get(key, 0) - 1)
274
+ if cls._CLIENT_REFCOUNT[key] == 0:
291
275
  try:
292
- cls._POOL_CACHE[key].close()
276
+ cls._CLIENT_CACHE[key].close()
293
277
  except Exception as e:
294
- logger.warning(f"[NebulaGraphDB] Error closing shared pool: {e}")
278
+ logger.warning(f"[NebulaGraphDBSync] Error closing client: {e}")
295
279
  finally:
296
- cls._POOL_CACHE.pop(key, None)
297
- cls._POOL_REFCOUNT.pop(key, None)
298
- logger.info(f"[NebulaGraphDB] Closed and removed shared pool key={key}")
280
+ cls._CLIENT_CACHE.pop(key, None)
281
+ cls._CLIENT_REFCOUNT.pop(key, None)
282
+ logger.info(f"[NebulaGraphDBSync] Closed & removed client key={key}")
299
283
 
300
284
  @classmethod
301
- def close_all_shared_pools(cls):
302
- """Force close all cached pools. Call this on graceful shutdown."""
303
- with cls._POOL_LOCK:
304
- for key, pool in list(cls._POOL_CACHE.items()):
285
+ def close_all_shared_clients(cls):
286
+ with cls._CLIENT_LOCK:
287
+ for key, client in list(cls._CLIENT_CACHE.items()):
305
288
  try:
306
- pool.close()
289
+ client.close()
307
290
  except Exception as e:
308
- logger.warning(f"[NebulaGraphDB] Error closing pool key={key}: {e}")
291
+ logger.warning(f"[NebulaGraphDBSync] Error closing client {key}: {e}")
309
292
  finally:
310
- logger.info(f"[NebulaGraphDB] Closed pool key={key}")
311
- cls._POOL_CACHE.clear()
312
- cls._POOL_REFCOUNT.clear()
293
+ logger.info(f"[NebulaGraphDBSync] Closed client key={key}")
294
+ cls._CLIENT_CACHE.clear()
295
+ cls._CLIENT_REFCOUNT.clear()
313
296
 
314
297
  @require_python_package(
315
298
  import_name="nebulagraph_python",
316
- install_command="pip install ... @Tianxing",
299
+ install_command="pip install nebulagraph-python>=5.1.1",
317
300
  install_link=".....",
318
301
  )
319
302
  def __init__(self, config: NebulaGraphDBConfig):
@@ -371,34 +354,32 @@ class NebulaGraphDB(BaseGraphDB):
371
354
 
372
355
  # ---- NEW: pool acquisition strategy
373
356
  # Get or create a shared pool from the class-level cache
374
- self._pool_key, self.pool = self._get_or_create_shared_pool(config)
375
- self._owns_pool = True # We manage refcount for this instance
376
-
377
- # auto-create graph type / graph / index if needed
378
- if config.auto_create:
379
- self._ensure_database_exists()
380
-
381
- self.execute_query(f"SESSION SET GRAPH `{self.db_name}`")
382
-
383
- # Create only if not exists
384
- self.create_index(dimensions=config.embedding_dimension)
357
+ self._client_key, self._client = self._get_or_create_shared_client(config)
358
+ self._owns_client = True
385
359
 
386
360
  logger.info("Connected to NebulaGraph successfully.")
387
361
 
388
362
  @timed
389
- def execute_query(self, gql: str, timeout: float = 10.0, auto_set_db: bool = True):
390
- with self.pool.get() as client:
391
- try:
392
- if auto_set_db and self.db_name:
393
- client.execute(f"SESSION SET GRAPH `{self.db_name}`")
394
- return client.execute(gql, timeout=timeout)
363
+ def execute_query(self, gql: str, timeout: float = 60.0, auto_set_db: bool = True):
364
+ def _wrap_use_db(q: str) -> str:
365
+ if auto_set_db and self.db_name:
366
+ return f"USE `{self.db_name}`\n{q}"
367
+ return q
395
368
 
396
- except Exception as e:
397
- if "Session not found" in str(e) or "Connection not established" in str(e):
398
- logger.warning(f"[execute_query] {e!s}, replacing client...")
399
- self.pool.replace_client(client)
400
- return self.execute_query(gql, timeout, auto_set_db)
401
- raise
369
+ try:
370
+ return self._client.execute(_wrap_use_db(gql), timeout=timeout)
371
+
372
+ except Exception as e:
373
+ emsg = str(e)
374
+ if any(k.lower() in emsg.lower() for k in _TRANSIENT_ERR_KEYS):
375
+ logger.warning(f"[execute_query] {e!s} → refreshing session pool and retry once...")
376
+ try:
377
+ self._refresh_client()
378
+ return self._client.execute(_wrap_use_db(gql), timeout=timeout)
379
+ except Exception:
380
+ logger.exception("[execute_query] retry after refresh failed")
381
+ raise
382
+ raise
402
383
 
403
384
  @timed
404
385
  def close(self):
@@ -409,13 +390,13 @@ class NebulaGraphDB(BaseGraphDB):
409
390
  - If pool was acquired via shared cache, decrement refcount and close
410
391
  when the last owner releases it.
411
392
  """
412
- if not self._owns_pool:
413
- logger.debug("[NebulaGraphDB] close() skipped (injected pool).")
393
+ if not self._owns_client:
394
+ logger.debug("[NebulaGraphDBSync] close() skipped (injected client).")
414
395
  return
415
- if self._pool_key:
416
- self._release_shared_pool(self._pool_key)
417
- self._pool_key = None
418
- self.pool = None
396
+ if self._client_key:
397
+ self._release_shared_client(self._client_key)
398
+ self._client_key = None
399
+ self._client = None
419
400
 
420
401
  # NOTE: __del__ is best-effort; do not rely on GC order.
421
402
  def __del__(self):
@@ -972,6 +953,7 @@ class NebulaGraphDB(BaseGraphDB):
972
953
  scope: str | None = None,
973
954
  status: str | None = None,
974
955
  threshold: float | None = None,
956
+ search_filter: dict | None = None,
975
957
  **kwargs,
976
958
  ) -> list[dict]:
977
959
  """
@@ -984,6 +966,8 @@ class NebulaGraphDB(BaseGraphDB):
984
966
  status (str, optional): Node status filter (e.g., 'active', 'archived').
985
967
  If provided, restricts results to nodes with matching status.
986
968
  threshold (float, optional): Minimum similarity score threshold (0 ~ 1).
969
+ search_filter (dict, optional): Additional metadata filters for search results.
970
+ Keys should match node properties, values are the expected values.
987
971
 
988
972
  Returns:
989
973
  list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
@@ -993,6 +977,7 @@ class NebulaGraphDB(BaseGraphDB):
993
977
  - If scope is provided, it restricts results to nodes with matching memory_type.
994
978
  - If 'status' is provided, only nodes with the matching status will be returned.
995
979
  - If threshold is provided, only results with score >= threshold will be returned.
980
+ - If search_filter is provided, additional WHERE clauses will be added for metadata filtering.
996
981
  - Typical use case: restrict to 'status = activated' to avoid
997
982
  matching archived or merged nodes.
998
983
  """
@@ -1012,10 +997,17 @@ class NebulaGraphDB(BaseGraphDB):
1012
997
  else:
1013
998
  where_clauses.append(f'n.user_name = "{self.config.user_name}"')
1014
999
 
1000
+ # Add search_filter conditions
1001
+ if search_filter:
1002
+ for key, value in search_filter.items():
1003
+ if isinstance(value, str):
1004
+ where_clauses.append(f'n.{key} = "{value}"')
1005
+ else:
1006
+ where_clauses.append(f"n.{key} = {value}")
1007
+
1015
1008
  where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
1016
1009
 
1017
1010
  gql = f"""
1018
- USE `{self.db_name}`
1019
1011
  MATCH (n@Memory)
1020
1012
  {where_clause}
1021
1013
  ORDER BY inner_product(n.{self.dim_field}, {gql_vector}) DESC
@@ -1038,7 +1030,7 @@ class NebulaGraphDB(BaseGraphDB):
1038
1030
  id_val = values[0].as_string()
1039
1031
  score_val = values[1].as_double()
1040
1032
  score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score
1041
- if threshold is None or score_val <= threshold:
1033
+ if threshold is None or score_val >= threshold:
1042
1034
  output.append({"id": id_val, "score": score_val})
1043
1035
  return output
1044
1036
  except Exception as e:
@@ -1368,9 +1360,9 @@ class NebulaGraphDB(BaseGraphDB):
1368
1360
  where_clause += f' AND n.user_name = "{self.config.user_name}"'
1369
1361
 
1370
1362
  return_fields = self._build_return_fields(include_embedding)
1363
+ return_fields += f", n.{self.dim_field} AS {self.dim_field}"
1371
1364
 
1372
1365
  query = f"""
1373
- USE `{self.db_name}`
1374
1366
  MATCH (n@Memory)
1375
1367
  WHERE {where_clause}
1376
1368
  OPTIONAL MATCH (n)-[@PARENT]->(c@Memory)
@@ -1380,11 +1372,16 @@ class NebulaGraphDB(BaseGraphDB):
1380
1372
  """
1381
1373
 
1382
1374
  candidates = []
1375
+ node_ids = set()
1383
1376
  try:
1384
1377
  results = self.execute_query(query)
1385
1378
  for row in results:
1386
1379
  props = {k: v.value for k, v in row.items()}
1387
- candidates.append(self._parse_node(props))
1380
+ node = self._parse_node(props)
1381
+ node_id = node["id"]
1382
+ if node_id not in node_ids:
1383
+ candidates.append(node)
1384
+ node_ids.add(node_id)
1388
1385
  except Exception as e:
1389
1386
  logger.error(f"Failed : {e}, traceback: {traceback.format_exc()}")
1390
1387
  return candidates
@@ -1538,18 +1535,19 @@ class NebulaGraphDB(BaseGraphDB):
1538
1535
  logger.info(f"✅ Graph Type {graph_type_name} already include {self.dim_field}")
1539
1536
 
1540
1537
  create_graph = f"CREATE GRAPH IF NOT EXISTS `{self.db_name}` TYPED {graph_type_name}"
1541
- set_graph_working = f"SESSION SET GRAPH `{self.db_name}`"
1542
-
1543
1538
  try:
1544
1539
  self.execute_query(create_graph, auto_set_db=False)
1545
- self.execute_query(set_graph_working)
1546
1540
  logger.info(f"✅ Graph ``{self.db_name}`` is now the working graph.")
1547
1541
  except Exception as e:
1548
1542
  logger.error(f"❌ Failed to create tag: {e} trace: {traceback.format_exc()}")
1549
1543
 
1550
1544
  @timed
1551
1545
  def _create_vector_index(
1552
- self, label: str, vector_property: str, dimensions: int, index_name: str
1546
+ self,
1547
+ label: str = "Memory",
1548
+ vector_property: str = "embedding",
1549
+ dimensions: int = 3072,
1550
+ index_name: str = "memory_vector_index",
1553
1551
  ) -> None:
1554
1552
  """
1555
1553
  Create a vector index for the specified property in the label.