rxnn 0.2.40__tar.gz → 0.2.42__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rxnn-0.2.40 → rxnn-0.2.42}/PKG-INFO +1 -1
  2. {rxnn-0.2.40 → rxnn-0.2.42}/pyproject.toml +1 -1
  3. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/memory/attention.py +13 -1
  4. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/rxt/models.py +7 -1
  5. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/training/models.py +24 -8
  6. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/training/mrl.py +37 -24
  7. {rxnn-0.2.40 → rxnn-0.2.42}/LICENSE +0 -0
  8. {rxnn-0.2.40 → rxnn-0.2.42}/README.md +0 -0
  9. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/.DS_Store +0 -0
  10. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/__init__.py +0 -0
  11. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/experimental/__init__.py +0 -0
  12. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/experimental/attention.py +0 -0
  13. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/experimental/models.py +0 -0
  14. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/experimental/moe.py +0 -0
  15. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/memory/__init__.py +0 -0
  16. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/memory/norm.py +0 -0
  17. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/memory/stm.py +0 -0
  18. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/rxt/__init__.py +0 -0
  19. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/training/__init__.py +0 -0
  20. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/training/base.py +0 -0
  21. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/training/bml.py +0 -0
  22. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/training/callbacks.py +0 -0
  23. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/training/dataset.py +0 -0
  24. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/training/ddp.py +0 -0
  25. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/training/reward.py +0 -0
  26. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/training/rl.py +0 -0
  27. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/training/scheduler.py +0 -0
  28. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/training/tokenizer.py +0 -0
  29. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/training/utils.py +0 -0
  30. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/transformers/__init__.py +0 -0
  31. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/transformers/attention.py +0 -0
  32. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/transformers/ff.py +0 -0
  33. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/transformers/layers.py +0 -0
  34. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/transformers/mask.py +0 -0
  35. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/transformers/models.py +0 -0
  36. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/transformers/sampler.py +0 -0
  39. {rxnn-0.2.40 → rxnn-0.2.42}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.40
3
+ Version: 0.2.42
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.40"
7
+ version = "0.2.42"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -8,6 +8,9 @@ class StmMemoryAttention(nn.Module):
8
8
  stm: ShortTermMemory,
9
9
  attention_layers: nn.ModuleList,
10
10
  memory_norm_layers: nn.ModuleList,
11
+ use_gated_residual: bool = False,
12
+ per_slot_gate: bool = False,
13
+ init_gate: float = 0.0,
11
14
  *args,
12
15
  **kwargs
13
16
  ):
@@ -17,6 +20,10 @@ class StmMemoryAttention(nn.Module):
17
20
  self.memory_norm_layers = memory_norm_layers
18
21
  assert len(self.attention_layers) == len(self.memory_norm_layers) == self.stm.memory.size(0)
19
22
  self.num_layers = len(attention_layers)
23
+ self.use_gated_residual = use_gated_residual
24
+ self.per_slot_gate = per_slot_gate
25
+ if self.use_gated_residual:
26
+ self.gate = nn.Parameter(torch.full((self.num_layers, self.stm.stm_size, 1), init_gate) if self.per_slot_gate else torch.full((self.num_layers,), init_gate))
20
27
 
21
28
  def update_max_len(self, max_seq_len: int):
22
29
  for i in range(self.num_layers):
@@ -35,7 +42,12 @@ class StmMemoryAttention(nn.Module):
35
42
  encoded_layer_data = x[i]
36
43
  normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
37
44
  new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=mask)
38
- new_stm[i] = new_layer_stm + layer_stm # residual
45
+ if self.use_gated_residual:
46
+ # gated residual
47
+ layer_gate = torch.sigmoid(self.gate[i])
48
+ new_stm[i] = layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
49
+ else:
50
+ new_stm[i] = new_layer_stm + layer_stm # residual
39
51
  self.stm.update_all(new_stm)
40
52
  return self.stm.memory
41
53
 
@@ -250,6 +250,9 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
250
250
  norm_init_gate: float = -2.0,
251
251
  norm_per_dim_scale: bool = False,
252
252
  norm_decay: float = 0.9,
253
+ use_gated_residual: bool = False,
254
+ residual_per_slot_gate: bool = False,
255
+ residual_init_gate: float = 0.0,
253
256
  **kwargs,
254
257
  ):
255
258
  super(RxTAlphaMemoryAttention, self).__init__(**kwargs)
@@ -276,7 +279,10 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
276
279
  init_gate=norm_init_gate, per_dim_scale=norm_per_dim_scale)
277
280
  for _ in range(num_layers)])
278
281
  attention_layers = nn.ModuleList([att_init() for _ in range(num_layers)])
279
- self.model = StmMemoryAttention(stm, attention_layers, memory_norm_layers)
282
+ self.model = StmMemoryAttention(
283
+ stm, attention_layers, memory_norm_layers,
284
+ use_gated_residual=use_gated_residual, per_slot_gate=residual_per_slot_gate, init_gate=residual_init_gate
285
+ )
280
286
 
281
287
  def freeze(self):
282
288
  for param in self.parameters():
@@ -80,7 +80,7 @@ class MrlActorModel(nn.Module):
80
80
  self.decoder = decoder
81
81
  self.memory_attention = memory_attention
82
82
 
83
- def freeze_components(self, stage: Literal['update', 'fetch', 'joint'] = 'joint'):
83
+ def freeze_components(self, stage: Literal['update', 'fetch', 'joint'] = 'joint', freeze_embeddings: bool = False):
84
84
  """Freeze encoder/decoder except memory-related layers."""
85
85
  # Freeze/unfreeze encoder
86
86
  if self.encoder.freeze_without_memory is not None:
@@ -116,7 +116,11 @@ class MrlActorModel(nn.Module):
116
116
  for param in self.memory_attention.parameters():
117
117
  param.requires_grad = True if stage != 'fetch' else False
118
118
 
119
- def unfreeze_components(self):
119
+ if freeze_embeddings:
120
+ for param in self.encoder.model.embedding.parameters():
121
+ param.requires_grad = False
122
+
123
+ def unfreeze_components(self, freeze_embeddings: bool = False):
120
124
  """Unfreeze all components after initial training."""
121
125
  if self.encoder.unfreeze_all is not None:
122
126
  self.encoder.unfreeze_all()
@@ -134,6 +138,11 @@ class MrlActorModel(nn.Module):
134
138
  for param in self.memory_attention.parameters():
135
139
  param.requires_grad = True
136
140
 
141
+ if freeze_embeddings:
142
+ for param in self.encoder.model.embedding.parameters():
143
+ param.requires_grad = False
144
+
145
+
137
146
  def reset_memory(self):
138
147
  self.memory_attention.reset_memory()
139
148
 
@@ -159,12 +168,19 @@ class MrlActorModel(nn.Module):
159
168
  self.decoder.not_memory_parameters()
160
169
  ))
161
170
 
162
- def unique_parameters(self):
163
- return list(set(
164
- list(self.encoder.parameters()) +
165
- list(self.decoder.parameters()) +
166
- list(self.memory_attention.parameters())
167
- ))
171
+ def unique_parameters(self, with_embedding: bool = True):
172
+ if with_embedding:
173
+ return list(set(
174
+ list(self.encoder.parameters()) +
175
+ list(self.decoder.parameters()) +
176
+ list(self.memory_attention.parameters())
177
+ ))
178
+ else:
179
+ return list(set(
180
+ self.not_memory_parameters() +
181
+ self.memory_cross_attention_parameters() +
182
+ list(self.memory_attention_parameters())
183
+ ))
168
184
 
169
185
  def moe_router_loss(self):
170
186
  if self.encoder.model.use_moe and self.decoder.model.use_moe:
@@ -21,18 +21,20 @@ class MrlConfig(TypedDict):
21
21
  separate_memory_lr: Optional[bool]
22
22
  memory_lr: Optional[float]
23
23
  critic_lr: float
24
- critic_encoder_lr: float
24
+ critic_encoder_lr: Optional[float]
25
25
  max_seq_len: int
26
26
  critic_max_len: int
27
- weight_decay: float
28
- critic_weight_decay: float
27
+ weight_decay: Optional[float]
28
+ critic_weight_decay: Optional[float]
29
29
  update_epochs: int
30
30
  pad_token_id: int
31
31
  end_token_id: int
32
32
  callbacks: Optional[list[MrlTrainerCallback]]
33
- memory_aware_critic: bool
34
- use_moe_aux_loss: bool
35
- moe_aux_loss_scale: float
33
+ memory_aware_critic: Optional[bool]
34
+ use_moe_aux_loss: Optional[bool]
35
+ moe_aux_loss_scale: Optional[float]
36
+ freeze_embeddings: Optional[bool]
37
+ embedding_lr: Optional[float]
36
38
 
37
39
 
38
40
  class MrlStrategy(Enum):
@@ -66,6 +68,8 @@ class CurriculumConfig(TypedDict):
66
68
  weight_decay: Optional[float]
67
69
  critic_weight_decay: Optional[float]
68
70
  update_epochs: Optional[int]
71
+ freeze_embeddings: Optional[bool]
72
+ embedding_lr: Optional[float]
69
73
 
70
74
 
71
75
  class SamplerConfig(TypedDict):
@@ -129,6 +133,8 @@ class MRLTrainer:
129
133
  self.memory_aware_critic = config.get('memory_aware_critic', False)
130
134
  self.use_moe_aux_loss = config.get('use_moe_aux_loss', False)
131
135
  self.moe_aux_loss_scale = config.get('moe_aux_loss_scale', 0.01)
136
+ self.shared_freeze_embeddings = config.get('freeze_embeddings', False)
137
+ self.freeze_embeddings = self.shared_freeze_embeddings
132
138
  # Internal update epochs config
133
139
  self.shared_update_epochs = config.get('update_epochs', 10)
134
140
  self.update_epochs = self.shared_update_epochs
@@ -166,6 +172,7 @@ class MRLTrainer:
166
172
  'weight_decay': config.get('weight_decay', 0.01),
167
173
  'critic_weight_decay': config.get('critic_weight_decay', 0.01),
168
174
  'critic_encoder_lr': config.get('critic_encoder_lr', config.get('critic_lr', 1e-4)),
175
+ 'embedding_lr': config.get('embedding_lr', config.get('lr', 3e-4)),
169
176
  }
170
177
  else:
171
178
  self.base_optim_config = {
@@ -174,6 +181,7 @@ class MRLTrainer:
174
181
  'weight_decay': config.get('weight_decay', 0.01),
175
182
  'critic_weight_decay': config.get('critic_weight_decay', 0.01),
176
183
  'critic_encoder_lr': config.get('critic_encoder_lr', config.get('critic_lr', 1e-4)),
184
+ 'embedding_lr': config.get('embedding_lr', config.get('lr', 3e-4)),
177
185
  }
178
186
 
179
187
  self.optim_config = self.base_optim_config
@@ -212,20 +220,22 @@ class MRLTrainer:
212
220
  weight_decay: float,
213
221
  critic_weight_decay: float,
214
222
  critic_encoder_lr: float,
223
+ embedding_lr: float,
215
224
  memory_lr: Optional[float] = None,
216
225
  ) -> tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
217
226
  if memory_lr is not None:
218
227
  optimizer = torch.optim.AdamW([
219
- {'params': self.actor.encoder.embedding.parameters(), 'lr': lr},
228
+ {'params': self.actor.encoder.embedding.parameters(), 'lr': embedding_lr},
220
229
  {'params': self.actor.not_memory_parameters(), 'lr': lr},
221
230
  {'params': self.actor.memory_parameters(), 'lr': memory_lr},
222
231
  ],
223
232
  weight_decay=weight_decay,
224
233
  )
225
234
  else:
226
- optimizer = torch.optim.AdamW(
227
- self.actor.unique_parameters(),
228
- lr=lr,
235
+ optimizer = torch.optim.AdamW([
236
+ {'params': self.actor.encoder.embedding.parameters(), 'lr': embedding_lr},
237
+ {'params': self.actor.unique_parameters(with_embedding=False), 'lr': lr},
238
+ ],
229
239
  weight_decay=weight_decay,
230
240
  )
231
241
 
@@ -872,41 +882,41 @@ class MRLTrainer:
872
882
  if isinstance(update_epoch, tuple):
873
883
  switch_epoch, cross_att_lr = update_epoch
874
884
  if epoch == switch_epoch:
875
- self.actor.unfreeze_components()
885
+ self.actor.unfreeze_components(freeze_embeddings=self.freeze_embeddings)
876
886
  self.optimizer = self._init_unfreeze_optimizer('update', cross_att_lr)
877
887
  print(f"Activating 'update' unfreeze strategy with custom cross_att_lr: {cross_att_lr}")
878
888
  elif epoch == update_epoch:
879
- self.actor.freeze_components('update')
889
+ self.actor.freeze_components('update', freeze_embeddings=self.freeze_embeddings)
880
890
  print(
881
891
  f"Activating 'update' unfreeze strategy - mem-att trainable / cross-att frozen / rest model frozen")
882
892
 
883
893
  if isinstance(fetch_epoch, tuple):
884
894
  switch_epoch, mem_att_lr = fetch_epoch
885
895
  if epoch == switch_epoch:
886
- self.actor.unfreeze_components()
896
+ self.actor.unfreeze_components(freeze_embeddings=self.freeze_embeddings)
887
897
  self.optimizer = self._init_unfreeze_optimizer('fetch', mem_att_lr)
888
898
  print(f"Activating 'fetch' unfreeze strategy with custom mem_att_lr: {mem_att_lr}")
889
899
  elif epoch == fetch_epoch:
890
- self.actor.freeze_components('fetch')
900
+ self.actor.freeze_components('fetch', freeze_embeddings=self.freeze_embeddings)
891
901
  print(
892
902
  f"Activating 'fetch' unfreeze strategy - mem-att frozen / cross-att trainable / rest model frozen")
893
903
 
894
904
  if isinstance(joint_epoch, tuple):
895
905
  switch_epoch, model_lr = joint_epoch
896
906
  if epoch == switch_epoch:
897
- self.actor.unfreeze_components()
907
+ self.actor.unfreeze_components(freeze_embeddings=self.freeze_embeddings)
898
908
  self.optimizer = self._init_unfreeze_optimizer('joint', model_lr)
899
909
  print(f"Activating 'joint' unfreeze strategy with custom model_lr: {model_lr}")
900
910
  elif epoch == joint_epoch:
901
- self.actor.freeze_components('joint')
911
+ self.actor.freeze_components('joint', freeze_embeddings=self.freeze_embeddings)
902
912
  print(f"Activating 'joint' unfreeze strategy - mem-att/cross-att trainable / rest model frozen")
903
913
 
904
914
  if epoch == all_epoch:
905
- self.actor.unfreeze_components()
915
+ self.actor.unfreeze_components(freeze_embeddings=self.freeze_embeddings)
906
916
  self.optimizer = self._init_unfreeze_optimizer('all', 0.)
907
917
  print(f"Switching to train 'all' strategy - unfreeze all components")
908
918
  elif epoch == unfreeze_epoch:
909
- self.actor.unfreeze_components()
919
+ self.actor.unfreeze_components(freeze_embeddings=self.freeze_embeddings)
910
920
  print(f"Switching to train 'all' strategy - unfreeze all components")
911
921
 
912
922
  def _init_unfreeze_optimizer(
@@ -915,11 +925,11 @@ class MRLTrainer:
915
925
  unfreeze_lr: float,
916
926
  ) -> torch.optim.Optimizer:
917
927
  memory_lr = self.optim_config['memory_lr'] if 'memory_lr' in self.optim_config else self.optim_config['lr']
918
- model_lr = self.optim_config['lr']
928
+ model_lr, embedding_lr = self.optim_config['lr'], self.optim_config['embedding_lr']
919
929
 
920
930
  if mode == 'update':
921
931
  params = [
922
- {'params': self.actor.encoder.embedding.parameters(), 'lr': model_lr},
932
+ {'params': self.actor.encoder.embedding.parameters(), 'lr': embedding_lr},
923
933
  {'params': self.actor.encoder.not_memory_parameters(), 'lr': model_lr},
924
934
  {'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
925
935
  {'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
@@ -946,7 +956,7 @@ class MRLTrainer:
946
956
  ]
947
957
  else:
948
958
  params = [
949
- {'params': self.actor.encoder.embedding.parameters(), 'lr': model_lr},
959
+ {'params': self.actor.encoder.embedding.parameters(), 'lr': embedding_lr},
950
960
  {'params': self.actor.encoder.not_memory_parameters(), 'lr': model_lr},
951
961
  {'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
952
962
  {'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
@@ -968,6 +978,7 @@ class MRLTrainer:
968
978
  MrlStrategy.MULTI_STEP_STRATEGY) # MRL strategy for given curriculum stage
969
979
  self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
970
980
  self.update_epochs = config.get('update_epochs', self.shared_update_epochs) # Internal update epochs
981
+ self.freeze_embeddings = config.get('freeze_embeddings', self.shared_freeze_embeddings)
971
982
  if config['lr'] is not None or config['critic_lr'] is not None or config['weight_decay'] is not None or config[
972
983
  'critic_weight_decay'] is not None or (config['separate_memory_lr'] and config['memory_lr'] is not None):
973
984
  if config.get('separate_memory_lr', False):
@@ -979,6 +990,7 @@ class MRLTrainer:
979
990
  self.base_optim_config['critic_weight_decay']),
980
991
  'critic_encoder_lr': config.get('critic_encoder_lr', self.base_optim_config['critic_encoder_lr']),
981
992
  'memory_lr': config.get('memory_lr', self.base_optim_config['memory_lr']),
993
+ 'embedding_lr': config.get('embedding_lr', self.base_optim_config['embedding_lr'])
982
994
  }
983
995
  else:
984
996
  self.optim_config = {
@@ -988,6 +1000,7 @@ class MRLTrainer:
988
1000
  'critic_weight_decay': config.get('critic_weight_decay',
989
1001
  self.base_optim_config['critic_weight_decay']),
990
1002
  'critic_encoder_lr': config.get('critic_encoder_lr', self.base_optim_config['critic_encoder_lr']),
1003
+ 'embedding_lr': config.get('embedding_lr', self.base_optim_config['embedding_lr'])
991
1004
  }
992
1005
  self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
993
1006
  elif self.optim_config != self.base_optim_config:
@@ -1009,7 +1022,7 @@ class MRLTrainer:
1009
1022
 
1010
1023
  return (epochs, unfreeze_epoch), (random_resets, random_resets_from, random_resets_ratio)
1011
1024
 
1012
- def __call__(self, curriculum_config: list[CurriculumConfig], batch_size: int):
1025
+ def __call__(self, curriculum_config: list[CurriculumConfig], batch_size: int, ddp_find_unused_parameters: bool = False):
1013
1026
  """Start Memory Reinforcement Learning Curriculum."""
1014
1027
 
1015
1028
  # 0. Set global epoch count for all stages
@@ -1020,7 +1033,7 @@ class MRLTrainer:
1020
1033
  if self.use_ddp:
1021
1034
  rank, world_size = get_os_ddp_config()
1022
1035
  dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
1023
- self.actor = DistributedDataParallel(self.actor, device_ids=[self.device.index])
1036
+ self.actor = DistributedDataParallel(self.actor, device_ids=[self.device.index], find_unused_parameters=ddp_find_unused_parameters)
1024
1037
  self.critic = DistributedDataParallel(self.critic, device_ids=[self.device.index])
1025
1038
 
1026
1039
  # 2. Init BatchSampler with actor model (we have to run it after DDP init)
@@ -1039,7 +1052,7 @@ class MRLTrainer:
1039
1052
  if callable(unfreeze_epoch):
1040
1053
  unfreeze_epoch(-1)
1041
1054
  else:
1042
- self.actor.freeze_components('joint')
1055
+ self.actor.freeze_components('joint', freeze_embeddings=self.freeze_embeddings)
1043
1056
  if isinstance(unfreeze_epoch, tuple):
1044
1057
  print(
1045
1058
  f"Starting training with unfreeze strategies - 'warmup' - mem-att/cross-att trainable / rest model frozen")
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes