mem0ai-azure-mysql 0.1.115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mem0/__init__.py +6 -0
  2. mem0/client/__init__.py +0 -0
  3. mem0/client/main.py +1535 -0
  4. mem0/client/project.py +860 -0
  5. mem0/client/utils.py +29 -0
  6. mem0/configs/__init__.py +0 -0
  7. mem0/configs/base.py +90 -0
  8. mem0/configs/dbs/__init__.py +4 -0
  9. mem0/configs/dbs/base.py +41 -0
  10. mem0/configs/dbs/mysql.py +25 -0
  11. mem0/configs/embeddings/__init__.py +0 -0
  12. mem0/configs/embeddings/base.py +108 -0
  13. mem0/configs/enums.py +7 -0
  14. mem0/configs/llms/__init__.py +0 -0
  15. mem0/configs/llms/base.py +152 -0
  16. mem0/configs/prompts.py +333 -0
  17. mem0/configs/vector_stores/__init__.py +0 -0
  18. mem0/configs/vector_stores/azure_ai_search.py +59 -0
  19. mem0/configs/vector_stores/baidu.py +29 -0
  20. mem0/configs/vector_stores/chroma.py +40 -0
  21. mem0/configs/vector_stores/elasticsearch.py +47 -0
  22. mem0/configs/vector_stores/faiss.py +39 -0
  23. mem0/configs/vector_stores/langchain.py +32 -0
  24. mem0/configs/vector_stores/milvus.py +43 -0
  25. mem0/configs/vector_stores/mongodb.py +25 -0
  26. mem0/configs/vector_stores/opensearch.py +41 -0
  27. mem0/configs/vector_stores/pgvector.py +37 -0
  28. mem0/configs/vector_stores/pinecone.py +56 -0
  29. mem0/configs/vector_stores/qdrant.py +49 -0
  30. mem0/configs/vector_stores/redis.py +26 -0
  31. mem0/configs/vector_stores/supabase.py +44 -0
  32. mem0/configs/vector_stores/upstash_vector.py +36 -0
  33. mem0/configs/vector_stores/vertex_ai_vector_search.py +27 -0
  34. mem0/configs/vector_stores/weaviate.py +43 -0
  35. mem0/dbs/__init__.py +4 -0
  36. mem0/dbs/base.py +68 -0
  37. mem0/dbs/configs.py +21 -0
  38. mem0/dbs/mysql.py +321 -0
  39. mem0/embeddings/__init__.py +0 -0
  40. mem0/embeddings/aws_bedrock.py +100 -0
  41. mem0/embeddings/azure_openai.py +43 -0
  42. mem0/embeddings/base.py +31 -0
  43. mem0/embeddings/configs.py +30 -0
  44. mem0/embeddings/gemini.py +39 -0
  45. mem0/embeddings/huggingface.py +41 -0
  46. mem0/embeddings/langchain.py +35 -0
  47. mem0/embeddings/lmstudio.py +29 -0
  48. mem0/embeddings/mock.py +11 -0
  49. mem0/embeddings/ollama.py +53 -0
  50. mem0/embeddings/openai.py +49 -0
  51. mem0/embeddings/together.py +31 -0
  52. mem0/embeddings/vertexai.py +54 -0
  53. mem0/graphs/__init__.py +0 -0
  54. mem0/graphs/configs.py +96 -0
  55. mem0/graphs/neptune/__init__.py +0 -0
  56. mem0/graphs/neptune/base.py +410 -0
  57. mem0/graphs/neptune/main.py +372 -0
  58. mem0/graphs/tools.py +371 -0
  59. mem0/graphs/utils.py +97 -0
  60. mem0/llms/__init__.py +0 -0
  61. mem0/llms/anthropic.py +64 -0
  62. mem0/llms/aws_bedrock.py +270 -0
  63. mem0/llms/azure_openai.py +114 -0
  64. mem0/llms/azure_openai_structured.py +76 -0
  65. mem0/llms/base.py +32 -0
  66. mem0/llms/configs.py +34 -0
  67. mem0/llms/deepseek.py +85 -0
  68. mem0/llms/gemini.py +201 -0
  69. mem0/llms/groq.py +88 -0
  70. mem0/llms/langchain.py +65 -0
  71. mem0/llms/litellm.py +87 -0
  72. mem0/llms/lmstudio.py +53 -0
  73. mem0/llms/ollama.py +94 -0
  74. mem0/llms/openai.py +124 -0
  75. mem0/llms/openai_structured.py +52 -0
  76. mem0/llms/sarvam.py +89 -0
  77. mem0/llms/together.py +88 -0
  78. mem0/llms/vllm.py +89 -0
  79. mem0/llms/xai.py +52 -0
  80. mem0/memory/__init__.py +0 -0
  81. mem0/memory/base.py +63 -0
  82. mem0/memory/graph_memory.py +632 -0
  83. mem0/memory/main.py +1843 -0
  84. mem0/memory/memgraph_memory.py +630 -0
  85. mem0/memory/setup.py +56 -0
  86. mem0/memory/storage.py +218 -0
  87. mem0/memory/telemetry.py +90 -0
  88. mem0/memory/utils.py +133 -0
  89. mem0/proxy/__init__.py +0 -0
  90. mem0/proxy/main.py +194 -0
  91. mem0/utils/factory.py +132 -0
  92. mem0/vector_stores/__init__.py +0 -0
  93. mem0/vector_stores/azure_ai_search.py +383 -0
  94. mem0/vector_stores/baidu.py +368 -0
  95. mem0/vector_stores/base.py +58 -0
  96. mem0/vector_stores/chroma.py +229 -0
  97. mem0/vector_stores/configs.py +60 -0
  98. mem0/vector_stores/elasticsearch.py +235 -0
  99. mem0/vector_stores/faiss.py +473 -0
  100. mem0/vector_stores/langchain.py +179 -0
  101. mem0/vector_stores/milvus.py +245 -0
  102. mem0/vector_stores/mongodb.py +293 -0
  103. mem0/vector_stores/opensearch.py +281 -0
  104. mem0/vector_stores/pgvector.py +294 -0
  105. mem0/vector_stores/pinecone.py +373 -0
  106. mem0/vector_stores/qdrant.py +240 -0
  107. mem0/vector_stores/redis.py +295 -0
  108. mem0/vector_stores/supabase.py +237 -0
  109. mem0/vector_stores/upstash_vector.py +293 -0
  110. mem0/vector_stores/vertex_ai_vector_search.py +629 -0
  111. mem0/vector_stores/weaviate.py +316 -0
  112. mem0ai_azure_mysql-0.1.115.data/data/README.md +169 -0
  113. mem0ai_azure_mysql-0.1.115.dist-info/METADATA +224 -0
  114. mem0ai_azure_mysql-0.1.115.dist-info/RECORD +116 -0
  115. mem0ai_azure_mysql-0.1.115.dist-info/WHEEL +4 -0
  116. mem0ai_azure_mysql-0.1.115.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,630 @@
1
+ import logging
2
+
3
+ from mem0.memory.utils import format_entities
4
+
5
+ try:
6
+ from langchain_memgraph.graphs.memgraph import Memgraph
7
+ except ImportError:
8
+ raise ImportError("langchain_memgraph is not installed. Please install it using pip install langchain-memgraph")
9
+
10
+ try:
11
+ from rank_bm25 import BM25Okapi
12
+ except ImportError:
13
+ raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25")
14
+
15
+ from mem0.graphs.tools import (
16
+ DELETE_MEMORY_STRUCT_TOOL_GRAPH,
17
+ DELETE_MEMORY_TOOL_GRAPH,
18
+ EXTRACT_ENTITIES_STRUCT_TOOL,
19
+ EXTRACT_ENTITIES_TOOL,
20
+ RELATIONS_STRUCT_TOOL,
21
+ RELATIONS_TOOL,
22
+ )
23
+ from mem0.graphs.utils import EXTRACT_RELATIONS_PROMPT, get_delete_messages
24
+ from mem0.utils.factory import EmbedderFactory, LlmFactory
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class MemoryGraph:
30
+ def __init__(self, config):
31
+ self.config = config
32
+ self.graph = Memgraph(
33
+ self.config.graph_store.config.url,
34
+ self.config.graph_store.config.username,
35
+ self.config.graph_store.config.password,
36
+ )
37
+ self.embedding_model = EmbedderFactory.create(
38
+ self.config.embedder.provider,
39
+ self.config.embedder.config,
40
+ {"enable_embeddings": True},
41
+ )
42
+
43
+ self.llm_provider = "openai_structured"
44
+ if self.config.llm.provider:
45
+ self.llm_provider = self.config.llm.provider
46
+ if self.config.graph_store.llm:
47
+ self.llm_provider = self.config.graph_store.llm.provider
48
+
49
+ self.llm = LlmFactory.create(self.llm_provider, self.config.llm.config)
50
+ self.user_id = None
51
+ self.threshold = 0.7
52
+
53
+ # Setup Memgraph:
54
+ # 1. Create vector index (created Entity label on all nodes)
55
+ # 2. Create label property index for performance optimizations
56
+ embedding_dims = self.config.embedder.config["embedding_dims"]
57
+ index_info = self._fetch_existing_indexes()
58
+ # Create vector index if not exists
59
+ if not any(idx.get("index_name") == "memzero" for idx in index_info["vector_index_exists"]):
60
+ self.graph.query(
61
+ f"CREATE VECTOR INDEX memzero ON :Entity(embedding) WITH CONFIG {{'dimension': {embedding_dims}, 'capacity': 1000, 'metric': 'cos'}};"
62
+ )
63
+ # Create label+property index if not exists
64
+ if not any(
65
+ idx.get("index type") == "label+property" and idx.get("label") == "Entity"
66
+ for idx in index_info["index_exists"]
67
+ ):
68
+ self.graph.query("CREATE INDEX ON :Entity(user_id);")
69
+ # Create label index if not exists
70
+ if not any(
71
+ idx.get("index type") == "label" and idx.get("label") == "Entity"
72
+ for idx in index_info["index_exists"]
73
+ ):
74
+ self.graph.query("CREATE INDEX ON :Entity;")
75
+
76
+ def add(self, data, filters):
77
+ """
78
+ Adds data to the graph.
79
+
80
+ Args:
81
+ data (str): The data to add to the graph.
82
+ filters (dict): A dictionary containing filters to be applied during the addition.
83
+ """
84
+ entity_type_map = self._retrieve_nodes_from_data(data, filters)
85
+ to_be_added = self._establish_nodes_relations_from_data(data, filters, entity_type_map)
86
+ search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
87
+ to_be_deleted = self._get_delete_entities_from_search_output(search_output, data, filters)
88
+
89
+ # TODO: Batch queries with APOC plugin
90
+ # TODO: Add more filter support
91
+ deleted_entities = self._delete_entities(to_be_deleted, filters)
92
+ added_entities = self._add_entities(to_be_added, filters, entity_type_map)
93
+
94
+ return {"deleted_entities": deleted_entities, "added_entities": added_entities}
95
+
96
+ def search(self, query, filters, limit=100):
97
+ """
98
+ Search for memories and related graph data.
99
+
100
+ Args:
101
+ query (str): Query to search for.
102
+ filters (dict): A dictionary containing filters to be applied during the search.
103
+ limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100.
104
+
105
+ Returns:
106
+ dict: A dictionary containing:
107
+ - "contexts": List of search results from the base data store.
108
+ - "entities": List of related graph data based on the query.
109
+ """
110
+ entity_type_map = self._retrieve_nodes_from_data(query, filters)
111
+ search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
112
+
113
+ if not search_output:
114
+ return []
115
+
116
+ search_outputs_sequence = [
117
+ [item["source"], item["relationship"], item["destination"]] for item in search_output
118
+ ]
119
+ bm25 = BM25Okapi(search_outputs_sequence)
120
+
121
+ tokenized_query = query.split(" ")
122
+ reranked_results = bm25.get_top_n(tokenized_query, search_outputs_sequence, n=5)
123
+
124
+ search_results = []
125
+ for item in reranked_results:
126
+ search_results.append({"source": item[0], "relationship": item[1], "destination": item[2]})
127
+
128
+ logger.info(f"Returned {len(search_results)} search results")
129
+
130
+ return search_results
131
+
132
+ def delete_all(self, filters):
133
+ """Delete all nodes and relationships for a user or specific agent."""
134
+ if filters.get("agent_id"):
135
+ cypher = """
136
+ MATCH (n:Entity {user_id: $user_id, agent_id: $agent_id})
137
+ DETACH DELETE n
138
+ """
139
+ params = {"user_id": filters["user_id"], "agent_id": filters["agent_id"]}
140
+ else:
141
+ cypher = """
142
+ MATCH (n:Entity {user_id: $user_id})
143
+ DETACH DELETE n
144
+ """
145
+ params = {"user_id": filters["user_id"]}
146
+ self.graph.query(cypher, params=params)
147
+
148
+ def get_all(self, filters, limit=100):
149
+ """
150
+ Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
151
+
152
+ Args:
153
+ filters (dict): A dictionary containing filters to be applied during the retrieval.
154
+ Supports 'user_id' (required) and 'agent_id' (optional).
155
+ limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100.
156
+ Returns:
157
+ list: A list of dictionaries, each containing:
158
+ - 'source': The source node name.
159
+ - 'relationship': The relationship type.
160
+ - 'target': The target node name.
161
+ """
162
+ # Build query based on whether agent_id is provided
163
+ if filters.get("agent_id"):
164
+ query = """
165
+ MATCH (n:Entity {user_id: $user_id, agent_id: $agent_id})-[r]->(m:Entity {user_id: $user_id, agent_id: $agent_id})
166
+ RETURN n.name AS source, type(r) AS relationship, m.name AS target
167
+ LIMIT $limit
168
+ """
169
+ params = {"user_id": filters["user_id"], "agent_id": filters["agent_id"], "limit": limit}
170
+ else:
171
+ query = """
172
+ MATCH (n:Entity {user_id: $user_id})-[r]->(m:Entity {user_id: $user_id})
173
+ RETURN n.name AS source, type(r) AS relationship, m.name AS target
174
+ LIMIT $limit
175
+ """
176
+ params = {"user_id": filters["user_id"], "limit": limit}
177
+
178
+ results = self.graph.query(query, params=params)
179
+
180
+ final_results = []
181
+ for result in results:
182
+ final_results.append(
183
+ {
184
+ "source": result["source"],
185
+ "relationship": result["relationship"],
186
+ "target": result["target"],
187
+ }
188
+ )
189
+
190
+ logger.info(f"Retrieved {len(final_results)} relationships")
191
+
192
+ return final_results
193
+
194
+ def _retrieve_nodes_from_data(self, data, filters):
195
+ """Extracts all the entities mentioned in the query."""
196
+ _tools = [EXTRACT_ENTITIES_TOOL]
197
+ if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
198
+ _tools = [EXTRACT_ENTITIES_STRUCT_TOOL]
199
+ search_results = self.llm.generate_response(
200
+ messages=[
201
+ {
202
+ "role": "system",
203
+ "content": f"You are a smart assistant who understands entities and their types in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source entity. Extract all the entities from the text. ***DO NOT*** answer the question itself if the given text is a question.",
204
+ },
205
+ {"role": "user", "content": data},
206
+ ],
207
+ tools=_tools,
208
+ )
209
+
210
+ entity_type_map = {}
211
+
212
+ try:
213
+ for tool_call in search_results["tool_calls"]:
214
+ if tool_call["name"] != "extract_entities":
215
+ continue
216
+ for item in tool_call["arguments"]["entities"]:
217
+ entity_type_map[item["entity"]] = item["entity_type"]
218
+ except Exception as e:
219
+ logger.exception(
220
+ f"Error in search tool: {e}, llm_provider={self.llm_provider}, search_results={search_results}"
221
+ )
222
+
223
+ entity_type_map = {k.lower().replace(" ", "_"): v.lower().replace(" ", "_") for k, v in entity_type_map.items()}
224
+ logger.debug(f"Entity type map: {entity_type_map}\n search_results={search_results}")
225
+ return entity_type_map
226
+
227
+ def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
228
+ """Eshtablish relations among the extracted nodes."""
229
+ if self.config.graph_store.custom_prompt:
230
+ messages = [
231
+ {
232
+ "role": "system",
233
+ "content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]).replace(
234
+ "CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
235
+ ),
236
+ },
237
+ {"role": "user", "content": data},
238
+ ]
239
+ else:
240
+ messages = [
241
+ {
242
+ "role": "system",
243
+ "content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]),
244
+ },
245
+ {
246
+ "role": "user",
247
+ "content": f"List of entities: {list(entity_type_map.keys())}. \n\nText: {data}",
248
+ },
249
+ ]
250
+
251
+ _tools = [RELATIONS_TOOL]
252
+ if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
253
+ _tools = [RELATIONS_STRUCT_TOOL]
254
+
255
+ extracted_entities = self.llm.generate_response(
256
+ messages=messages,
257
+ tools=_tools,
258
+ )
259
+
260
+ entities = []
261
+ if extracted_entities["tool_calls"]:
262
+ entities = extracted_entities["tool_calls"][0]["arguments"]["entities"]
263
+
264
+ entities = self._remove_spaces_from_entities(entities)
265
+ logger.debug(f"Extracted entities: {entities}")
266
+ return entities
267
+
268
+ def _search_graph_db(self, node_list, filters, limit=100):
269
+ """Search similar nodes among and their respective incoming and outgoing relations."""
270
+ result_relations = []
271
+
272
+ for node in node_list:
273
+ n_embedding = self.embedding_model.embed(node)
274
+
275
+ # Build query based on whether agent_id is provided
276
+ if filters.get("agent_id"):
277
+ cypher_query = """
278
+ MATCH (n:Entity {user_id: $user_id, agent_id: $agent_id})-[r]->(m:Entity)
279
+ WHERE n.embedding IS NOT NULL
280
+ WITH collect(n) AS nodes1, collect(m) AS nodes2, r
281
+ CALL node_similarity.cosine_pairwise("embedding", nodes1, nodes2)
282
+ YIELD node1, node2, similarity
283
+ WITH node1, node2, similarity, r
284
+ WHERE similarity >= $threshold
285
+ RETURN node1.name AS source, id(node1) AS source_id, type(r) AS relationship, id(r) AS relation_id, node2.name AS destination, id(node2) AS destination_id, similarity
286
+ UNION
287
+ MATCH (n:Entity {user_id: $user_id, agent_id: $agent_id})<-[r]-(m:Entity)
288
+ WHERE n.embedding IS NOT NULL
289
+ WITH collect(n) AS nodes1, collect(m) AS nodes2, r
290
+ CALL node_similarity.cosine_pairwise("embedding", nodes1, nodes2)
291
+ YIELD node1, node2, similarity
292
+ WITH node1, node2, similarity, r
293
+ WHERE similarity >= $threshold
294
+ RETURN node2.name AS source, id(node2) AS source_id, type(r) AS relationship, id(r) AS relation_id, node1.name AS destination, id(node1) AS destination_id, similarity
295
+ ORDER BY similarity DESC
296
+ LIMIT $limit;
297
+ """
298
+ params = {
299
+ "n_embedding": n_embedding,
300
+ "threshold": self.threshold,
301
+ "user_id": filters["user_id"],
302
+ "agent_id": filters["agent_id"],
303
+ "limit": limit,
304
+ }
305
+ else:
306
+ cypher_query = """
307
+ MATCH (n:Entity {user_id: $user_id})-[r]->(m:Entity)
308
+ WHERE n.embedding IS NOT NULL
309
+ WITH collect(n) AS nodes1, collect(m) AS nodes2, r
310
+ CALL node_similarity.cosine_pairwise("embedding", nodes1, nodes2)
311
+ YIELD node1, node2, similarity
312
+ WITH node1, node2, similarity, r
313
+ WHERE similarity >= $threshold
314
+ RETURN node1.name AS source, id(node1) AS source_id, type(r) AS relationship, id(r) AS relation_id, node2.name AS destination, id(node2) AS destination_id, similarity
315
+ UNION
316
+ MATCH (n:Entity {user_id: $user_id})<-[r]-(m:Entity)
317
+ WHERE n.embedding IS NOT NULL
318
+ WITH collect(n) AS nodes1, collect(m) AS nodes2, r
319
+ CALL node_similarity.cosine_pairwise("embedding", nodes1, nodes2)
320
+ YIELD node1, node2, similarity
321
+ WITH node1, node2, similarity, r
322
+ WHERE similarity >= $threshold
323
+ RETURN node2.name AS source, id(node2) AS source_id, type(r) AS relationship, id(r) AS relation_id, node1.name AS destination, id(node1) AS destination_id, similarity
324
+ ORDER BY similarity DESC
325
+ LIMIT $limit;
326
+ """
327
+ params = {
328
+ "n_embedding": n_embedding,
329
+ "threshold": self.threshold,
330
+ "user_id": filters["user_id"],
331
+ "limit": limit,
332
+ }
333
+
334
+ ans = self.graph.query(cypher_query, params=params)
335
+ result_relations.extend(ans)
336
+
337
+ return result_relations
338
+
339
+ def _get_delete_entities_from_search_output(self, search_output, data, filters):
340
+ """Get the entities to be deleted from the search output."""
341
+ search_output_string = format_entities(search_output)
342
+ system_prompt, user_prompt = get_delete_messages(search_output_string, data, filters["user_id"])
343
+
344
+ _tools = [DELETE_MEMORY_TOOL_GRAPH]
345
+ if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
346
+ _tools = [
347
+ DELETE_MEMORY_STRUCT_TOOL_GRAPH,
348
+ ]
349
+
350
+ memory_updates = self.llm.generate_response(
351
+ messages=[
352
+ {"role": "system", "content": system_prompt},
353
+ {"role": "user", "content": user_prompt},
354
+ ],
355
+ tools=_tools,
356
+ )
357
+ to_be_deleted = []
358
+ for item in memory_updates["tool_calls"]:
359
+ if item["name"] == "delete_graph_memory":
360
+ to_be_deleted.append(item["arguments"])
361
+ # in case if it is not in the correct format
362
+ to_be_deleted = self._remove_spaces_from_entities(to_be_deleted)
363
+ logger.debug(f"Deleted relationships: {to_be_deleted}")
364
+ return to_be_deleted
365
+
366
+ def _delete_entities(self, to_be_deleted, filters):
367
+ """Delete the entities from the graph."""
368
+ user_id = filters["user_id"]
369
+ agent_id = filters.get("agent_id", None)
370
+ results = []
371
+
372
+ for item in to_be_deleted:
373
+ source = item["source"]
374
+ destination = item["destination"]
375
+ relationship = item["relationship"]
376
+
377
+ # Build the agent filter for the query
378
+ agent_filter = ""
379
+ params = {
380
+ "source_name": source,
381
+ "dest_name": destination,
382
+ "user_id": user_id,
383
+ }
384
+
385
+ if agent_id:
386
+ agent_filter = "AND n.agent_id = $agent_id AND m.agent_id = $agent_id"
387
+ params["agent_id"] = agent_id
388
+
389
+ # Delete the specific relationship between nodes
390
+ cypher = f"""
391
+ MATCH (n:Entity {{name: $source_name, user_id: $user_id}})
392
+ -[r:{relationship}]->
393
+ (m:Entity {{name: $dest_name, user_id: $user_id}})
394
+ WHERE 1=1 {agent_filter}
395
+ DELETE r
396
+ RETURN
397
+ n.name AS source,
398
+ m.name AS target,
399
+ type(r) AS relationship
400
+ """
401
+
402
+ result = self.graph.query(cypher, params=params)
403
+ results.append(result)
404
+
405
+ return results
406
+
407
+ # added Entity label to all nodes for vector search to work
408
+ def _add_entities(self, to_be_added, filters, entity_type_map):
409
+ """Add the new entities to the graph. Merge the nodes if they already exist."""
410
+ user_id = filters["user_id"]
411
+ agent_id = filters.get("agent_id", None)
412
+ results = []
413
+
414
+ for item in to_be_added:
415
+ # entities
416
+ source = item["source"]
417
+ destination = item["destination"]
418
+ relationship = item["relationship"]
419
+
420
+ # types
421
+ source_type = entity_type_map.get(source, "__User__")
422
+ destination_type = entity_type_map.get(destination, "__User__")
423
+
424
+ # embeddings
425
+ source_embedding = self.embedding_model.embed(source)
426
+ dest_embedding = self.embedding_model.embed(destination)
427
+
428
+ # search for the nodes with the closest embeddings
429
+ source_node_search_result = self._search_source_node(source_embedding, filters, threshold=0.9)
430
+ destination_node_search_result = self._search_destination_node(dest_embedding, filters, threshold=0.9)
431
+
432
+ # Prepare agent_id for node creation
433
+ agent_id_clause = ""
434
+ if agent_id:
435
+ agent_id_clause = ", agent_id: $agent_id"
436
+
437
+ # TODO: Create a cypher query and common params for all the cases
438
+ if not destination_node_search_result and source_node_search_result:
439
+ cypher = f"""
440
+ MATCH (source:Entity)
441
+ WHERE id(source) = $source_id
442
+ MERGE (destination:{destination_type}:Entity {{name: $destination_name, user_id: $user_id{agent_id_clause}}})
443
+ ON CREATE SET
444
+ destination.created = timestamp(),
445
+ destination.embedding = $destination_embedding,
446
+ destination:Entity
447
+ MERGE (source)-[r:{relationship}]->(destination)
448
+ ON CREATE SET
449
+ r.created = timestamp()
450
+ RETURN source.name AS source, type(r) AS relationship, destination.name AS target
451
+ """
452
+
453
+ params = {
454
+ "source_id": source_node_search_result[0]["id(source_candidate)"],
455
+ "destination_name": destination,
456
+ "destination_embedding": dest_embedding,
457
+ "user_id": user_id,
458
+ }
459
+ if agent_id:
460
+ params["agent_id"] = agent_id
461
+
462
+ elif destination_node_search_result and not source_node_search_result:
463
+ cypher = f"""
464
+ MATCH (destination:Entity)
465
+ WHERE id(destination) = $destination_id
466
+ MERGE (source:{source_type}:Entity {{name: $source_name, user_id: $user_id{agent_id_clause}}})
467
+ ON CREATE SET
468
+ source.created = timestamp(),
469
+ source.embedding = $source_embedding,
470
+ source:Entity
471
+ MERGE (source)-[r:{relationship}]->(destination)
472
+ ON CREATE SET
473
+ r.created = timestamp()
474
+ RETURN source.name AS source, type(r) AS relationship, destination.name AS target
475
+ """
476
+
477
+ params = {
478
+ "destination_id": destination_node_search_result[0]["id(destination_candidate)"],
479
+ "source_name": source,
480
+ "source_embedding": source_embedding,
481
+ "user_id": user_id,
482
+ }
483
+ if agent_id:
484
+ params["agent_id"] = agent_id
485
+
486
+ elif source_node_search_result and destination_node_search_result:
487
+ cypher = f"""
488
+ MATCH (source:Entity)
489
+ WHERE id(source) = $source_id
490
+ MATCH (destination:Entity)
491
+ WHERE id(destination) = $destination_id
492
+ MERGE (source)-[r:{relationship}]->(destination)
493
+ ON CREATE SET
494
+ r.created_at = timestamp(),
495
+ r.updated_at = timestamp()
496
+ RETURN source.name AS source, type(r) AS relationship, destination.name AS target
497
+ """
498
+ params = {
499
+ "source_id": source_node_search_result[0]["id(source_candidate)"],
500
+ "destination_id": destination_node_search_result[0]["id(destination_candidate)"],
501
+ "user_id": user_id,
502
+ }
503
+ if agent_id:
504
+ params["agent_id"] = agent_id
505
+
506
+ else:
507
+ cypher = f"""
508
+ MERGE (n:{source_type}:Entity {{name: $source_name, user_id: $user_id{agent_id_clause}}})
509
+ ON CREATE SET n.created = timestamp(), n.embedding = $source_embedding, n:Entity
510
+ ON MATCH SET n.embedding = $source_embedding
511
+ MERGE (m:{destination_type}:Entity {{name: $dest_name, user_id: $user_id{agent_id_clause}}})
512
+ ON CREATE SET m.created = timestamp(), m.embedding = $dest_embedding, m:Entity
513
+ ON MATCH SET m.embedding = $dest_embedding
514
+ MERGE (n)-[rel:{relationship}]->(m)
515
+ ON CREATE SET rel.created = timestamp()
516
+ RETURN n.name AS source, type(rel) AS relationship, m.name AS target
517
+ """
518
+ params = {
519
+ "source_name": source,
520
+ "dest_name": destination,
521
+ "source_embedding": source_embedding,
522
+ "dest_embedding": dest_embedding,
523
+ "user_id": user_id,
524
+ }
525
+ if agent_id:
526
+ params["agent_id"] = agent_id
527
+
528
+ result = self.graph.query(cypher, params=params)
529
+ results.append(result)
530
+ return results
531
+
532
+ def _remove_spaces_from_entities(self, entity_list):
533
+ for item in entity_list:
534
+ item["source"] = item["source"].lower().replace(" ", "_")
535
+ item["relationship"] = item["relationship"].lower().replace(" ", "_")
536
+ item["destination"] = item["destination"].lower().replace(" ", "_")
537
+ return entity_list
538
+
539
+ def _search_source_node(self, source_embedding, filters, threshold=0.9):
540
+ """Search for source nodes with similar embeddings."""
541
+ user_id = filters["user_id"]
542
+ agent_id = filters.get("agent_id", None)
543
+
544
+ if agent_id:
545
+ cypher = """
546
+ CALL vector_search.search("memzero", 1, $source_embedding)
547
+ YIELD distance, node, similarity
548
+ WITH node AS source_candidate, similarity
549
+ WHERE source_candidate.user_id = $user_id
550
+ AND source_candidate.agent_id = $agent_id
551
+ AND similarity >= $threshold
552
+ RETURN id(source_candidate);
553
+ """
554
+ params = {
555
+ "source_embedding": source_embedding,
556
+ "user_id": user_id,
557
+ "agent_id": agent_id,
558
+ "threshold": threshold,
559
+ }
560
+ else:
561
+ cypher = """
562
+ CALL vector_search.search("memzero", 1, $source_embedding)
563
+ YIELD distance, node, similarity
564
+ WITH node AS source_candidate, similarity
565
+ WHERE source_candidate.user_id = $user_id
566
+ AND similarity >= $threshold
567
+ RETURN id(source_candidate);
568
+ """
569
+ params = {
570
+ "source_embedding": source_embedding,
571
+ "user_id": user_id,
572
+ "threshold": threshold,
573
+ }
574
+
575
+ result = self.graph.query(cypher, params=params)
576
+ return result
577
+
578
+ def _search_destination_node(self, destination_embedding, filters, threshold=0.9):
579
+ """Search for destination nodes with similar embeddings."""
580
+ user_id = filters["user_id"]
581
+ agent_id = filters.get("agent_id", None)
582
+
583
+ if agent_id:
584
+ cypher = """
585
+ CALL vector_search.search("memzero", 1, $destination_embedding)
586
+ YIELD distance, node, similarity
587
+ WITH node AS destination_candidate, similarity
588
+ WHERE node.user_id = $user_id
589
+ AND node.agent_id = $agent_id
590
+ AND similarity >= $threshold
591
+ RETURN id(destination_candidate);
592
+ """
593
+ params = {
594
+ "destination_embedding": destination_embedding,
595
+ "user_id": user_id,
596
+ "agent_id": agent_id,
597
+ "threshold": threshold,
598
+ }
599
+ else:
600
+ cypher = """
601
+ CALL vector_search.search("memzero", 1, $destination_embedding)
602
+ YIELD distance, node, similarity
603
+ WITH node AS destination_candidate, similarity
604
+ WHERE node.user_id = $user_id
605
+ AND similarity >= $threshold
606
+ RETURN id(destination_candidate);
607
+ """
608
+ params = {
609
+ "destination_embedding": destination_embedding,
610
+ "user_id": user_id,
611
+ "threshold": threshold,
612
+ }
613
+
614
+ result = self.graph.query(cypher, params=params)
615
+ return result
616
+
617
+ def _fetch_existing_indexes(self):
618
+ """
619
+ Retrieves information about existing indexes and vector indexes in the Memgraph database.
620
+
621
+ Returns:
622
+ dict: A dictionary containing lists of existing indexes and vector indexes.
623
+ """
624
+
625
+ index_exists = list(self.graph.query("SHOW INDEX INFO;"))
626
+ vector_index_exists = list(self.graph.query("SHOW VECTOR INDEX INFO;"))
627
+ return {
628
+ "index_exists": index_exists,
629
+ "vector_index_exists": vector_index_exists
630
+ }
mem0/memory/setup.py ADDED
@@ -0,0 +1,56 @@
1
+ import json
2
+ import os
3
+ import uuid
4
+
5
+ # Set up the directory path
6
+ VECTOR_ID = str(uuid.uuid4())
7
+ home_dir = os.path.expanduser("~")
8
+ mem0_dir = os.environ.get("MEM0_DIR") or os.path.join(home_dir, ".mem0")
9
+ os.makedirs(mem0_dir, exist_ok=True)
10
+
11
+
12
+ def setup_config():
13
+ config_path = os.path.join(mem0_dir, "config.json")
14
+ if not os.path.exists(config_path):
15
+ user_id = str(uuid.uuid4())
16
+ config = {"user_id": user_id}
17
+ with open(config_path, "w") as config_file:
18
+ json.dump(config, config_file, indent=4)
19
+
20
+
21
+ def get_user_id():
22
+ config_path = os.path.join(mem0_dir, "config.json")
23
+ if not os.path.exists(config_path):
24
+ return "anonymous_user"
25
+
26
+ try:
27
+ with open(config_path, "r") as config_file:
28
+ config = json.load(config_file)
29
+ user_id = config.get("user_id")
30
+ return user_id
31
+ except Exception:
32
+ return "anonymous_user"
33
+
34
+
35
+ def get_or_create_user_id(vector_store):
36
+ """Store user_id in vector store and return it."""
37
+ user_id = get_user_id()
38
+
39
+ # Try to get existing user_id from vector store
40
+ try:
41
+ existing = vector_store.get(vector_id=user_id)
42
+ if existing and hasattr(existing, "payload") and existing.payload and "user_id" in existing.payload:
43
+ return existing.payload["user_id"]
44
+ except Exception:
45
+ pass
46
+
47
+ # If we get here, we need to insert the user_id
48
+ try:
49
+ dims = getattr(vector_store, "embedding_model_dims", 1536)
50
+ vector_store.insert(
51
+ vectors=[[0.1] * dims], payloads=[{"user_id": user_id, "type": "user_identity"}], ids=[user_id]
52
+ )
53
+ except Exception:
54
+ pass
55
+
56
+ return user_id