graphiti-core 0.17.11__py3-none-any.whl → 0.18.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -23,7 +23,7 @@ import numpy as np
23
23
  from numpy._typing import NDArray
24
24
  from typing_extensions import LiteralString
25
25
 
26
- from graphiti_core.driver.driver import GraphDriver
26
+ from graphiti_core.driver.driver import GraphDriver, GraphProvider
27
27
  from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
28
28
  from graphiti_core.graph_queries import (
29
29
  get_nodes_query,
@@ -36,6 +36,8 @@ from graphiti_core.helpers import (
36
36
  normalize_l2,
37
37
  semaphore_gather,
38
38
  )
39
+ from graphiti_core.models.edges.edge_db_queries import ENTITY_EDGE_RETURN
40
+ from graphiti_core.models.nodes.node_db_queries import COMMUNITY_NODE_RETURN, EPISODIC_NODE_RETURN
39
41
  from graphiti_core.nodes import (
40
42
  ENTITY_NODE_RETURN,
41
43
  CommunityNode,
@@ -100,20 +102,13 @@ async def get_mentioned_nodes(
100
102
  ) -> list[EntityNode]:
101
103
  episode_uuids = [episode.uuid for episode in episodes]
102
104
 
103
- query = """
104
- MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
105
+ records, _, _ = await driver.execute_query(
106
+ """
107
+ MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity)
108
+ WHERE episode.uuid IN $uuids
105
109
  RETURN DISTINCT
106
- n.uuid As uuid,
107
- n.group_id AS group_id,
108
- n.name AS name,
109
- n.created_at AS created_at,
110
- n.summary AS summary,
111
- labels(n) AS labels,
112
- properties(n) AS attributes
113
110
  """
114
-
115
- records, _, _ = await driver.execute_query(
116
- query,
111
+ + ENTITY_NODE_RETURN,
117
112
  uuids=episode_uuids,
118
113
  routing_='r',
119
114
  )
@@ -128,18 +123,13 @@ async def get_communities_by_nodes(
128
123
  ) -> list[CommunityNode]:
129
124
  node_uuids = [node.uuid for node in nodes]
130
125
 
131
- query = """
132
- MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
133
- RETURN DISTINCT
134
- c.uuid As uuid,
135
- c.group_id AS group_id,
136
- c.name AS name,
137
- c.created_at AS created_at,
138
- c.summary AS summary
139
- """
140
-
141
126
  records, _, _ = await driver.execute_query(
142
- query,
127
+ """
128
+ MATCH (n:Community)-[:HAS_MEMBER]->(m:Entity)
129
+ WHERE m.uuid IN $uuids
130
+ RETURN DISTINCT
131
+ """
132
+ + COMMUNITY_NODE_RETURN,
143
133
  uuids=node_uuids,
144
134
  routing_='r',
145
135
  )
@@ -164,38 +154,30 @@ async def edge_fulltext_search(
164
154
  filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
165
155
 
166
156
  query = (
167
- get_relationships_query('edge_name_and_fact', db_type=driver.provider)
157
+ get_relationships_query('edge_name_and_fact', provider=driver.provider)
168
158
  + """
169
159
  YIELD relationship AS rel, score
170
- MATCH (n:Entity)-[r:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
171
- WHERE r.group_id IN $group_ids """
160
+ MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
161
+ WHERE e.group_id IN $group_ids """
172
162
  + filter_query
173
163
  + """
174
- WITH r, score, startNode(r) AS n, endNode(r) AS m
164
+ WITH e, score, n, m
175
165
  RETURN
176
- r.uuid AS uuid,
177
- r.group_id AS group_id,
178
- n.uuid AS source_node_uuid,
179
- m.uuid AS target_node_uuid,
180
- r.created_at AS created_at,
181
- r.name AS name,
182
- r.fact AS fact,
183
- r.episodes AS episodes,
184
- r.expired_at AS expired_at,
185
- r.valid_at AS valid_at,
186
- r.invalid_at AS invalid_at,
187
- properties(r) AS attributes
188
- ORDER BY score DESC LIMIT $limit
166
+ """
167
+ + ENTITY_EDGE_RETURN
168
+ + """
169
+ ORDER BY score DESC
170
+ LIMIT $limit
189
171
  """
190
172
  )
191
173
 
192
174
  records, _, _ = await driver.execute_query(
193
175
  query,
194
- params=filter_params,
195
176
  query=fuzzy_query,
196
177
  group_ids=group_ids,
197
178
  limit=limit,
198
179
  routing_='r',
180
+ **filter_params,
199
181
  )
200
182
 
201
183
  edges = [get_entity_edge_from_record(record) for record in records]
@@ -219,58 +201,47 @@ async def edge_similarity_search(
219
201
  filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
220
202
  query_params.update(filter_params)
221
203
 
222
- group_filter_query: LiteralString = 'WHERE r.group_id IS NOT NULL'
204
+ group_filter_query: LiteralString = 'WHERE e.group_id IS NOT NULL'
223
205
  if group_ids is not None:
224
- group_filter_query += '\nAND r.group_id IN $group_ids'
206
+ group_filter_query += '\nAND e.group_id IN $group_ids'
225
207
  query_params['group_ids'] = group_ids
226
- query_params['source_node_uuid'] = source_node_uuid
227
- query_params['target_node_uuid'] = target_node_uuid
228
208
 
229
209
  if source_node_uuid is not None:
230
- group_filter_query += '\nAND (n.uuid IN [$source_uuid, $target_uuid])'
210
+ query_params['source_uuid'] = source_node_uuid
211
+ group_filter_query += '\nAND (n.uuid = $source_uuid)'
231
212
 
232
213
  if target_node_uuid is not None:
233
- group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])'
214
+ query_params['target_uuid'] = target_node_uuid
215
+ group_filter_query += '\nAND (m.uuid = $target_uuid)'
234
216
 
235
217
  query = (
236
218
  RUNTIME_QUERY
237
219
  + """
238
- MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
220
+ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
239
221
  """
240
222
  + group_filter_query
241
223
  + filter_query
242
224
  + """
243
- WITH DISTINCT r, """
244
- + get_vector_cosine_func_query('r.fact_embedding', '$search_vector', driver.provider)
225
+ WITH DISTINCT e, n, m, """
226
+ + get_vector_cosine_func_query('e.fact_embedding', '$search_vector', driver.provider)
245
227
  + """ AS score
246
228
  WHERE score > $min_score
247
229
  RETURN
248
- r.uuid AS uuid,
249
- r.group_id AS group_id,
250
- startNode(r).uuid AS source_node_uuid,
251
- endNode(r).uuid AS target_node_uuid,
252
- r.created_at AS created_at,
253
- r.name AS name,
254
- r.fact AS fact,
255
- r.episodes AS episodes,
256
- r.expired_at AS expired_at,
257
- r.valid_at AS valid_at,
258
- r.invalid_at AS invalid_at,
259
- properties(r) AS attributes
230
+ """
231
+ + ENTITY_EDGE_RETURN
232
+ + """
260
233
  ORDER BY score DESC
261
234
  LIMIT $limit
262
235
  """
263
236
  )
264
- records, header, _ = await driver.execute_query(
237
+
238
+ records, _, _ = await driver.execute_query(
265
239
  query,
266
- params=query_params,
267
240
  search_vector=search_vector,
268
- source_uuid=source_node_uuid,
269
- target_uuid=target_node_uuid,
270
- group_ids=group_ids,
271
241
  limit=limit,
272
242
  min_score=min_score,
273
243
  routing_='r',
244
+ **query_params,
274
245
  )
275
246
 
276
247
  edges = [get_entity_edge_from_record(record) for record in records]
@@ -293,41 +264,31 @@ async def edge_bfs_search(
293
264
  filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
294
265
 
295
266
  query = (
296
- """
267
+ f"""
297
268
  UNWIND $bfs_origin_node_uuids AS origin_uuid
298
- MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
269
+ MATCH path = (origin:Entity|Episodic {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(:Entity)
299
270
  UNWIND relationships(path) AS rel
300
- MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
301
- WHERE r.uuid = rel.uuid
302
- AND r.group_id IN $group_ids
271
+ MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
272
+ WHERE e.uuid = rel.uuid
273
+ AND e.group_id IN $group_ids
303
274
  """
304
275
  + filter_query
305
- + """
306
- RETURN DISTINCT
307
- r.uuid AS uuid,
308
- r.group_id AS group_id,
309
- startNode(r).uuid AS source_node_uuid,
310
- endNode(r).uuid AS target_node_uuid,
311
- r.created_at AS created_at,
312
- r.name AS name,
313
- r.fact AS fact,
314
- r.episodes AS episodes,
315
- r.expired_at AS expired_at,
316
- r.valid_at AS valid_at,
317
- r.invalid_at AS invalid_at,
318
- properties(r) AS attributes
319
- LIMIT $limit
276
+ + """
277
+ RETURN DISTINCT
278
+ """
279
+ + ENTITY_EDGE_RETURN
280
+ + """
281
+ LIMIT $limit
320
282
  """
321
283
  )
322
284
 
323
285
  records, _, _ = await driver.execute_query(
324
286
  query,
325
- params=filter_params,
326
287
  bfs_origin_node_uuids=bfs_origin_node_uuids,
327
- depth=bfs_max_depth,
328
288
  group_ids=group_ids,
329
289
  limit=limit,
330
290
  routing_='r',
291
+ **filter_params,
331
292
  )
332
293
 
333
294
  edges = [get_entity_edge_from_record(record) for record in records]
@@ -352,23 +313,27 @@ async def node_fulltext_search(
352
313
  get_nodes_query(driver.provider, 'node_name_and_summary', '$query')
353
314
  + """
354
315
  YIELD node AS n, score
355
- WITH n, score
356
- LIMIT $limit
357
- WHERE n:Entity AND n.group_id IN $group_ids
316
+ WHERE n:Entity AND n.group_id IN $group_ids
317
+ WITH n, score
318
+ LIMIT $limit
358
319
  """
359
320
  + filter_query
321
+ + """
322
+ RETURN
323
+ """
360
324
  + ENTITY_NODE_RETURN
361
325
  + """
362
326
  ORDER BY score DESC
363
327
  """
364
328
  )
365
- records, header, _ = await driver.execute_query(
329
+
330
+ records, _, _ = await driver.execute_query(
366
331
  query,
367
- params=filter_params,
368
332
  query=fuzzy_query,
369
333
  group_ids=group_ids,
370
334
  limit=limit,
371
335
  routing_='r',
336
+ **filter_params,
372
337
  )
373
338
 
374
339
  nodes = [get_entity_node_from_record(record) for record in records]
@@ -406,22 +371,23 @@ async def node_similarity_search(
406
371
  WITH n, """
407
372
  + get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
408
373
  + """ AS score
409
- WHERE score > $min_score"""
374
+ WHERE score > $min_score
375
+ RETURN
376
+ """
410
377
  + ENTITY_NODE_RETURN
411
378
  + """
412
379
  ORDER BY score DESC
413
380
  LIMIT $limit
414
- """
381
+ """
415
382
  )
416
383
 
417
- records, header, _ = await driver.execute_query(
384
+ records, _, _ = await driver.execute_query(
418
385
  query,
419
- params=query_params,
420
386
  search_vector=search_vector,
421
- group_ids=group_ids,
422
387
  limit=limit,
423
388
  min_score=min_score,
424
389
  routing_='r',
390
+ **query_params,
425
391
  )
426
392
 
427
393
  nodes = [get_entity_node_from_record(record) for record in records]
@@ -444,26 +410,29 @@ async def node_bfs_search(
444
410
  filter_query, filter_params = node_search_filter_query_constructor(search_filter)
445
411
 
446
412
  query = (
447
- """
413
+ f"""
448
414
  UNWIND $bfs_origin_node_uuids AS origin_uuid
449
- MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
415
+ MATCH (origin:Entity|Episodic {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
450
416
  WHERE n.group_id = origin.group_id
451
417
  AND origin.group_id IN $group_ids
452
418
  """
453
419
  + filter_query
420
+ + """
421
+ RETURN
422
+ """
454
423
  + ENTITY_NODE_RETURN
455
424
  + """
456
425
  LIMIT $limit
457
426
  """
458
427
  )
428
+
459
429
  records, _, _ = await driver.execute_query(
460
430
  query,
461
- params=filter_params,
462
431
  bfs_origin_node_uuids=bfs_origin_node_uuids,
463
- depth=bfs_max_depth,
464
432
  group_ids=group_ids,
465
433
  limit=limit,
466
434
  routing_='r',
435
+ **filter_params,
467
436
  )
468
437
  nodes = [get_entity_node_from_record(record) for record in records]
469
438
 
@@ -489,16 +458,10 @@ async def episode_fulltext_search(
489
458
  MATCH (e:Episodic)
490
459
  WHERE e.uuid = episode.uuid
491
460
  AND e.group_id IN $group_ids
492
- RETURN
493
- e.content AS content,
494
- e.created_at AS created_at,
495
- e.valid_at AS valid_at,
496
- e.uuid AS uuid,
497
- e.name AS name,
498
- e.group_id AS group_id,
499
- e.source_description AS source_description,
500
- e.source AS source,
501
- e.entity_edges AS entity_edges
461
+ RETURN
462
+ """
463
+ + EPISODIC_NODE_RETURN
464
+ + """
502
465
  ORDER BY score DESC
503
466
  LIMIT $limit
504
467
  """
@@ -530,15 +493,12 @@ async def community_fulltext_search(
530
493
  query = (
531
494
  get_nodes_query(driver.provider, 'community_name', '$query')
532
495
  + """
533
- YIELD node AS comm, score
534
- WHERE comm.group_id IN $group_ids
496
+ YIELD node AS n, score
497
+ WHERE n.group_id IN $group_ids
535
498
  RETURN
536
- comm.uuid AS uuid,
537
- comm.group_id AS group_id,
538
- comm.name AS name,
539
- comm.created_at AS created_at,
540
- comm.summary AS summary,
541
- comm.name_embedding AS name_embedding
499
+ """
500
+ + COMMUNITY_NODE_RETURN
501
+ + """
542
502
  ORDER BY score DESC
543
503
  LIMIT $limit
544
504
  """
@@ -568,39 +528,37 @@ async def community_similarity_search(
568
528
 
569
529
  group_filter_query: LiteralString = ''
570
530
  if group_ids is not None:
571
- group_filter_query += 'WHERE comm.group_id IN $group_ids'
531
+ group_filter_query += 'WHERE n.group_id IN $group_ids'
572
532
  query_params['group_ids'] = group_ids
573
533
 
574
534
  query = (
575
535
  RUNTIME_QUERY
576
536
  + """
577
- MATCH (comm:Community)
578
- """
537
+ MATCH (n:Community)
538
+ """
579
539
  + group_filter_query
580
540
  + """
581
- WITH comm, """
582
- + get_vector_cosine_func_query('comm.name_embedding', '$search_vector', driver.provider)
541
+ WITH n,
542
+ """
543
+ + get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
583
544
  + """ AS score
584
- WHERE score > $min_score
585
- RETURN
586
- comm.uuid As uuid,
587
- comm.group_id AS group_id,
588
- comm.name AS name,
589
- comm.created_at AS created_at,
590
- comm.summary AS summary,
591
- comm.name_embedding AS name_embedding
592
- ORDER BY score DESC
593
- LIMIT $limit
545
+ WHERE score > $min_score
546
+ RETURN
547
+ """
548
+ + COMMUNITY_NODE_RETURN
549
+ + """
550
+ ORDER BY score DESC
551
+ LIMIT $limit
594
552
  """
595
553
  )
596
554
 
597
555
  records, _, _ = await driver.execute_query(
598
556
  query,
599
557
  search_vector=search_vector,
600
- group_ids=group_ids,
601
558
  limit=limit,
602
559
  min_score=min_score,
603
560
  routing_='r',
561
+ **query_params,
604
562
  )
605
563
  communities = [get_community_node_from_record(record) for record in records]
606
564
 
@@ -672,7 +630,7 @@ async def hybrid_node_search(
672
630
  }
673
631
  result_uuids = [[node.uuid for node in result] for result in results]
674
632
 
675
- ranked_uuids = rrf(result_uuids)
633
+ ranked_uuids, _ = rrf(result_uuids)
676
634
 
677
635
  relevant_nodes: list[EntityNode] = [node_uuid_map[uuid] for uuid in ranked_uuids]
678
636
 
@@ -719,8 +677,8 @@ async def get_relevant_nodes(
719
677
  WHERE m.group_id = $group_id
720
678
  WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
721
679
 
722
- WITH node,
723
- top_vector_nodes,
680
+ WITH node,
681
+ top_vector_nodes,
724
682
  [m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
725
683
 
726
684
  WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
@@ -728,10 +686,10 @@ async def get_relevant_nodes(
728
686
  UNWIND combined_nodes AS combined_node
729
687
  WITH node, collect(DISTINCT combined_node) AS deduped_nodes
730
688
 
731
- RETURN
689
+ RETURN
732
690
  node.uuid AS search_node_uuid,
733
691
  [x IN deduped_nodes | {
734
- uuid: x.uuid,
692
+ uuid: x.uuid,
735
693
  name: x.name,
736
694
  name_embedding: x.name_embedding,
737
695
  group_id: x.group_id,
@@ -755,12 +713,12 @@ async def get_relevant_nodes(
755
713
 
756
714
  results, _, _ = await driver.execute_query(
757
715
  query,
758
- params=query_params,
759
716
  nodes=query_nodes,
760
717
  group_id=group_id,
761
718
  limit=limit,
762
719
  min_score=min_score,
763
720
  routing_='r',
721
+ **query_params,
764
722
  )
765
723
 
766
724
  relevant_nodes_dict: dict[str, list[EntityNode]] = {
@@ -825,11 +783,11 @@ async def get_relevant_edges(
825
783
 
826
784
  results, _, _ = await driver.execute_query(
827
785
  query,
828
- params=query_params,
829
786
  edges=[edge.model_dump() for edge in edges],
830
787
  limit=limit,
831
788
  min_score=min_score,
832
789
  routing_='r',
790
+ **query_params,
833
791
  )
834
792
 
835
793
  relevant_edges_dict: dict[str, list[EntityEdge]] = {
@@ -895,11 +853,11 @@ async def get_edge_invalidation_candidates(
895
853
 
896
854
  results, _, _ = await driver.execute_query(
897
855
  query,
898
- params=query_params,
899
856
  edges=[edge.model_dump() for edge in edges],
900
857
  limit=limit,
901
858
  min_score=min_score,
902
859
  routing_='r',
860
+ **query_params,
903
861
  )
904
862
  invalidation_edges_dict: dict[str, list[EntityEdge]] = {
905
863
  result['search_edge_uuid']: [
@@ -914,7 +872,9 @@ async def get_edge_invalidation_candidates(
914
872
 
915
873
 
916
874
  # takes in a list of rankings of uuids
917
- def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[str]:
875
+ def rrf(
876
+ results: list[list[str]], rank_const=1, min_score: float = 0
877
+ ) -> tuple[list[str], list[float]]:
918
878
  scores: dict[str, float] = defaultdict(float)
919
879
  for result in results:
920
880
  for i, uuid in enumerate(result):
@@ -925,7 +885,9 @@ def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[st
925
885
 
926
886
  sorted_uuids = [term[0] for term in scored_uuids]
927
887
 
928
- return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score]
888
+ return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [
889
+ scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score
890
+ ]
929
891
 
930
892
 
931
893
  async def node_distance_reranker(
@@ -933,24 +895,23 @@ async def node_distance_reranker(
933
895
  node_uuids: list[str],
934
896
  center_node_uuid: str,
935
897
  min_score: float = 0,
936
- ) -> list[str]:
898
+ ) -> tuple[list[str], list[float]]:
937
899
  # filter out node_uuid center node node uuid
938
900
  filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
939
901
  scores: dict[str, float] = {center_node_uuid: 0.0}
940
902
 
941
903
  # Find the shortest path to center node
942
- query = """
904
+ results, header, _ = await driver.execute_query(
905
+ """
943
906
  UNWIND $node_uuids AS node_uuid
944
907
  MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
945
908
  RETURN 1 AS score, node_uuid AS uuid
946
- """
947
- results, header, _ = await driver.execute_query(
948
- query,
909
+ """,
949
910
  node_uuids=filtered_uuids,
950
911
  center_uuid=center_node_uuid,
951
912
  routing_='r',
952
913
  )
953
- if driver.provider == 'falkordb':
914
+ if driver.provider == GraphProvider.FALKORDB:
954
915
  results = [dict(zip(header, row, strict=True)) for row in results]
955
916
 
956
917
  for result in results:
@@ -970,24 +931,25 @@ async def node_distance_reranker(
970
931
  scores[center_node_uuid] = 0.1
971
932
  filtered_uuids = [center_node_uuid] + filtered_uuids
972
933
 
973
- return [uuid for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score]
934
+ return [uuid for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score], [
935
+ 1 / scores[uuid] for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score
936
+ ]
974
937
 
975
938
 
976
939
  async def episode_mentions_reranker(
977
940
  driver: GraphDriver, node_uuids: list[list[str]], min_score: float = 0
978
- ) -> list[str]:
941
+ ) -> tuple[list[str], list[float]]:
979
942
  # use rrf as a preliminary ranker
980
- sorted_uuids = rrf(node_uuids)
943
+ sorted_uuids, _ = rrf(node_uuids)
981
944
  scores: dict[str, float] = {}
982
945
 
983
946
  # Find the shortest path to center node
984
- query = """
985
- UNWIND $node_uuids AS node_uuid
947
+ results, _, _ = await driver.execute_query(
948
+ """
949
+ UNWIND $node_uuids AS node_uuid
986
950
  MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
987
951
  RETURN count(*) AS score, n.uuid AS uuid
988
- """
989
- results, _, _ = await driver.execute_query(
990
- query,
952
+ """,
991
953
  node_uuids=sorted_uuids,
992
954
  routing_='r',
993
955
  )
@@ -998,7 +960,9 @@ async def episode_mentions_reranker(
998
960
  # rerank on shortest distance
999
961
  sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
1000
962
 
1001
- return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score]
963
+ return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [
964
+ scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score
965
+ ]
1002
966
 
1003
967
 
1004
968
  def maximal_marginal_relevance(
@@ -1006,7 +970,7 @@ def maximal_marginal_relevance(
1006
970
  candidates: dict[str, list[float]],
1007
971
  mmr_lambda: float = DEFAULT_MMR_LAMBDA,
1008
972
  min_score: float = -2.0,
1009
- ) -> list[str]:
973
+ ) -> tuple[list[str], list[float]]:
1010
974
  start = time()
1011
975
  query_array = np.array(query_vector)
1012
976
  candidate_arrays: dict[str, NDArray] = {}
@@ -1037,21 +1001,24 @@ def maximal_marginal_relevance(
1037
1001
  end = time()
1038
1002
  logger.debug(f'Completed MMR reranking in {(end - start) * 1000} ms')
1039
1003
 
1040
- return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score]
1004
+ return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score], [
1005
+ mmr_scores[uuid] for uuid in uuids if mmr_scores[uuid] >= min_score
1006
+ ]
1041
1007
 
1042
1008
 
1043
1009
  async def get_embeddings_for_nodes(
1044
1010
  driver: GraphDriver, nodes: list[EntityNode]
1045
1011
  ) -> dict[str, list[float]]:
1046
- query: LiteralString = """MATCH (n:Entity)
1047
- WHERE n.uuid IN $node_uuids
1048
- RETURN DISTINCT
1049
- n.uuid AS uuid,
1050
- n.name_embedding AS name_embedding
1051
- """
1052
-
1053
1012
  results, _, _ = await driver.execute_query(
1054
- query, node_uuids=[node.uuid for node in nodes], routing_='r'
1013
+ """
1014
+ MATCH (n:Entity)
1015
+ WHERE n.uuid IN $node_uuids
1016
+ RETURN DISTINCT
1017
+ n.uuid AS uuid,
1018
+ n.name_embedding AS name_embedding
1019
+ """,
1020
+ node_uuids=[node.uuid for node in nodes],
1021
+ routing_='r',
1055
1022
  )
1056
1023
 
1057
1024
  embeddings_dict: dict[str, list[float]] = {}
@@ -1067,15 +1034,14 @@ async def get_embeddings_for_nodes(
1067
1034
  async def get_embeddings_for_communities(
1068
1035
  driver: GraphDriver, communities: list[CommunityNode]
1069
1036
  ) -> dict[str, list[float]]:
1070
- query: LiteralString = """MATCH (c:Community)
1071
- WHERE c.uuid IN $community_uuids
1072
- RETURN DISTINCT
1073
- c.uuid AS uuid,
1074
- c.name_embedding AS name_embedding
1075
- """
1076
-
1077
1037
  results, _, _ = await driver.execute_query(
1078
- query,
1038
+ """
1039
+ MATCH (c:Community)
1040
+ WHERE c.uuid IN $community_uuids
1041
+ RETURN DISTINCT
1042
+ c.uuid AS uuid,
1043
+ c.name_embedding AS name_embedding
1044
+ """,
1079
1045
  community_uuids=[community.uuid for community in communities],
1080
1046
  routing_='r',
1081
1047
  )
@@ -1093,15 +1059,14 @@ async def get_embeddings_for_communities(
1093
1059
  async def get_embeddings_for_edges(
1094
1060
  driver: GraphDriver, edges: list[EntityEdge]
1095
1061
  ) -> dict[str, list[float]]:
1096
- query: LiteralString = """MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
1097
- WHERE e.uuid IN $edge_uuids
1098
- RETURN DISTINCT
1099
- e.uuid AS uuid,
1100
- e.fact_embedding AS fact_embedding
1101
- """
1102
-
1103
1062
  results, _, _ = await driver.execute_query(
1104
- query,
1063
+ """
1064
+ MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
1065
+ WHERE e.uuid IN $edge_uuids
1066
+ RETURN DISTINCT
1067
+ e.uuid AS uuid,
1068
+ e.fact_embedding AS fact_embedding
1069
+ """,
1105
1070
  edge_uuids=[edge.uuid for edge in edges],
1106
1071
  routing_='r',
1107
1072
  )