graphiti-core 0.20.4__py3-none-any.whl → 0.21.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +224 -0
- graphiti_core/driver/falkordb_driver.py +1 -0
- graphiti_core/driver/kuzu_driver.py +1 -0
- graphiti_core/driver/neo4j_driver.py +59 -2
- graphiti_core/driver/neptune_driver.py +26 -45
- graphiti_core/edges.py +61 -4
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/graphiti.py +21 -5
- graphiti_core/models/edges/edge_db_queries.py +36 -16
- graphiti_core/models/nodes/node_db_queries.py +30 -10
- graphiti_core/nodes.py +120 -22
- graphiti_core/search/search_filters.py +53 -0
- graphiti_core/search/search_utils.py +225 -57
- graphiti_core/utils/bulk_utils.py +23 -3
- graphiti_core/utils/maintenance/edge_operations.py +39 -5
- graphiti_core/utils/maintenance/graph_data_operations.py +9 -5
- {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0rc2.dist-info}/METADATA +4 -1
- {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0rc2.dist-info}/RECORD +20 -20
- {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0rc2.dist-info}/WHEEL +0 -0
- {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0rc2.dist-info}/licenses/LICENSE +0 -0
|
@@ -23,7 +23,13 @@ import numpy as np
|
|
|
23
23
|
from numpy._typing import NDArray
|
|
24
24
|
from typing_extensions import LiteralString
|
|
25
25
|
|
|
26
|
-
from graphiti_core.driver.driver import
|
|
26
|
+
from graphiti_core.driver.driver import (
|
|
27
|
+
ENTITY_EDGE_INDEX_NAME,
|
|
28
|
+
ENTITY_INDEX_NAME,
|
|
29
|
+
EPISODE_INDEX_NAME,
|
|
30
|
+
GraphDriver,
|
|
31
|
+
GraphProvider,
|
|
32
|
+
)
|
|
27
33
|
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
|
28
34
|
from graphiti_core.graph_queries import (
|
|
29
35
|
get_nodes_query,
|
|
@@ -51,6 +57,8 @@ from graphiti_core.nodes import (
|
|
|
51
57
|
)
|
|
52
58
|
from graphiti_core.search.search_filters import (
|
|
53
59
|
SearchFilters,
|
|
60
|
+
build_aoss_edge_filters,
|
|
61
|
+
build_aoss_node_filters,
|
|
54
62
|
edge_search_filter_query_constructor,
|
|
55
63
|
node_search_filter_query_constructor,
|
|
56
64
|
)
|
|
@@ -200,7 +208,6 @@ async def edge_fulltext_search(
|
|
|
200
208
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
201
209
|
res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue
|
|
202
210
|
if res['hits']['total']['value'] > 0:
|
|
203
|
-
# Calculate Cosine similarity then return the edge ids
|
|
204
211
|
input_ids = []
|
|
205
212
|
for r in res['hits']['hits']:
|
|
206
213
|
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
|
|
@@ -208,11 +215,11 @@ async def edge_fulltext_search(
|
|
|
208
215
|
# Match the edge ids and return the values
|
|
209
216
|
query = (
|
|
210
217
|
"""
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
218
|
+
UNWIND $ids as id
|
|
219
|
+
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
220
|
+
WHERE e.group_id IN $group_ids
|
|
221
|
+
AND id(e)=id
|
|
222
|
+
"""
|
|
216
223
|
+ filter_query
|
|
217
224
|
+ """
|
|
218
225
|
AND id(e)=id
|
|
@@ -244,6 +251,35 @@ async def edge_fulltext_search(
|
|
|
244
251
|
)
|
|
245
252
|
else:
|
|
246
253
|
return []
|
|
254
|
+
elif driver.aoss_client:
|
|
255
|
+
route = group_ids[0] if group_ids else None
|
|
256
|
+
filters = build_aoss_edge_filters(group_ids or [], search_filter)
|
|
257
|
+
res = await driver.aoss_client.search(
|
|
258
|
+
index=ENTITY_EDGE_INDEX_NAME,
|
|
259
|
+
params={'routing': route},
|
|
260
|
+
body={
|
|
261
|
+
'size': limit,
|
|
262
|
+
'_source': ['uuid'],
|
|
263
|
+
'query': {
|
|
264
|
+
'bool': {
|
|
265
|
+
'filter': filters,
|
|
266
|
+
'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}],
|
|
267
|
+
}
|
|
268
|
+
},
|
|
269
|
+
},
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
if res['hits']['total']['value'] > 0:
|
|
273
|
+
input_uuids = {}
|
|
274
|
+
for r in res['hits']['hits']:
|
|
275
|
+
input_uuids[r['_source']['uuid']] = r['_score']
|
|
276
|
+
|
|
277
|
+
# Get edges
|
|
278
|
+
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
|
|
279
|
+
entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
|
280
|
+
return entity_edges
|
|
281
|
+
else:
|
|
282
|
+
return []
|
|
247
283
|
else:
|
|
248
284
|
query = (
|
|
249
285
|
get_relationships_query('edge_name_and_fact', limit=limit, provider=driver.provider)
|
|
@@ -318,8 +354,8 @@ async def edge_similarity_search(
|
|
|
318
354
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
319
355
|
query = (
|
|
320
356
|
"""
|
|
321
|
-
|
|
322
|
-
|
|
357
|
+
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
358
|
+
"""
|
|
323
359
|
+ filter_query
|
|
324
360
|
+ """
|
|
325
361
|
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
|
|
@@ -377,6 +413,38 @@ async def edge_similarity_search(
|
|
|
377
413
|
)
|
|
378
414
|
else:
|
|
379
415
|
return []
|
|
416
|
+
elif driver.aoss_client:
|
|
417
|
+
route = group_ids[0] if group_ids else None
|
|
418
|
+
filters = build_aoss_edge_filters(group_ids or [], search_filter)
|
|
419
|
+
res = await driver.aoss_client.search(
|
|
420
|
+
index=ENTITY_EDGE_INDEX_NAME,
|
|
421
|
+
params={'routing': route},
|
|
422
|
+
body={
|
|
423
|
+
'size': limit,
|
|
424
|
+
'_source': ['uuid'],
|
|
425
|
+
'query': {
|
|
426
|
+
'knn': {
|
|
427
|
+
'fact_embedding': {
|
|
428
|
+
'vector': list(map(float, search_vector)),
|
|
429
|
+
'k': limit,
|
|
430
|
+
'filter': {'bool': {'filter': filters}},
|
|
431
|
+
}
|
|
432
|
+
}
|
|
433
|
+
},
|
|
434
|
+
},
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
if res['hits']['total']['value'] > 0:
|
|
438
|
+
input_uuids = {}
|
|
439
|
+
for r in res['hits']['hits']:
|
|
440
|
+
input_uuids[r['_source']['uuid']] = r['_score']
|
|
441
|
+
|
|
442
|
+
# Get edges
|
|
443
|
+
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
|
|
444
|
+
entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
|
445
|
+
return entity_edges
|
|
446
|
+
return []
|
|
447
|
+
|
|
380
448
|
else:
|
|
381
449
|
query = (
|
|
382
450
|
match_query
|
|
@@ -563,7 +631,6 @@ async def node_fulltext_search(
|
|
|
563
631
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
564
632
|
res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
|
565
633
|
if res['hits']['total']['value'] > 0:
|
|
566
|
-
# Calculate Cosine similarity then return the edge ids
|
|
567
634
|
input_ids = []
|
|
568
635
|
for r in res['hits']['hits']:
|
|
569
636
|
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
|
|
@@ -571,11 +638,11 @@ async def node_fulltext_search(
|
|
|
571
638
|
# Match the edge ides and return the values
|
|
572
639
|
query = (
|
|
573
640
|
"""
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
641
|
+
UNWIND $ids as i
|
|
642
|
+
MATCH (n:Entity)
|
|
643
|
+
WHERE n.uuid=i.id
|
|
644
|
+
RETURN
|
|
645
|
+
"""
|
|
579
646
|
+ get_entity_node_return_query(driver.provider)
|
|
580
647
|
+ """
|
|
581
648
|
ORDER BY i.score DESC
|
|
@@ -592,6 +659,43 @@ async def node_fulltext_search(
|
|
|
592
659
|
)
|
|
593
660
|
else:
|
|
594
661
|
return []
|
|
662
|
+
elif driver.aoss_client:
|
|
663
|
+
route = group_ids[0] if group_ids else None
|
|
664
|
+
filters = build_aoss_node_filters(group_ids or [], search_filter)
|
|
665
|
+
res = await driver.aoss_client.search(
|
|
666
|
+
index=ENTITY_INDEX_NAME,
|
|
667
|
+
params={'routing': route},
|
|
668
|
+
body={
|
|
669
|
+
'_source': ['uuid'],
|
|
670
|
+
'size': limit,
|
|
671
|
+
'query': {
|
|
672
|
+
'bool': {
|
|
673
|
+
'filter': filters,
|
|
674
|
+
'must': [
|
|
675
|
+
{
|
|
676
|
+
'multi_match': {
|
|
677
|
+
'query': query,
|
|
678
|
+
'fields': ['name', 'summary'],
|
|
679
|
+
'operator': 'or',
|
|
680
|
+
}
|
|
681
|
+
}
|
|
682
|
+
],
|
|
683
|
+
}
|
|
684
|
+
},
|
|
685
|
+
},
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
if res['hits']['total']['value'] > 0:
|
|
689
|
+
input_uuids = {}
|
|
690
|
+
for r in res['hits']['hits']:
|
|
691
|
+
input_uuids[r['_source']['uuid']] = r['_score']
|
|
692
|
+
|
|
693
|
+
# Get nodes
|
|
694
|
+
entities = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
|
|
695
|
+
entities.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
|
696
|
+
return entities
|
|
697
|
+
else:
|
|
698
|
+
return []
|
|
595
699
|
else:
|
|
596
700
|
query = (
|
|
597
701
|
get_nodes_query(
|
|
@@ -648,8 +752,8 @@ async def node_similarity_search(
|
|
|
648
752
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
649
753
|
query = (
|
|
650
754
|
"""
|
|
651
|
-
|
|
652
|
-
|
|
755
|
+
MATCH (n:Entity)
|
|
756
|
+
"""
|
|
653
757
|
+ filter_query
|
|
654
758
|
+ """
|
|
655
759
|
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
|
@@ -678,11 +782,11 @@ async def node_similarity_search(
|
|
|
678
782
|
# Match the edge ides and return the values
|
|
679
783
|
query = (
|
|
680
784
|
"""
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
785
|
+
UNWIND $ids as i
|
|
786
|
+
MATCH (n:Entity)
|
|
787
|
+
WHERE id(n)=i.id
|
|
788
|
+
RETURN
|
|
789
|
+
"""
|
|
686
790
|
+ get_entity_node_return_query(driver.provider)
|
|
687
791
|
+ """
|
|
688
792
|
ORDER BY i.score DESC
|
|
@@ -700,11 +804,42 @@ async def node_similarity_search(
|
|
|
700
804
|
)
|
|
701
805
|
else:
|
|
702
806
|
return []
|
|
807
|
+
elif driver.aoss_client:
|
|
808
|
+
route = group_ids[0] if group_ids else None
|
|
809
|
+
filters = build_aoss_node_filters(group_ids or [], search_filter)
|
|
810
|
+
res = await driver.aoss_client.search(
|
|
811
|
+
index=ENTITY_INDEX_NAME,
|
|
812
|
+
params={'routing': route},
|
|
813
|
+
body={
|
|
814
|
+
'size': limit,
|
|
815
|
+
'_source': ['uuid'],
|
|
816
|
+
'query': {
|
|
817
|
+
'knn': {
|
|
818
|
+
'name_embedding': {
|
|
819
|
+
'vector': list(map(float, search_vector)),
|
|
820
|
+
'k': limit,
|
|
821
|
+
'filter': {'bool': {'filter': filters}},
|
|
822
|
+
}
|
|
823
|
+
}
|
|
824
|
+
},
|
|
825
|
+
},
|
|
826
|
+
)
|
|
827
|
+
|
|
828
|
+
if res['hits']['total']['value'] > 0:
|
|
829
|
+
input_uuids = {}
|
|
830
|
+
for r in res['hits']['hits']:
|
|
831
|
+
input_uuids[r['_source']['uuid']] = r['_score']
|
|
832
|
+
|
|
833
|
+
# Get edges
|
|
834
|
+
entity_nodes = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
|
|
835
|
+
entity_nodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
|
836
|
+
return entity_nodes
|
|
837
|
+
return []
|
|
703
838
|
else:
|
|
704
839
|
query = (
|
|
705
840
|
"""
|
|
706
|
-
|
|
707
|
-
|
|
841
|
+
MATCH (n:Entity)
|
|
842
|
+
"""
|
|
708
843
|
+ filter_query
|
|
709
844
|
+ """
|
|
710
845
|
WITH n, """
|
|
@@ -843,7 +978,6 @@ async def episode_fulltext_search(
|
|
|
843
978
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
844
979
|
res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
|
845
980
|
if res['hits']['total']['value'] > 0:
|
|
846
|
-
# Calculate Cosine similarity then return the edge ids
|
|
847
981
|
input_ids = []
|
|
848
982
|
for r in res['hits']['hits']:
|
|
849
983
|
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
|
|
@@ -852,7 +986,7 @@ async def episode_fulltext_search(
|
|
|
852
986
|
query = """
|
|
853
987
|
UNWIND $ids as i
|
|
854
988
|
MATCH (e:Episodic)
|
|
855
|
-
WHERE e.uuid=i.
|
|
989
|
+
WHERE e.uuid=i.uuid
|
|
856
990
|
RETURN
|
|
857
991
|
e.content AS content,
|
|
858
992
|
e.created_at AS created_at,
|
|
@@ -876,6 +1010,40 @@ async def episode_fulltext_search(
|
|
|
876
1010
|
)
|
|
877
1011
|
else:
|
|
878
1012
|
return []
|
|
1013
|
+
elif driver.aoss_client:
|
|
1014
|
+
route = group_ids[0] if group_ids else None
|
|
1015
|
+
res = await driver.aoss_client.search(
|
|
1016
|
+
index=EPISODE_INDEX_NAME,
|
|
1017
|
+
params={'routing': route},
|
|
1018
|
+
body={
|
|
1019
|
+
'size': limit,
|
|
1020
|
+
'_source': ['uuid'],
|
|
1021
|
+
'bool': {
|
|
1022
|
+
'filter': {'terms': group_ids},
|
|
1023
|
+
'must': [
|
|
1024
|
+
{
|
|
1025
|
+
'multi_match': {
|
|
1026
|
+
'query': query,
|
|
1027
|
+
'field': ['name', 'content'],
|
|
1028
|
+
'operator': 'or',
|
|
1029
|
+
}
|
|
1030
|
+
}
|
|
1031
|
+
],
|
|
1032
|
+
},
|
|
1033
|
+
},
|
|
1034
|
+
)
|
|
1035
|
+
|
|
1036
|
+
if res['hits']['total']['value'] > 0:
|
|
1037
|
+
input_uuids = {}
|
|
1038
|
+
for r in res['hits']['hits']:
|
|
1039
|
+
input_uuids[r['_source']['uuid']] = r['_score']
|
|
1040
|
+
|
|
1041
|
+
# Get nodes
|
|
1042
|
+
episodes = await EpisodicNode.get_by_uuids(driver, list(input_uuids.keys()))
|
|
1043
|
+
episodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
|
1044
|
+
return episodes
|
|
1045
|
+
else:
|
|
1046
|
+
return []
|
|
879
1047
|
else:
|
|
880
1048
|
query = (
|
|
881
1049
|
get_nodes_query('episode_content', '$query', limit=limit, provider=driver.provider)
|
|
@@ -1003,8 +1171,8 @@ async def community_similarity_search(
|
|
|
1003
1171
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
1004
1172
|
query = (
|
|
1005
1173
|
"""
|
|
1006
|
-
|
|
1007
|
-
|
|
1174
|
+
MATCH (n:Community)
|
|
1175
|
+
"""
|
|
1008
1176
|
+ group_filter_query
|
|
1009
1177
|
+ """
|
|
1010
1178
|
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
|
@@ -1063,8 +1231,8 @@ async def community_similarity_search(
|
|
|
1063
1231
|
|
|
1064
1232
|
query = (
|
|
1065
1233
|
"""
|
|
1066
|
-
|
|
1067
|
-
|
|
1234
|
+
MATCH (c:Community)
|
|
1235
|
+
"""
|
|
1068
1236
|
+ group_filter_query
|
|
1069
1237
|
+ """
|
|
1070
1238
|
WITH c,
|
|
@@ -1206,9 +1374,9 @@ async def get_relevant_nodes(
|
|
|
1206
1374
|
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
|
|
1207
1375
|
query = (
|
|
1208
1376
|
"""
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
|
|
1377
|
+
UNWIND $nodes AS node
|
|
1378
|
+
MATCH (n:Entity {group_id: $group_id})
|
|
1379
|
+
"""
|
|
1212
1380
|
+ filter_query
|
|
1213
1381
|
+ """
|
|
1214
1382
|
WITH node, n, """
|
|
@@ -1253,9 +1421,9 @@ async def get_relevant_nodes(
|
|
|
1253
1421
|
else:
|
|
1254
1422
|
query = (
|
|
1255
1423
|
"""
|
|
1256
|
-
|
|
1257
|
-
|
|
1258
|
-
|
|
1424
|
+
UNWIND $nodes AS node
|
|
1425
|
+
MATCH (n:Entity {group_id: $group_id})
|
|
1426
|
+
"""
|
|
1259
1427
|
+ filter_query
|
|
1260
1428
|
+ """
|
|
1261
1429
|
WITH node, n, """
|
|
@@ -1344,9 +1512,9 @@ async def get_relevant_edges(
|
|
|
1344
1512
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
1345
1513
|
query = (
|
|
1346
1514
|
"""
|
|
1347
|
-
|
|
1348
|
-
|
|
1349
|
-
|
|
1515
|
+
UNWIND $edges AS edge
|
|
1516
|
+
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
|
1517
|
+
"""
|
|
1350
1518
|
+ filter_query
|
|
1351
1519
|
+ """
|
|
1352
1520
|
WITH e, edge
|
|
@@ -1416,9 +1584,9 @@ async def get_relevant_edges(
|
|
|
1416
1584
|
|
|
1417
1585
|
query = (
|
|
1418
1586
|
"""
|
|
1419
|
-
|
|
1420
|
-
|
|
1421
|
-
|
|
1587
|
+
UNWIND $edges AS edge
|
|
1588
|
+
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
|
|
1589
|
+
"""
|
|
1422
1590
|
+ filter_query
|
|
1423
1591
|
+ """
|
|
1424
1592
|
WITH e, edge, n, m, """
|
|
@@ -1454,9 +1622,9 @@ async def get_relevant_edges(
|
|
|
1454
1622
|
else:
|
|
1455
1623
|
query = (
|
|
1456
1624
|
"""
|
|
1457
|
-
|
|
1458
|
-
|
|
1459
|
-
|
|
1625
|
+
UNWIND $edges AS edge
|
|
1626
|
+
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
|
1627
|
+
"""
|
|
1460
1628
|
+ filter_query
|
|
1461
1629
|
+ """
|
|
1462
1630
|
WITH e, edge, """
|
|
@@ -1529,10 +1697,10 @@ async def get_edge_invalidation_candidates(
|
|
|
1529
1697
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
1530
1698
|
query = (
|
|
1531
1699
|
"""
|
|
1532
|
-
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
|
|
1700
|
+
UNWIND $edges AS edge
|
|
1701
|
+
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
|
1702
|
+
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
|
1703
|
+
"""
|
|
1536
1704
|
+ filter_query
|
|
1537
1705
|
+ """
|
|
1538
1706
|
WITH e, edge
|
|
@@ -1602,10 +1770,10 @@ async def get_edge_invalidation_candidates(
|
|
|
1602
1770
|
|
|
1603
1771
|
query = (
|
|
1604
1772
|
"""
|
|
1605
|
-
|
|
1606
|
-
|
|
1607
|
-
|
|
1608
|
-
|
|
1773
|
+
UNWIND $edges AS edge
|
|
1774
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
|
|
1775
|
+
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
|
|
1776
|
+
"""
|
|
1609
1777
|
+ filter_query
|
|
1610
1778
|
+ """
|
|
1611
1779
|
WITH edge, e, n, m, """
|
|
@@ -1641,10 +1809,10 @@ async def get_edge_invalidation_candidates(
|
|
|
1641
1809
|
else:
|
|
1642
1810
|
query = (
|
|
1643
1811
|
"""
|
|
1644
|
-
|
|
1645
|
-
|
|
1646
|
-
|
|
1647
|
-
|
|
1812
|
+
UNWIND $edges AS edge
|
|
1813
|
+
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
|
1814
|
+
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
|
1815
|
+
"""
|
|
1648
1816
|
+ filter_query
|
|
1649
1817
|
+ """
|
|
1650
1818
|
WITH edge, e, """
|
|
@@ -23,7 +23,14 @@ import numpy as np
|
|
|
23
23
|
from pydantic import BaseModel, Field
|
|
24
24
|
from typing_extensions import Any
|
|
25
25
|
|
|
26
|
-
from graphiti_core.driver.driver import
|
|
26
|
+
from graphiti_core.driver.driver import (
|
|
27
|
+
ENTITY_EDGE_INDEX_NAME,
|
|
28
|
+
ENTITY_INDEX_NAME,
|
|
29
|
+
EPISODE_INDEX_NAME,
|
|
30
|
+
GraphDriver,
|
|
31
|
+
GraphDriverSession,
|
|
32
|
+
GraphProvider,
|
|
33
|
+
)
|
|
27
34
|
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
|
|
28
35
|
from graphiti_core.embedder import EmbedderClient
|
|
29
36
|
from graphiti_core.graphiti_types import GraphitiClients
|
|
@@ -187,12 +194,25 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
187
194
|
await tx.run(episodic_edge_query, **edge.model_dump())
|
|
188
195
|
else:
|
|
189
196
|
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
|
|
190
|
-
await tx.run(
|
|
197
|
+
await tx.run(
|
|
198
|
+
get_entity_node_save_bulk_query(driver.provider, nodes),
|
|
199
|
+
nodes=nodes,
|
|
200
|
+
has_aoss=bool(driver.aoss_client),
|
|
201
|
+
)
|
|
191
202
|
await tx.run(
|
|
192
203
|
get_episodic_edge_save_bulk_query(driver.provider),
|
|
193
204
|
episodic_edges=[edge.model_dump() for edge in episodic_edges],
|
|
194
205
|
)
|
|
195
|
-
await tx.run(
|
|
206
|
+
await tx.run(
|
|
207
|
+
get_entity_edge_save_bulk_query(driver.provider),
|
|
208
|
+
entity_edges=edges,
|
|
209
|
+
has_aoss=bool(driver.aoss_client),
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
if driver.aoss_client:
|
|
213
|
+
await driver.save_to_aoss(EPISODE_INDEX_NAME, episodes)
|
|
214
|
+
await driver.save_to_aoss(ENTITY_INDEX_NAME, nodes)
|
|
215
|
+
await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges)
|
|
196
216
|
|
|
197
217
|
|
|
198
218
|
async def extract_nodes_and_edges_bulk(
|
|
@@ -36,8 +36,10 @@ from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
|
|
36
36
|
from graphiti_core.prompts import prompt_library
|
|
37
37
|
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
|
|
38
38
|
from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
|
|
39
|
+
from graphiti_core.search.search import search
|
|
40
|
+
from graphiti_core.search.search_config import SearchResults
|
|
41
|
+
from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF
|
|
39
42
|
from graphiti_core.search.search_filters import SearchFilters
|
|
40
|
-
from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
|
|
41
43
|
from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
|
|
42
44
|
|
|
43
45
|
logger = logging.getLogger(__name__)
|
|
@@ -258,12 +260,44 @@ async def resolve_extracted_edges(
|
|
|
258
260
|
embedder = clients.embedder
|
|
259
261
|
await create_entity_edge_embeddings(embedder, extracted_edges)
|
|
260
262
|
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
263
|
+
valid_edges_list: list[list[EntityEdge]] = await semaphore_gather(
|
|
264
|
+
*[
|
|
265
|
+
EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid)
|
|
266
|
+
for edge in extracted_edges
|
|
267
|
+
]
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
related_edges_results: list[SearchResults] = await semaphore_gather(
|
|
271
|
+
*[
|
|
272
|
+
search(
|
|
273
|
+
clients,
|
|
274
|
+
extracted_edge.fact,
|
|
275
|
+
group_ids=[extracted_edge.group_id],
|
|
276
|
+
config=EDGE_HYBRID_SEARCH_RRF,
|
|
277
|
+
search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
|
|
278
|
+
)
|
|
279
|
+
for extracted_edge, valid_edges in zip(extracted_edges, valid_edges_list, strict=True)
|
|
280
|
+
]
|
|
264
281
|
)
|
|
265
282
|
|
|
266
|
-
related_edges_lists
|
|
283
|
+
related_edges_lists: list[list[EntityEdge]] = [result.edges for result in related_edges_results]
|
|
284
|
+
|
|
285
|
+
edge_invalidation_candidate_results: list[SearchResults] = await semaphore_gather(
|
|
286
|
+
*[
|
|
287
|
+
search(
|
|
288
|
+
clients,
|
|
289
|
+
extracted_edge.fact,
|
|
290
|
+
group_ids=[extracted_edge.group_id],
|
|
291
|
+
config=EDGE_HYBRID_SEARCH_RRF,
|
|
292
|
+
search_filter=SearchFilters(),
|
|
293
|
+
)
|
|
294
|
+
for extracted_edge in extracted_edges
|
|
295
|
+
]
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
edge_invalidation_candidates: list[list[EntityEdge]] = [
|
|
299
|
+
result.edges for result in edge_invalidation_candidate_results
|
|
300
|
+
]
|
|
267
301
|
|
|
268
302
|
logger.debug(
|
|
269
303
|
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
|
|
@@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
|
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
|
|
37
|
-
if driver.
|
|
37
|
+
if driver.aoss_client:
|
|
38
38
|
await driver.create_aoss_indices() # pyright: ignore[reportAttributeAccessIssue]
|
|
39
39
|
return
|
|
40
40
|
if delete_existing:
|
|
@@ -56,7 +56,9 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo
|
|
|
56
56
|
|
|
57
57
|
range_indices: list[LiteralString] = get_range_indices(driver.provider)
|
|
58
58
|
|
|
59
|
-
|
|
59
|
+
# Don't create fulltext indices if OpenSearch is being used
|
|
60
|
+
if not driver.aoss_client:
|
|
61
|
+
fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
|
|
60
62
|
|
|
61
63
|
if driver.provider == GraphProvider.KUZU:
|
|
62
64
|
# Skip creating fulltext indices if they already exist. Need to do this manually
|
|
@@ -93,6 +95,8 @@ async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
|
|
|
93
95
|
|
|
94
96
|
async def delete_all(tx):
|
|
95
97
|
await tx.run('MATCH (n) DETACH DELETE n')
|
|
98
|
+
if driver.aoss_client:
|
|
99
|
+
await driver.clear_aoss_indices()
|
|
96
100
|
|
|
97
101
|
async def delete_group_ids(tx):
|
|
98
102
|
labels = ['Entity', 'Episodic', 'Community']
|
|
@@ -149,9 +153,9 @@ async def retrieve_episodes(
|
|
|
149
153
|
|
|
150
154
|
query: LiteralString = (
|
|
151
155
|
"""
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
156
|
+
MATCH (e:Episodic)
|
|
157
|
+
WHERE e.valid_at <= $reference_time
|
|
158
|
+
"""
|
|
155
159
|
+ query_filter
|
|
156
160
|
+ """
|
|
157
161
|
RETURN
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: graphiti-core
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.21.0rc2
|
|
4
4
|
Summary: A temporal graph building library
|
|
5
5
|
Project-URL: Homepage, https://help.getzep.com/graphiti/graphiti/overview
|
|
6
6
|
Project-URL: Repository, https://github.com/getzep/graphiti
|
|
@@ -47,6 +47,9 @@ Provides-Extra: groq
|
|
|
47
47
|
Requires-Dist: groq>=0.2.0; extra == 'groq'
|
|
48
48
|
Provides-Extra: kuzu
|
|
49
49
|
Requires-Dist: kuzu>=0.11.2; extra == 'kuzu'
|
|
50
|
+
Provides-Extra: neo4j-opensearch
|
|
51
|
+
Requires-Dist: boto3>=1.39.16; extra == 'neo4j-opensearch'
|
|
52
|
+
Requires-Dist: opensearch-py>=3.0.0; extra == 'neo4j-opensearch'
|
|
50
53
|
Provides-Extra: neptune
|
|
51
54
|
Requires-Dist: boto3>=1.39.16; extra == 'neptune'
|
|
52
55
|
Requires-Dist: langchain-aws>=0.2.29; extra == 'neptune'
|