corp-extractor 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of corp-extractor might be problematic. Click here for more details.

@@ -0,0 +1,649 @@
1
+ """
2
+ Statement Extractor - Extract structured statements from text using T5-Gemma 2.
3
+
4
+ This module uses Diverse Beam Search (Vijayakumar et al., 2016) to generate
5
+ multiple candidate extractions and selects/merges the best results using
6
+ quality scoring.
7
+
8
+ Paper: https://arxiv.org/abs/1610.02424
9
+ """
10
+
11
+ import logging
12
+ import re
13
+ import xml.etree.ElementTree as ET
14
+ from typing import Optional
15
+
16
+ import torch
17
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
18
+
19
+ from .models import (
20
+ Entity,
21
+ EntityType,
22
+ ExtractionOptions,
23
+ ExtractionResult,
24
+ PredicateComparisonConfig,
25
+ PredicateTaxonomy,
26
+ ScoringConfig,
27
+ Statement,
28
+ )
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+ # Default model
33
+ DEFAULT_MODEL_ID = "Corp-o-Rate-Community/statement-extractor"
34
+
35
+
36
+ class StatementExtractor:
37
+ """
38
+ Extract structured statements from unstructured text.
39
+
40
+ Uses the T5-Gemma 2 statement extraction model with Diverse Beam Search
41
+ to generate high-quality subject-predicate-object triples.
42
+
43
+ Features:
44
+ - Quality-based beam scoring (not just longest output)
45
+ - Beam merging for better coverage
46
+ - Embedding-based predicate comparison for smart deduplication
47
+ - Configurable precision/recall tradeoff
48
+
49
+ Example:
50
+ >>> extractor = StatementExtractor()
51
+ >>> result = extractor.extract("Apple Inc. announced a new iPhone today.")
52
+ >>> for stmt in result:
53
+ ... print(stmt)
54
+ Apple Inc. -- announced --> a new iPhone
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ model_id: str = DEFAULT_MODEL_ID,
60
+ device: Optional[str] = None,
61
+ torch_dtype: Optional[torch.dtype] = None,
62
+ predicate_taxonomy: Optional[PredicateTaxonomy] = None,
63
+ predicate_config: Optional[PredicateComparisonConfig] = None,
64
+ scoring_config: Optional[ScoringConfig] = None,
65
+ ):
66
+ """
67
+ Initialize the statement extractor.
68
+
69
+ Args:
70
+ model_id: HuggingFace model ID or local path
71
+ device: Device to use ('cuda', 'cpu', or None for auto-detect)
72
+ torch_dtype: Torch dtype (default: bfloat16 on GPU, float32 on CPU)
73
+ predicate_taxonomy: Optional taxonomy for predicate normalization
74
+ predicate_config: Configuration for predicate comparison
75
+ scoring_config: Configuration for quality scoring
76
+ """
77
+ self.model_id = model_id
78
+ self._model: Optional[AutoModelForSeq2SeqLM] = None
79
+ self._tokenizer: Optional[AutoTokenizer] = None
80
+
81
+ # Auto-detect device
82
+ if device is None:
83
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
84
+ else:
85
+ self.device = device
86
+
87
+ # Auto-detect dtype
88
+ if torch_dtype is None:
89
+ self.torch_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
90
+ else:
91
+ self.torch_dtype = torch_dtype
92
+
93
+ # Scoring and comparison config
94
+ self._predicate_taxonomy = predicate_taxonomy
95
+ self._predicate_config = predicate_config
96
+ self._scoring_config = scoring_config
97
+
98
+ # Lazy-loaded components
99
+ self._beam_scorer = None
100
+ self._predicate_comparer = None
101
+
102
+ def _load_model(self) -> None:
103
+ """Load model and tokenizer if not already loaded."""
104
+ if self._model is not None:
105
+ return
106
+
107
+ logger.info(f"Loading model: {self.model_id}")
108
+
109
+ self._tokenizer = AutoTokenizer.from_pretrained(
110
+ self.model_id,
111
+ trust_remote_code=True,
112
+ )
113
+
114
+ if self.device == "cuda":
115
+ self._model = AutoModelForSeq2SeqLM.from_pretrained(
116
+ self.model_id,
117
+ torch_dtype=self.torch_dtype,
118
+ trust_remote_code=True,
119
+ device_map="auto",
120
+ )
121
+ else:
122
+ self._model = AutoModelForSeq2SeqLM.from_pretrained(
123
+ self.model_id,
124
+ trust_remote_code=True,
125
+ )
126
+ self._model = self._model.to(self.device)
127
+
128
+ logger.info(f"Model loaded on {self.device}")
129
+
130
+ def _get_beam_scorer(self, options: ExtractionOptions):
131
+ """Get or create beam scorer with current config."""
132
+ from .scoring import BeamScorer
133
+
134
+ config = options.scoring_config or self._scoring_config or ScoringConfig()
135
+ return BeamScorer(config=config)
136
+
137
+ def _get_predicate_comparer(self, options: ExtractionOptions):
138
+ """Get or create predicate comparer if embeddings enabled."""
139
+ if not options.embedding_dedup:
140
+ return None
141
+
142
+ from .predicate_comparer import PredicateComparer
143
+
144
+ taxonomy = options.predicate_taxonomy or self._predicate_taxonomy
145
+ config = options.predicate_config or self._predicate_config or PredicateComparisonConfig()
146
+ return PredicateComparer(taxonomy=taxonomy, config=config, device=self.device)
147
+
148
+ @property
149
+ def model(self) -> AutoModelForSeq2SeqLM:
150
+ """Get the model, loading it if necessary."""
151
+ self._load_model()
152
+ return self._model
153
+
154
+ @property
155
+ def tokenizer(self) -> AutoTokenizer:
156
+ """Get the tokenizer, loading it if necessary."""
157
+ self._load_model()
158
+ return self._tokenizer
159
+
160
+ def extract(
161
+ self,
162
+ text: str,
163
+ options: Optional[ExtractionOptions] = None,
164
+ ) -> ExtractionResult:
165
+ """
166
+ Extract statements from text.
167
+
168
+ Args:
169
+ text: Input text to extract statements from
170
+ options: Extraction options (uses defaults if not provided)
171
+
172
+ Returns:
173
+ ExtractionResult containing the extracted statements
174
+ """
175
+ if options is None:
176
+ options = ExtractionOptions()
177
+
178
+ # Store original text for scoring
179
+ original_text = text
180
+
181
+ # Wrap text in page tags if not already wrapped
182
+ if not text.startswith("<page>"):
183
+ text = f"<page>{text}</page>"
184
+
185
+ # Run extraction with retry logic
186
+ statements = self._extract_with_scoring(text, original_text, options)
187
+
188
+ return ExtractionResult(
189
+ statements=statements,
190
+ source_text=original_text,
191
+ )
192
+
193
+ def extract_as_xml(
194
+ self,
195
+ text: str,
196
+ options: Optional[ExtractionOptions] = None,
197
+ ) -> str:
198
+ """
199
+ Extract statements and return raw XML output.
200
+
201
+ Note: This bypasses the new scoring/merging logic for backward compatibility.
202
+ Use extract() for full quality scoring.
203
+
204
+ Args:
205
+ text: Input text to extract statements from
206
+ options: Extraction options
207
+
208
+ Returns:
209
+ XML string with <statements> containing <stmt> elements
210
+ """
211
+ if options is None:
212
+ options = ExtractionOptions()
213
+
214
+ if not text.startswith("<page>"):
215
+ text = f"<page>{text}</page>"
216
+
217
+ return self._extract_raw_xml(text, options)
218
+
219
+ def extract_as_json(
220
+ self,
221
+ text: str,
222
+ options: Optional[ExtractionOptions] = None,
223
+ indent: Optional[int] = 2,
224
+ ) -> str:
225
+ """
226
+ Extract statements and return JSON string.
227
+
228
+ Args:
229
+ text: Input text to extract statements from
230
+ options: Extraction options
231
+ indent: JSON indentation (None for compact)
232
+
233
+ Returns:
234
+ JSON string representation of the extraction result
235
+ """
236
+ result = self.extract(text, options)
237
+ return result.model_dump_json(indent=indent)
238
+
239
+ def extract_as_dict(
240
+ self,
241
+ text: str,
242
+ options: Optional[ExtractionOptions] = None,
243
+ ) -> dict:
244
+ """
245
+ Extract statements and return as dictionary.
246
+
247
+ Args:
248
+ text: Input text to extract statements from
249
+ options: Extraction options
250
+
251
+ Returns:
252
+ Dictionary representation of the extraction result
253
+ """
254
+ result = self.extract(text, options)
255
+ return result.model_dump()
256
+
257
+ def _extract_with_scoring(
258
+ self,
259
+ text: str,
260
+ original_text: str,
261
+ options: ExtractionOptions,
262
+ ) -> list[Statement]:
263
+ """
264
+ Extract statements with quality scoring and beam merging.
265
+
266
+ This is the new extraction pipeline that:
267
+ 1. Generates multiple candidates via DBS
268
+ 2. Parses each to statements
269
+ 3. Scores each triple for groundedness
270
+ 4. Merges top beams or selects best beam
271
+ 5. Deduplicates using embeddings (if enabled)
272
+ """
273
+ # Tokenize input
274
+ inputs = self.tokenizer(
275
+ text,
276
+ return_tensors="pt",
277
+ max_length=4096,
278
+ truncation=True,
279
+ ).to(self.device)
280
+
281
+ # Count sentences for quality check
282
+ num_sentences = self._count_sentences(text)
283
+ min_expected = int(num_sentences * options.min_statement_ratio)
284
+
285
+ logger.info(f"Input has ~{num_sentences} sentences, expecting >= {min_expected} statements")
286
+
287
+ # Get beam scorer
288
+ beam_scorer = self._get_beam_scorer(options)
289
+
290
+ all_candidates: list[list[Statement]] = []
291
+
292
+ for attempt in range(options.max_attempts):
293
+ # Generate candidate beams
294
+ candidates = self._generate_candidate_beams(inputs, options)
295
+
296
+ # Parse each candidate to statements
297
+ parsed_candidates = []
298
+ for xml_output in candidates:
299
+ statements = self._parse_xml_to_statements(xml_output)
300
+ if statements:
301
+ parsed_candidates.append(statements)
302
+
303
+ all_candidates.extend(parsed_candidates)
304
+
305
+ # Check if we have enough statements
306
+ total_stmts = sum(len(c) for c in parsed_candidates)
307
+ logger.info(f"Attempt {attempt + 1}/{options.max_attempts}: {len(parsed_candidates)} beams, {total_stmts} total statements")
308
+
309
+ if total_stmts >= min_expected:
310
+ break
311
+
312
+ if not all_candidates:
313
+ return []
314
+
315
+ # Select or merge beams
316
+ if options.merge_beams:
317
+ statements = beam_scorer.merge_beams(all_candidates, original_text)
318
+ else:
319
+ statements = beam_scorer.select_best_beam(all_candidates, original_text)
320
+
321
+ # Apply embedding-based deduplication if enabled
322
+ if options.embedding_dedup and options.deduplicate:
323
+ try:
324
+ comparer = self._get_predicate_comparer(options)
325
+ if comparer:
326
+ statements = comparer.deduplicate_statements(
327
+ statements,
328
+ entity_canonicalizer=options.entity_canonicalizer
329
+ )
330
+ # Also normalize predicates if taxonomy provided
331
+ if options.predicate_taxonomy or self._predicate_taxonomy:
332
+ statements = comparer.normalize_predicates(statements)
333
+ except Exception as e:
334
+ logger.warning(f"Embedding deduplication failed, falling back to exact match: {e}")
335
+ statements = self._deduplicate_statements_exact(statements, options)
336
+ elif options.deduplicate:
337
+ statements = self._deduplicate_statements_exact(statements, options)
338
+
339
+ return statements
340
+
341
+ def _generate_candidate_beams(
342
+ self,
343
+ inputs,
344
+ options: ExtractionOptions,
345
+ ) -> list[str]:
346
+ """Generate multiple candidate beams using diverse beam search."""
347
+ num_seqs = options.num_beams
348
+
349
+ with torch.no_grad():
350
+ outputs = self.model.generate(
351
+ **inputs,
352
+ max_new_tokens=options.max_new_tokens,
353
+ num_beams=num_seqs,
354
+ num_beam_groups=num_seqs,
355
+ num_return_sequences=num_seqs,
356
+ diversity_penalty=options.diversity_penalty,
357
+ do_sample=False,
358
+ trust_remote_code=True,
359
+ )
360
+
361
+ # Decode and process candidates
362
+ end_tag = "</statements>"
363
+ candidates: list[str] = []
364
+
365
+ for i, output in enumerate(outputs):
366
+ decoded = self.tokenizer.decode(output, skip_special_tokens=True)
367
+ output_len = len(output)
368
+
369
+ # Truncate at </statements>
370
+ if end_tag in decoded:
371
+ end_pos = decoded.find(end_tag) + len(end_tag)
372
+ decoded = decoded[:end_pos]
373
+ candidates.append(decoded)
374
+ logger.debug(f"Beam {i}: {output_len} tokens, found end tag, {len(decoded)} chars")
375
+ else:
376
+ # Log the issue - likely truncated
377
+ logger.warning(f"Beam {i}: {output_len} tokens, NO end tag found (truncated?)")
378
+ logger.warning(f"Beam {i} full output ({len(decoded)} chars):\n{decoded}")
379
+
380
+ # Include fallback if no valid candidates
381
+ if not candidates and len(outputs) > 0:
382
+ fallback = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
383
+ logger.warning(f"Using fallback beam (no valid candidates found), {len(fallback)} chars")
384
+ candidates.append(fallback)
385
+
386
+ return candidates
387
+
388
+ def _extract_raw_xml(
389
+ self,
390
+ text: str,
391
+ options: ExtractionOptions,
392
+ ) -> str:
393
+ """
394
+ Extract and return raw XML (legacy method for backward compatibility).
395
+
396
+ Uses length-based selection like the original implementation.
397
+ """
398
+ # Tokenize input
399
+ inputs = self.tokenizer(
400
+ text,
401
+ return_tensors="pt",
402
+ max_length=4096,
403
+ truncation=True,
404
+ ).to(self.device)
405
+
406
+ num_sentences = self._count_sentences(text)
407
+ min_expected = int(num_sentences * options.min_statement_ratio)
408
+
409
+ all_results: list[tuple[str, int]] = []
410
+
411
+ for attempt in range(options.max_attempts):
412
+ candidates = self._generate_candidate_beams(inputs, options)
413
+
414
+ for candidate in candidates:
415
+ if options.deduplicate:
416
+ candidate = self._deduplicate_xml(candidate)
417
+ num_stmts = self._count_statements(candidate)
418
+ all_results.append((candidate, num_stmts))
419
+
420
+ best_so_far = max(all_results, key=lambda x: x[1])[1] if all_results else 0
421
+ if best_so_far >= min_expected:
422
+ break
423
+
424
+ if not all_results:
425
+ return "<statements></statements>"
426
+
427
+ # Select best result (longest, for backward compatibility)
428
+ return max(all_results, key=lambda x: len(x[0]))[0]
429
+
430
+ def _deduplicate_statements_exact(
431
+ self,
432
+ statements: list[Statement],
433
+ options: ExtractionOptions,
434
+ ) -> list[Statement]:
435
+ """Deduplicate statements using exact text matching."""
436
+ from .canonicalization import deduplicate_statements_exact
437
+ return deduplicate_statements_exact(
438
+ statements,
439
+ entity_canonicalizer=options.entity_canonicalizer
440
+ )
441
+
442
+ def _deduplicate_xml(self, xml_output: str) -> str:
443
+ """Remove duplicate <stmt> blocks from XML output (legacy method)."""
444
+ try:
445
+ root = ET.fromstring(xml_output)
446
+ except ET.ParseError:
447
+ return xml_output
448
+
449
+ if root.tag != 'statements':
450
+ return xml_output
451
+
452
+ seen: set[tuple[str, str, str]] = set()
453
+ unique_stmts: list[ET.Element] = []
454
+
455
+ for stmt in root.findall('stmt'):
456
+ subject = stmt.findtext('subject', '').strip().lower()
457
+ predicate = stmt.findtext('predicate', '').strip().lower()
458
+ obj = stmt.findtext('object', '').strip().lower()
459
+ key = (subject, predicate, obj)
460
+
461
+ if key not in seen:
462
+ seen.add(key)
463
+ unique_stmts.append(stmt)
464
+
465
+ new_root = ET.Element('statements')
466
+ for stmt in unique_stmts:
467
+ new_root.append(stmt)
468
+
469
+ return ET.tostring(new_root, encoding='unicode')
470
+
471
+ def _parse_xml_to_statements(self, xml_output: str) -> list[Statement]:
472
+ """Parse XML output into Statement objects."""
473
+ statements: list[Statement] = []
474
+
475
+ try:
476
+ root = ET.fromstring(xml_output)
477
+ except ET.ParseError as e:
478
+ # Log full output for debugging
479
+ logger.warning(f"Failed to parse XML output: {e}")
480
+ logger.warning(f"Full XML output ({len(xml_output)} chars):\n{xml_output}")
481
+ return statements
482
+
483
+ if root.tag != 'statements':
484
+ return statements
485
+
486
+ for stmt_elem in root.findall('stmt'):
487
+ try:
488
+ # Parse subject
489
+ subject_elem = stmt_elem.find('subject')
490
+ subject_text = subject_elem.text.strip() if subject_elem is not None and subject_elem.text else ""
491
+ subject_type = self._parse_entity_type(subject_elem.get('type') if subject_elem is not None else None)
492
+
493
+ # Parse object
494
+ object_elem = stmt_elem.find('object')
495
+ object_text = object_elem.text.strip() if object_elem is not None and object_elem.text else ""
496
+ object_type = self._parse_entity_type(object_elem.get('type') if object_elem is not None else None)
497
+
498
+ # Parse predicate
499
+ predicate_elem = stmt_elem.find('predicate')
500
+ predicate = predicate_elem.text.strip() if predicate_elem is not None and predicate_elem.text else ""
501
+
502
+ # Parse source text
503
+ text_elem = stmt_elem.find('text')
504
+ source_text = text_elem.text.strip() if text_elem is not None and text_elem.text else None
505
+
506
+ if subject_text and predicate and object_text:
507
+ statements.append(Statement(
508
+ subject=Entity(text=subject_text, type=subject_type),
509
+ predicate=predicate,
510
+ object=Entity(text=object_text, type=object_type),
511
+ source_text=source_text,
512
+ ))
513
+ except Exception as e:
514
+ logger.warning(f"Failed to parse statement: {e}")
515
+ continue
516
+
517
+ return statements
518
+
519
+ def _parse_entity_type(self, type_str: Optional[str]) -> EntityType:
520
+ """Parse entity type string to EntityType enum."""
521
+ if type_str is None:
522
+ return EntityType.UNKNOWN
523
+ try:
524
+ return EntityType(type_str.upper())
525
+ except ValueError:
526
+ return EntityType.UNKNOWN
527
+
528
+ @staticmethod
529
+ def _count_sentences(text: str) -> int:
530
+ """Count approximate number of sentences in text."""
531
+ clean_text = re.sub(r'<[^>]+>', '', text)
532
+ sentences = re.split(r'[.!?]+', clean_text)
533
+ sentences = [s.strip() for s in sentences if len(s.strip()) > 10]
534
+ return max(1, len(sentences))
535
+
536
+ @staticmethod
537
+ def _count_statements(xml_output: str) -> int:
538
+ """Count number of <stmt> tags in output."""
539
+ return len(re.findall(r'<stmt>', xml_output))
540
+
541
+
542
+ # Convenience functions for simple usage
543
+
544
+ _default_extractor: Optional[StatementExtractor] = None
545
+
546
+
547
+ def _get_default_extractor() -> StatementExtractor:
548
+ """Get or create the default extractor instance."""
549
+ global _default_extractor
550
+ if _default_extractor is None:
551
+ _default_extractor = StatementExtractor()
552
+ return _default_extractor
553
+
554
+
555
+ def extract_statements(
556
+ text: str,
557
+ options: Optional[ExtractionOptions] = None,
558
+ **kwargs,
559
+ ) -> ExtractionResult:
560
+ """
561
+ Extract structured statements from text.
562
+
563
+ This is a convenience function that uses a default StatementExtractor instance.
564
+ For more control, create your own StatementExtractor.
565
+
566
+ By default, uses embedding-based deduplication and beam merging for
567
+ high-quality extraction. Requires sentence-transformers package.
568
+
569
+ Args:
570
+ text: Input text to extract statements from
571
+ options: Extraction options (or pass individual options as kwargs)
572
+ **kwargs: Individual option overrides (num_beams, diversity_penalty, etc.)
573
+
574
+ Returns:
575
+ ExtractionResult containing Statement objects
576
+
577
+ Example:
578
+ >>> result = extract_statements("Apple announced a new product.")
579
+ >>> for stmt in result:
580
+ ... print(f"{stmt.subject.text} -> {stmt.predicate} -> {stmt.object.text}")
581
+ """
582
+ if options is None and kwargs:
583
+ options = ExtractionOptions(**kwargs)
584
+ return _get_default_extractor().extract(text, options)
585
+
586
+
587
+ def extract_statements_as_xml(
588
+ text: str,
589
+ options: Optional[ExtractionOptions] = None,
590
+ **kwargs,
591
+ ) -> str:
592
+ """
593
+ Extract statements and return raw XML output.
594
+
595
+ Args:
596
+ text: Input text to extract statements from
597
+ options: Extraction options
598
+ **kwargs: Individual option overrides
599
+
600
+ Returns:
601
+ XML string with <statements> containing <stmt> elements
602
+ """
603
+ if options is None and kwargs:
604
+ options = ExtractionOptions(**kwargs)
605
+ return _get_default_extractor().extract_as_xml(text, options)
606
+
607
+
608
+ def extract_statements_as_json(
609
+ text: str,
610
+ options: Optional[ExtractionOptions] = None,
611
+ indent: Optional[int] = 2,
612
+ **kwargs,
613
+ ) -> str:
614
+ """
615
+ Extract statements and return JSON string.
616
+
617
+ Args:
618
+ text: Input text to extract statements from
619
+ options: Extraction options
620
+ indent: JSON indentation (None for compact)
621
+ **kwargs: Individual option overrides
622
+
623
+ Returns:
624
+ JSON string representation of the extraction result
625
+ """
626
+ if options is None and kwargs:
627
+ options = ExtractionOptions(**kwargs)
628
+ return _get_default_extractor().extract_as_json(text, options, indent)
629
+
630
+
631
+ def extract_statements_as_dict(
632
+ text: str,
633
+ options: Optional[ExtractionOptions] = None,
634
+ **kwargs,
635
+ ) -> dict:
636
+ """
637
+ Extract statements and return as dictionary.
638
+
639
+ Args:
640
+ text: Input text to extract statements from
641
+ options: Extraction options
642
+ **kwargs: Individual option overrides
643
+
644
+ Returns:
645
+ Dictionary representation of the extraction result
646
+ """
647
+ if options is None and kwargs:
648
+ options = ExtractionOptions(**kwargs)
649
+ return _get_default_extractor().extract_as_dict(text, options)