corp-extractor 0.2.5__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {corp_extractor-0.2.5.dist-info → corp_extractor-0.3.0.dist-info}/METADATA +115 -22
- corp_extractor-0.3.0.dist-info/RECORD +12 -0
- statement_extractor/__init__.py +3 -1
- statement_extractor/cli.py +41 -1
- statement_extractor/extractor.py +381 -26
- statement_extractor/models.py +33 -1
- statement_extractor/predicate_comparer.py +23 -1
- statement_extractor/scoring.py +189 -97
- statement_extractor/spacy_extraction.py +386 -0
- corp_extractor-0.2.5.dist-info/RECORD +0 -11
- {corp_extractor-0.2.5.dist-info → corp_extractor-0.3.0.dist-info}/WHEEL +0 -0
- {corp_extractor-0.2.5.dist-info → corp_extractor-0.3.0.dist-info}/entry_points.txt +0 -0
statement_extractor/extractor.py
CHANGED
|
@@ -14,11 +14,12 @@ import xml.etree.ElementTree as ET
|
|
|
14
14
|
from typing import Optional
|
|
15
15
|
|
|
16
16
|
import torch
|
|
17
|
-
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|
17
|
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
|
|
18
18
|
|
|
19
19
|
from .models import (
|
|
20
20
|
Entity,
|
|
21
21
|
EntityType,
|
|
22
|
+
ExtractionMethod,
|
|
22
23
|
ExtractionOptions,
|
|
23
24
|
ExtractionResult,
|
|
24
25
|
PredicateComparisonConfig,
|
|
@@ -29,6 +30,114 @@ from .models import (
|
|
|
29
30
|
|
|
30
31
|
logger = logging.getLogger(__name__)
|
|
31
32
|
|
|
33
|
+
|
|
34
|
+
class StopOnSequence(StoppingCriteria):
|
|
35
|
+
"""
|
|
36
|
+
Stop generation when a specific multi-token sequence is generated.
|
|
37
|
+
|
|
38
|
+
Decodes the generated tokens and checks if the stop sequence appears.
|
|
39
|
+
Works with sequences that span multiple tokens (e.g., "</statements>").
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, tokenizer, stop_sequence: str, input_length: int):
|
|
43
|
+
self.tokenizer = tokenizer
|
|
44
|
+
self.stop_sequence = stop_sequence
|
|
45
|
+
self.input_length = input_length
|
|
46
|
+
# Track which beams have stopped (for batch generation)
|
|
47
|
+
self.stopped = set()
|
|
48
|
+
|
|
49
|
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
|
50
|
+
# Check each sequence in the batch
|
|
51
|
+
for idx, seq in enumerate(input_ids):
|
|
52
|
+
if idx in self.stopped:
|
|
53
|
+
continue
|
|
54
|
+
# Only decode the generated portion (after input)
|
|
55
|
+
generated = seq[self.input_length:]
|
|
56
|
+
decoded = self.tokenizer.decode(generated, skip_special_tokens=True)
|
|
57
|
+
if self.stop_sequence in decoded:
|
|
58
|
+
self.stopped.add(idx)
|
|
59
|
+
|
|
60
|
+
# Stop when ALL sequences have the stop sequence
|
|
61
|
+
return len(self.stopped) >= len(input_ids)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def repair_xml(xml_string: str) -> tuple[str, list[str]]:
|
|
65
|
+
"""
|
|
66
|
+
Attempt to repair common XML syntax errors.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Tuple of (repaired_xml, list_of_repairs_made)
|
|
70
|
+
"""
|
|
71
|
+
repairs = []
|
|
72
|
+
original = xml_string
|
|
73
|
+
|
|
74
|
+
# 1. Fix unescaped ampersands (but not already escaped entities)
|
|
75
|
+
# Match & not followed by amp; lt; gt; quot; apos; or #
|
|
76
|
+
ampersand_pattern = r'&(?!(amp|lt|gt|quot|apos|#\d+|#x[0-9a-fA-F]+);)'
|
|
77
|
+
if re.search(ampersand_pattern, xml_string):
|
|
78
|
+
xml_string = re.sub(ampersand_pattern, '&', xml_string)
|
|
79
|
+
repairs.append("escaped unescaped ampersands")
|
|
80
|
+
|
|
81
|
+
# 2. Fix unescaped < and > inside text content (not tags)
|
|
82
|
+
# This is tricky - we need to be careful not to break actual tags
|
|
83
|
+
# For now, just handle the most common case: < followed by space or lowercase
|
|
84
|
+
less_than_pattern = r'<(?=\s|[a-z]{2,}[^a-z/>])'
|
|
85
|
+
if re.search(less_than_pattern, xml_string):
|
|
86
|
+
xml_string = re.sub(less_than_pattern, '<', xml_string)
|
|
87
|
+
repairs.append("escaped unescaped less-than signs")
|
|
88
|
+
|
|
89
|
+
# 3. Fix truncated closing tags (e.g., "</statemen" -> try to complete)
|
|
90
|
+
truncated_patterns = [
|
|
91
|
+
(r'</statement[^s>]*$', '</statements>'),
|
|
92
|
+
(r'</stm[^t>]*$', '</stmt>'),
|
|
93
|
+
(r'</subjec[^t>]*$', '</subject>'),
|
|
94
|
+
(r'</objec[^t>]*$', '</object>'),
|
|
95
|
+
(r'</predica[^t>]*$', '</predicate>'),
|
|
96
|
+
(r'</tex[^t>]*$', '</text>'),
|
|
97
|
+
]
|
|
98
|
+
for pattern, replacement in truncated_patterns:
|
|
99
|
+
if re.search(pattern, xml_string):
|
|
100
|
+
xml_string = re.sub(pattern, replacement, xml_string)
|
|
101
|
+
repairs.append(f"completed truncated tag: {replacement}")
|
|
102
|
+
|
|
103
|
+
# 4. Add missing </statements> if we have <statements> but no closing
|
|
104
|
+
if '<statements>' in xml_string and '</statements>' not in xml_string:
|
|
105
|
+
# Try to find a good place to add it
|
|
106
|
+
# Look for the last complete </stmt> and add after it
|
|
107
|
+
last_stmt = xml_string.rfind('</stmt>')
|
|
108
|
+
if last_stmt != -1:
|
|
109
|
+
insert_pos = last_stmt + len('</stmt>')
|
|
110
|
+
xml_string = xml_string[:insert_pos] + '</statements>'
|
|
111
|
+
repairs.append("added missing </statements> after last </stmt>")
|
|
112
|
+
else:
|
|
113
|
+
xml_string = xml_string + '</statements>'
|
|
114
|
+
repairs.append("added missing </statements> at end")
|
|
115
|
+
|
|
116
|
+
# 5. Fix unclosed <stmt> tags - find <stmt> without matching </stmt>
|
|
117
|
+
# Count opens and closes
|
|
118
|
+
open_stmts = len(re.findall(r'<stmt>', xml_string))
|
|
119
|
+
close_stmts = len(re.findall(r'</stmt>', xml_string))
|
|
120
|
+
if open_stmts > close_stmts:
|
|
121
|
+
# Find incomplete statement blocks and try to close them
|
|
122
|
+
# Look for patterns like <stmt>...<subject>...</subject> without </stmt>
|
|
123
|
+
# This is complex, so just add closing tags before </statements>
|
|
124
|
+
missing = open_stmts - close_stmts
|
|
125
|
+
if '</statements>' in xml_string:
|
|
126
|
+
xml_string = xml_string.replace('</statements>', '</stmt>' * missing + '</statements>')
|
|
127
|
+
repairs.append(f"added {missing} missing </stmt> tag(s)")
|
|
128
|
+
|
|
129
|
+
# 6. Remove any content after </statements>
|
|
130
|
+
end_pos = xml_string.find('</statements>')
|
|
131
|
+
if end_pos != -1:
|
|
132
|
+
end_pos += len('</statements>')
|
|
133
|
+
if end_pos < len(xml_string):
|
|
134
|
+
xml_string = xml_string[:end_pos]
|
|
135
|
+
repairs.append("removed content after </statements>")
|
|
136
|
+
|
|
137
|
+
if xml_string != original:
|
|
138
|
+
return xml_string, repairs
|
|
139
|
+
return xml_string, []
|
|
140
|
+
|
|
32
141
|
# Default model
|
|
33
142
|
DEFAULT_MODEL_ID = "Corp-o-Rate-Community/statement-extractor"
|
|
34
143
|
|
|
@@ -80,11 +189,16 @@ class StatementExtractor:
|
|
|
80
189
|
|
|
81
190
|
# Auto-detect device
|
|
82
191
|
if device is None:
|
|
83
|
-
|
|
192
|
+
if torch.cuda.is_available():
|
|
193
|
+
self.device = "cuda"
|
|
194
|
+
elif torch.backends.mps.is_available():
|
|
195
|
+
self.device = "mps"
|
|
196
|
+
else:
|
|
197
|
+
self.device = "cpu"
|
|
84
198
|
else:
|
|
85
199
|
self.device = device
|
|
86
200
|
|
|
87
|
-
# Auto-detect dtype
|
|
201
|
+
# Auto-detect dtype (bfloat16 only for CUDA, float32 for MPS/CPU)
|
|
88
202
|
if torch_dtype is None:
|
|
89
203
|
self.torch_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
|
|
90
204
|
else:
|
|
@@ -175,6 +289,14 @@ class StatementExtractor:
|
|
|
175
289
|
if options is None:
|
|
176
290
|
options = ExtractionOptions()
|
|
177
291
|
|
|
292
|
+
logger.debug("=" * 60)
|
|
293
|
+
logger.debug("EXTRACTION STARTED")
|
|
294
|
+
logger.debug("=" * 60)
|
|
295
|
+
logger.debug(f"Input text length: {len(text)} chars")
|
|
296
|
+
logger.debug(f"Options: num_beams={options.num_beams}, diversity={options.diversity_penalty}")
|
|
297
|
+
logger.debug(f" merge_beams={options.merge_beams}, embedding_dedup={options.embedding_dedup}")
|
|
298
|
+
logger.debug(f" deduplicate={options.deduplicate}, max_new_tokens={options.max_new_tokens}")
|
|
299
|
+
|
|
178
300
|
# Store original text for scoring
|
|
179
301
|
original_text = text
|
|
180
302
|
|
|
@@ -185,6 +307,10 @@ class StatementExtractor:
|
|
|
185
307
|
# Run extraction with retry logic
|
|
186
308
|
statements = self._extract_with_scoring(text, original_text, options)
|
|
187
309
|
|
|
310
|
+
logger.debug("=" * 60)
|
|
311
|
+
logger.debug(f"EXTRACTION COMPLETE: {len(statements)} statements")
|
|
312
|
+
logger.debug("=" * 60)
|
|
313
|
+
|
|
188
314
|
return ExtractionResult(
|
|
189
315
|
statements=statements,
|
|
190
316
|
source_text=original_text,
|
|
@@ -270,6 +396,10 @@ class StatementExtractor:
|
|
|
270
396
|
4. Merges top beams or selects best beam
|
|
271
397
|
5. Deduplicates using embeddings (if enabled)
|
|
272
398
|
"""
|
|
399
|
+
logger.debug("-" * 40)
|
|
400
|
+
logger.debug("PHASE 1: Tokenization")
|
|
401
|
+
logger.debug("-" * 40)
|
|
402
|
+
|
|
273
403
|
# Tokenize input
|
|
274
404
|
inputs = self.tokenizer(
|
|
275
405
|
text,
|
|
@@ -278,48 +408,78 @@ class StatementExtractor:
|
|
|
278
408
|
truncation=True,
|
|
279
409
|
).to(self.device)
|
|
280
410
|
|
|
411
|
+
input_ids = inputs["input_ids"]
|
|
412
|
+
logger.debug(f"Tokenized: {input_ids.shape[1]} tokens")
|
|
413
|
+
|
|
281
414
|
# Count sentences for quality check
|
|
282
415
|
num_sentences = self._count_sentences(text)
|
|
283
416
|
min_expected = int(num_sentences * options.min_statement_ratio)
|
|
284
417
|
|
|
285
|
-
logger.
|
|
418
|
+
logger.debug(f"Input has ~{num_sentences} sentences, min expected: {min_expected}")
|
|
286
419
|
|
|
287
420
|
# Get beam scorer
|
|
288
421
|
beam_scorer = self._get_beam_scorer(options)
|
|
289
422
|
|
|
423
|
+
logger.debug("-" * 40)
|
|
424
|
+
logger.debug("PHASE 2: Diverse Beam Search Generation")
|
|
425
|
+
logger.debug("-" * 40)
|
|
426
|
+
|
|
290
427
|
all_candidates: list[list[Statement]] = []
|
|
291
428
|
|
|
292
429
|
for attempt in range(options.max_attempts):
|
|
430
|
+
logger.debug(f"Attempt {attempt + 1}/{options.max_attempts}: Generating {options.num_beams} beams...")
|
|
431
|
+
|
|
293
432
|
# Generate candidate beams
|
|
294
433
|
candidates = self._generate_candidate_beams(inputs, options)
|
|
434
|
+
logger.debug(f" Generated {len(candidates)} valid XML outputs")
|
|
295
435
|
|
|
296
436
|
# Parse each candidate to statements
|
|
297
437
|
parsed_candidates = []
|
|
298
|
-
for xml_output in candidates:
|
|
299
|
-
statements = self._parse_xml_to_statements(xml_output)
|
|
438
|
+
for i, xml_output in enumerate(candidates):
|
|
439
|
+
statements = self._parse_xml_to_statements(xml_output, options)
|
|
300
440
|
if statements:
|
|
301
441
|
parsed_candidates.append(statements)
|
|
442
|
+
logger.debug(f" Beam {i}: {len(statements)} statements parsed")
|
|
443
|
+
else:
|
|
444
|
+
logger.warning(f" Beam {i}: 0 statements (parse failed)")
|
|
445
|
+
logger.warning(f" Beam {i} XML output:\n{xml_output}")
|
|
302
446
|
|
|
303
447
|
all_candidates.extend(parsed_candidates)
|
|
304
448
|
|
|
305
449
|
# Check if we have enough statements
|
|
306
450
|
total_stmts = sum(len(c) for c in parsed_candidates)
|
|
307
|
-
logger.
|
|
451
|
+
logger.debug(f" Total: {len(parsed_candidates)} beams, {total_stmts} statements")
|
|
308
452
|
|
|
309
453
|
if total_stmts >= min_expected:
|
|
454
|
+
logger.debug(f" Sufficient statements ({total_stmts} >= {min_expected}), stopping")
|
|
310
455
|
break
|
|
311
456
|
|
|
312
457
|
if not all_candidates:
|
|
458
|
+
logger.debug("No valid candidates generated, returning empty result")
|
|
313
459
|
return []
|
|
314
460
|
|
|
461
|
+
logger.debug("-" * 40)
|
|
462
|
+
logger.debug("PHASE 3: Beam Selection/Merging")
|
|
463
|
+
logger.debug("-" * 40)
|
|
464
|
+
|
|
315
465
|
# Select or merge beams
|
|
316
466
|
if options.merge_beams:
|
|
467
|
+
logger.debug(f"Merging {len(all_candidates)} beams...")
|
|
317
468
|
statements = beam_scorer.merge_beams(all_candidates, original_text)
|
|
469
|
+
logger.debug(f" After merge: {len(statements)} statements")
|
|
318
470
|
else:
|
|
471
|
+
logger.debug(f"Selecting best beam from {len(all_candidates)} candidates...")
|
|
319
472
|
statements = beam_scorer.select_best_beam(all_candidates, original_text)
|
|
473
|
+
logger.debug(f" Selected beam has {len(statements)} statements")
|
|
474
|
+
|
|
475
|
+
logger.debug("-" * 40)
|
|
476
|
+
logger.debug("PHASE 4: Deduplication")
|
|
477
|
+
logger.debug("-" * 40)
|
|
320
478
|
|
|
321
479
|
# Apply embedding-based deduplication if enabled
|
|
322
480
|
if options.embedding_dedup and options.deduplicate:
|
|
481
|
+
logger.debug("Using embedding-based deduplication...")
|
|
482
|
+
pre_dedup_count = len(statements)
|
|
323
483
|
try:
|
|
324
484
|
comparer = self._get_predicate_comparer(options)
|
|
325
485
|
if comparer:
|
|
@@ -327,14 +487,41 @@ class StatementExtractor:
|
|
|
327
487
|
statements,
|
|
328
488
|
entity_canonicalizer=options.entity_canonicalizer
|
|
329
489
|
)
|
|
490
|
+
logger.debug(f" After embedding dedup: {len(statements)} statements (removed {pre_dedup_count - len(statements)})")
|
|
491
|
+
|
|
330
492
|
# Also normalize predicates if taxonomy provided
|
|
331
493
|
if options.predicate_taxonomy or self._predicate_taxonomy:
|
|
494
|
+
logger.debug("Normalizing predicates to taxonomy...")
|
|
332
495
|
statements = comparer.normalize_predicates(statements)
|
|
333
496
|
except Exception as e:
|
|
334
497
|
logger.warning(f"Embedding deduplication failed, falling back to exact match: {e}")
|
|
335
498
|
statements = self._deduplicate_statements_exact(statements, options)
|
|
499
|
+
logger.debug(f" After exact dedup: {len(statements)} statements")
|
|
336
500
|
elif options.deduplicate:
|
|
501
|
+
logger.debug("Using exact text deduplication...")
|
|
502
|
+
pre_dedup_count = len(statements)
|
|
337
503
|
statements = self._deduplicate_statements_exact(statements, options)
|
|
504
|
+
logger.debug(f" After exact dedup: {len(statements)} statements (removed {pre_dedup_count - len(statements)})")
|
|
505
|
+
else:
|
|
506
|
+
logger.debug("Deduplication disabled")
|
|
507
|
+
|
|
508
|
+
# Select best triple per source text (unless all_triples enabled)
|
|
509
|
+
if not options.all_triples:
|
|
510
|
+
logger.debug("-" * 40)
|
|
511
|
+
logger.debug("PHASE 5: Best Triple Selection")
|
|
512
|
+
logger.debug("-" * 40)
|
|
513
|
+
pre_select_count = len(statements)
|
|
514
|
+
statements = self._select_best_per_source(statements)
|
|
515
|
+
logger.debug(f" Selected best per source: {len(statements)} statements (from {pre_select_count})")
|
|
516
|
+
|
|
517
|
+
# Log final statements
|
|
518
|
+
logger.debug("-" * 40)
|
|
519
|
+
logger.debug("FINAL STATEMENTS:")
|
|
520
|
+
logger.debug("-" * 40)
|
|
521
|
+
for i, stmt in enumerate(statements):
|
|
522
|
+
conf = f" (conf={stmt.confidence_score:.2f})" if stmt.confidence_score else ""
|
|
523
|
+
canonical = f" -> {stmt.canonical_predicate}" if stmt.canonical_predicate else ""
|
|
524
|
+
logger.debug(f" {i+1}. {stmt.subject.text} --[{stmt.predicate}{canonical}]--> {stmt.object.text}{conf}")
|
|
338
525
|
|
|
339
526
|
return statements
|
|
340
527
|
|
|
@@ -346,16 +533,29 @@ class StatementExtractor:
|
|
|
346
533
|
"""Generate multiple candidate beams using diverse beam search."""
|
|
347
534
|
num_seqs = options.num_beams
|
|
348
535
|
|
|
536
|
+
# Create stopping criteria to stop when </statements> is generated
|
|
537
|
+
input_length = inputs["input_ids"].shape[1]
|
|
538
|
+
stop_criteria = StopOnSequence(
|
|
539
|
+
tokenizer=self.tokenizer,
|
|
540
|
+
stop_sequence="</statements>",
|
|
541
|
+
input_length=input_length,
|
|
542
|
+
)
|
|
543
|
+
|
|
349
544
|
with torch.no_grad():
|
|
350
545
|
outputs = self.model.generate(
|
|
351
546
|
**inputs,
|
|
352
547
|
max_new_tokens=options.max_new_tokens,
|
|
548
|
+
max_length=None, # Override model default, use max_new_tokens only
|
|
353
549
|
num_beams=num_seqs,
|
|
354
550
|
num_beam_groups=num_seqs,
|
|
355
551
|
num_return_sequences=num_seqs,
|
|
356
552
|
diversity_penalty=options.diversity_penalty,
|
|
357
553
|
do_sample=False,
|
|
554
|
+
top_p=None, # Override model config to suppress warning
|
|
555
|
+
top_k=None, # Override model config to suppress warning
|
|
358
556
|
trust_remote_code=True,
|
|
557
|
+
custom_generate="transformers-community/group-beam-search",
|
|
558
|
+
stopping_criteria=StoppingCriteriaList([stop_criteria]),
|
|
359
559
|
)
|
|
360
560
|
|
|
361
561
|
# Decode and process candidates
|
|
@@ -439,6 +639,50 @@ class StatementExtractor:
|
|
|
439
639
|
entity_canonicalizer=options.entity_canonicalizer
|
|
440
640
|
)
|
|
441
641
|
|
|
642
|
+
def _select_best_per_source(
|
|
643
|
+
self,
|
|
644
|
+
statements: list[Statement],
|
|
645
|
+
) -> list[Statement]:
|
|
646
|
+
"""
|
|
647
|
+
Select the highest-scoring triple for each unique source text.
|
|
648
|
+
|
|
649
|
+
Groups statements by source_text and keeps only the one with
|
|
650
|
+
the highest confidence_score from each group.
|
|
651
|
+
|
|
652
|
+
Statements without source_text are kept as-is.
|
|
653
|
+
"""
|
|
654
|
+
if not statements:
|
|
655
|
+
return statements
|
|
656
|
+
|
|
657
|
+
# Group by source_text
|
|
658
|
+
from collections import defaultdict
|
|
659
|
+
groups: dict[str | None, list[Statement]] = defaultdict(list)
|
|
660
|
+
|
|
661
|
+
for stmt in statements:
|
|
662
|
+
groups[stmt.source_text].append(stmt)
|
|
663
|
+
|
|
664
|
+
# Select best from each group
|
|
665
|
+
result: list[Statement] = []
|
|
666
|
+
|
|
667
|
+
for source_text, group in groups.items():
|
|
668
|
+
if source_text is None or len(group) == 1:
|
|
669
|
+
# No source text or only one statement - keep as-is
|
|
670
|
+
result.extend(group)
|
|
671
|
+
else:
|
|
672
|
+
# Multiple candidates for same source - select best
|
|
673
|
+
best = max(
|
|
674
|
+
group,
|
|
675
|
+
key=lambda s: s.confidence_score if s.confidence_score is not None else 0.0
|
|
676
|
+
)
|
|
677
|
+
logger.debug(
|
|
678
|
+
f" Selected best for source '{source_text[:40]}...': "
|
|
679
|
+
f"'{best.subject.text}' --[{best.predicate}]--> '{best.object.text}' "
|
|
680
|
+
f"(score={best.confidence_score:.2f}, method={best.extraction_method.value})"
|
|
681
|
+
)
|
|
682
|
+
result.append(best)
|
|
683
|
+
|
|
684
|
+
return result
|
|
685
|
+
|
|
442
686
|
def _deduplicate_xml(self, xml_output: str) -> str:
|
|
443
687
|
"""Remove duplicate <stmt> blocks from XML output (legacy method)."""
|
|
444
688
|
try:
|
|
@@ -468,48 +712,159 @@ class StatementExtractor:
|
|
|
468
712
|
|
|
469
713
|
return ET.tostring(new_root, encoding='unicode')
|
|
470
714
|
|
|
471
|
-
def _parse_xml_to_statements(
|
|
472
|
-
|
|
715
|
+
def _parse_xml_to_statements(
|
|
716
|
+
self,
|
|
717
|
+
xml_output: str,
|
|
718
|
+
options: Optional[ExtractionOptions] = None,
|
|
719
|
+
) -> list[Statement]:
|
|
720
|
+
"""
|
|
721
|
+
Parse XML output into Statement objects.
|
|
722
|
+
|
|
723
|
+
Uses model for subject, object, entity types, and source_text.
|
|
724
|
+
Always uses spaCy for predicate extraction (model predicates are unreliable).
|
|
725
|
+
|
|
726
|
+
Produces two candidates for each statement:
|
|
727
|
+
1. Hybrid: model subject/object + spaCy predicate
|
|
728
|
+
2. spaCy-only: all components from spaCy
|
|
729
|
+
|
|
730
|
+
Both go into the candidate pool; scoring/dedup picks the best.
|
|
731
|
+
"""
|
|
473
732
|
statements: list[Statement] = []
|
|
733
|
+
use_spacy_extraction = options.use_spacy_extraction if options else True
|
|
474
734
|
|
|
475
735
|
try:
|
|
476
736
|
root = ET.fromstring(xml_output)
|
|
477
737
|
except ET.ParseError as e:
|
|
478
738
|
# Log full output for debugging
|
|
479
|
-
logger.
|
|
480
|
-
logger.
|
|
481
|
-
|
|
739
|
+
logger.debug(f"Initial XML parse failed: {e}")
|
|
740
|
+
logger.debug(f"Raw XML output ({len(xml_output)} chars):\n{xml_output}")
|
|
741
|
+
|
|
742
|
+
# Try to repair the XML
|
|
743
|
+
repaired_xml, repairs = repair_xml(xml_output)
|
|
744
|
+
if repairs:
|
|
745
|
+
logger.debug(f"Attempted XML repairs: {', '.join(repairs)}")
|
|
746
|
+
try:
|
|
747
|
+
root = ET.fromstring(repaired_xml)
|
|
748
|
+
logger.info(f"XML repair successful, parsing repaired output")
|
|
749
|
+
except ET.ParseError as e2:
|
|
750
|
+
logger.warning(f"XML repair failed, still cannot parse: {e2}")
|
|
751
|
+
logger.warning(f"Repaired XML ({len(repaired_xml)} chars):\n{repaired_xml}")
|
|
752
|
+
return statements
|
|
753
|
+
else:
|
|
754
|
+
logger.warning(f"No repairs possible for XML output")
|
|
755
|
+
logger.warning(f"Full XML output ({len(xml_output)} chars):\n{xml_output}")
|
|
756
|
+
return statements
|
|
482
757
|
|
|
483
758
|
if root.tag != 'statements':
|
|
759
|
+
logger.warning(f"Root tag is '{root.tag}', expected 'statements'")
|
|
484
760
|
return statements
|
|
485
761
|
|
|
486
762
|
for stmt_elem in root.findall('stmt'):
|
|
487
763
|
try:
|
|
488
|
-
# Parse subject
|
|
764
|
+
# Parse subject from model
|
|
489
765
|
subject_elem = stmt_elem.find('subject')
|
|
490
766
|
subject_text = subject_elem.text.strip() if subject_elem is not None and subject_elem.text else ""
|
|
491
767
|
subject_type = self._parse_entity_type(subject_elem.get('type') if subject_elem is not None else None)
|
|
492
768
|
|
|
493
|
-
# Parse object
|
|
769
|
+
# Parse object from model
|
|
494
770
|
object_elem = stmt_elem.find('object')
|
|
495
771
|
object_text = object_elem.text.strip() if object_elem is not None and object_elem.text else ""
|
|
496
772
|
object_type = self._parse_entity_type(object_elem.get('type') if object_elem is not None else None)
|
|
497
773
|
|
|
498
|
-
# Parse
|
|
499
|
-
predicate_elem = stmt_elem.find('predicate')
|
|
500
|
-
predicate = predicate_elem.text.strip() if predicate_elem is not None and predicate_elem.text else ""
|
|
501
|
-
|
|
502
|
-
# Parse source text
|
|
774
|
+
# Parse source text from model
|
|
503
775
|
text_elem = stmt_elem.find('text')
|
|
504
776
|
source_text = text_elem.text.strip() if text_elem is not None and text_elem.text else None
|
|
505
777
|
|
|
506
|
-
if
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
778
|
+
# Skip if missing required components from model
|
|
779
|
+
if not subject_text or not object_text:
|
|
780
|
+
logger.debug(f"Skipping statement: missing subject or object from model")
|
|
781
|
+
continue
|
|
782
|
+
|
|
783
|
+
if use_spacy_extraction and source_text:
|
|
784
|
+
try:
|
|
785
|
+
from .spacy_extraction import extract_triple_from_text, extract_triple_by_predicate_split
|
|
786
|
+
spacy_result = extract_triple_from_text(
|
|
787
|
+
source_text=source_text,
|
|
788
|
+
model_subject=subject_text,
|
|
789
|
+
model_object=object_text,
|
|
790
|
+
model_predicate="", # Don't pass model predicate
|
|
791
|
+
)
|
|
792
|
+
if spacy_result:
|
|
793
|
+
spacy_subj, spacy_pred, spacy_obj = spacy_result
|
|
794
|
+
|
|
795
|
+
if spacy_pred:
|
|
796
|
+
# Candidate 1: Hybrid (model subject/object + spaCy predicate)
|
|
797
|
+
logger.debug(
|
|
798
|
+
f"Adding hybrid candidate: '{subject_text}' --[{spacy_pred}]--> '{object_text}'"
|
|
799
|
+
)
|
|
800
|
+
statements.append(Statement(
|
|
801
|
+
subject=Entity(text=subject_text, type=subject_type),
|
|
802
|
+
predicate=spacy_pred,
|
|
803
|
+
object=Entity(text=object_text, type=object_type),
|
|
804
|
+
source_text=source_text,
|
|
805
|
+
extraction_method=ExtractionMethod.HYBRID,
|
|
806
|
+
))
|
|
807
|
+
|
|
808
|
+
# Candidate 2: spaCy-only (if different from hybrid)
|
|
809
|
+
if spacy_subj and spacy_obj:
|
|
810
|
+
is_different = (spacy_subj != subject_text or spacy_obj != object_text)
|
|
811
|
+
if is_different:
|
|
812
|
+
logger.debug(
|
|
813
|
+
f"Adding spaCy-only candidate: '{spacy_subj}' --[{spacy_pred}]--> '{spacy_obj}'"
|
|
814
|
+
)
|
|
815
|
+
statements.append(Statement(
|
|
816
|
+
subject=Entity(text=spacy_subj, type=subject_type),
|
|
817
|
+
predicate=spacy_pred,
|
|
818
|
+
object=Entity(text=spacy_obj, type=object_type),
|
|
819
|
+
source_text=source_text,
|
|
820
|
+
extraction_method=ExtractionMethod.SPACY,
|
|
821
|
+
))
|
|
822
|
+
|
|
823
|
+
# Candidate 3: Predicate-split (split source text around predicate)
|
|
824
|
+
split_result = extract_triple_by_predicate_split(
|
|
825
|
+
source_text=source_text,
|
|
826
|
+
predicate=spacy_pred,
|
|
827
|
+
)
|
|
828
|
+
if split_result:
|
|
829
|
+
split_subj, split_pred, split_obj = split_result
|
|
830
|
+
# Only add if different from previous candidates
|
|
831
|
+
is_different_from_hybrid = (split_subj != subject_text or split_obj != object_text)
|
|
832
|
+
is_different_from_spacy = (split_subj != spacy_subj or split_obj != spacy_obj)
|
|
833
|
+
if is_different_from_hybrid and is_different_from_spacy:
|
|
834
|
+
logger.debug(
|
|
835
|
+
f"Adding predicate-split candidate: '{split_subj}' --[{split_pred}]--> '{split_obj}'"
|
|
836
|
+
)
|
|
837
|
+
statements.append(Statement(
|
|
838
|
+
subject=Entity(text=split_subj, type=subject_type),
|
|
839
|
+
predicate=split_pred,
|
|
840
|
+
object=Entity(text=split_obj, type=object_type),
|
|
841
|
+
source_text=source_text,
|
|
842
|
+
extraction_method=ExtractionMethod.SPLIT,
|
|
843
|
+
))
|
|
844
|
+
else:
|
|
845
|
+
logger.debug(
|
|
846
|
+
f"spaCy found no predicate for: '{subject_text}' --> '{object_text}'"
|
|
847
|
+
)
|
|
848
|
+
except Exception as e:
|
|
849
|
+
logger.debug(f"spaCy extraction failed: {e}")
|
|
850
|
+
else:
|
|
851
|
+
# spaCy disabled - fall back to model predicate
|
|
852
|
+
predicate_elem = stmt_elem.find('predicate')
|
|
853
|
+
model_predicate = predicate_elem.text.strip() if predicate_elem is not None and predicate_elem.text else ""
|
|
854
|
+
|
|
855
|
+
if model_predicate:
|
|
856
|
+
statements.append(Statement(
|
|
857
|
+
subject=Entity(text=subject_text, type=subject_type),
|
|
858
|
+
predicate=model_predicate,
|
|
859
|
+
object=Entity(text=object_text, type=object_type),
|
|
860
|
+
source_text=source_text,
|
|
861
|
+
extraction_method=ExtractionMethod.MODEL,
|
|
862
|
+
))
|
|
863
|
+
else:
|
|
864
|
+
logger.debug(
|
|
865
|
+
f"Skipping statement (no predicate, spaCy disabled): "
|
|
866
|
+
f"'{subject_text}' --> '{object_text}'"
|
|
867
|
+
)
|
|
513
868
|
except Exception as e:
|
|
514
869
|
logger.warning(f"Failed to parse statement: {e}")
|
|
515
870
|
continue
|
statement_extractor/models.py
CHANGED
|
@@ -24,6 +24,14 @@ class EntityType(str, Enum):
|
|
|
24
24
|
UNKNOWN = "UNKNOWN"
|
|
25
25
|
|
|
26
26
|
|
|
27
|
+
class ExtractionMethod(str, Enum):
|
|
28
|
+
"""Method used to extract the triple components."""
|
|
29
|
+
HYBRID = "hybrid" # Model subject/object + spaCy predicate
|
|
30
|
+
SPACY = "spacy" # All components from spaCy dependency parsing
|
|
31
|
+
SPLIT = "split" # Subject/object from splitting source text around predicate
|
|
32
|
+
MODEL = "model" # All components from T5-Gemma model (when spaCy disabled)
|
|
33
|
+
|
|
34
|
+
|
|
27
35
|
class Entity(BaseModel):
|
|
28
36
|
"""An entity (subject or object) with its text and type."""
|
|
29
37
|
text: str = Field(..., description="The entity text")
|
|
@@ -52,12 +60,18 @@ class Statement(BaseModel):
|
|
|
52
60
|
object: Entity = Field(..., description="The object entity")
|
|
53
61
|
source_text: Optional[str] = Field(None, description="The original text this statement was extracted from")
|
|
54
62
|
|
|
63
|
+
# Extraction method tracking
|
|
64
|
+
extraction_method: ExtractionMethod = Field(
|
|
65
|
+
default=ExtractionMethod.MODEL,
|
|
66
|
+
description="Method used to extract this triple (hybrid, spacy, split, or model)"
|
|
67
|
+
)
|
|
68
|
+
|
|
55
69
|
# Quality scoring fields
|
|
56
70
|
confidence_score: Optional[float] = Field(
|
|
57
71
|
None,
|
|
58
72
|
ge=0.0,
|
|
59
73
|
le=1.0,
|
|
60
|
-
description="
|
|
74
|
+
description="Semantic similarity score (0-1) between source text and reassembled triple"
|
|
61
75
|
)
|
|
62
76
|
evidence_span: Optional[tuple[int, int]] = Field(
|
|
63
77
|
None,
|
|
@@ -99,6 +113,7 @@ class Statement(BaseModel):
|
|
|
99
113
|
object=merged_object,
|
|
100
114
|
predicate=self.predicate,
|
|
101
115
|
source_text=self.source_text,
|
|
116
|
+
extraction_method=self.extraction_method,
|
|
102
117
|
confidence_score=self.confidence_score,
|
|
103
118
|
evidence_span=self.evidence_span,
|
|
104
119
|
canonical_predicate=self.canonical_predicate,
|
|
@@ -116,6 +131,7 @@ class Statement(BaseModel):
|
|
|
116
131
|
object=self.subject,
|
|
117
132
|
predicate=self.predicate,
|
|
118
133
|
source_text=self.source_text,
|
|
134
|
+
extraction_method=self.extraction_method,
|
|
119
135
|
confidence_score=self.confidence_score,
|
|
120
136
|
evidence_span=self.evidence_span,
|
|
121
137
|
canonical_predicate=self.canonical_predicate,
|
|
@@ -279,6 +295,22 @@ class ExtractionOptions(BaseModel):
|
|
|
279
295
|
default=True,
|
|
280
296
|
description="Use embedding similarity for predicate deduplication"
|
|
281
297
|
)
|
|
298
|
+
use_spacy_extraction: bool = Field(
|
|
299
|
+
default=True,
|
|
300
|
+
description="Use spaCy for predicate/subject/object extraction (model provides structure + coreference)"
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
# Verbose logging
|
|
304
|
+
verbose: bool = Field(
|
|
305
|
+
default=False,
|
|
306
|
+
description="Enable verbose logging for debugging"
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
# Triple selection
|
|
310
|
+
all_triples: bool = Field(
|
|
311
|
+
default=False,
|
|
312
|
+
description="Keep all candidate triples instead of selecting the highest-scoring one per source"
|
|
313
|
+
)
|
|
282
314
|
|
|
283
315
|
class Config:
|
|
284
316
|
arbitrary_types_allowed = True # Allow Callable type
|