OntoLearner 1.4.10__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. ontolearner/VERSION +1 -1
  2. ontolearner/base/learner.py +41 -18
  3. ontolearner/evaluation/metrics.py +72 -32
  4. ontolearner/learner/__init__.py +3 -2
  5. ontolearner/learner/label_mapper.py +5 -4
  6. ontolearner/learner/llm.py +257 -0
  7. ontolearner/learner/prompt.py +40 -5
  8. ontolearner/learner/rag/__init__.py +14 -0
  9. ontolearner/learner/{rag.py → rag/rag.py} +7 -2
  10. ontolearner/learner/retriever/__init__.py +1 -1
  11. ontolearner/learner/retriever/{llm_retriever.py → augmented_retriever.py} +48 -39
  12. ontolearner/learner/retriever/learner.py +3 -4
  13. ontolearner/learner/taxonomy_discovery/alexbek.py +632 -310
  14. ontolearner/learner/taxonomy_discovery/skhnlp.py +216 -156
  15. ontolearner/learner/text2onto/__init__.py +1 -1
  16. ontolearner/learner/text2onto/alexbek.py +484 -1105
  17. ontolearner/learner/text2onto/sbunlp.py +498 -493
  18. ontolearner/ontology/biology.py +2 -3
  19. ontolearner/ontology/chemistry.py +16 -18
  20. ontolearner/ontology/ecology_environment.py +2 -3
  21. ontolearner/ontology/general.py +4 -6
  22. ontolearner/ontology/material_science_engineering.py +64 -45
  23. ontolearner/ontology/medicine.py +2 -3
  24. ontolearner/ontology/scholarly_knowledge.py +6 -9
  25. ontolearner/processor.py +3 -3
  26. ontolearner/text2onto/splitter.py +69 -6
  27. {ontolearner-1.4.10.dist-info → ontolearner-1.5.0.dist-info}/METADATA +2 -2
  28. {ontolearner-1.4.10.dist-info → ontolearner-1.5.0.dist-info}/RECORD +30 -29
  29. {ontolearner-1.4.10.dist-info → ontolearner-1.5.0.dist-info}/WHEEL +1 -1
  30. {ontolearner-1.4.10.dist-info → ontolearner-1.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -11,7 +11,6 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
15
14
  import os
16
15
  import re
17
16
  import random
@@ -31,7 +30,7 @@ from transformers import (
31
30
  Trainer,
32
31
  TrainingArguments,
33
32
  )
34
-
33
+ from tqdm import tqdm
35
34
  from ...base import AutoLearner, AutoPrompt
36
35
  from ...utils import taxonomy_split, train_test_split as ontology_split
37
36
  from ...data_structure import OntologyData, TaxonomicRelation
@@ -88,31 +87,31 @@ class SKHNLPSequentialFTLearner(AutoLearner):
88
87
  """
89
88
 
90
89
  def __init__(
91
- self,
92
- # core
93
- model_name: str = "bert-large-uncased",
94
- n_prompts: int = 7,
95
- random_state: int = 1403,
96
- num_labels: int = 2,
97
- device: str = "cpu", # "cuda" | "cpu" | None (auto)
98
- # data split & negative sampling (now configurable)
99
- eval_fraction: float = 0.16,
100
- neg_ratio_reversed: float = 1 / 3,
101
- neg_ratio_manipulated: float = 2 / 3,
102
- # ---- expose TrainingArguments as individual user-defined args ----
103
- output_dir: str = "./results/",
104
- num_train_epochs: int = 1,
105
- per_device_train_batch_size: int = 4,
106
- per_device_eval_batch_size: int = 4,
107
- warmup_steps: int = 500,
108
- weight_decay: float = 0.01,
109
- logging_dir: str = "./logs/",
110
- logging_steps: int = 50,
111
- eval_strategy: str = "epoch",
112
- save_strategy: str = "epoch",
113
- load_best_model_at_end: bool = True,
114
- use_fast_tokenizer: Optional[bool] = None,
115
- trust_remote_code: bool = False,
90
+ self,
91
+ # core
92
+ model_name: str = "bert-large-uncased",
93
+ n_prompts: int = 7,
94
+ random_state: int = 1403,
95
+ num_labels: int = 2,
96
+ device: str = "cpu", # "cuda" | "cpu" | None (auto)
97
+ # data split & negative sampling (now configurable)
98
+ eval_fraction: float = 0.16,
99
+ neg_ratio_reversed: float = 1 / 3,
100
+ neg_ratio_manipulated: float = 2 / 3,
101
+ # ---- expose TrainingArguments as individual user-defined args ----
102
+ output_dir: str = "./results/",
103
+ num_train_epochs: int = 1,
104
+ per_device_train_batch_size: int = 4,
105
+ per_device_eval_batch_size: int = 4,
106
+ warmup_steps: int = 500,
107
+ weight_decay: float = 0.01,
108
+ logging_dir: str = "./logs/",
109
+ logging_steps: int = 50,
110
+ eval_strategy: str = "epoch",
111
+ save_strategy: str = "epoch",
112
+ load_best_model_at_end: bool = True,
113
+ use_fast_tokenizer: Optional[bool] = None,
114
+ trust_remote_code: bool = False,
116
115
  ) -> None:
117
116
  """Configure the sequential fine-tuning learner.
118
117
 
@@ -170,6 +169,7 @@ class SKHNLPSequentialFTLearner(AutoLearner):
170
169
  self._last_eval: Optional[pd.DataFrame] = None
171
170
  self.trust_remote_code = bool(trust_remote_code)
172
171
  self.use_fast_tokenizer = use_fast_tokenizer
172
+ self.per_device_eval_batch_size = per_device_eval_batch_size
173
173
 
174
174
  random.seed(self.random_state)
175
175
 
@@ -216,9 +216,9 @@ class SKHNLPSequentialFTLearner(AutoLearner):
216
216
  if getattr(self.tokenizer, "pad_token", None) is None:
217
217
  # Try sensible fallbacks
218
218
  fallback = (
219
- getattr(self.tokenizer, "eos_token", None)
220
- or getattr(self.tokenizer, "sep_token", None)
221
- or getattr(self.tokenizer, "cls_token", None)
219
+ getattr(self.tokenizer, "eos_token", None)
220
+ or getattr(self.tokenizer, "sep_token", None)
221
+ or getattr(self.tokenizer, "cls_token", None)
222
222
  )
223
223
  if fallback is not None:
224
224
  self.tokenizer.pad_token = fallback
@@ -234,8 +234,8 @@ class SKHNLPSequentialFTLearner(AutoLearner):
234
234
 
235
235
  # Make sure padding ids line up
236
236
  if (
237
- getattr(self.model.config, "pad_token_id", None) is None
238
- and getattr(self.tokenizer, "pad_token_id", None) is not None
237
+ getattr(self.model.config, "pad_token_id", None) is None
238
+ and getattr(self.tokenizer, "pad_token_id", None) is not None
239
239
  ):
240
240
  self.model.config.pad_token_id = self.tokenizer.pad_token_id
241
241
 
@@ -279,7 +279,7 @@ class SKHNLPSequentialFTLearner(AutoLearner):
279
279
  return super().tasks_ground_truth_former(data, task)
280
280
 
281
281
  def _make_negatives(
282
- self, positives_df: pd.DataFrame
282
+ self, positives_df: pd.DataFrame
283
283
  ) -> Tuple[pd.DataFrame, pd.DataFrame]:
284
284
  """Create two types of negatives from a positives table.
285
285
 
@@ -313,10 +313,10 @@ class SKHNLPSequentialFTLearner(AutoLearner):
313
313
  return as_reversed(positives_df), with_random_parent(positives_df)
314
314
 
315
315
  def _balance_with_negatives(
316
- self,
317
- positives_df: pd.DataFrame,
318
- reversed_df: pd.DataFrame,
319
- manipulated_df: pd.DataFrame,
316
+ self,
317
+ positives_df: pd.DataFrame,
318
+ reversed_df: pd.DataFrame,
319
+ manipulated_df: pd.DataFrame,
320
320
  ) -> pd.DataFrame:
321
321
  """Combine positives with negatives using configured ratios.
322
322
 
@@ -368,7 +368,7 @@ class SKHNLPSequentialFTLearner(AutoLearner):
368
368
  return out
369
369
 
370
370
  def _df_from_relations(
371
- self, relations: List[TaxonomicRelation], label: bool = True
371
+ self, relations: List[TaxonomicRelation], label: bool = True
372
372
  ) -> pd.DataFrame:
373
373
  """Convert a list of `TaxonomicRelation` to a DataFrame.
374
374
 
@@ -400,7 +400,7 @@ class SKHNLPSequentialFTLearner(AutoLearner):
400
400
  ]
401
401
 
402
402
  def _build_masked_prompt(
403
- self, parent: str, child: str, index_1_based: int, mask_token: str = "[MASK]"
403
+ self, parent: str, child: str, index_1_based: int, mask_token: str = "[MASK]"
404
404
  ) -> str:
405
405
  """Construct one of several True/False prompts with a mask token.
406
406
 
@@ -441,7 +441,7 @@ class SKHNLPSequentialFTLearner(AutoLearner):
441
441
  predicted_label = torch.argmax(logits, dim=1).item()
442
442
  return predicted_label == 1
443
443
 
444
- def _select_parent_via_prompts(self, child: str) -> str:
444
+ def _select_parent_via_prompts_main(self, child: str) -> str:
445
445
  """Select the most likely parent for a given child via prompt voting.
446
446
 
447
447
  The procedure:
@@ -493,6 +493,128 @@ class SKHNLPSequentialFTLearner(AutoLearner):
493
493
 
494
494
  return recurse(list(scores.keys()), level=0)
495
495
 
496
+ def _select_parent_via_prompts(self, child: str) -> str:
497
+ """Select the most likely parent for a given child via batched inference.
498
+
499
+ This vectorized approach processes all candidate parents simultaneously,
500
+ making it O(n_prompts) instead of O(candidates × n_prompts).
501
+
502
+ Args:
503
+ child: The child label whose parent should be predicted.
504
+
505
+ Returns:
506
+ The chosen parent string with highest average probability.
507
+
508
+ Raises:
509
+ AssertionError: If candidate parents were not initialized.
510
+ """
511
+ assert self._candidate_parents, "Candidate parents not initialized."
512
+
513
+ # Build all prompts for all candidates at once
514
+ # Shape: (n_candidates, n_prompts)
515
+ all_prompts = []
516
+ for parent in self._candidate_parents:
517
+ parent_prompts = [
518
+ self._build_masked_prompt(parent, child, idx + 1)
519
+ for idx in range(self.n_prompts)
520
+ ]
521
+ all_prompts.extend(parent_prompts)
522
+
523
+ # Tokenize all prompts in one batch
524
+ encodings = self.tokenizer(
525
+ all_prompts,
526
+ return_tensors="pt",
527
+ padding=True,
528
+ truncation=True,
529
+ ).to(self.model.device)
530
+
531
+ # Single forward pass for all prompts
532
+ logits = self.model(**encodings).logits
533
+
534
+ # Get probabilities for the "True" class (index 1)
535
+ true_probs = torch.softmax(logits, dim=1)[:, 1]
536
+
537
+ # Reshape to (n_candidates, n_prompts)
538
+ true_probs = true_probs.view(len(self._candidate_parents), self.n_prompts)
539
+
540
+ # Average across prompts for each candidate
541
+ avg_scores = true_probs.mean(dim=1)
542
+
543
+ # Select candidate with highest average score
544
+ best_idx = torch.argmax(avg_scores).item()
545
+
546
+ return self._candidate_parents[best_idx]
547
+
548
+
549
+ def _select_parents_batch(self,
550
+ children: List[str],
551
+ batch_size: int = 8,
552
+ max_prompts_per_forward: int = 32) -> List[str]:
553
+ """Select parents for multiple children with memory-efficient chunked inference.
554
+
555
+ Args:
556
+ children: List of child labels to find parents for.
557
+ batch_size: Number of children to process simultaneously.
558
+ max_prompts_per_forward: Max prompts per forward pass (reduce if OOM: try 32, 16, or 8).
559
+
560
+ Returns:
561
+ List of chosen parent strings (same order as input children).
562
+ """
563
+ assert self._candidate_parents, "Candidate parents not initialized."
564
+
565
+ all_predictions = []
566
+
567
+ for batch_start in tqdm(range(0, len(children), batch_size), desc="Batch parent selection"):
568
+ batch_children = children[batch_start:batch_start + batch_size]
569
+
570
+ # Build all prompts for this batch of children
571
+ # Shape: (batch_size × n_candidates × n_prompts)
572
+ batch_prompts = []
573
+ for child in batch_children:
574
+ for parent in self._candidate_parents:
575
+ for idx in range(self.n_prompts):
576
+ batch_prompts.append(
577
+ self._build_masked_prompt(parent, child, idx + 1)
578
+ )
579
+
580
+ # Process prompts in chunks to avoid OOM
581
+ all_true_probs = []
582
+ for chunk_start in range(0, len(batch_prompts), max_prompts_per_forward):
583
+ chunk_prompts = batch_prompts[chunk_start:chunk_start + max_prompts_per_forward]
584
+
585
+ # Tokenize chunk
586
+ encodings = self.tokenizer(
587
+ chunk_prompts,
588
+ return_tensors="pt",
589
+ padding=True,
590
+ max_length=256,
591
+ truncation=True,
592
+ ).to(self.model.device)
593
+
594
+ # Forward pass on chunk
595
+ logits = self.model(**encodings).logits
596
+ true_probs = torch.softmax(logits, dim=1)[:, 1]
597
+ all_true_probs.append(true_probs.cpu()) # Move to CPU to free GPU memory
598
+
599
+ # Clear GPU cache
600
+ del encodings, logits
601
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
602
+
603
+ # Concatenate all chunks
604
+ all_true_probs = torch.cat(all_true_probs, dim=0)
605
+
606
+ # Reshape: (batch_size, n_candidates, n_prompts)
607
+ all_true_probs = all_true_probs.view(len(batch_children), len(self._candidate_parents), self.n_prompts)
608
+ # Average across prompts, then select best candidate per child
609
+ avg_scores = all_true_probs.mean(dim=2) # (batch_size, n_candidates)
610
+ best_indices = torch.argmax(avg_scores, dim=1) # (batch_size,)
611
+
612
+ # Map indices to parent names
613
+ batch_predictions = [self._candidate_parents[idx.item()] for idx in best_indices]
614
+ all_predictions.extend(batch_predictions)
615
+
616
+ return all_predictions
617
+
496
618
  def _taxonomy_discovery(self, data: Any, test: bool = False):
497
619
  """
498
620
  TRAIN:
@@ -540,18 +662,24 @@ class SKHNLPSequentialFTLearner(AutoLearner):
540
662
 
541
663
  if test:
542
664
  if is_ontology_object and self._candidate_parents:
543
- predictions: List[dict[str, str]] = []
544
- for _, row in pairs_df.iterrows():
545
- child_term = row["child"]
546
- chosen_parent = self._select_parent_via_prompts(child_term)
547
- predictions.append({"parent": chosen_parent, "child": child_term})
665
+ # predictions: List[dict[str, str]] = []
666
+ # for _, row in tqdm(pairs_df.iterrows(), desc="Select paren via prompt:"):
667
+ # child_term = row["child"]
668
+ # chosen_parent = self._select_parent_via_prompts(child_term)
669
+ # predictions.append({"parent": chosen_parent, "child": child_term})
670
+ # return predictions
671
+ # predictions: List[dict[str, str]] = []
672
+ children_list = pairs_df["child"].tolist()
673
+ chosen_parents = self._select_parents_batch(children_list, batch_size=self.per_device_eval_batch_size)
674
+ predictions = [{"parent": parent, "child": child}
675
+ for parent, child in zip(chosen_parents, children_list)]
548
676
  return predictions
549
677
 
550
678
  # pairwise binary classification
551
679
  prompts_df = self._add_prompt_columns(pairs_df.copy())
552
680
  true_probs_by_prompt: List[torch.Tensor] = []
553
681
 
554
- for i in range(self.n_prompts):
682
+ for i in tqdm(range(self.n_prompts), desc="Prompt via prompt:"):
555
683
  col = f"prompt_{i + 1}"
556
684
  enc = self.tokenizer(
557
685
  prompts_df[col].tolist(),
@@ -567,36 +695,21 @@ class SKHNLPSequentialFTLearner(AutoLearner):
567
695
  predicted_bool = (avg_true_prob >= 0.5).cpu().tolist()
568
696
 
569
697
  results: List[dict[str, Any]] = []
570
- for p, c, s, yhat in zip(
571
- pairs_df["parent"],
572
- pairs_df["child"],
573
- avg_true_prob.tolist(),
574
- predicted_bool,
575
- ):
576
- results.append(
577
- {
578
- "parent": p,
579
- "child": c,
580
- "label": int(bool(yhat)),
581
- "score": float(s),
582
- }
583
- )
698
+ for p, c, s, yhat in tqdm(zip(pairs_df["parent"],
699
+ pairs_df["child"],
700
+ avg_true_prob.tolist(),
701
+ predicted_bool), desc="Append:"):
702
+ results.append({"parent": p, "child": c, "label": int(bool(yhat)), "score": float(s)})
584
703
  return results
585
704
 
586
705
  if isinstance(data, OntologyData):
587
- train_onto, eval_onto = ontology_split(
588
- data,
589
- test_size=self._eval_fraction,
590
- random_state=self.random_state,
591
- verbose=False,
592
- )
706
+ train_onto, eval_onto = ontology_split(data,
707
+ test_size=self._eval_fraction,
708
+ random_state=self.random_state,
709
+ verbose=False)
593
710
 
594
- train_pos_rel: List[TaxonomicRelation] = (
595
- getattr(train_onto.type_taxonomies, "taxonomies", []) or []
596
- )
597
- eval_pos_rel: List[TaxonomicRelation] = (
598
- getattr(eval_onto.type_taxonomies, "taxonomies", []) or []
599
- )
711
+ train_pos_rel: List[TaxonomicRelation] = getattr(train_onto.type_taxonomies, "taxonomies", []) or []
712
+ eval_pos_rel: List[TaxonomicRelation] = getattr(eval_onto.type_taxonomies, "taxonomies", []) or []
600
713
 
601
714
  train_pos_df = self._df_from_relations(train_pos_rel, label=True)
602
715
  eval_pos_df = self._df_from_relations(eval_pos_rel, label=True)
@@ -612,9 +725,7 @@ class SKHNLPSequentialFTLearner(AutoLearner):
612
725
 
613
726
  else:
614
727
  if "label" not in pairs_df.columns or pairs_df["label"].nunique() == 1:
615
- positives_df = pairs_df[pairs_df.get("label", True)][
616
- ["parent", "child"]
617
- ].copy()
728
+ positives_df = pairs_df[pairs_df.get("label", True)][["parent", "child"]].copy()
618
729
  pos_rel = self._relations_from_df(positives_df)
619
730
 
620
731
  tr_rel, ev_rel = taxonomy_split(
@@ -630,12 +741,8 @@ class SKHNLPSequentialFTLearner(AutoLearner):
630
741
  tr_rev_df, tr_man_df = self._make_negatives(train_pos_df)
631
742
  ev_rev_df, ev_man_df = self._make_negatives(eval_pos_df)
632
743
 
633
- train_df = self._balance_with_negatives(
634
- train_pos_df, tr_rev_df, tr_man_df
635
- )
636
- eval_df = self._balance_with_negatives(
637
- eval_pos_df, ev_rev_df, ev_man_df
638
- )
744
+ train_df = self._balance_with_negatives(train_pos_df, tr_rev_df, tr_man_df)
745
+ eval_df = self._balance_with_negatives(eval_pos_df, ev_rev_df, ev_man_df)
639
746
 
640
747
  train_df = self._add_prompt_columns(train_df)
641
748
  eval_df = self._add_prompt_columns(eval_df)
@@ -655,20 +762,11 @@ class SKHNLPSequentialFTLearner(AutoLearner):
655
762
  eval_pos_df = self._df_from_relations(ev_rel, label=True)
656
763
 
657
764
  negatives_df = pairs_df[pairs_df["label"]][["parent", "child"]].copy()
658
- negatives_df = negatives_df.sample(
659
- frac=1.0, random_state=self.random_state
660
- ).reset_index(drop=True)
661
-
662
- n_eval_neg = (
663
- max(1, int(len(negatives_df) * self._eval_fraction))
664
- if len(negatives_df) > 0
665
- else 0
666
- )
667
- eval_neg_df = (
668
- negatives_df.iloc[:n_eval_neg].copy()
669
- if n_eval_neg > 0
670
- else negatives_df.iloc[:0].copy()
671
- )
765
+ negatives_df = negatives_df.sample(frac=1.0, random_state=self.random_state).reset_index(drop=True)
766
+
767
+ n_eval_neg = max(1, int(len(negatives_df) * self._eval_fraction)) if len(negatives_df) > 0 else 0
768
+
769
+ eval_neg_df = negatives_df.iloc[:n_eval_neg].copy() if n_eval_neg > 0 else negatives_df.iloc[:0].copy()
672
770
  train_neg_df = negatives_df.iloc[n_eval_neg:].copy()
673
771
 
674
772
  train_neg_df["label"] = False
@@ -687,35 +785,21 @@ class SKHNLPSequentialFTLearner(AutoLearner):
687
785
  # Sequential fine-tuning across prompts
688
786
  for i in range(self.n_prompts):
689
787
  prompt_col = f"prompt_{i + 1}"
690
- train_ds = Dataset.from_pandas(
691
- train_df[[prompt_col, "label"]].reset_index(drop=True)
692
- )
693
- eval_ds = Dataset.from_pandas(
694
- eval_df[[prompt_col, "label"]].reset_index(drop=True)
695
- )
788
+ train_ds = Dataset.from_pandas(train_df[[prompt_col, "label"]].reset_index(drop=True))
789
+ eval_ds = Dataset.from_pandas(eval_df[[prompt_col, "label"]].reset_index(drop=True))
696
790
 
697
791
  train_ds = train_ds.rename_column("label", "labels")
698
792
  eval_ds = eval_ds.rename_column("label", "labels")
699
793
 
700
794
  def tokenize_batch(batch):
701
795
  """Tokenize a batch for the current prompt column with truncation/padding."""
702
- return self.tokenizer(
703
- batch[prompt_col], padding="max_length", truncation=True
704
- )
796
+ return self.tokenizer(batch[prompt_col], padding="max_length", truncation=True)
705
797
 
706
- train_ds = train_ds.map(
707
- tokenize_batch, batched=True, remove_columns=[prompt_col]
708
- )
709
- eval_ds = eval_ds.map(
710
- tokenize_batch, batched=True, remove_columns=[prompt_col]
711
- )
798
+ train_ds = train_ds.map(tokenize_batch, batched=True, remove_columns=[prompt_col])
799
+ eval_ds = eval_ds.map(tokenize_batch, batched=True, remove_columns=[prompt_col])
712
800
 
713
- train_ds.set_format(
714
- type="torch", columns=["input_ids", "attention_mask", "labels"]
715
- )
716
- eval_ds.set_format(
717
- type="torch", columns=["input_ids", "attention_mask", "labels"]
718
- )
801
+ train_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
802
+ eval_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
719
803
 
720
804
  trainer = Trainer(
721
805
  model=self.model,
@@ -729,7 +813,6 @@ class SKHNLPSequentialFTLearner(AutoLearner):
729
813
  self._last_eval = eval_df
730
814
  return None
731
815
 
732
-
733
816
  class SKHNLPZSLearner(AutoLearner):
734
817
  """
735
818
  Zero-shot taxonomy learner using an instruction-tuned causal LLM.
@@ -848,10 +931,7 @@ class SKHNLPZSLearner(AutoLearner):
848
931
  self._tokenizer = AutoTokenizer.from_pretrained(model_id)
849
932
 
850
933
  # Ensure a pad token is set for generation
851
- if (
852
- self._tokenizer.pad_token_id is None
853
- and self._tokenizer.eos_token_id is not None
854
- ):
934
+ if self._tokenizer.pad_token_id is None and self._tokenizer.eos_token_id is not None:
855
935
  self._tokenizer.pad_token = self._tokenizer.eos_token
856
936
 
857
937
  self._model = AutoModelForCausalLM.from_pretrained(
@@ -871,9 +951,7 @@ class SKHNLPZSLearner(AutoLearner):
871
951
  print("Device set to use", "cuda" if self._has_cuda else "cpu")
872
952
  print("[ZeroShotTaxonomyLearner] Model loaded.")
873
953
 
874
- def _taxonomy_discovery(
875
- self, data: Any, test: bool = False
876
- ) -> Optional[List[Dict[str, str]]]:
954
+ def _taxonomy_discovery(self, data: Any, test: bool = False) -> Optional[List[Dict[str, str]]]:
877
955
  """
878
956
  Zero-shot prediction over all incoming rows (no filtering/augmentation).
879
957
 
@@ -967,16 +1045,14 @@ class SKHNLPZSLearner(AutoLearner):
967
1045
  add_generation_prompt=True,
968
1046
  )
969
1047
 
970
- generation = self._pipeline(
971
- prompt,
972
- max_new_tokens=self.max_new_tokens,
973
- do_sample=False,
974
- temperature=0.0,
975
- top_p=1.0,
976
- eos_token_id=self._tokenizer.eos_token_id,
977
- pad_token_id=self._tokenizer.eos_token_id,
978
- return_full_text=False,
979
- )[0]["generated_text"]
1048
+ generation = self._pipeline(prompt,
1049
+ max_new_tokens=self.max_new_tokens,
1050
+ do_sample=False,
1051
+ temperature=0.0,
1052
+ top_p=1.0,
1053
+ eos_token_id=self._tokenizer.eos_token_id,
1054
+ pad_token_id=self._tokenizer.eos_token_id,
1055
+ return_full_text=False)[0]["generated_text"]
980
1056
 
981
1057
  match = self._PREDICTION_PATTERN.search(generation)
982
1058
  parsed = match.group(1).strip() if match else "unknown"
@@ -1000,11 +1076,7 @@ class SKHNLPZSLearner(AutoLearner):
1000
1076
 
1001
1077
  for label in self.CLASS_LIST:
1002
1078
  label_lower = label.lower()
1003
- if (
1004
- lowered == label_lower
1005
- or lowered in label_lower
1006
- or label_lower in lowered
1007
- ):
1079
+ if lowered == label_lower or lowered in label_lower or label_lower in lowered:
1008
1080
  return label
1009
1081
  return "unknown"
1010
1082
 
@@ -1045,9 +1117,7 @@ class SKHNLPZSLearner(AutoLearner):
1045
1117
  Normalized label string or 'unknown'.
1046
1118
  """
1047
1119
  snapped = self._normalize_substring_only(text)
1048
- return (
1049
- snapped if snapped != "unknown" else self._normalize_levenshtein_only(text)
1050
- )
1120
+ return snapped if snapped != "unknown" else self._normalize_levenshtein_only(text)
1051
1121
 
1052
1122
  def _to_dataframe(self, data: Any) -> pd.DataFrame:
1053
1123
  """
@@ -1082,13 +1152,9 @@ class SKHNLPZSLearner(AutoLearner):
1082
1152
  if isinstance(first, (list, tuple)) and not isinstance(first, dict):
1083
1153
  n = len(first)
1084
1154
  if n >= 3:
1085
- return pd.DataFrame(
1086
- data, columns=["child", "parent", "label"]
1087
- ).reset_index(drop=True)
1155
+ return pd.DataFrame(data, columns=["child", "parent", "label"]).reset_index(drop=True)
1088
1156
  if n == 2:
1089
- return pd.DataFrame(data, columns=["child", "parent"]).reset_index(
1090
- drop=True
1091
- )
1157
+ return pd.DataFrame(data, columns=["child", "parent"]).reset_index(drop=True)
1092
1158
 
1093
1159
  try:
1094
1160
  type_taxonomies = getattr(data, "type_taxonomies", None)
@@ -1099,15 +1165,9 @@ class SKHNLPZSLearner(AutoLearner):
1099
1165
  for rel in taxonomies:
1100
1166
  parent = getattr(rel, "parent", None)
1101
1167
  child = getattr(rel, "child", None)
1102
- label = (
1103
- getattr(rel, "label", None)
1104
- if hasattr(rel, "label")
1105
- else None
1106
- )
1168
+ label = getattr(rel, "label", None) if hasattr(rel, "label") else None
1107
1169
  if parent is not None and child is not None:
1108
- rows.append(
1109
- {"child": child, "parent": parent, "label": label}
1110
- )
1170
+ rows.append({"child": child, "parent": parent, "label": label})
1111
1171
  if rows:
1112
1172
  return pd.DataFrame(rows).reset_index(drop=True)
1113
1173
  except Exception:
@@ -12,5 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .alexbek import AlexbekFewShotLearner
15
+ from .alexbek import AlexbekRAGFewShotLearner
16
16
  from .sbunlp import SBUNLPFewShotLearner