corp-extractor 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of corp-extractor might be problematic. Click here for more details.
- corp_extractor-0.2.7.dist-info/METADATA +377 -0
- corp_extractor-0.2.7.dist-info/RECORD +11 -0
- corp_extractor-0.2.7.dist-info/WHEEL +4 -0
- corp_extractor-0.2.7.dist-info/entry_points.txt +3 -0
- statement_extractor/__init__.py +110 -0
- statement_extractor/canonicalization.py +196 -0
- statement_extractor/cli.py +215 -0
- statement_extractor/extractor.py +649 -0
- statement_extractor/models.py +284 -0
- statement_extractor/predicate_comparer.py +611 -0
- statement_extractor/scoring.py +419 -0
|
@@ -0,0 +1,611 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Embedding-based predicate comparison and normalization.
|
|
3
|
+
|
|
4
|
+
Uses sentence-transformers for local, offline embedding computation.
|
|
5
|
+
Provides semantic similarity for deduplication and taxonomy matching.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
from .models import (
|
|
14
|
+
PredicateComparisonConfig,
|
|
15
|
+
PredicateMatch,
|
|
16
|
+
PredicateTaxonomy,
|
|
17
|
+
Statement,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class EmbeddingDependencyError(Exception):
|
|
24
|
+
"""Raised when sentence-transformers is required but not installed."""
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _check_embedding_dependency():
|
|
29
|
+
"""Check if sentence-transformers is installed, raise helpful error if not."""
|
|
30
|
+
try:
|
|
31
|
+
import sentence_transformers # noqa: F401
|
|
32
|
+
except ImportError:
|
|
33
|
+
raise EmbeddingDependencyError(
|
|
34
|
+
"Embedding-based comparison requires sentence-transformers.\n\n"
|
|
35
|
+
"Install with:\n"
|
|
36
|
+
" pip install corp-extractor[embeddings]\n"
|
|
37
|
+
" or: pip install sentence-transformers\n\n"
|
|
38
|
+
"To disable embeddings, set embedding_dedup=False in ExtractionOptions."
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class PredicateComparer:
|
|
43
|
+
"""
|
|
44
|
+
Embedding-based predicate comparison and normalization.
|
|
45
|
+
|
|
46
|
+
Features:
|
|
47
|
+
- Map extracted predicates to canonical forms from a taxonomy
|
|
48
|
+
- Detect duplicate/similar predicates for deduplication
|
|
49
|
+
- Fully offline using sentence-transformers
|
|
50
|
+
- Lazy model loading to avoid startup cost
|
|
51
|
+
- Caches taxonomy embeddings for efficiency
|
|
52
|
+
|
|
53
|
+
Example:
|
|
54
|
+
>>> taxonomy = PredicateTaxonomy(predicates=["acquired", "founded", "works_for"])
|
|
55
|
+
>>> comparer = PredicateComparer(taxonomy=taxonomy)
|
|
56
|
+
>>> match = comparer.match_to_canonical("bought")
|
|
57
|
+
>>> print(match.canonical) # "acquired"
|
|
58
|
+
>>> print(match.similarity) # ~0.82
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
taxonomy: Optional[PredicateTaxonomy] = None,
|
|
64
|
+
config: Optional[PredicateComparisonConfig] = None,
|
|
65
|
+
device: Optional[str] = None,
|
|
66
|
+
):
|
|
67
|
+
"""
|
|
68
|
+
Initialize the predicate comparer.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
taxonomy: Optional canonical predicate taxonomy for normalization
|
|
72
|
+
config: Comparison configuration (uses defaults if not provided)
|
|
73
|
+
device: Device to use ('cuda', 'cpu', or None for auto-detect)
|
|
74
|
+
|
|
75
|
+
Raises:
|
|
76
|
+
EmbeddingDependencyError: If sentence-transformers is not installed
|
|
77
|
+
"""
|
|
78
|
+
_check_embedding_dependency()
|
|
79
|
+
|
|
80
|
+
self.taxonomy = taxonomy
|
|
81
|
+
self.config = config or PredicateComparisonConfig()
|
|
82
|
+
|
|
83
|
+
# Auto-detect device
|
|
84
|
+
if device is None:
|
|
85
|
+
import torch
|
|
86
|
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
87
|
+
else:
|
|
88
|
+
self.device = device
|
|
89
|
+
|
|
90
|
+
# Lazy-loaded resources
|
|
91
|
+
self._model = None
|
|
92
|
+
self._taxonomy_embeddings: Optional[np.ndarray] = None
|
|
93
|
+
|
|
94
|
+
def _load_model(self):
|
|
95
|
+
"""Load sentence-transformers model lazily."""
|
|
96
|
+
if self._model is not None:
|
|
97
|
+
return
|
|
98
|
+
|
|
99
|
+
from sentence_transformers import SentenceTransformer
|
|
100
|
+
|
|
101
|
+
logger.info(f"Loading embedding model: {self.config.embedding_model} on {self.device}")
|
|
102
|
+
self._model = SentenceTransformer(self.config.embedding_model, device=self.device)
|
|
103
|
+
logger.info(f"Embedding model loaded on {self.device}")
|
|
104
|
+
|
|
105
|
+
def _normalize_text(self, text: str) -> str:
|
|
106
|
+
"""Normalize text before embedding."""
|
|
107
|
+
if self.config.normalize_text:
|
|
108
|
+
return text.lower().strip()
|
|
109
|
+
return text.strip()
|
|
110
|
+
|
|
111
|
+
def _compute_embeddings(self, texts: list[str]) -> np.ndarray:
|
|
112
|
+
"""Compute embeddings for a list of texts."""
|
|
113
|
+
self._load_model()
|
|
114
|
+
normalized = [self._normalize_text(t) for t in texts]
|
|
115
|
+
return self._model.encode(normalized, convert_to_numpy=True)
|
|
116
|
+
|
|
117
|
+
def _get_taxonomy_embeddings(self) -> np.ndarray:
|
|
118
|
+
"""Get or compute cached taxonomy embeddings."""
|
|
119
|
+
if self.taxonomy is None:
|
|
120
|
+
raise ValueError("No taxonomy provided")
|
|
121
|
+
|
|
122
|
+
if self._taxonomy_embeddings is None:
|
|
123
|
+
logger.debug(f"Computing embeddings for {len(self.taxonomy.predicates)} taxonomy predicates")
|
|
124
|
+
self._taxonomy_embeddings = self._compute_embeddings(self.taxonomy.predicates)
|
|
125
|
+
|
|
126
|
+
return self._taxonomy_embeddings
|
|
127
|
+
|
|
128
|
+
def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
|
|
129
|
+
"""Compute cosine similarity between two vectors."""
|
|
130
|
+
dot = np.dot(vec1, vec2)
|
|
131
|
+
norm1 = np.linalg.norm(vec1)
|
|
132
|
+
norm2 = np.linalg.norm(vec2)
|
|
133
|
+
if norm1 == 0 or norm2 == 0:
|
|
134
|
+
return 0.0
|
|
135
|
+
return float(dot / (norm1 * norm2))
|
|
136
|
+
|
|
137
|
+
def _cosine_similarity_batch(self, vec: np.ndarray, matrix: np.ndarray) -> np.ndarray:
|
|
138
|
+
"""Compute cosine similarity between a vector and all rows of a matrix."""
|
|
139
|
+
vec_norm = vec / (np.linalg.norm(vec) + 1e-8)
|
|
140
|
+
matrix_norms = np.linalg.norm(matrix, axis=1, keepdims=True) + 1e-8
|
|
141
|
+
matrix_normalized = matrix / matrix_norms
|
|
142
|
+
return np.dot(matrix_normalized, vec_norm)
|
|
143
|
+
|
|
144
|
+
# =========================================================================
|
|
145
|
+
# Public API
|
|
146
|
+
# =========================================================================
|
|
147
|
+
|
|
148
|
+
def match_to_canonical(self, predicate: str) -> PredicateMatch:
|
|
149
|
+
"""
|
|
150
|
+
Match a predicate to the closest canonical form in the taxonomy.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
predicate: The extracted predicate to match
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
PredicateMatch with canonical form and similarity score
|
|
157
|
+
"""
|
|
158
|
+
if self.taxonomy is None or len(self.taxonomy.predicates) == 0:
|
|
159
|
+
return PredicateMatch(original=predicate)
|
|
160
|
+
|
|
161
|
+
pred_embedding = self._compute_embeddings([predicate])[0]
|
|
162
|
+
taxonomy_embeddings = self._get_taxonomy_embeddings()
|
|
163
|
+
|
|
164
|
+
similarities = self._cosine_similarity_batch(pred_embedding, taxonomy_embeddings)
|
|
165
|
+
best_idx = int(np.argmax(similarities))
|
|
166
|
+
best_score = float(similarities[best_idx])
|
|
167
|
+
|
|
168
|
+
if best_score >= self.config.similarity_threshold:
|
|
169
|
+
return PredicateMatch(
|
|
170
|
+
original=predicate,
|
|
171
|
+
canonical=self.taxonomy.predicates[best_idx],
|
|
172
|
+
similarity=best_score,
|
|
173
|
+
matched=True,
|
|
174
|
+
)
|
|
175
|
+
else:
|
|
176
|
+
return PredicateMatch(
|
|
177
|
+
original=predicate,
|
|
178
|
+
canonical=None,
|
|
179
|
+
similarity=best_score,
|
|
180
|
+
matched=False,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
def match_batch(self, predicates: list[str]) -> list[PredicateMatch]:
|
|
184
|
+
"""
|
|
185
|
+
Match multiple predicates to canonical forms efficiently.
|
|
186
|
+
|
|
187
|
+
Uses batch embedding computation for better performance.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
predicates: List of predicates to match
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
List of PredicateMatch results
|
|
194
|
+
"""
|
|
195
|
+
if self.taxonomy is None or len(self.taxonomy.predicates) == 0:
|
|
196
|
+
return [PredicateMatch(original=p) for p in predicates]
|
|
197
|
+
|
|
198
|
+
# Batch embedding computation
|
|
199
|
+
pred_embeddings = self._compute_embeddings(predicates)
|
|
200
|
+
taxonomy_embeddings = self._get_taxonomy_embeddings()
|
|
201
|
+
|
|
202
|
+
results = []
|
|
203
|
+
for i, predicate in enumerate(predicates):
|
|
204
|
+
similarities = self._cosine_similarity_batch(
|
|
205
|
+
pred_embeddings[i],
|
|
206
|
+
taxonomy_embeddings
|
|
207
|
+
)
|
|
208
|
+
best_idx = int(np.argmax(similarities))
|
|
209
|
+
best_score = float(similarities[best_idx])
|
|
210
|
+
|
|
211
|
+
if best_score >= self.config.similarity_threshold:
|
|
212
|
+
results.append(PredicateMatch(
|
|
213
|
+
original=predicate,
|
|
214
|
+
canonical=self.taxonomy.predicates[best_idx],
|
|
215
|
+
similarity=best_score,
|
|
216
|
+
matched=True,
|
|
217
|
+
))
|
|
218
|
+
else:
|
|
219
|
+
results.append(PredicateMatch(
|
|
220
|
+
original=predicate,
|
|
221
|
+
canonical=None,
|
|
222
|
+
similarity=best_score,
|
|
223
|
+
matched=False,
|
|
224
|
+
))
|
|
225
|
+
|
|
226
|
+
return results
|
|
227
|
+
|
|
228
|
+
def are_similar(
|
|
229
|
+
self,
|
|
230
|
+
pred1: str,
|
|
231
|
+
pred2: str,
|
|
232
|
+
threshold: Optional[float] = None
|
|
233
|
+
) -> bool:
|
|
234
|
+
"""
|
|
235
|
+
Check if two predicates are semantically similar.
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
pred1: First predicate
|
|
239
|
+
pred2: Second predicate
|
|
240
|
+
threshold: Similarity threshold (uses config.dedup_threshold if not provided)
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
True if predicates are similar above threshold
|
|
244
|
+
"""
|
|
245
|
+
embeddings = self._compute_embeddings([pred1, pred2])
|
|
246
|
+
similarity = self._cosine_similarity(embeddings[0], embeddings[1])
|
|
247
|
+
threshold = threshold if threshold is not None else self.config.dedup_threshold
|
|
248
|
+
return similarity >= threshold
|
|
249
|
+
|
|
250
|
+
def compute_similarity(self, pred1: str, pred2: str) -> float:
|
|
251
|
+
"""
|
|
252
|
+
Compute similarity score between two predicates.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
pred1: First predicate
|
|
256
|
+
pred2: Second predicate
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
Cosine similarity score (0.0 to 1.0)
|
|
260
|
+
"""
|
|
261
|
+
embeddings = self._compute_embeddings([pred1, pred2])
|
|
262
|
+
return self._cosine_similarity(embeddings[0], embeddings[1])
|
|
263
|
+
|
|
264
|
+
def deduplicate_statements(
|
|
265
|
+
self,
|
|
266
|
+
statements: list[Statement],
|
|
267
|
+
entity_canonicalizer: Optional[callable] = None,
|
|
268
|
+
detect_reversals: bool = True,
|
|
269
|
+
) -> list[Statement]:
|
|
270
|
+
"""
|
|
271
|
+
Remove duplicate statements using embedding-based predicate comparison.
|
|
272
|
+
|
|
273
|
+
Two statements are considered duplicates if:
|
|
274
|
+
- Canonicalized subjects match AND canonicalized objects match, OR
|
|
275
|
+
- Canonicalized subjects match objects (reversed) when detect_reversals=True
|
|
276
|
+
- Predicates are similar (embedding-based)
|
|
277
|
+
|
|
278
|
+
When duplicates are found, keeps the statement with better contextualized
|
|
279
|
+
match (comparing "Subject Predicate Object" against source text).
|
|
280
|
+
|
|
281
|
+
For reversed duplicates, the correct orientation is determined by comparing
|
|
282
|
+
both "S P O" and "O P S" against source text.
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
statements: List of Statement objects
|
|
286
|
+
entity_canonicalizer: Optional function to canonicalize entity text
|
|
287
|
+
detect_reversals: Whether to detect reversed duplicates (default True)
|
|
288
|
+
|
|
289
|
+
Returns:
|
|
290
|
+
Deduplicated list of statements (keeps best contextualized match)
|
|
291
|
+
"""
|
|
292
|
+
if len(statements) <= 1:
|
|
293
|
+
return statements
|
|
294
|
+
|
|
295
|
+
def canonicalize(text: str) -> str:
|
|
296
|
+
if entity_canonicalizer:
|
|
297
|
+
return entity_canonicalizer(text)
|
|
298
|
+
return text.lower().strip()
|
|
299
|
+
|
|
300
|
+
# Compute all predicate embeddings at once for efficiency
|
|
301
|
+
predicates = [s.predicate for s in statements]
|
|
302
|
+
pred_embeddings = self._compute_embeddings(predicates)
|
|
303
|
+
|
|
304
|
+
# Compute contextualized embeddings: "Subject Predicate Object" for each statement
|
|
305
|
+
contextualized_texts = [
|
|
306
|
+
f"{s.subject.text} {s.predicate} {s.object.text}" for s in statements
|
|
307
|
+
]
|
|
308
|
+
contextualized_embeddings = self._compute_embeddings(contextualized_texts)
|
|
309
|
+
|
|
310
|
+
# Compute reversed contextualized embeddings: "Object Predicate Subject"
|
|
311
|
+
reversed_texts = [
|
|
312
|
+
f"{s.object.text} {s.predicate} {s.subject.text}" for s in statements
|
|
313
|
+
]
|
|
314
|
+
reversed_embeddings = self._compute_embeddings(reversed_texts)
|
|
315
|
+
|
|
316
|
+
# Compute source text embeddings for scoring which duplicate to keep
|
|
317
|
+
source_embeddings = []
|
|
318
|
+
for stmt in statements:
|
|
319
|
+
source_text = stmt.source_text or f"{stmt.subject.text} {stmt.predicate} {stmt.object.text}"
|
|
320
|
+
source_embeddings.append(self._compute_embeddings([source_text])[0])
|
|
321
|
+
|
|
322
|
+
unique_statements: list[Statement] = []
|
|
323
|
+
unique_pred_embeddings: list[np.ndarray] = []
|
|
324
|
+
unique_context_embeddings: list[np.ndarray] = []
|
|
325
|
+
unique_reversed_embeddings: list[np.ndarray] = []
|
|
326
|
+
unique_source_embeddings: list[np.ndarray] = []
|
|
327
|
+
unique_indices: list[int] = []
|
|
328
|
+
|
|
329
|
+
for i, stmt in enumerate(statements):
|
|
330
|
+
subj_canon = canonicalize(stmt.subject.text)
|
|
331
|
+
obj_canon = canonicalize(stmt.object.text)
|
|
332
|
+
|
|
333
|
+
duplicate_idx = None
|
|
334
|
+
is_reversed_match = False
|
|
335
|
+
|
|
336
|
+
for j, unique_stmt in enumerate(unique_statements):
|
|
337
|
+
unique_subj = canonicalize(unique_stmt.subject.text)
|
|
338
|
+
unique_obj = canonicalize(unique_stmt.object.text)
|
|
339
|
+
|
|
340
|
+
# Check direct match: subject->subject, object->object
|
|
341
|
+
direct_match = (subj_canon == unique_subj and obj_canon == unique_obj)
|
|
342
|
+
|
|
343
|
+
# Check reversed match: subject->object, object->subject
|
|
344
|
+
reversed_match = (
|
|
345
|
+
detect_reversals and
|
|
346
|
+
subj_canon == unique_obj and
|
|
347
|
+
obj_canon == unique_subj
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
if not direct_match and not reversed_match:
|
|
351
|
+
continue
|
|
352
|
+
|
|
353
|
+
# Check predicate similarity
|
|
354
|
+
similarity = self._cosine_similarity(
|
|
355
|
+
pred_embeddings[i],
|
|
356
|
+
unique_pred_embeddings[j]
|
|
357
|
+
)
|
|
358
|
+
if similarity >= self.config.dedup_threshold:
|
|
359
|
+
duplicate_idx = j
|
|
360
|
+
is_reversed_match = reversed_match and not direct_match
|
|
361
|
+
break
|
|
362
|
+
|
|
363
|
+
if duplicate_idx is None:
|
|
364
|
+
# Not a duplicate - add to unique list
|
|
365
|
+
unique_statements.append(stmt)
|
|
366
|
+
unique_pred_embeddings.append(pred_embeddings[i])
|
|
367
|
+
unique_context_embeddings.append(contextualized_embeddings[i])
|
|
368
|
+
unique_reversed_embeddings.append(reversed_embeddings[i])
|
|
369
|
+
unique_source_embeddings.append(source_embeddings[i])
|
|
370
|
+
unique_indices.append(i)
|
|
371
|
+
else:
|
|
372
|
+
existing_stmt = unique_statements[duplicate_idx]
|
|
373
|
+
|
|
374
|
+
if is_reversed_match:
|
|
375
|
+
# Reversed duplicate - determine correct orientation using source text
|
|
376
|
+
# Compare current's normal vs reversed against its source
|
|
377
|
+
current_normal_score = self._cosine_similarity(
|
|
378
|
+
contextualized_embeddings[i], source_embeddings[i]
|
|
379
|
+
)
|
|
380
|
+
current_reversed_score = self._cosine_similarity(
|
|
381
|
+
reversed_embeddings[i], source_embeddings[i]
|
|
382
|
+
)
|
|
383
|
+
# Compare existing's normal vs reversed against its source
|
|
384
|
+
existing_normal_score = self._cosine_similarity(
|
|
385
|
+
unique_context_embeddings[duplicate_idx],
|
|
386
|
+
unique_source_embeddings[duplicate_idx]
|
|
387
|
+
)
|
|
388
|
+
existing_reversed_score = self._cosine_similarity(
|
|
389
|
+
unique_reversed_embeddings[duplicate_idx],
|
|
390
|
+
unique_source_embeddings[duplicate_idx]
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
# Determine best orientation for current
|
|
394
|
+
current_best = max(current_normal_score, current_reversed_score)
|
|
395
|
+
current_should_reverse = current_reversed_score > current_normal_score
|
|
396
|
+
|
|
397
|
+
# Determine best orientation for existing
|
|
398
|
+
existing_best = max(existing_normal_score, existing_reversed_score)
|
|
399
|
+
existing_should_reverse = existing_reversed_score > existing_normal_score
|
|
400
|
+
|
|
401
|
+
if current_best > existing_best:
|
|
402
|
+
# Current is better - use it (possibly reversed)
|
|
403
|
+
if current_should_reverse:
|
|
404
|
+
best_stmt = stmt.reversed()
|
|
405
|
+
else:
|
|
406
|
+
best_stmt = stmt
|
|
407
|
+
# Merge entity types from existing (accounting for reversal)
|
|
408
|
+
if existing_should_reverse:
|
|
409
|
+
best_stmt = best_stmt.merge_entity_types_from(existing_stmt.reversed())
|
|
410
|
+
else:
|
|
411
|
+
best_stmt = best_stmt.merge_entity_types_from(existing_stmt)
|
|
412
|
+
unique_statements[duplicate_idx] = best_stmt
|
|
413
|
+
unique_pred_embeddings[duplicate_idx] = pred_embeddings[i]
|
|
414
|
+
unique_context_embeddings[duplicate_idx] = contextualized_embeddings[i]
|
|
415
|
+
unique_reversed_embeddings[duplicate_idx] = reversed_embeddings[i]
|
|
416
|
+
unique_source_embeddings[duplicate_idx] = source_embeddings[i]
|
|
417
|
+
unique_indices[duplicate_idx] = i
|
|
418
|
+
else:
|
|
419
|
+
# Existing is better - possibly fix its orientation
|
|
420
|
+
if existing_should_reverse and not existing_stmt.was_reversed:
|
|
421
|
+
best_stmt = existing_stmt.reversed()
|
|
422
|
+
else:
|
|
423
|
+
best_stmt = existing_stmt
|
|
424
|
+
# Merge entity types from current (accounting for reversal)
|
|
425
|
+
if current_should_reverse:
|
|
426
|
+
best_stmt = best_stmt.merge_entity_types_from(stmt.reversed())
|
|
427
|
+
else:
|
|
428
|
+
best_stmt = best_stmt.merge_entity_types_from(stmt)
|
|
429
|
+
unique_statements[duplicate_idx] = best_stmt
|
|
430
|
+
else:
|
|
431
|
+
# Direct duplicate - keep the one with better contextualized match
|
|
432
|
+
current_score = self._cosine_similarity(
|
|
433
|
+
contextualized_embeddings[i], source_embeddings[i]
|
|
434
|
+
)
|
|
435
|
+
existing_score = self._cosine_similarity(
|
|
436
|
+
unique_context_embeddings[duplicate_idx],
|
|
437
|
+
unique_source_embeddings[duplicate_idx]
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
if current_score > existing_score:
|
|
441
|
+
# Current statement is a better match - replace
|
|
442
|
+
merged_stmt = stmt.merge_entity_types_from(existing_stmt)
|
|
443
|
+
unique_statements[duplicate_idx] = merged_stmt
|
|
444
|
+
unique_pred_embeddings[duplicate_idx] = pred_embeddings[i]
|
|
445
|
+
unique_context_embeddings[duplicate_idx] = contextualized_embeddings[i]
|
|
446
|
+
unique_reversed_embeddings[duplicate_idx] = reversed_embeddings[i]
|
|
447
|
+
unique_source_embeddings[duplicate_idx] = source_embeddings[i]
|
|
448
|
+
unique_indices[duplicate_idx] = i
|
|
449
|
+
else:
|
|
450
|
+
# Existing statement is better - merge entity types from current
|
|
451
|
+
merged_stmt = existing_stmt.merge_entity_types_from(stmt)
|
|
452
|
+
unique_statements[duplicate_idx] = merged_stmt
|
|
453
|
+
|
|
454
|
+
return unique_statements
|
|
455
|
+
|
|
456
|
+
def normalize_predicates(
|
|
457
|
+
self,
|
|
458
|
+
statements: list[Statement]
|
|
459
|
+
) -> list[Statement]:
|
|
460
|
+
"""
|
|
461
|
+
Normalize all predicates in statements to canonical forms.
|
|
462
|
+
|
|
463
|
+
Uses contextualized matching: compares "Subject CanonicalPredicate Object"
|
|
464
|
+
against the statement's source text for better semantic matching.
|
|
465
|
+
|
|
466
|
+
Sets canonical_predicate field on each statement if a match is found.
|
|
467
|
+
|
|
468
|
+
Args:
|
|
469
|
+
statements: List of Statement objects
|
|
470
|
+
|
|
471
|
+
Returns:
|
|
472
|
+
Statements with canonical_predicate field populated
|
|
473
|
+
"""
|
|
474
|
+
if self.taxonomy is None or len(self.taxonomy.predicates) == 0:
|
|
475
|
+
return statements
|
|
476
|
+
|
|
477
|
+
for stmt in statements:
|
|
478
|
+
match = self._match_predicate_contextualized(stmt)
|
|
479
|
+
if match.matched and match.canonical:
|
|
480
|
+
stmt.canonical_predicate = match.canonical
|
|
481
|
+
|
|
482
|
+
return statements
|
|
483
|
+
|
|
484
|
+
def _match_predicate_contextualized(self, statement: Statement) -> PredicateMatch:
|
|
485
|
+
"""
|
|
486
|
+
Match a statement's predicate to canonical form using full context.
|
|
487
|
+
|
|
488
|
+
Compares "Subject CanonicalPredicate Object" strings against the
|
|
489
|
+
statement's source text for better semantic matching.
|
|
490
|
+
|
|
491
|
+
Args:
|
|
492
|
+
statement: The statement to match
|
|
493
|
+
|
|
494
|
+
Returns:
|
|
495
|
+
PredicateMatch with best canonical form
|
|
496
|
+
"""
|
|
497
|
+
if self.taxonomy is None or len(self.taxonomy.predicates) == 0:
|
|
498
|
+
return PredicateMatch(original=statement.predicate)
|
|
499
|
+
|
|
500
|
+
# Get the reference text to compare against
|
|
501
|
+
# Use source_text if available, otherwise construct from components
|
|
502
|
+
reference_text = statement.source_text or f"{statement.subject.text} {statement.predicate} {statement.object.text}"
|
|
503
|
+
|
|
504
|
+
# Compute embedding for the reference text
|
|
505
|
+
reference_embedding = self._compute_embeddings([reference_text])[0]
|
|
506
|
+
|
|
507
|
+
# Construct contextualized strings for each canonical predicate
|
|
508
|
+
# Format: "Subject CanonicalPredicate Object"
|
|
509
|
+
canonical_statements = [
|
|
510
|
+
f"{statement.subject.text} {canonical_pred} {statement.object.text}"
|
|
511
|
+
for canonical_pred in self.taxonomy.predicates
|
|
512
|
+
]
|
|
513
|
+
|
|
514
|
+
# Compute embeddings for all canonical statement forms
|
|
515
|
+
canonical_embeddings = self._compute_embeddings(canonical_statements)
|
|
516
|
+
|
|
517
|
+
# Find best match
|
|
518
|
+
similarities = self._cosine_similarity_batch(reference_embedding, canonical_embeddings)
|
|
519
|
+
best_idx = int(np.argmax(similarities))
|
|
520
|
+
best_score = float(similarities[best_idx])
|
|
521
|
+
|
|
522
|
+
if best_score >= self.config.similarity_threshold:
|
|
523
|
+
return PredicateMatch(
|
|
524
|
+
original=statement.predicate,
|
|
525
|
+
canonical=self.taxonomy.predicates[best_idx],
|
|
526
|
+
similarity=best_score,
|
|
527
|
+
matched=True,
|
|
528
|
+
)
|
|
529
|
+
else:
|
|
530
|
+
return PredicateMatch(
|
|
531
|
+
original=statement.predicate,
|
|
532
|
+
canonical=None,
|
|
533
|
+
similarity=best_score,
|
|
534
|
+
matched=False,
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
def detect_and_fix_reversals(
|
|
538
|
+
self,
|
|
539
|
+
statements: list[Statement],
|
|
540
|
+
threshold: float = 0.05,
|
|
541
|
+
) -> list[Statement]:
|
|
542
|
+
"""
|
|
543
|
+
Detect and fix subject-object reversals using embedding comparison.
|
|
544
|
+
|
|
545
|
+
For each statement, compares:
|
|
546
|
+
- "Subject Predicate Object" embedding against source_text
|
|
547
|
+
- "Object Predicate Subject" embedding against source_text
|
|
548
|
+
|
|
549
|
+
If the reversed version has significantly higher similarity to the source,
|
|
550
|
+
the subject and object are swapped and was_reversed is set to True.
|
|
551
|
+
|
|
552
|
+
Args:
|
|
553
|
+
statements: List of Statement objects
|
|
554
|
+
threshold: Minimum similarity difference to trigger reversal (default 0.05)
|
|
555
|
+
|
|
556
|
+
Returns:
|
|
557
|
+
List of statements with reversals corrected
|
|
558
|
+
"""
|
|
559
|
+
if not statements:
|
|
560
|
+
return statements
|
|
561
|
+
|
|
562
|
+
result = []
|
|
563
|
+
for stmt in statements:
|
|
564
|
+
# Skip if no source_text to compare against
|
|
565
|
+
if not stmt.source_text:
|
|
566
|
+
result.append(stmt)
|
|
567
|
+
continue
|
|
568
|
+
|
|
569
|
+
# Build normal and reversed triple strings
|
|
570
|
+
normal_text = f"{stmt.subject.text} {stmt.predicate} {stmt.object.text}"
|
|
571
|
+
reversed_text = f"{stmt.object.text} {stmt.predicate} {stmt.subject.text}"
|
|
572
|
+
|
|
573
|
+
# Compute embeddings for normal, reversed, and source
|
|
574
|
+
embeddings = self._compute_embeddings([normal_text, reversed_text, stmt.source_text])
|
|
575
|
+
normal_emb, reversed_emb, source_emb = embeddings[0], embeddings[1], embeddings[2]
|
|
576
|
+
|
|
577
|
+
# Compute similarities to source
|
|
578
|
+
normal_sim = self._cosine_similarity(normal_emb, source_emb)
|
|
579
|
+
reversed_sim = self._cosine_similarity(reversed_emb, source_emb)
|
|
580
|
+
|
|
581
|
+
# If reversed is significantly better, swap subject and object
|
|
582
|
+
if reversed_sim > normal_sim + threshold:
|
|
583
|
+
result.append(stmt.reversed())
|
|
584
|
+
else:
|
|
585
|
+
result.append(stmt)
|
|
586
|
+
|
|
587
|
+
return result
|
|
588
|
+
|
|
589
|
+
def check_reversal(self, statement: Statement) -> tuple[bool, float, float]:
|
|
590
|
+
"""
|
|
591
|
+
Check if a single statement should be reversed.
|
|
592
|
+
|
|
593
|
+
Args:
|
|
594
|
+
statement: Statement to check
|
|
595
|
+
|
|
596
|
+
Returns:
|
|
597
|
+
Tuple of (should_reverse, normal_similarity, reversed_similarity)
|
|
598
|
+
"""
|
|
599
|
+
if not statement.source_text:
|
|
600
|
+
return (False, 0.0, 0.0)
|
|
601
|
+
|
|
602
|
+
normal_text = f"{statement.subject.text} {statement.predicate} {statement.object.text}"
|
|
603
|
+
reversed_text = f"{statement.object.text} {statement.predicate} {statement.subject.text}"
|
|
604
|
+
|
|
605
|
+
embeddings = self._compute_embeddings([normal_text, reversed_text, statement.source_text])
|
|
606
|
+
normal_emb, reversed_emb, source_emb = embeddings[0], embeddings[1], embeddings[2]
|
|
607
|
+
|
|
608
|
+
normal_sim = self._cosine_similarity(normal_emb, source_emb)
|
|
609
|
+
reversed_sim = self._cosine_similarity(reversed_emb, source_emb)
|
|
610
|
+
|
|
611
|
+
return (reversed_sim > normal_sim, normal_sim, reversed_sim)
|