agentrun-mem0ai 0.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentrun_mem0/__init__.py +6 -0
- agentrun_mem0/client/__init__.py +0 -0
- agentrun_mem0/client/main.py +1747 -0
- agentrun_mem0/client/project.py +931 -0
- agentrun_mem0/client/utils.py +115 -0
- agentrun_mem0/configs/__init__.py +0 -0
- agentrun_mem0/configs/base.py +90 -0
- agentrun_mem0/configs/embeddings/__init__.py +0 -0
- agentrun_mem0/configs/embeddings/base.py +110 -0
- agentrun_mem0/configs/enums.py +7 -0
- agentrun_mem0/configs/llms/__init__.py +0 -0
- agentrun_mem0/configs/llms/anthropic.py +56 -0
- agentrun_mem0/configs/llms/aws_bedrock.py +192 -0
- agentrun_mem0/configs/llms/azure.py +57 -0
- agentrun_mem0/configs/llms/base.py +62 -0
- agentrun_mem0/configs/llms/deepseek.py +56 -0
- agentrun_mem0/configs/llms/lmstudio.py +59 -0
- agentrun_mem0/configs/llms/ollama.py +56 -0
- agentrun_mem0/configs/llms/openai.py +79 -0
- agentrun_mem0/configs/llms/vllm.py +56 -0
- agentrun_mem0/configs/prompts.py +459 -0
- agentrun_mem0/configs/rerankers/__init__.py +0 -0
- agentrun_mem0/configs/rerankers/base.py +17 -0
- agentrun_mem0/configs/rerankers/cohere.py +15 -0
- agentrun_mem0/configs/rerankers/config.py +12 -0
- agentrun_mem0/configs/rerankers/huggingface.py +17 -0
- agentrun_mem0/configs/rerankers/llm.py +48 -0
- agentrun_mem0/configs/rerankers/sentence_transformer.py +16 -0
- agentrun_mem0/configs/rerankers/zero_entropy.py +28 -0
- agentrun_mem0/configs/vector_stores/__init__.py +0 -0
- agentrun_mem0/configs/vector_stores/alibabacloud_mysql.py +64 -0
- agentrun_mem0/configs/vector_stores/aliyun_tablestore.py +32 -0
- agentrun_mem0/configs/vector_stores/azure_ai_search.py +57 -0
- agentrun_mem0/configs/vector_stores/azure_mysql.py +84 -0
- agentrun_mem0/configs/vector_stores/baidu.py +27 -0
- agentrun_mem0/configs/vector_stores/chroma.py +58 -0
- agentrun_mem0/configs/vector_stores/databricks.py +61 -0
- agentrun_mem0/configs/vector_stores/elasticsearch.py +65 -0
- agentrun_mem0/configs/vector_stores/faiss.py +37 -0
- agentrun_mem0/configs/vector_stores/langchain.py +30 -0
- agentrun_mem0/configs/vector_stores/milvus.py +42 -0
- agentrun_mem0/configs/vector_stores/mongodb.py +25 -0
- agentrun_mem0/configs/vector_stores/neptune.py +27 -0
- agentrun_mem0/configs/vector_stores/opensearch.py +41 -0
- agentrun_mem0/configs/vector_stores/pgvector.py +52 -0
- agentrun_mem0/configs/vector_stores/pinecone.py +55 -0
- agentrun_mem0/configs/vector_stores/qdrant.py +47 -0
- agentrun_mem0/configs/vector_stores/redis.py +24 -0
- agentrun_mem0/configs/vector_stores/s3_vectors.py +28 -0
- agentrun_mem0/configs/vector_stores/supabase.py +44 -0
- agentrun_mem0/configs/vector_stores/upstash_vector.py +34 -0
- agentrun_mem0/configs/vector_stores/valkey.py +15 -0
- agentrun_mem0/configs/vector_stores/vertex_ai_vector_search.py +28 -0
- agentrun_mem0/configs/vector_stores/weaviate.py +41 -0
- agentrun_mem0/embeddings/__init__.py +0 -0
- agentrun_mem0/embeddings/aws_bedrock.py +100 -0
- agentrun_mem0/embeddings/azure_openai.py +55 -0
- agentrun_mem0/embeddings/base.py +31 -0
- agentrun_mem0/embeddings/configs.py +30 -0
- agentrun_mem0/embeddings/gemini.py +39 -0
- agentrun_mem0/embeddings/huggingface.py +44 -0
- agentrun_mem0/embeddings/langchain.py +35 -0
- agentrun_mem0/embeddings/lmstudio.py +29 -0
- agentrun_mem0/embeddings/mock.py +11 -0
- agentrun_mem0/embeddings/ollama.py +53 -0
- agentrun_mem0/embeddings/openai.py +49 -0
- agentrun_mem0/embeddings/together.py +31 -0
- agentrun_mem0/embeddings/vertexai.py +64 -0
- agentrun_mem0/exceptions.py +503 -0
- agentrun_mem0/graphs/__init__.py +0 -0
- agentrun_mem0/graphs/configs.py +105 -0
- agentrun_mem0/graphs/neptune/__init__.py +0 -0
- agentrun_mem0/graphs/neptune/base.py +497 -0
- agentrun_mem0/graphs/neptune/neptunedb.py +511 -0
- agentrun_mem0/graphs/neptune/neptunegraph.py +474 -0
- agentrun_mem0/graphs/tools.py +371 -0
- agentrun_mem0/graphs/utils.py +97 -0
- agentrun_mem0/llms/__init__.py +0 -0
- agentrun_mem0/llms/anthropic.py +87 -0
- agentrun_mem0/llms/aws_bedrock.py +665 -0
- agentrun_mem0/llms/azure_openai.py +141 -0
- agentrun_mem0/llms/azure_openai_structured.py +91 -0
- agentrun_mem0/llms/base.py +131 -0
- agentrun_mem0/llms/configs.py +34 -0
- agentrun_mem0/llms/deepseek.py +107 -0
- agentrun_mem0/llms/gemini.py +201 -0
- agentrun_mem0/llms/groq.py +88 -0
- agentrun_mem0/llms/langchain.py +94 -0
- agentrun_mem0/llms/litellm.py +87 -0
- agentrun_mem0/llms/lmstudio.py +114 -0
- agentrun_mem0/llms/ollama.py +117 -0
- agentrun_mem0/llms/openai.py +147 -0
- agentrun_mem0/llms/openai_structured.py +52 -0
- agentrun_mem0/llms/sarvam.py +89 -0
- agentrun_mem0/llms/together.py +88 -0
- agentrun_mem0/llms/vllm.py +107 -0
- agentrun_mem0/llms/xai.py +52 -0
- agentrun_mem0/memory/__init__.py +0 -0
- agentrun_mem0/memory/base.py +63 -0
- agentrun_mem0/memory/graph_memory.py +698 -0
- agentrun_mem0/memory/kuzu_memory.py +713 -0
- agentrun_mem0/memory/main.py +2229 -0
- agentrun_mem0/memory/memgraph_memory.py +689 -0
- agentrun_mem0/memory/setup.py +56 -0
- agentrun_mem0/memory/storage.py +218 -0
- agentrun_mem0/memory/telemetry.py +90 -0
- agentrun_mem0/memory/utils.py +208 -0
- agentrun_mem0/proxy/__init__.py +0 -0
- agentrun_mem0/proxy/main.py +189 -0
- agentrun_mem0/reranker/__init__.py +9 -0
- agentrun_mem0/reranker/base.py +20 -0
- agentrun_mem0/reranker/cohere_reranker.py +85 -0
- agentrun_mem0/reranker/huggingface_reranker.py +147 -0
- agentrun_mem0/reranker/llm_reranker.py +142 -0
- agentrun_mem0/reranker/sentence_transformer_reranker.py +107 -0
- agentrun_mem0/reranker/zero_entropy_reranker.py +96 -0
- agentrun_mem0/utils/factory.py +283 -0
- agentrun_mem0/utils/gcp_auth.py +167 -0
- agentrun_mem0/vector_stores/__init__.py +0 -0
- agentrun_mem0/vector_stores/alibabacloud_mysql.py +547 -0
- agentrun_mem0/vector_stores/aliyun_tablestore.py +252 -0
- agentrun_mem0/vector_stores/azure_ai_search.py +396 -0
- agentrun_mem0/vector_stores/azure_mysql.py +463 -0
- agentrun_mem0/vector_stores/baidu.py +368 -0
- agentrun_mem0/vector_stores/base.py +58 -0
- agentrun_mem0/vector_stores/chroma.py +332 -0
- agentrun_mem0/vector_stores/configs.py +67 -0
- agentrun_mem0/vector_stores/databricks.py +761 -0
- agentrun_mem0/vector_stores/elasticsearch.py +237 -0
- agentrun_mem0/vector_stores/faiss.py +479 -0
- agentrun_mem0/vector_stores/langchain.py +180 -0
- agentrun_mem0/vector_stores/milvus.py +250 -0
- agentrun_mem0/vector_stores/mongodb.py +310 -0
- agentrun_mem0/vector_stores/neptune_analytics.py +467 -0
- agentrun_mem0/vector_stores/opensearch.py +292 -0
- agentrun_mem0/vector_stores/pgvector.py +404 -0
- agentrun_mem0/vector_stores/pinecone.py +382 -0
- agentrun_mem0/vector_stores/qdrant.py +270 -0
- agentrun_mem0/vector_stores/redis.py +295 -0
- agentrun_mem0/vector_stores/s3_vectors.py +176 -0
- agentrun_mem0/vector_stores/supabase.py +237 -0
- agentrun_mem0/vector_stores/upstash_vector.py +293 -0
- agentrun_mem0/vector_stores/valkey.py +824 -0
- agentrun_mem0/vector_stores/vertex_ai_vector_search.py +635 -0
- agentrun_mem0/vector_stores/weaviate.py +343 -0
- agentrun_mem0ai-0.0.11.data/data/README.md +205 -0
- agentrun_mem0ai-0.0.11.dist-info/METADATA +277 -0
- agentrun_mem0ai-0.0.11.dist-info/RECORD +150 -0
- agentrun_mem0ai-0.0.11.dist-info/WHEEL +4 -0
- agentrun_mem0ai-0.0.11.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,474 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from .base import NeptuneBase
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
from langchain_aws import NeptuneAnalyticsGraph
|
|
7
|
+
from botocore.config import Config
|
|
8
|
+
except ImportError:
|
|
9
|
+
raise ImportError("langchain_aws is not installed. Please install it using 'make install_all'.")
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MemoryGraph(NeptuneBase):
|
|
15
|
+
def __init__(self, config):
|
|
16
|
+
self.config = config
|
|
17
|
+
|
|
18
|
+
self.graph = None
|
|
19
|
+
endpoint = self.config.graph_store.config.endpoint
|
|
20
|
+
app_id = self.config.graph_store.config.app_id
|
|
21
|
+
if endpoint and endpoint.startswith("neptune-graph://"):
|
|
22
|
+
graph_identifier = endpoint.replace("neptune-graph://", "")
|
|
23
|
+
self.graph = NeptuneAnalyticsGraph(graph_identifier = graph_identifier,
|
|
24
|
+
config = Config(user_agent_appid=app_id))
|
|
25
|
+
|
|
26
|
+
if not self.graph:
|
|
27
|
+
raise ValueError("Unable to create a Neptune client: missing 'endpoint' in config")
|
|
28
|
+
|
|
29
|
+
self.node_label = ":`__Entity__`" if self.config.graph_store.config.base_label else ""
|
|
30
|
+
|
|
31
|
+
self.embedding_model = NeptuneBase._create_embedding_model(self.config)
|
|
32
|
+
|
|
33
|
+
# Default to openai if no specific provider is configured
|
|
34
|
+
self.llm_provider = "openai"
|
|
35
|
+
if self.config.llm.provider:
|
|
36
|
+
self.llm_provider = self.config.llm.provider
|
|
37
|
+
if self.config.graph_store.llm:
|
|
38
|
+
self.llm_provider = self.config.graph_store.llm.provider
|
|
39
|
+
|
|
40
|
+
self.llm = NeptuneBase._create_llm(self.config, self.llm_provider)
|
|
41
|
+
self.user_id = None
|
|
42
|
+
self.threshold = 0.7
|
|
43
|
+
|
|
44
|
+
def _delete_entities_cypher(self, source, destination, relationship, user_id):
|
|
45
|
+
"""
|
|
46
|
+
Returns the OpenCypher query and parameters for deleting entities in the graph DB
|
|
47
|
+
|
|
48
|
+
:param source: source node
|
|
49
|
+
:param destination: destination node
|
|
50
|
+
:param relationship: relationship label
|
|
51
|
+
:param user_id: user_id to use
|
|
52
|
+
:return: str, dict
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
cypher = f"""
|
|
56
|
+
MATCH (n {self.node_label} {{name: $source_name, user_id: $user_id}})
|
|
57
|
+
-[r:{relationship}]->
|
|
58
|
+
(m {self.node_label} {{name: $dest_name, user_id: $user_id}})
|
|
59
|
+
DELETE r
|
|
60
|
+
RETURN
|
|
61
|
+
n.name AS source,
|
|
62
|
+
m.name AS target,
|
|
63
|
+
type(r) AS relationship
|
|
64
|
+
"""
|
|
65
|
+
params = {
|
|
66
|
+
"source_name": source,
|
|
67
|
+
"dest_name": destination,
|
|
68
|
+
"user_id": user_id,
|
|
69
|
+
}
|
|
70
|
+
logger.debug(f"_delete_entities\n query={cypher}")
|
|
71
|
+
return cypher, params
|
|
72
|
+
|
|
73
|
+
def _add_entities_by_source_cypher(
|
|
74
|
+
self,
|
|
75
|
+
source_node_list,
|
|
76
|
+
destination,
|
|
77
|
+
dest_embedding,
|
|
78
|
+
destination_type,
|
|
79
|
+
relationship,
|
|
80
|
+
user_id,
|
|
81
|
+
):
|
|
82
|
+
"""
|
|
83
|
+
Returns the OpenCypher query and parameters for adding entities in the graph DB
|
|
84
|
+
|
|
85
|
+
:param source_node_list: list of source nodes
|
|
86
|
+
:param destination: destination name
|
|
87
|
+
:param dest_embedding: destination embedding
|
|
88
|
+
:param destination_type: destination node label
|
|
89
|
+
:param relationship: relationship label
|
|
90
|
+
:param user_id: user id to use
|
|
91
|
+
:return: str, dict
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
destination_label = self.node_label if self.node_label else f":`{destination_type}`"
|
|
95
|
+
destination_extra_set = f", destination:`{destination_type}`" if self.node_label else ""
|
|
96
|
+
|
|
97
|
+
cypher = f"""
|
|
98
|
+
MATCH (source {{user_id: $user_id}})
|
|
99
|
+
WHERE id(source) = $source_id
|
|
100
|
+
SET source.mentions = coalesce(source.mentions, 0) + 1
|
|
101
|
+
WITH source
|
|
102
|
+
MERGE (destination {destination_label} {{name: $destination_name, user_id: $user_id}})
|
|
103
|
+
ON CREATE SET
|
|
104
|
+
destination.created = timestamp(),
|
|
105
|
+
destination.updated = timestamp(),
|
|
106
|
+
destination.mentions = 1
|
|
107
|
+
{destination_extra_set}
|
|
108
|
+
ON MATCH SET
|
|
109
|
+
destination.mentions = coalesce(destination.mentions, 0) + 1,
|
|
110
|
+
destination.updated = timestamp()
|
|
111
|
+
WITH source, destination, $dest_embedding as dest_embedding
|
|
112
|
+
CALL neptune.algo.vectors.upsert(destination, dest_embedding)
|
|
113
|
+
WITH source, destination
|
|
114
|
+
MERGE (source)-[r:{relationship}]->(destination)
|
|
115
|
+
ON CREATE SET
|
|
116
|
+
r.created = timestamp(),
|
|
117
|
+
r.updated = timestamp(),
|
|
118
|
+
r.mentions = 1
|
|
119
|
+
ON MATCH SET
|
|
120
|
+
r.mentions = coalesce(r.mentions, 0) + 1,
|
|
121
|
+
r.updated = timestamp()
|
|
122
|
+
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
params = {
|
|
126
|
+
"source_id": source_node_list[0]["id(source_candidate)"],
|
|
127
|
+
"destination_name": destination,
|
|
128
|
+
"dest_embedding": dest_embedding,
|
|
129
|
+
"user_id": user_id,
|
|
130
|
+
}
|
|
131
|
+
logger.debug(
|
|
132
|
+
f"_add_entities:\n source_node_search_result={source_node_list[0]}\n query={cypher}"
|
|
133
|
+
)
|
|
134
|
+
return cypher, params
|
|
135
|
+
|
|
136
|
+
def _add_entities_by_destination_cypher(
|
|
137
|
+
self,
|
|
138
|
+
source,
|
|
139
|
+
source_embedding,
|
|
140
|
+
source_type,
|
|
141
|
+
destination_node_list,
|
|
142
|
+
relationship,
|
|
143
|
+
user_id,
|
|
144
|
+
):
|
|
145
|
+
"""
|
|
146
|
+
Returns the OpenCypher query and parameters for adding entities in the graph DB
|
|
147
|
+
|
|
148
|
+
:param source: source node name
|
|
149
|
+
:param source_embedding: source node embedding
|
|
150
|
+
:param source_type: source node label
|
|
151
|
+
:param destination_node_list: list of dest nodes
|
|
152
|
+
:param relationship: relationship label
|
|
153
|
+
:param user_id: user id to use
|
|
154
|
+
:return: str, dict
|
|
155
|
+
"""
|
|
156
|
+
|
|
157
|
+
source_label = self.node_label if self.node_label else f":`{source_type}`"
|
|
158
|
+
source_extra_set = f", source:`{source_type}`" if self.node_label else ""
|
|
159
|
+
|
|
160
|
+
cypher = f"""
|
|
161
|
+
MATCH (destination {{user_id: $user_id}})
|
|
162
|
+
WHERE id(destination) = $destination_id
|
|
163
|
+
SET
|
|
164
|
+
destination.mentions = coalesce(destination.mentions, 0) + 1,
|
|
165
|
+
destination.updated = timestamp()
|
|
166
|
+
WITH destination
|
|
167
|
+
MERGE (source {source_label} {{name: $source_name, user_id: $user_id}})
|
|
168
|
+
ON CREATE SET
|
|
169
|
+
source.created = timestamp(),
|
|
170
|
+
source.updated = timestamp(),
|
|
171
|
+
source.mentions = 1
|
|
172
|
+
{source_extra_set}
|
|
173
|
+
ON MATCH SET
|
|
174
|
+
source.mentions = coalesce(source.mentions, 0) + 1,
|
|
175
|
+
source.updated = timestamp()
|
|
176
|
+
WITH source, destination, $source_embedding as source_embedding
|
|
177
|
+
CALL neptune.algo.vectors.upsert(source, source_embedding)
|
|
178
|
+
WITH source, destination
|
|
179
|
+
MERGE (source)-[r:{relationship}]->(destination)
|
|
180
|
+
ON CREATE SET
|
|
181
|
+
r.created = timestamp(),
|
|
182
|
+
r.updated = timestamp(),
|
|
183
|
+
r.mentions = 1
|
|
184
|
+
ON MATCH SET
|
|
185
|
+
r.mentions = coalesce(r.mentions, 0) + 1,
|
|
186
|
+
r.updated = timestamp()
|
|
187
|
+
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
|
188
|
+
"""
|
|
189
|
+
|
|
190
|
+
params = {
|
|
191
|
+
"destination_id": destination_node_list[0]["id(destination_candidate)"],
|
|
192
|
+
"source_name": source,
|
|
193
|
+
"source_embedding": source_embedding,
|
|
194
|
+
"user_id": user_id,
|
|
195
|
+
}
|
|
196
|
+
logger.debug(
|
|
197
|
+
f"_add_entities:\n destination_node_search_result={destination_node_list[0]}\n query={cypher}"
|
|
198
|
+
)
|
|
199
|
+
return cypher, params
|
|
200
|
+
|
|
201
|
+
def _add_relationship_entities_cypher(
|
|
202
|
+
self,
|
|
203
|
+
source_node_list,
|
|
204
|
+
destination_node_list,
|
|
205
|
+
relationship,
|
|
206
|
+
user_id,
|
|
207
|
+
):
|
|
208
|
+
"""
|
|
209
|
+
Returns the OpenCypher query and parameters for adding entities in the graph DB
|
|
210
|
+
|
|
211
|
+
:param source_node_list: list of source node ids
|
|
212
|
+
:param destination_node_list: list of dest node ids
|
|
213
|
+
:param relationship: relationship label
|
|
214
|
+
:param user_id: user id to use
|
|
215
|
+
:return: str, dict
|
|
216
|
+
"""
|
|
217
|
+
|
|
218
|
+
cypher = f"""
|
|
219
|
+
MATCH (source {{user_id: $user_id}})
|
|
220
|
+
WHERE id(source) = $source_id
|
|
221
|
+
SET
|
|
222
|
+
source.mentions = coalesce(source.mentions, 0) + 1,
|
|
223
|
+
source.updated = timestamp()
|
|
224
|
+
WITH source
|
|
225
|
+
MATCH (destination {{user_id: $user_id}})
|
|
226
|
+
WHERE id(destination) = $destination_id
|
|
227
|
+
SET
|
|
228
|
+
destination.mentions = coalesce(destination.mentions) + 1,
|
|
229
|
+
destination.updated = timestamp()
|
|
230
|
+
MERGE (source)-[r:{relationship}]->(destination)
|
|
231
|
+
ON CREATE SET
|
|
232
|
+
r.created_at = timestamp(),
|
|
233
|
+
r.updated_at = timestamp(),
|
|
234
|
+
r.mentions = 1
|
|
235
|
+
ON MATCH SET r.mentions = coalesce(r.mentions, 0) + 1
|
|
236
|
+
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
|
237
|
+
"""
|
|
238
|
+
params = {
|
|
239
|
+
"source_id": source_node_list[0]["id(source_candidate)"],
|
|
240
|
+
"destination_id": destination_node_list[0]["id(destination_candidate)"],
|
|
241
|
+
"user_id": user_id,
|
|
242
|
+
}
|
|
243
|
+
logger.debug(
|
|
244
|
+
f"_add_entities:\n destination_node_search_result={destination_node_list[0]}\n source_node_search_result={source_node_list[0]}\n query={cypher}"
|
|
245
|
+
)
|
|
246
|
+
return cypher, params
|
|
247
|
+
|
|
248
|
+
def _add_new_entities_cypher(
|
|
249
|
+
self,
|
|
250
|
+
source,
|
|
251
|
+
source_embedding,
|
|
252
|
+
source_type,
|
|
253
|
+
destination,
|
|
254
|
+
dest_embedding,
|
|
255
|
+
destination_type,
|
|
256
|
+
relationship,
|
|
257
|
+
user_id,
|
|
258
|
+
):
|
|
259
|
+
"""
|
|
260
|
+
Returns the OpenCypher query and parameters for adding entities in the graph DB
|
|
261
|
+
|
|
262
|
+
:param source: source node name
|
|
263
|
+
:param source_embedding: source node embedding
|
|
264
|
+
:param source_type: source node label
|
|
265
|
+
:param destination: destination name
|
|
266
|
+
:param dest_embedding: destination embedding
|
|
267
|
+
:param destination_type: destination node label
|
|
268
|
+
:param relationship: relationship label
|
|
269
|
+
:param user_id: user id to use
|
|
270
|
+
:return: str, dict
|
|
271
|
+
"""
|
|
272
|
+
|
|
273
|
+
source_label = self.node_label if self.node_label else f":`{source_type}`"
|
|
274
|
+
source_extra_set = f", source:`{source_type}`" if self.node_label else ""
|
|
275
|
+
destination_label = self.node_label if self.node_label else f":`{destination_type}`"
|
|
276
|
+
destination_extra_set = f", destination:`{destination_type}`" if self.node_label else ""
|
|
277
|
+
|
|
278
|
+
cypher = f"""
|
|
279
|
+
MERGE (n {source_label} {{name: $source_name, user_id: $user_id}})
|
|
280
|
+
ON CREATE SET n.created = timestamp(),
|
|
281
|
+
n.updated = timestamp(),
|
|
282
|
+
n.mentions = 1
|
|
283
|
+
{source_extra_set}
|
|
284
|
+
ON MATCH SET
|
|
285
|
+
n.mentions = coalesce(n.mentions, 0) + 1,
|
|
286
|
+
n.updated = timestamp()
|
|
287
|
+
WITH n, $source_embedding as source_embedding
|
|
288
|
+
CALL neptune.algo.vectors.upsert(n, source_embedding)
|
|
289
|
+
WITH n
|
|
290
|
+
MERGE (m {destination_label} {{name: $dest_name, user_id: $user_id}})
|
|
291
|
+
ON CREATE SET
|
|
292
|
+
m.created = timestamp(),
|
|
293
|
+
m.updated = timestamp(),
|
|
294
|
+
m.mentions = 1
|
|
295
|
+
{destination_extra_set}
|
|
296
|
+
ON MATCH SET
|
|
297
|
+
m.updated = timestamp(),
|
|
298
|
+
m.mentions = coalesce(m.mentions, 0) + 1
|
|
299
|
+
WITH n, m, $dest_embedding as dest_embedding
|
|
300
|
+
CALL neptune.algo.vectors.upsert(m, dest_embedding)
|
|
301
|
+
WITH n, m
|
|
302
|
+
MERGE (n)-[rel:{relationship}]->(m)
|
|
303
|
+
ON CREATE SET
|
|
304
|
+
rel.created = timestamp(),
|
|
305
|
+
rel.updated = timestamp(),
|
|
306
|
+
rel.mentions = 1
|
|
307
|
+
ON MATCH SET
|
|
308
|
+
rel.updated = timestamp(),
|
|
309
|
+
rel.mentions = coalesce(rel.mentions, 0) + 1
|
|
310
|
+
RETURN n.name AS source, type(rel) AS relationship, m.name AS target
|
|
311
|
+
"""
|
|
312
|
+
params = {
|
|
313
|
+
"source_name": source,
|
|
314
|
+
"dest_name": destination,
|
|
315
|
+
"source_embedding": source_embedding,
|
|
316
|
+
"dest_embedding": dest_embedding,
|
|
317
|
+
"user_id": user_id,
|
|
318
|
+
}
|
|
319
|
+
logger.debug(
|
|
320
|
+
f"_add_new_entities_cypher:\n query={cypher}"
|
|
321
|
+
)
|
|
322
|
+
return cypher, params
|
|
323
|
+
|
|
324
|
+
def _search_source_node_cypher(self, source_embedding, user_id, threshold):
|
|
325
|
+
"""
|
|
326
|
+
Returns the OpenCypher query and parameters to search for source nodes
|
|
327
|
+
|
|
328
|
+
:param source_embedding: source vector
|
|
329
|
+
:param user_id: user_id to use
|
|
330
|
+
:param threshold: the threshold for similarity
|
|
331
|
+
:return: str, dict
|
|
332
|
+
"""
|
|
333
|
+
cypher = f"""
|
|
334
|
+
MATCH (source_candidate {self.node_label})
|
|
335
|
+
WHERE source_candidate.user_id = $user_id
|
|
336
|
+
|
|
337
|
+
WITH source_candidate, $source_embedding as v_embedding
|
|
338
|
+
CALL neptune.algo.vectors.distanceByEmbedding(
|
|
339
|
+
v_embedding,
|
|
340
|
+
source_candidate,
|
|
341
|
+
{{metric:"CosineSimilarity"}}
|
|
342
|
+
) YIELD distance
|
|
343
|
+
WITH source_candidate, distance AS cosine_similarity
|
|
344
|
+
WHERE cosine_similarity >= $threshold
|
|
345
|
+
|
|
346
|
+
WITH source_candidate, cosine_similarity
|
|
347
|
+
ORDER BY cosine_similarity DESC
|
|
348
|
+
LIMIT 1
|
|
349
|
+
|
|
350
|
+
RETURN id(source_candidate), cosine_similarity
|
|
351
|
+
"""
|
|
352
|
+
|
|
353
|
+
params = {
|
|
354
|
+
"source_embedding": source_embedding,
|
|
355
|
+
"user_id": user_id,
|
|
356
|
+
"threshold": threshold,
|
|
357
|
+
}
|
|
358
|
+
logger.debug(f"_search_source_node\n query={cypher}")
|
|
359
|
+
return cypher, params
|
|
360
|
+
|
|
361
|
+
def _search_destination_node_cypher(self, destination_embedding, user_id, threshold):
|
|
362
|
+
"""
|
|
363
|
+
Returns the OpenCypher query and parameters to search for destination nodes
|
|
364
|
+
|
|
365
|
+
:param source_embedding: source vector
|
|
366
|
+
:param user_id: user_id to use
|
|
367
|
+
:param threshold: the threshold for similarity
|
|
368
|
+
:return: str, dict
|
|
369
|
+
"""
|
|
370
|
+
cypher = f"""
|
|
371
|
+
MATCH (destination_candidate {self.node_label})
|
|
372
|
+
WHERE destination_candidate.user_id = $user_id
|
|
373
|
+
|
|
374
|
+
WITH destination_candidate, $destination_embedding as v_embedding
|
|
375
|
+
CALL neptune.algo.vectors.distanceByEmbedding(
|
|
376
|
+
v_embedding,
|
|
377
|
+
destination_candidate,
|
|
378
|
+
{{metric:"CosineSimilarity"}}
|
|
379
|
+
) YIELD distance
|
|
380
|
+
WITH destination_candidate, distance AS cosine_similarity
|
|
381
|
+
WHERE cosine_similarity >= $threshold
|
|
382
|
+
|
|
383
|
+
WITH destination_candidate, cosine_similarity
|
|
384
|
+
ORDER BY cosine_similarity DESC
|
|
385
|
+
LIMIT 1
|
|
386
|
+
|
|
387
|
+
RETURN id(destination_candidate), cosine_similarity
|
|
388
|
+
"""
|
|
389
|
+
params = {
|
|
390
|
+
"destination_embedding": destination_embedding,
|
|
391
|
+
"user_id": user_id,
|
|
392
|
+
"threshold": threshold,
|
|
393
|
+
}
|
|
394
|
+
|
|
395
|
+
logger.debug(f"_search_destination_node\n query={cypher}")
|
|
396
|
+
return cypher, params
|
|
397
|
+
|
|
398
|
+
def _delete_all_cypher(self, filters):
|
|
399
|
+
"""
|
|
400
|
+
Returns the OpenCypher query and parameters to delete all edges/nodes in the memory store
|
|
401
|
+
|
|
402
|
+
:param filters: search filters
|
|
403
|
+
:return: str, dict
|
|
404
|
+
"""
|
|
405
|
+
cypher = f"""
|
|
406
|
+
MATCH (n {self.node_label} {{user_id: $user_id}})
|
|
407
|
+
DETACH DELETE n
|
|
408
|
+
"""
|
|
409
|
+
params = {"user_id": filters["user_id"]}
|
|
410
|
+
|
|
411
|
+
logger.debug(f"delete_all query={cypher}")
|
|
412
|
+
return cypher, params
|
|
413
|
+
|
|
414
|
+
def _get_all_cypher(self, filters, limit):
|
|
415
|
+
"""
|
|
416
|
+
Returns the OpenCypher query and parameters to get all edges/nodes in the memory store
|
|
417
|
+
|
|
418
|
+
:param filters: search filters
|
|
419
|
+
:param limit: return limit
|
|
420
|
+
:return: str, dict
|
|
421
|
+
"""
|
|
422
|
+
|
|
423
|
+
cypher = f"""
|
|
424
|
+
MATCH (n {self.node_label} {{user_id: $user_id}})-[r]->(m {self.node_label} {{user_id: $user_id}})
|
|
425
|
+
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
|
426
|
+
LIMIT $limit
|
|
427
|
+
"""
|
|
428
|
+
params = {"user_id": filters["user_id"], "limit": limit}
|
|
429
|
+
return cypher, params
|
|
430
|
+
|
|
431
|
+
def _search_graph_db_cypher(self, n_embedding, filters, limit):
|
|
432
|
+
"""
|
|
433
|
+
Returns the OpenCypher query and parameters to search for similar nodes in the memory store
|
|
434
|
+
|
|
435
|
+
:param n_embedding: node vector
|
|
436
|
+
:param filters: search filters
|
|
437
|
+
:param limit: return limit
|
|
438
|
+
:return: str, dict
|
|
439
|
+
"""
|
|
440
|
+
|
|
441
|
+
cypher_query = f"""
|
|
442
|
+
MATCH (n {self.node_label})
|
|
443
|
+
WHERE n.user_id = $user_id
|
|
444
|
+
WITH n, $n_embedding as n_embedding
|
|
445
|
+
CALL neptune.algo.vectors.distanceByEmbedding(
|
|
446
|
+
n_embedding,
|
|
447
|
+
n,
|
|
448
|
+
{{metric:"CosineSimilarity"}}
|
|
449
|
+
) YIELD distance
|
|
450
|
+
WITH n, distance as similarity
|
|
451
|
+
WHERE similarity >= $threshold
|
|
452
|
+
CALL {{
|
|
453
|
+
WITH n
|
|
454
|
+
MATCH (n)-[r]->(m)
|
|
455
|
+
RETURN n.name AS source, id(n) AS source_id, type(r) AS relationship, id(r) AS relation_id, m.name AS destination, id(m) AS destination_id
|
|
456
|
+
UNION ALL
|
|
457
|
+
WITH n
|
|
458
|
+
MATCH (m)-[r]->(n)
|
|
459
|
+
RETURN m.name AS source, id(m) AS source_id, type(r) AS relationship, id(r) AS relation_id, n.name AS destination, id(n) AS destination_id
|
|
460
|
+
}}
|
|
461
|
+
WITH distinct source, source_id, relationship, relation_id, destination, destination_id, similarity
|
|
462
|
+
RETURN source, source_id, relationship, relation_id, destination, destination_id, similarity
|
|
463
|
+
ORDER BY similarity DESC
|
|
464
|
+
LIMIT $limit
|
|
465
|
+
"""
|
|
466
|
+
params = {
|
|
467
|
+
"n_embedding": n_embedding,
|
|
468
|
+
"threshold": self.threshold,
|
|
469
|
+
"user_id": filters["user_id"],
|
|
470
|
+
"limit": limit,
|
|
471
|
+
}
|
|
472
|
+
logger.debug(f"_search_graph_db\n query={cypher_query}")
|
|
473
|
+
|
|
474
|
+
return cypher_query, params
|