corp-extractor 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {corp_extractor-0.3.0.dist-info → corp_extractor-0.5.0.dist-info}/METADATA +235 -96
- corp_extractor-0.5.0.dist-info/RECORD +55 -0
- statement_extractor/__init__.py +9 -0
- statement_extractor/cli.py +460 -21
- statement_extractor/data/default_predicates.json +368 -0
- statement_extractor/data/statement_taxonomy.json +1182 -0
- statement_extractor/extractor.py +32 -47
- statement_extractor/gliner_extraction.py +218 -0
- statement_extractor/llm.py +255 -0
- statement_extractor/models/__init__.py +74 -0
- statement_extractor/models/canonical.py +139 -0
- statement_extractor/models/entity.py +102 -0
- statement_extractor/models/labels.py +191 -0
- statement_extractor/models/qualifiers.py +91 -0
- statement_extractor/models/statement.py +75 -0
- statement_extractor/models.py +15 -6
- statement_extractor/pipeline/__init__.py +39 -0
- statement_extractor/pipeline/config.py +134 -0
- statement_extractor/pipeline/context.py +177 -0
- statement_extractor/pipeline/orchestrator.py +447 -0
- statement_extractor/pipeline/registry.py +297 -0
- statement_extractor/plugins/__init__.py +43 -0
- statement_extractor/plugins/base.py +446 -0
- statement_extractor/plugins/canonicalizers/__init__.py +17 -0
- statement_extractor/plugins/canonicalizers/base.py +9 -0
- statement_extractor/plugins/canonicalizers/location.py +219 -0
- statement_extractor/plugins/canonicalizers/organization.py +230 -0
- statement_extractor/plugins/canonicalizers/person.py +242 -0
- statement_extractor/plugins/extractors/__init__.py +13 -0
- statement_extractor/plugins/extractors/base.py +9 -0
- statement_extractor/plugins/extractors/gliner2.py +536 -0
- statement_extractor/plugins/labelers/__init__.py +29 -0
- statement_extractor/plugins/labelers/base.py +9 -0
- statement_extractor/plugins/labelers/confidence.py +138 -0
- statement_extractor/plugins/labelers/relation_type.py +87 -0
- statement_extractor/plugins/labelers/sentiment.py +159 -0
- statement_extractor/plugins/labelers/taxonomy.py +373 -0
- statement_extractor/plugins/labelers/taxonomy_embedding.py +466 -0
- statement_extractor/plugins/qualifiers/__init__.py +19 -0
- statement_extractor/plugins/qualifiers/base.py +9 -0
- statement_extractor/plugins/qualifiers/companies_house.py +174 -0
- statement_extractor/plugins/qualifiers/gleif.py +186 -0
- statement_extractor/plugins/qualifiers/person.py +221 -0
- statement_extractor/plugins/qualifiers/sec_edgar.py +198 -0
- statement_extractor/plugins/splitters/__init__.py +13 -0
- statement_extractor/plugins/splitters/base.py +9 -0
- statement_extractor/plugins/splitters/t5_gemma.py +188 -0
- statement_extractor/plugins/taxonomy/__init__.py +13 -0
- statement_extractor/plugins/taxonomy/embedding.py +337 -0
- statement_extractor/plugins/taxonomy/mnli.py +279 -0
- statement_extractor/scoring.py +17 -69
- corp_extractor-0.3.0.dist-info/RECORD +0 -12
- statement_extractor/spacy_extraction.py +0 -386
- {corp_extractor-0.3.0.dist-info → corp_extractor-0.5.0.dist-info}/WHEEL +0 -0
- {corp_extractor-0.3.0.dist-info → corp_extractor-0.5.0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,536 @@
|
|
|
1
|
+
"""
|
|
2
|
+
GLiNER2Extractor - Stage 2 plugin that refines triples using GLiNER2.
|
|
3
|
+
|
|
4
|
+
Uses GLiNER2 for:
|
|
5
|
+
1. Entity extraction: Refine subject/object boundaries
|
|
6
|
+
2. Relation extraction: When predicate list is provided
|
|
7
|
+
3. Entity scoring: Score how entity-like subjects/objects are
|
|
8
|
+
4. Classification: Run labeler classification schemas in single pass
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
import logging
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Optional
|
|
15
|
+
|
|
16
|
+
from ..base import BaseExtractorPlugin, ClassificationSchema, PluginCapability
|
|
17
|
+
from ...pipeline.context import PipelineContext
|
|
18
|
+
from ...pipeline.registry import PluginRegistry
|
|
19
|
+
from ...models import RawTriple, PipelineStatement, ExtractedEntity, EntityType
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
# Type alias for predicate configuration with description and threshold
|
|
24
|
+
PredicateConfig = dict[str, str | float] # {"description": str, "threshold": float}
|
|
25
|
+
|
|
26
|
+
# Path to bundled default predicates JSON
|
|
27
|
+
DEFAULT_PREDICATES_PATH = Path(__file__).parent.parent.parent / "data" / "default_predicates.json"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def load_predicates_from_json(path: Path) -> dict[str, dict[str, PredicateConfig]]:
|
|
31
|
+
"""
|
|
32
|
+
Load predicate categories from a JSON file.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
path: Path to JSON file containing predicate categories
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
Dict of category -> {predicate -> {description, threshold}}
|
|
39
|
+
|
|
40
|
+
Raises:
|
|
41
|
+
FileNotFoundError: If path doesn't exist
|
|
42
|
+
json.JSONDecodeError: If JSON is invalid
|
|
43
|
+
"""
|
|
44
|
+
with open(path) as f:
|
|
45
|
+
return json.load(f)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _load_default_predicates() -> dict[str, dict[str, PredicateConfig]]:
|
|
49
|
+
"""Load the bundled default predicates."""
|
|
50
|
+
try:
|
|
51
|
+
return load_predicates_from_json(DEFAULT_PREDICATES_PATH)
|
|
52
|
+
except (FileNotFoundError, json.JSONDecodeError) as e:
|
|
53
|
+
logger.warning(f"Failed to load default predicates from {DEFAULT_PREDICATES_PATH}: {e}")
|
|
54
|
+
return {}
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# Load default predicates on module import
|
|
58
|
+
PREDICATE_CATEGORIES: dict[str, dict[str, PredicateConfig]] = _load_default_predicates()
|
|
59
|
+
|
|
60
|
+
# Ensure we have predicates loaded
|
|
61
|
+
if not PREDICATE_CATEGORIES:
|
|
62
|
+
logger.error("No predicate categories loaded - relation extraction will fail")
|
|
63
|
+
|
|
64
|
+
# Build reverse lookup: predicate -> category
|
|
65
|
+
PREDICATE_TO_CATEGORY: dict[str, str] = {}
|
|
66
|
+
for category, predicates in PREDICATE_CATEGORIES.items():
|
|
67
|
+
for predicate in predicates.keys():
|
|
68
|
+
PREDICATE_TO_CATEGORY[predicate] = category
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def get_predicate_category(predicate: str) -> Optional[str]:
|
|
73
|
+
"""
|
|
74
|
+
Look up the category for a predicate.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
predicate: The predicate string to look up
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
The category name if found, None otherwise
|
|
81
|
+
"""
|
|
82
|
+
# Direct lookup
|
|
83
|
+
if predicate in PREDICATE_TO_CATEGORY:
|
|
84
|
+
return PREDICATE_TO_CATEGORY[predicate]
|
|
85
|
+
|
|
86
|
+
# Try normalized form (lowercase, underscores)
|
|
87
|
+
normalized = predicate.lower().replace(" ", "_").replace("-", "_")
|
|
88
|
+
if normalized in PREDICATE_TO_CATEGORY:
|
|
89
|
+
return PREDICATE_TO_CATEGORY[normalized]
|
|
90
|
+
|
|
91
|
+
return None
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# GLiNER2 entity type to our EntityType mapping
|
|
95
|
+
GLINER_TYPE_MAP = {
|
|
96
|
+
"person": EntityType.PERSON,
|
|
97
|
+
"organization": EntityType.ORG,
|
|
98
|
+
"company": EntityType.ORG,
|
|
99
|
+
"location": EntityType.LOC,
|
|
100
|
+
"city": EntityType.GPE,
|
|
101
|
+
"country": EntityType.GPE,
|
|
102
|
+
"product": EntityType.PRODUCT,
|
|
103
|
+
"event": EntityType.EVENT,
|
|
104
|
+
"date": EntityType.DATE,
|
|
105
|
+
"money": EntityType.MONEY,
|
|
106
|
+
"quantity": EntityType.QUANTITY,
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@PluginRegistry.extractor
|
|
111
|
+
class GLiNER2Extractor(BaseExtractorPlugin):
|
|
112
|
+
"""
|
|
113
|
+
Extractor plugin that uses GLiNER2 for entity and relation refinement.
|
|
114
|
+
|
|
115
|
+
Processes raw triples from Stage 1 and produces PipelineStatement
|
|
116
|
+
objects with typed entities. Also runs classification schemas from
|
|
117
|
+
labeler plugins in a single pass.
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
def __init__(
|
|
121
|
+
self,
|
|
122
|
+
predicates_file: Optional[str | Path] = None,
|
|
123
|
+
entity_types: Optional[list[str]] = None,
|
|
124
|
+
classification_schemas: Optional[list[ClassificationSchema]] = None,
|
|
125
|
+
min_confidence: float = 0.75,
|
|
126
|
+
):
|
|
127
|
+
"""
|
|
128
|
+
Initialize the GLiNER2 extractor.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
predicates_file: Optional path to custom predicates JSON file.
|
|
132
|
+
If not provided, uses bundled default_predicates.json.
|
|
133
|
+
entity_types: Optional list of entity types to extract.
|
|
134
|
+
If not provided, uses bundled default from JSON config.
|
|
135
|
+
classification_schemas: Optional list of classification schemas from labelers
|
|
136
|
+
min_confidence: Minimum confidence threshold for relation extraction (default 0.75)
|
|
137
|
+
"""
|
|
138
|
+
self._predicates_file = Path(predicates_file) if predicates_file else None
|
|
139
|
+
self._predicate_categories: Optional[dict[str, dict[str, PredicateConfig]]] = None
|
|
140
|
+
self._entity_types = entity_types
|
|
141
|
+
self._classification_schemas = classification_schemas or []
|
|
142
|
+
self._min_confidence = min_confidence
|
|
143
|
+
self._model = None
|
|
144
|
+
|
|
145
|
+
# Load custom predicates if file provided
|
|
146
|
+
if self._predicates_file:
|
|
147
|
+
try:
|
|
148
|
+
self._predicate_categories = load_predicates_from_json(self._predicates_file)
|
|
149
|
+
logger.info(f"Loaded {len(self._predicate_categories)} predicate categories from {self._predicates_file}")
|
|
150
|
+
except Exception as e:
|
|
151
|
+
logger.warning(f"Failed to load custom predicates from {self._predicates_file}: {e}")
|
|
152
|
+
self._predicate_categories = None
|
|
153
|
+
|
|
154
|
+
def _get_predicate_categories(self) -> dict[str, dict[str, PredicateConfig]]:
|
|
155
|
+
"""Get predicate categories - custom file or default from JSON."""
|
|
156
|
+
if self._predicate_categories is not None:
|
|
157
|
+
return self._predicate_categories
|
|
158
|
+
return PREDICATE_CATEGORIES
|
|
159
|
+
|
|
160
|
+
def _get_entity_types(self) -> list[str]:
|
|
161
|
+
"""Get entity types - from init or derived from GLINER_TYPE_MAP keys."""
|
|
162
|
+
if self._entity_types is not None:
|
|
163
|
+
return self._entity_types
|
|
164
|
+
# Use keys from GLINER_TYPE_MAP as default entity types
|
|
165
|
+
return list(GLINER_TYPE_MAP.keys())
|
|
166
|
+
|
|
167
|
+
@property
|
|
168
|
+
def name(self) -> str:
|
|
169
|
+
return "gliner2_extractor"
|
|
170
|
+
|
|
171
|
+
@property
|
|
172
|
+
def priority(self) -> int:
|
|
173
|
+
return 10 # High priority - primary extractor
|
|
174
|
+
|
|
175
|
+
@property
|
|
176
|
+
def capabilities(self) -> PluginCapability:
|
|
177
|
+
return PluginCapability.BATCH_PROCESSING | PluginCapability.LLM_REQUIRED
|
|
178
|
+
|
|
179
|
+
@property
|
|
180
|
+
def description(self) -> str:
|
|
181
|
+
return "GLiNER2 model for entity and relation extraction"
|
|
182
|
+
|
|
183
|
+
def _get_model(self):
|
|
184
|
+
"""Lazy-load the GLiNER2 model."""
|
|
185
|
+
if self._model is None:
|
|
186
|
+
try:
|
|
187
|
+
from gliner2 import GLiNER2
|
|
188
|
+
logger.info("Loading GLiNER2 model...")
|
|
189
|
+
self._model = GLiNER2.from_pretrained("fastino/gliner2-base-v1")
|
|
190
|
+
logger.debug("GLiNER2 model loaded")
|
|
191
|
+
except ImportError:
|
|
192
|
+
logger.warning("GLiNER2 not installed, using fallback")
|
|
193
|
+
self._model = None
|
|
194
|
+
return self._model
|
|
195
|
+
|
|
196
|
+
def add_classification_schema(self, schema: ClassificationSchema) -> None:
|
|
197
|
+
"""Add a classification schema to run during extraction."""
|
|
198
|
+
self._classification_schemas.append(schema)
|
|
199
|
+
|
|
200
|
+
def extract(
|
|
201
|
+
self,
|
|
202
|
+
raw_triples: list[RawTriple],
|
|
203
|
+
context: PipelineContext,
|
|
204
|
+
) -> list[PipelineStatement]:
|
|
205
|
+
"""
|
|
206
|
+
Extract statements from raw triples using GLiNER2.
|
|
207
|
+
|
|
208
|
+
Returns ALL matching relations from GLiNER2 (not just the best one).
|
|
209
|
+
Also runs any classification schemas and stores results in context.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
raw_triples: Raw triples from Stage 1
|
|
213
|
+
context: Pipeline context
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
List of PipelineStatement objects (may contain multiple per raw triple)
|
|
217
|
+
"""
|
|
218
|
+
predicate_categories = self._get_predicate_categories()
|
|
219
|
+
logger.info(f"GLiNER2Extractor processing {len(raw_triples)} triples")
|
|
220
|
+
logger.info(f"Using {len(predicate_categories)} predicate categories")
|
|
221
|
+
|
|
222
|
+
statements = []
|
|
223
|
+
model = self._get_model()
|
|
224
|
+
classified_texts: set[str] = set()
|
|
225
|
+
|
|
226
|
+
for raw in raw_triples:
|
|
227
|
+
try:
|
|
228
|
+
if model:
|
|
229
|
+
# Use relation extraction iterating through categories
|
|
230
|
+
# Returns ALL matches, not just the best one
|
|
231
|
+
extracted_stmts = self._extract_with_relations(raw, model, predicate_categories)
|
|
232
|
+
else:
|
|
233
|
+
# No model available - skip
|
|
234
|
+
logger.warning("No GLiNER2 model available - skipping extraction")
|
|
235
|
+
extracted_stmts = []
|
|
236
|
+
|
|
237
|
+
for stmt in extracted_stmts:
|
|
238
|
+
statements.append(stmt)
|
|
239
|
+
|
|
240
|
+
# Run classifications for this statement's source text (once per unique text)
|
|
241
|
+
if model and self._classification_schemas and stmt.source_text not in classified_texts:
|
|
242
|
+
self._run_classifications(model, stmt.source_text, context)
|
|
243
|
+
classified_texts.add(stmt.source_text)
|
|
244
|
+
|
|
245
|
+
except Exception as e:
|
|
246
|
+
logger.warning(f"Error extracting triple: {e}")
|
|
247
|
+
# No fallback - skip this triple
|
|
248
|
+
|
|
249
|
+
logger.info(f"GLiNER2Extractor produced {len(statements)} statements from {len(raw_triples)} raw triples")
|
|
250
|
+
return statements
|
|
251
|
+
|
|
252
|
+
def _run_classifications(
|
|
253
|
+
self,
|
|
254
|
+
model,
|
|
255
|
+
source_text: str,
|
|
256
|
+
context: PipelineContext,
|
|
257
|
+
) -> None:
|
|
258
|
+
"""
|
|
259
|
+
Run classification schemas using GLiNER2 and store results in context.
|
|
260
|
+
|
|
261
|
+
Uses GLiNER2's create_schema() API for efficient batch classification.
|
|
262
|
+
"""
|
|
263
|
+
if not self._classification_schemas:
|
|
264
|
+
return
|
|
265
|
+
|
|
266
|
+
# Skip if already classified this text
|
|
267
|
+
if source_text in context.classification_results:
|
|
268
|
+
return
|
|
269
|
+
|
|
270
|
+
try:
|
|
271
|
+
# Build schema with all classifications
|
|
272
|
+
schema = model.create_schema()
|
|
273
|
+
|
|
274
|
+
for class_schema in self._classification_schemas:
|
|
275
|
+
schema = schema.classification(
|
|
276
|
+
class_schema.label_type,
|
|
277
|
+
class_schema.choices,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
# Run extraction with schema
|
|
281
|
+
results = model.extract(source_text, schema, include_confidence=True)
|
|
282
|
+
|
|
283
|
+
# Store results in context
|
|
284
|
+
for class_schema in self._classification_schemas:
|
|
285
|
+
label_type = class_schema.label_type
|
|
286
|
+
if label_type in results:
|
|
287
|
+
result_value = results[label_type]
|
|
288
|
+
# With include_confidence=True, GLiNER2 returns
|
|
289
|
+
# {'label': 'value', 'confidence': 0.95} for classifications
|
|
290
|
+
if isinstance(result_value, dict):
|
|
291
|
+
label_value = result_value.get("label", str(result_value))
|
|
292
|
+
confidence = result_value.get("confidence", 0.85)
|
|
293
|
+
else:
|
|
294
|
+
label_value = str(result_value)
|
|
295
|
+
confidence = 0.85
|
|
296
|
+
context.set_classification(
|
|
297
|
+
source_text, label_type, label_value, confidence
|
|
298
|
+
)
|
|
299
|
+
logger.debug(
|
|
300
|
+
f"GLiNER2 classified '{source_text[:50]}...' "
|
|
301
|
+
f"as {label_type}={label_value}"
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
except Exception as e:
|
|
305
|
+
logger.warning(f"GLiNER2 classification failed: {e}")
|
|
306
|
+
|
|
307
|
+
def _extract_with_relations(
|
|
308
|
+
self,
|
|
309
|
+
raw: RawTriple,
|
|
310
|
+
model,
|
|
311
|
+
predicate_categories: dict[str, dict[str, PredicateConfig]],
|
|
312
|
+
) -> list[PipelineStatement]:
|
|
313
|
+
"""
|
|
314
|
+
Extract using GLiNER2 relation extraction, iterating through categories.
|
|
315
|
+
|
|
316
|
+
Iterates through each predicate category separately to stay under
|
|
317
|
+
GLiNER2's ~25 label limit. Uses schema API with entities + relations.
|
|
318
|
+
Returns ALL matching relations, not just the best one.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
raw: Raw triple from Stage 1
|
|
322
|
+
model: GLiNER2 model instance
|
|
323
|
+
predicate_categories: Dict of category -> predicates to use
|
|
324
|
+
|
|
325
|
+
Returns:
|
|
326
|
+
List of PipelineStatements for all relations found
|
|
327
|
+
"""
|
|
328
|
+
logger.debug(f"Attempting relation extraction for: '{raw.source_sentence[:80]}...'")
|
|
329
|
+
|
|
330
|
+
# Iterate through each category separately to stay under GLiNER2's ~25 label limit
|
|
331
|
+
# Use schema API with entities + relations together for better extraction
|
|
332
|
+
all_relations: list[tuple[str, str, str, str, float]] = [] # (head, rel_type, tail, category, confidence)
|
|
333
|
+
|
|
334
|
+
for category_name, category_predicates in predicate_categories.items():
|
|
335
|
+
# Build relations dict with descriptions for GLiNER2 schema API
|
|
336
|
+
# The .relations() method expects {relation_name: description} dict, not a list
|
|
337
|
+
relations_dict = {
|
|
338
|
+
pred_name: pred_config.get("description", pred_name) if isinstance(pred_config, dict) else str(pred_config)
|
|
339
|
+
for pred_name, pred_config in category_predicates.items()
|
|
340
|
+
}
|
|
341
|
+
|
|
342
|
+
try:
|
|
343
|
+
# Build schema with entities and relations for this category
|
|
344
|
+
schema = (model.create_schema()
|
|
345
|
+
.entities(self._get_entity_types())
|
|
346
|
+
.relations(relations_dict)
|
|
347
|
+
)
|
|
348
|
+
result = model.extract(raw.source_sentence, schema, include_confidence=True)
|
|
349
|
+
|
|
350
|
+
# Get relations from this category
|
|
351
|
+
relation_data = result.get("relations", result.get("relation_extraction", {}))
|
|
352
|
+
|
|
353
|
+
# Filter to non-empty and collect relations
|
|
354
|
+
for rel_type, relations in relation_data.items():
|
|
355
|
+
if not relations:
|
|
356
|
+
continue
|
|
357
|
+
|
|
358
|
+
for rel in relations:
|
|
359
|
+
head, tail, confidence = self._parse_relation(rel)
|
|
360
|
+
if head and tail:
|
|
361
|
+
all_relations.append((head, rel_type, tail, category_name, confidence))
|
|
362
|
+
logger.debug(f" [{category_name}] {head} --[{rel_type}]--> {tail} (conf={confidence:.2f})")
|
|
363
|
+
|
|
364
|
+
except Exception as e:
|
|
365
|
+
logger.debug(f" Category {category_name} extraction failed: {e}")
|
|
366
|
+
continue
|
|
367
|
+
|
|
368
|
+
total_found = len(all_relations)
|
|
369
|
+
logger.debug(f" GLiNER2 found {total_found} total relations across all categories")
|
|
370
|
+
|
|
371
|
+
if not all_relations:
|
|
372
|
+
logger.debug(f"No GLiNER2 relation match in: '{raw.source_sentence[:60]}...'")
|
|
373
|
+
return []
|
|
374
|
+
|
|
375
|
+
# Filter by confidence threshold and sort descending
|
|
376
|
+
all_relations = [(h, r, t, c, conf) for h, r, t, c, conf in all_relations if conf >= self._min_confidence]
|
|
377
|
+
all_relations.sort(reverse=True, key=lambda x: x[4]) # Sort by confidence
|
|
378
|
+
statements = []
|
|
379
|
+
|
|
380
|
+
filtered_count = total_found - len(all_relations)
|
|
381
|
+
if filtered_count > 0:
|
|
382
|
+
logger.debug(f" Filtered {filtered_count} relations below confidence threshold ({self._min_confidence})")
|
|
383
|
+
|
|
384
|
+
if not all_relations:
|
|
385
|
+
logger.debug(f"No relations above confidence threshold ({self._min_confidence})")
|
|
386
|
+
return []
|
|
387
|
+
|
|
388
|
+
for head, rel_type, tail, category, confidence in all_relations:
|
|
389
|
+
logger.info(
|
|
390
|
+
f"GLiNER2 relation match: {head} --[{rel_type}]--> {tail} "
|
|
391
|
+
f"(category={category}, confidence={confidence:.2f})"
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
# Get entity types
|
|
395
|
+
subj_type = self._infer_entity_type(head, model, raw.source_sentence)
|
|
396
|
+
obj_type = self._infer_entity_type(tail, model, raw.source_sentence)
|
|
397
|
+
logger.debug(f" Entity types: {subj_type.value}, {obj_type.value}")
|
|
398
|
+
|
|
399
|
+
stmt = PipelineStatement(
|
|
400
|
+
subject=ExtractedEntity(
|
|
401
|
+
text=head,
|
|
402
|
+
type=subj_type,
|
|
403
|
+
confidence=confidence,
|
|
404
|
+
),
|
|
405
|
+
predicate=rel_type,
|
|
406
|
+
predicate_category=category,
|
|
407
|
+
object=ExtractedEntity(
|
|
408
|
+
text=tail,
|
|
409
|
+
type=obj_type,
|
|
410
|
+
confidence=confidence,
|
|
411
|
+
),
|
|
412
|
+
source_text=raw.source_sentence,
|
|
413
|
+
confidence_score=confidence,
|
|
414
|
+
extraction_method="gliner_relation",
|
|
415
|
+
)
|
|
416
|
+
statements.append(stmt)
|
|
417
|
+
|
|
418
|
+
return statements
|
|
419
|
+
|
|
420
|
+
def _extract_with_entities(
|
|
421
|
+
self,
|
|
422
|
+
raw: RawTriple,
|
|
423
|
+
model,
|
|
424
|
+
) -> Optional[PipelineStatement]:
|
|
425
|
+
"""
|
|
426
|
+
Entity extraction mode - returns None since we don't use T5-Gemma predicates.
|
|
427
|
+
|
|
428
|
+
This method is called when predicates are disabled. Without GLiNER2 relation
|
|
429
|
+
extraction, we cannot form valid statements.
|
|
430
|
+
"""
|
|
431
|
+
logger.debug(f"Entity extraction mode (no predicates) - skipping: '{raw.source_sentence[:60]}...'")
|
|
432
|
+
return None
|
|
433
|
+
|
|
434
|
+
def _parse_relation(self, rel) -> tuple[str, str, float]:
|
|
435
|
+
"""
|
|
436
|
+
Parse a relation from GLiNER2 output.
|
|
437
|
+
|
|
438
|
+
Args:
|
|
439
|
+
rel: Relation data (tuple, dict, or other format from GLiNER2)
|
|
440
|
+
|
|
441
|
+
Returns:
|
|
442
|
+
Tuple of (head_text, tail_text, confidence)
|
|
443
|
+
"""
|
|
444
|
+
# Log the actual structure for debugging
|
|
445
|
+
logger.debug(f" Parsing relation: type={type(rel).__name__}, value={rel}")
|
|
446
|
+
|
|
447
|
+
# Handle tuple format: (head, tail) or (head, tail, score)
|
|
448
|
+
if isinstance(rel, (tuple, list)):
|
|
449
|
+
if len(rel) == 2:
|
|
450
|
+
head, tail = rel
|
|
451
|
+
# Try to extract text if they're dicts
|
|
452
|
+
head_text = head.get("text", str(head)) if isinstance(head, dict) else str(head)
|
|
453
|
+
tail_text = tail.get("text", str(tail)) if isinstance(tail, dict) else str(tail)
|
|
454
|
+
# Try to get confidence from dict
|
|
455
|
+
head_conf = head.get("score", head.get("confidence", 0.5)) if isinstance(head, dict) else 0.5
|
|
456
|
+
tail_conf = tail.get("score", tail.get("confidence", 0.5)) if isinstance(tail, dict) else 0.5
|
|
457
|
+
return head_text, tail_text, min(head_conf, tail_conf)
|
|
458
|
+
elif len(rel) >= 3:
|
|
459
|
+
head, tail, score = rel[0], rel[1], rel[2]
|
|
460
|
+
head_text = head.get("text", str(head)) if isinstance(head, dict) else str(head)
|
|
461
|
+
tail_text = tail.get("text", str(tail)) if isinstance(tail, dict) else str(tail)
|
|
462
|
+
return head_text, tail_text, float(score) if score else 0.5
|
|
463
|
+
|
|
464
|
+
# Handle dict format with head/tail keys
|
|
465
|
+
if isinstance(rel, dict):
|
|
466
|
+
# Try different key names for head/tail
|
|
467
|
+
head_data = rel.get("head") or rel.get("source") or rel.get("subject") or {}
|
|
468
|
+
tail_data = rel.get("tail") or rel.get("target") or rel.get("object") or {}
|
|
469
|
+
|
|
470
|
+
# Get overall relation confidence if available
|
|
471
|
+
rel_conf = rel.get("score") or rel.get("confidence") or rel.get("prob")
|
|
472
|
+
|
|
473
|
+
# Parse head
|
|
474
|
+
if isinstance(head_data, dict):
|
|
475
|
+
head = head_data.get("text") or head_data.get("name") or head_data.get("span") or ""
|
|
476
|
+
head_conf = head_data.get("score") or head_data.get("confidence") or head_data.get("prob")
|
|
477
|
+
else:
|
|
478
|
+
head = str(head_data) if head_data else ""
|
|
479
|
+
head_conf = None
|
|
480
|
+
|
|
481
|
+
# Parse tail
|
|
482
|
+
if isinstance(tail_data, dict):
|
|
483
|
+
tail = tail_data.get("text") or tail_data.get("name") or tail_data.get("span") or ""
|
|
484
|
+
tail_conf = tail_data.get("score") or tail_data.get("confidence") or tail_data.get("prob")
|
|
485
|
+
else:
|
|
486
|
+
tail = str(tail_data) if tail_data else ""
|
|
487
|
+
tail_conf = None
|
|
488
|
+
|
|
489
|
+
# Determine final confidence: prefer relation-level, then min of head/tail
|
|
490
|
+
if rel_conf is not None:
|
|
491
|
+
confidence = float(rel_conf)
|
|
492
|
+
elif head_conf is not None and tail_conf is not None:
|
|
493
|
+
confidence = min(float(head_conf), float(tail_conf))
|
|
494
|
+
elif head_conf is not None:
|
|
495
|
+
confidence = float(head_conf)
|
|
496
|
+
elif tail_conf is not None:
|
|
497
|
+
confidence = float(tail_conf)
|
|
498
|
+
else:
|
|
499
|
+
confidence = 0.5 # Default if no confidence found
|
|
500
|
+
|
|
501
|
+
return head, tail, confidence
|
|
502
|
+
|
|
503
|
+
# Unknown format
|
|
504
|
+
logger.warning(f" Unknown relation format: {type(rel).__name__}")
|
|
505
|
+
return "", "", 0.0
|
|
506
|
+
|
|
507
|
+
def _infer_entity_type(
|
|
508
|
+
self,
|
|
509
|
+
text: str,
|
|
510
|
+
model,
|
|
511
|
+
source_text: str,
|
|
512
|
+
) -> EntityType:
|
|
513
|
+
"""Infer entity type using GLiNER2 entity extraction."""
|
|
514
|
+
try:
|
|
515
|
+
result = model.extract_entities(source_text, self._get_entity_types(), include_confidence=True)
|
|
516
|
+
entities = result.get("entities", {})
|
|
517
|
+
|
|
518
|
+
text_lower = text.lower()
|
|
519
|
+
for entity_type, entity_list in entities.items():
|
|
520
|
+
for entity in entity_list:
|
|
521
|
+
if isinstance(entity, dict):
|
|
522
|
+
entity_text = entity.get("text", "").lower()
|
|
523
|
+
else:
|
|
524
|
+
entity_text = str(entity).lower()
|
|
525
|
+
|
|
526
|
+
if entity_text == text_lower or entity_text in text_lower or text_lower in entity_text:
|
|
527
|
+
return GLINER_TYPE_MAP.get(entity_type.lower(), EntityType.UNKNOWN)
|
|
528
|
+
|
|
529
|
+
except Exception as e:
|
|
530
|
+
logger.debug(f"Entity type inference failed: {e}")
|
|
531
|
+
|
|
532
|
+
return EntityType.UNKNOWN
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
# Allow importing without decorator for testing
|
|
536
|
+
GLiNER2ExtractorClass = GLiNER2Extractor
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Labeler plugins for Stage 5 (Labeling).
|
|
3
|
+
|
|
4
|
+
Applies labels to statements (sentiment, relation type, confidence).
|
|
5
|
+
|
|
6
|
+
Note: Taxonomy classification is handled in Stage 6 (Taxonomy) via
|
|
7
|
+
plugins/taxonomy/ modules, not here. The TaxonomyLabeler classes are
|
|
8
|
+
provided for backward compatibility but are NOT auto-registered.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from .base import BaseLabelerPlugin
|
|
12
|
+
from .sentiment import SentimentLabeler
|
|
13
|
+
from .relation_type import RelationTypeLabeler
|
|
14
|
+
from .confidence import ConfidenceLabeler
|
|
15
|
+
|
|
16
|
+
# Taxonomy labelers - exported for backward compatibility only
|
|
17
|
+
# NOT auto-registered as Stage 5 labelers (use Stage 6 taxonomy plugins instead)
|
|
18
|
+
from .taxonomy import TaxonomyLabeler
|
|
19
|
+
from .taxonomy_embedding import EmbeddingTaxonomyLabeler
|
|
20
|
+
|
|
21
|
+
__all__ = [
|
|
22
|
+
"BaseLabelerPlugin",
|
|
23
|
+
"SentimentLabeler",
|
|
24
|
+
"RelationTypeLabeler",
|
|
25
|
+
"ConfidenceLabeler",
|
|
26
|
+
# Taxonomy labelers (not auto-registered - for manual use only)
|
|
27
|
+
"TaxonomyLabeler",
|
|
28
|
+
"EmbeddingTaxonomyLabeler",
|
|
29
|
+
]
|