MemoryOS 0.2.2__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MemoryOS might be problematic. Click here for more details.
- {memoryos-0.2.2.dist-info → memoryos-1.0.0.dist-info}/METADATA +6 -1
- {memoryos-0.2.2.dist-info → memoryos-1.0.0.dist-info}/RECORD +61 -55
- memos/__init__.py +1 -1
- memos/api/config.py +6 -8
- memos/api/context/context.py +1 -1
- memos/api/context/dependencies.py +11 -0
- memos/configs/internet_retriever.py +13 -0
- memos/configs/mem_scheduler.py +38 -16
- memos/graph_dbs/base.py +30 -3
- memos/graph_dbs/nebular.py +442 -194
- memos/graph_dbs/neo4j.py +14 -5
- memos/log.py +5 -0
- memos/mem_os/core.py +19 -9
- memos/mem_os/main.py +1 -1
- memos/mem_os/product.py +6 -69
- memos/mem_os/utils/default_config.py +1 -1
- memos/mem_os/utils/format_utils.py +11 -47
- memos/mem_os/utils/reference_utils.py +133 -0
- memos/mem_scheduler/base_scheduler.py +58 -55
- memos/mem_scheduler/{modules → general_modules}/base.py +1 -2
- memos/mem_scheduler/{modules → general_modules}/dispatcher.py +54 -15
- memos/mem_scheduler/{modules → general_modules}/rabbitmq_service.py +4 -4
- memos/mem_scheduler/{modules → general_modules}/redis_service.py +1 -1
- memos/mem_scheduler/{modules → general_modules}/retriever.py +19 -5
- memos/mem_scheduler/{modules → general_modules}/scheduler_logger.py +10 -4
- memos/mem_scheduler/general_scheduler.py +110 -67
- memos/mem_scheduler/monitors/__init__.py +0 -0
- memos/mem_scheduler/monitors/dispatcher_monitor.py +305 -0
- memos/mem_scheduler/{modules/monitor.py → monitors/general_monitor.py} +57 -19
- memos/mem_scheduler/mos_for_test_scheduler.py +7 -1
- memos/mem_scheduler/schemas/general_schemas.py +3 -2
- memos/mem_scheduler/schemas/message_schemas.py +2 -1
- memos/mem_scheduler/schemas/monitor_schemas.py +10 -2
- memos/mem_scheduler/utils/misc_utils.py +43 -2
- memos/memories/activation/item.py +1 -1
- memos/memories/activation/kv.py +20 -8
- memos/memories/textual/base.py +1 -1
- memos/memories/textual/general.py +1 -1
- memos/memories/textual/tree_text_memory/organize/{conflict.py → handler.py} +30 -48
- memos/memories/textual/tree_text_memory/organize/manager.py +8 -96
- memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +2 -0
- memos/memories/textual/tree_text_memory/organize/reorganizer.py +102 -140
- memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +229 -0
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +9 -0
- memos/memories/textual/tree_text_memory/retrieve/recall.py +15 -8
- memos/memories/textual/tree_text_memory/retrieve/reranker.py +1 -1
- memos/memories/textual/tree_text_memory/retrieve/searcher.py +177 -125
- memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +7 -2
- memos/memories/textual/tree_text_memory/retrieve/utils.py +1 -1
- memos/memos_tools/lockfree_dict.py +120 -0
- memos/memos_tools/thread_safe_dict.py +288 -0
- memos/templates/mem_reader_prompts.py +2 -0
- memos/templates/mem_scheduler_prompts.py +23 -10
- memos/templates/mos_prompts.py +40 -11
- memos/templates/tree_reorganize_prompts.py +24 -17
- memos/utils.py +19 -0
- memos/memories/textual/tree_text_memory/organize/redundancy.py +0 -193
- {memoryos-0.2.2.dist-info → memoryos-1.0.0.dist-info}/LICENSE +0 -0
- {memoryos-0.2.2.dist-info → memoryos-1.0.0.dist-info}/WHEEL +0 -0
- {memoryos-0.2.2.dist-info → memoryos-1.0.0.dist-info}/entry_points.txt +0 -0
- /memos/mem_scheduler/{modules → general_modules}/__init__.py +0 -0
- /memos/mem_scheduler/{modules → general_modules}/misc.py +0 -0
memos/graph_dbs/nebular.py
CHANGED
|
@@ -12,17 +12,20 @@ from memos.configs.graph_db import NebulaGraphDBConfig
|
|
|
12
12
|
from memos.dependency import require_python_package
|
|
13
13
|
from memos.graph_dbs.base import BaseGraphDB
|
|
14
14
|
from memos.log import get_logger
|
|
15
|
+
from memos.utils import timed
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
logger = get_logger(__name__)
|
|
18
19
|
|
|
19
20
|
|
|
21
|
+
@timed
|
|
20
22
|
def _normalize(vec: list[float]) -> list[float]:
|
|
21
23
|
v = np.asarray(vec, dtype=np.float32)
|
|
22
24
|
norm = np.linalg.norm(v)
|
|
23
25
|
return (v / (norm if norm else 1.0)).tolist()
|
|
24
26
|
|
|
25
27
|
|
|
28
|
+
@timed
|
|
26
29
|
def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]:
|
|
27
30
|
node_id = item["id"]
|
|
28
31
|
memory = item["memory"]
|
|
@@ -30,97 +33,12 @@ def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]:
|
|
|
30
33
|
return node_id, memory, metadata
|
|
31
34
|
|
|
32
35
|
|
|
33
|
-
|
|
34
|
-
"""
|
|
35
|
-
Ensure metadata has proper datetime fields and normalized types.
|
|
36
|
-
|
|
37
|
-
- Fill `created_at` and `updated_at` if missing (in ISO 8601 format).
|
|
38
|
-
- Convert embedding to list of float if present.
|
|
39
|
-
"""
|
|
40
|
-
now = datetime.utcnow().isoformat()
|
|
41
|
-
metadata["node_type"] = metadata.pop("type")
|
|
42
|
-
|
|
43
|
-
# Fill timestamps if missing
|
|
44
|
-
metadata.setdefault("created_at", now)
|
|
45
|
-
metadata.setdefault("updated_at", now)
|
|
46
|
-
|
|
47
|
-
# Normalize embedding type
|
|
48
|
-
embedding = metadata.get("embedding")
|
|
49
|
-
if embedding and isinstance(embedding, list):
|
|
50
|
-
metadata["embedding"] = _normalize([float(x) for x in embedding])
|
|
51
|
-
|
|
52
|
-
return metadata
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def _metadata_filter(metadata: dict[str, Any]) -> dict[str, Any]:
|
|
56
|
-
"""
|
|
57
|
-
Filter and validate metadata dictionary against the Memory node schema.
|
|
58
|
-
- Removes keys not in schema.
|
|
59
|
-
- Warns if required fields are missing.
|
|
60
|
-
"""
|
|
61
|
-
|
|
62
|
-
allowed_fields = {
|
|
63
|
-
"id",
|
|
64
|
-
"memory",
|
|
65
|
-
"user_name",
|
|
66
|
-
"user_id",
|
|
67
|
-
"session_id",
|
|
68
|
-
"status",
|
|
69
|
-
"key",
|
|
70
|
-
"confidence",
|
|
71
|
-
"tags",
|
|
72
|
-
"created_at",
|
|
73
|
-
"updated_at",
|
|
74
|
-
"memory_type",
|
|
75
|
-
"sources",
|
|
76
|
-
"source",
|
|
77
|
-
"node_type",
|
|
78
|
-
"visibility",
|
|
79
|
-
"usage",
|
|
80
|
-
"background",
|
|
81
|
-
"embedding",
|
|
82
|
-
}
|
|
83
|
-
|
|
84
|
-
missing_fields = allowed_fields - metadata.keys()
|
|
85
|
-
if missing_fields:
|
|
86
|
-
logger.warning(f"Metadata missing required fields: {sorted(missing_fields)}")
|
|
87
|
-
|
|
88
|
-
filtered_metadata = {k: v for k, v in metadata.items() if k in allowed_fields}
|
|
89
|
-
|
|
90
|
-
return filtered_metadata
|
|
91
|
-
|
|
92
|
-
|
|
36
|
+
@timed
|
|
93
37
|
def _escape_str(value: str) -> str:
|
|
94
38
|
return value.replace('"', '\\"')
|
|
95
39
|
|
|
96
40
|
|
|
97
|
-
|
|
98
|
-
from nebulagraph_python.py_data_types import NVector
|
|
99
|
-
|
|
100
|
-
if isinstance(val, str):
|
|
101
|
-
return f'"{_escape_str(val)}"'
|
|
102
|
-
elif isinstance(val, (int | float)):
|
|
103
|
-
return str(val)
|
|
104
|
-
elif isinstance(val, datetime):
|
|
105
|
-
return f'datetime("{val.isoformat()}")'
|
|
106
|
-
elif isinstance(val, list):
|
|
107
|
-
if key == "embedding":
|
|
108
|
-
dim = len(val)
|
|
109
|
-
joined = ",".join(str(float(x)) for x in val)
|
|
110
|
-
return f"VECTOR<{dim}, FLOAT>([{joined}])"
|
|
111
|
-
else:
|
|
112
|
-
return f"[{', '.join(_format_value(v) for v in val)}]"
|
|
113
|
-
elif isinstance(val, NVector):
|
|
114
|
-
if key == "embedding":
|
|
115
|
-
dim = len(val)
|
|
116
|
-
joined = ",".join(str(float(x)) for x in val)
|
|
117
|
-
return f"VECTOR<{dim}, FLOAT>([{joined}])"
|
|
118
|
-
elif val is None:
|
|
119
|
-
return "NULL"
|
|
120
|
-
else:
|
|
121
|
-
return f'"{_escape_str(str(val))}"'
|
|
122
|
-
|
|
123
|
-
|
|
41
|
+
@timed
|
|
124
42
|
def _format_datetime(value: str | datetime) -> str:
|
|
125
43
|
"""Ensure datetime is in ISO 8601 format string."""
|
|
126
44
|
if isinstance(value, datetime):
|
|
@@ -128,6 +46,21 @@ def _format_datetime(value: str | datetime) -> str:
|
|
|
128
46
|
return str(value)
|
|
129
47
|
|
|
130
48
|
|
|
49
|
+
@timed
|
|
50
|
+
def _normalize_datetime(val):
|
|
51
|
+
"""
|
|
52
|
+
Normalize datetime to ISO 8601 UTC string with +00:00.
|
|
53
|
+
- If val is datetime object -> keep isoformat() (Neo4j)
|
|
54
|
+
- If val is string without timezone -> append +00:00 (Nebula)
|
|
55
|
+
- Otherwise just str()
|
|
56
|
+
"""
|
|
57
|
+
if hasattr(val, "isoformat"):
|
|
58
|
+
return val.isoformat()
|
|
59
|
+
if isinstance(val, str) and not val.endswith(("+00:00", "Z", "+08:00")):
|
|
60
|
+
return val + "+08:00"
|
|
61
|
+
return str(val)
|
|
62
|
+
|
|
63
|
+
|
|
131
64
|
class SessionPoolError(Exception):
|
|
132
65
|
pass
|
|
133
66
|
|
|
@@ -149,6 +82,7 @@ class SessionPool:
|
|
|
149
82
|
self.hosts = hosts
|
|
150
83
|
self.user = user
|
|
151
84
|
self.password = password
|
|
85
|
+
self.minsize = minsize
|
|
152
86
|
self.maxsize = maxsize
|
|
153
87
|
self.pool = Queue(maxsize)
|
|
154
88
|
self.lock = Lock()
|
|
@@ -158,6 +92,7 @@ class SessionPool:
|
|
|
158
92
|
for _ in range(minsize):
|
|
159
93
|
self._create_and_add_client()
|
|
160
94
|
|
|
95
|
+
@timed
|
|
161
96
|
def _create_and_add_client(self):
|
|
162
97
|
from nebulagraph_python import NebulaClient
|
|
163
98
|
|
|
@@ -165,28 +100,37 @@ class SessionPool:
|
|
|
165
100
|
self.pool.put(client)
|
|
166
101
|
self.clients.append(client)
|
|
167
102
|
|
|
103
|
+
@timed
|
|
168
104
|
def get_client(self, timeout: float = 5.0):
|
|
169
|
-
from nebulagraph_python import NebulaClient
|
|
170
|
-
|
|
171
105
|
try:
|
|
172
106
|
return self.pool.get(timeout=timeout)
|
|
173
107
|
except Empty:
|
|
174
108
|
with self.lock:
|
|
175
109
|
if len(self.clients) < self.maxsize:
|
|
110
|
+
from nebulagraph_python import NebulaClient
|
|
111
|
+
|
|
176
112
|
client = NebulaClient(self.hosts, self.user, self.password)
|
|
177
113
|
self.clients.append(client)
|
|
178
114
|
return client
|
|
179
115
|
raise RuntimeError("NebulaClientPool exhausted") from None
|
|
180
116
|
|
|
117
|
+
@timed
|
|
181
118
|
def return_client(self, client):
|
|
182
|
-
|
|
119
|
+
try:
|
|
120
|
+
client.execute("YIELD 1")
|
|
121
|
+
self.pool.put(client)
|
|
122
|
+
except Exception:
|
|
123
|
+
logger.info("[Pool] Client dead, replacing...")
|
|
124
|
+
self.replace_client(client)
|
|
183
125
|
|
|
126
|
+
@timed
|
|
184
127
|
def close(self):
|
|
185
128
|
for client in self.clients:
|
|
186
129
|
with suppress(Exception):
|
|
187
130
|
client.close()
|
|
188
131
|
self.clients.clear()
|
|
189
132
|
|
|
133
|
+
@timed
|
|
190
134
|
def get(self):
|
|
191
135
|
"""
|
|
192
136
|
Context manager: with pool.get() as client:
|
|
@@ -207,6 +151,46 @@ class SessionPool:
|
|
|
207
151
|
|
|
208
152
|
return _ClientContext(self)
|
|
209
153
|
|
|
154
|
+
@timed
|
|
155
|
+
def reset_pool(self):
|
|
156
|
+
"""⚠️ Emergency reset: Close all clients and clear the pool."""
|
|
157
|
+
logger.warning("[Pool] Resetting all clients. Existing sessions will be lost.")
|
|
158
|
+
with self.lock:
|
|
159
|
+
for client in self.clients:
|
|
160
|
+
try:
|
|
161
|
+
client.close()
|
|
162
|
+
except Exception:
|
|
163
|
+
logger.error("Fail to close!!!")
|
|
164
|
+
self.clients.clear()
|
|
165
|
+
while not self.pool.empty():
|
|
166
|
+
try:
|
|
167
|
+
self.pool.get_nowait()
|
|
168
|
+
except Empty:
|
|
169
|
+
break
|
|
170
|
+
for _ in range(self.minsize):
|
|
171
|
+
self._create_and_add_client()
|
|
172
|
+
logger.info("[Pool] Pool has been reset successfully.")
|
|
173
|
+
|
|
174
|
+
@timed
|
|
175
|
+
def replace_client(self, client):
|
|
176
|
+
try:
|
|
177
|
+
client.close()
|
|
178
|
+
except Exception:
|
|
179
|
+
logger.error("Fail to close client")
|
|
180
|
+
|
|
181
|
+
if client in self.clients:
|
|
182
|
+
self.clients.remove(client)
|
|
183
|
+
|
|
184
|
+
from nebulagraph_python import NebulaClient
|
|
185
|
+
|
|
186
|
+
new_client = NebulaClient(self.hosts, self.user, self.password)
|
|
187
|
+
self.clients.append(new_client)
|
|
188
|
+
|
|
189
|
+
self.pool.put(new_client)
|
|
190
|
+
|
|
191
|
+
logger.info("[Pool] Replaced dead client with a new one.")
|
|
192
|
+
return new_client
|
|
193
|
+
|
|
210
194
|
|
|
211
195
|
class NebulaGraphDB(BaseGraphDB):
|
|
212
196
|
"""
|
|
@@ -240,6 +224,33 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
240
224
|
self.config = config
|
|
241
225
|
self.db_name = config.space
|
|
242
226
|
self.user_name = config.user_name
|
|
227
|
+
self.embedding_dimension = config.embedding_dimension
|
|
228
|
+
self.default_memory_dimension = 3072
|
|
229
|
+
self.common_fields = {
|
|
230
|
+
"id",
|
|
231
|
+
"memory",
|
|
232
|
+
"user_name",
|
|
233
|
+
"user_id",
|
|
234
|
+
"session_id",
|
|
235
|
+
"status",
|
|
236
|
+
"key",
|
|
237
|
+
"confidence",
|
|
238
|
+
"tags",
|
|
239
|
+
"created_at",
|
|
240
|
+
"updated_at",
|
|
241
|
+
"memory_type",
|
|
242
|
+
"sources",
|
|
243
|
+
"source",
|
|
244
|
+
"node_type",
|
|
245
|
+
"visibility",
|
|
246
|
+
"usage",
|
|
247
|
+
"background",
|
|
248
|
+
}
|
|
249
|
+
self.dim_field = (
|
|
250
|
+
f"embedding_{self.embedding_dimension}"
|
|
251
|
+
if (str(self.embedding_dimension) != str(self.default_memory_dimension))
|
|
252
|
+
else "embedding"
|
|
253
|
+
)
|
|
243
254
|
self.system_db_name = "system" if config.use_multi_db else config.space
|
|
244
255
|
self.pool = SessionPool(
|
|
245
256
|
hosts=config.get("uri"),
|
|
@@ -259,15 +270,26 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
259
270
|
|
|
260
271
|
logger.info("Connected to NebulaGraph successfully.")
|
|
261
272
|
|
|
273
|
+
@timed
|
|
262
274
|
def execute_query(self, gql: str, timeout: float = 5.0, auto_set_db: bool = True):
|
|
263
275
|
with self.pool.get() as client:
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
276
|
+
try:
|
|
277
|
+
if auto_set_db and self.db_name:
|
|
278
|
+
client.execute(f"SESSION SET GRAPH `{self.db_name}`")
|
|
279
|
+
return client.execute(gql, timeout=timeout)
|
|
280
|
+
|
|
281
|
+
except Exception as e:
|
|
282
|
+
if "Session not found" in str(e) or "Connection not established" in str(e):
|
|
283
|
+
logger.warning(f"[execute_query] {e!s}, replacing client...")
|
|
284
|
+
self.pool.replace_client(client)
|
|
285
|
+
return self.execute_query(gql, timeout, auto_set_db)
|
|
286
|
+
raise
|
|
267
287
|
|
|
288
|
+
@timed
|
|
268
289
|
def close(self):
|
|
269
290
|
self.pool.close()
|
|
270
291
|
|
|
292
|
+
@timed
|
|
271
293
|
def create_index(
|
|
272
294
|
self,
|
|
273
295
|
label: str = "Memory",
|
|
@@ -280,6 +302,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
280
302
|
# Create indexes
|
|
281
303
|
self._create_basic_property_indexes()
|
|
282
304
|
|
|
305
|
+
@timed
|
|
283
306
|
def remove_oldest_memory(self, memory_type: str, keep_latest: int) -> None:
|
|
284
307
|
"""
|
|
285
308
|
Remove all WorkingMemory nodes except the latest `keep_latest` entries.
|
|
@@ -302,6 +325,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
302
325
|
"""
|
|
303
326
|
self.execute_query(query)
|
|
304
327
|
|
|
328
|
+
@timed
|
|
305
329
|
def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None:
|
|
306
330
|
"""
|
|
307
331
|
Insert or update a Memory node in NebulaGraph.
|
|
@@ -318,10 +342,14 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
318
342
|
metadata["memory"] = memory
|
|
319
343
|
|
|
320
344
|
if "embedding" in metadata and isinstance(metadata["embedding"], list):
|
|
321
|
-
metadata["embedding"]
|
|
345
|
+
assert len(metadata["embedding"]) == self.embedding_dimension, (
|
|
346
|
+
f"input embedding dimension must equal to {self.embedding_dimension}"
|
|
347
|
+
)
|
|
348
|
+
embedding = metadata.pop("embedding")
|
|
349
|
+
metadata[self.dim_field] = _normalize(embedding)
|
|
322
350
|
|
|
323
|
-
metadata = _metadata_filter(metadata)
|
|
324
|
-
properties = ", ".join(f"{k}: {_format_value(v, k)}" for k, v in metadata.items())
|
|
351
|
+
metadata = self._metadata_filter(metadata)
|
|
352
|
+
properties = ", ".join(f"{k}: {self._format_value(v, k)}" for k, v in metadata.items())
|
|
325
353
|
gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
|
|
326
354
|
|
|
327
355
|
try:
|
|
@@ -332,16 +360,18 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
332
360
|
f"Failed to insert vertex {id}: gql: {gql}, {e}\ntrace: {traceback.format_exc()}"
|
|
333
361
|
)
|
|
334
362
|
|
|
363
|
+
@timed
|
|
335
364
|
def node_not_exist(self, scope: str) -> int:
|
|
336
365
|
if not self.config.use_multi_db and self.config.user_name:
|
|
337
366
|
filter_clause = f'n.memory_type = "{scope}" AND n.user_name = "{self.config.user_name}"'
|
|
338
367
|
else:
|
|
339
368
|
filter_clause = f'n.memory_type = "{scope}"'
|
|
369
|
+
return_fields = ", ".join(f"n.{field} AS {field}" for field in self.common_fields)
|
|
340
370
|
|
|
341
371
|
query = f"""
|
|
342
372
|
MATCH (n@Memory)
|
|
343
373
|
WHERE {filter_clause}
|
|
344
|
-
RETURN
|
|
374
|
+
RETURN {return_fields}
|
|
345
375
|
LIMIT 1
|
|
346
376
|
"""
|
|
347
377
|
|
|
@@ -352,6 +382,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
352
382
|
logger.error(f"[node_not_exist] Query failed: {e}", exc_info=True)
|
|
353
383
|
raise
|
|
354
384
|
|
|
385
|
+
@timed
|
|
355
386
|
def update_node(self, id: str, fields: dict[str, Any]) -> None:
|
|
356
387
|
"""
|
|
357
388
|
Update node fields in Nebular, auto-converting `created_at` and `updated_at` to datetime type if present.
|
|
@@ -359,7 +390,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
359
390
|
fields = fields.copy()
|
|
360
391
|
set_clauses = []
|
|
361
392
|
for k, v in fields.items():
|
|
362
|
-
set_clauses.append(f"n.{k} = {_format_value(v, k)}")
|
|
393
|
+
set_clauses.append(f"n.{k} = {self._format_value(v, k)}")
|
|
363
394
|
|
|
364
395
|
set_clause_str = ",\n ".join(set_clauses)
|
|
365
396
|
|
|
@@ -373,6 +404,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
373
404
|
query += f"\nSET {set_clause_str}"
|
|
374
405
|
self.execute_query(query)
|
|
375
406
|
|
|
407
|
+
@timed
|
|
376
408
|
def delete_node(self, id: str) -> None:
|
|
377
409
|
"""
|
|
378
410
|
Delete a node from the graph.
|
|
@@ -384,10 +416,11 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
384
416
|
"""
|
|
385
417
|
if not self.config.use_multi_db and self.config.user_name:
|
|
386
418
|
user_name = self.config.user_name
|
|
387
|
-
query += f" WHERE n.user_name = {_format_value(user_name)}"
|
|
419
|
+
query += f" WHERE n.user_name = {self._format_value(user_name)}"
|
|
388
420
|
query += "\n DETACH DELETE n"
|
|
389
421
|
self.execute_query(query)
|
|
390
422
|
|
|
423
|
+
@timed
|
|
391
424
|
def add_edge(self, source_id: str, target_id: str, type: str):
|
|
392
425
|
"""
|
|
393
426
|
Create an edge from source node to target node.
|
|
@@ -412,6 +445,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
412
445
|
except Exception as e:
|
|
413
446
|
logger.error(f"Failed to insert edge: {e}", exc_info=True)
|
|
414
447
|
|
|
448
|
+
@timed
|
|
415
449
|
def delete_edge(self, source_id: str, target_id: str, type: str) -> None:
|
|
416
450
|
"""
|
|
417
451
|
Delete a specific edge between two nodes.
|
|
@@ -422,16 +456,17 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
422
456
|
"""
|
|
423
457
|
query = f"""
|
|
424
458
|
MATCH (a@Memory) -[r@{type}]-> (b@Memory)
|
|
425
|
-
WHERE a.id = {_format_value(source_id)} AND b.id = {_format_value(target_id)}
|
|
459
|
+
WHERE a.id = {self._format_value(source_id)} AND b.id = {self._format_value(target_id)}
|
|
426
460
|
"""
|
|
427
461
|
|
|
428
462
|
if not self.config.use_multi_db and self.config.user_name:
|
|
429
463
|
user_name = self.config.user_name
|
|
430
|
-
query += f" AND a.user_name = {_format_value(user_name)} AND b.user_name = {_format_value(user_name)}"
|
|
464
|
+
query += f" AND a.user_name = {self._format_value(user_name)} AND b.user_name = {self._format_value(user_name)}"
|
|
431
465
|
|
|
432
466
|
query += "\nDELETE r"
|
|
433
467
|
self.execute_query(query)
|
|
434
468
|
|
|
469
|
+
@timed
|
|
435
470
|
def get_memory_count(self, memory_type: str) -> int:
|
|
436
471
|
query = f"""
|
|
437
472
|
MATCH (n@Memory)
|
|
@@ -449,6 +484,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
449
484
|
logger.error(f"[get_memory_count] Failed: {e}")
|
|
450
485
|
return -1
|
|
451
486
|
|
|
487
|
+
@timed
|
|
452
488
|
def count_nodes(self, scope: str) -> int:
|
|
453
489
|
query = f"""
|
|
454
490
|
MATCH (n@Memory)
|
|
@@ -462,6 +498,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
462
498
|
result = self.execute_query(query)
|
|
463
499
|
return result.one_or_none()["count"].value
|
|
464
500
|
|
|
501
|
+
@timed
|
|
465
502
|
def edge_exists(
|
|
466
503
|
self, source_id: str, target_id: str, type: str = "ANY", direction: str = "OUTGOING"
|
|
467
504
|
) -> bool:
|
|
@@ -503,43 +540,54 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
503
540
|
return False
|
|
504
541
|
return record.values() is not None
|
|
505
542
|
|
|
543
|
+
@timed
|
|
506
544
|
# Graph Query & Reasoning
|
|
507
|
-
def get_node(self, id: str) -> dict[str, Any] | None:
|
|
545
|
+
def get_node(self, id: str, include_embedding: bool = False) -> dict[str, Any] | None:
|
|
508
546
|
"""
|
|
509
547
|
Retrieve a Memory node by its unique ID.
|
|
510
548
|
|
|
511
549
|
Args:
|
|
512
550
|
id (str): Node ID (Memory.id)
|
|
551
|
+
include_embedding: with/without embedding
|
|
513
552
|
|
|
514
553
|
Returns:
|
|
515
554
|
dict: Node properties as key-value pairs, or None if not found.
|
|
516
555
|
"""
|
|
556
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
557
|
+
filter_clause = f'n.user_name = "{self.config.user_name}" AND n.id = "{id}"'
|
|
558
|
+
else:
|
|
559
|
+
filter_clause = f'n.id = "{id}"'
|
|
560
|
+
|
|
561
|
+
return_fields = self._build_return_fields(include_embedding)
|
|
517
562
|
gql = f"""
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
563
|
+
MATCH (n@Memory)
|
|
564
|
+
WHERE {filter_clause}
|
|
565
|
+
RETURN {return_fields}
|
|
566
|
+
"""
|
|
522
567
|
|
|
523
568
|
try:
|
|
524
569
|
result = self.execute_query(gql)
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
return node
|
|
570
|
+
for row in result:
|
|
571
|
+
if include_embedding:
|
|
572
|
+
props = row.values()[0].as_node().get_properties()
|
|
573
|
+
else:
|
|
574
|
+
props = {k: v.value for k, v in row.items()}
|
|
575
|
+
node = self._parse_node(props)
|
|
576
|
+
return node
|
|
533
577
|
|
|
534
578
|
except Exception as e:
|
|
535
|
-
logger.error(
|
|
579
|
+
logger.error(
|
|
580
|
+
f"[get_node] Failed to retrieve node '{id}': {e}, trace: {traceback.format_exc()}"
|
|
581
|
+
)
|
|
536
582
|
return None
|
|
537
583
|
|
|
538
|
-
|
|
584
|
+
@timed
|
|
585
|
+
def get_nodes(self, ids: list[str], include_embedding: bool = False) -> list[dict[str, Any]]:
|
|
539
586
|
"""
|
|
540
587
|
Retrieve the metadata and memory of a list of nodes.
|
|
541
588
|
Args:
|
|
542
589
|
ids: List of Node identifier.
|
|
590
|
+
include_embedding: with/without embedding
|
|
543
591
|
Returns:
|
|
544
592
|
list[dict]: Parsed node records containing 'id', 'memory', and 'metadata'.
|
|
545
593
|
|
|
@@ -554,16 +602,31 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
554
602
|
if not self.config.use_multi_db and self.config.user_name:
|
|
555
603
|
where_user = f" AND n.user_name = '{self.config.user_name}'"
|
|
556
604
|
|
|
557
|
-
|
|
605
|
+
# Safe formatting of the ID list
|
|
606
|
+
id_list = ",".join(f'"{_id}"' for _id in ids)
|
|
558
607
|
|
|
559
|
-
|
|
608
|
+
return_fields = self._build_return_fields(include_embedding)
|
|
609
|
+
query = f"""
|
|
610
|
+
MATCH (n@Memory)
|
|
611
|
+
WHERE n.id IN [{id_list}] {where_user}
|
|
612
|
+
RETURN {return_fields}
|
|
613
|
+
"""
|
|
560
614
|
nodes = []
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
615
|
+
try:
|
|
616
|
+
results = self.execute_query(query)
|
|
617
|
+
for row in results:
|
|
618
|
+
if include_embedding:
|
|
619
|
+
props = row.values()[0].as_node().get_properties()
|
|
620
|
+
else:
|
|
621
|
+
props = {k: v.value for k, v in row.items()}
|
|
622
|
+
nodes.append(self._parse_node(props))
|
|
623
|
+
except Exception as e:
|
|
624
|
+
logger.error(
|
|
625
|
+
f"[get_nodes] Failed to retrieve nodes {ids}: {e}, trace: {traceback.format_exc()}"
|
|
626
|
+
)
|
|
565
627
|
return nodes
|
|
566
628
|
|
|
629
|
+
@timed
|
|
567
630
|
def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[dict[str, str]]:
|
|
568
631
|
"""
|
|
569
632
|
Get edges connected to a node, with optional type and direction filter.
|
|
@@ -617,6 +680,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
617
680
|
)
|
|
618
681
|
return edges
|
|
619
682
|
|
|
683
|
+
@timed
|
|
620
684
|
def get_neighbors_by_tag(
|
|
621
685
|
self,
|
|
622
686
|
tags: list[str],
|
|
@@ -681,6 +745,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
681
745
|
result.append(neighbor)
|
|
682
746
|
return result
|
|
683
747
|
|
|
748
|
+
@timed
|
|
684
749
|
def get_children_with_embeddings(self, id: str) -> list[dict[str, Any]]:
|
|
685
750
|
where_user = ""
|
|
686
751
|
|
|
@@ -691,19 +756,20 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
691
756
|
query = f"""
|
|
692
757
|
MATCH (p@Memory)-[@PARENT]->(c@Memory)
|
|
693
758
|
WHERE p.id = "{id}" {where_user}
|
|
694
|
-
RETURN c.id AS id, c.
|
|
759
|
+
RETURN c.id AS id, c.{self.dim_field} AS {self.dim_field}, c.memory AS memory
|
|
695
760
|
"""
|
|
696
761
|
result = self.execute_query(query)
|
|
697
762
|
children = []
|
|
698
763
|
for row in result:
|
|
699
764
|
eid = row["id"].value # STRING
|
|
700
|
-
emb_v = row[
|
|
765
|
+
emb_v = row[self.dim_field].value # NVector
|
|
701
766
|
emb = list(emb_v.values) if emb_v else []
|
|
702
767
|
mem = row["memory"].value # STRING
|
|
703
768
|
|
|
704
769
|
children.append({"id": eid, "embedding": emb, "memory": mem})
|
|
705
770
|
return children
|
|
706
771
|
|
|
772
|
+
@timed
|
|
707
773
|
def get_subgraph(
|
|
708
774
|
self, center_id: str, depth: int = 2, center_status: str = "activated"
|
|
709
775
|
) -> dict[str, Any]:
|
|
@@ -765,6 +831,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
765
831
|
|
|
766
832
|
return {"core_node": core_node, "neighbors": neighbors, "edges": edges}
|
|
767
833
|
|
|
834
|
+
@timed
|
|
768
835
|
# Search / recall operations
|
|
769
836
|
def search_by_embedding(
|
|
770
837
|
self,
|
|
@@ -815,11 +882,11 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
815
882
|
USE `{self.db_name}`
|
|
816
883
|
MATCH (n@Memory)
|
|
817
884
|
{where_clause}
|
|
818
|
-
ORDER BY inner_product(n.
|
|
885
|
+
ORDER BY inner_product(n.{self.dim_field}, {gql_vector}) DESC
|
|
819
886
|
APPROXIMATE
|
|
820
887
|
LIMIT {top_k}
|
|
821
888
|
OPTIONS {{ METRIC: IP, TYPE: IVF, NPROBE: 8 }}
|
|
822
|
-
RETURN n.id AS id, inner_product(n.
|
|
889
|
+
RETURN n.id AS id, inner_product(n.{self.dim_field}, {gql_vector}) AS score
|
|
823
890
|
"""
|
|
824
891
|
|
|
825
892
|
try:
|
|
@@ -842,6 +909,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
842
909
|
logger.error(f"[search_by_embedding] Result parse failed: {e}")
|
|
843
910
|
return []
|
|
844
911
|
|
|
912
|
+
@timed
|
|
845
913
|
def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]:
|
|
846
914
|
"""
|
|
847
915
|
1. ADD logic: "AND" vs "OR"(support logic combination);
|
|
@@ -912,6 +980,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
912
980
|
logger.error(f"Failed to get metadata: {e}, gql is {gql}")
|
|
913
981
|
return ids
|
|
914
982
|
|
|
983
|
+
@timed
|
|
915
984
|
def get_grouped_counts(
|
|
916
985
|
self,
|
|
917
986
|
group_fields: list[str],
|
|
@@ -980,6 +1049,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
980
1049
|
|
|
981
1050
|
return output
|
|
982
1051
|
|
|
1052
|
+
@timed
|
|
983
1053
|
def clear(self) -> None:
|
|
984
1054
|
"""
|
|
985
1055
|
Clear the entire graph if the target database exists.
|
|
@@ -996,9 +1066,12 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
996
1066
|
except Exception as e:
|
|
997
1067
|
logger.error(f"[ERROR] Failed to clear database: {e}")
|
|
998
1068
|
|
|
999
|
-
|
|
1069
|
+
@timed
|
|
1070
|
+
def export_graph(self, include_embedding: bool = False) -> dict[str, Any]:
|
|
1000
1071
|
"""
|
|
1001
1072
|
Export all graph nodes and edges in a structured form.
|
|
1073
|
+
Args:
|
|
1074
|
+
include_embedding (bool): Whether to include the large embedding field.
|
|
1002
1075
|
|
|
1003
1076
|
Returns:
|
|
1004
1077
|
{
|
|
@@ -1015,13 +1088,41 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1015
1088
|
edge_query += f' WHERE r.user_name = "{username}"'
|
|
1016
1089
|
|
|
1017
1090
|
try:
|
|
1018
|
-
|
|
1019
|
-
|
|
1091
|
+
if include_embedding:
|
|
1092
|
+
return_fields = "n"
|
|
1093
|
+
else:
|
|
1094
|
+
return_fields = ",".join(
|
|
1095
|
+
[
|
|
1096
|
+
"n.id AS id",
|
|
1097
|
+
"n.memory AS memory",
|
|
1098
|
+
"n.user_name AS user_name",
|
|
1099
|
+
"n.user_id AS user_id",
|
|
1100
|
+
"n.session_id AS session_id",
|
|
1101
|
+
"n.status AS status",
|
|
1102
|
+
"n.key AS key",
|
|
1103
|
+
"n.confidence AS confidence",
|
|
1104
|
+
"n.tags AS tags",
|
|
1105
|
+
"n.created_at AS created_at",
|
|
1106
|
+
"n.updated_at AS updated_at",
|
|
1107
|
+
"n.memory_type AS memory_type",
|
|
1108
|
+
"n.sources AS sources",
|
|
1109
|
+
"n.source AS source",
|
|
1110
|
+
"n.node_type AS node_type",
|
|
1111
|
+
"n.visibility AS visibility",
|
|
1112
|
+
"n.usage AS usage",
|
|
1113
|
+
"n.background AS background",
|
|
1114
|
+
]
|
|
1115
|
+
)
|
|
1116
|
+
|
|
1117
|
+
full_node_query = f"{node_query} RETURN {return_fields}"
|
|
1118
|
+
node_result = self.execute_query(full_node_query, timeout=20)
|
|
1020
1119
|
nodes = []
|
|
1120
|
+
logger.debug(f"Debugging: {node_result}")
|
|
1021
1121
|
for row in node_result:
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1122
|
+
if include_embedding:
|
|
1123
|
+
props = row.values()[0].as_node().get_properties()
|
|
1124
|
+
else:
|
|
1125
|
+
props = {k: v.value for k, v in row.items()}
|
|
1025
1126
|
node = self._parse_node(props)
|
|
1026
1127
|
nodes.append(node)
|
|
1027
1128
|
except Exception as e:
|
|
@@ -1029,7 +1130,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1029
1130
|
|
|
1030
1131
|
try:
|
|
1031
1132
|
full_edge_query = f"{edge_query} RETURN a.id AS source, b.id AS target, type(r) as edge"
|
|
1032
|
-
edge_result = self.execute_query(full_edge_query)
|
|
1133
|
+
edge_result = self.execute_query(full_edge_query, timeout=20)
|
|
1033
1134
|
edges = [
|
|
1034
1135
|
{
|
|
1035
1136
|
"source": row.values()[0].value,
|
|
@@ -1043,6 +1144,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1043
1144
|
|
|
1044
1145
|
return {"nodes": nodes, "edges": edges}
|
|
1045
1146
|
|
|
1147
|
+
@timed
|
|
1046
1148
|
def import_graph(self, data: dict[str, Any]) -> None:
|
|
1047
1149
|
"""
|
|
1048
1150
|
Import the entire graph from a serialized dictionary.
|
|
@@ -1056,9 +1158,9 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1056
1158
|
if not self.config.use_multi_db and self.config.user_name:
|
|
1057
1159
|
metadata["user_name"] = self.config.user_name
|
|
1058
1160
|
|
|
1059
|
-
metadata = _prepare_node_metadata(metadata)
|
|
1161
|
+
metadata = self._prepare_node_metadata(metadata)
|
|
1060
1162
|
metadata.update({"id": id, "memory": memory})
|
|
1061
|
-
properties = ", ".join(f"{k}: {_format_value(v, k)}" for k, v in metadata.items())
|
|
1163
|
+
properties = ", ".join(f"{k}: {self._format_value(v, k)}" for k, v in metadata.items())
|
|
1062
1164
|
node_gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
|
|
1063
1165
|
self.execute_query(node_gql)
|
|
1064
1166
|
|
|
@@ -1074,12 +1176,14 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1074
1176
|
'''
|
|
1075
1177
|
self.execute_query(edge_gql)
|
|
1076
1178
|
|
|
1077
|
-
|
|
1179
|
+
@timed
|
|
1180
|
+
def get_all_memory_items(self, scope: str, include_embedding: bool = False) -> (list)[dict]:
|
|
1078
1181
|
"""
|
|
1079
1182
|
Retrieve all memory items of a specific memory_type.
|
|
1080
1183
|
|
|
1081
1184
|
Args:
|
|
1082
1185
|
scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'.
|
|
1186
|
+
include_embedding: with/without embedding
|
|
1083
1187
|
|
|
1084
1188
|
Returns:
|
|
1085
1189
|
list[dict]: Full list of memory items under this scope.
|
|
@@ -1092,22 +1196,31 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1092
1196
|
if not self.config.use_multi_db and self.config.user_name:
|
|
1093
1197
|
where_clause += f" AND n.user_name = '{self.config.user_name}'"
|
|
1094
1198
|
|
|
1199
|
+
return_fields = self._build_return_fields(include_embedding)
|
|
1200
|
+
|
|
1095
1201
|
query = f"""
|
|
1096
1202
|
MATCH (n@Memory)
|
|
1097
1203
|
{where_clause}
|
|
1098
|
-
RETURN
|
|
1204
|
+
RETURN {return_fields}
|
|
1205
|
+
LIMIT 100
|
|
1099
1206
|
"""
|
|
1100
1207
|
nodes = []
|
|
1101
1208
|
try:
|
|
1102
1209
|
results = self.execute_query(query)
|
|
1103
|
-
for
|
|
1104
|
-
|
|
1105
|
-
|
|
1210
|
+
for row in results:
|
|
1211
|
+
if include_embedding:
|
|
1212
|
+
props = row.values()[0].as_node().get_properties()
|
|
1213
|
+
else:
|
|
1214
|
+
props = {k: v.value for k, v in row.items()}
|
|
1215
|
+
nodes.append(self._parse_node(props))
|
|
1106
1216
|
except Exception as e:
|
|
1107
1217
|
logger.error(f"Failed to get memories: {e}")
|
|
1108
1218
|
return nodes
|
|
1109
1219
|
|
|
1110
|
-
|
|
1220
|
+
@timed
|
|
1221
|
+
def get_structure_optimization_candidates(
|
|
1222
|
+
self, scope: str, include_embedding: bool = False
|
|
1223
|
+
) -> list[dict]:
|
|
1111
1224
|
"""
|
|
1112
1225
|
Find nodes that are likely candidates for structure optimization:
|
|
1113
1226
|
- Isolated nodes, nodes with empty background, or nodes with exactly one child.
|
|
@@ -1121,6 +1234,8 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1121
1234
|
if not self.config.use_multi_db and self.config.user_name:
|
|
1122
1235
|
where_clause += f' AND n.user_name = "{self.config.user_name}"'
|
|
1123
1236
|
|
|
1237
|
+
return_fields = self._build_return_fields(include_embedding)
|
|
1238
|
+
|
|
1124
1239
|
query = f"""
|
|
1125
1240
|
USE `{self.db_name}`
|
|
1126
1241
|
MATCH (n@Memory)
|
|
@@ -1128,19 +1243,23 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1128
1243
|
OPTIONAL MATCH (n)-[@PARENT]->(c@Memory)
|
|
1129
1244
|
OPTIONAL MATCH (p@Memory)-[@PARENT]->(n)
|
|
1130
1245
|
WHERE c IS NULL AND p IS NULL
|
|
1131
|
-
RETURN
|
|
1246
|
+
RETURN {return_fields}
|
|
1132
1247
|
"""
|
|
1133
1248
|
|
|
1134
1249
|
candidates = []
|
|
1135
1250
|
try:
|
|
1136
1251
|
results = self.execute_query(query)
|
|
1137
|
-
for
|
|
1138
|
-
|
|
1139
|
-
|
|
1252
|
+
for row in results:
|
|
1253
|
+
if include_embedding:
|
|
1254
|
+
props = row.values()[0].as_node().get_properties()
|
|
1255
|
+
else:
|
|
1256
|
+
props = {k: v.value for k, v in row.items()}
|
|
1257
|
+
candidates.append(self._parse_node(props))
|
|
1140
1258
|
except Exception as e:
|
|
1141
|
-
logger.error(f"Failed : {e}")
|
|
1259
|
+
logger.error(f"Failed : {e}, traceback: {traceback.format_exc()}")
|
|
1142
1260
|
return candidates
|
|
1143
1261
|
|
|
1262
|
+
@timed
|
|
1144
1263
|
def drop_database(self) -> None:
|
|
1145
1264
|
"""
|
|
1146
1265
|
Permanently delete the entire database this instance is using.
|
|
@@ -1155,6 +1274,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1155
1274
|
f"Shared Database Multi-Tenant mode"
|
|
1156
1275
|
)
|
|
1157
1276
|
|
|
1277
|
+
@timed
|
|
1158
1278
|
def detect_conflicts(self) -> list[tuple[str, str]]:
|
|
1159
1279
|
"""
|
|
1160
1280
|
Detect conflicting nodes based on logical or semantic inconsistency.
|
|
@@ -1163,6 +1283,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1163
1283
|
"""
|
|
1164
1284
|
raise NotImplementedError
|
|
1165
1285
|
|
|
1286
|
+
@timed
|
|
1166
1287
|
# Structure Maintenance
|
|
1167
1288
|
def deduplicate_nodes(self) -> None:
|
|
1168
1289
|
"""
|
|
@@ -1171,6 +1292,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1171
1292
|
"""
|
|
1172
1293
|
raise NotImplementedError
|
|
1173
1294
|
|
|
1295
|
+
@timed
|
|
1174
1296
|
def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]:
|
|
1175
1297
|
"""
|
|
1176
1298
|
Get the ordered context chain starting from a node, following a relationship type.
|
|
@@ -1182,6 +1304,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1182
1304
|
"""
|
|
1183
1305
|
raise NotImplementedError
|
|
1184
1306
|
|
|
1307
|
+
@timed
|
|
1185
1308
|
def get_neighbors(
|
|
1186
1309
|
self, id: str, type: str, direction: Literal["in", "out", "both"] = "out"
|
|
1187
1310
|
) -> list[str]:
|
|
@@ -1196,6 +1319,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1196
1319
|
"""
|
|
1197
1320
|
raise NotImplementedError
|
|
1198
1321
|
|
|
1322
|
+
@timed
|
|
1199
1323
|
def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[str]:
|
|
1200
1324
|
"""
|
|
1201
1325
|
Get the path of nodes from source to target within a limited depth.
|
|
@@ -1208,6 +1332,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1208
1332
|
"""
|
|
1209
1333
|
raise NotImplementedError
|
|
1210
1334
|
|
|
1335
|
+
@timed
|
|
1211
1336
|
def merge_nodes(self, id1: str, id2: str) -> str:
|
|
1212
1337
|
"""
|
|
1213
1338
|
Merge two similar or duplicate nodes into one.
|
|
@@ -1219,70 +1344,112 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1219
1344
|
"""
|
|
1220
1345
|
raise NotImplementedError
|
|
1221
1346
|
|
|
1347
|
+
@timed
|
|
1222
1348
|
def _ensure_database_exists(self):
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
|
|
1250
|
-
|
|
1251
|
-
|
|
1252
|
-
|
|
1253
|
-
|
|
1254
|
-
|
|
1255
|
-
|
|
1349
|
+
graph_type_name = "MemOSBgeM3Type"
|
|
1350
|
+
|
|
1351
|
+
check_type_query = "SHOW GRAPH TYPES"
|
|
1352
|
+
result = self.execute_query(check_type_query, auto_set_db=False)
|
|
1353
|
+
|
|
1354
|
+
type_exists = any(row["graph_type"].as_string() == graph_type_name for row in result)
|
|
1355
|
+
|
|
1356
|
+
if not type_exists:
|
|
1357
|
+
create_tag = f"""
|
|
1358
|
+
CREATE GRAPH TYPE IF NOT EXISTS {graph_type_name} AS {{
|
|
1359
|
+
NODE Memory (:MemoryTag {{
|
|
1360
|
+
id STRING,
|
|
1361
|
+
memory STRING,
|
|
1362
|
+
user_name STRING,
|
|
1363
|
+
user_id STRING,
|
|
1364
|
+
session_id STRING,
|
|
1365
|
+
status STRING,
|
|
1366
|
+
key STRING,
|
|
1367
|
+
confidence FLOAT,
|
|
1368
|
+
tags LIST<STRING>,
|
|
1369
|
+
created_at STRING,
|
|
1370
|
+
updated_at STRING,
|
|
1371
|
+
memory_type STRING,
|
|
1372
|
+
sources LIST<STRING>,
|
|
1373
|
+
source STRING,
|
|
1374
|
+
node_type STRING,
|
|
1375
|
+
visibility STRING,
|
|
1376
|
+
usage LIST<STRING>,
|
|
1377
|
+
background STRING,
|
|
1378
|
+
{self.dim_field} VECTOR<{self.embedding_dimension}, FLOAT>,
|
|
1379
|
+
PRIMARY KEY(id)
|
|
1380
|
+
}}),
|
|
1381
|
+
EDGE RELATE_TO (Memory) -[{{user_name STRING}}]-> (Memory),
|
|
1382
|
+
EDGE PARENT (Memory) -[{{user_name STRING}}]-> (Memory),
|
|
1383
|
+
EDGE AGGREGATE_TO (Memory) -[{{user_name STRING}}]-> (Memory),
|
|
1384
|
+
EDGE MERGED_TO (Memory) -[{{user_name STRING}}]-> (Memory),
|
|
1385
|
+
EDGE INFERS (Memory) -[{{user_name STRING}}]-> (Memory),
|
|
1386
|
+
EDGE FOLLOWS (Memory) -[{{user_name STRING}}]-> (Memory)
|
|
1387
|
+
}}
|
|
1388
|
+
"""
|
|
1389
|
+
self.execute_query(create_tag, auto_set_db=False)
|
|
1390
|
+
else:
|
|
1391
|
+
describe_query = f"DESCRIBE NODE TYPE Memory OF {graph_type_name};"
|
|
1392
|
+
desc_result = self.execute_query(describe_query, auto_set_db=False)
|
|
1393
|
+
|
|
1394
|
+
memory_fields = []
|
|
1395
|
+
for row in desc_result:
|
|
1396
|
+
field_name = row.values()[0].as_string()
|
|
1397
|
+
memory_fields.append(field_name)
|
|
1398
|
+
|
|
1399
|
+
if self.dim_field not in memory_fields:
|
|
1400
|
+
alter_query = f"""
|
|
1401
|
+
ALTER GRAPH TYPE {graph_type_name} {{
|
|
1402
|
+
ALTER NODE TYPE Memory ADD PROPERTIES {{ {self.dim_field} VECTOR<{self.embedding_dimension}, FLOAT> }}
|
|
1403
|
+
}}
|
|
1404
|
+
"""
|
|
1405
|
+
self.execute_query(alter_query, auto_set_db=False)
|
|
1406
|
+
logger.info(f"✅ Add new vector search {self.dim_field} to {graph_type_name}")
|
|
1407
|
+
else:
|
|
1408
|
+
logger.info(f"✅ Graph Type {graph_type_name} already include {self.dim_field}")
|
|
1409
|
+
|
|
1410
|
+
create_graph = f"CREATE GRAPH IF NOT EXISTS `{self.db_name}` TYPED {graph_type_name}"
|
|
1256
1411
|
set_graph_working = f"SESSION SET GRAPH `{self.db_name}`"
|
|
1257
1412
|
|
|
1258
1413
|
try:
|
|
1259
|
-
self.execute_query(create_tag, auto_set_db=False)
|
|
1260
1414
|
self.execute_query(create_graph, auto_set_db=False)
|
|
1261
1415
|
self.execute_query(set_graph_working)
|
|
1262
1416
|
logger.info(f"✅ Graph ``{self.db_name}`` is now the working graph.")
|
|
1263
1417
|
except Exception as e:
|
|
1264
1418
|
logger.error(f"❌ Failed to create tag: {e} trace: {traceback.format_exc()}")
|
|
1265
1419
|
|
|
1420
|
+
@timed
|
|
1266
1421
|
def _create_vector_index(
|
|
1267
1422
|
self, label: str, vector_property: str, dimensions: int, index_name: str
|
|
1268
1423
|
) -> None:
|
|
1269
1424
|
"""
|
|
1270
1425
|
Create a vector index for the specified property in the label.
|
|
1271
1426
|
"""
|
|
1427
|
+
if str(dimensions) == str(self.default_memory_dimension):
|
|
1428
|
+
index_name = f"idx_{vector_property}"
|
|
1429
|
+
vector_name = vector_property
|
|
1430
|
+
else:
|
|
1431
|
+
index_name = f"idx_{vector_property}_{dimensions}"
|
|
1432
|
+
vector_name = f"{vector_property}_{dimensions}"
|
|
1433
|
+
|
|
1272
1434
|
create_vector_index = f"""
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
|
|
1435
|
+
CREATE VECTOR INDEX IF NOT EXISTS {index_name}
|
|
1436
|
+
ON NODE {label}::{vector_name}
|
|
1437
|
+
OPTIONS {{
|
|
1438
|
+
DIM: {dimensions},
|
|
1439
|
+
METRIC: IP,
|
|
1440
|
+
TYPE: IVF,
|
|
1441
|
+
NLIST: 100,
|
|
1442
|
+
TRAINSIZE: 1000
|
|
1443
|
+
}}
|
|
1444
|
+
FOR `{self.db_name}`
|
|
1445
|
+
"""
|
|
1284
1446
|
self.execute_query(create_vector_index)
|
|
1447
|
+
logger.info(
|
|
1448
|
+
f"✅ Ensure {label}::{vector_property} vector index {index_name} "
|
|
1449
|
+
f"exists (DIM={dimensions})"
|
|
1450
|
+
)
|
|
1285
1451
|
|
|
1452
|
+
@timed
|
|
1286
1453
|
def _create_basic_property_indexes(self) -> None:
|
|
1287
1454
|
"""
|
|
1288
1455
|
Create standard B-tree indexes on status, memory_type, created_at
|
|
@@ -1304,8 +1471,11 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1304
1471
|
self.execute_query(gql)
|
|
1305
1472
|
logger.info(f"✅ Created index: {index_name} on field {field}")
|
|
1306
1473
|
except Exception as e:
|
|
1307
|
-
logger.error(
|
|
1474
|
+
logger.error(
|
|
1475
|
+
f"❌ Failed to create index {index_name}: {e}, trace: {traceback.format_exc()}"
|
|
1476
|
+
)
|
|
1308
1477
|
|
|
1478
|
+
@timed
|
|
1309
1479
|
def _index_exists(self, index_name: str) -> bool:
|
|
1310
1480
|
"""
|
|
1311
1481
|
Check if an index with the given name exists.
|
|
@@ -1327,6 +1497,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1327
1497
|
logger.error(f"[Nebula] Failed to check index existence: {e}")
|
|
1328
1498
|
return False
|
|
1329
1499
|
|
|
1500
|
+
@timed
|
|
1330
1501
|
def _parse_value(self, value: Any) -> Any:
|
|
1331
1502
|
"""turn Nebula ValueWrapper to Python type"""
|
|
1332
1503
|
from nebulagraph_python.value_wrapper import ValueWrapper
|
|
@@ -1352,8 +1523,8 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1352
1523
|
parsed = {k: self._parse_value(v) for k, v in props.items()}
|
|
1353
1524
|
|
|
1354
1525
|
for tf in ("created_at", "updated_at"):
|
|
1355
|
-
if tf in parsed and
|
|
1356
|
-
parsed[tf] = parsed[tf]
|
|
1526
|
+
if tf in parsed and parsed[tf] is not None:
|
|
1527
|
+
parsed[tf] = _normalize_datetime(parsed[tf])
|
|
1357
1528
|
|
|
1358
1529
|
node_id = parsed.pop("id")
|
|
1359
1530
|
memory = parsed.pop("memory", "")
|
|
@@ -1361,4 +1532,81 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1361
1532
|
metadata = parsed
|
|
1362
1533
|
metadata["type"] = metadata.pop("node_type")
|
|
1363
1534
|
|
|
1535
|
+
if self.dim_field in metadata:
|
|
1536
|
+
metadata["embedding"] = metadata.pop(self.dim_field)
|
|
1537
|
+
|
|
1364
1538
|
return {"id": node_id, "memory": memory, "metadata": metadata}
|
|
1539
|
+
|
|
1540
|
+
@timed
|
|
1541
|
+
def _prepare_node_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]:
|
|
1542
|
+
"""
|
|
1543
|
+
Ensure metadata has proper datetime fields and normalized types.
|
|
1544
|
+
|
|
1545
|
+
- Fill `created_at` and `updated_at` if missing (in ISO 8601 format).
|
|
1546
|
+
- Convert embedding to list of float if present.
|
|
1547
|
+
"""
|
|
1548
|
+
now = datetime.utcnow().isoformat()
|
|
1549
|
+
metadata["node_type"] = metadata.pop("type")
|
|
1550
|
+
|
|
1551
|
+
# Fill timestamps if missing
|
|
1552
|
+
metadata.setdefault("created_at", now)
|
|
1553
|
+
metadata.setdefault("updated_at", now)
|
|
1554
|
+
|
|
1555
|
+
# Normalize embedding type
|
|
1556
|
+
embedding = metadata.get("embedding")
|
|
1557
|
+
if embedding and isinstance(embedding, list):
|
|
1558
|
+
metadata[self.dim_field] = _normalize([float(x) for x in embedding])
|
|
1559
|
+
|
|
1560
|
+
return metadata
|
|
1561
|
+
|
|
1562
|
+
@timed
|
|
1563
|
+
def _format_value(self, val: Any, key: str = "") -> str:
|
|
1564
|
+
from nebulagraph_python.py_data_types import NVector
|
|
1565
|
+
|
|
1566
|
+
if isinstance(val, str):
|
|
1567
|
+
return f'"{_escape_str(val)}"'
|
|
1568
|
+
elif isinstance(val, (int | float)):
|
|
1569
|
+
return str(val)
|
|
1570
|
+
elif isinstance(val, datetime):
|
|
1571
|
+
return f'datetime("{val.isoformat()}")'
|
|
1572
|
+
elif isinstance(val, list):
|
|
1573
|
+
if key == self.dim_field:
|
|
1574
|
+
dim = len(val)
|
|
1575
|
+
joined = ",".join(str(float(x)) for x in val)
|
|
1576
|
+
return f"VECTOR<{dim}, FLOAT>([{joined}])"
|
|
1577
|
+
else:
|
|
1578
|
+
return f"[{', '.join(self._format_value(v) for v in val)}]"
|
|
1579
|
+
elif isinstance(val, NVector):
|
|
1580
|
+
if key == self.dim_field:
|
|
1581
|
+
dim = len(val)
|
|
1582
|
+
joined = ",".join(str(float(x)) for x in val)
|
|
1583
|
+
return f"VECTOR<{dim}, FLOAT>([{joined}])"
|
|
1584
|
+
elif val is None:
|
|
1585
|
+
return "NULL"
|
|
1586
|
+
else:
|
|
1587
|
+
return f'"{_escape_str(str(val))}"'
|
|
1588
|
+
|
|
1589
|
+
@timed
|
|
1590
|
+
def _metadata_filter(self, metadata: dict[str, Any]) -> dict[str, Any]:
|
|
1591
|
+
"""
|
|
1592
|
+
Filter and validate metadata dictionary against the Memory node schema.
|
|
1593
|
+
- Removes keys not in schema.
|
|
1594
|
+
- Warns if required fields are missing.
|
|
1595
|
+
"""
|
|
1596
|
+
|
|
1597
|
+
dim_fields = {self.dim_field}
|
|
1598
|
+
|
|
1599
|
+
allowed_fields = self.common_fields | dim_fields
|
|
1600
|
+
|
|
1601
|
+
missing_fields = allowed_fields - metadata.keys()
|
|
1602
|
+
if missing_fields:
|
|
1603
|
+
logger.info(f"Metadata missing required fields: {sorted(missing_fields)}")
|
|
1604
|
+
|
|
1605
|
+
filtered_metadata = {k: v for k, v in metadata.items() if k in allowed_fields}
|
|
1606
|
+
|
|
1607
|
+
return filtered_metadata
|
|
1608
|
+
|
|
1609
|
+
def _build_return_fields(self, include_embedding: bool = False) -> str:
|
|
1610
|
+
if include_embedding:
|
|
1611
|
+
return "n"
|
|
1612
|
+
return ", ".join(f"n.{field} AS {field}" for field in self.common_fields)
|