OntoLearner 1.4.10__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. ontolearner/VERSION +1 -1
  2. ontolearner/base/learner.py +41 -18
  3. ontolearner/evaluation/metrics.py +72 -32
  4. ontolearner/learner/__init__.py +3 -2
  5. ontolearner/learner/label_mapper.py +5 -4
  6. ontolearner/learner/llm.py +257 -0
  7. ontolearner/learner/prompt.py +40 -5
  8. ontolearner/learner/rag/__init__.py +14 -0
  9. ontolearner/learner/{rag.py → rag/rag.py} +7 -2
  10. ontolearner/learner/retriever/__init__.py +1 -1
  11. ontolearner/learner/retriever/{llm_retriever.py → augmented_retriever.py} +48 -39
  12. ontolearner/learner/retriever/learner.py +3 -4
  13. ontolearner/learner/taxonomy_discovery/alexbek.py +632 -310
  14. ontolearner/learner/taxonomy_discovery/skhnlp.py +216 -156
  15. ontolearner/learner/text2onto/__init__.py +1 -1
  16. ontolearner/learner/text2onto/alexbek.py +484 -1105
  17. ontolearner/learner/text2onto/sbunlp.py +498 -493
  18. ontolearner/ontology/biology.py +2 -3
  19. ontolearner/ontology/chemistry.py +16 -18
  20. ontolearner/ontology/ecology_environment.py +2 -3
  21. ontolearner/ontology/general.py +4 -6
  22. ontolearner/ontology/material_science_engineering.py +64 -45
  23. ontolearner/ontology/medicine.py +2 -3
  24. ontolearner/ontology/scholarly_knowledge.py +6 -9
  25. ontolearner/processor.py +3 -3
  26. ontolearner/text2onto/splitter.py +69 -6
  27. {ontolearner-1.4.10.dist-info → ontolearner-1.5.0.dist-info}/METADATA +2 -2
  28. {ontolearner-1.4.10.dist-info → ontolearner-1.5.0.dist-info}/RECORD +30 -29
  29. {ontolearner-1.4.10.dist-info → ontolearner-1.5.0.dist-info}/WHEEL +1 -1
  30. {ontolearner-1.4.10.dist-info → ontolearner-1.5.0.dist-info}/licenses/LICENSE +0 -0
ontolearner/VERSION CHANGED
@@ -1 +1 @@
1
- 1.4.10
1
+ 1.5.0
@@ -18,6 +18,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
18
18
  import torch
19
19
  import torch.nn.functional as F
20
20
  from sentence_transformers import SentenceTransformer
21
+ from collections import defaultdict
21
22
 
22
23
  class AutoLearner(ABC):
23
24
  """
@@ -70,6 +71,7 @@ class AutoLearner(ABC):
70
71
  - "term-typing": Predict semantic types for terms
71
72
  - "taxonomy-discovery": Identify hierarchical relationships
72
73
  - "non-taxonomy-discovery": Identify non-hierarchical relationships
74
+ - "text2onto" : Extract ontology terms and their semantic types from documents
73
75
 
74
76
  Raises:
75
77
  NotImplementedError: If not implemented by concrete class.
@@ -81,6 +83,8 @@ class AutoLearner(ABC):
81
83
  self._taxonomy_discovery(train_data, test=False)
82
84
  elif task == 'non-taxonomic-re':
83
85
  self._non_taxonomic_re(train_data, test=False)
86
+ elif task == 'text2onto':
87
+ self._text2onto(train_data, test=False)
84
88
  else:
85
89
  raise ValueError(f"{task} is not a valid task.")
86
90
 
@@ -103,6 +107,7 @@ class AutoLearner(ABC):
103
107
  - term-typing: List of predicted types for each term
104
108
  - taxonomy-discovery: Boolean predictions for relationships
105
109
  - non-taxonomy-discovery: Predicted relation types
110
+ - text2onto : Extract ontology terms and their semantic types from documents
106
111
 
107
112
  Raises:
108
113
  NotImplementedError: If not implemented by concrete class.
@@ -115,6 +120,8 @@ class AutoLearner(ABC):
115
120
  return self._taxonomy_discovery(eval_data, test=True)
116
121
  elif task == 'non-taxonomic-re':
117
122
  return self._non_taxonomic_re(eval_data, test=True)
123
+ elif task == 'text2onto':
124
+ return self._text2onto(eval_data, test=True)
118
125
  else:
119
126
  raise ValueError(f"{task} is not a valid task.")
120
127
 
@@ -147,6 +154,9 @@ class AutoLearner(ABC):
147
154
  def _non_taxonomic_re(self, data: Any, test: bool = False) -> Optional[Any]:
148
155
  pass
149
156
 
157
+ def _text2onto(self, data: Any, test: bool = False) -> Optional[Any]:
158
+ pass
159
+
150
160
  def tasks_data_former(self, data: Any, task: str, test: bool = False) -> List[str | Dict[str, str]]:
151
161
  formatted_data = []
152
162
  if task == "term-typing":
@@ -171,6 +181,7 @@ class AutoLearner(ABC):
171
181
  non_taxonomic_types = list(set(non_taxonomic_types))
172
182
  non_taxonomic_res = list(set(non_taxonomic_res))
173
183
  formatted_data = {"types": non_taxonomic_types, "relations": non_taxonomic_res}
184
+
174
185
  return formatted_data
175
186
 
176
187
  def tasks_ground_truth_former(self, data: Any, task: str) -> List[Dict[str, str]]:
@@ -186,6 +197,26 @@ class AutoLearner(ABC):
186
197
  formatted_data.append({"head": non_taxonomic_triplets.head,
187
198
  "tail": non_taxonomic_triplets.tail,
188
199
  "relation": non_taxonomic_triplets.relation})
200
+ if task == "text2onto":
201
+ terms2docs = data.get("terms2docs", {}) or {}
202
+ terms2types = data.get("terms2types", {}) or {}
203
+
204
+ # gold doc→terms
205
+ gold_terms = []
206
+ for term, doc_ids in terms2docs.items():
207
+ for doc_id in doc_ids or []:
208
+ gold_terms.append({"doc_id": doc_id, "term": term})
209
+
210
+ # gold doc→types derived via doc→terms + term→types
211
+ doc2types = defaultdict(set)
212
+ for term, doc_ids in terms2docs.items():
213
+ for doc_id in doc_ids or []:
214
+ for ty in (terms2types.get(term, []) or []):
215
+ if isinstance(ty, str) and ty.strip():
216
+ doc2types[doc_id].add(ty.strip())
217
+ gold_types = [{"doc_id": doc_id, "type": ty} for doc_id, tys in doc2types.items() for ty in tys]
218
+ return {"terms": gold_terms, "types": gold_types}
219
+
189
220
  return formatted_data
190
221
 
191
222
  class AutoLLM(ABC):
@@ -201,7 +232,7 @@ class AutoLLM(ABC):
201
232
  tokenizer: The tokenizer associated with the model.
202
233
  """
203
234
 
204
- def __init__(self, label_mapper: Any, device: str='cpu', token: str="") -> None:
235
+ def __init__(self, label_mapper: Any, device: str='cpu', token: str="", max_length: int = 512) -> None:
205
236
  """
206
237
  Initialize the LLM component.
207
238
 
@@ -213,6 +244,7 @@ class AutoLLM(ABC):
213
244
  self.device=device
214
245
  self.model: Optional[Any] = None
215
246
  self.tokenizer: Optional[Any] = None
247
+ self.max_length = max_length
216
248
 
217
249
 
218
250
  def load(self, model_id: str) -> None:
@@ -236,10 +268,8 @@ class AutoLLM(ABC):
236
268
  self.tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left', token=self.token)
237
269
  self.tokenizer.pad_token = self.tokenizer.eos_token
238
270
  if self.device == "cpu":
239
- # device_map = "cpu"
240
271
  self.model = AutoModelForCausalLM.from_pretrained(
241
272
  model_id,
242
- # device_map=device_map,
243
273
  torch_dtype=torch.bfloat16,
244
274
  token=self.token
245
275
  )
@@ -248,11 +278,12 @@ class AutoLLM(ABC):
248
278
  self.model = AutoModelForCausalLM.from_pretrained(
249
279
  model_id,
250
280
  device_map=device_map,
251
- torch_dtype=torch.bfloat16,
252
- token=self.token
281
+ token=self.token,
282
+ trust_remote_code=True,
253
283
  )
254
284
  self.label_mapper.fit()
255
285
 
286
+ @torch.no_grad()
256
287
  def generate(self, inputs: List[str], max_new_tokens: int = 50) -> List[str]:
257
288
  """
258
289
  Generate text responses for the given input prompts.
@@ -276,29 +307,21 @@ class AutoLLM(ABC):
276
307
  List of generated text responses, one for each input prompt.
277
308
  Responses include the original input plus generated continuation.
278
309
  """
279
- # Tokenize inputs and move to device
280
310
  encoded_inputs = self.tokenizer(inputs,
281
311
  return_tensors="pt",
282
- padding=True,
283
- truncation=True).to(self.model.device)
312
+ max_length=self.max_length,
313
+ truncation=True,
314
+ padding=True).to(self.model.device)
284
315
  input_ids = encoded_inputs["input_ids"]
285
316
  input_length = input_ids.shape[1]
286
-
287
- # Generate output
288
317
  outputs = self.model.generate(
289
318
  **encoded_inputs,
290
319
  max_new_tokens=max_new_tokens,
291
- pad_token_id=self.tokenizer.eos_token_id
320
+ pad_token_id=self.tokenizer.eos_token_id,
321
+ eos_token_id=self.tokenizer.eos_token_id
292
322
  )
293
-
294
- # Extract only the newly generated tokens (excluding prompt)
295
323
  generated_tokens = outputs[:, input_length:]
296
-
297
- # Decode only the generated part
298
324
  decoded_outputs = [self.tokenizer.decode(g, skip_special_tokens=True).strip() for g in generated_tokens]
299
- print(decoded_outputs)
300
- print(self.label_mapper.predict(decoded_outputs))
301
- # Map the decoded text to labels
302
325
  return self.label_mapper.predict(decoded_outputs)
303
326
 
304
327
  class AutoRetriever(ABC):
@@ -11,44 +11,84 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- from typing import List, Dict, Tuple, Set
14
+ from typing import List, Dict, Tuple, Set, Any, Union
15
15
 
16
16
  SYMMETRIC_RELATIONS = {"equivalentclass", "sameas", "disjointwith"}
17
17
 
18
- def text2onto_metrics(y_true: List[str], y_pred: List[str], similarity_threshold: float = 0.8) -> Dict[str, float | int]:
19
- def jaccard_similarity(a: str, b: str) -> float:
20
- set_a = set(a.lower().split())
21
- set_b = set(b.lower().split())
22
- if not set_a and not set_b:
18
+ def text2onto_metrics(
19
+ y_true: Dict[str, Any],
20
+ y_pred: Dict[str, Any],
21
+ similarity_threshold: float = 0.8
22
+ ) -> Dict[str, Any]:
23
+ """
24
+ Expects:
25
+ y_true = {"terms": [{"doc_id": str, "term": str}, ...],
26
+ "types": [{"doc_id": str, "type": str}, ...]}
27
+ y_pred = same shape
28
+
29
+ Returns:
30
+ {"terms": {...}, "types": {...}}
31
+ """
32
+
33
+ def jaccard_similarity(text_a: str, text_b: str) -> float:
34
+ tokens_a = set(text_a.lower().split())
35
+ tokens_b = set(text_b.lower().split())
36
+ if not tokens_a and not tokens_b:
23
37
  return 1.0
24
- return len(set_a & set_b) / len(set_a | set_b)
25
-
26
- matched_gt_indices = set()
27
- matched_pred_indices = set()
28
- for i, pred_label in enumerate(y_pred):
29
- for j, gt_label in enumerate(y_true):
30
- if j in matched_gt_indices:
31
- continue
32
- sim = jaccard_similarity(pred_label, gt_label)
33
- if sim >= similarity_threshold:
34
- matched_pred_indices.add(i)
35
- matched_gt_indices.add(j)
36
- break # each gt matched once
37
-
38
- total_correct = len(matched_pred_indices)
39
- total_predicted = len(y_pred)
40
- total_ground_truth = len(y_true)
38
+ return len(tokens_a & tokens_b) / len(tokens_a | tokens_b)
39
+
40
+ def pairs_to_strings(rows: List[Dict[str, str]], value_key: str) -> List[str]:
41
+ paired_strings: List[str] = []
42
+ for row in rows or []:
43
+ doc_id = (row.get("doc_id") or "").strip()
44
+ value = (row.get(value_key) or "").strip()
45
+ if doc_id and value:
46
+ # keep doc association + allow token Jaccard
47
+ paired_strings.append(f"{doc_id} {value}")
48
+ return paired_strings
49
+
50
+ def score_list(ground_truth_items: List[str], predicted_items: List[str]) -> Dict[str, Union[float, int]]:
51
+ matched_ground_truth_indices: Set[int] = set()
52
+ matched_predicted_indices: Set[int] = set()
53
+
54
+ for predicted_index, predicted_item in enumerate(predicted_items):
55
+ for ground_truth_index, ground_truth_item in enumerate(ground_truth_items):
56
+ if ground_truth_index in matched_ground_truth_indices:
57
+ continue
58
+
59
+ if jaccard_similarity(predicted_item, ground_truth_item) >= similarity_threshold:
60
+ matched_predicted_indices.add(predicted_index)
61
+ matched_ground_truth_indices.add(ground_truth_index)
62
+ break
63
+
64
+ total_correct = len(matched_predicted_indices)
65
+ total_predicted = len(predicted_items)
66
+ total_ground_truth = len(ground_truth_items)
67
+
68
+ precision = total_correct / total_predicted if total_predicted else 0.0
69
+ recall = total_correct / total_ground_truth if total_ground_truth else 0.0
70
+ f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
71
+
72
+ return {
73
+ "f1_score": f1,
74
+ "precision": precision,
75
+ "recall": recall,
76
+ "total_correct": total_correct,
77
+ "total_predicted": total_predicted,
78
+ "total_ground_truth": total_ground_truth,
79
+ }
80
+
81
+ ground_truth_terms = pairs_to_strings(y_true.get("terms", []), "term")
82
+ predicted_terms = pairs_to_strings(y_pred.get("terms", []), "term")
83
+ ground_truth_types = pairs_to_strings(y_true.get("types", []), "type")
84
+ predicted_types = pairs_to_strings(y_pred.get("types", []), "type")
85
+
86
+ terms_metrics = score_list(ground_truth_terms, predicted_terms)
87
+ types_metrics = score_list(ground_truth_types, predicted_types)
41
88
 
42
- precision = total_correct / total_predicted if total_predicted > 0 else 0
43
- recall = total_correct / total_ground_truth if total_ground_truth > 0 else 0
44
- f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
45
89
  return {
46
- "f1_score": f1_score,
47
- "precision": precision,
48
- "recall": recall,
49
- "total_correct": total_correct,
50
- "total_predicted": total_predicted,
51
- "total_ground_truth": total_ground_truth
90
+ "terms": terms_metrics,
91
+ "types": types_metrics,
52
92
  }
53
93
 
54
94
  def term_typing_metrics(y_true: List[Dict[str, List[str]]], y_pred: List[Dict[str, List[str]]]) -> Dict[str, float | int]:
@@ -12,8 +12,9 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .llm import AutoLLMLearner, FalconLLM, MistralLLM
15
+ from .llm import AutoLLMLearner, FalconLLM, MistralLLM, LogitMistralLLM, \
16
+ QwenInstructLLM, QwenThinkingLLM, LogitAutoLLM, LogitQuantAutoLLM
16
17
  from .retriever import AutoRetrieverLearner, LLMAugmentedRetrieverLearner
17
- from .rag import AutoRAGLearner
18
+ from .rag import AutoRAGLearner, LLMAugmentedRAGLearner
18
19
  from .prompt import StandardizedPrompting
19
20
  from .label_mapper import LabelMapper
@@ -31,7 +31,7 @@ class LabelMapper:
31
31
  ngram_range: Tuple=(1, 1),
32
32
  label_dict: Dict[str, List[str]]=None,
33
33
  analyzer: str = 'word',
34
- iterator_no: int = 100):
34
+ iterator_no: int = 1000):
35
35
  """
36
36
  Initializes the TFIDFLabelMapper with a specified classifier and TF-IDF configuration.
37
37
 
@@ -45,11 +45,12 @@ class LabelMapper:
45
45
  if label_dict is None:
46
46
  label_dict = {
47
47
  "yes": ["yes", "true"],
48
- "no": ["no", "false", " "]
48
+ "no": ["no", "false"]
49
49
  }
50
- self.labels = [label.lower() for label in list(label_dict.keys())]
50
+ self.label_dict = label_dict
51
+ self.labels = [label.lower() for label in list(self.label_dict.keys())]
51
52
  self.x_train, self.y_train = [], []
52
- for label, candidates in label_dict.items():
53
+ for label, candidates in self.label_dict.items():
53
54
  self.x_train += [label] + candidates
54
55
  self.y_train += [label] * (len(candidates) + 1)
55
56
  self.x_train = iterator_no * self.x_train
@@ -18,9 +18,11 @@ import warnings
18
18
  from tqdm import tqdm
19
19
  from torch.utils.data import DataLoader
20
20
  import torch
21
+ import torch.nn.functional as F
21
22
  from transformers import Mistral3ForConditionalGeneration
22
23
  from mistral_common.protocol.instruct.request import ChatCompletionRequest
23
24
  from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
25
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
24
26
 
25
27
  class AutoLLMLearner(AutoLearner):
26
28
 
@@ -144,6 +146,7 @@ class AutoLLMLearner(AutoLearner):
144
146
 
145
147
  class FalconLLM(AutoLLM):
146
148
 
149
+ @torch.no_grad()
147
150
  def generate(self, inputs: List[str], max_new_tokens: int = 50) -> List[str]:
148
151
  encoded_inputs = self.tokenizer(inputs,
149
152
  return_tensors="pt",
@@ -160,6 +163,7 @@ class FalconLLM(AutoLLM):
160
163
  decoded_outputs = [self.tokenizer.decode(g, skip_special_tokens=True).strip() for g in generated_tokens]
161
164
  return self.label_mapper.predict(decoded_outputs)
162
165
 
166
+
163
167
  class MistralLLM(AutoLLM):
164
168
 
165
169
  def load(self, model_id: str) -> None:
@@ -178,6 +182,7 @@ class MistralLLM(AutoLLM):
178
182
  self.tokenizer.pad_token_id = self.model.generation_config.eos_token_id
179
183
  self.label_mapper.fit()
180
184
 
185
+ @torch.no_grad()
181
186
  def generate(self, inputs: List[str], max_new_tokens: int = 50) -> List[str]:
182
187
  tokenized_list = []
183
188
  for prompt in inputs:
@@ -206,3 +211,255 @@ class MistralLLM(AutoLLM):
206
211
  output_text = self.tokenizer.decode(tokens[len(tokenized_list[i]):])
207
212
  decoded_outputs.append(output_text)
208
213
  return self.label_mapper.predict(decoded_outputs)
214
+
215
+
216
+ class LogitMistralLLM(AutoLLM):
217
+ label_dict = {
218
+ "yes": ["yes", "true", " yes", "Yes"],
219
+ "no": ["no", "false", " no", "No"]
220
+ }
221
+
222
+ def _get_label_token_ids(self):
223
+ label_token_ids = {}
224
+
225
+ for label, words in self.label_dict.items():
226
+ ids = []
227
+ for w in words:
228
+ messages = [{"role": "user", "content": [{"type": "text", "text": w}]}]
229
+ tokenized = self.tokenizer.encode_chat_completion(ChatCompletionRequest(messages=messages))
230
+ token_ids = tokenized.tokens[2:-1]
231
+ ids.append(token_ids)
232
+ label_token_ids[label] = ids
233
+ return label_token_ids
234
+
235
+ def load(self, model_id: str) -> None:
236
+ self.tokenizer = MistralTokenizer.from_hf_hub(model_id)
237
+ self.tokenizer.padding_side = 'left'
238
+ device_map = "cpu" if self.device == "cpu" else "balanced"
239
+ self.model = Mistral3ForConditionalGeneration.from_pretrained(
240
+ model_id,
241
+ device_map=device_map,
242
+ torch_dtype=torch.bfloat16,
243
+ token=self.token
244
+ )
245
+ self.pad_token_id = self.model.generation_config.eos_token_id
246
+ self.label_token_ids = self._get_label_token_ids()
247
+
248
+ @torch.no_grad()
249
+ def generate(self, inputs: List[str], max_new_tokens: int = 1) -> List[str]:
250
+ tokenized_list = []
251
+ for prompt in inputs:
252
+ messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
253
+ req = ChatCompletionRequest(messages=messages)
254
+ tokenized = self.tokenizer.encode_chat_completion(req)
255
+ tokenized_list.append(tokenized.tokens)
256
+
257
+ max_len = max(len(t) for t in tokenized_list)
258
+ input_ids, attention_masks = [], []
259
+ for tokens in tokenized_list:
260
+ pad_len = max_len - len(tokens)
261
+ input_ids.append(tokens + [self.pad_token_id] * pad_len)
262
+ attention_masks.append([1] * len(tokens) + [0] * pad_len)
263
+
264
+ input_ids = torch.tensor(input_ids).to(self.model.device)
265
+ attention_masks = torch.tensor(attention_masks).to(self.model.device)
266
+
267
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_masks)
268
+ # logits: [batch, seq_len, vocab]
269
+ logits = outputs.logits
270
+ # next-token prediction
271
+ last_logits = logits[:, -1, :]
272
+ probs = torch.softmax(last_logits, dim=-1)
273
+ predictions = []
274
+ for i in range(probs.size(0)):
275
+ label_scores = {}
276
+ for label, token_id_lists in self.label_token_ids.items():
277
+ score = 0.0
278
+ for token_ids in token_id_lists:
279
+ # single-token in practice, but safe
280
+ score += probs[i, token_ids[0]].item()
281
+ label_scores[label] = score
282
+ predictions.append(max(label_scores, key=label_scores.get))
283
+ return predictions
284
+
285
+
286
+ class QwenInstructLLM(AutoLLM):
287
+
288
+ def generate(self, inputs: List[str], max_new_tokens: int = 50) -> List[str]:
289
+ messages = [[{"role": "user", "content": prompt + " Please show your final response with 'answer': 'label'."}]
290
+ for prompt in inputs]
291
+
292
+ texts = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
293
+
294
+ encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding="max_length", truncation=True,
295
+ max_length=256).to(self.model.device)
296
+
297
+ generated_ids = self.model.generate(**encoded_inputs,
298
+ max_new_tokens=max_new_tokens,
299
+ use_cache=False,
300
+ pad_token_id=self.tokenizer.pad_token_id,
301
+ eos_token_id=self.tokenizer.eos_token_id)
302
+ decoded_outputs = []
303
+ for i in range(len(generated_ids)):
304
+ prompt_len = encoded_inputs.attention_mask[i].sum().item()
305
+ output_ids = generated_ids[i][prompt_len:].tolist()
306
+ output_content = self.tokenizer.decode(output_ids, skip_special_tokens=True).strip()
307
+ decoded_outputs.append(output_content)
308
+ return self.label_mapper.predict(decoded_outputs)
309
+
310
+
311
+ class QwenThinkingLLM(AutoLLM):
312
+
313
+ @torch.no_grad()
314
+ def generate(self, inputs: List[str], max_new_tokens: int = 50) -> List[str]:
315
+ messages = [[{"role": "user", "content": prompt + " Please show your final response with 'answer': 'label'."}]
316
+ for prompt in inputs]
317
+ texts = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
318
+ encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True).to(self.model.device)
319
+ generated_ids = self.model.generate(**encoded_inputs, max_new_tokens=max_new_tokens)
320
+ decoded_outputs = []
321
+ for i in range(len(generated_ids)):
322
+ prompt_len = encoded_inputs.attention_mask[i].sum().item()
323
+ output_ids = generated_ids[i][prompt_len:].tolist()
324
+ try:
325
+ end = len(output_ids) - output_ids[::-1].index(151668)
326
+ thinking_ids = output_ids[:end]
327
+ except ValueError:
328
+ thinking_ids = output_ids
329
+ thinking_content = self.tokenizer.decode(thinking_ids, skip_special_tokens=True).strip()
330
+ decoded_outputs.append(thinking_content)
331
+ return self.label_mapper.predict(decoded_outputs)
332
+
333
+
334
+ class LogitAutoLLM(AutoLLM):
335
+ def _get_label_token_ids(self):
336
+ label_token_ids = {}
337
+ for label, words in self.label_mapper.label_dict.items():
338
+ ids = []
339
+ for w in words:
340
+ token_ids = self.tokenizer.encode(w, add_special_tokens=False)
341
+ ids.append(token_ids)
342
+ label_token_ids[label] = ids
343
+ return label_token_ids
344
+
345
+ def load(self, model_id: str) -> None:
346
+ super().load(model_id)
347
+ self.label_token_ids = self._get_label_token_ids()
348
+
349
+ @torch.no_grad()
350
+ def generate(self, inputs: List[str], max_new_tokens: int = 1) -> List[str]:
351
+ encoded = self.tokenizer(inputs, return_tensors="pt", truncation=True, padding=True).to(self.model.device)
352
+ outputs = self.model(**encoded)
353
+ logits = outputs.logits # logits: [batch, seq_len, vocab]
354
+ last_logits = logits[:, -1, :] # [batch, vocab] # we only care about the NEXT token prediction
355
+ probs = F.softmax(last_logits, dim=-1)
356
+ predictions = []
357
+ for i in range(probs.size(0)):
358
+ label_scores = {}
359
+ for label, token_id_lists in self.label_token_ids.items():
360
+ score = 0.0
361
+ for token_ids in token_id_lists:
362
+ if len(token_ids) == 1:
363
+ score += probs[i, token_ids[0]].item()
364
+ else:
365
+ score += probs[i, token_ids[0]].item() # multi-token fallback (rare but safe)
366
+ label_scores[label] = score
367
+ predictions.append(max(label_scores, key=label_scores.get))
368
+ return predictions
369
+
370
+
371
+ class LogitQuantAutoLLM(AutoLLM):
372
+ label_dict = {
373
+ "yes": ["yes", "true", " yes", "Yes"],
374
+ "no": ["no", "false", " no", "No"]
375
+ }
376
+
377
+ def _get_label_token_ids(self):
378
+ label_token_ids = {}
379
+
380
+ for label, words in self.label_dict.items():
381
+ ids = []
382
+ for w in words:
383
+ token_ids = self.tokenizer.encode(
384
+ w,
385
+ add_special_tokens=False
386
+ )
387
+ # usually single-token, but be safe
388
+ ids.append(token_ids)
389
+ label_token_ids[label] = ids
390
+
391
+ return label_token_ids
392
+
393
+ def load(self, model_id: str) -> None:
394
+
395
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left', token=self.token)
396
+ self.tokenizer.pad_token = self.tokenizer.eos_token
397
+ if self.device == "cpu":
398
+ # device_map = "cpu"
399
+ self.model = AutoModelForCausalLM.from_pretrained(
400
+ model_id,
401
+ # device_map=device_map,
402
+ torch_dtype=torch.bfloat16,
403
+ token=self.token
404
+ )
405
+ else:
406
+ device_map = "balanced"
407
+ # self.model = AutoModelForCausalLM.from_pretrained(
408
+ # model_id,
409
+ # device_map=device_map,
410
+ # torch_dtype=torch.bfloat16,
411
+ # token=self.token
412
+ # )
413
+ bnb_config = BitsAndBytesConfig(
414
+ load_in_4bit=True,
415
+ bnb_4bit_quant_type="nf4",
416
+ bnb_4bit_compute_dtype=torch.float16,
417
+ bnb_4bit_use_double_quant=True
418
+ )
419
+ self.model = AutoModelForCausalLM.from_pretrained(
420
+ model_id,
421
+ quantization_config=bnb_config,
422
+ device_map=device_map,
423
+ token=self.token,
424
+ # trust_remote_code=True,
425
+ # attn_implementation="flash_attention_2"
426
+ )
427
+ self.label_token_ids = self._get_label_token_ids()
428
+
429
+ @torch.no_grad()
430
+ def generate(self, inputs: List[str], max_new_tokens: int = 1) -> List[str]:
431
+ encoded = self.tokenizer(
432
+ inputs,
433
+ return_tensors="pt",
434
+ max_length=256,
435
+ truncation=True,
436
+ padding=True
437
+ ).to(self.model.device)
438
+
439
+ outputs = self.model(**encoded)
440
+
441
+ # logits: [batch, seq_len, vocab]
442
+ logits = outputs.logits
443
+
444
+ # we only care about the NEXT token prediction
445
+ last_logits = logits[:, -1, :] # [batch, vocab]
446
+
447
+ probs = F.softmax(last_logits, dim=-1)
448
+
449
+ predictions = []
450
+
451
+ for i in range(probs.size(0)):
452
+ label_scores = {}
453
+
454
+ for label, token_id_lists in self.label_token_ids.items():
455
+ score = 0.0
456
+ for token_ids in token_id_lists:
457
+ if len(token_ids) == 1:
458
+ score += probs[i, token_ids[0]].item()
459
+ else:
460
+ # multi-token fallback (rare but safe)
461
+ score += probs[i, token_ids[0]].item()
462
+ label_scores[label] = score
463
+
464
+ predictions.append(max(label_scores, key=label_scores.get))
465
+ return predictions
@@ -17,15 +17,50 @@ from ..base import AutoPrompt
17
17
  class StandardizedPrompting(AutoPrompt):
18
18
  def __init__(self, task: str = None):
19
19
  if task == "term-typing":
20
- prompt_template = """Determine whether the given term can be categorized as an instance of the specified high-level type. Answer with `yes` if it is otherwise answer with `no`. Do not explain.
20
+ prompt_template = """You are performing term typing.
21
+
22
+ Determine whether the given term is a clear and unambiguous instance of the specified high-level type.
23
+
24
+ Rules:
25
+ - Answer "yes" only if the term commonly and directly belongs to the type.
26
+ - Answer "no" if the term does not belong to the type, is ambiguous, or only weakly related.
27
+ - Use the most common meaning of the term.
28
+ - Do not explain your answer.
29
+
21
30
  Term: {term}
22
31
  Type: {type}
23
- Answer: """
32
+ Answer (yes or no):"""
24
33
  elif task == "taxonomy-discovery":
25
- prompt_template = """Is {parent} a direct or indirect superclass (or parent concept) of {child} in a conceptual hierarchy? Answer with yes or no.
26
- Answer: """
34
+ prompt_template = """You are identifying taxonomic (is-a) relationships.
35
+
36
+ Question:
37
+ Is "{parent}" a superclass (direct or indirect) of "{child}" in a standard conceptual or ontological hierarchy?
38
+
39
+ Rules:
40
+ - A superclass means: "{child}" is a type or instance of "{parent}".
41
+ - Answer "yes" only if the relationship is a true is-a relationship.
42
+ - Answer "no" for part-of, related-to, or associative relationships.
43
+ - Use general world knowledge.
44
+ - Do not explain.
45
+
46
+ Parent: {parent}
47
+ Child: {child}
48
+ Answer (yes or no):"""
27
49
  elif task == "non-taxonomic-re":
28
- prompt_template = """Given the conceptual types `{head}` and `{tail}`, does a `{relation}` relation exist between them? Respond with "yes" if it does, otherwise respond with "no"."""
50
+ prompt_template = """You are identifying non-taxonomic conceptual relationships.
51
+
52
+ Given two conceptual types, determine whether the specified relation typically holds between them.
53
+
54
+ Rules:
55
+ - Answer "yes" only if the relation commonly and meaningfully applies.
56
+ - Answer "no" if the relation is rare, indirect, or context-dependent.
57
+ - Do not infer relations that require specific situations.
58
+ - Do not explain.
59
+
60
+ Head type: {head}
61
+ Tail type: {tail}
62
+ Relation: {relation}
63
+ Answer (yes or no):"""
29
64
  else:
30
65
  raise ValueError("Unknown task! Current tasks are: 'term-typing', 'taxonomy-discovery', 'non-taxonomic-re'")
31
66
  super().__init__(prompt_template)