graphiti-core 0.17.4__py3-none-any.whl → 0.24.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. graphiti_core/cross_encoder/gemini_reranker_client.py +1 -1
  2. graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
  3. graphiti_core/decorators.py +110 -0
  4. graphiti_core/driver/driver.py +62 -2
  5. graphiti_core/driver/falkordb_driver.py +215 -23
  6. graphiti_core/driver/graph_operations/graph_operations.py +191 -0
  7. graphiti_core/driver/kuzu_driver.py +182 -0
  8. graphiti_core/driver/neo4j_driver.py +61 -8
  9. graphiti_core/driver/neptune_driver.py +305 -0
  10. graphiti_core/driver/search_interface/search_interface.py +89 -0
  11. graphiti_core/edges.py +264 -132
  12. graphiti_core/embedder/azure_openai.py +10 -3
  13. graphiti_core/embedder/client.py +2 -1
  14. graphiti_core/graph_queries.py +114 -101
  15. graphiti_core/graphiti.py +582 -255
  16. graphiti_core/graphiti_types.py +2 -0
  17. graphiti_core/helpers.py +21 -14
  18. graphiti_core/llm_client/anthropic_client.py +142 -52
  19. graphiti_core/llm_client/azure_openai_client.py +57 -19
  20. graphiti_core/llm_client/client.py +83 -21
  21. graphiti_core/llm_client/config.py +1 -1
  22. graphiti_core/llm_client/gemini_client.py +75 -57
  23. graphiti_core/llm_client/openai_base_client.py +94 -50
  24. graphiti_core/llm_client/openai_client.py +28 -8
  25. graphiti_core/llm_client/openai_generic_client.py +91 -56
  26. graphiti_core/models/edges/edge_db_queries.py +259 -35
  27. graphiti_core/models/nodes/node_db_queries.py +311 -32
  28. graphiti_core/nodes.py +388 -164
  29. graphiti_core/prompts/dedupe_edges.py +42 -31
  30. graphiti_core/prompts/dedupe_nodes.py +56 -39
  31. graphiti_core/prompts/eval.py +4 -4
  32. graphiti_core/prompts/extract_edges.py +23 -14
  33. graphiti_core/prompts/extract_nodes.py +73 -32
  34. graphiti_core/prompts/prompt_helpers.py +39 -0
  35. graphiti_core/prompts/snippets.py +29 -0
  36. graphiti_core/prompts/summarize_nodes.py +23 -25
  37. graphiti_core/search/search.py +154 -74
  38. graphiti_core/search/search_config.py +39 -4
  39. graphiti_core/search/search_filters.py +109 -31
  40. graphiti_core/search/search_helpers.py +5 -6
  41. graphiti_core/search/search_utils.py +1360 -473
  42. graphiti_core/tracer.py +193 -0
  43. graphiti_core/utils/bulk_utils.py +216 -90
  44. graphiti_core/utils/datetime_utils.py +13 -0
  45. graphiti_core/utils/maintenance/community_operations.py +62 -38
  46. graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
  47. graphiti_core/utils/maintenance/edge_operations.py +286 -126
  48. graphiti_core/utils/maintenance/graph_data_operations.py +44 -74
  49. graphiti_core/utils/maintenance/node_operations.py +320 -158
  50. graphiti_core/utils/maintenance/temporal_operations.py +11 -3
  51. graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
  52. graphiti_core/utils/text_utils.py +53 -0
  53. {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/METADATA +221 -87
  54. graphiti_core-0.24.3.dist-info/RECORD +86 -0
  55. {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
  56. graphiti_core-0.17.4.dist-info/RECORD +0 -77
  57. /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
  58. {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/licenses/LICENSE +0 -0
@@ -21,7 +21,7 @@ from time import time
21
21
  from pydantic import BaseModel
22
22
  from typing_extensions import LiteralString
23
23
 
24
- from graphiti_core.driver.driver import GraphDriver
24
+ from graphiti_core.driver.driver import GraphDriver, GraphProvider
25
25
  from graphiti_core.edges import (
26
26
  CommunityEdge,
27
27
  EntityEdge,
@@ -34,11 +34,16 @@ from graphiti_core.llm_client import LLMClient
34
34
  from graphiti_core.llm_client.config import ModelSize
35
35
  from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
36
36
  from graphiti_core.prompts import prompt_library
37
- from graphiti_core.prompts.dedupe_edges import EdgeDuplicate, UniqueFacts
37
+ from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
38
38
  from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
39
+ from graphiti_core.search.search import search
40
+ from graphiti_core.search.search_config import SearchResults
41
+ from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF
39
42
  from graphiti_core.search.search_filters import SearchFilters
40
- from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
41
43
  from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
44
+ from graphiti_core.utils.maintenance.dedup_helpers import _normalize_string_exact
45
+
46
+ DEFAULT_EDGE_NAME = 'RELATES_TO'
42
47
 
43
48
  logger = logging.getLogger(__name__)
44
49
 
@@ -63,32 +68,6 @@ def build_episodic_edges(
63
68
  return episodic_edges
64
69
 
65
70
 
66
- def build_duplicate_of_edges(
67
- episode: EpisodicNode,
68
- created_at: datetime,
69
- duplicate_nodes: list[tuple[EntityNode, EntityNode]],
70
- ) -> list[EntityEdge]:
71
- is_duplicate_of_edges: list[EntityEdge] = []
72
- for source_node, target_node in duplicate_nodes:
73
- if source_node.uuid == target_node.uuid:
74
- continue
75
-
76
- is_duplicate_of_edges.append(
77
- EntityEdge(
78
- source_node_uuid=source_node.uuid,
79
- target_node_uuid=target_node.uuid,
80
- name='IS_DUPLICATE_OF',
81
- group_id=episode.group_id,
82
- fact=f'{source_node.name} is a duplicate of {target_node.name}',
83
- episodes=[episode.uuid],
84
- created_at=created_at,
85
- valid_at=created_at,
86
- )
87
- )
88
-
89
- return is_duplicate_of_edges
90
-
91
-
92
71
  def build_community_edges(
93
72
  entity_nodes: list[EntityNode],
94
73
  community_node: CommunityNode,
@@ -114,7 +93,7 @@ async def extract_edges(
114
93
  previous_episodes: list[EpisodicNode],
115
94
  edge_type_map: dict[tuple[str, str], list[str]],
116
95
  group_id: str = '',
117
- edge_types: dict[str, BaseModel] | None = None,
96
+ edge_types: dict[str, type[BaseModel]] | None = None,
118
97
  ) -> list[EntityEdge]:
119
98
  start = time()
120
99
 
@@ -160,10 +139,12 @@ async def extract_edges(
160
139
  prompt_library.extract_edges.edge(context),
161
140
  response_model=ExtractedEdges,
162
141
  max_tokens=extract_edges_max_tokens,
142
+ group_id=group_id,
143
+ prompt_name='extract_edges.edge',
163
144
  )
164
- edges_data = llm_response.get('edges', [])
145
+ edges_data = ExtractedEdges(**llm_response).edges
165
146
 
166
- context['extracted_facts'] = [edge_data.get('fact', '') for edge_data in edges_data]
147
+ context['extracted_facts'] = [edge_data.fact for edge_data in edges_data]
167
148
 
168
149
  reflexion_iterations += 1
169
150
  if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
@@ -171,6 +152,8 @@ async def extract_edges(
171
152
  prompt_library.extract_edges.reflexion(context),
172
153
  response_model=MissingFacts,
173
154
  max_tokens=extract_edges_max_tokens,
155
+ group_id=group_id,
156
+ prompt_name='extract_edges.reflexion',
174
157
  )
175
158
 
176
159
  missing_facts = reflexion_response.get('missing_facts', [])
@@ -193,20 +176,31 @@ async def extract_edges(
193
176
  edges = []
194
177
  for edge_data in edges_data:
195
178
  # Validate Edge Date information
196
- valid_at = edge_data.get('valid_at', None)
197
- invalid_at = edge_data.get('invalid_at', None)
179
+ valid_at = edge_data.valid_at
180
+ invalid_at = edge_data.invalid_at
198
181
  valid_at_datetime = None
199
182
  invalid_at_datetime = None
200
183
 
201
- source_node_idx = edge_data.get('source_entity_id', -1)
202
- target_node_idx = edge_data.get('target_entity_id', -1)
203
- if not (-1 < source_node_idx < len(nodes) and -1 < target_node_idx < len(nodes)):
184
+ # Filter out empty edges
185
+ if not edge_data.fact.strip():
186
+ continue
187
+
188
+ source_node_idx = edge_data.source_entity_id
189
+ target_node_idx = edge_data.target_entity_id
190
+
191
+ if len(nodes) == 0:
192
+ logger.warning('No entities provided for edge extraction')
193
+ continue
194
+
195
+ if not (0 <= source_node_idx < len(nodes) and 0 <= target_node_idx < len(nodes)):
204
196
  logger.warning(
205
- f'WARNING: source or target node not filled {edge_data.get("edge_name")}. source_node_uuid: {source_node_idx} and target_node_uuid: {target_node_idx} '
197
+ f'Invalid entity IDs in edge extraction for {edge_data.relation_type}. '
198
+ f'source_entity_id: {source_node_idx}, target_entity_id: {target_node_idx}, '
199
+ f'but only {len(nodes)} entities available (valid range: 0-{len(nodes) - 1})'
206
200
  )
207
201
  continue
208
202
  source_node_uuid = nodes[source_node_idx].uuid
209
- target_node_uuid = nodes[edge_data.get('target_entity_id')].uuid
203
+ target_node_uuid = nodes[target_node_idx].uuid
210
204
 
211
205
  if valid_at:
212
206
  try:
@@ -226,9 +220,9 @@ async def extract_edges(
226
220
  edge = EntityEdge(
227
221
  source_node_uuid=source_node_uuid,
228
222
  target_node_uuid=target_node_uuid,
229
- name=edge_data.get('relation_type', ''),
223
+ name=edge_data.relation_type,
230
224
  group_id=group_id,
231
- fact=edge_data.get('fact', ''),
225
+ fact=edge_data.fact,
232
226
  episodes=[episode.uuid],
233
227
  created_at=utc_now(),
234
228
  valid_at=valid_at_datetime,
@@ -249,20 +243,68 @@ async def resolve_extracted_edges(
249
243
  extracted_edges: list[EntityEdge],
250
244
  episode: EpisodicNode,
251
245
  entities: list[EntityNode],
252
- edge_types: dict[str, BaseModel],
246
+ edge_types: dict[str, type[BaseModel]],
253
247
  edge_type_map: dict[tuple[str, str], list[str]],
254
248
  ) -> tuple[list[EntityEdge], list[EntityEdge]]:
249
+ # Fast path: deduplicate exact matches within the extracted edges before parallel processing
250
+ seen: dict[tuple[str, str, str], EntityEdge] = {}
251
+ deduplicated_edges: list[EntityEdge] = []
252
+
253
+ for edge in extracted_edges:
254
+ key = (
255
+ edge.source_node_uuid,
256
+ edge.target_node_uuid,
257
+ _normalize_string_exact(edge.fact),
258
+ )
259
+ if key not in seen:
260
+ seen[key] = edge
261
+ deduplicated_edges.append(edge)
262
+
263
+ extracted_edges = deduplicated_edges
264
+
255
265
  driver = clients.driver
256
266
  llm_client = clients.llm_client
257
267
  embedder = clients.embedder
258
268
  await create_entity_edge_embeddings(embedder, extracted_edges)
259
269
 
260
- search_results = await semaphore_gather(
261
- get_relevant_edges(driver, extracted_edges, SearchFilters()),
262
- get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters(), 0.2),
270
+ valid_edges_list: list[list[EntityEdge]] = await semaphore_gather(
271
+ *[
272
+ EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid)
273
+ for edge in extracted_edges
274
+ ]
275
+ )
276
+
277
+ related_edges_results: list[SearchResults] = await semaphore_gather(
278
+ *[
279
+ search(
280
+ clients,
281
+ extracted_edge.fact,
282
+ group_ids=[extracted_edge.group_id],
283
+ config=EDGE_HYBRID_SEARCH_RRF,
284
+ search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
285
+ )
286
+ for extracted_edge, valid_edges in zip(extracted_edges, valid_edges_list, strict=True)
287
+ ]
263
288
  )
264
289
 
265
- related_edges_lists, edge_invalidation_candidates = search_results
290
+ related_edges_lists: list[list[EntityEdge]] = [result.edges for result in related_edges_results]
291
+
292
+ edge_invalidation_candidate_results: list[SearchResults] = await semaphore_gather(
293
+ *[
294
+ search(
295
+ clients,
296
+ extracted_edge.fact,
297
+ group_ids=[extracted_edge.group_id],
298
+ config=EDGE_HYBRID_SEARCH_RRF,
299
+ search_filter=SearchFilters(),
300
+ )
301
+ for extracted_edge in extracted_edges
302
+ ]
303
+ )
304
+
305
+ edge_invalidation_candidates: list[list[EntityEdge]] = [
306
+ result.edges for result in edge_invalidation_candidate_results
307
+ ]
266
308
 
267
309
  logger.debug(
268
310
  f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
@@ -271,11 +313,21 @@ async def resolve_extracted_edges(
271
313
  # Build entity hash table
272
314
  uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
273
315
 
274
- # Determine which edge types are relevant for each edge
275
- edge_types_lst: list[dict[str, BaseModel]] = []
316
+ # Determine which edge types are relevant for each edge.
317
+ # `edge_types_lst` stores the subset of custom edge definitions whose
318
+ # node signature matches each extracted edge. Anything outside this subset
319
+ # should only stay on the edge if it is a non-custom (LLM generated) label.
320
+ edge_types_lst: list[dict[str, type[BaseModel]]] = []
321
+ custom_type_names = set(edge_types or {})
276
322
  for extracted_edge in extracted_edges:
277
- source_node_labels = uuid_entity_map[extracted_edge.source_node_uuid].labels + ['Entity']
278
- target_node_labels = uuid_entity_map[extracted_edge.target_node_uuid].labels + ['Entity']
323
+ source_node = uuid_entity_map.get(extracted_edge.source_node_uuid)
324
+ target_node = uuid_entity_map.get(extracted_edge.target_node_uuid)
325
+ source_node_labels = (
326
+ source_node.labels + ['Entity'] if source_node is not None else ['Entity']
327
+ )
328
+ target_node_labels = (
329
+ target_node.labels + ['Entity'] if target_node is not None else ['Entity']
330
+ )
279
331
  label_tuples = [
280
332
  (source_label, target_label)
281
333
  for source_label in source_node_labels
@@ -294,6 +346,20 @@ async def resolve_extracted_edges(
294
346
 
295
347
  edge_types_lst.append(extracted_edge_types)
296
348
 
349
+ for extracted_edge, extracted_edge_types in zip(extracted_edges, edge_types_lst, strict=True):
350
+ allowed_type_names = set(extracted_edge_types)
351
+ is_custom_name = extracted_edge.name in custom_type_names
352
+ if not allowed_type_names:
353
+ # No custom types are valid for this node pairing. Keep LLM generated
354
+ # labels, but flip disallowed custom names back to the default.
355
+ if is_custom_name and extracted_edge.name != DEFAULT_EDGE_NAME:
356
+ extracted_edge.name = DEFAULT_EDGE_NAME
357
+ continue
358
+ if is_custom_name and extracted_edge.name not in allowed_type_names:
359
+ # Custom name exists but it is not permitted for this source/target
360
+ # signature, so fall back to the default edge label.
361
+ extracted_edge.name = DEFAULT_EDGE_NAME
362
+
297
363
  # resolve edges with related edges in the graph and find invalidation candidates
298
364
  results: list[tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]] = list(
299
365
  await semaphore_gather(
@@ -305,6 +371,7 @@ async def resolve_extracted_edges(
305
371
  existing_edges,
306
372
  episode,
307
373
  extracted_edge_types,
374
+ custom_type_names,
308
375
  )
309
376
  for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
310
377
  extracted_edges,
@@ -346,21 +413,26 @@ def resolve_edge_contradictions(
346
413
  invalidated_edges: list[EntityEdge] = []
347
414
  for edge in invalidation_candidates:
348
415
  # (Edge invalid before new edge becomes valid) or (new edge invalid before edge becomes valid)
416
+ edge_invalid_at_utc = ensure_utc(edge.invalid_at)
417
+ resolved_edge_valid_at_utc = ensure_utc(resolved_edge.valid_at)
418
+ edge_valid_at_utc = ensure_utc(edge.valid_at)
419
+ resolved_edge_invalid_at_utc = ensure_utc(resolved_edge.invalid_at)
420
+
349
421
  if (
350
- edge.invalid_at is not None
351
- and resolved_edge.valid_at is not None
352
- and edge.invalid_at <= resolved_edge.valid_at
422
+ edge_invalid_at_utc is not None
423
+ and resolved_edge_valid_at_utc is not None
424
+ and edge_invalid_at_utc <= resolved_edge_valid_at_utc
353
425
  ) or (
354
- edge.valid_at is not None
355
- and resolved_edge.invalid_at is not None
356
- and resolved_edge.invalid_at <= edge.valid_at
426
+ edge_valid_at_utc is not None
427
+ and resolved_edge_invalid_at_utc is not None
428
+ and resolved_edge_invalid_at_utc <= edge_valid_at_utc
357
429
  ):
358
430
  continue
359
431
  # New edge invalidates edge
360
432
  elif (
361
- edge.valid_at is not None
362
- and resolved_edge.valid_at is not None
363
- and edge.valid_at < resolved_edge.valid_at
433
+ edge_valid_at_utc is not None
434
+ and resolved_edge_valid_at_utc is not None
435
+ and edge_valid_at_utc < resolved_edge_valid_at_utc
364
436
  ):
365
437
  edge.invalid_at = resolved_edge.valid_at
366
438
  edge.expired_at = edge.expired_at if edge.expired_at is not None else utc_now()
@@ -375,32 +447,69 @@ async def resolve_extracted_edge(
375
447
  related_edges: list[EntityEdge],
376
448
  existing_edges: list[EntityEdge],
377
449
  episode: EpisodicNode,
378
- edge_types: dict[str, BaseModel] | None = None,
450
+ edge_type_candidates: dict[str, type[BaseModel]] | None = None,
451
+ custom_edge_type_names: set[str] | None = None,
379
452
  ) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
453
+ """Resolve an extracted edge against existing graph context.
454
+
455
+ Parameters
456
+ ----------
457
+ llm_client : LLMClient
458
+ Client used to invoke the LLM for deduplication and attribute extraction.
459
+ extracted_edge : EntityEdge
460
+ Newly extracted edge whose canonical representation is being resolved.
461
+ related_edges : list[EntityEdge]
462
+ Candidate edges with identical endpoints used for duplicate detection.
463
+ existing_edges : list[EntityEdge]
464
+ Broader set of edges evaluated for contradiction / invalidation.
465
+ episode : EpisodicNode
466
+ Episode providing content context when extracting edge attributes.
467
+ edge_type_candidates : dict[str, type[BaseModel]] | None
468
+ Custom edge types permitted for the current source/target signature.
469
+ custom_edge_type_names : set[str] | None
470
+ Full catalog of registered custom edge names. Used to distinguish
471
+ between disallowed custom types (which fall back to the default label)
472
+ and ad-hoc labels emitted by the LLM.
473
+
474
+ Returns
475
+ -------
476
+ tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]
477
+ The resolved edge, any duplicates, and edges to invalidate.
478
+ """
380
479
  if len(related_edges) == 0 and len(existing_edges) == 0:
381
480
  return extracted_edge, [], []
382
481
 
482
+ # Fast path: if the fact text and endpoints already exist verbatim, reuse the matching edge.
483
+ normalized_fact = _normalize_string_exact(extracted_edge.fact)
484
+ for edge in related_edges:
485
+ if (
486
+ edge.source_node_uuid == extracted_edge.source_node_uuid
487
+ and edge.target_node_uuid == extracted_edge.target_node_uuid
488
+ and _normalize_string_exact(edge.fact) == normalized_fact
489
+ ):
490
+ resolved = edge
491
+ if episode is not None and episode.uuid not in resolved.episodes:
492
+ resolved.episodes.append(episode.uuid)
493
+ return resolved, [], []
494
+
383
495
  start = time()
384
496
 
385
497
  # Prepare context for LLM
386
- related_edges_context = [
387
- {'id': edge.uuid, 'fact': edge.fact} for i, edge in enumerate(related_edges)
388
- ]
498
+ related_edges_context = [{'idx': i, 'fact': edge.fact} for i, edge in enumerate(related_edges)]
389
499
 
390
500
  invalidation_edge_candidates_context = [
391
- {'id': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
501
+ {'idx': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
392
502
  ]
393
503
 
394
504
  edge_types_context = (
395
505
  [
396
506
  {
397
- 'fact_type_id': i,
398
507
  'fact_type_name': type_name,
399
508
  'fact_type_description': type_model.__doc__,
400
509
  }
401
- for i, (type_name, type_model) in enumerate(edge_types.items())
510
+ for type_name, type_model in edge_type_candidates.items()
402
511
  ]
403
- if edge_types is not None
512
+ if edge_type_candidates is not None
404
513
  else []
405
514
  )
406
515
 
@@ -411,15 +520,34 @@ async def resolve_extracted_edge(
411
520
  'edge_types': edge_types_context,
412
521
  }
413
522
 
523
+ if related_edges or existing_edges:
524
+ logger.debug(
525
+ 'Resolving edge: sent %d EXISTING FACTS%s and %d INVALIDATION CANDIDATES%s',
526
+ len(related_edges),
527
+ f' (idx 0-{len(related_edges) - 1})' if related_edges else '',
528
+ len(existing_edges),
529
+ f' (idx 0-{len(existing_edges) - 1})' if existing_edges else '',
530
+ )
531
+
414
532
  llm_response = await llm_client.generate_response(
415
533
  prompt_library.dedupe_edges.resolve_edge(context),
416
534
  response_model=EdgeDuplicate,
417
535
  model_size=ModelSize.small,
536
+ prompt_name='dedupe_edges.resolve_edge',
418
537
  )
538
+ response_object = EdgeDuplicate(**llm_response)
539
+ duplicate_facts = response_object.duplicate_facts
540
+
541
+ # Validate duplicate_facts are in valid range for EXISTING FACTS
542
+ invalid_duplicates = [i for i in duplicate_facts if i < 0 or i >= len(related_edges)]
543
+ if invalid_duplicates:
544
+ logger.warning(
545
+ 'LLM returned invalid duplicate_facts idx values %s (valid range: 0-%d for EXISTING FACTS)',
546
+ invalid_duplicates,
547
+ len(related_edges) - 1,
548
+ )
419
549
 
420
- duplicate_fact_ids: list[int] = list(
421
- filter(lambda i: 0 <= i < len(related_edges), llm_response.get('duplicate_facts', []))
422
- )
550
+ duplicate_fact_ids: list[int] = [i for i in duplicate_facts if 0 <= i < len(related_edges)]
423
551
 
424
552
  resolved_edge = extracted_edge
425
553
  for duplicate_fact_id in duplicate_fact_ids:
@@ -429,12 +557,32 @@ async def resolve_extracted_edge(
429
557
  if duplicate_fact_ids and episode is not None:
430
558
  resolved_edge.episodes.append(episode.uuid)
431
559
 
432
- contradicted_facts: list[int] = llm_response.get('contradicted_facts', [])
560
+ contradicted_facts: list[int] = response_object.contradicted_facts
433
561
 
434
- invalidation_candidates: list[EntityEdge] = [existing_edges[i] for i in contradicted_facts]
562
+ # Validate contradicted_facts are in valid range for INVALIDATION CANDIDATES
563
+ invalid_contradictions = [i for i in contradicted_facts if i < 0 or i >= len(existing_edges)]
564
+ if invalid_contradictions:
565
+ logger.warning(
566
+ 'LLM returned invalid contradicted_facts idx values %s (valid range: 0-%d for INVALIDATION CANDIDATES)',
567
+ invalid_contradictions,
568
+ len(existing_edges) - 1,
569
+ )
435
570
 
436
- fact_type: str = str(llm_response.get('fact_type'))
437
- if fact_type.upper() != 'DEFAULT' and edge_types is not None:
571
+ invalidation_candidates: list[EntityEdge] = [
572
+ existing_edges[i] for i in contradicted_facts if 0 <= i < len(existing_edges)
573
+ ]
574
+
575
+ fact_type: str = response_object.fact_type
576
+ candidate_type_names = set(edge_type_candidates or {})
577
+ custom_type_names = custom_edge_type_names or set()
578
+
579
+ is_default_type = fact_type.upper() == 'DEFAULT'
580
+ is_custom_type = fact_type in custom_type_names
581
+ is_allowed_custom_type = fact_type in candidate_type_names
582
+
583
+ if is_allowed_custom_type:
584
+ # The LLM selected a custom type that is allowed for the node pair.
585
+ # Adopt the custom type and, if needed, extract its structured attributes.
438
586
  resolved_edge.name = fact_type
439
587
 
440
588
  edge_attributes_context = {
@@ -443,15 +591,26 @@ async def resolve_extracted_edge(
443
591
  'fact': resolved_edge.fact,
444
592
  }
445
593
 
446
- edge_model = edge_types.get(fact_type)
594
+ edge_model = edge_type_candidates.get(fact_type) if edge_type_candidates else None
447
595
  if edge_model is not None and len(edge_model.model_fields) != 0:
448
596
  edge_attributes_response = await llm_client.generate_response(
449
597
  prompt_library.extract_edges.extract_attributes(edge_attributes_context),
450
598
  response_model=edge_model, # type: ignore
451
599
  model_size=ModelSize.small,
600
+ prompt_name='extract_edges.extract_attributes',
452
601
  )
453
602
 
454
603
  resolved_edge.attributes = edge_attributes_response
604
+ elif not is_default_type and is_custom_type:
605
+ # The LLM picked a custom type that is not allowed for this signature.
606
+ # Reset to the default label and drop any structured attributes.
607
+ resolved_edge.name = DEFAULT_EDGE_NAME
608
+ resolved_edge.attributes = {}
609
+ elif not is_default_type:
610
+ # Non-custom labels are allowed to pass through so long as the LLM does
611
+ # not return the sentinel DEFAULT value.
612
+ resolved_edge.name = fact_type
613
+ resolved_edge.attributes = {}
455
614
 
456
615
  end = time()
457
616
  logger.debug(
@@ -465,14 +624,14 @@ async def resolve_extracted_edge(
465
624
 
466
625
  # Determine if the new_edge needs to be expired
467
626
  if resolved_edge.expired_at is None:
468
- invalidation_candidates.sort(key=lambda c: (c.valid_at is None, c.valid_at))
627
+ invalidation_candidates.sort(key=lambda c: (c.valid_at is None, ensure_utc(c.valid_at)))
469
628
  for candidate in invalidation_candidates:
629
+ candidate_valid_at_utc = ensure_utc(candidate.valid_at)
630
+ resolved_edge_valid_at_utc = ensure_utc(resolved_edge.valid_at)
470
631
  if (
471
- candidate.valid_at
472
- and resolved_edge.valid_at
473
- and candidate.valid_at.tzinfo
474
- and resolved_edge.valid_at.tzinfo
475
- and candidate.valid_at > resolved_edge.valid_at
632
+ candidate_valid_at_utc is not None
633
+ and resolved_edge_valid_at_utc is not None
634
+ and candidate_valid_at_utc > resolved_edge_valid_at_utc
476
635
  ):
477
636
  # Expire new edge since we have information about more recent events
478
637
  resolved_edge.invalid_at = candidate.valid_at
@@ -488,59 +647,60 @@ async def resolve_extracted_edge(
488
647
  return resolved_edge, invalidated_edges, duplicate_edges
489
648
 
490
649
 
491
- async def dedupe_edge_list(
492
- llm_client: LLMClient,
493
- edges: list[EntityEdge],
494
- ) -> list[EntityEdge]:
495
- start = time()
496
-
497
- # Create edge map
498
- edge_map = {}
499
- for edge in edges:
500
- edge_map[edge.uuid] = edge
501
-
502
- # Prepare context for LLM
503
- context = {'edges': [{'uuid': edge.uuid, 'fact': edge.fact} for edge in edges]}
504
-
505
- llm_response = await llm_client.generate_response(
506
- prompt_library.dedupe_edges.edge_list(context), response_model=UniqueFacts
507
- )
508
- unique_edges_data = llm_response.get('unique_facts', [])
509
-
510
- end = time()
511
- logger.debug(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ')
512
-
513
- # Get full edge data
514
- unique_edges = []
515
- for edge_data in unique_edges_data:
516
- uuid = edge_data['uuid']
517
- edge = edge_map[uuid]
518
- edge.fact = edge_data['fact']
519
- unique_edges.append(edge)
520
-
521
- return unique_edges
522
-
523
-
524
650
  async def filter_existing_duplicate_of_edges(
525
651
  driver: GraphDriver, duplicates_node_tuples: list[tuple[EntityNode, EntityNode]]
526
652
  ) -> list[tuple[EntityNode, EntityNode]]:
527
- query: LiteralString = """
528
- UNWIND $duplicate_node_uuids AS duplicate_tuple
529
- MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
530
- RETURN DISTINCT
531
- n.uuid AS source_uuid,
532
- m.uuid AS target_uuid
533
- """
653
+ if not duplicates_node_tuples:
654
+ return []
534
655
 
535
656
  duplicate_nodes_map = {
536
657
  (source.uuid, target.uuid): (source, target) for source, target in duplicates_node_tuples
537
658
  }
538
659
 
539
- records, _, _ = await driver.execute_query(
540
- query,
541
- duplicate_node_uuids=list(duplicate_nodes_map.keys()),
542
- routing_='r',
543
- )
660
+ if driver.provider == GraphProvider.NEPTUNE:
661
+ query: LiteralString = """
662
+ UNWIND $duplicate_node_uuids AS duplicate_tuple
663
+ MATCH (n:Entity {uuid: duplicate_tuple.source})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple.target})
664
+ RETURN DISTINCT
665
+ n.uuid AS source_uuid,
666
+ m.uuid AS target_uuid
667
+ """
668
+
669
+ duplicate_nodes = [
670
+ {'source': source.uuid, 'target': target.uuid}
671
+ for source, target in duplicates_node_tuples
672
+ ]
673
+
674
+ records, _, _ = await driver.execute_query(
675
+ query,
676
+ duplicate_node_uuids=duplicate_nodes,
677
+ routing_='r',
678
+ )
679
+ else:
680
+ if driver.provider == GraphProvider.KUZU:
681
+ query = """
682
+ UNWIND $duplicate_node_uuids AS duplicate
683
+ MATCH (n:Entity {uuid: duplicate.src})-[:RELATES_TO]->(e:RelatesToNode_ {name: 'IS_DUPLICATE_OF'})-[:RELATES_TO]->(m:Entity {uuid: duplicate.dst})
684
+ RETURN DISTINCT
685
+ n.uuid AS source_uuid,
686
+ m.uuid AS target_uuid
687
+ """
688
+ duplicate_node_uuids = [{'src': src, 'dst': dst} for src, dst in duplicate_nodes_map]
689
+ else:
690
+ query: LiteralString = """
691
+ UNWIND $duplicate_node_uuids AS duplicate_tuple
692
+ MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
693
+ RETURN DISTINCT
694
+ n.uuid AS source_uuid,
695
+ m.uuid AS target_uuid
696
+ """
697
+ duplicate_node_uuids = list(duplicate_nodes_map.keys())
698
+
699
+ records, _, _ = await driver.execute_query(
700
+ query,
701
+ duplicate_node_uuids=duplicate_node_uuids,
702
+ routing_='r',
703
+ )
544
704
 
545
705
  # Remove duplicates that already have the IS_DUPLICATE_OF edge
546
706
  for record in records: