rxnn 0.2.38__tar.gz → 0.2.40__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rxnn-0.2.38 → rxnn-0.2.40}/PKG-INFO +1 -1
  2. {rxnn-0.2.38 → rxnn-0.2.40}/pyproject.toml +1 -1
  3. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/rxt/models.py +11 -5
  4. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/training/bml.py +0 -8
  5. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/training/models.py +27 -9
  6. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/training/mrl.py +44 -10
  7. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/transformers/models.py +7 -2
  8. {rxnn-0.2.38 → rxnn-0.2.40}/LICENSE +0 -0
  9. {rxnn-0.2.38 → rxnn-0.2.40}/README.md +0 -0
  10. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/.DS_Store +0 -0
  11. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/__init__.py +0 -0
  12. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/experimental/__init__.py +0 -0
  13. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/experimental/attention.py +0 -0
  14. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/experimental/models.py +0 -0
  15. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/experimental/moe.py +0 -0
  16. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/memory/__init__.py +0 -0
  17. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/memory/attention.py +0 -0
  18. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/memory/norm.py +0 -0
  19. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/memory/stm.py +0 -0
  20. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/rxt/__init__.py +0 -0
  21. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/training/__init__.py +0 -0
  22. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/training/base.py +0 -0
  23. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/training/callbacks.py +0 -0
  24. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/training/dataset.py +0 -0
  25. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/training/ddp.py +0 -0
  26. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/training/reward.py +0 -0
  27. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/training/rl.py +0 -0
  28. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/training/scheduler.py +0 -0
  29. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/training/tokenizer.py +0 -0
  30. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/training/utils.py +0 -0
  31. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/transformers/__init__.py +0 -0
  32. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/transformers/attention.py +0 -0
  33. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/transformers/ff.py +0 -0
  34. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/transformers/layers.py +0 -0
  35. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/transformers/mask.py +0 -0
  36. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/transformers/sampler.py +0 -0
  39. {rxnn-0.2.38 → rxnn-0.2.40}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.38
3
+ Version: 0.2.40
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.38"
7
+ version = "0.2.40"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -130,10 +130,10 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
130
130
  memory_cross_attention=cross_att_init(),
131
131
  ) for _ in range(num_layers)
132
132
  ])
133
- self.model = self._init_model(stm, layers, embedding, use_flash_attention, embed_dim, vocab_size)
133
+ self.model = self._init_model(stm, layers, embedding, use_flash_attention, embed_dim, vocab_size, use_moe)
134
134
 
135
135
  def _init_model(self, stm: ShortTermMemory, layers: nn.ModuleList, embedding: nn.Embedding,
136
- use_flash_attention: bool, embed_dim: int, vocab_size: int) -> ReactiveTransformerBase:
136
+ use_flash_attention: bool, embed_dim: int, vocab_size: int, use_moe: bool) -> ReactiveTransformerBase:
137
137
  pass
138
138
 
139
139
  def params_count(self):
@@ -185,13 +185,15 @@ class RxTAlphaEncoder(RxTAlphaComponentBase, pipeline_tag="fill-mask", license="
185
185
  embedding: nn.Embedding,
186
186
  use_flash_attention: bool,
187
187
  embed_dim: int,
188
- vocab_size: int
188
+ vocab_size: int,
189
+ use_moe: bool,
189
190
  ) -> ReactiveTransformerEncoder:
190
191
  return ReactiveTransformerEncoder(
191
192
  stm=stm,
192
193
  embedding=embedding,
193
194
  own_layers=layers,
194
195
  use_flash_attention=use_flash_attention,
196
+ use_moe=use_moe,
195
197
  )
196
198
 
197
199
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
@@ -210,7 +212,8 @@ class RxTAlphaDecoder(RxTAlphaComponentBase, pipeline_tag="text-generation", lic
210
212
  embedding: nn.Embedding,
211
213
  use_flash_attention: bool,
212
214
  embed_dim: int,
213
- vocab_size: int
215
+ vocab_size: int,
216
+ use_moe: bool,
214
217
  ) -> ReactiveTransformerDecoder:
215
218
  return ReactiveTransformerDecoder(
216
219
  embed_dim,
@@ -219,6 +222,7 @@ class RxTAlphaDecoder(RxTAlphaComponentBase, pipeline_tag="text-generation", lic
219
222
  embedding=embedding,
220
223
  own_layers=layers,
221
224
  use_flash_attention=use_flash_attention,
225
+ use_moe=use_moe,
222
226
  )
223
227
 
224
228
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
@@ -307,13 +311,15 @@ class RxTAlphaCriticEncoder(RxTAlphaComponentBase, pipeline_tag="text-classifica
307
311
  embedding: nn.Embedding,
308
312
  use_flash_attention: bool,
309
313
  embed_dim: int,
310
- vocab_size: int
314
+ vocab_size: int,
315
+ use_moe: bool = False,
311
316
  ) -> ReactiveTransformerEncoderDetachStm:
312
317
  return ReactiveTransformerEncoderDetachStm(
313
318
  stm=stm,
314
319
  embedding=embedding,
315
320
  own_layers=layers,
316
321
  use_flash_attention=use_flash_attention,
322
+ use_moe=use_moe,
317
323
  )
318
324
 
319
325
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
@@ -51,10 +51,6 @@ class MLMTrainer(BaseTrainer):
51
51
  model = next(self.model.children()) if isinstance(self.model, DistributedDataParallel) else self.model
52
52
 
53
53
  router_loss = model.encoder.model.moe_router_loss()
54
-
55
- if self.use_ddp:
56
- router_loss = distributed_mean(router_loss)
57
-
58
54
  loss = main_loss + self.moe_aux_loss_scale * router_loss
59
55
 
60
56
  if self.writer is not None:
@@ -156,10 +152,6 @@ class AutoregressiveTrainer(BaseTrainer):
156
152
  model = next(self.model.children()) if isinstance(self.model, DistributedDataParallel) else self.model
157
153
 
158
154
  router_loss = model.model.moe_router_loss()
159
-
160
- if self.use_ddp:
161
- router_loss = distributed_mean(router_loss)
162
-
163
155
  loss = main_loss + self.moe_aux_loss_scale * router_loss
164
156
 
165
157
  if self.writer is not None:
@@ -82,23 +82,31 @@ class MrlActorModel(nn.Module):
82
82
 
83
83
  def freeze_components(self, stage: Literal['update', 'fetch', 'joint'] = 'joint'):
84
84
  """Freeze encoder/decoder except memory-related layers."""
85
+ # Freeze/unfreeze encoder
85
86
  if self.encoder.freeze_without_memory is not None:
86
- self.encoder.freeze_without_memory(unfreeze_norms=True)
87
- if stage == 'update':
87
+ if stage == 'update' or stage == 'joint':
88
+ self.encoder.unfreeze_all()
89
+ else:
90
+ self.encoder.freeze_without_memory(unfreeze_norms=True)
88
91
  self.encoder.freeze_memory(with_norms=True)
89
92
  else:
90
93
  for param in self.encoder.parameters():
91
- param.requires_grad = False
92
- self.encoder.model.trainable_cross_attention_(True if stage != 'update' else False, with_norms=True)
94
+ param.requires_grad = True if stage != 'fetch' else False
95
+ self.encoder.model.trainable_cross_attention_(True if stage != 'fetch' else False, with_norms=True)
96
+ # Freeze/unfreeze decoder
93
97
  if self.decoder.freeze_without_memory is not None:
94
- self.decoder.freeze_without_memory(unfreeze_norms=True)
95
- if stage == 'update':
96
- self.decoder.freeze_memory(with_norms=True)
98
+ if stage == 'fetch':
99
+ self.decoder.unfreeze_all()
100
+ else:
101
+ self.decoder.freeze_without_memory(unfreeze_norms=True)
102
+ if stage == 'update':
103
+ self.decoder.freeze_memory(with_norms=True)
97
104
  else:
98
105
  for param in self.decoder.parameters():
99
- param.requires_grad = False
106
+ param.requires_grad = True if stage == 'fetch' else False
100
107
  self.decoder.model.trainable_cross_attention_(True if stage != 'update' else False, with_norms=True)
101
- # Unfreeze memory attention
108
+
109
+ # Freeze/unfreeze memory attention
102
110
  if self.memory_attention.freeze is not None:
103
111
  if stage == 'fetch':
104
112
  self.memory_attention.freeze()
@@ -158,6 +166,16 @@ class MrlActorModel(nn.Module):
158
166
  list(self.memory_attention.parameters())
159
167
  ))
160
168
 
169
+ def moe_router_loss(self):
170
+ if self.encoder.model.use_moe and self.decoder.model.use_moe:
171
+ return (self.encoder.model.moe_router_loss() + self.decoder.model.moe_router_loss()) / 2
172
+ elif self.encoder.model.use_moe:
173
+ return self.encoder.model.moe_router_loss()
174
+ elif self.decoder.model.use_moe:
175
+ return self.decoder.model.moe_router_loss()
176
+ else:
177
+ return None
178
+
161
179
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None,
162
180
  action: MrlActorAction = MrlActorAction.DECODE) -> torch.Tensor:
163
181
  if action == MrlActorAction.DECODE:
@@ -31,6 +31,8 @@ class MrlConfig(TypedDict):
31
31
  end_token_id: int
32
32
  callbacks: Optional[list[MrlTrainerCallback]]
33
33
  memory_aware_critic: bool
34
+ use_moe_aux_loss: bool
35
+ moe_aux_loss_scale: float
34
36
 
35
37
 
36
38
  class MrlStrategy(Enum):
@@ -125,6 +127,8 @@ class MRLTrainer:
125
127
  self.max_seq_len = config.get('max_seq_len', 256)
126
128
  self.critic_max_len = config.get('critic_max_len', 512)
127
129
  self.memory_aware_critic = config.get('memory_aware_critic', False)
130
+ self.use_moe_aux_loss = config.get('use_moe_aux_loss', False)
131
+ self.moe_aux_loss_scale = config.get('moe_aux_loss_scale', 0.01)
128
132
  # Internal update epochs config
129
133
  self.shared_update_epochs = config.get('update_epochs', 10)
130
134
  self.update_epochs = self.shared_update_epochs
@@ -212,6 +216,7 @@ class MRLTrainer:
212
216
  ) -> tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
213
217
  if memory_lr is not None:
214
218
  optimizer = torch.optim.AdamW([
219
+ {'params': self.actor.encoder.embedding.parameters(), 'lr': lr},
215
220
  {'params': self.actor.not_memory_parameters(), 'lr': lr},
216
221
  {'params': self.actor.memory_parameters(), 'lr': memory_lr},
217
222
  ],
@@ -522,6 +527,18 @@ class MRLTrainer:
522
527
  # 6. Return loss item
523
528
  return critic_loss_item
524
529
 
530
+ def _moe_aux_loss(self, main_loss: torch.Tensor) -> torch.Tensor:
531
+ if not self.use_moe_aux_loss:
532
+ return main_loss
533
+
534
+ actor = next(self.actor.children()) if isinstance(self.actor, DistributedDataParallel) else self.actor
535
+
536
+ router_loss = actor.moe_router_loss()
537
+ if router_loss is not None:
538
+ return main_loss + self.moe_aux_loss_scale * router_loss
539
+ else:
540
+ return main_loss
541
+
525
542
  def update_actor(self, state: tuple[TokenizedDict, TokenizedDict, TokenizedDict], action: TokenizedDict,
526
543
  advantages: torch.Tensor, old_log_probs: torch.Tensor, epoch: int) -> float:
527
544
  # 1. Reset actor gradients
@@ -544,6 +561,8 @@ class MRLTrainer:
544
561
  # 4.2 Calculate policy loss with selected algorithm
545
562
  policy_loss = self.rl_algorithm.policy_loss(next_query, action, logits, old_log_probs,
546
563
  advantages)
564
+ policy_loss = self._moe_aux_loss(policy_loss)
565
+
547
566
  # 4.3 Run backpropagation with scaler
548
567
  self.scaler.scale(policy_loss).backward(retain_graph=True)
549
568
  # 4.4 Unscale and clip gradient norms
@@ -561,6 +580,7 @@ class MRLTrainer:
561
580
  action=MrlActorAction.DECODE)
562
581
  # 4.2 Calculate policy loss with selected algorithm
563
582
  policy_loss = self.rl_algorithm.policy_loss(next_query, action, logits, old_log_probs, advantages)
583
+ policy_loss = self._moe_aux_loss(policy_loss)
564
584
  # 4.3 Run backpropagation
565
585
  policy_loss.backward(retain_graph=True)
566
586
  # 4.4 Clip gradient norms
@@ -852,7 +872,7 @@ class MRLTrainer:
852
872
  if isinstance(update_epoch, tuple):
853
873
  switch_epoch, cross_att_lr = update_epoch
854
874
  if epoch == switch_epoch:
855
- self.actor.freeze_components('joint')
875
+ self.actor.unfreeze_components()
856
876
  self.optimizer = self._init_unfreeze_optimizer('update', cross_att_lr)
857
877
  print(f"Activating 'update' unfreeze strategy with custom cross_att_lr: {cross_att_lr}")
858
878
  elif epoch == update_epoch:
@@ -863,7 +883,7 @@ class MRLTrainer:
863
883
  if isinstance(fetch_epoch, tuple):
864
884
  switch_epoch, mem_att_lr = fetch_epoch
865
885
  if epoch == switch_epoch:
866
- self.actor.freeze_components('joint')
886
+ self.actor.unfreeze_components()
867
887
  self.optimizer = self._init_unfreeze_optimizer('fetch', mem_att_lr)
868
888
  print(f"Activating 'fetch' unfreeze strategy with custom mem_att_lr: {mem_att_lr}")
869
889
  elif epoch == fetch_epoch:
@@ -899,25 +919,39 @@ class MRLTrainer:
899
919
 
900
920
  if mode == 'update':
901
921
  params = [
902
- {'params': self.actor.not_memory_parameters(), 'lr': model_lr},
922
+ {'params': self.actor.encoder.embedding.parameters(), 'lr': model_lr},
923
+ {'params': self.actor.encoder.not_memory_parameters(), 'lr': model_lr},
924
+ {'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
903
925
  {'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
904
- {'params': self.actor.memory_cross_attention_parameters(), 'lr': unfreeze_lr},
926
+ {'params': self.actor.decoder.memory_parameters(), 'lr': unfreeze_lr},
927
+ {'params': self.actor.decoder.not_memory_parameters(), 'lr': unfreeze_lr},
905
928
  ]
906
929
  elif mode == 'fetch':
907
930
  params = [
908
- {'params': self.actor.not_memory_parameters(), 'lr': model_lr},
909
- {'params': self.actor.memory_cross_attention_parameters(), 'lr': memory_lr},
931
+ {'params': self.actor.encoder.embedding.parameters(), 'lr': unfreeze_lr},
932
+ {'params': self.actor.encoder.not_memory_parameters(), 'lr': unfreeze_lr},
933
+ {'params': self.actor.encoder.memory_parameters(), 'lr': unfreeze_lr},
910
934
  {'params': self.actor.memory_attention_parameters(), 'lr': unfreeze_lr},
935
+ {'params': self.actor.decoder.memory_parameters(), 'lr': memory_lr},
936
+ {'params': self.actor.decoder.not_memory_parameters(), 'lr': model_lr},
911
937
  ]
912
938
  elif mode == 'joint':
913
939
  params = [
914
- {'params': self.actor.not_memory_parameters(), 'lr': unfreeze_lr},
915
- {'params': self.actor.memory_parameters(), 'lr': memory_lr},
940
+ {'params': self.actor.encoder.embedding.parameters(), 'lr': unfreeze_lr},
941
+ {'params': self.actor.encoder.not_memory_parameters(), 'lr': unfreeze_lr},
942
+ {'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
943
+ {'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
944
+ {'params': self.actor.decoder.memory_parameters(), 'lr': memory_lr},
945
+ {'params': self.actor.decoder.not_memory_parameters(), 'lr': unfreeze_lr},
916
946
  ]
917
947
  else:
918
948
  params = [
919
- {'params': self.actor.not_memory_parameters(), 'lr': model_lr},
920
- {'params': self.actor.memory_parameters(), 'lr': memory_lr},
949
+ {'params': self.actor.encoder.embedding.parameters(), 'lr': model_lr},
950
+ {'params': self.actor.encoder.not_memory_parameters(), 'lr': model_lr},
951
+ {'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
952
+ {'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
953
+ {'params': self.actor.decoder.memory_parameters(), 'lr': memory_lr},
954
+ {'params': self.actor.decoder.not_memory_parameters(), 'lr': model_lr},
921
955
  ]
922
956
 
923
957
  return torch.optim.AdamW(params, weight_decay=self.optim_config['weight_decay'])
@@ -17,6 +17,7 @@ class ReactiveTransformerBase(nn.Module):
17
17
  absolute_embedding: AbsolutePositionalEmbedding = None,
18
18
  use_flash_attention: bool = False,
19
19
  use_relative_embedding: bool = False,
20
+ use_moe: bool = False,
20
21
  *args,
21
22
  **kwargs,
22
23
  ):
@@ -32,6 +33,7 @@ class ReactiveTransformerBase(nn.Module):
32
33
  self.layers = own_layers
33
34
  self.num_shared_layers = len(shared_layers) if shared_layers else 0
34
35
  self.num_own_layers = len(own_layers) if own_layers else 0
36
+ self.use_moe = use_moe
35
37
 
36
38
  def trainable_cross_attention_(self, is_trainable: bool, with_norms: bool = True):
37
39
  for i in range(self.num_shared_layers):
@@ -50,8 +52,11 @@ class ReactiveTransformerBase(nn.Module):
50
52
  return own + shared
51
53
 
52
54
  def moe_router_loss(self):
53
- return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_own_layers) if self.layers[i].use_moe or self.layers[i].use_moe_att] + [
54
- self.shared_layers[i].moe_router_loss() for i in range(self.num_shared_layers) if self.shared_layers[i].use_moe or self.shared_layers[i].use_moe_att]).mean()
55
+ if self.use_moe:
56
+ return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_own_layers) if self.layers[i].use_moe or self.layers[i].use_moe_att] + [
57
+ self.shared_layers[i].moe_router_loss() for i in range(self.num_shared_layers) if self.shared_layers[i].use_moe or self.shared_layers[i].use_moe_att]).mean()
58
+ else:
59
+ return None
55
60
 
56
61
  def forward(self, x: torch.Tensor) -> torch.Tensor:
57
62
  # Shared logic for encoders and decoders - apply embeddings and positional encoding
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes