rxnn 0.2.46__tar.gz → 0.2.48__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rxnn-0.2.46 → rxnn-0.2.48}/PKG-INFO +1 -1
- {rxnn-0.2.46 → rxnn-0.2.48}/pyproject.toml +1 -1
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/memory/attention.py +11 -8
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/memory/norm.py +3 -1
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/rxt/models.py +2 -2
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/training/callbacks.py +18 -2
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/training/models.py +1 -1
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/training/mrl.py +23 -5
- {rxnn-0.2.46 → rxnn-0.2.48}/LICENSE +0 -0
- {rxnn-0.2.46 → rxnn-0.2.48}/README.md +0 -0
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/.DS_Store +0 -0
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/__init__.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/experimental/__init__.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/experimental/attention.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/experimental/models.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/experimental/moe.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/memory/__init__.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/memory/stm.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/rxt/__init__.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/training/__init__.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/training/base.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/training/bml.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/training/dataset.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/training/ddp.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/training/reward.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/training/rl.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/training/scheduler.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/training/tokenizer.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/training/utils.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/transformers/__init__.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/transformers/attention.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/transformers/ff.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/transformers/layers.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/transformers/mask.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/transformers/models.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/transformers/moe.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/transformers/positional.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/transformers/sampler.py +0 -0
- {rxnn-0.2.46 → rxnn-0.2.48}/src/rxnn/utils.py +0 -0
@@ -33,9 +33,16 @@ class StmMemoryAttention(nn.Module):
|
|
33
33
|
if self.attention_layers[i].rope is not None:
|
34
34
|
self.attention_layers[i].rope.update_max_len(max_seq_len)
|
35
35
|
|
36
|
-
def
|
37
|
-
|
36
|
+
def _residual_gate(self, gate: torch.Tensor, layer_stm: torch.Tensor, new_layer_stm: torch.Tensor) -> torch.Tensor:
|
37
|
+
if self.use_dynamic_gate:
|
38
|
+
mean_dim = -1 if self.per_slot_gate else [1, 2]
|
39
|
+
gate_input = gate * (new_layer_stm + layer_stm).mean(dim=mean_dim, keepdim=True)
|
40
|
+
layer_gate = torch.sigmoid(gate_input)
|
41
|
+
else:
|
42
|
+
layer_gate = torch.sigmoid(gate)
|
43
|
+
return layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
|
38
44
|
|
45
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
39
46
|
new_stm = torch.zeros_like(self.stm.memory)
|
40
47
|
for i in range(self.num_layers):
|
41
48
|
layer_stm = self.stm(i)
|
@@ -44,14 +51,10 @@ class StmMemoryAttention(nn.Module):
|
|
44
51
|
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
45
52
|
encoded_layer_data = x[i]
|
46
53
|
normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
|
47
|
-
new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data
|
54
|
+
new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data)
|
48
55
|
if self.use_gated_residual:
|
49
|
-
# gated residual
|
50
|
-
gate_input = self.gate[i] * (new_layer_stm + layer_stm) if self.use_dynamic_gate else self.gate[i]
|
51
|
-
layer_gate = torch.sigmoid(gate_input)
|
52
|
-
new_stm[i] = layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
|
56
|
+
new_stm[i] = self._residual_gate(self.gate[i], layer_stm, new_layer_stm) # gated residual
|
53
57
|
else:
|
54
58
|
new_stm[i] = new_layer_stm + layer_stm # residual
|
55
59
|
self.stm.update_all(new_stm)
|
56
60
|
return self.stm.memory
|
57
|
-
|
@@ -163,7 +163,7 @@ def init_memory_norm(
|
|
163
163
|
init_scale: float = 1.0,
|
164
164
|
per_dim_scale: bool = False,
|
165
165
|
) -> nn.Module:
|
166
|
-
assert norm_type in ['layer', 'rms', 'adaptive', 'positional']
|
166
|
+
assert norm_type in ['layer', 'rms', 'adaptive', 'positional', 'classic-rms']
|
167
167
|
if norm_type == 'layer':
|
168
168
|
return MemoryLayerNorm(dim, use_gate, init_scale, init_gate)
|
169
169
|
elif norm_type == 'rms':
|
@@ -172,4 +172,6 @@ def init_memory_norm(
|
|
172
172
|
return AdaptiveRMSMemoryNorm(dim, use_gate, decay, init_scale, init_gate)
|
173
173
|
elif norm_type == 'positional':
|
174
174
|
return AdaptivePositionalMemoryNorm(num_slots, dim, decay, use_scale, use_gate, init_gate, per_dim_scale)
|
175
|
+
elif norm_type == 'classic-rms':
|
176
|
+
return nn.RMSNorm(dim)
|
175
177
|
return MemoryLayerNorm(dim, use_gate, init_scale, init_gate)
|
@@ -306,8 +306,8 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
306
306
|
def clone_reset_memory(self):
|
307
307
|
self.model.stm.clone_detach_reset()
|
308
308
|
|
309
|
-
def forward(self, x: torch.Tensor
|
310
|
-
return self.model(x
|
309
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
310
|
+
return self.model(x)
|
311
311
|
|
312
312
|
class RxTAlphaCriticEncoder(RxTAlphaComponentBase, pipeline_tag="text-classification", license="apache-2.0"):
|
313
313
|
"""RxT-Alpha (Reactive Transformer) encoder model"""
|
@@ -560,6 +560,12 @@ class MrlTrainerCallback:
|
|
560
560
|
|
561
561
|
|
562
562
|
class MrlPrintCallback(MrlTrainerCallback):
|
563
|
+
def __init__(self, update_steps_interval: int = 10) -> None:
|
564
|
+
super(MrlPrintCallback, self).__init__()
|
565
|
+
self.update_steps_interval = update_steps_interval
|
566
|
+
self.policy_losses = []
|
567
|
+
self.critic_losses = []
|
568
|
+
|
563
569
|
def on_epoch_start(self, actor: nn.Module, epoch: int, stage_epochs: int, curriculum_config: dict,
|
564
570
|
global_epoch: int, global_epochs: int) -> None:
|
565
571
|
print(
|
@@ -582,11 +588,21 @@ class MrlPrintCallback(MrlTrainerCallback):
|
|
582
588
|
print(f'Epoch {global_epoch} | Starting update epoch {update_epoch}')
|
583
589
|
|
584
590
|
def on_batch_updated(self, actor: nn.Module, epoch: int, step: int, policy_loss: float) -> None:
|
585
|
-
|
591
|
+
if step != 0 and step % self.update_steps_interval == 0:
|
592
|
+
loss = sum(self.policy_losses) / len(self.policy_losses)
|
593
|
+
self.policy_losses = []
|
594
|
+
print(f'Epoch {epoch} | Steps {step - self.update_steps_interval} - {step} - mean policy loss {loss} | current policy loss {policy_loss}')
|
595
|
+
else:
|
596
|
+
self.policy_losses.append(policy_loss)
|
586
597
|
|
587
598
|
def on_critic_updated(self, actor: nn.Module, critic: nn.Module, epoch: int, step: int,
|
588
599
|
critic_loss: float) -> None:
|
589
|
-
|
600
|
+
if step != 0 and step % self.update_steps_interval == 0:
|
601
|
+
loss = sum(self.critic_losses) / len(self.critic_losses)
|
602
|
+
self.critic_losses = []
|
603
|
+
print(f'Epoch {epoch} | Steps {step - self.update_steps_interval} - {step} - mean critic loss {loss} | current critic loss {critic_loss}')
|
604
|
+
else:
|
605
|
+
self.critic_losses.append(critic_loss)
|
590
606
|
|
591
607
|
def on_update_epoch_end(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int, policy_loss: float, critic_loss: float) -> None:
|
592
608
|
print(f'Epoch {global_epoch} | Update epoch {update_epoch} - mean policy loss {policy_loss} | mean critic loss {critic_loss}')
|
@@ -204,7 +204,7 @@ class MrlActorModel(nn.Module):
|
|
204
204
|
return self.decoder(x, attention_mask=attention_mask)
|
205
205
|
else:
|
206
206
|
_, ed = self.encoder(x, attention_mask=attention_mask)
|
207
|
-
return self.memory_attention(ed
|
207
|
+
return self.memory_attention(ed)
|
208
208
|
|
209
209
|
|
210
210
|
class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipeline_tag="text-classification"):
|
@@ -35,6 +35,7 @@ class MrlConfig(TypedDict):
|
|
35
35
|
moe_aux_loss_scale: Optional[float]
|
36
36
|
freeze_embeddings: Optional[bool]
|
37
37
|
embedding_lr: Optional[float]
|
38
|
+
use_memory_warmup: Optional[bool]
|
38
39
|
|
39
40
|
|
40
41
|
class MrlStrategy(Enum):
|
@@ -136,6 +137,7 @@ class MRLTrainer:
|
|
136
137
|
self.moe_aux_loss_scale = config.get('moe_aux_loss_scale', 0.01)
|
137
138
|
self.shared_freeze_embeddings = config.get('freeze_embeddings', False)
|
138
139
|
self.freeze_embeddings = self.shared_freeze_embeddings
|
140
|
+
self.use_memory_warmup = config.get('use_memory_warmup', False)
|
139
141
|
# Internal update epochs config
|
140
142
|
self.shared_update_epochs = config.get('update_epochs', 10)
|
141
143
|
self.update_epochs = self.shared_update_epochs
|
@@ -381,6 +383,11 @@ class MRLTrainer:
|
|
381
383
|
self.writer.add_scalar(f'Collect/episode reward (steps: {self.curriculum_steps})', avg_reward,
|
382
384
|
self.stage_step['collect'])
|
383
385
|
|
386
|
+
def memory_warmup(self, query: TokenizedDict, answer: TokenizedDict):
|
387
|
+
if self.use_memory_warmup:
|
388
|
+
with torch.no_grad():
|
389
|
+
self.encode_and_update_stm(query, answer)
|
390
|
+
|
384
391
|
def collect_trajectories(self, dataloader: DataLoader, epoch: int, batch_size: int) -> list[MrlTrajectoryEpisode]:
|
385
392
|
"""Collect trajectories for PPO for current curriculum step."""
|
386
393
|
# 1. Init trajectories list
|
@@ -402,8 +409,13 @@ class MRLTrainer:
|
|
402
409
|
first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
|
403
410
|
interactions = interactions[:self.curriculum_steps]
|
404
411
|
interactions_len = len(interactions)
|
412
|
+
|
413
|
+
first_interaction = self._move_multiple_batches(first_query, first_answer)
|
414
|
+
|
415
|
+
if reset_done:
|
416
|
+
self.memory_warmup(*first_interaction)
|
405
417
|
# 6. Encode and update STM with data to save from first interaction
|
406
|
-
self.encode_and_update_stm(*
|
418
|
+
self.encode_and_update_stm(*first_interaction)
|
407
419
|
|
408
420
|
# 7. Save first interaction as data to save (for trajectory state)
|
409
421
|
query, answer = first_query, first_answer
|
@@ -649,6 +661,9 @@ class MRLTrainer:
|
|
649
661
|
|
650
662
|
self.actor.clone_reset_memory()
|
651
663
|
|
664
|
+
if should_reset_stm and step_idx == 0:
|
665
|
+
self.memory_warmup(query, answer)
|
666
|
+
|
652
667
|
# 7. In memory aware critic version, encode and update STM before critic update, to include its gradients in critic loss too
|
653
668
|
if self.memory_aware_critic:
|
654
669
|
self.encode_and_update_stm(query, answer)
|
@@ -798,13 +813,16 @@ class MRLTrainer:
|
|
798
813
|
if batch['query']['input_ids'].size(0) == batch_size:
|
799
814
|
self._increment_steps('eval')
|
800
815
|
# 3. Reset STM with random resets ratio and reward model running mean
|
801
|
-
self.reset_stm()
|
816
|
+
reset_stm = self.reset_stm()
|
802
817
|
self.reward.reset_running_mean()
|
803
818
|
|
804
819
|
# 4. Get batches for first queries, answers and all follow-up interactions
|
805
820
|
first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
|
806
821
|
# 5. Encode and update STM with initial interactions (batch)
|
807
|
-
self.
|
822
|
+
first_interaction = self._move_multiple_batches(first_query, first_answer)
|
823
|
+
if reset_stm:
|
824
|
+
self.memory_warmup(*first_interaction)
|
825
|
+
self.encode_and_update_stm(*first_interaction)
|
808
826
|
|
809
827
|
# 6. Save follow-up interactions len and first query and answer as previous one for iteration
|
810
828
|
interactions_len = len(interactions)
|
@@ -941,7 +959,7 @@ class MRLTrainer:
|
|
941
959
|
]
|
942
960
|
elif mode == 'fetch':
|
943
961
|
params = [
|
944
|
-
{'params': self.actor.embedding_parameters(), 'lr':
|
962
|
+
{'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
|
945
963
|
{'params': self.actor.encoder.not_memory_parameters(), 'lr': unfreeze_lr},
|
946
964
|
{'params': self.actor.encoder.memory_parameters(), 'lr': unfreeze_lr},
|
947
965
|
{'params': self.actor.memory_attention_parameters(), 'lr': unfreeze_lr},
|
@@ -950,7 +968,7 @@ class MRLTrainer:
|
|
950
968
|
]
|
951
969
|
elif mode == 'joint':
|
952
970
|
params = [
|
953
|
-
{'params': self.actor.embedding_parameters(), 'lr':
|
971
|
+
{'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
|
954
972
|
{'params': self.actor.encoder.not_memory_parameters(), 'lr': unfreeze_lr},
|
955
973
|
{'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
|
956
974
|
{'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|