OntoLearner 1.4.6__py3-none-any.whl → 1.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ontolearner/VERSION +1 -1
- ontolearner/base/learner.py +20 -14
- ontolearner/learner/__init__.py +1 -1
- ontolearner/learner/label_mapper.py +1 -1
- ontolearner/learner/llm.py +73 -3
- ontolearner/learner/retriever.py +24 -3
- ontolearner/learner/taxonomy_discovery/__init__.py +18 -0
- ontolearner/learner/taxonomy_discovery/alexbek.py +500 -0
- ontolearner/learner/taxonomy_discovery/rwthdbis.py +1082 -0
- ontolearner/learner/taxonomy_discovery/sbunlp.py +402 -0
- ontolearner/learner/taxonomy_discovery/skhnlp.py +1138 -0
- ontolearner/learner/term_typing/__init__.py +17 -0
- ontolearner/learner/term_typing/alexbek.py +1262 -0
- ontolearner/learner/term_typing/rwthdbis.py +379 -0
- ontolearner/learner/term_typing/sbunlp.py +478 -0
- ontolearner/learner/text2onto/__init__.py +16 -0
- ontolearner/learner/text2onto/alexbek.py +1219 -0
- ontolearner/learner/text2onto/sbunlp.py +598 -0
- {ontolearner-1.4.6.dist-info → ontolearner-1.4.8.dist-info}/METADATA +5 -1
- {ontolearner-1.4.6.dist-info → ontolearner-1.4.8.dist-info}/RECORD +22 -10
- {ontolearner-1.4.6.dist-info → ontolearner-1.4.8.dist-info}/WHEEL +0 -0
- {ontolearner-1.4.6.dist-info → ontolearner-1.4.8.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1262 @@
|
|
|
1
|
+
# Copyright (c) 2025 SciKnowOrg
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the MIT License (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# https://opensource.org/licenses/MIT
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Learners for supervised and retrieval-augmented *term typing*.
|
|
16
|
+
|
|
17
|
+
This module implements two learners:
|
|
18
|
+
|
|
19
|
+
- **AlexbekRFLearner** (retriever/classifier):
|
|
20
|
+
Encodes terms with a Hugging Face encoder, optionally augments with simple
|
|
21
|
+
graph features, and trains a One-vs-Rest RandomForest for multi-label typing.
|
|
22
|
+
|
|
23
|
+
- **AlexbekRAGLearner** (retrieval-augmented generation):
|
|
24
|
+
Builds an in-memory example index with sentence embeddings, retrieves
|
|
25
|
+
nearest examples for each query term, then prompts an instruction-tuned
|
|
26
|
+
causal LLM to produce types, parsing the JSON response.
|
|
27
|
+
|
|
28
|
+
Both learners conform to the `AutoLearner` / `AutoRetriever` APIs used in
|
|
29
|
+
the outer pipeline.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
import gc
|
|
33
|
+
import json
|
|
34
|
+
import re
|
|
35
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
36
|
+
|
|
37
|
+
import numpy as np
|
|
38
|
+
import torch
|
|
39
|
+
import torch.nn.functional as F
|
|
40
|
+
import networkx as nx
|
|
41
|
+
from tqdm import tqdm
|
|
42
|
+
from sklearn.preprocessing import MultiLabelBinarizer
|
|
43
|
+
from sklearn.ensemble import RandomForestClassifier
|
|
44
|
+
from sklearn.multiclass import OneVsRestClassifier
|
|
45
|
+
|
|
46
|
+
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
|
|
47
|
+
from sentence_transformers import SentenceTransformer
|
|
48
|
+
|
|
49
|
+
from ...base import AutoLearner, AutoRetriever
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class AlexbekRFLearner(AutoRetriever):
|
|
53
|
+
"""
|
|
54
|
+
Embedding-based multi-label classifier for *term typing*.
|
|
55
|
+
|
|
56
|
+
Pipeline
|
|
57
|
+
1) Load a Hugging Face encoder (tokenizer + model).
|
|
58
|
+
2) Encode input terms into sentence embeddings.
|
|
59
|
+
3) Optionally augment with simple graph (co-occurrence) features.
|
|
60
|
+
4) Train a One-vs-Rest RandomForest on the concatenated features.
|
|
61
|
+
5) Predict multi-label types with a probability threshold (fallback to top-1).
|
|
62
|
+
|
|
63
|
+
Implements the `AutoRetriever` interface used by the outer pipeline.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
device: str = "cpu",
|
|
69
|
+
batch_size: int = 16,
|
|
70
|
+
max_length: int = 256,
|
|
71
|
+
threshold: float = 0.30,
|
|
72
|
+
use_graph_features: bool = True,
|
|
73
|
+
rf_kwargs: Optional[Dict[str, Any]] = None,
|
|
74
|
+
):
|
|
75
|
+
"""Configure the RF-based multi-label learner.
|
|
76
|
+
|
|
77
|
+
Parameters
|
|
78
|
+
device:
|
|
79
|
+
Torch device spec ('cpu' or 'cuda').
|
|
80
|
+
batch_size:
|
|
81
|
+
Encoding mini-batch size for the transformer.
|
|
82
|
+
max_length:
|
|
83
|
+
Maximum input token length for the encoder tokenizer.
|
|
84
|
+
threshold:
|
|
85
|
+
Per-label probability threshold at prediction time.
|
|
86
|
+
use_graph_features:
|
|
87
|
+
If True, add simple graph features to embeddings.
|
|
88
|
+
rf_kwargs:
|
|
89
|
+
Optional RandomForest hyperparameters dictionary.
|
|
90
|
+
|
|
91
|
+
"""
|
|
92
|
+
# Runtime / inference settings
|
|
93
|
+
self.device = torch.device(device)
|
|
94
|
+
self.batch_size = batch_size
|
|
95
|
+
self.max_length = max_length
|
|
96
|
+
self.threshold = threshold # probability cutoff for selecting labels
|
|
97
|
+
self.use_graph_features = use_graph_features
|
|
98
|
+
|
|
99
|
+
# RandomForest hyperparameters (with sensible defaults)
|
|
100
|
+
self.rf_kwargs = rf_kwargs or dict(
|
|
101
|
+
n_estimators=200, max_depth=20, class_weight="balanced", random_state=42
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# Filled during load/fit
|
|
105
|
+
self.model_name: Optional[str] = None
|
|
106
|
+
self.tokenizer: Optional[AutoTokenizer] = None
|
|
107
|
+
self.embedding_model: Optional[AutoModel] = None
|
|
108
|
+
|
|
109
|
+
# Label processing / classifier / optional graph
|
|
110
|
+
self.label_binarizer = MultiLabelBinarizer()
|
|
111
|
+
self.ovr_random_forest: Optional[OneVsRestClassifier] = None
|
|
112
|
+
self.term_graph: Optional[nx.Graph] = None
|
|
113
|
+
|
|
114
|
+
def load(self, model_id: str, **_: Any) -> None:
|
|
115
|
+
"""Load a Hugging Face encoder by model id (tokenizer + base model).
|
|
116
|
+
|
|
117
|
+
Parameters
|
|
118
|
+
model_id:
|
|
119
|
+
HF model identifier or local path for an encoder backbone.
|
|
120
|
+
|
|
121
|
+
Side Effects
|
|
122
|
+
- Sets `self.model_name`, `self.tokenizer`, `self.embedding_model`.
|
|
123
|
+
- Puts the model in eval mode and moves it to `self.device`.
|
|
124
|
+
"""
|
|
125
|
+
self.model_name = model_id
|
|
126
|
+
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
127
|
+
self.embedding_model = AutoModel.from_pretrained(model_id)
|
|
128
|
+
self.embedding_model.eval().to(self.device)
|
|
129
|
+
|
|
130
|
+
def fit(self, data: Any, task: str, ontologizer: bool = True, **_: Any) -> None:
|
|
131
|
+
"""Train the One-vs-Rest RandomForest on term embeddings (+ optional graph features).
|
|
132
|
+
|
|
133
|
+
Parameters
|
|
134
|
+
data:
|
|
135
|
+
Training payload; supported formats are routed via `_as_term_types_dicts`.
|
|
136
|
+
Each example must contain at least `{"term": str, "types": List[str]}`.
|
|
137
|
+
task:
|
|
138
|
+
Must be `'term-typing'`.
|
|
139
|
+
ontologizer:
|
|
140
|
+
Unused here; accepted for API compatibility.
|
|
141
|
+
**_:
|
|
142
|
+
Ignored extra arguments.
|
|
143
|
+
|
|
144
|
+
Raises
|
|
145
|
+
ValueError
|
|
146
|
+
If `task` is not `'term-typing'` or if no valid examples are found.
|
|
147
|
+
"""
|
|
148
|
+
if task != "term-typing":
|
|
149
|
+
raise ValueError(
|
|
150
|
+
"OntologyTypeRFClassifier supports only task='term-typing'."
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# Normalize incoming training data into a list of dicts: {term, types, RAG}
|
|
154
|
+
training_rows = self._as_term_types_dicts(data)
|
|
155
|
+
if not training_rows:
|
|
156
|
+
raise ValueError(
|
|
157
|
+
"No valid training examples found (need 'term' and 'types')."
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# Split out terms and raw labels
|
|
161
|
+
training_terms: List[str] = [row["term"] for row in training_rows]
|
|
162
|
+
raw_label_lists: List[List[str]] = [row["types"] for row in training_rows]
|
|
163
|
+
|
|
164
|
+
# Fit label binarizer to learn label space/order
|
|
165
|
+
self.label_binarizer.fit(raw_label_lists)
|
|
166
|
+
|
|
167
|
+
# Encode terms to sentence embeddings
|
|
168
|
+
term_embeddings_train = self._encode(training_terms)
|
|
169
|
+
|
|
170
|
+
# Optionally build a light-weight co-occurrence graph and extract features
|
|
171
|
+
if self.use_graph_features:
|
|
172
|
+
self.term_graph = self._create_term_graph(training_rows)
|
|
173
|
+
graph_features_train = self._extract_graph_features(
|
|
174
|
+
self.term_graph, training_terms
|
|
175
|
+
)
|
|
176
|
+
X_train = np.hstack([term_embeddings_train, graph_features_train])
|
|
177
|
+
else:
|
|
178
|
+
self.term_graph = None
|
|
179
|
+
X_train = term_embeddings_train
|
|
180
|
+
|
|
181
|
+
# Multi-label targets (multi-hot)
|
|
182
|
+
Y_train = self.label_binarizer.transform(raw_label_lists)
|
|
183
|
+
|
|
184
|
+
# One-vs-Rest RandomForest (one binary RF per label)
|
|
185
|
+
self.ovr_random_forest = OneVsRestClassifier(
|
|
186
|
+
RandomForestClassifier(**self.rf_kwargs)
|
|
187
|
+
)
|
|
188
|
+
self.ovr_random_forest.fit(X_train, Y_train)
|
|
189
|
+
|
|
190
|
+
def predict(
|
|
191
|
+
self, data: Any, task: str, ontologizer: bool = True, **_: Any
|
|
192
|
+
) -> List[Dict[str, Any]]:
|
|
193
|
+
"""Predict multi-label types for input terms.
|
|
194
|
+
|
|
195
|
+
Parameters
|
|
196
|
+
data:
|
|
197
|
+
Evaluation payload; formats normalized by `_as_predict_terms_ids`.
|
|
198
|
+
task:
|
|
199
|
+
Must be `'term-typing'`.
|
|
200
|
+
ontologizer:
|
|
201
|
+
Unused here; accepted for API compatibility.
|
|
202
|
+
**_:
|
|
203
|
+
Ignored extra arguments.
|
|
204
|
+
|
|
205
|
+
Returns
|
|
206
|
+
List[Dict[str, Any]]
|
|
207
|
+
A list of dictionaries with keys:
|
|
208
|
+
- `id`: Original example id (if provided).
|
|
209
|
+
- `term`: Input term string.
|
|
210
|
+
- `types`: List of predicted label strings (selected by threshold or top-1).
|
|
211
|
+
|
|
212
|
+
Raises
|
|
213
|
+
ValueError
|
|
214
|
+
If `task` is not `'term-typing'`.
|
|
215
|
+
RuntimeError
|
|
216
|
+
If `load()` and `fit()` have not been called.
|
|
217
|
+
"""
|
|
218
|
+
if task != "term-typing":
|
|
219
|
+
raise ValueError(
|
|
220
|
+
"OntologyTypeRFClassifier supports only task='term-typing'."
|
|
221
|
+
)
|
|
222
|
+
if (
|
|
223
|
+
self.ovr_random_forest is None
|
|
224
|
+
or self.tokenizer is None
|
|
225
|
+
or self.embedding_model is None
|
|
226
|
+
):
|
|
227
|
+
raise RuntimeError("Call load() and fit() before predict().")
|
|
228
|
+
|
|
229
|
+
# Normalize prediction input into parallel lists of terms and example ids
|
|
230
|
+
test_terms, example_ids = self._as_predict_terms_ids(data)
|
|
231
|
+
|
|
232
|
+
# Encode terms
|
|
233
|
+
term_embeddings_test = self._encode(test_terms)
|
|
234
|
+
|
|
235
|
+
# Match feature layout used during training
|
|
236
|
+
if self.use_graph_features and self.term_graph is not None:
|
|
237
|
+
graph_features_test = self._extract_graph_features(
|
|
238
|
+
self.term_graph, test_terms
|
|
239
|
+
)
|
|
240
|
+
X_test = np.hstack([term_embeddings_test, graph_features_test])
|
|
241
|
+
else:
|
|
242
|
+
X_test = term_embeddings_test
|
|
243
|
+
|
|
244
|
+
# Probabilities per label (shape: [n_samples, n_labels])
|
|
245
|
+
probability_matrix = self.ovr_random_forest.predict_proba(X_test)
|
|
246
|
+
|
|
247
|
+
predictions: List[Dict[str, Any]] = []
|
|
248
|
+
label_names = self.label_binarizer.classes_
|
|
249
|
+
threshold = float(self.threshold)
|
|
250
|
+
|
|
251
|
+
# Select labels above threshold; fallback to argmax if none exceed it
|
|
252
|
+
for row_index, label_probabilities in enumerate(probability_matrix):
|
|
253
|
+
selected_label_indices = np.where(label_probabilities > threshold)[0]
|
|
254
|
+
if len(selected_label_indices) == 0:
|
|
255
|
+
selected_label_indices = [int(np.argmax(label_probabilities))]
|
|
256
|
+
|
|
257
|
+
predicted_types = [
|
|
258
|
+
label_names[label_idx] for label_idx in selected_label_indices
|
|
259
|
+
]
|
|
260
|
+
|
|
261
|
+
predictions.append(
|
|
262
|
+
{
|
|
263
|
+
"id": example_ids[row_index],
|
|
264
|
+
"term": test_terms[row_index],
|
|
265
|
+
"types": predicted_types,
|
|
266
|
+
}
|
|
267
|
+
)
|
|
268
|
+
return predictions
|
|
269
|
+
|
|
270
|
+
def tasks_ground_truth_former(self, data: Any, task: str) -> List[Dict[str, Any]]:
|
|
271
|
+
"""Normalize ground-truth into a list of {id, term, types} dicts for evaluation.
|
|
272
|
+
|
|
273
|
+
Parameters
|
|
274
|
+
data:
|
|
275
|
+
Ground-truth payload; supported formats include objects exposing
|
|
276
|
+
`.term_typings`, a list of dicts, or a list of tuples/lists.
|
|
277
|
+
task:
|
|
278
|
+
Must be `'term-typing'`.
|
|
279
|
+
|
|
280
|
+
Returns
|
|
281
|
+
List[Dict[str, Any]]
|
|
282
|
+
A list of dictionaries with keys `id`, `term`, `types` (list of str).
|
|
283
|
+
|
|
284
|
+
Raises
|
|
285
|
+
ValueError
|
|
286
|
+
If `task` is not `'term-typing'`.
|
|
287
|
+
"""
|
|
288
|
+
if task != "term-typing":
|
|
289
|
+
raise ValueError(
|
|
290
|
+
"OntologyTypeRFClassifier supports only task='term-typing'."
|
|
291
|
+
)
|
|
292
|
+
return self._as_gold_id_term_types(data)
|
|
293
|
+
|
|
294
|
+
def _encode(self, texts: List[str]) -> np.ndarray:
|
|
295
|
+
"""Encode a list of strings into L2-normalized sentence embeddings.
|
|
296
|
+
|
|
297
|
+
Parameters
|
|
298
|
+
texts:
|
|
299
|
+
List of input texts/terms.
|
|
300
|
+
|
|
301
|
+
Returns
|
|
302
|
+
np.ndarray
|
|
303
|
+
Array of shape `(len(texts), hidden_size)` with L2-normalized
|
|
304
|
+
embeddings. If `texts` is empty, returns a `(0, hidden_size)` array.
|
|
305
|
+
"""
|
|
306
|
+
assert self.tokenizer is not None and self.embedding_model is not None, (
|
|
307
|
+
"Call load(model_id) first."
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
if not texts:
|
|
311
|
+
hidden_size = getattr(
|
|
312
|
+
getattr(self.embedding_model, "config", None), "hidden_size", 768
|
|
313
|
+
)
|
|
314
|
+
return np.zeros((0, hidden_size), dtype=np.float32)
|
|
315
|
+
|
|
316
|
+
batch_embeddings: List[torch.Tensor] = []
|
|
317
|
+
|
|
318
|
+
for start_idx in tqdm(range(0, len(texts), self.batch_size), desc="Embedding"):
|
|
319
|
+
end_idx = start_idx + self.batch_size
|
|
320
|
+
batch_texts = texts[start_idx:end_idx]
|
|
321
|
+
|
|
322
|
+
# Tokenize and move to device
|
|
323
|
+
tokenized_batch = self.tokenizer(
|
|
324
|
+
batch_texts,
|
|
325
|
+
padding=True,
|
|
326
|
+
truncation=True,
|
|
327
|
+
max_length=self.max_length,
|
|
328
|
+
return_tensors="pt",
|
|
329
|
+
).to(self.device)
|
|
330
|
+
|
|
331
|
+
# Forward pass without gradients
|
|
332
|
+
with torch.no_grad():
|
|
333
|
+
model_output = self.embedding_model(**tokenized_batch)
|
|
334
|
+
|
|
335
|
+
# Prefer dedicated pooler if provided; otherwise pool by last valid token
|
|
336
|
+
if (
|
|
337
|
+
hasattr(model_output, "pooler_output")
|
|
338
|
+
and model_output.pooler_output is not None
|
|
339
|
+
):
|
|
340
|
+
sentence_embeddings = model_output.pooler_output
|
|
341
|
+
else:
|
|
342
|
+
sentence_embeddings = self._last_token_pool(
|
|
343
|
+
model_output.last_hidden_state,
|
|
344
|
+
tokenized_batch["attention_mask"],
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
# L2-normalize embeddings for stability
|
|
348
|
+
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
|
|
349
|
+
|
|
350
|
+
# Detach, move to CPU, collect
|
|
351
|
+
batch_embeddings.append(sentence_embeddings.detach().cpu())
|
|
352
|
+
|
|
353
|
+
# Best-effort memory cleanup (especially useful on CUDA)
|
|
354
|
+
del tokenized_batch, model_output, sentence_embeddings
|
|
355
|
+
if self.device.type == "cuda":
|
|
356
|
+
torch.cuda.empty_cache()
|
|
357
|
+
gc.collect()
|
|
358
|
+
|
|
359
|
+
# Concatenate all batches and convert to NumPy
|
|
360
|
+
return torch.cat(batch_embeddings, dim=0).numpy()
|
|
361
|
+
|
|
362
|
+
def _last_token_pool(
|
|
363
|
+
self, last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
|
|
364
|
+
) -> torch.Tensor:
|
|
365
|
+
"""Select the last *non-padding* token embedding for each sequence.
|
|
366
|
+
|
|
367
|
+
Parameters
|
|
368
|
+
last_hidden_states:
|
|
369
|
+
Tensor of shape `(batch, seq_len, hidden)`.
|
|
370
|
+
attention_mask:
|
|
371
|
+
Tensor of shape `(batch, seq_len)` with 1 for real tokens.
|
|
372
|
+
|
|
373
|
+
Returns
|
|
374
|
+
torch.Tensor
|
|
375
|
+
Tensor of shape `(batch, hidden)` with per-sequence pooled embeddings.
|
|
376
|
+
"""
|
|
377
|
+
last_valid_token_idx = attention_mask.sum(dim=1) - 1 # (batch,)
|
|
378
|
+
batch_row_idx = torch.arange(
|
|
379
|
+
last_hidden_states.size(0), device=last_hidden_states.device
|
|
380
|
+
)
|
|
381
|
+
return last_hidden_states[batch_row_idx, last_valid_token_idx]
|
|
382
|
+
|
|
383
|
+
def _create_term_graph(self, training_rows: List[Dict[str, Any]]) -> nx.Graph:
|
|
384
|
+
"""Create a simple undirected co-occurrence graph from training rows.
|
|
385
|
+
|
|
386
|
+
Graph Structure
|
|
387
|
+
Nodes
|
|
388
|
+
Terms (node attribute `'types'` is stored per term).
|
|
389
|
+
Edges
|
|
390
|
+
Between a term and each neighbor from its optional RAG list.
|
|
391
|
+
Edge weight = number of shared types (or 0.1 if none shared).
|
|
392
|
+
|
|
393
|
+
Parameters
|
|
394
|
+
training_rows:
|
|
395
|
+
Normalized rows with keys: `'term'`, `'types'`, optional `'RAG'`.
|
|
396
|
+
|
|
397
|
+
Returns
|
|
398
|
+
networkx.Graph
|
|
399
|
+
The constructed undirected graph.
|
|
400
|
+
"""
|
|
401
|
+
graph = nx.Graph()
|
|
402
|
+
|
|
403
|
+
for row in training_rows:
|
|
404
|
+
term = row["term"]
|
|
405
|
+
term_types = row.get("types", [])
|
|
406
|
+
graph.add_node(term, types=term_types)
|
|
407
|
+
|
|
408
|
+
# RAG may be a list of neighbor dicts like {"term": ..., "types": [...]}
|
|
409
|
+
for neighbor in row.get("RAG", []) or []:
|
|
410
|
+
neighbor_term = neighbor.get("term")
|
|
411
|
+
neighbor_types = neighbor.get("types", [])
|
|
412
|
+
|
|
413
|
+
# Shared-type-based edge weight (weak edge if no overlap)
|
|
414
|
+
shared_types = set(term_types).intersection(set(neighbor_types))
|
|
415
|
+
edge_weight = float(len(shared_types)) if shared_types else 0.1
|
|
416
|
+
|
|
417
|
+
graph.add_edge(term, neighbor_term, weight=edge_weight)
|
|
418
|
+
|
|
419
|
+
return graph
|
|
420
|
+
|
|
421
|
+
def _extract_graph_features(
|
|
422
|
+
self, term_graph: nx.Graph, terms: List[str]
|
|
423
|
+
) -> np.ndarray:
|
|
424
|
+
"""Compute simple per-term graph features.
|
|
425
|
+
|
|
426
|
+
Feature Vector
|
|
427
|
+
For each term we compute a 4-dim vector:
|
|
428
|
+
`[degree, clustering_coefficient, degree_centrality, pagerank_score]`
|
|
429
|
+
|
|
430
|
+
Parameters
|
|
431
|
+
term_graph:
|
|
432
|
+
Graph built over training terms.
|
|
433
|
+
terms:
|
|
434
|
+
List of term strings to extract features for.
|
|
435
|
+
|
|
436
|
+
Returns
|
|
437
|
+
np.ndarray
|
|
438
|
+
Array of shape `(len(terms), 4)` (dtype float32).
|
|
439
|
+
"""
|
|
440
|
+
if len(term_graph):
|
|
441
|
+
degree_centrality = nx.degree_centrality(term_graph)
|
|
442
|
+
pagerank_scores = nx.pagerank(term_graph)
|
|
443
|
+
else:
|
|
444
|
+
degree_centrality, pagerank_scores = {}, {}
|
|
445
|
+
|
|
446
|
+
feature_rows: List[List[float]] = []
|
|
447
|
+
for term in terms:
|
|
448
|
+
if term in term_graph:
|
|
449
|
+
feature_rows.append(
|
|
450
|
+
[
|
|
451
|
+
float(term_graph.degree(term)),
|
|
452
|
+
float(nx.clustering(term_graph, term)),
|
|
453
|
+
float(degree_centrality.get(term, 0.0)),
|
|
454
|
+
float(pagerank_scores.get(term, 0.0)),
|
|
455
|
+
]
|
|
456
|
+
)
|
|
457
|
+
else:
|
|
458
|
+
feature_rows.append([0.0, 0.0, 0.0, 0.0])
|
|
459
|
+
|
|
460
|
+
return np.asarray(feature_rows, dtype=np.float32)
|
|
461
|
+
|
|
462
|
+
def _as_term_types_dicts(self, data: Any) -> List[Dict[str, Any]]:
|
|
463
|
+
"""Normalize diverse training data formats to a list of dicts: {term, types, RAG}.
|
|
464
|
+
|
|
465
|
+
Supported Inputs
|
|
466
|
+
- Object with attribute `.term_typings` (iterable of items exposing
|
|
467
|
+
`.term`, `.types`, optional `.RAG`).
|
|
468
|
+
- List of dicts with keys `term`, `types`, optional `RAG`.
|
|
469
|
+
- List/tuple of `(term, types[, RAG])`.
|
|
470
|
+
|
|
471
|
+
Parameters
|
|
472
|
+
data:
|
|
473
|
+
Training payload.
|
|
474
|
+
|
|
475
|
+
Returns
|
|
476
|
+
List[Dict[str, Any]]
|
|
477
|
+
Normalized dictionaries ready for training.
|
|
478
|
+
|
|
479
|
+
Raises
|
|
480
|
+
ValueError
|
|
481
|
+
If `data` is neither a list/tuple nor exposes `.term_typings`.
|
|
482
|
+
"""
|
|
483
|
+
normalized_rows: List[Dict[str, Any]] = []
|
|
484
|
+
|
|
485
|
+
# Case 1: object with attribute `.term_typings`
|
|
486
|
+
term_typings_attr = getattr(data, "term_typings", None)
|
|
487
|
+
if term_typings_attr is not None:
|
|
488
|
+
for item in term_typings_attr:
|
|
489
|
+
term_text = getattr(item, "term", None)
|
|
490
|
+
type_list = getattr(item, "types", None)
|
|
491
|
+
rag_neighbors = getattr(item, "RAG", None)
|
|
492
|
+
if term_text is None or type_list is None:
|
|
493
|
+
continue
|
|
494
|
+
if not isinstance(type_list, list):
|
|
495
|
+
type_list = [type_list]
|
|
496
|
+
normalized_rows.append(
|
|
497
|
+
{
|
|
498
|
+
"term": str(term_text),
|
|
499
|
+
"types": [str(x) for x in type_list],
|
|
500
|
+
"RAG": rag_neighbors,
|
|
501
|
+
}
|
|
502
|
+
)
|
|
503
|
+
return normalized_rows
|
|
504
|
+
|
|
505
|
+
# Otherwise: must be a list/tuple-like container
|
|
506
|
+
if not isinstance(data, (list, tuple)):
|
|
507
|
+
raise ValueError(
|
|
508
|
+
"Training data must be a list/tuple or expose .term_typings"
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
if not data:
|
|
512
|
+
return normalized_rows
|
|
513
|
+
|
|
514
|
+
# Case 2: list of dicts
|
|
515
|
+
if isinstance(data[0], dict):
|
|
516
|
+
for row in data:
|
|
517
|
+
term_text = row.get("term")
|
|
518
|
+
type_list = row.get("types")
|
|
519
|
+
rag_neighbors = row.get("RAG")
|
|
520
|
+
if term_text is None or type_list is None:
|
|
521
|
+
continue
|
|
522
|
+
if not isinstance(type_list, list):
|
|
523
|
+
type_list = [type_list]
|
|
524
|
+
normalized_rows.append(
|
|
525
|
+
{
|
|
526
|
+
"term": str(term_text),
|
|
527
|
+
"types": [str(x) for x in type_list],
|
|
528
|
+
"RAG": rag_neighbors,
|
|
529
|
+
}
|
|
530
|
+
)
|
|
531
|
+
return normalized_rows
|
|
532
|
+
|
|
533
|
+
# Case 3: list of tuples/lists: (term, types[, RAG])
|
|
534
|
+
for item in data:
|
|
535
|
+
if not isinstance(item, (list, tuple)) or len(item) < 2:
|
|
536
|
+
continue
|
|
537
|
+
term_text, type_list = item[0], item[1]
|
|
538
|
+
rag_neighbors = item[2] if len(item) > 2 else None
|
|
539
|
+
if term_text is None or type_list is None:
|
|
540
|
+
continue
|
|
541
|
+
if not isinstance(type_list, list):
|
|
542
|
+
type_list = [type_list]
|
|
543
|
+
normalized_rows.append(
|
|
544
|
+
{
|
|
545
|
+
"term": str(term_text),
|
|
546
|
+
"types": [str(x) for x in type_list],
|
|
547
|
+
"RAG": rag_neighbors,
|
|
548
|
+
}
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
return normalized_rows
|
|
552
|
+
|
|
553
|
+
def _as_predict_terms_ids(self, data: Any) -> Tuple[List[str], List[Any]]:
|
|
554
|
+
"""Normalize prediction input into parallel lists: (terms, ids).
|
|
555
|
+
|
|
556
|
+
Supported Inputs
|
|
557
|
+
- Object with `.term_typings`.
|
|
558
|
+
- List of dicts with `term` and optional `id`.
|
|
559
|
+
- List of tuples/lists `(term, id[, ...])`.
|
|
560
|
+
- List of plain term strings.
|
|
561
|
+
|
|
562
|
+
Parameters
|
|
563
|
+
data:
|
|
564
|
+
Evaluation payload.
|
|
565
|
+
|
|
566
|
+
Returns
|
|
567
|
+
Tuple[List[str], List[Any]]
|
|
568
|
+
`(terms, example_ids)` lists aligned by index.
|
|
569
|
+
|
|
570
|
+
Raises
|
|
571
|
+
ValueError
|
|
572
|
+
If the input format is unsupported.
|
|
573
|
+
"""
|
|
574
|
+
terms: List[str] = []
|
|
575
|
+
example_ids: List[Any] = []
|
|
576
|
+
|
|
577
|
+
# Case 1: object with attribute `.term_typings`
|
|
578
|
+
term_typings_attr = getattr(data, "term_typings", None)
|
|
579
|
+
if term_typings_attr is not None:
|
|
580
|
+
for idx, item in enumerate(term_typings_attr):
|
|
581
|
+
terms.append(str(getattr(item, "term", "")))
|
|
582
|
+
example_ids.append(getattr(item, "id", getattr(item, "ID", idx)))
|
|
583
|
+
return terms, example_ids
|
|
584
|
+
|
|
585
|
+
# Case 2: list/tuple container
|
|
586
|
+
if isinstance(data, (list, tuple)) and data:
|
|
587
|
+
first_element = data[0]
|
|
588
|
+
|
|
589
|
+
# 2a) list of dicts
|
|
590
|
+
if isinstance(first_element, dict):
|
|
591
|
+
for i, row in enumerate(data):
|
|
592
|
+
terms.append(str(row.get("term", "")))
|
|
593
|
+
example_ids.append(row.get("id", row.get("ID", i)))
|
|
594
|
+
return terms, example_ids
|
|
595
|
+
|
|
596
|
+
# 2b) list of tuples/lists: (term, id[, ...])
|
|
597
|
+
if isinstance(first_element, (list, tuple)):
|
|
598
|
+
for i, tuple_row in enumerate(data):
|
|
599
|
+
if not tuple_row:
|
|
600
|
+
continue
|
|
601
|
+
terms.append(str(tuple_row[0]))
|
|
602
|
+
example_ids.append(tuple_row[1] if len(tuple_row) > 1 else i)
|
|
603
|
+
return terms, example_ids
|
|
604
|
+
|
|
605
|
+
# 2c) list of strings (terms only)
|
|
606
|
+
if isinstance(first_element, str):
|
|
607
|
+
terms = [str(x) for x in data] # type: ignore[arg-type]
|
|
608
|
+
example_ids = list(range(len(terms)))
|
|
609
|
+
return terms, example_ids
|
|
610
|
+
|
|
611
|
+
raise ValueError("Unsupported predict() input format.")
|
|
612
|
+
|
|
613
|
+
def _as_gold_id_term_types(self, data: Any) -> List[Dict[str, Any]]:
|
|
614
|
+
"""Normalize gold labels into a list of dicts: {id, term, types}.
|
|
615
|
+
|
|
616
|
+
Supported Inputs
|
|
617
|
+
Mirrors `_as_term_types_dicts`, but ensures an `id` is set.
|
|
618
|
+
|
|
619
|
+
Parameters
|
|
620
|
+
data:
|
|
621
|
+
Ground-truth payload.
|
|
622
|
+
|
|
623
|
+
Returns
|
|
624
|
+
List[Dict[str, Any]]
|
|
625
|
+
`{'id': Any, 'term': str, 'types': List[str]}` entries.
|
|
626
|
+
|
|
627
|
+
"""
|
|
628
|
+
gold_rows: List[Dict[str, Any]] = []
|
|
629
|
+
|
|
630
|
+
# Case 1: object with attribute `.term_typings`
|
|
631
|
+
term_typings_attr = getattr(data, "term_typings", None)
|
|
632
|
+
if term_typings_attr is not None:
|
|
633
|
+
for idx, item in enumerate(term_typings_attr):
|
|
634
|
+
gold_id = getattr(item, "id", getattr(item, "ID", idx))
|
|
635
|
+
term_text = str(getattr(item, "term", ""))
|
|
636
|
+
type_list = getattr(item, "types", [])
|
|
637
|
+
if not isinstance(type_list, list):
|
|
638
|
+
type_list = [type_list]
|
|
639
|
+
gold_rows.append(
|
|
640
|
+
{
|
|
641
|
+
"id": gold_id,
|
|
642
|
+
"term": term_text,
|
|
643
|
+
"types": [str(t) for t in type_list],
|
|
644
|
+
}
|
|
645
|
+
)
|
|
646
|
+
return gold_rows
|
|
647
|
+
|
|
648
|
+
# Case 2: list/tuple container
|
|
649
|
+
if isinstance(data, (list, tuple)) and data:
|
|
650
|
+
first_element = data[0]
|
|
651
|
+
|
|
652
|
+
# 2a) list of dicts
|
|
653
|
+
if isinstance(first_element, dict):
|
|
654
|
+
for i, row in enumerate(data):
|
|
655
|
+
gold_id = row.get("id", row.get("ID", i))
|
|
656
|
+
term_text = str(row.get("term", ""))
|
|
657
|
+
type_list = row.get("types", [])
|
|
658
|
+
if not isinstance(type_list, list):
|
|
659
|
+
type_list = [type_list]
|
|
660
|
+
gold_rows.append(
|
|
661
|
+
{
|
|
662
|
+
"id": gold_id,
|
|
663
|
+
"term": term_text,
|
|
664
|
+
"types": [str(t) for t in type_list],
|
|
665
|
+
}
|
|
666
|
+
)
|
|
667
|
+
return gold_rows
|
|
668
|
+
|
|
669
|
+
# 2b) list of tuples/lists: (term, types[, id])
|
|
670
|
+
if isinstance(first_element, (list, tuple)):
|
|
671
|
+
for i, tuple_row in enumerate(data):
|
|
672
|
+
if not tuple_row or len(tuple_row) < 2:
|
|
673
|
+
continue
|
|
674
|
+
term_text = str(tuple_row[0])
|
|
675
|
+
type_list = tuple_row[1]
|
|
676
|
+
gold_id = tuple_row[2] if len(tuple_row) > 2 else i
|
|
677
|
+
if not isinstance(type_list, list):
|
|
678
|
+
type_list = [type_list]
|
|
679
|
+
gold_rows.append(
|
|
680
|
+
{
|
|
681
|
+
"id": gold_id,
|
|
682
|
+
"term": term_text,
|
|
683
|
+
"types": [str(t) for t in type_list],
|
|
684
|
+
}
|
|
685
|
+
)
|
|
686
|
+
return gold_rows
|
|
687
|
+
|
|
688
|
+
raise ValueError(
|
|
689
|
+
"Unsupported ground-truth input format for tasks_ground_truth_former()."
|
|
690
|
+
)
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
class AlexbekRAGLearner(AutoLearner):
|
|
694
|
+
"""Retrieval-Augmented Term Typing learner (single task: term-typing).
|
|
695
|
+
|
|
696
|
+
Flow
|
|
697
|
+
1) `fit`: collect (term -> [types]) examples, build an in-memory index
|
|
698
|
+
using a sentence-embedding model.
|
|
699
|
+
2) `predict`: for each new term, retrieve top-k similar examples, compose a
|
|
700
|
+
structured prompt, query an instruction-tuned causal LLM, and parse types.
|
|
701
|
+
|
|
702
|
+
Returns
|
|
703
|
+
List[Dict[str, Any]]
|
|
704
|
+
`{"term": str, "types": List[str], "id": Optional[str]}` rows.
|
|
705
|
+
"""
|
|
706
|
+
|
|
707
|
+
def __init__(
|
|
708
|
+
self,
|
|
709
|
+
llm_model_id: str = "Qwen/Qwen2.5-0.5B-Instruct",
|
|
710
|
+
retriever_model_id: str = "sentence-transformers/all-MiniLM-L6-v2",
|
|
711
|
+
device: str = "auto", # "auto" | "cuda" | "cpu"
|
|
712
|
+
token: str = "", # HF token if needed
|
|
713
|
+
top_k: int = 3,
|
|
714
|
+
max_new_tokens: int = 256,
|
|
715
|
+
gen_batch_size: int = 4, # generation batch size
|
|
716
|
+
enc_batch_size: int = 64, # embedding batch size
|
|
717
|
+
**kwargs: Any, # absorb extra pipeline-style args
|
|
718
|
+
) -> None:
|
|
719
|
+
"""Configure the RAG learner.
|
|
720
|
+
|
|
721
|
+
Parameters
|
|
722
|
+
llm_model_id:
|
|
723
|
+
HF model id/path for the instruction-tuned causal LLM.
|
|
724
|
+
retriever_model_id:
|
|
725
|
+
Sentence-embedding model id for retrieval.
|
|
726
|
+
device:
|
|
727
|
+
Device policy ('auto'|'cuda'|'cpu') for the LLM.
|
|
728
|
+
token:
|
|
729
|
+
Optional HF token for gated models.
|
|
730
|
+
top_k:
|
|
731
|
+
Number of nearest examples to retrieve per query term.
|
|
732
|
+
max_new_tokens:
|
|
733
|
+
Decoding budget for the LLM.
|
|
734
|
+
gen_batch_size:
|
|
735
|
+
Number of prompts per generation batch.
|
|
736
|
+
enc_batch_size:
|
|
737
|
+
Number of texts per embedding batch.
|
|
738
|
+
**kwargs:
|
|
739
|
+
Extra configuration captured for downstream use.
|
|
740
|
+
"""
|
|
741
|
+
super().__init__()
|
|
742
|
+
|
|
743
|
+
# Consolidated configuration for simple serialization
|
|
744
|
+
self.cfg: Dict[str, Any] = {
|
|
745
|
+
"llm_model_id": llm_model_id,
|
|
746
|
+
"retriever_model_id": retriever_model_id,
|
|
747
|
+
"device": device,
|
|
748
|
+
"token": token,
|
|
749
|
+
"top_k": int(top_k),
|
|
750
|
+
"max_new_tokens": int(max_new_tokens),
|
|
751
|
+
"gen_batch_size": int(gen_batch_size),
|
|
752
|
+
"enc_batch_size": int(enc_batch_size),
|
|
753
|
+
}
|
|
754
|
+
self.extra_cfg: Dict[str, Any] = dict(kwargs)
|
|
755
|
+
|
|
756
|
+
# LLM components
|
|
757
|
+
self.tokenizer: Optional[AutoTokenizer] = None
|
|
758
|
+
self.generation_model: Optional[AutoModelForCausalLM] = None
|
|
759
|
+
|
|
760
|
+
# Retriever components
|
|
761
|
+
self.embedder: Optional[SentenceTransformer] = None
|
|
762
|
+
self.indexed_corpus: List[str] = [] # items: "<term> || [<types>...]"
|
|
763
|
+
self.corpus_embeddings: Optional[torch.Tensor] = None
|
|
764
|
+
|
|
765
|
+
# Training cache of (term, [types]) tuples
|
|
766
|
+
self.train_term_types: List[Tuple[str, List[str]]] = []
|
|
767
|
+
|
|
768
|
+
# Prompt templates
|
|
769
|
+
self._system_prompt: str = (
|
|
770
|
+
"You are an expert in ontologies and semantic term classification.\n"
|
|
771
|
+
"Task: determine semantic types for the TERM using the EXAMPLES provided.\n"
|
|
772
|
+
"Rules:\n"
|
|
773
|
+
"1) Types must be generalizing categories from the domain ontology.\n"
|
|
774
|
+
"2) Be concise. Respond ONLY in JSON using double quotes.\n"
|
|
775
|
+
'Format: {"term":"...", "reasoning":"<<=100 words>>", "types":["...", "..."]}\n'
|
|
776
|
+
)
|
|
777
|
+
self._user_prompt_template: str = """{examples}
|
|
778
|
+
|
|
779
|
+
TERM: {term}
|
|
780
|
+
|
|
781
|
+
TASK: Determine semantic types for the given term based on the domain ontology.
|
|
782
|
+
Remember: types are generalizing categories, not the term itself. Respond in JSON.
|
|
783
|
+
"""
|
|
784
|
+
|
|
785
|
+
def load(
|
|
786
|
+
self,
|
|
787
|
+
model_id: Optional[str] = None,
|
|
788
|
+
retriever_id: Optional[str] = None,
|
|
789
|
+
device: Optional[str] = None,
|
|
790
|
+
token: Optional[str] = None,
|
|
791
|
+
**kwargs: Any,
|
|
792
|
+
) -> None:
|
|
793
|
+
"""Load the LLM and the embedding retriever. Overrides constructor values if provided.
|
|
794
|
+
|
|
795
|
+
Parameters
|
|
796
|
+
model_id:
|
|
797
|
+
Optional override for the LLM model id.
|
|
798
|
+
retriever_id:
|
|
799
|
+
Optional override for the embedding model id.
|
|
800
|
+
device:
|
|
801
|
+
Optional override for device selection policy.
|
|
802
|
+
token:
|
|
803
|
+
Optional override for HF token.
|
|
804
|
+
**kwargs:
|
|
805
|
+
Extra values to store in `extra_cfg`.
|
|
806
|
+
|
|
807
|
+
"""
|
|
808
|
+
if model_id is not None:
|
|
809
|
+
self.cfg["llm_model_id"] = model_id
|
|
810
|
+
if retriever_id is not None:
|
|
811
|
+
self.cfg["retriever_model_id"] = retriever_id
|
|
812
|
+
if device is not None:
|
|
813
|
+
self.cfg["device"] = device
|
|
814
|
+
if token is not None:
|
|
815
|
+
self.cfg["token"] = token
|
|
816
|
+
self.extra_cfg.update(kwargs)
|
|
817
|
+
|
|
818
|
+
# Choose device & dtype for the LLM
|
|
819
|
+
cuda_available: bool = torch.cuda.is_available()
|
|
820
|
+
use_cuda: bool = cuda_available and (self.cfg["device"] != "cpu")
|
|
821
|
+
device_map: str = "auto" if use_cuda else "cpu"
|
|
822
|
+
torch_dtype = torch.bfloat16 if use_cuda else torch.float32
|
|
823
|
+
|
|
824
|
+
# Tokenizer
|
|
825
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
826
|
+
self.cfg["llm_model_id"], padding_side="left", token=self.cfg["token"]
|
|
827
|
+
)
|
|
828
|
+
if self.tokenizer.pad_token is None:
|
|
829
|
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
830
|
+
|
|
831
|
+
# LLM
|
|
832
|
+
self.generation_model = AutoModelForCausalLM.from_pretrained(
|
|
833
|
+
self.cfg["llm_model_id"],
|
|
834
|
+
device_map=device_map,
|
|
835
|
+
torch_dtype=torch_dtype,
|
|
836
|
+
token=self.cfg["token"],
|
|
837
|
+
)
|
|
838
|
+
|
|
839
|
+
# Deterministic decoding defaults
|
|
840
|
+
generation_cfg = self.generation_model.generation_config
|
|
841
|
+
generation_cfg.do_sample = False
|
|
842
|
+
generation_cfg.temperature = None
|
|
843
|
+
generation_cfg.top_p = None
|
|
844
|
+
generation_cfg.top_k = None
|
|
845
|
+
generation_cfg.num_beams = 1
|
|
846
|
+
|
|
847
|
+
# Retriever
|
|
848
|
+
self.embedder = SentenceTransformer(
|
|
849
|
+
self.cfg["retriever_model_id"], trust_remote_code=True
|
|
850
|
+
)
|
|
851
|
+
|
|
852
|
+
def fit(self, train_data: Any, task: str, ontologizer: bool = True) -> None:
|
|
853
|
+
"""Prepare the retrieval index from training examples.
|
|
854
|
+
|
|
855
|
+
Parameters
|
|
856
|
+
train_data:
|
|
857
|
+
Training payload containing terms and their types.
|
|
858
|
+
task:
|
|
859
|
+
Must be `'term-typing'`; other tasks are forwarded to base.
|
|
860
|
+
ontologizer:
|
|
861
|
+
Unused flag for API compatibility.
|
|
862
|
+
|
|
863
|
+
Side Effects
|
|
864
|
+
- Normalizes to a list of `(term, [types])`.
|
|
865
|
+
- Builds an indexable text corpus and (if embedder is loaded)
|
|
866
|
+
computes embeddings for retrieval.
|
|
867
|
+
"""
|
|
868
|
+
if task != "term-typing":
|
|
869
|
+
return super().fit(train_data, task, ontologizer)
|
|
870
|
+
|
|
871
|
+
# Normalize incoming training data -> list[(term, [types])]
|
|
872
|
+
self.train_term_types = self._unpack_train(train_data)
|
|
873
|
+
|
|
874
|
+
# Build the textual corpus to index
|
|
875
|
+
self.indexed_corpus = [
|
|
876
|
+
f"{term} || {json.dumps(types, ensure_ascii=False)}"
|
|
877
|
+
for term, types in self.train_term_types
|
|
878
|
+
]
|
|
879
|
+
|
|
880
|
+
# Embed the corpus if available; else fall back to zero-shot prompting
|
|
881
|
+
if self.indexed_corpus and self.embedder is not None:
|
|
882
|
+
self.corpus_embeddings = self._encode_texts(self.indexed_corpus)
|
|
883
|
+
else:
|
|
884
|
+
self.corpus_embeddings = None
|
|
885
|
+
|
|
886
|
+
def predict(self, eval_data: Any, task: str, ontologizer: bool = True) -> Any:
|
|
887
|
+
"""Predict types for evaluation items; returns a list of {term, types, id?}.
|
|
888
|
+
|
|
889
|
+
Parameters
|
|
890
|
+
eval_data:
|
|
891
|
+
Evaluation payload to type (terms + optional ids).
|
|
892
|
+
task:
|
|
893
|
+
Must be `'term-typing'`; other tasks are forwarded to base.
|
|
894
|
+
ontologizer:
|
|
895
|
+
Unused flag for API compatibility.
|
|
896
|
+
|
|
897
|
+
Returns
|
|
898
|
+
List[Dict[str, Any]]
|
|
899
|
+
For each input term, a dictionary with keys:
|
|
900
|
+
- `term`: The input term.
|
|
901
|
+
- `types`: A (unique, sorted) list of predicted types.
|
|
902
|
+
- `id`: Optional example id (if provided in input).
|
|
903
|
+
"""
|
|
904
|
+
if task != "term-typing":
|
|
905
|
+
return super().predict(eval_data, task, ontologizer)
|
|
906
|
+
|
|
907
|
+
eval_terms, eval_ids = self._unpack_eval(eval_data)
|
|
908
|
+
if not eval_terms:
|
|
909
|
+
return []
|
|
910
|
+
|
|
911
|
+
# Use RAG if we have an indexed corpus & embeddings; otherwise zero-shot
|
|
912
|
+
rag_available = (
|
|
913
|
+
self.corpus_embeddings is not None
|
|
914
|
+
and self.embedder is not None
|
|
915
|
+
and len(self.indexed_corpus) > 0
|
|
916
|
+
)
|
|
917
|
+
|
|
918
|
+
if rag_available:
|
|
919
|
+
neighbor_docs_per_query = self._retrieve_batch(
|
|
920
|
+
eval_terms, top_k=int(self.cfg["top_k"])
|
|
921
|
+
)
|
|
922
|
+
else:
|
|
923
|
+
neighbor_docs_per_query = [[] for _ in eval_terms]
|
|
924
|
+
|
|
925
|
+
# Compose prompts
|
|
926
|
+
prompts: List[str] = []
|
|
927
|
+
for term, neighbor_docs in zip(eval_terms, neighbor_docs_per_query):
|
|
928
|
+
example_pairs = self._decode_examples(neighbor_docs)
|
|
929
|
+
examples_block = self._format_examples(example_pairs)
|
|
930
|
+
prompt_text = self._compose_prompt(examples_block, term)
|
|
931
|
+
prompts.append(prompt_text)
|
|
932
|
+
|
|
933
|
+
predicted_types_lists = self._generate_and_parse(prompts)
|
|
934
|
+
|
|
935
|
+
# Build standardized results
|
|
936
|
+
results: List[Dict[str, Any]] = []
|
|
937
|
+
for term, example_id, predicted_types in zip(
|
|
938
|
+
eval_terms, eval_ids, predicted_types_lists
|
|
939
|
+
):
|
|
940
|
+
result_row: Dict[str, Any] = {
|
|
941
|
+
"term": term,
|
|
942
|
+
"types": sorted({t for t in predicted_types}), # unique + sorted
|
|
943
|
+
}
|
|
944
|
+
if example_id is not None:
|
|
945
|
+
result_row["id"] = example_id
|
|
946
|
+
results.append(result_row)
|
|
947
|
+
|
|
948
|
+
assert all(("term" in row and "types" in row) for row in results), (
|
|
949
|
+
"predict() must return term + types"
|
|
950
|
+
)
|
|
951
|
+
return results
|
|
952
|
+
|
|
953
|
+
def _unpack_train(self, data: Any) -> List[Tuple[str, List[str]]]:
|
|
954
|
+
"""Extract `(term, [types])` tuples from supported training payloads.
|
|
955
|
+
|
|
956
|
+
Supported Inputs
|
|
957
|
+
- `data.term_typings` (objects exposing `.term` & `.types`)
|
|
958
|
+
- `list[dict]` with keys `'term'` and `'types'`
|
|
959
|
+
- `list[str]` → returns empty (nothing to index)
|
|
960
|
+
- other formats → empty
|
|
961
|
+
|
|
962
|
+
Parameters
|
|
963
|
+
data:
|
|
964
|
+
Training payload.
|
|
965
|
+
|
|
966
|
+
Returns
|
|
967
|
+
List[Tuple[str, List[str]]]
|
|
968
|
+
(term, types) tuples (types kept as strings).
|
|
969
|
+
"""
|
|
970
|
+
term_typings = getattr(data, "term_typings", None)
|
|
971
|
+
if term_typings is not None:
|
|
972
|
+
parsed_pairs: List[Tuple[str, List[str]]] = []
|
|
973
|
+
for item in term_typings:
|
|
974
|
+
term = getattr(item, "term", None)
|
|
975
|
+
types = list(getattr(item, "types", []) or [])
|
|
976
|
+
if term and types:
|
|
977
|
+
parsed_pairs.append(
|
|
978
|
+
(term, [t for t in types if isinstance(t, str)])
|
|
979
|
+
)
|
|
980
|
+
return parsed_pairs
|
|
981
|
+
|
|
982
|
+
if isinstance(data, list) and data and isinstance(data[0], dict):
|
|
983
|
+
parsed_pairs = []
|
|
984
|
+
for row in data:
|
|
985
|
+
term = row.get("term")
|
|
986
|
+
types = row.get("types") or []
|
|
987
|
+
if term and isinstance(types, list) and types:
|
|
988
|
+
parsed_pairs.append(
|
|
989
|
+
(term, [t for t in types if isinstance(t, str)])
|
|
990
|
+
)
|
|
991
|
+
return parsed_pairs
|
|
992
|
+
|
|
993
|
+
# If only a list of strings is provided, there's nothing to index for RAG
|
|
994
|
+
if isinstance(data, (list, set, tuple)) and all(
|
|
995
|
+
isinstance(x, str) for x in data
|
|
996
|
+
):
|
|
997
|
+
return []
|
|
998
|
+
|
|
999
|
+
return []
|
|
1000
|
+
|
|
1001
|
+
def _unpack_eval(self, data: Any) -> Tuple[List[str], List[Optional[str]]]:
|
|
1002
|
+
"""Extract `(terms, ids)` from supported evaluation payloads.
|
|
1003
|
+
|
|
1004
|
+
Supported Inputs
|
|
1005
|
+
- `data.term_typings` (objects exposing `.term` & optional `.id`)
|
|
1006
|
+
- `list[str]`
|
|
1007
|
+
- `list[dict]` with `term` and optional `id`
|
|
1008
|
+
|
|
1009
|
+
Parameters
|
|
1010
|
+
data:
|
|
1011
|
+
Evaluation payload.
|
|
1012
|
+
|
|
1013
|
+
Returns
|
|
1014
|
+
Tuple[List[str], List[Optional[str]]]
|
|
1015
|
+
Two lists aligned by index: terms and ids (ids may contain `None`).
|
|
1016
|
+
"""
|
|
1017
|
+
term_typings = getattr(data, "term_typings", None)
|
|
1018
|
+
if term_typings is not None:
|
|
1019
|
+
terms: List[str] = []
|
|
1020
|
+
ids: List[Optional[str]] = []
|
|
1021
|
+
for item in term_typings:
|
|
1022
|
+
terms.append(getattr(item, "term", ""))
|
|
1023
|
+
ids.append(getattr(item, "id", None))
|
|
1024
|
+
return terms, ids
|
|
1025
|
+
|
|
1026
|
+
if isinstance(data, list) and data and isinstance(data[0], str):
|
|
1027
|
+
return list(data), [None] * len(data)
|
|
1028
|
+
|
|
1029
|
+
if isinstance(data, list) and data and isinstance(data[0], dict):
|
|
1030
|
+
terms: List[str] = []
|
|
1031
|
+
ids: List[Optional[str]] = []
|
|
1032
|
+
for row in data:
|
|
1033
|
+
terms.append(row.get("term", ""))
|
|
1034
|
+
ids.append(row.get("id"))
|
|
1035
|
+
return terms, ids
|
|
1036
|
+
|
|
1037
|
+
return [], []
|
|
1038
|
+
|
|
1039
|
+
def _encode_texts(self, texts: List[str]) -> torch.Tensor:
|
|
1040
|
+
"""Encode a batch of texts with the sentence-embedding model.
|
|
1041
|
+
|
|
1042
|
+
Parameters
|
|
1043
|
+
texts:
|
|
1044
|
+
List of strings to embed.
|
|
1045
|
+
|
|
1046
|
+
Returns
|
|
1047
|
+
torch.Tensor
|
|
1048
|
+
Tensor of shape `(len(texts), hidden_dim)`. If `texts` is empty,
|
|
1049
|
+
returns an empty tensor with 0 rows.
|
|
1050
|
+
"""
|
|
1051
|
+
batch_size = int(self.cfg["enc_batch_size"])
|
|
1052
|
+
batch_embeddings: List[torch.Tensor] = []
|
|
1053
|
+
|
|
1054
|
+
for batch_start in range(0, len(texts), batch_size):
|
|
1055
|
+
batch_texts = texts[batch_start : batch_start + batch_size]
|
|
1056
|
+
embeddings = self.embedder.encode(
|
|
1057
|
+
batch_texts, convert_to_tensor=True, show_progress_bar=False
|
|
1058
|
+
)
|
|
1059
|
+
batch_embeddings.append(embeddings)
|
|
1060
|
+
|
|
1061
|
+
return (
|
|
1062
|
+
torch.cat(batch_embeddings, dim=0) if batch_embeddings else torch.empty(0)
|
|
1063
|
+
)
|
|
1064
|
+
|
|
1065
|
+
def _retrieve_batch(self, queries: List[str], top_k: int) -> List[List[str]]:
|
|
1066
|
+
"""Return for each query the top-k most similar corpus entries.
|
|
1067
|
+
|
|
1068
|
+
Parameters
|
|
1069
|
+
queries:
|
|
1070
|
+
List of query terms.
|
|
1071
|
+
top_k:
|
|
1072
|
+
Number of neighbors to retrieve for each query.
|
|
1073
|
+
|
|
1074
|
+
Returns
|
|
1075
|
+
List[List[str]]
|
|
1076
|
+
For each query, a list of raw corpus strings formatted as
|
|
1077
|
+
`"<term> || [\\"type1\\", ...]"`.
|
|
1078
|
+
"""
|
|
1079
|
+
if self.corpus_embeddings is None or not self.indexed_corpus:
|
|
1080
|
+
return [[] for _ in queries]
|
|
1081
|
+
|
|
1082
|
+
query_embeddings = self._encode_texts(queries) # [Q, D]
|
|
1083
|
+
doc_embeddings = self.corpus_embeddings # [N, D]
|
|
1084
|
+
if query_embeddings.shape[-1] != doc_embeddings.shape[-1]:
|
|
1085
|
+
raise ValueError(
|
|
1086
|
+
f"Embedding dim mismatch: {query_embeddings.shape[-1]} vs {doc_embeddings.shape[-1]}"
|
|
1087
|
+
)
|
|
1088
|
+
|
|
1089
|
+
# Cosine similarity via L2-normalized dot product
|
|
1090
|
+
q_norm = F.normalize(query_embeddings, p=2, dim=1)
|
|
1091
|
+
d_norm = F.normalize(doc_embeddings, p=2, dim=1)
|
|
1092
|
+
cos_sim = torch.matmul(q_norm, d_norm.T) # [Q, N]
|
|
1093
|
+
|
|
1094
|
+
k = min(max(1, top_k), len(self.indexed_corpus))
|
|
1095
|
+
_, top_indices = torch.topk(cos_sim, k=k, dim=1)
|
|
1096
|
+
return [[self.indexed_corpus[j] for j in row.tolist()] for row in top_indices]
|
|
1097
|
+
|
|
1098
|
+
def _decode_examples(self, docs: List[str]) -> List[Tuple[str, List[str]]]:
|
|
1099
|
+
"""Parse raw corpus rows ('term || [types]') into `(term, [types])` pairs.
|
|
1100
|
+
|
|
1101
|
+
Parameters
|
|
1102
|
+
docs:
|
|
1103
|
+
Raw strings from the index/corpus.
|
|
1104
|
+
|
|
1105
|
+
Returns
|
|
1106
|
+
List[Tuple[str, List[str]]]
|
|
1107
|
+
Parsed (term, types) pairs; malformed rows are skipped.
|
|
1108
|
+
"""
|
|
1109
|
+
example_pairs: List[Tuple[str, List[str]]] = []
|
|
1110
|
+
for raw_row in docs:
|
|
1111
|
+
try:
|
|
1112
|
+
term_raw, types_json = raw_row.split("||", 1)
|
|
1113
|
+
term = term_raw.strip()
|
|
1114
|
+
types_list = json.loads(types_json.strip())
|
|
1115
|
+
if isinstance(types_list, list):
|
|
1116
|
+
example_pairs.append(
|
|
1117
|
+
(term, [t for t in types_list if isinstance(t, str)])
|
|
1118
|
+
)
|
|
1119
|
+
except Exception:
|
|
1120
|
+
continue
|
|
1121
|
+
return example_pairs
|
|
1122
|
+
|
|
1123
|
+
def _format_examples(self, pairs: List[Tuple[str, List[str]]]) -> str:
|
|
1124
|
+
"""Format retrieved example pairs into a compact block for the prompt.
|
|
1125
|
+
|
|
1126
|
+
Parameters
|
|
1127
|
+
pairs:
|
|
1128
|
+
Retrieved `(term, [types])` examples.
|
|
1129
|
+
|
|
1130
|
+
Returns
|
|
1131
|
+
str
|
|
1132
|
+
Human-readable lines to provide *light* guidance to the LLM.
|
|
1133
|
+
"""
|
|
1134
|
+
if not pairs:
|
|
1135
|
+
return "EXAMPLES: (none provided)"
|
|
1136
|
+
lines: List[str] = ["CLASSIFICATION EXAMPLES:"]
|
|
1137
|
+
for idx, (term, types) in enumerate(pairs, 1):
|
|
1138
|
+
preview_types = types[:3] # keep context small
|
|
1139
|
+
lines.append(f"{idx}. Term: '{term}' → Types: {list(preview_types)}")
|
|
1140
|
+
lines.append("END OF EXAMPLES.")
|
|
1141
|
+
return "\n".join(lines)
|
|
1142
|
+
|
|
1143
|
+
def _compose_prompt(self, examples_block: str, term: str) -> str:
|
|
1144
|
+
"""Compose the final prompt from system + user blocks.
|
|
1145
|
+
|
|
1146
|
+
Parameters
|
|
1147
|
+
examples_block:
|
|
1148
|
+
Text block with retrieved examples.
|
|
1149
|
+
term:
|
|
1150
|
+
The query term to classify.
|
|
1151
|
+
|
|
1152
|
+
Returns
|
|
1153
|
+
str
|
|
1154
|
+
Full prompt string passed to the LLM.
|
|
1155
|
+
"""
|
|
1156
|
+
user_block = self._user_prompt_template.format(
|
|
1157
|
+
examples=examples_block, term=term
|
|
1158
|
+
)
|
|
1159
|
+
return f"{self._system_prompt}\n\n{user_block}\n"
|
|
1160
|
+
|
|
1161
|
+
def _generate_and_parse(self, prompts: List[str]) -> List[List[str]]:
|
|
1162
|
+
"""Run generation for a batch of prompts and parse the JSON `'types'` from outputs.
|
|
1163
|
+
|
|
1164
|
+
Parameters
|
|
1165
|
+
prompts:
|
|
1166
|
+
Finalized prompts for the LLM.
|
|
1167
|
+
|
|
1168
|
+
Returns
|
|
1169
|
+
List[List[str]]
|
|
1170
|
+
For each prompt, a list of predicted type strings.
|
|
1171
|
+
"""
|
|
1172
|
+
batch_size = int(self.cfg["gen_batch_size"])
|
|
1173
|
+
all_predicted_types: List[List[str]] = []
|
|
1174
|
+
|
|
1175
|
+
for batch_start in range(0, len(prompts), batch_size):
|
|
1176
|
+
prompt_batch = prompts[batch_start : batch_start + batch_size]
|
|
1177
|
+
|
|
1178
|
+
# Tokenize and move to the LLM's device
|
|
1179
|
+
model_device = getattr(self.generation_model, "device", None)
|
|
1180
|
+
encodings = self.tokenizer(
|
|
1181
|
+
prompt_batch, return_tensors="pt", padding=True
|
|
1182
|
+
).to(model_device)
|
|
1183
|
+
input_token_length = encodings["input_ids"].shape[1]
|
|
1184
|
+
|
|
1185
|
+
# Deterministic decoding (greedy)
|
|
1186
|
+
with torch.no_grad():
|
|
1187
|
+
generated_tokens = self.generation_model.generate(
|
|
1188
|
+
**encodings,
|
|
1189
|
+
do_sample=False,
|
|
1190
|
+
num_beams=1,
|
|
1191
|
+
temperature=None,
|
|
1192
|
+
top_p=None,
|
|
1193
|
+
top_k=None,
|
|
1194
|
+
max_new_tokens=int(self.cfg["max_new_tokens"]),
|
|
1195
|
+
pad_token_id=self.tokenizer.eos_token_id,
|
|
1196
|
+
)
|
|
1197
|
+
|
|
1198
|
+
# Slice off the prompt tokens and decode only newly generated tokens
|
|
1199
|
+
new_token_span = generated_tokens[:, input_token_length:]
|
|
1200
|
+
decoded_texts = [
|
|
1201
|
+
self.tokenizer.decode(seq, skip_special_tokens=True)
|
|
1202
|
+
for seq in new_token_span
|
|
1203
|
+
]
|
|
1204
|
+
|
|
1205
|
+
parsed_types_per_prompt = [
|
|
1206
|
+
self._parse_types(text) for text in decoded_texts
|
|
1207
|
+
]
|
|
1208
|
+
all_predicted_types.extend(parsed_types_per_prompt)
|
|
1209
|
+
|
|
1210
|
+
return all_predicted_types
|
|
1211
|
+
|
|
1212
|
+
def _parse_types(self, text: str) -> List[str]:
|
|
1213
|
+
"""Extract a list of type strings from LLM output.
|
|
1214
|
+
|
|
1215
|
+
Parsing Strategy (in order)
|
|
1216
|
+
1) Strict JSON object with `"types"`.
|
|
1217
|
+
2) Regex-extract JSON object containing `"types"`.
|
|
1218
|
+
3) Regex-extract first bracketed list.
|
|
1219
|
+
4) Comma-split fallback.
|
|
1220
|
+
|
|
1221
|
+
Parameters
|
|
1222
|
+
text:
|
|
1223
|
+
Raw LLM output to parse.
|
|
1224
|
+
|
|
1225
|
+
Returns
|
|
1226
|
+
List[str]
|
|
1227
|
+
Parsed list of type strings (possibly empty if parsing fails).
|
|
1228
|
+
"""
|
|
1229
|
+
try:
|
|
1230
|
+
obj = json.loads(text)
|
|
1231
|
+
if isinstance(obj, dict) and isinstance(obj.get("types"), list):
|
|
1232
|
+
return [t for t in obj["types"] if isinstance(t, str)]
|
|
1233
|
+
except Exception:
|
|
1234
|
+
pass
|
|
1235
|
+
|
|
1236
|
+
try:
|
|
1237
|
+
obj_match = re.search(
|
|
1238
|
+
r'\{[^{}]*"types"\s*:\s*\[[^\]]*\][^{}]*\}', text, re.S
|
|
1239
|
+
)
|
|
1240
|
+
if obj_match:
|
|
1241
|
+
obj = json.loads(obj_match.group(0))
|
|
1242
|
+
types = obj.get("types", [])
|
|
1243
|
+
return [t for t in types if isinstance(t, str)]
|
|
1244
|
+
except Exception:
|
|
1245
|
+
pass
|
|
1246
|
+
|
|
1247
|
+
try:
|
|
1248
|
+
list_match = re.search(r"\[([^\]]+)\]", text)
|
|
1249
|
+
if list_match:
|
|
1250
|
+
items = [
|
|
1251
|
+
x.strip().strip('"').strip("'")
|
|
1252
|
+
for x in list_match.group(1).split(",")
|
|
1253
|
+
]
|
|
1254
|
+
return [t for t in items if t]
|
|
1255
|
+
except Exception:
|
|
1256
|
+
pass
|
|
1257
|
+
|
|
1258
|
+
if "," in text:
|
|
1259
|
+
items = [x.strip().strip('"').strip("'") for x in text.split(",")]
|
|
1260
|
+
return [t for t in items if t]
|
|
1261
|
+
|
|
1262
|
+
return []
|