corp-extractor 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. {corp_extractor-0.4.0.dist-info → corp_extractor-0.5.0.dist-info}/METADATA +181 -64
  2. corp_extractor-0.5.0.dist-info/RECORD +55 -0
  3. statement_extractor/__init__.py +9 -0
  4. statement_extractor/cli.py +446 -17
  5. statement_extractor/data/default_predicates.json +368 -0
  6. statement_extractor/data/statement_taxonomy.json +1182 -0
  7. statement_extractor/extractor.py +1 -23
  8. statement_extractor/gliner_extraction.py +4 -74
  9. statement_extractor/llm.py +255 -0
  10. statement_extractor/models/__init__.py +74 -0
  11. statement_extractor/models/canonical.py +139 -0
  12. statement_extractor/models/entity.py +102 -0
  13. statement_extractor/models/labels.py +191 -0
  14. statement_extractor/models/qualifiers.py +91 -0
  15. statement_extractor/models/statement.py +75 -0
  16. statement_extractor/models.py +4 -1
  17. statement_extractor/pipeline/__init__.py +39 -0
  18. statement_extractor/pipeline/config.py +134 -0
  19. statement_extractor/pipeline/context.py +177 -0
  20. statement_extractor/pipeline/orchestrator.py +447 -0
  21. statement_extractor/pipeline/registry.py +297 -0
  22. statement_extractor/plugins/__init__.py +43 -0
  23. statement_extractor/plugins/base.py +446 -0
  24. statement_extractor/plugins/canonicalizers/__init__.py +17 -0
  25. statement_extractor/plugins/canonicalizers/base.py +9 -0
  26. statement_extractor/plugins/canonicalizers/location.py +219 -0
  27. statement_extractor/plugins/canonicalizers/organization.py +230 -0
  28. statement_extractor/plugins/canonicalizers/person.py +242 -0
  29. statement_extractor/plugins/extractors/__init__.py +13 -0
  30. statement_extractor/plugins/extractors/base.py +9 -0
  31. statement_extractor/plugins/extractors/gliner2.py +536 -0
  32. statement_extractor/plugins/labelers/__init__.py +29 -0
  33. statement_extractor/plugins/labelers/base.py +9 -0
  34. statement_extractor/plugins/labelers/confidence.py +138 -0
  35. statement_extractor/plugins/labelers/relation_type.py +87 -0
  36. statement_extractor/plugins/labelers/sentiment.py +159 -0
  37. statement_extractor/plugins/labelers/taxonomy.py +373 -0
  38. statement_extractor/plugins/labelers/taxonomy_embedding.py +466 -0
  39. statement_extractor/plugins/qualifiers/__init__.py +19 -0
  40. statement_extractor/plugins/qualifiers/base.py +9 -0
  41. statement_extractor/plugins/qualifiers/companies_house.py +174 -0
  42. statement_extractor/plugins/qualifiers/gleif.py +186 -0
  43. statement_extractor/plugins/qualifiers/person.py +221 -0
  44. statement_extractor/plugins/qualifiers/sec_edgar.py +198 -0
  45. statement_extractor/plugins/splitters/__init__.py +13 -0
  46. statement_extractor/plugins/splitters/base.py +9 -0
  47. statement_extractor/plugins/splitters/t5_gemma.py +188 -0
  48. statement_extractor/plugins/taxonomy/__init__.py +13 -0
  49. statement_extractor/plugins/taxonomy/embedding.py +337 -0
  50. statement_extractor/plugins/taxonomy/mnli.py +279 -0
  51. corp_extractor-0.4.0.dist-info/RECORD +0 -12
  52. {corp_extractor-0.4.0.dist-info → corp_extractor-0.5.0.dist-info}/WHEEL +0 -0
  53. {corp_extractor-0.4.0.dist-info → corp_extractor-0.5.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,373 @@
1
+ """
2
+ TaxonomyLabeler - Classifies statements against a large taxonomy using MNLI.
3
+
4
+ Uses zero-shot classification with MNLI models for taxonomy labeling where
5
+ there are too many possible values for simple multi-choice classification.
6
+ """
7
+
8
+ import json
9
+ import logging
10
+ from pathlib import Path
11
+ from typing import Optional
12
+
13
+ from ..base import BaseLabelerPlugin, TaxonomySchema, PluginCapability
14
+ from ...pipeline.context import PipelineContext
15
+ from ...models import (
16
+ PipelineStatement,
17
+ CanonicalEntity,
18
+ StatementLabel,
19
+ )
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # Default taxonomy file location (relative to this module)
24
+ DEFAULT_TAXONOMY_PATH = Path(__file__).parent.parent.parent / "data" / "statement_taxonomy.json"
25
+
26
+ # Default categories to use (all of them)
27
+ DEFAULT_CATEGORIES = [
28
+ "environment",
29
+ "society",
30
+ "governance",
31
+ "animals",
32
+ "industry",
33
+ "human_harm",
34
+ "human_benefit",
35
+ "animal_harm",
36
+ "animal_benefit",
37
+ "environment_harm",
38
+ "environment_benefit",
39
+ ]
40
+
41
+
42
+ class TaxonomyClassifier:
43
+ """
44
+ MNLI-based zero-shot classifier for taxonomy labeling.
45
+
46
+ Uses HuggingFace transformers zero-shot-classification pipeline.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ model_id: str = "facebook/bart-large-mnli",
52
+ device: Optional[str] = None,
53
+ ):
54
+ """
55
+ Initialize the classifier.
56
+
57
+ Args:
58
+ model_id: HuggingFace model ID for MNLI classification
59
+ device: Device to use ('cuda', 'mps', 'cpu', or None for auto)
60
+ """
61
+ self._model_id = model_id
62
+ self._device = device
63
+ self._classifier = None
64
+
65
+ def _load_classifier(self):
66
+ """Lazy-load the zero-shot classification pipeline."""
67
+ if self._classifier is not None:
68
+ return
69
+
70
+ try:
71
+ from transformers import pipeline
72
+ import torch
73
+
74
+ # Auto-detect device if not specified
75
+ device = self._device
76
+ if device is None:
77
+ if torch.cuda.is_available():
78
+ device = "cuda"
79
+ elif torch.backends.mps.is_available():
80
+ device = "mps"
81
+ else:
82
+ device = "cpu"
83
+
84
+ logger.info(f"Loading MNLI classifier '{self._model_id}' on {device}...")
85
+ self._classifier = pipeline(
86
+ "zero-shot-classification",
87
+ model=self._model_id,
88
+ device=device if device != "cpu" else -1,
89
+ )
90
+ logger.debug("MNLI classifier loaded")
91
+
92
+ except ImportError as e:
93
+ raise ImportError(
94
+ "transformers is required for MNLI classification. "
95
+ "Install with: pip install transformers"
96
+ ) from e
97
+
98
+ def classify(
99
+ self,
100
+ text: str,
101
+ candidate_labels: list[str],
102
+ multi_label: bool = False,
103
+ ) -> tuple[str, float]:
104
+ """
105
+ Classify text against candidate labels using MNLI.
106
+
107
+ Args:
108
+ text: Text to classify
109
+ candidate_labels: List of possible labels
110
+ multi_label: Whether multiple labels can apply
111
+
112
+ Returns:
113
+ Tuple of (best_label, confidence)
114
+ """
115
+ self._load_classifier()
116
+
117
+ result = self._classifier(
118
+ text,
119
+ candidate_labels=candidate_labels,
120
+ multi_label=multi_label,
121
+ )
122
+
123
+ # Result format: {'sequence': '...', 'labels': [...], 'scores': [...]}
124
+ best_label = result["labels"][0]
125
+ confidence = result["scores"][0]
126
+
127
+ return best_label, confidence
128
+
129
+ def classify_hierarchical(
130
+ self,
131
+ text: str,
132
+ taxonomy: dict[str, list[str]],
133
+ top_k_categories: int = 3,
134
+ ) -> tuple[str, str, float]:
135
+ """
136
+ Hierarchical classification: first category, then label within category.
137
+
138
+ More efficient than flat classification for large taxonomies.
139
+
140
+ Args:
141
+ text: Text to classify
142
+ taxonomy: Dict mapping category -> list of labels
143
+ top_k_categories: Number of top categories to consider
144
+
145
+ Returns:
146
+ Tuple of (category, label, confidence)
147
+ """
148
+ self._load_classifier()
149
+
150
+ # Step 1: Classify into category
151
+ categories = list(taxonomy.keys())
152
+ cat_result = self._classifier(text, candidate_labels=categories)
153
+
154
+ # Get top-k categories
155
+ top_categories = cat_result["labels"][:top_k_categories]
156
+ top_cat_scores = cat_result["scores"][:top_k_categories]
157
+
158
+ # Step 2: Classify within top categories
159
+ best_label = None
160
+ best_category = None
161
+ best_score = 0.0
162
+
163
+ for cat, cat_score in zip(top_categories, top_cat_scores):
164
+ labels = taxonomy[cat]
165
+ if not labels:
166
+ continue
167
+
168
+ label_result = self._classifier(text, candidate_labels=labels)
169
+ label = label_result["labels"][0]
170
+ label_score = label_result["scores"][0]
171
+
172
+ # Combined score: category confidence * label confidence
173
+ combined_score = cat_score * label_score
174
+
175
+ if combined_score > best_score:
176
+ best_score = combined_score
177
+ best_label = label
178
+ best_category = cat
179
+
180
+ return best_category, best_label, best_score
181
+
182
+
183
+ class TaxonomyLabeler(BaseLabelerPlugin):
184
+ """
185
+ Labeler that classifies statements against a large taxonomy using MNLI.
186
+
187
+ Supports hierarchical classification for efficiency with large taxonomies.
188
+ """
189
+
190
+ def __init__(
191
+ self,
192
+ taxonomy_path: Optional[str | Path] = None,
193
+ categories: Optional[list[str]] = None,
194
+ model_id: str = "facebook/bart-large-mnli",
195
+ use_hierarchical: bool = True,
196
+ top_k_categories: int = 3,
197
+ min_confidence: float = 0.3,
198
+ ):
199
+ """
200
+ Initialize the taxonomy labeler.
201
+
202
+ Args:
203
+ taxonomy_path: Path to taxonomy JSON file (default: built-in taxonomy)
204
+ categories: List of categories to use (default: all categories)
205
+ model_id: HuggingFace model ID for MNLI classifier
206
+ use_hierarchical: Use hierarchical classification (category then label)
207
+ top_k_categories: Number of top categories to consider in hierarchical mode
208
+ min_confidence: Minimum confidence threshold for returning a label
209
+ """
210
+ self._taxonomy_path = Path(taxonomy_path) if taxonomy_path else DEFAULT_TAXONOMY_PATH
211
+ self._categories = categories or DEFAULT_CATEGORIES
212
+ self._model_id = model_id
213
+ self._use_hierarchical = use_hierarchical
214
+ self._top_k_categories = top_k_categories
215
+ self._min_confidence = min_confidence
216
+
217
+ self._taxonomy: Optional[dict[str, dict[str, int]]] = None
218
+ self._classifier: Optional[TaxonomyClassifier] = None
219
+
220
+ @property
221
+ def name(self) -> str:
222
+ return "taxonomy_labeler"
223
+
224
+ @property
225
+ def priority(self) -> int:
226
+ return 60 # Lower priority than embedding taxonomy (use --plugins taxonomy_labeler to enable)
227
+
228
+ @property
229
+ def capabilities(self) -> PluginCapability:
230
+ return PluginCapability.LLM_REQUIRED
231
+
232
+ @property
233
+ def description(self) -> str:
234
+ return "Classifies statements against a taxonomy using MNLI zero-shot classification"
235
+
236
+ @property
237
+ def label_type(self) -> str:
238
+ return "taxonomy"
239
+
240
+ @property
241
+ def taxonomy_schema(self) -> TaxonomySchema:
242
+ """Provide taxonomy schema (for documentation/introspection)."""
243
+ taxonomy = self._load_taxonomy()
244
+ # Filter to selected categories
245
+ filtered = {cat: list(labels.keys()) for cat, labels in taxonomy.items() if cat in self._categories}
246
+ return TaxonomySchema(
247
+ label_type=self.label_type,
248
+ values=filtered,
249
+ description="Statement topic classification against corporate ESG taxonomy",
250
+ scope="statement",
251
+ )
252
+
253
+ def _load_taxonomy(self) -> dict[str, dict[str, int]]:
254
+ """Load taxonomy from JSON file."""
255
+ if self._taxonomy is not None:
256
+ return self._taxonomy
257
+
258
+ if not self._taxonomy_path.exists():
259
+ raise FileNotFoundError(f"Taxonomy file not found: {self._taxonomy_path}")
260
+
261
+ with open(self._taxonomy_path) as f:
262
+ self._taxonomy = json.load(f)
263
+
264
+ logger.debug(f"Loaded taxonomy with {len(self._taxonomy)} categories")
265
+ return self._taxonomy
266
+
267
+ def _get_classifier(self) -> TaxonomyClassifier:
268
+ """Get or create the MNLI classifier."""
269
+ if self._classifier is None:
270
+ self._classifier = TaxonomyClassifier(model_id=self._model_id)
271
+ return self._classifier
272
+
273
+ def _get_filtered_taxonomy(self) -> dict[str, list[str]]:
274
+ """Get taxonomy filtered to selected categories, with label names only."""
275
+ taxonomy = self._load_taxonomy()
276
+ return {
277
+ cat: list(labels.keys())
278
+ for cat, labels in taxonomy.items()
279
+ if cat in self._categories
280
+ }
281
+
282
+ def label(
283
+ self,
284
+ statement: PipelineStatement,
285
+ subject_canonical: CanonicalEntity,
286
+ object_canonical: CanonicalEntity,
287
+ context: PipelineContext,
288
+ ) -> Optional[StatementLabel]:
289
+ """
290
+ Classify statement against the taxonomy.
291
+
292
+ Args:
293
+ statement: The statement to label
294
+ subject_canonical: Canonicalized subject
295
+ object_canonical: Canonicalized object
296
+ context: Pipeline context
297
+
298
+ Returns:
299
+ StatementLabel with taxonomy classification, or None if below threshold
300
+ """
301
+ # Check for pre-computed classification from extractor
302
+ result = context.get_classification(statement.source_text, self.label_type)
303
+ if result:
304
+ label_value, confidence = result
305
+ if confidence >= self._min_confidence:
306
+ return StatementLabel(
307
+ label_type=self.label_type,
308
+ label_value=label_value,
309
+ confidence=confidence,
310
+ labeler=self.name,
311
+ )
312
+ return None
313
+
314
+ # Run MNLI classification
315
+ try:
316
+ classifier = self._get_classifier()
317
+ taxonomy = self._get_filtered_taxonomy()
318
+
319
+ # Text to classify
320
+ text = statement.source_text
321
+
322
+ if self._use_hierarchical:
323
+ category, label, confidence = classifier.classify_hierarchical(
324
+ text,
325
+ taxonomy,
326
+ top_k_categories=self._top_k_categories,
327
+ )
328
+ # Include category in label for clarity
329
+ full_label = f"{category}:{label}" if category and label else None
330
+ else:
331
+ # Flat classification across all labels
332
+ all_labels = []
333
+ for labels in taxonomy.values():
334
+ all_labels.extend(labels)
335
+
336
+ label, confidence = classifier.classify(text, all_labels)
337
+ full_label = label
338
+
339
+ if full_label and confidence >= self._min_confidence:
340
+ # Get the numeric ID for reproducibility
341
+ label_id = self._get_label_id(category if self._use_hierarchical else None, label)
342
+
343
+ return StatementLabel(
344
+ label_type=self.label_type,
345
+ label_value=full_label,
346
+ confidence=confidence,
347
+ labeler=self.name,
348
+ metadata={"label_id": label_id, "category": category} if self._use_hierarchical else {"label_id": label_id},
349
+ )
350
+
351
+ except Exception as e:
352
+ logger.warning(f"Taxonomy classification failed: {e}")
353
+
354
+ return None
355
+
356
+ def _get_label_id(self, category: Optional[str], label: str) -> Optional[int]:
357
+ """Get the numeric ID for a label."""
358
+ taxonomy = self._load_taxonomy()
359
+
360
+ if category and category in taxonomy:
361
+ return taxonomy[category].get(label)
362
+
363
+ # Search all categories for flat classification
364
+ for cat_labels in taxonomy.values():
365
+ if label in cat_labels:
366
+ return cat_labels[label]
367
+
368
+ return None
369
+
370
+
371
+ # Allow importing without decorator for testing
372
+ TaxonomyLabelerClass = TaxonomyLabeler
373
+ TaxonomyClassifierClass = TaxonomyClassifier