MemoryOS 1.0.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MemoryOS might be problematic. Click here for more details.
- {memoryos-1.0.0.dist-info → memoryos-1.1.1.dist-info}/METADATA +8 -2
- {memoryos-1.0.0.dist-info → memoryos-1.1.1.dist-info}/RECORD +92 -69
- {memoryos-1.0.0.dist-info → memoryos-1.1.1.dist-info}/WHEEL +1 -1
- memos/__init__.py +1 -1
- memos/api/client.py +109 -0
- memos/api/config.py +35 -8
- memos/api/context/dependencies.py +15 -66
- memos/api/middleware/request_context.py +63 -0
- memos/api/product_api.py +5 -2
- memos/api/product_models.py +107 -16
- memos/api/routers/product_router.py +62 -19
- memos/api/start_api.py +13 -0
- memos/configs/graph_db.py +4 -0
- memos/configs/mem_scheduler.py +38 -3
- memos/configs/memory.py +13 -0
- memos/configs/reranker.py +18 -0
- memos/context/context.py +255 -0
- memos/embedders/factory.py +2 -0
- memos/graph_dbs/base.py +4 -2
- memos/graph_dbs/nebular.py +368 -223
- memos/graph_dbs/neo4j.py +49 -13
- memos/graph_dbs/neo4j_community.py +13 -3
- memos/llms/factory.py +2 -0
- memos/llms/openai.py +74 -2
- memos/llms/vllm.py +2 -0
- memos/log.py +128 -4
- memos/mem_cube/general.py +3 -1
- memos/mem_os/core.py +89 -23
- memos/mem_os/main.py +3 -6
- memos/mem_os/product.py +418 -154
- memos/mem_os/utils/reference_utils.py +20 -0
- memos/mem_reader/factory.py +2 -0
- memos/mem_reader/simple_struct.py +204 -82
- memos/mem_scheduler/analyzer/__init__.py +0 -0
- memos/mem_scheduler/analyzer/mos_for_test_scheduler.py +569 -0
- memos/mem_scheduler/analyzer/scheduler_for_eval.py +280 -0
- memos/mem_scheduler/base_scheduler.py +126 -56
- memos/mem_scheduler/general_modules/dispatcher.py +2 -2
- memos/mem_scheduler/general_modules/misc.py +99 -1
- memos/mem_scheduler/general_modules/scheduler_logger.py +17 -11
- memos/mem_scheduler/general_scheduler.py +40 -88
- memos/mem_scheduler/memory_manage_modules/__init__.py +5 -0
- memos/mem_scheduler/memory_manage_modules/memory_filter.py +308 -0
- memos/mem_scheduler/{general_modules → memory_manage_modules}/retriever.py +34 -7
- memos/mem_scheduler/monitors/dispatcher_monitor.py +9 -8
- memos/mem_scheduler/monitors/general_monitor.py +119 -39
- memos/mem_scheduler/optimized_scheduler.py +124 -0
- memos/mem_scheduler/orm_modules/__init__.py +0 -0
- memos/mem_scheduler/orm_modules/base_model.py +635 -0
- memos/mem_scheduler/orm_modules/monitor_models.py +261 -0
- memos/mem_scheduler/scheduler_factory.py +2 -0
- memos/mem_scheduler/schemas/monitor_schemas.py +96 -29
- memos/mem_scheduler/utils/config_utils.py +100 -0
- memos/mem_scheduler/utils/db_utils.py +33 -0
- memos/mem_scheduler/utils/filter_utils.py +1 -1
- memos/mem_scheduler/webservice_modules/__init__.py +0 -0
- memos/mem_user/mysql_user_manager.py +4 -2
- memos/memories/activation/kv.py +2 -1
- memos/memories/textual/item.py +96 -17
- memos/memories/textual/naive.py +1 -1
- memos/memories/textual/tree.py +57 -3
- memos/memories/textual/tree_text_memory/organize/handler.py +4 -2
- memos/memories/textual/tree_text_memory/organize/manager.py +28 -14
- memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py +1 -2
- memos/memories/textual/tree_text_memory/organize/reorganizer.py +75 -23
- memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +10 -6
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever.py +6 -2
- memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +2 -0
- memos/memories/textual/tree_text_memory/retrieve/recall.py +119 -21
- memos/memories/textual/tree_text_memory/retrieve/searcher.py +172 -44
- memos/memories/textual/tree_text_memory/retrieve/utils.py +6 -4
- memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py +5 -4
- memos/memos_tools/notification_utils.py +46 -0
- memos/memos_tools/singleton.py +174 -0
- memos/memos_tools/thread_safe_dict.py +22 -0
- memos/memos_tools/thread_safe_dict_segment.py +382 -0
- memos/parsers/factory.py +2 -0
- memos/reranker/__init__.py +4 -0
- memos/reranker/base.py +24 -0
- memos/reranker/concat.py +59 -0
- memos/reranker/cosine_local.py +96 -0
- memos/reranker/factory.py +48 -0
- memos/reranker/http_bge.py +312 -0
- memos/reranker/noop.py +16 -0
- memos/templates/mem_reader_prompts.py +289 -40
- memos/templates/mem_scheduler_prompts.py +242 -0
- memos/templates/mos_prompts.py +133 -60
- memos/types.py +4 -1
- memos/api/context/context.py +0 -147
- memos/mem_scheduler/mos_for_test_scheduler.py +0 -146
- {memoryos-1.0.0.dist-info → memoryos-1.1.1.dist-info}/entry_points.txt +0 -0
- {memoryos-1.0.0.dist-info → memoryos-1.1.1.dist-info/licenses}/LICENSE +0 -0
- /memos/mem_scheduler/{general_modules → webservice_modules}/rabbitmq_service.py +0 -0
- /memos/mem_scheduler/{general_modules → webservice_modules}/redis_service.py +0 -0
memos/graph_dbs/nebular.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
+
import json
|
|
1
2
|
import traceback
|
|
2
3
|
|
|
3
4
|
from contextlib import suppress
|
|
4
5
|
from datetime import datetime
|
|
5
|
-
from queue import Empty, Queue
|
|
6
6
|
from threading import Lock
|
|
7
|
-
from typing import Any, Literal
|
|
7
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Literal
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
|
|
@@ -15,9 +15,28 @@ from memos.log import get_logger
|
|
|
15
15
|
from memos.utils import timed
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from nebulagraph_python import (
|
|
20
|
+
NebulaClient,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
18
24
|
logger = get_logger(__name__)
|
|
19
25
|
|
|
20
26
|
|
|
27
|
+
_TRANSIENT_ERR_KEYS = (
|
|
28
|
+
"Session not found",
|
|
29
|
+
"Connection not established",
|
|
30
|
+
"timeout",
|
|
31
|
+
"deadline exceeded",
|
|
32
|
+
"Broken pipe",
|
|
33
|
+
"EOFError",
|
|
34
|
+
"socket closed",
|
|
35
|
+
"connection reset",
|
|
36
|
+
"connection refused",
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
21
40
|
@timed
|
|
22
41
|
def _normalize(vec: list[float]) -> list[float]:
|
|
23
42
|
v = np.asarray(vec, dtype=np.float32)
|
|
@@ -35,7 +54,28 @@ def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]:
|
|
|
35
54
|
|
|
36
55
|
@timed
|
|
37
56
|
def _escape_str(value: str) -> str:
|
|
38
|
-
|
|
57
|
+
out = []
|
|
58
|
+
for ch in value:
|
|
59
|
+
code = ord(ch)
|
|
60
|
+
if ch == "\\":
|
|
61
|
+
out.append("\\\\")
|
|
62
|
+
elif ch == '"':
|
|
63
|
+
out.append('\\"')
|
|
64
|
+
elif ch == "\n":
|
|
65
|
+
out.append("\\n")
|
|
66
|
+
elif ch == "\r":
|
|
67
|
+
out.append("\\r")
|
|
68
|
+
elif ch == "\t":
|
|
69
|
+
out.append("\\t")
|
|
70
|
+
elif ch == "\b":
|
|
71
|
+
out.append("\\b")
|
|
72
|
+
elif ch == "\f":
|
|
73
|
+
out.append("\\f")
|
|
74
|
+
elif code < 0x20 or code in (0x2028, 0x2029):
|
|
75
|
+
out.append(f"\\u{code:04x}")
|
|
76
|
+
else:
|
|
77
|
+
out.append(ch)
|
|
78
|
+
return "".join(out)
|
|
39
79
|
|
|
40
80
|
|
|
41
81
|
@timed
|
|
@@ -61,145 +101,202 @@ def _normalize_datetime(val):
|
|
|
61
101
|
return str(val)
|
|
62
102
|
|
|
63
103
|
|
|
64
|
-
class
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
class SessionPool:
|
|
69
|
-
@require_python_package(
|
|
70
|
-
import_name="nebulagraph_python",
|
|
71
|
-
install_command="pip install ... @Tianxing",
|
|
72
|
-
install_link=".....",
|
|
73
|
-
)
|
|
74
|
-
def __init__(
|
|
75
|
-
self,
|
|
76
|
-
hosts: list[str],
|
|
77
|
-
user: str,
|
|
78
|
-
password: str,
|
|
79
|
-
minsize: int = 1,
|
|
80
|
-
maxsize: int = 10000,
|
|
81
|
-
):
|
|
82
|
-
self.hosts = hosts
|
|
83
|
-
self.user = user
|
|
84
|
-
self.password = password
|
|
85
|
-
self.minsize = minsize
|
|
86
|
-
self.maxsize = maxsize
|
|
87
|
-
self.pool = Queue(maxsize)
|
|
88
|
-
self.lock = Lock()
|
|
89
|
-
|
|
90
|
-
self.clients = []
|
|
91
|
-
|
|
92
|
-
for _ in range(minsize):
|
|
93
|
-
self._create_and_add_client()
|
|
94
|
-
|
|
95
|
-
@timed
|
|
96
|
-
def _create_and_add_client(self):
|
|
97
|
-
from nebulagraph_python import NebulaClient
|
|
98
|
-
|
|
99
|
-
client = NebulaClient(self.hosts, self.user, self.password)
|
|
100
|
-
self.pool.put(client)
|
|
101
|
-
self.clients.append(client)
|
|
102
|
-
|
|
103
|
-
@timed
|
|
104
|
-
def get_client(self, timeout: float = 5.0):
|
|
105
|
-
try:
|
|
106
|
-
return self.pool.get(timeout=timeout)
|
|
107
|
-
except Empty:
|
|
108
|
-
with self.lock:
|
|
109
|
-
if len(self.clients) < self.maxsize:
|
|
110
|
-
from nebulagraph_python import NebulaClient
|
|
111
|
-
|
|
112
|
-
client = NebulaClient(self.hosts, self.user, self.password)
|
|
113
|
-
self.clients.append(client)
|
|
114
|
-
return client
|
|
115
|
-
raise RuntimeError("NebulaClientPool exhausted") from None
|
|
116
|
-
|
|
117
|
-
@timed
|
|
118
|
-
def return_client(self, client):
|
|
119
|
-
try:
|
|
120
|
-
client.execute("YIELD 1")
|
|
121
|
-
self.pool.put(client)
|
|
122
|
-
except Exception:
|
|
123
|
-
logger.info("[Pool] Client dead, replacing...")
|
|
124
|
-
self.replace_client(client)
|
|
125
|
-
|
|
126
|
-
@timed
|
|
127
|
-
def close(self):
|
|
128
|
-
for client in self.clients:
|
|
129
|
-
with suppress(Exception):
|
|
130
|
-
client.close()
|
|
131
|
-
self.clients.clear()
|
|
104
|
+
class NebulaGraphDB(BaseGraphDB):
|
|
105
|
+
"""
|
|
106
|
+
NebulaGraph-based implementation of a graph memory store.
|
|
107
|
+
"""
|
|
132
108
|
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
109
|
+
# ====== shared pool cache & refcount ======
|
|
110
|
+
# These are process-local; in a multi-process model each process will
|
|
111
|
+
# have its own cache.
|
|
112
|
+
_CLIENT_CACHE: ClassVar[dict[str, "NebulaClient"]] = {}
|
|
113
|
+
_CLIENT_REFCOUNT: ClassVar[dict[str, int]] = {}
|
|
114
|
+
_CLIENT_LOCK: ClassVar[Lock] = Lock()
|
|
115
|
+
_CLIENT_INIT_DONE: ClassVar[set[str]] = set()
|
|
116
|
+
|
|
117
|
+
@staticmethod
|
|
118
|
+
def _get_hosts_from_cfg(cfg: NebulaGraphDBConfig) -> list[str]:
|
|
119
|
+
hosts = getattr(cfg, "uri", None) or getattr(cfg, "hosts", None)
|
|
120
|
+
if isinstance(hosts, str):
|
|
121
|
+
return [hosts]
|
|
122
|
+
return list(hosts or [])
|
|
123
|
+
|
|
124
|
+
@staticmethod
|
|
125
|
+
def _make_client_key(cfg: NebulaGraphDBConfig) -> str:
|
|
126
|
+
hosts = NebulaGraphDB._get_hosts_from_cfg(cfg)
|
|
127
|
+
return "|".join(
|
|
128
|
+
[
|
|
129
|
+
"nebula-sync",
|
|
130
|
+
",".join(hosts),
|
|
131
|
+
str(getattr(cfg, "user", "")),
|
|
132
|
+
str(getattr(cfg, "use_multi_db", False)),
|
|
133
|
+
str(getattr(cfg, "space", "")),
|
|
134
|
+
]
|
|
135
|
+
)
|
|
138
136
|
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
137
|
+
@classmethod
|
|
138
|
+
def _bootstrap_admin(cls, cfg: NebulaGraphDBConfig, client: "NebulaClient") -> "NebulaGraphDB":
|
|
139
|
+
tmp = object.__new__(NebulaGraphDB)
|
|
140
|
+
tmp.config = cfg
|
|
141
|
+
tmp.db_name = cfg.space
|
|
142
|
+
tmp.user_name = getattr(cfg, "user_name", None)
|
|
143
|
+
tmp.embedding_dimension = getattr(cfg, "embedding_dimension", 3072)
|
|
144
|
+
tmp.default_memory_dimension = 3072
|
|
145
|
+
tmp.common_fields = {
|
|
146
|
+
"id",
|
|
147
|
+
"memory",
|
|
148
|
+
"user_name",
|
|
149
|
+
"user_id",
|
|
150
|
+
"session_id",
|
|
151
|
+
"status",
|
|
152
|
+
"key",
|
|
153
|
+
"confidence",
|
|
154
|
+
"tags",
|
|
155
|
+
"created_at",
|
|
156
|
+
"updated_at",
|
|
157
|
+
"memory_type",
|
|
158
|
+
"sources",
|
|
159
|
+
"source",
|
|
160
|
+
"node_type",
|
|
161
|
+
"visibility",
|
|
162
|
+
"usage",
|
|
163
|
+
"background",
|
|
164
|
+
}
|
|
165
|
+
tmp.base_fields = set(tmp.common_fields) - {"usage"}
|
|
166
|
+
tmp.heavy_fields = {"usage"}
|
|
167
|
+
tmp.dim_field = (
|
|
168
|
+
f"embedding_{tmp.embedding_dimension}"
|
|
169
|
+
if str(tmp.embedding_dimension) != str(tmp.default_memory_dimension)
|
|
170
|
+
else "embedding"
|
|
171
|
+
)
|
|
172
|
+
tmp.system_db_name = "system" if getattr(cfg, "use_multi_db", False) else cfg.space
|
|
173
|
+
tmp._client = client
|
|
174
|
+
tmp._owns_client = False
|
|
175
|
+
return tmp
|
|
176
|
+
|
|
177
|
+
@classmethod
|
|
178
|
+
def _get_or_create_shared_client(cls, cfg: NebulaGraphDBConfig) -> tuple[str, "NebulaClient"]:
|
|
179
|
+
from nebulagraph_python import (
|
|
180
|
+
ConnectionConfig,
|
|
181
|
+
NebulaClient,
|
|
182
|
+
SessionConfig,
|
|
183
|
+
SessionPoolConfig,
|
|
184
|
+
)
|
|
143
185
|
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
186
|
+
key = cls._make_client_key(cfg)
|
|
187
|
+
with cls._CLIENT_LOCK:
|
|
188
|
+
client = cls._CLIENT_CACHE.get(key)
|
|
189
|
+
if client is None:
|
|
190
|
+
# Connection setting
|
|
191
|
+
conn_conf: ConnectionConfig | None = getattr(cfg, "conn_config", None)
|
|
192
|
+
if conn_conf is None:
|
|
193
|
+
conn_conf = ConnectionConfig.from_defults(
|
|
194
|
+
cls._get_hosts_from_cfg(cfg),
|
|
195
|
+
getattr(cfg, "ssl_param", None),
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
sess_conf = SessionConfig(graph=getattr(cfg, "space", None))
|
|
199
|
+
pool_conf = SessionPoolConfig(
|
|
200
|
+
size=int(getattr(cfg, "max_client", 1000)), wait_timeout=5000
|
|
201
|
+
)
|
|
147
202
|
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
203
|
+
client = NebulaClient(
|
|
204
|
+
hosts=conn_conf.hosts,
|
|
205
|
+
username=cfg.user,
|
|
206
|
+
password=cfg.password,
|
|
207
|
+
conn_config=conn_conf,
|
|
208
|
+
session_config=sess_conf,
|
|
209
|
+
session_pool_config=pool_conf,
|
|
210
|
+
)
|
|
211
|
+
cls._CLIENT_CACHE[key] = client
|
|
212
|
+
cls._CLIENT_REFCOUNT[key] = 0
|
|
213
|
+
logger.info(f"[NebulaGraphDBSync] Created shared NebulaClient key={key}")
|
|
151
214
|
|
|
152
|
-
|
|
215
|
+
cls._CLIENT_REFCOUNT[key] = cls._CLIENT_REFCOUNT.get(key, 0) + 1
|
|
153
216
|
|
|
154
|
-
|
|
155
|
-
def reset_pool(self):
|
|
156
|
-
"""⚠️ Emergency reset: Close all clients and clear the pool."""
|
|
157
|
-
logger.warning("[Pool] Resetting all clients. Existing sessions will be lost.")
|
|
158
|
-
with self.lock:
|
|
159
|
-
for client in self.clients:
|
|
217
|
+
if getattr(cfg, "auto_create", False) and key not in cls._CLIENT_INIT_DONE:
|
|
160
218
|
try:
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
219
|
+
pass
|
|
220
|
+
finally:
|
|
221
|
+
pass
|
|
222
|
+
|
|
223
|
+
if getattr(cfg, "auto_create", False) and key not in cls._CLIENT_INIT_DONE:
|
|
224
|
+
with cls._CLIENT_LOCK:
|
|
225
|
+
if key not in cls._CLIENT_INIT_DONE:
|
|
226
|
+
admin = cls._bootstrap_admin(cfg, client)
|
|
227
|
+
try:
|
|
228
|
+
admin._ensure_database_exists()
|
|
229
|
+
admin._create_basic_property_indexes()
|
|
230
|
+
admin._create_vector_index(
|
|
231
|
+
dimensions=int(
|
|
232
|
+
admin.embedding_dimension or admin.default_memory_dimension
|
|
233
|
+
),
|
|
234
|
+
)
|
|
235
|
+
cls._CLIENT_INIT_DONE.add(key)
|
|
236
|
+
logger.info("[NebulaGraphDBSync] One-time init done")
|
|
237
|
+
except Exception:
|
|
238
|
+
logger.exception("[NebulaGraphDBSync] One-time init failed")
|
|
239
|
+
|
|
240
|
+
return key, client
|
|
241
|
+
|
|
242
|
+
def _refresh_client(self):
|
|
243
|
+
"""
|
|
244
|
+
refresh NebulaClient:
|
|
245
|
+
"""
|
|
246
|
+
old_key = getattr(self, "_client_key", None)
|
|
247
|
+
if not old_key:
|
|
248
|
+
return
|
|
249
|
+
|
|
250
|
+
cls = self.__class__
|
|
251
|
+
with cls._CLIENT_LOCK:
|
|
252
|
+
try:
|
|
253
|
+
if old_key in cls._CLIENT_CACHE:
|
|
254
|
+
try:
|
|
255
|
+
cls._CLIENT_CACHE[old_key].close()
|
|
256
|
+
except Exception as e:
|
|
257
|
+
logger.warning(f"[refresh_client] close old client error: {e}")
|
|
258
|
+
finally:
|
|
259
|
+
cls._CLIENT_CACHE.pop(old_key, None)
|
|
260
|
+
finally:
|
|
261
|
+
cls._CLIENT_REFCOUNT[old_key] = 0
|
|
262
|
+
|
|
263
|
+
new_key, new_client = cls._get_or_create_shared_client(self.config)
|
|
264
|
+
self._client_key = new_key
|
|
265
|
+
self._client = new_client
|
|
266
|
+
logger.info(f"[NebulaGraphDBSync] client refreshed: {old_key} -> {new_key}")
|
|
267
|
+
|
|
268
|
+
@classmethod
|
|
269
|
+
def _release_shared_client(cls, key: str):
|
|
270
|
+
with cls._CLIENT_LOCK:
|
|
271
|
+
if key not in cls._CLIENT_CACHE:
|
|
272
|
+
return
|
|
273
|
+
cls._CLIENT_REFCOUNT[key] = max(0, cls._CLIENT_REFCOUNT.get(key, 0) - 1)
|
|
274
|
+
if cls._CLIENT_REFCOUNT[key] == 0:
|
|
166
275
|
try:
|
|
167
|
-
|
|
168
|
-
except
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
self.clients.append(new_client)
|
|
188
|
-
|
|
189
|
-
self.pool.put(new_client)
|
|
190
|
-
|
|
191
|
-
logger.info("[Pool] Replaced dead client with a new one.")
|
|
192
|
-
return new_client
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
class NebulaGraphDB(BaseGraphDB):
|
|
196
|
-
"""
|
|
197
|
-
NebulaGraph-based implementation of a graph memory store.
|
|
198
|
-
"""
|
|
276
|
+
cls._CLIENT_CACHE[key].close()
|
|
277
|
+
except Exception as e:
|
|
278
|
+
logger.warning(f"[NebulaGraphDBSync] Error closing client: {e}")
|
|
279
|
+
finally:
|
|
280
|
+
cls._CLIENT_CACHE.pop(key, None)
|
|
281
|
+
cls._CLIENT_REFCOUNT.pop(key, None)
|
|
282
|
+
logger.info(f"[NebulaGraphDBSync] Closed & removed client key={key}")
|
|
283
|
+
|
|
284
|
+
@classmethod
|
|
285
|
+
def close_all_shared_clients(cls):
|
|
286
|
+
with cls._CLIENT_LOCK:
|
|
287
|
+
for key, client in list(cls._CLIENT_CACHE.items()):
|
|
288
|
+
try:
|
|
289
|
+
client.close()
|
|
290
|
+
except Exception as e:
|
|
291
|
+
logger.warning(f"[NebulaGraphDBSync] Error closing client {key}: {e}")
|
|
292
|
+
finally:
|
|
293
|
+
logger.info(f"[NebulaGraphDBSync] Closed client key={key}")
|
|
294
|
+
cls._CLIENT_CACHE.clear()
|
|
295
|
+
cls._CLIENT_REFCOUNT.clear()
|
|
199
296
|
|
|
200
297
|
@require_python_package(
|
|
201
298
|
import_name="nebulagraph_python",
|
|
202
|
-
install_command="pip install
|
|
299
|
+
install_command="pip install nebulagraph-python>=5.1.1",
|
|
203
300
|
install_link=".....",
|
|
204
301
|
)
|
|
205
302
|
def __init__(self, config: NebulaGraphDBConfig):
|
|
@@ -246,48 +343,65 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
246
343
|
"usage",
|
|
247
344
|
"background",
|
|
248
345
|
}
|
|
346
|
+
self.base_fields = set(self.common_fields) - {"usage"}
|
|
347
|
+
self.heavy_fields = {"usage"}
|
|
249
348
|
self.dim_field = (
|
|
250
349
|
f"embedding_{self.embedding_dimension}"
|
|
251
350
|
if (str(self.embedding_dimension) != str(self.default_memory_dimension))
|
|
252
351
|
else "embedding"
|
|
253
352
|
)
|
|
254
353
|
self.system_db_name = "system" if config.use_multi_db else config.space
|
|
255
|
-
self.pool = SessionPool(
|
|
256
|
-
hosts=config.get("uri"),
|
|
257
|
-
user=config.get("user"),
|
|
258
|
-
password=config.get("password"),
|
|
259
|
-
minsize=1,
|
|
260
|
-
maxsize=config.get("max_client", 1000),
|
|
261
|
-
)
|
|
262
|
-
|
|
263
|
-
if config.auto_create:
|
|
264
|
-
self._ensure_database_exists()
|
|
265
|
-
|
|
266
|
-
self.execute_query(f"SESSION SET GRAPH `{self.db_name}`")
|
|
267
354
|
|
|
268
|
-
#
|
|
269
|
-
|
|
355
|
+
# ---- NEW: pool acquisition strategy
|
|
356
|
+
# Get or create a shared pool from the class-level cache
|
|
357
|
+
self._client_key, self._client = self._get_or_create_shared_client(config)
|
|
358
|
+
self._owns_client = True
|
|
270
359
|
|
|
271
360
|
logger.info("Connected to NebulaGraph successfully.")
|
|
272
361
|
|
|
273
362
|
@timed
|
|
274
|
-
def execute_query(self, gql: str, timeout: float =
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
return client.execute(gql, timeout=timeout)
|
|
363
|
+
def execute_query(self, gql: str, timeout: float = 60.0, auto_set_db: bool = True):
|
|
364
|
+
def _wrap_use_db(q: str) -> str:
|
|
365
|
+
if auto_set_db and self.db_name:
|
|
366
|
+
return f"USE `{self.db_name}`\n{q}"
|
|
367
|
+
return q
|
|
280
368
|
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
369
|
+
try:
|
|
370
|
+
return self._client.execute(_wrap_use_db(gql), timeout=timeout)
|
|
371
|
+
|
|
372
|
+
except Exception as e:
|
|
373
|
+
emsg = str(e)
|
|
374
|
+
if any(k.lower() in emsg.lower() for k in _TRANSIENT_ERR_KEYS):
|
|
375
|
+
logger.warning(f"[execute_query] {e!s} → refreshing session pool and retry once...")
|
|
376
|
+
try:
|
|
377
|
+
self._refresh_client()
|
|
378
|
+
return self._client.execute(_wrap_use_db(gql), timeout=timeout)
|
|
379
|
+
except Exception:
|
|
380
|
+
logger.exception("[execute_query] retry after refresh failed")
|
|
381
|
+
raise
|
|
382
|
+
raise
|
|
287
383
|
|
|
288
384
|
@timed
|
|
289
385
|
def close(self):
|
|
290
|
-
|
|
386
|
+
"""
|
|
387
|
+
Close the connection resource if this instance owns it.
|
|
388
|
+
|
|
389
|
+
- If pool was injected (`shared_pool`), do nothing.
|
|
390
|
+
- If pool was acquired via shared cache, decrement refcount and close
|
|
391
|
+
when the last owner releases it.
|
|
392
|
+
"""
|
|
393
|
+
if not self._owns_client:
|
|
394
|
+
logger.debug("[NebulaGraphDBSync] close() skipped (injected client).")
|
|
395
|
+
return
|
|
396
|
+
if self._client_key:
|
|
397
|
+
self._release_shared_client(self._client_key)
|
|
398
|
+
self._client_key = None
|
|
399
|
+
self._client = None
|
|
400
|
+
|
|
401
|
+
# NOTE: __del__ is best-effort; do not rely on GC order.
|
|
402
|
+
def __del__(self):
|
|
403
|
+
with suppress(Exception):
|
|
404
|
+
self.close()
|
|
291
405
|
|
|
292
406
|
@timed
|
|
293
407
|
def create_index(
|
|
@@ -366,12 +480,10 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
366
480
|
filter_clause = f'n.memory_type = "{scope}" AND n.user_name = "{self.config.user_name}"'
|
|
367
481
|
else:
|
|
368
482
|
filter_clause = f'n.memory_type = "{scope}"'
|
|
369
|
-
return_fields = ", ".join(f"n.{field} AS {field}" for field in self.common_fields)
|
|
370
|
-
|
|
371
483
|
query = f"""
|
|
372
484
|
MATCH (n@Memory)
|
|
373
485
|
WHERE {filter_clause}
|
|
374
|
-
RETURN
|
|
486
|
+
RETURN n.id AS id
|
|
375
487
|
LIMIT 1
|
|
376
488
|
"""
|
|
377
489
|
|
|
@@ -568,10 +680,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
568
680
|
try:
|
|
569
681
|
result = self.execute_query(gql)
|
|
570
682
|
for row in result:
|
|
571
|
-
|
|
572
|
-
props = row.values()[0].as_node().get_properties()
|
|
573
|
-
else:
|
|
574
|
-
props = {k: v.value for k, v in row.items()}
|
|
683
|
+
props = {k: v.value for k, v in row.items()}
|
|
575
684
|
node = self._parse_node(props)
|
|
576
685
|
return node
|
|
577
686
|
|
|
@@ -582,7 +691,9 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
582
691
|
return None
|
|
583
692
|
|
|
584
693
|
@timed
|
|
585
|
-
def get_nodes(
|
|
694
|
+
def get_nodes(
|
|
695
|
+
self, ids: list[str], include_embedding: bool = False, **kwargs
|
|
696
|
+
) -> list[dict[str, Any]]:
|
|
586
697
|
"""
|
|
587
698
|
Retrieve the metadata and memory of a list of nodes.
|
|
588
699
|
Args:
|
|
@@ -600,7 +711,10 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
600
711
|
|
|
601
712
|
where_user = ""
|
|
602
713
|
if not self.config.use_multi_db and self.config.user_name:
|
|
603
|
-
|
|
714
|
+
if kwargs.get("cube_name"):
|
|
715
|
+
where_user = f" AND n.user_name = '{kwargs['cube_name']}'"
|
|
716
|
+
else:
|
|
717
|
+
where_user = f" AND n.user_name = '{self.config.user_name}'"
|
|
604
718
|
|
|
605
719
|
# Safe formatting of the ID list
|
|
606
720
|
id_list = ",".join(f'"{_id}"' for _id in ids)
|
|
@@ -615,10 +729,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
615
729
|
try:
|
|
616
730
|
results = self.execute_query(query)
|
|
617
731
|
for row in results:
|
|
618
|
-
|
|
619
|
-
props = row.values()[0].as_node().get_properties()
|
|
620
|
-
else:
|
|
621
|
-
props = {k: v.value for k, v in row.items()}
|
|
732
|
+
props = {k: v.value for k, v in row.items()}
|
|
622
733
|
nodes.append(self._parse_node(props))
|
|
623
734
|
except Exception as e:
|
|
624
735
|
logger.error(
|
|
@@ -687,6 +798,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
687
798
|
exclude_ids: list[str],
|
|
688
799
|
top_k: int = 5,
|
|
689
800
|
min_overlap: int = 1,
|
|
801
|
+
include_embedding: bool = False,
|
|
690
802
|
) -> list[dict[str, Any]]:
|
|
691
803
|
"""
|
|
692
804
|
Find top-K neighbor nodes with maximum tag overlap.
|
|
@@ -696,6 +808,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
696
808
|
exclude_ids: Node IDs to exclude (e.g., local cluster).
|
|
697
809
|
top_k: Max number of neighbors to return.
|
|
698
810
|
min_overlap: Minimum number of overlapping tags required.
|
|
811
|
+
include_embedding: with/without embedding
|
|
699
812
|
|
|
700
813
|
Returns:
|
|
701
814
|
List of dicts with node details and overlap count.
|
|
@@ -717,12 +830,13 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
717
830
|
where_clause = " AND ".join(where_clauses)
|
|
718
831
|
tag_list_literal = "[" + ", ".join(f'"{_escape_str(t)}"' for t in tags) + "]"
|
|
719
832
|
|
|
833
|
+
return_fields = self._build_return_fields(include_embedding)
|
|
720
834
|
query = f"""
|
|
721
835
|
LET tag_list = {tag_list_literal}
|
|
722
836
|
|
|
723
837
|
MATCH (n@Memory)
|
|
724
838
|
WHERE {where_clause}
|
|
725
|
-
RETURN
|
|
839
|
+
RETURN {return_fields},
|
|
726
840
|
size( filter( n.tags, t -> t IN tag_list ) ) AS overlap_count
|
|
727
841
|
ORDER BY overlap_count DESC
|
|
728
842
|
LIMIT {top_k}
|
|
@@ -731,9 +845,8 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
731
845
|
result = self.execute_query(query)
|
|
732
846
|
neighbors: list[dict[str, Any]] = []
|
|
733
847
|
for r in result:
|
|
734
|
-
|
|
735
|
-
parsed = self._parse_node(
|
|
736
|
-
|
|
848
|
+
props = {k: v.value for k, v in r.items() if k != "overlap_count"}
|
|
849
|
+
parsed = self._parse_node(props)
|
|
737
850
|
parsed["overlap_count"] = r["overlap_count"].value
|
|
738
851
|
neighbors.append(parsed)
|
|
739
852
|
|
|
@@ -840,6 +953,8 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
840
953
|
scope: str | None = None,
|
|
841
954
|
status: str | None = None,
|
|
842
955
|
threshold: float | None = None,
|
|
956
|
+
search_filter: dict | None = None,
|
|
957
|
+
**kwargs,
|
|
843
958
|
) -> list[dict]:
|
|
844
959
|
"""
|
|
845
960
|
Retrieve node IDs based on vector similarity.
|
|
@@ -851,6 +966,8 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
851
966
|
status (str, optional): Node status filter (e.g., 'active', 'archived').
|
|
852
967
|
If provided, restricts results to nodes with matching status.
|
|
853
968
|
threshold (float, optional): Minimum similarity score threshold (0 ~ 1).
|
|
969
|
+
search_filter (dict, optional): Additional metadata filters for search results.
|
|
970
|
+
Keys should match node properties, values are the expected values.
|
|
854
971
|
|
|
855
972
|
Returns:
|
|
856
973
|
list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
|
|
@@ -860,6 +977,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
860
977
|
- If scope is provided, it restricts results to nodes with matching memory_type.
|
|
861
978
|
- If 'status' is provided, only nodes with the matching status will be returned.
|
|
862
979
|
- If threshold is provided, only results with score >= threshold will be returned.
|
|
980
|
+
- If search_filter is provided, additional WHERE clauses will be added for metadata filtering.
|
|
863
981
|
- Typical use case: restrict to 'status = activated' to avoid
|
|
864
982
|
matching archived or merged nodes.
|
|
865
983
|
"""
|
|
@@ -874,12 +992,22 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
874
992
|
if status:
|
|
875
993
|
where_clauses.append(f'n.status = "{status}"')
|
|
876
994
|
if not self.config.use_multi_db and self.config.user_name:
|
|
877
|
-
|
|
995
|
+
if kwargs.get("cube_name"):
|
|
996
|
+
where_clauses.append(f'n.user_name = "{kwargs["cube_name"]}"')
|
|
997
|
+
else:
|
|
998
|
+
where_clauses.append(f'n.user_name = "{self.config.user_name}"')
|
|
999
|
+
|
|
1000
|
+
# Add search_filter conditions
|
|
1001
|
+
if search_filter:
|
|
1002
|
+
for key, value in search_filter.items():
|
|
1003
|
+
if isinstance(value, str):
|
|
1004
|
+
where_clauses.append(f'n.{key} = "{value}"')
|
|
1005
|
+
else:
|
|
1006
|
+
where_clauses.append(f"n.{key} = {value}")
|
|
878
1007
|
|
|
879
1008
|
where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
|
|
880
1009
|
|
|
881
1010
|
gql = f"""
|
|
882
|
-
USE `{self.db_name}`
|
|
883
1011
|
MATCH (n@Memory)
|
|
884
1012
|
{where_clause}
|
|
885
1013
|
ORDER BY inner_product(n.{self.dim_field}, {gql_vector}) DESC
|
|
@@ -902,7 +1030,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
902
1030
|
id_val = values[0].as_string()
|
|
903
1031
|
score_val = values[1].as_double()
|
|
904
1032
|
score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score
|
|
905
|
-
if threshold is None or score_val
|
|
1033
|
+
if threshold is None or score_val >= threshold:
|
|
906
1034
|
output.append({"id": id_val, "score": score_val})
|
|
907
1035
|
return output
|
|
908
1036
|
except Exception as e:
|
|
@@ -936,20 +1064,12 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
936
1064
|
"""
|
|
937
1065
|
where_clauses = []
|
|
938
1066
|
|
|
939
|
-
def _escape_value(value):
|
|
940
|
-
if isinstance(value, str):
|
|
941
|
-
return f'"{value}"'
|
|
942
|
-
elif isinstance(value, list):
|
|
943
|
-
return "[" + ", ".join(_escape_value(v) for v in value) + "]"
|
|
944
|
-
else:
|
|
945
|
-
return str(value)
|
|
946
|
-
|
|
947
1067
|
for _i, f in enumerate(filters):
|
|
948
1068
|
field = f["field"]
|
|
949
1069
|
op = f.get("op", "=")
|
|
950
1070
|
value = f["value"]
|
|
951
1071
|
|
|
952
|
-
escaped_value =
|
|
1072
|
+
escaped_value = self._format_value(value)
|
|
953
1073
|
|
|
954
1074
|
# Build WHERE clause
|
|
955
1075
|
if op == "=":
|
|
@@ -1153,28 +1273,36 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1153
1273
|
data: A dictionary containing all nodes and edges to be loaded.
|
|
1154
1274
|
"""
|
|
1155
1275
|
for node in data.get("nodes", []):
|
|
1156
|
-
|
|
1276
|
+
try:
|
|
1277
|
+
id, memory, metadata = _compose_node(node)
|
|
1157
1278
|
|
|
1158
|
-
|
|
1159
|
-
|
|
1279
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
1280
|
+
metadata["user_name"] = self.config.user_name
|
|
1160
1281
|
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
|
|
1164
|
-
|
|
1165
|
-
|
|
1282
|
+
metadata = self._prepare_node_metadata(metadata)
|
|
1283
|
+
metadata.update({"id": id, "memory": memory})
|
|
1284
|
+
properties = ", ".join(
|
|
1285
|
+
f"{k}: {self._format_value(v, k)}" for k, v in metadata.items()
|
|
1286
|
+
)
|
|
1287
|
+
node_gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
|
|
1288
|
+
self.execute_query(node_gql)
|
|
1289
|
+
except Exception as e:
|
|
1290
|
+
logger.error(f"Fail to load node: {node}, error: {e}")
|
|
1166
1291
|
|
|
1167
1292
|
for edge in data.get("edges", []):
|
|
1168
|
-
|
|
1169
|
-
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
|
|
1293
|
+
try:
|
|
1294
|
+
source_id, target_id = edge["source"], edge["target"]
|
|
1295
|
+
edge_type = edge["type"]
|
|
1296
|
+
props = ""
|
|
1297
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
1298
|
+
props = f'{{user_name: "{self.config.user_name}"}}'
|
|
1299
|
+
edge_gql = f'''
|
|
1300
|
+
MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}})
|
|
1301
|
+
INSERT OR IGNORE (a) -[e@{edge_type} {props}]-> (b)
|
|
1302
|
+
'''
|
|
1303
|
+
self.execute_query(edge_gql)
|
|
1304
|
+
except Exception as e:
|
|
1305
|
+
logger.error(f"Fail to load edge: {edge}, error: {e}")
|
|
1178
1306
|
|
|
1179
1307
|
@timed
|
|
1180
1308
|
def get_all_memory_items(self, scope: str, include_embedding: bool = False) -> (list)[dict]:
|
|
@@ -1208,10 +1336,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1208
1336
|
try:
|
|
1209
1337
|
results = self.execute_query(query)
|
|
1210
1338
|
for row in results:
|
|
1211
|
-
|
|
1212
|
-
props = row.values()[0].as_node().get_properties()
|
|
1213
|
-
else:
|
|
1214
|
-
props = {k: v.value for k, v in row.items()}
|
|
1339
|
+
props = {k: v.value for k, v in row.items()}
|
|
1215
1340
|
nodes.append(self._parse_node(props))
|
|
1216
1341
|
except Exception as e:
|
|
1217
1342
|
logger.error(f"Failed to get memories: {e}")
|
|
@@ -1235,9 +1360,9 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1235
1360
|
where_clause += f' AND n.user_name = "{self.config.user_name}"'
|
|
1236
1361
|
|
|
1237
1362
|
return_fields = self._build_return_fields(include_embedding)
|
|
1363
|
+
return_fields += f", n.{self.dim_field} AS {self.dim_field}"
|
|
1238
1364
|
|
|
1239
1365
|
query = f"""
|
|
1240
|
-
USE `{self.db_name}`
|
|
1241
1366
|
MATCH (n@Memory)
|
|
1242
1367
|
WHERE {where_clause}
|
|
1243
1368
|
OPTIONAL MATCH (n)-[@PARENT]->(c@Memory)
|
|
@@ -1247,14 +1372,16 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1247
1372
|
"""
|
|
1248
1373
|
|
|
1249
1374
|
candidates = []
|
|
1375
|
+
node_ids = set()
|
|
1250
1376
|
try:
|
|
1251
1377
|
results = self.execute_query(query)
|
|
1252
1378
|
for row in results:
|
|
1253
|
-
|
|
1254
|
-
|
|
1255
|
-
|
|
1256
|
-
|
|
1257
|
-
|
|
1379
|
+
props = {k: v.value for k, v in row.items()}
|
|
1380
|
+
node = self._parse_node(props)
|
|
1381
|
+
node_id = node["id"]
|
|
1382
|
+
if node_id not in node_ids:
|
|
1383
|
+
candidates.append(node)
|
|
1384
|
+
node_ids.add(node_id)
|
|
1258
1385
|
except Exception as e:
|
|
1259
1386
|
logger.error(f"Failed : {e}, traceback: {traceback.format_exc()}")
|
|
1260
1387
|
return candidates
|
|
@@ -1408,18 +1535,19 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1408
1535
|
logger.info(f"✅ Graph Type {graph_type_name} already include {self.dim_field}")
|
|
1409
1536
|
|
|
1410
1537
|
create_graph = f"CREATE GRAPH IF NOT EXISTS `{self.db_name}` TYPED {graph_type_name}"
|
|
1411
|
-
set_graph_working = f"SESSION SET GRAPH `{self.db_name}`"
|
|
1412
|
-
|
|
1413
1538
|
try:
|
|
1414
1539
|
self.execute_query(create_graph, auto_set_db=False)
|
|
1415
|
-
self.execute_query(set_graph_working)
|
|
1416
1540
|
logger.info(f"✅ Graph ``{self.db_name}`` is now the working graph.")
|
|
1417
1541
|
except Exception as e:
|
|
1418
1542
|
logger.error(f"❌ Failed to create tag: {e} trace: {traceback.format_exc()}")
|
|
1419
1543
|
|
|
1420
1544
|
@timed
|
|
1421
1545
|
def _create_vector_index(
|
|
1422
|
-
self,
|
|
1546
|
+
self,
|
|
1547
|
+
label: str = "Memory",
|
|
1548
|
+
vector_property: str = "embedding",
|
|
1549
|
+
dimensions: int = 3072,
|
|
1550
|
+
index_name: str = "memory_vector_index",
|
|
1423
1551
|
) -> None:
|
|
1424
1552
|
"""
|
|
1425
1553
|
Create a vector index for the specified property in the label.
|
|
@@ -1555,6 +1683,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1555
1683
|
# Normalize embedding type
|
|
1556
1684
|
embedding = metadata.get("embedding")
|
|
1557
1685
|
if embedding and isinstance(embedding, list):
|
|
1686
|
+
metadata.pop("embedding")
|
|
1558
1687
|
metadata[self.dim_field] = _normalize([float(x) for x in embedding])
|
|
1559
1688
|
|
|
1560
1689
|
return metadata
|
|
@@ -1563,12 +1692,22 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1563
1692
|
def _format_value(self, val: Any, key: str = "") -> str:
|
|
1564
1693
|
from nebulagraph_python.py_data_types import NVector
|
|
1565
1694
|
|
|
1695
|
+
# None
|
|
1696
|
+
if val is None:
|
|
1697
|
+
return "NULL"
|
|
1698
|
+
# bool
|
|
1699
|
+
if isinstance(val, bool):
|
|
1700
|
+
return "true" if val else "false"
|
|
1701
|
+
# str
|
|
1566
1702
|
if isinstance(val, str):
|
|
1567
1703
|
return f'"{_escape_str(val)}"'
|
|
1704
|
+
# num
|
|
1568
1705
|
elif isinstance(val, (int | float)):
|
|
1569
1706
|
return str(val)
|
|
1707
|
+
# time
|
|
1570
1708
|
elif isinstance(val, datetime):
|
|
1571
1709
|
return f'datetime("{val.isoformat()}")'
|
|
1710
|
+
# list
|
|
1572
1711
|
elif isinstance(val, list):
|
|
1573
1712
|
if key == self.dim_field:
|
|
1574
1713
|
dim = len(val)
|
|
@@ -1576,13 +1715,18 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1576
1715
|
return f"VECTOR<{dim}, FLOAT>([{joined}])"
|
|
1577
1716
|
else:
|
|
1578
1717
|
return f"[{', '.join(self._format_value(v) for v in val)}]"
|
|
1718
|
+
# NVector
|
|
1579
1719
|
elif isinstance(val, NVector):
|
|
1580
1720
|
if key == self.dim_field:
|
|
1581
1721
|
dim = len(val)
|
|
1582
1722
|
joined = ",".join(str(float(x)) for x in val)
|
|
1583
1723
|
return f"VECTOR<{dim}, FLOAT>([{joined}])"
|
|
1584
|
-
|
|
1585
|
-
|
|
1724
|
+
else:
|
|
1725
|
+
logger.warning("Invalid NVector")
|
|
1726
|
+
# dict
|
|
1727
|
+
if isinstance(val, dict):
|
|
1728
|
+
j = json.dumps(val, ensure_ascii=False, separators=(",", ":"))
|
|
1729
|
+
return f'"{_escape_str(j)}"'
|
|
1586
1730
|
else:
|
|
1587
1731
|
return f'"{_escape_str(str(val))}"'
|
|
1588
1732
|
|
|
@@ -1607,6 +1751,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1607
1751
|
return filtered_metadata
|
|
1608
1752
|
|
|
1609
1753
|
def _build_return_fields(self, include_embedding: bool = False) -> str:
|
|
1754
|
+
fields = set(self.base_fields)
|
|
1610
1755
|
if include_embedding:
|
|
1611
|
-
|
|
1612
|
-
return ", ".join(f"n.{
|
|
1756
|
+
fields.add(self.dim_field)
|
|
1757
|
+
return ", ".join(f"n.{f} AS {f}" for f in fields)
|