graphiti-core 0.11.6rc7__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (33) hide show
  1. graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
  2. graphiti_core/driver/__init__.py +17 -0
  3. graphiti_core/driver/driver.py +66 -0
  4. graphiti_core/driver/falkordb_driver.py +132 -0
  5. graphiti_core/driver/neo4j_driver.py +61 -0
  6. graphiti_core/edges.py +66 -40
  7. graphiti_core/embedder/azure_openai.py +64 -0
  8. graphiti_core/embedder/gemini.py +14 -3
  9. graphiti_core/graph_queries.py +149 -0
  10. graphiti_core/graphiti.py +41 -14
  11. graphiti_core/graphiti_types.py +2 -2
  12. graphiti_core/helpers.py +17 -30
  13. graphiti_core/llm_client/__init__.py +16 -0
  14. graphiti_core/llm_client/azure_openai_client.py +73 -0
  15. graphiti_core/llm_client/gemini_client.py +4 -1
  16. graphiti_core/models/edges/edge_db_queries.py +2 -4
  17. graphiti_core/nodes.py +31 -31
  18. graphiti_core/prompts/dedupe_edges.py +52 -1
  19. graphiti_core/prompts/dedupe_nodes.py +79 -4
  20. graphiti_core/prompts/extract_edges.py +50 -5
  21. graphiti_core/prompts/invalidate_edges.py +1 -1
  22. graphiti_core/search/search.py +25 -55
  23. graphiti_core/search/search_filters.py +23 -9
  24. graphiti_core/search/search_utils.py +360 -195
  25. graphiti_core/utils/bulk_utils.py +38 -11
  26. graphiti_core/utils/maintenance/community_operations.py +6 -7
  27. graphiti_core/utils/maintenance/edge_operations.py +149 -19
  28. graphiti_core/utils/maintenance/graph_data_operations.py +13 -42
  29. graphiti_core/utils/maintenance/node_operations.py +52 -71
  30. {graphiti_core-0.11.6rc7.dist-info → graphiti_core-0.12.0.dist-info}/METADATA +14 -5
  31. {graphiti_core-0.11.6rc7.dist-info → graphiti_core-0.12.0.dist-info}/RECORD +33 -26
  32. {graphiti_core-0.11.6rc7.dist-info → graphiti_core-0.12.0.dist-info}/LICENSE +0 -0
  33. {graphiti_core-0.11.6rc7.dist-info → graphiti_core-0.12.0.dist-info}/WHEEL +0 -0
@@ -20,10 +20,16 @@ from time import time
20
20
  from typing import Any
21
21
 
22
22
  import numpy as np
23
- from neo4j import AsyncDriver, Query
23
+ from numpy._typing import NDArray
24
24
  from typing_extensions import LiteralString
25
25
 
26
+ from graphiti_core.driver.driver import GraphDriver
26
27
  from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
28
+ from graphiti_core.graph_queries import (
29
+ get_nodes_query,
30
+ get_relationships_query,
31
+ get_vector_cosine_func_query,
32
+ )
27
33
  from graphiti_core.helpers import (
28
34
  DEFAULT_DATABASE,
29
35
  RUNTIME_QUERY,
@@ -57,7 +63,7 @@ MAX_QUERY_LENGTH = 32
57
63
 
58
64
  def fulltext_query(query: str, group_ids: list[str] | None = None):
59
65
  group_ids_filter_list = (
60
- [f'group_id:"{lucene_sanitize(g)}"' for g in group_ids] if group_ids is not None else []
66
+ [f"group_id-'{lucene_sanitize(g)}'" for g in group_ids] if group_ids is not None else []
61
67
  )
62
68
  group_ids_filter = ''
63
69
  for f in group_ids_filter_list:
@@ -76,7 +82,7 @@ def fulltext_query(query: str, group_ids: list[str] | None = None):
76
82
 
77
83
 
78
84
  async def get_episodes_by_mentions(
79
- driver: AsyncDriver,
85
+ driver: GraphDriver,
80
86
  nodes: list[EntityNode],
81
87
  edges: list[EntityEdge],
82
88
  limit: int = RELEVANT_SCHEMA_LIMIT,
@@ -91,11 +97,11 @@ async def get_episodes_by_mentions(
91
97
 
92
98
 
93
99
  async def get_mentioned_nodes(
94
- driver: AsyncDriver, episodes: list[EpisodicNode]
100
+ driver: GraphDriver, episodes: list[EpisodicNode]
95
101
  ) -> list[EntityNode]:
96
102
  episode_uuids = [episode.uuid for episode in episodes]
97
- records, _, _ = await driver.execute_query(
98
- """
103
+
104
+ query = """
99
105
  MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
100
106
  RETURN DISTINCT
101
107
  n.uuid As uuid,
@@ -105,7 +111,10 @@ async def get_mentioned_nodes(
105
111
  n.summary AS summary,
106
112
  labels(n) AS labels,
107
113
  properties(n) AS attributes
108
- """,
114
+ """
115
+
116
+ records, _, _ = await driver.execute_query(
117
+ query,
109
118
  uuids=episode_uuids,
110
119
  database_=DEFAULT_DATABASE,
111
120
  routing_='r',
@@ -117,11 +126,11 @@ async def get_mentioned_nodes(
117
126
 
118
127
 
119
128
  async def get_communities_by_nodes(
120
- driver: AsyncDriver, nodes: list[EntityNode]
129
+ driver: GraphDriver, nodes: list[EntityNode]
121
130
  ) -> list[CommunityNode]:
122
131
  node_uuids = [node.uuid for node in nodes]
123
- records, _, _ = await driver.execute_query(
124
- """
132
+
133
+ query = """
125
134
  MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
126
135
  RETURN DISTINCT
127
136
  c.uuid As uuid,
@@ -129,7 +138,10 @@ async def get_communities_by_nodes(
129
138
  c.name AS name,
130
139
  c.created_at AS created_at,
131
140
  c.summary AS summary
132
- """,
141
+ """
142
+
143
+ records, _, _ = await driver.execute_query(
144
+ query,
133
145
  uuids=node_uuids,
134
146
  database_=DEFAULT_DATABASE,
135
147
  routing_='r',
@@ -141,7 +153,7 @@ async def get_communities_by_nodes(
141
153
 
142
154
 
143
155
  async def edge_fulltext_search(
144
- driver: AsyncDriver,
156
+ driver: GraphDriver,
145
157
  query: str,
146
158
  search_filter: SearchFilters,
147
159
  group_ids: list[str] | None = None,
@@ -154,33 +166,35 @@ async def edge_fulltext_search(
154
166
 
155
167
  filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
156
168
 
157
- cypher_query = Query(
158
- """
159
- CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query, {limit: $limit})
160
- YIELD relationship AS rel, score
161
- MATCH (:Entity)-[r:RELATES_TO]->(:Entity)
162
- WHERE r.group_id IN $group_ids"""
169
+ query = (
170
+ get_relationships_query('edge_name_and_fact', db_type=driver.provider)
171
+ + """
172
+ YIELD relationship AS rel, score
173
+ MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
174
+ WHERE r.group_id IN $group_ids """
163
175
  + filter_query
164
- + """\nWITH r, score, startNode(r) AS n, endNode(r) AS m
165
- RETURN
166
- r.uuid AS uuid,
167
- r.group_id AS group_id,
168
- n.uuid AS source_node_uuid,
169
- m.uuid AS target_node_uuid,
170
- r.created_at AS created_at,
171
- r.name AS name,
172
- r.fact AS fact,
173
- r.episodes AS episodes,
174
- r.expired_at AS expired_at,
175
- r.valid_at AS valid_at,
176
- r.invalid_at AS invalid_at
177
- ORDER BY score DESC LIMIT $limit
178
- """
176
+ + """
177
+ WITH r, score, startNode(r) AS n, endNode(r) AS m
178
+ RETURN
179
+ r.uuid AS uuid,
180
+ r.group_id AS group_id,
181
+ n.uuid AS source_node_uuid,
182
+ m.uuid AS target_node_uuid,
183
+ r.created_at AS created_at,
184
+ r.name AS name,
185
+ r.fact AS fact,
186
+ r.episodes AS episodes,
187
+ r.expired_at AS expired_at,
188
+ r.valid_at AS valid_at,
189
+ r.invalid_at AS invalid_at,
190
+ properties(r) AS attributes
191
+ ORDER BY score DESC LIMIT $limit
192
+ """
179
193
  )
180
194
 
181
195
  records, _, _ = await driver.execute_query(
182
- cypher_query,
183
- filter_params,
196
+ query,
197
+ params=filter_params,
184
198
  query=fuzzy_query,
185
199
  group_ids=group_ids,
186
200
  limit=limit,
@@ -194,7 +208,7 @@ async def edge_fulltext_search(
194
208
 
195
209
 
196
210
  async def edge_similarity_search(
197
- driver: AsyncDriver,
211
+ driver: GraphDriver,
198
212
  search_vector: list[float],
199
213
  source_node_uuid: str | None,
200
214
  target_node_uuid: str | None,
@@ -209,9 +223,9 @@ async def edge_similarity_search(
209
223
  filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
210
224
  query_params.update(filter_params)
211
225
 
212
- group_filter_query: LiteralString = ''
226
+ group_filter_query: LiteralString = 'WHERE r.group_id IS NOT NULL'
213
227
  if group_ids is not None:
214
- group_filter_query += 'WHERE r.group_id IN $group_ids'
228
+ group_filter_query += '\nAND r.group_id IN $group_ids'
215
229
  query_params['group_ids'] = group_ids
216
230
  query_params['source_node_uuid'] = source_node_uuid
217
231
  query_params['target_node_uuid'] = target_node_uuid
@@ -222,35 +236,38 @@ async def edge_similarity_search(
222
236
  if target_node_uuid is not None:
223
237
  group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])'
224
238
 
225
- query: LiteralString = (
239
+ query = (
226
240
  RUNTIME_QUERY
227
241
  + """
228
- MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
229
- """
242
+ MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
243
+ """
230
244
  + group_filter_query
231
245
  + filter_query
232
- + """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
233
- WHERE score > $min_score
234
- RETURN
235
- r.uuid AS uuid,
236
- r.group_id AS group_id,
237
- startNode(r).uuid AS source_node_uuid,
238
- endNode(r).uuid AS target_node_uuid,
239
- r.created_at AS created_at,
240
- r.name AS name,
241
- r.fact AS fact,
242
- r.episodes AS episodes,
243
- r.expired_at AS expired_at,
244
- r.valid_at AS valid_at,
245
- r.invalid_at AS invalid_at
246
- ORDER BY score DESC
247
- LIMIT $limit
246
+ + """
247
+ WITH DISTINCT r, """
248
+ + get_vector_cosine_func_query('r.fact_embedding', '$search_vector', driver.provider)
249
+ + """ AS score
250
+ WHERE score > $min_score
251
+ RETURN
252
+ r.uuid AS uuid,
253
+ r.group_id AS group_id,
254
+ startNode(r).uuid AS source_node_uuid,
255
+ endNode(r).uuid AS target_node_uuid,
256
+ r.created_at AS created_at,
257
+ r.name AS name,
258
+ r.fact AS fact,
259
+ r.episodes AS episodes,
260
+ r.expired_at AS expired_at,
261
+ r.valid_at AS valid_at,
262
+ r.invalid_at AS invalid_at,
263
+ properties(r) AS attributes
264
+ ORDER BY score DESC
265
+ LIMIT $limit
248
266
  """
249
267
  )
250
-
251
- records, _, _ = await driver.execute_query(
268
+ records, header, _ = await driver.execute_query(
252
269
  query,
253
- query_params,
270
+ params=query_params,
254
271
  search_vector=search_vector,
255
272
  source_uuid=source_node_uuid,
256
273
  target_uuid=target_node_uuid,
@@ -261,13 +278,16 @@ async def edge_similarity_search(
261
278
  routing_='r',
262
279
  )
263
280
 
281
+ if driver.provider == 'falkordb':
282
+ records = [dict(zip(header, row, strict=True)) for row in records]
283
+
264
284
  edges = [get_entity_edge_from_record(record) for record in records]
265
285
 
266
286
  return edges
267
287
 
268
288
 
269
289
  async def edge_bfs_search(
270
- driver: AsyncDriver,
290
+ driver: GraphDriver,
271
291
  bfs_origin_node_uuids: list[str] | None,
272
292
  bfs_max_depth: int,
273
293
  search_filter: SearchFilters,
@@ -279,14 +299,14 @@ async def edge_bfs_search(
279
299
 
280
300
  filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
281
301
 
282
- query = Query(
302
+ query = (
283
303
  """
284
- UNWIND $bfs_origin_node_uuids AS origin_uuid
285
- MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
286
- UNWIND relationships(path) AS rel
287
- MATCH ()-[r:RELATES_TO]-()
288
- WHERE r.uuid = rel.uuid
289
- """
304
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
305
+ MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
306
+ UNWIND relationships(path) AS rel
307
+ MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
308
+ WHERE r.uuid = rel.uuid
309
+ """
290
310
  + filter_query
291
311
  + """
292
312
  RETURN DISTINCT
@@ -300,14 +320,15 @@ async def edge_bfs_search(
300
320
  r.episodes AS episodes,
301
321
  r.expired_at AS expired_at,
302
322
  r.valid_at AS valid_at,
303
- r.invalid_at AS invalid_at
323
+ r.invalid_at AS invalid_at,
324
+ properties(r) AS attributes
304
325
  LIMIT $limit
305
326
  """
306
327
  )
307
328
 
308
329
  records, _, _ = await driver.execute_query(
309
330
  query,
310
- filter_params,
331
+ params=filter_params,
311
332
  bfs_origin_node_uuids=bfs_origin_node_uuids,
312
333
  depth=bfs_max_depth,
313
334
  limit=limit,
@@ -321,7 +342,7 @@ async def edge_bfs_search(
321
342
 
322
343
 
323
344
  async def node_fulltext_search(
324
- driver: AsyncDriver,
345
+ driver: GraphDriver,
325
346
  query: str,
326
347
  search_filter: SearchFilters,
327
348
  group_ids: list[str] | None = None,
@@ -331,38 +352,41 @@ async def node_fulltext_search(
331
352
  fuzzy_query = fulltext_query(query, group_ids)
332
353
  if fuzzy_query == '':
333
354
  return []
334
-
335
355
  filter_query, filter_params = node_search_filter_query_constructor(search_filter)
336
356
 
337
357
  query = (
358
+ get_nodes_query(driver.provider, 'node_name_and_summary', '$query')
359
+ + """
360
+ YIELD node AS n, score
361
+ WITH n, score
362
+ LIMIT $limit
363
+ WHERE n:Entity
338
364
  """
339
- CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
340
- YIELD node AS n, score
341
- WHERE n:Entity
342
- """
343
365
  + filter_query
344
366
  + ENTITY_NODE_RETURN
345
367
  + """
346
368
  ORDER BY score DESC
347
369
  """
348
370
  )
349
-
350
- records, _, _ = await driver.execute_query(
371
+ records, header, _ = await driver.execute_query(
351
372
  query,
352
- filter_params,
373
+ params=filter_params,
353
374
  query=fuzzy_query,
354
375
  group_ids=group_ids,
355
376
  limit=limit,
356
377
  database_=DEFAULT_DATABASE,
357
378
  routing_='r',
358
379
  )
380
+ if driver.provider == 'falkordb':
381
+ records = [dict(zip(header, row, strict=True)) for row in records]
382
+
359
383
  nodes = [get_entity_node_from_record(record) for record in records]
360
384
 
361
385
  return nodes
362
386
 
363
387
 
364
388
  async def node_similarity_search(
365
- driver: AsyncDriver,
389
+ driver: GraphDriver,
366
390
  search_vector: list[float],
367
391
  search_filter: SearchFilters,
368
392
  group_ids: list[str] | None = None,
@@ -372,30 +396,36 @@ async def node_similarity_search(
372
396
  # vector similarity search over entity names
373
397
  query_params: dict[str, Any] = {}
374
398
 
375
- group_filter_query: LiteralString = ''
399
+ group_filter_query: LiteralString = 'WHERE n.group_id IS NOT NULL'
376
400
  if group_ids is not None:
377
- group_filter_query += 'WHERE n.group_id IN $group_ids'
401
+ group_filter_query += ' AND n.group_id IN $group_ids'
378
402
  query_params['group_ids'] = group_ids
379
403
 
380
404
  filter_query, filter_params = node_search_filter_query_constructor(search_filter)
381
405
  query_params.update(filter_params)
382
406
 
383
- records, _, _ = await driver.execute_query(
407
+ query = (
384
408
  RUNTIME_QUERY
385
409
  + """
386
- MATCH (n:Entity)
387
- """
410
+ MATCH (n:Entity)
411
+ """
388
412
  + group_filter_query
389
413
  + filter_query
390
414
  + """
391
- WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
392
- WHERE score > $min_score"""
415
+ WITH n, """
416
+ + get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
417
+ + """ AS score
418
+ WHERE score > $min_score"""
393
419
  + ENTITY_NODE_RETURN
394
420
  + """
395
421
  ORDER BY score DESC
396
422
  LIMIT $limit
397
- """,
398
- query_params,
423
+ """
424
+ )
425
+
426
+ records, header, _ = await driver.execute_query(
427
+ query,
428
+ params=query_params,
399
429
  search_vector=search_vector,
400
430
  group_ids=group_ids,
401
431
  limit=limit,
@@ -403,13 +433,15 @@ async def node_similarity_search(
403
433
  database_=DEFAULT_DATABASE,
404
434
  routing_='r',
405
435
  )
436
+ if driver.provider == 'falkordb':
437
+ records = [dict(zip(header, row, strict=True)) for row in records]
406
438
  nodes = [get_entity_node_from_record(record) for record in records]
407
439
 
408
440
  return nodes
409
441
 
410
442
 
411
443
  async def node_bfs_search(
412
- driver: AsyncDriver,
444
+ driver: GraphDriver,
413
445
  bfs_origin_node_uuids: list[str] | None,
414
446
  search_filter: SearchFilters,
415
447
  bfs_max_depth: int,
@@ -421,18 +453,21 @@ async def node_bfs_search(
421
453
 
422
454
  filter_query, filter_params = node_search_filter_query_constructor(search_filter)
423
455
 
424
- records, _, _ = await driver.execute_query(
456
+ query = (
425
457
  """
426
- UNWIND $bfs_origin_node_uuids AS origin_uuid
427
- MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
428
- WHERE n.group_id = origin.group_id
429
- """
458
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
459
+ MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
460
+ WHERE n.group_id = origin.group_id
461
+ """
430
462
  + filter_query
431
463
  + ENTITY_NODE_RETURN
432
464
  + """
433
465
  LIMIT $limit
434
- """,
435
- filter_params,
466
+ """
467
+ )
468
+ records, _, _ = await driver.execute_query(
469
+ query,
470
+ params=filter_params,
436
471
  bfs_origin_node_uuids=bfs_origin_node_uuids,
437
472
  depth=bfs_max_depth,
438
473
  limit=limit,
@@ -445,7 +480,7 @@ async def node_bfs_search(
445
480
 
446
481
 
447
482
  async def episode_fulltext_search(
448
- driver: AsyncDriver,
483
+ driver: GraphDriver,
449
484
  query: str,
450
485
  _search_filter: SearchFilters,
451
486
  group_ids: list[str] | None = None,
@@ -456,9 +491,9 @@ async def episode_fulltext_search(
456
491
  if fuzzy_query == '':
457
492
  return []
458
493
 
459
- records, _, _ = await driver.execute_query(
460
- """
461
- CALL db.index.fulltext.queryNodes("episode_content", $query, {limit: $limit})
494
+ query = (
495
+ get_nodes_query(driver.provider, 'episode_content', '$query')
496
+ + """
462
497
  YIELD node AS episode, score
463
498
  MATCH (e:Episodic)
464
499
  WHERE e.uuid = episode.uuid
@@ -474,7 +509,11 @@ async def episode_fulltext_search(
474
509
  e.entity_edges AS entity_edges
475
510
  ORDER BY score DESC
476
511
  LIMIT $limit
477
- """,
512
+ """
513
+ )
514
+
515
+ records, _, _ = await driver.execute_query(
516
+ query,
478
517
  query=fuzzy_query,
479
518
  group_ids=group_ids,
480
519
  limit=limit,
@@ -487,7 +526,7 @@ async def episode_fulltext_search(
487
526
 
488
527
 
489
528
  async def community_fulltext_search(
490
- driver: AsyncDriver,
529
+ driver: GraphDriver,
491
530
  query: str,
492
531
  group_ids: list[str] | None = None,
493
532
  limit=RELEVANT_SCHEMA_LIMIT,
@@ -497,9 +536,9 @@ async def community_fulltext_search(
497
536
  if fuzzy_query == '':
498
537
  return []
499
538
 
500
- records, _, _ = await driver.execute_query(
501
- """
502
- CALL db.index.fulltext.queryNodes("community_name", $query, {limit: $limit})
539
+ query = (
540
+ get_nodes_query(driver.provider, 'community_name', '$query')
541
+ + """
503
542
  YIELD node AS comm, score
504
543
  RETURN
505
544
  comm.uuid AS uuid,
@@ -509,7 +548,11 @@ async def community_fulltext_search(
509
548
  comm.summary AS summary
510
549
  ORDER BY score DESC
511
550
  LIMIT $limit
512
- """,
551
+ """
552
+ )
553
+
554
+ records, _, _ = await driver.execute_query(
555
+ query,
513
556
  query=fuzzy_query,
514
557
  group_ids=group_ids,
515
558
  limit=limit,
@@ -522,7 +565,7 @@ async def community_fulltext_search(
522
565
 
523
566
 
524
567
  async def community_similarity_search(
525
- driver: AsyncDriver,
568
+ driver: GraphDriver,
526
569
  search_vector: list[float],
527
570
  group_ids: list[str] | None = None,
528
571
  limit=RELEVANT_SCHEMA_LIMIT,
@@ -536,14 +579,16 @@ async def community_similarity_search(
536
579
  group_filter_query += 'WHERE comm.group_id IN $group_ids'
537
580
  query_params['group_ids'] = group_ids
538
581
 
539
- records, _, _ = await driver.execute_query(
582
+ query = (
540
583
  RUNTIME_QUERY
541
584
  + """
542
585
  MATCH (comm:Community)
543
586
  """
544
587
  + group_filter_query
545
588
  + """
546
- WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
589
+ WITH comm, """
590
+ + get_vector_cosine_func_query('comm.name_embedding', '$search_vector', driver.provider)
591
+ + """ AS score
547
592
  WHERE score > $min_score
548
593
  RETURN
549
594
  comm.uuid As uuid,
@@ -553,7 +598,11 @@ async def community_similarity_search(
553
598
  comm.summary AS summary
554
599
  ORDER BY score DESC
555
600
  LIMIT $limit
556
- """,
601
+ """
602
+ )
603
+
604
+ records, _, _ = await driver.execute_query(
605
+ query,
557
606
  search_vector=search_vector,
558
607
  group_ids=group_ids,
559
608
  limit=limit,
@@ -569,7 +618,7 @@ async def community_similarity_search(
569
618
  async def hybrid_node_search(
570
619
  queries: list[str],
571
620
  embeddings: list[list[float]],
572
- driver: AsyncDriver,
621
+ driver: GraphDriver,
573
622
  search_filter: SearchFilters,
574
623
  group_ids: list[str] | None = None,
575
624
  limit: int = RELEVANT_SCHEMA_LIMIT,
@@ -586,7 +635,7 @@ async def hybrid_node_search(
586
635
  A list of text queries to search for.
587
636
  embeddings : list[list[float]]
588
637
  A list of embedding vectors corresponding to the queries. If empty only fulltext search is performed.
589
- driver : AsyncDriver
638
+ driver : GraphDriver
590
639
  The Neo4j driver instance for database operations.
591
640
  group_ids : list[str] | None, optional
592
641
  The list of group ids to retrieve nodes from.
@@ -641,7 +690,7 @@ async def hybrid_node_search(
641
690
 
642
691
 
643
692
  async def get_relevant_nodes(
644
- driver: AsyncDriver,
693
+ driver: GraphDriver,
645
694
  nodes: list[EntityNode],
646
695
  search_filter: SearchFilters,
647
696
  min_score: float = DEFAULT_MIN_SCORE,
@@ -660,29 +709,33 @@ async def get_relevant_nodes(
660
709
 
661
710
  query = (
662
711
  RUNTIME_QUERY
663
- + """UNWIND $nodes AS node
664
- MATCH (n:Entity {group_id: $group_id})
665
- """
712
+ + """
713
+ UNWIND $nodes AS node
714
+ MATCH (n:Entity {group_id: $group_id})
715
+ """
666
716
  + filter_query
667
717
  + """
668
- WITH node, n, vector.similarity.cosine(n.name_embedding, node.name_embedding) AS score
718
+ WITH node, n, """
719
+ + get_vector_cosine_func_query('n.name_embedding', 'node.name_embedding', driver.provider)
720
+ + """ AS score
669
721
  WHERE score > $min_score
670
722
  WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
671
-
672
- CALL db.index.fulltext.queryNodes("node_name_and_summary", node.fulltext_query, {limit: $limit})
723
+ """
724
+ + get_nodes_query(driver.provider, 'node_name_and_summary', 'node.fulltext_query')
725
+ + """
673
726
  YIELD node AS m
674
727
  WHERE m.group_id = $group_id
675
728
  WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
676
-
729
+
677
730
  WITH node,
678
731
  top_vector_nodes,
679
732
  [m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
680
-
733
+
681
734
  WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
682
-
735
+
683
736
  UNWIND combined_nodes AS combined_node
684
737
  WITH node, collect(DISTINCT combined_node) AS deduped_nodes
685
-
738
+
686
739
  RETURN
687
740
  node.uuid AS search_node_uuid,
688
741
  [x IN deduped_nodes | {
@@ -710,7 +763,7 @@ async def get_relevant_nodes(
710
763
 
711
764
  results, _, _ = await driver.execute_query(
712
765
  query,
713
- query_params,
766
+ params=query_params,
714
767
  nodes=query_nodes,
715
768
  group_id=group_id,
716
769
  limit=limit,
@@ -732,7 +785,7 @@ async def get_relevant_nodes(
732
785
 
733
786
 
734
787
  async def get_relevant_edges(
735
- driver: AsyncDriver,
788
+ driver: GraphDriver,
736
789
  edges: list[EntityEdge],
737
790
  search_filter: SearchFilters,
738
791
  min_score: float = DEFAULT_MIN_SCORE,
@@ -748,42 +801,47 @@ async def get_relevant_edges(
748
801
 
749
802
  query = (
750
803
  RUNTIME_QUERY
751
- + """UNWIND $edges AS edge
752
- MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
753
- """
804
+ + """
805
+ UNWIND $edges AS edge
806
+ MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
807
+ """
754
808
  + filter_query
755
809
  + """
756
- WITH e, edge, vector.similarity.cosine(e.fact_embedding, edge.fact_embedding) AS score
757
- WHERE score > $min_score
758
- WITH edge, e, score
759
- ORDER BY score DESC
760
- RETURN edge.uuid AS search_edge_uuid,
761
- collect({
762
- uuid: e.uuid,
763
- source_node_uuid: startNode(e).uuid,
764
- target_node_uuid: endNode(e).uuid,
765
- created_at: e.created_at,
766
- name: e.name,
767
- group_id: e.group_id,
768
- fact: e.fact,
769
- fact_embedding: e.fact_embedding,
770
- episodes: e.episodes,
771
- expired_at: e.expired_at,
772
- valid_at: e.valid_at,
773
- invalid_at: e.invalid_at
774
- })[..$limit] AS matches
810
+ WITH e, edge, """
811
+ + get_vector_cosine_func_query('e.fact_embedding', 'edge.fact_embedding', driver.provider)
812
+ + """ AS score
813
+ WHERE score > $min_score
814
+ WITH edge, e, score
815
+ ORDER BY score DESC
816
+ RETURN edge.uuid AS search_edge_uuid,
817
+ collect({
818
+ uuid: e.uuid,
819
+ source_node_uuid: startNode(e).uuid,
820
+ target_node_uuid: endNode(e).uuid,
821
+ created_at: e.created_at,
822
+ name: e.name,
823
+ group_id: e.group_id,
824
+ fact: e.fact,
825
+ fact_embedding: e.fact_embedding,
826
+ episodes: e.episodes,
827
+ expired_at: e.expired_at,
828
+ valid_at: e.valid_at,
829
+ invalid_at: e.invalid_at,
830
+ attributes: properties(e)
831
+ })[..$limit] AS matches
775
832
  """
776
833
  )
777
834
 
778
835
  results, _, _ = await driver.execute_query(
779
836
  query,
780
- query_params,
837
+ params=query_params,
781
838
  edges=[edge.model_dump() for edge in edges],
782
839
  limit=limit,
783
840
  min_score=min_score,
784
841
  database_=DEFAULT_DATABASE,
785
842
  routing_='r',
786
843
  )
844
+
787
845
  relevant_edges_dict: dict[str, list[EntityEdge]] = {
788
846
  result['search_edge_uuid']: [
789
847
  get_entity_edge_from_record(record) for record in result['matches']
@@ -797,7 +855,7 @@ async def get_relevant_edges(
797
855
 
798
856
 
799
857
  async def get_edge_invalidation_candidates(
800
- driver: AsyncDriver,
858
+ driver: GraphDriver,
801
859
  edges: list[EntityEdge],
802
860
  search_filter: SearchFilters,
803
861
  min_score: float = DEFAULT_MIN_SCORE,
@@ -813,37 +871,41 @@ async def get_edge_invalidation_candidates(
813
871
 
814
872
  query = (
815
873
  RUNTIME_QUERY
816
- + """UNWIND $edges AS edge
817
- MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
818
- WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
819
- """
874
+ + """
875
+ UNWIND $edges AS edge
876
+ MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
877
+ WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
878
+ """
820
879
  + filter_query
821
880
  + """
822
- WITH edge, e, vector.similarity.cosine(e.fact_embedding, edge.fact_embedding) AS score
823
- WHERE score > $min_score
824
- WITH edge, e, score
825
- ORDER BY score DESC
826
- RETURN edge.uuid AS search_edge_uuid,
827
- collect({
828
- uuid: e.uuid,
829
- source_node_uuid: startNode(e).uuid,
830
- target_node_uuid: endNode(e).uuid,
831
- created_at: e.created_at,
832
- name: e.name,
833
- group_id: e.group_id,
834
- fact: e.fact,
835
- fact_embedding: e.fact_embedding,
836
- episodes: e.episodes,
837
- expired_at: e.expired_at,
838
- valid_at: e.valid_at,
839
- invalid_at: e.invalid_at
840
- })[..$limit] AS matches
881
+ WITH edge, e, """
882
+ + get_vector_cosine_func_query('e.fact_embedding', 'edge.fact_embedding', driver.provider)
883
+ + """ AS score
884
+ WHERE score > $min_score
885
+ WITH edge, e, score
886
+ ORDER BY score DESC
887
+ RETURN edge.uuid AS search_edge_uuid,
888
+ collect({
889
+ uuid: e.uuid,
890
+ source_node_uuid: startNode(e).uuid,
891
+ target_node_uuid: endNode(e).uuid,
892
+ created_at: e.created_at,
893
+ name: e.name,
894
+ group_id: e.group_id,
895
+ fact: e.fact,
896
+ fact_embedding: e.fact_embedding,
897
+ episodes: e.episodes,
898
+ expired_at: e.expired_at,
899
+ valid_at: e.valid_at,
900
+ invalid_at: e.invalid_at,
901
+ attributes: properties(e)
902
+ })[..$limit] AS matches
841
903
  """
842
904
  )
843
905
 
844
906
  results, _, _ = await driver.execute_query(
845
907
  query,
846
- query_params,
908
+ params=query_params,
847
909
  edges=[edge.model_dump() for edge in edges],
848
910
  limit=limit,
849
911
  min_score=min_score,
@@ -878,7 +940,7 @@ def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[st
878
940
 
879
941
 
880
942
  async def node_distance_reranker(
881
- driver: AsyncDriver,
943
+ driver: GraphDriver,
882
944
  node_uuids: list[str],
883
945
  center_node_uuid: str,
884
946
  min_score: float = 0,
@@ -888,20 +950,22 @@ async def node_distance_reranker(
888
950
  scores: dict[str, float] = {center_node_uuid: 0.0}
889
951
 
890
952
  # Find the shortest path to center node
891
- query = Query("""
953
+ query = """
892
954
  UNWIND $node_uuids AS node_uuid
893
- MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: node_uuid})
894
- RETURN length(p) AS score, node_uuid AS uuid
895
- """)
896
-
897
- path_results, _, _ = await driver.execute_query(
955
+ MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
956
+ RETURN 1 AS score, node_uuid AS uuid
957
+ """
958
+ results, header, _ = await driver.execute_query(
898
959
  query,
899
960
  node_uuids=filtered_uuids,
900
961
  center_uuid=center_node_uuid,
901
962
  database_=DEFAULT_DATABASE,
963
+ routing_='r',
902
964
  )
965
+ if driver.provider == 'falkordb':
966
+ results = [dict(zip(header, row, strict=True)) for row in results]
903
967
 
904
- for result in path_results:
968
+ for result in results:
905
969
  uuid = result['uuid']
906
970
  score = result['score']
907
971
  scores[uuid] = score
@@ -922,23 +986,23 @@ async def node_distance_reranker(
922
986
 
923
987
 
924
988
  async def episode_mentions_reranker(
925
- driver: AsyncDriver, node_uuids: list[list[str]], min_score: float = 0
989
+ driver: GraphDriver, node_uuids: list[list[str]], min_score: float = 0
926
990
  ) -> list[str]:
927
991
  # use rrf as a preliminary ranker
928
992
  sorted_uuids = rrf(node_uuids)
929
993
  scores: dict[str, float] = {}
930
994
 
931
995
  # Find the shortest path to center node
932
- query = Query("""
996
+ query = """
933
997
  UNWIND $node_uuids AS node_uuid
934
998
  MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
935
999
  RETURN count(*) AS score, n.uuid AS uuid
936
- """)
937
-
1000
+ """
938
1001
  results, _, _ = await driver.execute_query(
939
1002
  query,
940
1003
  node_uuids=sorted_uuids,
941
1004
  database_=DEFAULT_DATABASE,
1005
+ routing_='r',
942
1006
  )
943
1007
 
944
1008
  for result in results:
@@ -952,15 +1016,116 @@ async def episode_mentions_reranker(
952
1016
 
953
1017
  def maximal_marginal_relevance(
954
1018
  query_vector: list[float],
955
- candidates: list[tuple[str, list[float]]],
1019
+ candidates: dict[str, list[float]],
956
1020
  mmr_lambda: float = DEFAULT_MMR_LAMBDA,
957
- ):
958
- candidates_with_mmr: list[tuple[str, float]] = []
959
- for candidate in candidates:
960
- max_sim = max([np.dot(normalize_l2(candidate[1]), normalize_l2(c[1])) for c in candidates])
961
- mmr = mmr_lambda * np.dot(candidate[1], query_vector) - (1 - mmr_lambda) * max_sim
962
- candidates_with_mmr.append((candidate[0], mmr))
1021
+ min_score: float = -2.0,
1022
+ ) -> list[str]:
1023
+ start = time()
1024
+ query_array = np.array(query_vector)
1025
+ candidate_arrays: dict[str, NDArray] = {}
1026
+ for uuid, embedding in candidates.items():
1027
+ candidate_arrays[uuid] = normalize_l2(embedding)
1028
+
1029
+ uuids: list[str] = list(candidate_arrays.keys())
1030
+
1031
+ similarity_matrix = np.zeros((len(uuids), len(uuids)))
1032
+
1033
+ for i, uuid_1 in enumerate(uuids):
1034
+ for j, uuid_2 in enumerate(uuids[:i]):
1035
+ u = candidate_arrays[uuid_1]
1036
+ v = candidate_arrays[uuid_2]
1037
+ similarity = np.dot(u, v)
1038
+
1039
+ similarity_matrix[i, j] = similarity
1040
+ similarity_matrix[j, i] = similarity
1041
+
1042
+ mmr_scores: dict[str, float] = {}
1043
+ for i, uuid in enumerate(uuids):
1044
+ max_sim = np.max(similarity_matrix[i, :])
1045
+ mmr = mmr_lambda * np.dot(query_array, candidate_arrays[uuid]) + (mmr_lambda - 1) * max_sim
1046
+ mmr_scores[uuid] = mmr
1047
+
1048
+ uuids.sort(reverse=True, key=lambda c: mmr_scores[c])
1049
+
1050
+ end = time()
1051
+ logger.debug(f'Completed MMR reranking in {(end - start) * 1000} ms')
1052
+
1053
+ return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score]
1054
+
1055
+
1056
+ async def get_embeddings_for_nodes(
1057
+ driver: GraphDriver, nodes: list[EntityNode]
1058
+ ) -> dict[str, list[float]]:
1059
+ query: LiteralString = """MATCH (n:Entity)
1060
+ WHERE n.uuid IN $node_uuids
1061
+ RETURN DISTINCT
1062
+ n.uuid AS uuid,
1063
+ n.name_embedding AS name_embedding
1064
+ """
1065
+
1066
+ results, _, _ = await driver.execute_query(
1067
+ query, node_uuids=[node.uuid for node in nodes], database_=DEFAULT_DATABASE, routing_='r'
1068
+ )
1069
+
1070
+ embeddings_dict: dict[str, list[float]] = {}
1071
+ for result in results:
1072
+ uuid: str = result.get('uuid')
1073
+ embedding: list[float] = result.get('name_embedding')
1074
+ if uuid is not None and embedding is not None:
1075
+ embeddings_dict[uuid] = embedding
1076
+
1077
+ return embeddings_dict
1078
+
963
1079
 
964
- candidates_with_mmr.sort(reverse=True, key=lambda c: c[1])
1080
+ async def get_embeddings_for_communities(
1081
+ driver: GraphDriver, communities: list[CommunityNode]
1082
+ ) -> dict[str, list[float]]:
1083
+ query: LiteralString = """MATCH (c:Community)
1084
+ WHERE c.uuid IN $community_uuids
1085
+ RETURN DISTINCT
1086
+ c.uuid AS uuid,
1087
+ c.name_embedding AS name_embedding
1088
+ """
1089
+
1090
+ results, _, _ = await driver.execute_query(
1091
+ query,
1092
+ community_uuids=[community.uuid for community in communities],
1093
+ database_=DEFAULT_DATABASE,
1094
+ routing_='r',
1095
+ )
1096
+
1097
+ embeddings_dict: dict[str, list[float]] = {}
1098
+ for result in results:
1099
+ uuid: str = result.get('uuid')
1100
+ embedding: list[float] = result.get('name_embedding')
1101
+ if uuid is not None and embedding is not None:
1102
+ embeddings_dict[uuid] = embedding
1103
+
1104
+ return embeddings_dict
1105
+
1106
+
1107
+ async def get_embeddings_for_edges(
1108
+ driver: GraphDriver, edges: list[EntityEdge]
1109
+ ) -> dict[str, list[float]]:
1110
+ query: LiteralString = """MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
1111
+ WHERE e.uuid IN $edge_uuids
1112
+ RETURN DISTINCT
1113
+ e.uuid AS uuid,
1114
+ e.fact_embedding AS fact_embedding
1115
+ """
1116
+
1117
+ results, _, _ = await driver.execute_query(
1118
+ query,
1119
+ edge_uuids=[edge.uuid for edge in edges],
1120
+ database_=DEFAULT_DATABASE,
1121
+ routing_='r',
1122
+ )
1123
+
1124
+ embeddings_dict: dict[str, list[float]] = {}
1125
+ for result in results:
1126
+ uuid: str = result.get('uuid')
1127
+ embedding: list[float] = result.get('fact_embedding')
1128
+ if uuid is not None and embedding is not None:
1129
+ embeddings_dict[uuid] = embedding
965
1130
 
966
- return list(set([candidate[0] for candidate in candidates_with_mmr]))
1131
+ return embeddings_dict