corp-extractor 0.4.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. {corp_extractor-0.4.0.dist-info → corp_extractor-0.9.0.dist-info}/METADATA +348 -64
  2. corp_extractor-0.9.0.dist-info/RECORD +76 -0
  3. statement_extractor/__init__.py +10 -1
  4. statement_extractor/cli.py +1663 -17
  5. statement_extractor/data/default_predicates.json +368 -0
  6. statement_extractor/data/statement_taxonomy.json +6972 -0
  7. statement_extractor/database/__init__.py +52 -0
  8. statement_extractor/database/embeddings.py +186 -0
  9. statement_extractor/database/hub.py +520 -0
  10. statement_extractor/database/importers/__init__.py +24 -0
  11. statement_extractor/database/importers/companies_house.py +545 -0
  12. statement_extractor/database/importers/gleif.py +538 -0
  13. statement_extractor/database/importers/sec_edgar.py +375 -0
  14. statement_extractor/database/importers/wikidata.py +1012 -0
  15. statement_extractor/database/importers/wikidata_people.py +632 -0
  16. statement_extractor/database/models.py +230 -0
  17. statement_extractor/database/resolver.py +245 -0
  18. statement_extractor/database/store.py +1609 -0
  19. statement_extractor/document/__init__.py +62 -0
  20. statement_extractor/document/chunker.py +410 -0
  21. statement_extractor/document/context.py +171 -0
  22. statement_extractor/document/deduplicator.py +173 -0
  23. statement_extractor/document/html_extractor.py +246 -0
  24. statement_extractor/document/loader.py +303 -0
  25. statement_extractor/document/pipeline.py +388 -0
  26. statement_extractor/document/summarizer.py +195 -0
  27. statement_extractor/extractor.py +1 -23
  28. statement_extractor/gliner_extraction.py +4 -74
  29. statement_extractor/llm.py +255 -0
  30. statement_extractor/models/__init__.py +89 -0
  31. statement_extractor/models/canonical.py +182 -0
  32. statement_extractor/models/document.py +308 -0
  33. statement_extractor/models/entity.py +102 -0
  34. statement_extractor/models/labels.py +220 -0
  35. statement_extractor/models/qualifiers.py +139 -0
  36. statement_extractor/models/statement.py +101 -0
  37. statement_extractor/models.py +4 -1
  38. statement_extractor/pipeline/__init__.py +39 -0
  39. statement_extractor/pipeline/config.py +129 -0
  40. statement_extractor/pipeline/context.py +177 -0
  41. statement_extractor/pipeline/orchestrator.py +416 -0
  42. statement_extractor/pipeline/registry.py +303 -0
  43. statement_extractor/plugins/__init__.py +55 -0
  44. statement_extractor/plugins/base.py +716 -0
  45. statement_extractor/plugins/extractors/__init__.py +13 -0
  46. statement_extractor/plugins/extractors/base.py +9 -0
  47. statement_extractor/plugins/extractors/gliner2.py +546 -0
  48. statement_extractor/plugins/labelers/__init__.py +29 -0
  49. statement_extractor/plugins/labelers/base.py +9 -0
  50. statement_extractor/plugins/labelers/confidence.py +138 -0
  51. statement_extractor/plugins/labelers/relation_type.py +87 -0
  52. statement_extractor/plugins/labelers/sentiment.py +159 -0
  53. statement_extractor/plugins/labelers/taxonomy.py +386 -0
  54. statement_extractor/plugins/labelers/taxonomy_embedding.py +477 -0
  55. statement_extractor/plugins/pdf/__init__.py +10 -0
  56. statement_extractor/plugins/pdf/pypdf.py +291 -0
  57. statement_extractor/plugins/qualifiers/__init__.py +30 -0
  58. statement_extractor/plugins/qualifiers/base.py +9 -0
  59. statement_extractor/plugins/qualifiers/companies_house.py +185 -0
  60. statement_extractor/plugins/qualifiers/embedding_company.py +420 -0
  61. statement_extractor/plugins/qualifiers/gleif.py +197 -0
  62. statement_extractor/plugins/qualifiers/person.py +785 -0
  63. statement_extractor/plugins/qualifiers/sec_edgar.py +209 -0
  64. statement_extractor/plugins/scrapers/__init__.py +10 -0
  65. statement_extractor/plugins/scrapers/http.py +236 -0
  66. statement_extractor/plugins/splitters/__init__.py +13 -0
  67. statement_extractor/plugins/splitters/base.py +9 -0
  68. statement_extractor/plugins/splitters/t5_gemma.py +293 -0
  69. statement_extractor/plugins/taxonomy/__init__.py +13 -0
  70. statement_extractor/plugins/taxonomy/embedding.py +484 -0
  71. statement_extractor/plugins/taxonomy/mnli.py +291 -0
  72. statement_extractor/scoring.py +8 -8
  73. corp_extractor-0.4.0.dist-info/RECORD +0 -12
  74. {corp_extractor-0.4.0.dist-info → corp_extractor-0.9.0.dist-info}/WHEEL +0 -0
  75. {corp_extractor-0.4.0.dist-info → corp_extractor-0.9.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,293 @@
1
+ """
2
+ T5GemmaSplitter - Stage 1 plugin that wraps the existing StatementExtractor.
3
+
4
+ Uses T5-Gemma2 model with Diverse Beam Search to generate high-quality
5
+ subject-predicate-object triples from text.
6
+ """
7
+
8
+ import logging
9
+ import re
10
+ from typing import Optional
11
+
12
+ from ..base import BaseSplitterPlugin, PluginCapability
13
+ from ...pipeline.context import PipelineContext
14
+ from ...pipeline.registry import PluginRegistry
15
+ from ...models import RawTriple
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ @PluginRegistry.splitter
21
+ class T5GemmaSplitter(BaseSplitterPlugin):
22
+ """
23
+ Splitter plugin that uses T5-Gemma2 for triple extraction.
24
+
25
+ Wraps the existing StatementExtractor from extractor.py to produce
26
+ RawTriple objects for the pipeline.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ model_id: Optional[str] = None,
32
+ device: Optional[str] = None,
33
+ num_beams: int = 4,
34
+ diversity_penalty: float = 1.0,
35
+ max_new_tokens: int = 2048,
36
+ ):
37
+ """
38
+ Initialize the T5Gemma splitter.
39
+
40
+ Args:
41
+ model_id: HuggingFace model ID (defaults to Corp-o-Rate model)
42
+ device: Device to use (auto-detected if not specified)
43
+ num_beams: Number of beams for diverse beam search
44
+ diversity_penalty: Penalty for beam diversity
45
+ max_new_tokens: Maximum tokens to generate
46
+ """
47
+ self._model_id = model_id
48
+ self._device = device
49
+ self._num_beams = num_beams
50
+ self._diversity_penalty = diversity_penalty
51
+ self._max_new_tokens = max_new_tokens
52
+ self._extractor = None
53
+
54
+ @property
55
+ def name(self) -> str:
56
+ return "t5_gemma_splitter"
57
+
58
+ @property
59
+ def priority(self) -> int:
60
+ return 10 # High priority - primary splitter
61
+
62
+ @property
63
+ def capabilities(self) -> PluginCapability:
64
+ return PluginCapability.LLM_REQUIRED | PluginCapability.BATCH_PROCESSING
65
+
66
+ @property
67
+ def description(self) -> str:
68
+ return "T5-Gemma2 model for extracting triples using Diverse Beam Search"
69
+
70
+ @property
71
+ def model_vram_gb(self) -> float:
72
+ """T5-Gemma2 model weights ~2GB in bfloat16."""
73
+ return 2.0
74
+
75
+ @property
76
+ def per_item_vram_gb(self) -> float:
77
+ """Each text item during batch processing ~0.5GB for KV cache and activations."""
78
+ return 0.5
79
+
80
+ def _get_extractor(self):
81
+ """Lazy-load the StatementExtractor."""
82
+ if self._extractor is None:
83
+ from ...extractor import StatementExtractor
84
+ # Only pass model_id and device if they were explicitly set
85
+ kwargs = {}
86
+ if self._model_id is not None:
87
+ kwargs["model_id"] = self._model_id
88
+ if self._device is not None:
89
+ kwargs["device"] = self._device
90
+ self._extractor = StatementExtractor(**kwargs)
91
+ return self._extractor
92
+
93
+ def split(
94
+ self,
95
+ text: str,
96
+ context: PipelineContext,
97
+ ) -> list[RawTriple]:
98
+ """
99
+ Split text into raw triples using T5-Gemma2.
100
+
101
+ Args:
102
+ text: Input text to split
103
+ context: Pipeline context
104
+
105
+ Returns:
106
+ List of RawTriple objects
107
+ """
108
+ logger.debug(f"T5GemmaSplitter processing {len(text)} chars")
109
+
110
+ # Get options from context if available
111
+ splitter_options = context.source_metadata.get("splitter_options", {})
112
+ num_beams = splitter_options.get("num_beams", self._num_beams)
113
+ diversity_penalty = splitter_options.get("diversity_penalty", self._diversity_penalty)
114
+ max_new_tokens = splitter_options.get("max_new_tokens", self._max_new_tokens)
115
+
116
+ # Create extraction options
117
+ from ...models import ExtractionOptions as LegacyExtractionOptions
118
+ options = LegacyExtractionOptions(
119
+ num_beams=num_beams,
120
+ diversity_penalty=diversity_penalty,
121
+ max_new_tokens=max_new_tokens,
122
+ # Disable GLiNER and dedup - we handle those in later stages
123
+ use_gliner_extraction=False,
124
+ embedding_dedup=False,
125
+ deduplicate=False,
126
+ )
127
+
128
+ # Get raw XML from extractor
129
+ extractor = self._get_extractor()
130
+ xml_output = extractor.extract_as_xml(text, options)
131
+
132
+ # Parse XML to RawTriple objects
133
+ raw_triples = self._parse_xml_to_raw_triples(xml_output)
134
+
135
+ logger.info(f"T5GemmaSplitter produced {len(raw_triples)} raw triples")
136
+ return raw_triples
137
+
138
+ def split_batch(
139
+ self,
140
+ texts: list[str],
141
+ context: PipelineContext,
142
+ ) -> list[list[RawTriple]]:
143
+ """
144
+ Split multiple texts into atomic triples using batch processing.
145
+
146
+ Processes all texts through the T5-Gemma2 model in batches
147
+ sized for optimal GPU utilization.
148
+
149
+ Args:
150
+ texts: List of input texts to split
151
+ context: Pipeline context
152
+
153
+ Returns:
154
+ List of RawTriple lists, one per input text
155
+ """
156
+ if not texts:
157
+ return []
158
+
159
+ batch_size = self.get_optimal_batch_size()
160
+ logger.info(f"T5GemmaSplitter batch processing {len(texts)} texts with batch_size={batch_size}")
161
+
162
+ # Get options from context
163
+ splitter_options = context.source_metadata.get("splitter_options", {})
164
+ num_beams = splitter_options.get("num_beams", self._num_beams)
165
+ diversity_penalty = splitter_options.get("diversity_penalty", self._diversity_penalty)
166
+ max_new_tokens = splitter_options.get("max_new_tokens", self._max_new_tokens)
167
+
168
+ # Create extraction options
169
+ from ...models import ExtractionOptions as LegacyExtractionOptions
170
+ options = LegacyExtractionOptions(
171
+ num_beams=num_beams,
172
+ diversity_penalty=diversity_penalty,
173
+ max_new_tokens=max_new_tokens,
174
+ use_gliner_extraction=False,
175
+ embedding_dedup=False,
176
+ deduplicate=False,
177
+ )
178
+
179
+ extractor = self._get_extractor()
180
+ all_results: list[list[RawTriple]] = []
181
+
182
+ # Process in batches
183
+ for i in range(0, len(texts), batch_size):
184
+ batch_texts = texts[i:i + batch_size]
185
+ logger.debug(f"Processing batch {i // batch_size + 1}: {len(batch_texts)} texts")
186
+
187
+ batch_results = self._process_batch(batch_texts, extractor, options)
188
+ all_results.extend(batch_results)
189
+
190
+ total_triples = sum(len(r) for r in all_results)
191
+ logger.info(f"T5GemmaSplitter batch produced {total_triples} total triples from {len(texts)} texts")
192
+ return all_results
193
+
194
+ def _process_batch(
195
+ self,
196
+ texts: list[str],
197
+ extractor,
198
+ options,
199
+ ) -> list[list[RawTriple]]:
200
+ """
201
+ Process a batch of texts through the model.
202
+
203
+ Uses the model's batch generation capability for efficient GPU utilization.
204
+ """
205
+ import torch
206
+
207
+ # Wrap texts in page tags
208
+ wrapped_texts = [f"<page>{t}</page>" if not t.startswith("<page>") else t for t in texts]
209
+
210
+ # Tokenize batch
211
+ tokenizer = extractor.tokenizer
212
+ model = extractor.model
213
+
214
+ inputs = tokenizer(
215
+ wrapped_texts,
216
+ return_tensors="pt",
217
+ max_length=4096,
218
+ truncation=True,
219
+ padding=True,
220
+ ).to(extractor.device)
221
+
222
+ # Create stopping criteria
223
+ from ...extractor import StopOnSequence
224
+ from transformers import StoppingCriteriaList
225
+
226
+ input_length = inputs["input_ids"].shape[1]
227
+ stop_criteria = StopOnSequence(
228
+ tokenizer=tokenizer,
229
+ stop_sequence="</statements>",
230
+ input_length=input_length,
231
+ )
232
+
233
+ # Generate for all texts in batch
234
+ with torch.no_grad():
235
+ outputs = model.generate(
236
+ **inputs,
237
+ max_new_tokens=options.max_new_tokens,
238
+ max_length=None,
239
+ num_beams=options.num_beams,
240
+ num_beam_groups=options.num_beams,
241
+ num_return_sequences=1, # One sequence per input for batch
242
+ diversity_penalty=options.diversity_penalty,
243
+ do_sample=False,
244
+ top_p=None,
245
+ top_k=None,
246
+ trust_remote_code=True,
247
+ custom_generate="transformers-community/group-beam-search",
248
+ stopping_criteria=StoppingCriteriaList([stop_criteria]),
249
+ )
250
+
251
+ # Decode and parse each output
252
+ results: list[list[RawTriple]] = []
253
+ end_tag = "</statements>"
254
+
255
+ for output in outputs:
256
+ decoded = tokenizer.decode(output, skip_special_tokens=True)
257
+
258
+ # Truncate at </statements>
259
+ if end_tag in decoded:
260
+ end_pos = decoded.find(end_tag) + len(end_tag)
261
+ decoded = decoded[:end_pos]
262
+
263
+ triples = self._parse_xml_to_raw_triples(decoded)
264
+ results.append(triples)
265
+
266
+ return results
267
+
268
+ # Regex pattern to extract <text> content from <stmt> blocks
269
+ _STMT_TEXT_PATTERN = re.compile(r'<stmt>.*?<text>(.*?)</text>.*?</stmt>', re.DOTALL)
270
+
271
+ def _parse_xml_to_raw_triples(self, xml_output: str) -> list[RawTriple]:
272
+ """Extract source sentences from <stmt><text>...</text></stmt> blocks."""
273
+ raw_triples = []
274
+
275
+ # Find all <text> content within <stmt> blocks
276
+ text_matches = self._STMT_TEXT_PATTERN.findall(xml_output)
277
+ logger.debug(f"Found {len(text_matches)} stmt text blocks via regex")
278
+
279
+ for source_text in text_matches:
280
+ source_text = source_text.strip()
281
+ if source_text:
282
+ raw_triples.append(RawTriple(
283
+ subject_text="",
284
+ predicate_text="",
285
+ object_text="",
286
+ source_sentence=source_text,
287
+ ))
288
+
289
+ return raw_triples
290
+
291
+
292
+ # Allow importing without decorator for testing
293
+ T5GemmaSplitterClass = T5GemmaSplitter
@@ -0,0 +1,13 @@
1
+ """
2
+ Taxonomy classifier plugins for Stage 6 (Taxonomy).
3
+
4
+ Classifies statements against large taxonomies using MNLI or embeddings.
5
+ """
6
+
7
+ from .mnli import MNLITaxonomyClassifier
8
+ from .embedding import EmbeddingTaxonomyClassifier
9
+
10
+ __all__ = [
11
+ "MNLITaxonomyClassifier",
12
+ "EmbeddingTaxonomyClassifier",
13
+ ]