corp-extractor 0.2.2__tar.gz → 0.2.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: corp-extractor
3
- Version: 0.2.2
3
+ Version: 0.2.8
4
4
  Summary: Extract structured statements from text using T5-Gemma 2 and Diverse Beam Search
5
5
  Project-URL: Homepage, https://github.com/corp-o-rate/statement-extractor
6
6
  Project-URL: Documentation, https://github.com/corp-o-rate/statement-extractor#readme
@@ -23,10 +23,11 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
23
23
  Classifier: Topic :: Scientific/Engineering :: Information Analysis
24
24
  Classifier: Topic :: Text Processing :: Linguistic
25
25
  Requires-Python: >=3.10
26
+ Requires-Dist: click>=8.0.0
26
27
  Requires-Dist: numpy>=1.24.0
27
28
  Requires-Dist: pydantic>=2.0.0
28
29
  Requires-Dist: torch>=2.0.0
29
- Requires-Dist: transformers>=4.35.0
30
+ Requires-Dist: transformers>=5.0.0rc3
30
31
  Provides-Extra: all
31
32
  Requires-Dist: sentence-transformers>=2.2.0; extra == 'all'
32
33
  Provides-Extra: dev
@@ -55,22 +56,30 @@ Extract structured subject-predicate-object statements from unstructured text us
55
56
  - **Embedding-based Dedup** *(v0.2.0)*: Uses semantic similarity to detect near-duplicate predicates
56
57
  - **Predicate Taxonomies** *(v0.2.0)*: Normalize predicates to canonical forms via embeddings
57
58
  - **Contextualized Matching** *(v0.2.2)*: Compares full "Subject Predicate Object" against source text for better accuracy
59
+ - **Entity Type Merging** *(v0.2.3)*: Automatically merges UNKNOWN entity types with specific types during deduplication
60
+ - **Reversal Detection** *(v0.2.3)*: Detects and corrects subject-object reversals using embedding comparison
61
+ - **Command Line Interface** *(v0.2.4)*: Full-featured CLI for terminal usage
58
62
  - **Multiple Output Formats**: Get results as Pydantic models, JSON, XML, or dictionaries
59
63
 
60
64
  ## Installation
61
65
 
62
66
  ```bash
63
67
  # Recommended: include embedding support for smart deduplication
64
- pip install corp-extractor[embeddings]
68
+ pip install "corp-extractor[embeddings]"
65
69
 
66
70
  # Minimal installation (no embedding features)
67
71
  pip install corp-extractor
68
72
  ```
69
73
 
70
- **Note**: For GPU support, install PyTorch with CUDA first:
74
+ **Note**: This package requires `transformers>=5.0.0` (pre-release) for T5-Gemma2 model support. Install with `--pre` flag if needed:
75
+ ```bash
76
+ pip install --pre "corp-extractor[embeddings]"
77
+ ```
78
+
79
+ **For GPU support**, install PyTorch with CUDA first:
71
80
  ```bash
72
81
  pip install torch --index-url https://download.pytorch.org/whl/cu121
73
- pip install corp-extractor[embeddings]
82
+ pip install "corp-extractor[embeddings]"
74
83
  ```
75
84
 
76
85
  ## Quick Start
@@ -89,6 +98,96 @@ for stmt in result:
89
98
  print(f" Confidence: {stmt.confidence_score:.2f}") # NEW in v0.2.0
90
99
  ```
91
100
 
101
+ ## Command Line Interface
102
+
103
+ The library includes a CLI for quick extraction from the terminal.
104
+
105
+ ### Install Globally (Recommended)
106
+
107
+ For best results, install globally first:
108
+
109
+ ```bash
110
+ # Using uv (recommended)
111
+ uv tool install "corp-extractor[embeddings]"
112
+
113
+ # Using pipx
114
+ pipx install "corp-extractor[embeddings]"
115
+
116
+ # Using pip
117
+ pip install "corp-extractor[embeddings]"
118
+
119
+ # Then use anywhere
120
+ corp-extractor "Your text here"
121
+ ```
122
+
123
+ ### Quick Run with uvx
124
+
125
+ Run directly without installing using [uv](https://docs.astral.sh/uv/):
126
+
127
+ ```bash
128
+ uvx corp-extractor "Apple announced a new iPhone."
129
+ ```
130
+
131
+ **Note**: First run downloads the model (~1.5GB) which may take a few minutes.
132
+
133
+ ### Usage Examples
134
+
135
+ ```bash
136
+ # Extract from text argument
137
+ corp-extractor "Apple Inc. announced the iPhone 15 at their September event."
138
+
139
+ # Extract from file
140
+ corp-extractor -f article.txt
141
+
142
+ # Pipe from stdin
143
+ cat article.txt | corp-extractor -
144
+
145
+ # Output as JSON
146
+ corp-extractor "Tim Cook is CEO of Apple." --json
147
+
148
+ # Output as XML
149
+ corp-extractor -f article.txt --xml
150
+
151
+ # Verbose output with confidence scores
152
+ corp-extractor -f article.txt --verbose
153
+
154
+ # Use more beams for better quality
155
+ corp-extractor -f article.txt --beams 8
156
+
157
+ # Use custom predicate taxonomy
158
+ corp-extractor -f article.txt --taxonomy predicates.txt
159
+
160
+ # Use GPU explicitly
161
+ corp-extractor -f article.txt --device cuda
162
+ ```
163
+
164
+ ### CLI Options
165
+
166
+ ```
167
+ Usage: corp-extractor [OPTIONS] [TEXT]
168
+
169
+ Options:
170
+ -f, --file PATH Read input from file
171
+ -o, --output [table|json|xml] Output format (default: table)
172
+ --json Output as JSON (shortcut)
173
+ --xml Output as XML (shortcut)
174
+ -b, --beams INTEGER Number of beams (default: 4)
175
+ --diversity FLOAT Diversity penalty (default: 1.0)
176
+ --max-tokens INTEGER Max tokens to generate (default: 2048)
177
+ --no-dedup Disable deduplication
178
+ --no-embeddings Disable embedding-based dedup (faster)
179
+ --no-merge Disable beam merging
180
+ --dedup-threshold FLOAT Deduplication threshold (default: 0.65)
181
+ --min-confidence FLOAT Min confidence filter (default: 0)
182
+ --taxonomy PATH Load predicate taxonomy from file
183
+ --taxonomy-threshold FLOAT Taxonomy matching threshold (default: 0.5)
184
+ --device [auto|cuda|cpu] Device to use (default: auto)
185
+ -v, --verbose Show confidence scores and metadata
186
+ -q, --quiet Suppress progress messages
187
+ --version Show version
188
+ --help Show this message
189
+ ```
190
+
92
191
  ## New in v0.2.0: Quality Scoring & Beam Merging
93
192
 
94
193
  By default, the library now:
@@ -139,6 +238,47 @@ Predicate canonicalization and deduplication now use **contextualized matching**
139
238
 
140
239
  This means "Apple bought Beats" vs "Apple acquired Beats" are compared holistically, not just "bought" vs "acquired".
141
240
 
241
+ ## New in v0.2.3: Entity Type Merging & Reversal Detection
242
+
243
+ ### Entity Type Merging
244
+
245
+ When deduplicating statements, entity types are now automatically merged. If one statement has `UNKNOWN` type and a duplicate has a specific type (like `ORG` or `PERSON`), the specific type is preserved:
246
+
247
+ ```python
248
+ # Before deduplication:
249
+ # Statement 1: AtlasBio Labs (UNKNOWN) --sued by--> CuraPharm (ORG)
250
+ # Statement 2: AtlasBio Labs (ORG) --sued by--> CuraPharm (ORG)
251
+
252
+ # After deduplication:
253
+ # Single statement: AtlasBio Labs (ORG) --sued by--> CuraPharm (ORG)
254
+ ```
255
+
256
+ ### Subject-Object Reversal Detection
257
+
258
+ The library now detects when subject and object may have been extracted in the wrong order by comparing embeddings against source text:
259
+
260
+ ```python
261
+ from statement_extractor import PredicateComparer
262
+
263
+ comparer = PredicateComparer()
264
+
265
+ # Automatically detect and fix reversals
266
+ fixed_statements = comparer.detect_and_fix_reversals(statements)
267
+
268
+ for stmt in fixed_statements:
269
+ if stmt.was_reversed:
270
+ print(f"Fixed reversal: {stmt}")
271
+ ```
272
+
273
+ **How it works:**
274
+ 1. For each statement with source text, compares:
275
+ - "Subject Predicate Object" embedding vs source text
276
+ - "Object Predicate Subject" embedding vs source text
277
+ 2. If the reversed form has higher similarity, swaps subject and object
278
+ 3. Sets `was_reversed=True` to indicate the correction
279
+
280
+ During deduplication, reversed duplicates (e.g., "A -> P -> B" and "B -> P -> A") are now detected and merged, with the correct orientation determined by source text similarity.
281
+
142
282
  ## Disable Embeddings (Faster, No Extra Dependencies)
143
283
 
144
284
  ```python
@@ -213,6 +353,8 @@ This library uses the T5-Gemma 2 statement extraction model with **Diverse Beam
213
353
  4. **Embedding Dedup** *(v0.2.0)*: Semantic similarity removes near-duplicate predicates
214
354
  5. **Predicate Normalization** *(v0.2.0)*: Optional taxonomy matching via embeddings
215
355
  6. **Contextualized Matching** *(v0.2.2)*: Full statement context used for canonicalization and dedup
356
+ 7. **Entity Type Merging** *(v0.2.3)*: UNKNOWN types merged with specific types during dedup
357
+ 8. **Reversal Detection** *(v0.2.3)*: Subject-object reversals detected and corrected via embedding comparison
216
358
 
217
359
  ## Requirements
218
360
 
@@ -15,22 +15,30 @@ Extract structured subject-predicate-object statements from unstructured text us
15
15
  - **Embedding-based Dedup** *(v0.2.0)*: Uses semantic similarity to detect near-duplicate predicates
16
16
  - **Predicate Taxonomies** *(v0.2.0)*: Normalize predicates to canonical forms via embeddings
17
17
  - **Contextualized Matching** *(v0.2.2)*: Compares full "Subject Predicate Object" against source text for better accuracy
18
+ - **Entity Type Merging** *(v0.2.3)*: Automatically merges UNKNOWN entity types with specific types during deduplication
19
+ - **Reversal Detection** *(v0.2.3)*: Detects and corrects subject-object reversals using embedding comparison
20
+ - **Command Line Interface** *(v0.2.4)*: Full-featured CLI for terminal usage
18
21
  - **Multiple Output Formats**: Get results as Pydantic models, JSON, XML, or dictionaries
19
22
 
20
23
  ## Installation
21
24
 
22
25
  ```bash
23
26
  # Recommended: include embedding support for smart deduplication
24
- pip install corp-extractor[embeddings]
27
+ pip install "corp-extractor[embeddings]"
25
28
 
26
29
  # Minimal installation (no embedding features)
27
30
  pip install corp-extractor
28
31
  ```
29
32
 
30
- **Note**: For GPU support, install PyTorch with CUDA first:
33
+ **Note**: This package requires `transformers>=5.0.0` (pre-release) for T5-Gemma2 model support. Install with `--pre` flag if needed:
34
+ ```bash
35
+ pip install --pre "corp-extractor[embeddings]"
36
+ ```
37
+
38
+ **For GPU support**, install PyTorch with CUDA first:
31
39
  ```bash
32
40
  pip install torch --index-url https://download.pytorch.org/whl/cu121
33
- pip install corp-extractor[embeddings]
41
+ pip install "corp-extractor[embeddings]"
34
42
  ```
35
43
 
36
44
  ## Quick Start
@@ -49,6 +57,96 @@ for stmt in result:
49
57
  print(f" Confidence: {stmt.confidence_score:.2f}") # NEW in v0.2.0
50
58
  ```
51
59
 
60
+ ## Command Line Interface
61
+
62
+ The library includes a CLI for quick extraction from the terminal.
63
+
64
+ ### Install Globally (Recommended)
65
+
66
+ For best results, install globally first:
67
+
68
+ ```bash
69
+ # Using uv (recommended)
70
+ uv tool install "corp-extractor[embeddings]"
71
+
72
+ # Using pipx
73
+ pipx install "corp-extractor[embeddings]"
74
+
75
+ # Using pip
76
+ pip install "corp-extractor[embeddings]"
77
+
78
+ # Then use anywhere
79
+ corp-extractor "Your text here"
80
+ ```
81
+
82
+ ### Quick Run with uvx
83
+
84
+ Run directly without installing using [uv](https://docs.astral.sh/uv/):
85
+
86
+ ```bash
87
+ uvx corp-extractor "Apple announced a new iPhone."
88
+ ```
89
+
90
+ **Note**: First run downloads the model (~1.5GB) which may take a few minutes.
91
+
92
+ ### Usage Examples
93
+
94
+ ```bash
95
+ # Extract from text argument
96
+ corp-extractor "Apple Inc. announced the iPhone 15 at their September event."
97
+
98
+ # Extract from file
99
+ corp-extractor -f article.txt
100
+
101
+ # Pipe from stdin
102
+ cat article.txt | corp-extractor -
103
+
104
+ # Output as JSON
105
+ corp-extractor "Tim Cook is CEO of Apple." --json
106
+
107
+ # Output as XML
108
+ corp-extractor -f article.txt --xml
109
+
110
+ # Verbose output with confidence scores
111
+ corp-extractor -f article.txt --verbose
112
+
113
+ # Use more beams for better quality
114
+ corp-extractor -f article.txt --beams 8
115
+
116
+ # Use custom predicate taxonomy
117
+ corp-extractor -f article.txt --taxonomy predicates.txt
118
+
119
+ # Use GPU explicitly
120
+ corp-extractor -f article.txt --device cuda
121
+ ```
122
+
123
+ ### CLI Options
124
+
125
+ ```
126
+ Usage: corp-extractor [OPTIONS] [TEXT]
127
+
128
+ Options:
129
+ -f, --file PATH Read input from file
130
+ -o, --output [table|json|xml] Output format (default: table)
131
+ --json Output as JSON (shortcut)
132
+ --xml Output as XML (shortcut)
133
+ -b, --beams INTEGER Number of beams (default: 4)
134
+ --diversity FLOAT Diversity penalty (default: 1.0)
135
+ --max-tokens INTEGER Max tokens to generate (default: 2048)
136
+ --no-dedup Disable deduplication
137
+ --no-embeddings Disable embedding-based dedup (faster)
138
+ --no-merge Disable beam merging
139
+ --dedup-threshold FLOAT Deduplication threshold (default: 0.65)
140
+ --min-confidence FLOAT Min confidence filter (default: 0)
141
+ --taxonomy PATH Load predicate taxonomy from file
142
+ --taxonomy-threshold FLOAT Taxonomy matching threshold (default: 0.5)
143
+ --device [auto|cuda|cpu] Device to use (default: auto)
144
+ -v, --verbose Show confidence scores and metadata
145
+ -q, --quiet Suppress progress messages
146
+ --version Show version
147
+ --help Show this message
148
+ ```
149
+
52
150
  ## New in v0.2.0: Quality Scoring & Beam Merging
53
151
 
54
152
  By default, the library now:
@@ -99,6 +197,47 @@ Predicate canonicalization and deduplication now use **contextualized matching**
99
197
 
100
198
  This means "Apple bought Beats" vs "Apple acquired Beats" are compared holistically, not just "bought" vs "acquired".
101
199
 
200
+ ## New in v0.2.3: Entity Type Merging & Reversal Detection
201
+
202
+ ### Entity Type Merging
203
+
204
+ When deduplicating statements, entity types are now automatically merged. If one statement has `UNKNOWN` type and a duplicate has a specific type (like `ORG` or `PERSON`), the specific type is preserved:
205
+
206
+ ```python
207
+ # Before deduplication:
208
+ # Statement 1: AtlasBio Labs (UNKNOWN) --sued by--> CuraPharm (ORG)
209
+ # Statement 2: AtlasBio Labs (ORG) --sued by--> CuraPharm (ORG)
210
+
211
+ # After deduplication:
212
+ # Single statement: AtlasBio Labs (ORG) --sued by--> CuraPharm (ORG)
213
+ ```
214
+
215
+ ### Subject-Object Reversal Detection
216
+
217
+ The library now detects when subject and object may have been extracted in the wrong order by comparing embeddings against source text:
218
+
219
+ ```python
220
+ from statement_extractor import PredicateComparer
221
+
222
+ comparer = PredicateComparer()
223
+
224
+ # Automatically detect and fix reversals
225
+ fixed_statements = comparer.detect_and_fix_reversals(statements)
226
+
227
+ for stmt in fixed_statements:
228
+ if stmt.was_reversed:
229
+ print(f"Fixed reversal: {stmt}")
230
+ ```
231
+
232
+ **How it works:**
233
+ 1. For each statement with source text, compares:
234
+ - "Subject Predicate Object" embedding vs source text
235
+ - "Object Predicate Subject" embedding vs source text
236
+ 2. If the reversed form has higher similarity, swaps subject and object
237
+ 3. Sets `was_reversed=True` to indicate the correction
238
+
239
+ During deduplication, reversed duplicates (e.g., "A -> P -> B" and "B -> P -> A") are now detected and merged, with the correct orientation determined by source text similarity.
240
+
102
241
  ## Disable Embeddings (Faster, No Extra Dependencies)
103
242
 
104
243
  ```python
@@ -173,6 +312,8 @@ This library uses the T5-Gemma 2 statement extraction model with **Diverse Beam
173
312
  4. **Embedding Dedup** *(v0.2.0)*: Semantic similarity removes near-duplicate predicates
174
313
  5. **Predicate Normalization** *(v0.2.0)*: Optional taxonomy matching via embeddings
175
314
  6. **Contextualized Matching** *(v0.2.2)*: Full statement context used for canonicalization and dedup
315
+ 7. **Entity Type Merging** *(v0.2.3)*: UNKNOWN types merged with specific types during dedup
316
+ 8. **Reversal Detection** *(v0.2.3)*: Subject-object reversals detected and corrected via embedding comparison
176
317
 
177
318
  ## Requirements
178
319
 
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "corp-extractor"
7
- version = "0.2.2"
7
+ version = "0.2.8"
8
8
  description = "Extract structured statements from text using T5-Gemma 2 and Diverse Beam Search"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.10"
@@ -46,8 +46,9 @@ classifiers = [
46
46
  dependencies = [
47
47
  "pydantic>=2.0.0",
48
48
  "torch>=2.0.0",
49
- "transformers>=4.35.0",
49
+ "transformers>=5.0.0rc3",
50
50
  "numpy>=1.24.0",
51
+ "click>=8.0.0",
51
52
  ]
52
53
 
53
54
  [project.optional-dependencies]
@@ -66,6 +67,10 @@ all = [
66
67
  "sentence-transformers>=2.2.0",
67
68
  ]
68
69
 
70
+ [project.scripts]
71
+ statement-extractor = "statement_extractor.cli:main"
72
+ corp-extractor = "statement_extractor.cli:main"
73
+
69
74
  [project.urls]
70
75
  Homepage = "https://github.com/corp-o-rate/statement-extractor"
71
76
  Documentation = "https://github.com/corp-o-rate/statement-extractor#readme"
@@ -29,7 +29,7 @@ Example:
29
29
  >>> data = extract_statements_as_dict("Some text...")
30
30
  """
31
31
 
32
- __version__ = "0.2.2"
32
+ __version__ = "0.2.5"
33
33
 
34
34
  # Core models
35
35
  from .models import (
@@ -139,32 +139,58 @@ class Canonicalizer:
139
139
 
140
140
  def deduplicate_statements_exact(
141
141
  statements: list[Statement],
142
- entity_canonicalizer: Optional[Callable[[str], str]] = None
142
+ entity_canonicalizer: Optional[Callable[[str], str]] = None,
143
+ detect_reversals: bool = True,
143
144
  ) -> list[Statement]:
144
145
  """
145
146
  Deduplicate statements using exact text matching.
146
147
 
147
148
  Use this when embedding-based deduplication is disabled.
149
+ When duplicates are found, entity types are merged - specific types
150
+ (ORG, PERSON, etc.) take precedence over UNKNOWN.
151
+
152
+ When detect_reversals=True, also detects reversed duplicates where
153
+ subject and object are swapped. The first occurrence determines the
154
+ canonical orientation.
148
155
 
149
156
  Args:
150
157
  statements: List of statements to deduplicate
151
158
  entity_canonicalizer: Optional custom canonicalization function
159
+ detect_reversals: Whether to detect reversed duplicates (default True)
152
160
 
153
161
  Returns:
154
- Deduplicated list (keeps first occurrence)
162
+ Deduplicated list with merged entity types
155
163
  """
156
164
  if len(statements) <= 1:
157
165
  return statements
158
166
 
159
167
  canonicalizer = Canonicalizer(entity_fn=entity_canonicalizer)
160
168
 
161
- seen: set[tuple[str, str, str]] = set()
169
+ # Map from dedup key to index in unique list
170
+ seen: dict[tuple[str, str, str], int] = {}
162
171
  unique: list[Statement] = []
163
172
 
164
173
  for stmt in statements:
165
174
  key = canonicalizer.create_dedup_key(stmt)
166
- if key not in seen:
167
- seen.add(key)
175
+ # Also compute reversed key (object, predicate, subject)
176
+ reversed_key = (key[2], key[1], key[0])
177
+
178
+ if key in seen:
179
+ # Direct duplicate found - merge entity types
180
+ existing_idx = seen[key]
181
+ existing_stmt = unique[existing_idx]
182
+ merged_stmt = existing_stmt.merge_entity_types_from(stmt)
183
+ unique[existing_idx] = merged_stmt
184
+ elif detect_reversals and reversed_key in seen:
185
+ # Reversed duplicate found - merge entity types (accounting for reversal)
186
+ existing_idx = seen[reversed_key]
187
+ existing_stmt = unique[existing_idx]
188
+ # Merge types from the reversed statement
189
+ merged_stmt = existing_stmt.merge_entity_types_from(stmt.reversed())
190
+ unique[existing_idx] = merged_stmt
191
+ else:
192
+ # New unique statement
193
+ seen[key] = len(unique)
168
194
  unique.append(stmt)
169
195
 
170
196
  return unique
@@ -0,0 +1,215 @@
1
+ """
2
+ Command-line interface for statement extraction.
3
+
4
+ Usage:
5
+ corp-extractor "Your text here"
6
+ corp-extractor -f input.txt
7
+ cat input.txt | corp-extractor -
8
+ """
9
+
10
+ import sys
11
+ from typing import Optional
12
+
13
+ import click
14
+
15
+ from . import __version__
16
+ from .models import (
17
+ ExtractionOptions,
18
+ PredicateComparisonConfig,
19
+ PredicateTaxonomy,
20
+ ScoringConfig,
21
+ )
22
+
23
+
24
+ @click.command()
25
+ @click.argument("text", required=False)
26
+ @click.option("-f", "--file", "input_file", type=click.Path(exists=True), help="Read input from file")
27
+ @click.option(
28
+ "-o", "--output",
29
+ type=click.Choice(["table", "json", "xml"], case_sensitive=False),
30
+ default="table",
31
+ help="Output format (default: table)"
32
+ )
33
+ @click.option("--json", "output_json", is_flag=True, help="Output as JSON (shortcut for -o json)")
34
+ @click.option("--xml", "output_xml", is_flag=True, help="Output as XML (shortcut for -o xml)")
35
+ # Beam search options
36
+ @click.option("-b", "--beams", type=int, default=4, help="Number of beams for diverse beam search (default: 4)")
37
+ @click.option("--diversity", type=float, default=1.0, help="Diversity penalty for beam search (default: 1.0)")
38
+ @click.option("--max-tokens", type=int, default=2048, help="Maximum tokens to generate (default: 2048)")
39
+ # Deduplication options
40
+ @click.option("--no-dedup", is_flag=True, help="Disable deduplication")
41
+ @click.option("--no-embeddings", is_flag=True, help="Disable embedding-based deduplication (faster)")
42
+ @click.option("--no-merge", is_flag=True, help="Disable beam merging (select single best beam)")
43
+ @click.option("--dedup-threshold", type=float, default=0.65, help="Similarity threshold for deduplication (default: 0.65)")
44
+ # Quality options
45
+ @click.option("--min-confidence", type=float, default=0.0, help="Minimum confidence threshold 0-1 (default: 0)")
46
+ # Taxonomy options
47
+ @click.option("--taxonomy", type=click.Path(exists=True), help="Load predicate taxonomy from file (one per line)")
48
+ @click.option("--taxonomy-threshold", type=float, default=0.5, help="Similarity threshold for taxonomy matching (default: 0.5)")
49
+ # Device options
50
+ @click.option("--device", type=click.Choice(["auto", "cuda", "mps", "cpu"]), default="auto", help="Device to use (default: auto)")
51
+ # Output options
52
+ @click.option("-v", "--verbose", is_flag=True, help="Show verbose output with confidence scores")
53
+ @click.option("-q", "--quiet", is_flag=True, help="Suppress progress messages")
54
+ @click.version_option(version=__version__)
55
+ def main(
56
+ text: Optional[str],
57
+ input_file: Optional[str],
58
+ output: str,
59
+ output_json: bool,
60
+ output_xml: bool,
61
+ beams: int,
62
+ diversity: float,
63
+ max_tokens: int,
64
+ no_dedup: bool,
65
+ no_embeddings: bool,
66
+ no_merge: bool,
67
+ dedup_threshold: float,
68
+ min_confidence: float,
69
+ taxonomy: Optional[str],
70
+ taxonomy_threshold: float,
71
+ device: str,
72
+ verbose: bool,
73
+ quiet: bool,
74
+ ):
75
+ """
76
+ Extract structured statements from text.
77
+
78
+ TEXT can be provided as an argument, read from a file with -f, or piped via stdin.
79
+
80
+ \b
81
+ Examples:
82
+ corp-extractor "Apple announced a new iPhone."
83
+ corp-extractor -f article.txt --json
84
+ corp-extractor -f article.txt -o json --beams 8
85
+ cat article.txt | corp-extractor -
86
+ echo "Tim Cook is CEO of Apple." | corp-extractor - --verbose
87
+
88
+ \b
89
+ Output formats:
90
+ table Human-readable table (default)
91
+ json JSON with full metadata
92
+ xml Raw XML from model
93
+ """
94
+ # Determine output format
95
+ if output_json:
96
+ output = "json"
97
+ elif output_xml:
98
+ output = "xml"
99
+
100
+ # Get input text
101
+ input_text = _get_input_text(text, input_file)
102
+ if not input_text:
103
+ raise click.UsageError(
104
+ "No input provided. Use: statement-extractor \"text\", "
105
+ "statement-extractor -f file.txt, or pipe via stdin."
106
+ )
107
+
108
+ if not quiet:
109
+ click.echo(f"Processing {len(input_text)} characters...", err=True)
110
+
111
+ # Load taxonomy if provided
112
+ predicate_taxonomy = None
113
+ if taxonomy:
114
+ predicate_taxonomy = PredicateTaxonomy.from_file(taxonomy)
115
+ if not quiet:
116
+ click.echo(f"Loaded taxonomy with {len(predicate_taxonomy.predicates)} predicates", err=True)
117
+
118
+ # Configure predicate comparison
119
+ predicate_config = PredicateComparisonConfig(
120
+ similarity_threshold=taxonomy_threshold,
121
+ dedup_threshold=dedup_threshold,
122
+ )
123
+
124
+ # Configure scoring
125
+ scoring_config = ScoringConfig(min_confidence=min_confidence)
126
+
127
+ # Configure extraction options
128
+ options = ExtractionOptions(
129
+ num_beams=beams,
130
+ diversity_penalty=diversity,
131
+ max_new_tokens=max_tokens,
132
+ deduplicate=not no_dedup,
133
+ embedding_dedup=not no_embeddings,
134
+ merge_beams=not no_merge,
135
+ predicate_taxonomy=predicate_taxonomy,
136
+ predicate_config=predicate_config,
137
+ scoring_config=scoring_config,
138
+ )
139
+
140
+ # Import here to allow --help without loading torch
141
+ from .extractor import StatementExtractor
142
+
143
+ # Create extractor with specified device
144
+ device_arg = None if device == "auto" else device
145
+ extractor = StatementExtractor(device=device_arg)
146
+
147
+ if not quiet:
148
+ click.echo(f"Using device: {extractor.device}", err=True)
149
+
150
+ # Run extraction
151
+ try:
152
+ if output == "xml":
153
+ result = extractor.extract_as_xml(input_text, options)
154
+ click.echo(result)
155
+ elif output == "json":
156
+ result = extractor.extract_as_json(input_text, options)
157
+ click.echo(result)
158
+ else:
159
+ # Table format
160
+ result = extractor.extract(input_text, options)
161
+ _print_table(result, verbose)
162
+ except Exception as e:
163
+ raise click.ClickException(f"Extraction failed: {e}")
164
+
165
+
166
+ def _get_input_text(text: Optional[str], input_file: Optional[str]) -> Optional[str]:
167
+ """Get input text from argument, file, or stdin."""
168
+ if text == "-" or (text is None and input_file is None and not sys.stdin.isatty()):
169
+ # Read from stdin
170
+ return sys.stdin.read().strip()
171
+ elif input_file:
172
+ # Read from file
173
+ with open(input_file, "r", encoding="utf-8") as f:
174
+ return f.read().strip()
175
+ elif text:
176
+ return text.strip()
177
+ return None
178
+
179
+
180
+ def _print_table(result, verbose: bool):
181
+ """Print statements in a human-readable table format."""
182
+ if not result.statements:
183
+ click.echo("No statements extracted.")
184
+ return
185
+
186
+ click.echo(f"\nExtracted {len(result.statements)} statement(s):\n")
187
+ click.echo("-" * 80)
188
+
189
+ for i, stmt in enumerate(result.statements, 1):
190
+ subject_type = f" ({stmt.subject.type.value})" if stmt.subject.type.value != "UNKNOWN" else ""
191
+ object_type = f" ({stmt.object.type.value})" if stmt.object.type.value != "UNKNOWN" else ""
192
+
193
+ click.echo(f"{i}. {stmt.subject.text}{subject_type}")
194
+ click.echo(f" --[{stmt.predicate}]-->")
195
+ click.echo(f" {stmt.object.text}{object_type}")
196
+
197
+ if verbose:
198
+ if stmt.confidence_score is not None:
199
+ click.echo(f" Confidence: {stmt.confidence_score:.2f}")
200
+
201
+ if stmt.canonical_predicate:
202
+ click.echo(f" Canonical: {stmt.canonical_predicate}")
203
+
204
+ if stmt.was_reversed:
205
+ click.echo(f" (subject/object were swapped)")
206
+
207
+ if stmt.source_text:
208
+ source = stmt.source_text[:60] + "..." if len(stmt.source_text) > 60 else stmt.source_text
209
+ click.echo(f" Source: \"{source}\"")
210
+
211
+ click.echo("-" * 80)
212
+
213
+
214
+ if __name__ == "__main__":
215
+ main()
@@ -80,11 +80,16 @@ class StatementExtractor:
80
80
 
81
81
  # Auto-detect device
82
82
  if device is None:
83
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
83
+ if torch.cuda.is_available():
84
+ self.device = "cuda"
85
+ elif torch.backends.mps.is_available():
86
+ self.device = "mps"
87
+ else:
88
+ self.device = "cpu"
84
89
  else:
85
90
  self.device = device
86
91
 
87
- # Auto-detect dtype
92
+ # Auto-detect dtype (bfloat16 only for CUDA, float32 for MPS/CPU)
88
93
  if torch_dtype is None:
89
94
  self.torch_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
90
95
  else:
@@ -143,7 +148,7 @@ class StatementExtractor:
143
148
 
144
149
  taxonomy = options.predicate_taxonomy or self._predicate_taxonomy
145
150
  config = options.predicate_config or self._predicate_config or PredicateComparisonConfig()
146
- return PredicateComparer(taxonomy=taxonomy, config=config)
151
+ return PredicateComparer(taxonomy=taxonomy, config=config, device=self.device)
147
152
 
148
153
  @property
149
154
  def model(self) -> AutoModelForSeq2SeqLM:
@@ -350,30 +355,41 @@ class StatementExtractor:
350
355
  outputs = self.model.generate(
351
356
  **inputs,
352
357
  max_new_tokens=options.max_new_tokens,
358
+ max_length=None, # Override model default, use max_new_tokens only
353
359
  num_beams=num_seqs,
354
360
  num_beam_groups=num_seqs,
355
361
  num_return_sequences=num_seqs,
356
362
  diversity_penalty=options.diversity_penalty,
357
363
  do_sample=False,
364
+ top_p=None, # Override model config to suppress warning
365
+ top_k=None, # Override model config to suppress warning
358
366
  trust_remote_code=True,
367
+ custom_generate="transformers-community/group-beam-search",
359
368
  )
360
369
 
361
370
  # Decode and process candidates
362
371
  end_tag = "</statements>"
363
372
  candidates: list[str] = []
364
373
 
365
- for output in outputs:
374
+ for i, output in enumerate(outputs):
366
375
  decoded = self.tokenizer.decode(output, skip_special_tokens=True)
376
+ output_len = len(output)
367
377
 
368
378
  # Truncate at </statements>
369
379
  if end_tag in decoded:
370
380
  end_pos = decoded.find(end_tag) + len(end_tag)
371
381
  decoded = decoded[:end_pos]
372
382
  candidates.append(decoded)
383
+ logger.debug(f"Beam {i}: {output_len} tokens, found end tag, {len(decoded)} chars")
384
+ else:
385
+ # Log the issue - likely truncated
386
+ logger.warning(f"Beam {i}: {output_len} tokens, NO end tag found (truncated?)")
387
+ logger.warning(f"Beam {i} full output ({len(decoded)} chars):\n{decoded}")
373
388
 
374
389
  # Include fallback if no valid candidates
375
390
  if not candidates and len(outputs) > 0:
376
391
  fallback = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
392
+ logger.warning(f"Using fallback beam (no valid candidates found), {len(fallback)} chars")
377
393
  candidates.append(fallback)
378
394
 
379
395
  return candidates
@@ -467,8 +483,10 @@ class StatementExtractor:
467
483
 
468
484
  try:
469
485
  root = ET.fromstring(xml_output)
470
- except ET.ParseError:
471
- logger.warning("Failed to parse XML output")
486
+ except ET.ParseError as e:
487
+ # Log full output for debugging
488
+ logger.warning(f"Failed to parse XML output: {e}")
489
+ logger.warning(f"Full XML output ({len(xml_output)} chars):\n{xml_output}")
472
490
  return statements
473
491
 
474
492
  if root.tag != 'statements':
@@ -32,6 +32,18 @@ class Entity(BaseModel):
32
32
  def __str__(self) -> str:
33
33
  return f"{self.text} ({self.type.value})"
34
34
 
35
+ def merge_type_from(self, other: "Entity") -> "Entity":
36
+ """
37
+ Return a new Entity with the more specific type.
38
+
39
+ If this entity has UNKNOWN type and other has a specific type,
40
+ returns a new entity with this text but other's type.
41
+ Otherwise returns self unchanged.
42
+ """
43
+ if self.type == EntityType.UNKNOWN and other.type != EntityType.UNKNOWN:
44
+ return Entity(text=self.text, type=other.type)
45
+ return self
46
+
35
47
 
36
48
  class Statement(BaseModel):
37
49
  """A single extracted statement (subject-predicate-object triple)."""
@@ -55,6 +67,10 @@ class Statement(BaseModel):
55
67
  None,
56
68
  description="Canonical form of the predicate if taxonomy matching was used"
57
69
  )
70
+ was_reversed: bool = Field(
71
+ default=False,
72
+ description="True if subject/object were swapped during reversal detection"
73
+ )
58
74
 
59
75
  def __str__(self) -> str:
60
76
  return f"{self.subject.text} -- {self.predicate} --> {self.object.text}"
@@ -63,6 +79,49 @@ class Statement(BaseModel):
63
79
  """Return as a simple (subject, predicate, object) tuple."""
64
80
  return (self.subject.text, self.predicate, self.object.text)
65
81
 
82
+ def merge_entity_types_from(self, other: "Statement") -> "Statement":
83
+ """
84
+ Return a new Statement with more specific entity types merged from other.
85
+
86
+ If this statement has UNKNOWN entity types and other has specific types,
87
+ the returned statement will use the specific types from other.
88
+ All other fields come from self.
89
+ """
90
+ merged_subject = self.subject.merge_type_from(other.subject)
91
+ merged_object = self.object.merge_type_from(other.object)
92
+
93
+ # Only create new statement if something changed
94
+ if merged_subject is self.subject and merged_object is self.object:
95
+ return self
96
+
97
+ return Statement(
98
+ subject=merged_subject,
99
+ object=merged_object,
100
+ predicate=self.predicate,
101
+ source_text=self.source_text,
102
+ confidence_score=self.confidence_score,
103
+ evidence_span=self.evidence_span,
104
+ canonical_predicate=self.canonical_predicate,
105
+ was_reversed=self.was_reversed,
106
+ )
107
+
108
+ def reversed(self) -> "Statement":
109
+ """
110
+ Return a new Statement with subject and object swapped.
111
+
112
+ Sets was_reversed=True to indicate the swap occurred.
113
+ """
114
+ return Statement(
115
+ subject=self.object,
116
+ object=self.subject,
117
+ predicate=self.predicate,
118
+ source_text=self.source_text,
119
+ confidence_score=self.confidence_score,
120
+ evidence_span=self.evidence_span,
121
+ canonical_predicate=self.canonical_predicate,
122
+ was_reversed=True,
123
+ )
124
+
66
125
 
67
126
  class ExtractionResult(BaseModel):
68
127
  """The result of statement extraction from text."""
@@ -62,6 +62,7 @@ class PredicateComparer:
62
62
  self,
63
63
  taxonomy: Optional[PredicateTaxonomy] = None,
64
64
  config: Optional[PredicateComparisonConfig] = None,
65
+ device: Optional[str] = None,
65
66
  ):
66
67
  """
67
68
  Initialize the predicate comparer.
@@ -69,6 +70,7 @@ class PredicateComparer:
69
70
  Args:
70
71
  taxonomy: Optional canonical predicate taxonomy for normalization
71
72
  config: Comparison configuration (uses defaults if not provided)
73
+ device: Device to use ('cuda', 'cpu', or None for auto-detect)
72
74
 
73
75
  Raises:
74
76
  EmbeddingDependencyError: If sentence-transformers is not installed
@@ -78,6 +80,18 @@ class PredicateComparer:
78
80
  self.taxonomy = taxonomy
79
81
  self.config = config or PredicateComparisonConfig()
80
82
 
83
+ # Auto-detect device
84
+ if device is None:
85
+ import torch
86
+ if torch.cuda.is_available():
87
+ self.device = "cuda"
88
+ elif torch.backends.mps.is_available():
89
+ self.device = "mps"
90
+ else:
91
+ self.device = "cpu"
92
+ else:
93
+ self.device = device
94
+
81
95
  # Lazy-loaded resources
82
96
  self._model = None
83
97
  self._taxonomy_embeddings: Optional[np.ndarray] = None
@@ -89,9 +103,9 @@ class PredicateComparer:
89
103
 
90
104
  from sentence_transformers import SentenceTransformer
91
105
 
92
- logger.info(f"Loading embedding model: {self.config.embedding_model}")
93
- self._model = SentenceTransformer(self.config.embedding_model)
94
- logger.info("Embedding model loaded")
106
+ logger.info(f"Loading embedding model: {self.config.embedding_model} on {self.device}")
107
+ self._model = SentenceTransformer(self.config.embedding_model, device=self.device)
108
+ logger.info(f"Embedding model loaded on {self.device}")
95
109
 
96
110
  def _normalize_text(self, text: str) -> str:
97
111
  """Normalize text before embedding."""
@@ -256,21 +270,26 @@ class PredicateComparer:
256
270
  self,
257
271
  statements: list[Statement],
258
272
  entity_canonicalizer: Optional[callable] = None,
273
+ detect_reversals: bool = True,
259
274
  ) -> list[Statement]:
260
275
  """
261
276
  Remove duplicate statements using embedding-based predicate comparison.
262
277
 
263
278
  Two statements are considered duplicates if:
264
- - Canonicalized subjects match
279
+ - Canonicalized subjects match AND canonicalized objects match, OR
280
+ - Canonicalized subjects match objects (reversed) when detect_reversals=True
265
281
  - Predicates are similar (embedding-based)
266
- - Canonicalized objects match
267
282
 
268
283
  When duplicates are found, keeps the statement with better contextualized
269
284
  match (comparing "Subject Predicate Object" against source text).
270
285
 
286
+ For reversed duplicates, the correct orientation is determined by comparing
287
+ both "S P O" and "O P S" against source text.
288
+
271
289
  Args:
272
290
  statements: List of Statement objects
273
291
  entity_canonicalizer: Optional function to canonicalize entity text
292
+ detect_reversals: Whether to detect reversed duplicates (default True)
274
293
 
275
294
  Returns:
276
295
  Deduplicated list of statements (keeps best contextualized match)
@@ -293,6 +312,12 @@ class PredicateComparer:
293
312
  ]
294
313
  contextualized_embeddings = self._compute_embeddings(contextualized_texts)
295
314
 
315
+ # Compute reversed contextualized embeddings: "Object Predicate Subject"
316
+ reversed_texts = [
317
+ f"{s.object.text} {s.predicate} {s.subject.text}" for s in statements
318
+ ]
319
+ reversed_embeddings = self._compute_embeddings(reversed_texts)
320
+
296
321
  # Compute source text embeddings for scoring which duplicate to keep
297
322
  source_embeddings = []
298
323
  for stmt in statements:
@@ -302,6 +327,7 @@ class PredicateComparer:
302
327
  unique_statements: list[Statement] = []
303
328
  unique_pred_embeddings: list[np.ndarray] = []
304
329
  unique_context_embeddings: list[np.ndarray] = []
330
+ unique_reversed_embeddings: list[np.ndarray] = []
305
331
  unique_source_embeddings: list[np.ndarray] = []
306
332
  unique_indices: list[int] = []
307
333
 
@@ -310,13 +336,23 @@ class PredicateComparer:
310
336
  obj_canon = canonicalize(stmt.object.text)
311
337
 
312
338
  duplicate_idx = None
339
+ is_reversed_match = False
313
340
 
314
341
  for j, unique_stmt in enumerate(unique_statements):
315
342
  unique_subj = canonicalize(unique_stmt.subject.text)
316
343
  unique_obj = canonicalize(unique_stmt.object.text)
317
344
 
318
- # Check subject and object match
319
- if subj_canon != unique_subj or obj_canon != unique_obj:
345
+ # Check direct match: subject->subject, object->object
346
+ direct_match = (subj_canon == unique_subj and obj_canon == unique_obj)
347
+
348
+ # Check reversed match: subject->object, object->subject
349
+ reversed_match = (
350
+ detect_reversals and
351
+ subj_canon == unique_obj and
352
+ obj_canon == unique_subj
353
+ )
354
+
355
+ if not direct_match and not reversed_match:
320
356
  continue
321
357
 
322
358
  # Check predicate similarity
@@ -326,6 +362,7 @@ class PredicateComparer:
326
362
  )
327
363
  if similarity >= self.config.dedup_threshold:
328
364
  duplicate_idx = j
365
+ is_reversed_match = reversed_match and not direct_match
329
366
  break
330
367
 
331
368
  if duplicate_idx is None:
@@ -333,27 +370,91 @@ class PredicateComparer:
333
370
  unique_statements.append(stmt)
334
371
  unique_pred_embeddings.append(pred_embeddings[i])
335
372
  unique_context_embeddings.append(contextualized_embeddings[i])
373
+ unique_reversed_embeddings.append(reversed_embeddings[i])
336
374
  unique_source_embeddings.append(source_embeddings[i])
337
375
  unique_indices.append(i)
338
376
  else:
339
- # Duplicate found - keep the one with better contextualized match
340
- # Compare "Subject Predicate Object" against source text
341
- current_score = self._cosine_similarity(
342
- contextualized_embeddings[i],
343
- source_embeddings[i]
344
- )
345
- existing_score = self._cosine_similarity(
346
- unique_context_embeddings[duplicate_idx],
347
- unique_source_embeddings[duplicate_idx]
348
- )
349
-
350
- if current_score > existing_score:
351
- # Current statement is a better match - replace
352
- unique_statements[duplicate_idx] = stmt
353
- unique_pred_embeddings[duplicate_idx] = pred_embeddings[i]
354
- unique_context_embeddings[duplicate_idx] = contextualized_embeddings[i]
355
- unique_source_embeddings[duplicate_idx] = source_embeddings[i]
356
- unique_indices[duplicate_idx] = i
377
+ existing_stmt = unique_statements[duplicate_idx]
378
+
379
+ if is_reversed_match:
380
+ # Reversed duplicate - determine correct orientation using source text
381
+ # Compare current's normal vs reversed against its source
382
+ current_normal_score = self._cosine_similarity(
383
+ contextualized_embeddings[i], source_embeddings[i]
384
+ )
385
+ current_reversed_score = self._cosine_similarity(
386
+ reversed_embeddings[i], source_embeddings[i]
387
+ )
388
+ # Compare existing's normal vs reversed against its source
389
+ existing_normal_score = self._cosine_similarity(
390
+ unique_context_embeddings[duplicate_idx],
391
+ unique_source_embeddings[duplicate_idx]
392
+ )
393
+ existing_reversed_score = self._cosine_similarity(
394
+ unique_reversed_embeddings[duplicate_idx],
395
+ unique_source_embeddings[duplicate_idx]
396
+ )
397
+
398
+ # Determine best orientation for current
399
+ current_best = max(current_normal_score, current_reversed_score)
400
+ current_should_reverse = current_reversed_score > current_normal_score
401
+
402
+ # Determine best orientation for existing
403
+ existing_best = max(existing_normal_score, existing_reversed_score)
404
+ existing_should_reverse = existing_reversed_score > existing_normal_score
405
+
406
+ if current_best > existing_best:
407
+ # Current is better - use it (possibly reversed)
408
+ if current_should_reverse:
409
+ best_stmt = stmt.reversed()
410
+ else:
411
+ best_stmt = stmt
412
+ # Merge entity types from existing (accounting for reversal)
413
+ if existing_should_reverse:
414
+ best_stmt = best_stmt.merge_entity_types_from(existing_stmt.reversed())
415
+ else:
416
+ best_stmt = best_stmt.merge_entity_types_from(existing_stmt)
417
+ unique_statements[duplicate_idx] = best_stmt
418
+ unique_pred_embeddings[duplicate_idx] = pred_embeddings[i]
419
+ unique_context_embeddings[duplicate_idx] = contextualized_embeddings[i]
420
+ unique_reversed_embeddings[duplicate_idx] = reversed_embeddings[i]
421
+ unique_source_embeddings[duplicate_idx] = source_embeddings[i]
422
+ unique_indices[duplicate_idx] = i
423
+ else:
424
+ # Existing is better - possibly fix its orientation
425
+ if existing_should_reverse and not existing_stmt.was_reversed:
426
+ best_stmt = existing_stmt.reversed()
427
+ else:
428
+ best_stmt = existing_stmt
429
+ # Merge entity types from current (accounting for reversal)
430
+ if current_should_reverse:
431
+ best_stmt = best_stmt.merge_entity_types_from(stmt.reversed())
432
+ else:
433
+ best_stmt = best_stmt.merge_entity_types_from(stmt)
434
+ unique_statements[duplicate_idx] = best_stmt
435
+ else:
436
+ # Direct duplicate - keep the one with better contextualized match
437
+ current_score = self._cosine_similarity(
438
+ contextualized_embeddings[i], source_embeddings[i]
439
+ )
440
+ existing_score = self._cosine_similarity(
441
+ unique_context_embeddings[duplicate_idx],
442
+ unique_source_embeddings[duplicate_idx]
443
+ )
444
+
445
+ if current_score > existing_score:
446
+ # Current statement is a better match - replace
447
+ merged_stmt = stmt.merge_entity_types_from(existing_stmt)
448
+ unique_statements[duplicate_idx] = merged_stmt
449
+ unique_pred_embeddings[duplicate_idx] = pred_embeddings[i]
450
+ unique_context_embeddings[duplicate_idx] = contextualized_embeddings[i]
451
+ unique_reversed_embeddings[duplicate_idx] = reversed_embeddings[i]
452
+ unique_source_embeddings[duplicate_idx] = source_embeddings[i]
453
+ unique_indices[duplicate_idx] = i
454
+ else:
455
+ # Existing statement is better - merge entity types from current
456
+ merged_stmt = existing_stmt.merge_entity_types_from(stmt)
457
+ unique_statements[duplicate_idx] = merged_stmt
357
458
 
358
459
  return unique_statements
359
460
 
@@ -437,3 +538,79 @@ class PredicateComparer:
437
538
  similarity=best_score,
438
539
  matched=False,
439
540
  )
541
+
542
+ def detect_and_fix_reversals(
543
+ self,
544
+ statements: list[Statement],
545
+ threshold: float = 0.05,
546
+ ) -> list[Statement]:
547
+ """
548
+ Detect and fix subject-object reversals using embedding comparison.
549
+
550
+ For each statement, compares:
551
+ - "Subject Predicate Object" embedding against source_text
552
+ - "Object Predicate Subject" embedding against source_text
553
+
554
+ If the reversed version has significantly higher similarity to the source,
555
+ the subject and object are swapped and was_reversed is set to True.
556
+
557
+ Args:
558
+ statements: List of Statement objects
559
+ threshold: Minimum similarity difference to trigger reversal (default 0.05)
560
+
561
+ Returns:
562
+ List of statements with reversals corrected
563
+ """
564
+ if not statements:
565
+ return statements
566
+
567
+ result = []
568
+ for stmt in statements:
569
+ # Skip if no source_text to compare against
570
+ if not stmt.source_text:
571
+ result.append(stmt)
572
+ continue
573
+
574
+ # Build normal and reversed triple strings
575
+ normal_text = f"{stmt.subject.text} {stmt.predicate} {stmt.object.text}"
576
+ reversed_text = f"{stmt.object.text} {stmt.predicate} {stmt.subject.text}"
577
+
578
+ # Compute embeddings for normal, reversed, and source
579
+ embeddings = self._compute_embeddings([normal_text, reversed_text, stmt.source_text])
580
+ normal_emb, reversed_emb, source_emb = embeddings[0], embeddings[1], embeddings[2]
581
+
582
+ # Compute similarities to source
583
+ normal_sim = self._cosine_similarity(normal_emb, source_emb)
584
+ reversed_sim = self._cosine_similarity(reversed_emb, source_emb)
585
+
586
+ # If reversed is significantly better, swap subject and object
587
+ if reversed_sim > normal_sim + threshold:
588
+ result.append(stmt.reversed())
589
+ else:
590
+ result.append(stmt)
591
+
592
+ return result
593
+
594
+ def check_reversal(self, statement: Statement) -> tuple[bool, float, float]:
595
+ """
596
+ Check if a single statement should be reversed.
597
+
598
+ Args:
599
+ statement: Statement to check
600
+
601
+ Returns:
602
+ Tuple of (should_reverse, normal_similarity, reversed_similarity)
603
+ """
604
+ if not statement.source_text:
605
+ return (False, 0.0, 0.0)
606
+
607
+ normal_text = f"{statement.subject.text} {statement.predicate} {statement.object.text}"
608
+ reversed_text = f"{statement.object.text} {statement.predicate} {statement.subject.text}"
609
+
610
+ embeddings = self._compute_embeddings([normal_text, reversed_text, statement.source_text])
611
+ normal_emb, reversed_emb, source_emb = embeddings[0], embeddings[1], embeddings[2]
612
+
613
+ normal_sim = self._cosine_similarity(normal_emb, source_emb)
614
+ reversed_sim = self._cosine_similarity(reversed_emb, source_emb)
615
+
616
+ return (reversed_sim > normal_sim, normal_sim, reversed_sim)