graphiti-core 0.20.4__py3-none-any.whl → 0.21.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -23,7 +23,13 @@ import numpy as np
23
23
  from numpy._typing import NDArray
24
24
  from typing_extensions import LiteralString
25
25
 
26
- from graphiti_core.driver.driver import GraphDriver, GraphProvider
26
+ from graphiti_core.driver.driver import (
27
+ ENTITY_EDGE_INDEX_NAME,
28
+ ENTITY_INDEX_NAME,
29
+ EPISODE_INDEX_NAME,
30
+ GraphDriver,
31
+ GraphProvider,
32
+ )
27
33
  from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
28
34
  from graphiti_core.graph_queries import (
29
35
  get_nodes_query,
@@ -51,6 +57,8 @@ from graphiti_core.nodes import (
51
57
  )
52
58
  from graphiti_core.search.search_filters import (
53
59
  SearchFilters,
60
+ build_aoss_edge_filters,
61
+ build_aoss_node_filters,
54
62
  edge_search_filter_query_constructor,
55
63
  node_search_filter_query_constructor,
56
64
  )
@@ -200,7 +208,6 @@ async def edge_fulltext_search(
200
208
  if driver.provider == GraphProvider.NEPTUNE:
201
209
  res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue
202
210
  if res['hits']['total']['value'] > 0:
203
- # Calculate Cosine similarity then return the edge ids
204
211
  input_ids = []
205
212
  for r in res['hits']['hits']:
206
213
  input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
@@ -208,11 +215,11 @@ async def edge_fulltext_search(
208
215
  # Match the edge ids and return the values
209
216
  query = (
210
217
  """
211
- UNWIND $ids as id
212
- MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
213
- WHERE e.group_id IN $group_ids
214
- AND id(e)=id
215
- """
218
+ UNWIND $ids as id
219
+ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
220
+ WHERE e.group_id IN $group_ids
221
+ AND id(e)=id
222
+ """
216
223
  + filter_query
217
224
  + """
218
225
  AND id(e)=id
@@ -244,6 +251,35 @@ async def edge_fulltext_search(
244
251
  )
245
252
  else:
246
253
  return []
254
+ elif driver.aoss_client:
255
+ route = group_ids[0] if group_ids else None
256
+ filters = build_aoss_edge_filters(group_ids or [], search_filter)
257
+ res = await driver.aoss_client.search(
258
+ index=ENTITY_EDGE_INDEX_NAME,
259
+ params={'routing': route},
260
+ body={
261
+ 'size': limit,
262
+ '_source': ['uuid'],
263
+ 'query': {
264
+ 'bool': {
265
+ 'filter': filters,
266
+ 'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}],
267
+ }
268
+ },
269
+ },
270
+ )
271
+
272
+ if res['hits']['total']['value'] > 0:
273
+ input_uuids = {}
274
+ for r in res['hits']['hits']:
275
+ input_uuids[r['_source']['uuid']] = r['_score']
276
+
277
+ # Get edges
278
+ entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
279
+ entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
280
+ return entity_edges
281
+ else:
282
+ return []
247
283
  else:
248
284
  query = (
249
285
  get_relationships_query('edge_name_and_fact', limit=limit, provider=driver.provider)
@@ -318,8 +354,8 @@ async def edge_similarity_search(
318
354
  if driver.provider == GraphProvider.NEPTUNE:
319
355
  query = (
320
356
  """
321
- MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
322
- """
357
+ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
358
+ """
323
359
  + filter_query
324
360
  + """
325
361
  RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
@@ -377,6 +413,38 @@ async def edge_similarity_search(
377
413
  )
378
414
  else:
379
415
  return []
416
+ elif driver.aoss_client:
417
+ route = group_ids[0] if group_ids else None
418
+ filters = build_aoss_edge_filters(group_ids or [], search_filter)
419
+ res = await driver.aoss_client.search(
420
+ index=ENTITY_EDGE_INDEX_NAME,
421
+ params={'routing': route},
422
+ body={
423
+ 'size': limit,
424
+ '_source': ['uuid'],
425
+ 'query': {
426
+ 'knn': {
427
+ 'fact_embedding': {
428
+ 'vector': list(map(float, search_vector)),
429
+ 'k': limit,
430
+ 'filter': {'bool': {'filter': filters}},
431
+ }
432
+ }
433
+ },
434
+ },
435
+ )
436
+
437
+ if res['hits']['total']['value'] > 0:
438
+ input_uuids = {}
439
+ for r in res['hits']['hits']:
440
+ input_uuids[r['_source']['uuid']] = r['_score']
441
+
442
+ # Get edges
443
+ entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
444
+ entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
445
+ return entity_edges
446
+ return []
447
+
380
448
  else:
381
449
  query = (
382
450
  match_query
@@ -563,7 +631,6 @@ async def node_fulltext_search(
563
631
  if driver.provider == GraphProvider.NEPTUNE:
564
632
  res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
565
633
  if res['hits']['total']['value'] > 0:
566
- # Calculate Cosine similarity then return the edge ids
567
634
  input_ids = []
568
635
  for r in res['hits']['hits']:
569
636
  input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
@@ -571,11 +638,11 @@ async def node_fulltext_search(
571
638
  # Match the edge ides and return the values
572
639
  query = (
573
640
  """
574
- UNWIND $ids as i
575
- MATCH (n:Entity)
576
- WHERE n.uuid=i.id
577
- RETURN
578
- """
641
+ UNWIND $ids as i
642
+ MATCH (n:Entity)
643
+ WHERE n.uuid=i.id
644
+ RETURN
645
+ """
579
646
  + get_entity_node_return_query(driver.provider)
580
647
  + """
581
648
  ORDER BY i.score DESC
@@ -592,6 +659,43 @@ async def node_fulltext_search(
592
659
  )
593
660
  else:
594
661
  return []
662
+ elif driver.aoss_client:
663
+ route = group_ids[0] if group_ids else None
664
+ filters = build_aoss_node_filters(group_ids or [], search_filter)
665
+ res = await driver.aoss_client.search(
666
+ index=ENTITY_INDEX_NAME,
667
+ params={'routing': route},
668
+ body={
669
+ '_source': ['uuid'],
670
+ 'size': limit,
671
+ 'query': {
672
+ 'bool': {
673
+ 'filter': filters,
674
+ 'must': [
675
+ {
676
+ 'multi_match': {
677
+ 'query': query,
678
+ 'fields': ['name', 'summary'],
679
+ 'operator': 'or',
680
+ }
681
+ }
682
+ ],
683
+ }
684
+ },
685
+ },
686
+ )
687
+
688
+ if res['hits']['total']['value'] > 0:
689
+ input_uuids = {}
690
+ for r in res['hits']['hits']:
691
+ input_uuids[r['_source']['uuid']] = r['_score']
692
+
693
+ # Get nodes
694
+ entities = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
695
+ entities.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
696
+ return entities
697
+ else:
698
+ return []
595
699
  else:
596
700
  query = (
597
701
  get_nodes_query(
@@ -648,8 +752,8 @@ async def node_similarity_search(
648
752
  if driver.provider == GraphProvider.NEPTUNE:
649
753
  query = (
650
754
  """
651
- MATCH (n:Entity)
652
- """
755
+ MATCH (n:Entity)
756
+ """
653
757
  + filter_query
654
758
  + """
655
759
  RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@@ -678,11 +782,11 @@ async def node_similarity_search(
678
782
  # Match the edge ides and return the values
679
783
  query = (
680
784
  """
681
- UNWIND $ids as i
682
- MATCH (n:Entity)
683
- WHERE id(n)=i.id
684
- RETURN
685
- """
785
+ UNWIND $ids as i
786
+ MATCH (n:Entity)
787
+ WHERE id(n)=i.id
788
+ RETURN
789
+ """
686
790
  + get_entity_node_return_query(driver.provider)
687
791
  + """
688
792
  ORDER BY i.score DESC
@@ -700,11 +804,42 @@ async def node_similarity_search(
700
804
  )
701
805
  else:
702
806
  return []
807
+ elif driver.aoss_client:
808
+ route = group_ids[0] if group_ids else None
809
+ filters = build_aoss_node_filters(group_ids or [], search_filter)
810
+ res = await driver.aoss_client.search(
811
+ index=ENTITY_INDEX_NAME,
812
+ params={'routing': route},
813
+ body={
814
+ 'size': limit,
815
+ '_source': ['uuid'],
816
+ 'query': {
817
+ 'knn': {
818
+ 'name_embedding': {
819
+ 'vector': list(map(float, search_vector)),
820
+ 'k': limit,
821
+ 'filter': {'bool': {'filter': filters}},
822
+ }
823
+ }
824
+ },
825
+ },
826
+ )
827
+
828
+ if res['hits']['total']['value'] > 0:
829
+ input_uuids = {}
830
+ for r in res['hits']['hits']:
831
+ input_uuids[r['_source']['uuid']] = r['_score']
832
+
833
+ # Get edges
834
+ entity_nodes = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
835
+ entity_nodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
836
+ return entity_nodes
837
+ return []
703
838
  else:
704
839
  query = (
705
840
  """
706
- MATCH (n:Entity)
707
- """
841
+ MATCH (n:Entity)
842
+ """
708
843
  + filter_query
709
844
  + """
710
845
  WITH n, """
@@ -843,7 +978,6 @@ async def episode_fulltext_search(
843
978
  if driver.provider == GraphProvider.NEPTUNE:
844
979
  res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
845
980
  if res['hits']['total']['value'] > 0:
846
- # Calculate Cosine similarity then return the edge ids
847
981
  input_ids = []
848
982
  for r in res['hits']['hits']:
849
983
  input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
@@ -852,7 +986,7 @@ async def episode_fulltext_search(
852
986
  query = """
853
987
  UNWIND $ids as i
854
988
  MATCH (e:Episodic)
855
- WHERE e.uuid=i.id
989
+ WHERE e.uuid=i.uuid
856
990
  RETURN
857
991
  e.content AS content,
858
992
  e.created_at AS created_at,
@@ -876,6 +1010,40 @@ async def episode_fulltext_search(
876
1010
  )
877
1011
  else:
878
1012
  return []
1013
+ elif driver.aoss_client:
1014
+ route = group_ids[0] if group_ids else None
1015
+ res = await driver.aoss_client.search(
1016
+ index=EPISODE_INDEX_NAME,
1017
+ params={'routing': route},
1018
+ body={
1019
+ 'size': limit,
1020
+ '_source': ['uuid'],
1021
+ 'bool': {
1022
+ 'filter': {'terms': group_ids},
1023
+ 'must': [
1024
+ {
1025
+ 'multi_match': {
1026
+ 'query': query,
1027
+ 'field': ['name', 'content'],
1028
+ 'operator': 'or',
1029
+ }
1030
+ }
1031
+ ],
1032
+ },
1033
+ },
1034
+ )
1035
+
1036
+ if res['hits']['total']['value'] > 0:
1037
+ input_uuids = {}
1038
+ for r in res['hits']['hits']:
1039
+ input_uuids[r['_source']['uuid']] = r['_score']
1040
+
1041
+ # Get nodes
1042
+ episodes = await EpisodicNode.get_by_uuids(driver, list(input_uuids.keys()))
1043
+ episodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
1044
+ return episodes
1045
+ else:
1046
+ return []
879
1047
  else:
880
1048
  query = (
881
1049
  get_nodes_query('episode_content', '$query', limit=limit, provider=driver.provider)
@@ -1003,8 +1171,8 @@ async def community_similarity_search(
1003
1171
  if driver.provider == GraphProvider.NEPTUNE:
1004
1172
  query = (
1005
1173
  """
1006
- MATCH (n:Community)
1007
- """
1174
+ MATCH (n:Community)
1175
+ """
1008
1176
  + group_filter_query
1009
1177
  + """
1010
1178
  RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@@ -1063,8 +1231,8 @@ async def community_similarity_search(
1063
1231
 
1064
1232
  query = (
1065
1233
  """
1066
- MATCH (c:Community)
1067
- """
1234
+ MATCH (c:Community)
1235
+ """
1068
1236
  + group_filter_query
1069
1237
  + """
1070
1238
  WITH c,
@@ -1206,9 +1374,9 @@ async def get_relevant_nodes(
1206
1374
  # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
1207
1375
  query = (
1208
1376
  """
1209
- UNWIND $nodes AS node
1210
- MATCH (n:Entity {group_id: $group_id})
1211
- """
1377
+ UNWIND $nodes AS node
1378
+ MATCH (n:Entity {group_id: $group_id})
1379
+ """
1212
1380
  + filter_query
1213
1381
  + """
1214
1382
  WITH node, n, """
@@ -1253,9 +1421,9 @@ async def get_relevant_nodes(
1253
1421
  else:
1254
1422
  query = (
1255
1423
  """
1256
- UNWIND $nodes AS node
1257
- MATCH (n:Entity {group_id: $group_id})
1258
- """
1424
+ UNWIND $nodes AS node
1425
+ MATCH (n:Entity {group_id: $group_id})
1426
+ """
1259
1427
  + filter_query
1260
1428
  + """
1261
1429
  WITH node, n, """
@@ -1344,9 +1512,9 @@ async def get_relevant_edges(
1344
1512
  if driver.provider == GraphProvider.NEPTUNE:
1345
1513
  query = (
1346
1514
  """
1347
- UNWIND $edges AS edge
1348
- MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
1349
- """
1515
+ UNWIND $edges AS edge
1516
+ MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
1517
+ """
1350
1518
  + filter_query
1351
1519
  + """
1352
1520
  WITH e, edge
@@ -1416,9 +1584,9 @@ async def get_relevant_edges(
1416
1584
 
1417
1585
  query = (
1418
1586
  """
1419
- UNWIND $edges AS edge
1420
- MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
1421
- """
1587
+ UNWIND $edges AS edge
1588
+ MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
1589
+ """
1422
1590
  + filter_query
1423
1591
  + """
1424
1592
  WITH e, edge, n, m, """
@@ -1454,9 +1622,9 @@ async def get_relevant_edges(
1454
1622
  else:
1455
1623
  query = (
1456
1624
  """
1457
- UNWIND $edges AS edge
1458
- MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
1459
- """
1625
+ UNWIND $edges AS edge
1626
+ MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
1627
+ """
1460
1628
  + filter_query
1461
1629
  + """
1462
1630
  WITH e, edge, """
@@ -1529,10 +1697,10 @@ async def get_edge_invalidation_candidates(
1529
1697
  if driver.provider == GraphProvider.NEPTUNE:
1530
1698
  query = (
1531
1699
  """
1532
- UNWIND $edges AS edge
1533
- MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
1534
- WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
1535
- """
1700
+ UNWIND $edges AS edge
1701
+ MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
1702
+ WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
1703
+ """
1536
1704
  + filter_query
1537
1705
  + """
1538
1706
  WITH e, edge
@@ -1602,10 +1770,10 @@ async def get_edge_invalidation_candidates(
1602
1770
 
1603
1771
  query = (
1604
1772
  """
1605
- UNWIND $edges AS edge
1606
- MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
1607
- WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
1608
- """
1773
+ UNWIND $edges AS edge
1774
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
1775
+ WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
1776
+ """
1609
1777
  + filter_query
1610
1778
  + """
1611
1779
  WITH edge, e, n, m, """
@@ -1641,10 +1809,10 @@ async def get_edge_invalidation_candidates(
1641
1809
  else:
1642
1810
  query = (
1643
1811
  """
1644
- UNWIND $edges AS edge
1645
- MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
1646
- WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
1647
- """
1812
+ UNWIND $edges AS edge
1813
+ MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
1814
+ WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
1815
+ """
1648
1816
  + filter_query
1649
1817
  + """
1650
1818
  WITH edge, e, """
@@ -23,7 +23,14 @@ import numpy as np
23
23
  from pydantic import BaseModel, Field
24
24
  from typing_extensions import Any
25
25
 
26
- from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
26
+ from graphiti_core.driver.driver import (
27
+ ENTITY_EDGE_INDEX_NAME,
28
+ ENTITY_INDEX_NAME,
29
+ EPISODE_INDEX_NAME,
30
+ GraphDriver,
31
+ GraphDriverSession,
32
+ GraphProvider,
33
+ )
27
34
  from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
28
35
  from graphiti_core.embedder import EmbedderClient
29
36
  from graphiti_core.graphiti_types import GraphitiClients
@@ -187,12 +194,25 @@ async def add_nodes_and_edges_bulk_tx(
187
194
  await tx.run(episodic_edge_query, **edge.model_dump())
188
195
  else:
189
196
  await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
190
- await tx.run(get_entity_node_save_bulk_query(driver.provider, nodes), nodes=nodes)
197
+ await tx.run(
198
+ get_entity_node_save_bulk_query(driver.provider, nodes),
199
+ nodes=nodes,
200
+ has_aoss=bool(driver.aoss_client),
201
+ )
191
202
  await tx.run(
192
203
  get_episodic_edge_save_bulk_query(driver.provider),
193
204
  episodic_edges=[edge.model_dump() for edge in episodic_edges],
194
205
  )
195
- await tx.run(get_entity_edge_save_bulk_query(driver.provider), entity_edges=edges)
206
+ await tx.run(
207
+ get_entity_edge_save_bulk_query(driver.provider),
208
+ entity_edges=edges,
209
+ has_aoss=bool(driver.aoss_client),
210
+ )
211
+
212
+ if driver.aoss_client:
213
+ await driver.save_to_aoss(EPISODE_INDEX_NAME, episodes)
214
+ await driver.save_to_aoss(ENTITY_INDEX_NAME, nodes)
215
+ await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges)
196
216
 
197
217
 
198
218
  async def extract_nodes_and_edges_bulk(
@@ -36,8 +36,10 @@ from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
36
36
  from graphiti_core.prompts import prompt_library
37
37
  from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
38
38
  from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
39
+ from graphiti_core.search.search import search
40
+ from graphiti_core.search.search_config import SearchResults
41
+ from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF
39
42
  from graphiti_core.search.search_filters import SearchFilters
40
- from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
41
43
  from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
42
44
 
43
45
  logger = logging.getLogger(__name__)
@@ -258,12 +260,44 @@ async def resolve_extracted_edges(
258
260
  embedder = clients.embedder
259
261
  await create_entity_edge_embeddings(embedder, extracted_edges)
260
262
 
261
- search_results = await semaphore_gather(
262
- get_relevant_edges(driver, extracted_edges, SearchFilters()),
263
- get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters(), 0.2),
263
+ valid_edges_list: list[list[EntityEdge]] = await semaphore_gather(
264
+ *[
265
+ EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid)
266
+ for edge in extracted_edges
267
+ ]
268
+ )
269
+
270
+ related_edges_results: list[SearchResults] = await semaphore_gather(
271
+ *[
272
+ search(
273
+ clients,
274
+ extracted_edge.fact,
275
+ group_ids=[extracted_edge.group_id],
276
+ config=EDGE_HYBRID_SEARCH_RRF,
277
+ search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
278
+ )
279
+ for extracted_edge, valid_edges in zip(extracted_edges, valid_edges_list, strict=True)
280
+ ]
264
281
  )
265
282
 
266
- related_edges_lists, edge_invalidation_candidates = search_results
283
+ related_edges_lists: list[list[EntityEdge]] = [result.edges for result in related_edges_results]
284
+
285
+ edge_invalidation_candidate_results: list[SearchResults] = await semaphore_gather(
286
+ *[
287
+ search(
288
+ clients,
289
+ extracted_edge.fact,
290
+ group_ids=[extracted_edge.group_id],
291
+ config=EDGE_HYBRID_SEARCH_RRF,
292
+ search_filter=SearchFilters(),
293
+ )
294
+ for extracted_edge in extracted_edges
295
+ ]
296
+ )
297
+
298
+ edge_invalidation_candidates: list[list[EntityEdge]] = [
299
+ result.edges for result in edge_invalidation_candidate_results
300
+ ]
267
301
 
268
302
  logger.debug(
269
303
  f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
@@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
34
34
 
35
35
 
36
36
  async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
37
- if driver.provider == GraphProvider.NEPTUNE:
37
+ if driver.aoss_client:
38
38
  await driver.create_aoss_indices() # pyright: ignore[reportAttributeAccessIssue]
39
39
  return
40
40
  if delete_existing:
@@ -56,7 +56,9 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo
56
56
 
57
57
  range_indices: list[LiteralString] = get_range_indices(driver.provider)
58
58
 
59
- fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
59
+ # Don't create fulltext indices if OpenSearch is being used
60
+ if not driver.aoss_client:
61
+ fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
60
62
 
61
63
  if driver.provider == GraphProvider.KUZU:
62
64
  # Skip creating fulltext indices if they already exist. Need to do this manually
@@ -93,6 +95,8 @@ async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
93
95
 
94
96
  async def delete_all(tx):
95
97
  await tx.run('MATCH (n) DETACH DELETE n')
98
+ if driver.aoss_client:
99
+ await driver.clear_aoss_indices()
96
100
 
97
101
  async def delete_group_ids(tx):
98
102
  labels = ['Entity', 'Episodic', 'Community']
@@ -149,9 +153,9 @@ async def retrieve_episodes(
149
153
 
150
154
  query: LiteralString = (
151
155
  """
152
- MATCH (e:Episodic)
153
- WHERE e.valid_at <= $reference_time
154
- """
156
+ MATCH (e:Episodic)
157
+ WHERE e.valid_at <= $reference_time
158
+ """
155
159
  + query_filter
156
160
  + """
157
161
  RETURN
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: graphiti-core
3
- Version: 0.20.4
3
+ Version: 0.21.0rc2
4
4
  Summary: A temporal graph building library
5
5
  Project-URL: Homepage, https://help.getzep.com/graphiti/graphiti/overview
6
6
  Project-URL: Repository, https://github.com/getzep/graphiti
@@ -47,6 +47,9 @@ Provides-Extra: groq
47
47
  Requires-Dist: groq>=0.2.0; extra == 'groq'
48
48
  Provides-Extra: kuzu
49
49
  Requires-Dist: kuzu>=0.11.2; extra == 'kuzu'
50
+ Provides-Extra: neo4j-opensearch
51
+ Requires-Dist: boto3>=1.39.16; extra == 'neo4j-opensearch'
52
+ Requires-Dist: opensearch-py>=3.0.0; extra == 'neo4j-opensearch'
50
53
  Provides-Extra: neptune
51
54
  Requires-Dist: boto3>=1.39.16; extra == 'neptune'
52
55
  Requires-Dist: langchain-aws>=0.2.29; extra == 'neptune'