graphiti-core 0.20.4__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (39) hide show
  1. graphiti_core/driver/driver.py +28 -0
  2. graphiti_core/driver/falkordb_driver.py +112 -0
  3. graphiti_core/driver/kuzu_driver.py +1 -0
  4. graphiti_core/driver/neo4j_driver.py +10 -2
  5. graphiti_core/driver/neptune_driver.py +4 -6
  6. graphiti_core/edges.py +67 -7
  7. graphiti_core/embedder/client.py +2 -1
  8. graphiti_core/graph_queries.py +35 -6
  9. graphiti_core/graphiti.py +27 -23
  10. graphiti_core/graphiti_types.py +0 -1
  11. graphiti_core/helpers.py +2 -2
  12. graphiti_core/llm_client/client.py +19 -4
  13. graphiti_core/llm_client/gemini_client.py +4 -2
  14. graphiti_core/llm_client/openai_base_client.py +3 -2
  15. graphiti_core/llm_client/openai_generic_client.py +3 -2
  16. graphiti_core/models/edges/edge_db_queries.py +36 -16
  17. graphiti_core/models/nodes/node_db_queries.py +30 -10
  18. graphiti_core/nodes.py +126 -25
  19. graphiti_core/prompts/dedupe_edges.py +40 -29
  20. graphiti_core/prompts/dedupe_nodes.py +51 -34
  21. graphiti_core/prompts/eval.py +3 -3
  22. graphiti_core/prompts/extract_edges.py +17 -9
  23. graphiti_core/prompts/extract_nodes.py +10 -9
  24. graphiti_core/prompts/prompt_helpers.py +3 -3
  25. graphiti_core/prompts/summarize_nodes.py +5 -5
  26. graphiti_core/search/search_filters.py +53 -0
  27. graphiti_core/search/search_helpers.py +5 -7
  28. graphiti_core/search/search_utils.py +227 -57
  29. graphiti_core/utils/bulk_utils.py +168 -69
  30. graphiti_core/utils/maintenance/community_operations.py +8 -20
  31. graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
  32. graphiti_core/utils/maintenance/edge_operations.py +187 -50
  33. graphiti_core/utils/maintenance/graph_data_operations.py +9 -5
  34. graphiti_core/utils/maintenance/node_operations.py +244 -88
  35. graphiti_core/utils/maintenance/temporal_operations.py +0 -4
  36. {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0.dist-info}/METADATA +7 -1
  37. {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0.dist-info}/RECORD +39 -38
  38. {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0.dist-info}/WHEEL +0 -0
  39. {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0.dist-info}/licenses/LICENSE +0 -0
@@ -23,7 +23,14 @@ import numpy as np
23
23
  from pydantic import BaseModel, Field
24
24
  from typing_extensions import Any
25
25
 
26
- from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
26
+ from graphiti_core.driver.driver import (
27
+ ENTITY_EDGE_INDEX_NAME,
28
+ ENTITY_INDEX_NAME,
29
+ EPISODE_INDEX_NAME,
30
+ GraphDriver,
31
+ GraphDriverSession,
32
+ GraphProvider,
33
+ )
27
34
  from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
28
35
  from graphiti_core.embedder import EmbedderClient
29
36
  from graphiti_core.graphiti_types import GraphitiClients
@@ -36,8 +43,14 @@ from graphiti_core.models.nodes.node_db_queries import (
36
43
  get_entity_node_save_bulk_query,
37
44
  get_episode_node_save_bulk_query,
38
45
  )
39
- from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
46
+ from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
40
47
  from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
48
+ from graphiti_core.utils.maintenance.dedup_helpers import (
49
+ DedupResolutionState,
50
+ _build_candidate_indexes,
51
+ _normalize_string_exact,
52
+ _resolve_with_similarity,
53
+ )
41
54
  from graphiti_core.utils.maintenance.edge_operations import (
42
55
  extract_edges,
43
56
  resolve_extracted_edge,
@@ -56,6 +69,38 @@ logger = logging.getLogger(__name__)
56
69
  CHUNK_SIZE = 10
57
70
 
58
71
 
72
+ def _build_directed_uuid_map(pairs: list[tuple[str, str]]) -> dict[str, str]:
73
+ """Collapse alias -> canonical chains while preserving direction.
74
+
75
+ The incoming pairs represent directed mappings discovered during node dedupe. We use a simple
76
+ union-find with iterative path compression to ensure every source UUID resolves to its ultimate
77
+ canonical target, even if aliases appear lexicographically smaller than the canonical UUID.
78
+ """
79
+
80
+ parent: dict[str, str] = {}
81
+
82
+ def find(uuid: str) -> str:
83
+ """Directed union-find lookup using iterative path compression."""
84
+ parent.setdefault(uuid, uuid)
85
+ root = uuid
86
+ while parent[root] != root:
87
+ root = parent[root]
88
+
89
+ while parent[uuid] != root:
90
+ next_uuid = parent[uuid]
91
+ parent[uuid] = root
92
+ uuid = next_uuid
93
+
94
+ return root
95
+
96
+ for source_uuid, target_uuid in pairs:
97
+ parent.setdefault(source_uuid, source_uuid)
98
+ parent.setdefault(target_uuid, target_uuid)
99
+ parent[find(source_uuid)] = find(target_uuid)
100
+
101
+ return {uuid: find(uuid) for uuid in parent}
102
+
103
+
59
104
  class RawEpisode(BaseModel):
60
105
  name: str
61
106
  uuid: str | None = Field(default=None)
@@ -129,12 +174,14 @@ async def add_nodes_and_edges_bulk_tx(
129
174
  entity_data: dict[str, Any] = {
130
175
  'uuid': node.uuid,
131
176
  'name': node.name,
132
- 'name_embedding': node.name_embedding,
133
177
  'group_id': node.group_id,
134
178
  'summary': node.summary,
135
179
  'created_at': node.created_at,
136
180
  }
137
181
 
182
+ if not bool(driver.aoss_client):
183
+ entity_data['name_embedding'] = node.name_embedding
184
+
138
185
  entity_data['labels'] = list(set(node.labels + ['Entity']))
139
186
  if driver.provider == GraphProvider.KUZU:
140
187
  attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
@@ -154,7 +201,6 @@ async def add_nodes_and_edges_bulk_tx(
154
201
  'target_node_uuid': edge.target_node_uuid,
155
202
  'name': edge.name,
156
203
  'fact': edge.fact,
157
- 'fact_embedding': edge.fact_embedding,
158
204
  'group_id': edge.group_id,
159
205
  'episodes': edge.episodes,
160
206
  'created_at': edge.created_at,
@@ -163,6 +209,9 @@ async def add_nodes_and_edges_bulk_tx(
163
209
  'invalid_at': edge.invalid_at,
164
210
  }
165
211
 
212
+ if not bool(driver.aoss_client):
213
+ edge_data['fact_embedding'] = edge.fact_embedding
214
+
166
215
  if driver.provider == GraphProvider.KUZU:
167
216
  attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
168
217
  edge_data['attributes'] = json.dumps(attributes)
@@ -187,12 +236,33 @@ async def add_nodes_and_edges_bulk_tx(
187
236
  await tx.run(episodic_edge_query, **edge.model_dump())
188
237
  else:
189
238
  await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
190
- await tx.run(get_entity_node_save_bulk_query(driver.provider, nodes), nodes=nodes)
239
+ await tx.run(
240
+ get_entity_node_save_bulk_query(
241
+ driver.provider, nodes, has_aoss=bool(driver.aoss_client)
242
+ ),
243
+ nodes=nodes,
244
+ )
191
245
  await tx.run(
192
246
  get_episodic_edge_save_bulk_query(driver.provider),
193
247
  episodic_edges=[edge.model_dump() for edge in episodic_edges],
194
248
  )
195
- await tx.run(get_entity_edge_save_bulk_query(driver.provider), entity_edges=edges)
249
+ await tx.run(
250
+ get_entity_edge_save_bulk_query(driver.provider, has_aoss=bool(driver.aoss_client)),
251
+ entity_edges=edges,
252
+ )
253
+
254
+ if bool(driver.aoss_client):
255
+ for node_data, entity_node in zip(nodes, entity_nodes, strict=True):
256
+ if node_data.get('uuid') == entity_node.uuid:
257
+ node_data['name_embedding'] = entity_node.name_embedding
258
+
259
+ for edge_data, entity_edge in zip(edges, entity_edges, strict=True):
260
+ if edge_data.get('uuid') == entity_edge.uuid:
261
+ edge_data['fact_embedding'] = entity_edge.fact_embedding
262
+
263
+ await driver.save_to_aoss(EPISODE_INDEX_NAME, episodes)
264
+ await driver.save_to_aoss(ENTITY_INDEX_NAME, nodes)
265
+ await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges)
196
266
 
197
267
 
198
268
  async def extract_nodes_and_edges_bulk(
@@ -234,83 +304,111 @@ async def dedupe_nodes_bulk(
234
304
  episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
235
305
  entity_types: dict[str, type[BaseModel]] | None = None,
236
306
  ) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
237
- embedder = clients.embedder
238
- min_score = 0.8
307
+ """Resolve entity duplicates across an in-memory batch using a two-pass strategy.
239
308
 
240
- # generate embeddings
241
- await semaphore_gather(
242
- *[create_entity_node_embeddings(embedder, nodes) for nodes in extracted_nodes]
243
- )
244
-
245
- # Find similar results
246
- dedupe_tuples: list[tuple[list[EntityNode], list[EntityNode]]] = []
247
- for i, nodes_i in enumerate(extracted_nodes):
248
- existing_nodes: list[EntityNode] = []
249
- for j, nodes_j in enumerate(extracted_nodes):
250
- if i == j:
251
- continue
252
- existing_nodes += nodes_j
253
-
254
- candidates_i: list[EntityNode] = []
255
- for node in nodes_i:
256
- for existing_node in existing_nodes:
257
- # Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
258
- # This approach will cast a wider net than BM25, which is ideal for this use case
259
- node_words = set(node.name.lower().split())
260
- existing_node_words = set(existing_node.name.lower().split())
261
- has_overlap = not node_words.isdisjoint(existing_node_words)
262
- if has_overlap:
263
- candidates_i.append(existing_node)
264
- continue
265
-
266
- # Check for semantic similarity even if there is no overlap
267
- similarity = np.dot(
268
- normalize_l2(node.name_embedding or []),
269
- normalize_l2(existing_node.name_embedding or []),
270
- )
271
- if similarity >= min_score:
272
- candidates_i.append(existing_node)
273
-
274
- dedupe_tuples.append((nodes_i, candidates_i))
309
+ 1. Run :func:`resolve_extracted_nodes` for every episode in parallel so each batch item is
310
+ reconciled against the live graph just like the non-batch flow.
311
+ 2. Re-run the deterministic similarity heuristics across the union of resolved nodes to catch
312
+ duplicates that only co-occur inside this batch, emitting a canonical UUID map that callers
313
+ can apply to edges and persistence.
314
+ """
275
315
 
276
- # Determine Node Resolutions
277
- bulk_node_resolutions: list[
278
- tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]
279
- ] = await semaphore_gather(
316
+ first_pass_results = await semaphore_gather(
280
317
  *[
281
318
  resolve_extracted_nodes(
282
319
  clients,
283
- dedupe_tuple[0],
320
+ nodes,
284
321
  episode_tuples[i][0],
285
322
  episode_tuples[i][1],
286
323
  entity_types,
287
- existing_nodes_override=dedupe_tuples[i][1],
288
324
  )
289
- for i, dedupe_tuple in enumerate(dedupe_tuples)
325
+ for i, nodes in enumerate(extracted_nodes)
290
326
  ]
291
327
  )
292
328
 
293
- # Collect all duplicate pairs sorted by uuid
329
+ episode_resolutions: list[tuple[str, list[EntityNode]]] = []
330
+ per_episode_uuid_maps: list[dict[str, str]] = []
294
331
  duplicate_pairs: list[tuple[str, str]] = []
295
- for _, _, duplicates in bulk_node_resolutions:
296
- for duplicate in duplicates:
297
- n, m = duplicate
298
- duplicate_pairs.append((n.uuid, m.uuid))
299
332
 
300
- # Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
301
- compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs)
333
+ for (resolved_nodes, uuid_map, duplicates), (episode, _) in zip(
334
+ first_pass_results, episode_tuples, strict=True
335
+ ):
336
+ episode_resolutions.append((episode.uuid, resolved_nodes))
337
+ per_episode_uuid_maps.append(uuid_map)
338
+ duplicate_pairs.extend((source.uuid, target.uuid) for source, target in duplicates)
339
+
340
+ canonical_nodes: dict[str, EntityNode] = {}
341
+ for _, resolved_nodes in episode_resolutions:
342
+ for node in resolved_nodes:
343
+ # NOTE: this loop is O(n^2) in the number of nodes inside the batch because we rebuild
344
+ # the MinHash index for the accumulated canonical pool each time. The LRU-backed
345
+ # shingle cache keeps the constant factors low for typical batch sizes (≤ CHUNK_SIZE),
346
+ # but if batches grow significantly we should switch to an incremental index or chunked
347
+ # processing.
348
+ if not canonical_nodes:
349
+ canonical_nodes[node.uuid] = node
350
+ continue
302
351
 
303
- node_uuid_map: dict[str, EntityNode] = {
304
- node.uuid: node for nodes in extracted_nodes for node in nodes
305
- }
352
+ existing_candidates = list(canonical_nodes.values())
353
+ normalized = _normalize_string_exact(node.name)
354
+ exact_match = next(
355
+ (
356
+ candidate
357
+ for candidate in existing_candidates
358
+ if _normalize_string_exact(candidate.name) == normalized
359
+ ),
360
+ None,
361
+ )
362
+ if exact_match is not None:
363
+ if exact_match.uuid != node.uuid:
364
+ duplicate_pairs.append((node.uuid, exact_match.uuid))
365
+ continue
366
+
367
+ indexes = _build_candidate_indexes(existing_candidates)
368
+ state = DedupResolutionState(
369
+ resolved_nodes=[None],
370
+ uuid_map={},
371
+ unresolved_indices=[],
372
+ )
373
+ _resolve_with_similarity([node], indexes, state)
374
+
375
+ resolved = state.resolved_nodes[0]
376
+ if resolved is None:
377
+ canonical_nodes[node.uuid] = node
378
+ continue
379
+
380
+ canonical_uuid = resolved.uuid
381
+ canonical_nodes.setdefault(canonical_uuid, resolved)
382
+ if canonical_uuid != node.uuid:
383
+ duplicate_pairs.append((node.uuid, canonical_uuid))
384
+
385
+ union_pairs: list[tuple[str, str]] = []
386
+ for uuid_map in per_episode_uuid_maps:
387
+ union_pairs.extend(uuid_map.items())
388
+ union_pairs.extend(duplicate_pairs)
389
+
390
+ compressed_map: dict[str, str] = _build_directed_uuid_map(union_pairs)
306
391
 
307
392
  nodes_by_episode: dict[str, list[EntityNode]] = {}
308
- for i, nodes in enumerate(extracted_nodes):
309
- episode = episode_tuples[i][0]
393
+ for episode_uuid, resolved_nodes in episode_resolutions:
394
+ deduped_nodes: list[EntityNode] = []
395
+ seen: set[str] = set()
396
+ for node in resolved_nodes:
397
+ canonical_uuid = compressed_map.get(node.uuid, node.uuid)
398
+ if canonical_uuid in seen:
399
+ continue
400
+ seen.add(canonical_uuid)
401
+ canonical_node = canonical_nodes.get(canonical_uuid)
402
+ if canonical_node is None:
403
+ logger.error(
404
+ 'Canonical node %s missing during batch dedupe; falling back to %s',
405
+ canonical_uuid,
406
+ node.uuid,
407
+ )
408
+ canonical_node = node
409
+ deduped_nodes.append(canonical_node)
310
410
 
311
- nodes_by_episode[episode.uuid] = [
312
- node_uuid_map[compressed_map.get(node.uuid, node.uuid)] for node in nodes
313
- ]
411
+ nodes_by_episode[episode_uuid] = deduped_nodes
314
412
 
315
413
  return nodes_by_episode, compressed_map
316
414
 
@@ -335,14 +433,15 @@ async def dedupe_edges_bulk(
335
433
  dedupe_tuples: list[tuple[EpisodicNode, EntityEdge, list[EntityEdge]]] = []
336
434
  for i, edges_i in enumerate(extracted_edges):
337
435
  existing_edges: list[EntityEdge] = []
338
- for j, edges_j in enumerate(extracted_edges):
339
- if i == j:
340
- continue
436
+ for edges_j in extracted_edges:
341
437
  existing_edges += edges_j
342
438
 
343
439
  for edge in edges_i:
344
440
  candidates: list[EntityEdge] = []
345
441
  for existing_edge in existing_edges:
442
+ # Skip self-comparison
443
+ if edge.uuid == existing_edge.uuid:
444
+ continue
346
445
  # Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
347
446
  # This approach will cast a wider net than BM25, which is ideal for this use case
348
447
  if (
@@ -379,7 +478,7 @@ async def dedupe_edges_bulk(
379
478
  candidates,
380
479
  episode,
381
480
  edge_types,
382
- clients.ensure_ascii,
481
+ set(edge_types),
383
482
  )
384
483
  for episode, edge, candidates in dedupe_tuples
385
484
  ]
@@ -131,13 +131,10 @@ def label_propagation(projection: dict[str, list[Neighbor]]) -> list[list[str]]:
131
131
  return clusters
132
132
 
133
133
 
134
- async def summarize_pair(
135
- llm_client: LLMClient, summary_pair: tuple[str, str], ensure_ascii: bool = True
136
- ) -> str:
134
+ async def summarize_pair(llm_client: LLMClient, summary_pair: tuple[str, str]) -> str:
137
135
  # Prepare context for LLM
138
136
  context = {
139
137
  'node_summaries': [{'summary': summary} for summary in summary_pair],
140
- 'ensure_ascii': ensure_ascii,
141
138
  }
142
139
 
143
140
  llm_response = await llm_client.generate_response(
@@ -149,12 +146,9 @@ async def summarize_pair(
149
146
  return pair_summary
150
147
 
151
148
 
152
- async def generate_summary_description(
153
- llm_client: LLMClient, summary: str, ensure_ascii: bool = True
154
- ) -> str:
149
+ async def generate_summary_description(llm_client: LLMClient, summary: str) -> str:
155
150
  context = {
156
151
  'summary': summary,
157
- 'ensure_ascii': ensure_ascii,
158
152
  }
159
153
 
160
154
  llm_response = await llm_client.generate_response(
@@ -168,7 +162,7 @@ async def generate_summary_description(
168
162
 
169
163
 
170
164
  async def build_community(
171
- llm_client: LLMClient, community_cluster: list[EntityNode], ensure_ascii: bool = True
165
+ llm_client: LLMClient, community_cluster: list[EntityNode]
172
166
  ) -> tuple[CommunityNode, list[CommunityEdge]]:
173
167
  summaries = [entity.summary for entity in community_cluster]
174
168
  length = len(summaries)
@@ -180,9 +174,7 @@ async def build_community(
180
174
  new_summaries: list[str] = list(
181
175
  await semaphore_gather(
182
176
  *[
183
- summarize_pair(
184
- llm_client, (str(left_summary), str(right_summary)), ensure_ascii
185
- )
177
+ summarize_pair(llm_client, (str(left_summary), str(right_summary)))
186
178
  for left_summary, right_summary in zip(
187
179
  summaries[: int(length / 2)], summaries[int(length / 2) :], strict=False
188
180
  )
@@ -195,7 +187,7 @@ async def build_community(
195
187
  length = len(summaries)
196
188
 
197
189
  summary = summaries[0]
198
- name = await generate_summary_description(llm_client, summary, ensure_ascii)
190
+ name = await generate_summary_description(llm_client, summary)
199
191
  now = utc_now()
200
192
  community_node = CommunityNode(
201
193
  name=name,
@@ -215,7 +207,6 @@ async def build_communities(
215
207
  driver: GraphDriver,
216
208
  llm_client: LLMClient,
217
209
  group_ids: list[str] | None,
218
- ensure_ascii: bool = True,
219
210
  ) -> tuple[list[CommunityNode], list[CommunityEdge]]:
220
211
  community_clusters = await get_community_clusters(driver, group_ids)
221
212
 
@@ -223,7 +214,7 @@ async def build_communities(
223
214
 
224
215
  async def limited_build_community(cluster):
225
216
  async with semaphore:
226
- return await build_community(llm_client, cluster, ensure_ascii)
217
+ return await build_community(llm_client, cluster)
227
218
 
228
219
  communities: list[tuple[CommunityNode, list[CommunityEdge]]] = list(
229
220
  await semaphore_gather(
@@ -312,17 +303,14 @@ async def update_community(
312
303
  llm_client: LLMClient,
313
304
  embedder: EmbedderClient,
314
305
  entity: EntityNode,
315
- ensure_ascii: bool = True,
316
306
  ) -> tuple[list[CommunityNode], list[CommunityEdge]]:
317
307
  community, is_new = await determine_entity_community(driver, entity)
318
308
 
319
309
  if community is None:
320
310
  return [], []
321
311
 
322
- new_summary = await summarize_pair(
323
- llm_client, (entity.summary, community.summary), ensure_ascii
324
- )
325
- new_name = await generate_summary_description(llm_client, new_summary, ensure_ascii)
312
+ new_summary = await summarize_pair(llm_client, (entity.summary, community.summary))
313
+ new_name = await generate_summary_description(llm_client, new_summary)
326
314
 
327
315
  community.summary = new_summary
328
316
  community.name = new_name
@@ -0,0 +1,262 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import math
20
+ import re
21
+ from collections import defaultdict
22
+ from collections.abc import Iterable
23
+ from dataclasses import dataclass, field
24
+ from functools import lru_cache
25
+ from hashlib import blake2b
26
+ from typing import TYPE_CHECKING
27
+
28
+ if TYPE_CHECKING:
29
+ from graphiti_core.nodes import EntityNode
30
+
31
+ _NAME_ENTROPY_THRESHOLD = 1.5
32
+ _MIN_NAME_LENGTH = 6
33
+ _MIN_TOKEN_COUNT = 2
34
+ _FUZZY_JACCARD_THRESHOLD = 0.9
35
+ _MINHASH_PERMUTATIONS = 32
36
+ _MINHASH_BAND_SIZE = 4
37
+
38
+
39
+ def _normalize_string_exact(name: str) -> str:
40
+ """Lowercase text and collapse whitespace so equal names map to the same key."""
41
+ normalized = re.sub(r'[\s]+', ' ', name.lower())
42
+ return normalized.strip()
43
+
44
+
45
+ def _normalize_name_for_fuzzy(name: str) -> str:
46
+ """Produce a fuzzier form that keeps alphanumerics and apostrophes for n-gram shingles."""
47
+ normalized = re.sub(r"[^a-z0-9' ]", ' ', _normalize_string_exact(name))
48
+ normalized = normalized.strip()
49
+ return re.sub(r'[\s]+', ' ', normalized)
50
+
51
+
52
+ def _name_entropy(normalized_name: str) -> float:
53
+ """Approximate text specificity using Shannon entropy over characters.
54
+
55
+ We strip spaces, count how often each character appears, and sum
56
+ probability * -log2(probability). Short or repetitive names yield low
57
+ entropy, which signals we should defer resolution to the LLM instead of
58
+ trusting fuzzy similarity.
59
+ """
60
+ if not normalized_name:
61
+ return 0.0
62
+
63
+ counts: dict[str, int] = {}
64
+ for char in normalized_name.replace(' ', ''):
65
+ counts[char] = counts.get(char, 0) + 1
66
+
67
+ total = sum(counts.values())
68
+ if total == 0:
69
+ return 0.0
70
+
71
+ entropy = 0.0
72
+ for count in counts.values():
73
+ probability = count / total
74
+ entropy -= probability * math.log2(probability)
75
+
76
+ return entropy
77
+
78
+
79
+ def _has_high_entropy(normalized_name: str) -> bool:
80
+ """Filter out very short or low-entropy names that are unreliable for fuzzy matching."""
81
+ token_count = len(normalized_name.split())
82
+ if len(normalized_name) < _MIN_NAME_LENGTH and token_count < _MIN_TOKEN_COUNT:
83
+ return False
84
+
85
+ return _name_entropy(normalized_name) >= _NAME_ENTROPY_THRESHOLD
86
+
87
+
88
+ def _shingles(normalized_name: str) -> set[str]:
89
+ """Create 3-gram shingles from the normalized name for MinHash calculations."""
90
+ cleaned = normalized_name.replace(' ', '')
91
+ if len(cleaned) < 2:
92
+ return {cleaned} if cleaned else set()
93
+
94
+ return {cleaned[i : i + 3] for i in range(len(cleaned) - 2)}
95
+
96
+
97
+ def _hash_shingle(shingle: str, seed: int) -> int:
98
+ """Generate a deterministic 64-bit hash for a shingle given the permutation seed."""
99
+ digest = blake2b(f'{seed}:{shingle}'.encode(), digest_size=8)
100
+ return int.from_bytes(digest.digest(), 'big')
101
+
102
+
103
+ def _minhash_signature(shingles: Iterable[str]) -> tuple[int, ...]:
104
+ """Compute the MinHash signature for the shingle set across predefined permutations."""
105
+ if not shingles:
106
+ return tuple()
107
+
108
+ seeds = range(_MINHASH_PERMUTATIONS)
109
+ signature: list[int] = []
110
+ for seed in seeds:
111
+ min_hash = min(_hash_shingle(shingle, seed) for shingle in shingles)
112
+ signature.append(min_hash)
113
+
114
+ return tuple(signature)
115
+
116
+
117
+ def _lsh_bands(signature: Iterable[int]) -> list[tuple[int, ...]]:
118
+ """Split the MinHash signature into fixed-size bands for locality-sensitive hashing."""
119
+ signature_list = list(signature)
120
+ if not signature_list:
121
+ return []
122
+
123
+ bands: list[tuple[int, ...]] = []
124
+ for start in range(0, len(signature_list), _MINHASH_BAND_SIZE):
125
+ band = tuple(signature_list[start : start + _MINHASH_BAND_SIZE])
126
+ if len(band) == _MINHASH_BAND_SIZE:
127
+ bands.append(band)
128
+ return bands
129
+
130
+
131
+ def _jaccard_similarity(a: set[str], b: set[str]) -> float:
132
+ """Return the Jaccard similarity between two shingle sets, handling empty edge cases."""
133
+ if not a and not b:
134
+ return 1.0
135
+ if not a or not b:
136
+ return 0.0
137
+
138
+ intersection = len(a.intersection(b))
139
+ union = len(a.union(b))
140
+ return intersection / union if union else 0.0
141
+
142
+
143
+ @lru_cache(maxsize=512)
144
+ def _cached_shingles(name: str) -> set[str]:
145
+ """Cache shingle sets per normalized name to avoid recomputation within a worker."""
146
+ return _shingles(name)
147
+
148
+
149
+ @dataclass
150
+ class DedupCandidateIndexes:
151
+ """Precomputed lookup structures that drive entity deduplication heuristics."""
152
+
153
+ existing_nodes: list[EntityNode]
154
+ nodes_by_uuid: dict[str, EntityNode]
155
+ normalized_existing: defaultdict[str, list[EntityNode]]
156
+ shingles_by_candidate: dict[str, set[str]]
157
+ lsh_buckets: defaultdict[tuple[int, tuple[int, ...]], list[str]]
158
+
159
+
160
+ @dataclass
161
+ class DedupResolutionState:
162
+ """Mutable resolution bookkeeping shared across deterministic and LLM passes."""
163
+
164
+ resolved_nodes: list[EntityNode | None]
165
+ uuid_map: dict[str, str]
166
+ unresolved_indices: list[int]
167
+ duplicate_pairs: list[tuple[EntityNode, EntityNode]] = field(default_factory=list)
168
+
169
+
170
+ def _build_candidate_indexes(existing_nodes: list[EntityNode]) -> DedupCandidateIndexes:
171
+ """Precompute exact and fuzzy lookup structures once per dedupe run."""
172
+ normalized_existing: defaultdict[str, list[EntityNode]] = defaultdict(list)
173
+ nodes_by_uuid: dict[str, EntityNode] = {}
174
+ shingles_by_candidate: dict[str, set[str]] = {}
175
+ lsh_buckets: defaultdict[tuple[int, tuple[int, ...]], list[str]] = defaultdict(list)
176
+
177
+ for candidate in existing_nodes:
178
+ normalized = _normalize_string_exact(candidate.name)
179
+ normalized_existing[normalized].append(candidate)
180
+ nodes_by_uuid[candidate.uuid] = candidate
181
+
182
+ shingles = _cached_shingles(_normalize_name_for_fuzzy(candidate.name))
183
+ shingles_by_candidate[candidate.uuid] = shingles
184
+
185
+ signature = _minhash_signature(shingles)
186
+ for band_index, band in enumerate(_lsh_bands(signature)):
187
+ lsh_buckets[(band_index, band)].append(candidate.uuid)
188
+
189
+ return DedupCandidateIndexes(
190
+ existing_nodes=existing_nodes,
191
+ nodes_by_uuid=nodes_by_uuid,
192
+ normalized_existing=normalized_existing,
193
+ shingles_by_candidate=shingles_by_candidate,
194
+ lsh_buckets=lsh_buckets,
195
+ )
196
+
197
+
198
+ def _resolve_with_similarity(
199
+ extracted_nodes: list[EntityNode],
200
+ indexes: DedupCandidateIndexes,
201
+ state: DedupResolutionState,
202
+ ) -> None:
203
+ """Attempt deterministic resolution using exact name hits and fuzzy MinHash comparisons."""
204
+ for idx, node in enumerate(extracted_nodes):
205
+ normalized_exact = _normalize_string_exact(node.name)
206
+ normalized_fuzzy = _normalize_name_for_fuzzy(node.name)
207
+
208
+ if not _has_high_entropy(normalized_fuzzy):
209
+ state.unresolved_indices.append(idx)
210
+ continue
211
+
212
+ existing_matches = indexes.normalized_existing.get(normalized_exact, [])
213
+ if len(existing_matches) == 1:
214
+ match = existing_matches[0]
215
+ state.resolved_nodes[idx] = match
216
+ state.uuid_map[node.uuid] = match.uuid
217
+ if match.uuid != node.uuid:
218
+ state.duplicate_pairs.append((node, match))
219
+ continue
220
+ if len(existing_matches) > 1:
221
+ state.unresolved_indices.append(idx)
222
+ continue
223
+
224
+ shingles = _cached_shingles(normalized_fuzzy)
225
+ signature = _minhash_signature(shingles)
226
+ candidate_ids: set[str] = set()
227
+ for band_index, band in enumerate(_lsh_bands(signature)):
228
+ candidate_ids.update(indexes.lsh_buckets.get((band_index, band), []))
229
+
230
+ best_candidate: EntityNode | None = None
231
+ best_score = 0.0
232
+ for candidate_id in candidate_ids:
233
+ candidate_shingles = indexes.shingles_by_candidate.get(candidate_id, set())
234
+ score = _jaccard_similarity(shingles, candidate_shingles)
235
+ if score > best_score:
236
+ best_score = score
237
+ best_candidate = indexes.nodes_by_uuid.get(candidate_id)
238
+
239
+ if best_candidate is not None and best_score >= _FUZZY_JACCARD_THRESHOLD:
240
+ state.resolved_nodes[idx] = best_candidate
241
+ state.uuid_map[node.uuid] = best_candidate.uuid
242
+ if best_candidate.uuid != node.uuid:
243
+ state.duplicate_pairs.append((node, best_candidate))
244
+ continue
245
+
246
+ state.unresolved_indices.append(idx)
247
+
248
+
249
+ __all__ = [
250
+ 'DedupCandidateIndexes',
251
+ 'DedupResolutionState',
252
+ '_normalize_string_exact',
253
+ '_normalize_name_for_fuzzy',
254
+ '_has_high_entropy',
255
+ '_minhash_signature',
256
+ '_lsh_bands',
257
+ '_jaccard_similarity',
258
+ '_cached_shingles',
259
+ '_FUZZY_JACCARD_THRESHOLD',
260
+ '_build_candidate_indexes',
261
+ '_resolve_with_similarity',
262
+ ]