corp-extractor 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of corp-extractor might be problematic. Click here for more details.

@@ -0,0 +1,611 @@
1
+ """
2
+ Embedding-based predicate comparison and normalization.
3
+
4
+ Uses sentence-transformers for local, offline embedding computation.
5
+ Provides semantic similarity for deduplication and taxonomy matching.
6
+ """
7
+
8
+ import logging
9
+ from typing import Optional
10
+
11
+ import numpy as np
12
+
13
+ from .models import (
14
+ PredicateComparisonConfig,
15
+ PredicateMatch,
16
+ PredicateTaxonomy,
17
+ Statement,
18
+ )
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class EmbeddingDependencyError(Exception):
24
+ """Raised when sentence-transformers is required but not installed."""
25
+ pass
26
+
27
+
28
+ def _check_embedding_dependency():
29
+ """Check if sentence-transformers is installed, raise helpful error if not."""
30
+ try:
31
+ import sentence_transformers # noqa: F401
32
+ except ImportError:
33
+ raise EmbeddingDependencyError(
34
+ "Embedding-based comparison requires sentence-transformers.\n\n"
35
+ "Install with:\n"
36
+ " pip install corp-extractor[embeddings]\n"
37
+ " or: pip install sentence-transformers\n\n"
38
+ "To disable embeddings, set embedding_dedup=False in ExtractionOptions."
39
+ )
40
+
41
+
42
+ class PredicateComparer:
43
+ """
44
+ Embedding-based predicate comparison and normalization.
45
+
46
+ Features:
47
+ - Map extracted predicates to canonical forms from a taxonomy
48
+ - Detect duplicate/similar predicates for deduplication
49
+ - Fully offline using sentence-transformers
50
+ - Lazy model loading to avoid startup cost
51
+ - Caches taxonomy embeddings for efficiency
52
+
53
+ Example:
54
+ >>> taxonomy = PredicateTaxonomy(predicates=["acquired", "founded", "works_for"])
55
+ >>> comparer = PredicateComparer(taxonomy=taxonomy)
56
+ >>> match = comparer.match_to_canonical("bought")
57
+ >>> print(match.canonical) # "acquired"
58
+ >>> print(match.similarity) # ~0.82
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ taxonomy: Optional[PredicateTaxonomy] = None,
64
+ config: Optional[PredicateComparisonConfig] = None,
65
+ device: Optional[str] = None,
66
+ ):
67
+ """
68
+ Initialize the predicate comparer.
69
+
70
+ Args:
71
+ taxonomy: Optional canonical predicate taxonomy for normalization
72
+ config: Comparison configuration (uses defaults if not provided)
73
+ device: Device to use ('cuda', 'cpu', or None for auto-detect)
74
+
75
+ Raises:
76
+ EmbeddingDependencyError: If sentence-transformers is not installed
77
+ """
78
+ _check_embedding_dependency()
79
+
80
+ self.taxonomy = taxonomy
81
+ self.config = config or PredicateComparisonConfig()
82
+
83
+ # Auto-detect device
84
+ if device is None:
85
+ import torch
86
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
87
+ else:
88
+ self.device = device
89
+
90
+ # Lazy-loaded resources
91
+ self._model = None
92
+ self._taxonomy_embeddings: Optional[np.ndarray] = None
93
+
94
+ def _load_model(self):
95
+ """Load sentence-transformers model lazily."""
96
+ if self._model is not None:
97
+ return
98
+
99
+ from sentence_transformers import SentenceTransformer
100
+
101
+ logger.info(f"Loading embedding model: {self.config.embedding_model} on {self.device}")
102
+ self._model = SentenceTransformer(self.config.embedding_model, device=self.device)
103
+ logger.info(f"Embedding model loaded on {self.device}")
104
+
105
+ def _normalize_text(self, text: str) -> str:
106
+ """Normalize text before embedding."""
107
+ if self.config.normalize_text:
108
+ return text.lower().strip()
109
+ return text.strip()
110
+
111
+ def _compute_embeddings(self, texts: list[str]) -> np.ndarray:
112
+ """Compute embeddings for a list of texts."""
113
+ self._load_model()
114
+ normalized = [self._normalize_text(t) for t in texts]
115
+ return self._model.encode(normalized, convert_to_numpy=True)
116
+
117
+ def _get_taxonomy_embeddings(self) -> np.ndarray:
118
+ """Get or compute cached taxonomy embeddings."""
119
+ if self.taxonomy is None:
120
+ raise ValueError("No taxonomy provided")
121
+
122
+ if self._taxonomy_embeddings is None:
123
+ logger.debug(f"Computing embeddings for {len(self.taxonomy.predicates)} taxonomy predicates")
124
+ self._taxonomy_embeddings = self._compute_embeddings(self.taxonomy.predicates)
125
+
126
+ return self._taxonomy_embeddings
127
+
128
+ def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
129
+ """Compute cosine similarity between two vectors."""
130
+ dot = np.dot(vec1, vec2)
131
+ norm1 = np.linalg.norm(vec1)
132
+ norm2 = np.linalg.norm(vec2)
133
+ if norm1 == 0 or norm2 == 0:
134
+ return 0.0
135
+ return float(dot / (norm1 * norm2))
136
+
137
+ def _cosine_similarity_batch(self, vec: np.ndarray, matrix: np.ndarray) -> np.ndarray:
138
+ """Compute cosine similarity between a vector and all rows of a matrix."""
139
+ vec_norm = vec / (np.linalg.norm(vec) + 1e-8)
140
+ matrix_norms = np.linalg.norm(matrix, axis=1, keepdims=True) + 1e-8
141
+ matrix_normalized = matrix / matrix_norms
142
+ return np.dot(matrix_normalized, vec_norm)
143
+
144
+ # =========================================================================
145
+ # Public API
146
+ # =========================================================================
147
+
148
+ def match_to_canonical(self, predicate: str) -> PredicateMatch:
149
+ """
150
+ Match a predicate to the closest canonical form in the taxonomy.
151
+
152
+ Args:
153
+ predicate: The extracted predicate to match
154
+
155
+ Returns:
156
+ PredicateMatch with canonical form and similarity score
157
+ """
158
+ if self.taxonomy is None or len(self.taxonomy.predicates) == 0:
159
+ return PredicateMatch(original=predicate)
160
+
161
+ pred_embedding = self._compute_embeddings([predicate])[0]
162
+ taxonomy_embeddings = self._get_taxonomy_embeddings()
163
+
164
+ similarities = self._cosine_similarity_batch(pred_embedding, taxonomy_embeddings)
165
+ best_idx = int(np.argmax(similarities))
166
+ best_score = float(similarities[best_idx])
167
+
168
+ if best_score >= self.config.similarity_threshold:
169
+ return PredicateMatch(
170
+ original=predicate,
171
+ canonical=self.taxonomy.predicates[best_idx],
172
+ similarity=best_score,
173
+ matched=True,
174
+ )
175
+ else:
176
+ return PredicateMatch(
177
+ original=predicate,
178
+ canonical=None,
179
+ similarity=best_score,
180
+ matched=False,
181
+ )
182
+
183
+ def match_batch(self, predicates: list[str]) -> list[PredicateMatch]:
184
+ """
185
+ Match multiple predicates to canonical forms efficiently.
186
+
187
+ Uses batch embedding computation for better performance.
188
+
189
+ Args:
190
+ predicates: List of predicates to match
191
+
192
+ Returns:
193
+ List of PredicateMatch results
194
+ """
195
+ if self.taxonomy is None or len(self.taxonomy.predicates) == 0:
196
+ return [PredicateMatch(original=p) for p in predicates]
197
+
198
+ # Batch embedding computation
199
+ pred_embeddings = self._compute_embeddings(predicates)
200
+ taxonomy_embeddings = self._get_taxonomy_embeddings()
201
+
202
+ results = []
203
+ for i, predicate in enumerate(predicates):
204
+ similarities = self._cosine_similarity_batch(
205
+ pred_embeddings[i],
206
+ taxonomy_embeddings
207
+ )
208
+ best_idx = int(np.argmax(similarities))
209
+ best_score = float(similarities[best_idx])
210
+
211
+ if best_score >= self.config.similarity_threshold:
212
+ results.append(PredicateMatch(
213
+ original=predicate,
214
+ canonical=self.taxonomy.predicates[best_idx],
215
+ similarity=best_score,
216
+ matched=True,
217
+ ))
218
+ else:
219
+ results.append(PredicateMatch(
220
+ original=predicate,
221
+ canonical=None,
222
+ similarity=best_score,
223
+ matched=False,
224
+ ))
225
+
226
+ return results
227
+
228
+ def are_similar(
229
+ self,
230
+ pred1: str,
231
+ pred2: str,
232
+ threshold: Optional[float] = None
233
+ ) -> bool:
234
+ """
235
+ Check if two predicates are semantically similar.
236
+
237
+ Args:
238
+ pred1: First predicate
239
+ pred2: Second predicate
240
+ threshold: Similarity threshold (uses config.dedup_threshold if not provided)
241
+
242
+ Returns:
243
+ True if predicates are similar above threshold
244
+ """
245
+ embeddings = self._compute_embeddings([pred1, pred2])
246
+ similarity = self._cosine_similarity(embeddings[0], embeddings[1])
247
+ threshold = threshold if threshold is not None else self.config.dedup_threshold
248
+ return similarity >= threshold
249
+
250
+ def compute_similarity(self, pred1: str, pred2: str) -> float:
251
+ """
252
+ Compute similarity score between two predicates.
253
+
254
+ Args:
255
+ pred1: First predicate
256
+ pred2: Second predicate
257
+
258
+ Returns:
259
+ Cosine similarity score (0.0 to 1.0)
260
+ """
261
+ embeddings = self._compute_embeddings([pred1, pred2])
262
+ return self._cosine_similarity(embeddings[0], embeddings[1])
263
+
264
+ def deduplicate_statements(
265
+ self,
266
+ statements: list[Statement],
267
+ entity_canonicalizer: Optional[callable] = None,
268
+ detect_reversals: bool = True,
269
+ ) -> list[Statement]:
270
+ """
271
+ Remove duplicate statements using embedding-based predicate comparison.
272
+
273
+ Two statements are considered duplicates if:
274
+ - Canonicalized subjects match AND canonicalized objects match, OR
275
+ - Canonicalized subjects match objects (reversed) when detect_reversals=True
276
+ - Predicates are similar (embedding-based)
277
+
278
+ When duplicates are found, keeps the statement with better contextualized
279
+ match (comparing "Subject Predicate Object" against source text).
280
+
281
+ For reversed duplicates, the correct orientation is determined by comparing
282
+ both "S P O" and "O P S" against source text.
283
+
284
+ Args:
285
+ statements: List of Statement objects
286
+ entity_canonicalizer: Optional function to canonicalize entity text
287
+ detect_reversals: Whether to detect reversed duplicates (default True)
288
+
289
+ Returns:
290
+ Deduplicated list of statements (keeps best contextualized match)
291
+ """
292
+ if len(statements) <= 1:
293
+ return statements
294
+
295
+ def canonicalize(text: str) -> str:
296
+ if entity_canonicalizer:
297
+ return entity_canonicalizer(text)
298
+ return text.lower().strip()
299
+
300
+ # Compute all predicate embeddings at once for efficiency
301
+ predicates = [s.predicate for s in statements]
302
+ pred_embeddings = self._compute_embeddings(predicates)
303
+
304
+ # Compute contextualized embeddings: "Subject Predicate Object" for each statement
305
+ contextualized_texts = [
306
+ f"{s.subject.text} {s.predicate} {s.object.text}" for s in statements
307
+ ]
308
+ contextualized_embeddings = self._compute_embeddings(contextualized_texts)
309
+
310
+ # Compute reversed contextualized embeddings: "Object Predicate Subject"
311
+ reversed_texts = [
312
+ f"{s.object.text} {s.predicate} {s.subject.text}" for s in statements
313
+ ]
314
+ reversed_embeddings = self._compute_embeddings(reversed_texts)
315
+
316
+ # Compute source text embeddings for scoring which duplicate to keep
317
+ source_embeddings = []
318
+ for stmt in statements:
319
+ source_text = stmt.source_text or f"{stmt.subject.text} {stmt.predicate} {stmt.object.text}"
320
+ source_embeddings.append(self._compute_embeddings([source_text])[0])
321
+
322
+ unique_statements: list[Statement] = []
323
+ unique_pred_embeddings: list[np.ndarray] = []
324
+ unique_context_embeddings: list[np.ndarray] = []
325
+ unique_reversed_embeddings: list[np.ndarray] = []
326
+ unique_source_embeddings: list[np.ndarray] = []
327
+ unique_indices: list[int] = []
328
+
329
+ for i, stmt in enumerate(statements):
330
+ subj_canon = canonicalize(stmt.subject.text)
331
+ obj_canon = canonicalize(stmt.object.text)
332
+
333
+ duplicate_idx = None
334
+ is_reversed_match = False
335
+
336
+ for j, unique_stmt in enumerate(unique_statements):
337
+ unique_subj = canonicalize(unique_stmt.subject.text)
338
+ unique_obj = canonicalize(unique_stmt.object.text)
339
+
340
+ # Check direct match: subject->subject, object->object
341
+ direct_match = (subj_canon == unique_subj and obj_canon == unique_obj)
342
+
343
+ # Check reversed match: subject->object, object->subject
344
+ reversed_match = (
345
+ detect_reversals and
346
+ subj_canon == unique_obj and
347
+ obj_canon == unique_subj
348
+ )
349
+
350
+ if not direct_match and not reversed_match:
351
+ continue
352
+
353
+ # Check predicate similarity
354
+ similarity = self._cosine_similarity(
355
+ pred_embeddings[i],
356
+ unique_pred_embeddings[j]
357
+ )
358
+ if similarity >= self.config.dedup_threshold:
359
+ duplicate_idx = j
360
+ is_reversed_match = reversed_match and not direct_match
361
+ break
362
+
363
+ if duplicate_idx is None:
364
+ # Not a duplicate - add to unique list
365
+ unique_statements.append(stmt)
366
+ unique_pred_embeddings.append(pred_embeddings[i])
367
+ unique_context_embeddings.append(contextualized_embeddings[i])
368
+ unique_reversed_embeddings.append(reversed_embeddings[i])
369
+ unique_source_embeddings.append(source_embeddings[i])
370
+ unique_indices.append(i)
371
+ else:
372
+ existing_stmt = unique_statements[duplicate_idx]
373
+
374
+ if is_reversed_match:
375
+ # Reversed duplicate - determine correct orientation using source text
376
+ # Compare current's normal vs reversed against its source
377
+ current_normal_score = self._cosine_similarity(
378
+ contextualized_embeddings[i], source_embeddings[i]
379
+ )
380
+ current_reversed_score = self._cosine_similarity(
381
+ reversed_embeddings[i], source_embeddings[i]
382
+ )
383
+ # Compare existing's normal vs reversed against its source
384
+ existing_normal_score = self._cosine_similarity(
385
+ unique_context_embeddings[duplicate_idx],
386
+ unique_source_embeddings[duplicate_idx]
387
+ )
388
+ existing_reversed_score = self._cosine_similarity(
389
+ unique_reversed_embeddings[duplicate_idx],
390
+ unique_source_embeddings[duplicate_idx]
391
+ )
392
+
393
+ # Determine best orientation for current
394
+ current_best = max(current_normal_score, current_reversed_score)
395
+ current_should_reverse = current_reversed_score > current_normal_score
396
+
397
+ # Determine best orientation for existing
398
+ existing_best = max(existing_normal_score, existing_reversed_score)
399
+ existing_should_reverse = existing_reversed_score > existing_normal_score
400
+
401
+ if current_best > existing_best:
402
+ # Current is better - use it (possibly reversed)
403
+ if current_should_reverse:
404
+ best_stmt = stmt.reversed()
405
+ else:
406
+ best_stmt = stmt
407
+ # Merge entity types from existing (accounting for reversal)
408
+ if existing_should_reverse:
409
+ best_stmt = best_stmt.merge_entity_types_from(existing_stmt.reversed())
410
+ else:
411
+ best_stmt = best_stmt.merge_entity_types_from(existing_stmt)
412
+ unique_statements[duplicate_idx] = best_stmt
413
+ unique_pred_embeddings[duplicate_idx] = pred_embeddings[i]
414
+ unique_context_embeddings[duplicate_idx] = contextualized_embeddings[i]
415
+ unique_reversed_embeddings[duplicate_idx] = reversed_embeddings[i]
416
+ unique_source_embeddings[duplicate_idx] = source_embeddings[i]
417
+ unique_indices[duplicate_idx] = i
418
+ else:
419
+ # Existing is better - possibly fix its orientation
420
+ if existing_should_reverse and not existing_stmt.was_reversed:
421
+ best_stmt = existing_stmt.reversed()
422
+ else:
423
+ best_stmt = existing_stmt
424
+ # Merge entity types from current (accounting for reversal)
425
+ if current_should_reverse:
426
+ best_stmt = best_stmt.merge_entity_types_from(stmt.reversed())
427
+ else:
428
+ best_stmt = best_stmt.merge_entity_types_from(stmt)
429
+ unique_statements[duplicate_idx] = best_stmt
430
+ else:
431
+ # Direct duplicate - keep the one with better contextualized match
432
+ current_score = self._cosine_similarity(
433
+ contextualized_embeddings[i], source_embeddings[i]
434
+ )
435
+ existing_score = self._cosine_similarity(
436
+ unique_context_embeddings[duplicate_idx],
437
+ unique_source_embeddings[duplicate_idx]
438
+ )
439
+
440
+ if current_score > existing_score:
441
+ # Current statement is a better match - replace
442
+ merged_stmt = stmt.merge_entity_types_from(existing_stmt)
443
+ unique_statements[duplicate_idx] = merged_stmt
444
+ unique_pred_embeddings[duplicate_idx] = pred_embeddings[i]
445
+ unique_context_embeddings[duplicate_idx] = contextualized_embeddings[i]
446
+ unique_reversed_embeddings[duplicate_idx] = reversed_embeddings[i]
447
+ unique_source_embeddings[duplicate_idx] = source_embeddings[i]
448
+ unique_indices[duplicate_idx] = i
449
+ else:
450
+ # Existing statement is better - merge entity types from current
451
+ merged_stmt = existing_stmt.merge_entity_types_from(stmt)
452
+ unique_statements[duplicate_idx] = merged_stmt
453
+
454
+ return unique_statements
455
+
456
+ def normalize_predicates(
457
+ self,
458
+ statements: list[Statement]
459
+ ) -> list[Statement]:
460
+ """
461
+ Normalize all predicates in statements to canonical forms.
462
+
463
+ Uses contextualized matching: compares "Subject CanonicalPredicate Object"
464
+ against the statement's source text for better semantic matching.
465
+
466
+ Sets canonical_predicate field on each statement if a match is found.
467
+
468
+ Args:
469
+ statements: List of Statement objects
470
+
471
+ Returns:
472
+ Statements with canonical_predicate field populated
473
+ """
474
+ if self.taxonomy is None or len(self.taxonomy.predicates) == 0:
475
+ return statements
476
+
477
+ for stmt in statements:
478
+ match = self._match_predicate_contextualized(stmt)
479
+ if match.matched and match.canonical:
480
+ stmt.canonical_predicate = match.canonical
481
+
482
+ return statements
483
+
484
+ def _match_predicate_contextualized(self, statement: Statement) -> PredicateMatch:
485
+ """
486
+ Match a statement's predicate to canonical form using full context.
487
+
488
+ Compares "Subject CanonicalPredicate Object" strings against the
489
+ statement's source text for better semantic matching.
490
+
491
+ Args:
492
+ statement: The statement to match
493
+
494
+ Returns:
495
+ PredicateMatch with best canonical form
496
+ """
497
+ if self.taxonomy is None or len(self.taxonomy.predicates) == 0:
498
+ return PredicateMatch(original=statement.predicate)
499
+
500
+ # Get the reference text to compare against
501
+ # Use source_text if available, otherwise construct from components
502
+ reference_text = statement.source_text or f"{statement.subject.text} {statement.predicate} {statement.object.text}"
503
+
504
+ # Compute embedding for the reference text
505
+ reference_embedding = self._compute_embeddings([reference_text])[0]
506
+
507
+ # Construct contextualized strings for each canonical predicate
508
+ # Format: "Subject CanonicalPredicate Object"
509
+ canonical_statements = [
510
+ f"{statement.subject.text} {canonical_pred} {statement.object.text}"
511
+ for canonical_pred in self.taxonomy.predicates
512
+ ]
513
+
514
+ # Compute embeddings for all canonical statement forms
515
+ canonical_embeddings = self._compute_embeddings(canonical_statements)
516
+
517
+ # Find best match
518
+ similarities = self._cosine_similarity_batch(reference_embedding, canonical_embeddings)
519
+ best_idx = int(np.argmax(similarities))
520
+ best_score = float(similarities[best_idx])
521
+
522
+ if best_score >= self.config.similarity_threshold:
523
+ return PredicateMatch(
524
+ original=statement.predicate,
525
+ canonical=self.taxonomy.predicates[best_idx],
526
+ similarity=best_score,
527
+ matched=True,
528
+ )
529
+ else:
530
+ return PredicateMatch(
531
+ original=statement.predicate,
532
+ canonical=None,
533
+ similarity=best_score,
534
+ matched=False,
535
+ )
536
+
537
+ def detect_and_fix_reversals(
538
+ self,
539
+ statements: list[Statement],
540
+ threshold: float = 0.05,
541
+ ) -> list[Statement]:
542
+ """
543
+ Detect and fix subject-object reversals using embedding comparison.
544
+
545
+ For each statement, compares:
546
+ - "Subject Predicate Object" embedding against source_text
547
+ - "Object Predicate Subject" embedding against source_text
548
+
549
+ If the reversed version has significantly higher similarity to the source,
550
+ the subject and object are swapped and was_reversed is set to True.
551
+
552
+ Args:
553
+ statements: List of Statement objects
554
+ threshold: Minimum similarity difference to trigger reversal (default 0.05)
555
+
556
+ Returns:
557
+ List of statements with reversals corrected
558
+ """
559
+ if not statements:
560
+ return statements
561
+
562
+ result = []
563
+ for stmt in statements:
564
+ # Skip if no source_text to compare against
565
+ if not stmt.source_text:
566
+ result.append(stmt)
567
+ continue
568
+
569
+ # Build normal and reversed triple strings
570
+ normal_text = f"{stmt.subject.text} {stmt.predicate} {stmt.object.text}"
571
+ reversed_text = f"{stmt.object.text} {stmt.predicate} {stmt.subject.text}"
572
+
573
+ # Compute embeddings for normal, reversed, and source
574
+ embeddings = self._compute_embeddings([normal_text, reversed_text, stmt.source_text])
575
+ normal_emb, reversed_emb, source_emb = embeddings[0], embeddings[1], embeddings[2]
576
+
577
+ # Compute similarities to source
578
+ normal_sim = self._cosine_similarity(normal_emb, source_emb)
579
+ reversed_sim = self._cosine_similarity(reversed_emb, source_emb)
580
+
581
+ # If reversed is significantly better, swap subject and object
582
+ if reversed_sim > normal_sim + threshold:
583
+ result.append(stmt.reversed())
584
+ else:
585
+ result.append(stmt)
586
+
587
+ return result
588
+
589
+ def check_reversal(self, statement: Statement) -> tuple[bool, float, float]:
590
+ """
591
+ Check if a single statement should be reversed.
592
+
593
+ Args:
594
+ statement: Statement to check
595
+
596
+ Returns:
597
+ Tuple of (should_reverse, normal_similarity, reversed_similarity)
598
+ """
599
+ if not statement.source_text:
600
+ return (False, 0.0, 0.0)
601
+
602
+ normal_text = f"{statement.subject.text} {statement.predicate} {statement.object.text}"
603
+ reversed_text = f"{statement.object.text} {statement.predicate} {statement.subject.text}"
604
+
605
+ embeddings = self._compute_embeddings([normal_text, reversed_text, statement.source_text])
606
+ normal_emb, reversed_emb, source_emb = embeddings[0], embeddings[1], embeddings[2]
607
+
608
+ normal_sim = self._cosine_similarity(normal_emb, source_emb)
609
+ reversed_sim = self._cosine_similarity(reversed_emb, source_emb)
610
+
611
+ return (reversed_sim > normal_sim, normal_sim, reversed_sim)