rxnn 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
src/training/bml.py ADDED
@@ -0,0 +1,345 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from huggingface_hub import PyTorchModelHubMixin
6
+ from typing import Union
7
+ import torch.distributed as dist
8
+ from src.transformers.models import ReactiveTransformerEncoder, ReactiveTransformerDecoder
9
+ from src.training.base import BaseTrainer
10
+
11
+ class MLMHead(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
12
+ def __init__(self, embed_dim: int, vocab_size: int, *args, **kwargs):
13
+ super(MLMHead, self).__init__(*args, **kwargs)
14
+ self.dense = nn.Linear(embed_dim, embed_dim)
15
+ self.act = nn.GELU()
16
+ self.layer_norm = nn.LayerNorm(embed_dim)
17
+ self.decoder = nn.Linear(embed_dim, vocab_size)
18
+
19
+ def forward(self, hidden_states):
20
+ x = self.dense(hidden_states)
21
+ x = self.act(x)
22
+ x = self.layer_norm(x)
23
+ return self.decoder(x)
24
+
25
+
26
+ class MLMTrainingModel(nn.Module):
27
+ def __init__(
28
+ self,
29
+ encoder: ReactiveTransformerEncoder,
30
+ mlm_head: MLMHead,
31
+ *args,
32
+ **kwargs
33
+ ):
34
+ super(MLMTrainingModel, self).__init__(*args, **kwargs)
35
+ self.encoder = encoder
36
+ self.mlm_head = mlm_head
37
+
38
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
39
+ h, _ = self.encoder(x, attention_mask=attention_mask)
40
+ y = self.mlm_head(h)
41
+ return y
42
+
43
+
44
+ class MLMTrainer(BaseTrainer):
45
+ def __init__(
46
+ self,
47
+ model: MLMTrainingModel,
48
+ device: torch.device,
49
+ vocab_size: int,
50
+ use_amp: bool = False,
51
+ dtype: torch.dtype = None,
52
+ **kwargs
53
+ ):
54
+ super(MLMTrainer, self).__init__(model, device, use_amp=use_amp, dtype=dtype, **kwargs)
55
+ self.vocab_size = vocab_size
56
+
57
+ def compute_loss(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
58
+ inputs = batch['input_ids']
59
+ attention_mask = batch['attention_mask']
60
+ labels = batch['labels']
61
+
62
+ logits = self.model(
63
+ inputs,
64
+ attention_mask=attention_mask
65
+ )
66
+
67
+ return F.cross_entropy(
68
+ logits.view(-1, self.vocab_size),
69
+ labels.view(-1),
70
+ ignore_index=-100
71
+ ), logits
72
+
73
+ def validate(self, batch_size: int) -> tuple[float, dict]:
74
+ self.model.eval()
75
+ val_dataloader = self._valid_loader(batch_size)
76
+ val_loss = torch.tensor(0.0).to(self.device)
77
+ correct = torch.tensor(0).to(self.device)
78
+ total = torch.tensor(0).to(self.device)
79
+
80
+ with torch.no_grad():
81
+ for batch in val_dataloader:
82
+ if self.get_batch_size(batch) == batch_size:
83
+ loss, logits = self.valid_step(batch)
84
+ val_loss += loss
85
+ labels = batch[self.target_field_name].to(self.device)
86
+ valid_indices = labels != -100
87
+ if valid_indices.any():
88
+ preds = logits.argmax(-1)
89
+ correct += (preds[valid_indices] == labels[valid_indices]).sum()
90
+ total += valid_indices.sum()
91
+
92
+ avg_loss = (val_loss / len(val_dataloader)).item()
93
+ if self.use_ddp:
94
+ dist.all_reduce(correct, op=dist.ReduceOp.SUM)
95
+ dist.all_reduce(total, op=dist.ReduceOp.SUM)
96
+
97
+ metrics = {
98
+ 'accuracy': (correct / total * 100).item() if total > 0 else 0.0
99
+ }
100
+ self.model.train()
101
+ return avg_loss, metrics
102
+
103
+
104
+ class AutoregressiveTrainer(BaseTrainer):
105
+ def __init__(
106
+ self,
107
+ model: ReactiveTransformerDecoder,
108
+ device: torch.device,
109
+ vocab_size: int,
110
+ use_amp: bool = False,
111
+ dtype: torch.dtype = None,
112
+ **kwargs
113
+ ):
114
+ super(AutoregressiveTrainer, self).__init__(model, device, use_amp=use_amp, dtype=dtype,
115
+ target_field_name='targets', **kwargs)
116
+ self.vocab_size = vocab_size
117
+
118
+ def compute_loss(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
119
+ inputs = batch['input_ids']
120
+ attention_mask = batch['attention_mask']
121
+ targets = batch['targets']
122
+
123
+ outputs = self.model(
124
+ inputs,
125
+ attention_mask=attention_mask
126
+ )
127
+
128
+ shifted_logits = outputs[:, :-1].contiguous()
129
+ shifted_targets = targets[:, 1:].contiguous()
130
+
131
+ return F.cross_entropy(
132
+ shifted_logits.view(-1, self.vocab_size),
133
+ shifted_targets.view(-1)
134
+ ), outputs
135
+
136
+ def validate(self, batch_size: int) -> tuple[float, dict]:
137
+ self.model.eval()
138
+ val_dataloader = self._valid_loader(batch_size)
139
+ val_loss = torch.tensor(0.0).to(self.device)
140
+ correct = torch.tensor(0).to(self.device)
141
+ total = torch.tensor(0).to(self.device)
142
+
143
+ with torch.no_grad():
144
+ for batch in val_dataloader:
145
+ if self.get_batch_size(batch) == batch_size:
146
+ loss, logits = self.valid_step(batch)
147
+ val_loss += loss
148
+ shifted_logits = logits[:, :-1].contiguous()
149
+ shifted_targets = batch[self.target_field_name][:, 1:].to(self.device).contiguous()
150
+ valid_indices = shifted_targets != -100
151
+ if valid_indices.any():
152
+ preds = shifted_logits.argmax(-1)
153
+ correct += (preds[valid_indices] == shifted_targets[valid_indices]).sum()
154
+ total += valid_indices.sum()
155
+
156
+ avg_loss = (val_loss / len(val_dataloader)).item()
157
+
158
+ if self.use_ddp:
159
+ dist.all_reduce(correct, op=dist.ReduceOp.SUM)
160
+ dist.all_reduce(total, op=dist.ReduceOp.SUM)
161
+
162
+ metrics = {
163
+ 'accuracy': (correct / total * 100).item() if total > 0 else 0.0
164
+ }
165
+ self.model.train()
166
+ return avg_loss, metrics
167
+
168
+
169
+ class JointTrainingModel(nn.Module):
170
+ def __init__(
171
+ self,
172
+ encoder: ReactiveTransformerEncoder,
173
+ decoder: ReactiveTransformerDecoder,
174
+ mlm_head: MLMHead,
175
+ *args,
176
+ **kwargs
177
+ ):
178
+ super(JointTrainingModel, self).__init__(*args, **kwargs)
179
+ self.encoder = encoder
180
+ self.mlm_head = mlm_head
181
+ self.decoder = decoder
182
+
183
+ def forward(self, x_e: torch.Tensor, x_d: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[
184
+ torch.Tensor, torch.Tensor]:
185
+ encoder_result, _ = self.encoder(x_e, attention_mask=attention_mask)
186
+ y_e = self.mlm_head(encoder_result)
187
+ y_d = self.decoder(x_d, attention_mask=attention_mask)
188
+ return y_e, y_d
189
+
190
+
191
+ class JointLMTrainer(BaseTrainer):
192
+ def __init__(
193
+ self,
194
+ model: JointTrainingModel,
195
+ device: torch.device,
196
+ vocab_size: int,
197
+ use_amp: bool = False,
198
+ dtype: torch.dtype = None,
199
+ components_loss_log_interval: int = None,
200
+ encoder_loss_scale: float = 1.0,
201
+ **kwargs
202
+ ):
203
+ super(JointLMTrainer, self).__init__(model, device, use_amp=use_amp, dtype=dtype, **kwargs)
204
+ self.vocab_size = vocab_size
205
+ self.components_loss_log_interval = components_loss_log_interval
206
+ self.encoder_loss_scale = encoder_loss_scale
207
+
208
+ def train_step(self, batch: dict[str, Union[torch.Tensor, dict[torch.Tensor]]], batch_idx: int) -> torch.Tensor:
209
+ if self.use_amp:
210
+ batch = {
211
+ k: ({kk: vv.to(self.device) for kk, vv in v.items()} if not torch.is_tensor(v) else v.to(self.device))
212
+ for k, v in batch.items()}
213
+ with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
214
+ (encoder_loss, decoder_loss), _ = self.compute_loss(batch)
215
+ else:
216
+ batch = {k: (
217
+ {kk: vv.to(self.device, dtype=self.dtype) for kk, vv in v.items()} if not torch.is_tensor(v) else v.to(
218
+ self.device, dtype=self.dtype)) for k, v in batch.items()}
219
+ (encoder_loss, decoder_loss), _ = self.compute_loss(batch)
220
+ if self.components_loss_log_interval is not None:
221
+ if batch_idx % self.components_loss_log_interval == 0:
222
+ print(f"Encoder loss: {encoder_loss.item():.4f}")
223
+ print(f"Decoder loss: {decoder_loss.item():.4f}")
224
+ if self.encoder_loss_scale != 1.0:
225
+ print(
226
+ f"Encoder loss scaled by {self.encoder_loss_scale}: {(encoder_loss * self.encoder_loss_scale).item() :.4f}")
227
+ return (encoder_loss * self.encoder_loss_scale) + decoder_loss
228
+
229
+ def compute_loss(self, batch: dict[str, dict[str, torch.Tensor]]) -> tuple[
230
+ tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
231
+ encoder_inputs = batch['encoder']['input_ids']
232
+ encoder_labels = batch['encoder']['labels']
233
+ decoder_inputs = batch['decoder']['input_ids']
234
+ decoder_targets = batch['decoder']['targets']
235
+ attention_mask = batch['attention_mask']
236
+
237
+ encoder_logits, decoder_logits = self.model(
238
+ encoder_inputs,
239
+ decoder_inputs,
240
+ attention_mask=attention_mask
241
+ )
242
+
243
+ encoder_loss = F.cross_entropy(
244
+ encoder_logits.view(-1, self.vocab_size),
245
+ encoder_labels.view(-1),
246
+ ignore_index=-100
247
+ )
248
+
249
+ shifted_logits = decoder_logits[:, :-1].contiguous()
250
+ shifted_targets = decoder_targets[:, 1:].contiguous()
251
+
252
+ decoder_loss = F.cross_entropy(
253
+ shifted_logits.view(-1, self.vocab_size),
254
+ shifted_targets.view(-1)
255
+ )
256
+
257
+ return (encoder_loss, decoder_loss), (encoder_logits, decoder_logits)
258
+
259
+ def _valid_writer(self, epoch: int, val_loss: float, val_metrics: dict):
260
+ self.writer.add_scalar('Loss/validation', val_loss, epoch)
261
+ self.writer.add_scalar('Perplexity/validation', math.exp(val_loss), epoch)
262
+ if val_metrics['accuracy']:
263
+ self.writer.add_scalar('Encoder accuracy/validation', val_metrics['accuracy']['encoder'], epoch)
264
+ self.writer.add_scalar('Decoder accuracy/validation', val_metrics['accuracy']['decoder'], epoch)
265
+ if val_metrics['loss']:
266
+ self.writer.add_scalar('Encoder loss/validation', val_metrics['loss']['encoder'], epoch)
267
+ self.writer.add_scalar('Encoder perplexity/validation', math.exp(val_metrics['loss']['encoder']), epoch)
268
+ self.writer.add_scalar('Decoder accuracy/validation', val_metrics['loss']['decoder'], epoch)
269
+ self.writer.add_scalar('Decoder perplexity/validation', math.exp(val_metrics['loss']['decoder']), epoch)
270
+
271
+ def validate(self, batch_size: int) -> tuple[float, dict]:
272
+ self.model.eval()
273
+ val_loss = torch.tensor(0.0).to(self.device)
274
+ dec_loss = torch.tensor(0.0).to(self.device)
275
+ enc_loss = torch.tensor(0.0).to(self.device)
276
+ correct_mlm = torch.tensor(0).to(self.device)
277
+ total_mlm = torch.tensor(0).to(self.device)
278
+ correct_alm = torch.tensor(0).to(self.device)
279
+ total_alm = torch.tensor(0).to(self.device)
280
+
281
+ val_dataloader = self._valid_loader(batch_size)
282
+
283
+ with torch.no_grad():
284
+ for batch in val_dataloader:
285
+ if self.get_batch_size(batch) == batch_size:
286
+ if self.use_amp:
287
+ batch = {
288
+ k: ({kk: vv.to(self.device) for kk, vv in v.items()} if not torch.is_tensor(v) else v.to(
289
+ self.device)) for k, v in batch.items()}
290
+ with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
291
+ (encoder_loss, decoder_loss), (encoder_logits, decoder_logits) = self.compute_loss(batch)
292
+ else:
293
+ batch = {k: (
294
+ {kk: vv.to(self.device, dtype=self.dtype) for kk, vv in v.items()} if not torch.is_tensor(
295
+ v) else v.to(self.device, dtype=self.dtype)) for k, v in batch.items()}
296
+ (encoder_loss, decoder_loss), (encoder_logits, decoder_logits) = self.compute_loss(batch)
297
+ enc_loss += encoder_loss
298
+ dec_loss += decoder_loss
299
+ val_loss += (enc_loss * self.encoder_loss_scale) + dec_loss
300
+
301
+ encoder_labels = batch['encoder']['labels'].to(self.device)
302
+ valid_mlm_indices = encoder_labels != -100
303
+ if valid_mlm_indices.any():
304
+ preds_mlm = encoder_logits.argmax(-1)
305
+ correct_mlm += (preds_mlm[valid_mlm_indices] == encoder_labels[valid_mlm_indices]).sum()
306
+ total_mlm += valid_mlm_indices.sum()
307
+
308
+ shifted_logits = decoder_logits[:, :-1].contiguous()
309
+ shifted_targets = batch['decoder']['targets'][:, 1:].to(self.device).contiguous()
310
+ valid_alm_indices = shifted_targets != -100
311
+ if valid_alm_indices.any():
312
+ preds_alm = shifted_logits.argmax(-1)
313
+ correct_alm += (preds_alm[valid_alm_indices] == shifted_targets[valid_alm_indices]).sum()
314
+ total_alm += valid_alm_indices.sum()
315
+
316
+ loader_len = len(val_dataloader)
317
+ avg_loss = val_loss / loader_len
318
+ avg_dec_loss = dec_loss / loader_len
319
+ avg_enc_loss = enc_loss / loader_len
320
+
321
+ if self.use_ddp:
322
+ dist.all_reduce(avg_dec_loss, op=dist.ReduceOp.SUM)
323
+ dist.all_reduce(avg_enc_loss, op=dist.ReduceOp.SUM)
324
+ dist.all_reduce(correct_mlm, op=dist.ReduceOp.SUM)
325
+ dist.all_reduce(total_mlm, op=dist.ReduceOp.SUM)
326
+ dist.all_reduce(correct_alm, op=dist.ReduceOp.SUM)
327
+ dist.all_reduce(total_alm, op=dist.ReduceOp.SUM)
328
+ avg_dec_loss = avg_dec_loss / dist.get_world_size()
329
+ avg_enc_loss = avg_enc_loss / dist.get_world_size()
330
+
331
+ mlm_acc = (correct_mlm / total_mlm * 100).item() if total_mlm > 0 else 0.0
332
+ alm_acc = (correct_alm / total_alm * 100).item() if total_alm > 0 else 0.0
333
+
334
+ metrics = {
335
+ 'accuracy': {
336
+ 'encoder': mlm_acc,
337
+ 'decoder': alm_acc,
338
+ },
339
+ 'loss': {
340
+ 'encoder': avg_enc_loss,
341
+ 'decoder': avg_dec_loss,
342
+ }
343
+ }
344
+ self.model.train()
345
+ return avg_loss, metrics