rxnn 0.2.39__py3-none-any.whl → 0.2.41__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/memory/attention.py CHANGED
@@ -8,6 +8,9 @@ class StmMemoryAttention(nn.Module):
8
8
  stm: ShortTermMemory,
9
9
  attention_layers: nn.ModuleList,
10
10
  memory_norm_layers: nn.ModuleList,
11
+ use_gated_residual: bool = False,
12
+ per_slot_gate: bool = False,
13
+ init_gate: float = 0.0,
11
14
  *args,
12
15
  **kwargs
13
16
  ):
@@ -17,6 +20,10 @@ class StmMemoryAttention(nn.Module):
17
20
  self.memory_norm_layers = memory_norm_layers
18
21
  assert len(self.attention_layers) == len(self.memory_norm_layers) == self.stm.memory.size(0)
19
22
  self.num_layers = len(attention_layers)
23
+ self.use_gated_residual = use_gated_residual
24
+ self.per_slot_gate = per_slot_gate
25
+ if self.use_gated_residual:
26
+ self.gate = nn.Parameter(torch.full((self.num_layers, self.stm.stm_size, 1), init_gate) if self.per_slot_gate else torch.full((self.num_layers,), init_gate))
20
27
 
21
28
  def update_max_len(self, max_seq_len: int):
22
29
  for i in range(self.num_layers):
@@ -35,7 +42,12 @@ class StmMemoryAttention(nn.Module):
35
42
  encoded_layer_data = x[i]
36
43
  normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
37
44
  new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=mask)
38
- new_stm[i] = new_layer_stm + layer_stm # residual
45
+ if self.use_gated_residual:
46
+ # gated residual
47
+ layer_gate = torch.sigmoid(self.gate[i])
48
+ new_stm[i] = layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
49
+ else:
50
+ new_stm[i] = new_layer_stm + layer_stm # residual
39
51
  self.stm.update_all(new_stm)
40
52
  return self.stm.memory
41
53
 
rxnn/rxt/models.py CHANGED
@@ -130,10 +130,10 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
130
130
  memory_cross_attention=cross_att_init(),
131
131
  ) for _ in range(num_layers)
132
132
  ])
133
- self.model = self._init_model(stm, layers, embedding, use_flash_attention, embed_dim, vocab_size)
133
+ self.model = self._init_model(stm, layers, embedding, use_flash_attention, embed_dim, vocab_size, use_moe)
134
134
 
135
135
  def _init_model(self, stm: ShortTermMemory, layers: nn.ModuleList, embedding: nn.Embedding,
136
- use_flash_attention: bool, embed_dim: int, vocab_size: int) -> ReactiveTransformerBase:
136
+ use_flash_attention: bool, embed_dim: int, vocab_size: int, use_moe: bool) -> ReactiveTransformerBase:
137
137
  pass
138
138
 
139
139
  def params_count(self):
@@ -185,13 +185,15 @@ class RxTAlphaEncoder(RxTAlphaComponentBase, pipeline_tag="fill-mask", license="
185
185
  embedding: nn.Embedding,
186
186
  use_flash_attention: bool,
187
187
  embed_dim: int,
188
- vocab_size: int
188
+ vocab_size: int,
189
+ use_moe: bool,
189
190
  ) -> ReactiveTransformerEncoder:
190
191
  return ReactiveTransformerEncoder(
191
192
  stm=stm,
192
193
  embedding=embedding,
193
194
  own_layers=layers,
194
195
  use_flash_attention=use_flash_attention,
196
+ use_moe=use_moe,
195
197
  )
196
198
 
197
199
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
@@ -210,7 +212,8 @@ class RxTAlphaDecoder(RxTAlphaComponentBase, pipeline_tag="text-generation", lic
210
212
  embedding: nn.Embedding,
211
213
  use_flash_attention: bool,
212
214
  embed_dim: int,
213
- vocab_size: int
215
+ vocab_size: int,
216
+ use_moe: bool,
214
217
  ) -> ReactiveTransformerDecoder:
215
218
  return ReactiveTransformerDecoder(
216
219
  embed_dim,
@@ -219,6 +222,7 @@ class RxTAlphaDecoder(RxTAlphaComponentBase, pipeline_tag="text-generation", lic
219
222
  embedding=embedding,
220
223
  own_layers=layers,
221
224
  use_flash_attention=use_flash_attention,
225
+ use_moe=use_moe,
222
226
  )
223
227
 
224
228
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
@@ -246,6 +250,9 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
246
250
  norm_init_gate: float = -2.0,
247
251
  norm_per_dim_scale: bool = False,
248
252
  norm_decay: float = 0.9,
253
+ use_gated_residual: bool = False,
254
+ residual_per_slot_gate: bool = False,
255
+ residual_init_gate: float = 0.0,
249
256
  **kwargs,
250
257
  ):
251
258
  super(RxTAlphaMemoryAttention, self).__init__(**kwargs)
@@ -272,7 +279,10 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
272
279
  init_gate=norm_init_gate, per_dim_scale=norm_per_dim_scale)
273
280
  for _ in range(num_layers)])
274
281
  attention_layers = nn.ModuleList([att_init() for _ in range(num_layers)])
275
- self.model = StmMemoryAttention(stm, attention_layers, memory_norm_layers)
282
+ self.model = StmMemoryAttention(
283
+ stm, attention_layers, memory_norm_layers,
284
+ use_gated_residual=use_gated_residual, per_slot_gate=residual_per_slot_gate, init_gate=residual_init_gate
285
+ )
276
286
 
277
287
  def freeze(self):
278
288
  for param in self.parameters():
@@ -307,13 +317,15 @@ class RxTAlphaCriticEncoder(RxTAlphaComponentBase, pipeline_tag="text-classifica
307
317
  embedding: nn.Embedding,
308
318
  use_flash_attention: bool,
309
319
  embed_dim: int,
310
- vocab_size: int
320
+ vocab_size: int,
321
+ use_moe: bool = False,
311
322
  ) -> ReactiveTransformerEncoderDetachStm:
312
323
  return ReactiveTransformerEncoderDetachStm(
313
324
  stm=stm,
314
325
  embedding=embedding,
315
326
  own_layers=layers,
316
327
  use_flash_attention=use_flash_attention,
328
+ use_moe=use_moe,
317
329
  )
318
330
 
319
331
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
rxnn/training/models.py CHANGED
@@ -80,25 +80,33 @@ class MrlActorModel(nn.Module):
80
80
  self.decoder = decoder
81
81
  self.memory_attention = memory_attention
82
82
 
83
- def freeze_components(self, stage: Literal['update', 'fetch', 'joint'] = 'joint'):
83
+ def freeze_components(self, stage: Literal['update', 'fetch', 'joint'] = 'joint', freeze_embeddings: bool = False):
84
84
  """Freeze encoder/decoder except memory-related layers."""
85
+ # Freeze/unfreeze encoder
85
86
  if self.encoder.freeze_without_memory is not None:
86
- self.encoder.freeze_without_memory(unfreeze_norms=True)
87
- if stage == 'update':
87
+ if stage == 'update' or stage == 'joint':
88
+ self.encoder.unfreeze_all()
89
+ else:
90
+ self.encoder.freeze_without_memory(unfreeze_norms=True)
88
91
  self.encoder.freeze_memory(with_norms=True)
89
92
  else:
90
93
  for param in self.encoder.parameters():
91
- param.requires_grad = False
92
- self.encoder.model.trainable_cross_attention_(True if stage != 'update' else False, with_norms=True)
94
+ param.requires_grad = True if stage != 'fetch' else False
95
+ self.encoder.model.trainable_cross_attention_(True if stage != 'fetch' else False, with_norms=True)
96
+ # Freeze/unfreeze decoder
93
97
  if self.decoder.freeze_without_memory is not None:
94
- self.decoder.freeze_without_memory(unfreeze_norms=True)
95
- if stage == 'update':
96
- self.decoder.freeze_memory(with_norms=True)
98
+ if stage == 'fetch':
99
+ self.decoder.unfreeze_all()
100
+ else:
101
+ self.decoder.freeze_without_memory(unfreeze_norms=True)
102
+ if stage == 'update':
103
+ self.decoder.freeze_memory(with_norms=True)
97
104
  else:
98
105
  for param in self.decoder.parameters():
99
- param.requires_grad = False
106
+ param.requires_grad = True if stage == 'fetch' else False
100
107
  self.decoder.model.trainable_cross_attention_(True if stage != 'update' else False, with_norms=True)
101
- # Unfreeze memory attention
108
+
109
+ # Freeze/unfreeze memory attention
102
110
  if self.memory_attention.freeze is not None:
103
111
  if stage == 'fetch':
104
112
  self.memory_attention.freeze()
@@ -108,7 +116,11 @@ class MrlActorModel(nn.Module):
108
116
  for param in self.memory_attention.parameters():
109
117
  param.requires_grad = True if stage != 'fetch' else False
110
118
 
111
- def unfreeze_components(self):
119
+ if freeze_embeddings:
120
+ for param in self.encoder.model.embedding.parameters():
121
+ param.requires_grad = False
122
+
123
+ def unfreeze_components(self, freeze_embeddings: bool = False):
112
124
  """Unfreeze all components after initial training."""
113
125
  if self.encoder.unfreeze_all is not None:
114
126
  self.encoder.unfreeze_all()
@@ -126,6 +138,11 @@ class MrlActorModel(nn.Module):
126
138
  for param in self.memory_attention.parameters():
127
139
  param.requires_grad = True
128
140
 
141
+ if freeze_embeddings:
142
+ for param in self.encoder.model.embedding.parameters():
143
+ param.requires_grad = False
144
+
145
+
129
146
  def reset_memory(self):
130
147
  self.memory_attention.reset_memory()
131
148
 
@@ -151,12 +168,29 @@ class MrlActorModel(nn.Module):
151
168
  self.decoder.not_memory_parameters()
152
169
  ))
153
170
 
154
- def unique_parameters(self):
155
- return list(set(
156
- list(self.encoder.parameters()) +
157
- list(self.decoder.parameters()) +
158
- list(self.memory_attention.parameters())
159
- ))
171
+ def unique_parameters(self, with_embedding: bool = True):
172
+ if with_embedding:
173
+ return list(set(
174
+ list(self.encoder.parameters()) +
175
+ list(self.decoder.parameters()) +
176
+ list(self.memory_attention.parameters())
177
+ ))
178
+ else:
179
+ return list(set(
180
+ self.not_memory_parameters() +
181
+ self.memory_cross_attention_parameters() +
182
+ list(self.memory_attention_parameters())
183
+ ))
184
+
185
+ def moe_router_loss(self):
186
+ if self.encoder.model.use_moe and self.decoder.model.use_moe:
187
+ return (self.encoder.model.moe_router_loss() + self.decoder.model.moe_router_loss()) / 2
188
+ elif self.encoder.model.use_moe:
189
+ return self.encoder.model.moe_router_loss()
190
+ elif self.decoder.model.use_moe:
191
+ return self.decoder.model.moe_router_loss()
192
+ else:
193
+ return None
160
194
 
161
195
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None,
162
196
  action: MrlActorAction = MrlActorAction.DECODE) -> torch.Tensor:
rxnn/training/mrl.py CHANGED
@@ -21,16 +21,20 @@ class MrlConfig(TypedDict):
21
21
  separate_memory_lr: Optional[bool]
22
22
  memory_lr: Optional[float]
23
23
  critic_lr: float
24
- critic_encoder_lr: float
24
+ critic_encoder_lr: Optional[float]
25
25
  max_seq_len: int
26
26
  critic_max_len: int
27
- weight_decay: float
28
- critic_weight_decay: float
27
+ weight_decay: Optional[float]
28
+ critic_weight_decay: Optional[float]
29
29
  update_epochs: int
30
30
  pad_token_id: int
31
31
  end_token_id: int
32
32
  callbacks: Optional[list[MrlTrainerCallback]]
33
- memory_aware_critic: bool
33
+ memory_aware_critic: Optional[bool]
34
+ use_moe_aux_loss: Optional[bool]
35
+ moe_aux_loss_scale: Optional[float]
36
+ freeze_embeddings: Optional[bool]
37
+ embedding_lr: Optional[float]
34
38
 
35
39
 
36
40
  class MrlStrategy(Enum):
@@ -64,6 +68,8 @@ class CurriculumConfig(TypedDict):
64
68
  weight_decay: Optional[float]
65
69
  critic_weight_decay: Optional[float]
66
70
  update_epochs: Optional[int]
71
+ freeze_embeddings: Optional[bool]
72
+ embedding_lr: Optional[float]
67
73
 
68
74
 
69
75
  class SamplerConfig(TypedDict):
@@ -125,6 +131,10 @@ class MRLTrainer:
125
131
  self.max_seq_len = config.get('max_seq_len', 256)
126
132
  self.critic_max_len = config.get('critic_max_len', 512)
127
133
  self.memory_aware_critic = config.get('memory_aware_critic', False)
134
+ self.use_moe_aux_loss = config.get('use_moe_aux_loss', False)
135
+ self.moe_aux_loss_scale = config.get('moe_aux_loss_scale', 0.01)
136
+ self.shared_freeze_embeddings = config.get('freeze_embeddings', False)
137
+ self.freeze_embeddings = self.shared_freeze_embeddings
128
138
  # Internal update epochs config
129
139
  self.shared_update_epochs = config.get('update_epochs', 10)
130
140
  self.update_epochs = self.shared_update_epochs
@@ -162,6 +172,7 @@ class MRLTrainer:
162
172
  'weight_decay': config.get('weight_decay', 0.01),
163
173
  'critic_weight_decay': config.get('critic_weight_decay', 0.01),
164
174
  'critic_encoder_lr': config.get('critic_encoder_lr', config.get('critic_lr', 1e-4)),
175
+ 'embedding_lr': config.get('embedding_lr', config.get('lr', 3e-4)),
165
176
  }
166
177
  else:
167
178
  self.base_optim_config = {
@@ -170,6 +181,7 @@ class MRLTrainer:
170
181
  'weight_decay': config.get('weight_decay', 0.01),
171
182
  'critic_weight_decay': config.get('critic_weight_decay', 0.01),
172
183
  'critic_encoder_lr': config.get('critic_encoder_lr', config.get('critic_lr', 1e-4)),
184
+ 'embedding_lr': config.get('embedding_lr', config.get('lr', 3e-4)),
173
185
  }
174
186
 
175
187
  self.optim_config = self.base_optim_config
@@ -208,19 +220,22 @@ class MRLTrainer:
208
220
  weight_decay: float,
209
221
  critic_weight_decay: float,
210
222
  critic_encoder_lr: float,
223
+ embedding_lr: float,
211
224
  memory_lr: Optional[float] = None,
212
225
  ) -> tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
213
226
  if memory_lr is not None:
214
227
  optimizer = torch.optim.AdamW([
228
+ {'params': self.actor.encoder.embedding.parameters(), 'lr': embedding_lr},
215
229
  {'params': self.actor.not_memory_parameters(), 'lr': lr},
216
230
  {'params': self.actor.memory_parameters(), 'lr': memory_lr},
217
231
  ],
218
232
  weight_decay=weight_decay,
219
233
  )
220
234
  else:
221
- optimizer = torch.optim.AdamW(
222
- self.actor.unique_parameters(),
223
- lr=lr,
235
+ optimizer = torch.optim.AdamW([
236
+ {'params': self.actor.encoder.embedding.parameters(), 'lr': embedding_lr},
237
+ {'params': self.actor.unique_parameters(with_embedding=False), 'lr': lr},
238
+ ],
224
239
  weight_decay=weight_decay,
225
240
  )
226
241
 
@@ -522,6 +537,18 @@ class MRLTrainer:
522
537
  # 6. Return loss item
523
538
  return critic_loss_item
524
539
 
540
+ def _moe_aux_loss(self, main_loss: torch.Tensor) -> torch.Tensor:
541
+ if not self.use_moe_aux_loss:
542
+ return main_loss
543
+
544
+ actor = next(self.actor.children()) if isinstance(self.actor, DistributedDataParallel) else self.actor
545
+
546
+ router_loss = actor.moe_router_loss()
547
+ if router_loss is not None:
548
+ return main_loss + self.moe_aux_loss_scale * router_loss
549
+ else:
550
+ return main_loss
551
+
525
552
  def update_actor(self, state: tuple[TokenizedDict, TokenizedDict, TokenizedDict], action: TokenizedDict,
526
553
  advantages: torch.Tensor, old_log_probs: torch.Tensor, epoch: int) -> float:
527
554
  # 1. Reset actor gradients
@@ -544,6 +571,8 @@ class MRLTrainer:
544
571
  # 4.2 Calculate policy loss with selected algorithm
545
572
  policy_loss = self.rl_algorithm.policy_loss(next_query, action, logits, old_log_probs,
546
573
  advantages)
574
+ policy_loss = self._moe_aux_loss(policy_loss)
575
+
547
576
  # 4.3 Run backpropagation with scaler
548
577
  self.scaler.scale(policy_loss).backward(retain_graph=True)
549
578
  # 4.4 Unscale and clip gradient norms
@@ -561,6 +590,7 @@ class MRLTrainer:
561
590
  action=MrlActorAction.DECODE)
562
591
  # 4.2 Calculate policy loss with selected algorithm
563
592
  policy_loss = self.rl_algorithm.policy_loss(next_query, action, logits, old_log_probs, advantages)
593
+ policy_loss = self._moe_aux_loss(policy_loss)
564
594
  # 4.3 Run backpropagation
565
595
  policy_loss.backward(retain_graph=True)
566
596
  # 4.4 Clip gradient norms
@@ -852,41 +882,41 @@ class MRLTrainer:
852
882
  if isinstance(update_epoch, tuple):
853
883
  switch_epoch, cross_att_lr = update_epoch
854
884
  if epoch == switch_epoch:
855
- self.actor.freeze_components('joint')
885
+ self.actor.unfreeze_components(freeze_embeddings=self.freeze_embeddings)
856
886
  self.optimizer = self._init_unfreeze_optimizer('update', cross_att_lr)
857
887
  print(f"Activating 'update' unfreeze strategy with custom cross_att_lr: {cross_att_lr}")
858
888
  elif epoch == update_epoch:
859
- self.actor.freeze_components('update')
889
+ self.actor.freeze_components('update', freeze_embeddings=self.freeze_embeddings)
860
890
  print(
861
891
  f"Activating 'update' unfreeze strategy - mem-att trainable / cross-att frozen / rest model frozen")
862
892
 
863
893
  if isinstance(fetch_epoch, tuple):
864
894
  switch_epoch, mem_att_lr = fetch_epoch
865
895
  if epoch == switch_epoch:
866
- self.actor.freeze_components('joint')
896
+ self.actor.unfreeze_components(freeze_embeddings=self.freeze_embeddings)
867
897
  self.optimizer = self._init_unfreeze_optimizer('fetch', mem_att_lr)
868
898
  print(f"Activating 'fetch' unfreeze strategy with custom mem_att_lr: {mem_att_lr}")
869
899
  elif epoch == fetch_epoch:
870
- self.actor.freeze_components('fetch')
900
+ self.actor.freeze_components('fetch', freeze_embeddings=self.freeze_embeddings)
871
901
  print(
872
902
  f"Activating 'fetch' unfreeze strategy - mem-att frozen / cross-att trainable / rest model frozen")
873
903
 
874
904
  if isinstance(joint_epoch, tuple):
875
905
  switch_epoch, model_lr = joint_epoch
876
906
  if epoch == switch_epoch:
877
- self.actor.unfreeze_components()
907
+ self.actor.unfreeze_components(freeze_embeddings=self.freeze_embeddings)
878
908
  self.optimizer = self._init_unfreeze_optimizer('joint', model_lr)
879
909
  print(f"Activating 'joint' unfreeze strategy with custom model_lr: {model_lr}")
880
910
  elif epoch == joint_epoch:
881
- self.actor.freeze_components('joint')
911
+ self.actor.freeze_components('joint', freeze_embeddings=self.freeze_embeddings)
882
912
  print(f"Activating 'joint' unfreeze strategy - mem-att/cross-att trainable / rest model frozen")
883
913
 
884
914
  if epoch == all_epoch:
885
- self.actor.unfreeze_components()
915
+ self.actor.unfreeze_components(freeze_embeddings=self.freeze_embeddings)
886
916
  self.optimizer = self._init_unfreeze_optimizer('all', 0.)
887
917
  print(f"Switching to train 'all' strategy - unfreeze all components")
888
918
  elif epoch == unfreeze_epoch:
889
- self.actor.unfreeze_components()
919
+ self.actor.unfreeze_components(freeze_embeddings=self.freeze_embeddings)
890
920
  print(f"Switching to train 'all' strategy - unfreeze all components")
891
921
 
892
922
  def _init_unfreeze_optimizer(
@@ -895,29 +925,43 @@ class MRLTrainer:
895
925
  unfreeze_lr: float,
896
926
  ) -> torch.optim.Optimizer:
897
927
  memory_lr = self.optim_config['memory_lr'] if 'memory_lr' in self.optim_config else self.optim_config['lr']
898
- model_lr = self.optim_config['lr']
928
+ model_lr, embedding_lr = self.optim_config['lr'], self.optim_config['embedding_lr']
899
929
 
900
930
  if mode == 'update':
901
931
  params = [
902
- {'params': self.actor.not_memory_parameters(), 'lr': model_lr},
932
+ {'params': self.actor.encoder.embedding.parameters(), 'lr': embedding_lr},
933
+ {'params': self.actor.encoder.not_memory_parameters(), 'lr': model_lr},
934
+ {'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
903
935
  {'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
904
- {'params': self.actor.memory_cross_attention_parameters(), 'lr': unfreeze_lr},
936
+ {'params': self.actor.decoder.memory_parameters(), 'lr': unfreeze_lr},
937
+ {'params': self.actor.decoder.not_memory_parameters(), 'lr': unfreeze_lr},
905
938
  ]
906
939
  elif mode == 'fetch':
907
940
  params = [
908
- {'params': self.actor.not_memory_parameters(), 'lr': model_lr},
909
- {'params': self.actor.memory_cross_attention_parameters(), 'lr': memory_lr},
941
+ {'params': self.actor.encoder.embedding.parameters(), 'lr': unfreeze_lr},
942
+ {'params': self.actor.encoder.not_memory_parameters(), 'lr': unfreeze_lr},
943
+ {'params': self.actor.encoder.memory_parameters(), 'lr': unfreeze_lr},
910
944
  {'params': self.actor.memory_attention_parameters(), 'lr': unfreeze_lr},
945
+ {'params': self.actor.decoder.memory_parameters(), 'lr': memory_lr},
946
+ {'params': self.actor.decoder.not_memory_parameters(), 'lr': model_lr},
911
947
  ]
912
948
  elif mode == 'joint':
913
949
  params = [
914
- {'params': self.actor.not_memory_parameters(), 'lr': unfreeze_lr},
915
- {'params': self.actor.memory_parameters(), 'lr': memory_lr},
950
+ {'params': self.actor.encoder.embedding.parameters(), 'lr': unfreeze_lr},
951
+ {'params': self.actor.encoder.not_memory_parameters(), 'lr': unfreeze_lr},
952
+ {'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
953
+ {'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
954
+ {'params': self.actor.decoder.memory_parameters(), 'lr': memory_lr},
955
+ {'params': self.actor.decoder.not_memory_parameters(), 'lr': unfreeze_lr},
916
956
  ]
917
957
  else:
918
958
  params = [
919
- {'params': self.actor.not_memory_parameters(), 'lr': model_lr},
920
- {'params': self.actor.memory_parameters(), 'lr': memory_lr},
959
+ {'params': self.actor.encoder.embedding.parameters(), 'lr': embedding_lr},
960
+ {'params': self.actor.encoder.not_memory_parameters(), 'lr': model_lr},
961
+ {'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
962
+ {'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
963
+ {'params': self.actor.decoder.memory_parameters(), 'lr': memory_lr},
964
+ {'params': self.actor.decoder.not_memory_parameters(), 'lr': model_lr},
921
965
  ]
922
966
 
923
967
  return torch.optim.AdamW(params, weight_decay=self.optim_config['weight_decay'])
@@ -934,6 +978,7 @@ class MRLTrainer:
934
978
  MrlStrategy.MULTI_STEP_STRATEGY) # MRL strategy for given curriculum stage
935
979
  self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
936
980
  self.update_epochs = config.get('update_epochs', self.shared_update_epochs) # Internal update epochs
981
+ self.freeze_embeddings = config.get('freeze_embeddings', self.shared_freeze_embeddings)
937
982
  if config['lr'] is not None or config['critic_lr'] is not None or config['weight_decay'] is not None or config[
938
983
  'critic_weight_decay'] is not None or (config['separate_memory_lr'] and config['memory_lr'] is not None):
939
984
  if config.get('separate_memory_lr', False):
@@ -945,6 +990,7 @@ class MRLTrainer:
945
990
  self.base_optim_config['critic_weight_decay']),
946
991
  'critic_encoder_lr': config.get('critic_encoder_lr', self.base_optim_config['critic_encoder_lr']),
947
992
  'memory_lr': config.get('memory_lr', self.base_optim_config['memory_lr']),
993
+ 'embedding_lr': config.get('embedding_lr', self.base_optim_config['embedding_lr'])
948
994
  }
949
995
  else:
950
996
  self.optim_config = {
@@ -954,6 +1000,7 @@ class MRLTrainer:
954
1000
  'critic_weight_decay': config.get('critic_weight_decay',
955
1001
  self.base_optim_config['critic_weight_decay']),
956
1002
  'critic_encoder_lr': config.get('critic_encoder_lr', self.base_optim_config['critic_encoder_lr']),
1003
+ 'embedding_lr': config.get('embedding_lr', self.base_optim_config['embedding_lr'])
957
1004
  }
958
1005
  self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
959
1006
  elif self.optim_config != self.base_optim_config:
@@ -1005,7 +1052,7 @@ class MRLTrainer:
1005
1052
  if callable(unfreeze_epoch):
1006
1053
  unfreeze_epoch(-1)
1007
1054
  else:
1008
- self.actor.freeze_components('joint')
1055
+ self.actor.freeze_components('joint', freeze_embeddings=self.freeze_embeddings)
1009
1056
  if isinstance(unfreeze_epoch, tuple):
1010
1057
  print(
1011
1058
  f"Starting training with unfreeze strategies - 'warmup' - mem-att/cross-att trainable / rest model frozen")
@@ -17,6 +17,7 @@ class ReactiveTransformerBase(nn.Module):
17
17
  absolute_embedding: AbsolutePositionalEmbedding = None,
18
18
  use_flash_attention: bool = False,
19
19
  use_relative_embedding: bool = False,
20
+ use_moe: bool = False,
20
21
  *args,
21
22
  **kwargs,
22
23
  ):
@@ -32,6 +33,7 @@ class ReactiveTransformerBase(nn.Module):
32
33
  self.layers = own_layers
33
34
  self.num_shared_layers = len(shared_layers) if shared_layers else 0
34
35
  self.num_own_layers = len(own_layers) if own_layers else 0
36
+ self.use_moe = use_moe
35
37
 
36
38
  def trainable_cross_attention_(self, is_trainable: bool, with_norms: bool = True):
37
39
  for i in range(self.num_shared_layers):
@@ -50,8 +52,11 @@ class ReactiveTransformerBase(nn.Module):
50
52
  return own + shared
51
53
 
52
54
  def moe_router_loss(self):
53
- return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_own_layers) if self.layers[i].use_moe or self.layers[i].use_moe_att] + [
54
- self.shared_layers[i].moe_router_loss() for i in range(self.num_shared_layers) if self.shared_layers[i].use_moe or self.shared_layers[i].use_moe_att]).mean()
55
+ if self.use_moe:
56
+ return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_own_layers) if self.layers[i].use_moe or self.layers[i].use_moe_att] + [
57
+ self.shared_layers[i].moe_router_loss() for i in range(self.num_shared_layers) if self.shared_layers[i].use_moe or self.shared_layers[i].use_moe_att]).mean()
58
+ else:
59
+ return None
55
60
 
56
61
  def forward(self, x: torch.Tensor) -> torch.Tensor:
57
62
  # Shared logic for encoders and decoders - apply embeddings and positional encoding
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.39
3
+ Version: 0.2.41
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -5,19 +5,19 @@ rxnn/experimental/attention.py,sha256=46qwZLJuZMpIBrZ-r9DaQEPPmmZkO464C3Tkm_Mq-c
5
5
  rxnn/experimental/models.py,sha256=foBo0n0ufvBnfIdJomiEg3CuSOiWSt-q5ako7vzYxx4,4888
6
6
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
7
7
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
- rxnn/memory/attention.py,sha256=PXVBZQYNsRraZh7QDBgUOdPy3lTI8B0d8CzduojBjG0,1747
8
+ rxnn/memory/attention.py,sha256=POszZeW0QBKOh4VTDVekmZGKKwUr1Zj0FOAilTv8Vyg,2411
9
9
  rxnn/memory/norm.py,sha256=E98jOQEuIOFFhlkvS8s4fFN-D4tLO6vaOqnObv1oVmA,6592
10
10
  rxnn/memory/stm.py,sha256=SSfc-RL9FE-RLkmOEkLB-9Rb00ZXbMLbsAEPdpIW89o,3851
11
11
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- rxnn/rxt/models.py,sha256=CzFELVv5-ybAwl1s1ptpmwM7wdJ07M4jaT1-I8PYrR0,13999
12
+ rxnn/rxt/models.py,sha256=lRn7NRIAAeCxr8hoIXanhaD-cGwVwA23hBdIQpBK6kc,14484
13
13
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  rxnn/training/base.py,sha256=CqaArEZYOdH64nmKfx28U3GI46TzO4oNkjf_hrF23Cw,11835
15
15
  rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
16
16
  rxnn/training/callbacks.py,sha256=p72lbzFAmFjpcUvyy4aUB3qd53I8C6Sk5w9nQvsKgTk,35852
17
17
  rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
18
18
  rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
19
- rxnn/training/models.py,sha256=y-9XHedSheyK1AmLBp3ayulnUvAmDuJ3t0qVg8wHBRg,7463
20
- rxnn/training/mrl.py,sha256=fIrg1Er0aAK4TnyDRmJC1m7az9wdkhikxv0CBCrGT-c,55868
19
+ rxnn/training/models.py,sha256=4hDH-R9l1lNvBMW_CGG_QgmCVrkyG7Lyo40PPzvkovQ,8876
20
+ rxnn/training/mrl.py,sha256=tv7LjW1HBXF9H7rrITQD4EmN1-qgJT44UblREzsjeew,59378
21
21
  rxnn/training/reward.py,sha256=B7nerPk9eNAv2i7umtNF88tVQVwijNNrchIrEITGHKk,11623
22
22
  rxnn/training/rl.py,sha256=q4NzIZAmXRHVToT13IHrPTtEikWQUvT0NO0IjApjAO8,6171
23
23
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
@@ -28,12 +28,12 @@ rxnn/transformers/attention.py,sha256=d0Igo1Nrn76BphbHrzekiKJfT3RCy4iSXSB6FLAOTg
28
28
  rxnn/transformers/ff.py,sha256=WDjO-H9XWInoWnUnxiseIH6Kx5GlHP0zGJygwhcb1gc,2589
29
29
  rxnn/transformers/layers.py,sha256=l0bXmhN7KOkCw0KTVLixWSo9Op4SesGabWJ4R4EQBMY,7988
30
30
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
31
- rxnn/transformers/models.py,sha256=hey6tFN9gmLfWCZLjtl_9OcvIjGpWLI1IDeVnr5y8YM,10583
31
+ rxnn/transformers/models.py,sha256=7ypPNFFnacdZjvaLVue1KR2PmMSdVYsbCMQSunXDL70,10720
32
32
  rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
33
33
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
34
34
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
35
35
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
36
- rxnn-0.2.39.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
- rxnn-0.2.39.dist-info/METADATA,sha256=0Ky_SOITUSAzWBAcLtNl6Wq2n6ESnMNEs6_sBKezQ88,25960
38
- rxnn-0.2.39.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
- rxnn-0.2.39.dist-info/RECORD,,
36
+ rxnn-0.2.41.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
+ rxnn-0.2.41.dist-info/METADATA,sha256=5oKrThfhnOQK8KjDYJfcP-LTb03hNyUrSTjbOSpUUdg,25960
38
+ rxnn-0.2.41.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
+ rxnn-0.2.41.dist-info/RECORD,,
File without changes
File without changes