rxnn 0.2.44__py3-none-any.whl → 0.2.46__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/memory/attention.py +6 -2
- rxnn/rxt/models.py +3 -1
- rxnn/training/callbacks.py +6 -6
- rxnn/training/mrl.py +14 -2
- {rxnn-0.2.44.dist-info → rxnn-0.2.46.dist-info}/METADATA +1 -1
- {rxnn-0.2.44.dist-info → rxnn-0.2.46.dist-info}/RECORD +8 -8
- {rxnn-0.2.44.dist-info → rxnn-0.2.46.dist-info}/LICENSE +0 -0
- {rxnn-0.2.44.dist-info → rxnn-0.2.46.dist-info}/WHEEL +0 -0
rxnn/memory/attention.py
CHANGED
@@ -11,6 +11,7 @@ class StmMemoryAttention(nn.Module):
|
|
11
11
|
use_gated_residual: bool = False,
|
12
12
|
per_slot_gate: bool = False,
|
13
13
|
init_gate: float = 0.0,
|
14
|
+
use_dynamic_gate: bool = False,
|
14
15
|
*args,
|
15
16
|
**kwargs
|
16
17
|
):
|
@@ -22,8 +23,10 @@ class StmMemoryAttention(nn.Module):
|
|
22
23
|
self.num_layers = len(attention_layers)
|
23
24
|
self.use_gated_residual = use_gated_residual
|
24
25
|
self.per_slot_gate = per_slot_gate
|
26
|
+
self.use_dynamic_gate = use_dynamic_gate
|
25
27
|
if self.use_gated_residual:
|
26
|
-
|
28
|
+
gate_shape = (self.num_layers, self.stm.stm_size, 1) if self.per_slot_gate else (self.num_layers,)
|
29
|
+
self.gate = nn.Parameter(torch.full(gate_shape, init_gate))
|
27
30
|
|
28
31
|
def update_max_len(self, max_seq_len: int):
|
29
32
|
for i in range(self.num_layers):
|
@@ -44,7 +47,8 @@ class StmMemoryAttention(nn.Module):
|
|
44
47
|
new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=mask)
|
45
48
|
if self.use_gated_residual:
|
46
49
|
# gated residual
|
47
|
-
|
50
|
+
gate_input = self.gate[i] * (new_layer_stm + layer_stm) if self.use_dynamic_gate else self.gate[i]
|
51
|
+
layer_gate = torch.sigmoid(gate_input)
|
48
52
|
new_stm[i] = layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
|
49
53
|
else:
|
50
54
|
new_stm[i] = new_layer_stm + layer_stm # residual
|
rxnn/rxt/models.py
CHANGED
@@ -253,6 +253,7 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
253
253
|
use_gated_residual: bool = False,
|
254
254
|
residual_per_slot_gate: bool = False,
|
255
255
|
residual_init_gate: float = 0.0,
|
256
|
+
use_dynamic_residual_gate: bool = False,
|
256
257
|
**kwargs,
|
257
258
|
):
|
258
259
|
super(RxTAlphaMemoryAttention, self).__init__(**kwargs)
|
@@ -281,7 +282,8 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
281
282
|
attention_layers = nn.ModuleList([att_init() for _ in range(num_layers)])
|
282
283
|
self.model = StmMemoryAttention(
|
283
284
|
stm, attention_layers, memory_norm_layers,
|
284
|
-
use_gated_residual=use_gated_residual, per_slot_gate=residual_per_slot_gate,
|
285
|
+
use_gated_residual=use_gated_residual, per_slot_gate=residual_per_slot_gate,
|
286
|
+
init_gate=residual_init_gate, use_dynamic_gate=use_dynamic_residual_gate,
|
285
287
|
)
|
286
288
|
|
287
289
|
def freeze(self):
|
rxnn/training/callbacks.py
CHANGED
@@ -533,7 +533,7 @@ class MrlTrainerCallback:
|
|
533
533
|
reward: float) -> None:
|
534
534
|
pass
|
535
535
|
|
536
|
-
def on_reward(self, actor: nn.Module,
|
536
|
+
def on_reward(self, actor: nn.Module, rewards: list[float], generated: str, reference: str, saved_data: str, eval_mode: bool) -> None:
|
537
537
|
pass
|
538
538
|
|
539
539
|
def on_update_epoch_start(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int) -> None:
|
@@ -563,20 +563,20 @@ class MrlPrintCallback(MrlTrainerCallback):
|
|
563
563
|
def on_epoch_start(self, actor: nn.Module, epoch: int, stage_epochs: int, curriculum_config: dict,
|
564
564
|
global_epoch: int, global_epochs: int) -> None:
|
565
565
|
print(
|
566
|
-
f'Starting epoch {epoch}/{stage_epochs} (stage) | {global_epoch}/{global_epochs} (global) for {curriculum_config["steps"]} steps in {curriculum_config["strategy"]} strategy.')
|
566
|
+
f'Starting epoch {epoch}/{stage_epochs - 1} (stage) | {global_epoch}/{global_epochs} (global) for {curriculum_config["steps"]} steps in {curriculum_config["strategy"]} strategy.')
|
567
567
|
|
568
568
|
def on_epoch_end(self, actor: nn.Module, epoch: int, stage_epochs: int, policy_loss: float,
|
569
569
|
critic_loss: float, global_epoch: int, global_epochs: int) -> None:
|
570
|
-
print(f'Finished epoch {epoch}/{stage_epochs} (stage) | {global_epoch}/{global_epochs} (global)')
|
570
|
+
print(f'Finished epoch {epoch}/{stage_epochs - 1} (stage) | {global_epoch}/{global_epochs} (global)')
|
571
571
|
print(f'Policy mean loss: {policy_loss} | Critic mean loss: {critic_loss}')
|
572
572
|
|
573
573
|
def on_episode_collected(self, actor: nn.Module, batch_idx: int, episode_trajectories: list[dict],
|
574
574
|
reward: float) -> None:
|
575
575
|
print(f'Collected {batch_idx} episode | mean reward {reward}')
|
576
576
|
|
577
|
-
def on_reward(self, actor: nn.Module,
|
577
|
+
def on_reward(self, actor: nn.Module, rewards: list[float], generated: dict[str, torch.Tensor],
|
578
578
|
reference: dict[str, torch.Tensor], saved_data: dict[str, torch.Tensor], eval_mode: bool) -> None:
|
579
|
-
print(f"{'Eval' if eval_mode else 'Train'} |
|
579
|
+
print(f"{'Eval' if eval_mode else 'Train'} | Mean reward: {sum(rewards) / len(rewards)} | All collected rewards: {rewards}")
|
580
580
|
|
581
581
|
def on_update_epoch_start(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int) -> None:
|
582
582
|
print(f'Epoch {global_epoch} | Starting update epoch {update_epoch}')
|
@@ -780,7 +780,7 @@ class MrlGeneratedTokensCallback(MrlTrainerCallback):
|
|
780
780
|
self.steps_log_interval = steps_log_interval
|
781
781
|
self.step = 0
|
782
782
|
|
783
|
-
def on_reward(self, actor: nn.Module,
|
783
|
+
def on_reward(self, actor: nn.Module, rewards: list[float], generated: dict[str, torch.Tensor],
|
784
784
|
reference: dict[str, torch.Tensor], saved_data: dict[str, torch.Tensor], eval_mode: bool) -> None:
|
785
785
|
self.step += 1
|
786
786
|
attention_mask = generated['attention_mask']
|
rxnn/training/mrl.py
CHANGED
@@ -91,6 +91,7 @@ class MrlTrajectoryEpisode(TypedDict):
|
|
91
91
|
reset_stm: bool
|
92
92
|
steps: list[MrlTrajectoryStep]
|
93
93
|
|
94
|
+
OptimField: TypeAlias = Literal['lr', 'critic_lr', 'weight_decay', 'critic_weight_decay', 'separate_memory_lr', 'memory_lr']
|
94
95
|
|
95
96
|
class MRLTrainer:
|
96
97
|
def __init__(
|
@@ -981,8 +982,19 @@ class MRLTrainer:
|
|
981
982
|
self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
|
982
983
|
self.update_epochs = config.get('update_epochs', self.shared_update_epochs) # Internal update epochs
|
983
984
|
self.freeze_embeddings = config.get('freeze_embeddings', self.shared_freeze_embeddings)
|
984
|
-
|
985
|
-
|
985
|
+
|
986
|
+
|
987
|
+
|
988
|
+
def has_param(field: OptimField) -> bool:
|
989
|
+
return field in config and config[field] is not None
|
990
|
+
|
991
|
+
optim_params: list[OptimField] = ['lr', 'critic_lr', 'weight_decay', 'critic_weight_decay']
|
992
|
+
|
993
|
+
has_any_optim_param = any(
|
994
|
+
has_param(field) for field in optim_params
|
995
|
+
) or (has_param('separate_memory_lr') and config['separate_memory_lr'] and has_param('memory_lr'))
|
996
|
+
|
997
|
+
if has_any_optim_param:
|
986
998
|
if config.get('separate_memory_lr', False):
|
987
999
|
self.optim_config = {
|
988
1000
|
'lr': config.get('lr', self.base_optim_config['lr']),
|
@@ -5,19 +5,19 @@ rxnn/experimental/attention.py,sha256=46qwZLJuZMpIBrZ-r9DaQEPPmmZkO464C3Tkm_Mq-c
|
|
5
5
|
rxnn/experimental/models.py,sha256=foBo0n0ufvBnfIdJomiEg3CuSOiWSt-q5ako7vzYxx4,4888
|
6
6
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
7
7
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
|
-
rxnn/memory/attention.py,sha256=
|
8
|
+
rxnn/memory/attention.py,sha256=sXh6f_iOpEYCaqyG-QVp_C_A9IF0QcXTi3hW5G8FCwA,2630
|
9
9
|
rxnn/memory/norm.py,sha256=E98jOQEuIOFFhlkvS8s4fFN-D4tLO6vaOqnObv1oVmA,6592
|
10
10
|
rxnn/memory/stm.py,sha256=jv57gsH9XW19sLbxpRDqsp1yfsii_4Ef4Ncr_ztk-i4,3937
|
11
11
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
rxnn/rxt/models.py,sha256=
|
12
|
+
rxnn/rxt/models.py,sha256=4MbCL4xGY3ceewZQmopjmwAyLQS92L6KLOPqaW7-Fho,14673
|
13
13
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
14
|
rxnn/training/base.py,sha256=CqaArEZYOdH64nmKfx28U3GI46TzO4oNkjf_hrF23Cw,11835
|
15
15
|
rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
|
16
|
-
rxnn/training/callbacks.py,sha256=
|
16
|
+
rxnn/training/callbacks.py,sha256=RPW3Lisi31VJvoYyZeAF3dQzttrceDQDsZ6G5Xl09HM,35933
|
17
17
|
rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
|
18
18
|
rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
|
19
19
|
rxnn/training/models.py,sha256=tqABOt_xEcWbZNEW2I2Jt-3eyaGICK011zILwuTk6Zc,9082
|
20
|
-
rxnn/training/mrl.py,sha256=
|
20
|
+
rxnn/training/mrl.py,sha256=L4G7xSPlxsymvNhvsSloCpaqYjOXxEm7GmKilM_Ojvc,59809
|
21
21
|
rxnn/training/reward.py,sha256=B7nerPk9eNAv2i7umtNF88tVQVwijNNrchIrEITGHKk,11623
|
22
22
|
rxnn/training/rl.py,sha256=q4NzIZAmXRHVToT13IHrPTtEikWQUvT0NO0IjApjAO8,6171
|
23
23
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
33
33
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
34
34
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
35
35
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
39
|
-
rxnn-0.2.
|
36
|
+
rxnn-0.2.46.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
37
|
+
rxnn-0.2.46.dist-info/METADATA,sha256=hpTQT4p75cKrAaGOz_56gCBm1rT_y-Nr1TI9Mhv6wv0,25960
|
38
|
+
rxnn-0.2.46.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
39
|
+
rxnn-0.2.46.dist-info/RECORD,,
|
File without changes
|
File without changes
|