MemoryOS 0.2.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MemoryOS might be problematic. Click here for more details.
- {memoryos-0.2.1.dist-info → memoryos-1.0.0.dist-info}/METADATA +7 -1
- {memoryos-0.2.1.dist-info → memoryos-1.0.0.dist-info}/RECORD +87 -64
- memos/__init__.py +1 -1
- memos/api/config.py +158 -69
- memos/api/context/context.py +147 -0
- memos/api/context/dependencies.py +101 -0
- memos/api/product_models.py +5 -1
- memos/api/routers/product_router.py +54 -26
- memos/configs/graph_db.py +49 -1
- memos/configs/internet_retriever.py +19 -0
- memos/configs/mem_os.py +5 -0
- memos/configs/mem_reader.py +9 -0
- memos/configs/mem_scheduler.py +54 -18
- memos/configs/mem_user.py +58 -0
- memos/graph_dbs/base.py +38 -3
- memos/graph_dbs/factory.py +2 -0
- memos/graph_dbs/nebular.py +1612 -0
- memos/graph_dbs/neo4j.py +18 -9
- memos/log.py +6 -1
- memos/mem_cube/utils.py +13 -6
- memos/mem_os/core.py +157 -37
- memos/mem_os/main.py +2 -2
- memos/mem_os/product.py +252 -201
- memos/mem_os/utils/default_config.py +1 -1
- memos/mem_os/utils/format_utils.py +281 -70
- memos/mem_os/utils/reference_utils.py +133 -0
- memos/mem_reader/simple_struct.py +13 -5
- memos/mem_scheduler/base_scheduler.py +239 -266
- memos/mem_scheduler/{modules → general_modules}/base.py +4 -5
- memos/mem_scheduler/{modules → general_modules}/dispatcher.py +57 -21
- memos/mem_scheduler/general_modules/misc.py +104 -0
- memos/mem_scheduler/{modules → general_modules}/rabbitmq_service.py +12 -10
- memos/mem_scheduler/{modules → general_modules}/redis_service.py +1 -1
- memos/mem_scheduler/general_modules/retriever.py +199 -0
- memos/mem_scheduler/general_modules/scheduler_logger.py +261 -0
- memos/mem_scheduler/general_scheduler.py +243 -80
- memos/mem_scheduler/monitors/__init__.py +0 -0
- memos/mem_scheduler/monitors/dispatcher_monitor.py +305 -0
- memos/mem_scheduler/{modules/monitor.py → monitors/general_monitor.py} +106 -57
- memos/mem_scheduler/mos_for_test_scheduler.py +23 -20
- memos/mem_scheduler/schemas/__init__.py +0 -0
- memos/mem_scheduler/schemas/general_schemas.py +44 -0
- memos/mem_scheduler/schemas/message_schemas.py +149 -0
- memos/mem_scheduler/schemas/monitor_schemas.py +337 -0
- memos/mem_scheduler/utils/__init__.py +0 -0
- memos/mem_scheduler/utils/filter_utils.py +176 -0
- memos/mem_scheduler/utils/misc_utils.py +102 -0
- memos/mem_user/factory.py +94 -0
- memos/mem_user/mysql_persistent_user_manager.py +271 -0
- memos/mem_user/mysql_user_manager.py +500 -0
- memos/mem_user/persistent_factory.py +96 -0
- memos/mem_user/user_manager.py +4 -4
- memos/memories/activation/item.py +5 -1
- memos/memories/activation/kv.py +20 -8
- memos/memories/textual/base.py +2 -2
- memos/memories/textual/general.py +36 -92
- memos/memories/textual/item.py +5 -33
- memos/memories/textual/tree.py +13 -7
- memos/memories/textual/tree_text_memory/organize/{conflict.py → handler.py} +34 -50
- memos/memories/textual/tree_text_memory/organize/manager.py +8 -96
- memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +49 -43
- memos/memories/textual/tree_text_memory/organize/reorganizer.py +107 -142
- memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +229 -0
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever.py +6 -3
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +11 -0
- memos/memories/textual/tree_text_memory/retrieve/recall.py +15 -8
- memos/memories/textual/tree_text_memory/retrieve/reranker.py +1 -1
- memos/memories/textual/tree_text_memory/retrieve/retrieval_mid_structs.py +2 -0
- memos/memories/textual/tree_text_memory/retrieve/searcher.py +191 -116
- memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +47 -15
- memos/memories/textual/tree_text_memory/retrieve/utils.py +11 -7
- memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py +62 -58
- memos/memos_tools/dinding_report_bot.py +422 -0
- memos/memos_tools/lockfree_dict.py +120 -0
- memos/memos_tools/notification_service.py +44 -0
- memos/memos_tools/notification_utils.py +96 -0
- memos/memos_tools/thread_safe_dict.py +288 -0
- memos/settings.py +3 -1
- memos/templates/mem_reader_prompts.py +4 -1
- memos/templates/mem_scheduler_prompts.py +62 -15
- memos/templates/mos_prompts.py +116 -0
- memos/templates/tree_reorganize_prompts.py +24 -17
- memos/utils.py +19 -0
- memos/mem_scheduler/modules/misc.py +0 -39
- memos/mem_scheduler/modules/retriever.py +0 -268
- memos/mem_scheduler/modules/schemas.py +0 -328
- memos/mem_scheduler/utils.py +0 -75
- memos/memories/textual/tree_text_memory/organize/redundancy.py +0 -193
- {memoryos-0.2.1.dist-info → memoryos-1.0.0.dist-info}/LICENSE +0 -0
- {memoryos-0.2.1.dist-info → memoryos-1.0.0.dist-info}/WHEEL +0 -0
- {memoryos-0.2.1.dist-info → memoryos-1.0.0.dist-info}/entry_points.txt +0 -0
- /memos/mem_scheduler/{modules → general_modules}/__init__.py +0 -0
|
@@ -0,0 +1,1612 @@
|
|
|
1
|
+
import traceback
|
|
2
|
+
|
|
3
|
+
from contextlib import suppress
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from queue import Empty, Queue
|
|
6
|
+
from threading import Lock
|
|
7
|
+
from typing import Any, Literal
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from memos.configs.graph_db import NebulaGraphDBConfig
|
|
12
|
+
from memos.dependency import require_python_package
|
|
13
|
+
from memos.graph_dbs.base import BaseGraphDB
|
|
14
|
+
from memos.log import get_logger
|
|
15
|
+
from memos.utils import timed
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
logger = get_logger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@timed
|
|
22
|
+
def _normalize(vec: list[float]) -> list[float]:
|
|
23
|
+
v = np.asarray(vec, dtype=np.float32)
|
|
24
|
+
norm = np.linalg.norm(v)
|
|
25
|
+
return (v / (norm if norm else 1.0)).tolist()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@timed
|
|
29
|
+
def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]:
|
|
30
|
+
node_id = item["id"]
|
|
31
|
+
memory = item["memory"]
|
|
32
|
+
metadata = item.get("metadata", {})
|
|
33
|
+
return node_id, memory, metadata
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@timed
|
|
37
|
+
def _escape_str(value: str) -> str:
|
|
38
|
+
return value.replace('"', '\\"')
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@timed
|
|
42
|
+
def _format_datetime(value: str | datetime) -> str:
|
|
43
|
+
"""Ensure datetime is in ISO 8601 format string."""
|
|
44
|
+
if isinstance(value, datetime):
|
|
45
|
+
return value.isoformat()
|
|
46
|
+
return str(value)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@timed
|
|
50
|
+
def _normalize_datetime(val):
|
|
51
|
+
"""
|
|
52
|
+
Normalize datetime to ISO 8601 UTC string with +00:00.
|
|
53
|
+
- If val is datetime object -> keep isoformat() (Neo4j)
|
|
54
|
+
- If val is string without timezone -> append +00:00 (Nebula)
|
|
55
|
+
- Otherwise just str()
|
|
56
|
+
"""
|
|
57
|
+
if hasattr(val, "isoformat"):
|
|
58
|
+
return val.isoformat()
|
|
59
|
+
if isinstance(val, str) and not val.endswith(("+00:00", "Z", "+08:00")):
|
|
60
|
+
return val + "+08:00"
|
|
61
|
+
return str(val)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class SessionPoolError(Exception):
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class SessionPool:
|
|
69
|
+
@require_python_package(
|
|
70
|
+
import_name="nebulagraph_python",
|
|
71
|
+
install_command="pip install ... @Tianxing",
|
|
72
|
+
install_link=".....",
|
|
73
|
+
)
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
hosts: list[str],
|
|
77
|
+
user: str,
|
|
78
|
+
password: str,
|
|
79
|
+
minsize: int = 1,
|
|
80
|
+
maxsize: int = 10000,
|
|
81
|
+
):
|
|
82
|
+
self.hosts = hosts
|
|
83
|
+
self.user = user
|
|
84
|
+
self.password = password
|
|
85
|
+
self.minsize = minsize
|
|
86
|
+
self.maxsize = maxsize
|
|
87
|
+
self.pool = Queue(maxsize)
|
|
88
|
+
self.lock = Lock()
|
|
89
|
+
|
|
90
|
+
self.clients = []
|
|
91
|
+
|
|
92
|
+
for _ in range(minsize):
|
|
93
|
+
self._create_and_add_client()
|
|
94
|
+
|
|
95
|
+
@timed
|
|
96
|
+
def _create_and_add_client(self):
|
|
97
|
+
from nebulagraph_python import NebulaClient
|
|
98
|
+
|
|
99
|
+
client = NebulaClient(self.hosts, self.user, self.password)
|
|
100
|
+
self.pool.put(client)
|
|
101
|
+
self.clients.append(client)
|
|
102
|
+
|
|
103
|
+
@timed
|
|
104
|
+
def get_client(self, timeout: float = 5.0):
|
|
105
|
+
try:
|
|
106
|
+
return self.pool.get(timeout=timeout)
|
|
107
|
+
except Empty:
|
|
108
|
+
with self.lock:
|
|
109
|
+
if len(self.clients) < self.maxsize:
|
|
110
|
+
from nebulagraph_python import NebulaClient
|
|
111
|
+
|
|
112
|
+
client = NebulaClient(self.hosts, self.user, self.password)
|
|
113
|
+
self.clients.append(client)
|
|
114
|
+
return client
|
|
115
|
+
raise RuntimeError("NebulaClientPool exhausted") from None
|
|
116
|
+
|
|
117
|
+
@timed
|
|
118
|
+
def return_client(self, client):
|
|
119
|
+
try:
|
|
120
|
+
client.execute("YIELD 1")
|
|
121
|
+
self.pool.put(client)
|
|
122
|
+
except Exception:
|
|
123
|
+
logger.info("[Pool] Client dead, replacing...")
|
|
124
|
+
self.replace_client(client)
|
|
125
|
+
|
|
126
|
+
@timed
|
|
127
|
+
def close(self):
|
|
128
|
+
for client in self.clients:
|
|
129
|
+
with suppress(Exception):
|
|
130
|
+
client.close()
|
|
131
|
+
self.clients.clear()
|
|
132
|
+
|
|
133
|
+
@timed
|
|
134
|
+
def get(self):
|
|
135
|
+
"""
|
|
136
|
+
Context manager: with pool.get() as client:
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
class _ClientContext:
|
|
140
|
+
def __init__(self, outer):
|
|
141
|
+
self.outer = outer
|
|
142
|
+
self.client = None
|
|
143
|
+
|
|
144
|
+
def __enter__(self):
|
|
145
|
+
self.client = self.outer.get_client()
|
|
146
|
+
return self.client
|
|
147
|
+
|
|
148
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
149
|
+
if self.client:
|
|
150
|
+
self.outer.return_client(self.client)
|
|
151
|
+
|
|
152
|
+
return _ClientContext(self)
|
|
153
|
+
|
|
154
|
+
@timed
|
|
155
|
+
def reset_pool(self):
|
|
156
|
+
"""⚠️ Emergency reset: Close all clients and clear the pool."""
|
|
157
|
+
logger.warning("[Pool] Resetting all clients. Existing sessions will be lost.")
|
|
158
|
+
with self.lock:
|
|
159
|
+
for client in self.clients:
|
|
160
|
+
try:
|
|
161
|
+
client.close()
|
|
162
|
+
except Exception:
|
|
163
|
+
logger.error("Fail to close!!!")
|
|
164
|
+
self.clients.clear()
|
|
165
|
+
while not self.pool.empty():
|
|
166
|
+
try:
|
|
167
|
+
self.pool.get_nowait()
|
|
168
|
+
except Empty:
|
|
169
|
+
break
|
|
170
|
+
for _ in range(self.minsize):
|
|
171
|
+
self._create_and_add_client()
|
|
172
|
+
logger.info("[Pool] Pool has been reset successfully.")
|
|
173
|
+
|
|
174
|
+
@timed
|
|
175
|
+
def replace_client(self, client):
|
|
176
|
+
try:
|
|
177
|
+
client.close()
|
|
178
|
+
except Exception:
|
|
179
|
+
logger.error("Fail to close client")
|
|
180
|
+
|
|
181
|
+
if client in self.clients:
|
|
182
|
+
self.clients.remove(client)
|
|
183
|
+
|
|
184
|
+
from nebulagraph_python import NebulaClient
|
|
185
|
+
|
|
186
|
+
new_client = NebulaClient(self.hosts, self.user, self.password)
|
|
187
|
+
self.clients.append(new_client)
|
|
188
|
+
|
|
189
|
+
self.pool.put(new_client)
|
|
190
|
+
|
|
191
|
+
logger.info("[Pool] Replaced dead client with a new one.")
|
|
192
|
+
return new_client
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
class NebulaGraphDB(BaseGraphDB):
|
|
196
|
+
"""
|
|
197
|
+
NebulaGraph-based implementation of a graph memory store.
|
|
198
|
+
"""
|
|
199
|
+
|
|
200
|
+
@require_python_package(
|
|
201
|
+
import_name="nebulagraph_python",
|
|
202
|
+
install_command="pip install ... @Tianxing",
|
|
203
|
+
install_link=".....",
|
|
204
|
+
)
|
|
205
|
+
def __init__(self, config: NebulaGraphDBConfig):
|
|
206
|
+
"""
|
|
207
|
+
NebulaGraph DB client initialization.
|
|
208
|
+
|
|
209
|
+
Required config attributes:
|
|
210
|
+
- hosts: list[str] like ["host1:port", "host2:port"]
|
|
211
|
+
- user: str
|
|
212
|
+
- password: str
|
|
213
|
+
- db_name: str (optional for basic commands)
|
|
214
|
+
|
|
215
|
+
Example config:
|
|
216
|
+
{
|
|
217
|
+
"hosts": ["xxx.xx.xx.xxx:xxxx"],
|
|
218
|
+
"user": "root",
|
|
219
|
+
"password": "nebula",
|
|
220
|
+
"space": "test"
|
|
221
|
+
}
|
|
222
|
+
"""
|
|
223
|
+
|
|
224
|
+
self.config = config
|
|
225
|
+
self.db_name = config.space
|
|
226
|
+
self.user_name = config.user_name
|
|
227
|
+
self.embedding_dimension = config.embedding_dimension
|
|
228
|
+
self.default_memory_dimension = 3072
|
|
229
|
+
self.common_fields = {
|
|
230
|
+
"id",
|
|
231
|
+
"memory",
|
|
232
|
+
"user_name",
|
|
233
|
+
"user_id",
|
|
234
|
+
"session_id",
|
|
235
|
+
"status",
|
|
236
|
+
"key",
|
|
237
|
+
"confidence",
|
|
238
|
+
"tags",
|
|
239
|
+
"created_at",
|
|
240
|
+
"updated_at",
|
|
241
|
+
"memory_type",
|
|
242
|
+
"sources",
|
|
243
|
+
"source",
|
|
244
|
+
"node_type",
|
|
245
|
+
"visibility",
|
|
246
|
+
"usage",
|
|
247
|
+
"background",
|
|
248
|
+
}
|
|
249
|
+
self.dim_field = (
|
|
250
|
+
f"embedding_{self.embedding_dimension}"
|
|
251
|
+
if (str(self.embedding_dimension) != str(self.default_memory_dimension))
|
|
252
|
+
else "embedding"
|
|
253
|
+
)
|
|
254
|
+
self.system_db_name = "system" if config.use_multi_db else config.space
|
|
255
|
+
self.pool = SessionPool(
|
|
256
|
+
hosts=config.get("uri"),
|
|
257
|
+
user=config.get("user"),
|
|
258
|
+
password=config.get("password"),
|
|
259
|
+
minsize=1,
|
|
260
|
+
maxsize=config.get("max_client", 1000),
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
if config.auto_create:
|
|
264
|
+
self._ensure_database_exists()
|
|
265
|
+
|
|
266
|
+
self.execute_query(f"SESSION SET GRAPH `{self.db_name}`")
|
|
267
|
+
|
|
268
|
+
# Create only if not exists
|
|
269
|
+
self.create_index(dimensions=config.embedding_dimension)
|
|
270
|
+
|
|
271
|
+
logger.info("Connected to NebulaGraph successfully.")
|
|
272
|
+
|
|
273
|
+
@timed
|
|
274
|
+
def execute_query(self, gql: str, timeout: float = 5.0, auto_set_db: bool = True):
|
|
275
|
+
with self.pool.get() as client:
|
|
276
|
+
try:
|
|
277
|
+
if auto_set_db and self.db_name:
|
|
278
|
+
client.execute(f"SESSION SET GRAPH `{self.db_name}`")
|
|
279
|
+
return client.execute(gql, timeout=timeout)
|
|
280
|
+
|
|
281
|
+
except Exception as e:
|
|
282
|
+
if "Session not found" in str(e) or "Connection not established" in str(e):
|
|
283
|
+
logger.warning(f"[execute_query] {e!s}, replacing client...")
|
|
284
|
+
self.pool.replace_client(client)
|
|
285
|
+
return self.execute_query(gql, timeout, auto_set_db)
|
|
286
|
+
raise
|
|
287
|
+
|
|
288
|
+
@timed
|
|
289
|
+
def close(self):
|
|
290
|
+
self.pool.close()
|
|
291
|
+
|
|
292
|
+
@timed
|
|
293
|
+
def create_index(
|
|
294
|
+
self,
|
|
295
|
+
label: str = "Memory",
|
|
296
|
+
vector_property: str = "embedding",
|
|
297
|
+
dimensions: int = 3072,
|
|
298
|
+
index_name: str = "memory_vector_index",
|
|
299
|
+
) -> None:
|
|
300
|
+
# Create vector index
|
|
301
|
+
self._create_vector_index(label, vector_property, dimensions, index_name)
|
|
302
|
+
# Create indexes
|
|
303
|
+
self._create_basic_property_indexes()
|
|
304
|
+
|
|
305
|
+
@timed
|
|
306
|
+
def remove_oldest_memory(self, memory_type: str, keep_latest: int) -> None:
|
|
307
|
+
"""
|
|
308
|
+
Remove all WorkingMemory nodes except the latest `keep_latest` entries.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory').
|
|
312
|
+
keep_latest (int): Number of latest WorkingMemory entries to keep.
|
|
313
|
+
"""
|
|
314
|
+
optional_condition = ""
|
|
315
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
316
|
+
optional_condition = f"AND n.user_name = '{self.config.user_name}'"
|
|
317
|
+
|
|
318
|
+
query = f"""
|
|
319
|
+
MATCH (n@Memory)
|
|
320
|
+
WHERE n.memory_type = '{memory_type}'
|
|
321
|
+
{optional_condition}
|
|
322
|
+
ORDER BY n.updated_at DESC
|
|
323
|
+
OFFSET {keep_latest}
|
|
324
|
+
DETACH DELETE n
|
|
325
|
+
"""
|
|
326
|
+
self.execute_query(query)
|
|
327
|
+
|
|
328
|
+
@timed
|
|
329
|
+
def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None:
|
|
330
|
+
"""
|
|
331
|
+
Insert or update a Memory node in NebulaGraph.
|
|
332
|
+
"""
|
|
333
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
334
|
+
metadata["user_name"] = self.config.user_name
|
|
335
|
+
|
|
336
|
+
now = datetime.utcnow()
|
|
337
|
+
metadata = metadata.copy()
|
|
338
|
+
metadata.setdefault("created_at", now)
|
|
339
|
+
metadata.setdefault("updated_at", now)
|
|
340
|
+
metadata["node_type"] = metadata.pop("type")
|
|
341
|
+
metadata["id"] = id
|
|
342
|
+
metadata["memory"] = memory
|
|
343
|
+
|
|
344
|
+
if "embedding" in metadata and isinstance(metadata["embedding"], list):
|
|
345
|
+
assert len(metadata["embedding"]) == self.embedding_dimension, (
|
|
346
|
+
f"input embedding dimension must equal to {self.embedding_dimension}"
|
|
347
|
+
)
|
|
348
|
+
embedding = metadata.pop("embedding")
|
|
349
|
+
metadata[self.dim_field] = _normalize(embedding)
|
|
350
|
+
|
|
351
|
+
metadata = self._metadata_filter(metadata)
|
|
352
|
+
properties = ", ".join(f"{k}: {self._format_value(v, k)}" for k, v in metadata.items())
|
|
353
|
+
gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
|
|
354
|
+
|
|
355
|
+
try:
|
|
356
|
+
self.execute_query(gql)
|
|
357
|
+
logger.info("insert success")
|
|
358
|
+
except Exception as e:
|
|
359
|
+
logger.error(
|
|
360
|
+
f"Failed to insert vertex {id}: gql: {gql}, {e}\ntrace: {traceback.format_exc()}"
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
@timed
|
|
364
|
+
def node_not_exist(self, scope: str) -> int:
|
|
365
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
366
|
+
filter_clause = f'n.memory_type = "{scope}" AND n.user_name = "{self.config.user_name}"'
|
|
367
|
+
else:
|
|
368
|
+
filter_clause = f'n.memory_type = "{scope}"'
|
|
369
|
+
return_fields = ", ".join(f"n.{field} AS {field}" for field in self.common_fields)
|
|
370
|
+
|
|
371
|
+
query = f"""
|
|
372
|
+
MATCH (n@Memory)
|
|
373
|
+
WHERE {filter_clause}
|
|
374
|
+
RETURN {return_fields}
|
|
375
|
+
LIMIT 1
|
|
376
|
+
"""
|
|
377
|
+
|
|
378
|
+
try:
|
|
379
|
+
result = self.execute_query(query)
|
|
380
|
+
return result.size == 0
|
|
381
|
+
except Exception as e:
|
|
382
|
+
logger.error(f"[node_not_exist] Query failed: {e}", exc_info=True)
|
|
383
|
+
raise
|
|
384
|
+
|
|
385
|
+
@timed
|
|
386
|
+
def update_node(self, id: str, fields: dict[str, Any]) -> None:
|
|
387
|
+
"""
|
|
388
|
+
Update node fields in Nebular, auto-converting `created_at` and `updated_at` to datetime type if present.
|
|
389
|
+
"""
|
|
390
|
+
fields = fields.copy()
|
|
391
|
+
set_clauses = []
|
|
392
|
+
for k, v in fields.items():
|
|
393
|
+
set_clauses.append(f"n.{k} = {self._format_value(v, k)}")
|
|
394
|
+
|
|
395
|
+
set_clause_str = ",\n ".join(set_clauses)
|
|
396
|
+
|
|
397
|
+
query = f"""
|
|
398
|
+
MATCH (n@Memory {{id: "{id}"}})
|
|
399
|
+
"""
|
|
400
|
+
|
|
401
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
402
|
+
query += f'WHERE n.user_name = "{self.config.user_name}"'
|
|
403
|
+
|
|
404
|
+
query += f"\nSET {set_clause_str}"
|
|
405
|
+
self.execute_query(query)
|
|
406
|
+
|
|
407
|
+
@timed
|
|
408
|
+
def delete_node(self, id: str) -> None:
|
|
409
|
+
"""
|
|
410
|
+
Delete a node from the graph.
|
|
411
|
+
Args:
|
|
412
|
+
id: Node identifier to delete.
|
|
413
|
+
"""
|
|
414
|
+
query = f"""
|
|
415
|
+
MATCH (n@Memory {{id: "{id}"}})
|
|
416
|
+
"""
|
|
417
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
418
|
+
user_name = self.config.user_name
|
|
419
|
+
query += f" WHERE n.user_name = {self._format_value(user_name)}"
|
|
420
|
+
query += "\n DETACH DELETE n"
|
|
421
|
+
self.execute_query(query)
|
|
422
|
+
|
|
423
|
+
@timed
|
|
424
|
+
def add_edge(self, source_id: str, target_id: str, type: str):
|
|
425
|
+
"""
|
|
426
|
+
Create an edge from source node to target node.
|
|
427
|
+
Args:
|
|
428
|
+
source_id: ID of the source node.
|
|
429
|
+
target_id: ID of the target node.
|
|
430
|
+
type: Relationship type (e.g., 'RELATE_TO', 'PARENT').
|
|
431
|
+
"""
|
|
432
|
+
if not source_id or not target_id:
|
|
433
|
+
raise ValueError("[add_edge] source_id and target_id must be provided")
|
|
434
|
+
|
|
435
|
+
props = ""
|
|
436
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
437
|
+
props = f'{{user_name: "{self.config.user_name}"}}'
|
|
438
|
+
|
|
439
|
+
insert_stmt = f'''
|
|
440
|
+
MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}})
|
|
441
|
+
INSERT (a) -[e@{type} {props}]-> (b)
|
|
442
|
+
'''
|
|
443
|
+
try:
|
|
444
|
+
self.execute_query(insert_stmt)
|
|
445
|
+
except Exception as e:
|
|
446
|
+
logger.error(f"Failed to insert edge: {e}", exc_info=True)
|
|
447
|
+
|
|
448
|
+
@timed
|
|
449
|
+
def delete_edge(self, source_id: str, target_id: str, type: str) -> None:
|
|
450
|
+
"""
|
|
451
|
+
Delete a specific edge between two nodes.
|
|
452
|
+
Args:
|
|
453
|
+
source_id: ID of the source node.
|
|
454
|
+
target_id: ID of the target node.
|
|
455
|
+
type: Relationship type to remove.
|
|
456
|
+
"""
|
|
457
|
+
query = f"""
|
|
458
|
+
MATCH (a@Memory) -[r@{type}]-> (b@Memory)
|
|
459
|
+
WHERE a.id = {self._format_value(source_id)} AND b.id = {self._format_value(target_id)}
|
|
460
|
+
"""
|
|
461
|
+
|
|
462
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
463
|
+
user_name = self.config.user_name
|
|
464
|
+
query += f" AND a.user_name = {self._format_value(user_name)} AND b.user_name = {self._format_value(user_name)}"
|
|
465
|
+
|
|
466
|
+
query += "\nDELETE r"
|
|
467
|
+
self.execute_query(query)
|
|
468
|
+
|
|
469
|
+
@timed
|
|
470
|
+
def get_memory_count(self, memory_type: str) -> int:
|
|
471
|
+
query = f"""
|
|
472
|
+
MATCH (n@Memory)
|
|
473
|
+
WHERE n.memory_type = "{memory_type}"
|
|
474
|
+
"""
|
|
475
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
476
|
+
user_name = self.config.user_name
|
|
477
|
+
query += f"\nAND n.user_name = '{user_name}'"
|
|
478
|
+
query += "\nRETURN COUNT(n) AS count"
|
|
479
|
+
|
|
480
|
+
try:
|
|
481
|
+
result = self.execute_query(query)
|
|
482
|
+
return result.one_or_none()["count"].value
|
|
483
|
+
except Exception as e:
|
|
484
|
+
logger.error(f"[get_memory_count] Failed: {e}")
|
|
485
|
+
return -1
|
|
486
|
+
|
|
487
|
+
@timed
|
|
488
|
+
def count_nodes(self, scope: str) -> int:
|
|
489
|
+
query = f"""
|
|
490
|
+
MATCH (n@Memory)
|
|
491
|
+
WHERE n.memory_type = "{scope}"
|
|
492
|
+
"""
|
|
493
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
494
|
+
user_name = self.config.user_name
|
|
495
|
+
query += f"\nAND n.user_name = '{user_name}'"
|
|
496
|
+
query += "\nRETURN count(n) AS count"
|
|
497
|
+
|
|
498
|
+
result = self.execute_query(query)
|
|
499
|
+
return result.one_or_none()["count"].value
|
|
500
|
+
|
|
501
|
+
@timed
|
|
502
|
+
def edge_exists(
|
|
503
|
+
self, source_id: str, target_id: str, type: str = "ANY", direction: str = "OUTGOING"
|
|
504
|
+
) -> bool:
|
|
505
|
+
"""
|
|
506
|
+
Check if an edge exists between two nodes.
|
|
507
|
+
Args:
|
|
508
|
+
source_id: ID of the source node.
|
|
509
|
+
target_id: ID of the target node.
|
|
510
|
+
type: Relationship type. Use "ANY" to match any relationship type.
|
|
511
|
+
direction: Direction of the edge.
|
|
512
|
+
Use "OUTGOING" (default), "INCOMING", or "ANY".
|
|
513
|
+
Returns:
|
|
514
|
+
True if the edge exists, otherwise False.
|
|
515
|
+
"""
|
|
516
|
+
# Prepare the relationship pattern
|
|
517
|
+
rel = "r" if type == "ANY" else f"r@{type}"
|
|
518
|
+
|
|
519
|
+
# Prepare the match pattern with direction
|
|
520
|
+
if direction == "OUTGOING":
|
|
521
|
+
pattern = f"(a@Memory {{id: '{source_id}'}})-[{rel}]->(b@Memory {{id: '{target_id}'}})"
|
|
522
|
+
elif direction == "INCOMING":
|
|
523
|
+
pattern = f"(a@Memory {{id: '{source_id}'}})<-[{rel}]-(b@Memory {{id: '{target_id}'}})"
|
|
524
|
+
elif direction == "ANY":
|
|
525
|
+
pattern = f"(a@Memory {{id: '{source_id}'}})-[{rel}]-(b@Memory {{id: '{target_id}'}})"
|
|
526
|
+
else:
|
|
527
|
+
raise ValueError(
|
|
528
|
+
f"Invalid direction: {direction}. Must be 'OUTGOING', 'INCOMING', or 'ANY'."
|
|
529
|
+
)
|
|
530
|
+
query = f"MATCH {pattern}"
|
|
531
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
532
|
+
user_name = self.config.user_name
|
|
533
|
+
query += f"\nWHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}'"
|
|
534
|
+
query += "\nRETURN r"
|
|
535
|
+
|
|
536
|
+
# Run the Cypher query
|
|
537
|
+
result = self.execute_query(query)
|
|
538
|
+
record = result.one_or_none()
|
|
539
|
+
if record is None:
|
|
540
|
+
return False
|
|
541
|
+
return record.values() is not None
|
|
542
|
+
|
|
543
|
+
@timed
|
|
544
|
+
# Graph Query & Reasoning
|
|
545
|
+
def get_node(self, id: str, include_embedding: bool = False) -> dict[str, Any] | None:
|
|
546
|
+
"""
|
|
547
|
+
Retrieve a Memory node by its unique ID.
|
|
548
|
+
|
|
549
|
+
Args:
|
|
550
|
+
id (str): Node ID (Memory.id)
|
|
551
|
+
include_embedding: with/without embedding
|
|
552
|
+
|
|
553
|
+
Returns:
|
|
554
|
+
dict: Node properties as key-value pairs, or None if not found.
|
|
555
|
+
"""
|
|
556
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
557
|
+
filter_clause = f'n.user_name = "{self.config.user_name}" AND n.id = "{id}"'
|
|
558
|
+
else:
|
|
559
|
+
filter_clause = f'n.id = "{id}"'
|
|
560
|
+
|
|
561
|
+
return_fields = self._build_return_fields(include_embedding)
|
|
562
|
+
gql = f"""
|
|
563
|
+
MATCH (n@Memory)
|
|
564
|
+
WHERE {filter_clause}
|
|
565
|
+
RETURN {return_fields}
|
|
566
|
+
"""
|
|
567
|
+
|
|
568
|
+
try:
|
|
569
|
+
result = self.execute_query(gql)
|
|
570
|
+
for row in result:
|
|
571
|
+
if include_embedding:
|
|
572
|
+
props = row.values()[0].as_node().get_properties()
|
|
573
|
+
else:
|
|
574
|
+
props = {k: v.value for k, v in row.items()}
|
|
575
|
+
node = self._parse_node(props)
|
|
576
|
+
return node
|
|
577
|
+
|
|
578
|
+
except Exception as e:
|
|
579
|
+
logger.error(
|
|
580
|
+
f"[get_node] Failed to retrieve node '{id}': {e}, trace: {traceback.format_exc()}"
|
|
581
|
+
)
|
|
582
|
+
return None
|
|
583
|
+
|
|
584
|
+
@timed
|
|
585
|
+
def get_nodes(self, ids: list[str], include_embedding: bool = False) -> list[dict[str, Any]]:
|
|
586
|
+
"""
|
|
587
|
+
Retrieve the metadata and memory of a list of nodes.
|
|
588
|
+
Args:
|
|
589
|
+
ids: List of Node identifier.
|
|
590
|
+
include_embedding: with/without embedding
|
|
591
|
+
Returns:
|
|
592
|
+
list[dict]: Parsed node records containing 'id', 'memory', and 'metadata'.
|
|
593
|
+
|
|
594
|
+
Notes:
|
|
595
|
+
- Assumes all provided IDs are valid and exist.
|
|
596
|
+
- Returns empty list if input is empty.
|
|
597
|
+
"""
|
|
598
|
+
if not ids:
|
|
599
|
+
return []
|
|
600
|
+
|
|
601
|
+
where_user = ""
|
|
602
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
603
|
+
where_user = f" AND n.user_name = '{self.config.user_name}'"
|
|
604
|
+
|
|
605
|
+
# Safe formatting of the ID list
|
|
606
|
+
id_list = ",".join(f'"{_id}"' for _id in ids)
|
|
607
|
+
|
|
608
|
+
return_fields = self._build_return_fields(include_embedding)
|
|
609
|
+
query = f"""
|
|
610
|
+
MATCH (n@Memory)
|
|
611
|
+
WHERE n.id IN [{id_list}] {where_user}
|
|
612
|
+
RETURN {return_fields}
|
|
613
|
+
"""
|
|
614
|
+
nodes = []
|
|
615
|
+
try:
|
|
616
|
+
results = self.execute_query(query)
|
|
617
|
+
for row in results:
|
|
618
|
+
if include_embedding:
|
|
619
|
+
props = row.values()[0].as_node().get_properties()
|
|
620
|
+
else:
|
|
621
|
+
props = {k: v.value for k, v in row.items()}
|
|
622
|
+
nodes.append(self._parse_node(props))
|
|
623
|
+
except Exception as e:
|
|
624
|
+
logger.error(
|
|
625
|
+
f"[get_nodes] Failed to retrieve nodes {ids}: {e}, trace: {traceback.format_exc()}"
|
|
626
|
+
)
|
|
627
|
+
return nodes
|
|
628
|
+
|
|
629
|
+
@timed
|
|
630
|
+
def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[dict[str, str]]:
|
|
631
|
+
"""
|
|
632
|
+
Get edges connected to a node, with optional type and direction filter.
|
|
633
|
+
|
|
634
|
+
Args:
|
|
635
|
+
id: Node ID to retrieve edges for.
|
|
636
|
+
type: Relationship type to match, or 'ANY' to match all.
|
|
637
|
+
direction: 'OUTGOING', 'INCOMING', or 'ANY'.
|
|
638
|
+
|
|
639
|
+
Returns:
|
|
640
|
+
List of edges:
|
|
641
|
+
[
|
|
642
|
+
{"from": "source_id", "to": "target_id", "type": "RELATE"},
|
|
643
|
+
...
|
|
644
|
+
]
|
|
645
|
+
"""
|
|
646
|
+
# Build relationship type filter
|
|
647
|
+
rel_type = "" if type == "ANY" else f"@{type}"
|
|
648
|
+
|
|
649
|
+
# Build Cypher pattern based on direction
|
|
650
|
+
if direction == "OUTGOING":
|
|
651
|
+
pattern = f"(a@Memory)-[r{rel_type}]->(b@Memory)"
|
|
652
|
+
where_clause = f"a.id = '{id}'"
|
|
653
|
+
elif direction == "INCOMING":
|
|
654
|
+
pattern = f"(a@Memory)<-[r{rel_type}]-(b@Memory)"
|
|
655
|
+
where_clause = f"a.id = '{id}'"
|
|
656
|
+
elif direction == "ANY":
|
|
657
|
+
pattern = f"(a@Memory)-[r{rel_type}]-(b@Memory)"
|
|
658
|
+
where_clause = f"a.id = '{id}' OR b.id = '{id}'"
|
|
659
|
+
else:
|
|
660
|
+
raise ValueError("Invalid direction. Must be 'OUTGOING', 'INCOMING', or 'ANY'.")
|
|
661
|
+
|
|
662
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
663
|
+
where_clause += f" AND a.user_name = '{self.config.user_name}' AND b.user_name = '{self.config.user_name}'"
|
|
664
|
+
|
|
665
|
+
query = f"""
|
|
666
|
+
MATCH {pattern}
|
|
667
|
+
WHERE {where_clause}
|
|
668
|
+
RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type
|
|
669
|
+
"""
|
|
670
|
+
|
|
671
|
+
result = self.execute_query(query)
|
|
672
|
+
edges = []
|
|
673
|
+
for record in result:
|
|
674
|
+
edges.append(
|
|
675
|
+
{
|
|
676
|
+
"from": record["from_id"].value,
|
|
677
|
+
"to": record["to_id"].value,
|
|
678
|
+
"type": record["edge_type"].value,
|
|
679
|
+
}
|
|
680
|
+
)
|
|
681
|
+
return edges
|
|
682
|
+
|
|
683
|
+
@timed
|
|
684
|
+
def get_neighbors_by_tag(
|
|
685
|
+
self,
|
|
686
|
+
tags: list[str],
|
|
687
|
+
exclude_ids: list[str],
|
|
688
|
+
top_k: int = 5,
|
|
689
|
+
min_overlap: int = 1,
|
|
690
|
+
) -> list[dict[str, Any]]:
|
|
691
|
+
"""
|
|
692
|
+
Find top-K neighbor nodes with maximum tag overlap.
|
|
693
|
+
|
|
694
|
+
Args:
|
|
695
|
+
tags: The list of tags to match.
|
|
696
|
+
exclude_ids: Node IDs to exclude (e.g., local cluster).
|
|
697
|
+
top_k: Max number of neighbors to return.
|
|
698
|
+
min_overlap: Minimum number of overlapping tags required.
|
|
699
|
+
|
|
700
|
+
Returns:
|
|
701
|
+
List of dicts with node details and overlap count.
|
|
702
|
+
"""
|
|
703
|
+
if not tags:
|
|
704
|
+
return []
|
|
705
|
+
|
|
706
|
+
where_clauses = [
|
|
707
|
+
'n.status = "activated"',
|
|
708
|
+
'NOT (n.node_type = "reasoning")',
|
|
709
|
+
'NOT (n.memory_type = "WorkingMemory")',
|
|
710
|
+
]
|
|
711
|
+
if exclude_ids:
|
|
712
|
+
where_clauses.append(f"NOT (n.id IN {exclude_ids})")
|
|
713
|
+
|
|
714
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
715
|
+
where_clauses.append(f'n.user_name = "{self.config.user_name}"')
|
|
716
|
+
|
|
717
|
+
where_clause = " AND ".join(where_clauses)
|
|
718
|
+
tag_list_literal = "[" + ", ".join(f'"{_escape_str(t)}"' for t in tags) + "]"
|
|
719
|
+
|
|
720
|
+
query = f"""
|
|
721
|
+
LET tag_list = {tag_list_literal}
|
|
722
|
+
|
|
723
|
+
MATCH (n@Memory)
|
|
724
|
+
WHERE {where_clause}
|
|
725
|
+
RETURN n,
|
|
726
|
+
size( filter( n.tags, t -> t IN tag_list ) ) AS overlap_count
|
|
727
|
+
ORDER BY overlap_count DESC
|
|
728
|
+
LIMIT {top_k}
|
|
729
|
+
"""
|
|
730
|
+
|
|
731
|
+
result = self.execute_query(query)
|
|
732
|
+
neighbors: list[dict[str, Any]] = []
|
|
733
|
+
for r in result:
|
|
734
|
+
node_props = r["n"].as_node().get_properties()
|
|
735
|
+
parsed = self._parse_node(node_props) # --> {id, memory, metadata}
|
|
736
|
+
|
|
737
|
+
parsed["overlap_count"] = r["overlap_count"].value
|
|
738
|
+
neighbors.append(parsed)
|
|
739
|
+
|
|
740
|
+
neighbors.sort(key=lambda x: x["overlap_count"], reverse=True)
|
|
741
|
+
neighbors = neighbors[:top_k]
|
|
742
|
+
result = []
|
|
743
|
+
for neighbor in neighbors[:top_k]:
|
|
744
|
+
neighbor.pop("overlap_count")
|
|
745
|
+
result.append(neighbor)
|
|
746
|
+
return result
|
|
747
|
+
|
|
748
|
+
@timed
|
|
749
|
+
def get_children_with_embeddings(self, id: str) -> list[dict[str, Any]]:
|
|
750
|
+
where_user = ""
|
|
751
|
+
|
|
752
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
753
|
+
user_name = self.config.user_name
|
|
754
|
+
where_user = f"AND p.user_name = '{user_name}' AND c.user_name = '{user_name}'"
|
|
755
|
+
|
|
756
|
+
query = f"""
|
|
757
|
+
MATCH (p@Memory)-[@PARENT]->(c@Memory)
|
|
758
|
+
WHERE p.id = "{id}" {where_user}
|
|
759
|
+
RETURN c.id AS id, c.{self.dim_field} AS {self.dim_field}, c.memory AS memory
|
|
760
|
+
"""
|
|
761
|
+
result = self.execute_query(query)
|
|
762
|
+
children = []
|
|
763
|
+
for row in result:
|
|
764
|
+
eid = row["id"].value # STRING
|
|
765
|
+
emb_v = row[self.dim_field].value # NVector
|
|
766
|
+
emb = list(emb_v.values) if emb_v else []
|
|
767
|
+
mem = row["memory"].value # STRING
|
|
768
|
+
|
|
769
|
+
children.append({"id": eid, "embedding": emb, "memory": mem})
|
|
770
|
+
return children
|
|
771
|
+
|
|
772
|
+
@timed
|
|
773
|
+
def get_subgraph(
|
|
774
|
+
self, center_id: str, depth: int = 2, center_status: str = "activated"
|
|
775
|
+
) -> dict[str, Any]:
|
|
776
|
+
"""
|
|
777
|
+
Retrieve a local subgraph centered at a given node.
|
|
778
|
+
Args:
|
|
779
|
+
center_id: The ID of the center node.
|
|
780
|
+
depth: The hop distance for neighbors.
|
|
781
|
+
center_status: Required status for center node.
|
|
782
|
+
Returns:
|
|
783
|
+
{
|
|
784
|
+
"core_node": {...},
|
|
785
|
+
"neighbors": [...],
|
|
786
|
+
"edges": [...]
|
|
787
|
+
}
|
|
788
|
+
"""
|
|
789
|
+
if not 1 <= depth <= 5:
|
|
790
|
+
raise ValueError("depth must be 1-5")
|
|
791
|
+
|
|
792
|
+
user_name = self.config.user_name
|
|
793
|
+
gql = f"""
|
|
794
|
+
MATCH (center@Memory)
|
|
795
|
+
WHERE center.id = '{center_id}'
|
|
796
|
+
AND center.status = '{center_status}'
|
|
797
|
+
AND center.user_name = '{user_name}'
|
|
798
|
+
OPTIONAL MATCH p = (center)-[e]->{{1,{depth}}}(neighbor@Memory)
|
|
799
|
+
WHERE neighbor.user_name = '{user_name}'
|
|
800
|
+
RETURN center,
|
|
801
|
+
collect(DISTINCT neighbor) AS neighbors,
|
|
802
|
+
collect(EDGES(p)) AS edge_chains
|
|
803
|
+
"""
|
|
804
|
+
|
|
805
|
+
result = self.execute_query(gql).one_or_none()
|
|
806
|
+
if not result or result.size == 0:
|
|
807
|
+
return {"core_node": None, "neighbors": [], "edges": []}
|
|
808
|
+
|
|
809
|
+
core_node_props = result["center"].as_node().get_properties()
|
|
810
|
+
core_node = self._parse_node(core_node_props)
|
|
811
|
+
neighbors = []
|
|
812
|
+
vid_to_id_map = {result["center"].as_node().node_id: core_node["id"]}
|
|
813
|
+
for n in result["neighbors"].value:
|
|
814
|
+
n_node = n.as_node()
|
|
815
|
+
n_props = n_node.get_properties()
|
|
816
|
+
node_parsed = self._parse_node(n_props)
|
|
817
|
+
neighbors.append(node_parsed)
|
|
818
|
+
vid_to_id_map[n_node.node_id] = node_parsed["id"]
|
|
819
|
+
|
|
820
|
+
edges = []
|
|
821
|
+
for chain_group in result["edge_chains"].value:
|
|
822
|
+
for edge_wr in chain_group.value:
|
|
823
|
+
edge = edge_wr.value
|
|
824
|
+
edges.append(
|
|
825
|
+
{
|
|
826
|
+
"type": edge.get_type(),
|
|
827
|
+
"source": vid_to_id_map.get(edge.get_src_id()),
|
|
828
|
+
"target": vid_to_id_map.get(edge.get_dst_id()),
|
|
829
|
+
}
|
|
830
|
+
)
|
|
831
|
+
|
|
832
|
+
return {"core_node": core_node, "neighbors": neighbors, "edges": edges}
|
|
833
|
+
|
|
834
|
+
@timed
|
|
835
|
+
# Search / recall operations
|
|
836
|
+
def search_by_embedding(
|
|
837
|
+
self,
|
|
838
|
+
vector: list[float],
|
|
839
|
+
top_k: int = 5,
|
|
840
|
+
scope: str | None = None,
|
|
841
|
+
status: str | None = None,
|
|
842
|
+
threshold: float | None = None,
|
|
843
|
+
) -> list[dict]:
|
|
844
|
+
"""
|
|
845
|
+
Retrieve node IDs based on vector similarity.
|
|
846
|
+
|
|
847
|
+
Args:
|
|
848
|
+
vector (list[float]): The embedding vector representing query semantics.
|
|
849
|
+
top_k (int): Number of top similar nodes to retrieve.
|
|
850
|
+
scope (str, optional): Memory type filter (e.g., 'WorkingMemory', 'LongTermMemory').
|
|
851
|
+
status (str, optional): Node status filter (e.g., 'active', 'archived').
|
|
852
|
+
If provided, restricts results to nodes with matching status.
|
|
853
|
+
threshold (float, optional): Minimum similarity score threshold (0 ~ 1).
|
|
854
|
+
|
|
855
|
+
Returns:
|
|
856
|
+
list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
|
|
857
|
+
|
|
858
|
+
Notes:
|
|
859
|
+
- This method uses Neo4j native vector indexing to search for similar nodes.
|
|
860
|
+
- If scope is provided, it restricts results to nodes with matching memory_type.
|
|
861
|
+
- If 'status' is provided, only nodes with the matching status will be returned.
|
|
862
|
+
- If threshold is provided, only results with score >= threshold will be returned.
|
|
863
|
+
- Typical use case: restrict to 'status = activated' to avoid
|
|
864
|
+
matching archived or merged nodes.
|
|
865
|
+
"""
|
|
866
|
+
vector = _normalize(vector)
|
|
867
|
+
dim = len(vector)
|
|
868
|
+
vector_str = ",".join(f"{float(x)}" for x in vector)
|
|
869
|
+
gql_vector = f"VECTOR<{dim}, FLOAT>([{vector_str}])"
|
|
870
|
+
|
|
871
|
+
where_clauses = []
|
|
872
|
+
if scope:
|
|
873
|
+
where_clauses.append(f'n.memory_type = "{scope}"')
|
|
874
|
+
if status:
|
|
875
|
+
where_clauses.append(f'n.status = "{status}"')
|
|
876
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
877
|
+
where_clauses.append(f'n.user_name = "{self.config.user_name}"')
|
|
878
|
+
|
|
879
|
+
where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
|
|
880
|
+
|
|
881
|
+
gql = f"""
|
|
882
|
+
USE `{self.db_name}`
|
|
883
|
+
MATCH (n@Memory)
|
|
884
|
+
{where_clause}
|
|
885
|
+
ORDER BY inner_product(n.{self.dim_field}, {gql_vector}) DESC
|
|
886
|
+
APPROXIMATE
|
|
887
|
+
LIMIT {top_k}
|
|
888
|
+
OPTIONS {{ METRIC: IP, TYPE: IVF, NPROBE: 8 }}
|
|
889
|
+
RETURN n.id AS id, inner_product(n.{self.dim_field}, {gql_vector}) AS score
|
|
890
|
+
"""
|
|
891
|
+
|
|
892
|
+
try:
|
|
893
|
+
result = self.execute_query(gql)
|
|
894
|
+
except Exception as e:
|
|
895
|
+
logger.error(f"[search_by_embedding] Query failed: {e}")
|
|
896
|
+
return []
|
|
897
|
+
|
|
898
|
+
try:
|
|
899
|
+
output = []
|
|
900
|
+
for row in result:
|
|
901
|
+
values = row.values()
|
|
902
|
+
id_val = values[0].as_string()
|
|
903
|
+
score_val = values[1].as_double()
|
|
904
|
+
score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score
|
|
905
|
+
if threshold is None or score_val <= threshold:
|
|
906
|
+
output.append({"id": id_val, "score": score_val})
|
|
907
|
+
return output
|
|
908
|
+
except Exception as e:
|
|
909
|
+
logger.error(f"[search_by_embedding] Result parse failed: {e}")
|
|
910
|
+
return []
|
|
911
|
+
|
|
912
|
+
@timed
|
|
913
|
+
def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]:
|
|
914
|
+
"""
|
|
915
|
+
1. ADD logic: "AND" vs "OR"(support logic combination);
|
|
916
|
+
2. Support nested conditional expressions;
|
|
917
|
+
|
|
918
|
+
Retrieve node IDs that match given metadata filters.
|
|
919
|
+
Supports exact match.
|
|
920
|
+
|
|
921
|
+
Args:
|
|
922
|
+
filters: List of filter dicts like:
|
|
923
|
+
[
|
|
924
|
+
{"field": "key", "op": "in", "value": ["A", "B"]},
|
|
925
|
+
{"field": "confidence", "op": ">=", "value": 80},
|
|
926
|
+
{"field": "tags", "op": "contains", "value": "AI"},
|
|
927
|
+
...
|
|
928
|
+
]
|
|
929
|
+
|
|
930
|
+
Returns:
|
|
931
|
+
list[str]: Node IDs whose metadata match the filter conditions. (AND logic).
|
|
932
|
+
|
|
933
|
+
Notes:
|
|
934
|
+
- Supports structured querying such as tag/category/importance/time filtering.
|
|
935
|
+
- Can be used for faceted recall or prefiltering before embedding rerank.
|
|
936
|
+
"""
|
|
937
|
+
where_clauses = []
|
|
938
|
+
|
|
939
|
+
def _escape_value(value):
|
|
940
|
+
if isinstance(value, str):
|
|
941
|
+
return f'"{value}"'
|
|
942
|
+
elif isinstance(value, list):
|
|
943
|
+
return "[" + ", ".join(_escape_value(v) for v in value) + "]"
|
|
944
|
+
else:
|
|
945
|
+
return str(value)
|
|
946
|
+
|
|
947
|
+
for _i, f in enumerate(filters):
|
|
948
|
+
field = f["field"]
|
|
949
|
+
op = f.get("op", "=")
|
|
950
|
+
value = f["value"]
|
|
951
|
+
|
|
952
|
+
escaped_value = _escape_value(value)
|
|
953
|
+
|
|
954
|
+
# Build WHERE clause
|
|
955
|
+
if op == "=":
|
|
956
|
+
where_clauses.append(f"n.{field} = {escaped_value}")
|
|
957
|
+
elif op == "in":
|
|
958
|
+
where_clauses.append(f"n.{field} IN {escaped_value}")
|
|
959
|
+
elif op == "contains":
|
|
960
|
+
where_clauses.append(f"size(filter(n.{field}, t -> t IN {escaped_value})) > 0")
|
|
961
|
+
elif op == "starts_with":
|
|
962
|
+
where_clauses.append(f"n.{field} STARTS WITH {escaped_value}")
|
|
963
|
+
elif op == "ends_with":
|
|
964
|
+
where_clauses.append(f"n.{field} ENDS WITH {escaped_value}")
|
|
965
|
+
elif op in [">", ">=", "<", "<="]:
|
|
966
|
+
where_clauses.append(f"n.{field} {op} {escaped_value}")
|
|
967
|
+
else:
|
|
968
|
+
raise ValueError(f"Unsupported operator: {op}")
|
|
969
|
+
|
|
970
|
+
if not self.config.use_multi_db and self.user_name:
|
|
971
|
+
where_clauses.append(f'n.user_name = "{self.config.user_name}"')
|
|
972
|
+
|
|
973
|
+
where_str = " AND ".join(where_clauses)
|
|
974
|
+
gql = f"MATCH (n@Memory) WHERE {where_str} RETURN n.id AS id"
|
|
975
|
+
ids = []
|
|
976
|
+
try:
|
|
977
|
+
result = self.execute_query(gql)
|
|
978
|
+
ids = [record["id"].value for record in result]
|
|
979
|
+
except Exception as e:
|
|
980
|
+
logger.error(f"Failed to get metadata: {e}, gql is {gql}")
|
|
981
|
+
return ids
|
|
982
|
+
|
|
983
|
+
@timed
|
|
984
|
+
def get_grouped_counts(
|
|
985
|
+
self,
|
|
986
|
+
group_fields: list[str],
|
|
987
|
+
where_clause: str = "",
|
|
988
|
+
params: dict[str, Any] | None = None,
|
|
989
|
+
) -> list[dict[str, Any]]:
|
|
990
|
+
"""
|
|
991
|
+
Count nodes grouped by any fields.
|
|
992
|
+
|
|
993
|
+
Args:
|
|
994
|
+
group_fields (list[str]): Fields to group by, e.g., ["memory_type", "status"]
|
|
995
|
+
where_clause (str, optional): Extra WHERE condition. E.g.,
|
|
996
|
+
"WHERE n.status = 'activated'"
|
|
997
|
+
params (dict, optional): Parameters for WHERE clause.
|
|
998
|
+
|
|
999
|
+
Returns:
|
|
1000
|
+
list[dict]: e.g., [{ 'memory_type': 'WorkingMemory', 'status': 'active', 'count': 10 }, ...]
|
|
1001
|
+
"""
|
|
1002
|
+
if not group_fields:
|
|
1003
|
+
raise ValueError("group_fields cannot be empty")
|
|
1004
|
+
|
|
1005
|
+
# GQL-specific modifications
|
|
1006
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
1007
|
+
user_clause = f"n.user_name = '{self.config.user_name}'"
|
|
1008
|
+
if where_clause:
|
|
1009
|
+
where_clause = where_clause.strip()
|
|
1010
|
+
if where_clause.upper().startswith("WHERE"):
|
|
1011
|
+
where_clause += f" AND {user_clause}"
|
|
1012
|
+
else:
|
|
1013
|
+
where_clause = f"WHERE {where_clause} AND {user_clause}"
|
|
1014
|
+
else:
|
|
1015
|
+
where_clause = f"WHERE {user_clause}"
|
|
1016
|
+
|
|
1017
|
+
# Inline parameters if provided
|
|
1018
|
+
if params:
|
|
1019
|
+
for key, value in params.items():
|
|
1020
|
+
# Handle different value types appropriately
|
|
1021
|
+
if isinstance(value, str):
|
|
1022
|
+
value = f"'{value}'"
|
|
1023
|
+
where_clause = where_clause.replace(f"${key}", str(value))
|
|
1024
|
+
|
|
1025
|
+
return_fields = []
|
|
1026
|
+
group_by_fields = []
|
|
1027
|
+
|
|
1028
|
+
for field in group_fields:
|
|
1029
|
+
alias = field.replace(".", "_")
|
|
1030
|
+
return_fields.append(f"n.{field} AS {alias}")
|
|
1031
|
+
group_by_fields.append(alias)
|
|
1032
|
+
# Full GQL query construction
|
|
1033
|
+
gql = f"""
|
|
1034
|
+
MATCH (n)
|
|
1035
|
+
{where_clause}
|
|
1036
|
+
RETURN {", ".join(return_fields)}, COUNT(n) AS count
|
|
1037
|
+
GROUP BY {", ".join(group_by_fields)}
|
|
1038
|
+
"""
|
|
1039
|
+
result = self.execute_query(gql) # Pure GQL string execution
|
|
1040
|
+
|
|
1041
|
+
output = []
|
|
1042
|
+
for record in result:
|
|
1043
|
+
group_values = {}
|
|
1044
|
+
for i, field in enumerate(group_fields):
|
|
1045
|
+
value = record.values()[i].as_string()
|
|
1046
|
+
group_values[field] = value
|
|
1047
|
+
count_value = record["count"].value
|
|
1048
|
+
output.append({**group_values, "count": count_value})
|
|
1049
|
+
|
|
1050
|
+
return output
|
|
1051
|
+
|
|
1052
|
+
@timed
|
|
1053
|
+
def clear(self) -> None:
|
|
1054
|
+
"""
|
|
1055
|
+
Clear the entire graph if the target database exists.
|
|
1056
|
+
"""
|
|
1057
|
+
try:
|
|
1058
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
1059
|
+
query = f"MATCH (n@Memory) WHERE n.user_name = '{self.config.user_name}' DETACH DELETE n"
|
|
1060
|
+
else:
|
|
1061
|
+
query = "MATCH (n) DETACH DELETE n"
|
|
1062
|
+
|
|
1063
|
+
self.execute_query(query)
|
|
1064
|
+
logger.info("Cleared all nodes from database.")
|
|
1065
|
+
|
|
1066
|
+
except Exception as e:
|
|
1067
|
+
logger.error(f"[ERROR] Failed to clear database: {e}")
|
|
1068
|
+
|
|
1069
|
+
@timed
|
|
1070
|
+
def export_graph(self, include_embedding: bool = False) -> dict[str, Any]:
|
|
1071
|
+
"""
|
|
1072
|
+
Export all graph nodes and edges in a structured form.
|
|
1073
|
+
Args:
|
|
1074
|
+
include_embedding (bool): Whether to include the large embedding field.
|
|
1075
|
+
|
|
1076
|
+
Returns:
|
|
1077
|
+
{
|
|
1078
|
+
"nodes": [ { "id": ..., "memory": ..., "metadata": {...} }, ... ],
|
|
1079
|
+
"edges": [ { "source": ..., "target": ..., "type": ... }, ... ]
|
|
1080
|
+
}
|
|
1081
|
+
"""
|
|
1082
|
+
node_query = "MATCH (n@Memory)"
|
|
1083
|
+
edge_query = "MATCH (a@Memory)-[r]->(b@Memory)"
|
|
1084
|
+
|
|
1085
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
1086
|
+
username = self.config.user_name
|
|
1087
|
+
node_query += f' WHERE n.user_name = "{username}"'
|
|
1088
|
+
edge_query += f' WHERE r.user_name = "{username}"'
|
|
1089
|
+
|
|
1090
|
+
try:
|
|
1091
|
+
if include_embedding:
|
|
1092
|
+
return_fields = "n"
|
|
1093
|
+
else:
|
|
1094
|
+
return_fields = ",".join(
|
|
1095
|
+
[
|
|
1096
|
+
"n.id AS id",
|
|
1097
|
+
"n.memory AS memory",
|
|
1098
|
+
"n.user_name AS user_name",
|
|
1099
|
+
"n.user_id AS user_id",
|
|
1100
|
+
"n.session_id AS session_id",
|
|
1101
|
+
"n.status AS status",
|
|
1102
|
+
"n.key AS key",
|
|
1103
|
+
"n.confidence AS confidence",
|
|
1104
|
+
"n.tags AS tags",
|
|
1105
|
+
"n.created_at AS created_at",
|
|
1106
|
+
"n.updated_at AS updated_at",
|
|
1107
|
+
"n.memory_type AS memory_type",
|
|
1108
|
+
"n.sources AS sources",
|
|
1109
|
+
"n.source AS source",
|
|
1110
|
+
"n.node_type AS node_type",
|
|
1111
|
+
"n.visibility AS visibility",
|
|
1112
|
+
"n.usage AS usage",
|
|
1113
|
+
"n.background AS background",
|
|
1114
|
+
]
|
|
1115
|
+
)
|
|
1116
|
+
|
|
1117
|
+
full_node_query = f"{node_query} RETURN {return_fields}"
|
|
1118
|
+
node_result = self.execute_query(full_node_query, timeout=20)
|
|
1119
|
+
nodes = []
|
|
1120
|
+
logger.debug(f"Debugging: {node_result}")
|
|
1121
|
+
for row in node_result:
|
|
1122
|
+
if include_embedding:
|
|
1123
|
+
props = row.values()[0].as_node().get_properties()
|
|
1124
|
+
else:
|
|
1125
|
+
props = {k: v.value for k, v in row.items()}
|
|
1126
|
+
node = self._parse_node(props)
|
|
1127
|
+
nodes.append(node)
|
|
1128
|
+
except Exception as e:
|
|
1129
|
+
raise RuntimeError(f"[EXPORT GRAPH - NODES] Exception: {e}") from e
|
|
1130
|
+
|
|
1131
|
+
try:
|
|
1132
|
+
full_edge_query = f"{edge_query} RETURN a.id AS source, b.id AS target, type(r) as edge"
|
|
1133
|
+
edge_result = self.execute_query(full_edge_query, timeout=20)
|
|
1134
|
+
edges = [
|
|
1135
|
+
{
|
|
1136
|
+
"source": row.values()[0].value,
|
|
1137
|
+
"target": row.values()[1].value,
|
|
1138
|
+
"type": row.values()[2].value,
|
|
1139
|
+
}
|
|
1140
|
+
for row in edge_result
|
|
1141
|
+
]
|
|
1142
|
+
except Exception as e:
|
|
1143
|
+
raise RuntimeError(f"[EXPORT GRAPH - EDGES] Exception: {e}") from e
|
|
1144
|
+
|
|
1145
|
+
return {"nodes": nodes, "edges": edges}
|
|
1146
|
+
|
|
1147
|
+
@timed
|
|
1148
|
+
def import_graph(self, data: dict[str, Any]) -> None:
|
|
1149
|
+
"""
|
|
1150
|
+
Import the entire graph from a serialized dictionary.
|
|
1151
|
+
|
|
1152
|
+
Args:
|
|
1153
|
+
data: A dictionary containing all nodes and edges to be loaded.
|
|
1154
|
+
"""
|
|
1155
|
+
for node in data.get("nodes", []):
|
|
1156
|
+
id, memory, metadata = _compose_node(node)
|
|
1157
|
+
|
|
1158
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
1159
|
+
metadata["user_name"] = self.config.user_name
|
|
1160
|
+
|
|
1161
|
+
metadata = self._prepare_node_metadata(metadata)
|
|
1162
|
+
metadata.update({"id": id, "memory": memory})
|
|
1163
|
+
properties = ", ".join(f"{k}: {self._format_value(v, k)}" for k, v in metadata.items())
|
|
1164
|
+
node_gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
|
|
1165
|
+
self.execute_query(node_gql)
|
|
1166
|
+
|
|
1167
|
+
for edge in data.get("edges", []):
|
|
1168
|
+
source_id, target_id = edge["source"], edge["target"]
|
|
1169
|
+
edge_type = edge["type"]
|
|
1170
|
+
props = ""
|
|
1171
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
1172
|
+
props = f'{{user_name: "{self.config.user_name}"}}'
|
|
1173
|
+
edge_gql = f'''
|
|
1174
|
+
MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}})
|
|
1175
|
+
INSERT OR IGNORE (a) -[e@{edge_type} {props}]-> (b)
|
|
1176
|
+
'''
|
|
1177
|
+
self.execute_query(edge_gql)
|
|
1178
|
+
|
|
1179
|
+
@timed
|
|
1180
|
+
def get_all_memory_items(self, scope: str, include_embedding: bool = False) -> (list)[dict]:
|
|
1181
|
+
"""
|
|
1182
|
+
Retrieve all memory items of a specific memory_type.
|
|
1183
|
+
|
|
1184
|
+
Args:
|
|
1185
|
+
scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'.
|
|
1186
|
+
include_embedding: with/without embedding
|
|
1187
|
+
|
|
1188
|
+
Returns:
|
|
1189
|
+
list[dict]: Full list of memory items under this scope.
|
|
1190
|
+
"""
|
|
1191
|
+
if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"}:
|
|
1192
|
+
raise ValueError(f"Unsupported memory type scope: {scope}")
|
|
1193
|
+
|
|
1194
|
+
where_clause = f"WHERE n.memory_type = '{scope}'"
|
|
1195
|
+
|
|
1196
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
1197
|
+
where_clause += f" AND n.user_name = '{self.config.user_name}'"
|
|
1198
|
+
|
|
1199
|
+
return_fields = self._build_return_fields(include_embedding)
|
|
1200
|
+
|
|
1201
|
+
query = f"""
|
|
1202
|
+
MATCH (n@Memory)
|
|
1203
|
+
{where_clause}
|
|
1204
|
+
RETURN {return_fields}
|
|
1205
|
+
LIMIT 100
|
|
1206
|
+
"""
|
|
1207
|
+
nodes = []
|
|
1208
|
+
try:
|
|
1209
|
+
results = self.execute_query(query)
|
|
1210
|
+
for row in results:
|
|
1211
|
+
if include_embedding:
|
|
1212
|
+
props = row.values()[0].as_node().get_properties()
|
|
1213
|
+
else:
|
|
1214
|
+
props = {k: v.value for k, v in row.items()}
|
|
1215
|
+
nodes.append(self._parse_node(props))
|
|
1216
|
+
except Exception as e:
|
|
1217
|
+
logger.error(f"Failed to get memories: {e}")
|
|
1218
|
+
return nodes
|
|
1219
|
+
|
|
1220
|
+
@timed
|
|
1221
|
+
def get_structure_optimization_candidates(
|
|
1222
|
+
self, scope: str, include_embedding: bool = False
|
|
1223
|
+
) -> list[dict]:
|
|
1224
|
+
"""
|
|
1225
|
+
Find nodes that are likely candidates for structure optimization:
|
|
1226
|
+
- Isolated nodes, nodes with empty background, or nodes with exactly one child.
|
|
1227
|
+
- Plus: the child of any parent node that has exactly one child.
|
|
1228
|
+
"""
|
|
1229
|
+
|
|
1230
|
+
where_clause = f'''
|
|
1231
|
+
n.memory_type = "{scope}"
|
|
1232
|
+
AND n.status = "activated"
|
|
1233
|
+
'''
|
|
1234
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
1235
|
+
where_clause += f' AND n.user_name = "{self.config.user_name}"'
|
|
1236
|
+
|
|
1237
|
+
return_fields = self._build_return_fields(include_embedding)
|
|
1238
|
+
|
|
1239
|
+
query = f"""
|
|
1240
|
+
USE `{self.db_name}`
|
|
1241
|
+
MATCH (n@Memory)
|
|
1242
|
+
WHERE {where_clause}
|
|
1243
|
+
OPTIONAL MATCH (n)-[@PARENT]->(c@Memory)
|
|
1244
|
+
OPTIONAL MATCH (p@Memory)-[@PARENT]->(n)
|
|
1245
|
+
WHERE c IS NULL AND p IS NULL
|
|
1246
|
+
RETURN {return_fields}
|
|
1247
|
+
"""
|
|
1248
|
+
|
|
1249
|
+
candidates = []
|
|
1250
|
+
try:
|
|
1251
|
+
results = self.execute_query(query)
|
|
1252
|
+
for row in results:
|
|
1253
|
+
if include_embedding:
|
|
1254
|
+
props = row.values()[0].as_node().get_properties()
|
|
1255
|
+
else:
|
|
1256
|
+
props = {k: v.value for k, v in row.items()}
|
|
1257
|
+
candidates.append(self._parse_node(props))
|
|
1258
|
+
except Exception as e:
|
|
1259
|
+
logger.error(f"Failed : {e}, traceback: {traceback.format_exc()}")
|
|
1260
|
+
return candidates
|
|
1261
|
+
|
|
1262
|
+
@timed
|
|
1263
|
+
def drop_database(self) -> None:
|
|
1264
|
+
"""
|
|
1265
|
+
Permanently delete the entire database this instance is using.
|
|
1266
|
+
WARNING: This operation is destructive and cannot be undone.
|
|
1267
|
+
"""
|
|
1268
|
+
if self.config.use_multi_db:
|
|
1269
|
+
self.execute_query(f"DROP GRAPH `{self.db_name}`")
|
|
1270
|
+
logger.info(f"Database '`{self.db_name}`' has been dropped.")
|
|
1271
|
+
else:
|
|
1272
|
+
raise ValueError(
|
|
1273
|
+
f"Refusing to drop protected database: `{self.db_name}` in "
|
|
1274
|
+
f"Shared Database Multi-Tenant mode"
|
|
1275
|
+
)
|
|
1276
|
+
|
|
1277
|
+
@timed
|
|
1278
|
+
def detect_conflicts(self) -> list[tuple[str, str]]:
|
|
1279
|
+
"""
|
|
1280
|
+
Detect conflicting nodes based on logical or semantic inconsistency.
|
|
1281
|
+
Returns:
|
|
1282
|
+
A list of (node_id1, node_id2) tuples that conflict.
|
|
1283
|
+
"""
|
|
1284
|
+
raise NotImplementedError
|
|
1285
|
+
|
|
1286
|
+
@timed
|
|
1287
|
+
# Structure Maintenance
|
|
1288
|
+
def deduplicate_nodes(self) -> None:
|
|
1289
|
+
"""
|
|
1290
|
+
Deduplicate redundant or semantically similar nodes.
|
|
1291
|
+
This typically involves identifying nodes with identical or near-identical memory.
|
|
1292
|
+
"""
|
|
1293
|
+
raise NotImplementedError
|
|
1294
|
+
|
|
1295
|
+
@timed
|
|
1296
|
+
def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]:
|
|
1297
|
+
"""
|
|
1298
|
+
Get the ordered context chain starting from a node, following a relationship type.
|
|
1299
|
+
Args:
|
|
1300
|
+
id: Starting node ID.
|
|
1301
|
+
type: Relationship type to follow (e.g., 'FOLLOWS').
|
|
1302
|
+
Returns:
|
|
1303
|
+
List of ordered node IDs in the chain.
|
|
1304
|
+
"""
|
|
1305
|
+
raise NotImplementedError
|
|
1306
|
+
|
|
1307
|
+
@timed
|
|
1308
|
+
def get_neighbors(
|
|
1309
|
+
self, id: str, type: str, direction: Literal["in", "out", "both"] = "out"
|
|
1310
|
+
) -> list[str]:
|
|
1311
|
+
"""
|
|
1312
|
+
Get connected node IDs in a specific direction and relationship type.
|
|
1313
|
+
Args:
|
|
1314
|
+
id: Source node ID.
|
|
1315
|
+
type: Relationship type.
|
|
1316
|
+
direction: Edge direction to follow ('out', 'in', or 'both').
|
|
1317
|
+
Returns:
|
|
1318
|
+
List of neighboring node IDs.
|
|
1319
|
+
"""
|
|
1320
|
+
raise NotImplementedError
|
|
1321
|
+
|
|
1322
|
+
@timed
|
|
1323
|
+
def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[str]:
|
|
1324
|
+
"""
|
|
1325
|
+
Get the path of nodes from source to target within a limited depth.
|
|
1326
|
+
Args:
|
|
1327
|
+
source_id: Starting node ID.
|
|
1328
|
+
target_id: Target node ID.
|
|
1329
|
+
max_depth: Maximum path length to traverse.
|
|
1330
|
+
Returns:
|
|
1331
|
+
Ordered list of node IDs along the path.
|
|
1332
|
+
"""
|
|
1333
|
+
raise NotImplementedError
|
|
1334
|
+
|
|
1335
|
+
@timed
|
|
1336
|
+
def merge_nodes(self, id1: str, id2: str) -> str:
|
|
1337
|
+
"""
|
|
1338
|
+
Merge two similar or duplicate nodes into one.
|
|
1339
|
+
Args:
|
|
1340
|
+
id1: First node ID.
|
|
1341
|
+
id2: Second node ID.
|
|
1342
|
+
Returns:
|
|
1343
|
+
ID of the resulting merged node.
|
|
1344
|
+
"""
|
|
1345
|
+
raise NotImplementedError
|
|
1346
|
+
|
|
1347
|
+
@timed
|
|
1348
|
+
def _ensure_database_exists(self):
|
|
1349
|
+
graph_type_name = "MemOSBgeM3Type"
|
|
1350
|
+
|
|
1351
|
+
check_type_query = "SHOW GRAPH TYPES"
|
|
1352
|
+
result = self.execute_query(check_type_query, auto_set_db=False)
|
|
1353
|
+
|
|
1354
|
+
type_exists = any(row["graph_type"].as_string() == graph_type_name for row in result)
|
|
1355
|
+
|
|
1356
|
+
if not type_exists:
|
|
1357
|
+
create_tag = f"""
|
|
1358
|
+
CREATE GRAPH TYPE IF NOT EXISTS {graph_type_name} AS {{
|
|
1359
|
+
NODE Memory (:MemoryTag {{
|
|
1360
|
+
id STRING,
|
|
1361
|
+
memory STRING,
|
|
1362
|
+
user_name STRING,
|
|
1363
|
+
user_id STRING,
|
|
1364
|
+
session_id STRING,
|
|
1365
|
+
status STRING,
|
|
1366
|
+
key STRING,
|
|
1367
|
+
confidence FLOAT,
|
|
1368
|
+
tags LIST<STRING>,
|
|
1369
|
+
created_at STRING,
|
|
1370
|
+
updated_at STRING,
|
|
1371
|
+
memory_type STRING,
|
|
1372
|
+
sources LIST<STRING>,
|
|
1373
|
+
source STRING,
|
|
1374
|
+
node_type STRING,
|
|
1375
|
+
visibility STRING,
|
|
1376
|
+
usage LIST<STRING>,
|
|
1377
|
+
background STRING,
|
|
1378
|
+
{self.dim_field} VECTOR<{self.embedding_dimension}, FLOAT>,
|
|
1379
|
+
PRIMARY KEY(id)
|
|
1380
|
+
}}),
|
|
1381
|
+
EDGE RELATE_TO (Memory) -[{{user_name STRING}}]-> (Memory),
|
|
1382
|
+
EDGE PARENT (Memory) -[{{user_name STRING}}]-> (Memory),
|
|
1383
|
+
EDGE AGGREGATE_TO (Memory) -[{{user_name STRING}}]-> (Memory),
|
|
1384
|
+
EDGE MERGED_TO (Memory) -[{{user_name STRING}}]-> (Memory),
|
|
1385
|
+
EDGE INFERS (Memory) -[{{user_name STRING}}]-> (Memory),
|
|
1386
|
+
EDGE FOLLOWS (Memory) -[{{user_name STRING}}]-> (Memory)
|
|
1387
|
+
}}
|
|
1388
|
+
"""
|
|
1389
|
+
self.execute_query(create_tag, auto_set_db=False)
|
|
1390
|
+
else:
|
|
1391
|
+
describe_query = f"DESCRIBE NODE TYPE Memory OF {graph_type_name};"
|
|
1392
|
+
desc_result = self.execute_query(describe_query, auto_set_db=False)
|
|
1393
|
+
|
|
1394
|
+
memory_fields = []
|
|
1395
|
+
for row in desc_result:
|
|
1396
|
+
field_name = row.values()[0].as_string()
|
|
1397
|
+
memory_fields.append(field_name)
|
|
1398
|
+
|
|
1399
|
+
if self.dim_field not in memory_fields:
|
|
1400
|
+
alter_query = f"""
|
|
1401
|
+
ALTER GRAPH TYPE {graph_type_name} {{
|
|
1402
|
+
ALTER NODE TYPE Memory ADD PROPERTIES {{ {self.dim_field} VECTOR<{self.embedding_dimension}, FLOAT> }}
|
|
1403
|
+
}}
|
|
1404
|
+
"""
|
|
1405
|
+
self.execute_query(alter_query, auto_set_db=False)
|
|
1406
|
+
logger.info(f"✅ Add new vector search {self.dim_field} to {graph_type_name}")
|
|
1407
|
+
else:
|
|
1408
|
+
logger.info(f"✅ Graph Type {graph_type_name} already include {self.dim_field}")
|
|
1409
|
+
|
|
1410
|
+
create_graph = f"CREATE GRAPH IF NOT EXISTS `{self.db_name}` TYPED {graph_type_name}"
|
|
1411
|
+
set_graph_working = f"SESSION SET GRAPH `{self.db_name}`"
|
|
1412
|
+
|
|
1413
|
+
try:
|
|
1414
|
+
self.execute_query(create_graph, auto_set_db=False)
|
|
1415
|
+
self.execute_query(set_graph_working)
|
|
1416
|
+
logger.info(f"✅ Graph ``{self.db_name}`` is now the working graph.")
|
|
1417
|
+
except Exception as e:
|
|
1418
|
+
logger.error(f"❌ Failed to create tag: {e} trace: {traceback.format_exc()}")
|
|
1419
|
+
|
|
1420
|
+
@timed
|
|
1421
|
+
def _create_vector_index(
|
|
1422
|
+
self, label: str, vector_property: str, dimensions: int, index_name: str
|
|
1423
|
+
) -> None:
|
|
1424
|
+
"""
|
|
1425
|
+
Create a vector index for the specified property in the label.
|
|
1426
|
+
"""
|
|
1427
|
+
if str(dimensions) == str(self.default_memory_dimension):
|
|
1428
|
+
index_name = f"idx_{vector_property}"
|
|
1429
|
+
vector_name = vector_property
|
|
1430
|
+
else:
|
|
1431
|
+
index_name = f"idx_{vector_property}_{dimensions}"
|
|
1432
|
+
vector_name = f"{vector_property}_{dimensions}"
|
|
1433
|
+
|
|
1434
|
+
create_vector_index = f"""
|
|
1435
|
+
CREATE VECTOR INDEX IF NOT EXISTS {index_name}
|
|
1436
|
+
ON NODE {label}::{vector_name}
|
|
1437
|
+
OPTIONS {{
|
|
1438
|
+
DIM: {dimensions},
|
|
1439
|
+
METRIC: IP,
|
|
1440
|
+
TYPE: IVF,
|
|
1441
|
+
NLIST: 100,
|
|
1442
|
+
TRAINSIZE: 1000
|
|
1443
|
+
}}
|
|
1444
|
+
FOR `{self.db_name}`
|
|
1445
|
+
"""
|
|
1446
|
+
self.execute_query(create_vector_index)
|
|
1447
|
+
logger.info(
|
|
1448
|
+
f"✅ Ensure {label}::{vector_property} vector index {index_name} "
|
|
1449
|
+
f"exists (DIM={dimensions})"
|
|
1450
|
+
)
|
|
1451
|
+
|
|
1452
|
+
@timed
|
|
1453
|
+
def _create_basic_property_indexes(self) -> None:
|
|
1454
|
+
"""
|
|
1455
|
+
Create standard B-tree indexes on status, memory_type, created_at
|
|
1456
|
+
and updated_at fields.
|
|
1457
|
+
Create standard B-tree indexes on user_name when use Shared Database
|
|
1458
|
+
Multi-Tenant Mode.
|
|
1459
|
+
"""
|
|
1460
|
+
fields = ["status", "memory_type", "created_at", "updated_at"]
|
|
1461
|
+
if not self.config.use_multi_db:
|
|
1462
|
+
fields.append("user_name")
|
|
1463
|
+
|
|
1464
|
+
for field in fields:
|
|
1465
|
+
index_name = f"idx_memory_{field}"
|
|
1466
|
+
gql = f"""
|
|
1467
|
+
CREATE INDEX IF NOT EXISTS {index_name} ON NODE Memory({field})
|
|
1468
|
+
FOR `{self.db_name}`
|
|
1469
|
+
"""
|
|
1470
|
+
try:
|
|
1471
|
+
self.execute_query(gql)
|
|
1472
|
+
logger.info(f"✅ Created index: {index_name} on field {field}")
|
|
1473
|
+
except Exception as e:
|
|
1474
|
+
logger.error(
|
|
1475
|
+
f"❌ Failed to create index {index_name}: {e}, trace: {traceback.format_exc()}"
|
|
1476
|
+
)
|
|
1477
|
+
|
|
1478
|
+
@timed
|
|
1479
|
+
def _index_exists(self, index_name: str) -> bool:
|
|
1480
|
+
"""
|
|
1481
|
+
Check if an index with the given name exists.
|
|
1482
|
+
"""
|
|
1483
|
+
"""
|
|
1484
|
+
Check if a vector index with the given name exists in NebulaGraph.
|
|
1485
|
+
|
|
1486
|
+
Args:
|
|
1487
|
+
index_name (str): The name of the index to check.
|
|
1488
|
+
|
|
1489
|
+
Returns:
|
|
1490
|
+
bool: True if the index exists, False otherwise.
|
|
1491
|
+
"""
|
|
1492
|
+
query = "SHOW VECTOR INDEXES"
|
|
1493
|
+
try:
|
|
1494
|
+
result = self.execute_query(query)
|
|
1495
|
+
return any(row.values()[0].as_string() == index_name for row in result)
|
|
1496
|
+
except Exception as e:
|
|
1497
|
+
logger.error(f"[Nebula] Failed to check index existence: {e}")
|
|
1498
|
+
return False
|
|
1499
|
+
|
|
1500
|
+
@timed
|
|
1501
|
+
def _parse_value(self, value: Any) -> Any:
|
|
1502
|
+
"""turn Nebula ValueWrapper to Python type"""
|
|
1503
|
+
from nebulagraph_python.value_wrapper import ValueWrapper
|
|
1504
|
+
|
|
1505
|
+
if value is None or (hasattr(value, "is_null") and value.is_null()):
|
|
1506
|
+
return None
|
|
1507
|
+
try:
|
|
1508
|
+
prim = value.cast_primitive() if isinstance(value, ValueWrapper) else value
|
|
1509
|
+
except Exception as e:
|
|
1510
|
+
logger.warning(f"Error when decode Nebula ValueWrapper: {e}")
|
|
1511
|
+
prim = value.cast() if isinstance(value, ValueWrapper) else value
|
|
1512
|
+
|
|
1513
|
+
if isinstance(prim, ValueWrapper):
|
|
1514
|
+
return self._parse_value(prim)
|
|
1515
|
+
if isinstance(prim, list):
|
|
1516
|
+
return [self._parse_value(v) for v in prim]
|
|
1517
|
+
if type(prim).__name__ == "NVector":
|
|
1518
|
+
return list(prim.values)
|
|
1519
|
+
|
|
1520
|
+
return prim # already a Python primitive
|
|
1521
|
+
|
|
1522
|
+
def _parse_node(self, props: dict[str, Any]) -> dict[str, Any]:
|
|
1523
|
+
parsed = {k: self._parse_value(v) for k, v in props.items()}
|
|
1524
|
+
|
|
1525
|
+
for tf in ("created_at", "updated_at"):
|
|
1526
|
+
if tf in parsed and parsed[tf] is not None:
|
|
1527
|
+
parsed[tf] = _normalize_datetime(parsed[tf])
|
|
1528
|
+
|
|
1529
|
+
node_id = parsed.pop("id")
|
|
1530
|
+
memory = parsed.pop("memory", "")
|
|
1531
|
+
parsed.pop("user_name", None)
|
|
1532
|
+
metadata = parsed
|
|
1533
|
+
metadata["type"] = metadata.pop("node_type")
|
|
1534
|
+
|
|
1535
|
+
if self.dim_field in metadata:
|
|
1536
|
+
metadata["embedding"] = metadata.pop(self.dim_field)
|
|
1537
|
+
|
|
1538
|
+
return {"id": node_id, "memory": memory, "metadata": metadata}
|
|
1539
|
+
|
|
1540
|
+
@timed
|
|
1541
|
+
def _prepare_node_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]:
|
|
1542
|
+
"""
|
|
1543
|
+
Ensure metadata has proper datetime fields and normalized types.
|
|
1544
|
+
|
|
1545
|
+
- Fill `created_at` and `updated_at` if missing (in ISO 8601 format).
|
|
1546
|
+
- Convert embedding to list of float if present.
|
|
1547
|
+
"""
|
|
1548
|
+
now = datetime.utcnow().isoformat()
|
|
1549
|
+
metadata["node_type"] = metadata.pop("type")
|
|
1550
|
+
|
|
1551
|
+
# Fill timestamps if missing
|
|
1552
|
+
metadata.setdefault("created_at", now)
|
|
1553
|
+
metadata.setdefault("updated_at", now)
|
|
1554
|
+
|
|
1555
|
+
# Normalize embedding type
|
|
1556
|
+
embedding = metadata.get("embedding")
|
|
1557
|
+
if embedding and isinstance(embedding, list):
|
|
1558
|
+
metadata[self.dim_field] = _normalize([float(x) for x in embedding])
|
|
1559
|
+
|
|
1560
|
+
return metadata
|
|
1561
|
+
|
|
1562
|
+
@timed
|
|
1563
|
+
def _format_value(self, val: Any, key: str = "") -> str:
|
|
1564
|
+
from nebulagraph_python.py_data_types import NVector
|
|
1565
|
+
|
|
1566
|
+
if isinstance(val, str):
|
|
1567
|
+
return f'"{_escape_str(val)}"'
|
|
1568
|
+
elif isinstance(val, (int | float)):
|
|
1569
|
+
return str(val)
|
|
1570
|
+
elif isinstance(val, datetime):
|
|
1571
|
+
return f'datetime("{val.isoformat()}")'
|
|
1572
|
+
elif isinstance(val, list):
|
|
1573
|
+
if key == self.dim_field:
|
|
1574
|
+
dim = len(val)
|
|
1575
|
+
joined = ",".join(str(float(x)) for x in val)
|
|
1576
|
+
return f"VECTOR<{dim}, FLOAT>([{joined}])"
|
|
1577
|
+
else:
|
|
1578
|
+
return f"[{', '.join(self._format_value(v) for v in val)}]"
|
|
1579
|
+
elif isinstance(val, NVector):
|
|
1580
|
+
if key == self.dim_field:
|
|
1581
|
+
dim = len(val)
|
|
1582
|
+
joined = ",".join(str(float(x)) for x in val)
|
|
1583
|
+
return f"VECTOR<{dim}, FLOAT>([{joined}])"
|
|
1584
|
+
elif val is None:
|
|
1585
|
+
return "NULL"
|
|
1586
|
+
else:
|
|
1587
|
+
return f'"{_escape_str(str(val))}"'
|
|
1588
|
+
|
|
1589
|
+
@timed
|
|
1590
|
+
def _metadata_filter(self, metadata: dict[str, Any]) -> dict[str, Any]:
|
|
1591
|
+
"""
|
|
1592
|
+
Filter and validate metadata dictionary against the Memory node schema.
|
|
1593
|
+
- Removes keys not in schema.
|
|
1594
|
+
- Warns if required fields are missing.
|
|
1595
|
+
"""
|
|
1596
|
+
|
|
1597
|
+
dim_fields = {self.dim_field}
|
|
1598
|
+
|
|
1599
|
+
allowed_fields = self.common_fields | dim_fields
|
|
1600
|
+
|
|
1601
|
+
missing_fields = allowed_fields - metadata.keys()
|
|
1602
|
+
if missing_fields:
|
|
1603
|
+
logger.info(f"Metadata missing required fields: {sorted(missing_fields)}")
|
|
1604
|
+
|
|
1605
|
+
filtered_metadata = {k: v for k, v in metadata.items() if k in allowed_fields}
|
|
1606
|
+
|
|
1607
|
+
return filtered_metadata
|
|
1608
|
+
|
|
1609
|
+
def _build_return_fields(self, include_embedding: bool = False) -> str:
|
|
1610
|
+
if include_embedding:
|
|
1611
|
+
return "n"
|
|
1612
|
+
return ", ".join(f"n.{field} AS {field}" for field in self.common_fields)
|