MemoryOS 1.0.1__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MemoryOS might be problematic. Click here for more details.
- {memoryos-1.0.1.dist-info → memoryos-1.1.1.dist-info}/METADATA +7 -2
- {memoryos-1.0.1.dist-info → memoryos-1.1.1.dist-info}/RECORD +79 -65
- {memoryos-1.0.1.dist-info → memoryos-1.1.1.dist-info}/WHEEL +1 -1
- memos/__init__.py +1 -1
- memos/api/client.py +109 -0
- memos/api/config.py +11 -9
- memos/api/context/dependencies.py +15 -55
- memos/api/middleware/request_context.py +9 -40
- memos/api/product_api.py +2 -3
- memos/api/product_models.py +91 -16
- memos/api/routers/product_router.py +23 -16
- memos/api/start_api.py +10 -0
- memos/configs/graph_db.py +4 -0
- memos/configs/mem_scheduler.py +38 -3
- memos/context/context.py +255 -0
- memos/embedders/factory.py +2 -0
- memos/graph_dbs/nebular.py +230 -232
- memos/graph_dbs/neo4j.py +35 -1
- memos/graph_dbs/neo4j_community.py +7 -0
- memos/llms/factory.py +2 -0
- memos/llms/openai.py +74 -2
- memos/log.py +27 -15
- memos/mem_cube/general.py +3 -1
- memos/mem_os/core.py +60 -22
- memos/mem_os/main.py +3 -6
- memos/mem_os/product.py +35 -11
- memos/mem_reader/factory.py +2 -0
- memos/mem_reader/simple_struct.py +127 -74
- memos/mem_scheduler/analyzer/__init__.py +0 -0
- memos/mem_scheduler/analyzer/mos_for_test_scheduler.py +569 -0
- memos/mem_scheduler/analyzer/scheduler_for_eval.py +280 -0
- memos/mem_scheduler/base_scheduler.py +126 -56
- memos/mem_scheduler/general_modules/dispatcher.py +2 -2
- memos/mem_scheduler/general_modules/misc.py +99 -1
- memos/mem_scheduler/general_modules/scheduler_logger.py +17 -11
- memos/mem_scheduler/general_scheduler.py +40 -88
- memos/mem_scheduler/memory_manage_modules/__init__.py +5 -0
- memos/mem_scheduler/memory_manage_modules/memory_filter.py +308 -0
- memos/mem_scheduler/{general_modules → memory_manage_modules}/retriever.py +34 -7
- memos/mem_scheduler/monitors/dispatcher_monitor.py +9 -8
- memos/mem_scheduler/monitors/general_monitor.py +119 -39
- memos/mem_scheduler/optimized_scheduler.py +124 -0
- memos/mem_scheduler/orm_modules/__init__.py +0 -0
- memos/mem_scheduler/orm_modules/base_model.py +635 -0
- memos/mem_scheduler/orm_modules/monitor_models.py +261 -0
- memos/mem_scheduler/scheduler_factory.py +2 -0
- memos/mem_scheduler/schemas/monitor_schemas.py +96 -29
- memos/mem_scheduler/utils/config_utils.py +100 -0
- memos/mem_scheduler/utils/db_utils.py +33 -0
- memos/mem_scheduler/utils/filter_utils.py +1 -1
- memos/mem_scheduler/webservice_modules/__init__.py +0 -0
- memos/memories/activation/kv.py +2 -1
- memos/memories/textual/item.py +95 -16
- memos/memories/textual/naive.py +1 -1
- memos/memories/textual/tree.py +27 -3
- memos/memories/textual/tree_text_memory/organize/handler.py +4 -2
- memos/memories/textual/tree_text_memory/organize/manager.py +28 -14
- memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +1 -2
- memos/memories/textual/tree_text_memory/organize/reorganizer.py +75 -23
- memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +7 -5
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever.py +6 -2
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +2 -0
- memos/memories/textual/tree_text_memory/retrieve/recall.py +70 -22
- memos/memories/textual/tree_text_memory/retrieve/searcher.py +101 -33
- memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py +5 -4
- memos/memos_tools/singleton.py +174 -0
- memos/memos_tools/thread_safe_dict.py +22 -0
- memos/memos_tools/thread_safe_dict_segment.py +382 -0
- memos/parsers/factory.py +2 -0
- memos/reranker/concat.py +59 -0
- memos/reranker/cosine_local.py +1 -0
- memos/reranker/factory.py +5 -0
- memos/reranker/http_bge.py +225 -12
- memos/templates/mem_scheduler_prompts.py +242 -0
- memos/types.py +4 -1
- memos/api/context/context.py +0 -147
- memos/api/context/context_thread.py +0 -96
- memos/mem_scheduler/mos_for_test_scheduler.py +0 -146
- {memoryos-1.0.1.dist-info → memoryos-1.1.1.dist-info}/entry_points.txt +0 -0
- {memoryos-1.0.1.dist-info → memoryos-1.1.1.dist-info/licenses}/LICENSE +0 -0
- /memos/mem_scheduler/{general_modules → webservice_modules}/rabbitmq_service.py +0 -0
- /memos/mem_scheduler/{general_modules → webservice_modules}/redis_service.py +0 -0
memos/graph_dbs/nebular.py
CHANGED
|
@@ -3,7 +3,6 @@ import traceback
|
|
|
3
3
|
|
|
4
4
|
from contextlib import suppress
|
|
5
5
|
from datetime import datetime
|
|
6
|
-
from queue import Empty, Queue
|
|
7
6
|
from threading import Lock
|
|
8
7
|
from typing import TYPE_CHECKING, Any, ClassVar, Literal
|
|
9
8
|
|
|
@@ -17,12 +16,27 @@ from memos.utils import timed
|
|
|
17
16
|
|
|
18
17
|
|
|
19
18
|
if TYPE_CHECKING:
|
|
20
|
-
from nebulagraph_python
|
|
19
|
+
from nebulagraph_python import (
|
|
20
|
+
NebulaClient,
|
|
21
|
+
)
|
|
21
22
|
|
|
22
23
|
|
|
23
24
|
logger = get_logger(__name__)
|
|
24
25
|
|
|
25
26
|
|
|
27
|
+
_TRANSIENT_ERR_KEYS = (
|
|
28
|
+
"Session not found",
|
|
29
|
+
"Connection not established",
|
|
30
|
+
"timeout",
|
|
31
|
+
"deadline exceeded",
|
|
32
|
+
"Broken pipe",
|
|
33
|
+
"EOFError",
|
|
34
|
+
"socket closed",
|
|
35
|
+
"connection reset",
|
|
36
|
+
"connection refused",
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
26
40
|
@timed
|
|
27
41
|
def _normalize(vec: list[float]) -> list[float]:
|
|
28
42
|
v = np.asarray(vec, dtype=np.float32)
|
|
@@ -87,137 +101,6 @@ def _normalize_datetime(val):
|
|
|
87
101
|
return str(val)
|
|
88
102
|
|
|
89
103
|
|
|
90
|
-
class SessionPoolError(Exception):
|
|
91
|
-
pass
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
class SessionPool:
|
|
95
|
-
@require_python_package(
|
|
96
|
-
import_name="nebulagraph_python",
|
|
97
|
-
install_command="pip install ... @Tianxing",
|
|
98
|
-
install_link=".....",
|
|
99
|
-
)
|
|
100
|
-
def __init__(
|
|
101
|
-
self,
|
|
102
|
-
hosts: list[str],
|
|
103
|
-
user: str,
|
|
104
|
-
password: str,
|
|
105
|
-
minsize: int = 1,
|
|
106
|
-
maxsize: int = 10000,
|
|
107
|
-
):
|
|
108
|
-
self.hosts = hosts
|
|
109
|
-
self.user = user
|
|
110
|
-
self.password = password
|
|
111
|
-
self.minsize = minsize
|
|
112
|
-
self.maxsize = maxsize
|
|
113
|
-
self.pool = Queue(maxsize)
|
|
114
|
-
self.lock = Lock()
|
|
115
|
-
|
|
116
|
-
self.clients = []
|
|
117
|
-
|
|
118
|
-
for _ in range(minsize):
|
|
119
|
-
self._create_and_add_client()
|
|
120
|
-
|
|
121
|
-
@timed
|
|
122
|
-
def _create_and_add_client(self):
|
|
123
|
-
from nebulagraph_python import NebulaClient
|
|
124
|
-
|
|
125
|
-
client = NebulaClient(self.hosts, self.user, self.password)
|
|
126
|
-
self.pool.put(client)
|
|
127
|
-
self.clients.append(client)
|
|
128
|
-
|
|
129
|
-
@timed
|
|
130
|
-
def get_client(self, timeout: float = 5.0):
|
|
131
|
-
try:
|
|
132
|
-
return self.pool.get(timeout=timeout)
|
|
133
|
-
except Empty:
|
|
134
|
-
with self.lock:
|
|
135
|
-
if len(self.clients) < self.maxsize:
|
|
136
|
-
from nebulagraph_python import NebulaClient
|
|
137
|
-
|
|
138
|
-
client = NebulaClient(self.hosts, self.user, self.password)
|
|
139
|
-
self.clients.append(client)
|
|
140
|
-
return client
|
|
141
|
-
raise RuntimeError("NebulaClientPool exhausted") from None
|
|
142
|
-
|
|
143
|
-
@timed
|
|
144
|
-
def return_client(self, client):
|
|
145
|
-
try:
|
|
146
|
-
client.execute("YIELD 1")
|
|
147
|
-
self.pool.put(client)
|
|
148
|
-
except Exception:
|
|
149
|
-
logger.info("[Pool] Client dead, replacing...")
|
|
150
|
-
self.replace_client(client)
|
|
151
|
-
|
|
152
|
-
@timed
|
|
153
|
-
def close(self):
|
|
154
|
-
for client in self.clients:
|
|
155
|
-
with suppress(Exception):
|
|
156
|
-
client.close()
|
|
157
|
-
self.clients.clear()
|
|
158
|
-
|
|
159
|
-
@timed
|
|
160
|
-
def get(self):
|
|
161
|
-
"""
|
|
162
|
-
Context manager: with pool.get() as client:
|
|
163
|
-
"""
|
|
164
|
-
|
|
165
|
-
class _ClientContext:
|
|
166
|
-
def __init__(self, outer):
|
|
167
|
-
self.outer = outer
|
|
168
|
-
self.client = None
|
|
169
|
-
|
|
170
|
-
def __enter__(self):
|
|
171
|
-
self.client = self.outer.get_client()
|
|
172
|
-
return self.client
|
|
173
|
-
|
|
174
|
-
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
175
|
-
if self.client:
|
|
176
|
-
self.outer.return_client(self.client)
|
|
177
|
-
|
|
178
|
-
return _ClientContext(self)
|
|
179
|
-
|
|
180
|
-
@timed
|
|
181
|
-
def reset_pool(self):
|
|
182
|
-
"""⚠️ Emergency reset: Close all clients and clear the pool."""
|
|
183
|
-
logger.warning("[Pool] Resetting all clients. Existing sessions will be lost.")
|
|
184
|
-
with self.lock:
|
|
185
|
-
for client in self.clients:
|
|
186
|
-
try:
|
|
187
|
-
client.close()
|
|
188
|
-
except Exception:
|
|
189
|
-
logger.error("Fail to close!!!")
|
|
190
|
-
self.clients.clear()
|
|
191
|
-
while not self.pool.empty():
|
|
192
|
-
try:
|
|
193
|
-
self.pool.get_nowait()
|
|
194
|
-
except Empty:
|
|
195
|
-
break
|
|
196
|
-
for _ in range(self.minsize):
|
|
197
|
-
self._create_and_add_client()
|
|
198
|
-
logger.info("[Pool] Pool has been reset successfully.")
|
|
199
|
-
|
|
200
|
-
@timed
|
|
201
|
-
def replace_client(self, client):
|
|
202
|
-
try:
|
|
203
|
-
client.close()
|
|
204
|
-
except Exception:
|
|
205
|
-
logger.error("Fail to close client")
|
|
206
|
-
|
|
207
|
-
if client in self.clients:
|
|
208
|
-
self.clients.remove(client)
|
|
209
|
-
|
|
210
|
-
from nebulagraph_python import NebulaClient
|
|
211
|
-
|
|
212
|
-
new_client = NebulaClient(self.hosts, self.user, self.password)
|
|
213
|
-
self.clients.append(new_client)
|
|
214
|
-
|
|
215
|
-
self.pool.put(new_client)
|
|
216
|
-
|
|
217
|
-
logger.info("[Pool] Replaced dead client with a new one.")
|
|
218
|
-
return new_client
|
|
219
|
-
|
|
220
|
-
|
|
221
104
|
class NebulaGraphDB(BaseGraphDB):
|
|
222
105
|
"""
|
|
223
106
|
NebulaGraph-based implementation of a graph memory store.
|
|
@@ -226,94 +109,194 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
226
109
|
# ====== shared pool cache & refcount ======
|
|
227
110
|
# These are process-local; in a multi-process model each process will
|
|
228
111
|
# have its own cache.
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
112
|
+
_CLIENT_CACHE: ClassVar[dict[str, "NebulaClient"]] = {}
|
|
113
|
+
_CLIENT_REFCOUNT: ClassVar[dict[str, int]] = {}
|
|
114
|
+
_CLIENT_LOCK: ClassVar[Lock] = Lock()
|
|
115
|
+
_CLIENT_INIT_DONE: ClassVar[set[str]] = set()
|
|
232
116
|
|
|
233
117
|
@staticmethod
|
|
234
|
-
def
|
|
235
|
-
"""
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
118
|
+
def _get_hosts_from_cfg(cfg: NebulaGraphDBConfig) -> list[str]:
|
|
119
|
+
hosts = getattr(cfg, "uri", None) or getattr(cfg, "hosts", None)
|
|
120
|
+
if isinstance(hosts, str):
|
|
121
|
+
return [hosts]
|
|
122
|
+
return list(hosts or [])
|
|
123
|
+
|
|
124
|
+
@staticmethod
|
|
125
|
+
def _make_client_key(cfg: NebulaGraphDBConfig) -> str:
|
|
126
|
+
hosts = NebulaGraphDB._get_hosts_from_cfg(cfg)
|
|
241
127
|
return "|".join(
|
|
242
128
|
[
|
|
243
|
-
"nebula",
|
|
244
|
-
|
|
129
|
+
"nebula-sync",
|
|
130
|
+
",".join(hosts),
|
|
245
131
|
str(getattr(cfg, "user", "")),
|
|
246
|
-
str(getattr(cfg, "password", "")),
|
|
247
|
-
# pool sizing / tls / timeouts if you have them in config:
|
|
248
|
-
str(getattr(cfg, "max_client", 1000)),
|
|
249
|
-
# multi-db mode can impact how we use sessions; keep it to be safe
|
|
250
132
|
str(getattr(cfg, "use_multi_db", False)),
|
|
133
|
+
str(getattr(cfg, "space", "")),
|
|
251
134
|
]
|
|
252
135
|
)
|
|
253
136
|
|
|
254
137
|
@classmethod
|
|
255
|
-
def
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
""
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
138
|
+
def _bootstrap_admin(cls, cfg: NebulaGraphDBConfig, client: "NebulaClient") -> "NebulaGraphDB":
|
|
139
|
+
tmp = object.__new__(NebulaGraphDB)
|
|
140
|
+
tmp.config = cfg
|
|
141
|
+
tmp.db_name = cfg.space
|
|
142
|
+
tmp.user_name = getattr(cfg, "user_name", None)
|
|
143
|
+
tmp.embedding_dimension = getattr(cfg, "embedding_dimension", 3072)
|
|
144
|
+
tmp.default_memory_dimension = 3072
|
|
145
|
+
tmp.common_fields = {
|
|
146
|
+
"id",
|
|
147
|
+
"memory",
|
|
148
|
+
"user_name",
|
|
149
|
+
"user_id",
|
|
150
|
+
"session_id",
|
|
151
|
+
"status",
|
|
152
|
+
"key",
|
|
153
|
+
"confidence",
|
|
154
|
+
"tags",
|
|
155
|
+
"created_at",
|
|
156
|
+
"updated_at",
|
|
157
|
+
"memory_type",
|
|
158
|
+
"sources",
|
|
159
|
+
"source",
|
|
160
|
+
"node_type",
|
|
161
|
+
"visibility",
|
|
162
|
+
"usage",
|
|
163
|
+
"background",
|
|
164
|
+
}
|
|
165
|
+
tmp.base_fields = set(tmp.common_fields) - {"usage"}
|
|
166
|
+
tmp.heavy_fields = {"usage"}
|
|
167
|
+
tmp.dim_field = (
|
|
168
|
+
f"embedding_{tmp.embedding_dimension}"
|
|
169
|
+
if str(tmp.embedding_dimension) != str(tmp.default_memory_dimension)
|
|
170
|
+
else "embedding"
|
|
171
|
+
)
|
|
172
|
+
tmp.system_db_name = "system" if getattr(cfg, "use_multi_db", False) else cfg.space
|
|
173
|
+
tmp._client = client
|
|
174
|
+
tmp._owns_client = False
|
|
175
|
+
return tmp
|
|
176
|
+
|
|
177
|
+
@classmethod
|
|
178
|
+
def _get_or_create_shared_client(cls, cfg: NebulaGraphDBConfig) -> tuple[str, "NebulaClient"]:
|
|
179
|
+
from nebulagraph_python import (
|
|
180
|
+
ConnectionConfig,
|
|
181
|
+
NebulaClient,
|
|
182
|
+
SessionConfig,
|
|
183
|
+
SessionPoolConfig,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
key = cls._make_client_key(cfg)
|
|
187
|
+
with cls._CLIENT_LOCK:
|
|
188
|
+
client = cls._CLIENT_CACHE.get(key)
|
|
189
|
+
if client is None:
|
|
190
|
+
# Connection setting
|
|
191
|
+
conn_conf: ConnectionConfig | None = getattr(cfg, "conn_config", None)
|
|
192
|
+
if conn_conf is None:
|
|
193
|
+
conn_conf = ConnectionConfig.from_defults(
|
|
194
|
+
cls._get_hosts_from_cfg(cfg),
|
|
195
|
+
getattr(cfg, "ssl_param", None),
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
sess_conf = SessionConfig(graph=getattr(cfg, "space", None))
|
|
199
|
+
pool_conf = SessionPoolConfig(
|
|
200
|
+
size=int(getattr(cfg, "max_client", 1000)), wait_timeout=5000
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
client = NebulaClient(
|
|
204
|
+
hosts=conn_conf.hosts,
|
|
205
|
+
username=cfg.user,
|
|
206
|
+
password=cfg.password,
|
|
207
|
+
conn_config=conn_conf,
|
|
208
|
+
session_config=sess_conf,
|
|
209
|
+
session_pool_config=pool_conf,
|
|
272
210
|
)
|
|
273
|
-
cls.
|
|
274
|
-
cls.
|
|
275
|
-
logger.info(f"[
|
|
211
|
+
cls._CLIENT_CACHE[key] = client
|
|
212
|
+
cls._CLIENT_REFCOUNT[key] = 0
|
|
213
|
+
logger.info(f"[NebulaGraphDBSync] Created shared NebulaClient key={key}")
|
|
276
214
|
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
215
|
+
cls._CLIENT_REFCOUNT[key] = cls._CLIENT_REFCOUNT.get(key, 0) + 1
|
|
216
|
+
|
|
217
|
+
if getattr(cfg, "auto_create", False) and key not in cls._CLIENT_INIT_DONE:
|
|
218
|
+
try:
|
|
219
|
+
pass
|
|
220
|
+
finally:
|
|
221
|
+
pass
|
|
222
|
+
|
|
223
|
+
if getattr(cfg, "auto_create", False) and key not in cls._CLIENT_INIT_DONE:
|
|
224
|
+
with cls._CLIENT_LOCK:
|
|
225
|
+
if key not in cls._CLIENT_INIT_DONE:
|
|
226
|
+
admin = cls._bootstrap_admin(cfg, client)
|
|
227
|
+
try:
|
|
228
|
+
admin._ensure_database_exists()
|
|
229
|
+
admin._create_basic_property_indexes()
|
|
230
|
+
admin._create_vector_index(
|
|
231
|
+
dimensions=int(
|
|
232
|
+
admin.embedding_dimension or admin.default_memory_dimension
|
|
233
|
+
),
|
|
234
|
+
)
|
|
235
|
+
cls._CLIENT_INIT_DONE.add(key)
|
|
236
|
+
logger.info("[NebulaGraphDBSync] One-time init done")
|
|
237
|
+
except Exception:
|
|
238
|
+
logger.exception("[NebulaGraphDBSync] One-time init failed")
|
|
239
|
+
|
|
240
|
+
return key, client
|
|
241
|
+
|
|
242
|
+
def _refresh_client(self):
|
|
243
|
+
"""
|
|
244
|
+
refresh NebulaClient:
|
|
245
|
+
"""
|
|
246
|
+
old_key = getattr(self, "_client_key", None)
|
|
247
|
+
if not old_key:
|
|
248
|
+
return
|
|
249
|
+
|
|
250
|
+
cls = self.__class__
|
|
251
|
+
with cls._CLIENT_LOCK:
|
|
252
|
+
try:
|
|
253
|
+
if old_key in cls._CLIENT_CACHE:
|
|
254
|
+
try:
|
|
255
|
+
cls._CLIENT_CACHE[old_key].close()
|
|
256
|
+
except Exception as e:
|
|
257
|
+
logger.warning(f"[refresh_client] close old client error: {e}")
|
|
258
|
+
finally:
|
|
259
|
+
cls._CLIENT_CACHE.pop(old_key, None)
|
|
260
|
+
finally:
|
|
261
|
+
cls._CLIENT_REFCOUNT[old_key] = 0
|
|
262
|
+
|
|
263
|
+
new_key, new_client = cls._get_or_create_shared_client(self.config)
|
|
264
|
+
self._client_key = new_key
|
|
265
|
+
self._client = new_client
|
|
266
|
+
logger.info(f"[NebulaGraphDBSync] client refreshed: {old_key} -> {new_key}")
|
|
280
267
|
|
|
281
268
|
@classmethod
|
|
282
|
-
def
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
"""
|
|
286
|
-
with cls._POOL_LOCK:
|
|
287
|
-
if key not in cls._POOL_CACHE:
|
|
269
|
+
def _release_shared_client(cls, key: str):
|
|
270
|
+
with cls._CLIENT_LOCK:
|
|
271
|
+
if key not in cls._CLIENT_CACHE:
|
|
288
272
|
return
|
|
289
|
-
cls.
|
|
290
|
-
if cls.
|
|
273
|
+
cls._CLIENT_REFCOUNT[key] = max(0, cls._CLIENT_REFCOUNT.get(key, 0) - 1)
|
|
274
|
+
if cls._CLIENT_REFCOUNT[key] == 0:
|
|
291
275
|
try:
|
|
292
|
-
cls.
|
|
276
|
+
cls._CLIENT_CACHE[key].close()
|
|
293
277
|
except Exception as e:
|
|
294
|
-
logger.warning(f"[
|
|
278
|
+
logger.warning(f"[NebulaGraphDBSync] Error closing client: {e}")
|
|
295
279
|
finally:
|
|
296
|
-
cls.
|
|
297
|
-
cls.
|
|
298
|
-
logger.info(f"[
|
|
280
|
+
cls._CLIENT_CACHE.pop(key, None)
|
|
281
|
+
cls._CLIENT_REFCOUNT.pop(key, None)
|
|
282
|
+
logger.info(f"[NebulaGraphDBSync] Closed & removed client key={key}")
|
|
299
283
|
|
|
300
284
|
@classmethod
|
|
301
|
-
def
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
for key, pool in list(cls._POOL_CACHE.items()):
|
|
285
|
+
def close_all_shared_clients(cls):
|
|
286
|
+
with cls._CLIENT_LOCK:
|
|
287
|
+
for key, client in list(cls._CLIENT_CACHE.items()):
|
|
305
288
|
try:
|
|
306
|
-
|
|
289
|
+
client.close()
|
|
307
290
|
except Exception as e:
|
|
308
|
-
logger.warning(f"[
|
|
291
|
+
logger.warning(f"[NebulaGraphDBSync] Error closing client {key}: {e}")
|
|
309
292
|
finally:
|
|
310
|
-
logger.info(f"[
|
|
311
|
-
cls.
|
|
312
|
-
cls.
|
|
293
|
+
logger.info(f"[NebulaGraphDBSync] Closed client key={key}")
|
|
294
|
+
cls._CLIENT_CACHE.clear()
|
|
295
|
+
cls._CLIENT_REFCOUNT.clear()
|
|
313
296
|
|
|
314
297
|
@require_python_package(
|
|
315
298
|
import_name="nebulagraph_python",
|
|
316
|
-
install_command="pip install
|
|
299
|
+
install_command="pip install nebulagraph-python>=5.1.1",
|
|
317
300
|
install_link=".....",
|
|
318
301
|
)
|
|
319
302
|
def __init__(self, config: NebulaGraphDBConfig):
|
|
@@ -371,34 +354,32 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
371
354
|
|
|
372
355
|
# ---- NEW: pool acquisition strategy
|
|
373
356
|
# Get or create a shared pool from the class-level cache
|
|
374
|
-
self.
|
|
375
|
-
self.
|
|
376
|
-
|
|
377
|
-
# auto-create graph type / graph / index if needed
|
|
378
|
-
if config.auto_create:
|
|
379
|
-
self._ensure_database_exists()
|
|
380
|
-
|
|
381
|
-
self.execute_query(f"SESSION SET GRAPH `{self.db_name}`")
|
|
382
|
-
|
|
383
|
-
# Create only if not exists
|
|
384
|
-
self.create_index(dimensions=config.embedding_dimension)
|
|
357
|
+
self._client_key, self._client = self._get_or_create_shared_client(config)
|
|
358
|
+
self._owns_client = True
|
|
385
359
|
|
|
386
360
|
logger.info("Connected to NebulaGraph successfully.")
|
|
387
361
|
|
|
388
362
|
@timed
|
|
389
|
-
def execute_query(self, gql: str, timeout: float =
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
return client.execute(gql, timeout=timeout)
|
|
363
|
+
def execute_query(self, gql: str, timeout: float = 60.0, auto_set_db: bool = True):
|
|
364
|
+
def _wrap_use_db(q: str) -> str:
|
|
365
|
+
if auto_set_db and self.db_name:
|
|
366
|
+
return f"USE `{self.db_name}`\n{q}"
|
|
367
|
+
return q
|
|
395
368
|
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
369
|
+
try:
|
|
370
|
+
return self._client.execute(_wrap_use_db(gql), timeout=timeout)
|
|
371
|
+
|
|
372
|
+
except Exception as e:
|
|
373
|
+
emsg = str(e)
|
|
374
|
+
if any(k.lower() in emsg.lower() for k in _TRANSIENT_ERR_KEYS):
|
|
375
|
+
logger.warning(f"[execute_query] {e!s} → refreshing session pool and retry once...")
|
|
376
|
+
try:
|
|
377
|
+
self._refresh_client()
|
|
378
|
+
return self._client.execute(_wrap_use_db(gql), timeout=timeout)
|
|
379
|
+
except Exception:
|
|
380
|
+
logger.exception("[execute_query] retry after refresh failed")
|
|
381
|
+
raise
|
|
382
|
+
raise
|
|
402
383
|
|
|
403
384
|
@timed
|
|
404
385
|
def close(self):
|
|
@@ -409,13 +390,13 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
409
390
|
- If pool was acquired via shared cache, decrement refcount and close
|
|
410
391
|
when the last owner releases it.
|
|
411
392
|
"""
|
|
412
|
-
if not self.
|
|
413
|
-
logger.debug("[
|
|
393
|
+
if not self._owns_client:
|
|
394
|
+
logger.debug("[NebulaGraphDBSync] close() skipped (injected client).")
|
|
414
395
|
return
|
|
415
|
-
if self.
|
|
416
|
-
self.
|
|
417
|
-
self.
|
|
418
|
-
self.
|
|
396
|
+
if self._client_key:
|
|
397
|
+
self._release_shared_client(self._client_key)
|
|
398
|
+
self._client_key = None
|
|
399
|
+
self._client = None
|
|
419
400
|
|
|
420
401
|
# NOTE: __del__ is best-effort; do not rely on GC order.
|
|
421
402
|
def __del__(self):
|
|
@@ -972,6 +953,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
972
953
|
scope: str | None = None,
|
|
973
954
|
status: str | None = None,
|
|
974
955
|
threshold: float | None = None,
|
|
956
|
+
search_filter: dict | None = None,
|
|
975
957
|
**kwargs,
|
|
976
958
|
) -> list[dict]:
|
|
977
959
|
"""
|
|
@@ -984,6 +966,8 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
984
966
|
status (str, optional): Node status filter (e.g., 'active', 'archived').
|
|
985
967
|
If provided, restricts results to nodes with matching status.
|
|
986
968
|
threshold (float, optional): Minimum similarity score threshold (0 ~ 1).
|
|
969
|
+
search_filter (dict, optional): Additional metadata filters for search results.
|
|
970
|
+
Keys should match node properties, values are the expected values.
|
|
987
971
|
|
|
988
972
|
Returns:
|
|
989
973
|
list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
|
|
@@ -993,6 +977,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
993
977
|
- If scope is provided, it restricts results to nodes with matching memory_type.
|
|
994
978
|
- If 'status' is provided, only nodes with the matching status will be returned.
|
|
995
979
|
- If threshold is provided, only results with score >= threshold will be returned.
|
|
980
|
+
- If search_filter is provided, additional WHERE clauses will be added for metadata filtering.
|
|
996
981
|
- Typical use case: restrict to 'status = activated' to avoid
|
|
997
982
|
matching archived or merged nodes.
|
|
998
983
|
"""
|
|
@@ -1012,10 +997,17 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1012
997
|
else:
|
|
1013
998
|
where_clauses.append(f'n.user_name = "{self.config.user_name}"')
|
|
1014
999
|
|
|
1000
|
+
# Add search_filter conditions
|
|
1001
|
+
if search_filter:
|
|
1002
|
+
for key, value in search_filter.items():
|
|
1003
|
+
if isinstance(value, str):
|
|
1004
|
+
where_clauses.append(f'n.{key} = "{value}"')
|
|
1005
|
+
else:
|
|
1006
|
+
where_clauses.append(f"n.{key} = {value}")
|
|
1007
|
+
|
|
1015
1008
|
where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
|
|
1016
1009
|
|
|
1017
1010
|
gql = f"""
|
|
1018
|
-
USE `{self.db_name}`
|
|
1019
1011
|
MATCH (n@Memory)
|
|
1020
1012
|
{where_clause}
|
|
1021
1013
|
ORDER BY inner_product(n.{self.dim_field}, {gql_vector}) DESC
|
|
@@ -1038,7 +1030,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1038
1030
|
id_val = values[0].as_string()
|
|
1039
1031
|
score_val = values[1].as_double()
|
|
1040
1032
|
score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score
|
|
1041
|
-
if threshold is None or score_val
|
|
1033
|
+
if threshold is None or score_val >= threshold:
|
|
1042
1034
|
output.append({"id": id_val, "score": score_val})
|
|
1043
1035
|
return output
|
|
1044
1036
|
except Exception as e:
|
|
@@ -1368,9 +1360,9 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1368
1360
|
where_clause += f' AND n.user_name = "{self.config.user_name}"'
|
|
1369
1361
|
|
|
1370
1362
|
return_fields = self._build_return_fields(include_embedding)
|
|
1363
|
+
return_fields += f", n.{self.dim_field} AS {self.dim_field}"
|
|
1371
1364
|
|
|
1372
1365
|
query = f"""
|
|
1373
|
-
USE `{self.db_name}`
|
|
1374
1366
|
MATCH (n@Memory)
|
|
1375
1367
|
WHERE {where_clause}
|
|
1376
1368
|
OPTIONAL MATCH (n)-[@PARENT]->(c@Memory)
|
|
@@ -1380,11 +1372,16 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1380
1372
|
"""
|
|
1381
1373
|
|
|
1382
1374
|
candidates = []
|
|
1375
|
+
node_ids = set()
|
|
1383
1376
|
try:
|
|
1384
1377
|
results = self.execute_query(query)
|
|
1385
1378
|
for row in results:
|
|
1386
1379
|
props = {k: v.value for k, v in row.items()}
|
|
1387
|
-
|
|
1380
|
+
node = self._parse_node(props)
|
|
1381
|
+
node_id = node["id"]
|
|
1382
|
+
if node_id not in node_ids:
|
|
1383
|
+
candidates.append(node)
|
|
1384
|
+
node_ids.add(node_id)
|
|
1388
1385
|
except Exception as e:
|
|
1389
1386
|
logger.error(f"Failed : {e}, traceback: {traceback.format_exc()}")
|
|
1390
1387
|
return candidates
|
|
@@ -1538,18 +1535,19 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1538
1535
|
logger.info(f"✅ Graph Type {graph_type_name} already include {self.dim_field}")
|
|
1539
1536
|
|
|
1540
1537
|
create_graph = f"CREATE GRAPH IF NOT EXISTS `{self.db_name}` TYPED {graph_type_name}"
|
|
1541
|
-
set_graph_working = f"SESSION SET GRAPH `{self.db_name}`"
|
|
1542
|
-
|
|
1543
1538
|
try:
|
|
1544
1539
|
self.execute_query(create_graph, auto_set_db=False)
|
|
1545
|
-
self.execute_query(set_graph_working)
|
|
1546
1540
|
logger.info(f"✅ Graph ``{self.db_name}`` is now the working graph.")
|
|
1547
1541
|
except Exception as e:
|
|
1548
1542
|
logger.error(f"❌ Failed to create tag: {e} trace: {traceback.format_exc()}")
|
|
1549
1543
|
|
|
1550
1544
|
@timed
|
|
1551
1545
|
def _create_vector_index(
|
|
1552
|
-
self,
|
|
1546
|
+
self,
|
|
1547
|
+
label: str = "Memory",
|
|
1548
|
+
vector_property: str = "embedding",
|
|
1549
|
+
dimensions: int = 3072,
|
|
1550
|
+
index_name: str = "memory_vector_index",
|
|
1553
1551
|
) -> None:
|
|
1554
1552
|
"""
|
|
1555
1553
|
Create a vector index for the specified property in the label.
|