rxnn 0.2.56__tar.gz → 0.2.58__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rxnn-0.2.56 → rxnn-0.2.58}/PKG-INFO +1 -1
  2. {rxnn-0.2.56 → rxnn-0.2.58}/pyproject.toml +1 -1
  3. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/memory/attention.py +3 -2
  4. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/rxt/models.py +2 -2
  5. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/training/dataset.py +1 -1
  6. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/training/models.py +1 -1
  7. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/training/mrl.py +6 -4
  8. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/training/rl.py +16 -8
  9. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/transformers/layers.py +5 -1
  10. {rxnn-0.2.56 → rxnn-0.2.58}/LICENSE +0 -0
  11. {rxnn-0.2.56 → rxnn-0.2.58}/README.md +0 -0
  12. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/.DS_Store +0 -0
  13. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/__init__.py +0 -0
  14. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/experimental/__init__.py +0 -0
  15. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/experimental/attention.py +0 -0
  16. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/experimental/models.py +0 -0
  17. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/experimental/moe.py +0 -0
  18. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/memory/__init__.py +0 -0
  19. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/memory/norm.py +0 -0
  20. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/memory/stm.py +0 -0
  21. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/rxt/__init__.py +0 -0
  22. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/training/__init__.py +0 -0
  23. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/training/base.py +0 -0
  24. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/training/bml.py +0 -0
  25. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/training/callbacks.py +0 -0
  26. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/training/ddp.py +0 -0
  27. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/training/reward.py +0 -0
  28. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/training/scheduler.py +0 -0
  29. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/training/tokenizer.py +0 -0
  30. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/training/utils.py +0 -0
  31. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/transformers/__init__.py +0 -0
  32. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/transformers/attention.py +0 -0
  33. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/transformers/ff.py +0 -0
  34. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/transformers/mask.py +0 -0
  35. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/transformers/models.py +0 -0
  36. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/transformers/sampler.py +0 -0
  39. {rxnn-0.2.56 → rxnn-0.2.58}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.56
3
+ Version: 0.2.58
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.56"
7
+ version = "0.2.58"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -47,7 +47,8 @@ class StmMemoryAttention(nn.Module):
47
47
  else:
48
48
  return layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
49
49
 
50
- def forward(self, x: torch.Tensor) -> torch.Tensor:
50
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
51
+ mem_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool() if attention_mask else None
51
52
  new_stm = torch.zeros_like(self.stm.memory)
52
53
  for i in range(self.num_layers):
53
54
  layer_stm = self.stm(i)
@@ -56,7 +57,7 @@ class StmMemoryAttention(nn.Module):
56
57
  layer_stm = layer_stm.expand(x.size(0), -1, -1)
57
58
  encoded_layer_data = x[i]
58
59
  normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
59
- new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data)
60
+ new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=mem_mask)
60
61
  if self.use_gated_residual:
61
62
  new_stm[i] = self._residual_gate(self.gate[i], layer_stm, new_layer_stm) # gated residual
62
63
  else:
@@ -306,8 +306,8 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
306
306
  def clone_reset_memory(self):
307
307
  self.model.stm.clone_detach_reset()
308
308
 
309
- def forward(self, x: torch.Tensor) -> torch.Tensor:
310
- return self.model(x)
309
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
310
+ return self.model(x, attention_mask=attention_mask)
311
311
 
312
312
  class RxTAlphaCriticEncoder(RxTAlphaComponentBase, pipeline_tag="text-classification", license="apache-2.0"):
313
313
  """RxT-Alpha (Reactive Transformer) encoder model"""
@@ -943,7 +943,7 @@ class MrlCurriculumDataset(Dataset):
943
943
  return self.get_tokenized_item(idx)
944
944
 
945
945
  def __len__(self) -> int:
946
- return len(self.episodes)
946
+ return len(self.inputs if self.is_pre_tokenized else self.episodes)
947
947
 
948
948
  def get_subset(self, size: float, from_start: bool = False, **kwargs) -> "MRlCurriculumDataset":
949
949
  split_point = int(len(self.episodes) * ((1 - size) if not from_start else size))
@@ -204,7 +204,7 @@ class MrlActorModel(nn.Module):
204
204
  return self.decoder(x, attention_mask=attention_mask)
205
205
  else:
206
206
  _, ed = self.encoder(x, attention_mask=attention_mask)
207
- return self.memory_attention(ed)
207
+ return self.memory_attention(ed, attention_mask=attention_mask)
208
208
 
209
209
 
210
210
  class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipeline_tag="text-classification"):
@@ -36,6 +36,8 @@ class MrlConfig(TypedDict):
36
36
  freeze_embeddings: Optional[bool]
37
37
  embedding_lr: Optional[float]
38
38
  use_memory_warmup: Optional[bool]
39
+ debug_mode: Optional[bool]
40
+ debug_interval: Optional[int]
39
41
 
40
42
 
41
43
  class MrlStrategy(Enum):
@@ -109,7 +111,6 @@ class MRLTrainer:
109
111
  use_ddp: bool = False,
110
112
  use_amp: bool = False,
111
113
  dtype: torch.dtype = torch.float32,
112
- debug_mode: bool = False,
113
114
  ):
114
115
  """
115
116
  Trainer for Memory Reinforcement Learning (MRL) algorithm for reactive models and Attention-Based Memory System.
@@ -140,7 +141,8 @@ class MRLTrainer:
140
141
  self.shared_freeze_embeddings = config.get('freeze_embeddings', False)
141
142
  self.freeze_embeddings = self.shared_freeze_embeddings
142
143
  self.use_memory_warmup = config.get('use_memory_warmup', False)
143
- self.debug_mode = debug_mode
144
+ self.debug_mode = config.get('debug_mode', False)
145
+ self.debug_interval = config.get('debug_interval', 10)
144
146
  # Internal update epochs config
145
147
  self.shared_update_epochs = config.get('update_epochs', 10)
146
148
  self.update_epochs = self.shared_update_epochs
@@ -606,7 +608,7 @@ class MRLTrainer:
606
608
  self.scaler.unscale_(self.optimizer)
607
609
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
608
610
  error_if_nonfinite=False)
609
- if self.debug_mode:
611
+ if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
610
612
  self._log_gradients()
611
613
  # 4.5 Run scaled optimization step
612
614
  self.scaler.step(self.optimizer)
@@ -625,7 +627,7 @@ class MRLTrainer:
625
627
  # 4.4 Clip gradient norms
626
628
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
627
629
  error_if_nonfinite=False)
628
- if self.debug_mode:
630
+ if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
629
631
  self._log_gradients()
630
632
  # 4.5 Run scaled optimization step
631
633
  self.optimizer.step()
@@ -33,10 +33,12 @@ class PPOConfig(TypedDict):
33
33
  use_distributed_advantage_norm: Optional[bool]
34
34
  clip_critic_values: Optional[bool]
35
35
  critic_value_clip: Optional[float]
36
+ debug_mode: Optional[bool]
37
+ debug_interval: Optional[int]
36
38
 
37
39
 
38
40
  class PPOAlgorithm(RlAlgorithm):
39
- def __init__(self, config: Optional[PPOConfig] = None, debug_mode: bool = False):
41
+ def __init__(self, config: Optional[PPOConfig] = None):
40
42
  super(PPOAlgorithm, self).__init__()
41
43
 
42
44
  if config is None:
@@ -50,7 +52,9 @@ class PPOAlgorithm(RlAlgorithm):
50
52
  self.use_distributed_advantage_norm = config.get('use_distributed_advantage_norm', False)
51
53
  self.clip_critic_values = config.get('clip_critic_values', True)
52
54
  self.critic_value_clip = config.get('critic_value_clip', 20.0)
53
- self.debug_mode = debug_mode
55
+ self.debug_mode = config.get('debug_mode', False)
56
+ self.debug_interval = config.get('debug_interval', 10)
57
+ self.debug_step = 0
54
58
 
55
59
  def critic_loss(self, values: torch.Tensor, ref_values: torch.Tensor) -> torch.Tensor:
56
60
  # Critic loss with clipped values
@@ -98,12 +102,16 @@ class PPOAlgorithm(RlAlgorithm):
98
102
  advantages = advantages.unsqueeze(-1)
99
103
 
100
104
  if self.debug_mode:
101
- print(
102
- f"Logits stats: min={new_logits.min().item():.4f}, max={new_logits.max().item():.4f}, mean={new_logits.mean().item():.4f}")
103
- print(
104
- f"Ratio stats: min={ratio.min().item():.4f}, max={ratio.max().item():.4f}, mean={ratio.mean().item():.4f}")
105
- print(
106
- f"Advantage stats: min={advantages.min().item():.4f}, max={advantages.max().item():.4f}, mean={advantages.mean().item():.4f}")
105
+ if self.debug_step != 0 and self.debug_step % self.debug_interval == 0:
106
+ self.debug_step = 0
107
+ print(
108
+ f"Logits stats: min={new_logits.min().item():.4f}, max={new_logits.max().item():.4f}, mean={new_logits.mean().item():.4f}")
109
+ print(
110
+ f"Ratio stats: min={ratio.min().item():.4f}, max={ratio.max().item():.4f}, mean={ratio.mean().item():.4f}")
111
+ print(
112
+ f"Advantage stats: min={advantages.min().item():.4f}, max={advantages.max().item():.4f}, mean={advantages.mean().item():.4f}")
113
+ else:
114
+ self.debug_step += 1
107
115
 
108
116
  # c) Clipped surrogate loss
109
117
  surr1 = ratio * advantages
@@ -110,7 +110,11 @@ class ReactiveTransformerLayer(nn.Module):
110
110
  residual = x
111
111
  if not self.use_post_norm:
112
112
  x = self.norm2(x)
113
- x = self.memory_cross_attention(x, stm, stm, mask=mask)
113
+
114
+ if mask is not None:
115
+ mem_mask = mask.squeeze(1).unsqueeze(-1).expand(-1, -1, -1, stm.size(1))
116
+
117
+ x = self.memory_cross_attention(x, stm, stm, mask=mem_mask)
114
118
  x = residual + x
115
119
  if self.use_post_norm:
116
120
  x = self.norm2(x)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes