corp-extractor 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {corp_extractor-0.4.0.dist-info → corp_extractor-0.5.0.dist-info}/METADATA +181 -64
- corp_extractor-0.5.0.dist-info/RECORD +55 -0
- statement_extractor/__init__.py +9 -0
- statement_extractor/cli.py +446 -17
- statement_extractor/data/default_predicates.json +368 -0
- statement_extractor/data/statement_taxonomy.json +1182 -0
- statement_extractor/extractor.py +1 -23
- statement_extractor/gliner_extraction.py +4 -74
- statement_extractor/llm.py +255 -0
- statement_extractor/models/__init__.py +74 -0
- statement_extractor/models/canonical.py +139 -0
- statement_extractor/models/entity.py +102 -0
- statement_extractor/models/labels.py +191 -0
- statement_extractor/models/qualifiers.py +91 -0
- statement_extractor/models/statement.py +75 -0
- statement_extractor/models.py +4 -1
- statement_extractor/pipeline/__init__.py +39 -0
- statement_extractor/pipeline/config.py +134 -0
- statement_extractor/pipeline/context.py +177 -0
- statement_extractor/pipeline/orchestrator.py +447 -0
- statement_extractor/pipeline/registry.py +297 -0
- statement_extractor/plugins/__init__.py +43 -0
- statement_extractor/plugins/base.py +446 -0
- statement_extractor/plugins/canonicalizers/__init__.py +17 -0
- statement_extractor/plugins/canonicalizers/base.py +9 -0
- statement_extractor/plugins/canonicalizers/location.py +219 -0
- statement_extractor/plugins/canonicalizers/organization.py +230 -0
- statement_extractor/plugins/canonicalizers/person.py +242 -0
- statement_extractor/plugins/extractors/__init__.py +13 -0
- statement_extractor/plugins/extractors/base.py +9 -0
- statement_extractor/plugins/extractors/gliner2.py +536 -0
- statement_extractor/plugins/labelers/__init__.py +29 -0
- statement_extractor/plugins/labelers/base.py +9 -0
- statement_extractor/plugins/labelers/confidence.py +138 -0
- statement_extractor/plugins/labelers/relation_type.py +87 -0
- statement_extractor/plugins/labelers/sentiment.py +159 -0
- statement_extractor/plugins/labelers/taxonomy.py +373 -0
- statement_extractor/plugins/labelers/taxonomy_embedding.py +466 -0
- statement_extractor/plugins/qualifiers/__init__.py +19 -0
- statement_extractor/plugins/qualifiers/base.py +9 -0
- statement_extractor/plugins/qualifiers/companies_house.py +174 -0
- statement_extractor/plugins/qualifiers/gleif.py +186 -0
- statement_extractor/plugins/qualifiers/person.py +221 -0
- statement_extractor/plugins/qualifiers/sec_edgar.py +198 -0
- statement_extractor/plugins/splitters/__init__.py +13 -0
- statement_extractor/plugins/splitters/base.py +9 -0
- statement_extractor/plugins/splitters/t5_gemma.py +188 -0
- statement_extractor/plugins/taxonomy/__init__.py +13 -0
- statement_extractor/plugins/taxonomy/embedding.py +337 -0
- statement_extractor/plugins/taxonomy/mnli.py +279 -0
- corp_extractor-0.4.0.dist-info/RECORD +0 -12
- {corp_extractor-0.4.0.dist-info → corp_extractor-0.5.0.dist-info}/WHEEL +0 -0
- {corp_extractor-0.4.0.dist-info → corp_extractor-0.5.0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PipelineContext - Data container that flows through all pipeline stages.
|
|
3
|
+
|
|
4
|
+
The context accumulates outputs from each stage:
|
|
5
|
+
- Stage 1 (Splitting): raw_triples
|
|
6
|
+
- Stage 2 (Extraction): statements
|
|
7
|
+
- Stage 3 (Qualification): qualified_entities
|
|
8
|
+
- Stage 4 (Canonicalization): canonical_entities
|
|
9
|
+
- Stage 5 (Labeling): labeled_statements
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from typing import Any, Optional
|
|
13
|
+
|
|
14
|
+
from pydantic import BaseModel, Field
|
|
15
|
+
|
|
16
|
+
from ..models import (
|
|
17
|
+
RawTriple,
|
|
18
|
+
PipelineStatement,
|
|
19
|
+
QualifiedEntity,
|
|
20
|
+
CanonicalEntity,
|
|
21
|
+
LabeledStatement,
|
|
22
|
+
TaxonomyResult,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class PipelineContext(BaseModel):
|
|
27
|
+
"""
|
|
28
|
+
Context object that flows through all pipeline stages.
|
|
29
|
+
|
|
30
|
+
Accumulates outputs from each stage and provides access to
|
|
31
|
+
source text, metadata, and intermediate results.
|
|
32
|
+
"""
|
|
33
|
+
# Input
|
|
34
|
+
source_text: str = Field(..., description="Original input text")
|
|
35
|
+
source_metadata: dict[str, Any] = Field(
|
|
36
|
+
default_factory=dict,
|
|
37
|
+
description="Metadata about the source (e.g., document ID, URL, timestamp)"
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
# Stage 1 output: Raw triples from splitting
|
|
41
|
+
raw_triples: list[RawTriple] = Field(
|
|
42
|
+
default_factory=list,
|
|
43
|
+
description="Raw triples from Stage 1 (Splitting)"
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# Stage 2 output: Statements with extracted entities
|
|
47
|
+
statements: list[PipelineStatement] = Field(
|
|
48
|
+
default_factory=list,
|
|
49
|
+
description="Statements from Stage 2 (Extraction)"
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
# Stage 3 output: Qualified entities (keyed by entity_ref)
|
|
53
|
+
qualified_entities: dict[str, QualifiedEntity] = Field(
|
|
54
|
+
default_factory=dict,
|
|
55
|
+
description="Qualified entities from Stage 3 (Qualification)"
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# Stage 4 output: Canonical entities (keyed by entity_ref)
|
|
59
|
+
canonical_entities: dict[str, CanonicalEntity] = Field(
|
|
60
|
+
default_factory=dict,
|
|
61
|
+
description="Canonical entities from Stage 4 (Canonicalization)"
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Stage 5 output: Final labeled statements
|
|
65
|
+
labeled_statements: list[LabeledStatement] = Field(
|
|
66
|
+
default_factory=list,
|
|
67
|
+
description="Final labeled statements from Stage 5 (Labeling)"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Classification results from extractor (populated by GLiNER2 or similar)
|
|
71
|
+
# Keyed by source_text -> label_type -> (label_value, confidence)
|
|
72
|
+
classification_results: dict[str, dict[str, tuple[str, float]]] = Field(
|
|
73
|
+
default_factory=dict,
|
|
74
|
+
description="Pre-computed classification results from Stage 2 extractor"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Stage 6 output: Taxonomy classifications
|
|
78
|
+
# Keyed by (source_text, taxonomy_name) -> list of TaxonomyResult
|
|
79
|
+
# Multiple labels may match a single statement above threshold
|
|
80
|
+
taxonomy_results: dict[tuple[str, str], list[TaxonomyResult]] = Field(
|
|
81
|
+
default_factory=dict,
|
|
82
|
+
description="Taxonomy classifications from Stage 6 (multiple labels per statement)"
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# Processing metadata
|
|
86
|
+
processing_errors: list[str] = Field(
|
|
87
|
+
default_factory=list,
|
|
88
|
+
description="Errors encountered during processing"
|
|
89
|
+
)
|
|
90
|
+
processing_warnings: list[str] = Field(
|
|
91
|
+
default_factory=list,
|
|
92
|
+
description="Warnings generated during processing"
|
|
93
|
+
)
|
|
94
|
+
stage_timings: dict[str, float] = Field(
|
|
95
|
+
default_factory=dict,
|
|
96
|
+
description="Timing information for each stage (stage_name -> seconds)"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def add_error(self, error: str) -> None:
|
|
100
|
+
"""Add a processing error."""
|
|
101
|
+
self.processing_errors.append(error)
|
|
102
|
+
|
|
103
|
+
def add_warning(self, warning: str) -> None:
|
|
104
|
+
"""Add a processing warning."""
|
|
105
|
+
self.processing_warnings.append(warning)
|
|
106
|
+
|
|
107
|
+
def record_timing(self, stage: str, duration: float) -> None:
|
|
108
|
+
"""Record timing for a stage."""
|
|
109
|
+
self.stage_timings[stage] = duration
|
|
110
|
+
|
|
111
|
+
def get_entity_refs(self) -> set[str]:
|
|
112
|
+
"""Get all unique entity refs from statements."""
|
|
113
|
+
refs = set()
|
|
114
|
+
for stmt in self.statements:
|
|
115
|
+
refs.add(stmt.subject.entity_ref)
|
|
116
|
+
refs.add(stmt.object.entity_ref)
|
|
117
|
+
return refs
|
|
118
|
+
|
|
119
|
+
def get_qualified_entity(self, entity_ref: str) -> Optional[QualifiedEntity]:
|
|
120
|
+
"""Get qualified entity by ref, or None if not found."""
|
|
121
|
+
return self.qualified_entities.get(entity_ref)
|
|
122
|
+
|
|
123
|
+
def get_canonical_entity(self, entity_ref: str) -> Optional[CanonicalEntity]:
|
|
124
|
+
"""Get canonical entity by ref, or None if not found."""
|
|
125
|
+
return self.canonical_entities.get(entity_ref)
|
|
126
|
+
|
|
127
|
+
def get_classification(
|
|
128
|
+
self,
|
|
129
|
+
source_text: str,
|
|
130
|
+
label_type: str,
|
|
131
|
+
) -> Optional[tuple[str, float]]:
|
|
132
|
+
"""
|
|
133
|
+
Get pre-computed classification result for a source text.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
source_text: The source text that was classified
|
|
137
|
+
label_type: The type of label (e.g., "sentiment")
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
Tuple of (label_value, confidence) or None if not found
|
|
141
|
+
"""
|
|
142
|
+
if source_text in self.classification_results:
|
|
143
|
+
return self.classification_results[source_text].get(label_type)
|
|
144
|
+
return None
|
|
145
|
+
|
|
146
|
+
def set_classification(
|
|
147
|
+
self,
|
|
148
|
+
source_text: str,
|
|
149
|
+
label_type: str,
|
|
150
|
+
label_value: str,
|
|
151
|
+
confidence: float,
|
|
152
|
+
) -> None:
|
|
153
|
+
"""
|
|
154
|
+
Store a classification result for a source text.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
source_text: The source text that was classified
|
|
158
|
+
label_type: The type of label (e.g., "sentiment")
|
|
159
|
+
label_value: The classification result (e.g., "positive")
|
|
160
|
+
confidence: Confidence score (0.0 to 1.0)
|
|
161
|
+
"""
|
|
162
|
+
if source_text not in self.classification_results:
|
|
163
|
+
self.classification_results[source_text] = {}
|
|
164
|
+
self.classification_results[source_text][label_type] = (label_value, confidence)
|
|
165
|
+
|
|
166
|
+
@property
|
|
167
|
+
def has_errors(self) -> bool:
|
|
168
|
+
"""Check if any errors occurred during processing."""
|
|
169
|
+
return len(self.processing_errors) > 0
|
|
170
|
+
|
|
171
|
+
@property
|
|
172
|
+
def statement_count(self) -> int:
|
|
173
|
+
"""Get the number of statements in the final output."""
|
|
174
|
+
return len(self.labeled_statements) if self.labeled_statements else len(self.statements)
|
|
175
|
+
|
|
176
|
+
class Config:
|
|
177
|
+
arbitrary_types_allowed = True
|
|
@@ -0,0 +1,447 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ExtractionPipeline - Main orchestrator for the 5-stage extraction pipeline.
|
|
3
|
+
|
|
4
|
+
Coordinates the flow of data through all pipeline stages:
|
|
5
|
+
1. Splitting: Text → RawTriple
|
|
6
|
+
2. Extraction: RawTriple → PipelineStatement
|
|
7
|
+
3. Qualification: Entity → QualifiedEntity
|
|
8
|
+
4. Canonicalization: QualifiedEntity → CanonicalEntity
|
|
9
|
+
5. Labeling: Statement → LabeledStatement
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import logging
|
|
13
|
+
import time
|
|
14
|
+
from typing import Any, Optional
|
|
15
|
+
|
|
16
|
+
from .context import PipelineContext
|
|
17
|
+
from .config import PipelineConfig, get_stage_name
|
|
18
|
+
from .registry import PluginRegistry
|
|
19
|
+
from ..models import (
|
|
20
|
+
QualifiedEntity,
|
|
21
|
+
EntityQualifiers,
|
|
22
|
+
CanonicalEntity,
|
|
23
|
+
LabeledStatement,
|
|
24
|
+
TaxonomyResult,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ExtractionPipeline:
|
|
31
|
+
"""
|
|
32
|
+
Main pipeline orchestrator.
|
|
33
|
+
|
|
34
|
+
Coordinates the flow of data through all 5 stages, invoking registered
|
|
35
|
+
plugins in priority order and accumulating results in PipelineContext.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(self, config: Optional[PipelineConfig] = None):
|
|
39
|
+
"""
|
|
40
|
+
Initialize the pipeline.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
config: Pipeline configuration (uses defaults if not provided)
|
|
44
|
+
"""
|
|
45
|
+
self.config = config or PipelineConfig.default()
|
|
46
|
+
|
|
47
|
+
def process(
|
|
48
|
+
self,
|
|
49
|
+
text: str,
|
|
50
|
+
metadata: Optional[dict[str, Any]] = None,
|
|
51
|
+
) -> PipelineContext:
|
|
52
|
+
"""
|
|
53
|
+
Process text through the extraction pipeline.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
text: Input text to process
|
|
57
|
+
metadata: Optional metadata about the source
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
PipelineContext with accumulated results from all stages
|
|
61
|
+
"""
|
|
62
|
+
# Merge config options into metadata for plugins
|
|
63
|
+
combined_metadata = metadata.copy() if metadata else {}
|
|
64
|
+
|
|
65
|
+
# Pass extractor options from config to context
|
|
66
|
+
if self.config.extractor_options:
|
|
67
|
+
existing_extractor_opts = combined_metadata.get("extractor_options", {})
|
|
68
|
+
combined_metadata["extractor_options"] = {
|
|
69
|
+
**self.config.extractor_options,
|
|
70
|
+
**existing_extractor_opts, # Allow explicit metadata to override config
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
ctx = PipelineContext(
|
|
74
|
+
source_text=text,
|
|
75
|
+
source_metadata=combined_metadata,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
logger.info(f"Starting pipeline processing: {len(text)} chars")
|
|
79
|
+
|
|
80
|
+
try:
|
|
81
|
+
# Stage 1: Splitting
|
|
82
|
+
if self.config.is_stage_enabled(1):
|
|
83
|
+
ctx = self._run_splitting(ctx)
|
|
84
|
+
|
|
85
|
+
# Stage 2: Extraction
|
|
86
|
+
if self.config.is_stage_enabled(2):
|
|
87
|
+
ctx = self._run_extraction(ctx)
|
|
88
|
+
|
|
89
|
+
# Stage 3: Qualification
|
|
90
|
+
if self.config.is_stage_enabled(3):
|
|
91
|
+
ctx = self._run_qualification(ctx)
|
|
92
|
+
|
|
93
|
+
# Stage 4: Canonicalization
|
|
94
|
+
if self.config.is_stage_enabled(4):
|
|
95
|
+
ctx = self._run_canonicalization(ctx)
|
|
96
|
+
|
|
97
|
+
# Stage 5: Labeling
|
|
98
|
+
if self.config.is_stage_enabled(5):
|
|
99
|
+
ctx = self._run_labeling(ctx)
|
|
100
|
+
|
|
101
|
+
# Stage 6: Taxonomy classification
|
|
102
|
+
if self.config.is_stage_enabled(6):
|
|
103
|
+
ctx = self._run_taxonomy(ctx)
|
|
104
|
+
|
|
105
|
+
except Exception as e:
|
|
106
|
+
logger.exception("Pipeline processing failed")
|
|
107
|
+
ctx.add_error(f"Pipeline error: {str(e)}")
|
|
108
|
+
if self.config.fail_fast:
|
|
109
|
+
raise
|
|
110
|
+
|
|
111
|
+
logger.info(
|
|
112
|
+
f"Pipeline complete: {ctx.statement_count} statements, "
|
|
113
|
+
f"{len(ctx.processing_errors)} errors"
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
return ctx
|
|
117
|
+
|
|
118
|
+
def _run_splitting(self, ctx: PipelineContext) -> PipelineContext:
|
|
119
|
+
"""Stage 1: Split text into raw triples."""
|
|
120
|
+
stage_name = get_stage_name(1)
|
|
121
|
+
logger.debug(f"Running {stage_name} stage")
|
|
122
|
+
start_time = time.time()
|
|
123
|
+
|
|
124
|
+
splitters = PluginRegistry.get_splitters()
|
|
125
|
+
if not splitters:
|
|
126
|
+
ctx.add_warning("No splitter plugins registered")
|
|
127
|
+
return ctx
|
|
128
|
+
|
|
129
|
+
# Use first enabled splitter (highest priority)
|
|
130
|
+
for splitter in splitters:
|
|
131
|
+
if not self.config.is_plugin_enabled(splitter.name):
|
|
132
|
+
continue
|
|
133
|
+
|
|
134
|
+
logger.debug(f"Using splitter: {splitter.name}")
|
|
135
|
+
try:
|
|
136
|
+
raw_triples = splitter.split(ctx.source_text, ctx)
|
|
137
|
+
ctx.raw_triples = raw_triples
|
|
138
|
+
logger.info(f"Splitting produced {len(raw_triples)} raw triples")
|
|
139
|
+
break
|
|
140
|
+
except Exception as e:
|
|
141
|
+
logger.exception(f"Splitter {splitter.name} failed")
|
|
142
|
+
ctx.add_error(f"Splitter {splitter.name} failed: {str(e)}")
|
|
143
|
+
if self.config.fail_fast:
|
|
144
|
+
raise
|
|
145
|
+
|
|
146
|
+
ctx.record_timing(stage_name, time.time() - start_time)
|
|
147
|
+
return ctx
|
|
148
|
+
|
|
149
|
+
def _run_extraction(self, ctx: PipelineContext) -> PipelineContext:
|
|
150
|
+
"""Stage 2: Extract statements with typed entities from raw triples."""
|
|
151
|
+
stage_name = get_stage_name(2)
|
|
152
|
+
logger.debug(f"Running {stage_name} stage")
|
|
153
|
+
start_time = time.time()
|
|
154
|
+
|
|
155
|
+
if not ctx.raw_triples:
|
|
156
|
+
logger.debug("No raw triples to extract from")
|
|
157
|
+
return ctx
|
|
158
|
+
|
|
159
|
+
extractors = PluginRegistry.get_extractors()
|
|
160
|
+
if not extractors:
|
|
161
|
+
ctx.add_warning("No extractor plugins registered")
|
|
162
|
+
return ctx
|
|
163
|
+
|
|
164
|
+
# Collect classification schemas from labelers for the extractor
|
|
165
|
+
classification_schemas = self._collect_classification_schemas()
|
|
166
|
+
if classification_schemas:
|
|
167
|
+
logger.debug(f"Collected {len(classification_schemas)} classification schemas from labelers")
|
|
168
|
+
|
|
169
|
+
# Use first enabled extractor (highest priority)
|
|
170
|
+
for extractor in extractors:
|
|
171
|
+
if not self.config.is_plugin_enabled(extractor.name):
|
|
172
|
+
continue
|
|
173
|
+
|
|
174
|
+
# Pass classification schemas to extractor if it supports them
|
|
175
|
+
if classification_schemas and hasattr(extractor, 'add_classification_schema'):
|
|
176
|
+
for schema in classification_schemas:
|
|
177
|
+
extractor.add_classification_schema(schema)
|
|
178
|
+
|
|
179
|
+
logger.debug(f"Using extractor: {extractor.name}")
|
|
180
|
+
try:
|
|
181
|
+
statements = extractor.extract(ctx.raw_triples, ctx)
|
|
182
|
+
ctx.statements = statements
|
|
183
|
+
logger.info(f"Extraction produced {len(statements)} statements")
|
|
184
|
+
break
|
|
185
|
+
except Exception as e:
|
|
186
|
+
logger.exception(f"Extractor {extractor.name} failed")
|
|
187
|
+
ctx.add_error(f"Extractor {extractor.name} failed: {str(e)}")
|
|
188
|
+
if self.config.fail_fast:
|
|
189
|
+
raise
|
|
190
|
+
|
|
191
|
+
ctx.record_timing(stage_name, time.time() - start_time)
|
|
192
|
+
return ctx
|
|
193
|
+
|
|
194
|
+
def _collect_classification_schemas(self) -> list:
|
|
195
|
+
"""Collect classification schemas from enabled labelers."""
|
|
196
|
+
schemas = []
|
|
197
|
+
labelers = PluginRegistry.get_labelers()
|
|
198
|
+
|
|
199
|
+
for labeler in labelers:
|
|
200
|
+
if not self.config.is_plugin_enabled(labeler.name):
|
|
201
|
+
continue
|
|
202
|
+
|
|
203
|
+
# Check for classification schema (simple multi-choice)
|
|
204
|
+
if hasattr(labeler, 'classification_schema') and labeler.classification_schema:
|
|
205
|
+
schemas.append(labeler.classification_schema)
|
|
206
|
+
logger.debug(
|
|
207
|
+
f"Labeler {labeler.name} provides classification schema: "
|
|
208
|
+
f"{labeler.classification_schema}"
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
return schemas
|
|
212
|
+
|
|
213
|
+
def _run_qualification(self, ctx: PipelineContext) -> PipelineContext:
|
|
214
|
+
"""Stage 3: Add qualifiers to entities."""
|
|
215
|
+
stage_name = get_stage_name(3)
|
|
216
|
+
logger.debug(f"Running {stage_name} stage")
|
|
217
|
+
start_time = time.time()
|
|
218
|
+
|
|
219
|
+
if not ctx.statements:
|
|
220
|
+
logger.debug("No statements to qualify")
|
|
221
|
+
return ctx
|
|
222
|
+
|
|
223
|
+
# Collect all unique entities from statements
|
|
224
|
+
entities_to_qualify = {}
|
|
225
|
+
for stmt in ctx.statements:
|
|
226
|
+
for entity in [stmt.subject, stmt.object]:
|
|
227
|
+
if entity.entity_ref not in entities_to_qualify:
|
|
228
|
+
entities_to_qualify[entity.entity_ref] = entity
|
|
229
|
+
|
|
230
|
+
logger.debug(f"Qualifying {len(entities_to_qualify)} unique entities")
|
|
231
|
+
|
|
232
|
+
# Qualify each entity using applicable plugins
|
|
233
|
+
for entity_ref, entity in entities_to_qualify.items():
|
|
234
|
+
qualifiers = EntityQualifiers()
|
|
235
|
+
sources = []
|
|
236
|
+
|
|
237
|
+
# Get qualifiers for this entity type
|
|
238
|
+
type_qualifiers = PluginRegistry.get_qualifiers_for_type(entity.type)
|
|
239
|
+
|
|
240
|
+
for qualifier_plugin in type_qualifiers:
|
|
241
|
+
if not self.config.is_plugin_enabled(qualifier_plugin.name):
|
|
242
|
+
continue
|
|
243
|
+
|
|
244
|
+
try:
|
|
245
|
+
plugin_qualifiers = qualifier_plugin.qualify(entity, ctx)
|
|
246
|
+
if plugin_qualifiers and plugin_qualifiers.has_any_qualifier():
|
|
247
|
+
qualifiers = qualifiers.merge_with(plugin_qualifiers)
|
|
248
|
+
sources.append(qualifier_plugin.name)
|
|
249
|
+
except Exception as e:
|
|
250
|
+
logger.error(f"Qualifier {qualifier_plugin.name} failed for {entity.text}: {e}")
|
|
251
|
+
ctx.add_error(f"Qualifier {qualifier_plugin.name} failed: {str(e)}")
|
|
252
|
+
if self.config.fail_fast:
|
|
253
|
+
raise
|
|
254
|
+
|
|
255
|
+
# Create QualifiedEntity
|
|
256
|
+
qualified = QualifiedEntity(
|
|
257
|
+
entity_ref=entity_ref,
|
|
258
|
+
original_text=entity.text,
|
|
259
|
+
entity_type=entity.type,
|
|
260
|
+
qualifiers=qualifiers,
|
|
261
|
+
qualification_sources=sources,
|
|
262
|
+
)
|
|
263
|
+
ctx.qualified_entities[entity_ref] = qualified
|
|
264
|
+
|
|
265
|
+
logger.info(f"Qualified {len(ctx.qualified_entities)} entities")
|
|
266
|
+
ctx.record_timing(stage_name, time.time() - start_time)
|
|
267
|
+
return ctx
|
|
268
|
+
|
|
269
|
+
def _run_canonicalization(self, ctx: PipelineContext) -> PipelineContext:
|
|
270
|
+
"""Stage 4: Resolve entities to canonical forms."""
|
|
271
|
+
stage_name = get_stage_name(4)
|
|
272
|
+
logger.debug(f"Running {stage_name} stage")
|
|
273
|
+
start_time = time.time()
|
|
274
|
+
|
|
275
|
+
if not ctx.qualified_entities:
|
|
276
|
+
# Create basic qualified entities if stage 3 was skipped
|
|
277
|
+
for stmt in ctx.statements:
|
|
278
|
+
for entity in [stmt.subject, stmt.object]:
|
|
279
|
+
if entity.entity_ref not in ctx.qualified_entities:
|
|
280
|
+
ctx.qualified_entities[entity.entity_ref] = QualifiedEntity(
|
|
281
|
+
entity_ref=entity.entity_ref,
|
|
282
|
+
original_text=entity.text,
|
|
283
|
+
entity_type=entity.type,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
# Canonicalize each qualified entity
|
|
287
|
+
for entity_ref, qualified in ctx.qualified_entities.items():
|
|
288
|
+
canonical_match = None
|
|
289
|
+
fqn = None
|
|
290
|
+
|
|
291
|
+
# Get canonicalizers for this entity type
|
|
292
|
+
type_canonicalizers = PluginRegistry.get_canonicalizers_for_type(qualified.entity_type)
|
|
293
|
+
|
|
294
|
+
for canon_plugin in type_canonicalizers:
|
|
295
|
+
if not self.config.is_plugin_enabled(canon_plugin.name):
|
|
296
|
+
continue
|
|
297
|
+
|
|
298
|
+
try:
|
|
299
|
+
match = canon_plugin.find_canonical(qualified, ctx)
|
|
300
|
+
if match:
|
|
301
|
+
canonical_match = match
|
|
302
|
+
fqn = canon_plugin.format_fqn(qualified, match)
|
|
303
|
+
break # Use first successful match
|
|
304
|
+
except Exception as e:
|
|
305
|
+
logger.error(f"Canonicalizer {canon_plugin.name} failed for {qualified.original_text}: {e}")
|
|
306
|
+
ctx.add_error(f"Canonicalizer {canon_plugin.name} failed: {str(e)}")
|
|
307
|
+
if self.config.fail_fast:
|
|
308
|
+
raise
|
|
309
|
+
|
|
310
|
+
# Create CanonicalEntity
|
|
311
|
+
canonical = CanonicalEntity.from_qualified(
|
|
312
|
+
qualified=qualified,
|
|
313
|
+
canonical_match=canonical_match,
|
|
314
|
+
fqn=fqn,
|
|
315
|
+
)
|
|
316
|
+
ctx.canonical_entities[entity_ref] = canonical
|
|
317
|
+
|
|
318
|
+
logger.info(f"Canonicalized {len(ctx.canonical_entities)} entities")
|
|
319
|
+
ctx.record_timing(stage_name, time.time() - start_time)
|
|
320
|
+
return ctx
|
|
321
|
+
|
|
322
|
+
def _run_labeling(self, ctx: PipelineContext) -> PipelineContext:
|
|
323
|
+
"""Stage 5: Apply labels to statements."""
|
|
324
|
+
stage_name = get_stage_name(5)
|
|
325
|
+
logger.debug(f"Running {stage_name} stage")
|
|
326
|
+
start_time = time.time()
|
|
327
|
+
|
|
328
|
+
if not ctx.statements:
|
|
329
|
+
logger.debug("No statements to label")
|
|
330
|
+
return ctx
|
|
331
|
+
|
|
332
|
+
# Ensure canonical entities exist
|
|
333
|
+
if not ctx.canonical_entities:
|
|
334
|
+
self._run_canonicalization(ctx)
|
|
335
|
+
|
|
336
|
+
labelers = PluginRegistry.get_labelers()
|
|
337
|
+
|
|
338
|
+
for stmt in ctx.statements:
|
|
339
|
+
# Get canonical entities
|
|
340
|
+
subj_canonical = ctx.canonical_entities.get(stmt.subject.entity_ref)
|
|
341
|
+
obj_canonical = ctx.canonical_entities.get(stmt.object.entity_ref)
|
|
342
|
+
|
|
343
|
+
if not subj_canonical or not obj_canonical:
|
|
344
|
+
# Create fallback canonical entities
|
|
345
|
+
if not subj_canonical:
|
|
346
|
+
subj_qualified = ctx.qualified_entities.get(
|
|
347
|
+
stmt.subject.entity_ref,
|
|
348
|
+
QualifiedEntity(
|
|
349
|
+
entity_ref=stmt.subject.entity_ref,
|
|
350
|
+
original_text=stmt.subject.text,
|
|
351
|
+
entity_type=stmt.subject.type,
|
|
352
|
+
)
|
|
353
|
+
)
|
|
354
|
+
subj_canonical = CanonicalEntity.from_qualified(subj_qualified)
|
|
355
|
+
|
|
356
|
+
if not obj_canonical:
|
|
357
|
+
obj_qualified = ctx.qualified_entities.get(
|
|
358
|
+
stmt.object.entity_ref,
|
|
359
|
+
QualifiedEntity(
|
|
360
|
+
entity_ref=stmt.object.entity_ref,
|
|
361
|
+
original_text=stmt.object.text,
|
|
362
|
+
entity_type=stmt.object.type,
|
|
363
|
+
)
|
|
364
|
+
)
|
|
365
|
+
obj_canonical = CanonicalEntity.from_qualified(obj_qualified)
|
|
366
|
+
|
|
367
|
+
# Create labeled statement
|
|
368
|
+
labeled = LabeledStatement(
|
|
369
|
+
statement=stmt,
|
|
370
|
+
subject_canonical=subj_canonical,
|
|
371
|
+
object_canonical=obj_canonical,
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
# Apply all labelers
|
|
375
|
+
for labeler in labelers:
|
|
376
|
+
if not self.config.is_plugin_enabled(labeler.name):
|
|
377
|
+
continue
|
|
378
|
+
|
|
379
|
+
try:
|
|
380
|
+
label = labeler.label(stmt, subj_canonical, obj_canonical, ctx)
|
|
381
|
+
if label:
|
|
382
|
+
labeled.add_label(label)
|
|
383
|
+
except Exception as e:
|
|
384
|
+
logger.error(f"Labeler {labeler.name} failed: {e}")
|
|
385
|
+
ctx.add_error(f"Labeler {labeler.name} failed: {str(e)}")
|
|
386
|
+
if self.config.fail_fast:
|
|
387
|
+
raise
|
|
388
|
+
|
|
389
|
+
ctx.labeled_statements.append(labeled)
|
|
390
|
+
|
|
391
|
+
logger.info(f"Labeled {len(ctx.labeled_statements)} statements")
|
|
392
|
+
ctx.record_timing(stage_name, time.time() - start_time)
|
|
393
|
+
return ctx
|
|
394
|
+
|
|
395
|
+
def _run_taxonomy(self, ctx: PipelineContext) -> PipelineContext:
|
|
396
|
+
"""Stage 6: Classify statements against taxonomies."""
|
|
397
|
+
stage_name = get_stage_name(6)
|
|
398
|
+
logger.debug(f"Running {stage_name} stage")
|
|
399
|
+
start_time = time.time()
|
|
400
|
+
|
|
401
|
+
if not ctx.labeled_statements:
|
|
402
|
+
logger.debug("No labeled statements to classify")
|
|
403
|
+
return ctx
|
|
404
|
+
|
|
405
|
+
taxonomy_classifiers = PluginRegistry.get_taxonomy_classifiers()
|
|
406
|
+
if not taxonomy_classifiers:
|
|
407
|
+
logger.debug("No taxonomy classifiers registered")
|
|
408
|
+
return ctx
|
|
409
|
+
|
|
410
|
+
total_results = 0
|
|
411
|
+
for labeled_stmt in ctx.labeled_statements:
|
|
412
|
+
stmt = labeled_stmt.statement
|
|
413
|
+
subj_canonical = labeled_stmt.subject_canonical
|
|
414
|
+
obj_canonical = labeled_stmt.object_canonical
|
|
415
|
+
|
|
416
|
+
# Apply all taxonomy classifiers
|
|
417
|
+
for classifier in taxonomy_classifiers:
|
|
418
|
+
if not self.config.is_plugin_enabled(classifier.name):
|
|
419
|
+
continue
|
|
420
|
+
|
|
421
|
+
try:
|
|
422
|
+
results = classifier.classify(stmt, subj_canonical, obj_canonical, ctx)
|
|
423
|
+
if results:
|
|
424
|
+
# Store taxonomy results in context (list of results per key)
|
|
425
|
+
key = (stmt.source_text, classifier.taxonomy_name)
|
|
426
|
+
if key not in ctx.taxonomy_results:
|
|
427
|
+
ctx.taxonomy_results[key] = []
|
|
428
|
+
ctx.taxonomy_results[key].extend(results)
|
|
429
|
+
total_results += len(results)
|
|
430
|
+
|
|
431
|
+
# Also add to the labeled statement for easy access
|
|
432
|
+
labeled_stmt.taxonomy_results.extend(results)
|
|
433
|
+
|
|
434
|
+
for result in results:
|
|
435
|
+
logger.debug(
|
|
436
|
+
f"Taxonomy {classifier.name}: {result.full_label} "
|
|
437
|
+
f"(confidence={result.confidence:.2f})"
|
|
438
|
+
)
|
|
439
|
+
except Exception as e:
|
|
440
|
+
logger.error(f"Taxonomy classifier {classifier.name} failed: {e}")
|
|
441
|
+
ctx.add_error(f"Taxonomy classifier {classifier.name} failed: {str(e)}")
|
|
442
|
+
if self.config.fail_fast:
|
|
443
|
+
raise
|
|
444
|
+
|
|
445
|
+
logger.info(f"Taxonomy produced {total_results} labels across {len(ctx.taxonomy_results)} statement-taxonomy pairs")
|
|
446
|
+
ctx.record_timing(stage_name, time.time() - start_time)
|
|
447
|
+
return ctx
|