rxnn 0.2.32__py3-none-any.whl → 0.2.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/models.py CHANGED
@@ -4,6 +4,7 @@ from enum import Enum
4
4
  from typing import Literal, Iterator
5
5
  from huggingface_hub import PyTorchModelHubMixin
6
6
  from ..transformers.models import ReactiveTransformerEncoder, ReactiveTransformerDecoder
7
+ from ..transformers.ff import GatedLinearUnit, get_activation_layer
7
8
 
8
9
  class MLMHead(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
9
10
  def __init__(self, embed_dim: int, vocab_size: int, *args, **kwargs):
@@ -161,10 +162,16 @@ class MrlActorModel(nn.Module):
161
162
  return self.memory_attention(ed, attention_mask=attention_mask)
162
163
 
163
164
  class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipeline_tag="text-classification"):
164
- def __init__(self, encoder: nn.Module, embed_dim: int, **kwargs):
165
+ def __init__(self, encoder: nn.Module, embed_dim: int, out_activation: Literal['sigmoid', 'tanh', 'linear'] = 'sigmoid', output_scale: float = 1.0, **kwargs):
165
166
  super(MrlCriticModel, self).__init__(**kwargs)
166
167
  self.encoder = encoder
167
- self.value_head = nn.Linear(embed_dim, 1)
168
+ self.value_head = nn.Sequential(
169
+ GatedLinearUnit(embed_dim, embed_dim, nn.SiLU()),
170
+ nn.LayerNorm(embed_dim),
171
+ nn.Linear(embed_dim, 1),
172
+ get_activation_layer(out_activation)
173
+ )
174
+ self.output_scale = output_scale
168
175
 
169
176
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
170
177
  x, _ = self.encoder(x, attention_mask=attention_mask)
@@ -175,4 +182,5 @@ class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipe
175
182
  else:
176
183
  x = x.mean(dim=1)
177
184
 
178
- return self.value_head(x)
185
+ return self.value_head(x) * self.output_scale
186
+
rxnn/training/mrl.py CHANGED
@@ -481,7 +481,7 @@ class MRLTrainer:
481
481
  critic_state = smart_concat_critic_states(*state, max_length=self.critic_max_len,
482
482
  pad_token_id=self.pad_token_id)
483
483
  values = self.critic(critic_state['input_ids'], attention_mask=critic_state['attention_mask']).squeeze()
484
- critic_loss = self.rl_algorithm.critic_loss(values, ref_values)
484
+ critic_loss = self.rl_algorithm.critic_loss(values, ref_values.detach())
485
485
  # 2.2 Run backpropagation with scaler
486
486
  self.critic_scaler.scale(critic_loss).backward()
487
487
  # 2.3 Unscale and clip gradients
@@ -495,7 +495,7 @@ class MRLTrainer:
495
495
  critic_state = smart_concat_critic_states(*state, max_length=self.critic_max_len,
496
496
  pad_token_id=self.pad_token_id)
497
497
  values = self.critic(critic_state['input_ids'], attention_mask=critic_state['attention_mask']).squeeze()
498
- critic_loss = self.rl_algorithm.critic_loss(values, ref_values)
498
+ critic_loss = self.rl_algorithm.critic_loss(values, ref_values.detach())
499
499
  # 2.2 Run backpropagation
500
500
  critic_loss.backward()
501
501
  # 2.3 Clip gradients
rxnn/training/reward.py CHANGED
@@ -42,6 +42,7 @@ class MrlRewardModel:
42
42
  running_mean_decay: float = 0.2,
43
43
  bleu_saved_weights: tuple = (0.5, 0.5),
44
44
  bleu_ref_weights: tuple = (0.5, 0.5),
45
+ tanh_reward_scale: bool = False,
45
46
  rewards_scale: float = 1.0,
46
47
  ):
47
48
  self.shared_embedding = shared_embedding.to(device)
@@ -71,6 +72,7 @@ class MrlRewardModel:
71
72
  self.running_mean_decay = running_mean_decay
72
73
  self.bleu_ref_weights = bleu_ref_weights
73
74
  self.bleu_saved_weights = bleu_saved_weights
75
+ self.tanh_reward_scale = tanh_reward_scale
74
76
  self.rewards_scale = rewards_scale
75
77
 
76
78
  self.prev_data_running_mean = None
@@ -175,6 +177,12 @@ class MrlRewardModel:
175
177
  self.prev_data_running_mean = (1 - self.running_mean_decay) * self._sequence_embedding(
176
178
  prev_data) + self.running_mean_decay * self.prev_data_running_mean
177
179
 
180
+ def _pre_scale_rewards(self, rewards: torch.Tensor) -> torch.Tensor:
181
+ if self.tanh_reward_scale:
182
+ return (rewards * 2) - 1 # Convert [0,1] to [-1,1]
183
+ else:
184
+ return rewards
185
+
178
186
  def __call__(
179
187
  self,
180
188
  generated: TokenizedDict,
@@ -204,5 +212,5 @@ class MrlRewardModel:
204
212
  cosine = self.negative_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
205
213
  sim_rewards = self.neg_bleu_factor * torch.tensor(bleu, device=self.device) + self.neg_cos_factor * cosine
206
214
 
207
- rewards = (sim_rewards + self.len_factor * self.len_reward(generated) if self.reward_len else sim_rewards) * self.rewards_scale
215
+ rewards = self._pre_scale_rewards(sim_rewards + self.len_factor * self.len_reward(generated) if self.reward_len else sim_rewards) * self.rewards_scale
208
216
  return rewards.tolist()
rxnn/training/rl.py CHANGED
@@ -21,8 +21,8 @@ class RlAlgorithm(ABC):
21
21
  def calculate_advantages(self, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
22
22
  pass
23
23
 
24
- def critic_loss(self, rewards: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
25
- return self.critic_loss(rewards, values)
24
+ def critic_loss(self, values: torch.Tensor, ref_values: torch.Tensor) -> torch.Tensor:
25
+ return self.critic_loss(values, ref_values)
26
26
 
27
27
 
28
28
  class PPOConfig(TypedDict):
@@ -31,6 +31,8 @@ class PPOConfig(TypedDict):
31
31
  gae_gamma: Optional[float]
32
32
  entropy_coef: Optional[float]
33
33
  use_distributed_advantage_norm: Optional[bool]
34
+ clip_critic_values: Optional[bool]
35
+ critic_value_clip: Optional[float]
34
36
 
35
37
 
36
38
  class PPOAlgorithm(RlAlgorithm):
@@ -46,6 +48,14 @@ class PPOAlgorithm(RlAlgorithm):
46
48
  self.gae_gamma = config.get('gae_gamma', 0.99)
47
49
  self.entropy_coef = config.get('entropy_coef', 0.01)
48
50
  self.use_distributed_advantage_norm = config.get('use_distributed_advantage_norm', False)
51
+ self.clip_critic_values = config.get('clip_critic_values', True)
52
+ self.critic_value_clip = config.get('critic_value_clip', 10.0)
53
+
54
+ def critic_loss(self, values: torch.Tensor, ref_values: torch.Tensor) -> torch.Tensor:
55
+ # Critic loss with clipped values
56
+ if self.clip_critic_values:
57
+ values = torch.clamp(values, -self.critic_value_clip, self.critic_value_clip)
58
+ return self.critic_loss(values, ref_values)
49
59
 
50
60
  def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
51
61
  old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
@@ -96,23 +106,24 @@ class PPOAlgorithm(RlAlgorithm):
96
106
 
97
107
  return policy_loss
98
108
 
99
- def _compute_gae(self, rewards: torch.Tensor, values: torch.Tensor, next_value: torch.Tensor, dones: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
100
- T, B = rewards.shape
101
- advantages = torch.zeros_like(rewards, device=values.device)
109
+ def _compute_gae(self, rewards: torch.Tensor, values: torch.Tensor,
110
+ last_value: torch.Tensor, dones: torch.Tensor):
111
+ trajectory_len, batch_size = rewards.shape
112
+ advantages = torch.zeros_like(rewards, device=rewards.device)
102
113
  last_advantage = 0
103
- last_value = next_value.detach()
104
-
105
- for t in reversed(range(T)):
106
- if t == T - 1:
107
- next_values = last_value
108
- else:
109
- next_values = values[t + 1]
110
-
111
- # Mask next values if episode ended
112
- next_values = next_values * ~dones[t]
113
- delta = rewards[t] + self.gae_gamma * next_values - values[t]
114
- advantages[t] = delta + self.gae_gamma * self.gae_lambda * last_advantage
114
+ next_value = last_value
115
+ next_done = torch.zeros(batch_size, device=dones.device) # Last state is terminal
116
+ dones = dones.float()
117
+ for t in reversed(range(trajectory_len)):
118
+ # Check if next state is terminal
119
+ non_terminal = 1.0 - next_done
120
+
121
+ # Delta should not include next_value if next is terminal
122
+ delta = rewards[t] + self.gae_gamma * next_value * non_terminal - values[t]
123
+ advantages[t] = delta + self.gae_gamma * self.gae_lambda * non_terminal * last_advantage
115
124
  last_advantage = advantages[t]
125
+ next_value = values[t]
126
+ next_done = dones[t]
116
127
 
117
128
  returns = advantages + values
118
129
  return advantages, returns
rxnn/transformers/ff.py CHANGED
@@ -66,6 +66,8 @@ def get_activation_layer(activation: str):
66
66
  return nn.SiLU()
67
67
  elif activation == 'sigmoid':
68
68
  return nn.Sigmoid()
69
+ elif activation == 'tanh':
70
+ return nn.Tanh()
69
71
  elif activation == 'linear':
70
72
  return LinearActivation()
71
73
  else:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.32
3
+ Version: 0.2.33
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -16,16 +16,16 @@ rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
16
16
  rxnn/training/callbacks.py,sha256=p72lbzFAmFjpcUvyy4aUB3qd53I8C6Sk5w9nQvsKgTk,35852
17
17
  rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
18
18
  rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
19
- rxnn/training/models.py,sha256=2KhNT7yx0AgUke4nmsFqzQKx_YYp78QvsLWYZjWeUgQ,6812
20
- rxnn/training/mrl.py,sha256=Aimiiqf_4p6dp5Ty9pY9VwetySBS_OFpCQlcVHVkO4Q,55124
21
- rxnn/training/reward.py,sha256=7MTVdNm5HnWmt6zFDi3TAYmnVSL_-24riOoY2F7z4x8,11290
22
- rxnn/training/rl.py,sha256=eL3C0yryiNBgl_xb-D-5dyYUtK4V4-K4t3a60x5ir28,5142
19
+ rxnn/training/models.py,sha256=8FV5eZx1HxtqRSgikwfKoB_bNhPuMYyNi0uSXB65-M4,7223
20
+ rxnn/training/mrl.py,sha256=1pYzjXI17FDZGPTVpmbaBvMYpB-a6SLv-84RHXA4JEA,55142
21
+ rxnn/training/reward.py,sha256=B7nerPk9eNAv2i7umtNF88tVQVwijNNrchIrEITGHKk,11623
22
+ rxnn/training/rl.py,sha256=ckx1nlzIGZBabzwZNRj4isvHqRZwg0y0jGOT-SN6KZc,5841
23
23
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
24
24
  rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
25
25
  rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
26
26
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
27
  rxnn/transformers/attention.py,sha256=d0Igo1Nrn76BphbHrzekiKJfT3RCy4iSXSB6FLAOTgc,16247
28
- rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
28
+ rxnn/transformers/ff.py,sha256=WDjO-H9XWInoWnUnxiseIH6Kx5GlHP0zGJygwhcb1gc,2589
29
29
  rxnn/transformers/layers.py,sha256=l0bXmhN7KOkCw0KTVLixWSo9Op4SesGabWJ4R4EQBMY,7988
30
30
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
31
31
  rxnn/transformers/models.py,sha256=hey6tFN9gmLfWCZLjtl_9OcvIjGpWLI1IDeVnr5y8YM,10583
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
33
33
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
34
34
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
35
35
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
36
- rxnn-0.2.32.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
- rxnn-0.2.32.dist-info/METADATA,sha256=Ugq4LZZBakM-pUwpd0ZY0W_Ot2zUJHAxFtofIqAHzA8,25960
38
- rxnn-0.2.32.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
- rxnn-0.2.32.dist-info/RECORD,,
36
+ rxnn-0.2.33.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
+ rxnn-0.2.33.dist-info/METADATA,sha256=im17irb58IYMXOzMXE6QaSPF31Akx0iYS4ay-aRqA9Q,25960
38
+ rxnn-0.2.33.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
+ rxnn-0.2.33.dist-info/RECORD,,
File without changes
File without changes