rxnn 0.2.45__tar.gz → 0.2.47__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rxnn-0.2.45 → rxnn-0.2.47}/PKG-INFO +1 -1
- {rxnn-0.2.45 → rxnn-0.2.47}/pyproject.toml +1 -1
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/memory/attention.py +15 -8
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/rxt/models.py +5 -3
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/training/callbacks.py +24 -8
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/training/models.py +1 -1
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/training/mrl.py +2 -2
- {rxnn-0.2.45 → rxnn-0.2.47}/LICENSE +0 -0
- {rxnn-0.2.45 → rxnn-0.2.47}/README.md +0 -0
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/.DS_Store +0 -0
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/__init__.py +0 -0
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/experimental/__init__.py +0 -0
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/experimental/attention.py +0 -0
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/experimental/models.py +0 -0
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/experimental/moe.py +0 -0
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/memory/__init__.py +0 -0
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/memory/norm.py +0 -0
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/memory/stm.py +0 -0
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/rxt/__init__.py +0 -0
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/training/__init__.py +0 -0
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/training/base.py +0 -0
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/training/bml.py +0 -0
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/training/dataset.py +0 -0
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/training/ddp.py +0 -0
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/training/reward.py +0 -0
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/training/rl.py +0 -0
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/training/scheduler.py +0 -0
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/training/tokenizer.py +0 -0
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/training/utils.py +0 -0
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/transformers/__init__.py +0 -0
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/transformers/attention.py +0 -0
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/transformers/ff.py +0 -0
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/transformers/layers.py +0 -0
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/transformers/mask.py +0 -0
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/transformers/models.py +0 -0
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/transformers/moe.py +0 -0
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/transformers/positional.py +0 -0
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/transformers/sampler.py +0 -0
- {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/utils.py +0 -0
@@ -11,6 +11,7 @@ class StmMemoryAttention(nn.Module):
|
|
11
11
|
use_gated_residual: bool = False,
|
12
12
|
per_slot_gate: bool = False,
|
13
13
|
init_gate: float = 0.0,
|
14
|
+
use_dynamic_gate: bool = False,
|
14
15
|
*args,
|
15
16
|
**kwargs
|
16
17
|
):
|
@@ -22,17 +23,26 @@ class StmMemoryAttention(nn.Module):
|
|
22
23
|
self.num_layers = len(attention_layers)
|
23
24
|
self.use_gated_residual = use_gated_residual
|
24
25
|
self.per_slot_gate = per_slot_gate
|
26
|
+
self.use_dynamic_gate = use_dynamic_gate
|
25
27
|
if self.use_gated_residual:
|
26
|
-
|
28
|
+
gate_shape = (self.num_layers, self.stm.stm_size, 1) if self.per_slot_gate else (self.num_layers,)
|
29
|
+
self.gate = nn.Parameter(torch.full(gate_shape, init_gate))
|
27
30
|
|
28
31
|
def update_max_len(self, max_seq_len: int):
|
29
32
|
for i in range(self.num_layers):
|
30
33
|
if self.attention_layers[i].rope is not None:
|
31
34
|
self.attention_layers[i].rope.update_max_len(max_seq_len)
|
32
35
|
|
33
|
-
def
|
34
|
-
|
36
|
+
def _residual_gate(self, gate: torch.Tensor, layer_stm: torch.Tensor, new_layer_stm: torch.Tensor) -> torch.Tensor:
|
37
|
+
if self.use_dynamic_gate:
|
38
|
+
mean_dim = -1 if self.per_slot_gate else [1, 2]
|
39
|
+
gate_input = gate * (new_layer_stm + layer_stm).mean(dim=mean_dim, keepdim=True)
|
40
|
+
layer_gate = torch.sigmoid(gate_input)
|
41
|
+
else:
|
42
|
+
layer_gate = torch.sigmoid(gate)
|
43
|
+
return layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
|
35
44
|
|
45
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
36
46
|
new_stm = torch.zeros_like(self.stm.memory)
|
37
47
|
for i in range(self.num_layers):
|
38
48
|
layer_stm = self.stm(i)
|
@@ -41,13 +51,10 @@ class StmMemoryAttention(nn.Module):
|
|
41
51
|
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
42
52
|
encoded_layer_data = x[i]
|
43
53
|
normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
|
44
|
-
new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data
|
54
|
+
new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data)
|
45
55
|
if self.use_gated_residual:
|
46
|
-
# gated residual
|
47
|
-
layer_gate = torch.sigmoid(self.gate[i])
|
48
|
-
new_stm[i] = layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
|
56
|
+
new_stm[i] = self._residual_gate(self.gate[i], layer_stm, new_layer_stm) # gated residual
|
49
57
|
else:
|
50
58
|
new_stm[i] = new_layer_stm + layer_stm # residual
|
51
59
|
self.stm.update_all(new_stm)
|
52
60
|
return self.stm.memory
|
53
|
-
|
@@ -253,6 +253,7 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
253
253
|
use_gated_residual: bool = False,
|
254
254
|
residual_per_slot_gate: bool = False,
|
255
255
|
residual_init_gate: float = 0.0,
|
256
|
+
use_dynamic_residual_gate: bool = False,
|
256
257
|
**kwargs,
|
257
258
|
):
|
258
259
|
super(RxTAlphaMemoryAttention, self).__init__(**kwargs)
|
@@ -281,7 +282,8 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
281
282
|
attention_layers = nn.ModuleList([att_init() for _ in range(num_layers)])
|
282
283
|
self.model = StmMemoryAttention(
|
283
284
|
stm, attention_layers, memory_norm_layers,
|
284
|
-
use_gated_residual=use_gated_residual, per_slot_gate=residual_per_slot_gate,
|
285
|
+
use_gated_residual=use_gated_residual, per_slot_gate=residual_per_slot_gate,
|
286
|
+
init_gate=residual_init_gate, use_dynamic_gate=use_dynamic_residual_gate,
|
285
287
|
)
|
286
288
|
|
287
289
|
def freeze(self):
|
@@ -304,8 +306,8 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
304
306
|
def clone_reset_memory(self):
|
305
307
|
self.model.stm.clone_detach_reset()
|
306
308
|
|
307
|
-
def forward(self, x: torch.Tensor
|
308
|
-
return self.model(x
|
309
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
310
|
+
return self.model(x)
|
309
311
|
|
310
312
|
class RxTAlphaCriticEncoder(RxTAlphaComponentBase, pipeline_tag="text-classification", license="apache-2.0"):
|
311
313
|
"""RxT-Alpha (Reactive Transformer) encoder model"""
|
@@ -533,7 +533,7 @@ class MrlTrainerCallback:
|
|
533
533
|
reward: float) -> None:
|
534
534
|
pass
|
535
535
|
|
536
|
-
def on_reward(self, actor: nn.Module,
|
536
|
+
def on_reward(self, actor: nn.Module, rewards: list[float], generated: str, reference: str, saved_data: str, eval_mode: bool) -> None:
|
537
537
|
pass
|
538
538
|
|
539
539
|
def on_update_epoch_start(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int) -> None:
|
@@ -560,33 +560,49 @@ class MrlTrainerCallback:
|
|
560
560
|
|
561
561
|
|
562
562
|
class MrlPrintCallback(MrlTrainerCallback):
|
563
|
+
def __init__(self, update_steps_interval: int = 10) -> None:
|
564
|
+
super(MrlPrintCallback, self).__init__()
|
565
|
+
self.update_steps_interval = update_steps_interval
|
566
|
+
self.policy_losses = []
|
567
|
+
self.critic_losses = []
|
568
|
+
|
563
569
|
def on_epoch_start(self, actor: nn.Module, epoch: int, stage_epochs: int, curriculum_config: dict,
|
564
570
|
global_epoch: int, global_epochs: int) -> None:
|
565
571
|
print(
|
566
|
-
f'Starting epoch {epoch}/{stage_epochs} (stage) | {global_epoch}/{global_epochs} (global) for {curriculum_config["steps"]} steps in {curriculum_config["strategy"]} strategy.')
|
572
|
+
f'Starting epoch {epoch}/{stage_epochs - 1} (stage) | {global_epoch}/{global_epochs} (global) for {curriculum_config["steps"]} steps in {curriculum_config["strategy"]} strategy.')
|
567
573
|
|
568
574
|
def on_epoch_end(self, actor: nn.Module, epoch: int, stage_epochs: int, policy_loss: float,
|
569
575
|
critic_loss: float, global_epoch: int, global_epochs: int) -> None:
|
570
|
-
print(f'Finished epoch {epoch}/{stage_epochs} (stage) | {global_epoch}/{global_epochs} (global)')
|
576
|
+
print(f'Finished epoch {epoch}/{stage_epochs - 1} (stage) | {global_epoch}/{global_epochs} (global)')
|
571
577
|
print(f'Policy mean loss: {policy_loss} | Critic mean loss: {critic_loss}')
|
572
578
|
|
573
579
|
def on_episode_collected(self, actor: nn.Module, batch_idx: int, episode_trajectories: list[dict],
|
574
580
|
reward: float) -> None:
|
575
581
|
print(f'Collected {batch_idx} episode | mean reward {reward}')
|
576
582
|
|
577
|
-
def on_reward(self, actor: nn.Module,
|
583
|
+
def on_reward(self, actor: nn.Module, rewards: list[float], generated: dict[str, torch.Tensor],
|
578
584
|
reference: dict[str, torch.Tensor], saved_data: dict[str, torch.Tensor], eval_mode: bool) -> None:
|
579
|
-
print(f"{'Eval' if eval_mode else 'Train'} |
|
585
|
+
print(f"{'Eval' if eval_mode else 'Train'} | Mean reward: {sum(rewards) / len(rewards)} | All collected rewards: {rewards}")
|
580
586
|
|
581
587
|
def on_update_epoch_start(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int) -> None:
|
582
588
|
print(f'Epoch {global_epoch} | Starting update epoch {update_epoch}')
|
583
589
|
|
584
590
|
def on_batch_updated(self, actor: nn.Module, epoch: int, step: int, policy_loss: float) -> None:
|
585
|
-
|
591
|
+
if step != 0 and step % self.update_steps_interval == 0:
|
592
|
+
loss = sum(self.policy_losses) / len(self.policy_losses)
|
593
|
+
self.policy_losses = []
|
594
|
+
print(f'Epoch {epoch} | Steps {step - self.update_steps_interval} - {step} - mean policy loss {loss} | current policy loss {policy_loss}')
|
595
|
+
else:
|
596
|
+
self.policy_losses.append(policy_loss)
|
586
597
|
|
587
598
|
def on_critic_updated(self, actor: nn.Module, critic: nn.Module, epoch: int, step: int,
|
588
599
|
critic_loss: float) -> None:
|
589
|
-
|
600
|
+
if step != 0 and step % self.update_steps_interval == 0:
|
601
|
+
loss = sum(self.critic_losses) / len(self.critic_losses)
|
602
|
+
self.critic_losses = []
|
603
|
+
print(f'Epoch {epoch} | Steps {step - self.update_steps_interval} - {step} - mean critic loss {loss} | current critic loss {critic_loss}')
|
604
|
+
else:
|
605
|
+
self.critic_losses.append(critic_loss)
|
590
606
|
|
591
607
|
def on_update_epoch_end(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int, policy_loss: float, critic_loss: float) -> None:
|
592
608
|
print(f'Epoch {global_epoch} | Update epoch {update_epoch} - mean policy loss {policy_loss} | mean critic loss {critic_loss}')
|
@@ -780,7 +796,7 @@ class MrlGeneratedTokensCallback(MrlTrainerCallback):
|
|
780
796
|
self.steps_log_interval = steps_log_interval
|
781
797
|
self.step = 0
|
782
798
|
|
783
|
-
def on_reward(self, actor: nn.Module,
|
799
|
+
def on_reward(self, actor: nn.Module, rewards: list[float], generated: dict[str, torch.Tensor],
|
784
800
|
reference: dict[str, torch.Tensor], saved_data: dict[str, torch.Tensor], eval_mode: bool) -> None:
|
785
801
|
self.step += 1
|
786
802
|
attention_mask = generated['attention_mask']
|
@@ -204,7 +204,7 @@ class MrlActorModel(nn.Module):
|
|
204
204
|
return self.decoder(x, attention_mask=attention_mask)
|
205
205
|
else:
|
206
206
|
_, ed = self.encoder(x, attention_mask=attention_mask)
|
207
|
-
return self.memory_attention(ed
|
207
|
+
return self.memory_attention(ed)
|
208
208
|
|
209
209
|
|
210
210
|
class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipeline_tag="text-classification"):
|
@@ -941,7 +941,7 @@ class MRLTrainer:
|
|
941
941
|
]
|
942
942
|
elif mode == 'fetch':
|
943
943
|
params = [
|
944
|
-
{'params': self.actor.embedding_parameters(), 'lr':
|
944
|
+
{'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
|
945
945
|
{'params': self.actor.encoder.not_memory_parameters(), 'lr': unfreeze_lr},
|
946
946
|
{'params': self.actor.encoder.memory_parameters(), 'lr': unfreeze_lr},
|
947
947
|
{'params': self.actor.memory_attention_parameters(), 'lr': unfreeze_lr},
|
@@ -950,7 +950,7 @@ class MRLTrainer:
|
|
950
950
|
]
|
951
951
|
elif mode == 'joint':
|
952
952
|
params = [
|
953
|
-
{'params': self.actor.embedding_parameters(), 'lr':
|
953
|
+
{'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
|
954
954
|
{'params': self.actor.encoder.not_memory_parameters(), 'lr': unfreeze_lr},
|
955
955
|
{'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
|
956
956
|
{'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|