graphiti-core 0.12.0rc1__py3-none-any.whl → 0.24.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphiti_core/cross_encoder/bge_reranker_client.py +12 -2
- graphiti_core/cross_encoder/gemini_reranker_client.py +161 -0
- graphiti_core/cross_encoder/openai_reranker_client.py +7 -5
- graphiti_core/decorators.py +110 -0
- graphiti_core/driver/__init__.py +19 -0
- graphiti_core/driver/driver.py +124 -0
- graphiti_core/driver/falkordb_driver.py +362 -0
- graphiti_core/driver/graph_operations/graph_operations.py +191 -0
- graphiti_core/driver/kuzu_driver.py +182 -0
- graphiti_core/driver/neo4j_driver.py +117 -0
- graphiti_core/driver/neptune_driver.py +305 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +287 -172
- graphiti_core/embedder/azure_openai.py +71 -0
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/embedder/gemini.py +116 -22
- graphiti_core/embedder/voyage.py +13 -2
- graphiti_core/errors.py +8 -0
- graphiti_core/graph_queries.py +162 -0
- graphiti_core/graphiti.py +705 -193
- graphiti_core/graphiti_types.py +4 -2
- graphiti_core/helpers.py +87 -10
- graphiti_core/llm_client/__init__.py +16 -0
- graphiti_core/llm_client/anthropic_client.py +159 -56
- graphiti_core/llm_client/azure_openai_client.py +115 -0
- graphiti_core/llm_client/client.py +98 -21
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/gemini_client.py +290 -41
- graphiti_core/llm_client/groq_client.py +14 -3
- graphiti_core/llm_client/openai_base_client.py +261 -0
- graphiti_core/llm_client/openai_client.py +56 -132
- graphiti_core/llm_client/openai_generic_client.py +91 -56
- graphiti_core/models/edges/edge_db_queries.py +259 -35
- graphiti_core/models/nodes/node_db_queries.py +311 -32
- graphiti_core/nodes.py +420 -205
- graphiti_core/prompts/dedupe_edges.py +46 -32
- graphiti_core/prompts/dedupe_nodes.py +67 -42
- graphiti_core/prompts/eval.py +4 -4
- graphiti_core/prompts/extract_edges.py +27 -16
- graphiti_core/prompts/extract_nodes.py +74 -31
- graphiti_core/prompts/prompt_helpers.py +39 -0
- graphiti_core/prompts/snippets.py +29 -0
- graphiti_core/prompts/summarize_nodes.py +23 -25
- graphiti_core/search/search.py +158 -82
- graphiti_core/search/search_config.py +39 -4
- graphiti_core/search/search_filters.py +126 -35
- graphiti_core/search/search_helpers.py +5 -6
- graphiti_core/search/search_utils.py +1405 -485
- graphiti_core/telemetry/__init__.py +9 -0
- graphiti_core/telemetry/telemetry.py +117 -0
- graphiti_core/tracer.py +193 -0
- graphiti_core/utils/bulk_utils.py +364 -285
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +67 -49
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +339 -197
- graphiti_core/utils/maintenance/graph_data_operations.py +50 -114
- graphiti_core/utils/maintenance/node_operations.py +319 -238
- graphiti_core/utils/maintenance/temporal_operations.py +11 -3
- graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
- graphiti_core/utils/text_utils.py +53 -0
- graphiti_core-0.24.3.dist-info/METADATA +726 -0
- graphiti_core-0.24.3.dist-info/RECORD +86 -0
- {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
- graphiti_core-0.12.0rc1.dist-info/METADATA +0 -350
- graphiti_core-0.12.0rc1.dist-info/RECORD +0 -66
- /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
- {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info/licenses}/LICENSE +0 -0
|
@@ -20,20 +20,31 @@ from time import time
|
|
|
20
20
|
from typing import Any
|
|
21
21
|
|
|
22
22
|
import numpy as np
|
|
23
|
-
from neo4j import AsyncDriver, Query
|
|
24
23
|
from numpy._typing import NDArray
|
|
25
24
|
from typing_extensions import LiteralString
|
|
26
25
|
|
|
26
|
+
from graphiti_core.driver.driver import (
|
|
27
|
+
GraphDriver,
|
|
28
|
+
GraphProvider,
|
|
29
|
+
)
|
|
27
30
|
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
|
31
|
+
from graphiti_core.graph_queries import (
|
|
32
|
+
get_nodes_query,
|
|
33
|
+
get_relationships_query,
|
|
34
|
+
get_vector_cosine_func_query,
|
|
35
|
+
)
|
|
28
36
|
from graphiti_core.helpers import (
|
|
29
|
-
DEFAULT_DATABASE,
|
|
30
|
-
RUNTIME_QUERY,
|
|
31
37
|
lucene_sanitize,
|
|
32
38
|
normalize_l2,
|
|
33
39
|
semaphore_gather,
|
|
34
40
|
)
|
|
41
|
+
from graphiti_core.models.edges.edge_db_queries import get_entity_edge_return_query
|
|
42
|
+
from graphiti_core.models.nodes.node_db_queries import (
|
|
43
|
+
COMMUNITY_NODE_RETURN,
|
|
44
|
+
EPISODIC_NODE_RETURN,
|
|
45
|
+
get_entity_node_return_query,
|
|
46
|
+
)
|
|
35
47
|
from graphiti_core.nodes import (
|
|
36
|
-
ENTITY_NODE_RETURN,
|
|
37
48
|
CommunityNode,
|
|
38
49
|
EntityNode,
|
|
39
50
|
EpisodicNode,
|
|
@@ -53,16 +64,39 @@ RELEVANT_SCHEMA_LIMIT = 10
|
|
|
53
64
|
DEFAULT_MIN_SCORE = 0.6
|
|
54
65
|
DEFAULT_MMR_LAMBDA = 0.5
|
|
55
66
|
MAX_SEARCH_DEPTH = 3
|
|
56
|
-
MAX_QUERY_LENGTH =
|
|
67
|
+
MAX_QUERY_LENGTH = 128
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def calculate_cosine_similarity(vector1: list[float], vector2: list[float]) -> float:
|
|
71
|
+
"""
|
|
72
|
+
Calculates the cosine similarity between two vectors using NumPy.
|
|
73
|
+
"""
|
|
74
|
+
dot_product = np.dot(vector1, vector2)
|
|
75
|
+
norm_vector1 = np.linalg.norm(vector1)
|
|
76
|
+
norm_vector2 = np.linalg.norm(vector2)
|
|
77
|
+
|
|
78
|
+
if norm_vector1 == 0 or norm_vector2 == 0:
|
|
79
|
+
return 0 # Handle cases where one or both vectors are zero vectors
|
|
80
|
+
|
|
81
|
+
return dot_product / (norm_vector1 * norm_vector2)
|
|
57
82
|
|
|
58
83
|
|
|
59
|
-
def fulltext_query(query: str, group_ids: list[str] | None
|
|
84
|
+
def fulltext_query(query: str, group_ids: list[str] | None, driver: GraphDriver):
|
|
85
|
+
if driver.provider == GraphProvider.KUZU:
|
|
86
|
+
# Kuzu only supports simple queries.
|
|
87
|
+
if len(query.split(' ')) > MAX_QUERY_LENGTH:
|
|
88
|
+
return ''
|
|
89
|
+
return query
|
|
90
|
+
elif driver.provider == GraphProvider.FALKORDB:
|
|
91
|
+
return driver.build_fulltext_query(query, group_ids, MAX_QUERY_LENGTH)
|
|
60
92
|
group_ids_filter_list = (
|
|
61
|
-
[f'group_id:"{
|
|
93
|
+
[driver.fulltext_syntax + f'group_id:"{g}"' for g in group_ids]
|
|
94
|
+
if group_ids is not None
|
|
95
|
+
else []
|
|
62
96
|
)
|
|
63
97
|
group_ids_filter = ''
|
|
64
98
|
for f in group_ids_filter_list:
|
|
65
|
-
group_ids_filter += f if not group_ids_filter else f'OR {f}'
|
|
99
|
+
group_ids_filter += f if not group_ids_filter else f' OR {f}'
|
|
66
100
|
|
|
67
101
|
group_ids_filter += ' AND ' if group_ids_filter else ''
|
|
68
102
|
|
|
@@ -77,7 +111,7 @@ def fulltext_query(query: str, group_ids: list[str] | None = None):
|
|
|
77
111
|
|
|
78
112
|
|
|
79
113
|
async def get_episodes_by_mentions(
|
|
80
|
-
driver:
|
|
114
|
+
driver: GraphDriver,
|
|
81
115
|
nodes: list[EntityNode],
|
|
82
116
|
edges: list[EntityEdge],
|
|
83
117
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
@@ -92,47 +126,39 @@ async def get_episodes_by_mentions(
|
|
|
92
126
|
|
|
93
127
|
|
|
94
128
|
async def get_mentioned_nodes(
|
|
95
|
-
driver:
|
|
129
|
+
driver: GraphDriver, episodes: list[EpisodicNode]
|
|
96
130
|
) -> list[EntityNode]:
|
|
97
131
|
episode_uuids = [episode.uuid for episode in episodes]
|
|
132
|
+
|
|
98
133
|
records, _, _ = await driver.execute_query(
|
|
99
134
|
"""
|
|
100
|
-
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity)
|
|
135
|
+
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity)
|
|
136
|
+
WHERE episode.uuid IN $uuids
|
|
101
137
|
RETURN DISTINCT
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
n.name AS name,
|
|
105
|
-
n.created_at AS created_at,
|
|
106
|
-
n.summary AS summary,
|
|
107
|
-
labels(n) AS labels,
|
|
108
|
-
properties(n) AS attributes
|
|
109
|
-
""",
|
|
138
|
+
"""
|
|
139
|
+
+ get_entity_node_return_query(driver.provider),
|
|
110
140
|
uuids=episode_uuids,
|
|
111
|
-
database_=DEFAULT_DATABASE,
|
|
112
141
|
routing_='r',
|
|
113
142
|
)
|
|
114
143
|
|
|
115
|
-
nodes = [get_entity_node_from_record(record) for record in records]
|
|
144
|
+
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
|
|
116
145
|
|
|
117
146
|
return nodes
|
|
118
147
|
|
|
119
148
|
|
|
120
149
|
async def get_communities_by_nodes(
|
|
121
|
-
driver:
|
|
150
|
+
driver: GraphDriver, nodes: list[EntityNode]
|
|
122
151
|
) -> list[CommunityNode]:
|
|
123
152
|
node_uuids = [node.uuid for node in nodes]
|
|
153
|
+
|
|
124
154
|
records, _, _ = await driver.execute_query(
|
|
125
155
|
"""
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
c.created_at AS created_at,
|
|
132
|
-
c.summary AS summary
|
|
133
|
-
""",
|
|
156
|
+
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)
|
|
157
|
+
WHERE m.uuid IN $uuids
|
|
158
|
+
RETURN DISTINCT
|
|
159
|
+
"""
|
|
160
|
+
+ COMMUNITY_NODE_RETURN,
|
|
134
161
|
uuids=node_uuids,
|
|
135
|
-
database_=DEFAULT_DATABASE,
|
|
136
162
|
routing_='r',
|
|
137
163
|
)
|
|
138
164
|
|
|
@@ -142,61 +168,122 @@ async def get_communities_by_nodes(
|
|
|
142
168
|
|
|
143
169
|
|
|
144
170
|
async def edge_fulltext_search(
|
|
145
|
-
driver:
|
|
171
|
+
driver: GraphDriver,
|
|
146
172
|
query: str,
|
|
147
173
|
search_filter: SearchFilters,
|
|
148
174
|
group_ids: list[str] | None = None,
|
|
149
175
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
150
176
|
) -> list[EntityEdge]:
|
|
177
|
+
if driver.search_interface:
|
|
178
|
+
return await driver.search_interface.edge_fulltext_search(
|
|
179
|
+
driver, query, search_filter, group_ids, limit
|
|
180
|
+
)
|
|
181
|
+
|
|
151
182
|
# fulltext search over facts
|
|
152
|
-
fuzzy_query = fulltext_query(query, group_ids)
|
|
183
|
+
fuzzy_query = fulltext_query(query, group_ids, driver)
|
|
184
|
+
|
|
153
185
|
if fuzzy_query == '':
|
|
154
186
|
return []
|
|
155
187
|
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
188
|
+
match_query = """
|
|
189
|
+
YIELD relationship AS rel, score
|
|
190
|
+
MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
|
|
191
|
+
"""
|
|
192
|
+
if driver.provider == GraphProvider.KUZU:
|
|
193
|
+
match_query = """
|
|
194
|
+
YIELD node, score
|
|
195
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: node.uuid})-[:RELATES_TO]->(m:Entity)
|
|
159
196
|
"""
|
|
160
|
-
CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query, {limit: $limit})
|
|
161
|
-
YIELD relationship AS rel, score
|
|
162
|
-
MATCH (:Entity)-[r:RELATES_TO]->(:Entity)
|
|
163
|
-
WHERE r.group_id IN $group_ids"""
|
|
164
|
-
+ filter_query
|
|
165
|
-
+ """\nWITH r, score, startNode(r) AS n, endNode(r) AS m
|
|
166
|
-
RETURN
|
|
167
|
-
r.uuid AS uuid,
|
|
168
|
-
r.group_id AS group_id,
|
|
169
|
-
n.uuid AS source_node_uuid,
|
|
170
|
-
m.uuid AS target_node_uuid,
|
|
171
|
-
r.created_at AS created_at,
|
|
172
|
-
r.name AS name,
|
|
173
|
-
r.fact AS fact,
|
|
174
|
-
r.episodes AS episodes,
|
|
175
|
-
r.expired_at AS expired_at,
|
|
176
|
-
r.valid_at AS valid_at,
|
|
177
|
-
r.invalid_at AS invalid_at,
|
|
178
|
-
properties(r) AS attributes
|
|
179
|
-
ORDER BY score DESC LIMIT $limit
|
|
180
|
-
"""
|
|
181
|
-
)
|
|
182
197
|
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
filter_params,
|
|
186
|
-
query=fuzzy_query,
|
|
187
|
-
group_ids=group_ids,
|
|
188
|
-
limit=limit,
|
|
189
|
-
database_=DEFAULT_DATABASE,
|
|
190
|
-
routing_='r',
|
|
198
|
+
filter_queries, filter_params = edge_search_filter_query_constructor(
|
|
199
|
+
search_filter, driver.provider
|
|
191
200
|
)
|
|
192
201
|
|
|
193
|
-
|
|
202
|
+
if group_ids is not None:
|
|
203
|
+
filter_queries.append('e.group_id IN $group_ids')
|
|
204
|
+
filter_params['group_ids'] = group_ids
|
|
205
|
+
|
|
206
|
+
filter_query = ''
|
|
207
|
+
if filter_queries:
|
|
208
|
+
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
|
|
209
|
+
|
|
210
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
211
|
+
res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue
|
|
212
|
+
if res['hits']['total']['value'] > 0:
|
|
213
|
+
input_ids = []
|
|
214
|
+
for r in res['hits']['hits']:
|
|
215
|
+
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
|
|
216
|
+
|
|
217
|
+
# Match the edge ids and return the values
|
|
218
|
+
query = (
|
|
219
|
+
"""
|
|
220
|
+
UNWIND $ids as id
|
|
221
|
+
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
222
|
+
WHERE e.group_id IN $group_ids
|
|
223
|
+
AND id(e)=id
|
|
224
|
+
"""
|
|
225
|
+
+ filter_query
|
|
226
|
+
+ """
|
|
227
|
+
AND id(e)=id
|
|
228
|
+
WITH e, id.score as score, startNode(e) AS n, endNode(e) AS m
|
|
229
|
+
RETURN
|
|
230
|
+
e.uuid AS uuid,
|
|
231
|
+
e.group_id AS group_id,
|
|
232
|
+
n.uuid AS source_node_uuid,
|
|
233
|
+
m.uuid AS target_node_uuid,
|
|
234
|
+
e.created_at AS created_at,
|
|
235
|
+
e.name AS name,
|
|
236
|
+
e.fact AS fact,
|
|
237
|
+
split(e.episodes, ",") AS episodes,
|
|
238
|
+
e.expired_at AS expired_at,
|
|
239
|
+
e.valid_at AS valid_at,
|
|
240
|
+
e.invalid_at AS invalid_at,
|
|
241
|
+
properties(e) AS attributes
|
|
242
|
+
ORDER BY score DESC LIMIT $limit
|
|
243
|
+
"""
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
records, _, _ = await driver.execute_query(
|
|
247
|
+
query,
|
|
248
|
+
query=fuzzy_query,
|
|
249
|
+
ids=input_ids,
|
|
250
|
+
limit=limit,
|
|
251
|
+
routing_='r',
|
|
252
|
+
**filter_params,
|
|
253
|
+
)
|
|
254
|
+
else:
|
|
255
|
+
return []
|
|
256
|
+
else:
|
|
257
|
+
query = (
|
|
258
|
+
get_relationships_query('edge_name_and_fact', limit=limit, provider=driver.provider)
|
|
259
|
+
+ match_query
|
|
260
|
+
+ filter_query
|
|
261
|
+
+ """
|
|
262
|
+
WITH e, score, n, m
|
|
263
|
+
RETURN
|
|
264
|
+
"""
|
|
265
|
+
+ get_entity_edge_return_query(driver.provider)
|
|
266
|
+
+ """
|
|
267
|
+
ORDER BY score DESC
|
|
268
|
+
LIMIT $limit
|
|
269
|
+
"""
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
records, _, _ = await driver.execute_query(
|
|
273
|
+
query,
|
|
274
|
+
query=fuzzy_query,
|
|
275
|
+
limit=limit,
|
|
276
|
+
routing_='r',
|
|
277
|
+
**filter_params,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
194
281
|
|
|
195
282
|
return edges
|
|
196
283
|
|
|
197
284
|
|
|
198
285
|
async def edge_similarity_search(
|
|
199
|
-
driver:
|
|
286
|
+
driver: GraphDriver,
|
|
200
287
|
search_vector: list[float],
|
|
201
288
|
source_node_uuid: str | None,
|
|
202
289
|
target_node_uuid: str | None,
|
|
@@ -205,34 +292,85 @@ async def edge_similarity_search(
|
|
|
205
292
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
206
293
|
min_score: float = DEFAULT_MIN_SCORE,
|
|
207
294
|
) -> list[EntityEdge]:
|
|
208
|
-
|
|
209
|
-
|
|
295
|
+
if driver.search_interface:
|
|
296
|
+
return await driver.search_interface.edge_similarity_search(
|
|
297
|
+
driver,
|
|
298
|
+
search_vector,
|
|
299
|
+
source_node_uuid,
|
|
300
|
+
target_node_uuid,
|
|
301
|
+
search_filter,
|
|
302
|
+
group_ids,
|
|
303
|
+
limit,
|
|
304
|
+
min_score,
|
|
305
|
+
)
|
|
210
306
|
|
|
211
|
-
|
|
212
|
-
|
|
307
|
+
match_query = """
|
|
308
|
+
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
309
|
+
"""
|
|
310
|
+
if driver.provider == GraphProvider.KUZU:
|
|
311
|
+
match_query = """
|
|
312
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
|
|
313
|
+
"""
|
|
314
|
+
|
|
315
|
+
filter_queries, filter_params = edge_search_filter_query_constructor(
|
|
316
|
+
search_filter, driver.provider
|
|
317
|
+
)
|
|
213
318
|
|
|
214
|
-
group_filter_query: LiteralString = ''
|
|
215
319
|
if group_ids is not None:
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
query_params['source_node_uuid'] = source_node_uuid
|
|
219
|
-
query_params['target_node_uuid'] = target_node_uuid
|
|
320
|
+
filter_queries.append('e.group_id IN $group_ids')
|
|
321
|
+
filter_params['group_ids'] = group_ids
|
|
220
322
|
|
|
221
323
|
if source_node_uuid is not None:
|
|
222
|
-
|
|
324
|
+
filter_params['source_uuid'] = source_node_uuid
|
|
325
|
+
filter_queries.append('n.uuid = $source_uuid')
|
|
223
326
|
|
|
224
327
|
if target_node_uuid is not None:
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
328
|
+
filter_params['target_uuid'] = target_node_uuid
|
|
329
|
+
filter_queries.append('m.uuid = $target_uuid')
|
|
330
|
+
|
|
331
|
+
filter_query = ''
|
|
332
|
+
if filter_queries:
|
|
333
|
+
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
|
|
334
|
+
|
|
335
|
+
search_vector_var = '$search_vector'
|
|
336
|
+
if driver.provider == GraphProvider.KUZU:
|
|
337
|
+
search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
|
|
338
|
+
|
|
339
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
340
|
+
query = (
|
|
341
|
+
"""
|
|
342
|
+
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
343
|
+
"""
|
|
344
|
+
+ filter_query
|
|
345
|
+
+ """
|
|
346
|
+
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
|
|
347
|
+
"""
|
|
348
|
+
)
|
|
349
|
+
resp, header, _ = await driver.execute_query(
|
|
350
|
+
query,
|
|
351
|
+
search_vector=search_vector,
|
|
352
|
+
limit=limit,
|
|
353
|
+
min_score=min_score,
|
|
354
|
+
routing_='r',
|
|
355
|
+
**filter_params,
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
if len(resp) > 0:
|
|
359
|
+
# Calculate Cosine similarity then return the edge ids
|
|
360
|
+
input_ids = []
|
|
361
|
+
for r in resp:
|
|
362
|
+
if r['embedding']:
|
|
363
|
+
score = calculate_cosine_similarity(
|
|
364
|
+
search_vector, list(map(float, r['embedding'].split(',')))
|
|
365
|
+
)
|
|
366
|
+
if score > min_score:
|
|
367
|
+
input_ids.append({'id': r['id'], 'score': score})
|
|
368
|
+
|
|
369
|
+
# Match the edge ides and return the values
|
|
370
|
+
query = """
|
|
371
|
+
UNWIND $ids as i
|
|
372
|
+
MATCH ()-[r]->()
|
|
373
|
+
WHERE id(r) = i.id
|
|
236
374
|
RETURN
|
|
237
375
|
r.uuid AS uuid,
|
|
238
376
|
r.group_id AS group_id,
|
|
@@ -241,292 +379,648 @@ async def edge_similarity_search(
|
|
|
241
379
|
r.created_at AS created_at,
|
|
242
380
|
r.name AS name,
|
|
243
381
|
r.fact AS fact,
|
|
244
|
-
r.episodes AS episodes,
|
|
382
|
+
split(r.episodes, ",") AS episodes,
|
|
245
383
|
r.expired_at AS expired_at,
|
|
246
384
|
r.valid_at AS valid_at,
|
|
247
385
|
r.invalid_at AS invalid_at,
|
|
248
386
|
properties(r) AS attributes
|
|
249
|
-
ORDER BY score DESC
|
|
387
|
+
ORDER BY i.score DESC
|
|
250
388
|
LIMIT $limit
|
|
251
|
-
|
|
252
|
-
|
|
389
|
+
"""
|
|
390
|
+
records, _, _ = await driver.execute_query(
|
|
391
|
+
query,
|
|
392
|
+
ids=input_ids,
|
|
393
|
+
search_vector=search_vector,
|
|
394
|
+
limit=limit,
|
|
395
|
+
min_score=min_score,
|
|
396
|
+
routing_='r',
|
|
397
|
+
**filter_params,
|
|
398
|
+
)
|
|
399
|
+
else:
|
|
400
|
+
return []
|
|
401
|
+
else:
|
|
402
|
+
query = (
|
|
403
|
+
match_query
|
|
404
|
+
+ filter_query
|
|
405
|
+
+ """
|
|
406
|
+
WITH DISTINCT e, n, m, """
|
|
407
|
+
+ get_vector_cosine_func_query('e.fact_embedding', search_vector_var, driver.provider)
|
|
408
|
+
+ """ AS score
|
|
409
|
+
WHERE score > $min_score
|
|
410
|
+
RETURN
|
|
411
|
+
"""
|
|
412
|
+
+ get_entity_edge_return_query(driver.provider)
|
|
413
|
+
+ """
|
|
414
|
+
ORDER BY score DESC
|
|
415
|
+
LIMIT $limit
|
|
416
|
+
"""
|
|
417
|
+
)
|
|
253
418
|
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
min_score=min_score,
|
|
263
|
-
database_=DEFAULT_DATABASE,
|
|
264
|
-
routing_='r',
|
|
265
|
-
)
|
|
419
|
+
records, _, _ = await driver.execute_query(
|
|
420
|
+
query,
|
|
421
|
+
search_vector=search_vector,
|
|
422
|
+
limit=limit,
|
|
423
|
+
min_score=min_score,
|
|
424
|
+
routing_='r',
|
|
425
|
+
**filter_params,
|
|
426
|
+
)
|
|
266
427
|
|
|
267
|
-
edges = [get_entity_edge_from_record(record) for record in records]
|
|
428
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
268
429
|
|
|
269
430
|
return edges
|
|
270
431
|
|
|
271
432
|
|
|
272
433
|
async def edge_bfs_search(
|
|
273
|
-
driver:
|
|
434
|
+
driver: GraphDriver,
|
|
274
435
|
bfs_origin_node_uuids: list[str] | None,
|
|
275
436
|
bfs_max_depth: int,
|
|
276
437
|
search_filter: SearchFilters,
|
|
277
|
-
|
|
438
|
+
group_ids: list[str] | None = None,
|
|
439
|
+
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
278
440
|
) -> list[EntityEdge]:
|
|
279
441
|
# vector similarity search over embedded facts
|
|
280
|
-
if bfs_origin_node_uuids is None:
|
|
442
|
+
if bfs_origin_node_uuids is None or len(bfs_origin_node_uuids) == 0:
|
|
281
443
|
return []
|
|
282
444
|
|
|
283
|
-
|
|
445
|
+
filter_queries, filter_params = edge_search_filter_query_constructor(
|
|
446
|
+
search_filter, driver.provider
|
|
447
|
+
)
|
|
284
448
|
|
|
285
|
-
|
|
286
|
-
|
|
449
|
+
if group_ids is not None:
|
|
450
|
+
filter_queries.append('e.group_id IN $group_ids')
|
|
451
|
+
filter_params['group_ids'] = group_ids
|
|
452
|
+
|
|
453
|
+
filter_query = ''
|
|
454
|
+
if filter_queries:
|
|
455
|
+
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
|
|
456
|
+
|
|
457
|
+
if driver.provider == GraphProvider.KUZU:
|
|
458
|
+
# Kuzu stores entity edges twice with an intermediate node, so we need to match them
|
|
459
|
+
# separately for the correct BFS depth.
|
|
460
|
+
depth = bfs_max_depth * 2 - 1
|
|
461
|
+
match_queries = [
|
|
462
|
+
f"""
|
|
463
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
464
|
+
MATCH path = (origin:Entity {{uuid: origin_uuid}})-[:RELATES_TO*1..{depth}]->(:RelatesToNode_)
|
|
465
|
+
UNWIND nodes(path) AS relNode
|
|
466
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {{uuid: relNode.uuid}})-[:RELATES_TO]->(m:Entity)
|
|
467
|
+
""",
|
|
468
|
+
]
|
|
469
|
+
if bfs_max_depth > 1:
|
|
470
|
+
depth = (bfs_max_depth - 1) * 2 - 1
|
|
471
|
+
match_queries.append(f"""
|
|
287
472
|
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
288
|
-
MATCH path = (origin:
|
|
473
|
+
MATCH path = (origin:Episodic {{uuid: origin_uuid}})-[:MENTIONS]->(:Entity)-[:RELATES_TO*1..{depth}]->(:RelatesToNode_)
|
|
474
|
+
UNWIND nodes(path) AS relNode
|
|
475
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {{uuid: relNode.uuid}})-[:RELATES_TO]->(m:Entity)
|
|
476
|
+
""")
|
|
477
|
+
|
|
478
|
+
records = []
|
|
479
|
+
for match_query in match_queries:
|
|
480
|
+
sub_records, _, _ = await driver.execute_query(
|
|
481
|
+
match_query
|
|
482
|
+
+ filter_query
|
|
483
|
+
+ """
|
|
484
|
+
RETURN DISTINCT
|
|
485
|
+
"""
|
|
486
|
+
+ get_entity_edge_return_query(driver.provider)
|
|
487
|
+
+ """
|
|
488
|
+
LIMIT $limit
|
|
489
|
+
""",
|
|
490
|
+
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
|
491
|
+
limit=limit,
|
|
492
|
+
routing_='r',
|
|
493
|
+
**filter_params,
|
|
494
|
+
)
|
|
495
|
+
records.extend(sub_records)
|
|
496
|
+
else:
|
|
497
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
498
|
+
query = (
|
|
499
|
+
f"""
|
|
500
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
501
|
+
MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS *1..{bfs_max_depth}]->(n:Entity)
|
|
502
|
+
WHERE origin:Entity OR origin:Episodic
|
|
289
503
|
UNWIND relationships(path) AS rel
|
|
290
|
-
MATCH ()-[
|
|
291
|
-
WHERE r.uuid = rel.uuid
|
|
504
|
+
MATCH (n:Entity)-[e:RELATES_TO {{uuid: rel.uuid}}]-(m:Entity)
|
|
292
505
|
"""
|
|
293
|
-
|
|
294
|
-
|
|
506
|
+
+ filter_query
|
|
507
|
+
+ """
|
|
295
508
|
RETURN DISTINCT
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
startNode(
|
|
299
|
-
endNode(
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
properties(
|
|
509
|
+
e.uuid AS uuid,
|
|
510
|
+
e.group_id AS group_id,
|
|
511
|
+
startNode(e).uuid AS source_node_uuid,
|
|
512
|
+
endNode(e).uuid AS target_node_uuid,
|
|
513
|
+
e.created_at AS created_at,
|
|
514
|
+
e.name AS name,
|
|
515
|
+
e.fact AS fact,
|
|
516
|
+
split(e.episodes, ',') AS episodes,
|
|
517
|
+
e.expired_at AS expired_at,
|
|
518
|
+
e.valid_at AS valid_at,
|
|
519
|
+
e.invalid_at AS invalid_at,
|
|
520
|
+
properties(e) AS attributes
|
|
308
521
|
LIMIT $limit
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
522
|
+
"""
|
|
523
|
+
)
|
|
524
|
+
else:
|
|
525
|
+
query = (
|
|
526
|
+
f"""
|
|
527
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
528
|
+
MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(:Entity)
|
|
529
|
+
UNWIND relationships(path) AS rel
|
|
530
|
+
MATCH (n:Entity)-[e:RELATES_TO {{uuid: rel.uuid}}]-(m:Entity)
|
|
531
|
+
"""
|
|
532
|
+
+ filter_query
|
|
533
|
+
+ """
|
|
534
|
+
RETURN DISTINCT
|
|
535
|
+
"""
|
|
536
|
+
+ get_entity_edge_return_query(driver.provider)
|
|
537
|
+
+ """
|
|
538
|
+
LIMIT $limit
|
|
539
|
+
"""
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
records, _, _ = await driver.execute_query(
|
|
543
|
+
query,
|
|
544
|
+
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
|
545
|
+
depth=bfs_max_depth,
|
|
546
|
+
limit=limit,
|
|
547
|
+
routing_='r',
|
|
548
|
+
**filter_params,
|
|
549
|
+
)
|
|
321
550
|
|
|
322
|
-
edges = [get_entity_edge_from_record(record) for record in records]
|
|
551
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
323
552
|
|
|
324
553
|
return edges
|
|
325
554
|
|
|
326
555
|
|
|
327
556
|
async def node_fulltext_search(
|
|
328
|
-
driver:
|
|
557
|
+
driver: GraphDriver,
|
|
329
558
|
query: str,
|
|
330
559
|
search_filter: SearchFilters,
|
|
331
560
|
group_ids: list[str] | None = None,
|
|
332
561
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
333
562
|
) -> list[EntityNode]:
|
|
563
|
+
if driver.search_interface:
|
|
564
|
+
return await driver.search_interface.node_fulltext_search(
|
|
565
|
+
driver, query, search_filter, group_ids, limit
|
|
566
|
+
)
|
|
567
|
+
|
|
334
568
|
# BM25 search to get top nodes
|
|
335
|
-
fuzzy_query = fulltext_query(query, group_ids)
|
|
569
|
+
fuzzy_query = fulltext_query(query, group_ids, driver)
|
|
336
570
|
if fuzzy_query == '':
|
|
337
571
|
return []
|
|
338
572
|
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
query = (
|
|
342
|
-
"""
|
|
343
|
-
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
|
344
|
-
YIELD node AS n, score
|
|
345
|
-
WHERE n:Entity
|
|
346
|
-
"""
|
|
347
|
-
+ filter_query
|
|
348
|
-
+ ENTITY_NODE_RETURN
|
|
349
|
-
+ """
|
|
350
|
-
ORDER BY score DESC
|
|
351
|
-
"""
|
|
573
|
+
filter_queries, filter_params = node_search_filter_query_constructor(
|
|
574
|
+
search_filter, driver.provider
|
|
352
575
|
)
|
|
353
576
|
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
filter_params
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
577
|
+
if group_ids is not None:
|
|
578
|
+
filter_queries.append('n.group_id IN $group_ids')
|
|
579
|
+
filter_params['group_ids'] = group_ids
|
|
580
|
+
|
|
581
|
+
filter_query = ''
|
|
582
|
+
if filter_queries:
|
|
583
|
+
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
|
|
584
|
+
|
|
585
|
+
yield_query = 'YIELD node AS n, score'
|
|
586
|
+
if driver.provider == GraphProvider.KUZU:
|
|
587
|
+
yield_query = 'WITH node AS n, score'
|
|
588
|
+
|
|
589
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
590
|
+
res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
|
591
|
+
if res['hits']['total']['value'] > 0:
|
|
592
|
+
input_ids = []
|
|
593
|
+
for r in res['hits']['hits']:
|
|
594
|
+
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
|
|
595
|
+
|
|
596
|
+
# Match the edge ides and return the values
|
|
597
|
+
query = (
|
|
598
|
+
"""
|
|
599
|
+
UNWIND $ids as i
|
|
600
|
+
MATCH (n:Entity)
|
|
601
|
+
WHERE n.uuid=i.id
|
|
602
|
+
RETURN
|
|
603
|
+
"""
|
|
604
|
+
+ get_entity_node_return_query(driver.provider)
|
|
605
|
+
+ """
|
|
606
|
+
ORDER BY i.score DESC
|
|
607
|
+
LIMIT $limit
|
|
608
|
+
"""
|
|
609
|
+
)
|
|
610
|
+
records, _, _ = await driver.execute_query(
|
|
611
|
+
query,
|
|
612
|
+
ids=input_ids,
|
|
613
|
+
query=fuzzy_query,
|
|
614
|
+
limit=limit,
|
|
615
|
+
routing_='r',
|
|
616
|
+
**filter_params,
|
|
617
|
+
)
|
|
618
|
+
else:
|
|
619
|
+
return []
|
|
620
|
+
else:
|
|
621
|
+
query = (
|
|
622
|
+
get_nodes_query(
|
|
623
|
+
'node_name_and_summary', '$query', limit=limit, provider=driver.provider
|
|
624
|
+
)
|
|
625
|
+
+ yield_query
|
|
626
|
+
+ filter_query
|
|
627
|
+
+ """
|
|
628
|
+
WITH n, score
|
|
629
|
+
ORDER BY score DESC
|
|
630
|
+
LIMIT $limit
|
|
631
|
+
RETURN
|
|
632
|
+
"""
|
|
633
|
+
+ get_entity_node_return_query(driver.provider)
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
records, _, _ = await driver.execute_query(
|
|
637
|
+
query,
|
|
638
|
+
query=fuzzy_query,
|
|
639
|
+
limit=limit,
|
|
640
|
+
routing_='r',
|
|
641
|
+
**filter_params,
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
|
|
364
645
|
|
|
365
646
|
return nodes
|
|
366
647
|
|
|
367
648
|
|
|
368
649
|
async def node_similarity_search(
|
|
369
|
-
driver:
|
|
650
|
+
driver: GraphDriver,
|
|
370
651
|
search_vector: list[float],
|
|
371
652
|
search_filter: SearchFilters,
|
|
372
653
|
group_ids: list[str] | None = None,
|
|
373
654
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
374
655
|
min_score: float = DEFAULT_MIN_SCORE,
|
|
375
656
|
) -> list[EntityNode]:
|
|
376
|
-
|
|
377
|
-
|
|
657
|
+
if driver.search_interface:
|
|
658
|
+
return await driver.search_interface.node_similarity_search(
|
|
659
|
+
driver, search_vector, search_filter, group_ids, limit, min_score
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
filter_queries, filter_params = node_search_filter_query_constructor(
|
|
663
|
+
search_filter, driver.provider
|
|
664
|
+
)
|
|
378
665
|
|
|
379
|
-
group_filter_query: LiteralString = ''
|
|
380
666
|
if group_ids is not None:
|
|
381
|
-
|
|
382
|
-
|
|
667
|
+
filter_queries.append('n.group_id IN $group_ids')
|
|
668
|
+
filter_params['group_ids'] = group_ids
|
|
383
669
|
|
|
384
|
-
filter_query
|
|
385
|
-
|
|
670
|
+
filter_query = ''
|
|
671
|
+
if filter_queries:
|
|
672
|
+
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
|
|
386
673
|
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
674
|
+
search_vector_var = '$search_vector'
|
|
675
|
+
if driver.provider == GraphProvider.KUZU:
|
|
676
|
+
search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
|
|
677
|
+
|
|
678
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
679
|
+
query = (
|
|
391
680
|
"""
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
681
|
+
MATCH (n:Entity)
|
|
682
|
+
"""
|
|
683
|
+
+ filter_query
|
|
684
|
+
+ """
|
|
685
|
+
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
|
686
|
+
"""
|
|
687
|
+
)
|
|
688
|
+
resp, header, _ = await driver.execute_query(
|
|
689
|
+
query,
|
|
690
|
+
params=filter_params,
|
|
691
|
+
search_vector=search_vector,
|
|
692
|
+
limit=limit,
|
|
693
|
+
min_score=min_score,
|
|
694
|
+
routing_='r',
|
|
695
|
+
)
|
|
696
|
+
|
|
697
|
+
if len(resp) > 0:
|
|
698
|
+
# Calculate Cosine similarity then return the edge ids
|
|
699
|
+
input_ids = []
|
|
700
|
+
for r in resp:
|
|
701
|
+
if r['embedding']:
|
|
702
|
+
score = calculate_cosine_similarity(
|
|
703
|
+
search_vector, list(map(float, r['embedding'].split(',')))
|
|
704
|
+
)
|
|
705
|
+
if score > min_score:
|
|
706
|
+
input_ids.append({'id': r['id'], 'score': score})
|
|
707
|
+
|
|
708
|
+
# Match the edge ides and return the values
|
|
709
|
+
query = (
|
|
710
|
+
"""
|
|
711
|
+
UNWIND $ids as i
|
|
712
|
+
MATCH (n:Entity)
|
|
713
|
+
WHERE id(n)=i.id
|
|
714
|
+
RETURN
|
|
715
|
+
"""
|
|
716
|
+
+ get_entity_node_return_query(driver.provider)
|
|
717
|
+
+ """
|
|
718
|
+
ORDER BY i.score DESC
|
|
719
|
+
LIMIT $limit
|
|
720
|
+
"""
|
|
721
|
+
)
|
|
722
|
+
records, header, _ = await driver.execute_query(
|
|
723
|
+
query,
|
|
724
|
+
ids=input_ids,
|
|
725
|
+
search_vector=search_vector,
|
|
726
|
+
limit=limit,
|
|
727
|
+
min_score=min_score,
|
|
728
|
+
routing_='r',
|
|
729
|
+
**filter_params,
|
|
730
|
+
)
|
|
731
|
+
else:
|
|
732
|
+
return []
|
|
733
|
+
else:
|
|
734
|
+
query = (
|
|
735
|
+
"""
|
|
736
|
+
MATCH (n:Entity)
|
|
737
|
+
"""
|
|
738
|
+
+ filter_query
|
|
739
|
+
+ """
|
|
740
|
+
WITH n, """
|
|
741
|
+
+ get_vector_cosine_func_query('n.name_embedding', search_vector_var, driver.provider)
|
|
742
|
+
+ """ AS score
|
|
743
|
+
WHERE score > $min_score
|
|
744
|
+
RETURN
|
|
745
|
+
"""
|
|
746
|
+
+ get_entity_node_return_query(driver.provider)
|
|
747
|
+
+ """
|
|
748
|
+
ORDER BY score DESC
|
|
749
|
+
LIMIT $limit
|
|
750
|
+
"""
|
|
751
|
+
)
|
|
752
|
+
|
|
753
|
+
records, _, _ = await driver.execute_query(
|
|
754
|
+
query,
|
|
755
|
+
search_vector=search_vector,
|
|
756
|
+
limit=limit,
|
|
757
|
+
min_score=min_score,
|
|
758
|
+
routing_='r',
|
|
759
|
+
**filter_params,
|
|
760
|
+
)
|
|
761
|
+
|
|
762
|
+
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
|
|
411
763
|
|
|
412
764
|
return nodes
|
|
413
765
|
|
|
414
766
|
|
|
415
767
|
async def node_bfs_search(
|
|
416
|
-
driver:
|
|
768
|
+
driver: GraphDriver,
|
|
417
769
|
bfs_origin_node_uuids: list[str] | None,
|
|
418
770
|
search_filter: SearchFilters,
|
|
419
771
|
bfs_max_depth: int,
|
|
420
|
-
|
|
772
|
+
group_ids: list[str] | None = None,
|
|
773
|
+
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
421
774
|
) -> list[EntityNode]:
|
|
422
|
-
|
|
423
|
-
if bfs_origin_node_uuids is None:
|
|
775
|
+
if bfs_origin_node_uuids is None or len(bfs_origin_node_uuids) == 0 or bfs_max_depth < 1:
|
|
424
776
|
return []
|
|
425
777
|
|
|
426
|
-
|
|
778
|
+
filter_queries, filter_params = node_search_filter_query_constructor(
|
|
779
|
+
search_filter, driver.provider
|
|
780
|
+
)
|
|
427
781
|
|
|
428
|
-
|
|
782
|
+
if group_ids is not None:
|
|
783
|
+
filter_queries.append('n.group_id IN $group_ids')
|
|
784
|
+
filter_queries.append('origin.group_id IN $group_ids')
|
|
785
|
+
filter_params['group_ids'] = group_ids
|
|
786
|
+
|
|
787
|
+
filter_query = ''
|
|
788
|
+
if filter_queries:
|
|
789
|
+
filter_query = ' AND ' + (' AND '.join(filter_queries))
|
|
790
|
+
|
|
791
|
+
match_queries = [
|
|
792
|
+
f"""
|
|
793
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
794
|
+
MATCH (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
|
|
795
|
+
WHERE n.group_id = origin.group_id
|
|
429
796
|
"""
|
|
797
|
+
]
|
|
798
|
+
|
|
799
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
800
|
+
match_queries = [
|
|
801
|
+
f"""
|
|
802
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
803
|
+
MATCH (origin {{uuid: origin_uuid}})-[e:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
|
|
804
|
+
WHERE origin:Entity OR origin.Episode
|
|
805
|
+
AND n.group_id = origin.group_id
|
|
806
|
+
"""
|
|
807
|
+
]
|
|
808
|
+
|
|
809
|
+
if driver.provider == GraphProvider.KUZU:
|
|
810
|
+
depth = bfs_max_depth * 2
|
|
811
|
+
match_queries = [
|
|
812
|
+
"""
|
|
813
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
814
|
+
MATCH (origin:Episodic {uuid: origin_uuid})-[:MENTIONS]->(n:Entity)
|
|
815
|
+
WHERE n.group_id = origin.group_id
|
|
816
|
+
""",
|
|
817
|
+
f"""
|
|
430
818
|
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
431
|
-
MATCH (origin:Entity
|
|
819
|
+
MATCH (origin:Entity {{uuid: origin_uuid}})-[:RELATES_TO*2..{depth}]->(n:Entity)
|
|
432
820
|
WHERE n.group_id = origin.group_id
|
|
821
|
+
""",
|
|
822
|
+
]
|
|
823
|
+
if bfs_max_depth > 1:
|
|
824
|
+
depth = (bfs_max_depth - 1) * 2
|
|
825
|
+
match_queries.append(f"""
|
|
826
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
827
|
+
MATCH (origin:Episodic {{uuid: origin_uuid}})-[:MENTIONS]->(:Entity)-[:RELATES_TO*2..{depth}]->(n:Entity)
|
|
828
|
+
WHERE n.group_id = origin.group_id
|
|
829
|
+
""")
|
|
830
|
+
|
|
831
|
+
records = []
|
|
832
|
+
for match_query in match_queries:
|
|
833
|
+
sub_records, _, _ = await driver.execute_query(
|
|
834
|
+
match_query
|
|
835
|
+
+ filter_query
|
|
836
|
+
+ """
|
|
837
|
+
RETURN
|
|
433
838
|
"""
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
)
|
|
446
|
-
nodes = [get_entity_node_from_record(record) for record in records]
|
|
839
|
+
+ get_entity_node_return_query(driver.provider)
|
|
840
|
+
+ """
|
|
841
|
+
LIMIT $limit
|
|
842
|
+
""",
|
|
843
|
+
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
|
844
|
+
limit=limit,
|
|
845
|
+
routing_='r',
|
|
846
|
+
**filter_params,
|
|
847
|
+
)
|
|
848
|
+
records.extend(sub_records)
|
|
849
|
+
|
|
850
|
+
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
|
|
447
851
|
|
|
448
852
|
return nodes
|
|
449
853
|
|
|
450
854
|
|
|
451
855
|
async def episode_fulltext_search(
|
|
452
|
-
driver:
|
|
856
|
+
driver: GraphDriver,
|
|
453
857
|
query: str,
|
|
454
858
|
_search_filter: SearchFilters,
|
|
455
859
|
group_ids: list[str] | None = None,
|
|
456
860
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
457
861
|
) -> list[EpisodicNode]:
|
|
862
|
+
if driver.search_interface:
|
|
863
|
+
return await driver.search_interface.episode_fulltext_search(
|
|
864
|
+
driver, query, _search_filter, group_ids, limit
|
|
865
|
+
)
|
|
866
|
+
|
|
458
867
|
# BM25 search to get top episodes
|
|
459
|
-
fuzzy_query = fulltext_query(query, group_ids)
|
|
868
|
+
fuzzy_query = fulltext_query(query, group_ids, driver)
|
|
460
869
|
if fuzzy_query == '':
|
|
461
870
|
return []
|
|
462
871
|
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
872
|
+
filter_params: dict[str, Any] = {}
|
|
873
|
+
group_filter_query: LiteralString = ''
|
|
874
|
+
if group_ids is not None:
|
|
875
|
+
group_filter_query += '\nAND e.group_id IN $group_ids'
|
|
876
|
+
filter_params['group_ids'] = group_ids
|
|
877
|
+
|
|
878
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
879
|
+
res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
|
880
|
+
if res['hits']['total']['value'] > 0:
|
|
881
|
+
input_ids = []
|
|
882
|
+
for r in res['hits']['hits']:
|
|
883
|
+
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
|
|
884
|
+
|
|
885
|
+
# Match the edge ides and return the values
|
|
886
|
+
query = """
|
|
887
|
+
UNWIND $ids as i
|
|
888
|
+
MATCH (e:Episodic)
|
|
889
|
+
WHERE e.uuid=i.uuid
|
|
890
|
+
RETURN
|
|
891
|
+
e.content AS content,
|
|
892
|
+
e.created_at AS created_at,
|
|
893
|
+
e.valid_at AS valid_at,
|
|
894
|
+
e.uuid AS uuid,
|
|
895
|
+
e.name AS name,
|
|
896
|
+
e.group_id AS group_id,
|
|
897
|
+
e.source_description AS source_description,
|
|
898
|
+
e.source AS source,
|
|
899
|
+
e.entity_edges AS entity_edges
|
|
900
|
+
ORDER BY i.score DESC
|
|
901
|
+
LIMIT $limit
|
|
902
|
+
"""
|
|
903
|
+
records, _, _ = await driver.execute_query(
|
|
904
|
+
query,
|
|
905
|
+
ids=input_ids,
|
|
906
|
+
query=fuzzy_query,
|
|
907
|
+
limit=limit,
|
|
908
|
+
routing_='r',
|
|
909
|
+
**filter_params,
|
|
910
|
+
)
|
|
911
|
+
else:
|
|
912
|
+
return []
|
|
913
|
+
else:
|
|
914
|
+
query = (
|
|
915
|
+
get_nodes_query('episode_content', '$query', limit=limit, provider=driver.provider)
|
|
916
|
+
+ """
|
|
917
|
+
YIELD node AS episode, score
|
|
918
|
+
MATCH (e:Episodic)
|
|
919
|
+
WHERE e.uuid = episode.uuid
|
|
920
|
+
"""
|
|
921
|
+
+ group_filter_query
|
|
922
|
+
+ """
|
|
923
|
+
RETURN
|
|
924
|
+
"""
|
|
925
|
+
+ EPISODIC_NODE_RETURN
|
|
926
|
+
+ """
|
|
927
|
+
ORDER BY score DESC
|
|
928
|
+
LIMIT $limit
|
|
929
|
+
"""
|
|
930
|
+
)
|
|
931
|
+
|
|
932
|
+
records, _, _ = await driver.execute_query(
|
|
933
|
+
query, query=fuzzy_query, limit=limit, routing_='r', **filter_params
|
|
934
|
+
)
|
|
935
|
+
|
|
488
936
|
episodes = [get_episodic_node_from_record(record) for record in records]
|
|
489
937
|
|
|
490
938
|
return episodes
|
|
491
939
|
|
|
492
940
|
|
|
493
941
|
async def community_fulltext_search(
|
|
494
|
-
driver:
|
|
942
|
+
driver: GraphDriver,
|
|
495
943
|
query: str,
|
|
496
944
|
group_ids: list[str] | None = None,
|
|
497
945
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
498
946
|
) -> list[CommunityNode]:
|
|
499
947
|
# BM25 search to get top communities
|
|
500
|
-
fuzzy_query = fulltext_query(query, group_ids)
|
|
948
|
+
fuzzy_query = fulltext_query(query, group_ids, driver)
|
|
501
949
|
if fuzzy_query == '':
|
|
502
950
|
return []
|
|
503
951
|
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
952
|
+
filter_params: dict[str, Any] = {}
|
|
953
|
+
group_filter_query: LiteralString = ''
|
|
954
|
+
if group_ids is not None:
|
|
955
|
+
group_filter_query = 'WHERE c.group_id IN $group_ids'
|
|
956
|
+
filter_params['group_ids'] = group_ids
|
|
957
|
+
|
|
958
|
+
yield_query = 'YIELD node AS c, score'
|
|
959
|
+
if driver.provider == GraphProvider.KUZU:
|
|
960
|
+
yield_query = 'WITH node AS c, score'
|
|
961
|
+
|
|
962
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
963
|
+
res = driver.run_aoss_query('community_name', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
|
964
|
+
if res['hits']['total']['value'] > 0:
|
|
965
|
+
# Calculate Cosine similarity then return the edge ids
|
|
966
|
+
input_ids = []
|
|
967
|
+
for r in res['hits']['hits']:
|
|
968
|
+
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
|
|
969
|
+
|
|
970
|
+
# Match the edge ides and return the values
|
|
971
|
+
query = """
|
|
972
|
+
UNWIND $ids as i
|
|
973
|
+
MATCH (comm:Community)
|
|
974
|
+
WHERE comm.uuid=i.id
|
|
975
|
+
RETURN
|
|
976
|
+
comm.uuid AS uuid,
|
|
977
|
+
comm.group_id AS group_id,
|
|
978
|
+
comm.name AS name,
|
|
979
|
+
comm.created_at AS created_at,
|
|
980
|
+
comm.summary AS summary,
|
|
981
|
+
[x IN split(comm.name_embedding, ",") | toFloat(x)]AS name_embedding
|
|
982
|
+
ORDER BY i.score DESC
|
|
983
|
+
LIMIT $limit
|
|
984
|
+
"""
|
|
985
|
+
records, _, _ = await driver.execute_query(
|
|
986
|
+
query,
|
|
987
|
+
ids=input_ids,
|
|
988
|
+
query=fuzzy_query,
|
|
989
|
+
limit=limit,
|
|
990
|
+
routing_='r',
|
|
991
|
+
**filter_params,
|
|
992
|
+
)
|
|
993
|
+
else:
|
|
994
|
+
return []
|
|
995
|
+
else:
|
|
996
|
+
query = (
|
|
997
|
+
get_nodes_query('community_name', '$query', limit=limit, provider=driver.provider)
|
|
998
|
+
+ yield_query
|
|
999
|
+
+ """
|
|
1000
|
+
WITH c, score
|
|
1001
|
+
"""
|
|
1002
|
+
+ group_filter_query
|
|
1003
|
+
+ """
|
|
1004
|
+
RETURN
|
|
1005
|
+
"""
|
|
1006
|
+
+ COMMUNITY_NODE_RETURN
|
|
1007
|
+
+ """
|
|
1008
|
+
ORDER BY score DESC
|
|
1009
|
+
LIMIT $limit
|
|
1010
|
+
"""
|
|
1011
|
+
)
|
|
1012
|
+
|
|
1013
|
+
records, _, _ = await driver.execute_query(
|
|
1014
|
+
query, query=fuzzy_query, limit=limit, routing_='r', **filter_params
|
|
1015
|
+
)
|
|
1016
|
+
|
|
523
1017
|
communities = [get_community_node_from_record(record) for record in records]
|
|
524
1018
|
|
|
525
1019
|
return communities
|
|
526
1020
|
|
|
527
1021
|
|
|
528
1022
|
async def community_similarity_search(
|
|
529
|
-
driver:
|
|
1023
|
+
driver: GraphDriver,
|
|
530
1024
|
search_vector: list[float],
|
|
531
1025
|
group_ids: list[str] | None = None,
|
|
532
1026
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
@@ -537,34 +1031,99 @@ async def community_similarity_search(
|
|
|
537
1031
|
|
|
538
1032
|
group_filter_query: LiteralString = ''
|
|
539
1033
|
if group_ids is not None:
|
|
540
|
-
group_filter_query += 'WHERE
|
|
1034
|
+
group_filter_query += ' WHERE c.group_id IN $group_ids'
|
|
541
1035
|
query_params['group_ids'] = group_ids
|
|
542
1036
|
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
1037
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
1038
|
+
query = (
|
|
1039
|
+
"""
|
|
1040
|
+
MATCH (n:Community)
|
|
1041
|
+
"""
|
|
1042
|
+
+ group_filter_query
|
|
1043
|
+
+ """
|
|
1044
|
+
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
|
1045
|
+
"""
|
|
1046
|
+
)
|
|
1047
|
+
resp, header, _ = await driver.execute_query(
|
|
1048
|
+
query,
|
|
1049
|
+
search_vector=search_vector,
|
|
1050
|
+
limit=limit,
|
|
1051
|
+
min_score=min_score,
|
|
1052
|
+
routing_='r',
|
|
1053
|
+
**query_params,
|
|
1054
|
+
)
|
|
1055
|
+
|
|
1056
|
+
if len(resp) > 0:
|
|
1057
|
+
# Calculate Cosine similarity then return the edge ids
|
|
1058
|
+
input_ids = []
|
|
1059
|
+
for r in resp:
|
|
1060
|
+
if r['embedding']:
|
|
1061
|
+
score = calculate_cosine_similarity(
|
|
1062
|
+
search_vector, list(map(float, r['embedding'].split(',')))
|
|
1063
|
+
)
|
|
1064
|
+
if score > min_score:
|
|
1065
|
+
input_ids.append({'id': r['id'], 'score': score})
|
|
1066
|
+
|
|
1067
|
+
# Match the edge ides and return the values
|
|
1068
|
+
query = """
|
|
1069
|
+
UNWIND $ids as i
|
|
1070
|
+
MATCH (comm:Community)
|
|
1071
|
+
WHERE id(comm)=i.id
|
|
1072
|
+
RETURN
|
|
1073
|
+
comm.uuid As uuid,
|
|
1074
|
+
comm.group_id AS group_id,
|
|
1075
|
+
comm.name AS name,
|
|
1076
|
+
comm.created_at AS created_at,
|
|
1077
|
+
comm.summary AS summary,
|
|
1078
|
+
comm.name_embedding AS name_embedding
|
|
1079
|
+
ORDER BY i.score DESC
|
|
1080
|
+
LIMIT $limit
|
|
1081
|
+
"""
|
|
1082
|
+
records, header, _ = await driver.execute_query(
|
|
1083
|
+
query,
|
|
1084
|
+
ids=input_ids,
|
|
1085
|
+
search_vector=search_vector,
|
|
1086
|
+
limit=limit,
|
|
1087
|
+
min_score=min_score,
|
|
1088
|
+
routing_='r',
|
|
1089
|
+
**query_params,
|
|
1090
|
+
)
|
|
1091
|
+
else:
|
|
1092
|
+
return []
|
|
1093
|
+
else:
|
|
1094
|
+
search_vector_var = '$search_vector'
|
|
1095
|
+
if driver.provider == GraphProvider.KUZU:
|
|
1096
|
+
search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
|
|
1097
|
+
|
|
1098
|
+
query = (
|
|
1099
|
+
"""
|
|
1100
|
+
MATCH (c:Community)
|
|
1101
|
+
"""
|
|
1102
|
+
+ group_filter_query
|
|
1103
|
+
+ """
|
|
1104
|
+
WITH c,
|
|
1105
|
+
"""
|
|
1106
|
+
+ get_vector_cosine_func_query('c.name_embedding', search_vector_var, driver.provider)
|
|
1107
|
+
+ """ AS score
|
|
1108
|
+
WHERE score > $min_score
|
|
1109
|
+
RETURN
|
|
1110
|
+
"""
|
|
1111
|
+
+ COMMUNITY_NODE_RETURN
|
|
1112
|
+
+ """
|
|
1113
|
+
ORDER BY score DESC
|
|
1114
|
+
LIMIT $limit
|
|
1115
|
+
"""
|
|
1116
|
+
)
|
|
1117
|
+
|
|
1118
|
+
records, _, _ = await driver.execute_query(
|
|
1119
|
+
query,
|
|
1120
|
+
search_vector=search_vector,
|
|
1121
|
+
limit=limit,
|
|
1122
|
+
min_score=min_score,
|
|
1123
|
+
routing_='r',
|
|
1124
|
+
**query_params,
|
|
1125
|
+
)
|
|
1126
|
+
|
|
568
1127
|
communities = [get_community_node_from_record(record) for record in records]
|
|
569
1128
|
|
|
570
1129
|
return communities
|
|
@@ -573,7 +1132,7 @@ async def community_similarity_search(
|
|
|
573
1132
|
async def hybrid_node_search(
|
|
574
1133
|
queries: list[str],
|
|
575
1134
|
embeddings: list[list[float]],
|
|
576
|
-
driver:
|
|
1135
|
+
driver: GraphDriver,
|
|
577
1136
|
search_filter: SearchFilters,
|
|
578
1137
|
group_ids: list[str] | None = None,
|
|
579
1138
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
@@ -590,7 +1149,7 @@ async def hybrid_node_search(
|
|
|
590
1149
|
A list of text queries to search for.
|
|
591
1150
|
embeddings : list[list[float]]
|
|
592
1151
|
A list of embedding vectors corresponding to the queries. If empty only fulltext search is performed.
|
|
593
|
-
driver :
|
|
1152
|
+
driver : GraphDriver
|
|
594
1153
|
The Neo4j driver instance for database operations.
|
|
595
1154
|
group_ids : list[str] | None, optional
|
|
596
1155
|
The list of group ids to retrieve nodes from.
|
|
@@ -635,7 +1194,7 @@ async def hybrid_node_search(
|
|
|
635
1194
|
}
|
|
636
1195
|
result_uuids = [[node.uuid for node in result] for result in results]
|
|
637
1196
|
|
|
638
|
-
ranked_uuids = rrf(result_uuids)
|
|
1197
|
+
ranked_uuids, _ = rrf(result_uuids)
|
|
639
1198
|
|
|
640
1199
|
relevant_nodes: list[EntityNode] = [node_uuid_map[uuid] for uuid in ranked_uuids]
|
|
641
1200
|
|
|
@@ -645,7 +1204,7 @@ async def hybrid_node_search(
|
|
|
645
1204
|
|
|
646
1205
|
|
|
647
1206
|
async def get_relevant_nodes(
|
|
648
|
-
driver:
|
|
1207
|
+
driver: GraphDriver,
|
|
649
1208
|
nodes: list[EntityNode],
|
|
650
1209
|
search_filter: SearchFilters,
|
|
651
1210
|
min_score: float = DEFAULT_MIN_SCORE,
|
|
@@ -655,77 +1214,140 @@ async def get_relevant_nodes(
|
|
|
655
1214
|
return []
|
|
656
1215
|
|
|
657
1216
|
group_id = nodes[0].group_id
|
|
658
|
-
|
|
659
|
-
# vector similarity search over entity names
|
|
660
|
-
query_params: dict[str, Any] = {}
|
|
661
|
-
|
|
662
|
-
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
|
663
|
-
query_params.update(filter_params)
|
|
664
|
-
|
|
665
|
-
query = (
|
|
666
|
-
RUNTIME_QUERY
|
|
667
|
-
+ """UNWIND $nodes AS node
|
|
668
|
-
MATCH (n:Entity {group_id: $group_id})
|
|
669
|
-
"""
|
|
670
|
-
+ filter_query
|
|
671
|
-
+ """
|
|
672
|
-
WITH node, n, vector.similarity.cosine(n.name_embedding, node.name_embedding) AS score
|
|
673
|
-
WHERE score > $min_score
|
|
674
|
-
WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
|
|
675
|
-
|
|
676
|
-
CALL db.index.fulltext.queryNodes("node_name_and_summary", node.fulltext_query, {limit: $limit})
|
|
677
|
-
YIELD node AS m
|
|
678
|
-
WHERE m.group_id = $group_id
|
|
679
|
-
WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
|
|
680
|
-
|
|
681
|
-
WITH node,
|
|
682
|
-
top_vector_nodes,
|
|
683
|
-
[m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
|
|
684
|
-
|
|
685
|
-
WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
|
|
686
|
-
|
|
687
|
-
UNWIND combined_nodes AS combined_node
|
|
688
|
-
WITH node, collect(DISTINCT combined_node) AS deduped_nodes
|
|
689
|
-
|
|
690
|
-
RETURN
|
|
691
|
-
node.uuid AS search_node_uuid,
|
|
692
|
-
[x IN deduped_nodes | {
|
|
693
|
-
uuid: x.uuid,
|
|
694
|
-
name: x.name,
|
|
695
|
-
name_embedding: x.name_embedding,
|
|
696
|
-
group_id: x.group_id,
|
|
697
|
-
created_at: x.created_at,
|
|
698
|
-
summary: x.summary,
|
|
699
|
-
labels: labels(x),
|
|
700
|
-
attributes: properties(x)
|
|
701
|
-
}] AS matches
|
|
702
|
-
"""
|
|
703
|
-
)
|
|
704
|
-
|
|
705
1217
|
query_nodes = [
|
|
706
1218
|
{
|
|
707
1219
|
'uuid': node.uuid,
|
|
708
1220
|
'name': node.name,
|
|
709
1221
|
'name_embedding': node.name_embedding,
|
|
710
|
-
'fulltext_query': fulltext_query(node.name, [node.group_id]),
|
|
1222
|
+
'fulltext_query': fulltext_query(node.name, [node.group_id], driver),
|
|
711
1223
|
}
|
|
712
1224
|
for node in nodes
|
|
713
1225
|
]
|
|
714
1226
|
|
|
1227
|
+
filter_queries, filter_params = node_search_filter_query_constructor(
|
|
1228
|
+
search_filter, driver.provider
|
|
1229
|
+
)
|
|
1230
|
+
|
|
1231
|
+
filter_query = ''
|
|
1232
|
+
if filter_queries:
|
|
1233
|
+
filter_query = 'WHERE ' + (' AND '.join(filter_queries))
|
|
1234
|
+
|
|
1235
|
+
if driver.provider == GraphProvider.KUZU:
|
|
1236
|
+
embedding_size = len(nodes[0].name_embedding) if nodes[0].name_embedding is not None else 0
|
|
1237
|
+
if embedding_size == 0:
|
|
1238
|
+
return []
|
|
1239
|
+
|
|
1240
|
+
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
|
|
1241
|
+
query = (
|
|
1242
|
+
"""
|
|
1243
|
+
UNWIND $nodes AS node
|
|
1244
|
+
MATCH (n:Entity {group_id: $group_id})
|
|
1245
|
+
"""
|
|
1246
|
+
+ filter_query
|
|
1247
|
+
+ """
|
|
1248
|
+
WITH node, n, """
|
|
1249
|
+
+ get_vector_cosine_func_query(
|
|
1250
|
+
'n.name_embedding',
|
|
1251
|
+
f'CAST(node.name_embedding AS FLOAT[{embedding_size}])',
|
|
1252
|
+
driver.provider,
|
|
1253
|
+
)
|
|
1254
|
+
+ """ AS score
|
|
1255
|
+
WHERE score > $min_score
|
|
1256
|
+
WITH node, collect(n)[:$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
|
|
1257
|
+
"""
|
|
1258
|
+
+ get_nodes_query(
|
|
1259
|
+
'node_name_and_summary',
|
|
1260
|
+
'node.fulltext_query',
|
|
1261
|
+
limit=limit,
|
|
1262
|
+
provider=driver.provider,
|
|
1263
|
+
)
|
|
1264
|
+
+ """
|
|
1265
|
+
WITH node AS m
|
|
1266
|
+
WHERE m.group_id = $group_id AND NOT m.uuid IN vector_node_uuids
|
|
1267
|
+
WITH node, top_vector_nodes, collect(m) AS fulltext_nodes
|
|
1268
|
+
|
|
1269
|
+
WITH node, list_concat(top_vector_nodes, fulltext_nodes) AS combined_nodes
|
|
1270
|
+
|
|
1271
|
+
UNWIND combined_nodes AS x
|
|
1272
|
+
WITH node, collect(DISTINCT {
|
|
1273
|
+
uuid: x.uuid,
|
|
1274
|
+
name: x.name,
|
|
1275
|
+
name_embedding: x.name_embedding,
|
|
1276
|
+
group_id: x.group_id,
|
|
1277
|
+
created_at: x.created_at,
|
|
1278
|
+
summary: x.summary,
|
|
1279
|
+
labels: x.labels,
|
|
1280
|
+
attributes: x.attributes
|
|
1281
|
+
}) AS matches
|
|
1282
|
+
|
|
1283
|
+
RETURN
|
|
1284
|
+
node.uuid AS search_node_uuid, matches
|
|
1285
|
+
"""
|
|
1286
|
+
)
|
|
1287
|
+
else:
|
|
1288
|
+
query = (
|
|
1289
|
+
"""
|
|
1290
|
+
UNWIND $nodes AS node
|
|
1291
|
+
MATCH (n:Entity {group_id: $group_id})
|
|
1292
|
+
"""
|
|
1293
|
+
+ filter_query
|
|
1294
|
+
+ """
|
|
1295
|
+
WITH node, n, """
|
|
1296
|
+
+ get_vector_cosine_func_query(
|
|
1297
|
+
'n.name_embedding', 'node.name_embedding', driver.provider
|
|
1298
|
+
)
|
|
1299
|
+
+ """ AS score
|
|
1300
|
+
WHERE score > $min_score
|
|
1301
|
+
WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
|
|
1302
|
+
"""
|
|
1303
|
+
+ get_nodes_query(
|
|
1304
|
+
'node_name_and_summary',
|
|
1305
|
+
'node.fulltext_query',
|
|
1306
|
+
limit=limit,
|
|
1307
|
+
provider=driver.provider,
|
|
1308
|
+
)
|
|
1309
|
+
+ """
|
|
1310
|
+
YIELD node AS m
|
|
1311
|
+
WHERE m.group_id = $group_id
|
|
1312
|
+
WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
|
|
1313
|
+
|
|
1314
|
+
WITH node,
|
|
1315
|
+
top_vector_nodes,
|
|
1316
|
+
[m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
|
|
1317
|
+
|
|
1318
|
+
WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
|
|
1319
|
+
|
|
1320
|
+
UNWIND combined_nodes AS combined_node
|
|
1321
|
+
WITH node, collect(DISTINCT combined_node) AS deduped_nodes
|
|
1322
|
+
|
|
1323
|
+
RETURN
|
|
1324
|
+
node.uuid AS search_node_uuid,
|
|
1325
|
+
[x IN deduped_nodes | {
|
|
1326
|
+
uuid: x.uuid,
|
|
1327
|
+
name: x.name,
|
|
1328
|
+
name_embedding: x.name_embedding,
|
|
1329
|
+
group_id: x.group_id,
|
|
1330
|
+
created_at: x.created_at,
|
|
1331
|
+
summary: x.summary,
|
|
1332
|
+
labels: labels(x),
|
|
1333
|
+
attributes: properties(x)
|
|
1334
|
+
}] AS matches
|
|
1335
|
+
"""
|
|
1336
|
+
)
|
|
1337
|
+
|
|
715
1338
|
results, _, _ = await driver.execute_query(
|
|
716
1339
|
query,
|
|
717
|
-
query_params,
|
|
718
1340
|
nodes=query_nodes,
|
|
719
1341
|
group_id=group_id,
|
|
720
1342
|
limit=limit,
|
|
721
1343
|
min_score=min_score,
|
|
722
|
-
database_=DEFAULT_DATABASE,
|
|
723
1344
|
routing_='r',
|
|
1345
|
+
**filter_params,
|
|
724
1346
|
)
|
|
725
1347
|
|
|
726
1348
|
relevant_nodes_dict: dict[str, list[EntityNode]] = {
|
|
727
1349
|
result['search_node_uuid']: [
|
|
728
|
-
get_entity_node_from_record(record) for record in result['matches']
|
|
1350
|
+
get_entity_node_from_record(record, driver.provider) for record in result['matches']
|
|
729
1351
|
]
|
|
730
1352
|
for result in results
|
|
731
1353
|
}
|
|
@@ -736,7 +1358,7 @@ async def get_relevant_nodes(
|
|
|
736
1358
|
|
|
737
1359
|
|
|
738
1360
|
async def get_relevant_edges(
|
|
739
|
-
driver:
|
|
1361
|
+
driver: GraphDriver,
|
|
740
1362
|
edges: list[EntityEdge],
|
|
741
1363
|
search_filter: SearchFilters,
|
|
742
1364
|
min_score: float = DEFAULT_MIN_SCORE,
|
|
@@ -745,53 +1367,172 @@ async def get_relevant_edges(
|
|
|
745
1367
|
if len(edges) == 0:
|
|
746
1368
|
return []
|
|
747
1369
|
|
|
748
|
-
|
|
1370
|
+
filter_queries, filter_params = edge_search_filter_query_constructor(
|
|
1371
|
+
search_filter, driver.provider
|
|
1372
|
+
)
|
|
749
1373
|
|
|
750
|
-
filter_query
|
|
751
|
-
|
|
1374
|
+
filter_query = ''
|
|
1375
|
+
if filter_queries:
|
|
1376
|
+
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
|
|
752
1377
|
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
+ """UNWIND $edges AS edge
|
|
756
|
-
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
|
1378
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
1379
|
+
query = (
|
|
757
1380
|
"""
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
RETURN edge.uuid
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
1381
|
+
UNWIND $edges AS edge
|
|
1382
|
+
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
|
1383
|
+
"""
|
|
1384
|
+
+ filter_query
|
|
1385
|
+
+ """
|
|
1386
|
+
WITH e, edge
|
|
1387
|
+
RETURN DISTINCT id(e) as id, e.fact_embedding as source_embedding, edge.uuid as search_edge_uuid,
|
|
1388
|
+
edge.fact_embedding as target_embedding
|
|
1389
|
+
"""
|
|
1390
|
+
)
|
|
1391
|
+
resp, _, _ = await driver.execute_query(
|
|
1392
|
+
query,
|
|
1393
|
+
edges=[edge.model_dump() for edge in edges],
|
|
1394
|
+
limit=limit,
|
|
1395
|
+
min_score=min_score,
|
|
1396
|
+
routing_='r',
|
|
1397
|
+
**filter_params,
|
|
1398
|
+
)
|
|
1399
|
+
|
|
1400
|
+
# Calculate Cosine similarity then return the edge ids
|
|
1401
|
+
input_ids = []
|
|
1402
|
+
for r in resp:
|
|
1403
|
+
score = calculate_cosine_similarity(
|
|
1404
|
+
list(map(float, r['source_embedding'].split(','))), r['target_embedding']
|
|
1405
|
+
)
|
|
1406
|
+
if score > min_score:
|
|
1407
|
+
input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})
|
|
1408
|
+
|
|
1409
|
+
# Match the edge ides and return the values
|
|
1410
|
+
query = """
|
|
1411
|
+
UNWIND $ids AS edge
|
|
1412
|
+
MATCH ()-[e]->()
|
|
1413
|
+
WHERE id(e) = edge.id
|
|
1414
|
+
WITH edge, e
|
|
1415
|
+
ORDER BY edge.score DESC
|
|
1416
|
+
RETURN edge.uuid AS search_edge_uuid,
|
|
1417
|
+
collect({
|
|
1418
|
+
uuid: e.uuid,
|
|
1419
|
+
source_node_uuid: startNode(e).uuid,
|
|
1420
|
+
target_node_uuid: endNode(e).uuid,
|
|
1421
|
+
created_at: e.created_at,
|
|
1422
|
+
name: e.name,
|
|
1423
|
+
group_id: e.group_id,
|
|
1424
|
+
fact: e.fact,
|
|
1425
|
+
fact_embedding: [x IN split(e.fact_embedding, ",") | toFloat(x)],
|
|
1426
|
+
episodes: split(e.episodes, ","),
|
|
1427
|
+
expired_at: e.expired_at,
|
|
1428
|
+
valid_at: e.valid_at,
|
|
1429
|
+
invalid_at: e.invalid_at,
|
|
1430
|
+
attributes: properties(e)
|
|
1431
|
+
})[..$limit] AS matches
|
|
1432
|
+
"""
|
|
1433
|
+
|
|
1434
|
+
results, _, _ = await driver.execute_query(
|
|
1435
|
+
query,
|
|
1436
|
+
ids=input_ids,
|
|
1437
|
+
edges=[edge.model_dump() for edge in edges],
|
|
1438
|
+
limit=limit,
|
|
1439
|
+
min_score=min_score,
|
|
1440
|
+
routing_='r',
|
|
1441
|
+
**filter_params,
|
|
1442
|
+
)
|
|
1443
|
+
else:
|
|
1444
|
+
if driver.provider == GraphProvider.KUZU:
|
|
1445
|
+
embedding_size = (
|
|
1446
|
+
len(edges[0].fact_embedding) if edges[0].fact_embedding is not None else 0
|
|
1447
|
+
)
|
|
1448
|
+
if embedding_size == 0:
|
|
1449
|
+
return []
|
|
1450
|
+
|
|
1451
|
+
query = (
|
|
1452
|
+
"""
|
|
1453
|
+
UNWIND $edges AS edge
|
|
1454
|
+
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
|
|
1455
|
+
"""
|
|
1456
|
+
+ filter_query
|
|
1457
|
+
+ """
|
|
1458
|
+
WITH e, edge, n, m, """
|
|
1459
|
+
+ get_vector_cosine_func_query(
|
|
1460
|
+
'e.fact_embedding',
|
|
1461
|
+
f'CAST(edge.fact_embedding AS FLOAT[{embedding_size}])',
|
|
1462
|
+
driver.provider,
|
|
1463
|
+
)
|
|
1464
|
+
+ """ AS score
|
|
1465
|
+
WHERE score > $min_score
|
|
1466
|
+
WITH e, edge, n, m, score
|
|
1467
|
+
ORDER BY score DESC
|
|
1468
|
+
LIMIT $limit
|
|
1469
|
+
RETURN
|
|
1470
|
+
edge.uuid AS search_edge_uuid,
|
|
1471
|
+
collect({
|
|
1472
|
+
uuid: e.uuid,
|
|
1473
|
+
source_node_uuid: n.uuid,
|
|
1474
|
+
target_node_uuid: m.uuid,
|
|
1475
|
+
created_at: e.created_at,
|
|
1476
|
+
name: e.name,
|
|
1477
|
+
group_id: e.group_id,
|
|
1478
|
+
fact: e.fact,
|
|
1479
|
+
fact_embedding: e.fact_embedding,
|
|
1480
|
+
episodes: e.episodes,
|
|
1481
|
+
expired_at: e.expired_at,
|
|
1482
|
+
valid_at: e.valid_at,
|
|
1483
|
+
invalid_at: e.invalid_at,
|
|
1484
|
+
attributes: e.attributes
|
|
1485
|
+
}) AS matches
|
|
1486
|
+
"""
|
|
1487
|
+
)
|
|
1488
|
+
else:
|
|
1489
|
+
query = (
|
|
1490
|
+
"""
|
|
1491
|
+
UNWIND $edges AS edge
|
|
1492
|
+
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
|
1493
|
+
"""
|
|
1494
|
+
+ filter_query
|
|
1495
|
+
+ """
|
|
1496
|
+
WITH e, edge, """
|
|
1497
|
+
+ get_vector_cosine_func_query(
|
|
1498
|
+
'e.fact_embedding', 'edge.fact_embedding', driver.provider
|
|
1499
|
+
)
|
|
1500
|
+
+ """ AS score
|
|
1501
|
+
WHERE score > $min_score
|
|
1502
|
+
WITH edge, e, score
|
|
1503
|
+
ORDER BY score DESC
|
|
1504
|
+
RETURN
|
|
1505
|
+
edge.uuid AS search_edge_uuid,
|
|
1506
|
+
collect({
|
|
1507
|
+
uuid: e.uuid,
|
|
1508
|
+
source_node_uuid: startNode(e).uuid,
|
|
1509
|
+
target_node_uuid: endNode(e).uuid,
|
|
1510
|
+
created_at: e.created_at,
|
|
1511
|
+
name: e.name,
|
|
1512
|
+
group_id: e.group_id,
|
|
1513
|
+
fact: e.fact,
|
|
1514
|
+
fact_embedding: e.fact_embedding,
|
|
1515
|
+
episodes: e.episodes,
|
|
1516
|
+
expired_at: e.expired_at,
|
|
1517
|
+
valid_at: e.valid_at,
|
|
1518
|
+
invalid_at: e.invalid_at,
|
|
1519
|
+
attributes: properties(e)
|
|
1520
|
+
})[..$limit] AS matches
|
|
1521
|
+
"""
|
|
1522
|
+
)
|
|
1523
|
+
|
|
1524
|
+
results, _, _ = await driver.execute_query(
|
|
1525
|
+
query,
|
|
1526
|
+
edges=[edge.model_dump() for edge in edges],
|
|
1527
|
+
limit=limit,
|
|
1528
|
+
min_score=min_score,
|
|
1529
|
+
routing_='r',
|
|
1530
|
+
**filter_params,
|
|
1531
|
+
)
|
|
782
1532
|
|
|
783
|
-
results, _, _ = await driver.execute_query(
|
|
784
|
-
query,
|
|
785
|
-
query_params,
|
|
786
|
-
edges=[edge.model_dump() for edge in edges],
|
|
787
|
-
limit=limit,
|
|
788
|
-
min_score=min_score,
|
|
789
|
-
database_=DEFAULT_DATABASE,
|
|
790
|
-
routing_='r',
|
|
791
|
-
)
|
|
792
1533
|
relevant_edges_dict: dict[str, list[EntityEdge]] = {
|
|
793
1534
|
result['search_edge_uuid']: [
|
|
794
|
-
get_entity_edge_from_record(record) for record in result['matches']
|
|
1535
|
+
get_entity_edge_from_record(record, driver.provider) for record in result['matches']
|
|
795
1536
|
]
|
|
796
1537
|
for result in results
|
|
797
1538
|
}
|
|
@@ -802,7 +1543,7 @@ async def get_relevant_edges(
|
|
|
802
1543
|
|
|
803
1544
|
|
|
804
1545
|
async def get_edge_invalidation_candidates(
|
|
805
|
-
driver:
|
|
1546
|
+
driver: GraphDriver,
|
|
806
1547
|
edges: list[EntityEdge],
|
|
807
1548
|
search_filter: SearchFilters,
|
|
808
1549
|
min_score: float = DEFAULT_MIN_SCORE,
|
|
@@ -811,54 +1552,174 @@ async def get_edge_invalidation_candidates(
|
|
|
811
1552
|
if len(edges) == 0:
|
|
812
1553
|
return []
|
|
813
1554
|
|
|
814
|
-
|
|
1555
|
+
filter_queries, filter_params = edge_search_filter_query_constructor(
|
|
1556
|
+
search_filter, driver.provider
|
|
1557
|
+
)
|
|
815
1558
|
|
|
816
|
-
filter_query
|
|
817
|
-
|
|
1559
|
+
filter_query = ''
|
|
1560
|
+
if filter_queries:
|
|
1561
|
+
filter_query = ' AND ' + (' AND '.join(filter_queries))
|
|
818
1562
|
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
+ """UNWIND $edges AS edge
|
|
822
|
-
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
|
823
|
-
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
|
1563
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
1564
|
+
query = (
|
|
824
1565
|
"""
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
attributes: properties(e)
|
|
846
|
-
})[..$limit] AS matches
|
|
847
|
-
"""
|
|
848
|
-
)
|
|
1566
|
+
UNWIND $edges AS edge
|
|
1567
|
+
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
|
1568
|
+
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
|
1569
|
+
"""
|
|
1570
|
+
+ filter_query
|
|
1571
|
+
+ """
|
|
1572
|
+
WITH e, edge
|
|
1573
|
+
RETURN DISTINCT id(e) as id, e.fact_embedding as source_embedding,
|
|
1574
|
+
edge.fact_embedding as target_embedding,
|
|
1575
|
+
edge.uuid as search_edge_uuid
|
|
1576
|
+
"""
|
|
1577
|
+
)
|
|
1578
|
+
resp, _, _ = await driver.execute_query(
|
|
1579
|
+
query,
|
|
1580
|
+
edges=[edge.model_dump() for edge in edges],
|
|
1581
|
+
limit=limit,
|
|
1582
|
+
min_score=min_score,
|
|
1583
|
+
routing_='r',
|
|
1584
|
+
**filter_params,
|
|
1585
|
+
)
|
|
849
1586
|
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
1587
|
+
# Calculate Cosine similarity then return the edge ids
|
|
1588
|
+
input_ids = []
|
|
1589
|
+
for r in resp:
|
|
1590
|
+
score = calculate_cosine_similarity(
|
|
1591
|
+
list(map(float, r['source_embedding'].split(','))), r['target_embedding']
|
|
1592
|
+
)
|
|
1593
|
+
if score > min_score:
|
|
1594
|
+
input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})
|
|
1595
|
+
|
|
1596
|
+
# Match the edge ides and return the values
|
|
1597
|
+
query = """
|
|
1598
|
+
UNWIND $ids AS edge
|
|
1599
|
+
MATCH ()-[e]->()
|
|
1600
|
+
WHERE id(e) = edge.id
|
|
1601
|
+
WITH edge, e
|
|
1602
|
+
ORDER BY edge.score DESC
|
|
1603
|
+
RETURN edge.uuid AS search_edge_uuid,
|
|
1604
|
+
collect({
|
|
1605
|
+
uuid: e.uuid,
|
|
1606
|
+
source_node_uuid: startNode(e).uuid,
|
|
1607
|
+
target_node_uuid: endNode(e).uuid,
|
|
1608
|
+
created_at: e.created_at,
|
|
1609
|
+
name: e.name,
|
|
1610
|
+
group_id: e.group_id,
|
|
1611
|
+
fact: e.fact,
|
|
1612
|
+
fact_embedding: [x IN split(e.fact_embedding, ",") | toFloat(x)],
|
|
1613
|
+
episodes: split(e.episodes, ","),
|
|
1614
|
+
expired_at: e.expired_at,
|
|
1615
|
+
valid_at: e.valid_at,
|
|
1616
|
+
invalid_at: e.invalid_at,
|
|
1617
|
+
attributes: properties(e)
|
|
1618
|
+
})[..$limit] AS matches
|
|
1619
|
+
"""
|
|
1620
|
+
results, _, _ = await driver.execute_query(
|
|
1621
|
+
query,
|
|
1622
|
+
ids=input_ids,
|
|
1623
|
+
edges=[edge.model_dump() for edge in edges],
|
|
1624
|
+
limit=limit,
|
|
1625
|
+
min_score=min_score,
|
|
1626
|
+
routing_='r',
|
|
1627
|
+
**filter_params,
|
|
1628
|
+
)
|
|
1629
|
+
else:
|
|
1630
|
+
if driver.provider == GraphProvider.KUZU:
|
|
1631
|
+
embedding_size = (
|
|
1632
|
+
len(edges[0].fact_embedding) if edges[0].fact_embedding is not None else 0
|
|
1633
|
+
)
|
|
1634
|
+
if embedding_size == 0:
|
|
1635
|
+
return []
|
|
1636
|
+
|
|
1637
|
+
query = (
|
|
1638
|
+
"""
|
|
1639
|
+
UNWIND $edges AS edge
|
|
1640
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
|
|
1641
|
+
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
|
|
1642
|
+
"""
|
|
1643
|
+
+ filter_query
|
|
1644
|
+
+ """
|
|
1645
|
+
WITH edge, e, n, m, """
|
|
1646
|
+
+ get_vector_cosine_func_query(
|
|
1647
|
+
'e.fact_embedding',
|
|
1648
|
+
f'CAST(edge.fact_embedding AS FLOAT[{embedding_size}])',
|
|
1649
|
+
driver.provider,
|
|
1650
|
+
)
|
|
1651
|
+
+ """ AS score
|
|
1652
|
+
WHERE score > $min_score
|
|
1653
|
+
WITH edge, e, n, m, score
|
|
1654
|
+
ORDER BY score DESC
|
|
1655
|
+
LIMIT $limit
|
|
1656
|
+
RETURN
|
|
1657
|
+
edge.uuid AS search_edge_uuid,
|
|
1658
|
+
collect({
|
|
1659
|
+
uuid: e.uuid,
|
|
1660
|
+
source_node_uuid: n.uuid,
|
|
1661
|
+
target_node_uuid: m.uuid,
|
|
1662
|
+
created_at: e.created_at,
|
|
1663
|
+
name: e.name,
|
|
1664
|
+
group_id: e.group_id,
|
|
1665
|
+
fact: e.fact,
|
|
1666
|
+
fact_embedding: e.fact_embedding,
|
|
1667
|
+
episodes: e.episodes,
|
|
1668
|
+
expired_at: e.expired_at,
|
|
1669
|
+
valid_at: e.valid_at,
|
|
1670
|
+
invalid_at: e.invalid_at,
|
|
1671
|
+
attributes: e.attributes
|
|
1672
|
+
}) AS matches
|
|
1673
|
+
"""
|
|
1674
|
+
)
|
|
1675
|
+
else:
|
|
1676
|
+
query = (
|
|
1677
|
+
"""
|
|
1678
|
+
UNWIND $edges AS edge
|
|
1679
|
+
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
|
1680
|
+
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
|
1681
|
+
"""
|
|
1682
|
+
+ filter_query
|
|
1683
|
+
+ """
|
|
1684
|
+
WITH edge, e, """
|
|
1685
|
+
+ get_vector_cosine_func_query(
|
|
1686
|
+
'e.fact_embedding', 'edge.fact_embedding', driver.provider
|
|
1687
|
+
)
|
|
1688
|
+
+ """ AS score
|
|
1689
|
+
WHERE score > $min_score
|
|
1690
|
+
WITH edge, e, score
|
|
1691
|
+
ORDER BY score DESC
|
|
1692
|
+
RETURN
|
|
1693
|
+
edge.uuid AS search_edge_uuid,
|
|
1694
|
+
collect({
|
|
1695
|
+
uuid: e.uuid,
|
|
1696
|
+
source_node_uuid: startNode(e).uuid,
|
|
1697
|
+
target_node_uuid: endNode(e).uuid,
|
|
1698
|
+
created_at: e.created_at,
|
|
1699
|
+
name: e.name,
|
|
1700
|
+
group_id: e.group_id,
|
|
1701
|
+
fact: e.fact,
|
|
1702
|
+
fact_embedding: e.fact_embedding,
|
|
1703
|
+
episodes: e.episodes,
|
|
1704
|
+
expired_at: e.expired_at,
|
|
1705
|
+
valid_at: e.valid_at,
|
|
1706
|
+
invalid_at: e.invalid_at,
|
|
1707
|
+
attributes: properties(e)
|
|
1708
|
+
})[..$limit] AS matches
|
|
1709
|
+
"""
|
|
1710
|
+
)
|
|
1711
|
+
|
|
1712
|
+
results, _, _ = await driver.execute_query(
|
|
1713
|
+
query,
|
|
1714
|
+
edges=[edge.model_dump() for edge in edges],
|
|
1715
|
+
limit=limit,
|
|
1716
|
+
min_score=min_score,
|
|
1717
|
+
routing_='r',
|
|
1718
|
+
**filter_params,
|
|
1719
|
+
)
|
|
859
1720
|
invalidation_edges_dict: dict[str, list[EntityEdge]] = {
|
|
860
1721
|
result['search_edge_uuid']: [
|
|
861
|
-
get_entity_edge_from_record(record) for record in result['matches']
|
|
1722
|
+
get_entity_edge_from_record(record, driver.provider) for record in result['matches']
|
|
862
1723
|
]
|
|
863
1724
|
for result in results
|
|
864
1725
|
}
|
|
@@ -869,7 +1730,9 @@ async def get_edge_invalidation_candidates(
|
|
|
869
1730
|
|
|
870
1731
|
|
|
871
1732
|
# takes in a list of rankings of uuids
|
|
872
|
-
def rrf(
|
|
1733
|
+
def rrf(
|
|
1734
|
+
results: list[list[str]], rank_const=1, min_score: float = 0
|
|
1735
|
+
) -> tuple[list[str], list[float]]:
|
|
873
1736
|
scores: dict[str, float] = defaultdict(float)
|
|
874
1737
|
for result in results:
|
|
875
1738
|
for i, uuid in enumerate(result):
|
|
@@ -880,35 +1743,44 @@ def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[st
|
|
|
880
1743
|
|
|
881
1744
|
sorted_uuids = [term[0] for term in scored_uuids]
|
|
882
1745
|
|
|
883
|
-
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score]
|
|
1746
|
+
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [
|
|
1747
|
+
scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score
|
|
1748
|
+
]
|
|
884
1749
|
|
|
885
1750
|
|
|
886
1751
|
async def node_distance_reranker(
|
|
887
|
-
driver:
|
|
1752
|
+
driver: GraphDriver,
|
|
888
1753
|
node_uuids: list[str],
|
|
889
1754
|
center_node_uuid: str,
|
|
890
1755
|
min_score: float = 0,
|
|
891
|
-
) -> list[str]:
|
|
1756
|
+
) -> tuple[list[str], list[float]]:
|
|
892
1757
|
# filter out node_uuid center node node uuid
|
|
893
1758
|
filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
|
|
894
1759
|
scores: dict[str, float] = {center_node_uuid: 0.0}
|
|
895
1760
|
|
|
896
|
-
|
|
897
|
-
|
|
1761
|
+
query = """
|
|
1762
|
+
UNWIND $node_uuids AS node_uuid
|
|
1763
|
+
MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
|
|
1764
|
+
RETURN 1 AS score, node_uuid AS uuid
|
|
1765
|
+
"""
|
|
1766
|
+
if driver.provider == GraphProvider.KUZU:
|
|
1767
|
+
query = """
|
|
898
1768
|
UNWIND $node_uuids AS node_uuid
|
|
899
|
-
MATCH
|
|
900
|
-
RETURN
|
|
901
|
-
"""
|
|
1769
|
+
MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(n:Entity {uuid: node_uuid})
|
|
1770
|
+
RETURN 1 AS score, node_uuid AS uuid
|
|
1771
|
+
"""
|
|
902
1772
|
|
|
903
|
-
|
|
1773
|
+
# Find the shortest path to center node
|
|
1774
|
+
results, header, _ = await driver.execute_query(
|
|
904
1775
|
query,
|
|
905
1776
|
node_uuids=filtered_uuids,
|
|
906
1777
|
center_uuid=center_node_uuid,
|
|
907
|
-
database_=DEFAULT_DATABASE,
|
|
908
1778
|
routing_='r',
|
|
909
1779
|
)
|
|
1780
|
+
if driver.provider == GraphProvider.FALKORDB:
|
|
1781
|
+
results = [dict(zip(header, row, strict=True)) for row in results]
|
|
910
1782
|
|
|
911
|
-
for result in
|
|
1783
|
+
for result in results:
|
|
912
1784
|
uuid = result['uuid']
|
|
913
1785
|
score = result['score']
|
|
914
1786
|
scores[uuid] = score
|
|
@@ -925,37 +1797,42 @@ async def node_distance_reranker(
|
|
|
925
1797
|
scores[center_node_uuid] = 0.1
|
|
926
1798
|
filtered_uuids = [center_node_uuid] + filtered_uuids
|
|
927
1799
|
|
|
928
|
-
return [uuid for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score]
|
|
1800
|
+
return [uuid for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score], [
|
|
1801
|
+
1 / scores[uuid] for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score
|
|
1802
|
+
]
|
|
929
1803
|
|
|
930
1804
|
|
|
931
1805
|
async def episode_mentions_reranker(
|
|
932
|
-
driver:
|
|
933
|
-
) -> list[str]:
|
|
1806
|
+
driver: GraphDriver, node_uuids: list[list[str]], min_score: float = 0
|
|
1807
|
+
) -> tuple[list[str], list[float]]:
|
|
934
1808
|
# use rrf as a preliminary ranker
|
|
935
|
-
sorted_uuids = rrf(node_uuids)
|
|
1809
|
+
sorted_uuids, _ = rrf(node_uuids)
|
|
936
1810
|
scores: dict[str, float] = {}
|
|
937
1811
|
|
|
938
1812
|
# Find the shortest path to center node
|
|
939
|
-
|
|
940
|
-
|
|
1813
|
+
results, _, _ = await driver.execute_query(
|
|
1814
|
+
"""
|
|
1815
|
+
UNWIND $node_uuids AS node_uuid
|
|
941
1816
|
MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
|
|
942
1817
|
RETURN count(*) AS score, n.uuid AS uuid
|
|
943
|
-
"""
|
|
944
|
-
|
|
945
|
-
results, _, _ = await driver.execute_query(
|
|
946
|
-
query,
|
|
1818
|
+
""",
|
|
947
1819
|
node_uuids=sorted_uuids,
|
|
948
|
-
database_=DEFAULT_DATABASE,
|
|
949
1820
|
routing_='r',
|
|
950
1821
|
)
|
|
951
1822
|
|
|
952
1823
|
for result in results:
|
|
953
1824
|
scores[result['uuid']] = result['score']
|
|
954
1825
|
|
|
1826
|
+
for uuid in sorted_uuids:
|
|
1827
|
+
if uuid not in scores:
|
|
1828
|
+
scores[uuid] = float('inf')
|
|
1829
|
+
|
|
955
1830
|
# rerank on shortest distance
|
|
956
1831
|
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
|
957
1832
|
|
|
958
|
-
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score]
|
|
1833
|
+
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [
|
|
1834
|
+
scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score
|
|
1835
|
+
]
|
|
959
1836
|
|
|
960
1837
|
|
|
961
1838
|
def maximal_marginal_relevance(
|
|
@@ -963,7 +1840,7 @@ def maximal_marginal_relevance(
|
|
|
963
1840
|
candidates: dict[str, list[float]],
|
|
964
1841
|
mmr_lambda: float = DEFAULT_MMR_LAMBDA,
|
|
965
1842
|
min_score: float = -2.0,
|
|
966
|
-
) -> list[str]:
|
|
1843
|
+
) -> tuple[list[str], list[float]]:
|
|
967
1844
|
start = time()
|
|
968
1845
|
query_array = np.array(query_vector)
|
|
969
1846
|
candidate_arrays: dict[str, NDArray] = {}
|
|
@@ -994,21 +1871,36 @@ def maximal_marginal_relevance(
|
|
|
994
1871
|
end = time()
|
|
995
1872
|
logger.debug(f'Completed MMR reranking in {(end - start) * 1000} ms')
|
|
996
1873
|
|
|
997
|
-
return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score]
|
|
1874
|
+
return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score], [
|
|
1875
|
+
mmr_scores[uuid] for uuid in uuids if mmr_scores[uuid] >= min_score
|
|
1876
|
+
]
|
|
998
1877
|
|
|
999
1878
|
|
|
1000
1879
|
async def get_embeddings_for_nodes(
|
|
1001
|
-
driver:
|
|
1880
|
+
driver: GraphDriver, nodes: list[EntityNode]
|
|
1002
1881
|
) -> dict[str, list[float]]:
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
|
-
|
|
1882
|
+
if driver.graph_operations_interface:
|
|
1883
|
+
return await driver.graph_operations_interface.node_load_embeddings_bulk(driver, nodes)
|
|
1884
|
+
elif driver.provider == GraphProvider.NEPTUNE:
|
|
1885
|
+
query = """
|
|
1886
|
+
MATCH (n:Entity)
|
|
1887
|
+
WHERE n.uuid IN $node_uuids
|
|
1888
|
+
RETURN DISTINCT
|
|
1889
|
+
n.uuid AS uuid,
|
|
1890
|
+
split(n.name_embedding, ",") AS name_embedding
|
|
1891
|
+
"""
|
|
1892
|
+
else:
|
|
1893
|
+
query = """
|
|
1894
|
+
MATCH (n:Entity)
|
|
1895
|
+
WHERE n.uuid IN $node_uuids
|
|
1896
|
+
RETURN DISTINCT
|
|
1897
|
+
n.uuid AS uuid,
|
|
1898
|
+
n.name_embedding AS name_embedding
|
|
1899
|
+
"""
|
|
1010
1900
|
results, _, _ = await driver.execute_query(
|
|
1011
|
-
query,
|
|
1901
|
+
query,
|
|
1902
|
+
node_uuids=[node.uuid for node in nodes],
|
|
1903
|
+
routing_='r',
|
|
1012
1904
|
)
|
|
1013
1905
|
|
|
1014
1906
|
embeddings_dict: dict[str, list[float]] = {}
|
|
@@ -1022,19 +1914,27 @@ async def get_embeddings_for_nodes(
|
|
|
1022
1914
|
|
|
1023
1915
|
|
|
1024
1916
|
async def get_embeddings_for_communities(
|
|
1025
|
-
driver:
|
|
1917
|
+
driver: GraphDriver, communities: list[CommunityNode]
|
|
1026
1918
|
) -> dict[str, list[float]]:
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1919
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
1920
|
+
query = """
|
|
1921
|
+
MATCH (c:Community)
|
|
1922
|
+
WHERE c.uuid IN $community_uuids
|
|
1923
|
+
RETURN DISTINCT
|
|
1924
|
+
c.uuid AS uuid,
|
|
1925
|
+
split(c.name_embedding, ",") AS name_embedding
|
|
1926
|
+
"""
|
|
1927
|
+
else:
|
|
1928
|
+
query = """
|
|
1929
|
+
MATCH (c:Community)
|
|
1930
|
+
WHERE c.uuid IN $community_uuids
|
|
1931
|
+
RETURN DISTINCT
|
|
1932
|
+
c.uuid AS uuid,
|
|
1933
|
+
c.name_embedding AS name_embedding
|
|
1934
|
+
"""
|
|
1034
1935
|
results, _, _ = await driver.execute_query(
|
|
1035
1936
|
query,
|
|
1036
1937
|
community_uuids=[community.uuid for community in communities],
|
|
1037
|
-
database_=DEFAULT_DATABASE,
|
|
1038
1938
|
routing_='r',
|
|
1039
1939
|
)
|
|
1040
1940
|
|
|
@@ -1049,19 +1949,39 @@ async def get_embeddings_for_communities(
|
|
|
1049
1949
|
|
|
1050
1950
|
|
|
1051
1951
|
async def get_embeddings_for_edges(
|
|
1052
|
-
driver:
|
|
1952
|
+
driver: GraphDriver, edges: list[EntityEdge]
|
|
1053
1953
|
) -> dict[str, list[float]]:
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1954
|
+
if driver.graph_operations_interface:
|
|
1955
|
+
return await driver.graph_operations_interface.edge_load_embeddings_bulk(driver, edges)
|
|
1956
|
+
elif driver.provider == GraphProvider.NEPTUNE:
|
|
1957
|
+
query = """
|
|
1958
|
+
MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
|
1959
|
+
WHERE e.uuid IN $edge_uuids
|
|
1960
|
+
RETURN DISTINCT
|
|
1961
|
+
e.uuid AS uuid,
|
|
1962
|
+
split(e.fact_embedding, ",") AS fact_embedding
|
|
1963
|
+
"""
|
|
1964
|
+
else:
|
|
1965
|
+
match_query = """
|
|
1966
|
+
MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
|
1967
|
+
"""
|
|
1968
|
+
if driver.provider == GraphProvider.KUZU:
|
|
1969
|
+
match_query = """
|
|
1970
|
+
MATCH (n:Entity)-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(m:Entity)
|
|
1971
|
+
"""
|
|
1060
1972
|
|
|
1973
|
+
query = (
|
|
1974
|
+
match_query
|
|
1975
|
+
+ """
|
|
1976
|
+
WHERE e.uuid IN $edge_uuids
|
|
1977
|
+
RETURN DISTINCT
|
|
1978
|
+
e.uuid AS uuid,
|
|
1979
|
+
e.fact_embedding AS fact_embedding
|
|
1980
|
+
"""
|
|
1981
|
+
)
|
|
1061
1982
|
results, _, _ = await driver.execute_query(
|
|
1062
1983
|
query,
|
|
1063
1984
|
edge_uuids=[edge.uuid for edge in edges],
|
|
1064
|
-
database_=DEFAULT_DATABASE,
|
|
1065
1985
|
routing_='r',
|
|
1066
1986
|
)
|
|
1067
1987
|
|