corp-extractor 0.2.11__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,11 +14,12 @@ import xml.etree.ElementTree as ET
14
14
  from typing import Optional
15
15
 
16
16
  import torch
17
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
17
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
18
18
 
19
19
  from .models import (
20
20
  Entity,
21
21
  EntityType,
22
+ ExtractionMethod,
22
23
  ExtractionOptions,
23
24
  ExtractionResult,
24
25
  PredicateComparisonConfig,
@@ -29,6 +30,114 @@ from .models import (
29
30
 
30
31
  logger = logging.getLogger(__name__)
31
32
 
33
+
34
+ class StopOnSequence(StoppingCriteria):
35
+ """
36
+ Stop generation when a specific multi-token sequence is generated.
37
+
38
+ Decodes the generated tokens and checks if the stop sequence appears.
39
+ Works with sequences that span multiple tokens (e.g., "</statements>").
40
+ """
41
+
42
+ def __init__(self, tokenizer, stop_sequence: str, input_length: int):
43
+ self.tokenizer = tokenizer
44
+ self.stop_sequence = stop_sequence
45
+ self.input_length = input_length
46
+ # Track which beams have stopped (for batch generation)
47
+ self.stopped = set()
48
+
49
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
50
+ # Check each sequence in the batch
51
+ for idx, seq in enumerate(input_ids):
52
+ if idx in self.stopped:
53
+ continue
54
+ # Only decode the generated portion (after input)
55
+ generated = seq[self.input_length:]
56
+ decoded = self.tokenizer.decode(generated, skip_special_tokens=True)
57
+ if self.stop_sequence in decoded:
58
+ self.stopped.add(idx)
59
+
60
+ # Stop when ALL sequences have the stop sequence
61
+ return len(self.stopped) >= len(input_ids)
62
+
63
+
64
+ def repair_xml(xml_string: str) -> tuple[str, list[str]]:
65
+ """
66
+ Attempt to repair common XML syntax errors.
67
+
68
+ Returns:
69
+ Tuple of (repaired_xml, list_of_repairs_made)
70
+ """
71
+ repairs = []
72
+ original = xml_string
73
+
74
+ # 1. Fix unescaped ampersands (but not already escaped entities)
75
+ # Match & not followed by amp; lt; gt; quot; apos; or #
76
+ ampersand_pattern = r'&(?!(amp|lt|gt|quot|apos|#\d+|#x[0-9a-fA-F]+);)'
77
+ if re.search(ampersand_pattern, xml_string):
78
+ xml_string = re.sub(ampersand_pattern, '&amp;', xml_string)
79
+ repairs.append("escaped unescaped ampersands")
80
+
81
+ # 2. Fix unescaped < and > inside text content (not tags)
82
+ # This is tricky - we need to be careful not to break actual tags
83
+ # For now, just handle the most common case: < followed by space or lowercase
84
+ less_than_pattern = r'<(?=\s|[a-z]{2,}[^a-z/>])'
85
+ if re.search(less_than_pattern, xml_string):
86
+ xml_string = re.sub(less_than_pattern, '&lt;', xml_string)
87
+ repairs.append("escaped unescaped less-than signs")
88
+
89
+ # 3. Fix truncated closing tags (e.g., "</statemen" -> try to complete)
90
+ truncated_patterns = [
91
+ (r'</statement[^s>]*$', '</statements>'),
92
+ (r'</stm[^t>]*$', '</stmt>'),
93
+ (r'</subjec[^t>]*$', '</subject>'),
94
+ (r'</objec[^t>]*$', '</object>'),
95
+ (r'</predica[^t>]*$', '</predicate>'),
96
+ (r'</tex[^t>]*$', '</text>'),
97
+ ]
98
+ for pattern, replacement in truncated_patterns:
99
+ if re.search(pattern, xml_string):
100
+ xml_string = re.sub(pattern, replacement, xml_string)
101
+ repairs.append(f"completed truncated tag: {replacement}")
102
+
103
+ # 4. Add missing </statements> if we have <statements> but no closing
104
+ if '<statements>' in xml_string and '</statements>' not in xml_string:
105
+ # Try to find a good place to add it
106
+ # Look for the last complete </stmt> and add after it
107
+ last_stmt = xml_string.rfind('</stmt>')
108
+ if last_stmt != -1:
109
+ insert_pos = last_stmt + len('</stmt>')
110
+ xml_string = xml_string[:insert_pos] + '</statements>'
111
+ repairs.append("added missing </statements> after last </stmt>")
112
+ else:
113
+ xml_string = xml_string + '</statements>'
114
+ repairs.append("added missing </statements> at end")
115
+
116
+ # 5. Fix unclosed <stmt> tags - find <stmt> without matching </stmt>
117
+ # Count opens and closes
118
+ open_stmts = len(re.findall(r'<stmt>', xml_string))
119
+ close_stmts = len(re.findall(r'</stmt>', xml_string))
120
+ if open_stmts > close_stmts:
121
+ # Find incomplete statement blocks and try to close them
122
+ # Look for patterns like <stmt>...<subject>...</subject> without </stmt>
123
+ # This is complex, so just add closing tags before </statements>
124
+ missing = open_stmts - close_stmts
125
+ if '</statements>' in xml_string:
126
+ xml_string = xml_string.replace('</statements>', '</stmt>' * missing + '</statements>')
127
+ repairs.append(f"added {missing} missing </stmt> tag(s)")
128
+
129
+ # 6. Remove any content after </statements>
130
+ end_pos = xml_string.find('</statements>')
131
+ if end_pos != -1:
132
+ end_pos += len('</statements>')
133
+ if end_pos < len(xml_string):
134
+ xml_string = xml_string[:end_pos]
135
+ repairs.append("removed content after </statements>")
136
+
137
+ if xml_string != original:
138
+ return xml_string, repairs
139
+ return xml_string, []
140
+
32
141
  # Default model
33
142
  DEFAULT_MODEL_ID = "Corp-o-Rate-Community/statement-extractor"
34
143
 
@@ -327,12 +436,13 @@ class StatementExtractor:
327
436
  # Parse each candidate to statements
328
437
  parsed_candidates = []
329
438
  for i, xml_output in enumerate(candidates):
330
- statements = self._parse_xml_to_statements(xml_output)
439
+ statements = self._parse_xml_to_statements(xml_output, options)
331
440
  if statements:
332
441
  parsed_candidates.append(statements)
333
442
  logger.debug(f" Beam {i}: {len(statements)} statements parsed")
334
443
  else:
335
- logger.debug(f" Beam {i}: 0 statements (parse failed)")
444
+ logger.warning(f" Beam {i}: 0 statements (parse failed)")
445
+ logger.warning(f" Beam {i} XML output:\n{xml_output}")
336
446
 
337
447
  all_candidates.extend(parsed_candidates)
338
448
 
@@ -395,6 +505,15 @@ class StatementExtractor:
395
505
  else:
396
506
  logger.debug("Deduplication disabled")
397
507
 
508
+ # Select best triple per source text (unless all_triples enabled)
509
+ if not options.all_triples:
510
+ logger.debug("-" * 40)
511
+ logger.debug("PHASE 5: Best Triple Selection")
512
+ logger.debug("-" * 40)
513
+ pre_select_count = len(statements)
514
+ statements = self._select_best_per_source(statements)
515
+ logger.debug(f" Selected best per source: {len(statements)} statements (from {pre_select_count})")
516
+
398
517
  # Log final statements
399
518
  logger.debug("-" * 40)
400
519
  logger.debug("FINAL STATEMENTS:")
@@ -414,6 +533,14 @@ class StatementExtractor:
414
533
  """Generate multiple candidate beams using diverse beam search."""
415
534
  num_seqs = options.num_beams
416
535
 
536
+ # Create stopping criteria to stop when </statements> is generated
537
+ input_length = inputs["input_ids"].shape[1]
538
+ stop_criteria = StopOnSequence(
539
+ tokenizer=self.tokenizer,
540
+ stop_sequence="</statements>",
541
+ input_length=input_length,
542
+ )
543
+
417
544
  with torch.no_grad():
418
545
  outputs = self.model.generate(
419
546
  **inputs,
@@ -428,6 +555,7 @@ class StatementExtractor:
428
555
  top_k=None, # Override model config to suppress warning
429
556
  trust_remote_code=True,
430
557
  custom_generate="transformers-community/group-beam-search",
558
+ stopping_criteria=StoppingCriteriaList([stop_criteria]),
431
559
  )
432
560
 
433
561
  # Decode and process candidates
@@ -511,6 +639,50 @@ class StatementExtractor:
511
639
  entity_canonicalizer=options.entity_canonicalizer
512
640
  )
513
641
 
642
+ def _select_best_per_source(
643
+ self,
644
+ statements: list[Statement],
645
+ ) -> list[Statement]:
646
+ """
647
+ Select the highest-scoring triple for each unique source text.
648
+
649
+ Groups statements by source_text and keeps only the one with
650
+ the highest confidence_score from each group.
651
+
652
+ Statements without source_text are kept as-is.
653
+ """
654
+ if not statements:
655
+ return statements
656
+
657
+ # Group by source_text
658
+ from collections import defaultdict
659
+ groups: dict[str | None, list[Statement]] = defaultdict(list)
660
+
661
+ for stmt in statements:
662
+ groups[stmt.source_text].append(stmt)
663
+
664
+ # Select best from each group
665
+ result: list[Statement] = []
666
+
667
+ for source_text, group in groups.items():
668
+ if source_text is None or len(group) == 1:
669
+ # No source text or only one statement - keep as-is
670
+ result.extend(group)
671
+ else:
672
+ # Multiple candidates for same source - select best
673
+ best = max(
674
+ group,
675
+ key=lambda s: s.confidence_score if s.confidence_score is not None else 0.0
676
+ )
677
+ logger.debug(
678
+ f" Selected best for source '{source_text[:40]}...': "
679
+ f"'{best.subject.text}' --[{best.predicate}]--> '{best.object.text}' "
680
+ f"(score={best.confidence_score:.2f}, method={best.extraction_method.value})"
681
+ )
682
+ result.append(best)
683
+
684
+ return result
685
+
514
686
  def _deduplicate_xml(self, xml_output: str) -> str:
515
687
  """Remove duplicate <stmt> blocks from XML output (legacy method)."""
516
688
  try:
@@ -540,48 +712,166 @@ class StatementExtractor:
540
712
 
541
713
  return ET.tostring(new_root, encoding='unicode')
542
714
 
543
- def _parse_xml_to_statements(self, xml_output: str) -> list[Statement]:
544
- """Parse XML output into Statement objects."""
715
+ def _parse_xml_to_statements(
716
+ self,
717
+ xml_output: str,
718
+ options: Optional[ExtractionOptions] = None,
719
+ ) -> list[Statement]:
720
+ """
721
+ Parse XML output into Statement objects.
722
+
723
+ Uses model for subject, object, entity types, and source_text.
724
+ Always uses GLiNER2 for predicate extraction (model predicates are unreliable).
725
+
726
+ Produces two candidates for each statement:
727
+ 1. Hybrid: model subject/object + GLiNER2 predicate
728
+ 2. GLiNER2-only: all components from GLiNER2
729
+
730
+ Both go into the candidate pool; scoring/dedup picks the best.
731
+ """
545
732
  statements: list[Statement] = []
733
+ use_gliner_extraction = options.use_gliner_extraction if options else True
734
+ predicates = options.predicates if options else None
546
735
 
547
736
  try:
548
737
  root = ET.fromstring(xml_output)
549
738
  except ET.ParseError as e:
550
739
  # Log full output for debugging
551
- logger.warning(f"Failed to parse XML output: {e}")
552
- logger.warning(f"Full XML output ({len(xml_output)} chars):\n{xml_output}")
553
- return statements
740
+ logger.debug(f"Initial XML parse failed: {e}")
741
+ logger.debug(f"Raw XML output ({len(xml_output)} chars):\n{xml_output}")
742
+
743
+ # Try to repair the XML
744
+ repaired_xml, repairs = repair_xml(xml_output)
745
+ if repairs:
746
+ logger.debug(f"Attempted XML repairs: {', '.join(repairs)}")
747
+ try:
748
+ root = ET.fromstring(repaired_xml)
749
+ logger.info(f"XML repair successful, parsing repaired output")
750
+ except ET.ParseError as e2:
751
+ logger.warning(f"XML repair failed, still cannot parse: {e2}")
752
+ logger.warning(f"Repaired XML ({len(repaired_xml)} chars):\n{repaired_xml}")
753
+ return statements
754
+ else:
755
+ logger.warning(f"No repairs possible for XML output")
756
+ logger.warning(f"Full XML output ({len(xml_output)} chars):\n{xml_output}")
757
+ return statements
554
758
 
555
759
  if root.tag != 'statements':
760
+ logger.warning(f"Root tag is '{root.tag}', expected 'statements'")
556
761
  return statements
557
762
 
558
763
  for stmt_elem in root.findall('stmt'):
559
764
  try:
560
- # Parse subject
765
+ # Parse subject from model
561
766
  subject_elem = stmt_elem.find('subject')
562
767
  subject_text = subject_elem.text.strip() if subject_elem is not None and subject_elem.text else ""
563
768
  subject_type = self._parse_entity_type(subject_elem.get('type') if subject_elem is not None else None)
564
769
 
565
- # Parse object
770
+ # Parse object from model
566
771
  object_elem = stmt_elem.find('object')
567
772
  object_text = object_elem.text.strip() if object_elem is not None and object_elem.text else ""
568
773
  object_type = self._parse_entity_type(object_elem.get('type') if object_elem is not None else None)
569
774
 
570
- # Parse predicate
571
- predicate_elem = stmt_elem.find('predicate')
572
- predicate = predicate_elem.text.strip() if predicate_elem is not None and predicate_elem.text else ""
573
-
574
- # Parse source text
775
+ # Parse source text from model
575
776
  text_elem = stmt_elem.find('text')
576
777
  source_text = text_elem.text.strip() if text_elem is not None and text_elem.text else None
577
778
 
578
- if subject_text and predicate and object_text:
579
- statements.append(Statement(
580
- subject=Entity(text=subject_text, type=subject_type),
581
- predicate=predicate,
582
- object=Entity(text=object_text, type=object_type),
583
- source_text=source_text,
584
- ))
779
+ # Skip if missing required components from model
780
+ if not subject_text or not object_text:
781
+ logger.debug(f"Skipping statement: missing subject or object from model")
782
+ continue
783
+
784
+ if use_gliner_extraction and source_text:
785
+ try:
786
+ from .gliner_extraction import extract_triple_from_text, extract_triple_by_predicate_split
787
+
788
+ # Get model predicate for fallback/refinement
789
+ predicate_elem = stmt_elem.find('predicate')
790
+ model_predicate = predicate_elem.text.strip() if predicate_elem is not None and predicate_elem.text else ""
791
+
792
+ gliner_result = extract_triple_from_text(
793
+ source_text=source_text,
794
+ model_subject=subject_text,
795
+ model_object=object_text,
796
+ model_predicate=model_predicate,
797
+ predicates=predicates,
798
+ )
799
+ if gliner_result:
800
+ gliner_subj, gliner_pred, gliner_obj = gliner_result
801
+
802
+ if gliner_pred:
803
+ # Candidate 1: Hybrid (model subject/object + GLiNER2 predicate)
804
+ logger.debug(
805
+ f"Adding hybrid candidate: '{subject_text}' --[{gliner_pred}]--> '{object_text}'"
806
+ )
807
+ statements.append(Statement(
808
+ subject=Entity(text=subject_text, type=subject_type),
809
+ predicate=gliner_pred,
810
+ object=Entity(text=object_text, type=object_type),
811
+ source_text=source_text,
812
+ extraction_method=ExtractionMethod.HYBRID,
813
+ ))
814
+
815
+ # Candidate 2: GLiNER2-only (if different from hybrid)
816
+ if gliner_subj and gliner_obj:
817
+ is_different = (gliner_subj != subject_text or gliner_obj != object_text)
818
+ if is_different:
819
+ logger.debug(
820
+ f"Adding GLiNER2-only candidate: '{gliner_subj}' --[{gliner_pred}]--> '{gliner_obj}'"
821
+ )
822
+ statements.append(Statement(
823
+ subject=Entity(text=gliner_subj, type=subject_type),
824
+ predicate=gliner_pred,
825
+ object=Entity(text=gliner_obj, type=object_type),
826
+ source_text=source_text,
827
+ extraction_method=ExtractionMethod.GLINER,
828
+ ))
829
+
830
+ # Candidate 3: Predicate-split (split source text around predicate)
831
+ split_result = extract_triple_by_predicate_split(
832
+ source_text=source_text,
833
+ predicate=gliner_pred,
834
+ )
835
+ if split_result:
836
+ split_subj, split_pred, split_obj = split_result
837
+ # Only add if different from previous candidates
838
+ is_different_from_hybrid = (split_subj != subject_text or split_obj != object_text)
839
+ is_different_from_gliner = (split_subj != gliner_subj or split_obj != gliner_obj)
840
+ if is_different_from_hybrid and is_different_from_gliner:
841
+ logger.debug(
842
+ f"Adding predicate-split candidate: '{split_subj}' --[{split_pred}]--> '{split_obj}'"
843
+ )
844
+ statements.append(Statement(
845
+ subject=Entity(text=split_subj, type=subject_type),
846
+ predicate=split_pred,
847
+ object=Entity(text=split_obj, type=object_type),
848
+ source_text=source_text,
849
+ extraction_method=ExtractionMethod.SPLIT,
850
+ ))
851
+ else:
852
+ logger.debug(
853
+ f"GLiNER2 found no predicate for: '{subject_text}' --> '{object_text}'"
854
+ )
855
+ except Exception as e:
856
+ logger.debug(f"GLiNER2 extraction failed: {e}")
857
+ else:
858
+ # GLiNER2 disabled - fall back to model predicate
859
+ predicate_elem = stmt_elem.find('predicate')
860
+ model_predicate = predicate_elem.text.strip() if predicate_elem is not None and predicate_elem.text else ""
861
+
862
+ if model_predicate:
863
+ statements.append(Statement(
864
+ subject=Entity(text=subject_text, type=subject_type),
865
+ predicate=model_predicate,
866
+ object=Entity(text=object_text, type=object_type),
867
+ source_text=source_text,
868
+ extraction_method=ExtractionMethod.MODEL,
869
+ ))
870
+ else:
871
+ logger.debug(
872
+ f"Skipping statement (no predicate, spaCy disabled): "
873
+ f"'{subject_text}' --> '{object_text}'"
874
+ )
585
875
  except Exception as e:
586
876
  logger.warning(f"Failed to parse statement: {e}")
587
877
  continue