OntoLearner 1.4.7__py3-none-any.whl → 1.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1138 @@
1
+ # Copyright (c) 2025 SciKnowOrg
2
+ #
3
+ # Licensed under the MIT License (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://opensource.org/licenses/MIT
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import re
17
+ import random
18
+
19
+ import pandas as pd
20
+ import torch
21
+ import Levenshtein
22
+ from datasets import Dataset
23
+ from typing import Any, Optional, List, Tuple, Dict
24
+ from transformers import (
25
+ AutoTokenizer,
26
+ AutoModelForSequenceClassification,
27
+ AutoModelForCausalLM,
28
+ BertTokenizer,
29
+ BertForSequenceClassification,
30
+ pipeline,
31
+ Trainer,
32
+ TrainingArguments,
33
+ )
34
+
35
+ from ...base import AutoLearner, AutoPrompt
36
+ from ...utils import taxonomy_split, train_test_split as ontology_split
37
+ from ...data_structure import OntologyData, TaxonomicRelation
38
+
39
+
40
+ class SKHNLPTaxonomyPrompts(AutoPrompt):
41
+ """Builds the 7 taxonomy prompts used during fine-tuning / inference.
42
+
43
+ The class stores a small inventory of prompt templates that verbalize the
44
+ (parent, child) relationship using different phrasings. Each template ends
45
+ with a masked token slot intended for True/False classification.
46
+ """
47
+
48
+ def __init__(self) -> None:
49
+ """Initialize prompt templates and the default prompt in the base class."""
50
+ super().__init__(
51
+ prompt_template="{parent} is the superclass of {child}. This statement is [MASK]."
52
+ )
53
+ self.templates: List[str] = [
54
+ "{parent} is the superclass of {child}. This statement is [MASK].",
55
+ "{child} is a subclass of {parent}. This statement is [MASK].",
56
+ "{parent} is the parent class of {child}. This statement is [MASK].",
57
+ "{child} is a child class of {parent}. This statement is [MASK].",
58
+ "{parent} is a supertype of {child}. This statement is [MASK].",
59
+ "{child} is a subtype of {parent}. This statement is [MASK].",
60
+ "{parent} is an ancestor class of {child}. This statement is [MASK].",
61
+ ]
62
+
63
+ def format(self, parent: str, child: str, template_idx: int) -> str:
64
+ """Render a prompt for a (parent, child) pair using a specific template.
65
+
66
+ Args:
67
+ parent: The parent/superclass label.
68
+ child: The child/subclass label.
69
+ template_idx: Index into the internal `templates` list.
70
+
71
+ Returns:
72
+ The fully formatted prompt string.
73
+ """
74
+ return self.templates[template_idx].format(parent=parent, child=child)
75
+
76
+
77
+ class SKHNLPSequentialFTLearner(AutoLearner):
78
+ """
79
+ BERT-based classifier for taxonomy discovery.
80
+
81
+ With OntologyData:
82
+ * TRAIN: ontology-aware split; create balanced train/eval with negatives.
83
+ * PREDICT/TEST: notebook-style parent selection -> list[{'parent', 'child'}].
84
+
85
+ With DataFrame/list:
86
+ * TRAIN: taxonomy_split + negatives; build prompts and fine-tune.
87
+ * PREDICT/TEST: pairwise binary classification (returns label + score).
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ # core
93
+ model_name: str = "bert-large-uncased",
94
+ n_prompts: int = 7,
95
+ random_state: int = 1403,
96
+ num_labels: int = 2,
97
+ device: str = "cpu", # "cuda" | "cpu" | None (auto)
98
+ # data split & negative sampling (now configurable)
99
+ eval_fraction: float = 0.16,
100
+ neg_ratio_reversed: float = 1 / 3,
101
+ neg_ratio_manipulated: float = 2 / 3,
102
+ # ---- expose TrainingArguments as individual user-defined args ----
103
+ output_dir: str = "./results/",
104
+ num_train_epochs: int = 1,
105
+ per_device_train_batch_size: int = 4,
106
+ per_device_eval_batch_size: int = 4,
107
+ warmup_steps: int = 500,
108
+ weight_decay: float = 0.01,
109
+ logging_dir: str = "./logs/",
110
+ logging_steps: int = 50,
111
+ eval_strategy: str = "epoch",
112
+ save_strategy: str = "epoch",
113
+ load_best_model_at_end: bool = True,
114
+ use_fast_tokenizer: Optional[bool] = None,
115
+ trust_remote_code: bool = False,
116
+ ) -> None:
117
+ """Configure the sequential fine-tuning learner.
118
+
119
+ Args:
120
+ model_name: HF model id or local path for the BERT backbone.
121
+ n_prompts: Number of prompt variants to iterate over sequentially.
122
+ random_state: RNG seed for shuffling/sampling steps.
123
+ num_labels: Number of classes for the classifier head.
124
+ device: Force device ('cuda' or 'cpu'). If None, auto-detects CUDA.
125
+ eval_fraction: Fraction of positives to hold out for evaluation.
126
+ neg_ratio_reversed: Proportion of reversed-parent negatives vs positives.
127
+ neg_ratio_manipulated: Proportion of random-parent negatives vs positives.
128
+ output_dir: Directory where HF Trainer writes checkpoints/outputs.
129
+ num_train_epochs: Number of epochs per prompt.
130
+ per_device_train_batch_size: Training batch size per device.
131
+ per_device_eval_batch_size: Evaluation batch size per device.
132
+ warmup_steps: Linear warmup steps for LR scheduler.
133
+ weight_decay: Weight decay coefficient.
134
+ logging_dir: Directory for Trainer logs.
135
+ logging_steps: Interval for log events (in steps).
136
+ eval_strategy: Evaluation schedule ('no', 'steps', 'epoch').
137
+ save_strategy: Checkpoint save schedule ('no', 'steps', 'epoch').
138
+ load_best_model_at_end: Whether to restore the best checkpoint.
139
+ use_fast_tokenizer: Force fast/slow tokenizer. If None, try fast then fallback to slow.
140
+ Notes:
141
+ The model is fine-tuned *sequentially* across prompt columns.
142
+ You can control the eval split and negative sampling mix via
143
+ `eval_fraction`, `neg_ratio_reversed`, and `neg_ratio_manipulated`.
144
+ """
145
+ super().__init__()
146
+ self.model_name = model_name
147
+ self.n_prompts = n_prompts
148
+ self.random_state = random_state
149
+ self.num_labels = num_labels
150
+ self.device = device
151
+
152
+ # user-tunable ratios / split
153
+ self._eval_fraction = float(eval_fraction)
154
+ self._neg_ratio_reversed = float(neg_ratio_reversed)
155
+ self._neg_ratio_manipulated = float(neg_ratio_manipulated)
156
+ if not (0.0 < self._eval_fraction < 1.0):
157
+ raise ValueError("eval_fraction must be in (0, 1).")
158
+ if self._neg_ratio_reversed < 0 or self._neg_ratio_manipulated < 0:
159
+ raise ValueError("neg_ratio_* must be >= 0.")
160
+
161
+ self.tokenizer: Optional[BertTokenizer] = None
162
+ self.model: Optional[BertForSequenceClassification] = None
163
+ self.prompter = SKHNLPTaxonomyPrompts()
164
+
165
+ # Candidate parents (unique parent list) for multi-class parent selection.
166
+ self._candidate_parents: Optional[List[str]] = None
167
+
168
+ # Keep last train/eval tables for inspection
169
+ self._last_train: Optional[pd.DataFrame] = None
170
+ self._last_eval: Optional[pd.DataFrame] = None
171
+ self.trust_remote_code = bool(trust_remote_code)
172
+ self.use_fast_tokenizer = use_fast_tokenizer
173
+
174
+ random.seed(self.random_state)
175
+
176
+ # Build TrainingArguments from the individual user-defined values
177
+ self.training_args = TrainingArguments(
178
+ output_dir=output_dir,
179
+ num_train_epochs=num_train_epochs,
180
+ per_device_train_batch_size=per_device_train_batch_size,
181
+ per_device_eval_batch_size=per_device_eval_batch_size,
182
+ warmup_steps=warmup_steps,
183
+ weight_decay=weight_decay,
184
+ logging_dir=logging_dir,
185
+ logging_steps=logging_steps,
186
+ eval_strategy=eval_strategy,
187
+ save_strategy=save_strategy,
188
+ load_best_model_at_end=load_best_model_at_end,
189
+ )
190
+
191
+ def load(self, model_id: Optional[str] = None, **_: Any) -> None:
192
+ """Load tokenizer & model in a backbone-agnostic way; move model to self.device."""
193
+ model_id = model_id or self.model_name
194
+
195
+ # ---- Tokenizer (robust fast→slow fallback unless explicitly set) ----
196
+ if self.use_fast_tokenizer is None:
197
+ try:
198
+ self.tokenizer = AutoTokenizer.from_pretrained(
199
+ model_id, use_fast=True, trust_remote_code=self.trust_remote_code
200
+ )
201
+ except Exception as fast_err:
202
+ print(
203
+ f"[tokenizer] Fast tokenizer failed: {fast_err}. Falling back to slow tokenizer..."
204
+ )
205
+ self.tokenizer = AutoTokenizer.from_pretrained(
206
+ model_id, use_fast=False, trust_remote_code=self.trust_remote_code
207
+ )
208
+ else:
209
+ self.tokenizer = AutoTokenizer.from_pretrained(
210
+ model_id,
211
+ use_fast=self.use_fast_tokenizer,
212
+ trust_remote_code=self.trust_remote_code,
213
+ )
214
+
215
+ # Ensure pad token exists (some models lack it)
216
+ if getattr(self.tokenizer, "pad_token", None) is None:
217
+ # Try sensible fallbacks
218
+ fallback = (
219
+ getattr(self.tokenizer, "eos_token", None)
220
+ or getattr(self.tokenizer, "sep_token", None)
221
+ or getattr(self.tokenizer, "cls_token", None)
222
+ )
223
+ if fallback is not None:
224
+ self.tokenizer.pad_token = fallback
225
+
226
+ # ---- Model (classifier head sized to self.num_labels) ----
227
+ self.model = AutoModelForSequenceClassification.from_pretrained(
228
+ model_id,
229
+ num_labels=self.num_labels,
230
+ trust_remote_code=self.trust_remote_code,
231
+ # Allows swapping in a new head size even if the checkpoint differs
232
+ ignore_mismatched_sizes=True,
233
+ )
234
+
235
+ # Make sure padding ids line up
236
+ if (
237
+ getattr(self.model.config, "pad_token_id", None) is None
238
+ and getattr(self.tokenizer, "pad_token_id", None) is not None
239
+ ):
240
+ self.model.config.pad_token_id = self.tokenizer.pad_token_id
241
+
242
+ # Set problem type (single-label classification by default)
243
+ # If you plan multi-label, you'd switch to "multi_label_classification"
244
+ self.model.config.problem_type = "single_label_classification"
245
+
246
+ # Move to target device
247
+ self.model.to(self.device)
248
+
249
+ def tasks_ground_truth_former(self, data: Any, task: str) -> Any:
250
+ """Normalize ground-truth inputs for 'taxonomy-discovery'.
251
+
252
+ Supports DataFrame with columns ['parent','child',('label')],
253
+ list of dicts, or falls back to the base class behavior.
254
+
255
+ Args:
256
+ data: Input object to normalize.
257
+ task: Task name, passed from the outer pipeline.
258
+
259
+ Returns:
260
+ A list of dictionaries with keys 'parent', 'child', and optionally
261
+ 'label' when present in the input.
262
+ """
263
+ if task != "taxonomy-discovery":
264
+ return super().tasks_ground_truth_former(data, task)
265
+
266
+ if isinstance(data, pd.DataFrame):
267
+ if "label" in data.columns:
268
+ return [
269
+ {"parent": p, "child": c, "label": bool(lbl)}
270
+ for p, c, lbl in zip(data["parent"], data["child"], data["label"])
271
+ ]
272
+ return [
273
+ {"parent": p, "child": c} for p, c in zip(data["parent"], data["child"])
274
+ ]
275
+
276
+ if isinstance(data, list):
277
+ return data
278
+
279
+ return super().tasks_ground_truth_former(data, task)
280
+
281
+ def _make_negatives(
282
+ self, positives_df: pd.DataFrame
283
+ ) -> Tuple[pd.DataFrame, pd.DataFrame]:
284
+ """Create two types of negatives from a positives table.
285
+
286
+ Returns:
287
+ A tuple `(reversed_df, manipulated_df)` where:
288
+ - `reversed_df`: pairs with parent/child columns swapped, label=False.
289
+ - `manipulated_df`: pairs with the parent replaced by a random
290
+ *different* parent from the same pool, label=False.
291
+
292
+ Notes:
293
+ The input DataFrame must contain columns ['parent', 'child'].
294
+ """
295
+ unique_parents = positives_df["parent"].unique().tolist()
296
+
297
+ def as_reversed(df: pd.DataFrame) -> pd.DataFrame:
298
+ out = df.copy()
299
+ out[["parent", "child"]] = out[["child", "parent"]].values
300
+ out["label"] = False
301
+ return out
302
+
303
+ def with_random_parent(df: pd.DataFrame) -> pd.DataFrame:
304
+ def pick_other_parent(p: str) -> str:
305
+ pool = [x for x in unique_parents if x != p]
306
+ return random.choice(pool) if pool else p
307
+
308
+ out = df.copy()
309
+ out["parent"] = out["parent"].apply(pick_other_parent)
310
+ out["label"] = False
311
+ return out
312
+
313
+ return as_reversed(positives_df), with_random_parent(positives_df)
314
+
315
+ def _balance_with_negatives(
316
+ self,
317
+ positives_df: pd.DataFrame,
318
+ reversed_df: pd.DataFrame,
319
+ manipulated_df: pd.DataFrame,
320
+ ) -> pd.DataFrame:
321
+ """Combine positives with negatives using configured ratios.
322
+
323
+ Sampling ratios are defined by the instance settings
324
+ `self._neg_ratio_reversed` and `self._neg_ratio_manipulated`,
325
+ keeping the positives count unchanged.
326
+
327
+ Args:
328
+ positives_df: Positive pairs with `label=True`.
329
+ reversed_df: Negative pairs produced by flipping parent/child.
330
+ manipulated_df: Negative pairs with randomly reassigned parents.
331
+
332
+ Returns:
333
+ A deduplicated, shuffled DataFrame with a class-balanced mix.
334
+ """
335
+ n_pos = len(positives_df)
336
+ n_rev = int(n_pos * self._neg_ratio_reversed)
337
+ n_man = int(n_pos * self._neg_ratio_manipulated)
338
+
339
+ combined = pd.concat(
340
+ [
341
+ positives_df.sample(n_pos, random_state=self.random_state),
342
+ reversed_df.sample(n_rev, random_state=self.random_state),
343
+ manipulated_df.sample(n_man, random_state=self.random_state),
344
+ ],
345
+ ignore_index=True,
346
+ )
347
+ combined = combined.drop_duplicates(
348
+ subset=["parent", "child", "label"]
349
+ ).reset_index(drop=True)
350
+ return combined
351
+
352
+ def _add_prompt_columns(self, df: pd.DataFrame) -> pd.DataFrame:
353
+ """Append one column per prompt variant to the given pairs table.
354
+
355
+ For each row `(parent, child)`, creates columns `prompt_1 ... prompt_n`.
356
+
357
+ Args:
358
+ df: Input DataFrame with columns ['parent', 'child', ...].
359
+
360
+ Returns:
361
+ A copy of `df` including the newly added prompt columns.
362
+ """
363
+ out = df.copy()
364
+ for i in range(self.n_prompts):
365
+ out[f"prompt_{i + 1}"] = out.apply(
366
+ lambda r, k=i: self.prompter.format(r["parent"], r["child"], k), axis=1
367
+ )
368
+ return out
369
+
370
+ def _df_from_relations(
371
+ self, relations: List[TaxonomicRelation], label: bool = True
372
+ ) -> pd.DataFrame:
373
+ """Convert a list of `TaxonomicRelation` to a DataFrame.
374
+
375
+ Args:
376
+ relations: Iterable of `TaxonomicRelation(parent, child)`.
377
+ label: Class label to assign to all resulting rows.
378
+
379
+ Returns:
380
+ DataFrame with columns ['parent', 'child', 'label'].
381
+ """
382
+ if not relations:
383
+ return pd.DataFrame(columns=["parent", "child", "label"])
384
+ return pd.DataFrame(
385
+ [{"parent": r.parent, "child": r.child, "label": label} for r in relations]
386
+ )
387
+
388
+ def _relations_from_df(self, df: pd.DataFrame) -> List[TaxonomicRelation]:
389
+ """Convert a DataFrame to a list of `TaxonomicRelation`.
390
+
391
+ Args:
392
+ df: DataFrame with columns ['parent', 'child'].
393
+
394
+ Returns:
395
+ List of `TaxonomicRelation` objects in row order.
396
+ """
397
+ return [
398
+ TaxonomicRelation(parent=p, child=c)
399
+ for p, c in zip(df["parent"], df["child"])
400
+ ]
401
+
402
+ def _build_masked_prompt(
403
+ self, parent: str, child: str, index_1_based: int, mask_token: str = "[MASK]"
404
+ ) -> str:
405
+ """Construct one of several True/False prompts with a mask token.
406
+
407
+ Args:
408
+ parent: Parent label.
409
+ child: Child label.
410
+ index_1_based: 1-based index selecting a template.
411
+ mask_token: The token used to denote the masked label.
412
+
413
+ Returns:
414
+ A formatted prompt string.
415
+ """
416
+ prompts_1based = [
417
+ f"{parent} is the superclass of {child}. This statement is {mask_token}.",
418
+ f"{child} is a subclass of {parent}. This statement is {mask_token}.",
419
+ f"{parent} is the parent class of {child}. This statement is {mask_token}.",
420
+ f"{child} is a child class of {parent}. This statement is {mask_token}.",
421
+ f"{parent} is a supertype of {child}. This statement is {mask_token}.",
422
+ f"{child} is a subtype of {parent}. This statement is {mask_token}.",
423
+ f"{parent} is an ancestor class of {child}. This statement is {mask_token}.",
424
+ f"{child} is a descendant classs of {child}. This statement is {mask_token}.",
425
+ f'"{parent}" is the superclass of "{child}". This statement is {mask_token}.',
426
+ ]
427
+ return prompts_1based[index_1_based - 1]
428
+
429
+ @torch.no_grad()
430
+ def _predict_prompt_true_false(self, sentence: str) -> bool:
431
+ """Run a single True/False prediction on a prompt.
432
+
433
+ Args:
434
+ sentence: Fully formatted prompt text.
435
+
436
+ Returns:
437
+ True iff the predicted class index is 1 (positive).
438
+ """
439
+ enc = self.tokenizer(sentence, return_tensors="pt").to(self.model.device)
440
+ logits = self.model(**enc).logits
441
+ predicted_label = torch.argmax(logits, dim=1).item()
442
+ return predicted_label == 1
443
+
444
+ def _select_parent_via_prompts(self, child: str) -> str:
445
+ """Select the most likely parent for a given child via prompt voting.
446
+
447
+ The procedure:
448
+ 1) Generate prompts for each candidate parent at increasing "levels".
449
+ 2) Accumulate votes from the True/False classifier.
450
+ 3) Resolve ties by recursing to the next level; after 4 levels, break ties randomly.
451
+
452
+ Args:
453
+ child: The child label whose parent should be predicted.
454
+
455
+ Returns:
456
+ The chosen parent string.
457
+
458
+ Raises:
459
+ AssertionError: If candidate parents were not initialized.
460
+ """
461
+ assert self._candidate_parents, "Candidate parents not initialized."
462
+ scores: dict[str, int] = {p: 0 for p in self._candidate_parents}
463
+
464
+ def prompt_indices_for_level(level: int) -> List[int]:
465
+ if level == 0:
466
+ return [1]
467
+ return [2 * level, 2 * level + 1]
468
+
469
+ def recurse(active_parents: List[str], level: int) -> str:
470
+ idxs = [
471
+ i for i in prompt_indices_for_level(level) if 1 <= i <= self.n_prompts
472
+ ]
473
+ if idxs:
474
+ for parent in active_parents:
475
+ votes = sum(
476
+ 1
477
+ for idx in idxs
478
+ if self._predict_prompt_true_false(
479
+ self._build_masked_prompt(
480
+ parent=parent, child=child, index_1_based=idx
481
+ )
482
+ )
483
+ )
484
+ scores[parent] += votes
485
+
486
+ max_score = max(scores[p] for p in active_parents)
487
+ tied = [p for p in active_parents if scores[p] == max_score]
488
+ if len(tied) == 1:
489
+ return tied[0]
490
+ if level < 4:
491
+ return recurse(tied, level + 1)
492
+ return random.choice(tied)
493
+
494
+ return recurse(list(scores.keys()), level=0)
495
+
496
+ def _taxonomy_discovery(self, data: Any, test: bool = False):
497
+ """
498
+ TRAIN:
499
+ - OntologyData -> ontology-aware split; negatives per split; balanced sets.
500
+ - DataFrame/list -> taxonomy_split for positives; negatives proportional.
501
+ TEST:
502
+ - OntologyData -> parent selection: [{'parent': predicted, 'child': child}]
503
+ - DataFrame/list -> binary pair classification with 'label' + 'score'
504
+
505
+ Args:
506
+ data: One of {OntologyData, pandas.DataFrame, list[dict], list[tuple]}.
507
+ test: If True, run inference; otherwise perform training.
508
+
509
+ Returns:
510
+ - On training: None (model is fine-tuned in-place).
511
+ - On inference with OntologyData: list of {'parent','child'} predictions.
512
+ - On inference with pairs: list of dicts including 'label' and 'score'.
513
+ """
514
+ is_ontology_object = isinstance(data, OntologyData)
515
+
516
+ # Normalize input
517
+ if isinstance(data, pd.DataFrame):
518
+ pairs_df = data.copy()
519
+ elif isinstance(data, list):
520
+ pairs_df = pd.DataFrame(data)
521
+ else:
522
+ gt_pairs = super().tasks_ground_truth_former(data, "taxonomy-discovery")
523
+ pairs_df = pd.DataFrame(gt_pairs)
524
+ if "label" not in pairs_df.columns:
525
+ pairs_df["label"] = True
526
+
527
+ # Maintain candidate parents across calls
528
+ if "parent" in pairs_df.columns:
529
+ parents_in_call = sorted(pd.unique(pairs_df["parent"]).tolist())
530
+ if test:
531
+ if self._candidate_parents is None:
532
+ self._candidate_parents = parents_in_call
533
+ else:
534
+ self._candidate_parents = sorted(
535
+ set(self._candidate_parents).union(parents_in_call)
536
+ )
537
+ else:
538
+ if self._candidate_parents is None:
539
+ self._candidate_parents = parents_in_call
540
+
541
+ if test:
542
+ if is_ontology_object and self._candidate_parents:
543
+ predictions: List[dict[str, str]] = []
544
+ for _, row in pairs_df.iterrows():
545
+ child_term = row["child"]
546
+ chosen_parent = self._select_parent_via_prompts(child_term)
547
+ predictions.append({"parent": chosen_parent, "child": child_term})
548
+ return predictions
549
+
550
+ # pairwise binary classification
551
+ prompts_df = self._add_prompt_columns(pairs_df.copy())
552
+ true_probs_by_prompt: List[torch.Tensor] = []
553
+
554
+ for i in range(self.n_prompts):
555
+ col = f"prompt_{i + 1}"
556
+ enc = self.tokenizer(
557
+ prompts_df[col].tolist(),
558
+ return_tensors="pt",
559
+ padding=True,
560
+ truncation=True,
561
+ ).to(self.model.device)
562
+ with torch.no_grad():
563
+ logits = self.model(**enc).logits
564
+ true_probs_by_prompt.append(torch.softmax(logits, dim=1)[:, 1])
565
+
566
+ avg_true_prob = torch.stack(true_probs_by_prompt, dim=0).mean(0)
567
+ predicted_bool = (avg_true_prob >= 0.5).cpu().tolist()
568
+
569
+ results: List[dict[str, Any]] = []
570
+ for p, c, s, yhat in zip(
571
+ pairs_df["parent"],
572
+ pairs_df["child"],
573
+ avg_true_prob.tolist(),
574
+ predicted_bool,
575
+ ):
576
+ results.append(
577
+ {
578
+ "parent": p,
579
+ "child": c,
580
+ "label": int(bool(yhat)),
581
+ "score": float(s),
582
+ }
583
+ )
584
+ return results
585
+
586
+ if isinstance(data, OntologyData):
587
+ train_onto, eval_onto = ontology_split(
588
+ data,
589
+ test_size=self._eval_fraction,
590
+ random_state=self.random_state,
591
+ verbose=False,
592
+ )
593
+
594
+ train_pos_rel: List[TaxonomicRelation] = (
595
+ getattr(train_onto.type_taxonomies, "taxonomies", []) or []
596
+ )
597
+ eval_pos_rel: List[TaxonomicRelation] = (
598
+ getattr(eval_onto.type_taxonomies, "taxonomies", []) or []
599
+ )
600
+
601
+ train_pos_df = self._df_from_relations(train_pos_rel, label=True)
602
+ eval_pos_df = self._df_from_relations(eval_pos_rel, label=True)
603
+
604
+ tr_rev_df, tr_man_df = self._make_negatives(train_pos_df)
605
+ ev_rev_df, ev_man_df = self._make_negatives(eval_pos_df)
606
+
607
+ train_df = self._balance_with_negatives(train_pos_df, tr_rev_df, tr_man_df)
608
+ eval_df = self._balance_with_negatives(eval_pos_df, ev_rev_df, ev_man_df)
609
+
610
+ train_df = self._add_prompt_columns(train_df)
611
+ eval_df = self._add_prompt_columns(eval_df)
612
+
613
+ else:
614
+ if "label" not in pairs_df.columns or pairs_df["label"].nunique() == 1:
615
+ positives_df = pairs_df[pairs_df.get("label", True)][
616
+ ["parent", "child"]
617
+ ].copy()
618
+ pos_rel = self._relations_from_df(positives_df)
619
+
620
+ tr_rel, ev_rel = taxonomy_split(
621
+ pos_rel,
622
+ train_terms=None,
623
+ test_size=self._eval_fraction,
624
+ random_state=self.random_state,
625
+ verbose=False,
626
+ )
627
+ train_pos_df = self._df_from_relations(tr_rel, label=True)
628
+ eval_pos_df = self._df_from_relations(ev_rel, label=True)
629
+
630
+ tr_rev_df, tr_man_df = self._make_negatives(train_pos_df)
631
+ ev_rev_df, ev_man_df = self._make_negatives(eval_pos_df)
632
+
633
+ train_df = self._balance_with_negatives(
634
+ train_pos_df, tr_rev_df, tr_man_df
635
+ )
636
+ eval_df = self._balance_with_negatives(
637
+ eval_pos_df, ev_rev_df, ev_man_df
638
+ )
639
+
640
+ train_df = self._add_prompt_columns(train_df)
641
+ eval_df = self._add_prompt_columns(eval_df)
642
+
643
+ else:
644
+ positives_df = pairs_df[pairs_df["label"]][["parent", "child"]].copy()
645
+ pos_rel = self._relations_from_df(positives_df)
646
+
647
+ tr_rel, ev_rel = taxonomy_split(
648
+ pos_rel,
649
+ train_terms=None,
650
+ test_size=self._eval_fraction,
651
+ random_state=self.random_state,
652
+ verbose=False,
653
+ )
654
+ train_pos_df = self._df_from_relations(tr_rel, label=True)
655
+ eval_pos_df = self._df_from_relations(ev_rel, label=True)
656
+
657
+ negatives_df = pairs_df[pairs_df["label"]][["parent", "child"]].copy()
658
+ negatives_df = negatives_df.sample(
659
+ frac=1.0, random_state=self.random_state
660
+ ).reset_index(drop=True)
661
+
662
+ n_eval_neg = (
663
+ max(1, int(len(negatives_df) * self._eval_fraction))
664
+ if len(negatives_df) > 0
665
+ else 0
666
+ )
667
+ eval_neg_df = (
668
+ negatives_df.iloc[:n_eval_neg].copy()
669
+ if n_eval_neg > 0
670
+ else negatives_df.iloc[:0].copy()
671
+ )
672
+ train_neg_df = negatives_df.iloc[n_eval_neg:].copy()
673
+
674
+ train_neg_df["label"] = False
675
+ eval_neg_df["label"] = False
676
+
677
+ train_df = pd.concat([train_pos_df, train_neg_df], ignore_index=True)
678
+ eval_df = pd.concat([eval_pos_df, eval_neg_df], ignore_index=True)
679
+
680
+ train_df = self._add_prompt_columns(train_df)
681
+ eval_df = self._add_prompt_columns(eval_df)
682
+
683
+ # Ensure labels are int64
684
+ train_df["label"] = train_df["label"].astype("int64")
685
+ eval_df["label"] = eval_df["label"].astype("int64")
686
+
687
+ # Sequential fine-tuning across prompts
688
+ for i in range(self.n_prompts):
689
+ prompt_col = f"prompt_{i + 1}"
690
+ train_ds = Dataset.from_pandas(
691
+ train_df[[prompt_col, "label"]].reset_index(drop=True)
692
+ )
693
+ eval_ds = Dataset.from_pandas(
694
+ eval_df[[prompt_col, "label"]].reset_index(drop=True)
695
+ )
696
+
697
+ train_ds = train_ds.rename_column("label", "labels")
698
+ eval_ds = eval_ds.rename_column("label", "labels")
699
+
700
+ def tokenize_batch(batch):
701
+ """Tokenize a batch for the current prompt column with truncation/padding."""
702
+ return self.tokenizer(
703
+ batch[prompt_col], padding="max_length", truncation=True
704
+ )
705
+
706
+ train_ds = train_ds.map(
707
+ tokenize_batch, batched=True, remove_columns=[prompt_col]
708
+ )
709
+ eval_ds = eval_ds.map(
710
+ tokenize_batch, batched=True, remove_columns=[prompt_col]
711
+ )
712
+
713
+ train_ds.set_format(
714
+ type="torch", columns=["input_ids", "attention_mask", "labels"]
715
+ )
716
+ eval_ds.set_format(
717
+ type="torch", columns=["input_ids", "attention_mask", "labels"]
718
+ )
719
+
720
+ trainer = Trainer(
721
+ model=self.model,
722
+ args=self.training_args,
723
+ train_dataset=train_ds,
724
+ eval_dataset=eval_ds,
725
+ )
726
+ trainer.train()
727
+
728
+ self._last_train = train_df
729
+ self._last_eval = eval_df
730
+ return None
731
+
732
+
733
+ class SKHNLPZSLearner(AutoLearner):
734
+ """
735
+ Zero-shot taxonomy learner using an instruction-tuned causal LLM.
736
+
737
+ Behavior
738
+ --------
739
+ - Builds a fixed classification prompt listing 9 GeoNames parent classes.
740
+ - For each input row (child term), generates a short completion and parses
741
+ the predicted class from a strict '#[ ... ]#' format.
742
+ - Optionally normalizes the raw prediction to one of the valid 9 labels via:
743
+ * "none" : keep the parsed text as-is
744
+ * "substring" : snap to a label if either is a substring of the other
745
+ * "levenshtein" : snap to the closest label by edit distance
746
+ * "auto" : substring, then Levenshtein if needed
747
+ - Saves raw and normalized predictions to CSV if `save_path` is provided.
748
+
749
+ Inputs the learner accepts (via `_to_dataframe`):
750
+ - pandas.DataFrame with columns: ['child', 'parent'] or ['child', 'parent', 'label']
751
+ - list[dict] with keys: 'child', 'parent' (and optionally 'label')
752
+ - list of tuples/lists: (child, parent) or (child, parent, label)
753
+ - OntoLearner-style object exposing .type_taxonomies.taxonomies iterable with (child, parent)
754
+ """
755
+
756
+ # Fixed class inventory (GeoNames parents)
757
+ CLASS_LIST = [
758
+ "city, village",
759
+ "country, state, region",
760
+ "forest, heath",
761
+ "mountain, hill, rock",
762
+ "parks, area",
763
+ "road, railroad",
764
+ "spot, building, farm",
765
+ "stream, lake",
766
+ "undersea",
767
+ ]
768
+
769
+ # Strict format: #[ ... ]#
770
+ _PREDICTION_PATTERN = re.compile(r"#\[\s*([^\]]+?)\s*\]#")
771
+
772
+ def __init__(
773
+ self,
774
+ model_name: str = "Qwen/Qwen2.5-0.5B-Instruct",
775
+ device: Optional[str] = None, # "cuda" | "cpu" | None (auto)
776
+ max_new_tokens: int = 16,
777
+ save_path: Optional[str] = None, # directory or full path
778
+ verbose: bool = True,
779
+ normalize_mode: str = "none", # "none" | "substring" | "levenshtein" | "auto"
780
+ random_state: int = 1403,
781
+ ) -> None:
782
+ """Configure the zero-shot learner.
783
+
784
+ Args:
785
+ model_name: HF model id/path for the instruction-tuned causal LLM.
786
+ device: Force device ('cuda' or 'cpu'), else auto-detect.
787
+ max_new_tokens: Generation length budget for each completion.
788
+ save_path: Optional CSV path or directory for saving predictions.
789
+ verbose: If True, print progress messages.
790
+ normalize_mode: Post-processing for class names
791
+ ('none' | 'substring' | 'levenshtein' | 'auto').
792
+ random_state: RNG seed for any sampling steps.
793
+ """
794
+ super().__init__()
795
+ self.model_name = model_name
796
+ self.verbose = verbose
797
+ self.max_new_tokens = max_new_tokens
798
+ self.save_path = save_path
799
+ self.normalize_mode = (normalize_mode or "none").lower().strip()
800
+ self.random_state = random_state
801
+
802
+ random.seed(self.random_state)
803
+
804
+ # Device: auto-detect CUDA if not specified
805
+ if device is None:
806
+ self._has_cuda = torch.cuda.is_available()
807
+ else:
808
+ self._has_cuda = device == "cuda"
809
+ self._pipe_device = 0 if self._has_cuda else -1
810
+ self._model_device_map = {"": "cuda"} if self._has_cuda else None
811
+
812
+ self._tokenizer = None
813
+ self._model = None
814
+ self._pipeline = None
815
+
816
+ # Prompt template used for every example
817
+ self._classification_prompt = (
818
+ "My task is classification. My classes are as follows: "
819
+ "(city, village), (country, state, region), (forest, heath), "
820
+ "(mountain, hill, rock), (parks, area), (road, railroad), "
821
+ "(spot, building, farm), (stream, lake), (undersea). "
822
+ 'I will provide you with a phrase like "wadi mouth". '
823
+ "The name of each class is placed within a pair of parentheses. "
824
+ "I want you to choose the most appropriate class from those mentioned above "
825
+ "based on the given phrase and present it in a format like #[parks, area]#. "
826
+ "So, the general format for each response will be #[class name]#. "
827
+ "Pay attention to the format of the response. Start with a '#' character, "
828
+ "include the class name inside it, and end with another '#' character. "
829
+ "Additionally, make sure to include a '#' character at the end to indicate "
830
+ "that the answer is complete. I don't need any additional explanations."
831
+ )
832
+
833
+ def load(self, model_id: str = "") -> None:
834
+ """
835
+ Load tokenizer, model, and text-generation pipeline.
836
+
837
+ Args:
838
+ model_id: Optional HF id/path override; defaults to `self.model_name`.
839
+
840
+ Side Effects:
841
+ Initializes the tokenizer and model, configures the generation
842
+ pipeline on CPU/GPU, and sets a pad token if absent.
843
+ """
844
+ model_id = model_id or self.model_name
845
+ if self.verbose:
846
+ print(f"[ZeroShotTaxonomyLearner] Loading {model_id}")
847
+
848
+ self._tokenizer = AutoTokenizer.from_pretrained(model_id)
849
+
850
+ # Ensure a pad token is set for generation
851
+ if (
852
+ self._tokenizer.pad_token_id is None
853
+ and self._tokenizer.eos_token_id is not None
854
+ ):
855
+ self._tokenizer.pad_token = self._tokenizer.eos_token
856
+
857
+ self._model = AutoModelForCausalLM.from_pretrained(
858
+ model_id,
859
+ device_map=self._model_device_map,
860
+ torch_dtype="auto",
861
+ )
862
+
863
+ self._pipeline = pipeline(
864
+ task="text-generation",
865
+ model=self._model,
866
+ tokenizer=self._tokenizer,
867
+ device=self._pipe_device, # 0 for GPU, -1 for CPU
868
+ )
869
+
870
+ if self.verbose:
871
+ print("Device set to use", "cuda" if self._has_cuda else "cpu")
872
+ print("[ZeroShotTaxonomyLearner] Model loaded.")
873
+
874
+ def _taxonomy_discovery(
875
+ self, data: Any, test: bool = False
876
+ ) -> Optional[List[Dict[str, str]]]:
877
+ """
878
+ Zero-shot prediction over all incoming rows (no filtering/augmentation).
879
+
880
+ Args:
881
+ data: One of {DataFrame, list[dict], list[tuple], Ontology-like}.
882
+ test: If False, training is skipped (zero-shot learner), and None is returned.
883
+
884
+ Returns:
885
+ On `test=True`, a list of dicts [{'parent': predicted_label, 'child': child}, ...].
886
+ On `test=False`, returns None.
887
+ """
888
+ if not test:
889
+ if self.verbose:
890
+ print("[ZeroShot] Training skipped (zero-shot).")
891
+ return None
892
+
893
+ df = self._to_dataframe(data)
894
+
895
+ if self.verbose:
896
+ print(f"[ZeroShot] Incoming rows: {len(df)}; columns: {list(df.columns)}")
897
+
898
+ eval_df = pd.DataFrame(df).reset_index(drop=True)
899
+ if eval_df.empty:
900
+ return []
901
+
902
+ # Prepare columns for inspection and saving
903
+ eval_df["prediction_raw"] = ""
904
+ eval_df["prediction_sub"] = ""
905
+ eval_df["prediction_lvn"] = ""
906
+ eval_df["prediction_auto"] = ""
907
+ eval_df["prediction"] = "" # final (per normalize_mode)
908
+
909
+ # Generate predictions row by row
910
+ for idx, row in eval_df.iterrows():
911
+ child_term = str(row["child"])
912
+ raw_text, parsed_raw = self._generate_and_parse(child_term)
913
+
914
+ # Choose a string to normalize (parsed token if matched, otherwise whole output)
915
+ basis = parsed_raw if parsed_raw != "unknown" else raw_text
916
+
917
+ # Compute all normalization variants
918
+ sub_norm = self._normalize_substring_only(basis)
919
+ lvn_norm = self._normalize_levenshtein_only(basis)
920
+ auto_norm = self._normalize_auto(basis)
921
+
922
+ # Final selection by mode
923
+ if self.normalize_mode == "none":
924
+ final_label = parsed_raw
925
+ elif self.normalize_mode == "substring":
926
+ final_label = sub_norm
927
+ elif self.normalize_mode == "levenshtein":
928
+ final_label = lvn_norm
929
+ elif self.normalize_mode == "auto":
930
+ final_label = auto_norm
931
+ else:
932
+ final_label = parsed_raw # fallback
933
+
934
+ # Persist to DataFrame for inspection/export
935
+ eval_df.at[idx, "prediction_raw"] = parsed_raw
936
+ eval_df.at[idx, "prediction_sub"] = sub_norm
937
+ eval_df.at[idx, "prediction_lvn"] = lvn_norm
938
+ eval_df.at[idx, "prediction_auto"] = auto_norm
939
+ eval_df.at[idx, "prediction"] = final_label
940
+
941
+ # Return in the format expected by the pipeline
942
+ return [
943
+ {"parent": p, "child": c}
944
+ for p, c in zip(eval_df["prediction"], eval_df["child"])
945
+ ]
946
+
947
+ def _generate_and_parse(self, child_term: str) -> (str, str):
948
+ """
949
+ Generate a completion for the given child term and extract the raw predicted class
950
+ using the strict '#[ ... ]#' pattern.
951
+
952
+ Args:
953
+ child_term: The child label to classify into one of the fixed classes.
954
+
955
+ Returns:
956
+ Tuple `(raw_generation_text, parsed_prediction_or_unknown)`, where the second
957
+ element is either the text inside '#[ ... ]#' or the string 'unknown'.
958
+ """
959
+ messages = [
960
+ {"role": "system", "content": "You are a helpful classifier."},
961
+ {"role": "user", "content": f"{self._classification_prompt} {child_term}"},
962
+ ]
963
+
964
+ prompt = self._tokenizer.apply_chat_template(
965
+ messages,
966
+ tokenize=False,
967
+ add_generation_prompt=True,
968
+ )
969
+
970
+ generation = self._pipeline(
971
+ prompt,
972
+ max_new_tokens=self.max_new_tokens,
973
+ do_sample=False,
974
+ temperature=0.0,
975
+ top_p=1.0,
976
+ eos_token_id=self._tokenizer.eos_token_id,
977
+ pad_token_id=self._tokenizer.eos_token_id,
978
+ return_full_text=False,
979
+ )[0]["generated_text"]
980
+
981
+ match = self._PREDICTION_PATTERN.search(generation)
982
+ parsed = match.group(1).strip() if match else "unknown"
983
+ return generation, parsed
984
+
985
+ def _normalize_substring_only(self, text: str) -> str:
986
+ """
987
+ Snap to a label if the string is equal to / contained in / contains a valid label (case-insensitive).
988
+
989
+ Args:
990
+ text: Raw class text to normalize.
991
+
992
+ Returns:
993
+ One of `CLASS_LIST` on a match; otherwise 'unknown'.
994
+ """
995
+ if not isinstance(text, str):
996
+ return "unknown"
997
+ lowered = text.strip().lower()
998
+ if not lowered:
999
+ return "unknown"
1000
+
1001
+ for label in self.CLASS_LIST:
1002
+ label_lower = label.lower()
1003
+ if (
1004
+ lowered == label_lower
1005
+ or lowered in label_lower
1006
+ or label_lower in lowered
1007
+ ):
1008
+ return label
1009
+ return "unknown"
1010
+
1011
+ def _normalize_levenshtein_only(self, text: str) -> str:
1012
+ """
1013
+ Snap to the nearest label by Levenshtein (edit) distance.
1014
+
1015
+ Args:
1016
+ text: Raw class text to normalize.
1017
+
1018
+ Returns:
1019
+ The nearest label in `CLASS_LIST`, or 'unknown' if input is empty/invalid.
1020
+ """
1021
+ if not isinstance(text, str):
1022
+ return "unknown"
1023
+ lowered = text.strip().lower()
1024
+ if not lowered:
1025
+ return "unknown"
1026
+
1027
+ best_label = None
1028
+ best_distance = 10**9
1029
+ for label in self.CLASS_LIST:
1030
+ label_lower = label.lower()
1031
+ distance = Levenshtein.distance(lowered, label_lower)
1032
+ if distance < best_distance:
1033
+ best_distance = distance
1034
+ best_label = label
1035
+ return best_label or "unknown"
1036
+
1037
+ def _normalize_auto(self, text: str) -> str:
1038
+ """
1039
+ Cascade: try substring-first; if no match, fall back to Levenshtein snapping.
1040
+
1041
+ Args:
1042
+ text: Raw class text to normalize.
1043
+
1044
+ Returns:
1045
+ Normalized label string or 'unknown'.
1046
+ """
1047
+ snapped = self._normalize_substring_only(text)
1048
+ return (
1049
+ snapped if snapped != "unknown" else self._normalize_levenshtein_only(text)
1050
+ )
1051
+
1052
+ def _to_dataframe(self, data: Any) -> pd.DataFrame:
1053
+ """
1054
+ Normalize various input formats into a DataFrame.
1055
+
1056
+ Supported inputs:
1057
+ * pandas.DataFrame with columns ['child','parent',('label')]
1058
+ * list[dict] with keys 'child','parent',('label')
1059
+ * list of tuples/lists: (child, parent) or (child, parent, label)
1060
+ * Ontology-like object with `.type_taxonomies.taxonomies`
1061
+
1062
+ Args:
1063
+ data: The source object to normalize.
1064
+
1065
+ Returns:
1066
+ A pandas DataFrame with standardized columns.
1067
+
1068
+ Raises:
1069
+ ValueError: If the input type/shape is not recognized.
1070
+ """
1071
+ if isinstance(data, pd.DataFrame):
1072
+ df = data.copy()
1073
+ df.columns = [str(c).lower() for c in df.columns]
1074
+ return df.reset_index(drop=True)
1075
+
1076
+ if isinstance(data, list) and data and isinstance(data[0], dict):
1077
+ rows = [{str(k).lower(): v for k, v in d.items()} for d in data]
1078
+ return pd.DataFrame(rows).reset_index(drop=True)
1079
+
1080
+ if isinstance(data, (list, tuple)) and data:
1081
+ first = data[0]
1082
+ if isinstance(first, (list, tuple)) and not isinstance(first, dict):
1083
+ n = len(first)
1084
+ if n >= 3:
1085
+ return pd.DataFrame(
1086
+ data, columns=["child", "parent", "label"]
1087
+ ).reset_index(drop=True)
1088
+ if n == 2:
1089
+ return pd.DataFrame(data, columns=["child", "parent"]).reset_index(
1090
+ drop=True
1091
+ )
1092
+
1093
+ try:
1094
+ type_taxonomies = getattr(data, "type_taxonomies", None)
1095
+ if type_taxonomies is not None:
1096
+ taxonomies = getattr(type_taxonomies, "taxonomies", None)
1097
+ if taxonomies is not None:
1098
+ rows = []
1099
+ for rel in taxonomies:
1100
+ parent = getattr(rel, "parent", None)
1101
+ child = getattr(rel, "child", None)
1102
+ label = (
1103
+ getattr(rel, "label", None)
1104
+ if hasattr(rel, "label")
1105
+ else None
1106
+ )
1107
+ if parent is not None and child is not None:
1108
+ rows.append(
1109
+ {"child": child, "parent": parent, "label": label}
1110
+ )
1111
+ if rows:
1112
+ return pd.DataFrame(rows).reset_index(drop=True)
1113
+ except Exception:
1114
+ pass
1115
+
1116
+ raise ValueError(
1117
+ "Unsupported data format. Provide a DataFrame, a list of dicts, "
1118
+ "a list of (child, parent[, label]) tuples/lists, or an object with "
1119
+ ".type_taxonomies.taxonomies."
1120
+ )
1121
+
1122
+ def _resolve_save_path(self, save_path: str, default_filename: str) -> str:
1123
+ """
1124
+ Resolve a target file path from a directory or path-like input.
1125
+
1126
+ If `save_path` points to a directory, joins it with `default_filename`.
1127
+ If it already looks like a file path (has an extension), returns as-is.
1128
+
1129
+ Args:
1130
+ save_path: Directory or file path supplied by the caller.
1131
+ default_filename: Basename to use when `save_path` is a directory.
1132
+
1133
+ Returns:
1134
+ A concrete file path where outputs can be written.
1135
+ """
1136
+ base = os.path.basename(save_path)
1137
+ has_ext = os.path.splitext(base)[1] != ""
1138
+ return save_path if has_ext else os.path.join(save_path, default_filename)