graphiti-core 0.17.4__py3-none-any.whl → 0.24.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphiti_core/cross_encoder/gemini_reranker_client.py +1 -1
- graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
- graphiti_core/decorators.py +110 -0
- graphiti_core/driver/driver.py +62 -2
- graphiti_core/driver/falkordb_driver.py +215 -23
- graphiti_core/driver/graph_operations/graph_operations.py +191 -0
- graphiti_core/driver/kuzu_driver.py +182 -0
- graphiti_core/driver/neo4j_driver.py +61 -8
- graphiti_core/driver/neptune_driver.py +305 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +264 -132
- graphiti_core/embedder/azure_openai.py +10 -3
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/graph_queries.py +114 -101
- graphiti_core/graphiti.py +582 -255
- graphiti_core/graphiti_types.py +2 -0
- graphiti_core/helpers.py +21 -14
- graphiti_core/llm_client/anthropic_client.py +142 -52
- graphiti_core/llm_client/azure_openai_client.py +57 -19
- graphiti_core/llm_client/client.py +83 -21
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/gemini_client.py +75 -57
- graphiti_core/llm_client/openai_base_client.py +94 -50
- graphiti_core/llm_client/openai_client.py +28 -8
- graphiti_core/llm_client/openai_generic_client.py +91 -56
- graphiti_core/models/edges/edge_db_queries.py +259 -35
- graphiti_core/models/nodes/node_db_queries.py +311 -32
- graphiti_core/nodes.py +388 -164
- graphiti_core/prompts/dedupe_edges.py +42 -31
- graphiti_core/prompts/dedupe_nodes.py +56 -39
- graphiti_core/prompts/eval.py +4 -4
- graphiti_core/prompts/extract_edges.py +23 -14
- graphiti_core/prompts/extract_nodes.py +73 -32
- graphiti_core/prompts/prompt_helpers.py +39 -0
- graphiti_core/prompts/snippets.py +29 -0
- graphiti_core/prompts/summarize_nodes.py +23 -25
- graphiti_core/search/search.py +154 -74
- graphiti_core/search/search_config.py +39 -4
- graphiti_core/search/search_filters.py +109 -31
- graphiti_core/search/search_helpers.py +5 -6
- graphiti_core/search/search_utils.py +1360 -473
- graphiti_core/tracer.py +193 -0
- graphiti_core/utils/bulk_utils.py +216 -90
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +62 -38
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +286 -126
- graphiti_core/utils/maintenance/graph_data_operations.py +44 -74
- graphiti_core/utils/maintenance/node_operations.py +320 -158
- graphiti_core/utils/maintenance/temporal_operations.py +11 -3
- graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
- graphiti_core/utils/text_utils.py +53 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/METADATA +221 -87
- graphiti_core-0.24.3.dist-info/RECORD +86 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
- graphiti_core-0.17.4.dist-info/RECORD +0 -77
- /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -21,7 +21,7 @@ from time import time
|
|
|
21
21
|
from pydantic import BaseModel
|
|
22
22
|
from typing_extensions import LiteralString
|
|
23
23
|
|
|
24
|
-
from graphiti_core.driver.driver import GraphDriver
|
|
24
|
+
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
|
25
25
|
from graphiti_core.edges import (
|
|
26
26
|
CommunityEdge,
|
|
27
27
|
EntityEdge,
|
|
@@ -34,11 +34,16 @@ from graphiti_core.llm_client import LLMClient
|
|
|
34
34
|
from graphiti_core.llm_client.config import ModelSize
|
|
35
35
|
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
|
36
36
|
from graphiti_core.prompts import prompt_library
|
|
37
|
-
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
|
|
37
|
+
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
|
|
38
38
|
from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
|
|
39
|
+
from graphiti_core.search.search import search
|
|
40
|
+
from graphiti_core.search.search_config import SearchResults
|
|
41
|
+
from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF
|
|
39
42
|
from graphiti_core.search.search_filters import SearchFilters
|
|
40
|
-
from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
|
|
41
43
|
from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
|
|
44
|
+
from graphiti_core.utils.maintenance.dedup_helpers import _normalize_string_exact
|
|
45
|
+
|
|
46
|
+
DEFAULT_EDGE_NAME = 'RELATES_TO'
|
|
42
47
|
|
|
43
48
|
logger = logging.getLogger(__name__)
|
|
44
49
|
|
|
@@ -63,32 +68,6 @@ def build_episodic_edges(
|
|
|
63
68
|
return episodic_edges
|
|
64
69
|
|
|
65
70
|
|
|
66
|
-
def build_duplicate_of_edges(
|
|
67
|
-
episode: EpisodicNode,
|
|
68
|
-
created_at: datetime,
|
|
69
|
-
duplicate_nodes: list[tuple[EntityNode, EntityNode]],
|
|
70
|
-
) -> list[EntityEdge]:
|
|
71
|
-
is_duplicate_of_edges: list[EntityEdge] = []
|
|
72
|
-
for source_node, target_node in duplicate_nodes:
|
|
73
|
-
if source_node.uuid == target_node.uuid:
|
|
74
|
-
continue
|
|
75
|
-
|
|
76
|
-
is_duplicate_of_edges.append(
|
|
77
|
-
EntityEdge(
|
|
78
|
-
source_node_uuid=source_node.uuid,
|
|
79
|
-
target_node_uuid=target_node.uuid,
|
|
80
|
-
name='IS_DUPLICATE_OF',
|
|
81
|
-
group_id=episode.group_id,
|
|
82
|
-
fact=f'{source_node.name} is a duplicate of {target_node.name}',
|
|
83
|
-
episodes=[episode.uuid],
|
|
84
|
-
created_at=created_at,
|
|
85
|
-
valid_at=created_at,
|
|
86
|
-
)
|
|
87
|
-
)
|
|
88
|
-
|
|
89
|
-
return is_duplicate_of_edges
|
|
90
|
-
|
|
91
|
-
|
|
92
71
|
def build_community_edges(
|
|
93
72
|
entity_nodes: list[EntityNode],
|
|
94
73
|
community_node: CommunityNode,
|
|
@@ -114,7 +93,7 @@ async def extract_edges(
|
|
|
114
93
|
previous_episodes: list[EpisodicNode],
|
|
115
94
|
edge_type_map: dict[tuple[str, str], list[str]],
|
|
116
95
|
group_id: str = '',
|
|
117
|
-
edge_types: dict[str, BaseModel] | None = None,
|
|
96
|
+
edge_types: dict[str, type[BaseModel]] | None = None,
|
|
118
97
|
) -> list[EntityEdge]:
|
|
119
98
|
start = time()
|
|
120
99
|
|
|
@@ -160,10 +139,12 @@ async def extract_edges(
|
|
|
160
139
|
prompt_library.extract_edges.edge(context),
|
|
161
140
|
response_model=ExtractedEdges,
|
|
162
141
|
max_tokens=extract_edges_max_tokens,
|
|
142
|
+
group_id=group_id,
|
|
143
|
+
prompt_name='extract_edges.edge',
|
|
163
144
|
)
|
|
164
|
-
edges_data = llm_response.
|
|
145
|
+
edges_data = ExtractedEdges(**llm_response).edges
|
|
165
146
|
|
|
166
|
-
context['extracted_facts'] = [edge_data.
|
|
147
|
+
context['extracted_facts'] = [edge_data.fact for edge_data in edges_data]
|
|
167
148
|
|
|
168
149
|
reflexion_iterations += 1
|
|
169
150
|
if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
|
|
@@ -171,6 +152,8 @@ async def extract_edges(
|
|
|
171
152
|
prompt_library.extract_edges.reflexion(context),
|
|
172
153
|
response_model=MissingFacts,
|
|
173
154
|
max_tokens=extract_edges_max_tokens,
|
|
155
|
+
group_id=group_id,
|
|
156
|
+
prompt_name='extract_edges.reflexion',
|
|
174
157
|
)
|
|
175
158
|
|
|
176
159
|
missing_facts = reflexion_response.get('missing_facts', [])
|
|
@@ -193,20 +176,31 @@ async def extract_edges(
|
|
|
193
176
|
edges = []
|
|
194
177
|
for edge_data in edges_data:
|
|
195
178
|
# Validate Edge Date information
|
|
196
|
-
valid_at = edge_data.
|
|
197
|
-
invalid_at = edge_data.
|
|
179
|
+
valid_at = edge_data.valid_at
|
|
180
|
+
invalid_at = edge_data.invalid_at
|
|
198
181
|
valid_at_datetime = None
|
|
199
182
|
invalid_at_datetime = None
|
|
200
183
|
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
184
|
+
# Filter out empty edges
|
|
185
|
+
if not edge_data.fact.strip():
|
|
186
|
+
continue
|
|
187
|
+
|
|
188
|
+
source_node_idx = edge_data.source_entity_id
|
|
189
|
+
target_node_idx = edge_data.target_entity_id
|
|
190
|
+
|
|
191
|
+
if len(nodes) == 0:
|
|
192
|
+
logger.warning('No entities provided for edge extraction')
|
|
193
|
+
continue
|
|
194
|
+
|
|
195
|
+
if not (0 <= source_node_idx < len(nodes) and 0 <= target_node_idx < len(nodes)):
|
|
204
196
|
logger.warning(
|
|
205
|
-
f'
|
|
197
|
+
f'Invalid entity IDs in edge extraction for {edge_data.relation_type}. '
|
|
198
|
+
f'source_entity_id: {source_node_idx}, target_entity_id: {target_node_idx}, '
|
|
199
|
+
f'but only {len(nodes)} entities available (valid range: 0-{len(nodes) - 1})'
|
|
206
200
|
)
|
|
207
201
|
continue
|
|
208
202
|
source_node_uuid = nodes[source_node_idx].uuid
|
|
209
|
-
target_node_uuid = nodes[
|
|
203
|
+
target_node_uuid = nodes[target_node_idx].uuid
|
|
210
204
|
|
|
211
205
|
if valid_at:
|
|
212
206
|
try:
|
|
@@ -226,9 +220,9 @@ async def extract_edges(
|
|
|
226
220
|
edge = EntityEdge(
|
|
227
221
|
source_node_uuid=source_node_uuid,
|
|
228
222
|
target_node_uuid=target_node_uuid,
|
|
229
|
-
name=edge_data.
|
|
223
|
+
name=edge_data.relation_type,
|
|
230
224
|
group_id=group_id,
|
|
231
|
-
fact=edge_data.
|
|
225
|
+
fact=edge_data.fact,
|
|
232
226
|
episodes=[episode.uuid],
|
|
233
227
|
created_at=utc_now(),
|
|
234
228
|
valid_at=valid_at_datetime,
|
|
@@ -249,20 +243,68 @@ async def resolve_extracted_edges(
|
|
|
249
243
|
extracted_edges: list[EntityEdge],
|
|
250
244
|
episode: EpisodicNode,
|
|
251
245
|
entities: list[EntityNode],
|
|
252
|
-
edge_types: dict[str, BaseModel],
|
|
246
|
+
edge_types: dict[str, type[BaseModel]],
|
|
253
247
|
edge_type_map: dict[tuple[str, str], list[str]],
|
|
254
248
|
) -> tuple[list[EntityEdge], list[EntityEdge]]:
|
|
249
|
+
# Fast path: deduplicate exact matches within the extracted edges before parallel processing
|
|
250
|
+
seen: dict[tuple[str, str, str], EntityEdge] = {}
|
|
251
|
+
deduplicated_edges: list[EntityEdge] = []
|
|
252
|
+
|
|
253
|
+
for edge in extracted_edges:
|
|
254
|
+
key = (
|
|
255
|
+
edge.source_node_uuid,
|
|
256
|
+
edge.target_node_uuid,
|
|
257
|
+
_normalize_string_exact(edge.fact),
|
|
258
|
+
)
|
|
259
|
+
if key not in seen:
|
|
260
|
+
seen[key] = edge
|
|
261
|
+
deduplicated_edges.append(edge)
|
|
262
|
+
|
|
263
|
+
extracted_edges = deduplicated_edges
|
|
264
|
+
|
|
255
265
|
driver = clients.driver
|
|
256
266
|
llm_client = clients.llm_client
|
|
257
267
|
embedder = clients.embedder
|
|
258
268
|
await create_entity_edge_embeddings(embedder, extracted_edges)
|
|
259
269
|
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
270
|
+
valid_edges_list: list[list[EntityEdge]] = await semaphore_gather(
|
|
271
|
+
*[
|
|
272
|
+
EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid)
|
|
273
|
+
for edge in extracted_edges
|
|
274
|
+
]
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
related_edges_results: list[SearchResults] = await semaphore_gather(
|
|
278
|
+
*[
|
|
279
|
+
search(
|
|
280
|
+
clients,
|
|
281
|
+
extracted_edge.fact,
|
|
282
|
+
group_ids=[extracted_edge.group_id],
|
|
283
|
+
config=EDGE_HYBRID_SEARCH_RRF,
|
|
284
|
+
search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
|
|
285
|
+
)
|
|
286
|
+
for extracted_edge, valid_edges in zip(extracted_edges, valid_edges_list, strict=True)
|
|
287
|
+
]
|
|
263
288
|
)
|
|
264
289
|
|
|
265
|
-
related_edges_lists
|
|
290
|
+
related_edges_lists: list[list[EntityEdge]] = [result.edges for result in related_edges_results]
|
|
291
|
+
|
|
292
|
+
edge_invalidation_candidate_results: list[SearchResults] = await semaphore_gather(
|
|
293
|
+
*[
|
|
294
|
+
search(
|
|
295
|
+
clients,
|
|
296
|
+
extracted_edge.fact,
|
|
297
|
+
group_ids=[extracted_edge.group_id],
|
|
298
|
+
config=EDGE_HYBRID_SEARCH_RRF,
|
|
299
|
+
search_filter=SearchFilters(),
|
|
300
|
+
)
|
|
301
|
+
for extracted_edge in extracted_edges
|
|
302
|
+
]
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
edge_invalidation_candidates: list[list[EntityEdge]] = [
|
|
306
|
+
result.edges for result in edge_invalidation_candidate_results
|
|
307
|
+
]
|
|
266
308
|
|
|
267
309
|
logger.debug(
|
|
268
310
|
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
|
|
@@ -271,11 +313,21 @@ async def resolve_extracted_edges(
|
|
|
271
313
|
# Build entity hash table
|
|
272
314
|
uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
|
|
273
315
|
|
|
274
|
-
# Determine which edge types are relevant for each edge
|
|
275
|
-
edge_types_lst
|
|
316
|
+
# Determine which edge types are relevant for each edge.
|
|
317
|
+
# `edge_types_lst` stores the subset of custom edge definitions whose
|
|
318
|
+
# node signature matches each extracted edge. Anything outside this subset
|
|
319
|
+
# should only stay on the edge if it is a non-custom (LLM generated) label.
|
|
320
|
+
edge_types_lst: list[dict[str, type[BaseModel]]] = []
|
|
321
|
+
custom_type_names = set(edge_types or {})
|
|
276
322
|
for extracted_edge in extracted_edges:
|
|
277
|
-
|
|
278
|
-
|
|
323
|
+
source_node = uuid_entity_map.get(extracted_edge.source_node_uuid)
|
|
324
|
+
target_node = uuid_entity_map.get(extracted_edge.target_node_uuid)
|
|
325
|
+
source_node_labels = (
|
|
326
|
+
source_node.labels + ['Entity'] if source_node is not None else ['Entity']
|
|
327
|
+
)
|
|
328
|
+
target_node_labels = (
|
|
329
|
+
target_node.labels + ['Entity'] if target_node is not None else ['Entity']
|
|
330
|
+
)
|
|
279
331
|
label_tuples = [
|
|
280
332
|
(source_label, target_label)
|
|
281
333
|
for source_label in source_node_labels
|
|
@@ -294,6 +346,20 @@ async def resolve_extracted_edges(
|
|
|
294
346
|
|
|
295
347
|
edge_types_lst.append(extracted_edge_types)
|
|
296
348
|
|
|
349
|
+
for extracted_edge, extracted_edge_types in zip(extracted_edges, edge_types_lst, strict=True):
|
|
350
|
+
allowed_type_names = set(extracted_edge_types)
|
|
351
|
+
is_custom_name = extracted_edge.name in custom_type_names
|
|
352
|
+
if not allowed_type_names:
|
|
353
|
+
# No custom types are valid for this node pairing. Keep LLM generated
|
|
354
|
+
# labels, but flip disallowed custom names back to the default.
|
|
355
|
+
if is_custom_name and extracted_edge.name != DEFAULT_EDGE_NAME:
|
|
356
|
+
extracted_edge.name = DEFAULT_EDGE_NAME
|
|
357
|
+
continue
|
|
358
|
+
if is_custom_name and extracted_edge.name not in allowed_type_names:
|
|
359
|
+
# Custom name exists but it is not permitted for this source/target
|
|
360
|
+
# signature, so fall back to the default edge label.
|
|
361
|
+
extracted_edge.name = DEFAULT_EDGE_NAME
|
|
362
|
+
|
|
297
363
|
# resolve edges with related edges in the graph and find invalidation candidates
|
|
298
364
|
results: list[tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]] = list(
|
|
299
365
|
await semaphore_gather(
|
|
@@ -305,6 +371,7 @@ async def resolve_extracted_edges(
|
|
|
305
371
|
existing_edges,
|
|
306
372
|
episode,
|
|
307
373
|
extracted_edge_types,
|
|
374
|
+
custom_type_names,
|
|
308
375
|
)
|
|
309
376
|
for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
|
|
310
377
|
extracted_edges,
|
|
@@ -346,21 +413,26 @@ def resolve_edge_contradictions(
|
|
|
346
413
|
invalidated_edges: list[EntityEdge] = []
|
|
347
414
|
for edge in invalidation_candidates:
|
|
348
415
|
# (Edge invalid before new edge becomes valid) or (new edge invalid before edge becomes valid)
|
|
416
|
+
edge_invalid_at_utc = ensure_utc(edge.invalid_at)
|
|
417
|
+
resolved_edge_valid_at_utc = ensure_utc(resolved_edge.valid_at)
|
|
418
|
+
edge_valid_at_utc = ensure_utc(edge.valid_at)
|
|
419
|
+
resolved_edge_invalid_at_utc = ensure_utc(resolved_edge.invalid_at)
|
|
420
|
+
|
|
349
421
|
if (
|
|
350
|
-
|
|
351
|
-
and
|
|
352
|
-
and
|
|
422
|
+
edge_invalid_at_utc is not None
|
|
423
|
+
and resolved_edge_valid_at_utc is not None
|
|
424
|
+
and edge_invalid_at_utc <= resolved_edge_valid_at_utc
|
|
353
425
|
) or (
|
|
354
|
-
|
|
355
|
-
and
|
|
356
|
-
and
|
|
426
|
+
edge_valid_at_utc is not None
|
|
427
|
+
and resolved_edge_invalid_at_utc is not None
|
|
428
|
+
and resolved_edge_invalid_at_utc <= edge_valid_at_utc
|
|
357
429
|
):
|
|
358
430
|
continue
|
|
359
431
|
# New edge invalidates edge
|
|
360
432
|
elif (
|
|
361
|
-
|
|
362
|
-
and
|
|
363
|
-
and
|
|
433
|
+
edge_valid_at_utc is not None
|
|
434
|
+
and resolved_edge_valid_at_utc is not None
|
|
435
|
+
and edge_valid_at_utc < resolved_edge_valid_at_utc
|
|
364
436
|
):
|
|
365
437
|
edge.invalid_at = resolved_edge.valid_at
|
|
366
438
|
edge.expired_at = edge.expired_at if edge.expired_at is not None else utc_now()
|
|
@@ -375,32 +447,69 @@ async def resolve_extracted_edge(
|
|
|
375
447
|
related_edges: list[EntityEdge],
|
|
376
448
|
existing_edges: list[EntityEdge],
|
|
377
449
|
episode: EpisodicNode,
|
|
378
|
-
|
|
450
|
+
edge_type_candidates: dict[str, type[BaseModel]] | None = None,
|
|
451
|
+
custom_edge_type_names: set[str] | None = None,
|
|
379
452
|
) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
|
|
453
|
+
"""Resolve an extracted edge against existing graph context.
|
|
454
|
+
|
|
455
|
+
Parameters
|
|
456
|
+
----------
|
|
457
|
+
llm_client : LLMClient
|
|
458
|
+
Client used to invoke the LLM for deduplication and attribute extraction.
|
|
459
|
+
extracted_edge : EntityEdge
|
|
460
|
+
Newly extracted edge whose canonical representation is being resolved.
|
|
461
|
+
related_edges : list[EntityEdge]
|
|
462
|
+
Candidate edges with identical endpoints used for duplicate detection.
|
|
463
|
+
existing_edges : list[EntityEdge]
|
|
464
|
+
Broader set of edges evaluated for contradiction / invalidation.
|
|
465
|
+
episode : EpisodicNode
|
|
466
|
+
Episode providing content context when extracting edge attributes.
|
|
467
|
+
edge_type_candidates : dict[str, type[BaseModel]] | None
|
|
468
|
+
Custom edge types permitted for the current source/target signature.
|
|
469
|
+
custom_edge_type_names : set[str] | None
|
|
470
|
+
Full catalog of registered custom edge names. Used to distinguish
|
|
471
|
+
between disallowed custom types (which fall back to the default label)
|
|
472
|
+
and ad-hoc labels emitted by the LLM.
|
|
473
|
+
|
|
474
|
+
Returns
|
|
475
|
+
-------
|
|
476
|
+
tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]
|
|
477
|
+
The resolved edge, any duplicates, and edges to invalidate.
|
|
478
|
+
"""
|
|
380
479
|
if len(related_edges) == 0 and len(existing_edges) == 0:
|
|
381
480
|
return extracted_edge, [], []
|
|
382
481
|
|
|
482
|
+
# Fast path: if the fact text and endpoints already exist verbatim, reuse the matching edge.
|
|
483
|
+
normalized_fact = _normalize_string_exact(extracted_edge.fact)
|
|
484
|
+
for edge in related_edges:
|
|
485
|
+
if (
|
|
486
|
+
edge.source_node_uuid == extracted_edge.source_node_uuid
|
|
487
|
+
and edge.target_node_uuid == extracted_edge.target_node_uuid
|
|
488
|
+
and _normalize_string_exact(edge.fact) == normalized_fact
|
|
489
|
+
):
|
|
490
|
+
resolved = edge
|
|
491
|
+
if episode is not None and episode.uuid not in resolved.episodes:
|
|
492
|
+
resolved.episodes.append(episode.uuid)
|
|
493
|
+
return resolved, [], []
|
|
494
|
+
|
|
383
495
|
start = time()
|
|
384
496
|
|
|
385
497
|
# Prepare context for LLM
|
|
386
|
-
related_edges_context = [
|
|
387
|
-
{'id': edge.uuid, 'fact': edge.fact} for i, edge in enumerate(related_edges)
|
|
388
|
-
]
|
|
498
|
+
related_edges_context = [{'idx': i, 'fact': edge.fact} for i, edge in enumerate(related_edges)]
|
|
389
499
|
|
|
390
500
|
invalidation_edge_candidates_context = [
|
|
391
|
-
{'
|
|
501
|
+
{'idx': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
|
|
392
502
|
]
|
|
393
503
|
|
|
394
504
|
edge_types_context = (
|
|
395
505
|
[
|
|
396
506
|
{
|
|
397
|
-
'fact_type_id': i,
|
|
398
507
|
'fact_type_name': type_name,
|
|
399
508
|
'fact_type_description': type_model.__doc__,
|
|
400
509
|
}
|
|
401
|
-
for
|
|
510
|
+
for type_name, type_model in edge_type_candidates.items()
|
|
402
511
|
]
|
|
403
|
-
if
|
|
512
|
+
if edge_type_candidates is not None
|
|
404
513
|
else []
|
|
405
514
|
)
|
|
406
515
|
|
|
@@ -411,15 +520,34 @@ async def resolve_extracted_edge(
|
|
|
411
520
|
'edge_types': edge_types_context,
|
|
412
521
|
}
|
|
413
522
|
|
|
523
|
+
if related_edges or existing_edges:
|
|
524
|
+
logger.debug(
|
|
525
|
+
'Resolving edge: sent %d EXISTING FACTS%s and %d INVALIDATION CANDIDATES%s',
|
|
526
|
+
len(related_edges),
|
|
527
|
+
f' (idx 0-{len(related_edges) - 1})' if related_edges else '',
|
|
528
|
+
len(existing_edges),
|
|
529
|
+
f' (idx 0-{len(existing_edges) - 1})' if existing_edges else '',
|
|
530
|
+
)
|
|
531
|
+
|
|
414
532
|
llm_response = await llm_client.generate_response(
|
|
415
533
|
prompt_library.dedupe_edges.resolve_edge(context),
|
|
416
534
|
response_model=EdgeDuplicate,
|
|
417
535
|
model_size=ModelSize.small,
|
|
536
|
+
prompt_name='dedupe_edges.resolve_edge',
|
|
418
537
|
)
|
|
538
|
+
response_object = EdgeDuplicate(**llm_response)
|
|
539
|
+
duplicate_facts = response_object.duplicate_facts
|
|
540
|
+
|
|
541
|
+
# Validate duplicate_facts are in valid range for EXISTING FACTS
|
|
542
|
+
invalid_duplicates = [i for i in duplicate_facts if i < 0 or i >= len(related_edges)]
|
|
543
|
+
if invalid_duplicates:
|
|
544
|
+
logger.warning(
|
|
545
|
+
'LLM returned invalid duplicate_facts idx values %s (valid range: 0-%d for EXISTING FACTS)',
|
|
546
|
+
invalid_duplicates,
|
|
547
|
+
len(related_edges) - 1,
|
|
548
|
+
)
|
|
419
549
|
|
|
420
|
-
duplicate_fact_ids: list[int] =
|
|
421
|
-
filter(lambda i: 0 <= i < len(related_edges), llm_response.get('duplicate_facts', []))
|
|
422
|
-
)
|
|
550
|
+
duplicate_fact_ids: list[int] = [i for i in duplicate_facts if 0 <= i < len(related_edges)]
|
|
423
551
|
|
|
424
552
|
resolved_edge = extracted_edge
|
|
425
553
|
for duplicate_fact_id in duplicate_fact_ids:
|
|
@@ -429,12 +557,32 @@ async def resolve_extracted_edge(
|
|
|
429
557
|
if duplicate_fact_ids and episode is not None:
|
|
430
558
|
resolved_edge.episodes.append(episode.uuid)
|
|
431
559
|
|
|
432
|
-
contradicted_facts: list[int] =
|
|
560
|
+
contradicted_facts: list[int] = response_object.contradicted_facts
|
|
433
561
|
|
|
434
|
-
|
|
562
|
+
# Validate contradicted_facts are in valid range for INVALIDATION CANDIDATES
|
|
563
|
+
invalid_contradictions = [i for i in contradicted_facts if i < 0 or i >= len(existing_edges)]
|
|
564
|
+
if invalid_contradictions:
|
|
565
|
+
logger.warning(
|
|
566
|
+
'LLM returned invalid contradicted_facts idx values %s (valid range: 0-%d for INVALIDATION CANDIDATES)',
|
|
567
|
+
invalid_contradictions,
|
|
568
|
+
len(existing_edges) - 1,
|
|
569
|
+
)
|
|
435
570
|
|
|
436
|
-
|
|
437
|
-
|
|
571
|
+
invalidation_candidates: list[EntityEdge] = [
|
|
572
|
+
existing_edges[i] for i in contradicted_facts if 0 <= i < len(existing_edges)
|
|
573
|
+
]
|
|
574
|
+
|
|
575
|
+
fact_type: str = response_object.fact_type
|
|
576
|
+
candidate_type_names = set(edge_type_candidates or {})
|
|
577
|
+
custom_type_names = custom_edge_type_names or set()
|
|
578
|
+
|
|
579
|
+
is_default_type = fact_type.upper() == 'DEFAULT'
|
|
580
|
+
is_custom_type = fact_type in custom_type_names
|
|
581
|
+
is_allowed_custom_type = fact_type in candidate_type_names
|
|
582
|
+
|
|
583
|
+
if is_allowed_custom_type:
|
|
584
|
+
# The LLM selected a custom type that is allowed for the node pair.
|
|
585
|
+
# Adopt the custom type and, if needed, extract its structured attributes.
|
|
438
586
|
resolved_edge.name = fact_type
|
|
439
587
|
|
|
440
588
|
edge_attributes_context = {
|
|
@@ -443,15 +591,26 @@ async def resolve_extracted_edge(
|
|
|
443
591
|
'fact': resolved_edge.fact,
|
|
444
592
|
}
|
|
445
593
|
|
|
446
|
-
edge_model =
|
|
594
|
+
edge_model = edge_type_candidates.get(fact_type) if edge_type_candidates else None
|
|
447
595
|
if edge_model is not None and len(edge_model.model_fields) != 0:
|
|
448
596
|
edge_attributes_response = await llm_client.generate_response(
|
|
449
597
|
prompt_library.extract_edges.extract_attributes(edge_attributes_context),
|
|
450
598
|
response_model=edge_model, # type: ignore
|
|
451
599
|
model_size=ModelSize.small,
|
|
600
|
+
prompt_name='extract_edges.extract_attributes',
|
|
452
601
|
)
|
|
453
602
|
|
|
454
603
|
resolved_edge.attributes = edge_attributes_response
|
|
604
|
+
elif not is_default_type and is_custom_type:
|
|
605
|
+
# The LLM picked a custom type that is not allowed for this signature.
|
|
606
|
+
# Reset to the default label and drop any structured attributes.
|
|
607
|
+
resolved_edge.name = DEFAULT_EDGE_NAME
|
|
608
|
+
resolved_edge.attributes = {}
|
|
609
|
+
elif not is_default_type:
|
|
610
|
+
# Non-custom labels are allowed to pass through so long as the LLM does
|
|
611
|
+
# not return the sentinel DEFAULT value.
|
|
612
|
+
resolved_edge.name = fact_type
|
|
613
|
+
resolved_edge.attributes = {}
|
|
455
614
|
|
|
456
615
|
end = time()
|
|
457
616
|
logger.debug(
|
|
@@ -465,14 +624,14 @@ async def resolve_extracted_edge(
|
|
|
465
624
|
|
|
466
625
|
# Determine if the new_edge needs to be expired
|
|
467
626
|
if resolved_edge.expired_at is None:
|
|
468
|
-
invalidation_candidates.sort(key=lambda c: (c.valid_at is None, c.valid_at))
|
|
627
|
+
invalidation_candidates.sort(key=lambda c: (c.valid_at is None, ensure_utc(c.valid_at)))
|
|
469
628
|
for candidate in invalidation_candidates:
|
|
629
|
+
candidate_valid_at_utc = ensure_utc(candidate.valid_at)
|
|
630
|
+
resolved_edge_valid_at_utc = ensure_utc(resolved_edge.valid_at)
|
|
470
631
|
if (
|
|
471
|
-
|
|
472
|
-
and
|
|
473
|
-
and
|
|
474
|
-
and resolved_edge.valid_at.tzinfo
|
|
475
|
-
and candidate.valid_at > resolved_edge.valid_at
|
|
632
|
+
candidate_valid_at_utc is not None
|
|
633
|
+
and resolved_edge_valid_at_utc is not None
|
|
634
|
+
and candidate_valid_at_utc > resolved_edge_valid_at_utc
|
|
476
635
|
):
|
|
477
636
|
# Expire new edge since we have information about more recent events
|
|
478
637
|
resolved_edge.invalid_at = candidate.valid_at
|
|
@@ -488,59 +647,60 @@ async def resolve_extracted_edge(
|
|
|
488
647
|
return resolved_edge, invalidated_edges, duplicate_edges
|
|
489
648
|
|
|
490
649
|
|
|
491
|
-
async def dedupe_edge_list(
|
|
492
|
-
llm_client: LLMClient,
|
|
493
|
-
edges: list[EntityEdge],
|
|
494
|
-
) -> list[EntityEdge]:
|
|
495
|
-
start = time()
|
|
496
|
-
|
|
497
|
-
# Create edge map
|
|
498
|
-
edge_map = {}
|
|
499
|
-
for edge in edges:
|
|
500
|
-
edge_map[edge.uuid] = edge
|
|
501
|
-
|
|
502
|
-
# Prepare context for LLM
|
|
503
|
-
context = {'edges': [{'uuid': edge.uuid, 'fact': edge.fact} for edge in edges]}
|
|
504
|
-
|
|
505
|
-
llm_response = await llm_client.generate_response(
|
|
506
|
-
prompt_library.dedupe_edges.edge_list(context), response_model=UniqueFacts
|
|
507
|
-
)
|
|
508
|
-
unique_edges_data = llm_response.get('unique_facts', [])
|
|
509
|
-
|
|
510
|
-
end = time()
|
|
511
|
-
logger.debug(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ')
|
|
512
|
-
|
|
513
|
-
# Get full edge data
|
|
514
|
-
unique_edges = []
|
|
515
|
-
for edge_data in unique_edges_data:
|
|
516
|
-
uuid = edge_data['uuid']
|
|
517
|
-
edge = edge_map[uuid]
|
|
518
|
-
edge.fact = edge_data['fact']
|
|
519
|
-
unique_edges.append(edge)
|
|
520
|
-
|
|
521
|
-
return unique_edges
|
|
522
|
-
|
|
523
|
-
|
|
524
650
|
async def filter_existing_duplicate_of_edges(
|
|
525
651
|
driver: GraphDriver, duplicates_node_tuples: list[tuple[EntityNode, EntityNode]]
|
|
526
652
|
) -> list[tuple[EntityNode, EntityNode]]:
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
|
|
530
|
-
RETURN DISTINCT
|
|
531
|
-
n.uuid AS source_uuid,
|
|
532
|
-
m.uuid AS target_uuid
|
|
533
|
-
"""
|
|
653
|
+
if not duplicates_node_tuples:
|
|
654
|
+
return []
|
|
534
655
|
|
|
535
656
|
duplicate_nodes_map = {
|
|
536
657
|
(source.uuid, target.uuid): (source, target) for source, target in duplicates_node_tuples
|
|
537
658
|
}
|
|
538
659
|
|
|
539
|
-
|
|
540
|
-
query
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
660
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
661
|
+
query: LiteralString = """
|
|
662
|
+
UNWIND $duplicate_node_uuids AS duplicate_tuple
|
|
663
|
+
MATCH (n:Entity {uuid: duplicate_tuple.source})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple.target})
|
|
664
|
+
RETURN DISTINCT
|
|
665
|
+
n.uuid AS source_uuid,
|
|
666
|
+
m.uuid AS target_uuid
|
|
667
|
+
"""
|
|
668
|
+
|
|
669
|
+
duplicate_nodes = [
|
|
670
|
+
{'source': source.uuid, 'target': target.uuid}
|
|
671
|
+
for source, target in duplicates_node_tuples
|
|
672
|
+
]
|
|
673
|
+
|
|
674
|
+
records, _, _ = await driver.execute_query(
|
|
675
|
+
query,
|
|
676
|
+
duplicate_node_uuids=duplicate_nodes,
|
|
677
|
+
routing_='r',
|
|
678
|
+
)
|
|
679
|
+
else:
|
|
680
|
+
if driver.provider == GraphProvider.KUZU:
|
|
681
|
+
query = """
|
|
682
|
+
UNWIND $duplicate_node_uuids AS duplicate
|
|
683
|
+
MATCH (n:Entity {uuid: duplicate.src})-[:RELATES_TO]->(e:RelatesToNode_ {name: 'IS_DUPLICATE_OF'})-[:RELATES_TO]->(m:Entity {uuid: duplicate.dst})
|
|
684
|
+
RETURN DISTINCT
|
|
685
|
+
n.uuid AS source_uuid,
|
|
686
|
+
m.uuid AS target_uuid
|
|
687
|
+
"""
|
|
688
|
+
duplicate_node_uuids = [{'src': src, 'dst': dst} for src, dst in duplicate_nodes_map]
|
|
689
|
+
else:
|
|
690
|
+
query: LiteralString = """
|
|
691
|
+
UNWIND $duplicate_node_uuids AS duplicate_tuple
|
|
692
|
+
MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
|
|
693
|
+
RETURN DISTINCT
|
|
694
|
+
n.uuid AS source_uuid,
|
|
695
|
+
m.uuid AS target_uuid
|
|
696
|
+
"""
|
|
697
|
+
duplicate_node_uuids = list(duplicate_nodes_map.keys())
|
|
698
|
+
|
|
699
|
+
records, _, _ = await driver.execute_query(
|
|
700
|
+
query,
|
|
701
|
+
duplicate_node_uuids=duplicate_node_uuids,
|
|
702
|
+
routing_='r',
|
|
703
|
+
)
|
|
544
704
|
|
|
545
705
|
# Remove duplicates that already have the IS_DUPLICATE_OF edge
|
|
546
706
|
for record in records:
|