rxnn 0.2.32__tar.gz → 0.2.33__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rxnn-0.2.32 → rxnn-0.2.33}/PKG-INFO +1 -1
  2. {rxnn-0.2.32 → rxnn-0.2.33}/pyproject.toml +1 -1
  3. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/training/models.py +11 -3
  4. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/training/mrl.py +2 -2
  5. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/training/reward.py +9 -1
  6. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/training/rl.py +28 -17
  7. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/transformers/ff.py +2 -0
  8. {rxnn-0.2.32 → rxnn-0.2.33}/LICENSE +0 -0
  9. {rxnn-0.2.32 → rxnn-0.2.33}/README.md +0 -0
  10. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/.DS_Store +0 -0
  11. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/__init__.py +0 -0
  12. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/experimental/__init__.py +0 -0
  13. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/experimental/attention.py +0 -0
  14. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/experimental/models.py +0 -0
  15. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/experimental/moe.py +0 -0
  16. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/memory/__init__.py +0 -0
  17. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/memory/attention.py +0 -0
  18. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/memory/norm.py +0 -0
  19. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/memory/stm.py +0 -0
  20. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/rxt/__init__.py +0 -0
  21. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/rxt/models.py +0 -0
  22. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/training/__init__.py +0 -0
  23. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/training/base.py +0 -0
  24. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/training/bml.py +0 -0
  25. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/training/callbacks.py +0 -0
  26. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/training/dataset.py +0 -0
  27. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/training/ddp.py +0 -0
  28. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/training/scheduler.py +0 -0
  29. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/training/tokenizer.py +0 -0
  30. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/training/utils.py +0 -0
  31. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/transformers/__init__.py +0 -0
  32. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/transformers/attention.py +0 -0
  33. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/transformers/layers.py +0 -0
  34. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/transformers/mask.py +0 -0
  35. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/transformers/models.py +0 -0
  36. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/transformers/sampler.py +0 -0
  39. {rxnn-0.2.32 → rxnn-0.2.33}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.32
3
+ Version: 0.2.33
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.32"
7
+ version = "0.2.33"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -4,6 +4,7 @@ from enum import Enum
4
4
  from typing import Literal, Iterator
5
5
  from huggingface_hub import PyTorchModelHubMixin
6
6
  from ..transformers.models import ReactiveTransformerEncoder, ReactiveTransformerDecoder
7
+ from ..transformers.ff import GatedLinearUnit, get_activation_layer
7
8
 
8
9
  class MLMHead(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
9
10
  def __init__(self, embed_dim: int, vocab_size: int, *args, **kwargs):
@@ -161,10 +162,16 @@ class MrlActorModel(nn.Module):
161
162
  return self.memory_attention(ed, attention_mask=attention_mask)
162
163
 
163
164
  class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipeline_tag="text-classification"):
164
- def __init__(self, encoder: nn.Module, embed_dim: int, **kwargs):
165
+ def __init__(self, encoder: nn.Module, embed_dim: int, out_activation: Literal['sigmoid', 'tanh', 'linear'] = 'sigmoid', output_scale: float = 1.0, **kwargs):
165
166
  super(MrlCriticModel, self).__init__(**kwargs)
166
167
  self.encoder = encoder
167
- self.value_head = nn.Linear(embed_dim, 1)
168
+ self.value_head = nn.Sequential(
169
+ GatedLinearUnit(embed_dim, embed_dim, nn.SiLU()),
170
+ nn.LayerNorm(embed_dim),
171
+ nn.Linear(embed_dim, 1),
172
+ get_activation_layer(out_activation)
173
+ )
174
+ self.output_scale = output_scale
168
175
 
169
176
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
170
177
  x, _ = self.encoder(x, attention_mask=attention_mask)
@@ -175,4 +182,5 @@ class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipe
175
182
  else:
176
183
  x = x.mean(dim=1)
177
184
 
178
- return self.value_head(x)
185
+ return self.value_head(x) * self.output_scale
186
+
@@ -481,7 +481,7 @@ class MRLTrainer:
481
481
  critic_state = smart_concat_critic_states(*state, max_length=self.critic_max_len,
482
482
  pad_token_id=self.pad_token_id)
483
483
  values = self.critic(critic_state['input_ids'], attention_mask=critic_state['attention_mask']).squeeze()
484
- critic_loss = self.rl_algorithm.critic_loss(values, ref_values)
484
+ critic_loss = self.rl_algorithm.critic_loss(values, ref_values.detach())
485
485
  # 2.2 Run backpropagation with scaler
486
486
  self.critic_scaler.scale(critic_loss).backward()
487
487
  # 2.3 Unscale and clip gradients
@@ -495,7 +495,7 @@ class MRLTrainer:
495
495
  critic_state = smart_concat_critic_states(*state, max_length=self.critic_max_len,
496
496
  pad_token_id=self.pad_token_id)
497
497
  values = self.critic(critic_state['input_ids'], attention_mask=critic_state['attention_mask']).squeeze()
498
- critic_loss = self.rl_algorithm.critic_loss(values, ref_values)
498
+ critic_loss = self.rl_algorithm.critic_loss(values, ref_values.detach())
499
499
  # 2.2 Run backpropagation
500
500
  critic_loss.backward()
501
501
  # 2.3 Clip gradients
@@ -42,6 +42,7 @@ class MrlRewardModel:
42
42
  running_mean_decay: float = 0.2,
43
43
  bleu_saved_weights: tuple = (0.5, 0.5),
44
44
  bleu_ref_weights: tuple = (0.5, 0.5),
45
+ tanh_reward_scale: bool = False,
45
46
  rewards_scale: float = 1.0,
46
47
  ):
47
48
  self.shared_embedding = shared_embedding.to(device)
@@ -71,6 +72,7 @@ class MrlRewardModel:
71
72
  self.running_mean_decay = running_mean_decay
72
73
  self.bleu_ref_weights = bleu_ref_weights
73
74
  self.bleu_saved_weights = bleu_saved_weights
75
+ self.tanh_reward_scale = tanh_reward_scale
74
76
  self.rewards_scale = rewards_scale
75
77
 
76
78
  self.prev_data_running_mean = None
@@ -175,6 +177,12 @@ class MrlRewardModel:
175
177
  self.prev_data_running_mean = (1 - self.running_mean_decay) * self._sequence_embedding(
176
178
  prev_data) + self.running_mean_decay * self.prev_data_running_mean
177
179
 
180
+ def _pre_scale_rewards(self, rewards: torch.Tensor) -> torch.Tensor:
181
+ if self.tanh_reward_scale:
182
+ return (rewards * 2) - 1 # Convert [0,1] to [-1,1]
183
+ else:
184
+ return rewards
185
+
178
186
  def __call__(
179
187
  self,
180
188
  generated: TokenizedDict,
@@ -204,5 +212,5 @@ class MrlRewardModel:
204
212
  cosine = self.negative_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
205
213
  sim_rewards = self.neg_bleu_factor * torch.tensor(bleu, device=self.device) + self.neg_cos_factor * cosine
206
214
 
207
- rewards = (sim_rewards + self.len_factor * self.len_reward(generated) if self.reward_len else sim_rewards) * self.rewards_scale
215
+ rewards = self._pre_scale_rewards(sim_rewards + self.len_factor * self.len_reward(generated) if self.reward_len else sim_rewards) * self.rewards_scale
208
216
  return rewards.tolist()
@@ -21,8 +21,8 @@ class RlAlgorithm(ABC):
21
21
  def calculate_advantages(self, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
22
22
  pass
23
23
 
24
- def critic_loss(self, rewards: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
25
- return self.critic_loss(rewards, values)
24
+ def critic_loss(self, values: torch.Tensor, ref_values: torch.Tensor) -> torch.Tensor:
25
+ return self.critic_loss(values, ref_values)
26
26
 
27
27
 
28
28
  class PPOConfig(TypedDict):
@@ -31,6 +31,8 @@ class PPOConfig(TypedDict):
31
31
  gae_gamma: Optional[float]
32
32
  entropy_coef: Optional[float]
33
33
  use_distributed_advantage_norm: Optional[bool]
34
+ clip_critic_values: Optional[bool]
35
+ critic_value_clip: Optional[float]
34
36
 
35
37
 
36
38
  class PPOAlgorithm(RlAlgorithm):
@@ -46,6 +48,14 @@ class PPOAlgorithm(RlAlgorithm):
46
48
  self.gae_gamma = config.get('gae_gamma', 0.99)
47
49
  self.entropy_coef = config.get('entropy_coef', 0.01)
48
50
  self.use_distributed_advantage_norm = config.get('use_distributed_advantage_norm', False)
51
+ self.clip_critic_values = config.get('clip_critic_values', True)
52
+ self.critic_value_clip = config.get('critic_value_clip', 10.0)
53
+
54
+ def critic_loss(self, values: torch.Tensor, ref_values: torch.Tensor) -> torch.Tensor:
55
+ # Critic loss with clipped values
56
+ if self.clip_critic_values:
57
+ values = torch.clamp(values, -self.critic_value_clip, self.critic_value_clip)
58
+ return self.critic_loss(values, ref_values)
49
59
 
50
60
  def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
51
61
  old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
@@ -96,23 +106,24 @@ class PPOAlgorithm(RlAlgorithm):
96
106
 
97
107
  return policy_loss
98
108
 
99
- def _compute_gae(self, rewards: torch.Tensor, values: torch.Tensor, next_value: torch.Tensor, dones: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
100
- T, B = rewards.shape
101
- advantages = torch.zeros_like(rewards, device=values.device)
109
+ def _compute_gae(self, rewards: torch.Tensor, values: torch.Tensor,
110
+ last_value: torch.Tensor, dones: torch.Tensor):
111
+ trajectory_len, batch_size = rewards.shape
112
+ advantages = torch.zeros_like(rewards, device=rewards.device)
102
113
  last_advantage = 0
103
- last_value = next_value.detach()
104
-
105
- for t in reversed(range(T)):
106
- if t == T - 1:
107
- next_values = last_value
108
- else:
109
- next_values = values[t + 1]
110
-
111
- # Mask next values if episode ended
112
- next_values = next_values * ~dones[t]
113
- delta = rewards[t] + self.gae_gamma * next_values - values[t]
114
- advantages[t] = delta + self.gae_gamma * self.gae_lambda * last_advantage
114
+ next_value = last_value
115
+ next_done = torch.zeros(batch_size, device=dones.device) # Last state is terminal
116
+ dones = dones.float()
117
+ for t in reversed(range(trajectory_len)):
118
+ # Check if next state is terminal
119
+ non_terminal = 1.0 - next_done
120
+
121
+ # Delta should not include next_value if next is terminal
122
+ delta = rewards[t] + self.gae_gamma * next_value * non_terminal - values[t]
123
+ advantages[t] = delta + self.gae_gamma * self.gae_lambda * non_terminal * last_advantage
115
124
  last_advantage = advantages[t]
125
+ next_value = values[t]
126
+ next_done = dones[t]
116
127
 
117
128
  returns = advantages + values
118
129
  return advantages, returns
@@ -66,6 +66,8 @@ def get_activation_layer(activation: str):
66
66
  return nn.SiLU()
67
67
  elif activation == 'sigmoid':
68
68
  return nn.Sigmoid()
69
+ elif activation == 'tanh':
70
+ return nn.Tanh()
69
71
  elif activation == 'linear':
70
72
  return LinearActivation()
71
73
  else:
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes