graphiti-core 0.21.0rc1__py3-none-any.whl → 0.21.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +97 -43
- graphiti_core/driver/neo4j_driver.py +18 -7
- graphiti_core/driver/neptune_driver.py +2 -2
- graphiti_core/edges.py +47 -5
- graphiti_core/graphiti.py +21 -5
- graphiti_core/nodes.py +99 -9
- graphiti_core/search/search_filters.py +8 -0
- graphiti_core/search/search_utils.py +130 -106
- graphiti_core/utils/bulk_utils.py +30 -11
- graphiti_core/utils/maintenance/edge_operations.py +39 -5
- graphiti_core/utils/maintenance/graph_data_operations.py +5 -3
- {graphiti_core-0.21.0rc1.dist-info → graphiti_core-0.21.0rc3.dist-info}/METADATA +1 -1
- {graphiti_core-0.21.0rc1.dist-info → graphiti_core-0.21.0rc3.dist-info}/RECORD +15 -15
- {graphiti_core-0.21.0rc1.dist-info → graphiti_core-0.21.0rc3.dist-info}/WHEEL +0 -0
- {graphiti_core-0.21.0rc1.dist-info → graphiti_core-0.21.0rc3.dist-info}/licenses/LICENSE +0 -0
|
@@ -23,7 +23,13 @@ import numpy as np
|
|
|
23
23
|
from numpy._typing import NDArray
|
|
24
24
|
from typing_extensions import LiteralString
|
|
25
25
|
|
|
26
|
-
from graphiti_core.driver.driver import
|
|
26
|
+
from graphiti_core.driver.driver import (
|
|
27
|
+
ENTITY_EDGE_INDEX_NAME,
|
|
28
|
+
ENTITY_INDEX_NAME,
|
|
29
|
+
EPISODE_INDEX_NAME,
|
|
30
|
+
GraphDriver,
|
|
31
|
+
GraphProvider,
|
|
32
|
+
)
|
|
27
33
|
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
|
28
34
|
from graphiti_core.graph_queries import (
|
|
29
35
|
get_nodes_query,
|
|
@@ -209,11 +215,11 @@ async def edge_fulltext_search(
|
|
|
209
215
|
# Match the edge ids and return the values
|
|
210
216
|
query = (
|
|
211
217
|
"""
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
218
|
+
UNWIND $ids as id
|
|
219
|
+
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
220
|
+
WHERE e.group_id IN $group_ids
|
|
221
|
+
AND id(e)=id
|
|
222
|
+
"""
|
|
217
223
|
+ filter_query
|
|
218
224
|
+ """
|
|
219
225
|
AND id(e)=id
|
|
@@ -248,17 +254,21 @@ async def edge_fulltext_search(
|
|
|
248
254
|
elif driver.aoss_client:
|
|
249
255
|
route = group_ids[0] if group_ids else None
|
|
250
256
|
filters = build_aoss_edge_filters(group_ids or [], search_filter)
|
|
251
|
-
res = driver.aoss_client.search(
|
|
252
|
-
index=
|
|
253
|
-
routing
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
'
|
|
257
|
-
|
|
258
|
-
'
|
|
259
|
-
|
|
257
|
+
res = await driver.aoss_client.search(
|
|
258
|
+
index=ENTITY_EDGE_INDEX_NAME,
|
|
259
|
+
params={'routing': route},
|
|
260
|
+
body={
|
|
261
|
+
'size': limit,
|
|
262
|
+
'_source': ['uuid'],
|
|
263
|
+
'query': {
|
|
264
|
+
'bool': {
|
|
265
|
+
'filter': filters,
|
|
266
|
+
'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}],
|
|
267
|
+
}
|
|
268
|
+
},
|
|
260
269
|
},
|
|
261
270
|
)
|
|
271
|
+
|
|
262
272
|
if res['hits']['total']['value'] > 0:
|
|
263
273
|
input_uuids = {}
|
|
264
274
|
for r in res['hits']['hits']:
|
|
@@ -344,8 +354,8 @@ async def edge_similarity_search(
|
|
|
344
354
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
345
355
|
query = (
|
|
346
356
|
"""
|
|
347
|
-
|
|
348
|
-
|
|
357
|
+
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
358
|
+
"""
|
|
349
359
|
+ filter_query
|
|
350
360
|
+ """
|
|
351
361
|
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
|
|
@@ -406,17 +416,22 @@ async def edge_similarity_search(
|
|
|
406
416
|
elif driver.aoss_client:
|
|
407
417
|
route = group_ids[0] if group_ids else None
|
|
408
418
|
filters = build_aoss_edge_filters(group_ids or [], search_filter)
|
|
409
|
-
res = driver.aoss_client.search(
|
|
410
|
-
index=
|
|
411
|
-
routing
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
'
|
|
415
|
-
'
|
|
416
|
-
|
|
417
|
-
|
|
419
|
+
res = await driver.aoss_client.search(
|
|
420
|
+
index=ENTITY_EDGE_INDEX_NAME,
|
|
421
|
+
params={'routing': route},
|
|
422
|
+
body={
|
|
423
|
+
'size': limit,
|
|
424
|
+
'_source': ['uuid'],
|
|
425
|
+
'query': {
|
|
426
|
+
'knn': {
|
|
427
|
+
'fact_embedding': {
|
|
428
|
+
'vector': list(map(float, search_vector)),
|
|
429
|
+
'k': limit,
|
|
430
|
+
'filter': {'bool': {'filter': filters}},
|
|
431
|
+
}
|
|
432
|
+
}
|
|
433
|
+
},
|
|
418
434
|
},
|
|
419
|
-
query={'bool': {'filter': filters}},
|
|
420
435
|
)
|
|
421
436
|
|
|
422
437
|
if res['hits']['total']['value'] > 0:
|
|
@@ -428,6 +443,7 @@ async def edge_similarity_search(
|
|
|
428
443
|
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
|
|
429
444
|
entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
|
430
445
|
return entity_edges
|
|
446
|
+
return []
|
|
431
447
|
|
|
432
448
|
else:
|
|
433
449
|
query = (
|
|
@@ -622,11 +638,11 @@ async def node_fulltext_search(
|
|
|
622
638
|
# Match the edge ides and return the values
|
|
623
639
|
query = (
|
|
624
640
|
"""
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
641
|
+
UNWIND $ids as i
|
|
642
|
+
MATCH (n:Entity)
|
|
643
|
+
WHERE n.uuid=i.id
|
|
644
|
+
RETURN
|
|
645
|
+
"""
|
|
630
646
|
+ get_entity_node_return_query(driver.provider)
|
|
631
647
|
+ """
|
|
632
648
|
ORDER BY i.score DESC
|
|
@@ -646,25 +662,27 @@ async def node_fulltext_search(
|
|
|
646
662
|
elif driver.aoss_client:
|
|
647
663
|
route = group_ids[0] if group_ids else None
|
|
648
664
|
filters = build_aoss_node_filters(group_ids or [], search_filter)
|
|
649
|
-
res = driver.aoss_client.search(
|
|
650
|
-
|
|
651
|
-
routing
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
'
|
|
655
|
-
|
|
656
|
-
'
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
'
|
|
661
|
-
|
|
665
|
+
res = await driver.aoss_client.search(
|
|
666
|
+
index=ENTITY_INDEX_NAME,
|
|
667
|
+
params={'routing': route},
|
|
668
|
+
body={
|
|
669
|
+
'_source': ['uuid'],
|
|
670
|
+
'size': limit,
|
|
671
|
+
'query': {
|
|
672
|
+
'bool': {
|
|
673
|
+
'filter': filters,
|
|
674
|
+
'must': [
|
|
675
|
+
{
|
|
676
|
+
'multi_match': {
|
|
677
|
+
'query': query,
|
|
678
|
+
'fields': ['name', 'summary'],
|
|
679
|
+
'operator': 'or',
|
|
680
|
+
}
|
|
662
681
|
}
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
}
|
|
682
|
+
],
|
|
683
|
+
}
|
|
684
|
+
},
|
|
666
685
|
},
|
|
667
|
-
limit=limit,
|
|
668
686
|
)
|
|
669
687
|
|
|
670
688
|
if res['hits']['total']['value'] > 0:
|
|
@@ -734,8 +752,8 @@ async def node_similarity_search(
|
|
|
734
752
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
735
753
|
query = (
|
|
736
754
|
"""
|
|
737
|
-
|
|
738
|
-
|
|
755
|
+
MATCH (n:Entity)
|
|
756
|
+
"""
|
|
739
757
|
+ filter_query
|
|
740
758
|
+ """
|
|
741
759
|
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
|
@@ -764,11 +782,11 @@ async def node_similarity_search(
|
|
|
764
782
|
# Match the edge ides and return the values
|
|
765
783
|
query = (
|
|
766
784
|
"""
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
785
|
+
UNWIND $ids as i
|
|
786
|
+
MATCH (n:Entity)
|
|
787
|
+
WHERE id(n)=i.id
|
|
788
|
+
RETURN
|
|
789
|
+
"""
|
|
772
790
|
+ get_entity_node_return_query(driver.provider)
|
|
773
791
|
+ """
|
|
774
792
|
ORDER BY i.score DESC
|
|
@@ -789,17 +807,22 @@ async def node_similarity_search(
|
|
|
789
807
|
elif driver.aoss_client:
|
|
790
808
|
route = group_ids[0] if group_ids else None
|
|
791
809
|
filters = build_aoss_node_filters(group_ids or [], search_filter)
|
|
792
|
-
res = driver.aoss_client.search(
|
|
793
|
-
index=
|
|
794
|
-
routing
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
'
|
|
798
|
-
'
|
|
799
|
-
|
|
800
|
-
|
|
810
|
+
res = await driver.aoss_client.search(
|
|
811
|
+
index=ENTITY_INDEX_NAME,
|
|
812
|
+
params={'routing': route},
|
|
813
|
+
body={
|
|
814
|
+
'size': limit,
|
|
815
|
+
'_source': ['uuid'],
|
|
816
|
+
'query': {
|
|
817
|
+
'knn': {
|
|
818
|
+
'name_embedding': {
|
|
819
|
+
'vector': list(map(float, search_vector)),
|
|
820
|
+
'k': limit,
|
|
821
|
+
'filter': {'bool': {'filter': filters}},
|
|
822
|
+
}
|
|
823
|
+
}
|
|
824
|
+
},
|
|
801
825
|
},
|
|
802
|
-
query={'bool': {'filter': filters}},
|
|
803
826
|
)
|
|
804
827
|
|
|
805
828
|
if res['hits']['total']['value'] > 0:
|
|
@@ -811,11 +834,12 @@ async def node_similarity_search(
|
|
|
811
834
|
entity_nodes = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
|
|
812
835
|
entity_nodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
|
|
813
836
|
return entity_nodes
|
|
837
|
+
return []
|
|
814
838
|
else:
|
|
815
839
|
query = (
|
|
816
840
|
"""
|
|
817
|
-
|
|
818
|
-
|
|
841
|
+
MATCH (n:Entity)
|
|
842
|
+
"""
|
|
819
843
|
+ filter_query
|
|
820
844
|
+ """
|
|
821
845
|
WITH n, """
|
|
@@ -988,11 +1012,12 @@ async def episode_fulltext_search(
|
|
|
988
1012
|
return []
|
|
989
1013
|
elif driver.aoss_client:
|
|
990
1014
|
route = group_ids[0] if group_ids else None
|
|
991
|
-
res = driver.aoss_client.search(
|
|
992
|
-
|
|
993
|
-
routing
|
|
994
|
-
|
|
995
|
-
|
|
1015
|
+
res = await driver.aoss_client.search(
|
|
1016
|
+
index=EPISODE_INDEX_NAME,
|
|
1017
|
+
params={'routing': route},
|
|
1018
|
+
body={
|
|
1019
|
+
'size': limit,
|
|
1020
|
+
'_source': ['uuid'],
|
|
996
1021
|
'bool': {
|
|
997
1022
|
'filter': {'terms': group_ids},
|
|
998
1023
|
'must': [
|
|
@@ -1004,9 +1029,8 @@ async def episode_fulltext_search(
|
|
|
1004
1029
|
}
|
|
1005
1030
|
}
|
|
1006
1031
|
],
|
|
1007
|
-
}
|
|
1032
|
+
},
|
|
1008
1033
|
},
|
|
1009
|
-
limit=limit,
|
|
1010
1034
|
)
|
|
1011
1035
|
|
|
1012
1036
|
if res['hits']['total']['value'] > 0:
|
|
@@ -1147,8 +1171,8 @@ async def community_similarity_search(
|
|
|
1147
1171
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
1148
1172
|
query = (
|
|
1149
1173
|
"""
|
|
1150
|
-
|
|
1151
|
-
|
|
1174
|
+
MATCH (n:Community)
|
|
1175
|
+
"""
|
|
1152
1176
|
+ group_filter_query
|
|
1153
1177
|
+ """
|
|
1154
1178
|
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
|
@@ -1207,8 +1231,8 @@ async def community_similarity_search(
|
|
|
1207
1231
|
|
|
1208
1232
|
query = (
|
|
1209
1233
|
"""
|
|
1210
|
-
|
|
1211
|
-
|
|
1234
|
+
MATCH (c:Community)
|
|
1235
|
+
"""
|
|
1212
1236
|
+ group_filter_query
|
|
1213
1237
|
+ """
|
|
1214
1238
|
WITH c,
|
|
@@ -1350,9 +1374,9 @@ async def get_relevant_nodes(
|
|
|
1350
1374
|
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
|
|
1351
1375
|
query = (
|
|
1352
1376
|
"""
|
|
1353
|
-
|
|
1354
|
-
|
|
1355
|
-
|
|
1377
|
+
UNWIND $nodes AS node
|
|
1378
|
+
MATCH (n:Entity {group_id: $group_id})
|
|
1379
|
+
"""
|
|
1356
1380
|
+ filter_query
|
|
1357
1381
|
+ """
|
|
1358
1382
|
WITH node, n, """
|
|
@@ -1397,9 +1421,9 @@ async def get_relevant_nodes(
|
|
|
1397
1421
|
else:
|
|
1398
1422
|
query = (
|
|
1399
1423
|
"""
|
|
1400
|
-
|
|
1401
|
-
|
|
1402
|
-
|
|
1424
|
+
UNWIND $nodes AS node
|
|
1425
|
+
MATCH (n:Entity {group_id: $group_id})
|
|
1426
|
+
"""
|
|
1403
1427
|
+ filter_query
|
|
1404
1428
|
+ """
|
|
1405
1429
|
WITH node, n, """
|
|
@@ -1488,9 +1512,9 @@ async def get_relevant_edges(
|
|
|
1488
1512
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
1489
1513
|
query = (
|
|
1490
1514
|
"""
|
|
1491
|
-
|
|
1492
|
-
|
|
1493
|
-
|
|
1515
|
+
UNWIND $edges AS edge
|
|
1516
|
+
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
|
1517
|
+
"""
|
|
1494
1518
|
+ filter_query
|
|
1495
1519
|
+ """
|
|
1496
1520
|
WITH e, edge
|
|
@@ -1560,9 +1584,9 @@ async def get_relevant_edges(
|
|
|
1560
1584
|
|
|
1561
1585
|
query = (
|
|
1562
1586
|
"""
|
|
1563
|
-
|
|
1564
|
-
|
|
1565
|
-
|
|
1587
|
+
UNWIND $edges AS edge
|
|
1588
|
+
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
|
|
1589
|
+
"""
|
|
1566
1590
|
+ filter_query
|
|
1567
1591
|
+ """
|
|
1568
1592
|
WITH e, edge, n, m, """
|
|
@@ -1598,9 +1622,9 @@ async def get_relevant_edges(
|
|
|
1598
1622
|
else:
|
|
1599
1623
|
query = (
|
|
1600
1624
|
"""
|
|
1601
|
-
|
|
1602
|
-
|
|
1603
|
-
|
|
1625
|
+
UNWIND $edges AS edge
|
|
1626
|
+
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
|
1627
|
+
"""
|
|
1604
1628
|
+ filter_query
|
|
1605
1629
|
+ """
|
|
1606
1630
|
WITH e, edge, """
|
|
@@ -1673,10 +1697,10 @@ async def get_edge_invalidation_candidates(
|
|
|
1673
1697
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
1674
1698
|
query = (
|
|
1675
1699
|
"""
|
|
1676
|
-
|
|
1677
|
-
|
|
1678
|
-
|
|
1679
|
-
|
|
1700
|
+
UNWIND $edges AS edge
|
|
1701
|
+
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
|
1702
|
+
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
|
1703
|
+
"""
|
|
1680
1704
|
+ filter_query
|
|
1681
1705
|
+ """
|
|
1682
1706
|
WITH e, edge
|
|
@@ -1746,10 +1770,10 @@ async def get_edge_invalidation_candidates(
|
|
|
1746
1770
|
|
|
1747
1771
|
query = (
|
|
1748
1772
|
"""
|
|
1749
|
-
|
|
1750
|
-
|
|
1751
|
-
|
|
1752
|
-
|
|
1773
|
+
UNWIND $edges AS edge
|
|
1774
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
|
|
1775
|
+
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
|
|
1776
|
+
"""
|
|
1753
1777
|
+ filter_query
|
|
1754
1778
|
+ """
|
|
1755
1779
|
WITH edge, e, n, m, """
|
|
@@ -1785,10 +1809,10 @@ async def get_edge_invalidation_candidates(
|
|
|
1785
1809
|
else:
|
|
1786
1810
|
query = (
|
|
1787
1811
|
"""
|
|
1788
|
-
|
|
1789
|
-
|
|
1790
|
-
|
|
1791
|
-
|
|
1812
|
+
UNWIND $edges AS edge
|
|
1813
|
+
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
|
1814
|
+
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
|
1815
|
+
"""
|
|
1792
1816
|
+ filter_query
|
|
1793
1817
|
+ """
|
|
1794
1818
|
WITH edge, e, """
|
|
@@ -23,7 +23,14 @@ import numpy as np
|
|
|
23
23
|
from pydantic import BaseModel, Field
|
|
24
24
|
from typing_extensions import Any
|
|
25
25
|
|
|
26
|
-
from graphiti_core.driver.driver import
|
|
26
|
+
from graphiti_core.driver.driver import (
|
|
27
|
+
ENTITY_EDGE_INDEX_NAME,
|
|
28
|
+
ENTITY_INDEX_NAME,
|
|
29
|
+
EPISODE_INDEX_NAME,
|
|
30
|
+
GraphDriver,
|
|
31
|
+
GraphDriverSession,
|
|
32
|
+
GraphProvider,
|
|
33
|
+
)
|
|
27
34
|
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
|
|
28
35
|
from graphiti_core.embedder import EmbedderClient
|
|
29
36
|
from graphiti_core.graphiti_types import GraphitiClients
|
|
@@ -129,12 +136,14 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
129
136
|
entity_data: dict[str, Any] = {
|
|
130
137
|
'uuid': node.uuid,
|
|
131
138
|
'name': node.name,
|
|
132
|
-
'name_embedding': node.name_embedding,
|
|
133
139
|
'group_id': node.group_id,
|
|
134
140
|
'summary': node.summary,
|
|
135
141
|
'created_at': node.created_at,
|
|
136
142
|
}
|
|
137
143
|
|
|
144
|
+
if not bool(driver.aoss_client):
|
|
145
|
+
entity_data['name_embedding'] = node.name_embedding
|
|
146
|
+
|
|
138
147
|
entity_data['labels'] = list(set(node.labels + ['Entity']))
|
|
139
148
|
if driver.provider == GraphProvider.KUZU:
|
|
140
149
|
attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
|
|
@@ -154,7 +163,6 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
154
163
|
'target_node_uuid': edge.target_node_uuid,
|
|
155
164
|
'name': edge.name,
|
|
156
165
|
'fact': edge.fact,
|
|
157
|
-
'fact_embedding': edge.fact_embedding,
|
|
158
166
|
'group_id': edge.group_id,
|
|
159
167
|
'episodes': edge.episodes,
|
|
160
168
|
'created_at': edge.created_at,
|
|
@@ -163,6 +171,9 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
163
171
|
'invalid_at': edge.invalid_at,
|
|
164
172
|
}
|
|
165
173
|
|
|
174
|
+
if not bool(driver.aoss_client):
|
|
175
|
+
edge_data['fact_embedding'] = edge.fact_embedding
|
|
176
|
+
|
|
166
177
|
if driver.provider == GraphProvider.KUZU:
|
|
167
178
|
attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
|
|
168
179
|
edge_data['attributes'] = json.dumps(attributes)
|
|
@@ -188,24 +199,32 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
188
199
|
else:
|
|
189
200
|
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
|
|
190
201
|
await tx.run(
|
|
191
|
-
get_entity_node_save_bulk_query(
|
|
202
|
+
get_entity_node_save_bulk_query(
|
|
203
|
+
driver.provider, nodes, has_aoss=bool(driver.aoss_client)
|
|
204
|
+
),
|
|
192
205
|
nodes=nodes,
|
|
193
|
-
has_aoss=bool(driver.aoss_client),
|
|
194
206
|
)
|
|
195
207
|
await tx.run(
|
|
196
208
|
get_episodic_edge_save_bulk_query(driver.provider),
|
|
197
209
|
episodic_edges=[edge.model_dump() for edge in episodic_edges],
|
|
198
210
|
)
|
|
199
211
|
await tx.run(
|
|
200
|
-
get_entity_edge_save_bulk_query(driver.provider),
|
|
212
|
+
get_entity_edge_save_bulk_query(driver.provider, has_aoss=bool(driver.aoss_client)),
|
|
201
213
|
entity_edges=edges,
|
|
202
|
-
has_aoss=bool(driver.aoss_client),
|
|
203
214
|
)
|
|
204
215
|
|
|
205
|
-
if driver.aoss_client:
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
216
|
+
if bool(driver.aoss_client):
|
|
217
|
+
for node_data, entity_node in zip(nodes, entity_nodes, strict=True):
|
|
218
|
+
if node_data.get('uuid') == entity_node.uuid:
|
|
219
|
+
node_data['name_embedding'] = entity_node.name_embedding
|
|
220
|
+
|
|
221
|
+
for edge_data, entity_edge in zip(edges, entity_edges, strict=True):
|
|
222
|
+
if edge_data.get('uuid') == entity_edge.uuid:
|
|
223
|
+
edge_data['fact_embedding'] = entity_edge.fact_embedding
|
|
224
|
+
|
|
225
|
+
await driver.save_to_aoss(EPISODE_INDEX_NAME, episodes)
|
|
226
|
+
await driver.save_to_aoss(ENTITY_INDEX_NAME, nodes)
|
|
227
|
+
await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges)
|
|
209
228
|
|
|
210
229
|
|
|
211
230
|
async def extract_nodes_and_edges_bulk(
|
|
@@ -36,8 +36,10 @@ from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
|
|
36
36
|
from graphiti_core.prompts import prompt_library
|
|
37
37
|
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
|
|
38
38
|
from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
|
|
39
|
+
from graphiti_core.search.search import search
|
|
40
|
+
from graphiti_core.search.search_config import SearchResults
|
|
41
|
+
from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF
|
|
39
42
|
from graphiti_core.search.search_filters import SearchFilters
|
|
40
|
-
from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
|
|
41
43
|
from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
|
|
42
44
|
|
|
43
45
|
logger = logging.getLogger(__name__)
|
|
@@ -258,12 +260,44 @@ async def resolve_extracted_edges(
|
|
|
258
260
|
embedder = clients.embedder
|
|
259
261
|
await create_entity_edge_embeddings(embedder, extracted_edges)
|
|
260
262
|
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
263
|
+
valid_edges_list: list[list[EntityEdge]] = await semaphore_gather(
|
|
264
|
+
*[
|
|
265
|
+
EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid)
|
|
266
|
+
for edge in extracted_edges
|
|
267
|
+
]
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
related_edges_results: list[SearchResults] = await semaphore_gather(
|
|
271
|
+
*[
|
|
272
|
+
search(
|
|
273
|
+
clients,
|
|
274
|
+
extracted_edge.fact,
|
|
275
|
+
group_ids=[extracted_edge.group_id],
|
|
276
|
+
config=EDGE_HYBRID_SEARCH_RRF,
|
|
277
|
+
search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
|
|
278
|
+
)
|
|
279
|
+
for extracted_edge, valid_edges in zip(extracted_edges, valid_edges_list, strict=True)
|
|
280
|
+
]
|
|
264
281
|
)
|
|
265
282
|
|
|
266
|
-
related_edges_lists
|
|
283
|
+
related_edges_lists: list[list[EntityEdge]] = [result.edges for result in related_edges_results]
|
|
284
|
+
|
|
285
|
+
edge_invalidation_candidate_results: list[SearchResults] = await semaphore_gather(
|
|
286
|
+
*[
|
|
287
|
+
search(
|
|
288
|
+
clients,
|
|
289
|
+
extracted_edge.fact,
|
|
290
|
+
group_ids=[extracted_edge.group_id],
|
|
291
|
+
config=EDGE_HYBRID_SEARCH_RRF,
|
|
292
|
+
search_filter=SearchFilters(),
|
|
293
|
+
)
|
|
294
|
+
for extracted_edge in extracted_edges
|
|
295
|
+
]
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
edge_invalidation_candidates: list[list[EntityEdge]] = [
|
|
299
|
+
result.edges for result in edge_invalidation_candidate_results
|
|
300
|
+
]
|
|
267
301
|
|
|
268
302
|
logger.debug(
|
|
269
303
|
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
|
|
@@ -95,6 +95,8 @@ async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
|
|
|
95
95
|
|
|
96
96
|
async def delete_all(tx):
|
|
97
97
|
await tx.run('MATCH (n) DETACH DELETE n')
|
|
98
|
+
if driver.aoss_client:
|
|
99
|
+
await driver.clear_aoss_indices()
|
|
98
100
|
|
|
99
101
|
async def delete_group_ids(tx):
|
|
100
102
|
labels = ['Entity', 'Episodic', 'Community']
|
|
@@ -151,9 +153,9 @@ async def retrieve_episodes(
|
|
|
151
153
|
|
|
152
154
|
query: LiteralString = (
|
|
153
155
|
"""
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
156
|
+
MATCH (e:Episodic)
|
|
157
|
+
WHERE e.valid_at <= $reference_time
|
|
158
|
+
"""
|
|
157
159
|
+ query_filter
|
|
158
160
|
+ """
|
|
159
161
|
RETURN
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: graphiti-core
|
|
3
|
-
Version: 0.21.
|
|
3
|
+
Version: 0.21.0rc3
|
|
4
4
|
Summary: A temporal graph building library
|
|
5
5
|
Project-URL: Homepage, https://help.getzep.com/graphiti/graphiti/overview
|
|
6
6
|
Project-URL: Repository, https://github.com/getzep/graphiti
|