graphiti-core 0.20.4__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +28 -0
- graphiti_core/driver/falkordb_driver.py +112 -0
- graphiti_core/driver/kuzu_driver.py +1 -0
- graphiti_core/driver/neo4j_driver.py +10 -2
- graphiti_core/driver/neptune_driver.py +4 -6
- graphiti_core/edges.py +67 -7
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/graph_queries.py +35 -6
- graphiti_core/graphiti.py +27 -23
- graphiti_core/graphiti_types.py +0 -1
- graphiti_core/helpers.py +2 -2
- graphiti_core/llm_client/client.py +19 -4
- graphiti_core/llm_client/gemini_client.py +4 -2
- graphiti_core/llm_client/openai_base_client.py +3 -2
- graphiti_core/llm_client/openai_generic_client.py +3 -2
- graphiti_core/models/edges/edge_db_queries.py +36 -16
- graphiti_core/models/nodes/node_db_queries.py +30 -10
- graphiti_core/nodes.py +126 -25
- graphiti_core/prompts/dedupe_edges.py +40 -29
- graphiti_core/prompts/dedupe_nodes.py +51 -34
- graphiti_core/prompts/eval.py +3 -3
- graphiti_core/prompts/extract_edges.py +17 -9
- graphiti_core/prompts/extract_nodes.py +10 -9
- graphiti_core/prompts/prompt_helpers.py +3 -3
- graphiti_core/prompts/summarize_nodes.py +5 -5
- graphiti_core/search/search_filters.py +53 -0
- graphiti_core/search/search_helpers.py +5 -7
- graphiti_core/search/search_utils.py +227 -57
- graphiti_core/utils/bulk_utils.py +168 -69
- graphiti_core/utils/maintenance/community_operations.py +8 -20
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +187 -50
- graphiti_core/utils/maintenance/graph_data_operations.py +9 -5
- graphiti_core/utils/maintenance/node_operations.py +244 -88
- graphiti_core/utils/maintenance/temporal_operations.py +0 -4
- {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0.dist-info}/METADATA +7 -1
- {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0.dist-info}/RECORD +39 -38
- {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0.dist-info}/WHEEL +0 -0
- {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -23,7 +23,14 @@ import numpy as np
|
|
|
23
23
|
from pydantic import BaseModel, Field
|
|
24
24
|
from typing_extensions import Any
|
|
25
25
|
|
|
26
|
-
from graphiti_core.driver.driver import
|
|
26
|
+
from graphiti_core.driver.driver import (
|
|
27
|
+
ENTITY_EDGE_INDEX_NAME,
|
|
28
|
+
ENTITY_INDEX_NAME,
|
|
29
|
+
EPISODE_INDEX_NAME,
|
|
30
|
+
GraphDriver,
|
|
31
|
+
GraphDriverSession,
|
|
32
|
+
GraphProvider,
|
|
33
|
+
)
|
|
27
34
|
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
|
|
28
35
|
from graphiti_core.embedder import EmbedderClient
|
|
29
36
|
from graphiti_core.graphiti_types import GraphitiClients
|
|
@@ -36,8 +43,14 @@ from graphiti_core.models.nodes.node_db_queries import (
|
|
|
36
43
|
get_entity_node_save_bulk_query,
|
|
37
44
|
get_episode_node_save_bulk_query,
|
|
38
45
|
)
|
|
39
|
-
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
|
46
|
+
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
|
40
47
|
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
|
|
48
|
+
from graphiti_core.utils.maintenance.dedup_helpers import (
|
|
49
|
+
DedupResolutionState,
|
|
50
|
+
_build_candidate_indexes,
|
|
51
|
+
_normalize_string_exact,
|
|
52
|
+
_resolve_with_similarity,
|
|
53
|
+
)
|
|
41
54
|
from graphiti_core.utils.maintenance.edge_operations import (
|
|
42
55
|
extract_edges,
|
|
43
56
|
resolve_extracted_edge,
|
|
@@ -56,6 +69,38 @@ logger = logging.getLogger(__name__)
|
|
|
56
69
|
CHUNK_SIZE = 10
|
|
57
70
|
|
|
58
71
|
|
|
72
|
+
def _build_directed_uuid_map(pairs: list[tuple[str, str]]) -> dict[str, str]:
|
|
73
|
+
"""Collapse alias -> canonical chains while preserving direction.
|
|
74
|
+
|
|
75
|
+
The incoming pairs represent directed mappings discovered during node dedupe. We use a simple
|
|
76
|
+
union-find with iterative path compression to ensure every source UUID resolves to its ultimate
|
|
77
|
+
canonical target, even if aliases appear lexicographically smaller than the canonical UUID.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
parent: dict[str, str] = {}
|
|
81
|
+
|
|
82
|
+
def find(uuid: str) -> str:
|
|
83
|
+
"""Directed union-find lookup using iterative path compression."""
|
|
84
|
+
parent.setdefault(uuid, uuid)
|
|
85
|
+
root = uuid
|
|
86
|
+
while parent[root] != root:
|
|
87
|
+
root = parent[root]
|
|
88
|
+
|
|
89
|
+
while parent[uuid] != root:
|
|
90
|
+
next_uuid = parent[uuid]
|
|
91
|
+
parent[uuid] = root
|
|
92
|
+
uuid = next_uuid
|
|
93
|
+
|
|
94
|
+
return root
|
|
95
|
+
|
|
96
|
+
for source_uuid, target_uuid in pairs:
|
|
97
|
+
parent.setdefault(source_uuid, source_uuid)
|
|
98
|
+
parent.setdefault(target_uuid, target_uuid)
|
|
99
|
+
parent[find(source_uuid)] = find(target_uuid)
|
|
100
|
+
|
|
101
|
+
return {uuid: find(uuid) for uuid in parent}
|
|
102
|
+
|
|
103
|
+
|
|
59
104
|
class RawEpisode(BaseModel):
|
|
60
105
|
name: str
|
|
61
106
|
uuid: str | None = Field(default=None)
|
|
@@ -129,12 +174,14 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
129
174
|
entity_data: dict[str, Any] = {
|
|
130
175
|
'uuid': node.uuid,
|
|
131
176
|
'name': node.name,
|
|
132
|
-
'name_embedding': node.name_embedding,
|
|
133
177
|
'group_id': node.group_id,
|
|
134
178
|
'summary': node.summary,
|
|
135
179
|
'created_at': node.created_at,
|
|
136
180
|
}
|
|
137
181
|
|
|
182
|
+
if not bool(driver.aoss_client):
|
|
183
|
+
entity_data['name_embedding'] = node.name_embedding
|
|
184
|
+
|
|
138
185
|
entity_data['labels'] = list(set(node.labels + ['Entity']))
|
|
139
186
|
if driver.provider == GraphProvider.KUZU:
|
|
140
187
|
attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
|
|
@@ -154,7 +201,6 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
154
201
|
'target_node_uuid': edge.target_node_uuid,
|
|
155
202
|
'name': edge.name,
|
|
156
203
|
'fact': edge.fact,
|
|
157
|
-
'fact_embedding': edge.fact_embedding,
|
|
158
204
|
'group_id': edge.group_id,
|
|
159
205
|
'episodes': edge.episodes,
|
|
160
206
|
'created_at': edge.created_at,
|
|
@@ -163,6 +209,9 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
163
209
|
'invalid_at': edge.invalid_at,
|
|
164
210
|
}
|
|
165
211
|
|
|
212
|
+
if not bool(driver.aoss_client):
|
|
213
|
+
edge_data['fact_embedding'] = edge.fact_embedding
|
|
214
|
+
|
|
166
215
|
if driver.provider == GraphProvider.KUZU:
|
|
167
216
|
attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
|
|
168
217
|
edge_data['attributes'] = json.dumps(attributes)
|
|
@@ -187,12 +236,33 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
187
236
|
await tx.run(episodic_edge_query, **edge.model_dump())
|
|
188
237
|
else:
|
|
189
238
|
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
|
|
190
|
-
await tx.run(
|
|
239
|
+
await tx.run(
|
|
240
|
+
get_entity_node_save_bulk_query(
|
|
241
|
+
driver.provider, nodes, has_aoss=bool(driver.aoss_client)
|
|
242
|
+
),
|
|
243
|
+
nodes=nodes,
|
|
244
|
+
)
|
|
191
245
|
await tx.run(
|
|
192
246
|
get_episodic_edge_save_bulk_query(driver.provider),
|
|
193
247
|
episodic_edges=[edge.model_dump() for edge in episodic_edges],
|
|
194
248
|
)
|
|
195
|
-
await tx.run(
|
|
249
|
+
await tx.run(
|
|
250
|
+
get_entity_edge_save_bulk_query(driver.provider, has_aoss=bool(driver.aoss_client)),
|
|
251
|
+
entity_edges=edges,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
if bool(driver.aoss_client):
|
|
255
|
+
for node_data, entity_node in zip(nodes, entity_nodes, strict=True):
|
|
256
|
+
if node_data.get('uuid') == entity_node.uuid:
|
|
257
|
+
node_data['name_embedding'] = entity_node.name_embedding
|
|
258
|
+
|
|
259
|
+
for edge_data, entity_edge in zip(edges, entity_edges, strict=True):
|
|
260
|
+
if edge_data.get('uuid') == entity_edge.uuid:
|
|
261
|
+
edge_data['fact_embedding'] = entity_edge.fact_embedding
|
|
262
|
+
|
|
263
|
+
await driver.save_to_aoss(EPISODE_INDEX_NAME, episodes)
|
|
264
|
+
await driver.save_to_aoss(ENTITY_INDEX_NAME, nodes)
|
|
265
|
+
await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges)
|
|
196
266
|
|
|
197
267
|
|
|
198
268
|
async def extract_nodes_and_edges_bulk(
|
|
@@ -234,83 +304,111 @@ async def dedupe_nodes_bulk(
|
|
|
234
304
|
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
|
235
305
|
entity_types: dict[str, type[BaseModel]] | None = None,
|
|
236
306
|
) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
|
|
237
|
-
|
|
238
|
-
min_score = 0.8
|
|
307
|
+
"""Resolve entity duplicates across an in-memory batch using a two-pass strategy.
|
|
239
308
|
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
dedupe_tuples: list[tuple[list[EntityNode], list[EntityNode]]] = []
|
|
247
|
-
for i, nodes_i in enumerate(extracted_nodes):
|
|
248
|
-
existing_nodes: list[EntityNode] = []
|
|
249
|
-
for j, nodes_j in enumerate(extracted_nodes):
|
|
250
|
-
if i == j:
|
|
251
|
-
continue
|
|
252
|
-
existing_nodes += nodes_j
|
|
253
|
-
|
|
254
|
-
candidates_i: list[EntityNode] = []
|
|
255
|
-
for node in nodes_i:
|
|
256
|
-
for existing_node in existing_nodes:
|
|
257
|
-
# Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
|
|
258
|
-
# This approach will cast a wider net than BM25, which is ideal for this use case
|
|
259
|
-
node_words = set(node.name.lower().split())
|
|
260
|
-
existing_node_words = set(existing_node.name.lower().split())
|
|
261
|
-
has_overlap = not node_words.isdisjoint(existing_node_words)
|
|
262
|
-
if has_overlap:
|
|
263
|
-
candidates_i.append(existing_node)
|
|
264
|
-
continue
|
|
265
|
-
|
|
266
|
-
# Check for semantic similarity even if there is no overlap
|
|
267
|
-
similarity = np.dot(
|
|
268
|
-
normalize_l2(node.name_embedding or []),
|
|
269
|
-
normalize_l2(existing_node.name_embedding or []),
|
|
270
|
-
)
|
|
271
|
-
if similarity >= min_score:
|
|
272
|
-
candidates_i.append(existing_node)
|
|
273
|
-
|
|
274
|
-
dedupe_tuples.append((nodes_i, candidates_i))
|
|
309
|
+
1. Run :func:`resolve_extracted_nodes` for every episode in parallel so each batch item is
|
|
310
|
+
reconciled against the live graph just like the non-batch flow.
|
|
311
|
+
2. Re-run the deterministic similarity heuristics across the union of resolved nodes to catch
|
|
312
|
+
duplicates that only co-occur inside this batch, emitting a canonical UUID map that callers
|
|
313
|
+
can apply to edges and persistence.
|
|
314
|
+
"""
|
|
275
315
|
|
|
276
|
-
|
|
277
|
-
bulk_node_resolutions: list[
|
|
278
|
-
tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]
|
|
279
|
-
] = await semaphore_gather(
|
|
316
|
+
first_pass_results = await semaphore_gather(
|
|
280
317
|
*[
|
|
281
318
|
resolve_extracted_nodes(
|
|
282
319
|
clients,
|
|
283
|
-
|
|
320
|
+
nodes,
|
|
284
321
|
episode_tuples[i][0],
|
|
285
322
|
episode_tuples[i][1],
|
|
286
323
|
entity_types,
|
|
287
|
-
existing_nodes_override=dedupe_tuples[i][1],
|
|
288
324
|
)
|
|
289
|
-
for i,
|
|
325
|
+
for i, nodes in enumerate(extracted_nodes)
|
|
290
326
|
]
|
|
291
327
|
)
|
|
292
328
|
|
|
293
|
-
|
|
329
|
+
episode_resolutions: list[tuple[str, list[EntityNode]]] = []
|
|
330
|
+
per_episode_uuid_maps: list[dict[str, str]] = []
|
|
294
331
|
duplicate_pairs: list[tuple[str, str]] = []
|
|
295
|
-
for _, _, duplicates in bulk_node_resolutions:
|
|
296
|
-
for duplicate in duplicates:
|
|
297
|
-
n, m = duplicate
|
|
298
|
-
duplicate_pairs.append((n.uuid, m.uuid))
|
|
299
332
|
|
|
300
|
-
|
|
301
|
-
|
|
333
|
+
for (resolved_nodes, uuid_map, duplicates), (episode, _) in zip(
|
|
334
|
+
first_pass_results, episode_tuples, strict=True
|
|
335
|
+
):
|
|
336
|
+
episode_resolutions.append((episode.uuid, resolved_nodes))
|
|
337
|
+
per_episode_uuid_maps.append(uuid_map)
|
|
338
|
+
duplicate_pairs.extend((source.uuid, target.uuid) for source, target in duplicates)
|
|
339
|
+
|
|
340
|
+
canonical_nodes: dict[str, EntityNode] = {}
|
|
341
|
+
for _, resolved_nodes in episode_resolutions:
|
|
342
|
+
for node in resolved_nodes:
|
|
343
|
+
# NOTE: this loop is O(n^2) in the number of nodes inside the batch because we rebuild
|
|
344
|
+
# the MinHash index for the accumulated canonical pool each time. The LRU-backed
|
|
345
|
+
# shingle cache keeps the constant factors low for typical batch sizes (≤ CHUNK_SIZE),
|
|
346
|
+
# but if batches grow significantly we should switch to an incremental index or chunked
|
|
347
|
+
# processing.
|
|
348
|
+
if not canonical_nodes:
|
|
349
|
+
canonical_nodes[node.uuid] = node
|
|
350
|
+
continue
|
|
302
351
|
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
352
|
+
existing_candidates = list(canonical_nodes.values())
|
|
353
|
+
normalized = _normalize_string_exact(node.name)
|
|
354
|
+
exact_match = next(
|
|
355
|
+
(
|
|
356
|
+
candidate
|
|
357
|
+
for candidate in existing_candidates
|
|
358
|
+
if _normalize_string_exact(candidate.name) == normalized
|
|
359
|
+
),
|
|
360
|
+
None,
|
|
361
|
+
)
|
|
362
|
+
if exact_match is not None:
|
|
363
|
+
if exact_match.uuid != node.uuid:
|
|
364
|
+
duplicate_pairs.append((node.uuid, exact_match.uuid))
|
|
365
|
+
continue
|
|
366
|
+
|
|
367
|
+
indexes = _build_candidate_indexes(existing_candidates)
|
|
368
|
+
state = DedupResolutionState(
|
|
369
|
+
resolved_nodes=[None],
|
|
370
|
+
uuid_map={},
|
|
371
|
+
unresolved_indices=[],
|
|
372
|
+
)
|
|
373
|
+
_resolve_with_similarity([node], indexes, state)
|
|
374
|
+
|
|
375
|
+
resolved = state.resolved_nodes[0]
|
|
376
|
+
if resolved is None:
|
|
377
|
+
canonical_nodes[node.uuid] = node
|
|
378
|
+
continue
|
|
379
|
+
|
|
380
|
+
canonical_uuid = resolved.uuid
|
|
381
|
+
canonical_nodes.setdefault(canonical_uuid, resolved)
|
|
382
|
+
if canonical_uuid != node.uuid:
|
|
383
|
+
duplicate_pairs.append((node.uuid, canonical_uuid))
|
|
384
|
+
|
|
385
|
+
union_pairs: list[tuple[str, str]] = []
|
|
386
|
+
for uuid_map in per_episode_uuid_maps:
|
|
387
|
+
union_pairs.extend(uuid_map.items())
|
|
388
|
+
union_pairs.extend(duplicate_pairs)
|
|
389
|
+
|
|
390
|
+
compressed_map: dict[str, str] = _build_directed_uuid_map(union_pairs)
|
|
306
391
|
|
|
307
392
|
nodes_by_episode: dict[str, list[EntityNode]] = {}
|
|
308
|
-
for
|
|
309
|
-
|
|
393
|
+
for episode_uuid, resolved_nodes in episode_resolutions:
|
|
394
|
+
deduped_nodes: list[EntityNode] = []
|
|
395
|
+
seen: set[str] = set()
|
|
396
|
+
for node in resolved_nodes:
|
|
397
|
+
canonical_uuid = compressed_map.get(node.uuid, node.uuid)
|
|
398
|
+
if canonical_uuid in seen:
|
|
399
|
+
continue
|
|
400
|
+
seen.add(canonical_uuid)
|
|
401
|
+
canonical_node = canonical_nodes.get(canonical_uuid)
|
|
402
|
+
if canonical_node is None:
|
|
403
|
+
logger.error(
|
|
404
|
+
'Canonical node %s missing during batch dedupe; falling back to %s',
|
|
405
|
+
canonical_uuid,
|
|
406
|
+
node.uuid,
|
|
407
|
+
)
|
|
408
|
+
canonical_node = node
|
|
409
|
+
deduped_nodes.append(canonical_node)
|
|
310
410
|
|
|
311
|
-
nodes_by_episode[
|
|
312
|
-
node_uuid_map[compressed_map.get(node.uuid, node.uuid)] for node in nodes
|
|
313
|
-
]
|
|
411
|
+
nodes_by_episode[episode_uuid] = deduped_nodes
|
|
314
412
|
|
|
315
413
|
return nodes_by_episode, compressed_map
|
|
316
414
|
|
|
@@ -335,14 +433,15 @@ async def dedupe_edges_bulk(
|
|
|
335
433
|
dedupe_tuples: list[tuple[EpisodicNode, EntityEdge, list[EntityEdge]]] = []
|
|
336
434
|
for i, edges_i in enumerate(extracted_edges):
|
|
337
435
|
existing_edges: list[EntityEdge] = []
|
|
338
|
-
for
|
|
339
|
-
if i == j:
|
|
340
|
-
continue
|
|
436
|
+
for edges_j in extracted_edges:
|
|
341
437
|
existing_edges += edges_j
|
|
342
438
|
|
|
343
439
|
for edge in edges_i:
|
|
344
440
|
candidates: list[EntityEdge] = []
|
|
345
441
|
for existing_edge in existing_edges:
|
|
442
|
+
# Skip self-comparison
|
|
443
|
+
if edge.uuid == existing_edge.uuid:
|
|
444
|
+
continue
|
|
346
445
|
# Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
|
|
347
446
|
# This approach will cast a wider net than BM25, which is ideal for this use case
|
|
348
447
|
if (
|
|
@@ -379,7 +478,7 @@ async def dedupe_edges_bulk(
|
|
|
379
478
|
candidates,
|
|
380
479
|
episode,
|
|
381
480
|
edge_types,
|
|
382
|
-
|
|
481
|
+
set(edge_types),
|
|
383
482
|
)
|
|
384
483
|
for episode, edge, candidates in dedupe_tuples
|
|
385
484
|
]
|
|
@@ -131,13 +131,10 @@ def label_propagation(projection: dict[str, list[Neighbor]]) -> list[list[str]]:
|
|
|
131
131
|
return clusters
|
|
132
132
|
|
|
133
133
|
|
|
134
|
-
async def summarize_pair(
|
|
135
|
-
llm_client: LLMClient, summary_pair: tuple[str, str], ensure_ascii: bool = True
|
|
136
|
-
) -> str:
|
|
134
|
+
async def summarize_pair(llm_client: LLMClient, summary_pair: tuple[str, str]) -> str:
|
|
137
135
|
# Prepare context for LLM
|
|
138
136
|
context = {
|
|
139
137
|
'node_summaries': [{'summary': summary} for summary in summary_pair],
|
|
140
|
-
'ensure_ascii': ensure_ascii,
|
|
141
138
|
}
|
|
142
139
|
|
|
143
140
|
llm_response = await llm_client.generate_response(
|
|
@@ -149,12 +146,9 @@ async def summarize_pair(
|
|
|
149
146
|
return pair_summary
|
|
150
147
|
|
|
151
148
|
|
|
152
|
-
async def generate_summary_description(
|
|
153
|
-
llm_client: LLMClient, summary: str, ensure_ascii: bool = True
|
|
154
|
-
) -> str:
|
|
149
|
+
async def generate_summary_description(llm_client: LLMClient, summary: str) -> str:
|
|
155
150
|
context = {
|
|
156
151
|
'summary': summary,
|
|
157
|
-
'ensure_ascii': ensure_ascii,
|
|
158
152
|
}
|
|
159
153
|
|
|
160
154
|
llm_response = await llm_client.generate_response(
|
|
@@ -168,7 +162,7 @@ async def generate_summary_description(
|
|
|
168
162
|
|
|
169
163
|
|
|
170
164
|
async def build_community(
|
|
171
|
-
llm_client: LLMClient, community_cluster: list[EntityNode]
|
|
165
|
+
llm_client: LLMClient, community_cluster: list[EntityNode]
|
|
172
166
|
) -> tuple[CommunityNode, list[CommunityEdge]]:
|
|
173
167
|
summaries = [entity.summary for entity in community_cluster]
|
|
174
168
|
length = len(summaries)
|
|
@@ -180,9 +174,7 @@ async def build_community(
|
|
|
180
174
|
new_summaries: list[str] = list(
|
|
181
175
|
await semaphore_gather(
|
|
182
176
|
*[
|
|
183
|
-
summarize_pair(
|
|
184
|
-
llm_client, (str(left_summary), str(right_summary)), ensure_ascii
|
|
185
|
-
)
|
|
177
|
+
summarize_pair(llm_client, (str(left_summary), str(right_summary)))
|
|
186
178
|
for left_summary, right_summary in zip(
|
|
187
179
|
summaries[: int(length / 2)], summaries[int(length / 2) :], strict=False
|
|
188
180
|
)
|
|
@@ -195,7 +187,7 @@ async def build_community(
|
|
|
195
187
|
length = len(summaries)
|
|
196
188
|
|
|
197
189
|
summary = summaries[0]
|
|
198
|
-
name = await generate_summary_description(llm_client, summary
|
|
190
|
+
name = await generate_summary_description(llm_client, summary)
|
|
199
191
|
now = utc_now()
|
|
200
192
|
community_node = CommunityNode(
|
|
201
193
|
name=name,
|
|
@@ -215,7 +207,6 @@ async def build_communities(
|
|
|
215
207
|
driver: GraphDriver,
|
|
216
208
|
llm_client: LLMClient,
|
|
217
209
|
group_ids: list[str] | None,
|
|
218
|
-
ensure_ascii: bool = True,
|
|
219
210
|
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
|
220
211
|
community_clusters = await get_community_clusters(driver, group_ids)
|
|
221
212
|
|
|
@@ -223,7 +214,7 @@ async def build_communities(
|
|
|
223
214
|
|
|
224
215
|
async def limited_build_community(cluster):
|
|
225
216
|
async with semaphore:
|
|
226
|
-
return await build_community(llm_client, cluster
|
|
217
|
+
return await build_community(llm_client, cluster)
|
|
227
218
|
|
|
228
219
|
communities: list[tuple[CommunityNode, list[CommunityEdge]]] = list(
|
|
229
220
|
await semaphore_gather(
|
|
@@ -312,17 +303,14 @@ async def update_community(
|
|
|
312
303
|
llm_client: LLMClient,
|
|
313
304
|
embedder: EmbedderClient,
|
|
314
305
|
entity: EntityNode,
|
|
315
|
-
ensure_ascii: bool = True,
|
|
316
306
|
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
|
317
307
|
community, is_new = await determine_entity_community(driver, entity)
|
|
318
308
|
|
|
319
309
|
if community is None:
|
|
320
310
|
return [], []
|
|
321
311
|
|
|
322
|
-
new_summary = await summarize_pair(
|
|
323
|
-
|
|
324
|
-
)
|
|
325
|
-
new_name = await generate_summary_description(llm_client, new_summary, ensure_ascii)
|
|
312
|
+
new_summary = await summarize_pair(llm_client, (entity.summary, community.summary))
|
|
313
|
+
new_name = await generate_summary_description(llm_client, new_summary)
|
|
326
314
|
|
|
327
315
|
community.summary = new_summary
|
|
328
316
|
community.name = new_name
|
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import math
|
|
20
|
+
import re
|
|
21
|
+
from collections import defaultdict
|
|
22
|
+
from collections.abc import Iterable
|
|
23
|
+
from dataclasses import dataclass, field
|
|
24
|
+
from functools import lru_cache
|
|
25
|
+
from hashlib import blake2b
|
|
26
|
+
from typing import TYPE_CHECKING
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from graphiti_core.nodes import EntityNode
|
|
30
|
+
|
|
31
|
+
_NAME_ENTROPY_THRESHOLD = 1.5
|
|
32
|
+
_MIN_NAME_LENGTH = 6
|
|
33
|
+
_MIN_TOKEN_COUNT = 2
|
|
34
|
+
_FUZZY_JACCARD_THRESHOLD = 0.9
|
|
35
|
+
_MINHASH_PERMUTATIONS = 32
|
|
36
|
+
_MINHASH_BAND_SIZE = 4
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _normalize_string_exact(name: str) -> str:
|
|
40
|
+
"""Lowercase text and collapse whitespace so equal names map to the same key."""
|
|
41
|
+
normalized = re.sub(r'[\s]+', ' ', name.lower())
|
|
42
|
+
return normalized.strip()
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _normalize_name_for_fuzzy(name: str) -> str:
|
|
46
|
+
"""Produce a fuzzier form that keeps alphanumerics and apostrophes for n-gram shingles."""
|
|
47
|
+
normalized = re.sub(r"[^a-z0-9' ]", ' ', _normalize_string_exact(name))
|
|
48
|
+
normalized = normalized.strip()
|
|
49
|
+
return re.sub(r'[\s]+', ' ', normalized)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _name_entropy(normalized_name: str) -> float:
|
|
53
|
+
"""Approximate text specificity using Shannon entropy over characters.
|
|
54
|
+
|
|
55
|
+
We strip spaces, count how often each character appears, and sum
|
|
56
|
+
probability * -log2(probability). Short or repetitive names yield low
|
|
57
|
+
entropy, which signals we should defer resolution to the LLM instead of
|
|
58
|
+
trusting fuzzy similarity.
|
|
59
|
+
"""
|
|
60
|
+
if not normalized_name:
|
|
61
|
+
return 0.0
|
|
62
|
+
|
|
63
|
+
counts: dict[str, int] = {}
|
|
64
|
+
for char in normalized_name.replace(' ', ''):
|
|
65
|
+
counts[char] = counts.get(char, 0) + 1
|
|
66
|
+
|
|
67
|
+
total = sum(counts.values())
|
|
68
|
+
if total == 0:
|
|
69
|
+
return 0.0
|
|
70
|
+
|
|
71
|
+
entropy = 0.0
|
|
72
|
+
for count in counts.values():
|
|
73
|
+
probability = count / total
|
|
74
|
+
entropy -= probability * math.log2(probability)
|
|
75
|
+
|
|
76
|
+
return entropy
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _has_high_entropy(normalized_name: str) -> bool:
|
|
80
|
+
"""Filter out very short or low-entropy names that are unreliable for fuzzy matching."""
|
|
81
|
+
token_count = len(normalized_name.split())
|
|
82
|
+
if len(normalized_name) < _MIN_NAME_LENGTH and token_count < _MIN_TOKEN_COUNT:
|
|
83
|
+
return False
|
|
84
|
+
|
|
85
|
+
return _name_entropy(normalized_name) >= _NAME_ENTROPY_THRESHOLD
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _shingles(normalized_name: str) -> set[str]:
|
|
89
|
+
"""Create 3-gram shingles from the normalized name for MinHash calculations."""
|
|
90
|
+
cleaned = normalized_name.replace(' ', '')
|
|
91
|
+
if len(cleaned) < 2:
|
|
92
|
+
return {cleaned} if cleaned else set()
|
|
93
|
+
|
|
94
|
+
return {cleaned[i : i + 3] for i in range(len(cleaned) - 2)}
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _hash_shingle(shingle: str, seed: int) -> int:
|
|
98
|
+
"""Generate a deterministic 64-bit hash for a shingle given the permutation seed."""
|
|
99
|
+
digest = blake2b(f'{seed}:{shingle}'.encode(), digest_size=8)
|
|
100
|
+
return int.from_bytes(digest.digest(), 'big')
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _minhash_signature(shingles: Iterable[str]) -> tuple[int, ...]:
|
|
104
|
+
"""Compute the MinHash signature for the shingle set across predefined permutations."""
|
|
105
|
+
if not shingles:
|
|
106
|
+
return tuple()
|
|
107
|
+
|
|
108
|
+
seeds = range(_MINHASH_PERMUTATIONS)
|
|
109
|
+
signature: list[int] = []
|
|
110
|
+
for seed in seeds:
|
|
111
|
+
min_hash = min(_hash_shingle(shingle, seed) for shingle in shingles)
|
|
112
|
+
signature.append(min_hash)
|
|
113
|
+
|
|
114
|
+
return tuple(signature)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _lsh_bands(signature: Iterable[int]) -> list[tuple[int, ...]]:
|
|
118
|
+
"""Split the MinHash signature into fixed-size bands for locality-sensitive hashing."""
|
|
119
|
+
signature_list = list(signature)
|
|
120
|
+
if not signature_list:
|
|
121
|
+
return []
|
|
122
|
+
|
|
123
|
+
bands: list[tuple[int, ...]] = []
|
|
124
|
+
for start in range(0, len(signature_list), _MINHASH_BAND_SIZE):
|
|
125
|
+
band = tuple(signature_list[start : start + _MINHASH_BAND_SIZE])
|
|
126
|
+
if len(band) == _MINHASH_BAND_SIZE:
|
|
127
|
+
bands.append(band)
|
|
128
|
+
return bands
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _jaccard_similarity(a: set[str], b: set[str]) -> float:
|
|
132
|
+
"""Return the Jaccard similarity between two shingle sets, handling empty edge cases."""
|
|
133
|
+
if not a and not b:
|
|
134
|
+
return 1.0
|
|
135
|
+
if not a or not b:
|
|
136
|
+
return 0.0
|
|
137
|
+
|
|
138
|
+
intersection = len(a.intersection(b))
|
|
139
|
+
union = len(a.union(b))
|
|
140
|
+
return intersection / union if union else 0.0
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
@lru_cache(maxsize=512)
|
|
144
|
+
def _cached_shingles(name: str) -> set[str]:
|
|
145
|
+
"""Cache shingle sets per normalized name to avoid recomputation within a worker."""
|
|
146
|
+
return _shingles(name)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
@dataclass
|
|
150
|
+
class DedupCandidateIndexes:
|
|
151
|
+
"""Precomputed lookup structures that drive entity deduplication heuristics."""
|
|
152
|
+
|
|
153
|
+
existing_nodes: list[EntityNode]
|
|
154
|
+
nodes_by_uuid: dict[str, EntityNode]
|
|
155
|
+
normalized_existing: defaultdict[str, list[EntityNode]]
|
|
156
|
+
shingles_by_candidate: dict[str, set[str]]
|
|
157
|
+
lsh_buckets: defaultdict[tuple[int, tuple[int, ...]], list[str]]
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
@dataclass
|
|
161
|
+
class DedupResolutionState:
|
|
162
|
+
"""Mutable resolution bookkeeping shared across deterministic and LLM passes."""
|
|
163
|
+
|
|
164
|
+
resolved_nodes: list[EntityNode | None]
|
|
165
|
+
uuid_map: dict[str, str]
|
|
166
|
+
unresolved_indices: list[int]
|
|
167
|
+
duplicate_pairs: list[tuple[EntityNode, EntityNode]] = field(default_factory=list)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def _build_candidate_indexes(existing_nodes: list[EntityNode]) -> DedupCandidateIndexes:
|
|
171
|
+
"""Precompute exact and fuzzy lookup structures once per dedupe run."""
|
|
172
|
+
normalized_existing: defaultdict[str, list[EntityNode]] = defaultdict(list)
|
|
173
|
+
nodes_by_uuid: dict[str, EntityNode] = {}
|
|
174
|
+
shingles_by_candidate: dict[str, set[str]] = {}
|
|
175
|
+
lsh_buckets: defaultdict[tuple[int, tuple[int, ...]], list[str]] = defaultdict(list)
|
|
176
|
+
|
|
177
|
+
for candidate in existing_nodes:
|
|
178
|
+
normalized = _normalize_string_exact(candidate.name)
|
|
179
|
+
normalized_existing[normalized].append(candidate)
|
|
180
|
+
nodes_by_uuid[candidate.uuid] = candidate
|
|
181
|
+
|
|
182
|
+
shingles = _cached_shingles(_normalize_name_for_fuzzy(candidate.name))
|
|
183
|
+
shingles_by_candidate[candidate.uuid] = shingles
|
|
184
|
+
|
|
185
|
+
signature = _minhash_signature(shingles)
|
|
186
|
+
for band_index, band in enumerate(_lsh_bands(signature)):
|
|
187
|
+
lsh_buckets[(band_index, band)].append(candidate.uuid)
|
|
188
|
+
|
|
189
|
+
return DedupCandidateIndexes(
|
|
190
|
+
existing_nodes=existing_nodes,
|
|
191
|
+
nodes_by_uuid=nodes_by_uuid,
|
|
192
|
+
normalized_existing=normalized_existing,
|
|
193
|
+
shingles_by_candidate=shingles_by_candidate,
|
|
194
|
+
lsh_buckets=lsh_buckets,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _resolve_with_similarity(
|
|
199
|
+
extracted_nodes: list[EntityNode],
|
|
200
|
+
indexes: DedupCandidateIndexes,
|
|
201
|
+
state: DedupResolutionState,
|
|
202
|
+
) -> None:
|
|
203
|
+
"""Attempt deterministic resolution using exact name hits and fuzzy MinHash comparisons."""
|
|
204
|
+
for idx, node in enumerate(extracted_nodes):
|
|
205
|
+
normalized_exact = _normalize_string_exact(node.name)
|
|
206
|
+
normalized_fuzzy = _normalize_name_for_fuzzy(node.name)
|
|
207
|
+
|
|
208
|
+
if not _has_high_entropy(normalized_fuzzy):
|
|
209
|
+
state.unresolved_indices.append(idx)
|
|
210
|
+
continue
|
|
211
|
+
|
|
212
|
+
existing_matches = indexes.normalized_existing.get(normalized_exact, [])
|
|
213
|
+
if len(existing_matches) == 1:
|
|
214
|
+
match = existing_matches[0]
|
|
215
|
+
state.resolved_nodes[idx] = match
|
|
216
|
+
state.uuid_map[node.uuid] = match.uuid
|
|
217
|
+
if match.uuid != node.uuid:
|
|
218
|
+
state.duplicate_pairs.append((node, match))
|
|
219
|
+
continue
|
|
220
|
+
if len(existing_matches) > 1:
|
|
221
|
+
state.unresolved_indices.append(idx)
|
|
222
|
+
continue
|
|
223
|
+
|
|
224
|
+
shingles = _cached_shingles(normalized_fuzzy)
|
|
225
|
+
signature = _minhash_signature(shingles)
|
|
226
|
+
candidate_ids: set[str] = set()
|
|
227
|
+
for band_index, band in enumerate(_lsh_bands(signature)):
|
|
228
|
+
candidate_ids.update(indexes.lsh_buckets.get((band_index, band), []))
|
|
229
|
+
|
|
230
|
+
best_candidate: EntityNode | None = None
|
|
231
|
+
best_score = 0.0
|
|
232
|
+
for candidate_id in candidate_ids:
|
|
233
|
+
candidate_shingles = indexes.shingles_by_candidate.get(candidate_id, set())
|
|
234
|
+
score = _jaccard_similarity(shingles, candidate_shingles)
|
|
235
|
+
if score > best_score:
|
|
236
|
+
best_score = score
|
|
237
|
+
best_candidate = indexes.nodes_by_uuid.get(candidate_id)
|
|
238
|
+
|
|
239
|
+
if best_candidate is not None and best_score >= _FUZZY_JACCARD_THRESHOLD:
|
|
240
|
+
state.resolved_nodes[idx] = best_candidate
|
|
241
|
+
state.uuid_map[node.uuid] = best_candidate.uuid
|
|
242
|
+
if best_candidate.uuid != node.uuid:
|
|
243
|
+
state.duplicate_pairs.append((node, best_candidate))
|
|
244
|
+
continue
|
|
245
|
+
|
|
246
|
+
state.unresolved_indices.append(idx)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
__all__ = [
|
|
250
|
+
'DedupCandidateIndexes',
|
|
251
|
+
'DedupResolutionState',
|
|
252
|
+
'_normalize_string_exact',
|
|
253
|
+
'_normalize_name_for_fuzzy',
|
|
254
|
+
'_has_high_entropy',
|
|
255
|
+
'_minhash_signature',
|
|
256
|
+
'_lsh_bands',
|
|
257
|
+
'_jaccard_similarity',
|
|
258
|
+
'_cached_shingles',
|
|
259
|
+
'_FUZZY_JACCARD_THRESHOLD',
|
|
260
|
+
'_build_candidate_indexes',
|
|
261
|
+
'_resolve_with_similarity',
|
|
262
|
+
]
|