agentrun-mem0ai 0.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentrun_mem0/__init__.py +6 -0
- agentrun_mem0/client/__init__.py +0 -0
- agentrun_mem0/client/main.py +1747 -0
- agentrun_mem0/client/project.py +931 -0
- agentrun_mem0/client/utils.py +115 -0
- agentrun_mem0/configs/__init__.py +0 -0
- agentrun_mem0/configs/base.py +90 -0
- agentrun_mem0/configs/embeddings/__init__.py +0 -0
- agentrun_mem0/configs/embeddings/base.py +110 -0
- agentrun_mem0/configs/enums.py +7 -0
- agentrun_mem0/configs/llms/__init__.py +0 -0
- agentrun_mem0/configs/llms/anthropic.py +56 -0
- agentrun_mem0/configs/llms/aws_bedrock.py +192 -0
- agentrun_mem0/configs/llms/azure.py +57 -0
- agentrun_mem0/configs/llms/base.py +62 -0
- agentrun_mem0/configs/llms/deepseek.py +56 -0
- agentrun_mem0/configs/llms/lmstudio.py +59 -0
- agentrun_mem0/configs/llms/ollama.py +56 -0
- agentrun_mem0/configs/llms/openai.py +79 -0
- agentrun_mem0/configs/llms/vllm.py +56 -0
- agentrun_mem0/configs/prompts.py +459 -0
- agentrun_mem0/configs/rerankers/__init__.py +0 -0
- agentrun_mem0/configs/rerankers/base.py +17 -0
- agentrun_mem0/configs/rerankers/cohere.py +15 -0
- agentrun_mem0/configs/rerankers/config.py +12 -0
- agentrun_mem0/configs/rerankers/huggingface.py +17 -0
- agentrun_mem0/configs/rerankers/llm.py +48 -0
- agentrun_mem0/configs/rerankers/sentence_transformer.py +16 -0
- agentrun_mem0/configs/rerankers/zero_entropy.py +28 -0
- agentrun_mem0/configs/vector_stores/__init__.py +0 -0
- agentrun_mem0/configs/vector_stores/alibabacloud_mysql.py +64 -0
- agentrun_mem0/configs/vector_stores/aliyun_tablestore.py +32 -0
- agentrun_mem0/configs/vector_stores/azure_ai_search.py +57 -0
- agentrun_mem0/configs/vector_stores/azure_mysql.py +84 -0
- agentrun_mem0/configs/vector_stores/baidu.py +27 -0
- agentrun_mem0/configs/vector_stores/chroma.py +58 -0
- agentrun_mem0/configs/vector_stores/databricks.py +61 -0
- agentrun_mem0/configs/vector_stores/elasticsearch.py +65 -0
- agentrun_mem0/configs/vector_stores/faiss.py +37 -0
- agentrun_mem0/configs/vector_stores/langchain.py +30 -0
- agentrun_mem0/configs/vector_stores/milvus.py +42 -0
- agentrun_mem0/configs/vector_stores/mongodb.py +25 -0
- agentrun_mem0/configs/vector_stores/neptune.py +27 -0
- agentrun_mem0/configs/vector_stores/opensearch.py +41 -0
- agentrun_mem0/configs/vector_stores/pgvector.py +52 -0
- agentrun_mem0/configs/vector_stores/pinecone.py +55 -0
- agentrun_mem0/configs/vector_stores/qdrant.py +47 -0
- agentrun_mem0/configs/vector_stores/redis.py +24 -0
- agentrun_mem0/configs/vector_stores/s3_vectors.py +28 -0
- agentrun_mem0/configs/vector_stores/supabase.py +44 -0
- agentrun_mem0/configs/vector_stores/upstash_vector.py +34 -0
- agentrun_mem0/configs/vector_stores/valkey.py +15 -0
- agentrun_mem0/configs/vector_stores/vertex_ai_vector_search.py +28 -0
- agentrun_mem0/configs/vector_stores/weaviate.py +41 -0
- agentrun_mem0/embeddings/__init__.py +0 -0
- agentrun_mem0/embeddings/aws_bedrock.py +100 -0
- agentrun_mem0/embeddings/azure_openai.py +55 -0
- agentrun_mem0/embeddings/base.py +31 -0
- agentrun_mem0/embeddings/configs.py +30 -0
- agentrun_mem0/embeddings/gemini.py +39 -0
- agentrun_mem0/embeddings/huggingface.py +44 -0
- agentrun_mem0/embeddings/langchain.py +35 -0
- agentrun_mem0/embeddings/lmstudio.py +29 -0
- agentrun_mem0/embeddings/mock.py +11 -0
- agentrun_mem0/embeddings/ollama.py +53 -0
- agentrun_mem0/embeddings/openai.py +49 -0
- agentrun_mem0/embeddings/together.py +31 -0
- agentrun_mem0/embeddings/vertexai.py +64 -0
- agentrun_mem0/exceptions.py +503 -0
- agentrun_mem0/graphs/__init__.py +0 -0
- agentrun_mem0/graphs/configs.py +105 -0
- agentrun_mem0/graphs/neptune/__init__.py +0 -0
- agentrun_mem0/graphs/neptune/base.py +497 -0
- agentrun_mem0/graphs/neptune/neptunedb.py +511 -0
- agentrun_mem0/graphs/neptune/neptunegraph.py +474 -0
- agentrun_mem0/graphs/tools.py +371 -0
- agentrun_mem0/graphs/utils.py +97 -0
- agentrun_mem0/llms/__init__.py +0 -0
- agentrun_mem0/llms/anthropic.py +87 -0
- agentrun_mem0/llms/aws_bedrock.py +665 -0
- agentrun_mem0/llms/azure_openai.py +141 -0
- agentrun_mem0/llms/azure_openai_structured.py +91 -0
- agentrun_mem0/llms/base.py +131 -0
- agentrun_mem0/llms/configs.py +34 -0
- agentrun_mem0/llms/deepseek.py +107 -0
- agentrun_mem0/llms/gemini.py +201 -0
- agentrun_mem0/llms/groq.py +88 -0
- agentrun_mem0/llms/langchain.py +94 -0
- agentrun_mem0/llms/litellm.py +87 -0
- agentrun_mem0/llms/lmstudio.py +114 -0
- agentrun_mem0/llms/ollama.py +117 -0
- agentrun_mem0/llms/openai.py +147 -0
- agentrun_mem0/llms/openai_structured.py +52 -0
- agentrun_mem0/llms/sarvam.py +89 -0
- agentrun_mem0/llms/together.py +88 -0
- agentrun_mem0/llms/vllm.py +107 -0
- agentrun_mem0/llms/xai.py +52 -0
- agentrun_mem0/memory/__init__.py +0 -0
- agentrun_mem0/memory/base.py +63 -0
- agentrun_mem0/memory/graph_memory.py +698 -0
- agentrun_mem0/memory/kuzu_memory.py +713 -0
- agentrun_mem0/memory/main.py +2229 -0
- agentrun_mem0/memory/memgraph_memory.py +689 -0
- agentrun_mem0/memory/setup.py +56 -0
- agentrun_mem0/memory/storage.py +218 -0
- agentrun_mem0/memory/telemetry.py +90 -0
- agentrun_mem0/memory/utils.py +208 -0
- agentrun_mem0/proxy/__init__.py +0 -0
- agentrun_mem0/proxy/main.py +189 -0
- agentrun_mem0/reranker/__init__.py +9 -0
- agentrun_mem0/reranker/base.py +20 -0
- agentrun_mem0/reranker/cohere_reranker.py +85 -0
- agentrun_mem0/reranker/huggingface_reranker.py +147 -0
- agentrun_mem0/reranker/llm_reranker.py +142 -0
- agentrun_mem0/reranker/sentence_transformer_reranker.py +107 -0
- agentrun_mem0/reranker/zero_entropy_reranker.py +96 -0
- agentrun_mem0/utils/factory.py +283 -0
- agentrun_mem0/utils/gcp_auth.py +167 -0
- agentrun_mem0/vector_stores/__init__.py +0 -0
- agentrun_mem0/vector_stores/alibabacloud_mysql.py +547 -0
- agentrun_mem0/vector_stores/aliyun_tablestore.py +252 -0
- agentrun_mem0/vector_stores/azure_ai_search.py +396 -0
- agentrun_mem0/vector_stores/azure_mysql.py +463 -0
- agentrun_mem0/vector_stores/baidu.py +368 -0
- agentrun_mem0/vector_stores/base.py +58 -0
- agentrun_mem0/vector_stores/chroma.py +332 -0
- agentrun_mem0/vector_stores/configs.py +67 -0
- agentrun_mem0/vector_stores/databricks.py +761 -0
- agentrun_mem0/vector_stores/elasticsearch.py +237 -0
- agentrun_mem0/vector_stores/faiss.py +479 -0
- agentrun_mem0/vector_stores/langchain.py +180 -0
- agentrun_mem0/vector_stores/milvus.py +250 -0
- agentrun_mem0/vector_stores/mongodb.py +310 -0
- agentrun_mem0/vector_stores/neptune_analytics.py +467 -0
- agentrun_mem0/vector_stores/opensearch.py +292 -0
- agentrun_mem0/vector_stores/pgvector.py +404 -0
- agentrun_mem0/vector_stores/pinecone.py +382 -0
- agentrun_mem0/vector_stores/qdrant.py +270 -0
- agentrun_mem0/vector_stores/redis.py +295 -0
- agentrun_mem0/vector_stores/s3_vectors.py +176 -0
- agentrun_mem0/vector_stores/supabase.py +237 -0
- agentrun_mem0/vector_stores/upstash_vector.py +293 -0
- agentrun_mem0/vector_stores/valkey.py +824 -0
- agentrun_mem0/vector_stores/vertex_ai_vector_search.py +635 -0
- agentrun_mem0/vector_stores/weaviate.py +343 -0
- agentrun_mem0ai-0.0.11.data/data/README.md +205 -0
- agentrun_mem0ai-0.0.11.dist-info/METADATA +277 -0
- agentrun_mem0ai-0.0.11.dist-info/RECORD +150 -0
- agentrun_mem0ai-0.0.11.dist-info/WHEEL +4 -0
- agentrun_mem0ai-0.0.11.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,713 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from agentrun_mem0.memory.utils import format_entities
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
import kuzu
|
|
7
|
+
except ImportError:
|
|
8
|
+
raise ImportError("kuzu is not installed. Please install it using pip install kuzu")
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
from rank_bm25 import BM25Okapi
|
|
12
|
+
except ImportError:
|
|
13
|
+
raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25")
|
|
14
|
+
|
|
15
|
+
from agentrun_mem0.graphs.tools import (
|
|
16
|
+
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
|
|
17
|
+
DELETE_MEMORY_TOOL_GRAPH,
|
|
18
|
+
EXTRACT_ENTITIES_STRUCT_TOOL,
|
|
19
|
+
EXTRACT_ENTITIES_TOOL,
|
|
20
|
+
RELATIONS_STRUCT_TOOL,
|
|
21
|
+
RELATIONS_TOOL,
|
|
22
|
+
)
|
|
23
|
+
from agentrun_mem0.graphs.utils import EXTRACT_RELATIONS_PROMPT, get_delete_messages
|
|
24
|
+
from agentrun_mem0.utils.factory import EmbedderFactory, LlmFactory
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class MemoryGraph:
|
|
30
|
+
def __init__(self, config):
|
|
31
|
+
self.config = config
|
|
32
|
+
|
|
33
|
+
self.embedding_model = EmbedderFactory.create(
|
|
34
|
+
self.config.embedder.provider,
|
|
35
|
+
self.config.embedder.config,
|
|
36
|
+
self.config.vector_store.config,
|
|
37
|
+
)
|
|
38
|
+
self.embedding_dims = self.embedding_model.config.embedding_dims
|
|
39
|
+
|
|
40
|
+
if self.embedding_dims is None or self.embedding_dims <= 0:
|
|
41
|
+
raise ValueError(f"embedding_dims must be a positive integer. Given: {self.embedding_dims}")
|
|
42
|
+
|
|
43
|
+
self.db = kuzu.Database(self.config.graph_store.config.db)
|
|
44
|
+
self.graph = kuzu.Connection(self.db)
|
|
45
|
+
|
|
46
|
+
self.node_label = ":Entity"
|
|
47
|
+
self.rel_label = ":CONNECTED_TO"
|
|
48
|
+
self.kuzu_create_schema()
|
|
49
|
+
|
|
50
|
+
# Default to openai if no specific provider is configured
|
|
51
|
+
self.llm_provider = "openai"
|
|
52
|
+
if self.config.llm and self.config.llm.provider:
|
|
53
|
+
self.llm_provider = self.config.llm.provider
|
|
54
|
+
if self.config.graph_store and self.config.graph_store.llm and self.config.graph_store.llm.provider:
|
|
55
|
+
self.llm_provider = self.config.graph_store.llm.provider
|
|
56
|
+
# Get LLM config with proper null checks
|
|
57
|
+
llm_config = None
|
|
58
|
+
if self.config.graph_store and self.config.graph_store.llm and hasattr(self.config.graph_store.llm, "config"):
|
|
59
|
+
llm_config = self.config.graph_store.llm.config
|
|
60
|
+
elif hasattr(self.config.llm, "config"):
|
|
61
|
+
llm_config = self.config.llm.config
|
|
62
|
+
self.llm = LlmFactory.create(self.llm_provider, llm_config)
|
|
63
|
+
|
|
64
|
+
self.user_id = None
|
|
65
|
+
self.threshold = 0.7
|
|
66
|
+
|
|
67
|
+
def kuzu_create_schema(self):
|
|
68
|
+
self.kuzu_execute(
|
|
69
|
+
"""
|
|
70
|
+
CREATE NODE TABLE IF NOT EXISTS Entity(
|
|
71
|
+
id SERIAL PRIMARY KEY,
|
|
72
|
+
user_id STRING,
|
|
73
|
+
agent_id STRING,
|
|
74
|
+
run_id STRING,
|
|
75
|
+
name STRING,
|
|
76
|
+
mentions INT64,
|
|
77
|
+
created TIMESTAMP,
|
|
78
|
+
embedding FLOAT[]);
|
|
79
|
+
"""
|
|
80
|
+
)
|
|
81
|
+
self.kuzu_execute(
|
|
82
|
+
"""
|
|
83
|
+
CREATE REL TABLE IF NOT EXISTS CONNECTED_TO(
|
|
84
|
+
FROM Entity TO Entity,
|
|
85
|
+
name STRING,
|
|
86
|
+
mentions INT64,
|
|
87
|
+
created TIMESTAMP,
|
|
88
|
+
updated TIMESTAMP
|
|
89
|
+
);
|
|
90
|
+
"""
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
def kuzu_execute(self, query, parameters=None):
|
|
94
|
+
results = self.graph.execute(query, parameters)
|
|
95
|
+
return list(results.rows_as_dict())
|
|
96
|
+
|
|
97
|
+
def add(self, data, filters):
|
|
98
|
+
"""
|
|
99
|
+
Adds data to the graph.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
data (str): The data to add to the graph.
|
|
103
|
+
filters (dict): A dictionary containing filters to be applied during the addition.
|
|
104
|
+
"""
|
|
105
|
+
entity_type_map = self._retrieve_nodes_from_data(data, filters)
|
|
106
|
+
to_be_added = self._establish_nodes_relations_from_data(data, filters, entity_type_map)
|
|
107
|
+
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
|
|
108
|
+
to_be_deleted = self._get_delete_entities_from_search_output(search_output, data, filters)
|
|
109
|
+
|
|
110
|
+
deleted_entities = self._delete_entities(to_be_deleted, filters)
|
|
111
|
+
added_entities = self._add_entities(to_be_added, filters, entity_type_map)
|
|
112
|
+
|
|
113
|
+
return {"deleted_entities": deleted_entities, "added_entities": added_entities}
|
|
114
|
+
|
|
115
|
+
def search(self, query, filters, limit=5):
|
|
116
|
+
"""
|
|
117
|
+
Search for memories and related graph data.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
query (str): Query to search for.
|
|
121
|
+
filters (dict): A dictionary containing filters to be applied during the search.
|
|
122
|
+
limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
dict: A dictionary containing:
|
|
126
|
+
- "contexts": List of search results from the base data store.
|
|
127
|
+
- "entities": List of related graph data based on the query.
|
|
128
|
+
"""
|
|
129
|
+
entity_type_map = self._retrieve_nodes_from_data(query, filters)
|
|
130
|
+
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
|
|
131
|
+
|
|
132
|
+
if not search_output:
|
|
133
|
+
return []
|
|
134
|
+
|
|
135
|
+
search_outputs_sequence = [
|
|
136
|
+
[item["source"], item["relationship"], item["destination"]] for item in search_output
|
|
137
|
+
]
|
|
138
|
+
bm25 = BM25Okapi(search_outputs_sequence)
|
|
139
|
+
|
|
140
|
+
tokenized_query = query.split(" ")
|
|
141
|
+
reranked_results = bm25.get_top_n(tokenized_query, search_outputs_sequence, n=limit)
|
|
142
|
+
|
|
143
|
+
search_results = []
|
|
144
|
+
for item in reranked_results:
|
|
145
|
+
search_results.append({"source": item[0], "relationship": item[1], "destination": item[2]})
|
|
146
|
+
|
|
147
|
+
logger.info(f"Returned {len(search_results)} search results")
|
|
148
|
+
|
|
149
|
+
return search_results
|
|
150
|
+
|
|
151
|
+
def delete_all(self, filters):
|
|
152
|
+
# Build node properties for filtering
|
|
153
|
+
node_props = ["user_id: $user_id"]
|
|
154
|
+
if filters.get("agent_id"):
|
|
155
|
+
node_props.append("agent_id: $agent_id")
|
|
156
|
+
if filters.get("run_id"):
|
|
157
|
+
node_props.append("run_id: $run_id")
|
|
158
|
+
node_props_str = ", ".join(node_props)
|
|
159
|
+
|
|
160
|
+
cypher = f"""
|
|
161
|
+
MATCH (n {self.node_label} {{{node_props_str}}})
|
|
162
|
+
DETACH DELETE n
|
|
163
|
+
"""
|
|
164
|
+
params = {"user_id": filters["user_id"]}
|
|
165
|
+
if filters.get("agent_id"):
|
|
166
|
+
params["agent_id"] = filters["agent_id"]
|
|
167
|
+
if filters.get("run_id"):
|
|
168
|
+
params["run_id"] = filters["run_id"]
|
|
169
|
+
self.kuzu_execute(cypher, parameters=params)
|
|
170
|
+
|
|
171
|
+
def get_all(self, filters, limit=100):
|
|
172
|
+
"""
|
|
173
|
+
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
|
|
174
|
+
Args:
|
|
175
|
+
filters (dict): A dictionary containing filters to be applied during the retrieval.
|
|
176
|
+
limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100.
|
|
177
|
+
Returns:
|
|
178
|
+
list: A list of dictionaries, each containing:
|
|
179
|
+
- 'contexts': The base data store response for each memory.
|
|
180
|
+
- 'entities': A list of strings representing the nodes and relationships
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
params = {
|
|
184
|
+
"user_id": filters["user_id"],
|
|
185
|
+
"limit": limit,
|
|
186
|
+
}
|
|
187
|
+
# Build node properties based on filters
|
|
188
|
+
node_props = ["user_id: $user_id"]
|
|
189
|
+
if filters.get("agent_id"):
|
|
190
|
+
node_props.append("agent_id: $agent_id")
|
|
191
|
+
params["agent_id"] = filters["agent_id"]
|
|
192
|
+
if filters.get("run_id"):
|
|
193
|
+
node_props.append("run_id: $run_id")
|
|
194
|
+
params["run_id"] = filters["run_id"]
|
|
195
|
+
node_props_str = ", ".join(node_props)
|
|
196
|
+
|
|
197
|
+
query = f"""
|
|
198
|
+
MATCH (n {self.node_label} {{{node_props_str}}})-[r]->(m {self.node_label} {{{node_props_str}}})
|
|
199
|
+
RETURN
|
|
200
|
+
n.name AS source,
|
|
201
|
+
r.name AS relationship,
|
|
202
|
+
m.name AS target
|
|
203
|
+
LIMIT $limit
|
|
204
|
+
"""
|
|
205
|
+
results = self.kuzu_execute(query, parameters=params)
|
|
206
|
+
|
|
207
|
+
final_results = []
|
|
208
|
+
for result in results:
|
|
209
|
+
final_results.append(
|
|
210
|
+
{
|
|
211
|
+
"source": result["source"],
|
|
212
|
+
"relationship": result["relationship"],
|
|
213
|
+
"target": result["target"],
|
|
214
|
+
}
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
logger.info(f"Retrieved {len(final_results)} relationships")
|
|
218
|
+
|
|
219
|
+
return final_results
|
|
220
|
+
|
|
221
|
+
def _retrieve_nodes_from_data(self, data, filters):
|
|
222
|
+
"""Extracts all the entities mentioned in the query."""
|
|
223
|
+
_tools = [EXTRACT_ENTITIES_TOOL]
|
|
224
|
+
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
|
225
|
+
_tools = [EXTRACT_ENTITIES_STRUCT_TOOL]
|
|
226
|
+
search_results = self.llm.generate_response(
|
|
227
|
+
messages=[
|
|
228
|
+
{
|
|
229
|
+
"role": "system",
|
|
230
|
+
"content": f"You are a smart assistant who understands entities and their types in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source entity. Extract all the entities from the text. ***DO NOT*** answer the question itself if the given text is a question.",
|
|
231
|
+
},
|
|
232
|
+
{"role": "user", "content": data},
|
|
233
|
+
],
|
|
234
|
+
tools=_tools,
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
entity_type_map = {}
|
|
238
|
+
|
|
239
|
+
try:
|
|
240
|
+
for tool_call in search_results["tool_calls"]:
|
|
241
|
+
if tool_call["name"] != "extract_entities":
|
|
242
|
+
continue
|
|
243
|
+
for item in tool_call["arguments"]["entities"]:
|
|
244
|
+
entity_type_map[item["entity"]] = item["entity_type"]
|
|
245
|
+
except Exception as e:
|
|
246
|
+
logger.exception(
|
|
247
|
+
f"Error in search tool: {e}, llm_provider={self.llm_provider}, search_results={search_results}"
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
entity_type_map = {k.lower().replace(" ", "_"): v.lower().replace(" ", "_") for k, v in entity_type_map.items()}
|
|
251
|
+
logger.debug(f"Entity type map: {entity_type_map}\n search_results={search_results}")
|
|
252
|
+
return entity_type_map
|
|
253
|
+
|
|
254
|
+
def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
|
|
255
|
+
"""Establish relations among the extracted nodes."""
|
|
256
|
+
|
|
257
|
+
# Compose user identification string for prompt
|
|
258
|
+
user_identity = f"user_id: {filters['user_id']}"
|
|
259
|
+
if filters.get("agent_id"):
|
|
260
|
+
user_identity += f", agent_id: {filters['agent_id']}"
|
|
261
|
+
if filters.get("run_id"):
|
|
262
|
+
user_identity += f", run_id: {filters['run_id']}"
|
|
263
|
+
|
|
264
|
+
if self.config.graph_store.custom_prompt:
|
|
265
|
+
system_content = EXTRACT_RELATIONS_PROMPT.replace("USER_ID", user_identity)
|
|
266
|
+
# Add the custom prompt line if configured
|
|
267
|
+
system_content = system_content.replace("CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}")
|
|
268
|
+
messages = [
|
|
269
|
+
{"role": "system", "content": system_content},
|
|
270
|
+
{"role": "user", "content": data},
|
|
271
|
+
]
|
|
272
|
+
else:
|
|
273
|
+
system_content = EXTRACT_RELATIONS_PROMPT.replace("USER_ID", user_identity)
|
|
274
|
+
messages = [
|
|
275
|
+
{"role": "system", "content": system_content},
|
|
276
|
+
{"role": "user", "content": f"List of entities: {list(entity_type_map.keys())}. \n\nText: {data}"},
|
|
277
|
+
]
|
|
278
|
+
|
|
279
|
+
_tools = [RELATIONS_TOOL]
|
|
280
|
+
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
|
281
|
+
_tools = [RELATIONS_STRUCT_TOOL]
|
|
282
|
+
|
|
283
|
+
extracted_entities = self.llm.generate_response(
|
|
284
|
+
messages=messages,
|
|
285
|
+
tools=_tools,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
entities = []
|
|
289
|
+
if extracted_entities.get("tool_calls"):
|
|
290
|
+
entities = extracted_entities["tool_calls"][0].get("arguments", {}).get("entities", [])
|
|
291
|
+
|
|
292
|
+
entities = self._remove_spaces_from_entities(entities)
|
|
293
|
+
logger.debug(f"Extracted entities: {entities}")
|
|
294
|
+
return entities
|
|
295
|
+
|
|
296
|
+
def _search_graph_db(self, node_list, filters, limit=100, threshold=None):
|
|
297
|
+
"""Search similar nodes among and their respective incoming and outgoing relations."""
|
|
298
|
+
result_relations = []
|
|
299
|
+
|
|
300
|
+
params = {
|
|
301
|
+
"threshold": threshold if threshold else self.threshold,
|
|
302
|
+
"user_id": filters["user_id"],
|
|
303
|
+
"limit": limit,
|
|
304
|
+
}
|
|
305
|
+
# Build node properties for filtering
|
|
306
|
+
node_props = ["user_id: $user_id"]
|
|
307
|
+
if filters.get("agent_id"):
|
|
308
|
+
node_props.append("agent_id: $agent_id")
|
|
309
|
+
params["agent_id"] = filters["agent_id"]
|
|
310
|
+
if filters.get("run_id"):
|
|
311
|
+
node_props.append("run_id: $run_id")
|
|
312
|
+
params["run_id"] = filters["run_id"]
|
|
313
|
+
node_props_str = ", ".join(node_props)
|
|
314
|
+
|
|
315
|
+
for node in node_list:
|
|
316
|
+
n_embedding = self.embedding_model.embed(node)
|
|
317
|
+
params["n_embedding"] = n_embedding
|
|
318
|
+
|
|
319
|
+
results = []
|
|
320
|
+
for match_fragment in [
|
|
321
|
+
f"(n)-[r]->(m {self.node_label} {{{node_props_str}}}) WITH n as src, r, m as dst, similarity",
|
|
322
|
+
f"(m {self.node_label} {{{node_props_str}}})-[r]->(n) WITH m as src, r, n as dst, similarity"
|
|
323
|
+
]:
|
|
324
|
+
results.extend(self.kuzu_execute(
|
|
325
|
+
f"""
|
|
326
|
+
MATCH (n {self.node_label} {{{node_props_str}}})
|
|
327
|
+
WHERE n.embedding IS NOT NULL
|
|
328
|
+
WITH n, array_cosine_similarity(n.embedding, CAST($n_embedding,'FLOAT[{self.embedding_dims}]')) AS similarity
|
|
329
|
+
WHERE similarity >= CAST($threshold, 'DOUBLE')
|
|
330
|
+
MATCH {match_fragment}
|
|
331
|
+
RETURN
|
|
332
|
+
src.name AS source,
|
|
333
|
+
id(src) AS source_id,
|
|
334
|
+
r.name AS relationship,
|
|
335
|
+
id(r) AS relation_id,
|
|
336
|
+
dst.name AS destination,
|
|
337
|
+
id(dst) AS destination_id,
|
|
338
|
+
similarity
|
|
339
|
+
LIMIT $limit
|
|
340
|
+
""",
|
|
341
|
+
parameters=params))
|
|
342
|
+
|
|
343
|
+
# Kuzu does not support sort/limit over unions. Do it manually for now.
|
|
344
|
+
result_relations.extend(sorted(results, key=lambda x: x["similarity"], reverse=True)[:limit])
|
|
345
|
+
|
|
346
|
+
return result_relations
|
|
347
|
+
|
|
348
|
+
def _get_delete_entities_from_search_output(self, search_output, data, filters):
|
|
349
|
+
"""Get the entities to be deleted from the search output."""
|
|
350
|
+
search_output_string = format_entities(search_output)
|
|
351
|
+
|
|
352
|
+
# Compose user identification string for prompt
|
|
353
|
+
user_identity = f"user_id: {filters['user_id']}"
|
|
354
|
+
if filters.get("agent_id"):
|
|
355
|
+
user_identity += f", agent_id: {filters['agent_id']}"
|
|
356
|
+
if filters.get("run_id"):
|
|
357
|
+
user_identity += f", run_id: {filters['run_id']}"
|
|
358
|
+
|
|
359
|
+
system_prompt, user_prompt = get_delete_messages(search_output_string, data, user_identity)
|
|
360
|
+
|
|
361
|
+
_tools = [DELETE_MEMORY_TOOL_GRAPH]
|
|
362
|
+
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
|
|
363
|
+
_tools = [
|
|
364
|
+
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
|
|
365
|
+
]
|
|
366
|
+
|
|
367
|
+
memory_updates = self.llm.generate_response(
|
|
368
|
+
messages=[
|
|
369
|
+
{"role": "system", "content": system_prompt},
|
|
370
|
+
{"role": "user", "content": user_prompt},
|
|
371
|
+
],
|
|
372
|
+
tools=_tools,
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
to_be_deleted = []
|
|
376
|
+
for item in memory_updates.get("tool_calls", []):
|
|
377
|
+
if item.get("name") == "delete_graph_memory":
|
|
378
|
+
to_be_deleted.append(item.get("arguments"))
|
|
379
|
+
# Clean entities formatting
|
|
380
|
+
to_be_deleted = self._remove_spaces_from_entities(to_be_deleted)
|
|
381
|
+
logger.debug(f"Deleted relationships: {to_be_deleted}")
|
|
382
|
+
return to_be_deleted
|
|
383
|
+
|
|
384
|
+
def _delete_entities(self, to_be_deleted, filters):
|
|
385
|
+
"""Delete the entities from the graph."""
|
|
386
|
+
user_id = filters["user_id"]
|
|
387
|
+
agent_id = filters.get("agent_id", None)
|
|
388
|
+
run_id = filters.get("run_id", None)
|
|
389
|
+
results = []
|
|
390
|
+
|
|
391
|
+
for item in to_be_deleted:
|
|
392
|
+
source = item["source"]
|
|
393
|
+
destination = item["destination"]
|
|
394
|
+
relationship = item["relationship"]
|
|
395
|
+
|
|
396
|
+
params = {
|
|
397
|
+
"source_name": source,
|
|
398
|
+
"dest_name": destination,
|
|
399
|
+
"user_id": user_id,
|
|
400
|
+
"relationship_name": relationship,
|
|
401
|
+
}
|
|
402
|
+
# Build node properties for filtering
|
|
403
|
+
source_props = ["name: $source_name", "user_id: $user_id"]
|
|
404
|
+
dest_props = ["name: $dest_name", "user_id: $user_id"]
|
|
405
|
+
if agent_id:
|
|
406
|
+
source_props.append("agent_id: $agent_id")
|
|
407
|
+
dest_props.append("agent_id: $agent_id")
|
|
408
|
+
params["agent_id"] = agent_id
|
|
409
|
+
if run_id:
|
|
410
|
+
source_props.append("run_id: $run_id")
|
|
411
|
+
dest_props.append("run_id: $run_id")
|
|
412
|
+
params["run_id"] = run_id
|
|
413
|
+
source_props_str = ", ".join(source_props)
|
|
414
|
+
dest_props_str = ", ".join(dest_props)
|
|
415
|
+
|
|
416
|
+
# Delete the specific relationship between nodes
|
|
417
|
+
cypher = f"""
|
|
418
|
+
MATCH (n {self.node_label} {{{source_props_str}}})
|
|
419
|
+
-[r {self.rel_label} {{name: $relationship_name}}]->
|
|
420
|
+
(m {self.node_label} {{{dest_props_str}}})
|
|
421
|
+
DELETE r
|
|
422
|
+
RETURN
|
|
423
|
+
n.name AS source,
|
|
424
|
+
r.name AS relationship,
|
|
425
|
+
m.name AS target
|
|
426
|
+
"""
|
|
427
|
+
|
|
428
|
+
result = self.kuzu_execute(cypher, parameters=params)
|
|
429
|
+
results.append(result)
|
|
430
|
+
|
|
431
|
+
return results
|
|
432
|
+
|
|
433
|
+
def _add_entities(self, to_be_added, filters, entity_type_map):
|
|
434
|
+
"""Add the new entities to the graph. Merge the nodes if they already exist."""
|
|
435
|
+
user_id = filters["user_id"]
|
|
436
|
+
agent_id = filters.get("agent_id", None)
|
|
437
|
+
run_id = filters.get("run_id", None)
|
|
438
|
+
results = []
|
|
439
|
+
for item in to_be_added:
|
|
440
|
+
# entities
|
|
441
|
+
source = item["source"]
|
|
442
|
+
source_label = self.node_label
|
|
443
|
+
|
|
444
|
+
destination = item["destination"]
|
|
445
|
+
destination_label = self.node_label
|
|
446
|
+
|
|
447
|
+
relationship = item["relationship"]
|
|
448
|
+
relationship_label = self.rel_label
|
|
449
|
+
|
|
450
|
+
# embeddings
|
|
451
|
+
source_embedding = self.embedding_model.embed(source)
|
|
452
|
+
dest_embedding = self.embedding_model.embed(destination)
|
|
453
|
+
|
|
454
|
+
# search for the nodes with the closest embeddings
|
|
455
|
+
source_node_search_result = self._search_source_node(source_embedding, filters, threshold=0.9)
|
|
456
|
+
destination_node_search_result = self._search_destination_node(dest_embedding, filters, threshold=0.9)
|
|
457
|
+
|
|
458
|
+
if not destination_node_search_result and source_node_search_result:
|
|
459
|
+
params = {
|
|
460
|
+
"table_id": source_node_search_result[0]["id"]["table"],
|
|
461
|
+
"offset_id": source_node_search_result[0]["id"]["offset"],
|
|
462
|
+
"destination_name": destination,
|
|
463
|
+
"destination_embedding": dest_embedding,
|
|
464
|
+
"relationship_name": relationship,
|
|
465
|
+
"user_id": user_id,
|
|
466
|
+
}
|
|
467
|
+
# Build source MERGE properties
|
|
468
|
+
merge_props = ["name: $destination_name", "user_id: $user_id"]
|
|
469
|
+
if agent_id:
|
|
470
|
+
merge_props.append("agent_id: $agent_id")
|
|
471
|
+
params["agent_id"] = agent_id
|
|
472
|
+
if run_id:
|
|
473
|
+
merge_props.append("run_id: $run_id")
|
|
474
|
+
params["run_id"] = run_id
|
|
475
|
+
merge_props_str = ", ".join(merge_props)
|
|
476
|
+
|
|
477
|
+
cypher = f"""
|
|
478
|
+
MATCH (source)
|
|
479
|
+
WHERE id(source) = internal_id($table_id, $offset_id)
|
|
480
|
+
SET source.mentions = coalesce(source.mentions, 0) + 1
|
|
481
|
+
WITH source
|
|
482
|
+
MERGE (destination {destination_label} {{{merge_props_str}}})
|
|
483
|
+
ON CREATE SET
|
|
484
|
+
destination.created = current_timestamp(),
|
|
485
|
+
destination.mentions = 1,
|
|
486
|
+
destination.embedding = CAST($destination_embedding,'FLOAT[{self.embedding_dims}]')
|
|
487
|
+
ON MATCH SET
|
|
488
|
+
destination.mentions = coalesce(destination.mentions, 0) + 1,
|
|
489
|
+
destination.embedding = CAST($destination_embedding,'FLOAT[{self.embedding_dims}]')
|
|
490
|
+
WITH source, destination
|
|
491
|
+
MERGE (source)-[r {relationship_label} {{name: $relationship_name}}]->(destination)
|
|
492
|
+
ON CREATE SET
|
|
493
|
+
r.created = current_timestamp(),
|
|
494
|
+
r.mentions = 1
|
|
495
|
+
ON MATCH SET
|
|
496
|
+
r.mentions = coalesce(r.mentions, 0) + 1
|
|
497
|
+
RETURN
|
|
498
|
+
source.name AS source,
|
|
499
|
+
r.name AS relationship,
|
|
500
|
+
destination.name AS target
|
|
501
|
+
"""
|
|
502
|
+
elif destination_node_search_result and not source_node_search_result:
|
|
503
|
+
params = {
|
|
504
|
+
"table_id": destination_node_search_result[0]["id"]["table"],
|
|
505
|
+
"offset_id": destination_node_search_result[0]["id"]["offset"],
|
|
506
|
+
"source_name": source,
|
|
507
|
+
"source_embedding": source_embedding,
|
|
508
|
+
"user_id": user_id,
|
|
509
|
+
"relationship_name": relationship,
|
|
510
|
+
}
|
|
511
|
+
# Build source MERGE properties
|
|
512
|
+
merge_props = ["name: $source_name", "user_id: $user_id"]
|
|
513
|
+
if agent_id:
|
|
514
|
+
merge_props.append("agent_id: $agent_id")
|
|
515
|
+
params["agent_id"] = agent_id
|
|
516
|
+
if run_id:
|
|
517
|
+
merge_props.append("run_id: $run_id")
|
|
518
|
+
params["run_id"] = run_id
|
|
519
|
+
merge_props_str = ", ".join(merge_props)
|
|
520
|
+
|
|
521
|
+
cypher = f"""
|
|
522
|
+
MATCH (destination)
|
|
523
|
+
WHERE id(destination) = internal_id($table_id, $offset_id)
|
|
524
|
+
SET destination.mentions = coalesce(destination.mentions, 0) + 1
|
|
525
|
+
WITH destination
|
|
526
|
+
MERGE (source {source_label} {{{merge_props_str}}})
|
|
527
|
+
ON CREATE SET
|
|
528
|
+
source.created = current_timestamp(),
|
|
529
|
+
source.mentions = 1,
|
|
530
|
+
source.embedding = CAST($source_embedding,'FLOAT[{self.embedding_dims}]')
|
|
531
|
+
ON MATCH SET
|
|
532
|
+
source.mentions = coalesce(source.mentions, 0) + 1,
|
|
533
|
+
source.embedding = CAST($source_embedding,'FLOAT[{self.embedding_dims}]')
|
|
534
|
+
WITH source, destination
|
|
535
|
+
MERGE (source)-[r {relationship_label} {{name: $relationship_name}}]->(destination)
|
|
536
|
+
ON CREATE SET
|
|
537
|
+
r.created = current_timestamp(),
|
|
538
|
+
r.mentions = 1
|
|
539
|
+
ON MATCH SET
|
|
540
|
+
r.mentions = coalesce(r.mentions, 0) + 1
|
|
541
|
+
RETURN
|
|
542
|
+
source.name AS source,
|
|
543
|
+
r.name AS relationship,
|
|
544
|
+
destination.name AS target
|
|
545
|
+
"""
|
|
546
|
+
elif source_node_search_result and destination_node_search_result:
|
|
547
|
+
cypher = f"""
|
|
548
|
+
MATCH (source)
|
|
549
|
+
WHERE id(source) = internal_id($src_table, $src_offset)
|
|
550
|
+
SET source.mentions = coalesce(source.mentions, 0) + 1
|
|
551
|
+
WITH source
|
|
552
|
+
MATCH (destination)
|
|
553
|
+
WHERE id(destination) = internal_id($dst_table, $dst_offset)
|
|
554
|
+
SET destination.mentions = coalesce(destination.mentions, 0) + 1
|
|
555
|
+
MERGE (source)-[r {relationship_label} {{name: $relationship_name}}]->(destination)
|
|
556
|
+
ON CREATE SET
|
|
557
|
+
r.created = current_timestamp(),
|
|
558
|
+
r.updated = current_timestamp(),
|
|
559
|
+
r.mentions = 1
|
|
560
|
+
ON MATCH SET r.mentions = coalesce(r.mentions, 0) + 1
|
|
561
|
+
RETURN
|
|
562
|
+
source.name AS source,
|
|
563
|
+
r.name AS relationship,
|
|
564
|
+
destination.name AS target
|
|
565
|
+
"""
|
|
566
|
+
|
|
567
|
+
params = {
|
|
568
|
+
"src_table": source_node_search_result[0]["id"]["table"],
|
|
569
|
+
"src_offset": source_node_search_result[0]["id"]["offset"],
|
|
570
|
+
"dst_table": destination_node_search_result[0]["id"]["table"],
|
|
571
|
+
"dst_offset": destination_node_search_result[0]["id"]["offset"],
|
|
572
|
+
"relationship_name": relationship,
|
|
573
|
+
}
|
|
574
|
+
else:
|
|
575
|
+
params = {
|
|
576
|
+
"source_name": source,
|
|
577
|
+
"dest_name": destination,
|
|
578
|
+
"relationship_name": relationship,
|
|
579
|
+
"source_embedding": source_embedding,
|
|
580
|
+
"dest_embedding": dest_embedding,
|
|
581
|
+
"user_id": user_id,
|
|
582
|
+
}
|
|
583
|
+
# Build dynamic MERGE props for both source and destination
|
|
584
|
+
source_props = ["name: $source_name", "user_id: $user_id"]
|
|
585
|
+
dest_props = ["name: $dest_name", "user_id: $user_id"]
|
|
586
|
+
if agent_id:
|
|
587
|
+
source_props.append("agent_id: $agent_id")
|
|
588
|
+
dest_props.append("agent_id: $agent_id")
|
|
589
|
+
params["agent_id"] = agent_id
|
|
590
|
+
if run_id:
|
|
591
|
+
source_props.append("run_id: $run_id")
|
|
592
|
+
dest_props.append("run_id: $run_id")
|
|
593
|
+
params["run_id"] = run_id
|
|
594
|
+
source_props_str = ", ".join(source_props)
|
|
595
|
+
dest_props_str = ", ".join(dest_props)
|
|
596
|
+
|
|
597
|
+
cypher = f"""
|
|
598
|
+
MERGE (source {source_label} {{{source_props_str}}})
|
|
599
|
+
ON CREATE SET
|
|
600
|
+
source.created = current_timestamp(),
|
|
601
|
+
source.mentions = 1,
|
|
602
|
+
source.embedding = CAST($source_embedding,'FLOAT[{self.embedding_dims}]')
|
|
603
|
+
ON MATCH SET
|
|
604
|
+
source.mentions = coalesce(source.mentions, 0) + 1,
|
|
605
|
+
source.embedding = CAST($source_embedding,'FLOAT[{self.embedding_dims}]')
|
|
606
|
+
WITH source
|
|
607
|
+
MERGE (destination {destination_label} {{{dest_props_str}}})
|
|
608
|
+
ON CREATE SET
|
|
609
|
+
destination.created = current_timestamp(),
|
|
610
|
+
destination.mentions = 1,
|
|
611
|
+
destination.embedding = CAST($dest_embedding,'FLOAT[{self.embedding_dims}]')
|
|
612
|
+
ON MATCH SET
|
|
613
|
+
destination.mentions = coalesce(destination.mentions, 0) + 1,
|
|
614
|
+
destination.embedding = CAST($dest_embedding,'FLOAT[{self.embedding_dims}]')
|
|
615
|
+
WITH source, destination
|
|
616
|
+
MERGE (source)-[rel {relationship_label} {{name: $relationship_name}}]->(destination)
|
|
617
|
+
ON CREATE SET
|
|
618
|
+
rel.created = current_timestamp(),
|
|
619
|
+
rel.mentions = 1
|
|
620
|
+
ON MATCH SET
|
|
621
|
+
rel.mentions = coalesce(rel.mentions, 0) + 1
|
|
622
|
+
RETURN
|
|
623
|
+
source.name AS source,
|
|
624
|
+
rel.name AS relationship,
|
|
625
|
+
destination.name AS target
|
|
626
|
+
"""
|
|
627
|
+
|
|
628
|
+
result = self.kuzu_execute(cypher, parameters=params)
|
|
629
|
+
results.append(result)
|
|
630
|
+
|
|
631
|
+
return results
|
|
632
|
+
|
|
633
|
+
def _remove_spaces_from_entities(self, entity_list):
|
|
634
|
+
for item in entity_list:
|
|
635
|
+
item["source"] = item["source"].lower().replace(" ", "_")
|
|
636
|
+
item["relationship"] = item["relationship"].lower().replace(" ", "_")
|
|
637
|
+
item["destination"] = item["destination"].lower().replace(" ", "_")
|
|
638
|
+
return entity_list
|
|
639
|
+
|
|
640
|
+
def _search_source_node(self, source_embedding, filters, threshold=0.9):
|
|
641
|
+
params = {
|
|
642
|
+
"source_embedding": source_embedding,
|
|
643
|
+
"user_id": filters["user_id"],
|
|
644
|
+
"threshold": threshold,
|
|
645
|
+
}
|
|
646
|
+
where_conditions = ["source_candidate.embedding IS NOT NULL", "source_candidate.user_id = $user_id"]
|
|
647
|
+
if filters.get("agent_id"):
|
|
648
|
+
where_conditions.append("source_candidate.agent_id = $agent_id")
|
|
649
|
+
params["agent_id"] = filters["agent_id"]
|
|
650
|
+
if filters.get("run_id"):
|
|
651
|
+
where_conditions.append("source_candidate.run_id = $run_id")
|
|
652
|
+
params["run_id"] = filters["run_id"]
|
|
653
|
+
where_clause = " AND ".join(where_conditions)
|
|
654
|
+
|
|
655
|
+
cypher = f"""
|
|
656
|
+
MATCH (source_candidate {self.node_label})
|
|
657
|
+
WHERE {where_clause}
|
|
658
|
+
|
|
659
|
+
WITH source_candidate,
|
|
660
|
+
array_cosine_similarity(source_candidate.embedding, CAST($source_embedding,'FLOAT[{self.embedding_dims}]')) AS source_similarity
|
|
661
|
+
|
|
662
|
+
WHERE source_similarity >= $threshold
|
|
663
|
+
|
|
664
|
+
WITH source_candidate, source_similarity
|
|
665
|
+
ORDER BY source_similarity DESC
|
|
666
|
+
LIMIT 2
|
|
667
|
+
|
|
668
|
+
RETURN id(source_candidate) as id, source_similarity
|
|
669
|
+
"""
|
|
670
|
+
|
|
671
|
+
return self.kuzu_execute(cypher, parameters=params)
|
|
672
|
+
|
|
673
|
+
def _search_destination_node(self, destination_embedding, filters, threshold=0.9):
|
|
674
|
+
params = {
|
|
675
|
+
"destination_embedding": destination_embedding,
|
|
676
|
+
"user_id": filters["user_id"],
|
|
677
|
+
"threshold": threshold,
|
|
678
|
+
}
|
|
679
|
+
where_conditions = ["destination_candidate.embedding IS NOT NULL", "destination_candidate.user_id = $user_id"]
|
|
680
|
+
if filters.get("agent_id"):
|
|
681
|
+
where_conditions.append("destination_candidate.agent_id = $agent_id")
|
|
682
|
+
params["agent_id"] = filters["agent_id"]
|
|
683
|
+
if filters.get("run_id"):
|
|
684
|
+
where_conditions.append("destination_candidate.run_id = $run_id")
|
|
685
|
+
params["run_id"] = filters["run_id"]
|
|
686
|
+
where_clause = " AND ".join(where_conditions)
|
|
687
|
+
|
|
688
|
+
cypher = f"""
|
|
689
|
+
MATCH (destination_candidate {self.node_label})
|
|
690
|
+
WHERE {where_clause}
|
|
691
|
+
|
|
692
|
+
WITH destination_candidate,
|
|
693
|
+
array_cosine_similarity(destination_candidate.embedding, CAST($destination_embedding,'FLOAT[{self.embedding_dims}]')) AS destination_similarity
|
|
694
|
+
|
|
695
|
+
WHERE destination_similarity >= $threshold
|
|
696
|
+
|
|
697
|
+
WITH destination_candidate, destination_similarity
|
|
698
|
+
ORDER BY destination_similarity DESC
|
|
699
|
+
LIMIT 2
|
|
700
|
+
|
|
701
|
+
RETURN id(destination_candidate) as id, destination_similarity
|
|
702
|
+
"""
|
|
703
|
+
|
|
704
|
+
return self.kuzu_execute(cypher, parameters=params)
|
|
705
|
+
|
|
706
|
+
# Reset is not defined in base.py
|
|
707
|
+
def reset(self):
|
|
708
|
+
"""Reset the graph by clearing all nodes and relationships."""
|
|
709
|
+
logger.warning("Clearing graph...")
|
|
710
|
+
cypher_query = """
|
|
711
|
+
MATCH (n) DETACH DELETE n
|
|
712
|
+
"""
|
|
713
|
+
return self.kuzu_execute(cypher_query)
|