corp-extractor 0.2.11__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {corp_extractor-0.2.11.dist-info → corp_extractor-0.3.0.dist-info}/METADATA +104 -19
- corp_extractor-0.3.0.dist-info/RECORD +12 -0
- statement_extractor/__init__.py +3 -1
- statement_extractor/cli.py +10 -0
- statement_extractor/extractor.py +305 -22
- statement_extractor/models.py +27 -1
- statement_extractor/scoring.py +160 -90
- statement_extractor/spacy_extraction.py +386 -0
- corp_extractor-0.2.11.dist-info/RECORD +0 -11
- {corp_extractor-0.2.11.dist-info → corp_extractor-0.3.0.dist-info}/WHEEL +0 -0
- {corp_extractor-0.2.11.dist-info → corp_extractor-0.3.0.dist-info}/entry_points.txt +0 -0
statement_extractor/extractor.py
CHANGED
|
@@ -14,11 +14,12 @@ import xml.etree.ElementTree as ET
|
|
|
14
14
|
from typing import Optional
|
|
15
15
|
|
|
16
16
|
import torch
|
|
17
|
-
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|
17
|
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
|
|
18
18
|
|
|
19
19
|
from .models import (
|
|
20
20
|
Entity,
|
|
21
21
|
EntityType,
|
|
22
|
+
ExtractionMethod,
|
|
22
23
|
ExtractionOptions,
|
|
23
24
|
ExtractionResult,
|
|
24
25
|
PredicateComparisonConfig,
|
|
@@ -29,6 +30,114 @@ from .models import (
|
|
|
29
30
|
|
|
30
31
|
logger = logging.getLogger(__name__)
|
|
31
32
|
|
|
33
|
+
|
|
34
|
+
class StopOnSequence(StoppingCriteria):
|
|
35
|
+
"""
|
|
36
|
+
Stop generation when a specific multi-token sequence is generated.
|
|
37
|
+
|
|
38
|
+
Decodes the generated tokens and checks if the stop sequence appears.
|
|
39
|
+
Works with sequences that span multiple tokens (e.g., "</statements>").
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, tokenizer, stop_sequence: str, input_length: int):
|
|
43
|
+
self.tokenizer = tokenizer
|
|
44
|
+
self.stop_sequence = stop_sequence
|
|
45
|
+
self.input_length = input_length
|
|
46
|
+
# Track which beams have stopped (for batch generation)
|
|
47
|
+
self.stopped = set()
|
|
48
|
+
|
|
49
|
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
|
50
|
+
# Check each sequence in the batch
|
|
51
|
+
for idx, seq in enumerate(input_ids):
|
|
52
|
+
if idx in self.stopped:
|
|
53
|
+
continue
|
|
54
|
+
# Only decode the generated portion (after input)
|
|
55
|
+
generated = seq[self.input_length:]
|
|
56
|
+
decoded = self.tokenizer.decode(generated, skip_special_tokens=True)
|
|
57
|
+
if self.stop_sequence in decoded:
|
|
58
|
+
self.stopped.add(idx)
|
|
59
|
+
|
|
60
|
+
# Stop when ALL sequences have the stop sequence
|
|
61
|
+
return len(self.stopped) >= len(input_ids)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def repair_xml(xml_string: str) -> tuple[str, list[str]]:
|
|
65
|
+
"""
|
|
66
|
+
Attempt to repair common XML syntax errors.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Tuple of (repaired_xml, list_of_repairs_made)
|
|
70
|
+
"""
|
|
71
|
+
repairs = []
|
|
72
|
+
original = xml_string
|
|
73
|
+
|
|
74
|
+
# 1. Fix unescaped ampersands (but not already escaped entities)
|
|
75
|
+
# Match & not followed by amp; lt; gt; quot; apos; or #
|
|
76
|
+
ampersand_pattern = r'&(?!(amp|lt|gt|quot|apos|#\d+|#x[0-9a-fA-F]+);)'
|
|
77
|
+
if re.search(ampersand_pattern, xml_string):
|
|
78
|
+
xml_string = re.sub(ampersand_pattern, '&', xml_string)
|
|
79
|
+
repairs.append("escaped unescaped ampersands")
|
|
80
|
+
|
|
81
|
+
# 2. Fix unescaped < and > inside text content (not tags)
|
|
82
|
+
# This is tricky - we need to be careful not to break actual tags
|
|
83
|
+
# For now, just handle the most common case: < followed by space or lowercase
|
|
84
|
+
less_than_pattern = r'<(?=\s|[a-z]{2,}[^a-z/>])'
|
|
85
|
+
if re.search(less_than_pattern, xml_string):
|
|
86
|
+
xml_string = re.sub(less_than_pattern, '<', xml_string)
|
|
87
|
+
repairs.append("escaped unescaped less-than signs")
|
|
88
|
+
|
|
89
|
+
# 3. Fix truncated closing tags (e.g., "</statemen" -> try to complete)
|
|
90
|
+
truncated_patterns = [
|
|
91
|
+
(r'</statement[^s>]*$', '</statements>'),
|
|
92
|
+
(r'</stm[^t>]*$', '</stmt>'),
|
|
93
|
+
(r'</subjec[^t>]*$', '</subject>'),
|
|
94
|
+
(r'</objec[^t>]*$', '</object>'),
|
|
95
|
+
(r'</predica[^t>]*$', '</predicate>'),
|
|
96
|
+
(r'</tex[^t>]*$', '</text>'),
|
|
97
|
+
]
|
|
98
|
+
for pattern, replacement in truncated_patterns:
|
|
99
|
+
if re.search(pattern, xml_string):
|
|
100
|
+
xml_string = re.sub(pattern, replacement, xml_string)
|
|
101
|
+
repairs.append(f"completed truncated tag: {replacement}")
|
|
102
|
+
|
|
103
|
+
# 4. Add missing </statements> if we have <statements> but no closing
|
|
104
|
+
if '<statements>' in xml_string and '</statements>' not in xml_string:
|
|
105
|
+
# Try to find a good place to add it
|
|
106
|
+
# Look for the last complete </stmt> and add after it
|
|
107
|
+
last_stmt = xml_string.rfind('</stmt>')
|
|
108
|
+
if last_stmt != -1:
|
|
109
|
+
insert_pos = last_stmt + len('</stmt>')
|
|
110
|
+
xml_string = xml_string[:insert_pos] + '</statements>'
|
|
111
|
+
repairs.append("added missing </statements> after last </stmt>")
|
|
112
|
+
else:
|
|
113
|
+
xml_string = xml_string + '</statements>'
|
|
114
|
+
repairs.append("added missing </statements> at end")
|
|
115
|
+
|
|
116
|
+
# 5. Fix unclosed <stmt> tags - find <stmt> without matching </stmt>
|
|
117
|
+
# Count opens and closes
|
|
118
|
+
open_stmts = len(re.findall(r'<stmt>', xml_string))
|
|
119
|
+
close_stmts = len(re.findall(r'</stmt>', xml_string))
|
|
120
|
+
if open_stmts > close_stmts:
|
|
121
|
+
# Find incomplete statement blocks and try to close them
|
|
122
|
+
# Look for patterns like <stmt>...<subject>...</subject> without </stmt>
|
|
123
|
+
# This is complex, so just add closing tags before </statements>
|
|
124
|
+
missing = open_stmts - close_stmts
|
|
125
|
+
if '</statements>' in xml_string:
|
|
126
|
+
xml_string = xml_string.replace('</statements>', '</stmt>' * missing + '</statements>')
|
|
127
|
+
repairs.append(f"added {missing} missing </stmt> tag(s)")
|
|
128
|
+
|
|
129
|
+
# 6. Remove any content after </statements>
|
|
130
|
+
end_pos = xml_string.find('</statements>')
|
|
131
|
+
if end_pos != -1:
|
|
132
|
+
end_pos += len('</statements>')
|
|
133
|
+
if end_pos < len(xml_string):
|
|
134
|
+
xml_string = xml_string[:end_pos]
|
|
135
|
+
repairs.append("removed content after </statements>")
|
|
136
|
+
|
|
137
|
+
if xml_string != original:
|
|
138
|
+
return xml_string, repairs
|
|
139
|
+
return xml_string, []
|
|
140
|
+
|
|
32
141
|
# Default model
|
|
33
142
|
DEFAULT_MODEL_ID = "Corp-o-Rate-Community/statement-extractor"
|
|
34
143
|
|
|
@@ -327,12 +436,13 @@ class StatementExtractor:
|
|
|
327
436
|
# Parse each candidate to statements
|
|
328
437
|
parsed_candidates = []
|
|
329
438
|
for i, xml_output in enumerate(candidates):
|
|
330
|
-
statements = self._parse_xml_to_statements(xml_output)
|
|
439
|
+
statements = self._parse_xml_to_statements(xml_output, options)
|
|
331
440
|
if statements:
|
|
332
441
|
parsed_candidates.append(statements)
|
|
333
442
|
logger.debug(f" Beam {i}: {len(statements)} statements parsed")
|
|
334
443
|
else:
|
|
335
|
-
logger.
|
|
444
|
+
logger.warning(f" Beam {i}: 0 statements (parse failed)")
|
|
445
|
+
logger.warning(f" Beam {i} XML output:\n{xml_output}")
|
|
336
446
|
|
|
337
447
|
all_candidates.extend(parsed_candidates)
|
|
338
448
|
|
|
@@ -395,6 +505,15 @@ class StatementExtractor:
|
|
|
395
505
|
else:
|
|
396
506
|
logger.debug("Deduplication disabled")
|
|
397
507
|
|
|
508
|
+
# Select best triple per source text (unless all_triples enabled)
|
|
509
|
+
if not options.all_triples:
|
|
510
|
+
logger.debug("-" * 40)
|
|
511
|
+
logger.debug("PHASE 5: Best Triple Selection")
|
|
512
|
+
logger.debug("-" * 40)
|
|
513
|
+
pre_select_count = len(statements)
|
|
514
|
+
statements = self._select_best_per_source(statements)
|
|
515
|
+
logger.debug(f" Selected best per source: {len(statements)} statements (from {pre_select_count})")
|
|
516
|
+
|
|
398
517
|
# Log final statements
|
|
399
518
|
logger.debug("-" * 40)
|
|
400
519
|
logger.debug("FINAL STATEMENTS:")
|
|
@@ -414,6 +533,14 @@ class StatementExtractor:
|
|
|
414
533
|
"""Generate multiple candidate beams using diverse beam search."""
|
|
415
534
|
num_seqs = options.num_beams
|
|
416
535
|
|
|
536
|
+
# Create stopping criteria to stop when </statements> is generated
|
|
537
|
+
input_length = inputs["input_ids"].shape[1]
|
|
538
|
+
stop_criteria = StopOnSequence(
|
|
539
|
+
tokenizer=self.tokenizer,
|
|
540
|
+
stop_sequence="</statements>",
|
|
541
|
+
input_length=input_length,
|
|
542
|
+
)
|
|
543
|
+
|
|
417
544
|
with torch.no_grad():
|
|
418
545
|
outputs = self.model.generate(
|
|
419
546
|
**inputs,
|
|
@@ -428,6 +555,7 @@ class StatementExtractor:
|
|
|
428
555
|
top_k=None, # Override model config to suppress warning
|
|
429
556
|
trust_remote_code=True,
|
|
430
557
|
custom_generate="transformers-community/group-beam-search",
|
|
558
|
+
stopping_criteria=StoppingCriteriaList([stop_criteria]),
|
|
431
559
|
)
|
|
432
560
|
|
|
433
561
|
# Decode and process candidates
|
|
@@ -511,6 +639,50 @@ class StatementExtractor:
|
|
|
511
639
|
entity_canonicalizer=options.entity_canonicalizer
|
|
512
640
|
)
|
|
513
641
|
|
|
642
|
+
def _select_best_per_source(
|
|
643
|
+
self,
|
|
644
|
+
statements: list[Statement],
|
|
645
|
+
) -> list[Statement]:
|
|
646
|
+
"""
|
|
647
|
+
Select the highest-scoring triple for each unique source text.
|
|
648
|
+
|
|
649
|
+
Groups statements by source_text and keeps only the one with
|
|
650
|
+
the highest confidence_score from each group.
|
|
651
|
+
|
|
652
|
+
Statements without source_text are kept as-is.
|
|
653
|
+
"""
|
|
654
|
+
if not statements:
|
|
655
|
+
return statements
|
|
656
|
+
|
|
657
|
+
# Group by source_text
|
|
658
|
+
from collections import defaultdict
|
|
659
|
+
groups: dict[str | None, list[Statement]] = defaultdict(list)
|
|
660
|
+
|
|
661
|
+
for stmt in statements:
|
|
662
|
+
groups[stmt.source_text].append(stmt)
|
|
663
|
+
|
|
664
|
+
# Select best from each group
|
|
665
|
+
result: list[Statement] = []
|
|
666
|
+
|
|
667
|
+
for source_text, group in groups.items():
|
|
668
|
+
if source_text is None or len(group) == 1:
|
|
669
|
+
# No source text or only one statement - keep as-is
|
|
670
|
+
result.extend(group)
|
|
671
|
+
else:
|
|
672
|
+
# Multiple candidates for same source - select best
|
|
673
|
+
best = max(
|
|
674
|
+
group,
|
|
675
|
+
key=lambda s: s.confidence_score if s.confidence_score is not None else 0.0
|
|
676
|
+
)
|
|
677
|
+
logger.debug(
|
|
678
|
+
f" Selected best for source '{source_text[:40]}...': "
|
|
679
|
+
f"'{best.subject.text}' --[{best.predicate}]--> '{best.object.text}' "
|
|
680
|
+
f"(score={best.confidence_score:.2f}, method={best.extraction_method.value})"
|
|
681
|
+
)
|
|
682
|
+
result.append(best)
|
|
683
|
+
|
|
684
|
+
return result
|
|
685
|
+
|
|
514
686
|
def _deduplicate_xml(self, xml_output: str) -> str:
|
|
515
687
|
"""Remove duplicate <stmt> blocks from XML output (legacy method)."""
|
|
516
688
|
try:
|
|
@@ -540,48 +712,159 @@ class StatementExtractor:
|
|
|
540
712
|
|
|
541
713
|
return ET.tostring(new_root, encoding='unicode')
|
|
542
714
|
|
|
543
|
-
def _parse_xml_to_statements(
|
|
544
|
-
|
|
715
|
+
def _parse_xml_to_statements(
|
|
716
|
+
self,
|
|
717
|
+
xml_output: str,
|
|
718
|
+
options: Optional[ExtractionOptions] = None,
|
|
719
|
+
) -> list[Statement]:
|
|
720
|
+
"""
|
|
721
|
+
Parse XML output into Statement objects.
|
|
722
|
+
|
|
723
|
+
Uses model for subject, object, entity types, and source_text.
|
|
724
|
+
Always uses spaCy for predicate extraction (model predicates are unreliable).
|
|
725
|
+
|
|
726
|
+
Produces two candidates for each statement:
|
|
727
|
+
1. Hybrid: model subject/object + spaCy predicate
|
|
728
|
+
2. spaCy-only: all components from spaCy
|
|
729
|
+
|
|
730
|
+
Both go into the candidate pool; scoring/dedup picks the best.
|
|
731
|
+
"""
|
|
545
732
|
statements: list[Statement] = []
|
|
733
|
+
use_spacy_extraction = options.use_spacy_extraction if options else True
|
|
546
734
|
|
|
547
735
|
try:
|
|
548
736
|
root = ET.fromstring(xml_output)
|
|
549
737
|
except ET.ParseError as e:
|
|
550
738
|
# Log full output for debugging
|
|
551
|
-
logger.
|
|
552
|
-
logger.
|
|
553
|
-
|
|
739
|
+
logger.debug(f"Initial XML parse failed: {e}")
|
|
740
|
+
logger.debug(f"Raw XML output ({len(xml_output)} chars):\n{xml_output}")
|
|
741
|
+
|
|
742
|
+
# Try to repair the XML
|
|
743
|
+
repaired_xml, repairs = repair_xml(xml_output)
|
|
744
|
+
if repairs:
|
|
745
|
+
logger.debug(f"Attempted XML repairs: {', '.join(repairs)}")
|
|
746
|
+
try:
|
|
747
|
+
root = ET.fromstring(repaired_xml)
|
|
748
|
+
logger.info(f"XML repair successful, parsing repaired output")
|
|
749
|
+
except ET.ParseError as e2:
|
|
750
|
+
logger.warning(f"XML repair failed, still cannot parse: {e2}")
|
|
751
|
+
logger.warning(f"Repaired XML ({len(repaired_xml)} chars):\n{repaired_xml}")
|
|
752
|
+
return statements
|
|
753
|
+
else:
|
|
754
|
+
logger.warning(f"No repairs possible for XML output")
|
|
755
|
+
logger.warning(f"Full XML output ({len(xml_output)} chars):\n{xml_output}")
|
|
756
|
+
return statements
|
|
554
757
|
|
|
555
758
|
if root.tag != 'statements':
|
|
759
|
+
logger.warning(f"Root tag is '{root.tag}', expected 'statements'")
|
|
556
760
|
return statements
|
|
557
761
|
|
|
558
762
|
for stmt_elem in root.findall('stmt'):
|
|
559
763
|
try:
|
|
560
|
-
# Parse subject
|
|
764
|
+
# Parse subject from model
|
|
561
765
|
subject_elem = stmt_elem.find('subject')
|
|
562
766
|
subject_text = subject_elem.text.strip() if subject_elem is not None and subject_elem.text else ""
|
|
563
767
|
subject_type = self._parse_entity_type(subject_elem.get('type') if subject_elem is not None else None)
|
|
564
768
|
|
|
565
|
-
# Parse object
|
|
769
|
+
# Parse object from model
|
|
566
770
|
object_elem = stmt_elem.find('object')
|
|
567
771
|
object_text = object_elem.text.strip() if object_elem is not None and object_elem.text else ""
|
|
568
772
|
object_type = self._parse_entity_type(object_elem.get('type') if object_elem is not None else None)
|
|
569
773
|
|
|
570
|
-
# Parse
|
|
571
|
-
predicate_elem = stmt_elem.find('predicate')
|
|
572
|
-
predicate = predicate_elem.text.strip() if predicate_elem is not None and predicate_elem.text else ""
|
|
573
|
-
|
|
574
|
-
# Parse source text
|
|
774
|
+
# Parse source text from model
|
|
575
775
|
text_elem = stmt_elem.find('text')
|
|
576
776
|
source_text = text_elem.text.strip() if text_elem is not None and text_elem.text else None
|
|
577
777
|
|
|
578
|
-
if
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
778
|
+
# Skip if missing required components from model
|
|
779
|
+
if not subject_text or not object_text:
|
|
780
|
+
logger.debug(f"Skipping statement: missing subject or object from model")
|
|
781
|
+
continue
|
|
782
|
+
|
|
783
|
+
if use_spacy_extraction and source_text:
|
|
784
|
+
try:
|
|
785
|
+
from .spacy_extraction import extract_triple_from_text, extract_triple_by_predicate_split
|
|
786
|
+
spacy_result = extract_triple_from_text(
|
|
787
|
+
source_text=source_text,
|
|
788
|
+
model_subject=subject_text,
|
|
789
|
+
model_object=object_text,
|
|
790
|
+
model_predicate="", # Don't pass model predicate
|
|
791
|
+
)
|
|
792
|
+
if spacy_result:
|
|
793
|
+
spacy_subj, spacy_pred, spacy_obj = spacy_result
|
|
794
|
+
|
|
795
|
+
if spacy_pred:
|
|
796
|
+
# Candidate 1: Hybrid (model subject/object + spaCy predicate)
|
|
797
|
+
logger.debug(
|
|
798
|
+
f"Adding hybrid candidate: '{subject_text}' --[{spacy_pred}]--> '{object_text}'"
|
|
799
|
+
)
|
|
800
|
+
statements.append(Statement(
|
|
801
|
+
subject=Entity(text=subject_text, type=subject_type),
|
|
802
|
+
predicate=spacy_pred,
|
|
803
|
+
object=Entity(text=object_text, type=object_type),
|
|
804
|
+
source_text=source_text,
|
|
805
|
+
extraction_method=ExtractionMethod.HYBRID,
|
|
806
|
+
))
|
|
807
|
+
|
|
808
|
+
# Candidate 2: spaCy-only (if different from hybrid)
|
|
809
|
+
if spacy_subj and spacy_obj:
|
|
810
|
+
is_different = (spacy_subj != subject_text or spacy_obj != object_text)
|
|
811
|
+
if is_different:
|
|
812
|
+
logger.debug(
|
|
813
|
+
f"Adding spaCy-only candidate: '{spacy_subj}' --[{spacy_pred}]--> '{spacy_obj}'"
|
|
814
|
+
)
|
|
815
|
+
statements.append(Statement(
|
|
816
|
+
subject=Entity(text=spacy_subj, type=subject_type),
|
|
817
|
+
predicate=spacy_pred,
|
|
818
|
+
object=Entity(text=spacy_obj, type=object_type),
|
|
819
|
+
source_text=source_text,
|
|
820
|
+
extraction_method=ExtractionMethod.SPACY,
|
|
821
|
+
))
|
|
822
|
+
|
|
823
|
+
# Candidate 3: Predicate-split (split source text around predicate)
|
|
824
|
+
split_result = extract_triple_by_predicate_split(
|
|
825
|
+
source_text=source_text,
|
|
826
|
+
predicate=spacy_pred,
|
|
827
|
+
)
|
|
828
|
+
if split_result:
|
|
829
|
+
split_subj, split_pred, split_obj = split_result
|
|
830
|
+
# Only add if different from previous candidates
|
|
831
|
+
is_different_from_hybrid = (split_subj != subject_text or split_obj != object_text)
|
|
832
|
+
is_different_from_spacy = (split_subj != spacy_subj or split_obj != spacy_obj)
|
|
833
|
+
if is_different_from_hybrid and is_different_from_spacy:
|
|
834
|
+
logger.debug(
|
|
835
|
+
f"Adding predicate-split candidate: '{split_subj}' --[{split_pred}]--> '{split_obj}'"
|
|
836
|
+
)
|
|
837
|
+
statements.append(Statement(
|
|
838
|
+
subject=Entity(text=split_subj, type=subject_type),
|
|
839
|
+
predicate=split_pred,
|
|
840
|
+
object=Entity(text=split_obj, type=object_type),
|
|
841
|
+
source_text=source_text,
|
|
842
|
+
extraction_method=ExtractionMethod.SPLIT,
|
|
843
|
+
))
|
|
844
|
+
else:
|
|
845
|
+
logger.debug(
|
|
846
|
+
f"spaCy found no predicate for: '{subject_text}' --> '{object_text}'"
|
|
847
|
+
)
|
|
848
|
+
except Exception as e:
|
|
849
|
+
logger.debug(f"spaCy extraction failed: {e}")
|
|
850
|
+
else:
|
|
851
|
+
# spaCy disabled - fall back to model predicate
|
|
852
|
+
predicate_elem = stmt_elem.find('predicate')
|
|
853
|
+
model_predicate = predicate_elem.text.strip() if predicate_elem is not None and predicate_elem.text else ""
|
|
854
|
+
|
|
855
|
+
if model_predicate:
|
|
856
|
+
statements.append(Statement(
|
|
857
|
+
subject=Entity(text=subject_text, type=subject_type),
|
|
858
|
+
predicate=model_predicate,
|
|
859
|
+
object=Entity(text=object_text, type=object_type),
|
|
860
|
+
source_text=source_text,
|
|
861
|
+
extraction_method=ExtractionMethod.MODEL,
|
|
862
|
+
))
|
|
863
|
+
else:
|
|
864
|
+
logger.debug(
|
|
865
|
+
f"Skipping statement (no predicate, spaCy disabled): "
|
|
866
|
+
f"'{subject_text}' --> '{object_text}'"
|
|
867
|
+
)
|
|
585
868
|
except Exception as e:
|
|
586
869
|
logger.warning(f"Failed to parse statement: {e}")
|
|
587
870
|
continue
|
statement_extractor/models.py
CHANGED
|
@@ -24,6 +24,14 @@ class EntityType(str, Enum):
|
|
|
24
24
|
UNKNOWN = "UNKNOWN"
|
|
25
25
|
|
|
26
26
|
|
|
27
|
+
class ExtractionMethod(str, Enum):
|
|
28
|
+
"""Method used to extract the triple components."""
|
|
29
|
+
HYBRID = "hybrid" # Model subject/object + spaCy predicate
|
|
30
|
+
SPACY = "spacy" # All components from spaCy dependency parsing
|
|
31
|
+
SPLIT = "split" # Subject/object from splitting source text around predicate
|
|
32
|
+
MODEL = "model" # All components from T5-Gemma model (when spaCy disabled)
|
|
33
|
+
|
|
34
|
+
|
|
27
35
|
class Entity(BaseModel):
|
|
28
36
|
"""An entity (subject or object) with its text and type."""
|
|
29
37
|
text: str = Field(..., description="The entity text")
|
|
@@ -52,12 +60,18 @@ class Statement(BaseModel):
|
|
|
52
60
|
object: Entity = Field(..., description="The object entity")
|
|
53
61
|
source_text: Optional[str] = Field(None, description="The original text this statement was extracted from")
|
|
54
62
|
|
|
63
|
+
# Extraction method tracking
|
|
64
|
+
extraction_method: ExtractionMethod = Field(
|
|
65
|
+
default=ExtractionMethod.MODEL,
|
|
66
|
+
description="Method used to extract this triple (hybrid, spacy, split, or model)"
|
|
67
|
+
)
|
|
68
|
+
|
|
55
69
|
# Quality scoring fields
|
|
56
70
|
confidence_score: Optional[float] = Field(
|
|
57
71
|
None,
|
|
58
72
|
ge=0.0,
|
|
59
73
|
le=1.0,
|
|
60
|
-
description="
|
|
74
|
+
description="Semantic similarity score (0-1) between source text and reassembled triple"
|
|
61
75
|
)
|
|
62
76
|
evidence_span: Optional[tuple[int, int]] = Field(
|
|
63
77
|
None,
|
|
@@ -99,6 +113,7 @@ class Statement(BaseModel):
|
|
|
99
113
|
object=merged_object,
|
|
100
114
|
predicate=self.predicate,
|
|
101
115
|
source_text=self.source_text,
|
|
116
|
+
extraction_method=self.extraction_method,
|
|
102
117
|
confidence_score=self.confidence_score,
|
|
103
118
|
evidence_span=self.evidence_span,
|
|
104
119
|
canonical_predicate=self.canonical_predicate,
|
|
@@ -116,6 +131,7 @@ class Statement(BaseModel):
|
|
|
116
131
|
object=self.subject,
|
|
117
132
|
predicate=self.predicate,
|
|
118
133
|
source_text=self.source_text,
|
|
134
|
+
extraction_method=self.extraction_method,
|
|
119
135
|
confidence_score=self.confidence_score,
|
|
120
136
|
evidence_span=self.evidence_span,
|
|
121
137
|
canonical_predicate=self.canonical_predicate,
|
|
@@ -279,6 +295,10 @@ class ExtractionOptions(BaseModel):
|
|
|
279
295
|
default=True,
|
|
280
296
|
description="Use embedding similarity for predicate deduplication"
|
|
281
297
|
)
|
|
298
|
+
use_spacy_extraction: bool = Field(
|
|
299
|
+
default=True,
|
|
300
|
+
description="Use spaCy for predicate/subject/object extraction (model provides structure + coreference)"
|
|
301
|
+
)
|
|
282
302
|
|
|
283
303
|
# Verbose logging
|
|
284
304
|
verbose: bool = Field(
|
|
@@ -286,5 +306,11 @@ class ExtractionOptions(BaseModel):
|
|
|
286
306
|
description="Enable verbose logging for debugging"
|
|
287
307
|
)
|
|
288
308
|
|
|
309
|
+
# Triple selection
|
|
310
|
+
all_triples: bool = Field(
|
|
311
|
+
default=False,
|
|
312
|
+
description="Keep all candidate triples instead of selecting the highest-scoring one per source"
|
|
313
|
+
)
|
|
314
|
+
|
|
289
315
|
class Config:
|
|
290
316
|
arbitrary_types_allowed = True # Allow Callable type
|