graphiti-core 0.20.3__py3-none-any.whl → 0.21.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

graphiti_core/nodes.py CHANGED
@@ -273,20 +273,6 @@ class EpisodicNode(Node):
273
273
  )
274
274
 
275
275
  async def save(self, driver: GraphDriver):
276
- if driver.provider == GraphProvider.NEPTUNE:
277
- driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
278
- 'episode_content',
279
- [
280
- {
281
- 'uuid': self.uuid,
282
- 'group_id': self.group_id,
283
- 'source': self.source.value,
284
- 'content': self.content,
285
- 'source_description': self.source_description,
286
- }
287
- ],
288
- )
289
-
290
276
  episode_args = {
291
277
  'uuid': self.uuid,
292
278
  'name': self.name,
@@ -299,6 +285,12 @@ class EpisodicNode(Node):
299
285
  'source': self.source.value,
300
286
  }
301
287
 
288
+ if driver.aoss_client:
289
+ driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
290
+ 'episodes',
291
+ [episode_args],
292
+ )
293
+
302
294
  result = await driver.execute_query(
303
295
  get_episode_node_save_query(driver.provider), **episode_args
304
296
  )
@@ -433,6 +425,22 @@ class EntityNode(Node):
433
425
  MATCH (n:Entity {uuid: $uuid})
434
426
  RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
435
427
  """
428
+ elif driver.aoss_client:
429
+ resp = driver.aoss_client.search(
430
+ body={
431
+ 'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
432
+ 'size': 1,
433
+ },
434
+ index='entities',
435
+ routing=self.group_id,
436
+ )
437
+
438
+ if resp['hits']['hits']:
439
+ self.name_embedding = resp['hits']['hits'][0]['_source']['name_embedding']
440
+ return
441
+ else:
442
+ raise NodeNotFoundError(self.uuid)
443
+
436
444
  else:
437
445
  query: LiteralString = """
438
446
  MATCH (n:Entity {uuid: $uuid})
@@ -470,11 +478,11 @@ class EntityNode(Node):
470
478
  entity_data.update(self.attributes or {})
471
479
  labels = ':'.join(self.labels + ['Entity'])
472
480
 
473
- if driver.provider == GraphProvider.NEPTUNE:
474
- driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue
481
+ if driver.aoss_client:
482
+ driver.save_to_aoss('entities', [entity_data]) # pyright: ignore reportAttributeAccessIssue
475
483
 
476
484
  result = await driver.execute_query(
477
- get_entity_node_save_query(driver.provider, labels),
485
+ get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)),
478
486
  entity_data=entity_data,
479
487
  )
480
488
 
@@ -570,7 +578,7 @@ class CommunityNode(Node):
570
578
  async def save(self, driver: GraphDriver):
571
579
  if driver.provider == GraphProvider.NEPTUNE:
572
580
  driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
573
- 'community_name',
581
+ 'communities',
574
582
  [{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
575
583
  )
576
584
  result = await driver.execute_query(
@@ -54,6 +54,16 @@ class SearchFilters(BaseModel):
54
54
  expired_at: list[list[DateFilter]] | None = Field(default=None)
55
55
 
56
56
 
57
+ def cypher_to_opensearch_operator(op: ComparisonOperator) -> str:
58
+ mapping = {
59
+ ComparisonOperator.greater_than: 'gt',
60
+ ComparisonOperator.less_than: 'lt',
61
+ ComparisonOperator.greater_than_equal: 'gte',
62
+ ComparisonOperator.less_than_equal: 'lte',
63
+ }
64
+ return mapping.get(op, op.value)
65
+
66
+
57
67
  def node_search_filter_query_constructor(
58
68
  filters: SearchFilters,
59
69
  provider: GraphProvider,
@@ -234,3 +244,38 @@ def edge_search_filter_query_constructor(
234
244
  filter_queries.append(expired_at_filter)
235
245
 
236
246
  return filter_queries, filter_params
247
+
248
+
249
+ def build_aoss_node_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
250
+ filters = [{'terms': {'group_id': group_ids}}]
251
+
252
+ if search_filters.node_labels:
253
+ filters.append({'terms': {'node_labels': search_filters.node_labels}})
254
+
255
+ return filters
256
+
257
+
258
+ def build_aoss_edge_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
259
+ filters: list[dict] = [{'terms': {'group_id': group_ids}}]
260
+
261
+ if search_filters.edge_types:
262
+ filters.append({'terms': {'edge_types': search_filters.edge_types}})
263
+
264
+ for field in ['valid_at', 'invalid_at', 'created_at', 'expired_at']:
265
+ ranges = getattr(search_filters, field)
266
+ if ranges:
267
+ # OR of ANDs
268
+ should_clauses = []
269
+ for and_group in ranges:
270
+ and_filters = []
271
+ for df in and_group: # df is a DateFilter
272
+ range_query = {
273
+ 'range': {
274
+ field: {cypher_to_opensearch_operator(df.comparison_operator): df.date}
275
+ }
276
+ }
277
+ and_filters.append(range_query)
278
+ should_clauses.append({'bool': {'filter': and_filters}})
279
+ filters.append({'bool': {'should': should_clauses, 'minimum_should_match': 1}})
280
+
281
+ return filters
@@ -51,6 +51,8 @@ from graphiti_core.nodes import (
51
51
  )
52
52
  from graphiti_core.search.search_filters import (
53
53
  SearchFilters,
54
+ build_aoss_edge_filters,
55
+ build_aoss_node_filters,
54
56
  edge_search_filter_query_constructor,
55
57
  node_search_filter_query_constructor,
56
58
  )
@@ -200,7 +202,6 @@ async def edge_fulltext_search(
200
202
  if driver.provider == GraphProvider.NEPTUNE:
201
203
  res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue
202
204
  if res['hits']['total']['value'] > 0:
203
- # Calculate Cosine similarity then return the edge ids
204
205
  input_ids = []
205
206
  for r in res['hits']['hits']:
206
207
  input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
@@ -208,11 +209,11 @@ async def edge_fulltext_search(
208
209
  # Match the edge ids and return the values
209
210
  query = (
210
211
  """
211
- UNWIND $ids as id
212
- MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
213
- WHERE e.group_id IN $group_ids
214
- AND id(e)=id
215
- """
212
+ UNWIND $ids as id
213
+ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
214
+ WHERE e.group_id IN $group_ids
215
+ AND id(e)=id
216
+ """
216
217
  + filter_query
217
218
  + """
218
219
  AND id(e)=id
@@ -244,6 +245,31 @@ async def edge_fulltext_search(
244
245
  )
245
246
  else:
246
247
  return []
248
+ elif driver.aoss_client:
249
+ route = group_ids[0] if group_ids else None
250
+ filters = build_aoss_edge_filters(group_ids or [], search_filter)
251
+ res = driver.aoss_client.search(
252
+ index='entity_edges',
253
+ routing=route,
254
+ _source=['uuid'],
255
+ query={
256
+ 'bool': {
257
+ 'filter': filters,
258
+ 'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}],
259
+ }
260
+ },
261
+ )
262
+ if res['hits']['total']['value'] > 0:
263
+ input_uuids = {}
264
+ for r in res['hits']['hits']:
265
+ input_uuids[r['_source']['uuid']] = r['_score']
266
+
267
+ # Get edges
268
+ entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
269
+ entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
270
+ return entity_edges
271
+ else:
272
+ return []
247
273
  else:
248
274
  query = (
249
275
  get_relationships_query('edge_name_and_fact', limit=limit, provider=driver.provider)
@@ -318,8 +344,8 @@ async def edge_similarity_search(
318
344
  if driver.provider == GraphProvider.NEPTUNE:
319
345
  query = (
320
346
  """
321
- MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
322
- """
347
+ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
348
+ """
323
349
  + filter_query
324
350
  + """
325
351
  RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
@@ -377,6 +403,32 @@ async def edge_similarity_search(
377
403
  )
378
404
  else:
379
405
  return []
406
+ elif driver.aoss_client:
407
+ route = group_ids[0] if group_ids else None
408
+ filters = build_aoss_edge_filters(group_ids or [], search_filter)
409
+ res = driver.aoss_client.search(
410
+ index='entity_edges',
411
+ routing=route,
412
+ _source=['uuid'],
413
+ knn={
414
+ 'field': 'fact_embedding',
415
+ 'query_vector': search_vector,
416
+ 'k': limit,
417
+ 'num_candidates': 1000,
418
+ },
419
+ query={'bool': {'filter': filters}},
420
+ )
421
+
422
+ if res['hits']['total']['value'] > 0:
423
+ input_uuids = {}
424
+ for r in res['hits']['hits']:
425
+ input_uuids[r['_source']['uuid']] = r['_score']
426
+
427
+ # Get edges
428
+ entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
429
+ entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
430
+ return entity_edges
431
+
380
432
  else:
381
433
  query = (
382
434
  match_query
@@ -563,7 +615,6 @@ async def node_fulltext_search(
563
615
  if driver.provider == GraphProvider.NEPTUNE:
564
616
  res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
565
617
  if res['hits']['total']['value'] > 0:
566
- # Calculate Cosine similarity then return the edge ids
567
618
  input_ids = []
568
619
  for r in res['hits']['hits']:
569
620
  input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
@@ -571,11 +622,11 @@ async def node_fulltext_search(
571
622
  # Match the edge ides and return the values
572
623
  query = (
573
624
  """
574
- UNWIND $ids as i
575
- MATCH (n:Entity)
576
- WHERE n.uuid=i.id
577
- RETURN
578
- """
625
+ UNWIND $ids as i
626
+ MATCH (n:Entity)
627
+ WHERE n.uuid=i.id
628
+ RETURN
629
+ """
579
630
  + get_entity_node_return_query(driver.provider)
580
631
  + """
581
632
  ORDER BY i.score DESC
@@ -592,6 +643,41 @@ async def node_fulltext_search(
592
643
  )
593
644
  else:
594
645
  return []
646
+ elif driver.aoss_client:
647
+ route = group_ids[0] if group_ids else None
648
+ filters = build_aoss_node_filters(group_ids or [], search_filter)
649
+ res = driver.aoss_client.search(
650
+ 'entities',
651
+ routing=route,
652
+ _source=['uuid'],
653
+ query={
654
+ 'bool': {
655
+ 'filter': filters,
656
+ 'must': [
657
+ {
658
+ 'multi_match': {
659
+ 'query': query,
660
+ 'field': ['name', 'summary'],
661
+ 'operator': 'or',
662
+ }
663
+ }
664
+ ],
665
+ }
666
+ },
667
+ limit=limit,
668
+ )
669
+
670
+ if res['hits']['total']['value'] > 0:
671
+ input_uuids = {}
672
+ for r in res['hits']['hits']:
673
+ input_uuids[r['_source']['uuid']] = r['_score']
674
+
675
+ # Get nodes
676
+ entities = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
677
+ entities.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
678
+ return entities
679
+ else:
680
+ return []
595
681
  else:
596
682
  query = (
597
683
  get_nodes_query(
@@ -648,8 +734,8 @@ async def node_similarity_search(
648
734
  if driver.provider == GraphProvider.NEPTUNE:
649
735
  query = (
650
736
  """
651
- MATCH (n:Entity)
652
- """
737
+ MATCH (n:Entity)
738
+ """
653
739
  + filter_query
654
740
  + """
655
741
  RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@@ -678,11 +764,11 @@ async def node_similarity_search(
678
764
  # Match the edge ides and return the values
679
765
  query = (
680
766
  """
681
- UNWIND $ids as i
682
- MATCH (n:Entity)
683
- WHERE id(n)=i.id
684
- RETURN
685
- """
767
+ UNWIND $ids as i
768
+ MATCH (n:Entity)
769
+ WHERE id(n)=i.id
770
+ RETURN
771
+ """
686
772
  + get_entity_node_return_query(driver.provider)
687
773
  + """
688
774
  ORDER BY i.score DESC
@@ -700,11 +786,36 @@ async def node_similarity_search(
700
786
  )
701
787
  else:
702
788
  return []
789
+ elif driver.aoss_client:
790
+ route = group_ids[0] if group_ids else None
791
+ filters = build_aoss_node_filters(group_ids or [], search_filter)
792
+ res = driver.aoss_client.search(
793
+ index='entities',
794
+ routing=route,
795
+ _source=['uuid'],
796
+ knn={
797
+ 'field': 'fact_embedding',
798
+ 'query_vector': search_vector,
799
+ 'k': limit,
800
+ 'num_candidates': 1000,
801
+ },
802
+ query={'bool': {'filter': filters}},
803
+ )
804
+
805
+ if res['hits']['total']['value'] > 0:
806
+ input_uuids = {}
807
+ for r in res['hits']['hits']:
808
+ input_uuids[r['_source']['uuid']] = r['_score']
809
+
810
+ # Get edges
811
+ entity_nodes = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
812
+ entity_nodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
813
+ return entity_nodes
703
814
  else:
704
815
  query = (
705
816
  """
706
- MATCH (n:Entity)
707
- """
817
+ MATCH (n:Entity)
818
+ """
708
819
  + filter_query
709
820
  + """
710
821
  WITH n, """
@@ -843,7 +954,6 @@ async def episode_fulltext_search(
843
954
  if driver.provider == GraphProvider.NEPTUNE:
844
955
  res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
845
956
  if res['hits']['total']['value'] > 0:
846
- # Calculate Cosine similarity then return the edge ids
847
957
  input_ids = []
848
958
  for r in res['hits']['hits']:
849
959
  input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
@@ -852,7 +962,7 @@ async def episode_fulltext_search(
852
962
  query = """
853
963
  UNWIND $ids as i
854
964
  MATCH (e:Episodic)
855
- WHERE e.uuid=i.id
965
+ WHERE e.uuid=i.uuid
856
966
  RETURN
857
967
  e.content AS content,
858
968
  e.created_at AS created_at,
@@ -876,6 +986,40 @@ async def episode_fulltext_search(
876
986
  )
877
987
  else:
878
988
  return []
989
+ elif driver.aoss_client:
990
+ route = group_ids[0] if group_ids else None
991
+ res = driver.aoss_client.search(
992
+ 'episodes',
993
+ routing=route,
994
+ _source=['uuid'],
995
+ query={
996
+ 'bool': {
997
+ 'filter': {'terms': group_ids},
998
+ 'must': [
999
+ {
1000
+ 'multi_match': {
1001
+ 'query': query,
1002
+ 'field': ['name', 'content'],
1003
+ 'operator': 'or',
1004
+ }
1005
+ }
1006
+ ],
1007
+ }
1008
+ },
1009
+ limit=limit,
1010
+ )
1011
+
1012
+ if res['hits']['total']['value'] > 0:
1013
+ input_uuids = {}
1014
+ for r in res['hits']['hits']:
1015
+ input_uuids[r['_source']['uuid']] = r['_score']
1016
+
1017
+ # Get nodes
1018
+ episodes = await EpisodicNode.get_by_uuids(driver, list(input_uuids.keys()))
1019
+ episodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
1020
+ return episodes
1021
+ else:
1022
+ return []
879
1023
  else:
880
1024
  query = (
881
1025
  get_nodes_query('episode_content', '$query', limit=limit, provider=driver.provider)
@@ -1003,8 +1147,8 @@ async def community_similarity_search(
1003
1147
  if driver.provider == GraphProvider.NEPTUNE:
1004
1148
  query = (
1005
1149
  """
1006
- MATCH (n:Community)
1007
- """
1150
+ MATCH (n:Community)
1151
+ """
1008
1152
  + group_filter_query
1009
1153
  + """
1010
1154
  RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@@ -1063,8 +1207,8 @@ async def community_similarity_search(
1063
1207
 
1064
1208
  query = (
1065
1209
  """
1066
- MATCH (c:Community)
1067
- """
1210
+ MATCH (c:Community)
1211
+ """
1068
1212
  + group_filter_query
1069
1213
  + """
1070
1214
  WITH c,
@@ -1206,9 +1350,9 @@ async def get_relevant_nodes(
1206
1350
  # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
1207
1351
  query = (
1208
1352
  """
1209
- UNWIND $nodes AS node
1210
- MATCH (n:Entity {group_id: $group_id})
1211
- """
1353
+ UNWIND $nodes AS node
1354
+ MATCH (n:Entity {group_id: $group_id})
1355
+ """
1212
1356
  + filter_query
1213
1357
  + """
1214
1358
  WITH node, n, """
@@ -1253,9 +1397,9 @@ async def get_relevant_nodes(
1253
1397
  else:
1254
1398
  query = (
1255
1399
  """
1256
- UNWIND $nodes AS node
1257
- MATCH (n:Entity {group_id: $group_id})
1258
- """
1400
+ UNWIND $nodes AS node
1401
+ MATCH (n:Entity {group_id: $group_id})
1402
+ """
1259
1403
  + filter_query
1260
1404
  + """
1261
1405
  WITH node, n, """
@@ -1344,9 +1488,9 @@ async def get_relevant_edges(
1344
1488
  if driver.provider == GraphProvider.NEPTUNE:
1345
1489
  query = (
1346
1490
  """
1347
- UNWIND $edges AS edge
1348
- MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
1349
- """
1491
+ UNWIND $edges AS edge
1492
+ MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
1493
+ """
1350
1494
  + filter_query
1351
1495
  + """
1352
1496
  WITH e, edge
@@ -1416,9 +1560,9 @@ async def get_relevant_edges(
1416
1560
 
1417
1561
  query = (
1418
1562
  """
1419
- UNWIND $edges AS edge
1420
- MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
1421
- """
1563
+ UNWIND $edges AS edge
1564
+ MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
1565
+ """
1422
1566
  + filter_query
1423
1567
  + """
1424
1568
  WITH e, edge, n, m, """
@@ -1454,9 +1598,9 @@ async def get_relevant_edges(
1454
1598
  else:
1455
1599
  query = (
1456
1600
  """
1457
- UNWIND $edges AS edge
1458
- MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
1459
- """
1601
+ UNWIND $edges AS edge
1602
+ MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
1603
+ """
1460
1604
  + filter_query
1461
1605
  + """
1462
1606
  WITH e, edge, """
@@ -1529,10 +1673,10 @@ async def get_edge_invalidation_candidates(
1529
1673
  if driver.provider == GraphProvider.NEPTUNE:
1530
1674
  query = (
1531
1675
  """
1532
- UNWIND $edges AS edge
1533
- MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
1534
- WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
1535
- """
1676
+ UNWIND $edges AS edge
1677
+ MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
1678
+ WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
1679
+ """
1536
1680
  + filter_query
1537
1681
  + """
1538
1682
  WITH e, edge
@@ -1602,10 +1746,10 @@ async def get_edge_invalidation_candidates(
1602
1746
 
1603
1747
  query = (
1604
1748
  """
1605
- UNWIND $edges AS edge
1606
- MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
1607
- WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
1608
- """
1749
+ UNWIND $edges AS edge
1750
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
1751
+ WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
1752
+ """
1609
1753
  + filter_query
1610
1754
  + """
1611
1755
  WITH edge, e, n, m, """
@@ -1641,10 +1785,10 @@ async def get_edge_invalidation_candidates(
1641
1785
  else:
1642
1786
  query = (
1643
1787
  """
1644
- UNWIND $edges AS edge
1645
- MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
1646
- WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
1647
- """
1788
+ UNWIND $edges AS edge
1789
+ MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
1790
+ WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
1791
+ """
1648
1792
  + filter_query
1649
1793
  + """
1650
1794
  WITH edge, e, """
@@ -187,12 +187,25 @@ async def add_nodes_and_edges_bulk_tx(
187
187
  await tx.run(episodic_edge_query, **edge.model_dump())
188
188
  else:
189
189
  await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
190
- await tx.run(get_entity_node_save_bulk_query(driver.provider, nodes), nodes=nodes)
190
+ await tx.run(
191
+ get_entity_node_save_bulk_query(driver.provider, nodes),
192
+ nodes=nodes,
193
+ has_aoss=bool(driver.aoss_client),
194
+ )
191
195
  await tx.run(
192
196
  get_episodic_edge_save_bulk_query(driver.provider),
193
197
  episodic_edges=[edge.model_dump() for edge in episodic_edges],
194
198
  )
195
- await tx.run(get_entity_edge_save_bulk_query(driver.provider), entity_edges=edges)
199
+ await tx.run(
200
+ get_entity_edge_save_bulk_query(driver.provider),
201
+ entity_edges=edges,
202
+ has_aoss=bool(driver.aoss_client),
203
+ )
204
+
205
+ if driver.aoss_client:
206
+ driver.save_to_aoss('episodes', episodes)
207
+ driver.save_to_aoss('entities', nodes)
208
+ driver.save_to_aoss('entity_edges', edges)
196
209
 
197
210
 
198
211
  async def extract_nodes_and_edges_bulk(
@@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
34
34
 
35
35
 
36
36
  async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
37
- if driver.provider == GraphProvider.NEPTUNE:
37
+ if driver.aoss_client:
38
38
  await driver.create_aoss_indices() # pyright: ignore[reportAttributeAccessIssue]
39
39
  return
40
40
  if delete_existing:
@@ -56,7 +56,9 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo
56
56
 
57
57
  range_indices: list[LiteralString] = get_range_indices(driver.provider)
58
58
 
59
- fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
59
+ # Don't create fulltext indices if OpenSearch is being used
60
+ if not driver.aoss_client:
61
+ fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
60
62
 
61
63
  if driver.provider == GraphProvider.KUZU:
62
64
  # Skip creating fulltext indices if they already exist. Need to do this manually
@@ -149,9 +151,9 @@ async def retrieve_episodes(
149
151
 
150
152
  query: LiteralString = (
151
153
  """
152
- MATCH (e:Episodic)
153
- WHERE e.valid_at <= $reference_time
154
- """
154
+ MATCH (e:Episodic)
155
+ WHERE e.valid_at <= $reference_time
156
+ """
155
157
  + query_filter
156
158
  + """
157
159
  RETURN
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: graphiti-core
3
- Version: 0.20.3
3
+ Version: 0.21.0rc1
4
4
  Summary: A temporal graph building library
5
5
  Project-URL: Homepage, https://help.getzep.com/graphiti/graphiti/overview
6
6
  Project-URL: Repository, https://github.com/getzep/graphiti
@@ -47,6 +47,9 @@ Provides-Extra: groq
47
47
  Requires-Dist: groq>=0.2.0; extra == 'groq'
48
48
  Provides-Extra: kuzu
49
49
  Requires-Dist: kuzu>=0.11.2; extra == 'kuzu'
50
+ Provides-Extra: neo4j-opensearch
51
+ Requires-Dist: boto3>=1.39.16; extra == 'neo4j-opensearch'
52
+ Requires-Dist: opensearch-py>=3.0.0; extra == 'neo4j-opensearch'
50
53
  Provides-Extra: neptune
51
54
  Requires-Dist: boto3>=1.39.16; extra == 'neptune'
52
55
  Requires-Dist: langchain-aws>=0.2.29; extra == 'neptune'