OntoLearner 1.4.11__tar.gz → 1.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. {ontolearner-1.4.11 → ontolearner-1.5.0}/PKG-INFO +1 -1
  2. ontolearner-1.5.0/ontolearner/VERSION +1 -0
  3. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/base/learner.py +4 -2
  4. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/__init__.py +2 -1
  5. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/label_mapper.py +4 -3
  6. ontolearner-1.5.0/ontolearner/learner/llm.py +465 -0
  7. ontolearner-1.5.0/ontolearner/learner/taxonomy_discovery/alexbek.py +822 -0
  8. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/taxonomy_discovery/skhnlp.py +216 -156
  9. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/biology.py +2 -3
  10. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/chemistry.py +16 -18
  11. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/ecology_environment.py +2 -3
  12. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/general.py +4 -6
  13. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/material_science_engineering.py +64 -45
  14. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/medicine.py +2 -3
  15. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/scholarly_knowledge.py +6 -9
  16. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/processor.py +3 -3
  17. {ontolearner-1.4.11 → ontolearner-1.5.0}/pyproject.toml +1 -1
  18. ontolearner-1.4.11/ontolearner/VERSION +0 -1
  19. ontolearner-1.4.11/ontolearner/learner/llm.py +0 -208
  20. ontolearner-1.4.11/ontolearner/learner/taxonomy_discovery/alexbek.py +0 -500
  21. {ontolearner-1.4.11 → ontolearner-1.5.0}/LICENSE +0 -0
  22. {ontolearner-1.4.11 → ontolearner-1.5.0}/README.md +0 -0
  23. {ontolearner-1.4.11 → ontolearner-1.5.0}/images/logo.png +0 -0
  24. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/__init__.py +0 -0
  25. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/_learner.py +0 -0
  26. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/_ontology.py +0 -0
  27. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/base/__init__.py +0 -0
  28. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/base/ontology.py +0 -0
  29. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/base/text2onto.py +0 -0
  30. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/data_structure/__init__.py +0 -0
  31. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/data_structure/data.py +0 -0
  32. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/data_structure/metric.py +0 -0
  33. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/evaluation/__init__.py +0 -0
  34. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/evaluation/evaluate.py +0 -0
  35. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/evaluation/metrics.py +0 -0
  36. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/prompt.py +0 -0
  37. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/rag/__init__.py +0 -0
  38. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/rag/rag.py +0 -0
  39. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/retriever/__init__.py +0 -0
  40. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/retriever/augmented_retriever.py +0 -0
  41. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/retriever/crossencoder.py +0 -0
  42. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/retriever/embedding.py +0 -0
  43. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/retriever/learner.py +0 -0
  44. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/retriever/ngram.py +0 -0
  45. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/taxonomy_discovery/__init__.py +0 -0
  46. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/taxonomy_discovery/rwthdbis.py +0 -0
  47. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/taxonomy_discovery/sbunlp.py +0 -0
  48. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/term_typing/__init__.py +0 -0
  49. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/term_typing/alexbek.py +0 -0
  50. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/term_typing/rwthdbis.py +0 -0
  51. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/term_typing/sbunlp.py +0 -0
  52. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/text2onto/__init__.py +0 -0
  53. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/text2onto/alexbek.py +0 -0
  54. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/learner/text2onto/sbunlp.py +0 -0
  55. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/__init__.py +0 -0
  56. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/agriculture.py +0 -0
  57. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/arts_humanities.py +0 -0
  58. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/education.py +0 -0
  59. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/events.py +0 -0
  60. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/finance.py +0 -0
  61. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/food_beverage.py +0 -0
  62. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/geography.py +0 -0
  63. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/industry.py +0 -0
  64. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/law.py +0 -0
  65. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/library_cultural_heritage.py +0 -0
  66. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/news_media.py +0 -0
  67. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/social_sciences.py +0 -0
  68. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/units_measurements.py +0 -0
  69. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/upper_ontologies.py +0 -0
  70. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/ontology/web.py +0 -0
  71. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/text2onto/__init__.py +0 -0
  72. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/text2onto/batchifier.py +0 -0
  73. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/text2onto/general.py +0 -0
  74. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/text2onto/splitter.py +0 -0
  75. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/text2onto/synthesizer.py +0 -0
  76. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/tools/__init__.py +0 -0
  77. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/tools/analyzer.py +0 -0
  78. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/tools/visualizer.py +0 -0
  79. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/utils/__init__.py +0 -0
  80. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/utils/io.py +0 -0
  81. {ontolearner-1.4.11 → ontolearner-1.5.0}/ontolearner/utils/train_test_split.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: OntoLearner
3
- Version: 1.4.11
3
+ Version: 1.5.0
4
4
  Summary: OntoLearner: A Modular Python Library for Ontology Learning with LLMs.
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -0,0 +1 @@
1
+ 1.5.0
@@ -232,7 +232,7 @@ class AutoLLM(ABC):
232
232
  tokenizer: The tokenizer associated with the model.
233
233
  """
234
234
 
235
- def __init__(self, label_mapper: Any, device: str='cpu', token: str="", max_length: int = 256) -> None:
235
+ def __init__(self, label_mapper: Any, device: str='cpu', token: str="", max_length: int = 512) -> None:
236
236
  """
237
237
  Initialize the LLM component.
238
238
 
@@ -283,6 +283,7 @@ class AutoLLM(ABC):
283
283
  )
284
284
  self.label_mapper.fit()
285
285
 
286
+ @torch.no_grad()
286
287
  def generate(self, inputs: List[str], max_new_tokens: int = 50) -> List[str]:
287
288
  """
288
289
  Generate text responses for the given input prompts.
@@ -309,7 +310,8 @@ class AutoLLM(ABC):
309
310
  encoded_inputs = self.tokenizer(inputs,
310
311
  return_tensors="pt",
311
312
  max_length=self.max_length,
312
- truncation=True).to(self.model.device)
313
+ truncation=True,
314
+ padding=True).to(self.model.device)
313
315
  input_ids = encoded_inputs["input_ids"]
314
316
  input_length = input_ids.shape[1]
315
317
  outputs = self.model.generate(
@@ -12,7 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .llm import AutoLLMLearner, FalconLLM, MistralLLM
15
+ from .llm import AutoLLMLearner, FalconLLM, MistralLLM, LogitMistralLLM, \
16
+ QwenInstructLLM, QwenThinkingLLM, LogitAutoLLM, LogitQuantAutoLLM
16
17
  from .retriever import AutoRetrieverLearner, LLMAugmentedRetrieverLearner
17
18
  from .rag import AutoRAGLearner, LLMAugmentedRAGLearner
18
19
  from .prompt import StandardizedPrompting
@@ -45,11 +45,12 @@ class LabelMapper:
45
45
  if label_dict is None:
46
46
  label_dict = {
47
47
  "yes": ["yes", "true"],
48
- "no": ["no", "false", " "]
48
+ "no": ["no", "false"]
49
49
  }
50
- self.labels = [label.lower() for label in list(label_dict.keys())]
50
+ self.label_dict = label_dict
51
+ self.labels = [label.lower() for label in list(self.label_dict.keys())]
51
52
  self.x_train, self.y_train = [], []
52
- for label, candidates in label_dict.items():
53
+ for label, candidates in self.label_dict.items():
53
54
  self.x_train += [label] + candidates
54
55
  self.y_train += [label] * (len(candidates) + 1)
55
56
  self.x_train = iterator_no * self.x_train
@@ -0,0 +1,465 @@
1
+ # Copyright (c) 2025 SciKnowOrg
2
+ #
3
+ # Licensed under the MIT License (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://opensource.org/licenses/MIT
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..base import AutoLLM, AutoLearner
16
+ from typing import Any, List
17
+ import warnings
18
+ from tqdm import tqdm
19
+ from torch.utils.data import DataLoader
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from transformers import Mistral3ForConditionalGeneration
23
+ from mistral_common.protocol.instruct.request import ChatCompletionRequest
24
+ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
25
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
26
+
27
+ class AutoLLMLearner(AutoLearner):
28
+
29
+ def __init__(self,
30
+ prompting,
31
+ label_mapper,
32
+ llm: AutoLLM = AutoLLM,
33
+ token: str = "",
34
+ max_new_tokens: int = 5,
35
+ batch_size: int = 10,
36
+ device='cpu') -> None:
37
+ super().__init__()
38
+ self.llm = llm(token=token, label_mapper=label_mapper, device=device)
39
+ self.prompting = prompting
40
+ self.batch_size = batch_size
41
+ self.max_new_tokens = max_new_tokens
42
+ self._is_term_typing_fit = False
43
+
44
+ def load(self, model_id: str = "mistralai/Mistral-7B-Instruct-v0.1", **kwargs: Any):
45
+ self.llm.load(model_id=model_id)
46
+
47
+ def _term_typing_predict(self, dataset):
48
+ dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)
49
+ predictions = {}
50
+ for batch in tqdm(dataloader):
51
+ prediction = self.llm.generate(inputs=batch['prompt'], max_new_tokens=self.max_new_tokens)
52
+ for term, type, predict in zip(batch['term'], batch['type'], prediction):
53
+ if term not in predictions:
54
+ predictions[term] = []
55
+ if predict == 'yes':
56
+ predictions[term].append(type)
57
+ predicts = [{"term": term, "types": types} for term, types in predictions.items()]
58
+ return predicts
59
+
60
+ def _term_typing(self, data: Any, test: bool = False) -> Any:
61
+ """
62
+ during training: data = ["type-1", .... ],
63
+ during testing: data = ['term-1', ...]
64
+ """
65
+ if not isinstance(data, list) and not all(isinstance(item, str) for item in data):
66
+ raise TypeError("Expected a list of strings (types) for llm at term-typing task.")
67
+ if test:
68
+ if self._is_term_typing_fit:
69
+ prompting = self.prompting(task='term-typing')
70
+ dataset = [{"term": term, "type": type, "prompt": prompting.format(term=term, type=type)}
71
+ for term in data for type in self.candidate_types]
72
+ return self._term_typing_predict(dataset=dataset)
73
+ else:
74
+ raise RuntimeError("Term typing model must be fit before prediction.")
75
+ else:
76
+ self.candidate_types = data
77
+ self._is_term_typing_fit = True
78
+
79
+ def _taxonomy_discovery_predict(self, dataset):
80
+ dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)
81
+ predictions = []
82
+ for batch in tqdm(dataloader):
83
+ prediction = self.llm.generate(inputs=batch['prompt'], max_new_tokens=self.max_new_tokens)
84
+ predictions.extend({"parent": parent, "child": child}
85
+ for parent, child, predict in zip(batch['parent'], batch['child'], prediction)
86
+ if predict == 'yes')
87
+ return predictions
88
+
89
+ def _taxonomy_discovery(self, data: Any, test: bool = False) -> Any:
90
+ """
91
+ during training: data = ['type-1', ...],
92
+ during testing (same data): data= ['type-1', ...]
93
+ """
94
+ if test:
95
+ if not isinstance(data, list) and not all(isinstance(item, str) for item in data):
96
+ raise TypeError("Expected a list of strings (types) for llm at term-typing task.")
97
+ prompting = self.prompting(task='taxonomy-discovery')
98
+ dataset = [{"parent": type_i, "child": type_j, "prompt": prompting.format(parent=type_i, child=type_j)}
99
+ for idx, type_i in enumerate(data) for jdx, type_j in enumerate(data) if idx < jdx]
100
+ return self._taxonomy_discovery_predict(dataset=dataset)
101
+ else:
102
+ warnings.warn("No requirement for fiting the taxonomy-discovery model, the predict module will use the input data to do the 'is-a' relationship detection")
103
+
104
+ def _non_taxonomic_re_predict(self, dataset):
105
+ dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)
106
+ predictions = []
107
+ for batch in tqdm(dataloader):
108
+ prediction = self.llm.generate(inputs=batch['prompt'], max_new_tokens=self.max_new_tokens)
109
+ predictions.extend({"head": head, "tail": tail, "relation": relation}
110
+ for head, tail, relation, predict in
111
+ zip(batch['head'], batch['tail'], batch['relation'], prediction)
112
+ if predict == 'yes')
113
+ return predictions
114
+
115
+ def _non_taxonomic_re(self, data: Any, test: bool = False) -> Any:
116
+ """
117
+ during training: data = ['type-1', ...],
118
+ during testing: {'types': [...], 'relations': [... ]}
119
+ """
120
+ if test:
121
+ if 'types' not in data or 'relations' not in data:
122
+ raise ValueError("The non-taxonomic re predict should take {'types': [...], 'relations': [... ]}")
123
+ if len(data['types']) == 0:
124
+ warnings.warn("No `types` avaliable to do the non-taxonomic re-prediction.")
125
+ return None
126
+ # paring and finding paris that can have a relationship
127
+ prompting = self.prompting(task='taxonomy-discovery')
128
+ dataset = [{"parent": type_i, "child": type_j, "prompt": prompting.format(parent=type_i, child=type_j)}
129
+ for idx, type_i in enumerate(data['types']) for jdx, type_j in enumerate(data['types']) if idx < jdx]
130
+ dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)
131
+ predicts_lst = []
132
+ for batch in tqdm(dataloader):
133
+ prediction = self.llm.generate(inputs=batch['prompt'], max_new_tokens=self.max_new_tokens)
134
+ predicts_lst.extend((parent, child)
135
+ for parent, child, predict in zip(batch['parent'], batch['child'], prediction)
136
+ if predict == 'yes')
137
+ # finding relationships
138
+ prompting = self.prompting(task='non-taxonomic-re')
139
+ dataset = [{"head": head, "tail": tail, "relation": relation,
140
+ "prompt": prompting.format(head=head, tail=tail, relation=relation)}
141
+ for head, tail in predicts_lst for relation in data['relations']]
142
+ return self._non_taxonomic_re_predict(dataset=dataset)
143
+ else:
144
+ warnings.warn("No requirement for fiting the non-taxonomic-re model, the predict module will use the input data to do the task.")
145
+
146
+
147
+ class FalconLLM(AutoLLM):
148
+
149
+ @torch.no_grad()
150
+ def generate(self, inputs: List[str], max_new_tokens: int = 50) -> List[str]:
151
+ encoded_inputs = self.tokenizer(inputs,
152
+ return_tensors="pt",
153
+ padding=True,
154
+ truncation=True).to(self.model.device)
155
+ input_ids = encoded_inputs["input_ids"]
156
+ input_length = input_ids.shape[1]
157
+ outputs = self.model.generate(
158
+ input_ids,
159
+ max_new_tokens=max_new_tokens,
160
+ pad_token_id=self.tokenizer.eos_token_id
161
+ )
162
+ generated_tokens = outputs[:, input_length:]
163
+ decoded_outputs = [self.tokenizer.decode(g, skip_special_tokens=True).strip() for g in generated_tokens]
164
+ return self.label_mapper.predict(decoded_outputs)
165
+
166
+
167
+ class MistralLLM(AutoLLM):
168
+
169
+ def load(self, model_id: str) -> None:
170
+ self.tokenizer = MistralTokenizer.from_hf_hub(model_id)
171
+ if self.device == "cpu":
172
+ device_map = "cpu"
173
+ else:
174
+ device_map = "balanced"
175
+ self.model = Mistral3ForConditionalGeneration.from_pretrained(
176
+ model_id,
177
+ device_map=device_map,
178
+ torch_dtype=torch.bfloat16,
179
+ token=self.token
180
+ )
181
+ if not hasattr(self.tokenizer, "pad_token_id") or self.tokenizer.pad_token_id is None:
182
+ self.tokenizer.pad_token_id = self.model.generation_config.eos_token_id
183
+ self.label_mapper.fit()
184
+
185
+ @torch.no_grad()
186
+ def generate(self, inputs: List[str], max_new_tokens: int = 50) -> List[str]:
187
+ tokenized_list = []
188
+ for prompt in inputs:
189
+ messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
190
+ tokenized = self.tokenizer.encode_chat_completion(ChatCompletionRequest(messages=messages))
191
+ tokenized_list.append(tokenized.tokens)
192
+ max_len = max(len(tokens) for tokens in tokenized_list)
193
+ input_ids, attention_masks = [], []
194
+ for tokens in tokenized_list:
195
+ pad_length = max_len - len(tokens)
196
+ input_ids.append(tokens + [self.tokenizer.pad_token_id] * pad_length)
197
+ attention_masks.append([1] * len(tokens) + [0] * pad_length)
198
+
199
+ input_ids = torch.tensor(input_ids).to(self.model.device)
200
+ attention_masks = torch.tensor(attention_masks).to(self.model.device)
201
+
202
+ outputs =self.model.generate(
203
+ input_ids=input_ids,
204
+ attention_mask=attention_masks,
205
+ eos_token_id=self.model.generation_config.eos_token_id,
206
+ pad_token_id=self.tokenizer.pad_token_id,
207
+ max_new_tokens=max_new_tokens,
208
+ )
209
+ decoded_outputs = []
210
+ for i, tokens in enumerate(outputs):
211
+ output_text = self.tokenizer.decode(tokens[len(tokenized_list[i]):])
212
+ decoded_outputs.append(output_text)
213
+ return self.label_mapper.predict(decoded_outputs)
214
+
215
+
216
+ class LogitMistralLLM(AutoLLM):
217
+ label_dict = {
218
+ "yes": ["yes", "true", " yes", "Yes"],
219
+ "no": ["no", "false", " no", "No"]
220
+ }
221
+
222
+ def _get_label_token_ids(self):
223
+ label_token_ids = {}
224
+
225
+ for label, words in self.label_dict.items():
226
+ ids = []
227
+ for w in words:
228
+ messages = [{"role": "user", "content": [{"type": "text", "text": w}]}]
229
+ tokenized = self.tokenizer.encode_chat_completion(ChatCompletionRequest(messages=messages))
230
+ token_ids = tokenized.tokens[2:-1]
231
+ ids.append(token_ids)
232
+ label_token_ids[label] = ids
233
+ return label_token_ids
234
+
235
+ def load(self, model_id: str) -> None:
236
+ self.tokenizer = MistralTokenizer.from_hf_hub(model_id)
237
+ self.tokenizer.padding_side = 'left'
238
+ device_map = "cpu" if self.device == "cpu" else "balanced"
239
+ self.model = Mistral3ForConditionalGeneration.from_pretrained(
240
+ model_id,
241
+ device_map=device_map,
242
+ torch_dtype=torch.bfloat16,
243
+ token=self.token
244
+ )
245
+ self.pad_token_id = self.model.generation_config.eos_token_id
246
+ self.label_token_ids = self._get_label_token_ids()
247
+
248
+ @torch.no_grad()
249
+ def generate(self, inputs: List[str], max_new_tokens: int = 1) -> List[str]:
250
+ tokenized_list = []
251
+ for prompt in inputs:
252
+ messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
253
+ req = ChatCompletionRequest(messages=messages)
254
+ tokenized = self.tokenizer.encode_chat_completion(req)
255
+ tokenized_list.append(tokenized.tokens)
256
+
257
+ max_len = max(len(t) for t in tokenized_list)
258
+ input_ids, attention_masks = [], []
259
+ for tokens in tokenized_list:
260
+ pad_len = max_len - len(tokens)
261
+ input_ids.append(tokens + [self.pad_token_id] * pad_len)
262
+ attention_masks.append([1] * len(tokens) + [0] * pad_len)
263
+
264
+ input_ids = torch.tensor(input_ids).to(self.model.device)
265
+ attention_masks = torch.tensor(attention_masks).to(self.model.device)
266
+
267
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_masks)
268
+ # logits: [batch, seq_len, vocab]
269
+ logits = outputs.logits
270
+ # next-token prediction
271
+ last_logits = logits[:, -1, :]
272
+ probs = torch.softmax(last_logits, dim=-1)
273
+ predictions = []
274
+ for i in range(probs.size(0)):
275
+ label_scores = {}
276
+ for label, token_id_lists in self.label_token_ids.items():
277
+ score = 0.0
278
+ for token_ids in token_id_lists:
279
+ # single-token in practice, but safe
280
+ score += probs[i, token_ids[0]].item()
281
+ label_scores[label] = score
282
+ predictions.append(max(label_scores, key=label_scores.get))
283
+ return predictions
284
+
285
+
286
+ class QwenInstructLLM(AutoLLM):
287
+
288
+ def generate(self, inputs: List[str], max_new_tokens: int = 50) -> List[str]:
289
+ messages = [[{"role": "user", "content": prompt + " Please show your final response with 'answer': 'label'."}]
290
+ for prompt in inputs]
291
+
292
+ texts = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
293
+
294
+ encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding="max_length", truncation=True,
295
+ max_length=256).to(self.model.device)
296
+
297
+ generated_ids = self.model.generate(**encoded_inputs,
298
+ max_new_tokens=max_new_tokens,
299
+ use_cache=False,
300
+ pad_token_id=self.tokenizer.pad_token_id,
301
+ eos_token_id=self.tokenizer.eos_token_id)
302
+ decoded_outputs = []
303
+ for i in range(len(generated_ids)):
304
+ prompt_len = encoded_inputs.attention_mask[i].sum().item()
305
+ output_ids = generated_ids[i][prompt_len:].tolist()
306
+ output_content = self.tokenizer.decode(output_ids, skip_special_tokens=True).strip()
307
+ decoded_outputs.append(output_content)
308
+ return self.label_mapper.predict(decoded_outputs)
309
+
310
+
311
+ class QwenThinkingLLM(AutoLLM):
312
+
313
+ @torch.no_grad()
314
+ def generate(self, inputs: List[str], max_new_tokens: int = 50) -> List[str]:
315
+ messages = [[{"role": "user", "content": prompt + " Please show your final response with 'answer': 'label'."}]
316
+ for prompt in inputs]
317
+ texts = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
318
+ encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True).to(self.model.device)
319
+ generated_ids = self.model.generate(**encoded_inputs, max_new_tokens=max_new_tokens)
320
+ decoded_outputs = []
321
+ for i in range(len(generated_ids)):
322
+ prompt_len = encoded_inputs.attention_mask[i].sum().item()
323
+ output_ids = generated_ids[i][prompt_len:].tolist()
324
+ try:
325
+ end = len(output_ids) - output_ids[::-1].index(151668)
326
+ thinking_ids = output_ids[:end]
327
+ except ValueError:
328
+ thinking_ids = output_ids
329
+ thinking_content = self.tokenizer.decode(thinking_ids, skip_special_tokens=True).strip()
330
+ decoded_outputs.append(thinking_content)
331
+ return self.label_mapper.predict(decoded_outputs)
332
+
333
+
334
+ class LogitAutoLLM(AutoLLM):
335
+ def _get_label_token_ids(self):
336
+ label_token_ids = {}
337
+ for label, words in self.label_mapper.label_dict.items():
338
+ ids = []
339
+ for w in words:
340
+ token_ids = self.tokenizer.encode(w, add_special_tokens=False)
341
+ ids.append(token_ids)
342
+ label_token_ids[label] = ids
343
+ return label_token_ids
344
+
345
+ def load(self, model_id: str) -> None:
346
+ super().load(model_id)
347
+ self.label_token_ids = self._get_label_token_ids()
348
+
349
+ @torch.no_grad()
350
+ def generate(self, inputs: List[str], max_new_tokens: int = 1) -> List[str]:
351
+ encoded = self.tokenizer(inputs, return_tensors="pt", truncation=True, padding=True).to(self.model.device)
352
+ outputs = self.model(**encoded)
353
+ logits = outputs.logits # logits: [batch, seq_len, vocab]
354
+ last_logits = logits[:, -1, :] # [batch, vocab] # we only care about the NEXT token prediction
355
+ probs = F.softmax(last_logits, dim=-1)
356
+ predictions = []
357
+ for i in range(probs.size(0)):
358
+ label_scores = {}
359
+ for label, token_id_lists in self.label_token_ids.items():
360
+ score = 0.0
361
+ for token_ids in token_id_lists:
362
+ if len(token_ids) == 1:
363
+ score += probs[i, token_ids[0]].item()
364
+ else:
365
+ score += probs[i, token_ids[0]].item() # multi-token fallback (rare but safe)
366
+ label_scores[label] = score
367
+ predictions.append(max(label_scores, key=label_scores.get))
368
+ return predictions
369
+
370
+
371
+ class LogitQuantAutoLLM(AutoLLM):
372
+ label_dict = {
373
+ "yes": ["yes", "true", " yes", "Yes"],
374
+ "no": ["no", "false", " no", "No"]
375
+ }
376
+
377
+ def _get_label_token_ids(self):
378
+ label_token_ids = {}
379
+
380
+ for label, words in self.label_dict.items():
381
+ ids = []
382
+ for w in words:
383
+ token_ids = self.tokenizer.encode(
384
+ w,
385
+ add_special_tokens=False
386
+ )
387
+ # usually single-token, but be safe
388
+ ids.append(token_ids)
389
+ label_token_ids[label] = ids
390
+
391
+ return label_token_ids
392
+
393
+ def load(self, model_id: str) -> None:
394
+
395
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left', token=self.token)
396
+ self.tokenizer.pad_token = self.tokenizer.eos_token
397
+ if self.device == "cpu":
398
+ # device_map = "cpu"
399
+ self.model = AutoModelForCausalLM.from_pretrained(
400
+ model_id,
401
+ # device_map=device_map,
402
+ torch_dtype=torch.bfloat16,
403
+ token=self.token
404
+ )
405
+ else:
406
+ device_map = "balanced"
407
+ # self.model = AutoModelForCausalLM.from_pretrained(
408
+ # model_id,
409
+ # device_map=device_map,
410
+ # torch_dtype=torch.bfloat16,
411
+ # token=self.token
412
+ # )
413
+ bnb_config = BitsAndBytesConfig(
414
+ load_in_4bit=True,
415
+ bnb_4bit_quant_type="nf4",
416
+ bnb_4bit_compute_dtype=torch.float16,
417
+ bnb_4bit_use_double_quant=True
418
+ )
419
+ self.model = AutoModelForCausalLM.from_pretrained(
420
+ model_id,
421
+ quantization_config=bnb_config,
422
+ device_map=device_map,
423
+ token=self.token,
424
+ # trust_remote_code=True,
425
+ # attn_implementation="flash_attention_2"
426
+ )
427
+ self.label_token_ids = self._get_label_token_ids()
428
+
429
+ @torch.no_grad()
430
+ def generate(self, inputs: List[str], max_new_tokens: int = 1) -> List[str]:
431
+ encoded = self.tokenizer(
432
+ inputs,
433
+ return_tensors="pt",
434
+ max_length=256,
435
+ truncation=True,
436
+ padding=True
437
+ ).to(self.model.device)
438
+
439
+ outputs = self.model(**encoded)
440
+
441
+ # logits: [batch, seq_len, vocab]
442
+ logits = outputs.logits
443
+
444
+ # we only care about the NEXT token prediction
445
+ last_logits = logits[:, -1, :] # [batch, vocab]
446
+
447
+ probs = F.softmax(last_logits, dim=-1)
448
+
449
+ predictions = []
450
+
451
+ for i in range(probs.size(0)):
452
+ label_scores = {}
453
+
454
+ for label, token_id_lists in self.label_token_ids.items():
455
+ score = 0.0
456
+ for token_ids in token_id_lists:
457
+ if len(token_ids) == 1:
458
+ score += probs[i, token_ids[0]].item()
459
+ else:
460
+ # multi-token fallback (rare but safe)
461
+ score += probs[i, token_ids[0]].item()
462
+ label_scores[label] = score
463
+
464
+ predictions.append(max(label_scores, key=label_scores.get))
465
+ return predictions