corp-extractor 0.2.5__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,11 +14,12 @@ import xml.etree.ElementTree as ET
14
14
  from typing import Optional
15
15
 
16
16
  import torch
17
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
17
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
18
18
 
19
19
  from .models import (
20
20
  Entity,
21
21
  EntityType,
22
+ ExtractionMethod,
22
23
  ExtractionOptions,
23
24
  ExtractionResult,
24
25
  PredicateComparisonConfig,
@@ -29,6 +30,114 @@ from .models import (
29
30
 
30
31
  logger = logging.getLogger(__name__)
31
32
 
33
+
34
+ class StopOnSequence(StoppingCriteria):
35
+ """
36
+ Stop generation when a specific multi-token sequence is generated.
37
+
38
+ Decodes the generated tokens and checks if the stop sequence appears.
39
+ Works with sequences that span multiple tokens (e.g., "</statements>").
40
+ """
41
+
42
+ def __init__(self, tokenizer, stop_sequence: str, input_length: int):
43
+ self.tokenizer = tokenizer
44
+ self.stop_sequence = stop_sequence
45
+ self.input_length = input_length
46
+ # Track which beams have stopped (for batch generation)
47
+ self.stopped = set()
48
+
49
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
50
+ # Check each sequence in the batch
51
+ for idx, seq in enumerate(input_ids):
52
+ if idx in self.stopped:
53
+ continue
54
+ # Only decode the generated portion (after input)
55
+ generated = seq[self.input_length:]
56
+ decoded = self.tokenizer.decode(generated, skip_special_tokens=True)
57
+ if self.stop_sequence in decoded:
58
+ self.stopped.add(idx)
59
+
60
+ # Stop when ALL sequences have the stop sequence
61
+ return len(self.stopped) >= len(input_ids)
62
+
63
+
64
+ def repair_xml(xml_string: str) -> tuple[str, list[str]]:
65
+ """
66
+ Attempt to repair common XML syntax errors.
67
+
68
+ Returns:
69
+ Tuple of (repaired_xml, list_of_repairs_made)
70
+ """
71
+ repairs = []
72
+ original = xml_string
73
+
74
+ # 1. Fix unescaped ampersands (but not already escaped entities)
75
+ # Match & not followed by amp; lt; gt; quot; apos; or #
76
+ ampersand_pattern = r'&(?!(amp|lt|gt|quot|apos|#\d+|#x[0-9a-fA-F]+);)'
77
+ if re.search(ampersand_pattern, xml_string):
78
+ xml_string = re.sub(ampersand_pattern, '&amp;', xml_string)
79
+ repairs.append("escaped unescaped ampersands")
80
+
81
+ # 2. Fix unescaped < and > inside text content (not tags)
82
+ # This is tricky - we need to be careful not to break actual tags
83
+ # For now, just handle the most common case: < followed by space or lowercase
84
+ less_than_pattern = r'<(?=\s|[a-z]{2,}[^a-z/>])'
85
+ if re.search(less_than_pattern, xml_string):
86
+ xml_string = re.sub(less_than_pattern, '&lt;', xml_string)
87
+ repairs.append("escaped unescaped less-than signs")
88
+
89
+ # 3. Fix truncated closing tags (e.g., "</statemen" -> try to complete)
90
+ truncated_patterns = [
91
+ (r'</statement[^s>]*$', '</statements>'),
92
+ (r'</stm[^t>]*$', '</stmt>'),
93
+ (r'</subjec[^t>]*$', '</subject>'),
94
+ (r'</objec[^t>]*$', '</object>'),
95
+ (r'</predica[^t>]*$', '</predicate>'),
96
+ (r'</tex[^t>]*$', '</text>'),
97
+ ]
98
+ for pattern, replacement in truncated_patterns:
99
+ if re.search(pattern, xml_string):
100
+ xml_string = re.sub(pattern, replacement, xml_string)
101
+ repairs.append(f"completed truncated tag: {replacement}")
102
+
103
+ # 4. Add missing </statements> if we have <statements> but no closing
104
+ if '<statements>' in xml_string and '</statements>' not in xml_string:
105
+ # Try to find a good place to add it
106
+ # Look for the last complete </stmt> and add after it
107
+ last_stmt = xml_string.rfind('</stmt>')
108
+ if last_stmt != -1:
109
+ insert_pos = last_stmt + len('</stmt>')
110
+ xml_string = xml_string[:insert_pos] + '</statements>'
111
+ repairs.append("added missing </statements> after last </stmt>")
112
+ else:
113
+ xml_string = xml_string + '</statements>'
114
+ repairs.append("added missing </statements> at end")
115
+
116
+ # 5. Fix unclosed <stmt> tags - find <stmt> without matching </stmt>
117
+ # Count opens and closes
118
+ open_stmts = len(re.findall(r'<stmt>', xml_string))
119
+ close_stmts = len(re.findall(r'</stmt>', xml_string))
120
+ if open_stmts > close_stmts:
121
+ # Find incomplete statement blocks and try to close them
122
+ # Look for patterns like <stmt>...<subject>...</subject> without </stmt>
123
+ # This is complex, so just add closing tags before </statements>
124
+ missing = open_stmts - close_stmts
125
+ if '</statements>' in xml_string:
126
+ xml_string = xml_string.replace('</statements>', '</stmt>' * missing + '</statements>')
127
+ repairs.append(f"added {missing} missing </stmt> tag(s)")
128
+
129
+ # 6. Remove any content after </statements>
130
+ end_pos = xml_string.find('</statements>')
131
+ if end_pos != -1:
132
+ end_pos += len('</statements>')
133
+ if end_pos < len(xml_string):
134
+ xml_string = xml_string[:end_pos]
135
+ repairs.append("removed content after </statements>")
136
+
137
+ if xml_string != original:
138
+ return xml_string, repairs
139
+ return xml_string, []
140
+
32
141
  # Default model
33
142
  DEFAULT_MODEL_ID = "Corp-o-Rate-Community/statement-extractor"
34
143
 
@@ -80,11 +189,16 @@ class StatementExtractor:
80
189
 
81
190
  # Auto-detect device
82
191
  if device is None:
83
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
192
+ if torch.cuda.is_available():
193
+ self.device = "cuda"
194
+ elif torch.backends.mps.is_available():
195
+ self.device = "mps"
196
+ else:
197
+ self.device = "cpu"
84
198
  else:
85
199
  self.device = device
86
200
 
87
- # Auto-detect dtype
201
+ # Auto-detect dtype (bfloat16 only for CUDA, float32 for MPS/CPU)
88
202
  if torch_dtype is None:
89
203
  self.torch_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
90
204
  else:
@@ -175,6 +289,14 @@ class StatementExtractor:
175
289
  if options is None:
176
290
  options = ExtractionOptions()
177
291
 
292
+ logger.debug("=" * 60)
293
+ logger.debug("EXTRACTION STARTED")
294
+ logger.debug("=" * 60)
295
+ logger.debug(f"Input text length: {len(text)} chars")
296
+ logger.debug(f"Options: num_beams={options.num_beams}, diversity={options.diversity_penalty}")
297
+ logger.debug(f" merge_beams={options.merge_beams}, embedding_dedup={options.embedding_dedup}")
298
+ logger.debug(f" deduplicate={options.deduplicate}, max_new_tokens={options.max_new_tokens}")
299
+
178
300
  # Store original text for scoring
179
301
  original_text = text
180
302
 
@@ -185,6 +307,10 @@ class StatementExtractor:
185
307
  # Run extraction with retry logic
186
308
  statements = self._extract_with_scoring(text, original_text, options)
187
309
 
310
+ logger.debug("=" * 60)
311
+ logger.debug(f"EXTRACTION COMPLETE: {len(statements)} statements")
312
+ logger.debug("=" * 60)
313
+
188
314
  return ExtractionResult(
189
315
  statements=statements,
190
316
  source_text=original_text,
@@ -270,6 +396,10 @@ class StatementExtractor:
270
396
  4. Merges top beams or selects best beam
271
397
  5. Deduplicates using embeddings (if enabled)
272
398
  """
399
+ logger.debug("-" * 40)
400
+ logger.debug("PHASE 1: Tokenization")
401
+ logger.debug("-" * 40)
402
+
273
403
  # Tokenize input
274
404
  inputs = self.tokenizer(
275
405
  text,
@@ -278,48 +408,78 @@ class StatementExtractor:
278
408
  truncation=True,
279
409
  ).to(self.device)
280
410
 
411
+ input_ids = inputs["input_ids"]
412
+ logger.debug(f"Tokenized: {input_ids.shape[1]} tokens")
413
+
281
414
  # Count sentences for quality check
282
415
  num_sentences = self._count_sentences(text)
283
416
  min_expected = int(num_sentences * options.min_statement_ratio)
284
417
 
285
- logger.info(f"Input has ~{num_sentences} sentences, expecting >= {min_expected} statements")
418
+ logger.debug(f"Input has ~{num_sentences} sentences, min expected: {min_expected}")
286
419
 
287
420
  # Get beam scorer
288
421
  beam_scorer = self._get_beam_scorer(options)
289
422
 
423
+ logger.debug("-" * 40)
424
+ logger.debug("PHASE 2: Diverse Beam Search Generation")
425
+ logger.debug("-" * 40)
426
+
290
427
  all_candidates: list[list[Statement]] = []
291
428
 
292
429
  for attempt in range(options.max_attempts):
430
+ logger.debug(f"Attempt {attempt + 1}/{options.max_attempts}: Generating {options.num_beams} beams...")
431
+
293
432
  # Generate candidate beams
294
433
  candidates = self._generate_candidate_beams(inputs, options)
434
+ logger.debug(f" Generated {len(candidates)} valid XML outputs")
295
435
 
296
436
  # Parse each candidate to statements
297
437
  parsed_candidates = []
298
- for xml_output in candidates:
299
- statements = self._parse_xml_to_statements(xml_output)
438
+ for i, xml_output in enumerate(candidates):
439
+ statements = self._parse_xml_to_statements(xml_output, options)
300
440
  if statements:
301
441
  parsed_candidates.append(statements)
442
+ logger.debug(f" Beam {i}: {len(statements)} statements parsed")
443
+ else:
444
+ logger.warning(f" Beam {i}: 0 statements (parse failed)")
445
+ logger.warning(f" Beam {i} XML output:\n{xml_output}")
302
446
 
303
447
  all_candidates.extend(parsed_candidates)
304
448
 
305
449
  # Check if we have enough statements
306
450
  total_stmts = sum(len(c) for c in parsed_candidates)
307
- logger.info(f"Attempt {attempt + 1}/{options.max_attempts}: {len(parsed_candidates)} beams, {total_stmts} total statements")
451
+ logger.debug(f" Total: {len(parsed_candidates)} beams, {total_stmts} statements")
308
452
 
309
453
  if total_stmts >= min_expected:
454
+ logger.debug(f" Sufficient statements ({total_stmts} >= {min_expected}), stopping")
310
455
  break
311
456
 
312
457
  if not all_candidates:
458
+ logger.debug("No valid candidates generated, returning empty result")
313
459
  return []
314
460
 
461
+ logger.debug("-" * 40)
462
+ logger.debug("PHASE 3: Beam Selection/Merging")
463
+ logger.debug("-" * 40)
464
+
315
465
  # Select or merge beams
316
466
  if options.merge_beams:
467
+ logger.debug(f"Merging {len(all_candidates)} beams...")
317
468
  statements = beam_scorer.merge_beams(all_candidates, original_text)
469
+ logger.debug(f" After merge: {len(statements)} statements")
318
470
  else:
471
+ logger.debug(f"Selecting best beam from {len(all_candidates)} candidates...")
319
472
  statements = beam_scorer.select_best_beam(all_candidates, original_text)
473
+ logger.debug(f" Selected beam has {len(statements)} statements")
474
+
475
+ logger.debug("-" * 40)
476
+ logger.debug("PHASE 4: Deduplication")
477
+ logger.debug("-" * 40)
320
478
 
321
479
  # Apply embedding-based deduplication if enabled
322
480
  if options.embedding_dedup and options.deduplicate:
481
+ logger.debug("Using embedding-based deduplication...")
482
+ pre_dedup_count = len(statements)
323
483
  try:
324
484
  comparer = self._get_predicate_comparer(options)
325
485
  if comparer:
@@ -327,14 +487,41 @@ class StatementExtractor:
327
487
  statements,
328
488
  entity_canonicalizer=options.entity_canonicalizer
329
489
  )
490
+ logger.debug(f" After embedding dedup: {len(statements)} statements (removed {pre_dedup_count - len(statements)})")
491
+
330
492
  # Also normalize predicates if taxonomy provided
331
493
  if options.predicate_taxonomy or self._predicate_taxonomy:
494
+ logger.debug("Normalizing predicates to taxonomy...")
332
495
  statements = comparer.normalize_predicates(statements)
333
496
  except Exception as e:
334
497
  logger.warning(f"Embedding deduplication failed, falling back to exact match: {e}")
335
498
  statements = self._deduplicate_statements_exact(statements, options)
499
+ logger.debug(f" After exact dedup: {len(statements)} statements")
336
500
  elif options.deduplicate:
501
+ logger.debug("Using exact text deduplication...")
502
+ pre_dedup_count = len(statements)
337
503
  statements = self._deduplicate_statements_exact(statements, options)
504
+ logger.debug(f" After exact dedup: {len(statements)} statements (removed {pre_dedup_count - len(statements)})")
505
+ else:
506
+ logger.debug("Deduplication disabled")
507
+
508
+ # Select best triple per source text (unless all_triples enabled)
509
+ if not options.all_triples:
510
+ logger.debug("-" * 40)
511
+ logger.debug("PHASE 5: Best Triple Selection")
512
+ logger.debug("-" * 40)
513
+ pre_select_count = len(statements)
514
+ statements = self._select_best_per_source(statements)
515
+ logger.debug(f" Selected best per source: {len(statements)} statements (from {pre_select_count})")
516
+
517
+ # Log final statements
518
+ logger.debug("-" * 40)
519
+ logger.debug("FINAL STATEMENTS:")
520
+ logger.debug("-" * 40)
521
+ for i, stmt in enumerate(statements):
522
+ conf = f" (conf={stmt.confidence_score:.2f})" if stmt.confidence_score else ""
523
+ canonical = f" -> {stmt.canonical_predicate}" if stmt.canonical_predicate else ""
524
+ logger.debug(f" {i+1}. {stmt.subject.text} --[{stmt.predicate}{canonical}]--> {stmt.object.text}{conf}")
338
525
 
339
526
  return statements
340
527
 
@@ -346,16 +533,29 @@ class StatementExtractor:
346
533
  """Generate multiple candidate beams using diverse beam search."""
347
534
  num_seqs = options.num_beams
348
535
 
536
+ # Create stopping criteria to stop when </statements> is generated
537
+ input_length = inputs["input_ids"].shape[1]
538
+ stop_criteria = StopOnSequence(
539
+ tokenizer=self.tokenizer,
540
+ stop_sequence="</statements>",
541
+ input_length=input_length,
542
+ )
543
+
349
544
  with torch.no_grad():
350
545
  outputs = self.model.generate(
351
546
  **inputs,
352
547
  max_new_tokens=options.max_new_tokens,
548
+ max_length=None, # Override model default, use max_new_tokens only
353
549
  num_beams=num_seqs,
354
550
  num_beam_groups=num_seqs,
355
551
  num_return_sequences=num_seqs,
356
552
  diversity_penalty=options.diversity_penalty,
357
553
  do_sample=False,
554
+ top_p=None, # Override model config to suppress warning
555
+ top_k=None, # Override model config to suppress warning
358
556
  trust_remote_code=True,
557
+ custom_generate="transformers-community/group-beam-search",
558
+ stopping_criteria=StoppingCriteriaList([stop_criteria]),
359
559
  )
360
560
 
361
561
  # Decode and process candidates
@@ -439,6 +639,50 @@ class StatementExtractor:
439
639
  entity_canonicalizer=options.entity_canonicalizer
440
640
  )
441
641
 
642
+ def _select_best_per_source(
643
+ self,
644
+ statements: list[Statement],
645
+ ) -> list[Statement]:
646
+ """
647
+ Select the highest-scoring triple for each unique source text.
648
+
649
+ Groups statements by source_text and keeps only the one with
650
+ the highest confidence_score from each group.
651
+
652
+ Statements without source_text are kept as-is.
653
+ """
654
+ if not statements:
655
+ return statements
656
+
657
+ # Group by source_text
658
+ from collections import defaultdict
659
+ groups: dict[str | None, list[Statement]] = defaultdict(list)
660
+
661
+ for stmt in statements:
662
+ groups[stmt.source_text].append(stmt)
663
+
664
+ # Select best from each group
665
+ result: list[Statement] = []
666
+
667
+ for source_text, group in groups.items():
668
+ if source_text is None or len(group) == 1:
669
+ # No source text or only one statement - keep as-is
670
+ result.extend(group)
671
+ else:
672
+ # Multiple candidates for same source - select best
673
+ best = max(
674
+ group,
675
+ key=lambda s: s.confidence_score if s.confidence_score is not None else 0.0
676
+ )
677
+ logger.debug(
678
+ f" Selected best for source '{source_text[:40]}...': "
679
+ f"'{best.subject.text}' --[{best.predicate}]--> '{best.object.text}' "
680
+ f"(score={best.confidence_score:.2f}, method={best.extraction_method.value})"
681
+ )
682
+ result.append(best)
683
+
684
+ return result
685
+
442
686
  def _deduplicate_xml(self, xml_output: str) -> str:
443
687
  """Remove duplicate <stmt> blocks from XML output (legacy method)."""
444
688
  try:
@@ -468,48 +712,159 @@ class StatementExtractor:
468
712
 
469
713
  return ET.tostring(new_root, encoding='unicode')
470
714
 
471
- def _parse_xml_to_statements(self, xml_output: str) -> list[Statement]:
472
- """Parse XML output into Statement objects."""
715
+ def _parse_xml_to_statements(
716
+ self,
717
+ xml_output: str,
718
+ options: Optional[ExtractionOptions] = None,
719
+ ) -> list[Statement]:
720
+ """
721
+ Parse XML output into Statement objects.
722
+
723
+ Uses model for subject, object, entity types, and source_text.
724
+ Always uses spaCy for predicate extraction (model predicates are unreliable).
725
+
726
+ Produces two candidates for each statement:
727
+ 1. Hybrid: model subject/object + spaCy predicate
728
+ 2. spaCy-only: all components from spaCy
729
+
730
+ Both go into the candidate pool; scoring/dedup picks the best.
731
+ """
473
732
  statements: list[Statement] = []
733
+ use_spacy_extraction = options.use_spacy_extraction if options else True
474
734
 
475
735
  try:
476
736
  root = ET.fromstring(xml_output)
477
737
  except ET.ParseError as e:
478
738
  # Log full output for debugging
479
- logger.warning(f"Failed to parse XML output: {e}")
480
- logger.warning(f"Full XML output ({len(xml_output)} chars):\n{xml_output}")
481
- return statements
739
+ logger.debug(f"Initial XML parse failed: {e}")
740
+ logger.debug(f"Raw XML output ({len(xml_output)} chars):\n{xml_output}")
741
+
742
+ # Try to repair the XML
743
+ repaired_xml, repairs = repair_xml(xml_output)
744
+ if repairs:
745
+ logger.debug(f"Attempted XML repairs: {', '.join(repairs)}")
746
+ try:
747
+ root = ET.fromstring(repaired_xml)
748
+ logger.info(f"XML repair successful, parsing repaired output")
749
+ except ET.ParseError as e2:
750
+ logger.warning(f"XML repair failed, still cannot parse: {e2}")
751
+ logger.warning(f"Repaired XML ({len(repaired_xml)} chars):\n{repaired_xml}")
752
+ return statements
753
+ else:
754
+ logger.warning(f"No repairs possible for XML output")
755
+ logger.warning(f"Full XML output ({len(xml_output)} chars):\n{xml_output}")
756
+ return statements
482
757
 
483
758
  if root.tag != 'statements':
759
+ logger.warning(f"Root tag is '{root.tag}', expected 'statements'")
484
760
  return statements
485
761
 
486
762
  for stmt_elem in root.findall('stmt'):
487
763
  try:
488
- # Parse subject
764
+ # Parse subject from model
489
765
  subject_elem = stmt_elem.find('subject')
490
766
  subject_text = subject_elem.text.strip() if subject_elem is not None and subject_elem.text else ""
491
767
  subject_type = self._parse_entity_type(subject_elem.get('type') if subject_elem is not None else None)
492
768
 
493
- # Parse object
769
+ # Parse object from model
494
770
  object_elem = stmt_elem.find('object')
495
771
  object_text = object_elem.text.strip() if object_elem is not None and object_elem.text else ""
496
772
  object_type = self._parse_entity_type(object_elem.get('type') if object_elem is not None else None)
497
773
 
498
- # Parse predicate
499
- predicate_elem = stmt_elem.find('predicate')
500
- predicate = predicate_elem.text.strip() if predicate_elem is not None and predicate_elem.text else ""
501
-
502
- # Parse source text
774
+ # Parse source text from model
503
775
  text_elem = stmt_elem.find('text')
504
776
  source_text = text_elem.text.strip() if text_elem is not None and text_elem.text else None
505
777
 
506
- if subject_text and predicate and object_text:
507
- statements.append(Statement(
508
- subject=Entity(text=subject_text, type=subject_type),
509
- predicate=predicate,
510
- object=Entity(text=object_text, type=object_type),
511
- source_text=source_text,
512
- ))
778
+ # Skip if missing required components from model
779
+ if not subject_text or not object_text:
780
+ logger.debug(f"Skipping statement: missing subject or object from model")
781
+ continue
782
+
783
+ if use_spacy_extraction and source_text:
784
+ try:
785
+ from .spacy_extraction import extract_triple_from_text, extract_triple_by_predicate_split
786
+ spacy_result = extract_triple_from_text(
787
+ source_text=source_text,
788
+ model_subject=subject_text,
789
+ model_object=object_text,
790
+ model_predicate="", # Don't pass model predicate
791
+ )
792
+ if spacy_result:
793
+ spacy_subj, spacy_pred, spacy_obj = spacy_result
794
+
795
+ if spacy_pred:
796
+ # Candidate 1: Hybrid (model subject/object + spaCy predicate)
797
+ logger.debug(
798
+ f"Adding hybrid candidate: '{subject_text}' --[{spacy_pred}]--> '{object_text}'"
799
+ )
800
+ statements.append(Statement(
801
+ subject=Entity(text=subject_text, type=subject_type),
802
+ predicate=spacy_pred,
803
+ object=Entity(text=object_text, type=object_type),
804
+ source_text=source_text,
805
+ extraction_method=ExtractionMethod.HYBRID,
806
+ ))
807
+
808
+ # Candidate 2: spaCy-only (if different from hybrid)
809
+ if spacy_subj and spacy_obj:
810
+ is_different = (spacy_subj != subject_text or spacy_obj != object_text)
811
+ if is_different:
812
+ logger.debug(
813
+ f"Adding spaCy-only candidate: '{spacy_subj}' --[{spacy_pred}]--> '{spacy_obj}'"
814
+ )
815
+ statements.append(Statement(
816
+ subject=Entity(text=spacy_subj, type=subject_type),
817
+ predicate=spacy_pred,
818
+ object=Entity(text=spacy_obj, type=object_type),
819
+ source_text=source_text,
820
+ extraction_method=ExtractionMethod.SPACY,
821
+ ))
822
+
823
+ # Candidate 3: Predicate-split (split source text around predicate)
824
+ split_result = extract_triple_by_predicate_split(
825
+ source_text=source_text,
826
+ predicate=spacy_pred,
827
+ )
828
+ if split_result:
829
+ split_subj, split_pred, split_obj = split_result
830
+ # Only add if different from previous candidates
831
+ is_different_from_hybrid = (split_subj != subject_text or split_obj != object_text)
832
+ is_different_from_spacy = (split_subj != spacy_subj or split_obj != spacy_obj)
833
+ if is_different_from_hybrid and is_different_from_spacy:
834
+ logger.debug(
835
+ f"Adding predicate-split candidate: '{split_subj}' --[{split_pred}]--> '{split_obj}'"
836
+ )
837
+ statements.append(Statement(
838
+ subject=Entity(text=split_subj, type=subject_type),
839
+ predicate=split_pred,
840
+ object=Entity(text=split_obj, type=object_type),
841
+ source_text=source_text,
842
+ extraction_method=ExtractionMethod.SPLIT,
843
+ ))
844
+ else:
845
+ logger.debug(
846
+ f"spaCy found no predicate for: '{subject_text}' --> '{object_text}'"
847
+ )
848
+ except Exception as e:
849
+ logger.debug(f"spaCy extraction failed: {e}")
850
+ else:
851
+ # spaCy disabled - fall back to model predicate
852
+ predicate_elem = stmt_elem.find('predicate')
853
+ model_predicate = predicate_elem.text.strip() if predicate_elem is not None and predicate_elem.text else ""
854
+
855
+ if model_predicate:
856
+ statements.append(Statement(
857
+ subject=Entity(text=subject_text, type=subject_type),
858
+ predicate=model_predicate,
859
+ object=Entity(text=object_text, type=object_type),
860
+ source_text=source_text,
861
+ extraction_method=ExtractionMethod.MODEL,
862
+ ))
863
+ else:
864
+ logger.debug(
865
+ f"Skipping statement (no predicate, spaCy disabled): "
866
+ f"'{subject_text}' --> '{object_text}'"
867
+ )
513
868
  except Exception as e:
514
869
  logger.warning(f"Failed to parse statement: {e}")
515
870
  continue
@@ -24,6 +24,14 @@ class EntityType(str, Enum):
24
24
  UNKNOWN = "UNKNOWN"
25
25
 
26
26
 
27
+ class ExtractionMethod(str, Enum):
28
+ """Method used to extract the triple components."""
29
+ HYBRID = "hybrid" # Model subject/object + spaCy predicate
30
+ SPACY = "spacy" # All components from spaCy dependency parsing
31
+ SPLIT = "split" # Subject/object from splitting source text around predicate
32
+ MODEL = "model" # All components from T5-Gemma model (when spaCy disabled)
33
+
34
+
27
35
  class Entity(BaseModel):
28
36
  """An entity (subject or object) with its text and type."""
29
37
  text: str = Field(..., description="The entity text")
@@ -52,12 +60,18 @@ class Statement(BaseModel):
52
60
  object: Entity = Field(..., description="The object entity")
53
61
  source_text: Optional[str] = Field(None, description="The original text this statement was extracted from")
54
62
 
63
+ # Extraction method tracking
64
+ extraction_method: ExtractionMethod = Field(
65
+ default=ExtractionMethod.MODEL,
66
+ description="Method used to extract this triple (hybrid, spacy, split, or model)"
67
+ )
68
+
55
69
  # Quality scoring fields
56
70
  confidence_score: Optional[float] = Field(
57
71
  None,
58
72
  ge=0.0,
59
73
  le=1.0,
60
- description="Groundedness score (0-1) indicating how well the triple is supported by source text"
74
+ description="Semantic similarity score (0-1) between source text and reassembled triple"
61
75
  )
62
76
  evidence_span: Optional[tuple[int, int]] = Field(
63
77
  None,
@@ -99,6 +113,7 @@ class Statement(BaseModel):
99
113
  object=merged_object,
100
114
  predicate=self.predicate,
101
115
  source_text=self.source_text,
116
+ extraction_method=self.extraction_method,
102
117
  confidence_score=self.confidence_score,
103
118
  evidence_span=self.evidence_span,
104
119
  canonical_predicate=self.canonical_predicate,
@@ -116,6 +131,7 @@ class Statement(BaseModel):
116
131
  object=self.subject,
117
132
  predicate=self.predicate,
118
133
  source_text=self.source_text,
134
+ extraction_method=self.extraction_method,
119
135
  confidence_score=self.confidence_score,
120
136
  evidence_span=self.evidence_span,
121
137
  canonical_predicate=self.canonical_predicate,
@@ -279,6 +295,22 @@ class ExtractionOptions(BaseModel):
279
295
  default=True,
280
296
  description="Use embedding similarity for predicate deduplication"
281
297
  )
298
+ use_spacy_extraction: bool = Field(
299
+ default=True,
300
+ description="Use spaCy for predicate/subject/object extraction (model provides structure + coreference)"
301
+ )
302
+
303
+ # Verbose logging
304
+ verbose: bool = Field(
305
+ default=False,
306
+ description="Enable verbose logging for debugging"
307
+ )
308
+
309
+ # Triple selection
310
+ all_triples: bool = Field(
311
+ default=False,
312
+ description="Keep all candidate triples instead of selecting the highest-scoring one per source"
313
+ )
282
314
 
283
315
  class Config:
284
316
  arbitrary_types_allowed = True # Allow Callable type