rxnn 0.2.33__py3-none-any.whl → 0.2.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/models.py CHANGED
@@ -6,6 +6,7 @@ from huggingface_hub import PyTorchModelHubMixin
6
6
  from ..transformers.models import ReactiveTransformerEncoder, ReactiveTransformerDecoder
7
7
  from ..transformers.ff import GatedLinearUnit, get_activation_layer
8
8
 
9
+
9
10
  class MLMHead(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
10
11
  def __init__(self, embed_dim: int, vocab_size: int, *args, **kwargs):
11
12
  super(MLMHead, self).__init__(*args, **kwargs)
@@ -38,6 +39,7 @@ class MLMTrainingModel(nn.Module):
38
39
  y = self.mlm_head(h)
39
40
  return y
40
41
 
42
+
41
43
  class JointTrainingModel(nn.Module):
42
44
  def __init__(
43
45
  self,
@@ -59,10 +61,12 @@ class JointTrainingModel(nn.Module):
59
61
  y_d = self.decoder(x_d, attention_mask=attention_mask)
60
62
  return y_e, y_d
61
63
 
64
+
62
65
  class MrlActorAction(Enum):
63
66
  DECODE = 1
64
67
  UPDATE = 2
65
68
 
69
+
66
70
  class MrlActorModel(nn.Module):
67
71
  def __init__(
68
72
  self,
@@ -154,15 +158,18 @@ class MrlActorModel(nn.Module):
154
158
  list(self.memory_attention.parameters())
155
159
  ))
156
160
 
157
- def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None, action: MrlActorAction = MrlActorAction.DECODE) -> torch.Tensor:
161
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None,
162
+ action: MrlActorAction = MrlActorAction.DECODE) -> torch.Tensor:
158
163
  if action == MrlActorAction.DECODE:
159
164
  return self.decoder(x, attention_mask=attention_mask)
160
165
  else:
161
166
  _, ed = self.encoder(x, attention_mask=attention_mask)
162
167
  return self.memory_attention(ed, attention_mask=attention_mask)
163
168
 
169
+
164
170
  class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipeline_tag="text-classification"):
165
- def __init__(self, encoder: nn.Module, embed_dim: int, out_activation: Literal['sigmoid', 'tanh', 'linear'] = 'sigmoid', output_scale: float = 1.0, **kwargs):
171
+ def __init__(self, encoder: nn.Module, embed_dim: int,
172
+ out_activation: Literal['sigmoid', 'tanh', 'linear'] = 'sigmoid', output_scale: float = 1.0, **kwargs):
166
173
  super(MrlCriticModel, self).__init__(**kwargs)
167
174
  self.encoder = encoder
168
175
  self.value_head = nn.Sequential(
@@ -173,6 +180,12 @@ class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipe
173
180
  )
174
181
  self.output_scale = output_scale
175
182
 
183
+ def head_parameters(self) -> Iterator[nn.Parameter]:
184
+ return self.value_head.parameters()
185
+
186
+ def encoder_parameters(self) -> Iterator[nn.Parameter]:
187
+ return self.encoder.parameters()
188
+
176
189
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
177
190
  x, _ = self.encoder(x, attention_mask=attention_mask)
178
191
 
@@ -183,4 +196,3 @@ class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipe
183
196
  x = x.mean(dim=1)
184
197
 
185
198
  return self.value_head(x) * self.output_scale
186
-
rxnn/training/mrl.py CHANGED
@@ -15,11 +15,13 @@ from .reward import MrlRewardMode, MrlRewardModel
15
15
  from .models import MrlActorAction, MrlActorModel, MrlCriticModel
16
16
  from .ddp import get_os_ddp_config, distributed_mean
17
17
 
18
+
18
19
  class MrlConfig(TypedDict):
19
20
  lr: float
20
21
  separate_memory_lr: Optional[bool]
21
22
  memory_lr: Optional[float]
22
23
  critic_lr: float
24
+ critic_encoder_lr: float
23
25
  max_seq_len: int
24
26
  critic_max_len: int
25
27
  weight_decay: float
@@ -58,6 +60,7 @@ class CurriculumConfig(TypedDict):
58
60
  lr: Optional[float]
59
61
  memory_lr: Optional[float]
60
62
  critic_lr: Optional[float]
63
+ critic_encoder_lr: Optional[float]
61
64
  weight_decay: Optional[float]
62
65
  critic_weight_decay: Optional[float]
63
66
  update_epochs: Optional[int]
@@ -158,6 +161,7 @@ class MRLTrainer:
158
161
  'critic_lr': config.get('critic_lr', 1e-4),
159
162
  'weight_decay': config.get('weight_decay', 0.01),
160
163
  'critic_weight_decay': config.get('critic_weight_decay', 0.01),
164
+ 'critic_encoder_lr': config.get('critic_encoder_lr', config.get('critic_lr', 1e-4)),
161
165
  }
162
166
  else:
163
167
  self.base_optim_config = {
@@ -165,6 +169,7 @@ class MRLTrainer:
165
169
  'critic_lr': config.get('critic_lr', 1e-4),
166
170
  'weight_decay': config.get('weight_decay', 0.01),
167
171
  'critic_weight_decay': config.get('critic_weight_decay', 0.01),
172
+ 'critic_encoder_lr': config.get('critic_encoder_lr', config.get('critic_lr', 1e-4)),
168
173
  }
169
174
 
170
175
  self.optim_config = self.base_optim_config
@@ -202,6 +207,7 @@ class MRLTrainer:
202
207
  critic_lr: float,
203
208
  weight_decay: float,
204
209
  critic_weight_decay: float,
210
+ critic_encoder_lr: float,
205
211
  memory_lr: Optional[float] = None,
206
212
  ) -> tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
207
213
  if memory_lr is not None:
@@ -219,8 +225,10 @@ class MRLTrainer:
219
225
  )
220
226
 
221
227
  critic_optimizer = torch.optim.AdamW(
222
- self.critic.parameters(),
223
- lr=critic_lr,
228
+ [
229
+ {'params': self.critic.head_parameters(), 'lr': critic_lr},
230
+ {'params': self.critic.encoder_parameters(), 'lr': critic_encoder_lr},
231
+ ],
224
232
  weight_decay=critic_weight_decay,
225
233
  )
226
234
 
@@ -633,7 +641,8 @@ class MRLTrainer:
633
641
  for i, t in enumerate(episode['steps'])
634
642
  ]
635
643
  values = torch.stack([
636
- self._critic_values_with_memory(r, *self._move_multiple_batches(*t['state'])) for t, r in flat_trajectories
644
+ self._critic_values_with_memory(r, *self._move_multiple_batches(*t['state'])) for t, r in
645
+ flat_trajectories
637
646
  ]).to(self.device)
638
647
  rewards = torch.stack([torch.tensor(t['reward']) for t, _ in flat_trajectories]).to(self.device)
639
648
  dones = torch.stack([torch.tensor(t['done']) for t, _ in flat_trajectories]).to(self.device)
@@ -646,7 +655,8 @@ class MRLTrainer:
646
655
  dones = torch.stack([torch.tensor(t['done']) for t in flat_trajectories]).to(self.device)
647
656
  return values, rewards, dones
648
657
 
649
- def _critic_values_with_memory(self, reset_stm: bool, *moved_state: tuple[TokenizedDict, TokenizedDict, TokenizedDict]) -> torch.Tensor:
658
+ def _critic_values_with_memory(self, reset_stm: bool,
659
+ *moved_state: tuple[TokenizedDict, TokenizedDict, TokenizedDict]) -> torch.Tensor:
650
660
  # 1. Calculate critic values in memory aware version - reset/update STM before calculating values
651
661
  with torch.no_grad():
652
662
  # 2. Reset STM if it was reset in trajectory collection
@@ -933,6 +943,7 @@ class MRLTrainer:
933
943
  'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
934
944
  'critic_weight_decay': config.get('critic_weight_decay',
935
945
  self.base_optim_config['critic_weight_decay']),
946
+ 'critic_encoder_lr': config.get('critic_encoder_lr', self.base_optim_config['critic_encoder_lr']),
936
947
  'memory_lr': config.get('memory_lr', self.base_optim_config['memory_lr']),
937
948
  }
938
949
  else:
@@ -942,6 +953,7 @@ class MRLTrainer:
942
953
  'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
943
954
  'critic_weight_decay': config.get('critic_weight_decay',
944
955
  self.base_optim_config['critic_weight_decay']),
956
+ 'critic_encoder_lr': config.get('critic_encoder_lr', self.base_optim_config['critic_encoder_lr']),
945
957
  }
946
958
  self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
947
959
  elif self.optim_config != self.base_optim_config:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.33
3
+ Version: 0.2.34
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -16,8 +16,8 @@ rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
16
16
  rxnn/training/callbacks.py,sha256=p72lbzFAmFjpcUvyy4aUB3qd53I8C6Sk5w9nQvsKgTk,35852
17
17
  rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
18
18
  rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
19
- rxnn/training/models.py,sha256=8FV5eZx1HxtqRSgikwfKoB_bNhPuMYyNi0uSXB65-M4,7223
20
- rxnn/training/mrl.py,sha256=1pYzjXI17FDZGPTVpmbaBvMYpB-a6SLv-84RHXA4JEA,55142
19
+ rxnn/training/models.py,sha256=y-9XHedSheyK1AmLBp3ayulnUvAmDuJ3t0qVg8wHBRg,7463
20
+ rxnn/training/mrl.py,sha256=fIrg1Er0aAK4TnyDRmJC1m7az9wdkhikxv0CBCrGT-c,55868
21
21
  rxnn/training/reward.py,sha256=B7nerPk9eNAv2i7umtNF88tVQVwijNNrchIrEITGHKk,11623
22
22
  rxnn/training/rl.py,sha256=ckx1nlzIGZBabzwZNRj4isvHqRZwg0y0jGOT-SN6KZc,5841
23
23
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
33
33
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
34
34
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
35
35
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
36
- rxnn-0.2.33.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
- rxnn-0.2.33.dist-info/METADATA,sha256=im17irb58IYMXOzMXE6QaSPF31Akx0iYS4ay-aRqA9Q,25960
38
- rxnn-0.2.33.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
- rxnn-0.2.33.dist-info/RECORD,,
36
+ rxnn-0.2.34.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
+ rxnn-0.2.34.dist-info/METADATA,sha256=Q7LqPr7KHFhMPL6UrbqG1SmtJbM2Ho-Yuxp_7LyCtYw,25960
38
+ rxnn-0.2.34.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
+ rxnn-0.2.34.dist-info/RECORD,,
File without changes
File without changes