rxnn 0.2.31__tar.gz → 0.2.33__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rxnn-0.2.31 → rxnn-0.2.33}/PKG-INFO +1 -1
  2. {rxnn-0.2.31 → rxnn-0.2.33}/pyproject.toml +1 -1
  3. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/memory/stm.py +0 -1
  4. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/rxt/models.py +27 -1
  5. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/training/models.py +11 -3
  6. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/training/mrl.py +2 -2
  7. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/training/reward.py +9 -1
  8. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/training/rl.py +28 -17
  9. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/transformers/ff.py +2 -0
  10. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/transformers/models.py +33 -0
  11. {rxnn-0.2.31 → rxnn-0.2.33}/LICENSE +0 -0
  12. {rxnn-0.2.31 → rxnn-0.2.33}/README.md +0 -0
  13. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/.DS_Store +0 -0
  14. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/__init__.py +0 -0
  15. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/experimental/__init__.py +0 -0
  16. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/experimental/attention.py +0 -0
  17. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/experimental/models.py +0 -0
  18. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/experimental/moe.py +0 -0
  19. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/memory/__init__.py +0 -0
  20. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/memory/attention.py +0 -0
  21. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/memory/norm.py +0 -0
  22. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/rxt/__init__.py +0 -0
  23. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/training/__init__.py +0 -0
  24. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/training/base.py +0 -0
  25. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/training/bml.py +0 -0
  26. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/training/callbacks.py +0 -0
  27. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/training/dataset.py +0 -0
  28. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/training/ddp.py +0 -0
  29. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/training/scheduler.py +0 -0
  30. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/training/tokenizer.py +0 -0
  31. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/training/utils.py +0 -0
  32. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/transformers/__init__.py +0 -0
  33. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/transformers/attention.py +0 -0
  34. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/transformers/layers.py +0 -0
  35. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/transformers/mask.py +0 -0
  36. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/transformers/sampler.py +0 -0
  39. {rxnn-0.2.31 → rxnn-0.2.33}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.31
3
+ Version: 0.2.33
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.31"
7
+ version = "0.2.33"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -44,7 +44,6 @@ class ShortTermMemory(nn.Module):
44
44
 
45
45
  def update_all(self, new_stm: torch.Tensor):
46
46
  self.memory = new_stm
47
- # self.memory.copy_(new_stm)
48
47
 
49
48
  def make_trainable(self):
50
49
  if not self.is_trainable:
@@ -5,7 +5,7 @@ from huggingface_hub import PyTorchModelHubMixin
5
5
  from ..transformers.positional import RotaryPositionalEmbedding
6
6
  from ..transformers.attention import init_attention
7
7
  from ..transformers.layers import ReactiveTransformerLayer
8
- from ..transformers.models import ReactiveTransformerBase, ReactiveTransformerEncoder, ReactiveTransformerDecoder
8
+ from ..transformers.models import ReactiveTransformerBase, ReactiveTransformerEncoder, ReactiveTransformerDecoder, ReactiveTransformerEncoderDetachStm
9
9
  from ..transformers.ff import get_activation_layer
10
10
  from ..memory.stm import ShortTermMemory
11
11
  from ..memory.norm import init_memory_norm
@@ -293,3 +293,29 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
293
293
 
294
294
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
295
295
  return self.model(x, attention_mask=attention_mask)
296
+
297
+ class RxTAlphaCriticEncoder(RxTAlphaComponentBase, pipeline_tag="text-classification", license="apache-2.0"):
298
+ """RxT-Alpha (Reactive Transformer) encoder model"""
299
+
300
+ def __init__(self, **kwargs: RxTAlphaComponentConfig):
301
+ super(RxTAlphaCriticEncoder, self).__init__(False, **kwargs)
302
+
303
+ def _init_model(
304
+ self,
305
+ stm: ShortTermMemory,
306
+ layers: nn.ModuleList,
307
+ embedding: nn.Embedding,
308
+ use_flash_attention: bool,
309
+ embed_dim: int,
310
+ vocab_size: int
311
+ ) -> ReactiveTransformerEncoderDetachStm:
312
+ return ReactiveTransformerEncoderDetachStm(
313
+ stm=stm,
314
+ embedding=embedding,
315
+ own_layers=layers,
316
+ use_flash_attention=use_flash_attention,
317
+ )
318
+
319
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
320
+ return self.model(x, attention_mask=attention_mask)
321
+
@@ -4,6 +4,7 @@ from enum import Enum
4
4
  from typing import Literal, Iterator
5
5
  from huggingface_hub import PyTorchModelHubMixin
6
6
  from ..transformers.models import ReactiveTransformerEncoder, ReactiveTransformerDecoder
7
+ from ..transformers.ff import GatedLinearUnit, get_activation_layer
7
8
 
8
9
  class MLMHead(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
9
10
  def __init__(self, embed_dim: int, vocab_size: int, *args, **kwargs):
@@ -161,10 +162,16 @@ class MrlActorModel(nn.Module):
161
162
  return self.memory_attention(ed, attention_mask=attention_mask)
162
163
 
163
164
  class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipeline_tag="text-classification"):
164
- def __init__(self, encoder: nn.Module, embed_dim: int, **kwargs):
165
+ def __init__(self, encoder: nn.Module, embed_dim: int, out_activation: Literal['sigmoid', 'tanh', 'linear'] = 'sigmoid', output_scale: float = 1.0, **kwargs):
165
166
  super(MrlCriticModel, self).__init__(**kwargs)
166
167
  self.encoder = encoder
167
- self.value_head = nn.Linear(embed_dim, 1)
168
+ self.value_head = nn.Sequential(
169
+ GatedLinearUnit(embed_dim, embed_dim, nn.SiLU()),
170
+ nn.LayerNorm(embed_dim),
171
+ nn.Linear(embed_dim, 1),
172
+ get_activation_layer(out_activation)
173
+ )
174
+ self.output_scale = output_scale
168
175
 
169
176
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
170
177
  x, _ = self.encoder(x, attention_mask=attention_mask)
@@ -175,4 +182,5 @@ class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipe
175
182
  else:
176
183
  x = x.mean(dim=1)
177
184
 
178
- return self.value_head(x)
185
+ return self.value_head(x) * self.output_scale
186
+
@@ -481,7 +481,7 @@ class MRLTrainer:
481
481
  critic_state = smart_concat_critic_states(*state, max_length=self.critic_max_len,
482
482
  pad_token_id=self.pad_token_id)
483
483
  values = self.critic(critic_state['input_ids'], attention_mask=critic_state['attention_mask']).squeeze()
484
- critic_loss = self.rl_algorithm.critic_loss(values, ref_values)
484
+ critic_loss = self.rl_algorithm.critic_loss(values, ref_values.detach())
485
485
  # 2.2 Run backpropagation with scaler
486
486
  self.critic_scaler.scale(critic_loss).backward()
487
487
  # 2.3 Unscale and clip gradients
@@ -495,7 +495,7 @@ class MRLTrainer:
495
495
  critic_state = smart_concat_critic_states(*state, max_length=self.critic_max_len,
496
496
  pad_token_id=self.pad_token_id)
497
497
  values = self.critic(critic_state['input_ids'], attention_mask=critic_state['attention_mask']).squeeze()
498
- critic_loss = self.rl_algorithm.critic_loss(values, ref_values)
498
+ critic_loss = self.rl_algorithm.critic_loss(values, ref_values.detach())
499
499
  # 2.2 Run backpropagation
500
500
  critic_loss.backward()
501
501
  # 2.3 Clip gradients
@@ -42,6 +42,7 @@ class MrlRewardModel:
42
42
  running_mean_decay: float = 0.2,
43
43
  bleu_saved_weights: tuple = (0.5, 0.5),
44
44
  bleu_ref_weights: tuple = (0.5, 0.5),
45
+ tanh_reward_scale: bool = False,
45
46
  rewards_scale: float = 1.0,
46
47
  ):
47
48
  self.shared_embedding = shared_embedding.to(device)
@@ -71,6 +72,7 @@ class MrlRewardModel:
71
72
  self.running_mean_decay = running_mean_decay
72
73
  self.bleu_ref_weights = bleu_ref_weights
73
74
  self.bleu_saved_weights = bleu_saved_weights
75
+ self.tanh_reward_scale = tanh_reward_scale
74
76
  self.rewards_scale = rewards_scale
75
77
 
76
78
  self.prev_data_running_mean = None
@@ -175,6 +177,12 @@ class MrlRewardModel:
175
177
  self.prev_data_running_mean = (1 - self.running_mean_decay) * self._sequence_embedding(
176
178
  prev_data) + self.running_mean_decay * self.prev_data_running_mean
177
179
 
180
+ def _pre_scale_rewards(self, rewards: torch.Tensor) -> torch.Tensor:
181
+ if self.tanh_reward_scale:
182
+ return (rewards * 2) - 1 # Convert [0,1] to [-1,1]
183
+ else:
184
+ return rewards
185
+
178
186
  def __call__(
179
187
  self,
180
188
  generated: TokenizedDict,
@@ -204,5 +212,5 @@ class MrlRewardModel:
204
212
  cosine = self.negative_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
205
213
  sim_rewards = self.neg_bleu_factor * torch.tensor(bleu, device=self.device) + self.neg_cos_factor * cosine
206
214
 
207
- rewards = (sim_rewards + self.len_factor * self.len_reward(generated) if self.reward_len else sim_rewards) * self.rewards_scale
215
+ rewards = self._pre_scale_rewards(sim_rewards + self.len_factor * self.len_reward(generated) if self.reward_len else sim_rewards) * self.rewards_scale
208
216
  return rewards.tolist()
@@ -21,8 +21,8 @@ class RlAlgorithm(ABC):
21
21
  def calculate_advantages(self, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
22
22
  pass
23
23
 
24
- def critic_loss(self, rewards: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
25
- return self.critic_loss(rewards, values)
24
+ def critic_loss(self, values: torch.Tensor, ref_values: torch.Tensor) -> torch.Tensor:
25
+ return self.critic_loss(values, ref_values)
26
26
 
27
27
 
28
28
  class PPOConfig(TypedDict):
@@ -31,6 +31,8 @@ class PPOConfig(TypedDict):
31
31
  gae_gamma: Optional[float]
32
32
  entropy_coef: Optional[float]
33
33
  use_distributed_advantage_norm: Optional[bool]
34
+ clip_critic_values: Optional[bool]
35
+ critic_value_clip: Optional[float]
34
36
 
35
37
 
36
38
  class PPOAlgorithm(RlAlgorithm):
@@ -46,6 +48,14 @@ class PPOAlgorithm(RlAlgorithm):
46
48
  self.gae_gamma = config.get('gae_gamma', 0.99)
47
49
  self.entropy_coef = config.get('entropy_coef', 0.01)
48
50
  self.use_distributed_advantage_norm = config.get('use_distributed_advantage_norm', False)
51
+ self.clip_critic_values = config.get('clip_critic_values', True)
52
+ self.critic_value_clip = config.get('critic_value_clip', 10.0)
53
+
54
+ def critic_loss(self, values: torch.Tensor, ref_values: torch.Tensor) -> torch.Tensor:
55
+ # Critic loss with clipped values
56
+ if self.clip_critic_values:
57
+ values = torch.clamp(values, -self.critic_value_clip, self.critic_value_clip)
58
+ return self.critic_loss(values, ref_values)
49
59
 
50
60
  def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
51
61
  old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
@@ -96,23 +106,24 @@ class PPOAlgorithm(RlAlgorithm):
96
106
 
97
107
  return policy_loss
98
108
 
99
- def _compute_gae(self, rewards: torch.Tensor, values: torch.Tensor, next_value: torch.Tensor, dones: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
100
- T, B = rewards.shape
101
- advantages = torch.zeros_like(rewards, device=values.device)
109
+ def _compute_gae(self, rewards: torch.Tensor, values: torch.Tensor,
110
+ last_value: torch.Tensor, dones: torch.Tensor):
111
+ trajectory_len, batch_size = rewards.shape
112
+ advantages = torch.zeros_like(rewards, device=rewards.device)
102
113
  last_advantage = 0
103
- last_value = next_value.detach()
104
-
105
- for t in reversed(range(T)):
106
- if t == T - 1:
107
- next_values = last_value
108
- else:
109
- next_values = values[t + 1]
110
-
111
- # Mask next values if episode ended
112
- next_values = next_values * ~dones[t]
113
- delta = rewards[t] + self.gae_gamma * next_values - values[t]
114
- advantages[t] = delta + self.gae_gamma * self.gae_lambda * last_advantage
114
+ next_value = last_value
115
+ next_done = torch.zeros(batch_size, device=dones.device) # Last state is terminal
116
+ dones = dones.float()
117
+ for t in reversed(range(trajectory_len)):
118
+ # Check if next state is terminal
119
+ non_terminal = 1.0 - next_done
120
+
121
+ # Delta should not include next_value if next is terminal
122
+ delta = rewards[t] + self.gae_gamma * next_value * non_terminal - values[t]
123
+ advantages[t] = delta + self.gae_gamma * self.gae_lambda * non_terminal * last_advantage
115
124
  last_advantage = advantages[t]
125
+ next_value = values[t]
126
+ next_done = dones[t]
116
127
 
117
128
  returns = advantages + values
118
129
  return advantages, returns
@@ -66,6 +66,8 @@ def get_activation_layer(activation: str):
66
66
  return nn.SiLU()
67
67
  elif activation == 'sigmoid':
68
68
  return nn.Sigmoid()
69
+ elif activation == 'tanh':
70
+ return nn.Tanh()
69
71
  elif activation == 'linear':
70
72
  return LinearActivation()
71
73
  else:
@@ -126,6 +126,39 @@ class ReactiveTransformerEncoder(ReactiveTransformerBase):
126
126
  return x, torch.stack(hidden_states)
127
127
 
128
128
 
129
+ class ReactiveTransformerEncoderDetachStm(ReactiveTransformerBase):
130
+ """
131
+ Reactive Transformer encoder DetachStm version - reactive transformer encoder that's detaching Short-Term Memory tensors,
132
+ before processing them in layers (memory cross-attention). Made for Memory-Aware Critic models, to not include memory
133
+ update gradients in Critic optimization.
134
+ """
135
+
136
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
137
+ x = super().forward(x) # apply embeddings
138
+ if attention_mask is not None:
139
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
140
+
141
+ hidden_states = []
142
+ # Process shared layers
143
+ if self.shared_layers is not None:
144
+ for i in range(self.num_shared_layers):
145
+ layer_stm = self.stm(i).detach() # <- Detach STM layer
146
+ # expand layer STM to batch size, if it's not in batch mode
147
+ if layer_stm.size(0) == 1:
148
+ layer_stm = layer_stm.expand(x.size(0), -1, -1)
149
+ x = self.shared_layers[i](x, layer_stm, mask=attention_mask)
150
+ hidden_states.append(x)
151
+ # Process own layers
152
+ for i in range(self.num_own_layers):
153
+ layer_stm = self.stm(i).detach() # <- Detach STM layer
154
+ # expand layer STM to batch size, if it's not in batch mode
155
+ if layer_stm.size(0) == 1:
156
+ layer_stm = layer_stm.expand(x.size(0), -1, -1)
157
+ x = self.layers[i](x, layer_stm, mask=attention_mask)
158
+ hidden_states.append(x)
159
+ return x, torch.stack(hidden_states)
160
+
161
+
129
162
  class ClassicTransformerBase(nn.Module):
130
163
  """Base class for Classic Transformer models - common logic for both decoders and encoders."""
131
164
 
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes