rxnn 0.2.44__py3-none-any.whl → 0.2.46__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/memory/attention.py CHANGED
@@ -11,6 +11,7 @@ class StmMemoryAttention(nn.Module):
11
11
  use_gated_residual: bool = False,
12
12
  per_slot_gate: bool = False,
13
13
  init_gate: float = 0.0,
14
+ use_dynamic_gate: bool = False,
14
15
  *args,
15
16
  **kwargs
16
17
  ):
@@ -22,8 +23,10 @@ class StmMemoryAttention(nn.Module):
22
23
  self.num_layers = len(attention_layers)
23
24
  self.use_gated_residual = use_gated_residual
24
25
  self.per_slot_gate = per_slot_gate
26
+ self.use_dynamic_gate = use_dynamic_gate
25
27
  if self.use_gated_residual:
26
- self.gate = nn.Parameter(torch.full((self.num_layers, self.stm.stm_size, 1), init_gate) if self.per_slot_gate else torch.full((self.num_layers,), init_gate))
28
+ gate_shape = (self.num_layers, self.stm.stm_size, 1) if self.per_slot_gate else (self.num_layers,)
29
+ self.gate = nn.Parameter(torch.full(gate_shape, init_gate))
27
30
 
28
31
  def update_max_len(self, max_seq_len: int):
29
32
  for i in range(self.num_layers):
@@ -44,7 +47,8 @@ class StmMemoryAttention(nn.Module):
44
47
  new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=mask)
45
48
  if self.use_gated_residual:
46
49
  # gated residual
47
- layer_gate = torch.sigmoid(self.gate[i])
50
+ gate_input = self.gate[i] * (new_layer_stm + layer_stm) if self.use_dynamic_gate else self.gate[i]
51
+ layer_gate = torch.sigmoid(gate_input)
48
52
  new_stm[i] = layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
49
53
  else:
50
54
  new_stm[i] = new_layer_stm + layer_stm # residual
rxnn/rxt/models.py CHANGED
@@ -253,6 +253,7 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
253
253
  use_gated_residual: bool = False,
254
254
  residual_per_slot_gate: bool = False,
255
255
  residual_init_gate: float = 0.0,
256
+ use_dynamic_residual_gate: bool = False,
256
257
  **kwargs,
257
258
  ):
258
259
  super(RxTAlphaMemoryAttention, self).__init__(**kwargs)
@@ -281,7 +282,8 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
281
282
  attention_layers = nn.ModuleList([att_init() for _ in range(num_layers)])
282
283
  self.model = StmMemoryAttention(
283
284
  stm, attention_layers, memory_norm_layers,
284
- use_gated_residual=use_gated_residual, per_slot_gate=residual_per_slot_gate, init_gate=residual_init_gate
285
+ use_gated_residual=use_gated_residual, per_slot_gate=residual_per_slot_gate,
286
+ init_gate=residual_init_gate, use_dynamic_gate=use_dynamic_residual_gate,
285
287
  )
286
288
 
287
289
  def freeze(self):
@@ -533,7 +533,7 @@ class MrlTrainerCallback:
533
533
  reward: float) -> None:
534
534
  pass
535
535
 
536
- def on_reward(self, actor: nn.Module, reward: float, generated: str, reference: str, saved_data: str, eval_mode: bool) -> None:
536
+ def on_reward(self, actor: nn.Module, rewards: list[float], generated: str, reference: str, saved_data: str, eval_mode: bool) -> None:
537
537
  pass
538
538
 
539
539
  def on_update_epoch_start(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int) -> None:
@@ -563,20 +563,20 @@ class MrlPrintCallback(MrlTrainerCallback):
563
563
  def on_epoch_start(self, actor: nn.Module, epoch: int, stage_epochs: int, curriculum_config: dict,
564
564
  global_epoch: int, global_epochs: int) -> None:
565
565
  print(
566
- f'Starting epoch {epoch}/{stage_epochs} (stage) | {global_epoch}/{global_epochs} (global) for {curriculum_config["steps"]} steps in {curriculum_config["strategy"]} strategy.')
566
+ f'Starting epoch {epoch}/{stage_epochs - 1} (stage) | {global_epoch}/{global_epochs} (global) for {curriculum_config["steps"]} steps in {curriculum_config["strategy"]} strategy.')
567
567
 
568
568
  def on_epoch_end(self, actor: nn.Module, epoch: int, stage_epochs: int, policy_loss: float,
569
569
  critic_loss: float, global_epoch: int, global_epochs: int) -> None:
570
- print(f'Finished epoch {epoch}/{stage_epochs} (stage) | {global_epoch}/{global_epochs} (global)')
570
+ print(f'Finished epoch {epoch}/{stage_epochs - 1} (stage) | {global_epoch}/{global_epochs} (global)')
571
571
  print(f'Policy mean loss: {policy_loss} | Critic mean loss: {critic_loss}')
572
572
 
573
573
  def on_episode_collected(self, actor: nn.Module, batch_idx: int, episode_trajectories: list[dict],
574
574
  reward: float) -> None:
575
575
  print(f'Collected {batch_idx} episode | mean reward {reward}')
576
576
 
577
- def on_reward(self, actor: nn.Module, reward: float, generated: dict[str, torch.Tensor],
577
+ def on_reward(self, actor: nn.Module, rewards: list[float], generated: dict[str, torch.Tensor],
578
578
  reference: dict[str, torch.Tensor], saved_data: dict[str, torch.Tensor], eval_mode: bool) -> None:
579
- print(f"{'Eval' if eval_mode else 'Train'} | Collected reward {reward}")
579
+ print(f"{'Eval' if eval_mode else 'Train'} | Mean reward: {sum(rewards) / len(rewards)} | All collected rewards: {rewards}")
580
580
 
581
581
  def on_update_epoch_start(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int) -> None:
582
582
  print(f'Epoch {global_epoch} | Starting update epoch {update_epoch}')
@@ -780,7 +780,7 @@ class MrlGeneratedTokensCallback(MrlTrainerCallback):
780
780
  self.steps_log_interval = steps_log_interval
781
781
  self.step = 0
782
782
 
783
- def on_reward(self, actor: nn.Module, reward: float, generated: dict[str, torch.Tensor],
783
+ def on_reward(self, actor: nn.Module, rewards: list[float], generated: dict[str, torch.Tensor],
784
784
  reference: dict[str, torch.Tensor], saved_data: dict[str, torch.Tensor], eval_mode: bool) -> None:
785
785
  self.step += 1
786
786
  attention_mask = generated['attention_mask']
rxnn/training/mrl.py CHANGED
@@ -91,6 +91,7 @@ class MrlTrajectoryEpisode(TypedDict):
91
91
  reset_stm: bool
92
92
  steps: list[MrlTrajectoryStep]
93
93
 
94
+ OptimField: TypeAlias = Literal['lr', 'critic_lr', 'weight_decay', 'critic_weight_decay', 'separate_memory_lr', 'memory_lr']
94
95
 
95
96
  class MRLTrainer:
96
97
  def __init__(
@@ -981,8 +982,19 @@ class MRLTrainer:
981
982
  self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
982
983
  self.update_epochs = config.get('update_epochs', self.shared_update_epochs) # Internal update epochs
983
984
  self.freeze_embeddings = config.get('freeze_embeddings', self.shared_freeze_embeddings)
984
- if config['lr'] is not None or config['critic_lr'] is not None or config['weight_decay'] is not None or config[
985
- 'critic_weight_decay'] is not None or (config['separate_memory_lr'] and config['memory_lr'] is not None):
985
+
986
+
987
+
988
+ def has_param(field: OptimField) -> bool:
989
+ return field in config and config[field] is not None
990
+
991
+ optim_params: list[OptimField] = ['lr', 'critic_lr', 'weight_decay', 'critic_weight_decay']
992
+
993
+ has_any_optim_param = any(
994
+ has_param(field) for field in optim_params
995
+ ) or (has_param('separate_memory_lr') and config['separate_memory_lr'] and has_param('memory_lr'))
996
+
997
+ if has_any_optim_param:
986
998
  if config.get('separate_memory_lr', False):
987
999
  self.optim_config = {
988
1000
  'lr': config.get('lr', self.base_optim_config['lr']),
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.44
3
+ Version: 0.2.46
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -5,19 +5,19 @@ rxnn/experimental/attention.py,sha256=46qwZLJuZMpIBrZ-r9DaQEPPmmZkO464C3Tkm_Mq-c
5
5
  rxnn/experimental/models.py,sha256=foBo0n0ufvBnfIdJomiEg3CuSOiWSt-q5ako7vzYxx4,4888
6
6
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
7
7
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
- rxnn/memory/attention.py,sha256=POszZeW0QBKOh4VTDVekmZGKKwUr1Zj0FOAilTv8Vyg,2411
8
+ rxnn/memory/attention.py,sha256=sXh6f_iOpEYCaqyG-QVp_C_A9IF0QcXTi3hW5G8FCwA,2630
9
9
  rxnn/memory/norm.py,sha256=E98jOQEuIOFFhlkvS8s4fFN-D4tLO6vaOqnObv1oVmA,6592
10
10
  rxnn/memory/stm.py,sha256=jv57gsH9XW19sLbxpRDqsp1yfsii_4Ef4Ncr_ztk-i4,3937
11
11
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- rxnn/rxt/models.py,sha256=jh7TNLu_7CL0PH_T99rMZHcLezFPiZi-xnPazNyn_dU,14563
12
+ rxnn/rxt/models.py,sha256=4MbCL4xGY3ceewZQmopjmwAyLQS92L6KLOPqaW7-Fho,14673
13
13
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  rxnn/training/base.py,sha256=CqaArEZYOdH64nmKfx28U3GI46TzO4oNkjf_hrF23Cw,11835
15
15
  rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
16
- rxnn/training/callbacks.py,sha256=p72lbzFAmFjpcUvyy4aUB3qd53I8C6Sk5w9nQvsKgTk,35852
16
+ rxnn/training/callbacks.py,sha256=RPW3Lisi31VJvoYyZeAF3dQzttrceDQDsZ6G5Xl09HM,35933
17
17
  rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
18
18
  rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
19
19
  rxnn/training/models.py,sha256=tqABOt_xEcWbZNEW2I2Jt-3eyaGICK011zILwuTk6Zc,9082
20
- rxnn/training/mrl.py,sha256=BvrwqrIIyg_EmUA5p7c6UBcfFQ0ePIcl-EHEFQqyl2E,59472
20
+ rxnn/training/mrl.py,sha256=L4G7xSPlxsymvNhvsSloCpaqYjOXxEm7GmKilM_Ojvc,59809
21
21
  rxnn/training/reward.py,sha256=B7nerPk9eNAv2i7umtNF88tVQVwijNNrchIrEITGHKk,11623
22
22
  rxnn/training/rl.py,sha256=q4NzIZAmXRHVToT13IHrPTtEikWQUvT0NO0IjApjAO8,6171
23
23
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
33
33
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
34
34
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
35
35
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
36
- rxnn-0.2.44.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
- rxnn-0.2.44.dist-info/METADATA,sha256=tW2Ve4whRK2LfCxix10dTLS5Dl_0C6KhcK8FsoKq-x0,25960
38
- rxnn-0.2.44.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
- rxnn-0.2.44.dist-info/RECORD,,
36
+ rxnn-0.2.46.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
+ rxnn-0.2.46.dist-info/METADATA,sha256=hpTQT4p75cKrAaGOz_56gCBm1rT_y-Nr1TI9Mhv6wv0,25960
38
+ rxnn-0.2.46.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
+ rxnn-0.2.46.dist-info/RECORD,,
File without changes
File without changes