distil-trainer 0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- distil_trainer/__init__.py +31 -0
- distil_trainer/core/__init__.py +23 -0
- distil_trainer/core/callbacks.py +188 -0
- distil_trainer/core/config.py +358 -0
- distil_trainer/core/trainer.py +843 -0
- distil_trainer/data/__init__.py +19 -0
- distil_trainer/data/collators.py +240 -0
- distil_trainer/data/datamodule.py +191 -0
- distil_trainer/data/datasets.py +245 -0
- distil_trainer/data/loaders.py +163 -0
- distil_trainer/distillation/__init__.py +21 -0
- distil_trainer/distillation/losses.py +345 -0
- distil_trainer/distillation/multilingual.py +285 -0
- distil_trainer/distillation/strategies.py +211 -0
- distil_trainer/evaluation/__init__.py +19 -0
- distil_trainer/evaluation/benchmarks.py +86 -0
- distil_trainer/evaluation/evaluators.py +343 -0
- distil_trainer/evaluation/metrics.py +75 -0
- distil_trainer/models/__init__.py +5 -0
- distil_trainer/models/layers.py +115 -0
- distil_trainer/pruning/__init__.py +13 -0
- distil_trainer/pruning/combined_pruning.py +122 -0
- distil_trainer/pruning/depth_pruning.py +261 -0
- distil_trainer/pruning/importance.py +365 -0
- distil_trainer/pruning/width_pruning.py +480 -0
- distil_trainer-0.1.10.dist-info/METADATA +443 -0
- distil_trainer-0.1.10.dist-info/RECORD +29 -0
- distil_trainer-0.1.10.dist-info/WHEEL +4 -0
- distil_trainer-0.1.10.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Data module for distillation training."""
|
|
2
|
+
|
|
3
|
+
from distil_trainer.data.datamodule import DistillationDataModule
|
|
4
|
+
from distil_trainer.data.datasets import (
|
|
5
|
+
SentenceDistillationDataset,
|
|
6
|
+
TripletDataset,
|
|
7
|
+
ParallelSentencesDataset,
|
|
8
|
+
)
|
|
9
|
+
from distil_trainer.data.collators import DistillationCollator
|
|
10
|
+
from distil_trainer.data.loaders import DatasetLoaders
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"DistillationDataModule",
|
|
14
|
+
"SentenceDistillationDataset",
|
|
15
|
+
"TripletDataset",
|
|
16
|
+
"ParallelSentencesDataset",
|
|
17
|
+
"DistillationCollator",
|
|
18
|
+
"DatasetLoaders",
|
|
19
|
+
]
|
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
"""Data collators for distillation training."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DistillationCollator:
|
|
11
|
+
"""
|
|
12
|
+
Collator for distillation training data.
|
|
13
|
+
|
|
14
|
+
Handles tokenization and batching of text data.
|
|
15
|
+
|
|
16
|
+
Example:
|
|
17
|
+
>>> collator = DistillationCollator(tokenizer, max_length=512)
|
|
18
|
+
>>> batch = collator([{"sentence": "Hello"}, {"sentence": "World"}])
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
tokenizer: Any = None,
|
|
24
|
+
max_length: int = 512,
|
|
25
|
+
text_column: str = "sentence",
|
|
26
|
+
padding: bool = True,
|
|
27
|
+
truncation: bool = True,
|
|
28
|
+
):
|
|
29
|
+
"""
|
|
30
|
+
Initialize the collator.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
tokenizer: Tokenizer for encoding text.
|
|
34
|
+
max_length: Maximum sequence length.
|
|
35
|
+
text_column: Name of the text column.
|
|
36
|
+
padding: Whether to pad sequences.
|
|
37
|
+
truncation: Whether to truncate sequences.
|
|
38
|
+
"""
|
|
39
|
+
self.tokenizer = tokenizer
|
|
40
|
+
self.max_length = max_length
|
|
41
|
+
self.text_column = text_column
|
|
42
|
+
self.padding = padding
|
|
43
|
+
self.truncation = truncation
|
|
44
|
+
|
|
45
|
+
def __call__(self, batch: list[dict[str, Any]]) -> dict[str, Any]:
|
|
46
|
+
"""
|
|
47
|
+
Collate a batch of samples.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
batch: List of sample dictionaries.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
Collated batch dictionary.
|
|
54
|
+
"""
|
|
55
|
+
# Extract text
|
|
56
|
+
texts = [sample.get(self.text_column, sample.get("sentence", "")) for sample in batch]
|
|
57
|
+
|
|
58
|
+
# Tokenize if tokenizer is available
|
|
59
|
+
if self.tokenizer is not None:
|
|
60
|
+
encoded = self.tokenizer(
|
|
61
|
+
texts,
|
|
62
|
+
padding=self.padding,
|
|
63
|
+
truncation=self.truncation,
|
|
64
|
+
max_length=self.max_length,
|
|
65
|
+
return_tensors="pt",
|
|
66
|
+
)
|
|
67
|
+
result = dict(encoded)
|
|
68
|
+
else:
|
|
69
|
+
result = {self.text_column: texts}
|
|
70
|
+
|
|
71
|
+
# Add labels if present
|
|
72
|
+
if "label" in batch[0]:
|
|
73
|
+
labels = [sample["label"] for sample in batch]
|
|
74
|
+
if isinstance(labels[0], torch.Tensor):
|
|
75
|
+
result["label"] = torch.stack(labels)
|
|
76
|
+
else:
|
|
77
|
+
result["label"] = torch.tensor(labels)
|
|
78
|
+
|
|
79
|
+
return result
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class TripletCollator:
|
|
83
|
+
"""
|
|
84
|
+
Collator for triplet data (query, positive, negatives).
|
|
85
|
+
|
|
86
|
+
Example:
|
|
87
|
+
>>> collator = TripletCollator(tokenizer)
|
|
88
|
+
>>> batch = collator([{"query": "q1", "positive": "p1", "negatives": ["n1"]}])
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
tokenizer: Any = None,
|
|
94
|
+
max_length: int = 512,
|
|
95
|
+
padding: bool = True,
|
|
96
|
+
truncation: bool = True,
|
|
97
|
+
):
|
|
98
|
+
"""
|
|
99
|
+
Initialize the collator.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
tokenizer: Tokenizer for encoding text.
|
|
103
|
+
max_length: Maximum sequence length.
|
|
104
|
+
padding: Whether to pad sequences.
|
|
105
|
+
truncation: Whether to truncate sequences.
|
|
106
|
+
"""
|
|
107
|
+
self.tokenizer = tokenizer
|
|
108
|
+
self.max_length = max_length
|
|
109
|
+
self.padding = padding
|
|
110
|
+
self.truncation = truncation
|
|
111
|
+
|
|
112
|
+
def __call__(self, batch: list[dict[str, Any]]) -> dict[str, Any]:
|
|
113
|
+
"""Collate a batch of triplet samples."""
|
|
114
|
+
queries = [sample["query"] for sample in batch]
|
|
115
|
+
positives = [sample["positive"] for sample in batch]
|
|
116
|
+
|
|
117
|
+
result = {}
|
|
118
|
+
|
|
119
|
+
if self.tokenizer is not None:
|
|
120
|
+
# Tokenize queries
|
|
121
|
+
query_encoded = self.tokenizer(
|
|
122
|
+
queries,
|
|
123
|
+
padding=self.padding,
|
|
124
|
+
truncation=self.truncation,
|
|
125
|
+
max_length=self.max_length,
|
|
126
|
+
return_tensors="pt",
|
|
127
|
+
)
|
|
128
|
+
result["query_input_ids"] = query_encoded["input_ids"]
|
|
129
|
+
result["query_attention_mask"] = query_encoded["attention_mask"]
|
|
130
|
+
|
|
131
|
+
# Tokenize positives
|
|
132
|
+
pos_encoded = self.tokenizer(
|
|
133
|
+
positives,
|
|
134
|
+
padding=self.padding,
|
|
135
|
+
truncation=self.truncation,
|
|
136
|
+
max_length=self.max_length,
|
|
137
|
+
return_tensors="pt",
|
|
138
|
+
)
|
|
139
|
+
result["positive_input_ids"] = pos_encoded["input_ids"]
|
|
140
|
+
result["positive_attention_mask"] = pos_encoded["attention_mask"]
|
|
141
|
+
|
|
142
|
+
# Tokenize negatives if present
|
|
143
|
+
if "negatives" in batch[0]:
|
|
144
|
+
all_negatives = []
|
|
145
|
+
for sample in batch:
|
|
146
|
+
all_negatives.extend(sample.get("negatives", []))
|
|
147
|
+
|
|
148
|
+
if all_negatives:
|
|
149
|
+
neg_encoded = self.tokenizer(
|
|
150
|
+
all_negatives,
|
|
151
|
+
padding=self.padding,
|
|
152
|
+
truncation=self.truncation,
|
|
153
|
+
max_length=self.max_length,
|
|
154
|
+
return_tensors="pt",
|
|
155
|
+
)
|
|
156
|
+
result["negative_input_ids"] = neg_encoded["input_ids"]
|
|
157
|
+
result["negative_attention_mask"] = neg_encoded["attention_mask"]
|
|
158
|
+
|
|
159
|
+
else:
|
|
160
|
+
result["query"] = queries
|
|
161
|
+
result["positive"] = positives
|
|
162
|
+
if "negatives" in batch[0]:
|
|
163
|
+
result["negatives"] = [sample.get("negatives", []) for sample in batch]
|
|
164
|
+
|
|
165
|
+
return result
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
class ParallelSentenceCollator:
|
|
169
|
+
"""
|
|
170
|
+
Collator for parallel sentence data (multilingual).
|
|
171
|
+
|
|
172
|
+
Example:
|
|
173
|
+
>>> collator = ParallelSentenceCollator(tokenizer)
|
|
174
|
+
>>> batch = collator([{"english": "Hello", "non_english": "Hallo"}])
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
def __init__(
|
|
178
|
+
self,
|
|
179
|
+
tokenizer: Any = None,
|
|
180
|
+
max_length: int = 128,
|
|
181
|
+
padding: bool = True,
|
|
182
|
+
truncation: bool = True,
|
|
183
|
+
):
|
|
184
|
+
"""
|
|
185
|
+
Initialize the collator.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
tokenizer: Tokenizer for encoding text.
|
|
189
|
+
max_length: Maximum sequence length.
|
|
190
|
+
padding: Whether to pad sequences.
|
|
191
|
+
truncation: Whether to truncate sequences.
|
|
192
|
+
"""
|
|
193
|
+
self.tokenizer = tokenizer
|
|
194
|
+
self.max_length = max_length
|
|
195
|
+
self.padding = padding
|
|
196
|
+
self.truncation = truncation
|
|
197
|
+
|
|
198
|
+
def __call__(self, batch: list[dict[str, Any]]) -> dict[str, Any]:
|
|
199
|
+
"""Collate a batch of parallel sentence samples."""
|
|
200
|
+
english = [sample["english"] for sample in batch]
|
|
201
|
+
non_english = [sample["non_english"] for sample in batch]
|
|
202
|
+
|
|
203
|
+
result = {}
|
|
204
|
+
|
|
205
|
+
if self.tokenizer is not None:
|
|
206
|
+
# Tokenize English
|
|
207
|
+
en_encoded = self.tokenizer(
|
|
208
|
+
english,
|
|
209
|
+
padding=self.padding,
|
|
210
|
+
truncation=self.truncation,
|
|
211
|
+
max_length=self.max_length,
|
|
212
|
+
return_tensors="pt",
|
|
213
|
+
)
|
|
214
|
+
result["english_input_ids"] = en_encoded["input_ids"]
|
|
215
|
+
result["english_attention_mask"] = en_encoded["attention_mask"]
|
|
216
|
+
|
|
217
|
+
# Tokenize non-English
|
|
218
|
+
ne_encoded = self.tokenizer(
|
|
219
|
+
non_english,
|
|
220
|
+
padding=self.padding,
|
|
221
|
+
truncation=self.truncation,
|
|
222
|
+
max_length=self.max_length,
|
|
223
|
+
return_tensors="pt",
|
|
224
|
+
)
|
|
225
|
+
result["non_english_input_ids"] = ne_encoded["input_ids"]
|
|
226
|
+
result["non_english_attention_mask"] = ne_encoded["attention_mask"]
|
|
227
|
+
|
|
228
|
+
else:
|
|
229
|
+
result["english"] = english
|
|
230
|
+
result["non_english"] = non_english
|
|
231
|
+
|
|
232
|
+
# Add labels if present
|
|
233
|
+
if "label" in batch[0]:
|
|
234
|
+
labels = [sample["label"] for sample in batch]
|
|
235
|
+
if isinstance(labels[0], torch.Tensor):
|
|
236
|
+
result["label"] = torch.stack(labels)
|
|
237
|
+
else:
|
|
238
|
+
result["label"] = torch.tensor(labels)
|
|
239
|
+
|
|
240
|
+
return result
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
"""Base data module for distillation training."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Any, Callable
|
|
7
|
+
|
|
8
|
+
from datasets import Dataset, DatasetDict, load_dataset
|
|
9
|
+
from torch.utils.data import DataLoader
|
|
10
|
+
|
|
11
|
+
from distil_trainer.data.collators import DistillationCollator
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DistillationDataModule:
|
|
17
|
+
"""
|
|
18
|
+
Base data module for distillation training.
|
|
19
|
+
|
|
20
|
+
Handles data loading, preprocessing, and DataLoader creation.
|
|
21
|
+
|
|
22
|
+
Example:
|
|
23
|
+
>>> datamodule = DistillationDataModule(
|
|
24
|
+
... train_data="sentence-transformers/all-nli",
|
|
25
|
+
... text_column="sentence",
|
|
26
|
+
... batch_size=32,
|
|
27
|
+
... )
|
|
28
|
+
>>> datamodule.prepare_data()
|
|
29
|
+
>>> datamodule.setup()
|
|
30
|
+
>>> train_loader = datamodule.train_dataloader()
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
train_data: str | Dataset | None = None,
|
|
36
|
+
eval_data: str | Dataset | None = None,
|
|
37
|
+
test_data: str | Dataset | None = None,
|
|
38
|
+
tokenizer: Any = None,
|
|
39
|
+
text_column: str = "sentence",
|
|
40
|
+
max_seq_length: int = 512,
|
|
41
|
+
batch_size: int = 32,
|
|
42
|
+
num_workers: int = 4,
|
|
43
|
+
preprocessing_fn: Callable | None = None,
|
|
44
|
+
max_samples: int | None = None,
|
|
45
|
+
):
|
|
46
|
+
"""
|
|
47
|
+
Initialize the data module.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
train_data: Training data path/name or Dataset.
|
|
51
|
+
eval_data: Evaluation data path/name or Dataset.
|
|
52
|
+
test_data: Test data path/name or Dataset.
|
|
53
|
+
tokenizer: Tokenizer for encoding text.
|
|
54
|
+
text_column: Name of the text column in the dataset.
|
|
55
|
+
max_seq_length: Maximum sequence length for tokenization.
|
|
56
|
+
batch_size: Batch size for DataLoaders.
|
|
57
|
+
num_workers: Number of workers for data loading.
|
|
58
|
+
preprocessing_fn: Optional preprocessing function.
|
|
59
|
+
max_samples: Maximum number of samples to use.
|
|
60
|
+
"""
|
|
61
|
+
self.train_data = train_data
|
|
62
|
+
self.eval_data = eval_data
|
|
63
|
+
self.test_data = test_data
|
|
64
|
+
self.tokenizer = tokenizer
|
|
65
|
+
self.text_column = text_column
|
|
66
|
+
self.max_seq_length = max_seq_length
|
|
67
|
+
self.batch_size = batch_size
|
|
68
|
+
self.num_workers = num_workers
|
|
69
|
+
self.preprocessing_fn = preprocessing_fn
|
|
70
|
+
self.max_samples = max_samples
|
|
71
|
+
|
|
72
|
+
# Datasets
|
|
73
|
+
self.train_dataset: Dataset | None = None
|
|
74
|
+
self.eval_dataset: Dataset | None = None
|
|
75
|
+
self.test_dataset: Dataset | None = None
|
|
76
|
+
|
|
77
|
+
# Collator
|
|
78
|
+
self.collator = DistillationCollator(
|
|
79
|
+
tokenizer=tokenizer,
|
|
80
|
+
max_length=max_seq_length,
|
|
81
|
+
text_column=text_column,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
def prepare_data(self) -> None:
|
|
85
|
+
"""Download and prepare data. Called once on a single process."""
|
|
86
|
+
# Download datasets if they're specified as strings
|
|
87
|
+
if isinstance(self.train_data, str):
|
|
88
|
+
logger.info(f"Downloading training data: {self.train_data}")
|
|
89
|
+
_ = load_dataset(self.train_data, split="train")
|
|
90
|
+
|
|
91
|
+
if isinstance(self.eval_data, str):
|
|
92
|
+
logger.info(f"Downloading eval data: {self.eval_data}")
|
|
93
|
+
_ = load_dataset(self.eval_data, split="validation")
|
|
94
|
+
|
|
95
|
+
def setup(self, stage: str | None = None) -> None:
|
|
96
|
+
"""
|
|
97
|
+
Set up datasets for training/evaluation.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
stage: "fit", "validate", "test", or None for all.
|
|
101
|
+
"""
|
|
102
|
+
if stage in (None, "fit"):
|
|
103
|
+
self.train_dataset = self._load_dataset(self.train_data, "train")
|
|
104
|
+
if self.eval_data is not None:
|
|
105
|
+
self.eval_dataset = self._load_dataset(self.eval_data, "validation")
|
|
106
|
+
|
|
107
|
+
if stage in (None, "validate"):
|
|
108
|
+
if self.eval_dataset is None and self.eval_data is not None:
|
|
109
|
+
self.eval_dataset = self._load_dataset(self.eval_data, "validation")
|
|
110
|
+
|
|
111
|
+
if stage in (None, "test"):
|
|
112
|
+
if self.test_data is not None:
|
|
113
|
+
self.test_dataset = self._load_dataset(self.test_data, "test")
|
|
114
|
+
|
|
115
|
+
def _load_dataset(self, data: str | Dataset, split: str = "train") -> Dataset:
|
|
116
|
+
"""Load and preprocess a dataset."""
|
|
117
|
+
if data is None:
|
|
118
|
+
return None
|
|
119
|
+
|
|
120
|
+
if isinstance(data, Dataset):
|
|
121
|
+
dataset = data
|
|
122
|
+
else:
|
|
123
|
+
try:
|
|
124
|
+
dataset = load_dataset(data, split=split)
|
|
125
|
+
except Exception:
|
|
126
|
+
# Try loading without split
|
|
127
|
+
loaded = load_dataset(data)
|
|
128
|
+
if isinstance(loaded, DatasetDict):
|
|
129
|
+
if split in loaded:
|
|
130
|
+
dataset = loaded[split]
|
|
131
|
+
else:
|
|
132
|
+
dataset = list(loaded.values())[0]
|
|
133
|
+
else:
|
|
134
|
+
dataset = loaded
|
|
135
|
+
|
|
136
|
+
# Limit samples if specified
|
|
137
|
+
if self.max_samples is not None and len(dataset) > self.max_samples:
|
|
138
|
+
dataset = dataset.select(range(self.max_samples))
|
|
139
|
+
|
|
140
|
+
# Apply preprocessing
|
|
141
|
+
if self.preprocessing_fn is not None:
|
|
142
|
+
dataset = dataset.map(
|
|
143
|
+
self.preprocessing_fn,
|
|
144
|
+
batched=True,
|
|
145
|
+
remove_columns=dataset.column_names,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
logger.info(f"Loaded dataset with {len(dataset)} samples")
|
|
149
|
+
return dataset
|
|
150
|
+
|
|
151
|
+
def train_dataloader(self) -> DataLoader:
|
|
152
|
+
"""Return training DataLoader."""
|
|
153
|
+
if self.train_dataset is None:
|
|
154
|
+
raise ValueError("Training dataset not loaded. Call setup() first.")
|
|
155
|
+
|
|
156
|
+
return DataLoader(
|
|
157
|
+
self.train_dataset,
|
|
158
|
+
batch_size=self.batch_size,
|
|
159
|
+
shuffle=True,
|
|
160
|
+
collate_fn=self.collator,
|
|
161
|
+
num_workers=self.num_workers,
|
|
162
|
+
pin_memory=True,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
def val_dataloader(self) -> DataLoader | None:
|
|
166
|
+
"""Return validation DataLoader."""
|
|
167
|
+
if self.eval_dataset is None:
|
|
168
|
+
return None
|
|
169
|
+
|
|
170
|
+
return DataLoader(
|
|
171
|
+
self.eval_dataset,
|
|
172
|
+
batch_size=self.batch_size,
|
|
173
|
+
shuffle=False,
|
|
174
|
+
collate_fn=self.collator,
|
|
175
|
+
num_workers=self.num_workers,
|
|
176
|
+
pin_memory=True,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
def test_dataloader(self) -> DataLoader | None:
|
|
180
|
+
"""Return test DataLoader."""
|
|
181
|
+
if self.test_dataset is None:
|
|
182
|
+
return None
|
|
183
|
+
|
|
184
|
+
return DataLoader(
|
|
185
|
+
self.test_dataset,
|
|
186
|
+
batch_size=self.batch_size,
|
|
187
|
+
shuffle=False,
|
|
188
|
+
collate_fn=self.collator,
|
|
189
|
+
num_workers=self.num_workers,
|
|
190
|
+
pin_memory=True,
|
|
191
|
+
)
|
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
"""Dataset classes for distillation training."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch.utils.data import Dataset as TorchDataset
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SentenceDistillationDataset(TorchDataset):
|
|
12
|
+
"""
|
|
13
|
+
Dataset for sentence embedding distillation.
|
|
14
|
+
|
|
15
|
+
Example:
|
|
16
|
+
>>> dataset = SentenceDistillationDataset(
|
|
17
|
+
... sentences=["Hello world", "How are you?"],
|
|
18
|
+
... teacher_embeddings=teacher_embeddings,
|
|
19
|
+
... tokenizer=tokenizer,
|
|
20
|
+
... )
|
|
21
|
+
>>> item = dataset[0]
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
sentences: list[str],
|
|
27
|
+
teacher_embeddings: torch.Tensor | list | None = None,
|
|
28
|
+
tokenizer: Any = None,
|
|
29
|
+
max_length: int = 512,
|
|
30
|
+
):
|
|
31
|
+
"""
|
|
32
|
+
Initialize the dataset.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
sentences: List of sentences to encode.
|
|
36
|
+
teacher_embeddings: Precomputed teacher embeddings.
|
|
37
|
+
tokenizer: Tokenizer for encoding text.
|
|
38
|
+
max_length: Maximum sequence length.
|
|
39
|
+
"""
|
|
40
|
+
self.sentences = sentences
|
|
41
|
+
self.teacher_embeddings = teacher_embeddings
|
|
42
|
+
self.tokenizer = tokenizer
|
|
43
|
+
self.max_length = max_length
|
|
44
|
+
|
|
45
|
+
def __len__(self) -> int:
|
|
46
|
+
return len(self.sentences)
|
|
47
|
+
|
|
48
|
+
def __getitem__(self, idx: int) -> dict[str, Any]:
|
|
49
|
+
item = {"sentence": self.sentences[idx]}
|
|
50
|
+
|
|
51
|
+
if self.teacher_embeddings is not None:
|
|
52
|
+
if isinstance(self.teacher_embeddings, torch.Tensor):
|
|
53
|
+
item["label"] = self.teacher_embeddings[idx]
|
|
54
|
+
else:
|
|
55
|
+
item["label"] = torch.tensor(self.teacher_embeddings[idx])
|
|
56
|
+
|
|
57
|
+
return item
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class TripletDataset(TorchDataset):
|
|
61
|
+
"""
|
|
62
|
+
Dataset for triplet (query, positive, negative) training.
|
|
63
|
+
|
|
64
|
+
Used for contrastive learning and retrieval model training.
|
|
65
|
+
|
|
66
|
+
Example:
|
|
67
|
+
>>> dataset = TripletDataset(
|
|
68
|
+
... queries=["What is Python?"],
|
|
69
|
+
... positive_docs=[["Python is a programming language..."]],
|
|
70
|
+
... negative_docs=[["Java is a programming language..."]]
|
|
71
|
+
... )
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
queries: list[str],
|
|
77
|
+
positive_docs: list[list[str]],
|
|
78
|
+
negative_docs: list[list[str]] | None = None,
|
|
79
|
+
tokenizer: Any = None,
|
|
80
|
+
max_length: int = 512,
|
|
81
|
+
num_negatives: int = 5,
|
|
82
|
+
):
|
|
83
|
+
"""
|
|
84
|
+
Initialize the dataset.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
queries: List of query strings.
|
|
88
|
+
positive_docs: List of positive document lists (multiple per query).
|
|
89
|
+
negative_docs: List of negative document lists.
|
|
90
|
+
tokenizer: Tokenizer for encoding text.
|
|
91
|
+
max_length: Maximum sequence length.
|
|
92
|
+
num_negatives: Number of negatives to sample per query.
|
|
93
|
+
"""
|
|
94
|
+
self.queries = queries
|
|
95
|
+
self.positive_docs = positive_docs
|
|
96
|
+
self.negative_docs = negative_docs
|
|
97
|
+
self.tokenizer = tokenizer
|
|
98
|
+
self.max_length = max_length
|
|
99
|
+
self.num_negatives = num_negatives
|
|
100
|
+
|
|
101
|
+
def __len__(self) -> int:
|
|
102
|
+
return len(self.queries)
|
|
103
|
+
|
|
104
|
+
def __getitem__(self, idx: int) -> dict[str, Any]:
|
|
105
|
+
query = self.queries[idx]
|
|
106
|
+
positives = self.positive_docs[idx]
|
|
107
|
+
|
|
108
|
+
# Select first positive (or random)
|
|
109
|
+
positive = positives[0] if positives else ""
|
|
110
|
+
|
|
111
|
+
item = {
|
|
112
|
+
"query": query,
|
|
113
|
+
"positive": positive,
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
# Add negatives if available
|
|
117
|
+
if self.negative_docs is not None and idx < len(self.negative_docs):
|
|
118
|
+
negatives = self.negative_docs[idx]
|
|
119
|
+
# Limit to num_negatives
|
|
120
|
+
item["negatives"] = negatives[: self.num_negatives]
|
|
121
|
+
|
|
122
|
+
return item
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class ParallelSentencesDataset(TorchDataset):
|
|
126
|
+
"""
|
|
127
|
+
Dataset for parallel sentences (multilingual training).
|
|
128
|
+
|
|
129
|
+
Example:
|
|
130
|
+
>>> dataset = ParallelSentencesDataset(
|
|
131
|
+
... source_sentences=["Hello"],
|
|
132
|
+
... target_sentences=["Hallo"],
|
|
133
|
+
... teacher_embeddings=embeddings,
|
|
134
|
+
... )
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
def __init__(
|
|
138
|
+
self,
|
|
139
|
+
source_sentences: list[str],
|
|
140
|
+
target_sentences: list[str],
|
|
141
|
+
teacher_embeddings: torch.Tensor | list | None = None,
|
|
142
|
+
tokenizer: Any = None,
|
|
143
|
+
max_length: int = 128,
|
|
144
|
+
):
|
|
145
|
+
"""
|
|
146
|
+
Initialize the dataset.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
source_sentences: Source language sentences.
|
|
150
|
+
target_sentences: Target language sentences.
|
|
151
|
+
teacher_embeddings: Teacher embeddings for source sentences.
|
|
152
|
+
tokenizer: Tokenizer for encoding text.
|
|
153
|
+
max_length: Maximum sequence length.
|
|
154
|
+
"""
|
|
155
|
+
assert len(source_sentences) == len(target_sentences)
|
|
156
|
+
|
|
157
|
+
self.source_sentences = source_sentences
|
|
158
|
+
self.target_sentences = target_sentences
|
|
159
|
+
self.teacher_embeddings = teacher_embeddings
|
|
160
|
+
self.tokenizer = tokenizer
|
|
161
|
+
self.max_length = max_length
|
|
162
|
+
|
|
163
|
+
def __len__(self) -> int:
|
|
164
|
+
return len(self.source_sentences)
|
|
165
|
+
|
|
166
|
+
def __getitem__(self, idx: int) -> dict[str, Any]:
|
|
167
|
+
item = {
|
|
168
|
+
"english": self.source_sentences[idx],
|
|
169
|
+
"non_english": self.target_sentences[idx],
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
if self.teacher_embeddings is not None:
|
|
173
|
+
if isinstance(self.teacher_embeddings, torch.Tensor):
|
|
174
|
+
item["label"] = self.teacher_embeddings[idx]
|
|
175
|
+
else:
|
|
176
|
+
item["label"] = torch.tensor(self.teacher_embeddings[idx])
|
|
177
|
+
|
|
178
|
+
return item
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class ReasoningDataset(TorchDataset):
|
|
182
|
+
"""
|
|
183
|
+
Dataset for chain-of-thought/reasoning distillation.
|
|
184
|
+
|
|
185
|
+
Example:
|
|
186
|
+
>>> dataset = ReasoningDataset(
|
|
187
|
+
... questions=["What is 2+2?"],
|
|
188
|
+
... reasoning_chains=["Let me think... 2+2=4"],
|
|
189
|
+
... answers=["4"],
|
|
190
|
+
... )
|
|
191
|
+
"""
|
|
192
|
+
|
|
193
|
+
def __init__(
|
|
194
|
+
self,
|
|
195
|
+
questions: list[str],
|
|
196
|
+
reasoning_chains: list[str],
|
|
197
|
+
answers: list[str],
|
|
198
|
+
system_prompts: list[str] | None = None,
|
|
199
|
+
tokenizer: Any = None,
|
|
200
|
+
max_length: int = 16384,
|
|
201
|
+
chat_format: bool = True,
|
|
202
|
+
):
|
|
203
|
+
"""
|
|
204
|
+
Initialize the dataset.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
questions: List of questions.
|
|
208
|
+
reasoning_chains: List of reasoning chains.
|
|
209
|
+
answers: List of final answers.
|
|
210
|
+
system_prompts: Optional system prompts.
|
|
211
|
+
tokenizer: Tokenizer for encoding.
|
|
212
|
+
max_length: Maximum sequence length.
|
|
213
|
+
chat_format: Whether to format as chat.
|
|
214
|
+
"""
|
|
215
|
+
self.questions = questions
|
|
216
|
+
self.reasoning_chains = reasoning_chains
|
|
217
|
+
self.answers = answers
|
|
218
|
+
self.system_prompts = system_prompts
|
|
219
|
+
self.tokenizer = tokenizer
|
|
220
|
+
self.max_length = max_length
|
|
221
|
+
self.chat_format = chat_format
|
|
222
|
+
|
|
223
|
+
def __len__(self) -> int:
|
|
224
|
+
return len(self.questions)
|
|
225
|
+
|
|
226
|
+
def __getitem__(self, idx: int) -> dict[str, Any]:
|
|
227
|
+
question = self.questions[idx]
|
|
228
|
+
reasoning = self.reasoning_chains[idx]
|
|
229
|
+
answer = self.answers[idx]
|
|
230
|
+
|
|
231
|
+
item = {
|
|
232
|
+
"question": question,
|
|
233
|
+
"reasoning": reasoning,
|
|
234
|
+
"answer": answer,
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
if self.system_prompts is not None and idx < len(self.system_prompts):
|
|
238
|
+
item["system"] = self.system_prompts[idx]
|
|
239
|
+
|
|
240
|
+
if self.chat_format:
|
|
241
|
+
# Format as conversation
|
|
242
|
+
full_response = f"<|begin_of_thought|>\n{reasoning}\n<|end_of_thought|>\n\n<|begin_of_solution|>\n{answer}\n<|end_of_solution|>"
|
|
243
|
+
item["full_response"] = full_response
|
|
244
|
+
|
|
245
|
+
return item
|