rxnn 0.2.60__py3-none-any.whl → 0.2.61__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/memory/attention.py CHANGED
@@ -13,6 +13,8 @@ class StmMemoryAttention(nn.Module):
13
13
  init_gate: float = 0.0,
14
14
  use_dynamic_gate: bool = False,
15
15
  use_tanh_gate: bool = False,
16
+ debug_mode: bool = False,
17
+ debug_interval: int = 10,
16
18
  *args,
17
19
  **kwargs
18
20
  ):
@@ -30,6 +32,10 @@ class StmMemoryAttention(nn.Module):
30
32
  gate_shape = (self.num_layers, self.stm.stm_size, 1) if self.per_slot_gate else (self.num_layers,)
31
33
  self.gate = nn.Parameter(torch.full(gate_shape, init_gate))
32
34
 
35
+ self.debug_mode = debug_mode
36
+ self.debug_interval = debug_interval
37
+ self.debug_step = 0
38
+
33
39
  def update_max_len(self, max_seq_len: int):
34
40
  for i in range(self.num_layers):
35
41
  if self.attention_layers[i].rope is not None:
@@ -58,6 +64,14 @@ class StmMemoryAttention(nn.Module):
58
64
  layer_stm = layer_stm.expand(x.size(0), -1, -1)
59
65
  encoded_layer_data = x[i]
60
66
  normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
67
+
68
+ if self.debug_mode:
69
+ if self.debug_step != 0 and self.debug_step % self.debug_interval == 0:
70
+ self.debug_step = 0
71
+ print(f"Normalized STM stats - mean: {normalized_layer_stm.mean().item():.4f}, std: {normalized_layer_stm.std().item():.4f}")
72
+ else:
73
+ self.debug_step += 1
74
+
61
75
  new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=attention_mask)
62
76
  if self.use_gated_residual:
63
77
  new_stm[i] = self._residual_gate(self.gate[i], layer_stm, new_layer_stm) # gated residual
rxnn/rxt/models.py CHANGED
@@ -254,6 +254,8 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
254
254
  residual_per_slot_gate: bool = False,
255
255
  residual_init_gate: float = 0.0,
256
256
  use_dynamic_residual_gate: bool = False,
257
+ debug_mode: bool = False,
258
+ debug_interval: int = 10,
257
259
  **kwargs,
258
260
  ):
259
261
  super(RxTAlphaMemoryAttention, self).__init__(**kwargs)
@@ -284,6 +286,7 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
284
286
  stm, attention_layers, memory_norm_layers,
285
287
  use_gated_residual=use_gated_residual, per_slot_gate=residual_per_slot_gate,
286
288
  init_gate=residual_init_gate, use_dynamic_gate=use_dynamic_residual_gate,
289
+ debug_mode=debug_mode, debug_interval=debug_interval,
287
290
  )
288
291
 
289
292
  def freeze(self):
rxnn/training/mrl.py CHANGED
@@ -22,6 +22,9 @@ class MrlConfig(TypedDict):
22
22
  memory_lr: Optional[float]
23
23
  critic_lr: float
24
24
  critic_encoder_lr: Optional[float]
25
+ encoder_lr: Optional[float]
26
+ encoder_memory_lr: Optional[float]
27
+ memory_attn_lr: Optional[float]
25
28
  max_seq_len: int
26
29
  critic_max_len: int
27
30
  weight_decay: Optional[float]
@@ -68,6 +71,9 @@ class CurriculumConfig(TypedDict):
68
71
  memory_lr: Optional[float]
69
72
  critic_lr: Optional[float]
70
73
  critic_encoder_lr: Optional[float]
74
+ encoder_lr: Optional[float]
75
+ encoder_memory_lr: Optional[float]
76
+ memory_attn_lr: Optional[float]
71
77
  weight_decay: Optional[float]
72
78
  critic_weight_decay: Optional[float]
73
79
  update_epochs: Optional[int]
@@ -95,7 +101,10 @@ class MrlTrajectoryEpisode(TypedDict):
95
101
  reset_stm: bool
96
102
  steps: list[MrlTrajectoryStep]
97
103
 
98
- OptimField: TypeAlias = Literal['lr', 'critic_lr', 'weight_decay', 'critic_weight_decay', 'separate_memory_lr', 'memory_lr']
104
+ OptimField: TypeAlias = Literal[
105
+ 'lr', 'critic_lr', 'weight_decay', 'critic_weight_decay', 'separate_memory_lr',
106
+ 'memory_lr', 'encoder_lr', 'encoder_memory_lr', 'memory_attn_lr'
107
+ ]
99
108
 
100
109
  class MRLTrainer:
101
110
  def __init__(
@@ -181,6 +190,9 @@ class MRLTrainer:
181
190
  'critic_weight_decay': config.get('critic_weight_decay', 0.01),
182
191
  'critic_encoder_lr': config.get('critic_encoder_lr', config.get('critic_lr', 1e-4)),
183
192
  'embedding_lr': config.get('embedding_lr', config.get('lr', 3e-4)),
193
+ 'encoder_lr': config.get('encoder_lr', config.get('lr', 3e-4)),
194
+ 'encoder_memory_lr': config.get('encoder_memory_lr', config.get('memory_lr', 5e-4)),
195
+ 'memory_attn_lr': config.get('memory_attn_lr', config.get('memory_lr', 5e-4)),
184
196
  }
185
197
  else:
186
198
  self.base_optim_config = {
@@ -190,6 +202,7 @@ class MRLTrainer:
190
202
  'critic_weight_decay': config.get('critic_weight_decay', 0.01),
191
203
  'critic_encoder_lr': config.get('critic_encoder_lr', config.get('critic_lr', 1e-4)),
192
204
  'embedding_lr': config.get('embedding_lr', config.get('lr', 3e-4)),
205
+ 'encoder_lr': config.get('encoder_lr', config.get('lr', 3e-4)),
193
206
  }
194
207
 
195
208
  self.optim_config = self.base_optim_config
@@ -574,9 +587,18 @@ class MRLTrainer:
574
587
  encoder_total, encoder_mean = get_gradient_norms(self.actor.encoder)
575
588
  decoder_total, decoder_mean = get_gradient_norms(self.actor.decoder)
576
589
  mem_att_total, mem_att_mean = get_gradient_norms(self.actor.memory_attention)
577
- print(f"Encoder grad norm - total: {encoder_total:.4f}, mean: {encoder_mean:.4f}")
578
- print(f"Decoder grad norm - total: {decoder_total:.4f}, mean: {decoder_mean:.4f}")
579
- print(f"Memory attention grad norm - total: {mem_att_total:.4f}, mean: {mem_att_mean:.4f}")
590
+ print(f"Encoder grad norm - total: {encoder_total:.6f}, mean: {encoder_mean:.6f}")
591
+ print(f"Decoder grad norm - total: {decoder_total:.6f}, mean: {decoder_mean:.6f}")
592
+ print(f"Memory attention grad norm - total: {mem_att_total:.6f}, mean: {mem_att_mean:.6f}")
593
+ # decoder's cross att
594
+ dec_x_att_norms = [get_gradient_norms(layer.memory_cross_attention)[0] for layer in self.actor.decoder.model.layers]
595
+ print(f"Decoder cross-att mean total norm: {(sum(dec_x_att_norms) / len(dec_x_att_norms)):.6f}, all: {dec_x_att_norms}")
596
+
597
+ mem_att_norms = [get_gradient_norms(layer)[0] for layer in self.actor.memory_attention.model.attention_layers]
598
+ print(f"Memory attention layers mean total norm: {(sum(mem_att_norms) / len(mem_att_norms)):.6f}, all: {mem_att_norms}")
599
+
600
+ enc_ff_norms = [get_gradient_norms(layer.ff)[0] for layer in self.actor.encoder.model.layers]
601
+ print(f"Encoder ff mean total norm: {(sum(enc_ff_norms) / len(enc_ff_norms)):.6f}, all: {enc_ff_norms}")
580
602
 
581
603
  def update_actor(self, state: tuple[TokenizedDict, TokenizedDict, TokenizedDict], action: TokenizedDict,
582
604
  advantages: torch.Tensor, old_log_probs: torch.Tensor, epoch: int) -> float:
@@ -969,14 +991,16 @@ class MRLTrainer:
969
991
  unfreeze_lr: float,
970
992
  ) -> torch.optim.Optimizer:
971
993
  memory_lr = self.optim_config['memory_lr'] if 'memory_lr' in self.optim_config else self.optim_config['lr']
972
- model_lr, embedding_lr = self.optim_config['lr'], self.optim_config['embedding_lr']
994
+ encoder_memory_lr = self.optim_config['encoder_memory_lr'] if 'encoder_memory_lr' in self.optim_config else self.optim_config['encoder_lr']
995
+ memory_attn_lr = self.optim_config['memory_attn_lr'] if 'memory_attn_lr' in self.optim_config else self.optim_config['lr']
996
+ model_lr, embedding_lr, encoder_lr = self.optim_config['lr'], self.optim_config['embedding_lr'], self.optim_config['encoder_lr']
973
997
 
974
998
  if mode == 'update':
975
999
  params = [
976
1000
  {'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
977
- {'params': self.actor.encoder.not_memory_parameters(), 'lr': model_lr},
978
- {'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
979
- {'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
1001
+ {'params': self.actor.encoder.not_memory_parameters(), 'lr': encoder_lr},
1002
+ {'params': self.actor.encoder.memory_parameters(), 'lr': encoder_memory_lr},
1003
+ {'params': self.actor.memory_attention_parameters(), 'lr': memory_attn_lr},
980
1004
  {'params': self.actor.decoder.memory_parameters(), 'lr': unfreeze_lr},
981
1005
  {'params': self.actor.decoder.not_memory_parameters(), 'lr': unfreeze_lr},
982
1006
  ]
@@ -993,17 +1017,17 @@ class MRLTrainer:
993
1017
  params = [
994
1018
  {'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
995
1019
  {'params': self.actor.encoder.not_memory_parameters(), 'lr': unfreeze_lr},
996
- {'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
997
- {'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
1020
+ {'params': self.actor.encoder.memory_parameters(), 'lr': encoder_memory_lr},
1021
+ {'params': self.actor.memory_attention_parameters(), 'lr': memory_attn_lr},
998
1022
  {'params': self.actor.decoder.memory_parameters(), 'lr': memory_lr},
999
1023
  {'params': self.actor.decoder.not_memory_parameters(), 'lr': unfreeze_lr},
1000
1024
  ]
1001
1025
  else:
1002
1026
  params = [
1003
1027
  {'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
1004
- {'params': self.actor.encoder.not_memory_parameters(), 'lr': model_lr},
1005
- {'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
1006
- {'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
1028
+ {'params': self.actor.encoder.not_memory_parameters(), 'lr': encoder_lr},
1029
+ {'params': self.actor.encoder.memory_parameters(), 'lr': encoder_memory_lr},
1030
+ {'params': self.actor.memory_attention_parameters(), 'lr': memory_attn_lr},
1007
1031
  {'params': self.actor.decoder.memory_parameters(), 'lr': memory_lr},
1008
1032
  {'params': self.actor.decoder.not_memory_parameters(), 'lr': model_lr},
1009
1033
  ]
@@ -1028,11 +1052,14 @@ class MRLTrainer:
1028
1052
  def has_param(field: OptimField) -> bool:
1029
1053
  return field in config and config[field] is not None
1030
1054
 
1031
- optim_params: list[OptimField] = ['lr', 'critic_lr', 'weight_decay', 'critic_weight_decay']
1055
+ optim_params: list[OptimField] = ['lr', 'critic_lr', 'weight_decay', 'critic_weight_decay', 'encoder_lr']
1056
+ mem_optim_params: list[OptimField] = ['memory_lr', 'encoder_memory_lr', 'memory_attn_lr']
1032
1057
 
1033
1058
  has_any_optim_param = any(
1034
1059
  has_param(field) for field in optim_params
1035
- ) or (has_param('separate_memory_lr') and config['separate_memory_lr'] and has_param('memory_lr'))
1060
+ ) or (has_param('separate_memory_lr') and config['separate_memory_lr'] and any(
1061
+ has_param(field) for field in mem_optim_params
1062
+ ))
1036
1063
 
1037
1064
  if has_any_optim_param:
1038
1065
  if config.get('separate_memory_lr', False):
@@ -1044,7 +1071,10 @@ class MRLTrainer:
1044
1071
  self.base_optim_config['critic_weight_decay']),
1045
1072
  'critic_encoder_lr': config.get('critic_encoder_lr', self.base_optim_config['critic_encoder_lr']),
1046
1073
  'memory_lr': config.get('memory_lr', self.base_optim_config['memory_lr']),
1047
- 'embedding_lr': config.get('embedding_lr', self.base_optim_config['embedding_lr'])
1074
+ 'embedding_lr': config.get('embedding_lr', self.base_optim_config['embedding_lr']),
1075
+ 'encoder_lr': config.get('encoder_lr', self.base_optim_config['encoder_lr']),
1076
+ 'encoder_memory_lr': config.get('encoder_memory_lr', self.base_optim_config['encoder_memory_lr']),
1077
+ 'memory_attn_lr': config.get('memory_attn_lr', self.base_optim_config['memory_attn_lr']),
1048
1078
  }
1049
1079
  else:
1050
1080
  self.optim_config = {
@@ -1054,7 +1084,8 @@ class MRLTrainer:
1054
1084
  'critic_weight_decay': config.get('critic_weight_decay',
1055
1085
  self.base_optim_config['critic_weight_decay']),
1056
1086
  'critic_encoder_lr': config.get('critic_encoder_lr', self.base_optim_config['critic_encoder_lr']),
1057
- 'embedding_lr': config.get('embedding_lr', self.base_optim_config['embedding_lr'])
1087
+ 'embedding_lr': config.get('embedding_lr', self.base_optim_config['embedding_lr']),
1088
+ 'encoder_lr': config.get('encoder_lr', self.base_optim_config['encoder_lr']),
1058
1089
  }
1059
1090
  self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
1060
1091
  elif self.optim_config != self.base_optim_config:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.60
3
+ Version: 0.2.61
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -5,11 +5,11 @@ rxnn/experimental/attention.py,sha256=jlNS82INjycNEfmk3HtkIacUvT_ELhaCO2g-kZTvhX
5
5
  rxnn/experimental/models.py,sha256=oJWd56LUsLc9S8eCZw-ShvuWjoQxj4C9GitbohlQ0ok,5139
6
6
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
7
7
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
- rxnn/memory/attention.py,sha256=t-SWJhQ71TV8X_8I_yp0Cr5df7fnWxI-EnYiN8gjpok,3268
8
+ rxnn/memory/attention.py,sha256=thp4t1IVwuIFv-w2WnKZwZeQK_GHwOxaWlnTscs4AH0,3826
9
9
  rxnn/memory/norm.py,sha256=cVjjhCLqR5K6-321SP_ObG17y-ddlcTJeCTXvW4vpk0,6675
10
10
  rxnn/memory/stm.py,sha256=jv57gsH9XW19sLbxpRDqsp1yfsii_4Ef4Ncr_ztk-i4,3937
11
11
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- rxnn/rxt/models.py,sha256=jd1UVBUWJzWdw7Rjcvo9k5BXCJriQ0khuVszqEyfD7M,14665
12
+ rxnn/rxt/models.py,sha256=mLHvb3ablQK9UtupuOHmLlG440Q_NW-OuLWcxGMfGuY,14807
13
13
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  rxnn/training/base.py,sha256=CqaArEZYOdH64nmKfx28U3GI46TzO4oNkjf_hrF23Cw,11835
15
15
  rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
@@ -17,7 +17,7 @@ rxnn/training/callbacks.py,sha256=rS8leuVFPVVfE5Zc8DMkUZhRIPN-vpPbUjowXE5TSBw,36
17
17
  rxnn/training/dataset.py,sha256=tbtOSYldHnQB6SWgee_yUj9zTbgoEoLFNa6wvUS6Apg,51292
18
18
  rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
19
19
  rxnn/training/models.py,sha256=KIiOCW0VgKtMA4EMQ---xsVExdI1mBsgWjtRSmJpecA,9033
20
- rxnn/training/mrl.py,sha256=H2JcamaJv19vKqOgdoyhcCBwu1lb_aKfCmR_MuuvmS0,62085
20
+ rxnn/training/mrl.py,sha256=BWp87Lj4epjTlROmrQK8RnS_83IucqS7XWI1cBae7BM,64424
21
21
  rxnn/training/reward.py,sha256=uiSsBXmjMw2yv-1Bssy3RTlpU6zP8ape3490Sl-aT0M,16144
22
22
  rxnn/training/rl.py,sha256=hWtExxY-_pAmTOGYxyCNounUbaGWvLDVltC4sRC7MN4,7175
23
23
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
33
33
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
34
34
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
35
35
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
36
- rxnn-0.2.60.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
- rxnn-0.2.60.dist-info/METADATA,sha256=x_juLxld_xGztBqC7bbBWTn4llwNuZtu29xyjz4uiX8,25997
38
- rxnn-0.2.60.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
- rxnn-0.2.60.dist-info/RECORD,,
36
+ rxnn-0.2.61.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
+ rxnn-0.2.61.dist-info/METADATA,sha256=jVl_BOHGqGcDFbA2HnaWS830fczGtWJ_TR-Dp4Ga1ig,25997
38
+ rxnn-0.2.61.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
+ rxnn-0.2.61.dist-info/RECORD,,
File without changes
File without changes