rxnn 0.2.45__tar.gz → 0.2.47__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rxnn-0.2.45 → rxnn-0.2.47}/PKG-INFO +1 -1
  2. {rxnn-0.2.45 → rxnn-0.2.47}/pyproject.toml +1 -1
  3. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/memory/attention.py +15 -8
  4. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/rxt/models.py +5 -3
  5. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/training/callbacks.py +24 -8
  6. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/training/models.py +1 -1
  7. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/training/mrl.py +2 -2
  8. {rxnn-0.2.45 → rxnn-0.2.47}/LICENSE +0 -0
  9. {rxnn-0.2.45 → rxnn-0.2.47}/README.md +0 -0
  10. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/.DS_Store +0 -0
  11. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/__init__.py +0 -0
  12. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/experimental/__init__.py +0 -0
  13. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/experimental/attention.py +0 -0
  14. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/experimental/models.py +0 -0
  15. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/experimental/moe.py +0 -0
  16. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/memory/__init__.py +0 -0
  17. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/memory/norm.py +0 -0
  18. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/memory/stm.py +0 -0
  19. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/rxt/__init__.py +0 -0
  20. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/training/__init__.py +0 -0
  21. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/training/base.py +0 -0
  22. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/training/bml.py +0 -0
  23. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/training/dataset.py +0 -0
  24. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/training/ddp.py +0 -0
  25. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/training/reward.py +0 -0
  26. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/training/rl.py +0 -0
  27. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/training/scheduler.py +0 -0
  28. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/training/tokenizer.py +0 -0
  29. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/training/utils.py +0 -0
  30. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/transformers/__init__.py +0 -0
  31. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/transformers/attention.py +0 -0
  32. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/transformers/ff.py +0 -0
  33. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/transformers/layers.py +0 -0
  34. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/transformers/mask.py +0 -0
  35. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/transformers/models.py +0 -0
  36. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/transformers/sampler.py +0 -0
  39. {rxnn-0.2.45 → rxnn-0.2.47}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.45
3
+ Version: 0.2.47
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.45"
7
+ version = "0.2.47"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -11,6 +11,7 @@ class StmMemoryAttention(nn.Module):
11
11
  use_gated_residual: bool = False,
12
12
  per_slot_gate: bool = False,
13
13
  init_gate: float = 0.0,
14
+ use_dynamic_gate: bool = False,
14
15
  *args,
15
16
  **kwargs
16
17
  ):
@@ -22,17 +23,26 @@ class StmMemoryAttention(nn.Module):
22
23
  self.num_layers = len(attention_layers)
23
24
  self.use_gated_residual = use_gated_residual
24
25
  self.per_slot_gate = per_slot_gate
26
+ self.use_dynamic_gate = use_dynamic_gate
25
27
  if self.use_gated_residual:
26
- self.gate = nn.Parameter(torch.full((self.num_layers, self.stm.stm_size, 1), init_gate) if self.per_slot_gate else torch.full((self.num_layers,), init_gate))
28
+ gate_shape = (self.num_layers, self.stm.stm_size, 1) if self.per_slot_gate else (self.num_layers,)
29
+ self.gate = nn.Parameter(torch.full(gate_shape, init_gate))
27
30
 
28
31
  def update_max_len(self, max_seq_len: int):
29
32
  for i in range(self.num_layers):
30
33
  if self.attention_layers[i].rope is not None:
31
34
  self.attention_layers[i].rope.update_max_len(max_seq_len)
32
35
 
33
- def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
34
- mask = attention_mask.unsqueeze(1).unsqueeze(1).bool() if attention_mask is not None else None
36
+ def _residual_gate(self, gate: torch.Tensor, layer_stm: torch.Tensor, new_layer_stm: torch.Tensor) -> torch.Tensor:
37
+ if self.use_dynamic_gate:
38
+ mean_dim = -1 if self.per_slot_gate else [1, 2]
39
+ gate_input = gate * (new_layer_stm + layer_stm).mean(dim=mean_dim, keepdim=True)
40
+ layer_gate = torch.sigmoid(gate_input)
41
+ else:
42
+ layer_gate = torch.sigmoid(gate)
43
+ return layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
35
44
 
45
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
36
46
  new_stm = torch.zeros_like(self.stm.memory)
37
47
  for i in range(self.num_layers):
38
48
  layer_stm = self.stm(i)
@@ -41,13 +51,10 @@ class StmMemoryAttention(nn.Module):
41
51
  layer_stm = layer_stm.expand(x.size(0), -1, -1)
42
52
  encoded_layer_data = x[i]
43
53
  normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
44
- new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=mask)
54
+ new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data)
45
55
  if self.use_gated_residual:
46
- # gated residual
47
- layer_gate = torch.sigmoid(self.gate[i])
48
- new_stm[i] = layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
56
+ new_stm[i] = self._residual_gate(self.gate[i], layer_stm, new_layer_stm) # gated residual
49
57
  else:
50
58
  new_stm[i] = new_layer_stm + layer_stm # residual
51
59
  self.stm.update_all(new_stm)
52
60
  return self.stm.memory
53
-
@@ -253,6 +253,7 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
253
253
  use_gated_residual: bool = False,
254
254
  residual_per_slot_gate: bool = False,
255
255
  residual_init_gate: float = 0.0,
256
+ use_dynamic_residual_gate: bool = False,
256
257
  **kwargs,
257
258
  ):
258
259
  super(RxTAlphaMemoryAttention, self).__init__(**kwargs)
@@ -281,7 +282,8 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
281
282
  attention_layers = nn.ModuleList([att_init() for _ in range(num_layers)])
282
283
  self.model = StmMemoryAttention(
283
284
  stm, attention_layers, memory_norm_layers,
284
- use_gated_residual=use_gated_residual, per_slot_gate=residual_per_slot_gate, init_gate=residual_init_gate
285
+ use_gated_residual=use_gated_residual, per_slot_gate=residual_per_slot_gate,
286
+ init_gate=residual_init_gate, use_dynamic_gate=use_dynamic_residual_gate,
285
287
  )
286
288
 
287
289
  def freeze(self):
@@ -304,8 +306,8 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
304
306
  def clone_reset_memory(self):
305
307
  self.model.stm.clone_detach_reset()
306
308
 
307
- def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
308
- return self.model(x, attention_mask=attention_mask)
309
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
310
+ return self.model(x)
309
311
 
310
312
  class RxTAlphaCriticEncoder(RxTAlphaComponentBase, pipeline_tag="text-classification", license="apache-2.0"):
311
313
  """RxT-Alpha (Reactive Transformer) encoder model"""
@@ -533,7 +533,7 @@ class MrlTrainerCallback:
533
533
  reward: float) -> None:
534
534
  pass
535
535
 
536
- def on_reward(self, actor: nn.Module, reward: float, generated: str, reference: str, saved_data: str, eval_mode: bool) -> None:
536
+ def on_reward(self, actor: nn.Module, rewards: list[float], generated: str, reference: str, saved_data: str, eval_mode: bool) -> None:
537
537
  pass
538
538
 
539
539
  def on_update_epoch_start(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int) -> None:
@@ -560,33 +560,49 @@ class MrlTrainerCallback:
560
560
 
561
561
 
562
562
  class MrlPrintCallback(MrlTrainerCallback):
563
+ def __init__(self, update_steps_interval: int = 10) -> None:
564
+ super(MrlPrintCallback, self).__init__()
565
+ self.update_steps_interval = update_steps_interval
566
+ self.policy_losses = []
567
+ self.critic_losses = []
568
+
563
569
  def on_epoch_start(self, actor: nn.Module, epoch: int, stage_epochs: int, curriculum_config: dict,
564
570
  global_epoch: int, global_epochs: int) -> None:
565
571
  print(
566
- f'Starting epoch {epoch}/{stage_epochs} (stage) | {global_epoch}/{global_epochs} (global) for {curriculum_config["steps"]} steps in {curriculum_config["strategy"]} strategy.')
572
+ f'Starting epoch {epoch}/{stage_epochs - 1} (stage) | {global_epoch}/{global_epochs} (global) for {curriculum_config["steps"]} steps in {curriculum_config["strategy"]} strategy.')
567
573
 
568
574
  def on_epoch_end(self, actor: nn.Module, epoch: int, stage_epochs: int, policy_loss: float,
569
575
  critic_loss: float, global_epoch: int, global_epochs: int) -> None:
570
- print(f'Finished epoch {epoch}/{stage_epochs} (stage) | {global_epoch}/{global_epochs} (global)')
576
+ print(f'Finished epoch {epoch}/{stage_epochs - 1} (stage) | {global_epoch}/{global_epochs} (global)')
571
577
  print(f'Policy mean loss: {policy_loss} | Critic mean loss: {critic_loss}')
572
578
 
573
579
  def on_episode_collected(self, actor: nn.Module, batch_idx: int, episode_trajectories: list[dict],
574
580
  reward: float) -> None:
575
581
  print(f'Collected {batch_idx} episode | mean reward {reward}')
576
582
 
577
- def on_reward(self, actor: nn.Module, reward: float, generated: dict[str, torch.Tensor],
583
+ def on_reward(self, actor: nn.Module, rewards: list[float], generated: dict[str, torch.Tensor],
578
584
  reference: dict[str, torch.Tensor], saved_data: dict[str, torch.Tensor], eval_mode: bool) -> None:
579
- print(f"{'Eval' if eval_mode else 'Train'} | Collected reward {reward}")
585
+ print(f"{'Eval' if eval_mode else 'Train'} | Mean reward: {sum(rewards) / len(rewards)} | All collected rewards: {rewards}")
580
586
 
581
587
  def on_update_epoch_start(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int) -> None:
582
588
  print(f'Epoch {global_epoch} | Starting update epoch {update_epoch}')
583
589
 
584
590
  def on_batch_updated(self, actor: nn.Module, epoch: int, step: int, policy_loss: float) -> None:
585
- print(f'Epoch {epoch} | Step {step} - updated policy loss {policy_loss}')
591
+ if step != 0 and step % self.update_steps_interval == 0:
592
+ loss = sum(self.policy_losses) / len(self.policy_losses)
593
+ self.policy_losses = []
594
+ print(f'Epoch {epoch} | Steps {step - self.update_steps_interval} - {step} - mean policy loss {loss} | current policy loss {policy_loss}')
595
+ else:
596
+ self.policy_losses.append(policy_loss)
586
597
 
587
598
  def on_critic_updated(self, actor: nn.Module, critic: nn.Module, epoch: int, step: int,
588
599
  critic_loss: float) -> None:
589
- print(f'Epoch {epoch} | Step {step} - updated critic loss {critic_loss}')
600
+ if step != 0 and step % self.update_steps_interval == 0:
601
+ loss = sum(self.critic_losses) / len(self.critic_losses)
602
+ self.critic_losses = []
603
+ print(f'Epoch {epoch} | Steps {step - self.update_steps_interval} - {step} - mean critic loss {loss} | current critic loss {critic_loss}')
604
+ else:
605
+ self.critic_losses.append(critic_loss)
590
606
 
591
607
  def on_update_epoch_end(self, actor: nn.Module, critic: nn.Module, global_epoch: int, update_epoch: int, policy_loss: float, critic_loss: float) -> None:
592
608
  print(f'Epoch {global_epoch} | Update epoch {update_epoch} - mean policy loss {policy_loss} | mean critic loss {critic_loss}')
@@ -780,7 +796,7 @@ class MrlGeneratedTokensCallback(MrlTrainerCallback):
780
796
  self.steps_log_interval = steps_log_interval
781
797
  self.step = 0
782
798
 
783
- def on_reward(self, actor: nn.Module, reward: float, generated: dict[str, torch.Tensor],
799
+ def on_reward(self, actor: nn.Module, rewards: list[float], generated: dict[str, torch.Tensor],
784
800
  reference: dict[str, torch.Tensor], saved_data: dict[str, torch.Tensor], eval_mode: bool) -> None:
785
801
  self.step += 1
786
802
  attention_mask = generated['attention_mask']
@@ -204,7 +204,7 @@ class MrlActorModel(nn.Module):
204
204
  return self.decoder(x, attention_mask=attention_mask)
205
205
  else:
206
206
  _, ed = self.encoder(x, attention_mask=attention_mask)
207
- return self.memory_attention(ed, attention_mask=attention_mask)
207
+ return self.memory_attention(ed)
208
208
 
209
209
 
210
210
  class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipeline_tag="text-classification"):
@@ -941,7 +941,7 @@ class MRLTrainer:
941
941
  ]
942
942
  elif mode == 'fetch':
943
943
  params = [
944
- {'params': self.actor.embedding_parameters(), 'lr': unfreeze_lr},
944
+ {'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
945
945
  {'params': self.actor.encoder.not_memory_parameters(), 'lr': unfreeze_lr},
946
946
  {'params': self.actor.encoder.memory_parameters(), 'lr': unfreeze_lr},
947
947
  {'params': self.actor.memory_attention_parameters(), 'lr': unfreeze_lr},
@@ -950,7 +950,7 @@ class MRLTrainer:
950
950
  ]
951
951
  elif mode == 'joint':
952
952
  params = [
953
- {'params': self.actor.embedding_parameters(), 'lr': unfreeze_lr},
953
+ {'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
954
954
  {'params': self.actor.encoder.not_memory_parameters(), 'lr': unfreeze_lr},
955
955
  {'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
956
956
  {'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes