MemoryOS 1.0.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MemoryOS might be problematic. Click here for more details.

Files changed (42) hide show
  1. {memoryos-1.0.0.dist-info → memoryos-1.0.1.dist-info}/METADATA +2 -1
  2. {memoryos-1.0.0.dist-info → memoryos-1.0.1.dist-info}/RECORD +42 -33
  3. memos/__init__.py +1 -1
  4. memos/api/config.py +25 -0
  5. memos/api/context/context_thread.py +96 -0
  6. memos/api/context/dependencies.py +0 -11
  7. memos/api/middleware/request_context.py +94 -0
  8. memos/api/product_api.py +5 -1
  9. memos/api/product_models.py +16 -0
  10. memos/api/routers/product_router.py +39 -3
  11. memos/api/start_api.py +3 -0
  12. memos/configs/memory.py +13 -0
  13. memos/configs/reranker.py +18 -0
  14. memos/graph_dbs/base.py +4 -2
  15. memos/graph_dbs/nebular.py +215 -68
  16. memos/graph_dbs/neo4j.py +14 -12
  17. memos/graph_dbs/neo4j_community.py +6 -3
  18. memos/llms/vllm.py +2 -0
  19. memos/log.py +120 -8
  20. memos/mem_os/core.py +30 -2
  21. memos/mem_os/product.py +386 -146
  22. memos/mem_os/utils/reference_utils.py +20 -0
  23. memos/mem_reader/simple_struct.py +112 -43
  24. memos/mem_user/mysql_user_manager.py +4 -2
  25. memos/memories/textual/item.py +1 -1
  26. memos/memories/textual/tree.py +31 -1
  27. memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +3 -1
  28. memos/memories/textual/tree_text_memory/retrieve/recall.py +53 -3
  29. memos/memories/textual/tree_text_memory/retrieve/searcher.py +74 -14
  30. memos/memories/textual/tree_text_memory/retrieve/utils.py +6 -4
  31. memos/memos_tools/notification_utils.py +46 -0
  32. memos/reranker/__init__.py +4 -0
  33. memos/reranker/base.py +24 -0
  34. memos/reranker/cosine_local.py +95 -0
  35. memos/reranker/factory.py +43 -0
  36. memos/reranker/http_bge.py +99 -0
  37. memos/reranker/noop.py +16 -0
  38. memos/templates/mem_reader_prompts.py +289 -40
  39. memos/templates/mos_prompts.py +133 -60
  40. {memoryos-1.0.0.dist-info → memoryos-1.0.1.dist-info}/LICENSE +0 -0
  41. {memoryos-1.0.0.dist-info → memoryos-1.0.1.dist-info}/WHEEL +0 -0
  42. {memoryos-1.0.0.dist-info → memoryos-1.0.1.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,18 @@
1
+ # memos/configs/reranker.py
2
+ from __future__ import annotations
3
+
4
+ from typing import Any
5
+
6
+ from pydantic import BaseModel, Field
7
+
8
+
9
+ class RerankerConfigFactory(BaseModel):
10
+ """
11
+ {
12
+ "backend": "http_bge" | "cosine_local" | "noop",
13
+ "config": { ... backend-specific ... }
14
+ }
15
+ """
16
+
17
+ backend: str = Field(..., description="Reranker backend id")
18
+ config: dict[str, Any] = Field(default_factory=dict, description="Backend-specific options")
memos/graph_dbs/base.py CHANGED
@@ -81,7 +81,9 @@ class BaseGraphDB(ABC):
81
81
  """
82
82
 
83
83
  @abstractmethod
84
- def get_nodes(self, id: str, include_embedding: bool = False) -> dict[str, Any] | None:
84
+ def get_nodes(
85
+ self, id: str, include_embedding: bool = False, **kwargs
86
+ ) -> dict[str, Any] | None:
85
87
  """
86
88
  Retrieve the metadata and memory of a list of nodes.
87
89
  Args:
@@ -141,7 +143,7 @@ class BaseGraphDB(ABC):
141
143
 
142
144
  # Search / recall operations
143
145
  @abstractmethod
144
- def search_by_embedding(self, vector: list[float], top_k: int = 5) -> list[dict]:
146
+ def search_by_embedding(self, vector: list[float], top_k: int = 5, **kwargs) -> list[dict]:
145
147
  """
146
148
  Retrieve node IDs based on vector similarity.
147
149
 
@@ -1,10 +1,11 @@
1
+ import json
1
2
  import traceback
2
3
 
3
4
  from contextlib import suppress
4
5
  from datetime import datetime
5
6
  from queue import Empty, Queue
6
7
  from threading import Lock
7
- from typing import Any, Literal
8
+ from typing import TYPE_CHECKING, Any, ClassVar, Literal
8
9
 
9
10
  import numpy as np
10
11
 
@@ -15,6 +16,10 @@ from memos.log import get_logger
15
16
  from memos.utils import timed
16
17
 
17
18
 
19
+ if TYPE_CHECKING:
20
+ from nebulagraph_python.client.pool import NebulaPool
21
+
22
+
18
23
  logger = get_logger(__name__)
19
24
 
20
25
 
@@ -35,7 +40,28 @@ def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]:
35
40
 
36
41
  @timed
37
42
  def _escape_str(value: str) -> str:
38
- return value.replace('"', '\\"')
43
+ out = []
44
+ for ch in value:
45
+ code = ord(ch)
46
+ if ch == "\\":
47
+ out.append("\\\\")
48
+ elif ch == '"':
49
+ out.append('\\"')
50
+ elif ch == "\n":
51
+ out.append("\\n")
52
+ elif ch == "\r":
53
+ out.append("\\r")
54
+ elif ch == "\t":
55
+ out.append("\\t")
56
+ elif ch == "\b":
57
+ out.append("\\b")
58
+ elif ch == "\f":
59
+ out.append("\\f")
60
+ elif code < 0x20 or code in (0x2028, 0x2029):
61
+ out.append(f"\\u{code:04x}")
62
+ else:
63
+ out.append(ch)
64
+ return "".join(out)
39
65
 
40
66
 
41
67
  @timed
@@ -197,6 +223,94 @@ class NebulaGraphDB(BaseGraphDB):
197
223
  NebulaGraph-based implementation of a graph memory store.
198
224
  """
199
225
 
226
+ # ====== shared pool cache & refcount ======
227
+ # These are process-local; in a multi-process model each process will
228
+ # have its own cache.
229
+ _POOL_CACHE: ClassVar[dict[str, "NebulaPool"]] = {}
230
+ _POOL_REFCOUNT: ClassVar[dict[str, int]] = {}
231
+ _POOL_LOCK: ClassVar[Lock] = Lock()
232
+
233
+ @staticmethod
234
+ def _make_pool_key(cfg: NebulaGraphDBConfig) -> str:
235
+ """
236
+ Build a cache key that captures all connection-affecting options.
237
+ Keep this key stable and include fields that change the underlying pool behavior.
238
+ """
239
+ # NOTE: Do not include tenant-like or query-scope-only fields here.
240
+ # Only include things that affect the actual TCP/auth/session pool.
241
+ return "|".join(
242
+ [
243
+ "nebula",
244
+ str(getattr(cfg, "uri", "")),
245
+ str(getattr(cfg, "user", "")),
246
+ str(getattr(cfg, "password", "")),
247
+ # pool sizing / tls / timeouts if you have them in config:
248
+ str(getattr(cfg, "max_client", 1000)),
249
+ # multi-db mode can impact how we use sessions; keep it to be safe
250
+ str(getattr(cfg, "use_multi_db", False)),
251
+ ]
252
+ )
253
+
254
+ @classmethod
255
+ def _get_or_create_shared_pool(cls, cfg: NebulaGraphDBConfig):
256
+ """
257
+ Get a shared NebulaPool from cache or create one if missing.
258
+ Thread-safe with a lock; maintains a simple refcount.
259
+ """
260
+ key = cls._make_pool_key(cfg)
261
+
262
+ with cls._POOL_LOCK:
263
+ pool = cls._POOL_CACHE.get(key)
264
+ if pool is None:
265
+ # Create a new pool and put into cache
266
+ pool = SessionPool(
267
+ hosts=cfg.get("uri"),
268
+ user=cfg.get("user"),
269
+ password=cfg.get("password"),
270
+ minsize=1,
271
+ maxsize=cfg.get("max_client", 1000),
272
+ )
273
+ cls._POOL_CACHE[key] = pool
274
+ cls._POOL_REFCOUNT[key] = 0
275
+ logger.info(f"[NebulaGraphDB] Created new shared NebulaPool for key={key}")
276
+
277
+ # Increase refcount for the caller
278
+ cls._POOL_REFCOUNT[key] = cls._POOL_REFCOUNT.get(key, 0) + 1
279
+ return key, pool
280
+
281
+ @classmethod
282
+ def _release_shared_pool(cls, key: str):
283
+ """
284
+ Decrease refcount for the given pool key; only close when refcount hits zero.
285
+ """
286
+ with cls._POOL_LOCK:
287
+ if key not in cls._POOL_CACHE:
288
+ return
289
+ cls._POOL_REFCOUNT[key] = max(0, cls._POOL_REFCOUNT.get(key, 0) - 1)
290
+ if cls._POOL_REFCOUNT[key] == 0:
291
+ try:
292
+ cls._POOL_CACHE[key].close()
293
+ except Exception as e:
294
+ logger.warning(f"[NebulaGraphDB] Error closing shared pool: {e}")
295
+ finally:
296
+ cls._POOL_CACHE.pop(key, None)
297
+ cls._POOL_REFCOUNT.pop(key, None)
298
+ logger.info(f"[NebulaGraphDB] Closed and removed shared pool key={key}")
299
+
300
+ @classmethod
301
+ def close_all_shared_pools(cls):
302
+ """Force close all cached pools. Call this on graceful shutdown."""
303
+ with cls._POOL_LOCK:
304
+ for key, pool in list(cls._POOL_CACHE.items()):
305
+ try:
306
+ pool.close()
307
+ except Exception as e:
308
+ logger.warning(f"[NebulaGraphDB] Error closing pool key={key}: {e}")
309
+ finally:
310
+ logger.info(f"[NebulaGraphDB] Closed pool key={key}")
311
+ cls._POOL_CACHE.clear()
312
+ cls._POOL_REFCOUNT.clear()
313
+
200
314
  @require_python_package(
201
315
  import_name="nebulagraph_python",
202
316
  install_command="pip install ... @Tianxing",
@@ -246,20 +360,21 @@ class NebulaGraphDB(BaseGraphDB):
246
360
  "usage",
247
361
  "background",
248
362
  }
363
+ self.base_fields = set(self.common_fields) - {"usage"}
364
+ self.heavy_fields = {"usage"}
249
365
  self.dim_field = (
250
366
  f"embedding_{self.embedding_dimension}"
251
367
  if (str(self.embedding_dimension) != str(self.default_memory_dimension))
252
368
  else "embedding"
253
369
  )
254
370
  self.system_db_name = "system" if config.use_multi_db else config.space
255
- self.pool = SessionPool(
256
- hosts=config.get("uri"),
257
- user=config.get("user"),
258
- password=config.get("password"),
259
- minsize=1,
260
- maxsize=config.get("max_client", 1000),
261
- )
262
371
 
372
+ # ---- NEW: pool acquisition strategy
373
+ # Get or create a shared pool from the class-level cache
374
+ self._pool_key, self.pool = self._get_or_create_shared_pool(config)
375
+ self._owns_pool = True # We manage refcount for this instance
376
+
377
+ # auto-create graph type / graph / index if needed
263
378
  if config.auto_create:
264
379
  self._ensure_database_exists()
265
380
 
@@ -271,7 +386,7 @@ class NebulaGraphDB(BaseGraphDB):
271
386
  logger.info("Connected to NebulaGraph successfully.")
272
387
 
273
388
  @timed
274
- def execute_query(self, gql: str, timeout: float = 5.0, auto_set_db: bool = True):
389
+ def execute_query(self, gql: str, timeout: float = 10.0, auto_set_db: bool = True):
275
390
  with self.pool.get() as client:
276
391
  try:
277
392
  if auto_set_db and self.db_name:
@@ -287,7 +402,25 @@ class NebulaGraphDB(BaseGraphDB):
287
402
 
288
403
  @timed
289
404
  def close(self):
290
- self.pool.close()
405
+ """
406
+ Close the connection resource if this instance owns it.
407
+
408
+ - If pool was injected (`shared_pool`), do nothing.
409
+ - If pool was acquired via shared cache, decrement refcount and close
410
+ when the last owner releases it.
411
+ """
412
+ if not self._owns_pool:
413
+ logger.debug("[NebulaGraphDB] close() skipped (injected pool).")
414
+ return
415
+ if self._pool_key:
416
+ self._release_shared_pool(self._pool_key)
417
+ self._pool_key = None
418
+ self.pool = None
419
+
420
+ # NOTE: __del__ is best-effort; do not rely on GC order.
421
+ def __del__(self):
422
+ with suppress(Exception):
423
+ self.close()
291
424
 
292
425
  @timed
293
426
  def create_index(
@@ -366,12 +499,10 @@ class NebulaGraphDB(BaseGraphDB):
366
499
  filter_clause = f'n.memory_type = "{scope}" AND n.user_name = "{self.config.user_name}"'
367
500
  else:
368
501
  filter_clause = f'n.memory_type = "{scope}"'
369
- return_fields = ", ".join(f"n.{field} AS {field}" for field in self.common_fields)
370
-
371
502
  query = f"""
372
503
  MATCH (n@Memory)
373
504
  WHERE {filter_clause}
374
- RETURN {return_fields}
505
+ RETURN n.id AS id
375
506
  LIMIT 1
376
507
  """
377
508
 
@@ -568,10 +699,7 @@ class NebulaGraphDB(BaseGraphDB):
568
699
  try:
569
700
  result = self.execute_query(gql)
570
701
  for row in result:
571
- if include_embedding:
572
- props = row.values()[0].as_node().get_properties()
573
- else:
574
- props = {k: v.value for k, v in row.items()}
702
+ props = {k: v.value for k, v in row.items()}
575
703
  node = self._parse_node(props)
576
704
  return node
577
705
 
@@ -582,7 +710,9 @@ class NebulaGraphDB(BaseGraphDB):
582
710
  return None
583
711
 
584
712
  @timed
585
- def get_nodes(self, ids: list[str], include_embedding: bool = False) -> list[dict[str, Any]]:
713
+ def get_nodes(
714
+ self, ids: list[str], include_embedding: bool = False, **kwargs
715
+ ) -> list[dict[str, Any]]:
586
716
  """
587
717
  Retrieve the metadata and memory of a list of nodes.
588
718
  Args:
@@ -600,7 +730,10 @@ class NebulaGraphDB(BaseGraphDB):
600
730
 
601
731
  where_user = ""
602
732
  if not self.config.use_multi_db and self.config.user_name:
603
- where_user = f" AND n.user_name = '{self.config.user_name}'"
733
+ if kwargs.get("cube_name"):
734
+ where_user = f" AND n.user_name = '{kwargs['cube_name']}'"
735
+ else:
736
+ where_user = f" AND n.user_name = '{self.config.user_name}'"
604
737
 
605
738
  # Safe formatting of the ID list
606
739
  id_list = ",".join(f'"{_id}"' for _id in ids)
@@ -615,10 +748,7 @@ class NebulaGraphDB(BaseGraphDB):
615
748
  try:
616
749
  results = self.execute_query(query)
617
750
  for row in results:
618
- if include_embedding:
619
- props = row.values()[0].as_node().get_properties()
620
- else:
621
- props = {k: v.value for k, v in row.items()}
751
+ props = {k: v.value for k, v in row.items()}
622
752
  nodes.append(self._parse_node(props))
623
753
  except Exception as e:
624
754
  logger.error(
@@ -687,6 +817,7 @@ class NebulaGraphDB(BaseGraphDB):
687
817
  exclude_ids: list[str],
688
818
  top_k: int = 5,
689
819
  min_overlap: int = 1,
820
+ include_embedding: bool = False,
690
821
  ) -> list[dict[str, Any]]:
691
822
  """
692
823
  Find top-K neighbor nodes with maximum tag overlap.
@@ -696,6 +827,7 @@ class NebulaGraphDB(BaseGraphDB):
696
827
  exclude_ids: Node IDs to exclude (e.g., local cluster).
697
828
  top_k: Max number of neighbors to return.
698
829
  min_overlap: Minimum number of overlapping tags required.
830
+ include_embedding: with/without embedding
699
831
 
700
832
  Returns:
701
833
  List of dicts with node details and overlap count.
@@ -717,12 +849,13 @@ class NebulaGraphDB(BaseGraphDB):
717
849
  where_clause = " AND ".join(where_clauses)
718
850
  tag_list_literal = "[" + ", ".join(f'"{_escape_str(t)}"' for t in tags) + "]"
719
851
 
852
+ return_fields = self._build_return_fields(include_embedding)
720
853
  query = f"""
721
854
  LET tag_list = {tag_list_literal}
722
855
 
723
856
  MATCH (n@Memory)
724
857
  WHERE {where_clause}
725
- RETURN n,
858
+ RETURN {return_fields},
726
859
  size( filter( n.tags, t -> t IN tag_list ) ) AS overlap_count
727
860
  ORDER BY overlap_count DESC
728
861
  LIMIT {top_k}
@@ -731,9 +864,8 @@ class NebulaGraphDB(BaseGraphDB):
731
864
  result = self.execute_query(query)
732
865
  neighbors: list[dict[str, Any]] = []
733
866
  for r in result:
734
- node_props = r["n"].as_node().get_properties()
735
- parsed = self._parse_node(node_props) # --> {id, memory, metadata}
736
-
867
+ props = {k: v.value for k, v in r.items() if k != "overlap_count"}
868
+ parsed = self._parse_node(props)
737
869
  parsed["overlap_count"] = r["overlap_count"].value
738
870
  neighbors.append(parsed)
739
871
 
@@ -840,6 +972,7 @@ class NebulaGraphDB(BaseGraphDB):
840
972
  scope: str | None = None,
841
973
  status: str | None = None,
842
974
  threshold: float | None = None,
975
+ **kwargs,
843
976
  ) -> list[dict]:
844
977
  """
845
978
  Retrieve node IDs based on vector similarity.
@@ -874,7 +1007,10 @@ class NebulaGraphDB(BaseGraphDB):
874
1007
  if status:
875
1008
  where_clauses.append(f'n.status = "{status}"')
876
1009
  if not self.config.use_multi_db and self.config.user_name:
877
- where_clauses.append(f'n.user_name = "{self.config.user_name}"')
1010
+ if kwargs.get("cube_name"):
1011
+ where_clauses.append(f'n.user_name = "{kwargs["cube_name"]}"')
1012
+ else:
1013
+ where_clauses.append(f'n.user_name = "{self.config.user_name}"')
878
1014
 
879
1015
  where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
880
1016
 
@@ -936,20 +1072,12 @@ class NebulaGraphDB(BaseGraphDB):
936
1072
  """
937
1073
  where_clauses = []
938
1074
 
939
- def _escape_value(value):
940
- if isinstance(value, str):
941
- return f'"{value}"'
942
- elif isinstance(value, list):
943
- return "[" + ", ".join(_escape_value(v) for v in value) + "]"
944
- else:
945
- return str(value)
946
-
947
1075
  for _i, f in enumerate(filters):
948
1076
  field = f["field"]
949
1077
  op = f.get("op", "=")
950
1078
  value = f["value"]
951
1079
 
952
- escaped_value = _escape_value(value)
1080
+ escaped_value = self._format_value(value)
953
1081
 
954
1082
  # Build WHERE clause
955
1083
  if op == "=":
@@ -1153,28 +1281,36 @@ class NebulaGraphDB(BaseGraphDB):
1153
1281
  data: A dictionary containing all nodes and edges to be loaded.
1154
1282
  """
1155
1283
  for node in data.get("nodes", []):
1156
- id, memory, metadata = _compose_node(node)
1284
+ try:
1285
+ id, memory, metadata = _compose_node(node)
1157
1286
 
1158
- if not self.config.use_multi_db and self.config.user_name:
1159
- metadata["user_name"] = self.config.user_name
1287
+ if not self.config.use_multi_db and self.config.user_name:
1288
+ metadata["user_name"] = self.config.user_name
1160
1289
 
1161
- metadata = self._prepare_node_metadata(metadata)
1162
- metadata.update({"id": id, "memory": memory})
1163
- properties = ", ".join(f"{k}: {self._format_value(v, k)}" for k, v in metadata.items())
1164
- node_gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
1165
- self.execute_query(node_gql)
1290
+ metadata = self._prepare_node_metadata(metadata)
1291
+ metadata.update({"id": id, "memory": memory})
1292
+ properties = ", ".join(
1293
+ f"{k}: {self._format_value(v, k)}" for k, v in metadata.items()
1294
+ )
1295
+ node_gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
1296
+ self.execute_query(node_gql)
1297
+ except Exception as e:
1298
+ logger.error(f"Fail to load node: {node}, error: {e}")
1166
1299
 
1167
1300
  for edge in data.get("edges", []):
1168
- source_id, target_id = edge["source"], edge["target"]
1169
- edge_type = edge["type"]
1170
- props = ""
1171
- if not self.config.use_multi_db and self.config.user_name:
1172
- props = f'{{user_name: "{self.config.user_name}"}}'
1173
- edge_gql = f'''
1174
- MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}})
1175
- INSERT OR IGNORE (a) -[e@{edge_type} {props}]-> (b)
1176
- '''
1177
- self.execute_query(edge_gql)
1301
+ try:
1302
+ source_id, target_id = edge["source"], edge["target"]
1303
+ edge_type = edge["type"]
1304
+ props = ""
1305
+ if not self.config.use_multi_db and self.config.user_name:
1306
+ props = f'{{user_name: "{self.config.user_name}"}}'
1307
+ edge_gql = f'''
1308
+ MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}})
1309
+ INSERT OR IGNORE (a) -[e@{edge_type} {props}]-> (b)
1310
+ '''
1311
+ self.execute_query(edge_gql)
1312
+ except Exception as e:
1313
+ logger.error(f"Fail to load edge: {edge}, error: {e}")
1178
1314
 
1179
1315
  @timed
1180
1316
  def get_all_memory_items(self, scope: str, include_embedding: bool = False) -> (list)[dict]:
@@ -1208,10 +1344,7 @@ class NebulaGraphDB(BaseGraphDB):
1208
1344
  try:
1209
1345
  results = self.execute_query(query)
1210
1346
  for row in results:
1211
- if include_embedding:
1212
- props = row.values()[0].as_node().get_properties()
1213
- else:
1214
- props = {k: v.value for k, v in row.items()}
1347
+ props = {k: v.value for k, v in row.items()}
1215
1348
  nodes.append(self._parse_node(props))
1216
1349
  except Exception as e:
1217
1350
  logger.error(f"Failed to get memories: {e}")
@@ -1250,10 +1383,7 @@ class NebulaGraphDB(BaseGraphDB):
1250
1383
  try:
1251
1384
  results = self.execute_query(query)
1252
1385
  for row in results:
1253
- if include_embedding:
1254
- props = row.values()[0].as_node().get_properties()
1255
- else:
1256
- props = {k: v.value for k, v in row.items()}
1386
+ props = {k: v.value for k, v in row.items()}
1257
1387
  candidates.append(self._parse_node(props))
1258
1388
  except Exception as e:
1259
1389
  logger.error(f"Failed : {e}, traceback: {traceback.format_exc()}")
@@ -1555,6 +1685,7 @@ class NebulaGraphDB(BaseGraphDB):
1555
1685
  # Normalize embedding type
1556
1686
  embedding = metadata.get("embedding")
1557
1687
  if embedding and isinstance(embedding, list):
1688
+ metadata.pop("embedding")
1558
1689
  metadata[self.dim_field] = _normalize([float(x) for x in embedding])
1559
1690
 
1560
1691
  return metadata
@@ -1563,12 +1694,22 @@ class NebulaGraphDB(BaseGraphDB):
1563
1694
  def _format_value(self, val: Any, key: str = "") -> str:
1564
1695
  from nebulagraph_python.py_data_types import NVector
1565
1696
 
1697
+ # None
1698
+ if val is None:
1699
+ return "NULL"
1700
+ # bool
1701
+ if isinstance(val, bool):
1702
+ return "true" if val else "false"
1703
+ # str
1566
1704
  if isinstance(val, str):
1567
1705
  return f'"{_escape_str(val)}"'
1706
+ # num
1568
1707
  elif isinstance(val, (int | float)):
1569
1708
  return str(val)
1709
+ # time
1570
1710
  elif isinstance(val, datetime):
1571
1711
  return f'datetime("{val.isoformat()}")'
1712
+ # list
1572
1713
  elif isinstance(val, list):
1573
1714
  if key == self.dim_field:
1574
1715
  dim = len(val)
@@ -1576,13 +1717,18 @@ class NebulaGraphDB(BaseGraphDB):
1576
1717
  return f"VECTOR<{dim}, FLOAT>([{joined}])"
1577
1718
  else:
1578
1719
  return f"[{', '.join(self._format_value(v) for v in val)}]"
1720
+ # NVector
1579
1721
  elif isinstance(val, NVector):
1580
1722
  if key == self.dim_field:
1581
1723
  dim = len(val)
1582
1724
  joined = ",".join(str(float(x)) for x in val)
1583
1725
  return f"VECTOR<{dim}, FLOAT>([{joined}])"
1584
- elif val is None:
1585
- return "NULL"
1726
+ else:
1727
+ logger.warning("Invalid NVector")
1728
+ # dict
1729
+ if isinstance(val, dict):
1730
+ j = json.dumps(val, ensure_ascii=False, separators=(",", ":"))
1731
+ return f'"{_escape_str(j)}"'
1586
1732
  else:
1587
1733
  return f'"{_escape_str(str(val))}"'
1588
1734
 
@@ -1607,6 +1753,7 @@ class NebulaGraphDB(BaseGraphDB):
1607
1753
  return filtered_metadata
1608
1754
 
1609
1755
  def _build_return_fields(self, include_embedding: bool = False) -> str:
1756
+ fields = set(self.base_fields)
1610
1757
  if include_embedding:
1611
- return "n"
1612
- return ", ".join(f"n.{field} AS {field}" for field in self.common_fields)
1758
+ fields.add(self.dim_field)
1759
+ return ", ".join(f"n.{f} AS {f}" for f in fields)
memos/graph_dbs/neo4j.py CHANGED
@@ -323,12 +323,11 @@ class Neo4jGraphDB(BaseGraphDB):
323
323
  return result.single() is not None
324
324
 
325
325
  # Graph Query & Reasoning
326
- def get_node(self, id: str, include_embedding: bool = True) -> dict[str, Any] | None:
326
+ def get_node(self, id: str, **kwargs) -> dict[str, Any] | None:
327
327
  """
328
328
  Retrieve the metadata and memory of a node.
329
329
  Args:
330
330
  id: Node identifier.
331
- include_embedding (bool): Whether to include the large embedding field.
332
331
  Returns:
333
332
  Dictionary of node fields, or None if not found.
334
333
  """
@@ -345,12 +344,11 @@ class Neo4jGraphDB(BaseGraphDB):
345
344
  record = session.run(query, params).single()
346
345
  return self._parse_node(dict(record["n"])) if record else None
347
346
 
348
- def get_nodes(self, ids: list[str], include_embedding: bool = True) -> list[dict[str, Any]]:
347
+ def get_nodes(self, ids: list[str], **kwargs) -> list[dict[str, Any]]:
349
348
  """
350
349
  Retrieve the metadata and memory of a list of nodes.
351
350
  Args:
352
351
  ids: List of Node identifier.
353
- include_embedding (bool): Whether to include the large embedding field.
354
352
  Returns:
355
353
  list[dict]: Parsed node records containing 'id', 'memory', and 'metadata'.
356
354
 
@@ -367,7 +365,10 @@ class Neo4jGraphDB(BaseGraphDB):
367
365
 
368
366
  if not self.config.use_multi_db and self.config.user_name:
369
367
  where_user = " AND n.user_name = $user_name"
370
- params["user_name"] = self.config.user_name
368
+ if kwargs.get("cube_name"):
369
+ params["user_name"] = kwargs["cube_name"]
370
+ else:
371
+ params["user_name"] = self.config.user_name
371
372
 
372
373
  query = f"MATCH (n:Memory) WHERE n.id IN $ids{where_user} RETURN n"
373
374
 
@@ -605,6 +606,7 @@ class Neo4jGraphDB(BaseGraphDB):
605
606
  scope: str | None = None,
606
607
  status: str | None = None,
607
608
  threshold: float | None = None,
609
+ **kwargs,
608
610
  ) -> list[dict]:
609
611
  """
610
612
  Retrieve node IDs based on vector similarity.
@@ -654,7 +656,10 @@ class Neo4jGraphDB(BaseGraphDB):
654
656
  if status:
655
657
  parameters["status"] = status
656
658
  if not self.config.use_multi_db and self.config.user_name:
657
- parameters["user_name"] = self.config.user_name
659
+ if kwargs.get("cube_name"):
660
+ parameters["user_name"] = kwargs["cube_name"]
661
+ else:
662
+ parameters["user_name"] = self.config.user_name
658
663
 
659
664
  with self.driver.session(database=self.db_name) as session:
660
665
  result = session.run(query, parameters)
@@ -833,7 +838,7 @@ class Neo4jGraphDB(BaseGraphDB):
833
838
  logger.error(f"[ERROR] Failed to clear database '{self.db_name}': {e}")
834
839
  raise
835
840
 
836
- def export_graph(self, include_embedding: bool = True) -> dict[str, Any]:
841
+ def export_graph(self, **kwargs) -> dict[str, Any]:
837
842
  """
838
843
  Export all graph nodes and edges in a structured form.
839
844
 
@@ -914,13 +919,12 @@ class Neo4jGraphDB(BaseGraphDB):
914
919
  target_id=edge["target"],
915
920
  )
916
921
 
917
- def get_all_memory_items(self, scope: str, include_embedding: bool = True) -> list[dict]:
922
+ def get_all_memory_items(self, scope: str, **kwargs) -> list[dict]:
918
923
  """
919
924
  Retrieve all memory items of a specific memory_type.
920
925
 
921
926
  Args:
922
927
  scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'.
923
- include_embedding (bool): Whether to include the large embedding field.
924
928
  Returns:
925
929
 
926
930
  Returns:
@@ -946,9 +950,7 @@ class Neo4jGraphDB(BaseGraphDB):
946
950
  results = session.run(query, params)
947
951
  return [self._parse_node(dict(record["n"])) for record in results]
948
952
 
949
- def get_structure_optimization_candidates(
950
- self, scope: str, include_embedding: bool = True
951
- ) -> list[dict]:
953
+ def get_structure_optimization_candidates(self, scope: str, **kwargs) -> list[dict]:
952
954
  """
953
955
  Find nodes that are likely candidates for structure optimization:
954
956
  - Isolated nodes, nodes with empty background, or nodes with exactly one child.
@@ -129,6 +129,7 @@ class Neo4jCommunityGraphDB(Neo4jGraphDB):
129
129
  scope: str | None = None,
130
130
  status: str | None = None,
131
131
  threshold: float | None = None,
132
+ **kwargs,
132
133
  ) -> list[dict]:
133
134
  """
134
135
  Retrieve node IDs based on vector similarity using external vector DB.
@@ -157,7 +158,10 @@ class Neo4jCommunityGraphDB(Neo4jGraphDB):
157
158
  if status:
158
159
  vec_filter["status"] = status
159
160
  vec_filter["vector_sync"] = "success"
160
- vec_filter["user_name"] = self.config.user_name
161
+ if kwargs.get("cube_name"):
162
+ vec_filter["user_name"] = kwargs["cube_name"]
163
+ else:
164
+ vec_filter["user_name"] = self.config.user_name
161
165
 
162
166
  # Perform vector search
163
167
  results = self.vec_db.search(query_vector=vector, top_k=top_k, filter=vec_filter)
@@ -169,13 +173,12 @@ class Neo4jCommunityGraphDB(Neo4jGraphDB):
169
173
  # Return consistent format
170
174
  return [{"id": r.id, "score": r.score} for r in results]
171
175
 
172
- def get_all_memory_items(self, scope: str) -> list[dict]:
176
+ def get_all_memory_items(self, scope: str, **kwargs) -> list[dict]:
173
177
  """
174
178
  Retrieve all memory items of a specific memory_type.
175
179
 
176
180
  Args:
177
181
  scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'.
178
-
179
182
  Returns:
180
183
  list[dict]: Full list of memory items under this scope.
181
184
  """
memos/llms/vllm.py CHANGED
@@ -105,6 +105,7 @@ class VLLMLLM(BaseLLM):
105
105
  "temperature": float(getattr(self.config, "temperature", 0.8)),
106
106
  "max_tokens": int(getattr(self.config, "max_tokens", 1024)),
107
107
  "top_p": float(getattr(self.config, "top_p", 0.9)),
108
+ "extra_body": {"chat_template_kwargs": {"enable_thinking": False}},
108
109
  }
109
110
 
110
111
  response = self.client.chat.completions.create(**completion_kwargs)
@@ -142,6 +143,7 @@ class VLLMLLM(BaseLLM):
142
143
  "max_tokens": int(getattr(self.config, "max_tokens", 1024)),
143
144
  "top_p": float(getattr(self.config, "top_p", 0.9)),
144
145
  "stream": True, # Enable streaming
146
+ "extra_body": {"chat_template_kwargs": {"enable_thinking": False}},
145
147
  }
146
148
 
147
149
  stream = self.client.chat.completions.create(**completion_kwargs)