corp-extractor 0.2.2__tar.gz → 0.2.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {corp_extractor-0.2.2 → corp_extractor-0.2.8}/PKG-INFO +147 -5
- {corp_extractor-0.2.2 → corp_extractor-0.2.8}/README.md +144 -3
- {corp_extractor-0.2.2 → corp_extractor-0.2.8}/pyproject.toml +7 -2
- {corp_extractor-0.2.2 → corp_extractor-0.2.8}/src/statement_extractor/__init__.py +1 -1
- {corp_extractor-0.2.2 → corp_extractor-0.2.8}/src/statement_extractor/canonicalization.py +31 -5
- corp_extractor-0.2.8/src/statement_extractor/cli.py +215 -0
- {corp_extractor-0.2.2 → corp_extractor-0.2.8}/src/statement_extractor/extractor.py +24 -6
- {corp_extractor-0.2.2 → corp_extractor-0.2.8}/src/statement_extractor/models.py +59 -0
- {corp_extractor-0.2.2 → corp_extractor-0.2.8}/src/statement_extractor/predicate_comparer.py +202 -25
- {corp_extractor-0.2.2 → corp_extractor-0.2.8}/.gitignore +0 -0
- {corp_extractor-0.2.2 → corp_extractor-0.2.8}/src/statement_extractor/scoring.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: corp-extractor
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.8
|
|
4
4
|
Summary: Extract structured statements from text using T5-Gemma 2 and Diverse Beam Search
|
|
5
5
|
Project-URL: Homepage, https://github.com/corp-o-rate/statement-extractor
|
|
6
6
|
Project-URL: Documentation, https://github.com/corp-o-rate/statement-extractor#readme
|
|
@@ -23,10 +23,11 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
23
23
|
Classifier: Topic :: Scientific/Engineering :: Information Analysis
|
|
24
24
|
Classifier: Topic :: Text Processing :: Linguistic
|
|
25
25
|
Requires-Python: >=3.10
|
|
26
|
+
Requires-Dist: click>=8.0.0
|
|
26
27
|
Requires-Dist: numpy>=1.24.0
|
|
27
28
|
Requires-Dist: pydantic>=2.0.0
|
|
28
29
|
Requires-Dist: torch>=2.0.0
|
|
29
|
-
Requires-Dist: transformers>=
|
|
30
|
+
Requires-Dist: transformers>=5.0.0rc3
|
|
30
31
|
Provides-Extra: all
|
|
31
32
|
Requires-Dist: sentence-transformers>=2.2.0; extra == 'all'
|
|
32
33
|
Provides-Extra: dev
|
|
@@ -55,22 +56,30 @@ Extract structured subject-predicate-object statements from unstructured text us
|
|
|
55
56
|
- **Embedding-based Dedup** *(v0.2.0)*: Uses semantic similarity to detect near-duplicate predicates
|
|
56
57
|
- **Predicate Taxonomies** *(v0.2.0)*: Normalize predicates to canonical forms via embeddings
|
|
57
58
|
- **Contextualized Matching** *(v0.2.2)*: Compares full "Subject Predicate Object" against source text for better accuracy
|
|
59
|
+
- **Entity Type Merging** *(v0.2.3)*: Automatically merges UNKNOWN entity types with specific types during deduplication
|
|
60
|
+
- **Reversal Detection** *(v0.2.3)*: Detects and corrects subject-object reversals using embedding comparison
|
|
61
|
+
- **Command Line Interface** *(v0.2.4)*: Full-featured CLI for terminal usage
|
|
58
62
|
- **Multiple Output Formats**: Get results as Pydantic models, JSON, XML, or dictionaries
|
|
59
63
|
|
|
60
64
|
## Installation
|
|
61
65
|
|
|
62
66
|
```bash
|
|
63
67
|
# Recommended: include embedding support for smart deduplication
|
|
64
|
-
pip install corp-extractor[embeddings]
|
|
68
|
+
pip install "corp-extractor[embeddings]"
|
|
65
69
|
|
|
66
70
|
# Minimal installation (no embedding features)
|
|
67
71
|
pip install corp-extractor
|
|
68
72
|
```
|
|
69
73
|
|
|
70
|
-
**Note**:
|
|
74
|
+
**Note**: This package requires `transformers>=5.0.0` (pre-release) for T5-Gemma2 model support. Install with `--pre` flag if needed:
|
|
75
|
+
```bash
|
|
76
|
+
pip install --pre "corp-extractor[embeddings]"
|
|
77
|
+
```
|
|
78
|
+
|
|
79
|
+
**For GPU support**, install PyTorch with CUDA first:
|
|
71
80
|
```bash
|
|
72
81
|
pip install torch --index-url https://download.pytorch.org/whl/cu121
|
|
73
|
-
pip install corp-extractor[embeddings]
|
|
82
|
+
pip install "corp-extractor[embeddings]"
|
|
74
83
|
```
|
|
75
84
|
|
|
76
85
|
## Quick Start
|
|
@@ -89,6 +98,96 @@ for stmt in result:
|
|
|
89
98
|
print(f" Confidence: {stmt.confidence_score:.2f}") # NEW in v0.2.0
|
|
90
99
|
```
|
|
91
100
|
|
|
101
|
+
## Command Line Interface
|
|
102
|
+
|
|
103
|
+
The library includes a CLI for quick extraction from the terminal.
|
|
104
|
+
|
|
105
|
+
### Install Globally (Recommended)
|
|
106
|
+
|
|
107
|
+
For best results, install globally first:
|
|
108
|
+
|
|
109
|
+
```bash
|
|
110
|
+
# Using uv (recommended)
|
|
111
|
+
uv tool install "corp-extractor[embeddings]"
|
|
112
|
+
|
|
113
|
+
# Using pipx
|
|
114
|
+
pipx install "corp-extractor[embeddings]"
|
|
115
|
+
|
|
116
|
+
# Using pip
|
|
117
|
+
pip install "corp-extractor[embeddings]"
|
|
118
|
+
|
|
119
|
+
# Then use anywhere
|
|
120
|
+
corp-extractor "Your text here"
|
|
121
|
+
```
|
|
122
|
+
|
|
123
|
+
### Quick Run with uvx
|
|
124
|
+
|
|
125
|
+
Run directly without installing using [uv](https://docs.astral.sh/uv/):
|
|
126
|
+
|
|
127
|
+
```bash
|
|
128
|
+
uvx corp-extractor "Apple announced a new iPhone."
|
|
129
|
+
```
|
|
130
|
+
|
|
131
|
+
**Note**: First run downloads the model (~1.5GB) which may take a few minutes.
|
|
132
|
+
|
|
133
|
+
### Usage Examples
|
|
134
|
+
|
|
135
|
+
```bash
|
|
136
|
+
# Extract from text argument
|
|
137
|
+
corp-extractor "Apple Inc. announced the iPhone 15 at their September event."
|
|
138
|
+
|
|
139
|
+
# Extract from file
|
|
140
|
+
corp-extractor -f article.txt
|
|
141
|
+
|
|
142
|
+
# Pipe from stdin
|
|
143
|
+
cat article.txt | corp-extractor -
|
|
144
|
+
|
|
145
|
+
# Output as JSON
|
|
146
|
+
corp-extractor "Tim Cook is CEO of Apple." --json
|
|
147
|
+
|
|
148
|
+
# Output as XML
|
|
149
|
+
corp-extractor -f article.txt --xml
|
|
150
|
+
|
|
151
|
+
# Verbose output with confidence scores
|
|
152
|
+
corp-extractor -f article.txt --verbose
|
|
153
|
+
|
|
154
|
+
# Use more beams for better quality
|
|
155
|
+
corp-extractor -f article.txt --beams 8
|
|
156
|
+
|
|
157
|
+
# Use custom predicate taxonomy
|
|
158
|
+
corp-extractor -f article.txt --taxonomy predicates.txt
|
|
159
|
+
|
|
160
|
+
# Use GPU explicitly
|
|
161
|
+
corp-extractor -f article.txt --device cuda
|
|
162
|
+
```
|
|
163
|
+
|
|
164
|
+
### CLI Options
|
|
165
|
+
|
|
166
|
+
```
|
|
167
|
+
Usage: corp-extractor [OPTIONS] [TEXT]
|
|
168
|
+
|
|
169
|
+
Options:
|
|
170
|
+
-f, --file PATH Read input from file
|
|
171
|
+
-o, --output [table|json|xml] Output format (default: table)
|
|
172
|
+
--json Output as JSON (shortcut)
|
|
173
|
+
--xml Output as XML (shortcut)
|
|
174
|
+
-b, --beams INTEGER Number of beams (default: 4)
|
|
175
|
+
--diversity FLOAT Diversity penalty (default: 1.0)
|
|
176
|
+
--max-tokens INTEGER Max tokens to generate (default: 2048)
|
|
177
|
+
--no-dedup Disable deduplication
|
|
178
|
+
--no-embeddings Disable embedding-based dedup (faster)
|
|
179
|
+
--no-merge Disable beam merging
|
|
180
|
+
--dedup-threshold FLOAT Deduplication threshold (default: 0.65)
|
|
181
|
+
--min-confidence FLOAT Min confidence filter (default: 0)
|
|
182
|
+
--taxonomy PATH Load predicate taxonomy from file
|
|
183
|
+
--taxonomy-threshold FLOAT Taxonomy matching threshold (default: 0.5)
|
|
184
|
+
--device [auto|cuda|cpu] Device to use (default: auto)
|
|
185
|
+
-v, --verbose Show confidence scores and metadata
|
|
186
|
+
-q, --quiet Suppress progress messages
|
|
187
|
+
--version Show version
|
|
188
|
+
--help Show this message
|
|
189
|
+
```
|
|
190
|
+
|
|
92
191
|
## New in v0.2.0: Quality Scoring & Beam Merging
|
|
93
192
|
|
|
94
193
|
By default, the library now:
|
|
@@ -139,6 +238,47 @@ Predicate canonicalization and deduplication now use **contextualized matching**
|
|
|
139
238
|
|
|
140
239
|
This means "Apple bought Beats" vs "Apple acquired Beats" are compared holistically, not just "bought" vs "acquired".
|
|
141
240
|
|
|
241
|
+
## New in v0.2.3: Entity Type Merging & Reversal Detection
|
|
242
|
+
|
|
243
|
+
### Entity Type Merging
|
|
244
|
+
|
|
245
|
+
When deduplicating statements, entity types are now automatically merged. If one statement has `UNKNOWN` type and a duplicate has a specific type (like `ORG` or `PERSON`), the specific type is preserved:
|
|
246
|
+
|
|
247
|
+
```python
|
|
248
|
+
# Before deduplication:
|
|
249
|
+
# Statement 1: AtlasBio Labs (UNKNOWN) --sued by--> CuraPharm (ORG)
|
|
250
|
+
# Statement 2: AtlasBio Labs (ORG) --sued by--> CuraPharm (ORG)
|
|
251
|
+
|
|
252
|
+
# After deduplication:
|
|
253
|
+
# Single statement: AtlasBio Labs (ORG) --sued by--> CuraPharm (ORG)
|
|
254
|
+
```
|
|
255
|
+
|
|
256
|
+
### Subject-Object Reversal Detection
|
|
257
|
+
|
|
258
|
+
The library now detects when subject and object may have been extracted in the wrong order by comparing embeddings against source text:
|
|
259
|
+
|
|
260
|
+
```python
|
|
261
|
+
from statement_extractor import PredicateComparer
|
|
262
|
+
|
|
263
|
+
comparer = PredicateComparer()
|
|
264
|
+
|
|
265
|
+
# Automatically detect and fix reversals
|
|
266
|
+
fixed_statements = comparer.detect_and_fix_reversals(statements)
|
|
267
|
+
|
|
268
|
+
for stmt in fixed_statements:
|
|
269
|
+
if stmt.was_reversed:
|
|
270
|
+
print(f"Fixed reversal: {stmt}")
|
|
271
|
+
```
|
|
272
|
+
|
|
273
|
+
**How it works:**
|
|
274
|
+
1. For each statement with source text, compares:
|
|
275
|
+
- "Subject Predicate Object" embedding vs source text
|
|
276
|
+
- "Object Predicate Subject" embedding vs source text
|
|
277
|
+
2. If the reversed form has higher similarity, swaps subject and object
|
|
278
|
+
3. Sets `was_reversed=True` to indicate the correction
|
|
279
|
+
|
|
280
|
+
During deduplication, reversed duplicates (e.g., "A -> P -> B" and "B -> P -> A") are now detected and merged, with the correct orientation determined by source text similarity.
|
|
281
|
+
|
|
142
282
|
## Disable Embeddings (Faster, No Extra Dependencies)
|
|
143
283
|
|
|
144
284
|
```python
|
|
@@ -213,6 +353,8 @@ This library uses the T5-Gemma 2 statement extraction model with **Diverse Beam
|
|
|
213
353
|
4. **Embedding Dedup** *(v0.2.0)*: Semantic similarity removes near-duplicate predicates
|
|
214
354
|
5. **Predicate Normalization** *(v0.2.0)*: Optional taxonomy matching via embeddings
|
|
215
355
|
6. **Contextualized Matching** *(v0.2.2)*: Full statement context used for canonicalization and dedup
|
|
356
|
+
7. **Entity Type Merging** *(v0.2.3)*: UNKNOWN types merged with specific types during dedup
|
|
357
|
+
8. **Reversal Detection** *(v0.2.3)*: Subject-object reversals detected and corrected via embedding comparison
|
|
216
358
|
|
|
217
359
|
## Requirements
|
|
218
360
|
|
|
@@ -15,22 +15,30 @@ Extract structured subject-predicate-object statements from unstructured text us
|
|
|
15
15
|
- **Embedding-based Dedup** *(v0.2.0)*: Uses semantic similarity to detect near-duplicate predicates
|
|
16
16
|
- **Predicate Taxonomies** *(v0.2.0)*: Normalize predicates to canonical forms via embeddings
|
|
17
17
|
- **Contextualized Matching** *(v0.2.2)*: Compares full "Subject Predicate Object" against source text for better accuracy
|
|
18
|
+
- **Entity Type Merging** *(v0.2.3)*: Automatically merges UNKNOWN entity types with specific types during deduplication
|
|
19
|
+
- **Reversal Detection** *(v0.2.3)*: Detects and corrects subject-object reversals using embedding comparison
|
|
20
|
+
- **Command Line Interface** *(v0.2.4)*: Full-featured CLI for terminal usage
|
|
18
21
|
- **Multiple Output Formats**: Get results as Pydantic models, JSON, XML, or dictionaries
|
|
19
22
|
|
|
20
23
|
## Installation
|
|
21
24
|
|
|
22
25
|
```bash
|
|
23
26
|
# Recommended: include embedding support for smart deduplication
|
|
24
|
-
pip install corp-extractor[embeddings]
|
|
27
|
+
pip install "corp-extractor[embeddings]"
|
|
25
28
|
|
|
26
29
|
# Minimal installation (no embedding features)
|
|
27
30
|
pip install corp-extractor
|
|
28
31
|
```
|
|
29
32
|
|
|
30
|
-
**Note**:
|
|
33
|
+
**Note**: This package requires `transformers>=5.0.0` (pre-release) for T5-Gemma2 model support. Install with `--pre` flag if needed:
|
|
34
|
+
```bash
|
|
35
|
+
pip install --pre "corp-extractor[embeddings]"
|
|
36
|
+
```
|
|
37
|
+
|
|
38
|
+
**For GPU support**, install PyTorch with CUDA first:
|
|
31
39
|
```bash
|
|
32
40
|
pip install torch --index-url https://download.pytorch.org/whl/cu121
|
|
33
|
-
pip install corp-extractor[embeddings]
|
|
41
|
+
pip install "corp-extractor[embeddings]"
|
|
34
42
|
```
|
|
35
43
|
|
|
36
44
|
## Quick Start
|
|
@@ -49,6 +57,96 @@ for stmt in result:
|
|
|
49
57
|
print(f" Confidence: {stmt.confidence_score:.2f}") # NEW in v0.2.0
|
|
50
58
|
```
|
|
51
59
|
|
|
60
|
+
## Command Line Interface
|
|
61
|
+
|
|
62
|
+
The library includes a CLI for quick extraction from the terminal.
|
|
63
|
+
|
|
64
|
+
### Install Globally (Recommended)
|
|
65
|
+
|
|
66
|
+
For best results, install globally first:
|
|
67
|
+
|
|
68
|
+
```bash
|
|
69
|
+
# Using uv (recommended)
|
|
70
|
+
uv tool install "corp-extractor[embeddings]"
|
|
71
|
+
|
|
72
|
+
# Using pipx
|
|
73
|
+
pipx install "corp-extractor[embeddings]"
|
|
74
|
+
|
|
75
|
+
# Using pip
|
|
76
|
+
pip install "corp-extractor[embeddings]"
|
|
77
|
+
|
|
78
|
+
# Then use anywhere
|
|
79
|
+
corp-extractor "Your text here"
|
|
80
|
+
```
|
|
81
|
+
|
|
82
|
+
### Quick Run with uvx
|
|
83
|
+
|
|
84
|
+
Run directly without installing using [uv](https://docs.astral.sh/uv/):
|
|
85
|
+
|
|
86
|
+
```bash
|
|
87
|
+
uvx corp-extractor "Apple announced a new iPhone."
|
|
88
|
+
```
|
|
89
|
+
|
|
90
|
+
**Note**: First run downloads the model (~1.5GB) which may take a few minutes.
|
|
91
|
+
|
|
92
|
+
### Usage Examples
|
|
93
|
+
|
|
94
|
+
```bash
|
|
95
|
+
# Extract from text argument
|
|
96
|
+
corp-extractor "Apple Inc. announced the iPhone 15 at their September event."
|
|
97
|
+
|
|
98
|
+
# Extract from file
|
|
99
|
+
corp-extractor -f article.txt
|
|
100
|
+
|
|
101
|
+
# Pipe from stdin
|
|
102
|
+
cat article.txt | corp-extractor -
|
|
103
|
+
|
|
104
|
+
# Output as JSON
|
|
105
|
+
corp-extractor "Tim Cook is CEO of Apple." --json
|
|
106
|
+
|
|
107
|
+
# Output as XML
|
|
108
|
+
corp-extractor -f article.txt --xml
|
|
109
|
+
|
|
110
|
+
# Verbose output with confidence scores
|
|
111
|
+
corp-extractor -f article.txt --verbose
|
|
112
|
+
|
|
113
|
+
# Use more beams for better quality
|
|
114
|
+
corp-extractor -f article.txt --beams 8
|
|
115
|
+
|
|
116
|
+
# Use custom predicate taxonomy
|
|
117
|
+
corp-extractor -f article.txt --taxonomy predicates.txt
|
|
118
|
+
|
|
119
|
+
# Use GPU explicitly
|
|
120
|
+
corp-extractor -f article.txt --device cuda
|
|
121
|
+
```
|
|
122
|
+
|
|
123
|
+
### CLI Options
|
|
124
|
+
|
|
125
|
+
```
|
|
126
|
+
Usage: corp-extractor [OPTIONS] [TEXT]
|
|
127
|
+
|
|
128
|
+
Options:
|
|
129
|
+
-f, --file PATH Read input from file
|
|
130
|
+
-o, --output [table|json|xml] Output format (default: table)
|
|
131
|
+
--json Output as JSON (shortcut)
|
|
132
|
+
--xml Output as XML (shortcut)
|
|
133
|
+
-b, --beams INTEGER Number of beams (default: 4)
|
|
134
|
+
--diversity FLOAT Diversity penalty (default: 1.0)
|
|
135
|
+
--max-tokens INTEGER Max tokens to generate (default: 2048)
|
|
136
|
+
--no-dedup Disable deduplication
|
|
137
|
+
--no-embeddings Disable embedding-based dedup (faster)
|
|
138
|
+
--no-merge Disable beam merging
|
|
139
|
+
--dedup-threshold FLOAT Deduplication threshold (default: 0.65)
|
|
140
|
+
--min-confidence FLOAT Min confidence filter (default: 0)
|
|
141
|
+
--taxonomy PATH Load predicate taxonomy from file
|
|
142
|
+
--taxonomy-threshold FLOAT Taxonomy matching threshold (default: 0.5)
|
|
143
|
+
--device [auto|cuda|cpu] Device to use (default: auto)
|
|
144
|
+
-v, --verbose Show confidence scores and metadata
|
|
145
|
+
-q, --quiet Suppress progress messages
|
|
146
|
+
--version Show version
|
|
147
|
+
--help Show this message
|
|
148
|
+
```
|
|
149
|
+
|
|
52
150
|
## New in v0.2.0: Quality Scoring & Beam Merging
|
|
53
151
|
|
|
54
152
|
By default, the library now:
|
|
@@ -99,6 +197,47 @@ Predicate canonicalization and deduplication now use **contextualized matching**
|
|
|
99
197
|
|
|
100
198
|
This means "Apple bought Beats" vs "Apple acquired Beats" are compared holistically, not just "bought" vs "acquired".
|
|
101
199
|
|
|
200
|
+
## New in v0.2.3: Entity Type Merging & Reversal Detection
|
|
201
|
+
|
|
202
|
+
### Entity Type Merging
|
|
203
|
+
|
|
204
|
+
When deduplicating statements, entity types are now automatically merged. If one statement has `UNKNOWN` type and a duplicate has a specific type (like `ORG` or `PERSON`), the specific type is preserved:
|
|
205
|
+
|
|
206
|
+
```python
|
|
207
|
+
# Before deduplication:
|
|
208
|
+
# Statement 1: AtlasBio Labs (UNKNOWN) --sued by--> CuraPharm (ORG)
|
|
209
|
+
# Statement 2: AtlasBio Labs (ORG) --sued by--> CuraPharm (ORG)
|
|
210
|
+
|
|
211
|
+
# After deduplication:
|
|
212
|
+
# Single statement: AtlasBio Labs (ORG) --sued by--> CuraPharm (ORG)
|
|
213
|
+
```
|
|
214
|
+
|
|
215
|
+
### Subject-Object Reversal Detection
|
|
216
|
+
|
|
217
|
+
The library now detects when subject and object may have been extracted in the wrong order by comparing embeddings against source text:
|
|
218
|
+
|
|
219
|
+
```python
|
|
220
|
+
from statement_extractor import PredicateComparer
|
|
221
|
+
|
|
222
|
+
comparer = PredicateComparer()
|
|
223
|
+
|
|
224
|
+
# Automatically detect and fix reversals
|
|
225
|
+
fixed_statements = comparer.detect_and_fix_reversals(statements)
|
|
226
|
+
|
|
227
|
+
for stmt in fixed_statements:
|
|
228
|
+
if stmt.was_reversed:
|
|
229
|
+
print(f"Fixed reversal: {stmt}")
|
|
230
|
+
```
|
|
231
|
+
|
|
232
|
+
**How it works:**
|
|
233
|
+
1. For each statement with source text, compares:
|
|
234
|
+
- "Subject Predicate Object" embedding vs source text
|
|
235
|
+
- "Object Predicate Subject" embedding vs source text
|
|
236
|
+
2. If the reversed form has higher similarity, swaps subject and object
|
|
237
|
+
3. Sets `was_reversed=True` to indicate the correction
|
|
238
|
+
|
|
239
|
+
During deduplication, reversed duplicates (e.g., "A -> P -> B" and "B -> P -> A") are now detected and merged, with the correct orientation determined by source text similarity.
|
|
240
|
+
|
|
102
241
|
## Disable Embeddings (Faster, No Extra Dependencies)
|
|
103
242
|
|
|
104
243
|
```python
|
|
@@ -173,6 +312,8 @@ This library uses the T5-Gemma 2 statement extraction model with **Diverse Beam
|
|
|
173
312
|
4. **Embedding Dedup** *(v0.2.0)*: Semantic similarity removes near-duplicate predicates
|
|
174
313
|
5. **Predicate Normalization** *(v0.2.0)*: Optional taxonomy matching via embeddings
|
|
175
314
|
6. **Contextualized Matching** *(v0.2.2)*: Full statement context used for canonicalization and dedup
|
|
315
|
+
7. **Entity Type Merging** *(v0.2.3)*: UNKNOWN types merged with specific types during dedup
|
|
316
|
+
8. **Reversal Detection** *(v0.2.3)*: Subject-object reversals detected and corrected via embedding comparison
|
|
176
317
|
|
|
177
318
|
## Requirements
|
|
178
319
|
|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "corp-extractor"
|
|
7
|
-
version = "0.2.
|
|
7
|
+
version = "0.2.8"
|
|
8
8
|
description = "Extract structured statements from text using T5-Gemma 2 and Diverse Beam Search"
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
requires-python = ">=3.10"
|
|
@@ -46,8 +46,9 @@ classifiers = [
|
|
|
46
46
|
dependencies = [
|
|
47
47
|
"pydantic>=2.0.0",
|
|
48
48
|
"torch>=2.0.0",
|
|
49
|
-
"transformers>=
|
|
49
|
+
"transformers>=5.0.0rc3",
|
|
50
50
|
"numpy>=1.24.0",
|
|
51
|
+
"click>=8.0.0",
|
|
51
52
|
]
|
|
52
53
|
|
|
53
54
|
[project.optional-dependencies]
|
|
@@ -66,6 +67,10 @@ all = [
|
|
|
66
67
|
"sentence-transformers>=2.2.0",
|
|
67
68
|
]
|
|
68
69
|
|
|
70
|
+
[project.scripts]
|
|
71
|
+
statement-extractor = "statement_extractor.cli:main"
|
|
72
|
+
corp-extractor = "statement_extractor.cli:main"
|
|
73
|
+
|
|
69
74
|
[project.urls]
|
|
70
75
|
Homepage = "https://github.com/corp-o-rate/statement-extractor"
|
|
71
76
|
Documentation = "https://github.com/corp-o-rate/statement-extractor#readme"
|
|
@@ -139,32 +139,58 @@ class Canonicalizer:
|
|
|
139
139
|
|
|
140
140
|
def deduplicate_statements_exact(
|
|
141
141
|
statements: list[Statement],
|
|
142
|
-
entity_canonicalizer: Optional[Callable[[str], str]] = None
|
|
142
|
+
entity_canonicalizer: Optional[Callable[[str], str]] = None,
|
|
143
|
+
detect_reversals: bool = True,
|
|
143
144
|
) -> list[Statement]:
|
|
144
145
|
"""
|
|
145
146
|
Deduplicate statements using exact text matching.
|
|
146
147
|
|
|
147
148
|
Use this when embedding-based deduplication is disabled.
|
|
149
|
+
When duplicates are found, entity types are merged - specific types
|
|
150
|
+
(ORG, PERSON, etc.) take precedence over UNKNOWN.
|
|
151
|
+
|
|
152
|
+
When detect_reversals=True, also detects reversed duplicates where
|
|
153
|
+
subject and object are swapped. The first occurrence determines the
|
|
154
|
+
canonical orientation.
|
|
148
155
|
|
|
149
156
|
Args:
|
|
150
157
|
statements: List of statements to deduplicate
|
|
151
158
|
entity_canonicalizer: Optional custom canonicalization function
|
|
159
|
+
detect_reversals: Whether to detect reversed duplicates (default True)
|
|
152
160
|
|
|
153
161
|
Returns:
|
|
154
|
-
Deduplicated list
|
|
162
|
+
Deduplicated list with merged entity types
|
|
155
163
|
"""
|
|
156
164
|
if len(statements) <= 1:
|
|
157
165
|
return statements
|
|
158
166
|
|
|
159
167
|
canonicalizer = Canonicalizer(entity_fn=entity_canonicalizer)
|
|
160
168
|
|
|
161
|
-
|
|
169
|
+
# Map from dedup key to index in unique list
|
|
170
|
+
seen: dict[tuple[str, str, str], int] = {}
|
|
162
171
|
unique: list[Statement] = []
|
|
163
172
|
|
|
164
173
|
for stmt in statements:
|
|
165
174
|
key = canonicalizer.create_dedup_key(stmt)
|
|
166
|
-
|
|
167
|
-
|
|
175
|
+
# Also compute reversed key (object, predicate, subject)
|
|
176
|
+
reversed_key = (key[2], key[1], key[0])
|
|
177
|
+
|
|
178
|
+
if key in seen:
|
|
179
|
+
# Direct duplicate found - merge entity types
|
|
180
|
+
existing_idx = seen[key]
|
|
181
|
+
existing_stmt = unique[existing_idx]
|
|
182
|
+
merged_stmt = existing_stmt.merge_entity_types_from(stmt)
|
|
183
|
+
unique[existing_idx] = merged_stmt
|
|
184
|
+
elif detect_reversals and reversed_key in seen:
|
|
185
|
+
# Reversed duplicate found - merge entity types (accounting for reversal)
|
|
186
|
+
existing_idx = seen[reversed_key]
|
|
187
|
+
existing_stmt = unique[existing_idx]
|
|
188
|
+
# Merge types from the reversed statement
|
|
189
|
+
merged_stmt = existing_stmt.merge_entity_types_from(stmt.reversed())
|
|
190
|
+
unique[existing_idx] = merged_stmt
|
|
191
|
+
else:
|
|
192
|
+
# New unique statement
|
|
193
|
+
seen[key] = len(unique)
|
|
168
194
|
unique.append(stmt)
|
|
169
195
|
|
|
170
196
|
return unique
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Command-line interface for statement extraction.
|
|
3
|
+
|
|
4
|
+
Usage:
|
|
5
|
+
corp-extractor "Your text here"
|
|
6
|
+
corp-extractor -f input.txt
|
|
7
|
+
cat input.txt | corp-extractor -
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import sys
|
|
11
|
+
from typing import Optional
|
|
12
|
+
|
|
13
|
+
import click
|
|
14
|
+
|
|
15
|
+
from . import __version__
|
|
16
|
+
from .models import (
|
|
17
|
+
ExtractionOptions,
|
|
18
|
+
PredicateComparisonConfig,
|
|
19
|
+
PredicateTaxonomy,
|
|
20
|
+
ScoringConfig,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@click.command()
|
|
25
|
+
@click.argument("text", required=False)
|
|
26
|
+
@click.option("-f", "--file", "input_file", type=click.Path(exists=True), help="Read input from file")
|
|
27
|
+
@click.option(
|
|
28
|
+
"-o", "--output",
|
|
29
|
+
type=click.Choice(["table", "json", "xml"], case_sensitive=False),
|
|
30
|
+
default="table",
|
|
31
|
+
help="Output format (default: table)"
|
|
32
|
+
)
|
|
33
|
+
@click.option("--json", "output_json", is_flag=True, help="Output as JSON (shortcut for -o json)")
|
|
34
|
+
@click.option("--xml", "output_xml", is_flag=True, help="Output as XML (shortcut for -o xml)")
|
|
35
|
+
# Beam search options
|
|
36
|
+
@click.option("-b", "--beams", type=int, default=4, help="Number of beams for diverse beam search (default: 4)")
|
|
37
|
+
@click.option("--diversity", type=float, default=1.0, help="Diversity penalty for beam search (default: 1.0)")
|
|
38
|
+
@click.option("--max-tokens", type=int, default=2048, help="Maximum tokens to generate (default: 2048)")
|
|
39
|
+
# Deduplication options
|
|
40
|
+
@click.option("--no-dedup", is_flag=True, help="Disable deduplication")
|
|
41
|
+
@click.option("--no-embeddings", is_flag=True, help="Disable embedding-based deduplication (faster)")
|
|
42
|
+
@click.option("--no-merge", is_flag=True, help="Disable beam merging (select single best beam)")
|
|
43
|
+
@click.option("--dedup-threshold", type=float, default=0.65, help="Similarity threshold for deduplication (default: 0.65)")
|
|
44
|
+
# Quality options
|
|
45
|
+
@click.option("--min-confidence", type=float, default=0.0, help="Minimum confidence threshold 0-1 (default: 0)")
|
|
46
|
+
# Taxonomy options
|
|
47
|
+
@click.option("--taxonomy", type=click.Path(exists=True), help="Load predicate taxonomy from file (one per line)")
|
|
48
|
+
@click.option("--taxonomy-threshold", type=float, default=0.5, help="Similarity threshold for taxonomy matching (default: 0.5)")
|
|
49
|
+
# Device options
|
|
50
|
+
@click.option("--device", type=click.Choice(["auto", "cuda", "mps", "cpu"]), default="auto", help="Device to use (default: auto)")
|
|
51
|
+
# Output options
|
|
52
|
+
@click.option("-v", "--verbose", is_flag=True, help="Show verbose output with confidence scores")
|
|
53
|
+
@click.option("-q", "--quiet", is_flag=True, help="Suppress progress messages")
|
|
54
|
+
@click.version_option(version=__version__)
|
|
55
|
+
def main(
|
|
56
|
+
text: Optional[str],
|
|
57
|
+
input_file: Optional[str],
|
|
58
|
+
output: str,
|
|
59
|
+
output_json: bool,
|
|
60
|
+
output_xml: bool,
|
|
61
|
+
beams: int,
|
|
62
|
+
diversity: float,
|
|
63
|
+
max_tokens: int,
|
|
64
|
+
no_dedup: bool,
|
|
65
|
+
no_embeddings: bool,
|
|
66
|
+
no_merge: bool,
|
|
67
|
+
dedup_threshold: float,
|
|
68
|
+
min_confidence: float,
|
|
69
|
+
taxonomy: Optional[str],
|
|
70
|
+
taxonomy_threshold: float,
|
|
71
|
+
device: str,
|
|
72
|
+
verbose: bool,
|
|
73
|
+
quiet: bool,
|
|
74
|
+
):
|
|
75
|
+
"""
|
|
76
|
+
Extract structured statements from text.
|
|
77
|
+
|
|
78
|
+
TEXT can be provided as an argument, read from a file with -f, or piped via stdin.
|
|
79
|
+
|
|
80
|
+
\b
|
|
81
|
+
Examples:
|
|
82
|
+
corp-extractor "Apple announced a new iPhone."
|
|
83
|
+
corp-extractor -f article.txt --json
|
|
84
|
+
corp-extractor -f article.txt -o json --beams 8
|
|
85
|
+
cat article.txt | corp-extractor -
|
|
86
|
+
echo "Tim Cook is CEO of Apple." | corp-extractor - --verbose
|
|
87
|
+
|
|
88
|
+
\b
|
|
89
|
+
Output formats:
|
|
90
|
+
table Human-readable table (default)
|
|
91
|
+
json JSON with full metadata
|
|
92
|
+
xml Raw XML from model
|
|
93
|
+
"""
|
|
94
|
+
# Determine output format
|
|
95
|
+
if output_json:
|
|
96
|
+
output = "json"
|
|
97
|
+
elif output_xml:
|
|
98
|
+
output = "xml"
|
|
99
|
+
|
|
100
|
+
# Get input text
|
|
101
|
+
input_text = _get_input_text(text, input_file)
|
|
102
|
+
if not input_text:
|
|
103
|
+
raise click.UsageError(
|
|
104
|
+
"No input provided. Use: statement-extractor \"text\", "
|
|
105
|
+
"statement-extractor -f file.txt, or pipe via stdin."
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
if not quiet:
|
|
109
|
+
click.echo(f"Processing {len(input_text)} characters...", err=True)
|
|
110
|
+
|
|
111
|
+
# Load taxonomy if provided
|
|
112
|
+
predicate_taxonomy = None
|
|
113
|
+
if taxonomy:
|
|
114
|
+
predicate_taxonomy = PredicateTaxonomy.from_file(taxonomy)
|
|
115
|
+
if not quiet:
|
|
116
|
+
click.echo(f"Loaded taxonomy with {len(predicate_taxonomy.predicates)} predicates", err=True)
|
|
117
|
+
|
|
118
|
+
# Configure predicate comparison
|
|
119
|
+
predicate_config = PredicateComparisonConfig(
|
|
120
|
+
similarity_threshold=taxonomy_threshold,
|
|
121
|
+
dedup_threshold=dedup_threshold,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# Configure scoring
|
|
125
|
+
scoring_config = ScoringConfig(min_confidence=min_confidence)
|
|
126
|
+
|
|
127
|
+
# Configure extraction options
|
|
128
|
+
options = ExtractionOptions(
|
|
129
|
+
num_beams=beams,
|
|
130
|
+
diversity_penalty=diversity,
|
|
131
|
+
max_new_tokens=max_tokens,
|
|
132
|
+
deduplicate=not no_dedup,
|
|
133
|
+
embedding_dedup=not no_embeddings,
|
|
134
|
+
merge_beams=not no_merge,
|
|
135
|
+
predicate_taxonomy=predicate_taxonomy,
|
|
136
|
+
predicate_config=predicate_config,
|
|
137
|
+
scoring_config=scoring_config,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Import here to allow --help without loading torch
|
|
141
|
+
from .extractor import StatementExtractor
|
|
142
|
+
|
|
143
|
+
# Create extractor with specified device
|
|
144
|
+
device_arg = None if device == "auto" else device
|
|
145
|
+
extractor = StatementExtractor(device=device_arg)
|
|
146
|
+
|
|
147
|
+
if not quiet:
|
|
148
|
+
click.echo(f"Using device: {extractor.device}", err=True)
|
|
149
|
+
|
|
150
|
+
# Run extraction
|
|
151
|
+
try:
|
|
152
|
+
if output == "xml":
|
|
153
|
+
result = extractor.extract_as_xml(input_text, options)
|
|
154
|
+
click.echo(result)
|
|
155
|
+
elif output == "json":
|
|
156
|
+
result = extractor.extract_as_json(input_text, options)
|
|
157
|
+
click.echo(result)
|
|
158
|
+
else:
|
|
159
|
+
# Table format
|
|
160
|
+
result = extractor.extract(input_text, options)
|
|
161
|
+
_print_table(result, verbose)
|
|
162
|
+
except Exception as e:
|
|
163
|
+
raise click.ClickException(f"Extraction failed: {e}")
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _get_input_text(text: Optional[str], input_file: Optional[str]) -> Optional[str]:
|
|
167
|
+
"""Get input text from argument, file, or stdin."""
|
|
168
|
+
if text == "-" or (text is None and input_file is None and not sys.stdin.isatty()):
|
|
169
|
+
# Read from stdin
|
|
170
|
+
return sys.stdin.read().strip()
|
|
171
|
+
elif input_file:
|
|
172
|
+
# Read from file
|
|
173
|
+
with open(input_file, "r", encoding="utf-8") as f:
|
|
174
|
+
return f.read().strip()
|
|
175
|
+
elif text:
|
|
176
|
+
return text.strip()
|
|
177
|
+
return None
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _print_table(result, verbose: bool):
|
|
181
|
+
"""Print statements in a human-readable table format."""
|
|
182
|
+
if not result.statements:
|
|
183
|
+
click.echo("No statements extracted.")
|
|
184
|
+
return
|
|
185
|
+
|
|
186
|
+
click.echo(f"\nExtracted {len(result.statements)} statement(s):\n")
|
|
187
|
+
click.echo("-" * 80)
|
|
188
|
+
|
|
189
|
+
for i, stmt in enumerate(result.statements, 1):
|
|
190
|
+
subject_type = f" ({stmt.subject.type.value})" if stmt.subject.type.value != "UNKNOWN" else ""
|
|
191
|
+
object_type = f" ({stmt.object.type.value})" if stmt.object.type.value != "UNKNOWN" else ""
|
|
192
|
+
|
|
193
|
+
click.echo(f"{i}. {stmt.subject.text}{subject_type}")
|
|
194
|
+
click.echo(f" --[{stmt.predicate}]-->")
|
|
195
|
+
click.echo(f" {stmt.object.text}{object_type}")
|
|
196
|
+
|
|
197
|
+
if verbose:
|
|
198
|
+
if stmt.confidence_score is not None:
|
|
199
|
+
click.echo(f" Confidence: {stmt.confidence_score:.2f}")
|
|
200
|
+
|
|
201
|
+
if stmt.canonical_predicate:
|
|
202
|
+
click.echo(f" Canonical: {stmt.canonical_predicate}")
|
|
203
|
+
|
|
204
|
+
if stmt.was_reversed:
|
|
205
|
+
click.echo(f" (subject/object were swapped)")
|
|
206
|
+
|
|
207
|
+
if stmt.source_text:
|
|
208
|
+
source = stmt.source_text[:60] + "..." if len(stmt.source_text) > 60 else stmt.source_text
|
|
209
|
+
click.echo(f" Source: \"{source}\"")
|
|
210
|
+
|
|
211
|
+
click.echo("-" * 80)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
if __name__ == "__main__":
|
|
215
|
+
main()
|
|
@@ -80,11 +80,16 @@ class StatementExtractor:
|
|
|
80
80
|
|
|
81
81
|
# Auto-detect device
|
|
82
82
|
if device is None:
|
|
83
|
-
|
|
83
|
+
if torch.cuda.is_available():
|
|
84
|
+
self.device = "cuda"
|
|
85
|
+
elif torch.backends.mps.is_available():
|
|
86
|
+
self.device = "mps"
|
|
87
|
+
else:
|
|
88
|
+
self.device = "cpu"
|
|
84
89
|
else:
|
|
85
90
|
self.device = device
|
|
86
91
|
|
|
87
|
-
# Auto-detect dtype
|
|
92
|
+
# Auto-detect dtype (bfloat16 only for CUDA, float32 for MPS/CPU)
|
|
88
93
|
if torch_dtype is None:
|
|
89
94
|
self.torch_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
|
|
90
95
|
else:
|
|
@@ -143,7 +148,7 @@ class StatementExtractor:
|
|
|
143
148
|
|
|
144
149
|
taxonomy = options.predicate_taxonomy or self._predicate_taxonomy
|
|
145
150
|
config = options.predicate_config or self._predicate_config or PredicateComparisonConfig()
|
|
146
|
-
return PredicateComparer(taxonomy=taxonomy, config=config)
|
|
151
|
+
return PredicateComparer(taxonomy=taxonomy, config=config, device=self.device)
|
|
147
152
|
|
|
148
153
|
@property
|
|
149
154
|
def model(self) -> AutoModelForSeq2SeqLM:
|
|
@@ -350,30 +355,41 @@ class StatementExtractor:
|
|
|
350
355
|
outputs = self.model.generate(
|
|
351
356
|
**inputs,
|
|
352
357
|
max_new_tokens=options.max_new_tokens,
|
|
358
|
+
max_length=None, # Override model default, use max_new_tokens only
|
|
353
359
|
num_beams=num_seqs,
|
|
354
360
|
num_beam_groups=num_seqs,
|
|
355
361
|
num_return_sequences=num_seqs,
|
|
356
362
|
diversity_penalty=options.diversity_penalty,
|
|
357
363
|
do_sample=False,
|
|
364
|
+
top_p=None, # Override model config to suppress warning
|
|
365
|
+
top_k=None, # Override model config to suppress warning
|
|
358
366
|
trust_remote_code=True,
|
|
367
|
+
custom_generate="transformers-community/group-beam-search",
|
|
359
368
|
)
|
|
360
369
|
|
|
361
370
|
# Decode and process candidates
|
|
362
371
|
end_tag = "</statements>"
|
|
363
372
|
candidates: list[str] = []
|
|
364
373
|
|
|
365
|
-
for output in outputs:
|
|
374
|
+
for i, output in enumerate(outputs):
|
|
366
375
|
decoded = self.tokenizer.decode(output, skip_special_tokens=True)
|
|
376
|
+
output_len = len(output)
|
|
367
377
|
|
|
368
378
|
# Truncate at </statements>
|
|
369
379
|
if end_tag in decoded:
|
|
370
380
|
end_pos = decoded.find(end_tag) + len(end_tag)
|
|
371
381
|
decoded = decoded[:end_pos]
|
|
372
382
|
candidates.append(decoded)
|
|
383
|
+
logger.debug(f"Beam {i}: {output_len} tokens, found end tag, {len(decoded)} chars")
|
|
384
|
+
else:
|
|
385
|
+
# Log the issue - likely truncated
|
|
386
|
+
logger.warning(f"Beam {i}: {output_len} tokens, NO end tag found (truncated?)")
|
|
387
|
+
logger.warning(f"Beam {i} full output ({len(decoded)} chars):\n{decoded}")
|
|
373
388
|
|
|
374
389
|
# Include fallback if no valid candidates
|
|
375
390
|
if not candidates and len(outputs) > 0:
|
|
376
391
|
fallback = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
392
|
+
logger.warning(f"Using fallback beam (no valid candidates found), {len(fallback)} chars")
|
|
377
393
|
candidates.append(fallback)
|
|
378
394
|
|
|
379
395
|
return candidates
|
|
@@ -467,8 +483,10 @@ class StatementExtractor:
|
|
|
467
483
|
|
|
468
484
|
try:
|
|
469
485
|
root = ET.fromstring(xml_output)
|
|
470
|
-
except ET.ParseError:
|
|
471
|
-
|
|
486
|
+
except ET.ParseError as e:
|
|
487
|
+
# Log full output for debugging
|
|
488
|
+
logger.warning(f"Failed to parse XML output: {e}")
|
|
489
|
+
logger.warning(f"Full XML output ({len(xml_output)} chars):\n{xml_output}")
|
|
472
490
|
return statements
|
|
473
491
|
|
|
474
492
|
if root.tag != 'statements':
|
|
@@ -32,6 +32,18 @@ class Entity(BaseModel):
|
|
|
32
32
|
def __str__(self) -> str:
|
|
33
33
|
return f"{self.text} ({self.type.value})"
|
|
34
34
|
|
|
35
|
+
def merge_type_from(self, other: "Entity") -> "Entity":
|
|
36
|
+
"""
|
|
37
|
+
Return a new Entity with the more specific type.
|
|
38
|
+
|
|
39
|
+
If this entity has UNKNOWN type and other has a specific type,
|
|
40
|
+
returns a new entity with this text but other's type.
|
|
41
|
+
Otherwise returns self unchanged.
|
|
42
|
+
"""
|
|
43
|
+
if self.type == EntityType.UNKNOWN and other.type != EntityType.UNKNOWN:
|
|
44
|
+
return Entity(text=self.text, type=other.type)
|
|
45
|
+
return self
|
|
46
|
+
|
|
35
47
|
|
|
36
48
|
class Statement(BaseModel):
|
|
37
49
|
"""A single extracted statement (subject-predicate-object triple)."""
|
|
@@ -55,6 +67,10 @@ class Statement(BaseModel):
|
|
|
55
67
|
None,
|
|
56
68
|
description="Canonical form of the predicate if taxonomy matching was used"
|
|
57
69
|
)
|
|
70
|
+
was_reversed: bool = Field(
|
|
71
|
+
default=False,
|
|
72
|
+
description="True if subject/object were swapped during reversal detection"
|
|
73
|
+
)
|
|
58
74
|
|
|
59
75
|
def __str__(self) -> str:
|
|
60
76
|
return f"{self.subject.text} -- {self.predicate} --> {self.object.text}"
|
|
@@ -63,6 +79,49 @@ class Statement(BaseModel):
|
|
|
63
79
|
"""Return as a simple (subject, predicate, object) tuple."""
|
|
64
80
|
return (self.subject.text, self.predicate, self.object.text)
|
|
65
81
|
|
|
82
|
+
def merge_entity_types_from(self, other: "Statement") -> "Statement":
|
|
83
|
+
"""
|
|
84
|
+
Return a new Statement with more specific entity types merged from other.
|
|
85
|
+
|
|
86
|
+
If this statement has UNKNOWN entity types and other has specific types,
|
|
87
|
+
the returned statement will use the specific types from other.
|
|
88
|
+
All other fields come from self.
|
|
89
|
+
"""
|
|
90
|
+
merged_subject = self.subject.merge_type_from(other.subject)
|
|
91
|
+
merged_object = self.object.merge_type_from(other.object)
|
|
92
|
+
|
|
93
|
+
# Only create new statement if something changed
|
|
94
|
+
if merged_subject is self.subject and merged_object is self.object:
|
|
95
|
+
return self
|
|
96
|
+
|
|
97
|
+
return Statement(
|
|
98
|
+
subject=merged_subject,
|
|
99
|
+
object=merged_object,
|
|
100
|
+
predicate=self.predicate,
|
|
101
|
+
source_text=self.source_text,
|
|
102
|
+
confidence_score=self.confidence_score,
|
|
103
|
+
evidence_span=self.evidence_span,
|
|
104
|
+
canonical_predicate=self.canonical_predicate,
|
|
105
|
+
was_reversed=self.was_reversed,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
def reversed(self) -> "Statement":
|
|
109
|
+
"""
|
|
110
|
+
Return a new Statement with subject and object swapped.
|
|
111
|
+
|
|
112
|
+
Sets was_reversed=True to indicate the swap occurred.
|
|
113
|
+
"""
|
|
114
|
+
return Statement(
|
|
115
|
+
subject=self.object,
|
|
116
|
+
object=self.subject,
|
|
117
|
+
predicate=self.predicate,
|
|
118
|
+
source_text=self.source_text,
|
|
119
|
+
confidence_score=self.confidence_score,
|
|
120
|
+
evidence_span=self.evidence_span,
|
|
121
|
+
canonical_predicate=self.canonical_predicate,
|
|
122
|
+
was_reversed=True,
|
|
123
|
+
)
|
|
124
|
+
|
|
66
125
|
|
|
67
126
|
class ExtractionResult(BaseModel):
|
|
68
127
|
"""The result of statement extraction from text."""
|
|
@@ -62,6 +62,7 @@ class PredicateComparer:
|
|
|
62
62
|
self,
|
|
63
63
|
taxonomy: Optional[PredicateTaxonomy] = None,
|
|
64
64
|
config: Optional[PredicateComparisonConfig] = None,
|
|
65
|
+
device: Optional[str] = None,
|
|
65
66
|
):
|
|
66
67
|
"""
|
|
67
68
|
Initialize the predicate comparer.
|
|
@@ -69,6 +70,7 @@ class PredicateComparer:
|
|
|
69
70
|
Args:
|
|
70
71
|
taxonomy: Optional canonical predicate taxonomy for normalization
|
|
71
72
|
config: Comparison configuration (uses defaults if not provided)
|
|
73
|
+
device: Device to use ('cuda', 'cpu', or None for auto-detect)
|
|
72
74
|
|
|
73
75
|
Raises:
|
|
74
76
|
EmbeddingDependencyError: If sentence-transformers is not installed
|
|
@@ -78,6 +80,18 @@ class PredicateComparer:
|
|
|
78
80
|
self.taxonomy = taxonomy
|
|
79
81
|
self.config = config or PredicateComparisonConfig()
|
|
80
82
|
|
|
83
|
+
# Auto-detect device
|
|
84
|
+
if device is None:
|
|
85
|
+
import torch
|
|
86
|
+
if torch.cuda.is_available():
|
|
87
|
+
self.device = "cuda"
|
|
88
|
+
elif torch.backends.mps.is_available():
|
|
89
|
+
self.device = "mps"
|
|
90
|
+
else:
|
|
91
|
+
self.device = "cpu"
|
|
92
|
+
else:
|
|
93
|
+
self.device = device
|
|
94
|
+
|
|
81
95
|
# Lazy-loaded resources
|
|
82
96
|
self._model = None
|
|
83
97
|
self._taxonomy_embeddings: Optional[np.ndarray] = None
|
|
@@ -89,9 +103,9 @@ class PredicateComparer:
|
|
|
89
103
|
|
|
90
104
|
from sentence_transformers import SentenceTransformer
|
|
91
105
|
|
|
92
|
-
logger.info(f"Loading embedding model: {self.config.embedding_model}")
|
|
93
|
-
self._model = SentenceTransformer(self.config.embedding_model)
|
|
94
|
-
logger.info("Embedding model loaded")
|
|
106
|
+
logger.info(f"Loading embedding model: {self.config.embedding_model} on {self.device}")
|
|
107
|
+
self._model = SentenceTransformer(self.config.embedding_model, device=self.device)
|
|
108
|
+
logger.info(f"Embedding model loaded on {self.device}")
|
|
95
109
|
|
|
96
110
|
def _normalize_text(self, text: str) -> str:
|
|
97
111
|
"""Normalize text before embedding."""
|
|
@@ -256,21 +270,26 @@ class PredicateComparer:
|
|
|
256
270
|
self,
|
|
257
271
|
statements: list[Statement],
|
|
258
272
|
entity_canonicalizer: Optional[callable] = None,
|
|
273
|
+
detect_reversals: bool = True,
|
|
259
274
|
) -> list[Statement]:
|
|
260
275
|
"""
|
|
261
276
|
Remove duplicate statements using embedding-based predicate comparison.
|
|
262
277
|
|
|
263
278
|
Two statements are considered duplicates if:
|
|
264
|
-
- Canonicalized subjects match
|
|
279
|
+
- Canonicalized subjects match AND canonicalized objects match, OR
|
|
280
|
+
- Canonicalized subjects match objects (reversed) when detect_reversals=True
|
|
265
281
|
- Predicates are similar (embedding-based)
|
|
266
|
-
- Canonicalized objects match
|
|
267
282
|
|
|
268
283
|
When duplicates are found, keeps the statement with better contextualized
|
|
269
284
|
match (comparing "Subject Predicate Object" against source text).
|
|
270
285
|
|
|
286
|
+
For reversed duplicates, the correct orientation is determined by comparing
|
|
287
|
+
both "S P O" and "O P S" against source text.
|
|
288
|
+
|
|
271
289
|
Args:
|
|
272
290
|
statements: List of Statement objects
|
|
273
291
|
entity_canonicalizer: Optional function to canonicalize entity text
|
|
292
|
+
detect_reversals: Whether to detect reversed duplicates (default True)
|
|
274
293
|
|
|
275
294
|
Returns:
|
|
276
295
|
Deduplicated list of statements (keeps best contextualized match)
|
|
@@ -293,6 +312,12 @@ class PredicateComparer:
|
|
|
293
312
|
]
|
|
294
313
|
contextualized_embeddings = self._compute_embeddings(contextualized_texts)
|
|
295
314
|
|
|
315
|
+
# Compute reversed contextualized embeddings: "Object Predicate Subject"
|
|
316
|
+
reversed_texts = [
|
|
317
|
+
f"{s.object.text} {s.predicate} {s.subject.text}" for s in statements
|
|
318
|
+
]
|
|
319
|
+
reversed_embeddings = self._compute_embeddings(reversed_texts)
|
|
320
|
+
|
|
296
321
|
# Compute source text embeddings for scoring which duplicate to keep
|
|
297
322
|
source_embeddings = []
|
|
298
323
|
for stmt in statements:
|
|
@@ -302,6 +327,7 @@ class PredicateComparer:
|
|
|
302
327
|
unique_statements: list[Statement] = []
|
|
303
328
|
unique_pred_embeddings: list[np.ndarray] = []
|
|
304
329
|
unique_context_embeddings: list[np.ndarray] = []
|
|
330
|
+
unique_reversed_embeddings: list[np.ndarray] = []
|
|
305
331
|
unique_source_embeddings: list[np.ndarray] = []
|
|
306
332
|
unique_indices: list[int] = []
|
|
307
333
|
|
|
@@ -310,13 +336,23 @@ class PredicateComparer:
|
|
|
310
336
|
obj_canon = canonicalize(stmt.object.text)
|
|
311
337
|
|
|
312
338
|
duplicate_idx = None
|
|
339
|
+
is_reversed_match = False
|
|
313
340
|
|
|
314
341
|
for j, unique_stmt in enumerate(unique_statements):
|
|
315
342
|
unique_subj = canonicalize(unique_stmt.subject.text)
|
|
316
343
|
unique_obj = canonicalize(unique_stmt.object.text)
|
|
317
344
|
|
|
318
|
-
# Check subject
|
|
319
|
-
|
|
345
|
+
# Check direct match: subject->subject, object->object
|
|
346
|
+
direct_match = (subj_canon == unique_subj and obj_canon == unique_obj)
|
|
347
|
+
|
|
348
|
+
# Check reversed match: subject->object, object->subject
|
|
349
|
+
reversed_match = (
|
|
350
|
+
detect_reversals and
|
|
351
|
+
subj_canon == unique_obj and
|
|
352
|
+
obj_canon == unique_subj
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
if not direct_match and not reversed_match:
|
|
320
356
|
continue
|
|
321
357
|
|
|
322
358
|
# Check predicate similarity
|
|
@@ -326,6 +362,7 @@ class PredicateComparer:
|
|
|
326
362
|
)
|
|
327
363
|
if similarity >= self.config.dedup_threshold:
|
|
328
364
|
duplicate_idx = j
|
|
365
|
+
is_reversed_match = reversed_match and not direct_match
|
|
329
366
|
break
|
|
330
367
|
|
|
331
368
|
if duplicate_idx is None:
|
|
@@ -333,27 +370,91 @@ class PredicateComparer:
|
|
|
333
370
|
unique_statements.append(stmt)
|
|
334
371
|
unique_pred_embeddings.append(pred_embeddings[i])
|
|
335
372
|
unique_context_embeddings.append(contextualized_embeddings[i])
|
|
373
|
+
unique_reversed_embeddings.append(reversed_embeddings[i])
|
|
336
374
|
unique_source_embeddings.append(source_embeddings[i])
|
|
337
375
|
unique_indices.append(i)
|
|
338
376
|
else:
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
377
|
+
existing_stmt = unique_statements[duplicate_idx]
|
|
378
|
+
|
|
379
|
+
if is_reversed_match:
|
|
380
|
+
# Reversed duplicate - determine correct orientation using source text
|
|
381
|
+
# Compare current's normal vs reversed against its source
|
|
382
|
+
current_normal_score = self._cosine_similarity(
|
|
383
|
+
contextualized_embeddings[i], source_embeddings[i]
|
|
384
|
+
)
|
|
385
|
+
current_reversed_score = self._cosine_similarity(
|
|
386
|
+
reversed_embeddings[i], source_embeddings[i]
|
|
387
|
+
)
|
|
388
|
+
# Compare existing's normal vs reversed against its source
|
|
389
|
+
existing_normal_score = self._cosine_similarity(
|
|
390
|
+
unique_context_embeddings[duplicate_idx],
|
|
391
|
+
unique_source_embeddings[duplicate_idx]
|
|
392
|
+
)
|
|
393
|
+
existing_reversed_score = self._cosine_similarity(
|
|
394
|
+
unique_reversed_embeddings[duplicate_idx],
|
|
395
|
+
unique_source_embeddings[duplicate_idx]
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
# Determine best orientation for current
|
|
399
|
+
current_best = max(current_normal_score, current_reversed_score)
|
|
400
|
+
current_should_reverse = current_reversed_score > current_normal_score
|
|
401
|
+
|
|
402
|
+
# Determine best orientation for existing
|
|
403
|
+
existing_best = max(existing_normal_score, existing_reversed_score)
|
|
404
|
+
existing_should_reverse = existing_reversed_score > existing_normal_score
|
|
405
|
+
|
|
406
|
+
if current_best > existing_best:
|
|
407
|
+
# Current is better - use it (possibly reversed)
|
|
408
|
+
if current_should_reverse:
|
|
409
|
+
best_stmt = stmt.reversed()
|
|
410
|
+
else:
|
|
411
|
+
best_stmt = stmt
|
|
412
|
+
# Merge entity types from existing (accounting for reversal)
|
|
413
|
+
if existing_should_reverse:
|
|
414
|
+
best_stmt = best_stmt.merge_entity_types_from(existing_stmt.reversed())
|
|
415
|
+
else:
|
|
416
|
+
best_stmt = best_stmt.merge_entity_types_from(existing_stmt)
|
|
417
|
+
unique_statements[duplicate_idx] = best_stmt
|
|
418
|
+
unique_pred_embeddings[duplicate_idx] = pred_embeddings[i]
|
|
419
|
+
unique_context_embeddings[duplicate_idx] = contextualized_embeddings[i]
|
|
420
|
+
unique_reversed_embeddings[duplicate_idx] = reversed_embeddings[i]
|
|
421
|
+
unique_source_embeddings[duplicate_idx] = source_embeddings[i]
|
|
422
|
+
unique_indices[duplicate_idx] = i
|
|
423
|
+
else:
|
|
424
|
+
# Existing is better - possibly fix its orientation
|
|
425
|
+
if existing_should_reverse and not existing_stmt.was_reversed:
|
|
426
|
+
best_stmt = existing_stmt.reversed()
|
|
427
|
+
else:
|
|
428
|
+
best_stmt = existing_stmt
|
|
429
|
+
# Merge entity types from current (accounting for reversal)
|
|
430
|
+
if current_should_reverse:
|
|
431
|
+
best_stmt = best_stmt.merge_entity_types_from(stmt.reversed())
|
|
432
|
+
else:
|
|
433
|
+
best_stmt = best_stmt.merge_entity_types_from(stmt)
|
|
434
|
+
unique_statements[duplicate_idx] = best_stmt
|
|
435
|
+
else:
|
|
436
|
+
# Direct duplicate - keep the one with better contextualized match
|
|
437
|
+
current_score = self._cosine_similarity(
|
|
438
|
+
contextualized_embeddings[i], source_embeddings[i]
|
|
439
|
+
)
|
|
440
|
+
existing_score = self._cosine_similarity(
|
|
441
|
+
unique_context_embeddings[duplicate_idx],
|
|
442
|
+
unique_source_embeddings[duplicate_idx]
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
if current_score > existing_score:
|
|
446
|
+
# Current statement is a better match - replace
|
|
447
|
+
merged_stmt = stmt.merge_entity_types_from(existing_stmt)
|
|
448
|
+
unique_statements[duplicate_idx] = merged_stmt
|
|
449
|
+
unique_pred_embeddings[duplicate_idx] = pred_embeddings[i]
|
|
450
|
+
unique_context_embeddings[duplicate_idx] = contextualized_embeddings[i]
|
|
451
|
+
unique_reversed_embeddings[duplicate_idx] = reversed_embeddings[i]
|
|
452
|
+
unique_source_embeddings[duplicate_idx] = source_embeddings[i]
|
|
453
|
+
unique_indices[duplicate_idx] = i
|
|
454
|
+
else:
|
|
455
|
+
# Existing statement is better - merge entity types from current
|
|
456
|
+
merged_stmt = existing_stmt.merge_entity_types_from(stmt)
|
|
457
|
+
unique_statements[duplicate_idx] = merged_stmt
|
|
357
458
|
|
|
358
459
|
return unique_statements
|
|
359
460
|
|
|
@@ -437,3 +538,79 @@ class PredicateComparer:
|
|
|
437
538
|
similarity=best_score,
|
|
438
539
|
matched=False,
|
|
439
540
|
)
|
|
541
|
+
|
|
542
|
+
def detect_and_fix_reversals(
|
|
543
|
+
self,
|
|
544
|
+
statements: list[Statement],
|
|
545
|
+
threshold: float = 0.05,
|
|
546
|
+
) -> list[Statement]:
|
|
547
|
+
"""
|
|
548
|
+
Detect and fix subject-object reversals using embedding comparison.
|
|
549
|
+
|
|
550
|
+
For each statement, compares:
|
|
551
|
+
- "Subject Predicate Object" embedding against source_text
|
|
552
|
+
- "Object Predicate Subject" embedding against source_text
|
|
553
|
+
|
|
554
|
+
If the reversed version has significantly higher similarity to the source,
|
|
555
|
+
the subject and object are swapped and was_reversed is set to True.
|
|
556
|
+
|
|
557
|
+
Args:
|
|
558
|
+
statements: List of Statement objects
|
|
559
|
+
threshold: Minimum similarity difference to trigger reversal (default 0.05)
|
|
560
|
+
|
|
561
|
+
Returns:
|
|
562
|
+
List of statements with reversals corrected
|
|
563
|
+
"""
|
|
564
|
+
if not statements:
|
|
565
|
+
return statements
|
|
566
|
+
|
|
567
|
+
result = []
|
|
568
|
+
for stmt in statements:
|
|
569
|
+
# Skip if no source_text to compare against
|
|
570
|
+
if not stmt.source_text:
|
|
571
|
+
result.append(stmt)
|
|
572
|
+
continue
|
|
573
|
+
|
|
574
|
+
# Build normal and reversed triple strings
|
|
575
|
+
normal_text = f"{stmt.subject.text} {stmt.predicate} {stmt.object.text}"
|
|
576
|
+
reversed_text = f"{stmt.object.text} {stmt.predicate} {stmt.subject.text}"
|
|
577
|
+
|
|
578
|
+
# Compute embeddings for normal, reversed, and source
|
|
579
|
+
embeddings = self._compute_embeddings([normal_text, reversed_text, stmt.source_text])
|
|
580
|
+
normal_emb, reversed_emb, source_emb = embeddings[0], embeddings[1], embeddings[2]
|
|
581
|
+
|
|
582
|
+
# Compute similarities to source
|
|
583
|
+
normal_sim = self._cosine_similarity(normal_emb, source_emb)
|
|
584
|
+
reversed_sim = self._cosine_similarity(reversed_emb, source_emb)
|
|
585
|
+
|
|
586
|
+
# If reversed is significantly better, swap subject and object
|
|
587
|
+
if reversed_sim > normal_sim + threshold:
|
|
588
|
+
result.append(stmt.reversed())
|
|
589
|
+
else:
|
|
590
|
+
result.append(stmt)
|
|
591
|
+
|
|
592
|
+
return result
|
|
593
|
+
|
|
594
|
+
def check_reversal(self, statement: Statement) -> tuple[bool, float, float]:
|
|
595
|
+
"""
|
|
596
|
+
Check if a single statement should be reversed.
|
|
597
|
+
|
|
598
|
+
Args:
|
|
599
|
+
statement: Statement to check
|
|
600
|
+
|
|
601
|
+
Returns:
|
|
602
|
+
Tuple of (should_reverse, normal_similarity, reversed_similarity)
|
|
603
|
+
"""
|
|
604
|
+
if not statement.source_text:
|
|
605
|
+
return (False, 0.0, 0.0)
|
|
606
|
+
|
|
607
|
+
normal_text = f"{statement.subject.text} {statement.predicate} {statement.object.text}"
|
|
608
|
+
reversed_text = f"{statement.object.text} {statement.predicate} {statement.subject.text}"
|
|
609
|
+
|
|
610
|
+
embeddings = self._compute_embeddings([normal_text, reversed_text, statement.source_text])
|
|
611
|
+
normal_emb, reversed_emb, source_emb = embeddings[0], embeddings[1], embeddings[2]
|
|
612
|
+
|
|
613
|
+
normal_sim = self._cosine_similarity(normal_emb, source_emb)
|
|
614
|
+
reversed_sim = self._cosine_similarity(reversed_emb, source_emb)
|
|
615
|
+
|
|
616
|
+
return (reversed_sim > normal_sim, normal_sim, reversed_sim)
|
|
File without changes
|
|
File without changes
|