corp-extractor 0.4.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {corp_extractor-0.4.0.dist-info → corp_extractor-0.9.0.dist-info}/METADATA +348 -64
- corp_extractor-0.9.0.dist-info/RECORD +76 -0
- statement_extractor/__init__.py +10 -1
- statement_extractor/cli.py +1663 -17
- statement_extractor/data/default_predicates.json +368 -0
- statement_extractor/data/statement_taxonomy.json +6972 -0
- statement_extractor/database/__init__.py +52 -0
- statement_extractor/database/embeddings.py +186 -0
- statement_extractor/database/hub.py +520 -0
- statement_extractor/database/importers/__init__.py +24 -0
- statement_extractor/database/importers/companies_house.py +545 -0
- statement_extractor/database/importers/gleif.py +538 -0
- statement_extractor/database/importers/sec_edgar.py +375 -0
- statement_extractor/database/importers/wikidata.py +1012 -0
- statement_extractor/database/importers/wikidata_people.py +632 -0
- statement_extractor/database/models.py +230 -0
- statement_extractor/database/resolver.py +245 -0
- statement_extractor/database/store.py +1609 -0
- statement_extractor/document/__init__.py +62 -0
- statement_extractor/document/chunker.py +410 -0
- statement_extractor/document/context.py +171 -0
- statement_extractor/document/deduplicator.py +173 -0
- statement_extractor/document/html_extractor.py +246 -0
- statement_extractor/document/loader.py +303 -0
- statement_extractor/document/pipeline.py +388 -0
- statement_extractor/document/summarizer.py +195 -0
- statement_extractor/extractor.py +1 -23
- statement_extractor/gliner_extraction.py +4 -74
- statement_extractor/llm.py +255 -0
- statement_extractor/models/__init__.py +89 -0
- statement_extractor/models/canonical.py +182 -0
- statement_extractor/models/document.py +308 -0
- statement_extractor/models/entity.py +102 -0
- statement_extractor/models/labels.py +220 -0
- statement_extractor/models/qualifiers.py +139 -0
- statement_extractor/models/statement.py +101 -0
- statement_extractor/models.py +4 -1
- statement_extractor/pipeline/__init__.py +39 -0
- statement_extractor/pipeline/config.py +129 -0
- statement_extractor/pipeline/context.py +177 -0
- statement_extractor/pipeline/orchestrator.py +416 -0
- statement_extractor/pipeline/registry.py +303 -0
- statement_extractor/plugins/__init__.py +55 -0
- statement_extractor/plugins/base.py +716 -0
- statement_extractor/plugins/extractors/__init__.py +13 -0
- statement_extractor/plugins/extractors/base.py +9 -0
- statement_extractor/plugins/extractors/gliner2.py +546 -0
- statement_extractor/plugins/labelers/__init__.py +29 -0
- statement_extractor/plugins/labelers/base.py +9 -0
- statement_extractor/plugins/labelers/confidence.py +138 -0
- statement_extractor/plugins/labelers/relation_type.py +87 -0
- statement_extractor/plugins/labelers/sentiment.py +159 -0
- statement_extractor/plugins/labelers/taxonomy.py +386 -0
- statement_extractor/plugins/labelers/taxonomy_embedding.py +477 -0
- statement_extractor/plugins/pdf/__init__.py +10 -0
- statement_extractor/plugins/pdf/pypdf.py +291 -0
- statement_extractor/plugins/qualifiers/__init__.py +30 -0
- statement_extractor/plugins/qualifiers/base.py +9 -0
- statement_extractor/plugins/qualifiers/companies_house.py +185 -0
- statement_extractor/plugins/qualifiers/embedding_company.py +420 -0
- statement_extractor/plugins/qualifiers/gleif.py +197 -0
- statement_extractor/plugins/qualifiers/person.py +785 -0
- statement_extractor/plugins/qualifiers/sec_edgar.py +209 -0
- statement_extractor/plugins/scrapers/__init__.py +10 -0
- statement_extractor/plugins/scrapers/http.py +236 -0
- statement_extractor/plugins/splitters/__init__.py +13 -0
- statement_extractor/plugins/splitters/base.py +9 -0
- statement_extractor/plugins/splitters/t5_gemma.py +293 -0
- statement_extractor/plugins/taxonomy/__init__.py +13 -0
- statement_extractor/plugins/taxonomy/embedding.py +484 -0
- statement_extractor/plugins/taxonomy/mnli.py +291 -0
- statement_extractor/scoring.py +8 -8
- corp_extractor-0.4.0.dist-info/RECORD +0 -12
- {corp_extractor-0.4.0.dist-info → corp_extractor-0.9.0.dist-info}/WHEEL +0 -0
- {corp_extractor-0.4.0.dist-info → corp_extractor-0.9.0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,386 @@
|
|
|
1
|
+
"""
|
|
2
|
+
TaxonomyLabeler - Classifies statements against a large taxonomy using MNLI.
|
|
3
|
+
|
|
4
|
+
Uses zero-shot classification with MNLI models for taxonomy labeling where
|
|
5
|
+
there are too many possible values for simple multi-choice classification.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import logging
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Optional, TypedDict
|
|
12
|
+
|
|
13
|
+
from ..base import BaseLabelerPlugin, TaxonomySchema, PluginCapability
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TaxonomyEntry(TypedDict):
|
|
17
|
+
"""Structure for each taxonomy label entry."""
|
|
18
|
+
description: str
|
|
19
|
+
id: int
|
|
20
|
+
mnli_label: str
|
|
21
|
+
embedding_label: str
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
from ...pipeline.context import PipelineContext
|
|
25
|
+
from ...models import (
|
|
26
|
+
PipelineStatement,
|
|
27
|
+
CanonicalEntity,
|
|
28
|
+
StatementLabel,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
# Default taxonomy file location (relative to this module)
|
|
34
|
+
DEFAULT_TAXONOMY_PATH = Path(__file__).parent.parent.parent / "data" / "statement_taxonomy.json"
|
|
35
|
+
|
|
36
|
+
# Default categories to use (all of them)
|
|
37
|
+
DEFAULT_CATEGORIES = [
|
|
38
|
+
"environment",
|
|
39
|
+
"society",
|
|
40
|
+
"governance",
|
|
41
|
+
"animals",
|
|
42
|
+
"industry",
|
|
43
|
+
"human_harm",
|
|
44
|
+
"human_benefit",
|
|
45
|
+
"animal_harm",
|
|
46
|
+
"animal_benefit",
|
|
47
|
+
"environment_harm",
|
|
48
|
+
"environment_benefit",
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class TaxonomyClassifier:
|
|
53
|
+
"""
|
|
54
|
+
MNLI-based zero-shot classifier for taxonomy labeling.
|
|
55
|
+
|
|
56
|
+
Uses HuggingFace transformers zero-shot-classification pipeline.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
model_id: str = "facebook/bart-large-mnli",
|
|
62
|
+
device: Optional[str] = None,
|
|
63
|
+
):
|
|
64
|
+
"""
|
|
65
|
+
Initialize the classifier.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
model_id: HuggingFace model ID for MNLI classification
|
|
69
|
+
device: Device to use ('cuda', 'mps', 'cpu', or None for auto)
|
|
70
|
+
"""
|
|
71
|
+
self._model_id = model_id
|
|
72
|
+
self._device = device
|
|
73
|
+
self._classifier = None
|
|
74
|
+
|
|
75
|
+
def _load_classifier(self):
|
|
76
|
+
"""Lazy-load the zero-shot classification pipeline."""
|
|
77
|
+
if self._classifier is not None:
|
|
78
|
+
return
|
|
79
|
+
|
|
80
|
+
try:
|
|
81
|
+
from transformers import pipeline
|
|
82
|
+
import torch
|
|
83
|
+
|
|
84
|
+
# Auto-detect device if not specified
|
|
85
|
+
device = self._device
|
|
86
|
+
if device is None:
|
|
87
|
+
if torch.cuda.is_available():
|
|
88
|
+
device = "cuda"
|
|
89
|
+
elif torch.backends.mps.is_available():
|
|
90
|
+
device = "mps"
|
|
91
|
+
else:
|
|
92
|
+
device = "cpu"
|
|
93
|
+
|
|
94
|
+
logger.info(f"Loading MNLI classifier '{self._model_id}' on {device}...")
|
|
95
|
+
self._classifier = pipeline(
|
|
96
|
+
"zero-shot-classification",
|
|
97
|
+
model=self._model_id,
|
|
98
|
+
device=device if device != "cpu" else -1,
|
|
99
|
+
)
|
|
100
|
+
logger.debug("MNLI classifier loaded")
|
|
101
|
+
|
|
102
|
+
except ImportError as e:
|
|
103
|
+
raise ImportError(
|
|
104
|
+
"transformers is required for MNLI classification. "
|
|
105
|
+
"Install with: pip install transformers"
|
|
106
|
+
) from e
|
|
107
|
+
|
|
108
|
+
def classify(
|
|
109
|
+
self,
|
|
110
|
+
text: str,
|
|
111
|
+
candidate_labels: list[str],
|
|
112
|
+
multi_label: bool = False,
|
|
113
|
+
) -> tuple[str, float]:
|
|
114
|
+
"""
|
|
115
|
+
Classify text against candidate labels using MNLI.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
text: Text to classify
|
|
119
|
+
candidate_labels: List of possible labels
|
|
120
|
+
multi_label: Whether multiple labels can apply
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
Tuple of (best_label, confidence)
|
|
124
|
+
"""
|
|
125
|
+
self._load_classifier()
|
|
126
|
+
|
|
127
|
+
result = self._classifier(
|
|
128
|
+
text,
|
|
129
|
+
candidate_labels=candidate_labels,
|
|
130
|
+
multi_label=multi_label,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# Result format: {'sequence': '...', 'labels': [...], 'scores': [...]}
|
|
134
|
+
best_label = result["labels"][0]
|
|
135
|
+
confidence = result["scores"][0]
|
|
136
|
+
|
|
137
|
+
return best_label, confidence
|
|
138
|
+
|
|
139
|
+
def classify_hierarchical(
|
|
140
|
+
self,
|
|
141
|
+
text: str,
|
|
142
|
+
taxonomy: dict[str, list[str]],
|
|
143
|
+
top_k_categories: int = 3,
|
|
144
|
+
) -> tuple[str, str, float]:
|
|
145
|
+
"""
|
|
146
|
+
Hierarchical classification: first category, then label within category.
|
|
147
|
+
|
|
148
|
+
More efficient than flat classification for large taxonomies.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
text: Text to classify
|
|
152
|
+
taxonomy: Dict mapping category -> list of labels
|
|
153
|
+
top_k_categories: Number of top categories to consider
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Tuple of (category, label, confidence)
|
|
157
|
+
"""
|
|
158
|
+
self._load_classifier()
|
|
159
|
+
|
|
160
|
+
# Step 1: Classify into category
|
|
161
|
+
categories = list(taxonomy.keys())
|
|
162
|
+
cat_result = self._classifier(text, candidate_labels=categories)
|
|
163
|
+
|
|
164
|
+
# Get top-k categories
|
|
165
|
+
top_categories = cat_result["labels"][:top_k_categories]
|
|
166
|
+
top_cat_scores = cat_result["scores"][:top_k_categories]
|
|
167
|
+
|
|
168
|
+
# Step 2: Classify within top categories
|
|
169
|
+
best_label = None
|
|
170
|
+
best_category = None
|
|
171
|
+
best_score = 0.0
|
|
172
|
+
|
|
173
|
+
for cat, cat_score in zip(top_categories, top_cat_scores):
|
|
174
|
+
labels = taxonomy[cat]
|
|
175
|
+
if not labels:
|
|
176
|
+
continue
|
|
177
|
+
|
|
178
|
+
label_result = self._classifier(text, candidate_labels=labels)
|
|
179
|
+
label = label_result["labels"][0]
|
|
180
|
+
label_score = label_result["scores"][0]
|
|
181
|
+
|
|
182
|
+
# Combined score: category confidence * label confidence
|
|
183
|
+
combined_score = cat_score * label_score
|
|
184
|
+
|
|
185
|
+
if combined_score > best_score:
|
|
186
|
+
best_score = combined_score
|
|
187
|
+
best_label = label
|
|
188
|
+
best_category = cat
|
|
189
|
+
|
|
190
|
+
return best_category, best_label, best_score
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class TaxonomyLabeler(BaseLabelerPlugin):
|
|
194
|
+
"""
|
|
195
|
+
Labeler that classifies statements against a large taxonomy using MNLI.
|
|
196
|
+
|
|
197
|
+
Supports hierarchical classification for efficiency with large taxonomies.
|
|
198
|
+
"""
|
|
199
|
+
|
|
200
|
+
def __init__(
|
|
201
|
+
self,
|
|
202
|
+
taxonomy_path: Optional[str | Path] = None,
|
|
203
|
+
categories: Optional[list[str]] = None,
|
|
204
|
+
model_id: str = "facebook/bart-large-mnli",
|
|
205
|
+
use_hierarchical: bool = True,
|
|
206
|
+
top_k_categories: int = 3,
|
|
207
|
+
min_confidence: float = 0.3,
|
|
208
|
+
):
|
|
209
|
+
"""
|
|
210
|
+
Initialize the taxonomy labeler.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
taxonomy_path: Path to taxonomy JSON file (default: built-in taxonomy)
|
|
214
|
+
categories: List of categories to use (default: all categories)
|
|
215
|
+
model_id: HuggingFace model ID for MNLI classifier
|
|
216
|
+
use_hierarchical: Use hierarchical classification (category then label)
|
|
217
|
+
top_k_categories: Number of top categories to consider in hierarchical mode
|
|
218
|
+
min_confidence: Minimum confidence threshold for returning a label
|
|
219
|
+
"""
|
|
220
|
+
self._taxonomy_path = Path(taxonomy_path) if taxonomy_path else DEFAULT_TAXONOMY_PATH
|
|
221
|
+
self._categories = categories or DEFAULT_CATEGORIES
|
|
222
|
+
self._model_id = model_id
|
|
223
|
+
self._use_hierarchical = use_hierarchical
|
|
224
|
+
self._top_k_categories = top_k_categories
|
|
225
|
+
self._min_confidence = min_confidence
|
|
226
|
+
|
|
227
|
+
self._taxonomy: Optional[dict[str, dict[str, TaxonomyEntry]]] = None
|
|
228
|
+
self._classifier: Optional[TaxonomyClassifier] = None
|
|
229
|
+
|
|
230
|
+
@property
|
|
231
|
+
def name(self) -> str:
|
|
232
|
+
return "taxonomy_labeler"
|
|
233
|
+
|
|
234
|
+
@property
|
|
235
|
+
def priority(self) -> int:
|
|
236
|
+
return 60 # Lower priority than embedding taxonomy (use --plugins taxonomy_labeler to enable)
|
|
237
|
+
|
|
238
|
+
@property
|
|
239
|
+
def capabilities(self) -> PluginCapability:
|
|
240
|
+
return PluginCapability.LLM_REQUIRED
|
|
241
|
+
|
|
242
|
+
@property
|
|
243
|
+
def description(self) -> str:
|
|
244
|
+
return "Classifies statements against a taxonomy using MNLI zero-shot classification"
|
|
245
|
+
|
|
246
|
+
@property
|
|
247
|
+
def label_type(self) -> str:
|
|
248
|
+
return "taxonomy"
|
|
249
|
+
|
|
250
|
+
@property
|
|
251
|
+
def taxonomy_schema(self) -> TaxonomySchema:
|
|
252
|
+
"""Provide taxonomy schema (for documentation/introspection)."""
|
|
253
|
+
taxonomy = self._load_taxonomy()
|
|
254
|
+
# Filter to selected categories
|
|
255
|
+
filtered = {cat: list(labels.keys()) for cat, labels in taxonomy.items() if cat in self._categories}
|
|
256
|
+
return TaxonomySchema(
|
|
257
|
+
label_type=self.label_type,
|
|
258
|
+
values=filtered,
|
|
259
|
+
description="Statement topic classification against corporate ESG taxonomy",
|
|
260
|
+
scope="statement",
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
def _load_taxonomy(self) -> dict[str, dict[str, TaxonomyEntry]]:
|
|
264
|
+
"""Load taxonomy from JSON file."""
|
|
265
|
+
if self._taxonomy is not None:
|
|
266
|
+
return self._taxonomy
|
|
267
|
+
|
|
268
|
+
if not self._taxonomy_path.exists():
|
|
269
|
+
raise FileNotFoundError(f"Taxonomy file not found: {self._taxonomy_path}")
|
|
270
|
+
|
|
271
|
+
with open(self._taxonomy_path) as f:
|
|
272
|
+
self._taxonomy = json.load(f)
|
|
273
|
+
|
|
274
|
+
logger.debug(f"Loaded taxonomy with {len(self._taxonomy)} categories")
|
|
275
|
+
return self._taxonomy
|
|
276
|
+
|
|
277
|
+
def _get_classifier(self) -> TaxonomyClassifier:
|
|
278
|
+
"""Get or create the MNLI classifier."""
|
|
279
|
+
if self._classifier is None:
|
|
280
|
+
self._classifier = TaxonomyClassifier(model_id=self._model_id)
|
|
281
|
+
return self._classifier
|
|
282
|
+
|
|
283
|
+
def _get_filtered_taxonomy(self) -> dict[str, list[str]]:
|
|
284
|
+
"""Get taxonomy filtered to selected categories, with label names only."""
|
|
285
|
+
taxonomy = self._load_taxonomy()
|
|
286
|
+
return {
|
|
287
|
+
cat: list(labels.keys())
|
|
288
|
+
for cat, labels in taxonomy.items()
|
|
289
|
+
if cat in self._categories
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
def label(
|
|
293
|
+
self,
|
|
294
|
+
statement: PipelineStatement,
|
|
295
|
+
subject_canonical: CanonicalEntity,
|
|
296
|
+
object_canonical: CanonicalEntity,
|
|
297
|
+
context: PipelineContext,
|
|
298
|
+
) -> Optional[StatementLabel]:
|
|
299
|
+
"""
|
|
300
|
+
Classify statement against the taxonomy.
|
|
301
|
+
|
|
302
|
+
Args:
|
|
303
|
+
statement: The statement to label
|
|
304
|
+
subject_canonical: Canonicalized subject
|
|
305
|
+
object_canonical: Canonicalized object
|
|
306
|
+
context: Pipeline context
|
|
307
|
+
|
|
308
|
+
Returns:
|
|
309
|
+
StatementLabel with taxonomy classification, or None if below threshold
|
|
310
|
+
"""
|
|
311
|
+
# Check for pre-computed classification from extractor
|
|
312
|
+
result = context.get_classification(statement.source_text, self.label_type)
|
|
313
|
+
if result:
|
|
314
|
+
label_value, confidence = result
|
|
315
|
+
if confidence >= self._min_confidence:
|
|
316
|
+
return StatementLabel(
|
|
317
|
+
label_type=self.label_type,
|
|
318
|
+
label_value=label_value,
|
|
319
|
+
confidence=confidence,
|
|
320
|
+
labeler=self.name,
|
|
321
|
+
)
|
|
322
|
+
return None
|
|
323
|
+
|
|
324
|
+
# Run MNLI classification
|
|
325
|
+
try:
|
|
326
|
+
classifier = self._get_classifier()
|
|
327
|
+
taxonomy = self._get_filtered_taxonomy()
|
|
328
|
+
|
|
329
|
+
# Text to classify
|
|
330
|
+
text = statement.source_text
|
|
331
|
+
|
|
332
|
+
if self._use_hierarchical:
|
|
333
|
+
category, label, confidence = classifier.classify_hierarchical(
|
|
334
|
+
text,
|
|
335
|
+
taxonomy,
|
|
336
|
+
top_k_categories=self._top_k_categories,
|
|
337
|
+
)
|
|
338
|
+
# Include category in label for clarity
|
|
339
|
+
full_label = f"{category}:{label}" if category and label else None
|
|
340
|
+
else:
|
|
341
|
+
# Flat classification across all labels
|
|
342
|
+
all_labels = []
|
|
343
|
+
for labels in taxonomy.values():
|
|
344
|
+
all_labels.extend(labels)
|
|
345
|
+
|
|
346
|
+
label, confidence = classifier.classify(text, all_labels)
|
|
347
|
+
full_label = label
|
|
348
|
+
|
|
349
|
+
if full_label and confidence >= self._min_confidence:
|
|
350
|
+
# Get the numeric ID for reproducibility
|
|
351
|
+
label_id = self._get_label_id(category if self._use_hierarchical else None, label)
|
|
352
|
+
|
|
353
|
+
return StatementLabel(
|
|
354
|
+
label_type=self.label_type,
|
|
355
|
+
label_value=full_label,
|
|
356
|
+
confidence=confidence,
|
|
357
|
+
labeler=self.name,
|
|
358
|
+
metadata={"label_id": label_id, "category": category} if self._use_hierarchical else {"label_id": label_id},
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
except Exception as e:
|
|
362
|
+
logger.warning(f"Taxonomy classification failed: {e}")
|
|
363
|
+
|
|
364
|
+
return None
|
|
365
|
+
|
|
366
|
+
def _get_label_id(self, category: Optional[str], label: str) -> Optional[int]:
|
|
367
|
+
"""Get the numeric ID for a label."""
|
|
368
|
+
taxonomy = self._load_taxonomy()
|
|
369
|
+
|
|
370
|
+
if category and category in taxonomy:
|
|
371
|
+
entry = taxonomy[category].get(label)
|
|
372
|
+
if entry:
|
|
373
|
+
return entry.get("id")
|
|
374
|
+
|
|
375
|
+
# Search all categories for flat classification
|
|
376
|
+
for cat_labels in taxonomy.values():
|
|
377
|
+
if label in cat_labels:
|
|
378
|
+
entry = cat_labels[label]
|
|
379
|
+
return entry.get("id")
|
|
380
|
+
|
|
381
|
+
return None
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
# Allow importing without decorator for testing
|
|
385
|
+
TaxonomyLabelerClass = TaxonomyLabeler
|
|
386
|
+
TaxonomyClassifierClass = TaxonomyClassifier
|