agentrun-mem0ai 0.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentrun_mem0/__init__.py +6 -0
- agentrun_mem0/client/__init__.py +0 -0
- agentrun_mem0/client/main.py +1747 -0
- agentrun_mem0/client/project.py +931 -0
- agentrun_mem0/client/utils.py +115 -0
- agentrun_mem0/configs/__init__.py +0 -0
- agentrun_mem0/configs/base.py +90 -0
- agentrun_mem0/configs/embeddings/__init__.py +0 -0
- agentrun_mem0/configs/embeddings/base.py +110 -0
- agentrun_mem0/configs/enums.py +7 -0
- agentrun_mem0/configs/llms/__init__.py +0 -0
- agentrun_mem0/configs/llms/anthropic.py +56 -0
- agentrun_mem0/configs/llms/aws_bedrock.py +192 -0
- agentrun_mem0/configs/llms/azure.py +57 -0
- agentrun_mem0/configs/llms/base.py +62 -0
- agentrun_mem0/configs/llms/deepseek.py +56 -0
- agentrun_mem0/configs/llms/lmstudio.py +59 -0
- agentrun_mem0/configs/llms/ollama.py +56 -0
- agentrun_mem0/configs/llms/openai.py +79 -0
- agentrun_mem0/configs/llms/vllm.py +56 -0
- agentrun_mem0/configs/prompts.py +459 -0
- agentrun_mem0/configs/rerankers/__init__.py +0 -0
- agentrun_mem0/configs/rerankers/base.py +17 -0
- agentrun_mem0/configs/rerankers/cohere.py +15 -0
- agentrun_mem0/configs/rerankers/config.py +12 -0
- agentrun_mem0/configs/rerankers/huggingface.py +17 -0
- agentrun_mem0/configs/rerankers/llm.py +48 -0
- agentrun_mem0/configs/rerankers/sentence_transformer.py +16 -0
- agentrun_mem0/configs/rerankers/zero_entropy.py +28 -0
- agentrun_mem0/configs/vector_stores/__init__.py +0 -0
- agentrun_mem0/configs/vector_stores/alibabacloud_mysql.py +64 -0
- agentrun_mem0/configs/vector_stores/aliyun_tablestore.py +32 -0
- agentrun_mem0/configs/vector_stores/azure_ai_search.py +57 -0
- agentrun_mem0/configs/vector_stores/azure_mysql.py +84 -0
- agentrun_mem0/configs/vector_stores/baidu.py +27 -0
- agentrun_mem0/configs/vector_stores/chroma.py +58 -0
- agentrun_mem0/configs/vector_stores/databricks.py +61 -0
- agentrun_mem0/configs/vector_stores/elasticsearch.py +65 -0
- agentrun_mem0/configs/vector_stores/faiss.py +37 -0
- agentrun_mem0/configs/vector_stores/langchain.py +30 -0
- agentrun_mem0/configs/vector_stores/milvus.py +42 -0
- agentrun_mem0/configs/vector_stores/mongodb.py +25 -0
- agentrun_mem0/configs/vector_stores/neptune.py +27 -0
- agentrun_mem0/configs/vector_stores/opensearch.py +41 -0
- agentrun_mem0/configs/vector_stores/pgvector.py +52 -0
- agentrun_mem0/configs/vector_stores/pinecone.py +55 -0
- agentrun_mem0/configs/vector_stores/qdrant.py +47 -0
- agentrun_mem0/configs/vector_stores/redis.py +24 -0
- agentrun_mem0/configs/vector_stores/s3_vectors.py +28 -0
- agentrun_mem0/configs/vector_stores/supabase.py +44 -0
- agentrun_mem0/configs/vector_stores/upstash_vector.py +34 -0
- agentrun_mem0/configs/vector_stores/valkey.py +15 -0
- agentrun_mem0/configs/vector_stores/vertex_ai_vector_search.py +28 -0
- agentrun_mem0/configs/vector_stores/weaviate.py +41 -0
- agentrun_mem0/embeddings/__init__.py +0 -0
- agentrun_mem0/embeddings/aws_bedrock.py +100 -0
- agentrun_mem0/embeddings/azure_openai.py +55 -0
- agentrun_mem0/embeddings/base.py +31 -0
- agentrun_mem0/embeddings/configs.py +30 -0
- agentrun_mem0/embeddings/gemini.py +39 -0
- agentrun_mem0/embeddings/huggingface.py +44 -0
- agentrun_mem0/embeddings/langchain.py +35 -0
- agentrun_mem0/embeddings/lmstudio.py +29 -0
- agentrun_mem0/embeddings/mock.py +11 -0
- agentrun_mem0/embeddings/ollama.py +53 -0
- agentrun_mem0/embeddings/openai.py +49 -0
- agentrun_mem0/embeddings/together.py +31 -0
- agentrun_mem0/embeddings/vertexai.py +64 -0
- agentrun_mem0/exceptions.py +503 -0
- agentrun_mem0/graphs/__init__.py +0 -0
- agentrun_mem0/graphs/configs.py +105 -0
- agentrun_mem0/graphs/neptune/__init__.py +0 -0
- agentrun_mem0/graphs/neptune/base.py +497 -0
- agentrun_mem0/graphs/neptune/neptunedb.py +511 -0
- agentrun_mem0/graphs/neptune/neptunegraph.py +474 -0
- agentrun_mem0/graphs/tools.py +371 -0
- agentrun_mem0/graphs/utils.py +97 -0
- agentrun_mem0/llms/__init__.py +0 -0
- agentrun_mem0/llms/anthropic.py +87 -0
- agentrun_mem0/llms/aws_bedrock.py +665 -0
- agentrun_mem0/llms/azure_openai.py +141 -0
- agentrun_mem0/llms/azure_openai_structured.py +91 -0
- agentrun_mem0/llms/base.py +131 -0
- agentrun_mem0/llms/configs.py +34 -0
- agentrun_mem0/llms/deepseek.py +107 -0
- agentrun_mem0/llms/gemini.py +201 -0
- agentrun_mem0/llms/groq.py +88 -0
- agentrun_mem0/llms/langchain.py +94 -0
- agentrun_mem0/llms/litellm.py +87 -0
- agentrun_mem0/llms/lmstudio.py +114 -0
- agentrun_mem0/llms/ollama.py +117 -0
- agentrun_mem0/llms/openai.py +147 -0
- agentrun_mem0/llms/openai_structured.py +52 -0
- agentrun_mem0/llms/sarvam.py +89 -0
- agentrun_mem0/llms/together.py +88 -0
- agentrun_mem0/llms/vllm.py +107 -0
- agentrun_mem0/llms/xai.py +52 -0
- agentrun_mem0/memory/__init__.py +0 -0
- agentrun_mem0/memory/base.py +63 -0
- agentrun_mem0/memory/graph_memory.py +698 -0
- agentrun_mem0/memory/kuzu_memory.py +713 -0
- agentrun_mem0/memory/main.py +2229 -0
- agentrun_mem0/memory/memgraph_memory.py +689 -0
- agentrun_mem0/memory/setup.py +56 -0
- agentrun_mem0/memory/storage.py +218 -0
- agentrun_mem0/memory/telemetry.py +90 -0
- agentrun_mem0/memory/utils.py +208 -0
- agentrun_mem0/proxy/__init__.py +0 -0
- agentrun_mem0/proxy/main.py +189 -0
- agentrun_mem0/reranker/__init__.py +9 -0
- agentrun_mem0/reranker/base.py +20 -0
- agentrun_mem0/reranker/cohere_reranker.py +85 -0
- agentrun_mem0/reranker/huggingface_reranker.py +147 -0
- agentrun_mem0/reranker/llm_reranker.py +142 -0
- agentrun_mem0/reranker/sentence_transformer_reranker.py +107 -0
- agentrun_mem0/reranker/zero_entropy_reranker.py +96 -0
- agentrun_mem0/utils/factory.py +283 -0
- agentrun_mem0/utils/gcp_auth.py +167 -0
- agentrun_mem0/vector_stores/__init__.py +0 -0
- agentrun_mem0/vector_stores/alibabacloud_mysql.py +547 -0
- agentrun_mem0/vector_stores/aliyun_tablestore.py +252 -0
- agentrun_mem0/vector_stores/azure_ai_search.py +396 -0
- agentrun_mem0/vector_stores/azure_mysql.py +463 -0
- agentrun_mem0/vector_stores/baidu.py +368 -0
- agentrun_mem0/vector_stores/base.py +58 -0
- agentrun_mem0/vector_stores/chroma.py +332 -0
- agentrun_mem0/vector_stores/configs.py +67 -0
- agentrun_mem0/vector_stores/databricks.py +761 -0
- agentrun_mem0/vector_stores/elasticsearch.py +237 -0
- agentrun_mem0/vector_stores/faiss.py +479 -0
- agentrun_mem0/vector_stores/langchain.py +180 -0
- agentrun_mem0/vector_stores/milvus.py +250 -0
- agentrun_mem0/vector_stores/mongodb.py +310 -0
- agentrun_mem0/vector_stores/neptune_analytics.py +467 -0
- agentrun_mem0/vector_stores/opensearch.py +292 -0
- agentrun_mem0/vector_stores/pgvector.py +404 -0
- agentrun_mem0/vector_stores/pinecone.py +382 -0
- agentrun_mem0/vector_stores/qdrant.py +270 -0
- agentrun_mem0/vector_stores/redis.py +295 -0
- agentrun_mem0/vector_stores/s3_vectors.py +176 -0
- agentrun_mem0/vector_stores/supabase.py +237 -0
- agentrun_mem0/vector_stores/upstash_vector.py +293 -0
- agentrun_mem0/vector_stores/valkey.py +824 -0
- agentrun_mem0/vector_stores/vertex_ai_vector_search.py +635 -0
- agentrun_mem0/vector_stores/weaviate.py +343 -0
- agentrun_mem0ai-0.0.11.data/data/README.md +205 -0
- agentrun_mem0ai-0.0.11.dist-info/METADATA +277 -0
- agentrun_mem0ai-0.0.11.dist-info/RECORD +150 -0
- agentrun_mem0ai-0.0.11.dist-info/WHEEL +4 -0
- agentrun_mem0ai-0.0.11.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,689 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from agentrun_mem0.memory.utils import format_entities, sanitize_relationship_for_cypher
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
from langchain_memgraph.graphs.memgraph import Memgraph
|
|
7
|
+
except ImportError:
|
|
8
|
+
raise ImportError("langchain_memgraph is not installed. Please install it using pip install langchain-memgraph")
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
from rank_bm25 import BM25Okapi
|
|
12
|
+
except ImportError:
|
|
13
|
+
raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25")
|
|
14
|
+
|
|
15
|
+
from agentrun_mem0.graphs.tools import (
|
|
16
|
+
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
|
|
17
|
+
DELETE_MEMORY_TOOL_GRAPH,
|
|
18
|
+
EXTRACT_ENTITIES_STRUCT_TOOL,
|
|
19
|
+
EXTRACT_ENTITIES_TOOL,
|
|
20
|
+
RELATIONS_STRUCT_TOOL,
|
|
21
|
+
RELATIONS_TOOL,
|
|
22
|
+
)
|
|
23
|
+
from agentrun_mem0.graphs.utils import EXTRACT_RELATIONS_PROMPT, get_delete_messages
|
|
24
|
+
from agentrun_mem0.utils.factory import EmbedderFactory, LlmFactory
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class MemoryGraph:
|
|
30
|
+
def __init__(self, config):
|
|
31
|
+
self.config = config
|
|
32
|
+
self.graph = Memgraph(
|
|
33
|
+
self.config.graph_store.config.url,
|
|
34
|
+
self.config.graph_store.config.username,
|
|
35
|
+
self.config.graph_store.config.password,
|
|
36
|
+
)
|
|
37
|
+
self.embedding_model = EmbedderFactory.create(
|
|
38
|
+
self.config.embedder.provider,
|
|
39
|
+
self.config.embedder.config,
|
|
40
|
+
{"enable_embeddings": True},
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# Default to openai if no specific provider is configured
|
|
44
|
+
self.llm_provider = "openai"
|
|
45
|
+
if self.config.llm and self.config.llm.provider:
|
|
46
|
+
self.llm_provider = self.config.llm.provider
|
|
47
|
+
if self.config.graph_store and self.config.graph_store.llm and self.config.graph_store.llm.provider:
|
|
48
|
+
self.llm_provider = self.config.graph_store.llm.provider
|
|
49
|
+
|
|
50
|
+
# Get LLM config with proper null checks
|
|
51
|
+
llm_config = None
|
|
52
|
+
if self.config.graph_store and self.config.graph_store.llm and hasattr(self.config.graph_store.llm, "config"):
|
|
53
|
+
llm_config = self.config.graph_store.llm.config
|
|
54
|
+
elif hasattr(self.config.llm, "config"):
|
|
55
|
+
llm_config = self.config.llm.config
|
|
56
|
+
self.llm = LlmFactory.create(self.llm_provider, llm_config)
|
|
57
|
+
self.user_id = None
|
|
58
|
+
self.threshold = 0.7
|
|
59
|
+
|
|
60
|
+
# Setup Memgraph:
|
|
61
|
+
# 1. Create vector index (created Entity label on all nodes)
|
|
62
|
+
# 2. Create label property index for performance optimizations
|
|
63
|
+
embedding_dims = self.config.embedder.config["embedding_dims"]
|
|
64
|
+
index_info = self._fetch_existing_indexes()
|
|
65
|
+
|
|
66
|
+
# Create vector index if not exists
|
|
67
|
+
if not self._vector_index_exists(index_info, "memzero"):
|
|
68
|
+
self.graph.query(
|
|
69
|
+
f"CREATE VECTOR INDEX memzero ON :Entity(embedding) WITH CONFIG {{'dimension': {embedding_dims}, 'capacity': 1000, 'metric': 'cos'}};"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# Create label+property index if not exists
|
|
73
|
+
if not self._label_property_index_exists(index_info, "Entity", "user_id"):
|
|
74
|
+
self.graph.query("CREATE INDEX ON :Entity(user_id);")
|
|
75
|
+
|
|
76
|
+
# Create label index if not exists
|
|
77
|
+
if not self._label_index_exists(index_info, "Entity"):
|
|
78
|
+
self.graph.query("CREATE INDEX ON :Entity;")
|
|
79
|
+
|
|
80
|
+
def add(self, data, filters):
|
|
81
|
+
"""
|
|
82
|
+
Adds data to the graph.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
data (str): The data to add to the graph.
|
|
86
|
+
filters (dict): A dictionary containing filters to be applied during the addition.
|
|
87
|
+
"""
|
|
88
|
+
entity_type_map = self._retrieve_nodes_from_data(data, filters)
|
|
89
|
+
to_be_added = self._establish_nodes_relations_from_data(data, filters, entity_type_map)
|
|
90
|
+
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
|
|
91
|
+
to_be_deleted = self._get_delete_entities_from_search_output(search_output, data, filters)
|
|
92
|
+
|
|
93
|
+
# TODO: Batch queries with APOC plugin
|
|
94
|
+
# TODO: Add more filter support
|
|
95
|
+
deleted_entities = self._delete_entities(to_be_deleted, filters)
|
|
96
|
+
added_entities = self._add_entities(to_be_added, filters, entity_type_map)
|
|
97
|
+
|
|
98
|
+
return {"deleted_entities": deleted_entities, "added_entities": added_entities}
|
|
99
|
+
|
|
100
|
+
def search(self, query, filters, limit=100):
|
|
101
|
+
"""
|
|
102
|
+
Search for memories and related graph data.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
query (str): Query to search for.
|
|
106
|
+
filters (dict): A dictionary containing filters to be applied during the search.
|
|
107
|
+
limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100.
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
dict: A dictionary containing:
|
|
111
|
+
- "contexts": List of search results from the base data store.
|
|
112
|
+
- "entities": List of related graph data based on the query.
|
|
113
|
+
"""
|
|
114
|
+
entity_type_map = self._retrieve_nodes_from_data(query, filters)
|
|
115
|
+
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
|
|
116
|
+
|
|
117
|
+
if not search_output:
|
|
118
|
+
return []
|
|
119
|
+
|
|
120
|
+
search_outputs_sequence = [
|
|
121
|
+
[item["source"], item["relationship"], item["destination"]] for item in search_output
|
|
122
|
+
]
|
|
123
|
+
bm25 = BM25Okapi(search_outputs_sequence)
|
|
124
|
+
|
|
125
|
+
tokenized_query = query.split(" ")
|
|
126
|
+
reranked_results = bm25.get_top_n(tokenized_query, search_outputs_sequence, n=5)
|
|
127
|
+
|
|
128
|
+
search_results = []
|
|
129
|
+
for item in reranked_results:
|
|
130
|
+
search_results.append({"source": item[0], "relationship": item[1], "destination": item[2]})
|
|
131
|
+
|
|
132
|
+
logger.info(f"Returned {len(search_results)} search results")
|
|
133
|
+
|
|
134
|
+
return search_results
|
|
135
|
+
|
|
136
|
+
def delete_all(self, filters):
|
|
137
|
+
"""Delete all nodes and relationships for a user or specific agent."""
|
|
138
|
+
if filters.get("agent_id"):
|
|
139
|
+
cypher = """
|
|
140
|
+
MATCH (n:Entity {user_id: $user_id, agent_id: $agent_id})
|
|
141
|
+
DETACH DELETE n
|
|
142
|
+
"""
|
|
143
|
+
params = {"user_id": filters["user_id"], "agent_id": filters["agent_id"]}
|
|
144
|
+
else:
|
|
145
|
+
cypher = """
|
|
146
|
+
MATCH (n:Entity {user_id: $user_id})
|
|
147
|
+
DETACH DELETE n
|
|
148
|
+
"""
|
|
149
|
+
params = {"user_id": filters["user_id"]}
|
|
150
|
+
self.graph.query(cypher, params=params)
|
|
151
|
+
|
|
152
|
+
def get_all(self, filters, limit=100):
|
|
153
|
+
"""
|
|
154
|
+
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
filters (dict): A dictionary containing filters to be applied during the retrieval.
|
|
158
|
+
Supports 'user_id' (required) and 'agent_id' (optional).
|
|
159
|
+
limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100.
|
|
160
|
+
Returns:
|
|
161
|
+
list: A list of dictionaries, each containing:
|
|
162
|
+
- 'source': The source node name.
|
|
163
|
+
- 'relationship': The relationship type.
|
|
164
|
+
- 'target': The target node name.
|
|
165
|
+
"""
|
|
166
|
+
# Build query based on whether agent_id is provided
|
|
167
|
+
if filters.get("agent_id"):
|
|
168
|
+
query = """
|
|
169
|
+
MATCH (n:Entity {user_id: $user_id, agent_id: $agent_id})-[r]->(m:Entity {user_id: $user_id, agent_id: $agent_id})
|
|
170
|
+
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
|
171
|
+
LIMIT $limit
|
|
172
|
+
"""
|
|
173
|
+
params = {"user_id": filters["user_id"], "agent_id": filters["agent_id"], "limit": limit}
|
|
174
|
+
else:
|
|
175
|
+
query = """
|
|
176
|
+
MATCH (n:Entity {user_id: $user_id})-[r]->(m:Entity {user_id: $user_id})
|
|
177
|
+
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
|
178
|
+
LIMIT $limit
|
|
179
|
+
"""
|
|
180
|
+
params = {"user_id": filters["user_id"], "limit": limit}
|
|
181
|
+
|
|
182
|
+
results = self.graph.query(query, params=params)
|
|
183
|
+
|
|
184
|
+
final_results = []
|
|
185
|
+
for result in results:
|
|
186
|
+
final_results.append(
|
|
187
|
+
{
|
|
188
|
+
"source": result["source"],
|
|
189
|
+
"relationship": result["relationship"],
|
|
190
|
+
"target": result["target"],
|
|
191
|
+
}
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
logger.info(f"Retrieved {len(final_results)} relationships")
|
|
195
|
+
|
|
196
|
+
return final_results
|
|
197
|
+
|
|
198
|
+
def _retrieve_nodes_from_data(self, data, filters):
|
|
199
|
+
"""Extracts all the entities mentioned in the query."""
|
|
200
|
+
_tools = [EXTRACT_ENTITIES_TOOL]
|
|
201
|
+
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
|
202
|
+
_tools = [EXTRACT_ENTITIES_STRUCT_TOOL]
|
|
203
|
+
search_results = self.llm.generate_response(
|
|
204
|
+
messages=[
|
|
205
|
+
{
|
|
206
|
+
"role": "system",
|
|
207
|
+
"content": f"You are a smart assistant who understands entities and their types in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source entity. Extract all the entities from the text. ***DO NOT*** answer the question itself if the given text is a question.",
|
|
208
|
+
},
|
|
209
|
+
{"role": "user", "content": data},
|
|
210
|
+
],
|
|
211
|
+
tools=_tools,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
entity_type_map = {}
|
|
215
|
+
|
|
216
|
+
try:
|
|
217
|
+
for tool_call in search_results["tool_calls"]:
|
|
218
|
+
if tool_call["name"] != "extract_entities":
|
|
219
|
+
continue
|
|
220
|
+
for item in tool_call["arguments"]["entities"]:
|
|
221
|
+
entity_type_map[item["entity"]] = item["entity_type"]
|
|
222
|
+
except Exception as e:
|
|
223
|
+
logger.exception(
|
|
224
|
+
f"Error in search tool: {e}, llm_provider={self.llm_provider}, search_results={search_results}"
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
entity_type_map = {k.lower().replace(" ", "_"): v.lower().replace(" ", "_") for k, v in entity_type_map.items()}
|
|
228
|
+
logger.debug(f"Entity type map: {entity_type_map}\n search_results={search_results}")
|
|
229
|
+
return entity_type_map
|
|
230
|
+
|
|
231
|
+
def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
|
|
232
|
+
"""Eshtablish relations among the extracted nodes."""
|
|
233
|
+
if self.config.graph_store.custom_prompt:
|
|
234
|
+
messages = [
|
|
235
|
+
{
|
|
236
|
+
"role": "system",
|
|
237
|
+
"content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]).replace(
|
|
238
|
+
"CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
|
|
239
|
+
),
|
|
240
|
+
},
|
|
241
|
+
{"role": "user", "content": data},
|
|
242
|
+
]
|
|
243
|
+
else:
|
|
244
|
+
messages = [
|
|
245
|
+
{
|
|
246
|
+
"role": "system",
|
|
247
|
+
"content": EXTRACT_RELATIONS_PROMPT.replace("USER_ID", filters["user_id"]),
|
|
248
|
+
},
|
|
249
|
+
{
|
|
250
|
+
"role": "user",
|
|
251
|
+
"content": f"List of entities: {list(entity_type_map.keys())}. \n\nText: {data}",
|
|
252
|
+
},
|
|
253
|
+
]
|
|
254
|
+
|
|
255
|
+
_tools = [RELATIONS_TOOL]
|
|
256
|
+
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
|
257
|
+
_tools = [RELATIONS_STRUCT_TOOL]
|
|
258
|
+
|
|
259
|
+
extracted_entities = self.llm.generate_response(
|
|
260
|
+
messages=messages,
|
|
261
|
+
tools=_tools,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
entities = []
|
|
265
|
+
if extracted_entities["tool_calls"]:
|
|
266
|
+
entities = extracted_entities["tool_calls"][0]["arguments"]["entities"]
|
|
267
|
+
|
|
268
|
+
entities = self._remove_spaces_from_entities(entities)
|
|
269
|
+
logger.debug(f"Extracted entities: {entities}")
|
|
270
|
+
return entities
|
|
271
|
+
|
|
272
|
+
def _search_graph_db(self, node_list, filters, limit=100):
|
|
273
|
+
"""Search similar nodes among and their respective incoming and outgoing relations."""
|
|
274
|
+
result_relations = []
|
|
275
|
+
|
|
276
|
+
for node in node_list:
|
|
277
|
+
n_embedding = self.embedding_model.embed(node)
|
|
278
|
+
|
|
279
|
+
# Build query based on whether agent_id is provided
|
|
280
|
+
if filters.get("agent_id"):
|
|
281
|
+
cypher_query = """
|
|
282
|
+
CALL vector_search.search("memzero", $limit, $n_embedding)
|
|
283
|
+
YIELD distance, node, similarity
|
|
284
|
+
WITH node AS n, similarity
|
|
285
|
+
WHERE n:Entity AND n.user_id = $user_id AND n.agent_id = $agent_id AND n.embedding IS NOT NULL AND similarity >= $threshold
|
|
286
|
+
MATCH (n)-[r]->(m:Entity)
|
|
287
|
+
RETURN n.name AS source, id(n) AS source_id, type(r) AS relationship, id(r) AS relation_id, m.name AS destination, id(m) AS destination_id, similarity
|
|
288
|
+
UNION
|
|
289
|
+
CALL vector_search.search("memzero", $limit, $n_embedding)
|
|
290
|
+
YIELD distance, node, similarity
|
|
291
|
+
WITH node AS n, similarity
|
|
292
|
+
WHERE n:Entity AND n.user_id = $user_id AND n.agent_id = $agent_id AND n.embedding IS NOT NULL AND similarity >= $threshold
|
|
293
|
+
MATCH (m:Entity)-[r]->(n)
|
|
294
|
+
RETURN m.name AS source, id(m) AS source_id, type(r) AS relationship, id(r) AS relation_id, n.name AS destination, id(n) AS destination_id, similarity
|
|
295
|
+
ORDER BY similarity DESC
|
|
296
|
+
LIMIT $limit;
|
|
297
|
+
"""
|
|
298
|
+
params = {
|
|
299
|
+
"n_embedding": n_embedding,
|
|
300
|
+
"threshold": self.threshold,
|
|
301
|
+
"user_id": filters["user_id"],
|
|
302
|
+
"agent_id": filters["agent_id"],
|
|
303
|
+
"limit": limit,
|
|
304
|
+
}
|
|
305
|
+
else:
|
|
306
|
+
cypher_query = """
|
|
307
|
+
CALL vector_search.search("memzero", $limit, $n_embedding)
|
|
308
|
+
YIELD distance, node, similarity
|
|
309
|
+
WITH node AS n, similarity
|
|
310
|
+
WHERE n:Entity AND n.user_id = $user_id AND n.embedding IS NOT NULL AND similarity >= $threshold
|
|
311
|
+
MATCH (n)-[r]->(m:Entity)
|
|
312
|
+
RETURN n.name AS source, id(n) AS source_id, type(r) AS relationship, id(r) AS relation_id, m.name AS destination, id(m) AS destination_id, similarity
|
|
313
|
+
UNION
|
|
314
|
+
CALL vector_search.search("memzero", $limit, $n_embedding)
|
|
315
|
+
YIELD distance, node, similarity
|
|
316
|
+
WITH node AS n, similarity
|
|
317
|
+
WHERE n:Entity AND n.user_id = $user_id AND n.embedding IS NOT NULL AND similarity >= $threshold
|
|
318
|
+
MATCH (m:Entity)-[r]->(n)
|
|
319
|
+
RETURN m.name AS source, id(m) AS source_id, type(r) AS relationship, id(r) AS relation_id, n.name AS destination, id(n) AS destination_id, similarity
|
|
320
|
+
ORDER BY similarity DESC
|
|
321
|
+
LIMIT $limit;
|
|
322
|
+
"""
|
|
323
|
+
params = {
|
|
324
|
+
"n_embedding": n_embedding,
|
|
325
|
+
"threshold": self.threshold,
|
|
326
|
+
"user_id": filters["user_id"],
|
|
327
|
+
"limit": limit,
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
ans = self.graph.query(cypher_query, params=params)
|
|
331
|
+
result_relations.extend(ans)
|
|
332
|
+
|
|
333
|
+
return result_relations
|
|
334
|
+
|
|
335
|
+
def _get_delete_entities_from_search_output(self, search_output, data, filters):
|
|
336
|
+
"""Get the entities to be deleted from the search output."""
|
|
337
|
+
search_output_string = format_entities(search_output)
|
|
338
|
+
system_prompt, user_prompt = get_delete_messages(search_output_string, data, filters["user_id"])
|
|
339
|
+
|
|
340
|
+
_tools = [DELETE_MEMORY_TOOL_GRAPH]
|
|
341
|
+
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
|
342
|
+
_tools = [
|
|
343
|
+
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
|
|
344
|
+
]
|
|
345
|
+
|
|
346
|
+
memory_updates = self.llm.generate_response(
|
|
347
|
+
messages=[
|
|
348
|
+
{"role": "system", "content": system_prompt},
|
|
349
|
+
{"role": "user", "content": user_prompt},
|
|
350
|
+
],
|
|
351
|
+
tools=_tools,
|
|
352
|
+
)
|
|
353
|
+
to_be_deleted = []
|
|
354
|
+
for item in memory_updates["tool_calls"]:
|
|
355
|
+
if item["name"] == "delete_graph_memory":
|
|
356
|
+
to_be_deleted.append(item["arguments"])
|
|
357
|
+
# in case if it is not in the correct format
|
|
358
|
+
to_be_deleted = self._remove_spaces_from_entities(to_be_deleted)
|
|
359
|
+
logger.debug(f"Deleted relationships: {to_be_deleted}")
|
|
360
|
+
return to_be_deleted
|
|
361
|
+
|
|
362
|
+
def _delete_entities(self, to_be_deleted, filters):
|
|
363
|
+
"""Delete the entities from the graph."""
|
|
364
|
+
user_id = filters["user_id"]
|
|
365
|
+
agent_id = filters.get("agent_id", None)
|
|
366
|
+
results = []
|
|
367
|
+
|
|
368
|
+
for item in to_be_deleted:
|
|
369
|
+
source = item["source"]
|
|
370
|
+
destination = item["destination"]
|
|
371
|
+
relationship = item["relationship"]
|
|
372
|
+
|
|
373
|
+
# Build the agent filter for the query
|
|
374
|
+
agent_filter = ""
|
|
375
|
+
params = {
|
|
376
|
+
"source_name": source,
|
|
377
|
+
"dest_name": destination,
|
|
378
|
+
"user_id": user_id,
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
if agent_id:
|
|
382
|
+
agent_filter = "AND n.agent_id = $agent_id AND m.agent_id = $agent_id"
|
|
383
|
+
params["agent_id"] = agent_id
|
|
384
|
+
|
|
385
|
+
# Delete the specific relationship between nodes
|
|
386
|
+
cypher = f"""
|
|
387
|
+
MATCH (n:Entity {{name: $source_name, user_id: $user_id}})
|
|
388
|
+
-[r:{relationship}]->
|
|
389
|
+
(m:Entity {{name: $dest_name, user_id: $user_id}})
|
|
390
|
+
WHERE 1=1 {agent_filter}
|
|
391
|
+
DELETE r
|
|
392
|
+
RETURN
|
|
393
|
+
n.name AS source,
|
|
394
|
+
m.name AS target,
|
|
395
|
+
type(r) AS relationship
|
|
396
|
+
"""
|
|
397
|
+
|
|
398
|
+
result = self.graph.query(cypher, params=params)
|
|
399
|
+
results.append(result)
|
|
400
|
+
|
|
401
|
+
return results
|
|
402
|
+
|
|
403
|
+
# added Entity label to all nodes for vector search to work
|
|
404
|
+
def _add_entities(self, to_be_added, filters, entity_type_map):
|
|
405
|
+
"""Add the new entities to the graph. Merge the nodes if they already exist."""
|
|
406
|
+
user_id = filters["user_id"]
|
|
407
|
+
agent_id = filters.get("agent_id", None)
|
|
408
|
+
results = []
|
|
409
|
+
|
|
410
|
+
for item in to_be_added:
|
|
411
|
+
# entities
|
|
412
|
+
source = item["source"]
|
|
413
|
+
destination = item["destination"]
|
|
414
|
+
relationship = item["relationship"]
|
|
415
|
+
|
|
416
|
+
# types
|
|
417
|
+
source_type = entity_type_map.get(source, "__User__")
|
|
418
|
+
destination_type = entity_type_map.get(destination, "__User__")
|
|
419
|
+
|
|
420
|
+
# embeddings
|
|
421
|
+
source_embedding = self.embedding_model.embed(source)
|
|
422
|
+
dest_embedding = self.embedding_model.embed(destination)
|
|
423
|
+
|
|
424
|
+
# search for the nodes with the closest embeddings
|
|
425
|
+
source_node_search_result = self._search_source_node(source_embedding, filters, threshold=0.9)
|
|
426
|
+
destination_node_search_result = self._search_destination_node(dest_embedding, filters, threshold=0.9)
|
|
427
|
+
|
|
428
|
+
# Prepare agent_id for node creation
|
|
429
|
+
agent_id_clause = ""
|
|
430
|
+
if agent_id:
|
|
431
|
+
agent_id_clause = ", agent_id: $agent_id"
|
|
432
|
+
|
|
433
|
+
# TODO: Create a cypher query and common params for all the cases
|
|
434
|
+
if not destination_node_search_result and source_node_search_result:
|
|
435
|
+
cypher = f"""
|
|
436
|
+
MATCH (source:Entity)
|
|
437
|
+
WHERE id(source) = $source_id
|
|
438
|
+
MERGE (destination:{destination_type}:Entity {{name: $destination_name, user_id: $user_id{agent_id_clause}}})
|
|
439
|
+
ON CREATE SET
|
|
440
|
+
destination.created = timestamp(),
|
|
441
|
+
destination.embedding = $destination_embedding,
|
|
442
|
+
destination:Entity
|
|
443
|
+
MERGE (source)-[r:{relationship}]->(destination)
|
|
444
|
+
ON CREATE SET
|
|
445
|
+
r.created = timestamp()
|
|
446
|
+
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
|
447
|
+
"""
|
|
448
|
+
|
|
449
|
+
params = {
|
|
450
|
+
"source_id": source_node_search_result[0]["id(source_candidate)"],
|
|
451
|
+
"destination_name": destination,
|
|
452
|
+
"destination_embedding": dest_embedding,
|
|
453
|
+
"user_id": user_id,
|
|
454
|
+
}
|
|
455
|
+
if agent_id:
|
|
456
|
+
params["agent_id"] = agent_id
|
|
457
|
+
|
|
458
|
+
elif destination_node_search_result and not source_node_search_result:
|
|
459
|
+
cypher = f"""
|
|
460
|
+
MATCH (destination:Entity)
|
|
461
|
+
WHERE id(destination) = $destination_id
|
|
462
|
+
MERGE (source:{source_type}:Entity {{name: $source_name, user_id: $user_id{agent_id_clause}}})
|
|
463
|
+
ON CREATE SET
|
|
464
|
+
source.created = timestamp(),
|
|
465
|
+
source.embedding = $source_embedding,
|
|
466
|
+
source:Entity
|
|
467
|
+
MERGE (source)-[r:{relationship}]->(destination)
|
|
468
|
+
ON CREATE SET
|
|
469
|
+
r.created = timestamp()
|
|
470
|
+
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
|
471
|
+
"""
|
|
472
|
+
|
|
473
|
+
params = {
|
|
474
|
+
"destination_id": destination_node_search_result[0]["id(destination_candidate)"],
|
|
475
|
+
"source_name": source,
|
|
476
|
+
"source_embedding": source_embedding,
|
|
477
|
+
"user_id": user_id,
|
|
478
|
+
}
|
|
479
|
+
if agent_id:
|
|
480
|
+
params["agent_id"] = agent_id
|
|
481
|
+
|
|
482
|
+
elif source_node_search_result and destination_node_search_result:
|
|
483
|
+
cypher = f"""
|
|
484
|
+
MATCH (source:Entity)
|
|
485
|
+
WHERE id(source) = $source_id
|
|
486
|
+
MATCH (destination:Entity)
|
|
487
|
+
WHERE id(destination) = $destination_id
|
|
488
|
+
MERGE (source)-[r:{relationship}]->(destination)
|
|
489
|
+
ON CREATE SET
|
|
490
|
+
r.created_at = timestamp(),
|
|
491
|
+
r.updated_at = timestamp()
|
|
492
|
+
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
|
493
|
+
"""
|
|
494
|
+
params = {
|
|
495
|
+
"source_id": source_node_search_result[0]["id(source_candidate)"],
|
|
496
|
+
"destination_id": destination_node_search_result[0]["id(destination_candidate)"],
|
|
497
|
+
"user_id": user_id,
|
|
498
|
+
}
|
|
499
|
+
if agent_id:
|
|
500
|
+
params["agent_id"] = agent_id
|
|
501
|
+
|
|
502
|
+
else:
|
|
503
|
+
cypher = f"""
|
|
504
|
+
MERGE (n:{source_type}:Entity {{name: $source_name, user_id: $user_id{agent_id_clause}}})
|
|
505
|
+
ON CREATE SET n.created = timestamp(), n.embedding = $source_embedding, n:Entity
|
|
506
|
+
ON MATCH SET n.embedding = $source_embedding
|
|
507
|
+
MERGE (m:{destination_type}:Entity {{name: $dest_name, user_id: $user_id{agent_id_clause}}})
|
|
508
|
+
ON CREATE SET m.created = timestamp(), m.embedding = $dest_embedding, m:Entity
|
|
509
|
+
ON MATCH SET m.embedding = $dest_embedding
|
|
510
|
+
MERGE (n)-[rel:{relationship}]->(m)
|
|
511
|
+
ON CREATE SET rel.created = timestamp()
|
|
512
|
+
RETURN n.name AS source, type(rel) AS relationship, m.name AS target
|
|
513
|
+
"""
|
|
514
|
+
params = {
|
|
515
|
+
"source_name": source,
|
|
516
|
+
"dest_name": destination,
|
|
517
|
+
"source_embedding": source_embedding,
|
|
518
|
+
"dest_embedding": dest_embedding,
|
|
519
|
+
"user_id": user_id,
|
|
520
|
+
}
|
|
521
|
+
if agent_id:
|
|
522
|
+
params["agent_id"] = agent_id
|
|
523
|
+
|
|
524
|
+
result = self.graph.query(cypher, params=params)
|
|
525
|
+
results.append(result)
|
|
526
|
+
return results
|
|
527
|
+
|
|
528
|
+
def _remove_spaces_from_entities(self, entity_list):
|
|
529
|
+
for item in entity_list:
|
|
530
|
+
item["source"] = item["source"].lower().replace(" ", "_")
|
|
531
|
+
# Use the sanitization function for relationships to handle special characters
|
|
532
|
+
item["relationship"] = sanitize_relationship_for_cypher(item["relationship"].lower().replace(" ", "_"))
|
|
533
|
+
item["destination"] = item["destination"].lower().replace(" ", "_")
|
|
534
|
+
return entity_list
|
|
535
|
+
|
|
536
|
+
def _search_source_node(self, source_embedding, filters, threshold=0.9):
|
|
537
|
+
"""Search for source nodes with similar embeddings."""
|
|
538
|
+
user_id = filters["user_id"]
|
|
539
|
+
agent_id = filters.get("agent_id", None)
|
|
540
|
+
|
|
541
|
+
if agent_id:
|
|
542
|
+
cypher = """
|
|
543
|
+
CALL vector_search.search("memzero", 1, $source_embedding)
|
|
544
|
+
YIELD distance, node, similarity
|
|
545
|
+
WITH node AS source_candidate, similarity
|
|
546
|
+
WHERE source_candidate.user_id = $user_id
|
|
547
|
+
AND source_candidate.agent_id = $agent_id
|
|
548
|
+
AND similarity >= $threshold
|
|
549
|
+
RETURN id(source_candidate);
|
|
550
|
+
"""
|
|
551
|
+
params = {
|
|
552
|
+
"source_embedding": source_embedding,
|
|
553
|
+
"user_id": user_id,
|
|
554
|
+
"agent_id": agent_id,
|
|
555
|
+
"threshold": threshold,
|
|
556
|
+
}
|
|
557
|
+
else:
|
|
558
|
+
cypher = """
|
|
559
|
+
CALL vector_search.search("memzero", 1, $source_embedding)
|
|
560
|
+
YIELD distance, node, similarity
|
|
561
|
+
WITH node AS source_candidate, similarity
|
|
562
|
+
WHERE source_candidate.user_id = $user_id
|
|
563
|
+
AND similarity >= $threshold
|
|
564
|
+
RETURN id(source_candidate);
|
|
565
|
+
"""
|
|
566
|
+
params = {
|
|
567
|
+
"source_embedding": source_embedding,
|
|
568
|
+
"user_id": user_id,
|
|
569
|
+
"threshold": threshold,
|
|
570
|
+
}
|
|
571
|
+
|
|
572
|
+
result = self.graph.query(cypher, params=params)
|
|
573
|
+
return result
|
|
574
|
+
|
|
575
|
+
def _search_destination_node(self, destination_embedding, filters, threshold=0.9):
|
|
576
|
+
"""Search for destination nodes with similar embeddings."""
|
|
577
|
+
user_id = filters["user_id"]
|
|
578
|
+
agent_id = filters.get("agent_id", None)
|
|
579
|
+
|
|
580
|
+
if agent_id:
|
|
581
|
+
cypher = """
|
|
582
|
+
CALL vector_search.search("memzero", 1, $destination_embedding)
|
|
583
|
+
YIELD distance, node, similarity
|
|
584
|
+
WITH node AS destination_candidate, similarity
|
|
585
|
+
WHERE node.user_id = $user_id
|
|
586
|
+
AND node.agent_id = $agent_id
|
|
587
|
+
AND similarity >= $threshold
|
|
588
|
+
RETURN id(destination_candidate);
|
|
589
|
+
"""
|
|
590
|
+
params = {
|
|
591
|
+
"destination_embedding": destination_embedding,
|
|
592
|
+
"user_id": user_id,
|
|
593
|
+
"agent_id": agent_id,
|
|
594
|
+
"threshold": threshold,
|
|
595
|
+
}
|
|
596
|
+
else:
|
|
597
|
+
cypher = """
|
|
598
|
+
CALL vector_search.search("memzero", 1, $destination_embedding)
|
|
599
|
+
YIELD distance, node, similarity
|
|
600
|
+
WITH node AS destination_candidate, similarity
|
|
601
|
+
WHERE node.user_id = $user_id
|
|
602
|
+
AND similarity >= $threshold
|
|
603
|
+
RETURN id(destination_candidate);
|
|
604
|
+
"""
|
|
605
|
+
params = {
|
|
606
|
+
"destination_embedding": destination_embedding,
|
|
607
|
+
"user_id": user_id,
|
|
608
|
+
"threshold": threshold,
|
|
609
|
+
}
|
|
610
|
+
|
|
611
|
+
result = self.graph.query(cypher, params=params)
|
|
612
|
+
return result
|
|
613
|
+
|
|
614
|
+
|
|
615
|
+
def _vector_index_exists(self, index_info, index_name):
|
|
616
|
+
"""
|
|
617
|
+
Check if a vector index exists, compatible with both Memgraph versions.
|
|
618
|
+
|
|
619
|
+
Args:
|
|
620
|
+
index_info (dict): Index information from _fetch_existing_indexes
|
|
621
|
+
index_name (str): Name of the index to check
|
|
622
|
+
|
|
623
|
+
Returns:
|
|
624
|
+
bool: True if index exists, False otherwise
|
|
625
|
+
"""
|
|
626
|
+
vector_indexes = index_info.get("vector_index_exists", [])
|
|
627
|
+
|
|
628
|
+
# Check for index by name regardless of version-specific format differences
|
|
629
|
+
return any(
|
|
630
|
+
idx.get("index_name") == index_name or
|
|
631
|
+
idx.get("index name") == index_name or
|
|
632
|
+
idx.get("name") == index_name
|
|
633
|
+
for idx in vector_indexes
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
def _label_property_index_exists(self, index_info, label, property_name):
|
|
637
|
+
"""
|
|
638
|
+
Check if a label+property index exists, compatible with both versions.
|
|
639
|
+
|
|
640
|
+
Args:
|
|
641
|
+
index_info (dict): Index information from _fetch_existing_indexes
|
|
642
|
+
label (str): Label name
|
|
643
|
+
property_name (str): Property name
|
|
644
|
+
|
|
645
|
+
Returns:
|
|
646
|
+
bool: True if index exists, False otherwise
|
|
647
|
+
"""
|
|
648
|
+
indexes = index_info.get("index_exists", [])
|
|
649
|
+
|
|
650
|
+
return any(
|
|
651
|
+
(idx.get("index type") == "label+property" or idx.get("index_type") == "label+property") and
|
|
652
|
+
(idx.get("label") == label) and
|
|
653
|
+
(idx.get("property") == property_name or property_name in str(idx.get("properties", "")))
|
|
654
|
+
for idx in indexes
|
|
655
|
+
)
|
|
656
|
+
|
|
657
|
+
def _label_index_exists(self, index_info, label):
|
|
658
|
+
"""
|
|
659
|
+
Check if a label index exists, compatible with both versions.
|
|
660
|
+
|
|
661
|
+
Args:
|
|
662
|
+
index_info (dict): Index information from _fetch_existing_indexes
|
|
663
|
+
label (str): Label name
|
|
664
|
+
|
|
665
|
+
Returns:
|
|
666
|
+
bool: True if index exists, False otherwise
|
|
667
|
+
"""
|
|
668
|
+
indexes = index_info.get("index_exists", [])
|
|
669
|
+
|
|
670
|
+
return any(
|
|
671
|
+
(idx.get("index type") == "label" or idx.get("index_type") == "label") and
|
|
672
|
+
(idx.get("label") == label)
|
|
673
|
+
for idx in indexes
|
|
674
|
+
)
|
|
675
|
+
|
|
676
|
+
def _fetch_existing_indexes(self):
|
|
677
|
+
"""
|
|
678
|
+
Retrieves information about existing indexes and vector indexes in the Memgraph database.
|
|
679
|
+
|
|
680
|
+
Returns:
|
|
681
|
+
dict: A dictionary containing lists of existing indexes and vector indexes.
|
|
682
|
+
"""
|
|
683
|
+
try:
|
|
684
|
+
index_exists = list(self.graph.query("SHOW INDEX INFO;"))
|
|
685
|
+
vector_index_exists = list(self.graph.query("SHOW VECTOR INDEX INFO;"))
|
|
686
|
+
return {"index_exists": index_exists, "vector_index_exists": vector_index_exists}
|
|
687
|
+
except Exception as e:
|
|
688
|
+
logger.warning(f"Error fetching indexes: {e}. Returning empty index info.")
|
|
689
|
+
return {"index_exists": [], "vector_index_exists": []}
|