graphiti-core 0.18.9__py3-none-any.whl → 0.19.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -15,6 +15,7 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  import logging
18
+ import os
18
19
  from collections import defaultdict
19
20
  from time import time
20
21
  from typing import Any
@@ -54,6 +55,7 @@ from graphiti_core.search.search_filters import (
54
55
  )
55
56
 
56
57
  logger = logging.getLogger(__name__)
58
+ USE_HNSW = os.getenv('USE_HNSW', '').lower() in ('true', '1', 'yes')
57
59
 
58
60
  RELEVANT_SCHEMA_LIMIT = 10
59
61
  DEFAULT_MIN_SCORE = 0.6
@@ -62,6 +64,20 @@ MAX_SEARCH_DEPTH = 3
62
64
  MAX_QUERY_LENGTH = 128
63
65
 
64
66
 
67
+ def calculate_cosine_similarity(vector1: list[float], vector2: list[float]) -> float:
68
+ """
69
+ Calculates the cosine similarity between two vectors using NumPy.
70
+ """
71
+ dot_product = np.dot(vector1, vector2)
72
+ norm_vector1 = np.linalg.norm(vector1)
73
+ norm_vector2 = np.linalg.norm(vector2)
74
+
75
+ if norm_vector1 == 0 or norm_vector2 == 0:
76
+ return 0 # Handle cases where one or both vectors are zero vectors
77
+
78
+ return dot_product / (norm_vector1 * norm_vector2)
79
+
80
+
65
81
  def fulltext_query(query: str, group_ids: list[str] | None = None, fulltext_syntax: str = ''):
66
82
  group_ids_filter_list = (
67
83
  [fulltext_syntax + f'group_id:"{g}"' for g in group_ids] if group_ids is not None else []
@@ -153,32 +169,80 @@ async def edge_fulltext_search(
153
169
 
154
170
  filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
155
171
 
156
- query = (
157
- get_relationships_query('edge_name_and_fact', provider=driver.provider)
158
- + """
159
- YIELD relationship AS rel, score
160
- MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
161
- WHERE e.group_id IN $group_ids """
162
- + filter_query
163
- + """
164
- WITH e, score, n, m
165
- RETURN
166
- """
167
- + ENTITY_EDGE_RETURN
168
- + """
169
- ORDER BY score DESC
170
- LIMIT $limit
171
- """
172
- )
172
+ if driver.provider == GraphProvider.NEPTUNE:
173
+ res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue
174
+ if res['hits']['total']['value'] > 0:
175
+ # Calculate Cosine similarity then return the edge ids
176
+ input_ids = []
177
+ for r in res['hits']['hits']:
178
+ input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
179
+
180
+ # Match the edge ids and return the values
181
+ query = (
182
+ """
183
+ UNWIND $ids as id
184
+ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
185
+ WHERE e.group_id IN $group_ids
186
+ AND id(e)=id
187
+ """
188
+ + filter_query
189
+ + """
190
+ WITH e, id.score as score, startNode(e) AS n, endNode(e) AS m
191
+ RETURN
192
+ e.uuid AS uuid,
193
+ e.group_id AS group_id,
194
+ n.uuid AS source_node_uuid,
195
+ m.uuid AS target_node_uuid,
196
+ e.created_at AS created_at,
197
+ e.name AS name,
198
+ e.fact AS fact,
199
+ split(e.episodes, ",") AS episodes,
200
+ e.expired_at AS expired_at,
201
+ e.valid_at AS valid_at,
202
+ e.invalid_at AS invalid_at,
203
+ properties(e) AS attributes
204
+ ORDER BY score DESC LIMIT $limit
205
+ """
206
+ )
207
+
208
+ records, _, _ = await driver.execute_query(
209
+ query,
210
+ query=fuzzy_query,
211
+ group_ids=group_ids,
212
+ ids=input_ids,
213
+ limit=limit,
214
+ routing_='r',
215
+ **filter_params,
216
+ )
217
+ else:
218
+ return []
219
+ else:
220
+ query = (
221
+ get_relationships_query('edge_name_and_fact', provider=driver.provider)
222
+ + """
223
+ YIELD relationship AS rel, score
224
+ MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
225
+ WHERE e.group_id IN $group_ids """
226
+ + filter_query
227
+ + """
228
+ WITH e, score, n, m
229
+ RETURN
230
+ """
231
+ + ENTITY_EDGE_RETURN
232
+ + """
233
+ ORDER BY score DESC
234
+ LIMIT $limit
235
+ """
236
+ )
173
237
 
174
- records, _, _ = await driver.execute_query(
175
- query,
176
- query=fuzzy_query,
177
- group_ids=group_ids,
178
- limit=limit,
179
- routing_='r',
180
- **filter_params,
181
- )
238
+ records, _, _ = await driver.execute_query(
239
+ query,
240
+ query=fuzzy_query,
241
+ group_ids=group_ids,
242
+ limit=limit,
243
+ routing_='r',
244
+ **filter_params,
245
+ )
182
246
 
183
247
  edges = [get_entity_edge_from_record(record) for record in records]
184
248
 
@@ -214,35 +278,100 @@ async def edge_similarity_search(
214
278
  query_params['target_uuid'] = target_node_uuid
215
279
  group_filter_query += '\nAND (m.uuid = $target_uuid)'
216
280
 
217
- query = (
218
- RUNTIME_QUERY
219
- + """
220
- MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
221
- """
222
- + group_filter_query
223
- + filter_query
224
- + """
225
- WITH DISTINCT e, n, m, """
226
- + get_vector_cosine_func_query('e.fact_embedding', '$search_vector', driver.provider)
227
- + """ AS score
228
- WHERE score > $min_score
229
- RETURN
230
- """
231
- + ENTITY_EDGE_RETURN
232
- + """
233
- ORDER BY score DESC
234
- LIMIT $limit
235
- """
236
- )
281
+ if driver.provider == GraphProvider.NEPTUNE:
282
+ query = (
283
+ RUNTIME_QUERY
284
+ + """
285
+ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
286
+ """
287
+ + group_filter_query
288
+ + filter_query
289
+ + """
290
+ RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
291
+ """
292
+ )
293
+ resp, header, _ = await driver.execute_query(
294
+ query,
295
+ search_vector=search_vector,
296
+ limit=limit,
297
+ min_score=min_score,
298
+ routing_='r',
299
+ **query_params,
300
+ )
237
301
 
238
- records, _, _ = await driver.execute_query(
239
- query,
240
- search_vector=search_vector,
241
- limit=limit,
242
- min_score=min_score,
243
- routing_='r',
244
- **query_params,
245
- )
302
+ if len(resp) > 0:
303
+ # Calculate Cosine similarity then return the edge ids
304
+ input_ids = []
305
+ for r in resp:
306
+ if r['embedding']:
307
+ score = calculate_cosine_similarity(
308
+ search_vector, list(map(float, r['embedding'].split(',')))
309
+ )
310
+ if score > min_score:
311
+ input_ids.append({'id': r['id'], 'score': score})
312
+
313
+ # Match the edge ides and return the values
314
+ query = """
315
+ UNWIND $ids as i
316
+ MATCH ()-[r]->()
317
+ WHERE id(r) = i.id
318
+ RETURN
319
+ r.uuid AS uuid,
320
+ r.group_id AS group_id,
321
+ startNode(r).uuid AS source_node_uuid,
322
+ endNode(r).uuid AS target_node_uuid,
323
+ r.created_at AS created_at,
324
+ r.name AS name,
325
+ r.fact AS fact,
326
+ split(r.episodes, ",") AS episodes,
327
+ r.expired_at AS expired_at,
328
+ r.valid_at AS valid_at,
329
+ r.invalid_at AS invalid_at,
330
+ properties(r) AS attributes
331
+ ORDER BY i.score DESC
332
+ LIMIT $limit
333
+ """
334
+ records, _, _ = await driver.execute_query(
335
+ query,
336
+ ids=input_ids,
337
+ search_vector=search_vector,
338
+ limit=limit,
339
+ min_score=min_score,
340
+ routing_='r',
341
+ **query_params,
342
+ )
343
+ else:
344
+ return []
345
+ else:
346
+ query = (
347
+ RUNTIME_QUERY
348
+ + """
349
+ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
350
+ """
351
+ + group_filter_query
352
+ + filter_query
353
+ + """
354
+ WITH DISTINCT e, n, m, """
355
+ + get_vector_cosine_func_query('e.fact_embedding', '$search_vector', driver.provider)
356
+ + """ AS score
357
+ WHERE score > $min_score
358
+ RETURN
359
+ """
360
+ + ENTITY_EDGE_RETURN
361
+ + """
362
+ ORDER BY score DESC
363
+ LIMIT $limit
364
+ """
365
+ )
366
+
367
+ records, _, _ = await driver.execute_query(
368
+ query,
369
+ search_vector=search_vector,
370
+ limit=limit,
371
+ min_score=min_score,
372
+ routing_='r',
373
+ **query_params,
374
+ )
246
375
 
247
376
  edges = [get_entity_edge_from_record(record) for record in records]
248
377
 
@@ -263,28 +392,58 @@ async def edge_bfs_search(
263
392
 
264
393
  filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
265
394
 
266
- query = (
267
- f"""
268
- UNWIND $bfs_origin_node_uuids AS origin_uuid
269
- MATCH path = (origin:Entity|Episodic {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(:Entity)
270
- UNWIND relationships(path) AS rel
271
- MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
272
- WHERE e.uuid = rel.uuid
273
- AND e.group_id IN $group_ids
274
- """
275
- + filter_query
276
- + """
277
- RETURN DISTINCT
278
- """
279
- + ENTITY_EDGE_RETURN
280
- + """
281
- LIMIT $limit
282
- """
283
- )
395
+ if driver.provider == GraphProvider.NEPTUNE:
396
+ query = (
397
+ f"""
398
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
399
+ MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS *1..{bfs_max_depth}]->(n:Entity)
400
+ WHERE origin:Entity OR origin:Episodic
401
+ UNWIND relationships(path) AS rel
402
+ MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
403
+ WHERE e.uuid = rel.uuid
404
+ """
405
+ + filter_query
406
+ + """
407
+ RETURN DISTINCT
408
+ e.uuid AS uuid,
409
+ e.group_id AS group_id,
410
+ startNode(e).uuid AS source_node_uuid,
411
+ endNode(e).uuid AS target_node_uuid,
412
+ e.created_at AS created_at,
413
+ e.name AS name,
414
+ e.fact AS fact,
415
+ split(e.episodes, ',') AS episodes,
416
+ e.expired_at AS expired_at,
417
+ e.valid_at AS valid_at,
418
+ e.invalid_at AS invalid_at,
419
+ properties(e) AS attributes
420
+ LIMIT $limit
421
+ """
422
+ )
423
+ else:
424
+ query = (
425
+ f"""
426
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
427
+ MATCH path = (origin:Entity|Episodic {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(:Entity)
428
+ UNWIND relationships(path) AS rel
429
+ MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
430
+ WHERE e.uuid = rel.uuid
431
+ AND e.group_id IN $group_ids
432
+ """
433
+ + filter_query
434
+ + """
435
+ RETURN DISTINCT
436
+ """
437
+ + ENTITY_EDGE_RETURN
438
+ + """
439
+ LIMIT $limit
440
+ """
441
+ )
284
442
 
285
443
  records, _, _ = await driver.execute_query(
286
444
  query,
287
445
  bfs_origin_node_uuids=bfs_origin_node_uuids,
446
+ depth=bfs_max_depth,
288
447
  group_ids=group_ids,
289
448
  limit=limit,
290
449
  routing_='r',
@@ -309,30 +468,70 @@ async def node_fulltext_search(
309
468
  return []
310
469
  filter_query, filter_params = node_search_filter_query_constructor(search_filter)
311
470
 
312
- query = (
313
- get_nodes_query(driver.provider, 'node_name_and_summary', '$query')
314
- + """
315
- YIELD node AS n, score
316
- WHERE n:Entity AND n.group_id IN $group_ids
317
- """
318
- + filter_query
319
- + """
320
- WITH n, score
321
- ORDER BY score DESC
322
- LIMIT $limit
323
- RETURN
324
- """
325
- + ENTITY_NODE_RETURN
326
- )
471
+ if driver.provider == GraphProvider.NEPTUNE:
472
+ res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
473
+ if res['hits']['total']['value'] > 0:
474
+ # Calculate Cosine similarity then return the edge ids
475
+ input_ids = []
476
+ for r in res['hits']['hits']:
477
+ input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
478
+
479
+ # Match the edge ides and return the values
480
+ query = (
481
+ """
482
+ UNWIND $ids as i
483
+ MATCH (n:Entity)
484
+ WHERE n.uuid=i.id
485
+ RETURN
486
+ """
487
+ + ENTITY_NODE_RETURN
488
+ + """
489
+ ORDER BY i.score DESC
490
+ LIMIT $limit
491
+ """
492
+ )
493
+ records, _, _ = await driver.execute_query(
494
+ query,
495
+ ids=input_ids,
496
+ query=fuzzy_query,
497
+ group_ids=group_ids,
498
+ limit=limit,
499
+ routing_='r',
500
+ **filter_params,
501
+ )
502
+ else:
503
+ return []
504
+ else:
505
+ index_name = (
506
+ 'node_name_and_summary'
507
+ if not USE_HNSW
508
+ else 'node_name_and_summary_'
509
+ + (group_ids[0].replace('-', '') if group_ids is not None else '')
510
+ )
511
+ query = (
512
+ get_nodes_query(driver.provider, index_name, '$query')
513
+ + """
514
+ YIELD node AS n, score
515
+ WHERE n:Entity AND n.group_id IN $group_ids
516
+ """
517
+ + filter_query
518
+ + """
519
+ WITH n, score
520
+ ORDER BY score DESC
521
+ LIMIT $limit
522
+ RETURN
523
+ """
524
+ + ENTITY_NODE_RETURN
525
+ )
327
526
 
328
- records, _, _ = await driver.execute_query(
329
- query,
330
- query=fuzzy_query,
331
- group_ids=group_ids,
332
- limit=limit,
333
- routing_='r',
334
- **filter_params,
335
- )
527
+ records, _, _ = await driver.execute_query(
528
+ query,
529
+ query=fuzzy_query,
530
+ group_ids=group_ids,
531
+ limit=limit,
532
+ routing_='r',
533
+ **filter_params,
534
+ )
336
535
 
337
536
  nodes = [get_entity_node_from_record(record) for record in records]
338
537
 
@@ -358,35 +557,124 @@ async def node_similarity_search(
358
557
  filter_query, filter_params = node_search_filter_query_constructor(search_filter)
359
558
  query_params.update(filter_params)
360
559
 
361
- query = (
362
- RUNTIME_QUERY
363
- + """
364
- MATCH (n:Entity)
365
- """
366
- + group_filter_query
367
- + filter_query
368
- + """
369
- WITH n, """
370
- + get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
371
- + """ AS score
372
- WHERE score > $min_score
373
- RETURN
374
- """
375
- + ENTITY_NODE_RETURN
376
- + """
377
- ORDER BY score DESC
378
- LIMIT $limit
379
- """
380
- )
560
+ if driver.provider == GraphProvider.NEPTUNE:
561
+ query = (
562
+ RUNTIME_QUERY
563
+ + """
564
+ MATCH (n:Entity)
565
+ """
566
+ + group_filter_query
567
+ + filter_query
568
+ + """
569
+ RETURN DISTINCT id(n) as id, n.name_embedding as embedding
570
+ """
571
+ )
572
+ resp, header, _ = await driver.execute_query(
573
+ query,
574
+ params=query_params,
575
+ search_vector=search_vector,
576
+ group_ids=group_ids,
577
+ limit=limit,
578
+ min_score=min_score,
579
+ routing_='r',
580
+ )
381
581
 
382
- records, _, _ = await driver.execute_query(
383
- query,
384
- search_vector=search_vector,
385
- limit=limit,
386
- min_score=min_score,
387
- routing_='r',
388
- **query_params,
389
- )
582
+ if len(resp) > 0:
583
+ # Calculate Cosine similarity then return the edge ids
584
+ input_ids = []
585
+ for r in resp:
586
+ if r['embedding']:
587
+ score = calculate_cosine_similarity(
588
+ search_vector, list(map(float, r['embedding'].split(',')))
589
+ )
590
+ if score > min_score:
591
+ input_ids.append({'id': r['id'], 'score': score})
592
+
593
+ # Match the edge ides and return the values
594
+ query = (
595
+ """
596
+ UNWIND $ids as i
597
+ MATCH (n:Entity)
598
+ WHERE id(n)=i.id
599
+ RETURN
600
+ """
601
+ + ENTITY_NODE_RETURN
602
+ + """
603
+ ORDER BY i.score DESC
604
+ LIMIT $limit
605
+ """
606
+ )
607
+ records, header, _ = await driver.execute_query(
608
+ query,
609
+ ids=input_ids,
610
+ search_vector=search_vector,
611
+ limit=limit,
612
+ min_score=min_score,
613
+ routing_='r',
614
+ **query_params,
615
+ )
616
+ else:
617
+ return []
618
+ elif driver.provider == GraphProvider.NEO4J and USE_HNSW:
619
+ index_name = 'group_entity_vector_' + (
620
+ group_ids[0].replace('-', '') if group_ids is not None else ''
621
+ )
622
+ query = (
623
+ f"""
624
+ CALL db.index.vector.queryNodes('{index_name}', {limit}, $search_vector) YIELD node AS n, score
625
+ """
626
+ + group_filter_query
627
+ + filter_query
628
+ + """
629
+ AND score > $min_score
630
+ RETURN
631
+ """
632
+ + ENTITY_NODE_RETURN
633
+ + """
634
+ ORDER BY score DESC
635
+ LIMIT $limit
636
+ """
637
+ )
638
+
639
+ records, _, _ = await driver.execute_query(
640
+ query,
641
+ search_vector=search_vector,
642
+ limit=limit,
643
+ min_score=min_score,
644
+ routing_='r',
645
+ **query_params,
646
+ )
647
+
648
+ else:
649
+ query = (
650
+ RUNTIME_QUERY
651
+ + """
652
+ MATCH (n:Entity)
653
+ """
654
+ + group_filter_query
655
+ + filter_query
656
+ + """
657
+ WITH n, """
658
+ + get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
659
+ + """ AS score
660
+ WHERE score > $min_score
661
+ RETURN
662
+ """
663
+ + ENTITY_NODE_RETURN
664
+ + """
665
+ ORDER BY score DESC
666
+ LIMIT $limit
667
+ """
668
+ )
669
+
670
+ records, _, _ = await driver.execute_query(
671
+ query,
672
+ search_vector=search_vector,
673
+ limit=limit,
674
+ min_score=min_score,
675
+ routing_='r',
676
+ **query_params,
677
+ )
390
678
 
391
679
  nodes = [get_entity_node_from_record(record) for record in records]
392
680
 
@@ -407,22 +695,40 @@ async def node_bfs_search(
407
695
 
408
696
  filter_query, filter_params = node_search_filter_query_constructor(search_filter)
409
697
 
410
- query = (
411
- f"""
412
- UNWIND $bfs_origin_node_uuids AS origin_uuid
413
- MATCH (origin:Entity|Episodic {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
414
- WHERE n.group_id = origin.group_id
415
- AND origin.group_id IN $group_ids
416
- """
417
- + filter_query
418
- + """
419
- RETURN
420
- """
421
- + ENTITY_NODE_RETURN
422
- + """
423
- LIMIT $limit
424
- """
425
- )
698
+ if driver.provider == GraphProvider.NEPTUNE:
699
+ query = (
700
+ f"""
701
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
702
+ MATCH (origin {{uuid: origin_uuid}})-[e:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
703
+ WHERE origin:Entity OR origin.Episode
704
+ AND n.group_id = origin.group_id
705
+ """
706
+ + filter_query
707
+ + """
708
+ RETURN
709
+ """
710
+ + ENTITY_NODE_RETURN
711
+ + """
712
+ LIMIT $limit
713
+ """
714
+ )
715
+ else:
716
+ query = (
717
+ f"""
718
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
719
+ MATCH (origin:Entity|Episodic {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
720
+ WHERE n.group_id = origin.group_id
721
+ AND origin.group_id IN $group_ids
722
+ """
723
+ + filter_query
724
+ + """
725
+ RETURN
726
+ """
727
+ + ENTITY_NODE_RETURN
728
+ + """
729
+ LIMIT $limit
730
+ """
731
+ )
426
732
 
427
733
  records, _, _ = await driver.execute_query(
428
734
  query,
@@ -449,29 +755,72 @@ async def episode_fulltext_search(
449
755
  if fuzzy_query == '':
450
756
  return []
451
757
 
452
- query = (
453
- get_nodes_query(driver.provider, 'episode_content', '$query')
454
- + """
455
- YIELD node AS episode, score
456
- MATCH (e:Episodic)
457
- WHERE e.uuid = episode.uuid
458
- AND e.group_id IN $group_ids
459
- RETURN
460
- """
461
- + EPISODIC_NODE_RETURN
462
- + """
463
- ORDER BY score DESC
464
- LIMIT $limit
465
- """
466
- )
758
+ if driver.provider == GraphProvider.NEPTUNE:
759
+ res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
760
+ if res['hits']['total']['value'] > 0:
761
+ # Calculate Cosine similarity then return the edge ids
762
+ input_ids = []
763
+ for r in res['hits']['hits']:
764
+ input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
765
+
766
+ # Match the edge ides and return the values
767
+ query = """
768
+ UNWIND $ids as i
769
+ MATCH (e:Episodic)
770
+ WHERE e.uuid=i.id
771
+ RETURN
772
+ e.content AS content,
773
+ e.created_at AS created_at,
774
+ e.valid_at AS valid_at,
775
+ e.uuid AS uuid,
776
+ e.name AS name,
777
+ e.group_id AS group_id,
778
+ e.source_description AS source_description,
779
+ e.source AS source,
780
+ e.entity_edges AS entity_edges
781
+ ORDER BY i.score DESC
782
+ LIMIT $limit
783
+ """
784
+ records, _, _ = await driver.execute_query(
785
+ query,
786
+ ids=input_ids,
787
+ query=fuzzy_query,
788
+ group_ids=group_ids,
789
+ limit=limit,
790
+ routing_='r',
791
+ )
792
+ else:
793
+ return []
794
+ else:
795
+ index_name = (
796
+ 'episode_content'
797
+ if not USE_HNSW
798
+ else 'episode_content_'
799
+ + (group_ids[0].replace('-', '') if group_ids is not None else '')
800
+ )
801
+ query = (
802
+ get_nodes_query(driver.provider, index_name, '$query')
803
+ + """
804
+ YIELD node AS episode, score
805
+ MATCH (e:Episodic)
806
+ WHERE e.uuid = episode.uuid
807
+ AND e.group_id IN $group_ids
808
+ RETURN
809
+ """
810
+ + EPISODIC_NODE_RETURN
811
+ + """
812
+ ORDER BY score DESC
813
+ LIMIT $limit
814
+ """
815
+ )
467
816
 
468
- records, _, _ = await driver.execute_query(
469
- query,
470
- query=fuzzy_query,
471
- group_ids=group_ids,
472
- limit=limit,
473
- routing_='r',
474
- )
817
+ records, _, _ = await driver.execute_query(
818
+ query,
819
+ query=fuzzy_query,
820
+ group_ids=group_ids,
821
+ limit=limit,
822
+ routing_='r',
823
+ )
475
824
  episodes = [get_episodic_node_from_record(record) for record in records]
476
825
 
477
826
  return episodes
@@ -488,27 +837,61 @@ async def community_fulltext_search(
488
837
  if fuzzy_query == '':
489
838
  return []
490
839
 
491
- query = (
492
- get_nodes_query(driver.provider, 'community_name', '$query')
493
- + """
494
- YIELD node AS n, score
495
- WHERE n.group_id IN $group_ids
496
- RETURN
497
- """
498
- + COMMUNITY_NODE_RETURN
499
- + """
500
- ORDER BY score DESC
501
- LIMIT $limit
502
- """
503
- )
840
+ if driver.provider == GraphProvider.NEPTUNE:
841
+ res = driver.run_aoss_query('community_name', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
842
+ if res['hits']['total']['value'] > 0:
843
+ # Calculate Cosine similarity then return the edge ids
844
+ input_ids = []
845
+ for r in res['hits']['hits']:
846
+ input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
847
+
848
+ # Match the edge ides and return the values
849
+ query = """
850
+ UNWIND $ids as i
851
+ MATCH (comm:Community)
852
+ WHERE comm.uuid=i.id
853
+ RETURN
854
+ comm.uuid AS uuid,
855
+ comm.group_id AS group_id,
856
+ comm.name AS name,
857
+ comm.created_at AS created_at,
858
+ comm.summary AS summary,
859
+ [x IN split(comm.name_embedding, ",") | toFloat(x)]AS name_embedding
860
+ ORDER BY i.score DESC
861
+ LIMIT $limit
862
+ """
863
+ records, _, _ = await driver.execute_query(
864
+ query,
865
+ ids=input_ids,
866
+ query=fuzzy_query,
867
+ group_ids=group_ids,
868
+ limit=limit,
869
+ routing_='r',
870
+ )
871
+ else:
872
+ return []
873
+ else:
874
+ query = (
875
+ get_nodes_query(driver.provider, 'community_name', '$query')
876
+ + """
877
+ YIELD node AS n, score
878
+ WHERE n.group_id IN $group_ids
879
+ RETURN
880
+ """
881
+ + COMMUNITY_NODE_RETURN
882
+ + """
883
+ ORDER BY score DESC
884
+ LIMIT $limit
885
+ """
886
+ )
504
887
 
505
- records, _, _ = await driver.execute_query(
506
- query,
507
- query=fuzzy_query,
508
- group_ids=group_ids,
509
- limit=limit,
510
- routing_='r',
511
- )
888
+ records, _, _ = await driver.execute_query(
889
+ query,
890
+ query=fuzzy_query,
891
+ group_ids=group_ids,
892
+ limit=limit,
893
+ routing_='r',
894
+ )
512
895
  communities = [get_community_node_from_record(record) for record in records]
513
896
 
514
897
  return communities
@@ -529,35 +912,93 @@ async def community_similarity_search(
529
912
  group_filter_query += 'WHERE n.group_id IN $group_ids'
530
913
  query_params['group_ids'] = group_ids
531
914
 
532
- query = (
533
- RUNTIME_QUERY
534
- + """
535
- MATCH (n:Community)
536
- """
537
- + group_filter_query
538
- + """
539
- WITH n,
540
- """
541
- + get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
542
- + """ AS score
543
- WHERE score > $min_score
544
- RETURN
545
- """
546
- + COMMUNITY_NODE_RETURN
547
- + """
548
- ORDER BY score DESC
549
- LIMIT $limit
550
- """
551
- )
915
+ if driver.provider == GraphProvider.NEPTUNE:
916
+ query = (
917
+ RUNTIME_QUERY
918
+ + """
919
+ MATCH (n:Community)
920
+ """
921
+ + group_filter_query
922
+ + """
923
+ RETURN DISTINCT id(n) as id, n.name_embedding as embedding
924
+ """
925
+ )
926
+ resp, header, _ = await driver.execute_query(
927
+ query,
928
+ search_vector=search_vector,
929
+ limit=limit,
930
+ min_score=min_score,
931
+ routing_='r',
932
+ **query_params,
933
+ )
552
934
 
553
- records, _, _ = await driver.execute_query(
554
- query,
555
- search_vector=search_vector,
556
- limit=limit,
557
- min_score=min_score,
558
- routing_='r',
559
- **query_params,
560
- )
935
+ if len(resp) > 0:
936
+ # Calculate Cosine similarity then return the edge ids
937
+ input_ids = []
938
+ for r in resp:
939
+ if r['embedding']:
940
+ score = calculate_cosine_similarity(
941
+ search_vector, list(map(float, r['embedding'].split(',')))
942
+ )
943
+ if score > min_score:
944
+ input_ids.append({'id': r['id'], 'score': score})
945
+
946
+ # Match the edge ides and return the values
947
+ query = """
948
+ UNWIND $ids as i
949
+ MATCH (comm:Community)
950
+ WHERE id(comm)=i.id
951
+ RETURN
952
+ comm.uuid As uuid,
953
+ comm.group_id AS group_id,
954
+ comm.name AS name,
955
+ comm.created_at AS created_at,
956
+ comm.summary AS summary,
957
+ comm.name_embedding AS name_embedding
958
+ ORDER BY i.score DESC
959
+ LIMIT $limit
960
+ """
961
+ records, header, _ = await driver.execute_query(
962
+ query,
963
+ ids=input_ids,
964
+ search_vector=search_vector,
965
+ limit=limit,
966
+ min_score=min_score,
967
+ routing_='r',
968
+ **query_params,
969
+ )
970
+ else:
971
+ return []
972
+ else:
973
+ query = (
974
+ RUNTIME_QUERY
975
+ + """
976
+ MATCH (n:Community)
977
+ """
978
+ + group_filter_query
979
+ + """
980
+ WITH n,
981
+ """
982
+ + get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
983
+ + """ AS score
984
+ WHERE score > $min_score
985
+ RETURN
986
+ """
987
+ + COMMUNITY_NODE_RETURN
988
+ + """
989
+ ORDER BY score DESC
990
+ LIMIT $limit
991
+ """
992
+ )
993
+
994
+ records, _, _ = await driver.execute_query(
995
+ query,
996
+ search_vector=search_vector,
997
+ limit=limit,
998
+ min_score=min_score,
999
+ routing_='r',
1000
+ **query_params,
1001
+ )
561
1002
  communities = [get_community_node_from_record(record) for record in records]
562
1003
 
563
1004
  return communities
@@ -746,20 +1187,45 @@ async def get_relevant_edges(
746
1187
  filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
747
1188
  query_params.update(filter_params)
748
1189
 
749
- query = (
750
- RUNTIME_QUERY
751
- + """
752
- UNWIND $edges AS edge
753
- MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
754
- """
755
- + filter_query
756
- + """
757
- WITH e, edge, """
758
- + get_vector_cosine_func_query('e.fact_embedding', 'edge.fact_embedding', driver.provider)
759
- + """ AS score
760
- WHERE score > $min_score
761
- WITH edge, e, score
762
- ORDER BY score DESC
1190
+ if driver.provider == GraphProvider.NEPTUNE:
1191
+ query = (
1192
+ RUNTIME_QUERY
1193
+ + """
1194
+ UNWIND $edges AS edge
1195
+ MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
1196
+ """
1197
+ + filter_query
1198
+ + """
1199
+ WITH e, edge
1200
+ RETURN DISTINCT id(e) as id, e.fact_embedding as source_embedding, edge.uuid as search_edge_uuid,
1201
+ edge.fact_embedding as target_embedding
1202
+ """
1203
+ )
1204
+ resp, _, _ = await driver.execute_query(
1205
+ query,
1206
+ edges=[edge.model_dump() for edge in edges],
1207
+ limit=limit,
1208
+ min_score=min_score,
1209
+ routing_='r',
1210
+ **query_params,
1211
+ )
1212
+
1213
+ # Calculate Cosine similarity then return the edge ids
1214
+ input_ids = []
1215
+ for r in resp:
1216
+ score = calculate_cosine_similarity(
1217
+ list(map(float, r['source_embedding'].split(','))), r['target_embedding']
1218
+ )
1219
+ if score > min_score:
1220
+ input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})
1221
+
1222
+ # Match the edge ides and return the values
1223
+ query = """
1224
+ UNWIND $ids AS edge
1225
+ MATCH ()-[e]->()
1226
+ WHERE id(e) = edge.id
1227
+ WITH edge, e
1228
+ ORDER BY edge.score DESC
763
1229
  RETURN edge.uuid AS search_edge_uuid,
764
1230
  collect({
765
1231
  uuid: e.uuid,
@@ -769,24 +1235,69 @@ async def get_relevant_edges(
769
1235
  name: e.name,
770
1236
  group_id: e.group_id,
771
1237
  fact: e.fact,
772
- fact_embedding: e.fact_embedding,
773
- episodes: e.episodes,
1238
+ fact_embedding: [x IN split(e.fact_embedding, ",") | toFloat(x)],
1239
+ episodes: split(e.episodes, ","),
774
1240
  expired_at: e.expired_at,
775
1241
  valid_at: e.valid_at,
776
1242
  invalid_at: e.invalid_at,
777
1243
  attributes: properties(e)
778
1244
  })[..$limit] AS matches
779
- """
780
- )
1245
+ """
1246
+
1247
+ results, _, _ = await driver.execute_query(
1248
+ query,
1249
+ params=query_params,
1250
+ ids=input_ids,
1251
+ edges=[edge.model_dump() for edge in edges],
1252
+ limit=limit,
1253
+ min_score=min_score,
1254
+ routing_='r',
1255
+ **query_params,
1256
+ )
1257
+ else:
1258
+ query = (
1259
+ RUNTIME_QUERY
1260
+ + """
1261
+ UNWIND $edges AS edge
1262
+ MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
1263
+ """
1264
+ + filter_query
1265
+ + """
1266
+ WITH e, edge, """
1267
+ + get_vector_cosine_func_query(
1268
+ 'e.fact_embedding', 'edge.fact_embedding', driver.provider
1269
+ )
1270
+ + """ AS score
1271
+ WHERE score > $min_score
1272
+ WITH edge, e, score
1273
+ ORDER BY score DESC
1274
+ RETURN edge.uuid AS search_edge_uuid,
1275
+ collect({
1276
+ uuid: e.uuid,
1277
+ source_node_uuid: startNode(e).uuid,
1278
+ target_node_uuid: endNode(e).uuid,
1279
+ created_at: e.created_at,
1280
+ name: e.name,
1281
+ group_id: e.group_id,
1282
+ fact: e.fact,
1283
+ fact_embedding: e.fact_embedding,
1284
+ episodes: e.episodes,
1285
+ expired_at: e.expired_at,
1286
+ valid_at: e.valid_at,
1287
+ invalid_at: e.invalid_at,
1288
+ attributes: properties(e)
1289
+ })[..$limit] AS matches
1290
+ """
1291
+ )
781
1292
 
782
- results, _, _ = await driver.execute_query(
783
- query,
784
- edges=[edge.model_dump() for edge in edges],
785
- limit=limit,
786
- min_score=min_score,
787
- routing_='r',
788
- **query_params,
789
- )
1293
+ results, _, _ = await driver.execute_query(
1294
+ query,
1295
+ edges=[edge.model_dump() for edge in edges],
1296
+ limit=limit,
1297
+ min_score=min_score,
1298
+ routing_='r',
1299
+ **query_params,
1300
+ )
790
1301
 
791
1302
  relevant_edges_dict: dict[str, list[EntityEdge]] = {
792
1303
  result['search_edge_uuid']: [
@@ -815,21 +1326,47 @@ async def get_edge_invalidation_candidates(
815
1326
  filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
816
1327
  query_params.update(filter_params)
817
1328
 
818
- query = (
819
- RUNTIME_QUERY
820
- + """
821
- UNWIND $edges AS edge
822
- MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
823
- WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
824
- """
825
- + filter_query
826
- + """
827
- WITH edge, e, """
828
- + get_vector_cosine_func_query('e.fact_embedding', 'edge.fact_embedding', driver.provider)
829
- + """ AS score
830
- WHERE score > $min_score
831
- WITH edge, e, score
832
- ORDER BY score DESC
1329
+ if driver.provider == GraphProvider.NEPTUNE:
1330
+ query = (
1331
+ RUNTIME_QUERY
1332
+ + """
1333
+ UNWIND $edges AS edge
1334
+ MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
1335
+ WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
1336
+ """
1337
+ + filter_query
1338
+ + """
1339
+ WITH e, edge
1340
+ RETURN DISTINCT id(e) as id, e.fact_embedding as source_embedding,
1341
+ edge.fact_embedding as target_embedding,
1342
+ edge.uuid as search_edge_uuid
1343
+ """
1344
+ )
1345
+ resp, _, _ = await driver.execute_query(
1346
+ query,
1347
+ edges=[edge.model_dump() for edge in edges],
1348
+ limit=limit,
1349
+ min_score=min_score,
1350
+ routing_='r',
1351
+ **query_params,
1352
+ )
1353
+
1354
+ # Calculate Cosine similarity then return the edge ids
1355
+ input_ids = []
1356
+ for r in resp:
1357
+ score = calculate_cosine_similarity(
1358
+ list(map(float, r['source_embedding'].split(','))), r['target_embedding']
1359
+ )
1360
+ if score > min_score:
1361
+ input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})
1362
+
1363
+ # Match the edge ides and return the values
1364
+ query = """
1365
+ UNWIND $ids AS edge
1366
+ MATCH ()-[e]->()
1367
+ WHERE id(e) = edge.id
1368
+ WITH edge, e
1369
+ ORDER BY edge.score DESC
833
1370
  RETURN edge.uuid AS search_edge_uuid,
834
1371
  collect({
835
1372
  uuid: e.uuid,
@@ -839,24 +1376,68 @@ async def get_edge_invalidation_candidates(
839
1376
  name: e.name,
840
1377
  group_id: e.group_id,
841
1378
  fact: e.fact,
842
- fact_embedding: e.fact_embedding,
843
- episodes: e.episodes,
1379
+ fact_embedding: [x IN split(e.fact_embedding, ",") | toFloat(x)],
1380
+ episodes: split(e.episodes, ","),
844
1381
  expired_at: e.expired_at,
845
1382
  valid_at: e.valid_at,
846
1383
  invalid_at: e.invalid_at,
847
1384
  attributes: properties(e)
848
1385
  })[..$limit] AS matches
849
- """
850
- )
1386
+ """
1387
+ results, _, _ = await driver.execute_query(
1388
+ query,
1389
+ ids=input_ids,
1390
+ edges=[edge.model_dump() for edge in edges],
1391
+ limit=limit,
1392
+ min_score=min_score,
1393
+ routing_='r',
1394
+ **query_params,
1395
+ )
1396
+ else:
1397
+ query = (
1398
+ RUNTIME_QUERY
1399
+ + """
1400
+ UNWIND $edges AS edge
1401
+ MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
1402
+ WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
1403
+ """
1404
+ + filter_query
1405
+ + """
1406
+ WITH edge, e, """
1407
+ + get_vector_cosine_func_query(
1408
+ 'e.fact_embedding', 'edge.fact_embedding', driver.provider
1409
+ )
1410
+ + """ AS score
1411
+ WHERE score > $min_score
1412
+ WITH edge, e, score
1413
+ ORDER BY score DESC
1414
+ RETURN edge.uuid AS search_edge_uuid,
1415
+ collect({
1416
+ uuid: e.uuid,
1417
+ source_node_uuid: startNode(e).uuid,
1418
+ target_node_uuid: endNode(e).uuid,
1419
+ created_at: e.created_at,
1420
+ name: e.name,
1421
+ group_id: e.group_id,
1422
+ fact: e.fact,
1423
+ fact_embedding: e.fact_embedding,
1424
+ episodes: e.episodes,
1425
+ expired_at: e.expired_at,
1426
+ valid_at: e.valid_at,
1427
+ invalid_at: e.invalid_at,
1428
+ attributes: properties(e)
1429
+ })[..$limit] AS matches
1430
+ """
1431
+ )
851
1432
 
852
- results, _, _ = await driver.execute_query(
853
- query,
854
- edges=[edge.model_dump() for edge in edges],
855
- limit=limit,
856
- min_score=min_score,
857
- routing_='r',
858
- **query_params,
859
- )
1433
+ results, _, _ = await driver.execute_query(
1434
+ query,
1435
+ edges=[edge.model_dump() for edge in edges],
1436
+ limit=limit,
1437
+ min_score=min_score,
1438
+ routing_='r',
1439
+ **query_params,
1440
+ )
860
1441
  invalidation_edges_dict: dict[str, list[EntityEdge]] = {
861
1442
  result['search_edge_uuid']: [
862
1443
  get_entity_edge_from_record(record) for record in result['matches']
@@ -1007,14 +1588,24 @@ def maximal_marginal_relevance(
1007
1588
  async def get_embeddings_for_nodes(
1008
1589
  driver: GraphDriver, nodes: list[EntityNode]
1009
1590
  ) -> dict[str, list[float]]:
1010
- results, _, _ = await driver.execute_query(
1591
+ if driver.provider == GraphProvider.NEPTUNE:
1592
+ query = """
1593
+ MATCH (n:Entity)
1594
+ WHERE n.uuid IN $node_uuids
1595
+ RETURN DISTINCT
1596
+ n.uuid AS uuid,
1597
+ split(n.name_embedding, ",") AS name_embedding
1011
1598
  """
1599
+ else:
1600
+ query = """
1012
1601
  MATCH (n:Entity)
1013
1602
  WHERE n.uuid IN $node_uuids
1014
1603
  RETURN DISTINCT
1015
1604
  n.uuid AS uuid,
1016
1605
  n.name_embedding AS name_embedding
1017
- """,
1606
+ """
1607
+ results, _, _ = await driver.execute_query(
1608
+ query,
1018
1609
  node_uuids=[node.uuid for node in nodes],
1019
1610
  routing_='r',
1020
1611
  )
@@ -1032,14 +1623,24 @@ async def get_embeddings_for_nodes(
1032
1623
  async def get_embeddings_for_communities(
1033
1624
  driver: GraphDriver, communities: list[CommunityNode]
1034
1625
  ) -> dict[str, list[float]]:
1035
- results, _, _ = await driver.execute_query(
1626
+ if driver.provider == GraphProvider.NEPTUNE:
1627
+ query = """
1628
+ MATCH (c:Community)
1629
+ WHERE c.uuid IN $community_uuids
1630
+ RETURN DISTINCT
1631
+ c.uuid AS uuid,
1632
+ split(c.name_embedding, ",") AS name_embedding
1036
1633
  """
1634
+ else:
1635
+ query = """
1037
1636
  MATCH (c:Community)
1038
1637
  WHERE c.uuid IN $community_uuids
1039
1638
  RETURN DISTINCT
1040
1639
  c.uuid AS uuid,
1041
1640
  c.name_embedding AS name_embedding
1042
- """,
1641
+ """
1642
+ results, _, _ = await driver.execute_query(
1643
+ query,
1043
1644
  community_uuids=[community.uuid for community in communities],
1044
1645
  routing_='r',
1045
1646
  )
@@ -1057,14 +1658,24 @@ async def get_embeddings_for_communities(
1057
1658
  async def get_embeddings_for_edges(
1058
1659
  driver: GraphDriver, edges: list[EntityEdge]
1059
1660
  ) -> dict[str, list[float]]:
1060
- results, _, _ = await driver.execute_query(
1661
+ if driver.provider == GraphProvider.NEPTUNE:
1662
+ query = """
1663
+ MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
1664
+ WHERE e.uuid IN $edge_uuids
1665
+ RETURN DISTINCT
1666
+ e.uuid AS uuid,
1667
+ split(e.fact_embedding, ",") AS fact_embedding
1061
1668
  """
1669
+ else:
1670
+ query = """
1062
1671
  MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
1063
1672
  WHERE e.uuid IN $edge_uuids
1064
1673
  RETURN DISTINCT
1065
1674
  e.uuid AS uuid,
1066
1675
  e.fact_embedding AS fact_embedding
1067
- """,
1676
+ """
1677
+ results, _, _ = await driver.execute_query(
1678
+ query,
1068
1679
  edge_uuids=[edge.uuid for edge in edges],
1069
1680
  routing_='r',
1070
1681
  )