rxnn 0.2.32__py3-none-any.whl → 0.2.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/models.py CHANGED
@@ -4,6 +4,8 @@ from enum import Enum
4
4
  from typing import Literal, Iterator
5
5
  from huggingface_hub import PyTorchModelHubMixin
6
6
  from ..transformers.models import ReactiveTransformerEncoder, ReactiveTransformerDecoder
7
+ from ..transformers.ff import GatedLinearUnit, get_activation_layer
8
+
7
9
 
8
10
  class MLMHead(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
9
11
  def __init__(self, embed_dim: int, vocab_size: int, *args, **kwargs):
@@ -37,6 +39,7 @@ class MLMTrainingModel(nn.Module):
37
39
  y = self.mlm_head(h)
38
40
  return y
39
41
 
42
+
40
43
  class JointTrainingModel(nn.Module):
41
44
  def __init__(
42
45
  self,
@@ -58,10 +61,12 @@ class JointTrainingModel(nn.Module):
58
61
  y_d = self.decoder(x_d, attention_mask=attention_mask)
59
62
  return y_e, y_d
60
63
 
64
+
61
65
  class MrlActorAction(Enum):
62
66
  DECODE = 1
63
67
  UPDATE = 2
64
68
 
69
+
65
70
  class MrlActorModel(nn.Module):
66
71
  def __init__(
67
72
  self,
@@ -153,18 +158,33 @@ class MrlActorModel(nn.Module):
153
158
  list(self.memory_attention.parameters())
154
159
  ))
155
160
 
156
- def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None, action: MrlActorAction = MrlActorAction.DECODE) -> torch.Tensor:
161
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None,
162
+ action: MrlActorAction = MrlActorAction.DECODE) -> torch.Tensor:
157
163
  if action == MrlActorAction.DECODE:
158
164
  return self.decoder(x, attention_mask=attention_mask)
159
165
  else:
160
166
  _, ed = self.encoder(x, attention_mask=attention_mask)
161
167
  return self.memory_attention(ed, attention_mask=attention_mask)
162
168
 
169
+
163
170
  class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipeline_tag="text-classification"):
164
- def __init__(self, encoder: nn.Module, embed_dim: int, **kwargs):
171
+ def __init__(self, encoder: nn.Module, embed_dim: int,
172
+ out_activation: Literal['sigmoid', 'tanh', 'linear'] = 'sigmoid', output_scale: float = 1.0, **kwargs):
165
173
  super(MrlCriticModel, self).__init__(**kwargs)
166
174
  self.encoder = encoder
167
- self.value_head = nn.Linear(embed_dim, 1)
175
+ self.value_head = nn.Sequential(
176
+ GatedLinearUnit(embed_dim, embed_dim, nn.SiLU()),
177
+ nn.LayerNorm(embed_dim),
178
+ nn.Linear(embed_dim, 1),
179
+ get_activation_layer(out_activation)
180
+ )
181
+ self.output_scale = output_scale
182
+
183
+ def head_parameters(self) -> Iterator[nn.Parameter]:
184
+ return self.value_head.parameters()
185
+
186
+ def encoder_parameters(self) -> Iterator[nn.Parameter]:
187
+ return self.encoder.parameters()
168
188
 
169
189
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
170
190
  x, _ = self.encoder(x, attention_mask=attention_mask)
@@ -175,4 +195,4 @@ class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipe
175
195
  else:
176
196
  x = x.mean(dim=1)
177
197
 
178
- return self.value_head(x)
198
+ return self.value_head(x) * self.output_scale
rxnn/training/mrl.py CHANGED
@@ -15,11 +15,13 @@ from .reward import MrlRewardMode, MrlRewardModel
15
15
  from .models import MrlActorAction, MrlActorModel, MrlCriticModel
16
16
  from .ddp import get_os_ddp_config, distributed_mean
17
17
 
18
+
18
19
  class MrlConfig(TypedDict):
19
20
  lr: float
20
21
  separate_memory_lr: Optional[bool]
21
22
  memory_lr: Optional[float]
22
23
  critic_lr: float
24
+ critic_encoder_lr: float
23
25
  max_seq_len: int
24
26
  critic_max_len: int
25
27
  weight_decay: float
@@ -58,6 +60,7 @@ class CurriculumConfig(TypedDict):
58
60
  lr: Optional[float]
59
61
  memory_lr: Optional[float]
60
62
  critic_lr: Optional[float]
63
+ critic_encoder_lr: Optional[float]
61
64
  weight_decay: Optional[float]
62
65
  critic_weight_decay: Optional[float]
63
66
  update_epochs: Optional[int]
@@ -158,6 +161,7 @@ class MRLTrainer:
158
161
  'critic_lr': config.get('critic_lr', 1e-4),
159
162
  'weight_decay': config.get('weight_decay', 0.01),
160
163
  'critic_weight_decay': config.get('critic_weight_decay', 0.01),
164
+ 'critic_encoder_lr': config.get('critic_encoder_lr', config.get('critic_lr', 1e-4)),
161
165
  }
162
166
  else:
163
167
  self.base_optim_config = {
@@ -165,6 +169,7 @@ class MRLTrainer:
165
169
  'critic_lr': config.get('critic_lr', 1e-4),
166
170
  'weight_decay': config.get('weight_decay', 0.01),
167
171
  'critic_weight_decay': config.get('critic_weight_decay', 0.01),
172
+ 'critic_encoder_lr': config.get('critic_encoder_lr', config.get('critic_lr', 1e-4)),
168
173
  }
169
174
 
170
175
  self.optim_config = self.base_optim_config
@@ -202,6 +207,7 @@ class MRLTrainer:
202
207
  critic_lr: float,
203
208
  weight_decay: float,
204
209
  critic_weight_decay: float,
210
+ critic_encoder_lr: float,
205
211
  memory_lr: Optional[float] = None,
206
212
  ) -> tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
207
213
  if memory_lr is not None:
@@ -219,8 +225,10 @@ class MRLTrainer:
219
225
  )
220
226
 
221
227
  critic_optimizer = torch.optim.AdamW(
222
- self.critic.parameters(),
223
- lr=critic_lr,
228
+ [
229
+ {'params': self.critic.head_parameters(), 'lr': critic_lr},
230
+ {'params': self.critic.encoder_parameters(), 'lr': critic_encoder_lr},
231
+ ],
224
232
  weight_decay=critic_weight_decay,
225
233
  )
226
234
 
@@ -481,7 +489,7 @@ class MRLTrainer:
481
489
  critic_state = smart_concat_critic_states(*state, max_length=self.critic_max_len,
482
490
  pad_token_id=self.pad_token_id)
483
491
  values = self.critic(critic_state['input_ids'], attention_mask=critic_state['attention_mask']).squeeze()
484
- critic_loss = self.rl_algorithm.critic_loss(values, ref_values)
492
+ critic_loss = self.rl_algorithm.critic_loss(values, ref_values.detach())
485
493
  # 2.2 Run backpropagation with scaler
486
494
  self.critic_scaler.scale(critic_loss).backward()
487
495
  # 2.3 Unscale and clip gradients
@@ -495,7 +503,7 @@ class MRLTrainer:
495
503
  critic_state = smart_concat_critic_states(*state, max_length=self.critic_max_len,
496
504
  pad_token_id=self.pad_token_id)
497
505
  values = self.critic(critic_state['input_ids'], attention_mask=critic_state['attention_mask']).squeeze()
498
- critic_loss = self.rl_algorithm.critic_loss(values, ref_values)
506
+ critic_loss = self.rl_algorithm.critic_loss(values, ref_values.detach())
499
507
  # 2.2 Run backpropagation
500
508
  critic_loss.backward()
501
509
  # 2.3 Clip gradients
@@ -633,7 +641,8 @@ class MRLTrainer:
633
641
  for i, t in enumerate(episode['steps'])
634
642
  ]
635
643
  values = torch.stack([
636
- self._critic_values_with_memory(r, *self._move_multiple_batches(*t['state'])) for t, r in flat_trajectories
644
+ self._critic_values_with_memory(r, *self._move_multiple_batches(*t['state'])) for t, r in
645
+ flat_trajectories
637
646
  ]).to(self.device)
638
647
  rewards = torch.stack([torch.tensor(t['reward']) for t, _ in flat_trajectories]).to(self.device)
639
648
  dones = torch.stack([torch.tensor(t['done']) for t, _ in flat_trajectories]).to(self.device)
@@ -646,7 +655,8 @@ class MRLTrainer:
646
655
  dones = torch.stack([torch.tensor(t['done']) for t in flat_trajectories]).to(self.device)
647
656
  return values, rewards, dones
648
657
 
649
- def _critic_values_with_memory(self, reset_stm: bool, *moved_state: tuple[TokenizedDict, TokenizedDict, TokenizedDict]) -> torch.Tensor:
658
+ def _critic_values_with_memory(self, reset_stm: bool,
659
+ *moved_state: tuple[TokenizedDict, TokenizedDict, TokenizedDict]) -> torch.Tensor:
650
660
  # 1. Calculate critic values in memory aware version - reset/update STM before calculating values
651
661
  with torch.no_grad():
652
662
  # 2. Reset STM if it was reset in trajectory collection
@@ -933,6 +943,7 @@ class MRLTrainer:
933
943
  'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
934
944
  'critic_weight_decay': config.get('critic_weight_decay',
935
945
  self.base_optim_config['critic_weight_decay']),
946
+ 'critic_encoder_lr': config.get('critic_encoder_lr', self.base_optim_config['critic_encoder_lr']),
936
947
  'memory_lr': config.get('memory_lr', self.base_optim_config['memory_lr']),
937
948
  }
938
949
  else:
@@ -942,6 +953,7 @@ class MRLTrainer:
942
953
  'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
943
954
  'critic_weight_decay': config.get('critic_weight_decay',
944
955
  self.base_optim_config['critic_weight_decay']),
956
+ 'critic_encoder_lr': config.get('critic_encoder_lr', self.base_optim_config['critic_encoder_lr']),
945
957
  }
946
958
  self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
947
959
  elif self.optim_config != self.base_optim_config:
rxnn/training/reward.py CHANGED
@@ -42,6 +42,7 @@ class MrlRewardModel:
42
42
  running_mean_decay: float = 0.2,
43
43
  bleu_saved_weights: tuple = (0.5, 0.5),
44
44
  bleu_ref_weights: tuple = (0.5, 0.5),
45
+ tanh_reward_scale: bool = False,
45
46
  rewards_scale: float = 1.0,
46
47
  ):
47
48
  self.shared_embedding = shared_embedding.to(device)
@@ -71,6 +72,7 @@ class MrlRewardModel:
71
72
  self.running_mean_decay = running_mean_decay
72
73
  self.bleu_ref_weights = bleu_ref_weights
73
74
  self.bleu_saved_weights = bleu_saved_weights
75
+ self.tanh_reward_scale = tanh_reward_scale
74
76
  self.rewards_scale = rewards_scale
75
77
 
76
78
  self.prev_data_running_mean = None
@@ -175,6 +177,12 @@ class MrlRewardModel:
175
177
  self.prev_data_running_mean = (1 - self.running_mean_decay) * self._sequence_embedding(
176
178
  prev_data) + self.running_mean_decay * self.prev_data_running_mean
177
179
 
180
+ def _pre_scale_rewards(self, rewards: torch.Tensor) -> torch.Tensor:
181
+ if self.tanh_reward_scale:
182
+ return (rewards * 2) - 1 # Convert [0,1] to [-1,1]
183
+ else:
184
+ return rewards
185
+
178
186
  def __call__(
179
187
  self,
180
188
  generated: TokenizedDict,
@@ -204,5 +212,5 @@ class MrlRewardModel:
204
212
  cosine = self.negative_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
205
213
  sim_rewards = self.neg_bleu_factor * torch.tensor(bleu, device=self.device) + self.neg_cos_factor * cosine
206
214
 
207
- rewards = (sim_rewards + self.len_factor * self.len_reward(generated) if self.reward_len else sim_rewards) * self.rewards_scale
215
+ rewards = self._pre_scale_rewards(sim_rewards + self.len_factor * self.len_reward(generated) if self.reward_len else sim_rewards) * self.rewards_scale
208
216
  return rewards.tolist()
rxnn/training/rl.py CHANGED
@@ -21,8 +21,8 @@ class RlAlgorithm(ABC):
21
21
  def calculate_advantages(self, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
22
22
  pass
23
23
 
24
- def critic_loss(self, rewards: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
25
- return self.critic_loss(rewards, values)
24
+ def critic_loss(self, values: torch.Tensor, ref_values: torch.Tensor) -> torch.Tensor:
25
+ return self.critic_loss(values, ref_values)
26
26
 
27
27
 
28
28
  class PPOConfig(TypedDict):
@@ -31,6 +31,8 @@ class PPOConfig(TypedDict):
31
31
  gae_gamma: Optional[float]
32
32
  entropy_coef: Optional[float]
33
33
  use_distributed_advantage_norm: Optional[bool]
34
+ clip_critic_values: Optional[bool]
35
+ critic_value_clip: Optional[float]
34
36
 
35
37
 
36
38
  class PPOAlgorithm(RlAlgorithm):
@@ -46,6 +48,14 @@ class PPOAlgorithm(RlAlgorithm):
46
48
  self.gae_gamma = config.get('gae_gamma', 0.99)
47
49
  self.entropy_coef = config.get('entropy_coef', 0.01)
48
50
  self.use_distributed_advantage_norm = config.get('use_distributed_advantage_norm', False)
51
+ self.clip_critic_values = config.get('clip_critic_values', True)
52
+ self.critic_value_clip = config.get('critic_value_clip', 10.0)
53
+
54
+ def critic_loss(self, values: torch.Tensor, ref_values: torch.Tensor) -> torch.Tensor:
55
+ # Critic loss with clipped values
56
+ if self.clip_critic_values:
57
+ values = torch.clamp(values, -self.critic_value_clip, self.critic_value_clip)
58
+ return self.critic_loss(values, ref_values)
49
59
 
50
60
  def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
51
61
  old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
@@ -96,23 +106,24 @@ class PPOAlgorithm(RlAlgorithm):
96
106
 
97
107
  return policy_loss
98
108
 
99
- def _compute_gae(self, rewards: torch.Tensor, values: torch.Tensor, next_value: torch.Tensor, dones: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
100
- T, B = rewards.shape
101
- advantages = torch.zeros_like(rewards, device=values.device)
109
+ def _compute_gae(self, rewards: torch.Tensor, values: torch.Tensor,
110
+ last_value: torch.Tensor, dones: torch.Tensor):
111
+ trajectory_len, batch_size = rewards.shape
112
+ advantages = torch.zeros_like(rewards, device=rewards.device)
102
113
  last_advantage = 0
103
- last_value = next_value.detach()
104
-
105
- for t in reversed(range(T)):
106
- if t == T - 1:
107
- next_values = last_value
108
- else:
109
- next_values = values[t + 1]
110
-
111
- # Mask next values if episode ended
112
- next_values = next_values * ~dones[t]
113
- delta = rewards[t] + self.gae_gamma * next_values - values[t]
114
- advantages[t] = delta + self.gae_gamma * self.gae_lambda * last_advantage
114
+ next_value = last_value
115
+ next_done = torch.zeros(batch_size, device=dones.device) # Last state is terminal
116
+ dones = dones.float()
117
+ for t in reversed(range(trajectory_len)):
118
+ # Check if next state is terminal
119
+ non_terminal = 1.0 - next_done
120
+
121
+ # Delta should not include next_value if next is terminal
122
+ delta = rewards[t] + self.gae_gamma * next_value * non_terminal - values[t]
123
+ advantages[t] = delta + self.gae_gamma * self.gae_lambda * non_terminal * last_advantage
115
124
  last_advantage = advantages[t]
125
+ next_value = values[t]
126
+ next_done = dones[t]
116
127
 
117
128
  returns = advantages + values
118
129
  return advantages, returns
rxnn/transformers/ff.py CHANGED
@@ -66,6 +66,8 @@ def get_activation_layer(activation: str):
66
66
  return nn.SiLU()
67
67
  elif activation == 'sigmoid':
68
68
  return nn.Sigmoid()
69
+ elif activation == 'tanh':
70
+ return nn.Tanh()
69
71
  elif activation == 'linear':
70
72
  return LinearActivation()
71
73
  else:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.32
3
+ Version: 0.2.34
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -16,16 +16,16 @@ rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
16
16
  rxnn/training/callbacks.py,sha256=p72lbzFAmFjpcUvyy4aUB3qd53I8C6Sk5w9nQvsKgTk,35852
17
17
  rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
18
18
  rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
19
- rxnn/training/models.py,sha256=2KhNT7yx0AgUke4nmsFqzQKx_YYp78QvsLWYZjWeUgQ,6812
20
- rxnn/training/mrl.py,sha256=Aimiiqf_4p6dp5Ty9pY9VwetySBS_OFpCQlcVHVkO4Q,55124
21
- rxnn/training/reward.py,sha256=7MTVdNm5HnWmt6zFDi3TAYmnVSL_-24riOoY2F7z4x8,11290
22
- rxnn/training/rl.py,sha256=eL3C0yryiNBgl_xb-D-5dyYUtK4V4-K4t3a60x5ir28,5142
19
+ rxnn/training/models.py,sha256=y-9XHedSheyK1AmLBp3ayulnUvAmDuJ3t0qVg8wHBRg,7463
20
+ rxnn/training/mrl.py,sha256=fIrg1Er0aAK4TnyDRmJC1m7az9wdkhikxv0CBCrGT-c,55868
21
+ rxnn/training/reward.py,sha256=B7nerPk9eNAv2i7umtNF88tVQVwijNNrchIrEITGHKk,11623
22
+ rxnn/training/rl.py,sha256=ckx1nlzIGZBabzwZNRj4isvHqRZwg0y0jGOT-SN6KZc,5841
23
23
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
24
24
  rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
25
25
  rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
26
26
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
27
  rxnn/transformers/attention.py,sha256=d0Igo1Nrn76BphbHrzekiKJfT3RCy4iSXSB6FLAOTgc,16247
28
- rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
28
+ rxnn/transformers/ff.py,sha256=WDjO-H9XWInoWnUnxiseIH6Kx5GlHP0zGJygwhcb1gc,2589
29
29
  rxnn/transformers/layers.py,sha256=l0bXmhN7KOkCw0KTVLixWSo9Op4SesGabWJ4R4EQBMY,7988
30
30
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
31
31
  rxnn/transformers/models.py,sha256=hey6tFN9gmLfWCZLjtl_9OcvIjGpWLI1IDeVnr5y8YM,10583
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
33
33
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
34
34
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
35
35
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
36
- rxnn-0.2.32.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
- rxnn-0.2.32.dist-info/METADATA,sha256=Ugq4LZZBakM-pUwpd0ZY0W_Ot2zUJHAxFtofIqAHzA8,25960
38
- rxnn-0.2.32.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
- rxnn-0.2.32.dist-info/RECORD,,
36
+ rxnn-0.2.34.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
+ rxnn-0.2.34.dist-info/METADATA,sha256=Q7LqPr7KHFhMPL6UrbqG1SmtJbM2Ho-Yuxp_7LyCtYw,25960
38
+ rxnn-0.2.34.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
+ rxnn-0.2.34.dist-info/RECORD,,
File without changes
File without changes