graphiti-core 0.20.3__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (39) hide show
  1. graphiti_core/driver/driver.py +28 -0
  2. graphiti_core/driver/falkordb_driver.py +112 -0
  3. graphiti_core/driver/kuzu_driver.py +1 -0
  4. graphiti_core/driver/neo4j_driver.py +10 -2
  5. graphiti_core/driver/neptune_driver.py +4 -6
  6. graphiti_core/edges.py +67 -7
  7. graphiti_core/embedder/client.py +2 -1
  8. graphiti_core/graph_queries.py +35 -6
  9. graphiti_core/graphiti.py +36 -24
  10. graphiti_core/graphiti_types.py +0 -1
  11. graphiti_core/helpers.py +2 -2
  12. graphiti_core/llm_client/client.py +19 -4
  13. graphiti_core/llm_client/gemini_client.py +4 -2
  14. graphiti_core/llm_client/openai_base_client.py +3 -2
  15. graphiti_core/llm_client/openai_generic_client.py +3 -2
  16. graphiti_core/models/edges/edge_db_queries.py +36 -16
  17. graphiti_core/models/nodes/node_db_queries.py +30 -10
  18. graphiti_core/nodes.py +126 -25
  19. graphiti_core/prompts/dedupe_edges.py +40 -29
  20. graphiti_core/prompts/dedupe_nodes.py +51 -34
  21. graphiti_core/prompts/eval.py +3 -3
  22. graphiti_core/prompts/extract_edges.py +17 -9
  23. graphiti_core/prompts/extract_nodes.py +10 -9
  24. graphiti_core/prompts/prompt_helpers.py +3 -3
  25. graphiti_core/prompts/summarize_nodes.py +5 -5
  26. graphiti_core/search/search_filters.py +53 -0
  27. graphiti_core/search/search_helpers.py +5 -7
  28. graphiti_core/search/search_utils.py +227 -57
  29. graphiti_core/utils/bulk_utils.py +168 -69
  30. graphiti_core/utils/maintenance/community_operations.py +8 -20
  31. graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
  32. graphiti_core/utils/maintenance/edge_operations.py +187 -50
  33. graphiti_core/utils/maintenance/graph_data_operations.py +9 -5
  34. graphiti_core/utils/maintenance/node_operations.py +244 -88
  35. graphiti_core/utils/maintenance/temporal_operations.py +0 -4
  36. {graphiti_core-0.20.3.dist-info → graphiti_core-0.21.0.dist-info}/METADATA +7 -1
  37. {graphiti_core-0.20.3.dist-info → graphiti_core-0.21.0.dist-info}/RECORD +39 -38
  38. {graphiti_core-0.20.3.dist-info → graphiti_core-0.21.0.dist-info}/WHEEL +0 -0
  39. {graphiti_core-0.20.3.dist-info → graphiti_core-0.21.0.dist-info}/licenses/LICENSE +0 -0
@@ -36,9 +36,14 @@ from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
36
36
  from graphiti_core.prompts import prompt_library
37
37
  from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
38
38
  from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
39
+ from graphiti_core.search.search import search
40
+ from graphiti_core.search.search_config import SearchResults
41
+ from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF
39
42
  from graphiti_core.search.search_filters import SearchFilters
40
- from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
41
43
  from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
44
+ from graphiti_core.utils.maintenance.dedup_helpers import _normalize_string_exact
45
+
46
+ DEFAULT_EDGE_NAME = 'RELATES_TO'
42
47
 
43
48
  logger = logging.getLogger(__name__)
44
49
 
@@ -63,32 +68,6 @@ def build_episodic_edges(
63
68
  return episodic_edges
64
69
 
65
70
 
66
- def build_duplicate_of_edges(
67
- episode: EpisodicNode,
68
- created_at: datetime,
69
- duplicate_nodes: list[tuple[EntityNode, EntityNode]],
70
- ) -> list[EntityEdge]:
71
- is_duplicate_of_edges: list[EntityEdge] = []
72
- for source_node, target_node in duplicate_nodes:
73
- if source_node.uuid == target_node.uuid:
74
- continue
75
-
76
- is_duplicate_of_edges.append(
77
- EntityEdge(
78
- source_node_uuid=source_node.uuid,
79
- target_node_uuid=target_node.uuid,
80
- name='IS_DUPLICATE_OF',
81
- group_id=episode.group_id,
82
- fact=f'{source_node.name} is a duplicate of {target_node.name}',
83
- episodes=[episode.uuid],
84
- created_at=created_at,
85
- valid_at=created_at,
86
- )
87
- )
88
-
89
- return is_duplicate_of_edges
90
-
91
-
92
71
  def build_community_edges(
93
72
  entity_nodes: list[EntityNode],
94
73
  community_node: CommunityNode,
@@ -151,7 +130,6 @@ async def extract_edges(
151
130
  'reference_time': episode.valid_at,
152
131
  'edge_types': edge_types_context,
153
132
  'custom_prompt': '',
154
- 'ensure_ascii': clients.ensure_ascii,
155
133
  }
156
134
 
157
135
  facts_missed = True
@@ -161,6 +139,7 @@ async def extract_edges(
161
139
  prompt_library.extract_edges.edge(context),
162
140
  response_model=ExtractedEdges,
163
141
  max_tokens=extract_edges_max_tokens,
142
+ group_id=group_id,
164
143
  )
165
144
  edges_data = ExtractedEdges(**llm_response).edges
166
145
 
@@ -172,6 +151,7 @@ async def extract_edges(
172
151
  prompt_library.extract_edges.reflexion(context),
173
152
  response_model=MissingFacts,
174
153
  max_tokens=extract_edges_max_tokens,
154
+ group_id=group_id,
175
155
  )
176
156
 
177
157
  missing_facts = reflexion_response.get('missing_facts', [])
@@ -199,15 +179,26 @@ async def extract_edges(
199
179
  valid_at_datetime = None
200
180
  invalid_at_datetime = None
201
181
 
182
+ # Filter out empty edges
183
+ if not edge_data.fact.strip():
184
+ continue
185
+
202
186
  source_node_idx = edge_data.source_entity_id
203
187
  target_node_idx = edge_data.target_entity_id
204
- if not (-1 < source_node_idx < len(nodes) and -1 < target_node_idx < len(nodes)):
188
+
189
+ if len(nodes) == 0:
190
+ logger.warning('No entities provided for edge extraction')
191
+ continue
192
+
193
+ if not (0 <= source_node_idx < len(nodes) and 0 <= target_node_idx < len(nodes)):
205
194
  logger.warning(
206
- f'WARNING: source or target node not filled {edge_data.relation_type}. source_node_uuid: {source_node_idx} and target_node_uuid: {target_node_idx} '
195
+ f'Invalid entity IDs in edge extraction for {edge_data.relation_type}. '
196
+ f'source_entity_id: {source_node_idx}, target_entity_id: {target_node_idx}, '
197
+ f'but only {len(nodes)} entities available (valid range: 0-{len(nodes) - 1})'
207
198
  )
208
199
  continue
209
200
  source_node_uuid = nodes[source_node_idx].uuid
210
- target_node_uuid = nodes[edge_data.target_entity_id].uuid
201
+ target_node_uuid = nodes[target_node_idx].uuid
211
202
 
212
203
  if valid_at:
213
204
  try:
@@ -253,17 +244,65 @@ async def resolve_extracted_edges(
253
244
  edge_types: dict[str, type[BaseModel]],
254
245
  edge_type_map: dict[tuple[str, str], list[str]],
255
246
  ) -> tuple[list[EntityEdge], list[EntityEdge]]:
247
+ # Fast path: deduplicate exact matches within the extracted edges before parallel processing
248
+ seen: dict[tuple[str, str, str], EntityEdge] = {}
249
+ deduplicated_edges: list[EntityEdge] = []
250
+
251
+ for edge in extracted_edges:
252
+ key = (
253
+ edge.source_node_uuid,
254
+ edge.target_node_uuid,
255
+ _normalize_string_exact(edge.fact),
256
+ )
257
+ if key not in seen:
258
+ seen[key] = edge
259
+ deduplicated_edges.append(edge)
260
+
261
+ extracted_edges = deduplicated_edges
262
+
256
263
  driver = clients.driver
257
264
  llm_client = clients.llm_client
258
265
  embedder = clients.embedder
259
266
  await create_entity_edge_embeddings(embedder, extracted_edges)
260
267
 
261
- search_results = await semaphore_gather(
262
- get_relevant_edges(driver, extracted_edges, SearchFilters()),
263
- get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters(), 0.2),
268
+ valid_edges_list: list[list[EntityEdge]] = await semaphore_gather(
269
+ *[
270
+ EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid)
271
+ for edge in extracted_edges
272
+ ]
273
+ )
274
+
275
+ related_edges_results: list[SearchResults] = await semaphore_gather(
276
+ *[
277
+ search(
278
+ clients,
279
+ extracted_edge.fact,
280
+ group_ids=[extracted_edge.group_id],
281
+ config=EDGE_HYBRID_SEARCH_RRF,
282
+ search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
283
+ )
284
+ for extracted_edge, valid_edges in zip(extracted_edges, valid_edges_list, strict=True)
285
+ ]
264
286
  )
265
287
 
266
- related_edges_lists, edge_invalidation_candidates = search_results
288
+ related_edges_lists: list[list[EntityEdge]] = [result.edges for result in related_edges_results]
289
+
290
+ edge_invalidation_candidate_results: list[SearchResults] = await semaphore_gather(
291
+ *[
292
+ search(
293
+ clients,
294
+ extracted_edge.fact,
295
+ group_ids=[extracted_edge.group_id],
296
+ config=EDGE_HYBRID_SEARCH_RRF,
297
+ search_filter=SearchFilters(),
298
+ )
299
+ for extracted_edge in extracted_edges
300
+ ]
301
+ )
302
+
303
+ edge_invalidation_candidates: list[list[EntityEdge]] = [
304
+ result.edges for result in edge_invalidation_candidate_results
305
+ ]
267
306
 
268
307
  logger.debug(
269
308
  f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
@@ -272,8 +311,12 @@ async def resolve_extracted_edges(
272
311
  # Build entity hash table
273
312
  uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
274
313
 
275
- # Determine which edge types are relevant for each edge
314
+ # Determine which edge types are relevant for each edge.
315
+ # `edge_types_lst` stores the subset of custom edge definitions whose
316
+ # node signature matches each extracted edge. Anything outside this subset
317
+ # should only stay on the edge if it is a non-custom (LLM generated) label.
276
318
  edge_types_lst: list[dict[str, type[BaseModel]]] = []
319
+ custom_type_names = set(edge_types or {})
277
320
  for extracted_edge in extracted_edges:
278
321
  source_node = uuid_entity_map.get(extracted_edge.source_node_uuid)
279
322
  target_node = uuid_entity_map.get(extracted_edge.target_node_uuid)
@@ -301,6 +344,20 @@ async def resolve_extracted_edges(
301
344
 
302
345
  edge_types_lst.append(extracted_edge_types)
303
346
 
347
+ for extracted_edge, extracted_edge_types in zip(extracted_edges, edge_types_lst, strict=True):
348
+ allowed_type_names = set(extracted_edge_types)
349
+ is_custom_name = extracted_edge.name in custom_type_names
350
+ if not allowed_type_names:
351
+ # No custom types are valid for this node pairing. Keep LLM generated
352
+ # labels, but flip disallowed custom names back to the default.
353
+ if is_custom_name and extracted_edge.name != DEFAULT_EDGE_NAME:
354
+ extracted_edge.name = DEFAULT_EDGE_NAME
355
+ continue
356
+ if is_custom_name and extracted_edge.name not in allowed_type_names:
357
+ # Custom name exists but it is not permitted for this source/target
358
+ # signature, so fall back to the default edge label.
359
+ extracted_edge.name = DEFAULT_EDGE_NAME
360
+
304
361
  # resolve edges with related edges in the graph and find invalidation candidates
305
362
  results: list[tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]] = list(
306
363
  await semaphore_gather(
@@ -312,7 +369,7 @@ async def resolve_extracted_edges(
312
369
  existing_edges,
313
370
  episode,
314
371
  extracted_edge_types,
315
- clients.ensure_ascii,
372
+ custom_type_names,
316
373
  )
317
374
  for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
318
375
  extracted_edges,
@@ -383,33 +440,69 @@ async def resolve_extracted_edge(
383
440
  related_edges: list[EntityEdge],
384
441
  existing_edges: list[EntityEdge],
385
442
  episode: EpisodicNode,
386
- edge_types: dict[str, type[BaseModel]] | None = None,
387
- ensure_ascii: bool = True,
443
+ edge_type_candidates: dict[str, type[BaseModel]] | None = None,
444
+ custom_edge_type_names: set[str] | None = None,
388
445
  ) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
446
+ """Resolve an extracted edge against existing graph context.
447
+
448
+ Parameters
449
+ ----------
450
+ llm_client : LLMClient
451
+ Client used to invoke the LLM for deduplication and attribute extraction.
452
+ extracted_edge : EntityEdge
453
+ Newly extracted edge whose canonical representation is being resolved.
454
+ related_edges : list[EntityEdge]
455
+ Candidate edges with identical endpoints used for duplicate detection.
456
+ existing_edges : list[EntityEdge]
457
+ Broader set of edges evaluated for contradiction / invalidation.
458
+ episode : EpisodicNode
459
+ Episode providing content context when extracting edge attributes.
460
+ edge_type_candidates : dict[str, type[BaseModel]] | None
461
+ Custom edge types permitted for the current source/target signature.
462
+ custom_edge_type_names : set[str] | None
463
+ Full catalog of registered custom edge names. Used to distinguish
464
+ between disallowed custom types (which fall back to the default label)
465
+ and ad-hoc labels emitted by the LLM.
466
+
467
+ Returns
468
+ -------
469
+ tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]
470
+ The resolved edge, any duplicates, and edges to invalidate.
471
+ """
389
472
  if len(related_edges) == 0 and len(existing_edges) == 0:
390
473
  return extracted_edge, [], []
391
474
 
475
+ # Fast path: if the fact text and endpoints already exist verbatim, reuse the matching edge.
476
+ normalized_fact = _normalize_string_exact(extracted_edge.fact)
477
+ for edge in related_edges:
478
+ if (
479
+ edge.source_node_uuid == extracted_edge.source_node_uuid
480
+ and edge.target_node_uuid == extracted_edge.target_node_uuid
481
+ and _normalize_string_exact(edge.fact) == normalized_fact
482
+ ):
483
+ resolved = edge
484
+ if episode is not None and episode.uuid not in resolved.episodes:
485
+ resolved.episodes.append(episode.uuid)
486
+ return resolved, [], []
487
+
392
488
  start = time()
393
489
 
394
490
  # Prepare context for LLM
395
- related_edges_context = [
396
- {'id': edge.uuid, 'fact': edge.fact} for i, edge in enumerate(related_edges)
397
- ]
491
+ related_edges_context = [{'idx': i, 'fact': edge.fact} for i, edge in enumerate(related_edges)]
398
492
 
399
493
  invalidation_edge_candidates_context = [
400
- {'id': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
494
+ {'idx': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
401
495
  ]
402
496
 
403
497
  edge_types_context = (
404
498
  [
405
499
  {
406
- 'fact_type_id': i,
407
500
  'fact_type_name': type_name,
408
501
  'fact_type_description': type_model.__doc__,
409
502
  }
410
- for i, (type_name, type_model) in enumerate(edge_types.items())
503
+ for type_name, type_model in edge_type_candidates.items()
411
504
  ]
412
- if edge_types is not None
505
+ if edge_type_candidates is not None
413
506
  else []
414
507
  )
415
508
 
@@ -418,9 +511,17 @@ async def resolve_extracted_edge(
418
511
  'new_edge': extracted_edge.fact,
419
512
  'edge_invalidation_candidates': invalidation_edge_candidates_context,
420
513
  'edge_types': edge_types_context,
421
- 'ensure_ascii': ensure_ascii,
422
514
  }
423
515
 
516
+ if related_edges or existing_edges:
517
+ logger.debug(
518
+ 'Resolving edge: sent %d EXISTING FACTS%s and %d INVALIDATION CANDIDATES%s',
519
+ len(related_edges),
520
+ f' (idx 0-{len(related_edges) - 1})' if related_edges else '',
521
+ len(existing_edges),
522
+ f' (idx 0-{len(existing_edges) - 1})' if existing_edges else '',
523
+ )
524
+
424
525
  llm_response = await llm_client.generate_response(
425
526
  prompt_library.dedupe_edges.resolve_edge(context),
426
527
  response_model=EdgeDuplicate,
@@ -429,6 +530,15 @@ async def resolve_extracted_edge(
429
530
  response_object = EdgeDuplicate(**llm_response)
430
531
  duplicate_facts = response_object.duplicate_facts
431
532
 
533
+ # Validate duplicate_facts are in valid range for EXISTING FACTS
534
+ invalid_duplicates = [i for i in duplicate_facts if i < 0 or i >= len(related_edges)]
535
+ if invalid_duplicates:
536
+ logger.warning(
537
+ 'LLM returned invalid duplicate_facts idx values %s (valid range: 0-%d for EXISTING FACTS)',
538
+ invalid_duplicates,
539
+ len(related_edges) - 1,
540
+ )
541
+
432
542
  duplicate_fact_ids: list[int] = [i for i in duplicate_facts if 0 <= i < len(related_edges)]
433
543
 
434
544
  resolved_edge = extracted_edge
@@ -441,22 +551,39 @@ async def resolve_extracted_edge(
441
551
 
442
552
  contradicted_facts: list[int] = response_object.contradicted_facts
443
553
 
554
+ # Validate contradicted_facts are in valid range for INVALIDATION CANDIDATES
555
+ invalid_contradictions = [i for i in contradicted_facts if i < 0 or i >= len(existing_edges)]
556
+ if invalid_contradictions:
557
+ logger.warning(
558
+ 'LLM returned invalid contradicted_facts idx values %s (valid range: 0-%d for INVALIDATION CANDIDATES)',
559
+ invalid_contradictions,
560
+ len(existing_edges) - 1,
561
+ )
562
+
444
563
  invalidation_candidates: list[EntityEdge] = [
445
564
  existing_edges[i] for i in contradicted_facts if 0 <= i < len(existing_edges)
446
565
  ]
447
566
 
448
567
  fact_type: str = response_object.fact_type
449
- if fact_type.upper() != 'DEFAULT' and edge_types is not None:
568
+ candidate_type_names = set(edge_type_candidates or {})
569
+ custom_type_names = custom_edge_type_names or set()
570
+
571
+ is_default_type = fact_type.upper() == 'DEFAULT'
572
+ is_custom_type = fact_type in custom_type_names
573
+ is_allowed_custom_type = fact_type in candidate_type_names
574
+
575
+ if is_allowed_custom_type:
576
+ # The LLM selected a custom type that is allowed for the node pair.
577
+ # Adopt the custom type and, if needed, extract its structured attributes.
450
578
  resolved_edge.name = fact_type
451
579
 
452
580
  edge_attributes_context = {
453
581
  'episode_content': episode.content,
454
582
  'reference_time': episode.valid_at,
455
583
  'fact': resolved_edge.fact,
456
- 'ensure_ascii': ensure_ascii,
457
584
  }
458
585
 
459
- edge_model = edge_types.get(fact_type)
586
+ edge_model = edge_type_candidates.get(fact_type) if edge_type_candidates else None
460
587
  if edge_model is not None and len(edge_model.model_fields) != 0:
461
588
  edge_attributes_response = await llm_client.generate_response(
462
589
  prompt_library.extract_edges.extract_attributes(edge_attributes_context),
@@ -465,6 +592,16 @@ async def resolve_extracted_edge(
465
592
  )
466
593
 
467
594
  resolved_edge.attributes = edge_attributes_response
595
+ elif not is_default_type and is_custom_type:
596
+ # The LLM picked a custom type that is not allowed for this signature.
597
+ # Reset to the default label and drop any structured attributes.
598
+ resolved_edge.name = DEFAULT_EDGE_NAME
599
+ resolved_edge.attributes = {}
600
+ elif not is_default_type:
601
+ # Non-custom labels are allowed to pass through so long as the LLM does
602
+ # not return the sentinel DEFAULT value.
603
+ resolved_edge.name = fact_type
604
+ resolved_edge.attributes = {}
468
605
 
469
606
  end = time()
470
607
  logger.debug(
@@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
34
34
 
35
35
 
36
36
  async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bool = False):
37
- if driver.provider == GraphProvider.NEPTUNE:
37
+ if driver.aoss_client:
38
38
  await driver.create_aoss_indices() # pyright: ignore[reportAttributeAccessIssue]
39
39
  return
40
40
  if delete_existing:
@@ -56,7 +56,9 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo
56
56
 
57
57
  range_indices: list[LiteralString] = get_range_indices(driver.provider)
58
58
 
59
- fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
59
+ # Don't create fulltext indices if OpenSearch is being used
60
+ if not driver.aoss_client:
61
+ fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
60
62
 
61
63
  if driver.provider == GraphProvider.KUZU:
62
64
  # Skip creating fulltext indices if they already exist. Need to do this manually
@@ -93,6 +95,8 @@ async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
93
95
 
94
96
  async def delete_all(tx):
95
97
  await tx.run('MATCH (n) DETACH DELETE n')
98
+ if driver.aoss_client:
99
+ await driver.clear_aoss_indices()
96
100
 
97
101
  async def delete_group_ids(tx):
98
102
  labels = ['Entity', 'Episodic', 'Community']
@@ -149,9 +153,9 @@ async def retrieve_episodes(
149
153
 
150
154
  query: LiteralString = (
151
155
  """
152
- MATCH (e:Episodic)
153
- WHERE e.valid_at <= $reference_time
154
- """
156
+ MATCH (e:Episodic)
157
+ WHERE e.valid_at <= $reference_time
158
+ """
155
159
  + query_filter
156
160
  + """
157
161
  RETURN