corp-extractor 0.2.11__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,11 +14,12 @@ import xml.etree.ElementTree as ET
14
14
  from typing import Optional
15
15
 
16
16
  import torch
17
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
17
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
18
18
 
19
19
  from .models import (
20
20
  Entity,
21
21
  EntityType,
22
+ ExtractionMethod,
22
23
  ExtractionOptions,
23
24
  ExtractionResult,
24
25
  PredicateComparisonConfig,
@@ -29,6 +30,114 @@ from .models import (
29
30
 
30
31
  logger = logging.getLogger(__name__)
31
32
 
33
+
34
+ class StopOnSequence(StoppingCriteria):
35
+ """
36
+ Stop generation when a specific multi-token sequence is generated.
37
+
38
+ Decodes the generated tokens and checks if the stop sequence appears.
39
+ Works with sequences that span multiple tokens (e.g., "</statements>").
40
+ """
41
+
42
+ def __init__(self, tokenizer, stop_sequence: str, input_length: int):
43
+ self.tokenizer = tokenizer
44
+ self.stop_sequence = stop_sequence
45
+ self.input_length = input_length
46
+ # Track which beams have stopped (for batch generation)
47
+ self.stopped = set()
48
+
49
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
50
+ # Check each sequence in the batch
51
+ for idx, seq in enumerate(input_ids):
52
+ if idx in self.stopped:
53
+ continue
54
+ # Only decode the generated portion (after input)
55
+ generated = seq[self.input_length:]
56
+ decoded = self.tokenizer.decode(generated, skip_special_tokens=True)
57
+ if self.stop_sequence in decoded:
58
+ self.stopped.add(idx)
59
+
60
+ # Stop when ALL sequences have the stop sequence
61
+ return len(self.stopped) >= len(input_ids)
62
+
63
+
64
+ def repair_xml(xml_string: str) -> tuple[str, list[str]]:
65
+ """
66
+ Attempt to repair common XML syntax errors.
67
+
68
+ Returns:
69
+ Tuple of (repaired_xml, list_of_repairs_made)
70
+ """
71
+ repairs = []
72
+ original = xml_string
73
+
74
+ # 1. Fix unescaped ampersands (but not already escaped entities)
75
+ # Match & not followed by amp; lt; gt; quot; apos; or #
76
+ ampersand_pattern = r'&(?!(amp|lt|gt|quot|apos|#\d+|#x[0-9a-fA-F]+);)'
77
+ if re.search(ampersand_pattern, xml_string):
78
+ xml_string = re.sub(ampersand_pattern, '&amp;', xml_string)
79
+ repairs.append("escaped unescaped ampersands")
80
+
81
+ # 2. Fix unescaped < and > inside text content (not tags)
82
+ # This is tricky - we need to be careful not to break actual tags
83
+ # For now, just handle the most common case: < followed by space or lowercase
84
+ less_than_pattern = r'<(?=\s|[a-z]{2,}[^a-z/>])'
85
+ if re.search(less_than_pattern, xml_string):
86
+ xml_string = re.sub(less_than_pattern, '&lt;', xml_string)
87
+ repairs.append("escaped unescaped less-than signs")
88
+
89
+ # 3. Fix truncated closing tags (e.g., "</statemen" -> try to complete)
90
+ truncated_patterns = [
91
+ (r'</statement[^s>]*$', '</statements>'),
92
+ (r'</stm[^t>]*$', '</stmt>'),
93
+ (r'</subjec[^t>]*$', '</subject>'),
94
+ (r'</objec[^t>]*$', '</object>'),
95
+ (r'</predica[^t>]*$', '</predicate>'),
96
+ (r'</tex[^t>]*$', '</text>'),
97
+ ]
98
+ for pattern, replacement in truncated_patterns:
99
+ if re.search(pattern, xml_string):
100
+ xml_string = re.sub(pattern, replacement, xml_string)
101
+ repairs.append(f"completed truncated tag: {replacement}")
102
+
103
+ # 4. Add missing </statements> if we have <statements> but no closing
104
+ if '<statements>' in xml_string and '</statements>' not in xml_string:
105
+ # Try to find a good place to add it
106
+ # Look for the last complete </stmt> and add after it
107
+ last_stmt = xml_string.rfind('</stmt>')
108
+ if last_stmt != -1:
109
+ insert_pos = last_stmt + len('</stmt>')
110
+ xml_string = xml_string[:insert_pos] + '</statements>'
111
+ repairs.append("added missing </statements> after last </stmt>")
112
+ else:
113
+ xml_string = xml_string + '</statements>'
114
+ repairs.append("added missing </statements> at end")
115
+
116
+ # 5. Fix unclosed <stmt> tags - find <stmt> without matching </stmt>
117
+ # Count opens and closes
118
+ open_stmts = len(re.findall(r'<stmt>', xml_string))
119
+ close_stmts = len(re.findall(r'</stmt>', xml_string))
120
+ if open_stmts > close_stmts:
121
+ # Find incomplete statement blocks and try to close them
122
+ # Look for patterns like <stmt>...<subject>...</subject> without </stmt>
123
+ # This is complex, so just add closing tags before </statements>
124
+ missing = open_stmts - close_stmts
125
+ if '</statements>' in xml_string:
126
+ xml_string = xml_string.replace('</statements>', '</stmt>' * missing + '</statements>')
127
+ repairs.append(f"added {missing} missing </stmt> tag(s)")
128
+
129
+ # 6. Remove any content after </statements>
130
+ end_pos = xml_string.find('</statements>')
131
+ if end_pos != -1:
132
+ end_pos += len('</statements>')
133
+ if end_pos < len(xml_string):
134
+ xml_string = xml_string[:end_pos]
135
+ repairs.append("removed content after </statements>")
136
+
137
+ if xml_string != original:
138
+ return xml_string, repairs
139
+ return xml_string, []
140
+
32
141
  # Default model
33
142
  DEFAULT_MODEL_ID = "Corp-o-Rate-Community/statement-extractor"
34
143
 
@@ -327,12 +436,13 @@ class StatementExtractor:
327
436
  # Parse each candidate to statements
328
437
  parsed_candidates = []
329
438
  for i, xml_output in enumerate(candidates):
330
- statements = self._parse_xml_to_statements(xml_output)
439
+ statements = self._parse_xml_to_statements(xml_output, options)
331
440
  if statements:
332
441
  parsed_candidates.append(statements)
333
442
  logger.debug(f" Beam {i}: {len(statements)} statements parsed")
334
443
  else:
335
- logger.debug(f" Beam {i}: 0 statements (parse failed)")
444
+ logger.warning(f" Beam {i}: 0 statements (parse failed)")
445
+ logger.warning(f" Beam {i} XML output:\n{xml_output}")
336
446
 
337
447
  all_candidates.extend(parsed_candidates)
338
448
 
@@ -395,6 +505,15 @@ class StatementExtractor:
395
505
  else:
396
506
  logger.debug("Deduplication disabled")
397
507
 
508
+ # Select best triple per source text (unless all_triples enabled)
509
+ if not options.all_triples:
510
+ logger.debug("-" * 40)
511
+ logger.debug("PHASE 5: Best Triple Selection")
512
+ logger.debug("-" * 40)
513
+ pre_select_count = len(statements)
514
+ statements = self._select_best_per_source(statements)
515
+ logger.debug(f" Selected best per source: {len(statements)} statements (from {pre_select_count})")
516
+
398
517
  # Log final statements
399
518
  logger.debug("-" * 40)
400
519
  logger.debug("FINAL STATEMENTS:")
@@ -414,6 +533,14 @@ class StatementExtractor:
414
533
  """Generate multiple candidate beams using diverse beam search."""
415
534
  num_seqs = options.num_beams
416
535
 
536
+ # Create stopping criteria to stop when </statements> is generated
537
+ input_length = inputs["input_ids"].shape[1]
538
+ stop_criteria = StopOnSequence(
539
+ tokenizer=self.tokenizer,
540
+ stop_sequence="</statements>",
541
+ input_length=input_length,
542
+ )
543
+
417
544
  with torch.no_grad():
418
545
  outputs = self.model.generate(
419
546
  **inputs,
@@ -428,6 +555,7 @@ class StatementExtractor:
428
555
  top_k=None, # Override model config to suppress warning
429
556
  trust_remote_code=True,
430
557
  custom_generate="transformers-community/group-beam-search",
558
+ stopping_criteria=StoppingCriteriaList([stop_criteria]),
431
559
  )
432
560
 
433
561
  # Decode and process candidates
@@ -511,6 +639,50 @@ class StatementExtractor:
511
639
  entity_canonicalizer=options.entity_canonicalizer
512
640
  )
513
641
 
642
+ def _select_best_per_source(
643
+ self,
644
+ statements: list[Statement],
645
+ ) -> list[Statement]:
646
+ """
647
+ Select the highest-scoring triple for each unique source text.
648
+
649
+ Groups statements by source_text and keeps only the one with
650
+ the highest confidence_score from each group.
651
+
652
+ Statements without source_text are kept as-is.
653
+ """
654
+ if not statements:
655
+ return statements
656
+
657
+ # Group by source_text
658
+ from collections import defaultdict
659
+ groups: dict[str | None, list[Statement]] = defaultdict(list)
660
+
661
+ for stmt in statements:
662
+ groups[stmt.source_text].append(stmt)
663
+
664
+ # Select best from each group
665
+ result: list[Statement] = []
666
+
667
+ for source_text, group in groups.items():
668
+ if source_text is None or len(group) == 1:
669
+ # No source text or only one statement - keep as-is
670
+ result.extend(group)
671
+ else:
672
+ # Multiple candidates for same source - select best
673
+ best = max(
674
+ group,
675
+ key=lambda s: s.confidence_score if s.confidence_score is not None else 0.0
676
+ )
677
+ logger.debug(
678
+ f" Selected best for source '{source_text[:40]}...': "
679
+ f"'{best.subject.text}' --[{best.predicate}]--> '{best.object.text}' "
680
+ f"(score={best.confidence_score:.2f}, method={best.extraction_method.value})"
681
+ )
682
+ result.append(best)
683
+
684
+ return result
685
+
514
686
  def _deduplicate_xml(self, xml_output: str) -> str:
515
687
  """Remove duplicate <stmt> blocks from XML output (legacy method)."""
516
688
  try:
@@ -540,48 +712,159 @@ class StatementExtractor:
540
712
 
541
713
  return ET.tostring(new_root, encoding='unicode')
542
714
 
543
- def _parse_xml_to_statements(self, xml_output: str) -> list[Statement]:
544
- """Parse XML output into Statement objects."""
715
+ def _parse_xml_to_statements(
716
+ self,
717
+ xml_output: str,
718
+ options: Optional[ExtractionOptions] = None,
719
+ ) -> list[Statement]:
720
+ """
721
+ Parse XML output into Statement objects.
722
+
723
+ Uses model for subject, object, entity types, and source_text.
724
+ Always uses spaCy for predicate extraction (model predicates are unreliable).
725
+
726
+ Produces two candidates for each statement:
727
+ 1. Hybrid: model subject/object + spaCy predicate
728
+ 2. spaCy-only: all components from spaCy
729
+
730
+ Both go into the candidate pool; scoring/dedup picks the best.
731
+ """
545
732
  statements: list[Statement] = []
733
+ use_spacy_extraction = options.use_spacy_extraction if options else True
546
734
 
547
735
  try:
548
736
  root = ET.fromstring(xml_output)
549
737
  except ET.ParseError as e:
550
738
  # Log full output for debugging
551
- logger.warning(f"Failed to parse XML output: {e}")
552
- logger.warning(f"Full XML output ({len(xml_output)} chars):\n{xml_output}")
553
- return statements
739
+ logger.debug(f"Initial XML parse failed: {e}")
740
+ logger.debug(f"Raw XML output ({len(xml_output)} chars):\n{xml_output}")
741
+
742
+ # Try to repair the XML
743
+ repaired_xml, repairs = repair_xml(xml_output)
744
+ if repairs:
745
+ logger.debug(f"Attempted XML repairs: {', '.join(repairs)}")
746
+ try:
747
+ root = ET.fromstring(repaired_xml)
748
+ logger.info(f"XML repair successful, parsing repaired output")
749
+ except ET.ParseError as e2:
750
+ logger.warning(f"XML repair failed, still cannot parse: {e2}")
751
+ logger.warning(f"Repaired XML ({len(repaired_xml)} chars):\n{repaired_xml}")
752
+ return statements
753
+ else:
754
+ logger.warning(f"No repairs possible for XML output")
755
+ logger.warning(f"Full XML output ({len(xml_output)} chars):\n{xml_output}")
756
+ return statements
554
757
 
555
758
  if root.tag != 'statements':
759
+ logger.warning(f"Root tag is '{root.tag}', expected 'statements'")
556
760
  return statements
557
761
 
558
762
  for stmt_elem in root.findall('stmt'):
559
763
  try:
560
- # Parse subject
764
+ # Parse subject from model
561
765
  subject_elem = stmt_elem.find('subject')
562
766
  subject_text = subject_elem.text.strip() if subject_elem is not None and subject_elem.text else ""
563
767
  subject_type = self._parse_entity_type(subject_elem.get('type') if subject_elem is not None else None)
564
768
 
565
- # Parse object
769
+ # Parse object from model
566
770
  object_elem = stmt_elem.find('object')
567
771
  object_text = object_elem.text.strip() if object_elem is not None and object_elem.text else ""
568
772
  object_type = self._parse_entity_type(object_elem.get('type') if object_elem is not None else None)
569
773
 
570
- # Parse predicate
571
- predicate_elem = stmt_elem.find('predicate')
572
- predicate = predicate_elem.text.strip() if predicate_elem is not None and predicate_elem.text else ""
573
-
574
- # Parse source text
774
+ # Parse source text from model
575
775
  text_elem = stmt_elem.find('text')
576
776
  source_text = text_elem.text.strip() if text_elem is not None and text_elem.text else None
577
777
 
578
- if subject_text and predicate and object_text:
579
- statements.append(Statement(
580
- subject=Entity(text=subject_text, type=subject_type),
581
- predicate=predicate,
582
- object=Entity(text=object_text, type=object_type),
583
- source_text=source_text,
584
- ))
778
+ # Skip if missing required components from model
779
+ if not subject_text or not object_text:
780
+ logger.debug(f"Skipping statement: missing subject or object from model")
781
+ continue
782
+
783
+ if use_spacy_extraction and source_text:
784
+ try:
785
+ from .spacy_extraction import extract_triple_from_text, extract_triple_by_predicate_split
786
+ spacy_result = extract_triple_from_text(
787
+ source_text=source_text,
788
+ model_subject=subject_text,
789
+ model_object=object_text,
790
+ model_predicate="", # Don't pass model predicate
791
+ )
792
+ if spacy_result:
793
+ spacy_subj, spacy_pred, spacy_obj = spacy_result
794
+
795
+ if spacy_pred:
796
+ # Candidate 1: Hybrid (model subject/object + spaCy predicate)
797
+ logger.debug(
798
+ f"Adding hybrid candidate: '{subject_text}' --[{spacy_pred}]--> '{object_text}'"
799
+ )
800
+ statements.append(Statement(
801
+ subject=Entity(text=subject_text, type=subject_type),
802
+ predicate=spacy_pred,
803
+ object=Entity(text=object_text, type=object_type),
804
+ source_text=source_text,
805
+ extraction_method=ExtractionMethod.HYBRID,
806
+ ))
807
+
808
+ # Candidate 2: spaCy-only (if different from hybrid)
809
+ if spacy_subj and spacy_obj:
810
+ is_different = (spacy_subj != subject_text or spacy_obj != object_text)
811
+ if is_different:
812
+ logger.debug(
813
+ f"Adding spaCy-only candidate: '{spacy_subj}' --[{spacy_pred}]--> '{spacy_obj}'"
814
+ )
815
+ statements.append(Statement(
816
+ subject=Entity(text=spacy_subj, type=subject_type),
817
+ predicate=spacy_pred,
818
+ object=Entity(text=spacy_obj, type=object_type),
819
+ source_text=source_text,
820
+ extraction_method=ExtractionMethod.SPACY,
821
+ ))
822
+
823
+ # Candidate 3: Predicate-split (split source text around predicate)
824
+ split_result = extract_triple_by_predicate_split(
825
+ source_text=source_text,
826
+ predicate=spacy_pred,
827
+ )
828
+ if split_result:
829
+ split_subj, split_pred, split_obj = split_result
830
+ # Only add if different from previous candidates
831
+ is_different_from_hybrid = (split_subj != subject_text or split_obj != object_text)
832
+ is_different_from_spacy = (split_subj != spacy_subj or split_obj != spacy_obj)
833
+ if is_different_from_hybrid and is_different_from_spacy:
834
+ logger.debug(
835
+ f"Adding predicate-split candidate: '{split_subj}' --[{split_pred}]--> '{split_obj}'"
836
+ )
837
+ statements.append(Statement(
838
+ subject=Entity(text=split_subj, type=subject_type),
839
+ predicate=split_pred,
840
+ object=Entity(text=split_obj, type=object_type),
841
+ source_text=source_text,
842
+ extraction_method=ExtractionMethod.SPLIT,
843
+ ))
844
+ else:
845
+ logger.debug(
846
+ f"spaCy found no predicate for: '{subject_text}' --> '{object_text}'"
847
+ )
848
+ except Exception as e:
849
+ logger.debug(f"spaCy extraction failed: {e}")
850
+ else:
851
+ # spaCy disabled - fall back to model predicate
852
+ predicate_elem = stmt_elem.find('predicate')
853
+ model_predicate = predicate_elem.text.strip() if predicate_elem is not None and predicate_elem.text else ""
854
+
855
+ if model_predicate:
856
+ statements.append(Statement(
857
+ subject=Entity(text=subject_text, type=subject_type),
858
+ predicate=model_predicate,
859
+ object=Entity(text=object_text, type=object_type),
860
+ source_text=source_text,
861
+ extraction_method=ExtractionMethod.MODEL,
862
+ ))
863
+ else:
864
+ logger.debug(
865
+ f"Skipping statement (no predicate, spaCy disabled): "
866
+ f"'{subject_text}' --> '{object_text}'"
867
+ )
585
868
  except Exception as e:
586
869
  logger.warning(f"Failed to parse statement: {e}")
587
870
  continue
@@ -24,6 +24,14 @@ class EntityType(str, Enum):
24
24
  UNKNOWN = "UNKNOWN"
25
25
 
26
26
 
27
+ class ExtractionMethod(str, Enum):
28
+ """Method used to extract the triple components."""
29
+ HYBRID = "hybrid" # Model subject/object + spaCy predicate
30
+ SPACY = "spacy" # All components from spaCy dependency parsing
31
+ SPLIT = "split" # Subject/object from splitting source text around predicate
32
+ MODEL = "model" # All components from T5-Gemma model (when spaCy disabled)
33
+
34
+
27
35
  class Entity(BaseModel):
28
36
  """An entity (subject or object) with its text and type."""
29
37
  text: str = Field(..., description="The entity text")
@@ -52,12 +60,18 @@ class Statement(BaseModel):
52
60
  object: Entity = Field(..., description="The object entity")
53
61
  source_text: Optional[str] = Field(None, description="The original text this statement was extracted from")
54
62
 
63
+ # Extraction method tracking
64
+ extraction_method: ExtractionMethod = Field(
65
+ default=ExtractionMethod.MODEL,
66
+ description="Method used to extract this triple (hybrid, spacy, split, or model)"
67
+ )
68
+
55
69
  # Quality scoring fields
56
70
  confidence_score: Optional[float] = Field(
57
71
  None,
58
72
  ge=0.0,
59
73
  le=1.0,
60
- description="Groundedness score (0-1) indicating how well the triple is supported by source text"
74
+ description="Semantic similarity score (0-1) between source text and reassembled triple"
61
75
  )
62
76
  evidence_span: Optional[tuple[int, int]] = Field(
63
77
  None,
@@ -99,6 +113,7 @@ class Statement(BaseModel):
99
113
  object=merged_object,
100
114
  predicate=self.predicate,
101
115
  source_text=self.source_text,
116
+ extraction_method=self.extraction_method,
102
117
  confidence_score=self.confidence_score,
103
118
  evidence_span=self.evidence_span,
104
119
  canonical_predicate=self.canonical_predicate,
@@ -116,6 +131,7 @@ class Statement(BaseModel):
116
131
  object=self.subject,
117
132
  predicate=self.predicate,
118
133
  source_text=self.source_text,
134
+ extraction_method=self.extraction_method,
119
135
  confidence_score=self.confidence_score,
120
136
  evidence_span=self.evidence_span,
121
137
  canonical_predicate=self.canonical_predicate,
@@ -279,6 +295,10 @@ class ExtractionOptions(BaseModel):
279
295
  default=True,
280
296
  description="Use embedding similarity for predicate deduplication"
281
297
  )
298
+ use_spacy_extraction: bool = Field(
299
+ default=True,
300
+ description="Use spaCy for predicate/subject/object extraction (model provides structure + coreference)"
301
+ )
282
302
 
283
303
  # Verbose logging
284
304
  verbose: bool = Field(
@@ -286,5 +306,11 @@ class ExtractionOptions(BaseModel):
286
306
  description="Enable verbose logging for debugging"
287
307
  )
288
308
 
309
+ # Triple selection
310
+ all_triples: bool = Field(
311
+ default=False,
312
+ description="Keep all candidate triples instead of selecting the highest-scoring one per source"
313
+ )
314
+
289
315
  class Config:
290
316
  arbitrary_types_allowed = True # Allow Callable type