graphiti-core 0.18.9__py3-none-any.whl → 0.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +4 -0
- graphiti_core/driver/falkordb_driver.py +3 -14
- graphiti_core/driver/kuzu_driver.py +175 -0
- graphiti_core/driver/neptune_driver.py +301 -0
- graphiti_core/edges.py +155 -62
- graphiti_core/graph_queries.py +31 -2
- graphiti_core/graphiti.py +6 -1
- graphiti_core/helpers.py +8 -8
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/openai_base_client.py +12 -2
- graphiti_core/llm_client/openai_client.py +10 -2
- graphiti_core/migrations/__init__.py +0 -0
- graphiti_core/migrations/neo4j_node_group_labels.py +114 -0
- graphiti_core/models/edges/edge_db_queries.py +205 -76
- graphiti_core/models/nodes/node_db_queries.py +253 -74
- graphiti_core/nodes.py +271 -98
- graphiti_core/search/search.py +42 -12
- graphiti_core/search/search_config.py +4 -0
- graphiti_core/search/search_filters.py +35 -22
- graphiti_core/search/search_utils.py +1329 -392
- graphiti_core/utils/bulk_utils.py +50 -15
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +39 -32
- graphiti_core/utils/maintenance/edge_operations.py +47 -13
- graphiti_core/utils/maintenance/graph_data_operations.py +100 -15
- {graphiti_core-0.18.9.dist-info → graphiti_core-0.19.0.dist-info}/METADATA +87 -13
- {graphiti_core-0.18.9.dist-info → graphiti_core-0.19.0.dist-info}/RECORD +29 -25
- {graphiti_core-0.18.9.dist-info → graphiti_core-0.19.0.dist-info}/WHEEL +0 -0
- {graphiti_core-0.18.9.dist-info → graphiti_core-0.19.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -15,6 +15,7 @@ limitations under the License.
|
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
17
|
import logging
|
|
18
|
+
import os
|
|
18
19
|
from collections import defaultdict
|
|
19
20
|
from time import time
|
|
20
21
|
from typing import Any
|
|
@@ -36,10 +37,13 @@ from graphiti_core.helpers import (
|
|
|
36
37
|
normalize_l2,
|
|
37
38
|
semaphore_gather,
|
|
38
39
|
)
|
|
39
|
-
from graphiti_core.models.edges.edge_db_queries import
|
|
40
|
-
from graphiti_core.models.nodes.node_db_queries import
|
|
40
|
+
from graphiti_core.models.edges.edge_db_queries import get_entity_edge_return_query
|
|
41
|
+
from graphiti_core.models.nodes.node_db_queries import (
|
|
42
|
+
COMMUNITY_NODE_RETURN,
|
|
43
|
+
EPISODIC_NODE_RETURN,
|
|
44
|
+
get_entity_node_return_query,
|
|
45
|
+
)
|
|
41
46
|
from graphiti_core.nodes import (
|
|
42
|
-
ENTITY_NODE_RETURN,
|
|
43
47
|
CommunityNode,
|
|
44
48
|
EntityNode,
|
|
45
49
|
EpisodicNode,
|
|
@@ -54,6 +58,7 @@ from graphiti_core.search.search_filters import (
|
|
|
54
58
|
)
|
|
55
59
|
|
|
56
60
|
logger = logging.getLogger(__name__)
|
|
61
|
+
USE_HNSW = os.getenv('USE_HNSW', '').lower() in ('true', '1', 'yes')
|
|
57
62
|
|
|
58
63
|
RELEVANT_SCHEMA_LIMIT = 10
|
|
59
64
|
DEFAULT_MIN_SCORE = 0.6
|
|
@@ -62,9 +67,30 @@ MAX_SEARCH_DEPTH = 3
|
|
|
62
67
|
MAX_QUERY_LENGTH = 128
|
|
63
68
|
|
|
64
69
|
|
|
65
|
-
def
|
|
70
|
+
def calculate_cosine_similarity(vector1: list[float], vector2: list[float]) -> float:
|
|
71
|
+
"""
|
|
72
|
+
Calculates the cosine similarity between two vectors using NumPy.
|
|
73
|
+
"""
|
|
74
|
+
dot_product = np.dot(vector1, vector2)
|
|
75
|
+
norm_vector1 = np.linalg.norm(vector1)
|
|
76
|
+
norm_vector2 = np.linalg.norm(vector2)
|
|
77
|
+
|
|
78
|
+
if norm_vector1 == 0 or norm_vector2 == 0:
|
|
79
|
+
return 0 # Handle cases where one or both vectors are zero vectors
|
|
80
|
+
|
|
81
|
+
return dot_product / (norm_vector1 * norm_vector2)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def fulltext_query(query: str, group_ids: list[str] | None, driver: GraphDriver):
|
|
85
|
+
if driver.provider == GraphProvider.KUZU:
|
|
86
|
+
# Kuzu only supports simple queries.
|
|
87
|
+
if len(query.split(' ')) > MAX_QUERY_LENGTH:
|
|
88
|
+
return ''
|
|
89
|
+
return query
|
|
66
90
|
group_ids_filter_list = (
|
|
67
|
-
[fulltext_syntax + f'group_id:"{g}"' for g in group_ids]
|
|
91
|
+
[driver.fulltext_syntax + f'group_id:"{g}"' for g in group_ids]
|
|
92
|
+
if group_ids is not None
|
|
93
|
+
else []
|
|
68
94
|
)
|
|
69
95
|
group_ids_filter = ''
|
|
70
96
|
for f in group_ids_filter_list:
|
|
@@ -108,12 +134,12 @@ async def get_mentioned_nodes(
|
|
|
108
134
|
WHERE episode.uuid IN $uuids
|
|
109
135
|
RETURN DISTINCT
|
|
110
136
|
"""
|
|
111
|
-
+
|
|
137
|
+
+ get_entity_node_return_query(driver.provider),
|
|
112
138
|
uuids=episode_uuids,
|
|
113
139
|
routing_='r',
|
|
114
140
|
)
|
|
115
141
|
|
|
116
|
-
nodes = [get_entity_node_from_record(record) for record in records]
|
|
142
|
+
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
|
|
117
143
|
|
|
118
144
|
return nodes
|
|
119
145
|
|
|
@@ -125,7 +151,7 @@ async def get_communities_by_nodes(
|
|
|
125
151
|
|
|
126
152
|
records, _, _ = await driver.execute_query(
|
|
127
153
|
"""
|
|
128
|
-
MATCH (
|
|
154
|
+
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)
|
|
129
155
|
WHERE m.uuid IN $uuids
|
|
130
156
|
RETURN DISTINCT
|
|
131
157
|
"""
|
|
@@ -147,40 +173,105 @@ async def edge_fulltext_search(
|
|
|
147
173
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
148
174
|
) -> list[EntityEdge]:
|
|
149
175
|
# fulltext search over facts
|
|
150
|
-
fuzzy_query = fulltext_query(query, group_ids, driver
|
|
176
|
+
fuzzy_query = fulltext_query(query, group_ids, driver)
|
|
177
|
+
|
|
151
178
|
if fuzzy_query == '':
|
|
152
179
|
return []
|
|
153
180
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
+ filter_query
|
|
163
|
-
+ """
|
|
164
|
-
WITH e, score, n, m
|
|
165
|
-
RETURN
|
|
166
|
-
"""
|
|
167
|
-
+ ENTITY_EDGE_RETURN
|
|
168
|
-
+ """
|
|
169
|
-
ORDER BY score DESC
|
|
170
|
-
LIMIT $limit
|
|
181
|
+
match_query = """
|
|
182
|
+
YIELD relationship AS rel, score
|
|
183
|
+
MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
|
|
184
|
+
"""
|
|
185
|
+
if driver.provider == GraphProvider.KUZU:
|
|
186
|
+
match_query = """
|
|
187
|
+
YIELD node, score
|
|
188
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: node.uuid})-[:RELATES_TO]->(m:Entity)
|
|
171
189
|
"""
|
|
172
|
-
)
|
|
173
190
|
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
query=fuzzy_query,
|
|
177
|
-
group_ids=group_ids,
|
|
178
|
-
limit=limit,
|
|
179
|
-
routing_='r',
|
|
180
|
-
**filter_params,
|
|
191
|
+
filter_queries, filter_params = edge_search_filter_query_constructor(
|
|
192
|
+
search_filter, driver.provider
|
|
181
193
|
)
|
|
182
194
|
|
|
183
|
-
|
|
195
|
+
if group_ids is not None:
|
|
196
|
+
filter_queries.append('e.group_id IN $group_ids')
|
|
197
|
+
filter_params['group_ids'] = group_ids
|
|
198
|
+
|
|
199
|
+
filter_query = ''
|
|
200
|
+
if filter_queries:
|
|
201
|
+
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
|
|
202
|
+
|
|
203
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
204
|
+
res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue
|
|
205
|
+
if res['hits']['total']['value'] > 0:
|
|
206
|
+
# Calculate Cosine similarity then return the edge ids
|
|
207
|
+
input_ids = []
|
|
208
|
+
for r in res['hits']['hits']:
|
|
209
|
+
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
|
|
210
|
+
|
|
211
|
+
# Match the edge ids and return the values
|
|
212
|
+
query = (
|
|
213
|
+
"""
|
|
214
|
+
UNWIND $ids as id
|
|
215
|
+
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
216
|
+
WHERE e.group_id IN $group_ids
|
|
217
|
+
AND id(e)=id
|
|
218
|
+
"""
|
|
219
|
+
+ filter_query
|
|
220
|
+
+ """
|
|
221
|
+
AND id(e)=id
|
|
222
|
+
WITH e, id.score as score, startNode(e) AS n, endNode(e) AS m
|
|
223
|
+
RETURN
|
|
224
|
+
e.uuid AS uuid,
|
|
225
|
+
e.group_id AS group_id,
|
|
226
|
+
n.uuid AS source_node_uuid,
|
|
227
|
+
m.uuid AS target_node_uuid,
|
|
228
|
+
e.created_at AS created_at,
|
|
229
|
+
e.name AS name,
|
|
230
|
+
e.fact AS fact,
|
|
231
|
+
split(e.episodes, ",") AS episodes,
|
|
232
|
+
e.expired_at AS expired_at,
|
|
233
|
+
e.valid_at AS valid_at,
|
|
234
|
+
e.invalid_at AS invalid_at,
|
|
235
|
+
properties(e) AS attributes
|
|
236
|
+
ORDER BY score DESC LIMIT $limit
|
|
237
|
+
"""
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
records, _, _ = await driver.execute_query(
|
|
241
|
+
query,
|
|
242
|
+
query=fuzzy_query,
|
|
243
|
+
ids=input_ids,
|
|
244
|
+
limit=limit,
|
|
245
|
+
routing_='r',
|
|
246
|
+
**filter_params,
|
|
247
|
+
)
|
|
248
|
+
else:
|
|
249
|
+
return []
|
|
250
|
+
else:
|
|
251
|
+
query = (
|
|
252
|
+
get_relationships_query('edge_name_and_fact', limit=limit, provider=driver.provider)
|
|
253
|
+
+ match_query
|
|
254
|
+
+ filter_query
|
|
255
|
+
+ """
|
|
256
|
+
WITH e, score, n, m
|
|
257
|
+
RETURN
|
|
258
|
+
"""
|
|
259
|
+
+ get_entity_edge_return_query(driver.provider)
|
|
260
|
+
+ """
|
|
261
|
+
ORDER BY score DESC
|
|
262
|
+
LIMIT $limit
|
|
263
|
+
"""
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
records, _, _ = await driver.execute_query(
|
|
267
|
+
query,
|
|
268
|
+
query=fuzzy_query,
|
|
269
|
+
limit=limit,
|
|
270
|
+
routing_='r',
|
|
271
|
+
**filter_params,
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
184
275
|
|
|
185
276
|
return edges
|
|
186
277
|
|
|
@@ -195,56 +286,130 @@ async def edge_similarity_search(
|
|
|
195
286
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
196
287
|
min_score: float = DEFAULT_MIN_SCORE,
|
|
197
288
|
) -> list[EntityEdge]:
|
|
198
|
-
|
|
199
|
-
|
|
289
|
+
match_query = """
|
|
290
|
+
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
291
|
+
"""
|
|
292
|
+
if driver.provider == GraphProvider.KUZU:
|
|
293
|
+
match_query = """
|
|
294
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
|
|
295
|
+
"""
|
|
200
296
|
|
|
201
|
-
|
|
202
|
-
|
|
297
|
+
filter_queries, filter_params = edge_search_filter_query_constructor(
|
|
298
|
+
search_filter, driver.provider
|
|
299
|
+
)
|
|
203
300
|
|
|
204
|
-
group_filter_query: LiteralString = 'WHERE e.group_id IS NOT NULL'
|
|
205
301
|
if group_ids is not None:
|
|
206
|
-
|
|
207
|
-
|
|
302
|
+
filter_queries.append('e.group_id IN $group_ids')
|
|
303
|
+
filter_params['group_ids'] = group_ids
|
|
208
304
|
|
|
209
305
|
if source_node_uuid is not None:
|
|
210
|
-
|
|
211
|
-
|
|
306
|
+
filter_params['source_uuid'] = source_node_uuid
|
|
307
|
+
filter_queries.append('n.uuid = $source_uuid')
|
|
212
308
|
|
|
213
309
|
if target_node_uuid is not None:
|
|
214
|
-
|
|
215
|
-
|
|
310
|
+
filter_params['target_uuid'] = target_node_uuid
|
|
311
|
+
filter_queries.append('m.uuid = $target_uuid')
|
|
312
|
+
|
|
313
|
+
filter_query = ''
|
|
314
|
+
if filter_queries:
|
|
315
|
+
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
|
|
316
|
+
|
|
317
|
+
search_vector_var = '$search_vector'
|
|
318
|
+
if driver.provider == GraphProvider.KUZU:
|
|
319
|
+
search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
|
|
320
|
+
|
|
321
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
322
|
+
query = (
|
|
323
|
+
RUNTIME_QUERY
|
|
324
|
+
+ """
|
|
325
|
+
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
326
|
+
"""
|
|
327
|
+
+ filter_query
|
|
328
|
+
+ """
|
|
329
|
+
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
|
|
330
|
+
"""
|
|
331
|
+
)
|
|
332
|
+
resp, header, _ = await driver.execute_query(
|
|
333
|
+
query,
|
|
334
|
+
search_vector=search_vector,
|
|
335
|
+
limit=limit,
|
|
336
|
+
min_score=min_score,
|
|
337
|
+
routing_='r',
|
|
338
|
+
**filter_params,
|
|
339
|
+
)
|
|
216
340
|
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
341
|
+
if len(resp) > 0:
|
|
342
|
+
# Calculate Cosine similarity then return the edge ids
|
|
343
|
+
input_ids = []
|
|
344
|
+
for r in resp:
|
|
345
|
+
if r['embedding']:
|
|
346
|
+
score = calculate_cosine_similarity(
|
|
347
|
+
search_vector, list(map(float, r['embedding'].split(',')))
|
|
348
|
+
)
|
|
349
|
+
if score > min_score:
|
|
350
|
+
input_ids.append({'id': r['id'], 'score': score})
|
|
351
|
+
|
|
352
|
+
# Match the edge ides and return the values
|
|
353
|
+
query = """
|
|
354
|
+
UNWIND $ids as i
|
|
355
|
+
MATCH ()-[r]->()
|
|
356
|
+
WHERE id(r) = i.id
|
|
357
|
+
RETURN
|
|
358
|
+
r.uuid AS uuid,
|
|
359
|
+
r.group_id AS group_id,
|
|
360
|
+
startNode(r).uuid AS source_node_uuid,
|
|
361
|
+
endNode(r).uuid AS target_node_uuid,
|
|
362
|
+
r.created_at AS created_at,
|
|
363
|
+
r.name AS name,
|
|
364
|
+
r.fact AS fact,
|
|
365
|
+
split(r.episodes, ",") AS episodes,
|
|
366
|
+
r.expired_at AS expired_at,
|
|
367
|
+
r.valid_at AS valid_at,
|
|
368
|
+
r.invalid_at AS invalid_at,
|
|
369
|
+
properties(r) AS attributes
|
|
370
|
+
ORDER BY i.score DESC
|
|
371
|
+
LIMIT $limit
|
|
372
|
+
"""
|
|
373
|
+
records, _, _ = await driver.execute_query(
|
|
374
|
+
query,
|
|
375
|
+
ids=input_ids,
|
|
376
|
+
search_vector=search_vector,
|
|
377
|
+
limit=limit,
|
|
378
|
+
min_score=min_score,
|
|
379
|
+
routing_='r',
|
|
380
|
+
**filter_params,
|
|
381
|
+
)
|
|
382
|
+
else:
|
|
383
|
+
return []
|
|
384
|
+
else:
|
|
385
|
+
query = (
|
|
386
|
+
RUNTIME_QUERY
|
|
387
|
+
+ match_query
|
|
388
|
+
+ filter_query
|
|
389
|
+
+ """
|
|
390
|
+
WITH DISTINCT e, n, m, """
|
|
391
|
+
+ get_vector_cosine_func_query('e.fact_embedding', search_vector_var, driver.provider)
|
|
392
|
+
+ """ AS score
|
|
393
|
+
WHERE score > $min_score
|
|
394
|
+
RETURN
|
|
395
|
+
"""
|
|
396
|
+
+ get_entity_edge_return_query(driver.provider)
|
|
397
|
+
+ """
|
|
398
|
+
ORDER BY score DESC
|
|
399
|
+
LIMIT $limit
|
|
400
|
+
"""
|
|
401
|
+
)
|
|
237
402
|
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
403
|
+
records, _, _ = await driver.execute_query(
|
|
404
|
+
query,
|
|
405
|
+
search_vector=search_vector,
|
|
406
|
+
limit=limit,
|
|
407
|
+
min_score=min_score,
|
|
408
|
+
routing_='r',
|
|
409
|
+
**filter_params,
|
|
410
|
+
)
|
|
246
411
|
|
|
247
|
-
edges = [get_entity_edge_from_record(record) for record in records]
|
|
412
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
248
413
|
|
|
249
414
|
return edges
|
|
250
415
|
|
|
@@ -258,40 +423,116 @@ async def edge_bfs_search(
|
|
|
258
423
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
259
424
|
) -> list[EntityEdge]:
|
|
260
425
|
# vector similarity search over embedded facts
|
|
261
|
-
if bfs_origin_node_uuids is None:
|
|
426
|
+
if bfs_origin_node_uuids is None or len(bfs_origin_node_uuids) == 0:
|
|
262
427
|
return []
|
|
263
428
|
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
query = (
|
|
267
|
-
f"""
|
|
268
|
-
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
269
|
-
MATCH path = (origin:Entity|Episodic {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(:Entity)
|
|
270
|
-
UNWIND relationships(path) AS rel
|
|
271
|
-
MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
|
272
|
-
WHERE e.uuid = rel.uuid
|
|
273
|
-
AND e.group_id IN $group_ids
|
|
274
|
-
"""
|
|
275
|
-
+ filter_query
|
|
276
|
-
+ """
|
|
277
|
-
RETURN DISTINCT
|
|
278
|
-
"""
|
|
279
|
-
+ ENTITY_EDGE_RETURN
|
|
280
|
-
+ """
|
|
281
|
-
LIMIT $limit
|
|
282
|
-
"""
|
|
429
|
+
filter_queries, filter_params = edge_search_filter_query_constructor(
|
|
430
|
+
search_filter, driver.provider
|
|
283
431
|
)
|
|
284
432
|
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
433
|
+
if group_ids is not None:
|
|
434
|
+
filter_queries.append('e.group_id IN $group_ids')
|
|
435
|
+
filter_params['group_ids'] = group_ids
|
|
436
|
+
|
|
437
|
+
filter_query = ''
|
|
438
|
+
if filter_queries:
|
|
439
|
+
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
|
|
440
|
+
|
|
441
|
+
if driver.provider == GraphProvider.KUZU:
|
|
442
|
+
# Kuzu stores entity edges twice with an intermediate node, so we need to match them
|
|
443
|
+
# separately for the correct BFS depth.
|
|
444
|
+
depth = bfs_max_depth * 2 - 1
|
|
445
|
+
match_queries = [
|
|
446
|
+
f"""
|
|
447
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
448
|
+
MATCH path = (origin:Entity {{uuid: origin_uuid}})-[:RELATES_TO*1..{depth}]->(:RelatesToNode_)
|
|
449
|
+
UNWIND nodes(path) AS relNode
|
|
450
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {{uuid: relNode.uuid}})-[:RELATES_TO]->(m:Entity)
|
|
451
|
+
""",
|
|
452
|
+
]
|
|
453
|
+
if bfs_max_depth > 1:
|
|
454
|
+
depth = (bfs_max_depth - 1) * 2 - 1
|
|
455
|
+
match_queries.append(f"""
|
|
456
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
457
|
+
MATCH path = (origin:Episodic {{uuid: origin_uuid}})-[:MENTIONS]->(:Entity)-[:RELATES_TO*1..{depth}]->(:RelatesToNode_)
|
|
458
|
+
UNWIND nodes(path) AS relNode
|
|
459
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {{uuid: relNode.uuid}})-[:RELATES_TO]->(m:Entity)
|
|
460
|
+
""")
|
|
461
|
+
|
|
462
|
+
records = []
|
|
463
|
+
for match_query in match_queries:
|
|
464
|
+
sub_records, _, _ = await driver.execute_query(
|
|
465
|
+
match_query
|
|
466
|
+
+ filter_query
|
|
467
|
+
+ """
|
|
468
|
+
RETURN DISTINCT
|
|
469
|
+
"""
|
|
470
|
+
+ get_entity_edge_return_query(driver.provider)
|
|
471
|
+
+ """
|
|
472
|
+
LIMIT $limit
|
|
473
|
+
""",
|
|
474
|
+
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
|
475
|
+
limit=limit,
|
|
476
|
+
routing_='r',
|
|
477
|
+
**filter_params,
|
|
478
|
+
)
|
|
479
|
+
records.extend(sub_records)
|
|
480
|
+
else:
|
|
481
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
482
|
+
query = (
|
|
483
|
+
f"""
|
|
484
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
485
|
+
MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS *1..{bfs_max_depth}]->(n:Entity)
|
|
486
|
+
WHERE origin:Entity OR origin:Episodic
|
|
487
|
+
UNWIND relationships(path) AS rel
|
|
488
|
+
MATCH (n:Entity)-[e:RELATES_TO {{uuid: rel.uuid}}]-(m:Entity)
|
|
489
|
+
"""
|
|
490
|
+
+ filter_query
|
|
491
|
+
+ """
|
|
492
|
+
RETURN DISTINCT
|
|
493
|
+
e.uuid AS uuid,
|
|
494
|
+
e.group_id AS group_id,
|
|
495
|
+
startNode(e).uuid AS source_node_uuid,
|
|
496
|
+
endNode(e).uuid AS target_node_uuid,
|
|
497
|
+
e.created_at AS created_at,
|
|
498
|
+
e.name AS name,
|
|
499
|
+
e.fact AS fact,
|
|
500
|
+
split(e.episodes, ',') AS episodes,
|
|
501
|
+
e.expired_at AS expired_at,
|
|
502
|
+
e.valid_at AS valid_at,
|
|
503
|
+
e.invalid_at AS invalid_at,
|
|
504
|
+
properties(e) AS attributes
|
|
505
|
+
LIMIT $limit
|
|
506
|
+
"""
|
|
507
|
+
)
|
|
508
|
+
else:
|
|
509
|
+
query = (
|
|
510
|
+
f"""
|
|
511
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
512
|
+
MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(:Entity)
|
|
513
|
+
UNWIND relationships(path) AS rel
|
|
514
|
+
MATCH (n:Entity)-[e:RELATES_TO {{uuid: rel.uuid}}]-(m:Entity)
|
|
515
|
+
"""
|
|
516
|
+
+ filter_query
|
|
517
|
+
+ """
|
|
518
|
+
RETURN DISTINCT
|
|
519
|
+
"""
|
|
520
|
+
+ get_entity_edge_return_query(driver.provider)
|
|
521
|
+
+ """
|
|
522
|
+
LIMIT $limit
|
|
523
|
+
"""
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
records, _, _ = await driver.execute_query(
|
|
527
|
+
query,
|
|
528
|
+
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
|
529
|
+
depth=bfs_max_depth,
|
|
530
|
+
limit=limit,
|
|
531
|
+
routing_='r',
|
|
532
|
+
**filter_params,
|
|
533
|
+
)
|
|
293
534
|
|
|
294
|
-
edges = [get_entity_edge_from_record(record) for record in records]
|
|
535
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
295
536
|
|
|
296
537
|
return edges
|
|
297
538
|
|
|
@@ -302,39 +543,90 @@ async def node_fulltext_search(
|
|
|
302
543
|
search_filter: SearchFilters,
|
|
303
544
|
group_ids: list[str] | None = None,
|
|
304
545
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
546
|
+
use_local_indexes: bool = False,
|
|
305
547
|
) -> list[EntityNode]:
|
|
306
548
|
# BM25 search to get top nodes
|
|
307
|
-
fuzzy_query = fulltext_query(query, group_ids, driver
|
|
549
|
+
fuzzy_query = fulltext_query(query, group_ids, driver)
|
|
308
550
|
if fuzzy_query == '':
|
|
309
551
|
return []
|
|
310
|
-
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
|
311
552
|
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
+ """
|
|
315
|
-
YIELD node AS n, score
|
|
316
|
-
WHERE n:Entity AND n.group_id IN $group_ids
|
|
317
|
-
"""
|
|
318
|
-
+ filter_query
|
|
319
|
-
+ """
|
|
320
|
-
WITH n, score
|
|
321
|
-
ORDER BY score DESC
|
|
322
|
-
LIMIT $limit
|
|
323
|
-
RETURN
|
|
324
|
-
"""
|
|
325
|
-
+ ENTITY_NODE_RETURN
|
|
553
|
+
filter_queries, filter_params = node_search_filter_query_constructor(
|
|
554
|
+
search_filter, driver.provider
|
|
326
555
|
)
|
|
327
556
|
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
557
|
+
if group_ids is not None:
|
|
558
|
+
filter_queries.append('n.group_id IN $group_ids')
|
|
559
|
+
filter_params['group_ids'] = group_ids
|
|
560
|
+
|
|
561
|
+
filter_query = ''
|
|
562
|
+
if filter_queries:
|
|
563
|
+
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
|
|
564
|
+
|
|
565
|
+
yield_query = 'YIELD node AS n, score'
|
|
566
|
+
if driver.provider == GraphProvider.KUZU:
|
|
567
|
+
yield_query = 'WITH node AS n, score'
|
|
568
|
+
|
|
569
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
570
|
+
res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
|
571
|
+
if res['hits']['total']['value'] > 0:
|
|
572
|
+
# Calculate Cosine similarity then return the edge ids
|
|
573
|
+
input_ids = []
|
|
574
|
+
for r in res['hits']['hits']:
|
|
575
|
+
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
|
|
576
|
+
|
|
577
|
+
# Match the edge ides and return the values
|
|
578
|
+
query = (
|
|
579
|
+
"""
|
|
580
|
+
UNWIND $ids as i
|
|
581
|
+
MATCH (n:Entity)
|
|
582
|
+
WHERE n.uuid=i.id
|
|
583
|
+
RETURN
|
|
584
|
+
"""
|
|
585
|
+
+ get_entity_node_return_query(driver.provider)
|
|
586
|
+
+ """
|
|
587
|
+
ORDER BY i.score DESC
|
|
588
|
+
LIMIT $limit
|
|
589
|
+
"""
|
|
590
|
+
)
|
|
591
|
+
records, _, _ = await driver.execute_query(
|
|
592
|
+
query,
|
|
593
|
+
ids=input_ids,
|
|
594
|
+
query=fuzzy_query,
|
|
595
|
+
limit=limit,
|
|
596
|
+
routing_='r',
|
|
597
|
+
**filter_params,
|
|
598
|
+
)
|
|
599
|
+
else:
|
|
600
|
+
return []
|
|
601
|
+
else:
|
|
602
|
+
index_name = (
|
|
603
|
+
'node_name_and_summary'
|
|
604
|
+
if not use_local_indexes
|
|
605
|
+
else 'node_name_and_summary_'
|
|
606
|
+
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
|
|
607
|
+
)
|
|
608
|
+
query = (
|
|
609
|
+
get_nodes_query(index_name, '$query', limit=limit, provider=driver.provider)
|
|
610
|
+
+ yield_query
|
|
611
|
+
+ filter_query
|
|
612
|
+
+ """
|
|
613
|
+
WITH n, score
|
|
614
|
+
ORDER BY score DESC
|
|
615
|
+
LIMIT $limit
|
|
616
|
+
RETURN
|
|
617
|
+
"""
|
|
618
|
+
+ get_entity_node_return_query(driver.provider)
|
|
619
|
+
)
|
|
336
620
|
|
|
337
|
-
|
|
621
|
+
records, _, _ = await driver.execute_query(
|
|
622
|
+
query,
|
|
623
|
+
query=fuzzy_query,
|
|
624
|
+
limit=limit,
|
|
625
|
+
routing_='r',
|
|
626
|
+
**filter_params,
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
|
|
338
630
|
|
|
339
631
|
return nodes
|
|
340
632
|
|
|
@@ -346,49 +638,140 @@ async def node_similarity_search(
|
|
|
346
638
|
group_ids: list[str] | None = None,
|
|
347
639
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
348
640
|
min_score: float = DEFAULT_MIN_SCORE,
|
|
641
|
+
use_local_indexes: bool = False,
|
|
349
642
|
) -> list[EntityNode]:
|
|
350
|
-
|
|
351
|
-
|
|
643
|
+
filter_queries, filter_params = node_search_filter_query_constructor(
|
|
644
|
+
search_filter, driver.provider
|
|
645
|
+
)
|
|
352
646
|
|
|
353
|
-
group_filter_query: LiteralString = 'WHERE n.group_id IS NOT NULL'
|
|
354
647
|
if group_ids is not None:
|
|
355
|
-
|
|
356
|
-
|
|
648
|
+
filter_queries.append('n.group_id IN $group_ids')
|
|
649
|
+
filter_params['group_ids'] = group_ids
|
|
650
|
+
|
|
651
|
+
filter_query = ''
|
|
652
|
+
if filter_queries:
|
|
653
|
+
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
|
|
654
|
+
|
|
655
|
+
search_vector_var = '$search_vector'
|
|
656
|
+
if driver.provider == GraphProvider.KUZU:
|
|
657
|
+
search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
|
|
658
|
+
|
|
659
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
660
|
+
query = (
|
|
661
|
+
RUNTIME_QUERY
|
|
662
|
+
+ """
|
|
663
|
+
MATCH (n:Entity)
|
|
664
|
+
"""
|
|
665
|
+
+ filter_query
|
|
666
|
+
+ """
|
|
667
|
+
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
|
668
|
+
"""
|
|
669
|
+
)
|
|
670
|
+
resp, header, _ = await driver.execute_query(
|
|
671
|
+
query,
|
|
672
|
+
params=filter_params,
|
|
673
|
+
search_vector=search_vector,
|
|
674
|
+
limit=limit,
|
|
675
|
+
min_score=min_score,
|
|
676
|
+
routing_='r',
|
|
677
|
+
)
|
|
357
678
|
|
|
358
|
-
|
|
359
|
-
|
|
679
|
+
if len(resp) > 0:
|
|
680
|
+
# Calculate Cosine similarity then return the edge ids
|
|
681
|
+
input_ids = []
|
|
682
|
+
for r in resp:
|
|
683
|
+
if r['embedding']:
|
|
684
|
+
score = calculate_cosine_similarity(
|
|
685
|
+
search_vector, list(map(float, r['embedding'].split(',')))
|
|
686
|
+
)
|
|
687
|
+
if score > min_score:
|
|
688
|
+
input_ids.append({'id': r['id'], 'score': score})
|
|
689
|
+
|
|
690
|
+
# Match the edge ides and return the values
|
|
691
|
+
query = (
|
|
692
|
+
"""
|
|
693
|
+
UNWIND $ids as i
|
|
694
|
+
MATCH (n:Entity)
|
|
695
|
+
WHERE id(n)=i.id
|
|
696
|
+
RETURN
|
|
697
|
+
"""
|
|
698
|
+
+ get_entity_node_return_query(driver.provider)
|
|
699
|
+
+ """
|
|
700
|
+
ORDER BY i.score DESC
|
|
701
|
+
LIMIT $limit
|
|
702
|
+
"""
|
|
703
|
+
)
|
|
704
|
+
records, header, _ = await driver.execute_query(
|
|
705
|
+
query,
|
|
706
|
+
ids=input_ids,
|
|
707
|
+
search_vector=search_vector,
|
|
708
|
+
limit=limit,
|
|
709
|
+
min_score=min_score,
|
|
710
|
+
routing_='r',
|
|
711
|
+
**filter_params,
|
|
712
|
+
)
|
|
713
|
+
else:
|
|
714
|
+
return []
|
|
715
|
+
elif driver.provider == GraphProvider.NEO4J and use_local_indexes:
|
|
716
|
+
index_name = 'group_entity_vector_' + (
|
|
717
|
+
group_ids[0].replace('-', '') if group_ids is not None else ''
|
|
718
|
+
)
|
|
719
|
+
query = (
|
|
720
|
+
f"""
|
|
721
|
+
CALL db.index.vector.queryNodes('{index_name}', {limit}, $search_vector) YIELD node AS n, score
|
|
722
|
+
"""
|
|
723
|
+
+ filter_query
|
|
724
|
+
+ """
|
|
725
|
+
AND score > $min_score
|
|
726
|
+
RETURN
|
|
727
|
+
"""
|
|
728
|
+
+ get_entity_node_return_query(driver.provider)
|
|
729
|
+
+ """
|
|
730
|
+
ORDER BY score DESC
|
|
731
|
+
LIMIT $limit
|
|
732
|
+
"""
|
|
733
|
+
)
|
|
360
734
|
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
WITH n, """
|
|
370
|
-
+ get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
|
|
371
|
-
+ """ AS score
|
|
372
|
-
WHERE score > $min_score
|
|
373
|
-
RETURN
|
|
374
|
-
"""
|
|
375
|
-
+ ENTITY_NODE_RETURN
|
|
376
|
-
+ """
|
|
377
|
-
ORDER BY score DESC
|
|
378
|
-
LIMIT $limit
|
|
379
|
-
"""
|
|
380
|
-
)
|
|
735
|
+
records, _, _ = await driver.execute_query(
|
|
736
|
+
query,
|
|
737
|
+
search_vector=search_vector,
|
|
738
|
+
limit=limit,
|
|
739
|
+
min_score=min_score,
|
|
740
|
+
routing_='r',
|
|
741
|
+
**filter_params,
|
|
742
|
+
)
|
|
381
743
|
|
|
382
|
-
|
|
383
|
-
query
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
744
|
+
else:
|
|
745
|
+
query = (
|
|
746
|
+
RUNTIME_QUERY
|
|
747
|
+
+ """
|
|
748
|
+
MATCH (n:Entity)
|
|
749
|
+
"""
|
|
750
|
+
+ filter_query
|
|
751
|
+
+ """
|
|
752
|
+
WITH n, """
|
|
753
|
+
+ get_vector_cosine_func_query('n.name_embedding', search_vector_var, driver.provider)
|
|
754
|
+
+ """ AS score
|
|
755
|
+
WHERE score > $min_score
|
|
756
|
+
RETURN
|
|
757
|
+
"""
|
|
758
|
+
+ get_entity_node_return_query(driver.provider)
|
|
759
|
+
+ """
|
|
760
|
+
ORDER BY score DESC
|
|
761
|
+
LIMIT $limit
|
|
762
|
+
"""
|
|
763
|
+
)
|
|
764
|
+
|
|
765
|
+
records, _, _ = await driver.execute_query(
|
|
766
|
+
query,
|
|
767
|
+
search_vector=search_vector,
|
|
768
|
+
limit=limit,
|
|
769
|
+
min_score=min_score,
|
|
770
|
+
routing_='r',
|
|
771
|
+
**filter_params,
|
|
772
|
+
)
|
|
390
773
|
|
|
391
|
-
nodes = [get_entity_node_from_record(record) for record in records]
|
|
774
|
+
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
|
|
392
775
|
|
|
393
776
|
return nodes
|
|
394
777
|
|
|
@@ -401,38 +784,82 @@ async def node_bfs_search(
|
|
|
401
784
|
group_ids: list[str] | None = None,
|
|
402
785
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
403
786
|
) -> list[EntityNode]:
|
|
404
|
-
|
|
405
|
-
if bfs_origin_node_uuids is None:
|
|
787
|
+
if bfs_origin_node_uuids is None or len(bfs_origin_node_uuids) == 0 or bfs_max_depth < 1:
|
|
406
788
|
return []
|
|
407
789
|
|
|
408
|
-
|
|
790
|
+
filter_queries, filter_params = node_search_filter_query_constructor(
|
|
791
|
+
search_filter, driver.provider
|
|
792
|
+
)
|
|
793
|
+
|
|
794
|
+
if group_ids is not None:
|
|
795
|
+
filter_queries.append('n.group_id IN $group_ids')
|
|
796
|
+
filter_queries.append('origin.group_id IN $group_ids')
|
|
797
|
+
filter_params['group_ids'] = group_ids
|
|
798
|
+
|
|
799
|
+
filter_query = ''
|
|
800
|
+
if filter_queries:
|
|
801
|
+
filter_query = ' AND ' + (' AND '.join(filter_queries))
|
|
409
802
|
|
|
410
|
-
|
|
803
|
+
match_queries = [
|
|
411
804
|
f"""
|
|
805
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
806
|
+
MATCH (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
|
|
807
|
+
WHERE n.group_id = origin.group_id
|
|
808
|
+
"""
|
|
809
|
+
]
|
|
810
|
+
|
|
811
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
812
|
+
match_queries = [
|
|
813
|
+
f"""
|
|
412
814
|
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
413
|
-
MATCH (origin
|
|
815
|
+
MATCH (origin {{uuid: origin_uuid}})-[e:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
|
|
816
|
+
WHERE origin:Entity OR origin.Episode
|
|
817
|
+
AND n.group_id = origin.group_id
|
|
818
|
+
"""
|
|
819
|
+
]
|
|
820
|
+
|
|
821
|
+
if driver.provider == GraphProvider.KUZU:
|
|
822
|
+
depth = bfs_max_depth * 2
|
|
823
|
+
match_queries = [
|
|
824
|
+
"""
|
|
825
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
826
|
+
MATCH (origin:Episodic {uuid: origin_uuid})-[:MENTIONS]->(n:Entity)
|
|
414
827
|
WHERE n.group_id = origin.group_id
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
828
|
+
""",
|
|
829
|
+
f"""
|
|
830
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
831
|
+
MATCH (origin:Entity {{uuid: origin_uuid}})-[:RELATES_TO*2..{depth}]->(n:Entity)
|
|
832
|
+
WHERE n.group_id = origin.group_id
|
|
833
|
+
""",
|
|
834
|
+
]
|
|
835
|
+
if bfs_max_depth > 1:
|
|
836
|
+
depth = (bfs_max_depth - 1) * 2
|
|
837
|
+
match_queries.append(f"""
|
|
838
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
839
|
+
MATCH (origin:Episodic {{uuid: origin_uuid}})-[:MENTIONS]->(:Entity)-[:RELATES_TO*2..{depth}]->(n:Entity)
|
|
840
|
+
WHERE n.group_id = origin.group_id
|
|
841
|
+
""")
|
|
842
|
+
|
|
843
|
+
records = []
|
|
844
|
+
for match_query in match_queries:
|
|
845
|
+
sub_records, _, _ = await driver.execute_query(
|
|
846
|
+
match_query
|
|
847
|
+
+ filter_query
|
|
848
|
+
+ """
|
|
849
|
+
RETURN
|
|
850
|
+
"""
|
|
851
|
+
+ get_entity_node_return_query(driver.provider)
|
|
852
|
+
+ """
|
|
853
|
+
LIMIT $limit
|
|
854
|
+
""",
|
|
855
|
+
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
|
856
|
+
limit=limit,
|
|
857
|
+
routing_='r',
|
|
858
|
+
**filter_params,
|
|
859
|
+
)
|
|
860
|
+
records.extend(sub_records)
|
|
426
861
|
|
|
427
|
-
|
|
428
|
-
query,
|
|
429
|
-
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
|
430
|
-
group_ids=group_ids,
|
|
431
|
-
limit=limit,
|
|
432
|
-
routing_='r',
|
|
433
|
-
**filter_params,
|
|
434
|
-
)
|
|
435
|
-
nodes = [get_entity_node_from_record(record) for record in records]
|
|
862
|
+
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
|
|
436
863
|
|
|
437
864
|
return nodes
|
|
438
865
|
|
|
@@ -443,35 +870,84 @@ async def episode_fulltext_search(
|
|
|
443
870
|
_search_filter: SearchFilters,
|
|
444
871
|
group_ids: list[str] | None = None,
|
|
445
872
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
873
|
+
use_local_indexes: bool = False,
|
|
446
874
|
) -> list[EpisodicNode]:
|
|
447
875
|
# BM25 search to get top episodes
|
|
448
|
-
fuzzy_query = fulltext_query(query, group_ids, driver
|
|
876
|
+
fuzzy_query = fulltext_query(query, group_ids, driver)
|
|
449
877
|
if fuzzy_query == '':
|
|
450
878
|
return []
|
|
451
879
|
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
880
|
+
filter_params: dict[str, Any] = {}
|
|
881
|
+
group_filter_query: LiteralString = ''
|
|
882
|
+
if group_ids is not None:
|
|
883
|
+
group_filter_query += '\nAND e.group_id IN $group_ids'
|
|
884
|
+
filter_params['group_ids'] = group_ids
|
|
885
|
+
|
|
886
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
887
|
+
res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
|
888
|
+
if res['hits']['total']['value'] > 0:
|
|
889
|
+
# Calculate Cosine similarity then return the edge ids
|
|
890
|
+
input_ids = []
|
|
891
|
+
for r in res['hits']['hits']:
|
|
892
|
+
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
|
|
893
|
+
|
|
894
|
+
# Match the edge ides and return the values
|
|
895
|
+
query = """
|
|
896
|
+
UNWIND $ids as i
|
|
897
|
+
MATCH (e:Episodic)
|
|
898
|
+
WHERE e.uuid=i.id
|
|
899
|
+
RETURN
|
|
900
|
+
e.content AS content,
|
|
901
|
+
e.created_at AS created_at,
|
|
902
|
+
e.valid_at AS valid_at,
|
|
903
|
+
e.uuid AS uuid,
|
|
904
|
+
e.name AS name,
|
|
905
|
+
e.group_id AS group_id,
|
|
906
|
+
e.source_description AS source_description,
|
|
907
|
+
e.source AS source,
|
|
908
|
+
e.entity_edges AS entity_edges
|
|
909
|
+
ORDER BY i.score DESC
|
|
910
|
+
LIMIT $limit
|
|
911
|
+
"""
|
|
912
|
+
records, _, _ = await driver.execute_query(
|
|
913
|
+
query,
|
|
914
|
+
ids=input_ids,
|
|
915
|
+
query=fuzzy_query,
|
|
916
|
+
limit=limit,
|
|
917
|
+
routing_='r',
|
|
918
|
+
**filter_params,
|
|
919
|
+
)
|
|
920
|
+
else:
|
|
921
|
+
return []
|
|
922
|
+
else:
|
|
923
|
+
index_name = (
|
|
924
|
+
'episode_content'
|
|
925
|
+
if not use_local_indexes
|
|
926
|
+
else 'episode_content_'
|
|
927
|
+
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
|
|
928
|
+
)
|
|
929
|
+
query = (
|
|
930
|
+
get_nodes_query(index_name, '$query', limit=limit, provider=driver.provider)
|
|
931
|
+
+ """
|
|
932
|
+
YIELD node AS episode, score
|
|
933
|
+
MATCH (e:Episodic)
|
|
934
|
+
WHERE e.uuid = episode.uuid
|
|
935
|
+
"""
|
|
936
|
+
+ group_filter_query
|
|
937
|
+
+ """
|
|
938
|
+
RETURN
|
|
939
|
+
"""
|
|
940
|
+
+ EPISODIC_NODE_RETURN
|
|
941
|
+
+ """
|
|
942
|
+
ORDER BY score DESC
|
|
943
|
+
LIMIT $limit
|
|
944
|
+
"""
|
|
945
|
+
)
|
|
946
|
+
|
|
947
|
+
records, _, _ = await driver.execute_query(
|
|
948
|
+
query, query=fuzzy_query, limit=limit, routing_='r', **filter_params
|
|
949
|
+
)
|
|
467
950
|
|
|
468
|
-
records, _, _ = await driver.execute_query(
|
|
469
|
-
query,
|
|
470
|
-
query=fuzzy_query,
|
|
471
|
-
group_ids=group_ids,
|
|
472
|
-
limit=limit,
|
|
473
|
-
routing_='r',
|
|
474
|
-
)
|
|
475
951
|
episodes = [get_episodic_node_from_record(record) for record in records]
|
|
476
952
|
|
|
477
953
|
return episodes
|
|
@@ -484,31 +960,75 @@ async def community_fulltext_search(
|
|
|
484
960
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
485
961
|
) -> list[CommunityNode]:
|
|
486
962
|
# BM25 search to get top communities
|
|
487
|
-
fuzzy_query = fulltext_query(query, group_ids, driver
|
|
963
|
+
fuzzy_query = fulltext_query(query, group_ids, driver)
|
|
488
964
|
if fuzzy_query == '':
|
|
489
965
|
return []
|
|
490
966
|
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
967
|
+
filter_params: dict[str, Any] = {}
|
|
968
|
+
group_filter_query: LiteralString = ''
|
|
969
|
+
if group_ids is not None:
|
|
970
|
+
group_filter_query = 'WHERE c.group_id IN $group_ids'
|
|
971
|
+
filter_params['group_ids'] = group_ids
|
|
972
|
+
|
|
973
|
+
yield_query = 'YIELD node AS c, score'
|
|
974
|
+
if driver.provider == GraphProvider.KUZU:
|
|
975
|
+
yield_query = 'WITH node AS c, score'
|
|
976
|
+
|
|
977
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
978
|
+
res = driver.run_aoss_query('community_name', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
|
979
|
+
if res['hits']['total']['value'] > 0:
|
|
980
|
+
# Calculate Cosine similarity then return the edge ids
|
|
981
|
+
input_ids = []
|
|
982
|
+
for r in res['hits']['hits']:
|
|
983
|
+
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
|
|
984
|
+
|
|
985
|
+
# Match the edge ides and return the values
|
|
986
|
+
query = """
|
|
987
|
+
UNWIND $ids as i
|
|
988
|
+
MATCH (comm:Community)
|
|
989
|
+
WHERE comm.uuid=i.id
|
|
990
|
+
RETURN
|
|
991
|
+
comm.uuid AS uuid,
|
|
992
|
+
comm.group_id AS group_id,
|
|
993
|
+
comm.name AS name,
|
|
994
|
+
comm.created_at AS created_at,
|
|
995
|
+
comm.summary AS summary,
|
|
996
|
+
[x IN split(comm.name_embedding, ",") | toFloat(x)]AS name_embedding
|
|
997
|
+
ORDER BY i.score DESC
|
|
998
|
+
LIMIT $limit
|
|
999
|
+
"""
|
|
1000
|
+
records, _, _ = await driver.execute_query(
|
|
1001
|
+
query,
|
|
1002
|
+
ids=input_ids,
|
|
1003
|
+
query=fuzzy_query,
|
|
1004
|
+
limit=limit,
|
|
1005
|
+
routing_='r',
|
|
1006
|
+
**filter_params,
|
|
1007
|
+
)
|
|
1008
|
+
else:
|
|
1009
|
+
return []
|
|
1010
|
+
else:
|
|
1011
|
+
query = (
|
|
1012
|
+
get_nodes_query('community_name', '$query', limit=limit, provider=driver.provider)
|
|
1013
|
+
+ yield_query
|
|
1014
|
+
+ """
|
|
1015
|
+
WITH c, score
|
|
1016
|
+
"""
|
|
1017
|
+
+ group_filter_query
|
|
1018
|
+
+ """
|
|
1019
|
+
RETURN
|
|
1020
|
+
"""
|
|
1021
|
+
+ COMMUNITY_NODE_RETURN
|
|
1022
|
+
+ """
|
|
1023
|
+
ORDER BY score DESC
|
|
1024
|
+
LIMIT $limit
|
|
1025
|
+
"""
|
|
1026
|
+
)
|
|
1027
|
+
|
|
1028
|
+
records, _, _ = await driver.execute_query(
|
|
1029
|
+
query, query=fuzzy_query, limit=limit, routing_='r', **filter_params
|
|
1030
|
+
)
|
|
504
1031
|
|
|
505
|
-
records, _, _ = await driver.execute_query(
|
|
506
|
-
query,
|
|
507
|
-
query=fuzzy_query,
|
|
508
|
-
group_ids=group_ids,
|
|
509
|
-
limit=limit,
|
|
510
|
-
routing_='r',
|
|
511
|
-
)
|
|
512
1032
|
communities = [get_community_node_from_record(record) for record in records]
|
|
513
1033
|
|
|
514
1034
|
return communities
|
|
@@ -526,38 +1046,101 @@ async def community_similarity_search(
|
|
|
526
1046
|
|
|
527
1047
|
group_filter_query: LiteralString = ''
|
|
528
1048
|
if group_ids is not None:
|
|
529
|
-
group_filter_query += 'WHERE
|
|
1049
|
+
group_filter_query += ' WHERE c.group_id IN $group_ids'
|
|
530
1050
|
query_params['group_ids'] = group_ids
|
|
531
1051
|
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
1052
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
1053
|
+
query = (
|
|
1054
|
+
RUNTIME_QUERY
|
|
1055
|
+
+ """
|
|
1056
|
+
MATCH (n:Community)
|
|
1057
|
+
"""
|
|
1058
|
+
+ group_filter_query
|
|
1059
|
+
+ """
|
|
1060
|
+
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
|
1061
|
+
"""
|
|
1062
|
+
)
|
|
1063
|
+
resp, header, _ = await driver.execute_query(
|
|
1064
|
+
query,
|
|
1065
|
+
search_vector=search_vector,
|
|
1066
|
+
limit=limit,
|
|
1067
|
+
min_score=min_score,
|
|
1068
|
+
routing_='r',
|
|
1069
|
+
**query_params,
|
|
1070
|
+
)
|
|
1071
|
+
|
|
1072
|
+
if len(resp) > 0:
|
|
1073
|
+
# Calculate Cosine similarity then return the edge ids
|
|
1074
|
+
input_ids = []
|
|
1075
|
+
for r in resp:
|
|
1076
|
+
if r['embedding']:
|
|
1077
|
+
score = calculate_cosine_similarity(
|
|
1078
|
+
search_vector, list(map(float, r['embedding'].split(',')))
|
|
1079
|
+
)
|
|
1080
|
+
if score > min_score:
|
|
1081
|
+
input_ids.append({'id': r['id'], 'score': score})
|
|
1082
|
+
|
|
1083
|
+
# Match the edge ides and return the values
|
|
1084
|
+
query = """
|
|
1085
|
+
UNWIND $ids as i
|
|
1086
|
+
MATCH (comm:Community)
|
|
1087
|
+
WHERE id(comm)=i.id
|
|
1088
|
+
RETURN
|
|
1089
|
+
comm.uuid As uuid,
|
|
1090
|
+
comm.group_id AS group_id,
|
|
1091
|
+
comm.name AS name,
|
|
1092
|
+
comm.created_at AS created_at,
|
|
1093
|
+
comm.summary AS summary,
|
|
1094
|
+
comm.name_embedding AS name_embedding
|
|
1095
|
+
ORDER BY i.score DESC
|
|
1096
|
+
LIMIT $limit
|
|
1097
|
+
"""
|
|
1098
|
+
records, header, _ = await driver.execute_query(
|
|
1099
|
+
query,
|
|
1100
|
+
ids=input_ids,
|
|
1101
|
+
search_vector=search_vector,
|
|
1102
|
+
limit=limit,
|
|
1103
|
+
min_score=min_score,
|
|
1104
|
+
routing_='r',
|
|
1105
|
+
**query_params,
|
|
1106
|
+
)
|
|
1107
|
+
else:
|
|
1108
|
+
return []
|
|
1109
|
+
else:
|
|
1110
|
+
search_vector_var = '$search_vector'
|
|
1111
|
+
if driver.provider == GraphProvider.KUZU:
|
|
1112
|
+
search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
|
|
1113
|
+
|
|
1114
|
+
query = (
|
|
1115
|
+
RUNTIME_QUERY
|
|
1116
|
+
+ """
|
|
1117
|
+
MATCH (c:Community)
|
|
1118
|
+
"""
|
|
1119
|
+
+ group_filter_query
|
|
1120
|
+
+ """
|
|
1121
|
+
WITH c,
|
|
1122
|
+
"""
|
|
1123
|
+
+ get_vector_cosine_func_query('c.name_embedding', search_vector_var, driver.provider)
|
|
1124
|
+
+ """ AS score
|
|
1125
|
+
WHERE score > $min_score
|
|
1126
|
+
RETURN
|
|
1127
|
+
"""
|
|
1128
|
+
+ COMMUNITY_NODE_RETURN
|
|
1129
|
+
+ """
|
|
1130
|
+
ORDER BY score DESC
|
|
1131
|
+
LIMIT $limit
|
|
1132
|
+
"""
|
|
1133
|
+
)
|
|
1134
|
+
|
|
1135
|
+
records, _, _ = await driver.execute_query(
|
|
1136
|
+
query,
|
|
1137
|
+
search_vector=search_vector,
|
|
1138
|
+
limit=limit,
|
|
1139
|
+
min_score=min_score,
|
|
1140
|
+
routing_='r',
|
|
1141
|
+
**query_params,
|
|
1142
|
+
)
|
|
552
1143
|
|
|
553
|
-
records, _, _ = await driver.execute_query(
|
|
554
|
-
query,
|
|
555
|
-
search_vector=search_vector,
|
|
556
|
-
limit=limit,
|
|
557
|
-
min_score=min_score,
|
|
558
|
-
routing_='r',
|
|
559
|
-
**query_params,
|
|
560
|
-
)
|
|
561
1144
|
communities = [get_community_node_from_record(record) for record in records]
|
|
562
1145
|
|
|
563
1146
|
return communities
|
|
@@ -648,67 +1231,129 @@ async def get_relevant_nodes(
|
|
|
648
1231
|
return []
|
|
649
1232
|
|
|
650
1233
|
group_id = nodes[0].group_id
|
|
651
|
-
|
|
652
|
-
# vector similarity search over entity names
|
|
653
|
-
query_params: dict[str, Any] = {}
|
|
654
|
-
|
|
655
|
-
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
|
656
|
-
query_params.update(filter_params)
|
|
657
|
-
|
|
658
|
-
query = (
|
|
659
|
-
RUNTIME_QUERY
|
|
660
|
-
+ """
|
|
661
|
-
UNWIND $nodes AS node
|
|
662
|
-
MATCH (n:Entity {group_id: $group_id})
|
|
663
|
-
"""
|
|
664
|
-
+ filter_query
|
|
665
|
-
+ """
|
|
666
|
-
WITH node, n, """
|
|
667
|
-
+ get_vector_cosine_func_query('n.name_embedding', 'node.name_embedding', driver.provider)
|
|
668
|
-
+ """ AS score
|
|
669
|
-
WHERE score > $min_score
|
|
670
|
-
WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
|
|
671
|
-
"""
|
|
672
|
-
+ get_nodes_query(driver.provider, 'node_name_and_summary', 'node.fulltext_query')
|
|
673
|
-
+ """
|
|
674
|
-
YIELD node AS m
|
|
675
|
-
WHERE m.group_id = $group_id
|
|
676
|
-
WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
|
|
677
|
-
|
|
678
|
-
WITH node,
|
|
679
|
-
top_vector_nodes,
|
|
680
|
-
[m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
|
|
681
|
-
|
|
682
|
-
WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
|
|
683
|
-
|
|
684
|
-
UNWIND combined_nodes AS combined_node
|
|
685
|
-
WITH node, collect(DISTINCT combined_node) AS deduped_nodes
|
|
686
|
-
|
|
687
|
-
RETURN
|
|
688
|
-
node.uuid AS search_node_uuid,
|
|
689
|
-
[x IN deduped_nodes | {
|
|
690
|
-
uuid: x.uuid,
|
|
691
|
-
name: x.name,
|
|
692
|
-
name_embedding: x.name_embedding,
|
|
693
|
-
group_id: x.group_id,
|
|
694
|
-
created_at: x.created_at,
|
|
695
|
-
summary: x.summary,
|
|
696
|
-
labels: labels(x),
|
|
697
|
-
attributes: properties(x)
|
|
698
|
-
}] AS matches
|
|
699
|
-
"""
|
|
700
|
-
)
|
|
701
|
-
|
|
702
1234
|
query_nodes = [
|
|
703
1235
|
{
|
|
704
1236
|
'uuid': node.uuid,
|
|
705
1237
|
'name': node.name,
|
|
706
1238
|
'name_embedding': node.name_embedding,
|
|
707
|
-
'fulltext_query': fulltext_query(node.name, [node.group_id], driver
|
|
1239
|
+
'fulltext_query': fulltext_query(node.name, [node.group_id], driver),
|
|
708
1240
|
}
|
|
709
1241
|
for node in nodes
|
|
710
1242
|
]
|
|
711
1243
|
|
|
1244
|
+
filter_queries, filter_params = node_search_filter_query_constructor(
|
|
1245
|
+
search_filter, driver.provider
|
|
1246
|
+
)
|
|
1247
|
+
|
|
1248
|
+
filter_query = ''
|
|
1249
|
+
if filter_queries:
|
|
1250
|
+
filter_query = 'WHERE ' + (' AND '.join(filter_queries))
|
|
1251
|
+
|
|
1252
|
+
if driver.provider == GraphProvider.KUZU:
|
|
1253
|
+
embedding_size = len(nodes[0].name_embedding) if nodes[0].name_embedding is not None else 0
|
|
1254
|
+
if embedding_size == 0:
|
|
1255
|
+
return []
|
|
1256
|
+
|
|
1257
|
+
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
|
|
1258
|
+
query = (
|
|
1259
|
+
RUNTIME_QUERY
|
|
1260
|
+
+ """
|
|
1261
|
+
UNWIND $nodes AS node
|
|
1262
|
+
MATCH (n:Entity {group_id: $group_id})
|
|
1263
|
+
"""
|
|
1264
|
+
+ filter_query
|
|
1265
|
+
+ """
|
|
1266
|
+
WITH node, n, """
|
|
1267
|
+
+ get_vector_cosine_func_query(
|
|
1268
|
+
'n.name_embedding',
|
|
1269
|
+
f'CAST(node.name_embedding AS FLOAT[{embedding_size}])',
|
|
1270
|
+
driver.provider,
|
|
1271
|
+
)
|
|
1272
|
+
+ """ AS score
|
|
1273
|
+
WHERE score > $min_score
|
|
1274
|
+
WITH node, collect(n)[:$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
|
|
1275
|
+
"""
|
|
1276
|
+
+ get_nodes_query(
|
|
1277
|
+
'node_name_and_summary',
|
|
1278
|
+
'node.fulltext_query',
|
|
1279
|
+
limit=limit,
|
|
1280
|
+
provider=driver.provider,
|
|
1281
|
+
)
|
|
1282
|
+
+ """
|
|
1283
|
+
WITH node AS m
|
|
1284
|
+
WHERE m.group_id = $group_id AND NOT m.uuid IN vector_node_uuids
|
|
1285
|
+
WITH node, top_vector_nodes, collect(m) AS fulltext_nodes
|
|
1286
|
+
|
|
1287
|
+
WITH node, list_concat(top_vector_nodes, fulltext_nodes) AS combined_nodes
|
|
1288
|
+
|
|
1289
|
+
UNWIND combined_nodes AS x
|
|
1290
|
+
WITH node, collect(DISTINCT {
|
|
1291
|
+
uuid: x.uuid,
|
|
1292
|
+
name: x.name,
|
|
1293
|
+
name_embedding: x.name_embedding,
|
|
1294
|
+
group_id: x.group_id,
|
|
1295
|
+
created_at: x.created_at,
|
|
1296
|
+
summary: x.summary,
|
|
1297
|
+
labels: x.labels,
|
|
1298
|
+
attributes: x.attributes
|
|
1299
|
+
}) AS matches
|
|
1300
|
+
|
|
1301
|
+
RETURN
|
|
1302
|
+
node.uuid AS search_node_uuid, matches
|
|
1303
|
+
"""
|
|
1304
|
+
)
|
|
1305
|
+
else:
|
|
1306
|
+
query = (
|
|
1307
|
+
RUNTIME_QUERY
|
|
1308
|
+
+ """
|
|
1309
|
+
UNWIND $nodes AS node
|
|
1310
|
+
MATCH (n:Entity {group_id: $group_id})
|
|
1311
|
+
"""
|
|
1312
|
+
+ filter_query
|
|
1313
|
+
+ """
|
|
1314
|
+
WITH node, n, """
|
|
1315
|
+
+ get_vector_cosine_func_query(
|
|
1316
|
+
'n.name_embedding', 'node.name_embedding', driver.provider
|
|
1317
|
+
)
|
|
1318
|
+
+ """ AS score
|
|
1319
|
+
WHERE score > $min_score
|
|
1320
|
+
WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
|
|
1321
|
+
"""
|
|
1322
|
+
+ get_nodes_query(
|
|
1323
|
+
'node_name_and_summary',
|
|
1324
|
+
'node.fulltext_query',
|
|
1325
|
+
limit=limit,
|
|
1326
|
+
provider=driver.provider,
|
|
1327
|
+
)
|
|
1328
|
+
+ """
|
|
1329
|
+
YIELD node AS m
|
|
1330
|
+
WHERE m.group_id = $group_id
|
|
1331
|
+
WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
|
|
1332
|
+
|
|
1333
|
+
WITH node,
|
|
1334
|
+
top_vector_nodes,
|
|
1335
|
+
[m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
|
|
1336
|
+
|
|
1337
|
+
WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
|
|
1338
|
+
|
|
1339
|
+
UNWIND combined_nodes AS combined_node
|
|
1340
|
+
WITH node, collect(DISTINCT combined_node) AS deduped_nodes
|
|
1341
|
+
|
|
1342
|
+
RETURN
|
|
1343
|
+
node.uuid AS search_node_uuid,
|
|
1344
|
+
[x IN deduped_nodes | {
|
|
1345
|
+
uuid: x.uuid,
|
|
1346
|
+
name: x.name,
|
|
1347
|
+
name_embedding: x.name_embedding,
|
|
1348
|
+
group_id: x.group_id,
|
|
1349
|
+
created_at: x.created_at,
|
|
1350
|
+
summary: x.summary,
|
|
1351
|
+
labels: labels(x),
|
|
1352
|
+
attributes: properties(x)
|
|
1353
|
+
}] AS matches
|
|
1354
|
+
"""
|
|
1355
|
+
)
|
|
1356
|
+
|
|
712
1357
|
results, _, _ = await driver.execute_query(
|
|
713
1358
|
query,
|
|
714
1359
|
nodes=query_nodes,
|
|
@@ -716,12 +1361,12 @@ async def get_relevant_nodes(
|
|
|
716
1361
|
limit=limit,
|
|
717
1362
|
min_score=min_score,
|
|
718
1363
|
routing_='r',
|
|
719
|
-
**
|
|
1364
|
+
**filter_params,
|
|
720
1365
|
)
|
|
721
1366
|
|
|
722
1367
|
relevant_nodes_dict: dict[str, list[EntityNode]] = {
|
|
723
1368
|
result['search_node_uuid']: [
|
|
724
|
-
get_entity_node_from_record(record) for record in result['matches']
|
|
1369
|
+
get_entity_node_from_record(record, driver.provider) for record in result['matches']
|
|
725
1370
|
]
|
|
726
1371
|
for result in results
|
|
727
1372
|
}
|
|
@@ -741,25 +1386,53 @@ async def get_relevant_edges(
|
|
|
741
1386
|
if len(edges) == 0:
|
|
742
1387
|
return []
|
|
743
1388
|
|
|
744
|
-
|
|
1389
|
+
filter_queries, filter_params = edge_search_filter_query_constructor(
|
|
1390
|
+
search_filter, driver.provider
|
|
1391
|
+
)
|
|
745
1392
|
|
|
746
|
-
filter_query
|
|
747
|
-
|
|
1393
|
+
filter_query = ''
|
|
1394
|
+
if filter_queries:
|
|
1395
|
+
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
|
|
1396
|
+
|
|
1397
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
1398
|
+
query = (
|
|
1399
|
+
RUNTIME_QUERY
|
|
1400
|
+
+ """
|
|
1401
|
+
UNWIND $edges AS edge
|
|
1402
|
+
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
|
1403
|
+
"""
|
|
1404
|
+
+ filter_query
|
|
1405
|
+
+ """
|
|
1406
|
+
WITH e, edge
|
|
1407
|
+
RETURN DISTINCT id(e) as id, e.fact_embedding as source_embedding, edge.uuid as search_edge_uuid,
|
|
1408
|
+
edge.fact_embedding as target_embedding
|
|
1409
|
+
"""
|
|
1410
|
+
)
|
|
1411
|
+
resp, _, _ = await driver.execute_query(
|
|
1412
|
+
query,
|
|
1413
|
+
edges=[edge.model_dump() for edge in edges],
|
|
1414
|
+
limit=limit,
|
|
1415
|
+
min_score=min_score,
|
|
1416
|
+
routing_='r',
|
|
1417
|
+
**filter_params,
|
|
1418
|
+
)
|
|
748
1419
|
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
1420
|
+
# Calculate Cosine similarity then return the edge ids
|
|
1421
|
+
input_ids = []
|
|
1422
|
+
for r in resp:
|
|
1423
|
+
score = calculate_cosine_similarity(
|
|
1424
|
+
list(map(float, r['source_embedding'].split(','))), r['target_embedding']
|
|
1425
|
+
)
|
|
1426
|
+
if score > min_score:
|
|
1427
|
+
input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})
|
|
1428
|
+
|
|
1429
|
+
# Match the edge ides and return the values
|
|
1430
|
+
query = """
|
|
1431
|
+
UNWIND $ids AS edge
|
|
1432
|
+
MATCH ()-[e]->()
|
|
1433
|
+
WHERE id(e) = edge.id
|
|
1434
|
+
WITH edge, e
|
|
1435
|
+
ORDER BY edge.score DESC
|
|
763
1436
|
RETURN edge.uuid AS search_edge_uuid,
|
|
764
1437
|
collect({
|
|
765
1438
|
uuid: e.uuid,
|
|
@@ -769,28 +1442,119 @@ async def get_relevant_edges(
|
|
|
769
1442
|
name: e.name,
|
|
770
1443
|
group_id: e.group_id,
|
|
771
1444
|
fact: e.fact,
|
|
772
|
-
fact_embedding: e.fact_embedding,
|
|
773
|
-
episodes: e.episodes,
|
|
1445
|
+
fact_embedding: [x IN split(e.fact_embedding, ",") | toFloat(x)],
|
|
1446
|
+
episodes: split(e.episodes, ","),
|
|
774
1447
|
expired_at: e.expired_at,
|
|
775
1448
|
valid_at: e.valid_at,
|
|
776
1449
|
invalid_at: e.invalid_at,
|
|
777
1450
|
attributes: properties(e)
|
|
778
1451
|
})[..$limit] AS matches
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
1452
|
+
"""
|
|
1453
|
+
|
|
1454
|
+
results, _, _ = await driver.execute_query(
|
|
1455
|
+
query,
|
|
1456
|
+
ids=input_ids,
|
|
1457
|
+
edges=[edge.model_dump() for edge in edges],
|
|
1458
|
+
limit=limit,
|
|
1459
|
+
min_score=min_score,
|
|
1460
|
+
routing_='r',
|
|
1461
|
+
**filter_params,
|
|
1462
|
+
)
|
|
1463
|
+
else:
|
|
1464
|
+
if driver.provider == GraphProvider.KUZU:
|
|
1465
|
+
embedding_size = (
|
|
1466
|
+
len(edges[0].fact_embedding) if edges[0].fact_embedding is not None else 0
|
|
1467
|
+
)
|
|
1468
|
+
if embedding_size == 0:
|
|
1469
|
+
return []
|
|
1470
|
+
|
|
1471
|
+
query = (
|
|
1472
|
+
RUNTIME_QUERY
|
|
1473
|
+
+ """
|
|
1474
|
+
UNWIND $edges AS edge
|
|
1475
|
+
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
|
|
1476
|
+
"""
|
|
1477
|
+
+ filter_query
|
|
1478
|
+
+ """
|
|
1479
|
+
WITH e, edge, n, m, """
|
|
1480
|
+
+ get_vector_cosine_func_query(
|
|
1481
|
+
'e.fact_embedding',
|
|
1482
|
+
f'CAST(edge.fact_embedding AS FLOAT[{embedding_size}])',
|
|
1483
|
+
driver.provider,
|
|
1484
|
+
)
|
|
1485
|
+
+ """ AS score
|
|
1486
|
+
WHERE score > $min_score
|
|
1487
|
+
WITH e, edge, n, m, score
|
|
1488
|
+
ORDER BY score DESC
|
|
1489
|
+
LIMIT $limit
|
|
1490
|
+
RETURN
|
|
1491
|
+
edge.uuid AS search_edge_uuid,
|
|
1492
|
+
collect({
|
|
1493
|
+
uuid: e.uuid,
|
|
1494
|
+
source_node_uuid: n.uuid,
|
|
1495
|
+
target_node_uuid: m.uuid,
|
|
1496
|
+
created_at: e.created_at,
|
|
1497
|
+
name: e.name,
|
|
1498
|
+
group_id: e.group_id,
|
|
1499
|
+
fact: e.fact,
|
|
1500
|
+
fact_embedding: e.fact_embedding,
|
|
1501
|
+
episodes: e.episodes,
|
|
1502
|
+
expired_at: e.expired_at,
|
|
1503
|
+
valid_at: e.valid_at,
|
|
1504
|
+
invalid_at: e.invalid_at,
|
|
1505
|
+
attributes: e.attributes
|
|
1506
|
+
}) AS matches
|
|
1507
|
+
"""
|
|
1508
|
+
)
|
|
1509
|
+
else:
|
|
1510
|
+
query = (
|
|
1511
|
+
RUNTIME_QUERY
|
|
1512
|
+
+ """
|
|
1513
|
+
UNWIND $edges AS edge
|
|
1514
|
+
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
|
1515
|
+
"""
|
|
1516
|
+
+ filter_query
|
|
1517
|
+
+ """
|
|
1518
|
+
WITH e, edge, """
|
|
1519
|
+
+ get_vector_cosine_func_query(
|
|
1520
|
+
'e.fact_embedding', 'edge.fact_embedding', driver.provider
|
|
1521
|
+
)
|
|
1522
|
+
+ """ AS score
|
|
1523
|
+
WHERE score > $min_score
|
|
1524
|
+
WITH edge, e, score
|
|
1525
|
+
ORDER BY score DESC
|
|
1526
|
+
RETURN
|
|
1527
|
+
edge.uuid AS search_edge_uuid,
|
|
1528
|
+
collect({
|
|
1529
|
+
uuid: e.uuid,
|
|
1530
|
+
source_node_uuid: startNode(e).uuid,
|
|
1531
|
+
target_node_uuid: endNode(e).uuid,
|
|
1532
|
+
created_at: e.created_at,
|
|
1533
|
+
name: e.name,
|
|
1534
|
+
group_id: e.group_id,
|
|
1535
|
+
fact: e.fact,
|
|
1536
|
+
fact_embedding: e.fact_embedding,
|
|
1537
|
+
episodes: e.episodes,
|
|
1538
|
+
expired_at: e.expired_at,
|
|
1539
|
+
valid_at: e.valid_at,
|
|
1540
|
+
invalid_at: e.invalid_at,
|
|
1541
|
+
attributes: properties(e)
|
|
1542
|
+
})[..$limit] AS matches
|
|
1543
|
+
"""
|
|
1544
|
+
)
|
|
1545
|
+
|
|
1546
|
+
results, _, _ = await driver.execute_query(
|
|
1547
|
+
query,
|
|
1548
|
+
edges=[edge.model_dump() for edge in edges],
|
|
1549
|
+
limit=limit,
|
|
1550
|
+
min_score=min_score,
|
|
1551
|
+
routing_='r',
|
|
1552
|
+
**filter_params,
|
|
1553
|
+
)
|
|
790
1554
|
|
|
791
1555
|
relevant_edges_dict: dict[str, list[EntityEdge]] = {
|
|
792
1556
|
result['search_edge_uuid']: [
|
|
793
|
-
get_entity_edge_from_record(record) for record in result['matches']
|
|
1557
|
+
get_entity_edge_from_record(record, driver.provider) for record in result['matches']
|
|
794
1558
|
]
|
|
795
1559
|
for result in results
|
|
796
1560
|
}
|
|
@@ -810,26 +1574,55 @@ async def get_edge_invalidation_candidates(
|
|
|
810
1574
|
if len(edges) == 0:
|
|
811
1575
|
return []
|
|
812
1576
|
|
|
813
|
-
|
|
1577
|
+
filter_queries, filter_params = edge_search_filter_query_constructor(
|
|
1578
|
+
search_filter, driver.provider
|
|
1579
|
+
)
|
|
814
1580
|
|
|
815
|
-
filter_query
|
|
816
|
-
|
|
1581
|
+
filter_query = ''
|
|
1582
|
+
if filter_queries:
|
|
1583
|
+
filter_query = ' AND ' + (' AND '.join(filter_queries))
|
|
1584
|
+
|
|
1585
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
1586
|
+
query = (
|
|
1587
|
+
RUNTIME_QUERY
|
|
1588
|
+
+ """
|
|
1589
|
+
UNWIND $edges AS edge
|
|
1590
|
+
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
|
1591
|
+
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
|
1592
|
+
"""
|
|
1593
|
+
+ filter_query
|
|
1594
|
+
+ """
|
|
1595
|
+
WITH e, edge
|
|
1596
|
+
RETURN DISTINCT id(e) as id, e.fact_embedding as source_embedding,
|
|
1597
|
+
edge.fact_embedding as target_embedding,
|
|
1598
|
+
edge.uuid as search_edge_uuid
|
|
1599
|
+
"""
|
|
1600
|
+
)
|
|
1601
|
+
resp, _, _ = await driver.execute_query(
|
|
1602
|
+
query,
|
|
1603
|
+
edges=[edge.model_dump() for edge in edges],
|
|
1604
|
+
limit=limit,
|
|
1605
|
+
min_score=min_score,
|
|
1606
|
+
routing_='r',
|
|
1607
|
+
**filter_params,
|
|
1608
|
+
)
|
|
817
1609
|
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
1610
|
+
# Calculate Cosine similarity then return the edge ids
|
|
1611
|
+
input_ids = []
|
|
1612
|
+
for r in resp:
|
|
1613
|
+
score = calculate_cosine_similarity(
|
|
1614
|
+
list(map(float, r['source_embedding'].split(','))), r['target_embedding']
|
|
1615
|
+
)
|
|
1616
|
+
if score > min_score:
|
|
1617
|
+
input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})
|
|
1618
|
+
|
|
1619
|
+
# Match the edge ides and return the values
|
|
1620
|
+
query = """
|
|
1621
|
+
UNWIND $ids AS edge
|
|
1622
|
+
MATCH ()-[e]->()
|
|
1623
|
+
WHERE id(e) = edge.id
|
|
1624
|
+
WITH edge, e
|
|
1625
|
+
ORDER BY edge.score DESC
|
|
833
1626
|
RETURN edge.uuid AS search_edge_uuid,
|
|
834
1627
|
collect({
|
|
835
1628
|
uuid: e.uuid,
|
|
@@ -839,27 +1632,119 @@ async def get_edge_invalidation_candidates(
|
|
|
839
1632
|
name: e.name,
|
|
840
1633
|
group_id: e.group_id,
|
|
841
1634
|
fact: e.fact,
|
|
842
|
-
fact_embedding: e.fact_embedding,
|
|
843
|
-
episodes: e.episodes,
|
|
1635
|
+
fact_embedding: [x IN split(e.fact_embedding, ",") | toFloat(x)],
|
|
1636
|
+
episodes: split(e.episodes, ","),
|
|
844
1637
|
expired_at: e.expired_at,
|
|
845
1638
|
valid_at: e.valid_at,
|
|
846
1639
|
invalid_at: e.invalid_at,
|
|
847
1640
|
attributes: properties(e)
|
|
848
1641
|
})[..$limit] AS matches
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
1642
|
+
"""
|
|
1643
|
+
results, _, _ = await driver.execute_query(
|
|
1644
|
+
query,
|
|
1645
|
+
ids=input_ids,
|
|
1646
|
+
edges=[edge.model_dump() for edge in edges],
|
|
1647
|
+
limit=limit,
|
|
1648
|
+
min_score=min_score,
|
|
1649
|
+
routing_='r',
|
|
1650
|
+
**filter_params,
|
|
1651
|
+
)
|
|
1652
|
+
else:
|
|
1653
|
+
if driver.provider == GraphProvider.KUZU:
|
|
1654
|
+
embedding_size = (
|
|
1655
|
+
len(edges[0].fact_embedding) if edges[0].fact_embedding is not None else 0
|
|
1656
|
+
)
|
|
1657
|
+
if embedding_size == 0:
|
|
1658
|
+
return []
|
|
1659
|
+
|
|
1660
|
+
query = (
|
|
1661
|
+
RUNTIME_QUERY
|
|
1662
|
+
+ """
|
|
1663
|
+
UNWIND $edges AS edge
|
|
1664
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
|
|
1665
|
+
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
|
|
1666
|
+
"""
|
|
1667
|
+
+ filter_query
|
|
1668
|
+
+ """
|
|
1669
|
+
WITH edge, e, n, m, """
|
|
1670
|
+
+ get_vector_cosine_func_query(
|
|
1671
|
+
'e.fact_embedding',
|
|
1672
|
+
f'CAST(edge.fact_embedding AS FLOAT[{embedding_size}])',
|
|
1673
|
+
driver.provider,
|
|
1674
|
+
)
|
|
1675
|
+
+ """ AS score
|
|
1676
|
+
WHERE score > $min_score
|
|
1677
|
+
WITH edge, e, n, m, score
|
|
1678
|
+
ORDER BY score DESC
|
|
1679
|
+
LIMIT $limit
|
|
1680
|
+
RETURN
|
|
1681
|
+
edge.uuid AS search_edge_uuid,
|
|
1682
|
+
collect({
|
|
1683
|
+
uuid: e.uuid,
|
|
1684
|
+
source_node_uuid: n.uuid,
|
|
1685
|
+
target_node_uuid: m.uuid,
|
|
1686
|
+
created_at: e.created_at,
|
|
1687
|
+
name: e.name,
|
|
1688
|
+
group_id: e.group_id,
|
|
1689
|
+
fact: e.fact,
|
|
1690
|
+
fact_embedding: e.fact_embedding,
|
|
1691
|
+
episodes: e.episodes,
|
|
1692
|
+
expired_at: e.expired_at,
|
|
1693
|
+
valid_at: e.valid_at,
|
|
1694
|
+
invalid_at: e.invalid_at,
|
|
1695
|
+
attributes: e.attributes
|
|
1696
|
+
}) AS matches
|
|
1697
|
+
"""
|
|
1698
|
+
)
|
|
1699
|
+
else:
|
|
1700
|
+
query = (
|
|
1701
|
+
RUNTIME_QUERY
|
|
1702
|
+
+ """
|
|
1703
|
+
UNWIND $edges AS edge
|
|
1704
|
+
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
|
1705
|
+
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
|
1706
|
+
"""
|
|
1707
|
+
+ filter_query
|
|
1708
|
+
+ """
|
|
1709
|
+
WITH edge, e, """
|
|
1710
|
+
+ get_vector_cosine_func_query(
|
|
1711
|
+
'e.fact_embedding', 'edge.fact_embedding', driver.provider
|
|
1712
|
+
)
|
|
1713
|
+
+ """ AS score
|
|
1714
|
+
WHERE score > $min_score
|
|
1715
|
+
WITH edge, e, score
|
|
1716
|
+
ORDER BY score DESC
|
|
1717
|
+
RETURN
|
|
1718
|
+
edge.uuid AS search_edge_uuid,
|
|
1719
|
+
collect({
|
|
1720
|
+
uuid: e.uuid,
|
|
1721
|
+
source_node_uuid: startNode(e).uuid,
|
|
1722
|
+
target_node_uuid: endNode(e).uuid,
|
|
1723
|
+
created_at: e.created_at,
|
|
1724
|
+
name: e.name,
|
|
1725
|
+
group_id: e.group_id,
|
|
1726
|
+
fact: e.fact,
|
|
1727
|
+
fact_embedding: e.fact_embedding,
|
|
1728
|
+
episodes: e.episodes,
|
|
1729
|
+
expired_at: e.expired_at,
|
|
1730
|
+
valid_at: e.valid_at,
|
|
1731
|
+
invalid_at: e.invalid_at,
|
|
1732
|
+
attributes: properties(e)
|
|
1733
|
+
})[..$limit] AS matches
|
|
1734
|
+
"""
|
|
1735
|
+
)
|
|
1736
|
+
|
|
1737
|
+
results, _, _ = await driver.execute_query(
|
|
1738
|
+
query,
|
|
1739
|
+
edges=[edge.model_dump() for edge in edges],
|
|
1740
|
+
limit=limit,
|
|
1741
|
+
min_score=min_score,
|
|
1742
|
+
routing_='r',
|
|
1743
|
+
**filter_params,
|
|
1744
|
+
)
|
|
860
1745
|
invalidation_edges_dict: dict[str, list[EntityEdge]] = {
|
|
861
1746
|
result['search_edge_uuid']: [
|
|
862
|
-
get_entity_edge_from_record(record) for record in result['matches']
|
|
1747
|
+
get_entity_edge_from_record(record, driver.provider) for record in result['matches']
|
|
863
1748
|
]
|
|
864
1749
|
for result in results
|
|
865
1750
|
}
|
|
@@ -898,13 +1783,21 @@ async def node_distance_reranker(
|
|
|
898
1783
|
filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
|
|
899
1784
|
scores: dict[str, float] = {center_node_uuid: 0.0}
|
|
900
1785
|
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
1786
|
+
query = """
|
|
1787
|
+
UNWIND $node_uuids AS node_uuid
|
|
1788
|
+
MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
|
|
1789
|
+
RETURN 1 AS score, node_uuid AS uuid
|
|
1790
|
+
"""
|
|
1791
|
+
if driver.provider == GraphProvider.KUZU:
|
|
1792
|
+
query = """
|
|
904
1793
|
UNWIND $node_uuids AS node_uuid
|
|
905
|
-
MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
|
|
1794
|
+
MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(n:Entity {uuid: node_uuid})
|
|
906
1795
|
RETURN 1 AS score, node_uuid AS uuid
|
|
907
|
-
"""
|
|
1796
|
+
"""
|
|
1797
|
+
|
|
1798
|
+
# Find the shortest path to center node
|
|
1799
|
+
results, header, _ = await driver.execute_query(
|
|
1800
|
+
query,
|
|
908
1801
|
node_uuids=filtered_uuids,
|
|
909
1802
|
center_uuid=center_node_uuid,
|
|
910
1803
|
routing_='r',
|
|
@@ -955,6 +1848,10 @@ async def episode_mentions_reranker(
|
|
|
955
1848
|
for result in results:
|
|
956
1849
|
scores[result['uuid']] = result['score']
|
|
957
1850
|
|
|
1851
|
+
for uuid in sorted_uuids:
|
|
1852
|
+
if uuid not in scores:
|
|
1853
|
+
scores[uuid] = float('inf')
|
|
1854
|
+
|
|
958
1855
|
# rerank on shortest distance
|
|
959
1856
|
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
|
960
1857
|
|
|
@@ -1007,14 +1904,24 @@ def maximal_marginal_relevance(
|
|
|
1007
1904
|
async def get_embeddings_for_nodes(
|
|
1008
1905
|
driver: GraphDriver, nodes: list[EntityNode]
|
|
1009
1906
|
) -> dict[str, list[float]]:
|
|
1010
|
-
|
|
1907
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
1908
|
+
query = """
|
|
1909
|
+
MATCH (n:Entity)
|
|
1910
|
+
WHERE n.uuid IN $node_uuids
|
|
1911
|
+
RETURN DISTINCT
|
|
1912
|
+
n.uuid AS uuid,
|
|
1913
|
+
split(n.name_embedding, ",") AS name_embedding
|
|
1011
1914
|
"""
|
|
1915
|
+
else:
|
|
1916
|
+
query = """
|
|
1012
1917
|
MATCH (n:Entity)
|
|
1013
1918
|
WHERE n.uuid IN $node_uuids
|
|
1014
1919
|
RETURN DISTINCT
|
|
1015
1920
|
n.uuid AS uuid,
|
|
1016
1921
|
n.name_embedding AS name_embedding
|
|
1017
|
-
"""
|
|
1922
|
+
"""
|
|
1923
|
+
results, _, _ = await driver.execute_query(
|
|
1924
|
+
query,
|
|
1018
1925
|
node_uuids=[node.uuid for node in nodes],
|
|
1019
1926
|
routing_='r',
|
|
1020
1927
|
)
|
|
@@ -1032,14 +1939,24 @@ async def get_embeddings_for_nodes(
|
|
|
1032
1939
|
async def get_embeddings_for_communities(
|
|
1033
1940
|
driver: GraphDriver, communities: list[CommunityNode]
|
|
1034
1941
|
) -> dict[str, list[float]]:
|
|
1035
|
-
|
|
1942
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
1943
|
+
query = """
|
|
1944
|
+
MATCH (c:Community)
|
|
1945
|
+
WHERE c.uuid IN $community_uuids
|
|
1946
|
+
RETURN DISTINCT
|
|
1947
|
+
c.uuid AS uuid,
|
|
1948
|
+
split(c.name_embedding, ",") AS name_embedding
|
|
1036
1949
|
"""
|
|
1950
|
+
else:
|
|
1951
|
+
query = """
|
|
1037
1952
|
MATCH (c:Community)
|
|
1038
1953
|
WHERE c.uuid IN $community_uuids
|
|
1039
1954
|
RETURN DISTINCT
|
|
1040
1955
|
c.uuid AS uuid,
|
|
1041
1956
|
c.name_embedding AS name_embedding
|
|
1042
|
-
"""
|
|
1957
|
+
"""
|
|
1958
|
+
results, _, _ = await driver.execute_query(
|
|
1959
|
+
query,
|
|
1043
1960
|
community_uuids=[community.uuid for community in communities],
|
|
1044
1961
|
routing_='r',
|
|
1045
1962
|
)
|
|
@@ -1057,14 +1974,34 @@ async def get_embeddings_for_communities(
|
|
|
1057
1974
|
async def get_embeddings_for_edges(
|
|
1058
1975
|
driver: GraphDriver, edges: list[EntityEdge]
|
|
1059
1976
|
) -> dict[str, list[float]]:
|
|
1060
|
-
|
|
1061
|
-
"""
|
|
1977
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
1978
|
+
query = """
|
|
1062
1979
|
MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
|
1063
1980
|
WHERE e.uuid IN $edge_uuids
|
|
1981
|
+
RETURN DISTINCT
|
|
1982
|
+
e.uuid AS uuid,
|
|
1983
|
+
split(e.fact_embedding, ",") AS fact_embedding
|
|
1984
|
+
"""
|
|
1985
|
+
else:
|
|
1986
|
+
match_query = """
|
|
1987
|
+
MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
|
1988
|
+
"""
|
|
1989
|
+
if driver.provider == GraphProvider.KUZU:
|
|
1990
|
+
match_query = """
|
|
1991
|
+
MATCH (n:Entity)-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(m:Entity)
|
|
1992
|
+
"""
|
|
1993
|
+
|
|
1994
|
+
query = (
|
|
1995
|
+
match_query
|
|
1996
|
+
+ """
|
|
1997
|
+
WHERE e.uuid IN $edge_uuids
|
|
1064
1998
|
RETURN DISTINCT
|
|
1065
1999
|
e.uuid AS uuid,
|
|
1066
2000
|
e.fact_embedding AS fact_embedding
|
|
1067
|
-
"""
|
|
2001
|
+
"""
|
|
2002
|
+
)
|
|
2003
|
+
results, _, _ = await driver.execute_query(
|
|
2004
|
+
query,
|
|
1068
2005
|
edge_uuids=[edge.uuid for edge in edges],
|
|
1069
2006
|
routing_='r',
|
|
1070
2007
|
)
|