graphiti-core 0.11.4__py3-none-any.whl → 0.11.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/edges.py +18 -4
- graphiti_core/graphiti.py +12 -10
- graphiti_core/helpers.py +8 -10
- graphiti_core/llm_client/anthropic_client.py +0 -13
- graphiti_core/nodes.py +30 -7
- graphiti_core/prompts/dedupe_edges.py +44 -1
- graphiti_core/prompts/dedupe_nodes.py +85 -12
- graphiti_core/prompts/extract_nodes.py +1 -1
- graphiti_core/prompts/invalidate_edges.py +1 -1
- graphiti_core/prompts/summarize_nodes.py +4 -4
- graphiti_core/search/search.py +25 -42
- graphiti_core/search/search_utils.py +117 -20
- graphiti_core/utils/bulk_utils.py +15 -1
- graphiti_core/utils/maintenance/community_operations.py +0 -2
- graphiti_core/utils/maintenance/edge_operations.py +63 -15
- graphiti_core/utils/maintenance/node_operations.py +78 -35
- {graphiti_core-0.11.4.dist-info → graphiti_core-0.11.6.dist-info}/METADATA +1 -2
- {graphiti_core-0.11.4.dist-info → graphiti_core-0.11.6.dist-info}/RECORD +20 -20
- {graphiti_core-0.11.4.dist-info → graphiti_core-0.11.6.dist-info}/LICENSE +0 -0
- {graphiti_core-0.11.4.dist-info → graphiti_core-0.11.6.dist-info}/WHEEL +0 -0
|
@@ -21,6 +21,7 @@ from typing import Any
|
|
|
21
21
|
|
|
22
22
|
import numpy as np
|
|
23
23
|
from neo4j import AsyncDriver, Query
|
|
24
|
+
from numpy._typing import NDArray
|
|
24
25
|
from typing_extensions import LiteralString
|
|
25
26
|
|
|
26
27
|
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
|
@@ -101,7 +102,6 @@ async def get_mentioned_nodes(
|
|
|
101
102
|
n.uuid As uuid,
|
|
102
103
|
n.group_id AS group_id,
|
|
103
104
|
n.name AS name,
|
|
104
|
-
n.name_embedding AS name_embedding,
|
|
105
105
|
n.created_at AS created_at,
|
|
106
106
|
n.summary AS summary,
|
|
107
107
|
labels(n) AS labels,
|
|
@@ -128,7 +128,6 @@ async def get_communities_by_nodes(
|
|
|
128
128
|
c.uuid As uuid,
|
|
129
129
|
c.group_id AS group_id,
|
|
130
130
|
c.name AS name,
|
|
131
|
-
c.name_embedding AS name_embedding
|
|
132
131
|
c.created_at AS created_at,
|
|
133
132
|
c.summary AS summary
|
|
134
133
|
""",
|
|
@@ -172,7 +171,6 @@ async def edge_fulltext_search(
|
|
|
172
171
|
r.created_at AS created_at,
|
|
173
172
|
r.name AS name,
|
|
174
173
|
r.fact AS fact,
|
|
175
|
-
r.fact_embedding AS fact_embedding,
|
|
176
174
|
r.episodes AS episodes,
|
|
177
175
|
r.expired_at AS expired_at,
|
|
178
176
|
r.valid_at AS valid_at,
|
|
@@ -242,7 +240,6 @@ async def edge_similarity_search(
|
|
|
242
240
|
r.created_at AS created_at,
|
|
243
241
|
r.name AS name,
|
|
244
242
|
r.fact AS fact,
|
|
245
|
-
r.fact_embedding AS fact_embedding,
|
|
246
243
|
r.episodes AS episodes,
|
|
247
244
|
r.expired_at AS expired_at,
|
|
248
245
|
r.valid_at AS valid_at,
|
|
@@ -301,7 +298,6 @@ async def edge_bfs_search(
|
|
|
301
298
|
r.created_at AS created_at,
|
|
302
299
|
r.name AS name,
|
|
303
300
|
r.fact AS fact,
|
|
304
|
-
r.fact_embedding AS fact_embedding,
|
|
305
301
|
r.episodes AS episodes,
|
|
306
302
|
r.expired_at AS expired_at,
|
|
307
303
|
r.valid_at AS valid_at,
|
|
@@ -341,10 +337,10 @@ async def node_fulltext_search(
|
|
|
341
337
|
|
|
342
338
|
query = (
|
|
343
339
|
"""
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
340
|
+
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
|
341
|
+
YIELD node AS n, score
|
|
342
|
+
WHERE n:Entity
|
|
343
|
+
"""
|
|
348
344
|
+ filter_query
|
|
349
345
|
+ ENTITY_NODE_RETURN
|
|
350
346
|
+ """
|
|
@@ -510,7 +506,6 @@ async def community_fulltext_search(
|
|
|
510
506
|
comm.uuid AS uuid,
|
|
511
507
|
comm.group_id AS group_id,
|
|
512
508
|
comm.name AS name,
|
|
513
|
-
comm.name_embedding AS name_embedding,
|
|
514
509
|
comm.created_at AS created_at,
|
|
515
510
|
comm.summary AS summary
|
|
516
511
|
ORDER BY score DESC
|
|
@@ -555,7 +550,6 @@ async def community_similarity_search(
|
|
|
555
550
|
comm.uuid As uuid,
|
|
556
551
|
comm.group_id AS group_id,
|
|
557
552
|
comm.name AS name,
|
|
558
|
-
comm.name_embedding AS name_embedding,
|
|
559
553
|
comm.created_at AS created_at,
|
|
560
554
|
comm.summary AS summary
|
|
561
555
|
ORDER BY score DESC
|
|
@@ -906,6 +900,7 @@ async def node_distance_reranker(
|
|
|
906
900
|
node_uuids=filtered_uuids,
|
|
907
901
|
center_uuid=center_node_uuid,
|
|
908
902
|
database_=DEFAULT_DATABASE,
|
|
903
|
+
routing_='r',
|
|
909
904
|
)
|
|
910
905
|
|
|
911
906
|
for result in path_results:
|
|
@@ -946,6 +941,7 @@ async def episode_mentions_reranker(
|
|
|
946
941
|
query,
|
|
947
942
|
node_uuids=sorted_uuids,
|
|
948
943
|
database_=DEFAULT_DATABASE,
|
|
944
|
+
routing_='r',
|
|
949
945
|
)
|
|
950
946
|
|
|
951
947
|
for result in results:
|
|
@@ -959,15 +955,116 @@ async def episode_mentions_reranker(
|
|
|
959
955
|
|
|
960
956
|
def maximal_marginal_relevance(
|
|
961
957
|
query_vector: list[float],
|
|
962
|
-
candidates:
|
|
958
|
+
candidates: dict[str, list[float]],
|
|
963
959
|
mmr_lambda: float = DEFAULT_MMR_LAMBDA,
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
960
|
+
min_score: float = -2.0,
|
|
961
|
+
) -> list[str]:
|
|
962
|
+
start = time()
|
|
963
|
+
query_array = np.array(query_vector)
|
|
964
|
+
candidate_arrays: dict[str, NDArray] = {}
|
|
965
|
+
for uuid, embedding in candidates.items():
|
|
966
|
+
candidate_arrays[uuid] = normalize_l2(embedding)
|
|
967
|
+
|
|
968
|
+
uuids: list[str] = list(candidate_arrays.keys())
|
|
969
|
+
|
|
970
|
+
similarity_matrix = np.zeros((len(uuids), len(uuids)))
|
|
971
|
+
|
|
972
|
+
for i, uuid_1 in enumerate(uuids):
|
|
973
|
+
for j, uuid_2 in enumerate(uuids[:i]):
|
|
974
|
+
u = candidate_arrays[uuid_1]
|
|
975
|
+
v = candidate_arrays[uuid_2]
|
|
976
|
+
similarity = np.dot(u, v)
|
|
977
|
+
|
|
978
|
+
similarity_matrix[i, j] = similarity
|
|
979
|
+
similarity_matrix[j, i] = similarity
|
|
980
|
+
|
|
981
|
+
mmr_scores: dict[str, float] = {}
|
|
982
|
+
for i, uuid in enumerate(uuids):
|
|
983
|
+
max_sim = np.max(similarity_matrix[i, :])
|
|
984
|
+
mmr = mmr_lambda * np.dot(query_array, candidate_arrays[uuid]) + (mmr_lambda - 1) * max_sim
|
|
985
|
+
mmr_scores[uuid] = mmr
|
|
986
|
+
|
|
987
|
+
uuids.sort(reverse=True, key=lambda c: mmr_scores[c])
|
|
988
|
+
|
|
989
|
+
end = time()
|
|
990
|
+
logger.debug(f'Completed MMR reranking in {(end - start) * 1000} ms')
|
|
991
|
+
|
|
992
|
+
return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score]
|
|
993
|
+
|
|
994
|
+
|
|
995
|
+
async def get_embeddings_for_nodes(
|
|
996
|
+
driver: AsyncDriver, nodes: list[EntityNode]
|
|
997
|
+
) -> dict[str, list[float]]:
|
|
998
|
+
query: LiteralString = """MATCH (n:Entity)
|
|
999
|
+
WHERE n.uuid IN $node_uuids
|
|
1000
|
+
RETURN DISTINCT
|
|
1001
|
+
n.uuid AS uuid,
|
|
1002
|
+
n.name_embedding AS name_embedding
|
|
1003
|
+
"""
|
|
1004
|
+
|
|
1005
|
+
results, _, _ = await driver.execute_query(
|
|
1006
|
+
query, node_uuids=[node.uuid for node in nodes], database_=DEFAULT_DATABASE, routing_='r'
|
|
1007
|
+
)
|
|
1008
|
+
|
|
1009
|
+
embeddings_dict: dict[str, list[float]] = {}
|
|
1010
|
+
for result in results:
|
|
1011
|
+
uuid: str = result.get('uuid')
|
|
1012
|
+
embedding: list[float] = result.get('name_embedding')
|
|
1013
|
+
if uuid is not None and embedding is not None:
|
|
1014
|
+
embeddings_dict[uuid] = embedding
|
|
970
1015
|
|
|
971
|
-
|
|
1016
|
+
return embeddings_dict
|
|
1017
|
+
|
|
1018
|
+
|
|
1019
|
+
async def get_embeddings_for_communities(
|
|
1020
|
+
driver: AsyncDriver, communities: list[CommunityNode]
|
|
1021
|
+
) -> dict[str, list[float]]:
|
|
1022
|
+
query: LiteralString = """MATCH (c:Community)
|
|
1023
|
+
WHERE c.uuid IN $community_uuids
|
|
1024
|
+
RETURN DISTINCT
|
|
1025
|
+
c.uuid AS uuid,
|
|
1026
|
+
c.name_embedding AS name_embedding
|
|
1027
|
+
"""
|
|
1028
|
+
|
|
1029
|
+
results, _, _ = await driver.execute_query(
|
|
1030
|
+
query,
|
|
1031
|
+
community_uuids=[community.uuid for community in communities],
|
|
1032
|
+
database_=DEFAULT_DATABASE,
|
|
1033
|
+
routing_='r',
|
|
1034
|
+
)
|
|
1035
|
+
|
|
1036
|
+
embeddings_dict: dict[str, list[float]] = {}
|
|
1037
|
+
for result in results:
|
|
1038
|
+
uuid: str = result.get('uuid')
|
|
1039
|
+
embedding: list[float] = result.get('name_embedding')
|
|
1040
|
+
if uuid is not None and embedding is not None:
|
|
1041
|
+
embeddings_dict[uuid] = embedding
|
|
1042
|
+
|
|
1043
|
+
return embeddings_dict
|
|
1044
|
+
|
|
1045
|
+
|
|
1046
|
+
async def get_embeddings_for_edges(
|
|
1047
|
+
driver: AsyncDriver, edges: list[EntityEdge]
|
|
1048
|
+
) -> dict[str, list[float]]:
|
|
1049
|
+
query: LiteralString = """MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
|
1050
|
+
WHERE e.uuid IN $edge_uuids
|
|
1051
|
+
RETURN DISTINCT
|
|
1052
|
+
e.uuid AS uuid,
|
|
1053
|
+
e.fact_embedding AS fact_embedding
|
|
1054
|
+
"""
|
|
1055
|
+
|
|
1056
|
+
results, _, _ = await driver.execute_query(
|
|
1057
|
+
query,
|
|
1058
|
+
edge_uuids=[edge.uuid for edge in edges],
|
|
1059
|
+
database_=DEFAULT_DATABASE,
|
|
1060
|
+
routing_='r',
|
|
1061
|
+
)
|
|
1062
|
+
|
|
1063
|
+
embeddings_dict: dict[str, list[float]] = {}
|
|
1064
|
+
for result in results:
|
|
1065
|
+
uuid: str = result.get('uuid')
|
|
1066
|
+
embedding: list[float] = result.get('fact_embedding')
|
|
1067
|
+
if uuid is not None and embedding is not None:
|
|
1068
|
+
embeddings_dict[uuid] = embedding
|
|
972
1069
|
|
|
973
|
-
return
|
|
1070
|
+
return embeddings_dict
|
|
@@ -26,6 +26,7 @@ from pydantic import BaseModel
|
|
|
26
26
|
from typing_extensions import Any
|
|
27
27
|
|
|
28
28
|
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
|
|
29
|
+
from graphiti_core.embedder import EmbedderClient
|
|
29
30
|
from graphiti_core.graphiti_types import GraphitiClients
|
|
30
31
|
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
|
31
32
|
from graphiti_core.llm_client import LLMClient
|
|
@@ -95,10 +96,16 @@ async def add_nodes_and_edges_bulk(
|
|
|
95
96
|
episodic_edges: list[EpisodicEdge],
|
|
96
97
|
entity_nodes: list[EntityNode],
|
|
97
98
|
entity_edges: list[EntityEdge],
|
|
99
|
+
embedder: EmbedderClient,
|
|
98
100
|
):
|
|
99
101
|
async with driver.session(database=DEFAULT_DATABASE) as session:
|
|
100
102
|
await session.execute_write(
|
|
101
|
-
add_nodes_and_edges_bulk_tx,
|
|
103
|
+
add_nodes_and_edges_bulk_tx,
|
|
104
|
+
episodic_nodes,
|
|
105
|
+
episodic_edges,
|
|
106
|
+
entity_nodes,
|
|
107
|
+
entity_edges,
|
|
108
|
+
embedder,
|
|
102
109
|
)
|
|
103
110
|
|
|
104
111
|
|
|
@@ -108,12 +115,15 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
108
115
|
episodic_edges: list[EpisodicEdge],
|
|
109
116
|
entity_nodes: list[EntityNode],
|
|
110
117
|
entity_edges: list[EntityEdge],
|
|
118
|
+
embedder: EmbedderClient,
|
|
111
119
|
):
|
|
112
120
|
episodes = [dict(episode) for episode in episodic_nodes]
|
|
113
121
|
for episode in episodes:
|
|
114
122
|
episode['source'] = str(episode['source'].value)
|
|
115
123
|
nodes: list[dict[str, Any]] = []
|
|
116
124
|
for node in entity_nodes:
|
|
125
|
+
if node.name_embedding is None:
|
|
126
|
+
await node.generate_name_embedding(embedder)
|
|
117
127
|
entity_data: dict[str, Any] = {
|
|
118
128
|
'uuid': node.uuid,
|
|
119
129
|
'name': node.name,
|
|
@@ -127,6 +137,10 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
127
137
|
entity_data['labels'] = list(set(node.labels + ['Entity']))
|
|
128
138
|
nodes.append(entity_data)
|
|
129
139
|
|
|
140
|
+
for edge in entity_edges:
|
|
141
|
+
if edge.fact_embedding is None:
|
|
142
|
+
await edge.generate_embedding(embedder)
|
|
143
|
+
|
|
130
144
|
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
|
|
131
145
|
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes)
|
|
132
146
|
await tx.run(
|
|
@@ -239,7 +239,6 @@ async def determine_entity_community(
|
|
|
239
239
|
RETURN
|
|
240
240
|
c.uuid As uuid,
|
|
241
241
|
c.name AS name,
|
|
242
|
-
c.name_embedding AS name_embedding,
|
|
243
242
|
c.group_id AS group_id,
|
|
244
243
|
c.created_at AS created_at,
|
|
245
244
|
c.summary AS summary
|
|
@@ -258,7 +257,6 @@ async def determine_entity_community(
|
|
|
258
257
|
RETURN
|
|
259
258
|
c.uuid As uuid,
|
|
260
259
|
c.name AS name,
|
|
261
|
-
c.name_embedding AS name_embedding,
|
|
262
260
|
c.group_id AS group_id,
|
|
263
261
|
c.created_at AS created_at,
|
|
264
262
|
c.summary AS summary
|
|
@@ -35,9 +35,6 @@ from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
|
|
|
35
35
|
from graphiti_core.search.search_filters import SearchFilters
|
|
36
36
|
from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
|
|
37
37
|
from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
|
|
38
|
-
from graphiti_core.utils.maintenance.temporal_operations import (
|
|
39
|
-
get_edge_contradictions,
|
|
40
|
-
)
|
|
41
38
|
|
|
42
39
|
logger = logging.getLogger(__name__)
|
|
43
40
|
|
|
@@ -91,7 +88,6 @@ async def extract_edges(
|
|
|
91
88
|
|
|
92
89
|
extract_edges_max_tokens = 16384
|
|
93
90
|
llm_client = clients.llm_client
|
|
94
|
-
embedder = clients.embedder
|
|
95
91
|
|
|
96
92
|
node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
|
|
97
93
|
|
|
@@ -184,8 +180,6 @@ async def extract_edges(
|
|
|
184
180
|
f'Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})'
|
|
185
181
|
)
|
|
186
182
|
|
|
187
|
-
await create_entity_edge_embeddings(embedder, edges)
|
|
188
|
-
|
|
189
183
|
logger.debug(f'Extracted edges: {[(e.name, e.uuid) for e in edges]}')
|
|
190
184
|
|
|
191
185
|
return edges
|
|
@@ -238,13 +232,17 @@ async def dedupe_extracted_edges(
|
|
|
238
232
|
async def resolve_extracted_edges(
|
|
239
233
|
clients: GraphitiClients,
|
|
240
234
|
extracted_edges: list[EntityEdge],
|
|
235
|
+
episode: EpisodicNode,
|
|
241
236
|
) -> tuple[list[EntityEdge], list[EntityEdge]]:
|
|
242
237
|
driver = clients.driver
|
|
243
238
|
llm_client = clients.llm_client
|
|
239
|
+
embedder = clients.embedder
|
|
240
|
+
|
|
241
|
+
await create_entity_edge_embeddings(embedder, extracted_edges)
|
|
244
242
|
|
|
245
243
|
search_results: tuple[list[list[EntityEdge]], list[list[EntityEdge]]] = await semaphore_gather(
|
|
246
244
|
get_relevant_edges(driver, extracted_edges, SearchFilters()),
|
|
247
|
-
get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters()),
|
|
245
|
+
get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters(), 0.2),
|
|
248
246
|
)
|
|
249
247
|
|
|
250
248
|
related_edges_lists, edge_invalidation_candidates = search_results
|
|
@@ -258,10 +256,7 @@ async def resolve_extracted_edges(
|
|
|
258
256
|
await semaphore_gather(
|
|
259
257
|
*[
|
|
260
258
|
resolve_extracted_edge(
|
|
261
|
-
llm_client,
|
|
262
|
-
extracted_edge,
|
|
263
|
-
related_edges,
|
|
264
|
-
existing_edges,
|
|
259
|
+
llm_client, extracted_edge, related_edges, existing_edges, episode
|
|
265
260
|
)
|
|
266
261
|
for extracted_edge, related_edges, existing_edges in zip(
|
|
267
262
|
extracted_edges, related_edges_lists, edge_invalidation_candidates, strict=True
|
|
@@ -281,6 +276,11 @@ async def resolve_extracted_edges(
|
|
|
281
276
|
|
|
282
277
|
logger.debug(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
|
|
283
278
|
|
|
279
|
+
await semaphore_gather(
|
|
280
|
+
create_entity_edge_embeddings(embedder, resolved_edges),
|
|
281
|
+
create_entity_edge_embeddings(embedder, invalidated_edges),
|
|
282
|
+
)
|
|
283
|
+
|
|
284
284
|
return resolved_edges, invalidated_edges
|
|
285
285
|
|
|
286
286
|
|
|
@@ -322,10 +322,52 @@ async def resolve_extracted_edge(
|
|
|
322
322
|
extracted_edge: EntityEdge,
|
|
323
323
|
related_edges: list[EntityEdge],
|
|
324
324
|
existing_edges: list[EntityEdge],
|
|
325
|
+
episode: EpisodicNode | None = None,
|
|
325
326
|
) -> tuple[EntityEdge, list[EntityEdge]]:
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
327
|
+
if len(related_edges) == 0 and len(existing_edges) == 0:
|
|
328
|
+
return extracted_edge, []
|
|
329
|
+
|
|
330
|
+
start = time()
|
|
331
|
+
|
|
332
|
+
# Prepare context for LLM
|
|
333
|
+
related_edges_context = [
|
|
334
|
+
{'id': edge.uuid, 'fact': edge.fact} for i, edge in enumerate(related_edges)
|
|
335
|
+
]
|
|
336
|
+
|
|
337
|
+
invalidation_edge_candidates_context = [
|
|
338
|
+
{'id': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
|
|
339
|
+
]
|
|
340
|
+
|
|
341
|
+
context = {
|
|
342
|
+
'existing_edges': related_edges_context,
|
|
343
|
+
'new_edge': extracted_edge.fact,
|
|
344
|
+
'edge_invalidation_candidates': invalidation_edge_candidates_context,
|
|
345
|
+
}
|
|
346
|
+
|
|
347
|
+
llm_response = await llm_client.generate_response(
|
|
348
|
+
prompt_library.dedupe_edges.resolve_edge(context),
|
|
349
|
+
response_model=EdgeDuplicate,
|
|
350
|
+
model_size=ModelSize.small,
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
duplicate_fact_id: int = llm_response.get('duplicate_fact_id', -1)
|
|
354
|
+
|
|
355
|
+
resolved_edge = (
|
|
356
|
+
related_edges[duplicate_fact_id]
|
|
357
|
+
if 0 <= duplicate_fact_id < len(related_edges)
|
|
358
|
+
else extracted_edge
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
if duplicate_fact_id >= 0 and episode is not None:
|
|
362
|
+
resolved_edge.episodes.append(episode.uuid)
|
|
363
|
+
|
|
364
|
+
contradicted_facts: list[int] = llm_response.get('contradicted_facts', [])
|
|
365
|
+
|
|
366
|
+
invalidation_candidates: list[EntityEdge] = [existing_edges[i] for i in contradicted_facts]
|
|
367
|
+
|
|
368
|
+
end = time()
|
|
369
|
+
logger.debug(
|
|
370
|
+
f'Resolved Edge: {extracted_edge.name} is {resolved_edge.name}, in {(end - start) * 1000} ms'
|
|
329
371
|
)
|
|
330
372
|
|
|
331
373
|
now = utc_now()
|
|
@@ -356,7 +398,10 @@ async def resolve_extracted_edge(
|
|
|
356
398
|
|
|
357
399
|
|
|
358
400
|
async def dedupe_extracted_edge(
|
|
359
|
-
llm_client: LLMClient,
|
|
401
|
+
llm_client: LLMClient,
|
|
402
|
+
extracted_edge: EntityEdge,
|
|
403
|
+
related_edges: list[EntityEdge],
|
|
404
|
+
episode: EpisodicNode | None = None,
|
|
360
405
|
) -> EntityEdge:
|
|
361
406
|
if len(related_edges) == 0:
|
|
362
407
|
return extracted_edge
|
|
@@ -391,6 +436,9 @@ async def dedupe_extracted_edge(
|
|
|
391
436
|
else extracted_edge
|
|
392
437
|
)
|
|
393
438
|
|
|
439
|
+
if duplicate_fact_id >= 0 and episode is not None:
|
|
440
|
+
edge.episodes.append(episode.uuid)
|
|
441
|
+
|
|
394
442
|
end = time()
|
|
395
443
|
logger.debug(
|
|
396
444
|
f'Resolved Edge: {extracted_edge.name} is {edge.name}, in {(end - start) * 1000} ms'
|
|
@@ -18,6 +18,7 @@ import logging
|
|
|
18
18
|
from contextlib import suppress
|
|
19
19
|
from time import time
|
|
20
20
|
from typing import Any
|
|
21
|
+
from uuid import uuid4
|
|
21
22
|
|
|
22
23
|
import pydantic
|
|
23
24
|
from pydantic import BaseModel, Field
|
|
@@ -28,14 +29,16 @@ from graphiti_core.llm_client import LLMClient
|
|
|
28
29
|
from graphiti_core.llm_client.config import ModelSize
|
|
29
30
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
|
30
31
|
from graphiti_core.prompts import prompt_library
|
|
31
|
-
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
|
|
32
|
+
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
|
|
32
33
|
from graphiti_core.prompts.extract_nodes import (
|
|
33
34
|
ExtractedEntities,
|
|
34
35
|
ExtractedEntity,
|
|
35
36
|
MissedEntities,
|
|
36
37
|
)
|
|
38
|
+
from graphiti_core.search.search import search
|
|
39
|
+
from graphiti_core.search.search_config import SearchResults
|
|
40
|
+
from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
|
|
37
41
|
from graphiti_core.search.search_filters import SearchFilters
|
|
38
|
-
from graphiti_core.search.search_utils import get_relevant_nodes
|
|
39
42
|
from graphiti_core.utils.datetime_utils import utc_now
|
|
40
43
|
|
|
41
44
|
logger = logging.getLogger(__name__)
|
|
@@ -70,7 +73,6 @@ async def extract_nodes(
|
|
|
70
73
|
) -> list[EntityNode]:
|
|
71
74
|
start = time()
|
|
72
75
|
llm_client = clients.llm_client
|
|
73
|
-
embedder = clients.embedder
|
|
74
76
|
llm_response = {}
|
|
75
77
|
custom_prompt = ''
|
|
76
78
|
entities_missed = True
|
|
@@ -163,8 +165,6 @@ async def extract_nodes(
|
|
|
163
165
|
extracted_nodes.append(new_node)
|
|
164
166
|
logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
|
|
165
167
|
|
|
166
|
-
await create_entity_node_embeddings(embedder, extracted_nodes)
|
|
167
|
-
|
|
168
168
|
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
|
169
169
|
return extracted_nodes
|
|
170
170
|
|
|
@@ -227,35 +227,81 @@ async def resolve_extracted_nodes(
|
|
|
227
227
|
entity_types: dict[str, BaseModel] | None = None,
|
|
228
228
|
) -> tuple[list[EntityNode], dict[str, str]]:
|
|
229
229
|
llm_client = clients.llm_client
|
|
230
|
-
driver = clients.driver
|
|
231
230
|
|
|
232
|
-
|
|
233
|
-
existing_nodes_lists: list[list[EntityNode]] = await get_relevant_nodes(
|
|
234
|
-
driver, extracted_nodes, SearchFilters()
|
|
235
|
-
)
|
|
236
|
-
|
|
237
|
-
resolved_nodes: list[EntityNode] = await semaphore_gather(
|
|
231
|
+
search_results: list[SearchResults] = await semaphore_gather(
|
|
238
232
|
*[
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
entity_types.get(
|
|
246
|
-
next((item for item in extracted_node.labels if item != 'Entity'), '')
|
|
247
|
-
)
|
|
248
|
-
if entity_types is not None
|
|
249
|
-
else None,
|
|
250
|
-
)
|
|
251
|
-
for extracted_node, existing_nodes in zip(
|
|
252
|
-
extracted_nodes, existing_nodes_lists, strict=True
|
|
233
|
+
search(
|
|
234
|
+
clients=clients,
|
|
235
|
+
query=node.name,
|
|
236
|
+
group_ids=[node.group_id],
|
|
237
|
+
search_filter=SearchFilters(),
|
|
238
|
+
config=NODE_HYBRID_SEARCH_RRF,
|
|
253
239
|
)
|
|
240
|
+
for node in extracted_nodes
|
|
254
241
|
]
|
|
255
242
|
)
|
|
256
243
|
|
|
244
|
+
existing_nodes_lists: list[list[EntityNode]] = [result.nodes for result in search_results]
|
|
245
|
+
|
|
246
|
+
entity_types_dict: dict[str, BaseModel] = entity_types if entity_types is not None else {}
|
|
247
|
+
|
|
248
|
+
# Prepare context for LLM
|
|
249
|
+
extracted_nodes_context = [
|
|
250
|
+
{
|
|
251
|
+
'id': i,
|
|
252
|
+
'name': node.name,
|
|
253
|
+
'entity_type': node.labels,
|
|
254
|
+
'entity_type_description': entity_types_dict.get(
|
|
255
|
+
next((item for item in node.labels if item != 'Entity'), '')
|
|
256
|
+
).__doc__
|
|
257
|
+
or 'Default Entity Type',
|
|
258
|
+
'duplication_candidates': [
|
|
259
|
+
{
|
|
260
|
+
**{
|
|
261
|
+
'idx': j,
|
|
262
|
+
'name': candidate.name,
|
|
263
|
+
'entity_types': candidate.labels,
|
|
264
|
+
},
|
|
265
|
+
**candidate.attributes,
|
|
266
|
+
}
|
|
267
|
+
for j, candidate in enumerate(existing_nodes_lists[i])
|
|
268
|
+
],
|
|
269
|
+
}
|
|
270
|
+
for i, node in enumerate(extracted_nodes)
|
|
271
|
+
]
|
|
272
|
+
|
|
273
|
+
context = {
|
|
274
|
+
'extracted_nodes': extracted_nodes_context,
|
|
275
|
+
'episode_content': episode.content if episode is not None else '',
|
|
276
|
+
'previous_episodes': [ep.content for ep in previous_episodes]
|
|
277
|
+
if previous_episodes is not None
|
|
278
|
+
else [],
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
llm_response = await llm_client.generate_response(
|
|
282
|
+
prompt_library.dedupe_nodes.nodes(context),
|
|
283
|
+
response_model=NodeResolutions,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
node_resolutions: list = llm_response.get('entity_resolutions', [])
|
|
287
|
+
|
|
288
|
+
resolved_nodes: list[EntityNode] = []
|
|
257
289
|
uuid_map: dict[str, str] = {}
|
|
258
|
-
for
|
|
290
|
+
for resolution in node_resolutions:
|
|
291
|
+
resolution_id = resolution.get('id', -1)
|
|
292
|
+
duplicate_idx = resolution.get('duplicate_idx', -1)
|
|
293
|
+
|
|
294
|
+
extracted_node = extracted_nodes[resolution_id]
|
|
295
|
+
|
|
296
|
+
resolved_node = (
|
|
297
|
+
existing_nodes_lists[resolution_id][duplicate_idx]
|
|
298
|
+
if 0 <= duplicate_idx < len(existing_nodes_lists[resolution_id])
|
|
299
|
+
else extracted_node
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
resolved_node.name = resolution.get('name')
|
|
303
|
+
|
|
304
|
+
resolved_nodes.append(resolved_node)
|
|
259
305
|
uuid_map[extracted_node.uuid] = resolved_node.uuid
|
|
260
306
|
|
|
261
307
|
logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}')
|
|
@@ -375,7 +421,7 @@ async def extract_attributes_from_node(
|
|
|
375
421
|
'summary': (
|
|
376
422
|
str,
|
|
377
423
|
Field(
|
|
378
|
-
description='Summary containing the important information about the entity. Under
|
|
424
|
+
description='Summary containing the important information about the entity. Under 250 words',
|
|
379
425
|
),
|
|
380
426
|
)
|
|
381
427
|
}
|
|
@@ -387,7 +433,8 @@ async def extract_attributes_from_node(
|
|
|
387
433
|
Field(description=field_info.description),
|
|
388
434
|
)
|
|
389
435
|
|
|
390
|
-
|
|
436
|
+
unique_model_name = f'EntityAttributes_{uuid4().hex}'
|
|
437
|
+
entity_attributes_model = pydantic.create_model(unique_model_name, **attributes_definitions)
|
|
391
438
|
|
|
392
439
|
summary_context: dict[str, Any] = {
|
|
393
440
|
'node': node_context,
|
|
@@ -400,15 +447,14 @@ async def extract_attributes_from_node(
|
|
|
400
447
|
llm_response = await llm_client.generate_response(
|
|
401
448
|
prompt_library.extract_nodes.extract_attributes(summary_context),
|
|
402
449
|
response_model=entity_attributes_model,
|
|
450
|
+
model_size=ModelSize.small,
|
|
403
451
|
)
|
|
404
452
|
|
|
405
453
|
node.summary = llm_response.get('summary', node.summary)
|
|
406
|
-
node.name = llm_response.get('name', node.name)
|
|
407
454
|
node_attributes = {key: value for key, value in llm_response.items()}
|
|
408
455
|
|
|
409
456
|
with suppress(KeyError):
|
|
410
457
|
del node_attributes['summary']
|
|
411
|
-
del node_attributes['name']
|
|
412
458
|
|
|
413
459
|
node.attributes.update(node_attributes)
|
|
414
460
|
|
|
@@ -427,10 +473,7 @@ async def dedupe_node_list(
|
|
|
427
473
|
node_map[node.uuid] = node
|
|
428
474
|
|
|
429
475
|
# Prepare context for LLM
|
|
430
|
-
nodes_context = [
|
|
431
|
-
{'uuid': node.uuid, 'name': node.name, 'summary': node.summary}.update(node.attributes)
|
|
432
|
-
for node in nodes
|
|
433
|
-
]
|
|
476
|
+
nodes_context = [{'uuid': node.uuid, 'name': node.name, **node.attributes} for node in nodes]
|
|
434
477
|
|
|
435
478
|
context = {
|
|
436
479
|
'nodes': nodes_context,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: graphiti-core
|
|
3
|
-
Version: 0.11.
|
|
3
|
+
Version: 0.11.6
|
|
4
4
|
Summary: A temporal graph building library
|
|
5
5
|
License: Apache-2.0
|
|
6
6
|
Author: Paul Paliychuk
|
|
@@ -18,7 +18,6 @@ Provides-Extra: groq
|
|
|
18
18
|
Requires-Dist: anthropic (>=0.49.0) ; extra == "anthropic"
|
|
19
19
|
Requires-Dist: diskcache (>=5.6.3)
|
|
20
20
|
Requires-Dist: google-genai (>=1.8.0) ; extra == "google-genai"
|
|
21
|
-
Requires-Dist: graph-service (>=1.0.0.7,<2.0.0.0)
|
|
22
21
|
Requires-Dist: groq (>=0.2.0) ; extra == "groq"
|
|
23
22
|
Requires-Dist: neo4j (>=5.23.0)
|
|
24
23
|
Requires-Dist: numpy (>=1.0.0)
|