graphiti-core 0.12.0rc1__py3-none-any.whl → 0.24.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. graphiti_core/cross_encoder/bge_reranker_client.py +12 -2
  2. graphiti_core/cross_encoder/gemini_reranker_client.py +161 -0
  3. graphiti_core/cross_encoder/openai_reranker_client.py +7 -5
  4. graphiti_core/decorators.py +110 -0
  5. graphiti_core/driver/__init__.py +19 -0
  6. graphiti_core/driver/driver.py +124 -0
  7. graphiti_core/driver/falkordb_driver.py +362 -0
  8. graphiti_core/driver/graph_operations/graph_operations.py +191 -0
  9. graphiti_core/driver/kuzu_driver.py +182 -0
  10. graphiti_core/driver/neo4j_driver.py +117 -0
  11. graphiti_core/driver/neptune_driver.py +305 -0
  12. graphiti_core/driver/search_interface/search_interface.py +89 -0
  13. graphiti_core/edges.py +287 -172
  14. graphiti_core/embedder/azure_openai.py +71 -0
  15. graphiti_core/embedder/client.py +2 -1
  16. graphiti_core/embedder/gemini.py +116 -22
  17. graphiti_core/embedder/voyage.py +13 -2
  18. graphiti_core/errors.py +8 -0
  19. graphiti_core/graph_queries.py +162 -0
  20. graphiti_core/graphiti.py +705 -193
  21. graphiti_core/graphiti_types.py +4 -2
  22. graphiti_core/helpers.py +87 -10
  23. graphiti_core/llm_client/__init__.py +16 -0
  24. graphiti_core/llm_client/anthropic_client.py +159 -56
  25. graphiti_core/llm_client/azure_openai_client.py +115 -0
  26. graphiti_core/llm_client/client.py +98 -21
  27. graphiti_core/llm_client/config.py +1 -1
  28. graphiti_core/llm_client/gemini_client.py +290 -41
  29. graphiti_core/llm_client/groq_client.py +14 -3
  30. graphiti_core/llm_client/openai_base_client.py +261 -0
  31. graphiti_core/llm_client/openai_client.py +56 -132
  32. graphiti_core/llm_client/openai_generic_client.py +91 -56
  33. graphiti_core/models/edges/edge_db_queries.py +259 -35
  34. graphiti_core/models/nodes/node_db_queries.py +311 -32
  35. graphiti_core/nodes.py +420 -205
  36. graphiti_core/prompts/dedupe_edges.py +46 -32
  37. graphiti_core/prompts/dedupe_nodes.py +67 -42
  38. graphiti_core/prompts/eval.py +4 -4
  39. graphiti_core/prompts/extract_edges.py +27 -16
  40. graphiti_core/prompts/extract_nodes.py +74 -31
  41. graphiti_core/prompts/prompt_helpers.py +39 -0
  42. graphiti_core/prompts/snippets.py +29 -0
  43. graphiti_core/prompts/summarize_nodes.py +23 -25
  44. graphiti_core/search/search.py +158 -82
  45. graphiti_core/search/search_config.py +39 -4
  46. graphiti_core/search/search_filters.py +126 -35
  47. graphiti_core/search/search_helpers.py +5 -6
  48. graphiti_core/search/search_utils.py +1405 -485
  49. graphiti_core/telemetry/__init__.py +9 -0
  50. graphiti_core/telemetry/telemetry.py +117 -0
  51. graphiti_core/tracer.py +193 -0
  52. graphiti_core/utils/bulk_utils.py +364 -285
  53. graphiti_core/utils/datetime_utils.py +13 -0
  54. graphiti_core/utils/maintenance/community_operations.py +67 -49
  55. graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
  56. graphiti_core/utils/maintenance/edge_operations.py +339 -197
  57. graphiti_core/utils/maintenance/graph_data_operations.py +50 -114
  58. graphiti_core/utils/maintenance/node_operations.py +319 -238
  59. graphiti_core/utils/maintenance/temporal_operations.py +11 -3
  60. graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
  61. graphiti_core/utils/text_utils.py +53 -0
  62. graphiti_core-0.24.3.dist-info/METADATA +726 -0
  63. graphiti_core-0.24.3.dist-info/RECORD +86 -0
  64. {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
  65. graphiti_core-0.12.0rc1.dist-info/METADATA +0 -350
  66. graphiti_core-0.12.0rc1.dist-info/RECORD +0 -66
  67. /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
  68. {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info/licenses}/LICENSE +0 -0
@@ -19,7 +19,9 @@ from datetime import datetime
19
19
  from time import time
20
20
 
21
21
  from pydantic import BaseModel
22
+ from typing_extensions import LiteralString
22
23
 
24
+ from graphiti_core.driver.driver import GraphDriver, GraphProvider
23
25
  from graphiti_core.edges import (
24
26
  CommunityEdge,
25
27
  EntityEdge,
@@ -32,26 +34,31 @@ from graphiti_core.llm_client import LLMClient
32
34
  from graphiti_core.llm_client.config import ModelSize
33
35
  from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
34
36
  from graphiti_core.prompts import prompt_library
35
- from graphiti_core.prompts.dedupe_edges import EdgeDuplicate, UniqueFacts
37
+ from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
36
38
  from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
39
+ from graphiti_core.search.search import search
40
+ from graphiti_core.search.search_config import SearchResults
41
+ from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF
37
42
  from graphiti_core.search.search_filters import SearchFilters
38
- from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
39
43
  from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
44
+ from graphiti_core.utils.maintenance.dedup_helpers import _normalize_string_exact
45
+
46
+ DEFAULT_EDGE_NAME = 'RELATES_TO'
40
47
 
41
48
  logger = logging.getLogger(__name__)
42
49
 
43
50
 
44
51
  def build_episodic_edges(
45
52
  entity_nodes: list[EntityNode],
46
- episode: EpisodicNode,
53
+ episode_uuid: str,
47
54
  created_at: datetime,
48
55
  ) -> list[EpisodicEdge]:
49
56
  episodic_edges: list[EpisodicEdge] = [
50
57
  EpisodicEdge(
51
- source_node_uuid=episode.uuid,
58
+ source_node_uuid=episode_uuid,
52
59
  target_node_uuid=node.uuid,
53
60
  created_at=created_at,
54
- group_id=episode.group_id,
61
+ group_id=node.group_id,
55
62
  )
56
63
  for node in entity_nodes
57
64
  ]
@@ -84,20 +91,26 @@ async def extract_edges(
84
91
  episode: EpisodicNode,
85
92
  nodes: list[EntityNode],
86
93
  previous_episodes: list[EpisodicNode],
94
+ edge_type_map: dict[tuple[str, str], list[str]],
87
95
  group_id: str = '',
88
- edge_types: dict[str, BaseModel] | None = None,
96
+ edge_types: dict[str, type[BaseModel]] | None = None,
89
97
  ) -> list[EntityEdge]:
90
98
  start = time()
91
99
 
92
100
  extract_edges_max_tokens = 16384
93
101
  llm_client = clients.llm_client
94
102
 
95
- node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
103
+ edge_type_signature_map: dict[str, tuple[str, str]] = {
104
+ edge_type: signature
105
+ for signature, edge_types in edge_type_map.items()
106
+ for edge_type in edge_types
107
+ }
96
108
 
97
109
  edge_types_context = (
98
110
  [
99
111
  {
100
112
  'fact_type_name': type_name,
113
+ 'fact_type_signature': edge_type_signature_map.get(type_name, ('Entity', 'Entity')),
101
114
  'fact_type_description': type_model.__doc__,
102
115
  }
103
116
  for type_name, type_model in edge_types.items()
@@ -109,7 +122,10 @@ async def extract_edges(
109
122
  # Prepare context for LLM
110
123
  context = {
111
124
  'episode_content': episode.content,
112
- 'nodes': [node.name for node in nodes],
125
+ 'nodes': [
126
+ {'id': idx, 'name': node.name, 'entity_types': node.labels}
127
+ for idx, node in enumerate(nodes)
128
+ ],
113
129
  'previous_episodes': [ep.content for ep in previous_episodes],
114
130
  'reference_time': episode.valid_at,
115
131
  'edge_types': edge_types_context,
@@ -123,10 +139,12 @@ async def extract_edges(
123
139
  prompt_library.extract_edges.edge(context),
124
140
  response_model=ExtractedEdges,
125
141
  max_tokens=extract_edges_max_tokens,
142
+ group_id=group_id,
143
+ prompt_name='extract_edges.edge',
126
144
  )
127
- edges_data = llm_response.get('edges', [])
145
+ edges_data = ExtractedEdges(**llm_response).edges
128
146
 
129
- context['extracted_facts'] = [edge_data.get('fact', '') for edge_data in edges_data]
147
+ context['extracted_facts'] = [edge_data.fact for edge_data in edges_data]
130
148
 
131
149
  reflexion_iterations += 1
132
150
  if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
@@ -134,6 +152,8 @@ async def extract_edges(
134
152
  prompt_library.extract_edges.reflexion(context),
135
153
  response_model=MissingFacts,
136
154
  max_tokens=extract_edges_max_tokens,
155
+ group_id=group_id,
156
+ prompt_name='extract_edges.reflexion',
137
157
  )
138
158
 
139
159
  missing_facts = reflexion_response.get('missing_facts', [])
@@ -156,11 +176,32 @@ async def extract_edges(
156
176
  edges = []
157
177
  for edge_data in edges_data:
158
178
  # Validate Edge Date information
159
- valid_at = edge_data.get('valid_at', None)
160
- invalid_at = edge_data.get('invalid_at', None)
179
+ valid_at = edge_data.valid_at
180
+ invalid_at = edge_data.invalid_at
161
181
  valid_at_datetime = None
162
182
  invalid_at_datetime = None
163
183
 
184
+ # Filter out empty edges
185
+ if not edge_data.fact.strip():
186
+ continue
187
+
188
+ source_node_idx = edge_data.source_entity_id
189
+ target_node_idx = edge_data.target_entity_id
190
+
191
+ if len(nodes) == 0:
192
+ logger.warning('No entities provided for edge extraction')
193
+ continue
194
+
195
+ if not (0 <= source_node_idx < len(nodes) and 0 <= target_node_idx < len(nodes)):
196
+ logger.warning(
197
+ f'Invalid entity IDs in edge extraction for {edge_data.relation_type}. '
198
+ f'source_entity_id: {source_node_idx}, target_entity_id: {target_node_idx}, '
199
+ f'but only {len(nodes)} entities available (valid range: 0-{len(nodes) - 1})'
200
+ )
201
+ continue
202
+ source_node_uuid = nodes[source_node_idx].uuid
203
+ target_node_uuid = nodes[target_node_idx].uuid
204
+
164
205
  if valid_at:
165
206
  try:
166
207
  valid_at_datetime = ensure_utc(
@@ -177,15 +218,11 @@ async def extract_edges(
177
218
  except ValueError as e:
178
219
  logger.warning(f'WARNING: Error parsing invalid_at date: {e}. Input: {invalid_at}')
179
220
  edge = EntityEdge(
180
- source_node_uuid=node_uuids_by_name_map.get(
181
- edge_data.get('source_entity_name', ''), ''
182
- ),
183
- target_node_uuid=node_uuids_by_name_map.get(
184
- edge_data.get('target_entity_name', ''), ''
185
- ),
186
- name=edge_data.get('relation_type', ''),
221
+ source_node_uuid=source_node_uuid,
222
+ target_node_uuid=target_node_uuid,
223
+ name=edge_data.relation_type,
187
224
  group_id=group_id,
188
- fact=edge_data.get('fact', ''),
225
+ fact=edge_data.fact,
189
226
  episodes=[episode.uuid],
190
227
  created_at=utc_now(),
191
228
  valid_at=valid_at_datetime,
@@ -201,70 +238,73 @@ async def extract_edges(
201
238
  return edges
202
239
 
203
240
 
204
- async def dedupe_extracted_edges(
205
- llm_client: LLMClient,
206
- extracted_edges: list[EntityEdge],
207
- existing_edges: list[EntityEdge],
208
- ) -> list[EntityEdge]:
209
- # Create edge map
210
- edge_map: dict[str, EntityEdge] = {}
211
- for edge in existing_edges:
212
- edge_map[edge.uuid] = edge
213
-
214
- # Prepare context for LLM
215
- context = {
216
- 'extracted_edges': [
217
- {'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in extracted_edges
218
- ],
219
- 'existing_edges': [
220
- {'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in existing_edges
221
- ],
222
- }
223
-
224
- llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.edge(context))
225
- duplicate_data = llm_response.get('duplicates', [])
226
- logger.debug(f'Extracted unique edges: {duplicate_data}')
227
-
228
- duplicate_uuid_map: dict[str, str] = {}
229
- for duplicate in duplicate_data:
230
- uuid_value = duplicate['duplicate_of']
231
- duplicate_uuid_map[duplicate['uuid']] = uuid_value
232
-
233
- # Get full edge data
234
- edges: list[EntityEdge] = []
235
- for edge in extracted_edges:
236
- if edge.uuid in duplicate_uuid_map:
237
- existing_uuid = duplicate_uuid_map[edge.uuid]
238
- existing_edge = edge_map[existing_uuid]
239
- # Add current episode to the episodes list
240
- existing_edge.episodes += edge.episodes
241
- edges.append(existing_edge)
242
- else:
243
- edges.append(edge)
244
-
245
- return edges
246
-
247
-
248
241
  async def resolve_extracted_edges(
249
242
  clients: GraphitiClients,
250
243
  extracted_edges: list[EntityEdge],
251
244
  episode: EpisodicNode,
252
245
  entities: list[EntityNode],
253
- edge_types: dict[str, BaseModel],
246
+ edge_types: dict[str, type[BaseModel]],
254
247
  edge_type_map: dict[tuple[str, str], list[str]],
255
248
  ) -> tuple[list[EntityEdge], list[EntityEdge]]:
249
+ # Fast path: deduplicate exact matches within the extracted edges before parallel processing
250
+ seen: dict[tuple[str, str, str], EntityEdge] = {}
251
+ deduplicated_edges: list[EntityEdge] = []
252
+
253
+ for edge in extracted_edges:
254
+ key = (
255
+ edge.source_node_uuid,
256
+ edge.target_node_uuid,
257
+ _normalize_string_exact(edge.fact),
258
+ )
259
+ if key not in seen:
260
+ seen[key] = edge
261
+ deduplicated_edges.append(edge)
262
+
263
+ extracted_edges = deduplicated_edges
264
+
256
265
  driver = clients.driver
257
266
  llm_client = clients.llm_client
258
267
  embedder = clients.embedder
259
-
260
268
  await create_entity_edge_embeddings(embedder, extracted_edges)
261
269
 
262
- search_results: tuple[list[list[EntityEdge]], list[list[EntityEdge]]] = await semaphore_gather(
263
- get_relevant_edges(driver, extracted_edges, SearchFilters()),
264
- get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters(), 0.2),
270
+ valid_edges_list: list[list[EntityEdge]] = await semaphore_gather(
271
+ *[
272
+ EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid)
273
+ for edge in extracted_edges
274
+ ]
275
+ )
276
+
277
+ related_edges_results: list[SearchResults] = await semaphore_gather(
278
+ *[
279
+ search(
280
+ clients,
281
+ extracted_edge.fact,
282
+ group_ids=[extracted_edge.group_id],
283
+ config=EDGE_HYBRID_SEARCH_RRF,
284
+ search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
285
+ )
286
+ for extracted_edge, valid_edges in zip(extracted_edges, valid_edges_list, strict=True)
287
+ ]
265
288
  )
266
289
 
267
- related_edges_lists, edge_invalidation_candidates = search_results
290
+ related_edges_lists: list[list[EntityEdge]] = [result.edges for result in related_edges_results]
291
+
292
+ edge_invalidation_candidate_results: list[SearchResults] = await semaphore_gather(
293
+ *[
294
+ search(
295
+ clients,
296
+ extracted_edge.fact,
297
+ group_ids=[extracted_edge.group_id],
298
+ config=EDGE_HYBRID_SEARCH_RRF,
299
+ search_filter=SearchFilters(),
300
+ )
301
+ for extracted_edge in extracted_edges
302
+ ]
303
+ )
304
+
305
+ edge_invalidation_candidates: list[list[EntityEdge]] = [
306
+ result.edges for result in edge_invalidation_candidate_results
307
+ ]
268
308
 
269
309
  logger.debug(
270
310
  f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
@@ -273,11 +313,21 @@ async def resolve_extracted_edges(
273
313
  # Build entity hash table
274
314
  uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
275
315
 
276
- # Determine which edge types are relevant for each edge
277
- edge_types_lst: list[dict[str, BaseModel]] = []
316
+ # Determine which edge types are relevant for each edge.
317
+ # `edge_types_lst` stores the subset of custom edge definitions whose
318
+ # node signature matches each extracted edge. Anything outside this subset
319
+ # should only stay on the edge if it is a non-custom (LLM generated) label.
320
+ edge_types_lst: list[dict[str, type[BaseModel]]] = []
321
+ custom_type_names = set(edge_types or {})
278
322
  for extracted_edge in extracted_edges:
279
- source_node_labels = uuid_entity_map[extracted_edge.source_node_uuid].labels
280
- target_node_labels = uuid_entity_map[extracted_edge.target_node_uuid].labels
323
+ source_node = uuid_entity_map.get(extracted_edge.source_node_uuid)
324
+ target_node = uuid_entity_map.get(extracted_edge.target_node_uuid)
325
+ source_node_labels = (
326
+ source_node.labels + ['Entity'] if source_node is not None else ['Entity']
327
+ )
328
+ target_node_labels = (
329
+ target_node.labels + ['Entity'] if target_node is not None else ['Entity']
330
+ )
281
331
  label_tuples = [
282
332
  (source_label, target_label)
283
333
  for source_label in source_node_labels
@@ -296,8 +346,22 @@ async def resolve_extracted_edges(
296
346
 
297
347
  edge_types_lst.append(extracted_edge_types)
298
348
 
349
+ for extracted_edge, extracted_edge_types in zip(extracted_edges, edge_types_lst, strict=True):
350
+ allowed_type_names = set(extracted_edge_types)
351
+ is_custom_name = extracted_edge.name in custom_type_names
352
+ if not allowed_type_names:
353
+ # No custom types are valid for this node pairing. Keep LLM generated
354
+ # labels, but flip disallowed custom names back to the default.
355
+ if is_custom_name and extracted_edge.name != DEFAULT_EDGE_NAME:
356
+ extracted_edge.name = DEFAULT_EDGE_NAME
357
+ continue
358
+ if is_custom_name and extracted_edge.name not in allowed_type_names:
359
+ # Custom name exists but it is not permitted for this source/target
360
+ # signature, so fall back to the default edge label.
361
+ extracted_edge.name = DEFAULT_EDGE_NAME
362
+
299
363
  # resolve edges with related edges in the graph and find invalidation candidates
300
- results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
364
+ results: list[tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]] = list(
301
365
  await semaphore_gather(
302
366
  *[
303
367
  resolve_extracted_edge(
@@ -307,6 +371,7 @@ async def resolve_extracted_edges(
307
371
  existing_edges,
308
372
  episode,
309
373
  extracted_edge_types,
374
+ custom_type_names,
310
375
  )
311
376
  for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
312
377
  extracted_edges,
@@ -348,21 +413,26 @@ def resolve_edge_contradictions(
348
413
  invalidated_edges: list[EntityEdge] = []
349
414
  for edge in invalidation_candidates:
350
415
  # (Edge invalid before new edge becomes valid) or (new edge invalid before edge becomes valid)
416
+ edge_invalid_at_utc = ensure_utc(edge.invalid_at)
417
+ resolved_edge_valid_at_utc = ensure_utc(resolved_edge.valid_at)
418
+ edge_valid_at_utc = ensure_utc(edge.valid_at)
419
+ resolved_edge_invalid_at_utc = ensure_utc(resolved_edge.invalid_at)
420
+
351
421
  if (
352
- edge.invalid_at is not None
353
- and resolved_edge.valid_at is not None
354
- and edge.invalid_at <= resolved_edge.valid_at
422
+ edge_invalid_at_utc is not None
423
+ and resolved_edge_valid_at_utc is not None
424
+ and edge_invalid_at_utc <= resolved_edge_valid_at_utc
355
425
  ) or (
356
- edge.valid_at is not None
357
- and resolved_edge.invalid_at is not None
358
- and resolved_edge.invalid_at <= edge.valid_at
426
+ edge_valid_at_utc is not None
427
+ and resolved_edge_invalid_at_utc is not None
428
+ and resolved_edge_invalid_at_utc <= edge_valid_at_utc
359
429
  ):
360
430
  continue
361
431
  # New edge invalidates edge
362
432
  elif (
363
- edge.valid_at is not None
364
- and resolved_edge.valid_at is not None
365
- and edge.valid_at < resolved_edge.valid_at
433
+ edge_valid_at_utc is not None
434
+ and resolved_edge_valid_at_utc is not None
435
+ and edge_valid_at_utc < resolved_edge_valid_at_utc
366
436
  ):
367
437
  edge.invalid_at = resolved_edge.valid_at
368
438
  edge.expired_at = edge.expired_at if edge.expired_at is not None else utc_now()
@@ -377,32 +447,69 @@ async def resolve_extracted_edge(
377
447
  related_edges: list[EntityEdge],
378
448
  existing_edges: list[EntityEdge],
379
449
  episode: EpisodicNode,
380
- edge_types: dict[str, BaseModel] | None = None,
381
- ) -> tuple[EntityEdge, list[EntityEdge]]:
450
+ edge_type_candidates: dict[str, type[BaseModel]] | None = None,
451
+ custom_edge_type_names: set[str] | None = None,
452
+ ) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
453
+ """Resolve an extracted edge against existing graph context.
454
+
455
+ Parameters
456
+ ----------
457
+ llm_client : LLMClient
458
+ Client used to invoke the LLM for deduplication and attribute extraction.
459
+ extracted_edge : EntityEdge
460
+ Newly extracted edge whose canonical representation is being resolved.
461
+ related_edges : list[EntityEdge]
462
+ Candidate edges with identical endpoints used for duplicate detection.
463
+ existing_edges : list[EntityEdge]
464
+ Broader set of edges evaluated for contradiction / invalidation.
465
+ episode : EpisodicNode
466
+ Episode providing content context when extracting edge attributes.
467
+ edge_type_candidates : dict[str, type[BaseModel]] | None
468
+ Custom edge types permitted for the current source/target signature.
469
+ custom_edge_type_names : set[str] | None
470
+ Full catalog of registered custom edge names. Used to distinguish
471
+ between disallowed custom types (which fall back to the default label)
472
+ and ad-hoc labels emitted by the LLM.
473
+
474
+ Returns
475
+ -------
476
+ tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]
477
+ The resolved edge, any duplicates, and edges to invalidate.
478
+ """
382
479
  if len(related_edges) == 0 and len(existing_edges) == 0:
383
- return extracted_edge, []
480
+ return extracted_edge, [], []
481
+
482
+ # Fast path: if the fact text and endpoints already exist verbatim, reuse the matching edge.
483
+ normalized_fact = _normalize_string_exact(extracted_edge.fact)
484
+ for edge in related_edges:
485
+ if (
486
+ edge.source_node_uuid == extracted_edge.source_node_uuid
487
+ and edge.target_node_uuid == extracted_edge.target_node_uuid
488
+ and _normalize_string_exact(edge.fact) == normalized_fact
489
+ ):
490
+ resolved = edge
491
+ if episode is not None and episode.uuid not in resolved.episodes:
492
+ resolved.episodes.append(episode.uuid)
493
+ return resolved, [], []
384
494
 
385
495
  start = time()
386
496
 
387
497
  # Prepare context for LLM
388
- related_edges_context = [
389
- {'id': edge.uuid, 'fact': edge.fact} for i, edge in enumerate(related_edges)
390
- ]
498
+ related_edges_context = [{'idx': i, 'fact': edge.fact} for i, edge in enumerate(related_edges)]
391
499
 
392
500
  invalidation_edge_candidates_context = [
393
- {'id': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
501
+ {'idx': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
394
502
  ]
395
503
 
396
504
  edge_types_context = (
397
505
  [
398
506
  {
399
- 'fact_type_id': i,
400
507
  'fact_type_name': type_name,
401
508
  'fact_type_description': type_model.__doc__,
402
509
  }
403
- for i, (type_name, type_model) in enumerate(edge_types.items())
510
+ for type_name, type_model in edge_type_candidates.items()
404
511
  ]
405
- if edge_types is not None
512
+ if edge_type_candidates is not None
406
513
  else []
407
514
  )
408
515
 
@@ -413,46 +520,97 @@ async def resolve_extracted_edge(
413
520
  'edge_types': edge_types_context,
414
521
  }
415
522
 
523
+ if related_edges or existing_edges:
524
+ logger.debug(
525
+ 'Resolving edge: sent %d EXISTING FACTS%s and %d INVALIDATION CANDIDATES%s',
526
+ len(related_edges),
527
+ f' (idx 0-{len(related_edges) - 1})' if related_edges else '',
528
+ len(existing_edges),
529
+ f' (idx 0-{len(existing_edges) - 1})' if existing_edges else '',
530
+ )
531
+
416
532
  llm_response = await llm_client.generate_response(
417
533
  prompt_library.dedupe_edges.resolve_edge(context),
418
534
  response_model=EdgeDuplicate,
419
535
  model_size=ModelSize.small,
536
+ prompt_name='dedupe_edges.resolve_edge',
420
537
  )
538
+ response_object = EdgeDuplicate(**llm_response)
539
+ duplicate_facts = response_object.duplicate_facts
540
+
541
+ # Validate duplicate_facts are in valid range for EXISTING FACTS
542
+ invalid_duplicates = [i for i in duplicate_facts if i < 0 or i >= len(related_edges)]
543
+ if invalid_duplicates:
544
+ logger.warning(
545
+ 'LLM returned invalid duplicate_facts idx values %s (valid range: 0-%d for EXISTING FACTS)',
546
+ invalid_duplicates,
547
+ len(related_edges) - 1,
548
+ )
421
549
 
422
- duplicate_fact_id: int = llm_response.get('duplicate_fact_id', -1)
550
+ duplicate_fact_ids: list[int] = [i for i in duplicate_facts if 0 <= i < len(related_edges)]
423
551
 
424
- resolved_edge = (
425
- related_edges[duplicate_fact_id]
426
- if 0 <= duplicate_fact_id < len(related_edges)
427
- else extracted_edge
428
- )
552
+ resolved_edge = extracted_edge
553
+ for duplicate_fact_id in duplicate_fact_ids:
554
+ resolved_edge = related_edges[duplicate_fact_id]
555
+ break
429
556
 
430
- if duplicate_fact_id >= 0 and episode is not None:
557
+ if duplicate_fact_ids and episode is not None:
431
558
  resolved_edge.episodes.append(episode.uuid)
432
559
 
433
- contradicted_facts: list[int] = llm_response.get('contradicted_facts', [])
560
+ contradicted_facts: list[int] = response_object.contradicted_facts
434
561
 
435
- invalidation_candidates: list[EntityEdge] = [existing_edges[i] for i in contradicted_facts]
562
+ # Validate contradicted_facts are in valid range for INVALIDATION CANDIDATES
563
+ invalid_contradictions = [i for i in contradicted_facts if i < 0 or i >= len(existing_edges)]
564
+ if invalid_contradictions:
565
+ logger.warning(
566
+ 'LLM returned invalid contradicted_facts idx values %s (valid range: 0-%d for INVALIDATION CANDIDATES)',
567
+ invalid_contradictions,
568
+ len(existing_edges) - 1,
569
+ )
436
570
 
437
- fact_type: str = str(llm_response.get('fact_type'))
438
- if fact_type.upper() != 'DEFAULT' and edge_types is not None:
571
+ invalidation_candidates: list[EntityEdge] = [
572
+ existing_edges[i] for i in contradicted_facts if 0 <= i < len(existing_edges)
573
+ ]
574
+
575
+ fact_type: str = response_object.fact_type
576
+ candidate_type_names = set(edge_type_candidates or {})
577
+ custom_type_names = custom_edge_type_names or set()
578
+
579
+ is_default_type = fact_type.upper() == 'DEFAULT'
580
+ is_custom_type = fact_type in custom_type_names
581
+ is_allowed_custom_type = fact_type in candidate_type_names
582
+
583
+ if is_allowed_custom_type:
584
+ # The LLM selected a custom type that is allowed for the node pair.
585
+ # Adopt the custom type and, if needed, extract its structured attributes.
439
586
  resolved_edge.name = fact_type
440
587
 
441
588
  edge_attributes_context = {
442
- 'message': episode.content,
589
+ 'episode_content': episode.content,
443
590
  'reference_time': episode.valid_at,
444
591
  'fact': resolved_edge.fact,
445
592
  }
446
593
 
447
- edge_model = edge_types.get(fact_type)
448
-
449
- edge_attributes_response = await llm_client.generate_response(
450
- prompt_library.extract_edges.extract_attributes(edge_attributes_context),
451
- response_model=edge_model, # type: ignore
452
- model_size=ModelSize.small,
453
- )
594
+ edge_model = edge_type_candidates.get(fact_type) if edge_type_candidates else None
595
+ if edge_model is not None and len(edge_model.model_fields) != 0:
596
+ edge_attributes_response = await llm_client.generate_response(
597
+ prompt_library.extract_edges.extract_attributes(edge_attributes_context),
598
+ response_model=edge_model, # type: ignore
599
+ model_size=ModelSize.small,
600
+ prompt_name='extract_edges.extract_attributes',
601
+ )
454
602
 
455
- resolved_edge.attributes = edge_attributes_response
603
+ resolved_edge.attributes = edge_attributes_response
604
+ elif not is_default_type and is_custom_type:
605
+ # The LLM picked a custom type that is not allowed for this signature.
606
+ # Reset to the default label and drop any structured attributes.
607
+ resolved_edge.name = DEFAULT_EDGE_NAME
608
+ resolved_edge.attributes = {}
609
+ elif not is_default_type:
610
+ # Non-custom labels are allowed to pass through so long as the LLM does
611
+ # not return the sentinel DEFAULT value.
612
+ resolved_edge.name = fact_type
613
+ resolved_edge.attributes = {}
456
614
 
457
615
  end = time()
458
616
  logger.debug(
@@ -466,14 +624,14 @@ async def resolve_extracted_edge(
466
624
 
467
625
  # Determine if the new_edge needs to be expired
468
626
  if resolved_edge.expired_at is None:
469
- invalidation_candidates.sort(key=lambda c: (c.valid_at is None, c.valid_at))
627
+ invalidation_candidates.sort(key=lambda c: (c.valid_at is None, ensure_utc(c.valid_at)))
470
628
  for candidate in invalidation_candidates:
629
+ candidate_valid_at_utc = ensure_utc(candidate.valid_at)
630
+ resolved_edge_valid_at_utc = ensure_utc(resolved_edge.valid_at)
471
631
  if (
472
- candidate.valid_at
473
- and resolved_edge.valid_at
474
- and candidate.valid_at.tzinfo
475
- and resolved_edge.valid_at.tzinfo
476
- and candidate.valid_at > resolved_edge.valid_at
632
+ candidate_valid_at_utc is not None
633
+ and resolved_edge_valid_at_utc is not None
634
+ and candidate_valid_at_utc > resolved_edge_valid_at_utc
477
635
  ):
478
636
  # Expire new edge since we have information about more recent events
479
637
  resolved_edge.invalid_at = candidate.valid_at
@@ -481,89 +639,73 @@ async def resolve_extracted_edge(
481
639
  break
482
640
 
483
641
  # Determine which contradictory edges need to be expired
484
- invalidated_edges = resolve_edge_contradictions(resolved_edge, invalidation_candidates)
485
-
486
- return resolved_edge, invalidated_edges
487
-
488
-
489
- async def dedupe_extracted_edge(
490
- llm_client: LLMClient,
491
- extracted_edge: EntityEdge,
492
- related_edges: list[EntityEdge],
493
- episode: EpisodicNode | None = None,
494
- ) -> EntityEdge:
495
- if len(related_edges) == 0:
496
- return extracted_edge
497
-
498
- start = time()
499
-
500
- # Prepare context for LLM
501
- related_edges_context = [
502
- {'id': edge.uuid, 'fact': edge.fact} for i, edge in enumerate(related_edges)
503
- ]
504
-
505
- extracted_edge_context = {
506
- 'fact': extracted_edge.fact,
507
- }
508
-
509
- context = {
510
- 'related_edges': related_edges_context,
511
- 'extracted_edges': extracted_edge_context,
512
- }
513
-
514
- llm_response = await llm_client.generate_response(
515
- prompt_library.dedupe_edges.edge(context),
516
- response_model=EdgeDuplicate,
517
- model_size=ModelSize.small,
518
- )
519
-
520
- duplicate_fact_id: int = llm_response.get('duplicate_fact_id', -1)
521
-
522
- edge = (
523
- related_edges[duplicate_fact_id]
524
- if 0 <= duplicate_fact_id < len(related_edges)
525
- else extracted_edge
526
- )
527
-
528
- if duplicate_fact_id >= 0 and episode is not None:
529
- edge.episodes.append(episode.uuid)
530
-
531
- end = time()
532
- logger.debug(
533
- f'Resolved Edge: {extracted_edge.name} is {edge.name}, in {(end - start) * 1000} ms'
642
+ invalidated_edges: list[EntityEdge] = resolve_edge_contradictions(
643
+ resolved_edge, invalidation_candidates
534
644
  )
645
+ duplicate_edges: list[EntityEdge] = [related_edges[idx] for idx in duplicate_fact_ids]
535
646
 
536
- return edge
647
+ return resolved_edge, invalidated_edges, duplicate_edges
537
648
 
538
649
 
539
- async def dedupe_edge_list(
540
- llm_client: LLMClient,
541
- edges: list[EntityEdge],
542
- ) -> list[EntityEdge]:
543
- start = time()
544
-
545
- # Create edge map
546
- edge_map = {}
547
- for edge in edges:
548
- edge_map[edge.uuid] = edge
650
+ async def filter_existing_duplicate_of_edges(
651
+ driver: GraphDriver, duplicates_node_tuples: list[tuple[EntityNode, EntityNode]]
652
+ ) -> list[tuple[EntityNode, EntityNode]]:
653
+ if not duplicates_node_tuples:
654
+ return []
549
655
 
550
- # Prepare context for LLM
551
- context = {'edges': [{'uuid': edge.uuid, 'fact': edge.fact} for edge in edges]}
656
+ duplicate_nodes_map = {
657
+ (source.uuid, target.uuid): (source, target) for source, target in duplicates_node_tuples
658
+ }
552
659
 
553
- llm_response = await llm_client.generate_response(
554
- prompt_library.dedupe_edges.edge_list(context), response_model=UniqueFacts
555
- )
556
- unique_edges_data = llm_response.get('unique_facts', [])
660
+ if driver.provider == GraphProvider.NEPTUNE:
661
+ query: LiteralString = """
662
+ UNWIND $duplicate_node_uuids AS duplicate_tuple
663
+ MATCH (n:Entity {uuid: duplicate_tuple.source})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple.target})
664
+ RETURN DISTINCT
665
+ n.uuid AS source_uuid,
666
+ m.uuid AS target_uuid
667
+ """
668
+
669
+ duplicate_nodes = [
670
+ {'source': source.uuid, 'target': target.uuid}
671
+ for source, target in duplicates_node_tuples
672
+ ]
557
673
 
558
- end = time()
559
- logger.debug(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ')
674
+ records, _, _ = await driver.execute_query(
675
+ query,
676
+ duplicate_node_uuids=duplicate_nodes,
677
+ routing_='r',
678
+ )
679
+ else:
680
+ if driver.provider == GraphProvider.KUZU:
681
+ query = """
682
+ UNWIND $duplicate_node_uuids AS duplicate
683
+ MATCH (n:Entity {uuid: duplicate.src})-[:RELATES_TO]->(e:RelatesToNode_ {name: 'IS_DUPLICATE_OF'})-[:RELATES_TO]->(m:Entity {uuid: duplicate.dst})
684
+ RETURN DISTINCT
685
+ n.uuid AS source_uuid,
686
+ m.uuid AS target_uuid
687
+ """
688
+ duplicate_node_uuids = [{'src': src, 'dst': dst} for src, dst in duplicate_nodes_map]
689
+ else:
690
+ query: LiteralString = """
691
+ UNWIND $duplicate_node_uuids AS duplicate_tuple
692
+ MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
693
+ RETURN DISTINCT
694
+ n.uuid AS source_uuid,
695
+ m.uuid AS target_uuid
696
+ """
697
+ duplicate_node_uuids = list(duplicate_nodes_map.keys())
698
+
699
+ records, _, _ = await driver.execute_query(
700
+ query,
701
+ duplicate_node_uuids=duplicate_node_uuids,
702
+ routing_='r',
703
+ )
560
704
 
561
- # Get full edge data
562
- unique_edges = []
563
- for edge_data in unique_edges_data:
564
- uuid = edge_data['uuid']
565
- edge = edge_map[uuid]
566
- edge.fact = edge_data['fact']
567
- unique_edges.append(edge)
705
+ # Remove duplicates that already have the IS_DUPLICATE_OF edge
706
+ for record in records:
707
+ duplicate_tuple = (record.get('source_uuid'), record.get('target_uuid'))
708
+ if duplicate_nodes_map.get(duplicate_tuple):
709
+ duplicate_nodes_map.pop(duplicate_tuple)
568
710
 
569
- return unique_edges
711
+ return list(duplicate_nodes_map.values())