OntoLearner 1.4.7__py3-none-any.whl → 1.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ontolearner/VERSION +1 -1
- ontolearner/base/learner.py +15 -12
- ontolearner/learner/label_mapper.py +1 -1
- ontolearner/learner/retriever.py +24 -3
- ontolearner/learner/taxonomy_discovery/__init__.py +18 -0
- ontolearner/learner/taxonomy_discovery/alexbek.py +500 -0
- ontolearner/learner/taxonomy_discovery/rwthdbis.py +1082 -0
- ontolearner/learner/taxonomy_discovery/sbunlp.py +402 -0
- ontolearner/learner/taxonomy_discovery/skhnlp.py +1138 -0
- ontolearner/learner/term_typing/__init__.py +17 -0
- ontolearner/learner/term_typing/alexbek.py +1262 -0
- ontolearner/learner/term_typing/rwthdbis.py +379 -0
- ontolearner/learner/term_typing/sbunlp.py +478 -0
- ontolearner/learner/text2onto/__init__.py +16 -0
- ontolearner/learner/text2onto/alexbek.py +1219 -0
- ontolearner/learner/text2onto/sbunlp.py +598 -0
- {ontolearner-1.4.7.dist-info → ontolearner-1.4.8.dist-info}/METADATA +4 -1
- {ontolearner-1.4.7.dist-info → ontolearner-1.4.8.dist-info}/RECORD +20 -8
- {ontolearner-1.4.7.dist-info → ontolearner-1.4.8.dist-info}/WHEEL +0 -0
- {ontolearner-1.4.7.dist-info → ontolearner-1.4.8.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,478 @@
|
|
|
1
|
+
# Copyright (c) 2025 SciKnowOrg
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the MIT License (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# https://opensource.org/licenses/MIT
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Any, Dict, List, Optional
|
|
16
|
+
import re
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
20
|
+
|
|
21
|
+
from ...base import AutoLearner
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SBUNLPZSLearner(AutoLearner):
|
|
25
|
+
"""
|
|
26
|
+
Qwen-based blind term typing learner (Task B), implemented as an AutoLearner.
|
|
27
|
+
|
|
28
|
+
Lifecycle:
|
|
29
|
+
• `fit(...)` learns/records the allowed type inventory from the training payload.
|
|
30
|
+
• `load(...)` explicitly loads the tokenizer/model (pass `model_id`/`token` here).
|
|
31
|
+
• `predict(...)` prompts the model per term and returns normalized types limited
|
|
32
|
+
to the learned inventory.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
device: str = "cpu",
|
|
38
|
+
max_new_tokens: int = 64,
|
|
39
|
+
temperature: float = 0.0,
|
|
40
|
+
model_id: str = "Qwen/Qwen2.5-0.5B-Instruct",
|
|
41
|
+
token: Optional[str] = None,
|
|
42
|
+
) -> None:
|
|
43
|
+
"""
|
|
44
|
+
Configure runtime knobs. Model identity and auth are provided to `load(...)`.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
device: Torch device policy ("cuda", "mps", or "cpu").
|
|
48
|
+
max_new_tokens: Max tokens to generate per prompt (greedy decoding).
|
|
49
|
+
temperature: Reserved for future sampling; generation is greedy here.
|
|
50
|
+
model_id: Fallback model id/path used if `load()` is called without args.
|
|
51
|
+
token: Fallback HF token used if `load()` is called without args.
|
|
52
|
+
|
|
53
|
+
Side Effects:
|
|
54
|
+
Initializes runtime configuration, instance defaults for `load()`,
|
|
55
|
+
and placeholders for `tokenizer`, `model`, and `allowed_types`.
|
|
56
|
+
"""
|
|
57
|
+
super().__init__()
|
|
58
|
+
self.device = device
|
|
59
|
+
self.max_new_tokens = max_new_tokens
|
|
60
|
+
self.temperature = temperature
|
|
61
|
+
|
|
62
|
+
# Defaults that load() may use when its args are None
|
|
63
|
+
self.model_id = model_id
|
|
64
|
+
self.token = token
|
|
65
|
+
|
|
66
|
+
# Placeholders populated by load()
|
|
67
|
+
self.tokenizer: Optional[AutoTokenizer] = None
|
|
68
|
+
self.model: Optional[AutoModelForCausalLM] = None
|
|
69
|
+
|
|
70
|
+
# Learned inventory
|
|
71
|
+
self.allowed_types: List[str] = []
|
|
72
|
+
|
|
73
|
+
# Regex used to extract quoted strings from model output (e.g., "type")
|
|
74
|
+
self._quoted_re = re.compile(r'"([^"]+)"')
|
|
75
|
+
|
|
76
|
+
def load(
|
|
77
|
+
self,
|
|
78
|
+
model_id: Optional[str] = None,
|
|
79
|
+
token: Optional[str] = None,
|
|
80
|
+
dtype: Optional[torch.dtype] = None,
|
|
81
|
+
):
|
|
82
|
+
"""
|
|
83
|
+
Load tokenizer and model weights explicitly.
|
|
84
|
+
|
|
85
|
+
Argument precedence:
|
|
86
|
+
1) Use `model_id` / `token` passed to this method (if provided).
|
|
87
|
+
2) Else fall back to `self.model_id` / `self.token`.
|
|
88
|
+
|
|
89
|
+
Device & dtype:
|
|
90
|
+
• If `dtype` is None, the default is float16 on CUDA/MPS and float32 on CPU.
|
|
91
|
+
• `device_map` is `"auto"` for non-CPU devices, `"cpu"` otherwise.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
model_id: HF model id/path to load. If None, uses `self.model_id`.
|
|
95
|
+
token: HF token if the model is gated. If None, uses `self.token`.
|
|
96
|
+
dtype: Optional torch dtype override (e.g., `torch.float16`).
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
self
|
|
100
|
+
"""
|
|
101
|
+
resolved_model_id = model_id or self.model_id
|
|
102
|
+
resolved_token = token if token is not None else self.token
|
|
103
|
+
|
|
104
|
+
# Tokenizer
|
|
105
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
106
|
+
resolved_model_id, token=resolved_token
|
|
107
|
+
)
|
|
108
|
+
if self.tokenizer.pad_token is None:
|
|
109
|
+
# Prefer EOS as pad if available
|
|
110
|
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
111
|
+
|
|
112
|
+
# Device & dtype
|
|
113
|
+
if dtype is None:
|
|
114
|
+
if self.device == "cpu":
|
|
115
|
+
resolved_dtype = torch.float32
|
|
116
|
+
else:
|
|
117
|
+
# Works for CUDA and Apple MPS
|
|
118
|
+
resolved_dtype = torch.float16
|
|
119
|
+
else:
|
|
120
|
+
resolved_dtype = dtype
|
|
121
|
+
|
|
122
|
+
device_map = "auto" if self.device != "cpu" else "cpu"
|
|
123
|
+
|
|
124
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
125
|
+
resolved_model_id,
|
|
126
|
+
device_map=device_map,
|
|
127
|
+
torch_dtype=resolved_dtype, # keep torch_dtype for broad Transformers compatibility
|
|
128
|
+
token=resolved_token,
|
|
129
|
+
)
|
|
130
|
+
return self
|
|
131
|
+
|
|
132
|
+
def fit(self, train_data: Any, task: str, ontologizer: bool = True):
|
|
133
|
+
"""
|
|
134
|
+
Learn the allowed type inventory from the training data.
|
|
135
|
+
|
|
136
|
+
Normalization rules:
|
|
137
|
+
• If `ontologizer=True`, the framework's `tasks_data_former(..., test=False)`
|
|
138
|
+
is used to normalize `train_data`.
|
|
139
|
+
• If a container exposes `.term_typings`, types are collected from there.
|
|
140
|
+
• If the normalized data is a list of dicts with `"parent"`, unique parents
|
|
141
|
+
become the allowed types.
|
|
142
|
+
• If it's a list of strings, that unique set becomes the allowed types.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
train_data: Training payload provided by the pipeline.
|
|
146
|
+
task: Must be `"term-typing"`.
|
|
147
|
+
ontologizer: If True, normalize via `tasks_data_former()` first.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
self
|
|
151
|
+
|
|
152
|
+
Raises:
|
|
153
|
+
ValueError: If `task` is not `"term-typing"`.
|
|
154
|
+
TypeError: If the training data cannot be normalized to a list of
|
|
155
|
+
strings or relationship dicts.
|
|
156
|
+
"""
|
|
157
|
+
train_fmt = (
|
|
158
|
+
self.tasks_data_former(data=train_data, task=task, test=False)
|
|
159
|
+
if ontologizer
|
|
160
|
+
else train_data
|
|
161
|
+
)
|
|
162
|
+
if task != "term-typing":
|
|
163
|
+
raise ValueError("SBUNLPZSLearner only implements 'term-typing'.")
|
|
164
|
+
|
|
165
|
+
# If framework passed a container with `.term_typings`, extract types from there
|
|
166
|
+
if not isinstance(train_fmt, list):
|
|
167
|
+
if hasattr(train_fmt, "term_typings"):
|
|
168
|
+
try:
|
|
169
|
+
collected = set()
|
|
170
|
+
for tt in getattr(train_fmt, "term_typings") or []:
|
|
171
|
+
# tt.types could be list[str] or a single str
|
|
172
|
+
if hasattr(tt, "types"):
|
|
173
|
+
tvals = tt.types
|
|
174
|
+
elif isinstance(tt, dict) and "types" in tt:
|
|
175
|
+
tvals = tt["types"]
|
|
176
|
+
else:
|
|
177
|
+
tvals = None
|
|
178
|
+
|
|
179
|
+
if isinstance(tvals, (list, tuple, set)):
|
|
180
|
+
for x in tvals:
|
|
181
|
+
if isinstance(x, str):
|
|
182
|
+
collected.add(x)
|
|
183
|
+
elif isinstance(tvals, str):
|
|
184
|
+
collected.add(tvals)
|
|
185
|
+
|
|
186
|
+
if collected:
|
|
187
|
+
self.allowed_types = sorted(collected)
|
|
188
|
+
return self
|
|
189
|
+
except Exception:
|
|
190
|
+
# Fall through to error below if unexpected issues occur.
|
|
191
|
+
pass
|
|
192
|
+
|
|
193
|
+
raise TypeError("For term-typing, expected a list of type labels at fit().")
|
|
194
|
+
|
|
195
|
+
# At this point train_fmt is a list (original logic preserved)
|
|
196
|
+
if train_fmt and isinstance(train_fmt[0], dict) and "parent" in train_fmt[0]:
|
|
197
|
+
# Case A: Received raw relationships/pairs (e.g., from train_test_split).
|
|
198
|
+
unique_types = set(r.get("parent") for r in train_fmt if r.get("parent"))
|
|
199
|
+
self.allowed_types = sorted(unique_types)
|
|
200
|
+
elif all(isinstance(x, str) for x in train_fmt):
|
|
201
|
+
# Case B: Received a clean list of type labels (List[str]).
|
|
202
|
+
self.allowed_types = sorted(set(train_fmt))
|
|
203
|
+
else:
|
|
204
|
+
raise TypeError(
|
|
205
|
+
"For term-typing, input data format for fit() is invalid. "
|
|
206
|
+
"Expected list of strings (types) or list of relationships (dicts)."
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
return self
|
|
210
|
+
|
|
211
|
+
def predict(self, eval_data: Any, task: str, ontologizer: bool = True) -> Any:
|
|
212
|
+
"""
|
|
213
|
+
Predict types for each term and return standardized rows.
|
|
214
|
+
|
|
215
|
+
Expected inputs:
|
|
216
|
+
• With `ontologizer=True`: a `list[str]` of terms (IDs are auto-generated),
|
|
217
|
+
or a container exposing `.term_typings` from which `{'id','term'}` pairs
|
|
218
|
+
can be extracted.
|
|
219
|
+
• With `ontologizer=False`: a `list[dict]` of `{'id','term'}` to preserve IDs.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
eval_data: Evaluation payload as described above.
|
|
223
|
+
task: Must be `"term-typing"`.
|
|
224
|
+
ontologizer: If True, normalize through the pipeline’s data former.
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
A list of dictionaries:
|
|
228
|
+
`{"id": str, "term": str, "types": List[str]}`.
|
|
229
|
+
"""
|
|
230
|
+
if task != "term-typing":
|
|
231
|
+
# Delegate to base for other tasks (not implemented here)
|
|
232
|
+
return super().predict(eval_data, task, ontologizer=ontologizer)
|
|
233
|
+
|
|
234
|
+
def _extract_list_of_dicts_from_term_typings(
|
|
235
|
+
obj,
|
|
236
|
+
) -> Optional[List[Dict[str, str]]]:
|
|
237
|
+
"""Try to derive `[{id, term}, ...]` from an object with `.term_typings`."""
|
|
238
|
+
tts = getattr(obj, "term_typings", None)
|
|
239
|
+
if tts is None:
|
|
240
|
+
return None
|
|
241
|
+
out = []
|
|
242
|
+
for tt in tts:
|
|
243
|
+
if isinstance(tt, dict):
|
|
244
|
+
tid = tt.get("ID") or tt.get("id") or tt.get("Id") or tt.get("ID_")
|
|
245
|
+
tterm = tt.get("term") or tt.get("label") or tt.get("name")
|
|
246
|
+
else:
|
|
247
|
+
tid = (
|
|
248
|
+
getattr(tt, "ID", None)
|
|
249
|
+
or getattr(tt, "id", None)
|
|
250
|
+
or getattr(tt, "Id", None)
|
|
251
|
+
)
|
|
252
|
+
tterm = (
|
|
253
|
+
getattr(tt, "term", None)
|
|
254
|
+
or getattr(tt, "label", None)
|
|
255
|
+
or getattr(tt, "name", None)
|
|
256
|
+
)
|
|
257
|
+
if tid is None or tterm is None:
|
|
258
|
+
continue
|
|
259
|
+
out.append({"id": str(tid), "term": str(tterm)})
|
|
260
|
+
return out if out else None
|
|
261
|
+
|
|
262
|
+
# Case A: ontologizer=True -> framework often provides list[str]
|
|
263
|
+
if ontologizer:
|
|
264
|
+
if isinstance(eval_data, list) and all(
|
|
265
|
+
isinstance(x, str) for x in eval_data
|
|
266
|
+
):
|
|
267
|
+
eval_pack = [
|
|
268
|
+
{"id": f"TT_{i:06d}", "term": t} for i, t in enumerate(eval_data)
|
|
269
|
+
]
|
|
270
|
+
else:
|
|
271
|
+
maybe = _extract_list_of_dicts_from_term_typings(eval_data)
|
|
272
|
+
if maybe is not None:
|
|
273
|
+
eval_pack = maybe
|
|
274
|
+
else:
|
|
275
|
+
# Last resort: attempt to coerce iterables of str
|
|
276
|
+
if hasattr(eval_data, "__iter__") and not isinstance(
|
|
277
|
+
eval_data, (str, bytes)
|
|
278
|
+
):
|
|
279
|
+
lst = list(eval_data)
|
|
280
|
+
if all(isinstance(x, str) for x in lst):
|
|
281
|
+
eval_pack = [
|
|
282
|
+
{"id": f"TT_{i:06d}", "term": t}
|
|
283
|
+
for i, t in enumerate(lst)
|
|
284
|
+
]
|
|
285
|
+
else:
|
|
286
|
+
raise TypeError(
|
|
287
|
+
"With ontologizer=True, eval_data must be list[str] of terms."
|
|
288
|
+
)
|
|
289
|
+
else:
|
|
290
|
+
raise TypeError(
|
|
291
|
+
"With ontologizer=True, eval_data must be list[str] of terms."
|
|
292
|
+
)
|
|
293
|
+
return self._term_typing(eval_pack, test=True)
|
|
294
|
+
|
|
295
|
+
# Case B: ontologizer=False -> expect list[dict], but tolerate containers
|
|
296
|
+
else:
|
|
297
|
+
if isinstance(eval_data, list) and all(
|
|
298
|
+
isinstance(x, dict) for x in eval_data
|
|
299
|
+
):
|
|
300
|
+
eval_pack = eval_data
|
|
301
|
+
else:
|
|
302
|
+
maybe = _extract_list_of_dicts_from_term_typings(eval_data)
|
|
303
|
+
if maybe is not None:
|
|
304
|
+
eval_pack = maybe
|
|
305
|
+
else:
|
|
306
|
+
if isinstance(eval_data, dict):
|
|
307
|
+
for key in ("term_typings", "terms", "items"):
|
|
308
|
+
if key in eval_data and isinstance(
|
|
309
|
+
eval_data[key], (list, tuple)
|
|
310
|
+
):
|
|
311
|
+
converted = []
|
|
312
|
+
for x in eval_data[key]:
|
|
313
|
+
if (
|
|
314
|
+
isinstance(x, dict)
|
|
315
|
+
and ("id" in x or "ID" in x)
|
|
316
|
+
and ("term" in x or "name" in x)
|
|
317
|
+
):
|
|
318
|
+
tid = x.get("ID") or x.get("id")
|
|
319
|
+
tterm = x.get("term") or x.get("name")
|
|
320
|
+
converted.append(
|
|
321
|
+
{"id": str(tid), "term": str(tterm)}
|
|
322
|
+
)
|
|
323
|
+
if converted:
|
|
324
|
+
eval_pack = converted
|
|
325
|
+
break
|
|
326
|
+
else:
|
|
327
|
+
raise TypeError(
|
|
328
|
+
"With ontologizer=False, eval_data must be a list of dicts with keys {'id','term'}."
|
|
329
|
+
)
|
|
330
|
+
else:
|
|
331
|
+
raise TypeError(
|
|
332
|
+
"With ontologizer=False, eval_data must be a list of dicts with keys {'id','term'}."
|
|
333
|
+
)
|
|
334
|
+
return self._term_typing(eval_pack, test=True)
|
|
335
|
+
|
|
336
|
+
def _term_typing(self, data: Any, test: bool = False) -> Optional[Any]:
|
|
337
|
+
"""
|
|
338
|
+
Internal implementation of the *term-typing* task.
|
|
339
|
+
|
|
340
|
+
Training mode (`test=False`):
|
|
341
|
+
• Expects a `list[str]` of allowed types. Stores a sorted unique copy.
|
|
342
|
+
|
|
343
|
+
Inference mode (`test=True`):
|
|
344
|
+
• Expects a `list[dict]` of `{"id","term"}` items.
|
|
345
|
+
• Requires `load()` to have been called (model/tokenizer available).
|
|
346
|
+
• Builds a blind prompt per item, generates text, parses quoted
|
|
347
|
+
candidates, and filters them to `self.allowed_types`.
|
|
348
|
+
|
|
349
|
+
Args:
|
|
350
|
+
data: See the mode-specific expectations above.
|
|
351
|
+
test: Set `True` to run inference; `False` to store the type inventory.
|
|
352
|
+
|
|
353
|
+
Returns:
|
|
354
|
+
• `None` in training mode.
|
|
355
|
+
• `list[dict]` with `{"id","term","types":[...]}` in inference mode.
|
|
356
|
+
|
|
357
|
+
Raises:
|
|
358
|
+
TypeError: If `data` is not in the expected shape for the mode.
|
|
359
|
+
RuntimeError: If model/tokenizer are not loaded at inference time.
|
|
360
|
+
"""
|
|
361
|
+
if not test:
|
|
362
|
+
# training: expect a list of strings (type labels)
|
|
363
|
+
if not isinstance(data, list):
|
|
364
|
+
raise TypeError("Expected a list of type labels at training time.")
|
|
365
|
+
self.allowed_types = sorted(set(data))
|
|
366
|
+
return None
|
|
367
|
+
|
|
368
|
+
# Inference path
|
|
369
|
+
if not isinstance(data, list) or not all(isinstance(x, dict) for x in data):
|
|
370
|
+
raise TypeError(
|
|
371
|
+
"At prediction time, expected a list of {'id','term'} dicts."
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
if self.model is None or self.tokenizer is None:
|
|
375
|
+
raise RuntimeError(
|
|
376
|
+
"Model/tokenizer not loaded. Call .load() before predict()."
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
results = []
|
|
380
|
+
for item in data:
|
|
381
|
+
term_id = item["id"]
|
|
382
|
+
term_text = item["term"]
|
|
383
|
+
prompt = self._build_blind_prompt(term_id, term_text, self.allowed_types)
|
|
384
|
+
types = self._generate_and_parse_types(prompt)
|
|
385
|
+
results.append({"id": term_id, "term": term_text, "types": types})
|
|
386
|
+
|
|
387
|
+
return results
|
|
388
|
+
|
|
389
|
+
def _format_types_inline(self, allowed: List[str]) -> str:
|
|
390
|
+
"""
|
|
391
|
+
Format the allowed types for inline inclusion in prompts.
|
|
392
|
+
|
|
393
|
+
Args:
|
|
394
|
+
allowed: List of allowed type labels.
|
|
395
|
+
|
|
396
|
+
Returns:
|
|
397
|
+
A comma-separated string of quoted types, e.g.:
|
|
398
|
+
`"type1", "type2", "type3"`. Returns an empty string for an empty list.
|
|
399
|
+
"""
|
|
400
|
+
if not allowed:
|
|
401
|
+
return ""
|
|
402
|
+
return ", ".join(f'"{t}"' for t in allowed if isinstance(t, str) and t.strip())
|
|
403
|
+
|
|
404
|
+
def _build_blind_prompt(
|
|
405
|
+
self, term_id: str, term: str, allowed_types: List[str]
|
|
406
|
+
) -> str:
|
|
407
|
+
"""
|
|
408
|
+
Construct the blind JSON prompt for a single term.
|
|
409
|
+
|
|
410
|
+
The prompt:
|
|
411
|
+
• Instructs the model to produce ONLY a JSON array of `{id, types}` objects.
|
|
412
|
+
• Provides the allowed types list so the model should only use those.
|
|
413
|
+
• Includes the single input item for which the model must decide types.
|
|
414
|
+
|
|
415
|
+
Args:
|
|
416
|
+
term_id: Identifier to carry through to the output JSON.
|
|
417
|
+
term: The input term string to classify.
|
|
418
|
+
allowed_types: Inventory used to constrain outputs.
|
|
419
|
+
|
|
420
|
+
Returns:
|
|
421
|
+
The full prompt string to feed to the LLM.
|
|
422
|
+
"""
|
|
423
|
+
allowed_str = self._format_types_inline(allowed_types)
|
|
424
|
+
return (
|
|
425
|
+
"Identify the type(s) of the term in a second JSON file.\n"
|
|
426
|
+
"A term can have more than one type.\n"
|
|
427
|
+
"Output file must be in this format:\n"
|
|
428
|
+
"[\n"
|
|
429
|
+
'{ "id": "TT_465e8904", "types": [ "type1" ] },\n'
|
|
430
|
+
'{ "id": "TT_01c7707e", "types": [ "type2", "type3" ] },\n'
|
|
431
|
+
'{ "id": "TT_b20cb478", "types": [ "type4" ] }\n'
|
|
432
|
+
"]\n"
|
|
433
|
+
"The id must be taken from the input JSON file.\n"
|
|
434
|
+
"You must find the type(s) for each term in the JSON file.\n"
|
|
435
|
+
"Types must be selected only from the types list.\n\n"
|
|
436
|
+
f"Types list: {allowed_str}\n\n"
|
|
437
|
+
f'{{ "id": "{term_id}", "term": "{term}" }}'
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
def _generate_and_parse_types(self, prompt: str) -> List[str]:
|
|
441
|
+
"""
|
|
442
|
+
Greedy-generate text, extract candidate types, and filter to the inventory.
|
|
443
|
+
|
|
444
|
+
Workflow:
|
|
445
|
+
1) Tokenize the prompt and generate deterministically (greedy).
|
|
446
|
+
2) Decode and extract quoted substrings via regex (e.g., `"type"`).
|
|
447
|
+
3) Keep only those candidates that exist in `self.allowed_types`.
|
|
448
|
+
4) Return a unique, sorted list (stable across runs).
|
|
449
|
+
|
|
450
|
+
Args:
|
|
451
|
+
prompt: Fully formatted prompt string.
|
|
452
|
+
|
|
453
|
+
Returns:
|
|
454
|
+
List of predicted type labels (possibly empty if none found).
|
|
455
|
+
|
|
456
|
+
Raises:
|
|
457
|
+
AssertionError: If `model` or `tokenizer` are unexpectedly `None`.
|
|
458
|
+
"""
|
|
459
|
+
assert self.model is not None and self.tokenizer is not None
|
|
460
|
+
|
|
461
|
+
# Tokenize prompt and move tensors to model device to avoid device mismatch
|
|
462
|
+
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
|
|
463
|
+
|
|
464
|
+
with torch.no_grad():
|
|
465
|
+
outputs = self.model.generate(
|
|
466
|
+
**inputs,
|
|
467
|
+
max_new_tokens=self.max_new_tokens,
|
|
468
|
+
do_sample=False, # deterministic (greedy) decoding
|
|
469
|
+
pad_token_id=self.tokenizer.eos_token_id,
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
# Decode full generated sequence (prompt + generation). Then extract quoted strings.
|
|
473
|
+
text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
474
|
+
candidates = self._quoted_re.findall(text)
|
|
475
|
+
|
|
476
|
+
# Filter candidates to the allowed inventory and stabilize order.
|
|
477
|
+
filtered = [c for c in candidates if c in self.allowed_types]
|
|
478
|
+
return sorted(set(filtered))
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright (c) 2025 SciKnowOrg
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the MIT License (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# https://opensource.org/licenses/MIT
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .alexbek import AlexbekFewShotLearner
|
|
16
|
+
from .sbunlp import SBUNLPFewShotLearner
|