MemoryOS 1.0.0__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of MemoryOS might be problematic. Click here for more details.
- {memoryos-1.0.0.dist-info → memoryos-1.0.1.dist-info}/METADATA +2 -1
- {memoryos-1.0.0.dist-info → memoryos-1.0.1.dist-info}/RECORD +42 -33
- memos/__init__.py +1 -1
- memos/api/config.py +25 -0
- memos/api/context/context_thread.py +96 -0
- memos/api/context/dependencies.py +0 -11
- memos/api/middleware/request_context.py +94 -0
- memos/api/product_api.py +5 -1
- memos/api/product_models.py +16 -0
- memos/api/routers/product_router.py +39 -3
- memos/api/start_api.py +3 -0
- memos/configs/memory.py +13 -0
- memos/configs/reranker.py +18 -0
- memos/graph_dbs/base.py +4 -2
- memos/graph_dbs/nebular.py +215 -68
- memos/graph_dbs/neo4j.py +14 -12
- memos/graph_dbs/neo4j_community.py +6 -3
- memos/llms/vllm.py +2 -0
- memos/log.py +120 -8
- memos/mem_os/core.py +30 -2
- memos/mem_os/product.py +386 -146
- memos/mem_os/utils/reference_utils.py +20 -0
- memos/mem_reader/simple_struct.py +112 -43
- memos/mem_user/mysql_user_manager.py +4 -2
- memos/memories/textual/item.py +1 -1
- memos/memories/textual/tree.py +31 -1
- memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +3 -1
- memos/memories/textual/tree_text_memory/retrieve/recall.py +53 -3
- memos/memories/textual/tree_text_memory/retrieve/searcher.py +74 -14
- memos/memories/textual/tree_text_memory/retrieve/utils.py +6 -4
- memos/memos_tools/notification_utils.py +46 -0
- memos/reranker/__init__.py +4 -0
- memos/reranker/base.py +24 -0
- memos/reranker/cosine_local.py +95 -0
- memos/reranker/factory.py +43 -0
- memos/reranker/http_bge.py +99 -0
- memos/reranker/noop.py +16 -0
- memos/templates/mem_reader_prompts.py +289 -40
- memos/templates/mos_prompts.py +133 -60
- {memoryos-1.0.0.dist-info → memoryos-1.0.1.dist-info}/LICENSE +0 -0
- {memoryos-1.0.0.dist-info → memoryos-1.0.1.dist-info}/WHEEL +0 -0
- {memoryos-1.0.0.dist-info → memoryos-1.0.1.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# memos/configs/reranker.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, Field
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class RerankerConfigFactory(BaseModel):
|
|
10
|
+
"""
|
|
11
|
+
{
|
|
12
|
+
"backend": "http_bge" | "cosine_local" | "noop",
|
|
13
|
+
"config": { ... backend-specific ... }
|
|
14
|
+
}
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
backend: str = Field(..., description="Reranker backend id")
|
|
18
|
+
config: dict[str, Any] = Field(default_factory=dict, description="Backend-specific options")
|
memos/graph_dbs/base.py
CHANGED
|
@@ -81,7 +81,9 @@ class BaseGraphDB(ABC):
|
|
|
81
81
|
"""
|
|
82
82
|
|
|
83
83
|
@abstractmethod
|
|
84
|
-
def get_nodes(
|
|
84
|
+
def get_nodes(
|
|
85
|
+
self, id: str, include_embedding: bool = False, **kwargs
|
|
86
|
+
) -> dict[str, Any] | None:
|
|
85
87
|
"""
|
|
86
88
|
Retrieve the metadata and memory of a list of nodes.
|
|
87
89
|
Args:
|
|
@@ -141,7 +143,7 @@ class BaseGraphDB(ABC):
|
|
|
141
143
|
|
|
142
144
|
# Search / recall operations
|
|
143
145
|
@abstractmethod
|
|
144
|
-
def search_by_embedding(self, vector: list[float], top_k: int = 5) -> list[dict]:
|
|
146
|
+
def search_by_embedding(self, vector: list[float], top_k: int = 5, **kwargs) -> list[dict]:
|
|
145
147
|
"""
|
|
146
148
|
Retrieve node IDs based on vector similarity.
|
|
147
149
|
|
memos/graph_dbs/nebular.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
|
+
import json
|
|
1
2
|
import traceback
|
|
2
3
|
|
|
3
4
|
from contextlib import suppress
|
|
4
5
|
from datetime import datetime
|
|
5
6
|
from queue import Empty, Queue
|
|
6
7
|
from threading import Lock
|
|
7
|
-
from typing import Any, Literal
|
|
8
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Literal
|
|
8
9
|
|
|
9
10
|
import numpy as np
|
|
10
11
|
|
|
@@ -15,6 +16,10 @@ from memos.log import get_logger
|
|
|
15
16
|
from memos.utils import timed
|
|
16
17
|
|
|
17
18
|
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from nebulagraph_python.client.pool import NebulaPool
|
|
21
|
+
|
|
22
|
+
|
|
18
23
|
logger = get_logger(__name__)
|
|
19
24
|
|
|
20
25
|
|
|
@@ -35,7 +40,28 @@ def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]:
|
|
|
35
40
|
|
|
36
41
|
@timed
|
|
37
42
|
def _escape_str(value: str) -> str:
|
|
38
|
-
|
|
43
|
+
out = []
|
|
44
|
+
for ch in value:
|
|
45
|
+
code = ord(ch)
|
|
46
|
+
if ch == "\\":
|
|
47
|
+
out.append("\\\\")
|
|
48
|
+
elif ch == '"':
|
|
49
|
+
out.append('\\"')
|
|
50
|
+
elif ch == "\n":
|
|
51
|
+
out.append("\\n")
|
|
52
|
+
elif ch == "\r":
|
|
53
|
+
out.append("\\r")
|
|
54
|
+
elif ch == "\t":
|
|
55
|
+
out.append("\\t")
|
|
56
|
+
elif ch == "\b":
|
|
57
|
+
out.append("\\b")
|
|
58
|
+
elif ch == "\f":
|
|
59
|
+
out.append("\\f")
|
|
60
|
+
elif code < 0x20 or code in (0x2028, 0x2029):
|
|
61
|
+
out.append(f"\\u{code:04x}")
|
|
62
|
+
else:
|
|
63
|
+
out.append(ch)
|
|
64
|
+
return "".join(out)
|
|
39
65
|
|
|
40
66
|
|
|
41
67
|
@timed
|
|
@@ -197,6 +223,94 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
197
223
|
NebulaGraph-based implementation of a graph memory store.
|
|
198
224
|
"""
|
|
199
225
|
|
|
226
|
+
# ====== shared pool cache & refcount ======
|
|
227
|
+
# These are process-local; in a multi-process model each process will
|
|
228
|
+
# have its own cache.
|
|
229
|
+
_POOL_CACHE: ClassVar[dict[str, "NebulaPool"]] = {}
|
|
230
|
+
_POOL_REFCOUNT: ClassVar[dict[str, int]] = {}
|
|
231
|
+
_POOL_LOCK: ClassVar[Lock] = Lock()
|
|
232
|
+
|
|
233
|
+
@staticmethod
|
|
234
|
+
def _make_pool_key(cfg: NebulaGraphDBConfig) -> str:
|
|
235
|
+
"""
|
|
236
|
+
Build a cache key that captures all connection-affecting options.
|
|
237
|
+
Keep this key stable and include fields that change the underlying pool behavior.
|
|
238
|
+
"""
|
|
239
|
+
# NOTE: Do not include tenant-like or query-scope-only fields here.
|
|
240
|
+
# Only include things that affect the actual TCP/auth/session pool.
|
|
241
|
+
return "|".join(
|
|
242
|
+
[
|
|
243
|
+
"nebula",
|
|
244
|
+
str(getattr(cfg, "uri", "")),
|
|
245
|
+
str(getattr(cfg, "user", "")),
|
|
246
|
+
str(getattr(cfg, "password", "")),
|
|
247
|
+
# pool sizing / tls / timeouts if you have them in config:
|
|
248
|
+
str(getattr(cfg, "max_client", 1000)),
|
|
249
|
+
# multi-db mode can impact how we use sessions; keep it to be safe
|
|
250
|
+
str(getattr(cfg, "use_multi_db", False)),
|
|
251
|
+
]
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
@classmethod
|
|
255
|
+
def _get_or_create_shared_pool(cls, cfg: NebulaGraphDBConfig):
|
|
256
|
+
"""
|
|
257
|
+
Get a shared NebulaPool from cache or create one if missing.
|
|
258
|
+
Thread-safe with a lock; maintains a simple refcount.
|
|
259
|
+
"""
|
|
260
|
+
key = cls._make_pool_key(cfg)
|
|
261
|
+
|
|
262
|
+
with cls._POOL_LOCK:
|
|
263
|
+
pool = cls._POOL_CACHE.get(key)
|
|
264
|
+
if pool is None:
|
|
265
|
+
# Create a new pool and put into cache
|
|
266
|
+
pool = SessionPool(
|
|
267
|
+
hosts=cfg.get("uri"),
|
|
268
|
+
user=cfg.get("user"),
|
|
269
|
+
password=cfg.get("password"),
|
|
270
|
+
minsize=1,
|
|
271
|
+
maxsize=cfg.get("max_client", 1000),
|
|
272
|
+
)
|
|
273
|
+
cls._POOL_CACHE[key] = pool
|
|
274
|
+
cls._POOL_REFCOUNT[key] = 0
|
|
275
|
+
logger.info(f"[NebulaGraphDB] Created new shared NebulaPool for key={key}")
|
|
276
|
+
|
|
277
|
+
# Increase refcount for the caller
|
|
278
|
+
cls._POOL_REFCOUNT[key] = cls._POOL_REFCOUNT.get(key, 0) + 1
|
|
279
|
+
return key, pool
|
|
280
|
+
|
|
281
|
+
@classmethod
|
|
282
|
+
def _release_shared_pool(cls, key: str):
|
|
283
|
+
"""
|
|
284
|
+
Decrease refcount for the given pool key; only close when refcount hits zero.
|
|
285
|
+
"""
|
|
286
|
+
with cls._POOL_LOCK:
|
|
287
|
+
if key not in cls._POOL_CACHE:
|
|
288
|
+
return
|
|
289
|
+
cls._POOL_REFCOUNT[key] = max(0, cls._POOL_REFCOUNT.get(key, 0) - 1)
|
|
290
|
+
if cls._POOL_REFCOUNT[key] == 0:
|
|
291
|
+
try:
|
|
292
|
+
cls._POOL_CACHE[key].close()
|
|
293
|
+
except Exception as e:
|
|
294
|
+
logger.warning(f"[NebulaGraphDB] Error closing shared pool: {e}")
|
|
295
|
+
finally:
|
|
296
|
+
cls._POOL_CACHE.pop(key, None)
|
|
297
|
+
cls._POOL_REFCOUNT.pop(key, None)
|
|
298
|
+
logger.info(f"[NebulaGraphDB] Closed and removed shared pool key={key}")
|
|
299
|
+
|
|
300
|
+
@classmethod
|
|
301
|
+
def close_all_shared_pools(cls):
|
|
302
|
+
"""Force close all cached pools. Call this on graceful shutdown."""
|
|
303
|
+
with cls._POOL_LOCK:
|
|
304
|
+
for key, pool in list(cls._POOL_CACHE.items()):
|
|
305
|
+
try:
|
|
306
|
+
pool.close()
|
|
307
|
+
except Exception as e:
|
|
308
|
+
logger.warning(f"[NebulaGraphDB] Error closing pool key={key}: {e}")
|
|
309
|
+
finally:
|
|
310
|
+
logger.info(f"[NebulaGraphDB] Closed pool key={key}")
|
|
311
|
+
cls._POOL_CACHE.clear()
|
|
312
|
+
cls._POOL_REFCOUNT.clear()
|
|
313
|
+
|
|
200
314
|
@require_python_package(
|
|
201
315
|
import_name="nebulagraph_python",
|
|
202
316
|
install_command="pip install ... @Tianxing",
|
|
@@ -246,20 +360,21 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
246
360
|
"usage",
|
|
247
361
|
"background",
|
|
248
362
|
}
|
|
363
|
+
self.base_fields = set(self.common_fields) - {"usage"}
|
|
364
|
+
self.heavy_fields = {"usage"}
|
|
249
365
|
self.dim_field = (
|
|
250
366
|
f"embedding_{self.embedding_dimension}"
|
|
251
367
|
if (str(self.embedding_dimension) != str(self.default_memory_dimension))
|
|
252
368
|
else "embedding"
|
|
253
369
|
)
|
|
254
370
|
self.system_db_name = "system" if config.use_multi_db else config.space
|
|
255
|
-
self.pool = SessionPool(
|
|
256
|
-
hosts=config.get("uri"),
|
|
257
|
-
user=config.get("user"),
|
|
258
|
-
password=config.get("password"),
|
|
259
|
-
minsize=1,
|
|
260
|
-
maxsize=config.get("max_client", 1000),
|
|
261
|
-
)
|
|
262
371
|
|
|
372
|
+
# ---- NEW: pool acquisition strategy
|
|
373
|
+
# Get or create a shared pool from the class-level cache
|
|
374
|
+
self._pool_key, self.pool = self._get_or_create_shared_pool(config)
|
|
375
|
+
self._owns_pool = True # We manage refcount for this instance
|
|
376
|
+
|
|
377
|
+
# auto-create graph type / graph / index if needed
|
|
263
378
|
if config.auto_create:
|
|
264
379
|
self._ensure_database_exists()
|
|
265
380
|
|
|
@@ -271,7 +386,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
271
386
|
logger.info("Connected to NebulaGraph successfully.")
|
|
272
387
|
|
|
273
388
|
@timed
|
|
274
|
-
def execute_query(self, gql: str, timeout: float =
|
|
389
|
+
def execute_query(self, gql: str, timeout: float = 10.0, auto_set_db: bool = True):
|
|
275
390
|
with self.pool.get() as client:
|
|
276
391
|
try:
|
|
277
392
|
if auto_set_db and self.db_name:
|
|
@@ -287,7 +402,25 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
287
402
|
|
|
288
403
|
@timed
|
|
289
404
|
def close(self):
|
|
290
|
-
|
|
405
|
+
"""
|
|
406
|
+
Close the connection resource if this instance owns it.
|
|
407
|
+
|
|
408
|
+
- If pool was injected (`shared_pool`), do nothing.
|
|
409
|
+
- If pool was acquired via shared cache, decrement refcount and close
|
|
410
|
+
when the last owner releases it.
|
|
411
|
+
"""
|
|
412
|
+
if not self._owns_pool:
|
|
413
|
+
logger.debug("[NebulaGraphDB] close() skipped (injected pool).")
|
|
414
|
+
return
|
|
415
|
+
if self._pool_key:
|
|
416
|
+
self._release_shared_pool(self._pool_key)
|
|
417
|
+
self._pool_key = None
|
|
418
|
+
self.pool = None
|
|
419
|
+
|
|
420
|
+
# NOTE: __del__ is best-effort; do not rely on GC order.
|
|
421
|
+
def __del__(self):
|
|
422
|
+
with suppress(Exception):
|
|
423
|
+
self.close()
|
|
291
424
|
|
|
292
425
|
@timed
|
|
293
426
|
def create_index(
|
|
@@ -366,12 +499,10 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
366
499
|
filter_clause = f'n.memory_type = "{scope}" AND n.user_name = "{self.config.user_name}"'
|
|
367
500
|
else:
|
|
368
501
|
filter_clause = f'n.memory_type = "{scope}"'
|
|
369
|
-
return_fields = ", ".join(f"n.{field} AS {field}" for field in self.common_fields)
|
|
370
|
-
|
|
371
502
|
query = f"""
|
|
372
503
|
MATCH (n@Memory)
|
|
373
504
|
WHERE {filter_clause}
|
|
374
|
-
RETURN
|
|
505
|
+
RETURN n.id AS id
|
|
375
506
|
LIMIT 1
|
|
376
507
|
"""
|
|
377
508
|
|
|
@@ -568,10 +699,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
568
699
|
try:
|
|
569
700
|
result = self.execute_query(gql)
|
|
570
701
|
for row in result:
|
|
571
|
-
|
|
572
|
-
props = row.values()[0].as_node().get_properties()
|
|
573
|
-
else:
|
|
574
|
-
props = {k: v.value for k, v in row.items()}
|
|
702
|
+
props = {k: v.value for k, v in row.items()}
|
|
575
703
|
node = self._parse_node(props)
|
|
576
704
|
return node
|
|
577
705
|
|
|
@@ -582,7 +710,9 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
582
710
|
return None
|
|
583
711
|
|
|
584
712
|
@timed
|
|
585
|
-
def get_nodes(
|
|
713
|
+
def get_nodes(
|
|
714
|
+
self, ids: list[str], include_embedding: bool = False, **kwargs
|
|
715
|
+
) -> list[dict[str, Any]]:
|
|
586
716
|
"""
|
|
587
717
|
Retrieve the metadata and memory of a list of nodes.
|
|
588
718
|
Args:
|
|
@@ -600,7 +730,10 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
600
730
|
|
|
601
731
|
where_user = ""
|
|
602
732
|
if not self.config.use_multi_db and self.config.user_name:
|
|
603
|
-
|
|
733
|
+
if kwargs.get("cube_name"):
|
|
734
|
+
where_user = f" AND n.user_name = '{kwargs['cube_name']}'"
|
|
735
|
+
else:
|
|
736
|
+
where_user = f" AND n.user_name = '{self.config.user_name}'"
|
|
604
737
|
|
|
605
738
|
# Safe formatting of the ID list
|
|
606
739
|
id_list = ",".join(f'"{_id}"' for _id in ids)
|
|
@@ -615,10 +748,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
615
748
|
try:
|
|
616
749
|
results = self.execute_query(query)
|
|
617
750
|
for row in results:
|
|
618
|
-
|
|
619
|
-
props = row.values()[0].as_node().get_properties()
|
|
620
|
-
else:
|
|
621
|
-
props = {k: v.value for k, v in row.items()}
|
|
751
|
+
props = {k: v.value for k, v in row.items()}
|
|
622
752
|
nodes.append(self._parse_node(props))
|
|
623
753
|
except Exception as e:
|
|
624
754
|
logger.error(
|
|
@@ -687,6 +817,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
687
817
|
exclude_ids: list[str],
|
|
688
818
|
top_k: int = 5,
|
|
689
819
|
min_overlap: int = 1,
|
|
820
|
+
include_embedding: bool = False,
|
|
690
821
|
) -> list[dict[str, Any]]:
|
|
691
822
|
"""
|
|
692
823
|
Find top-K neighbor nodes with maximum tag overlap.
|
|
@@ -696,6 +827,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
696
827
|
exclude_ids: Node IDs to exclude (e.g., local cluster).
|
|
697
828
|
top_k: Max number of neighbors to return.
|
|
698
829
|
min_overlap: Minimum number of overlapping tags required.
|
|
830
|
+
include_embedding: with/without embedding
|
|
699
831
|
|
|
700
832
|
Returns:
|
|
701
833
|
List of dicts with node details and overlap count.
|
|
@@ -717,12 +849,13 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
717
849
|
where_clause = " AND ".join(where_clauses)
|
|
718
850
|
tag_list_literal = "[" + ", ".join(f'"{_escape_str(t)}"' for t in tags) + "]"
|
|
719
851
|
|
|
852
|
+
return_fields = self._build_return_fields(include_embedding)
|
|
720
853
|
query = f"""
|
|
721
854
|
LET tag_list = {tag_list_literal}
|
|
722
855
|
|
|
723
856
|
MATCH (n@Memory)
|
|
724
857
|
WHERE {where_clause}
|
|
725
|
-
RETURN
|
|
858
|
+
RETURN {return_fields},
|
|
726
859
|
size( filter( n.tags, t -> t IN tag_list ) ) AS overlap_count
|
|
727
860
|
ORDER BY overlap_count DESC
|
|
728
861
|
LIMIT {top_k}
|
|
@@ -731,9 +864,8 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
731
864
|
result = self.execute_query(query)
|
|
732
865
|
neighbors: list[dict[str, Any]] = []
|
|
733
866
|
for r in result:
|
|
734
|
-
|
|
735
|
-
parsed = self._parse_node(
|
|
736
|
-
|
|
867
|
+
props = {k: v.value for k, v in r.items() if k != "overlap_count"}
|
|
868
|
+
parsed = self._parse_node(props)
|
|
737
869
|
parsed["overlap_count"] = r["overlap_count"].value
|
|
738
870
|
neighbors.append(parsed)
|
|
739
871
|
|
|
@@ -840,6 +972,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
840
972
|
scope: str | None = None,
|
|
841
973
|
status: str | None = None,
|
|
842
974
|
threshold: float | None = None,
|
|
975
|
+
**kwargs,
|
|
843
976
|
) -> list[dict]:
|
|
844
977
|
"""
|
|
845
978
|
Retrieve node IDs based on vector similarity.
|
|
@@ -874,7 +1007,10 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
874
1007
|
if status:
|
|
875
1008
|
where_clauses.append(f'n.status = "{status}"')
|
|
876
1009
|
if not self.config.use_multi_db and self.config.user_name:
|
|
877
|
-
|
|
1010
|
+
if kwargs.get("cube_name"):
|
|
1011
|
+
where_clauses.append(f'n.user_name = "{kwargs["cube_name"]}"')
|
|
1012
|
+
else:
|
|
1013
|
+
where_clauses.append(f'n.user_name = "{self.config.user_name}"')
|
|
878
1014
|
|
|
879
1015
|
where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
|
|
880
1016
|
|
|
@@ -936,20 +1072,12 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
936
1072
|
"""
|
|
937
1073
|
where_clauses = []
|
|
938
1074
|
|
|
939
|
-
def _escape_value(value):
|
|
940
|
-
if isinstance(value, str):
|
|
941
|
-
return f'"{value}"'
|
|
942
|
-
elif isinstance(value, list):
|
|
943
|
-
return "[" + ", ".join(_escape_value(v) for v in value) + "]"
|
|
944
|
-
else:
|
|
945
|
-
return str(value)
|
|
946
|
-
|
|
947
1075
|
for _i, f in enumerate(filters):
|
|
948
1076
|
field = f["field"]
|
|
949
1077
|
op = f.get("op", "=")
|
|
950
1078
|
value = f["value"]
|
|
951
1079
|
|
|
952
|
-
escaped_value =
|
|
1080
|
+
escaped_value = self._format_value(value)
|
|
953
1081
|
|
|
954
1082
|
# Build WHERE clause
|
|
955
1083
|
if op == "=":
|
|
@@ -1153,28 +1281,36 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1153
1281
|
data: A dictionary containing all nodes and edges to be loaded.
|
|
1154
1282
|
"""
|
|
1155
1283
|
for node in data.get("nodes", []):
|
|
1156
|
-
|
|
1284
|
+
try:
|
|
1285
|
+
id, memory, metadata = _compose_node(node)
|
|
1157
1286
|
|
|
1158
|
-
|
|
1159
|
-
|
|
1287
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
1288
|
+
metadata["user_name"] = self.config.user_name
|
|
1160
1289
|
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
|
|
1164
|
-
|
|
1165
|
-
|
|
1290
|
+
metadata = self._prepare_node_metadata(metadata)
|
|
1291
|
+
metadata.update({"id": id, "memory": memory})
|
|
1292
|
+
properties = ", ".join(
|
|
1293
|
+
f"{k}: {self._format_value(v, k)}" for k, v in metadata.items()
|
|
1294
|
+
)
|
|
1295
|
+
node_gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
|
|
1296
|
+
self.execute_query(node_gql)
|
|
1297
|
+
except Exception as e:
|
|
1298
|
+
logger.error(f"Fail to load node: {node}, error: {e}")
|
|
1166
1299
|
|
|
1167
1300
|
for edge in data.get("edges", []):
|
|
1168
|
-
|
|
1169
|
-
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
|
|
1301
|
+
try:
|
|
1302
|
+
source_id, target_id = edge["source"], edge["target"]
|
|
1303
|
+
edge_type = edge["type"]
|
|
1304
|
+
props = ""
|
|
1305
|
+
if not self.config.use_multi_db and self.config.user_name:
|
|
1306
|
+
props = f'{{user_name: "{self.config.user_name}"}}'
|
|
1307
|
+
edge_gql = f'''
|
|
1308
|
+
MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}})
|
|
1309
|
+
INSERT OR IGNORE (a) -[e@{edge_type} {props}]-> (b)
|
|
1310
|
+
'''
|
|
1311
|
+
self.execute_query(edge_gql)
|
|
1312
|
+
except Exception as e:
|
|
1313
|
+
logger.error(f"Fail to load edge: {edge}, error: {e}")
|
|
1178
1314
|
|
|
1179
1315
|
@timed
|
|
1180
1316
|
def get_all_memory_items(self, scope: str, include_embedding: bool = False) -> (list)[dict]:
|
|
@@ -1208,10 +1344,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1208
1344
|
try:
|
|
1209
1345
|
results = self.execute_query(query)
|
|
1210
1346
|
for row in results:
|
|
1211
|
-
|
|
1212
|
-
props = row.values()[0].as_node().get_properties()
|
|
1213
|
-
else:
|
|
1214
|
-
props = {k: v.value for k, v in row.items()}
|
|
1347
|
+
props = {k: v.value for k, v in row.items()}
|
|
1215
1348
|
nodes.append(self._parse_node(props))
|
|
1216
1349
|
except Exception as e:
|
|
1217
1350
|
logger.error(f"Failed to get memories: {e}")
|
|
@@ -1250,10 +1383,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1250
1383
|
try:
|
|
1251
1384
|
results = self.execute_query(query)
|
|
1252
1385
|
for row in results:
|
|
1253
|
-
|
|
1254
|
-
props = row.values()[0].as_node().get_properties()
|
|
1255
|
-
else:
|
|
1256
|
-
props = {k: v.value for k, v in row.items()}
|
|
1386
|
+
props = {k: v.value for k, v in row.items()}
|
|
1257
1387
|
candidates.append(self._parse_node(props))
|
|
1258
1388
|
except Exception as e:
|
|
1259
1389
|
logger.error(f"Failed : {e}, traceback: {traceback.format_exc()}")
|
|
@@ -1555,6 +1685,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1555
1685
|
# Normalize embedding type
|
|
1556
1686
|
embedding = metadata.get("embedding")
|
|
1557
1687
|
if embedding and isinstance(embedding, list):
|
|
1688
|
+
metadata.pop("embedding")
|
|
1558
1689
|
metadata[self.dim_field] = _normalize([float(x) for x in embedding])
|
|
1559
1690
|
|
|
1560
1691
|
return metadata
|
|
@@ -1563,12 +1694,22 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1563
1694
|
def _format_value(self, val: Any, key: str = "") -> str:
|
|
1564
1695
|
from nebulagraph_python.py_data_types import NVector
|
|
1565
1696
|
|
|
1697
|
+
# None
|
|
1698
|
+
if val is None:
|
|
1699
|
+
return "NULL"
|
|
1700
|
+
# bool
|
|
1701
|
+
if isinstance(val, bool):
|
|
1702
|
+
return "true" if val else "false"
|
|
1703
|
+
# str
|
|
1566
1704
|
if isinstance(val, str):
|
|
1567
1705
|
return f'"{_escape_str(val)}"'
|
|
1706
|
+
# num
|
|
1568
1707
|
elif isinstance(val, (int | float)):
|
|
1569
1708
|
return str(val)
|
|
1709
|
+
# time
|
|
1570
1710
|
elif isinstance(val, datetime):
|
|
1571
1711
|
return f'datetime("{val.isoformat()}")'
|
|
1712
|
+
# list
|
|
1572
1713
|
elif isinstance(val, list):
|
|
1573
1714
|
if key == self.dim_field:
|
|
1574
1715
|
dim = len(val)
|
|
@@ -1576,13 +1717,18 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1576
1717
|
return f"VECTOR<{dim}, FLOAT>([{joined}])"
|
|
1577
1718
|
else:
|
|
1578
1719
|
return f"[{', '.join(self._format_value(v) for v in val)}]"
|
|
1720
|
+
# NVector
|
|
1579
1721
|
elif isinstance(val, NVector):
|
|
1580
1722
|
if key == self.dim_field:
|
|
1581
1723
|
dim = len(val)
|
|
1582
1724
|
joined = ",".join(str(float(x)) for x in val)
|
|
1583
1725
|
return f"VECTOR<{dim}, FLOAT>([{joined}])"
|
|
1584
|
-
|
|
1585
|
-
|
|
1726
|
+
else:
|
|
1727
|
+
logger.warning("Invalid NVector")
|
|
1728
|
+
# dict
|
|
1729
|
+
if isinstance(val, dict):
|
|
1730
|
+
j = json.dumps(val, ensure_ascii=False, separators=(",", ":"))
|
|
1731
|
+
return f'"{_escape_str(j)}"'
|
|
1586
1732
|
else:
|
|
1587
1733
|
return f'"{_escape_str(str(val))}"'
|
|
1588
1734
|
|
|
@@ -1607,6 +1753,7 @@ class NebulaGraphDB(BaseGraphDB):
|
|
|
1607
1753
|
return filtered_metadata
|
|
1608
1754
|
|
|
1609
1755
|
def _build_return_fields(self, include_embedding: bool = False) -> str:
|
|
1756
|
+
fields = set(self.base_fields)
|
|
1610
1757
|
if include_embedding:
|
|
1611
|
-
|
|
1612
|
-
return ", ".join(f"n.{
|
|
1758
|
+
fields.add(self.dim_field)
|
|
1759
|
+
return ", ".join(f"n.{f} AS {f}" for f in fields)
|
memos/graph_dbs/neo4j.py
CHANGED
|
@@ -323,12 +323,11 @@ class Neo4jGraphDB(BaseGraphDB):
|
|
|
323
323
|
return result.single() is not None
|
|
324
324
|
|
|
325
325
|
# Graph Query & Reasoning
|
|
326
|
-
def get_node(self, id: str,
|
|
326
|
+
def get_node(self, id: str, **kwargs) -> dict[str, Any] | None:
|
|
327
327
|
"""
|
|
328
328
|
Retrieve the metadata and memory of a node.
|
|
329
329
|
Args:
|
|
330
330
|
id: Node identifier.
|
|
331
|
-
include_embedding (bool): Whether to include the large embedding field.
|
|
332
331
|
Returns:
|
|
333
332
|
Dictionary of node fields, or None if not found.
|
|
334
333
|
"""
|
|
@@ -345,12 +344,11 @@ class Neo4jGraphDB(BaseGraphDB):
|
|
|
345
344
|
record = session.run(query, params).single()
|
|
346
345
|
return self._parse_node(dict(record["n"])) if record else None
|
|
347
346
|
|
|
348
|
-
def get_nodes(self, ids: list[str],
|
|
347
|
+
def get_nodes(self, ids: list[str], **kwargs) -> list[dict[str, Any]]:
|
|
349
348
|
"""
|
|
350
349
|
Retrieve the metadata and memory of a list of nodes.
|
|
351
350
|
Args:
|
|
352
351
|
ids: List of Node identifier.
|
|
353
|
-
include_embedding (bool): Whether to include the large embedding field.
|
|
354
352
|
Returns:
|
|
355
353
|
list[dict]: Parsed node records containing 'id', 'memory', and 'metadata'.
|
|
356
354
|
|
|
@@ -367,7 +365,10 @@ class Neo4jGraphDB(BaseGraphDB):
|
|
|
367
365
|
|
|
368
366
|
if not self.config.use_multi_db and self.config.user_name:
|
|
369
367
|
where_user = " AND n.user_name = $user_name"
|
|
370
|
-
|
|
368
|
+
if kwargs.get("cube_name"):
|
|
369
|
+
params["user_name"] = kwargs["cube_name"]
|
|
370
|
+
else:
|
|
371
|
+
params["user_name"] = self.config.user_name
|
|
371
372
|
|
|
372
373
|
query = f"MATCH (n:Memory) WHERE n.id IN $ids{where_user} RETURN n"
|
|
373
374
|
|
|
@@ -605,6 +606,7 @@ class Neo4jGraphDB(BaseGraphDB):
|
|
|
605
606
|
scope: str | None = None,
|
|
606
607
|
status: str | None = None,
|
|
607
608
|
threshold: float | None = None,
|
|
609
|
+
**kwargs,
|
|
608
610
|
) -> list[dict]:
|
|
609
611
|
"""
|
|
610
612
|
Retrieve node IDs based on vector similarity.
|
|
@@ -654,7 +656,10 @@ class Neo4jGraphDB(BaseGraphDB):
|
|
|
654
656
|
if status:
|
|
655
657
|
parameters["status"] = status
|
|
656
658
|
if not self.config.use_multi_db and self.config.user_name:
|
|
657
|
-
|
|
659
|
+
if kwargs.get("cube_name"):
|
|
660
|
+
parameters["user_name"] = kwargs["cube_name"]
|
|
661
|
+
else:
|
|
662
|
+
parameters["user_name"] = self.config.user_name
|
|
658
663
|
|
|
659
664
|
with self.driver.session(database=self.db_name) as session:
|
|
660
665
|
result = session.run(query, parameters)
|
|
@@ -833,7 +838,7 @@ class Neo4jGraphDB(BaseGraphDB):
|
|
|
833
838
|
logger.error(f"[ERROR] Failed to clear database '{self.db_name}': {e}")
|
|
834
839
|
raise
|
|
835
840
|
|
|
836
|
-
def export_graph(self,
|
|
841
|
+
def export_graph(self, **kwargs) -> dict[str, Any]:
|
|
837
842
|
"""
|
|
838
843
|
Export all graph nodes and edges in a structured form.
|
|
839
844
|
|
|
@@ -914,13 +919,12 @@ class Neo4jGraphDB(BaseGraphDB):
|
|
|
914
919
|
target_id=edge["target"],
|
|
915
920
|
)
|
|
916
921
|
|
|
917
|
-
def get_all_memory_items(self, scope: str,
|
|
922
|
+
def get_all_memory_items(self, scope: str, **kwargs) -> list[dict]:
|
|
918
923
|
"""
|
|
919
924
|
Retrieve all memory items of a specific memory_type.
|
|
920
925
|
|
|
921
926
|
Args:
|
|
922
927
|
scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'.
|
|
923
|
-
include_embedding (bool): Whether to include the large embedding field.
|
|
924
928
|
Returns:
|
|
925
929
|
|
|
926
930
|
Returns:
|
|
@@ -946,9 +950,7 @@ class Neo4jGraphDB(BaseGraphDB):
|
|
|
946
950
|
results = session.run(query, params)
|
|
947
951
|
return [self._parse_node(dict(record["n"])) for record in results]
|
|
948
952
|
|
|
949
|
-
def get_structure_optimization_candidates(
|
|
950
|
-
self, scope: str, include_embedding: bool = True
|
|
951
|
-
) -> list[dict]:
|
|
953
|
+
def get_structure_optimization_candidates(self, scope: str, **kwargs) -> list[dict]:
|
|
952
954
|
"""
|
|
953
955
|
Find nodes that are likely candidates for structure optimization:
|
|
954
956
|
- Isolated nodes, nodes with empty background, or nodes with exactly one child.
|
|
@@ -129,6 +129,7 @@ class Neo4jCommunityGraphDB(Neo4jGraphDB):
|
|
|
129
129
|
scope: str | None = None,
|
|
130
130
|
status: str | None = None,
|
|
131
131
|
threshold: float | None = None,
|
|
132
|
+
**kwargs,
|
|
132
133
|
) -> list[dict]:
|
|
133
134
|
"""
|
|
134
135
|
Retrieve node IDs based on vector similarity using external vector DB.
|
|
@@ -157,7 +158,10 @@ class Neo4jCommunityGraphDB(Neo4jGraphDB):
|
|
|
157
158
|
if status:
|
|
158
159
|
vec_filter["status"] = status
|
|
159
160
|
vec_filter["vector_sync"] = "success"
|
|
160
|
-
|
|
161
|
+
if kwargs.get("cube_name"):
|
|
162
|
+
vec_filter["user_name"] = kwargs["cube_name"]
|
|
163
|
+
else:
|
|
164
|
+
vec_filter["user_name"] = self.config.user_name
|
|
161
165
|
|
|
162
166
|
# Perform vector search
|
|
163
167
|
results = self.vec_db.search(query_vector=vector, top_k=top_k, filter=vec_filter)
|
|
@@ -169,13 +173,12 @@ class Neo4jCommunityGraphDB(Neo4jGraphDB):
|
|
|
169
173
|
# Return consistent format
|
|
170
174
|
return [{"id": r.id, "score": r.score} for r in results]
|
|
171
175
|
|
|
172
|
-
def get_all_memory_items(self, scope: str) -> list[dict]:
|
|
176
|
+
def get_all_memory_items(self, scope: str, **kwargs) -> list[dict]:
|
|
173
177
|
"""
|
|
174
178
|
Retrieve all memory items of a specific memory_type.
|
|
175
179
|
|
|
176
180
|
Args:
|
|
177
181
|
scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'.
|
|
178
|
-
|
|
179
182
|
Returns:
|
|
180
183
|
list[dict]: Full list of memory items under this scope.
|
|
181
184
|
"""
|
memos/llms/vllm.py
CHANGED
|
@@ -105,6 +105,7 @@ class VLLMLLM(BaseLLM):
|
|
|
105
105
|
"temperature": float(getattr(self.config, "temperature", 0.8)),
|
|
106
106
|
"max_tokens": int(getattr(self.config, "max_tokens", 1024)),
|
|
107
107
|
"top_p": float(getattr(self.config, "top_p", 0.9)),
|
|
108
|
+
"extra_body": {"chat_template_kwargs": {"enable_thinking": False}},
|
|
108
109
|
}
|
|
109
110
|
|
|
110
111
|
response = self.client.chat.completions.create(**completion_kwargs)
|
|
@@ -142,6 +143,7 @@ class VLLMLLM(BaseLLM):
|
|
|
142
143
|
"max_tokens": int(getattr(self.config, "max_tokens", 1024)),
|
|
143
144
|
"top_p": float(getattr(self.config, "top_p", 0.9)),
|
|
144
145
|
"stream": True, # Enable streaming
|
|
146
|
+
"extra_body": {"chat_template_kwargs": {"enable_thinking": False}},
|
|
145
147
|
}
|
|
146
148
|
|
|
147
149
|
stream = self.client.chat.completions.create(**completion_kwargs)
|