graphiti-core 0.21.0rc5__py3-none-any.whl → 0.21.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -0,0 +1,262 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import math
20
+ import re
21
+ from collections import defaultdict
22
+ from collections.abc import Iterable
23
+ from dataclasses import dataclass, field
24
+ from functools import lru_cache
25
+ from hashlib import blake2b
26
+ from typing import TYPE_CHECKING
27
+
28
+ if TYPE_CHECKING:
29
+ from graphiti_core.nodes import EntityNode
30
+
31
+ _NAME_ENTROPY_THRESHOLD = 1.5
32
+ _MIN_NAME_LENGTH = 6
33
+ _MIN_TOKEN_COUNT = 2
34
+ _FUZZY_JACCARD_THRESHOLD = 0.9
35
+ _MINHASH_PERMUTATIONS = 32
36
+ _MINHASH_BAND_SIZE = 4
37
+
38
+
39
+ def _normalize_string_exact(name: str) -> str:
40
+ """Lowercase text and collapse whitespace so equal names map to the same key."""
41
+ normalized = re.sub(r'[\s]+', ' ', name.lower())
42
+ return normalized.strip()
43
+
44
+
45
+ def _normalize_name_for_fuzzy(name: str) -> str:
46
+ """Produce a fuzzier form that keeps alphanumerics and apostrophes for n-gram shingles."""
47
+ normalized = re.sub(r"[^a-z0-9' ]", ' ', _normalize_string_exact(name))
48
+ normalized = normalized.strip()
49
+ return re.sub(r'[\s]+', ' ', normalized)
50
+
51
+
52
+ def _name_entropy(normalized_name: str) -> float:
53
+ """Approximate text specificity using Shannon entropy over characters.
54
+
55
+ We strip spaces, count how often each character appears, and sum
56
+ probability * -log2(probability). Short or repetitive names yield low
57
+ entropy, which signals we should defer resolution to the LLM instead of
58
+ trusting fuzzy similarity.
59
+ """
60
+ if not normalized_name:
61
+ return 0.0
62
+
63
+ counts: dict[str, int] = {}
64
+ for char in normalized_name.replace(' ', ''):
65
+ counts[char] = counts.get(char, 0) + 1
66
+
67
+ total = sum(counts.values())
68
+ if total == 0:
69
+ return 0.0
70
+
71
+ entropy = 0.0
72
+ for count in counts.values():
73
+ probability = count / total
74
+ entropy -= probability * math.log2(probability)
75
+
76
+ return entropy
77
+
78
+
79
+ def _has_high_entropy(normalized_name: str) -> bool:
80
+ """Filter out very short or low-entropy names that are unreliable for fuzzy matching."""
81
+ token_count = len(normalized_name.split())
82
+ if len(normalized_name) < _MIN_NAME_LENGTH and token_count < _MIN_TOKEN_COUNT:
83
+ return False
84
+
85
+ return _name_entropy(normalized_name) >= _NAME_ENTROPY_THRESHOLD
86
+
87
+
88
+ def _shingles(normalized_name: str) -> set[str]:
89
+ """Create 3-gram shingles from the normalized name for MinHash calculations."""
90
+ cleaned = normalized_name.replace(' ', '')
91
+ if len(cleaned) < 2:
92
+ return {cleaned} if cleaned else set()
93
+
94
+ return {cleaned[i : i + 3] for i in range(len(cleaned) - 2)}
95
+
96
+
97
+ def _hash_shingle(shingle: str, seed: int) -> int:
98
+ """Generate a deterministic 64-bit hash for a shingle given the permutation seed."""
99
+ digest = blake2b(f'{seed}:{shingle}'.encode(), digest_size=8)
100
+ return int.from_bytes(digest.digest(), 'big')
101
+
102
+
103
+ def _minhash_signature(shingles: Iterable[str]) -> tuple[int, ...]:
104
+ """Compute the MinHash signature for the shingle set across predefined permutations."""
105
+ if not shingles:
106
+ return tuple()
107
+
108
+ seeds = range(_MINHASH_PERMUTATIONS)
109
+ signature: list[int] = []
110
+ for seed in seeds:
111
+ min_hash = min(_hash_shingle(shingle, seed) for shingle in shingles)
112
+ signature.append(min_hash)
113
+
114
+ return tuple(signature)
115
+
116
+
117
+ def _lsh_bands(signature: Iterable[int]) -> list[tuple[int, ...]]:
118
+ """Split the MinHash signature into fixed-size bands for locality-sensitive hashing."""
119
+ signature_list = list(signature)
120
+ if not signature_list:
121
+ return []
122
+
123
+ bands: list[tuple[int, ...]] = []
124
+ for start in range(0, len(signature_list), _MINHASH_BAND_SIZE):
125
+ band = tuple(signature_list[start : start + _MINHASH_BAND_SIZE])
126
+ if len(band) == _MINHASH_BAND_SIZE:
127
+ bands.append(band)
128
+ return bands
129
+
130
+
131
+ def _jaccard_similarity(a: set[str], b: set[str]) -> float:
132
+ """Return the Jaccard similarity between two shingle sets, handling empty edge cases."""
133
+ if not a and not b:
134
+ return 1.0
135
+ if not a or not b:
136
+ return 0.0
137
+
138
+ intersection = len(a.intersection(b))
139
+ union = len(a.union(b))
140
+ return intersection / union if union else 0.0
141
+
142
+
143
+ @lru_cache(maxsize=512)
144
+ def _cached_shingles(name: str) -> set[str]:
145
+ """Cache shingle sets per normalized name to avoid recomputation within a worker."""
146
+ return _shingles(name)
147
+
148
+
149
+ @dataclass
150
+ class DedupCandidateIndexes:
151
+ """Precomputed lookup structures that drive entity deduplication heuristics."""
152
+
153
+ existing_nodes: list[EntityNode]
154
+ nodes_by_uuid: dict[str, EntityNode]
155
+ normalized_existing: defaultdict[str, list[EntityNode]]
156
+ shingles_by_candidate: dict[str, set[str]]
157
+ lsh_buckets: defaultdict[tuple[int, tuple[int, ...]], list[str]]
158
+
159
+
160
+ @dataclass
161
+ class DedupResolutionState:
162
+ """Mutable resolution bookkeeping shared across deterministic and LLM passes."""
163
+
164
+ resolved_nodes: list[EntityNode | None]
165
+ uuid_map: dict[str, str]
166
+ unresolved_indices: list[int]
167
+ duplicate_pairs: list[tuple[EntityNode, EntityNode]] = field(default_factory=list)
168
+
169
+
170
+ def _build_candidate_indexes(existing_nodes: list[EntityNode]) -> DedupCandidateIndexes:
171
+ """Precompute exact and fuzzy lookup structures once per dedupe run."""
172
+ normalized_existing: defaultdict[str, list[EntityNode]] = defaultdict(list)
173
+ nodes_by_uuid: dict[str, EntityNode] = {}
174
+ shingles_by_candidate: dict[str, set[str]] = {}
175
+ lsh_buckets: defaultdict[tuple[int, tuple[int, ...]], list[str]] = defaultdict(list)
176
+
177
+ for candidate in existing_nodes:
178
+ normalized = _normalize_string_exact(candidate.name)
179
+ normalized_existing[normalized].append(candidate)
180
+ nodes_by_uuid[candidate.uuid] = candidate
181
+
182
+ shingles = _cached_shingles(_normalize_name_for_fuzzy(candidate.name))
183
+ shingles_by_candidate[candidate.uuid] = shingles
184
+
185
+ signature = _minhash_signature(shingles)
186
+ for band_index, band in enumerate(_lsh_bands(signature)):
187
+ lsh_buckets[(band_index, band)].append(candidate.uuid)
188
+
189
+ return DedupCandidateIndexes(
190
+ existing_nodes=existing_nodes,
191
+ nodes_by_uuid=nodes_by_uuid,
192
+ normalized_existing=normalized_existing,
193
+ shingles_by_candidate=shingles_by_candidate,
194
+ lsh_buckets=lsh_buckets,
195
+ )
196
+
197
+
198
+ def _resolve_with_similarity(
199
+ extracted_nodes: list[EntityNode],
200
+ indexes: DedupCandidateIndexes,
201
+ state: DedupResolutionState,
202
+ ) -> None:
203
+ """Attempt deterministic resolution using exact name hits and fuzzy MinHash comparisons."""
204
+ for idx, node in enumerate(extracted_nodes):
205
+ normalized_exact = _normalize_string_exact(node.name)
206
+ normalized_fuzzy = _normalize_name_for_fuzzy(node.name)
207
+
208
+ if not _has_high_entropy(normalized_fuzzy):
209
+ state.unresolved_indices.append(idx)
210
+ continue
211
+
212
+ existing_matches = indexes.normalized_existing.get(normalized_exact, [])
213
+ if len(existing_matches) == 1:
214
+ match = existing_matches[0]
215
+ state.resolved_nodes[idx] = match
216
+ state.uuid_map[node.uuid] = match.uuid
217
+ if match.uuid != node.uuid:
218
+ state.duplicate_pairs.append((node, match))
219
+ continue
220
+ if len(existing_matches) > 1:
221
+ state.unresolved_indices.append(idx)
222
+ continue
223
+
224
+ shingles = _cached_shingles(normalized_fuzzy)
225
+ signature = _minhash_signature(shingles)
226
+ candidate_ids: set[str] = set()
227
+ for band_index, band in enumerate(_lsh_bands(signature)):
228
+ candidate_ids.update(indexes.lsh_buckets.get((band_index, band), []))
229
+
230
+ best_candidate: EntityNode | None = None
231
+ best_score = 0.0
232
+ for candidate_id in candidate_ids:
233
+ candidate_shingles = indexes.shingles_by_candidate.get(candidate_id, set())
234
+ score = _jaccard_similarity(shingles, candidate_shingles)
235
+ if score > best_score:
236
+ best_score = score
237
+ best_candidate = indexes.nodes_by_uuid.get(candidate_id)
238
+
239
+ if best_candidate is not None and best_score >= _FUZZY_JACCARD_THRESHOLD:
240
+ state.resolved_nodes[idx] = best_candidate
241
+ state.uuid_map[node.uuid] = best_candidate.uuid
242
+ if best_candidate.uuid != node.uuid:
243
+ state.duplicate_pairs.append((node, best_candidate))
244
+ continue
245
+
246
+ state.unresolved_indices.append(idx)
247
+
248
+
249
+ __all__ = [
250
+ 'DedupCandidateIndexes',
251
+ 'DedupResolutionState',
252
+ '_normalize_string_exact',
253
+ '_normalize_name_for_fuzzy',
254
+ '_has_high_entropy',
255
+ '_minhash_signature',
256
+ '_lsh_bands',
257
+ '_jaccard_similarity',
258
+ '_cached_shingles',
259
+ '_FUZZY_JACCARD_THRESHOLD',
260
+ '_build_candidate_indexes',
261
+ '_resolve_with_similarity',
262
+ ]
@@ -41,6 +41,9 @@ from graphiti_core.search.search_config import SearchResults
41
41
  from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF
42
42
  from graphiti_core.search.search_filters import SearchFilters
43
43
  from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
44
+ from graphiti_core.utils.maintenance.dedup_helpers import _normalize_string_exact
45
+
46
+ DEFAULT_EDGE_NAME = 'RELATES_TO'
44
47
 
45
48
  logger = logging.getLogger(__name__)
46
49
 
@@ -65,32 +68,6 @@ def build_episodic_edges(
65
68
  return episodic_edges
66
69
 
67
70
 
68
- def build_duplicate_of_edges(
69
- episode: EpisodicNode,
70
- created_at: datetime,
71
- duplicate_nodes: list[tuple[EntityNode, EntityNode]],
72
- ) -> list[EntityEdge]:
73
- is_duplicate_of_edges: list[EntityEdge] = []
74
- for source_node, target_node in duplicate_nodes:
75
- if source_node.uuid == target_node.uuid:
76
- continue
77
-
78
- is_duplicate_of_edges.append(
79
- EntityEdge(
80
- source_node_uuid=source_node.uuid,
81
- target_node_uuid=target_node.uuid,
82
- name='IS_DUPLICATE_OF',
83
- group_id=episode.group_id,
84
- fact=f'{source_node.name} is a duplicate of {target_node.name}',
85
- episodes=[episode.uuid],
86
- created_at=created_at,
87
- valid_at=created_at,
88
- )
89
- )
90
-
91
- return is_duplicate_of_edges
92
-
93
-
94
71
  def build_community_edges(
95
72
  entity_nodes: list[EntityNode],
96
73
  community_node: CommunityNode,
@@ -306,8 +283,12 @@ async def resolve_extracted_edges(
306
283
  # Build entity hash table
307
284
  uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
308
285
 
309
- # Determine which edge types are relevant for each edge
286
+ # Determine which edge types are relevant for each edge.
287
+ # `edge_types_lst` stores the subset of custom edge definitions whose
288
+ # node signature matches each extracted edge. Anything outside this subset
289
+ # should only stay on the edge if it is a non-custom (LLM generated) label.
310
290
  edge_types_lst: list[dict[str, type[BaseModel]]] = []
291
+ custom_type_names = set(edge_types or {})
311
292
  for extracted_edge in extracted_edges:
312
293
  source_node = uuid_entity_map.get(extracted_edge.source_node_uuid)
313
294
  target_node = uuid_entity_map.get(extracted_edge.target_node_uuid)
@@ -335,6 +316,20 @@ async def resolve_extracted_edges(
335
316
 
336
317
  edge_types_lst.append(extracted_edge_types)
337
318
 
319
+ for extracted_edge, extracted_edge_types in zip(extracted_edges, edge_types_lst, strict=True):
320
+ allowed_type_names = set(extracted_edge_types)
321
+ is_custom_name = extracted_edge.name in custom_type_names
322
+ if not allowed_type_names:
323
+ # No custom types are valid for this node pairing. Keep LLM generated
324
+ # labels, but flip disallowed custom names back to the default.
325
+ if is_custom_name and extracted_edge.name != DEFAULT_EDGE_NAME:
326
+ extracted_edge.name = DEFAULT_EDGE_NAME
327
+ continue
328
+ if is_custom_name and extracted_edge.name not in allowed_type_names:
329
+ # Custom name exists but it is not permitted for this source/target
330
+ # signature, so fall back to the default edge label.
331
+ extracted_edge.name = DEFAULT_EDGE_NAME
332
+
338
333
  # resolve edges with related edges in the graph and find invalidation candidates
339
334
  results: list[tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]] = list(
340
335
  await semaphore_gather(
@@ -346,6 +341,7 @@ async def resolve_extracted_edges(
346
341
  existing_edges,
347
342
  episode,
348
343
  extracted_edge_types,
344
+ custom_type_names,
349
345
  clients.ensure_ascii,
350
346
  )
351
347
  for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
@@ -417,12 +413,54 @@ async def resolve_extracted_edge(
417
413
  related_edges: list[EntityEdge],
418
414
  existing_edges: list[EntityEdge],
419
415
  episode: EpisodicNode,
420
- edge_types: dict[str, type[BaseModel]] | None = None,
416
+ edge_type_candidates: dict[str, type[BaseModel]] | None = None,
417
+ custom_edge_type_names: set[str] | None = None,
421
418
  ensure_ascii: bool = True,
422
419
  ) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
420
+ """Resolve an extracted edge against existing graph context.
421
+
422
+ Parameters
423
+ ----------
424
+ llm_client : LLMClient
425
+ Client used to invoke the LLM for deduplication and attribute extraction.
426
+ extracted_edge : EntityEdge
427
+ Newly extracted edge whose canonical representation is being resolved.
428
+ related_edges : list[EntityEdge]
429
+ Candidate edges with identical endpoints used for duplicate detection.
430
+ existing_edges : list[EntityEdge]
431
+ Broader set of edges evaluated for contradiction / invalidation.
432
+ episode : EpisodicNode
433
+ Episode providing content context when extracting edge attributes.
434
+ edge_type_candidates : dict[str, type[BaseModel]] | None
435
+ Custom edge types permitted for the current source/target signature.
436
+ custom_edge_type_names : set[str] | None
437
+ Full catalog of registered custom edge names. Used to distinguish
438
+ between disallowed custom types (which fall back to the default label)
439
+ and ad-hoc labels emitted by the LLM.
440
+ ensure_ascii : bool
441
+ Whether prompt payloads should coerce ASCII output.
442
+
443
+ Returns
444
+ -------
445
+ tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]
446
+ The resolved edge, any duplicates, and edges to invalidate.
447
+ """
423
448
  if len(related_edges) == 0 and len(existing_edges) == 0:
424
449
  return extracted_edge, [], []
425
450
 
451
+ # Fast path: if the fact text and endpoints already exist verbatim, reuse the matching edge.
452
+ normalized_fact = _normalize_string_exact(extracted_edge.fact)
453
+ for edge in related_edges:
454
+ if (
455
+ edge.source_node_uuid == extracted_edge.source_node_uuid
456
+ and edge.target_node_uuid == extracted_edge.target_node_uuid
457
+ and _normalize_string_exact(edge.fact) == normalized_fact
458
+ ):
459
+ resolved = edge
460
+ if episode is not None and episode.uuid not in resolved.episodes:
461
+ resolved.episodes.append(episode.uuid)
462
+ return resolved, [], []
463
+
426
464
  start = time()
427
465
 
428
466
  # Prepare context for LLM
@@ -441,9 +479,9 @@ async def resolve_extracted_edge(
441
479
  'fact_type_name': type_name,
442
480
  'fact_type_description': type_model.__doc__,
443
481
  }
444
- for i, (type_name, type_model) in enumerate(edge_types.items())
482
+ for i, (type_name, type_model) in enumerate(edge_type_candidates.items())
445
483
  ]
446
- if edge_types is not None
484
+ if edge_type_candidates is not None
447
485
  else []
448
486
  )
449
487
 
@@ -480,7 +518,16 @@ async def resolve_extracted_edge(
480
518
  ]
481
519
 
482
520
  fact_type: str = response_object.fact_type
483
- if fact_type.upper() != 'DEFAULT' and edge_types is not None:
521
+ candidate_type_names = set(edge_type_candidates or {})
522
+ custom_type_names = custom_edge_type_names or set()
523
+
524
+ is_default_type = fact_type.upper() == 'DEFAULT'
525
+ is_custom_type = fact_type in custom_type_names
526
+ is_allowed_custom_type = fact_type in candidate_type_names
527
+
528
+ if is_allowed_custom_type:
529
+ # The LLM selected a custom type that is allowed for the node pair.
530
+ # Adopt the custom type and, if needed, extract its structured attributes.
484
531
  resolved_edge.name = fact_type
485
532
 
486
533
  edge_attributes_context = {
@@ -490,7 +537,7 @@ async def resolve_extracted_edge(
490
537
  'ensure_ascii': ensure_ascii,
491
538
  }
492
539
 
493
- edge_model = edge_types.get(fact_type)
540
+ edge_model = edge_type_candidates.get(fact_type) if edge_type_candidates else None
494
541
  if edge_model is not None and len(edge_model.model_fields) != 0:
495
542
  edge_attributes_response = await llm_client.generate_response(
496
543
  prompt_library.extract_edges.extract_attributes(edge_attributes_context),
@@ -499,6 +546,16 @@ async def resolve_extracted_edge(
499
546
  )
500
547
 
501
548
  resolved_edge.attributes = edge_attributes_response
549
+ elif not is_default_type and is_custom_type:
550
+ # The LLM picked a custom type that is not allowed for this signature.
551
+ # Reset to the default label and drop any structured attributes.
552
+ resolved_edge.name = DEFAULT_EDGE_NAME
553
+ resolved_edge.attributes = {}
554
+ elif not is_default_type:
555
+ # Non-custom labels are allowed to pass through so long as the LLM does
556
+ # not return the sentinel DEFAULT value.
557
+ resolved_edge.name = fact_type
558
+ resolved_edge.attributes = {}
502
559
 
503
560
  end = time()
504
561
  logger.debug(