graphiti-core 0.17.4__py3-none-any.whl → 0.25.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. graphiti_core/cross_encoder/gemini_reranker_client.py +1 -1
  2. graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
  3. graphiti_core/decorators.py +110 -0
  4. graphiti_core/driver/driver.py +62 -2
  5. graphiti_core/driver/falkordb_driver.py +215 -23
  6. graphiti_core/driver/graph_operations/graph_operations.py +191 -0
  7. graphiti_core/driver/kuzu_driver.py +182 -0
  8. graphiti_core/driver/neo4j_driver.py +70 -8
  9. graphiti_core/driver/neptune_driver.py +305 -0
  10. graphiti_core/driver/search_interface/search_interface.py +89 -0
  11. graphiti_core/edges.py +264 -132
  12. graphiti_core/embedder/azure_openai.py +10 -3
  13. graphiti_core/embedder/client.py +2 -1
  14. graphiti_core/graph_queries.py +114 -101
  15. graphiti_core/graphiti.py +635 -260
  16. graphiti_core/graphiti_types.py +2 -0
  17. graphiti_core/helpers.py +37 -15
  18. graphiti_core/llm_client/anthropic_client.py +142 -52
  19. graphiti_core/llm_client/azure_openai_client.py +57 -19
  20. graphiti_core/llm_client/client.py +83 -21
  21. graphiti_core/llm_client/config.py +1 -1
  22. graphiti_core/llm_client/gemini_client.py +75 -57
  23. graphiti_core/llm_client/openai_base_client.py +92 -48
  24. graphiti_core/llm_client/openai_client.py +39 -9
  25. graphiti_core/llm_client/openai_generic_client.py +91 -56
  26. graphiti_core/models/edges/edge_db_queries.py +259 -35
  27. graphiti_core/models/nodes/node_db_queries.py +311 -32
  28. graphiti_core/nodes.py +388 -164
  29. graphiti_core/prompts/dedupe_edges.py +42 -31
  30. graphiti_core/prompts/dedupe_nodes.py +56 -39
  31. graphiti_core/prompts/eval.py +4 -4
  32. graphiti_core/prompts/extract_edges.py +24 -15
  33. graphiti_core/prompts/extract_nodes.py +76 -35
  34. graphiti_core/prompts/prompt_helpers.py +39 -0
  35. graphiti_core/prompts/snippets.py +29 -0
  36. graphiti_core/prompts/summarize_nodes.py +23 -25
  37. graphiti_core/search/search.py +154 -74
  38. graphiti_core/search/search_config.py +39 -4
  39. graphiti_core/search/search_filters.py +110 -31
  40. graphiti_core/search/search_helpers.py +5 -6
  41. graphiti_core/search/search_utils.py +1360 -473
  42. graphiti_core/tracer.py +193 -0
  43. graphiti_core/utils/bulk_utils.py +216 -90
  44. graphiti_core/utils/content_chunking.py +702 -0
  45. graphiti_core/utils/datetime_utils.py +13 -0
  46. graphiti_core/utils/maintenance/community_operations.py +62 -38
  47. graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
  48. graphiti_core/utils/maintenance/edge_operations.py +306 -156
  49. graphiti_core/utils/maintenance/graph_data_operations.py +44 -74
  50. graphiti_core/utils/maintenance/node_operations.py +466 -206
  51. graphiti_core/utils/maintenance/temporal_operations.py +11 -3
  52. graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
  53. graphiti_core/utils/text_utils.py +53 -0
  54. {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/METADATA +221 -87
  55. graphiti_core-0.25.3.dist-info/RECORD +87 -0
  56. {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/WHEEL +1 -1
  57. graphiti_core-0.17.4.dist-info/RECORD +0 -77
  58. /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
  59. {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/licenses/LICENSE +0 -0
@@ -21,7 +21,7 @@ from time import time
21
21
  from pydantic import BaseModel
22
22
  from typing_extensions import LiteralString
23
23
 
24
- from graphiti_core.driver.driver import GraphDriver
24
+ from graphiti_core.driver.driver import GraphDriver, GraphProvider
25
25
  from graphiti_core.edges import (
26
26
  CommunityEdge,
27
27
  EntityEdge,
@@ -29,16 +29,21 @@ from graphiti_core.edges import (
29
29
  create_entity_edge_embeddings,
30
30
  )
31
31
  from graphiti_core.graphiti_types import GraphitiClients
32
- from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
32
+ from graphiti_core.helpers import semaphore_gather
33
33
  from graphiti_core.llm_client import LLMClient
34
34
  from graphiti_core.llm_client.config import ModelSize
35
35
  from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
36
36
  from graphiti_core.prompts import prompt_library
37
- from graphiti_core.prompts.dedupe_edges import EdgeDuplicate, UniqueFacts
38
- from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
37
+ from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
38
+ from graphiti_core.prompts.extract_edges import ExtractedEdges
39
+ from graphiti_core.search.search import search
40
+ from graphiti_core.search.search_config import SearchResults
41
+ from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF
39
42
  from graphiti_core.search.search_filters import SearchFilters
40
- from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
41
43
  from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
44
+ from graphiti_core.utils.maintenance.dedup_helpers import _normalize_string_exact
45
+
46
+ DEFAULT_EDGE_NAME = 'RELATES_TO'
42
47
 
43
48
  logger = logging.getLogger(__name__)
44
49
 
@@ -63,32 +68,6 @@ def build_episodic_edges(
63
68
  return episodic_edges
64
69
 
65
70
 
66
- def build_duplicate_of_edges(
67
- episode: EpisodicNode,
68
- created_at: datetime,
69
- duplicate_nodes: list[tuple[EntityNode, EntityNode]],
70
- ) -> list[EntityEdge]:
71
- is_duplicate_of_edges: list[EntityEdge] = []
72
- for source_node, target_node in duplicate_nodes:
73
- if source_node.uuid == target_node.uuid:
74
- continue
75
-
76
- is_duplicate_of_edges.append(
77
- EntityEdge(
78
- source_node_uuid=source_node.uuid,
79
- target_node_uuid=target_node.uuid,
80
- name='IS_DUPLICATE_OF',
81
- group_id=episode.group_id,
82
- fact=f'{source_node.name} is a duplicate of {target_node.name}',
83
- episodes=[episode.uuid],
84
- created_at=created_at,
85
- valid_at=created_at,
86
- )
87
- )
88
-
89
- return is_duplicate_of_edges
90
-
91
-
92
71
  def build_community_edges(
93
72
  entity_nodes: list[EntityNode],
94
73
  community_node: CommunityNode,
@@ -114,7 +93,8 @@ async def extract_edges(
114
93
  previous_episodes: list[EpisodicNode],
115
94
  edge_type_map: dict[tuple[str, str], list[str]],
116
95
  group_id: str = '',
117
- edge_types: dict[str, BaseModel] | None = None,
96
+ edge_types: dict[str, type[BaseModel]] | None = None,
97
+ custom_extraction_instructions: str | None = None,
118
98
  ) -> list[EntityEdge]:
119
99
  start = time()
120
100
 
@@ -150,38 +130,17 @@ async def extract_edges(
150
130
  'previous_episodes': [ep.content for ep in previous_episodes],
151
131
  'reference_time': episode.valid_at,
152
132
  'edge_types': edge_types_context,
153
- 'custom_prompt': '',
133
+ 'custom_extraction_instructions': custom_extraction_instructions or '',
154
134
  }
155
135
 
156
- facts_missed = True
157
- reflexion_iterations = 0
158
- while facts_missed and reflexion_iterations <= MAX_REFLEXION_ITERATIONS:
159
- llm_response = await llm_client.generate_response(
160
- prompt_library.extract_edges.edge(context),
161
- response_model=ExtractedEdges,
162
- max_tokens=extract_edges_max_tokens,
163
- )
164
- edges_data = llm_response.get('edges', [])
165
-
166
- context['extracted_facts'] = [edge_data.get('fact', '') for edge_data in edges_data]
167
-
168
- reflexion_iterations += 1
169
- if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
170
- reflexion_response = await llm_client.generate_response(
171
- prompt_library.extract_edges.reflexion(context),
172
- response_model=MissingFacts,
173
- max_tokens=extract_edges_max_tokens,
174
- )
175
-
176
- missing_facts = reflexion_response.get('missing_facts', [])
177
-
178
- custom_prompt = 'The following facts were missed in a previous extraction: '
179
- for fact in missing_facts:
180
- custom_prompt += f'\n{fact},'
181
-
182
- context['custom_prompt'] = custom_prompt
183
-
184
- facts_missed = len(missing_facts) != 0
136
+ llm_response = await llm_client.generate_response(
137
+ prompt_library.extract_edges.edge(context),
138
+ response_model=ExtractedEdges,
139
+ max_tokens=extract_edges_max_tokens,
140
+ group_id=group_id,
141
+ prompt_name='extract_edges.edge',
142
+ )
143
+ edges_data = ExtractedEdges(**llm_response).edges
185
144
 
186
145
  end = time()
187
146
  logger.debug(f'Extracted new edges: {edges_data} in {(end - start) * 1000} ms')
@@ -193,20 +152,31 @@ async def extract_edges(
193
152
  edges = []
194
153
  for edge_data in edges_data:
195
154
  # Validate Edge Date information
196
- valid_at = edge_data.get('valid_at', None)
197
- invalid_at = edge_data.get('invalid_at', None)
155
+ valid_at = edge_data.valid_at
156
+ invalid_at = edge_data.invalid_at
198
157
  valid_at_datetime = None
199
158
  invalid_at_datetime = None
200
159
 
201
- source_node_idx = edge_data.get('source_entity_id', -1)
202
- target_node_idx = edge_data.get('target_entity_id', -1)
203
- if not (-1 < source_node_idx < len(nodes) and -1 < target_node_idx < len(nodes)):
160
+ # Filter out empty edges
161
+ if not edge_data.fact.strip():
162
+ continue
163
+
164
+ source_node_idx = edge_data.source_entity_id
165
+ target_node_idx = edge_data.target_entity_id
166
+
167
+ if len(nodes) == 0:
168
+ logger.warning('No entities provided for edge extraction')
169
+ continue
170
+
171
+ if not (0 <= source_node_idx < len(nodes) and 0 <= target_node_idx < len(nodes)):
204
172
  logger.warning(
205
- f'WARNING: source or target node not filled {edge_data.get("edge_name")}. source_node_uuid: {source_node_idx} and target_node_uuid: {target_node_idx} '
173
+ f'Invalid entity IDs in edge extraction for {edge_data.relation_type}. '
174
+ f'source_entity_id: {source_node_idx}, target_entity_id: {target_node_idx}, '
175
+ f'but only {len(nodes)} entities available (valid range: 0-{len(nodes) - 1})'
206
176
  )
207
177
  continue
208
178
  source_node_uuid = nodes[source_node_idx].uuid
209
- target_node_uuid = nodes[edge_data.get('target_entity_id')].uuid
179
+ target_node_uuid = nodes[target_node_idx].uuid
210
180
 
211
181
  if valid_at:
212
182
  try:
@@ -226,9 +196,9 @@ async def extract_edges(
226
196
  edge = EntityEdge(
227
197
  source_node_uuid=source_node_uuid,
228
198
  target_node_uuid=target_node_uuid,
229
- name=edge_data.get('relation_type', ''),
199
+ name=edge_data.relation_type,
230
200
  group_id=group_id,
231
- fact=edge_data.get('fact', ''),
201
+ fact=edge_data.fact,
232
202
  episodes=[episode.uuid],
233
203
  created_at=utc_now(),
234
204
  valid_at=valid_at_datetime,
@@ -249,20 +219,68 @@ async def resolve_extracted_edges(
249
219
  extracted_edges: list[EntityEdge],
250
220
  episode: EpisodicNode,
251
221
  entities: list[EntityNode],
252
- edge_types: dict[str, BaseModel],
222
+ edge_types: dict[str, type[BaseModel]],
253
223
  edge_type_map: dict[tuple[str, str], list[str]],
254
224
  ) -> tuple[list[EntityEdge], list[EntityEdge]]:
225
+ # Fast path: deduplicate exact matches within the extracted edges before parallel processing
226
+ seen: dict[tuple[str, str, str], EntityEdge] = {}
227
+ deduplicated_edges: list[EntityEdge] = []
228
+
229
+ for edge in extracted_edges:
230
+ key = (
231
+ edge.source_node_uuid,
232
+ edge.target_node_uuid,
233
+ _normalize_string_exact(edge.fact),
234
+ )
235
+ if key not in seen:
236
+ seen[key] = edge
237
+ deduplicated_edges.append(edge)
238
+
239
+ extracted_edges = deduplicated_edges
240
+
255
241
  driver = clients.driver
256
242
  llm_client = clients.llm_client
257
243
  embedder = clients.embedder
258
244
  await create_entity_edge_embeddings(embedder, extracted_edges)
259
245
 
260
- search_results = await semaphore_gather(
261
- get_relevant_edges(driver, extracted_edges, SearchFilters()),
262
- get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters(), 0.2),
246
+ valid_edges_list: list[list[EntityEdge]] = await semaphore_gather(
247
+ *[
248
+ EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid)
249
+ for edge in extracted_edges
250
+ ]
263
251
  )
264
252
 
265
- related_edges_lists, edge_invalidation_candidates = search_results
253
+ related_edges_results: list[SearchResults] = await semaphore_gather(
254
+ *[
255
+ search(
256
+ clients,
257
+ extracted_edge.fact,
258
+ group_ids=[extracted_edge.group_id],
259
+ config=EDGE_HYBRID_SEARCH_RRF,
260
+ search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
261
+ )
262
+ for extracted_edge, valid_edges in zip(extracted_edges, valid_edges_list, strict=True)
263
+ ]
264
+ )
265
+
266
+ related_edges_lists: list[list[EntityEdge]] = [result.edges for result in related_edges_results]
267
+
268
+ edge_invalidation_candidate_results: list[SearchResults] = await semaphore_gather(
269
+ *[
270
+ search(
271
+ clients,
272
+ extracted_edge.fact,
273
+ group_ids=[extracted_edge.group_id],
274
+ config=EDGE_HYBRID_SEARCH_RRF,
275
+ search_filter=SearchFilters(),
276
+ )
277
+ for extracted_edge in extracted_edges
278
+ ]
279
+ )
280
+
281
+ edge_invalidation_candidates: list[list[EntityEdge]] = [
282
+ result.edges for result in edge_invalidation_candidate_results
283
+ ]
266
284
 
267
285
  logger.debug(
268
286
  f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
@@ -271,11 +289,35 @@ async def resolve_extracted_edges(
271
289
  # Build entity hash table
272
290
  uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
273
291
 
274
- # Determine which edge types are relevant for each edge
275
- edge_types_lst: list[dict[str, BaseModel]] = []
292
+ # Collect all node UUIDs referenced by edges that are not in the entities list
293
+ referenced_node_uuids = set()
294
+ for extracted_edge in extracted_edges:
295
+ if extracted_edge.source_node_uuid not in uuid_entity_map:
296
+ referenced_node_uuids.add(extracted_edge.source_node_uuid)
297
+ if extracted_edge.target_node_uuid not in uuid_entity_map:
298
+ referenced_node_uuids.add(extracted_edge.target_node_uuid)
299
+
300
+ # Fetch missing nodes from the database
301
+ if referenced_node_uuids:
302
+ missing_nodes = await EntityNode.get_by_uuids(driver, list(referenced_node_uuids))
303
+ for node in missing_nodes:
304
+ uuid_entity_map[node.uuid] = node
305
+
306
+ # Determine which edge types are relevant for each edge.
307
+ # `edge_types_lst` stores the subset of custom edge definitions whose
308
+ # node signature matches each extracted edge. Anything outside this subset
309
+ # should only stay on the edge if it is a non-custom (LLM generated) label.
310
+ edge_types_lst: list[dict[str, type[BaseModel]]] = []
311
+ custom_type_names = set(edge_types or {})
276
312
  for extracted_edge in extracted_edges:
277
- source_node_labels = uuid_entity_map[extracted_edge.source_node_uuid].labels + ['Entity']
278
- target_node_labels = uuid_entity_map[extracted_edge.target_node_uuid].labels + ['Entity']
313
+ source_node = uuid_entity_map.get(extracted_edge.source_node_uuid)
314
+ target_node = uuid_entity_map.get(extracted_edge.target_node_uuid)
315
+ source_node_labels = (
316
+ source_node.labels + ['Entity'] if source_node is not None else ['Entity']
317
+ )
318
+ target_node_labels = (
319
+ target_node.labels + ['Entity'] if target_node is not None else ['Entity']
320
+ )
279
321
  label_tuples = [
280
322
  (source_label, target_label)
281
323
  for source_label in source_node_labels
@@ -294,6 +336,20 @@ async def resolve_extracted_edges(
294
336
 
295
337
  edge_types_lst.append(extracted_edge_types)
296
338
 
339
+ for extracted_edge, extracted_edge_types in zip(extracted_edges, edge_types_lst, strict=True):
340
+ allowed_type_names = set(extracted_edge_types)
341
+ is_custom_name = extracted_edge.name in custom_type_names
342
+ if not allowed_type_names:
343
+ # No custom types are valid for this node pairing. Keep LLM generated
344
+ # labels, but flip disallowed custom names back to the default.
345
+ if is_custom_name and extracted_edge.name != DEFAULT_EDGE_NAME:
346
+ extracted_edge.name = DEFAULT_EDGE_NAME
347
+ continue
348
+ if is_custom_name and extracted_edge.name not in allowed_type_names:
349
+ # Custom name exists but it is not permitted for this source/target
350
+ # signature, so fall back to the default edge label.
351
+ extracted_edge.name = DEFAULT_EDGE_NAME
352
+
297
353
  # resolve edges with related edges in the graph and find invalidation candidates
298
354
  results: list[tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]] = list(
299
355
  await semaphore_gather(
@@ -305,6 +361,7 @@ async def resolve_extracted_edges(
305
361
  existing_edges,
306
362
  episode,
307
363
  extracted_edge_types,
364
+ custom_type_names,
308
365
  )
309
366
  for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
310
367
  extracted_edges,
@@ -346,21 +403,26 @@ def resolve_edge_contradictions(
346
403
  invalidated_edges: list[EntityEdge] = []
347
404
  for edge in invalidation_candidates:
348
405
  # (Edge invalid before new edge becomes valid) or (new edge invalid before edge becomes valid)
406
+ edge_invalid_at_utc = ensure_utc(edge.invalid_at)
407
+ resolved_edge_valid_at_utc = ensure_utc(resolved_edge.valid_at)
408
+ edge_valid_at_utc = ensure_utc(edge.valid_at)
409
+ resolved_edge_invalid_at_utc = ensure_utc(resolved_edge.invalid_at)
410
+
349
411
  if (
350
- edge.invalid_at is not None
351
- and resolved_edge.valid_at is not None
352
- and edge.invalid_at <= resolved_edge.valid_at
412
+ edge_invalid_at_utc is not None
413
+ and resolved_edge_valid_at_utc is not None
414
+ and edge_invalid_at_utc <= resolved_edge_valid_at_utc
353
415
  ) or (
354
- edge.valid_at is not None
355
- and resolved_edge.invalid_at is not None
356
- and resolved_edge.invalid_at <= edge.valid_at
416
+ edge_valid_at_utc is not None
417
+ and resolved_edge_invalid_at_utc is not None
418
+ and resolved_edge_invalid_at_utc <= edge_valid_at_utc
357
419
  ):
358
420
  continue
359
421
  # New edge invalidates edge
360
422
  elif (
361
- edge.valid_at is not None
362
- and resolved_edge.valid_at is not None
363
- and edge.valid_at < resolved_edge.valid_at
423
+ edge_valid_at_utc is not None
424
+ and resolved_edge_valid_at_utc is not None
425
+ and edge_valid_at_utc < resolved_edge_valid_at_utc
364
426
  ):
365
427
  edge.invalid_at = resolved_edge.valid_at
366
428
  edge.expired_at = edge.expired_at if edge.expired_at is not None else utc_now()
@@ -375,32 +437,69 @@ async def resolve_extracted_edge(
375
437
  related_edges: list[EntityEdge],
376
438
  existing_edges: list[EntityEdge],
377
439
  episode: EpisodicNode,
378
- edge_types: dict[str, BaseModel] | None = None,
440
+ edge_type_candidates: dict[str, type[BaseModel]] | None = None,
441
+ custom_edge_type_names: set[str] | None = None,
379
442
  ) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
443
+ """Resolve an extracted edge against existing graph context.
444
+
445
+ Parameters
446
+ ----------
447
+ llm_client : LLMClient
448
+ Client used to invoke the LLM for deduplication and attribute extraction.
449
+ extracted_edge : EntityEdge
450
+ Newly extracted edge whose canonical representation is being resolved.
451
+ related_edges : list[EntityEdge]
452
+ Candidate edges with identical endpoints used for duplicate detection.
453
+ existing_edges : list[EntityEdge]
454
+ Broader set of edges evaluated for contradiction / invalidation.
455
+ episode : EpisodicNode
456
+ Episode providing content context when extracting edge attributes.
457
+ edge_type_candidates : dict[str, type[BaseModel]] | None
458
+ Custom edge types permitted for the current source/target signature.
459
+ custom_edge_type_names : set[str] | None
460
+ Full catalog of registered custom edge names. Used to distinguish
461
+ between disallowed custom types (which fall back to the default label)
462
+ and ad-hoc labels emitted by the LLM.
463
+
464
+ Returns
465
+ -------
466
+ tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]
467
+ The resolved edge, any duplicates, and edges to invalidate.
468
+ """
380
469
  if len(related_edges) == 0 and len(existing_edges) == 0:
381
470
  return extracted_edge, [], []
382
471
 
472
+ # Fast path: if the fact text and endpoints already exist verbatim, reuse the matching edge.
473
+ normalized_fact = _normalize_string_exact(extracted_edge.fact)
474
+ for edge in related_edges:
475
+ if (
476
+ edge.source_node_uuid == extracted_edge.source_node_uuid
477
+ and edge.target_node_uuid == extracted_edge.target_node_uuid
478
+ and _normalize_string_exact(edge.fact) == normalized_fact
479
+ ):
480
+ resolved = edge
481
+ if episode is not None and episode.uuid not in resolved.episodes:
482
+ resolved.episodes.append(episode.uuid)
483
+ return resolved, [], []
484
+
383
485
  start = time()
384
486
 
385
487
  # Prepare context for LLM
386
- related_edges_context = [
387
- {'id': edge.uuid, 'fact': edge.fact} for i, edge in enumerate(related_edges)
388
- ]
488
+ related_edges_context = [{'idx': i, 'fact': edge.fact} for i, edge in enumerate(related_edges)]
389
489
 
390
490
  invalidation_edge_candidates_context = [
391
- {'id': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
491
+ {'idx': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
392
492
  ]
393
493
 
394
494
  edge_types_context = (
395
495
  [
396
496
  {
397
- 'fact_type_id': i,
398
497
  'fact_type_name': type_name,
399
498
  'fact_type_description': type_model.__doc__,
400
499
  }
401
- for i, (type_name, type_model) in enumerate(edge_types.items())
500
+ for type_name, type_model in edge_type_candidates.items()
402
501
  ]
403
- if edge_types is not None
502
+ if edge_type_candidates is not None
404
503
  else []
405
504
  )
406
505
 
@@ -411,15 +510,34 @@ async def resolve_extracted_edge(
411
510
  'edge_types': edge_types_context,
412
511
  }
413
512
 
513
+ if related_edges or existing_edges:
514
+ logger.debug(
515
+ 'Resolving edge: sent %d EXISTING FACTS%s and %d INVALIDATION CANDIDATES%s',
516
+ len(related_edges),
517
+ f' (idx 0-{len(related_edges) - 1})' if related_edges else '',
518
+ len(existing_edges),
519
+ f' (idx 0-{len(existing_edges) - 1})' if existing_edges else '',
520
+ )
521
+
414
522
  llm_response = await llm_client.generate_response(
415
523
  prompt_library.dedupe_edges.resolve_edge(context),
416
524
  response_model=EdgeDuplicate,
417
525
  model_size=ModelSize.small,
526
+ prompt_name='dedupe_edges.resolve_edge',
418
527
  )
528
+ response_object = EdgeDuplicate(**llm_response)
529
+ duplicate_facts = response_object.duplicate_facts
530
+
531
+ # Validate duplicate_facts are in valid range for EXISTING FACTS
532
+ invalid_duplicates = [i for i in duplicate_facts if i < 0 or i >= len(related_edges)]
533
+ if invalid_duplicates:
534
+ logger.warning(
535
+ 'LLM returned invalid duplicate_facts idx values %s (valid range: 0-%d for EXISTING FACTS)',
536
+ invalid_duplicates,
537
+ len(related_edges) - 1,
538
+ )
419
539
 
420
- duplicate_fact_ids: list[int] = list(
421
- filter(lambda i: 0 <= i < len(related_edges), llm_response.get('duplicate_facts', []))
422
- )
540
+ duplicate_fact_ids: list[int] = [i for i in duplicate_facts if 0 <= i < len(related_edges)]
423
541
 
424
542
  resolved_edge = extracted_edge
425
543
  for duplicate_fact_id in duplicate_fact_ids:
@@ -429,12 +547,32 @@ async def resolve_extracted_edge(
429
547
  if duplicate_fact_ids and episode is not None:
430
548
  resolved_edge.episodes.append(episode.uuid)
431
549
 
432
- contradicted_facts: list[int] = llm_response.get('contradicted_facts', [])
550
+ contradicted_facts: list[int] = response_object.contradicted_facts
551
+
552
+ # Validate contradicted_facts are in valid range for INVALIDATION CANDIDATES
553
+ invalid_contradictions = [i for i in contradicted_facts if i < 0 or i >= len(existing_edges)]
554
+ if invalid_contradictions:
555
+ logger.warning(
556
+ 'LLM returned invalid contradicted_facts idx values %s (valid range: 0-%d for INVALIDATION CANDIDATES)',
557
+ invalid_contradictions,
558
+ len(existing_edges) - 1,
559
+ )
560
+
561
+ invalidation_candidates: list[EntityEdge] = [
562
+ existing_edges[i] for i in contradicted_facts if 0 <= i < len(existing_edges)
563
+ ]
564
+
565
+ fact_type: str = response_object.fact_type
566
+ candidate_type_names = set(edge_type_candidates or {})
567
+ custom_type_names = custom_edge_type_names or set()
433
568
 
434
- invalidation_candidates: list[EntityEdge] = [existing_edges[i] for i in contradicted_facts]
569
+ is_default_type = fact_type.upper() == 'DEFAULT'
570
+ is_custom_type = fact_type in custom_type_names
571
+ is_allowed_custom_type = fact_type in candidate_type_names
435
572
 
436
- fact_type: str = str(llm_response.get('fact_type'))
437
- if fact_type.upper() != 'DEFAULT' and edge_types is not None:
573
+ if is_allowed_custom_type:
574
+ # The LLM selected a custom type that is allowed for the node pair.
575
+ # Adopt the custom type and, if needed, extract its structured attributes.
438
576
  resolved_edge.name = fact_type
439
577
 
440
578
  edge_attributes_context = {
@@ -443,15 +581,26 @@ async def resolve_extracted_edge(
443
581
  'fact': resolved_edge.fact,
444
582
  }
445
583
 
446
- edge_model = edge_types.get(fact_type)
584
+ edge_model = edge_type_candidates.get(fact_type) if edge_type_candidates else None
447
585
  if edge_model is not None and len(edge_model.model_fields) != 0:
448
586
  edge_attributes_response = await llm_client.generate_response(
449
587
  prompt_library.extract_edges.extract_attributes(edge_attributes_context),
450
588
  response_model=edge_model, # type: ignore
451
589
  model_size=ModelSize.small,
590
+ prompt_name='extract_edges.extract_attributes',
452
591
  )
453
592
 
454
593
  resolved_edge.attributes = edge_attributes_response
594
+ elif not is_default_type and is_custom_type:
595
+ # The LLM picked a custom type that is not allowed for this signature.
596
+ # Reset to the default label and drop any structured attributes.
597
+ resolved_edge.name = DEFAULT_EDGE_NAME
598
+ resolved_edge.attributes = {}
599
+ elif not is_default_type:
600
+ # Non-custom labels are allowed to pass through so long as the LLM does
601
+ # not return the sentinel DEFAULT value.
602
+ resolved_edge.name = fact_type
603
+ resolved_edge.attributes = {}
455
604
 
456
605
  end = time()
457
606
  logger.debug(
@@ -465,14 +614,14 @@ async def resolve_extracted_edge(
465
614
 
466
615
  # Determine if the new_edge needs to be expired
467
616
  if resolved_edge.expired_at is None:
468
- invalidation_candidates.sort(key=lambda c: (c.valid_at is None, c.valid_at))
617
+ invalidation_candidates.sort(key=lambda c: (c.valid_at is None, ensure_utc(c.valid_at)))
469
618
  for candidate in invalidation_candidates:
619
+ candidate_valid_at_utc = ensure_utc(candidate.valid_at)
620
+ resolved_edge_valid_at_utc = ensure_utc(resolved_edge.valid_at)
470
621
  if (
471
- candidate.valid_at
472
- and resolved_edge.valid_at
473
- and candidate.valid_at.tzinfo
474
- and resolved_edge.valid_at.tzinfo
475
- and candidate.valid_at > resolved_edge.valid_at
622
+ candidate_valid_at_utc is not None
623
+ and resolved_edge_valid_at_utc is not None
624
+ and candidate_valid_at_utc > resolved_edge_valid_at_utc
476
625
  ):
477
626
  # Expire new edge since we have information about more recent events
478
627
  resolved_edge.invalid_at = candidate.valid_at
@@ -488,59 +637,60 @@ async def resolve_extracted_edge(
488
637
  return resolved_edge, invalidated_edges, duplicate_edges
489
638
 
490
639
 
491
- async def dedupe_edge_list(
492
- llm_client: LLMClient,
493
- edges: list[EntityEdge],
494
- ) -> list[EntityEdge]:
495
- start = time()
496
-
497
- # Create edge map
498
- edge_map = {}
499
- for edge in edges:
500
- edge_map[edge.uuid] = edge
501
-
502
- # Prepare context for LLM
503
- context = {'edges': [{'uuid': edge.uuid, 'fact': edge.fact} for edge in edges]}
504
-
505
- llm_response = await llm_client.generate_response(
506
- prompt_library.dedupe_edges.edge_list(context), response_model=UniqueFacts
507
- )
508
- unique_edges_data = llm_response.get('unique_facts', [])
509
-
510
- end = time()
511
- logger.debug(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ')
512
-
513
- # Get full edge data
514
- unique_edges = []
515
- for edge_data in unique_edges_data:
516
- uuid = edge_data['uuid']
517
- edge = edge_map[uuid]
518
- edge.fact = edge_data['fact']
519
- unique_edges.append(edge)
520
-
521
- return unique_edges
522
-
523
-
524
640
  async def filter_existing_duplicate_of_edges(
525
641
  driver: GraphDriver, duplicates_node_tuples: list[tuple[EntityNode, EntityNode]]
526
642
  ) -> list[tuple[EntityNode, EntityNode]]:
527
- query: LiteralString = """
528
- UNWIND $duplicate_node_uuids AS duplicate_tuple
529
- MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
530
- RETURN DISTINCT
531
- n.uuid AS source_uuid,
532
- m.uuid AS target_uuid
533
- """
643
+ if not duplicates_node_tuples:
644
+ return []
534
645
 
535
646
  duplicate_nodes_map = {
536
647
  (source.uuid, target.uuid): (source, target) for source, target in duplicates_node_tuples
537
648
  }
538
649
 
539
- records, _, _ = await driver.execute_query(
540
- query,
541
- duplicate_node_uuids=list(duplicate_nodes_map.keys()),
542
- routing_='r',
543
- )
650
+ if driver.provider == GraphProvider.NEPTUNE:
651
+ query: LiteralString = """
652
+ UNWIND $duplicate_node_uuids AS duplicate_tuple
653
+ MATCH (n:Entity {uuid: duplicate_tuple.source})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple.target})
654
+ RETURN DISTINCT
655
+ n.uuid AS source_uuid,
656
+ m.uuid AS target_uuid
657
+ """
658
+
659
+ duplicate_nodes = [
660
+ {'source': source.uuid, 'target': target.uuid}
661
+ for source, target in duplicates_node_tuples
662
+ ]
663
+
664
+ records, _, _ = await driver.execute_query(
665
+ query,
666
+ duplicate_node_uuids=duplicate_nodes,
667
+ routing_='r',
668
+ )
669
+ else:
670
+ if driver.provider == GraphProvider.KUZU:
671
+ query = """
672
+ UNWIND $duplicate_node_uuids AS duplicate
673
+ MATCH (n:Entity {uuid: duplicate.src})-[:RELATES_TO]->(e:RelatesToNode_ {name: 'IS_DUPLICATE_OF'})-[:RELATES_TO]->(m:Entity {uuid: duplicate.dst})
674
+ RETURN DISTINCT
675
+ n.uuid AS source_uuid,
676
+ m.uuid AS target_uuid
677
+ """
678
+ duplicate_node_uuids = [{'src': src, 'dst': dst} for src, dst in duplicate_nodes_map]
679
+ else:
680
+ query: LiteralString = """
681
+ UNWIND $duplicate_node_uuids AS duplicate_tuple
682
+ MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
683
+ RETURN DISTINCT
684
+ n.uuid AS source_uuid,
685
+ m.uuid AS target_uuid
686
+ """
687
+ duplicate_node_uuids = list(duplicate_nodes_map.keys())
688
+
689
+ records, _, _ = await driver.execute_query(
690
+ query,
691
+ duplicate_node_uuids=duplicate_node_uuids,
692
+ routing_='r',
693
+ )
544
694
 
545
695
  # Remove duplicates that already have the IS_DUPLICATE_OF edge
546
696
  for record in records: