OntoLearner 1.4.11__tar.gz → 1.5.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ontolearner-1.4.11 → ontolearner-1.5.0}/PKG-INFO +1 -1
- ontolearner-1.5.0/ontolearner/VERSION +1 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/base/learner.py +4 -2
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/__init__.py +2 -1
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/label_mapper.py +4 -3
- ontolearner-1.5.0/ontolearner/learner/llm.py +465 -0
- ontolearner-1.5.0/ontolearner/learner/taxonomy_discovery/alexbek.py +822 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/taxonomy_discovery/skhnlp.py +216 -156
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/biology.py +2 -3
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/chemistry.py +16 -18
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/ecology_environment.py +2 -3
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/general.py +4 -6
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/material_science_engineering.py +64 -45
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/medicine.py +2 -3
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/scholarly_knowledge.py +6 -9
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/processor.py +3 -3
- {ontolearner-1.4.11 → ontolearner-1.5.0}/pyproject.toml +1 -1
- ontolearner-1.4.11/ontolearner/VERSION +0 -1
- ontolearner-1.4.11/ontolearner/learner/llm.py +0 -208
- ontolearner-1.4.11/ontolearner/learner/taxonomy_discovery/alexbek.py +0 -500
- {ontolearner-1.4.11 → ontolearner-1.5.0}/LICENSE +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/README.md +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/images/logo.png +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/__init__.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/_learner.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/_ontology.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/base/__init__.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/base/ontology.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/base/text2onto.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/data_structure/__init__.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/data_structure/data.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/data_structure/metric.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/evaluation/__init__.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/evaluation/evaluate.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/evaluation/metrics.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/prompt.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/rag/__init__.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/rag/rag.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/retriever/__init__.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/retriever/augmented_retriever.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/retriever/crossencoder.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/retriever/embedding.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/retriever/learner.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/retriever/ngram.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/taxonomy_discovery/__init__.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/taxonomy_discovery/rwthdbis.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/taxonomy_discovery/sbunlp.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/term_typing/__init__.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/term_typing/alexbek.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/term_typing/rwthdbis.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/term_typing/sbunlp.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/text2onto/__init__.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/text2onto/alexbek.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/text2onto/sbunlp.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/__init__.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/agriculture.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/arts_humanities.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/education.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/events.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/finance.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/food_beverage.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/geography.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/industry.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/law.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/library_cultural_heritage.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/news_media.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/social_sciences.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/units_measurements.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/upper_ontologies.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/web.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/text2onto/__init__.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/text2onto/batchifier.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/text2onto/general.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/text2onto/splitter.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/text2onto/synthesizer.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/tools/__init__.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/tools/analyzer.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/tools/visualizer.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/utils/__init__.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/utils/io.py +0 -0
- {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/utils/train_test_split.py +0 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
1.5.0
|
|
@@ -232,7 +232,7 @@ class AutoLLM(ABC):
|
|
|
232
232
|
tokenizer: The tokenizer associated with the model.
|
|
233
233
|
"""
|
|
234
234
|
|
|
235
|
-
def __init__(self, label_mapper: Any, device: str='cpu', token: str="", max_length: int =
|
|
235
|
+
def __init__(self, label_mapper: Any, device: str='cpu', token: str="", max_length: int = 512) -> None:
|
|
236
236
|
"""
|
|
237
237
|
Initialize the LLM component.
|
|
238
238
|
|
|
@@ -283,6 +283,7 @@ class AutoLLM(ABC):
|
|
|
283
283
|
)
|
|
284
284
|
self.label_mapper.fit()
|
|
285
285
|
|
|
286
|
+
@torch.no_grad()
|
|
286
287
|
def generate(self, inputs: List[str], max_new_tokens: int = 50) -> List[str]:
|
|
287
288
|
"""
|
|
288
289
|
Generate text responses for the given input prompts.
|
|
@@ -309,7 +310,8 @@ class AutoLLM(ABC):
|
|
|
309
310
|
encoded_inputs = self.tokenizer(inputs,
|
|
310
311
|
return_tensors="pt",
|
|
311
312
|
max_length=self.max_length,
|
|
312
|
-
truncation=True
|
|
313
|
+
truncation=True,
|
|
314
|
+
padding=True).to(self.model.device)
|
|
313
315
|
input_ids = encoded_inputs["input_ids"]
|
|
314
316
|
input_length = input_ids.shape[1]
|
|
315
317
|
outputs = self.model.generate(
|
|
@@ -12,7 +12,8 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from .llm import AutoLLMLearner, FalconLLM, MistralLLM
|
|
15
|
+
from .llm import AutoLLMLearner, FalconLLM, MistralLLM, LogitMistralLLM, \
|
|
16
|
+
QwenInstructLLM, QwenThinkingLLM, LogitAutoLLM, LogitQuantAutoLLM
|
|
16
17
|
from .retriever import AutoRetrieverLearner, LLMAugmentedRetrieverLearner
|
|
17
18
|
from .rag import AutoRAGLearner, LLMAugmentedRAGLearner
|
|
18
19
|
from .prompt import StandardizedPrompting
|
|
@@ -45,11 +45,12 @@ class LabelMapper:
|
|
|
45
45
|
if label_dict is None:
|
|
46
46
|
label_dict = {
|
|
47
47
|
"yes": ["yes", "true"],
|
|
48
|
-
"no": ["no", "false"
|
|
48
|
+
"no": ["no", "false"]
|
|
49
49
|
}
|
|
50
|
-
self.
|
|
50
|
+
self.label_dict = label_dict
|
|
51
|
+
self.labels = [label.lower() for label in list(self.label_dict.keys())]
|
|
51
52
|
self.x_train, self.y_train = [], []
|
|
52
|
-
for label, candidates in label_dict.items():
|
|
53
|
+
for label, candidates in self.label_dict.items():
|
|
53
54
|
self.x_train += [label] + candidates
|
|
54
55
|
self.y_train += [label] * (len(candidates) + 1)
|
|
55
56
|
self.x_train = iterator_no * self.x_train
|
|
@@ -0,0 +1,465 @@
|
|
|
1
|
+
# Copyright (c) 2025 SciKnowOrg
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the MIT License (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# https://opensource.org/licenses/MIT
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from ..base import AutoLLM, AutoLearner
|
|
16
|
+
from typing import Any, List
|
|
17
|
+
import warnings
|
|
18
|
+
from tqdm import tqdm
|
|
19
|
+
from torch.utils.data import DataLoader
|
|
20
|
+
import torch
|
|
21
|
+
import torch.nn.functional as F
|
|
22
|
+
from transformers import Mistral3ForConditionalGeneration
|
|
23
|
+
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
|
24
|
+
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
|
25
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
|
26
|
+
|
|
27
|
+
class AutoLLMLearner(AutoLearner):
|
|
28
|
+
|
|
29
|
+
def __init__(self,
|
|
30
|
+
prompting,
|
|
31
|
+
label_mapper,
|
|
32
|
+
llm: AutoLLM = AutoLLM,
|
|
33
|
+
token: str = "",
|
|
34
|
+
max_new_tokens: int = 5,
|
|
35
|
+
batch_size: int = 10,
|
|
36
|
+
device='cpu') -> None:
|
|
37
|
+
super().__init__()
|
|
38
|
+
self.llm = llm(token=token, label_mapper=label_mapper, device=device)
|
|
39
|
+
self.prompting = prompting
|
|
40
|
+
self.batch_size = batch_size
|
|
41
|
+
self.max_new_tokens = max_new_tokens
|
|
42
|
+
self._is_term_typing_fit = False
|
|
43
|
+
|
|
44
|
+
def load(self, model_id: str = "mistralai/Mistral-7B-Instruct-v0.1", **kwargs: Any):
|
|
45
|
+
self.llm.load(model_id=model_id)
|
|
46
|
+
|
|
47
|
+
def _term_typing_predict(self, dataset):
|
|
48
|
+
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)
|
|
49
|
+
predictions = {}
|
|
50
|
+
for batch in tqdm(dataloader):
|
|
51
|
+
prediction = self.llm.generate(inputs=batch['prompt'], max_new_tokens=self.max_new_tokens)
|
|
52
|
+
for term, type, predict in zip(batch['term'], batch['type'], prediction):
|
|
53
|
+
if term not in predictions:
|
|
54
|
+
predictions[term] = []
|
|
55
|
+
if predict == 'yes':
|
|
56
|
+
predictions[term].append(type)
|
|
57
|
+
predicts = [{"term": term, "types": types} for term, types in predictions.items()]
|
|
58
|
+
return predicts
|
|
59
|
+
|
|
60
|
+
def _term_typing(self, data: Any, test: bool = False) -> Any:
|
|
61
|
+
"""
|
|
62
|
+
during training: data = ["type-1", .... ],
|
|
63
|
+
during testing: data = ['term-1', ...]
|
|
64
|
+
"""
|
|
65
|
+
if not isinstance(data, list) and not all(isinstance(item, str) for item in data):
|
|
66
|
+
raise TypeError("Expected a list of strings (types) for llm at term-typing task.")
|
|
67
|
+
if test:
|
|
68
|
+
if self._is_term_typing_fit:
|
|
69
|
+
prompting = self.prompting(task='term-typing')
|
|
70
|
+
dataset = [{"term": term, "type": type, "prompt": prompting.format(term=term, type=type)}
|
|
71
|
+
for term in data for type in self.candidate_types]
|
|
72
|
+
return self._term_typing_predict(dataset=dataset)
|
|
73
|
+
else:
|
|
74
|
+
raise RuntimeError("Term typing model must be fit before prediction.")
|
|
75
|
+
else:
|
|
76
|
+
self.candidate_types = data
|
|
77
|
+
self._is_term_typing_fit = True
|
|
78
|
+
|
|
79
|
+
def _taxonomy_discovery_predict(self, dataset):
|
|
80
|
+
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)
|
|
81
|
+
predictions = []
|
|
82
|
+
for batch in tqdm(dataloader):
|
|
83
|
+
prediction = self.llm.generate(inputs=batch['prompt'], max_new_tokens=self.max_new_tokens)
|
|
84
|
+
predictions.extend({"parent": parent, "child": child}
|
|
85
|
+
for parent, child, predict in zip(batch['parent'], batch['child'], prediction)
|
|
86
|
+
if predict == 'yes')
|
|
87
|
+
return predictions
|
|
88
|
+
|
|
89
|
+
def _taxonomy_discovery(self, data: Any, test: bool = False) -> Any:
|
|
90
|
+
"""
|
|
91
|
+
during training: data = ['type-1', ...],
|
|
92
|
+
during testing (same data): data= ['type-1', ...]
|
|
93
|
+
"""
|
|
94
|
+
if test:
|
|
95
|
+
if not isinstance(data, list) and not all(isinstance(item, str) for item in data):
|
|
96
|
+
raise TypeError("Expected a list of strings (types) for llm at term-typing task.")
|
|
97
|
+
prompting = self.prompting(task='taxonomy-discovery')
|
|
98
|
+
dataset = [{"parent": type_i, "child": type_j, "prompt": prompting.format(parent=type_i, child=type_j)}
|
|
99
|
+
for idx, type_i in enumerate(data) for jdx, type_j in enumerate(data) if idx < jdx]
|
|
100
|
+
return self._taxonomy_discovery_predict(dataset=dataset)
|
|
101
|
+
else:
|
|
102
|
+
warnings.warn("No requirement for fiting the taxonomy-discovery model, the predict module will use the input data to do the 'is-a' relationship detection")
|
|
103
|
+
|
|
104
|
+
def _non_taxonomic_re_predict(self, dataset):
|
|
105
|
+
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)
|
|
106
|
+
predictions = []
|
|
107
|
+
for batch in tqdm(dataloader):
|
|
108
|
+
prediction = self.llm.generate(inputs=batch['prompt'], max_new_tokens=self.max_new_tokens)
|
|
109
|
+
predictions.extend({"head": head, "tail": tail, "relation": relation}
|
|
110
|
+
for head, tail, relation, predict in
|
|
111
|
+
zip(batch['head'], batch['tail'], batch['relation'], prediction)
|
|
112
|
+
if predict == 'yes')
|
|
113
|
+
return predictions
|
|
114
|
+
|
|
115
|
+
def _non_taxonomic_re(self, data: Any, test: bool = False) -> Any:
|
|
116
|
+
"""
|
|
117
|
+
during training: data = ['type-1', ...],
|
|
118
|
+
during testing: {'types': [...], 'relations': [... ]}
|
|
119
|
+
"""
|
|
120
|
+
if test:
|
|
121
|
+
if 'types' not in data or 'relations' not in data:
|
|
122
|
+
raise ValueError("The non-taxonomic re predict should take {'types': [...], 'relations': [... ]}")
|
|
123
|
+
if len(data['types']) == 0:
|
|
124
|
+
warnings.warn("No `types` avaliable to do the non-taxonomic re-prediction.")
|
|
125
|
+
return None
|
|
126
|
+
# paring and finding paris that can have a relationship
|
|
127
|
+
prompting = self.prompting(task='taxonomy-discovery')
|
|
128
|
+
dataset = [{"parent": type_i, "child": type_j, "prompt": prompting.format(parent=type_i, child=type_j)}
|
|
129
|
+
for idx, type_i in enumerate(data['types']) for jdx, type_j in enumerate(data['types']) if idx < jdx]
|
|
130
|
+
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)
|
|
131
|
+
predicts_lst = []
|
|
132
|
+
for batch in tqdm(dataloader):
|
|
133
|
+
prediction = self.llm.generate(inputs=batch['prompt'], max_new_tokens=self.max_new_tokens)
|
|
134
|
+
predicts_lst.extend((parent, child)
|
|
135
|
+
for parent, child, predict in zip(batch['parent'], batch['child'], prediction)
|
|
136
|
+
if predict == 'yes')
|
|
137
|
+
# finding relationships
|
|
138
|
+
prompting = self.prompting(task='non-taxonomic-re')
|
|
139
|
+
dataset = [{"head": head, "tail": tail, "relation": relation,
|
|
140
|
+
"prompt": prompting.format(head=head, tail=tail, relation=relation)}
|
|
141
|
+
for head, tail in predicts_lst for relation in data['relations']]
|
|
142
|
+
return self._non_taxonomic_re_predict(dataset=dataset)
|
|
143
|
+
else:
|
|
144
|
+
warnings.warn("No requirement for fiting the non-taxonomic-re model, the predict module will use the input data to do the task.")
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class FalconLLM(AutoLLM):
|
|
148
|
+
|
|
149
|
+
@torch.no_grad()
|
|
150
|
+
def generate(self, inputs: List[str], max_new_tokens: int = 50) -> List[str]:
|
|
151
|
+
encoded_inputs = self.tokenizer(inputs,
|
|
152
|
+
return_tensors="pt",
|
|
153
|
+
padding=True,
|
|
154
|
+
truncation=True).to(self.model.device)
|
|
155
|
+
input_ids = encoded_inputs["input_ids"]
|
|
156
|
+
input_length = input_ids.shape[1]
|
|
157
|
+
outputs = self.model.generate(
|
|
158
|
+
input_ids,
|
|
159
|
+
max_new_tokens=max_new_tokens,
|
|
160
|
+
pad_token_id=self.tokenizer.eos_token_id
|
|
161
|
+
)
|
|
162
|
+
generated_tokens = outputs[:, input_length:]
|
|
163
|
+
decoded_outputs = [self.tokenizer.decode(g, skip_special_tokens=True).strip() for g in generated_tokens]
|
|
164
|
+
return self.label_mapper.predict(decoded_outputs)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class MistralLLM(AutoLLM):
|
|
168
|
+
|
|
169
|
+
def load(self, model_id: str) -> None:
|
|
170
|
+
self.tokenizer = MistralTokenizer.from_hf_hub(model_id)
|
|
171
|
+
if self.device == "cpu":
|
|
172
|
+
device_map = "cpu"
|
|
173
|
+
else:
|
|
174
|
+
device_map = "balanced"
|
|
175
|
+
self.model = Mistral3ForConditionalGeneration.from_pretrained(
|
|
176
|
+
model_id,
|
|
177
|
+
device_map=device_map,
|
|
178
|
+
torch_dtype=torch.bfloat16,
|
|
179
|
+
token=self.token
|
|
180
|
+
)
|
|
181
|
+
if not hasattr(self.tokenizer, "pad_token_id") or self.tokenizer.pad_token_id is None:
|
|
182
|
+
self.tokenizer.pad_token_id = self.model.generation_config.eos_token_id
|
|
183
|
+
self.label_mapper.fit()
|
|
184
|
+
|
|
185
|
+
@torch.no_grad()
|
|
186
|
+
def generate(self, inputs: List[str], max_new_tokens: int = 50) -> List[str]:
|
|
187
|
+
tokenized_list = []
|
|
188
|
+
for prompt in inputs:
|
|
189
|
+
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
|
190
|
+
tokenized = self.tokenizer.encode_chat_completion(ChatCompletionRequest(messages=messages))
|
|
191
|
+
tokenized_list.append(tokenized.tokens)
|
|
192
|
+
max_len = max(len(tokens) for tokens in tokenized_list)
|
|
193
|
+
input_ids, attention_masks = [], []
|
|
194
|
+
for tokens in tokenized_list:
|
|
195
|
+
pad_length = max_len - len(tokens)
|
|
196
|
+
input_ids.append(tokens + [self.tokenizer.pad_token_id] * pad_length)
|
|
197
|
+
attention_masks.append([1] * len(tokens) + [0] * pad_length)
|
|
198
|
+
|
|
199
|
+
input_ids = torch.tensor(input_ids).to(self.model.device)
|
|
200
|
+
attention_masks = torch.tensor(attention_masks).to(self.model.device)
|
|
201
|
+
|
|
202
|
+
outputs =self.model.generate(
|
|
203
|
+
input_ids=input_ids,
|
|
204
|
+
attention_mask=attention_masks,
|
|
205
|
+
eos_token_id=self.model.generation_config.eos_token_id,
|
|
206
|
+
pad_token_id=self.tokenizer.pad_token_id,
|
|
207
|
+
max_new_tokens=max_new_tokens,
|
|
208
|
+
)
|
|
209
|
+
decoded_outputs = []
|
|
210
|
+
for i, tokens in enumerate(outputs):
|
|
211
|
+
output_text = self.tokenizer.decode(tokens[len(tokenized_list[i]):])
|
|
212
|
+
decoded_outputs.append(output_text)
|
|
213
|
+
return self.label_mapper.predict(decoded_outputs)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class LogitMistralLLM(AutoLLM):
|
|
217
|
+
label_dict = {
|
|
218
|
+
"yes": ["yes", "true", " yes", "Yes"],
|
|
219
|
+
"no": ["no", "false", " no", "No"]
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
def _get_label_token_ids(self):
|
|
223
|
+
label_token_ids = {}
|
|
224
|
+
|
|
225
|
+
for label, words in self.label_dict.items():
|
|
226
|
+
ids = []
|
|
227
|
+
for w in words:
|
|
228
|
+
messages = [{"role": "user", "content": [{"type": "text", "text": w}]}]
|
|
229
|
+
tokenized = self.tokenizer.encode_chat_completion(ChatCompletionRequest(messages=messages))
|
|
230
|
+
token_ids = tokenized.tokens[2:-1]
|
|
231
|
+
ids.append(token_ids)
|
|
232
|
+
label_token_ids[label] = ids
|
|
233
|
+
return label_token_ids
|
|
234
|
+
|
|
235
|
+
def load(self, model_id: str) -> None:
|
|
236
|
+
self.tokenizer = MistralTokenizer.from_hf_hub(model_id)
|
|
237
|
+
self.tokenizer.padding_side = 'left'
|
|
238
|
+
device_map = "cpu" if self.device == "cpu" else "balanced"
|
|
239
|
+
self.model = Mistral3ForConditionalGeneration.from_pretrained(
|
|
240
|
+
model_id,
|
|
241
|
+
device_map=device_map,
|
|
242
|
+
torch_dtype=torch.bfloat16,
|
|
243
|
+
token=self.token
|
|
244
|
+
)
|
|
245
|
+
self.pad_token_id = self.model.generation_config.eos_token_id
|
|
246
|
+
self.label_token_ids = self._get_label_token_ids()
|
|
247
|
+
|
|
248
|
+
@torch.no_grad()
|
|
249
|
+
def generate(self, inputs: List[str], max_new_tokens: int = 1) -> List[str]:
|
|
250
|
+
tokenized_list = []
|
|
251
|
+
for prompt in inputs:
|
|
252
|
+
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
|
253
|
+
req = ChatCompletionRequest(messages=messages)
|
|
254
|
+
tokenized = self.tokenizer.encode_chat_completion(req)
|
|
255
|
+
tokenized_list.append(tokenized.tokens)
|
|
256
|
+
|
|
257
|
+
max_len = max(len(t) for t in tokenized_list)
|
|
258
|
+
input_ids, attention_masks = [], []
|
|
259
|
+
for tokens in tokenized_list:
|
|
260
|
+
pad_len = max_len - len(tokens)
|
|
261
|
+
input_ids.append(tokens + [self.pad_token_id] * pad_len)
|
|
262
|
+
attention_masks.append([1] * len(tokens) + [0] * pad_len)
|
|
263
|
+
|
|
264
|
+
input_ids = torch.tensor(input_ids).to(self.model.device)
|
|
265
|
+
attention_masks = torch.tensor(attention_masks).to(self.model.device)
|
|
266
|
+
|
|
267
|
+
outputs = self.model(input_ids=input_ids, attention_mask=attention_masks)
|
|
268
|
+
# logits: [batch, seq_len, vocab]
|
|
269
|
+
logits = outputs.logits
|
|
270
|
+
# next-token prediction
|
|
271
|
+
last_logits = logits[:, -1, :]
|
|
272
|
+
probs = torch.softmax(last_logits, dim=-1)
|
|
273
|
+
predictions = []
|
|
274
|
+
for i in range(probs.size(0)):
|
|
275
|
+
label_scores = {}
|
|
276
|
+
for label, token_id_lists in self.label_token_ids.items():
|
|
277
|
+
score = 0.0
|
|
278
|
+
for token_ids in token_id_lists:
|
|
279
|
+
# single-token in practice, but safe
|
|
280
|
+
score += probs[i, token_ids[0]].item()
|
|
281
|
+
label_scores[label] = score
|
|
282
|
+
predictions.append(max(label_scores, key=label_scores.get))
|
|
283
|
+
return predictions
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
class QwenInstructLLM(AutoLLM):
|
|
287
|
+
|
|
288
|
+
def generate(self, inputs: List[str], max_new_tokens: int = 50) -> List[str]:
|
|
289
|
+
messages = [[{"role": "user", "content": prompt + " Please show your final response with 'answer': 'label'."}]
|
|
290
|
+
for prompt in inputs]
|
|
291
|
+
|
|
292
|
+
texts = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
293
|
+
|
|
294
|
+
encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding="max_length", truncation=True,
|
|
295
|
+
max_length=256).to(self.model.device)
|
|
296
|
+
|
|
297
|
+
generated_ids = self.model.generate(**encoded_inputs,
|
|
298
|
+
max_new_tokens=max_new_tokens,
|
|
299
|
+
use_cache=False,
|
|
300
|
+
pad_token_id=self.tokenizer.pad_token_id,
|
|
301
|
+
eos_token_id=self.tokenizer.eos_token_id)
|
|
302
|
+
decoded_outputs = []
|
|
303
|
+
for i in range(len(generated_ids)):
|
|
304
|
+
prompt_len = encoded_inputs.attention_mask[i].sum().item()
|
|
305
|
+
output_ids = generated_ids[i][prompt_len:].tolist()
|
|
306
|
+
output_content = self.tokenizer.decode(output_ids, skip_special_tokens=True).strip()
|
|
307
|
+
decoded_outputs.append(output_content)
|
|
308
|
+
return self.label_mapper.predict(decoded_outputs)
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
class QwenThinkingLLM(AutoLLM):
|
|
312
|
+
|
|
313
|
+
@torch.no_grad()
|
|
314
|
+
def generate(self, inputs: List[str], max_new_tokens: int = 50) -> List[str]:
|
|
315
|
+
messages = [[{"role": "user", "content": prompt + " Please show your final response with 'answer': 'label'."}]
|
|
316
|
+
for prompt in inputs]
|
|
317
|
+
texts = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
318
|
+
encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True).to(self.model.device)
|
|
319
|
+
generated_ids = self.model.generate(**encoded_inputs, max_new_tokens=max_new_tokens)
|
|
320
|
+
decoded_outputs = []
|
|
321
|
+
for i in range(len(generated_ids)):
|
|
322
|
+
prompt_len = encoded_inputs.attention_mask[i].sum().item()
|
|
323
|
+
output_ids = generated_ids[i][prompt_len:].tolist()
|
|
324
|
+
try:
|
|
325
|
+
end = len(output_ids) - output_ids[::-1].index(151668)
|
|
326
|
+
thinking_ids = output_ids[:end]
|
|
327
|
+
except ValueError:
|
|
328
|
+
thinking_ids = output_ids
|
|
329
|
+
thinking_content = self.tokenizer.decode(thinking_ids, skip_special_tokens=True).strip()
|
|
330
|
+
decoded_outputs.append(thinking_content)
|
|
331
|
+
return self.label_mapper.predict(decoded_outputs)
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
class LogitAutoLLM(AutoLLM):
|
|
335
|
+
def _get_label_token_ids(self):
|
|
336
|
+
label_token_ids = {}
|
|
337
|
+
for label, words in self.label_mapper.label_dict.items():
|
|
338
|
+
ids = []
|
|
339
|
+
for w in words:
|
|
340
|
+
token_ids = self.tokenizer.encode(w, add_special_tokens=False)
|
|
341
|
+
ids.append(token_ids)
|
|
342
|
+
label_token_ids[label] = ids
|
|
343
|
+
return label_token_ids
|
|
344
|
+
|
|
345
|
+
def load(self, model_id: str) -> None:
|
|
346
|
+
super().load(model_id)
|
|
347
|
+
self.label_token_ids = self._get_label_token_ids()
|
|
348
|
+
|
|
349
|
+
@torch.no_grad()
|
|
350
|
+
def generate(self, inputs: List[str], max_new_tokens: int = 1) -> List[str]:
|
|
351
|
+
encoded = self.tokenizer(inputs, return_tensors="pt", truncation=True, padding=True).to(self.model.device)
|
|
352
|
+
outputs = self.model(**encoded)
|
|
353
|
+
logits = outputs.logits # logits: [batch, seq_len, vocab]
|
|
354
|
+
last_logits = logits[:, -1, :] # [batch, vocab] # we only care about the NEXT token prediction
|
|
355
|
+
probs = F.softmax(last_logits, dim=-1)
|
|
356
|
+
predictions = []
|
|
357
|
+
for i in range(probs.size(0)):
|
|
358
|
+
label_scores = {}
|
|
359
|
+
for label, token_id_lists in self.label_token_ids.items():
|
|
360
|
+
score = 0.0
|
|
361
|
+
for token_ids in token_id_lists:
|
|
362
|
+
if len(token_ids) == 1:
|
|
363
|
+
score += probs[i, token_ids[0]].item()
|
|
364
|
+
else:
|
|
365
|
+
score += probs[i, token_ids[0]].item() # multi-token fallback (rare but safe)
|
|
366
|
+
label_scores[label] = score
|
|
367
|
+
predictions.append(max(label_scores, key=label_scores.get))
|
|
368
|
+
return predictions
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
class LogitQuantAutoLLM(AutoLLM):
|
|
372
|
+
label_dict = {
|
|
373
|
+
"yes": ["yes", "true", " yes", "Yes"],
|
|
374
|
+
"no": ["no", "false", " no", "No"]
|
|
375
|
+
}
|
|
376
|
+
|
|
377
|
+
def _get_label_token_ids(self):
|
|
378
|
+
label_token_ids = {}
|
|
379
|
+
|
|
380
|
+
for label, words in self.label_dict.items():
|
|
381
|
+
ids = []
|
|
382
|
+
for w in words:
|
|
383
|
+
token_ids = self.tokenizer.encode(
|
|
384
|
+
w,
|
|
385
|
+
add_special_tokens=False
|
|
386
|
+
)
|
|
387
|
+
# usually single-token, but be safe
|
|
388
|
+
ids.append(token_ids)
|
|
389
|
+
label_token_ids[label] = ids
|
|
390
|
+
|
|
391
|
+
return label_token_ids
|
|
392
|
+
|
|
393
|
+
def load(self, model_id: str) -> None:
|
|
394
|
+
|
|
395
|
+
self.tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left', token=self.token)
|
|
396
|
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
397
|
+
if self.device == "cpu":
|
|
398
|
+
# device_map = "cpu"
|
|
399
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
400
|
+
model_id,
|
|
401
|
+
# device_map=device_map,
|
|
402
|
+
torch_dtype=torch.bfloat16,
|
|
403
|
+
token=self.token
|
|
404
|
+
)
|
|
405
|
+
else:
|
|
406
|
+
device_map = "balanced"
|
|
407
|
+
# self.model = AutoModelForCausalLM.from_pretrained(
|
|
408
|
+
# model_id,
|
|
409
|
+
# device_map=device_map,
|
|
410
|
+
# torch_dtype=torch.bfloat16,
|
|
411
|
+
# token=self.token
|
|
412
|
+
# )
|
|
413
|
+
bnb_config = BitsAndBytesConfig(
|
|
414
|
+
load_in_4bit=True,
|
|
415
|
+
bnb_4bit_quant_type="nf4",
|
|
416
|
+
bnb_4bit_compute_dtype=torch.float16,
|
|
417
|
+
bnb_4bit_use_double_quant=True
|
|
418
|
+
)
|
|
419
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
420
|
+
model_id,
|
|
421
|
+
quantization_config=bnb_config,
|
|
422
|
+
device_map=device_map,
|
|
423
|
+
token=self.token,
|
|
424
|
+
# trust_remote_code=True,
|
|
425
|
+
# attn_implementation="flash_attention_2"
|
|
426
|
+
)
|
|
427
|
+
self.label_token_ids = self._get_label_token_ids()
|
|
428
|
+
|
|
429
|
+
@torch.no_grad()
|
|
430
|
+
def generate(self, inputs: List[str], max_new_tokens: int = 1) -> List[str]:
|
|
431
|
+
encoded = self.tokenizer(
|
|
432
|
+
inputs,
|
|
433
|
+
return_tensors="pt",
|
|
434
|
+
max_length=256,
|
|
435
|
+
truncation=True,
|
|
436
|
+
padding=True
|
|
437
|
+
).to(self.model.device)
|
|
438
|
+
|
|
439
|
+
outputs = self.model(**encoded)
|
|
440
|
+
|
|
441
|
+
# logits: [batch, seq_len, vocab]
|
|
442
|
+
logits = outputs.logits
|
|
443
|
+
|
|
444
|
+
# we only care about the NEXT token prediction
|
|
445
|
+
last_logits = logits[:, -1, :] # [batch, vocab]
|
|
446
|
+
|
|
447
|
+
probs = F.softmax(last_logits, dim=-1)
|
|
448
|
+
|
|
449
|
+
predictions = []
|
|
450
|
+
|
|
451
|
+
for i in range(probs.size(0)):
|
|
452
|
+
label_scores = {}
|
|
453
|
+
|
|
454
|
+
for label, token_id_lists in self.label_token_ids.items():
|
|
455
|
+
score = 0.0
|
|
456
|
+
for token_ids in token_id_lists:
|
|
457
|
+
if len(token_ids) == 1:
|
|
458
|
+
score += probs[i, token_ids[0]].item()
|
|
459
|
+
else:
|
|
460
|
+
# multi-token fallback (rare but safe)
|
|
461
|
+
score += probs[i, token_ids[0]].item()
|
|
462
|
+
label_scores[label] = score
|
|
463
|
+
|
|
464
|
+
predictions.append(max(label_scores, key=label_scores.get))
|
|
465
|
+
return predictions
|