corp-extractor 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. {corp_extractor-0.4.0.dist-info → corp_extractor-0.5.0.dist-info}/METADATA +181 -64
  2. corp_extractor-0.5.0.dist-info/RECORD +55 -0
  3. statement_extractor/__init__.py +9 -0
  4. statement_extractor/cli.py +446 -17
  5. statement_extractor/data/default_predicates.json +368 -0
  6. statement_extractor/data/statement_taxonomy.json +1182 -0
  7. statement_extractor/extractor.py +1 -23
  8. statement_extractor/gliner_extraction.py +4 -74
  9. statement_extractor/llm.py +255 -0
  10. statement_extractor/models/__init__.py +74 -0
  11. statement_extractor/models/canonical.py +139 -0
  12. statement_extractor/models/entity.py +102 -0
  13. statement_extractor/models/labels.py +191 -0
  14. statement_extractor/models/qualifiers.py +91 -0
  15. statement_extractor/models/statement.py +75 -0
  16. statement_extractor/models.py +4 -1
  17. statement_extractor/pipeline/__init__.py +39 -0
  18. statement_extractor/pipeline/config.py +134 -0
  19. statement_extractor/pipeline/context.py +177 -0
  20. statement_extractor/pipeline/orchestrator.py +447 -0
  21. statement_extractor/pipeline/registry.py +297 -0
  22. statement_extractor/plugins/__init__.py +43 -0
  23. statement_extractor/plugins/base.py +446 -0
  24. statement_extractor/plugins/canonicalizers/__init__.py +17 -0
  25. statement_extractor/plugins/canonicalizers/base.py +9 -0
  26. statement_extractor/plugins/canonicalizers/location.py +219 -0
  27. statement_extractor/plugins/canonicalizers/organization.py +230 -0
  28. statement_extractor/plugins/canonicalizers/person.py +242 -0
  29. statement_extractor/plugins/extractors/__init__.py +13 -0
  30. statement_extractor/plugins/extractors/base.py +9 -0
  31. statement_extractor/plugins/extractors/gliner2.py +536 -0
  32. statement_extractor/plugins/labelers/__init__.py +29 -0
  33. statement_extractor/plugins/labelers/base.py +9 -0
  34. statement_extractor/plugins/labelers/confidence.py +138 -0
  35. statement_extractor/plugins/labelers/relation_type.py +87 -0
  36. statement_extractor/plugins/labelers/sentiment.py +159 -0
  37. statement_extractor/plugins/labelers/taxonomy.py +373 -0
  38. statement_extractor/plugins/labelers/taxonomy_embedding.py +466 -0
  39. statement_extractor/plugins/qualifiers/__init__.py +19 -0
  40. statement_extractor/plugins/qualifiers/base.py +9 -0
  41. statement_extractor/plugins/qualifiers/companies_house.py +174 -0
  42. statement_extractor/plugins/qualifiers/gleif.py +186 -0
  43. statement_extractor/plugins/qualifiers/person.py +221 -0
  44. statement_extractor/plugins/qualifiers/sec_edgar.py +198 -0
  45. statement_extractor/plugins/splitters/__init__.py +13 -0
  46. statement_extractor/plugins/splitters/base.py +9 -0
  47. statement_extractor/plugins/splitters/t5_gemma.py +188 -0
  48. statement_extractor/plugins/taxonomy/__init__.py +13 -0
  49. statement_extractor/plugins/taxonomy/embedding.py +337 -0
  50. statement_extractor/plugins/taxonomy/mnli.py +279 -0
  51. corp_extractor-0.4.0.dist-info/RECORD +0 -12
  52. {corp_extractor-0.4.0.dist-info → corp_extractor-0.5.0.dist-info}/WHEEL +0 -0
  53. {corp_extractor-0.4.0.dist-info → corp_extractor-0.5.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,279 @@
1
+ """
2
+ MNLITaxonomyClassifier - Classifies statements using MNLI zero-shot classification.
3
+
4
+ Uses HuggingFace transformers zero-shot-classification pipeline for taxonomy labeling
5
+ where there are too many possible values for simple multi-choice classification.
6
+ """
7
+
8
+ import json
9
+ import logging
10
+ from pathlib import Path
11
+ from typing import Optional
12
+
13
+ from ..base import BaseTaxonomyPlugin, TaxonomySchema, PluginCapability
14
+ from ...pipeline.context import PipelineContext
15
+ from ...pipeline.registry import PluginRegistry
16
+ from ...models import (
17
+ PipelineStatement,
18
+ CanonicalEntity,
19
+ TaxonomyResult,
20
+ )
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Default taxonomy file location (relative to this module)
25
+ DEFAULT_TAXONOMY_PATH = Path(__file__).parent.parent.parent / "data" / "statement_taxonomy.json"
26
+
27
+ # Default categories to use (all of them)
28
+ DEFAULT_CATEGORIES = [
29
+ "environment",
30
+ "society",
31
+ "governance",
32
+ "animals",
33
+ "industry",
34
+ "human_harm",
35
+ "human_benefit",
36
+ "animal_harm",
37
+ "animal_benefit",
38
+ "environment_harm",
39
+ "environment_benefit",
40
+ ]
41
+
42
+
43
+ class MNLIClassifier:
44
+ """
45
+ MNLI-based zero-shot classifier for taxonomy labeling.
46
+
47
+ Uses HuggingFace transformers zero-shot-classification pipeline.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ model_id: str = "facebook/bart-large-mnli",
53
+ device: Optional[str] = None,
54
+ ):
55
+ self._model_id = model_id
56
+ self._device = device
57
+ self._classifier = None
58
+
59
+ def _load_classifier(self):
60
+ """Lazy-load the zero-shot classification pipeline."""
61
+ if self._classifier is not None:
62
+ return
63
+
64
+ try:
65
+ from transformers import pipeline
66
+ import torch
67
+
68
+ device = self._device
69
+ if device is None:
70
+ if torch.cuda.is_available():
71
+ device = "cuda"
72
+ elif torch.backends.mps.is_available():
73
+ device = "mps"
74
+ else:
75
+ device = "cpu"
76
+
77
+ logger.info(f"Loading MNLI classifier '{self._model_id}' on {device}...")
78
+ self._classifier = pipeline(
79
+ "zero-shot-classification",
80
+ model=self._model_id,
81
+ device=device if device != "cpu" else -1,
82
+ )
83
+ logger.debug("MNLI classifier loaded")
84
+
85
+ except ImportError as e:
86
+ raise ImportError(
87
+ "transformers is required for MNLI classification. "
88
+ "Install with: pip install transformers"
89
+ ) from e
90
+
91
+ def classify_hierarchical(
92
+ self,
93
+ text: str,
94
+ taxonomy: dict[str, list[str]],
95
+ top_k_categories: int = 3,
96
+ min_score: float = 0.3,
97
+ ) -> list[tuple[str, str, float]]:
98
+ """
99
+ Hierarchical classification: first category, then labels within category.
100
+
101
+ Returns all labels above the threshold, not just the best match.
102
+
103
+ Args:
104
+ text: Text to classify
105
+ taxonomy: Dict mapping category -> list of labels
106
+ top_k_categories: Number of top categories to consider
107
+ min_score: Minimum combined score to include in results
108
+
109
+ Returns:
110
+ List of (category, label, confidence) tuples above threshold
111
+ """
112
+ self._load_classifier()
113
+
114
+ categories = list(taxonomy.keys())
115
+ cat_result = self._classifier(text, candidate_labels=categories)
116
+
117
+ top_categories = cat_result["labels"][:top_k_categories]
118
+ top_cat_scores = cat_result["scores"][:top_k_categories]
119
+
120
+ results: list[tuple[str, str, float]] = []
121
+
122
+ for cat, cat_score in zip(top_categories, top_cat_scores):
123
+ labels = taxonomy[cat]
124
+ if not labels:
125
+ continue
126
+
127
+ label_result = self._classifier(text, candidate_labels=labels)
128
+
129
+ # Get all labels above threshold for this category
130
+ for label, label_score in zip(label_result["labels"], label_result["scores"]):
131
+ combined_score = cat_score * label_score
132
+
133
+ if combined_score >= min_score:
134
+ results.append((cat, label, combined_score))
135
+
136
+ # Sort by confidence descending
137
+ results.sort(key=lambda x: x[2], reverse=True)
138
+ return results
139
+
140
+
141
+ @PluginRegistry.taxonomy
142
+ class MNLITaxonomyClassifier(BaseTaxonomyPlugin):
143
+ """
144
+ Taxonomy classifier using MNLI zero-shot classification.
145
+
146
+ Supports hierarchical classification for efficiency with large taxonomies.
147
+ """
148
+
149
+ def __init__(
150
+ self,
151
+ taxonomy_path: Optional[str | Path] = None,
152
+ categories: Optional[list[str]] = None,
153
+ model_id: str = "facebook/bart-large-mnli",
154
+ top_k_categories: int = 3,
155
+ min_confidence: float = 0.3,
156
+ ):
157
+ self._taxonomy_path = Path(taxonomy_path) if taxonomy_path else DEFAULT_TAXONOMY_PATH
158
+ self._categories = categories or DEFAULT_CATEGORIES
159
+ self._model_id = model_id
160
+ self._top_k_categories = top_k_categories
161
+ self._min_confidence = min_confidence
162
+
163
+ self._taxonomy: Optional[dict[str, dict[str, int]]] = None
164
+ self._classifier: Optional[MNLIClassifier] = None
165
+
166
+ @property
167
+ def name(self) -> str:
168
+ return "mnli_taxonomy_classifier"
169
+
170
+ @property
171
+ def priority(self) -> int:
172
+ return 50 # Lower priority than embedding (use --plugins mnli_taxonomy_classifier to enable)
173
+
174
+ @property
175
+ def capabilities(self) -> PluginCapability:
176
+ return PluginCapability.LLM_REQUIRED
177
+
178
+ @property
179
+ def description(self) -> str:
180
+ return "Classifies statements against a taxonomy using MNLI zero-shot classification"
181
+
182
+ @property
183
+ def taxonomy_name(self) -> str:
184
+ return "esg_topics"
185
+
186
+ @property
187
+ def taxonomy_schema(self) -> TaxonomySchema:
188
+ taxonomy = self._load_taxonomy()
189
+ filtered = {cat: list(labels.keys()) for cat, labels in taxonomy.items() if cat in self._categories}
190
+ return TaxonomySchema(
191
+ label_type="taxonomy",
192
+ values=filtered,
193
+ description="ESG topic classification taxonomy",
194
+ scope="statement",
195
+ )
196
+
197
+ @property
198
+ def supported_categories(self) -> list[str]:
199
+ return self._categories.copy()
200
+
201
+ def _load_taxonomy(self) -> dict[str, dict[str, int]]:
202
+ """Load taxonomy from JSON file."""
203
+ if self._taxonomy is not None:
204
+ return self._taxonomy
205
+
206
+ if not self._taxonomy_path.exists():
207
+ raise FileNotFoundError(f"Taxonomy file not found: {self._taxonomy_path}")
208
+
209
+ with open(self._taxonomy_path) as f:
210
+ self._taxonomy = json.load(f)
211
+
212
+ logger.debug(f"Loaded taxonomy with {len(self._taxonomy)} categories")
213
+ return self._taxonomy
214
+
215
+ def _get_classifier(self) -> MNLIClassifier:
216
+ if self._classifier is None:
217
+ self._classifier = MNLIClassifier(model_id=self._model_id)
218
+ return self._classifier
219
+
220
+ def _get_filtered_taxonomy(self) -> dict[str, list[str]]:
221
+ taxonomy = self._load_taxonomy()
222
+ return {
223
+ cat: list(labels.keys())
224
+ for cat, labels in taxonomy.items()
225
+ if cat in self._categories
226
+ }
227
+
228
+ def classify(
229
+ self,
230
+ statement: PipelineStatement,
231
+ subject_canonical: CanonicalEntity,
232
+ object_canonical: CanonicalEntity,
233
+ context: PipelineContext,
234
+ ) -> list[TaxonomyResult]:
235
+ """Classify statement against the taxonomy using MNLI.
236
+
237
+ Returns all labels above the confidence threshold.
238
+ """
239
+ results: list[TaxonomyResult] = []
240
+
241
+ try:
242
+ classifier = self._get_classifier()
243
+ taxonomy = self._get_filtered_taxonomy()
244
+
245
+ text = statement.source_text
246
+
247
+ classifications = classifier.classify_hierarchical(
248
+ text,
249
+ taxonomy,
250
+ top_k_categories=self._top_k_categories,
251
+ min_score=self._min_confidence,
252
+ )
253
+
254
+ for category, label, confidence in classifications:
255
+ label_id = self._get_label_id(category, label)
256
+
257
+ results.append(TaxonomyResult(
258
+ taxonomy_name=self.taxonomy_name,
259
+ category=category,
260
+ label=label,
261
+ label_id=label_id,
262
+ confidence=round(confidence, 4),
263
+ classifier=self.name,
264
+ ))
265
+
266
+ except Exception as e:
267
+ logger.warning(f"MNLI taxonomy classification failed: {e}")
268
+
269
+ return results
270
+
271
+ def _get_label_id(self, category: str, label: str) -> Optional[int]:
272
+ taxonomy = self._load_taxonomy()
273
+ if category in taxonomy:
274
+ return taxonomy[category].get(label)
275
+ return None
276
+
277
+
278
+ # For testing without decorator
279
+ MNLITaxonomyClassifierClass = MNLITaxonomyClassifier
@@ -1,12 +0,0 @@
1
- statement_extractor/__init__.py,sha256=KwZfWnTB9oevTLw0TrNlYFu67qIYO-34JqDtcpjOhZI,3013
2
- statement_extractor/canonicalization.py,sha256=ZMLs6RLWJa_rOJ8XZ7PoHFU13-zeJkOMDnvK-ZaFa5s,5991
3
- statement_extractor/cli.py,sha256=FOkkihVfoROc-Biu8ICCzlLJeDScYYNLLJHnv0GCGGM,9507
4
- statement_extractor/extractor.py,sha256=d0HnCeCPybw-4jDxH_ffZ4LY9Klvqnza_wa90Bd4Q18,40074
5
- statement_extractor/gliner_extraction.py,sha256=KNs3n5-fnoUwY1wvbPwZL8j-3YVstmioJlcjp2k1FmY,10491
6
- statement_extractor/models.py,sha256=cyCQc3vlYB3qlg6-uL5Vt4odIiulKtHzz1Cyrf0lEAU,12198
7
- statement_extractor/predicate_comparer.py,sha256=jcuaBi5BYqD3TKoyj3pR9dxtX5ihfDJvjdhEd2LHCwc,26184
8
- statement_extractor/scoring.py,sha256=s_8nhavBNzPPFmGf2FyBummH4tgP7YGpXoMhl2Jh3Xw,16650
9
- corp_extractor-0.4.0.dist-info/METADATA,sha256=8f2CDtZG757kaB6XMfbBVdNSRMyS5-4Lflc_LoZCC_8,17725
10
- corp_extractor-0.4.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
11
- corp_extractor-0.4.0.dist-info/entry_points.txt,sha256=i0iKFqPIusvb-QTQ1zNnFgAqatgVah-jIhahbs5TToQ,115
12
- corp_extractor-0.4.0.dist-info/RECORD,,