rxnn 0.2.33__tar.gz → 0.2.35__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rxnn-0.2.33 → rxnn-0.2.35}/PKG-INFO +1 -1
  2. {rxnn-0.2.33 → rxnn-0.2.35}/pyproject.toml +1 -1
  3. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/training/models.py +15 -3
  4. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/training/mrl.py +16 -4
  5. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/training/rl.py +12 -10
  6. {rxnn-0.2.33 → rxnn-0.2.35}/LICENSE +0 -0
  7. {rxnn-0.2.33 → rxnn-0.2.35}/README.md +0 -0
  8. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/.DS_Store +0 -0
  9. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/__init__.py +0 -0
  10. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/experimental/__init__.py +0 -0
  11. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/experimental/attention.py +0 -0
  12. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/experimental/models.py +0 -0
  13. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/experimental/moe.py +0 -0
  14. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/memory/__init__.py +0 -0
  15. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/memory/attention.py +0 -0
  16. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/memory/norm.py +0 -0
  17. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/memory/stm.py +0 -0
  18. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/rxt/__init__.py +0 -0
  19. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/rxt/models.py +0 -0
  20. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/training/__init__.py +0 -0
  21. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/training/base.py +0 -0
  22. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/training/bml.py +0 -0
  23. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/training/callbacks.py +0 -0
  24. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/training/dataset.py +0 -0
  25. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/training/ddp.py +0 -0
  26. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/training/reward.py +0 -0
  27. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/training/scheduler.py +0 -0
  28. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/training/tokenizer.py +0 -0
  29. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/training/utils.py +0 -0
  30. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/transformers/__init__.py +0 -0
  31. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/transformers/attention.py +0 -0
  32. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/transformers/ff.py +0 -0
  33. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/transformers/layers.py +0 -0
  34. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/transformers/mask.py +0 -0
  35. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/transformers/models.py +0 -0
  36. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/transformers/sampler.py +0 -0
  39. {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.33
3
+ Version: 0.2.35
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.33"
7
+ version = "0.2.35"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -6,6 +6,7 @@ from huggingface_hub import PyTorchModelHubMixin
6
6
  from ..transformers.models import ReactiveTransformerEncoder, ReactiveTransformerDecoder
7
7
  from ..transformers.ff import GatedLinearUnit, get_activation_layer
8
8
 
9
+
9
10
  class MLMHead(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
10
11
  def __init__(self, embed_dim: int, vocab_size: int, *args, **kwargs):
11
12
  super(MLMHead, self).__init__(*args, **kwargs)
@@ -38,6 +39,7 @@ class MLMTrainingModel(nn.Module):
38
39
  y = self.mlm_head(h)
39
40
  return y
40
41
 
42
+
41
43
  class JointTrainingModel(nn.Module):
42
44
  def __init__(
43
45
  self,
@@ -59,10 +61,12 @@ class JointTrainingModel(nn.Module):
59
61
  y_d = self.decoder(x_d, attention_mask=attention_mask)
60
62
  return y_e, y_d
61
63
 
64
+
62
65
  class MrlActorAction(Enum):
63
66
  DECODE = 1
64
67
  UPDATE = 2
65
68
 
69
+
66
70
  class MrlActorModel(nn.Module):
67
71
  def __init__(
68
72
  self,
@@ -154,15 +158,18 @@ class MrlActorModel(nn.Module):
154
158
  list(self.memory_attention.parameters())
155
159
  ))
156
160
 
157
- def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None, action: MrlActorAction = MrlActorAction.DECODE) -> torch.Tensor:
161
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None,
162
+ action: MrlActorAction = MrlActorAction.DECODE) -> torch.Tensor:
158
163
  if action == MrlActorAction.DECODE:
159
164
  return self.decoder(x, attention_mask=attention_mask)
160
165
  else:
161
166
  _, ed = self.encoder(x, attention_mask=attention_mask)
162
167
  return self.memory_attention(ed, attention_mask=attention_mask)
163
168
 
169
+
164
170
  class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipeline_tag="text-classification"):
165
- def __init__(self, encoder: nn.Module, embed_dim: int, out_activation: Literal['sigmoid', 'tanh', 'linear'] = 'sigmoid', output_scale: float = 1.0, **kwargs):
171
+ def __init__(self, encoder: nn.Module, embed_dim: int,
172
+ out_activation: Literal['sigmoid', 'tanh', 'linear'] = 'sigmoid', output_scale: float = 1.0, **kwargs):
166
173
  super(MrlCriticModel, self).__init__(**kwargs)
167
174
  self.encoder = encoder
168
175
  self.value_head = nn.Sequential(
@@ -173,6 +180,12 @@ class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipe
173
180
  )
174
181
  self.output_scale = output_scale
175
182
 
183
+ def head_parameters(self) -> Iterator[nn.Parameter]:
184
+ return self.value_head.parameters()
185
+
186
+ def encoder_parameters(self) -> Iterator[nn.Parameter]:
187
+ return self.encoder.parameters()
188
+
176
189
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
177
190
  x, _ = self.encoder(x, attention_mask=attention_mask)
178
191
 
@@ -183,4 +196,3 @@ class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipe
183
196
  x = x.mean(dim=1)
184
197
 
185
198
  return self.value_head(x) * self.output_scale
186
-
@@ -15,11 +15,13 @@ from .reward import MrlRewardMode, MrlRewardModel
15
15
  from .models import MrlActorAction, MrlActorModel, MrlCriticModel
16
16
  from .ddp import get_os_ddp_config, distributed_mean
17
17
 
18
+
18
19
  class MrlConfig(TypedDict):
19
20
  lr: float
20
21
  separate_memory_lr: Optional[bool]
21
22
  memory_lr: Optional[float]
22
23
  critic_lr: float
24
+ critic_encoder_lr: float
23
25
  max_seq_len: int
24
26
  critic_max_len: int
25
27
  weight_decay: float
@@ -58,6 +60,7 @@ class CurriculumConfig(TypedDict):
58
60
  lr: Optional[float]
59
61
  memory_lr: Optional[float]
60
62
  critic_lr: Optional[float]
63
+ critic_encoder_lr: Optional[float]
61
64
  weight_decay: Optional[float]
62
65
  critic_weight_decay: Optional[float]
63
66
  update_epochs: Optional[int]
@@ -158,6 +161,7 @@ class MRLTrainer:
158
161
  'critic_lr': config.get('critic_lr', 1e-4),
159
162
  'weight_decay': config.get('weight_decay', 0.01),
160
163
  'critic_weight_decay': config.get('critic_weight_decay', 0.01),
164
+ 'critic_encoder_lr': config.get('critic_encoder_lr', config.get('critic_lr', 1e-4)),
161
165
  }
162
166
  else:
163
167
  self.base_optim_config = {
@@ -165,6 +169,7 @@ class MRLTrainer:
165
169
  'critic_lr': config.get('critic_lr', 1e-4),
166
170
  'weight_decay': config.get('weight_decay', 0.01),
167
171
  'critic_weight_decay': config.get('critic_weight_decay', 0.01),
172
+ 'critic_encoder_lr': config.get('critic_encoder_lr', config.get('critic_lr', 1e-4)),
168
173
  }
169
174
 
170
175
  self.optim_config = self.base_optim_config
@@ -202,6 +207,7 @@ class MRLTrainer:
202
207
  critic_lr: float,
203
208
  weight_decay: float,
204
209
  critic_weight_decay: float,
210
+ critic_encoder_lr: float,
205
211
  memory_lr: Optional[float] = None,
206
212
  ) -> tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
207
213
  if memory_lr is not None:
@@ -219,8 +225,10 @@ class MRLTrainer:
219
225
  )
220
226
 
221
227
  critic_optimizer = torch.optim.AdamW(
222
- self.critic.parameters(),
223
- lr=critic_lr,
228
+ [
229
+ {'params': self.critic.head_parameters(), 'lr': critic_lr},
230
+ {'params': self.critic.encoder_parameters(), 'lr': critic_encoder_lr},
231
+ ],
224
232
  weight_decay=critic_weight_decay,
225
233
  )
226
234
 
@@ -633,7 +641,8 @@ class MRLTrainer:
633
641
  for i, t in enumerate(episode['steps'])
634
642
  ]
635
643
  values = torch.stack([
636
- self._critic_values_with_memory(r, *self._move_multiple_batches(*t['state'])) for t, r in flat_trajectories
644
+ self._critic_values_with_memory(r, *self._move_multiple_batches(*t['state'])) for t, r in
645
+ flat_trajectories
637
646
  ]).to(self.device)
638
647
  rewards = torch.stack([torch.tensor(t['reward']) for t, _ in flat_trajectories]).to(self.device)
639
648
  dones = torch.stack([torch.tensor(t['done']) for t, _ in flat_trajectories]).to(self.device)
@@ -646,7 +655,8 @@ class MRLTrainer:
646
655
  dones = torch.stack([torch.tensor(t['done']) for t in flat_trajectories]).to(self.device)
647
656
  return values, rewards, dones
648
657
 
649
- def _critic_values_with_memory(self, reset_stm: bool, *moved_state: tuple[TokenizedDict, TokenizedDict, TokenizedDict]) -> torch.Tensor:
658
+ def _critic_values_with_memory(self, reset_stm: bool,
659
+ *moved_state: tuple[TokenizedDict, TokenizedDict, TokenizedDict]) -> torch.Tensor:
650
660
  # 1. Calculate critic values in memory aware version - reset/update STM before calculating values
651
661
  with torch.no_grad():
652
662
  # 2. Reset STM if it was reset in trajectory collection
@@ -933,6 +943,7 @@ class MRLTrainer:
933
943
  'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
934
944
  'critic_weight_decay': config.get('critic_weight_decay',
935
945
  self.base_optim_config['critic_weight_decay']),
946
+ 'critic_encoder_lr': config.get('critic_encoder_lr', self.base_optim_config['critic_encoder_lr']),
936
947
  'memory_lr': config.get('memory_lr', self.base_optim_config['memory_lr']),
937
948
  }
938
949
  else:
@@ -942,6 +953,7 @@ class MRLTrainer:
942
953
  'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
943
954
  'critic_weight_decay': config.get('critic_weight_decay',
944
955
  self.base_optim_config['critic_weight_decay']),
956
+ 'critic_encoder_lr': config.get('critic_encoder_lr', self.base_optim_config['critic_encoder_lr']),
945
957
  }
946
958
  self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
947
959
  elif self.optim_config != self.base_optim_config:
@@ -10,7 +10,7 @@ from .ddp import distributed_mean
10
10
  class RlAlgorithm(ABC):
11
11
  def __init__(self):
12
12
  super(RlAlgorithm, self).__init__()
13
- self.critic_loss = nn.MSELoss()
13
+ self.critic_loss_fn = nn.MSELoss()
14
14
 
15
15
  @abstractmethod
16
16
  def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
@@ -22,7 +22,7 @@ class RlAlgorithm(ABC):
22
22
  pass
23
23
 
24
24
  def critic_loss(self, values: torch.Tensor, ref_values: torch.Tensor) -> torch.Tensor:
25
- return self.critic_loss(values, ref_values)
25
+ return self.critic_loss_fn(values, ref_values)
26
26
 
27
27
 
28
28
  class PPOConfig(TypedDict):
@@ -55,7 +55,7 @@ class PPOAlgorithm(RlAlgorithm):
55
55
  # Critic loss with clipped values
56
56
  if self.clip_critic_values:
57
57
  values = torch.clamp(values, -self.critic_value_clip, self.critic_value_clip)
58
- return self.critic_loss(values, ref_values)
58
+ return self.critic_loss_fn(values, ref_values)
59
59
 
60
60
  def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
61
61
  old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
@@ -114,15 +114,17 @@ class PPOAlgorithm(RlAlgorithm):
114
114
  next_value = last_value
115
115
  next_done = torch.zeros(batch_size, device=dones.device) # Last state is terminal
116
116
  dones = dones.float()
117
- for t in reversed(range(trajectory_len)):
118
- # Check if next state is terminal
119
- non_terminal = 1.0 - next_done
120
117
 
121
- # Delta should not include next_value if next is terminal
122
- delta = rewards[t] + self.gae_gamma * next_value * non_terminal - values[t]
123
- advantages[t] = delta + self.gae_gamma * self.gae_lambda * non_terminal * last_advantage
118
+ for t in reversed(range(trajectory_len)):
119
+ if t == trajectory_len - 1:
120
+ # For the last step, use the provided last_value
121
+ delta = rewards[t] + self.gae_gamma * next_value * (1 - next_done) - values[t]
122
+ else:
123
+ # For other steps, use the next value in the trajectory
124
+ delta = rewards[t] + self.gae_gamma * values[t + 1] * (1 - dones[t + 1]) - values[t]
125
+
126
+ advantages[t] = delta + self.gae_gamma * self.gae_lambda * (1 - dones[t]) * last_advantage
124
127
  last_advantage = advantages[t]
125
- next_value = values[t]
126
128
  next_done = dones[t]
127
129
 
128
130
  returns = advantages + values
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes