rxnn 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn-0.1.0.dist-info/LICENSE +201 -0
- rxnn-0.1.0.dist-info/METADATA +257 -0
- rxnn-0.1.0.dist-info/RECORD +23 -0
- rxnn-0.1.0.dist-info/WHEEL +4 -0
- src/experimental/attention.py +133 -0
- src/memory/norm.py +173 -0
- src/memory/stm.py +53 -0
- src/rxt/models.py +180 -0
- src/training/base.py +275 -0
- src/training/bml.py +345 -0
- src/training/callbacks.py +491 -0
- src/training/dataset.py +164 -0
- src/training/scheduler.py +19 -0
- src/training/tokenizer.py +208 -0
- src/transformers/attention.py +324 -0
- src/transformers/ff.py +72 -0
- src/transformers/layers.py +150 -0
- src/transformers/mask.py +10 -0
- src/transformers/models.py +168 -0
- src/transformers/moe.py +139 -0
- src/transformers/positional.py +105 -0
- src/transformers/sampler.py +109 -0
- src/utils.py +14 -0
@@ -0,0 +1,491 @@
|
|
1
|
+
import os, traceback, shutil
|
2
|
+
import numpy as np
|
3
|
+
import torch
|
4
|
+
import torch.nn as nn
|
5
|
+
from typing import Union
|
6
|
+
from src.utils import human_format
|
7
|
+
from torch.nn.parallel import DistributedDataParallel
|
8
|
+
from huggingface_hub import PyTorchModelHubMixin
|
9
|
+
|
10
|
+
|
11
|
+
class TrainerCallback:
|
12
|
+
def on_epoch_start(self, model: torch.nn.Module, epoch: int) -> None:
|
13
|
+
pass
|
14
|
+
|
15
|
+
def on_epoch_end(self, model: torch.nn.Module, epoch: int) -> Union[bool, None]:
|
16
|
+
pass
|
17
|
+
|
18
|
+
def on_batch_start(self, model: torch.nn.Module, batch_idx: int, batch: dict[str, torch.Tensor]) -> None:
|
19
|
+
pass
|
20
|
+
|
21
|
+
def on_batch_end(self, model: torch.nn.Module, batch_idx: int, loss: float, batch: dict[str, torch.Tensor]) -> \
|
22
|
+
Union[
|
23
|
+
bool, None]:
|
24
|
+
pass
|
25
|
+
|
26
|
+
def on_training_end(self, model: torch.nn.Module) -> None:
|
27
|
+
pass
|
28
|
+
|
29
|
+
def on_validation_end(self, model: torch.nn.Module, epoch: int, val_loss: float, val_metrics: dict) -> Union[
|
30
|
+
bool, None]:
|
31
|
+
pass
|
32
|
+
|
33
|
+
|
34
|
+
class PrintLossCallback(TrainerCallback):
|
35
|
+
def __init__(self, batch_log_interval: int = 100, joint_mode: bool = False, batches_per_epoch: int = None):
|
36
|
+
self.epoch_means = []
|
37
|
+
self.epoch_losses = []
|
38
|
+
self.batch_group_losses = []
|
39
|
+
self.batch_log_interval = batch_log_interval
|
40
|
+
self.joint_mode = joint_mode
|
41
|
+
self.batches_per_epoch = batches_per_epoch
|
42
|
+
|
43
|
+
def on_batch_start(self, model: nn.Module, batch_idx: int, batch: dict[str, torch.Tensor]) -> None:
|
44
|
+
pass
|
45
|
+
|
46
|
+
def on_batch_end(self, model: nn.Module, batch_idx: int, loss: int,
|
47
|
+
batch: dict[str, torch.Tensor]) -> None:
|
48
|
+
self.batch_group_losses.append(loss)
|
49
|
+
self.epoch_losses.append(loss)
|
50
|
+
|
51
|
+
if batch_idx != 0 and batch_idx % self.batch_log_interval == 0:
|
52
|
+
batch_group_mean = np.stack(self.batch_group_losses).mean()
|
53
|
+
self.batch_group_losses = []
|
54
|
+
if self.batches_per_epoch is not None:
|
55
|
+
print(
|
56
|
+
f'Batch {batch_idx} / {self.batches_per_epoch} - loss: {loss}, last {self.batch_log_interval} batches mean loss: {batch_group_mean:.4f}')
|
57
|
+
else:
|
58
|
+
print(
|
59
|
+
f'Batch {batch_idx} - loss: {loss}, last {self.batch_log_interval} batches mean loss: {batch_group_mean:.4f}')
|
60
|
+
|
61
|
+
def on_epoch_start(self, model: nn.Module, epoch: int) -> None:
|
62
|
+
self.epoch_losses = []
|
63
|
+
print(f'Start epoch: {epoch}')
|
64
|
+
|
65
|
+
def on_epoch_end(self, model: nn.Module, epoch: int) -> None:
|
66
|
+
epoch_mean = np.stack(self.epoch_losses).mean()
|
67
|
+
print(f'Epoch {epoch} - mean loss: {epoch_mean:.4f}')
|
68
|
+
self.epoch_means.append(epoch_mean)
|
69
|
+
|
70
|
+
def on_training_end(self, model: nn.Module) -> None:
|
71
|
+
print(f'Finished training! All losses:')
|
72
|
+
print(self.epoch_means)
|
73
|
+
|
74
|
+
def on_validation_end(self, model: nn.Module, epoch: int, val_loss: float, val_metrics: dict) -> None:
|
75
|
+
if self.joint_mode:
|
76
|
+
print(f"Epoch {epoch} - encoder loss: {val_metrics['loss']['encoder']:.4f}")
|
77
|
+
print(f"Epoch {epoch} - decoder loss: {val_metrics['loss']['decoder']:.4f}")
|
78
|
+
print(f"Epoch {epoch} - validation Loss: {val_loss:.4f}")
|
79
|
+
|
80
|
+
|
81
|
+
class PrintAccuracyCallback(TrainerCallback):
|
82
|
+
def __init__(self, joint_mode: bool = False):
|
83
|
+
self.joint_mode = joint_mode
|
84
|
+
|
85
|
+
def on_validation_end(self, model: nn.Module, epoch: int, val_loss: float, val_metrics: dict) -> None:
|
86
|
+
if self.joint_mode:
|
87
|
+
print(f"Epoch {epoch} - encoder accuracy: {val_metrics['accuracy']['encoder']:.4f}")
|
88
|
+
print(f"Epoch {epoch} - decoder accuracy: {val_metrics['accuracy']['decoder']:.4f}")
|
89
|
+
else:
|
90
|
+
print(f"Epoch {epoch} - accuracy: {val_metrics['accuracy']:.4f}")
|
91
|
+
|
92
|
+
|
93
|
+
class TokenCounterCallback(TrainerCallback):
|
94
|
+
def __init__(self, limit: int, batch_log_interval: int = 100):
|
95
|
+
self.total_tokens = 0
|
96
|
+
self.limit = limit
|
97
|
+
self.batch_log_interval = batch_log_interval
|
98
|
+
|
99
|
+
def on_batch_end(self, model: nn.Module, batch_idx: int, loss: int,
|
100
|
+
batch: dict[str, torch.Tensor]) -> bool:
|
101
|
+
attention_mask = batch['attention_mask']
|
102
|
+
batch_tokens = attention_mask.sum().item()
|
103
|
+
self.total_tokens += batch_tokens
|
104
|
+
if batch_idx != 0 and batch_idx % self.batch_log_interval == 0:
|
105
|
+
print(f'Total processed tokens: {human_format(self.total_tokens)}')
|
106
|
+
|
107
|
+
should_stop_training = self.total_tokens >= self.limit
|
108
|
+
if should_stop_training:
|
109
|
+
print(f'Reached a limit of {human_format(self.limit)} processed tokens - stopping training')
|
110
|
+
return should_stop_training
|
111
|
+
|
112
|
+
def on_training_end(self, model: torch.nn.Module) -> None:
|
113
|
+
print(f'Total training tokens: {human_format(self.total_tokens)}')
|
114
|
+
|
115
|
+
def get_total_tokens(self):
|
116
|
+
return self.total_tokens
|
117
|
+
|
118
|
+
|
119
|
+
class ModelSaveCallback(TrainerCallback):
|
120
|
+
def __init__(
|
121
|
+
self,
|
122
|
+
save_dir: str,
|
123
|
+
save_best_only: bool = True,
|
124
|
+
max_keep: int = 3,
|
125
|
+
push_to_hub: bool = False,
|
126
|
+
hub_model_id: str = None,
|
127
|
+
private_repo: bool = False,
|
128
|
+
hf_token: str = None,
|
129
|
+
push_checkpoint_weights: bool = True,
|
130
|
+
final_commit_message: str = None,
|
131
|
+
save_checkpoint_after_n_batches: int = None,
|
132
|
+
push_batch_checkpoint: bool = False,
|
133
|
+
display_exc_trace: bool = False,
|
134
|
+
):
|
135
|
+
self.save_dir = save_dir
|
136
|
+
self.save_best_only = save_best_only
|
137
|
+
self.max_keep = max_keep
|
138
|
+
self.best_loss = float('inf')
|
139
|
+
self.ckpt_paths = []
|
140
|
+
self.push_to_hub = push_to_hub
|
141
|
+
self.hub_model_id = hub_model_id
|
142
|
+
self.private_repo = private_repo
|
143
|
+
self.hf_token = hf_token
|
144
|
+
self.push_checkpoint_weights = push_checkpoint_weights
|
145
|
+
self.final_commit_message = final_commit_message
|
146
|
+
self.save_checkpoint_after_n_batches = save_checkpoint_after_n_batches
|
147
|
+
self.push_batch_checkpoint = push_batch_checkpoint
|
148
|
+
self.finished_epochs = 0
|
149
|
+
self.display_exc_trace = display_exc_trace
|
150
|
+
|
151
|
+
def on_batch_end(self, model: torch.nn.Module, batch_idx: int, loss: int, batch: dict[str, torch.Tensor]) -> Union[
|
152
|
+
bool, None]:
|
153
|
+
if self.save_checkpoint_after_n_batches is not None and batch_idx != 0 and batch_idx % self.save_checkpoint_after_n_batches == 0:
|
154
|
+
if isinstance(model, DistributedDataParallel):
|
155
|
+
model = next(model.children())
|
156
|
+
try:
|
157
|
+
if model.save_pretrained is not None:
|
158
|
+
ckpt_path = os.path.join(
|
159
|
+
self.save_dir,
|
160
|
+
'batch_checkpoint'
|
161
|
+
)
|
162
|
+
path_exists = os.path.exists(ckpt_path)
|
163
|
+
if not path_exists:
|
164
|
+
os.makedirs(ckpt_path)
|
165
|
+
model.save_pretrained(save_directory=ckpt_path)
|
166
|
+
else:
|
167
|
+
path_exists = os.path.exists(self.save_dir)
|
168
|
+
if not path_exists:
|
169
|
+
os.makedirs(self.save_dir)
|
170
|
+
ckpt_path = os.path.join(
|
171
|
+
self.save_dir,
|
172
|
+
'batch_checkpoint.pt'
|
173
|
+
)
|
174
|
+
os.remove(ckpt_path)
|
175
|
+
torch.save(model.state_dict(), ckpt_path)
|
176
|
+
except Exception as e:
|
177
|
+
print(f"Error saving batch checkpoint: {str(e)}")
|
178
|
+
if self.display_exc_trace:
|
179
|
+
traceback.print_exc()
|
180
|
+
try:
|
181
|
+
if self.push_to_hub and self.push_batch_checkpoint and model.push_to_hub is not None and self.hub_model_id:
|
182
|
+
model.push_to_hub(
|
183
|
+
repo_id=self.hub_model_id,
|
184
|
+
token=self.hf_token,
|
185
|
+
private=self.private_repo,
|
186
|
+
)
|
187
|
+
except Exception as e:
|
188
|
+
print(f"Error pushing batch checkpoint: {str(e)}")
|
189
|
+
if self.display_exc_trace:
|
190
|
+
traceback.print_exc()
|
191
|
+
|
192
|
+
def on_validation_end(
|
193
|
+
self,
|
194
|
+
model: Union[torch.nn.Module, PyTorchModelHubMixin],
|
195
|
+
epoch: int,
|
196
|
+
val_loss: float,
|
197
|
+
val_metrics: dict
|
198
|
+
):
|
199
|
+
self.finished_epochs += 1
|
200
|
+
if val_loss < self.best_loss:
|
201
|
+
self.best_loss = val_loss
|
202
|
+
if isinstance(model, DistributedDataParallel):
|
203
|
+
model = next(model.children())
|
204
|
+
try:
|
205
|
+
if model.save_pretrained is not None:
|
206
|
+
ckpt_path = os.path.join(
|
207
|
+
self.save_dir,
|
208
|
+
f'epoch_{epoch}_val_loss_{val_loss:.4f}'
|
209
|
+
)
|
210
|
+
path_exists = os.path.exists(ckpt_path)
|
211
|
+
if not path_exists:
|
212
|
+
os.makedirs(ckpt_path)
|
213
|
+
model.save_pretrained(save_directory=ckpt_path)
|
214
|
+
else:
|
215
|
+
path_exists = os.path.exists(self.save_dir)
|
216
|
+
if not path_exists:
|
217
|
+
os.makedirs(self.save_dir)
|
218
|
+
ckpt_path = os.path.join(
|
219
|
+
self.save_dir,
|
220
|
+
f'epoch_{epoch}_val_loss_{val_loss:.4f}.pt'
|
221
|
+
)
|
222
|
+
torch.save(model.state_dict(), ckpt_path)
|
223
|
+
self.ckpt_paths.append(ckpt_path)
|
224
|
+
|
225
|
+
# Keep only N best checkpoints
|
226
|
+
if len(self.ckpt_paths) > self.max_keep:
|
227
|
+
oldest_path = self.ckpt_paths.pop(0)
|
228
|
+
if model.save_pretrained is not None:
|
229
|
+
shutil.rmtree(oldest_path)
|
230
|
+
else:
|
231
|
+
os.remove(oldest_path)
|
232
|
+
except Exception as e:
|
233
|
+
print(f"Error saving epoch checkpoint: {str(e)}")
|
234
|
+
if self.display_exc_trace:
|
235
|
+
traceback.print_exc()
|
236
|
+
|
237
|
+
try:
|
238
|
+
if self.push_to_hub and self.push_checkpoint_weights and model.push_to_hub is not None and self.hub_model_id:
|
239
|
+
model.push_to_hub(
|
240
|
+
repo_id=self.hub_model_id,
|
241
|
+
commit_message=f'Epoch {epoch} - Val loss {val_loss:.4f}',
|
242
|
+
token=self.hf_token,
|
243
|
+
private=self.private_repo,
|
244
|
+
)
|
245
|
+
except Exception as e:
|
246
|
+
print(f"Error pushing epoch checkpoint: {str(e)}")
|
247
|
+
if self.display_exc_trace:
|
248
|
+
traceback.print_exc()
|
249
|
+
|
250
|
+
def on_training_end(self, model: Union[torch.nn.Module, PyTorchModelHubMixin]):
|
251
|
+
if isinstance(model, DistributedDataParallel):
|
252
|
+
model = next(model.children())
|
253
|
+
try:
|
254
|
+
# Save final model
|
255
|
+
if model.save_pretrained is not None:
|
256
|
+
ckpt_path = os.path.join(
|
257
|
+
self.save_dir,
|
258
|
+
'final_model'
|
259
|
+
)
|
260
|
+
model.save_pretrained(save_directory=ckpt_path)
|
261
|
+
else:
|
262
|
+
ckpt_path = os.path.join(self.save_dir, 'final_model.pt')
|
263
|
+
torch.save(model.state_dict(), ckpt_path)
|
264
|
+
print(f"Final model saved to {ckpt_path}")
|
265
|
+
except Exception as e:
|
266
|
+
print(f"Error saving final model: {str(e)}")
|
267
|
+
if self.display_exc_trace:
|
268
|
+
traceback.print_exc()
|
269
|
+
try:
|
270
|
+
if self.push_to_hub and model.push_to_hub is not None:
|
271
|
+
model.push_to_hub(
|
272
|
+
repo_id=self.hub_model_id,
|
273
|
+
commit_message=self.final_commit_message or f'Final pre-trained model, after {self.finished_epochs} epochs',
|
274
|
+
token=self.hf_token,
|
275
|
+
private=self.private_repo,
|
276
|
+
)
|
277
|
+
print(f"Model uploaded to repo: {self.hub_model_id}")
|
278
|
+
except Exception as e:
|
279
|
+
print(f"Error pushing final model: {str(e)}")
|
280
|
+
if self.display_exc_trace:
|
281
|
+
traceback.print_exc()
|
282
|
+
|
283
|
+
|
284
|
+
class JointModelSaveCallback(TrainerCallback):
|
285
|
+
def __init__(
|
286
|
+
self,
|
287
|
+
save_dir: str,
|
288
|
+
save_best_only: bool = True,
|
289
|
+
max_keep: int = 3,
|
290
|
+
push_to_hub: bool = False,
|
291
|
+
hub_model_decoder: str = None,
|
292
|
+
hub_model_encoder: str = None,
|
293
|
+
hub_model_head: str = None,
|
294
|
+
private_repo: bool = False,
|
295
|
+
hf_token: str = None,
|
296
|
+
push_checkpoint_weights: bool = True,
|
297
|
+
final_commit_message: str = None,
|
298
|
+
save_checkpoint_after_n_batches: int = None,
|
299
|
+
push_batch_checkpoint: bool = False,
|
300
|
+
mlm_mode: bool = False,
|
301
|
+
display_exc_trace: bool = False,
|
302
|
+
):
|
303
|
+
self.save_dir = save_dir
|
304
|
+
self.save_best_only = save_best_only
|
305
|
+
self.max_keep = max_keep
|
306
|
+
self.best_loss = float('inf')
|
307
|
+
self.ckpt_paths = []
|
308
|
+
self.push_to_hub = push_to_hub
|
309
|
+
self.hub_model_decoder = hub_model_decoder
|
310
|
+
self.hub_model_encoder = hub_model_encoder
|
311
|
+
self.hub_model_head = hub_model_head
|
312
|
+
self.private_repo = private_repo
|
313
|
+
self.hf_token = hf_token
|
314
|
+
self.push_checkpoint_weights = push_checkpoint_weights
|
315
|
+
self.final_commit_message = final_commit_message
|
316
|
+
self.save_checkpoint_after_n_batches = save_checkpoint_after_n_batches
|
317
|
+
self.push_batch_checkpoint = push_batch_checkpoint
|
318
|
+
self.finished_epochs = 0
|
319
|
+
self.mlm_mode = mlm_mode
|
320
|
+
self.display_exc_trace = display_exc_trace
|
321
|
+
|
322
|
+
def _save_batch(self, model: Union[nn.Module, PyTorchModelHubMixin], component: str, hub_id: str = None):
|
323
|
+
try:
|
324
|
+
if model.save_pretrained is not None:
|
325
|
+
ckpt_path = os.path.join(
|
326
|
+
self.save_dir,
|
327
|
+
component,
|
328
|
+
'batch_checkpoint'
|
329
|
+
)
|
330
|
+
path_exists = os.path.exists(ckpt_path)
|
331
|
+
if not path_exists:
|
332
|
+
os.makedirs(ckpt_path)
|
333
|
+
model.save_pretrained(save_directory=ckpt_path)
|
334
|
+
else:
|
335
|
+
comp_path = os.path.join(
|
336
|
+
self.save_dir,
|
337
|
+
component
|
338
|
+
)
|
339
|
+
path_exists = os.path.exists(comp_path)
|
340
|
+
if not path_exists:
|
341
|
+
os.makedirs(comp_path)
|
342
|
+
ckpt_path = os.path.join(
|
343
|
+
comp_path,
|
344
|
+
'batch_checkpoint.pt'
|
345
|
+
)
|
346
|
+
os.remove(ckpt_path)
|
347
|
+
torch.save(model.state_dict(), ckpt_path)
|
348
|
+
except Exception as e:
|
349
|
+
print(f"Error saving batch checkpoint: {str(e)}")
|
350
|
+
if self.display_exc_trace:
|
351
|
+
traceback.print_exc()
|
352
|
+
try:
|
353
|
+
if self.push_to_hub and self.push_batch_checkpoint and model.push_to_hub is not None and hub_id:
|
354
|
+
model.push_to_hub(
|
355
|
+
repo_id=hub_id,
|
356
|
+
token=self.hf_token,
|
357
|
+
private=self.private_repo,
|
358
|
+
)
|
359
|
+
except Exception as e:
|
360
|
+
print(f"Error pushing batch checkpoint: {str(e)}")
|
361
|
+
if self.display_exc_trace:
|
362
|
+
traceback.print_exc()
|
363
|
+
|
364
|
+
def on_batch_end(self, model: torch.nn.Module, batch_idx: int, loss: int, batch: dict[str, torch.Tensor]) -> Union[
|
365
|
+
bool, None]:
|
366
|
+
if self.save_checkpoint_after_n_batches is not None and batch_idx != 0 and batch_idx % self.save_checkpoint_after_n_batches == 0:
|
367
|
+
if isinstance(model, DistributedDataParallel):
|
368
|
+
model = next(model.children())
|
369
|
+
self._save_batch(model.encoder, 'encoder', hub_id=self.hub_model_encoder)
|
370
|
+
if not self.mlm_mode:
|
371
|
+
self._save_batch(model.decoder, 'decoder', hub_id=self.hub_model_decoder)
|
372
|
+
self._save_batch(model.mlm_head, 'head', hub_id=self.hub_model_head)
|
373
|
+
|
374
|
+
def _save_validation(self, model: Union[nn.Module, PyTorchModelHubMixin], component: str, epoch: int,
|
375
|
+
val_loss: float, hub_id: str = None):
|
376
|
+
try:
|
377
|
+
if model.save_pretrained is not None:
|
378
|
+
ckpt_path = os.path.join(
|
379
|
+
self.save_dir,
|
380
|
+
component,
|
381
|
+
f'epoch_{epoch}_val_loss_{val_loss:.4f}'
|
382
|
+
)
|
383
|
+
path_exists = os.path.exists(ckpt_path)
|
384
|
+
if not path_exists:
|
385
|
+
os.makedirs(ckpt_path)
|
386
|
+
model.save_pretrained(save_directory=ckpt_path)
|
387
|
+
else:
|
388
|
+
comp_path = os.path.join(
|
389
|
+
self.save_dir,
|
390
|
+
component
|
391
|
+
)
|
392
|
+
path_exists = os.path.exists(comp_path)
|
393
|
+
if not path_exists:
|
394
|
+
os.makedirs(comp_path)
|
395
|
+
ckpt_path = os.path.join(
|
396
|
+
comp_path,
|
397
|
+
f'epoch_{epoch}_val_loss_{val_loss:.4f}.pt'
|
398
|
+
)
|
399
|
+
torch.save(model.state_dict(), ckpt_path)
|
400
|
+
self.ckpt_paths.append(ckpt_path)
|
401
|
+
|
402
|
+
# Keep only N best checkpoints
|
403
|
+
if len(self.ckpt_paths) > self.max_keep:
|
404
|
+
oldest_path = self.ckpt_paths.pop(0)
|
405
|
+
if model.save_pretrained is not None:
|
406
|
+
shutil.rmtree(oldest_path)
|
407
|
+
else:
|
408
|
+
os.remove(oldest_path)
|
409
|
+
except Exception as e:
|
410
|
+
print(f"Error saving epoch checkpoint: {str(e)}")
|
411
|
+
if self.display_exc_trace:
|
412
|
+
traceback.print_exc()
|
413
|
+
|
414
|
+
try:
|
415
|
+
if self.push_to_hub and self.push_checkpoint_weights and model.push_to_hub is not None and hub_id:
|
416
|
+
model.push_to_hub(
|
417
|
+
repo_id=hub_id,
|
418
|
+
commit_message=f'Epoch {epoch} - Val loss {val_loss:.4f}',
|
419
|
+
token=self.hf_token,
|
420
|
+
private=self.private_repo,
|
421
|
+
)
|
422
|
+
except Exception as e:
|
423
|
+
print(f"Error pushing epoch checkpoint: {str(e)}")
|
424
|
+
if self.display_exc_trace:
|
425
|
+
traceback.print_exc()
|
426
|
+
|
427
|
+
def on_validation_end(
|
428
|
+
self,
|
429
|
+
model: Union[torch.nn.Module, PyTorchModelHubMixin],
|
430
|
+
epoch: int,
|
431
|
+
val_loss: float,
|
432
|
+
val_metrics: dict
|
433
|
+
):
|
434
|
+
self.finished_epochs += 1
|
435
|
+
if val_loss < self.best_loss:
|
436
|
+
self.best_loss = val_loss
|
437
|
+
if isinstance(model, DistributedDataParallel):
|
438
|
+
model = next(model.children())
|
439
|
+
self._save_validation(model.encoder, 'encoder', epoch, val_loss, hub_id=self.hub_model_encoder)
|
440
|
+
if not self.mlm_mode:
|
441
|
+
self._save_validation(model.decoder, 'decoder', epoch, val_loss, hub_id=self.hub_model_decoder)
|
442
|
+
self._save_validation(model.mlm_head, 'head', epoch, val_loss, hub_id=self.hub_model_head)
|
443
|
+
|
444
|
+
def _save_final(self, model: Union[nn.Module, PyTorchModelHubMixin], component: str, hub_id: str = None):
|
445
|
+
try:
|
446
|
+
# Save final model
|
447
|
+
if model.save_pretrained is not None:
|
448
|
+
ckpt_path = os.path.join(
|
449
|
+
self.save_dir,
|
450
|
+
component,
|
451
|
+
'final_model'
|
452
|
+
)
|
453
|
+
path_exists = os.path.exists(ckpt_path)
|
454
|
+
if not path_exists:
|
455
|
+
os.makedirs(ckpt_path)
|
456
|
+
model.save_pretrained(save_directory=ckpt_path)
|
457
|
+
else:
|
458
|
+
comp_path = os.path.join(
|
459
|
+
self.save_dir,
|
460
|
+
component
|
461
|
+
)
|
462
|
+
path_exists = os.path.exists(comp_path)
|
463
|
+
if not path_exists:
|
464
|
+
os.makedirs(comp_path)
|
465
|
+
ckpt_path = os.path.join(comp_path, 'final_model.pt')
|
466
|
+
torch.save(model.state_dict(), ckpt_path)
|
467
|
+
print(f"Final model saved to {ckpt_path}")
|
468
|
+
except Exception as e:
|
469
|
+
print(f"Error saving final model: {str(e)}")
|
470
|
+
if self.display_exc_trace:
|
471
|
+
traceback.print_exc()
|
472
|
+
try:
|
473
|
+
if self.push_to_hub and model.push_to_hub is not None and hub_id:
|
474
|
+
model.push_to_hub(
|
475
|
+
repo_id=hub_id,
|
476
|
+
commit_message=self.final_commit_message or f'Final pre-trained model, after {self.finished_epochs} epochs',
|
477
|
+
token=self.hf_token,
|
478
|
+
private=self.private_repo,
|
479
|
+
)
|
480
|
+
except Exception as e:
|
481
|
+
print(f"Error pushing final model: {str(e)}")
|
482
|
+
if self.display_exc_trace:
|
483
|
+
traceback.print_exc()
|
484
|
+
|
485
|
+
def on_training_end(self, model: Union[torch.nn.Module, PyTorchModelHubMixin]):
|
486
|
+
if isinstance(model, DistributedDataParallel):
|
487
|
+
model = next(model.children())
|
488
|
+
self._save_final(model.encoder, 'encoder', hub_id=self.hub_model_encoder)
|
489
|
+
if not self.mlm_mode:
|
490
|
+
self._save_final(model.decoder, 'decoder', hub_id=self.hub_model_decoder)
|
491
|
+
self._save_final(model.mlm_head, 'head', hub_id=self.hub_model_head)
|
src/training/dataset.py
ADDED
@@ -0,0 +1,164 @@
|
|
1
|
+
import torch
|
2
|
+
from torch.utils.data import Dataset
|
3
|
+
from datasets import Dataset as HfDataset
|
4
|
+
from transformers import PreTrainedTokenizer
|
5
|
+
|
6
|
+
from typing import Union
|
7
|
+
|
8
|
+
|
9
|
+
class BaseDataset(Dataset):
|
10
|
+
def __init__(
|
11
|
+
self,
|
12
|
+
texts: Union[list[str], HfDataset],
|
13
|
+
tokenizer: PreTrainedTokenizer,
|
14
|
+
max_seq_len: int = 1024,
|
15
|
+
hf_field: str = 'text',
|
16
|
+
merge_short_from: int = None,
|
17
|
+
*args,
|
18
|
+
**kwargs
|
19
|
+
):
|
20
|
+
super(BaseDataset, self).__init__(*args, **kwargs)
|
21
|
+
self.tokenizer = tokenizer
|
22
|
+
self.max_seq_len = max_seq_len
|
23
|
+
self.texts = texts
|
24
|
+
self.hf_field = hf_field
|
25
|
+
self.merge_short_from = merge_short_from
|
26
|
+
|
27
|
+
def get_tokenized_text(self, idx: int):
|
28
|
+
if isinstance(self.texts, list):
|
29
|
+
text = self.texts[idx]
|
30
|
+
else:
|
31
|
+
text = self.texts[idx][self.hf_field]
|
32
|
+
|
33
|
+
inputs = self.tokenizer(
|
34
|
+
text,
|
35
|
+
max_length=self.max_seq_len,
|
36
|
+
truncation=True,
|
37
|
+
padding='max_length',
|
38
|
+
return_tensors='pt',
|
39
|
+
return_attention_mask=True
|
40
|
+
)
|
41
|
+
if not (inputs['input_ids'][0] < self.tokenizer.vocab_size).all():
|
42
|
+
inputs['input_ids'][0][(inputs['input_ids'][0] >= self.tokenizer.vocab_size)] = self.tokenizer.unk_token_id
|
43
|
+
if not (inputs['input_ids'][0] >= 0).all():
|
44
|
+
inputs['input_ids'][0][inputs['input_ids'][0] < 0] = self.tokenizer.unk_token_id
|
45
|
+
|
46
|
+
return inputs
|
47
|
+
|
48
|
+
|
49
|
+
class JointLMDataset(BaseDataset):
|
50
|
+
def __init__(
|
51
|
+
self,
|
52
|
+
texts: Union[list[str], HfDataset],
|
53
|
+
tokenizer: PreTrainedTokenizer,
|
54
|
+
max_seq_len: int = 1024,
|
55
|
+
mask_prob: float = 0.15,
|
56
|
+
hf_field: str = 'text',
|
57
|
+
*args,
|
58
|
+
**kwargs
|
59
|
+
):
|
60
|
+
super(JointLMDataset, self).__init__(texts, tokenizer, max_seq_len, hf_field, *args, **kwargs)
|
61
|
+
self.mask_prob = mask_prob
|
62
|
+
|
63
|
+
def __getitem__(self, idx: int) -> dict[str, dict[str, torch.Tensor]]:
|
64
|
+
inputs = self.get_tokenized_text(idx)
|
65
|
+
encoder_input_ids = inputs['input_ids'][0]
|
66
|
+
attention_mask = inputs['attention_mask'][0]
|
67
|
+
|
68
|
+
decoder_input_ids = encoder_input_ids.clone()
|
69
|
+
|
70
|
+
encoder_labels = encoder_input_ids.clone()
|
71
|
+
decoder_targets = encoder_input_ids.clone()
|
72
|
+
|
73
|
+
# Create masked indices
|
74
|
+
masked_indices = torch.bernoulli(
|
75
|
+
torch.full(encoder_labels.shape, self.mask_prob)
|
76
|
+
).bool() & attention_mask.bool()
|
77
|
+
|
78
|
+
# Apply mask
|
79
|
+
encoder_labels[~masked_indices] = -100
|
80
|
+
encoder_input_ids[masked_indices] = self.tokenizer.mask_token_id
|
81
|
+
|
82
|
+
return {
|
83
|
+
'decoder': {
|
84
|
+
'input_ids': decoder_input_ids,
|
85
|
+
'targets': decoder_targets,
|
86
|
+
},
|
87
|
+
'encoder': {
|
88
|
+
'input_ids': encoder_input_ids,
|
89
|
+
'labels': encoder_labels,
|
90
|
+
},
|
91
|
+
'attention_mask': attention_mask,
|
92
|
+
}
|
93
|
+
|
94
|
+
def __len__(self):
|
95
|
+
return len(self.texts)
|
96
|
+
|
97
|
+
|
98
|
+
class MaskedLMDataset(BaseDataset):
|
99
|
+
def __init__(
|
100
|
+
self,
|
101
|
+
texts: Union[list[str], HfDataset],
|
102
|
+
tokenizer: PreTrainedTokenizer,
|
103
|
+
max_seq_len: int = 1024,
|
104
|
+
mask_prob: float = 0.15,
|
105
|
+
hf_field: str = 'text',
|
106
|
+
*args,
|
107
|
+
**kwargs
|
108
|
+
):
|
109
|
+
super(MaskedLMDataset, self).__init__(texts, tokenizer, max_seq_len, hf_field, *args, **kwargs)
|
110
|
+
self.mask_prob = mask_prob
|
111
|
+
|
112
|
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
113
|
+
inputs = self.get_tokenized_text(idx)
|
114
|
+
|
115
|
+
input_ids = inputs['input_ids'][0]
|
116
|
+
attention_mask = inputs['attention_mask'][0]
|
117
|
+
labels = input_ids.clone()
|
118
|
+
|
119
|
+
# Create masked indices
|
120
|
+
masked_indices = torch.bernoulli(
|
121
|
+
torch.full(labels.shape, self.mask_prob)
|
122
|
+
).bool() & attention_mask.bool()
|
123
|
+
|
124
|
+
# Apply mask
|
125
|
+
labels[~masked_indices] = -100
|
126
|
+
input_ids[masked_indices] = self.tokenizer.mask_token_id
|
127
|
+
|
128
|
+
return {
|
129
|
+
'input_ids': input_ids,
|
130
|
+
'attention_mask': attention_mask,
|
131
|
+
'labels': labels
|
132
|
+
}
|
133
|
+
|
134
|
+
def __len__(self):
|
135
|
+
return len(self.texts)
|
136
|
+
|
137
|
+
|
138
|
+
class AutoregressiveLMDataset(BaseDataset):
|
139
|
+
def __init__(
|
140
|
+
self,
|
141
|
+
texts: Union[list[str], HfDataset],
|
142
|
+
tokenizer: PreTrainedTokenizer,
|
143
|
+
max_seq_len: int = 1024,
|
144
|
+
hf_field: str = 'text',
|
145
|
+
*args,
|
146
|
+
**kwargs
|
147
|
+
):
|
148
|
+
super(AutoregressiveLMDataset, self).__init__(texts, tokenizer, max_seq_len, hf_field, *args, **kwargs)
|
149
|
+
|
150
|
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
151
|
+
inputs = self.get_tokenized_text(idx)
|
152
|
+
|
153
|
+
input_ids = inputs['input_ids'][0]
|
154
|
+
attention_mask = inputs['attention_mask'][0]
|
155
|
+
targets = input_ids.clone()
|
156
|
+
|
157
|
+
return {
|
158
|
+
'input_ids': input_ids,
|
159
|
+
'attention_mask': attention_mask,
|
160
|
+
'targets': targets
|
161
|
+
}
|
162
|
+
|
163
|
+
def __len__(self):
|
164
|
+
return len(self.texts)
|
@@ -0,0 +1,19 @@
|
|
1
|
+
import torch
|
2
|
+
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
|
3
|
+
import math
|
4
|
+
|
5
|
+
def get_transformer_lr_scheduler(
|
6
|
+
optimizer: torch.optim.Optimizer,
|
7
|
+
num_training_steps: int,
|
8
|
+
warmup_steps: int = 0
|
9
|
+
):
|
10
|
+
if warmup_steps > 0:
|
11
|
+
# Warmup + cosine decay
|
12
|
+
def lr_lambda(current_step):
|
13
|
+
if current_step < warmup_steps:
|
14
|
+
return float(current_step) / max(1, warmup_steps)
|
15
|
+
remaining = max(0, current_step - warmup_steps)
|
16
|
+
return 0.5 * (1.0 + math.cos(math.pi * remaining / (num_training_steps - warmup_steps)))
|
17
|
+
return LambdaLR(optimizer, lr_lambda)
|
18
|
+
else:
|
19
|
+
return CosineAnnealingLR(optimizer, T_max=num_training_steps)
|