corp-extractor 0.4.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. {corp_extractor-0.4.0.dist-info → corp_extractor-0.9.0.dist-info}/METADATA +348 -64
  2. corp_extractor-0.9.0.dist-info/RECORD +76 -0
  3. statement_extractor/__init__.py +10 -1
  4. statement_extractor/cli.py +1663 -17
  5. statement_extractor/data/default_predicates.json +368 -0
  6. statement_extractor/data/statement_taxonomy.json +6972 -0
  7. statement_extractor/database/__init__.py +52 -0
  8. statement_extractor/database/embeddings.py +186 -0
  9. statement_extractor/database/hub.py +520 -0
  10. statement_extractor/database/importers/__init__.py +24 -0
  11. statement_extractor/database/importers/companies_house.py +545 -0
  12. statement_extractor/database/importers/gleif.py +538 -0
  13. statement_extractor/database/importers/sec_edgar.py +375 -0
  14. statement_extractor/database/importers/wikidata.py +1012 -0
  15. statement_extractor/database/importers/wikidata_people.py +632 -0
  16. statement_extractor/database/models.py +230 -0
  17. statement_extractor/database/resolver.py +245 -0
  18. statement_extractor/database/store.py +1609 -0
  19. statement_extractor/document/__init__.py +62 -0
  20. statement_extractor/document/chunker.py +410 -0
  21. statement_extractor/document/context.py +171 -0
  22. statement_extractor/document/deduplicator.py +173 -0
  23. statement_extractor/document/html_extractor.py +246 -0
  24. statement_extractor/document/loader.py +303 -0
  25. statement_extractor/document/pipeline.py +388 -0
  26. statement_extractor/document/summarizer.py +195 -0
  27. statement_extractor/extractor.py +1 -23
  28. statement_extractor/gliner_extraction.py +4 -74
  29. statement_extractor/llm.py +255 -0
  30. statement_extractor/models/__init__.py +89 -0
  31. statement_extractor/models/canonical.py +182 -0
  32. statement_extractor/models/document.py +308 -0
  33. statement_extractor/models/entity.py +102 -0
  34. statement_extractor/models/labels.py +220 -0
  35. statement_extractor/models/qualifiers.py +139 -0
  36. statement_extractor/models/statement.py +101 -0
  37. statement_extractor/models.py +4 -1
  38. statement_extractor/pipeline/__init__.py +39 -0
  39. statement_extractor/pipeline/config.py +129 -0
  40. statement_extractor/pipeline/context.py +177 -0
  41. statement_extractor/pipeline/orchestrator.py +416 -0
  42. statement_extractor/pipeline/registry.py +303 -0
  43. statement_extractor/plugins/__init__.py +55 -0
  44. statement_extractor/plugins/base.py +716 -0
  45. statement_extractor/plugins/extractors/__init__.py +13 -0
  46. statement_extractor/plugins/extractors/base.py +9 -0
  47. statement_extractor/plugins/extractors/gliner2.py +546 -0
  48. statement_extractor/plugins/labelers/__init__.py +29 -0
  49. statement_extractor/plugins/labelers/base.py +9 -0
  50. statement_extractor/plugins/labelers/confidence.py +138 -0
  51. statement_extractor/plugins/labelers/relation_type.py +87 -0
  52. statement_extractor/plugins/labelers/sentiment.py +159 -0
  53. statement_extractor/plugins/labelers/taxonomy.py +386 -0
  54. statement_extractor/plugins/labelers/taxonomy_embedding.py +477 -0
  55. statement_extractor/plugins/pdf/__init__.py +10 -0
  56. statement_extractor/plugins/pdf/pypdf.py +291 -0
  57. statement_extractor/plugins/qualifiers/__init__.py +30 -0
  58. statement_extractor/plugins/qualifiers/base.py +9 -0
  59. statement_extractor/plugins/qualifiers/companies_house.py +185 -0
  60. statement_extractor/plugins/qualifiers/embedding_company.py +420 -0
  61. statement_extractor/plugins/qualifiers/gleif.py +197 -0
  62. statement_extractor/plugins/qualifiers/person.py +785 -0
  63. statement_extractor/plugins/qualifiers/sec_edgar.py +209 -0
  64. statement_extractor/plugins/scrapers/__init__.py +10 -0
  65. statement_extractor/plugins/scrapers/http.py +236 -0
  66. statement_extractor/plugins/splitters/__init__.py +13 -0
  67. statement_extractor/plugins/splitters/base.py +9 -0
  68. statement_extractor/plugins/splitters/t5_gemma.py +293 -0
  69. statement_extractor/plugins/taxonomy/__init__.py +13 -0
  70. statement_extractor/plugins/taxonomy/embedding.py +484 -0
  71. statement_extractor/plugins/taxonomy/mnli.py +291 -0
  72. statement_extractor/scoring.py +8 -8
  73. corp_extractor-0.4.0.dist-info/RECORD +0 -12
  74. {corp_extractor-0.4.0.dist-info → corp_extractor-0.9.0.dist-info}/WHEEL +0 -0
  75. {corp_extractor-0.4.0.dist-info → corp_extractor-0.9.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,291 @@
1
+ """
2
+ MNLITaxonomyClassifier - Classifies statements using MNLI zero-shot classification.
3
+
4
+ Uses HuggingFace transformers zero-shot-classification pipeline for taxonomy labeling
5
+ where there are too many possible values for simple multi-choice classification.
6
+ """
7
+
8
+ import json
9
+ import logging
10
+ from pathlib import Path
11
+ from typing import Optional, TypedDict
12
+
13
+ from ..base import BaseTaxonomyPlugin, TaxonomySchema, PluginCapability
14
+
15
+
16
+ class TaxonomyEntry(TypedDict):
17
+ """Structure for each taxonomy label entry."""
18
+ description: str
19
+ id: int
20
+ mnli_label: str
21
+ embedding_label: str
22
+
23
+
24
+ from ...pipeline.context import PipelineContext
25
+ from ...pipeline.registry import PluginRegistry
26
+ from ...models import (
27
+ PipelineStatement,
28
+ CanonicalEntity,
29
+ TaxonomyResult,
30
+ )
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ # Default taxonomy file location (relative to this module)
35
+ DEFAULT_TAXONOMY_PATH = Path(__file__).parent.parent.parent / "data" / "statement_taxonomy.json"
36
+
37
+ # Default categories to use (all of them)
38
+ DEFAULT_CATEGORIES = [
39
+ "environment",
40
+ "society",
41
+ "governance",
42
+ "animals",
43
+ "industry",
44
+ "human_harm",
45
+ "human_benefit",
46
+ "animal_harm",
47
+ "animal_benefit",
48
+ "environment_harm",
49
+ "environment_benefit",
50
+ ]
51
+
52
+
53
+ class MNLIClassifier:
54
+ """
55
+ MNLI-based zero-shot classifier for taxonomy labeling.
56
+
57
+ Uses HuggingFace transformers zero-shot-classification pipeline.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ model_id: str = "facebook/bart-large-mnli",
63
+ device: Optional[str] = None,
64
+ ):
65
+ self._model_id = model_id
66
+ self._device = device
67
+ self._classifier = None
68
+
69
+ def _load_classifier(self):
70
+ """Lazy-load the zero-shot classification pipeline."""
71
+ if self._classifier is not None:
72
+ return
73
+
74
+ try:
75
+ from transformers import pipeline
76
+ import torch
77
+
78
+ device = self._device
79
+ if device is None:
80
+ if torch.cuda.is_available():
81
+ device = "cuda"
82
+ elif torch.backends.mps.is_available():
83
+ device = "mps"
84
+ else:
85
+ device = "cpu"
86
+
87
+ logger.info(f"Loading MNLI classifier '{self._model_id}' on {device}...")
88
+ self._classifier = pipeline(
89
+ "zero-shot-classification",
90
+ model=self._model_id,
91
+ device=device if device != "cpu" else -1,
92
+ )
93
+ logger.debug("MNLI classifier loaded")
94
+
95
+ except ImportError as e:
96
+ raise ImportError(
97
+ "transformers is required for MNLI classification. "
98
+ "Install with: pip install transformers"
99
+ ) from e
100
+
101
+ def classify_hierarchical(
102
+ self,
103
+ text: str,
104
+ taxonomy: dict[str, list[str]],
105
+ top_k_categories: int = 3,
106
+ min_score: float = 0.3,
107
+ ) -> list[tuple[str, str, float]]:
108
+ """
109
+ Hierarchical classification: first category, then labels within category.
110
+
111
+ Returns all labels above the threshold, not just the best match.
112
+
113
+ Args:
114
+ text: Text to classify
115
+ taxonomy: Dict mapping category -> list of labels
116
+ top_k_categories: Number of top categories to consider
117
+ min_score: Minimum combined score to include in results
118
+
119
+ Returns:
120
+ List of (category, label, confidence) tuples above threshold
121
+ """
122
+ self._load_classifier()
123
+
124
+ categories = list(taxonomy.keys())
125
+ cat_result = self._classifier(text, candidate_labels=categories)
126
+
127
+ top_categories = cat_result["labels"][:top_k_categories]
128
+ top_cat_scores = cat_result["scores"][:top_k_categories]
129
+
130
+ results: list[tuple[str, str, float]] = []
131
+
132
+ for cat, cat_score in zip(top_categories, top_cat_scores):
133
+ labels = taxonomy[cat]
134
+ if not labels:
135
+ continue
136
+
137
+ label_result = self._classifier(text, candidate_labels=labels)
138
+
139
+ # Get all labels above threshold for this category
140
+ for label, label_score in zip(label_result["labels"], label_result["scores"]):
141
+ combined_score = cat_score * label_score
142
+
143
+ if combined_score >= min_score:
144
+ results.append((cat, label, combined_score))
145
+
146
+ # Sort by confidence descending
147
+ results.sort(key=lambda x: x[2], reverse=True)
148
+ return results
149
+
150
+
151
+ @PluginRegistry.taxonomy
152
+ class MNLITaxonomyClassifier(BaseTaxonomyPlugin):
153
+ """
154
+ Taxonomy classifier using MNLI zero-shot classification.
155
+
156
+ Supports hierarchical classification for efficiency with large taxonomies.
157
+ """
158
+
159
+ def __init__(
160
+ self,
161
+ taxonomy_path: Optional[str | Path] = None,
162
+ categories: Optional[list[str]] = None,
163
+ model_id: str = "facebook/bart-large-mnli",
164
+ top_k_categories: int = 3,
165
+ min_confidence: float = 0.3,
166
+ ):
167
+ self._taxonomy_path = Path(taxonomy_path) if taxonomy_path else DEFAULT_TAXONOMY_PATH
168
+ self._categories = categories or DEFAULT_CATEGORIES
169
+ self._model_id = model_id
170
+ self._top_k_categories = top_k_categories
171
+ self._min_confidence = min_confidence
172
+
173
+ self._taxonomy: Optional[dict[str, dict[str, TaxonomyEntry]]] = None
174
+ self._classifier: Optional[MNLIClassifier] = None
175
+
176
+ @property
177
+ def name(self) -> str:
178
+ return "mnli_taxonomy_classifier"
179
+
180
+ @property
181
+ def priority(self) -> int:
182
+ return 50 # Lower priority than embedding (use --plugins mnli_taxonomy_classifier to enable)
183
+
184
+ @property
185
+ def capabilities(self) -> PluginCapability:
186
+ return PluginCapability.LLM_REQUIRED
187
+
188
+ @property
189
+ def description(self) -> str:
190
+ return "Classifies statements against a taxonomy using MNLI zero-shot classification"
191
+
192
+ @property
193
+ def taxonomy_name(self) -> str:
194
+ return "esg_topics"
195
+
196
+ @property
197
+ def taxonomy_schema(self) -> TaxonomySchema:
198
+ taxonomy = self._load_taxonomy()
199
+ filtered = {cat: list(labels.keys()) for cat, labels in taxonomy.items() if cat in self._categories}
200
+ return TaxonomySchema(
201
+ label_type="taxonomy",
202
+ values=filtered,
203
+ description="ESG topic classification taxonomy",
204
+ scope="statement",
205
+ )
206
+
207
+ @property
208
+ def supported_categories(self) -> list[str]:
209
+ return self._categories.copy()
210
+
211
+ def _load_taxonomy(self) -> dict[str, dict[str, TaxonomyEntry]]:
212
+ """Load taxonomy from JSON file."""
213
+ if self._taxonomy is not None:
214
+ return self._taxonomy
215
+
216
+ if not self._taxonomy_path.exists():
217
+ raise FileNotFoundError(f"Taxonomy file not found: {self._taxonomy_path}")
218
+
219
+ with open(self._taxonomy_path) as f:
220
+ self._taxonomy = json.load(f)
221
+
222
+ logger.debug(f"Loaded taxonomy with {len(self._taxonomy)} categories")
223
+ return self._taxonomy
224
+
225
+ def _get_classifier(self) -> MNLIClassifier:
226
+ if self._classifier is None:
227
+ self._classifier = MNLIClassifier(model_id=self._model_id)
228
+ return self._classifier
229
+
230
+ def _get_filtered_taxonomy(self) -> dict[str, list[str]]:
231
+ taxonomy = self._load_taxonomy()
232
+ return {
233
+ cat: list(labels.keys())
234
+ for cat, labels in taxonomy.items()
235
+ if cat in self._categories
236
+ }
237
+
238
+ def classify(
239
+ self,
240
+ statement: PipelineStatement,
241
+ subject_canonical: CanonicalEntity,
242
+ object_canonical: CanonicalEntity,
243
+ context: PipelineContext,
244
+ ) -> list[TaxonomyResult]:
245
+ """Classify statement against the taxonomy using MNLI.
246
+
247
+ Returns all labels above the confidence threshold.
248
+ """
249
+ results: list[TaxonomyResult] = []
250
+
251
+ try:
252
+ classifier = self._get_classifier()
253
+ taxonomy = self._get_filtered_taxonomy()
254
+
255
+ text = statement.source_text
256
+
257
+ classifications = classifier.classify_hierarchical(
258
+ text,
259
+ taxonomy,
260
+ top_k_categories=self._top_k_categories,
261
+ min_score=self._min_confidence,
262
+ )
263
+
264
+ for category, label, confidence in classifications:
265
+ label_id = self._get_label_id(category, label)
266
+
267
+ results.append(TaxonomyResult(
268
+ taxonomy_name=self.taxonomy_name,
269
+ category=category,
270
+ label=label,
271
+ label_id=label_id,
272
+ confidence=round(confidence, 4),
273
+ classifier=self.name,
274
+ ))
275
+
276
+ except Exception as e:
277
+ logger.warning(f"MNLI taxonomy classification failed: {e}")
278
+
279
+ return results
280
+
281
+ def _get_label_id(self, category: str, label: str) -> Optional[int]:
282
+ taxonomy = self._load_taxonomy()
283
+ if category in taxonomy:
284
+ entry = taxonomy[category].get(label)
285
+ if entry:
286
+ return entry.get("id")
287
+ return None
288
+
289
+
290
+ # For testing without decorator
291
+ MNLITaxonomyClassifierClass = MNLITaxonomyClassifier
@@ -409,18 +409,18 @@ class BeamScorer:
409
409
  filtered = [s for s in all_statements if (s.confidence_score or 0) >= min_conf]
410
410
  logger.debug(f" After confidence filter (>={min_conf}): {len(filtered)} statements")
411
411
 
412
- # # Filter out statements where source_text doesn't support the predicate
413
- # # This catches model hallucinations where predicate doesn't match the evidence
414
- # consistent = [
415
- # s for s in filtered
416
- # if self._source_text_supports_predicate(s)
417
- # ]
418
- # logger.debug(f" After predicate consistency filter: {len(consistent)} statements")
412
+ # Filter out statements where source_text doesn't support the predicate
413
+ # This catches model hallucinations where predicate doesn't match the evidence
414
+ consistent = [
415
+ s for s in filtered
416
+ if self._source_text_supports_predicate(s)
417
+ ]
418
+ logger.debug(f" After predicate consistency filter: {len(consistent)} statements")
419
419
 
420
420
  # Deduplicate - keep highest confidence for each (subject, predicate, object)
421
421
  # Note: Same subject+predicate with different objects is valid (e.g., "Apple announced X and Y")
422
422
  seen: dict[tuple[str, str, str], Statement] = {}
423
- for stmt in all_statements:
423
+ for stmt in consistent:
424
424
  key = (
425
425
  stmt.subject.text.lower(),
426
426
  stmt.predicate.lower(),
@@ -1,12 +0,0 @@
1
- statement_extractor/__init__.py,sha256=KwZfWnTB9oevTLw0TrNlYFu67qIYO-34JqDtcpjOhZI,3013
2
- statement_extractor/canonicalization.py,sha256=ZMLs6RLWJa_rOJ8XZ7PoHFU13-zeJkOMDnvK-ZaFa5s,5991
3
- statement_extractor/cli.py,sha256=FOkkihVfoROc-Biu8ICCzlLJeDScYYNLLJHnv0GCGGM,9507
4
- statement_extractor/extractor.py,sha256=d0HnCeCPybw-4jDxH_ffZ4LY9Klvqnza_wa90Bd4Q18,40074
5
- statement_extractor/gliner_extraction.py,sha256=KNs3n5-fnoUwY1wvbPwZL8j-3YVstmioJlcjp2k1FmY,10491
6
- statement_extractor/models.py,sha256=cyCQc3vlYB3qlg6-uL5Vt4odIiulKtHzz1Cyrf0lEAU,12198
7
- statement_extractor/predicate_comparer.py,sha256=jcuaBi5BYqD3TKoyj3pR9dxtX5ihfDJvjdhEd2LHCwc,26184
8
- statement_extractor/scoring.py,sha256=s_8nhavBNzPPFmGf2FyBummH4tgP7YGpXoMhl2Jh3Xw,16650
9
- corp_extractor-0.4.0.dist-info/METADATA,sha256=8f2CDtZG757kaB6XMfbBVdNSRMyS5-4Lflc_LoZCC_8,17725
10
- corp_extractor-0.4.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
11
- corp_extractor-0.4.0.dist-info/entry_points.txt,sha256=i0iKFqPIusvb-QTQ1zNnFgAqatgVah-jIhahbs5TToQ,115
12
- corp_extractor-0.4.0.dist-info/RECORD,,