rxnn 0.2.46__py3-none-any.whl → 0.2.48__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/memory/attention.py CHANGED
@@ -33,9 +33,16 @@ class StmMemoryAttention(nn.Module):
33
33
  if self.attention_layers[i].rope is not None:
34
34
  self.attention_layers[i].rope.update_max_len(max_seq_len)
35
35
 
36
- def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
37
- mask = attention_mask.unsqueeze(1).unsqueeze(1).bool() if attention_mask is not None else None
36
+ def _residual_gate(self, gate: torch.Tensor, layer_stm: torch.Tensor, new_layer_stm: torch.Tensor) -> torch.Tensor:
37
+ if self.use_dynamic_gate:
38
+ mean_dim = -1 if self.per_slot_gate else [1, 2]
39
+ gate_input = gate * (new_layer_stm + layer_stm).mean(dim=mean_dim, keepdim=True)
40
+ layer_gate = torch.sigmoid(gate_input)
41
+ else:
42
+ layer_gate = torch.sigmoid(gate)
43
+ return layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
38
44
 
45
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
39
46
  new_stm = torch.zeros_like(self.stm.memory)
40
47
  for i in range(self.num_layers):
41
48
  layer_stm = self.stm(i)
@@ -44,14 +51,10 @@ class StmMemoryAttention(nn.Module):
44
51
  layer_stm = layer_stm.expand(x.size(0), -1, -1)
45
52
  encoded_layer_data = x[i]
46
53
  normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
47
- new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=mask)
54
+ new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data)
48
55
  if self.use_gated_residual:
49
- # gated residual
50
- gate_input = self.gate[i] * (new_layer_stm + layer_stm) if self.use_dynamic_gate else self.gate[i]
51
- layer_gate = torch.sigmoid(gate_input)
52
- new_stm[i] = layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
56
+ new_stm[i] = self._residual_gate(self.gate[i], layer_stm, new_layer_stm) # gated residual
53
57
  else:
54
58
  new_stm[i] = new_layer_stm + layer_stm # residual
55
59
  self.stm.update_all(new_stm)
56
60
  return self.stm.memory
57
-
rxnn/memory/norm.py CHANGED
@@ -163,7 +163,7 @@ def init_memory_norm(
163
163
  init_scale: float = 1.0,
164
164
  per_dim_scale: bool = False,
165
165
  ) -> nn.Module:
166
- assert norm_type in ['layer', 'rms', 'adaptive', 'positional']
166
+ assert norm_type in ['layer', 'rms', 'adaptive', 'positional', 'classic-rms']
167
167
  if norm_type == 'layer':
168
168
  return MemoryLayerNorm(dim, use_gate, init_scale, init_gate)
169
169
  elif norm_type == 'rms':
@@ -172,4 +172,6 @@ def init_memory_norm(
172
172
  return AdaptiveRMSMemoryNorm(dim, use_gate, decay, init_scale, init_gate)
173
173
  elif norm_type == 'positional':
174
174
  return AdaptivePositionalMemoryNorm(num_slots, dim, decay, use_scale, use_gate, init_gate, per_dim_scale)
175
+ elif norm_type == 'classic-rms':
176
+ return nn.RMSNorm(dim)
175
177
  return MemoryLayerNorm(dim, use_gate, init_scale, init_gate)
rxnn/rxt/models.py CHANGED
@@ -306,8 +306,8 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
306
306
  def clone_reset_memory(self):
307
307
  self.model.stm.clone_detach_reset()
308
308
 
309
- def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
310
- return self.model(x, attention_mask=attention_mask)
309
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
310
+ return self.model(x)
311
311
 
312
312
  class RxTAlphaCriticEncoder(RxTAlphaComponentBase, pipeline_tag="text-classification", license="apache-2.0"):
313
313
  """RxT-Alpha (Reactive Transformer) encoder model"""
@@ -560,6 +560,12 @@ class MrlTrainerCallback:
560
560
 
561
561
 
562
562
  class MrlPrintCallback(MrlTrainerCallback):
563
+ def __init__(self, update_steps_interval: int = 10) -> None:
564
+ super(MrlPrintCallback, self).__init__()
565
+ self.update_steps_interval = update_steps_interval
566
+ self.policy_losses = []
567
+ self.critic_losses = []
568
+
563
569
  def on_epoch_start(self, actor: nn.Module, epoch: int, stage_epochs: int, curriculum_config: dict,
564
570
  global_epoch: int, global_epochs: int) -> None:
565
571
  print(
@@ -582,11 +588,21 @@ class MrlPrintCallback(MrlTrainerCallback):
582
588
  print(f'Epoch {global_epoch} | Starting update epoch {update_epoch}')
583
589
 
584
590
  def on_batch_updated(self, actor: nn.Module, epoch: int, step: int, policy_loss: float) -> None:
585
- print(f'Epoch {epoch} | Step {step} - updated policy loss {policy_loss}')
591
+ if step != 0 and step % self.update_steps_interval == 0:
592
+ loss = sum(self.policy_losses) / len(self.policy_losses)
593
+ self.policy_losses = []
594
+ print(f'Epoch {epoch} | Steps {step - self.update_steps_interval} - {step} - mean policy loss {loss} | current policy loss {policy_loss}')
595
+ else:
596
+ self.policy_losses.append(policy_loss)
586
597
 
587
598
  def on_critic_updated(self, actor: nn.Module, critic: nn.Module, epoch: int, step: int,
588
599
  critic_loss: float) -> None:
589
- print(f'Epoch {epoch} | Step {step} - updated critic loss {critic_loss}')
600
+ if step != 0 and step % self.update_steps_interval == 0:
601
+ loss = sum(self.critic_losses) / len(self.critic_losses)
602
+ self.critic_losses = []
603
+ print(f'Epoch {epoch} | Steps {step - self.update_steps_interval} - {step} - mean critic loss {loss} | current critic loss {critic_loss}')
604
+ else:
605
+ self.critic_losses.append(critic_loss)
590
606
 
591
607
  def on_update_epoch_end(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int, policy_loss: float, critic_loss: float) -> None:
592
608
  print(f'Epoch {global_epoch} | Update epoch {update_epoch} - mean policy loss {policy_loss} | mean critic loss {critic_loss}')
rxnn/training/models.py CHANGED
@@ -204,7 +204,7 @@ class MrlActorModel(nn.Module):
204
204
  return self.decoder(x, attention_mask=attention_mask)
205
205
  else:
206
206
  _, ed = self.encoder(x, attention_mask=attention_mask)
207
- return self.memory_attention(ed, attention_mask=attention_mask)
207
+ return self.memory_attention(ed)
208
208
 
209
209
 
210
210
  class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipeline_tag="text-classification"):
rxnn/training/mrl.py CHANGED
@@ -35,6 +35,7 @@ class MrlConfig(TypedDict):
35
35
  moe_aux_loss_scale: Optional[float]
36
36
  freeze_embeddings: Optional[bool]
37
37
  embedding_lr: Optional[float]
38
+ use_memory_warmup: Optional[bool]
38
39
 
39
40
 
40
41
  class MrlStrategy(Enum):
@@ -136,6 +137,7 @@ class MRLTrainer:
136
137
  self.moe_aux_loss_scale = config.get('moe_aux_loss_scale', 0.01)
137
138
  self.shared_freeze_embeddings = config.get('freeze_embeddings', False)
138
139
  self.freeze_embeddings = self.shared_freeze_embeddings
140
+ self.use_memory_warmup = config.get('use_memory_warmup', False)
139
141
  # Internal update epochs config
140
142
  self.shared_update_epochs = config.get('update_epochs', 10)
141
143
  self.update_epochs = self.shared_update_epochs
@@ -381,6 +383,11 @@ class MRLTrainer:
381
383
  self.writer.add_scalar(f'Collect/episode reward (steps: {self.curriculum_steps})', avg_reward,
382
384
  self.stage_step['collect'])
383
385
 
386
+ def memory_warmup(self, query: TokenizedDict, answer: TokenizedDict):
387
+ if self.use_memory_warmup:
388
+ with torch.no_grad():
389
+ self.encode_and_update_stm(query, answer)
390
+
384
391
  def collect_trajectories(self, dataloader: DataLoader, epoch: int, batch_size: int) -> list[MrlTrajectoryEpisode]:
385
392
  """Collect trajectories for PPO for current curriculum step."""
386
393
  # 1. Init trajectories list
@@ -402,8 +409,13 @@ class MRLTrainer:
402
409
  first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
403
410
  interactions = interactions[:self.curriculum_steps]
404
411
  interactions_len = len(interactions)
412
+
413
+ first_interaction = self._move_multiple_batches(first_query, first_answer)
414
+
415
+ if reset_done:
416
+ self.memory_warmup(*first_interaction)
405
417
  # 6. Encode and update STM with data to save from first interaction
406
- self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
418
+ self.encode_and_update_stm(*first_interaction)
407
419
 
408
420
  # 7. Save first interaction as data to save (for trajectory state)
409
421
  query, answer = first_query, first_answer
@@ -649,6 +661,9 @@ class MRLTrainer:
649
661
 
650
662
  self.actor.clone_reset_memory()
651
663
 
664
+ if should_reset_stm and step_idx == 0:
665
+ self.memory_warmup(query, answer)
666
+
652
667
  # 7. In memory aware critic version, encode and update STM before critic update, to include its gradients in critic loss too
653
668
  if self.memory_aware_critic:
654
669
  self.encode_and_update_stm(query, answer)
@@ -798,13 +813,16 @@ class MRLTrainer:
798
813
  if batch['query']['input_ids'].size(0) == batch_size:
799
814
  self._increment_steps('eval')
800
815
  # 3. Reset STM with random resets ratio and reward model running mean
801
- self.reset_stm()
816
+ reset_stm = self.reset_stm()
802
817
  self.reward.reset_running_mean()
803
818
 
804
819
  # 4. Get batches for first queries, answers and all follow-up interactions
805
820
  first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
806
821
  # 5. Encode and update STM with initial interactions (batch)
807
- self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
822
+ first_interaction = self._move_multiple_batches(first_query, first_answer)
823
+ if reset_stm:
824
+ self.memory_warmup(*first_interaction)
825
+ self.encode_and_update_stm(*first_interaction)
808
826
 
809
827
  # 6. Save follow-up interactions len and first query and answer as previous one for iteration
810
828
  interactions_len = len(interactions)
@@ -941,7 +959,7 @@ class MRLTrainer:
941
959
  ]
942
960
  elif mode == 'fetch':
943
961
  params = [
944
- {'params': self.actor.embedding_parameters(), 'lr': unfreeze_lr},
962
+ {'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
945
963
  {'params': self.actor.encoder.not_memory_parameters(), 'lr': unfreeze_lr},
946
964
  {'params': self.actor.encoder.memory_parameters(), 'lr': unfreeze_lr},
947
965
  {'params': self.actor.memory_attention_parameters(), 'lr': unfreeze_lr},
@@ -950,7 +968,7 @@ class MRLTrainer:
950
968
  ]
951
969
  elif mode == 'joint':
952
970
  params = [
953
- {'params': self.actor.embedding_parameters(), 'lr': unfreeze_lr},
971
+ {'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
954
972
  {'params': self.actor.encoder.not_memory_parameters(), 'lr': unfreeze_lr},
955
973
  {'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
956
974
  {'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.46
3
+ Version: 0.2.48
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -5,19 +5,19 @@ rxnn/experimental/attention.py,sha256=46qwZLJuZMpIBrZ-r9DaQEPPmmZkO464C3Tkm_Mq-c
5
5
  rxnn/experimental/models.py,sha256=foBo0n0ufvBnfIdJomiEg3CuSOiWSt-q5ako7vzYxx4,4888
6
6
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
7
7
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
- rxnn/memory/attention.py,sha256=sXh6f_iOpEYCaqyG-QVp_C_A9IF0QcXTi3hW5G8FCwA,2630
9
- rxnn/memory/norm.py,sha256=E98jOQEuIOFFhlkvS8s4fFN-D4tLO6vaOqnObv1oVmA,6592
8
+ rxnn/memory/attention.py,sha256=kan6UNPTjLfO7zKNp92hGooldgWPi3li_2-_L5xiErs,2784
9
+ rxnn/memory/norm.py,sha256=cVjjhCLqR5K6-321SP_ObG17y-ddlcTJeCTXvW4vpk0,6675
10
10
  rxnn/memory/stm.py,sha256=jv57gsH9XW19sLbxpRDqsp1yfsii_4Ef4Ncr_ztk-i4,3937
11
11
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- rxnn/rxt/models.py,sha256=4MbCL4xGY3ceewZQmopjmwAyLQS92L6KLOPqaW7-Fho,14673
12
+ rxnn/rxt/models.py,sha256=new_YXLe9vfIBPX-pmFRoV523d7yCjEgfTY06EaH3Ms,14605
13
13
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  rxnn/training/base.py,sha256=CqaArEZYOdH64nmKfx28U3GI46TzO4oNkjf_hrF23Cw,11835
15
15
  rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
16
- rxnn/training/callbacks.py,sha256=RPW3Lisi31VJvoYyZeAF3dQzttrceDQDsZ6G5Xl09HM,35933
16
+ rxnn/training/callbacks.py,sha256=rS8leuVFPVVfE5Zc8DMkUZhRIPN-vpPbUjowXE5TSBw,36779
17
17
  rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
18
18
  rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
19
- rxnn/training/models.py,sha256=tqABOt_xEcWbZNEW2I2Jt-3eyaGICK011zILwuTk6Zc,9082
20
- rxnn/training/mrl.py,sha256=L4G7xSPlxsymvNhvsSloCpaqYjOXxEm7GmKilM_Ojvc,59809
19
+ rxnn/training/models.py,sha256=L2emJM06u7B9f9T1dFsGXzXX-rsV77ND7L1pAM9Z_Ow,9051
20
+ rxnn/training/mrl.py,sha256=cTVdNmyohiz4BB6NsmT1CWzFCbSgO7DCD7tfffoYEpc,60558
21
21
  rxnn/training/reward.py,sha256=B7nerPk9eNAv2i7umtNF88tVQVwijNNrchIrEITGHKk,11623
22
22
  rxnn/training/rl.py,sha256=q4NzIZAmXRHVToT13IHrPTtEikWQUvT0NO0IjApjAO8,6171
23
23
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
33
33
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
34
34
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
35
35
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
36
- rxnn-0.2.46.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
- rxnn-0.2.46.dist-info/METADATA,sha256=hpTQT4p75cKrAaGOz_56gCBm1rT_y-Nr1TI9Mhv6wv0,25960
38
- rxnn-0.2.46.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
- rxnn-0.2.46.dist-info/RECORD,,
36
+ rxnn-0.2.48.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
+ rxnn-0.2.48.dist-info/METADATA,sha256=IJUCcjg8vteeX8WkLEzwbciH814TOzdXPKikdb5xDgw,25960
38
+ rxnn-0.2.48.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
+ rxnn-0.2.48.dist-info/RECORD,,
File without changes
File without changes