rxnn 0.1.10__tar.gz → 0.1.12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {rxnn-0.1.10 → rxnn-0.1.12}/PKG-INFO +1 -1
  2. {rxnn-0.1.10 → rxnn-0.1.12}/pyproject.toml +1 -1
  3. {rxnn-0.1.10 → rxnn-0.1.12}/src/rxnn/training/base.py +6 -4
  4. {rxnn-0.1.10 → rxnn-0.1.12}/src/rxnn/training/bml.py +95 -10
  5. {rxnn-0.1.10 → rxnn-0.1.12}/src/rxnn/transformers/layers.py +8 -0
  6. {rxnn-0.1.10 → rxnn-0.1.12}/src/rxnn/transformers/models.py +7 -0
  7. {rxnn-0.1.10 → rxnn-0.1.12}/src/rxnn/transformers/moe.py +7 -5
  8. {rxnn-0.1.10 → rxnn-0.1.12}/LICENSE +0 -0
  9. {rxnn-0.1.10 → rxnn-0.1.12}/README.md +0 -0
  10. {rxnn-0.1.10 → rxnn-0.1.12}/src/rxnn/__init__.py +0 -0
  11. {rxnn-0.1.10 → rxnn-0.1.12}/src/rxnn/experimental/__init__.py +0 -0
  12. {rxnn-0.1.10 → rxnn-0.1.12}/src/rxnn/experimental/attention.py +0 -0
  13. {rxnn-0.1.10 → rxnn-0.1.12}/src/rxnn/memory/__init__.py +0 -0
  14. {rxnn-0.1.10 → rxnn-0.1.12}/src/rxnn/memory/norm.py +0 -0
  15. {rxnn-0.1.10 → rxnn-0.1.12}/src/rxnn/memory/stm.py +0 -0
  16. {rxnn-0.1.10 → rxnn-0.1.12}/src/rxnn/rxt/__init__.py +0 -0
  17. {rxnn-0.1.10 → rxnn-0.1.12}/src/rxnn/rxt/models.py +0 -0
  18. {rxnn-0.1.10 → rxnn-0.1.12}/src/rxnn/training/__init__.py +0 -0
  19. {rxnn-0.1.10 → rxnn-0.1.12}/src/rxnn/training/callbacks.py +0 -0
  20. {rxnn-0.1.10 → rxnn-0.1.12}/src/rxnn/training/dataset.py +0 -0
  21. {rxnn-0.1.10 → rxnn-0.1.12}/src/rxnn/training/scheduler.py +0 -0
  22. {rxnn-0.1.10 → rxnn-0.1.12}/src/rxnn/training/tokenizer.py +0 -0
  23. {rxnn-0.1.10 → rxnn-0.1.12}/src/rxnn/transformers/__init__.py +0 -0
  24. {rxnn-0.1.10 → rxnn-0.1.12}/src/rxnn/transformers/attention.py +0 -0
  25. {rxnn-0.1.10 → rxnn-0.1.12}/src/rxnn/transformers/ff.py +0 -0
  26. {rxnn-0.1.10 → rxnn-0.1.12}/src/rxnn/transformers/mask.py +0 -0
  27. {rxnn-0.1.10 → rxnn-0.1.12}/src/rxnn/transformers/positional.py +0 -0
  28. {rxnn-0.1.10 → rxnn-0.1.12}/src/rxnn/transformers/sampler.py +0 -0
  29. {rxnn-0.1.10 → rxnn-0.1.12}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.10
3
+ Version: 0.1.12
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.1.10"
7
+ version = "0.1.12"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -49,6 +49,7 @@ class BaseTrainer(ABC):
49
49
  self.validation_metrics = {}
50
50
  self.target_field_name = target_field_name
51
51
  self.total_tokens = 0
52
+ self.total_steps = 0
52
53
  self.gradient_accumulation_steps = gradient_accumulation_steps
53
54
  self.accumulated_loss = 0.0
54
55
  self.optimizer_step_count = 0
@@ -140,6 +141,7 @@ class BaseTrainer(ABC):
140
141
  for callback in self.callbacks:
141
142
  callback.on_batch_start(self.model, batch_idx, batch)
142
143
  if self.get_batch_size(batch) == batch_size:
144
+ self.total_steps += 1
143
145
  loss = self.train_step(batch, batch_idx)
144
146
  orig_loss = loss.item()
145
147
  self.accumulated_loss += orig_loss
@@ -226,11 +228,11 @@ class BaseTrainer(ABC):
226
228
  self.writer.close()
227
229
 
228
230
  def _valid_writer(self, epoch: int, val_loss: float, val_metrics: dict):
229
- self.writer.add_scalar('Loss/validation', val_loss, epoch)
230
- self.writer.add_scalar('Perplexity/validation', math.exp(val_loss), epoch)
231
+ self.writer.add_scalar('Loss/Valid', val_loss, epoch)
232
+ self.writer.add_scalar('Perplexity/Valid', math.exp(val_loss), epoch)
231
233
  if val_metrics['accuracy']:
232
- self.writer.add_scalar('Node Accuracy/validation', val_metrics['node_accuracy'], epoch)
233
- self.writer.add_scalar('Avg. Accuracy/validation', val_metrics['accuracy'], epoch)
234
+ self.writer.add_scalar('Node Accuracy/Valid', val_metrics['node_accuracy'], epoch)
235
+ self.writer.add_scalar('Avg. Accuracy/Valid', val_metrics['accuracy'], epoch)
234
236
 
235
237
  def valid_step(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
236
238
  if self.use_amp:
@@ -1,6 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
  import torch.nn.functional as F
4
+ from torch.nn.parallel import DistributedDataParallel
4
5
  import math
5
6
  from huggingface_hub import PyTorchModelHubMixin
6
7
  from typing import Union
@@ -171,6 +172,90 @@ class AutoregressiveTrainer(BaseTrainer):
171
172
  return avg_loss, metrics
172
173
 
173
174
 
175
+ class AutoregressiveMoeTrainer(BaseTrainer):
176
+ def __init__(
177
+ self,
178
+ model: ReactiveTransformerDecoder,
179
+ device: torch.device,
180
+ vocab_size: int,
181
+ use_amp: bool = False,
182
+ dtype: torch.dtype = None,
183
+ router_loss_scale: float = 0.1,
184
+ **kwargs
185
+ ):
186
+ super(AutoregressiveMoeTrainer, self).__init__(model, device, use_amp=use_amp, dtype=dtype,
187
+ target_field_name='targets', **kwargs)
188
+ self.vocab_size = vocab_size
189
+ self.router_loss_scale = router_loss_scale
190
+
191
+ def compute_loss(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
192
+ inputs = batch['input_ids']
193
+ attention_mask = batch['attention_mask']
194
+ targets = batch['targets']
195
+
196
+ outputs = self.model(
197
+ inputs,
198
+ attention_mask=attention_mask
199
+ )
200
+
201
+ shifted_logits = outputs[:, :-1].contiguous()
202
+ shifted_targets = targets[:, 1:].contiguous()
203
+
204
+ main_loss = F.cross_entropy(
205
+ shifted_logits.view(-1, self.vocab_size),
206
+ shifted_targets.view(-1)
207
+ )
208
+
209
+ model = next(self.model.children()) if isinstance(self.model, DistributedDataParallel) else self.model
210
+
211
+ router_loss = model.model.moe_router_loss()
212
+ loss = main_loss + self.router_loss_scale * router_loss
213
+
214
+ if self.writer is not None:
215
+ if self.model.training:
216
+ self.writer.add_scalar('Router aux loss/Train', router_loss.item(), self.total_steps)
217
+ self.writer.add_scalar('Model loss/Train', main_loss.item(), self.total_steps)
218
+ else:
219
+ self.writer.add_scalar('Router aux loss/Valid', router_loss.item(), self.total_steps)
220
+ self.writer.add_scalar('Model loss/Valid', main_loss.item(), self.total_steps)
221
+
222
+ return loss, outputs
223
+
224
+ def validate(self, batch_size: int) -> tuple[float, dict]:
225
+ self.model.eval()
226
+ val_dataloader = self._valid_loader(batch_size)
227
+ val_loss = torch.tensor(0.0).to(self.device)
228
+ correct = torch.tensor(0).to(self.device)
229
+ total = torch.tensor(0).to(self.device)
230
+
231
+ with torch.no_grad():
232
+ for batch in val_dataloader:
233
+ if self.get_batch_size(batch) == batch_size:
234
+ loss, logits = self.valid_step(batch)
235
+ val_loss += loss
236
+ shifted_logits = logits[:, :-1].contiguous()
237
+ shifted_targets = batch[self.target_field_name][:, 1:].to(self.device).contiguous()
238
+ valid_indices = shifted_targets != -100
239
+ if valid_indices.any():
240
+ preds = shifted_logits.argmax(-1)
241
+ correct += (preds[valid_indices] == shifted_targets[valid_indices]).sum()
242
+ total += valid_indices.sum()
243
+
244
+ avg_loss = (val_loss / len(val_dataloader)).item()
245
+ acc = (correct / total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
246
+ node_acc = acc.item()
247
+ if self.use_ddp:
248
+ dist.all_reduce(acc, op=dist.ReduceOp.SUM)
249
+ acc = acc / dist.get_world_size()
250
+
251
+ metrics = {
252
+ 'accuracy': acc.item(),
253
+ 'node_accuracy': node_acc,
254
+ }
255
+ self.model.train()
256
+ return avg_loss, metrics
257
+
258
+
174
259
  class JointTrainingModel(nn.Module):
175
260
  def __init__(
176
261
  self,
@@ -262,18 +347,18 @@ class JointLMTrainer(BaseTrainer):
262
347
  return (encoder_loss, decoder_loss), (encoder_logits, decoder_logits)
263
348
 
264
349
  def _valid_writer(self, epoch: int, val_loss: float, val_metrics: dict):
265
- self.writer.add_scalar('Loss/validation', val_loss, epoch)
266
- self.writer.add_scalar('Perplexity/validation', math.exp(val_loss), epoch)
350
+ self.writer.add_scalar('Loss/Valid', val_loss, epoch)
351
+ self.writer.add_scalar('Perplexity/Valid', math.exp(val_loss), epoch)
267
352
  if val_metrics['accuracy']:
268
- self.writer.add_scalar('Encoder node accuracy/validation', val_metrics['accuracy']['node_encoder'], epoch)
269
- self.writer.add_scalar('Decoder node accuracy/validation', val_metrics['accuracy']['node_decoder'], epoch)
270
- self.writer.add_scalar('Encoder avg. accuracy/validation', val_metrics['accuracy']['encoder'], epoch)
271
- self.writer.add_scalar('Decoder avg. accuracy/validation', val_metrics['accuracy']['decoder'], epoch)
353
+ self.writer.add_scalar('Encoder node accuracy/Valid', val_metrics['accuracy']['node_encoder'], epoch)
354
+ self.writer.add_scalar('Decoder node accuracy/Valid', val_metrics['accuracy']['node_decoder'], epoch)
355
+ self.writer.add_scalar('Encoder avg. accuracy/Valid', val_metrics['accuracy']['encoder'], epoch)
356
+ self.writer.add_scalar('Decoder avg. accuracy/Valid', val_metrics['accuracy']['decoder'], epoch)
272
357
  if val_metrics['loss']:
273
- self.writer.add_scalar('Encoder loss/validation', val_metrics['loss']['encoder'], epoch)
274
- self.writer.add_scalar('Encoder perplexity/validation', math.exp(val_metrics['loss']['encoder']), epoch)
275
- self.writer.add_scalar('Decoder accuracy/validation', val_metrics['loss']['decoder'], epoch)
276
- self.writer.add_scalar('Decoder perplexity/validation', math.exp(val_metrics['loss']['decoder']), epoch)
358
+ self.writer.add_scalar('Encoder loss/Valid', val_metrics['loss']['encoder'], epoch)
359
+ self.writer.add_scalar('Encoder perplexity/Valid', math.exp(val_metrics['loss']['encoder']), epoch)
360
+ self.writer.add_scalar('Decoder accuracy/Valid', val_metrics['loss']['decoder'], epoch)
361
+ self.writer.add_scalar('Decoder perplexity/Valid', math.exp(val_metrics['loss']['decoder']), epoch)
277
362
 
278
363
  def validate(self, batch_size: int) -> tuple[float, dict]:
279
364
  self.model.eval()
@@ -53,11 +53,15 @@ class ReactiveTransformerLayer(nn.Module):
53
53
  self.norm2 = nn.LayerNorm(embed_dim)
54
54
  self.norm3 = nn.LayerNorm(embed_dim)
55
55
  self.use_post_norm = use_post_norm
56
+ self.use_moe = use_moe
56
57
 
57
58
  def trainable_cross_attention_(self, is_trainable: bool):
58
59
  for param in self.memory_cross_attention.parameters():
59
60
  param.requires_grad_(is_trainable)
60
61
 
62
+ def moe_router_loss_(self):
63
+ return self.ff.router_loss() if self.use_moe else None
64
+
61
65
  def forward(self, x: torch.Tensor, stm: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
62
66
  # First step, self-attention
63
67
  residual = x
@@ -129,6 +133,10 @@ class ClassicTransformerLayer(nn.Module):
129
133
  self.norm1 = nn.LayerNorm(embed_dim)
130
134
  self.norm2 = nn.LayerNorm(embed_dim)
131
135
  self.use_post_norm = use_post_norm
136
+ self.use_moe = use_moe
137
+
138
+ def moe_router_loss_(self):
139
+ return self.ff.router_loss() if self.use_moe else torch.tensor(0.0)
132
140
 
133
141
  def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
134
142
  # First step, self-attention
@@ -37,6 +37,10 @@ class ReactiveTransformerBase(nn.Module):
37
37
  for i in range(self.num_own_layers):
38
38
  self.layers[i].trainable_cross_attention_(is_trainable)
39
39
 
40
+ def moe_router_loss_(self):
41
+ return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_own_layers) if self.layers[i].use_moe] + [
42
+ self.shared_layers[i].moe_router_loss() for i in range(self.num_shared_layers) if self.shared_layers[i].use_moe]).mean()
43
+
40
44
  def forward(self, x: torch.Tensor) -> torch.Tensor:
41
45
  # Shared logic for encoders and decoders - apply embeddings and positional encoding
42
46
  x = self.embedding(x)
@@ -119,6 +123,9 @@ class ClassicTransformerBase(nn.Module):
119
123
  self.layers = layers
120
124
  self.num_layers = len(layers) if layers else 0
121
125
 
126
+ def moe_router_loss_(self):
127
+ return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_layers) if self.layers[i].use_moe]).mean()
128
+
122
129
  def forward(self, x: torch.Tensor) -> torch.Tensor:
123
130
  # Shared logic for encoders and decoders - apply embeddings and positional encoding
124
131
  x = self.embedding(x)
@@ -11,7 +11,8 @@ class MoeRouter(nn.Module):
11
11
  self.top_k = top_k
12
12
  self.num_experts = num_experts
13
13
  self.gate = nn.Linear(embed_dim, num_experts, bias=False)
14
- self.aux_loss = 0.0 # For expert load balancing
14
+ # For expert load balancing
15
+ self.register_buffer('aux_loss', torch.tensor(0.0), persistent=False)
15
16
 
16
17
  def forward(self, x: torch.Tensor):
17
18
  # x shape: [batch_size*seq_len, embed_dim]
@@ -19,10 +20,8 @@ class MoeRouter(nn.Module):
19
20
  probs = F.softmax(logits, dim=-1)
20
21
 
21
22
  # Expert load balancing loss
22
- if self.training:
23
- probs_for_bal = F.softmax(logits, dim=0)
24
- self.aux_loss = (probs_for_bal.mean(dim=0) *
25
- torch.log(probs_for_bal.mean(dim=0) + 1e-9)).sum()
23
+ mean_probs = probs.mean(dim=0) # Mean probability per expert across batch
24
+ self.aux_loss = (mean_probs * torch.log(mean_probs + 1e-9)).sum() # Entropy-based loss
26
25
 
27
26
  top_k_weights, top_k_indices = probs.topk(self.top_k, dim=-1)
28
27
  top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9)
@@ -74,6 +73,9 @@ class MoeFeedForward(nn.Module):
74
73
  def _activate(self, h: torch.Tensor):
75
74
  return self.activation(h)
76
75
 
76
+ def router_loss(self):
77
+ return self.router.aux_loss
78
+
77
79
  def forward(self, x: torch.Tensor):
78
80
  orig_shape = x.shape
79
81
  x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes