agentrun-mem0ai 0.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentrun_mem0/__init__.py +6 -0
- agentrun_mem0/client/__init__.py +0 -0
- agentrun_mem0/client/main.py +1747 -0
- agentrun_mem0/client/project.py +931 -0
- agentrun_mem0/client/utils.py +115 -0
- agentrun_mem0/configs/__init__.py +0 -0
- agentrun_mem0/configs/base.py +90 -0
- agentrun_mem0/configs/embeddings/__init__.py +0 -0
- agentrun_mem0/configs/embeddings/base.py +110 -0
- agentrun_mem0/configs/enums.py +7 -0
- agentrun_mem0/configs/llms/__init__.py +0 -0
- agentrun_mem0/configs/llms/anthropic.py +56 -0
- agentrun_mem0/configs/llms/aws_bedrock.py +192 -0
- agentrun_mem0/configs/llms/azure.py +57 -0
- agentrun_mem0/configs/llms/base.py +62 -0
- agentrun_mem0/configs/llms/deepseek.py +56 -0
- agentrun_mem0/configs/llms/lmstudio.py +59 -0
- agentrun_mem0/configs/llms/ollama.py +56 -0
- agentrun_mem0/configs/llms/openai.py +79 -0
- agentrun_mem0/configs/llms/vllm.py +56 -0
- agentrun_mem0/configs/prompts.py +459 -0
- agentrun_mem0/configs/rerankers/__init__.py +0 -0
- agentrun_mem0/configs/rerankers/base.py +17 -0
- agentrun_mem0/configs/rerankers/cohere.py +15 -0
- agentrun_mem0/configs/rerankers/config.py +12 -0
- agentrun_mem0/configs/rerankers/huggingface.py +17 -0
- agentrun_mem0/configs/rerankers/llm.py +48 -0
- agentrun_mem0/configs/rerankers/sentence_transformer.py +16 -0
- agentrun_mem0/configs/rerankers/zero_entropy.py +28 -0
- agentrun_mem0/configs/vector_stores/__init__.py +0 -0
- agentrun_mem0/configs/vector_stores/alibabacloud_mysql.py +64 -0
- agentrun_mem0/configs/vector_stores/aliyun_tablestore.py +32 -0
- agentrun_mem0/configs/vector_stores/azure_ai_search.py +57 -0
- agentrun_mem0/configs/vector_stores/azure_mysql.py +84 -0
- agentrun_mem0/configs/vector_stores/baidu.py +27 -0
- agentrun_mem0/configs/vector_stores/chroma.py +58 -0
- agentrun_mem0/configs/vector_stores/databricks.py +61 -0
- agentrun_mem0/configs/vector_stores/elasticsearch.py +65 -0
- agentrun_mem0/configs/vector_stores/faiss.py +37 -0
- agentrun_mem0/configs/vector_stores/langchain.py +30 -0
- agentrun_mem0/configs/vector_stores/milvus.py +42 -0
- agentrun_mem0/configs/vector_stores/mongodb.py +25 -0
- agentrun_mem0/configs/vector_stores/neptune.py +27 -0
- agentrun_mem0/configs/vector_stores/opensearch.py +41 -0
- agentrun_mem0/configs/vector_stores/pgvector.py +52 -0
- agentrun_mem0/configs/vector_stores/pinecone.py +55 -0
- agentrun_mem0/configs/vector_stores/qdrant.py +47 -0
- agentrun_mem0/configs/vector_stores/redis.py +24 -0
- agentrun_mem0/configs/vector_stores/s3_vectors.py +28 -0
- agentrun_mem0/configs/vector_stores/supabase.py +44 -0
- agentrun_mem0/configs/vector_stores/upstash_vector.py +34 -0
- agentrun_mem0/configs/vector_stores/valkey.py +15 -0
- agentrun_mem0/configs/vector_stores/vertex_ai_vector_search.py +28 -0
- agentrun_mem0/configs/vector_stores/weaviate.py +41 -0
- agentrun_mem0/embeddings/__init__.py +0 -0
- agentrun_mem0/embeddings/aws_bedrock.py +100 -0
- agentrun_mem0/embeddings/azure_openai.py +55 -0
- agentrun_mem0/embeddings/base.py +31 -0
- agentrun_mem0/embeddings/configs.py +30 -0
- agentrun_mem0/embeddings/gemini.py +39 -0
- agentrun_mem0/embeddings/huggingface.py +44 -0
- agentrun_mem0/embeddings/langchain.py +35 -0
- agentrun_mem0/embeddings/lmstudio.py +29 -0
- agentrun_mem0/embeddings/mock.py +11 -0
- agentrun_mem0/embeddings/ollama.py +53 -0
- agentrun_mem0/embeddings/openai.py +49 -0
- agentrun_mem0/embeddings/together.py +31 -0
- agentrun_mem0/embeddings/vertexai.py +64 -0
- agentrun_mem0/exceptions.py +503 -0
- agentrun_mem0/graphs/__init__.py +0 -0
- agentrun_mem0/graphs/configs.py +105 -0
- agentrun_mem0/graphs/neptune/__init__.py +0 -0
- agentrun_mem0/graphs/neptune/base.py +497 -0
- agentrun_mem0/graphs/neptune/neptunedb.py +511 -0
- agentrun_mem0/graphs/neptune/neptunegraph.py +474 -0
- agentrun_mem0/graphs/tools.py +371 -0
- agentrun_mem0/graphs/utils.py +97 -0
- agentrun_mem0/llms/__init__.py +0 -0
- agentrun_mem0/llms/anthropic.py +87 -0
- agentrun_mem0/llms/aws_bedrock.py +665 -0
- agentrun_mem0/llms/azure_openai.py +141 -0
- agentrun_mem0/llms/azure_openai_structured.py +91 -0
- agentrun_mem0/llms/base.py +131 -0
- agentrun_mem0/llms/configs.py +34 -0
- agentrun_mem0/llms/deepseek.py +107 -0
- agentrun_mem0/llms/gemini.py +201 -0
- agentrun_mem0/llms/groq.py +88 -0
- agentrun_mem0/llms/langchain.py +94 -0
- agentrun_mem0/llms/litellm.py +87 -0
- agentrun_mem0/llms/lmstudio.py +114 -0
- agentrun_mem0/llms/ollama.py +117 -0
- agentrun_mem0/llms/openai.py +147 -0
- agentrun_mem0/llms/openai_structured.py +52 -0
- agentrun_mem0/llms/sarvam.py +89 -0
- agentrun_mem0/llms/together.py +88 -0
- agentrun_mem0/llms/vllm.py +107 -0
- agentrun_mem0/llms/xai.py +52 -0
- agentrun_mem0/memory/__init__.py +0 -0
- agentrun_mem0/memory/base.py +63 -0
- agentrun_mem0/memory/graph_memory.py +698 -0
- agentrun_mem0/memory/kuzu_memory.py +713 -0
- agentrun_mem0/memory/main.py +2229 -0
- agentrun_mem0/memory/memgraph_memory.py +689 -0
- agentrun_mem0/memory/setup.py +56 -0
- agentrun_mem0/memory/storage.py +218 -0
- agentrun_mem0/memory/telemetry.py +90 -0
- agentrun_mem0/memory/utils.py +208 -0
- agentrun_mem0/proxy/__init__.py +0 -0
- agentrun_mem0/proxy/main.py +189 -0
- agentrun_mem0/reranker/__init__.py +9 -0
- agentrun_mem0/reranker/base.py +20 -0
- agentrun_mem0/reranker/cohere_reranker.py +85 -0
- agentrun_mem0/reranker/huggingface_reranker.py +147 -0
- agentrun_mem0/reranker/llm_reranker.py +142 -0
- agentrun_mem0/reranker/sentence_transformer_reranker.py +107 -0
- agentrun_mem0/reranker/zero_entropy_reranker.py +96 -0
- agentrun_mem0/utils/factory.py +283 -0
- agentrun_mem0/utils/gcp_auth.py +167 -0
- agentrun_mem0/vector_stores/__init__.py +0 -0
- agentrun_mem0/vector_stores/alibabacloud_mysql.py +547 -0
- agentrun_mem0/vector_stores/aliyun_tablestore.py +252 -0
- agentrun_mem0/vector_stores/azure_ai_search.py +396 -0
- agentrun_mem0/vector_stores/azure_mysql.py +463 -0
- agentrun_mem0/vector_stores/baidu.py +368 -0
- agentrun_mem0/vector_stores/base.py +58 -0
- agentrun_mem0/vector_stores/chroma.py +332 -0
- agentrun_mem0/vector_stores/configs.py +67 -0
- agentrun_mem0/vector_stores/databricks.py +761 -0
- agentrun_mem0/vector_stores/elasticsearch.py +237 -0
- agentrun_mem0/vector_stores/faiss.py +479 -0
- agentrun_mem0/vector_stores/langchain.py +180 -0
- agentrun_mem0/vector_stores/milvus.py +250 -0
- agentrun_mem0/vector_stores/mongodb.py +310 -0
- agentrun_mem0/vector_stores/neptune_analytics.py +467 -0
- agentrun_mem0/vector_stores/opensearch.py +292 -0
- agentrun_mem0/vector_stores/pgvector.py +404 -0
- agentrun_mem0/vector_stores/pinecone.py +382 -0
- agentrun_mem0/vector_stores/qdrant.py +270 -0
- agentrun_mem0/vector_stores/redis.py +295 -0
- agentrun_mem0/vector_stores/s3_vectors.py +176 -0
- agentrun_mem0/vector_stores/supabase.py +237 -0
- agentrun_mem0/vector_stores/upstash_vector.py +293 -0
- agentrun_mem0/vector_stores/valkey.py +824 -0
- agentrun_mem0/vector_stores/vertex_ai_vector_search.py +635 -0
- agentrun_mem0/vector_stores/weaviate.py +343 -0
- agentrun_mem0ai-0.0.11.data/data/README.md +205 -0
- agentrun_mem0ai-0.0.11.dist-info/METADATA +277 -0
- agentrun_mem0ai-0.0.11.dist-info/RECORD +150 -0
- agentrun_mem0ai-0.0.11.dist-info/WHEEL +4 -0
- agentrun_mem0ai-0.0.11.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,511 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import uuid
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
import pytz
|
|
5
|
+
|
|
6
|
+
from .base import NeptuneBase
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
from langchain_aws import NeptuneGraph
|
|
10
|
+
except ImportError:
|
|
11
|
+
raise ImportError("langchain_aws is not installed. Please install it using 'make install_all'.")
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
class MemoryGraph(NeptuneBase):
|
|
16
|
+
def __init__(self, config):
|
|
17
|
+
"""
|
|
18
|
+
Initialize the Neptune DB memory store.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
self.config = config
|
|
22
|
+
|
|
23
|
+
self.graph = None
|
|
24
|
+
endpoint = self.config.graph_store.config.endpoint
|
|
25
|
+
if endpoint and endpoint.startswith("neptune-db://"):
|
|
26
|
+
host = endpoint.replace("neptune-db://", "")
|
|
27
|
+
port = 8182
|
|
28
|
+
self.graph = NeptuneGraph(host, port)
|
|
29
|
+
|
|
30
|
+
if not self.graph:
|
|
31
|
+
raise ValueError("Unable to create a Neptune-DB client: missing 'endpoint' in config")
|
|
32
|
+
|
|
33
|
+
self.node_label = ":`__Entity__`" if self.config.graph_store.config.base_label else ""
|
|
34
|
+
|
|
35
|
+
self.embedding_model = NeptuneBase._create_embedding_model(self.config)
|
|
36
|
+
|
|
37
|
+
# Default to openai if no specific provider is configured
|
|
38
|
+
self.llm_provider = "openai"
|
|
39
|
+
if self.config.graph_store.llm:
|
|
40
|
+
self.llm_provider = self.config.graph_store.llm.provider
|
|
41
|
+
elif self.config.llm.provider:
|
|
42
|
+
self.llm_provider = self.config.llm.provider
|
|
43
|
+
|
|
44
|
+
# fetch the vector store as a provider
|
|
45
|
+
self.vector_store_provider = self.config.vector_store.provider
|
|
46
|
+
if self.config.graph_store.config.collection_name:
|
|
47
|
+
vector_store_collection_name = self.config.graph_store.config.collection_name
|
|
48
|
+
else:
|
|
49
|
+
vector_store_config = self.config.vector_store.config
|
|
50
|
+
if vector_store_config.collection_name:
|
|
51
|
+
vector_store_collection_name = vector_store_config.collection_name + "_neptune_vector_store"
|
|
52
|
+
else:
|
|
53
|
+
vector_store_collection_name = "mem0_neptune_vector_store"
|
|
54
|
+
self.config.vector_store.config.collection_name = vector_store_collection_name
|
|
55
|
+
self.vector_store = NeptuneBase._create_vector_store(self.vector_store_provider, self.config)
|
|
56
|
+
|
|
57
|
+
self.llm = NeptuneBase._create_llm(self.config, self.llm_provider)
|
|
58
|
+
self.user_id = None
|
|
59
|
+
self.threshold = 0.7
|
|
60
|
+
self.vector_store_limit=5
|
|
61
|
+
|
|
62
|
+
def _delete_entities_cypher(self, source, destination, relationship, user_id):
|
|
63
|
+
"""
|
|
64
|
+
Returns the OpenCypher query and parameters for deleting entities in the graph DB
|
|
65
|
+
|
|
66
|
+
:param source: source node
|
|
67
|
+
:param destination: destination node
|
|
68
|
+
:param relationship: relationship label
|
|
69
|
+
:param user_id: user_id to use
|
|
70
|
+
:return: str, dict
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
cypher = f"""
|
|
74
|
+
MATCH (n {self.node_label} {{name: $source_name, user_id: $user_id}})
|
|
75
|
+
-[r:{relationship}]->
|
|
76
|
+
(m {self.node_label} {{name: $dest_name, user_id: $user_id}})
|
|
77
|
+
DELETE r
|
|
78
|
+
RETURN
|
|
79
|
+
n.name AS source,
|
|
80
|
+
m.name AS target,
|
|
81
|
+
type(r) AS relationship
|
|
82
|
+
"""
|
|
83
|
+
params = {
|
|
84
|
+
"source_name": source,
|
|
85
|
+
"dest_name": destination,
|
|
86
|
+
"user_id": user_id,
|
|
87
|
+
}
|
|
88
|
+
logger.debug(f"_delete_entities\n query={cypher}")
|
|
89
|
+
return cypher, params
|
|
90
|
+
|
|
91
|
+
def _add_entities_by_source_cypher(
|
|
92
|
+
self,
|
|
93
|
+
source_node_list,
|
|
94
|
+
destination,
|
|
95
|
+
dest_embedding,
|
|
96
|
+
destination_type,
|
|
97
|
+
relationship,
|
|
98
|
+
user_id,
|
|
99
|
+
):
|
|
100
|
+
"""
|
|
101
|
+
Returns the OpenCypher query and parameters for adding entities in the graph DB
|
|
102
|
+
|
|
103
|
+
:param source_node_list: list of source nodes
|
|
104
|
+
:param destination: destination name
|
|
105
|
+
:param dest_embedding: destination embedding
|
|
106
|
+
:param destination_type: destination node label
|
|
107
|
+
:param relationship: relationship label
|
|
108
|
+
:param user_id: user id to use
|
|
109
|
+
:return: str, dict
|
|
110
|
+
"""
|
|
111
|
+
destination_id = str(uuid.uuid4())
|
|
112
|
+
destination_payload = {
|
|
113
|
+
"name": destination,
|
|
114
|
+
"type": destination_type,
|
|
115
|
+
"user_id": user_id,
|
|
116
|
+
"created_at": datetime.now(pytz.timezone("US/Pacific")).isoformat(),
|
|
117
|
+
}
|
|
118
|
+
self.vector_store.insert(
|
|
119
|
+
vectors=[dest_embedding],
|
|
120
|
+
payloads=[destination_payload],
|
|
121
|
+
ids=[destination_id],
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
destination_label = self.node_label if self.node_label else f":`{destination_type}`"
|
|
125
|
+
destination_extra_set = f", destination:`{destination_type}`" if self.node_label else ""
|
|
126
|
+
|
|
127
|
+
cypher = f"""
|
|
128
|
+
MATCH (source {{user_id: $user_id}})
|
|
129
|
+
WHERE id(source) = $source_id
|
|
130
|
+
SET source.mentions = coalesce(source.mentions, 0) + 1
|
|
131
|
+
WITH source
|
|
132
|
+
MERGE (destination {destination_label} {{`~id`: $destination_id, name: $destination_name, user_id: $user_id}})
|
|
133
|
+
ON CREATE SET
|
|
134
|
+
destination.created = timestamp(),
|
|
135
|
+
destination.updated = timestamp(),
|
|
136
|
+
destination.mentions = 1
|
|
137
|
+
{destination_extra_set}
|
|
138
|
+
ON MATCH SET
|
|
139
|
+
destination.mentions = coalesce(destination.mentions, 0) + 1,
|
|
140
|
+
destination.updated = timestamp()
|
|
141
|
+
WITH source, destination
|
|
142
|
+
MERGE (source)-[r:{relationship}]->(destination)
|
|
143
|
+
ON CREATE SET
|
|
144
|
+
r.created = timestamp(),
|
|
145
|
+
r.updated = timestamp(),
|
|
146
|
+
r.mentions = 1
|
|
147
|
+
ON MATCH SET
|
|
148
|
+
r.mentions = coalesce(r.mentions, 0) + 1,
|
|
149
|
+
r.updated = timestamp()
|
|
150
|
+
RETURN source.name AS source, type(r) AS relationship, destination.name AS target, id(destination) AS destination_id
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
params = {
|
|
154
|
+
"source_id": source_node_list[0]["id(source_candidate)"],
|
|
155
|
+
"destination_id": destination_id,
|
|
156
|
+
"destination_name": destination,
|
|
157
|
+
"dest_embedding": dest_embedding,
|
|
158
|
+
"user_id": user_id,
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
logger.debug(
|
|
162
|
+
f"_add_entities:\n source_node_search_result={source_node_list[0]}\n query={cypher}"
|
|
163
|
+
)
|
|
164
|
+
return cypher, params
|
|
165
|
+
|
|
166
|
+
def _add_entities_by_destination_cypher(
|
|
167
|
+
self,
|
|
168
|
+
source,
|
|
169
|
+
source_embedding,
|
|
170
|
+
source_type,
|
|
171
|
+
destination_node_list,
|
|
172
|
+
relationship,
|
|
173
|
+
user_id,
|
|
174
|
+
):
|
|
175
|
+
"""
|
|
176
|
+
Returns the OpenCypher query and parameters for adding entities in the graph DB
|
|
177
|
+
|
|
178
|
+
:param source: source node name
|
|
179
|
+
:param source_embedding: source node embedding
|
|
180
|
+
:param source_type: source node label
|
|
181
|
+
:param destination_node_list: list of dest nodes
|
|
182
|
+
:param relationship: relationship label
|
|
183
|
+
:param user_id: user id to use
|
|
184
|
+
:return: str, dict
|
|
185
|
+
"""
|
|
186
|
+
source_id = str(uuid.uuid4())
|
|
187
|
+
source_payload = {
|
|
188
|
+
"name": source,
|
|
189
|
+
"type": source_type,
|
|
190
|
+
"user_id": user_id,
|
|
191
|
+
"created_at": datetime.now(pytz.timezone("US/Pacific")).isoformat(),
|
|
192
|
+
}
|
|
193
|
+
self.vector_store.insert(
|
|
194
|
+
vectors=[source_embedding],
|
|
195
|
+
payloads=[source_payload],
|
|
196
|
+
ids=[source_id],
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
source_label = self.node_label if self.node_label else f":`{source_type}`"
|
|
200
|
+
source_extra_set = f", source:`{source_type}`" if self.node_label else ""
|
|
201
|
+
|
|
202
|
+
cypher = f"""
|
|
203
|
+
MATCH (destination {{user_id: $user_id}})
|
|
204
|
+
WHERE id(destination) = $destination_id
|
|
205
|
+
SET
|
|
206
|
+
destination.mentions = coalesce(destination.mentions, 0) + 1,
|
|
207
|
+
destination.updated = timestamp()
|
|
208
|
+
WITH destination
|
|
209
|
+
MERGE (source {source_label} {{`~id`: $source_id, name: $source_name, user_id: $user_id}})
|
|
210
|
+
ON CREATE SET
|
|
211
|
+
source.created = timestamp(),
|
|
212
|
+
source.updated = timestamp(),
|
|
213
|
+
source.mentions = 1
|
|
214
|
+
{source_extra_set}
|
|
215
|
+
ON MATCH SET
|
|
216
|
+
source.mentions = coalesce(source.mentions, 0) + 1,
|
|
217
|
+
source.updated = timestamp()
|
|
218
|
+
WITH source, destination
|
|
219
|
+
MERGE (source)-[r:{relationship}]->(destination)
|
|
220
|
+
ON CREATE SET
|
|
221
|
+
r.created = timestamp(),
|
|
222
|
+
r.updated = timestamp(),
|
|
223
|
+
r.mentions = 1
|
|
224
|
+
ON MATCH SET
|
|
225
|
+
r.mentions = coalesce(r.mentions, 0) + 1,
|
|
226
|
+
r.updated = timestamp()
|
|
227
|
+
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
|
228
|
+
"""
|
|
229
|
+
|
|
230
|
+
params = {
|
|
231
|
+
"destination_id": destination_node_list[0]["id(destination_candidate)"],
|
|
232
|
+
"source_id": source_id,
|
|
233
|
+
"source_name": source,
|
|
234
|
+
"source_embedding": source_embedding,
|
|
235
|
+
"user_id": user_id,
|
|
236
|
+
}
|
|
237
|
+
logger.debug(
|
|
238
|
+
f"_add_entities:\n destination_node_search_result={destination_node_list[0]}\n query={cypher}"
|
|
239
|
+
)
|
|
240
|
+
return cypher, params
|
|
241
|
+
|
|
242
|
+
def _add_relationship_entities_cypher(
|
|
243
|
+
self,
|
|
244
|
+
source_node_list,
|
|
245
|
+
destination_node_list,
|
|
246
|
+
relationship,
|
|
247
|
+
user_id,
|
|
248
|
+
):
|
|
249
|
+
"""
|
|
250
|
+
Returns the OpenCypher query and parameters for adding entities in the graph DB
|
|
251
|
+
|
|
252
|
+
:param source_node_list: list of source node ids
|
|
253
|
+
:param destination_node_list: list of dest node ids
|
|
254
|
+
:param relationship: relationship label
|
|
255
|
+
:param user_id: user id to use
|
|
256
|
+
:return: str, dict
|
|
257
|
+
"""
|
|
258
|
+
|
|
259
|
+
cypher = f"""
|
|
260
|
+
MATCH (source {{user_id: $user_id}})
|
|
261
|
+
WHERE id(source) = $source_id
|
|
262
|
+
SET
|
|
263
|
+
source.mentions = coalesce(source.mentions, 0) + 1,
|
|
264
|
+
source.updated = timestamp()
|
|
265
|
+
WITH source
|
|
266
|
+
MATCH (destination {{user_id: $user_id}})
|
|
267
|
+
WHERE id(destination) = $destination_id
|
|
268
|
+
SET
|
|
269
|
+
destination.mentions = coalesce(destination.mentions) + 1,
|
|
270
|
+
destination.updated = timestamp()
|
|
271
|
+
MERGE (source)-[r:{relationship}]->(destination)
|
|
272
|
+
ON CREATE SET
|
|
273
|
+
r.created_at = timestamp(),
|
|
274
|
+
r.updated_at = timestamp(),
|
|
275
|
+
r.mentions = 1
|
|
276
|
+
ON MATCH SET r.mentions = coalesce(r.mentions, 0) + 1
|
|
277
|
+
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
|
|
278
|
+
"""
|
|
279
|
+
params = {
|
|
280
|
+
"source_id": source_node_list[0]["id(source_candidate)"],
|
|
281
|
+
"destination_id": destination_node_list[0]["id(destination_candidate)"],
|
|
282
|
+
"user_id": user_id,
|
|
283
|
+
}
|
|
284
|
+
logger.debug(
|
|
285
|
+
f"_add_entities:\n destination_node_search_result={destination_node_list[0]}\n source_node_search_result={source_node_list[0]}\n query={cypher}"
|
|
286
|
+
)
|
|
287
|
+
return cypher, params
|
|
288
|
+
|
|
289
|
+
def _add_new_entities_cypher(
|
|
290
|
+
self,
|
|
291
|
+
source,
|
|
292
|
+
source_embedding,
|
|
293
|
+
source_type,
|
|
294
|
+
destination,
|
|
295
|
+
dest_embedding,
|
|
296
|
+
destination_type,
|
|
297
|
+
relationship,
|
|
298
|
+
user_id,
|
|
299
|
+
):
|
|
300
|
+
"""
|
|
301
|
+
Returns the OpenCypher query and parameters for adding entities in the graph DB
|
|
302
|
+
|
|
303
|
+
:param source: source node name
|
|
304
|
+
:param source_embedding: source node embedding
|
|
305
|
+
:param source_type: source node label
|
|
306
|
+
:param destination: destination name
|
|
307
|
+
:param dest_embedding: destination embedding
|
|
308
|
+
:param destination_type: destination node label
|
|
309
|
+
:param relationship: relationship label
|
|
310
|
+
:param user_id: user id to use
|
|
311
|
+
:return: str, dict
|
|
312
|
+
"""
|
|
313
|
+
source_id = str(uuid.uuid4())
|
|
314
|
+
source_payload = {
|
|
315
|
+
"name": source,
|
|
316
|
+
"type": source_type,
|
|
317
|
+
"user_id": user_id,
|
|
318
|
+
"created_at": datetime.now(pytz.timezone("US/Pacific")).isoformat(),
|
|
319
|
+
}
|
|
320
|
+
destination_id = str(uuid.uuid4())
|
|
321
|
+
destination_payload = {
|
|
322
|
+
"name": destination,
|
|
323
|
+
"type": destination_type,
|
|
324
|
+
"user_id": user_id,
|
|
325
|
+
"created_at": datetime.now(pytz.timezone("US/Pacific")).isoformat(),
|
|
326
|
+
}
|
|
327
|
+
self.vector_store.insert(
|
|
328
|
+
vectors=[source_embedding, dest_embedding],
|
|
329
|
+
payloads=[source_payload, destination_payload],
|
|
330
|
+
ids=[source_id, destination_id],
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
source_label = self.node_label if self.node_label else f":`{source_type}`"
|
|
334
|
+
source_extra_set = f", source:`{source_type}`" if self.node_label else ""
|
|
335
|
+
destination_label = self.node_label if self.node_label else f":`{destination_type}`"
|
|
336
|
+
destination_extra_set = f", destination:`{destination_type}`" if self.node_label else ""
|
|
337
|
+
|
|
338
|
+
cypher = f"""
|
|
339
|
+
MERGE (n {source_label} {{name: $source_name, user_id: $user_id, `~id`: $source_id}})
|
|
340
|
+
ON CREATE SET n.created = timestamp(),
|
|
341
|
+
n.mentions = 1
|
|
342
|
+
{source_extra_set}
|
|
343
|
+
ON MATCH SET n.mentions = coalesce(n.mentions, 0) + 1
|
|
344
|
+
WITH n
|
|
345
|
+
MERGE (m {destination_label} {{name: $dest_name, user_id: $user_id, `~id`: $dest_id}})
|
|
346
|
+
ON CREATE SET m.created = timestamp(),
|
|
347
|
+
m.mentions = 1
|
|
348
|
+
{destination_extra_set}
|
|
349
|
+
ON MATCH SET m.mentions = coalesce(m.mentions, 0) + 1
|
|
350
|
+
WITH n, m
|
|
351
|
+
MERGE (n)-[rel:{relationship}]->(m)
|
|
352
|
+
ON CREATE SET rel.created = timestamp(), rel.mentions = 1
|
|
353
|
+
ON MATCH SET rel.mentions = coalesce(rel.mentions, 0) + 1
|
|
354
|
+
RETURN n.name AS source, type(rel) AS relationship, m.name AS target
|
|
355
|
+
"""
|
|
356
|
+
params = {
|
|
357
|
+
"source_id": source_id,
|
|
358
|
+
"dest_id": destination_id,
|
|
359
|
+
"source_name": source,
|
|
360
|
+
"dest_name": destination,
|
|
361
|
+
"source_embedding": source_embedding,
|
|
362
|
+
"dest_embedding": dest_embedding,
|
|
363
|
+
"user_id": user_id,
|
|
364
|
+
}
|
|
365
|
+
logger.debug(
|
|
366
|
+
f"_add_new_entities_cypher:\n query={cypher}"
|
|
367
|
+
)
|
|
368
|
+
return cypher, params
|
|
369
|
+
|
|
370
|
+
def _search_source_node_cypher(self, source_embedding, user_id, threshold):
|
|
371
|
+
"""
|
|
372
|
+
Returns the OpenCypher query and parameters to search for source nodes
|
|
373
|
+
|
|
374
|
+
:param source_embedding: source vector
|
|
375
|
+
:param user_id: user_id to use
|
|
376
|
+
:param threshold: the threshold for similarity
|
|
377
|
+
:return: str, dict
|
|
378
|
+
"""
|
|
379
|
+
|
|
380
|
+
source_nodes = self.vector_store.search(
|
|
381
|
+
query="",
|
|
382
|
+
vectors=source_embedding,
|
|
383
|
+
limit=self.vector_store_limit,
|
|
384
|
+
filters={"user_id": user_id},
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
ids = [n.id for n in filter(lambda s: s.score > threshold, source_nodes)]
|
|
388
|
+
|
|
389
|
+
cypher = f"""
|
|
390
|
+
MATCH (source_candidate {self.node_label})
|
|
391
|
+
WHERE source_candidate.user_id = $user_id AND id(source_candidate) IN $ids
|
|
392
|
+
RETURN id(source_candidate)
|
|
393
|
+
"""
|
|
394
|
+
|
|
395
|
+
params = {
|
|
396
|
+
"ids": ids,
|
|
397
|
+
"source_embedding": source_embedding,
|
|
398
|
+
"user_id": user_id,
|
|
399
|
+
"threshold": threshold,
|
|
400
|
+
}
|
|
401
|
+
logger.debug(f"_search_source_node\n query={cypher}")
|
|
402
|
+
return cypher, params
|
|
403
|
+
|
|
404
|
+
def _search_destination_node_cypher(self, destination_embedding, user_id, threshold):
|
|
405
|
+
"""
|
|
406
|
+
Returns the OpenCypher query and parameters to search for destination nodes
|
|
407
|
+
|
|
408
|
+
:param source_embedding: source vector
|
|
409
|
+
:param user_id: user_id to use
|
|
410
|
+
:param threshold: the threshold for similarity
|
|
411
|
+
:return: str, dict
|
|
412
|
+
"""
|
|
413
|
+
destination_nodes = self.vector_store.search(
|
|
414
|
+
query="",
|
|
415
|
+
vectors=destination_embedding,
|
|
416
|
+
limit=self.vector_store_limit,
|
|
417
|
+
filters={"user_id": user_id},
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
ids = [n.id for n in filter(lambda d: d.score > threshold, destination_nodes)]
|
|
421
|
+
|
|
422
|
+
cypher = f"""
|
|
423
|
+
MATCH (destination_candidate {self.node_label})
|
|
424
|
+
WHERE destination_candidate.user_id = $user_id AND id(destination_candidate) IN $ids
|
|
425
|
+
RETURN id(destination_candidate)
|
|
426
|
+
"""
|
|
427
|
+
|
|
428
|
+
params = {
|
|
429
|
+
"ids": ids,
|
|
430
|
+
"destination_embedding": destination_embedding,
|
|
431
|
+
"user_id": user_id,
|
|
432
|
+
}
|
|
433
|
+
|
|
434
|
+
logger.debug(f"_search_destination_node\n query={cypher}")
|
|
435
|
+
return cypher, params
|
|
436
|
+
|
|
437
|
+
def _delete_all_cypher(self, filters):
|
|
438
|
+
"""
|
|
439
|
+
Returns the OpenCypher query and parameters to delete all edges/nodes in the memory store
|
|
440
|
+
|
|
441
|
+
:param filters: search filters
|
|
442
|
+
:return: str, dict
|
|
443
|
+
"""
|
|
444
|
+
|
|
445
|
+
# remove the vector store index
|
|
446
|
+
self.vector_store.reset()
|
|
447
|
+
|
|
448
|
+
# create a query that: deletes the nodes of the graph_store
|
|
449
|
+
cypher = f"""
|
|
450
|
+
MATCH (n {self.node_label} {{user_id: $user_id}})
|
|
451
|
+
DETACH DELETE n
|
|
452
|
+
"""
|
|
453
|
+
params = {"user_id": filters["user_id"]}
|
|
454
|
+
|
|
455
|
+
logger.debug(f"delete_all query={cypher}")
|
|
456
|
+
return cypher, params
|
|
457
|
+
|
|
458
|
+
def _get_all_cypher(self, filters, limit):
|
|
459
|
+
"""
|
|
460
|
+
Returns the OpenCypher query and parameters to get all edges/nodes in the memory store
|
|
461
|
+
|
|
462
|
+
:param filters: search filters
|
|
463
|
+
:param limit: return limit
|
|
464
|
+
:return: str, dict
|
|
465
|
+
"""
|
|
466
|
+
|
|
467
|
+
cypher = f"""
|
|
468
|
+
MATCH (n {self.node_label} {{user_id: $user_id}})-[r]->(m {self.node_label} {{user_id: $user_id}})
|
|
469
|
+
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
|
470
|
+
LIMIT $limit
|
|
471
|
+
"""
|
|
472
|
+
params = {"user_id": filters["user_id"], "limit": limit}
|
|
473
|
+
return cypher, params
|
|
474
|
+
|
|
475
|
+
def _search_graph_db_cypher(self, n_embedding, filters, limit):
|
|
476
|
+
"""
|
|
477
|
+
Returns the OpenCypher query and parameters to search for similar nodes in the memory store
|
|
478
|
+
|
|
479
|
+
:param n_embedding: node vector
|
|
480
|
+
:param filters: search filters
|
|
481
|
+
:param limit: return limit
|
|
482
|
+
:return: str, dict
|
|
483
|
+
"""
|
|
484
|
+
|
|
485
|
+
# search vector store for applicable nodes using cosine similarity
|
|
486
|
+
search_nodes = self.vector_store.search(
|
|
487
|
+
query="",
|
|
488
|
+
vectors=n_embedding,
|
|
489
|
+
limit=self.vector_store_limit,
|
|
490
|
+
filters=filters,
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
ids = [n.id for n in search_nodes]
|
|
494
|
+
|
|
495
|
+
cypher_query = f"""
|
|
496
|
+
MATCH (n {self.node_label})-[r]->(m)
|
|
497
|
+
WHERE n.user_id = $user_id AND id(n) IN $n_ids
|
|
498
|
+
RETURN n.name AS source, id(n) AS source_id, type(r) AS relationship, id(r) AS relation_id, m.name AS destination, id(m) AS destination_id
|
|
499
|
+
UNION
|
|
500
|
+
MATCH (m)-[r]->(n {self.node_label})
|
|
501
|
+
RETURN m.name AS source, id(m) AS source_id, type(r) AS relationship, id(r) AS relation_id, n.name AS destination, id(n) AS destination_id
|
|
502
|
+
LIMIT $limit
|
|
503
|
+
"""
|
|
504
|
+
params = {
|
|
505
|
+
"n_ids": ids,
|
|
506
|
+
"user_id": filters["user_id"],
|
|
507
|
+
"limit": limit,
|
|
508
|
+
}
|
|
509
|
+
logger.debug(f"_search_graph_db\n query={cypher_query}")
|
|
510
|
+
|
|
511
|
+
return cypher_query, params
|