corp-extractor 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. {corp_extractor-0.3.0.dist-info → corp_extractor-0.5.0.dist-info}/METADATA +235 -96
  2. corp_extractor-0.5.0.dist-info/RECORD +55 -0
  3. statement_extractor/__init__.py +9 -0
  4. statement_extractor/cli.py +460 -21
  5. statement_extractor/data/default_predicates.json +368 -0
  6. statement_extractor/data/statement_taxonomy.json +1182 -0
  7. statement_extractor/extractor.py +32 -47
  8. statement_extractor/gliner_extraction.py +218 -0
  9. statement_extractor/llm.py +255 -0
  10. statement_extractor/models/__init__.py +74 -0
  11. statement_extractor/models/canonical.py +139 -0
  12. statement_extractor/models/entity.py +102 -0
  13. statement_extractor/models/labels.py +191 -0
  14. statement_extractor/models/qualifiers.py +91 -0
  15. statement_extractor/models/statement.py +75 -0
  16. statement_extractor/models.py +15 -6
  17. statement_extractor/pipeline/__init__.py +39 -0
  18. statement_extractor/pipeline/config.py +134 -0
  19. statement_extractor/pipeline/context.py +177 -0
  20. statement_extractor/pipeline/orchestrator.py +447 -0
  21. statement_extractor/pipeline/registry.py +297 -0
  22. statement_extractor/plugins/__init__.py +43 -0
  23. statement_extractor/plugins/base.py +446 -0
  24. statement_extractor/plugins/canonicalizers/__init__.py +17 -0
  25. statement_extractor/plugins/canonicalizers/base.py +9 -0
  26. statement_extractor/plugins/canonicalizers/location.py +219 -0
  27. statement_extractor/plugins/canonicalizers/organization.py +230 -0
  28. statement_extractor/plugins/canonicalizers/person.py +242 -0
  29. statement_extractor/plugins/extractors/__init__.py +13 -0
  30. statement_extractor/plugins/extractors/base.py +9 -0
  31. statement_extractor/plugins/extractors/gliner2.py +536 -0
  32. statement_extractor/plugins/labelers/__init__.py +29 -0
  33. statement_extractor/plugins/labelers/base.py +9 -0
  34. statement_extractor/plugins/labelers/confidence.py +138 -0
  35. statement_extractor/plugins/labelers/relation_type.py +87 -0
  36. statement_extractor/plugins/labelers/sentiment.py +159 -0
  37. statement_extractor/plugins/labelers/taxonomy.py +373 -0
  38. statement_extractor/plugins/labelers/taxonomy_embedding.py +466 -0
  39. statement_extractor/plugins/qualifiers/__init__.py +19 -0
  40. statement_extractor/plugins/qualifiers/base.py +9 -0
  41. statement_extractor/plugins/qualifiers/companies_house.py +174 -0
  42. statement_extractor/plugins/qualifiers/gleif.py +186 -0
  43. statement_extractor/plugins/qualifiers/person.py +221 -0
  44. statement_extractor/plugins/qualifiers/sec_edgar.py +198 -0
  45. statement_extractor/plugins/splitters/__init__.py +13 -0
  46. statement_extractor/plugins/splitters/base.py +9 -0
  47. statement_extractor/plugins/splitters/t5_gemma.py +188 -0
  48. statement_extractor/plugins/taxonomy/__init__.py +13 -0
  49. statement_extractor/plugins/taxonomy/embedding.py +337 -0
  50. statement_extractor/plugins/taxonomy/mnli.py +279 -0
  51. statement_extractor/scoring.py +17 -69
  52. corp_extractor-0.3.0.dist-info/RECORD +0 -12
  53. statement_extractor/spacy_extraction.py +0 -386
  54. {corp_extractor-0.3.0.dist-info → corp_extractor-0.5.0.dist-info}/WHEEL +0 -0
  55. {corp_extractor-0.3.0.dist-info → corp_extractor-0.5.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,279 @@
1
+ """
2
+ MNLITaxonomyClassifier - Classifies statements using MNLI zero-shot classification.
3
+
4
+ Uses HuggingFace transformers zero-shot-classification pipeline for taxonomy labeling
5
+ where there are too many possible values for simple multi-choice classification.
6
+ """
7
+
8
+ import json
9
+ import logging
10
+ from pathlib import Path
11
+ from typing import Optional
12
+
13
+ from ..base import BaseTaxonomyPlugin, TaxonomySchema, PluginCapability
14
+ from ...pipeline.context import PipelineContext
15
+ from ...pipeline.registry import PluginRegistry
16
+ from ...models import (
17
+ PipelineStatement,
18
+ CanonicalEntity,
19
+ TaxonomyResult,
20
+ )
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Default taxonomy file location (relative to this module)
25
+ DEFAULT_TAXONOMY_PATH = Path(__file__).parent.parent.parent / "data" / "statement_taxonomy.json"
26
+
27
+ # Default categories to use (all of them)
28
+ DEFAULT_CATEGORIES = [
29
+ "environment",
30
+ "society",
31
+ "governance",
32
+ "animals",
33
+ "industry",
34
+ "human_harm",
35
+ "human_benefit",
36
+ "animal_harm",
37
+ "animal_benefit",
38
+ "environment_harm",
39
+ "environment_benefit",
40
+ ]
41
+
42
+
43
+ class MNLIClassifier:
44
+ """
45
+ MNLI-based zero-shot classifier for taxonomy labeling.
46
+
47
+ Uses HuggingFace transformers zero-shot-classification pipeline.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ model_id: str = "facebook/bart-large-mnli",
53
+ device: Optional[str] = None,
54
+ ):
55
+ self._model_id = model_id
56
+ self._device = device
57
+ self._classifier = None
58
+
59
+ def _load_classifier(self):
60
+ """Lazy-load the zero-shot classification pipeline."""
61
+ if self._classifier is not None:
62
+ return
63
+
64
+ try:
65
+ from transformers import pipeline
66
+ import torch
67
+
68
+ device = self._device
69
+ if device is None:
70
+ if torch.cuda.is_available():
71
+ device = "cuda"
72
+ elif torch.backends.mps.is_available():
73
+ device = "mps"
74
+ else:
75
+ device = "cpu"
76
+
77
+ logger.info(f"Loading MNLI classifier '{self._model_id}' on {device}...")
78
+ self._classifier = pipeline(
79
+ "zero-shot-classification",
80
+ model=self._model_id,
81
+ device=device if device != "cpu" else -1,
82
+ )
83
+ logger.debug("MNLI classifier loaded")
84
+
85
+ except ImportError as e:
86
+ raise ImportError(
87
+ "transformers is required for MNLI classification. "
88
+ "Install with: pip install transformers"
89
+ ) from e
90
+
91
+ def classify_hierarchical(
92
+ self,
93
+ text: str,
94
+ taxonomy: dict[str, list[str]],
95
+ top_k_categories: int = 3,
96
+ min_score: float = 0.3,
97
+ ) -> list[tuple[str, str, float]]:
98
+ """
99
+ Hierarchical classification: first category, then labels within category.
100
+
101
+ Returns all labels above the threshold, not just the best match.
102
+
103
+ Args:
104
+ text: Text to classify
105
+ taxonomy: Dict mapping category -> list of labels
106
+ top_k_categories: Number of top categories to consider
107
+ min_score: Minimum combined score to include in results
108
+
109
+ Returns:
110
+ List of (category, label, confidence) tuples above threshold
111
+ """
112
+ self._load_classifier()
113
+
114
+ categories = list(taxonomy.keys())
115
+ cat_result = self._classifier(text, candidate_labels=categories)
116
+
117
+ top_categories = cat_result["labels"][:top_k_categories]
118
+ top_cat_scores = cat_result["scores"][:top_k_categories]
119
+
120
+ results: list[tuple[str, str, float]] = []
121
+
122
+ for cat, cat_score in zip(top_categories, top_cat_scores):
123
+ labels = taxonomy[cat]
124
+ if not labels:
125
+ continue
126
+
127
+ label_result = self._classifier(text, candidate_labels=labels)
128
+
129
+ # Get all labels above threshold for this category
130
+ for label, label_score in zip(label_result["labels"], label_result["scores"]):
131
+ combined_score = cat_score * label_score
132
+
133
+ if combined_score >= min_score:
134
+ results.append((cat, label, combined_score))
135
+
136
+ # Sort by confidence descending
137
+ results.sort(key=lambda x: x[2], reverse=True)
138
+ return results
139
+
140
+
141
+ @PluginRegistry.taxonomy
142
+ class MNLITaxonomyClassifier(BaseTaxonomyPlugin):
143
+ """
144
+ Taxonomy classifier using MNLI zero-shot classification.
145
+
146
+ Supports hierarchical classification for efficiency with large taxonomies.
147
+ """
148
+
149
+ def __init__(
150
+ self,
151
+ taxonomy_path: Optional[str | Path] = None,
152
+ categories: Optional[list[str]] = None,
153
+ model_id: str = "facebook/bart-large-mnli",
154
+ top_k_categories: int = 3,
155
+ min_confidence: float = 0.3,
156
+ ):
157
+ self._taxonomy_path = Path(taxonomy_path) if taxonomy_path else DEFAULT_TAXONOMY_PATH
158
+ self._categories = categories or DEFAULT_CATEGORIES
159
+ self._model_id = model_id
160
+ self._top_k_categories = top_k_categories
161
+ self._min_confidence = min_confidence
162
+
163
+ self._taxonomy: Optional[dict[str, dict[str, int]]] = None
164
+ self._classifier: Optional[MNLIClassifier] = None
165
+
166
+ @property
167
+ def name(self) -> str:
168
+ return "mnli_taxonomy_classifier"
169
+
170
+ @property
171
+ def priority(self) -> int:
172
+ return 50 # Lower priority than embedding (use --plugins mnli_taxonomy_classifier to enable)
173
+
174
+ @property
175
+ def capabilities(self) -> PluginCapability:
176
+ return PluginCapability.LLM_REQUIRED
177
+
178
+ @property
179
+ def description(self) -> str:
180
+ return "Classifies statements against a taxonomy using MNLI zero-shot classification"
181
+
182
+ @property
183
+ def taxonomy_name(self) -> str:
184
+ return "esg_topics"
185
+
186
+ @property
187
+ def taxonomy_schema(self) -> TaxonomySchema:
188
+ taxonomy = self._load_taxonomy()
189
+ filtered = {cat: list(labels.keys()) for cat, labels in taxonomy.items() if cat in self._categories}
190
+ return TaxonomySchema(
191
+ label_type="taxonomy",
192
+ values=filtered,
193
+ description="ESG topic classification taxonomy",
194
+ scope="statement",
195
+ )
196
+
197
+ @property
198
+ def supported_categories(self) -> list[str]:
199
+ return self._categories.copy()
200
+
201
+ def _load_taxonomy(self) -> dict[str, dict[str, int]]:
202
+ """Load taxonomy from JSON file."""
203
+ if self._taxonomy is not None:
204
+ return self._taxonomy
205
+
206
+ if not self._taxonomy_path.exists():
207
+ raise FileNotFoundError(f"Taxonomy file not found: {self._taxonomy_path}")
208
+
209
+ with open(self._taxonomy_path) as f:
210
+ self._taxonomy = json.load(f)
211
+
212
+ logger.debug(f"Loaded taxonomy with {len(self._taxonomy)} categories")
213
+ return self._taxonomy
214
+
215
+ def _get_classifier(self) -> MNLIClassifier:
216
+ if self._classifier is None:
217
+ self._classifier = MNLIClassifier(model_id=self._model_id)
218
+ return self._classifier
219
+
220
+ def _get_filtered_taxonomy(self) -> dict[str, list[str]]:
221
+ taxonomy = self._load_taxonomy()
222
+ return {
223
+ cat: list(labels.keys())
224
+ for cat, labels in taxonomy.items()
225
+ if cat in self._categories
226
+ }
227
+
228
+ def classify(
229
+ self,
230
+ statement: PipelineStatement,
231
+ subject_canonical: CanonicalEntity,
232
+ object_canonical: CanonicalEntity,
233
+ context: PipelineContext,
234
+ ) -> list[TaxonomyResult]:
235
+ """Classify statement against the taxonomy using MNLI.
236
+
237
+ Returns all labels above the confidence threshold.
238
+ """
239
+ results: list[TaxonomyResult] = []
240
+
241
+ try:
242
+ classifier = self._get_classifier()
243
+ taxonomy = self._get_filtered_taxonomy()
244
+
245
+ text = statement.source_text
246
+
247
+ classifications = classifier.classify_hierarchical(
248
+ text,
249
+ taxonomy,
250
+ top_k_categories=self._top_k_categories,
251
+ min_score=self._min_confidence,
252
+ )
253
+
254
+ for category, label, confidence in classifications:
255
+ label_id = self._get_label_id(category, label)
256
+
257
+ results.append(TaxonomyResult(
258
+ taxonomy_name=self.taxonomy_name,
259
+ category=category,
260
+ label=label,
261
+ label_id=label_id,
262
+ confidence=round(confidence, 4),
263
+ classifier=self.name,
264
+ ))
265
+
266
+ except Exception as e:
267
+ logger.warning(f"MNLI taxonomy classification failed: {e}")
268
+
269
+ return results
270
+
271
+ def _get_label_id(self, category: str, label: str) -> Optional[int]:
272
+ taxonomy = self._load_taxonomy()
273
+ if category in taxonomy:
274
+ return taxonomy[category].get(label)
275
+ return None
276
+
277
+
278
+ # For testing without decorator
279
+ MNLITaxonomyClassifierClass = MNLITaxonomyClassifier
@@ -15,41 +15,21 @@ from .models import ScoringConfig, Statement
15
15
 
16
16
  logger = logging.getLogger(__name__)
17
17
 
18
- # Lazy-loaded spaCy model for grammatical analysis
19
- _nlp = None
20
-
21
-
22
- def _get_nlp():
23
- """Lazy-load spaCy model for POS tagging."""
24
- global _nlp
25
- if _nlp is None:
26
- import spacy
27
- try:
28
- _nlp = spacy.load("en_core_web_sm", disable=["ner", "lemmatizer"])
29
- except OSError:
30
- # Model not found, try to download
31
- from .spacy_extraction import _download_model
32
- if _download_model():
33
- _nlp = spacy.load("en_core_web_sm", disable=["ner", "lemmatizer"])
34
- else:
35
- raise
36
- return _nlp
37
-
38
18
 
39
19
  class TripleScorer:
40
20
  """
41
- Score individual triples combining semantic similarity and grammatical accuracy.
21
+ Score individual triples combining semantic similarity and entity recognition.
42
22
 
43
23
  The score is a weighted combination of:
44
24
  - Semantic similarity (50%): Cosine similarity between source text and reassembled triple
45
- - Subject noun score (25%): How noun-like the subject is
46
- - Object noun score (25%): How noun-like the object is
47
-
48
- Noun scoring:
49
- - Proper noun only (PROPN): 1.0
50
- - Common noun only (NOUN): 0.8
51
- - Contains noun + other words: 0.6
52
- - No noun: 0.2
25
+ - Subject entity score (25%): How entity-like the subject is (via GLiNER2)
26
+ - Object entity score (25%): How entity-like the object is (via GLiNER2)
27
+
28
+ Entity scoring (via GLiNER2):
29
+ - Recognized entity with high confidence: 1.0
30
+ - Recognized entity with moderate confidence: 0.8
31
+ - Partially recognized: 0.6
32
+ - Not recognized: 0.2
53
33
  """
54
34
 
55
35
  def __init__(
@@ -102,54 +82,22 @@ class TripleScorer:
102
82
 
103
83
  def _score_noun_content(self, text: str) -> float:
104
84
  """
105
- Score how noun-like a text is.
85
+ Score how entity-like a text is using GLiNER2 entity recognition.
106
86
 
107
87
  Returns:
108
- 1.0 - Entirely proper noun(s)
109
- 0.8 - Entirely common noun(s)
110
- 0.6 - Contains noun(s) but also other words
111
- 0.2 - No nouns found
88
+ 1.0 - Recognized as a named entity with high confidence
89
+ 0.8 - Recognized as an entity with moderate confidence
90
+ 0.6 - Partially recognized or contains entity-like content
91
+ 0.2 - Not recognized as any entity type
112
92
  """
113
93
  if not text or not text.strip():
114
94
  return 0.2
115
95
 
116
96
  try:
117
- nlp = _get_nlp()
118
- doc = nlp(text)
119
-
120
- # Count token types (excluding punctuation and spaces)
121
- tokens = [t for t in doc if not t.is_punct and not t.is_space]
122
- if not tokens:
123
- return 0.2
124
-
125
- proper_nouns = sum(1 for t in tokens if t.pos_ == "PROPN")
126
- common_nouns = sum(1 for t in tokens if t.pos_ == "NOUN")
127
- total_nouns = proper_nouns + common_nouns
128
- total_tokens = len(tokens)
129
-
130
- if total_nouns == 0:
131
- # No nouns at all
132
- return 0.2
133
-
134
- if total_nouns == total_tokens:
135
- # Entirely nouns
136
- if proper_nouns == total_tokens:
137
- # All proper nouns
138
- return 1.0
139
- elif common_nouns == total_tokens:
140
- # All common nouns
141
- return 0.8
142
- else:
143
- # Mix of proper and common nouns
144
- return 0.9
145
-
146
- # Contains nouns but also other words
147
- # Score based on noun ratio
148
- noun_ratio = total_nouns / total_tokens
149
- return 0.4 + (noun_ratio * 0.4) # Range: 0.4 to 0.8
150
-
97
+ from .gliner_extraction import score_entity_content
98
+ return score_entity_content(text)
151
99
  except Exception as e:
152
- logger.debug(f"Noun scoring failed for '{text}': {e}")
100
+ logger.debug(f"Entity scoring failed for '{text}': {e}")
153
101
  return 0.5 # Neutral score on error
154
102
 
155
103
  def score_triple(self, statement: Statement, source_text: str) -> float:
@@ -1,12 +0,0 @@
1
- statement_extractor/__init__.py,sha256=KwZfWnTB9oevTLw0TrNlYFu67qIYO-34JqDtcpjOhZI,3013
2
- statement_extractor/canonicalization.py,sha256=ZMLs6RLWJa_rOJ8XZ7PoHFU13-zeJkOMDnvK-ZaFa5s,5991
3
- statement_extractor/cli.py,sha256=JMEXiT2xwmW1J8JmJliQh32AT-7bTAtAscPx1AGRfPg,9054
4
- statement_extractor/extractor.py,sha256=vS8UCgE8uITt_28PwCh4WCqOjWLpfrJcN3fh1YPBcjA,39657
5
- statement_extractor/models.py,sha256=FxLj2fIodX317XVIJLZ0GFNahm_VV07KzdoLSSjoVD4,11952
6
- statement_extractor/predicate_comparer.py,sha256=jcuaBi5BYqD3TKoyj3pR9dxtX5ihfDJvjdhEd2LHCwc,26184
7
- statement_extractor/scoring.py,sha256=pdNgyLHmlk-npISzm4nycK9G4wM2nztg5KTG7piFACI,18135
8
- statement_extractor/spacy_extraction.py,sha256=ACvIB-Ag7H7h_Gb0cdypIr8fnf3A-UjyJnqqjWD5Ccs,12320
9
- corp_extractor-0.3.0.dist-info/METADATA,sha256=eu8b7R_FQxFyc_9FSocy078TTyB7BwvGX-YAS79hKgg,17042
10
- corp_extractor-0.3.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
11
- corp_extractor-0.3.0.dist-info/entry_points.txt,sha256=i0iKFqPIusvb-QTQ1zNnFgAqatgVah-jIhahbs5TToQ,115
12
- corp_extractor-0.3.0.dist-info/RECORD,,