corp-extractor 0.4.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {corp_extractor-0.4.0.dist-info → corp_extractor-0.9.0.dist-info}/METADATA +348 -64
- corp_extractor-0.9.0.dist-info/RECORD +76 -0
- statement_extractor/__init__.py +10 -1
- statement_extractor/cli.py +1663 -17
- statement_extractor/data/default_predicates.json +368 -0
- statement_extractor/data/statement_taxonomy.json +6972 -0
- statement_extractor/database/__init__.py +52 -0
- statement_extractor/database/embeddings.py +186 -0
- statement_extractor/database/hub.py +520 -0
- statement_extractor/database/importers/__init__.py +24 -0
- statement_extractor/database/importers/companies_house.py +545 -0
- statement_extractor/database/importers/gleif.py +538 -0
- statement_extractor/database/importers/sec_edgar.py +375 -0
- statement_extractor/database/importers/wikidata.py +1012 -0
- statement_extractor/database/importers/wikidata_people.py +632 -0
- statement_extractor/database/models.py +230 -0
- statement_extractor/database/resolver.py +245 -0
- statement_extractor/database/store.py +1609 -0
- statement_extractor/document/__init__.py +62 -0
- statement_extractor/document/chunker.py +410 -0
- statement_extractor/document/context.py +171 -0
- statement_extractor/document/deduplicator.py +173 -0
- statement_extractor/document/html_extractor.py +246 -0
- statement_extractor/document/loader.py +303 -0
- statement_extractor/document/pipeline.py +388 -0
- statement_extractor/document/summarizer.py +195 -0
- statement_extractor/extractor.py +1 -23
- statement_extractor/gliner_extraction.py +4 -74
- statement_extractor/llm.py +255 -0
- statement_extractor/models/__init__.py +89 -0
- statement_extractor/models/canonical.py +182 -0
- statement_extractor/models/document.py +308 -0
- statement_extractor/models/entity.py +102 -0
- statement_extractor/models/labels.py +220 -0
- statement_extractor/models/qualifiers.py +139 -0
- statement_extractor/models/statement.py +101 -0
- statement_extractor/models.py +4 -1
- statement_extractor/pipeline/__init__.py +39 -0
- statement_extractor/pipeline/config.py +129 -0
- statement_extractor/pipeline/context.py +177 -0
- statement_extractor/pipeline/orchestrator.py +416 -0
- statement_extractor/pipeline/registry.py +303 -0
- statement_extractor/plugins/__init__.py +55 -0
- statement_extractor/plugins/base.py +716 -0
- statement_extractor/plugins/extractors/__init__.py +13 -0
- statement_extractor/plugins/extractors/base.py +9 -0
- statement_extractor/plugins/extractors/gliner2.py +546 -0
- statement_extractor/plugins/labelers/__init__.py +29 -0
- statement_extractor/plugins/labelers/base.py +9 -0
- statement_extractor/plugins/labelers/confidence.py +138 -0
- statement_extractor/plugins/labelers/relation_type.py +87 -0
- statement_extractor/plugins/labelers/sentiment.py +159 -0
- statement_extractor/plugins/labelers/taxonomy.py +386 -0
- statement_extractor/plugins/labelers/taxonomy_embedding.py +477 -0
- statement_extractor/plugins/pdf/__init__.py +10 -0
- statement_extractor/plugins/pdf/pypdf.py +291 -0
- statement_extractor/plugins/qualifiers/__init__.py +30 -0
- statement_extractor/plugins/qualifiers/base.py +9 -0
- statement_extractor/plugins/qualifiers/companies_house.py +185 -0
- statement_extractor/plugins/qualifiers/embedding_company.py +420 -0
- statement_extractor/plugins/qualifiers/gleif.py +197 -0
- statement_extractor/plugins/qualifiers/person.py +785 -0
- statement_extractor/plugins/qualifiers/sec_edgar.py +209 -0
- statement_extractor/plugins/scrapers/__init__.py +10 -0
- statement_extractor/plugins/scrapers/http.py +236 -0
- statement_extractor/plugins/splitters/__init__.py +13 -0
- statement_extractor/plugins/splitters/base.py +9 -0
- statement_extractor/plugins/splitters/t5_gemma.py +293 -0
- statement_extractor/plugins/taxonomy/__init__.py +13 -0
- statement_extractor/plugins/taxonomy/embedding.py +484 -0
- statement_extractor/plugins/taxonomy/mnli.py +291 -0
- statement_extractor/scoring.py +8 -8
- corp_extractor-0.4.0.dist-info/RECORD +0 -12
- {corp_extractor-0.4.0.dist-info → corp_extractor-0.9.0.dist-info}/WHEEL +0 -0
- {corp_extractor-0.4.0.dist-info → corp_extractor-0.9.0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,484 @@
|
|
|
1
|
+
"""
|
|
2
|
+
EmbeddingTaxonomyClassifier - Classifies statements using embedding similarity.
|
|
3
|
+
|
|
4
|
+
Uses sentence-transformers to embed text and compare to pre-computed label
|
|
5
|
+
embeddings using cosine similarity with sigmoid calibration.
|
|
6
|
+
|
|
7
|
+
Faster than MNLI but may be less accurate for nuanced classification.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import logging
|
|
12
|
+
import time
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Optional, TypedDict
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TaxonomyEntry(TypedDict):
|
|
20
|
+
"""Structure for each taxonomy label entry."""
|
|
21
|
+
description: str
|
|
22
|
+
id: int
|
|
23
|
+
mnli_label: str
|
|
24
|
+
embedding_label: str
|
|
25
|
+
|
|
26
|
+
from ..base import BaseTaxonomyPlugin, TaxonomySchema, PluginCapability
|
|
27
|
+
from ...pipeline.context import PipelineContext
|
|
28
|
+
from ...pipeline.registry import PluginRegistry
|
|
29
|
+
from ...models import (
|
|
30
|
+
PipelineStatement,
|
|
31
|
+
CanonicalEntity,
|
|
32
|
+
TaxonomyResult,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
logger = logging.getLogger(__name__)
|
|
36
|
+
|
|
37
|
+
# Default taxonomy file location
|
|
38
|
+
DEFAULT_TAXONOMY_PATH = Path(__file__).parent.parent.parent / "data" / "statement_taxonomy.json"
|
|
39
|
+
|
|
40
|
+
# Default categories
|
|
41
|
+
DEFAULT_CATEGORIES = [
|
|
42
|
+
"environment",
|
|
43
|
+
"society",
|
|
44
|
+
"governance",
|
|
45
|
+
"animals",
|
|
46
|
+
"industry",
|
|
47
|
+
"human_harm",
|
|
48
|
+
"human_benefit",
|
|
49
|
+
"animal_harm",
|
|
50
|
+
"animal_benefit",
|
|
51
|
+
"environment_harm",
|
|
52
|
+
"environment_benefit",
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class EmbeddingClassifier:
|
|
57
|
+
"""
|
|
58
|
+
Embedding-based classifier using cosine similarity.
|
|
59
|
+
|
|
60
|
+
Pre-computes embeddings for all labels and uses dot product
|
|
61
|
+
(on normalized vectors) for fast classification.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
SIMILARITY_THRESHOLD = 0.65
|
|
65
|
+
CALIBRATION_STEEPNESS = 25.0
|
|
66
|
+
|
|
67
|
+
def __init__(
|
|
68
|
+
self,
|
|
69
|
+
model_name: str = "google/embeddinggemma-300m",
|
|
70
|
+
device: Optional[str] = None,
|
|
71
|
+
):
|
|
72
|
+
self._model_name = model_name
|
|
73
|
+
self._device = device
|
|
74
|
+
self._model = None
|
|
75
|
+
self._label_embeddings: dict[str, dict[str, np.ndarray]] = {}
|
|
76
|
+
self._text_embedding_cache: dict[str, np.ndarray] = {} # Cache for input text embeddings
|
|
77
|
+
|
|
78
|
+
def _load_model(self):
|
|
79
|
+
if self._model is not None:
|
|
80
|
+
return
|
|
81
|
+
|
|
82
|
+
try:
|
|
83
|
+
from sentence_transformers import SentenceTransformer
|
|
84
|
+
import torch
|
|
85
|
+
|
|
86
|
+
device = self._device
|
|
87
|
+
if device is None:
|
|
88
|
+
if torch.cuda.is_available():
|
|
89
|
+
device = "cuda"
|
|
90
|
+
elif torch.backends.mps.is_available():
|
|
91
|
+
device = "mps"
|
|
92
|
+
else:
|
|
93
|
+
device = "cpu"
|
|
94
|
+
|
|
95
|
+
logger.info(f"Loading embedding model '{self._model_name}' on {device}...")
|
|
96
|
+
self._model = SentenceTransformer(self._model_name, device=device)
|
|
97
|
+
logger.debug("Embedding model loaded")
|
|
98
|
+
|
|
99
|
+
except ImportError as e:
|
|
100
|
+
raise ImportError(
|
|
101
|
+
"sentence-transformers is required for embedding classification. "
|
|
102
|
+
"Install with: pip install sentence-transformers"
|
|
103
|
+
) from e
|
|
104
|
+
|
|
105
|
+
def precompute_label_embeddings(
|
|
106
|
+
self,
|
|
107
|
+
taxonomy: dict[str, dict[str, TaxonomyEntry]],
|
|
108
|
+
categories: Optional[list[str]] = None,
|
|
109
|
+
) -> None:
|
|
110
|
+
"""Pre-compute embeddings for all label names."""
|
|
111
|
+
self._load_model()
|
|
112
|
+
|
|
113
|
+
start_time = time.perf_counter()
|
|
114
|
+
total_labels = 0
|
|
115
|
+
|
|
116
|
+
categories_to_process = categories or list(taxonomy.keys())
|
|
117
|
+
|
|
118
|
+
for category in categories_to_process:
|
|
119
|
+
if category not in taxonomy:
|
|
120
|
+
continue
|
|
121
|
+
|
|
122
|
+
labels = taxonomy[category]
|
|
123
|
+
label_names = list(labels.keys())
|
|
124
|
+
|
|
125
|
+
if not label_names:
|
|
126
|
+
continue
|
|
127
|
+
|
|
128
|
+
embeddings = self._model.encode(label_names, convert_to_numpy=True, show_progress_bar=False)
|
|
129
|
+
|
|
130
|
+
self._label_embeddings[category] = {}
|
|
131
|
+
for label_name, embedding in zip(label_names, embeddings):
|
|
132
|
+
norm = np.linalg.norm(embedding)
|
|
133
|
+
normalized = embedding / (norm + 1e-8)
|
|
134
|
+
self._label_embeddings[category][label_name] = normalized.astype(np.float32)
|
|
135
|
+
total_labels += 1
|
|
136
|
+
|
|
137
|
+
elapsed = time.perf_counter() - start_time
|
|
138
|
+
logger.info(
|
|
139
|
+
f"Pre-computed embeddings for {total_labels} labels "
|
|
140
|
+
f"across {len(self._label_embeddings)} categories in {elapsed:.2f}s"
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
def _calibrate_score(self, raw_similarity: float) -> float:
|
|
144
|
+
normalized = (raw_similarity + 1) / 2
|
|
145
|
+
exponent = -self.CALIBRATION_STEEPNESS * (normalized - self.SIMILARITY_THRESHOLD)
|
|
146
|
+
return 1.0 / (1.0 + np.exp(exponent))
|
|
147
|
+
|
|
148
|
+
def encode_batch(self, texts: list[str]) -> np.ndarray:
|
|
149
|
+
"""
|
|
150
|
+
Encode multiple texts into normalized embeddings in a single batch.
|
|
151
|
+
|
|
152
|
+
Uses caching to avoid re-encoding previously seen texts.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
texts: List of texts to encode
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
2D numpy array of shape (len(texts), embedding_dim) with normalized embeddings
|
|
159
|
+
"""
|
|
160
|
+
self._load_model()
|
|
161
|
+
|
|
162
|
+
# Separate cached from uncached texts
|
|
163
|
+
uncached_indices = []
|
|
164
|
+
uncached_texts = []
|
|
165
|
+
for i, text in enumerate(texts):
|
|
166
|
+
if text not in self._text_embedding_cache:
|
|
167
|
+
uncached_indices.append(i)
|
|
168
|
+
uncached_texts.append(text)
|
|
169
|
+
|
|
170
|
+
# Batch encode uncached texts
|
|
171
|
+
if uncached_texts:
|
|
172
|
+
embeddings = self._model.encode(uncached_texts, convert_to_numpy=True, show_progress_bar=False)
|
|
173
|
+
for i, (text, embedding) in enumerate(zip(uncached_texts, embeddings)):
|
|
174
|
+
norm = np.linalg.norm(embedding)
|
|
175
|
+
normalized = (embedding / (norm + 1e-8)).astype(np.float32)
|
|
176
|
+
self._text_embedding_cache[text] = normalized
|
|
177
|
+
|
|
178
|
+
logger.debug(f"Batch encoded {len(uncached_texts)} texts (cache size: {len(self._text_embedding_cache)})")
|
|
179
|
+
|
|
180
|
+
# Build result array from cache
|
|
181
|
+
result = np.stack([self._text_embedding_cache[text] for text in texts])
|
|
182
|
+
return result
|
|
183
|
+
|
|
184
|
+
def classify_batch(
|
|
185
|
+
self,
|
|
186
|
+
texts: list[str],
|
|
187
|
+
top_k_categories: int = 3,
|
|
188
|
+
min_score: float = 0.3,
|
|
189
|
+
) -> list[list[tuple[str, str, float]]]:
|
|
190
|
+
"""
|
|
191
|
+
Classify multiple texts in a single batch for efficiency.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
texts: List of texts to classify
|
|
195
|
+
top_k_categories: Number of top categories to consider per text
|
|
196
|
+
min_score: Minimum calibrated score to include in results
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
List of classification results, one list per input text
|
|
200
|
+
"""
|
|
201
|
+
if not texts:
|
|
202
|
+
return []
|
|
203
|
+
|
|
204
|
+
self._load_model()
|
|
205
|
+
|
|
206
|
+
if not self._label_embeddings:
|
|
207
|
+
raise RuntimeError("Label embeddings not pre-computed.")
|
|
208
|
+
|
|
209
|
+
# Batch encode all texts
|
|
210
|
+
input_embeddings = self.encode_batch(texts)
|
|
211
|
+
|
|
212
|
+
# Prepare label embeddings as matrices for vectorized similarity
|
|
213
|
+
all_results: list[list[tuple[str, str, float]]] = []
|
|
214
|
+
|
|
215
|
+
for input_normalized in input_embeddings:
|
|
216
|
+
# Compute average similarity to each category
|
|
217
|
+
category_scores: list[tuple[str, float]] = []
|
|
218
|
+
for category, labels in self._label_embeddings.items():
|
|
219
|
+
if not labels:
|
|
220
|
+
continue
|
|
221
|
+
|
|
222
|
+
sims = []
|
|
223
|
+
for label_embedding in labels.values():
|
|
224
|
+
sim = float(np.dot(input_normalized, label_embedding))
|
|
225
|
+
sims.append(sim)
|
|
226
|
+
|
|
227
|
+
avg_sim = np.mean(sims)
|
|
228
|
+
category_scores.append((category, avg_sim))
|
|
229
|
+
|
|
230
|
+
category_scores.sort(key=lambda x: x[1], reverse=True)
|
|
231
|
+
|
|
232
|
+
results: list[tuple[str, str, float]] = []
|
|
233
|
+
|
|
234
|
+
for category, _ in category_scores[:top_k_categories]:
|
|
235
|
+
for label, label_embedding in self._label_embeddings[category].items():
|
|
236
|
+
raw_sim = float(np.dot(input_normalized, label_embedding))
|
|
237
|
+
calibrated_score = self._calibrate_score(raw_sim)
|
|
238
|
+
|
|
239
|
+
if calibrated_score >= min_score:
|
|
240
|
+
results.append((category, label, calibrated_score))
|
|
241
|
+
|
|
242
|
+
# Sort by confidence descending
|
|
243
|
+
results.sort(key=lambda x: x[2], reverse=True)
|
|
244
|
+
all_results.append(results)
|
|
245
|
+
|
|
246
|
+
return all_results
|
|
247
|
+
|
|
248
|
+
def classify_hierarchical(
|
|
249
|
+
self,
|
|
250
|
+
text: str,
|
|
251
|
+
top_k_categories: int = 3,
|
|
252
|
+
min_score: float = 0.3,
|
|
253
|
+
) -> list[tuple[str, str, float]]:
|
|
254
|
+
"""Hierarchical classification: find categories, then all labels above threshold.
|
|
255
|
+
|
|
256
|
+
Returns all labels above the threshold, not just the best match.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
text: Text to classify
|
|
260
|
+
top_k_categories: Number of top categories to consider
|
|
261
|
+
min_score: Minimum calibrated score to include in results
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
List of (category, label, confidence) tuples above threshold
|
|
265
|
+
"""
|
|
266
|
+
# Use batch method for single text
|
|
267
|
+
results = self.classify_batch([text], top_k_categories, min_score)
|
|
268
|
+
return results[0] if results else []
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
@PluginRegistry.taxonomy
|
|
272
|
+
class EmbeddingTaxonomyClassifier(BaseTaxonomyPlugin):
|
|
273
|
+
"""
|
|
274
|
+
Taxonomy classifier using embedding similarity.
|
|
275
|
+
|
|
276
|
+
Faster than MNLI, good for high-throughput scenarios.
|
|
277
|
+
"""
|
|
278
|
+
|
|
279
|
+
def __init__(
|
|
280
|
+
self,
|
|
281
|
+
taxonomy_path: Optional[str | Path] = None,
|
|
282
|
+
categories: Optional[list[str]] = None,
|
|
283
|
+
model_name: str = "google/embeddinggemma-300m",
|
|
284
|
+
top_k_categories: int = 3,
|
|
285
|
+
min_confidence: float = 0.8,
|
|
286
|
+
):
|
|
287
|
+
self._taxonomy_path = Path(taxonomy_path) if taxonomy_path else DEFAULT_TAXONOMY_PATH
|
|
288
|
+
self._categories = categories or DEFAULT_CATEGORIES
|
|
289
|
+
self._model_name = model_name
|
|
290
|
+
self._top_k_categories = top_k_categories
|
|
291
|
+
self._min_confidence = min_confidence
|
|
292
|
+
|
|
293
|
+
self._taxonomy: Optional[dict[str, dict[str, TaxonomyEntry]]] = None
|
|
294
|
+
self._classifier: Optional[EmbeddingClassifier] = None
|
|
295
|
+
self._embeddings_computed = False
|
|
296
|
+
|
|
297
|
+
@property
|
|
298
|
+
def name(self) -> str:
|
|
299
|
+
return "embedding_taxonomy_classifier"
|
|
300
|
+
|
|
301
|
+
@property
|
|
302
|
+
def priority(self) -> int:
|
|
303
|
+
return 10 # High priority - default taxonomy classifier (faster than MNLI)
|
|
304
|
+
|
|
305
|
+
@property
|
|
306
|
+
def capabilities(self) -> PluginCapability:
|
|
307
|
+
return PluginCapability.LLM_REQUIRED | PluginCapability.BATCH_PROCESSING
|
|
308
|
+
|
|
309
|
+
@property
|
|
310
|
+
def description(self) -> str:
|
|
311
|
+
return "Classifies statements using embedding similarity (faster than MNLI)"
|
|
312
|
+
|
|
313
|
+
@property
|
|
314
|
+
def model_vram_gb(self) -> float:
|
|
315
|
+
"""EmbeddingGemma model weights ~1.2GB."""
|
|
316
|
+
return 1.2
|
|
317
|
+
|
|
318
|
+
@property
|
|
319
|
+
def per_item_vram_gb(self) -> float:
|
|
320
|
+
"""Each text embedding ~0.05GB (embeddings are small)."""
|
|
321
|
+
return 0.05
|
|
322
|
+
|
|
323
|
+
@property
|
|
324
|
+
def taxonomy_name(self) -> str:
|
|
325
|
+
return "esg_topics_embedding"
|
|
326
|
+
|
|
327
|
+
@property
|
|
328
|
+
def taxonomy_schema(self) -> TaxonomySchema:
|
|
329
|
+
taxonomy = self._load_taxonomy()
|
|
330
|
+
filtered = {cat: list(labels.keys()) for cat, labels in taxonomy.items() if cat in self._categories}
|
|
331
|
+
return TaxonomySchema(
|
|
332
|
+
label_type="taxonomy",
|
|
333
|
+
values=filtered,
|
|
334
|
+
description="ESG topic classification using embeddings",
|
|
335
|
+
scope="statement",
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
@property
|
|
339
|
+
def supported_categories(self) -> list[str]:
|
|
340
|
+
return self._categories.copy()
|
|
341
|
+
|
|
342
|
+
def _load_taxonomy(self) -> dict[str, dict[str, TaxonomyEntry]]:
|
|
343
|
+
if self._taxonomy is not None:
|
|
344
|
+
return self._taxonomy
|
|
345
|
+
|
|
346
|
+
if not self._taxonomy_path.exists():
|
|
347
|
+
raise FileNotFoundError(f"Taxonomy file not found: {self._taxonomy_path}")
|
|
348
|
+
|
|
349
|
+
with open(self._taxonomy_path) as f:
|
|
350
|
+
self._taxonomy = json.load(f)
|
|
351
|
+
|
|
352
|
+
logger.debug(f"Loaded taxonomy with {len(self._taxonomy)} categories")
|
|
353
|
+
return self._taxonomy
|
|
354
|
+
|
|
355
|
+
def _get_classifier(self) -> EmbeddingClassifier:
|
|
356
|
+
if self._classifier is None:
|
|
357
|
+
self._classifier = EmbeddingClassifier(model_name=self._model_name)
|
|
358
|
+
|
|
359
|
+
if not self._embeddings_computed:
|
|
360
|
+
taxonomy = self._load_taxonomy()
|
|
361
|
+
self._classifier.precompute_label_embeddings(taxonomy, self._categories)
|
|
362
|
+
self._embeddings_computed = True
|
|
363
|
+
|
|
364
|
+
return self._classifier
|
|
365
|
+
|
|
366
|
+
def classify(
|
|
367
|
+
self,
|
|
368
|
+
statement: PipelineStatement,
|
|
369
|
+
subject_canonical: CanonicalEntity,
|
|
370
|
+
object_canonical: CanonicalEntity,
|
|
371
|
+
context: PipelineContext,
|
|
372
|
+
) -> list[TaxonomyResult]:
|
|
373
|
+
"""Classify statement using embedding similarity.
|
|
374
|
+
|
|
375
|
+
Returns all labels above the confidence threshold.
|
|
376
|
+
"""
|
|
377
|
+
results: list[TaxonomyResult] = []
|
|
378
|
+
|
|
379
|
+
try:
|
|
380
|
+
classifier = self._get_classifier()
|
|
381
|
+
text = statement.source_text
|
|
382
|
+
|
|
383
|
+
classifications = classifier.classify_hierarchical(
|
|
384
|
+
text,
|
|
385
|
+
top_k_categories=self._top_k_categories,
|
|
386
|
+
min_score=self._min_confidence,
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
for category, label, confidence in classifications:
|
|
390
|
+
label_id = self._get_label_id(category, label)
|
|
391
|
+
|
|
392
|
+
results.append(TaxonomyResult(
|
|
393
|
+
taxonomy_name=self.taxonomy_name,
|
|
394
|
+
category=category,
|
|
395
|
+
label=label,
|
|
396
|
+
label_id=label_id,
|
|
397
|
+
confidence=round(confidence, 4),
|
|
398
|
+
classifier=self.name,
|
|
399
|
+
))
|
|
400
|
+
|
|
401
|
+
except Exception as e:
|
|
402
|
+
logger.warning(f"Embedding taxonomy classification failed: {e}")
|
|
403
|
+
|
|
404
|
+
return results
|
|
405
|
+
|
|
406
|
+
def _get_label_id(self, category: str, label: str) -> Optional[int]:
|
|
407
|
+
taxonomy = self._load_taxonomy()
|
|
408
|
+
if category in taxonomy:
|
|
409
|
+
entry = taxonomy[category].get(label)
|
|
410
|
+
if entry:
|
|
411
|
+
return entry.get("id")
|
|
412
|
+
return None
|
|
413
|
+
|
|
414
|
+
def classify_batch(
|
|
415
|
+
self,
|
|
416
|
+
items: list[tuple[PipelineStatement, CanonicalEntity, CanonicalEntity]],
|
|
417
|
+
context: PipelineContext,
|
|
418
|
+
) -> list[list[TaxonomyResult]]:
|
|
419
|
+
"""
|
|
420
|
+
Classify multiple statements in a single batch for efficiency.
|
|
421
|
+
|
|
422
|
+
Batch encodes all source texts, then classifies each against the taxonomy.
|
|
423
|
+
|
|
424
|
+
Args:
|
|
425
|
+
items: List of (statement, subject_canonical, object_canonical) tuples
|
|
426
|
+
context: Pipeline context
|
|
427
|
+
|
|
428
|
+
Returns:
|
|
429
|
+
List of TaxonomyResult lists, one per input statement
|
|
430
|
+
"""
|
|
431
|
+
if not items:
|
|
432
|
+
return []
|
|
433
|
+
|
|
434
|
+
# Extract unique source texts (may have duplicates across statements)
|
|
435
|
+
texts = [stmt.source_text for stmt, _, _ in items]
|
|
436
|
+
unique_texts = list(set(texts))
|
|
437
|
+
|
|
438
|
+
logger.info(f"Batch classifying {len(items)} statements ({len(unique_texts)} unique texts)")
|
|
439
|
+
|
|
440
|
+
try:
|
|
441
|
+
classifier = self._get_classifier()
|
|
442
|
+
|
|
443
|
+
# Batch classify all unique texts
|
|
444
|
+
batch_results = classifier.classify_batch(
|
|
445
|
+
unique_texts,
|
|
446
|
+
top_k_categories=self._top_k_categories,
|
|
447
|
+
min_score=self._min_confidence,
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
# Map unique texts to their classifications
|
|
451
|
+
text_to_results: dict[str, list[tuple[str, str, float]]] = {
|
|
452
|
+
text: results for text, results in zip(unique_texts, batch_results)
|
|
453
|
+
}
|
|
454
|
+
|
|
455
|
+
# Build results for each input statement
|
|
456
|
+
all_results: list[list[TaxonomyResult]] = []
|
|
457
|
+
for stmt, _, _ in items:
|
|
458
|
+
classifications = text_to_results.get(stmt.source_text, [])
|
|
459
|
+
|
|
460
|
+
results: list[TaxonomyResult] = []
|
|
461
|
+
for category, label, confidence in classifications:
|
|
462
|
+
label_id = self._get_label_id(category, label)
|
|
463
|
+
|
|
464
|
+
results.append(TaxonomyResult(
|
|
465
|
+
taxonomy_name=self.taxonomy_name,
|
|
466
|
+
category=category,
|
|
467
|
+
label=label,
|
|
468
|
+
label_id=label_id,
|
|
469
|
+
confidence=round(confidence, 4),
|
|
470
|
+
classifier=self.name,
|
|
471
|
+
))
|
|
472
|
+
|
|
473
|
+
all_results.append(results)
|
|
474
|
+
|
|
475
|
+
return all_results
|
|
476
|
+
|
|
477
|
+
except Exception as e:
|
|
478
|
+
logger.warning(f"Batch taxonomy classification failed: {e}")
|
|
479
|
+
# Return empty results for all items
|
|
480
|
+
return [[] for _ in items]
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
# For testing without decorator
|
|
484
|
+
EmbeddingTaxonomyClassifierClass = EmbeddingTaxonomyClassifier
|