corp-extractor 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {corp_extractor-0.3.0.dist-info → corp_extractor-0.5.0.dist-info}/METADATA +235 -96
- corp_extractor-0.5.0.dist-info/RECORD +55 -0
- statement_extractor/__init__.py +9 -0
- statement_extractor/cli.py +460 -21
- statement_extractor/data/default_predicates.json +368 -0
- statement_extractor/data/statement_taxonomy.json +1182 -0
- statement_extractor/extractor.py +32 -47
- statement_extractor/gliner_extraction.py +218 -0
- statement_extractor/llm.py +255 -0
- statement_extractor/models/__init__.py +74 -0
- statement_extractor/models/canonical.py +139 -0
- statement_extractor/models/entity.py +102 -0
- statement_extractor/models/labels.py +191 -0
- statement_extractor/models/qualifiers.py +91 -0
- statement_extractor/models/statement.py +75 -0
- statement_extractor/models.py +15 -6
- statement_extractor/pipeline/__init__.py +39 -0
- statement_extractor/pipeline/config.py +134 -0
- statement_extractor/pipeline/context.py +177 -0
- statement_extractor/pipeline/orchestrator.py +447 -0
- statement_extractor/pipeline/registry.py +297 -0
- statement_extractor/plugins/__init__.py +43 -0
- statement_extractor/plugins/base.py +446 -0
- statement_extractor/plugins/canonicalizers/__init__.py +17 -0
- statement_extractor/plugins/canonicalizers/base.py +9 -0
- statement_extractor/plugins/canonicalizers/location.py +219 -0
- statement_extractor/plugins/canonicalizers/organization.py +230 -0
- statement_extractor/plugins/canonicalizers/person.py +242 -0
- statement_extractor/plugins/extractors/__init__.py +13 -0
- statement_extractor/plugins/extractors/base.py +9 -0
- statement_extractor/plugins/extractors/gliner2.py +536 -0
- statement_extractor/plugins/labelers/__init__.py +29 -0
- statement_extractor/plugins/labelers/base.py +9 -0
- statement_extractor/plugins/labelers/confidence.py +138 -0
- statement_extractor/plugins/labelers/relation_type.py +87 -0
- statement_extractor/plugins/labelers/sentiment.py +159 -0
- statement_extractor/plugins/labelers/taxonomy.py +373 -0
- statement_extractor/plugins/labelers/taxonomy_embedding.py +466 -0
- statement_extractor/plugins/qualifiers/__init__.py +19 -0
- statement_extractor/plugins/qualifiers/base.py +9 -0
- statement_extractor/plugins/qualifiers/companies_house.py +174 -0
- statement_extractor/plugins/qualifiers/gleif.py +186 -0
- statement_extractor/plugins/qualifiers/person.py +221 -0
- statement_extractor/plugins/qualifiers/sec_edgar.py +198 -0
- statement_extractor/plugins/splitters/__init__.py +13 -0
- statement_extractor/plugins/splitters/base.py +9 -0
- statement_extractor/plugins/splitters/t5_gemma.py +188 -0
- statement_extractor/plugins/taxonomy/__init__.py +13 -0
- statement_extractor/plugins/taxonomy/embedding.py +337 -0
- statement_extractor/plugins/taxonomy/mnli.py +279 -0
- statement_extractor/scoring.py +17 -69
- corp_extractor-0.3.0.dist-info/RECORD +0 -12
- statement_extractor/spacy_extraction.py +0 -386
- {corp_extractor-0.3.0.dist-info → corp_extractor-0.5.0.dist-info}/WHEEL +0 -0
- {corp_extractor-0.3.0.dist-info → corp_extractor-0.5.0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,447 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ExtractionPipeline - Main orchestrator for the 5-stage extraction pipeline.
|
|
3
|
+
|
|
4
|
+
Coordinates the flow of data through all pipeline stages:
|
|
5
|
+
1. Splitting: Text → RawTriple
|
|
6
|
+
2. Extraction: RawTriple → PipelineStatement
|
|
7
|
+
3. Qualification: Entity → QualifiedEntity
|
|
8
|
+
4. Canonicalization: QualifiedEntity → CanonicalEntity
|
|
9
|
+
5. Labeling: Statement → LabeledStatement
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import logging
|
|
13
|
+
import time
|
|
14
|
+
from typing import Any, Optional
|
|
15
|
+
|
|
16
|
+
from .context import PipelineContext
|
|
17
|
+
from .config import PipelineConfig, get_stage_name
|
|
18
|
+
from .registry import PluginRegistry
|
|
19
|
+
from ..models import (
|
|
20
|
+
QualifiedEntity,
|
|
21
|
+
EntityQualifiers,
|
|
22
|
+
CanonicalEntity,
|
|
23
|
+
LabeledStatement,
|
|
24
|
+
TaxonomyResult,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ExtractionPipeline:
|
|
31
|
+
"""
|
|
32
|
+
Main pipeline orchestrator.
|
|
33
|
+
|
|
34
|
+
Coordinates the flow of data through all 5 stages, invoking registered
|
|
35
|
+
plugins in priority order and accumulating results in PipelineContext.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(self, config: Optional[PipelineConfig] = None):
|
|
39
|
+
"""
|
|
40
|
+
Initialize the pipeline.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
config: Pipeline configuration (uses defaults if not provided)
|
|
44
|
+
"""
|
|
45
|
+
self.config = config or PipelineConfig.default()
|
|
46
|
+
|
|
47
|
+
def process(
|
|
48
|
+
self,
|
|
49
|
+
text: str,
|
|
50
|
+
metadata: Optional[dict[str, Any]] = None,
|
|
51
|
+
) -> PipelineContext:
|
|
52
|
+
"""
|
|
53
|
+
Process text through the extraction pipeline.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
text: Input text to process
|
|
57
|
+
metadata: Optional metadata about the source
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
PipelineContext with accumulated results from all stages
|
|
61
|
+
"""
|
|
62
|
+
# Merge config options into metadata for plugins
|
|
63
|
+
combined_metadata = metadata.copy() if metadata else {}
|
|
64
|
+
|
|
65
|
+
# Pass extractor options from config to context
|
|
66
|
+
if self.config.extractor_options:
|
|
67
|
+
existing_extractor_opts = combined_metadata.get("extractor_options", {})
|
|
68
|
+
combined_metadata["extractor_options"] = {
|
|
69
|
+
**self.config.extractor_options,
|
|
70
|
+
**existing_extractor_opts, # Allow explicit metadata to override config
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
ctx = PipelineContext(
|
|
74
|
+
source_text=text,
|
|
75
|
+
source_metadata=combined_metadata,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
logger.info(f"Starting pipeline processing: {len(text)} chars")
|
|
79
|
+
|
|
80
|
+
try:
|
|
81
|
+
# Stage 1: Splitting
|
|
82
|
+
if self.config.is_stage_enabled(1):
|
|
83
|
+
ctx = self._run_splitting(ctx)
|
|
84
|
+
|
|
85
|
+
# Stage 2: Extraction
|
|
86
|
+
if self.config.is_stage_enabled(2):
|
|
87
|
+
ctx = self._run_extraction(ctx)
|
|
88
|
+
|
|
89
|
+
# Stage 3: Qualification
|
|
90
|
+
if self.config.is_stage_enabled(3):
|
|
91
|
+
ctx = self._run_qualification(ctx)
|
|
92
|
+
|
|
93
|
+
# Stage 4: Canonicalization
|
|
94
|
+
if self.config.is_stage_enabled(4):
|
|
95
|
+
ctx = self._run_canonicalization(ctx)
|
|
96
|
+
|
|
97
|
+
# Stage 5: Labeling
|
|
98
|
+
if self.config.is_stage_enabled(5):
|
|
99
|
+
ctx = self._run_labeling(ctx)
|
|
100
|
+
|
|
101
|
+
# Stage 6: Taxonomy classification
|
|
102
|
+
if self.config.is_stage_enabled(6):
|
|
103
|
+
ctx = self._run_taxonomy(ctx)
|
|
104
|
+
|
|
105
|
+
except Exception as e:
|
|
106
|
+
logger.exception("Pipeline processing failed")
|
|
107
|
+
ctx.add_error(f"Pipeline error: {str(e)}")
|
|
108
|
+
if self.config.fail_fast:
|
|
109
|
+
raise
|
|
110
|
+
|
|
111
|
+
logger.info(
|
|
112
|
+
f"Pipeline complete: {ctx.statement_count} statements, "
|
|
113
|
+
f"{len(ctx.processing_errors)} errors"
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
return ctx
|
|
117
|
+
|
|
118
|
+
def _run_splitting(self, ctx: PipelineContext) -> PipelineContext:
|
|
119
|
+
"""Stage 1: Split text into raw triples."""
|
|
120
|
+
stage_name = get_stage_name(1)
|
|
121
|
+
logger.debug(f"Running {stage_name} stage")
|
|
122
|
+
start_time = time.time()
|
|
123
|
+
|
|
124
|
+
splitters = PluginRegistry.get_splitters()
|
|
125
|
+
if not splitters:
|
|
126
|
+
ctx.add_warning("No splitter plugins registered")
|
|
127
|
+
return ctx
|
|
128
|
+
|
|
129
|
+
# Use first enabled splitter (highest priority)
|
|
130
|
+
for splitter in splitters:
|
|
131
|
+
if not self.config.is_plugin_enabled(splitter.name):
|
|
132
|
+
continue
|
|
133
|
+
|
|
134
|
+
logger.debug(f"Using splitter: {splitter.name}")
|
|
135
|
+
try:
|
|
136
|
+
raw_triples = splitter.split(ctx.source_text, ctx)
|
|
137
|
+
ctx.raw_triples = raw_triples
|
|
138
|
+
logger.info(f"Splitting produced {len(raw_triples)} raw triples")
|
|
139
|
+
break
|
|
140
|
+
except Exception as e:
|
|
141
|
+
logger.exception(f"Splitter {splitter.name} failed")
|
|
142
|
+
ctx.add_error(f"Splitter {splitter.name} failed: {str(e)}")
|
|
143
|
+
if self.config.fail_fast:
|
|
144
|
+
raise
|
|
145
|
+
|
|
146
|
+
ctx.record_timing(stage_name, time.time() - start_time)
|
|
147
|
+
return ctx
|
|
148
|
+
|
|
149
|
+
def _run_extraction(self, ctx: PipelineContext) -> PipelineContext:
|
|
150
|
+
"""Stage 2: Extract statements with typed entities from raw triples."""
|
|
151
|
+
stage_name = get_stage_name(2)
|
|
152
|
+
logger.debug(f"Running {stage_name} stage")
|
|
153
|
+
start_time = time.time()
|
|
154
|
+
|
|
155
|
+
if not ctx.raw_triples:
|
|
156
|
+
logger.debug("No raw triples to extract from")
|
|
157
|
+
return ctx
|
|
158
|
+
|
|
159
|
+
extractors = PluginRegistry.get_extractors()
|
|
160
|
+
if not extractors:
|
|
161
|
+
ctx.add_warning("No extractor plugins registered")
|
|
162
|
+
return ctx
|
|
163
|
+
|
|
164
|
+
# Collect classification schemas from labelers for the extractor
|
|
165
|
+
classification_schemas = self._collect_classification_schemas()
|
|
166
|
+
if classification_schemas:
|
|
167
|
+
logger.debug(f"Collected {len(classification_schemas)} classification schemas from labelers")
|
|
168
|
+
|
|
169
|
+
# Use first enabled extractor (highest priority)
|
|
170
|
+
for extractor in extractors:
|
|
171
|
+
if not self.config.is_plugin_enabled(extractor.name):
|
|
172
|
+
continue
|
|
173
|
+
|
|
174
|
+
# Pass classification schemas to extractor if it supports them
|
|
175
|
+
if classification_schemas and hasattr(extractor, 'add_classification_schema'):
|
|
176
|
+
for schema in classification_schemas:
|
|
177
|
+
extractor.add_classification_schema(schema)
|
|
178
|
+
|
|
179
|
+
logger.debug(f"Using extractor: {extractor.name}")
|
|
180
|
+
try:
|
|
181
|
+
statements = extractor.extract(ctx.raw_triples, ctx)
|
|
182
|
+
ctx.statements = statements
|
|
183
|
+
logger.info(f"Extraction produced {len(statements)} statements")
|
|
184
|
+
break
|
|
185
|
+
except Exception as e:
|
|
186
|
+
logger.exception(f"Extractor {extractor.name} failed")
|
|
187
|
+
ctx.add_error(f"Extractor {extractor.name} failed: {str(e)}")
|
|
188
|
+
if self.config.fail_fast:
|
|
189
|
+
raise
|
|
190
|
+
|
|
191
|
+
ctx.record_timing(stage_name, time.time() - start_time)
|
|
192
|
+
return ctx
|
|
193
|
+
|
|
194
|
+
def _collect_classification_schemas(self) -> list:
|
|
195
|
+
"""Collect classification schemas from enabled labelers."""
|
|
196
|
+
schemas = []
|
|
197
|
+
labelers = PluginRegistry.get_labelers()
|
|
198
|
+
|
|
199
|
+
for labeler in labelers:
|
|
200
|
+
if not self.config.is_plugin_enabled(labeler.name):
|
|
201
|
+
continue
|
|
202
|
+
|
|
203
|
+
# Check for classification schema (simple multi-choice)
|
|
204
|
+
if hasattr(labeler, 'classification_schema') and labeler.classification_schema:
|
|
205
|
+
schemas.append(labeler.classification_schema)
|
|
206
|
+
logger.debug(
|
|
207
|
+
f"Labeler {labeler.name} provides classification schema: "
|
|
208
|
+
f"{labeler.classification_schema}"
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
return schemas
|
|
212
|
+
|
|
213
|
+
def _run_qualification(self, ctx: PipelineContext) -> PipelineContext:
|
|
214
|
+
"""Stage 3: Add qualifiers to entities."""
|
|
215
|
+
stage_name = get_stage_name(3)
|
|
216
|
+
logger.debug(f"Running {stage_name} stage")
|
|
217
|
+
start_time = time.time()
|
|
218
|
+
|
|
219
|
+
if not ctx.statements:
|
|
220
|
+
logger.debug("No statements to qualify")
|
|
221
|
+
return ctx
|
|
222
|
+
|
|
223
|
+
# Collect all unique entities from statements
|
|
224
|
+
entities_to_qualify = {}
|
|
225
|
+
for stmt in ctx.statements:
|
|
226
|
+
for entity in [stmt.subject, stmt.object]:
|
|
227
|
+
if entity.entity_ref not in entities_to_qualify:
|
|
228
|
+
entities_to_qualify[entity.entity_ref] = entity
|
|
229
|
+
|
|
230
|
+
logger.debug(f"Qualifying {len(entities_to_qualify)} unique entities")
|
|
231
|
+
|
|
232
|
+
# Qualify each entity using applicable plugins
|
|
233
|
+
for entity_ref, entity in entities_to_qualify.items():
|
|
234
|
+
qualifiers = EntityQualifiers()
|
|
235
|
+
sources = []
|
|
236
|
+
|
|
237
|
+
# Get qualifiers for this entity type
|
|
238
|
+
type_qualifiers = PluginRegistry.get_qualifiers_for_type(entity.type)
|
|
239
|
+
|
|
240
|
+
for qualifier_plugin in type_qualifiers:
|
|
241
|
+
if not self.config.is_plugin_enabled(qualifier_plugin.name):
|
|
242
|
+
continue
|
|
243
|
+
|
|
244
|
+
try:
|
|
245
|
+
plugin_qualifiers = qualifier_plugin.qualify(entity, ctx)
|
|
246
|
+
if plugin_qualifiers and plugin_qualifiers.has_any_qualifier():
|
|
247
|
+
qualifiers = qualifiers.merge_with(plugin_qualifiers)
|
|
248
|
+
sources.append(qualifier_plugin.name)
|
|
249
|
+
except Exception as e:
|
|
250
|
+
logger.error(f"Qualifier {qualifier_plugin.name} failed for {entity.text}: {e}")
|
|
251
|
+
ctx.add_error(f"Qualifier {qualifier_plugin.name} failed: {str(e)}")
|
|
252
|
+
if self.config.fail_fast:
|
|
253
|
+
raise
|
|
254
|
+
|
|
255
|
+
# Create QualifiedEntity
|
|
256
|
+
qualified = QualifiedEntity(
|
|
257
|
+
entity_ref=entity_ref,
|
|
258
|
+
original_text=entity.text,
|
|
259
|
+
entity_type=entity.type,
|
|
260
|
+
qualifiers=qualifiers,
|
|
261
|
+
qualification_sources=sources,
|
|
262
|
+
)
|
|
263
|
+
ctx.qualified_entities[entity_ref] = qualified
|
|
264
|
+
|
|
265
|
+
logger.info(f"Qualified {len(ctx.qualified_entities)} entities")
|
|
266
|
+
ctx.record_timing(stage_name, time.time() - start_time)
|
|
267
|
+
return ctx
|
|
268
|
+
|
|
269
|
+
def _run_canonicalization(self, ctx: PipelineContext) -> PipelineContext:
|
|
270
|
+
"""Stage 4: Resolve entities to canonical forms."""
|
|
271
|
+
stage_name = get_stage_name(4)
|
|
272
|
+
logger.debug(f"Running {stage_name} stage")
|
|
273
|
+
start_time = time.time()
|
|
274
|
+
|
|
275
|
+
if not ctx.qualified_entities:
|
|
276
|
+
# Create basic qualified entities if stage 3 was skipped
|
|
277
|
+
for stmt in ctx.statements:
|
|
278
|
+
for entity in [stmt.subject, stmt.object]:
|
|
279
|
+
if entity.entity_ref not in ctx.qualified_entities:
|
|
280
|
+
ctx.qualified_entities[entity.entity_ref] = QualifiedEntity(
|
|
281
|
+
entity_ref=entity.entity_ref,
|
|
282
|
+
original_text=entity.text,
|
|
283
|
+
entity_type=entity.type,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
# Canonicalize each qualified entity
|
|
287
|
+
for entity_ref, qualified in ctx.qualified_entities.items():
|
|
288
|
+
canonical_match = None
|
|
289
|
+
fqn = None
|
|
290
|
+
|
|
291
|
+
# Get canonicalizers for this entity type
|
|
292
|
+
type_canonicalizers = PluginRegistry.get_canonicalizers_for_type(qualified.entity_type)
|
|
293
|
+
|
|
294
|
+
for canon_plugin in type_canonicalizers:
|
|
295
|
+
if not self.config.is_plugin_enabled(canon_plugin.name):
|
|
296
|
+
continue
|
|
297
|
+
|
|
298
|
+
try:
|
|
299
|
+
match = canon_plugin.find_canonical(qualified, ctx)
|
|
300
|
+
if match:
|
|
301
|
+
canonical_match = match
|
|
302
|
+
fqn = canon_plugin.format_fqn(qualified, match)
|
|
303
|
+
break # Use first successful match
|
|
304
|
+
except Exception as e:
|
|
305
|
+
logger.error(f"Canonicalizer {canon_plugin.name} failed for {qualified.original_text}: {e}")
|
|
306
|
+
ctx.add_error(f"Canonicalizer {canon_plugin.name} failed: {str(e)}")
|
|
307
|
+
if self.config.fail_fast:
|
|
308
|
+
raise
|
|
309
|
+
|
|
310
|
+
# Create CanonicalEntity
|
|
311
|
+
canonical = CanonicalEntity.from_qualified(
|
|
312
|
+
qualified=qualified,
|
|
313
|
+
canonical_match=canonical_match,
|
|
314
|
+
fqn=fqn,
|
|
315
|
+
)
|
|
316
|
+
ctx.canonical_entities[entity_ref] = canonical
|
|
317
|
+
|
|
318
|
+
logger.info(f"Canonicalized {len(ctx.canonical_entities)} entities")
|
|
319
|
+
ctx.record_timing(stage_name, time.time() - start_time)
|
|
320
|
+
return ctx
|
|
321
|
+
|
|
322
|
+
def _run_labeling(self, ctx: PipelineContext) -> PipelineContext:
|
|
323
|
+
"""Stage 5: Apply labels to statements."""
|
|
324
|
+
stage_name = get_stage_name(5)
|
|
325
|
+
logger.debug(f"Running {stage_name} stage")
|
|
326
|
+
start_time = time.time()
|
|
327
|
+
|
|
328
|
+
if not ctx.statements:
|
|
329
|
+
logger.debug("No statements to label")
|
|
330
|
+
return ctx
|
|
331
|
+
|
|
332
|
+
# Ensure canonical entities exist
|
|
333
|
+
if not ctx.canonical_entities:
|
|
334
|
+
self._run_canonicalization(ctx)
|
|
335
|
+
|
|
336
|
+
labelers = PluginRegistry.get_labelers()
|
|
337
|
+
|
|
338
|
+
for stmt in ctx.statements:
|
|
339
|
+
# Get canonical entities
|
|
340
|
+
subj_canonical = ctx.canonical_entities.get(stmt.subject.entity_ref)
|
|
341
|
+
obj_canonical = ctx.canonical_entities.get(stmt.object.entity_ref)
|
|
342
|
+
|
|
343
|
+
if not subj_canonical or not obj_canonical:
|
|
344
|
+
# Create fallback canonical entities
|
|
345
|
+
if not subj_canonical:
|
|
346
|
+
subj_qualified = ctx.qualified_entities.get(
|
|
347
|
+
stmt.subject.entity_ref,
|
|
348
|
+
QualifiedEntity(
|
|
349
|
+
entity_ref=stmt.subject.entity_ref,
|
|
350
|
+
original_text=stmt.subject.text,
|
|
351
|
+
entity_type=stmt.subject.type,
|
|
352
|
+
)
|
|
353
|
+
)
|
|
354
|
+
subj_canonical = CanonicalEntity.from_qualified(subj_qualified)
|
|
355
|
+
|
|
356
|
+
if not obj_canonical:
|
|
357
|
+
obj_qualified = ctx.qualified_entities.get(
|
|
358
|
+
stmt.object.entity_ref,
|
|
359
|
+
QualifiedEntity(
|
|
360
|
+
entity_ref=stmt.object.entity_ref,
|
|
361
|
+
original_text=stmt.object.text,
|
|
362
|
+
entity_type=stmt.object.type,
|
|
363
|
+
)
|
|
364
|
+
)
|
|
365
|
+
obj_canonical = CanonicalEntity.from_qualified(obj_qualified)
|
|
366
|
+
|
|
367
|
+
# Create labeled statement
|
|
368
|
+
labeled = LabeledStatement(
|
|
369
|
+
statement=stmt,
|
|
370
|
+
subject_canonical=subj_canonical,
|
|
371
|
+
object_canonical=obj_canonical,
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
# Apply all labelers
|
|
375
|
+
for labeler in labelers:
|
|
376
|
+
if not self.config.is_plugin_enabled(labeler.name):
|
|
377
|
+
continue
|
|
378
|
+
|
|
379
|
+
try:
|
|
380
|
+
label = labeler.label(stmt, subj_canonical, obj_canonical, ctx)
|
|
381
|
+
if label:
|
|
382
|
+
labeled.add_label(label)
|
|
383
|
+
except Exception as e:
|
|
384
|
+
logger.error(f"Labeler {labeler.name} failed: {e}")
|
|
385
|
+
ctx.add_error(f"Labeler {labeler.name} failed: {str(e)}")
|
|
386
|
+
if self.config.fail_fast:
|
|
387
|
+
raise
|
|
388
|
+
|
|
389
|
+
ctx.labeled_statements.append(labeled)
|
|
390
|
+
|
|
391
|
+
logger.info(f"Labeled {len(ctx.labeled_statements)} statements")
|
|
392
|
+
ctx.record_timing(stage_name, time.time() - start_time)
|
|
393
|
+
return ctx
|
|
394
|
+
|
|
395
|
+
def _run_taxonomy(self, ctx: PipelineContext) -> PipelineContext:
|
|
396
|
+
"""Stage 6: Classify statements against taxonomies."""
|
|
397
|
+
stage_name = get_stage_name(6)
|
|
398
|
+
logger.debug(f"Running {stage_name} stage")
|
|
399
|
+
start_time = time.time()
|
|
400
|
+
|
|
401
|
+
if not ctx.labeled_statements:
|
|
402
|
+
logger.debug("No labeled statements to classify")
|
|
403
|
+
return ctx
|
|
404
|
+
|
|
405
|
+
taxonomy_classifiers = PluginRegistry.get_taxonomy_classifiers()
|
|
406
|
+
if not taxonomy_classifiers:
|
|
407
|
+
logger.debug("No taxonomy classifiers registered")
|
|
408
|
+
return ctx
|
|
409
|
+
|
|
410
|
+
total_results = 0
|
|
411
|
+
for labeled_stmt in ctx.labeled_statements:
|
|
412
|
+
stmt = labeled_stmt.statement
|
|
413
|
+
subj_canonical = labeled_stmt.subject_canonical
|
|
414
|
+
obj_canonical = labeled_stmt.object_canonical
|
|
415
|
+
|
|
416
|
+
# Apply all taxonomy classifiers
|
|
417
|
+
for classifier in taxonomy_classifiers:
|
|
418
|
+
if not self.config.is_plugin_enabled(classifier.name):
|
|
419
|
+
continue
|
|
420
|
+
|
|
421
|
+
try:
|
|
422
|
+
results = classifier.classify(stmt, subj_canonical, obj_canonical, ctx)
|
|
423
|
+
if results:
|
|
424
|
+
# Store taxonomy results in context (list of results per key)
|
|
425
|
+
key = (stmt.source_text, classifier.taxonomy_name)
|
|
426
|
+
if key not in ctx.taxonomy_results:
|
|
427
|
+
ctx.taxonomy_results[key] = []
|
|
428
|
+
ctx.taxonomy_results[key].extend(results)
|
|
429
|
+
total_results += len(results)
|
|
430
|
+
|
|
431
|
+
# Also add to the labeled statement for easy access
|
|
432
|
+
labeled_stmt.taxonomy_results.extend(results)
|
|
433
|
+
|
|
434
|
+
for result in results:
|
|
435
|
+
logger.debug(
|
|
436
|
+
f"Taxonomy {classifier.name}: {result.full_label} "
|
|
437
|
+
f"(confidence={result.confidence:.2f})"
|
|
438
|
+
)
|
|
439
|
+
except Exception as e:
|
|
440
|
+
logger.error(f"Taxonomy classifier {classifier.name} failed: {e}")
|
|
441
|
+
ctx.add_error(f"Taxonomy classifier {classifier.name} failed: {str(e)}")
|
|
442
|
+
if self.config.fail_fast:
|
|
443
|
+
raise
|
|
444
|
+
|
|
445
|
+
logger.info(f"Taxonomy produced {total_results} labels across {len(ctx.taxonomy_results)} statement-taxonomy pairs")
|
|
446
|
+
ctx.record_timing(stage_name, time.time() - start_time)
|
|
447
|
+
return ctx
|