graphiti-core 0.19.0rc3__py3-none-any.whl → 0.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +3 -0
- graphiti_core/driver/falkordb_driver.py +3 -14
- graphiti_core/driver/kuzu_driver.py +175 -0
- graphiti_core/driver/neptune_driver.py +2 -0
- graphiti_core/edges.py +148 -83
- graphiti_core/graph_queries.py +31 -2
- graphiti_core/graphiti.py +4 -1
- graphiti_core/helpers.py +7 -12
- graphiti_core/migrations/neo4j_node_group_labels.py +33 -4
- graphiti_core/models/edges/edge_db_queries.py +121 -42
- graphiti_core/models/nodes/node_db_queries.py +102 -23
- graphiti_core/nodes.py +169 -66
- graphiti_core/search/search.py +13 -3
- graphiti_core/search/search_config.py +4 -0
- graphiti_core/search/search_filters.py +35 -22
- graphiti_core/search/search_utils.py +693 -382
- graphiti_core/utils/bulk_utils.py +50 -18
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +39 -32
- graphiti_core/utils/maintenance/edge_operations.py +19 -8
- graphiti_core/utils/maintenance/graph_data_operations.py +77 -47
- {graphiti_core-0.19.0rc3.dist-info → graphiti_core-0.20.0.dist-info}/METADATA +116 -48
- {graphiti_core-0.19.0rc3.dist-info → graphiti_core-0.20.0.dist-info}/RECORD +25 -24
- {graphiti_core-0.19.0rc3.dist-info → graphiti_core-0.20.0.dist-info}/WHEEL +0 -0
- {graphiti_core-0.19.0rc3.dist-info → graphiti_core-0.20.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -32,15 +32,17 @@ from graphiti_core.graph_queries import (
|
|
|
32
32
|
get_vector_cosine_func_query,
|
|
33
33
|
)
|
|
34
34
|
from graphiti_core.helpers import (
|
|
35
|
-
RUNTIME_QUERY,
|
|
36
35
|
lucene_sanitize,
|
|
37
36
|
normalize_l2,
|
|
38
37
|
semaphore_gather,
|
|
39
38
|
)
|
|
40
|
-
from graphiti_core.models.edges.edge_db_queries import
|
|
41
|
-
from graphiti_core.models.nodes.node_db_queries import
|
|
39
|
+
from graphiti_core.models.edges.edge_db_queries import get_entity_edge_return_query
|
|
40
|
+
from graphiti_core.models.nodes.node_db_queries import (
|
|
41
|
+
COMMUNITY_NODE_RETURN,
|
|
42
|
+
EPISODIC_NODE_RETURN,
|
|
43
|
+
get_entity_node_return_query,
|
|
44
|
+
)
|
|
42
45
|
from graphiti_core.nodes import (
|
|
43
|
-
ENTITY_NODE_RETURN,
|
|
44
46
|
CommunityNode,
|
|
45
47
|
EntityNode,
|
|
46
48
|
EpisodicNode,
|
|
@@ -78,9 +80,16 @@ def calculate_cosine_similarity(vector1: list[float], vector2: list[float]) -> f
|
|
|
78
80
|
return dot_product / (norm_vector1 * norm_vector2)
|
|
79
81
|
|
|
80
82
|
|
|
81
|
-
def fulltext_query(query: str, group_ids: list[str] | None
|
|
83
|
+
def fulltext_query(query: str, group_ids: list[str] | None, driver: GraphDriver):
|
|
84
|
+
if driver.provider == GraphProvider.KUZU:
|
|
85
|
+
# Kuzu only supports simple queries.
|
|
86
|
+
if len(query.split(' ')) > MAX_QUERY_LENGTH:
|
|
87
|
+
return ''
|
|
88
|
+
return query
|
|
82
89
|
group_ids_filter_list = (
|
|
83
|
-
[fulltext_syntax + f'group_id:"{g}"' for g in group_ids]
|
|
90
|
+
[driver.fulltext_syntax + f'group_id:"{g}"' for g in group_ids]
|
|
91
|
+
if group_ids is not None
|
|
92
|
+
else []
|
|
84
93
|
)
|
|
85
94
|
group_ids_filter = ''
|
|
86
95
|
for f in group_ids_filter_list:
|
|
@@ -124,12 +133,12 @@ async def get_mentioned_nodes(
|
|
|
124
133
|
WHERE episode.uuid IN $uuids
|
|
125
134
|
RETURN DISTINCT
|
|
126
135
|
"""
|
|
127
|
-
+
|
|
136
|
+
+ get_entity_node_return_query(driver.provider),
|
|
128
137
|
uuids=episode_uuids,
|
|
129
138
|
routing_='r',
|
|
130
139
|
)
|
|
131
140
|
|
|
132
|
-
nodes = [get_entity_node_from_record(record) for record in records]
|
|
141
|
+
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
|
|
133
142
|
|
|
134
143
|
return nodes
|
|
135
144
|
|
|
@@ -141,7 +150,7 @@ async def get_communities_by_nodes(
|
|
|
141
150
|
|
|
142
151
|
records, _, _ = await driver.execute_query(
|
|
143
152
|
"""
|
|
144
|
-
MATCH (
|
|
153
|
+
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)
|
|
145
154
|
WHERE m.uuid IN $uuids
|
|
146
155
|
RETURN DISTINCT
|
|
147
156
|
"""
|
|
@@ -163,11 +172,32 @@ async def edge_fulltext_search(
|
|
|
163
172
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
164
173
|
) -> list[EntityEdge]:
|
|
165
174
|
# fulltext search over facts
|
|
166
|
-
fuzzy_query = fulltext_query(query, group_ids, driver
|
|
175
|
+
fuzzy_query = fulltext_query(query, group_ids, driver)
|
|
176
|
+
|
|
167
177
|
if fuzzy_query == '':
|
|
168
178
|
return []
|
|
169
179
|
|
|
170
|
-
|
|
180
|
+
match_query = """
|
|
181
|
+
YIELD relationship AS rel, score
|
|
182
|
+
MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
|
|
183
|
+
"""
|
|
184
|
+
if driver.provider == GraphProvider.KUZU:
|
|
185
|
+
match_query = """
|
|
186
|
+
YIELD node, score
|
|
187
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: node.uuid})-[:RELATES_TO]->(m:Entity)
|
|
188
|
+
"""
|
|
189
|
+
|
|
190
|
+
filter_queries, filter_params = edge_search_filter_query_constructor(
|
|
191
|
+
search_filter, driver.provider
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
if group_ids is not None:
|
|
195
|
+
filter_queries.append('e.group_id IN $group_ids')
|
|
196
|
+
filter_params['group_ids'] = group_ids
|
|
197
|
+
|
|
198
|
+
filter_query = ''
|
|
199
|
+
if filter_queries:
|
|
200
|
+
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
|
|
171
201
|
|
|
172
202
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
173
203
|
res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue
|
|
@@ -180,13 +210,14 @@ async def edge_fulltext_search(
|
|
|
180
210
|
# Match the edge ids and return the values
|
|
181
211
|
query = (
|
|
182
212
|
"""
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
213
|
+
UNWIND $ids as id
|
|
214
|
+
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
215
|
+
WHERE e.group_id IN $group_ids
|
|
216
|
+
AND id(e)=id
|
|
217
|
+
"""
|
|
188
218
|
+ filter_query
|
|
189
219
|
+ """
|
|
220
|
+
AND id(e)=id
|
|
190
221
|
WITH e, id.score as score, startNode(e) AS n, endNode(e) AS m
|
|
191
222
|
RETURN
|
|
192
223
|
e.uuid AS uuid,
|
|
@@ -208,7 +239,6 @@ async def edge_fulltext_search(
|
|
|
208
239
|
records, _, _ = await driver.execute_query(
|
|
209
240
|
query,
|
|
210
241
|
query=fuzzy_query,
|
|
211
|
-
group_ids=group_ids,
|
|
212
242
|
ids=input_ids,
|
|
213
243
|
limit=limit,
|
|
214
244
|
routing_='r',
|
|
@@ -218,17 +248,14 @@ async def edge_fulltext_search(
|
|
|
218
248
|
return []
|
|
219
249
|
else:
|
|
220
250
|
query = (
|
|
221
|
-
get_relationships_query('edge_name_and_fact', provider=driver.provider)
|
|
222
|
-
+
|
|
223
|
-
YIELD relationship AS rel, score
|
|
224
|
-
MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
|
|
225
|
-
WHERE e.group_id IN $group_ids """
|
|
251
|
+
get_relationships_query('edge_name_and_fact', limit=limit, provider=driver.provider)
|
|
252
|
+
+ match_query
|
|
226
253
|
+ filter_query
|
|
227
254
|
+ """
|
|
228
255
|
WITH e, score, n, m
|
|
229
256
|
RETURN
|
|
230
257
|
"""
|
|
231
|
-
+
|
|
258
|
+
+ get_entity_edge_return_query(driver.provider)
|
|
232
259
|
+ """
|
|
233
260
|
ORDER BY score DESC
|
|
234
261
|
LIMIT $limit
|
|
@@ -238,13 +265,12 @@ async def edge_fulltext_search(
|
|
|
238
265
|
records, _, _ = await driver.execute_query(
|
|
239
266
|
query,
|
|
240
267
|
query=fuzzy_query,
|
|
241
|
-
group_ids=group_ids,
|
|
242
268
|
limit=limit,
|
|
243
269
|
routing_='r',
|
|
244
270
|
**filter_params,
|
|
245
271
|
)
|
|
246
272
|
|
|
247
|
-
edges = [get_entity_edge_from_record(record) for record in records]
|
|
273
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
248
274
|
|
|
249
275
|
return edges
|
|
250
276
|
|
|
@@ -259,32 +285,43 @@ async def edge_similarity_search(
|
|
|
259
285
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
260
286
|
min_score: float = DEFAULT_MIN_SCORE,
|
|
261
287
|
) -> list[EntityEdge]:
|
|
262
|
-
|
|
263
|
-
|
|
288
|
+
match_query = """
|
|
289
|
+
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
290
|
+
"""
|
|
291
|
+
if driver.provider == GraphProvider.KUZU:
|
|
292
|
+
match_query = """
|
|
293
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
|
|
294
|
+
"""
|
|
264
295
|
|
|
265
|
-
|
|
266
|
-
|
|
296
|
+
filter_queries, filter_params = edge_search_filter_query_constructor(
|
|
297
|
+
search_filter, driver.provider
|
|
298
|
+
)
|
|
267
299
|
|
|
268
|
-
group_filter_query: LiteralString = 'WHERE e.group_id IS NOT NULL'
|
|
269
300
|
if group_ids is not None:
|
|
270
|
-
|
|
271
|
-
|
|
301
|
+
filter_queries.append('e.group_id IN $group_ids')
|
|
302
|
+
filter_params['group_ids'] = group_ids
|
|
272
303
|
|
|
273
304
|
if source_node_uuid is not None:
|
|
274
|
-
|
|
275
|
-
|
|
305
|
+
filter_params['source_uuid'] = source_node_uuid
|
|
306
|
+
filter_queries.append('n.uuid = $source_uuid')
|
|
276
307
|
|
|
277
308
|
if target_node_uuid is not None:
|
|
278
|
-
|
|
279
|
-
|
|
309
|
+
filter_params['target_uuid'] = target_node_uuid
|
|
310
|
+
filter_queries.append('m.uuid = $target_uuid')
|
|
311
|
+
|
|
312
|
+
filter_query = ''
|
|
313
|
+
if filter_queries:
|
|
314
|
+
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
|
|
315
|
+
|
|
316
|
+
search_vector_var = '$search_vector'
|
|
317
|
+
if driver.provider == GraphProvider.KUZU:
|
|
318
|
+
search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
|
|
280
319
|
|
|
281
320
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
282
321
|
query = (
|
|
283
|
-
RUNTIME_QUERY
|
|
284
|
-
+ """
|
|
285
|
-
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
286
322
|
"""
|
|
287
|
-
|
|
323
|
+
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
324
|
+
"""
|
|
288
325
|
+ filter_query
|
|
289
326
|
+ """
|
|
290
327
|
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
|
|
@@ -296,7 +333,7 @@ async def edge_similarity_search(
|
|
|
296
333
|
limit=limit,
|
|
297
334
|
min_score=min_score,
|
|
298
335
|
routing_='r',
|
|
299
|
-
**
|
|
336
|
+
**filter_params,
|
|
300
337
|
)
|
|
301
338
|
|
|
302
339
|
if len(resp) > 0:
|
|
@@ -338,26 +375,22 @@ async def edge_similarity_search(
|
|
|
338
375
|
limit=limit,
|
|
339
376
|
min_score=min_score,
|
|
340
377
|
routing_='r',
|
|
341
|
-
**
|
|
378
|
+
**filter_params,
|
|
342
379
|
)
|
|
343
380
|
else:
|
|
344
381
|
return []
|
|
345
382
|
else:
|
|
346
383
|
query = (
|
|
347
|
-
|
|
348
|
-
+ """
|
|
349
|
-
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
350
|
-
"""
|
|
351
|
-
+ group_filter_query
|
|
384
|
+
match_query
|
|
352
385
|
+ filter_query
|
|
353
386
|
+ """
|
|
354
387
|
WITH DISTINCT e, n, m, """
|
|
355
|
-
+ get_vector_cosine_func_query('e.fact_embedding',
|
|
388
|
+
+ get_vector_cosine_func_query('e.fact_embedding', search_vector_var, driver.provider)
|
|
356
389
|
+ """ AS score
|
|
357
390
|
WHERE score > $min_score
|
|
358
391
|
RETURN
|
|
359
392
|
"""
|
|
360
|
-
+
|
|
393
|
+
+ get_entity_edge_return_query(driver.provider)
|
|
361
394
|
+ """
|
|
362
395
|
ORDER BY score DESC
|
|
363
396
|
LIMIT $limit
|
|
@@ -370,10 +403,10 @@ async def edge_similarity_search(
|
|
|
370
403
|
limit=limit,
|
|
371
404
|
min_score=min_score,
|
|
372
405
|
routing_='r',
|
|
373
|
-
**
|
|
406
|
+
**filter_params,
|
|
374
407
|
)
|
|
375
408
|
|
|
376
|
-
edges = [get_entity_edge_from_record(record) for record in records]
|
|
409
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
377
410
|
|
|
378
411
|
return edges
|
|
379
412
|
|
|
@@ -387,70 +420,116 @@ async def edge_bfs_search(
|
|
|
387
420
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
388
421
|
) -> list[EntityEdge]:
|
|
389
422
|
# vector similarity search over embedded facts
|
|
390
|
-
if bfs_origin_node_uuids is None:
|
|
423
|
+
if bfs_origin_node_uuids is None or len(bfs_origin_node_uuids) == 0:
|
|
391
424
|
return []
|
|
392
425
|
|
|
393
|
-
|
|
426
|
+
filter_queries, filter_params = edge_search_filter_query_constructor(
|
|
427
|
+
search_filter, driver.provider
|
|
428
|
+
)
|
|
394
429
|
|
|
395
|
-
if
|
|
396
|
-
|
|
430
|
+
if group_ids is not None:
|
|
431
|
+
filter_queries.append('e.group_id IN $group_ids')
|
|
432
|
+
filter_params['group_ids'] = group_ids
|
|
433
|
+
|
|
434
|
+
filter_query = ''
|
|
435
|
+
if filter_queries:
|
|
436
|
+
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
|
|
437
|
+
|
|
438
|
+
if driver.provider == GraphProvider.KUZU:
|
|
439
|
+
# Kuzu stores entity edges twice with an intermediate node, so we need to match them
|
|
440
|
+
# separately for the correct BFS depth.
|
|
441
|
+
depth = bfs_max_depth * 2 - 1
|
|
442
|
+
match_queries = [
|
|
397
443
|
f"""
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
LIMIT $limit
|
|
444
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
445
|
+
MATCH path = (origin:Entity {{uuid: origin_uuid}})-[:RELATES_TO*1..{depth}]->(:RelatesToNode_)
|
|
446
|
+
UNWIND nodes(path) AS relNode
|
|
447
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {{uuid: relNode.uuid}})-[:RELATES_TO]->(m:Entity)
|
|
448
|
+
""",
|
|
449
|
+
]
|
|
450
|
+
if bfs_max_depth > 1:
|
|
451
|
+
depth = (bfs_max_depth - 1) * 2 - 1
|
|
452
|
+
match_queries.append(f"""
|
|
453
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
454
|
+
MATCH path = (origin:Episodic {{uuid: origin_uuid}})-[:MENTIONS]->(:Entity)-[:RELATES_TO*1..{depth}]->(:RelatesToNode_)
|
|
455
|
+
UNWIND nodes(path) AS relNode
|
|
456
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {{uuid: relNode.uuid}})-[:RELATES_TO]->(m:Entity)
|
|
457
|
+
""")
|
|
458
|
+
|
|
459
|
+
records = []
|
|
460
|
+
for match_query in match_queries:
|
|
461
|
+
sub_records, _, _ = await driver.execute_query(
|
|
462
|
+
match_query
|
|
463
|
+
+ filter_query
|
|
464
|
+
+ """
|
|
465
|
+
RETURN DISTINCT
|
|
421
466
|
"""
|
|
422
|
-
|
|
467
|
+
+ get_entity_edge_return_query(driver.provider)
|
|
468
|
+
+ """
|
|
469
|
+
LIMIT $limit
|
|
470
|
+
""",
|
|
471
|
+
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
|
472
|
+
limit=limit,
|
|
473
|
+
routing_='r',
|
|
474
|
+
**filter_params,
|
|
475
|
+
)
|
|
476
|
+
records.extend(sub_records)
|
|
423
477
|
else:
|
|
424
|
-
|
|
425
|
-
|
|
478
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
479
|
+
query = (
|
|
480
|
+
f"""
|
|
426
481
|
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
427
|
-
MATCH path = (origin
|
|
482
|
+
MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS *1..{bfs_max_depth}]->(n:Entity)
|
|
483
|
+
WHERE origin:Entity OR origin:Episodic
|
|
428
484
|
UNWIND relationships(path) AS rel
|
|
429
|
-
MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
485
|
+
MATCH (n:Entity)-[e:RELATES_TO {{uuid: rel.uuid}}]-(m:Entity)
|
|
486
|
+
"""
|
|
487
|
+
+ filter_query
|
|
488
|
+
+ """
|
|
489
|
+
RETURN DISTINCT
|
|
490
|
+
e.uuid AS uuid,
|
|
491
|
+
e.group_id AS group_id,
|
|
492
|
+
startNode(e).uuid AS source_node_uuid,
|
|
493
|
+
endNode(e).uuid AS target_node_uuid,
|
|
494
|
+
e.created_at AS created_at,
|
|
495
|
+
e.name AS name,
|
|
496
|
+
e.fact AS fact,
|
|
497
|
+
split(e.episodes, ',') AS episodes,
|
|
498
|
+
e.expired_at AS expired_at,
|
|
499
|
+
e.valid_at AS valid_at,
|
|
500
|
+
e.invalid_at AS invalid_at,
|
|
501
|
+
properties(e) AS attributes
|
|
502
|
+
LIMIT $limit
|
|
503
|
+
"""
|
|
504
|
+
)
|
|
505
|
+
else:
|
|
506
|
+
query = (
|
|
507
|
+
f"""
|
|
508
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
509
|
+
MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(:Entity)
|
|
510
|
+
UNWIND relationships(path) AS rel
|
|
511
|
+
MATCH (n:Entity)-[e:RELATES_TO {{uuid: rel.uuid}}]-(m:Entity)
|
|
512
|
+
"""
|
|
513
|
+
+ filter_query
|
|
514
|
+
+ """
|
|
515
|
+
RETURN DISTINCT
|
|
516
|
+
"""
|
|
517
|
+
+ get_entity_edge_return_query(driver.provider)
|
|
518
|
+
+ """
|
|
519
|
+
LIMIT $limit
|
|
520
|
+
"""
|
|
521
|
+
)
|
|
442
522
|
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
)
|
|
523
|
+
records, _, _ = await driver.execute_query(
|
|
524
|
+
query,
|
|
525
|
+
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
|
526
|
+
depth=bfs_max_depth,
|
|
527
|
+
limit=limit,
|
|
528
|
+
routing_='r',
|
|
529
|
+
**filter_params,
|
|
530
|
+
)
|
|
452
531
|
|
|
453
|
-
edges = [get_entity_edge_from_record(record) for record in records]
|
|
532
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
454
533
|
|
|
455
534
|
return edges
|
|
456
535
|
|
|
@@ -461,12 +540,28 @@ async def node_fulltext_search(
|
|
|
461
540
|
search_filter: SearchFilters,
|
|
462
541
|
group_ids: list[str] | None = None,
|
|
463
542
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
543
|
+
use_local_indexes: bool = False,
|
|
464
544
|
) -> list[EntityNode]:
|
|
465
545
|
# BM25 search to get top nodes
|
|
466
|
-
fuzzy_query = fulltext_query(query, group_ids, driver
|
|
546
|
+
fuzzy_query = fulltext_query(query, group_ids, driver)
|
|
467
547
|
if fuzzy_query == '':
|
|
468
548
|
return []
|
|
469
|
-
|
|
549
|
+
|
|
550
|
+
filter_queries, filter_params = node_search_filter_query_constructor(
|
|
551
|
+
search_filter, driver.provider
|
|
552
|
+
)
|
|
553
|
+
|
|
554
|
+
if group_ids is not None:
|
|
555
|
+
filter_queries.append('n.group_id IN $group_ids')
|
|
556
|
+
filter_params['group_ids'] = group_ids
|
|
557
|
+
|
|
558
|
+
filter_query = ''
|
|
559
|
+
if filter_queries:
|
|
560
|
+
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
|
|
561
|
+
|
|
562
|
+
yield_query = 'YIELD node AS n, score'
|
|
563
|
+
if driver.provider == GraphProvider.KUZU:
|
|
564
|
+
yield_query = 'WITH node AS n, score'
|
|
470
565
|
|
|
471
566
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
472
567
|
res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
|
@@ -479,12 +574,12 @@ async def node_fulltext_search(
|
|
|
479
574
|
# Match the edge ides and return the values
|
|
480
575
|
query = (
|
|
481
576
|
"""
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
+
|
|
577
|
+
UNWIND $ids as i
|
|
578
|
+
MATCH (n:Entity)
|
|
579
|
+
WHERE n.uuid=i.id
|
|
580
|
+
RETURN
|
|
581
|
+
"""
|
|
582
|
+
+ get_entity_node_return_query(driver.provider)
|
|
488
583
|
+ """
|
|
489
584
|
ORDER BY i.score DESC
|
|
490
585
|
LIMIT $limit
|
|
@@ -494,7 +589,6 @@ async def node_fulltext_search(
|
|
|
494
589
|
query,
|
|
495
590
|
ids=input_ids,
|
|
496
591
|
query=fuzzy_query,
|
|
497
|
-
group_ids=group_ids,
|
|
498
592
|
limit=limit,
|
|
499
593
|
routing_='r',
|
|
500
594
|
**filter_params,
|
|
@@ -504,36 +598,32 @@ async def node_fulltext_search(
|
|
|
504
598
|
else:
|
|
505
599
|
index_name = (
|
|
506
600
|
'node_name_and_summary'
|
|
507
|
-
if not
|
|
601
|
+
if not use_local_indexes
|
|
508
602
|
else 'node_name_and_summary_'
|
|
509
603
|
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
|
|
510
604
|
)
|
|
511
605
|
query = (
|
|
512
|
-
get_nodes_query(
|
|
513
|
-
+
|
|
514
|
-
YIELD node AS n, score
|
|
515
|
-
WHERE n:Entity AND n.group_id IN $group_ids
|
|
516
|
-
"""
|
|
606
|
+
get_nodes_query(index_name, '$query', limit=limit, provider=driver.provider)
|
|
607
|
+
+ yield_query
|
|
517
608
|
+ filter_query
|
|
518
609
|
+ """
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
+
|
|
610
|
+
WITH n, score
|
|
611
|
+
ORDER BY score DESC
|
|
612
|
+
LIMIT $limit
|
|
613
|
+
RETURN
|
|
614
|
+
"""
|
|
615
|
+
+ get_entity_node_return_query(driver.provider)
|
|
525
616
|
)
|
|
526
617
|
|
|
527
618
|
records, _, _ = await driver.execute_query(
|
|
528
619
|
query,
|
|
529
620
|
query=fuzzy_query,
|
|
530
|
-
group_ids=group_ids,
|
|
531
621
|
limit=limit,
|
|
532
622
|
routing_='r',
|
|
533
623
|
**filter_params,
|
|
534
624
|
)
|
|
535
625
|
|
|
536
|
-
nodes = [get_entity_node_from_record(record) for record in records]
|
|
626
|
+
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
|
|
537
627
|
|
|
538
628
|
return nodes
|
|
539
629
|
|
|
@@ -545,25 +635,29 @@ async def node_similarity_search(
|
|
|
545
635
|
group_ids: list[str] | None = None,
|
|
546
636
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
547
637
|
min_score: float = DEFAULT_MIN_SCORE,
|
|
638
|
+
use_local_indexes: bool = False,
|
|
548
639
|
) -> list[EntityNode]:
|
|
549
|
-
|
|
550
|
-
|
|
640
|
+
filter_queries, filter_params = node_search_filter_query_constructor(
|
|
641
|
+
search_filter, driver.provider
|
|
642
|
+
)
|
|
551
643
|
|
|
552
|
-
group_filter_query: LiteralString = 'WHERE n.group_id IS NOT NULL'
|
|
553
644
|
if group_ids is not None:
|
|
554
|
-
|
|
555
|
-
|
|
645
|
+
filter_queries.append('n.group_id IN $group_ids')
|
|
646
|
+
filter_params['group_ids'] = group_ids
|
|
556
647
|
|
|
557
|
-
filter_query
|
|
558
|
-
|
|
648
|
+
filter_query = ''
|
|
649
|
+
if filter_queries:
|
|
650
|
+
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
|
|
651
|
+
|
|
652
|
+
search_vector_var = '$search_vector'
|
|
653
|
+
if driver.provider == GraphProvider.KUZU:
|
|
654
|
+
search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
|
|
559
655
|
|
|
560
656
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
561
657
|
query = (
|
|
562
|
-
RUNTIME_QUERY
|
|
563
|
-
+ """
|
|
564
|
-
MATCH (n:Entity)
|
|
565
658
|
"""
|
|
566
|
-
|
|
659
|
+
MATCH (n:Entity)
|
|
660
|
+
"""
|
|
567
661
|
+ filter_query
|
|
568
662
|
+ """
|
|
569
663
|
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
|
@@ -571,9 +665,8 @@ async def node_similarity_search(
|
|
|
571
665
|
)
|
|
572
666
|
resp, header, _ = await driver.execute_query(
|
|
573
667
|
query,
|
|
574
|
-
params=
|
|
668
|
+
params=filter_params,
|
|
575
669
|
search_vector=search_vector,
|
|
576
|
-
group_ids=group_ids,
|
|
577
670
|
limit=limit,
|
|
578
671
|
min_score=min_score,
|
|
579
672
|
routing_='r',
|
|
@@ -593,12 +686,12 @@ async def node_similarity_search(
|
|
|
593
686
|
# Match the edge ides and return the values
|
|
594
687
|
query = (
|
|
595
688
|
"""
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
+
|
|
689
|
+
UNWIND $ids as i
|
|
690
|
+
MATCH (n:Entity)
|
|
691
|
+
WHERE id(n)=i.id
|
|
692
|
+
RETURN
|
|
693
|
+
"""
|
|
694
|
+
+ get_entity_node_return_query(driver.provider)
|
|
602
695
|
+ """
|
|
603
696
|
ORDER BY i.score DESC
|
|
604
697
|
LIMIT $limit
|
|
@@ -611,11 +704,11 @@ async def node_similarity_search(
|
|
|
611
704
|
limit=limit,
|
|
612
705
|
min_score=min_score,
|
|
613
706
|
routing_='r',
|
|
614
|
-
**
|
|
707
|
+
**filter_params,
|
|
615
708
|
)
|
|
616
709
|
else:
|
|
617
710
|
return []
|
|
618
|
-
elif driver.provider == GraphProvider.NEO4J and
|
|
711
|
+
elif driver.provider == GraphProvider.NEO4J and use_local_indexes:
|
|
619
712
|
index_name = 'group_entity_vector_' + (
|
|
620
713
|
group_ids[0].replace('-', '') if group_ids is not None else ''
|
|
621
714
|
)
|
|
@@ -623,13 +716,12 @@ async def node_similarity_search(
|
|
|
623
716
|
f"""
|
|
624
717
|
CALL db.index.vector.queryNodes('{index_name}', {limit}, $search_vector) YIELD node AS n, score
|
|
625
718
|
"""
|
|
626
|
-
+ group_filter_query
|
|
627
719
|
+ filter_query
|
|
628
720
|
+ """
|
|
629
721
|
AND score > $min_score
|
|
630
722
|
RETURN
|
|
631
723
|
"""
|
|
632
|
-
+
|
|
724
|
+
+ get_entity_node_return_query(driver.provider)
|
|
633
725
|
+ """
|
|
634
726
|
ORDER BY score DESC
|
|
635
727
|
LIMIT $limit
|
|
@@ -642,25 +734,23 @@ async def node_similarity_search(
|
|
|
642
734
|
limit=limit,
|
|
643
735
|
min_score=min_score,
|
|
644
736
|
routing_='r',
|
|
645
|
-
**
|
|
737
|
+
**filter_params,
|
|
646
738
|
)
|
|
647
739
|
|
|
648
740
|
else:
|
|
649
741
|
query = (
|
|
650
|
-
RUNTIME_QUERY
|
|
651
|
-
+ """
|
|
652
|
-
MATCH (n:Entity)
|
|
653
742
|
"""
|
|
654
|
-
|
|
743
|
+
MATCH (n:Entity)
|
|
744
|
+
"""
|
|
655
745
|
+ filter_query
|
|
656
746
|
+ """
|
|
657
747
|
WITH n, """
|
|
658
|
-
+ get_vector_cosine_func_query('n.name_embedding',
|
|
748
|
+
+ get_vector_cosine_func_query('n.name_embedding', search_vector_var, driver.provider)
|
|
659
749
|
+ """ AS score
|
|
660
750
|
WHERE score > $min_score
|
|
661
751
|
RETURN
|
|
662
752
|
"""
|
|
663
|
-
+
|
|
753
|
+
+ get_entity_node_return_query(driver.provider)
|
|
664
754
|
+ """
|
|
665
755
|
ORDER BY score DESC
|
|
666
756
|
LIMIT $limit
|
|
@@ -673,10 +763,10 @@ async def node_similarity_search(
|
|
|
673
763
|
limit=limit,
|
|
674
764
|
min_score=min_score,
|
|
675
765
|
routing_='r',
|
|
676
|
-
**
|
|
766
|
+
**filter_params,
|
|
677
767
|
)
|
|
678
768
|
|
|
679
|
-
nodes = [get_entity_node_from_record(record) for record in records]
|
|
769
|
+
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
|
|
680
770
|
|
|
681
771
|
return nodes
|
|
682
772
|
|
|
@@ -689,56 +779,82 @@ async def node_bfs_search(
|
|
|
689
779
|
group_ids: list[str] | None = None,
|
|
690
780
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
691
781
|
) -> list[EntityNode]:
|
|
692
|
-
|
|
693
|
-
if bfs_origin_node_uuids is None:
|
|
782
|
+
if bfs_origin_node_uuids is None or len(bfs_origin_node_uuids) == 0 or bfs_max_depth < 1:
|
|
694
783
|
return []
|
|
695
784
|
|
|
696
|
-
|
|
785
|
+
filter_queries, filter_params = node_search_filter_query_constructor(
|
|
786
|
+
search_filter, driver.provider
|
|
787
|
+
)
|
|
788
|
+
|
|
789
|
+
if group_ids is not None:
|
|
790
|
+
filter_queries.append('n.group_id IN $group_ids')
|
|
791
|
+
filter_queries.append('origin.group_id IN $group_ids')
|
|
792
|
+
filter_params['group_ids'] = group_ids
|
|
793
|
+
|
|
794
|
+
filter_query = ''
|
|
795
|
+
if filter_queries:
|
|
796
|
+
filter_query = ' AND ' + (' AND '.join(filter_queries))
|
|
797
|
+
|
|
798
|
+
match_queries = [
|
|
799
|
+
f"""
|
|
800
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
801
|
+
MATCH (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
|
|
802
|
+
WHERE n.group_id = origin.group_id
|
|
803
|
+
"""
|
|
804
|
+
]
|
|
697
805
|
|
|
698
806
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
699
|
-
|
|
807
|
+
match_queries = [
|
|
700
808
|
f"""
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
"""
|
|
706
|
-
+ filter_query
|
|
707
|
-
+ """
|
|
708
|
-
RETURN
|
|
809
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
810
|
+
MATCH (origin {{uuid: origin_uuid}})-[e:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
|
|
811
|
+
WHERE origin:Entity OR origin.Episode
|
|
812
|
+
AND n.group_id = origin.group_id
|
|
709
813
|
"""
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
814
|
+
]
|
|
815
|
+
|
|
816
|
+
if driver.provider == GraphProvider.KUZU:
|
|
817
|
+
depth = bfs_max_depth * 2
|
|
818
|
+
match_queries = [
|
|
713
819
|
"""
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
820
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
821
|
+
MATCH (origin:Episodic {uuid: origin_uuid})-[:MENTIONS]->(n:Entity)
|
|
822
|
+
WHERE n.group_id = origin.group_id
|
|
823
|
+
""",
|
|
717
824
|
f"""
|
|
825
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
826
|
+
MATCH (origin:Entity {{uuid: origin_uuid}})-[:RELATES_TO*2..{depth}]->(n:Entity)
|
|
827
|
+
WHERE n.group_id = origin.group_id
|
|
828
|
+
""",
|
|
829
|
+
]
|
|
830
|
+
if bfs_max_depth > 1:
|
|
831
|
+
depth = (bfs_max_depth - 1) * 2
|
|
832
|
+
match_queries.append(f"""
|
|
718
833
|
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
719
|
-
MATCH (origin:
|
|
834
|
+
MATCH (origin:Episodic {{uuid: origin_uuid}})-[:MENTIONS]->(:Entity)-[:RELATES_TO*2..{depth}]->(n:Entity)
|
|
720
835
|
WHERE n.group_id = origin.group_id
|
|
721
|
-
|
|
722
|
-
|
|
836
|
+
""")
|
|
837
|
+
|
|
838
|
+
records = []
|
|
839
|
+
for match_query in match_queries:
|
|
840
|
+
sub_records, _, _ = await driver.execute_query(
|
|
841
|
+
match_query
|
|
723
842
|
+ filter_query
|
|
724
843
|
+ """
|
|
725
844
|
RETURN
|
|
726
845
|
"""
|
|
727
|
-
+
|
|
846
|
+
+ get_entity_node_return_query(driver.provider)
|
|
728
847
|
+ """
|
|
729
848
|
LIMIT $limit
|
|
730
|
-
"""
|
|
849
|
+
""",
|
|
850
|
+
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
|
851
|
+
limit=limit,
|
|
852
|
+
routing_='r',
|
|
853
|
+
**filter_params,
|
|
731
854
|
)
|
|
855
|
+
records.extend(sub_records)
|
|
732
856
|
|
|
733
|
-
|
|
734
|
-
query,
|
|
735
|
-
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
|
736
|
-
group_ids=group_ids,
|
|
737
|
-
limit=limit,
|
|
738
|
-
routing_='r',
|
|
739
|
-
**filter_params,
|
|
740
|
-
)
|
|
741
|
-
nodes = [get_entity_node_from_record(record) for record in records]
|
|
857
|
+
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
|
|
742
858
|
|
|
743
859
|
return nodes
|
|
744
860
|
|
|
@@ -749,12 +865,19 @@ async def episode_fulltext_search(
|
|
|
749
865
|
_search_filter: SearchFilters,
|
|
750
866
|
group_ids: list[str] | None = None,
|
|
751
867
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
868
|
+
use_local_indexes: bool = False,
|
|
752
869
|
) -> list[EpisodicNode]:
|
|
753
870
|
# BM25 search to get top episodes
|
|
754
|
-
fuzzy_query = fulltext_query(query, group_ids, driver
|
|
871
|
+
fuzzy_query = fulltext_query(query, group_ids, driver)
|
|
755
872
|
if fuzzy_query == '':
|
|
756
873
|
return []
|
|
757
874
|
|
|
875
|
+
filter_params: dict[str, Any] = {}
|
|
876
|
+
group_filter_query: LiteralString = ''
|
|
877
|
+
if group_ids is not None:
|
|
878
|
+
group_filter_query += '\nAND e.group_id IN $group_ids'
|
|
879
|
+
filter_params['group_ids'] = group_ids
|
|
880
|
+
|
|
758
881
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
759
882
|
res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
|
760
883
|
if res['hits']['total']['value'] > 0:
|
|
@@ -768,7 +891,7 @@ async def episode_fulltext_search(
|
|
|
768
891
|
UNWIND $ids as i
|
|
769
892
|
MATCH (e:Episodic)
|
|
770
893
|
WHERE e.uuid=i.id
|
|
771
|
-
RETURN
|
|
894
|
+
RETURN
|
|
772
895
|
e.content AS content,
|
|
773
896
|
e.created_at AS created_at,
|
|
774
897
|
e.valid_at AS valid_at,
|
|
@@ -785,26 +908,28 @@ async def episode_fulltext_search(
|
|
|
785
908
|
query,
|
|
786
909
|
ids=input_ids,
|
|
787
910
|
query=fuzzy_query,
|
|
788
|
-
group_ids=group_ids,
|
|
789
911
|
limit=limit,
|
|
790
912
|
routing_='r',
|
|
913
|
+
**filter_params,
|
|
791
914
|
)
|
|
792
915
|
else:
|
|
793
916
|
return []
|
|
794
917
|
else:
|
|
795
918
|
index_name = (
|
|
796
919
|
'episode_content'
|
|
797
|
-
if not
|
|
920
|
+
if not use_local_indexes
|
|
798
921
|
else 'episode_content_'
|
|
799
922
|
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
|
|
800
923
|
)
|
|
801
924
|
query = (
|
|
802
|
-
get_nodes_query(
|
|
925
|
+
get_nodes_query(index_name, '$query', limit=limit, provider=driver.provider)
|
|
803
926
|
+ """
|
|
804
927
|
YIELD node AS episode, score
|
|
805
928
|
MATCH (e:Episodic)
|
|
806
929
|
WHERE e.uuid = episode.uuid
|
|
807
|
-
|
|
930
|
+
"""
|
|
931
|
+
+ group_filter_query
|
|
932
|
+
+ """
|
|
808
933
|
RETURN
|
|
809
934
|
"""
|
|
810
935
|
+ EPISODIC_NODE_RETURN
|
|
@@ -815,12 +940,9 @@ async def episode_fulltext_search(
|
|
|
815
940
|
)
|
|
816
941
|
|
|
817
942
|
records, _, _ = await driver.execute_query(
|
|
818
|
-
query,
|
|
819
|
-
query=fuzzy_query,
|
|
820
|
-
group_ids=group_ids,
|
|
821
|
-
limit=limit,
|
|
822
|
-
routing_='r',
|
|
943
|
+
query, query=fuzzy_query, limit=limit, routing_='r', **filter_params
|
|
823
944
|
)
|
|
945
|
+
|
|
824
946
|
episodes = [get_episodic_node_from_record(record) for record in records]
|
|
825
947
|
|
|
826
948
|
return episodes
|
|
@@ -833,10 +955,20 @@ async def community_fulltext_search(
|
|
|
833
955
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
834
956
|
) -> list[CommunityNode]:
|
|
835
957
|
# BM25 search to get top communities
|
|
836
|
-
fuzzy_query = fulltext_query(query, group_ids, driver
|
|
958
|
+
fuzzy_query = fulltext_query(query, group_ids, driver)
|
|
837
959
|
if fuzzy_query == '':
|
|
838
960
|
return []
|
|
839
961
|
|
|
962
|
+
filter_params: dict[str, Any] = {}
|
|
963
|
+
group_filter_query: LiteralString = ''
|
|
964
|
+
if group_ids is not None:
|
|
965
|
+
group_filter_query = 'WHERE c.group_id IN $group_ids'
|
|
966
|
+
filter_params['group_ids'] = group_ids
|
|
967
|
+
|
|
968
|
+
yield_query = 'YIELD node AS c, score'
|
|
969
|
+
if driver.provider == GraphProvider.KUZU:
|
|
970
|
+
yield_query = 'WITH node AS c, score'
|
|
971
|
+
|
|
840
972
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
841
973
|
res = driver.run_aoss_query('community_name', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
|
842
974
|
if res['hits']['total']['value'] > 0:
|
|
@@ -852,9 +984,9 @@ async def community_fulltext_search(
|
|
|
852
984
|
WHERE comm.uuid=i.id
|
|
853
985
|
RETURN
|
|
854
986
|
comm.uuid AS uuid,
|
|
855
|
-
comm.group_id AS group_id,
|
|
856
|
-
comm.name AS name,
|
|
857
|
-
comm.created_at AS created_at,
|
|
987
|
+
comm.group_id AS group_id,
|
|
988
|
+
comm.name AS name,
|
|
989
|
+
comm.created_at AS created_at,
|
|
858
990
|
comm.summary AS summary,
|
|
859
991
|
[x IN split(comm.name_embedding, ",") | toFloat(x)]AS name_embedding
|
|
860
992
|
ORDER BY i.score DESC
|
|
@@ -864,18 +996,21 @@ async def community_fulltext_search(
|
|
|
864
996
|
query,
|
|
865
997
|
ids=input_ids,
|
|
866
998
|
query=fuzzy_query,
|
|
867
|
-
group_ids=group_ids,
|
|
868
999
|
limit=limit,
|
|
869
1000
|
routing_='r',
|
|
1001
|
+
**filter_params,
|
|
870
1002
|
)
|
|
871
1003
|
else:
|
|
872
1004
|
return []
|
|
873
1005
|
else:
|
|
874
1006
|
query = (
|
|
875
|
-
get_nodes_query(
|
|
1007
|
+
get_nodes_query('community_name', '$query', limit=limit, provider=driver.provider)
|
|
1008
|
+
+ yield_query
|
|
1009
|
+
+ """
|
|
1010
|
+
WITH c, score
|
|
1011
|
+
"""
|
|
1012
|
+
+ group_filter_query
|
|
876
1013
|
+ """
|
|
877
|
-
YIELD node AS n, score
|
|
878
|
-
WHERE n.group_id IN $group_ids
|
|
879
1014
|
RETURN
|
|
880
1015
|
"""
|
|
881
1016
|
+ COMMUNITY_NODE_RETURN
|
|
@@ -886,12 +1021,9 @@ async def community_fulltext_search(
|
|
|
886
1021
|
)
|
|
887
1022
|
|
|
888
1023
|
records, _, _ = await driver.execute_query(
|
|
889
|
-
query,
|
|
890
|
-
query=fuzzy_query,
|
|
891
|
-
group_ids=group_ids,
|
|
892
|
-
limit=limit,
|
|
893
|
-
routing_='r',
|
|
1024
|
+
query, query=fuzzy_query, limit=limit, routing_='r', **filter_params
|
|
894
1025
|
)
|
|
1026
|
+
|
|
895
1027
|
communities = [get_community_node_from_record(record) for record in records]
|
|
896
1028
|
|
|
897
1029
|
return communities
|
|
@@ -909,15 +1041,14 @@ async def community_similarity_search(
|
|
|
909
1041
|
|
|
910
1042
|
group_filter_query: LiteralString = ''
|
|
911
1043
|
if group_ids is not None:
|
|
912
|
-
group_filter_query += 'WHERE
|
|
1044
|
+
group_filter_query += ' WHERE c.group_id IN $group_ids'
|
|
913
1045
|
query_params['group_ids'] = group_ids
|
|
914
1046
|
|
|
915
1047
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
916
1048
|
query = (
|
|
917
|
-
RUNTIME_QUERY
|
|
918
|
-
+ """
|
|
919
|
-
MATCH (n:Community)
|
|
920
1049
|
"""
|
|
1050
|
+
MATCH (n:Community)
|
|
1051
|
+
"""
|
|
921
1052
|
+ group_filter_query
|
|
922
1053
|
+ """
|
|
923
1054
|
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
|
@@ -951,8 +1082,8 @@ async def community_similarity_search(
|
|
|
951
1082
|
RETURN
|
|
952
1083
|
comm.uuid As uuid,
|
|
953
1084
|
comm.group_id AS group_id,
|
|
954
|
-
comm.name AS name,
|
|
955
|
-
comm.created_at AS created_at,
|
|
1085
|
+
comm.name AS name,
|
|
1086
|
+
comm.created_at AS created_at,
|
|
956
1087
|
comm.summary AS summary,
|
|
957
1088
|
comm.name_embedding AS name_embedding
|
|
958
1089
|
ORDER BY i.score DESC
|
|
@@ -970,16 +1101,19 @@ async def community_similarity_search(
|
|
|
970
1101
|
else:
|
|
971
1102
|
return []
|
|
972
1103
|
else:
|
|
1104
|
+
search_vector_var = '$search_vector'
|
|
1105
|
+
if driver.provider == GraphProvider.KUZU:
|
|
1106
|
+
search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
|
|
1107
|
+
|
|
973
1108
|
query = (
|
|
974
|
-
RUNTIME_QUERY
|
|
975
|
-
+ """
|
|
976
|
-
MATCH (n:Community)
|
|
977
1109
|
"""
|
|
1110
|
+
MATCH (c:Community)
|
|
1111
|
+
"""
|
|
978
1112
|
+ group_filter_query
|
|
979
1113
|
+ """
|
|
980
|
-
WITH
|
|
1114
|
+
WITH c,
|
|
981
1115
|
"""
|
|
982
|
-
+ get_vector_cosine_func_query('
|
|
1116
|
+
+ get_vector_cosine_func_query('c.name_embedding', search_vector_var, driver.provider)
|
|
983
1117
|
+ """ AS score
|
|
984
1118
|
WHERE score > $min_score
|
|
985
1119
|
RETURN
|
|
@@ -999,6 +1133,7 @@ async def community_similarity_search(
|
|
|
999
1133
|
routing_='r',
|
|
1000
1134
|
**query_params,
|
|
1001
1135
|
)
|
|
1136
|
+
|
|
1002
1137
|
communities = [get_community_node_from_record(record) for record in records]
|
|
1003
1138
|
|
|
1004
1139
|
return communities
|
|
@@ -1089,67 +1224,127 @@ async def get_relevant_nodes(
|
|
|
1089
1224
|
return []
|
|
1090
1225
|
|
|
1091
1226
|
group_id = nodes[0].group_id
|
|
1092
|
-
|
|
1093
|
-
# vector similarity search over entity names
|
|
1094
|
-
query_params: dict[str, Any] = {}
|
|
1095
|
-
|
|
1096
|
-
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
|
1097
|
-
query_params.update(filter_params)
|
|
1098
|
-
|
|
1099
|
-
query = (
|
|
1100
|
-
RUNTIME_QUERY
|
|
1101
|
-
+ """
|
|
1102
|
-
UNWIND $nodes AS node
|
|
1103
|
-
MATCH (n:Entity {group_id: $group_id})
|
|
1104
|
-
"""
|
|
1105
|
-
+ filter_query
|
|
1106
|
-
+ """
|
|
1107
|
-
WITH node, n, """
|
|
1108
|
-
+ get_vector_cosine_func_query('n.name_embedding', 'node.name_embedding', driver.provider)
|
|
1109
|
-
+ """ AS score
|
|
1110
|
-
WHERE score > $min_score
|
|
1111
|
-
WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
|
|
1112
|
-
"""
|
|
1113
|
-
+ get_nodes_query(driver.provider, 'node_name_and_summary', 'node.fulltext_query')
|
|
1114
|
-
+ """
|
|
1115
|
-
YIELD node AS m
|
|
1116
|
-
WHERE m.group_id = $group_id
|
|
1117
|
-
WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
|
|
1118
|
-
|
|
1119
|
-
WITH node,
|
|
1120
|
-
top_vector_nodes,
|
|
1121
|
-
[m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
|
|
1122
|
-
|
|
1123
|
-
WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
|
|
1124
|
-
|
|
1125
|
-
UNWIND combined_nodes AS combined_node
|
|
1126
|
-
WITH node, collect(DISTINCT combined_node) AS deduped_nodes
|
|
1127
|
-
|
|
1128
|
-
RETURN
|
|
1129
|
-
node.uuid AS search_node_uuid,
|
|
1130
|
-
[x IN deduped_nodes | {
|
|
1131
|
-
uuid: x.uuid,
|
|
1132
|
-
name: x.name,
|
|
1133
|
-
name_embedding: x.name_embedding,
|
|
1134
|
-
group_id: x.group_id,
|
|
1135
|
-
created_at: x.created_at,
|
|
1136
|
-
summary: x.summary,
|
|
1137
|
-
labels: labels(x),
|
|
1138
|
-
attributes: properties(x)
|
|
1139
|
-
}] AS matches
|
|
1140
|
-
"""
|
|
1141
|
-
)
|
|
1142
|
-
|
|
1143
1227
|
query_nodes = [
|
|
1144
1228
|
{
|
|
1145
1229
|
'uuid': node.uuid,
|
|
1146
1230
|
'name': node.name,
|
|
1147
1231
|
'name_embedding': node.name_embedding,
|
|
1148
|
-
'fulltext_query': fulltext_query(node.name, [node.group_id], driver
|
|
1232
|
+
'fulltext_query': fulltext_query(node.name, [node.group_id], driver),
|
|
1149
1233
|
}
|
|
1150
1234
|
for node in nodes
|
|
1151
1235
|
]
|
|
1152
1236
|
|
|
1237
|
+
filter_queries, filter_params = node_search_filter_query_constructor(
|
|
1238
|
+
search_filter, driver.provider
|
|
1239
|
+
)
|
|
1240
|
+
|
|
1241
|
+
filter_query = ''
|
|
1242
|
+
if filter_queries:
|
|
1243
|
+
filter_query = 'WHERE ' + (' AND '.join(filter_queries))
|
|
1244
|
+
|
|
1245
|
+
if driver.provider == GraphProvider.KUZU:
|
|
1246
|
+
embedding_size = len(nodes[0].name_embedding) if nodes[0].name_embedding is not None else 0
|
|
1247
|
+
if embedding_size == 0:
|
|
1248
|
+
return []
|
|
1249
|
+
|
|
1250
|
+
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
|
|
1251
|
+
query = (
|
|
1252
|
+
"""
|
|
1253
|
+
UNWIND $nodes AS node
|
|
1254
|
+
MATCH (n:Entity {group_id: $group_id})
|
|
1255
|
+
"""
|
|
1256
|
+
+ filter_query
|
|
1257
|
+
+ """
|
|
1258
|
+
WITH node, n, """
|
|
1259
|
+
+ get_vector_cosine_func_query(
|
|
1260
|
+
'n.name_embedding',
|
|
1261
|
+
f'CAST(node.name_embedding AS FLOAT[{embedding_size}])',
|
|
1262
|
+
driver.provider,
|
|
1263
|
+
)
|
|
1264
|
+
+ """ AS score
|
|
1265
|
+
WHERE score > $min_score
|
|
1266
|
+
WITH node, collect(n)[:$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
|
|
1267
|
+
"""
|
|
1268
|
+
+ get_nodes_query(
|
|
1269
|
+
'node_name_and_summary',
|
|
1270
|
+
'node.fulltext_query',
|
|
1271
|
+
limit=limit,
|
|
1272
|
+
provider=driver.provider,
|
|
1273
|
+
)
|
|
1274
|
+
+ """
|
|
1275
|
+
WITH node AS m
|
|
1276
|
+
WHERE m.group_id = $group_id AND NOT m.uuid IN vector_node_uuids
|
|
1277
|
+
WITH node, top_vector_nodes, collect(m) AS fulltext_nodes
|
|
1278
|
+
|
|
1279
|
+
WITH node, list_concat(top_vector_nodes, fulltext_nodes) AS combined_nodes
|
|
1280
|
+
|
|
1281
|
+
UNWIND combined_nodes AS x
|
|
1282
|
+
WITH node, collect(DISTINCT {
|
|
1283
|
+
uuid: x.uuid,
|
|
1284
|
+
name: x.name,
|
|
1285
|
+
name_embedding: x.name_embedding,
|
|
1286
|
+
group_id: x.group_id,
|
|
1287
|
+
created_at: x.created_at,
|
|
1288
|
+
summary: x.summary,
|
|
1289
|
+
labels: x.labels,
|
|
1290
|
+
attributes: x.attributes
|
|
1291
|
+
}) AS matches
|
|
1292
|
+
|
|
1293
|
+
RETURN
|
|
1294
|
+
node.uuid AS search_node_uuid, matches
|
|
1295
|
+
"""
|
|
1296
|
+
)
|
|
1297
|
+
else:
|
|
1298
|
+
query = (
|
|
1299
|
+
"""
|
|
1300
|
+
UNWIND $nodes AS node
|
|
1301
|
+
MATCH (n:Entity {group_id: $group_id})
|
|
1302
|
+
"""
|
|
1303
|
+
+ filter_query
|
|
1304
|
+
+ """
|
|
1305
|
+
WITH node, n, """
|
|
1306
|
+
+ get_vector_cosine_func_query(
|
|
1307
|
+
'n.name_embedding', 'node.name_embedding', driver.provider
|
|
1308
|
+
)
|
|
1309
|
+
+ """ AS score
|
|
1310
|
+
WHERE score > $min_score
|
|
1311
|
+
WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
|
|
1312
|
+
"""
|
|
1313
|
+
+ get_nodes_query(
|
|
1314
|
+
'node_name_and_summary',
|
|
1315
|
+
'node.fulltext_query',
|
|
1316
|
+
limit=limit,
|
|
1317
|
+
provider=driver.provider,
|
|
1318
|
+
)
|
|
1319
|
+
+ """
|
|
1320
|
+
YIELD node AS m
|
|
1321
|
+
WHERE m.group_id = $group_id
|
|
1322
|
+
WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
|
|
1323
|
+
|
|
1324
|
+
WITH node,
|
|
1325
|
+
top_vector_nodes,
|
|
1326
|
+
[m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
|
|
1327
|
+
|
|
1328
|
+
WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
|
|
1329
|
+
|
|
1330
|
+
UNWIND combined_nodes AS combined_node
|
|
1331
|
+
WITH node, collect(DISTINCT combined_node) AS deduped_nodes
|
|
1332
|
+
|
|
1333
|
+
RETURN
|
|
1334
|
+
node.uuid AS search_node_uuid,
|
|
1335
|
+
[x IN deduped_nodes | {
|
|
1336
|
+
uuid: x.uuid,
|
|
1337
|
+
name: x.name,
|
|
1338
|
+
name_embedding: x.name_embedding,
|
|
1339
|
+
group_id: x.group_id,
|
|
1340
|
+
created_at: x.created_at,
|
|
1341
|
+
summary: x.summary,
|
|
1342
|
+
labels: labels(x),
|
|
1343
|
+
attributes: properties(x)
|
|
1344
|
+
}] AS matches
|
|
1345
|
+
"""
|
|
1346
|
+
)
|
|
1347
|
+
|
|
1153
1348
|
results, _, _ = await driver.execute_query(
|
|
1154
1349
|
query,
|
|
1155
1350
|
nodes=query_nodes,
|
|
@@ -1157,12 +1352,12 @@ async def get_relevant_nodes(
|
|
|
1157
1352
|
limit=limit,
|
|
1158
1353
|
min_score=min_score,
|
|
1159
1354
|
routing_='r',
|
|
1160
|
-
**
|
|
1355
|
+
**filter_params,
|
|
1161
1356
|
)
|
|
1162
1357
|
|
|
1163
1358
|
relevant_nodes_dict: dict[str, list[EntityNode]] = {
|
|
1164
1359
|
result['search_node_uuid']: [
|
|
1165
|
-
get_entity_node_from_record(record) for record in result['matches']
|
|
1360
|
+
get_entity_node_from_record(record, driver.provider) for record in result['matches']
|
|
1166
1361
|
]
|
|
1167
1362
|
for result in results
|
|
1168
1363
|
}
|
|
@@ -1182,22 +1377,24 @@ async def get_relevant_edges(
|
|
|
1182
1377
|
if len(edges) == 0:
|
|
1183
1378
|
return []
|
|
1184
1379
|
|
|
1185
|
-
|
|
1380
|
+
filter_queries, filter_params = edge_search_filter_query_constructor(
|
|
1381
|
+
search_filter, driver.provider
|
|
1382
|
+
)
|
|
1186
1383
|
|
|
1187
|
-
filter_query
|
|
1188
|
-
|
|
1384
|
+
filter_query = ''
|
|
1385
|
+
if filter_queries:
|
|
1386
|
+
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
|
|
1189
1387
|
|
|
1190
1388
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
1191
1389
|
query = (
|
|
1192
|
-
RUNTIME_QUERY
|
|
1193
|
-
+ """
|
|
1194
|
-
UNWIND $edges AS edge
|
|
1195
|
-
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
|
1196
1390
|
"""
|
|
1391
|
+
UNWIND $edges AS edge
|
|
1392
|
+
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
|
1393
|
+
"""
|
|
1197
1394
|
+ filter_query
|
|
1198
1395
|
+ """
|
|
1199
1396
|
WITH e, edge
|
|
1200
|
-
RETURN DISTINCT id(e) as id, e.fact_embedding as source_embedding, edge.uuid as search_edge_uuid,
|
|
1397
|
+
RETURN DISTINCT id(e) as id, e.fact_embedding as source_embedding, edge.uuid as search_edge_uuid,
|
|
1201
1398
|
edge.fact_embedding as target_embedding
|
|
1202
1399
|
"""
|
|
1203
1400
|
)
|
|
@@ -1207,7 +1404,7 @@ async def get_relevant_edges(
|
|
|
1207
1404
|
limit=limit,
|
|
1208
1405
|
min_score=min_score,
|
|
1209
1406
|
routing_='r',
|
|
1210
|
-
**
|
|
1407
|
+
**filter_params,
|
|
1211
1408
|
)
|
|
1212
1409
|
|
|
1213
1410
|
# Calculate Cosine similarity then return the edge ids
|
|
@@ -1220,7 +1417,7 @@ async def get_relevant_edges(
|
|
|
1220
1417
|
input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})
|
|
1221
1418
|
|
|
1222
1419
|
# Match the edge ides and return the values
|
|
1223
|
-
query = """
|
|
1420
|
+
query = """
|
|
1224
1421
|
UNWIND $ids AS edge
|
|
1225
1422
|
MATCH ()-[e]->()
|
|
1226
1423
|
WHERE id(e) = edge.id
|
|
@@ -1246,49 +1443,93 @@ async def get_relevant_edges(
|
|
|
1246
1443
|
|
|
1247
1444
|
results, _, _ = await driver.execute_query(
|
|
1248
1445
|
query,
|
|
1249
|
-
params=query_params,
|
|
1250
1446
|
ids=input_ids,
|
|
1251
1447
|
edges=[edge.model_dump() for edge in edges],
|
|
1252
1448
|
limit=limit,
|
|
1253
1449
|
min_score=min_score,
|
|
1254
1450
|
routing_='r',
|
|
1255
|
-
**
|
|
1451
|
+
**filter_params,
|
|
1256
1452
|
)
|
|
1257
1453
|
else:
|
|
1258
|
-
|
|
1259
|
-
|
|
1260
|
-
|
|
1261
|
-
|
|
1262
|
-
|
|
1263
|
-
|
|
1264
|
-
|
|
1265
|
-
|
|
1266
|
-
|
|
1267
|
-
|
|
1268
|
-
|
|
1454
|
+
if driver.provider == GraphProvider.KUZU:
|
|
1455
|
+
embedding_size = (
|
|
1456
|
+
len(edges[0].fact_embedding) if edges[0].fact_embedding is not None else 0
|
|
1457
|
+
)
|
|
1458
|
+
if embedding_size == 0:
|
|
1459
|
+
return []
|
|
1460
|
+
|
|
1461
|
+
query = (
|
|
1462
|
+
"""
|
|
1463
|
+
UNWIND $edges AS edge
|
|
1464
|
+
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
|
|
1465
|
+
"""
|
|
1466
|
+
+ filter_query
|
|
1467
|
+
+ """
|
|
1468
|
+
WITH e, edge, n, m, """
|
|
1469
|
+
+ get_vector_cosine_func_query(
|
|
1470
|
+
'e.fact_embedding',
|
|
1471
|
+
f'CAST(edge.fact_embedding AS FLOAT[{embedding_size}])',
|
|
1472
|
+
driver.provider,
|
|
1473
|
+
)
|
|
1474
|
+
+ """ AS score
|
|
1475
|
+
WHERE score > $min_score
|
|
1476
|
+
WITH e, edge, n, m, score
|
|
1477
|
+
ORDER BY score DESC
|
|
1478
|
+
LIMIT $limit
|
|
1479
|
+
RETURN
|
|
1480
|
+
edge.uuid AS search_edge_uuid,
|
|
1481
|
+
collect({
|
|
1482
|
+
uuid: e.uuid,
|
|
1483
|
+
source_node_uuid: n.uuid,
|
|
1484
|
+
target_node_uuid: m.uuid,
|
|
1485
|
+
created_at: e.created_at,
|
|
1486
|
+
name: e.name,
|
|
1487
|
+
group_id: e.group_id,
|
|
1488
|
+
fact: e.fact,
|
|
1489
|
+
fact_embedding: e.fact_embedding,
|
|
1490
|
+
episodes: e.episodes,
|
|
1491
|
+
expired_at: e.expired_at,
|
|
1492
|
+
valid_at: e.valid_at,
|
|
1493
|
+
invalid_at: e.invalid_at,
|
|
1494
|
+
attributes: e.attributes
|
|
1495
|
+
}) AS matches
|
|
1496
|
+
"""
|
|
1497
|
+
)
|
|
1498
|
+
else:
|
|
1499
|
+
query = (
|
|
1500
|
+
"""
|
|
1501
|
+
UNWIND $edges AS edge
|
|
1502
|
+
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
|
1503
|
+
"""
|
|
1504
|
+
+ filter_query
|
|
1505
|
+
+ """
|
|
1506
|
+
WITH e, edge, """
|
|
1507
|
+
+ get_vector_cosine_func_query(
|
|
1508
|
+
'e.fact_embedding', 'edge.fact_embedding', driver.provider
|
|
1509
|
+
)
|
|
1510
|
+
+ """ AS score
|
|
1511
|
+
WHERE score > $min_score
|
|
1512
|
+
WITH edge, e, score
|
|
1513
|
+
ORDER BY score DESC
|
|
1514
|
+
RETURN
|
|
1515
|
+
edge.uuid AS search_edge_uuid,
|
|
1516
|
+
collect({
|
|
1517
|
+
uuid: e.uuid,
|
|
1518
|
+
source_node_uuid: startNode(e).uuid,
|
|
1519
|
+
target_node_uuid: endNode(e).uuid,
|
|
1520
|
+
created_at: e.created_at,
|
|
1521
|
+
name: e.name,
|
|
1522
|
+
group_id: e.group_id,
|
|
1523
|
+
fact: e.fact,
|
|
1524
|
+
fact_embedding: e.fact_embedding,
|
|
1525
|
+
episodes: e.episodes,
|
|
1526
|
+
expired_at: e.expired_at,
|
|
1527
|
+
valid_at: e.valid_at,
|
|
1528
|
+
invalid_at: e.invalid_at,
|
|
1529
|
+
attributes: properties(e)
|
|
1530
|
+
})[..$limit] AS matches
|
|
1531
|
+
"""
|
|
1269
1532
|
)
|
|
1270
|
-
+ """ AS score
|
|
1271
|
-
WHERE score > $min_score
|
|
1272
|
-
WITH edge, e, score
|
|
1273
|
-
ORDER BY score DESC
|
|
1274
|
-
RETURN edge.uuid AS search_edge_uuid,
|
|
1275
|
-
collect({
|
|
1276
|
-
uuid: e.uuid,
|
|
1277
|
-
source_node_uuid: startNode(e).uuid,
|
|
1278
|
-
target_node_uuid: endNode(e).uuid,
|
|
1279
|
-
created_at: e.created_at,
|
|
1280
|
-
name: e.name,
|
|
1281
|
-
group_id: e.group_id,
|
|
1282
|
-
fact: e.fact,
|
|
1283
|
-
fact_embedding: e.fact_embedding,
|
|
1284
|
-
episodes: e.episodes,
|
|
1285
|
-
expired_at: e.expired_at,
|
|
1286
|
-
valid_at: e.valid_at,
|
|
1287
|
-
invalid_at: e.invalid_at,
|
|
1288
|
-
attributes: properties(e)
|
|
1289
|
-
})[..$limit] AS matches
|
|
1290
|
-
"""
|
|
1291
|
-
)
|
|
1292
1533
|
|
|
1293
1534
|
results, _, _ = await driver.execute_query(
|
|
1294
1535
|
query,
|
|
@@ -1296,12 +1537,12 @@ async def get_relevant_edges(
|
|
|
1296
1537
|
limit=limit,
|
|
1297
1538
|
min_score=min_score,
|
|
1298
1539
|
routing_='r',
|
|
1299
|
-
**
|
|
1540
|
+
**filter_params,
|
|
1300
1541
|
)
|
|
1301
1542
|
|
|
1302
1543
|
relevant_edges_dict: dict[str, list[EntityEdge]] = {
|
|
1303
1544
|
result['search_edge_uuid']: [
|
|
1304
|
-
get_entity_edge_from_record(record) for record in result['matches']
|
|
1545
|
+
get_entity_edge_from_record(record, driver.provider) for record in result['matches']
|
|
1305
1546
|
]
|
|
1306
1547
|
for result in results
|
|
1307
1548
|
}
|
|
@@ -1321,19 +1562,21 @@ async def get_edge_invalidation_candidates(
|
|
|
1321
1562
|
if len(edges) == 0:
|
|
1322
1563
|
return []
|
|
1323
1564
|
|
|
1324
|
-
|
|
1565
|
+
filter_queries, filter_params = edge_search_filter_query_constructor(
|
|
1566
|
+
search_filter, driver.provider
|
|
1567
|
+
)
|
|
1325
1568
|
|
|
1326
|
-
filter_query
|
|
1327
|
-
|
|
1569
|
+
filter_query = ''
|
|
1570
|
+
if filter_queries:
|
|
1571
|
+
filter_query = ' AND ' + (' AND '.join(filter_queries))
|
|
1328
1572
|
|
|
1329
1573
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
1330
1574
|
query = (
|
|
1331
|
-
RUNTIME_QUERY
|
|
1332
|
-
+ """
|
|
1333
|
-
UNWIND $edges AS edge
|
|
1334
|
-
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
|
1335
|
-
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
|
1336
1575
|
"""
|
|
1576
|
+
UNWIND $edges AS edge
|
|
1577
|
+
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
|
1578
|
+
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
|
1579
|
+
"""
|
|
1337
1580
|
+ filter_query
|
|
1338
1581
|
+ """
|
|
1339
1582
|
WITH e, edge
|
|
@@ -1348,7 +1591,7 @@ async def get_edge_invalidation_candidates(
|
|
|
1348
1591
|
limit=limit,
|
|
1349
1592
|
min_score=min_score,
|
|
1350
1593
|
routing_='r',
|
|
1351
|
-
**
|
|
1594
|
+
**filter_params,
|
|
1352
1595
|
)
|
|
1353
1596
|
|
|
1354
1597
|
# Calculate Cosine similarity then return the edge ids
|
|
@@ -1361,7 +1604,7 @@ async def get_edge_invalidation_candidates(
|
|
|
1361
1604
|
input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})
|
|
1362
1605
|
|
|
1363
1606
|
# Match the edge ides and return the values
|
|
1364
|
-
query = """
|
|
1607
|
+
query = """
|
|
1365
1608
|
UNWIND $ids AS edge
|
|
1366
1609
|
MATCH ()-[e]->()
|
|
1367
1610
|
WHERE id(e) = edge.id
|
|
@@ -1391,44 +1634,90 @@ async def get_edge_invalidation_candidates(
|
|
|
1391
1634
|
limit=limit,
|
|
1392
1635
|
min_score=min_score,
|
|
1393
1636
|
routing_='r',
|
|
1394
|
-
**
|
|
1637
|
+
**filter_params,
|
|
1395
1638
|
)
|
|
1396
1639
|
else:
|
|
1397
|
-
|
|
1398
|
-
|
|
1399
|
-
|
|
1400
|
-
|
|
1401
|
-
|
|
1402
|
-
|
|
1403
|
-
|
|
1404
|
-
|
|
1405
|
-
|
|
1406
|
-
|
|
1407
|
-
|
|
1408
|
-
|
|
1640
|
+
if driver.provider == GraphProvider.KUZU:
|
|
1641
|
+
embedding_size = (
|
|
1642
|
+
len(edges[0].fact_embedding) if edges[0].fact_embedding is not None else 0
|
|
1643
|
+
)
|
|
1644
|
+
if embedding_size == 0:
|
|
1645
|
+
return []
|
|
1646
|
+
|
|
1647
|
+
query = (
|
|
1648
|
+
"""
|
|
1649
|
+
UNWIND $edges AS edge
|
|
1650
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
|
|
1651
|
+
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
|
|
1652
|
+
"""
|
|
1653
|
+
+ filter_query
|
|
1654
|
+
+ """
|
|
1655
|
+
WITH edge, e, n, m, """
|
|
1656
|
+
+ get_vector_cosine_func_query(
|
|
1657
|
+
'e.fact_embedding',
|
|
1658
|
+
f'CAST(edge.fact_embedding AS FLOAT[{embedding_size}])',
|
|
1659
|
+
driver.provider,
|
|
1660
|
+
)
|
|
1661
|
+
+ """ AS score
|
|
1662
|
+
WHERE score > $min_score
|
|
1663
|
+
WITH edge, e, n, m, score
|
|
1664
|
+
ORDER BY score DESC
|
|
1665
|
+
LIMIT $limit
|
|
1666
|
+
RETURN
|
|
1667
|
+
edge.uuid AS search_edge_uuid,
|
|
1668
|
+
collect({
|
|
1669
|
+
uuid: e.uuid,
|
|
1670
|
+
source_node_uuid: n.uuid,
|
|
1671
|
+
target_node_uuid: m.uuid,
|
|
1672
|
+
created_at: e.created_at,
|
|
1673
|
+
name: e.name,
|
|
1674
|
+
group_id: e.group_id,
|
|
1675
|
+
fact: e.fact,
|
|
1676
|
+
fact_embedding: e.fact_embedding,
|
|
1677
|
+
episodes: e.episodes,
|
|
1678
|
+
expired_at: e.expired_at,
|
|
1679
|
+
valid_at: e.valid_at,
|
|
1680
|
+
invalid_at: e.invalid_at,
|
|
1681
|
+
attributes: e.attributes
|
|
1682
|
+
}) AS matches
|
|
1683
|
+
"""
|
|
1684
|
+
)
|
|
1685
|
+
else:
|
|
1686
|
+
query = (
|
|
1687
|
+
"""
|
|
1688
|
+
UNWIND $edges AS edge
|
|
1689
|
+
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
|
1690
|
+
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
|
1691
|
+
"""
|
|
1692
|
+
+ filter_query
|
|
1693
|
+
+ """
|
|
1694
|
+
WITH edge, e, """
|
|
1695
|
+
+ get_vector_cosine_func_query(
|
|
1696
|
+
'e.fact_embedding', 'edge.fact_embedding', driver.provider
|
|
1697
|
+
)
|
|
1698
|
+
+ """ AS score
|
|
1699
|
+
WHERE score > $min_score
|
|
1700
|
+
WITH edge, e, score
|
|
1701
|
+
ORDER BY score DESC
|
|
1702
|
+
RETURN
|
|
1703
|
+
edge.uuid AS search_edge_uuid,
|
|
1704
|
+
collect({
|
|
1705
|
+
uuid: e.uuid,
|
|
1706
|
+
source_node_uuid: startNode(e).uuid,
|
|
1707
|
+
target_node_uuid: endNode(e).uuid,
|
|
1708
|
+
created_at: e.created_at,
|
|
1709
|
+
name: e.name,
|
|
1710
|
+
group_id: e.group_id,
|
|
1711
|
+
fact: e.fact,
|
|
1712
|
+
fact_embedding: e.fact_embedding,
|
|
1713
|
+
episodes: e.episodes,
|
|
1714
|
+
expired_at: e.expired_at,
|
|
1715
|
+
valid_at: e.valid_at,
|
|
1716
|
+
invalid_at: e.invalid_at,
|
|
1717
|
+
attributes: properties(e)
|
|
1718
|
+
})[..$limit] AS matches
|
|
1719
|
+
"""
|
|
1409
1720
|
)
|
|
1410
|
-
+ """ AS score
|
|
1411
|
-
WHERE score > $min_score
|
|
1412
|
-
WITH edge, e, score
|
|
1413
|
-
ORDER BY score DESC
|
|
1414
|
-
RETURN edge.uuid AS search_edge_uuid,
|
|
1415
|
-
collect({
|
|
1416
|
-
uuid: e.uuid,
|
|
1417
|
-
source_node_uuid: startNode(e).uuid,
|
|
1418
|
-
target_node_uuid: endNode(e).uuid,
|
|
1419
|
-
created_at: e.created_at,
|
|
1420
|
-
name: e.name,
|
|
1421
|
-
group_id: e.group_id,
|
|
1422
|
-
fact: e.fact,
|
|
1423
|
-
fact_embedding: e.fact_embedding,
|
|
1424
|
-
episodes: e.episodes,
|
|
1425
|
-
expired_at: e.expired_at,
|
|
1426
|
-
valid_at: e.valid_at,
|
|
1427
|
-
invalid_at: e.invalid_at,
|
|
1428
|
-
attributes: properties(e)
|
|
1429
|
-
})[..$limit] AS matches
|
|
1430
|
-
"""
|
|
1431
|
-
)
|
|
1432
1721
|
|
|
1433
1722
|
results, _, _ = await driver.execute_query(
|
|
1434
1723
|
query,
|
|
@@ -1436,11 +1725,11 @@ async def get_edge_invalidation_candidates(
|
|
|
1436
1725
|
limit=limit,
|
|
1437
1726
|
min_score=min_score,
|
|
1438
1727
|
routing_='r',
|
|
1439
|
-
**
|
|
1728
|
+
**filter_params,
|
|
1440
1729
|
)
|
|
1441
1730
|
invalidation_edges_dict: dict[str, list[EntityEdge]] = {
|
|
1442
1731
|
result['search_edge_uuid']: [
|
|
1443
|
-
get_entity_edge_from_record(record) for record in result['matches']
|
|
1732
|
+
get_entity_edge_from_record(record, driver.provider) for record in result['matches']
|
|
1444
1733
|
]
|
|
1445
1734
|
for result in results
|
|
1446
1735
|
}
|
|
@@ -1479,13 +1768,21 @@ async def node_distance_reranker(
|
|
|
1479
1768
|
filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
|
|
1480
1769
|
scores: dict[str, float] = {center_node_uuid: 0.0}
|
|
1481
1770
|
|
|
1482
|
-
|
|
1483
|
-
|
|
1484
|
-
|
|
1771
|
+
query = """
|
|
1772
|
+
UNWIND $node_uuids AS node_uuid
|
|
1773
|
+
MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
|
|
1774
|
+
RETURN 1 AS score, node_uuid AS uuid
|
|
1775
|
+
"""
|
|
1776
|
+
if driver.provider == GraphProvider.KUZU:
|
|
1777
|
+
query = """
|
|
1485
1778
|
UNWIND $node_uuids AS node_uuid
|
|
1486
|
-
MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
|
|
1779
|
+
MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(n:Entity {uuid: node_uuid})
|
|
1487
1780
|
RETURN 1 AS score, node_uuid AS uuid
|
|
1488
|
-
"""
|
|
1781
|
+
"""
|
|
1782
|
+
|
|
1783
|
+
# Find the shortest path to center node
|
|
1784
|
+
results, header, _ = await driver.execute_query(
|
|
1785
|
+
query,
|
|
1489
1786
|
node_uuids=filtered_uuids,
|
|
1490
1787
|
center_uuid=center_node_uuid,
|
|
1491
1788
|
routing_='r',
|
|
@@ -1536,6 +1833,10 @@ async def episode_mentions_reranker(
|
|
|
1536
1833
|
for result in results:
|
|
1537
1834
|
scores[result['uuid']] = result['score']
|
|
1538
1835
|
|
|
1836
|
+
for uuid in sorted_uuids:
|
|
1837
|
+
if uuid not in scores:
|
|
1838
|
+
scores[uuid] = float('inf')
|
|
1839
|
+
|
|
1539
1840
|
# rerank on shortest distance
|
|
1540
1841
|
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
|
1541
1842
|
|
|
@@ -1667,13 +1968,23 @@ async def get_embeddings_for_edges(
|
|
|
1667
1968
|
split(e.fact_embedding, ",") AS fact_embedding
|
|
1668
1969
|
"""
|
|
1669
1970
|
else:
|
|
1670
|
-
|
|
1671
|
-
|
|
1971
|
+
match_query = """
|
|
1972
|
+
MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
|
1973
|
+
"""
|
|
1974
|
+
if driver.provider == GraphProvider.KUZU:
|
|
1975
|
+
match_query = """
|
|
1976
|
+
MATCH (n:Entity)-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(m:Entity)
|
|
1977
|
+
"""
|
|
1978
|
+
|
|
1979
|
+
query = (
|
|
1980
|
+
match_query
|
|
1981
|
+
+ """
|
|
1672
1982
|
WHERE e.uuid IN $edge_uuids
|
|
1673
1983
|
RETURN DISTINCT
|
|
1674
1984
|
e.uuid AS uuid,
|
|
1675
1985
|
e.fact_embedding AS fact_embedding
|
|
1676
1986
|
"""
|
|
1987
|
+
)
|
|
1677
1988
|
results, _, _ = await driver.execute_query(
|
|
1678
1989
|
query,
|
|
1679
1990
|
edge_uuids=[edge.uuid for edge in edges],
|