graphiti-core 0.11.6rc7__py3-none-any.whl → 0.12.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphiti_core/edges.py +42 -16
- graphiti_core/embedder/gemini.py +14 -3
- graphiti_core/graphiti.py +33 -10
- graphiti_core/helpers.py +8 -27
- graphiti_core/llm_client/gemini_client.py +4 -1
- graphiti_core/models/edges/edge_db_queries.py +2 -4
- graphiti_core/prompts/dedupe_edges.py +52 -1
- graphiti_core/prompts/dedupe_nodes.py +75 -4
- graphiti_core/prompts/extract_edges.py +46 -2
- graphiti_core/prompts/invalidate_edges.py +1 -1
- graphiti_core/search/search.py +19 -45
- graphiti_core/search/search_utils.py +127 -18
- graphiti_core/utils/bulk_utils.py +19 -1
- graphiti_core/utils/maintenance/edge_operations.py +137 -10
- graphiti_core/utils/maintenance/node_operations.py +58 -20
- {graphiti_core-0.11.6rc7.dist-info → graphiti_core-0.12.0rc1.dist-info}/METADATA +1 -1
- {graphiti_core-0.11.6rc7.dist-info → graphiti_core-0.12.0rc1.dist-info}/RECORD +19 -19
- {graphiti_core-0.11.6rc7.dist-info → graphiti_core-0.12.0rc1.dist-info}/LICENSE +0 -0
- {graphiti_core-0.11.6rc7.dist-info → graphiti_core-0.12.0rc1.dist-info}/WHEEL +0 -0
graphiti_core/search/search.py
CHANGED
|
@@ -50,6 +50,9 @@ from graphiti_core.search.search_utils import (
|
|
|
50
50
|
edge_similarity_search,
|
|
51
51
|
episode_fulltext_search,
|
|
52
52
|
episode_mentions_reranker,
|
|
53
|
+
get_embeddings_for_communities,
|
|
54
|
+
get_embeddings_for_edges,
|
|
55
|
+
get_embeddings_for_nodes,
|
|
53
56
|
maximal_marginal_relevance,
|
|
54
57
|
node_bfs_search,
|
|
55
58
|
node_distance_reranker,
|
|
@@ -209,26 +212,17 @@ async def edge_search(
|
|
|
209
212
|
|
|
210
213
|
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
211
214
|
elif config.reranker == EdgeReranker.mmr:
|
|
212
|
-
await
|
|
213
|
-
|
|
215
|
+
search_result_uuids_and_vectors = await get_embeddings_for_edges(
|
|
216
|
+
driver, list(edge_uuid_map.values())
|
|
214
217
|
)
|
|
215
|
-
search_result_uuids_and_vectors = [
|
|
216
|
-
(edge.uuid, edge.fact_embedding if edge.fact_embedding is not None else [0.0] * 1024)
|
|
217
|
-
for result in search_results
|
|
218
|
-
for edge in result
|
|
219
|
-
]
|
|
220
218
|
reranked_uuids = maximal_marginal_relevance(
|
|
221
219
|
query_vector,
|
|
222
220
|
search_result_uuids_and_vectors,
|
|
223
221
|
config.mmr_lambda,
|
|
222
|
+
reranker_min_score,
|
|
224
223
|
)
|
|
225
224
|
elif config.reranker == EdgeReranker.cross_encoder:
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
229
|
-
rrf_edges = [edge_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
|
|
230
|
-
|
|
231
|
-
fact_to_uuid_map = {edge.fact: edge.uuid for edge in rrf_edges}
|
|
225
|
+
fact_to_uuid_map = {edge.fact: edge.uuid for edge in list(edge_uuid_map.values())[:limit]}
|
|
232
226
|
reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
|
|
233
227
|
reranked_uuids = [
|
|
234
228
|
fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score
|
|
@@ -311,30 +305,23 @@ async def node_search(
|
|
|
311
305
|
if config.reranker == NodeReranker.rrf:
|
|
312
306
|
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
313
307
|
elif config.reranker == NodeReranker.mmr:
|
|
314
|
-
await
|
|
315
|
-
|
|
308
|
+
search_result_uuids_and_vectors = await get_embeddings_for_nodes(
|
|
309
|
+
driver, list(node_uuid_map.values())
|
|
316
310
|
)
|
|
317
|
-
|
|
318
|
-
(node.uuid, node.name_embedding if node.name_embedding is not None else [0.0] * 1024)
|
|
319
|
-
for result in search_results
|
|
320
|
-
for node in result
|
|
321
|
-
]
|
|
311
|
+
|
|
322
312
|
reranked_uuids = maximal_marginal_relevance(
|
|
323
313
|
query_vector,
|
|
324
314
|
search_result_uuids_and_vectors,
|
|
325
315
|
config.mmr_lambda,
|
|
316
|
+
reranker_min_score,
|
|
326
317
|
)
|
|
327
318
|
elif config.reranker == NodeReranker.cross_encoder:
|
|
328
|
-
|
|
329
|
-
rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
330
|
-
rrf_results = [node_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
|
|
319
|
+
name_to_uuid_map = {node.name: node.uuid for node in list(node_uuid_map.values())}
|
|
331
320
|
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
|
|
321
|
+
reranked_node_names = await cross_encoder.rank(query, list(name_to_uuid_map.keys()))
|
|
335
322
|
reranked_uuids = [
|
|
336
|
-
|
|
337
|
-
for
|
|
323
|
+
name_to_uuid_map[name]
|
|
324
|
+
for name, score in reranked_node_names
|
|
338
325
|
if score >= reranker_min_score
|
|
339
326
|
]
|
|
340
327
|
elif config.reranker == NodeReranker.episode_mentions:
|
|
@@ -437,25 +424,12 @@ async def community_search(
|
|
|
437
424
|
if config.reranker == CommunityReranker.rrf:
|
|
438
425
|
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
439
426
|
elif config.reranker == CommunityReranker.mmr:
|
|
440
|
-
await
|
|
441
|
-
|
|
442
|
-
community.load_name_embedding(driver)
|
|
443
|
-
for result in search_results
|
|
444
|
-
for community in result
|
|
445
|
-
]
|
|
427
|
+
search_result_uuids_and_vectors = await get_embeddings_for_communities(
|
|
428
|
+
driver, list(community_uuid_map.values())
|
|
446
429
|
)
|
|
447
|
-
|
|
448
|
-
(
|
|
449
|
-
community.uuid,
|
|
450
|
-
community.name_embedding if community.name_embedding is not None else [0.0] * 1024,
|
|
451
|
-
)
|
|
452
|
-
for result in search_results
|
|
453
|
-
for community in result
|
|
454
|
-
]
|
|
430
|
+
|
|
455
431
|
reranked_uuids = maximal_marginal_relevance(
|
|
456
|
-
query_vector,
|
|
457
|
-
search_result_uuids_and_vectors,
|
|
458
|
-
config.mmr_lambda,
|
|
432
|
+
query_vector, search_result_uuids_and_vectors, config.mmr_lambda, reranker_min_score
|
|
459
433
|
)
|
|
460
434
|
elif config.reranker == CommunityReranker.cross_encoder:
|
|
461
435
|
name_to_uuid_map = {node.name: node.uuid for result in search_results for node in result}
|
|
@@ -21,6 +21,7 @@ from typing import Any
|
|
|
21
21
|
|
|
22
22
|
import numpy as np
|
|
23
23
|
from neo4j import AsyncDriver, Query
|
|
24
|
+
from numpy._typing import NDArray
|
|
24
25
|
from typing_extensions import LiteralString
|
|
25
26
|
|
|
26
27
|
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
|
@@ -173,7 +174,8 @@ async def edge_fulltext_search(
|
|
|
173
174
|
r.episodes AS episodes,
|
|
174
175
|
r.expired_at AS expired_at,
|
|
175
176
|
r.valid_at AS valid_at,
|
|
176
|
-
r.invalid_at AS invalid_at
|
|
177
|
+
r.invalid_at AS invalid_at,
|
|
178
|
+
properties(r) AS attributes
|
|
177
179
|
ORDER BY score DESC LIMIT $limit
|
|
178
180
|
"""
|
|
179
181
|
)
|
|
@@ -242,7 +244,8 @@ async def edge_similarity_search(
|
|
|
242
244
|
r.episodes AS episodes,
|
|
243
245
|
r.expired_at AS expired_at,
|
|
244
246
|
r.valid_at AS valid_at,
|
|
245
|
-
r.invalid_at AS invalid_at
|
|
247
|
+
r.invalid_at AS invalid_at,
|
|
248
|
+
properties(r) AS attributes
|
|
246
249
|
ORDER BY score DESC
|
|
247
250
|
LIMIT $limit
|
|
248
251
|
"""
|
|
@@ -300,7 +303,8 @@ async def edge_bfs_search(
|
|
|
300
303
|
r.episodes AS episodes,
|
|
301
304
|
r.expired_at AS expired_at,
|
|
302
305
|
r.valid_at AS valid_at,
|
|
303
|
-
r.invalid_at AS invalid_at
|
|
306
|
+
r.invalid_at AS invalid_at,
|
|
307
|
+
properties(r) AS attributes
|
|
304
308
|
LIMIT $limit
|
|
305
309
|
"""
|
|
306
310
|
)
|
|
@@ -336,10 +340,10 @@ async def node_fulltext_search(
|
|
|
336
340
|
|
|
337
341
|
query = (
|
|
338
342
|
"""
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
+
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
|
344
|
+
YIELD node AS n, score
|
|
345
|
+
WHERE n:Entity
|
|
346
|
+
"""
|
|
343
347
|
+ filter_query
|
|
344
348
|
+ ENTITY_NODE_RETURN
|
|
345
349
|
+ """
|
|
@@ -770,7 +774,8 @@ async def get_relevant_edges(
|
|
|
770
774
|
episodes: e.episodes,
|
|
771
775
|
expired_at: e.expired_at,
|
|
772
776
|
valid_at: e.valid_at,
|
|
773
|
-
invalid_at: e.invalid_at
|
|
777
|
+
invalid_at: e.invalid_at,
|
|
778
|
+
attributes: properties(e)
|
|
774
779
|
})[..$limit] AS matches
|
|
775
780
|
"""
|
|
776
781
|
)
|
|
@@ -836,7 +841,8 @@ async def get_edge_invalidation_candidates(
|
|
|
836
841
|
episodes: e.episodes,
|
|
837
842
|
expired_at: e.expired_at,
|
|
838
843
|
valid_at: e.valid_at,
|
|
839
|
-
invalid_at: e.invalid_at
|
|
844
|
+
invalid_at: e.invalid_at,
|
|
845
|
+
attributes: properties(e)
|
|
840
846
|
})[..$limit] AS matches
|
|
841
847
|
"""
|
|
842
848
|
)
|
|
@@ -899,6 +905,7 @@ async def node_distance_reranker(
|
|
|
899
905
|
node_uuids=filtered_uuids,
|
|
900
906
|
center_uuid=center_node_uuid,
|
|
901
907
|
database_=DEFAULT_DATABASE,
|
|
908
|
+
routing_='r',
|
|
902
909
|
)
|
|
903
910
|
|
|
904
911
|
for result in path_results:
|
|
@@ -939,6 +946,7 @@ async def episode_mentions_reranker(
|
|
|
939
946
|
query,
|
|
940
947
|
node_uuids=sorted_uuids,
|
|
941
948
|
database_=DEFAULT_DATABASE,
|
|
949
|
+
routing_='r',
|
|
942
950
|
)
|
|
943
951
|
|
|
944
952
|
for result in results:
|
|
@@ -952,15 +960,116 @@ async def episode_mentions_reranker(
|
|
|
952
960
|
|
|
953
961
|
def maximal_marginal_relevance(
|
|
954
962
|
query_vector: list[float],
|
|
955
|
-
candidates:
|
|
963
|
+
candidates: dict[str, list[float]],
|
|
956
964
|
mmr_lambda: float = DEFAULT_MMR_LAMBDA,
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
965
|
+
min_score: float = -2.0,
|
|
966
|
+
) -> list[str]:
|
|
967
|
+
start = time()
|
|
968
|
+
query_array = np.array(query_vector)
|
|
969
|
+
candidate_arrays: dict[str, NDArray] = {}
|
|
970
|
+
for uuid, embedding in candidates.items():
|
|
971
|
+
candidate_arrays[uuid] = normalize_l2(embedding)
|
|
972
|
+
|
|
973
|
+
uuids: list[str] = list(candidate_arrays.keys())
|
|
974
|
+
|
|
975
|
+
similarity_matrix = np.zeros((len(uuids), len(uuids)))
|
|
976
|
+
|
|
977
|
+
for i, uuid_1 in enumerate(uuids):
|
|
978
|
+
for j, uuid_2 in enumerate(uuids[:i]):
|
|
979
|
+
u = candidate_arrays[uuid_1]
|
|
980
|
+
v = candidate_arrays[uuid_2]
|
|
981
|
+
similarity = np.dot(u, v)
|
|
982
|
+
|
|
983
|
+
similarity_matrix[i, j] = similarity
|
|
984
|
+
similarity_matrix[j, i] = similarity
|
|
985
|
+
|
|
986
|
+
mmr_scores: dict[str, float] = {}
|
|
987
|
+
for i, uuid in enumerate(uuids):
|
|
988
|
+
max_sim = np.max(similarity_matrix[i, :])
|
|
989
|
+
mmr = mmr_lambda * np.dot(query_array, candidate_arrays[uuid]) + (mmr_lambda - 1) * max_sim
|
|
990
|
+
mmr_scores[uuid] = mmr
|
|
991
|
+
|
|
992
|
+
uuids.sort(reverse=True, key=lambda c: mmr_scores[c])
|
|
993
|
+
|
|
994
|
+
end = time()
|
|
995
|
+
logger.debug(f'Completed MMR reranking in {(end - start) * 1000} ms')
|
|
996
|
+
|
|
997
|
+
return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score]
|
|
998
|
+
|
|
999
|
+
|
|
1000
|
+
async def get_embeddings_for_nodes(
|
|
1001
|
+
driver: AsyncDriver, nodes: list[EntityNode]
|
|
1002
|
+
) -> dict[str, list[float]]:
|
|
1003
|
+
query: LiteralString = """MATCH (n:Entity)
|
|
1004
|
+
WHERE n.uuid IN $node_uuids
|
|
1005
|
+
RETURN DISTINCT
|
|
1006
|
+
n.uuid AS uuid,
|
|
1007
|
+
n.name_embedding AS name_embedding
|
|
1008
|
+
"""
|
|
1009
|
+
|
|
1010
|
+
results, _, _ = await driver.execute_query(
|
|
1011
|
+
query, node_uuids=[node.uuid for node in nodes], database_=DEFAULT_DATABASE, routing_='r'
|
|
1012
|
+
)
|
|
1013
|
+
|
|
1014
|
+
embeddings_dict: dict[str, list[float]] = {}
|
|
1015
|
+
for result in results:
|
|
1016
|
+
uuid: str = result.get('uuid')
|
|
1017
|
+
embedding: list[float] = result.get('name_embedding')
|
|
1018
|
+
if uuid is not None and embedding is not None:
|
|
1019
|
+
embeddings_dict[uuid] = embedding
|
|
963
1020
|
|
|
964
|
-
|
|
1021
|
+
return embeddings_dict
|
|
1022
|
+
|
|
1023
|
+
|
|
1024
|
+
async def get_embeddings_for_communities(
|
|
1025
|
+
driver: AsyncDriver, communities: list[CommunityNode]
|
|
1026
|
+
) -> dict[str, list[float]]:
|
|
1027
|
+
query: LiteralString = """MATCH (c:Community)
|
|
1028
|
+
WHERE c.uuid IN $community_uuids
|
|
1029
|
+
RETURN DISTINCT
|
|
1030
|
+
c.uuid AS uuid,
|
|
1031
|
+
c.name_embedding AS name_embedding
|
|
1032
|
+
"""
|
|
1033
|
+
|
|
1034
|
+
results, _, _ = await driver.execute_query(
|
|
1035
|
+
query,
|
|
1036
|
+
community_uuids=[community.uuid for community in communities],
|
|
1037
|
+
database_=DEFAULT_DATABASE,
|
|
1038
|
+
routing_='r',
|
|
1039
|
+
)
|
|
1040
|
+
|
|
1041
|
+
embeddings_dict: dict[str, list[float]] = {}
|
|
1042
|
+
for result in results:
|
|
1043
|
+
uuid: str = result.get('uuid')
|
|
1044
|
+
embedding: list[float] = result.get('name_embedding')
|
|
1045
|
+
if uuid is not None and embedding is not None:
|
|
1046
|
+
embeddings_dict[uuid] = embedding
|
|
1047
|
+
|
|
1048
|
+
return embeddings_dict
|
|
1049
|
+
|
|
1050
|
+
|
|
1051
|
+
async def get_embeddings_for_edges(
|
|
1052
|
+
driver: AsyncDriver, edges: list[EntityEdge]
|
|
1053
|
+
) -> dict[str, list[float]]:
|
|
1054
|
+
query: LiteralString = """MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
|
1055
|
+
WHERE e.uuid IN $edge_uuids
|
|
1056
|
+
RETURN DISTINCT
|
|
1057
|
+
e.uuid AS uuid,
|
|
1058
|
+
e.fact_embedding AS fact_embedding
|
|
1059
|
+
"""
|
|
1060
|
+
|
|
1061
|
+
results, _, _ = await driver.execute_query(
|
|
1062
|
+
query,
|
|
1063
|
+
edge_uuids=[edge.uuid for edge in edges],
|
|
1064
|
+
database_=DEFAULT_DATABASE,
|
|
1065
|
+
routing_='r',
|
|
1066
|
+
)
|
|
1067
|
+
|
|
1068
|
+
embeddings_dict: dict[str, list[float]] = {}
|
|
1069
|
+
for result in results:
|
|
1070
|
+
uuid: str = result.get('uuid')
|
|
1071
|
+
embedding: list[float] = result.get('fact_embedding')
|
|
1072
|
+
if uuid is not None and embedding is not None:
|
|
1073
|
+
embeddings_dict[uuid] = embedding
|
|
965
1074
|
|
|
966
|
-
return
|
|
1075
|
+
return embeddings_dict
|
|
@@ -137,16 +137,34 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
137
137
|
entity_data['labels'] = list(set(node.labels + ['Entity']))
|
|
138
138
|
nodes.append(entity_data)
|
|
139
139
|
|
|
140
|
+
edges: list[dict[str, Any]] = []
|
|
140
141
|
for edge in entity_edges:
|
|
141
142
|
if edge.fact_embedding is None:
|
|
142
143
|
await edge.generate_embedding(embedder)
|
|
144
|
+
edge_data: dict[str, Any] = {
|
|
145
|
+
'uuid': edge.uuid,
|
|
146
|
+
'source_node_uuid': edge.source_node_uuid,
|
|
147
|
+
'target_node_uuid': edge.target_node_uuid,
|
|
148
|
+
'name': edge.name,
|
|
149
|
+
'fact': edge.fact,
|
|
150
|
+
'fact_embedding': edge.fact_embedding,
|
|
151
|
+
'group_id': edge.group_id,
|
|
152
|
+
'episodes': edge.episodes,
|
|
153
|
+
'created_at': edge.created_at,
|
|
154
|
+
'expired_at': edge.expired_at,
|
|
155
|
+
'valid_at': edge.valid_at,
|
|
156
|
+
'invalid_at': edge.invalid_at,
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
edge_data.update(edge.attributes or {})
|
|
160
|
+
edges.append(edge_data)
|
|
143
161
|
|
|
144
162
|
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
|
|
145
163
|
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes)
|
|
146
164
|
await tx.run(
|
|
147
165
|
EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
|
|
148
166
|
)
|
|
149
|
-
await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=
|
|
167
|
+
await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=edges)
|
|
150
168
|
|
|
151
169
|
|
|
152
170
|
async def extract_nodes_and_edges_bulk(
|
|
@@ -18,6 +18,8 @@ import logging
|
|
|
18
18
|
from datetime import datetime
|
|
19
19
|
from time import time
|
|
20
20
|
|
|
21
|
+
from pydantic import BaseModel
|
|
22
|
+
|
|
21
23
|
from graphiti_core.edges import (
|
|
22
24
|
CommunityEdge,
|
|
23
25
|
EntityEdge,
|
|
@@ -35,9 +37,6 @@ from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
|
|
|
35
37
|
from graphiti_core.search.search_filters import SearchFilters
|
|
36
38
|
from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
|
|
37
39
|
from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
|
|
38
|
-
from graphiti_core.utils.maintenance.temporal_operations import (
|
|
39
|
-
get_edge_contradictions,
|
|
40
|
-
)
|
|
41
40
|
|
|
42
41
|
logger = logging.getLogger(__name__)
|
|
43
42
|
|
|
@@ -86,6 +85,7 @@ async def extract_edges(
|
|
|
86
85
|
nodes: list[EntityNode],
|
|
87
86
|
previous_episodes: list[EpisodicNode],
|
|
88
87
|
group_id: str = '',
|
|
88
|
+
edge_types: dict[str, BaseModel] | None = None,
|
|
89
89
|
) -> list[EntityEdge]:
|
|
90
90
|
start = time()
|
|
91
91
|
|
|
@@ -94,12 +94,25 @@ async def extract_edges(
|
|
|
94
94
|
|
|
95
95
|
node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
|
|
96
96
|
|
|
97
|
+
edge_types_context = (
|
|
98
|
+
[
|
|
99
|
+
{
|
|
100
|
+
'fact_type_name': type_name,
|
|
101
|
+
'fact_type_description': type_model.__doc__,
|
|
102
|
+
}
|
|
103
|
+
for type_name, type_model in edge_types.items()
|
|
104
|
+
]
|
|
105
|
+
if edge_types is not None
|
|
106
|
+
else []
|
|
107
|
+
)
|
|
108
|
+
|
|
97
109
|
# Prepare context for LLM
|
|
98
110
|
context = {
|
|
99
111
|
'episode_content': episode.content,
|
|
100
112
|
'nodes': [node.name for node in nodes],
|
|
101
113
|
'previous_episodes': [ep.content for ep in previous_episodes],
|
|
102
114
|
'reference_time': episode.valid_at,
|
|
115
|
+
'edge_types': edge_types_context,
|
|
103
116
|
'custom_prompt': '',
|
|
104
117
|
}
|
|
105
118
|
|
|
@@ -236,6 +249,9 @@ async def resolve_extracted_edges(
|
|
|
236
249
|
clients: GraphitiClients,
|
|
237
250
|
extracted_edges: list[EntityEdge],
|
|
238
251
|
episode: EpisodicNode,
|
|
252
|
+
entities: list[EntityNode],
|
|
253
|
+
edge_types: dict[str, BaseModel],
|
|
254
|
+
edge_type_map: dict[tuple[str, str], list[str]],
|
|
239
255
|
) -> tuple[list[EntityEdge], list[EntityEdge]]:
|
|
240
256
|
driver = clients.driver
|
|
241
257
|
llm_client = clients.llm_client
|
|
@@ -245,7 +261,7 @@ async def resolve_extracted_edges(
|
|
|
245
261
|
|
|
246
262
|
search_results: tuple[list[list[EntityEdge]], list[list[EntityEdge]]] = await semaphore_gather(
|
|
247
263
|
get_relevant_edges(driver, extracted_edges, SearchFilters()),
|
|
248
|
-
get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters()),
|
|
264
|
+
get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters(), 0.2),
|
|
249
265
|
)
|
|
250
266
|
|
|
251
267
|
related_edges_lists, edge_invalidation_candidates = search_results
|
|
@@ -254,15 +270,50 @@ async def resolve_extracted_edges(
|
|
|
254
270
|
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
|
|
255
271
|
)
|
|
256
272
|
|
|
273
|
+
# Build entity hash table
|
|
274
|
+
uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
|
|
275
|
+
|
|
276
|
+
# Determine which edge types are relevant for each edge
|
|
277
|
+
edge_types_lst: list[dict[str, BaseModel]] = []
|
|
278
|
+
for extracted_edge in extracted_edges:
|
|
279
|
+
source_node_labels = uuid_entity_map[extracted_edge.source_node_uuid].labels
|
|
280
|
+
target_node_labels = uuid_entity_map[extracted_edge.target_node_uuid].labels
|
|
281
|
+
label_tuples = [
|
|
282
|
+
(source_label, target_label)
|
|
283
|
+
for source_label in source_node_labels
|
|
284
|
+
for target_label in target_node_labels
|
|
285
|
+
]
|
|
286
|
+
|
|
287
|
+
extracted_edge_types = {}
|
|
288
|
+
for label_tuple in label_tuples:
|
|
289
|
+
type_names = edge_type_map.get(label_tuple, [])
|
|
290
|
+
for type_name in type_names:
|
|
291
|
+
type_model = edge_types.get(type_name)
|
|
292
|
+
if type_model is None:
|
|
293
|
+
continue
|
|
294
|
+
|
|
295
|
+
extracted_edge_types[type_name] = type_model
|
|
296
|
+
|
|
297
|
+
edge_types_lst.append(extracted_edge_types)
|
|
298
|
+
|
|
257
299
|
# resolve edges with related edges in the graph and find invalidation candidates
|
|
258
300
|
results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
|
|
259
301
|
await semaphore_gather(
|
|
260
302
|
*[
|
|
261
303
|
resolve_extracted_edge(
|
|
262
|
-
llm_client,
|
|
304
|
+
llm_client,
|
|
305
|
+
extracted_edge,
|
|
306
|
+
related_edges,
|
|
307
|
+
existing_edges,
|
|
308
|
+
episode,
|
|
309
|
+
extracted_edge_types,
|
|
263
310
|
)
|
|
264
|
-
for extracted_edge, related_edges, existing_edges in zip(
|
|
265
|
-
extracted_edges,
|
|
311
|
+
for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
|
|
312
|
+
extracted_edges,
|
|
313
|
+
related_edges_lists,
|
|
314
|
+
edge_invalidation_candidates,
|
|
315
|
+
edge_types_lst,
|
|
316
|
+
strict=True,
|
|
266
317
|
)
|
|
267
318
|
]
|
|
268
319
|
)
|
|
@@ -326,10 +377,86 @@ async def resolve_extracted_edge(
|
|
|
326
377
|
related_edges: list[EntityEdge],
|
|
327
378
|
existing_edges: list[EntityEdge],
|
|
328
379
|
episode: EpisodicNode,
|
|
380
|
+
edge_types: dict[str, BaseModel] | None = None,
|
|
329
381
|
) -> tuple[EntityEdge, list[EntityEdge]]:
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
382
|
+
if len(related_edges) == 0 and len(existing_edges) == 0:
|
|
383
|
+
return extracted_edge, []
|
|
384
|
+
|
|
385
|
+
start = time()
|
|
386
|
+
|
|
387
|
+
# Prepare context for LLM
|
|
388
|
+
related_edges_context = [
|
|
389
|
+
{'id': edge.uuid, 'fact': edge.fact} for i, edge in enumerate(related_edges)
|
|
390
|
+
]
|
|
391
|
+
|
|
392
|
+
invalidation_edge_candidates_context = [
|
|
393
|
+
{'id': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
|
|
394
|
+
]
|
|
395
|
+
|
|
396
|
+
edge_types_context = (
|
|
397
|
+
[
|
|
398
|
+
{
|
|
399
|
+
'fact_type_id': i,
|
|
400
|
+
'fact_type_name': type_name,
|
|
401
|
+
'fact_type_description': type_model.__doc__,
|
|
402
|
+
}
|
|
403
|
+
for i, (type_name, type_model) in enumerate(edge_types.items())
|
|
404
|
+
]
|
|
405
|
+
if edge_types is not None
|
|
406
|
+
else []
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
context = {
|
|
410
|
+
'existing_edges': related_edges_context,
|
|
411
|
+
'new_edge': extracted_edge.fact,
|
|
412
|
+
'edge_invalidation_candidates': invalidation_edge_candidates_context,
|
|
413
|
+
'edge_types': edge_types_context,
|
|
414
|
+
}
|
|
415
|
+
|
|
416
|
+
llm_response = await llm_client.generate_response(
|
|
417
|
+
prompt_library.dedupe_edges.resolve_edge(context),
|
|
418
|
+
response_model=EdgeDuplicate,
|
|
419
|
+
model_size=ModelSize.small,
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
duplicate_fact_id: int = llm_response.get('duplicate_fact_id', -1)
|
|
423
|
+
|
|
424
|
+
resolved_edge = (
|
|
425
|
+
related_edges[duplicate_fact_id]
|
|
426
|
+
if 0 <= duplicate_fact_id < len(related_edges)
|
|
427
|
+
else extracted_edge
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
if duplicate_fact_id >= 0 and episode is not None:
|
|
431
|
+
resolved_edge.episodes.append(episode.uuid)
|
|
432
|
+
|
|
433
|
+
contradicted_facts: list[int] = llm_response.get('contradicted_facts', [])
|
|
434
|
+
|
|
435
|
+
invalidation_candidates: list[EntityEdge] = [existing_edges[i] for i in contradicted_facts]
|
|
436
|
+
|
|
437
|
+
fact_type: str = str(llm_response.get('fact_type'))
|
|
438
|
+
if fact_type.upper() != 'DEFAULT' and edge_types is not None:
|
|
439
|
+
resolved_edge.name = fact_type
|
|
440
|
+
|
|
441
|
+
edge_attributes_context = {
|
|
442
|
+
'message': episode.content,
|
|
443
|
+
'reference_time': episode.valid_at,
|
|
444
|
+
'fact': resolved_edge.fact,
|
|
445
|
+
}
|
|
446
|
+
|
|
447
|
+
edge_model = edge_types.get(fact_type)
|
|
448
|
+
|
|
449
|
+
edge_attributes_response = await llm_client.generate_response(
|
|
450
|
+
prompt_library.extract_edges.extract_attributes(edge_attributes_context),
|
|
451
|
+
response_model=edge_model, # type: ignore
|
|
452
|
+
model_size=ModelSize.small,
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
resolved_edge.attributes = edge_attributes_response
|
|
456
|
+
|
|
457
|
+
end = time()
|
|
458
|
+
logger.debug(
|
|
459
|
+
f'Resolved Edge: {extracted_edge.name} is {resolved_edge.name}, in {(end - start) * 1000} ms'
|
|
333
460
|
)
|
|
334
461
|
|
|
335
462
|
now = utc_now()
|
|
@@ -29,7 +29,7 @@ from graphiti_core.llm_client import LLMClient
|
|
|
29
29
|
from graphiti_core.llm_client.config import ModelSize
|
|
30
30
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
|
31
31
|
from graphiti_core.prompts import prompt_library
|
|
32
|
-
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
|
|
32
|
+
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
|
|
33
33
|
from graphiti_core.prompts.extract_nodes import (
|
|
34
34
|
ExtractedEntities,
|
|
35
35
|
ExtractedEntity,
|
|
@@ -243,28 +243,65 @@ async def resolve_extracted_nodes(
|
|
|
243
243
|
|
|
244
244
|
existing_nodes_lists: list[list[EntityNode]] = [result.nodes for result in search_results]
|
|
245
245
|
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
246
|
+
entity_types_dict: dict[str, BaseModel] = entity_types if entity_types is not None else {}
|
|
247
|
+
|
|
248
|
+
# Prepare context for LLM
|
|
249
|
+
extracted_nodes_context = [
|
|
250
|
+
{
|
|
251
|
+
'id': i,
|
|
252
|
+
'name': node.name,
|
|
253
|
+
'entity_type': node.labels,
|
|
254
|
+
'entity_type_description': entity_types_dict.get(
|
|
255
|
+
next((item for item in node.labels if item != 'Entity'), '')
|
|
256
|
+
).__doc__
|
|
257
|
+
or 'Default Entity Type',
|
|
258
|
+
'duplication_candidates': [
|
|
259
|
+
{
|
|
260
|
+
**{
|
|
261
|
+
'idx': j,
|
|
262
|
+
'name': candidate.name,
|
|
263
|
+
'entity_types': candidate.labels,
|
|
264
|
+
},
|
|
265
|
+
**candidate.attributes,
|
|
266
|
+
}
|
|
267
|
+
for j, candidate in enumerate(existing_nodes_lists[i])
|
|
268
|
+
],
|
|
269
|
+
}
|
|
270
|
+
for i, node in enumerate(extracted_nodes)
|
|
271
|
+
]
|
|
272
|
+
|
|
273
|
+
context = {
|
|
274
|
+
'extracted_nodes': extracted_nodes_context,
|
|
275
|
+
'episode_content': episode.content if episode is not None else '',
|
|
276
|
+
'previous_episodes': [ep.content for ep in previous_episodes]
|
|
277
|
+
if previous_episodes is not None
|
|
278
|
+
else [],
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
llm_response = await llm_client.generate_response(
|
|
282
|
+
prompt_library.dedupe_nodes.nodes(context),
|
|
283
|
+
response_model=NodeResolutions,
|
|
264
284
|
)
|
|
265
285
|
|
|
286
|
+
node_resolutions: list = llm_response.get('entity_resolutions', [])
|
|
287
|
+
|
|
288
|
+
resolved_nodes: list[EntityNode] = []
|
|
266
289
|
uuid_map: dict[str, str] = {}
|
|
267
|
-
for
|
|
290
|
+
for resolution in node_resolutions:
|
|
291
|
+
resolution_id = resolution.get('id', -1)
|
|
292
|
+
duplicate_idx = resolution.get('duplicate_idx', -1)
|
|
293
|
+
|
|
294
|
+
extracted_node = extracted_nodes[resolution_id]
|
|
295
|
+
|
|
296
|
+
resolved_node = (
|
|
297
|
+
existing_nodes_lists[resolution_id][duplicate_idx]
|
|
298
|
+
if 0 <= duplicate_idx < len(existing_nodes_lists[resolution_id])
|
|
299
|
+
else extracted_node
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
resolved_node.name = resolution.get('name')
|
|
303
|
+
|
|
304
|
+
resolved_nodes.append(resolved_node)
|
|
268
305
|
uuid_map[extracted_node.uuid] = resolved_node.uuid
|
|
269
306
|
|
|
270
307
|
logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}')
|
|
@@ -410,6 +447,7 @@ async def extract_attributes_from_node(
|
|
|
410
447
|
llm_response = await llm_client.generate_response(
|
|
411
448
|
prompt_library.extract_nodes.extract_attributes(summary_context),
|
|
412
449
|
response_model=entity_attributes_model,
|
|
450
|
+
model_size=ModelSize.small,
|
|
413
451
|
)
|
|
414
452
|
|
|
415
453
|
node.summary = llm_response.get('summary', node.summary)
|