MemoryOS 0.0.1__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of MemoryOS might be problematic. Click here for more details.

Files changed (124) hide show
  1. memoryos-0.1.13.dist-info/METADATA +288 -0
  2. memoryos-0.1.13.dist-info/RECORD +122 -0
  3. memos/__init__.py +20 -1
  4. memos/api/start_api.py +420 -0
  5. memos/chunkers/__init__.py +4 -0
  6. memos/chunkers/base.py +24 -0
  7. memos/chunkers/factory.py +22 -0
  8. memos/chunkers/sentence_chunker.py +35 -0
  9. memos/configs/__init__.py +0 -0
  10. memos/configs/base.py +82 -0
  11. memos/configs/chunker.py +45 -0
  12. memos/configs/embedder.py +53 -0
  13. memos/configs/graph_db.py +45 -0
  14. memos/configs/internet_retriever.py +81 -0
  15. memos/configs/llm.py +71 -0
  16. memos/configs/mem_chat.py +81 -0
  17. memos/configs/mem_cube.py +89 -0
  18. memos/configs/mem_os.py +74 -0
  19. memos/configs/mem_reader.py +53 -0
  20. memos/configs/mem_scheduler.py +78 -0
  21. memos/configs/memory.py +195 -0
  22. memos/configs/parser.py +38 -0
  23. memos/configs/utils.py +8 -0
  24. memos/configs/vec_db.py +64 -0
  25. memos/deprecation.py +262 -0
  26. memos/embedders/__init__.py +0 -0
  27. memos/embedders/base.py +15 -0
  28. memos/embedders/factory.py +23 -0
  29. memos/embedders/ollama.py +74 -0
  30. memos/embedders/sentence_transformer.py +40 -0
  31. memos/exceptions.py +30 -0
  32. memos/graph_dbs/__init__.py +0 -0
  33. memos/graph_dbs/base.py +215 -0
  34. memos/graph_dbs/factory.py +21 -0
  35. memos/graph_dbs/neo4j.py +827 -0
  36. memos/hello_world.py +97 -0
  37. memos/llms/__init__.py +0 -0
  38. memos/llms/base.py +16 -0
  39. memos/llms/factory.py +25 -0
  40. memos/llms/hf.py +231 -0
  41. memos/llms/ollama.py +82 -0
  42. memos/llms/openai.py +34 -0
  43. memos/llms/utils.py +14 -0
  44. memos/log.py +78 -0
  45. memos/mem_chat/__init__.py +0 -0
  46. memos/mem_chat/base.py +30 -0
  47. memos/mem_chat/factory.py +21 -0
  48. memos/mem_chat/simple.py +200 -0
  49. memos/mem_cube/__init__.py +0 -0
  50. memos/mem_cube/base.py +29 -0
  51. memos/mem_cube/general.py +146 -0
  52. memos/mem_cube/utils.py +24 -0
  53. memos/mem_os/client.py +5 -0
  54. memos/mem_os/core.py +819 -0
  55. memos/mem_os/main.py +503 -0
  56. memos/mem_os/product.py +89 -0
  57. memos/mem_reader/__init__.py +0 -0
  58. memos/mem_reader/base.py +27 -0
  59. memos/mem_reader/factory.py +21 -0
  60. memos/mem_reader/memory.py +298 -0
  61. memos/mem_reader/simple_struct.py +241 -0
  62. memos/mem_scheduler/__init__.py +0 -0
  63. memos/mem_scheduler/base_scheduler.py +164 -0
  64. memos/mem_scheduler/general_scheduler.py +305 -0
  65. memos/mem_scheduler/modules/__init__.py +0 -0
  66. memos/mem_scheduler/modules/base.py +74 -0
  67. memos/mem_scheduler/modules/dispatcher.py +103 -0
  68. memos/mem_scheduler/modules/monitor.py +82 -0
  69. memos/mem_scheduler/modules/redis_service.py +146 -0
  70. memos/mem_scheduler/modules/retriever.py +41 -0
  71. memos/mem_scheduler/modules/schemas.py +146 -0
  72. memos/mem_scheduler/scheduler_factory.py +21 -0
  73. memos/mem_scheduler/utils.py +26 -0
  74. memos/mem_user/user_manager.py +488 -0
  75. memos/memories/__init__.py +0 -0
  76. memos/memories/activation/__init__.py +0 -0
  77. memos/memories/activation/base.py +42 -0
  78. memos/memories/activation/item.py +25 -0
  79. memos/memories/activation/kv.py +232 -0
  80. memos/memories/base.py +19 -0
  81. memos/memories/factory.py +34 -0
  82. memos/memories/parametric/__init__.py +0 -0
  83. memos/memories/parametric/base.py +19 -0
  84. memos/memories/parametric/item.py +11 -0
  85. memos/memories/parametric/lora.py +41 -0
  86. memos/memories/textual/__init__.py +0 -0
  87. memos/memories/textual/base.py +89 -0
  88. memos/memories/textual/general.py +286 -0
  89. memos/memories/textual/item.py +167 -0
  90. memos/memories/textual/naive.py +185 -0
  91. memos/memories/textual/tree.py +321 -0
  92. memos/memories/textual/tree_text_memory/__init__.py +0 -0
  93. memos/memories/textual/tree_text_memory/organize/__init__.py +0 -0
  94. memos/memories/textual/tree_text_memory/organize/manager.py +305 -0
  95. memos/memories/textual/tree_text_memory/retrieve/__init__.py +0 -0
  96. memos/memories/textual/tree_text_memory/retrieve/internet_retriever.py +263 -0
  97. memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +89 -0
  98. memos/memories/textual/tree_text_memory/retrieve/reasoner.py +61 -0
  99. memos/memories/textual/tree_text_memory/retrieve/recall.py +158 -0
  100. memos/memories/textual/tree_text_memory/retrieve/reranker.py +111 -0
  101. memos/memories/textual/tree_text_memory/retrieve/retrieval_mid_structs.py +13 -0
  102. memos/memories/textual/tree_text_memory/retrieve/searcher.py +208 -0
  103. memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +68 -0
  104. memos/memories/textual/tree_text_memory/retrieve/utils.py +48 -0
  105. memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py +335 -0
  106. memos/parsers/__init__.py +0 -0
  107. memos/parsers/base.py +15 -0
  108. memos/parsers/factory.py +19 -0
  109. memos/parsers/markitdown.py +22 -0
  110. memos/settings.py +8 -0
  111. memos/templates/__init__.py +0 -0
  112. memos/templates/mem_reader_prompts.py +98 -0
  113. memos/templates/mem_scheduler_prompts.py +65 -0
  114. memos/templates/mos_prompts.py +63 -0
  115. memos/types.py +55 -0
  116. memos/vec_dbs/__init__.py +0 -0
  117. memos/vec_dbs/base.py +105 -0
  118. memos/vec_dbs/factory.py +21 -0
  119. memos/vec_dbs/item.py +43 -0
  120. memos/vec_dbs/qdrant.py +292 -0
  121. memoryos-0.0.1.dist-info/METADATA +0 -53
  122. memoryos-0.0.1.dist-info/RECORD +0 -5
  123. {memoryos-0.0.1.dist-info → memoryos-0.1.13.dist-info}/LICENSE +0 -0
  124. {memoryos-0.0.1.dist-info → memoryos-0.1.13.dist-info}/WHEEL +0 -0
@@ -0,0 +1,827 @@
1
+ import time
2
+
3
+ from datetime import datetime
4
+ from typing import Any, Literal
5
+
6
+ from neo4j import GraphDatabase
7
+
8
+ from memos.configs.graph_db import Neo4jGraphDBConfig
9
+ from memos.graph_dbs.base import BaseGraphDB
10
+ from memos.log import get_logger
11
+
12
+
13
+ logger = get_logger(__name__)
14
+
15
+
16
+ def _parse_node(node_data: dict[str, Any]) -> dict[str, Any]:
17
+ node = node_data.copy()
18
+
19
+ # Convert Neo4j datetime to string
20
+ for time_field in ("created_at", "updated_at"):
21
+ if time_field in node and hasattr(node[time_field], "isoformat"):
22
+ node[time_field] = node[time_field].isoformat()
23
+
24
+ return {"id": node.pop("id"), "memory": node.pop("memory", ""), "metadata": node}
25
+
26
+
27
+ def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]:
28
+ node_id = item["id"]
29
+ memory = item["memory"]
30
+ metadata = item.get("metadata", {})
31
+ return node_id, memory, metadata
32
+
33
+
34
+ def _prepare_node_metadata(metadata: dict[str, Any]) -> dict[str, Any]:
35
+ """
36
+ Ensure metadata has proper datetime fields and normalized types.
37
+
38
+ - Fill `created_at` and `updated_at` if missing (in ISO 8601 format).
39
+ - Convert embedding to list of float if present.
40
+ """
41
+ now = datetime.utcnow().isoformat()
42
+
43
+ # Fill timestamps if missing
44
+ metadata.setdefault("created_at", now)
45
+ metadata.setdefault("updated_at", now)
46
+
47
+ # Normalize embedding type
48
+ embedding = metadata.get("embedding")
49
+ if embedding and isinstance(embedding, list):
50
+ metadata["embedding"] = [float(x) for x in embedding]
51
+
52
+ return metadata
53
+
54
+
55
+ class Neo4jGraphDB(BaseGraphDB):
56
+ """Neo4j-based implementation of a graph memory store."""
57
+
58
+ def __init__(self, config: Neo4jGraphDBConfig):
59
+ self.config = config
60
+ self.driver = GraphDatabase.driver(config.uri, auth=(config.user, config.password))
61
+ self.db_name = config.db_name
62
+
63
+ if config.auto_create:
64
+ self._ensure_database_exists()
65
+
66
+ # Create only if not exists
67
+ self.create_index(dimensions=config.embedding_dimension)
68
+
69
+ def create_index(
70
+ self,
71
+ label: str = "Memory",
72
+ vector_property: str = "embedding",
73
+ dimensions: int = 1536,
74
+ index_name: str = "memory_vector_index",
75
+ ) -> None:
76
+ """
77
+ Create the vector index for embedding and datetime indexes for created_at and updated_at fields.
78
+ """
79
+ # Create vector index if it doesn't exist
80
+ if not self._vector_index_exists(index_name):
81
+ self._create_vector_index(label, vector_property, dimensions, index_name)
82
+ # Create indexes
83
+ self._create_basic_property_indexes()
84
+
85
+ def get_memory_count(self, memory_type: str) -> int:
86
+ query = """
87
+ MATCH (n:Memory)
88
+ WHERE n.memory_type = $memory_type
89
+ RETURN COUNT(n) AS count
90
+ """
91
+ with self.driver.session(database=self.db_name) as session:
92
+ result = session.run(query, memory_type=memory_type)
93
+ return result.single()["count"]
94
+
95
+ def remove_oldest_memory(self, memory_type: str, keep_latest: int) -> None:
96
+ """
97
+ Remove all WorkingMemory nodes except the latest `keep_latest` entries.
98
+
99
+ Args:
100
+ memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory').
101
+ keep_latest (int): Number of latest WorkingMemory entries to keep.
102
+ """
103
+ query = f"""
104
+ MATCH (n:Memory)
105
+ WHERE n.memory_type = '{memory_type}'
106
+ WITH n ORDER BY n.updated_at DESC
107
+ SKIP {keep_latest}
108
+ DETACH DELETE n
109
+ """
110
+ with self.driver.session(database=self.db_name) as session:
111
+ session.run(query)
112
+
113
+ def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None:
114
+ # Safely process metadata
115
+ metadata = _prepare_node_metadata(metadata)
116
+
117
+ # Merge node and set metadata
118
+ created_at = metadata.pop("created_at")
119
+ updated_at = metadata.pop("updated_at")
120
+
121
+ query = """
122
+ MERGE (n:Memory {id: $id})
123
+ SET n.memory = $memory,
124
+ n.created_at = datetime($created_at),
125
+ n.updated_at = datetime($updated_at),
126
+ n += $metadata
127
+ """
128
+ with self.driver.session(database=self.db_name) as session:
129
+ session.run(
130
+ query,
131
+ id=id,
132
+ memory=memory,
133
+ created_at=created_at,
134
+ updated_at=updated_at,
135
+ metadata=metadata,
136
+ )
137
+
138
+ def update_node(self, id: str, fields: dict[str, Any]) -> None:
139
+ """
140
+ Update node fields in Neo4j, auto-converting `created_at` and `updated_at` to datetime type if present.
141
+ """
142
+ fields = fields.copy() # Avoid mutating external dict
143
+ set_clauses = []
144
+ params = {"id": id, "fields": fields}
145
+
146
+ for time_field in ("created_at", "updated_at"):
147
+ if time_field in fields:
148
+ # Set clause like: n.created_at = datetime($created_at)
149
+ set_clauses.append(f"n.{time_field} = datetime(${time_field})")
150
+ params[time_field] = fields.pop(time_field)
151
+
152
+ set_clauses.append("n += $fields") # Merge remaining fields
153
+ set_clause_str = ",\n ".join(set_clauses)
154
+
155
+ query = f"""
156
+ MATCH (n:Memory {{id: $id}})
157
+ SET {set_clause_str}
158
+ """
159
+
160
+ with self.driver.session(database=self.db_name) as session:
161
+ session.run(query, **params)
162
+
163
+ def delete_node(self, id: str) -> None:
164
+ """
165
+ Delete a node from the graph.
166
+ Args:
167
+ id: Node identifier to delete.
168
+ """
169
+ with self.driver.session(database=self.db_name) as session:
170
+ session.run("MATCH (n:Memory {id: $id}) DETACH DELETE n", id=id)
171
+
172
+ # Edge (Relationship) Management
173
+ def add_edge(self, source_id: str, target_id: str, type: str) -> None:
174
+ """
175
+ Create an edge from source node to target node.
176
+ Args:
177
+ source_id: ID of the source node.
178
+ target_id: ID of the target node.
179
+ type: Relationship type (e.g., 'RELATE_TO', 'PARENT').
180
+ """
181
+ with self.driver.session(database=self.db_name) as session:
182
+ session.run(
183
+ f"""
184
+ MATCH (a:Memory {{id: $source_id}})
185
+ MATCH (b:Memory {{id: $target_id}})
186
+ MERGE (a)-[:{type}]->(b)
187
+ """,
188
+ {"source_id": source_id, "target_id": target_id},
189
+ )
190
+
191
+ def delete_edge(self, source_id: str, target_id: str, type: str) -> None:
192
+ """
193
+ Delete a specific edge between two nodes.
194
+ Args:
195
+ source_id: ID of the source node.
196
+ target_id: ID of the target node.
197
+ type: Relationship type to remove.
198
+ """
199
+ with self.driver.session(database=self.db_name) as session:
200
+ session.run(
201
+ f"MATCH (a:Memory {{id: $source}})-[r:{type}]->(b:Memory {{id: $target}})\nDELETE r",
202
+ source=source_id,
203
+ target=target_id,
204
+ )
205
+
206
+ def edge_exists(
207
+ self, source_id: str, target_id: str, type: str = "ANY", direction: str = "OUTGOING"
208
+ ) -> bool:
209
+ """
210
+ Check if an edge exists between two nodes.
211
+ Args:
212
+ source_id: ID of the source node.
213
+ target_id: ID of the target node.
214
+ type: Relationship type. Use "ANY" to match any relationship type.
215
+ direction: Direction of the edge.
216
+ Use "OUTGOING" (default), "INCOMING", or "ANY".
217
+ Returns:
218
+ True if the edge exists, otherwise False.
219
+ """
220
+ # Prepare the relationship pattern
221
+ rel = "r" if type == "ANY" else f"r:{type}"
222
+
223
+ # Prepare the match pattern with direction
224
+ if direction == "OUTGOING":
225
+ pattern = f"(a:Memory {{id: $source}})-[{rel}]->(b:Memory {{id: $target}})"
226
+ elif direction == "INCOMING":
227
+ pattern = f"(a:Memory {{id: $source}})<-[{rel}]-(b:Memory {{id: $target}})"
228
+ elif direction == "ANY":
229
+ pattern = f"(a:Memory {{id: $source}})-[{rel}]-(b:Memory {{id: $target}})"
230
+ else:
231
+ raise ValueError(
232
+ f"Invalid direction: {direction}. Must be 'OUTGOING', 'INCOMING', or 'ANY'."
233
+ )
234
+
235
+ # Run the Cypher query
236
+ with self.driver.session(database=self.db_name) as session:
237
+ result = session.run(
238
+ f"MATCH {pattern} RETURN r",
239
+ source=source_id,
240
+ target=target_id,
241
+ )
242
+ return result.single() is not None
243
+
244
+ # Graph Query & Reasoning
245
+ def get_node(self, id: str) -> dict[str, Any] | None:
246
+ """
247
+ Retrieve the metadata and memory of a node.
248
+ Args:
249
+ id: Node identifier.
250
+ Returns:
251
+ Dictionary of node fields, or None if not found.
252
+ """
253
+ with self.driver.session(database=self.db_name) as session:
254
+ result = session.run("MATCH (n:Memory {id: $id}) RETURN n", id=id)
255
+ record = result.single()
256
+ return _parse_node(dict(record["n"])) if record else None
257
+
258
+ def get_nodes(self, ids: list[str]) -> list[dict[str, Any]]:
259
+ """
260
+ Retrieve the metadata and memory of a list of nodes.
261
+ Args:
262
+ ids: List of Node identifier.
263
+ Returns:
264
+ list[dict]: Parsed node records containing 'id', 'memory', and 'metadata'.
265
+
266
+ Notes:
267
+ - Assumes all provided IDs are valid and exist.
268
+ - Returns empty list if input is empty.
269
+ """
270
+ if not ids:
271
+ return []
272
+
273
+ query = "MATCH (n:Memory) WHERE n.id IN $ids RETURN n"
274
+ with self.driver.session(database=self.db_name) as session:
275
+ results = session.run(query, {"ids": ids})
276
+ return [_parse_node(dict(record["n"])) for record in results]
277
+
278
+ def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[dict[str, str]]:
279
+ """
280
+ Get edges connected to a node, with optional type and direction filter.
281
+
282
+ Args:
283
+ id: Node ID to retrieve edges for.
284
+ type: Relationship type to match, or 'ANY' to match all.
285
+ direction: 'OUTGOING', 'INCOMING', or 'ANY'.
286
+
287
+ Returns:
288
+ List of edges:
289
+ [
290
+ {"from": "source_id", "to": "target_id", "type": "RELATE"},
291
+ ...
292
+ ]
293
+ """
294
+ # Build relationship type filter
295
+ rel_type = "" if type == "ANY" else f":{type}"
296
+
297
+ # Build Cypher pattern based on direction
298
+ if direction == "OUTGOING":
299
+ pattern = f"(a:Memory)-[r{rel_type}]->(b:Memory)"
300
+ where_clause = "a.id = $id"
301
+ elif direction == "INCOMING":
302
+ pattern = f"(a:Memory)<-[r{rel_type}]-(b:Memory)"
303
+ where_clause = "a.id = $id"
304
+ elif direction == "ANY":
305
+ pattern = f"(a:Memory)-[r{rel_type}]-(b:Memory)"
306
+ where_clause = "a.id = $id OR b.id = $id"
307
+ else:
308
+ raise ValueError("Invalid direction. Must be 'OUTGOING', 'INCOMING', or 'ANY'.")
309
+
310
+ query = f"""
311
+ MATCH {pattern}
312
+ WHERE {where_clause}
313
+ RETURN a.id AS from_id, b.id AS to_id, type(r) AS type
314
+ """
315
+
316
+ with self.driver.session(database=self.db_name) as session:
317
+ result = session.run(query, id=id)
318
+ edges = []
319
+ for record in result:
320
+ edges.append(
321
+ {"from": record["from_id"], "to": record["to_id"], "type": record["type"]}
322
+ )
323
+ return edges
324
+
325
+ def get_neighbors(
326
+ self, id: str, type: str, direction: Literal["in", "out", "both"] = "out"
327
+ ) -> list[str]:
328
+ """
329
+ Get connected node IDs in a specific direction and relationship type.
330
+ Args:
331
+ id: Source node ID.
332
+ type: Relationship type.
333
+ direction: Edge direction to follow ('out', 'in', or 'both').
334
+ Returns:
335
+ List of neighboring node IDs.
336
+ """
337
+ raise NotImplementedError
338
+
339
+ def get_children_with_embeddings(self, id: str) -> list[str]:
340
+ query = """
341
+ MATCH (p:Memory)-[:PARENT]->(c:Memory)
342
+ WHERE p.id = $id
343
+ RETURN c.id AS id, c.embedding AS embedding, c.memory AS memory
344
+ """
345
+ with self.driver.session(database=self.db_name) as session:
346
+ return list(session.run(query, id=id))
347
+
348
+ def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[str]:
349
+ """
350
+ Get the path of nodes from source to target within a limited depth.
351
+ Args:
352
+ source_id: Starting node ID.
353
+ target_id: Target node ID.
354
+ max_depth: Maximum path length to traverse.
355
+ Returns:
356
+ Ordered list of node IDs along the path.
357
+ """
358
+ raise NotImplementedError
359
+
360
+ def get_subgraph(
361
+ self, center_id: str, depth: int = 2, center_status: str = "activated"
362
+ ) -> dict[str, Any]:
363
+ """
364
+ Retrieve a local subgraph centered at a given node.
365
+ Args:
366
+ center_id: The ID of the center node.
367
+ depth: The hop distance for neighbors.
368
+ center_status: Required status for center node.
369
+ Returns:
370
+ {
371
+ "core_node": {...},
372
+ "neighbors": [...],
373
+ "edges": [...]
374
+ }
375
+ """
376
+ with self.driver.session(database=self.db_name) as session:
377
+ status_clause = f", status: '{center_status}'" if center_status else ""
378
+ query = f"""
379
+ MATCH (center:Memory {{id: $center_id{status_clause}}})
380
+ OPTIONAL MATCH (center)-[r*1..{depth}]-(neighbor:Memory)
381
+ WITH collect(DISTINCT center) AS centers,
382
+ collect(DISTINCT neighbor) AS neighbors,
383
+ collect(DISTINCT r) AS rels
384
+ RETURN centers, neighbors, rels
385
+ """
386
+ record = session.run(query, {"center_id": center_id}).single()
387
+
388
+ if not record:
389
+ logger.warning(
390
+ f"No active node found for center_id={center_id} with status={center_status}"
391
+ )
392
+ return {"core_node": None, "neighbors": [], "edges": []}
393
+
394
+ centers = record["centers"]
395
+ if not centers or centers[0] is None:
396
+ logger.warning(f"Center node not found or inactive for id={center_id}")
397
+ return {"core_node": None, "neighbors": [], "edges": []}
398
+
399
+ core_node = _parse_node(dict(centers[0]))
400
+ neighbors = [_parse_node(dict(n)) for n in record["neighbors"] if n]
401
+ edges = []
402
+ for rel_chain in record["rels"]:
403
+ for rel in rel_chain:
404
+ edges.append(
405
+ {
406
+ "type": rel.type,
407
+ "source": rel.start_node["id"],
408
+ "target": rel.end_node["id"],
409
+ }
410
+ )
411
+
412
+ return {"core_node": core_node, "neighbors": neighbors, "edges": edges}
413
+
414
+ def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]:
415
+ """
416
+ Get the ordered context chain starting from a node, following a relationship type.
417
+ Args:
418
+ id: Starting node ID.
419
+ type: Relationship type to follow (e.g., 'FOLLOWS').
420
+ Returns:
421
+ List of ordered node IDs in the chain.
422
+ """
423
+ raise NotImplementedError
424
+
425
+ # Search / recall operations
426
+ def search_by_embedding(
427
+ self,
428
+ vector: list[float],
429
+ top_k: int = 5,
430
+ scope: str | None = None,
431
+ status: str | None = None,
432
+ threshold: float | None = None,
433
+ ) -> list[dict]:
434
+ """
435
+ Retrieve node IDs based on vector similarity.
436
+
437
+ Args:
438
+ vector (list[float]): The embedding vector representing query semantics.
439
+ top_k (int): Number of top similar nodes to retrieve.
440
+ scope (str, optional): Memory type filter (e.g., 'WorkingMemory', 'LongTermMemory').
441
+ status (str, optional): Node status filter (e.g., 'active', 'archived').
442
+ If provided, restricts results to nodes with matching status.
443
+ threshold (float, optional): Minimum similarity score threshold (0 ~ 1).
444
+
445
+ Returns:
446
+ list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
447
+
448
+ Notes:
449
+ - This method uses Neo4j native vector indexing to search for similar nodes.
450
+ - If scope is provided, it restricts results to nodes with matching memory_type.
451
+ - If 'status' is provided, only nodes with the matching status will be returned.
452
+ - If threshold is provided, only results with score >= threshold will be returned.
453
+ - Typical use case: restrict to 'status = activated' to avoid
454
+ matching archived or merged nodes.
455
+ """
456
+ # Build WHERE clause dynamically
457
+ where_clauses = []
458
+ if scope:
459
+ where_clauses.append("node.memory_type = $scope")
460
+ if status:
461
+ where_clauses.append("node.status = $status")
462
+
463
+ where_clause = ""
464
+ if where_clauses:
465
+ where_clause = "WHERE " + " AND ".join(where_clauses)
466
+
467
+ query = f"""
468
+ CALL db.index.vector.queryNodes('memory_vector_index', $k, $embedding)
469
+ YIELD node, score
470
+ {where_clause}
471
+ RETURN node.id AS id, score
472
+ """
473
+
474
+ parameters = {"embedding": vector, "k": top_k, "scope": scope}
475
+ if scope:
476
+ parameters["scope"] = scope
477
+ if status:
478
+ parameters["status"] = status
479
+
480
+ with self.driver.session(database=self.db_name) as session:
481
+ result = session.run(query, parameters)
482
+ records = [{"id": record["id"], "score": record["score"]} for record in result]
483
+
484
+ # Threshold filtering after retrieval
485
+ if threshold is not None:
486
+ records = [r for r in records if r["score"] >= threshold]
487
+
488
+ return records
489
+
490
+ def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]:
491
+ """
492
+ TODO:
493
+ 1. ADD logic: "AND" vs "OR"(support logic combination);
494
+ 2. Support nested conditional expressions;
495
+
496
+ Retrieve node IDs that match given metadata filters.
497
+ Supports exact match.
498
+
499
+ Args:
500
+ filters: List of filter dicts like:
501
+ [
502
+ {"field": "key", "op": "in", "value": ["A", "B"]},
503
+ {"field": "confidence", "op": ">=", "value": 80},
504
+ {"field": "tags", "op": "contains", "value": "AI"},
505
+ ...
506
+ ]
507
+
508
+ Returns:
509
+ list[str]: Node IDs whose metadata match the filter conditions. (AND logic).
510
+
511
+ Notes:
512
+ - Supports structured querying such as tag/category/importance/time filtering.
513
+ - Can be used for faceted recall or prefiltering before embedding rerank.
514
+ """
515
+ where_clauses = []
516
+ params = {}
517
+
518
+ for i, f in enumerate(filters):
519
+ field = f["field"]
520
+ op = f.get("op", "=")
521
+ value = f["value"]
522
+ param_key = f"val{i}"
523
+
524
+ # Build WHERE clause
525
+ if op == "=":
526
+ where_clauses.append(f"n.{field} = ${param_key}")
527
+ params[param_key] = value
528
+ elif op == "in":
529
+ where_clauses.append(f"n.{field} IN ${param_key}")
530
+ params[param_key] = value
531
+ elif op == "contains":
532
+ where_clauses.append(f"ANY(x IN ${param_key} WHERE x IN n.{field})")
533
+ params[param_key] = value
534
+ elif op == "starts_with":
535
+ where_clauses.append(f"n.{field} STARTS WITH ${param_key}")
536
+ params[param_key] = value
537
+ elif op == "ends_with":
538
+ where_clauses.append(f"n.{field} ENDS WITH ${param_key}")
539
+ params[param_key] = value
540
+ elif op in [">", ">=", "<", "<="]:
541
+ where_clauses.append(f"n.{field} {op} ${param_key}")
542
+ params[param_key] = value
543
+ else:
544
+ raise ValueError(f"Unsupported operator: {op}")
545
+
546
+ where_str = " AND ".join(where_clauses)
547
+ query = f"MATCH (n:Memory) WHERE {where_str} RETURN n.id AS id"
548
+
549
+ with self.driver.session(database=self.db_name) as session:
550
+ result = session.run(query, params)
551
+ return [record["id"] for record in result]
552
+
553
+ def get_grouped_counts(
554
+ self,
555
+ group_fields: list[str],
556
+ where_clause: str = "",
557
+ params: dict[str, Any] | None = None,
558
+ ) -> list[dict[str, Any]]:
559
+ """
560
+ Count nodes grouped by any fields.
561
+
562
+ Args:
563
+ group_fields (list[str]): Fields to group by, e.g., ["memory_type", "status"]
564
+ where_clause (str, optional): Extra WHERE condition. E.g.,
565
+ "WHERE n.status = 'activated'"
566
+ params (dict, optional): Parameters for WHERE clause.
567
+
568
+ Returns:
569
+ list[dict]: e.g., [{ 'memory_type': 'WorkingMemory', 'status': 'active', 'count': 10 }, ...]
570
+ """
571
+ if not group_fields:
572
+ raise ValueError("group_fields cannot be empty")
573
+
574
+ # Force RETURN field AS field to guarantee key match
575
+ group_fields_cypher = ", ".join([f"n.{field} AS {field}" for field in group_fields])
576
+
577
+ query = f"""
578
+ MATCH (n:Memory)
579
+ {where_clause}
580
+ RETURN {group_fields_cypher}, COUNT(n) AS count
581
+ """
582
+
583
+ with self.driver.session(database=self.db_name) as session:
584
+ result = session.run(query, params or {})
585
+ return [
586
+ {**{field: record[field] for field in group_fields}, "count": record["count"]}
587
+ for record in result
588
+ ]
589
+
590
+ # Structure Maintenance
591
+ def deduplicate_nodes(self) -> None:
592
+ """
593
+ Deduplicate redundant or semantically similar nodes.
594
+ This typically involves identifying nodes with identical or near-identical memory.
595
+ """
596
+ raise NotImplementedError
597
+
598
+ def detect_conflicts(self) -> list[tuple[str, str]]:
599
+ """
600
+ Detect conflicting nodes based on logical or semantic inconsistency.
601
+ Returns:
602
+ A list of (node_id1, node_id2) tuples that conflict.
603
+ """
604
+ raise NotImplementedError
605
+
606
+ def merge_nodes(self, id1: str, id2: str) -> str:
607
+ """
608
+ Merge two similar or duplicate nodes into one.
609
+ Args:
610
+ id1: First node ID.
611
+ id2: Second node ID.
612
+ Returns:
613
+ ID of the resulting merged node.
614
+ """
615
+ raise NotImplementedError
616
+
617
+ # Utilities
618
+ def clear(self) -> None:
619
+ """
620
+ Clear the entire graph if the target database exists.
621
+ """
622
+ try:
623
+ # Step 1: Check if the database exists
624
+ with self.driver.session(database="system") as session:
625
+ result = session.run("SHOW DATABASES YIELD name RETURN name")
626
+ db_names = [record["name"] for record in result]
627
+ if self.db_name not in db_names:
628
+ logger.info(f"[Skip] Database '{self.db_name}' does not exist.")
629
+ return
630
+
631
+ # Step 2: Clear the graph in that database
632
+ with self.driver.session(database=self.db_name) as session:
633
+ session.run("MATCH (n) DETACH DELETE n")
634
+ logger.info(f"Cleared all nodes from database '{self.db_name}'.")
635
+
636
+ except Exception as e:
637
+ logger.error(f"[ERROR] Failed to clear database '{self.db_name}': {e}")
638
+ raise
639
+
640
+ def export_graph(self) -> dict[str, Any]:
641
+ """
642
+ Export all graph nodes and edges in a structured form.
643
+
644
+ Returns:
645
+ {
646
+ "nodes": [ { "id": ..., "memory": ..., "metadata": {...} }, ... ],
647
+ "edges": [ { "source": ..., "target": ..., "type": ... }, ... ]
648
+ }
649
+ """
650
+ with self.driver.session(database=self.db_name) as session:
651
+ # Export nodes
652
+ node_result = session.run("MATCH (n:Memory) RETURN n")
653
+ nodes = [_parse_node(dict(record["n"])) for record in node_result]
654
+
655
+ # Export edges
656
+ edge_result = session.run("""
657
+ MATCH (a:Memory)-[r]->(b:Memory)
658
+ RETURN a.id AS source, b.id AS target, type(r) AS type
659
+ """)
660
+ edges = [
661
+ {"source": record["source"], "target": record["target"], "type": record["type"]}
662
+ for record in edge_result
663
+ ]
664
+
665
+ return {"nodes": nodes, "edges": edges}
666
+
667
+ def import_graph(self, data: dict[str, Any]) -> None:
668
+ """
669
+ Import the entire graph from a serialized dictionary.
670
+
671
+ Args:
672
+ data: A dictionary containing all nodes and edges to be loaded.
673
+ """
674
+ with self.driver.session(database=self.db_name) as session:
675
+ for node in data.get("nodes", []):
676
+ id, memory, metadata = _compose_node(node)
677
+
678
+ metadata = _prepare_node_metadata(metadata)
679
+
680
+ # Merge node and set metadata
681
+ created_at = metadata.pop("created_at")
682
+ updated_at = metadata.pop("updated_at")
683
+
684
+ session.run(
685
+ """
686
+ MERGE (n:Memory {id: $id})
687
+ SET n.memory = $memory,
688
+ n.created_at = datetime($created_at),
689
+ n.updated_at = datetime($updated_at),
690
+ n += $metadata
691
+ """,
692
+ id=id,
693
+ memory=memory,
694
+ created_at=created_at,
695
+ updated_at=updated_at,
696
+ metadata=metadata,
697
+ )
698
+
699
+ for edge in data.get("edges", []):
700
+ session.run(
701
+ f"""
702
+ MATCH (a:Memory {{id: $source_id}})
703
+ MATCH (b:Memory {{id: $target_id}})
704
+ MERGE (a)-[:{edge["type"]}]->(b)
705
+ """,
706
+ source_id=edge["source"],
707
+ target_id=edge["target"],
708
+ )
709
+
710
+ def get_all_memory_items(self, scope: str) -> list[dict]:
711
+ """
712
+ Retrieve all memory items of a specific memory_type.
713
+
714
+ Args:
715
+ scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'.
716
+
717
+ Returns:
718
+ list[dict]: Full list of memory items under this scope.
719
+ """
720
+ if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory"}:
721
+ raise ValueError(f"Unsupported memory type scope: {scope}")
722
+
723
+ query = """
724
+ MATCH (n:Memory)
725
+ WHERE n.memory_type = $scope
726
+ RETURN n
727
+ """
728
+
729
+ with self.driver.session(database=self.db_name) as session:
730
+ results = session.run(query, {"scope": scope})
731
+ return [_parse_node(dict(record["n"])) for record in results]
732
+
733
+ def drop_database(self) -> None:
734
+ """
735
+ Permanently delete the entire database this instance is using.
736
+ WARNING: This operation is destructive and cannot be undone.
737
+ """
738
+ if self.db_name in ("system", "neo4j"):
739
+ raise ValueError(f"Refusing to drop protected database: {self.db_name}")
740
+
741
+ with self.driver.session(database="system") as session:
742
+ session.run(f"DROP DATABASE {self.db_name} IF EXISTS")
743
+ print(f"Database '{self.db_name}' has been dropped.")
744
+
745
+ def _ensure_database_exists(self):
746
+ with self.driver.session(database="system") as session:
747
+ session.run(f"CREATE DATABASE {self.db_name} IF NOT EXISTS")
748
+
749
+ # Wait until the database is available
750
+ for _ in range(10):
751
+ with self.driver.session(database="system") as session:
752
+ result = session.run(
753
+ "SHOW DATABASES YIELD name, currentStatus RETURN name, currentStatus"
754
+ )
755
+ status_map = {r["name"]: r["currentStatus"] for r in result}
756
+ if self.db_name in status_map and status_map[self.db_name] == "online":
757
+ return
758
+ time.sleep(1)
759
+
760
+ raise RuntimeError(f"Database {self.db_name} not ready after waiting.")
761
+
762
+ def _vector_index_exists(self, index_name: str = "memory_vector_index") -> bool:
763
+ query = "SHOW INDEXES YIELD name WHERE name = $name RETURN name"
764
+ with self.driver.session(database=self.db_name) as session:
765
+ result = session.run(query, name=index_name)
766
+ return result.single() is not None
767
+
768
+ def _create_vector_index(
769
+ self, label: str, vector_property: str, dimensions: int, index_name: str
770
+ ) -> None:
771
+ """
772
+ Create a vector index for the specified property in the label.
773
+ """
774
+ try:
775
+ query = f"""
776
+ CREATE VECTOR INDEX {index_name} IF NOT EXISTS
777
+ FOR (n:{label}) ON (n.{vector_property})
778
+ OPTIONS {{
779
+ indexConfig: {{
780
+ `vector.dimensions`: {dimensions},
781
+ `vector.similarity_function`: 'cosine'
782
+ }}
783
+ }}
784
+ """
785
+ with self.driver.session(database=self.db_name) as session:
786
+ session.run(query)
787
+ logger.debug(f"Vector index '{index_name}' ensured.")
788
+ except Exception as e:
789
+ logger.warning(f"Failed to create vector index '{index_name}': {e}")
790
+
791
+ def _create_basic_property_indexes(self) -> None:
792
+ """
793
+ Create standard B-tree indexes on memory_type, created_at, and updated_at fields.
794
+ """
795
+ try:
796
+ with self.driver.session(database=self.db_name) as session:
797
+ session.run("""
798
+ CREATE INDEX memory_type_index IF NOT EXISTS
799
+ FOR (n:Memory) ON (n.memory_type)
800
+ """)
801
+ logger.debug("Index 'memory_type_index' ensured.")
802
+
803
+ session.run("""
804
+ CREATE INDEX memory_created_at_index IF NOT EXISTS
805
+ FOR (n:Memory) ON (n.created_at)
806
+ """)
807
+ logger.debug("Index 'memory_created_at_index' ensured.")
808
+
809
+ session.run("""
810
+ CREATE INDEX memory_updated_at_index IF NOT EXISTS
811
+ FOR (n:Memory) ON (n.updated_at)
812
+ """)
813
+ logger.debug("Index 'memory_updated_at_index' ensured.")
814
+ except Exception as e:
815
+ logger.warning(f"Failed to create basic property indexes: {e}")
816
+
817
+ def _index_exists(self, index_name: str) -> bool:
818
+ """
819
+ Check if an index with the given name exists.
820
+ """
821
+ query = "SHOW INDEXES"
822
+ with self.driver.session(database=self.db_name) as session:
823
+ result = session.run(query)
824
+ for record in result:
825
+ if record["name"] == index_name:
826
+ return True
827
+ return False