rxnn 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,491 @@
1
+ import os, traceback, shutil
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from typing import Union
6
+ from src.utils import human_format
7
+ from torch.nn.parallel import DistributedDataParallel
8
+ from huggingface_hub import PyTorchModelHubMixin
9
+
10
+
11
+ class TrainerCallback:
12
+ def on_epoch_start(self, model: torch.nn.Module, epoch: int) -> None:
13
+ pass
14
+
15
+ def on_epoch_end(self, model: torch.nn.Module, epoch: int) -> Union[bool, None]:
16
+ pass
17
+
18
+ def on_batch_start(self, model: torch.nn.Module, batch_idx: int, batch: dict[str, torch.Tensor]) -> None:
19
+ pass
20
+
21
+ def on_batch_end(self, model: torch.nn.Module, batch_idx: int, loss: float, batch: dict[str, torch.Tensor]) -> \
22
+ Union[
23
+ bool, None]:
24
+ pass
25
+
26
+ def on_training_end(self, model: torch.nn.Module) -> None:
27
+ pass
28
+
29
+ def on_validation_end(self, model: torch.nn.Module, epoch: int, val_loss: float, val_metrics: dict) -> Union[
30
+ bool, None]:
31
+ pass
32
+
33
+
34
+ class PrintLossCallback(TrainerCallback):
35
+ def __init__(self, batch_log_interval: int = 100, joint_mode: bool = False, batches_per_epoch: int = None):
36
+ self.epoch_means = []
37
+ self.epoch_losses = []
38
+ self.batch_group_losses = []
39
+ self.batch_log_interval = batch_log_interval
40
+ self.joint_mode = joint_mode
41
+ self.batches_per_epoch = batches_per_epoch
42
+
43
+ def on_batch_start(self, model: nn.Module, batch_idx: int, batch: dict[str, torch.Tensor]) -> None:
44
+ pass
45
+
46
+ def on_batch_end(self, model: nn.Module, batch_idx: int, loss: int,
47
+ batch: dict[str, torch.Tensor]) -> None:
48
+ self.batch_group_losses.append(loss)
49
+ self.epoch_losses.append(loss)
50
+
51
+ if batch_idx != 0 and batch_idx % self.batch_log_interval == 0:
52
+ batch_group_mean = np.stack(self.batch_group_losses).mean()
53
+ self.batch_group_losses = []
54
+ if self.batches_per_epoch is not None:
55
+ print(
56
+ f'Batch {batch_idx} / {self.batches_per_epoch} - loss: {loss}, last {self.batch_log_interval} batches mean loss: {batch_group_mean:.4f}')
57
+ else:
58
+ print(
59
+ f'Batch {batch_idx} - loss: {loss}, last {self.batch_log_interval} batches mean loss: {batch_group_mean:.4f}')
60
+
61
+ def on_epoch_start(self, model: nn.Module, epoch: int) -> None:
62
+ self.epoch_losses = []
63
+ print(f'Start epoch: {epoch}')
64
+
65
+ def on_epoch_end(self, model: nn.Module, epoch: int) -> None:
66
+ epoch_mean = np.stack(self.epoch_losses).mean()
67
+ print(f'Epoch {epoch} - mean loss: {epoch_mean:.4f}')
68
+ self.epoch_means.append(epoch_mean)
69
+
70
+ def on_training_end(self, model: nn.Module) -> None:
71
+ print(f'Finished training! All losses:')
72
+ print(self.epoch_means)
73
+
74
+ def on_validation_end(self, model: nn.Module, epoch: int, val_loss: float, val_metrics: dict) -> None:
75
+ if self.joint_mode:
76
+ print(f"Epoch {epoch} - encoder loss: {val_metrics['loss']['encoder']:.4f}")
77
+ print(f"Epoch {epoch} - decoder loss: {val_metrics['loss']['decoder']:.4f}")
78
+ print(f"Epoch {epoch} - validation Loss: {val_loss:.4f}")
79
+
80
+
81
+ class PrintAccuracyCallback(TrainerCallback):
82
+ def __init__(self, joint_mode: bool = False):
83
+ self.joint_mode = joint_mode
84
+
85
+ def on_validation_end(self, model: nn.Module, epoch: int, val_loss: float, val_metrics: dict) -> None:
86
+ if self.joint_mode:
87
+ print(f"Epoch {epoch} - encoder accuracy: {val_metrics['accuracy']['encoder']:.4f}")
88
+ print(f"Epoch {epoch} - decoder accuracy: {val_metrics['accuracy']['decoder']:.4f}")
89
+ else:
90
+ print(f"Epoch {epoch} - accuracy: {val_metrics['accuracy']:.4f}")
91
+
92
+
93
+ class TokenCounterCallback(TrainerCallback):
94
+ def __init__(self, limit: int, batch_log_interval: int = 100):
95
+ self.total_tokens = 0
96
+ self.limit = limit
97
+ self.batch_log_interval = batch_log_interval
98
+
99
+ def on_batch_end(self, model: nn.Module, batch_idx: int, loss: int,
100
+ batch: dict[str, torch.Tensor]) -> bool:
101
+ attention_mask = batch['attention_mask']
102
+ batch_tokens = attention_mask.sum().item()
103
+ self.total_tokens += batch_tokens
104
+ if batch_idx != 0 and batch_idx % self.batch_log_interval == 0:
105
+ print(f'Total processed tokens: {human_format(self.total_tokens)}')
106
+
107
+ should_stop_training = self.total_tokens >= self.limit
108
+ if should_stop_training:
109
+ print(f'Reached a limit of {human_format(self.limit)} processed tokens - stopping training')
110
+ return should_stop_training
111
+
112
+ def on_training_end(self, model: torch.nn.Module) -> None:
113
+ print(f'Total training tokens: {human_format(self.total_tokens)}')
114
+
115
+ def get_total_tokens(self):
116
+ return self.total_tokens
117
+
118
+
119
+ class ModelSaveCallback(TrainerCallback):
120
+ def __init__(
121
+ self,
122
+ save_dir: str,
123
+ save_best_only: bool = True,
124
+ max_keep: int = 3,
125
+ push_to_hub: bool = False,
126
+ hub_model_id: str = None,
127
+ private_repo: bool = False,
128
+ hf_token: str = None,
129
+ push_checkpoint_weights: bool = True,
130
+ final_commit_message: str = None,
131
+ save_checkpoint_after_n_batches: int = None,
132
+ push_batch_checkpoint: bool = False,
133
+ display_exc_trace: bool = False,
134
+ ):
135
+ self.save_dir = save_dir
136
+ self.save_best_only = save_best_only
137
+ self.max_keep = max_keep
138
+ self.best_loss = float('inf')
139
+ self.ckpt_paths = []
140
+ self.push_to_hub = push_to_hub
141
+ self.hub_model_id = hub_model_id
142
+ self.private_repo = private_repo
143
+ self.hf_token = hf_token
144
+ self.push_checkpoint_weights = push_checkpoint_weights
145
+ self.final_commit_message = final_commit_message
146
+ self.save_checkpoint_after_n_batches = save_checkpoint_after_n_batches
147
+ self.push_batch_checkpoint = push_batch_checkpoint
148
+ self.finished_epochs = 0
149
+ self.display_exc_trace = display_exc_trace
150
+
151
+ def on_batch_end(self, model: torch.nn.Module, batch_idx: int, loss: int, batch: dict[str, torch.Tensor]) -> Union[
152
+ bool, None]:
153
+ if self.save_checkpoint_after_n_batches is not None and batch_idx != 0 and batch_idx % self.save_checkpoint_after_n_batches == 0:
154
+ if isinstance(model, DistributedDataParallel):
155
+ model = next(model.children())
156
+ try:
157
+ if model.save_pretrained is not None:
158
+ ckpt_path = os.path.join(
159
+ self.save_dir,
160
+ 'batch_checkpoint'
161
+ )
162
+ path_exists = os.path.exists(ckpt_path)
163
+ if not path_exists:
164
+ os.makedirs(ckpt_path)
165
+ model.save_pretrained(save_directory=ckpt_path)
166
+ else:
167
+ path_exists = os.path.exists(self.save_dir)
168
+ if not path_exists:
169
+ os.makedirs(self.save_dir)
170
+ ckpt_path = os.path.join(
171
+ self.save_dir,
172
+ 'batch_checkpoint.pt'
173
+ )
174
+ os.remove(ckpt_path)
175
+ torch.save(model.state_dict(), ckpt_path)
176
+ except Exception as e:
177
+ print(f"Error saving batch checkpoint: {str(e)}")
178
+ if self.display_exc_trace:
179
+ traceback.print_exc()
180
+ try:
181
+ if self.push_to_hub and self.push_batch_checkpoint and model.push_to_hub is not None and self.hub_model_id:
182
+ model.push_to_hub(
183
+ repo_id=self.hub_model_id,
184
+ token=self.hf_token,
185
+ private=self.private_repo,
186
+ )
187
+ except Exception as e:
188
+ print(f"Error pushing batch checkpoint: {str(e)}")
189
+ if self.display_exc_trace:
190
+ traceback.print_exc()
191
+
192
+ def on_validation_end(
193
+ self,
194
+ model: Union[torch.nn.Module, PyTorchModelHubMixin],
195
+ epoch: int,
196
+ val_loss: float,
197
+ val_metrics: dict
198
+ ):
199
+ self.finished_epochs += 1
200
+ if val_loss < self.best_loss:
201
+ self.best_loss = val_loss
202
+ if isinstance(model, DistributedDataParallel):
203
+ model = next(model.children())
204
+ try:
205
+ if model.save_pretrained is not None:
206
+ ckpt_path = os.path.join(
207
+ self.save_dir,
208
+ f'epoch_{epoch}_val_loss_{val_loss:.4f}'
209
+ )
210
+ path_exists = os.path.exists(ckpt_path)
211
+ if not path_exists:
212
+ os.makedirs(ckpt_path)
213
+ model.save_pretrained(save_directory=ckpt_path)
214
+ else:
215
+ path_exists = os.path.exists(self.save_dir)
216
+ if not path_exists:
217
+ os.makedirs(self.save_dir)
218
+ ckpt_path = os.path.join(
219
+ self.save_dir,
220
+ f'epoch_{epoch}_val_loss_{val_loss:.4f}.pt'
221
+ )
222
+ torch.save(model.state_dict(), ckpt_path)
223
+ self.ckpt_paths.append(ckpt_path)
224
+
225
+ # Keep only N best checkpoints
226
+ if len(self.ckpt_paths) > self.max_keep:
227
+ oldest_path = self.ckpt_paths.pop(0)
228
+ if model.save_pretrained is not None:
229
+ shutil.rmtree(oldest_path)
230
+ else:
231
+ os.remove(oldest_path)
232
+ except Exception as e:
233
+ print(f"Error saving epoch checkpoint: {str(e)}")
234
+ if self.display_exc_trace:
235
+ traceback.print_exc()
236
+
237
+ try:
238
+ if self.push_to_hub and self.push_checkpoint_weights and model.push_to_hub is not None and self.hub_model_id:
239
+ model.push_to_hub(
240
+ repo_id=self.hub_model_id,
241
+ commit_message=f'Epoch {epoch} - Val loss {val_loss:.4f}',
242
+ token=self.hf_token,
243
+ private=self.private_repo,
244
+ )
245
+ except Exception as e:
246
+ print(f"Error pushing epoch checkpoint: {str(e)}")
247
+ if self.display_exc_trace:
248
+ traceback.print_exc()
249
+
250
+ def on_training_end(self, model: Union[torch.nn.Module, PyTorchModelHubMixin]):
251
+ if isinstance(model, DistributedDataParallel):
252
+ model = next(model.children())
253
+ try:
254
+ # Save final model
255
+ if model.save_pretrained is not None:
256
+ ckpt_path = os.path.join(
257
+ self.save_dir,
258
+ 'final_model'
259
+ )
260
+ model.save_pretrained(save_directory=ckpt_path)
261
+ else:
262
+ ckpt_path = os.path.join(self.save_dir, 'final_model.pt')
263
+ torch.save(model.state_dict(), ckpt_path)
264
+ print(f"Final model saved to {ckpt_path}")
265
+ except Exception as e:
266
+ print(f"Error saving final model: {str(e)}")
267
+ if self.display_exc_trace:
268
+ traceback.print_exc()
269
+ try:
270
+ if self.push_to_hub and model.push_to_hub is not None:
271
+ model.push_to_hub(
272
+ repo_id=self.hub_model_id,
273
+ commit_message=self.final_commit_message or f'Final pre-trained model, after {self.finished_epochs} epochs',
274
+ token=self.hf_token,
275
+ private=self.private_repo,
276
+ )
277
+ print(f"Model uploaded to repo: {self.hub_model_id}")
278
+ except Exception as e:
279
+ print(f"Error pushing final model: {str(e)}")
280
+ if self.display_exc_trace:
281
+ traceback.print_exc()
282
+
283
+
284
+ class JointModelSaveCallback(TrainerCallback):
285
+ def __init__(
286
+ self,
287
+ save_dir: str,
288
+ save_best_only: bool = True,
289
+ max_keep: int = 3,
290
+ push_to_hub: bool = False,
291
+ hub_model_decoder: str = None,
292
+ hub_model_encoder: str = None,
293
+ hub_model_head: str = None,
294
+ private_repo: bool = False,
295
+ hf_token: str = None,
296
+ push_checkpoint_weights: bool = True,
297
+ final_commit_message: str = None,
298
+ save_checkpoint_after_n_batches: int = None,
299
+ push_batch_checkpoint: bool = False,
300
+ mlm_mode: bool = False,
301
+ display_exc_trace: bool = False,
302
+ ):
303
+ self.save_dir = save_dir
304
+ self.save_best_only = save_best_only
305
+ self.max_keep = max_keep
306
+ self.best_loss = float('inf')
307
+ self.ckpt_paths = []
308
+ self.push_to_hub = push_to_hub
309
+ self.hub_model_decoder = hub_model_decoder
310
+ self.hub_model_encoder = hub_model_encoder
311
+ self.hub_model_head = hub_model_head
312
+ self.private_repo = private_repo
313
+ self.hf_token = hf_token
314
+ self.push_checkpoint_weights = push_checkpoint_weights
315
+ self.final_commit_message = final_commit_message
316
+ self.save_checkpoint_after_n_batches = save_checkpoint_after_n_batches
317
+ self.push_batch_checkpoint = push_batch_checkpoint
318
+ self.finished_epochs = 0
319
+ self.mlm_mode = mlm_mode
320
+ self.display_exc_trace = display_exc_trace
321
+
322
+ def _save_batch(self, model: Union[nn.Module, PyTorchModelHubMixin], component: str, hub_id: str = None):
323
+ try:
324
+ if model.save_pretrained is not None:
325
+ ckpt_path = os.path.join(
326
+ self.save_dir,
327
+ component,
328
+ 'batch_checkpoint'
329
+ )
330
+ path_exists = os.path.exists(ckpt_path)
331
+ if not path_exists:
332
+ os.makedirs(ckpt_path)
333
+ model.save_pretrained(save_directory=ckpt_path)
334
+ else:
335
+ comp_path = os.path.join(
336
+ self.save_dir,
337
+ component
338
+ )
339
+ path_exists = os.path.exists(comp_path)
340
+ if not path_exists:
341
+ os.makedirs(comp_path)
342
+ ckpt_path = os.path.join(
343
+ comp_path,
344
+ 'batch_checkpoint.pt'
345
+ )
346
+ os.remove(ckpt_path)
347
+ torch.save(model.state_dict(), ckpt_path)
348
+ except Exception as e:
349
+ print(f"Error saving batch checkpoint: {str(e)}")
350
+ if self.display_exc_trace:
351
+ traceback.print_exc()
352
+ try:
353
+ if self.push_to_hub and self.push_batch_checkpoint and model.push_to_hub is not None and hub_id:
354
+ model.push_to_hub(
355
+ repo_id=hub_id,
356
+ token=self.hf_token,
357
+ private=self.private_repo,
358
+ )
359
+ except Exception as e:
360
+ print(f"Error pushing batch checkpoint: {str(e)}")
361
+ if self.display_exc_trace:
362
+ traceback.print_exc()
363
+
364
+ def on_batch_end(self, model: torch.nn.Module, batch_idx: int, loss: int, batch: dict[str, torch.Tensor]) -> Union[
365
+ bool, None]:
366
+ if self.save_checkpoint_after_n_batches is not None and batch_idx != 0 and batch_idx % self.save_checkpoint_after_n_batches == 0:
367
+ if isinstance(model, DistributedDataParallel):
368
+ model = next(model.children())
369
+ self._save_batch(model.encoder, 'encoder', hub_id=self.hub_model_encoder)
370
+ if not self.mlm_mode:
371
+ self._save_batch(model.decoder, 'decoder', hub_id=self.hub_model_decoder)
372
+ self._save_batch(model.mlm_head, 'head', hub_id=self.hub_model_head)
373
+
374
+ def _save_validation(self, model: Union[nn.Module, PyTorchModelHubMixin], component: str, epoch: int,
375
+ val_loss: float, hub_id: str = None):
376
+ try:
377
+ if model.save_pretrained is not None:
378
+ ckpt_path = os.path.join(
379
+ self.save_dir,
380
+ component,
381
+ f'epoch_{epoch}_val_loss_{val_loss:.4f}'
382
+ )
383
+ path_exists = os.path.exists(ckpt_path)
384
+ if not path_exists:
385
+ os.makedirs(ckpt_path)
386
+ model.save_pretrained(save_directory=ckpt_path)
387
+ else:
388
+ comp_path = os.path.join(
389
+ self.save_dir,
390
+ component
391
+ )
392
+ path_exists = os.path.exists(comp_path)
393
+ if not path_exists:
394
+ os.makedirs(comp_path)
395
+ ckpt_path = os.path.join(
396
+ comp_path,
397
+ f'epoch_{epoch}_val_loss_{val_loss:.4f}.pt'
398
+ )
399
+ torch.save(model.state_dict(), ckpt_path)
400
+ self.ckpt_paths.append(ckpt_path)
401
+
402
+ # Keep only N best checkpoints
403
+ if len(self.ckpt_paths) > self.max_keep:
404
+ oldest_path = self.ckpt_paths.pop(0)
405
+ if model.save_pretrained is not None:
406
+ shutil.rmtree(oldest_path)
407
+ else:
408
+ os.remove(oldest_path)
409
+ except Exception as e:
410
+ print(f"Error saving epoch checkpoint: {str(e)}")
411
+ if self.display_exc_trace:
412
+ traceback.print_exc()
413
+
414
+ try:
415
+ if self.push_to_hub and self.push_checkpoint_weights and model.push_to_hub is not None and hub_id:
416
+ model.push_to_hub(
417
+ repo_id=hub_id,
418
+ commit_message=f'Epoch {epoch} - Val loss {val_loss:.4f}',
419
+ token=self.hf_token,
420
+ private=self.private_repo,
421
+ )
422
+ except Exception as e:
423
+ print(f"Error pushing epoch checkpoint: {str(e)}")
424
+ if self.display_exc_trace:
425
+ traceback.print_exc()
426
+
427
+ def on_validation_end(
428
+ self,
429
+ model: Union[torch.nn.Module, PyTorchModelHubMixin],
430
+ epoch: int,
431
+ val_loss: float,
432
+ val_metrics: dict
433
+ ):
434
+ self.finished_epochs += 1
435
+ if val_loss < self.best_loss:
436
+ self.best_loss = val_loss
437
+ if isinstance(model, DistributedDataParallel):
438
+ model = next(model.children())
439
+ self._save_validation(model.encoder, 'encoder', epoch, val_loss, hub_id=self.hub_model_encoder)
440
+ if not self.mlm_mode:
441
+ self._save_validation(model.decoder, 'decoder', epoch, val_loss, hub_id=self.hub_model_decoder)
442
+ self._save_validation(model.mlm_head, 'head', epoch, val_loss, hub_id=self.hub_model_head)
443
+
444
+ def _save_final(self, model: Union[nn.Module, PyTorchModelHubMixin], component: str, hub_id: str = None):
445
+ try:
446
+ # Save final model
447
+ if model.save_pretrained is not None:
448
+ ckpt_path = os.path.join(
449
+ self.save_dir,
450
+ component,
451
+ 'final_model'
452
+ )
453
+ path_exists = os.path.exists(ckpt_path)
454
+ if not path_exists:
455
+ os.makedirs(ckpt_path)
456
+ model.save_pretrained(save_directory=ckpt_path)
457
+ else:
458
+ comp_path = os.path.join(
459
+ self.save_dir,
460
+ component
461
+ )
462
+ path_exists = os.path.exists(comp_path)
463
+ if not path_exists:
464
+ os.makedirs(comp_path)
465
+ ckpt_path = os.path.join(comp_path, 'final_model.pt')
466
+ torch.save(model.state_dict(), ckpt_path)
467
+ print(f"Final model saved to {ckpt_path}")
468
+ except Exception as e:
469
+ print(f"Error saving final model: {str(e)}")
470
+ if self.display_exc_trace:
471
+ traceback.print_exc()
472
+ try:
473
+ if self.push_to_hub and model.push_to_hub is not None and hub_id:
474
+ model.push_to_hub(
475
+ repo_id=hub_id,
476
+ commit_message=self.final_commit_message or f'Final pre-trained model, after {self.finished_epochs} epochs',
477
+ token=self.hf_token,
478
+ private=self.private_repo,
479
+ )
480
+ except Exception as e:
481
+ print(f"Error pushing final model: {str(e)}")
482
+ if self.display_exc_trace:
483
+ traceback.print_exc()
484
+
485
+ def on_training_end(self, model: Union[torch.nn.Module, PyTorchModelHubMixin]):
486
+ if isinstance(model, DistributedDataParallel):
487
+ model = next(model.children())
488
+ self._save_final(model.encoder, 'encoder', hub_id=self.hub_model_encoder)
489
+ if not self.mlm_mode:
490
+ self._save_final(model.decoder, 'decoder', hub_id=self.hub_model_decoder)
491
+ self._save_final(model.mlm_head, 'head', hub_id=self.hub_model_head)
@@ -0,0 +1,164 @@
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ from datasets import Dataset as HfDataset
4
+ from transformers import PreTrainedTokenizer
5
+
6
+ from typing import Union
7
+
8
+
9
+ class BaseDataset(Dataset):
10
+ def __init__(
11
+ self,
12
+ texts: Union[list[str], HfDataset],
13
+ tokenizer: PreTrainedTokenizer,
14
+ max_seq_len: int = 1024,
15
+ hf_field: str = 'text',
16
+ merge_short_from: int = None,
17
+ *args,
18
+ **kwargs
19
+ ):
20
+ super(BaseDataset, self).__init__(*args, **kwargs)
21
+ self.tokenizer = tokenizer
22
+ self.max_seq_len = max_seq_len
23
+ self.texts = texts
24
+ self.hf_field = hf_field
25
+ self.merge_short_from = merge_short_from
26
+
27
+ def get_tokenized_text(self, idx: int):
28
+ if isinstance(self.texts, list):
29
+ text = self.texts[idx]
30
+ else:
31
+ text = self.texts[idx][self.hf_field]
32
+
33
+ inputs = self.tokenizer(
34
+ text,
35
+ max_length=self.max_seq_len,
36
+ truncation=True,
37
+ padding='max_length',
38
+ return_tensors='pt',
39
+ return_attention_mask=True
40
+ )
41
+ if not (inputs['input_ids'][0] < self.tokenizer.vocab_size).all():
42
+ inputs['input_ids'][0][(inputs['input_ids'][0] >= self.tokenizer.vocab_size)] = self.tokenizer.unk_token_id
43
+ if not (inputs['input_ids'][0] >= 0).all():
44
+ inputs['input_ids'][0][inputs['input_ids'][0] < 0] = self.tokenizer.unk_token_id
45
+
46
+ return inputs
47
+
48
+
49
+ class JointLMDataset(BaseDataset):
50
+ def __init__(
51
+ self,
52
+ texts: Union[list[str], HfDataset],
53
+ tokenizer: PreTrainedTokenizer,
54
+ max_seq_len: int = 1024,
55
+ mask_prob: float = 0.15,
56
+ hf_field: str = 'text',
57
+ *args,
58
+ **kwargs
59
+ ):
60
+ super(JointLMDataset, self).__init__(texts, tokenizer, max_seq_len, hf_field, *args, **kwargs)
61
+ self.mask_prob = mask_prob
62
+
63
+ def __getitem__(self, idx: int) -> dict[str, dict[str, torch.Tensor]]:
64
+ inputs = self.get_tokenized_text(idx)
65
+ encoder_input_ids = inputs['input_ids'][0]
66
+ attention_mask = inputs['attention_mask'][0]
67
+
68
+ decoder_input_ids = encoder_input_ids.clone()
69
+
70
+ encoder_labels = encoder_input_ids.clone()
71
+ decoder_targets = encoder_input_ids.clone()
72
+
73
+ # Create masked indices
74
+ masked_indices = torch.bernoulli(
75
+ torch.full(encoder_labels.shape, self.mask_prob)
76
+ ).bool() & attention_mask.bool()
77
+
78
+ # Apply mask
79
+ encoder_labels[~masked_indices] = -100
80
+ encoder_input_ids[masked_indices] = self.tokenizer.mask_token_id
81
+
82
+ return {
83
+ 'decoder': {
84
+ 'input_ids': decoder_input_ids,
85
+ 'targets': decoder_targets,
86
+ },
87
+ 'encoder': {
88
+ 'input_ids': encoder_input_ids,
89
+ 'labels': encoder_labels,
90
+ },
91
+ 'attention_mask': attention_mask,
92
+ }
93
+
94
+ def __len__(self):
95
+ return len(self.texts)
96
+
97
+
98
+ class MaskedLMDataset(BaseDataset):
99
+ def __init__(
100
+ self,
101
+ texts: Union[list[str], HfDataset],
102
+ tokenizer: PreTrainedTokenizer,
103
+ max_seq_len: int = 1024,
104
+ mask_prob: float = 0.15,
105
+ hf_field: str = 'text',
106
+ *args,
107
+ **kwargs
108
+ ):
109
+ super(MaskedLMDataset, self).__init__(texts, tokenizer, max_seq_len, hf_field, *args, **kwargs)
110
+ self.mask_prob = mask_prob
111
+
112
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
113
+ inputs = self.get_tokenized_text(idx)
114
+
115
+ input_ids = inputs['input_ids'][0]
116
+ attention_mask = inputs['attention_mask'][0]
117
+ labels = input_ids.clone()
118
+
119
+ # Create masked indices
120
+ masked_indices = torch.bernoulli(
121
+ torch.full(labels.shape, self.mask_prob)
122
+ ).bool() & attention_mask.bool()
123
+
124
+ # Apply mask
125
+ labels[~masked_indices] = -100
126
+ input_ids[masked_indices] = self.tokenizer.mask_token_id
127
+
128
+ return {
129
+ 'input_ids': input_ids,
130
+ 'attention_mask': attention_mask,
131
+ 'labels': labels
132
+ }
133
+
134
+ def __len__(self):
135
+ return len(self.texts)
136
+
137
+
138
+ class AutoregressiveLMDataset(BaseDataset):
139
+ def __init__(
140
+ self,
141
+ texts: Union[list[str], HfDataset],
142
+ tokenizer: PreTrainedTokenizer,
143
+ max_seq_len: int = 1024,
144
+ hf_field: str = 'text',
145
+ *args,
146
+ **kwargs
147
+ ):
148
+ super(AutoregressiveLMDataset, self).__init__(texts, tokenizer, max_seq_len, hf_field, *args, **kwargs)
149
+
150
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
151
+ inputs = self.get_tokenized_text(idx)
152
+
153
+ input_ids = inputs['input_ids'][0]
154
+ attention_mask = inputs['attention_mask'][0]
155
+ targets = input_ids.clone()
156
+
157
+ return {
158
+ 'input_ids': input_ids,
159
+ 'attention_mask': attention_mask,
160
+ 'targets': targets
161
+ }
162
+
163
+ def __len__(self):
164
+ return len(self.texts)
@@ -0,0 +1,19 @@
1
+ import torch
2
+ from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
3
+ import math
4
+
5
+ def get_transformer_lr_scheduler(
6
+ optimizer: torch.optim.Optimizer,
7
+ num_training_steps: int,
8
+ warmup_steps: int = 0
9
+ ):
10
+ if warmup_steps > 0:
11
+ # Warmup + cosine decay
12
+ def lr_lambda(current_step):
13
+ if current_step < warmup_steps:
14
+ return float(current_step) / max(1, warmup_steps)
15
+ remaining = max(0, current_step - warmup_steps)
16
+ return 0.5 * (1.0 + math.cos(math.pi * remaining / (num_training_steps - warmup_steps)))
17
+ return LambdaLR(optimizer, lr_lambda)
18
+ else:
19
+ return CosineAnnealingLR(optimizer, T_max=num_training_steps)