corp-extractor 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {corp_extractor-0.4.0.dist-info → corp_extractor-0.5.0.dist-info}/METADATA +181 -64
- corp_extractor-0.5.0.dist-info/RECORD +55 -0
- statement_extractor/__init__.py +9 -0
- statement_extractor/cli.py +446 -17
- statement_extractor/data/default_predicates.json +368 -0
- statement_extractor/data/statement_taxonomy.json +1182 -0
- statement_extractor/extractor.py +1 -23
- statement_extractor/gliner_extraction.py +4 -74
- statement_extractor/llm.py +255 -0
- statement_extractor/models/__init__.py +74 -0
- statement_extractor/models/canonical.py +139 -0
- statement_extractor/models/entity.py +102 -0
- statement_extractor/models/labels.py +191 -0
- statement_extractor/models/qualifiers.py +91 -0
- statement_extractor/models/statement.py +75 -0
- statement_extractor/models.py +4 -1
- statement_extractor/pipeline/__init__.py +39 -0
- statement_extractor/pipeline/config.py +134 -0
- statement_extractor/pipeline/context.py +177 -0
- statement_extractor/pipeline/orchestrator.py +447 -0
- statement_extractor/pipeline/registry.py +297 -0
- statement_extractor/plugins/__init__.py +43 -0
- statement_extractor/plugins/base.py +446 -0
- statement_extractor/plugins/canonicalizers/__init__.py +17 -0
- statement_extractor/plugins/canonicalizers/base.py +9 -0
- statement_extractor/plugins/canonicalizers/location.py +219 -0
- statement_extractor/plugins/canonicalizers/organization.py +230 -0
- statement_extractor/plugins/canonicalizers/person.py +242 -0
- statement_extractor/plugins/extractors/__init__.py +13 -0
- statement_extractor/plugins/extractors/base.py +9 -0
- statement_extractor/plugins/extractors/gliner2.py +536 -0
- statement_extractor/plugins/labelers/__init__.py +29 -0
- statement_extractor/plugins/labelers/base.py +9 -0
- statement_extractor/plugins/labelers/confidence.py +138 -0
- statement_extractor/plugins/labelers/relation_type.py +87 -0
- statement_extractor/plugins/labelers/sentiment.py +159 -0
- statement_extractor/plugins/labelers/taxonomy.py +373 -0
- statement_extractor/plugins/labelers/taxonomy_embedding.py +466 -0
- statement_extractor/plugins/qualifiers/__init__.py +19 -0
- statement_extractor/plugins/qualifiers/base.py +9 -0
- statement_extractor/plugins/qualifiers/companies_house.py +174 -0
- statement_extractor/plugins/qualifiers/gleif.py +186 -0
- statement_extractor/plugins/qualifiers/person.py +221 -0
- statement_extractor/plugins/qualifiers/sec_edgar.py +198 -0
- statement_extractor/plugins/splitters/__init__.py +13 -0
- statement_extractor/plugins/splitters/base.py +9 -0
- statement_extractor/plugins/splitters/t5_gemma.py +188 -0
- statement_extractor/plugins/taxonomy/__init__.py +13 -0
- statement_extractor/plugins/taxonomy/embedding.py +337 -0
- statement_extractor/plugins/taxonomy/mnli.py +279 -0
- corp_extractor-0.4.0.dist-info/RECORD +0 -12
- {corp_extractor-0.4.0.dist-info → corp_extractor-0.5.0.dist-info}/WHEEL +0 -0
- {corp_extractor-0.4.0.dist-info → corp_extractor-0.5.0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MNLITaxonomyClassifier - Classifies statements using MNLI zero-shot classification.
|
|
3
|
+
|
|
4
|
+
Uses HuggingFace transformers zero-shot-classification pipeline for taxonomy labeling
|
|
5
|
+
where there are too many possible values for simple multi-choice classification.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import logging
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Optional
|
|
12
|
+
|
|
13
|
+
from ..base import BaseTaxonomyPlugin, TaxonomySchema, PluginCapability
|
|
14
|
+
from ...pipeline.context import PipelineContext
|
|
15
|
+
from ...pipeline.registry import PluginRegistry
|
|
16
|
+
from ...models import (
|
|
17
|
+
PipelineStatement,
|
|
18
|
+
CanonicalEntity,
|
|
19
|
+
TaxonomyResult,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
# Default taxonomy file location (relative to this module)
|
|
25
|
+
DEFAULT_TAXONOMY_PATH = Path(__file__).parent.parent.parent / "data" / "statement_taxonomy.json"
|
|
26
|
+
|
|
27
|
+
# Default categories to use (all of them)
|
|
28
|
+
DEFAULT_CATEGORIES = [
|
|
29
|
+
"environment",
|
|
30
|
+
"society",
|
|
31
|
+
"governance",
|
|
32
|
+
"animals",
|
|
33
|
+
"industry",
|
|
34
|
+
"human_harm",
|
|
35
|
+
"human_benefit",
|
|
36
|
+
"animal_harm",
|
|
37
|
+
"animal_benefit",
|
|
38
|
+
"environment_harm",
|
|
39
|
+
"environment_benefit",
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class MNLIClassifier:
|
|
44
|
+
"""
|
|
45
|
+
MNLI-based zero-shot classifier for taxonomy labeling.
|
|
46
|
+
|
|
47
|
+
Uses HuggingFace transformers zero-shot-classification pipeline.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
model_id: str = "facebook/bart-large-mnli",
|
|
53
|
+
device: Optional[str] = None,
|
|
54
|
+
):
|
|
55
|
+
self._model_id = model_id
|
|
56
|
+
self._device = device
|
|
57
|
+
self._classifier = None
|
|
58
|
+
|
|
59
|
+
def _load_classifier(self):
|
|
60
|
+
"""Lazy-load the zero-shot classification pipeline."""
|
|
61
|
+
if self._classifier is not None:
|
|
62
|
+
return
|
|
63
|
+
|
|
64
|
+
try:
|
|
65
|
+
from transformers import pipeline
|
|
66
|
+
import torch
|
|
67
|
+
|
|
68
|
+
device = self._device
|
|
69
|
+
if device is None:
|
|
70
|
+
if torch.cuda.is_available():
|
|
71
|
+
device = "cuda"
|
|
72
|
+
elif torch.backends.mps.is_available():
|
|
73
|
+
device = "mps"
|
|
74
|
+
else:
|
|
75
|
+
device = "cpu"
|
|
76
|
+
|
|
77
|
+
logger.info(f"Loading MNLI classifier '{self._model_id}' on {device}...")
|
|
78
|
+
self._classifier = pipeline(
|
|
79
|
+
"zero-shot-classification",
|
|
80
|
+
model=self._model_id,
|
|
81
|
+
device=device if device != "cpu" else -1,
|
|
82
|
+
)
|
|
83
|
+
logger.debug("MNLI classifier loaded")
|
|
84
|
+
|
|
85
|
+
except ImportError as e:
|
|
86
|
+
raise ImportError(
|
|
87
|
+
"transformers is required for MNLI classification. "
|
|
88
|
+
"Install with: pip install transformers"
|
|
89
|
+
) from e
|
|
90
|
+
|
|
91
|
+
def classify_hierarchical(
|
|
92
|
+
self,
|
|
93
|
+
text: str,
|
|
94
|
+
taxonomy: dict[str, list[str]],
|
|
95
|
+
top_k_categories: int = 3,
|
|
96
|
+
min_score: float = 0.3,
|
|
97
|
+
) -> list[tuple[str, str, float]]:
|
|
98
|
+
"""
|
|
99
|
+
Hierarchical classification: first category, then labels within category.
|
|
100
|
+
|
|
101
|
+
Returns all labels above the threshold, not just the best match.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
text: Text to classify
|
|
105
|
+
taxonomy: Dict mapping category -> list of labels
|
|
106
|
+
top_k_categories: Number of top categories to consider
|
|
107
|
+
min_score: Minimum combined score to include in results
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
List of (category, label, confidence) tuples above threshold
|
|
111
|
+
"""
|
|
112
|
+
self._load_classifier()
|
|
113
|
+
|
|
114
|
+
categories = list(taxonomy.keys())
|
|
115
|
+
cat_result = self._classifier(text, candidate_labels=categories)
|
|
116
|
+
|
|
117
|
+
top_categories = cat_result["labels"][:top_k_categories]
|
|
118
|
+
top_cat_scores = cat_result["scores"][:top_k_categories]
|
|
119
|
+
|
|
120
|
+
results: list[tuple[str, str, float]] = []
|
|
121
|
+
|
|
122
|
+
for cat, cat_score in zip(top_categories, top_cat_scores):
|
|
123
|
+
labels = taxonomy[cat]
|
|
124
|
+
if not labels:
|
|
125
|
+
continue
|
|
126
|
+
|
|
127
|
+
label_result = self._classifier(text, candidate_labels=labels)
|
|
128
|
+
|
|
129
|
+
# Get all labels above threshold for this category
|
|
130
|
+
for label, label_score in zip(label_result["labels"], label_result["scores"]):
|
|
131
|
+
combined_score = cat_score * label_score
|
|
132
|
+
|
|
133
|
+
if combined_score >= min_score:
|
|
134
|
+
results.append((cat, label, combined_score))
|
|
135
|
+
|
|
136
|
+
# Sort by confidence descending
|
|
137
|
+
results.sort(key=lambda x: x[2], reverse=True)
|
|
138
|
+
return results
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
@PluginRegistry.taxonomy
|
|
142
|
+
class MNLITaxonomyClassifier(BaseTaxonomyPlugin):
|
|
143
|
+
"""
|
|
144
|
+
Taxonomy classifier using MNLI zero-shot classification.
|
|
145
|
+
|
|
146
|
+
Supports hierarchical classification for efficiency with large taxonomies.
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
def __init__(
|
|
150
|
+
self,
|
|
151
|
+
taxonomy_path: Optional[str | Path] = None,
|
|
152
|
+
categories: Optional[list[str]] = None,
|
|
153
|
+
model_id: str = "facebook/bart-large-mnli",
|
|
154
|
+
top_k_categories: int = 3,
|
|
155
|
+
min_confidence: float = 0.3,
|
|
156
|
+
):
|
|
157
|
+
self._taxonomy_path = Path(taxonomy_path) if taxonomy_path else DEFAULT_TAXONOMY_PATH
|
|
158
|
+
self._categories = categories or DEFAULT_CATEGORIES
|
|
159
|
+
self._model_id = model_id
|
|
160
|
+
self._top_k_categories = top_k_categories
|
|
161
|
+
self._min_confidence = min_confidence
|
|
162
|
+
|
|
163
|
+
self._taxonomy: Optional[dict[str, dict[str, int]]] = None
|
|
164
|
+
self._classifier: Optional[MNLIClassifier] = None
|
|
165
|
+
|
|
166
|
+
@property
|
|
167
|
+
def name(self) -> str:
|
|
168
|
+
return "mnli_taxonomy_classifier"
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
def priority(self) -> int:
|
|
172
|
+
return 50 # Lower priority than embedding (use --plugins mnli_taxonomy_classifier to enable)
|
|
173
|
+
|
|
174
|
+
@property
|
|
175
|
+
def capabilities(self) -> PluginCapability:
|
|
176
|
+
return PluginCapability.LLM_REQUIRED
|
|
177
|
+
|
|
178
|
+
@property
|
|
179
|
+
def description(self) -> str:
|
|
180
|
+
return "Classifies statements against a taxonomy using MNLI zero-shot classification"
|
|
181
|
+
|
|
182
|
+
@property
|
|
183
|
+
def taxonomy_name(self) -> str:
|
|
184
|
+
return "esg_topics"
|
|
185
|
+
|
|
186
|
+
@property
|
|
187
|
+
def taxonomy_schema(self) -> TaxonomySchema:
|
|
188
|
+
taxonomy = self._load_taxonomy()
|
|
189
|
+
filtered = {cat: list(labels.keys()) for cat, labels in taxonomy.items() if cat in self._categories}
|
|
190
|
+
return TaxonomySchema(
|
|
191
|
+
label_type="taxonomy",
|
|
192
|
+
values=filtered,
|
|
193
|
+
description="ESG topic classification taxonomy",
|
|
194
|
+
scope="statement",
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
@property
|
|
198
|
+
def supported_categories(self) -> list[str]:
|
|
199
|
+
return self._categories.copy()
|
|
200
|
+
|
|
201
|
+
def _load_taxonomy(self) -> dict[str, dict[str, int]]:
|
|
202
|
+
"""Load taxonomy from JSON file."""
|
|
203
|
+
if self._taxonomy is not None:
|
|
204
|
+
return self._taxonomy
|
|
205
|
+
|
|
206
|
+
if not self._taxonomy_path.exists():
|
|
207
|
+
raise FileNotFoundError(f"Taxonomy file not found: {self._taxonomy_path}")
|
|
208
|
+
|
|
209
|
+
with open(self._taxonomy_path) as f:
|
|
210
|
+
self._taxonomy = json.load(f)
|
|
211
|
+
|
|
212
|
+
logger.debug(f"Loaded taxonomy with {len(self._taxonomy)} categories")
|
|
213
|
+
return self._taxonomy
|
|
214
|
+
|
|
215
|
+
def _get_classifier(self) -> MNLIClassifier:
|
|
216
|
+
if self._classifier is None:
|
|
217
|
+
self._classifier = MNLIClassifier(model_id=self._model_id)
|
|
218
|
+
return self._classifier
|
|
219
|
+
|
|
220
|
+
def _get_filtered_taxonomy(self) -> dict[str, list[str]]:
|
|
221
|
+
taxonomy = self._load_taxonomy()
|
|
222
|
+
return {
|
|
223
|
+
cat: list(labels.keys())
|
|
224
|
+
for cat, labels in taxonomy.items()
|
|
225
|
+
if cat in self._categories
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
def classify(
|
|
229
|
+
self,
|
|
230
|
+
statement: PipelineStatement,
|
|
231
|
+
subject_canonical: CanonicalEntity,
|
|
232
|
+
object_canonical: CanonicalEntity,
|
|
233
|
+
context: PipelineContext,
|
|
234
|
+
) -> list[TaxonomyResult]:
|
|
235
|
+
"""Classify statement against the taxonomy using MNLI.
|
|
236
|
+
|
|
237
|
+
Returns all labels above the confidence threshold.
|
|
238
|
+
"""
|
|
239
|
+
results: list[TaxonomyResult] = []
|
|
240
|
+
|
|
241
|
+
try:
|
|
242
|
+
classifier = self._get_classifier()
|
|
243
|
+
taxonomy = self._get_filtered_taxonomy()
|
|
244
|
+
|
|
245
|
+
text = statement.source_text
|
|
246
|
+
|
|
247
|
+
classifications = classifier.classify_hierarchical(
|
|
248
|
+
text,
|
|
249
|
+
taxonomy,
|
|
250
|
+
top_k_categories=self._top_k_categories,
|
|
251
|
+
min_score=self._min_confidence,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
for category, label, confidence in classifications:
|
|
255
|
+
label_id = self._get_label_id(category, label)
|
|
256
|
+
|
|
257
|
+
results.append(TaxonomyResult(
|
|
258
|
+
taxonomy_name=self.taxonomy_name,
|
|
259
|
+
category=category,
|
|
260
|
+
label=label,
|
|
261
|
+
label_id=label_id,
|
|
262
|
+
confidence=round(confidence, 4),
|
|
263
|
+
classifier=self.name,
|
|
264
|
+
))
|
|
265
|
+
|
|
266
|
+
except Exception as e:
|
|
267
|
+
logger.warning(f"MNLI taxonomy classification failed: {e}")
|
|
268
|
+
|
|
269
|
+
return results
|
|
270
|
+
|
|
271
|
+
def _get_label_id(self, category: str, label: str) -> Optional[int]:
|
|
272
|
+
taxonomy = self._load_taxonomy()
|
|
273
|
+
if category in taxonomy:
|
|
274
|
+
return taxonomy[category].get(label)
|
|
275
|
+
return None
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
# For testing without decorator
|
|
279
|
+
MNLITaxonomyClassifierClass = MNLITaxonomyClassifier
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
statement_extractor/__init__.py,sha256=KwZfWnTB9oevTLw0TrNlYFu67qIYO-34JqDtcpjOhZI,3013
|
|
2
|
-
statement_extractor/canonicalization.py,sha256=ZMLs6RLWJa_rOJ8XZ7PoHFU13-zeJkOMDnvK-ZaFa5s,5991
|
|
3
|
-
statement_extractor/cli.py,sha256=FOkkihVfoROc-Biu8ICCzlLJeDScYYNLLJHnv0GCGGM,9507
|
|
4
|
-
statement_extractor/extractor.py,sha256=d0HnCeCPybw-4jDxH_ffZ4LY9Klvqnza_wa90Bd4Q18,40074
|
|
5
|
-
statement_extractor/gliner_extraction.py,sha256=KNs3n5-fnoUwY1wvbPwZL8j-3YVstmioJlcjp2k1FmY,10491
|
|
6
|
-
statement_extractor/models.py,sha256=cyCQc3vlYB3qlg6-uL5Vt4odIiulKtHzz1Cyrf0lEAU,12198
|
|
7
|
-
statement_extractor/predicate_comparer.py,sha256=jcuaBi5BYqD3TKoyj3pR9dxtX5ihfDJvjdhEd2LHCwc,26184
|
|
8
|
-
statement_extractor/scoring.py,sha256=s_8nhavBNzPPFmGf2FyBummH4tgP7YGpXoMhl2Jh3Xw,16650
|
|
9
|
-
corp_extractor-0.4.0.dist-info/METADATA,sha256=8f2CDtZG757kaB6XMfbBVdNSRMyS5-4Lflc_LoZCC_8,17725
|
|
10
|
-
corp_extractor-0.4.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
11
|
-
corp_extractor-0.4.0.dist-info/entry_points.txt,sha256=i0iKFqPIusvb-QTQ1zNnFgAqatgVah-jIhahbs5TToQ,115
|
|
12
|
-
corp_extractor-0.4.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|