graphiti-core 0.21.0rc6__py3-none-any.whl → 0.21.0rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -0,0 +1,262 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import math
20
+ import re
21
+ from collections import defaultdict
22
+ from collections.abc import Iterable
23
+ from dataclasses import dataclass, field
24
+ from functools import lru_cache
25
+ from hashlib import blake2b
26
+ from typing import TYPE_CHECKING
27
+
28
+ if TYPE_CHECKING:
29
+ from graphiti_core.nodes import EntityNode
30
+
31
+ _NAME_ENTROPY_THRESHOLD = 1.5
32
+ _MIN_NAME_LENGTH = 6
33
+ _MIN_TOKEN_COUNT = 2
34
+ _FUZZY_JACCARD_THRESHOLD = 0.9
35
+ _MINHASH_PERMUTATIONS = 32
36
+ _MINHASH_BAND_SIZE = 4
37
+
38
+
39
+ def _normalize_string_exact(name: str) -> str:
40
+ """Lowercase text and collapse whitespace so equal names map to the same key."""
41
+ normalized = re.sub(r'[\s]+', ' ', name.lower())
42
+ return normalized.strip()
43
+
44
+
45
+ def _normalize_name_for_fuzzy(name: str) -> str:
46
+ """Produce a fuzzier form that keeps alphanumerics and apostrophes for n-gram shingles."""
47
+ normalized = re.sub(r"[^a-z0-9' ]", ' ', _normalize_string_exact(name))
48
+ normalized = normalized.strip()
49
+ return re.sub(r'[\s]+', ' ', normalized)
50
+
51
+
52
+ def _name_entropy(normalized_name: str) -> float:
53
+ """Approximate text specificity using Shannon entropy over characters.
54
+
55
+ We strip spaces, count how often each character appears, and sum
56
+ probability * -log2(probability). Short or repetitive names yield low
57
+ entropy, which signals we should defer resolution to the LLM instead of
58
+ trusting fuzzy similarity.
59
+ """
60
+ if not normalized_name:
61
+ return 0.0
62
+
63
+ counts: dict[str, int] = {}
64
+ for char in normalized_name.replace(' ', ''):
65
+ counts[char] = counts.get(char, 0) + 1
66
+
67
+ total = sum(counts.values())
68
+ if total == 0:
69
+ return 0.0
70
+
71
+ entropy = 0.0
72
+ for count in counts.values():
73
+ probability = count / total
74
+ entropy -= probability * math.log2(probability)
75
+
76
+ return entropy
77
+
78
+
79
+ def _has_high_entropy(normalized_name: str) -> bool:
80
+ """Filter out very short or low-entropy names that are unreliable for fuzzy matching."""
81
+ token_count = len(normalized_name.split())
82
+ if len(normalized_name) < _MIN_NAME_LENGTH and token_count < _MIN_TOKEN_COUNT:
83
+ return False
84
+
85
+ return _name_entropy(normalized_name) >= _NAME_ENTROPY_THRESHOLD
86
+
87
+
88
+ def _shingles(normalized_name: str) -> set[str]:
89
+ """Create 3-gram shingles from the normalized name for MinHash calculations."""
90
+ cleaned = normalized_name.replace(' ', '')
91
+ if len(cleaned) < 2:
92
+ return {cleaned} if cleaned else set()
93
+
94
+ return {cleaned[i : i + 3] for i in range(len(cleaned) - 2)}
95
+
96
+
97
+ def _hash_shingle(shingle: str, seed: int) -> int:
98
+ """Generate a deterministic 64-bit hash for a shingle given the permutation seed."""
99
+ digest = blake2b(f'{seed}:{shingle}'.encode(), digest_size=8)
100
+ return int.from_bytes(digest.digest(), 'big')
101
+
102
+
103
+ def _minhash_signature(shingles: Iterable[str]) -> tuple[int, ...]:
104
+ """Compute the MinHash signature for the shingle set across predefined permutations."""
105
+ if not shingles:
106
+ return tuple()
107
+
108
+ seeds = range(_MINHASH_PERMUTATIONS)
109
+ signature: list[int] = []
110
+ for seed in seeds:
111
+ min_hash = min(_hash_shingle(shingle, seed) for shingle in shingles)
112
+ signature.append(min_hash)
113
+
114
+ return tuple(signature)
115
+
116
+
117
+ def _lsh_bands(signature: Iterable[int]) -> list[tuple[int, ...]]:
118
+ """Split the MinHash signature into fixed-size bands for locality-sensitive hashing."""
119
+ signature_list = list(signature)
120
+ if not signature_list:
121
+ return []
122
+
123
+ bands: list[tuple[int, ...]] = []
124
+ for start in range(0, len(signature_list), _MINHASH_BAND_SIZE):
125
+ band = tuple(signature_list[start : start + _MINHASH_BAND_SIZE])
126
+ if len(band) == _MINHASH_BAND_SIZE:
127
+ bands.append(band)
128
+ return bands
129
+
130
+
131
+ def _jaccard_similarity(a: set[str], b: set[str]) -> float:
132
+ """Return the Jaccard similarity between two shingle sets, handling empty edge cases."""
133
+ if not a and not b:
134
+ return 1.0
135
+ if not a or not b:
136
+ return 0.0
137
+
138
+ intersection = len(a.intersection(b))
139
+ union = len(a.union(b))
140
+ return intersection / union if union else 0.0
141
+
142
+
143
+ @lru_cache(maxsize=512)
144
+ def _cached_shingles(name: str) -> set[str]:
145
+ """Cache shingle sets per normalized name to avoid recomputation within a worker."""
146
+ return _shingles(name)
147
+
148
+
149
+ @dataclass
150
+ class DedupCandidateIndexes:
151
+ """Precomputed lookup structures that drive entity deduplication heuristics."""
152
+
153
+ existing_nodes: list[EntityNode]
154
+ nodes_by_uuid: dict[str, EntityNode]
155
+ normalized_existing: defaultdict[str, list[EntityNode]]
156
+ shingles_by_candidate: dict[str, set[str]]
157
+ lsh_buckets: defaultdict[tuple[int, tuple[int, ...]], list[str]]
158
+
159
+
160
+ @dataclass
161
+ class DedupResolutionState:
162
+ """Mutable resolution bookkeeping shared across deterministic and LLM passes."""
163
+
164
+ resolved_nodes: list[EntityNode | None]
165
+ uuid_map: dict[str, str]
166
+ unresolved_indices: list[int]
167
+ duplicate_pairs: list[tuple[EntityNode, EntityNode]] = field(default_factory=list)
168
+
169
+
170
+ def _build_candidate_indexes(existing_nodes: list[EntityNode]) -> DedupCandidateIndexes:
171
+ """Precompute exact and fuzzy lookup structures once per dedupe run."""
172
+ normalized_existing: defaultdict[str, list[EntityNode]] = defaultdict(list)
173
+ nodes_by_uuid: dict[str, EntityNode] = {}
174
+ shingles_by_candidate: dict[str, set[str]] = {}
175
+ lsh_buckets: defaultdict[tuple[int, tuple[int, ...]], list[str]] = defaultdict(list)
176
+
177
+ for candidate in existing_nodes:
178
+ normalized = _normalize_string_exact(candidate.name)
179
+ normalized_existing[normalized].append(candidate)
180
+ nodes_by_uuid[candidate.uuid] = candidate
181
+
182
+ shingles = _cached_shingles(_normalize_name_for_fuzzy(candidate.name))
183
+ shingles_by_candidate[candidate.uuid] = shingles
184
+
185
+ signature = _minhash_signature(shingles)
186
+ for band_index, band in enumerate(_lsh_bands(signature)):
187
+ lsh_buckets[(band_index, band)].append(candidate.uuid)
188
+
189
+ return DedupCandidateIndexes(
190
+ existing_nodes=existing_nodes,
191
+ nodes_by_uuid=nodes_by_uuid,
192
+ normalized_existing=normalized_existing,
193
+ shingles_by_candidate=shingles_by_candidate,
194
+ lsh_buckets=lsh_buckets,
195
+ )
196
+
197
+
198
+ def _resolve_with_similarity(
199
+ extracted_nodes: list[EntityNode],
200
+ indexes: DedupCandidateIndexes,
201
+ state: DedupResolutionState,
202
+ ) -> None:
203
+ """Attempt deterministic resolution using exact name hits and fuzzy MinHash comparisons."""
204
+ for idx, node in enumerate(extracted_nodes):
205
+ normalized_exact = _normalize_string_exact(node.name)
206
+ normalized_fuzzy = _normalize_name_for_fuzzy(node.name)
207
+
208
+ if not _has_high_entropy(normalized_fuzzy):
209
+ state.unresolved_indices.append(idx)
210
+ continue
211
+
212
+ existing_matches = indexes.normalized_existing.get(normalized_exact, [])
213
+ if len(existing_matches) == 1:
214
+ match = existing_matches[0]
215
+ state.resolved_nodes[idx] = match
216
+ state.uuid_map[node.uuid] = match.uuid
217
+ if match.uuid != node.uuid:
218
+ state.duplicate_pairs.append((node, match))
219
+ continue
220
+ if len(existing_matches) > 1:
221
+ state.unresolved_indices.append(idx)
222
+ continue
223
+
224
+ shingles = _cached_shingles(normalized_fuzzy)
225
+ signature = _minhash_signature(shingles)
226
+ candidate_ids: set[str] = set()
227
+ for band_index, band in enumerate(_lsh_bands(signature)):
228
+ candidate_ids.update(indexes.lsh_buckets.get((band_index, band), []))
229
+
230
+ best_candidate: EntityNode | None = None
231
+ best_score = 0.0
232
+ for candidate_id in candidate_ids:
233
+ candidate_shingles = indexes.shingles_by_candidate.get(candidate_id, set())
234
+ score = _jaccard_similarity(shingles, candidate_shingles)
235
+ if score > best_score:
236
+ best_score = score
237
+ best_candidate = indexes.nodes_by_uuid.get(candidate_id)
238
+
239
+ if best_candidate is not None and best_score >= _FUZZY_JACCARD_THRESHOLD:
240
+ state.resolved_nodes[idx] = best_candidate
241
+ state.uuid_map[node.uuid] = best_candidate.uuid
242
+ if best_candidate.uuid != node.uuid:
243
+ state.duplicate_pairs.append((node, best_candidate))
244
+ continue
245
+
246
+ state.unresolved_indices.append(idx)
247
+
248
+
249
+ __all__ = [
250
+ 'DedupCandidateIndexes',
251
+ 'DedupResolutionState',
252
+ '_normalize_string_exact',
253
+ '_normalize_name_for_fuzzy',
254
+ '_has_high_entropy',
255
+ '_minhash_signature',
256
+ '_lsh_bands',
257
+ '_jaccard_similarity',
258
+ '_cached_shingles',
259
+ '_FUZZY_JACCARD_THRESHOLD',
260
+ '_build_candidate_indexes',
261
+ '_resolve_with_similarity',
262
+ ]
@@ -41,6 +41,9 @@ from graphiti_core.search.search_config import SearchResults
41
41
  from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF
42
42
  from graphiti_core.search.search_filters import SearchFilters
43
43
  from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
44
+ from graphiti_core.utils.maintenance.dedup_helpers import _normalize_string_exact
45
+
46
+ DEFAULT_EDGE_NAME = 'RELATES_TO'
44
47
 
45
48
  logger = logging.getLogger(__name__)
46
49
 
@@ -229,6 +232,22 @@ async def resolve_extracted_edges(
229
232
  edge_types: dict[str, type[BaseModel]],
230
233
  edge_type_map: dict[tuple[str, str], list[str]],
231
234
  ) -> tuple[list[EntityEdge], list[EntityEdge]]:
235
+ # Fast path: deduplicate exact matches within the extracted edges before parallel processing
236
+ seen: dict[tuple[str, str, str], EntityEdge] = {}
237
+ deduplicated_edges: list[EntityEdge] = []
238
+
239
+ for edge in extracted_edges:
240
+ key = (
241
+ edge.source_node_uuid,
242
+ edge.target_node_uuid,
243
+ _normalize_string_exact(edge.fact),
244
+ )
245
+ if key not in seen:
246
+ seen[key] = edge
247
+ deduplicated_edges.append(edge)
248
+
249
+ extracted_edges = deduplicated_edges
250
+
232
251
  driver = clients.driver
233
252
  llm_client = clients.llm_client
234
253
  embedder = clients.embedder
@@ -280,8 +299,12 @@ async def resolve_extracted_edges(
280
299
  # Build entity hash table
281
300
  uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
282
301
 
283
- # Determine which edge types are relevant for each edge
302
+ # Determine which edge types are relevant for each edge.
303
+ # `edge_types_lst` stores the subset of custom edge definitions whose
304
+ # node signature matches each extracted edge. Anything outside this subset
305
+ # should only stay on the edge if it is a non-custom (LLM generated) label.
284
306
  edge_types_lst: list[dict[str, type[BaseModel]]] = []
307
+ custom_type_names = set(edge_types or {})
285
308
  for extracted_edge in extracted_edges:
286
309
  source_node = uuid_entity_map.get(extracted_edge.source_node_uuid)
287
310
  target_node = uuid_entity_map.get(extracted_edge.target_node_uuid)
@@ -309,6 +332,20 @@ async def resolve_extracted_edges(
309
332
 
310
333
  edge_types_lst.append(extracted_edge_types)
311
334
 
335
+ for extracted_edge, extracted_edge_types in zip(extracted_edges, edge_types_lst, strict=True):
336
+ allowed_type_names = set(extracted_edge_types)
337
+ is_custom_name = extracted_edge.name in custom_type_names
338
+ if not allowed_type_names:
339
+ # No custom types are valid for this node pairing. Keep LLM generated
340
+ # labels, but flip disallowed custom names back to the default.
341
+ if is_custom_name and extracted_edge.name != DEFAULT_EDGE_NAME:
342
+ extracted_edge.name = DEFAULT_EDGE_NAME
343
+ continue
344
+ if is_custom_name and extracted_edge.name not in allowed_type_names:
345
+ # Custom name exists but it is not permitted for this source/target
346
+ # signature, so fall back to the default edge label.
347
+ extracted_edge.name = DEFAULT_EDGE_NAME
348
+
312
349
  # resolve edges with related edges in the graph and find invalidation candidates
313
350
  results: list[tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]] = list(
314
351
  await semaphore_gather(
@@ -320,6 +357,7 @@ async def resolve_extracted_edges(
320
357
  existing_edges,
321
358
  episode,
322
359
  extracted_edge_types,
360
+ custom_type_names,
323
361
  clients.ensure_ascii,
324
362
  )
325
363
  for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
@@ -391,17 +429,59 @@ async def resolve_extracted_edge(
391
429
  related_edges: list[EntityEdge],
392
430
  existing_edges: list[EntityEdge],
393
431
  episode: EpisodicNode,
394
- edge_types: dict[str, type[BaseModel]] | None = None,
432
+ edge_type_candidates: dict[str, type[BaseModel]] | None = None,
433
+ custom_edge_type_names: set[str] | None = None,
395
434
  ensure_ascii: bool = True,
396
435
  ) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
436
+ """Resolve an extracted edge against existing graph context.
437
+
438
+ Parameters
439
+ ----------
440
+ llm_client : LLMClient
441
+ Client used to invoke the LLM for deduplication and attribute extraction.
442
+ extracted_edge : EntityEdge
443
+ Newly extracted edge whose canonical representation is being resolved.
444
+ related_edges : list[EntityEdge]
445
+ Candidate edges with identical endpoints used for duplicate detection.
446
+ existing_edges : list[EntityEdge]
447
+ Broader set of edges evaluated for contradiction / invalidation.
448
+ episode : EpisodicNode
449
+ Episode providing content context when extracting edge attributes.
450
+ edge_type_candidates : dict[str, type[BaseModel]] | None
451
+ Custom edge types permitted for the current source/target signature.
452
+ custom_edge_type_names : set[str] | None
453
+ Full catalog of registered custom edge names. Used to distinguish
454
+ between disallowed custom types (which fall back to the default label)
455
+ and ad-hoc labels emitted by the LLM.
456
+ ensure_ascii : bool
457
+ Whether prompt payloads should coerce ASCII output.
458
+
459
+ Returns
460
+ -------
461
+ tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]
462
+ The resolved edge, any duplicates, and edges to invalidate.
463
+ """
397
464
  if len(related_edges) == 0 and len(existing_edges) == 0:
398
465
  return extracted_edge, [], []
399
466
 
467
+ # Fast path: if the fact text and endpoints already exist verbatim, reuse the matching edge.
468
+ normalized_fact = _normalize_string_exact(extracted_edge.fact)
469
+ for edge in related_edges:
470
+ if (
471
+ edge.source_node_uuid == extracted_edge.source_node_uuid
472
+ and edge.target_node_uuid == extracted_edge.target_node_uuid
473
+ and _normalize_string_exact(edge.fact) == normalized_fact
474
+ ):
475
+ resolved = edge
476
+ if episode is not None and episode.uuid not in resolved.episodes:
477
+ resolved.episodes.append(episode.uuid)
478
+ return resolved, [], []
479
+
400
480
  start = time()
401
481
 
402
482
  # Prepare context for LLM
403
483
  related_edges_context = [
404
- {'id': edge.uuid, 'fact': edge.fact} for i, edge in enumerate(related_edges)
484
+ {'id': i, 'fact': edge.fact} for i, edge in enumerate(related_edges)
405
485
  ]
406
486
 
407
487
  invalidation_edge_candidates_context = [
@@ -415,9 +495,9 @@ async def resolve_extracted_edge(
415
495
  'fact_type_name': type_name,
416
496
  'fact_type_description': type_model.__doc__,
417
497
  }
418
- for i, (type_name, type_model) in enumerate(edge_types.items())
498
+ for i, (type_name, type_model) in enumerate(edge_type_candidates.items())
419
499
  ]
420
- if edge_types is not None
500
+ if edge_type_candidates is not None
421
501
  else []
422
502
  )
423
503
 
@@ -454,7 +534,16 @@ async def resolve_extracted_edge(
454
534
  ]
455
535
 
456
536
  fact_type: str = response_object.fact_type
457
- if fact_type.upper() != 'DEFAULT' and edge_types is not None:
537
+ candidate_type_names = set(edge_type_candidates or {})
538
+ custom_type_names = custom_edge_type_names or set()
539
+
540
+ is_default_type = fact_type.upper() == 'DEFAULT'
541
+ is_custom_type = fact_type in custom_type_names
542
+ is_allowed_custom_type = fact_type in candidate_type_names
543
+
544
+ if is_allowed_custom_type:
545
+ # The LLM selected a custom type that is allowed for the node pair.
546
+ # Adopt the custom type and, if needed, extract its structured attributes.
458
547
  resolved_edge.name = fact_type
459
548
 
460
549
  edge_attributes_context = {
@@ -464,7 +553,7 @@ async def resolve_extracted_edge(
464
553
  'ensure_ascii': ensure_ascii,
465
554
  }
466
555
 
467
- edge_model = edge_types.get(fact_type)
556
+ edge_model = edge_type_candidates.get(fact_type) if edge_type_candidates else None
468
557
  if edge_model is not None and len(edge_model.model_fields) != 0:
469
558
  edge_attributes_response = await llm_client.generate_response(
470
559
  prompt_library.extract_edges.extract_attributes(edge_attributes_context),
@@ -473,6 +562,16 @@ async def resolve_extracted_edge(
473
562
  )
474
563
 
475
564
  resolved_edge.attributes = edge_attributes_response
565
+ elif not is_default_type and is_custom_type:
566
+ # The LLM picked a custom type that is not allowed for this signature.
567
+ # Reset to the default label and drop any structured attributes.
568
+ resolved_edge.name = DEFAULT_EDGE_NAME
569
+ resolved_edge.attributes = {}
570
+ elif not is_default_type:
571
+ # Non-custom labels are allowed to pass through so long as the LLM does
572
+ # not return the sentinel DEFAULT value.
573
+ resolved_edge.name = fact_type
574
+ resolved_edge.attributes = {}
476
575
 
477
576
  end = time()
478
577
  logger.debug(