rxnn 0.2.46__tar.gz → 0.2.47__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rxnn-0.2.46 → rxnn-0.2.47}/PKG-INFO +1 -1
- {rxnn-0.2.46 → rxnn-0.2.47}/pyproject.toml +1 -1
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/memory/attention.py +11 -8
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/rxt/models.py +2 -2
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/training/callbacks.py +18 -2
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/training/models.py +1 -1
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/training/mrl.py +2 -2
- {rxnn-0.2.46 → rxnn-0.2.47}/LICENSE +0 -0
- {rxnn-0.2.46 → rxnn-0.2.47}/README.md +0 -0
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/.DS_Store +0 -0
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/__init__.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/experimental/__init__.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/experimental/attention.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/experimental/models.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/experimental/moe.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/memory/__init__.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/memory/norm.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/memory/stm.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/rxt/__init__.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/training/__init__.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/training/base.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/training/bml.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/training/dataset.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/training/ddp.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/training/reward.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/training/rl.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/training/scheduler.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/training/tokenizer.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/training/utils.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/transformers/__init__.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/transformers/attention.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/transformers/ff.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/transformers/layers.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/transformers/mask.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/transformers/models.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/transformers/moe.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/transformers/positional.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/transformers/sampler.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.47}/src/rxnn/utils.py +0 -0
@@ -33,9 +33,16 @@ class StmMemoryAttention(nn.Module):
|
|
33
33
|
if self.attention_layers[i].rope is not None:
|
34
34
|
self.attention_layers[i].rope.update_max_len(max_seq_len)
|
35
35
|
|
36
|
-
def
|
37
|
-
|
36
|
+
def _residual_gate(self, gate: torch.Tensor, layer_stm: torch.Tensor, new_layer_stm: torch.Tensor) -> torch.Tensor:
|
37
|
+
if self.use_dynamic_gate:
|
38
|
+
mean_dim = -1 if self.per_slot_gate else [1, 2]
|
39
|
+
gate_input = gate * (new_layer_stm + layer_stm).mean(dim=mean_dim, keepdim=True)
|
40
|
+
layer_gate = torch.sigmoid(gate_input)
|
41
|
+
else:
|
42
|
+
layer_gate = torch.sigmoid(gate)
|
43
|
+
return layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
|
38
44
|
|
45
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
39
46
|
new_stm = torch.zeros_like(self.stm.memory)
|
40
47
|
for i in range(self.num_layers):
|
41
48
|
layer_stm = self.stm(i)
|
@@ -44,14 +51,10 @@ class StmMemoryAttention(nn.Module):
|
|
44
51
|
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
45
52
|
encoded_layer_data = x[i]
|
46
53
|
normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
|
47
|
-
new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data
|
54
|
+
new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data)
|
48
55
|
if self.use_gated_residual:
|
49
|
-
# gated residual
|
50
|
-
gate_input = self.gate[i] * (new_layer_stm + layer_stm) if self.use_dynamic_gate else self.gate[i]
|
51
|
-
layer_gate = torch.sigmoid(gate_input)
|
52
|
-
new_stm[i] = layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
|
56
|
+
new_stm[i] = self._residual_gate(self.gate[i], layer_stm, new_layer_stm) # gated residual
|
53
57
|
else:
|
54
58
|
new_stm[i] = new_layer_stm + layer_stm # residual
|
55
59
|
self.stm.update_all(new_stm)
|
56
60
|
return self.stm.memory
|
57
|
-
|
@@ -306,8 +306,8 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
306
306
|
def clone_reset_memory(self):
|
307
307
|
self.model.stm.clone_detach_reset()
|
308
308
|
|
309
|
-
def forward(self, x: torch.Tensor
|
310
|
-
return self.model(x
|
309
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
310
|
+
return self.model(x)
|
311
311
|
|
312
312
|
class RxTAlphaCriticEncoder(RxTAlphaComponentBase, pipeline_tag="text-classification", license="apache-2.0"):
|
313
313
|
"""RxT-Alpha (Reactive Transformer) encoder model"""
|
@@ -560,6 +560,12 @@ class MrlTrainerCallback:
|
|
560
560
|
|
561
561
|
|
562
562
|
class MrlPrintCallback(MrlTrainerCallback):
|
563
|
+
def __init__(self, update_steps_interval: int = 10) -> None:
|
564
|
+
super(MrlPrintCallback, self).__init__()
|
565
|
+
self.update_steps_interval = update_steps_interval
|
566
|
+
self.policy_losses = []
|
567
|
+
self.critic_losses = []
|
568
|
+
|
563
569
|
def on_epoch_start(self, actor: nn.Module, epoch: int, stage_epochs: int, curriculum_config: dict,
|
564
570
|
global_epoch: int, global_epochs: int) -> None:
|
565
571
|
print(
|
@@ -582,11 +588,21 @@ class MrlPrintCallback(MrlTrainerCallback):
|
|
582
588
|
print(f'Epoch {global_epoch} | Starting update epoch {update_epoch}')
|
583
589
|
|
584
590
|
def on_batch_updated(self, actor: nn.Module, epoch: int, step: int, policy_loss: float) -> None:
|
585
|
-
|
591
|
+
if step != 0 and step % self.update_steps_interval == 0:
|
592
|
+
loss = sum(self.policy_losses) / len(self.policy_losses)
|
593
|
+
self.policy_losses = []
|
594
|
+
print(f'Epoch {epoch} | Steps {step - self.update_steps_interval} - {step} - mean policy loss {loss} | current policy loss {policy_loss}')
|
595
|
+
else:
|
596
|
+
self.policy_losses.append(policy_loss)
|
586
597
|
|
587
598
|
def on_critic_updated(self, actor: nn.Module, critic: nn.Module, epoch: int, step: int,
|
588
599
|
critic_loss: float) -> None:
|
589
|
-
|
600
|
+
if step != 0 and step % self.update_steps_interval == 0:
|
601
|
+
loss = sum(self.critic_losses) / len(self.critic_losses)
|
602
|
+
self.critic_losses = []
|
603
|
+
print(f'Epoch {epoch} | Steps {step - self.update_steps_interval} - {step} - mean critic loss {loss} | current critic loss {critic_loss}')
|
604
|
+
else:
|
605
|
+
self.critic_losses.append(critic_loss)
|
590
606
|
|
591
607
|
def on_update_epoch_end(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int, policy_loss: float, critic_loss: float) -> None:
|
592
608
|
print(f'Epoch {global_epoch} | Update epoch {update_epoch} - mean policy loss {policy_loss} | mean critic loss {critic_loss}')
|
@@ -204,7 +204,7 @@ class MrlActorModel(nn.Module):
|
|
204
204
|
return self.decoder(x, attention_mask=attention_mask)
|
205
205
|
else:
|
206
206
|
_, ed = self.encoder(x, attention_mask=attention_mask)
|
207
|
-
return self.memory_attention(ed
|
207
|
+
return self.memory_attention(ed)
|
208
208
|
|
209
209
|
|
210
210
|
class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipeline_tag="text-classification"):
|
@@ -941,7 +941,7 @@ class MRLTrainer:
|
|
941
941
|
]
|
942
942
|
elif mode == 'fetch':
|
943
943
|
params = [
|
944
|
-
{'params': self.actor.embedding_parameters(), 'lr':
|
944
|
+
{'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
|
945
945
|
{'params': self.actor.encoder.not_memory_parameters(), 'lr': unfreeze_lr},
|
946
946
|
{'params': self.actor.encoder.memory_parameters(), 'lr': unfreeze_lr},
|
947
947
|
{'params': self.actor.memory_attention_parameters(), 'lr': unfreeze_lr},
|
@@ -950,7 +950,7 @@ class MRLTrainer:
|
|
950
950
|
]
|
951
951
|
elif mode == 'joint':
|
952
952
|
params = [
|
953
|
-
{'params': self.actor.embedding_parameters(), 'lr':
|
953
|
+
{'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
|
954
954
|
{'params': self.actor.encoder.not_memory_parameters(), 'lr': unfreeze_lr},
|
955
955
|
{'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
|
956
956
|
{'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|