corp-extractor 0.4.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {corp_extractor-0.4.0.dist-info → corp_extractor-0.9.0.dist-info}/METADATA +348 -64
- corp_extractor-0.9.0.dist-info/RECORD +76 -0
- statement_extractor/__init__.py +10 -1
- statement_extractor/cli.py +1663 -17
- statement_extractor/data/default_predicates.json +368 -0
- statement_extractor/data/statement_taxonomy.json +6972 -0
- statement_extractor/database/__init__.py +52 -0
- statement_extractor/database/embeddings.py +186 -0
- statement_extractor/database/hub.py +520 -0
- statement_extractor/database/importers/__init__.py +24 -0
- statement_extractor/database/importers/companies_house.py +545 -0
- statement_extractor/database/importers/gleif.py +538 -0
- statement_extractor/database/importers/sec_edgar.py +375 -0
- statement_extractor/database/importers/wikidata.py +1012 -0
- statement_extractor/database/importers/wikidata_people.py +632 -0
- statement_extractor/database/models.py +230 -0
- statement_extractor/database/resolver.py +245 -0
- statement_extractor/database/store.py +1609 -0
- statement_extractor/document/__init__.py +62 -0
- statement_extractor/document/chunker.py +410 -0
- statement_extractor/document/context.py +171 -0
- statement_extractor/document/deduplicator.py +173 -0
- statement_extractor/document/html_extractor.py +246 -0
- statement_extractor/document/loader.py +303 -0
- statement_extractor/document/pipeline.py +388 -0
- statement_extractor/document/summarizer.py +195 -0
- statement_extractor/extractor.py +1 -23
- statement_extractor/gliner_extraction.py +4 -74
- statement_extractor/llm.py +255 -0
- statement_extractor/models/__init__.py +89 -0
- statement_extractor/models/canonical.py +182 -0
- statement_extractor/models/document.py +308 -0
- statement_extractor/models/entity.py +102 -0
- statement_extractor/models/labels.py +220 -0
- statement_extractor/models/qualifiers.py +139 -0
- statement_extractor/models/statement.py +101 -0
- statement_extractor/models.py +4 -1
- statement_extractor/pipeline/__init__.py +39 -0
- statement_extractor/pipeline/config.py +129 -0
- statement_extractor/pipeline/context.py +177 -0
- statement_extractor/pipeline/orchestrator.py +416 -0
- statement_extractor/pipeline/registry.py +303 -0
- statement_extractor/plugins/__init__.py +55 -0
- statement_extractor/plugins/base.py +716 -0
- statement_extractor/plugins/extractors/__init__.py +13 -0
- statement_extractor/plugins/extractors/base.py +9 -0
- statement_extractor/plugins/extractors/gliner2.py +546 -0
- statement_extractor/plugins/labelers/__init__.py +29 -0
- statement_extractor/plugins/labelers/base.py +9 -0
- statement_extractor/plugins/labelers/confidence.py +138 -0
- statement_extractor/plugins/labelers/relation_type.py +87 -0
- statement_extractor/plugins/labelers/sentiment.py +159 -0
- statement_extractor/plugins/labelers/taxonomy.py +386 -0
- statement_extractor/plugins/labelers/taxonomy_embedding.py +477 -0
- statement_extractor/plugins/pdf/__init__.py +10 -0
- statement_extractor/plugins/pdf/pypdf.py +291 -0
- statement_extractor/plugins/qualifiers/__init__.py +30 -0
- statement_extractor/plugins/qualifiers/base.py +9 -0
- statement_extractor/plugins/qualifiers/companies_house.py +185 -0
- statement_extractor/plugins/qualifiers/embedding_company.py +420 -0
- statement_extractor/plugins/qualifiers/gleif.py +197 -0
- statement_extractor/plugins/qualifiers/person.py +785 -0
- statement_extractor/plugins/qualifiers/sec_edgar.py +209 -0
- statement_extractor/plugins/scrapers/__init__.py +10 -0
- statement_extractor/plugins/scrapers/http.py +236 -0
- statement_extractor/plugins/splitters/__init__.py +13 -0
- statement_extractor/plugins/splitters/base.py +9 -0
- statement_extractor/plugins/splitters/t5_gemma.py +293 -0
- statement_extractor/plugins/taxonomy/__init__.py +13 -0
- statement_extractor/plugins/taxonomy/embedding.py +484 -0
- statement_extractor/plugins/taxonomy/mnli.py +291 -0
- statement_extractor/scoring.py +8 -8
- corp_extractor-0.4.0.dist-info/RECORD +0 -12
- {corp_extractor-0.4.0.dist-info → corp_extractor-0.9.0.dist-info}/WHEEL +0 -0
- {corp_extractor-0.4.0.dist-info → corp_extractor-0.9.0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
"""
|
|
2
|
+
T5GemmaSplitter - Stage 1 plugin that wraps the existing StatementExtractor.
|
|
3
|
+
|
|
4
|
+
Uses T5-Gemma2 model with Diverse Beam Search to generate high-quality
|
|
5
|
+
subject-predicate-object triples from text.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
import re
|
|
10
|
+
from typing import Optional
|
|
11
|
+
|
|
12
|
+
from ..base import BaseSplitterPlugin, PluginCapability
|
|
13
|
+
from ...pipeline.context import PipelineContext
|
|
14
|
+
from ...pipeline.registry import PluginRegistry
|
|
15
|
+
from ...models import RawTriple
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@PluginRegistry.splitter
|
|
21
|
+
class T5GemmaSplitter(BaseSplitterPlugin):
|
|
22
|
+
"""
|
|
23
|
+
Splitter plugin that uses T5-Gemma2 for triple extraction.
|
|
24
|
+
|
|
25
|
+
Wraps the existing StatementExtractor from extractor.py to produce
|
|
26
|
+
RawTriple objects for the pipeline.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
model_id: Optional[str] = None,
|
|
32
|
+
device: Optional[str] = None,
|
|
33
|
+
num_beams: int = 4,
|
|
34
|
+
diversity_penalty: float = 1.0,
|
|
35
|
+
max_new_tokens: int = 2048,
|
|
36
|
+
):
|
|
37
|
+
"""
|
|
38
|
+
Initialize the T5Gemma splitter.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
model_id: HuggingFace model ID (defaults to Corp-o-Rate model)
|
|
42
|
+
device: Device to use (auto-detected if not specified)
|
|
43
|
+
num_beams: Number of beams for diverse beam search
|
|
44
|
+
diversity_penalty: Penalty for beam diversity
|
|
45
|
+
max_new_tokens: Maximum tokens to generate
|
|
46
|
+
"""
|
|
47
|
+
self._model_id = model_id
|
|
48
|
+
self._device = device
|
|
49
|
+
self._num_beams = num_beams
|
|
50
|
+
self._diversity_penalty = diversity_penalty
|
|
51
|
+
self._max_new_tokens = max_new_tokens
|
|
52
|
+
self._extractor = None
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def name(self) -> str:
|
|
56
|
+
return "t5_gemma_splitter"
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def priority(self) -> int:
|
|
60
|
+
return 10 # High priority - primary splitter
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def capabilities(self) -> PluginCapability:
|
|
64
|
+
return PluginCapability.LLM_REQUIRED | PluginCapability.BATCH_PROCESSING
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def description(self) -> str:
|
|
68
|
+
return "T5-Gemma2 model for extracting triples using Diverse Beam Search"
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def model_vram_gb(self) -> float:
|
|
72
|
+
"""T5-Gemma2 model weights ~2GB in bfloat16."""
|
|
73
|
+
return 2.0
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def per_item_vram_gb(self) -> float:
|
|
77
|
+
"""Each text item during batch processing ~0.5GB for KV cache and activations."""
|
|
78
|
+
return 0.5
|
|
79
|
+
|
|
80
|
+
def _get_extractor(self):
|
|
81
|
+
"""Lazy-load the StatementExtractor."""
|
|
82
|
+
if self._extractor is None:
|
|
83
|
+
from ...extractor import StatementExtractor
|
|
84
|
+
# Only pass model_id and device if they were explicitly set
|
|
85
|
+
kwargs = {}
|
|
86
|
+
if self._model_id is not None:
|
|
87
|
+
kwargs["model_id"] = self._model_id
|
|
88
|
+
if self._device is not None:
|
|
89
|
+
kwargs["device"] = self._device
|
|
90
|
+
self._extractor = StatementExtractor(**kwargs)
|
|
91
|
+
return self._extractor
|
|
92
|
+
|
|
93
|
+
def split(
|
|
94
|
+
self,
|
|
95
|
+
text: str,
|
|
96
|
+
context: PipelineContext,
|
|
97
|
+
) -> list[RawTriple]:
|
|
98
|
+
"""
|
|
99
|
+
Split text into raw triples using T5-Gemma2.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
text: Input text to split
|
|
103
|
+
context: Pipeline context
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
List of RawTriple objects
|
|
107
|
+
"""
|
|
108
|
+
logger.debug(f"T5GemmaSplitter processing {len(text)} chars")
|
|
109
|
+
|
|
110
|
+
# Get options from context if available
|
|
111
|
+
splitter_options = context.source_metadata.get("splitter_options", {})
|
|
112
|
+
num_beams = splitter_options.get("num_beams", self._num_beams)
|
|
113
|
+
diversity_penalty = splitter_options.get("diversity_penalty", self._diversity_penalty)
|
|
114
|
+
max_new_tokens = splitter_options.get("max_new_tokens", self._max_new_tokens)
|
|
115
|
+
|
|
116
|
+
# Create extraction options
|
|
117
|
+
from ...models import ExtractionOptions as LegacyExtractionOptions
|
|
118
|
+
options = LegacyExtractionOptions(
|
|
119
|
+
num_beams=num_beams,
|
|
120
|
+
diversity_penalty=diversity_penalty,
|
|
121
|
+
max_new_tokens=max_new_tokens,
|
|
122
|
+
# Disable GLiNER and dedup - we handle those in later stages
|
|
123
|
+
use_gliner_extraction=False,
|
|
124
|
+
embedding_dedup=False,
|
|
125
|
+
deduplicate=False,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# Get raw XML from extractor
|
|
129
|
+
extractor = self._get_extractor()
|
|
130
|
+
xml_output = extractor.extract_as_xml(text, options)
|
|
131
|
+
|
|
132
|
+
# Parse XML to RawTriple objects
|
|
133
|
+
raw_triples = self._parse_xml_to_raw_triples(xml_output)
|
|
134
|
+
|
|
135
|
+
logger.info(f"T5GemmaSplitter produced {len(raw_triples)} raw triples")
|
|
136
|
+
return raw_triples
|
|
137
|
+
|
|
138
|
+
def split_batch(
|
|
139
|
+
self,
|
|
140
|
+
texts: list[str],
|
|
141
|
+
context: PipelineContext,
|
|
142
|
+
) -> list[list[RawTriple]]:
|
|
143
|
+
"""
|
|
144
|
+
Split multiple texts into atomic triples using batch processing.
|
|
145
|
+
|
|
146
|
+
Processes all texts through the T5-Gemma2 model in batches
|
|
147
|
+
sized for optimal GPU utilization.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
texts: List of input texts to split
|
|
151
|
+
context: Pipeline context
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
List of RawTriple lists, one per input text
|
|
155
|
+
"""
|
|
156
|
+
if not texts:
|
|
157
|
+
return []
|
|
158
|
+
|
|
159
|
+
batch_size = self.get_optimal_batch_size()
|
|
160
|
+
logger.info(f"T5GemmaSplitter batch processing {len(texts)} texts with batch_size={batch_size}")
|
|
161
|
+
|
|
162
|
+
# Get options from context
|
|
163
|
+
splitter_options = context.source_metadata.get("splitter_options", {})
|
|
164
|
+
num_beams = splitter_options.get("num_beams", self._num_beams)
|
|
165
|
+
diversity_penalty = splitter_options.get("diversity_penalty", self._diversity_penalty)
|
|
166
|
+
max_new_tokens = splitter_options.get("max_new_tokens", self._max_new_tokens)
|
|
167
|
+
|
|
168
|
+
# Create extraction options
|
|
169
|
+
from ...models import ExtractionOptions as LegacyExtractionOptions
|
|
170
|
+
options = LegacyExtractionOptions(
|
|
171
|
+
num_beams=num_beams,
|
|
172
|
+
diversity_penalty=diversity_penalty,
|
|
173
|
+
max_new_tokens=max_new_tokens,
|
|
174
|
+
use_gliner_extraction=False,
|
|
175
|
+
embedding_dedup=False,
|
|
176
|
+
deduplicate=False,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
extractor = self._get_extractor()
|
|
180
|
+
all_results: list[list[RawTriple]] = []
|
|
181
|
+
|
|
182
|
+
# Process in batches
|
|
183
|
+
for i in range(0, len(texts), batch_size):
|
|
184
|
+
batch_texts = texts[i:i + batch_size]
|
|
185
|
+
logger.debug(f"Processing batch {i // batch_size + 1}: {len(batch_texts)} texts")
|
|
186
|
+
|
|
187
|
+
batch_results = self._process_batch(batch_texts, extractor, options)
|
|
188
|
+
all_results.extend(batch_results)
|
|
189
|
+
|
|
190
|
+
total_triples = sum(len(r) for r in all_results)
|
|
191
|
+
logger.info(f"T5GemmaSplitter batch produced {total_triples} total triples from {len(texts)} texts")
|
|
192
|
+
return all_results
|
|
193
|
+
|
|
194
|
+
def _process_batch(
|
|
195
|
+
self,
|
|
196
|
+
texts: list[str],
|
|
197
|
+
extractor,
|
|
198
|
+
options,
|
|
199
|
+
) -> list[list[RawTriple]]:
|
|
200
|
+
"""
|
|
201
|
+
Process a batch of texts through the model.
|
|
202
|
+
|
|
203
|
+
Uses the model's batch generation capability for efficient GPU utilization.
|
|
204
|
+
"""
|
|
205
|
+
import torch
|
|
206
|
+
|
|
207
|
+
# Wrap texts in page tags
|
|
208
|
+
wrapped_texts = [f"<page>{t}</page>" if not t.startswith("<page>") else t for t in texts]
|
|
209
|
+
|
|
210
|
+
# Tokenize batch
|
|
211
|
+
tokenizer = extractor.tokenizer
|
|
212
|
+
model = extractor.model
|
|
213
|
+
|
|
214
|
+
inputs = tokenizer(
|
|
215
|
+
wrapped_texts,
|
|
216
|
+
return_tensors="pt",
|
|
217
|
+
max_length=4096,
|
|
218
|
+
truncation=True,
|
|
219
|
+
padding=True,
|
|
220
|
+
).to(extractor.device)
|
|
221
|
+
|
|
222
|
+
# Create stopping criteria
|
|
223
|
+
from ...extractor import StopOnSequence
|
|
224
|
+
from transformers import StoppingCriteriaList
|
|
225
|
+
|
|
226
|
+
input_length = inputs["input_ids"].shape[1]
|
|
227
|
+
stop_criteria = StopOnSequence(
|
|
228
|
+
tokenizer=tokenizer,
|
|
229
|
+
stop_sequence="</statements>",
|
|
230
|
+
input_length=input_length,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
# Generate for all texts in batch
|
|
234
|
+
with torch.no_grad():
|
|
235
|
+
outputs = model.generate(
|
|
236
|
+
**inputs,
|
|
237
|
+
max_new_tokens=options.max_new_tokens,
|
|
238
|
+
max_length=None,
|
|
239
|
+
num_beams=options.num_beams,
|
|
240
|
+
num_beam_groups=options.num_beams,
|
|
241
|
+
num_return_sequences=1, # One sequence per input for batch
|
|
242
|
+
diversity_penalty=options.diversity_penalty,
|
|
243
|
+
do_sample=False,
|
|
244
|
+
top_p=None,
|
|
245
|
+
top_k=None,
|
|
246
|
+
trust_remote_code=True,
|
|
247
|
+
custom_generate="transformers-community/group-beam-search",
|
|
248
|
+
stopping_criteria=StoppingCriteriaList([stop_criteria]),
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
# Decode and parse each output
|
|
252
|
+
results: list[list[RawTriple]] = []
|
|
253
|
+
end_tag = "</statements>"
|
|
254
|
+
|
|
255
|
+
for output in outputs:
|
|
256
|
+
decoded = tokenizer.decode(output, skip_special_tokens=True)
|
|
257
|
+
|
|
258
|
+
# Truncate at </statements>
|
|
259
|
+
if end_tag in decoded:
|
|
260
|
+
end_pos = decoded.find(end_tag) + len(end_tag)
|
|
261
|
+
decoded = decoded[:end_pos]
|
|
262
|
+
|
|
263
|
+
triples = self._parse_xml_to_raw_triples(decoded)
|
|
264
|
+
results.append(triples)
|
|
265
|
+
|
|
266
|
+
return results
|
|
267
|
+
|
|
268
|
+
# Regex pattern to extract <text> content from <stmt> blocks
|
|
269
|
+
_STMT_TEXT_PATTERN = re.compile(r'<stmt>.*?<text>(.*?)</text>.*?</stmt>', re.DOTALL)
|
|
270
|
+
|
|
271
|
+
def _parse_xml_to_raw_triples(self, xml_output: str) -> list[RawTriple]:
|
|
272
|
+
"""Extract source sentences from <stmt><text>...</text></stmt> blocks."""
|
|
273
|
+
raw_triples = []
|
|
274
|
+
|
|
275
|
+
# Find all <text> content within <stmt> blocks
|
|
276
|
+
text_matches = self._STMT_TEXT_PATTERN.findall(xml_output)
|
|
277
|
+
logger.debug(f"Found {len(text_matches)} stmt text blocks via regex")
|
|
278
|
+
|
|
279
|
+
for source_text in text_matches:
|
|
280
|
+
source_text = source_text.strip()
|
|
281
|
+
if source_text:
|
|
282
|
+
raw_triples.append(RawTriple(
|
|
283
|
+
subject_text="",
|
|
284
|
+
predicate_text="",
|
|
285
|
+
object_text="",
|
|
286
|
+
source_sentence=source_text,
|
|
287
|
+
))
|
|
288
|
+
|
|
289
|
+
return raw_triples
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
# Allow importing without decorator for testing
|
|
293
|
+
T5GemmaSplitterClass = T5GemmaSplitter
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Taxonomy classifier plugins for Stage 6 (Taxonomy).
|
|
3
|
+
|
|
4
|
+
Classifies statements against large taxonomies using MNLI or embeddings.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .mnli import MNLITaxonomyClassifier
|
|
8
|
+
from .embedding import EmbeddingTaxonomyClassifier
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"MNLITaxonomyClassifier",
|
|
12
|
+
"EmbeddingTaxonomyClassifier",
|
|
13
|
+
]
|