graphiti-core 0.21.0rc6__py3-none-any.whl → 0.21.0rc8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/graphiti.py +1 -0
- graphiti_core/llm_client/client.py +14 -4
- graphiti_core/llm_client/gemini_client.py +2 -2
- graphiti_core/llm_client/openai_base_client.py +2 -2
- graphiti_core/llm_client/openai_generic_client.py +2 -2
- graphiti_core/prompts/dedupe_nodes.py +42 -26
- graphiti_core/prompts/extract_nodes.py +2 -1
- graphiti_core/utils/bulk_utils.py +131 -63
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +106 -7
- graphiti_core/utils/maintenance/node_operations.py +171 -64
- {graphiti_core-0.21.0rc6.dist-info → graphiti_core-0.21.0rc8.dist-info}/METADATA +4 -1
- {graphiti_core-0.21.0rc6.dist-info → graphiti_core-0.21.0rc8.dist-info}/RECORD +15 -14
- {graphiti_core-0.21.0rc6.dist-info → graphiti_core-0.21.0rc8.dist-info}/WHEEL +0 -0
- {graphiti_core-0.21.0rc6.dist-info → graphiti_core-0.21.0rc8.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import math
|
|
20
|
+
import re
|
|
21
|
+
from collections import defaultdict
|
|
22
|
+
from collections.abc import Iterable
|
|
23
|
+
from dataclasses import dataclass, field
|
|
24
|
+
from functools import lru_cache
|
|
25
|
+
from hashlib import blake2b
|
|
26
|
+
from typing import TYPE_CHECKING
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from graphiti_core.nodes import EntityNode
|
|
30
|
+
|
|
31
|
+
_NAME_ENTROPY_THRESHOLD = 1.5
|
|
32
|
+
_MIN_NAME_LENGTH = 6
|
|
33
|
+
_MIN_TOKEN_COUNT = 2
|
|
34
|
+
_FUZZY_JACCARD_THRESHOLD = 0.9
|
|
35
|
+
_MINHASH_PERMUTATIONS = 32
|
|
36
|
+
_MINHASH_BAND_SIZE = 4
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _normalize_string_exact(name: str) -> str:
|
|
40
|
+
"""Lowercase text and collapse whitespace so equal names map to the same key."""
|
|
41
|
+
normalized = re.sub(r'[\s]+', ' ', name.lower())
|
|
42
|
+
return normalized.strip()
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _normalize_name_for_fuzzy(name: str) -> str:
|
|
46
|
+
"""Produce a fuzzier form that keeps alphanumerics and apostrophes for n-gram shingles."""
|
|
47
|
+
normalized = re.sub(r"[^a-z0-9' ]", ' ', _normalize_string_exact(name))
|
|
48
|
+
normalized = normalized.strip()
|
|
49
|
+
return re.sub(r'[\s]+', ' ', normalized)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _name_entropy(normalized_name: str) -> float:
|
|
53
|
+
"""Approximate text specificity using Shannon entropy over characters.
|
|
54
|
+
|
|
55
|
+
We strip spaces, count how often each character appears, and sum
|
|
56
|
+
probability * -log2(probability). Short or repetitive names yield low
|
|
57
|
+
entropy, which signals we should defer resolution to the LLM instead of
|
|
58
|
+
trusting fuzzy similarity.
|
|
59
|
+
"""
|
|
60
|
+
if not normalized_name:
|
|
61
|
+
return 0.0
|
|
62
|
+
|
|
63
|
+
counts: dict[str, int] = {}
|
|
64
|
+
for char in normalized_name.replace(' ', ''):
|
|
65
|
+
counts[char] = counts.get(char, 0) + 1
|
|
66
|
+
|
|
67
|
+
total = sum(counts.values())
|
|
68
|
+
if total == 0:
|
|
69
|
+
return 0.0
|
|
70
|
+
|
|
71
|
+
entropy = 0.0
|
|
72
|
+
for count in counts.values():
|
|
73
|
+
probability = count / total
|
|
74
|
+
entropy -= probability * math.log2(probability)
|
|
75
|
+
|
|
76
|
+
return entropy
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _has_high_entropy(normalized_name: str) -> bool:
|
|
80
|
+
"""Filter out very short or low-entropy names that are unreliable for fuzzy matching."""
|
|
81
|
+
token_count = len(normalized_name.split())
|
|
82
|
+
if len(normalized_name) < _MIN_NAME_LENGTH and token_count < _MIN_TOKEN_COUNT:
|
|
83
|
+
return False
|
|
84
|
+
|
|
85
|
+
return _name_entropy(normalized_name) >= _NAME_ENTROPY_THRESHOLD
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _shingles(normalized_name: str) -> set[str]:
|
|
89
|
+
"""Create 3-gram shingles from the normalized name for MinHash calculations."""
|
|
90
|
+
cleaned = normalized_name.replace(' ', '')
|
|
91
|
+
if len(cleaned) < 2:
|
|
92
|
+
return {cleaned} if cleaned else set()
|
|
93
|
+
|
|
94
|
+
return {cleaned[i : i + 3] for i in range(len(cleaned) - 2)}
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _hash_shingle(shingle: str, seed: int) -> int:
|
|
98
|
+
"""Generate a deterministic 64-bit hash for a shingle given the permutation seed."""
|
|
99
|
+
digest = blake2b(f'{seed}:{shingle}'.encode(), digest_size=8)
|
|
100
|
+
return int.from_bytes(digest.digest(), 'big')
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _minhash_signature(shingles: Iterable[str]) -> tuple[int, ...]:
|
|
104
|
+
"""Compute the MinHash signature for the shingle set across predefined permutations."""
|
|
105
|
+
if not shingles:
|
|
106
|
+
return tuple()
|
|
107
|
+
|
|
108
|
+
seeds = range(_MINHASH_PERMUTATIONS)
|
|
109
|
+
signature: list[int] = []
|
|
110
|
+
for seed in seeds:
|
|
111
|
+
min_hash = min(_hash_shingle(shingle, seed) for shingle in shingles)
|
|
112
|
+
signature.append(min_hash)
|
|
113
|
+
|
|
114
|
+
return tuple(signature)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _lsh_bands(signature: Iterable[int]) -> list[tuple[int, ...]]:
|
|
118
|
+
"""Split the MinHash signature into fixed-size bands for locality-sensitive hashing."""
|
|
119
|
+
signature_list = list(signature)
|
|
120
|
+
if not signature_list:
|
|
121
|
+
return []
|
|
122
|
+
|
|
123
|
+
bands: list[tuple[int, ...]] = []
|
|
124
|
+
for start in range(0, len(signature_list), _MINHASH_BAND_SIZE):
|
|
125
|
+
band = tuple(signature_list[start : start + _MINHASH_BAND_SIZE])
|
|
126
|
+
if len(band) == _MINHASH_BAND_SIZE:
|
|
127
|
+
bands.append(band)
|
|
128
|
+
return bands
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _jaccard_similarity(a: set[str], b: set[str]) -> float:
|
|
132
|
+
"""Return the Jaccard similarity between two shingle sets, handling empty edge cases."""
|
|
133
|
+
if not a and not b:
|
|
134
|
+
return 1.0
|
|
135
|
+
if not a or not b:
|
|
136
|
+
return 0.0
|
|
137
|
+
|
|
138
|
+
intersection = len(a.intersection(b))
|
|
139
|
+
union = len(a.union(b))
|
|
140
|
+
return intersection / union if union else 0.0
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
@lru_cache(maxsize=512)
|
|
144
|
+
def _cached_shingles(name: str) -> set[str]:
|
|
145
|
+
"""Cache shingle sets per normalized name to avoid recomputation within a worker."""
|
|
146
|
+
return _shingles(name)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
@dataclass
|
|
150
|
+
class DedupCandidateIndexes:
|
|
151
|
+
"""Precomputed lookup structures that drive entity deduplication heuristics."""
|
|
152
|
+
|
|
153
|
+
existing_nodes: list[EntityNode]
|
|
154
|
+
nodes_by_uuid: dict[str, EntityNode]
|
|
155
|
+
normalized_existing: defaultdict[str, list[EntityNode]]
|
|
156
|
+
shingles_by_candidate: dict[str, set[str]]
|
|
157
|
+
lsh_buckets: defaultdict[tuple[int, tuple[int, ...]], list[str]]
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
@dataclass
|
|
161
|
+
class DedupResolutionState:
|
|
162
|
+
"""Mutable resolution bookkeeping shared across deterministic and LLM passes."""
|
|
163
|
+
|
|
164
|
+
resolved_nodes: list[EntityNode | None]
|
|
165
|
+
uuid_map: dict[str, str]
|
|
166
|
+
unresolved_indices: list[int]
|
|
167
|
+
duplicate_pairs: list[tuple[EntityNode, EntityNode]] = field(default_factory=list)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def _build_candidate_indexes(existing_nodes: list[EntityNode]) -> DedupCandidateIndexes:
|
|
171
|
+
"""Precompute exact and fuzzy lookup structures once per dedupe run."""
|
|
172
|
+
normalized_existing: defaultdict[str, list[EntityNode]] = defaultdict(list)
|
|
173
|
+
nodes_by_uuid: dict[str, EntityNode] = {}
|
|
174
|
+
shingles_by_candidate: dict[str, set[str]] = {}
|
|
175
|
+
lsh_buckets: defaultdict[tuple[int, tuple[int, ...]], list[str]] = defaultdict(list)
|
|
176
|
+
|
|
177
|
+
for candidate in existing_nodes:
|
|
178
|
+
normalized = _normalize_string_exact(candidate.name)
|
|
179
|
+
normalized_existing[normalized].append(candidate)
|
|
180
|
+
nodes_by_uuid[candidate.uuid] = candidate
|
|
181
|
+
|
|
182
|
+
shingles = _cached_shingles(_normalize_name_for_fuzzy(candidate.name))
|
|
183
|
+
shingles_by_candidate[candidate.uuid] = shingles
|
|
184
|
+
|
|
185
|
+
signature = _minhash_signature(shingles)
|
|
186
|
+
for band_index, band in enumerate(_lsh_bands(signature)):
|
|
187
|
+
lsh_buckets[(band_index, band)].append(candidate.uuid)
|
|
188
|
+
|
|
189
|
+
return DedupCandidateIndexes(
|
|
190
|
+
existing_nodes=existing_nodes,
|
|
191
|
+
nodes_by_uuid=nodes_by_uuid,
|
|
192
|
+
normalized_existing=normalized_existing,
|
|
193
|
+
shingles_by_candidate=shingles_by_candidate,
|
|
194
|
+
lsh_buckets=lsh_buckets,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _resolve_with_similarity(
|
|
199
|
+
extracted_nodes: list[EntityNode],
|
|
200
|
+
indexes: DedupCandidateIndexes,
|
|
201
|
+
state: DedupResolutionState,
|
|
202
|
+
) -> None:
|
|
203
|
+
"""Attempt deterministic resolution using exact name hits and fuzzy MinHash comparisons."""
|
|
204
|
+
for idx, node in enumerate(extracted_nodes):
|
|
205
|
+
normalized_exact = _normalize_string_exact(node.name)
|
|
206
|
+
normalized_fuzzy = _normalize_name_for_fuzzy(node.name)
|
|
207
|
+
|
|
208
|
+
if not _has_high_entropy(normalized_fuzzy):
|
|
209
|
+
state.unresolved_indices.append(idx)
|
|
210
|
+
continue
|
|
211
|
+
|
|
212
|
+
existing_matches = indexes.normalized_existing.get(normalized_exact, [])
|
|
213
|
+
if len(existing_matches) == 1:
|
|
214
|
+
match = existing_matches[0]
|
|
215
|
+
state.resolved_nodes[idx] = match
|
|
216
|
+
state.uuid_map[node.uuid] = match.uuid
|
|
217
|
+
if match.uuid != node.uuid:
|
|
218
|
+
state.duplicate_pairs.append((node, match))
|
|
219
|
+
continue
|
|
220
|
+
if len(existing_matches) > 1:
|
|
221
|
+
state.unresolved_indices.append(idx)
|
|
222
|
+
continue
|
|
223
|
+
|
|
224
|
+
shingles = _cached_shingles(normalized_fuzzy)
|
|
225
|
+
signature = _minhash_signature(shingles)
|
|
226
|
+
candidate_ids: set[str] = set()
|
|
227
|
+
for band_index, band in enumerate(_lsh_bands(signature)):
|
|
228
|
+
candidate_ids.update(indexes.lsh_buckets.get((band_index, band), []))
|
|
229
|
+
|
|
230
|
+
best_candidate: EntityNode | None = None
|
|
231
|
+
best_score = 0.0
|
|
232
|
+
for candidate_id in candidate_ids:
|
|
233
|
+
candidate_shingles = indexes.shingles_by_candidate.get(candidate_id, set())
|
|
234
|
+
score = _jaccard_similarity(shingles, candidate_shingles)
|
|
235
|
+
if score > best_score:
|
|
236
|
+
best_score = score
|
|
237
|
+
best_candidate = indexes.nodes_by_uuid.get(candidate_id)
|
|
238
|
+
|
|
239
|
+
if best_candidate is not None and best_score >= _FUZZY_JACCARD_THRESHOLD:
|
|
240
|
+
state.resolved_nodes[idx] = best_candidate
|
|
241
|
+
state.uuid_map[node.uuid] = best_candidate.uuid
|
|
242
|
+
if best_candidate.uuid != node.uuid:
|
|
243
|
+
state.duplicate_pairs.append((node, best_candidate))
|
|
244
|
+
continue
|
|
245
|
+
|
|
246
|
+
state.unresolved_indices.append(idx)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
__all__ = [
|
|
250
|
+
'DedupCandidateIndexes',
|
|
251
|
+
'DedupResolutionState',
|
|
252
|
+
'_normalize_string_exact',
|
|
253
|
+
'_normalize_name_for_fuzzy',
|
|
254
|
+
'_has_high_entropy',
|
|
255
|
+
'_minhash_signature',
|
|
256
|
+
'_lsh_bands',
|
|
257
|
+
'_jaccard_similarity',
|
|
258
|
+
'_cached_shingles',
|
|
259
|
+
'_FUZZY_JACCARD_THRESHOLD',
|
|
260
|
+
'_build_candidate_indexes',
|
|
261
|
+
'_resolve_with_similarity',
|
|
262
|
+
]
|
|
@@ -41,6 +41,9 @@ from graphiti_core.search.search_config import SearchResults
|
|
|
41
41
|
from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF
|
|
42
42
|
from graphiti_core.search.search_filters import SearchFilters
|
|
43
43
|
from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
|
|
44
|
+
from graphiti_core.utils.maintenance.dedup_helpers import _normalize_string_exact
|
|
45
|
+
|
|
46
|
+
DEFAULT_EDGE_NAME = 'RELATES_TO'
|
|
44
47
|
|
|
45
48
|
logger = logging.getLogger(__name__)
|
|
46
49
|
|
|
@@ -229,6 +232,22 @@ async def resolve_extracted_edges(
|
|
|
229
232
|
edge_types: dict[str, type[BaseModel]],
|
|
230
233
|
edge_type_map: dict[tuple[str, str], list[str]],
|
|
231
234
|
) -> tuple[list[EntityEdge], list[EntityEdge]]:
|
|
235
|
+
# Fast path: deduplicate exact matches within the extracted edges before parallel processing
|
|
236
|
+
seen: dict[tuple[str, str, str], EntityEdge] = {}
|
|
237
|
+
deduplicated_edges: list[EntityEdge] = []
|
|
238
|
+
|
|
239
|
+
for edge in extracted_edges:
|
|
240
|
+
key = (
|
|
241
|
+
edge.source_node_uuid,
|
|
242
|
+
edge.target_node_uuid,
|
|
243
|
+
_normalize_string_exact(edge.fact),
|
|
244
|
+
)
|
|
245
|
+
if key not in seen:
|
|
246
|
+
seen[key] = edge
|
|
247
|
+
deduplicated_edges.append(edge)
|
|
248
|
+
|
|
249
|
+
extracted_edges = deduplicated_edges
|
|
250
|
+
|
|
232
251
|
driver = clients.driver
|
|
233
252
|
llm_client = clients.llm_client
|
|
234
253
|
embedder = clients.embedder
|
|
@@ -280,8 +299,12 @@ async def resolve_extracted_edges(
|
|
|
280
299
|
# Build entity hash table
|
|
281
300
|
uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
|
|
282
301
|
|
|
283
|
-
# Determine which edge types are relevant for each edge
|
|
302
|
+
# Determine which edge types are relevant for each edge.
|
|
303
|
+
# `edge_types_lst` stores the subset of custom edge definitions whose
|
|
304
|
+
# node signature matches each extracted edge. Anything outside this subset
|
|
305
|
+
# should only stay on the edge if it is a non-custom (LLM generated) label.
|
|
284
306
|
edge_types_lst: list[dict[str, type[BaseModel]]] = []
|
|
307
|
+
custom_type_names = set(edge_types or {})
|
|
285
308
|
for extracted_edge in extracted_edges:
|
|
286
309
|
source_node = uuid_entity_map.get(extracted_edge.source_node_uuid)
|
|
287
310
|
target_node = uuid_entity_map.get(extracted_edge.target_node_uuid)
|
|
@@ -309,6 +332,20 @@ async def resolve_extracted_edges(
|
|
|
309
332
|
|
|
310
333
|
edge_types_lst.append(extracted_edge_types)
|
|
311
334
|
|
|
335
|
+
for extracted_edge, extracted_edge_types in zip(extracted_edges, edge_types_lst, strict=True):
|
|
336
|
+
allowed_type_names = set(extracted_edge_types)
|
|
337
|
+
is_custom_name = extracted_edge.name in custom_type_names
|
|
338
|
+
if not allowed_type_names:
|
|
339
|
+
# No custom types are valid for this node pairing. Keep LLM generated
|
|
340
|
+
# labels, but flip disallowed custom names back to the default.
|
|
341
|
+
if is_custom_name and extracted_edge.name != DEFAULT_EDGE_NAME:
|
|
342
|
+
extracted_edge.name = DEFAULT_EDGE_NAME
|
|
343
|
+
continue
|
|
344
|
+
if is_custom_name and extracted_edge.name not in allowed_type_names:
|
|
345
|
+
# Custom name exists but it is not permitted for this source/target
|
|
346
|
+
# signature, so fall back to the default edge label.
|
|
347
|
+
extracted_edge.name = DEFAULT_EDGE_NAME
|
|
348
|
+
|
|
312
349
|
# resolve edges with related edges in the graph and find invalidation candidates
|
|
313
350
|
results: list[tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]] = list(
|
|
314
351
|
await semaphore_gather(
|
|
@@ -320,6 +357,7 @@ async def resolve_extracted_edges(
|
|
|
320
357
|
existing_edges,
|
|
321
358
|
episode,
|
|
322
359
|
extracted_edge_types,
|
|
360
|
+
custom_type_names,
|
|
323
361
|
clients.ensure_ascii,
|
|
324
362
|
)
|
|
325
363
|
for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
|
|
@@ -391,17 +429,59 @@ async def resolve_extracted_edge(
|
|
|
391
429
|
related_edges: list[EntityEdge],
|
|
392
430
|
existing_edges: list[EntityEdge],
|
|
393
431
|
episode: EpisodicNode,
|
|
394
|
-
|
|
432
|
+
edge_type_candidates: dict[str, type[BaseModel]] | None = None,
|
|
433
|
+
custom_edge_type_names: set[str] | None = None,
|
|
395
434
|
ensure_ascii: bool = True,
|
|
396
435
|
) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
|
|
436
|
+
"""Resolve an extracted edge against existing graph context.
|
|
437
|
+
|
|
438
|
+
Parameters
|
|
439
|
+
----------
|
|
440
|
+
llm_client : LLMClient
|
|
441
|
+
Client used to invoke the LLM for deduplication and attribute extraction.
|
|
442
|
+
extracted_edge : EntityEdge
|
|
443
|
+
Newly extracted edge whose canonical representation is being resolved.
|
|
444
|
+
related_edges : list[EntityEdge]
|
|
445
|
+
Candidate edges with identical endpoints used for duplicate detection.
|
|
446
|
+
existing_edges : list[EntityEdge]
|
|
447
|
+
Broader set of edges evaluated for contradiction / invalidation.
|
|
448
|
+
episode : EpisodicNode
|
|
449
|
+
Episode providing content context when extracting edge attributes.
|
|
450
|
+
edge_type_candidates : dict[str, type[BaseModel]] | None
|
|
451
|
+
Custom edge types permitted for the current source/target signature.
|
|
452
|
+
custom_edge_type_names : set[str] | None
|
|
453
|
+
Full catalog of registered custom edge names. Used to distinguish
|
|
454
|
+
between disallowed custom types (which fall back to the default label)
|
|
455
|
+
and ad-hoc labels emitted by the LLM.
|
|
456
|
+
ensure_ascii : bool
|
|
457
|
+
Whether prompt payloads should coerce ASCII output.
|
|
458
|
+
|
|
459
|
+
Returns
|
|
460
|
+
-------
|
|
461
|
+
tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]
|
|
462
|
+
The resolved edge, any duplicates, and edges to invalidate.
|
|
463
|
+
"""
|
|
397
464
|
if len(related_edges) == 0 and len(existing_edges) == 0:
|
|
398
465
|
return extracted_edge, [], []
|
|
399
466
|
|
|
467
|
+
# Fast path: if the fact text and endpoints already exist verbatim, reuse the matching edge.
|
|
468
|
+
normalized_fact = _normalize_string_exact(extracted_edge.fact)
|
|
469
|
+
for edge in related_edges:
|
|
470
|
+
if (
|
|
471
|
+
edge.source_node_uuid == extracted_edge.source_node_uuid
|
|
472
|
+
and edge.target_node_uuid == extracted_edge.target_node_uuid
|
|
473
|
+
and _normalize_string_exact(edge.fact) == normalized_fact
|
|
474
|
+
):
|
|
475
|
+
resolved = edge
|
|
476
|
+
if episode is not None and episode.uuid not in resolved.episodes:
|
|
477
|
+
resolved.episodes.append(episode.uuid)
|
|
478
|
+
return resolved, [], []
|
|
479
|
+
|
|
400
480
|
start = time()
|
|
401
481
|
|
|
402
482
|
# Prepare context for LLM
|
|
403
483
|
related_edges_context = [
|
|
404
|
-
{'id':
|
|
484
|
+
{'id': i, 'fact': edge.fact} for i, edge in enumerate(related_edges)
|
|
405
485
|
]
|
|
406
486
|
|
|
407
487
|
invalidation_edge_candidates_context = [
|
|
@@ -415,9 +495,9 @@ async def resolve_extracted_edge(
|
|
|
415
495
|
'fact_type_name': type_name,
|
|
416
496
|
'fact_type_description': type_model.__doc__,
|
|
417
497
|
}
|
|
418
|
-
for i, (type_name, type_model) in enumerate(
|
|
498
|
+
for i, (type_name, type_model) in enumerate(edge_type_candidates.items())
|
|
419
499
|
]
|
|
420
|
-
if
|
|
500
|
+
if edge_type_candidates is not None
|
|
421
501
|
else []
|
|
422
502
|
)
|
|
423
503
|
|
|
@@ -454,7 +534,16 @@ async def resolve_extracted_edge(
|
|
|
454
534
|
]
|
|
455
535
|
|
|
456
536
|
fact_type: str = response_object.fact_type
|
|
457
|
-
|
|
537
|
+
candidate_type_names = set(edge_type_candidates or {})
|
|
538
|
+
custom_type_names = custom_edge_type_names or set()
|
|
539
|
+
|
|
540
|
+
is_default_type = fact_type.upper() == 'DEFAULT'
|
|
541
|
+
is_custom_type = fact_type in custom_type_names
|
|
542
|
+
is_allowed_custom_type = fact_type in candidate_type_names
|
|
543
|
+
|
|
544
|
+
if is_allowed_custom_type:
|
|
545
|
+
# The LLM selected a custom type that is allowed for the node pair.
|
|
546
|
+
# Adopt the custom type and, if needed, extract its structured attributes.
|
|
458
547
|
resolved_edge.name = fact_type
|
|
459
548
|
|
|
460
549
|
edge_attributes_context = {
|
|
@@ -464,7 +553,7 @@ async def resolve_extracted_edge(
|
|
|
464
553
|
'ensure_ascii': ensure_ascii,
|
|
465
554
|
}
|
|
466
555
|
|
|
467
|
-
edge_model =
|
|
556
|
+
edge_model = edge_type_candidates.get(fact_type) if edge_type_candidates else None
|
|
468
557
|
if edge_model is not None and len(edge_model.model_fields) != 0:
|
|
469
558
|
edge_attributes_response = await llm_client.generate_response(
|
|
470
559
|
prompt_library.extract_edges.extract_attributes(edge_attributes_context),
|
|
@@ -473,6 +562,16 @@ async def resolve_extracted_edge(
|
|
|
473
562
|
)
|
|
474
563
|
|
|
475
564
|
resolved_edge.attributes = edge_attributes_response
|
|
565
|
+
elif not is_default_type and is_custom_type:
|
|
566
|
+
# The LLM picked a custom type that is not allowed for this signature.
|
|
567
|
+
# Reset to the default label and drop any structured attributes.
|
|
568
|
+
resolved_edge.name = DEFAULT_EDGE_NAME
|
|
569
|
+
resolved_edge.attributes = {}
|
|
570
|
+
elif not is_default_type:
|
|
571
|
+
# Non-custom labels are allowed to pass through so long as the LLM does
|
|
572
|
+
# not return the sentinel DEFAULT value.
|
|
573
|
+
resolved_edge.name = fact_type
|
|
574
|
+
resolved_edge.attributes = {}
|
|
476
575
|
|
|
477
576
|
end = time()
|
|
478
577
|
logger.debug(
|