OntoLearner 1.4.7__py3-none-any.whl → 1.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ontolearner/VERSION +1 -1
- ontolearner/base/learner.py +15 -12
- ontolearner/learner/__init__.py +1 -1
- ontolearner/learner/label_mapper.py +1 -1
- ontolearner/learner/retriever/__init__.py +19 -0
- ontolearner/learner/retriever/crossencoder.py +129 -0
- ontolearner/learner/retriever/embedding.py +229 -0
- ontolearner/learner/retriever/learner.py +217 -0
- ontolearner/learner/retriever/llm_retriever.py +356 -0
- ontolearner/learner/retriever/ngram.py +123 -0
- ontolearner/learner/taxonomy_discovery/__init__.py +18 -0
- ontolearner/learner/taxonomy_discovery/alexbek.py +500 -0
- ontolearner/learner/taxonomy_discovery/rwthdbis.py +1082 -0
- ontolearner/learner/taxonomy_discovery/sbunlp.py +402 -0
- ontolearner/learner/taxonomy_discovery/skhnlp.py +1138 -0
- ontolearner/learner/term_typing/__init__.py +17 -0
- ontolearner/learner/term_typing/alexbek.py +1262 -0
- ontolearner/learner/term_typing/rwthdbis.py +379 -0
- ontolearner/learner/term_typing/sbunlp.py +478 -0
- ontolearner/learner/text2onto/__init__.py +16 -0
- ontolearner/learner/text2onto/alexbek.py +1219 -0
- ontolearner/learner/text2onto/sbunlp.py +598 -0
- {ontolearner-1.4.7.dist-info → ontolearner-1.4.9.dist-info}/METADATA +16 -12
- {ontolearner-1.4.7.dist-info → ontolearner-1.4.9.dist-info}/RECORD +26 -9
- ontolearner/learner/retriever.py +0 -101
- {ontolearner-1.4.7.dist-info → ontolearner-1.4.9.dist-info}/WHEEL +0 -0
- {ontolearner-1.4.7.dist-info → ontolearner-1.4.9.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1138 @@
|
|
|
1
|
+
# Copyright (c) 2025 SciKnowOrg
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the MIT License (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# https://opensource.org/licenses/MIT
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
import re
|
|
17
|
+
import random
|
|
18
|
+
|
|
19
|
+
import pandas as pd
|
|
20
|
+
import torch
|
|
21
|
+
import Levenshtein
|
|
22
|
+
from datasets import Dataset
|
|
23
|
+
from typing import Any, Optional, List, Tuple, Dict
|
|
24
|
+
from transformers import (
|
|
25
|
+
AutoTokenizer,
|
|
26
|
+
AutoModelForSequenceClassification,
|
|
27
|
+
AutoModelForCausalLM,
|
|
28
|
+
BertTokenizer,
|
|
29
|
+
BertForSequenceClassification,
|
|
30
|
+
pipeline,
|
|
31
|
+
Trainer,
|
|
32
|
+
TrainingArguments,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
from ...base import AutoLearner, AutoPrompt
|
|
36
|
+
from ...utils import taxonomy_split, train_test_split as ontology_split
|
|
37
|
+
from ...data_structure import OntologyData, TaxonomicRelation
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class SKHNLPTaxonomyPrompts(AutoPrompt):
|
|
41
|
+
"""Builds the 7 taxonomy prompts used during fine-tuning / inference.
|
|
42
|
+
|
|
43
|
+
The class stores a small inventory of prompt templates that verbalize the
|
|
44
|
+
(parent, child) relationship using different phrasings. Each template ends
|
|
45
|
+
with a masked token slot intended for True/False classification.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(self) -> None:
|
|
49
|
+
"""Initialize prompt templates and the default prompt in the base class."""
|
|
50
|
+
super().__init__(
|
|
51
|
+
prompt_template="{parent} is the superclass of {child}. This statement is [MASK]."
|
|
52
|
+
)
|
|
53
|
+
self.templates: List[str] = [
|
|
54
|
+
"{parent} is the superclass of {child}. This statement is [MASK].",
|
|
55
|
+
"{child} is a subclass of {parent}. This statement is [MASK].",
|
|
56
|
+
"{parent} is the parent class of {child}. This statement is [MASK].",
|
|
57
|
+
"{child} is a child class of {parent}. This statement is [MASK].",
|
|
58
|
+
"{parent} is a supertype of {child}. This statement is [MASK].",
|
|
59
|
+
"{child} is a subtype of {parent}. This statement is [MASK].",
|
|
60
|
+
"{parent} is an ancestor class of {child}. This statement is [MASK].",
|
|
61
|
+
]
|
|
62
|
+
|
|
63
|
+
def format(self, parent: str, child: str, template_idx: int) -> str:
|
|
64
|
+
"""Render a prompt for a (parent, child) pair using a specific template.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
parent: The parent/superclass label.
|
|
68
|
+
child: The child/subclass label.
|
|
69
|
+
template_idx: Index into the internal `templates` list.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
The fully formatted prompt string.
|
|
73
|
+
"""
|
|
74
|
+
return self.templates[template_idx].format(parent=parent, child=child)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class SKHNLPSequentialFTLearner(AutoLearner):
|
|
78
|
+
"""
|
|
79
|
+
BERT-based classifier for taxonomy discovery.
|
|
80
|
+
|
|
81
|
+
With OntologyData:
|
|
82
|
+
* TRAIN: ontology-aware split; create balanced train/eval with negatives.
|
|
83
|
+
* PREDICT/TEST: notebook-style parent selection -> list[{'parent', 'child'}].
|
|
84
|
+
|
|
85
|
+
With DataFrame/list:
|
|
86
|
+
* TRAIN: taxonomy_split + negatives; build prompts and fine-tune.
|
|
87
|
+
* PREDICT/TEST: pairwise binary classification (returns label + score).
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
# core
|
|
93
|
+
model_name: str = "bert-large-uncased",
|
|
94
|
+
n_prompts: int = 7,
|
|
95
|
+
random_state: int = 1403,
|
|
96
|
+
num_labels: int = 2,
|
|
97
|
+
device: str = "cpu", # "cuda" | "cpu" | None (auto)
|
|
98
|
+
# data split & negative sampling (now configurable)
|
|
99
|
+
eval_fraction: float = 0.16,
|
|
100
|
+
neg_ratio_reversed: float = 1 / 3,
|
|
101
|
+
neg_ratio_manipulated: float = 2 / 3,
|
|
102
|
+
# ---- expose TrainingArguments as individual user-defined args ----
|
|
103
|
+
output_dir: str = "./results/",
|
|
104
|
+
num_train_epochs: int = 1,
|
|
105
|
+
per_device_train_batch_size: int = 4,
|
|
106
|
+
per_device_eval_batch_size: int = 4,
|
|
107
|
+
warmup_steps: int = 500,
|
|
108
|
+
weight_decay: float = 0.01,
|
|
109
|
+
logging_dir: str = "./logs/",
|
|
110
|
+
logging_steps: int = 50,
|
|
111
|
+
eval_strategy: str = "epoch",
|
|
112
|
+
save_strategy: str = "epoch",
|
|
113
|
+
load_best_model_at_end: bool = True,
|
|
114
|
+
use_fast_tokenizer: Optional[bool] = None,
|
|
115
|
+
trust_remote_code: bool = False,
|
|
116
|
+
) -> None:
|
|
117
|
+
"""Configure the sequential fine-tuning learner.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
model_name: HF model id or local path for the BERT backbone.
|
|
121
|
+
n_prompts: Number of prompt variants to iterate over sequentially.
|
|
122
|
+
random_state: RNG seed for shuffling/sampling steps.
|
|
123
|
+
num_labels: Number of classes for the classifier head.
|
|
124
|
+
device: Force device ('cuda' or 'cpu'). If None, auto-detects CUDA.
|
|
125
|
+
eval_fraction: Fraction of positives to hold out for evaluation.
|
|
126
|
+
neg_ratio_reversed: Proportion of reversed-parent negatives vs positives.
|
|
127
|
+
neg_ratio_manipulated: Proportion of random-parent negatives vs positives.
|
|
128
|
+
output_dir: Directory where HF Trainer writes checkpoints/outputs.
|
|
129
|
+
num_train_epochs: Number of epochs per prompt.
|
|
130
|
+
per_device_train_batch_size: Training batch size per device.
|
|
131
|
+
per_device_eval_batch_size: Evaluation batch size per device.
|
|
132
|
+
warmup_steps: Linear warmup steps for LR scheduler.
|
|
133
|
+
weight_decay: Weight decay coefficient.
|
|
134
|
+
logging_dir: Directory for Trainer logs.
|
|
135
|
+
logging_steps: Interval for log events (in steps).
|
|
136
|
+
eval_strategy: Evaluation schedule ('no', 'steps', 'epoch').
|
|
137
|
+
save_strategy: Checkpoint save schedule ('no', 'steps', 'epoch').
|
|
138
|
+
load_best_model_at_end: Whether to restore the best checkpoint.
|
|
139
|
+
use_fast_tokenizer: Force fast/slow tokenizer. If None, try fast then fallback to slow.
|
|
140
|
+
Notes:
|
|
141
|
+
The model is fine-tuned *sequentially* across prompt columns.
|
|
142
|
+
You can control the eval split and negative sampling mix via
|
|
143
|
+
`eval_fraction`, `neg_ratio_reversed`, and `neg_ratio_manipulated`.
|
|
144
|
+
"""
|
|
145
|
+
super().__init__()
|
|
146
|
+
self.model_name = model_name
|
|
147
|
+
self.n_prompts = n_prompts
|
|
148
|
+
self.random_state = random_state
|
|
149
|
+
self.num_labels = num_labels
|
|
150
|
+
self.device = device
|
|
151
|
+
|
|
152
|
+
# user-tunable ratios / split
|
|
153
|
+
self._eval_fraction = float(eval_fraction)
|
|
154
|
+
self._neg_ratio_reversed = float(neg_ratio_reversed)
|
|
155
|
+
self._neg_ratio_manipulated = float(neg_ratio_manipulated)
|
|
156
|
+
if not (0.0 < self._eval_fraction < 1.0):
|
|
157
|
+
raise ValueError("eval_fraction must be in (0, 1).")
|
|
158
|
+
if self._neg_ratio_reversed < 0 or self._neg_ratio_manipulated < 0:
|
|
159
|
+
raise ValueError("neg_ratio_* must be >= 0.")
|
|
160
|
+
|
|
161
|
+
self.tokenizer: Optional[BertTokenizer] = None
|
|
162
|
+
self.model: Optional[BertForSequenceClassification] = None
|
|
163
|
+
self.prompter = SKHNLPTaxonomyPrompts()
|
|
164
|
+
|
|
165
|
+
# Candidate parents (unique parent list) for multi-class parent selection.
|
|
166
|
+
self._candidate_parents: Optional[List[str]] = None
|
|
167
|
+
|
|
168
|
+
# Keep last train/eval tables for inspection
|
|
169
|
+
self._last_train: Optional[pd.DataFrame] = None
|
|
170
|
+
self._last_eval: Optional[pd.DataFrame] = None
|
|
171
|
+
self.trust_remote_code = bool(trust_remote_code)
|
|
172
|
+
self.use_fast_tokenizer = use_fast_tokenizer
|
|
173
|
+
|
|
174
|
+
random.seed(self.random_state)
|
|
175
|
+
|
|
176
|
+
# Build TrainingArguments from the individual user-defined values
|
|
177
|
+
self.training_args = TrainingArguments(
|
|
178
|
+
output_dir=output_dir,
|
|
179
|
+
num_train_epochs=num_train_epochs,
|
|
180
|
+
per_device_train_batch_size=per_device_train_batch_size,
|
|
181
|
+
per_device_eval_batch_size=per_device_eval_batch_size,
|
|
182
|
+
warmup_steps=warmup_steps,
|
|
183
|
+
weight_decay=weight_decay,
|
|
184
|
+
logging_dir=logging_dir,
|
|
185
|
+
logging_steps=logging_steps,
|
|
186
|
+
eval_strategy=eval_strategy,
|
|
187
|
+
save_strategy=save_strategy,
|
|
188
|
+
load_best_model_at_end=load_best_model_at_end,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
def load(self, model_id: Optional[str] = None, **_: Any) -> None:
|
|
192
|
+
"""Load tokenizer & model in a backbone-agnostic way; move model to self.device."""
|
|
193
|
+
model_id = model_id or self.model_name
|
|
194
|
+
|
|
195
|
+
# ---- Tokenizer (robust fast→slow fallback unless explicitly set) ----
|
|
196
|
+
if self.use_fast_tokenizer is None:
|
|
197
|
+
try:
|
|
198
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
199
|
+
model_id, use_fast=True, trust_remote_code=self.trust_remote_code
|
|
200
|
+
)
|
|
201
|
+
except Exception as fast_err:
|
|
202
|
+
print(
|
|
203
|
+
f"[tokenizer] Fast tokenizer failed: {fast_err}. Falling back to slow tokenizer..."
|
|
204
|
+
)
|
|
205
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
206
|
+
model_id, use_fast=False, trust_remote_code=self.trust_remote_code
|
|
207
|
+
)
|
|
208
|
+
else:
|
|
209
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
210
|
+
model_id,
|
|
211
|
+
use_fast=self.use_fast_tokenizer,
|
|
212
|
+
trust_remote_code=self.trust_remote_code,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
# Ensure pad token exists (some models lack it)
|
|
216
|
+
if getattr(self.tokenizer, "pad_token", None) is None:
|
|
217
|
+
# Try sensible fallbacks
|
|
218
|
+
fallback = (
|
|
219
|
+
getattr(self.tokenizer, "eos_token", None)
|
|
220
|
+
or getattr(self.tokenizer, "sep_token", None)
|
|
221
|
+
or getattr(self.tokenizer, "cls_token", None)
|
|
222
|
+
)
|
|
223
|
+
if fallback is not None:
|
|
224
|
+
self.tokenizer.pad_token = fallback
|
|
225
|
+
|
|
226
|
+
# ---- Model (classifier head sized to self.num_labels) ----
|
|
227
|
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
|
228
|
+
model_id,
|
|
229
|
+
num_labels=self.num_labels,
|
|
230
|
+
trust_remote_code=self.trust_remote_code,
|
|
231
|
+
# Allows swapping in a new head size even if the checkpoint differs
|
|
232
|
+
ignore_mismatched_sizes=True,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# Make sure padding ids line up
|
|
236
|
+
if (
|
|
237
|
+
getattr(self.model.config, "pad_token_id", None) is None
|
|
238
|
+
and getattr(self.tokenizer, "pad_token_id", None) is not None
|
|
239
|
+
):
|
|
240
|
+
self.model.config.pad_token_id = self.tokenizer.pad_token_id
|
|
241
|
+
|
|
242
|
+
# Set problem type (single-label classification by default)
|
|
243
|
+
# If you plan multi-label, you'd switch to "multi_label_classification"
|
|
244
|
+
self.model.config.problem_type = "single_label_classification"
|
|
245
|
+
|
|
246
|
+
# Move to target device
|
|
247
|
+
self.model.to(self.device)
|
|
248
|
+
|
|
249
|
+
def tasks_ground_truth_former(self, data: Any, task: str) -> Any:
|
|
250
|
+
"""Normalize ground-truth inputs for 'taxonomy-discovery'.
|
|
251
|
+
|
|
252
|
+
Supports DataFrame with columns ['parent','child',('label')],
|
|
253
|
+
list of dicts, or falls back to the base class behavior.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
data: Input object to normalize.
|
|
257
|
+
task: Task name, passed from the outer pipeline.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
A list of dictionaries with keys 'parent', 'child', and optionally
|
|
261
|
+
'label' when present in the input.
|
|
262
|
+
"""
|
|
263
|
+
if task != "taxonomy-discovery":
|
|
264
|
+
return super().tasks_ground_truth_former(data, task)
|
|
265
|
+
|
|
266
|
+
if isinstance(data, pd.DataFrame):
|
|
267
|
+
if "label" in data.columns:
|
|
268
|
+
return [
|
|
269
|
+
{"parent": p, "child": c, "label": bool(lbl)}
|
|
270
|
+
for p, c, lbl in zip(data["parent"], data["child"], data["label"])
|
|
271
|
+
]
|
|
272
|
+
return [
|
|
273
|
+
{"parent": p, "child": c} for p, c in zip(data["parent"], data["child"])
|
|
274
|
+
]
|
|
275
|
+
|
|
276
|
+
if isinstance(data, list):
|
|
277
|
+
return data
|
|
278
|
+
|
|
279
|
+
return super().tasks_ground_truth_former(data, task)
|
|
280
|
+
|
|
281
|
+
def _make_negatives(
|
|
282
|
+
self, positives_df: pd.DataFrame
|
|
283
|
+
) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
|
284
|
+
"""Create two types of negatives from a positives table.
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
A tuple `(reversed_df, manipulated_df)` where:
|
|
288
|
+
- `reversed_df`: pairs with parent/child columns swapped, label=False.
|
|
289
|
+
- `manipulated_df`: pairs with the parent replaced by a random
|
|
290
|
+
*different* parent from the same pool, label=False.
|
|
291
|
+
|
|
292
|
+
Notes:
|
|
293
|
+
The input DataFrame must contain columns ['parent', 'child'].
|
|
294
|
+
"""
|
|
295
|
+
unique_parents = positives_df["parent"].unique().tolist()
|
|
296
|
+
|
|
297
|
+
def as_reversed(df: pd.DataFrame) -> pd.DataFrame:
|
|
298
|
+
out = df.copy()
|
|
299
|
+
out[["parent", "child"]] = out[["child", "parent"]].values
|
|
300
|
+
out["label"] = False
|
|
301
|
+
return out
|
|
302
|
+
|
|
303
|
+
def with_random_parent(df: pd.DataFrame) -> pd.DataFrame:
|
|
304
|
+
def pick_other_parent(p: str) -> str:
|
|
305
|
+
pool = [x for x in unique_parents if x != p]
|
|
306
|
+
return random.choice(pool) if pool else p
|
|
307
|
+
|
|
308
|
+
out = df.copy()
|
|
309
|
+
out["parent"] = out["parent"].apply(pick_other_parent)
|
|
310
|
+
out["label"] = False
|
|
311
|
+
return out
|
|
312
|
+
|
|
313
|
+
return as_reversed(positives_df), with_random_parent(positives_df)
|
|
314
|
+
|
|
315
|
+
def _balance_with_negatives(
|
|
316
|
+
self,
|
|
317
|
+
positives_df: pd.DataFrame,
|
|
318
|
+
reversed_df: pd.DataFrame,
|
|
319
|
+
manipulated_df: pd.DataFrame,
|
|
320
|
+
) -> pd.DataFrame:
|
|
321
|
+
"""Combine positives with negatives using configured ratios.
|
|
322
|
+
|
|
323
|
+
Sampling ratios are defined by the instance settings
|
|
324
|
+
`self._neg_ratio_reversed` and `self._neg_ratio_manipulated`,
|
|
325
|
+
keeping the positives count unchanged.
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
positives_df: Positive pairs with `label=True`.
|
|
329
|
+
reversed_df: Negative pairs produced by flipping parent/child.
|
|
330
|
+
manipulated_df: Negative pairs with randomly reassigned parents.
|
|
331
|
+
|
|
332
|
+
Returns:
|
|
333
|
+
A deduplicated, shuffled DataFrame with a class-balanced mix.
|
|
334
|
+
"""
|
|
335
|
+
n_pos = len(positives_df)
|
|
336
|
+
n_rev = int(n_pos * self._neg_ratio_reversed)
|
|
337
|
+
n_man = int(n_pos * self._neg_ratio_manipulated)
|
|
338
|
+
|
|
339
|
+
combined = pd.concat(
|
|
340
|
+
[
|
|
341
|
+
positives_df.sample(n_pos, random_state=self.random_state),
|
|
342
|
+
reversed_df.sample(n_rev, random_state=self.random_state),
|
|
343
|
+
manipulated_df.sample(n_man, random_state=self.random_state),
|
|
344
|
+
],
|
|
345
|
+
ignore_index=True,
|
|
346
|
+
)
|
|
347
|
+
combined = combined.drop_duplicates(
|
|
348
|
+
subset=["parent", "child", "label"]
|
|
349
|
+
).reset_index(drop=True)
|
|
350
|
+
return combined
|
|
351
|
+
|
|
352
|
+
def _add_prompt_columns(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
353
|
+
"""Append one column per prompt variant to the given pairs table.
|
|
354
|
+
|
|
355
|
+
For each row `(parent, child)`, creates columns `prompt_1 ... prompt_n`.
|
|
356
|
+
|
|
357
|
+
Args:
|
|
358
|
+
df: Input DataFrame with columns ['parent', 'child', ...].
|
|
359
|
+
|
|
360
|
+
Returns:
|
|
361
|
+
A copy of `df` including the newly added prompt columns.
|
|
362
|
+
"""
|
|
363
|
+
out = df.copy()
|
|
364
|
+
for i in range(self.n_prompts):
|
|
365
|
+
out[f"prompt_{i + 1}"] = out.apply(
|
|
366
|
+
lambda r, k=i: self.prompter.format(r["parent"], r["child"], k), axis=1
|
|
367
|
+
)
|
|
368
|
+
return out
|
|
369
|
+
|
|
370
|
+
def _df_from_relations(
|
|
371
|
+
self, relations: List[TaxonomicRelation], label: bool = True
|
|
372
|
+
) -> pd.DataFrame:
|
|
373
|
+
"""Convert a list of `TaxonomicRelation` to a DataFrame.
|
|
374
|
+
|
|
375
|
+
Args:
|
|
376
|
+
relations: Iterable of `TaxonomicRelation(parent, child)`.
|
|
377
|
+
label: Class label to assign to all resulting rows.
|
|
378
|
+
|
|
379
|
+
Returns:
|
|
380
|
+
DataFrame with columns ['parent', 'child', 'label'].
|
|
381
|
+
"""
|
|
382
|
+
if not relations:
|
|
383
|
+
return pd.DataFrame(columns=["parent", "child", "label"])
|
|
384
|
+
return pd.DataFrame(
|
|
385
|
+
[{"parent": r.parent, "child": r.child, "label": label} for r in relations]
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
def _relations_from_df(self, df: pd.DataFrame) -> List[TaxonomicRelation]:
|
|
389
|
+
"""Convert a DataFrame to a list of `TaxonomicRelation`.
|
|
390
|
+
|
|
391
|
+
Args:
|
|
392
|
+
df: DataFrame with columns ['parent', 'child'].
|
|
393
|
+
|
|
394
|
+
Returns:
|
|
395
|
+
List of `TaxonomicRelation` objects in row order.
|
|
396
|
+
"""
|
|
397
|
+
return [
|
|
398
|
+
TaxonomicRelation(parent=p, child=c)
|
|
399
|
+
for p, c in zip(df["parent"], df["child"])
|
|
400
|
+
]
|
|
401
|
+
|
|
402
|
+
def _build_masked_prompt(
|
|
403
|
+
self, parent: str, child: str, index_1_based: int, mask_token: str = "[MASK]"
|
|
404
|
+
) -> str:
|
|
405
|
+
"""Construct one of several True/False prompts with a mask token.
|
|
406
|
+
|
|
407
|
+
Args:
|
|
408
|
+
parent: Parent label.
|
|
409
|
+
child: Child label.
|
|
410
|
+
index_1_based: 1-based index selecting a template.
|
|
411
|
+
mask_token: The token used to denote the masked label.
|
|
412
|
+
|
|
413
|
+
Returns:
|
|
414
|
+
A formatted prompt string.
|
|
415
|
+
"""
|
|
416
|
+
prompts_1based = [
|
|
417
|
+
f"{parent} is the superclass of {child}. This statement is {mask_token}.",
|
|
418
|
+
f"{child} is a subclass of {parent}. This statement is {mask_token}.",
|
|
419
|
+
f"{parent} is the parent class of {child}. This statement is {mask_token}.",
|
|
420
|
+
f"{child} is a child class of {parent}. This statement is {mask_token}.",
|
|
421
|
+
f"{parent} is a supertype of {child}. This statement is {mask_token}.",
|
|
422
|
+
f"{child} is a subtype of {parent}. This statement is {mask_token}.",
|
|
423
|
+
f"{parent} is an ancestor class of {child}. This statement is {mask_token}.",
|
|
424
|
+
f"{child} is a descendant classs of {child}. This statement is {mask_token}.",
|
|
425
|
+
f'"{parent}" is the superclass of "{child}". This statement is {mask_token}.',
|
|
426
|
+
]
|
|
427
|
+
return prompts_1based[index_1_based - 1]
|
|
428
|
+
|
|
429
|
+
@torch.no_grad()
|
|
430
|
+
def _predict_prompt_true_false(self, sentence: str) -> bool:
|
|
431
|
+
"""Run a single True/False prediction on a prompt.
|
|
432
|
+
|
|
433
|
+
Args:
|
|
434
|
+
sentence: Fully formatted prompt text.
|
|
435
|
+
|
|
436
|
+
Returns:
|
|
437
|
+
True iff the predicted class index is 1 (positive).
|
|
438
|
+
"""
|
|
439
|
+
enc = self.tokenizer(sentence, return_tensors="pt").to(self.model.device)
|
|
440
|
+
logits = self.model(**enc).logits
|
|
441
|
+
predicted_label = torch.argmax(logits, dim=1).item()
|
|
442
|
+
return predicted_label == 1
|
|
443
|
+
|
|
444
|
+
def _select_parent_via_prompts(self, child: str) -> str:
|
|
445
|
+
"""Select the most likely parent for a given child via prompt voting.
|
|
446
|
+
|
|
447
|
+
The procedure:
|
|
448
|
+
1) Generate prompts for each candidate parent at increasing "levels".
|
|
449
|
+
2) Accumulate votes from the True/False classifier.
|
|
450
|
+
3) Resolve ties by recursing to the next level; after 4 levels, break ties randomly.
|
|
451
|
+
|
|
452
|
+
Args:
|
|
453
|
+
child: The child label whose parent should be predicted.
|
|
454
|
+
|
|
455
|
+
Returns:
|
|
456
|
+
The chosen parent string.
|
|
457
|
+
|
|
458
|
+
Raises:
|
|
459
|
+
AssertionError: If candidate parents were not initialized.
|
|
460
|
+
"""
|
|
461
|
+
assert self._candidate_parents, "Candidate parents not initialized."
|
|
462
|
+
scores: dict[str, int] = {p: 0 for p in self._candidate_parents}
|
|
463
|
+
|
|
464
|
+
def prompt_indices_for_level(level: int) -> List[int]:
|
|
465
|
+
if level == 0:
|
|
466
|
+
return [1]
|
|
467
|
+
return [2 * level, 2 * level + 1]
|
|
468
|
+
|
|
469
|
+
def recurse(active_parents: List[str], level: int) -> str:
|
|
470
|
+
idxs = [
|
|
471
|
+
i for i in prompt_indices_for_level(level) if 1 <= i <= self.n_prompts
|
|
472
|
+
]
|
|
473
|
+
if idxs:
|
|
474
|
+
for parent in active_parents:
|
|
475
|
+
votes = sum(
|
|
476
|
+
1
|
|
477
|
+
for idx in idxs
|
|
478
|
+
if self._predict_prompt_true_false(
|
|
479
|
+
self._build_masked_prompt(
|
|
480
|
+
parent=parent, child=child, index_1_based=idx
|
|
481
|
+
)
|
|
482
|
+
)
|
|
483
|
+
)
|
|
484
|
+
scores[parent] += votes
|
|
485
|
+
|
|
486
|
+
max_score = max(scores[p] for p in active_parents)
|
|
487
|
+
tied = [p for p in active_parents if scores[p] == max_score]
|
|
488
|
+
if len(tied) == 1:
|
|
489
|
+
return tied[0]
|
|
490
|
+
if level < 4:
|
|
491
|
+
return recurse(tied, level + 1)
|
|
492
|
+
return random.choice(tied)
|
|
493
|
+
|
|
494
|
+
return recurse(list(scores.keys()), level=0)
|
|
495
|
+
|
|
496
|
+
def _taxonomy_discovery(self, data: Any, test: bool = False):
|
|
497
|
+
"""
|
|
498
|
+
TRAIN:
|
|
499
|
+
- OntologyData -> ontology-aware split; negatives per split; balanced sets.
|
|
500
|
+
- DataFrame/list -> taxonomy_split for positives; negatives proportional.
|
|
501
|
+
TEST:
|
|
502
|
+
- OntologyData -> parent selection: [{'parent': predicted, 'child': child}]
|
|
503
|
+
- DataFrame/list -> binary pair classification with 'label' + 'score'
|
|
504
|
+
|
|
505
|
+
Args:
|
|
506
|
+
data: One of {OntologyData, pandas.DataFrame, list[dict], list[tuple]}.
|
|
507
|
+
test: If True, run inference; otherwise perform training.
|
|
508
|
+
|
|
509
|
+
Returns:
|
|
510
|
+
- On training: None (model is fine-tuned in-place).
|
|
511
|
+
- On inference with OntologyData: list of {'parent','child'} predictions.
|
|
512
|
+
- On inference with pairs: list of dicts including 'label' and 'score'.
|
|
513
|
+
"""
|
|
514
|
+
is_ontology_object = isinstance(data, OntologyData)
|
|
515
|
+
|
|
516
|
+
# Normalize input
|
|
517
|
+
if isinstance(data, pd.DataFrame):
|
|
518
|
+
pairs_df = data.copy()
|
|
519
|
+
elif isinstance(data, list):
|
|
520
|
+
pairs_df = pd.DataFrame(data)
|
|
521
|
+
else:
|
|
522
|
+
gt_pairs = super().tasks_ground_truth_former(data, "taxonomy-discovery")
|
|
523
|
+
pairs_df = pd.DataFrame(gt_pairs)
|
|
524
|
+
if "label" not in pairs_df.columns:
|
|
525
|
+
pairs_df["label"] = True
|
|
526
|
+
|
|
527
|
+
# Maintain candidate parents across calls
|
|
528
|
+
if "parent" in pairs_df.columns:
|
|
529
|
+
parents_in_call = sorted(pd.unique(pairs_df["parent"]).tolist())
|
|
530
|
+
if test:
|
|
531
|
+
if self._candidate_parents is None:
|
|
532
|
+
self._candidate_parents = parents_in_call
|
|
533
|
+
else:
|
|
534
|
+
self._candidate_parents = sorted(
|
|
535
|
+
set(self._candidate_parents).union(parents_in_call)
|
|
536
|
+
)
|
|
537
|
+
else:
|
|
538
|
+
if self._candidate_parents is None:
|
|
539
|
+
self._candidate_parents = parents_in_call
|
|
540
|
+
|
|
541
|
+
if test:
|
|
542
|
+
if is_ontology_object and self._candidate_parents:
|
|
543
|
+
predictions: List[dict[str, str]] = []
|
|
544
|
+
for _, row in pairs_df.iterrows():
|
|
545
|
+
child_term = row["child"]
|
|
546
|
+
chosen_parent = self._select_parent_via_prompts(child_term)
|
|
547
|
+
predictions.append({"parent": chosen_parent, "child": child_term})
|
|
548
|
+
return predictions
|
|
549
|
+
|
|
550
|
+
# pairwise binary classification
|
|
551
|
+
prompts_df = self._add_prompt_columns(pairs_df.copy())
|
|
552
|
+
true_probs_by_prompt: List[torch.Tensor] = []
|
|
553
|
+
|
|
554
|
+
for i in range(self.n_prompts):
|
|
555
|
+
col = f"prompt_{i + 1}"
|
|
556
|
+
enc = self.tokenizer(
|
|
557
|
+
prompts_df[col].tolist(),
|
|
558
|
+
return_tensors="pt",
|
|
559
|
+
padding=True,
|
|
560
|
+
truncation=True,
|
|
561
|
+
).to(self.model.device)
|
|
562
|
+
with torch.no_grad():
|
|
563
|
+
logits = self.model(**enc).logits
|
|
564
|
+
true_probs_by_prompt.append(torch.softmax(logits, dim=1)[:, 1])
|
|
565
|
+
|
|
566
|
+
avg_true_prob = torch.stack(true_probs_by_prompt, dim=0).mean(0)
|
|
567
|
+
predicted_bool = (avg_true_prob >= 0.5).cpu().tolist()
|
|
568
|
+
|
|
569
|
+
results: List[dict[str, Any]] = []
|
|
570
|
+
for p, c, s, yhat in zip(
|
|
571
|
+
pairs_df["parent"],
|
|
572
|
+
pairs_df["child"],
|
|
573
|
+
avg_true_prob.tolist(),
|
|
574
|
+
predicted_bool,
|
|
575
|
+
):
|
|
576
|
+
results.append(
|
|
577
|
+
{
|
|
578
|
+
"parent": p,
|
|
579
|
+
"child": c,
|
|
580
|
+
"label": int(bool(yhat)),
|
|
581
|
+
"score": float(s),
|
|
582
|
+
}
|
|
583
|
+
)
|
|
584
|
+
return results
|
|
585
|
+
|
|
586
|
+
if isinstance(data, OntologyData):
|
|
587
|
+
train_onto, eval_onto = ontology_split(
|
|
588
|
+
data,
|
|
589
|
+
test_size=self._eval_fraction,
|
|
590
|
+
random_state=self.random_state,
|
|
591
|
+
verbose=False,
|
|
592
|
+
)
|
|
593
|
+
|
|
594
|
+
train_pos_rel: List[TaxonomicRelation] = (
|
|
595
|
+
getattr(train_onto.type_taxonomies, "taxonomies", []) or []
|
|
596
|
+
)
|
|
597
|
+
eval_pos_rel: List[TaxonomicRelation] = (
|
|
598
|
+
getattr(eval_onto.type_taxonomies, "taxonomies", []) or []
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
train_pos_df = self._df_from_relations(train_pos_rel, label=True)
|
|
602
|
+
eval_pos_df = self._df_from_relations(eval_pos_rel, label=True)
|
|
603
|
+
|
|
604
|
+
tr_rev_df, tr_man_df = self._make_negatives(train_pos_df)
|
|
605
|
+
ev_rev_df, ev_man_df = self._make_negatives(eval_pos_df)
|
|
606
|
+
|
|
607
|
+
train_df = self._balance_with_negatives(train_pos_df, tr_rev_df, tr_man_df)
|
|
608
|
+
eval_df = self._balance_with_negatives(eval_pos_df, ev_rev_df, ev_man_df)
|
|
609
|
+
|
|
610
|
+
train_df = self._add_prompt_columns(train_df)
|
|
611
|
+
eval_df = self._add_prompt_columns(eval_df)
|
|
612
|
+
|
|
613
|
+
else:
|
|
614
|
+
if "label" not in pairs_df.columns or pairs_df["label"].nunique() == 1:
|
|
615
|
+
positives_df = pairs_df[pairs_df.get("label", True)][
|
|
616
|
+
["parent", "child"]
|
|
617
|
+
].copy()
|
|
618
|
+
pos_rel = self._relations_from_df(positives_df)
|
|
619
|
+
|
|
620
|
+
tr_rel, ev_rel = taxonomy_split(
|
|
621
|
+
pos_rel,
|
|
622
|
+
train_terms=None,
|
|
623
|
+
test_size=self._eval_fraction,
|
|
624
|
+
random_state=self.random_state,
|
|
625
|
+
verbose=False,
|
|
626
|
+
)
|
|
627
|
+
train_pos_df = self._df_from_relations(tr_rel, label=True)
|
|
628
|
+
eval_pos_df = self._df_from_relations(ev_rel, label=True)
|
|
629
|
+
|
|
630
|
+
tr_rev_df, tr_man_df = self._make_negatives(train_pos_df)
|
|
631
|
+
ev_rev_df, ev_man_df = self._make_negatives(eval_pos_df)
|
|
632
|
+
|
|
633
|
+
train_df = self._balance_with_negatives(
|
|
634
|
+
train_pos_df, tr_rev_df, tr_man_df
|
|
635
|
+
)
|
|
636
|
+
eval_df = self._balance_with_negatives(
|
|
637
|
+
eval_pos_df, ev_rev_df, ev_man_df
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
train_df = self._add_prompt_columns(train_df)
|
|
641
|
+
eval_df = self._add_prompt_columns(eval_df)
|
|
642
|
+
|
|
643
|
+
else:
|
|
644
|
+
positives_df = pairs_df[pairs_df["label"]][["parent", "child"]].copy()
|
|
645
|
+
pos_rel = self._relations_from_df(positives_df)
|
|
646
|
+
|
|
647
|
+
tr_rel, ev_rel = taxonomy_split(
|
|
648
|
+
pos_rel,
|
|
649
|
+
train_terms=None,
|
|
650
|
+
test_size=self._eval_fraction,
|
|
651
|
+
random_state=self.random_state,
|
|
652
|
+
verbose=False,
|
|
653
|
+
)
|
|
654
|
+
train_pos_df = self._df_from_relations(tr_rel, label=True)
|
|
655
|
+
eval_pos_df = self._df_from_relations(ev_rel, label=True)
|
|
656
|
+
|
|
657
|
+
negatives_df = pairs_df[pairs_df["label"]][["parent", "child"]].copy()
|
|
658
|
+
negatives_df = negatives_df.sample(
|
|
659
|
+
frac=1.0, random_state=self.random_state
|
|
660
|
+
).reset_index(drop=True)
|
|
661
|
+
|
|
662
|
+
n_eval_neg = (
|
|
663
|
+
max(1, int(len(negatives_df) * self._eval_fraction))
|
|
664
|
+
if len(negatives_df) > 0
|
|
665
|
+
else 0
|
|
666
|
+
)
|
|
667
|
+
eval_neg_df = (
|
|
668
|
+
negatives_df.iloc[:n_eval_neg].copy()
|
|
669
|
+
if n_eval_neg > 0
|
|
670
|
+
else negatives_df.iloc[:0].copy()
|
|
671
|
+
)
|
|
672
|
+
train_neg_df = negatives_df.iloc[n_eval_neg:].copy()
|
|
673
|
+
|
|
674
|
+
train_neg_df["label"] = False
|
|
675
|
+
eval_neg_df["label"] = False
|
|
676
|
+
|
|
677
|
+
train_df = pd.concat([train_pos_df, train_neg_df], ignore_index=True)
|
|
678
|
+
eval_df = pd.concat([eval_pos_df, eval_neg_df], ignore_index=True)
|
|
679
|
+
|
|
680
|
+
train_df = self._add_prompt_columns(train_df)
|
|
681
|
+
eval_df = self._add_prompt_columns(eval_df)
|
|
682
|
+
|
|
683
|
+
# Ensure labels are int64
|
|
684
|
+
train_df["label"] = train_df["label"].astype("int64")
|
|
685
|
+
eval_df["label"] = eval_df["label"].astype("int64")
|
|
686
|
+
|
|
687
|
+
# Sequential fine-tuning across prompts
|
|
688
|
+
for i in range(self.n_prompts):
|
|
689
|
+
prompt_col = f"prompt_{i + 1}"
|
|
690
|
+
train_ds = Dataset.from_pandas(
|
|
691
|
+
train_df[[prompt_col, "label"]].reset_index(drop=True)
|
|
692
|
+
)
|
|
693
|
+
eval_ds = Dataset.from_pandas(
|
|
694
|
+
eval_df[[prompt_col, "label"]].reset_index(drop=True)
|
|
695
|
+
)
|
|
696
|
+
|
|
697
|
+
train_ds = train_ds.rename_column("label", "labels")
|
|
698
|
+
eval_ds = eval_ds.rename_column("label", "labels")
|
|
699
|
+
|
|
700
|
+
def tokenize_batch(batch):
|
|
701
|
+
"""Tokenize a batch for the current prompt column with truncation/padding."""
|
|
702
|
+
return self.tokenizer(
|
|
703
|
+
batch[prompt_col], padding="max_length", truncation=True
|
|
704
|
+
)
|
|
705
|
+
|
|
706
|
+
train_ds = train_ds.map(
|
|
707
|
+
tokenize_batch, batched=True, remove_columns=[prompt_col]
|
|
708
|
+
)
|
|
709
|
+
eval_ds = eval_ds.map(
|
|
710
|
+
tokenize_batch, batched=True, remove_columns=[prompt_col]
|
|
711
|
+
)
|
|
712
|
+
|
|
713
|
+
train_ds.set_format(
|
|
714
|
+
type="torch", columns=["input_ids", "attention_mask", "labels"]
|
|
715
|
+
)
|
|
716
|
+
eval_ds.set_format(
|
|
717
|
+
type="torch", columns=["input_ids", "attention_mask", "labels"]
|
|
718
|
+
)
|
|
719
|
+
|
|
720
|
+
trainer = Trainer(
|
|
721
|
+
model=self.model,
|
|
722
|
+
args=self.training_args,
|
|
723
|
+
train_dataset=train_ds,
|
|
724
|
+
eval_dataset=eval_ds,
|
|
725
|
+
)
|
|
726
|
+
trainer.train()
|
|
727
|
+
|
|
728
|
+
self._last_train = train_df
|
|
729
|
+
self._last_eval = eval_df
|
|
730
|
+
return None
|
|
731
|
+
|
|
732
|
+
|
|
733
|
+
class SKHNLPZSLearner(AutoLearner):
|
|
734
|
+
"""
|
|
735
|
+
Zero-shot taxonomy learner using an instruction-tuned causal LLM.
|
|
736
|
+
|
|
737
|
+
Behavior
|
|
738
|
+
--------
|
|
739
|
+
- Builds a fixed classification prompt listing 9 GeoNames parent classes.
|
|
740
|
+
- For each input row (child term), generates a short completion and parses
|
|
741
|
+
the predicted class from a strict '#[ ... ]#' format.
|
|
742
|
+
- Optionally normalizes the raw prediction to one of the valid 9 labels via:
|
|
743
|
+
* "none" : keep the parsed text as-is
|
|
744
|
+
* "substring" : snap to a label if either is a substring of the other
|
|
745
|
+
* "levenshtein" : snap to the closest label by edit distance
|
|
746
|
+
* "auto" : substring, then Levenshtein if needed
|
|
747
|
+
- Saves raw and normalized predictions to CSV if `save_path` is provided.
|
|
748
|
+
|
|
749
|
+
Inputs the learner accepts (via `_to_dataframe`):
|
|
750
|
+
- pandas.DataFrame with columns: ['child', 'parent'] or ['child', 'parent', 'label']
|
|
751
|
+
- list[dict] with keys: 'child', 'parent' (and optionally 'label')
|
|
752
|
+
- list of tuples/lists: (child, parent) or (child, parent, label)
|
|
753
|
+
- OntoLearner-style object exposing .type_taxonomies.taxonomies iterable with (child, parent)
|
|
754
|
+
"""
|
|
755
|
+
|
|
756
|
+
# Fixed class inventory (GeoNames parents)
|
|
757
|
+
CLASS_LIST = [
|
|
758
|
+
"city, village",
|
|
759
|
+
"country, state, region",
|
|
760
|
+
"forest, heath",
|
|
761
|
+
"mountain, hill, rock",
|
|
762
|
+
"parks, area",
|
|
763
|
+
"road, railroad",
|
|
764
|
+
"spot, building, farm",
|
|
765
|
+
"stream, lake",
|
|
766
|
+
"undersea",
|
|
767
|
+
]
|
|
768
|
+
|
|
769
|
+
# Strict format: #[ ... ]#
|
|
770
|
+
_PREDICTION_PATTERN = re.compile(r"#\[\s*([^\]]+?)\s*\]#")
|
|
771
|
+
|
|
772
|
+
def __init__(
|
|
773
|
+
self,
|
|
774
|
+
model_name: str = "Qwen/Qwen2.5-0.5B-Instruct",
|
|
775
|
+
device: Optional[str] = None, # "cuda" | "cpu" | None (auto)
|
|
776
|
+
max_new_tokens: int = 16,
|
|
777
|
+
save_path: Optional[str] = None, # directory or full path
|
|
778
|
+
verbose: bool = True,
|
|
779
|
+
normalize_mode: str = "none", # "none" | "substring" | "levenshtein" | "auto"
|
|
780
|
+
random_state: int = 1403,
|
|
781
|
+
) -> None:
|
|
782
|
+
"""Configure the zero-shot learner.
|
|
783
|
+
|
|
784
|
+
Args:
|
|
785
|
+
model_name: HF model id/path for the instruction-tuned causal LLM.
|
|
786
|
+
device: Force device ('cuda' or 'cpu'), else auto-detect.
|
|
787
|
+
max_new_tokens: Generation length budget for each completion.
|
|
788
|
+
save_path: Optional CSV path or directory for saving predictions.
|
|
789
|
+
verbose: If True, print progress messages.
|
|
790
|
+
normalize_mode: Post-processing for class names
|
|
791
|
+
('none' | 'substring' | 'levenshtein' | 'auto').
|
|
792
|
+
random_state: RNG seed for any sampling steps.
|
|
793
|
+
"""
|
|
794
|
+
super().__init__()
|
|
795
|
+
self.model_name = model_name
|
|
796
|
+
self.verbose = verbose
|
|
797
|
+
self.max_new_tokens = max_new_tokens
|
|
798
|
+
self.save_path = save_path
|
|
799
|
+
self.normalize_mode = (normalize_mode or "none").lower().strip()
|
|
800
|
+
self.random_state = random_state
|
|
801
|
+
|
|
802
|
+
random.seed(self.random_state)
|
|
803
|
+
|
|
804
|
+
# Device: auto-detect CUDA if not specified
|
|
805
|
+
if device is None:
|
|
806
|
+
self._has_cuda = torch.cuda.is_available()
|
|
807
|
+
else:
|
|
808
|
+
self._has_cuda = device == "cuda"
|
|
809
|
+
self._pipe_device = 0 if self._has_cuda else -1
|
|
810
|
+
self._model_device_map = {"": "cuda"} if self._has_cuda else None
|
|
811
|
+
|
|
812
|
+
self._tokenizer = None
|
|
813
|
+
self._model = None
|
|
814
|
+
self._pipeline = None
|
|
815
|
+
|
|
816
|
+
# Prompt template used for every example
|
|
817
|
+
self._classification_prompt = (
|
|
818
|
+
"My task is classification. My classes are as follows: "
|
|
819
|
+
"(city, village), (country, state, region), (forest, heath), "
|
|
820
|
+
"(mountain, hill, rock), (parks, area), (road, railroad), "
|
|
821
|
+
"(spot, building, farm), (stream, lake), (undersea). "
|
|
822
|
+
'I will provide you with a phrase like "wadi mouth". '
|
|
823
|
+
"The name of each class is placed within a pair of parentheses. "
|
|
824
|
+
"I want you to choose the most appropriate class from those mentioned above "
|
|
825
|
+
"based on the given phrase and present it in a format like #[parks, area]#. "
|
|
826
|
+
"So, the general format for each response will be #[class name]#. "
|
|
827
|
+
"Pay attention to the format of the response. Start with a '#' character, "
|
|
828
|
+
"include the class name inside it, and end with another '#' character. "
|
|
829
|
+
"Additionally, make sure to include a '#' character at the end to indicate "
|
|
830
|
+
"that the answer is complete. I don't need any additional explanations."
|
|
831
|
+
)
|
|
832
|
+
|
|
833
|
+
def load(self, model_id: str = "") -> None:
|
|
834
|
+
"""
|
|
835
|
+
Load tokenizer, model, and text-generation pipeline.
|
|
836
|
+
|
|
837
|
+
Args:
|
|
838
|
+
model_id: Optional HF id/path override; defaults to `self.model_name`.
|
|
839
|
+
|
|
840
|
+
Side Effects:
|
|
841
|
+
Initializes the tokenizer and model, configures the generation
|
|
842
|
+
pipeline on CPU/GPU, and sets a pad token if absent.
|
|
843
|
+
"""
|
|
844
|
+
model_id = model_id or self.model_name
|
|
845
|
+
if self.verbose:
|
|
846
|
+
print(f"[ZeroShotTaxonomyLearner] Loading {model_id}")
|
|
847
|
+
|
|
848
|
+
self._tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
849
|
+
|
|
850
|
+
# Ensure a pad token is set for generation
|
|
851
|
+
if (
|
|
852
|
+
self._tokenizer.pad_token_id is None
|
|
853
|
+
and self._tokenizer.eos_token_id is not None
|
|
854
|
+
):
|
|
855
|
+
self._tokenizer.pad_token = self._tokenizer.eos_token
|
|
856
|
+
|
|
857
|
+
self._model = AutoModelForCausalLM.from_pretrained(
|
|
858
|
+
model_id,
|
|
859
|
+
device_map=self._model_device_map,
|
|
860
|
+
torch_dtype="auto",
|
|
861
|
+
)
|
|
862
|
+
|
|
863
|
+
self._pipeline = pipeline(
|
|
864
|
+
task="text-generation",
|
|
865
|
+
model=self._model,
|
|
866
|
+
tokenizer=self._tokenizer,
|
|
867
|
+
device=self._pipe_device, # 0 for GPU, -1 for CPU
|
|
868
|
+
)
|
|
869
|
+
|
|
870
|
+
if self.verbose:
|
|
871
|
+
print("Device set to use", "cuda" if self._has_cuda else "cpu")
|
|
872
|
+
print("[ZeroShotTaxonomyLearner] Model loaded.")
|
|
873
|
+
|
|
874
|
+
def _taxonomy_discovery(
|
|
875
|
+
self, data: Any, test: bool = False
|
|
876
|
+
) -> Optional[List[Dict[str, str]]]:
|
|
877
|
+
"""
|
|
878
|
+
Zero-shot prediction over all incoming rows (no filtering/augmentation).
|
|
879
|
+
|
|
880
|
+
Args:
|
|
881
|
+
data: One of {DataFrame, list[dict], list[tuple], Ontology-like}.
|
|
882
|
+
test: If False, training is skipped (zero-shot learner), and None is returned.
|
|
883
|
+
|
|
884
|
+
Returns:
|
|
885
|
+
On `test=True`, a list of dicts [{'parent': predicted_label, 'child': child}, ...].
|
|
886
|
+
On `test=False`, returns None.
|
|
887
|
+
"""
|
|
888
|
+
if not test:
|
|
889
|
+
if self.verbose:
|
|
890
|
+
print("[ZeroShot] Training skipped (zero-shot).")
|
|
891
|
+
return None
|
|
892
|
+
|
|
893
|
+
df = self._to_dataframe(data)
|
|
894
|
+
|
|
895
|
+
if self.verbose:
|
|
896
|
+
print(f"[ZeroShot] Incoming rows: {len(df)}; columns: {list(df.columns)}")
|
|
897
|
+
|
|
898
|
+
eval_df = pd.DataFrame(df).reset_index(drop=True)
|
|
899
|
+
if eval_df.empty:
|
|
900
|
+
return []
|
|
901
|
+
|
|
902
|
+
# Prepare columns for inspection and saving
|
|
903
|
+
eval_df["prediction_raw"] = ""
|
|
904
|
+
eval_df["prediction_sub"] = ""
|
|
905
|
+
eval_df["prediction_lvn"] = ""
|
|
906
|
+
eval_df["prediction_auto"] = ""
|
|
907
|
+
eval_df["prediction"] = "" # final (per normalize_mode)
|
|
908
|
+
|
|
909
|
+
# Generate predictions row by row
|
|
910
|
+
for idx, row in eval_df.iterrows():
|
|
911
|
+
child_term = str(row["child"])
|
|
912
|
+
raw_text, parsed_raw = self._generate_and_parse(child_term)
|
|
913
|
+
|
|
914
|
+
# Choose a string to normalize (parsed token if matched, otherwise whole output)
|
|
915
|
+
basis = parsed_raw if parsed_raw != "unknown" else raw_text
|
|
916
|
+
|
|
917
|
+
# Compute all normalization variants
|
|
918
|
+
sub_norm = self._normalize_substring_only(basis)
|
|
919
|
+
lvn_norm = self._normalize_levenshtein_only(basis)
|
|
920
|
+
auto_norm = self._normalize_auto(basis)
|
|
921
|
+
|
|
922
|
+
# Final selection by mode
|
|
923
|
+
if self.normalize_mode == "none":
|
|
924
|
+
final_label = parsed_raw
|
|
925
|
+
elif self.normalize_mode == "substring":
|
|
926
|
+
final_label = sub_norm
|
|
927
|
+
elif self.normalize_mode == "levenshtein":
|
|
928
|
+
final_label = lvn_norm
|
|
929
|
+
elif self.normalize_mode == "auto":
|
|
930
|
+
final_label = auto_norm
|
|
931
|
+
else:
|
|
932
|
+
final_label = parsed_raw # fallback
|
|
933
|
+
|
|
934
|
+
# Persist to DataFrame for inspection/export
|
|
935
|
+
eval_df.at[idx, "prediction_raw"] = parsed_raw
|
|
936
|
+
eval_df.at[idx, "prediction_sub"] = sub_norm
|
|
937
|
+
eval_df.at[idx, "prediction_lvn"] = lvn_norm
|
|
938
|
+
eval_df.at[idx, "prediction_auto"] = auto_norm
|
|
939
|
+
eval_df.at[idx, "prediction"] = final_label
|
|
940
|
+
|
|
941
|
+
# Return in the format expected by the pipeline
|
|
942
|
+
return [
|
|
943
|
+
{"parent": p, "child": c}
|
|
944
|
+
for p, c in zip(eval_df["prediction"], eval_df["child"])
|
|
945
|
+
]
|
|
946
|
+
|
|
947
|
+
def _generate_and_parse(self, child_term: str) -> (str, str):
|
|
948
|
+
"""
|
|
949
|
+
Generate a completion for the given child term and extract the raw predicted class
|
|
950
|
+
using the strict '#[ ... ]#' pattern.
|
|
951
|
+
|
|
952
|
+
Args:
|
|
953
|
+
child_term: The child label to classify into one of the fixed classes.
|
|
954
|
+
|
|
955
|
+
Returns:
|
|
956
|
+
Tuple `(raw_generation_text, parsed_prediction_or_unknown)`, where the second
|
|
957
|
+
element is either the text inside '#[ ... ]#' or the string 'unknown'.
|
|
958
|
+
"""
|
|
959
|
+
messages = [
|
|
960
|
+
{"role": "system", "content": "You are a helpful classifier."},
|
|
961
|
+
{"role": "user", "content": f"{self._classification_prompt} {child_term}"},
|
|
962
|
+
]
|
|
963
|
+
|
|
964
|
+
prompt = self._tokenizer.apply_chat_template(
|
|
965
|
+
messages,
|
|
966
|
+
tokenize=False,
|
|
967
|
+
add_generation_prompt=True,
|
|
968
|
+
)
|
|
969
|
+
|
|
970
|
+
generation = self._pipeline(
|
|
971
|
+
prompt,
|
|
972
|
+
max_new_tokens=self.max_new_tokens,
|
|
973
|
+
do_sample=False,
|
|
974
|
+
temperature=0.0,
|
|
975
|
+
top_p=1.0,
|
|
976
|
+
eos_token_id=self._tokenizer.eos_token_id,
|
|
977
|
+
pad_token_id=self._tokenizer.eos_token_id,
|
|
978
|
+
return_full_text=False,
|
|
979
|
+
)[0]["generated_text"]
|
|
980
|
+
|
|
981
|
+
match = self._PREDICTION_PATTERN.search(generation)
|
|
982
|
+
parsed = match.group(1).strip() if match else "unknown"
|
|
983
|
+
return generation, parsed
|
|
984
|
+
|
|
985
|
+
def _normalize_substring_only(self, text: str) -> str:
|
|
986
|
+
"""
|
|
987
|
+
Snap to a label if the string is equal to / contained in / contains a valid label (case-insensitive).
|
|
988
|
+
|
|
989
|
+
Args:
|
|
990
|
+
text: Raw class text to normalize.
|
|
991
|
+
|
|
992
|
+
Returns:
|
|
993
|
+
One of `CLASS_LIST` on a match; otherwise 'unknown'.
|
|
994
|
+
"""
|
|
995
|
+
if not isinstance(text, str):
|
|
996
|
+
return "unknown"
|
|
997
|
+
lowered = text.strip().lower()
|
|
998
|
+
if not lowered:
|
|
999
|
+
return "unknown"
|
|
1000
|
+
|
|
1001
|
+
for label in self.CLASS_LIST:
|
|
1002
|
+
label_lower = label.lower()
|
|
1003
|
+
if (
|
|
1004
|
+
lowered == label_lower
|
|
1005
|
+
or lowered in label_lower
|
|
1006
|
+
or label_lower in lowered
|
|
1007
|
+
):
|
|
1008
|
+
return label
|
|
1009
|
+
return "unknown"
|
|
1010
|
+
|
|
1011
|
+
def _normalize_levenshtein_only(self, text: str) -> str:
|
|
1012
|
+
"""
|
|
1013
|
+
Snap to the nearest label by Levenshtein (edit) distance.
|
|
1014
|
+
|
|
1015
|
+
Args:
|
|
1016
|
+
text: Raw class text to normalize.
|
|
1017
|
+
|
|
1018
|
+
Returns:
|
|
1019
|
+
The nearest label in `CLASS_LIST`, or 'unknown' if input is empty/invalid.
|
|
1020
|
+
"""
|
|
1021
|
+
if not isinstance(text, str):
|
|
1022
|
+
return "unknown"
|
|
1023
|
+
lowered = text.strip().lower()
|
|
1024
|
+
if not lowered:
|
|
1025
|
+
return "unknown"
|
|
1026
|
+
|
|
1027
|
+
best_label = None
|
|
1028
|
+
best_distance = 10**9
|
|
1029
|
+
for label in self.CLASS_LIST:
|
|
1030
|
+
label_lower = label.lower()
|
|
1031
|
+
distance = Levenshtein.distance(lowered, label_lower)
|
|
1032
|
+
if distance < best_distance:
|
|
1033
|
+
best_distance = distance
|
|
1034
|
+
best_label = label
|
|
1035
|
+
return best_label or "unknown"
|
|
1036
|
+
|
|
1037
|
+
def _normalize_auto(self, text: str) -> str:
|
|
1038
|
+
"""
|
|
1039
|
+
Cascade: try substring-first; if no match, fall back to Levenshtein snapping.
|
|
1040
|
+
|
|
1041
|
+
Args:
|
|
1042
|
+
text: Raw class text to normalize.
|
|
1043
|
+
|
|
1044
|
+
Returns:
|
|
1045
|
+
Normalized label string or 'unknown'.
|
|
1046
|
+
"""
|
|
1047
|
+
snapped = self._normalize_substring_only(text)
|
|
1048
|
+
return (
|
|
1049
|
+
snapped if snapped != "unknown" else self._normalize_levenshtein_only(text)
|
|
1050
|
+
)
|
|
1051
|
+
|
|
1052
|
+
def _to_dataframe(self, data: Any) -> pd.DataFrame:
|
|
1053
|
+
"""
|
|
1054
|
+
Normalize various input formats into a DataFrame.
|
|
1055
|
+
|
|
1056
|
+
Supported inputs:
|
|
1057
|
+
* pandas.DataFrame with columns ['child','parent',('label')]
|
|
1058
|
+
* list[dict] with keys 'child','parent',('label')
|
|
1059
|
+
* list of tuples/lists: (child, parent) or (child, parent, label)
|
|
1060
|
+
* Ontology-like object with `.type_taxonomies.taxonomies`
|
|
1061
|
+
|
|
1062
|
+
Args:
|
|
1063
|
+
data: The source object to normalize.
|
|
1064
|
+
|
|
1065
|
+
Returns:
|
|
1066
|
+
A pandas DataFrame with standardized columns.
|
|
1067
|
+
|
|
1068
|
+
Raises:
|
|
1069
|
+
ValueError: If the input type/shape is not recognized.
|
|
1070
|
+
"""
|
|
1071
|
+
if isinstance(data, pd.DataFrame):
|
|
1072
|
+
df = data.copy()
|
|
1073
|
+
df.columns = [str(c).lower() for c in df.columns]
|
|
1074
|
+
return df.reset_index(drop=True)
|
|
1075
|
+
|
|
1076
|
+
if isinstance(data, list) and data and isinstance(data[0], dict):
|
|
1077
|
+
rows = [{str(k).lower(): v for k, v in d.items()} for d in data]
|
|
1078
|
+
return pd.DataFrame(rows).reset_index(drop=True)
|
|
1079
|
+
|
|
1080
|
+
if isinstance(data, (list, tuple)) and data:
|
|
1081
|
+
first = data[0]
|
|
1082
|
+
if isinstance(first, (list, tuple)) and not isinstance(first, dict):
|
|
1083
|
+
n = len(first)
|
|
1084
|
+
if n >= 3:
|
|
1085
|
+
return pd.DataFrame(
|
|
1086
|
+
data, columns=["child", "parent", "label"]
|
|
1087
|
+
).reset_index(drop=True)
|
|
1088
|
+
if n == 2:
|
|
1089
|
+
return pd.DataFrame(data, columns=["child", "parent"]).reset_index(
|
|
1090
|
+
drop=True
|
|
1091
|
+
)
|
|
1092
|
+
|
|
1093
|
+
try:
|
|
1094
|
+
type_taxonomies = getattr(data, "type_taxonomies", None)
|
|
1095
|
+
if type_taxonomies is not None:
|
|
1096
|
+
taxonomies = getattr(type_taxonomies, "taxonomies", None)
|
|
1097
|
+
if taxonomies is not None:
|
|
1098
|
+
rows = []
|
|
1099
|
+
for rel in taxonomies:
|
|
1100
|
+
parent = getattr(rel, "parent", None)
|
|
1101
|
+
child = getattr(rel, "child", None)
|
|
1102
|
+
label = (
|
|
1103
|
+
getattr(rel, "label", None)
|
|
1104
|
+
if hasattr(rel, "label")
|
|
1105
|
+
else None
|
|
1106
|
+
)
|
|
1107
|
+
if parent is not None and child is not None:
|
|
1108
|
+
rows.append(
|
|
1109
|
+
{"child": child, "parent": parent, "label": label}
|
|
1110
|
+
)
|
|
1111
|
+
if rows:
|
|
1112
|
+
return pd.DataFrame(rows).reset_index(drop=True)
|
|
1113
|
+
except Exception:
|
|
1114
|
+
pass
|
|
1115
|
+
|
|
1116
|
+
raise ValueError(
|
|
1117
|
+
"Unsupported data format. Provide a DataFrame, a list of dicts, "
|
|
1118
|
+
"a list of (child, parent[, label]) tuples/lists, or an object with "
|
|
1119
|
+
".type_taxonomies.taxonomies."
|
|
1120
|
+
)
|
|
1121
|
+
|
|
1122
|
+
def _resolve_save_path(self, save_path: str, default_filename: str) -> str:
|
|
1123
|
+
"""
|
|
1124
|
+
Resolve a target file path from a directory or path-like input.
|
|
1125
|
+
|
|
1126
|
+
If `save_path` points to a directory, joins it with `default_filename`.
|
|
1127
|
+
If it already looks like a file path (has an extension), returns as-is.
|
|
1128
|
+
|
|
1129
|
+
Args:
|
|
1130
|
+
save_path: Directory or file path supplied by the caller.
|
|
1131
|
+
default_filename: Basename to use when `save_path` is a directory.
|
|
1132
|
+
|
|
1133
|
+
Returns:
|
|
1134
|
+
A concrete file path where outputs can be written.
|
|
1135
|
+
"""
|
|
1136
|
+
base = os.path.basename(save_path)
|
|
1137
|
+
has_ext = os.path.splitext(base)[1] != ""
|
|
1138
|
+
return save_path if has_ext else os.path.join(save_path, default_filename)
|