graphiti-core 0.19.0rc2__py3-none-any.whl → 0.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -32,15 +32,17 @@ from graphiti_core.graph_queries import (
32
32
  get_vector_cosine_func_query,
33
33
  )
34
34
  from graphiti_core.helpers import (
35
- RUNTIME_QUERY,
36
35
  lucene_sanitize,
37
36
  normalize_l2,
38
37
  semaphore_gather,
39
38
  )
40
- from graphiti_core.models.edges.edge_db_queries import ENTITY_EDGE_RETURN
41
- from graphiti_core.models.nodes.node_db_queries import COMMUNITY_NODE_RETURN, EPISODIC_NODE_RETURN
39
+ from graphiti_core.models.edges.edge_db_queries import get_entity_edge_return_query
40
+ from graphiti_core.models.nodes.node_db_queries import (
41
+ COMMUNITY_NODE_RETURN,
42
+ EPISODIC_NODE_RETURN,
43
+ get_entity_node_return_query,
44
+ )
42
45
  from graphiti_core.nodes import (
43
- ENTITY_NODE_RETURN,
44
46
  CommunityNode,
45
47
  EntityNode,
46
48
  EpisodicNode,
@@ -78,9 +80,16 @@ def calculate_cosine_similarity(vector1: list[float], vector2: list[float]) -> f
78
80
  return dot_product / (norm_vector1 * norm_vector2)
79
81
 
80
82
 
81
- def fulltext_query(query: str, group_ids: list[str] | None = None, fulltext_syntax: str = ''):
83
+ def fulltext_query(query: str, group_ids: list[str] | None, driver: GraphDriver):
84
+ if driver.provider == GraphProvider.KUZU:
85
+ # Kuzu only supports simple queries.
86
+ if len(query.split(' ')) > MAX_QUERY_LENGTH:
87
+ return ''
88
+ return query
82
89
  group_ids_filter_list = (
83
- [fulltext_syntax + f'group_id:"{g}"' for g in group_ids] if group_ids is not None else []
90
+ [driver.fulltext_syntax + f'group_id:"{g}"' for g in group_ids]
91
+ if group_ids is not None
92
+ else []
84
93
  )
85
94
  group_ids_filter = ''
86
95
  for f in group_ids_filter_list:
@@ -124,12 +133,12 @@ async def get_mentioned_nodes(
124
133
  WHERE episode.uuid IN $uuids
125
134
  RETURN DISTINCT
126
135
  """
127
- + ENTITY_NODE_RETURN,
136
+ + get_entity_node_return_query(driver.provider),
128
137
  uuids=episode_uuids,
129
138
  routing_='r',
130
139
  )
131
140
 
132
- nodes = [get_entity_node_from_record(record) for record in records]
141
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
133
142
 
134
143
  return nodes
135
144
 
@@ -141,7 +150,7 @@ async def get_communities_by_nodes(
141
150
 
142
151
  records, _, _ = await driver.execute_query(
143
152
  """
144
- MATCH (n:Community)-[:HAS_MEMBER]->(m:Entity)
153
+ MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)
145
154
  WHERE m.uuid IN $uuids
146
155
  RETURN DISTINCT
147
156
  """
@@ -163,11 +172,32 @@ async def edge_fulltext_search(
163
172
  limit=RELEVANT_SCHEMA_LIMIT,
164
173
  ) -> list[EntityEdge]:
165
174
  # fulltext search over facts
166
- fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
175
+ fuzzy_query = fulltext_query(query, group_ids, driver)
176
+
167
177
  if fuzzy_query == '':
168
178
  return []
169
179
 
170
- filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
180
+ match_query = """
181
+ YIELD relationship AS rel, score
182
+ MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
183
+ """
184
+ if driver.provider == GraphProvider.KUZU:
185
+ match_query = """
186
+ YIELD node, score
187
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: node.uuid})-[:RELATES_TO]->(m:Entity)
188
+ """
189
+
190
+ filter_queries, filter_params = edge_search_filter_query_constructor(
191
+ search_filter, driver.provider
192
+ )
193
+
194
+ if group_ids is not None:
195
+ filter_queries.append('e.group_id IN $group_ids')
196
+ filter_params['group_ids'] = group_ids
197
+
198
+ filter_query = ''
199
+ if filter_queries:
200
+ filter_query = ' WHERE ' + (' AND '.join(filter_queries))
171
201
 
172
202
  if driver.provider == GraphProvider.NEPTUNE:
173
203
  res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue
@@ -180,13 +210,14 @@ async def edge_fulltext_search(
180
210
  # Match the edge ids and return the values
181
211
  query = (
182
212
  """
183
- UNWIND $ids as id
184
- MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
185
- WHERE e.group_id IN $group_ids
186
- AND id(e)=id
187
- """
213
+ UNWIND $ids as id
214
+ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
215
+ WHERE e.group_id IN $group_ids
216
+ AND id(e)=id
217
+ """
188
218
  + filter_query
189
219
  + """
220
+ AND id(e)=id
190
221
  WITH e, id.score as score, startNode(e) AS n, endNode(e) AS m
191
222
  RETURN
192
223
  e.uuid AS uuid,
@@ -208,7 +239,6 @@ async def edge_fulltext_search(
208
239
  records, _, _ = await driver.execute_query(
209
240
  query,
210
241
  query=fuzzy_query,
211
- group_ids=group_ids,
212
242
  ids=input_ids,
213
243
  limit=limit,
214
244
  routing_='r',
@@ -218,17 +248,14 @@ async def edge_fulltext_search(
218
248
  return []
219
249
  else:
220
250
  query = (
221
- get_relationships_query('edge_name_and_fact', provider=driver.provider)
222
- + """
223
- YIELD relationship AS rel, score
224
- MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
225
- WHERE e.group_id IN $group_ids """
251
+ get_relationships_query('edge_name_and_fact', limit=limit, provider=driver.provider)
252
+ + match_query
226
253
  + filter_query
227
254
  + """
228
255
  WITH e, score, n, m
229
256
  RETURN
230
257
  """
231
- + ENTITY_EDGE_RETURN
258
+ + get_entity_edge_return_query(driver.provider)
232
259
  + """
233
260
  ORDER BY score DESC
234
261
  LIMIT $limit
@@ -238,13 +265,12 @@ async def edge_fulltext_search(
238
265
  records, _, _ = await driver.execute_query(
239
266
  query,
240
267
  query=fuzzy_query,
241
- group_ids=group_ids,
242
268
  limit=limit,
243
269
  routing_='r',
244
270
  **filter_params,
245
271
  )
246
272
 
247
- edges = [get_entity_edge_from_record(record) for record in records]
273
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
248
274
 
249
275
  return edges
250
276
 
@@ -259,32 +285,43 @@ async def edge_similarity_search(
259
285
  limit: int = RELEVANT_SCHEMA_LIMIT,
260
286
  min_score: float = DEFAULT_MIN_SCORE,
261
287
  ) -> list[EntityEdge]:
262
- # vector similarity search over embedded facts
263
- query_params: dict[str, Any] = {}
288
+ match_query = """
289
+ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
290
+ """
291
+ if driver.provider == GraphProvider.KUZU:
292
+ match_query = """
293
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
294
+ """
264
295
 
265
- filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
266
- query_params.update(filter_params)
296
+ filter_queries, filter_params = edge_search_filter_query_constructor(
297
+ search_filter, driver.provider
298
+ )
267
299
 
268
- group_filter_query: LiteralString = 'WHERE e.group_id IS NOT NULL'
269
300
  if group_ids is not None:
270
- group_filter_query += '\nAND e.group_id IN $group_ids'
271
- query_params['group_ids'] = group_ids
301
+ filter_queries.append('e.group_id IN $group_ids')
302
+ filter_params['group_ids'] = group_ids
272
303
 
273
304
  if source_node_uuid is not None:
274
- query_params['source_uuid'] = source_node_uuid
275
- group_filter_query += '\nAND (n.uuid = $source_uuid)'
305
+ filter_params['source_uuid'] = source_node_uuid
306
+ filter_queries.append('n.uuid = $source_uuid')
276
307
 
277
308
  if target_node_uuid is not None:
278
- query_params['target_uuid'] = target_node_uuid
279
- group_filter_query += '\nAND (m.uuid = $target_uuid)'
309
+ filter_params['target_uuid'] = target_node_uuid
310
+ filter_queries.append('m.uuid = $target_uuid')
311
+
312
+ filter_query = ''
313
+ if filter_queries:
314
+ filter_query = ' WHERE ' + (' AND '.join(filter_queries))
315
+
316
+ search_vector_var = '$search_vector'
317
+ if driver.provider == GraphProvider.KUZU:
318
+ search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
280
319
 
281
320
  if driver.provider == GraphProvider.NEPTUNE:
282
321
  query = (
283
- RUNTIME_QUERY
284
- + """
285
- MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
286
322
  """
287
- + group_filter_query
323
+ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
324
+ """
288
325
  + filter_query
289
326
  + """
290
327
  RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
@@ -296,7 +333,7 @@ async def edge_similarity_search(
296
333
  limit=limit,
297
334
  min_score=min_score,
298
335
  routing_='r',
299
- **query_params,
336
+ **filter_params,
300
337
  )
301
338
 
302
339
  if len(resp) > 0:
@@ -338,26 +375,22 @@ async def edge_similarity_search(
338
375
  limit=limit,
339
376
  min_score=min_score,
340
377
  routing_='r',
341
- **query_params,
378
+ **filter_params,
342
379
  )
343
380
  else:
344
381
  return []
345
382
  else:
346
383
  query = (
347
- RUNTIME_QUERY
348
- + """
349
- MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
350
- """
351
- + group_filter_query
384
+ match_query
352
385
  + filter_query
353
386
  + """
354
387
  WITH DISTINCT e, n, m, """
355
- + get_vector_cosine_func_query('e.fact_embedding', '$search_vector', driver.provider)
388
+ + get_vector_cosine_func_query('e.fact_embedding', search_vector_var, driver.provider)
356
389
  + """ AS score
357
390
  WHERE score > $min_score
358
391
  RETURN
359
392
  """
360
- + ENTITY_EDGE_RETURN
393
+ + get_entity_edge_return_query(driver.provider)
361
394
  + """
362
395
  ORDER BY score DESC
363
396
  LIMIT $limit
@@ -370,10 +403,10 @@ async def edge_similarity_search(
370
403
  limit=limit,
371
404
  min_score=min_score,
372
405
  routing_='r',
373
- **query_params,
406
+ **filter_params,
374
407
  )
375
408
 
376
- edges = [get_entity_edge_from_record(record) for record in records]
409
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
377
410
 
378
411
  return edges
379
412
 
@@ -387,70 +420,116 @@ async def edge_bfs_search(
387
420
  limit: int = RELEVANT_SCHEMA_LIMIT,
388
421
  ) -> list[EntityEdge]:
389
422
  # vector similarity search over embedded facts
390
- if bfs_origin_node_uuids is None:
423
+ if bfs_origin_node_uuids is None or len(bfs_origin_node_uuids) == 0:
391
424
  return []
392
425
 
393
- filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
426
+ filter_queries, filter_params = edge_search_filter_query_constructor(
427
+ search_filter, driver.provider
428
+ )
394
429
 
395
- if driver.provider == GraphProvider.NEPTUNE:
396
- query = (
430
+ if group_ids is not None:
431
+ filter_queries.append('e.group_id IN $group_ids')
432
+ filter_params['group_ids'] = group_ids
433
+
434
+ filter_query = ''
435
+ if filter_queries:
436
+ filter_query = ' WHERE ' + (' AND '.join(filter_queries))
437
+
438
+ if driver.provider == GraphProvider.KUZU:
439
+ # Kuzu stores entity edges twice with an intermediate node, so we need to match them
440
+ # separately for the correct BFS depth.
441
+ depth = bfs_max_depth * 2 - 1
442
+ match_queries = [
397
443
  f"""
398
- UNWIND $bfs_origin_node_uuids AS origin_uuid
399
- MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS *1..{bfs_max_depth}]->(n:Entity)
400
- WHERE origin:Entity OR origin:Episodic
401
- UNWIND relationships(path) AS rel
402
- MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
403
- WHERE e.uuid = rel.uuid
404
- """
405
- + filter_query
406
- + """
407
- RETURN DISTINCT
408
- e.uuid AS uuid,
409
- e.group_id AS group_id,
410
- startNode(e).uuid AS source_node_uuid,
411
- endNode(e).uuid AS target_node_uuid,
412
- e.created_at AS created_at,
413
- e.name AS name,
414
- e.fact AS fact,
415
- split(e.episodes, ',') AS episodes,
416
- e.expired_at AS expired_at,
417
- e.valid_at AS valid_at,
418
- e.invalid_at AS invalid_at,
419
- properties(e) AS attributes
420
- LIMIT $limit
444
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
445
+ MATCH path = (origin:Entity {{uuid: origin_uuid}})-[:RELATES_TO*1..{depth}]->(:RelatesToNode_)
446
+ UNWIND nodes(path) AS relNode
447
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {{uuid: relNode.uuid}})-[:RELATES_TO]->(m:Entity)
448
+ """,
449
+ ]
450
+ if bfs_max_depth > 1:
451
+ depth = (bfs_max_depth - 1) * 2 - 1
452
+ match_queries.append(f"""
453
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
454
+ MATCH path = (origin:Episodic {{uuid: origin_uuid}})-[:MENTIONS]->(:Entity)-[:RELATES_TO*1..{depth}]->(:RelatesToNode_)
455
+ UNWIND nodes(path) AS relNode
456
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {{uuid: relNode.uuid}})-[:RELATES_TO]->(m:Entity)
457
+ """)
458
+
459
+ records = []
460
+ for match_query in match_queries:
461
+ sub_records, _, _ = await driver.execute_query(
462
+ match_query
463
+ + filter_query
464
+ + """
465
+ RETURN DISTINCT
421
466
  """
422
- )
467
+ + get_entity_edge_return_query(driver.provider)
468
+ + """
469
+ LIMIT $limit
470
+ """,
471
+ bfs_origin_node_uuids=bfs_origin_node_uuids,
472
+ limit=limit,
473
+ routing_='r',
474
+ **filter_params,
475
+ )
476
+ records.extend(sub_records)
423
477
  else:
424
- query = (
425
- f"""
478
+ if driver.provider == GraphProvider.NEPTUNE:
479
+ query = (
480
+ f"""
426
481
  UNWIND $bfs_origin_node_uuids AS origin_uuid
427
- MATCH path = (origin:Entity|Episodic {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(:Entity)
482
+ MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS *1..{bfs_max_depth}]->(n:Entity)
483
+ WHERE origin:Entity OR origin:Episodic
428
484
  UNWIND relationships(path) AS rel
429
- MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
430
- WHERE e.uuid = rel.uuid
431
- AND e.group_id IN $group_ids
432
- """
433
- + filter_query
434
- + """
435
- RETURN DISTINCT
436
- """
437
- + ENTITY_EDGE_RETURN
438
- + """
439
- LIMIT $limit
440
- """
441
- )
485
+ MATCH (n:Entity)-[e:RELATES_TO {{uuid: rel.uuid}}]-(m:Entity)
486
+ """
487
+ + filter_query
488
+ + """
489
+ RETURN DISTINCT
490
+ e.uuid AS uuid,
491
+ e.group_id AS group_id,
492
+ startNode(e).uuid AS source_node_uuid,
493
+ endNode(e).uuid AS target_node_uuid,
494
+ e.created_at AS created_at,
495
+ e.name AS name,
496
+ e.fact AS fact,
497
+ split(e.episodes, ',') AS episodes,
498
+ e.expired_at AS expired_at,
499
+ e.valid_at AS valid_at,
500
+ e.invalid_at AS invalid_at,
501
+ properties(e) AS attributes
502
+ LIMIT $limit
503
+ """
504
+ )
505
+ else:
506
+ query = (
507
+ f"""
508
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
509
+ MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(:Entity)
510
+ UNWIND relationships(path) AS rel
511
+ MATCH (n:Entity)-[e:RELATES_TO {{uuid: rel.uuid}}]-(m:Entity)
512
+ """
513
+ + filter_query
514
+ + """
515
+ RETURN DISTINCT
516
+ """
517
+ + get_entity_edge_return_query(driver.provider)
518
+ + """
519
+ LIMIT $limit
520
+ """
521
+ )
442
522
 
443
- records, _, _ = await driver.execute_query(
444
- query,
445
- bfs_origin_node_uuids=bfs_origin_node_uuids,
446
- depth=bfs_max_depth,
447
- group_ids=group_ids,
448
- limit=limit,
449
- routing_='r',
450
- **filter_params,
451
- )
523
+ records, _, _ = await driver.execute_query(
524
+ query,
525
+ bfs_origin_node_uuids=bfs_origin_node_uuids,
526
+ depth=bfs_max_depth,
527
+ limit=limit,
528
+ routing_='r',
529
+ **filter_params,
530
+ )
452
531
 
453
- edges = [get_entity_edge_from_record(record) for record in records]
532
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
454
533
 
455
534
  return edges
456
535
 
@@ -461,12 +540,28 @@ async def node_fulltext_search(
461
540
  search_filter: SearchFilters,
462
541
  group_ids: list[str] | None = None,
463
542
  limit=RELEVANT_SCHEMA_LIMIT,
543
+ use_local_indexes: bool = False,
464
544
  ) -> list[EntityNode]:
465
545
  # BM25 search to get top nodes
466
- fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
546
+ fuzzy_query = fulltext_query(query, group_ids, driver)
467
547
  if fuzzy_query == '':
468
548
  return []
469
- filter_query, filter_params = node_search_filter_query_constructor(search_filter)
549
+
550
+ filter_queries, filter_params = node_search_filter_query_constructor(
551
+ search_filter, driver.provider
552
+ )
553
+
554
+ if group_ids is not None:
555
+ filter_queries.append('n.group_id IN $group_ids')
556
+ filter_params['group_ids'] = group_ids
557
+
558
+ filter_query = ''
559
+ if filter_queries:
560
+ filter_query = ' WHERE ' + (' AND '.join(filter_queries))
561
+
562
+ yield_query = 'YIELD node AS n, score'
563
+ if driver.provider == GraphProvider.KUZU:
564
+ yield_query = 'WITH node AS n, score'
470
565
 
471
566
  if driver.provider == GraphProvider.NEPTUNE:
472
567
  res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
@@ -479,12 +574,12 @@ async def node_fulltext_search(
479
574
  # Match the edge ides and return the values
480
575
  query = (
481
576
  """
482
- UNWIND $ids as i
483
- MATCH (n:Entity)
484
- WHERE n.uuid=i.id
485
- RETURN
486
- """
487
- + ENTITY_NODE_RETURN
577
+ UNWIND $ids as i
578
+ MATCH (n:Entity)
579
+ WHERE n.uuid=i.id
580
+ RETURN
581
+ """
582
+ + get_entity_node_return_query(driver.provider)
488
583
  + """
489
584
  ORDER BY i.score DESC
490
585
  LIMIT $limit
@@ -494,7 +589,6 @@ async def node_fulltext_search(
494
589
  query,
495
590
  ids=input_ids,
496
591
  query=fuzzy_query,
497
- group_ids=group_ids,
498
592
  limit=limit,
499
593
  routing_='r',
500
594
  **filter_params,
@@ -504,36 +598,32 @@ async def node_fulltext_search(
504
598
  else:
505
599
  index_name = (
506
600
  'node_name_and_summary'
507
- if not USE_HNSW
601
+ if not use_local_indexes
508
602
  else 'node_name_and_summary_'
509
603
  + (group_ids[0].replace('-', '') if group_ids is not None else '')
510
604
  )
511
605
  query = (
512
- get_nodes_query(driver.provider, index_name, '$query')
513
- + """
514
- YIELD node AS n, score
515
- WHERE n:Entity AND n.group_id IN $group_ids
516
- """
606
+ get_nodes_query(index_name, '$query', limit=limit, provider=driver.provider)
607
+ + yield_query
517
608
  + filter_query
518
609
  + """
519
- WITH n, score
520
- ORDER BY score DESC
521
- LIMIT $limit
522
- RETURN
523
- """
524
- + ENTITY_NODE_RETURN
610
+ WITH n, score
611
+ ORDER BY score DESC
612
+ LIMIT $limit
613
+ RETURN
614
+ """
615
+ + get_entity_node_return_query(driver.provider)
525
616
  )
526
617
 
527
618
  records, _, _ = await driver.execute_query(
528
619
  query,
529
620
  query=fuzzy_query,
530
- group_ids=group_ids,
531
621
  limit=limit,
532
622
  routing_='r',
533
623
  **filter_params,
534
624
  )
535
625
 
536
- nodes = [get_entity_node_from_record(record) for record in records]
626
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
537
627
 
538
628
  return nodes
539
629
 
@@ -545,25 +635,29 @@ async def node_similarity_search(
545
635
  group_ids: list[str] | None = None,
546
636
  limit=RELEVANT_SCHEMA_LIMIT,
547
637
  min_score: float = DEFAULT_MIN_SCORE,
638
+ use_local_indexes: bool = False,
548
639
  ) -> list[EntityNode]:
549
- # vector similarity search over entity names
550
- query_params: dict[str, Any] = {}
640
+ filter_queries, filter_params = node_search_filter_query_constructor(
641
+ search_filter, driver.provider
642
+ )
551
643
 
552
- group_filter_query: LiteralString = 'WHERE n.group_id IS NOT NULL'
553
644
  if group_ids is not None:
554
- group_filter_query += ' AND n.group_id IN $group_ids'
555
- query_params['group_ids'] = group_ids
645
+ filter_queries.append('n.group_id IN $group_ids')
646
+ filter_params['group_ids'] = group_ids
556
647
 
557
- filter_query, filter_params = node_search_filter_query_constructor(search_filter)
558
- query_params.update(filter_params)
648
+ filter_query = ''
649
+ if filter_queries:
650
+ filter_query = ' WHERE ' + (' AND '.join(filter_queries))
651
+
652
+ search_vector_var = '$search_vector'
653
+ if driver.provider == GraphProvider.KUZU:
654
+ search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
559
655
 
560
656
  if driver.provider == GraphProvider.NEPTUNE:
561
657
  query = (
562
- RUNTIME_QUERY
563
- + """
564
- MATCH (n:Entity)
565
658
  """
566
- + group_filter_query
659
+ MATCH (n:Entity)
660
+ """
567
661
  + filter_query
568
662
  + """
569
663
  RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@@ -571,9 +665,8 @@ async def node_similarity_search(
571
665
  )
572
666
  resp, header, _ = await driver.execute_query(
573
667
  query,
574
- params=query_params,
668
+ params=filter_params,
575
669
  search_vector=search_vector,
576
- group_ids=group_ids,
577
670
  limit=limit,
578
671
  min_score=min_score,
579
672
  routing_='r',
@@ -593,12 +686,12 @@ async def node_similarity_search(
593
686
  # Match the edge ides and return the values
594
687
  query = (
595
688
  """
596
- UNWIND $ids as i
597
- MATCH (n:Entity)
598
- WHERE id(n)=i.id
599
- RETURN
600
- """
601
- + ENTITY_NODE_RETURN
689
+ UNWIND $ids as i
690
+ MATCH (n:Entity)
691
+ WHERE id(n)=i.id
692
+ RETURN
693
+ """
694
+ + get_entity_node_return_query(driver.provider)
602
695
  + """
603
696
  ORDER BY i.score DESC
604
697
  LIMIT $limit
@@ -611,11 +704,11 @@ async def node_similarity_search(
611
704
  limit=limit,
612
705
  min_score=min_score,
613
706
  routing_='r',
614
- **query_params,
707
+ **filter_params,
615
708
  )
616
709
  else:
617
710
  return []
618
- elif driver.provider == GraphProvider.NEO4J and USE_HNSW:
711
+ elif driver.provider == GraphProvider.NEO4J and use_local_indexes:
619
712
  index_name = 'group_entity_vector_' + (
620
713
  group_ids[0].replace('-', '') if group_ids is not None else ''
621
714
  )
@@ -623,13 +716,12 @@ async def node_similarity_search(
623
716
  f"""
624
717
  CALL db.index.vector.queryNodes('{index_name}', {limit}, $search_vector) YIELD node AS n, score
625
718
  """
626
- + group_filter_query
627
719
  + filter_query
628
720
  + """
629
721
  AND score > $min_score
630
722
  RETURN
631
723
  """
632
- + ENTITY_NODE_RETURN
724
+ + get_entity_node_return_query(driver.provider)
633
725
  + """
634
726
  ORDER BY score DESC
635
727
  LIMIT $limit
@@ -642,25 +734,23 @@ async def node_similarity_search(
642
734
  limit=limit,
643
735
  min_score=min_score,
644
736
  routing_='r',
645
- **query_params,
737
+ **filter_params,
646
738
  )
647
739
 
648
740
  else:
649
741
  query = (
650
- RUNTIME_QUERY
651
- + """
652
- MATCH (n:Entity)
653
742
  """
654
- + group_filter_query
743
+ MATCH (n:Entity)
744
+ """
655
745
  + filter_query
656
746
  + """
657
747
  WITH n, """
658
- + get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
748
+ + get_vector_cosine_func_query('n.name_embedding', search_vector_var, driver.provider)
659
749
  + """ AS score
660
750
  WHERE score > $min_score
661
751
  RETURN
662
752
  """
663
- + ENTITY_NODE_RETURN
753
+ + get_entity_node_return_query(driver.provider)
664
754
  + """
665
755
  ORDER BY score DESC
666
756
  LIMIT $limit
@@ -673,10 +763,10 @@ async def node_similarity_search(
673
763
  limit=limit,
674
764
  min_score=min_score,
675
765
  routing_='r',
676
- **query_params,
766
+ **filter_params,
677
767
  )
678
768
 
679
- nodes = [get_entity_node_from_record(record) for record in records]
769
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
680
770
 
681
771
  return nodes
682
772
 
@@ -689,56 +779,82 @@ async def node_bfs_search(
689
779
  group_ids: list[str] | None = None,
690
780
  limit: int = RELEVANT_SCHEMA_LIMIT,
691
781
  ) -> list[EntityNode]:
692
- # vector similarity search over entity names
693
- if bfs_origin_node_uuids is None:
782
+ if bfs_origin_node_uuids is None or len(bfs_origin_node_uuids) == 0 or bfs_max_depth < 1:
694
783
  return []
695
784
 
696
- filter_query, filter_params = node_search_filter_query_constructor(search_filter)
785
+ filter_queries, filter_params = node_search_filter_query_constructor(
786
+ search_filter, driver.provider
787
+ )
788
+
789
+ if group_ids is not None:
790
+ filter_queries.append('n.group_id IN $group_ids')
791
+ filter_queries.append('origin.group_id IN $group_ids')
792
+ filter_params['group_ids'] = group_ids
793
+
794
+ filter_query = ''
795
+ if filter_queries:
796
+ filter_query = ' AND ' + (' AND '.join(filter_queries))
797
+
798
+ match_queries = [
799
+ f"""
800
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
801
+ MATCH (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
802
+ WHERE n.group_id = origin.group_id
803
+ """
804
+ ]
697
805
 
698
806
  if driver.provider == GraphProvider.NEPTUNE:
699
- query = (
807
+ match_queries = [
700
808
  f"""
701
- UNWIND $bfs_origin_node_uuids AS origin_uuid
702
- MATCH (origin {{uuid: origin_uuid}})-[e:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
703
- WHERE origin:Entity OR origin.Episode
704
- AND n.group_id = origin.group_id
705
- """
706
- + filter_query
707
- + """
708
- RETURN
809
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
810
+ MATCH (origin {{uuid: origin_uuid}})-[e:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
811
+ WHERE origin:Entity OR origin.Episode
812
+ AND n.group_id = origin.group_id
709
813
  """
710
- + ENTITY_NODE_RETURN
711
- + """
712
- LIMIT $limit
814
+ ]
815
+
816
+ if driver.provider == GraphProvider.KUZU:
817
+ depth = bfs_max_depth * 2
818
+ match_queries = [
713
819
  """
714
- )
715
- else:
716
- query = (
820
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
821
+ MATCH (origin:Episodic {uuid: origin_uuid})-[:MENTIONS]->(n:Entity)
822
+ WHERE n.group_id = origin.group_id
823
+ """,
717
824
  f"""
825
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
826
+ MATCH (origin:Entity {{uuid: origin_uuid}})-[:RELATES_TO*2..{depth}]->(n:Entity)
827
+ WHERE n.group_id = origin.group_id
828
+ """,
829
+ ]
830
+ if bfs_max_depth > 1:
831
+ depth = (bfs_max_depth - 1) * 2
832
+ match_queries.append(f"""
718
833
  UNWIND $bfs_origin_node_uuids AS origin_uuid
719
- MATCH (origin:Entity|Episodic {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
834
+ MATCH (origin:Episodic {{uuid: origin_uuid}})-[:MENTIONS]->(:Entity)-[:RELATES_TO*2..{depth}]->(n:Entity)
720
835
  WHERE n.group_id = origin.group_id
721
- AND origin.group_id IN $group_ids
722
- """
836
+ """)
837
+
838
+ records = []
839
+ for match_query in match_queries:
840
+ sub_records, _, _ = await driver.execute_query(
841
+ match_query
723
842
  + filter_query
724
843
  + """
725
844
  RETURN
726
845
  """
727
- + ENTITY_NODE_RETURN
846
+ + get_entity_node_return_query(driver.provider)
728
847
  + """
729
848
  LIMIT $limit
730
- """
849
+ """,
850
+ bfs_origin_node_uuids=bfs_origin_node_uuids,
851
+ limit=limit,
852
+ routing_='r',
853
+ **filter_params,
731
854
  )
855
+ records.extend(sub_records)
732
856
 
733
- records, _, _ = await driver.execute_query(
734
- query,
735
- bfs_origin_node_uuids=bfs_origin_node_uuids,
736
- group_ids=group_ids,
737
- limit=limit,
738
- routing_='r',
739
- **filter_params,
740
- )
741
- nodes = [get_entity_node_from_record(record) for record in records]
857
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
742
858
 
743
859
  return nodes
744
860
 
@@ -749,12 +865,19 @@ async def episode_fulltext_search(
749
865
  _search_filter: SearchFilters,
750
866
  group_ids: list[str] | None = None,
751
867
  limit=RELEVANT_SCHEMA_LIMIT,
868
+ use_local_indexes: bool = False,
752
869
  ) -> list[EpisodicNode]:
753
870
  # BM25 search to get top episodes
754
- fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
871
+ fuzzy_query = fulltext_query(query, group_ids, driver)
755
872
  if fuzzy_query == '':
756
873
  return []
757
874
 
875
+ filter_params: dict[str, Any] = {}
876
+ group_filter_query: LiteralString = ''
877
+ if group_ids is not None:
878
+ group_filter_query += '\nAND e.group_id IN $group_ids'
879
+ filter_params['group_ids'] = group_ids
880
+
758
881
  if driver.provider == GraphProvider.NEPTUNE:
759
882
  res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
760
883
  if res['hits']['total']['value'] > 0:
@@ -768,7 +891,7 @@ async def episode_fulltext_search(
768
891
  UNWIND $ids as i
769
892
  MATCH (e:Episodic)
770
893
  WHERE e.uuid=i.id
771
- RETURN
894
+ RETURN
772
895
  e.content AS content,
773
896
  e.created_at AS created_at,
774
897
  e.valid_at AS valid_at,
@@ -785,26 +908,28 @@ async def episode_fulltext_search(
785
908
  query,
786
909
  ids=input_ids,
787
910
  query=fuzzy_query,
788
- group_ids=group_ids,
789
911
  limit=limit,
790
912
  routing_='r',
913
+ **filter_params,
791
914
  )
792
915
  else:
793
916
  return []
794
917
  else:
795
918
  index_name = (
796
919
  'episode_content'
797
- if not USE_HNSW
920
+ if not use_local_indexes
798
921
  else 'episode_content_'
799
922
  + (group_ids[0].replace('-', '') if group_ids is not None else '')
800
923
  )
801
924
  query = (
802
- get_nodes_query(driver.provider, index_name, '$query')
925
+ get_nodes_query(index_name, '$query', limit=limit, provider=driver.provider)
803
926
  + """
804
927
  YIELD node AS episode, score
805
928
  MATCH (e:Episodic)
806
929
  WHERE e.uuid = episode.uuid
807
- AND e.group_id IN $group_ids
930
+ """
931
+ + group_filter_query
932
+ + """
808
933
  RETURN
809
934
  """
810
935
  + EPISODIC_NODE_RETURN
@@ -815,12 +940,9 @@ async def episode_fulltext_search(
815
940
  )
816
941
 
817
942
  records, _, _ = await driver.execute_query(
818
- query,
819
- query=fuzzy_query,
820
- group_ids=group_ids,
821
- limit=limit,
822
- routing_='r',
943
+ query, query=fuzzy_query, limit=limit, routing_='r', **filter_params
823
944
  )
945
+
824
946
  episodes = [get_episodic_node_from_record(record) for record in records]
825
947
 
826
948
  return episodes
@@ -833,10 +955,20 @@ async def community_fulltext_search(
833
955
  limit=RELEVANT_SCHEMA_LIMIT,
834
956
  ) -> list[CommunityNode]:
835
957
  # BM25 search to get top communities
836
- fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
958
+ fuzzy_query = fulltext_query(query, group_ids, driver)
837
959
  if fuzzy_query == '':
838
960
  return []
839
961
 
962
+ filter_params: dict[str, Any] = {}
963
+ group_filter_query: LiteralString = ''
964
+ if group_ids is not None:
965
+ group_filter_query = 'WHERE c.group_id IN $group_ids'
966
+ filter_params['group_ids'] = group_ids
967
+
968
+ yield_query = 'YIELD node AS c, score'
969
+ if driver.provider == GraphProvider.KUZU:
970
+ yield_query = 'WITH node AS c, score'
971
+
840
972
  if driver.provider == GraphProvider.NEPTUNE:
841
973
  res = driver.run_aoss_query('community_name', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
842
974
  if res['hits']['total']['value'] > 0:
@@ -852,9 +984,9 @@ async def community_fulltext_search(
852
984
  WHERE comm.uuid=i.id
853
985
  RETURN
854
986
  comm.uuid AS uuid,
855
- comm.group_id AS group_id,
856
- comm.name AS name,
857
- comm.created_at AS created_at,
987
+ comm.group_id AS group_id,
988
+ comm.name AS name,
989
+ comm.created_at AS created_at,
858
990
  comm.summary AS summary,
859
991
  [x IN split(comm.name_embedding, ",") | toFloat(x)]AS name_embedding
860
992
  ORDER BY i.score DESC
@@ -864,18 +996,21 @@ async def community_fulltext_search(
864
996
  query,
865
997
  ids=input_ids,
866
998
  query=fuzzy_query,
867
- group_ids=group_ids,
868
999
  limit=limit,
869
1000
  routing_='r',
1001
+ **filter_params,
870
1002
  )
871
1003
  else:
872
1004
  return []
873
1005
  else:
874
1006
  query = (
875
- get_nodes_query(driver.provider, 'community_name', '$query')
1007
+ get_nodes_query('community_name', '$query', limit=limit, provider=driver.provider)
1008
+ + yield_query
1009
+ + """
1010
+ WITH c, score
1011
+ """
1012
+ + group_filter_query
876
1013
  + """
877
- YIELD node AS n, score
878
- WHERE n.group_id IN $group_ids
879
1014
  RETURN
880
1015
  """
881
1016
  + COMMUNITY_NODE_RETURN
@@ -886,12 +1021,9 @@ async def community_fulltext_search(
886
1021
  )
887
1022
 
888
1023
  records, _, _ = await driver.execute_query(
889
- query,
890
- query=fuzzy_query,
891
- group_ids=group_ids,
892
- limit=limit,
893
- routing_='r',
1024
+ query, query=fuzzy_query, limit=limit, routing_='r', **filter_params
894
1025
  )
1026
+
895
1027
  communities = [get_community_node_from_record(record) for record in records]
896
1028
 
897
1029
  return communities
@@ -909,15 +1041,14 @@ async def community_similarity_search(
909
1041
 
910
1042
  group_filter_query: LiteralString = ''
911
1043
  if group_ids is not None:
912
- group_filter_query += 'WHERE n.group_id IN $group_ids'
1044
+ group_filter_query += ' WHERE c.group_id IN $group_ids'
913
1045
  query_params['group_ids'] = group_ids
914
1046
 
915
1047
  if driver.provider == GraphProvider.NEPTUNE:
916
1048
  query = (
917
- RUNTIME_QUERY
918
- + """
919
- MATCH (n:Community)
920
1049
  """
1050
+ MATCH (n:Community)
1051
+ """
921
1052
  + group_filter_query
922
1053
  + """
923
1054
  RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@@ -951,8 +1082,8 @@ async def community_similarity_search(
951
1082
  RETURN
952
1083
  comm.uuid As uuid,
953
1084
  comm.group_id AS group_id,
954
- comm.name AS name,
955
- comm.created_at AS created_at,
1085
+ comm.name AS name,
1086
+ comm.created_at AS created_at,
956
1087
  comm.summary AS summary,
957
1088
  comm.name_embedding AS name_embedding
958
1089
  ORDER BY i.score DESC
@@ -970,16 +1101,19 @@ async def community_similarity_search(
970
1101
  else:
971
1102
  return []
972
1103
  else:
1104
+ search_vector_var = '$search_vector'
1105
+ if driver.provider == GraphProvider.KUZU:
1106
+ search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
1107
+
973
1108
  query = (
974
- RUNTIME_QUERY
975
- + """
976
- MATCH (n:Community)
977
1109
  """
1110
+ MATCH (c:Community)
1111
+ """
978
1112
  + group_filter_query
979
1113
  + """
980
- WITH n,
1114
+ WITH c,
981
1115
  """
982
- + get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
1116
+ + get_vector_cosine_func_query('c.name_embedding', search_vector_var, driver.provider)
983
1117
  + """ AS score
984
1118
  WHERE score > $min_score
985
1119
  RETURN
@@ -999,6 +1133,7 @@ async def community_similarity_search(
999
1133
  routing_='r',
1000
1134
  **query_params,
1001
1135
  )
1136
+
1002
1137
  communities = [get_community_node_from_record(record) for record in records]
1003
1138
 
1004
1139
  return communities
@@ -1089,67 +1224,127 @@ async def get_relevant_nodes(
1089
1224
  return []
1090
1225
 
1091
1226
  group_id = nodes[0].group_id
1092
-
1093
- # vector similarity search over entity names
1094
- query_params: dict[str, Any] = {}
1095
-
1096
- filter_query, filter_params = node_search_filter_query_constructor(search_filter)
1097
- query_params.update(filter_params)
1098
-
1099
- query = (
1100
- RUNTIME_QUERY
1101
- + """
1102
- UNWIND $nodes AS node
1103
- MATCH (n:Entity {group_id: $group_id})
1104
- """
1105
- + filter_query
1106
- + """
1107
- WITH node, n, """
1108
- + get_vector_cosine_func_query('n.name_embedding', 'node.name_embedding', driver.provider)
1109
- + """ AS score
1110
- WHERE score > $min_score
1111
- WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
1112
- """
1113
- + get_nodes_query(driver.provider, 'node_name_and_summary', 'node.fulltext_query')
1114
- + """
1115
- YIELD node AS m
1116
- WHERE m.group_id = $group_id
1117
- WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
1118
-
1119
- WITH node,
1120
- top_vector_nodes,
1121
- [m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
1122
-
1123
- WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
1124
-
1125
- UNWIND combined_nodes AS combined_node
1126
- WITH node, collect(DISTINCT combined_node) AS deduped_nodes
1127
-
1128
- RETURN
1129
- node.uuid AS search_node_uuid,
1130
- [x IN deduped_nodes | {
1131
- uuid: x.uuid,
1132
- name: x.name,
1133
- name_embedding: x.name_embedding,
1134
- group_id: x.group_id,
1135
- created_at: x.created_at,
1136
- summary: x.summary,
1137
- labels: labels(x),
1138
- attributes: properties(x)
1139
- }] AS matches
1140
- """
1141
- )
1142
-
1143
1227
  query_nodes = [
1144
1228
  {
1145
1229
  'uuid': node.uuid,
1146
1230
  'name': node.name,
1147
1231
  'name_embedding': node.name_embedding,
1148
- 'fulltext_query': fulltext_query(node.name, [node.group_id], driver.fulltext_syntax),
1232
+ 'fulltext_query': fulltext_query(node.name, [node.group_id], driver),
1149
1233
  }
1150
1234
  for node in nodes
1151
1235
  ]
1152
1236
 
1237
+ filter_queries, filter_params = node_search_filter_query_constructor(
1238
+ search_filter, driver.provider
1239
+ )
1240
+
1241
+ filter_query = ''
1242
+ if filter_queries:
1243
+ filter_query = 'WHERE ' + (' AND '.join(filter_queries))
1244
+
1245
+ if driver.provider == GraphProvider.KUZU:
1246
+ embedding_size = len(nodes[0].name_embedding) if nodes[0].name_embedding is not None else 0
1247
+ if embedding_size == 0:
1248
+ return []
1249
+
1250
+ # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
1251
+ query = (
1252
+ """
1253
+ UNWIND $nodes AS node
1254
+ MATCH (n:Entity {group_id: $group_id})
1255
+ """
1256
+ + filter_query
1257
+ + """
1258
+ WITH node, n, """
1259
+ + get_vector_cosine_func_query(
1260
+ 'n.name_embedding',
1261
+ f'CAST(node.name_embedding AS FLOAT[{embedding_size}])',
1262
+ driver.provider,
1263
+ )
1264
+ + """ AS score
1265
+ WHERE score > $min_score
1266
+ WITH node, collect(n)[:$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
1267
+ """
1268
+ + get_nodes_query(
1269
+ 'node_name_and_summary',
1270
+ 'node.fulltext_query',
1271
+ limit=limit,
1272
+ provider=driver.provider,
1273
+ )
1274
+ + """
1275
+ WITH node AS m
1276
+ WHERE m.group_id = $group_id AND NOT m.uuid IN vector_node_uuids
1277
+ WITH node, top_vector_nodes, collect(m) AS fulltext_nodes
1278
+
1279
+ WITH node, list_concat(top_vector_nodes, fulltext_nodes) AS combined_nodes
1280
+
1281
+ UNWIND combined_nodes AS x
1282
+ WITH node, collect(DISTINCT {
1283
+ uuid: x.uuid,
1284
+ name: x.name,
1285
+ name_embedding: x.name_embedding,
1286
+ group_id: x.group_id,
1287
+ created_at: x.created_at,
1288
+ summary: x.summary,
1289
+ labels: x.labels,
1290
+ attributes: x.attributes
1291
+ }) AS matches
1292
+
1293
+ RETURN
1294
+ node.uuid AS search_node_uuid, matches
1295
+ """
1296
+ )
1297
+ else:
1298
+ query = (
1299
+ """
1300
+ UNWIND $nodes AS node
1301
+ MATCH (n:Entity {group_id: $group_id})
1302
+ """
1303
+ + filter_query
1304
+ + """
1305
+ WITH node, n, """
1306
+ + get_vector_cosine_func_query(
1307
+ 'n.name_embedding', 'node.name_embedding', driver.provider
1308
+ )
1309
+ + """ AS score
1310
+ WHERE score > $min_score
1311
+ WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
1312
+ """
1313
+ + get_nodes_query(
1314
+ 'node_name_and_summary',
1315
+ 'node.fulltext_query',
1316
+ limit=limit,
1317
+ provider=driver.provider,
1318
+ )
1319
+ + """
1320
+ YIELD node AS m
1321
+ WHERE m.group_id = $group_id
1322
+ WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
1323
+
1324
+ WITH node,
1325
+ top_vector_nodes,
1326
+ [m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
1327
+
1328
+ WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
1329
+
1330
+ UNWIND combined_nodes AS combined_node
1331
+ WITH node, collect(DISTINCT combined_node) AS deduped_nodes
1332
+
1333
+ RETURN
1334
+ node.uuid AS search_node_uuid,
1335
+ [x IN deduped_nodes | {
1336
+ uuid: x.uuid,
1337
+ name: x.name,
1338
+ name_embedding: x.name_embedding,
1339
+ group_id: x.group_id,
1340
+ created_at: x.created_at,
1341
+ summary: x.summary,
1342
+ labels: labels(x),
1343
+ attributes: properties(x)
1344
+ }] AS matches
1345
+ """
1346
+ )
1347
+
1153
1348
  results, _, _ = await driver.execute_query(
1154
1349
  query,
1155
1350
  nodes=query_nodes,
@@ -1157,12 +1352,12 @@ async def get_relevant_nodes(
1157
1352
  limit=limit,
1158
1353
  min_score=min_score,
1159
1354
  routing_='r',
1160
- **query_params,
1355
+ **filter_params,
1161
1356
  )
1162
1357
 
1163
1358
  relevant_nodes_dict: dict[str, list[EntityNode]] = {
1164
1359
  result['search_node_uuid']: [
1165
- get_entity_node_from_record(record) for record in result['matches']
1360
+ get_entity_node_from_record(record, driver.provider) for record in result['matches']
1166
1361
  ]
1167
1362
  for result in results
1168
1363
  }
@@ -1182,22 +1377,24 @@ async def get_relevant_edges(
1182
1377
  if len(edges) == 0:
1183
1378
  return []
1184
1379
 
1185
- query_params: dict[str, Any] = {}
1380
+ filter_queries, filter_params = edge_search_filter_query_constructor(
1381
+ search_filter, driver.provider
1382
+ )
1186
1383
 
1187
- filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
1188
- query_params.update(filter_params)
1384
+ filter_query = ''
1385
+ if filter_queries:
1386
+ filter_query = ' WHERE ' + (' AND '.join(filter_queries))
1189
1387
 
1190
1388
  if driver.provider == GraphProvider.NEPTUNE:
1191
1389
  query = (
1192
- RUNTIME_QUERY
1193
- + """
1194
- UNWIND $edges AS edge
1195
- MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
1196
1390
  """
1391
+ UNWIND $edges AS edge
1392
+ MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
1393
+ """
1197
1394
  + filter_query
1198
1395
  + """
1199
1396
  WITH e, edge
1200
- RETURN DISTINCT id(e) as id, e.fact_embedding as source_embedding, edge.uuid as search_edge_uuid,
1397
+ RETURN DISTINCT id(e) as id, e.fact_embedding as source_embedding, edge.uuid as search_edge_uuid,
1201
1398
  edge.fact_embedding as target_embedding
1202
1399
  """
1203
1400
  )
@@ -1207,7 +1404,7 @@ async def get_relevant_edges(
1207
1404
  limit=limit,
1208
1405
  min_score=min_score,
1209
1406
  routing_='r',
1210
- **query_params,
1407
+ **filter_params,
1211
1408
  )
1212
1409
 
1213
1410
  # Calculate Cosine similarity then return the edge ids
@@ -1220,7 +1417,7 @@ async def get_relevant_edges(
1220
1417
  input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})
1221
1418
 
1222
1419
  # Match the edge ides and return the values
1223
- query = """
1420
+ query = """
1224
1421
  UNWIND $ids AS edge
1225
1422
  MATCH ()-[e]->()
1226
1423
  WHERE id(e) = edge.id
@@ -1246,49 +1443,93 @@ async def get_relevant_edges(
1246
1443
 
1247
1444
  results, _, _ = await driver.execute_query(
1248
1445
  query,
1249
- params=query_params,
1250
1446
  ids=input_ids,
1251
1447
  edges=[edge.model_dump() for edge in edges],
1252
1448
  limit=limit,
1253
1449
  min_score=min_score,
1254
1450
  routing_='r',
1255
- **query_params,
1451
+ **filter_params,
1256
1452
  )
1257
1453
  else:
1258
- query = (
1259
- RUNTIME_QUERY
1260
- + """
1261
- UNWIND $edges AS edge
1262
- MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
1263
- """
1264
- + filter_query
1265
- + """
1266
- WITH e, edge, """
1267
- + get_vector_cosine_func_query(
1268
- 'e.fact_embedding', 'edge.fact_embedding', driver.provider
1454
+ if driver.provider == GraphProvider.KUZU:
1455
+ embedding_size = (
1456
+ len(edges[0].fact_embedding) if edges[0].fact_embedding is not None else 0
1457
+ )
1458
+ if embedding_size == 0:
1459
+ return []
1460
+
1461
+ query = (
1462
+ """
1463
+ UNWIND $edges AS edge
1464
+ MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
1465
+ """
1466
+ + filter_query
1467
+ + """
1468
+ WITH e, edge, n, m, """
1469
+ + get_vector_cosine_func_query(
1470
+ 'e.fact_embedding',
1471
+ f'CAST(edge.fact_embedding AS FLOAT[{embedding_size}])',
1472
+ driver.provider,
1473
+ )
1474
+ + """ AS score
1475
+ WHERE score > $min_score
1476
+ WITH e, edge, n, m, score
1477
+ ORDER BY score DESC
1478
+ LIMIT $limit
1479
+ RETURN
1480
+ edge.uuid AS search_edge_uuid,
1481
+ collect({
1482
+ uuid: e.uuid,
1483
+ source_node_uuid: n.uuid,
1484
+ target_node_uuid: m.uuid,
1485
+ created_at: e.created_at,
1486
+ name: e.name,
1487
+ group_id: e.group_id,
1488
+ fact: e.fact,
1489
+ fact_embedding: e.fact_embedding,
1490
+ episodes: e.episodes,
1491
+ expired_at: e.expired_at,
1492
+ valid_at: e.valid_at,
1493
+ invalid_at: e.invalid_at,
1494
+ attributes: e.attributes
1495
+ }) AS matches
1496
+ """
1497
+ )
1498
+ else:
1499
+ query = (
1500
+ """
1501
+ UNWIND $edges AS edge
1502
+ MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
1503
+ """
1504
+ + filter_query
1505
+ + """
1506
+ WITH e, edge, """
1507
+ + get_vector_cosine_func_query(
1508
+ 'e.fact_embedding', 'edge.fact_embedding', driver.provider
1509
+ )
1510
+ + """ AS score
1511
+ WHERE score > $min_score
1512
+ WITH edge, e, score
1513
+ ORDER BY score DESC
1514
+ RETURN
1515
+ edge.uuid AS search_edge_uuid,
1516
+ collect({
1517
+ uuid: e.uuid,
1518
+ source_node_uuid: startNode(e).uuid,
1519
+ target_node_uuid: endNode(e).uuid,
1520
+ created_at: e.created_at,
1521
+ name: e.name,
1522
+ group_id: e.group_id,
1523
+ fact: e.fact,
1524
+ fact_embedding: e.fact_embedding,
1525
+ episodes: e.episodes,
1526
+ expired_at: e.expired_at,
1527
+ valid_at: e.valid_at,
1528
+ invalid_at: e.invalid_at,
1529
+ attributes: properties(e)
1530
+ })[..$limit] AS matches
1531
+ """
1269
1532
  )
1270
- + """ AS score
1271
- WHERE score > $min_score
1272
- WITH edge, e, score
1273
- ORDER BY score DESC
1274
- RETURN edge.uuid AS search_edge_uuid,
1275
- collect({
1276
- uuid: e.uuid,
1277
- source_node_uuid: startNode(e).uuid,
1278
- target_node_uuid: endNode(e).uuid,
1279
- created_at: e.created_at,
1280
- name: e.name,
1281
- group_id: e.group_id,
1282
- fact: e.fact,
1283
- fact_embedding: e.fact_embedding,
1284
- episodes: e.episodes,
1285
- expired_at: e.expired_at,
1286
- valid_at: e.valid_at,
1287
- invalid_at: e.invalid_at,
1288
- attributes: properties(e)
1289
- })[..$limit] AS matches
1290
- """
1291
- )
1292
1533
 
1293
1534
  results, _, _ = await driver.execute_query(
1294
1535
  query,
@@ -1296,12 +1537,12 @@ async def get_relevant_edges(
1296
1537
  limit=limit,
1297
1538
  min_score=min_score,
1298
1539
  routing_='r',
1299
- **query_params,
1540
+ **filter_params,
1300
1541
  )
1301
1542
 
1302
1543
  relevant_edges_dict: dict[str, list[EntityEdge]] = {
1303
1544
  result['search_edge_uuid']: [
1304
- get_entity_edge_from_record(record) for record in result['matches']
1545
+ get_entity_edge_from_record(record, driver.provider) for record in result['matches']
1305
1546
  ]
1306
1547
  for result in results
1307
1548
  }
@@ -1321,19 +1562,21 @@ async def get_edge_invalidation_candidates(
1321
1562
  if len(edges) == 0:
1322
1563
  return []
1323
1564
 
1324
- query_params: dict[str, Any] = {}
1565
+ filter_queries, filter_params = edge_search_filter_query_constructor(
1566
+ search_filter, driver.provider
1567
+ )
1325
1568
 
1326
- filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
1327
- query_params.update(filter_params)
1569
+ filter_query = ''
1570
+ if filter_queries:
1571
+ filter_query = ' AND ' + (' AND '.join(filter_queries))
1328
1572
 
1329
1573
  if driver.provider == GraphProvider.NEPTUNE:
1330
1574
  query = (
1331
- RUNTIME_QUERY
1332
- + """
1333
- UNWIND $edges AS edge
1334
- MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
1335
- WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
1336
1575
  """
1576
+ UNWIND $edges AS edge
1577
+ MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
1578
+ WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
1579
+ """
1337
1580
  + filter_query
1338
1581
  + """
1339
1582
  WITH e, edge
@@ -1348,7 +1591,7 @@ async def get_edge_invalidation_candidates(
1348
1591
  limit=limit,
1349
1592
  min_score=min_score,
1350
1593
  routing_='r',
1351
- **query_params,
1594
+ **filter_params,
1352
1595
  )
1353
1596
 
1354
1597
  # Calculate Cosine similarity then return the edge ids
@@ -1361,7 +1604,7 @@ async def get_edge_invalidation_candidates(
1361
1604
  input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})
1362
1605
 
1363
1606
  # Match the edge ides and return the values
1364
- query = """
1607
+ query = """
1365
1608
  UNWIND $ids AS edge
1366
1609
  MATCH ()-[e]->()
1367
1610
  WHERE id(e) = edge.id
@@ -1391,44 +1634,90 @@ async def get_edge_invalidation_candidates(
1391
1634
  limit=limit,
1392
1635
  min_score=min_score,
1393
1636
  routing_='r',
1394
- **query_params,
1637
+ **filter_params,
1395
1638
  )
1396
1639
  else:
1397
- query = (
1398
- RUNTIME_QUERY
1399
- + """
1400
- UNWIND $edges AS edge
1401
- MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
1402
- WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
1403
- """
1404
- + filter_query
1405
- + """
1406
- WITH edge, e, """
1407
- + get_vector_cosine_func_query(
1408
- 'e.fact_embedding', 'edge.fact_embedding', driver.provider
1640
+ if driver.provider == GraphProvider.KUZU:
1641
+ embedding_size = (
1642
+ len(edges[0].fact_embedding) if edges[0].fact_embedding is not None else 0
1643
+ )
1644
+ if embedding_size == 0:
1645
+ return []
1646
+
1647
+ query = (
1648
+ """
1649
+ UNWIND $edges AS edge
1650
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
1651
+ WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
1652
+ """
1653
+ + filter_query
1654
+ + """
1655
+ WITH edge, e, n, m, """
1656
+ + get_vector_cosine_func_query(
1657
+ 'e.fact_embedding',
1658
+ f'CAST(edge.fact_embedding AS FLOAT[{embedding_size}])',
1659
+ driver.provider,
1660
+ )
1661
+ + """ AS score
1662
+ WHERE score > $min_score
1663
+ WITH edge, e, n, m, score
1664
+ ORDER BY score DESC
1665
+ LIMIT $limit
1666
+ RETURN
1667
+ edge.uuid AS search_edge_uuid,
1668
+ collect({
1669
+ uuid: e.uuid,
1670
+ source_node_uuid: n.uuid,
1671
+ target_node_uuid: m.uuid,
1672
+ created_at: e.created_at,
1673
+ name: e.name,
1674
+ group_id: e.group_id,
1675
+ fact: e.fact,
1676
+ fact_embedding: e.fact_embedding,
1677
+ episodes: e.episodes,
1678
+ expired_at: e.expired_at,
1679
+ valid_at: e.valid_at,
1680
+ invalid_at: e.invalid_at,
1681
+ attributes: e.attributes
1682
+ }) AS matches
1683
+ """
1684
+ )
1685
+ else:
1686
+ query = (
1687
+ """
1688
+ UNWIND $edges AS edge
1689
+ MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
1690
+ WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
1691
+ """
1692
+ + filter_query
1693
+ + """
1694
+ WITH edge, e, """
1695
+ + get_vector_cosine_func_query(
1696
+ 'e.fact_embedding', 'edge.fact_embedding', driver.provider
1697
+ )
1698
+ + """ AS score
1699
+ WHERE score > $min_score
1700
+ WITH edge, e, score
1701
+ ORDER BY score DESC
1702
+ RETURN
1703
+ edge.uuid AS search_edge_uuid,
1704
+ collect({
1705
+ uuid: e.uuid,
1706
+ source_node_uuid: startNode(e).uuid,
1707
+ target_node_uuid: endNode(e).uuid,
1708
+ created_at: e.created_at,
1709
+ name: e.name,
1710
+ group_id: e.group_id,
1711
+ fact: e.fact,
1712
+ fact_embedding: e.fact_embedding,
1713
+ episodes: e.episodes,
1714
+ expired_at: e.expired_at,
1715
+ valid_at: e.valid_at,
1716
+ invalid_at: e.invalid_at,
1717
+ attributes: properties(e)
1718
+ })[..$limit] AS matches
1719
+ """
1409
1720
  )
1410
- + """ AS score
1411
- WHERE score > $min_score
1412
- WITH edge, e, score
1413
- ORDER BY score DESC
1414
- RETURN edge.uuid AS search_edge_uuid,
1415
- collect({
1416
- uuid: e.uuid,
1417
- source_node_uuid: startNode(e).uuid,
1418
- target_node_uuid: endNode(e).uuid,
1419
- created_at: e.created_at,
1420
- name: e.name,
1421
- group_id: e.group_id,
1422
- fact: e.fact,
1423
- fact_embedding: e.fact_embedding,
1424
- episodes: e.episodes,
1425
- expired_at: e.expired_at,
1426
- valid_at: e.valid_at,
1427
- invalid_at: e.invalid_at,
1428
- attributes: properties(e)
1429
- })[..$limit] AS matches
1430
- """
1431
- )
1432
1721
 
1433
1722
  results, _, _ = await driver.execute_query(
1434
1723
  query,
@@ -1436,11 +1725,11 @@ async def get_edge_invalidation_candidates(
1436
1725
  limit=limit,
1437
1726
  min_score=min_score,
1438
1727
  routing_='r',
1439
- **query_params,
1728
+ **filter_params,
1440
1729
  )
1441
1730
  invalidation_edges_dict: dict[str, list[EntityEdge]] = {
1442
1731
  result['search_edge_uuid']: [
1443
- get_entity_edge_from_record(record) for record in result['matches']
1732
+ get_entity_edge_from_record(record, driver.provider) for record in result['matches']
1444
1733
  ]
1445
1734
  for result in results
1446
1735
  }
@@ -1479,13 +1768,21 @@ async def node_distance_reranker(
1479
1768
  filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
1480
1769
  scores: dict[str, float] = {center_node_uuid: 0.0}
1481
1770
 
1482
- # Find the shortest path to center node
1483
- results, header, _ = await driver.execute_query(
1484
- """
1771
+ query = """
1772
+ UNWIND $node_uuids AS node_uuid
1773
+ MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
1774
+ RETURN 1 AS score, node_uuid AS uuid
1775
+ """
1776
+ if driver.provider == GraphProvider.KUZU:
1777
+ query = """
1485
1778
  UNWIND $node_uuids AS node_uuid
1486
- MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
1779
+ MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(n:Entity {uuid: node_uuid})
1487
1780
  RETURN 1 AS score, node_uuid AS uuid
1488
- """,
1781
+ """
1782
+
1783
+ # Find the shortest path to center node
1784
+ results, header, _ = await driver.execute_query(
1785
+ query,
1489
1786
  node_uuids=filtered_uuids,
1490
1787
  center_uuid=center_node_uuid,
1491
1788
  routing_='r',
@@ -1536,6 +1833,10 @@ async def episode_mentions_reranker(
1536
1833
  for result in results:
1537
1834
  scores[result['uuid']] = result['score']
1538
1835
 
1836
+ for uuid in sorted_uuids:
1837
+ if uuid not in scores:
1838
+ scores[uuid] = float('inf')
1839
+
1539
1840
  # rerank on shortest distance
1540
1841
  sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
1541
1842
 
@@ -1667,13 +1968,23 @@ async def get_embeddings_for_edges(
1667
1968
  split(e.fact_embedding, ",") AS fact_embedding
1668
1969
  """
1669
1970
  else:
1670
- query = """
1671
- MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
1971
+ match_query = """
1972
+ MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
1973
+ """
1974
+ if driver.provider == GraphProvider.KUZU:
1975
+ match_query = """
1976
+ MATCH (n:Entity)-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(m:Entity)
1977
+ """
1978
+
1979
+ query = (
1980
+ match_query
1981
+ + """
1672
1982
  WHERE e.uuid IN $edge_uuids
1673
1983
  RETURN DISTINCT
1674
1984
  e.uuid AS uuid,
1675
1985
  e.fact_embedding AS fact_embedding
1676
1986
  """
1987
+ )
1677
1988
  results, _, _ = await driver.execute_query(
1678
1989
  query,
1679
1990
  edge_uuids=[edge.uuid for edge in edges],