graphiti-core 0.11.6rc9__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (33) hide show
  1. graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
  2. graphiti_core/driver/__init__.py +17 -0
  3. graphiti_core/driver/driver.py +66 -0
  4. graphiti_core/driver/falkordb_driver.py +132 -0
  5. graphiti_core/driver/neo4j_driver.py +61 -0
  6. graphiti_core/edges.py +66 -40
  7. graphiti_core/embedder/azure_openai.py +64 -0
  8. graphiti_core/embedder/gemini.py +14 -3
  9. graphiti_core/graph_queries.py +149 -0
  10. graphiti_core/graphiti.py +41 -14
  11. graphiti_core/graphiti_types.py +2 -2
  12. graphiti_core/helpers.py +9 -4
  13. graphiti_core/llm_client/__init__.py +16 -0
  14. graphiti_core/llm_client/azure_openai_client.py +73 -0
  15. graphiti_core/llm_client/gemini_client.py +4 -1
  16. graphiti_core/models/edges/edge_db_queries.py +2 -4
  17. graphiti_core/nodes.py +31 -31
  18. graphiti_core/prompts/dedupe_edges.py +52 -1
  19. graphiti_core/prompts/dedupe_nodes.py +79 -4
  20. graphiti_core/prompts/extract_edges.py +50 -5
  21. graphiti_core/prompts/invalidate_edges.py +1 -1
  22. graphiti_core/search/search.py +6 -10
  23. graphiti_core/search/search_filters.py +23 -9
  24. graphiti_core/search/search_utils.py +250 -189
  25. graphiti_core/utils/bulk_utils.py +38 -11
  26. graphiti_core/utils/maintenance/community_operations.py +6 -7
  27. graphiti_core/utils/maintenance/edge_operations.py +149 -19
  28. graphiti_core/utils/maintenance/graph_data_operations.py +13 -42
  29. graphiti_core/utils/maintenance/node_operations.py +52 -71
  30. {graphiti_core-0.11.6rc9.dist-info → graphiti_core-0.12.0.dist-info}/METADATA +14 -5
  31. {graphiti_core-0.11.6rc9.dist-info → graphiti_core-0.12.0.dist-info}/RECORD +33 -26
  32. {graphiti_core-0.11.6rc9.dist-info → graphiti_core-0.12.0.dist-info}/LICENSE +0 -0
  33. {graphiti_core-0.11.6rc9.dist-info → graphiti_core-0.12.0.dist-info}/WHEEL +0 -0
@@ -20,11 +20,16 @@ from time import time
20
20
  from typing import Any
21
21
 
22
22
  import numpy as np
23
- from neo4j import AsyncDriver, Query
24
23
  from numpy._typing import NDArray
25
24
  from typing_extensions import LiteralString
26
25
 
26
+ from graphiti_core.driver.driver import GraphDriver
27
27
  from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
28
+ from graphiti_core.graph_queries import (
29
+ get_nodes_query,
30
+ get_relationships_query,
31
+ get_vector_cosine_func_query,
32
+ )
28
33
  from graphiti_core.helpers import (
29
34
  DEFAULT_DATABASE,
30
35
  RUNTIME_QUERY,
@@ -58,7 +63,7 @@ MAX_QUERY_LENGTH = 32
58
63
 
59
64
  def fulltext_query(query: str, group_ids: list[str] | None = None):
60
65
  group_ids_filter_list = (
61
- [f'group_id:"{lucene_sanitize(g)}"' for g in group_ids] if group_ids is not None else []
66
+ [f"group_id-'{lucene_sanitize(g)}'" for g in group_ids] if group_ids is not None else []
62
67
  )
63
68
  group_ids_filter = ''
64
69
  for f in group_ids_filter_list:
@@ -77,7 +82,7 @@ def fulltext_query(query: str, group_ids: list[str] | None = None):
77
82
 
78
83
 
79
84
  async def get_episodes_by_mentions(
80
- driver: AsyncDriver,
85
+ driver: GraphDriver,
81
86
  nodes: list[EntityNode],
82
87
  edges: list[EntityEdge],
83
88
  limit: int = RELEVANT_SCHEMA_LIMIT,
@@ -92,11 +97,11 @@ async def get_episodes_by_mentions(
92
97
 
93
98
 
94
99
  async def get_mentioned_nodes(
95
- driver: AsyncDriver, episodes: list[EpisodicNode]
100
+ driver: GraphDriver, episodes: list[EpisodicNode]
96
101
  ) -> list[EntityNode]:
97
102
  episode_uuids = [episode.uuid for episode in episodes]
98
- records, _, _ = await driver.execute_query(
99
- """
103
+
104
+ query = """
100
105
  MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
101
106
  RETURN DISTINCT
102
107
  n.uuid As uuid,
@@ -106,7 +111,10 @@ async def get_mentioned_nodes(
106
111
  n.summary AS summary,
107
112
  labels(n) AS labels,
108
113
  properties(n) AS attributes
109
- """,
114
+ """
115
+
116
+ records, _, _ = await driver.execute_query(
117
+ query,
110
118
  uuids=episode_uuids,
111
119
  database_=DEFAULT_DATABASE,
112
120
  routing_='r',
@@ -118,11 +126,11 @@ async def get_mentioned_nodes(
118
126
 
119
127
 
120
128
  async def get_communities_by_nodes(
121
- driver: AsyncDriver, nodes: list[EntityNode]
129
+ driver: GraphDriver, nodes: list[EntityNode]
122
130
  ) -> list[CommunityNode]:
123
131
  node_uuids = [node.uuid for node in nodes]
124
- records, _, _ = await driver.execute_query(
125
- """
132
+
133
+ query = """
126
134
  MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
127
135
  RETURN DISTINCT
128
136
  c.uuid As uuid,
@@ -130,7 +138,10 @@ async def get_communities_by_nodes(
130
138
  c.name AS name,
131
139
  c.created_at AS created_at,
132
140
  c.summary AS summary
133
- """,
141
+ """
142
+
143
+ records, _, _ = await driver.execute_query(
144
+ query,
134
145
  uuids=node_uuids,
135
146
  database_=DEFAULT_DATABASE,
136
147
  routing_='r',
@@ -142,7 +153,7 @@ async def get_communities_by_nodes(
142
153
 
143
154
 
144
155
  async def edge_fulltext_search(
145
- driver: AsyncDriver,
156
+ driver: GraphDriver,
146
157
  query: str,
147
158
  search_filter: SearchFilters,
148
159
  group_ids: list[str] | None = None,
@@ -155,33 +166,35 @@ async def edge_fulltext_search(
155
166
 
156
167
  filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
157
168
 
158
- cypher_query = Query(
159
- """
160
- CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query, {limit: $limit})
161
- YIELD relationship AS rel, score
162
- MATCH (:Entity)-[r:RELATES_TO]->(:Entity)
163
- WHERE r.group_id IN $group_ids"""
169
+ query = (
170
+ get_relationships_query('edge_name_and_fact', db_type=driver.provider)
171
+ + """
172
+ YIELD relationship AS rel, score
173
+ MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
174
+ WHERE r.group_id IN $group_ids """
164
175
  + filter_query
165
- + """\nWITH r, score, startNode(r) AS n, endNode(r) AS m
166
- RETURN
167
- r.uuid AS uuid,
168
- r.group_id AS group_id,
169
- n.uuid AS source_node_uuid,
170
- m.uuid AS target_node_uuid,
171
- r.created_at AS created_at,
172
- r.name AS name,
173
- r.fact AS fact,
174
- r.episodes AS episodes,
175
- r.expired_at AS expired_at,
176
- r.valid_at AS valid_at,
177
- r.invalid_at AS invalid_at
178
- ORDER BY score DESC LIMIT $limit
179
- """
176
+ + """
177
+ WITH r, score, startNode(r) AS n, endNode(r) AS m
178
+ RETURN
179
+ r.uuid AS uuid,
180
+ r.group_id AS group_id,
181
+ n.uuid AS source_node_uuid,
182
+ m.uuid AS target_node_uuid,
183
+ r.created_at AS created_at,
184
+ r.name AS name,
185
+ r.fact AS fact,
186
+ r.episodes AS episodes,
187
+ r.expired_at AS expired_at,
188
+ r.valid_at AS valid_at,
189
+ r.invalid_at AS invalid_at,
190
+ properties(r) AS attributes
191
+ ORDER BY score DESC LIMIT $limit
192
+ """
180
193
  )
181
194
 
182
195
  records, _, _ = await driver.execute_query(
183
- cypher_query,
184
- filter_params,
196
+ query,
197
+ params=filter_params,
185
198
  query=fuzzy_query,
186
199
  group_ids=group_ids,
187
200
  limit=limit,
@@ -195,7 +208,7 @@ async def edge_fulltext_search(
195
208
 
196
209
 
197
210
  async def edge_similarity_search(
198
- driver: AsyncDriver,
211
+ driver: GraphDriver,
199
212
  search_vector: list[float],
200
213
  source_node_uuid: str | None,
201
214
  target_node_uuid: str | None,
@@ -210,9 +223,9 @@ async def edge_similarity_search(
210
223
  filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
211
224
  query_params.update(filter_params)
212
225
 
213
- group_filter_query: LiteralString = ''
226
+ group_filter_query: LiteralString = 'WHERE r.group_id IS NOT NULL'
214
227
  if group_ids is not None:
215
- group_filter_query += 'WHERE r.group_id IN $group_ids'
228
+ group_filter_query += '\nAND r.group_id IN $group_ids'
216
229
  query_params['group_ids'] = group_ids
217
230
  query_params['source_node_uuid'] = source_node_uuid
218
231
  query_params['target_node_uuid'] = target_node_uuid
@@ -223,35 +236,38 @@ async def edge_similarity_search(
223
236
  if target_node_uuid is not None:
224
237
  group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])'
225
238
 
226
- query: LiteralString = (
239
+ query = (
227
240
  RUNTIME_QUERY
228
241
  + """
229
- MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
230
- """
242
+ MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
243
+ """
231
244
  + group_filter_query
232
245
  + filter_query
233
- + """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
234
- WHERE score > $min_score
235
- RETURN
236
- r.uuid AS uuid,
237
- r.group_id AS group_id,
238
- startNode(r).uuid AS source_node_uuid,
239
- endNode(r).uuid AS target_node_uuid,
240
- r.created_at AS created_at,
241
- r.name AS name,
242
- r.fact AS fact,
243
- r.episodes AS episodes,
244
- r.expired_at AS expired_at,
245
- r.valid_at AS valid_at,
246
- r.invalid_at AS invalid_at
247
- ORDER BY score DESC
248
- LIMIT $limit
246
+ + """
247
+ WITH DISTINCT r, """
248
+ + get_vector_cosine_func_query('r.fact_embedding', '$search_vector', driver.provider)
249
+ + """ AS score
250
+ WHERE score > $min_score
251
+ RETURN
252
+ r.uuid AS uuid,
253
+ r.group_id AS group_id,
254
+ startNode(r).uuid AS source_node_uuid,
255
+ endNode(r).uuid AS target_node_uuid,
256
+ r.created_at AS created_at,
257
+ r.name AS name,
258
+ r.fact AS fact,
259
+ r.episodes AS episodes,
260
+ r.expired_at AS expired_at,
261
+ r.valid_at AS valid_at,
262
+ r.invalid_at AS invalid_at,
263
+ properties(r) AS attributes
264
+ ORDER BY score DESC
265
+ LIMIT $limit
249
266
  """
250
267
  )
251
-
252
- records, _, _ = await driver.execute_query(
268
+ records, header, _ = await driver.execute_query(
253
269
  query,
254
- query_params,
270
+ params=query_params,
255
271
  search_vector=search_vector,
256
272
  source_uuid=source_node_uuid,
257
273
  target_uuid=target_node_uuid,
@@ -262,13 +278,16 @@ async def edge_similarity_search(
262
278
  routing_='r',
263
279
  )
264
280
 
281
+ if driver.provider == 'falkordb':
282
+ records = [dict(zip(header, row, strict=True)) for row in records]
283
+
265
284
  edges = [get_entity_edge_from_record(record) for record in records]
266
285
 
267
286
  return edges
268
287
 
269
288
 
270
289
  async def edge_bfs_search(
271
- driver: AsyncDriver,
290
+ driver: GraphDriver,
272
291
  bfs_origin_node_uuids: list[str] | None,
273
292
  bfs_max_depth: int,
274
293
  search_filter: SearchFilters,
@@ -280,14 +299,14 @@ async def edge_bfs_search(
280
299
 
281
300
  filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
282
301
 
283
- query = Query(
302
+ query = (
284
303
  """
285
- UNWIND $bfs_origin_node_uuids AS origin_uuid
286
- MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
287
- UNWIND relationships(path) AS rel
288
- MATCH ()-[r:RELATES_TO]-()
289
- WHERE r.uuid = rel.uuid
290
- """
304
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
305
+ MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
306
+ UNWIND relationships(path) AS rel
307
+ MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
308
+ WHERE r.uuid = rel.uuid
309
+ """
291
310
  + filter_query
292
311
  + """
293
312
  RETURN DISTINCT
@@ -301,14 +320,15 @@ async def edge_bfs_search(
301
320
  r.episodes AS episodes,
302
321
  r.expired_at AS expired_at,
303
322
  r.valid_at AS valid_at,
304
- r.invalid_at AS invalid_at
323
+ r.invalid_at AS invalid_at,
324
+ properties(r) AS attributes
305
325
  LIMIT $limit
306
326
  """
307
327
  )
308
328
 
309
329
  records, _, _ = await driver.execute_query(
310
330
  query,
311
- filter_params,
331
+ params=filter_params,
312
332
  bfs_origin_node_uuids=bfs_origin_node_uuids,
313
333
  depth=bfs_max_depth,
314
334
  limit=limit,
@@ -322,7 +342,7 @@ async def edge_bfs_search(
322
342
 
323
343
 
324
344
  async def node_fulltext_search(
325
- driver: AsyncDriver,
345
+ driver: GraphDriver,
326
346
  query: str,
327
347
  search_filter: SearchFilters,
328
348
  group_ids: list[str] | None = None,
@@ -332,38 +352,41 @@ async def node_fulltext_search(
332
352
  fuzzy_query = fulltext_query(query, group_ids)
333
353
  if fuzzy_query == '':
334
354
  return []
335
-
336
355
  filter_query, filter_params = node_search_filter_query_constructor(search_filter)
337
356
 
338
357
  query = (
358
+ get_nodes_query(driver.provider, 'node_name_and_summary', '$query')
359
+ + """
360
+ YIELD node AS n, score
361
+ WITH n, score
362
+ LIMIT $limit
363
+ WHERE n:Entity
339
364
  """
340
- CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
341
- YIELD node AS n, score
342
- WHERE n:Entity
343
- """
344
365
  + filter_query
345
366
  + ENTITY_NODE_RETURN
346
367
  + """
347
368
  ORDER BY score DESC
348
369
  """
349
370
  )
350
-
351
- records, _, _ = await driver.execute_query(
371
+ records, header, _ = await driver.execute_query(
352
372
  query,
353
- filter_params,
373
+ params=filter_params,
354
374
  query=fuzzy_query,
355
375
  group_ids=group_ids,
356
376
  limit=limit,
357
377
  database_=DEFAULT_DATABASE,
358
378
  routing_='r',
359
379
  )
380
+ if driver.provider == 'falkordb':
381
+ records = [dict(zip(header, row, strict=True)) for row in records]
382
+
360
383
  nodes = [get_entity_node_from_record(record) for record in records]
361
384
 
362
385
  return nodes
363
386
 
364
387
 
365
388
  async def node_similarity_search(
366
- driver: AsyncDriver,
389
+ driver: GraphDriver,
367
390
  search_vector: list[float],
368
391
  search_filter: SearchFilters,
369
392
  group_ids: list[str] | None = None,
@@ -373,30 +396,36 @@ async def node_similarity_search(
373
396
  # vector similarity search over entity names
374
397
  query_params: dict[str, Any] = {}
375
398
 
376
- group_filter_query: LiteralString = ''
399
+ group_filter_query: LiteralString = 'WHERE n.group_id IS NOT NULL'
377
400
  if group_ids is not None:
378
- group_filter_query += 'WHERE n.group_id IN $group_ids'
401
+ group_filter_query += ' AND n.group_id IN $group_ids'
379
402
  query_params['group_ids'] = group_ids
380
403
 
381
404
  filter_query, filter_params = node_search_filter_query_constructor(search_filter)
382
405
  query_params.update(filter_params)
383
406
 
384
- records, _, _ = await driver.execute_query(
407
+ query = (
385
408
  RUNTIME_QUERY
386
409
  + """
387
- MATCH (n:Entity)
388
- """
410
+ MATCH (n:Entity)
411
+ """
389
412
  + group_filter_query
390
413
  + filter_query
391
414
  + """
392
- WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
393
- WHERE score > $min_score"""
415
+ WITH n, """
416
+ + get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
417
+ + """ AS score
418
+ WHERE score > $min_score"""
394
419
  + ENTITY_NODE_RETURN
395
420
  + """
396
421
  ORDER BY score DESC
397
422
  LIMIT $limit
398
- """,
399
- query_params,
423
+ """
424
+ )
425
+
426
+ records, header, _ = await driver.execute_query(
427
+ query,
428
+ params=query_params,
400
429
  search_vector=search_vector,
401
430
  group_ids=group_ids,
402
431
  limit=limit,
@@ -404,13 +433,15 @@ async def node_similarity_search(
404
433
  database_=DEFAULT_DATABASE,
405
434
  routing_='r',
406
435
  )
436
+ if driver.provider == 'falkordb':
437
+ records = [dict(zip(header, row, strict=True)) for row in records]
407
438
  nodes = [get_entity_node_from_record(record) for record in records]
408
439
 
409
440
  return nodes
410
441
 
411
442
 
412
443
  async def node_bfs_search(
413
- driver: AsyncDriver,
444
+ driver: GraphDriver,
414
445
  bfs_origin_node_uuids: list[str] | None,
415
446
  search_filter: SearchFilters,
416
447
  bfs_max_depth: int,
@@ -422,18 +453,21 @@ async def node_bfs_search(
422
453
 
423
454
  filter_query, filter_params = node_search_filter_query_constructor(search_filter)
424
455
 
425
- records, _, _ = await driver.execute_query(
456
+ query = (
426
457
  """
427
- UNWIND $bfs_origin_node_uuids AS origin_uuid
428
- MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
429
- WHERE n.group_id = origin.group_id
430
- """
458
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
459
+ MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
460
+ WHERE n.group_id = origin.group_id
461
+ """
431
462
  + filter_query
432
463
  + ENTITY_NODE_RETURN
433
464
  + """
434
465
  LIMIT $limit
435
- """,
436
- filter_params,
466
+ """
467
+ )
468
+ records, _, _ = await driver.execute_query(
469
+ query,
470
+ params=filter_params,
437
471
  bfs_origin_node_uuids=bfs_origin_node_uuids,
438
472
  depth=bfs_max_depth,
439
473
  limit=limit,
@@ -446,7 +480,7 @@ async def node_bfs_search(
446
480
 
447
481
 
448
482
  async def episode_fulltext_search(
449
- driver: AsyncDriver,
483
+ driver: GraphDriver,
450
484
  query: str,
451
485
  _search_filter: SearchFilters,
452
486
  group_ids: list[str] | None = None,
@@ -457,9 +491,9 @@ async def episode_fulltext_search(
457
491
  if fuzzy_query == '':
458
492
  return []
459
493
 
460
- records, _, _ = await driver.execute_query(
461
- """
462
- CALL db.index.fulltext.queryNodes("episode_content", $query, {limit: $limit})
494
+ query = (
495
+ get_nodes_query(driver.provider, 'episode_content', '$query')
496
+ + """
463
497
  YIELD node AS episode, score
464
498
  MATCH (e:Episodic)
465
499
  WHERE e.uuid = episode.uuid
@@ -475,7 +509,11 @@ async def episode_fulltext_search(
475
509
  e.entity_edges AS entity_edges
476
510
  ORDER BY score DESC
477
511
  LIMIT $limit
478
- """,
512
+ """
513
+ )
514
+
515
+ records, _, _ = await driver.execute_query(
516
+ query,
479
517
  query=fuzzy_query,
480
518
  group_ids=group_ids,
481
519
  limit=limit,
@@ -488,7 +526,7 @@ async def episode_fulltext_search(
488
526
 
489
527
 
490
528
  async def community_fulltext_search(
491
- driver: AsyncDriver,
529
+ driver: GraphDriver,
492
530
  query: str,
493
531
  group_ids: list[str] | None = None,
494
532
  limit=RELEVANT_SCHEMA_LIMIT,
@@ -498,9 +536,9 @@ async def community_fulltext_search(
498
536
  if fuzzy_query == '':
499
537
  return []
500
538
 
501
- records, _, _ = await driver.execute_query(
502
- """
503
- CALL db.index.fulltext.queryNodes("community_name", $query, {limit: $limit})
539
+ query = (
540
+ get_nodes_query(driver.provider, 'community_name', '$query')
541
+ + """
504
542
  YIELD node AS comm, score
505
543
  RETURN
506
544
  comm.uuid AS uuid,
@@ -510,7 +548,11 @@ async def community_fulltext_search(
510
548
  comm.summary AS summary
511
549
  ORDER BY score DESC
512
550
  LIMIT $limit
513
- """,
551
+ """
552
+ )
553
+
554
+ records, _, _ = await driver.execute_query(
555
+ query,
514
556
  query=fuzzy_query,
515
557
  group_ids=group_ids,
516
558
  limit=limit,
@@ -523,7 +565,7 @@ async def community_fulltext_search(
523
565
 
524
566
 
525
567
  async def community_similarity_search(
526
- driver: AsyncDriver,
568
+ driver: GraphDriver,
527
569
  search_vector: list[float],
528
570
  group_ids: list[str] | None = None,
529
571
  limit=RELEVANT_SCHEMA_LIMIT,
@@ -537,14 +579,16 @@ async def community_similarity_search(
537
579
  group_filter_query += 'WHERE comm.group_id IN $group_ids'
538
580
  query_params['group_ids'] = group_ids
539
581
 
540
- records, _, _ = await driver.execute_query(
582
+ query = (
541
583
  RUNTIME_QUERY
542
584
  + """
543
585
  MATCH (comm:Community)
544
586
  """
545
587
  + group_filter_query
546
588
  + """
547
- WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
589
+ WITH comm, """
590
+ + get_vector_cosine_func_query('comm.name_embedding', '$search_vector', driver.provider)
591
+ + """ AS score
548
592
  WHERE score > $min_score
549
593
  RETURN
550
594
  comm.uuid As uuid,
@@ -554,7 +598,11 @@ async def community_similarity_search(
554
598
  comm.summary AS summary
555
599
  ORDER BY score DESC
556
600
  LIMIT $limit
557
- """,
601
+ """
602
+ )
603
+
604
+ records, _, _ = await driver.execute_query(
605
+ query,
558
606
  search_vector=search_vector,
559
607
  group_ids=group_ids,
560
608
  limit=limit,
@@ -570,7 +618,7 @@ async def community_similarity_search(
570
618
  async def hybrid_node_search(
571
619
  queries: list[str],
572
620
  embeddings: list[list[float]],
573
- driver: AsyncDriver,
621
+ driver: GraphDriver,
574
622
  search_filter: SearchFilters,
575
623
  group_ids: list[str] | None = None,
576
624
  limit: int = RELEVANT_SCHEMA_LIMIT,
@@ -587,7 +635,7 @@ async def hybrid_node_search(
587
635
  A list of text queries to search for.
588
636
  embeddings : list[list[float]]
589
637
  A list of embedding vectors corresponding to the queries. If empty only fulltext search is performed.
590
- driver : AsyncDriver
638
+ driver : GraphDriver
591
639
  The Neo4j driver instance for database operations.
592
640
  group_ids : list[str] | None, optional
593
641
  The list of group ids to retrieve nodes from.
@@ -642,7 +690,7 @@ async def hybrid_node_search(
642
690
 
643
691
 
644
692
  async def get_relevant_nodes(
645
- driver: AsyncDriver,
693
+ driver: GraphDriver,
646
694
  nodes: list[EntityNode],
647
695
  search_filter: SearchFilters,
648
696
  min_score: float = DEFAULT_MIN_SCORE,
@@ -661,29 +709,33 @@ async def get_relevant_nodes(
661
709
 
662
710
  query = (
663
711
  RUNTIME_QUERY
664
- + """UNWIND $nodes AS node
665
- MATCH (n:Entity {group_id: $group_id})
666
- """
712
+ + """
713
+ UNWIND $nodes AS node
714
+ MATCH (n:Entity {group_id: $group_id})
715
+ """
667
716
  + filter_query
668
717
  + """
669
- WITH node, n, vector.similarity.cosine(n.name_embedding, node.name_embedding) AS score
718
+ WITH node, n, """
719
+ + get_vector_cosine_func_query('n.name_embedding', 'node.name_embedding', driver.provider)
720
+ + """ AS score
670
721
  WHERE score > $min_score
671
722
  WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
672
-
673
- CALL db.index.fulltext.queryNodes("node_name_and_summary", node.fulltext_query, {limit: $limit})
723
+ """
724
+ + get_nodes_query(driver.provider, 'node_name_and_summary', 'node.fulltext_query')
725
+ + """
674
726
  YIELD node AS m
675
727
  WHERE m.group_id = $group_id
676
728
  WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
677
-
729
+
678
730
  WITH node,
679
731
  top_vector_nodes,
680
732
  [m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
681
-
733
+
682
734
  WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
683
-
735
+
684
736
  UNWIND combined_nodes AS combined_node
685
737
  WITH node, collect(DISTINCT combined_node) AS deduped_nodes
686
-
738
+
687
739
  RETURN
688
740
  node.uuid AS search_node_uuid,
689
741
  [x IN deduped_nodes | {
@@ -711,7 +763,7 @@ async def get_relevant_nodes(
711
763
 
712
764
  results, _, _ = await driver.execute_query(
713
765
  query,
714
- query_params,
766
+ params=query_params,
715
767
  nodes=query_nodes,
716
768
  group_id=group_id,
717
769
  limit=limit,
@@ -733,7 +785,7 @@ async def get_relevant_nodes(
733
785
 
734
786
 
735
787
  async def get_relevant_edges(
736
- driver: AsyncDriver,
788
+ driver: GraphDriver,
737
789
  edges: list[EntityEdge],
738
790
  search_filter: SearchFilters,
739
791
  min_score: float = DEFAULT_MIN_SCORE,
@@ -749,42 +801,47 @@ async def get_relevant_edges(
749
801
 
750
802
  query = (
751
803
  RUNTIME_QUERY
752
- + """UNWIND $edges AS edge
753
- MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
754
- """
804
+ + """
805
+ UNWIND $edges AS edge
806
+ MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
807
+ """
755
808
  + filter_query
756
809
  + """
757
- WITH e, edge, vector.similarity.cosine(e.fact_embedding, edge.fact_embedding) AS score
758
- WHERE score > $min_score
759
- WITH edge, e, score
760
- ORDER BY score DESC
761
- RETURN edge.uuid AS search_edge_uuid,
762
- collect({
763
- uuid: e.uuid,
764
- source_node_uuid: startNode(e).uuid,
765
- target_node_uuid: endNode(e).uuid,
766
- created_at: e.created_at,
767
- name: e.name,
768
- group_id: e.group_id,
769
- fact: e.fact,
770
- fact_embedding: e.fact_embedding,
771
- episodes: e.episodes,
772
- expired_at: e.expired_at,
773
- valid_at: e.valid_at,
774
- invalid_at: e.invalid_at
775
- })[..$limit] AS matches
810
+ WITH e, edge, """
811
+ + get_vector_cosine_func_query('e.fact_embedding', 'edge.fact_embedding', driver.provider)
812
+ + """ AS score
813
+ WHERE score > $min_score
814
+ WITH edge, e, score
815
+ ORDER BY score DESC
816
+ RETURN edge.uuid AS search_edge_uuid,
817
+ collect({
818
+ uuid: e.uuid,
819
+ source_node_uuid: startNode(e).uuid,
820
+ target_node_uuid: endNode(e).uuid,
821
+ created_at: e.created_at,
822
+ name: e.name,
823
+ group_id: e.group_id,
824
+ fact: e.fact,
825
+ fact_embedding: e.fact_embedding,
826
+ episodes: e.episodes,
827
+ expired_at: e.expired_at,
828
+ valid_at: e.valid_at,
829
+ invalid_at: e.invalid_at,
830
+ attributes: properties(e)
831
+ })[..$limit] AS matches
776
832
  """
777
833
  )
778
834
 
779
835
  results, _, _ = await driver.execute_query(
780
836
  query,
781
- query_params,
837
+ params=query_params,
782
838
  edges=[edge.model_dump() for edge in edges],
783
839
  limit=limit,
784
840
  min_score=min_score,
785
841
  database_=DEFAULT_DATABASE,
786
842
  routing_='r',
787
843
  )
844
+
788
845
  relevant_edges_dict: dict[str, list[EntityEdge]] = {
789
846
  result['search_edge_uuid']: [
790
847
  get_entity_edge_from_record(record) for record in result['matches']
@@ -798,7 +855,7 @@ async def get_relevant_edges(
798
855
 
799
856
 
800
857
  async def get_edge_invalidation_candidates(
801
- driver: AsyncDriver,
858
+ driver: GraphDriver,
802
859
  edges: list[EntityEdge],
803
860
  search_filter: SearchFilters,
804
861
  min_score: float = DEFAULT_MIN_SCORE,
@@ -814,37 +871,41 @@ async def get_edge_invalidation_candidates(
814
871
 
815
872
  query = (
816
873
  RUNTIME_QUERY
817
- + """UNWIND $edges AS edge
818
- MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
819
- WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
820
- """
874
+ + """
875
+ UNWIND $edges AS edge
876
+ MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
877
+ WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
878
+ """
821
879
  + filter_query
822
880
  + """
823
- WITH edge, e, vector.similarity.cosine(e.fact_embedding, edge.fact_embedding) AS score
824
- WHERE score > $min_score
825
- WITH edge, e, score
826
- ORDER BY score DESC
827
- RETURN edge.uuid AS search_edge_uuid,
828
- collect({
829
- uuid: e.uuid,
830
- source_node_uuid: startNode(e).uuid,
831
- target_node_uuid: endNode(e).uuid,
832
- created_at: e.created_at,
833
- name: e.name,
834
- group_id: e.group_id,
835
- fact: e.fact,
836
- fact_embedding: e.fact_embedding,
837
- episodes: e.episodes,
838
- expired_at: e.expired_at,
839
- valid_at: e.valid_at,
840
- invalid_at: e.invalid_at
841
- })[..$limit] AS matches
881
+ WITH edge, e, """
882
+ + get_vector_cosine_func_query('e.fact_embedding', 'edge.fact_embedding', driver.provider)
883
+ + """ AS score
884
+ WHERE score > $min_score
885
+ WITH edge, e, score
886
+ ORDER BY score DESC
887
+ RETURN edge.uuid AS search_edge_uuid,
888
+ collect({
889
+ uuid: e.uuid,
890
+ source_node_uuid: startNode(e).uuid,
891
+ target_node_uuid: endNode(e).uuid,
892
+ created_at: e.created_at,
893
+ name: e.name,
894
+ group_id: e.group_id,
895
+ fact: e.fact,
896
+ fact_embedding: e.fact_embedding,
897
+ episodes: e.episodes,
898
+ expired_at: e.expired_at,
899
+ valid_at: e.valid_at,
900
+ invalid_at: e.invalid_at,
901
+ attributes: properties(e)
902
+ })[..$limit] AS matches
842
903
  """
843
904
  )
844
905
 
845
906
  results, _, _ = await driver.execute_query(
846
907
  query,
847
- query_params,
908
+ params=query_params,
848
909
  edges=[edge.model_dump() for edge in edges],
849
910
  limit=limit,
850
911
  min_score=min_score,
@@ -879,7 +940,7 @@ def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[st
879
940
 
880
941
 
881
942
  async def node_distance_reranker(
882
- driver: AsyncDriver,
943
+ driver: GraphDriver,
883
944
  node_uuids: list[str],
884
945
  center_node_uuid: str,
885
946
  min_score: float = 0,
@@ -889,21 +950,22 @@ async def node_distance_reranker(
889
950
  scores: dict[str, float] = {center_node_uuid: 0.0}
890
951
 
891
952
  # Find the shortest path to center node
892
- query = Query("""
953
+ query = """
893
954
  UNWIND $node_uuids AS node_uuid
894
- MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: node_uuid})
895
- RETURN length(p) AS score, node_uuid AS uuid
896
- """)
897
-
898
- path_results, _, _ = await driver.execute_query(
955
+ MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
956
+ RETURN 1 AS score, node_uuid AS uuid
957
+ """
958
+ results, header, _ = await driver.execute_query(
899
959
  query,
900
960
  node_uuids=filtered_uuids,
901
961
  center_uuid=center_node_uuid,
902
962
  database_=DEFAULT_DATABASE,
903
963
  routing_='r',
904
964
  )
965
+ if driver.provider == 'falkordb':
966
+ results = [dict(zip(header, row, strict=True)) for row in results]
905
967
 
906
- for result in path_results:
968
+ for result in results:
907
969
  uuid = result['uuid']
908
970
  score = result['score']
909
971
  scores[uuid] = score
@@ -924,19 +986,18 @@ async def node_distance_reranker(
924
986
 
925
987
 
926
988
  async def episode_mentions_reranker(
927
- driver: AsyncDriver, node_uuids: list[list[str]], min_score: float = 0
989
+ driver: GraphDriver, node_uuids: list[list[str]], min_score: float = 0
928
990
  ) -> list[str]:
929
991
  # use rrf as a preliminary ranker
930
992
  sorted_uuids = rrf(node_uuids)
931
993
  scores: dict[str, float] = {}
932
994
 
933
995
  # Find the shortest path to center node
934
- query = Query("""
996
+ query = """
935
997
  UNWIND $node_uuids AS node_uuid
936
998
  MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
937
999
  RETURN count(*) AS score, n.uuid AS uuid
938
- """)
939
-
1000
+ """
940
1001
  results, _, _ = await driver.execute_query(
941
1002
  query,
942
1003
  node_uuids=sorted_uuids,
@@ -993,7 +1054,7 @@ def maximal_marginal_relevance(
993
1054
 
994
1055
 
995
1056
  async def get_embeddings_for_nodes(
996
- driver: AsyncDriver, nodes: list[EntityNode]
1057
+ driver: GraphDriver, nodes: list[EntityNode]
997
1058
  ) -> dict[str, list[float]]:
998
1059
  query: LiteralString = """MATCH (n:Entity)
999
1060
  WHERE n.uuid IN $node_uuids
@@ -1017,7 +1078,7 @@ async def get_embeddings_for_nodes(
1017
1078
 
1018
1079
 
1019
1080
  async def get_embeddings_for_communities(
1020
- driver: AsyncDriver, communities: list[CommunityNode]
1081
+ driver: GraphDriver, communities: list[CommunityNode]
1021
1082
  ) -> dict[str, list[float]]:
1022
1083
  query: LiteralString = """MATCH (c:Community)
1023
1084
  WHERE c.uuid IN $community_uuids
@@ -1044,7 +1105,7 @@ async def get_embeddings_for_communities(
1044
1105
 
1045
1106
 
1046
1107
  async def get_embeddings_for_edges(
1047
- driver: AsyncDriver, edges: list[EntityEdge]
1108
+ driver: GraphDriver, edges: list[EntityEdge]
1048
1109
  ) -> dict[str, list[float]]:
1049
1110
  query: LiteralString = """MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
1050
1111
  WHERE e.uuid IN $edge_uuids