OntoLearner 1.4.10__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ontolearner/VERSION +1 -1
- ontolearner/base/learner.py +41 -18
- ontolearner/evaluation/metrics.py +72 -32
- ontolearner/learner/__init__.py +3 -2
- ontolearner/learner/label_mapper.py +5 -4
- ontolearner/learner/llm.py +257 -0
- ontolearner/learner/prompt.py +40 -5
- ontolearner/learner/rag/__init__.py +14 -0
- ontolearner/learner/{rag.py → rag/rag.py} +7 -2
- ontolearner/learner/retriever/__init__.py +1 -1
- ontolearner/learner/retriever/{llm_retriever.py → augmented_retriever.py} +48 -39
- ontolearner/learner/retriever/learner.py +3 -4
- ontolearner/learner/taxonomy_discovery/alexbek.py +632 -310
- ontolearner/learner/taxonomy_discovery/skhnlp.py +216 -156
- ontolearner/learner/text2onto/__init__.py +1 -1
- ontolearner/learner/text2onto/alexbek.py +484 -1105
- ontolearner/learner/text2onto/sbunlp.py +498 -493
- ontolearner/ontology/biology.py +2 -3
- ontolearner/ontology/chemistry.py +16 -18
- ontolearner/ontology/ecology_environment.py +2 -3
- ontolearner/ontology/general.py +4 -6
- ontolearner/ontology/material_science_engineering.py +64 -45
- ontolearner/ontology/medicine.py +2 -3
- ontolearner/ontology/scholarly_knowledge.py +6 -9
- ontolearner/processor.py +3 -3
- ontolearner/text2onto/splitter.py +69 -6
- {ontolearner-1.4.10.dist-info → ontolearner-1.5.0.dist-info}/METADATA +2 -2
- {ontolearner-1.4.10.dist-info → ontolearner-1.5.0.dist-info}/RECORD +30 -29
- {ontolearner-1.4.10.dist-info → ontolearner-1.5.0.dist-info}/WHEEL +1 -1
- {ontolearner-1.4.10.dist-info → ontolearner-1.5.0.dist-info}/licenses/LICENSE +0 -0
ontolearner/VERSION
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
1.
|
|
1
|
+
1.5.0
|
ontolearner/base/learner.py
CHANGED
|
@@ -18,6 +18,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
18
18
|
import torch
|
|
19
19
|
import torch.nn.functional as F
|
|
20
20
|
from sentence_transformers import SentenceTransformer
|
|
21
|
+
from collections import defaultdict
|
|
21
22
|
|
|
22
23
|
class AutoLearner(ABC):
|
|
23
24
|
"""
|
|
@@ -70,6 +71,7 @@ class AutoLearner(ABC):
|
|
|
70
71
|
- "term-typing": Predict semantic types for terms
|
|
71
72
|
- "taxonomy-discovery": Identify hierarchical relationships
|
|
72
73
|
- "non-taxonomy-discovery": Identify non-hierarchical relationships
|
|
74
|
+
- "text2onto" : Extract ontology terms and their semantic types from documents
|
|
73
75
|
|
|
74
76
|
Raises:
|
|
75
77
|
NotImplementedError: If not implemented by concrete class.
|
|
@@ -81,6 +83,8 @@ class AutoLearner(ABC):
|
|
|
81
83
|
self._taxonomy_discovery(train_data, test=False)
|
|
82
84
|
elif task == 'non-taxonomic-re':
|
|
83
85
|
self._non_taxonomic_re(train_data, test=False)
|
|
86
|
+
elif task == 'text2onto':
|
|
87
|
+
self._text2onto(train_data, test=False)
|
|
84
88
|
else:
|
|
85
89
|
raise ValueError(f"{task} is not a valid task.")
|
|
86
90
|
|
|
@@ -103,6 +107,7 @@ class AutoLearner(ABC):
|
|
|
103
107
|
- term-typing: List of predicted types for each term
|
|
104
108
|
- taxonomy-discovery: Boolean predictions for relationships
|
|
105
109
|
- non-taxonomy-discovery: Predicted relation types
|
|
110
|
+
- text2onto : Extract ontology terms and their semantic types from documents
|
|
106
111
|
|
|
107
112
|
Raises:
|
|
108
113
|
NotImplementedError: If not implemented by concrete class.
|
|
@@ -115,6 +120,8 @@ class AutoLearner(ABC):
|
|
|
115
120
|
return self._taxonomy_discovery(eval_data, test=True)
|
|
116
121
|
elif task == 'non-taxonomic-re':
|
|
117
122
|
return self._non_taxonomic_re(eval_data, test=True)
|
|
123
|
+
elif task == 'text2onto':
|
|
124
|
+
return self._text2onto(eval_data, test=True)
|
|
118
125
|
else:
|
|
119
126
|
raise ValueError(f"{task} is not a valid task.")
|
|
120
127
|
|
|
@@ -147,6 +154,9 @@ class AutoLearner(ABC):
|
|
|
147
154
|
def _non_taxonomic_re(self, data: Any, test: bool = False) -> Optional[Any]:
|
|
148
155
|
pass
|
|
149
156
|
|
|
157
|
+
def _text2onto(self, data: Any, test: bool = False) -> Optional[Any]:
|
|
158
|
+
pass
|
|
159
|
+
|
|
150
160
|
def tasks_data_former(self, data: Any, task: str, test: bool = False) -> List[str | Dict[str, str]]:
|
|
151
161
|
formatted_data = []
|
|
152
162
|
if task == "term-typing":
|
|
@@ -171,6 +181,7 @@ class AutoLearner(ABC):
|
|
|
171
181
|
non_taxonomic_types = list(set(non_taxonomic_types))
|
|
172
182
|
non_taxonomic_res = list(set(non_taxonomic_res))
|
|
173
183
|
formatted_data = {"types": non_taxonomic_types, "relations": non_taxonomic_res}
|
|
184
|
+
|
|
174
185
|
return formatted_data
|
|
175
186
|
|
|
176
187
|
def tasks_ground_truth_former(self, data: Any, task: str) -> List[Dict[str, str]]:
|
|
@@ -186,6 +197,26 @@ class AutoLearner(ABC):
|
|
|
186
197
|
formatted_data.append({"head": non_taxonomic_triplets.head,
|
|
187
198
|
"tail": non_taxonomic_triplets.tail,
|
|
188
199
|
"relation": non_taxonomic_triplets.relation})
|
|
200
|
+
if task == "text2onto":
|
|
201
|
+
terms2docs = data.get("terms2docs", {}) or {}
|
|
202
|
+
terms2types = data.get("terms2types", {}) or {}
|
|
203
|
+
|
|
204
|
+
# gold doc→terms
|
|
205
|
+
gold_terms = []
|
|
206
|
+
for term, doc_ids in terms2docs.items():
|
|
207
|
+
for doc_id in doc_ids or []:
|
|
208
|
+
gold_terms.append({"doc_id": doc_id, "term": term})
|
|
209
|
+
|
|
210
|
+
# gold doc→types derived via doc→terms + term→types
|
|
211
|
+
doc2types = defaultdict(set)
|
|
212
|
+
for term, doc_ids in terms2docs.items():
|
|
213
|
+
for doc_id in doc_ids or []:
|
|
214
|
+
for ty in (terms2types.get(term, []) or []):
|
|
215
|
+
if isinstance(ty, str) and ty.strip():
|
|
216
|
+
doc2types[doc_id].add(ty.strip())
|
|
217
|
+
gold_types = [{"doc_id": doc_id, "type": ty} for doc_id, tys in doc2types.items() for ty in tys]
|
|
218
|
+
return {"terms": gold_terms, "types": gold_types}
|
|
219
|
+
|
|
189
220
|
return formatted_data
|
|
190
221
|
|
|
191
222
|
class AutoLLM(ABC):
|
|
@@ -201,7 +232,7 @@ class AutoLLM(ABC):
|
|
|
201
232
|
tokenizer: The tokenizer associated with the model.
|
|
202
233
|
"""
|
|
203
234
|
|
|
204
|
-
def __init__(self, label_mapper: Any, device: str='cpu', token: str="") -> None:
|
|
235
|
+
def __init__(self, label_mapper: Any, device: str='cpu', token: str="", max_length: int = 512) -> None:
|
|
205
236
|
"""
|
|
206
237
|
Initialize the LLM component.
|
|
207
238
|
|
|
@@ -213,6 +244,7 @@ class AutoLLM(ABC):
|
|
|
213
244
|
self.device=device
|
|
214
245
|
self.model: Optional[Any] = None
|
|
215
246
|
self.tokenizer: Optional[Any] = None
|
|
247
|
+
self.max_length = max_length
|
|
216
248
|
|
|
217
249
|
|
|
218
250
|
def load(self, model_id: str) -> None:
|
|
@@ -236,10 +268,8 @@ class AutoLLM(ABC):
|
|
|
236
268
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left', token=self.token)
|
|
237
269
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
238
270
|
if self.device == "cpu":
|
|
239
|
-
# device_map = "cpu"
|
|
240
271
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
241
272
|
model_id,
|
|
242
|
-
# device_map=device_map,
|
|
243
273
|
torch_dtype=torch.bfloat16,
|
|
244
274
|
token=self.token
|
|
245
275
|
)
|
|
@@ -248,11 +278,12 @@ class AutoLLM(ABC):
|
|
|
248
278
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
249
279
|
model_id,
|
|
250
280
|
device_map=device_map,
|
|
251
|
-
|
|
252
|
-
|
|
281
|
+
token=self.token,
|
|
282
|
+
trust_remote_code=True,
|
|
253
283
|
)
|
|
254
284
|
self.label_mapper.fit()
|
|
255
285
|
|
|
286
|
+
@torch.no_grad()
|
|
256
287
|
def generate(self, inputs: List[str], max_new_tokens: int = 50) -> List[str]:
|
|
257
288
|
"""
|
|
258
289
|
Generate text responses for the given input prompts.
|
|
@@ -276,29 +307,21 @@ class AutoLLM(ABC):
|
|
|
276
307
|
List of generated text responses, one for each input prompt.
|
|
277
308
|
Responses include the original input plus generated continuation.
|
|
278
309
|
"""
|
|
279
|
-
# Tokenize inputs and move to device
|
|
280
310
|
encoded_inputs = self.tokenizer(inputs,
|
|
281
311
|
return_tensors="pt",
|
|
282
|
-
|
|
283
|
-
truncation=True
|
|
312
|
+
max_length=self.max_length,
|
|
313
|
+
truncation=True,
|
|
314
|
+
padding=True).to(self.model.device)
|
|
284
315
|
input_ids = encoded_inputs["input_ids"]
|
|
285
316
|
input_length = input_ids.shape[1]
|
|
286
|
-
|
|
287
|
-
# Generate output
|
|
288
317
|
outputs = self.model.generate(
|
|
289
318
|
**encoded_inputs,
|
|
290
319
|
max_new_tokens=max_new_tokens,
|
|
291
|
-
pad_token_id=self.tokenizer.eos_token_id
|
|
320
|
+
pad_token_id=self.tokenizer.eos_token_id,
|
|
321
|
+
eos_token_id=self.tokenizer.eos_token_id
|
|
292
322
|
)
|
|
293
|
-
|
|
294
|
-
# Extract only the newly generated tokens (excluding prompt)
|
|
295
323
|
generated_tokens = outputs[:, input_length:]
|
|
296
|
-
|
|
297
|
-
# Decode only the generated part
|
|
298
324
|
decoded_outputs = [self.tokenizer.decode(g, skip_special_tokens=True).strip() for g in generated_tokens]
|
|
299
|
-
print(decoded_outputs)
|
|
300
|
-
print(self.label_mapper.predict(decoded_outputs))
|
|
301
|
-
# Map the decoded text to labels
|
|
302
325
|
return self.label_mapper.predict(decoded_outputs)
|
|
303
326
|
|
|
304
327
|
class AutoRetriever(ABC):
|
|
@@ -11,44 +11,84 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
from typing import List, Dict, Tuple, Set
|
|
14
|
+
from typing import List, Dict, Tuple, Set, Any, Union
|
|
15
15
|
|
|
16
16
|
SYMMETRIC_RELATIONS = {"equivalentclass", "sameas", "disjointwith"}
|
|
17
17
|
|
|
18
|
-
def text2onto_metrics(
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
18
|
+
def text2onto_metrics(
|
|
19
|
+
y_true: Dict[str, Any],
|
|
20
|
+
y_pred: Dict[str, Any],
|
|
21
|
+
similarity_threshold: float = 0.8
|
|
22
|
+
) -> Dict[str, Any]:
|
|
23
|
+
"""
|
|
24
|
+
Expects:
|
|
25
|
+
y_true = {"terms": [{"doc_id": str, "term": str}, ...],
|
|
26
|
+
"types": [{"doc_id": str, "type": str}, ...]}
|
|
27
|
+
y_pred = same shape
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
{"terms": {...}, "types": {...}}
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def jaccard_similarity(text_a: str, text_b: str) -> float:
|
|
34
|
+
tokens_a = set(text_a.lower().split())
|
|
35
|
+
tokens_b = set(text_b.lower().split())
|
|
36
|
+
if not tokens_a and not tokens_b:
|
|
23
37
|
return 1.0
|
|
24
|
-
return len(
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
38
|
+
return len(tokens_a & tokens_b) / len(tokens_a | tokens_b)
|
|
39
|
+
|
|
40
|
+
def pairs_to_strings(rows: List[Dict[str, str]], value_key: str) -> List[str]:
|
|
41
|
+
paired_strings: List[str] = []
|
|
42
|
+
for row in rows or []:
|
|
43
|
+
doc_id = (row.get("doc_id") or "").strip()
|
|
44
|
+
value = (row.get(value_key) or "").strip()
|
|
45
|
+
if doc_id and value:
|
|
46
|
+
# keep doc association + allow token Jaccard
|
|
47
|
+
paired_strings.append(f"{doc_id} {value}")
|
|
48
|
+
return paired_strings
|
|
49
|
+
|
|
50
|
+
def score_list(ground_truth_items: List[str], predicted_items: List[str]) -> Dict[str, Union[float, int]]:
|
|
51
|
+
matched_ground_truth_indices: Set[int] = set()
|
|
52
|
+
matched_predicted_indices: Set[int] = set()
|
|
53
|
+
|
|
54
|
+
for predicted_index, predicted_item in enumerate(predicted_items):
|
|
55
|
+
for ground_truth_index, ground_truth_item in enumerate(ground_truth_items):
|
|
56
|
+
if ground_truth_index in matched_ground_truth_indices:
|
|
57
|
+
continue
|
|
58
|
+
|
|
59
|
+
if jaccard_similarity(predicted_item, ground_truth_item) >= similarity_threshold:
|
|
60
|
+
matched_predicted_indices.add(predicted_index)
|
|
61
|
+
matched_ground_truth_indices.add(ground_truth_index)
|
|
62
|
+
break
|
|
63
|
+
|
|
64
|
+
total_correct = len(matched_predicted_indices)
|
|
65
|
+
total_predicted = len(predicted_items)
|
|
66
|
+
total_ground_truth = len(ground_truth_items)
|
|
67
|
+
|
|
68
|
+
precision = total_correct / total_predicted if total_predicted else 0.0
|
|
69
|
+
recall = total_correct / total_ground_truth if total_ground_truth else 0.0
|
|
70
|
+
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
|
|
71
|
+
|
|
72
|
+
return {
|
|
73
|
+
"f1_score": f1,
|
|
74
|
+
"precision": precision,
|
|
75
|
+
"recall": recall,
|
|
76
|
+
"total_correct": total_correct,
|
|
77
|
+
"total_predicted": total_predicted,
|
|
78
|
+
"total_ground_truth": total_ground_truth,
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
ground_truth_terms = pairs_to_strings(y_true.get("terms", []), "term")
|
|
82
|
+
predicted_terms = pairs_to_strings(y_pred.get("terms", []), "term")
|
|
83
|
+
ground_truth_types = pairs_to_strings(y_true.get("types", []), "type")
|
|
84
|
+
predicted_types = pairs_to_strings(y_pred.get("types", []), "type")
|
|
85
|
+
|
|
86
|
+
terms_metrics = score_list(ground_truth_terms, predicted_terms)
|
|
87
|
+
types_metrics = score_list(ground_truth_types, predicted_types)
|
|
41
88
|
|
|
42
|
-
precision = total_correct / total_predicted if total_predicted > 0 else 0
|
|
43
|
-
recall = total_correct / total_ground_truth if total_ground_truth > 0 else 0
|
|
44
|
-
f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
|
|
45
89
|
return {
|
|
46
|
-
"
|
|
47
|
-
"
|
|
48
|
-
"recall": recall,
|
|
49
|
-
"total_correct": total_correct,
|
|
50
|
-
"total_predicted": total_predicted,
|
|
51
|
-
"total_ground_truth": total_ground_truth
|
|
90
|
+
"terms": terms_metrics,
|
|
91
|
+
"types": types_metrics,
|
|
52
92
|
}
|
|
53
93
|
|
|
54
94
|
def term_typing_metrics(y_true: List[Dict[str, List[str]]], y_pred: List[Dict[str, List[str]]]) -> Dict[str, float | int]:
|
ontolearner/learner/__init__.py
CHANGED
|
@@ -12,8 +12,9 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from .llm import AutoLLMLearner, FalconLLM, MistralLLM
|
|
15
|
+
from .llm import AutoLLMLearner, FalconLLM, MistralLLM, LogitMistralLLM, \
|
|
16
|
+
QwenInstructLLM, QwenThinkingLLM, LogitAutoLLM, LogitQuantAutoLLM
|
|
16
17
|
from .retriever import AutoRetrieverLearner, LLMAugmentedRetrieverLearner
|
|
17
|
-
from .rag import AutoRAGLearner
|
|
18
|
+
from .rag import AutoRAGLearner, LLMAugmentedRAGLearner
|
|
18
19
|
from .prompt import StandardizedPrompting
|
|
19
20
|
from .label_mapper import LabelMapper
|
|
@@ -31,7 +31,7 @@ class LabelMapper:
|
|
|
31
31
|
ngram_range: Tuple=(1, 1),
|
|
32
32
|
label_dict: Dict[str, List[str]]=None,
|
|
33
33
|
analyzer: str = 'word',
|
|
34
|
-
iterator_no: int =
|
|
34
|
+
iterator_no: int = 1000):
|
|
35
35
|
"""
|
|
36
36
|
Initializes the TFIDFLabelMapper with a specified classifier and TF-IDF configuration.
|
|
37
37
|
|
|
@@ -45,11 +45,12 @@ class LabelMapper:
|
|
|
45
45
|
if label_dict is None:
|
|
46
46
|
label_dict = {
|
|
47
47
|
"yes": ["yes", "true"],
|
|
48
|
-
"no": ["no", "false"
|
|
48
|
+
"no": ["no", "false"]
|
|
49
49
|
}
|
|
50
|
-
self.
|
|
50
|
+
self.label_dict = label_dict
|
|
51
|
+
self.labels = [label.lower() for label in list(self.label_dict.keys())]
|
|
51
52
|
self.x_train, self.y_train = [], []
|
|
52
|
-
for label, candidates in label_dict.items():
|
|
53
|
+
for label, candidates in self.label_dict.items():
|
|
53
54
|
self.x_train += [label] + candidates
|
|
54
55
|
self.y_train += [label] * (len(candidates) + 1)
|
|
55
56
|
self.x_train = iterator_no * self.x_train
|
ontolearner/learner/llm.py
CHANGED
|
@@ -18,9 +18,11 @@ import warnings
|
|
|
18
18
|
from tqdm import tqdm
|
|
19
19
|
from torch.utils.data import DataLoader
|
|
20
20
|
import torch
|
|
21
|
+
import torch.nn.functional as F
|
|
21
22
|
from transformers import Mistral3ForConditionalGeneration
|
|
22
23
|
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
|
23
24
|
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
|
25
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
|
24
26
|
|
|
25
27
|
class AutoLLMLearner(AutoLearner):
|
|
26
28
|
|
|
@@ -144,6 +146,7 @@ class AutoLLMLearner(AutoLearner):
|
|
|
144
146
|
|
|
145
147
|
class FalconLLM(AutoLLM):
|
|
146
148
|
|
|
149
|
+
@torch.no_grad()
|
|
147
150
|
def generate(self, inputs: List[str], max_new_tokens: int = 50) -> List[str]:
|
|
148
151
|
encoded_inputs = self.tokenizer(inputs,
|
|
149
152
|
return_tensors="pt",
|
|
@@ -160,6 +163,7 @@ class FalconLLM(AutoLLM):
|
|
|
160
163
|
decoded_outputs = [self.tokenizer.decode(g, skip_special_tokens=True).strip() for g in generated_tokens]
|
|
161
164
|
return self.label_mapper.predict(decoded_outputs)
|
|
162
165
|
|
|
166
|
+
|
|
163
167
|
class MistralLLM(AutoLLM):
|
|
164
168
|
|
|
165
169
|
def load(self, model_id: str) -> None:
|
|
@@ -178,6 +182,7 @@ class MistralLLM(AutoLLM):
|
|
|
178
182
|
self.tokenizer.pad_token_id = self.model.generation_config.eos_token_id
|
|
179
183
|
self.label_mapper.fit()
|
|
180
184
|
|
|
185
|
+
@torch.no_grad()
|
|
181
186
|
def generate(self, inputs: List[str], max_new_tokens: int = 50) -> List[str]:
|
|
182
187
|
tokenized_list = []
|
|
183
188
|
for prompt in inputs:
|
|
@@ -206,3 +211,255 @@ class MistralLLM(AutoLLM):
|
|
|
206
211
|
output_text = self.tokenizer.decode(tokens[len(tokenized_list[i]):])
|
|
207
212
|
decoded_outputs.append(output_text)
|
|
208
213
|
return self.label_mapper.predict(decoded_outputs)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class LogitMistralLLM(AutoLLM):
|
|
217
|
+
label_dict = {
|
|
218
|
+
"yes": ["yes", "true", " yes", "Yes"],
|
|
219
|
+
"no": ["no", "false", " no", "No"]
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
def _get_label_token_ids(self):
|
|
223
|
+
label_token_ids = {}
|
|
224
|
+
|
|
225
|
+
for label, words in self.label_dict.items():
|
|
226
|
+
ids = []
|
|
227
|
+
for w in words:
|
|
228
|
+
messages = [{"role": "user", "content": [{"type": "text", "text": w}]}]
|
|
229
|
+
tokenized = self.tokenizer.encode_chat_completion(ChatCompletionRequest(messages=messages))
|
|
230
|
+
token_ids = tokenized.tokens[2:-1]
|
|
231
|
+
ids.append(token_ids)
|
|
232
|
+
label_token_ids[label] = ids
|
|
233
|
+
return label_token_ids
|
|
234
|
+
|
|
235
|
+
def load(self, model_id: str) -> None:
|
|
236
|
+
self.tokenizer = MistralTokenizer.from_hf_hub(model_id)
|
|
237
|
+
self.tokenizer.padding_side = 'left'
|
|
238
|
+
device_map = "cpu" if self.device == "cpu" else "balanced"
|
|
239
|
+
self.model = Mistral3ForConditionalGeneration.from_pretrained(
|
|
240
|
+
model_id,
|
|
241
|
+
device_map=device_map,
|
|
242
|
+
torch_dtype=torch.bfloat16,
|
|
243
|
+
token=self.token
|
|
244
|
+
)
|
|
245
|
+
self.pad_token_id = self.model.generation_config.eos_token_id
|
|
246
|
+
self.label_token_ids = self._get_label_token_ids()
|
|
247
|
+
|
|
248
|
+
@torch.no_grad()
|
|
249
|
+
def generate(self, inputs: List[str], max_new_tokens: int = 1) -> List[str]:
|
|
250
|
+
tokenized_list = []
|
|
251
|
+
for prompt in inputs:
|
|
252
|
+
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
|
253
|
+
req = ChatCompletionRequest(messages=messages)
|
|
254
|
+
tokenized = self.tokenizer.encode_chat_completion(req)
|
|
255
|
+
tokenized_list.append(tokenized.tokens)
|
|
256
|
+
|
|
257
|
+
max_len = max(len(t) for t in tokenized_list)
|
|
258
|
+
input_ids, attention_masks = [], []
|
|
259
|
+
for tokens in tokenized_list:
|
|
260
|
+
pad_len = max_len - len(tokens)
|
|
261
|
+
input_ids.append(tokens + [self.pad_token_id] * pad_len)
|
|
262
|
+
attention_masks.append([1] * len(tokens) + [0] * pad_len)
|
|
263
|
+
|
|
264
|
+
input_ids = torch.tensor(input_ids).to(self.model.device)
|
|
265
|
+
attention_masks = torch.tensor(attention_masks).to(self.model.device)
|
|
266
|
+
|
|
267
|
+
outputs = self.model(input_ids=input_ids, attention_mask=attention_masks)
|
|
268
|
+
# logits: [batch, seq_len, vocab]
|
|
269
|
+
logits = outputs.logits
|
|
270
|
+
# next-token prediction
|
|
271
|
+
last_logits = logits[:, -1, :]
|
|
272
|
+
probs = torch.softmax(last_logits, dim=-1)
|
|
273
|
+
predictions = []
|
|
274
|
+
for i in range(probs.size(0)):
|
|
275
|
+
label_scores = {}
|
|
276
|
+
for label, token_id_lists in self.label_token_ids.items():
|
|
277
|
+
score = 0.0
|
|
278
|
+
for token_ids in token_id_lists:
|
|
279
|
+
# single-token in practice, but safe
|
|
280
|
+
score += probs[i, token_ids[0]].item()
|
|
281
|
+
label_scores[label] = score
|
|
282
|
+
predictions.append(max(label_scores, key=label_scores.get))
|
|
283
|
+
return predictions
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
class QwenInstructLLM(AutoLLM):
|
|
287
|
+
|
|
288
|
+
def generate(self, inputs: List[str], max_new_tokens: int = 50) -> List[str]:
|
|
289
|
+
messages = [[{"role": "user", "content": prompt + " Please show your final response with 'answer': 'label'."}]
|
|
290
|
+
for prompt in inputs]
|
|
291
|
+
|
|
292
|
+
texts = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
293
|
+
|
|
294
|
+
encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding="max_length", truncation=True,
|
|
295
|
+
max_length=256).to(self.model.device)
|
|
296
|
+
|
|
297
|
+
generated_ids = self.model.generate(**encoded_inputs,
|
|
298
|
+
max_new_tokens=max_new_tokens,
|
|
299
|
+
use_cache=False,
|
|
300
|
+
pad_token_id=self.tokenizer.pad_token_id,
|
|
301
|
+
eos_token_id=self.tokenizer.eos_token_id)
|
|
302
|
+
decoded_outputs = []
|
|
303
|
+
for i in range(len(generated_ids)):
|
|
304
|
+
prompt_len = encoded_inputs.attention_mask[i].sum().item()
|
|
305
|
+
output_ids = generated_ids[i][prompt_len:].tolist()
|
|
306
|
+
output_content = self.tokenizer.decode(output_ids, skip_special_tokens=True).strip()
|
|
307
|
+
decoded_outputs.append(output_content)
|
|
308
|
+
return self.label_mapper.predict(decoded_outputs)
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
class QwenThinkingLLM(AutoLLM):
|
|
312
|
+
|
|
313
|
+
@torch.no_grad()
|
|
314
|
+
def generate(self, inputs: List[str], max_new_tokens: int = 50) -> List[str]:
|
|
315
|
+
messages = [[{"role": "user", "content": prompt + " Please show your final response with 'answer': 'label'."}]
|
|
316
|
+
for prompt in inputs]
|
|
317
|
+
texts = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
318
|
+
encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True).to(self.model.device)
|
|
319
|
+
generated_ids = self.model.generate(**encoded_inputs, max_new_tokens=max_new_tokens)
|
|
320
|
+
decoded_outputs = []
|
|
321
|
+
for i in range(len(generated_ids)):
|
|
322
|
+
prompt_len = encoded_inputs.attention_mask[i].sum().item()
|
|
323
|
+
output_ids = generated_ids[i][prompt_len:].tolist()
|
|
324
|
+
try:
|
|
325
|
+
end = len(output_ids) - output_ids[::-1].index(151668)
|
|
326
|
+
thinking_ids = output_ids[:end]
|
|
327
|
+
except ValueError:
|
|
328
|
+
thinking_ids = output_ids
|
|
329
|
+
thinking_content = self.tokenizer.decode(thinking_ids, skip_special_tokens=True).strip()
|
|
330
|
+
decoded_outputs.append(thinking_content)
|
|
331
|
+
return self.label_mapper.predict(decoded_outputs)
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
class LogitAutoLLM(AutoLLM):
|
|
335
|
+
def _get_label_token_ids(self):
|
|
336
|
+
label_token_ids = {}
|
|
337
|
+
for label, words in self.label_mapper.label_dict.items():
|
|
338
|
+
ids = []
|
|
339
|
+
for w in words:
|
|
340
|
+
token_ids = self.tokenizer.encode(w, add_special_tokens=False)
|
|
341
|
+
ids.append(token_ids)
|
|
342
|
+
label_token_ids[label] = ids
|
|
343
|
+
return label_token_ids
|
|
344
|
+
|
|
345
|
+
def load(self, model_id: str) -> None:
|
|
346
|
+
super().load(model_id)
|
|
347
|
+
self.label_token_ids = self._get_label_token_ids()
|
|
348
|
+
|
|
349
|
+
@torch.no_grad()
|
|
350
|
+
def generate(self, inputs: List[str], max_new_tokens: int = 1) -> List[str]:
|
|
351
|
+
encoded = self.tokenizer(inputs, return_tensors="pt", truncation=True, padding=True).to(self.model.device)
|
|
352
|
+
outputs = self.model(**encoded)
|
|
353
|
+
logits = outputs.logits # logits: [batch, seq_len, vocab]
|
|
354
|
+
last_logits = logits[:, -1, :] # [batch, vocab] # we only care about the NEXT token prediction
|
|
355
|
+
probs = F.softmax(last_logits, dim=-1)
|
|
356
|
+
predictions = []
|
|
357
|
+
for i in range(probs.size(0)):
|
|
358
|
+
label_scores = {}
|
|
359
|
+
for label, token_id_lists in self.label_token_ids.items():
|
|
360
|
+
score = 0.0
|
|
361
|
+
for token_ids in token_id_lists:
|
|
362
|
+
if len(token_ids) == 1:
|
|
363
|
+
score += probs[i, token_ids[0]].item()
|
|
364
|
+
else:
|
|
365
|
+
score += probs[i, token_ids[0]].item() # multi-token fallback (rare but safe)
|
|
366
|
+
label_scores[label] = score
|
|
367
|
+
predictions.append(max(label_scores, key=label_scores.get))
|
|
368
|
+
return predictions
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
class LogitQuantAutoLLM(AutoLLM):
|
|
372
|
+
label_dict = {
|
|
373
|
+
"yes": ["yes", "true", " yes", "Yes"],
|
|
374
|
+
"no": ["no", "false", " no", "No"]
|
|
375
|
+
}
|
|
376
|
+
|
|
377
|
+
def _get_label_token_ids(self):
|
|
378
|
+
label_token_ids = {}
|
|
379
|
+
|
|
380
|
+
for label, words in self.label_dict.items():
|
|
381
|
+
ids = []
|
|
382
|
+
for w in words:
|
|
383
|
+
token_ids = self.tokenizer.encode(
|
|
384
|
+
w,
|
|
385
|
+
add_special_tokens=False
|
|
386
|
+
)
|
|
387
|
+
# usually single-token, but be safe
|
|
388
|
+
ids.append(token_ids)
|
|
389
|
+
label_token_ids[label] = ids
|
|
390
|
+
|
|
391
|
+
return label_token_ids
|
|
392
|
+
|
|
393
|
+
def load(self, model_id: str) -> None:
|
|
394
|
+
|
|
395
|
+
self.tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left', token=self.token)
|
|
396
|
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
397
|
+
if self.device == "cpu":
|
|
398
|
+
# device_map = "cpu"
|
|
399
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
400
|
+
model_id,
|
|
401
|
+
# device_map=device_map,
|
|
402
|
+
torch_dtype=torch.bfloat16,
|
|
403
|
+
token=self.token
|
|
404
|
+
)
|
|
405
|
+
else:
|
|
406
|
+
device_map = "balanced"
|
|
407
|
+
# self.model = AutoModelForCausalLM.from_pretrained(
|
|
408
|
+
# model_id,
|
|
409
|
+
# device_map=device_map,
|
|
410
|
+
# torch_dtype=torch.bfloat16,
|
|
411
|
+
# token=self.token
|
|
412
|
+
# )
|
|
413
|
+
bnb_config = BitsAndBytesConfig(
|
|
414
|
+
load_in_4bit=True,
|
|
415
|
+
bnb_4bit_quant_type="nf4",
|
|
416
|
+
bnb_4bit_compute_dtype=torch.float16,
|
|
417
|
+
bnb_4bit_use_double_quant=True
|
|
418
|
+
)
|
|
419
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
420
|
+
model_id,
|
|
421
|
+
quantization_config=bnb_config,
|
|
422
|
+
device_map=device_map,
|
|
423
|
+
token=self.token,
|
|
424
|
+
# trust_remote_code=True,
|
|
425
|
+
# attn_implementation="flash_attention_2"
|
|
426
|
+
)
|
|
427
|
+
self.label_token_ids = self._get_label_token_ids()
|
|
428
|
+
|
|
429
|
+
@torch.no_grad()
|
|
430
|
+
def generate(self, inputs: List[str], max_new_tokens: int = 1) -> List[str]:
|
|
431
|
+
encoded = self.tokenizer(
|
|
432
|
+
inputs,
|
|
433
|
+
return_tensors="pt",
|
|
434
|
+
max_length=256,
|
|
435
|
+
truncation=True,
|
|
436
|
+
padding=True
|
|
437
|
+
).to(self.model.device)
|
|
438
|
+
|
|
439
|
+
outputs = self.model(**encoded)
|
|
440
|
+
|
|
441
|
+
# logits: [batch, seq_len, vocab]
|
|
442
|
+
logits = outputs.logits
|
|
443
|
+
|
|
444
|
+
# we only care about the NEXT token prediction
|
|
445
|
+
last_logits = logits[:, -1, :] # [batch, vocab]
|
|
446
|
+
|
|
447
|
+
probs = F.softmax(last_logits, dim=-1)
|
|
448
|
+
|
|
449
|
+
predictions = []
|
|
450
|
+
|
|
451
|
+
for i in range(probs.size(0)):
|
|
452
|
+
label_scores = {}
|
|
453
|
+
|
|
454
|
+
for label, token_id_lists in self.label_token_ids.items():
|
|
455
|
+
score = 0.0
|
|
456
|
+
for token_ids in token_id_lists:
|
|
457
|
+
if len(token_ids) == 1:
|
|
458
|
+
score += probs[i, token_ids[0]].item()
|
|
459
|
+
else:
|
|
460
|
+
# multi-token fallback (rare but safe)
|
|
461
|
+
score += probs[i, token_ids[0]].item()
|
|
462
|
+
label_scores[label] = score
|
|
463
|
+
|
|
464
|
+
predictions.append(max(label_scores, key=label_scores.get))
|
|
465
|
+
return predictions
|
ontolearner/learner/prompt.py
CHANGED
|
@@ -17,15 +17,50 @@ from ..base import AutoPrompt
|
|
|
17
17
|
class StandardizedPrompting(AutoPrompt):
|
|
18
18
|
def __init__(self, task: str = None):
|
|
19
19
|
if task == "term-typing":
|
|
20
|
-
prompt_template = """
|
|
20
|
+
prompt_template = """You are performing term typing.
|
|
21
|
+
|
|
22
|
+
Determine whether the given term is a clear and unambiguous instance of the specified high-level type.
|
|
23
|
+
|
|
24
|
+
Rules:
|
|
25
|
+
- Answer "yes" only if the term commonly and directly belongs to the type.
|
|
26
|
+
- Answer "no" if the term does not belong to the type, is ambiguous, or only weakly related.
|
|
27
|
+
- Use the most common meaning of the term.
|
|
28
|
+
- Do not explain your answer.
|
|
29
|
+
|
|
21
30
|
Term: {term}
|
|
22
31
|
Type: {type}
|
|
23
|
-
Answer:
|
|
32
|
+
Answer (yes or no):"""
|
|
24
33
|
elif task == "taxonomy-discovery":
|
|
25
|
-
prompt_template =
|
|
26
|
-
|
|
34
|
+
prompt_template = """You are identifying taxonomic (is-a) relationships.
|
|
35
|
+
|
|
36
|
+
Question:
|
|
37
|
+
Is "{parent}" a superclass (direct or indirect) of "{child}" in a standard conceptual or ontological hierarchy?
|
|
38
|
+
|
|
39
|
+
Rules:
|
|
40
|
+
- A superclass means: "{child}" is a type or instance of "{parent}".
|
|
41
|
+
- Answer "yes" only if the relationship is a true is-a relationship.
|
|
42
|
+
- Answer "no" for part-of, related-to, or associative relationships.
|
|
43
|
+
- Use general world knowledge.
|
|
44
|
+
- Do not explain.
|
|
45
|
+
|
|
46
|
+
Parent: {parent}
|
|
47
|
+
Child: {child}
|
|
48
|
+
Answer (yes or no):"""
|
|
27
49
|
elif task == "non-taxonomic-re":
|
|
28
|
-
prompt_template = """
|
|
50
|
+
prompt_template = """You are identifying non-taxonomic conceptual relationships.
|
|
51
|
+
|
|
52
|
+
Given two conceptual types, determine whether the specified relation typically holds between them.
|
|
53
|
+
|
|
54
|
+
Rules:
|
|
55
|
+
- Answer "yes" only if the relation commonly and meaningfully applies.
|
|
56
|
+
- Answer "no" if the relation is rare, indirect, or context-dependent.
|
|
57
|
+
- Do not infer relations that require specific situations.
|
|
58
|
+
- Do not explain.
|
|
59
|
+
|
|
60
|
+
Head type: {head}
|
|
61
|
+
Tail type: {tail}
|
|
62
|
+
Relation: {relation}
|
|
63
|
+
Answer (yes or no):"""
|
|
29
64
|
else:
|
|
30
65
|
raise ValueError("Unknown task! Current tasks are: 'term-typing', 'taxonomy-discovery', 'non-taxonomic-re'")
|
|
31
66
|
super().__init__(prompt_template)
|