corp-extractor 0.2.2__tar.gz → 0.2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: corp-extractor
3
- Version: 0.2.2
3
+ Version: 0.2.3
4
4
  Summary: Extract structured statements from text using T5-Gemma 2 and Diverse Beam Search
5
5
  Project-URL: Homepage, https://github.com/corp-o-rate/statement-extractor
6
6
  Project-URL: Documentation, https://github.com/corp-o-rate/statement-extractor#readme
@@ -55,6 +55,8 @@ Extract structured subject-predicate-object statements from unstructured text us
55
55
  - **Embedding-based Dedup** *(v0.2.0)*: Uses semantic similarity to detect near-duplicate predicates
56
56
  - **Predicate Taxonomies** *(v0.2.0)*: Normalize predicates to canonical forms via embeddings
57
57
  - **Contextualized Matching** *(v0.2.2)*: Compares full "Subject Predicate Object" against source text for better accuracy
58
+ - **Entity Type Merging** *(v0.2.3)*: Automatically merges UNKNOWN entity types with specific types during deduplication
59
+ - **Reversal Detection** *(v0.2.3)*: Detects and corrects subject-object reversals using embedding comparison
58
60
  - **Multiple Output Formats**: Get results as Pydantic models, JSON, XML, or dictionaries
59
61
 
60
62
  ## Installation
@@ -139,6 +141,47 @@ Predicate canonicalization and deduplication now use **contextualized matching**
139
141
 
140
142
  This means "Apple bought Beats" vs "Apple acquired Beats" are compared holistically, not just "bought" vs "acquired".
141
143
 
144
+ ## New in v0.2.3: Entity Type Merging & Reversal Detection
145
+
146
+ ### Entity Type Merging
147
+
148
+ When deduplicating statements, entity types are now automatically merged. If one statement has `UNKNOWN` type and a duplicate has a specific type (like `ORG` or `PERSON`), the specific type is preserved:
149
+
150
+ ```python
151
+ # Before deduplication:
152
+ # Statement 1: AtlasBio Labs (UNKNOWN) --sued by--> CuraPharm (ORG)
153
+ # Statement 2: AtlasBio Labs (ORG) --sued by--> CuraPharm (ORG)
154
+
155
+ # After deduplication:
156
+ # Single statement: AtlasBio Labs (ORG) --sued by--> CuraPharm (ORG)
157
+ ```
158
+
159
+ ### Subject-Object Reversal Detection
160
+
161
+ The library now detects when subject and object may have been extracted in the wrong order by comparing embeddings against source text:
162
+
163
+ ```python
164
+ from statement_extractor import PredicateComparer
165
+
166
+ comparer = PredicateComparer()
167
+
168
+ # Automatically detect and fix reversals
169
+ fixed_statements = comparer.detect_and_fix_reversals(statements)
170
+
171
+ for stmt in fixed_statements:
172
+ if stmt.was_reversed:
173
+ print(f"Fixed reversal: {stmt}")
174
+ ```
175
+
176
+ **How it works:**
177
+ 1. For each statement with source text, compares:
178
+ - "Subject Predicate Object" embedding vs source text
179
+ - "Object Predicate Subject" embedding vs source text
180
+ 2. If the reversed form has higher similarity, swaps subject and object
181
+ 3. Sets `was_reversed=True` to indicate the correction
182
+
183
+ During deduplication, reversed duplicates (e.g., "A -> P -> B" and "B -> P -> A") are now detected and merged, with the correct orientation determined by source text similarity.
184
+
142
185
  ## Disable Embeddings (Faster, No Extra Dependencies)
143
186
 
144
187
  ```python
@@ -213,6 +256,8 @@ This library uses the T5-Gemma 2 statement extraction model with **Diverse Beam
213
256
  4. **Embedding Dedup** *(v0.2.0)*: Semantic similarity removes near-duplicate predicates
214
257
  5. **Predicate Normalization** *(v0.2.0)*: Optional taxonomy matching via embeddings
215
258
  6. **Contextualized Matching** *(v0.2.2)*: Full statement context used for canonicalization and dedup
259
+ 7. **Entity Type Merging** *(v0.2.3)*: UNKNOWN types merged with specific types during dedup
260
+ 8. **Reversal Detection** *(v0.2.3)*: Subject-object reversals detected and corrected via embedding comparison
216
261
 
217
262
  ## Requirements
218
263
 
@@ -15,6 +15,8 @@ Extract structured subject-predicate-object statements from unstructured text us
15
15
  - **Embedding-based Dedup** *(v0.2.0)*: Uses semantic similarity to detect near-duplicate predicates
16
16
  - **Predicate Taxonomies** *(v0.2.0)*: Normalize predicates to canonical forms via embeddings
17
17
  - **Contextualized Matching** *(v0.2.2)*: Compares full "Subject Predicate Object" against source text for better accuracy
18
+ - **Entity Type Merging** *(v0.2.3)*: Automatically merges UNKNOWN entity types with specific types during deduplication
19
+ - **Reversal Detection** *(v0.2.3)*: Detects and corrects subject-object reversals using embedding comparison
18
20
  - **Multiple Output Formats**: Get results as Pydantic models, JSON, XML, or dictionaries
19
21
 
20
22
  ## Installation
@@ -99,6 +101,47 @@ Predicate canonicalization and deduplication now use **contextualized matching**
99
101
 
100
102
  This means "Apple bought Beats" vs "Apple acquired Beats" are compared holistically, not just "bought" vs "acquired".
101
103
 
104
+ ## New in v0.2.3: Entity Type Merging & Reversal Detection
105
+
106
+ ### Entity Type Merging
107
+
108
+ When deduplicating statements, entity types are now automatically merged. If one statement has `UNKNOWN` type and a duplicate has a specific type (like `ORG` or `PERSON`), the specific type is preserved:
109
+
110
+ ```python
111
+ # Before deduplication:
112
+ # Statement 1: AtlasBio Labs (UNKNOWN) --sued by--> CuraPharm (ORG)
113
+ # Statement 2: AtlasBio Labs (ORG) --sued by--> CuraPharm (ORG)
114
+
115
+ # After deduplication:
116
+ # Single statement: AtlasBio Labs (ORG) --sued by--> CuraPharm (ORG)
117
+ ```
118
+
119
+ ### Subject-Object Reversal Detection
120
+
121
+ The library now detects when subject and object may have been extracted in the wrong order by comparing embeddings against source text:
122
+
123
+ ```python
124
+ from statement_extractor import PredicateComparer
125
+
126
+ comparer = PredicateComparer()
127
+
128
+ # Automatically detect and fix reversals
129
+ fixed_statements = comparer.detect_and_fix_reversals(statements)
130
+
131
+ for stmt in fixed_statements:
132
+ if stmt.was_reversed:
133
+ print(f"Fixed reversal: {stmt}")
134
+ ```
135
+
136
+ **How it works:**
137
+ 1. For each statement with source text, compares:
138
+ - "Subject Predicate Object" embedding vs source text
139
+ - "Object Predicate Subject" embedding vs source text
140
+ 2. If the reversed form has higher similarity, swaps subject and object
141
+ 3. Sets `was_reversed=True` to indicate the correction
142
+
143
+ During deduplication, reversed duplicates (e.g., "A -> P -> B" and "B -> P -> A") are now detected and merged, with the correct orientation determined by source text similarity.
144
+
102
145
  ## Disable Embeddings (Faster, No Extra Dependencies)
103
146
 
104
147
  ```python
@@ -173,6 +216,8 @@ This library uses the T5-Gemma 2 statement extraction model with **Diverse Beam
173
216
  4. **Embedding Dedup** *(v0.2.0)*: Semantic similarity removes near-duplicate predicates
174
217
  5. **Predicate Normalization** *(v0.2.0)*: Optional taxonomy matching via embeddings
175
218
  6. **Contextualized Matching** *(v0.2.2)*: Full statement context used for canonicalization and dedup
219
+ 7. **Entity Type Merging** *(v0.2.3)*: UNKNOWN types merged with specific types during dedup
220
+ 8. **Reversal Detection** *(v0.2.3)*: Subject-object reversals detected and corrected via embedding comparison
176
221
 
177
222
  ## Requirements
178
223
 
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "corp-extractor"
7
- version = "0.2.2"
7
+ version = "0.2.3"
8
8
  description = "Extract structured statements from text using T5-Gemma 2 and Diverse Beam Search"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.10"
@@ -139,32 +139,58 @@ class Canonicalizer:
139
139
 
140
140
  def deduplicate_statements_exact(
141
141
  statements: list[Statement],
142
- entity_canonicalizer: Optional[Callable[[str], str]] = None
142
+ entity_canonicalizer: Optional[Callable[[str], str]] = None,
143
+ detect_reversals: bool = True,
143
144
  ) -> list[Statement]:
144
145
  """
145
146
  Deduplicate statements using exact text matching.
146
147
 
147
148
  Use this when embedding-based deduplication is disabled.
149
+ When duplicates are found, entity types are merged - specific types
150
+ (ORG, PERSON, etc.) take precedence over UNKNOWN.
151
+
152
+ When detect_reversals=True, also detects reversed duplicates where
153
+ subject and object are swapped. The first occurrence determines the
154
+ canonical orientation.
148
155
 
149
156
  Args:
150
157
  statements: List of statements to deduplicate
151
158
  entity_canonicalizer: Optional custom canonicalization function
159
+ detect_reversals: Whether to detect reversed duplicates (default True)
152
160
 
153
161
  Returns:
154
- Deduplicated list (keeps first occurrence)
162
+ Deduplicated list with merged entity types
155
163
  """
156
164
  if len(statements) <= 1:
157
165
  return statements
158
166
 
159
167
  canonicalizer = Canonicalizer(entity_fn=entity_canonicalizer)
160
168
 
161
- seen: set[tuple[str, str, str]] = set()
169
+ # Map from dedup key to index in unique list
170
+ seen: dict[tuple[str, str, str], int] = {}
162
171
  unique: list[Statement] = []
163
172
 
164
173
  for stmt in statements:
165
174
  key = canonicalizer.create_dedup_key(stmt)
166
- if key not in seen:
167
- seen.add(key)
175
+ # Also compute reversed key (object, predicate, subject)
176
+ reversed_key = (key[2], key[1], key[0])
177
+
178
+ if key in seen:
179
+ # Direct duplicate found - merge entity types
180
+ existing_idx = seen[key]
181
+ existing_stmt = unique[existing_idx]
182
+ merged_stmt = existing_stmt.merge_entity_types_from(stmt)
183
+ unique[existing_idx] = merged_stmt
184
+ elif detect_reversals and reversed_key in seen:
185
+ # Reversed duplicate found - merge entity types (accounting for reversal)
186
+ existing_idx = seen[reversed_key]
187
+ existing_stmt = unique[existing_idx]
188
+ # Merge types from the reversed statement
189
+ merged_stmt = existing_stmt.merge_entity_types_from(stmt.reversed())
190
+ unique[existing_idx] = merged_stmt
191
+ else:
192
+ # New unique statement
193
+ seen[key] = len(unique)
168
194
  unique.append(stmt)
169
195
 
170
196
  return unique
@@ -143,7 +143,7 @@ class StatementExtractor:
143
143
 
144
144
  taxonomy = options.predicate_taxonomy or self._predicate_taxonomy
145
145
  config = options.predicate_config or self._predicate_config or PredicateComparisonConfig()
146
- return PredicateComparer(taxonomy=taxonomy, config=config)
146
+ return PredicateComparer(taxonomy=taxonomy, config=config, device=self.device)
147
147
 
148
148
  @property
149
149
  def model(self) -> AutoModelForSeq2SeqLM:
@@ -362,18 +362,25 @@ class StatementExtractor:
362
362
  end_tag = "</statements>"
363
363
  candidates: list[str] = []
364
364
 
365
- for output in outputs:
365
+ for i, output in enumerate(outputs):
366
366
  decoded = self.tokenizer.decode(output, skip_special_tokens=True)
367
+ output_len = len(output)
367
368
 
368
369
  # Truncate at </statements>
369
370
  if end_tag in decoded:
370
371
  end_pos = decoded.find(end_tag) + len(end_tag)
371
372
  decoded = decoded[:end_pos]
372
373
  candidates.append(decoded)
374
+ logger.debug(f"Beam {i}: {output_len} tokens, found end tag, {len(decoded)} chars")
375
+ else:
376
+ # Log the issue - likely truncated
377
+ logger.warning(f"Beam {i}: {output_len} tokens, NO end tag found (truncated?)")
378
+ logger.warning(f"Beam {i} full output ({len(decoded)} chars):\n{decoded}")
373
379
 
374
380
  # Include fallback if no valid candidates
375
381
  if not candidates and len(outputs) > 0:
376
382
  fallback = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
383
+ logger.warning(f"Using fallback beam (no valid candidates found), {len(fallback)} chars")
377
384
  candidates.append(fallback)
378
385
 
379
386
  return candidates
@@ -467,8 +474,10 @@ class StatementExtractor:
467
474
 
468
475
  try:
469
476
  root = ET.fromstring(xml_output)
470
- except ET.ParseError:
471
- logger.warning("Failed to parse XML output")
477
+ except ET.ParseError as e:
478
+ # Log full output for debugging
479
+ logger.warning(f"Failed to parse XML output: {e}")
480
+ logger.warning(f"Full XML output ({len(xml_output)} chars):\n{xml_output}")
472
481
  return statements
473
482
 
474
483
  if root.tag != 'statements':
@@ -32,6 +32,18 @@ class Entity(BaseModel):
32
32
  def __str__(self) -> str:
33
33
  return f"{self.text} ({self.type.value})"
34
34
 
35
+ def merge_type_from(self, other: "Entity") -> "Entity":
36
+ """
37
+ Return a new Entity with the more specific type.
38
+
39
+ If this entity has UNKNOWN type and other has a specific type,
40
+ returns a new entity with this text but other's type.
41
+ Otherwise returns self unchanged.
42
+ """
43
+ if self.type == EntityType.UNKNOWN and other.type != EntityType.UNKNOWN:
44
+ return Entity(text=self.text, type=other.type)
45
+ return self
46
+
35
47
 
36
48
  class Statement(BaseModel):
37
49
  """A single extracted statement (subject-predicate-object triple)."""
@@ -55,6 +67,10 @@ class Statement(BaseModel):
55
67
  None,
56
68
  description="Canonical form of the predicate if taxonomy matching was used"
57
69
  )
70
+ was_reversed: bool = Field(
71
+ default=False,
72
+ description="True if subject/object were swapped during reversal detection"
73
+ )
58
74
 
59
75
  def __str__(self) -> str:
60
76
  return f"{self.subject.text} -- {self.predicate} --> {self.object.text}"
@@ -63,6 +79,49 @@ class Statement(BaseModel):
63
79
  """Return as a simple (subject, predicate, object) tuple."""
64
80
  return (self.subject.text, self.predicate, self.object.text)
65
81
 
82
+ def merge_entity_types_from(self, other: "Statement") -> "Statement":
83
+ """
84
+ Return a new Statement with more specific entity types merged from other.
85
+
86
+ If this statement has UNKNOWN entity types and other has specific types,
87
+ the returned statement will use the specific types from other.
88
+ All other fields come from self.
89
+ """
90
+ merged_subject = self.subject.merge_type_from(other.subject)
91
+ merged_object = self.object.merge_type_from(other.object)
92
+
93
+ # Only create new statement if something changed
94
+ if merged_subject is self.subject and merged_object is self.object:
95
+ return self
96
+
97
+ return Statement(
98
+ subject=merged_subject,
99
+ object=merged_object,
100
+ predicate=self.predicate,
101
+ source_text=self.source_text,
102
+ confidence_score=self.confidence_score,
103
+ evidence_span=self.evidence_span,
104
+ canonical_predicate=self.canonical_predicate,
105
+ was_reversed=self.was_reversed,
106
+ )
107
+
108
+ def reversed(self) -> "Statement":
109
+ """
110
+ Return a new Statement with subject and object swapped.
111
+
112
+ Sets was_reversed=True to indicate the swap occurred.
113
+ """
114
+ return Statement(
115
+ subject=self.object,
116
+ object=self.subject,
117
+ predicate=self.predicate,
118
+ source_text=self.source_text,
119
+ confidence_score=self.confidence_score,
120
+ evidence_span=self.evidence_span,
121
+ canonical_predicate=self.canonical_predicate,
122
+ was_reversed=True,
123
+ )
124
+
66
125
 
67
126
  class ExtractionResult(BaseModel):
68
127
  """The result of statement extraction from text."""
@@ -62,6 +62,7 @@ class PredicateComparer:
62
62
  self,
63
63
  taxonomy: Optional[PredicateTaxonomy] = None,
64
64
  config: Optional[PredicateComparisonConfig] = None,
65
+ device: Optional[str] = None,
65
66
  ):
66
67
  """
67
68
  Initialize the predicate comparer.
@@ -69,6 +70,7 @@ class PredicateComparer:
69
70
  Args:
70
71
  taxonomy: Optional canonical predicate taxonomy for normalization
71
72
  config: Comparison configuration (uses defaults if not provided)
73
+ device: Device to use ('cuda', 'cpu', or None for auto-detect)
72
74
 
73
75
  Raises:
74
76
  EmbeddingDependencyError: If sentence-transformers is not installed
@@ -78,6 +80,13 @@ class PredicateComparer:
78
80
  self.taxonomy = taxonomy
79
81
  self.config = config or PredicateComparisonConfig()
80
82
 
83
+ # Auto-detect device
84
+ if device is None:
85
+ import torch
86
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
87
+ else:
88
+ self.device = device
89
+
81
90
  # Lazy-loaded resources
82
91
  self._model = None
83
92
  self._taxonomy_embeddings: Optional[np.ndarray] = None
@@ -89,9 +98,9 @@ class PredicateComparer:
89
98
 
90
99
  from sentence_transformers import SentenceTransformer
91
100
 
92
- logger.info(f"Loading embedding model: {self.config.embedding_model}")
93
- self._model = SentenceTransformer(self.config.embedding_model)
94
- logger.info("Embedding model loaded")
101
+ logger.info(f"Loading embedding model: {self.config.embedding_model} on {self.device}")
102
+ self._model = SentenceTransformer(self.config.embedding_model, device=self.device)
103
+ logger.info(f"Embedding model loaded on {self.device}")
95
104
 
96
105
  def _normalize_text(self, text: str) -> str:
97
106
  """Normalize text before embedding."""
@@ -256,21 +265,26 @@ class PredicateComparer:
256
265
  self,
257
266
  statements: list[Statement],
258
267
  entity_canonicalizer: Optional[callable] = None,
268
+ detect_reversals: bool = True,
259
269
  ) -> list[Statement]:
260
270
  """
261
271
  Remove duplicate statements using embedding-based predicate comparison.
262
272
 
263
273
  Two statements are considered duplicates if:
264
- - Canonicalized subjects match
274
+ - Canonicalized subjects match AND canonicalized objects match, OR
275
+ - Canonicalized subjects match objects (reversed) when detect_reversals=True
265
276
  - Predicates are similar (embedding-based)
266
- - Canonicalized objects match
267
277
 
268
278
  When duplicates are found, keeps the statement with better contextualized
269
279
  match (comparing "Subject Predicate Object" against source text).
270
280
 
281
+ For reversed duplicates, the correct orientation is determined by comparing
282
+ both "S P O" and "O P S" against source text.
283
+
271
284
  Args:
272
285
  statements: List of Statement objects
273
286
  entity_canonicalizer: Optional function to canonicalize entity text
287
+ detect_reversals: Whether to detect reversed duplicates (default True)
274
288
 
275
289
  Returns:
276
290
  Deduplicated list of statements (keeps best contextualized match)
@@ -293,6 +307,12 @@ class PredicateComparer:
293
307
  ]
294
308
  contextualized_embeddings = self._compute_embeddings(contextualized_texts)
295
309
 
310
+ # Compute reversed contextualized embeddings: "Object Predicate Subject"
311
+ reversed_texts = [
312
+ f"{s.object.text} {s.predicate} {s.subject.text}" for s in statements
313
+ ]
314
+ reversed_embeddings = self._compute_embeddings(reversed_texts)
315
+
296
316
  # Compute source text embeddings for scoring which duplicate to keep
297
317
  source_embeddings = []
298
318
  for stmt in statements:
@@ -302,6 +322,7 @@ class PredicateComparer:
302
322
  unique_statements: list[Statement] = []
303
323
  unique_pred_embeddings: list[np.ndarray] = []
304
324
  unique_context_embeddings: list[np.ndarray] = []
325
+ unique_reversed_embeddings: list[np.ndarray] = []
305
326
  unique_source_embeddings: list[np.ndarray] = []
306
327
  unique_indices: list[int] = []
307
328
 
@@ -310,13 +331,23 @@ class PredicateComparer:
310
331
  obj_canon = canonicalize(stmt.object.text)
311
332
 
312
333
  duplicate_idx = None
334
+ is_reversed_match = False
313
335
 
314
336
  for j, unique_stmt in enumerate(unique_statements):
315
337
  unique_subj = canonicalize(unique_stmt.subject.text)
316
338
  unique_obj = canonicalize(unique_stmt.object.text)
317
339
 
318
- # Check subject and object match
319
- if subj_canon != unique_subj or obj_canon != unique_obj:
340
+ # Check direct match: subject->subject, object->object
341
+ direct_match = (subj_canon == unique_subj and obj_canon == unique_obj)
342
+
343
+ # Check reversed match: subject->object, object->subject
344
+ reversed_match = (
345
+ detect_reversals and
346
+ subj_canon == unique_obj and
347
+ obj_canon == unique_subj
348
+ )
349
+
350
+ if not direct_match and not reversed_match:
320
351
  continue
321
352
 
322
353
  # Check predicate similarity
@@ -326,6 +357,7 @@ class PredicateComparer:
326
357
  )
327
358
  if similarity >= self.config.dedup_threshold:
328
359
  duplicate_idx = j
360
+ is_reversed_match = reversed_match and not direct_match
329
361
  break
330
362
 
331
363
  if duplicate_idx is None:
@@ -333,27 +365,91 @@ class PredicateComparer:
333
365
  unique_statements.append(stmt)
334
366
  unique_pred_embeddings.append(pred_embeddings[i])
335
367
  unique_context_embeddings.append(contextualized_embeddings[i])
368
+ unique_reversed_embeddings.append(reversed_embeddings[i])
336
369
  unique_source_embeddings.append(source_embeddings[i])
337
370
  unique_indices.append(i)
338
371
  else:
339
- # Duplicate found - keep the one with better contextualized match
340
- # Compare "Subject Predicate Object" against source text
341
- current_score = self._cosine_similarity(
342
- contextualized_embeddings[i],
343
- source_embeddings[i]
344
- )
345
- existing_score = self._cosine_similarity(
346
- unique_context_embeddings[duplicate_idx],
347
- unique_source_embeddings[duplicate_idx]
348
- )
349
-
350
- if current_score > existing_score:
351
- # Current statement is a better match - replace
352
- unique_statements[duplicate_idx] = stmt
353
- unique_pred_embeddings[duplicate_idx] = pred_embeddings[i]
354
- unique_context_embeddings[duplicate_idx] = contextualized_embeddings[i]
355
- unique_source_embeddings[duplicate_idx] = source_embeddings[i]
356
- unique_indices[duplicate_idx] = i
372
+ existing_stmt = unique_statements[duplicate_idx]
373
+
374
+ if is_reversed_match:
375
+ # Reversed duplicate - determine correct orientation using source text
376
+ # Compare current's normal vs reversed against its source
377
+ current_normal_score = self._cosine_similarity(
378
+ contextualized_embeddings[i], source_embeddings[i]
379
+ )
380
+ current_reversed_score = self._cosine_similarity(
381
+ reversed_embeddings[i], source_embeddings[i]
382
+ )
383
+ # Compare existing's normal vs reversed against its source
384
+ existing_normal_score = self._cosine_similarity(
385
+ unique_context_embeddings[duplicate_idx],
386
+ unique_source_embeddings[duplicate_idx]
387
+ )
388
+ existing_reversed_score = self._cosine_similarity(
389
+ unique_reversed_embeddings[duplicate_idx],
390
+ unique_source_embeddings[duplicate_idx]
391
+ )
392
+
393
+ # Determine best orientation for current
394
+ current_best = max(current_normal_score, current_reversed_score)
395
+ current_should_reverse = current_reversed_score > current_normal_score
396
+
397
+ # Determine best orientation for existing
398
+ existing_best = max(existing_normal_score, existing_reversed_score)
399
+ existing_should_reverse = existing_reversed_score > existing_normal_score
400
+
401
+ if current_best > existing_best:
402
+ # Current is better - use it (possibly reversed)
403
+ if current_should_reverse:
404
+ best_stmt = stmt.reversed()
405
+ else:
406
+ best_stmt = stmt
407
+ # Merge entity types from existing (accounting for reversal)
408
+ if existing_should_reverse:
409
+ best_stmt = best_stmt.merge_entity_types_from(existing_stmt.reversed())
410
+ else:
411
+ best_stmt = best_stmt.merge_entity_types_from(existing_stmt)
412
+ unique_statements[duplicate_idx] = best_stmt
413
+ unique_pred_embeddings[duplicate_idx] = pred_embeddings[i]
414
+ unique_context_embeddings[duplicate_idx] = contextualized_embeddings[i]
415
+ unique_reversed_embeddings[duplicate_idx] = reversed_embeddings[i]
416
+ unique_source_embeddings[duplicate_idx] = source_embeddings[i]
417
+ unique_indices[duplicate_idx] = i
418
+ else:
419
+ # Existing is better - possibly fix its orientation
420
+ if existing_should_reverse and not existing_stmt.was_reversed:
421
+ best_stmt = existing_stmt.reversed()
422
+ else:
423
+ best_stmt = existing_stmt
424
+ # Merge entity types from current (accounting for reversal)
425
+ if current_should_reverse:
426
+ best_stmt = best_stmt.merge_entity_types_from(stmt.reversed())
427
+ else:
428
+ best_stmt = best_stmt.merge_entity_types_from(stmt)
429
+ unique_statements[duplicate_idx] = best_stmt
430
+ else:
431
+ # Direct duplicate - keep the one with better contextualized match
432
+ current_score = self._cosine_similarity(
433
+ contextualized_embeddings[i], source_embeddings[i]
434
+ )
435
+ existing_score = self._cosine_similarity(
436
+ unique_context_embeddings[duplicate_idx],
437
+ unique_source_embeddings[duplicate_idx]
438
+ )
439
+
440
+ if current_score > existing_score:
441
+ # Current statement is a better match - replace
442
+ merged_stmt = stmt.merge_entity_types_from(existing_stmt)
443
+ unique_statements[duplicate_idx] = merged_stmt
444
+ unique_pred_embeddings[duplicate_idx] = pred_embeddings[i]
445
+ unique_context_embeddings[duplicate_idx] = contextualized_embeddings[i]
446
+ unique_reversed_embeddings[duplicate_idx] = reversed_embeddings[i]
447
+ unique_source_embeddings[duplicate_idx] = source_embeddings[i]
448
+ unique_indices[duplicate_idx] = i
449
+ else:
450
+ # Existing statement is better - merge entity types from current
451
+ merged_stmt = existing_stmt.merge_entity_types_from(stmt)
452
+ unique_statements[duplicate_idx] = merged_stmt
357
453
 
358
454
  return unique_statements
359
455
 
@@ -437,3 +533,79 @@ class PredicateComparer:
437
533
  similarity=best_score,
438
534
  matched=False,
439
535
  )
536
+
537
+ def detect_and_fix_reversals(
538
+ self,
539
+ statements: list[Statement],
540
+ threshold: float = 0.05,
541
+ ) -> list[Statement]:
542
+ """
543
+ Detect and fix subject-object reversals using embedding comparison.
544
+
545
+ For each statement, compares:
546
+ - "Subject Predicate Object" embedding against source_text
547
+ - "Object Predicate Subject" embedding against source_text
548
+
549
+ If the reversed version has significantly higher similarity to the source,
550
+ the subject and object are swapped and was_reversed is set to True.
551
+
552
+ Args:
553
+ statements: List of Statement objects
554
+ threshold: Minimum similarity difference to trigger reversal (default 0.05)
555
+
556
+ Returns:
557
+ List of statements with reversals corrected
558
+ """
559
+ if not statements:
560
+ return statements
561
+
562
+ result = []
563
+ for stmt in statements:
564
+ # Skip if no source_text to compare against
565
+ if not stmt.source_text:
566
+ result.append(stmt)
567
+ continue
568
+
569
+ # Build normal and reversed triple strings
570
+ normal_text = f"{stmt.subject.text} {stmt.predicate} {stmt.object.text}"
571
+ reversed_text = f"{stmt.object.text} {stmt.predicate} {stmt.subject.text}"
572
+
573
+ # Compute embeddings for normal, reversed, and source
574
+ embeddings = self._compute_embeddings([normal_text, reversed_text, stmt.source_text])
575
+ normal_emb, reversed_emb, source_emb = embeddings[0], embeddings[1], embeddings[2]
576
+
577
+ # Compute similarities to source
578
+ normal_sim = self._cosine_similarity(normal_emb, source_emb)
579
+ reversed_sim = self._cosine_similarity(reversed_emb, source_emb)
580
+
581
+ # If reversed is significantly better, swap subject and object
582
+ if reversed_sim > normal_sim + threshold:
583
+ result.append(stmt.reversed())
584
+ else:
585
+ result.append(stmt)
586
+
587
+ return result
588
+
589
+ def check_reversal(self, statement: Statement) -> tuple[bool, float, float]:
590
+ """
591
+ Check if a single statement should be reversed.
592
+
593
+ Args:
594
+ statement: Statement to check
595
+
596
+ Returns:
597
+ Tuple of (should_reverse, normal_similarity, reversed_similarity)
598
+ """
599
+ if not statement.source_text:
600
+ return (False, 0.0, 0.0)
601
+
602
+ normal_text = f"{statement.subject.text} {statement.predicate} {statement.object.text}"
603
+ reversed_text = f"{statement.object.text} {statement.predicate} {statement.subject.text}"
604
+
605
+ embeddings = self._compute_embeddings([normal_text, reversed_text, statement.source_text])
606
+ normal_emb, reversed_emb, source_emb = embeddings[0], embeddings[1], embeddings[2]
607
+
608
+ normal_sim = self._cosine_similarity(normal_emb, source_emb)
609
+ reversed_sim = self._cosine_similarity(reversed_emb, source_emb)
610
+
611
+ return (reversed_sim > normal_sim, normal_sim, reversed_sim)