rxnn 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn-0.1.0.dist-info/LICENSE +201 -0
- rxnn-0.1.0.dist-info/METADATA +257 -0
- rxnn-0.1.0.dist-info/RECORD +23 -0
- rxnn-0.1.0.dist-info/WHEEL +4 -0
- src/experimental/attention.py +133 -0
- src/memory/norm.py +173 -0
- src/memory/stm.py +53 -0
- src/rxt/models.py +180 -0
- src/training/base.py +275 -0
- src/training/bml.py +345 -0
- src/training/callbacks.py +491 -0
- src/training/dataset.py +164 -0
- src/training/scheduler.py +19 -0
- src/training/tokenizer.py +208 -0
- src/transformers/attention.py +324 -0
- src/transformers/ff.py +72 -0
- src/transformers/layers.py +150 -0
- src/transformers/mask.py +10 -0
- src/transformers/models.py +168 -0
- src/transformers/moe.py +139 -0
- src/transformers/positional.py +105 -0
- src/transformers/sampler.py +109 -0
- src/utils.py +14 -0
src/training/bml.py
ADDED
@@ -0,0 +1,345 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
import math
|
5
|
+
from huggingface_hub import PyTorchModelHubMixin
|
6
|
+
from typing import Union
|
7
|
+
import torch.distributed as dist
|
8
|
+
from src.transformers.models import ReactiveTransformerEncoder, ReactiveTransformerDecoder
|
9
|
+
from src.training.base import BaseTrainer
|
10
|
+
|
11
|
+
class MLMHead(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
|
12
|
+
def __init__(self, embed_dim: int, vocab_size: int, *args, **kwargs):
|
13
|
+
super(MLMHead, self).__init__(*args, **kwargs)
|
14
|
+
self.dense = nn.Linear(embed_dim, embed_dim)
|
15
|
+
self.act = nn.GELU()
|
16
|
+
self.layer_norm = nn.LayerNorm(embed_dim)
|
17
|
+
self.decoder = nn.Linear(embed_dim, vocab_size)
|
18
|
+
|
19
|
+
def forward(self, hidden_states):
|
20
|
+
x = self.dense(hidden_states)
|
21
|
+
x = self.act(x)
|
22
|
+
x = self.layer_norm(x)
|
23
|
+
return self.decoder(x)
|
24
|
+
|
25
|
+
|
26
|
+
class MLMTrainingModel(nn.Module):
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
encoder: ReactiveTransformerEncoder,
|
30
|
+
mlm_head: MLMHead,
|
31
|
+
*args,
|
32
|
+
**kwargs
|
33
|
+
):
|
34
|
+
super(MLMTrainingModel, self).__init__(*args, **kwargs)
|
35
|
+
self.encoder = encoder
|
36
|
+
self.mlm_head = mlm_head
|
37
|
+
|
38
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
39
|
+
h, _ = self.encoder(x, attention_mask=attention_mask)
|
40
|
+
y = self.mlm_head(h)
|
41
|
+
return y
|
42
|
+
|
43
|
+
|
44
|
+
class MLMTrainer(BaseTrainer):
|
45
|
+
def __init__(
|
46
|
+
self,
|
47
|
+
model: MLMTrainingModel,
|
48
|
+
device: torch.device,
|
49
|
+
vocab_size: int,
|
50
|
+
use_amp: bool = False,
|
51
|
+
dtype: torch.dtype = None,
|
52
|
+
**kwargs
|
53
|
+
):
|
54
|
+
super(MLMTrainer, self).__init__(model, device, use_amp=use_amp, dtype=dtype, **kwargs)
|
55
|
+
self.vocab_size = vocab_size
|
56
|
+
|
57
|
+
def compute_loss(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
58
|
+
inputs = batch['input_ids']
|
59
|
+
attention_mask = batch['attention_mask']
|
60
|
+
labels = batch['labels']
|
61
|
+
|
62
|
+
logits = self.model(
|
63
|
+
inputs,
|
64
|
+
attention_mask=attention_mask
|
65
|
+
)
|
66
|
+
|
67
|
+
return F.cross_entropy(
|
68
|
+
logits.view(-1, self.vocab_size),
|
69
|
+
labels.view(-1),
|
70
|
+
ignore_index=-100
|
71
|
+
), logits
|
72
|
+
|
73
|
+
def validate(self, batch_size: int) -> tuple[float, dict]:
|
74
|
+
self.model.eval()
|
75
|
+
val_dataloader = self._valid_loader(batch_size)
|
76
|
+
val_loss = torch.tensor(0.0).to(self.device)
|
77
|
+
correct = torch.tensor(0).to(self.device)
|
78
|
+
total = torch.tensor(0).to(self.device)
|
79
|
+
|
80
|
+
with torch.no_grad():
|
81
|
+
for batch in val_dataloader:
|
82
|
+
if self.get_batch_size(batch) == batch_size:
|
83
|
+
loss, logits = self.valid_step(batch)
|
84
|
+
val_loss += loss
|
85
|
+
labels = batch[self.target_field_name].to(self.device)
|
86
|
+
valid_indices = labels != -100
|
87
|
+
if valid_indices.any():
|
88
|
+
preds = logits.argmax(-1)
|
89
|
+
correct += (preds[valid_indices] == labels[valid_indices]).sum()
|
90
|
+
total += valid_indices.sum()
|
91
|
+
|
92
|
+
avg_loss = (val_loss / len(val_dataloader)).item()
|
93
|
+
if self.use_ddp:
|
94
|
+
dist.all_reduce(correct, op=dist.ReduceOp.SUM)
|
95
|
+
dist.all_reduce(total, op=dist.ReduceOp.SUM)
|
96
|
+
|
97
|
+
metrics = {
|
98
|
+
'accuracy': (correct / total * 100).item() if total > 0 else 0.0
|
99
|
+
}
|
100
|
+
self.model.train()
|
101
|
+
return avg_loss, metrics
|
102
|
+
|
103
|
+
|
104
|
+
class AutoregressiveTrainer(BaseTrainer):
|
105
|
+
def __init__(
|
106
|
+
self,
|
107
|
+
model: ReactiveTransformerDecoder,
|
108
|
+
device: torch.device,
|
109
|
+
vocab_size: int,
|
110
|
+
use_amp: bool = False,
|
111
|
+
dtype: torch.dtype = None,
|
112
|
+
**kwargs
|
113
|
+
):
|
114
|
+
super(AutoregressiveTrainer, self).__init__(model, device, use_amp=use_amp, dtype=dtype,
|
115
|
+
target_field_name='targets', **kwargs)
|
116
|
+
self.vocab_size = vocab_size
|
117
|
+
|
118
|
+
def compute_loss(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
119
|
+
inputs = batch['input_ids']
|
120
|
+
attention_mask = batch['attention_mask']
|
121
|
+
targets = batch['targets']
|
122
|
+
|
123
|
+
outputs = self.model(
|
124
|
+
inputs,
|
125
|
+
attention_mask=attention_mask
|
126
|
+
)
|
127
|
+
|
128
|
+
shifted_logits = outputs[:, :-1].contiguous()
|
129
|
+
shifted_targets = targets[:, 1:].contiguous()
|
130
|
+
|
131
|
+
return F.cross_entropy(
|
132
|
+
shifted_logits.view(-1, self.vocab_size),
|
133
|
+
shifted_targets.view(-1)
|
134
|
+
), outputs
|
135
|
+
|
136
|
+
def validate(self, batch_size: int) -> tuple[float, dict]:
|
137
|
+
self.model.eval()
|
138
|
+
val_dataloader = self._valid_loader(batch_size)
|
139
|
+
val_loss = torch.tensor(0.0).to(self.device)
|
140
|
+
correct = torch.tensor(0).to(self.device)
|
141
|
+
total = torch.tensor(0).to(self.device)
|
142
|
+
|
143
|
+
with torch.no_grad():
|
144
|
+
for batch in val_dataloader:
|
145
|
+
if self.get_batch_size(batch) == batch_size:
|
146
|
+
loss, logits = self.valid_step(batch)
|
147
|
+
val_loss += loss
|
148
|
+
shifted_logits = logits[:, :-1].contiguous()
|
149
|
+
shifted_targets = batch[self.target_field_name][:, 1:].to(self.device).contiguous()
|
150
|
+
valid_indices = shifted_targets != -100
|
151
|
+
if valid_indices.any():
|
152
|
+
preds = shifted_logits.argmax(-1)
|
153
|
+
correct += (preds[valid_indices] == shifted_targets[valid_indices]).sum()
|
154
|
+
total += valid_indices.sum()
|
155
|
+
|
156
|
+
avg_loss = (val_loss / len(val_dataloader)).item()
|
157
|
+
|
158
|
+
if self.use_ddp:
|
159
|
+
dist.all_reduce(correct, op=dist.ReduceOp.SUM)
|
160
|
+
dist.all_reduce(total, op=dist.ReduceOp.SUM)
|
161
|
+
|
162
|
+
metrics = {
|
163
|
+
'accuracy': (correct / total * 100).item() if total > 0 else 0.0
|
164
|
+
}
|
165
|
+
self.model.train()
|
166
|
+
return avg_loss, metrics
|
167
|
+
|
168
|
+
|
169
|
+
class JointTrainingModel(nn.Module):
|
170
|
+
def __init__(
|
171
|
+
self,
|
172
|
+
encoder: ReactiveTransformerEncoder,
|
173
|
+
decoder: ReactiveTransformerDecoder,
|
174
|
+
mlm_head: MLMHead,
|
175
|
+
*args,
|
176
|
+
**kwargs
|
177
|
+
):
|
178
|
+
super(JointTrainingModel, self).__init__(*args, **kwargs)
|
179
|
+
self.encoder = encoder
|
180
|
+
self.mlm_head = mlm_head
|
181
|
+
self.decoder = decoder
|
182
|
+
|
183
|
+
def forward(self, x_e: torch.Tensor, x_d: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[
|
184
|
+
torch.Tensor, torch.Tensor]:
|
185
|
+
encoder_result, _ = self.encoder(x_e, attention_mask=attention_mask)
|
186
|
+
y_e = self.mlm_head(encoder_result)
|
187
|
+
y_d = self.decoder(x_d, attention_mask=attention_mask)
|
188
|
+
return y_e, y_d
|
189
|
+
|
190
|
+
|
191
|
+
class JointLMTrainer(BaseTrainer):
|
192
|
+
def __init__(
|
193
|
+
self,
|
194
|
+
model: JointTrainingModel,
|
195
|
+
device: torch.device,
|
196
|
+
vocab_size: int,
|
197
|
+
use_amp: bool = False,
|
198
|
+
dtype: torch.dtype = None,
|
199
|
+
components_loss_log_interval: int = None,
|
200
|
+
encoder_loss_scale: float = 1.0,
|
201
|
+
**kwargs
|
202
|
+
):
|
203
|
+
super(JointLMTrainer, self).__init__(model, device, use_amp=use_amp, dtype=dtype, **kwargs)
|
204
|
+
self.vocab_size = vocab_size
|
205
|
+
self.components_loss_log_interval = components_loss_log_interval
|
206
|
+
self.encoder_loss_scale = encoder_loss_scale
|
207
|
+
|
208
|
+
def train_step(self, batch: dict[str, Union[torch.Tensor, dict[torch.Tensor]]], batch_idx: int) -> torch.Tensor:
|
209
|
+
if self.use_amp:
|
210
|
+
batch = {
|
211
|
+
k: ({kk: vv.to(self.device) for kk, vv in v.items()} if not torch.is_tensor(v) else v.to(self.device))
|
212
|
+
for k, v in batch.items()}
|
213
|
+
with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
|
214
|
+
(encoder_loss, decoder_loss), _ = self.compute_loss(batch)
|
215
|
+
else:
|
216
|
+
batch = {k: (
|
217
|
+
{kk: vv.to(self.device, dtype=self.dtype) for kk, vv in v.items()} if not torch.is_tensor(v) else v.to(
|
218
|
+
self.device, dtype=self.dtype)) for k, v in batch.items()}
|
219
|
+
(encoder_loss, decoder_loss), _ = self.compute_loss(batch)
|
220
|
+
if self.components_loss_log_interval is not None:
|
221
|
+
if batch_idx % self.components_loss_log_interval == 0:
|
222
|
+
print(f"Encoder loss: {encoder_loss.item():.4f}")
|
223
|
+
print(f"Decoder loss: {decoder_loss.item():.4f}")
|
224
|
+
if self.encoder_loss_scale != 1.0:
|
225
|
+
print(
|
226
|
+
f"Encoder loss scaled by {self.encoder_loss_scale}: {(encoder_loss * self.encoder_loss_scale).item() :.4f}")
|
227
|
+
return (encoder_loss * self.encoder_loss_scale) + decoder_loss
|
228
|
+
|
229
|
+
def compute_loss(self, batch: dict[str, dict[str, torch.Tensor]]) -> tuple[
|
230
|
+
tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
|
231
|
+
encoder_inputs = batch['encoder']['input_ids']
|
232
|
+
encoder_labels = batch['encoder']['labels']
|
233
|
+
decoder_inputs = batch['decoder']['input_ids']
|
234
|
+
decoder_targets = batch['decoder']['targets']
|
235
|
+
attention_mask = batch['attention_mask']
|
236
|
+
|
237
|
+
encoder_logits, decoder_logits = self.model(
|
238
|
+
encoder_inputs,
|
239
|
+
decoder_inputs,
|
240
|
+
attention_mask=attention_mask
|
241
|
+
)
|
242
|
+
|
243
|
+
encoder_loss = F.cross_entropy(
|
244
|
+
encoder_logits.view(-1, self.vocab_size),
|
245
|
+
encoder_labels.view(-1),
|
246
|
+
ignore_index=-100
|
247
|
+
)
|
248
|
+
|
249
|
+
shifted_logits = decoder_logits[:, :-1].contiguous()
|
250
|
+
shifted_targets = decoder_targets[:, 1:].contiguous()
|
251
|
+
|
252
|
+
decoder_loss = F.cross_entropy(
|
253
|
+
shifted_logits.view(-1, self.vocab_size),
|
254
|
+
shifted_targets.view(-1)
|
255
|
+
)
|
256
|
+
|
257
|
+
return (encoder_loss, decoder_loss), (encoder_logits, decoder_logits)
|
258
|
+
|
259
|
+
def _valid_writer(self, epoch: int, val_loss: float, val_metrics: dict):
|
260
|
+
self.writer.add_scalar('Loss/validation', val_loss, epoch)
|
261
|
+
self.writer.add_scalar('Perplexity/validation', math.exp(val_loss), epoch)
|
262
|
+
if val_metrics['accuracy']:
|
263
|
+
self.writer.add_scalar('Encoder accuracy/validation', val_metrics['accuracy']['encoder'], epoch)
|
264
|
+
self.writer.add_scalar('Decoder accuracy/validation', val_metrics['accuracy']['decoder'], epoch)
|
265
|
+
if val_metrics['loss']:
|
266
|
+
self.writer.add_scalar('Encoder loss/validation', val_metrics['loss']['encoder'], epoch)
|
267
|
+
self.writer.add_scalar('Encoder perplexity/validation', math.exp(val_metrics['loss']['encoder']), epoch)
|
268
|
+
self.writer.add_scalar('Decoder accuracy/validation', val_metrics['loss']['decoder'], epoch)
|
269
|
+
self.writer.add_scalar('Decoder perplexity/validation', math.exp(val_metrics['loss']['decoder']), epoch)
|
270
|
+
|
271
|
+
def validate(self, batch_size: int) -> tuple[float, dict]:
|
272
|
+
self.model.eval()
|
273
|
+
val_loss = torch.tensor(0.0).to(self.device)
|
274
|
+
dec_loss = torch.tensor(0.0).to(self.device)
|
275
|
+
enc_loss = torch.tensor(0.0).to(self.device)
|
276
|
+
correct_mlm = torch.tensor(0).to(self.device)
|
277
|
+
total_mlm = torch.tensor(0).to(self.device)
|
278
|
+
correct_alm = torch.tensor(0).to(self.device)
|
279
|
+
total_alm = torch.tensor(0).to(self.device)
|
280
|
+
|
281
|
+
val_dataloader = self._valid_loader(batch_size)
|
282
|
+
|
283
|
+
with torch.no_grad():
|
284
|
+
for batch in val_dataloader:
|
285
|
+
if self.get_batch_size(batch) == batch_size:
|
286
|
+
if self.use_amp:
|
287
|
+
batch = {
|
288
|
+
k: ({kk: vv.to(self.device) for kk, vv in v.items()} if not torch.is_tensor(v) else v.to(
|
289
|
+
self.device)) for k, v in batch.items()}
|
290
|
+
with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
|
291
|
+
(encoder_loss, decoder_loss), (encoder_logits, decoder_logits) = self.compute_loss(batch)
|
292
|
+
else:
|
293
|
+
batch = {k: (
|
294
|
+
{kk: vv.to(self.device, dtype=self.dtype) for kk, vv in v.items()} if not torch.is_tensor(
|
295
|
+
v) else v.to(self.device, dtype=self.dtype)) for k, v in batch.items()}
|
296
|
+
(encoder_loss, decoder_loss), (encoder_logits, decoder_logits) = self.compute_loss(batch)
|
297
|
+
enc_loss += encoder_loss
|
298
|
+
dec_loss += decoder_loss
|
299
|
+
val_loss += (enc_loss * self.encoder_loss_scale) + dec_loss
|
300
|
+
|
301
|
+
encoder_labels = batch['encoder']['labels'].to(self.device)
|
302
|
+
valid_mlm_indices = encoder_labels != -100
|
303
|
+
if valid_mlm_indices.any():
|
304
|
+
preds_mlm = encoder_logits.argmax(-1)
|
305
|
+
correct_mlm += (preds_mlm[valid_mlm_indices] == encoder_labels[valid_mlm_indices]).sum()
|
306
|
+
total_mlm += valid_mlm_indices.sum()
|
307
|
+
|
308
|
+
shifted_logits = decoder_logits[:, :-1].contiguous()
|
309
|
+
shifted_targets = batch['decoder']['targets'][:, 1:].to(self.device).contiguous()
|
310
|
+
valid_alm_indices = shifted_targets != -100
|
311
|
+
if valid_alm_indices.any():
|
312
|
+
preds_alm = shifted_logits.argmax(-1)
|
313
|
+
correct_alm += (preds_alm[valid_alm_indices] == shifted_targets[valid_alm_indices]).sum()
|
314
|
+
total_alm += valid_alm_indices.sum()
|
315
|
+
|
316
|
+
loader_len = len(val_dataloader)
|
317
|
+
avg_loss = val_loss / loader_len
|
318
|
+
avg_dec_loss = dec_loss / loader_len
|
319
|
+
avg_enc_loss = enc_loss / loader_len
|
320
|
+
|
321
|
+
if self.use_ddp:
|
322
|
+
dist.all_reduce(avg_dec_loss, op=dist.ReduceOp.SUM)
|
323
|
+
dist.all_reduce(avg_enc_loss, op=dist.ReduceOp.SUM)
|
324
|
+
dist.all_reduce(correct_mlm, op=dist.ReduceOp.SUM)
|
325
|
+
dist.all_reduce(total_mlm, op=dist.ReduceOp.SUM)
|
326
|
+
dist.all_reduce(correct_alm, op=dist.ReduceOp.SUM)
|
327
|
+
dist.all_reduce(total_alm, op=dist.ReduceOp.SUM)
|
328
|
+
avg_dec_loss = avg_dec_loss / dist.get_world_size()
|
329
|
+
avg_enc_loss = avg_enc_loss / dist.get_world_size()
|
330
|
+
|
331
|
+
mlm_acc = (correct_mlm / total_mlm * 100).item() if total_mlm > 0 else 0.0
|
332
|
+
alm_acc = (correct_alm / total_alm * 100).item() if total_alm > 0 else 0.0
|
333
|
+
|
334
|
+
metrics = {
|
335
|
+
'accuracy': {
|
336
|
+
'encoder': mlm_acc,
|
337
|
+
'decoder': alm_acc,
|
338
|
+
},
|
339
|
+
'loss': {
|
340
|
+
'encoder': avg_enc_loss,
|
341
|
+
'decoder': avg_dec_loss,
|
342
|
+
}
|
343
|
+
}
|
344
|
+
self.model.train()
|
345
|
+
return avg_loss, metrics
|