corp-extractor 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {corp_extractor-0.4.0.dist-info → corp_extractor-0.5.0.dist-info}/METADATA +181 -64
- corp_extractor-0.5.0.dist-info/RECORD +55 -0
- statement_extractor/__init__.py +9 -0
- statement_extractor/cli.py +446 -17
- statement_extractor/data/default_predicates.json +368 -0
- statement_extractor/data/statement_taxonomy.json +1182 -0
- statement_extractor/extractor.py +1 -23
- statement_extractor/gliner_extraction.py +4 -74
- statement_extractor/llm.py +255 -0
- statement_extractor/models/__init__.py +74 -0
- statement_extractor/models/canonical.py +139 -0
- statement_extractor/models/entity.py +102 -0
- statement_extractor/models/labels.py +191 -0
- statement_extractor/models/qualifiers.py +91 -0
- statement_extractor/models/statement.py +75 -0
- statement_extractor/models.py +4 -1
- statement_extractor/pipeline/__init__.py +39 -0
- statement_extractor/pipeline/config.py +134 -0
- statement_extractor/pipeline/context.py +177 -0
- statement_extractor/pipeline/orchestrator.py +447 -0
- statement_extractor/pipeline/registry.py +297 -0
- statement_extractor/plugins/__init__.py +43 -0
- statement_extractor/plugins/base.py +446 -0
- statement_extractor/plugins/canonicalizers/__init__.py +17 -0
- statement_extractor/plugins/canonicalizers/base.py +9 -0
- statement_extractor/plugins/canonicalizers/location.py +219 -0
- statement_extractor/plugins/canonicalizers/organization.py +230 -0
- statement_extractor/plugins/canonicalizers/person.py +242 -0
- statement_extractor/plugins/extractors/__init__.py +13 -0
- statement_extractor/plugins/extractors/base.py +9 -0
- statement_extractor/plugins/extractors/gliner2.py +536 -0
- statement_extractor/plugins/labelers/__init__.py +29 -0
- statement_extractor/plugins/labelers/base.py +9 -0
- statement_extractor/plugins/labelers/confidence.py +138 -0
- statement_extractor/plugins/labelers/relation_type.py +87 -0
- statement_extractor/plugins/labelers/sentiment.py +159 -0
- statement_extractor/plugins/labelers/taxonomy.py +373 -0
- statement_extractor/plugins/labelers/taxonomy_embedding.py +466 -0
- statement_extractor/plugins/qualifiers/__init__.py +19 -0
- statement_extractor/plugins/qualifiers/base.py +9 -0
- statement_extractor/plugins/qualifiers/companies_house.py +174 -0
- statement_extractor/plugins/qualifiers/gleif.py +186 -0
- statement_extractor/plugins/qualifiers/person.py +221 -0
- statement_extractor/plugins/qualifiers/sec_edgar.py +198 -0
- statement_extractor/plugins/splitters/__init__.py +13 -0
- statement_extractor/plugins/splitters/base.py +9 -0
- statement_extractor/plugins/splitters/t5_gemma.py +188 -0
- statement_extractor/plugins/taxonomy/__init__.py +13 -0
- statement_extractor/plugins/taxonomy/embedding.py +337 -0
- statement_extractor/plugins/taxonomy/mnli.py +279 -0
- corp_extractor-0.4.0.dist-info/RECORD +0 -12
- {corp_extractor-0.4.0.dist-info → corp_extractor-0.5.0.dist-info}/WHEEL +0 -0
- {corp_extractor-0.4.0.dist-info → corp_extractor-0.5.0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
"""
|
|
2
|
+
T5GemmaSplitter - Stage 1 plugin that wraps the existing StatementExtractor.
|
|
3
|
+
|
|
4
|
+
Uses T5-Gemma2 model with Diverse Beam Search to generate high-quality
|
|
5
|
+
subject-predicate-object triples from text.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
import re
|
|
10
|
+
import xml.etree.ElementTree as ET
|
|
11
|
+
from typing import Optional
|
|
12
|
+
|
|
13
|
+
from ..base import BaseSplitterPlugin, PluginCapability
|
|
14
|
+
from ...pipeline.context import PipelineContext
|
|
15
|
+
from ...pipeline.registry import PluginRegistry
|
|
16
|
+
from ...models import RawTriple
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@PluginRegistry.splitter
|
|
22
|
+
class T5GemmaSplitter(BaseSplitterPlugin):
|
|
23
|
+
"""
|
|
24
|
+
Splitter plugin that uses T5-Gemma2 for triple extraction.
|
|
25
|
+
|
|
26
|
+
Wraps the existing StatementExtractor from extractor.py to produce
|
|
27
|
+
RawTriple objects for the pipeline.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
model_id: Optional[str] = None,
|
|
33
|
+
device: Optional[str] = None,
|
|
34
|
+
num_beams: int = 4,
|
|
35
|
+
diversity_penalty: float = 1.0,
|
|
36
|
+
max_new_tokens: int = 2048,
|
|
37
|
+
):
|
|
38
|
+
"""
|
|
39
|
+
Initialize the T5Gemma splitter.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
model_id: HuggingFace model ID (defaults to Corp-o-Rate model)
|
|
43
|
+
device: Device to use (auto-detected if not specified)
|
|
44
|
+
num_beams: Number of beams for diverse beam search
|
|
45
|
+
diversity_penalty: Penalty for beam diversity
|
|
46
|
+
max_new_tokens: Maximum tokens to generate
|
|
47
|
+
"""
|
|
48
|
+
self._model_id = model_id
|
|
49
|
+
self._device = device
|
|
50
|
+
self._num_beams = num_beams
|
|
51
|
+
self._diversity_penalty = diversity_penalty
|
|
52
|
+
self._max_new_tokens = max_new_tokens
|
|
53
|
+
self._extractor = None
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def name(self) -> str:
|
|
57
|
+
return "t5_gemma_splitter"
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def priority(self) -> int:
|
|
61
|
+
return 10 # High priority - primary splitter
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def capabilities(self) -> PluginCapability:
|
|
65
|
+
return PluginCapability.LLM_REQUIRED
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def description(self) -> str:
|
|
69
|
+
return "T5-Gemma2 model for extracting triples using Diverse Beam Search"
|
|
70
|
+
|
|
71
|
+
def _get_extractor(self):
|
|
72
|
+
"""Lazy-load the StatementExtractor."""
|
|
73
|
+
if self._extractor is None:
|
|
74
|
+
from ...extractor import StatementExtractor
|
|
75
|
+
# Only pass model_id and device if they were explicitly set
|
|
76
|
+
kwargs = {}
|
|
77
|
+
if self._model_id is not None:
|
|
78
|
+
kwargs["model_id"] = self._model_id
|
|
79
|
+
if self._device is not None:
|
|
80
|
+
kwargs["device"] = self._device
|
|
81
|
+
self._extractor = StatementExtractor(**kwargs)
|
|
82
|
+
return self._extractor
|
|
83
|
+
|
|
84
|
+
def split(
|
|
85
|
+
self,
|
|
86
|
+
text: str,
|
|
87
|
+
context: PipelineContext,
|
|
88
|
+
) -> list[RawTriple]:
|
|
89
|
+
"""
|
|
90
|
+
Split text into raw triples using T5-Gemma2.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
text: Input text to split
|
|
94
|
+
context: Pipeline context
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
List of RawTriple objects
|
|
98
|
+
"""
|
|
99
|
+
logger.debug(f"T5GemmaSplitter processing {len(text)} chars")
|
|
100
|
+
|
|
101
|
+
# Get options from context if available
|
|
102
|
+
splitter_options = context.source_metadata.get("splitter_options", {})
|
|
103
|
+
num_beams = splitter_options.get("num_beams", self._num_beams)
|
|
104
|
+
diversity_penalty = splitter_options.get("diversity_penalty", self._diversity_penalty)
|
|
105
|
+
max_new_tokens = splitter_options.get("max_new_tokens", self._max_new_tokens)
|
|
106
|
+
|
|
107
|
+
# Create extraction options
|
|
108
|
+
from ...models import ExtractionOptions as LegacyExtractionOptions
|
|
109
|
+
options = LegacyExtractionOptions(
|
|
110
|
+
num_beams=num_beams,
|
|
111
|
+
diversity_penalty=diversity_penalty,
|
|
112
|
+
max_new_tokens=max_new_tokens,
|
|
113
|
+
# Disable GLiNER and dedup - we handle those in later stages
|
|
114
|
+
use_gliner_extraction=False,
|
|
115
|
+
embedding_dedup=False,
|
|
116
|
+
deduplicate=False,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Get raw XML from extractor
|
|
120
|
+
extractor = self._get_extractor()
|
|
121
|
+
xml_output = extractor.extract_as_xml(text, options)
|
|
122
|
+
|
|
123
|
+
# Parse XML to RawTriple objects
|
|
124
|
+
raw_triples = self._parse_xml_to_raw_triples(xml_output)
|
|
125
|
+
|
|
126
|
+
logger.info(f"T5GemmaSplitter produced {len(raw_triples)} raw triples")
|
|
127
|
+
return raw_triples
|
|
128
|
+
|
|
129
|
+
def _parse_xml_to_raw_triples(self, xml_output: str) -> list[RawTriple]:
|
|
130
|
+
"""Parse XML output into RawTriple objects."""
|
|
131
|
+
raw_triples = []
|
|
132
|
+
|
|
133
|
+
try:
|
|
134
|
+
root = ET.fromstring(xml_output)
|
|
135
|
+
except ET.ParseError as e:
|
|
136
|
+
logger.warning(f"XML parse error: {e}")
|
|
137
|
+
# Try to repair
|
|
138
|
+
xml_output = self._repair_xml(xml_output)
|
|
139
|
+
try:
|
|
140
|
+
root = ET.fromstring(xml_output)
|
|
141
|
+
except ET.ParseError:
|
|
142
|
+
logger.error("XML repair failed")
|
|
143
|
+
return raw_triples
|
|
144
|
+
|
|
145
|
+
if root.tag != "statements":
|
|
146
|
+
logger.warning(f"Unexpected root tag: {root.tag}")
|
|
147
|
+
return raw_triples
|
|
148
|
+
|
|
149
|
+
for stmt_elem in root.findall("stmt"):
|
|
150
|
+
try:
|
|
151
|
+
subject_elem = stmt_elem.find("subject")
|
|
152
|
+
predicate_elem = stmt_elem.find("predicate")
|
|
153
|
+
object_elem = stmt_elem.find("object")
|
|
154
|
+
text_elem = stmt_elem.find("text")
|
|
155
|
+
|
|
156
|
+
subject_text = subject_elem.text.strip() if subject_elem is not None and subject_elem.text else ""
|
|
157
|
+
predicate_text = predicate_elem.text.strip() if predicate_elem is not None and predicate_elem.text else ""
|
|
158
|
+
object_text = object_elem.text.strip() if object_elem is not None and object_elem.text else ""
|
|
159
|
+
source_text = text_elem.text.strip() if text_elem is not None and text_elem.text else ""
|
|
160
|
+
|
|
161
|
+
if subject_text and object_text and source_text:
|
|
162
|
+
raw_triples.append(RawTriple(
|
|
163
|
+
subject_text=subject_text,
|
|
164
|
+
predicate_text=predicate_text,
|
|
165
|
+
object_text=object_text,
|
|
166
|
+
source_sentence=source_text,
|
|
167
|
+
))
|
|
168
|
+
else:
|
|
169
|
+
logger.debug(f"Skipping incomplete triple: s={subject_text}, p={predicate_text}, o={object_text}")
|
|
170
|
+
|
|
171
|
+
except Exception as e:
|
|
172
|
+
logger.warning(f"Error parsing stmt element: {e}")
|
|
173
|
+
continue
|
|
174
|
+
|
|
175
|
+
return raw_triples
|
|
176
|
+
|
|
177
|
+
def _repair_xml(self, xml_string: str) -> str:
|
|
178
|
+
"""Attempt to repair common XML syntax errors."""
|
|
179
|
+
# Use the repair function from extractor.py
|
|
180
|
+
from ...extractor import repair_xml
|
|
181
|
+
repaired, repairs = repair_xml(xml_string)
|
|
182
|
+
if repairs:
|
|
183
|
+
logger.debug(f"XML repairs: {', '.join(repairs)}")
|
|
184
|
+
return repaired
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
# Allow importing without decorator for testing
|
|
188
|
+
T5GemmaSplitterClass = T5GemmaSplitter
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Taxonomy classifier plugins for Stage 6 (Taxonomy).
|
|
3
|
+
|
|
4
|
+
Classifies statements against large taxonomies using MNLI or embeddings.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .mnli import MNLITaxonomyClassifier
|
|
8
|
+
from .embedding import EmbeddingTaxonomyClassifier
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"MNLITaxonomyClassifier",
|
|
12
|
+
"EmbeddingTaxonomyClassifier",
|
|
13
|
+
]
|
|
@@ -0,0 +1,337 @@
|
|
|
1
|
+
"""
|
|
2
|
+
EmbeddingTaxonomyClassifier - Classifies statements using embedding similarity.
|
|
3
|
+
|
|
4
|
+
Uses sentence-transformers to embed text and compare to pre-computed label
|
|
5
|
+
embeddings using cosine similarity with sigmoid calibration.
|
|
6
|
+
|
|
7
|
+
Faster than MNLI but may be less accurate for nuanced classification.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import logging
|
|
12
|
+
import time
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Optional
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
|
|
18
|
+
from ..base import BaseTaxonomyPlugin, TaxonomySchema, PluginCapability
|
|
19
|
+
from ...pipeline.context import PipelineContext
|
|
20
|
+
from ...pipeline.registry import PluginRegistry
|
|
21
|
+
from ...models import (
|
|
22
|
+
PipelineStatement,
|
|
23
|
+
CanonicalEntity,
|
|
24
|
+
TaxonomyResult,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
# Default taxonomy file location
|
|
30
|
+
DEFAULT_TAXONOMY_PATH = Path(__file__).parent.parent.parent / "data" / "statement_taxonomy.json"
|
|
31
|
+
|
|
32
|
+
# Default categories
|
|
33
|
+
DEFAULT_CATEGORIES = [
|
|
34
|
+
"environment",
|
|
35
|
+
"society",
|
|
36
|
+
"governance",
|
|
37
|
+
"animals",
|
|
38
|
+
"industry",
|
|
39
|
+
"human_harm",
|
|
40
|
+
"human_benefit",
|
|
41
|
+
"animal_harm",
|
|
42
|
+
"animal_benefit",
|
|
43
|
+
"environment_harm",
|
|
44
|
+
"environment_benefit",
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class EmbeddingClassifier:
|
|
49
|
+
"""
|
|
50
|
+
Embedding-based classifier using cosine similarity.
|
|
51
|
+
|
|
52
|
+
Pre-computes embeddings for all labels and uses dot product
|
|
53
|
+
(on normalized vectors) for fast classification.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
SIMILARITY_THRESHOLD = 0.65
|
|
57
|
+
CALIBRATION_STEEPNESS = 25.0
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
model_name: str = "google/embeddinggemma-300m",
|
|
62
|
+
device: Optional[str] = None,
|
|
63
|
+
):
|
|
64
|
+
self._model_name = model_name
|
|
65
|
+
self._device = device
|
|
66
|
+
self._model = None
|
|
67
|
+
self._label_embeddings: dict[str, dict[str, np.ndarray]] = {}
|
|
68
|
+
self._text_embedding_cache: dict[str, np.ndarray] = {} # Cache for input text embeddings
|
|
69
|
+
|
|
70
|
+
def _load_model(self):
|
|
71
|
+
if self._model is not None:
|
|
72
|
+
return
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
from sentence_transformers import SentenceTransformer
|
|
76
|
+
import torch
|
|
77
|
+
|
|
78
|
+
device = self._device
|
|
79
|
+
if device is None:
|
|
80
|
+
if torch.cuda.is_available():
|
|
81
|
+
device = "cuda"
|
|
82
|
+
elif torch.backends.mps.is_available():
|
|
83
|
+
device = "mps"
|
|
84
|
+
else:
|
|
85
|
+
device = "cpu"
|
|
86
|
+
|
|
87
|
+
logger.info(f"Loading embedding model '{self._model_name}' on {device}...")
|
|
88
|
+
self._model = SentenceTransformer(self._model_name, device=device)
|
|
89
|
+
logger.debug("Embedding model loaded")
|
|
90
|
+
|
|
91
|
+
except ImportError as e:
|
|
92
|
+
raise ImportError(
|
|
93
|
+
"sentence-transformers is required for embedding classification. "
|
|
94
|
+
"Install with: pip install sentence-transformers"
|
|
95
|
+
) from e
|
|
96
|
+
|
|
97
|
+
def precompute_label_embeddings(
|
|
98
|
+
self,
|
|
99
|
+
taxonomy: dict[str, dict[str, int]],
|
|
100
|
+
categories: Optional[list[str]] = None,
|
|
101
|
+
) -> None:
|
|
102
|
+
"""Pre-compute embeddings for all label names."""
|
|
103
|
+
self._load_model()
|
|
104
|
+
|
|
105
|
+
start_time = time.perf_counter()
|
|
106
|
+
total_labels = 0
|
|
107
|
+
|
|
108
|
+
categories_to_process = categories or list(taxonomy.keys())
|
|
109
|
+
|
|
110
|
+
for category in categories_to_process:
|
|
111
|
+
if category not in taxonomy:
|
|
112
|
+
continue
|
|
113
|
+
|
|
114
|
+
labels = taxonomy[category]
|
|
115
|
+
label_names = list(labels.keys())
|
|
116
|
+
|
|
117
|
+
if not label_names:
|
|
118
|
+
continue
|
|
119
|
+
|
|
120
|
+
embeddings = self._model.encode(label_names, convert_to_numpy=True, show_progress_bar=False)
|
|
121
|
+
|
|
122
|
+
self._label_embeddings[category] = {}
|
|
123
|
+
for label_name, embedding in zip(label_names, embeddings):
|
|
124
|
+
norm = np.linalg.norm(embedding)
|
|
125
|
+
normalized = embedding / (norm + 1e-8)
|
|
126
|
+
self._label_embeddings[category][label_name] = normalized.astype(np.float32)
|
|
127
|
+
total_labels += 1
|
|
128
|
+
|
|
129
|
+
elapsed = time.perf_counter() - start_time
|
|
130
|
+
logger.info(
|
|
131
|
+
f"Pre-computed embeddings for {total_labels} labels "
|
|
132
|
+
f"across {len(self._label_embeddings)} categories in {elapsed:.2f}s"
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
def _calibrate_score(self, raw_similarity: float) -> float:
|
|
136
|
+
normalized = (raw_similarity + 1) / 2
|
|
137
|
+
exponent = -self.CALIBRATION_STEEPNESS * (normalized - self.SIMILARITY_THRESHOLD)
|
|
138
|
+
return 1.0 / (1.0 + np.exp(exponent))
|
|
139
|
+
|
|
140
|
+
def classify_hierarchical(
|
|
141
|
+
self,
|
|
142
|
+
text: str,
|
|
143
|
+
top_k_categories: int = 3,
|
|
144
|
+
min_score: float = 0.3,
|
|
145
|
+
) -> list[tuple[str, str, float]]:
|
|
146
|
+
"""Hierarchical classification: find categories, then all labels above threshold.
|
|
147
|
+
|
|
148
|
+
Returns all labels above the threshold, not just the best match.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
text: Text to classify
|
|
152
|
+
top_k_categories: Number of top categories to consider
|
|
153
|
+
min_score: Minimum calibrated score to include in results
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
List of (category, label, confidence) tuples above threshold
|
|
157
|
+
"""
|
|
158
|
+
self._load_model()
|
|
159
|
+
|
|
160
|
+
if not self._label_embeddings:
|
|
161
|
+
raise RuntimeError("Label embeddings not pre-computed.")
|
|
162
|
+
|
|
163
|
+
# Check cache for input text embedding
|
|
164
|
+
if text in self._text_embedding_cache:
|
|
165
|
+
input_normalized = self._text_embedding_cache[text]
|
|
166
|
+
else:
|
|
167
|
+
input_embedding = self._model.encode(text, convert_to_numpy=True, show_progress_bar=False)
|
|
168
|
+
input_norm = np.linalg.norm(input_embedding)
|
|
169
|
+
input_normalized = (input_embedding / (input_norm + 1e-8)).astype(np.float32)
|
|
170
|
+
self._text_embedding_cache[text] = input_normalized
|
|
171
|
+
logger.debug(f"Cached embedding for text: '{text[:50]}...' (cache size: {len(self._text_embedding_cache)})")
|
|
172
|
+
|
|
173
|
+
# Compute average similarity to each category
|
|
174
|
+
category_scores: list[tuple[str, float]] = []
|
|
175
|
+
for category, labels in self._label_embeddings.items():
|
|
176
|
+
if not labels:
|
|
177
|
+
continue
|
|
178
|
+
|
|
179
|
+
sims = []
|
|
180
|
+
for label_embedding in labels.values():
|
|
181
|
+
sim = float(np.dot(input_normalized, label_embedding))
|
|
182
|
+
sims.append(sim)
|
|
183
|
+
|
|
184
|
+
avg_sim = np.mean(sims)
|
|
185
|
+
category_scores.append((category, avg_sim))
|
|
186
|
+
|
|
187
|
+
category_scores.sort(key=lambda x: x[1], reverse=True)
|
|
188
|
+
|
|
189
|
+
results: list[tuple[str, str, float]] = []
|
|
190
|
+
|
|
191
|
+
for category, _ in category_scores[:top_k_categories]:
|
|
192
|
+
for label, label_embedding in self._label_embeddings[category].items():
|
|
193
|
+
raw_sim = float(np.dot(input_normalized, label_embedding))
|
|
194
|
+
calibrated_score = self._calibrate_score(raw_sim)
|
|
195
|
+
|
|
196
|
+
if calibrated_score >= min_score:
|
|
197
|
+
results.append((category, label, calibrated_score))
|
|
198
|
+
|
|
199
|
+
# Sort by confidence descending
|
|
200
|
+
results.sort(key=lambda x: x[2], reverse=True)
|
|
201
|
+
return results
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@PluginRegistry.taxonomy
|
|
205
|
+
class EmbeddingTaxonomyClassifier(BaseTaxonomyPlugin):
|
|
206
|
+
"""
|
|
207
|
+
Taxonomy classifier using embedding similarity.
|
|
208
|
+
|
|
209
|
+
Faster than MNLI, good for high-throughput scenarios.
|
|
210
|
+
"""
|
|
211
|
+
|
|
212
|
+
def __init__(
|
|
213
|
+
self,
|
|
214
|
+
taxonomy_path: Optional[str | Path] = None,
|
|
215
|
+
categories: Optional[list[str]] = None,
|
|
216
|
+
model_name: str = "google/embeddinggemma-300m",
|
|
217
|
+
top_k_categories: int = 3,
|
|
218
|
+
min_confidence: float = 0.8,
|
|
219
|
+
):
|
|
220
|
+
self._taxonomy_path = Path(taxonomy_path) if taxonomy_path else DEFAULT_TAXONOMY_PATH
|
|
221
|
+
self._categories = categories or DEFAULT_CATEGORIES
|
|
222
|
+
self._model_name = model_name
|
|
223
|
+
self._top_k_categories = top_k_categories
|
|
224
|
+
self._min_confidence = min_confidence
|
|
225
|
+
|
|
226
|
+
self._taxonomy: Optional[dict[str, dict[str, int]]] = None
|
|
227
|
+
self._classifier: Optional[EmbeddingClassifier] = None
|
|
228
|
+
self._embeddings_computed = False
|
|
229
|
+
|
|
230
|
+
@property
|
|
231
|
+
def name(self) -> str:
|
|
232
|
+
return "embedding_taxonomy_classifier"
|
|
233
|
+
|
|
234
|
+
@property
|
|
235
|
+
def priority(self) -> int:
|
|
236
|
+
return 10 # High priority - default taxonomy classifier (faster than MNLI)
|
|
237
|
+
|
|
238
|
+
@property
|
|
239
|
+
def capabilities(self) -> PluginCapability:
|
|
240
|
+
return PluginCapability.LLM_REQUIRED | PluginCapability.BATCH_PROCESSING
|
|
241
|
+
|
|
242
|
+
@property
|
|
243
|
+
def description(self) -> str:
|
|
244
|
+
return "Classifies statements using embedding similarity (faster than MNLI)"
|
|
245
|
+
|
|
246
|
+
@property
|
|
247
|
+
def taxonomy_name(self) -> str:
|
|
248
|
+
return "esg_topics_embedding"
|
|
249
|
+
|
|
250
|
+
@property
|
|
251
|
+
def taxonomy_schema(self) -> TaxonomySchema:
|
|
252
|
+
taxonomy = self._load_taxonomy()
|
|
253
|
+
filtered = {cat: list(labels.keys()) for cat, labels in taxonomy.items() if cat in self._categories}
|
|
254
|
+
return TaxonomySchema(
|
|
255
|
+
label_type="taxonomy",
|
|
256
|
+
values=filtered,
|
|
257
|
+
description="ESG topic classification using embeddings",
|
|
258
|
+
scope="statement",
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
@property
|
|
262
|
+
def supported_categories(self) -> list[str]:
|
|
263
|
+
return self._categories.copy()
|
|
264
|
+
|
|
265
|
+
def _load_taxonomy(self) -> dict[str, dict[str, int]]:
|
|
266
|
+
if self._taxonomy is not None:
|
|
267
|
+
return self._taxonomy
|
|
268
|
+
|
|
269
|
+
if not self._taxonomy_path.exists():
|
|
270
|
+
raise FileNotFoundError(f"Taxonomy file not found: {self._taxonomy_path}")
|
|
271
|
+
|
|
272
|
+
with open(self._taxonomy_path) as f:
|
|
273
|
+
self._taxonomy = json.load(f)
|
|
274
|
+
|
|
275
|
+
logger.debug(f"Loaded taxonomy with {len(self._taxonomy)} categories")
|
|
276
|
+
return self._taxonomy
|
|
277
|
+
|
|
278
|
+
def _get_classifier(self) -> EmbeddingClassifier:
|
|
279
|
+
if self._classifier is None:
|
|
280
|
+
self._classifier = EmbeddingClassifier(model_name=self._model_name)
|
|
281
|
+
|
|
282
|
+
if not self._embeddings_computed:
|
|
283
|
+
taxonomy = self._load_taxonomy()
|
|
284
|
+
self._classifier.precompute_label_embeddings(taxonomy, self._categories)
|
|
285
|
+
self._embeddings_computed = True
|
|
286
|
+
|
|
287
|
+
return self._classifier
|
|
288
|
+
|
|
289
|
+
def classify(
|
|
290
|
+
self,
|
|
291
|
+
statement: PipelineStatement,
|
|
292
|
+
subject_canonical: CanonicalEntity,
|
|
293
|
+
object_canonical: CanonicalEntity,
|
|
294
|
+
context: PipelineContext,
|
|
295
|
+
) -> list[TaxonomyResult]:
|
|
296
|
+
"""Classify statement using embedding similarity.
|
|
297
|
+
|
|
298
|
+
Returns all labels above the confidence threshold.
|
|
299
|
+
"""
|
|
300
|
+
results: list[TaxonomyResult] = []
|
|
301
|
+
|
|
302
|
+
try:
|
|
303
|
+
classifier = self._get_classifier()
|
|
304
|
+
text = statement.source_text
|
|
305
|
+
|
|
306
|
+
classifications = classifier.classify_hierarchical(
|
|
307
|
+
text,
|
|
308
|
+
top_k_categories=self._top_k_categories,
|
|
309
|
+
min_score=self._min_confidence,
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
for category, label, confidence in classifications:
|
|
313
|
+
label_id = self._get_label_id(category, label)
|
|
314
|
+
|
|
315
|
+
results.append(TaxonomyResult(
|
|
316
|
+
taxonomy_name=self.taxonomy_name,
|
|
317
|
+
category=category,
|
|
318
|
+
label=label,
|
|
319
|
+
label_id=label_id,
|
|
320
|
+
confidence=round(confidence, 4),
|
|
321
|
+
classifier=self.name,
|
|
322
|
+
))
|
|
323
|
+
|
|
324
|
+
except Exception as e:
|
|
325
|
+
logger.warning(f"Embedding taxonomy classification failed: {e}")
|
|
326
|
+
|
|
327
|
+
return results
|
|
328
|
+
|
|
329
|
+
def _get_label_id(self, category: str, label: str) -> Optional[int]:
|
|
330
|
+
taxonomy = self._load_taxonomy()
|
|
331
|
+
if category in taxonomy:
|
|
332
|
+
return taxonomy[category].get(label)
|
|
333
|
+
return None
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
# For testing without decorator
|
|
337
|
+
EmbeddingTaxonomyClassifierClass = EmbeddingTaxonomyClassifier
|