graphiti-core 0.17.4__py3-none-any.whl → 0.25.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphiti_core/cross_encoder/gemini_reranker_client.py +1 -1
- graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
- graphiti_core/decorators.py +110 -0
- graphiti_core/driver/driver.py +62 -2
- graphiti_core/driver/falkordb_driver.py +215 -23
- graphiti_core/driver/graph_operations/graph_operations.py +191 -0
- graphiti_core/driver/kuzu_driver.py +182 -0
- graphiti_core/driver/neo4j_driver.py +70 -8
- graphiti_core/driver/neptune_driver.py +305 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +264 -132
- graphiti_core/embedder/azure_openai.py +10 -3
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/graph_queries.py +114 -101
- graphiti_core/graphiti.py +635 -260
- graphiti_core/graphiti_types.py +2 -0
- graphiti_core/helpers.py +37 -15
- graphiti_core/llm_client/anthropic_client.py +142 -52
- graphiti_core/llm_client/azure_openai_client.py +57 -19
- graphiti_core/llm_client/client.py +83 -21
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/gemini_client.py +75 -57
- graphiti_core/llm_client/openai_base_client.py +92 -48
- graphiti_core/llm_client/openai_client.py +39 -9
- graphiti_core/llm_client/openai_generic_client.py +91 -56
- graphiti_core/models/edges/edge_db_queries.py +259 -35
- graphiti_core/models/nodes/node_db_queries.py +311 -32
- graphiti_core/nodes.py +388 -164
- graphiti_core/prompts/dedupe_edges.py +42 -31
- graphiti_core/prompts/dedupe_nodes.py +56 -39
- graphiti_core/prompts/eval.py +4 -4
- graphiti_core/prompts/extract_edges.py +24 -15
- graphiti_core/prompts/extract_nodes.py +76 -35
- graphiti_core/prompts/prompt_helpers.py +39 -0
- graphiti_core/prompts/snippets.py +29 -0
- graphiti_core/prompts/summarize_nodes.py +23 -25
- graphiti_core/search/search.py +154 -74
- graphiti_core/search/search_config.py +39 -4
- graphiti_core/search/search_filters.py +110 -31
- graphiti_core/search/search_helpers.py +5 -6
- graphiti_core/search/search_utils.py +1360 -473
- graphiti_core/tracer.py +193 -0
- graphiti_core/utils/bulk_utils.py +216 -90
- graphiti_core/utils/content_chunking.py +702 -0
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +62 -38
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +306 -156
- graphiti_core/utils/maintenance/graph_data_operations.py +44 -74
- graphiti_core/utils/maintenance/node_operations.py +466 -206
- graphiti_core/utils/maintenance/temporal_operations.py +11 -3
- graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
- graphiti_core/utils/text_utils.py +53 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/METADATA +221 -87
- graphiti_core-0.25.3.dist-info/RECORD +87 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/WHEEL +1 -1
- graphiti_core-0.17.4.dist-info/RECORD +0 -77
- /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -21,7 +21,7 @@ from time import time
|
|
|
21
21
|
from pydantic import BaseModel
|
|
22
22
|
from typing_extensions import LiteralString
|
|
23
23
|
|
|
24
|
-
from graphiti_core.driver.driver import GraphDriver
|
|
24
|
+
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
|
25
25
|
from graphiti_core.edges import (
|
|
26
26
|
CommunityEdge,
|
|
27
27
|
EntityEdge,
|
|
@@ -29,16 +29,21 @@ from graphiti_core.edges import (
|
|
|
29
29
|
create_entity_edge_embeddings,
|
|
30
30
|
)
|
|
31
31
|
from graphiti_core.graphiti_types import GraphitiClients
|
|
32
|
-
from graphiti_core.helpers import
|
|
32
|
+
from graphiti_core.helpers import semaphore_gather
|
|
33
33
|
from graphiti_core.llm_client import LLMClient
|
|
34
34
|
from graphiti_core.llm_client.config import ModelSize
|
|
35
35
|
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
|
36
36
|
from graphiti_core.prompts import prompt_library
|
|
37
|
-
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
|
|
38
|
-
from graphiti_core.prompts.extract_edges import ExtractedEdges
|
|
37
|
+
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
|
|
38
|
+
from graphiti_core.prompts.extract_edges import ExtractedEdges
|
|
39
|
+
from graphiti_core.search.search import search
|
|
40
|
+
from graphiti_core.search.search_config import SearchResults
|
|
41
|
+
from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF
|
|
39
42
|
from graphiti_core.search.search_filters import SearchFilters
|
|
40
|
-
from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
|
|
41
43
|
from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
|
|
44
|
+
from graphiti_core.utils.maintenance.dedup_helpers import _normalize_string_exact
|
|
45
|
+
|
|
46
|
+
DEFAULT_EDGE_NAME = 'RELATES_TO'
|
|
42
47
|
|
|
43
48
|
logger = logging.getLogger(__name__)
|
|
44
49
|
|
|
@@ -63,32 +68,6 @@ def build_episodic_edges(
|
|
|
63
68
|
return episodic_edges
|
|
64
69
|
|
|
65
70
|
|
|
66
|
-
def build_duplicate_of_edges(
|
|
67
|
-
episode: EpisodicNode,
|
|
68
|
-
created_at: datetime,
|
|
69
|
-
duplicate_nodes: list[tuple[EntityNode, EntityNode]],
|
|
70
|
-
) -> list[EntityEdge]:
|
|
71
|
-
is_duplicate_of_edges: list[EntityEdge] = []
|
|
72
|
-
for source_node, target_node in duplicate_nodes:
|
|
73
|
-
if source_node.uuid == target_node.uuid:
|
|
74
|
-
continue
|
|
75
|
-
|
|
76
|
-
is_duplicate_of_edges.append(
|
|
77
|
-
EntityEdge(
|
|
78
|
-
source_node_uuid=source_node.uuid,
|
|
79
|
-
target_node_uuid=target_node.uuid,
|
|
80
|
-
name='IS_DUPLICATE_OF',
|
|
81
|
-
group_id=episode.group_id,
|
|
82
|
-
fact=f'{source_node.name} is a duplicate of {target_node.name}',
|
|
83
|
-
episodes=[episode.uuid],
|
|
84
|
-
created_at=created_at,
|
|
85
|
-
valid_at=created_at,
|
|
86
|
-
)
|
|
87
|
-
)
|
|
88
|
-
|
|
89
|
-
return is_duplicate_of_edges
|
|
90
|
-
|
|
91
|
-
|
|
92
71
|
def build_community_edges(
|
|
93
72
|
entity_nodes: list[EntityNode],
|
|
94
73
|
community_node: CommunityNode,
|
|
@@ -114,7 +93,8 @@ async def extract_edges(
|
|
|
114
93
|
previous_episodes: list[EpisodicNode],
|
|
115
94
|
edge_type_map: dict[tuple[str, str], list[str]],
|
|
116
95
|
group_id: str = '',
|
|
117
|
-
edge_types: dict[str, BaseModel] | None = None,
|
|
96
|
+
edge_types: dict[str, type[BaseModel]] | None = None,
|
|
97
|
+
custom_extraction_instructions: str | None = None,
|
|
118
98
|
) -> list[EntityEdge]:
|
|
119
99
|
start = time()
|
|
120
100
|
|
|
@@ -150,38 +130,17 @@ async def extract_edges(
|
|
|
150
130
|
'previous_episodes': [ep.content for ep in previous_episodes],
|
|
151
131
|
'reference_time': episode.valid_at,
|
|
152
132
|
'edge_types': edge_types_context,
|
|
153
|
-
'
|
|
133
|
+
'custom_extraction_instructions': custom_extraction_instructions or '',
|
|
154
134
|
}
|
|
155
135
|
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
edges_data = llm_response.get('edges', [])
|
|
165
|
-
|
|
166
|
-
context['extracted_facts'] = [edge_data.get('fact', '') for edge_data in edges_data]
|
|
167
|
-
|
|
168
|
-
reflexion_iterations += 1
|
|
169
|
-
if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
|
|
170
|
-
reflexion_response = await llm_client.generate_response(
|
|
171
|
-
prompt_library.extract_edges.reflexion(context),
|
|
172
|
-
response_model=MissingFacts,
|
|
173
|
-
max_tokens=extract_edges_max_tokens,
|
|
174
|
-
)
|
|
175
|
-
|
|
176
|
-
missing_facts = reflexion_response.get('missing_facts', [])
|
|
177
|
-
|
|
178
|
-
custom_prompt = 'The following facts were missed in a previous extraction: '
|
|
179
|
-
for fact in missing_facts:
|
|
180
|
-
custom_prompt += f'\n{fact},'
|
|
181
|
-
|
|
182
|
-
context['custom_prompt'] = custom_prompt
|
|
183
|
-
|
|
184
|
-
facts_missed = len(missing_facts) != 0
|
|
136
|
+
llm_response = await llm_client.generate_response(
|
|
137
|
+
prompt_library.extract_edges.edge(context),
|
|
138
|
+
response_model=ExtractedEdges,
|
|
139
|
+
max_tokens=extract_edges_max_tokens,
|
|
140
|
+
group_id=group_id,
|
|
141
|
+
prompt_name='extract_edges.edge',
|
|
142
|
+
)
|
|
143
|
+
edges_data = ExtractedEdges(**llm_response).edges
|
|
185
144
|
|
|
186
145
|
end = time()
|
|
187
146
|
logger.debug(f'Extracted new edges: {edges_data} in {(end - start) * 1000} ms')
|
|
@@ -193,20 +152,31 @@ async def extract_edges(
|
|
|
193
152
|
edges = []
|
|
194
153
|
for edge_data in edges_data:
|
|
195
154
|
# Validate Edge Date information
|
|
196
|
-
valid_at = edge_data.
|
|
197
|
-
invalid_at = edge_data.
|
|
155
|
+
valid_at = edge_data.valid_at
|
|
156
|
+
invalid_at = edge_data.invalid_at
|
|
198
157
|
valid_at_datetime = None
|
|
199
158
|
invalid_at_datetime = None
|
|
200
159
|
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
160
|
+
# Filter out empty edges
|
|
161
|
+
if not edge_data.fact.strip():
|
|
162
|
+
continue
|
|
163
|
+
|
|
164
|
+
source_node_idx = edge_data.source_entity_id
|
|
165
|
+
target_node_idx = edge_data.target_entity_id
|
|
166
|
+
|
|
167
|
+
if len(nodes) == 0:
|
|
168
|
+
logger.warning('No entities provided for edge extraction')
|
|
169
|
+
continue
|
|
170
|
+
|
|
171
|
+
if not (0 <= source_node_idx < len(nodes) and 0 <= target_node_idx < len(nodes)):
|
|
204
172
|
logger.warning(
|
|
205
|
-
f'
|
|
173
|
+
f'Invalid entity IDs in edge extraction for {edge_data.relation_type}. '
|
|
174
|
+
f'source_entity_id: {source_node_idx}, target_entity_id: {target_node_idx}, '
|
|
175
|
+
f'but only {len(nodes)} entities available (valid range: 0-{len(nodes) - 1})'
|
|
206
176
|
)
|
|
207
177
|
continue
|
|
208
178
|
source_node_uuid = nodes[source_node_idx].uuid
|
|
209
|
-
target_node_uuid = nodes[
|
|
179
|
+
target_node_uuid = nodes[target_node_idx].uuid
|
|
210
180
|
|
|
211
181
|
if valid_at:
|
|
212
182
|
try:
|
|
@@ -226,9 +196,9 @@ async def extract_edges(
|
|
|
226
196
|
edge = EntityEdge(
|
|
227
197
|
source_node_uuid=source_node_uuid,
|
|
228
198
|
target_node_uuid=target_node_uuid,
|
|
229
|
-
name=edge_data.
|
|
199
|
+
name=edge_data.relation_type,
|
|
230
200
|
group_id=group_id,
|
|
231
|
-
fact=edge_data.
|
|
201
|
+
fact=edge_data.fact,
|
|
232
202
|
episodes=[episode.uuid],
|
|
233
203
|
created_at=utc_now(),
|
|
234
204
|
valid_at=valid_at_datetime,
|
|
@@ -249,20 +219,68 @@ async def resolve_extracted_edges(
|
|
|
249
219
|
extracted_edges: list[EntityEdge],
|
|
250
220
|
episode: EpisodicNode,
|
|
251
221
|
entities: list[EntityNode],
|
|
252
|
-
edge_types: dict[str, BaseModel],
|
|
222
|
+
edge_types: dict[str, type[BaseModel]],
|
|
253
223
|
edge_type_map: dict[tuple[str, str], list[str]],
|
|
254
224
|
) -> tuple[list[EntityEdge], list[EntityEdge]]:
|
|
225
|
+
# Fast path: deduplicate exact matches within the extracted edges before parallel processing
|
|
226
|
+
seen: dict[tuple[str, str, str], EntityEdge] = {}
|
|
227
|
+
deduplicated_edges: list[EntityEdge] = []
|
|
228
|
+
|
|
229
|
+
for edge in extracted_edges:
|
|
230
|
+
key = (
|
|
231
|
+
edge.source_node_uuid,
|
|
232
|
+
edge.target_node_uuid,
|
|
233
|
+
_normalize_string_exact(edge.fact),
|
|
234
|
+
)
|
|
235
|
+
if key not in seen:
|
|
236
|
+
seen[key] = edge
|
|
237
|
+
deduplicated_edges.append(edge)
|
|
238
|
+
|
|
239
|
+
extracted_edges = deduplicated_edges
|
|
240
|
+
|
|
255
241
|
driver = clients.driver
|
|
256
242
|
llm_client = clients.llm_client
|
|
257
243
|
embedder = clients.embedder
|
|
258
244
|
await create_entity_edge_embeddings(embedder, extracted_edges)
|
|
259
245
|
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
246
|
+
valid_edges_list: list[list[EntityEdge]] = await semaphore_gather(
|
|
247
|
+
*[
|
|
248
|
+
EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid)
|
|
249
|
+
for edge in extracted_edges
|
|
250
|
+
]
|
|
263
251
|
)
|
|
264
252
|
|
|
265
|
-
|
|
253
|
+
related_edges_results: list[SearchResults] = await semaphore_gather(
|
|
254
|
+
*[
|
|
255
|
+
search(
|
|
256
|
+
clients,
|
|
257
|
+
extracted_edge.fact,
|
|
258
|
+
group_ids=[extracted_edge.group_id],
|
|
259
|
+
config=EDGE_HYBRID_SEARCH_RRF,
|
|
260
|
+
search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
|
|
261
|
+
)
|
|
262
|
+
for extracted_edge, valid_edges in zip(extracted_edges, valid_edges_list, strict=True)
|
|
263
|
+
]
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
related_edges_lists: list[list[EntityEdge]] = [result.edges for result in related_edges_results]
|
|
267
|
+
|
|
268
|
+
edge_invalidation_candidate_results: list[SearchResults] = await semaphore_gather(
|
|
269
|
+
*[
|
|
270
|
+
search(
|
|
271
|
+
clients,
|
|
272
|
+
extracted_edge.fact,
|
|
273
|
+
group_ids=[extracted_edge.group_id],
|
|
274
|
+
config=EDGE_HYBRID_SEARCH_RRF,
|
|
275
|
+
search_filter=SearchFilters(),
|
|
276
|
+
)
|
|
277
|
+
for extracted_edge in extracted_edges
|
|
278
|
+
]
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
edge_invalidation_candidates: list[list[EntityEdge]] = [
|
|
282
|
+
result.edges for result in edge_invalidation_candidate_results
|
|
283
|
+
]
|
|
266
284
|
|
|
267
285
|
logger.debug(
|
|
268
286
|
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
|
|
@@ -271,11 +289,35 @@ async def resolve_extracted_edges(
|
|
|
271
289
|
# Build entity hash table
|
|
272
290
|
uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
|
|
273
291
|
|
|
274
|
-
#
|
|
275
|
-
|
|
292
|
+
# Collect all node UUIDs referenced by edges that are not in the entities list
|
|
293
|
+
referenced_node_uuids = set()
|
|
294
|
+
for extracted_edge in extracted_edges:
|
|
295
|
+
if extracted_edge.source_node_uuid not in uuid_entity_map:
|
|
296
|
+
referenced_node_uuids.add(extracted_edge.source_node_uuid)
|
|
297
|
+
if extracted_edge.target_node_uuid not in uuid_entity_map:
|
|
298
|
+
referenced_node_uuids.add(extracted_edge.target_node_uuid)
|
|
299
|
+
|
|
300
|
+
# Fetch missing nodes from the database
|
|
301
|
+
if referenced_node_uuids:
|
|
302
|
+
missing_nodes = await EntityNode.get_by_uuids(driver, list(referenced_node_uuids))
|
|
303
|
+
for node in missing_nodes:
|
|
304
|
+
uuid_entity_map[node.uuid] = node
|
|
305
|
+
|
|
306
|
+
# Determine which edge types are relevant for each edge.
|
|
307
|
+
# `edge_types_lst` stores the subset of custom edge definitions whose
|
|
308
|
+
# node signature matches each extracted edge. Anything outside this subset
|
|
309
|
+
# should only stay on the edge if it is a non-custom (LLM generated) label.
|
|
310
|
+
edge_types_lst: list[dict[str, type[BaseModel]]] = []
|
|
311
|
+
custom_type_names = set(edge_types or {})
|
|
276
312
|
for extracted_edge in extracted_edges:
|
|
277
|
-
|
|
278
|
-
|
|
313
|
+
source_node = uuid_entity_map.get(extracted_edge.source_node_uuid)
|
|
314
|
+
target_node = uuid_entity_map.get(extracted_edge.target_node_uuid)
|
|
315
|
+
source_node_labels = (
|
|
316
|
+
source_node.labels + ['Entity'] if source_node is not None else ['Entity']
|
|
317
|
+
)
|
|
318
|
+
target_node_labels = (
|
|
319
|
+
target_node.labels + ['Entity'] if target_node is not None else ['Entity']
|
|
320
|
+
)
|
|
279
321
|
label_tuples = [
|
|
280
322
|
(source_label, target_label)
|
|
281
323
|
for source_label in source_node_labels
|
|
@@ -294,6 +336,20 @@ async def resolve_extracted_edges(
|
|
|
294
336
|
|
|
295
337
|
edge_types_lst.append(extracted_edge_types)
|
|
296
338
|
|
|
339
|
+
for extracted_edge, extracted_edge_types in zip(extracted_edges, edge_types_lst, strict=True):
|
|
340
|
+
allowed_type_names = set(extracted_edge_types)
|
|
341
|
+
is_custom_name = extracted_edge.name in custom_type_names
|
|
342
|
+
if not allowed_type_names:
|
|
343
|
+
# No custom types are valid for this node pairing. Keep LLM generated
|
|
344
|
+
# labels, but flip disallowed custom names back to the default.
|
|
345
|
+
if is_custom_name and extracted_edge.name != DEFAULT_EDGE_NAME:
|
|
346
|
+
extracted_edge.name = DEFAULT_EDGE_NAME
|
|
347
|
+
continue
|
|
348
|
+
if is_custom_name and extracted_edge.name not in allowed_type_names:
|
|
349
|
+
# Custom name exists but it is not permitted for this source/target
|
|
350
|
+
# signature, so fall back to the default edge label.
|
|
351
|
+
extracted_edge.name = DEFAULT_EDGE_NAME
|
|
352
|
+
|
|
297
353
|
# resolve edges with related edges in the graph and find invalidation candidates
|
|
298
354
|
results: list[tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]] = list(
|
|
299
355
|
await semaphore_gather(
|
|
@@ -305,6 +361,7 @@ async def resolve_extracted_edges(
|
|
|
305
361
|
existing_edges,
|
|
306
362
|
episode,
|
|
307
363
|
extracted_edge_types,
|
|
364
|
+
custom_type_names,
|
|
308
365
|
)
|
|
309
366
|
for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
|
|
310
367
|
extracted_edges,
|
|
@@ -346,21 +403,26 @@ def resolve_edge_contradictions(
|
|
|
346
403
|
invalidated_edges: list[EntityEdge] = []
|
|
347
404
|
for edge in invalidation_candidates:
|
|
348
405
|
# (Edge invalid before new edge becomes valid) or (new edge invalid before edge becomes valid)
|
|
406
|
+
edge_invalid_at_utc = ensure_utc(edge.invalid_at)
|
|
407
|
+
resolved_edge_valid_at_utc = ensure_utc(resolved_edge.valid_at)
|
|
408
|
+
edge_valid_at_utc = ensure_utc(edge.valid_at)
|
|
409
|
+
resolved_edge_invalid_at_utc = ensure_utc(resolved_edge.invalid_at)
|
|
410
|
+
|
|
349
411
|
if (
|
|
350
|
-
|
|
351
|
-
and
|
|
352
|
-
and
|
|
412
|
+
edge_invalid_at_utc is not None
|
|
413
|
+
and resolved_edge_valid_at_utc is not None
|
|
414
|
+
and edge_invalid_at_utc <= resolved_edge_valid_at_utc
|
|
353
415
|
) or (
|
|
354
|
-
|
|
355
|
-
and
|
|
356
|
-
and
|
|
416
|
+
edge_valid_at_utc is not None
|
|
417
|
+
and resolved_edge_invalid_at_utc is not None
|
|
418
|
+
and resolved_edge_invalid_at_utc <= edge_valid_at_utc
|
|
357
419
|
):
|
|
358
420
|
continue
|
|
359
421
|
# New edge invalidates edge
|
|
360
422
|
elif (
|
|
361
|
-
|
|
362
|
-
and
|
|
363
|
-
and
|
|
423
|
+
edge_valid_at_utc is not None
|
|
424
|
+
and resolved_edge_valid_at_utc is not None
|
|
425
|
+
and edge_valid_at_utc < resolved_edge_valid_at_utc
|
|
364
426
|
):
|
|
365
427
|
edge.invalid_at = resolved_edge.valid_at
|
|
366
428
|
edge.expired_at = edge.expired_at if edge.expired_at is not None else utc_now()
|
|
@@ -375,32 +437,69 @@ async def resolve_extracted_edge(
|
|
|
375
437
|
related_edges: list[EntityEdge],
|
|
376
438
|
existing_edges: list[EntityEdge],
|
|
377
439
|
episode: EpisodicNode,
|
|
378
|
-
|
|
440
|
+
edge_type_candidates: dict[str, type[BaseModel]] | None = None,
|
|
441
|
+
custom_edge_type_names: set[str] | None = None,
|
|
379
442
|
) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
|
|
443
|
+
"""Resolve an extracted edge against existing graph context.
|
|
444
|
+
|
|
445
|
+
Parameters
|
|
446
|
+
----------
|
|
447
|
+
llm_client : LLMClient
|
|
448
|
+
Client used to invoke the LLM for deduplication and attribute extraction.
|
|
449
|
+
extracted_edge : EntityEdge
|
|
450
|
+
Newly extracted edge whose canonical representation is being resolved.
|
|
451
|
+
related_edges : list[EntityEdge]
|
|
452
|
+
Candidate edges with identical endpoints used for duplicate detection.
|
|
453
|
+
existing_edges : list[EntityEdge]
|
|
454
|
+
Broader set of edges evaluated for contradiction / invalidation.
|
|
455
|
+
episode : EpisodicNode
|
|
456
|
+
Episode providing content context when extracting edge attributes.
|
|
457
|
+
edge_type_candidates : dict[str, type[BaseModel]] | None
|
|
458
|
+
Custom edge types permitted for the current source/target signature.
|
|
459
|
+
custom_edge_type_names : set[str] | None
|
|
460
|
+
Full catalog of registered custom edge names. Used to distinguish
|
|
461
|
+
between disallowed custom types (which fall back to the default label)
|
|
462
|
+
and ad-hoc labels emitted by the LLM.
|
|
463
|
+
|
|
464
|
+
Returns
|
|
465
|
+
-------
|
|
466
|
+
tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]
|
|
467
|
+
The resolved edge, any duplicates, and edges to invalidate.
|
|
468
|
+
"""
|
|
380
469
|
if len(related_edges) == 0 and len(existing_edges) == 0:
|
|
381
470
|
return extracted_edge, [], []
|
|
382
471
|
|
|
472
|
+
# Fast path: if the fact text and endpoints already exist verbatim, reuse the matching edge.
|
|
473
|
+
normalized_fact = _normalize_string_exact(extracted_edge.fact)
|
|
474
|
+
for edge in related_edges:
|
|
475
|
+
if (
|
|
476
|
+
edge.source_node_uuid == extracted_edge.source_node_uuid
|
|
477
|
+
and edge.target_node_uuid == extracted_edge.target_node_uuid
|
|
478
|
+
and _normalize_string_exact(edge.fact) == normalized_fact
|
|
479
|
+
):
|
|
480
|
+
resolved = edge
|
|
481
|
+
if episode is not None and episode.uuid not in resolved.episodes:
|
|
482
|
+
resolved.episodes.append(episode.uuid)
|
|
483
|
+
return resolved, [], []
|
|
484
|
+
|
|
383
485
|
start = time()
|
|
384
486
|
|
|
385
487
|
# Prepare context for LLM
|
|
386
|
-
related_edges_context = [
|
|
387
|
-
{'id': edge.uuid, 'fact': edge.fact} for i, edge in enumerate(related_edges)
|
|
388
|
-
]
|
|
488
|
+
related_edges_context = [{'idx': i, 'fact': edge.fact} for i, edge in enumerate(related_edges)]
|
|
389
489
|
|
|
390
490
|
invalidation_edge_candidates_context = [
|
|
391
|
-
{'
|
|
491
|
+
{'idx': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
|
|
392
492
|
]
|
|
393
493
|
|
|
394
494
|
edge_types_context = (
|
|
395
495
|
[
|
|
396
496
|
{
|
|
397
|
-
'fact_type_id': i,
|
|
398
497
|
'fact_type_name': type_name,
|
|
399
498
|
'fact_type_description': type_model.__doc__,
|
|
400
499
|
}
|
|
401
|
-
for
|
|
500
|
+
for type_name, type_model in edge_type_candidates.items()
|
|
402
501
|
]
|
|
403
|
-
if
|
|
502
|
+
if edge_type_candidates is not None
|
|
404
503
|
else []
|
|
405
504
|
)
|
|
406
505
|
|
|
@@ -411,15 +510,34 @@ async def resolve_extracted_edge(
|
|
|
411
510
|
'edge_types': edge_types_context,
|
|
412
511
|
}
|
|
413
512
|
|
|
513
|
+
if related_edges or existing_edges:
|
|
514
|
+
logger.debug(
|
|
515
|
+
'Resolving edge: sent %d EXISTING FACTS%s and %d INVALIDATION CANDIDATES%s',
|
|
516
|
+
len(related_edges),
|
|
517
|
+
f' (idx 0-{len(related_edges) - 1})' if related_edges else '',
|
|
518
|
+
len(existing_edges),
|
|
519
|
+
f' (idx 0-{len(existing_edges) - 1})' if existing_edges else '',
|
|
520
|
+
)
|
|
521
|
+
|
|
414
522
|
llm_response = await llm_client.generate_response(
|
|
415
523
|
prompt_library.dedupe_edges.resolve_edge(context),
|
|
416
524
|
response_model=EdgeDuplicate,
|
|
417
525
|
model_size=ModelSize.small,
|
|
526
|
+
prompt_name='dedupe_edges.resolve_edge',
|
|
418
527
|
)
|
|
528
|
+
response_object = EdgeDuplicate(**llm_response)
|
|
529
|
+
duplicate_facts = response_object.duplicate_facts
|
|
530
|
+
|
|
531
|
+
# Validate duplicate_facts are in valid range for EXISTING FACTS
|
|
532
|
+
invalid_duplicates = [i for i in duplicate_facts if i < 0 or i >= len(related_edges)]
|
|
533
|
+
if invalid_duplicates:
|
|
534
|
+
logger.warning(
|
|
535
|
+
'LLM returned invalid duplicate_facts idx values %s (valid range: 0-%d for EXISTING FACTS)',
|
|
536
|
+
invalid_duplicates,
|
|
537
|
+
len(related_edges) - 1,
|
|
538
|
+
)
|
|
419
539
|
|
|
420
|
-
duplicate_fact_ids: list[int] =
|
|
421
|
-
filter(lambda i: 0 <= i < len(related_edges), llm_response.get('duplicate_facts', []))
|
|
422
|
-
)
|
|
540
|
+
duplicate_fact_ids: list[int] = [i for i in duplicate_facts if 0 <= i < len(related_edges)]
|
|
423
541
|
|
|
424
542
|
resolved_edge = extracted_edge
|
|
425
543
|
for duplicate_fact_id in duplicate_fact_ids:
|
|
@@ -429,12 +547,32 @@ async def resolve_extracted_edge(
|
|
|
429
547
|
if duplicate_fact_ids and episode is not None:
|
|
430
548
|
resolved_edge.episodes.append(episode.uuid)
|
|
431
549
|
|
|
432
|
-
contradicted_facts: list[int] =
|
|
550
|
+
contradicted_facts: list[int] = response_object.contradicted_facts
|
|
551
|
+
|
|
552
|
+
# Validate contradicted_facts are in valid range for INVALIDATION CANDIDATES
|
|
553
|
+
invalid_contradictions = [i for i in contradicted_facts if i < 0 or i >= len(existing_edges)]
|
|
554
|
+
if invalid_contradictions:
|
|
555
|
+
logger.warning(
|
|
556
|
+
'LLM returned invalid contradicted_facts idx values %s (valid range: 0-%d for INVALIDATION CANDIDATES)',
|
|
557
|
+
invalid_contradictions,
|
|
558
|
+
len(existing_edges) - 1,
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
invalidation_candidates: list[EntityEdge] = [
|
|
562
|
+
existing_edges[i] for i in contradicted_facts if 0 <= i < len(existing_edges)
|
|
563
|
+
]
|
|
564
|
+
|
|
565
|
+
fact_type: str = response_object.fact_type
|
|
566
|
+
candidate_type_names = set(edge_type_candidates or {})
|
|
567
|
+
custom_type_names = custom_edge_type_names or set()
|
|
433
568
|
|
|
434
|
-
|
|
569
|
+
is_default_type = fact_type.upper() == 'DEFAULT'
|
|
570
|
+
is_custom_type = fact_type in custom_type_names
|
|
571
|
+
is_allowed_custom_type = fact_type in candidate_type_names
|
|
435
572
|
|
|
436
|
-
|
|
437
|
-
|
|
573
|
+
if is_allowed_custom_type:
|
|
574
|
+
# The LLM selected a custom type that is allowed for the node pair.
|
|
575
|
+
# Adopt the custom type and, if needed, extract its structured attributes.
|
|
438
576
|
resolved_edge.name = fact_type
|
|
439
577
|
|
|
440
578
|
edge_attributes_context = {
|
|
@@ -443,15 +581,26 @@ async def resolve_extracted_edge(
|
|
|
443
581
|
'fact': resolved_edge.fact,
|
|
444
582
|
}
|
|
445
583
|
|
|
446
|
-
edge_model =
|
|
584
|
+
edge_model = edge_type_candidates.get(fact_type) if edge_type_candidates else None
|
|
447
585
|
if edge_model is not None and len(edge_model.model_fields) != 0:
|
|
448
586
|
edge_attributes_response = await llm_client.generate_response(
|
|
449
587
|
prompt_library.extract_edges.extract_attributes(edge_attributes_context),
|
|
450
588
|
response_model=edge_model, # type: ignore
|
|
451
589
|
model_size=ModelSize.small,
|
|
590
|
+
prompt_name='extract_edges.extract_attributes',
|
|
452
591
|
)
|
|
453
592
|
|
|
454
593
|
resolved_edge.attributes = edge_attributes_response
|
|
594
|
+
elif not is_default_type and is_custom_type:
|
|
595
|
+
# The LLM picked a custom type that is not allowed for this signature.
|
|
596
|
+
# Reset to the default label and drop any structured attributes.
|
|
597
|
+
resolved_edge.name = DEFAULT_EDGE_NAME
|
|
598
|
+
resolved_edge.attributes = {}
|
|
599
|
+
elif not is_default_type:
|
|
600
|
+
# Non-custom labels are allowed to pass through so long as the LLM does
|
|
601
|
+
# not return the sentinel DEFAULT value.
|
|
602
|
+
resolved_edge.name = fact_type
|
|
603
|
+
resolved_edge.attributes = {}
|
|
455
604
|
|
|
456
605
|
end = time()
|
|
457
606
|
logger.debug(
|
|
@@ -465,14 +614,14 @@ async def resolve_extracted_edge(
|
|
|
465
614
|
|
|
466
615
|
# Determine if the new_edge needs to be expired
|
|
467
616
|
if resolved_edge.expired_at is None:
|
|
468
|
-
invalidation_candidates.sort(key=lambda c: (c.valid_at is None, c.valid_at))
|
|
617
|
+
invalidation_candidates.sort(key=lambda c: (c.valid_at is None, ensure_utc(c.valid_at)))
|
|
469
618
|
for candidate in invalidation_candidates:
|
|
619
|
+
candidate_valid_at_utc = ensure_utc(candidate.valid_at)
|
|
620
|
+
resolved_edge_valid_at_utc = ensure_utc(resolved_edge.valid_at)
|
|
470
621
|
if (
|
|
471
|
-
|
|
472
|
-
and
|
|
473
|
-
and
|
|
474
|
-
and resolved_edge.valid_at.tzinfo
|
|
475
|
-
and candidate.valid_at > resolved_edge.valid_at
|
|
622
|
+
candidate_valid_at_utc is not None
|
|
623
|
+
and resolved_edge_valid_at_utc is not None
|
|
624
|
+
and candidate_valid_at_utc > resolved_edge_valid_at_utc
|
|
476
625
|
):
|
|
477
626
|
# Expire new edge since we have information about more recent events
|
|
478
627
|
resolved_edge.invalid_at = candidate.valid_at
|
|
@@ -488,59 +637,60 @@ async def resolve_extracted_edge(
|
|
|
488
637
|
return resolved_edge, invalidated_edges, duplicate_edges
|
|
489
638
|
|
|
490
639
|
|
|
491
|
-
async def dedupe_edge_list(
|
|
492
|
-
llm_client: LLMClient,
|
|
493
|
-
edges: list[EntityEdge],
|
|
494
|
-
) -> list[EntityEdge]:
|
|
495
|
-
start = time()
|
|
496
|
-
|
|
497
|
-
# Create edge map
|
|
498
|
-
edge_map = {}
|
|
499
|
-
for edge in edges:
|
|
500
|
-
edge_map[edge.uuid] = edge
|
|
501
|
-
|
|
502
|
-
# Prepare context for LLM
|
|
503
|
-
context = {'edges': [{'uuid': edge.uuid, 'fact': edge.fact} for edge in edges]}
|
|
504
|
-
|
|
505
|
-
llm_response = await llm_client.generate_response(
|
|
506
|
-
prompt_library.dedupe_edges.edge_list(context), response_model=UniqueFacts
|
|
507
|
-
)
|
|
508
|
-
unique_edges_data = llm_response.get('unique_facts', [])
|
|
509
|
-
|
|
510
|
-
end = time()
|
|
511
|
-
logger.debug(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ')
|
|
512
|
-
|
|
513
|
-
# Get full edge data
|
|
514
|
-
unique_edges = []
|
|
515
|
-
for edge_data in unique_edges_data:
|
|
516
|
-
uuid = edge_data['uuid']
|
|
517
|
-
edge = edge_map[uuid]
|
|
518
|
-
edge.fact = edge_data['fact']
|
|
519
|
-
unique_edges.append(edge)
|
|
520
|
-
|
|
521
|
-
return unique_edges
|
|
522
|
-
|
|
523
|
-
|
|
524
640
|
async def filter_existing_duplicate_of_edges(
|
|
525
641
|
driver: GraphDriver, duplicates_node_tuples: list[tuple[EntityNode, EntityNode]]
|
|
526
642
|
) -> list[tuple[EntityNode, EntityNode]]:
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
|
|
530
|
-
RETURN DISTINCT
|
|
531
|
-
n.uuid AS source_uuid,
|
|
532
|
-
m.uuid AS target_uuid
|
|
533
|
-
"""
|
|
643
|
+
if not duplicates_node_tuples:
|
|
644
|
+
return []
|
|
534
645
|
|
|
535
646
|
duplicate_nodes_map = {
|
|
536
647
|
(source.uuid, target.uuid): (source, target) for source, target in duplicates_node_tuples
|
|
537
648
|
}
|
|
538
649
|
|
|
539
|
-
|
|
540
|
-
query
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
650
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
651
|
+
query: LiteralString = """
|
|
652
|
+
UNWIND $duplicate_node_uuids AS duplicate_tuple
|
|
653
|
+
MATCH (n:Entity {uuid: duplicate_tuple.source})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple.target})
|
|
654
|
+
RETURN DISTINCT
|
|
655
|
+
n.uuid AS source_uuid,
|
|
656
|
+
m.uuid AS target_uuid
|
|
657
|
+
"""
|
|
658
|
+
|
|
659
|
+
duplicate_nodes = [
|
|
660
|
+
{'source': source.uuid, 'target': target.uuid}
|
|
661
|
+
for source, target in duplicates_node_tuples
|
|
662
|
+
]
|
|
663
|
+
|
|
664
|
+
records, _, _ = await driver.execute_query(
|
|
665
|
+
query,
|
|
666
|
+
duplicate_node_uuids=duplicate_nodes,
|
|
667
|
+
routing_='r',
|
|
668
|
+
)
|
|
669
|
+
else:
|
|
670
|
+
if driver.provider == GraphProvider.KUZU:
|
|
671
|
+
query = """
|
|
672
|
+
UNWIND $duplicate_node_uuids AS duplicate
|
|
673
|
+
MATCH (n:Entity {uuid: duplicate.src})-[:RELATES_TO]->(e:RelatesToNode_ {name: 'IS_DUPLICATE_OF'})-[:RELATES_TO]->(m:Entity {uuid: duplicate.dst})
|
|
674
|
+
RETURN DISTINCT
|
|
675
|
+
n.uuid AS source_uuid,
|
|
676
|
+
m.uuid AS target_uuid
|
|
677
|
+
"""
|
|
678
|
+
duplicate_node_uuids = [{'src': src, 'dst': dst} for src, dst in duplicate_nodes_map]
|
|
679
|
+
else:
|
|
680
|
+
query: LiteralString = """
|
|
681
|
+
UNWIND $duplicate_node_uuids AS duplicate_tuple
|
|
682
|
+
MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
|
|
683
|
+
RETURN DISTINCT
|
|
684
|
+
n.uuid AS source_uuid,
|
|
685
|
+
m.uuid AS target_uuid
|
|
686
|
+
"""
|
|
687
|
+
duplicate_node_uuids = list(duplicate_nodes_map.keys())
|
|
688
|
+
|
|
689
|
+
records, _, _ = await driver.execute_query(
|
|
690
|
+
query,
|
|
691
|
+
duplicate_node_uuids=duplicate_node_uuids,
|
|
692
|
+
routing_='r',
|
|
693
|
+
)
|
|
544
694
|
|
|
545
695
|
# Remove duplicates that already have the IS_DUPLICATE_OF edge
|
|
546
696
|
for record in records:
|