corp-extractor 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. {corp_extractor-0.3.0.dist-info → corp_extractor-0.5.0.dist-info}/METADATA +235 -96
  2. corp_extractor-0.5.0.dist-info/RECORD +55 -0
  3. statement_extractor/__init__.py +9 -0
  4. statement_extractor/cli.py +460 -21
  5. statement_extractor/data/default_predicates.json +368 -0
  6. statement_extractor/data/statement_taxonomy.json +1182 -0
  7. statement_extractor/extractor.py +32 -47
  8. statement_extractor/gliner_extraction.py +218 -0
  9. statement_extractor/llm.py +255 -0
  10. statement_extractor/models/__init__.py +74 -0
  11. statement_extractor/models/canonical.py +139 -0
  12. statement_extractor/models/entity.py +102 -0
  13. statement_extractor/models/labels.py +191 -0
  14. statement_extractor/models/qualifiers.py +91 -0
  15. statement_extractor/models/statement.py +75 -0
  16. statement_extractor/models.py +15 -6
  17. statement_extractor/pipeline/__init__.py +39 -0
  18. statement_extractor/pipeline/config.py +134 -0
  19. statement_extractor/pipeline/context.py +177 -0
  20. statement_extractor/pipeline/orchestrator.py +447 -0
  21. statement_extractor/pipeline/registry.py +297 -0
  22. statement_extractor/plugins/__init__.py +43 -0
  23. statement_extractor/plugins/base.py +446 -0
  24. statement_extractor/plugins/canonicalizers/__init__.py +17 -0
  25. statement_extractor/plugins/canonicalizers/base.py +9 -0
  26. statement_extractor/plugins/canonicalizers/location.py +219 -0
  27. statement_extractor/plugins/canonicalizers/organization.py +230 -0
  28. statement_extractor/plugins/canonicalizers/person.py +242 -0
  29. statement_extractor/plugins/extractors/__init__.py +13 -0
  30. statement_extractor/plugins/extractors/base.py +9 -0
  31. statement_extractor/plugins/extractors/gliner2.py +536 -0
  32. statement_extractor/plugins/labelers/__init__.py +29 -0
  33. statement_extractor/plugins/labelers/base.py +9 -0
  34. statement_extractor/plugins/labelers/confidence.py +138 -0
  35. statement_extractor/plugins/labelers/relation_type.py +87 -0
  36. statement_extractor/plugins/labelers/sentiment.py +159 -0
  37. statement_extractor/plugins/labelers/taxonomy.py +373 -0
  38. statement_extractor/plugins/labelers/taxonomy_embedding.py +466 -0
  39. statement_extractor/plugins/qualifiers/__init__.py +19 -0
  40. statement_extractor/plugins/qualifiers/base.py +9 -0
  41. statement_extractor/plugins/qualifiers/companies_house.py +174 -0
  42. statement_extractor/plugins/qualifiers/gleif.py +186 -0
  43. statement_extractor/plugins/qualifiers/person.py +221 -0
  44. statement_extractor/plugins/qualifiers/sec_edgar.py +198 -0
  45. statement_extractor/plugins/splitters/__init__.py +13 -0
  46. statement_extractor/plugins/splitters/base.py +9 -0
  47. statement_extractor/plugins/splitters/t5_gemma.py +188 -0
  48. statement_extractor/plugins/taxonomy/__init__.py +13 -0
  49. statement_extractor/plugins/taxonomy/embedding.py +337 -0
  50. statement_extractor/plugins/taxonomy/mnli.py +279 -0
  51. statement_extractor/scoring.py +17 -69
  52. corp_extractor-0.3.0.dist-info/RECORD +0 -12
  53. statement_extractor/spacy_extraction.py +0 -386
  54. {corp_extractor-0.3.0.dist-info → corp_extractor-0.5.0.dist-info}/WHEEL +0 -0
  55. {corp_extractor-0.3.0.dist-info → corp_extractor-0.5.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,188 @@
1
+ """
2
+ T5GemmaSplitter - Stage 1 plugin that wraps the existing StatementExtractor.
3
+
4
+ Uses T5-Gemma2 model with Diverse Beam Search to generate high-quality
5
+ subject-predicate-object triples from text.
6
+ """
7
+
8
+ import logging
9
+ import re
10
+ import xml.etree.ElementTree as ET
11
+ from typing import Optional
12
+
13
+ from ..base import BaseSplitterPlugin, PluginCapability
14
+ from ...pipeline.context import PipelineContext
15
+ from ...pipeline.registry import PluginRegistry
16
+ from ...models import RawTriple
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ @PluginRegistry.splitter
22
+ class T5GemmaSplitter(BaseSplitterPlugin):
23
+ """
24
+ Splitter plugin that uses T5-Gemma2 for triple extraction.
25
+
26
+ Wraps the existing StatementExtractor from extractor.py to produce
27
+ RawTriple objects for the pipeline.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ model_id: Optional[str] = None,
33
+ device: Optional[str] = None,
34
+ num_beams: int = 4,
35
+ diversity_penalty: float = 1.0,
36
+ max_new_tokens: int = 2048,
37
+ ):
38
+ """
39
+ Initialize the T5Gemma splitter.
40
+
41
+ Args:
42
+ model_id: HuggingFace model ID (defaults to Corp-o-Rate model)
43
+ device: Device to use (auto-detected if not specified)
44
+ num_beams: Number of beams for diverse beam search
45
+ diversity_penalty: Penalty for beam diversity
46
+ max_new_tokens: Maximum tokens to generate
47
+ """
48
+ self._model_id = model_id
49
+ self._device = device
50
+ self._num_beams = num_beams
51
+ self._diversity_penalty = diversity_penalty
52
+ self._max_new_tokens = max_new_tokens
53
+ self._extractor = None
54
+
55
+ @property
56
+ def name(self) -> str:
57
+ return "t5_gemma_splitter"
58
+
59
+ @property
60
+ def priority(self) -> int:
61
+ return 10 # High priority - primary splitter
62
+
63
+ @property
64
+ def capabilities(self) -> PluginCapability:
65
+ return PluginCapability.LLM_REQUIRED
66
+
67
+ @property
68
+ def description(self) -> str:
69
+ return "T5-Gemma2 model for extracting triples using Diverse Beam Search"
70
+
71
+ def _get_extractor(self):
72
+ """Lazy-load the StatementExtractor."""
73
+ if self._extractor is None:
74
+ from ...extractor import StatementExtractor
75
+ # Only pass model_id and device if they were explicitly set
76
+ kwargs = {}
77
+ if self._model_id is not None:
78
+ kwargs["model_id"] = self._model_id
79
+ if self._device is not None:
80
+ kwargs["device"] = self._device
81
+ self._extractor = StatementExtractor(**kwargs)
82
+ return self._extractor
83
+
84
+ def split(
85
+ self,
86
+ text: str,
87
+ context: PipelineContext,
88
+ ) -> list[RawTriple]:
89
+ """
90
+ Split text into raw triples using T5-Gemma2.
91
+
92
+ Args:
93
+ text: Input text to split
94
+ context: Pipeline context
95
+
96
+ Returns:
97
+ List of RawTriple objects
98
+ """
99
+ logger.debug(f"T5GemmaSplitter processing {len(text)} chars")
100
+
101
+ # Get options from context if available
102
+ splitter_options = context.source_metadata.get("splitter_options", {})
103
+ num_beams = splitter_options.get("num_beams", self._num_beams)
104
+ diversity_penalty = splitter_options.get("diversity_penalty", self._diversity_penalty)
105
+ max_new_tokens = splitter_options.get("max_new_tokens", self._max_new_tokens)
106
+
107
+ # Create extraction options
108
+ from ...models import ExtractionOptions as LegacyExtractionOptions
109
+ options = LegacyExtractionOptions(
110
+ num_beams=num_beams,
111
+ diversity_penalty=diversity_penalty,
112
+ max_new_tokens=max_new_tokens,
113
+ # Disable GLiNER and dedup - we handle those in later stages
114
+ use_gliner_extraction=False,
115
+ embedding_dedup=False,
116
+ deduplicate=False,
117
+ )
118
+
119
+ # Get raw XML from extractor
120
+ extractor = self._get_extractor()
121
+ xml_output = extractor.extract_as_xml(text, options)
122
+
123
+ # Parse XML to RawTriple objects
124
+ raw_triples = self._parse_xml_to_raw_triples(xml_output)
125
+
126
+ logger.info(f"T5GemmaSplitter produced {len(raw_triples)} raw triples")
127
+ return raw_triples
128
+
129
+ def _parse_xml_to_raw_triples(self, xml_output: str) -> list[RawTriple]:
130
+ """Parse XML output into RawTriple objects."""
131
+ raw_triples = []
132
+
133
+ try:
134
+ root = ET.fromstring(xml_output)
135
+ except ET.ParseError as e:
136
+ logger.warning(f"XML parse error: {e}")
137
+ # Try to repair
138
+ xml_output = self._repair_xml(xml_output)
139
+ try:
140
+ root = ET.fromstring(xml_output)
141
+ except ET.ParseError:
142
+ logger.error("XML repair failed")
143
+ return raw_triples
144
+
145
+ if root.tag != "statements":
146
+ logger.warning(f"Unexpected root tag: {root.tag}")
147
+ return raw_triples
148
+
149
+ for stmt_elem in root.findall("stmt"):
150
+ try:
151
+ subject_elem = stmt_elem.find("subject")
152
+ predicate_elem = stmt_elem.find("predicate")
153
+ object_elem = stmt_elem.find("object")
154
+ text_elem = stmt_elem.find("text")
155
+
156
+ subject_text = subject_elem.text.strip() if subject_elem is not None and subject_elem.text else ""
157
+ predicate_text = predicate_elem.text.strip() if predicate_elem is not None and predicate_elem.text else ""
158
+ object_text = object_elem.text.strip() if object_elem is not None and object_elem.text else ""
159
+ source_text = text_elem.text.strip() if text_elem is not None and text_elem.text else ""
160
+
161
+ if subject_text and object_text and source_text:
162
+ raw_triples.append(RawTriple(
163
+ subject_text=subject_text,
164
+ predicate_text=predicate_text,
165
+ object_text=object_text,
166
+ source_sentence=source_text,
167
+ ))
168
+ else:
169
+ logger.debug(f"Skipping incomplete triple: s={subject_text}, p={predicate_text}, o={object_text}")
170
+
171
+ except Exception as e:
172
+ logger.warning(f"Error parsing stmt element: {e}")
173
+ continue
174
+
175
+ return raw_triples
176
+
177
+ def _repair_xml(self, xml_string: str) -> str:
178
+ """Attempt to repair common XML syntax errors."""
179
+ # Use the repair function from extractor.py
180
+ from ...extractor import repair_xml
181
+ repaired, repairs = repair_xml(xml_string)
182
+ if repairs:
183
+ logger.debug(f"XML repairs: {', '.join(repairs)}")
184
+ return repaired
185
+
186
+
187
+ # Allow importing without decorator for testing
188
+ T5GemmaSplitterClass = T5GemmaSplitter
@@ -0,0 +1,13 @@
1
+ """
2
+ Taxonomy classifier plugins for Stage 6 (Taxonomy).
3
+
4
+ Classifies statements against large taxonomies using MNLI or embeddings.
5
+ """
6
+
7
+ from .mnli import MNLITaxonomyClassifier
8
+ from .embedding import EmbeddingTaxonomyClassifier
9
+
10
+ __all__ = [
11
+ "MNLITaxonomyClassifier",
12
+ "EmbeddingTaxonomyClassifier",
13
+ ]
@@ -0,0 +1,337 @@
1
+ """
2
+ EmbeddingTaxonomyClassifier - Classifies statements using embedding similarity.
3
+
4
+ Uses sentence-transformers to embed text and compare to pre-computed label
5
+ embeddings using cosine similarity with sigmoid calibration.
6
+
7
+ Faster than MNLI but may be less accurate for nuanced classification.
8
+ """
9
+
10
+ import json
11
+ import logging
12
+ import time
13
+ from pathlib import Path
14
+ from typing import Optional
15
+
16
+ import numpy as np
17
+
18
+ from ..base import BaseTaxonomyPlugin, TaxonomySchema, PluginCapability
19
+ from ...pipeline.context import PipelineContext
20
+ from ...pipeline.registry import PluginRegistry
21
+ from ...models import (
22
+ PipelineStatement,
23
+ CanonicalEntity,
24
+ TaxonomyResult,
25
+ )
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ # Default taxonomy file location
30
+ DEFAULT_TAXONOMY_PATH = Path(__file__).parent.parent.parent / "data" / "statement_taxonomy.json"
31
+
32
+ # Default categories
33
+ DEFAULT_CATEGORIES = [
34
+ "environment",
35
+ "society",
36
+ "governance",
37
+ "animals",
38
+ "industry",
39
+ "human_harm",
40
+ "human_benefit",
41
+ "animal_harm",
42
+ "animal_benefit",
43
+ "environment_harm",
44
+ "environment_benefit",
45
+ ]
46
+
47
+
48
+ class EmbeddingClassifier:
49
+ """
50
+ Embedding-based classifier using cosine similarity.
51
+
52
+ Pre-computes embeddings for all labels and uses dot product
53
+ (on normalized vectors) for fast classification.
54
+ """
55
+
56
+ SIMILARITY_THRESHOLD = 0.65
57
+ CALIBRATION_STEEPNESS = 25.0
58
+
59
+ def __init__(
60
+ self,
61
+ model_name: str = "google/embeddinggemma-300m",
62
+ device: Optional[str] = None,
63
+ ):
64
+ self._model_name = model_name
65
+ self._device = device
66
+ self._model = None
67
+ self._label_embeddings: dict[str, dict[str, np.ndarray]] = {}
68
+ self._text_embedding_cache: dict[str, np.ndarray] = {} # Cache for input text embeddings
69
+
70
+ def _load_model(self):
71
+ if self._model is not None:
72
+ return
73
+
74
+ try:
75
+ from sentence_transformers import SentenceTransformer
76
+ import torch
77
+
78
+ device = self._device
79
+ if device is None:
80
+ if torch.cuda.is_available():
81
+ device = "cuda"
82
+ elif torch.backends.mps.is_available():
83
+ device = "mps"
84
+ else:
85
+ device = "cpu"
86
+
87
+ logger.info(f"Loading embedding model '{self._model_name}' on {device}...")
88
+ self._model = SentenceTransformer(self._model_name, device=device)
89
+ logger.debug("Embedding model loaded")
90
+
91
+ except ImportError as e:
92
+ raise ImportError(
93
+ "sentence-transformers is required for embedding classification. "
94
+ "Install with: pip install sentence-transformers"
95
+ ) from e
96
+
97
+ def precompute_label_embeddings(
98
+ self,
99
+ taxonomy: dict[str, dict[str, int]],
100
+ categories: Optional[list[str]] = None,
101
+ ) -> None:
102
+ """Pre-compute embeddings for all label names."""
103
+ self._load_model()
104
+
105
+ start_time = time.perf_counter()
106
+ total_labels = 0
107
+
108
+ categories_to_process = categories or list(taxonomy.keys())
109
+
110
+ for category in categories_to_process:
111
+ if category not in taxonomy:
112
+ continue
113
+
114
+ labels = taxonomy[category]
115
+ label_names = list(labels.keys())
116
+
117
+ if not label_names:
118
+ continue
119
+
120
+ embeddings = self._model.encode(label_names, convert_to_numpy=True, show_progress_bar=False)
121
+
122
+ self._label_embeddings[category] = {}
123
+ for label_name, embedding in zip(label_names, embeddings):
124
+ norm = np.linalg.norm(embedding)
125
+ normalized = embedding / (norm + 1e-8)
126
+ self._label_embeddings[category][label_name] = normalized.astype(np.float32)
127
+ total_labels += 1
128
+
129
+ elapsed = time.perf_counter() - start_time
130
+ logger.info(
131
+ f"Pre-computed embeddings for {total_labels} labels "
132
+ f"across {len(self._label_embeddings)} categories in {elapsed:.2f}s"
133
+ )
134
+
135
+ def _calibrate_score(self, raw_similarity: float) -> float:
136
+ normalized = (raw_similarity + 1) / 2
137
+ exponent = -self.CALIBRATION_STEEPNESS * (normalized - self.SIMILARITY_THRESHOLD)
138
+ return 1.0 / (1.0 + np.exp(exponent))
139
+
140
+ def classify_hierarchical(
141
+ self,
142
+ text: str,
143
+ top_k_categories: int = 3,
144
+ min_score: float = 0.3,
145
+ ) -> list[tuple[str, str, float]]:
146
+ """Hierarchical classification: find categories, then all labels above threshold.
147
+
148
+ Returns all labels above the threshold, not just the best match.
149
+
150
+ Args:
151
+ text: Text to classify
152
+ top_k_categories: Number of top categories to consider
153
+ min_score: Minimum calibrated score to include in results
154
+
155
+ Returns:
156
+ List of (category, label, confidence) tuples above threshold
157
+ """
158
+ self._load_model()
159
+
160
+ if not self._label_embeddings:
161
+ raise RuntimeError("Label embeddings not pre-computed.")
162
+
163
+ # Check cache for input text embedding
164
+ if text in self._text_embedding_cache:
165
+ input_normalized = self._text_embedding_cache[text]
166
+ else:
167
+ input_embedding = self._model.encode(text, convert_to_numpy=True, show_progress_bar=False)
168
+ input_norm = np.linalg.norm(input_embedding)
169
+ input_normalized = (input_embedding / (input_norm + 1e-8)).astype(np.float32)
170
+ self._text_embedding_cache[text] = input_normalized
171
+ logger.debug(f"Cached embedding for text: '{text[:50]}...' (cache size: {len(self._text_embedding_cache)})")
172
+
173
+ # Compute average similarity to each category
174
+ category_scores: list[tuple[str, float]] = []
175
+ for category, labels in self._label_embeddings.items():
176
+ if not labels:
177
+ continue
178
+
179
+ sims = []
180
+ for label_embedding in labels.values():
181
+ sim = float(np.dot(input_normalized, label_embedding))
182
+ sims.append(sim)
183
+
184
+ avg_sim = np.mean(sims)
185
+ category_scores.append((category, avg_sim))
186
+
187
+ category_scores.sort(key=lambda x: x[1], reverse=True)
188
+
189
+ results: list[tuple[str, str, float]] = []
190
+
191
+ for category, _ in category_scores[:top_k_categories]:
192
+ for label, label_embedding in self._label_embeddings[category].items():
193
+ raw_sim = float(np.dot(input_normalized, label_embedding))
194
+ calibrated_score = self._calibrate_score(raw_sim)
195
+
196
+ if calibrated_score >= min_score:
197
+ results.append((category, label, calibrated_score))
198
+
199
+ # Sort by confidence descending
200
+ results.sort(key=lambda x: x[2], reverse=True)
201
+ return results
202
+
203
+
204
+ @PluginRegistry.taxonomy
205
+ class EmbeddingTaxonomyClassifier(BaseTaxonomyPlugin):
206
+ """
207
+ Taxonomy classifier using embedding similarity.
208
+
209
+ Faster than MNLI, good for high-throughput scenarios.
210
+ """
211
+
212
+ def __init__(
213
+ self,
214
+ taxonomy_path: Optional[str | Path] = None,
215
+ categories: Optional[list[str]] = None,
216
+ model_name: str = "google/embeddinggemma-300m",
217
+ top_k_categories: int = 3,
218
+ min_confidence: float = 0.8,
219
+ ):
220
+ self._taxonomy_path = Path(taxonomy_path) if taxonomy_path else DEFAULT_TAXONOMY_PATH
221
+ self._categories = categories or DEFAULT_CATEGORIES
222
+ self._model_name = model_name
223
+ self._top_k_categories = top_k_categories
224
+ self._min_confidence = min_confidence
225
+
226
+ self._taxonomy: Optional[dict[str, dict[str, int]]] = None
227
+ self._classifier: Optional[EmbeddingClassifier] = None
228
+ self._embeddings_computed = False
229
+
230
+ @property
231
+ def name(self) -> str:
232
+ return "embedding_taxonomy_classifier"
233
+
234
+ @property
235
+ def priority(self) -> int:
236
+ return 10 # High priority - default taxonomy classifier (faster than MNLI)
237
+
238
+ @property
239
+ def capabilities(self) -> PluginCapability:
240
+ return PluginCapability.LLM_REQUIRED | PluginCapability.BATCH_PROCESSING
241
+
242
+ @property
243
+ def description(self) -> str:
244
+ return "Classifies statements using embedding similarity (faster than MNLI)"
245
+
246
+ @property
247
+ def taxonomy_name(self) -> str:
248
+ return "esg_topics_embedding"
249
+
250
+ @property
251
+ def taxonomy_schema(self) -> TaxonomySchema:
252
+ taxonomy = self._load_taxonomy()
253
+ filtered = {cat: list(labels.keys()) for cat, labels in taxonomy.items() if cat in self._categories}
254
+ return TaxonomySchema(
255
+ label_type="taxonomy",
256
+ values=filtered,
257
+ description="ESG topic classification using embeddings",
258
+ scope="statement",
259
+ )
260
+
261
+ @property
262
+ def supported_categories(self) -> list[str]:
263
+ return self._categories.copy()
264
+
265
+ def _load_taxonomy(self) -> dict[str, dict[str, int]]:
266
+ if self._taxonomy is not None:
267
+ return self._taxonomy
268
+
269
+ if not self._taxonomy_path.exists():
270
+ raise FileNotFoundError(f"Taxonomy file not found: {self._taxonomy_path}")
271
+
272
+ with open(self._taxonomy_path) as f:
273
+ self._taxonomy = json.load(f)
274
+
275
+ logger.debug(f"Loaded taxonomy with {len(self._taxonomy)} categories")
276
+ return self._taxonomy
277
+
278
+ def _get_classifier(self) -> EmbeddingClassifier:
279
+ if self._classifier is None:
280
+ self._classifier = EmbeddingClassifier(model_name=self._model_name)
281
+
282
+ if not self._embeddings_computed:
283
+ taxonomy = self._load_taxonomy()
284
+ self._classifier.precompute_label_embeddings(taxonomy, self._categories)
285
+ self._embeddings_computed = True
286
+
287
+ return self._classifier
288
+
289
+ def classify(
290
+ self,
291
+ statement: PipelineStatement,
292
+ subject_canonical: CanonicalEntity,
293
+ object_canonical: CanonicalEntity,
294
+ context: PipelineContext,
295
+ ) -> list[TaxonomyResult]:
296
+ """Classify statement using embedding similarity.
297
+
298
+ Returns all labels above the confidence threshold.
299
+ """
300
+ results: list[TaxonomyResult] = []
301
+
302
+ try:
303
+ classifier = self._get_classifier()
304
+ text = statement.source_text
305
+
306
+ classifications = classifier.classify_hierarchical(
307
+ text,
308
+ top_k_categories=self._top_k_categories,
309
+ min_score=self._min_confidence,
310
+ )
311
+
312
+ for category, label, confidence in classifications:
313
+ label_id = self._get_label_id(category, label)
314
+
315
+ results.append(TaxonomyResult(
316
+ taxonomy_name=self.taxonomy_name,
317
+ category=category,
318
+ label=label,
319
+ label_id=label_id,
320
+ confidence=round(confidence, 4),
321
+ classifier=self.name,
322
+ ))
323
+
324
+ except Exception as e:
325
+ logger.warning(f"Embedding taxonomy classification failed: {e}")
326
+
327
+ return results
328
+
329
+ def _get_label_id(self, category: str, label: str) -> Optional[int]:
330
+ taxonomy = self._load_taxonomy()
331
+ if category in taxonomy:
332
+ return taxonomy[category].get(label)
333
+ return None
334
+
335
+
336
+ # For testing without decorator
337
+ EmbeddingTaxonomyClassifierClass = EmbeddingTaxonomyClassifier