distil-trainer 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,19 @@
1
+ """Data module for distillation training."""
2
+
3
+ from distil_trainer.data.datamodule import DistillationDataModule
4
+ from distil_trainer.data.datasets import (
5
+ SentenceDistillationDataset,
6
+ TripletDataset,
7
+ ParallelSentencesDataset,
8
+ )
9
+ from distil_trainer.data.collators import DistillationCollator
10
+ from distil_trainer.data.loaders import DatasetLoaders
11
+
12
+ __all__ = [
13
+ "DistillationDataModule",
14
+ "SentenceDistillationDataset",
15
+ "TripletDataset",
16
+ "ParallelSentencesDataset",
17
+ "DistillationCollator",
18
+ "DatasetLoaders",
19
+ ]
@@ -0,0 +1,240 @@
1
+ """Data collators for distillation training."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ import torch
8
+
9
+
10
+ class DistillationCollator:
11
+ """
12
+ Collator for distillation training data.
13
+
14
+ Handles tokenization and batching of text data.
15
+
16
+ Example:
17
+ >>> collator = DistillationCollator(tokenizer, max_length=512)
18
+ >>> batch = collator([{"sentence": "Hello"}, {"sentence": "World"}])
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ tokenizer: Any = None,
24
+ max_length: int = 512,
25
+ text_column: str = "sentence",
26
+ padding: bool = True,
27
+ truncation: bool = True,
28
+ ):
29
+ """
30
+ Initialize the collator.
31
+
32
+ Args:
33
+ tokenizer: Tokenizer for encoding text.
34
+ max_length: Maximum sequence length.
35
+ text_column: Name of the text column.
36
+ padding: Whether to pad sequences.
37
+ truncation: Whether to truncate sequences.
38
+ """
39
+ self.tokenizer = tokenizer
40
+ self.max_length = max_length
41
+ self.text_column = text_column
42
+ self.padding = padding
43
+ self.truncation = truncation
44
+
45
+ def __call__(self, batch: list[dict[str, Any]]) -> dict[str, Any]:
46
+ """
47
+ Collate a batch of samples.
48
+
49
+ Args:
50
+ batch: List of sample dictionaries.
51
+
52
+ Returns:
53
+ Collated batch dictionary.
54
+ """
55
+ # Extract text
56
+ texts = [sample.get(self.text_column, sample.get("sentence", "")) for sample in batch]
57
+
58
+ # Tokenize if tokenizer is available
59
+ if self.tokenizer is not None:
60
+ encoded = self.tokenizer(
61
+ texts,
62
+ padding=self.padding,
63
+ truncation=self.truncation,
64
+ max_length=self.max_length,
65
+ return_tensors="pt",
66
+ )
67
+ result = dict(encoded)
68
+ else:
69
+ result = {self.text_column: texts}
70
+
71
+ # Add labels if present
72
+ if "label" in batch[0]:
73
+ labels = [sample["label"] for sample in batch]
74
+ if isinstance(labels[0], torch.Tensor):
75
+ result["label"] = torch.stack(labels)
76
+ else:
77
+ result["label"] = torch.tensor(labels)
78
+
79
+ return result
80
+
81
+
82
+ class TripletCollator:
83
+ """
84
+ Collator for triplet data (query, positive, negatives).
85
+
86
+ Example:
87
+ >>> collator = TripletCollator(tokenizer)
88
+ >>> batch = collator([{"query": "q1", "positive": "p1", "negatives": ["n1"]}])
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ tokenizer: Any = None,
94
+ max_length: int = 512,
95
+ padding: bool = True,
96
+ truncation: bool = True,
97
+ ):
98
+ """
99
+ Initialize the collator.
100
+
101
+ Args:
102
+ tokenizer: Tokenizer for encoding text.
103
+ max_length: Maximum sequence length.
104
+ padding: Whether to pad sequences.
105
+ truncation: Whether to truncate sequences.
106
+ """
107
+ self.tokenizer = tokenizer
108
+ self.max_length = max_length
109
+ self.padding = padding
110
+ self.truncation = truncation
111
+
112
+ def __call__(self, batch: list[dict[str, Any]]) -> dict[str, Any]:
113
+ """Collate a batch of triplet samples."""
114
+ queries = [sample["query"] for sample in batch]
115
+ positives = [sample["positive"] for sample in batch]
116
+
117
+ result = {}
118
+
119
+ if self.tokenizer is not None:
120
+ # Tokenize queries
121
+ query_encoded = self.tokenizer(
122
+ queries,
123
+ padding=self.padding,
124
+ truncation=self.truncation,
125
+ max_length=self.max_length,
126
+ return_tensors="pt",
127
+ )
128
+ result["query_input_ids"] = query_encoded["input_ids"]
129
+ result["query_attention_mask"] = query_encoded["attention_mask"]
130
+
131
+ # Tokenize positives
132
+ pos_encoded = self.tokenizer(
133
+ positives,
134
+ padding=self.padding,
135
+ truncation=self.truncation,
136
+ max_length=self.max_length,
137
+ return_tensors="pt",
138
+ )
139
+ result["positive_input_ids"] = pos_encoded["input_ids"]
140
+ result["positive_attention_mask"] = pos_encoded["attention_mask"]
141
+
142
+ # Tokenize negatives if present
143
+ if "negatives" in batch[0]:
144
+ all_negatives = []
145
+ for sample in batch:
146
+ all_negatives.extend(sample.get("negatives", []))
147
+
148
+ if all_negatives:
149
+ neg_encoded = self.tokenizer(
150
+ all_negatives,
151
+ padding=self.padding,
152
+ truncation=self.truncation,
153
+ max_length=self.max_length,
154
+ return_tensors="pt",
155
+ )
156
+ result["negative_input_ids"] = neg_encoded["input_ids"]
157
+ result["negative_attention_mask"] = neg_encoded["attention_mask"]
158
+
159
+ else:
160
+ result["query"] = queries
161
+ result["positive"] = positives
162
+ if "negatives" in batch[0]:
163
+ result["negatives"] = [sample.get("negatives", []) for sample in batch]
164
+
165
+ return result
166
+
167
+
168
+ class ParallelSentenceCollator:
169
+ """
170
+ Collator for parallel sentence data (multilingual).
171
+
172
+ Example:
173
+ >>> collator = ParallelSentenceCollator(tokenizer)
174
+ >>> batch = collator([{"english": "Hello", "non_english": "Hallo"}])
175
+ """
176
+
177
+ def __init__(
178
+ self,
179
+ tokenizer: Any = None,
180
+ max_length: int = 128,
181
+ padding: bool = True,
182
+ truncation: bool = True,
183
+ ):
184
+ """
185
+ Initialize the collator.
186
+
187
+ Args:
188
+ tokenizer: Tokenizer for encoding text.
189
+ max_length: Maximum sequence length.
190
+ padding: Whether to pad sequences.
191
+ truncation: Whether to truncate sequences.
192
+ """
193
+ self.tokenizer = tokenizer
194
+ self.max_length = max_length
195
+ self.padding = padding
196
+ self.truncation = truncation
197
+
198
+ def __call__(self, batch: list[dict[str, Any]]) -> dict[str, Any]:
199
+ """Collate a batch of parallel sentence samples."""
200
+ english = [sample["english"] for sample in batch]
201
+ non_english = [sample["non_english"] for sample in batch]
202
+
203
+ result = {}
204
+
205
+ if self.tokenizer is not None:
206
+ # Tokenize English
207
+ en_encoded = self.tokenizer(
208
+ english,
209
+ padding=self.padding,
210
+ truncation=self.truncation,
211
+ max_length=self.max_length,
212
+ return_tensors="pt",
213
+ )
214
+ result["english_input_ids"] = en_encoded["input_ids"]
215
+ result["english_attention_mask"] = en_encoded["attention_mask"]
216
+
217
+ # Tokenize non-English
218
+ ne_encoded = self.tokenizer(
219
+ non_english,
220
+ padding=self.padding,
221
+ truncation=self.truncation,
222
+ max_length=self.max_length,
223
+ return_tensors="pt",
224
+ )
225
+ result["non_english_input_ids"] = ne_encoded["input_ids"]
226
+ result["non_english_attention_mask"] = ne_encoded["attention_mask"]
227
+
228
+ else:
229
+ result["english"] = english
230
+ result["non_english"] = non_english
231
+
232
+ # Add labels if present
233
+ if "label" in batch[0]:
234
+ labels = [sample["label"] for sample in batch]
235
+ if isinstance(labels[0], torch.Tensor):
236
+ result["label"] = torch.stack(labels)
237
+ else:
238
+ result["label"] = torch.tensor(labels)
239
+
240
+ return result
@@ -0,0 +1,191 @@
1
+ """Base data module for distillation training."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import Any, Callable
7
+
8
+ from datasets import Dataset, DatasetDict, load_dataset
9
+ from torch.utils.data import DataLoader
10
+
11
+ from distil_trainer.data.collators import DistillationCollator
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class DistillationDataModule:
17
+ """
18
+ Base data module for distillation training.
19
+
20
+ Handles data loading, preprocessing, and DataLoader creation.
21
+
22
+ Example:
23
+ >>> datamodule = DistillationDataModule(
24
+ ... train_data="sentence-transformers/all-nli",
25
+ ... text_column="sentence",
26
+ ... batch_size=32,
27
+ ... )
28
+ >>> datamodule.prepare_data()
29
+ >>> datamodule.setup()
30
+ >>> train_loader = datamodule.train_dataloader()
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ train_data: str | Dataset | None = None,
36
+ eval_data: str | Dataset | None = None,
37
+ test_data: str | Dataset | None = None,
38
+ tokenizer: Any = None,
39
+ text_column: str = "sentence",
40
+ max_seq_length: int = 512,
41
+ batch_size: int = 32,
42
+ num_workers: int = 4,
43
+ preprocessing_fn: Callable | None = None,
44
+ max_samples: int | None = None,
45
+ ):
46
+ """
47
+ Initialize the data module.
48
+
49
+ Args:
50
+ train_data: Training data path/name or Dataset.
51
+ eval_data: Evaluation data path/name or Dataset.
52
+ test_data: Test data path/name or Dataset.
53
+ tokenizer: Tokenizer for encoding text.
54
+ text_column: Name of the text column in the dataset.
55
+ max_seq_length: Maximum sequence length for tokenization.
56
+ batch_size: Batch size for DataLoaders.
57
+ num_workers: Number of workers for data loading.
58
+ preprocessing_fn: Optional preprocessing function.
59
+ max_samples: Maximum number of samples to use.
60
+ """
61
+ self.train_data = train_data
62
+ self.eval_data = eval_data
63
+ self.test_data = test_data
64
+ self.tokenizer = tokenizer
65
+ self.text_column = text_column
66
+ self.max_seq_length = max_seq_length
67
+ self.batch_size = batch_size
68
+ self.num_workers = num_workers
69
+ self.preprocessing_fn = preprocessing_fn
70
+ self.max_samples = max_samples
71
+
72
+ # Datasets
73
+ self.train_dataset: Dataset | None = None
74
+ self.eval_dataset: Dataset | None = None
75
+ self.test_dataset: Dataset | None = None
76
+
77
+ # Collator
78
+ self.collator = DistillationCollator(
79
+ tokenizer=tokenizer,
80
+ max_length=max_seq_length,
81
+ text_column=text_column,
82
+ )
83
+
84
+ def prepare_data(self) -> None:
85
+ """Download and prepare data. Called once on a single process."""
86
+ # Download datasets if they're specified as strings
87
+ if isinstance(self.train_data, str):
88
+ logger.info(f"Downloading training data: {self.train_data}")
89
+ _ = load_dataset(self.train_data, split="train")
90
+
91
+ if isinstance(self.eval_data, str):
92
+ logger.info(f"Downloading eval data: {self.eval_data}")
93
+ _ = load_dataset(self.eval_data, split="validation")
94
+
95
+ def setup(self, stage: str | None = None) -> None:
96
+ """
97
+ Set up datasets for training/evaluation.
98
+
99
+ Args:
100
+ stage: "fit", "validate", "test", or None for all.
101
+ """
102
+ if stage in (None, "fit"):
103
+ self.train_dataset = self._load_dataset(self.train_data, "train")
104
+ if self.eval_data is not None:
105
+ self.eval_dataset = self._load_dataset(self.eval_data, "validation")
106
+
107
+ if stage in (None, "validate"):
108
+ if self.eval_dataset is None and self.eval_data is not None:
109
+ self.eval_dataset = self._load_dataset(self.eval_data, "validation")
110
+
111
+ if stage in (None, "test"):
112
+ if self.test_data is not None:
113
+ self.test_dataset = self._load_dataset(self.test_data, "test")
114
+
115
+ def _load_dataset(self, data: str | Dataset, split: str = "train") -> Dataset:
116
+ """Load and preprocess a dataset."""
117
+ if data is None:
118
+ return None
119
+
120
+ if isinstance(data, Dataset):
121
+ dataset = data
122
+ else:
123
+ try:
124
+ dataset = load_dataset(data, split=split)
125
+ except Exception:
126
+ # Try loading without split
127
+ loaded = load_dataset(data)
128
+ if isinstance(loaded, DatasetDict):
129
+ if split in loaded:
130
+ dataset = loaded[split]
131
+ else:
132
+ dataset = list(loaded.values())[0]
133
+ else:
134
+ dataset = loaded
135
+
136
+ # Limit samples if specified
137
+ if self.max_samples is not None and len(dataset) > self.max_samples:
138
+ dataset = dataset.select(range(self.max_samples))
139
+
140
+ # Apply preprocessing
141
+ if self.preprocessing_fn is not None:
142
+ dataset = dataset.map(
143
+ self.preprocessing_fn,
144
+ batched=True,
145
+ remove_columns=dataset.column_names,
146
+ )
147
+
148
+ logger.info(f"Loaded dataset with {len(dataset)} samples")
149
+ return dataset
150
+
151
+ def train_dataloader(self) -> DataLoader:
152
+ """Return training DataLoader."""
153
+ if self.train_dataset is None:
154
+ raise ValueError("Training dataset not loaded. Call setup() first.")
155
+
156
+ return DataLoader(
157
+ self.train_dataset,
158
+ batch_size=self.batch_size,
159
+ shuffle=True,
160
+ collate_fn=self.collator,
161
+ num_workers=self.num_workers,
162
+ pin_memory=True,
163
+ )
164
+
165
+ def val_dataloader(self) -> DataLoader | None:
166
+ """Return validation DataLoader."""
167
+ if self.eval_dataset is None:
168
+ return None
169
+
170
+ return DataLoader(
171
+ self.eval_dataset,
172
+ batch_size=self.batch_size,
173
+ shuffle=False,
174
+ collate_fn=self.collator,
175
+ num_workers=self.num_workers,
176
+ pin_memory=True,
177
+ )
178
+
179
+ def test_dataloader(self) -> DataLoader | None:
180
+ """Return test DataLoader."""
181
+ if self.test_dataset is None:
182
+ return None
183
+
184
+ return DataLoader(
185
+ self.test_dataset,
186
+ batch_size=self.batch_size,
187
+ shuffle=False,
188
+ collate_fn=self.collator,
189
+ num_workers=self.num_workers,
190
+ pin_memory=True,
191
+ )
@@ -0,0 +1,245 @@
1
+ """Dataset classes for distillation training."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ import torch
8
+ from torch.utils.data import Dataset as TorchDataset
9
+
10
+
11
+ class SentenceDistillationDataset(TorchDataset):
12
+ """
13
+ Dataset for sentence embedding distillation.
14
+
15
+ Example:
16
+ >>> dataset = SentenceDistillationDataset(
17
+ ... sentences=["Hello world", "How are you?"],
18
+ ... teacher_embeddings=teacher_embeddings,
19
+ ... tokenizer=tokenizer,
20
+ ... )
21
+ >>> item = dataset[0]
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ sentences: list[str],
27
+ teacher_embeddings: torch.Tensor | list | None = None,
28
+ tokenizer: Any = None,
29
+ max_length: int = 512,
30
+ ):
31
+ """
32
+ Initialize the dataset.
33
+
34
+ Args:
35
+ sentences: List of sentences to encode.
36
+ teacher_embeddings: Precomputed teacher embeddings.
37
+ tokenizer: Tokenizer for encoding text.
38
+ max_length: Maximum sequence length.
39
+ """
40
+ self.sentences = sentences
41
+ self.teacher_embeddings = teacher_embeddings
42
+ self.tokenizer = tokenizer
43
+ self.max_length = max_length
44
+
45
+ def __len__(self) -> int:
46
+ return len(self.sentences)
47
+
48
+ def __getitem__(self, idx: int) -> dict[str, Any]:
49
+ item = {"sentence": self.sentences[idx]}
50
+
51
+ if self.teacher_embeddings is not None:
52
+ if isinstance(self.teacher_embeddings, torch.Tensor):
53
+ item["label"] = self.teacher_embeddings[idx]
54
+ else:
55
+ item["label"] = torch.tensor(self.teacher_embeddings[idx])
56
+
57
+ return item
58
+
59
+
60
+ class TripletDataset(TorchDataset):
61
+ """
62
+ Dataset for triplet (query, positive, negative) training.
63
+
64
+ Used for contrastive learning and retrieval model training.
65
+
66
+ Example:
67
+ >>> dataset = TripletDataset(
68
+ ... queries=["What is Python?"],
69
+ ... positive_docs=[["Python is a programming language..."]],
70
+ ... negative_docs=[["Java is a programming language..."]]
71
+ ... )
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ queries: list[str],
77
+ positive_docs: list[list[str]],
78
+ negative_docs: list[list[str]] | None = None,
79
+ tokenizer: Any = None,
80
+ max_length: int = 512,
81
+ num_negatives: int = 5,
82
+ ):
83
+ """
84
+ Initialize the dataset.
85
+
86
+ Args:
87
+ queries: List of query strings.
88
+ positive_docs: List of positive document lists (multiple per query).
89
+ negative_docs: List of negative document lists.
90
+ tokenizer: Tokenizer for encoding text.
91
+ max_length: Maximum sequence length.
92
+ num_negatives: Number of negatives to sample per query.
93
+ """
94
+ self.queries = queries
95
+ self.positive_docs = positive_docs
96
+ self.negative_docs = negative_docs
97
+ self.tokenizer = tokenizer
98
+ self.max_length = max_length
99
+ self.num_negatives = num_negatives
100
+
101
+ def __len__(self) -> int:
102
+ return len(self.queries)
103
+
104
+ def __getitem__(self, idx: int) -> dict[str, Any]:
105
+ query = self.queries[idx]
106
+ positives = self.positive_docs[idx]
107
+
108
+ # Select first positive (or random)
109
+ positive = positives[0] if positives else ""
110
+
111
+ item = {
112
+ "query": query,
113
+ "positive": positive,
114
+ }
115
+
116
+ # Add negatives if available
117
+ if self.negative_docs is not None and idx < len(self.negative_docs):
118
+ negatives = self.negative_docs[idx]
119
+ # Limit to num_negatives
120
+ item["negatives"] = negatives[: self.num_negatives]
121
+
122
+ return item
123
+
124
+
125
+ class ParallelSentencesDataset(TorchDataset):
126
+ """
127
+ Dataset for parallel sentences (multilingual training).
128
+
129
+ Example:
130
+ >>> dataset = ParallelSentencesDataset(
131
+ ... source_sentences=["Hello"],
132
+ ... target_sentences=["Hallo"],
133
+ ... teacher_embeddings=embeddings,
134
+ ... )
135
+ """
136
+
137
+ def __init__(
138
+ self,
139
+ source_sentences: list[str],
140
+ target_sentences: list[str],
141
+ teacher_embeddings: torch.Tensor | list | None = None,
142
+ tokenizer: Any = None,
143
+ max_length: int = 128,
144
+ ):
145
+ """
146
+ Initialize the dataset.
147
+
148
+ Args:
149
+ source_sentences: Source language sentences.
150
+ target_sentences: Target language sentences.
151
+ teacher_embeddings: Teacher embeddings for source sentences.
152
+ tokenizer: Tokenizer for encoding text.
153
+ max_length: Maximum sequence length.
154
+ """
155
+ assert len(source_sentences) == len(target_sentences)
156
+
157
+ self.source_sentences = source_sentences
158
+ self.target_sentences = target_sentences
159
+ self.teacher_embeddings = teacher_embeddings
160
+ self.tokenizer = tokenizer
161
+ self.max_length = max_length
162
+
163
+ def __len__(self) -> int:
164
+ return len(self.source_sentences)
165
+
166
+ def __getitem__(self, idx: int) -> dict[str, Any]:
167
+ item = {
168
+ "english": self.source_sentences[idx],
169
+ "non_english": self.target_sentences[idx],
170
+ }
171
+
172
+ if self.teacher_embeddings is not None:
173
+ if isinstance(self.teacher_embeddings, torch.Tensor):
174
+ item["label"] = self.teacher_embeddings[idx]
175
+ else:
176
+ item["label"] = torch.tensor(self.teacher_embeddings[idx])
177
+
178
+ return item
179
+
180
+
181
+ class ReasoningDataset(TorchDataset):
182
+ """
183
+ Dataset for chain-of-thought/reasoning distillation.
184
+
185
+ Example:
186
+ >>> dataset = ReasoningDataset(
187
+ ... questions=["What is 2+2?"],
188
+ ... reasoning_chains=["Let me think... 2+2=4"],
189
+ ... answers=["4"],
190
+ ... )
191
+ """
192
+
193
+ def __init__(
194
+ self,
195
+ questions: list[str],
196
+ reasoning_chains: list[str],
197
+ answers: list[str],
198
+ system_prompts: list[str] | None = None,
199
+ tokenizer: Any = None,
200
+ max_length: int = 16384,
201
+ chat_format: bool = True,
202
+ ):
203
+ """
204
+ Initialize the dataset.
205
+
206
+ Args:
207
+ questions: List of questions.
208
+ reasoning_chains: List of reasoning chains.
209
+ answers: List of final answers.
210
+ system_prompts: Optional system prompts.
211
+ tokenizer: Tokenizer for encoding.
212
+ max_length: Maximum sequence length.
213
+ chat_format: Whether to format as chat.
214
+ """
215
+ self.questions = questions
216
+ self.reasoning_chains = reasoning_chains
217
+ self.answers = answers
218
+ self.system_prompts = system_prompts
219
+ self.tokenizer = tokenizer
220
+ self.max_length = max_length
221
+ self.chat_format = chat_format
222
+
223
+ def __len__(self) -> int:
224
+ return len(self.questions)
225
+
226
+ def __getitem__(self, idx: int) -> dict[str, Any]:
227
+ question = self.questions[idx]
228
+ reasoning = self.reasoning_chains[idx]
229
+ answer = self.answers[idx]
230
+
231
+ item = {
232
+ "question": question,
233
+ "reasoning": reasoning,
234
+ "answer": answer,
235
+ }
236
+
237
+ if self.system_prompts is not None and idx < len(self.system_prompts):
238
+ item["system"] = self.system_prompts[idx]
239
+
240
+ if self.chat_format:
241
+ # Format as conversation
242
+ full_response = f"<|begin_of_thought|>\n{reasoning}\n<|end_of_thought|>\n\n<|begin_of_solution|>\n{answer}\n<|end_of_solution|>"
243
+ item["full_response"] = full_response
244
+
245
+ return item