glinker 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- glinker/__init__.py +54 -0
- glinker/core/__init__.py +56 -0
- glinker/core/base.py +103 -0
- glinker/core/builders.py +547 -0
- glinker/core/dag.py +898 -0
- glinker/core/factory.py +261 -0
- glinker/core/registry.py +31 -0
- glinker/l0/__init__.py +21 -0
- glinker/l0/component.py +472 -0
- glinker/l0/models.py +90 -0
- glinker/l0/processor.py +108 -0
- glinker/l1/__init__.py +15 -0
- glinker/l1/component.py +284 -0
- glinker/l1/models.py +47 -0
- glinker/l1/processor.py +152 -0
- glinker/l2/__init__.py +19 -0
- glinker/l2/component.py +1220 -0
- glinker/l2/models.py +99 -0
- glinker/l2/processor.py +170 -0
- glinker/l3/__init__.py +12 -0
- glinker/l3/component.py +184 -0
- glinker/l3/models.py +48 -0
- glinker/l3/processor.py +350 -0
- glinker/l4/__init__.py +9 -0
- glinker/l4/component.py +121 -0
- glinker/l4/models.py +21 -0
- glinker/l4/processor.py +156 -0
- glinker/py.typed +1 -0
- glinker-0.1.0.dist-info/METADATA +994 -0
- glinker-0.1.0.dist-info/RECORD +33 -0
- glinker-0.1.0.dist-info/WHEEL +5 -0
- glinker-0.1.0.dist-info/licenses/LICENSE +201 -0
- glinker-0.1.0.dist-info/top_level.txt +1 -0
glinker/l3/processor.py
ADDED
|
@@ -0,0 +1,350 @@
|
|
|
1
|
+
from typing import Any, List, Optional
|
|
2
|
+
import torch
|
|
3
|
+
from glinker.core.base import BaseProcessor
|
|
4
|
+
from glinker.core.registry import processor_registry
|
|
5
|
+
from .models import L3Config, L3Input, L3Output, L3Entity
|
|
6
|
+
from .component import L3Component
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class L3Processor(BaseProcessor[L3Config, L3Input, L3Output]):
|
|
10
|
+
"""GLiNER entity linking processor"""
|
|
11
|
+
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
config: L3Config,
|
|
15
|
+
component: L3Component,
|
|
16
|
+
pipeline: list[tuple[str, dict[str, Any]]] = None
|
|
17
|
+
):
|
|
18
|
+
super().__init__(config, component, pipeline)
|
|
19
|
+
self._validate_pipeline()
|
|
20
|
+
self.schema = {}
|
|
21
|
+
self._l2_processor = None # Will be set by DAG executor for cache write-back
|
|
22
|
+
|
|
23
|
+
def _default_pipeline(self) -> list[tuple[str, dict[str, Any]]]:
|
|
24
|
+
return [
|
|
25
|
+
("predict_entities", {}),
|
|
26
|
+
("filter_by_score", {}),
|
|
27
|
+
("sort_by_position", {})
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
@staticmethod
|
|
31
|
+
def _build_input_spans(l1_entities_for_text: List[Any]) -> List[List[dict]]:
|
|
32
|
+
"""Convert L1 entities to GLiNER input_spans format.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
l1_entities_for_text: List of L1Entity objects for a single text
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
List of span dicts with 'start' and 'end' keys,
|
|
39
|
+
wrapped in an outer list as expected by GLiNER input_spans.
|
|
40
|
+
"""
|
|
41
|
+
spans = [{"start": e.start, "end": e.end} for e in l1_entities_for_text]
|
|
42
|
+
return [spans]
|
|
43
|
+
|
|
44
|
+
def __call__(
|
|
45
|
+
self,
|
|
46
|
+
texts: List[str] = None,
|
|
47
|
+
candidates: List[List[Any]] = None,
|
|
48
|
+
l1_entities: List[List[Any]] = None,
|
|
49
|
+
input_data: L3Input = None
|
|
50
|
+
) -> L3Output:
|
|
51
|
+
"""Process texts with candidate labels
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
texts: List of input texts
|
|
55
|
+
candidates: List of candidate lists per text (from L2)
|
|
56
|
+
l1_entities: Optional L1 entities per text, used to build input_spans
|
|
57
|
+
so L3 predicts on the same spans extracted in L1
|
|
58
|
+
input_data: Alternative L3Input object
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
# Support both direct params and L3Input
|
|
62
|
+
if texts is not None and candidates is not None:
|
|
63
|
+
texts_to_process = texts
|
|
64
|
+
candidates_to_process = candidates
|
|
65
|
+
elif input_data is not None:
|
|
66
|
+
texts_to_process = input_data.texts
|
|
67
|
+
candidates_to_process = input_data.labels
|
|
68
|
+
else:
|
|
69
|
+
raise ValueError("Either 'texts'+'candidates' or 'input_data' must be provided")
|
|
70
|
+
|
|
71
|
+
all_entities = []
|
|
72
|
+
|
|
73
|
+
# Detect shared candidates (all texts use the same list, e.g. simple pipeline)
|
|
74
|
+
shared = (
|
|
75
|
+
len(candidates_to_process) > 1
|
|
76
|
+
and all(c is candidates_to_process[0] for c in candidates_to_process[1:])
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
# Pre-compute labels & embeddings once when candidates are shared
|
|
80
|
+
shared_labels = None
|
|
81
|
+
shared_label_to_candidate = None
|
|
82
|
+
shared_use_precomputed = False
|
|
83
|
+
shared_embeddings = None
|
|
84
|
+
|
|
85
|
+
if shared:
|
|
86
|
+
ref_candidates = candidates_to_process[0]
|
|
87
|
+
if self.schema:
|
|
88
|
+
shared_labels, shared_label_to_candidate = (
|
|
89
|
+
self._create_gliner_labels_with_mapping(ref_candidates)
|
|
90
|
+
)
|
|
91
|
+
else:
|
|
92
|
+
shared_labels = [self._extract_label(c) for c in ref_candidates]
|
|
93
|
+
shared_label_to_candidate = {}
|
|
94
|
+
|
|
95
|
+
shared_use_precomputed = (
|
|
96
|
+
self.config.use_precomputed_embeddings
|
|
97
|
+
and self.component.supports_precomputed_embeddings
|
|
98
|
+
and self._can_use_precomputed(ref_candidates, shared_label_to_candidate)
|
|
99
|
+
)
|
|
100
|
+
if shared_use_precomputed:
|
|
101
|
+
shared_embeddings = self._get_embeddings_tensor(
|
|
102
|
+
ref_candidates, shared_labels, shared_label_to_candidate
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
for idx, (text, text_candidates) in enumerate(zip(texts_to_process, candidates_to_process)):
|
|
106
|
+
# Build input_spans from L1 entities if available
|
|
107
|
+
input_spans = None
|
|
108
|
+
if l1_entities is not None and idx < len(l1_entities):
|
|
109
|
+
text_l1 = l1_entities[idx]
|
|
110
|
+
if text_l1:
|
|
111
|
+
input_spans = self._build_input_spans(text_l1)
|
|
112
|
+
|
|
113
|
+
if shared:
|
|
114
|
+
labels = shared_labels
|
|
115
|
+
label_to_candidate = shared_label_to_candidate
|
|
116
|
+
use_precomputed = shared_use_precomputed
|
|
117
|
+
embeddings = shared_embeddings
|
|
118
|
+
else:
|
|
119
|
+
# Create labels from candidates (per-text)
|
|
120
|
+
if self.schema:
|
|
121
|
+
labels, label_to_candidate = self._create_gliner_labels_with_mapping(text_candidates)
|
|
122
|
+
else:
|
|
123
|
+
labels = [self._extract_label(c) for c in text_candidates]
|
|
124
|
+
label_to_candidate = {}
|
|
125
|
+
|
|
126
|
+
use_precomputed = (
|
|
127
|
+
self.config.use_precomputed_embeddings
|
|
128
|
+
and self.component.supports_precomputed_embeddings
|
|
129
|
+
and self._can_use_precomputed(text_candidates, label_to_candidate)
|
|
130
|
+
)
|
|
131
|
+
embeddings = None
|
|
132
|
+
if use_precomputed:
|
|
133
|
+
embeddings = self._get_embeddings_tensor(text_candidates, labels, label_to_candidate)
|
|
134
|
+
|
|
135
|
+
if use_precomputed and embeddings is not None:
|
|
136
|
+
entities = self.component.predict_with_embeddings(
|
|
137
|
+
text, labels, embeddings, input_spans=input_spans
|
|
138
|
+
)
|
|
139
|
+
else:
|
|
140
|
+
# Regular prediction
|
|
141
|
+
entities = self.component.predict_entities(text, labels, input_spans=input_spans)
|
|
142
|
+
|
|
143
|
+
# Optionally cache computed embeddings
|
|
144
|
+
if self.config.cache_embeddings and self.component.supports_precomputed_embeddings:
|
|
145
|
+
self._cache_embeddings(text_candidates, labels, label_to_candidate)
|
|
146
|
+
|
|
147
|
+
# Apply rest of pipeline
|
|
148
|
+
for method_name, kwargs in self.pipeline[1:]:
|
|
149
|
+
method = getattr(self.component, method_name)
|
|
150
|
+
entities = method(entities, **kwargs)
|
|
151
|
+
|
|
152
|
+
# Apply ranking if configured
|
|
153
|
+
if self.schema.get('ranking'):
|
|
154
|
+
entities = self._rank_entities(entities, text_candidates)
|
|
155
|
+
|
|
156
|
+
all_entities.append(entities)
|
|
157
|
+
|
|
158
|
+
return L3Output(entities=all_entities)
|
|
159
|
+
|
|
160
|
+
def _can_use_precomputed(
|
|
161
|
+
self,
|
|
162
|
+
candidates: List[Any],
|
|
163
|
+
label_to_candidate: dict
|
|
164
|
+
) -> bool:
|
|
165
|
+
"""Check if all candidates have compatible precomputed embeddings"""
|
|
166
|
+
if not candidates:
|
|
167
|
+
return False
|
|
168
|
+
|
|
169
|
+
expected_model = self.config.model_name
|
|
170
|
+
|
|
171
|
+
for candidate in candidates:
|
|
172
|
+
# Check if candidate has embedding
|
|
173
|
+
embedding = getattr(candidate, 'embedding', None)
|
|
174
|
+
if embedding is None:
|
|
175
|
+
return False
|
|
176
|
+
|
|
177
|
+
# Check if model matches
|
|
178
|
+
model_id = getattr(candidate, 'embedding_model_id', None)
|
|
179
|
+
if model_id != expected_model:
|
|
180
|
+
return False
|
|
181
|
+
|
|
182
|
+
return True
|
|
183
|
+
|
|
184
|
+
def _get_embeddings_tensor(
|
|
185
|
+
self,
|
|
186
|
+
candidates: List[Any],
|
|
187
|
+
labels: List[str],
|
|
188
|
+
label_to_candidate: dict
|
|
189
|
+
) -> torch.Tensor:
|
|
190
|
+
"""Build embeddings tensor from candidates in same order as labels"""
|
|
191
|
+
embeddings = []
|
|
192
|
+
|
|
193
|
+
for label in labels:
|
|
194
|
+
candidate = label_to_candidate.get(label)
|
|
195
|
+
if candidate and hasattr(candidate, 'embedding') and candidate.embedding:
|
|
196
|
+
embeddings.append(candidate.embedding)
|
|
197
|
+
else:
|
|
198
|
+
# Should not happen if _can_use_precomputed returned True
|
|
199
|
+
raise ValueError(f"Missing embedding for label: {label}")
|
|
200
|
+
|
|
201
|
+
return torch.tensor(embeddings, device=self.component.device)
|
|
202
|
+
|
|
203
|
+
def _cache_embeddings(
|
|
204
|
+
self,
|
|
205
|
+
candidates: List[Any],
|
|
206
|
+
labels: List[str],
|
|
207
|
+
label_to_candidate: dict
|
|
208
|
+
):
|
|
209
|
+
"""Compute and cache embeddings for candidates without them"""
|
|
210
|
+
if not self._l2_processor:
|
|
211
|
+
return
|
|
212
|
+
|
|
213
|
+
# Find candidates without embeddings
|
|
214
|
+
to_compute = []
|
|
215
|
+
to_compute_ids = []
|
|
216
|
+
|
|
217
|
+
for candidate in candidates:
|
|
218
|
+
if not getattr(candidate, 'embedding', None):
|
|
219
|
+
to_compute.append(candidate)
|
|
220
|
+
to_compute_ids.append(candidate.entity_id)
|
|
221
|
+
|
|
222
|
+
if not to_compute:
|
|
223
|
+
return
|
|
224
|
+
|
|
225
|
+
# Format labels for these candidates
|
|
226
|
+
template = self.schema.get('template', '{label}')
|
|
227
|
+
compute_labels = []
|
|
228
|
+
for candidate in to_compute:
|
|
229
|
+
try:
|
|
230
|
+
if hasattr(candidate, 'model_dump'):
|
|
231
|
+
formatted = template.format(**candidate.model_dump())
|
|
232
|
+
elif hasattr(candidate, 'dict'):
|
|
233
|
+
formatted = template.format(**candidate.dict())
|
|
234
|
+
else:
|
|
235
|
+
formatted = candidate.label
|
|
236
|
+
compute_labels.append(formatted)
|
|
237
|
+
except KeyError:
|
|
238
|
+
compute_labels.append(candidate.label)
|
|
239
|
+
|
|
240
|
+
# Encode labels
|
|
241
|
+
embeddings = self.component.encode_labels(compute_labels)
|
|
242
|
+
|
|
243
|
+
# Update L2 layer
|
|
244
|
+
if hasattr(self._l2_processor, 'component'):
|
|
245
|
+
for layer in self._l2_processor.component.layers:
|
|
246
|
+
if layer.is_available():
|
|
247
|
+
layer.update_embeddings(
|
|
248
|
+
to_compute_ids,
|
|
249
|
+
embeddings.tolist(),
|
|
250
|
+
self.config.model_name
|
|
251
|
+
)
|
|
252
|
+
break # Update first available layer
|
|
253
|
+
|
|
254
|
+
def _extract_label(self, candidate: Any) -> str:
|
|
255
|
+
"""Extract label from candidate"""
|
|
256
|
+
if hasattr(candidate, 'label'):
|
|
257
|
+
return candidate.label
|
|
258
|
+
return str(candidate)
|
|
259
|
+
|
|
260
|
+
def _create_gliner_labels_with_mapping(self, candidates: List[Any]) -> tuple:
|
|
261
|
+
"""
|
|
262
|
+
Create GLiNER labels using schema template and return label->candidate mapping.
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
tuple: (labels: List[str], label_to_candidate: dict)
|
|
266
|
+
"""
|
|
267
|
+
template = self.schema.get('template', '{label}')
|
|
268
|
+
labels = []
|
|
269
|
+
label_to_candidate = {}
|
|
270
|
+
seen = set()
|
|
271
|
+
|
|
272
|
+
for candidate in candidates:
|
|
273
|
+
try:
|
|
274
|
+
if hasattr(candidate, 'model_dump'):
|
|
275
|
+
cand_dict = candidate.model_dump()
|
|
276
|
+
elif hasattr(candidate, 'dict'):
|
|
277
|
+
cand_dict = candidate.dict()
|
|
278
|
+
elif isinstance(candidate, dict):
|
|
279
|
+
cand_dict = candidate
|
|
280
|
+
else:
|
|
281
|
+
label = str(candidate)
|
|
282
|
+
if label.lower() not in seen:
|
|
283
|
+
labels.append(label)
|
|
284
|
+
seen.add(label.lower())
|
|
285
|
+
continue
|
|
286
|
+
|
|
287
|
+
label = template.format(**cand_dict)
|
|
288
|
+
label_lower = label.lower()
|
|
289
|
+
if label_lower not in seen:
|
|
290
|
+
labels.append(label)
|
|
291
|
+
label_to_candidate[label] = candidate
|
|
292
|
+
seen.add(label_lower)
|
|
293
|
+
except (KeyError, AttributeError):
|
|
294
|
+
if hasattr(candidate, 'label'):
|
|
295
|
+
if candidate.label.lower() not in seen:
|
|
296
|
+
labels.append(candidate.label)
|
|
297
|
+
label_to_candidate[candidate.label] = candidate
|
|
298
|
+
seen.add(candidate.label.lower())
|
|
299
|
+
|
|
300
|
+
return labels, label_to_candidate
|
|
301
|
+
|
|
302
|
+
def _create_gliner_labels(self, candidates: List[Any]) -> List[str]:
|
|
303
|
+
"""Create GLiNER labels using schema template (legacy, for compatibility)"""
|
|
304
|
+
labels, _ = self._create_gliner_labels_with_mapping(candidates)
|
|
305
|
+
return labels
|
|
306
|
+
|
|
307
|
+
def _rank_entities(self, entities: List[L3Entity], candidates: List[Any]) -> List[L3Entity]:
|
|
308
|
+
"""Re-rank entities using multiple scoring factors"""
|
|
309
|
+
# Build label to candidate mapping
|
|
310
|
+
label_to_candidate = {}
|
|
311
|
+
for c in candidates:
|
|
312
|
+
if hasattr(c, 'label'):
|
|
313
|
+
label_to_candidate[c.label] = c
|
|
314
|
+
if hasattr(c, 'aliases'):
|
|
315
|
+
for alias in c.aliases:
|
|
316
|
+
if alias not in label_to_candidate:
|
|
317
|
+
label_to_candidate[alias] = c
|
|
318
|
+
|
|
319
|
+
# Calculate weighted scores
|
|
320
|
+
for entity in entities:
|
|
321
|
+
total_score = 0.0
|
|
322
|
+
total_weight = 0.0
|
|
323
|
+
|
|
324
|
+
for rank_spec in self.schema['ranking']:
|
|
325
|
+
field = rank_spec['field']
|
|
326
|
+
weight = rank_spec['weight']
|
|
327
|
+
total_weight += weight
|
|
328
|
+
|
|
329
|
+
if field == 'gliner_score':
|
|
330
|
+
total_score += entity.score * weight
|
|
331
|
+
else:
|
|
332
|
+
candidate = label_to_candidate.get(entity.label)
|
|
333
|
+
if candidate and hasattr(candidate, field):
|
|
334
|
+
value = getattr(candidate, field, 0)
|
|
335
|
+
if isinstance(value, (int, float)):
|
|
336
|
+
normalized = min(value / 1000000.0, 1.0)
|
|
337
|
+
total_score += normalized * weight
|
|
338
|
+
|
|
339
|
+
if total_weight > 0:
|
|
340
|
+
entity.score = total_score / total_weight
|
|
341
|
+
|
|
342
|
+
return sorted(entities, key=lambda x: x.score, reverse=True)
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
@processor_registry.register("l3_batch")
|
|
346
|
+
def create_l3_processor(config_dict: dict, pipeline: list = None) -> L3Processor:
|
|
347
|
+
"""Factory: creates component + processor"""
|
|
348
|
+
config = L3Config(**config_dict)
|
|
349
|
+
component = L3Component(config)
|
|
350
|
+
return L3Processor(config, component, pipeline)
|
glinker/l4/__init__.py
ADDED
glinker/l4/component.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from typing import List, Optional
|
|
2
|
+
from gliner import GLiNER
|
|
3
|
+
from glinker.core.base import BaseComponent
|
|
4
|
+
from glinker.l3.models import L3Entity
|
|
5
|
+
from .models import L4Config
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class L4Component(BaseComponent[L4Config]):
|
|
9
|
+
"""GLiNER-based reranking component (uni-encoder only, no precomputed embeddings)"""
|
|
10
|
+
|
|
11
|
+
def _setup(self):
|
|
12
|
+
"""Initialize GLiNER model"""
|
|
13
|
+
self.model = GLiNER.from_pretrained(
|
|
14
|
+
self.config.model_name,
|
|
15
|
+
token=self.config.token,
|
|
16
|
+
max_length=self.config.max_length
|
|
17
|
+
)
|
|
18
|
+
self.model.to(self.config.device)
|
|
19
|
+
|
|
20
|
+
def get_available_methods(self) -> List[str]:
|
|
21
|
+
return [
|
|
22
|
+
"predict_entities",
|
|
23
|
+
"predict_entities_chunked",
|
|
24
|
+
"filter_by_score",
|
|
25
|
+
"sort_by_position",
|
|
26
|
+
"deduplicate_entities"
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
def predict_entities(
|
|
30
|
+
self,
|
|
31
|
+
text: str,
|
|
32
|
+
labels: List[str],
|
|
33
|
+
input_spans: List[List[dict]] = None
|
|
34
|
+
) -> List[L3Entity]:
|
|
35
|
+
"""Predict entities using GLiNER for a single label set.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
text: Input text
|
|
39
|
+
labels: List of label strings
|
|
40
|
+
input_spans: Optional list of span dicts with 'start' and 'end' keys
|
|
41
|
+
"""
|
|
42
|
+
if not labels:
|
|
43
|
+
return []
|
|
44
|
+
|
|
45
|
+
kwargs = dict(
|
|
46
|
+
threshold=self.config.threshold,
|
|
47
|
+
flat_ner=self.config.flat_ner,
|
|
48
|
+
multi_label=self.config.multi_label,
|
|
49
|
+
return_class_probs=True
|
|
50
|
+
)
|
|
51
|
+
if input_spans is not None:
|
|
52
|
+
kwargs["input_spans"] = input_spans
|
|
53
|
+
|
|
54
|
+
entities = self.model.predict_entities(text, labels, **kwargs)
|
|
55
|
+
|
|
56
|
+
return [
|
|
57
|
+
L3Entity(
|
|
58
|
+
text=e["text"],
|
|
59
|
+
label=e["label"],
|
|
60
|
+
start=e["start"],
|
|
61
|
+
end=e["end"],
|
|
62
|
+
score=e["score"],
|
|
63
|
+
class_probs=e.get("class_probs")
|
|
64
|
+
)
|
|
65
|
+
for e in entities
|
|
66
|
+
]
|
|
67
|
+
|
|
68
|
+
def predict_entities_chunked(
|
|
69
|
+
self,
|
|
70
|
+
text: str,
|
|
71
|
+
labels: List[str],
|
|
72
|
+
max_labels: int,
|
|
73
|
+
input_spans: List[List[dict]] = None
|
|
74
|
+
) -> List[L3Entity]:
|
|
75
|
+
"""Predict entities with candidate chunking.
|
|
76
|
+
|
|
77
|
+
Splits labels into chunks of max_labels, runs inference on each chunk,
|
|
78
|
+
and merges results.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
text: Input text
|
|
82
|
+
labels: Full list of candidate label strings
|
|
83
|
+
max_labels: Maximum labels per inference call
|
|
84
|
+
input_spans: Optional span constraints from L1 entities
|
|
85
|
+
"""
|
|
86
|
+
if not labels:
|
|
87
|
+
return []
|
|
88
|
+
|
|
89
|
+
if len(labels) <= max_labels:
|
|
90
|
+
return self.predict_entities(text, labels, input_spans=input_spans)
|
|
91
|
+
|
|
92
|
+
# Split labels into chunks
|
|
93
|
+
chunks = [
|
|
94
|
+
labels[i:i + max_labels]
|
|
95
|
+
for i in range(0, len(labels), max_labels)
|
|
96
|
+
]
|
|
97
|
+
|
|
98
|
+
all_entities = []
|
|
99
|
+
for chunk in chunks:
|
|
100
|
+
entities = self.predict_entities(text, chunk, input_spans=input_spans)
|
|
101
|
+
all_entities.extend(entities)
|
|
102
|
+
|
|
103
|
+
return all_entities
|
|
104
|
+
|
|
105
|
+
def filter_by_score(self, entities: List[L3Entity], threshold: float = None) -> List[L3Entity]:
|
|
106
|
+
"""Filter entities by confidence score"""
|
|
107
|
+
threshold = threshold if threshold is not None else self.config.threshold
|
|
108
|
+
return [e for e in entities if e.score >= threshold]
|
|
109
|
+
|
|
110
|
+
def sort_by_position(self, entities: List[L3Entity]) -> List[L3Entity]:
|
|
111
|
+
"""Sort entities by position in text"""
|
|
112
|
+
return sorted(entities, key=lambda e: e.start)
|
|
113
|
+
|
|
114
|
+
def deduplicate_entities(self, entities: List[L3Entity]) -> List[L3Entity]:
|
|
115
|
+
"""Remove duplicate entities, keeping the highest-scoring one per span"""
|
|
116
|
+
best = {}
|
|
117
|
+
for entity in entities:
|
|
118
|
+
key = (entity.text, entity.start, entity.end)
|
|
119
|
+
if key not in best or entity.score > best[key].score:
|
|
120
|
+
best[key] = entity
|
|
121
|
+
return list(best.values())
|
glinker/l4/models.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from pydantic import Field
|
|
2
|
+
from typing import List, Any, Optional
|
|
3
|
+
from glinker.core.base import BaseConfig, BaseInput, BaseOutput
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class L4Config(BaseConfig):
|
|
7
|
+
model_name: str = Field(...)
|
|
8
|
+
token: str = Field(None)
|
|
9
|
+
device: str = Field("cpu")
|
|
10
|
+
threshold: float = Field(0.5)
|
|
11
|
+
flat_ner: bool = Field(True)
|
|
12
|
+
multi_label: bool = Field(False)
|
|
13
|
+
max_labels: int = Field(
|
|
14
|
+
20,
|
|
15
|
+
description="Maximum number of candidate labels per inference call. "
|
|
16
|
+
"When candidates exceed this, they are split into chunks."
|
|
17
|
+
)
|
|
18
|
+
max_length: int = Field(
|
|
19
|
+
None,
|
|
20
|
+
description="Maximum sequence length for tokenization. Passed to GLiNER.from_pretrained."
|
|
21
|
+
)
|
glinker/l4/processor.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
from typing import Any, List, Optional
|
|
2
|
+
from glinker.core.base import BaseProcessor
|
|
3
|
+
from glinker.core.registry import processor_registry
|
|
4
|
+
from glinker.l3.models import L3Input, L3Output, L3Entity
|
|
5
|
+
from .models import L4Config
|
|
6
|
+
from .component import L4Component
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class L4Processor(BaseProcessor[L4Config, L3Input, L3Output]):
|
|
10
|
+
"""GLiNER reranking processor with candidate chunking"""
|
|
11
|
+
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
config: L4Config,
|
|
15
|
+
component: L4Component,
|
|
16
|
+
pipeline: list[tuple[str, dict[str, Any]]] = None
|
|
17
|
+
):
|
|
18
|
+
super().__init__(config, component, pipeline)
|
|
19
|
+
self._validate_pipeline()
|
|
20
|
+
self.schema = {}
|
|
21
|
+
|
|
22
|
+
def _default_pipeline(self) -> list[tuple[str, dict[str, Any]]]:
|
|
23
|
+
return [
|
|
24
|
+
("predict_entities_chunked", {}),
|
|
25
|
+
("deduplicate_entities", {}),
|
|
26
|
+
("filter_by_score", {}),
|
|
27
|
+
("sort_by_position", {})
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
@staticmethod
|
|
31
|
+
def _build_input_spans(l1_entities_for_text: List[Any]) -> List[List[dict]]:
|
|
32
|
+
"""Convert L1 entities to GLiNER input_spans format."""
|
|
33
|
+
spans = [{"start": e.start, "end": e.end} for e in l1_entities_for_text]
|
|
34
|
+
return [spans]
|
|
35
|
+
|
|
36
|
+
def __call__(
|
|
37
|
+
self,
|
|
38
|
+
texts: List[str] = None,
|
|
39
|
+
candidates: List[List[Any]] = None,
|
|
40
|
+
l1_entities: List[List[Any]] = None,
|
|
41
|
+
input_data: L3Input = None
|
|
42
|
+
) -> L3Output:
|
|
43
|
+
"""Process texts with candidate labels using chunked GLiNER inference.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
texts: List of input texts
|
|
47
|
+
candidates: List of candidate lists per text (from L2)
|
|
48
|
+
l1_entities: Optional L1 entities per text, used to build input_spans
|
|
49
|
+
input_data: Alternative L3Input object
|
|
50
|
+
"""
|
|
51
|
+
if texts is not None and candidates is not None:
|
|
52
|
+
texts_to_process = texts
|
|
53
|
+
candidates_to_process = candidates
|
|
54
|
+
elif input_data is not None:
|
|
55
|
+
texts_to_process = input_data.texts
|
|
56
|
+
candidates_to_process = input_data.labels
|
|
57
|
+
else:
|
|
58
|
+
raise ValueError("Either 'texts'+'candidates' or 'input_data' must be provided")
|
|
59
|
+
|
|
60
|
+
all_entities = []
|
|
61
|
+
max_labels = self.config.max_labels
|
|
62
|
+
|
|
63
|
+
# Detect shared candidates (all texts use the same list)
|
|
64
|
+
shared = (
|
|
65
|
+
len(candidates_to_process) > 1
|
|
66
|
+
and all(c is candidates_to_process[0] for c in candidates_to_process[1:])
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
shared_labels = None
|
|
70
|
+
if shared:
|
|
71
|
+
ref_candidates = candidates_to_process[0]
|
|
72
|
+
if self.schema:
|
|
73
|
+
shared_labels, _ = self._create_gliner_labels_with_mapping(ref_candidates)
|
|
74
|
+
else:
|
|
75
|
+
shared_labels = [self._extract_label(c) for c in ref_candidates]
|
|
76
|
+
|
|
77
|
+
for idx, (text, text_candidates) in enumerate(zip(texts_to_process, candidates_to_process)):
|
|
78
|
+
# Build input_spans from L1 entities if available
|
|
79
|
+
input_spans = None
|
|
80
|
+
if l1_entities is not None and idx < len(l1_entities):
|
|
81
|
+
text_l1 = l1_entities[idx]
|
|
82
|
+
if text_l1:
|
|
83
|
+
input_spans = self._build_input_spans(text_l1)
|
|
84
|
+
|
|
85
|
+
if shared:
|
|
86
|
+
labels = shared_labels
|
|
87
|
+
else:
|
|
88
|
+
if self.schema:
|
|
89
|
+
labels, _ = self._create_gliner_labels_with_mapping(text_candidates)
|
|
90
|
+
else:
|
|
91
|
+
labels = [self._extract_label(c) for c in text_candidates]
|
|
92
|
+
|
|
93
|
+
# Run chunked prediction
|
|
94
|
+
entities = self.component.predict_entities_chunked(
|
|
95
|
+
text, labels, max_labels, input_spans=input_spans
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
# Apply remaining pipeline steps (deduplicate, filter, sort)
|
|
99
|
+
for method_name, kwargs in self.pipeline[1:]:
|
|
100
|
+
method = getattr(self.component, method_name)
|
|
101
|
+
entities = method(entities, **kwargs)
|
|
102
|
+
|
|
103
|
+
all_entities.append(entities)
|
|
104
|
+
|
|
105
|
+
return L3Output(entities=all_entities)
|
|
106
|
+
|
|
107
|
+
def _extract_label(self, candidate: Any) -> str:
|
|
108
|
+
"""Extract label from candidate"""
|
|
109
|
+
if hasattr(candidate, 'label'):
|
|
110
|
+
return candidate.label
|
|
111
|
+
return str(candidate)
|
|
112
|
+
|
|
113
|
+
def _create_gliner_labels_with_mapping(self, candidates: List[Any]) -> tuple:
|
|
114
|
+
"""Create GLiNER labels using schema template and return label->candidate mapping."""
|
|
115
|
+
template = self.schema.get('template', '{label}')
|
|
116
|
+
labels = []
|
|
117
|
+
label_to_candidate = {}
|
|
118
|
+
seen = set()
|
|
119
|
+
|
|
120
|
+
for candidate in candidates:
|
|
121
|
+
try:
|
|
122
|
+
if hasattr(candidate, 'model_dump'):
|
|
123
|
+
cand_dict = candidate.model_dump()
|
|
124
|
+
elif hasattr(candidate, 'dict'):
|
|
125
|
+
cand_dict = candidate.dict()
|
|
126
|
+
elif isinstance(candidate, dict):
|
|
127
|
+
cand_dict = candidate
|
|
128
|
+
else:
|
|
129
|
+
label = str(candidate)
|
|
130
|
+
if label.lower() not in seen:
|
|
131
|
+
labels.append(label)
|
|
132
|
+
seen.add(label.lower())
|
|
133
|
+
continue
|
|
134
|
+
|
|
135
|
+
label = template.format(**cand_dict)
|
|
136
|
+
label_lower = label.lower()
|
|
137
|
+
if label_lower not in seen:
|
|
138
|
+
labels.append(label)
|
|
139
|
+
label_to_candidate[label] = candidate
|
|
140
|
+
seen.add(label_lower)
|
|
141
|
+
except (KeyError, AttributeError):
|
|
142
|
+
if hasattr(candidate, 'label'):
|
|
143
|
+
if candidate.label.lower() not in seen:
|
|
144
|
+
labels.append(candidate.label)
|
|
145
|
+
label_to_candidate[candidate.label] = candidate
|
|
146
|
+
seen.add(candidate.label.lower())
|
|
147
|
+
|
|
148
|
+
return labels, label_to_candidate
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@processor_registry.register("l4_reranker")
|
|
152
|
+
def create_l4_processor(config_dict: dict, pipeline: list = None) -> L4Processor:
|
|
153
|
+
"""Factory: creates component + processor"""
|
|
154
|
+
config = L4Config(**config_dict)
|
|
155
|
+
component = L4Component(config)
|
|
156
|
+
return L4Processor(config, component, pipeline)
|
glinker/py.typed
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Marker file for PEP 561
|