corp-extractor 0.4.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. {corp_extractor-0.4.0.dist-info → corp_extractor-0.9.0.dist-info}/METADATA +348 -64
  2. corp_extractor-0.9.0.dist-info/RECORD +76 -0
  3. statement_extractor/__init__.py +10 -1
  4. statement_extractor/cli.py +1663 -17
  5. statement_extractor/data/default_predicates.json +368 -0
  6. statement_extractor/data/statement_taxonomy.json +6972 -0
  7. statement_extractor/database/__init__.py +52 -0
  8. statement_extractor/database/embeddings.py +186 -0
  9. statement_extractor/database/hub.py +520 -0
  10. statement_extractor/database/importers/__init__.py +24 -0
  11. statement_extractor/database/importers/companies_house.py +545 -0
  12. statement_extractor/database/importers/gleif.py +538 -0
  13. statement_extractor/database/importers/sec_edgar.py +375 -0
  14. statement_extractor/database/importers/wikidata.py +1012 -0
  15. statement_extractor/database/importers/wikidata_people.py +632 -0
  16. statement_extractor/database/models.py +230 -0
  17. statement_extractor/database/resolver.py +245 -0
  18. statement_extractor/database/store.py +1609 -0
  19. statement_extractor/document/__init__.py +62 -0
  20. statement_extractor/document/chunker.py +410 -0
  21. statement_extractor/document/context.py +171 -0
  22. statement_extractor/document/deduplicator.py +173 -0
  23. statement_extractor/document/html_extractor.py +246 -0
  24. statement_extractor/document/loader.py +303 -0
  25. statement_extractor/document/pipeline.py +388 -0
  26. statement_extractor/document/summarizer.py +195 -0
  27. statement_extractor/extractor.py +1 -23
  28. statement_extractor/gliner_extraction.py +4 -74
  29. statement_extractor/llm.py +255 -0
  30. statement_extractor/models/__init__.py +89 -0
  31. statement_extractor/models/canonical.py +182 -0
  32. statement_extractor/models/document.py +308 -0
  33. statement_extractor/models/entity.py +102 -0
  34. statement_extractor/models/labels.py +220 -0
  35. statement_extractor/models/qualifiers.py +139 -0
  36. statement_extractor/models/statement.py +101 -0
  37. statement_extractor/models.py +4 -1
  38. statement_extractor/pipeline/__init__.py +39 -0
  39. statement_extractor/pipeline/config.py +129 -0
  40. statement_extractor/pipeline/context.py +177 -0
  41. statement_extractor/pipeline/orchestrator.py +416 -0
  42. statement_extractor/pipeline/registry.py +303 -0
  43. statement_extractor/plugins/__init__.py +55 -0
  44. statement_extractor/plugins/base.py +716 -0
  45. statement_extractor/plugins/extractors/__init__.py +13 -0
  46. statement_extractor/plugins/extractors/base.py +9 -0
  47. statement_extractor/plugins/extractors/gliner2.py +546 -0
  48. statement_extractor/plugins/labelers/__init__.py +29 -0
  49. statement_extractor/plugins/labelers/base.py +9 -0
  50. statement_extractor/plugins/labelers/confidence.py +138 -0
  51. statement_extractor/plugins/labelers/relation_type.py +87 -0
  52. statement_extractor/plugins/labelers/sentiment.py +159 -0
  53. statement_extractor/plugins/labelers/taxonomy.py +386 -0
  54. statement_extractor/plugins/labelers/taxonomy_embedding.py +477 -0
  55. statement_extractor/plugins/pdf/__init__.py +10 -0
  56. statement_extractor/plugins/pdf/pypdf.py +291 -0
  57. statement_extractor/plugins/qualifiers/__init__.py +30 -0
  58. statement_extractor/plugins/qualifiers/base.py +9 -0
  59. statement_extractor/plugins/qualifiers/companies_house.py +185 -0
  60. statement_extractor/plugins/qualifiers/embedding_company.py +420 -0
  61. statement_extractor/plugins/qualifiers/gleif.py +197 -0
  62. statement_extractor/plugins/qualifiers/person.py +785 -0
  63. statement_extractor/plugins/qualifiers/sec_edgar.py +209 -0
  64. statement_extractor/plugins/scrapers/__init__.py +10 -0
  65. statement_extractor/plugins/scrapers/http.py +236 -0
  66. statement_extractor/plugins/splitters/__init__.py +13 -0
  67. statement_extractor/plugins/splitters/base.py +9 -0
  68. statement_extractor/plugins/splitters/t5_gemma.py +293 -0
  69. statement_extractor/plugins/taxonomy/__init__.py +13 -0
  70. statement_extractor/plugins/taxonomy/embedding.py +484 -0
  71. statement_extractor/plugins/taxonomy/mnli.py +291 -0
  72. statement_extractor/scoring.py +8 -8
  73. corp_extractor-0.4.0.dist-info/RECORD +0 -12
  74. {corp_extractor-0.4.0.dist-info → corp_extractor-0.9.0.dist-info}/WHEEL +0 -0
  75. {corp_extractor-0.4.0.dist-info → corp_extractor-0.9.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,386 @@
1
+ """
2
+ TaxonomyLabeler - Classifies statements against a large taxonomy using MNLI.
3
+
4
+ Uses zero-shot classification with MNLI models for taxonomy labeling where
5
+ there are too many possible values for simple multi-choice classification.
6
+ """
7
+
8
+ import json
9
+ import logging
10
+ from pathlib import Path
11
+ from typing import Optional, TypedDict
12
+
13
+ from ..base import BaseLabelerPlugin, TaxonomySchema, PluginCapability
14
+
15
+
16
+ class TaxonomyEntry(TypedDict):
17
+ """Structure for each taxonomy label entry."""
18
+ description: str
19
+ id: int
20
+ mnli_label: str
21
+ embedding_label: str
22
+
23
+
24
+ from ...pipeline.context import PipelineContext
25
+ from ...models import (
26
+ PipelineStatement,
27
+ CanonicalEntity,
28
+ StatementLabel,
29
+ )
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+ # Default taxonomy file location (relative to this module)
34
+ DEFAULT_TAXONOMY_PATH = Path(__file__).parent.parent.parent / "data" / "statement_taxonomy.json"
35
+
36
+ # Default categories to use (all of them)
37
+ DEFAULT_CATEGORIES = [
38
+ "environment",
39
+ "society",
40
+ "governance",
41
+ "animals",
42
+ "industry",
43
+ "human_harm",
44
+ "human_benefit",
45
+ "animal_harm",
46
+ "animal_benefit",
47
+ "environment_harm",
48
+ "environment_benefit",
49
+ ]
50
+
51
+
52
+ class TaxonomyClassifier:
53
+ """
54
+ MNLI-based zero-shot classifier for taxonomy labeling.
55
+
56
+ Uses HuggingFace transformers zero-shot-classification pipeline.
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ model_id: str = "facebook/bart-large-mnli",
62
+ device: Optional[str] = None,
63
+ ):
64
+ """
65
+ Initialize the classifier.
66
+
67
+ Args:
68
+ model_id: HuggingFace model ID for MNLI classification
69
+ device: Device to use ('cuda', 'mps', 'cpu', or None for auto)
70
+ """
71
+ self._model_id = model_id
72
+ self._device = device
73
+ self._classifier = None
74
+
75
+ def _load_classifier(self):
76
+ """Lazy-load the zero-shot classification pipeline."""
77
+ if self._classifier is not None:
78
+ return
79
+
80
+ try:
81
+ from transformers import pipeline
82
+ import torch
83
+
84
+ # Auto-detect device if not specified
85
+ device = self._device
86
+ if device is None:
87
+ if torch.cuda.is_available():
88
+ device = "cuda"
89
+ elif torch.backends.mps.is_available():
90
+ device = "mps"
91
+ else:
92
+ device = "cpu"
93
+
94
+ logger.info(f"Loading MNLI classifier '{self._model_id}' on {device}...")
95
+ self._classifier = pipeline(
96
+ "zero-shot-classification",
97
+ model=self._model_id,
98
+ device=device if device != "cpu" else -1,
99
+ )
100
+ logger.debug("MNLI classifier loaded")
101
+
102
+ except ImportError as e:
103
+ raise ImportError(
104
+ "transformers is required for MNLI classification. "
105
+ "Install with: pip install transformers"
106
+ ) from e
107
+
108
+ def classify(
109
+ self,
110
+ text: str,
111
+ candidate_labels: list[str],
112
+ multi_label: bool = False,
113
+ ) -> tuple[str, float]:
114
+ """
115
+ Classify text against candidate labels using MNLI.
116
+
117
+ Args:
118
+ text: Text to classify
119
+ candidate_labels: List of possible labels
120
+ multi_label: Whether multiple labels can apply
121
+
122
+ Returns:
123
+ Tuple of (best_label, confidence)
124
+ """
125
+ self._load_classifier()
126
+
127
+ result = self._classifier(
128
+ text,
129
+ candidate_labels=candidate_labels,
130
+ multi_label=multi_label,
131
+ )
132
+
133
+ # Result format: {'sequence': '...', 'labels': [...], 'scores': [...]}
134
+ best_label = result["labels"][0]
135
+ confidence = result["scores"][0]
136
+
137
+ return best_label, confidence
138
+
139
+ def classify_hierarchical(
140
+ self,
141
+ text: str,
142
+ taxonomy: dict[str, list[str]],
143
+ top_k_categories: int = 3,
144
+ ) -> tuple[str, str, float]:
145
+ """
146
+ Hierarchical classification: first category, then label within category.
147
+
148
+ More efficient than flat classification for large taxonomies.
149
+
150
+ Args:
151
+ text: Text to classify
152
+ taxonomy: Dict mapping category -> list of labels
153
+ top_k_categories: Number of top categories to consider
154
+
155
+ Returns:
156
+ Tuple of (category, label, confidence)
157
+ """
158
+ self._load_classifier()
159
+
160
+ # Step 1: Classify into category
161
+ categories = list(taxonomy.keys())
162
+ cat_result = self._classifier(text, candidate_labels=categories)
163
+
164
+ # Get top-k categories
165
+ top_categories = cat_result["labels"][:top_k_categories]
166
+ top_cat_scores = cat_result["scores"][:top_k_categories]
167
+
168
+ # Step 2: Classify within top categories
169
+ best_label = None
170
+ best_category = None
171
+ best_score = 0.0
172
+
173
+ for cat, cat_score in zip(top_categories, top_cat_scores):
174
+ labels = taxonomy[cat]
175
+ if not labels:
176
+ continue
177
+
178
+ label_result = self._classifier(text, candidate_labels=labels)
179
+ label = label_result["labels"][0]
180
+ label_score = label_result["scores"][0]
181
+
182
+ # Combined score: category confidence * label confidence
183
+ combined_score = cat_score * label_score
184
+
185
+ if combined_score > best_score:
186
+ best_score = combined_score
187
+ best_label = label
188
+ best_category = cat
189
+
190
+ return best_category, best_label, best_score
191
+
192
+
193
+ class TaxonomyLabeler(BaseLabelerPlugin):
194
+ """
195
+ Labeler that classifies statements against a large taxonomy using MNLI.
196
+
197
+ Supports hierarchical classification for efficiency with large taxonomies.
198
+ """
199
+
200
+ def __init__(
201
+ self,
202
+ taxonomy_path: Optional[str | Path] = None,
203
+ categories: Optional[list[str]] = None,
204
+ model_id: str = "facebook/bart-large-mnli",
205
+ use_hierarchical: bool = True,
206
+ top_k_categories: int = 3,
207
+ min_confidence: float = 0.3,
208
+ ):
209
+ """
210
+ Initialize the taxonomy labeler.
211
+
212
+ Args:
213
+ taxonomy_path: Path to taxonomy JSON file (default: built-in taxonomy)
214
+ categories: List of categories to use (default: all categories)
215
+ model_id: HuggingFace model ID for MNLI classifier
216
+ use_hierarchical: Use hierarchical classification (category then label)
217
+ top_k_categories: Number of top categories to consider in hierarchical mode
218
+ min_confidence: Minimum confidence threshold for returning a label
219
+ """
220
+ self._taxonomy_path = Path(taxonomy_path) if taxonomy_path else DEFAULT_TAXONOMY_PATH
221
+ self._categories = categories or DEFAULT_CATEGORIES
222
+ self._model_id = model_id
223
+ self._use_hierarchical = use_hierarchical
224
+ self._top_k_categories = top_k_categories
225
+ self._min_confidence = min_confidence
226
+
227
+ self._taxonomy: Optional[dict[str, dict[str, TaxonomyEntry]]] = None
228
+ self._classifier: Optional[TaxonomyClassifier] = None
229
+
230
+ @property
231
+ def name(self) -> str:
232
+ return "taxonomy_labeler"
233
+
234
+ @property
235
+ def priority(self) -> int:
236
+ return 60 # Lower priority than embedding taxonomy (use --plugins taxonomy_labeler to enable)
237
+
238
+ @property
239
+ def capabilities(self) -> PluginCapability:
240
+ return PluginCapability.LLM_REQUIRED
241
+
242
+ @property
243
+ def description(self) -> str:
244
+ return "Classifies statements against a taxonomy using MNLI zero-shot classification"
245
+
246
+ @property
247
+ def label_type(self) -> str:
248
+ return "taxonomy"
249
+
250
+ @property
251
+ def taxonomy_schema(self) -> TaxonomySchema:
252
+ """Provide taxonomy schema (for documentation/introspection)."""
253
+ taxonomy = self._load_taxonomy()
254
+ # Filter to selected categories
255
+ filtered = {cat: list(labels.keys()) for cat, labels in taxonomy.items() if cat in self._categories}
256
+ return TaxonomySchema(
257
+ label_type=self.label_type,
258
+ values=filtered,
259
+ description="Statement topic classification against corporate ESG taxonomy",
260
+ scope="statement",
261
+ )
262
+
263
+ def _load_taxonomy(self) -> dict[str, dict[str, TaxonomyEntry]]:
264
+ """Load taxonomy from JSON file."""
265
+ if self._taxonomy is not None:
266
+ return self._taxonomy
267
+
268
+ if not self._taxonomy_path.exists():
269
+ raise FileNotFoundError(f"Taxonomy file not found: {self._taxonomy_path}")
270
+
271
+ with open(self._taxonomy_path) as f:
272
+ self._taxonomy = json.load(f)
273
+
274
+ logger.debug(f"Loaded taxonomy with {len(self._taxonomy)} categories")
275
+ return self._taxonomy
276
+
277
+ def _get_classifier(self) -> TaxonomyClassifier:
278
+ """Get or create the MNLI classifier."""
279
+ if self._classifier is None:
280
+ self._classifier = TaxonomyClassifier(model_id=self._model_id)
281
+ return self._classifier
282
+
283
+ def _get_filtered_taxonomy(self) -> dict[str, list[str]]:
284
+ """Get taxonomy filtered to selected categories, with label names only."""
285
+ taxonomy = self._load_taxonomy()
286
+ return {
287
+ cat: list(labels.keys())
288
+ for cat, labels in taxonomy.items()
289
+ if cat in self._categories
290
+ }
291
+
292
+ def label(
293
+ self,
294
+ statement: PipelineStatement,
295
+ subject_canonical: CanonicalEntity,
296
+ object_canonical: CanonicalEntity,
297
+ context: PipelineContext,
298
+ ) -> Optional[StatementLabel]:
299
+ """
300
+ Classify statement against the taxonomy.
301
+
302
+ Args:
303
+ statement: The statement to label
304
+ subject_canonical: Canonicalized subject
305
+ object_canonical: Canonicalized object
306
+ context: Pipeline context
307
+
308
+ Returns:
309
+ StatementLabel with taxonomy classification, or None if below threshold
310
+ """
311
+ # Check for pre-computed classification from extractor
312
+ result = context.get_classification(statement.source_text, self.label_type)
313
+ if result:
314
+ label_value, confidence = result
315
+ if confidence >= self._min_confidence:
316
+ return StatementLabel(
317
+ label_type=self.label_type,
318
+ label_value=label_value,
319
+ confidence=confidence,
320
+ labeler=self.name,
321
+ )
322
+ return None
323
+
324
+ # Run MNLI classification
325
+ try:
326
+ classifier = self._get_classifier()
327
+ taxonomy = self._get_filtered_taxonomy()
328
+
329
+ # Text to classify
330
+ text = statement.source_text
331
+
332
+ if self._use_hierarchical:
333
+ category, label, confidence = classifier.classify_hierarchical(
334
+ text,
335
+ taxonomy,
336
+ top_k_categories=self._top_k_categories,
337
+ )
338
+ # Include category in label for clarity
339
+ full_label = f"{category}:{label}" if category and label else None
340
+ else:
341
+ # Flat classification across all labels
342
+ all_labels = []
343
+ for labels in taxonomy.values():
344
+ all_labels.extend(labels)
345
+
346
+ label, confidence = classifier.classify(text, all_labels)
347
+ full_label = label
348
+
349
+ if full_label and confidence >= self._min_confidence:
350
+ # Get the numeric ID for reproducibility
351
+ label_id = self._get_label_id(category if self._use_hierarchical else None, label)
352
+
353
+ return StatementLabel(
354
+ label_type=self.label_type,
355
+ label_value=full_label,
356
+ confidence=confidence,
357
+ labeler=self.name,
358
+ metadata={"label_id": label_id, "category": category} if self._use_hierarchical else {"label_id": label_id},
359
+ )
360
+
361
+ except Exception as e:
362
+ logger.warning(f"Taxonomy classification failed: {e}")
363
+
364
+ return None
365
+
366
+ def _get_label_id(self, category: Optional[str], label: str) -> Optional[int]:
367
+ """Get the numeric ID for a label."""
368
+ taxonomy = self._load_taxonomy()
369
+
370
+ if category and category in taxonomy:
371
+ entry = taxonomy[category].get(label)
372
+ if entry:
373
+ return entry.get("id")
374
+
375
+ # Search all categories for flat classification
376
+ for cat_labels in taxonomy.values():
377
+ if label in cat_labels:
378
+ entry = cat_labels[label]
379
+ return entry.get("id")
380
+
381
+ return None
382
+
383
+
384
+ # Allow importing without decorator for testing
385
+ TaxonomyLabelerClass = TaxonomyLabeler
386
+ TaxonomyClassifierClass = TaxonomyClassifier