graphiti-core 0.12.0rc5__py3-none-any.whl → 0.12.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -20,11 +20,16 @@ from time import time
20
20
  from typing import Any
21
21
 
22
22
  import numpy as np
23
- from neo4j import AsyncDriver, Query
24
23
  from numpy._typing import NDArray
25
24
  from typing_extensions import LiteralString
26
25
 
26
+ from graphiti_core.driver.driver import GraphDriver
27
27
  from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
28
+ from graphiti_core.graph_queries import (
29
+ get_nodes_query,
30
+ get_relationships_query,
31
+ get_vector_cosine_func_query,
32
+ )
28
33
  from graphiti_core.helpers import (
29
34
  DEFAULT_DATABASE,
30
35
  RUNTIME_QUERY,
@@ -58,7 +63,7 @@ MAX_QUERY_LENGTH = 32
58
63
 
59
64
  def fulltext_query(query: str, group_ids: list[str] | None = None):
60
65
  group_ids_filter_list = (
61
- [f'group_id:"{lucene_sanitize(g)}"' for g in group_ids] if group_ids is not None else []
66
+ [f"group_id-'{lucene_sanitize(g)}'" for g in group_ids] if group_ids is not None else []
62
67
  )
63
68
  group_ids_filter = ''
64
69
  for f in group_ids_filter_list:
@@ -77,7 +82,7 @@ def fulltext_query(query: str, group_ids: list[str] | None = None):
77
82
 
78
83
 
79
84
  async def get_episodes_by_mentions(
80
- driver: AsyncDriver,
85
+ driver: GraphDriver,
81
86
  nodes: list[EntityNode],
82
87
  edges: list[EntityEdge],
83
88
  limit: int = RELEVANT_SCHEMA_LIMIT,
@@ -92,11 +97,11 @@ async def get_episodes_by_mentions(
92
97
 
93
98
 
94
99
  async def get_mentioned_nodes(
95
- driver: AsyncDriver, episodes: list[EpisodicNode]
100
+ driver: GraphDriver, episodes: list[EpisodicNode]
96
101
  ) -> list[EntityNode]:
97
102
  episode_uuids = [episode.uuid for episode in episodes]
98
- records, _, _ = await driver.execute_query(
99
- """
103
+
104
+ query = """
100
105
  MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
101
106
  RETURN DISTINCT
102
107
  n.uuid As uuid,
@@ -106,7 +111,10 @@ async def get_mentioned_nodes(
106
111
  n.summary AS summary,
107
112
  labels(n) AS labels,
108
113
  properties(n) AS attributes
109
- """,
114
+ """
115
+
116
+ records, _, _ = await driver.execute_query(
117
+ query,
110
118
  uuids=episode_uuids,
111
119
  database_=DEFAULT_DATABASE,
112
120
  routing_='r',
@@ -118,11 +126,11 @@ async def get_mentioned_nodes(
118
126
 
119
127
 
120
128
  async def get_communities_by_nodes(
121
- driver: AsyncDriver, nodes: list[EntityNode]
129
+ driver: GraphDriver, nodes: list[EntityNode]
122
130
  ) -> list[CommunityNode]:
123
131
  node_uuids = [node.uuid for node in nodes]
124
- records, _, _ = await driver.execute_query(
125
- """
132
+
133
+ query = """
126
134
  MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
127
135
  RETURN DISTINCT
128
136
  c.uuid As uuid,
@@ -130,7 +138,10 @@ async def get_communities_by_nodes(
130
138
  c.name AS name,
131
139
  c.created_at AS created_at,
132
140
  c.summary AS summary
133
- """,
141
+ """
142
+
143
+ records, _, _ = await driver.execute_query(
144
+ query,
134
145
  uuids=node_uuids,
135
146
  database_=DEFAULT_DATABASE,
136
147
  routing_='r',
@@ -142,7 +153,7 @@ async def get_communities_by_nodes(
142
153
 
143
154
 
144
155
  async def edge_fulltext_search(
145
- driver: AsyncDriver,
156
+ driver: GraphDriver,
146
157
  query: str,
147
158
  search_filter: SearchFilters,
148
159
  group_ids: list[str] | None = None,
@@ -155,34 +166,35 @@ async def edge_fulltext_search(
155
166
 
156
167
  filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
157
168
 
158
- cypher_query = Query(
159
- """
160
- CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query, {limit: $limit})
161
- YIELD relationship AS rel, score
162
- MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
163
- WHERE r.group_id IN $group_ids"""
169
+ query = (
170
+ get_relationships_query('edge_name_and_fact', db_type=driver.provider)
171
+ + """
172
+ YIELD relationship AS rel, score
173
+ MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
174
+ WHERE r.group_id IN $group_ids """
164
175
  + filter_query
165
- + """\nWITH r, score, startNode(r) AS n, endNode(r) AS m
166
- RETURN
167
- r.uuid AS uuid,
168
- r.group_id AS group_id,
169
- n.uuid AS source_node_uuid,
170
- m.uuid AS target_node_uuid,
171
- r.created_at AS created_at,
172
- r.name AS name,
173
- r.fact AS fact,
174
- r.episodes AS episodes,
175
- r.expired_at AS expired_at,
176
- r.valid_at AS valid_at,
177
- r.invalid_at AS invalid_at,
178
- properties(r) AS attributes
179
- ORDER BY score DESC LIMIT $limit
180
- """
176
+ + """
177
+ WITH r, score, startNode(r) AS n, endNode(r) AS m
178
+ RETURN
179
+ r.uuid AS uuid,
180
+ r.group_id AS group_id,
181
+ n.uuid AS source_node_uuid,
182
+ m.uuid AS target_node_uuid,
183
+ r.created_at AS created_at,
184
+ r.name AS name,
185
+ r.fact AS fact,
186
+ r.episodes AS episodes,
187
+ r.expired_at AS expired_at,
188
+ r.valid_at AS valid_at,
189
+ r.invalid_at AS invalid_at,
190
+ properties(r) AS attributes
191
+ ORDER BY score DESC LIMIT $limit
192
+ """
181
193
  )
182
194
 
183
195
  records, _, _ = await driver.execute_query(
184
- cypher_query,
185
- filter_params,
196
+ query,
197
+ params=filter_params,
186
198
  query=fuzzy_query,
187
199
  group_ids=group_ids,
188
200
  limit=limit,
@@ -196,7 +208,7 @@ async def edge_fulltext_search(
196
208
 
197
209
 
198
210
  async def edge_similarity_search(
199
- driver: AsyncDriver,
211
+ driver: GraphDriver,
200
212
  search_vector: list[float],
201
213
  source_node_uuid: str | None,
202
214
  target_node_uuid: str | None,
@@ -224,36 +236,38 @@ async def edge_similarity_search(
224
236
  if target_node_uuid is not None:
225
237
  group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])'
226
238
 
227
- query: LiteralString = (
239
+ query = (
228
240
  RUNTIME_QUERY
229
241
  + """
230
242
  MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
231
- """
243
+ """
232
244
  + group_filter_query
233
245
  + filter_query
234
- + """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
235
- WHERE score > $min_score
236
- RETURN
237
- r.uuid AS uuid,
238
- r.group_id AS group_id,
239
- startNode(r).uuid AS source_node_uuid,
240
- endNode(r).uuid AS target_node_uuid,
241
- r.created_at AS created_at,
242
- r.name AS name,
243
- r.fact AS fact,
244
- r.episodes AS episodes,
245
- r.expired_at AS expired_at,
246
- r.valid_at AS valid_at,
247
- r.invalid_at AS invalid_at,
248
- properties(r) AS attributes
249
- ORDER BY score DESC
250
- LIMIT $limit
246
+ + """
247
+ WITH DISTINCT r, """
248
+ + get_vector_cosine_func_query('r.fact_embedding', '$search_vector', driver.provider)
249
+ + """ AS score
250
+ WHERE score > $min_score
251
+ RETURN
252
+ r.uuid AS uuid,
253
+ r.group_id AS group_id,
254
+ startNode(r).uuid AS source_node_uuid,
255
+ endNode(r).uuid AS target_node_uuid,
256
+ r.created_at AS created_at,
257
+ r.name AS name,
258
+ r.fact AS fact,
259
+ r.episodes AS episodes,
260
+ r.expired_at AS expired_at,
261
+ r.valid_at AS valid_at,
262
+ r.invalid_at AS invalid_at,
263
+ properties(r) AS attributes
264
+ ORDER BY score DESC
265
+ LIMIT $limit
251
266
  """
252
267
  )
253
-
254
- records, _, _ = await driver.execute_query(
268
+ records, header, _ = await driver.execute_query(
255
269
  query,
256
- query_params,
270
+ params=query_params,
257
271
  search_vector=search_vector,
258
272
  source_uuid=source_node_uuid,
259
273
  target_uuid=target_node_uuid,
@@ -264,13 +278,16 @@ async def edge_similarity_search(
264
278
  routing_='r',
265
279
  )
266
280
 
281
+ if driver.provider == 'falkordb':
282
+ records = [dict(zip(header, row, strict=True)) for row in records]
283
+
267
284
  edges = [get_entity_edge_from_record(record) for record in records]
268
285
 
269
286
  return edges
270
287
 
271
288
 
272
289
  async def edge_bfs_search(
273
- driver: AsyncDriver,
290
+ driver: GraphDriver,
274
291
  bfs_origin_node_uuids: list[str] | None,
275
292
  bfs_max_depth: int,
276
293
  search_filter: SearchFilters,
@@ -282,14 +299,14 @@ async def edge_bfs_search(
282
299
 
283
300
  filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
284
301
 
285
- query = Query(
302
+ query = (
286
303
  """
287
- UNWIND $bfs_origin_node_uuids AS origin_uuid
288
- MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
289
- UNWIND relationships(path) AS rel
290
- MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
291
- WHERE r.uuid = rel.uuid
292
- """
304
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
305
+ MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
306
+ UNWIND relationships(path) AS rel
307
+ MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
308
+ WHERE r.uuid = rel.uuid
309
+ """
293
310
  + filter_query
294
311
  + """
295
312
  RETURN DISTINCT
@@ -311,7 +328,7 @@ async def edge_bfs_search(
311
328
 
312
329
  records, _, _ = await driver.execute_query(
313
330
  query,
314
- filter_params,
331
+ params=filter_params,
315
332
  bfs_origin_node_uuids=bfs_origin_node_uuids,
316
333
  depth=bfs_max_depth,
317
334
  limit=limit,
@@ -325,7 +342,7 @@ async def edge_bfs_search(
325
342
 
326
343
 
327
344
  async def node_fulltext_search(
328
- driver: AsyncDriver,
345
+ driver: GraphDriver,
329
346
  query: str,
330
347
  search_filter: SearchFilters,
331
348
  group_ids: list[str] | None = None,
@@ -335,38 +352,41 @@ async def node_fulltext_search(
335
352
  fuzzy_query = fulltext_query(query, group_ids)
336
353
  if fuzzy_query == '':
337
354
  return []
338
-
339
355
  filter_query, filter_params = node_search_filter_query_constructor(search_filter)
340
356
 
341
357
  query = (
358
+ get_nodes_query(driver.provider, 'node_name_and_summary', '$query')
359
+ + """
360
+ YIELD node AS n, score
361
+ WITH n, score
362
+ LIMIT $limit
363
+ WHERE n:Entity
342
364
  """
343
- CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
344
- YIELD node AS n, score
345
- WHERE n:Entity
346
- """
347
365
  + filter_query
348
366
  + ENTITY_NODE_RETURN
349
367
  + """
350
368
  ORDER BY score DESC
351
369
  """
352
370
  )
353
-
354
- records, _, _ = await driver.execute_query(
371
+ records, header, _ = await driver.execute_query(
355
372
  query,
356
- filter_params,
373
+ params=filter_params,
357
374
  query=fuzzy_query,
358
375
  group_ids=group_ids,
359
376
  limit=limit,
360
377
  database_=DEFAULT_DATABASE,
361
378
  routing_='r',
362
379
  )
380
+ if driver.provider == 'falkordb':
381
+ records = [dict(zip(header, row, strict=True)) for row in records]
382
+
363
383
  nodes = [get_entity_node_from_record(record) for record in records]
364
384
 
365
385
  return nodes
366
386
 
367
387
 
368
388
  async def node_similarity_search(
369
- driver: AsyncDriver,
389
+ driver: GraphDriver,
370
390
  search_vector: list[float],
371
391
  search_filter: SearchFilters,
372
392
  group_ids: list[str] | None = None,
@@ -384,22 +404,28 @@ async def node_similarity_search(
384
404
  filter_query, filter_params = node_search_filter_query_constructor(search_filter)
385
405
  query_params.update(filter_params)
386
406
 
387
- records, _, _ = await driver.execute_query(
407
+ query = (
388
408
  RUNTIME_QUERY
389
409
  + """
390
- MATCH (n:Entity)
391
- """
410
+ MATCH (n:Entity)
411
+ """
392
412
  + group_filter_query
393
413
  + filter_query
394
414
  + """
395
- WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
396
- WHERE score > $min_score"""
415
+ WITH n, """
416
+ + get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
417
+ + """ AS score
418
+ WHERE score > $min_score"""
397
419
  + ENTITY_NODE_RETURN
398
420
  + """
399
421
  ORDER BY score DESC
400
422
  LIMIT $limit
401
- """,
402
- query_params,
423
+ """
424
+ )
425
+
426
+ records, header, _ = await driver.execute_query(
427
+ query,
428
+ params=query_params,
403
429
  search_vector=search_vector,
404
430
  group_ids=group_ids,
405
431
  limit=limit,
@@ -407,13 +433,15 @@ async def node_similarity_search(
407
433
  database_=DEFAULT_DATABASE,
408
434
  routing_='r',
409
435
  )
436
+ if driver.provider == 'falkordb':
437
+ records = [dict(zip(header, row, strict=True)) for row in records]
410
438
  nodes = [get_entity_node_from_record(record) for record in records]
411
439
 
412
440
  return nodes
413
441
 
414
442
 
415
443
  async def node_bfs_search(
416
- driver: AsyncDriver,
444
+ driver: GraphDriver,
417
445
  bfs_origin_node_uuids: list[str] | None,
418
446
  search_filter: SearchFilters,
419
447
  bfs_max_depth: int,
@@ -425,18 +453,21 @@ async def node_bfs_search(
425
453
 
426
454
  filter_query, filter_params = node_search_filter_query_constructor(search_filter)
427
455
 
428
- records, _, _ = await driver.execute_query(
456
+ query = (
429
457
  """
430
- UNWIND $bfs_origin_node_uuids AS origin_uuid
431
- MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
432
- WHERE n.group_id = origin.group_id
433
- """
458
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
459
+ MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
460
+ WHERE n.group_id = origin.group_id
461
+ """
434
462
  + filter_query
435
463
  + ENTITY_NODE_RETURN
436
464
  + """
437
465
  LIMIT $limit
438
- """,
439
- filter_params,
466
+ """
467
+ )
468
+ records, _, _ = await driver.execute_query(
469
+ query,
470
+ params=filter_params,
440
471
  bfs_origin_node_uuids=bfs_origin_node_uuids,
441
472
  depth=bfs_max_depth,
442
473
  limit=limit,
@@ -449,7 +480,7 @@ async def node_bfs_search(
449
480
 
450
481
 
451
482
  async def episode_fulltext_search(
452
- driver: AsyncDriver,
483
+ driver: GraphDriver,
453
484
  query: str,
454
485
  _search_filter: SearchFilters,
455
486
  group_ids: list[str] | None = None,
@@ -460,9 +491,9 @@ async def episode_fulltext_search(
460
491
  if fuzzy_query == '':
461
492
  return []
462
493
 
463
- records, _, _ = await driver.execute_query(
464
- """
465
- CALL db.index.fulltext.queryNodes("episode_content", $query, {limit: $limit})
494
+ query = (
495
+ get_nodes_query(driver.provider, 'episode_content', '$query')
496
+ + """
466
497
  YIELD node AS episode, score
467
498
  MATCH (e:Episodic)
468
499
  WHERE e.uuid = episode.uuid
@@ -478,7 +509,11 @@ async def episode_fulltext_search(
478
509
  e.entity_edges AS entity_edges
479
510
  ORDER BY score DESC
480
511
  LIMIT $limit
481
- """,
512
+ """
513
+ )
514
+
515
+ records, _, _ = await driver.execute_query(
516
+ query,
482
517
  query=fuzzy_query,
483
518
  group_ids=group_ids,
484
519
  limit=limit,
@@ -491,7 +526,7 @@ async def episode_fulltext_search(
491
526
 
492
527
 
493
528
  async def community_fulltext_search(
494
- driver: AsyncDriver,
529
+ driver: GraphDriver,
495
530
  query: str,
496
531
  group_ids: list[str] | None = None,
497
532
  limit=RELEVANT_SCHEMA_LIMIT,
@@ -501,9 +536,9 @@ async def community_fulltext_search(
501
536
  if fuzzy_query == '':
502
537
  return []
503
538
 
504
- records, _, _ = await driver.execute_query(
505
- """
506
- CALL db.index.fulltext.queryNodes("community_name", $query, {limit: $limit})
539
+ query = (
540
+ get_nodes_query(driver.provider, 'community_name', '$query')
541
+ + """
507
542
  YIELD node AS comm, score
508
543
  RETURN
509
544
  comm.uuid AS uuid,
@@ -513,7 +548,11 @@ async def community_fulltext_search(
513
548
  comm.summary AS summary
514
549
  ORDER BY score DESC
515
550
  LIMIT $limit
516
- """,
551
+ """
552
+ )
553
+
554
+ records, _, _ = await driver.execute_query(
555
+ query,
517
556
  query=fuzzy_query,
518
557
  group_ids=group_ids,
519
558
  limit=limit,
@@ -526,7 +565,7 @@ async def community_fulltext_search(
526
565
 
527
566
 
528
567
  async def community_similarity_search(
529
- driver: AsyncDriver,
568
+ driver: GraphDriver,
530
569
  search_vector: list[float],
531
570
  group_ids: list[str] | None = None,
532
571
  limit=RELEVANT_SCHEMA_LIMIT,
@@ -540,14 +579,16 @@ async def community_similarity_search(
540
579
  group_filter_query += 'WHERE comm.group_id IN $group_ids'
541
580
  query_params['group_ids'] = group_ids
542
581
 
543
- records, _, _ = await driver.execute_query(
582
+ query = (
544
583
  RUNTIME_QUERY
545
584
  + """
546
585
  MATCH (comm:Community)
547
586
  """
548
587
  + group_filter_query
549
588
  + """
550
- WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
589
+ WITH comm, """
590
+ + get_vector_cosine_func_query('comm.name_embedding', '$search_vector', driver.provider)
591
+ + """ AS score
551
592
  WHERE score > $min_score
552
593
  RETURN
553
594
  comm.uuid As uuid,
@@ -557,7 +598,11 @@ async def community_similarity_search(
557
598
  comm.summary AS summary
558
599
  ORDER BY score DESC
559
600
  LIMIT $limit
560
- """,
601
+ """
602
+ )
603
+
604
+ records, _, _ = await driver.execute_query(
605
+ query,
561
606
  search_vector=search_vector,
562
607
  group_ids=group_ids,
563
608
  limit=limit,
@@ -573,7 +618,7 @@ async def community_similarity_search(
573
618
  async def hybrid_node_search(
574
619
  queries: list[str],
575
620
  embeddings: list[list[float]],
576
- driver: AsyncDriver,
621
+ driver: GraphDriver,
577
622
  search_filter: SearchFilters,
578
623
  group_ids: list[str] | None = None,
579
624
  limit: int = RELEVANT_SCHEMA_LIMIT,
@@ -590,7 +635,7 @@ async def hybrid_node_search(
590
635
  A list of text queries to search for.
591
636
  embeddings : list[list[float]]
592
637
  A list of embedding vectors corresponding to the queries. If empty only fulltext search is performed.
593
- driver : AsyncDriver
638
+ driver : GraphDriver
594
639
  The Neo4j driver instance for database operations.
595
640
  group_ids : list[str] | None, optional
596
641
  The list of group ids to retrieve nodes from.
@@ -645,7 +690,7 @@ async def hybrid_node_search(
645
690
 
646
691
 
647
692
  async def get_relevant_nodes(
648
- driver: AsyncDriver,
693
+ driver: GraphDriver,
649
694
  nodes: list[EntityNode],
650
695
  search_filter: SearchFilters,
651
696
  min_score: float = DEFAULT_MIN_SCORE,
@@ -664,29 +709,33 @@ async def get_relevant_nodes(
664
709
 
665
710
  query = (
666
711
  RUNTIME_QUERY
667
- + """UNWIND $nodes AS node
668
- MATCH (n:Entity {group_id: $group_id})
669
- """
712
+ + """
713
+ UNWIND $nodes AS node
714
+ MATCH (n:Entity {group_id: $group_id})
715
+ """
670
716
  + filter_query
671
717
  + """
672
- WITH node, n, vector.similarity.cosine(n.name_embedding, node.name_embedding) AS score
718
+ WITH node, n, """
719
+ + get_vector_cosine_func_query('n.name_embedding', 'node.name_embedding', driver.provider)
720
+ + """ AS score
673
721
  WHERE score > $min_score
674
722
  WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
675
-
676
- CALL db.index.fulltext.queryNodes("node_name_and_summary", node.fulltext_query, {limit: $limit})
723
+ """
724
+ + get_nodes_query(driver.provider, 'node_name_and_summary', 'node.fulltext_query')
725
+ + """
677
726
  YIELD node AS m
678
727
  WHERE m.group_id = $group_id
679
728
  WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
680
-
729
+
681
730
  WITH node,
682
731
  top_vector_nodes,
683
732
  [m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
684
-
733
+
685
734
  WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
686
-
735
+
687
736
  UNWIND combined_nodes AS combined_node
688
737
  WITH node, collect(DISTINCT combined_node) AS deduped_nodes
689
-
738
+
690
739
  RETURN
691
740
  node.uuid AS search_node_uuid,
692
741
  [x IN deduped_nodes | {
@@ -714,7 +763,7 @@ async def get_relevant_nodes(
714
763
 
715
764
  results, _, _ = await driver.execute_query(
716
765
  query,
717
- query_params,
766
+ params=query_params,
718
767
  nodes=query_nodes,
719
768
  group_id=group_id,
720
769
  limit=limit,
@@ -736,7 +785,7 @@ async def get_relevant_nodes(
736
785
 
737
786
 
738
787
  async def get_relevant_edges(
739
- driver: AsyncDriver,
788
+ driver: GraphDriver,
740
789
  edges: list[EntityEdge],
741
790
  search_filter: SearchFilters,
742
791
  min_score: float = DEFAULT_MIN_SCORE,
@@ -752,43 +801,47 @@ async def get_relevant_edges(
752
801
 
753
802
  query = (
754
803
  RUNTIME_QUERY
755
- + """UNWIND $edges AS edge
756
- MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
757
- """
804
+ + """
805
+ UNWIND $edges AS edge
806
+ MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
807
+ """
758
808
  + filter_query
759
809
  + """
760
- WITH e, edge, vector.similarity.cosine(e.fact_embedding, edge.fact_embedding) AS score
761
- WHERE score > $min_score
762
- WITH edge, e, score
763
- ORDER BY score DESC
764
- RETURN edge.uuid AS search_edge_uuid,
765
- collect({
766
- uuid: e.uuid,
767
- source_node_uuid: startNode(e).uuid,
768
- target_node_uuid: endNode(e).uuid,
769
- created_at: e.created_at,
770
- name: e.name,
771
- group_id: e.group_id,
772
- fact: e.fact,
773
- fact_embedding: e.fact_embedding,
774
- episodes: e.episodes,
775
- expired_at: e.expired_at,
776
- valid_at: e.valid_at,
777
- invalid_at: e.invalid_at,
778
- attributes: properties(e)
779
- })[..$limit] AS matches
810
+ WITH e, edge, """
811
+ + get_vector_cosine_func_query('e.fact_embedding', 'edge.fact_embedding', driver.provider)
812
+ + """ AS score
813
+ WHERE score > $min_score
814
+ WITH edge, e, score
815
+ ORDER BY score DESC
816
+ RETURN edge.uuid AS search_edge_uuid,
817
+ collect({
818
+ uuid: e.uuid,
819
+ source_node_uuid: startNode(e).uuid,
820
+ target_node_uuid: endNode(e).uuid,
821
+ created_at: e.created_at,
822
+ name: e.name,
823
+ group_id: e.group_id,
824
+ fact: e.fact,
825
+ fact_embedding: e.fact_embedding,
826
+ episodes: e.episodes,
827
+ expired_at: e.expired_at,
828
+ valid_at: e.valid_at,
829
+ invalid_at: e.invalid_at,
830
+ attributes: properties(e)
831
+ })[..$limit] AS matches
780
832
  """
781
833
  )
782
834
 
783
835
  results, _, _ = await driver.execute_query(
784
836
  query,
785
- query_params,
837
+ params=query_params,
786
838
  edges=[edge.model_dump() for edge in edges],
787
839
  limit=limit,
788
840
  min_score=min_score,
789
841
  database_=DEFAULT_DATABASE,
790
842
  routing_='r',
791
843
  )
844
+
792
845
  relevant_edges_dict: dict[str, list[EntityEdge]] = {
793
846
  result['search_edge_uuid']: [
794
847
  get_entity_edge_from_record(record) for record in result['matches']
@@ -802,7 +855,7 @@ async def get_relevant_edges(
802
855
 
803
856
 
804
857
  async def get_edge_invalidation_candidates(
805
- driver: AsyncDriver,
858
+ driver: GraphDriver,
806
859
  edges: list[EntityEdge],
807
860
  search_filter: SearchFilters,
808
861
  min_score: float = DEFAULT_MIN_SCORE,
@@ -818,38 +871,41 @@ async def get_edge_invalidation_candidates(
818
871
 
819
872
  query = (
820
873
  RUNTIME_QUERY
821
- + """UNWIND $edges AS edge
822
- MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
823
- WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
824
- """
874
+ + """
875
+ UNWIND $edges AS edge
876
+ MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
877
+ WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
878
+ """
825
879
  + filter_query
826
880
  + """
827
- WITH edge, e, vector.similarity.cosine(e.fact_embedding, edge.fact_embedding) AS score
828
- WHERE score > $min_score
829
- WITH edge, e, score
830
- ORDER BY score DESC
831
- RETURN edge.uuid AS search_edge_uuid,
832
- collect({
833
- uuid: e.uuid,
834
- source_node_uuid: startNode(e).uuid,
835
- target_node_uuid: endNode(e).uuid,
836
- created_at: e.created_at,
837
- name: e.name,
838
- group_id: e.group_id,
839
- fact: e.fact,
840
- fact_embedding: e.fact_embedding,
841
- episodes: e.episodes,
842
- expired_at: e.expired_at,
843
- valid_at: e.valid_at,
844
- invalid_at: e.invalid_at,
845
- attributes: properties(e)
846
- })[..$limit] AS matches
881
+ WITH edge, e, """
882
+ + get_vector_cosine_func_query('e.fact_embedding', 'edge.fact_embedding', driver.provider)
883
+ + """ AS score
884
+ WHERE score > $min_score
885
+ WITH edge, e, score
886
+ ORDER BY score DESC
887
+ RETURN edge.uuid AS search_edge_uuid,
888
+ collect({
889
+ uuid: e.uuid,
890
+ source_node_uuid: startNode(e).uuid,
891
+ target_node_uuid: endNode(e).uuid,
892
+ created_at: e.created_at,
893
+ name: e.name,
894
+ group_id: e.group_id,
895
+ fact: e.fact,
896
+ fact_embedding: e.fact_embedding,
897
+ episodes: e.episodes,
898
+ expired_at: e.expired_at,
899
+ valid_at: e.valid_at,
900
+ invalid_at: e.invalid_at,
901
+ attributes: properties(e)
902
+ })[..$limit] AS matches
847
903
  """
848
904
  )
849
905
 
850
906
  results, _, _ = await driver.execute_query(
851
907
  query,
852
- query_params,
908
+ params=query_params,
853
909
  edges=[edge.model_dump() for edge in edges],
854
910
  limit=limit,
855
911
  min_score=min_score,
@@ -884,7 +940,7 @@ def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[st
884
940
 
885
941
 
886
942
  async def node_distance_reranker(
887
- driver: AsyncDriver,
943
+ driver: GraphDriver,
888
944
  node_uuids: list[str],
889
945
  center_node_uuid: str,
890
946
  min_score: float = 0,
@@ -894,21 +950,22 @@ async def node_distance_reranker(
894
950
  scores: dict[str, float] = {center_node_uuid: 0.0}
895
951
 
896
952
  # Find the shortest path to center node
897
- query = Query("""
953
+ query = """
898
954
  UNWIND $node_uuids AS node_uuid
899
- MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: node_uuid})
900
- RETURN length(p) AS score, node_uuid AS uuid
901
- """)
902
-
903
- path_results, _, _ = await driver.execute_query(
955
+ MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
956
+ RETURN 1 AS score, node_uuid AS uuid
957
+ """
958
+ results, header, _ = await driver.execute_query(
904
959
  query,
905
960
  node_uuids=filtered_uuids,
906
961
  center_uuid=center_node_uuid,
907
962
  database_=DEFAULT_DATABASE,
908
963
  routing_='r',
909
964
  )
965
+ if driver.provider == 'falkordb':
966
+ results = [dict(zip(header, row, strict=True)) for row in results]
910
967
 
911
- for result in path_results:
968
+ for result in results:
912
969
  uuid = result['uuid']
913
970
  score = result['score']
914
971
  scores[uuid] = score
@@ -929,19 +986,18 @@ async def node_distance_reranker(
929
986
 
930
987
 
931
988
  async def episode_mentions_reranker(
932
- driver: AsyncDriver, node_uuids: list[list[str]], min_score: float = 0
989
+ driver: GraphDriver, node_uuids: list[list[str]], min_score: float = 0
933
990
  ) -> list[str]:
934
991
  # use rrf as a preliminary ranker
935
992
  sorted_uuids = rrf(node_uuids)
936
993
  scores: dict[str, float] = {}
937
994
 
938
995
  # Find the shortest path to center node
939
- query = Query("""
996
+ query = """
940
997
  UNWIND $node_uuids AS node_uuid
941
998
  MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
942
999
  RETURN count(*) AS score, n.uuid AS uuid
943
- """)
944
-
1000
+ """
945
1001
  results, _, _ = await driver.execute_query(
946
1002
  query,
947
1003
  node_uuids=sorted_uuids,
@@ -998,7 +1054,7 @@ def maximal_marginal_relevance(
998
1054
 
999
1055
 
1000
1056
  async def get_embeddings_for_nodes(
1001
- driver: AsyncDriver, nodes: list[EntityNode]
1057
+ driver: GraphDriver, nodes: list[EntityNode]
1002
1058
  ) -> dict[str, list[float]]:
1003
1059
  query: LiteralString = """MATCH (n:Entity)
1004
1060
  WHERE n.uuid IN $node_uuids
@@ -1022,7 +1078,7 @@ async def get_embeddings_for_nodes(
1022
1078
 
1023
1079
 
1024
1080
  async def get_embeddings_for_communities(
1025
- driver: AsyncDriver, communities: list[CommunityNode]
1081
+ driver: GraphDriver, communities: list[CommunityNode]
1026
1082
  ) -> dict[str, list[float]]:
1027
1083
  query: LiteralString = """MATCH (c:Community)
1028
1084
  WHERE c.uuid IN $community_uuids
@@ -1049,7 +1105,7 @@ async def get_embeddings_for_communities(
1049
1105
 
1050
1106
 
1051
1107
  async def get_embeddings_for_edges(
1052
- driver: AsyncDriver, edges: list[EntityEdge]
1108
+ driver: GraphDriver, edges: list[EntityEdge]
1053
1109
  ) -> dict[str, list[float]]:
1054
1110
  query: LiteralString = """MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
1055
1111
  WHERE e.uuid IN $edge_uuids