rxnn 0.1.10__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/base.py CHANGED
@@ -49,6 +49,7 @@ class BaseTrainer(ABC):
49
49
  self.validation_metrics = {}
50
50
  self.target_field_name = target_field_name
51
51
  self.total_tokens = 0
52
+ self.total_steps = 0
52
53
  self.gradient_accumulation_steps = gradient_accumulation_steps
53
54
  self.accumulated_loss = 0.0
54
55
  self.optimizer_step_count = 0
@@ -140,6 +141,7 @@ class BaseTrainer(ABC):
140
141
  for callback in self.callbacks:
141
142
  callback.on_batch_start(self.model, batch_idx, batch)
142
143
  if self.get_batch_size(batch) == batch_size:
144
+ self.total_steps += 1
143
145
  loss = self.train_step(batch, batch_idx)
144
146
  orig_loss = loss.item()
145
147
  self.accumulated_loss += orig_loss
@@ -226,11 +228,11 @@ class BaseTrainer(ABC):
226
228
  self.writer.close()
227
229
 
228
230
  def _valid_writer(self, epoch: int, val_loss: float, val_metrics: dict):
229
- self.writer.add_scalar('Loss/validation', val_loss, epoch)
230
- self.writer.add_scalar('Perplexity/validation', math.exp(val_loss), epoch)
231
+ self.writer.add_scalar('Loss/Valid', val_loss, epoch)
232
+ self.writer.add_scalar('Perplexity/Valid', math.exp(val_loss), epoch)
231
233
  if val_metrics['accuracy']:
232
- self.writer.add_scalar('Node Accuracy/validation', val_metrics['node_accuracy'], epoch)
233
- self.writer.add_scalar('Avg. Accuracy/validation', val_metrics['accuracy'], epoch)
234
+ self.writer.add_scalar('Node Accuracy/Valid', val_metrics['node_accuracy'], epoch)
235
+ self.writer.add_scalar('Avg. Accuracy/Valid', val_metrics['accuracy'], epoch)
234
236
 
235
237
  def valid_step(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
236
238
  if self.use_amp:
rxnn/training/bml.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
  import torch.nn.functional as F
4
+ from torch.nn.parallel import DistributedDataParallel
4
5
  import math
5
6
  from huggingface_hub import PyTorchModelHubMixin
6
7
  from typing import Union
@@ -171,6 +172,90 @@ class AutoregressiveTrainer(BaseTrainer):
171
172
  return avg_loss, metrics
172
173
 
173
174
 
175
+ class AutoregressiveMoeTrainer(BaseTrainer):
176
+ def __init__(
177
+ self,
178
+ model: ReactiveTransformerDecoder,
179
+ device: torch.device,
180
+ vocab_size: int,
181
+ use_amp: bool = False,
182
+ dtype: torch.dtype = None,
183
+ router_loss_scale: float = 0.1,
184
+ **kwargs
185
+ ):
186
+ super(AutoregressiveMoeTrainer, self).__init__(model, device, use_amp=use_amp, dtype=dtype,
187
+ target_field_name='targets', **kwargs)
188
+ self.vocab_size = vocab_size
189
+ self.router_loss_scale = router_loss_scale
190
+
191
+ def compute_loss(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
192
+ inputs = batch['input_ids']
193
+ attention_mask = batch['attention_mask']
194
+ targets = batch['targets']
195
+
196
+ outputs = self.model(
197
+ inputs,
198
+ attention_mask=attention_mask
199
+ )
200
+
201
+ shifted_logits = outputs[:, :-1].contiguous()
202
+ shifted_targets = targets[:, 1:].contiguous()
203
+
204
+ main_loss = F.cross_entropy(
205
+ shifted_logits.view(-1, self.vocab_size),
206
+ shifted_targets.view(-1)
207
+ )
208
+
209
+ model = next(self.model.children()) if isinstance(self.model, DistributedDataParallel) else self.model
210
+
211
+ router_loss = model.moe_router_loss()
212
+ loss = main_loss + self.router_loss_scale * router_loss
213
+
214
+ if self.writer is not None:
215
+ if self.model.training:
216
+ self.writer.add_scalar('Router aux loss/Train', router_loss.item(), self.total_steps)
217
+ self.writer.add_scalar('Model loss/Train', main_loss.item(), self.total_steps)
218
+ else:
219
+ self.writer.add_scalar('Router aux loss/Valid', router_loss.item(), self.total_steps)
220
+ self.writer.add_scalar('Model loss/Valid', main_loss.item(), self.total_steps)
221
+
222
+ return loss, outputs
223
+
224
+ def validate(self, batch_size: int) -> tuple[float, dict]:
225
+ self.model.eval()
226
+ val_dataloader = self._valid_loader(batch_size)
227
+ val_loss = torch.tensor(0.0).to(self.device)
228
+ correct = torch.tensor(0).to(self.device)
229
+ total = torch.tensor(0).to(self.device)
230
+
231
+ with torch.no_grad():
232
+ for batch in val_dataloader:
233
+ if self.get_batch_size(batch) == batch_size:
234
+ loss, logits = self.valid_step(batch)
235
+ val_loss += loss
236
+ shifted_logits = logits[:, :-1].contiguous()
237
+ shifted_targets = batch[self.target_field_name][:, 1:].to(self.device).contiguous()
238
+ valid_indices = shifted_targets != -100
239
+ if valid_indices.any():
240
+ preds = shifted_logits.argmax(-1)
241
+ correct += (preds[valid_indices] == shifted_targets[valid_indices]).sum()
242
+ total += valid_indices.sum()
243
+
244
+ avg_loss = (val_loss / len(val_dataloader)).item()
245
+ acc = (correct / total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
246
+ node_acc = acc.item()
247
+ if self.use_ddp:
248
+ dist.all_reduce(acc, op=dist.ReduceOp.SUM)
249
+ acc = acc / dist.get_world_size()
250
+
251
+ metrics = {
252
+ 'accuracy': acc.item(),
253
+ 'node_accuracy': node_acc,
254
+ }
255
+ self.model.train()
256
+ return avg_loss, metrics
257
+
258
+
174
259
  class JointTrainingModel(nn.Module):
175
260
  def __init__(
176
261
  self,
@@ -262,18 +347,18 @@ class JointLMTrainer(BaseTrainer):
262
347
  return (encoder_loss, decoder_loss), (encoder_logits, decoder_logits)
263
348
 
264
349
  def _valid_writer(self, epoch: int, val_loss: float, val_metrics: dict):
265
- self.writer.add_scalar('Loss/validation', val_loss, epoch)
266
- self.writer.add_scalar('Perplexity/validation', math.exp(val_loss), epoch)
350
+ self.writer.add_scalar('Loss/Valid', val_loss, epoch)
351
+ self.writer.add_scalar('Perplexity/Valid', math.exp(val_loss), epoch)
267
352
  if val_metrics['accuracy']:
268
- self.writer.add_scalar('Encoder node accuracy/validation', val_metrics['accuracy']['node_encoder'], epoch)
269
- self.writer.add_scalar('Decoder node accuracy/validation', val_metrics['accuracy']['node_decoder'], epoch)
270
- self.writer.add_scalar('Encoder avg. accuracy/validation', val_metrics['accuracy']['encoder'], epoch)
271
- self.writer.add_scalar('Decoder avg. accuracy/validation', val_metrics['accuracy']['decoder'], epoch)
353
+ self.writer.add_scalar('Encoder node accuracy/Valid', val_metrics['accuracy']['node_encoder'], epoch)
354
+ self.writer.add_scalar('Decoder node accuracy/Valid', val_metrics['accuracy']['node_decoder'], epoch)
355
+ self.writer.add_scalar('Encoder avg. accuracy/Valid', val_metrics['accuracy']['encoder'], epoch)
356
+ self.writer.add_scalar('Decoder avg. accuracy/Valid', val_metrics['accuracy']['decoder'], epoch)
272
357
  if val_metrics['loss']:
273
- self.writer.add_scalar('Encoder loss/validation', val_metrics['loss']['encoder'], epoch)
274
- self.writer.add_scalar('Encoder perplexity/validation', math.exp(val_metrics['loss']['encoder']), epoch)
275
- self.writer.add_scalar('Decoder accuracy/validation', val_metrics['loss']['decoder'], epoch)
276
- self.writer.add_scalar('Decoder perplexity/validation', math.exp(val_metrics['loss']['decoder']), epoch)
358
+ self.writer.add_scalar('Encoder loss/Valid', val_metrics['loss']['encoder'], epoch)
359
+ self.writer.add_scalar('Encoder perplexity/Valid', math.exp(val_metrics['loss']['encoder']), epoch)
360
+ self.writer.add_scalar('Decoder accuracy/Valid', val_metrics['loss']['decoder'], epoch)
361
+ self.writer.add_scalar('Decoder perplexity/Valid', math.exp(val_metrics['loss']['decoder']), epoch)
277
362
 
278
363
  def validate(self, batch_size: int) -> tuple[float, dict]:
279
364
  self.model.eval()
@@ -53,11 +53,15 @@ class ReactiveTransformerLayer(nn.Module):
53
53
  self.norm2 = nn.LayerNorm(embed_dim)
54
54
  self.norm3 = nn.LayerNorm(embed_dim)
55
55
  self.use_post_norm = use_post_norm
56
+ self.use_moe = use_moe
56
57
 
57
58
  def trainable_cross_attention_(self, is_trainable: bool):
58
59
  for param in self.memory_cross_attention.parameters():
59
60
  param.requires_grad_(is_trainable)
60
61
 
62
+ def moe_router_loss_(self):
63
+ return self.ff.router_loss() if self.use_moe else None
64
+
61
65
  def forward(self, x: torch.Tensor, stm: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
62
66
  # First step, self-attention
63
67
  residual = x
@@ -129,6 +133,10 @@ class ClassicTransformerLayer(nn.Module):
129
133
  self.norm1 = nn.LayerNorm(embed_dim)
130
134
  self.norm2 = nn.LayerNorm(embed_dim)
131
135
  self.use_post_norm = use_post_norm
136
+ self.use_moe = use_moe
137
+
138
+ def moe_router_loss_(self):
139
+ return self.ff.router_loss() if self.use_moe else torch.tensor(0.0)
132
140
 
133
141
  def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
134
142
  # First step, self-attention
@@ -37,6 +37,10 @@ class ReactiveTransformerBase(nn.Module):
37
37
  for i in range(self.num_own_layers):
38
38
  self.layers[i].trainable_cross_attention_(is_trainable)
39
39
 
40
+ def moe_router_loss_(self):
41
+ return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_own_layers) if self.layers[i].use_moe] + [
42
+ self.shared_layers[i].moe_router_loss() for i in range(self.num_shared_layers) if self.shared_layers[i].use_moe]).mean()
43
+
40
44
  def forward(self, x: torch.Tensor) -> torch.Tensor:
41
45
  # Shared logic for encoders and decoders - apply embeddings and positional encoding
42
46
  x = self.embedding(x)
@@ -119,6 +123,9 @@ class ClassicTransformerBase(nn.Module):
119
123
  self.layers = layers
120
124
  self.num_layers = len(layers) if layers else 0
121
125
 
126
+ def moe_router_loss_(self):
127
+ return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_layers) if self.layers[i].use_moe]).mean()
128
+
122
129
  def forward(self, x: torch.Tensor) -> torch.Tensor:
123
130
  # Shared logic for encoders and decoders - apply embeddings and positional encoding
124
131
  x = self.embedding(x)
rxnn/transformers/moe.py CHANGED
@@ -11,7 +11,8 @@ class MoeRouter(nn.Module):
11
11
  self.top_k = top_k
12
12
  self.num_experts = num_experts
13
13
  self.gate = nn.Linear(embed_dim, num_experts, bias=False)
14
- self.aux_loss = 0.0 # For expert load balancing
14
+ # For expert load balancing
15
+ self.register_buffer('aux_loss', torch.tensor(0.0), persistent=False)
15
16
 
16
17
  def forward(self, x: torch.Tensor):
17
18
  # x shape: [batch_size*seq_len, embed_dim]
@@ -19,10 +20,8 @@ class MoeRouter(nn.Module):
19
20
  probs = F.softmax(logits, dim=-1)
20
21
 
21
22
  # Expert load balancing loss
22
- if self.training:
23
- probs_for_bal = F.softmax(logits, dim=0)
24
- self.aux_loss = (probs_for_bal.mean(dim=0) *
25
- torch.log(probs_for_bal.mean(dim=0) + 1e-9)).sum()
23
+ mean_probs = probs.mean(dim=0) # Mean probability per expert across batch
24
+ self.aux_loss = (mean_probs * torch.log(mean_probs + 1e-9)).sum() # Entropy-based loss
26
25
 
27
26
  top_k_weights, top_k_indices = probs.topk(self.top_k, dim=-1)
28
27
  top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9)
@@ -74,6 +73,9 @@ class MoeFeedForward(nn.Module):
74
73
  def _activate(self, h: torch.Tensor):
75
74
  return self.activation(h)
76
75
 
76
+ def router_loss(self):
77
+ return self.router.aux_loss
78
+
77
79
  def forward(self, x: torch.Tensor):
78
80
  orig_shape = x.shape
79
81
  x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.10
3
+ Version: 0.1.11
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -7,8 +7,8 @@ rxnn/memory/stm.py,sha256=EsD8slSP4_9dLuq6aFPDmuFe8PWilxh90so5Z3nm-ig,2057
7
7
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  rxnn/rxt/models.py,sha256=INTFeNcqzAsjyWhNtbBHL4Tx7tYDsaQHgm72tf6u20M,6918
9
9
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
- rxnn/training/base.py,sha256=YOtSLlG6-h0r54OJtyU777k5rNkbSCps3YFfB-Fh35g,11176
11
- rxnn/training/bml.py,sha256=pEH0_pDy8QThsuYgfcT2lSdfMOnqGhlhu63xMFkUSOs,15246
10
+ rxnn/training/base.py,sha256=QD8uS14jSyR5Y_8BgCaBQTKpsarerU3lyufsWsCq_6o,11227
11
+ rxnn/training/bml.py,sha256=2kk9q3Buxq4wBHUQhyIAuHoBCninYX2K8hykWAJnxB0,18654
12
12
  rxnn/training/callbacks.py,sha256=_YfMKY_eFdc-tubhO9nYH2PXDZDQwlSI74FVOoCXpQg,22108
13
13
  rxnn/training/dataset.py,sha256=vQ5mDF3bA0HXya474n4D4iL8Mn3AEpJukgzFNVkxjGU,5106
14
14
  rxnn/training/scheduler.py,sha256=ow6oALzWjWQmHSpcJEjv6tg4g4CDMvr73TypxfcefMc,712
@@ -16,14 +16,14 @@ rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,80
16
16
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
17
  rxnn/transformers/attention.py,sha256=FfEYE0THO73p_1eRupr2mcwfW4UbI_riIxkHfr8X_1c,14022
18
18
  rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
19
- rxnn/transformers/layers.py,sha256=jdM7L0uOMO68aZiu9p6jba1Hx3aLGOChF1Zz-j4vJ5U,5364
19
+ rxnn/transformers/layers.py,sha256=xMocHzdSu7hcC_mPE_aG3-LQg2RXgunKSxcgNXYnOeo,5631
20
20
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
21
- rxnn/transformers/models.py,sha256=sLYMkVOWQ1NcM1evpCTUMucXvklySpeNT0IqpIGKmyc,6716
22
- rxnn/transformers/moe.py,sha256=JQ5QSX4FS7S-fqB7-s1ZmJbPpOeD_Injn8o4vo7wGQE,4936
21
+ rxnn/transformers/models.py,sha256=PVhiTTSQ7VTDVdOcyRf-xGNvj6oOa_2fUV2mfthcE0Y,7171
22
+ rxnn/transformers/moe.py,sha256=v21HDEhkDr10--If0P-XBjT5C7IlQJo0wGQlpDnVWEA,5020
23
23
  rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
24
24
  rxnn/transformers/sampler.py,sha256=wSz_1wNloqtuiix5w2Mcsj5NhaO9QlY0j__TVG7wJnM,3938
25
25
  rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
26
- rxnn-0.1.10.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
27
- rxnn-0.1.10.dist-info/METADATA,sha256=dbmUcafrjisLl8YzU7Y9bBeSm0cJ2IaWnts8DdqWzMY,14629
28
- rxnn-0.1.10.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
29
- rxnn-0.1.10.dist-info/RECORD,,
26
+ rxnn-0.1.11.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
27
+ rxnn-0.1.11.dist-info/METADATA,sha256=WFoe6AqfJVI6wFZ23i3qGQ3babDlLtjIMU0htjOIikw,14629
28
+ rxnn-0.1.11.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
29
+ rxnn-0.1.11.dist-info/RECORD,,
File without changes
File without changes