MemoryOS 0.2.2__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MemoryOS might be problematic. Click here for more details.
- {memoryos-0.2.2.dist-info → memoryos-1.0.1.dist-info}/METADATA +7 -1
- {memoryos-0.2.2.dist-info → memoryos-1.0.1.dist-info}/RECORD +81 -66
- memos/__init__.py +1 -1
- memos/api/config.py +31 -8
- memos/api/context/context.py +1 -1
- memos/api/context/context_thread.py +96 -0
- memos/api/middleware/request_context.py +94 -0
- memos/api/product_api.py +5 -1
- memos/api/product_models.py +16 -0
- memos/api/routers/product_router.py +39 -3
- memos/api/start_api.py +3 -0
- memos/configs/internet_retriever.py +13 -0
- memos/configs/mem_scheduler.py +38 -16
- memos/configs/memory.py +13 -0
- memos/configs/reranker.py +18 -0
- memos/graph_dbs/base.py +33 -4
- memos/graph_dbs/nebular.py +631 -236
- memos/graph_dbs/neo4j.py +18 -7
- memos/graph_dbs/neo4j_community.py +6 -3
- memos/llms/vllm.py +2 -0
- memos/log.py +125 -8
- memos/mem_os/core.py +49 -11
- memos/mem_os/main.py +1 -1
- memos/mem_os/product.py +392 -215
- memos/mem_os/utils/default_config.py +1 -1
- memos/mem_os/utils/format_utils.py +11 -47
- memos/mem_os/utils/reference_utils.py +153 -0
- memos/mem_reader/simple_struct.py +112 -43
- memos/mem_scheduler/base_scheduler.py +58 -55
- memos/mem_scheduler/{modules → general_modules}/base.py +1 -2
- memos/mem_scheduler/{modules → general_modules}/dispatcher.py +54 -15
- memos/mem_scheduler/{modules → general_modules}/rabbitmq_service.py +4 -4
- memos/mem_scheduler/{modules → general_modules}/redis_service.py +1 -1
- memos/mem_scheduler/{modules → general_modules}/retriever.py +19 -5
- memos/mem_scheduler/{modules → general_modules}/scheduler_logger.py +10 -4
- memos/mem_scheduler/general_scheduler.py +110 -67
- memos/mem_scheduler/monitors/__init__.py +0 -0
- memos/mem_scheduler/monitors/dispatcher_monitor.py +305 -0
- memos/mem_scheduler/{modules/monitor.py → monitors/general_monitor.py} +57 -19
- memos/mem_scheduler/mos_for_test_scheduler.py +7 -1
- memos/mem_scheduler/schemas/general_schemas.py +3 -2
- memos/mem_scheduler/schemas/message_schemas.py +2 -1
- memos/mem_scheduler/schemas/monitor_schemas.py +10 -2
- memos/mem_scheduler/utils/misc_utils.py +43 -2
- memos/mem_user/mysql_user_manager.py +4 -2
- memos/memories/activation/item.py +1 -1
- memos/memories/activation/kv.py +20 -8
- memos/memories/textual/base.py +1 -1
- memos/memories/textual/general.py +1 -1
- memos/memories/textual/item.py +1 -1
- memos/memories/textual/tree.py +31 -1
- memos/memories/textual/tree_text_memory/organize/{conflict.py → handler.py} +30 -48
- memos/memories/textual/tree_text_memory/organize/manager.py +8 -96
- memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +2 -0
- memos/memories/textual/tree_text_memory/organize/reorganizer.py +102 -140
- memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +231 -0
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +9 -0
- memos/memories/textual/tree_text_memory/retrieve/recall.py +67 -10
- memos/memories/textual/tree_text_memory/retrieve/reranker.py +1 -1
- memos/memories/textual/tree_text_memory/retrieve/searcher.py +246 -134
- memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +7 -2
- memos/memories/textual/tree_text_memory/retrieve/utils.py +7 -5
- memos/memos_tools/lockfree_dict.py +120 -0
- memos/memos_tools/notification_utils.py +46 -0
- memos/memos_tools/thread_safe_dict.py +288 -0
- memos/reranker/__init__.py +4 -0
- memos/reranker/base.py +24 -0
- memos/reranker/cosine_local.py +95 -0
- memos/reranker/factory.py +43 -0
- memos/reranker/http_bge.py +99 -0
- memos/reranker/noop.py +16 -0
- memos/templates/mem_reader_prompts.py +290 -39
- memos/templates/mem_scheduler_prompts.py +23 -10
- memos/templates/mos_prompts.py +133 -31
- memos/templates/tree_reorganize_prompts.py +24 -17
- memos/utils.py +19 -0
- memos/memories/textual/tree_text_memory/organize/redundancy.py +0 -193
- {memoryos-0.2.2.dist-info → memoryos-1.0.1.dist-info}/LICENSE +0 -0
- {memoryos-0.2.2.dist-info → memoryos-1.0.1.dist-info}/WHEEL +0 -0
- {memoryos-0.2.2.dist-info → memoryos-1.0.1.dist-info}/entry_points.txt +0 -0
- /memos/mem_scheduler/{modules → general_modules}/__init__.py +0 -0
- /memos/mem_scheduler/{modules → general_modules}/misc.py +0 -0
memos/graph_dbs/nebular.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
|
+
import json
|
|
1
2
|
import traceback
|
|
2
3
|
|
|
3
4
|
from contextlib import suppress
|
|
4
5
|
from datetime import datetime
|
|
5
6
|
from queue import Empty, Queue
|
|
6
7
|
from threading import Lock
|
|
7
|
-
from typing import Any, Literal
|
|
8
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Literal
|
|
8
9
|
|
|
9
10
|
import numpy as np
|
|
10
11
|
|
|
@@ -12,17 +13,24 @@ from memos.configs.graph_db import NebulaGraphDBConfig
|
|
|
12
13
|
from memos.dependency import require_python_package
|
|
13
14
|
from memos.graph_dbs.base import BaseGraphDB
|
|
14
15
|
from memos.log import get_logger
|
|
16
|
+
from memos.utils import timed
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from nebulagraph_python.client.pool import NebulaPool
|
|
15
21
|
|
|
16
22
|
|
|
17
23
|
logger = get_logger(__name__)
|
|
18
24
|
|
|
19
25
|
|
|
26
|
+
@timed
|
|
20
27
|
def _normalize(vec: list[float]) -> list[float]:
|
|
21
28
|
v = np.asarray(vec, dtype=np.float32)
|
|
22
29
|
norm = np.linalg.norm(v)
|
|
23
30
|
return (v / (norm if norm else 1.0)).tolist()
|
|
24
31
|
|
|
25
32
|
|
|
33
|
+
@timed
|
|
26
34
|
def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]:
|
|
27
35
|
node_id = item["id"]
|
|
28
36
|
memory = item["memory"]
|
|
@@ -30,97 +38,33 @@ def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]:
|
|
|
30
38
|
return node_id, memory, metadata
|
|
31
39
|
|
|
32
40
|
|
|
33
|
-
|
|
34
|
-
"""
|
|
35
|
-
Ensure metadata has proper datetime fields and normalized types.
|
|
36
|
-
|
|
37
|
-
- Fill `created_at` and `updated_at` if missing (in ISO 8601 format).
|
|
38
|
-
- Convert embedding to list of float if present.
|
|
39
|
-
"""
|
|
40
|
-
now = datetime.utcnow().isoformat()
|
|
41
|
-
metadata["node_type"] = metadata.pop("type")
|
|
42
|
-
|
|
43
|
-
# Fill timestamps if missing
|
|
44
|
-
metadata.setdefault("created_at", now)
|
|
45
|
-
metadata.setdefault("updated_at", now)
|
|
46
|
-
|
|
47
|
-
# Normalize embedding type
|
|
48
|
-
embedding = metadata.get("embedding")
|
|
49
|
-
if embedding and isinstance(embedding, list):
|
|
50
|
-
metadata["embedding"] = _normalize([float(x) for x in embedding])
|
|
51
|
-
|
|
52
|
-
return metadata
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def _metadata_filter(metadata: dict[str, Any]) -> dict[str, Any]:
|
|
56
|
-
"""
|
|
57
|
-
Filter and validate metadata dictionary against the Memory node schema.
|
|
58
|
-
- Removes keys not in schema.
|
|
59
|
-
- Warns if required fields are missing.
|
|
60
|
-
"""
|
|
61
|
-
|
|
62
|
-
allowed_fields = {
|
|
63
|
-
"id",
|
|
64
|
-
"memory",
|
|
65
|
-
"user_name",
|
|
66
|
-
"user_id",
|
|
67
|
-
"session_id",
|
|
68
|
-
"status",
|
|
69
|
-
"key",
|
|
70
|
-
"confidence",
|
|
71
|
-
"tags",
|
|
72
|
-
"created_at",
|
|
73
|
-
"updated_at",
|
|
74
|
-
"memory_type",
|
|
75
|
-
"sources",
|
|
76
|
-
"source",
|
|
77
|
-
"node_type",
|
|
78
|
-
"visibility",
|
|
79
|
-
"usage",
|
|
80
|
-
"background",
|
|
81
|
-
"embedding",
|
|
82
|
-
}
|
|
83
|
-
|
|
84
|
-
missing_fields = allowed_fields - metadata.keys()
|
|
85
|
-
if missing_fields:
|
|
86
|
-
logger.warning(f"Metadata missing required fields: {sorted(missing_fields)}")
|
|
87
|
-
|
|
88
|
-
filtered_metadata = {k: v for k, v in metadata.items() if k in allowed_fields}
|
|
89
|
-
|
|
90
|
-
return filtered_metadata
|
|
91
|
-
|
|
92
|
-
|
|
41
|
+
@timed
|
|
93
42
|
def _escape_str(value: str) -> str:
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
43
|
+
out = []
|
|
44
|
+
for ch in value:
|
|
45
|
+
code = ord(ch)
|
|
46
|
+
if ch == "\\":
|
|
47
|
+
out.append("\\\\")
|
|
48
|
+
elif ch == '"':
|
|
49
|
+
out.append('\\"')
|
|
50
|
+
elif ch == "\n":
|
|
51
|
+
out.append("\\n")
|
|
52
|
+
elif ch == "\r":
|
|
53
|
+
out.append("\\r")
|
|
54
|
+
elif ch == "\t":
|
|
55
|
+
out.append("\\t")
|
|
56
|
+
elif ch == "\b":
|
|
57
|
+
out.append("\\b")
|
|
58
|
+
elif ch == "\f":
|
|
59
|
+
out.append("\\f")
|
|
60
|
+
elif code < 0x20 or code in (0x2028, 0x2029):
|
|
61
|
+
out.append(f"\\u{code:04x}")
|
|
111
62
|
else:
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
if key == "embedding":
|
|
115
|
-
dim = len(val)
|
|
116
|
-
joined = ",".join(str(float(x)) for x in val)
|
|
117
|
-
return f"VECTOR<{dim}, FLOAT>([{joined}])"
|
|
118
|
-
elif val is None:
|
|
119
|
-
return "NULL"
|
|
120
|
-
else:
|
|
121
|
-
return f'"{_escape_str(str(val))}"'
|
|
63
|
+
out.append(ch)
|
|
64
|
+
return "".join(out)
|
|
122
65
|
|
|
123
66
|
|
|
67
|
+
@timed
|
|
124
68
|
def _format_datetime(value: str | datetime) -> str:
|
|
125
69
|
"""Ensure datetime is in ISO 8601 format string."""
|
|
126
70
|
if isinstance(value, datetime):
|
|
@@ -128,6 +72,21 @@ def _format_datetime(value: str | datetime) -> str:
|
|
|
128
72
|
return str(value)
|
|
129
73
|
|
|
130
74
|
|
|
75
|
+
@timed
|
|
76
|
+
def _normalize_datetime(val):
|
|
77
|
+
"""
|
|
78
|
+
Normalize datetime to ISO 8601 UTC string with +00:00.
|
|
79
|
+
- If val is datetime object -> keep isoformat() (Neo4j)
|
|
80
|
+
- If val is string without timezone -> append +00:00 (Nebula)
|
|
81
|
+
- Otherwise just str()
|
|
82
|
+
"""
|
|
83
|
+
if hasattr(val, "isoformat"):
|
|
84
|
+
return val.isoformat()
|
|
85
|
+
if isinstance(val, str) and not val.endswith(("+00:00", "Z", "+08:00")):
|
|
86
|
+
return val + "+08:00"
|
|
87
|
+
return str(val)
|
|
88
|
+
|
|
89
|
+
|
|
131
90
|
class SessionPoolError(Exception):
|
|
132
91
|
pass
|
|
133
92
|
|
|
@@ -149,6 +108,7 @@ class SessionPool:
|
|
|
149
108
|
self.hosts = hosts
|
|
150
109
|
self.user = user
|
|
151
110
|
self.password = password
|
|
111
|
+
self.minsize = minsize
|
|
152
112
|
self.maxsize = maxsize
|
|
153
113
|
self.pool = Queue(maxsize)
|
|
154
114
|
self.lock = Lock()
|
|
@@ -158,6 +118,7 @@ class SessionPool:
|
|
|
158
118
|
for _ in range(minsize):
|
|
159
119
|
self._create_and_add_client()
|
|
160
120
|
|
|
121
|
+
@timed
|
|
161
122
|
def _create_and_add_client(self):
|
|
162
123
|
from nebulagraph_python import NebulaClient
|
|
163
124
|
|
|
@@ -165,28 +126,37 @@ class SessionPool:
|
|
|
165
126
|
self.pool.put(client)
|
|
166
127
|
self.clients.append(client)
|
|
167
128
|
|
|
129
|
+
@timed
|
|
168
130
|
def get_client(self, timeout: float = 5.0):
|
|
169
|
-
from nebulagraph_python import NebulaClient
|
|
170
|
-
|
|
171
131
|
try:
|
|
172
132
|
return self.pool.get(timeout=timeout)
|
|
173
133
|
except Empty:
|
|
174
134
|
with self.lock:
|
|
175
135
|
if len(self.clients) < self.maxsize:
|
|
136
|
+
from nebulagraph_python import NebulaClient
|
|
137
|
+
|
|
176
138
|
client = NebulaClient(self.hosts, self.user, self.password)
|
|
177
139
|
self.clients.append(client)
|
|
178
140
|
return client
|
|
179
141
|
raise RuntimeError("NebulaClientPool exhausted") from None
|
|
180
142
|
|
|
143
|
+
@timed
|
|
181
144
|
def return_client(self, client):
|
|
182
|
-
|
|
145
|
+
try:
|
|
146
|
+
client.execute("YIELD 1")
|
|
147
|
+
self.pool.put(client)
|
|
148
|
+
except Exception:
|
|
149
|
+
logger.info("[Pool] Client dead, replacing...")
|
|
150
|
+
self.replace_client(client)
|
|
183
151
|
|
|
152
|
+
@timed
|
|
184
153
|
def close(self):
|
|
185
154
|
for client in self.clients:
|
|
186
155
|
with suppress(Exception):
|
|
187
156
|
client.close()
|
|
188
157
|
self.clients.clear()
|
|
189
158
|
|
|
159
|
+
@timed
|
|
190
160
|
def get(self):
|
|
191
161
|
"""
|
|
192
162
|
Context manager: with pool.get() as client:
|
|
@@ -207,12 +177,140 @@ class SessionPool:
|
|
|
207
177
|
|
|
208
178
|
return _ClientContext(self)
|
|
209
179
|
|
|
180
|
+
@timed
|
|
181
|
+
def reset_pool(self):
|
|
182
|
+
"""⚠️ Emergency reset: Close all clients and clear the pool."""
|
|
183
|
+
logger.warning("[Pool] Resetting all clients. Existing sessions will be lost.")
|
|
184
|
+
with self.lock:
|
|
185
|
+
for client in self.clients:
|
|
186
|
+
try:
|
|
187
|
+
client.close()
|
|
188
|
+
except Exception:
|
|
189
|
+
logger.error("Fail to close!!!")
|
|
190
|
+
self.clients.clear()
|
|
191
|
+
while not self.pool.empty():
|
|
192
|
+
try:
|
|
193
|
+
self.pool.get_nowait()
|
|
194
|
+
except Empty:
|
|
195
|
+
break
|
|
196
|
+
for _ in range(self.minsize):
|
|
197
|
+
self._create_and_add_client()
|
|
198
|
+
logger.info("[Pool] Pool has been reset successfully.")
|
|
199
|
+
|
|
200
|
+
@timed
|
|
201
|
+
def replace_client(self, client):
|
|
202
|
+
try:
|
|
203
|
+
client.close()
|
|
204
|
+
except Exception:
|
|
205
|
+
logger.error("Fail to close client")
|
|
206
|
+
|
|
207
|
+
if client in self.clients:
|
|
208
|
+
self.clients.remove(client)
|
|
209
|
+
|
|
210
|
+
from nebulagraph_python import NebulaClient
|
|
211
|
+
|
|
212
|
+
new_client = NebulaClient(self.hosts, self.user, self.password)
|
|
213
|
+
self.clients.append(new_client)
|
|
214
|
+
|
|
215
|
+
self.pool.put(new_client)
|
|
216
|
+
|
|
217
|
+
logger.info("[Pool] Replaced dead client with a new one.")
|
|
218
|
+
return new_client
|
|
219
|
+
|
|
210
220
|
|
|
211
221
|
class NebulaGraphDB(BaseGraphDB):
|
|
212
222
|
"""
|
|
213
223
|
NebulaGraph-based implementation of a graph memory store.
|
|
214
224
|
"""
|
|
215
225
|
|
|
226
|
+
# ====== shared pool cache & refcount ======
|
|
227
|
+
# These are process-local; in a multi-process model each process will
|
|
228
|
+
# have its own cache.
|
|
229
|
+
_POOL_CACHE: ClassVar[dict[str, "NebulaPool"]] = {}
|
|
230
|
+
_POOL_REFCOUNT: ClassVar[dict[str, int]] = {}
|
|
231
|
+
_POOL_LOCK: ClassVar[Lock] = Lock()
|
|
232
|
+
|
|
233
|
+
@staticmethod
|
|
234
|
+
def _make_pool_key(cfg: NebulaGraphDBConfig) -> str:
|
|
235
|
+
"""
|
|
236
|
+
Build a cache key that captures all connection-affecting options.
|
|
237
|
+
Keep this key stable and include fields that change the underlying pool behavior.
|
|
238
|
+
"""
|
|
239
|
+
# NOTE: Do not include tenant-like or query-scope-only fields here.
|
|
240
|
+
# Only include things that affect the actual TCP/auth/session pool.
|
|
241
|
+
return "|".join(
|
|
242
|
+
[
|
|
243
|
+
"nebula",
|
|
244
|
+
str(getattr(cfg, "uri", "")),
|
|
245
|
+
str(getattr(cfg, "user", "")),
|
|
246
|
+
str(getattr(cfg, "password", "")),
|
|
247
|
+
# pool sizing / tls / timeouts if you have them in config:
|
|
248
|
+
str(getattr(cfg, "max_client", 1000)),
|
|
249
|
+
# multi-db mode can impact how we use sessions; keep it to be safe
|
|
250
|
+
str(getattr(cfg, "use_multi_db", False)),
|
|
251
|
+
]
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
@classmethod
|
|
255
|
+
def _get_or_create_shared_pool(cls, cfg: NebulaGraphDBConfig):
|
|
256
|
+
"""
|
|
257
|
+
Get a shared NebulaPool from cache or create one if missing.
|
|
258
|
+
Thread-safe with a lock; maintains a simple refcount.
|
|
259
|
+
"""
|
|
260
|
+
key = cls._make_pool_key(cfg)
|
|
261
|
+
|
|
262
|
+
with cls._POOL_LOCK:
|
|
263
|
+
pool = cls._POOL_CACHE.get(key)
|
|
264
|
+
if pool is None:
|
|
265
|
+
# Create a new pool and put into cache
|
|
266
|
+
pool = SessionPool(
|
|
267
|
+
hosts=cfg.get("uri"),
|
|
268
|
+
user=cfg.get("user"),
|
|
269
|
+
password=cfg.get("password"),
|
|
270
|
+
minsize=1,
|
|
271
|
+
maxsize=cfg.get("max_client", 1000),
|
|
272
|
+
)
|
|
273
|
+
cls._POOL_CACHE[key] = pool
|
|
274
|
+
cls._POOL_REFCOUNT[key] = 0
|
|
275
|
+
logger.info(f"[NebulaGraphDB] Created new shared NebulaPool for key={key}")
|
|
276
|
+
|
|
277
|
+
# Increase refcount for the caller
|
|
278
|
+
cls._POOL_REFCOUNT[key] = cls._POOL_REFCOUNT.get(key, 0) + 1
|
|
279
|
+
return key, pool
|
|
280
|
+
|
|
281
|
+
@classmethod
|
|
282
|
+
def _release_shared_pool(cls, key: str):
|
|
283
|
+
"""
|
|
284
|
+
Decrease refcount for the given pool key; only close when refcount hits zero.
|
|
285
|
+
"""
|
|
286
|
+
with cls._POOL_LOCK:
|
|
287
|
+
if key not in cls._POOL_CACHE:
|
|
288
|
+
return
|
|
289
|
+
cls._POOL_REFCOUNT[key] = max(0, cls._POOL_REFCOUNT.get(key, 0) - 1)
|
|
290
|
+
if cls._POOL_REFCOUNT[key] == 0:
|
|
291
|
+
try:
|
|
292
|
+
cls._POOL_CACHE[key].close()
|
|
293
|
+
except Exception as e:
|
|
294
|
+
logger.warning(f"[NebulaGraphDB] Error closing shared pool: {e}")
|
|
295
|
+
finally:
|
|
296
|
+
cls._POOL_CACHE.pop(key, None)
|
|
297
|
+
cls._POOL_REFCOUNT.pop(key, None)
|
|
298
|
+
logger.info(f"[NebulaGraphDB] Closed and removed shared pool key={key}")
|
|
299
|
+
|
|
300
|
+
@classmethod
|
|
301
|
+
def close_all_shared_pools(cls):
|
|
302
|
+
"""Force close all cached pools. Call this on graceful shutdown."""
|
|
303
|
+
with cls._POOL_LOCK:
|
|
304
|
+
for key, pool in list(cls._POOL_CACHE.items()):
|
|
305
|
+
try:
|
|
306
|
+
pool.close()
|
|
307
|
+
except Exception as e:
|
|
308
|
+
logger.warning(f"[NebulaGraphDB] Error closing pool key={key}: {e}")
|
|
309
|
+
finally:
|
|
310
|
+
logger.info(f"[NebulaGraphDB] Closed pool key={key}")
|
|
311
|
+
cls._POOL_CACHE.clear()
|
|
312
|
+
cls._POOL_REFCOUNT.clear()
|
|
313
|
+
|
|
216
314
|
@require_python_package(
|
|
217
315
|
import_name="nebulagraph_python",
|
|
218
316
|
install_command="pip install ... @Tianxing",
|
|
@@ -240,15 +338,43 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
240
338
|
self.config = config
|
|
241
339
|
self.db_name = config.space
|
|
242
340
|
self.user_name = config.user_name
|
|
243
|
-
self.
|
|
244
|
-
self.
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
341
|
+
self.embedding_dimension = config.embedding_dimension
|
|
342
|
+
self.default_memory_dimension = 3072
|
|
343
|
+
self.common_fields = {
|
|
344
|
+
"id",
|
|
345
|
+
"memory",
|
|
346
|
+
"user_name",
|
|
347
|
+
"user_id",
|
|
348
|
+
"session_id",
|
|
349
|
+
"status",
|
|
350
|
+
"key",
|
|
351
|
+
"confidence",
|
|
352
|
+
"tags",
|
|
353
|
+
"created_at",
|
|
354
|
+
"updated_at",
|
|
355
|
+
"memory_type",
|
|
356
|
+
"sources",
|
|
357
|
+
"source",
|
|
358
|
+
"node_type",
|
|
359
|
+
"visibility",
|
|
360
|
+
"usage",
|
|
361
|
+
"background",
|
|
362
|
+
}
|
|
363
|
+
self.base_fields = set(self.common_fields) - {"usage"}
|
|
364
|
+
self.heavy_fields = {"usage"}
|
|
365
|
+
self.dim_field = (
|
|
366
|
+
f"embedding_{self.embedding_dimension}"
|
|
367
|
+
if (str(self.embedding_dimension) != str(self.default_memory_dimension))
|
|
368
|
+
else "embedding"
|
|
250
369
|
)
|
|
370
|
+
self.system_db_name = "system" if config.use_multi_db else config.space
|
|
251
371
|
|
|
372
|
+
# ---- NEW: pool acquisition strategy
|
|
373
|
+
# Get or create a shared pool from the class-level cache
|
|
374
|
+
self._pool_key, self.pool = self._get_or_create_shared_pool(config)
|
|
375
|
+
self._owns_pool = True # We manage refcount for this instance
|
|
376
|
+
|
|
377
|
+
# auto-create graph type / graph / index if needed
|
|
252
378
|
if config.auto_create:
|
|
253
379
|
self._ensure_database_exists()
|
|
254
380
|
|
|
@@ -259,15 +385,44 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
259
385
|
|
|
260
386
|
logger.info("Connected to NebulaGraph successfully.")
|
|
261
387
|
|
|
262
|
-
|
|
388
|
+
@timed
|
|
389
|
+
def execute_query(self, gql: str, timeout: float = 10.0, auto_set_db: bool = True):
|
|
263
390
|
with self.pool.get() as client:
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
391
|
+
try:
|
|
392
|
+
if auto_set_db and self.db_name:
|
|
393
|
+
client.execute(f"SESSION SET GRAPH `{self.db_name}`")
|
|
394
|
+
return client.execute(gql, timeout=timeout)
|
|
267
395
|
|
|
396
|
+
except Exception as e:
|
|
397
|
+
if "Session not found" in str(e) or "Connection not established" in str(e):
|
|
398
|
+
logger.warning(f"[execute_query] {e!s}, replacing client...")
|
|
399
|
+
self.pool.replace_client(client)
|
|
400
|
+
return self.execute_query(gql, timeout, auto_set_db)
|
|
401
|
+
raise
|
|
402
|
+
|
|
403
|
+
@timed
|
|
268
404
|
def close(self):
|
|
269
|
-
|
|
405
|
+
"""
|
|
406
|
+
Close the connection resource if this instance owns it.
|
|
270
407
|
|
|
408
|
+
- If pool was injected (`shared_pool`), do nothing.
|
|
409
|
+
- If pool was acquired via shared cache, decrement refcount and close
|
|
410
|
+
when the last owner releases it.
|
|
411
|
+
"""
|
|
412
|
+
if not self._owns_pool:
|
|
413
|
+
logger.debug("[NebulaGraphDB] close() skipped (injected pool).")
|
|
414
|
+
return
|
|
415
|
+
if self._pool_key:
|
|
416
|
+
self._release_shared_pool(self._pool_key)
|
|
417
|
+
self._pool_key = None
|
|
418
|
+
self.pool = None
|
|
419
|
+
|
|
420
|
+
# NOTE: __del__ is best-effort; do not rely on GC order.
|
|
421
|
+
def __del__(self):
|
|
422
|
+
with suppress(Exception):
|
|
423
|
+
self.close()
|
|
424
|
+
|
|
425
|
+
@timed
|
|
271
426
|
def create_index(
|
|
272
427
|
self,
|
|
273
428
|
label: str = "Memory",
|
|
@@ -280,6 +435,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
280
435
|
# Create indexes
|
|
281
436
|
self._create_basic_property_indexes()
|
|
282
437
|
|
|
438
|
+
@timed
|
|
283
439
|
def remove_oldest_memory(self, memory_type: str, keep_latest: int) -> None:
|
|
284
440
|
"""
|
|
285
441
|
Remove all WorkingMemory nodes except the latest `keep_latest` entries.
|
|
@@ -302,6 +458,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
302
458
|
"""
|
|
303
459
|
self.execute_query(query)
|
|
304
460
|
|
|
461
|
+
@timed
|
|
305
462
|
def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None:
|
|
306
463
|
"""
|
|
307
464
|
Insert or update a Memory node in NebulaGraph.
|
|
@@ -318,10 +475,14 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
318
475
|
metadata["memory"] = memory
|
|
319
476
|
|
|
320
477
|
if "embedding" in metadata and isinstance(metadata["embedding"], list):
|
|
321
|
-
metadata["embedding"]
|
|
478
|
+
assert len(metadata["embedding"]) == self.embedding_dimension, (
|
|
479
|
+
f"input embedding dimension must equal to {self.embedding_dimension}"
|
|
480
|
+
)
|
|
481
|
+
embedding = metadata.pop("embedding")
|
|
482
|
+
metadata[self.dim_field] = _normalize(embedding)
|
|
322
483
|
|
|
323
|
-
metadata = _metadata_filter(metadata)
|
|
324
|
-
properties = ", ".join(f"{k}: {_format_value(v, k)}" for k, v in metadata.items())
|
|
484
|
+
metadata = self._metadata_filter(metadata)
|
|
485
|
+
properties = ", ".join(f"{k}: {self._format_value(v, k)}" for k, v in metadata.items())
|
|
325
486
|
gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
|
|
326
487
|
|
|
327
488
|
try:
|
|
@@ -332,16 +493,16 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
332
493
|
f"Failed to insert vertex {id}: gql: {gql}, {e}\ntrace: {traceback.format_exc()}"
|
|
333
494
|
)
|
|
334
495
|
|
|
496
|
+
@timed
|
|
335
497
|
def node_not_exist(self, scope: str) -> int:
|
|
336
498
|
if not self.config.use_multi_db and self.config.user_name:
|
|
337
499
|
filter_clause = f'n.memory_type = "{scope}" AND n.user_name = "{self.config.user_name}"'
|
|
338
500
|
else:
|
|
339
501
|
filter_clause = f'n.memory_type = "{scope}"'
|
|
340
|
-
|
|
341
502
|
query = f"""
|
|
342
503
|
MATCH (n@Memory)
|
|
343
504
|
WHERE {filter_clause}
|
|
344
|
-
RETURN n
|
|
505
|
+
RETURN n.id AS id
|
|
345
506
|
LIMIT 1
|
|
346
507
|
"""
|
|
347
508
|
|
|
@@ -352,6 +513,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
352
513
|
logger.error(f"[node_not_exist] Query failed: {e}", exc_info=True)
|
|
353
514
|
raise
|
|
354
515
|
|
|
516
|
+
@timed
|
|
355
517
|
def update_node(self, id: str, fields: dict[str, Any]) -> None:
|
|
356
518
|
"""
|
|
357
519
|
Update node fields in Nebular, auto-converting `created_at` and `updated_at` to datetime type if present.
|
|
@@ -359,7 +521,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
359
521
|
fields = fields.copy()
|
|
360
522
|
set_clauses = []
|
|
361
523
|
for k, v in fields.items():
|
|
362
|
-
set_clauses.append(f"n.{k} = {_format_value(v, k)}")
|
|
524
|
+
set_clauses.append(f"n.{k} = {self._format_value(v, k)}")
|
|
363
525
|
|
|
364
526
|
set_clause_str = ",\n ".join(set_clauses)
|
|
365
527
|
|
|
@@ -373,6 +535,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
373
535
|
query += f"\nSET {set_clause_str}"
|
|
374
536
|
self.execute_query(query)
|
|
375
537
|
|
|
538
|
+
@timed
|
|
376
539
|
def delete_node(self, id: str) -> None:
|
|
377
540
|
"""
|
|
378
541
|
Delete a node from the graph.
|
|
@@ -384,10 +547,11 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
384
547
|
"""
|
|
385
548
|
if not self.config.use_multi_db and self.config.user_name:
|
|
386
549
|
user_name = self.config.user_name
|
|
387
|
-
query += f" WHERE n.user_name = {_format_value(user_name)}"
|
|
550
|
+
query += f" WHERE n.user_name = {self._format_value(user_name)}"
|
|
388
551
|
query += "\n DETACH DELETE n"
|
|
389
552
|
self.execute_query(query)
|
|
390
553
|
|
|
554
|
+
@timed
|
|
391
555
|
def add_edge(self, source_id: str, target_id: str, type: str):
|
|
392
556
|
"""
|
|
393
557
|
Create an edge from source node to target node.
|
|
@@ -412,6 +576,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
412
576
|
except Exception as e:
|
|
413
577
|
logger.error(f"Failed to insert edge: {e}", exc_info=True)
|
|
414
578
|
|
|
579
|
+
@timed
|
|
415
580
|
def delete_edge(self, source_id: str, target_id: str, type: str) -> None:
|
|
416
581
|
"""
|
|
417
582
|
Delete a specific edge between two nodes.
|
|
@@ -422,16 +587,17 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
422
587
|
"""
|
|
423
588
|
query = f"""
|
|
424
589
|
MATCH (a@Memory) -[r@{type}]-> (b@Memory)
|
|
425
|
-
WHERE a.id = {_format_value(source_id)} AND b.id = {_format_value(target_id)}
|
|
590
|
+
WHERE a.id = {self._format_value(source_id)} AND b.id = {self._format_value(target_id)}
|
|
426
591
|
"""
|
|
427
592
|
|
|
428
593
|
if not self.config.use_multi_db and self.config.user_name:
|
|
429
594
|
user_name = self.config.user_name
|
|
430
|
-
query += f" AND a.user_name = {_format_value(user_name)} AND b.user_name = {_format_value(user_name)}"
|
|
595
|
+
query += f" AND a.user_name = {self._format_value(user_name)} AND b.user_name = {self._format_value(user_name)}"
|
|
431
596
|
|
|
432
597
|
query += "\nDELETE r"
|
|
433
598
|
self.execute_query(query)
|
|
434
599
|
|
|
600
|
+
@timed
|
|
435
601
|
def get_memory_count(self, memory_type: str) -> int:
|
|
436
602
|
query = f"""
|
|
437
603
|
MATCH (n@Memory)
|
|
@@ -449,6 +615,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
449
615
|
logger.error(f"[get_memory_count] Failed: {e}")
|
|
450
616
|
return -1
|
|
451
617
|
|
|
618
|
+
@timed
|
|
452
619
|
def count_nodes(self, scope: str) -> int:
|
|
453
620
|
query = f"""
|
|
454
621
|
MATCH (n@Memory)
|
|
@@ -462,6 +629,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
462
629
|
result = self.execute_query(query)
|
|
463
630
|
return result.one_or_none()["count"].value
|
|
464
631
|
|
|
632
|
+
@timed
|
|
465
633
|
def edge_exists(
|
|
466
634
|
self, source_id: str, target_id: str, type: str = "ANY", direction: str = "OUTGOING"
|
|
467
635
|
) -> bool:
|
|
@@ -503,43 +671,53 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
503
671
|
return False
|
|
504
672
|
return record.values() is not None
|
|
505
673
|
|
|
674
|
+
@timed
|
|
506
675
|
# Graph Query & Reasoning
|
|
507
|
-
def get_node(self, id: str) -> dict[str, Any] | None:
|
|
676
|
+
def get_node(self, id: str, include_embedding: bool = False) -> dict[str, Any] | None:
|
|
508
677
|
"""
|
|
509
678
|
Retrieve a Memory node by its unique ID.
|
|
510
679
|
|
|
511
680
|
Args:
|
|
512
681
|
id (str): Node ID (Memory.id)
|
|
682
|
+
include_embedding: with/without embedding
|
|
513
683
|
|
|
514
684
|
Returns:
|
|
515
685
|
dict: Node properties as key-value pairs, or None if not found.
|
|
516
686
|
"""
|
|
687
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
688
|
+
filter_clause = f'n.user_name = "{self.config.user_name}" AND n.id = "{id}"'
|
|
689
|
+
else:
|
|
690
|
+
filter_clause = f'n.id = "{id}"'
|
|
691
|
+
|
|
692
|
+
return_fields = self._build_return_fields(include_embedding)
|
|
517
693
|
gql = f"""
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
694
|
+
MATCH (n@Memory)
|
|
695
|
+
WHERE {filter_clause}
|
|
696
|
+
RETURN {return_fields}
|
|
697
|
+
"""
|
|
522
698
|
|
|
523
699
|
try:
|
|
524
700
|
result = self.execute_query(gql)
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
node_wrapper = record["v"].as_node()
|
|
530
|
-
props = node_wrapper.get_properties()
|
|
531
|
-
node = self._parse_node(props)
|
|
532
|
-
return node
|
|
701
|
+
for row in result:
|
|
702
|
+
props = {k: v.value for k, v in row.items()}
|
|
703
|
+
node = self._parse_node(props)
|
|
704
|
+
return node
|
|
533
705
|
|
|
534
706
|
except Exception as e:
|
|
535
|
-
logger.error(
|
|
707
|
+
logger.error(
|
|
708
|
+
f"[get_node] Failed to retrieve node '{id}': {e}, trace: {traceback.format_exc()}"
|
|
709
|
+
)
|
|
536
710
|
return None
|
|
537
711
|
|
|
538
|
-
|
|
712
|
+
@timed
|
|
713
|
+
def get_nodes(
|
|
714
|
+
self, ids: list[str], include_embedding: bool = False, **kwargs
|
|
715
|
+
) -> list[dict[str, Any]]:
|
|
539
716
|
"""
|
|
540
717
|
Retrieve the metadata and memory of a list of nodes.
|
|
541
718
|
Args:
|
|
542
719
|
ids: List of Node identifier.
|
|
720
|
+
include_embedding: with/without embedding
|
|
543
721
|
Returns:
|
|
544
722
|
list[dict]: Parsed node records containing 'id', 'memory', and 'metadata'.
|
|
545
723
|
|
|
@@ -552,18 +730,33 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
552
730
|
|
|
553
731
|
where_user = ""
|
|
554
732
|
if not self.config.use_multi_db and self.config.user_name:
|
|
555
|
-
|
|
733
|
+
if kwargs.get("cube_name"):
|
|
734
|
+
where_user = f" AND n.user_name = '{kwargs['cube_name']}'"
|
|
735
|
+
else:
|
|
736
|
+
where_user = f" AND n.user_name = '{self.config.user_name}'"
|
|
556
737
|
|
|
557
|
-
|
|
738
|
+
# Safe formatting of the ID list
|
|
739
|
+
id_list = ",".join(f'"{_id}"' for _id in ids)
|
|
558
740
|
|
|
559
|
-
|
|
741
|
+
return_fields = self._build_return_fields(include_embedding)
|
|
742
|
+
query = f"""
|
|
743
|
+
MATCH (n@Memory)
|
|
744
|
+
WHERE n.id IN [{id_list}] {where_user}
|
|
745
|
+
RETURN {return_fields}
|
|
746
|
+
"""
|
|
560
747
|
nodes = []
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
748
|
+
try:
|
|
749
|
+
results = self.execute_query(query)
|
|
750
|
+
for row in results:
|
|
751
|
+
props = {k: v.value for k, v in row.items()}
|
|
752
|
+
nodes.append(self._parse_node(props))
|
|
753
|
+
except Exception as e:
|
|
754
|
+
logger.error(
|
|
755
|
+
f"[get_nodes] Failed to retrieve nodes {ids}: {e}, trace: {traceback.format_exc()}"
|
|
756
|
+
)
|
|
565
757
|
return nodes
|
|
566
758
|
|
|
759
|
+
@timed
|
|
567
760
|
def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[dict[str, str]]:
|
|
568
761
|
"""
|
|
569
762
|
Get edges connected to a node, with optional type and direction filter.
|
|
@@ -617,12 +810,14 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
617
810
|
)
|
|
618
811
|
return edges
|
|
619
812
|
|
|
813
|
+
@timed
|
|
620
814
|
def get_neighbors_by_tag(
|
|
621
815
|
self,
|
|
622
816
|
tags: list[str],
|
|
623
817
|
exclude_ids: list[str],
|
|
624
818
|
top_k: int = 5,
|
|
625
819
|
min_overlap: int = 1,
|
|
820
|
+
include_embedding: bool = False,
|
|
626
821
|
) -> list[dict[str, Any]]:
|
|
627
822
|
"""
|
|
628
823
|
Find top-K neighbor nodes with maximum tag overlap.
|
|
@@ -632,6 +827,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
632
827
|
exclude_ids: Node IDs to exclude (e.g., local cluster).
|
|
633
828
|
top_k: Max number of neighbors to return.
|
|
634
829
|
min_overlap: Minimum number of overlapping tags required.
|
|
830
|
+
include_embedding: with/without embedding
|
|
635
831
|
|
|
636
832
|
Returns:
|
|
637
833
|
List of dicts with node details and overlap count.
|
|
@@ -653,12 +849,13 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
653
849
|
where_clause = " AND ".join(where_clauses)
|
|
654
850
|
tag_list_literal = "[" + ", ".join(f'"{_escape_str(t)}"' for t in tags) + "]"
|
|
655
851
|
|
|
852
|
+
return_fields = self._build_return_fields(include_embedding)
|
|
656
853
|
query = f"""
|
|
657
854
|
LET tag_list = {tag_list_literal}
|
|
658
855
|
|
|
659
856
|
MATCH (n@Memory)
|
|
660
857
|
WHERE {where_clause}
|
|
661
|
-
RETURN
|
|
858
|
+
RETURN {return_fields},
|
|
662
859
|
size( filter( n.tags, t -> t IN tag_list ) ) AS overlap_count
|
|
663
860
|
ORDER BY overlap_count DESC
|
|
664
861
|
LIMIT {top_k}
|
|
@@ -667,9 +864,8 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
667
864
|
result = self.execute_query(query)
|
|
668
865
|
neighbors: list[dict[str, Any]] = []
|
|
669
866
|
for r in result:
|
|
670
|
-
|
|
671
|
-
parsed = self._parse_node(
|
|
672
|
-
|
|
867
|
+
props = {k: v.value for k, v in r.items() if k != "overlap_count"}
|
|
868
|
+
parsed = self._parse_node(props)
|
|
673
869
|
parsed["overlap_count"] = r["overlap_count"].value
|
|
674
870
|
neighbors.append(parsed)
|
|
675
871
|
|
|
@@ -681,6 +877,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
681
877
|
result.append(neighbor)
|
|
682
878
|
return result
|
|
683
879
|
|
|
880
|
+
@timed
|
|
684
881
|
def get_children_with_embeddings(self, id: str) -> list[dict[str, Any]]:
|
|
685
882
|
where_user = ""
|
|
686
883
|
|
|
@@ -691,19 +888,20 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
691
888
|
query = f"""
|
|
692
889
|
MATCH (p@Memory)-[@PARENT]->(c@Memory)
|
|
693
890
|
WHERE p.id = "{id}" {where_user}
|
|
694
|
-
RETURN c.id AS id, c.
|
|
891
|
+
RETURN c.id AS id, c.{self.dim_field} AS {self.dim_field}, c.memory AS memory
|
|
695
892
|
"""
|
|
696
893
|
result = self.execute_query(query)
|
|
697
894
|
children = []
|
|
698
895
|
for row in result:
|
|
699
896
|
eid = row["id"].value # STRING
|
|
700
|
-
emb_v = row[
|
|
897
|
+
emb_v = row[self.dim_field].value # NVector
|
|
701
898
|
emb = list(emb_v.values) if emb_v else []
|
|
702
899
|
mem = row["memory"].value # STRING
|
|
703
900
|
|
|
704
901
|
children.append({"id": eid, "embedding": emb, "memory": mem})
|
|
705
902
|
return children
|
|
706
903
|
|
|
904
|
+
@timed
|
|
707
905
|
def get_subgraph(
|
|
708
906
|
self, center_id: str, depth: int = 2, center_status: str = "activated"
|
|
709
907
|
) -> dict[str, Any]:
|
|
@@ -765,6 +963,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
765
963
|
|
|
766
964
|
return {"core_node": core_node, "neighbors": neighbors, "edges": edges}
|
|
767
965
|
|
|
966
|
+
@timed
|
|
768
967
|
# Search / recall operations
|
|
769
968
|
def search_by_embedding(
|
|
770
969
|
self,
|
|
@@ -773,6 +972,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
773
972
|
scope: str | None = None,
|
|
774
973
|
status: str | None = None,
|
|
775
974
|
threshold: float | None = None,
|
|
975
|
+
**kwargs,
|
|
776
976
|
) -> list[dict]:
|
|
777
977
|
"""
|
|
778
978
|
Retrieve node IDs based on vector similarity.
|
|
@@ -807,7 +1007,10 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
807
1007
|
if status:
|
|
808
1008
|
where_clauses.append(f'n.status = "{status}"')
|
|
809
1009
|
if not self.config.use_multi_db and self.config.user_name:
|
|
810
|
-
|
|
1010
|
+
if kwargs.get("cube_name"):
|
|
1011
|
+
where_clauses.append(f'n.user_name = "{kwargs["cube_name"]}"')
|
|
1012
|
+
else:
|
|
1013
|
+
where_clauses.append(f'n.user_name = "{self.config.user_name}"')
|
|
811
1014
|
|
|
812
1015
|
where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
|
|
813
1016
|
|
|
@@ -815,11 +1018,11 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
815
1018
|
USE `{self.db_name}`
|
|
816
1019
|
MATCH (n@Memory)
|
|
817
1020
|
{where_clause}
|
|
818
|
-
ORDER BY inner_product(n.
|
|
1021
|
+
ORDER BY inner_product(n.{self.dim_field}, {gql_vector}) DESC
|
|
819
1022
|
APPROXIMATE
|
|
820
1023
|
LIMIT {top_k}
|
|
821
1024
|
OPTIONS {{ METRIC: IP, TYPE: IVF, NPROBE: 8 }}
|
|
822
|
-
RETURN n.id AS id, inner_product(n.
|
|
1025
|
+
RETURN n.id AS id, inner_product(n.{self.dim_field}, {gql_vector}) AS score
|
|
823
1026
|
"""
|
|
824
1027
|
|
|
825
1028
|
try:
|
|
@@ -842,6 +1045,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
842
1045
|
logger.error(f"[search_by_embedding] Result parse failed: {e}")
|
|
843
1046
|
return []
|
|
844
1047
|
|
|
1048
|
+
@timed
|
|
845
1049
|
def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]:
|
|
846
1050
|
"""
|
|
847
1051
|
1. ADD logic: "AND" vs "OR"(support logic combination);
|
|
@@ -868,20 +1072,12 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
868
1072
|
"""
|
|
869
1073
|
where_clauses = []
|
|
870
1074
|
|
|
871
|
-
def _escape_value(value):
|
|
872
|
-
if isinstance(value, str):
|
|
873
|
-
return f'"{value}"'
|
|
874
|
-
elif isinstance(value, list):
|
|
875
|
-
return "[" + ", ".join(_escape_value(v) for v in value) + "]"
|
|
876
|
-
else:
|
|
877
|
-
return str(value)
|
|
878
|
-
|
|
879
1075
|
for _i, f in enumerate(filters):
|
|
880
1076
|
field = f["field"]
|
|
881
1077
|
op = f.get("op", "=")
|
|
882
1078
|
value = f["value"]
|
|
883
1079
|
|
|
884
|
-
escaped_value =
|
|
1080
|
+
escaped_value = self._format_value(value)
|
|
885
1081
|
|
|
886
1082
|
# Build WHERE clause
|
|
887
1083
|
if op == "=":
|
|
@@ -912,6 +1108,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
912
1108
|
logger.error(f"Failed to get metadata: {e}, gql is {gql}")
|
|
913
1109
|
return ids
|
|
914
1110
|
|
|
1111
|
+
@timed
|
|
915
1112
|
def get_grouped_counts(
|
|
916
1113
|
self,
|
|
917
1114
|
group_fields: list[str],
|
|
@@ -980,6 +1177,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
980
1177
|
|
|
981
1178
|
return output
|
|
982
1179
|
|
|
1180
|
+
@timed
|
|
983
1181
|
def clear(self) -> None:
|
|
984
1182
|
"""
|
|
985
1183
|
Clear the entire graph if the target database exists.
|
|
@@ -996,9 +1194,12 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
996
1194
|
except Exception as e:
|
|
997
1195
|
logger.error(f"[ERROR] Failed to clear database: {e}")
|
|
998
1196
|
|
|
999
|
-
|
|
1197
|
+
@timed
|
|
1198
|
+
def export_graph(self, include_embedding: bool = False) -> dict[str, Any]:
|
|
1000
1199
|
"""
|
|
1001
1200
|
Export all graph nodes and edges in a structured form.
|
|
1201
|
+
Args:
|
|
1202
|
+
include_embedding (bool): Whether to include the large embedding field.
|
|
1002
1203
|
|
|
1003
1204
|
Returns:
|
|
1004
1205
|
{
|
|
@@ -1015,13 +1216,41 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1015
1216
|
edge_query += f' WHERE r.user_name = "{username}"'
|
|
1016
1217
|
|
|
1017
1218
|
try:
|
|
1018
|
-
|
|
1019
|
-
|
|
1219
|
+
if include_embedding:
|
|
1220
|
+
return_fields = "n"
|
|
1221
|
+
else:
|
|
1222
|
+
return_fields = ",".join(
|
|
1223
|
+
[
|
|
1224
|
+
"n.id AS id",
|
|
1225
|
+
"n.memory AS memory",
|
|
1226
|
+
"n.user_name AS user_name",
|
|
1227
|
+
"n.user_id AS user_id",
|
|
1228
|
+
"n.session_id AS session_id",
|
|
1229
|
+
"n.status AS status",
|
|
1230
|
+
"n.key AS key",
|
|
1231
|
+
"n.confidence AS confidence",
|
|
1232
|
+
"n.tags AS tags",
|
|
1233
|
+
"n.created_at AS created_at",
|
|
1234
|
+
"n.updated_at AS updated_at",
|
|
1235
|
+
"n.memory_type AS memory_type",
|
|
1236
|
+
"n.sources AS sources",
|
|
1237
|
+
"n.source AS source",
|
|
1238
|
+
"n.node_type AS node_type",
|
|
1239
|
+
"n.visibility AS visibility",
|
|
1240
|
+
"n.usage AS usage",
|
|
1241
|
+
"n.background AS background",
|
|
1242
|
+
]
|
|
1243
|
+
)
|
|
1244
|
+
|
|
1245
|
+
full_node_query = f"{node_query} RETURN {return_fields}"
|
|
1246
|
+
node_result = self.execute_query(full_node_query, timeout=20)
|
|
1020
1247
|
nodes = []
|
|
1248
|
+
logger.debug(f"Debugging: {node_result}")
|
|
1021
1249
|
for row in node_result:
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1250
|
+
if include_embedding:
|
|
1251
|
+
props = row.values()[0].as_node().get_properties()
|
|
1252
|
+
else:
|
|
1253
|
+
props = {k: v.value for k, v in row.items()}
|
|
1025
1254
|
node = self._parse_node(props)
|
|
1026
1255
|
nodes.append(node)
|
|
1027
1256
|
except Exception as e:
|
|
@@ -1029,7 +1258,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1029
1258
|
|
|
1030
1259
|
try:
|
|
1031
1260
|
full_edge_query = f"{edge_query} RETURN a.id AS source, b.id AS target, type(r) as edge"
|
|
1032
|
-
edge_result = self.execute_query(full_edge_query)
|
|
1261
|
+
edge_result = self.execute_query(full_edge_query, timeout=20)
|
|
1033
1262
|
edges = [
|
|
1034
1263
|
{
|
|
1035
1264
|
"source": row.values()[0].value,
|
|
@@ -1043,6 +1272,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1043
1272
|
|
|
1044
1273
|
return {"nodes": nodes, "edges": edges}
|
|
1045
1274
|
|
|
1275
|
+
@timed
|
|
1046
1276
|
def import_graph(self, data: dict[str, Any]) -> None:
|
|
1047
1277
|
"""
|
|
1048
1278
|
Import the entire graph from a serialized dictionary.
|
|
@@ -1051,35 +1281,45 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1051
1281
|
data: A dictionary containing all nodes and edges to be loaded.
|
|
1052
1282
|
"""
|
|
1053
1283
|
for node in data.get("nodes", []):
|
|
1054
|
-
|
|
1284
|
+
try:
|
|
1285
|
+
id, memory, metadata = _compose_node(node)
|
|
1055
1286
|
|
|
1056
|
-
|
|
1057
|
-
|
|
1287
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
1288
|
+
metadata["user_name"] = self.config.user_name
|
|
1058
1289
|
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1290
|
+
metadata = self._prepare_node_metadata(metadata)
|
|
1291
|
+
metadata.update({"id": id, "memory": memory})
|
|
1292
|
+
properties = ", ".join(
|
|
1293
|
+
f"{k}: {self._format_value(v, k)}" for k, v in metadata.items()
|
|
1294
|
+
)
|
|
1295
|
+
node_gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
|
|
1296
|
+
self.execute_query(node_gql)
|
|
1297
|
+
except Exception as e:
|
|
1298
|
+
logger.error(f"Fail to load node: {node}, error: {e}")
|
|
1064
1299
|
|
|
1065
1300
|
for edge in data.get("edges", []):
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1301
|
+
try:
|
|
1302
|
+
source_id, target_id = edge["source"], edge["target"]
|
|
1303
|
+
edge_type = edge["type"]
|
|
1304
|
+
props = ""
|
|
1305
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
1306
|
+
props = f'{{user_name: "{self.config.user_name}"}}'
|
|
1307
|
+
edge_gql = f'''
|
|
1308
|
+
MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}})
|
|
1309
|
+
INSERT OR IGNORE (a) -[e@{edge_type} {props}]-> (b)
|
|
1310
|
+
'''
|
|
1311
|
+
self.execute_query(edge_gql)
|
|
1312
|
+
except Exception as e:
|
|
1313
|
+
logger.error(f"Fail to load edge: {edge}, error: {e}")
|
|
1076
1314
|
|
|
1077
|
-
|
|
1315
|
+
@timed
|
|
1316
|
+
def get_all_memory_items(self, scope: str, include_embedding: bool = False) -> (list)[dict]:
|
|
1078
1317
|
"""
|
|
1079
1318
|
Retrieve all memory items of a specific memory_type.
|
|
1080
1319
|
|
|
1081
1320
|
Args:
|
|
1082
1321
|
scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'.
|
|
1322
|
+
include_embedding: with/without embedding
|
|
1083
1323
|
|
|
1084
1324
|
Returns:
|
|
1085
1325
|
list[dict]: Full list of memory items under this scope.
|
|
@@ -1092,22 +1332,28 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1092
1332
|
if not self.config.use_multi_db and self.config.user_name:
|
|
1093
1333
|
where_clause += f" AND n.user_name = '{self.config.user_name}'"
|
|
1094
1334
|
|
|
1335
|
+
return_fields = self._build_return_fields(include_embedding)
|
|
1336
|
+
|
|
1095
1337
|
query = f"""
|
|
1096
1338
|
MATCH (n@Memory)
|
|
1097
1339
|
{where_clause}
|
|
1098
|
-
RETURN
|
|
1340
|
+
RETURN {return_fields}
|
|
1341
|
+
LIMIT 100
|
|
1099
1342
|
"""
|
|
1100
1343
|
nodes = []
|
|
1101
1344
|
try:
|
|
1102
1345
|
results = self.execute_query(query)
|
|
1103
|
-
for
|
|
1104
|
-
|
|
1105
|
-
nodes.append(self._parse_node(
|
|
1346
|
+
for row in results:
|
|
1347
|
+
props = {k: v.value for k, v in row.items()}
|
|
1348
|
+
nodes.append(self._parse_node(props))
|
|
1106
1349
|
except Exception as e:
|
|
1107
1350
|
logger.error(f"Failed to get memories: {e}")
|
|
1108
1351
|
return nodes
|
|
1109
1352
|
|
|
1110
|
-
|
|
1353
|
+
@timed
|
|
1354
|
+
def get_structure_optimization_candidates(
|
|
1355
|
+
self, scope: str, include_embedding: bool = False
|
|
1356
|
+
) -> list[dict]:
|
|
1111
1357
|
"""
|
|
1112
1358
|
Find nodes that are likely candidates for structure optimization:
|
|
1113
1359
|
- Isolated nodes, nodes with empty background, or nodes with exactly one child.
|
|
@@ -1121,6 +1367,8 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1121
1367
|
if not self.config.use_multi_db and self.config.user_name:
|
|
1122
1368
|
where_clause += f' AND n.user_name = "{self.config.user_name}"'
|
|
1123
1369
|
|
|
1370
|
+
return_fields = self._build_return_fields(include_embedding)
|
|
1371
|
+
|
|
1124
1372
|
query = f"""
|
|
1125
1373
|
USE `{self.db_name}`
|
|
1126
1374
|
MATCH (n@Memory)
|
|
@@ -1128,19 +1376,20 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1128
1376
|
OPTIONAL MATCH (n)-[@PARENT]->(c@Memory)
|
|
1129
1377
|
OPTIONAL MATCH (p@Memory)-[@PARENT]->(n)
|
|
1130
1378
|
WHERE c IS NULL AND p IS NULL
|
|
1131
|
-
RETURN
|
|
1379
|
+
RETURN {return_fields}
|
|
1132
1380
|
"""
|
|
1133
1381
|
|
|
1134
1382
|
candidates = []
|
|
1135
1383
|
try:
|
|
1136
1384
|
results = self.execute_query(query)
|
|
1137
|
-
for
|
|
1138
|
-
|
|
1139
|
-
candidates.append(self._parse_node(
|
|
1385
|
+
for row in results:
|
|
1386
|
+
props = {k: v.value for k, v in row.items()}
|
|
1387
|
+
candidates.append(self._parse_node(props))
|
|
1140
1388
|
except Exception as e:
|
|
1141
|
-
logger.error(f"Failed : {e}")
|
|
1389
|
+
logger.error(f"Failed : {e}, traceback: {traceback.format_exc()}")
|
|
1142
1390
|
return candidates
|
|
1143
1391
|
|
|
1392
|
+
@timed
|
|
1144
1393
|
def drop_database(self) -> None:
|
|
1145
1394
|
"""
|
|
1146
1395
|
Permanently delete the entire database this instance is using.
|
|
@@ -1155,6 +1404,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1155
1404
|
f"Shared Database Multi-Tenant mode"
|
|
1156
1405
|
)
|
|
1157
1406
|
|
|
1407
|
+
@timed
|
|
1158
1408
|
def detect_conflicts(self) -> list[tuple[str, str]]:
|
|
1159
1409
|
"""
|
|
1160
1410
|
Detect conflicting nodes based on logical or semantic inconsistency.
|
|
@@ -1163,6 +1413,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1163
1413
|
"""
|
|
1164
1414
|
raise NotImplementedError
|
|
1165
1415
|
|
|
1416
|
+
@timed
|
|
1166
1417
|
# Structure Maintenance
|
|
1167
1418
|
def deduplicate_nodes(self) -> None:
|
|
1168
1419
|
"""
|
|
@@ -1171,6 +1422,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1171
1422
|
"""
|
|
1172
1423
|
raise NotImplementedError
|
|
1173
1424
|
|
|
1425
|
+
@timed
|
|
1174
1426
|
def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]:
|
|
1175
1427
|
"""
|
|
1176
1428
|
Get the ordered context chain starting from a node, following a relationship type.
|
|
@@ -1182,6 +1434,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1182
1434
|
"""
|
|
1183
1435
|
raise NotImplementedError
|
|
1184
1436
|
|
|
1437
|
+
@timed
|
|
1185
1438
|
def get_neighbors(
|
|
1186
1439
|
self, id: str, type: str, direction: Literal["in", "out", "both"] = "out"
|
|
1187
1440
|
) -> list[str]:
|
|
@@ -1196,6 +1449,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1196
1449
|
"""
|
|
1197
1450
|
raise NotImplementedError
|
|
1198
1451
|
|
|
1452
|
+
@timed
|
|
1199
1453
|
def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[str]:
|
|
1200
1454
|
"""
|
|
1201
1455
|
Get the path of nodes from source to target within a limited depth.
|
|
@@ -1208,6 +1462,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1208
1462
|
"""
|
|
1209
1463
|
raise NotImplementedError
|
|
1210
1464
|
|
|
1465
|
+
@timed
|
|
1211
1466
|
def merge_nodes(self, id1: str, id2: str) -> str:
|
|
1212
1467
|
"""
|
|
1213
1468
|
Merge two similar or duplicate nodes into one.
|
|
@@ -1219,70 +1474,112 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1219
1474
|
"""
|
|
1220
1475
|
raise NotImplementedError
|
|
1221
1476
|
|
|
1477
|
+
@timed
|
|
1222
1478
|
def _ensure_database_exists(self):
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
|
|
1250
|
-
|
|
1251
|
-
|
|
1252
|
-
|
|
1253
|
-
|
|
1254
|
-
|
|
1255
|
-
|
|
1479
|
+
graph_type_name = "MemOSBgeM3Type"
|
|
1480
|
+
|
|
1481
|
+
check_type_query = "SHOW GRAPH TYPES"
|
|
1482
|
+
result = self.execute_query(check_type_query, auto_set_db=False)
|
|
1483
|
+
|
|
1484
|
+
type_exists = any(row["graph_type"].as_string() == graph_type_name for row in result)
|
|
1485
|
+
|
|
1486
|
+
if not type_exists:
|
|
1487
|
+
create_tag = f"""
|
|
1488
|
+
CREATE GRAPH TYPE IF NOT EXISTS {graph_type_name} AS {{
|
|
1489
|
+
NODE Memory (:MemoryTag {{
|
|
1490
|
+
id STRING,
|
|
1491
|
+
memory STRING,
|
|
1492
|
+
user_name STRING,
|
|
1493
|
+
user_id STRING,
|
|
1494
|
+
session_id STRING,
|
|
1495
|
+
status STRING,
|
|
1496
|
+
key STRING,
|
|
1497
|
+
confidence FLOAT,
|
|
1498
|
+
tags LIST<STRING>,
|
|
1499
|
+
created_at STRING,
|
|
1500
|
+
updated_at STRING,
|
|
1501
|
+
memory_type STRING,
|
|
1502
|
+
sources LIST<STRING>,
|
|
1503
|
+
source STRING,
|
|
1504
|
+
node_type STRING,
|
|
1505
|
+
visibility STRING,
|
|
1506
|
+
usage LIST<STRING>,
|
|
1507
|
+
background STRING,
|
|
1508
|
+
{self.dim_field} VECTOR<{self.embedding_dimension}, FLOAT>,
|
|
1509
|
+
PRIMARY KEY(id)
|
|
1510
|
+
}}),
|
|
1511
|
+
EDGE RELATE_TO (Memory) -[{{user_name STRING}}]-> (Memory),
|
|
1512
|
+
EDGE PARENT (Memory) -[{{user_name STRING}}]-> (Memory),
|
|
1513
|
+
EDGE AGGREGATE_TO (Memory) -[{{user_name STRING}}]-> (Memory),
|
|
1514
|
+
EDGE MERGED_TO (Memory) -[{{user_name STRING}}]-> (Memory),
|
|
1515
|
+
EDGE INFERS (Memory) -[{{user_name STRING}}]-> (Memory),
|
|
1516
|
+
EDGE FOLLOWS (Memory) -[{{user_name STRING}}]-> (Memory)
|
|
1517
|
+
}}
|
|
1518
|
+
"""
|
|
1519
|
+
self.execute_query(create_tag, auto_set_db=False)
|
|
1520
|
+
else:
|
|
1521
|
+
describe_query = f"DESCRIBE NODE TYPE Memory OF {graph_type_name};"
|
|
1522
|
+
desc_result = self.execute_query(describe_query, auto_set_db=False)
|
|
1523
|
+
|
|
1524
|
+
memory_fields = []
|
|
1525
|
+
for row in desc_result:
|
|
1526
|
+
field_name = row.values()[0].as_string()
|
|
1527
|
+
memory_fields.append(field_name)
|
|
1528
|
+
|
|
1529
|
+
if self.dim_field not in memory_fields:
|
|
1530
|
+
alter_query = f"""
|
|
1531
|
+
ALTER GRAPH TYPE {graph_type_name} {{
|
|
1532
|
+
ALTER NODE TYPE Memory ADD PROPERTIES {{ {self.dim_field} VECTOR<{self.embedding_dimension}, FLOAT> }}
|
|
1533
|
+
}}
|
|
1534
|
+
"""
|
|
1535
|
+
self.execute_query(alter_query, auto_set_db=False)
|
|
1536
|
+
logger.info(f"✅ Add new vector search {self.dim_field} to {graph_type_name}")
|
|
1537
|
+
else:
|
|
1538
|
+
logger.info(f"✅ Graph Type {graph_type_name} already include {self.dim_field}")
|
|
1539
|
+
|
|
1540
|
+
create_graph = f"CREATE GRAPH IF NOT EXISTS `{self.db_name}` TYPED {graph_type_name}"
|
|
1256
1541
|
set_graph_working = f"SESSION SET GRAPH `{self.db_name}`"
|
|
1257
1542
|
|
|
1258
1543
|
try:
|
|
1259
|
-
self.execute_query(create_tag, auto_set_db=False)
|
|
1260
1544
|
self.execute_query(create_graph, auto_set_db=False)
|
|
1261
1545
|
self.execute_query(set_graph_working)
|
|
1262
1546
|
logger.info(f"✅ Graph ``{self.db_name}`` is now the working graph.")
|
|
1263
1547
|
except Exception as e:
|
|
1264
1548
|
logger.error(f"❌ Failed to create tag: {e} trace: {traceback.format_exc()}")
|
|
1265
1549
|
|
|
1550
|
+
@timed
|
|
1266
1551
|
def _create_vector_index(
|
|
1267
1552
|
self, label: str, vector_property: str, dimensions: int, index_name: str
|
|
1268
1553
|
) -> None:
|
|
1269
1554
|
"""
|
|
1270
1555
|
Create a vector index for the specified property in the label.
|
|
1271
1556
|
"""
|
|
1557
|
+
if str(dimensions) == str(self.default_memory_dimension):
|
|
1558
|
+
index_name = f"idx_{vector_property}"
|
|
1559
|
+
vector_name = vector_property
|
|
1560
|
+
else:
|
|
1561
|
+
index_name = f"idx_{vector_property}_{dimensions}"
|
|
1562
|
+
vector_name = f"{vector_property}_{dimensions}"
|
|
1563
|
+
|
|
1272
1564
|
create_vector_index = f"""
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
|
|
1565
|
+
CREATE VECTOR INDEX IF NOT EXISTS {index_name}
|
|
1566
|
+
ON NODE {label}::{vector_name}
|
|
1567
|
+
OPTIONS {{
|
|
1568
|
+
DIM: {dimensions},
|
|
1569
|
+
METRIC: IP,
|
|
1570
|
+
TYPE: IVF,
|
|
1571
|
+
NLIST: 100,
|
|
1572
|
+
TRAINSIZE: 1000
|
|
1573
|
+
}}
|
|
1574
|
+
FOR `{self.db_name}`
|
|
1575
|
+
"""
|
|
1284
1576
|
self.execute_query(create_vector_index)
|
|
1577
|
+
logger.info(
|
|
1578
|
+
f"✅ Ensure {label}::{vector_property} vector index {index_name} "
|
|
1579
|
+
f"exists (DIM={dimensions})"
|
|
1580
|
+
)
|
|
1285
1581
|
|
|
1582
|
+
@timed
|
|
1286
1583
|
def _create_basic_property_indexes(self) -> None:
|
|
1287
1584
|
"""
|
|
1288
1585
|
Create standard B-tree indexes on status, memory_type, created_at
|
|
@@ -1304,8 +1601,11 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1304
1601
|
self.execute_query(gql)
|
|
1305
1602
|
logger.info(f"✅ Created index: {index_name} on field {field}")
|
|
1306
1603
|
except Exception as e:
|
|
1307
|
-
logger.error(
|
|
1604
|
+
logger.error(
|
|
1605
|
+
f"❌ Failed to create index {index_name}: {e}, trace: {traceback.format_exc()}"
|
|
1606
|
+
)
|
|
1308
1607
|
|
|
1608
|
+
@timed
|
|
1309
1609
|
def _index_exists(self, index_name: str) -> bool:
|
|
1310
1610
|
"""
|
|
1311
1611
|
Check if an index with the given name exists.
|
|
@@ -1327,6 +1627,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1327
1627
|
logger.error(f"[Nebula] Failed to check index existence: {e}")
|
|
1328
1628
|
return False
|
|
1329
1629
|
|
|
1630
|
+
@timed
|
|
1330
1631
|
def _parse_value(self, value: Any) -> Any:
|
|
1331
1632
|
"""turn Nebula ValueWrapper to Python type"""
|
|
1332
1633
|
from nebulagraph_python.value_wrapper import ValueWrapper
|
|
@@ -1352,8 +1653,8 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1352
1653
|
parsed = {k: self._parse_value(v) for k, v in props.items()}
|
|
1353
1654
|
|
|
1354
1655
|
for tf in ("created_at", "updated_at"):
|
|
1355
|
-
if tf in parsed and
|
|
1356
|
-
parsed[tf] = parsed[tf]
|
|
1656
|
+
if tf in parsed and parsed[tf] is not None:
|
|
1657
|
+
parsed[tf] = _normalize_datetime(parsed[tf])
|
|
1357
1658
|
|
|
1358
1659
|
node_id = parsed.pop("id")
|
|
1359
1660
|
memory = parsed.pop("memory", "")
|
|
@@ -1361,4 +1662,98 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1361
1662
|
metadata = parsed
|
|
1362
1663
|
metadata["type"] = metadata.pop("node_type")
|
|
1363
1664
|
|
|
1665
|
+
if self.dim_field in metadata:
|
|
1666
|
+
metadata["embedding"] = metadata.pop(self.dim_field)
|
|
1667
|
+
|
|
1364
1668
|
return {"id": node_id, "memory": memory, "metadata": metadata}
|
|
1669
|
+
|
|
1670
|
+
@timed
|
|
1671
|
+
def _prepare_node_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]:
|
|
1672
|
+
"""
|
|
1673
|
+
Ensure metadata has proper datetime fields and normalized types.
|
|
1674
|
+
|
|
1675
|
+
- Fill `created_at` and `updated_at` if missing (in ISO 8601 format).
|
|
1676
|
+
- Convert embedding to list of float if present.
|
|
1677
|
+
"""
|
|
1678
|
+
now = datetime.utcnow().isoformat()
|
|
1679
|
+
metadata["node_type"] = metadata.pop("type")
|
|
1680
|
+
|
|
1681
|
+
# Fill timestamps if missing
|
|
1682
|
+
metadata.setdefault("created_at", now)
|
|
1683
|
+
metadata.setdefault("updated_at", now)
|
|
1684
|
+
|
|
1685
|
+
# Normalize embedding type
|
|
1686
|
+
embedding = metadata.get("embedding")
|
|
1687
|
+
if embedding and isinstance(embedding, list):
|
|
1688
|
+
metadata.pop("embedding")
|
|
1689
|
+
metadata[self.dim_field] = _normalize([float(x) for x in embedding])
|
|
1690
|
+
|
|
1691
|
+
return metadata
|
|
1692
|
+
|
|
1693
|
+
@timed
|
|
1694
|
+
def _format_value(self, val: Any, key: str = "") -> str:
|
|
1695
|
+
from nebulagraph_python.py_data_types import NVector
|
|
1696
|
+
|
|
1697
|
+
# None
|
|
1698
|
+
if val is None:
|
|
1699
|
+
return "NULL"
|
|
1700
|
+
# bool
|
|
1701
|
+
if isinstance(val, bool):
|
|
1702
|
+
return "true" if val else "false"
|
|
1703
|
+
# str
|
|
1704
|
+
if isinstance(val, str):
|
|
1705
|
+
return f'"{_escape_str(val)}"'
|
|
1706
|
+
# num
|
|
1707
|
+
elif isinstance(val, (int | float)):
|
|
1708
|
+
return str(val)
|
|
1709
|
+
# time
|
|
1710
|
+
elif isinstance(val, datetime):
|
|
1711
|
+
return f'datetime("{val.isoformat()}")'
|
|
1712
|
+
# list
|
|
1713
|
+
elif isinstance(val, list):
|
|
1714
|
+
if key == self.dim_field:
|
|
1715
|
+
dim = len(val)
|
|
1716
|
+
joined = ",".join(str(float(x)) for x in val)
|
|
1717
|
+
return f"VECTOR<{dim}, FLOAT>([{joined}])"
|
|
1718
|
+
else:
|
|
1719
|
+
return f"[{', '.join(self._format_value(v) for v in val)}]"
|
|
1720
|
+
# NVector
|
|
1721
|
+
elif isinstance(val, NVector):
|
|
1722
|
+
if key == self.dim_field:
|
|
1723
|
+
dim = len(val)
|
|
1724
|
+
joined = ",".join(str(float(x)) for x in val)
|
|
1725
|
+
return f"VECTOR<{dim}, FLOAT>([{joined}])"
|
|
1726
|
+
else:
|
|
1727
|
+
logger.warning("Invalid NVector")
|
|
1728
|
+
# dict
|
|
1729
|
+
if isinstance(val, dict):
|
|
1730
|
+
j = json.dumps(val, ensure_ascii=False, separators=(",", ":"))
|
|
1731
|
+
return f'"{_escape_str(j)}"'
|
|
1732
|
+
else:
|
|
1733
|
+
return f'"{_escape_str(str(val))}"'
|
|
1734
|
+
|
|
1735
|
+
@timed
|
|
1736
|
+
def _metadata_filter(self, metadata: dict[str, Any]) -> dict[str, Any]:
|
|
1737
|
+
"""
|
|
1738
|
+
Filter and validate metadata dictionary against the Memory node schema.
|
|
1739
|
+
- Removes keys not in schema.
|
|
1740
|
+
- Warns if required fields are missing.
|
|
1741
|
+
"""
|
|
1742
|
+
|
|
1743
|
+
dim_fields = {self.dim_field}
|
|
1744
|
+
|
|
1745
|
+
allowed_fields = self.common_fields | dim_fields
|
|
1746
|
+
|
|
1747
|
+
missing_fields = allowed_fields - metadata.keys()
|
|
1748
|
+
if missing_fields:
|
|
1749
|
+
logger.info(f"Metadata missing required fields: {sorted(missing_fields)}")
|
|
1750
|
+
|
|
1751
|
+
filtered_metadata = {k: v for k, v in metadata.items() if k in allowed_fields}
|
|
1752
|
+
|
|
1753
|
+
return filtered_metadata
|
|
1754
|
+
|
|
1755
|
+
def _build_return_fields(self, include_embedding: bool = False) -> str:
|
|
1756
|
+
fields = set(self.base_fields)
|
|
1757
|
+
if include_embedding:
|
|
1758
|
+
fields.add(self.dim_field)
|
|
1759
|
+
return ", ".join(f"n.{f} AS {f}" for f in fields)
|