corp-extractor 0.2.5__py3-none-any.whl → 0.2.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {corp_extractor-0.2.5.dist-info → corp_extractor-0.2.11.dist-info}/METADATA +19 -11
- corp_extractor-0.2.11.dist-info/RECORD +11 -0
- statement_extractor/cli.py +31 -1
- statement_extractor/extractor.py +77 -5
- statement_extractor/models.py +6 -0
- statement_extractor/predicate_comparer.py +23 -1
- statement_extractor/scoring.py +32 -10
- corp_extractor-0.2.5.dist-info/RECORD +0 -11
- {corp_extractor-0.2.5.dist-info → corp_extractor-0.2.11.dist-info}/WHEEL +0 -0
- {corp_extractor-0.2.5.dist-info → corp_extractor-0.2.11.dist-info}/entry_points.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: corp-extractor
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.11
|
|
4
4
|
Summary: Extract structured statements from text using T5-Gemma 2 and Diverse Beam Search
|
|
5
5
|
Project-URL: Homepage, https://github.com/corp-o-rate/statement-extractor
|
|
6
6
|
Project-URL: Documentation, https://github.com/corp-o-rate/statement-extractor#readme
|
|
@@ -27,7 +27,7 @@ Requires-Dist: click>=8.0.0
|
|
|
27
27
|
Requires-Dist: numpy>=1.24.0
|
|
28
28
|
Requires-Dist: pydantic>=2.0.0
|
|
29
29
|
Requires-Dist: torch>=2.0.0
|
|
30
|
-
Requires-Dist: transformers>=5.0.
|
|
30
|
+
Requires-Dist: transformers>=5.0.0rc3
|
|
31
31
|
Provides-Extra: all
|
|
32
32
|
Requires-Dist: sentence-transformers>=2.2.0; extra == 'all'
|
|
33
33
|
Provides-Extra: dev
|
|
@@ -65,18 +65,26 @@ Extract structured subject-predicate-object statements from unstructured text us
|
|
|
65
65
|
|
|
66
66
|
```bash
|
|
67
67
|
# Recommended: include embedding support for smart deduplication
|
|
68
|
-
pip install corp-extractor[embeddings]
|
|
68
|
+
pip install "corp-extractor[embeddings]"
|
|
69
69
|
|
|
70
70
|
# Minimal installation (no embedding features)
|
|
71
71
|
pip install corp-extractor
|
|
72
72
|
```
|
|
73
73
|
|
|
74
|
-
**Note**: This package requires
|
|
74
|
+
**Note**: This package requires `transformers>=5.0.0` (pre-release) for T5-Gemma2 model support. Install with `--pre` flag if needed:
|
|
75
|
+
```bash
|
|
76
|
+
pip install --pre "corp-extractor[embeddings]"
|
|
77
|
+
```
|
|
75
78
|
|
|
76
79
|
**For GPU support**, install PyTorch with CUDA first:
|
|
77
80
|
```bash
|
|
78
81
|
pip install torch --index-url https://download.pytorch.org/whl/cu121
|
|
79
|
-
pip install corp-extractor[embeddings]
|
|
82
|
+
pip install "corp-extractor[embeddings]"
|
|
83
|
+
```
|
|
84
|
+
|
|
85
|
+
**For Apple Silicon (M1/M2/M3)**, MPS acceleration is automatically detected:
|
|
86
|
+
```bash
|
|
87
|
+
pip install "corp-extractor[embeddings]" # MPS used automatically
|
|
80
88
|
```
|
|
81
89
|
|
|
82
90
|
## Quick Start
|
|
@@ -105,13 +113,13 @@ For best results, install globally first:
|
|
|
105
113
|
|
|
106
114
|
```bash
|
|
107
115
|
# Using uv (recommended)
|
|
108
|
-
uv tool install corp-extractor[embeddings]
|
|
116
|
+
uv tool install "corp-extractor[embeddings]"
|
|
109
117
|
|
|
110
118
|
# Using pipx
|
|
111
|
-
pipx install corp-extractor[embeddings]
|
|
119
|
+
pipx install "corp-extractor[embeddings]"
|
|
112
120
|
|
|
113
121
|
# Using pip
|
|
114
|
-
pip install corp-extractor[embeddings]
|
|
122
|
+
pip install "corp-extractor[embeddings]"
|
|
115
123
|
|
|
116
124
|
# Then use anywhere
|
|
117
125
|
corp-extractor "Your text here"
|
|
@@ -125,7 +133,7 @@ Run directly without installing using [uv](https://docs.astral.sh/uv/):
|
|
|
125
133
|
uvx corp-extractor "Apple announced a new iPhone."
|
|
126
134
|
```
|
|
127
135
|
|
|
128
|
-
**Note**:
|
|
136
|
+
**Note**: First run downloads the model (~1.5GB) which may take a few minutes.
|
|
129
137
|
|
|
130
138
|
### Usage Examples
|
|
131
139
|
|
|
@@ -178,7 +186,7 @@ Options:
|
|
|
178
186
|
--min-confidence FLOAT Min confidence filter (default: 0)
|
|
179
187
|
--taxonomy PATH Load predicate taxonomy from file
|
|
180
188
|
--taxonomy-threshold FLOAT Taxonomy matching threshold (default: 0.5)
|
|
181
|
-
--device [auto|cuda|cpu]
|
|
189
|
+
--device [auto|cuda|mps|cpu] Device to use (default: auto)
|
|
182
190
|
-v, --verbose Show confidence scores and metadata
|
|
183
191
|
-q, --quiet Suppress progress messages
|
|
184
192
|
--version Show version
|
|
@@ -314,7 +322,7 @@ dict_output = extract_statements_as_dict(text)
|
|
|
314
322
|
```python
|
|
315
323
|
from statement_extractor import StatementExtractor
|
|
316
324
|
|
|
317
|
-
extractor = StatementExtractor(device="cuda") # or "cpu"
|
|
325
|
+
extractor = StatementExtractor(device="cuda") # or "mps" (Apple Silicon) or "cpu"
|
|
318
326
|
|
|
319
327
|
texts = ["Text 1...", "Text 2...", "Text 3..."]
|
|
320
328
|
for text in texts:
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
statement_extractor/__init__.py,sha256=MIZgn-lD9-XGJapzdyYxMhEJFRrTzftbRklrhwA4e8w,2967
|
|
2
|
+
statement_extractor/canonicalization.py,sha256=ZMLs6RLWJa_rOJ8XZ7PoHFU13-zeJkOMDnvK-ZaFa5s,5991
|
|
3
|
+
statement_extractor/cli.py,sha256=NIGCpqcnzF42B16RCiSu4kN0RlnVne2ZAT8341Znt1g,8558
|
|
4
|
+
statement_extractor/extractor.py,sha256=r2gcCfZT43Q8STPuzaXmhbjWXTAs4JwMeAtCjQxlsIQ,25870
|
|
5
|
+
statement_extractor/models.py,sha256=IE3TyIiOl2CINPMroQnGT12rSeQFR0bV3y4BJ79wLmI,10877
|
|
6
|
+
statement_extractor/predicate_comparer.py,sha256=jcuaBi5BYqD3TKoyj3pR9dxtX5ihfDJvjdhEd2LHCwc,26184
|
|
7
|
+
statement_extractor/scoring.py,sha256=xs0SxrV42QNBULQguU1-HhcCc-HnS-ekbcdx7FqWGVk,15663
|
|
8
|
+
corp_extractor-0.2.11.dist-info/METADATA,sha256=D-fs9i9kn4v5bRAHCHxI3cq_6vosNgDCN7uuYwVZztM,13775
|
|
9
|
+
corp_extractor-0.2.11.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
10
|
+
corp_extractor-0.2.11.dist-info/entry_points.txt,sha256=i0iKFqPIusvb-QTQ1zNnFgAqatgVah-jIhahbs5TToQ,115
|
|
11
|
+
corp_extractor-0.2.11.dist-info/RECORD,,
|
statement_extractor/cli.py
CHANGED
|
@@ -7,11 +7,36 @@ Usage:
|
|
|
7
7
|
cat input.txt | corp-extractor -
|
|
8
8
|
"""
|
|
9
9
|
|
|
10
|
+
import logging
|
|
10
11
|
import sys
|
|
11
12
|
from typing import Optional
|
|
12
13
|
|
|
13
14
|
import click
|
|
14
15
|
|
|
16
|
+
|
|
17
|
+
def _configure_logging(verbose: bool) -> None:
|
|
18
|
+
"""Configure logging for the extraction pipeline."""
|
|
19
|
+
level = logging.DEBUG if verbose else logging.WARNING
|
|
20
|
+
|
|
21
|
+
# Configure root logger for statement_extractor package
|
|
22
|
+
logging.basicConfig(
|
|
23
|
+
level=level,
|
|
24
|
+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
|
25
|
+
datefmt="%H:%M:%S",
|
|
26
|
+
stream=sys.stderr,
|
|
27
|
+
force=True,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
# Set level for all statement_extractor loggers
|
|
31
|
+
for logger_name in [
|
|
32
|
+
"statement_extractor",
|
|
33
|
+
"statement_extractor.extractor",
|
|
34
|
+
"statement_extractor.scoring",
|
|
35
|
+
"statement_extractor.predicate_comparer",
|
|
36
|
+
"statement_extractor.canonicalization",
|
|
37
|
+
]:
|
|
38
|
+
logging.getLogger(logger_name).setLevel(level)
|
|
39
|
+
|
|
15
40
|
from . import __version__
|
|
16
41
|
from .models import (
|
|
17
42
|
ExtractionOptions,
|
|
@@ -47,7 +72,7 @@ from .models import (
|
|
|
47
72
|
@click.option("--taxonomy", type=click.Path(exists=True), help="Load predicate taxonomy from file (one per line)")
|
|
48
73
|
@click.option("--taxonomy-threshold", type=float, default=0.5, help="Similarity threshold for taxonomy matching (default: 0.5)")
|
|
49
74
|
# Device options
|
|
50
|
-
@click.option("--device", type=click.Choice(["auto", "cuda", "cpu"]), default="auto", help="Device to use (default: auto)")
|
|
75
|
+
@click.option("--device", type=click.Choice(["auto", "cuda", "mps", "cpu"]), default="auto", help="Device to use (default: auto)")
|
|
51
76
|
# Output options
|
|
52
77
|
@click.option("-v", "--verbose", is_flag=True, help="Show verbose output with confidence scores")
|
|
53
78
|
@click.option("-q", "--quiet", is_flag=True, help="Suppress progress messages")
|
|
@@ -91,6 +116,9 @@ def main(
|
|
|
91
116
|
json JSON with full metadata
|
|
92
117
|
xml Raw XML from model
|
|
93
118
|
"""
|
|
119
|
+
# Configure logging based on verbose flag
|
|
120
|
+
_configure_logging(verbose)
|
|
121
|
+
|
|
94
122
|
# Determine output format
|
|
95
123
|
if output_json:
|
|
96
124
|
output = "json"
|
|
@@ -135,6 +163,7 @@ def main(
|
|
|
135
163
|
predicate_taxonomy=predicate_taxonomy,
|
|
136
164
|
predicate_config=predicate_config,
|
|
137
165
|
scoring_config=scoring_config,
|
|
166
|
+
verbose=verbose,
|
|
138
167
|
)
|
|
139
168
|
|
|
140
169
|
# Import here to allow --help without loading torch
|
|
@@ -160,6 +189,7 @@ def main(
|
|
|
160
189
|
result = extractor.extract(input_text, options)
|
|
161
190
|
_print_table(result, verbose)
|
|
162
191
|
except Exception as e:
|
|
192
|
+
logging.exception("Error extracting statements:")
|
|
163
193
|
raise click.ClickException(f"Extraction failed: {e}")
|
|
164
194
|
|
|
165
195
|
|
statement_extractor/extractor.py
CHANGED
|
@@ -80,11 +80,16 @@ class StatementExtractor:
|
|
|
80
80
|
|
|
81
81
|
# Auto-detect device
|
|
82
82
|
if device is None:
|
|
83
|
-
|
|
83
|
+
if torch.cuda.is_available():
|
|
84
|
+
self.device = "cuda"
|
|
85
|
+
elif torch.backends.mps.is_available():
|
|
86
|
+
self.device = "mps"
|
|
87
|
+
else:
|
|
88
|
+
self.device = "cpu"
|
|
84
89
|
else:
|
|
85
90
|
self.device = device
|
|
86
91
|
|
|
87
|
-
# Auto-detect dtype
|
|
92
|
+
# Auto-detect dtype (bfloat16 only for CUDA, float32 for MPS/CPU)
|
|
88
93
|
if torch_dtype is None:
|
|
89
94
|
self.torch_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
|
|
90
95
|
else:
|
|
@@ -175,6 +180,14 @@ class StatementExtractor:
|
|
|
175
180
|
if options is None:
|
|
176
181
|
options = ExtractionOptions()
|
|
177
182
|
|
|
183
|
+
logger.debug("=" * 60)
|
|
184
|
+
logger.debug("EXTRACTION STARTED")
|
|
185
|
+
logger.debug("=" * 60)
|
|
186
|
+
logger.debug(f"Input text length: {len(text)} chars")
|
|
187
|
+
logger.debug(f"Options: num_beams={options.num_beams}, diversity={options.diversity_penalty}")
|
|
188
|
+
logger.debug(f" merge_beams={options.merge_beams}, embedding_dedup={options.embedding_dedup}")
|
|
189
|
+
logger.debug(f" deduplicate={options.deduplicate}, max_new_tokens={options.max_new_tokens}")
|
|
190
|
+
|
|
178
191
|
# Store original text for scoring
|
|
179
192
|
original_text = text
|
|
180
193
|
|
|
@@ -185,6 +198,10 @@ class StatementExtractor:
|
|
|
185
198
|
# Run extraction with retry logic
|
|
186
199
|
statements = self._extract_with_scoring(text, original_text, options)
|
|
187
200
|
|
|
201
|
+
logger.debug("=" * 60)
|
|
202
|
+
logger.debug(f"EXTRACTION COMPLETE: {len(statements)} statements")
|
|
203
|
+
logger.debug("=" * 60)
|
|
204
|
+
|
|
188
205
|
return ExtractionResult(
|
|
189
206
|
statements=statements,
|
|
190
207
|
source_text=original_text,
|
|
@@ -270,6 +287,10 @@ class StatementExtractor:
|
|
|
270
287
|
4. Merges top beams or selects best beam
|
|
271
288
|
5. Deduplicates using embeddings (if enabled)
|
|
272
289
|
"""
|
|
290
|
+
logger.debug("-" * 40)
|
|
291
|
+
logger.debug("PHASE 1: Tokenization")
|
|
292
|
+
logger.debug("-" * 40)
|
|
293
|
+
|
|
273
294
|
# Tokenize input
|
|
274
295
|
inputs = self.tokenizer(
|
|
275
296
|
text,
|
|
@@ -278,48 +299,77 @@ class StatementExtractor:
|
|
|
278
299
|
truncation=True,
|
|
279
300
|
).to(self.device)
|
|
280
301
|
|
|
302
|
+
input_ids = inputs["input_ids"]
|
|
303
|
+
logger.debug(f"Tokenized: {input_ids.shape[1]} tokens")
|
|
304
|
+
|
|
281
305
|
# Count sentences for quality check
|
|
282
306
|
num_sentences = self._count_sentences(text)
|
|
283
307
|
min_expected = int(num_sentences * options.min_statement_ratio)
|
|
284
308
|
|
|
285
|
-
logger.
|
|
309
|
+
logger.debug(f"Input has ~{num_sentences} sentences, min expected: {min_expected}")
|
|
286
310
|
|
|
287
311
|
# Get beam scorer
|
|
288
312
|
beam_scorer = self._get_beam_scorer(options)
|
|
289
313
|
|
|
314
|
+
logger.debug("-" * 40)
|
|
315
|
+
logger.debug("PHASE 2: Diverse Beam Search Generation")
|
|
316
|
+
logger.debug("-" * 40)
|
|
317
|
+
|
|
290
318
|
all_candidates: list[list[Statement]] = []
|
|
291
319
|
|
|
292
320
|
for attempt in range(options.max_attempts):
|
|
321
|
+
logger.debug(f"Attempt {attempt + 1}/{options.max_attempts}: Generating {options.num_beams} beams...")
|
|
322
|
+
|
|
293
323
|
# Generate candidate beams
|
|
294
324
|
candidates = self._generate_candidate_beams(inputs, options)
|
|
325
|
+
logger.debug(f" Generated {len(candidates)} valid XML outputs")
|
|
295
326
|
|
|
296
327
|
# Parse each candidate to statements
|
|
297
328
|
parsed_candidates = []
|
|
298
|
-
for xml_output in candidates:
|
|
329
|
+
for i, xml_output in enumerate(candidates):
|
|
299
330
|
statements = self._parse_xml_to_statements(xml_output)
|
|
300
331
|
if statements:
|
|
301
332
|
parsed_candidates.append(statements)
|
|
333
|
+
logger.debug(f" Beam {i}: {len(statements)} statements parsed")
|
|
334
|
+
else:
|
|
335
|
+
logger.debug(f" Beam {i}: 0 statements (parse failed)")
|
|
302
336
|
|
|
303
337
|
all_candidates.extend(parsed_candidates)
|
|
304
338
|
|
|
305
339
|
# Check if we have enough statements
|
|
306
340
|
total_stmts = sum(len(c) for c in parsed_candidates)
|
|
307
|
-
logger.
|
|
341
|
+
logger.debug(f" Total: {len(parsed_candidates)} beams, {total_stmts} statements")
|
|
308
342
|
|
|
309
343
|
if total_stmts >= min_expected:
|
|
344
|
+
logger.debug(f" Sufficient statements ({total_stmts} >= {min_expected}), stopping")
|
|
310
345
|
break
|
|
311
346
|
|
|
312
347
|
if not all_candidates:
|
|
348
|
+
logger.debug("No valid candidates generated, returning empty result")
|
|
313
349
|
return []
|
|
314
350
|
|
|
351
|
+
logger.debug("-" * 40)
|
|
352
|
+
logger.debug("PHASE 3: Beam Selection/Merging")
|
|
353
|
+
logger.debug("-" * 40)
|
|
354
|
+
|
|
315
355
|
# Select or merge beams
|
|
316
356
|
if options.merge_beams:
|
|
357
|
+
logger.debug(f"Merging {len(all_candidates)} beams...")
|
|
317
358
|
statements = beam_scorer.merge_beams(all_candidates, original_text)
|
|
359
|
+
logger.debug(f" After merge: {len(statements)} statements")
|
|
318
360
|
else:
|
|
361
|
+
logger.debug(f"Selecting best beam from {len(all_candidates)} candidates...")
|
|
319
362
|
statements = beam_scorer.select_best_beam(all_candidates, original_text)
|
|
363
|
+
logger.debug(f" Selected beam has {len(statements)} statements")
|
|
364
|
+
|
|
365
|
+
logger.debug("-" * 40)
|
|
366
|
+
logger.debug("PHASE 4: Deduplication")
|
|
367
|
+
logger.debug("-" * 40)
|
|
320
368
|
|
|
321
369
|
# Apply embedding-based deduplication if enabled
|
|
322
370
|
if options.embedding_dedup and options.deduplicate:
|
|
371
|
+
logger.debug("Using embedding-based deduplication...")
|
|
372
|
+
pre_dedup_count = len(statements)
|
|
323
373
|
try:
|
|
324
374
|
comparer = self._get_predicate_comparer(options)
|
|
325
375
|
if comparer:
|
|
@@ -327,14 +377,32 @@ class StatementExtractor:
|
|
|
327
377
|
statements,
|
|
328
378
|
entity_canonicalizer=options.entity_canonicalizer
|
|
329
379
|
)
|
|
380
|
+
logger.debug(f" After embedding dedup: {len(statements)} statements (removed {pre_dedup_count - len(statements)})")
|
|
381
|
+
|
|
330
382
|
# Also normalize predicates if taxonomy provided
|
|
331
383
|
if options.predicate_taxonomy or self._predicate_taxonomy:
|
|
384
|
+
logger.debug("Normalizing predicates to taxonomy...")
|
|
332
385
|
statements = comparer.normalize_predicates(statements)
|
|
333
386
|
except Exception as e:
|
|
334
387
|
logger.warning(f"Embedding deduplication failed, falling back to exact match: {e}")
|
|
335
388
|
statements = self._deduplicate_statements_exact(statements, options)
|
|
389
|
+
logger.debug(f" After exact dedup: {len(statements)} statements")
|
|
336
390
|
elif options.deduplicate:
|
|
391
|
+
logger.debug("Using exact text deduplication...")
|
|
392
|
+
pre_dedup_count = len(statements)
|
|
337
393
|
statements = self._deduplicate_statements_exact(statements, options)
|
|
394
|
+
logger.debug(f" After exact dedup: {len(statements)} statements (removed {pre_dedup_count - len(statements)})")
|
|
395
|
+
else:
|
|
396
|
+
logger.debug("Deduplication disabled")
|
|
397
|
+
|
|
398
|
+
# Log final statements
|
|
399
|
+
logger.debug("-" * 40)
|
|
400
|
+
logger.debug("FINAL STATEMENTS:")
|
|
401
|
+
logger.debug("-" * 40)
|
|
402
|
+
for i, stmt in enumerate(statements):
|
|
403
|
+
conf = f" (conf={stmt.confidence_score:.2f})" if stmt.confidence_score else ""
|
|
404
|
+
canonical = f" -> {stmt.canonical_predicate}" if stmt.canonical_predicate else ""
|
|
405
|
+
logger.debug(f" {i+1}. {stmt.subject.text} --[{stmt.predicate}{canonical}]--> {stmt.object.text}{conf}")
|
|
338
406
|
|
|
339
407
|
return statements
|
|
340
408
|
|
|
@@ -350,12 +418,16 @@ class StatementExtractor:
|
|
|
350
418
|
outputs = self.model.generate(
|
|
351
419
|
**inputs,
|
|
352
420
|
max_new_tokens=options.max_new_tokens,
|
|
421
|
+
max_length=None, # Override model default, use max_new_tokens only
|
|
353
422
|
num_beams=num_seqs,
|
|
354
423
|
num_beam_groups=num_seqs,
|
|
355
424
|
num_return_sequences=num_seqs,
|
|
356
425
|
diversity_penalty=options.diversity_penalty,
|
|
357
426
|
do_sample=False,
|
|
427
|
+
top_p=None, # Override model config to suppress warning
|
|
428
|
+
top_k=None, # Override model config to suppress warning
|
|
358
429
|
trust_remote_code=True,
|
|
430
|
+
custom_generate="transformers-community/group-beam-search",
|
|
359
431
|
)
|
|
360
432
|
|
|
361
433
|
# Decode and process candidates
|
statement_extractor/models.py
CHANGED
|
@@ -280,5 +280,11 @@ class ExtractionOptions(BaseModel):
|
|
|
280
280
|
description="Use embedding similarity for predicate deduplication"
|
|
281
281
|
)
|
|
282
282
|
|
|
283
|
+
# Verbose logging
|
|
284
|
+
verbose: bool = Field(
|
|
285
|
+
default=False,
|
|
286
|
+
description="Enable verbose logging for debugging"
|
|
287
|
+
)
|
|
288
|
+
|
|
283
289
|
class Config:
|
|
284
290
|
arbitrary_types_allowed = True # Allow Callable type
|
|
@@ -83,7 +83,12 @@ class PredicateComparer:
|
|
|
83
83
|
# Auto-detect device
|
|
84
84
|
if device is None:
|
|
85
85
|
import torch
|
|
86
|
-
|
|
86
|
+
if torch.cuda.is_available():
|
|
87
|
+
self.device = "cuda"
|
|
88
|
+
elif torch.backends.mps.is_available():
|
|
89
|
+
self.device = "mps"
|
|
90
|
+
else:
|
|
91
|
+
self.device = "cpu"
|
|
87
92
|
else:
|
|
88
93
|
self.device = device
|
|
89
94
|
|
|
@@ -289,6 +294,8 @@ class PredicateComparer:
|
|
|
289
294
|
Returns:
|
|
290
295
|
Deduplicated list of statements (keeps best contextualized match)
|
|
291
296
|
"""
|
|
297
|
+
logger.debug(f"Embedding deduplication: {len(statements)} statements, detect_reversals={detect_reversals}")
|
|
298
|
+
|
|
292
299
|
if len(statements) <= 1:
|
|
293
300
|
return statements
|
|
294
301
|
|
|
@@ -297,27 +304,33 @@ class PredicateComparer:
|
|
|
297
304
|
return entity_canonicalizer(text)
|
|
298
305
|
return text.lower().strip()
|
|
299
306
|
|
|
307
|
+
logger.debug(" Computing predicate embeddings...")
|
|
300
308
|
# Compute all predicate embeddings at once for efficiency
|
|
301
309
|
predicates = [s.predicate for s in statements]
|
|
302
310
|
pred_embeddings = self._compute_embeddings(predicates)
|
|
311
|
+
logger.debug(f" Computed {len(pred_embeddings)} predicate embeddings")
|
|
303
312
|
|
|
313
|
+
logger.debug(" Computing contextualized embeddings (S P O)...")
|
|
304
314
|
# Compute contextualized embeddings: "Subject Predicate Object" for each statement
|
|
305
315
|
contextualized_texts = [
|
|
306
316
|
f"{s.subject.text} {s.predicate} {s.object.text}" for s in statements
|
|
307
317
|
]
|
|
308
318
|
contextualized_embeddings = self._compute_embeddings(contextualized_texts)
|
|
309
319
|
|
|
320
|
+
logger.debug(" Computing reversed embeddings (O P S)...")
|
|
310
321
|
# Compute reversed contextualized embeddings: "Object Predicate Subject"
|
|
311
322
|
reversed_texts = [
|
|
312
323
|
f"{s.object.text} {s.predicate} {s.subject.text}" for s in statements
|
|
313
324
|
]
|
|
314
325
|
reversed_embeddings = self._compute_embeddings(reversed_texts)
|
|
315
326
|
|
|
327
|
+
logger.debug(" Computing source text embeddings...")
|
|
316
328
|
# Compute source text embeddings for scoring which duplicate to keep
|
|
317
329
|
source_embeddings = []
|
|
318
330
|
for stmt in statements:
|
|
319
331
|
source_text = stmt.source_text or f"{stmt.subject.text} {stmt.predicate} {stmt.object.text}"
|
|
320
332
|
source_embeddings.append(self._compute_embeddings([source_text])[0])
|
|
333
|
+
logger.debug(" All embeddings computed, starting comparison loop...")
|
|
321
334
|
|
|
322
335
|
unique_statements: list[Statement] = []
|
|
323
336
|
unique_pred_embeddings: list[np.ndarray] = []
|
|
@@ -358,9 +371,17 @@ class PredicateComparer:
|
|
|
358
371
|
if similarity >= self.config.dedup_threshold:
|
|
359
372
|
duplicate_idx = j
|
|
360
373
|
is_reversed_match = reversed_match and not direct_match
|
|
374
|
+
match_type = "reversed" if is_reversed_match else "direct"
|
|
375
|
+
logger.debug(
|
|
376
|
+
f" [{i}] DUPLICATE of [{unique_indices[j]}] ({match_type}, sim={similarity:.3f}): "
|
|
377
|
+
f"'{stmt.subject.text}' --[{stmt.predicate}]--> '{stmt.object.text}'"
|
|
378
|
+
)
|
|
361
379
|
break
|
|
362
380
|
|
|
363
381
|
if duplicate_idx is None:
|
|
382
|
+
logger.debug(
|
|
383
|
+
f" [{i}] UNIQUE: '{stmt.subject.text}' --[{stmt.predicate}]--> '{stmt.object.text}'"
|
|
384
|
+
)
|
|
364
385
|
# Not a duplicate - add to unique list
|
|
365
386
|
unique_statements.append(stmt)
|
|
366
387
|
unique_pred_embeddings.append(pred_embeddings[i])
|
|
@@ -451,6 +472,7 @@ class PredicateComparer:
|
|
|
451
472
|
merged_stmt = existing_stmt.merge_entity_types_from(stmt)
|
|
452
473
|
unique_statements[duplicate_idx] = merged_stmt
|
|
453
474
|
|
|
475
|
+
logger.debug(f" Deduplication complete: {len(statements)} -> {len(unique_statements)} statements")
|
|
454
476
|
return unique_statements
|
|
455
477
|
|
|
456
478
|
def normalize_predicates(
|
statement_extractor/scoring.py
CHANGED
|
@@ -6,10 +6,13 @@ Provides:
|
|
|
6
6
|
- BeamScorer: Score and select/merge beams based on quality metrics
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
+
import logging
|
|
9
10
|
from typing import Optional
|
|
10
11
|
|
|
11
12
|
from .models import ScoringConfig, Statement
|
|
12
13
|
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
13
16
|
|
|
14
17
|
class TripleScorer:
|
|
15
18
|
"""
|
|
@@ -32,6 +35,7 @@ class TripleScorer:
|
|
|
32
35
|
Higher scores indicate better grounding in source text.
|
|
33
36
|
"""
|
|
34
37
|
if not source_text:
|
|
38
|
+
logger.debug(f" No source text, returning neutral score 0.5")
|
|
35
39
|
return 0.5 # Neutral score if no source text
|
|
36
40
|
|
|
37
41
|
score = 0.0
|
|
@@ -53,6 +57,7 @@ class TripleScorer:
|
|
|
53
57
|
weights_sum += 0.2
|
|
54
58
|
|
|
55
59
|
# Check proximity - subject and object in same/nearby region (weight: 0.2)
|
|
60
|
+
proximity_score = 0.0
|
|
56
61
|
if subject_found and object_found:
|
|
57
62
|
proximity_score = self._compute_proximity(
|
|
58
63
|
statement.subject.text,
|
|
@@ -62,7 +67,14 @@ class TripleScorer:
|
|
|
62
67
|
score += 0.2 * proximity_score
|
|
63
68
|
weights_sum += 0.2
|
|
64
69
|
|
|
65
|
-
|
|
70
|
+
final_score = score / weights_sum if weights_sum > 0 else 0.0
|
|
71
|
+
|
|
72
|
+
logger.debug(
|
|
73
|
+
f" Score for '{statement.subject.text}' --[{statement.predicate}]--> '{statement.object.text}': "
|
|
74
|
+
f"{final_score:.2f} (subj={subject_found}, obj={object_found}, pred={predicate_grounded}, prox={proximity_score:.2f})"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
return final_score
|
|
66
78
|
|
|
67
79
|
def find_evidence_span(
|
|
68
80
|
self,
|
|
@@ -347,10 +359,12 @@ class BeamScorer:
|
|
|
347
359
|
return []
|
|
348
360
|
|
|
349
361
|
top_n = top_n or self.config.merge_top_n
|
|
362
|
+
logger.debug(f"Merging beams: {len(candidates)} candidates, selecting top {top_n}")
|
|
350
363
|
|
|
351
364
|
# Score each beam
|
|
352
365
|
scored_beams = []
|
|
353
|
-
for beam in candidates:
|
|
366
|
+
for i, beam in enumerate(candidates):
|
|
367
|
+
logger.debug(f" Scoring beam {i} ({len(beam)} statements)...")
|
|
354
368
|
for stmt in beam:
|
|
355
369
|
if stmt.confidence_score is None:
|
|
356
370
|
stmt.confidence_score = self.triple_scorer.score_triple(stmt, source_text)
|
|
@@ -359,31 +373,36 @@ class BeamScorer:
|
|
|
359
373
|
|
|
360
374
|
beam_score = self.score_beam(beam, source_text)
|
|
361
375
|
scored_beams.append((beam_score, beam))
|
|
376
|
+
logger.debug(f" Beam {i} score: {beam_score:.3f}")
|
|
362
377
|
|
|
363
378
|
# Sort and take top N
|
|
364
379
|
scored_beams.sort(key=lambda x: x[0], reverse=True)
|
|
365
380
|
top_beams = [beam for _, beam in scored_beams[:top_n]]
|
|
381
|
+
logger.debug(f" Selected top {len(top_beams)} beams")
|
|
366
382
|
|
|
367
383
|
# Pool all triples
|
|
368
384
|
all_statements: list[Statement] = []
|
|
369
385
|
for beam in top_beams:
|
|
370
386
|
all_statements.extend(beam)
|
|
387
|
+
logger.debug(f" Pooled {len(all_statements)} statements from top beams")
|
|
371
388
|
|
|
372
389
|
# Filter by confidence threshold
|
|
373
390
|
min_conf = self.config.min_confidence
|
|
374
391
|
filtered = [s for s in all_statements if (s.confidence_score or 0) >= min_conf]
|
|
392
|
+
logger.debug(f" After confidence filter (>={min_conf}): {len(filtered)} statements")
|
|
375
393
|
|
|
376
|
-
# Filter out statements where source_text doesn't support the predicate
|
|
377
|
-
# This catches model hallucinations where predicate doesn't match the evidence
|
|
378
|
-
consistent = [
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
]
|
|
394
|
+
# # Filter out statements where source_text doesn't support the predicate
|
|
395
|
+
# # This catches model hallucinations where predicate doesn't match the evidence
|
|
396
|
+
# consistent = [
|
|
397
|
+
# s for s in filtered
|
|
398
|
+
# if self._source_text_supports_predicate(s)
|
|
399
|
+
# ]
|
|
400
|
+
# logger.debug(f" After predicate consistency filter: {len(consistent)} statements")
|
|
382
401
|
|
|
383
402
|
# Deduplicate - keep highest confidence for each (subject, predicate, object)
|
|
384
403
|
# Note: Same subject+predicate with different objects is valid (e.g., "Apple announced X and Y")
|
|
385
404
|
seen: dict[tuple[str, str, str], Statement] = {}
|
|
386
|
-
for stmt in
|
|
405
|
+
for stmt in all_statements:
|
|
387
406
|
key = (
|
|
388
407
|
stmt.subject.text.lower(),
|
|
389
408
|
stmt.predicate.lower(),
|
|
@@ -392,7 +411,10 @@ class BeamScorer:
|
|
|
392
411
|
if key not in seen or (stmt.confidence_score or 0) > (seen[key].confidence_score or 0):
|
|
393
412
|
seen[key] = stmt
|
|
394
413
|
|
|
395
|
-
|
|
414
|
+
result = list(seen.values())
|
|
415
|
+
logger.debug(f" After deduplication: {len(result)} unique statements")
|
|
416
|
+
|
|
417
|
+
return result
|
|
396
418
|
|
|
397
419
|
def _source_text_supports_predicate(self, stmt: Statement) -> bool:
|
|
398
420
|
"""
|
|
@@ -1,11 +0,0 @@
|
|
|
1
|
-
statement_extractor/__init__.py,sha256=MIZgn-lD9-XGJapzdyYxMhEJFRrTzftbRklrhwA4e8w,2967
|
|
2
|
-
statement_extractor/canonicalization.py,sha256=ZMLs6RLWJa_rOJ8XZ7PoHFU13-zeJkOMDnvK-ZaFa5s,5991
|
|
3
|
-
statement_extractor/cli.py,sha256=kJnZm_mbq4np1vTxSjczMZM5zGuDlC8Z5xLJd8O3xZ4,7605
|
|
4
|
-
statement_extractor/extractor.py,sha256=PX0SiJnYUnh06seyH5W77FcPpcvLXwEM8IGsuVuRh0Q,22158
|
|
5
|
-
statement_extractor/models.py,sha256=xDF3pDPhIiqiMwFMPV94aBEgZGbSe-x2TkshahOiCog,10739
|
|
6
|
-
statement_extractor/predicate_comparer.py,sha256=iwBfNJFNOFv8ODKN9F9EtmknpCeSThOpnu6P_PJSmgE,24898
|
|
7
|
-
statement_extractor/scoring.py,sha256=Wa1BW6jXtHD7dZkUXwdwE39hwFo2ko6BuIogBc4E2Lk,14493
|
|
8
|
-
corp_extractor-0.2.5.dist-info/METADATA,sha256=iN_MPbqHhizaFAGJKzR5JNSbDivrS133oSTiYWrFht4,13552
|
|
9
|
-
corp_extractor-0.2.5.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
10
|
-
corp_extractor-0.2.5.dist-info/entry_points.txt,sha256=i0iKFqPIusvb-QTQ1zNnFgAqatgVah-jIhahbs5TToQ,115
|
|
11
|
-
corp_extractor-0.2.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|