rxnn 0.2.19__py3-none-any.whl → 0.2.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/models.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
  from enum import Enum
4
+ from typing import Literal
4
5
  from huggingface_hub import PyTorchModelHubMixin
5
6
  from ..transformers.models import ReactiveTransformerEncoder, ReactiveTransformerDecoder
6
7
 
@@ -74,23 +75,27 @@ class MrlActorModel(nn.Module):
74
75
  self.decoder = decoder
75
76
  self.memory_attention = memory_attention
76
77
 
77
- def freeze_components(self):
78
+ def freeze_components(self, stage: Literal['update', 'fetch', 'both'] = 'both'):
78
79
  """Freeze encoder/decoder except memory-related layers."""
79
80
  if self.encoder.freeze_without_memory is not None:
80
81
  self.encoder.freeze_without_memory()
82
+ if stage == 'update':
83
+ self.encoder.freeze_memory()
81
84
  else:
82
85
  for param in self.encoder.parameters():
83
86
  param.requires_grad = False
84
- self.encoder.model.trainable_cross_attention_(True)
87
+ self.encoder.model.trainable_cross_attention_(True if stage != 'update' else False)
85
88
  if self.decoder.freeze_without_memory is not None:
86
89
  self.decoder.freeze_without_memory()
90
+ if stage == 'update':
91
+ self.decoder.freeze_memory()
87
92
  else:
88
93
  for param in self.decoder.parameters():
89
94
  param.requires_grad = False
90
- self.decoder.model.trainable_cross_attention_(True)
95
+ self.decoder.model.trainable_cross_attention_(True if stage != 'update' else False)
91
96
  # Unfreeze memory attention
92
97
  for param in self.memory_attention.parameters():
93
- param.requires_grad = True
98
+ param.requires_grad = True if stage != 'fetch' else False
94
99
 
95
100
  def unfreeze_components(self):
96
101
  """Unfreeze all components after initial training."""
@@ -124,7 +129,7 @@ class MrlActorModel(nn.Module):
124
129
  _, ed = self.encoder(x, attention_mask=attention_mask)
125
130
  return self.memory_attention(ed, attention_mask=attention_mask)
126
131
 
127
- class MrlCriticModel(nn.Module):
132
+ class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipeline_tag="text-classification"):
128
133
  def __init__(self, encoder: nn.Module, embed_dim: int, **kwargs):
129
134
  super(MrlCriticModel, self).__init__(**kwargs)
130
135
  self.encoder = encoder
rxnn/training/mrl.py CHANGED
@@ -3,7 +3,7 @@ from torch.utils.data import DataLoader, DistributedSampler
3
3
  from torch.utils.tensorboard import SummaryWriter
4
4
  import torch.distributed as dist
5
5
  from torch.nn.parallel import DistributedDataParallel
6
- from typing import Optional, TypedDict
6
+ from typing import Optional, TypedDict, Union
7
7
  from enum import Enum
8
8
  import random, os
9
9
  from ..transformers.sampler import BatchSampler
@@ -37,11 +37,15 @@ class CurriculumConfig(TypedDict):
37
37
  eval_dataset: Optional[MrlCurriculumDataset]
38
38
  callbacks: Optional[list[MrlTrainerCallback]]
39
39
  strategy: MrlStrategy
40
- unfreeze_epoch: Optional[int]
40
+ unfreeze_epoch: Optional[Union[int, tuple[int, int, int]]]
41
41
  random_resets: Optional[bool]
42
42
  random_resets_from: Optional[int]
43
43
  random_resets_ratio: Optional[float]
44
44
  reward_model: Optional[MrlRewardModel]
45
+ lr: Optional[float]
46
+ critic_lr: Optional[float]
47
+ weight_decay: Optional[float]
48
+ critic_weight_decay: Optional[float]
45
49
 
46
50
 
47
51
  class SamplerConfig(TypedDict):
@@ -119,17 +123,15 @@ class MRLTrainer:
119
123
  self.use_amp = use_amp
120
124
  self.dtype = dtype
121
125
 
126
+ self.base_optim_config = {
127
+ 'lr': config.get('lr', 3e-4),
128
+ 'critic_lr': config.get('critic_lr', 1e-4),
129
+ 'weight_decay': config.get('weight_decay', 0.01),
130
+ 'critic_weight_decay': config.get('critic_weight_decay', 0.01),
131
+ }
132
+
122
133
  # Optimizers
123
- self.optimizer = torch.optim.AdamW(
124
- self.actor.unique_parameters(),
125
- lr=config.get("lr", 3e-4),
126
- weight_decay=config.get("weight_decay", 0.01),
127
- )
128
- self.critic_optimizer = torch.optim.AdamW(
129
- self.critic.parameters(),
130
- lr=config.get("critic_lr", 1e-4),
131
- weight_decay=config.get("critic_weight_decay", 0.01),
132
- )
134
+ self.optimizer, self.critic_optimizer = self._init_optimizers(**self.base_optim_config)
133
135
 
134
136
  self.scaler = torch.amp.GradScaler() if self.use_amp else None
135
137
  self.critic_scaler = torch.amp.GradScaler() if self.use_amp else None
@@ -156,6 +158,21 @@ class MRLTrainer:
156
158
  self.global_epoch = 0
157
159
  self.global_epochs_count = 0
158
160
 
161
+ def _init_optimizers(self, lr: float, critic_lr: float, weight_decay: float, critic_weight_decay: float):
162
+ optimizer = torch.optim.AdamW(
163
+ self.actor.unique_parameters(),
164
+ lr=lr,
165
+ weight_decay=weight_decay,
166
+ )
167
+
168
+ critic_optimizer = torch.optim.AdamW(
169
+ self.critic.parameters(),
170
+ lr=critic_lr,
171
+ weight_decay=critic_weight_decay,
172
+ )
173
+ return optimizer, critic_optimizer
174
+
175
+
159
176
  def _init_steps(self):
160
177
  return {
161
178
  'collect': 0,
@@ -705,6 +722,13 @@ class MRLTrainer:
705
722
  self.strategy = config.get('strategy',
706
723
  MrlStrategy.MULTI_STEP_STRATEGY) # MRL strategy for given curriculum stage
707
724
  self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
725
+ if config['lr'] is not None or config['critic_lr'] is not None or config['weight_decay'] is not None or config['critic_weight_decay'] is not None:
726
+ self.optimizer, self.critic_optimizer = self._init_optimizers(
727
+ lr=config['lr'] or self.base_optim_config['lr'],
728
+ critic_lr=config['critic_lr'] or self.base_optim_config['critic_lr'],
729
+ weight_decay=config['weight_decay'] or self.base_optim_config['weight_decay'],
730
+ critic_weight_decay=config['critic_weight_decay'] or self.base_optim_config['critic_weight_decay']
731
+ )
708
732
 
709
733
  # 2. Get epochs and random resets configs
710
734
  epochs = config.get('epochs', 5) # number of epochs for current stage
@@ -746,7 +770,11 @@ class MRLTrainer:
746
770
 
747
771
  # 4. Freeze all components except memory attention and memory cross-attention layers in decoder/encoder
748
772
  if unfreeze_epoch != 0:
749
- self.actor.freeze_components()
773
+ is_staged_unfreeze = isinstance(unfreeze_epoch, tuple)
774
+ if is_staged_unfreeze:
775
+ self.actor.freeze_components('update')
776
+ else:
777
+ self.actor.freeze_components()
750
778
 
751
779
  # 5. Setup train DataLoader
752
780
  if self.use_ddp:
@@ -787,8 +815,18 @@ class MRLTrainer:
787
815
  self.random_resets_ratio = 1.0
788
816
 
789
817
  # 11. Unfreeze all components before selected epoch
790
- if epoch == unfreeze_epoch:
791
- self.actor.unfreeze_components()
818
+ is_staged_unfreeze = isinstance(unfreeze_epoch, tuple)
819
+ if is_staged_unfreeze:
820
+ fetch_epoch, both_epoch, all_epoch = unfreeze_epoch
821
+ if epoch == fetch_epoch:
822
+ self.actor.freeze_components('fetch')
823
+ elif epoch == both_epoch:
824
+ self.actor.freeze_components('both')
825
+ elif epoch == all_epoch:
826
+ self.actor.unfreeze_components()
827
+ else:
828
+ if epoch == unfreeze_epoch:
829
+ self.actor.unfreeze_components()
792
830
 
793
831
  # 12. Set epoch for distributed sampler
794
832
  if train_sampler is not None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.19
3
+ Version: 0.2.21
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -15,8 +15,8 @@ rxnn/training/base.py,sha256=_xik1GXE4RJ_nxwqLQ1ccXA5pRtBCi-jL-jeRFBdHBU,11851
15
15
  rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
16
16
  rxnn/training/callbacks.py,sha256=-N0MQPpZQaUWCINdTOsjul4bDGbGr2JgQBqOIXBLS6o,35053
17
17
  rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
18
- rxnn/training/models.py,sha256=renPa5YH443XNTMFI-YTCwi5vNp3QzwF5UXedNd5hDk,5187
19
- rxnn/training/mrl.py,sha256=RSbeJRRjAH1lzkySzeoDmng6hleRmUfnNcM1YVv57as,41388
18
+ rxnn/training/models.py,sha256=wf98gYKKm9-ZY3zwdX9NIeJ-lvh7Ro1SoAijmQxYM28,5599
19
+ rxnn/training/mrl.py,sha256=zk4m1JFuX0y82J0tG2XkY0Pz6Uy2did9cngOXqR9lMk,43326
20
20
  rxnn/training/reward.py,sha256=7MTVdNm5HnWmt6zFDi3TAYmnVSL_-24riOoY2F7z4x8,11290
21
21
  rxnn/training/rl.py,sha256=DHFwnPUlnq2JVj6CS6DwifnC_eMeBAUVp36UCAWNMis,3934
22
22
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
@@ -32,7 +32,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
32
32
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
33
33
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
34
34
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
35
- rxnn-0.2.19.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
- rxnn-0.2.19.dist-info/METADATA,sha256=y3om6t6e6WreQXmVjEfmr_vSkqBl-R04Tmch9Qk6rQg,25960
37
- rxnn-0.2.19.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
- rxnn-0.2.19.dist-info/RECORD,,
35
+ rxnn-0.2.21.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
+ rxnn-0.2.21.dist-info/METADATA,sha256=XXf_qBMs2dOwWyAN5oNEg1W1-oPVIAQPy0FkNcO7QZQ,25960
37
+ rxnn-0.2.21.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
+ rxnn-0.2.21.dist-info/RECORD,,
File without changes
File without changes