mem0ai-azure-mysql 0.1.115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mem0/__init__.py +6 -0
  2. mem0/client/__init__.py +0 -0
  3. mem0/client/main.py +1535 -0
  4. mem0/client/project.py +860 -0
  5. mem0/client/utils.py +29 -0
  6. mem0/configs/__init__.py +0 -0
  7. mem0/configs/base.py +90 -0
  8. mem0/configs/dbs/__init__.py +4 -0
  9. mem0/configs/dbs/base.py +41 -0
  10. mem0/configs/dbs/mysql.py +25 -0
  11. mem0/configs/embeddings/__init__.py +0 -0
  12. mem0/configs/embeddings/base.py +108 -0
  13. mem0/configs/enums.py +7 -0
  14. mem0/configs/llms/__init__.py +0 -0
  15. mem0/configs/llms/base.py +152 -0
  16. mem0/configs/prompts.py +333 -0
  17. mem0/configs/vector_stores/__init__.py +0 -0
  18. mem0/configs/vector_stores/azure_ai_search.py +59 -0
  19. mem0/configs/vector_stores/baidu.py +29 -0
  20. mem0/configs/vector_stores/chroma.py +40 -0
  21. mem0/configs/vector_stores/elasticsearch.py +47 -0
  22. mem0/configs/vector_stores/faiss.py +39 -0
  23. mem0/configs/vector_stores/langchain.py +32 -0
  24. mem0/configs/vector_stores/milvus.py +43 -0
  25. mem0/configs/vector_stores/mongodb.py +25 -0
  26. mem0/configs/vector_stores/opensearch.py +41 -0
  27. mem0/configs/vector_stores/pgvector.py +37 -0
  28. mem0/configs/vector_stores/pinecone.py +56 -0
  29. mem0/configs/vector_stores/qdrant.py +49 -0
  30. mem0/configs/vector_stores/redis.py +26 -0
  31. mem0/configs/vector_stores/supabase.py +44 -0
  32. mem0/configs/vector_stores/upstash_vector.py +36 -0
  33. mem0/configs/vector_stores/vertex_ai_vector_search.py +27 -0
  34. mem0/configs/vector_stores/weaviate.py +43 -0
  35. mem0/dbs/__init__.py +4 -0
  36. mem0/dbs/base.py +68 -0
  37. mem0/dbs/configs.py +21 -0
  38. mem0/dbs/mysql.py +321 -0
  39. mem0/embeddings/__init__.py +0 -0
  40. mem0/embeddings/aws_bedrock.py +100 -0
  41. mem0/embeddings/azure_openai.py +43 -0
  42. mem0/embeddings/base.py +31 -0
  43. mem0/embeddings/configs.py +30 -0
  44. mem0/embeddings/gemini.py +39 -0
  45. mem0/embeddings/huggingface.py +41 -0
  46. mem0/embeddings/langchain.py +35 -0
  47. mem0/embeddings/lmstudio.py +29 -0
  48. mem0/embeddings/mock.py +11 -0
  49. mem0/embeddings/ollama.py +53 -0
  50. mem0/embeddings/openai.py +49 -0
  51. mem0/embeddings/together.py +31 -0
  52. mem0/embeddings/vertexai.py +54 -0
  53. mem0/graphs/__init__.py +0 -0
  54. mem0/graphs/configs.py +96 -0
  55. mem0/graphs/neptune/__init__.py +0 -0
  56. mem0/graphs/neptune/base.py +410 -0
  57. mem0/graphs/neptune/main.py +372 -0
  58. mem0/graphs/tools.py +371 -0
  59. mem0/graphs/utils.py +97 -0
  60. mem0/llms/__init__.py +0 -0
  61. mem0/llms/anthropic.py +64 -0
  62. mem0/llms/aws_bedrock.py +270 -0
  63. mem0/llms/azure_openai.py +114 -0
  64. mem0/llms/azure_openai_structured.py +76 -0
  65. mem0/llms/base.py +32 -0
  66. mem0/llms/configs.py +34 -0
  67. mem0/llms/deepseek.py +85 -0
  68. mem0/llms/gemini.py +201 -0
  69. mem0/llms/groq.py +88 -0
  70. mem0/llms/langchain.py +65 -0
  71. mem0/llms/litellm.py +87 -0
  72. mem0/llms/lmstudio.py +53 -0
  73. mem0/llms/ollama.py +94 -0
  74. mem0/llms/openai.py +124 -0
  75. mem0/llms/openai_structured.py +52 -0
  76. mem0/llms/sarvam.py +89 -0
  77. mem0/llms/together.py +88 -0
  78. mem0/llms/vllm.py +89 -0
  79. mem0/llms/xai.py +52 -0
  80. mem0/memory/__init__.py +0 -0
  81. mem0/memory/base.py +63 -0
  82. mem0/memory/graph_memory.py +632 -0
  83. mem0/memory/main.py +1843 -0
  84. mem0/memory/memgraph_memory.py +630 -0
  85. mem0/memory/setup.py +56 -0
  86. mem0/memory/storage.py +218 -0
  87. mem0/memory/telemetry.py +90 -0
  88. mem0/memory/utils.py +133 -0
  89. mem0/proxy/__init__.py +0 -0
  90. mem0/proxy/main.py +194 -0
  91. mem0/utils/factory.py +132 -0
  92. mem0/vector_stores/__init__.py +0 -0
  93. mem0/vector_stores/azure_ai_search.py +383 -0
  94. mem0/vector_stores/baidu.py +368 -0
  95. mem0/vector_stores/base.py +58 -0
  96. mem0/vector_stores/chroma.py +229 -0
  97. mem0/vector_stores/configs.py +60 -0
  98. mem0/vector_stores/elasticsearch.py +235 -0
  99. mem0/vector_stores/faiss.py +473 -0
  100. mem0/vector_stores/langchain.py +179 -0
  101. mem0/vector_stores/milvus.py +245 -0
  102. mem0/vector_stores/mongodb.py +293 -0
  103. mem0/vector_stores/opensearch.py +281 -0
  104. mem0/vector_stores/pgvector.py +294 -0
  105. mem0/vector_stores/pinecone.py +373 -0
  106. mem0/vector_stores/qdrant.py +240 -0
  107. mem0/vector_stores/redis.py +295 -0
  108. mem0/vector_stores/supabase.py +237 -0
  109. mem0/vector_stores/upstash_vector.py +293 -0
  110. mem0/vector_stores/vertex_ai_vector_search.py +629 -0
  111. mem0/vector_stores/weaviate.py +316 -0
  112. mem0ai_azure_mysql-0.1.115.data/data/README.md +169 -0
  113. mem0ai_azure_mysql-0.1.115.dist-info/METADATA +224 -0
  114. mem0ai_azure_mysql-0.1.115.dist-info/RECORD +116 -0
  115. mem0ai_azure_mysql-0.1.115.dist-info/WHEEL +4 -0
  116. mem0ai_azure_mysql-0.1.115.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,632 @@
1
+ import logging
2
+
3
+ from mem0.memory.utils import format_entities
4
+
5
+ try:
6
+ from langchain_neo4j import Neo4jGraph
7
+ except ImportError:
8
+ raise ImportError("langchain_neo4j is not installed. Please install it using pip install langchain-neo4j")
9
+
10
+ try:
11
+ from rank_bm25 import BM25Okapi
12
+ except ImportError:
13
+ raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25")
14
+
15
+ from mem0.graphs.tools import (
16
+ DELETE_MEMORY_STRUCT_TOOL_GRAPH,
17
+ DELETE_MEMORY_TOOL_GRAPH,
18
+ EXTRACT_ENTITIES_STRUCT_TOOL,
19
+ EXTRACT_ENTITIES_TOOL,
20
+ RELATIONS_STRUCT_TOOL,
21
+ RELATIONS_TOOL,
22
+ )
23
+ from mem0.graphs.utils import EXTRACT_RELATIONS_PROMPT, get_delete_messages
24
+ from mem0.utils.factory import EmbedderFactory, LlmFactory
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class MemoryGraph:
30
+ def __init__(self, config):
31
+ self.config = config
32
+ self.graph = Neo4jGraph(
33
+ self.config.graph_store.config.url,
34
+ self.config.graph_store.config.username,
35
+ self.config.graph_store.config.password,
36
+ self.config.graph_store.config.database,
37
+ refresh_schema=False,
38
+ driver_config={"notifications_min_severity": "OFF"},
39
+ )
40
+ self.embedding_model = EmbedderFactory.create(
41
+ self.config.embedder.provider, self.config.embedder.config, self.config.vector_store.config
42
+ )
43
+ self.node_label = ":`__Entity__`" if self.config.graph_store.config.base_label else ""
44
+
45
+ if self.config.graph_store.config.base_label:
46
+ # Safely add user_id index
47
+ try:
48
+ self.graph.query(f"CREATE INDEX entity_single IF NOT EXISTS FOR (n {self.node_label}) ON (n.user_id)")
49
+ except Exception:
50
+ pass
51
+ try: # Safely try to add composite index (Enterprise only)
52
+ self.graph.query(
53
+ f"CREATE INDEX entity_composite IF NOT EXISTS FOR (n {self.node_label}) ON (n.name, n.user_id)"
54
+ )
55
+ except Exception:
56
+ pass
57
+
58
+ self.llm_provider = "openai_structured"
59
+ if self.config.llm.provider:
60
+ self.llm_provider = self.config.llm.provider
61
+ if self.config.graph_store.llm:
62
+ self.llm_provider = self.config.graph_store.llm.provider
63
+
64
+ self.llm = LlmFactory.create(self.llm_provider, self.config.llm.config)
65
+ self.user_id = None
66
+ self.threshold = 0.7
67
+
68
+ def add(self, data, filters):
69
+ """
70
+ Adds data to the graph.
71
+
72
+ Args:
73
+ data (str): The data to add to the graph.
74
+ filters (dict): A dictionary containing filters to be applied during the addition.
75
+ """
76
+ entity_type_map = self._retrieve_nodes_from_data(data, filters)
77
+ to_be_added = self._establish_nodes_relations_from_data(data, filters, entity_type_map)
78
+ search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
79
+ to_be_deleted = self._get_delete_entities_from_search_output(search_output, data, filters)
80
+
81
+ # TODO: Batch queries with APOC plugin
82
+ # TODO: Add more filter support
83
+ deleted_entities = self._delete_entities(to_be_deleted, filters)
84
+ added_entities = self._add_entities(to_be_added, filters, entity_type_map)
85
+
86
+ return {"deleted_entities": deleted_entities, "added_entities": added_entities}
87
+
88
+ def search(self, query, filters, limit=100):
89
+ """
90
+ Search for memories and related graph data.
91
+
92
+ Args:
93
+ query (str): Query to search for.
94
+ filters (dict): A dictionary containing filters to be applied during the search.
95
+ limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100.
96
+
97
+ Returns:
98
+ dict: A dictionary containing:
99
+ - "contexts": List of search results from the base data store.
100
+ - "entities": List of related graph data based on the query.
101
+ """
102
+ entity_type_map = self._retrieve_nodes_from_data(query, filters)
103
+ search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
104
+
105
+ if not search_output:
106
+ return []
107
+
108
+ search_outputs_sequence = [
109
+ [item["source"], item["relationship"], item["destination"]] for item in search_output
110
+ ]
111
+ bm25 = BM25Okapi(search_outputs_sequence)
112
+
113
+ tokenized_query = query.split(" ")
114
+ reranked_results = bm25.get_top_n(tokenized_query, search_outputs_sequence, n=5)
115
+
116
+ search_results = []
117
+ for item in reranked_results:
118
+ search_results.append({"source": item[0], "relationship": item[1], "destination": item[2]})
119
+
120
+ logger.info(f"Returned {len(search_results)} search results")
121
+
122
+ return search_results
123
+
124
+ def delete_all(self, filters):
125
+ if filters.get("agent_id"):
126
+ cypher = f"""
127
+ MATCH (n {self.node_label} {{user_id: $user_id, agent_id: $agent_id}})
128
+ DETACH DELETE n
129
+ """
130
+ params = {"user_id": filters["user_id"], "agent_id": filters["agent_id"]}
131
+ else:
132
+ cypher = f"""
133
+ MATCH (n {self.node_label} {{user_id: $user_id}})
134
+ DETACH DELETE n
135
+ """
136
+ params = {"user_id": filters["user_id"]}
137
+ self.graph.query(cypher, params=params)
138
+
139
+ def get_all(self, filters, limit=100):
140
+ """
141
+ Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
142
+ Args:
143
+ filters (dict): A dictionary containing filters to be applied during the retrieval.
144
+ limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100.
145
+ Returns:
146
+ list: A list of dictionaries, each containing:
147
+ - 'contexts': The base data store response for each memory.
148
+ - 'entities': A list of strings representing the nodes and relationships
149
+ """
150
+ agent_filter = ""
151
+ params = {"user_id": filters["user_id"], "limit": limit}
152
+ if filters.get("agent_id"):
153
+ agent_filter = "AND n.agent_id = $agent_id AND m.agent_id = $agent_id"
154
+ params["agent_id"] = filters["agent_id"]
155
+
156
+ query = f"""
157
+ MATCH (n {self.node_label} {{user_id: $user_id}})-[r]->(m {self.node_label} {{user_id: $user_id}})
158
+ WHERE 1=1 {agent_filter}
159
+ RETURN n.name AS source, type(r) AS relationship, m.name AS target
160
+ LIMIT $limit
161
+ """
162
+ results = self.graph.query(query, params=params)
163
+
164
+ final_results = []
165
+ for result in results:
166
+ final_results.append(
167
+ {
168
+ "source": result["source"],
169
+ "relationship": result["relationship"],
170
+ "target": result["target"],
171
+ }
172
+ )
173
+
174
+ logger.info(f"Retrieved {len(final_results)} relationships")
175
+
176
+ return final_results
177
+
178
+ def _retrieve_nodes_from_data(self, data, filters):
179
+ """Extracts all the entities mentioned in the query."""
180
+ _tools = [EXTRACT_ENTITIES_TOOL]
181
+ if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
182
+ _tools = [EXTRACT_ENTITIES_STRUCT_TOOL]
183
+ search_results = self.llm.generate_response(
184
+ messages=[
185
+ {
186
+ "role": "system",
187
+ "content": f"You are a smart assistant who understands entities and their types in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source entity. Extract all the entities from the text. ***DO NOT*** answer the question itself if the given text is a question.",
188
+ },
189
+ {"role": "user", "content": data},
190
+ ],
191
+ tools=_tools,
192
+ )
193
+
194
+ entity_type_map = {}
195
+
196
+ try:
197
+ for tool_call in search_results["tool_calls"]:
198
+ if tool_call["name"] != "extract_entities":
199
+ continue
200
+ for item in tool_call["arguments"]["entities"]:
201
+ entity_type_map[item["entity"]] = item["entity_type"]
202
+ except Exception as e:
203
+ logger.exception(
204
+ f"Error in search tool: {e}, llm_provider={self.llm_provider}, search_results={search_results}"
205
+ )
206
+
207
+ entity_type_map = {k.lower().replace(" ", "_"): v.lower().replace(" ", "_") for k, v in entity_type_map.items()}
208
+ logger.debug(f"Entity type map: {entity_type_map}\n search_results={search_results}")
209
+ return entity_type_map
210
+
211
+ def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
212
+ """Establish relations among the extracted nodes."""
213
+
214
+ # Compose user identification string for prompt
215
+ user_identity = f"user_id: {filters['user_id']}"
216
+ if filters.get("agent_id"):
217
+ user_identity += f", agent_id: {filters['agent_id']}"
218
+
219
+ if self.config.graph_store.custom_prompt:
220
+ system_content = EXTRACT_RELATIONS_PROMPT.replace("USER_ID", user_identity)
221
+ # Add the custom prompt line if configured
222
+ system_content = system_content.replace("CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}")
223
+ messages = [
224
+ {"role": "system", "content": system_content},
225
+ {"role": "user", "content": data},
226
+ ]
227
+ else:
228
+ system_content = EXTRACT_RELATIONS_PROMPT.replace("USER_ID", user_identity)
229
+ messages = [
230
+ {"role": "system", "content": system_content},
231
+ {"role": "user", "content": f"List of entities: {list(entity_type_map.keys())}. \n\nText: {data}"},
232
+ ]
233
+
234
+ _tools = [RELATIONS_TOOL]
235
+ if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
236
+ _tools = [RELATIONS_STRUCT_TOOL]
237
+
238
+ extracted_entities = self.llm.generate_response(
239
+ messages=messages,
240
+ tools=_tools,
241
+ )
242
+
243
+ entities = []
244
+ if extracted_entities.get("tool_calls"):
245
+ entities = extracted_entities["tool_calls"][0].get("arguments", {}).get("entities", [])
246
+
247
+ entities = self._remove_spaces_from_entities(entities)
248
+ logger.debug(f"Extracted entities: {entities}")
249
+ return entities
250
+
251
+ def _search_graph_db(self, node_list, filters, limit=100):
252
+ """Search similar nodes among and their respective incoming and outgoing relations."""
253
+ result_relations = []
254
+ agent_filter = ""
255
+ if filters.get("agent_id"):
256
+ agent_filter = "AND n.agent_id = $agent_id AND m.agent_id = $agent_id"
257
+
258
+ for node in node_list:
259
+ n_embedding = self.embedding_model.embed(node)
260
+
261
+ cypher_query = f"""
262
+ MATCH (n {self.node_label})
263
+ WHERE n.embedding IS NOT NULL AND n.user_id = $user_id
264
+ {agent_filter}
265
+ WITH n, round(2 * vector.similarity.cosine(n.embedding, $n_embedding) - 1, 4) AS similarity // denormalize for backward compatibility
266
+ WHERE similarity >= $threshold
267
+ CALL {{
268
+ MATCH (n)-[r]->(m)
269
+ WHERE m.user_id = $user_id {agent_filter.replace("n.", "m.")}
270
+ RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relationship, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id
271
+ UNION
272
+ MATCH (m)-[r]->(n)
273
+ WHERE m.user_id = $user_id {agent_filter.replace("n.", "m.")}
274
+ RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relationship, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id
275
+ }}
276
+ WITH distinct source, source_id, relationship, relation_id, destination, destination_id, similarity
277
+ RETURN source, source_id, relationship, relation_id, destination, destination_id, similarity
278
+ ORDER BY similarity DESC
279
+ LIMIT $limit
280
+ """
281
+
282
+ params = {
283
+ "n_embedding": n_embedding,
284
+ "threshold": self.threshold,
285
+ "user_id": filters["user_id"],
286
+ "limit": limit,
287
+ }
288
+ if filters.get("agent_id"):
289
+ params["agent_id"] = filters["agent_id"]
290
+
291
+ ans = self.graph.query(cypher_query, params=params)
292
+ result_relations.extend(ans)
293
+
294
+ return result_relations
295
+
296
+ def _get_delete_entities_from_search_output(self, search_output, data, filters):
297
+ """Get the entities to be deleted from the search output."""
298
+ search_output_string = format_entities(search_output)
299
+
300
+ # Compose user identification string for prompt
301
+ user_identity = f"user_id: {filters['user_id']}"
302
+ if filters.get("agent_id"):
303
+ user_identity += f", agent_id: {filters['agent_id']}"
304
+
305
+ system_prompt, user_prompt = get_delete_messages(search_output_string, data, user_identity)
306
+
307
+ _tools = [DELETE_MEMORY_TOOL_GRAPH]
308
+ if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
309
+ _tools = [
310
+ DELETE_MEMORY_STRUCT_TOOL_GRAPH,
311
+ ]
312
+
313
+ memory_updates = self.llm.generate_response(
314
+ messages=[
315
+ {"role": "system", "content": system_prompt},
316
+ {"role": "user", "content": user_prompt},
317
+ ],
318
+ tools=_tools,
319
+ )
320
+
321
+ to_be_deleted = []
322
+ for item in memory_updates.get("tool_calls", []):
323
+ if item.get("name") == "delete_graph_memory":
324
+ to_be_deleted.append(item.get("arguments"))
325
+ # Clean entities formatting
326
+ to_be_deleted = self._remove_spaces_from_entities(to_be_deleted)
327
+ logger.debug(f"Deleted relationships: {to_be_deleted}")
328
+ return to_be_deleted
329
+
330
+ def _delete_entities(self, to_be_deleted, filters):
331
+ """Delete the entities from the graph."""
332
+ user_id = filters["user_id"]
333
+ agent_id = filters.get("agent_id", None)
334
+ results = []
335
+
336
+ for item in to_be_deleted:
337
+ source = item["source"]
338
+ destination = item["destination"]
339
+ relationship = item["relationship"]
340
+
341
+ # Build the agent filter for the query
342
+ agent_filter = ""
343
+ params = {
344
+ "source_name": source,
345
+ "dest_name": destination,
346
+ "user_id": user_id,
347
+ }
348
+
349
+ if agent_id:
350
+ agent_filter = "AND n.agent_id = $agent_id AND m.agent_id = $agent_id"
351
+ params["agent_id"] = agent_id
352
+
353
+ # Delete the specific relationship between nodes
354
+ cypher = f"""
355
+ MATCH (n {self.node_label} {{name: $source_name, user_id: $user_id}})
356
+ -[r:{relationship}]->
357
+ (m {self.node_label} {{name: $dest_name, user_id: $user_id}})
358
+ WHERE 1=1 {agent_filter}
359
+ DELETE r
360
+ RETURN
361
+ n.name AS source,
362
+ m.name AS target,
363
+ type(r) AS relationship
364
+ """
365
+
366
+ result = self.graph.query(cypher, params=params)
367
+ results.append(result)
368
+
369
+ return results
370
+
371
+ def _add_entities(self, to_be_added, filters, entity_type_map):
372
+ """Add the new entities to the graph. Merge the nodes if they already exist."""
373
+ user_id = filters["user_id"]
374
+ agent_id = filters.get("agent_id", None)
375
+ results = []
376
+ for item in to_be_added:
377
+ # entities
378
+ source = item["source"]
379
+ destination = item["destination"]
380
+ relationship = item["relationship"]
381
+
382
+ # types
383
+ source_type = entity_type_map.get(source, "__User__")
384
+ source_label = self.node_label if self.node_label else f":`{source_type}`"
385
+ source_extra_set = f", source:`{source_type}`" if self.node_label else ""
386
+ destination_type = entity_type_map.get(destination, "__User__")
387
+ destination_label = self.node_label if self.node_label else f":`{destination_type}`"
388
+ destination_extra_set = f", destination:`{destination_type}`" if self.node_label else ""
389
+
390
+ # embeddings
391
+ source_embedding = self.embedding_model.embed(source)
392
+ dest_embedding = self.embedding_model.embed(destination)
393
+
394
+ # search for the nodes with the closest embeddings
395
+ source_node_search_result = self._search_source_node(source_embedding, filters, threshold=0.9)
396
+ destination_node_search_result = self._search_destination_node(dest_embedding, filters, threshold=0.9)
397
+
398
+ # TODO: Create a cypher query and common params for all the cases
399
+ if not destination_node_search_result and source_node_search_result:
400
+ # Build destination MERGE properties
401
+ merge_props = ["name: $destination_name", "user_id: $user_id"]
402
+ if agent_id:
403
+ merge_props.append("agent_id: $agent_id")
404
+ merge_props_str = ", ".join(merge_props)
405
+
406
+ cypher = f"""
407
+ MATCH (source)
408
+ WHERE elementId(source) = $source_id
409
+ SET source.mentions = coalesce(source.mentions, 0) + 1
410
+ WITH source
411
+ MERGE (destination {destination_label} {{{merge_props_str}}})
412
+ ON CREATE SET
413
+ destination.created = timestamp(),
414
+ destination.mentions = 1
415
+ {destination_extra_set}
416
+ ON MATCH SET
417
+ destination.mentions = coalesce(destination.mentions, 0) + 1
418
+ WITH source, destination
419
+ CALL db.create.setNodeVectorProperty(destination, 'embedding', $destination_embedding)
420
+ WITH source, destination
421
+ MERGE (source)-[r:{relationship}]->(destination)
422
+ ON CREATE SET
423
+ r.created = timestamp(),
424
+ r.mentions = 1
425
+ ON MATCH SET
426
+ r.mentions = coalesce(r.mentions, 0) + 1
427
+ RETURN source.name AS source, type(r) AS relationship, destination.name AS target
428
+ """
429
+
430
+ params = {
431
+ "source_id": source_node_search_result[0]["elementId(source_candidate)"],
432
+ "destination_name": destination,
433
+ "destination_embedding": dest_embedding,
434
+ "user_id": user_id,
435
+ }
436
+ if agent_id:
437
+ params["agent_id"] = agent_id
438
+
439
+ elif destination_node_search_result and not source_node_search_result:
440
+ # Build source MERGE properties
441
+ merge_props = ["name: $source_name", "user_id: $user_id"]
442
+ if agent_id:
443
+ merge_props.append("agent_id: $agent_id")
444
+ merge_props_str = ", ".join(merge_props)
445
+
446
+ cypher = f"""
447
+ MATCH (destination)
448
+ WHERE elementId(destination) = $destination_id
449
+ SET destination.mentions = coalesce(destination.mentions, 0) + 1
450
+ WITH destination
451
+ MERGE (source {source_label} {{{merge_props_str}}})
452
+ ON CREATE SET
453
+ source.created = timestamp(),
454
+ source.mentions = 1
455
+ {source_extra_set}
456
+ ON MATCH SET
457
+ source.mentions = coalesce(source.mentions, 0) + 1
458
+ WITH source, destination
459
+ CALL db.create.setNodeVectorProperty(source, 'embedding', $source_embedding)
460
+ WITH source, destination
461
+ MERGE (source)-[r:{relationship}]->(destination)
462
+ ON CREATE SET
463
+ r.created = timestamp(),
464
+ r.mentions = 1
465
+ ON MATCH SET
466
+ r.mentions = coalesce(r.mentions, 0) + 1
467
+ RETURN source.name AS source, type(r) AS relationship, destination.name AS target
468
+ """
469
+
470
+ params = {
471
+ "destination_id": destination_node_search_result[0]["elementId(destination_candidate)"],
472
+ "source_name": source,
473
+ "source_embedding": source_embedding,
474
+ "user_id": user_id,
475
+ }
476
+ if agent_id:
477
+ params["agent_id"] = agent_id
478
+
479
+ elif source_node_search_result and destination_node_search_result:
480
+ cypher = f"""
481
+ MATCH (source)
482
+ WHERE elementId(source) = $source_id
483
+ SET source.mentions = coalesce(source.mentions, 0) + 1
484
+ WITH source
485
+ MATCH (destination)
486
+ WHERE elementId(destination) = $destination_id
487
+ SET destination.mentions = coalesce(destination.mentions, 0) + 1
488
+ MERGE (source)-[r:{relationship}]->(destination)
489
+ ON CREATE SET
490
+ r.created_at = timestamp(),
491
+ r.updated_at = timestamp(),
492
+ r.mentions = 1
493
+ ON MATCH SET r.mentions = coalesce(r.mentions, 0) + 1
494
+ RETURN source.name AS source, type(r) AS relationship, destination.name AS target
495
+ """
496
+
497
+ params = {
498
+ "source_id": source_node_search_result[0]["elementId(source_candidate)"],
499
+ "destination_id": destination_node_search_result[0]["elementId(destination_candidate)"],
500
+ "user_id": user_id,
501
+ }
502
+ if agent_id:
503
+ params["agent_id"] = agent_id
504
+
505
+ else:
506
+ # Build dynamic MERGE props for both source and destination
507
+ source_props = ["name: $source_name", "user_id: $user_id"]
508
+ dest_props = ["name: $dest_name", "user_id: $user_id"]
509
+ if agent_id:
510
+ source_props.append("agent_id: $agent_id")
511
+ dest_props.append("agent_id: $agent_id")
512
+ source_props_str = ", ".join(source_props)
513
+ dest_props_str = ", ".join(dest_props)
514
+
515
+ cypher = f"""
516
+ MERGE (source {source_label} {{{source_props_str}}})
517
+ ON CREATE SET source.created = timestamp(),
518
+ source.mentions = 1
519
+ {source_extra_set}
520
+ ON MATCH SET source.mentions = coalesce(source.mentions, 0) + 1
521
+ WITH source
522
+ CALL db.create.setNodeVectorProperty(source, 'embedding', $source_embedding)
523
+ WITH source
524
+ MERGE (destination {destination_label} {{{dest_props_str}}})
525
+ ON CREATE SET destination.created = timestamp(),
526
+ destination.mentions = 1
527
+ {destination_extra_set}
528
+ ON MATCH SET destination.mentions = coalesce(destination.mentions, 0) + 1
529
+ WITH source, destination
530
+ CALL db.create.setNodeVectorProperty(destination, 'embedding', $dest_embedding)
531
+ WITH source, destination
532
+ MERGE (source)-[rel:{relationship}]->(destination)
533
+ ON CREATE SET rel.created = timestamp(), rel.mentions = 1
534
+ ON MATCH SET rel.mentions = coalesce(rel.mentions, 0) + 1
535
+ RETURN source.name AS source, type(rel) AS relationship, destination.name AS target
536
+ """
537
+
538
+ params = {
539
+ "source_name": source,
540
+ "dest_name": destination,
541
+ "source_embedding": source_embedding,
542
+ "dest_embedding": dest_embedding,
543
+ "user_id": user_id,
544
+ }
545
+ if agent_id:
546
+ params["agent_id"] = agent_id
547
+ result = self.graph.query(cypher, params=params)
548
+ results.append(result)
549
+ return results
550
+
551
+ def _remove_spaces_from_entities(self, entity_list):
552
+ for item in entity_list:
553
+ item["source"] = item["source"].lower().replace(" ", "_")
554
+ item["relationship"] = item["relationship"].lower().replace(" ", "_")
555
+ item["destination"] = item["destination"].lower().replace(" ", "_")
556
+ return entity_list
557
+
558
+ def _search_source_node(self, source_embedding, filters, threshold=0.9):
559
+ agent_filter = ""
560
+ if filters.get("agent_id"):
561
+ agent_filter = "AND source_candidate.agent_id = $agent_id"
562
+
563
+ cypher = f"""
564
+ MATCH (source_candidate {self.node_label})
565
+ WHERE source_candidate.embedding IS NOT NULL
566
+ AND source_candidate.user_id = $user_id
567
+ {agent_filter}
568
+
569
+ WITH source_candidate,
570
+ round(2 * vector.similarity.cosine(source_candidate.embedding, $source_embedding) - 1, 4) AS source_similarity // denormalize for backward compatibility
571
+ WHERE source_similarity >= $threshold
572
+
573
+ WITH source_candidate, source_similarity
574
+ ORDER BY source_similarity DESC
575
+ LIMIT 1
576
+
577
+ RETURN elementId(source_candidate)
578
+ """
579
+
580
+ params = {
581
+ "source_embedding": source_embedding,
582
+ "user_id": filters["user_id"],
583
+ "threshold": threshold,
584
+ }
585
+ if filters.get("agent_id"):
586
+ params["agent_id"] = filters["agent_id"]
587
+
588
+ result = self.graph.query(cypher, params=params)
589
+ return result
590
+
591
+ def _search_destination_node(self, destination_embedding, filters, threshold=0.9):
592
+ agent_filter = ""
593
+ if filters.get("agent_id"):
594
+ agent_filter = "AND destination_candidate.agent_id = $agent_id"
595
+
596
+ cypher = f"""
597
+ MATCH (destination_candidate {self.node_label})
598
+ WHERE destination_candidate.embedding IS NOT NULL
599
+ AND destination_candidate.user_id = $user_id
600
+ {agent_filter}
601
+
602
+ WITH destination_candidate,
603
+ round(2 * vector.similarity.cosine(destination_candidate.embedding, $destination_embedding) - 1, 4) AS destination_similarity // denormalize for backward compatibility
604
+
605
+ WHERE destination_similarity >= $threshold
606
+
607
+ WITH destination_candidate, destination_similarity
608
+ ORDER BY destination_similarity DESC
609
+ LIMIT 1
610
+
611
+ RETURN elementId(destination_candidate)
612
+ """
613
+
614
+ params = {
615
+ "destination_embedding": destination_embedding,
616
+ "user_id": filters["user_id"],
617
+ "threshold": threshold,
618
+ }
619
+ if filters.get("agent_id"):
620
+ params["agent_id"] = filters["agent_id"]
621
+
622
+ result = self.graph.query(cypher, params=params)
623
+ return result
624
+
625
+ # Reset is not defined in base.py
626
+ def reset(self):
627
+ """Reset the graph by clearing all nodes and relationships."""
628
+ logger.warning("Clearing graph...")
629
+ cypher_query = """
630
+ MATCH (n) DETACH DELETE n
631
+ """
632
+ return self.graph.query(cypher_query)