rxnn 0.2.26__py3-none-any.whl → 0.2.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/models.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
  from enum import Enum
4
- from typing import Literal
4
+ from typing import Literal, Iterator
5
5
  from huggingface_hub import PyTorchModelHubMixin
6
6
  from ..transformers.models import ReactiveTransformerEncoder, ReactiveTransformerDecoder
7
7
 
@@ -75,7 +75,7 @@ class MrlActorModel(nn.Module):
75
75
  self.decoder = decoder
76
76
  self.memory_attention = memory_attention
77
77
 
78
- def freeze_components(self, stage: Literal['update', 'fetch', 'both'] = 'both'):
78
+ def freeze_components(self, stage: Literal['update', 'fetch', 'joint'] = 'joint'):
79
79
  """Freeze encoder/decoder except memory-related layers."""
80
80
  if self.encoder.freeze_without_memory is not None:
81
81
  self.encoder.freeze_without_memory(unfreeze_norms=True)
@@ -131,6 +131,16 @@ class MrlActorModel(nn.Module):
131
131
  self.memory_attention.parameters()
132
132
  ))
133
133
 
134
+ def memory_cross_attention_parameters(self) -> list[nn.Parameter]:
135
+ return list(set(
136
+ self.encoder.memory_parameters() +
137
+ self.decoder.memory_parameters()
138
+ ))
139
+
140
+ def memory_attention_parameters(self) -> Iterator[nn.Parameter]:
141
+ return self.memory_attention.parameters()
142
+
143
+
134
144
  def not_memory_parameters(self) -> list[nn.Parameter]:
135
145
  return list(set(
136
146
  self.encoder.not_memory_parameters() +
rxnn/training/mrl.py CHANGED
@@ -3,7 +3,7 @@ from torch.utils.data import DataLoader, DistributedSampler
3
3
  from torch.utils.tensorboard import SummaryWriter
4
4
  import torch.distributed as dist
5
5
  from torch.nn.parallel import DistributedDataParallel
6
- from typing import Optional, TypedDict, Union
6
+ from typing import Optional, TypedDict, Union, TypeAlias, Literal
7
7
  from enum import Enum
8
8
  import random, os
9
9
  from ..transformers.sampler import BatchSampler
@@ -31,6 +31,8 @@ class MrlStrategy(Enum):
31
31
  MULTI_STEP_STRATEGY = 2
32
32
  LONG_RANGE_STRATEGY = 3
33
33
 
34
+ UnfreezeItem = Union[int, tuple[int, float]]
35
+ UnfreezeEpochsStrategy: TypeAlias = Union[int, tuple[UnfreezeItem, UnfreezeItem, UnfreezeItem, int]]
34
36
 
35
37
  class CurriculumConfig(TypedDict):
36
38
  steps: int
@@ -39,7 +41,7 @@ class CurriculumConfig(TypedDict):
39
41
  eval_dataset: Optional[MrlCurriculumDataset]
40
42
  callbacks: Optional[list[MrlTrainerCallback]]
41
43
  strategy: MrlStrategy
42
- unfreeze_epoch: Optional[Union[int, tuple[int, int, int, int]]]
44
+ unfreeze_epoch: Optional[UnfreezeEpochsStrategy]
43
45
  random_resets: Optional[bool]
44
46
  random_resets_from: Optional[int]
45
47
  random_resets_ratio: Optional[float]
@@ -132,7 +134,8 @@ class MRLTrainer:
132
134
 
133
135
  if self.separate_memory_lr:
134
136
  self.base_optim_config = {
135
- 'lr': (config.get('lr', 3e-4), config.get('memory_lr', 5e-4)),
137
+ 'lr': config.get('lr', 3e-4),
138
+ 'memory_lr': config.get('memory_lr', 5e-4),
136
139
  'critic_lr': config.get('critic_lr', 1e-4),
137
140
  'weight_decay': config.get('weight_decay', 0.01),
138
141
  'critic_weight_decay': config.get('critic_weight_decay', 0.01),
@@ -145,8 +148,9 @@ class MRLTrainer:
145
148
  'critic_weight_decay': config.get('critic_weight_decay', 0.01),
146
149
  }
147
150
 
148
- # Optimizers
149
- self.optimizer, self.critic_optimizer = self._init_optimizers(**self.base_optim_config, separate_memory_lr=self.separate_memory_lr)
151
+ self.optim_config = self.base_optim_config
152
+
153
+ self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
150
154
 
151
155
  self.scaler = torch.amp.GradScaler() if self.use_amp else None
152
156
  self.critic_scaler = torch.amp.GradScaler() if self.use_amp else None
@@ -173,11 +177,17 @@ class MRLTrainer:
173
177
  self.global_epoch = 0
174
178
  self.global_epochs_count = 0
175
179
 
176
- def _init_optimizers(self, lr: Union[float, tuple[float, float]], critic_lr: float, weight_decay: float, critic_weight_decay: float, separate_memory_lr: bool = False) -> tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
177
- if separate_memory_lr:
178
- rest_lr, memory_lr = lr
180
+ def _init_optimizers(
181
+ self,
182
+ lr: float,
183
+ critic_lr: float,
184
+ weight_decay: float,
185
+ critic_weight_decay: float,
186
+ memory_lr: Optional[float] = None,
187
+ ) -> tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
188
+ if memory_lr is not None:
179
189
  optimizer = torch.optim.AdamW([
180
- { 'params': self.actor.not_memory_parameters(), 'lr': rest_lr },
190
+ { 'params': self.actor.not_memory_parameters(), 'lr': lr },
181
191
  { 'params': self.actor.memory_parameters(), 'lr': memory_lr },
182
192
  ],
183
193
  weight_decay=weight_decay,
@@ -737,7 +747,7 @@ class MRLTrainer:
737
747
 
738
748
  return should_stop_stage
739
749
 
740
- def _setup_curriculum_step(self, config: CurriculumConfig) -> tuple[tuple[int, int], tuple[bool, int, float]]:
750
+ def _setup_curriculum_step(self, config: CurriculumConfig) -> tuple[tuple[int, UnfreezeEpochsStrategy], tuple[bool, int, float]]:
741
751
  # 1. Set common fields based on config
742
752
  self.curriculum_steps = config.get('steps', 1) # number of steps to run in episode
743
753
  self.train_dataset = config.get('dataset', None) # training dataset for current curriculum stage
@@ -748,13 +758,28 @@ class MRLTrainer:
748
758
  MrlStrategy.MULTI_STEP_STRATEGY) # MRL strategy for given curriculum stage
749
759
  self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
750
760
  if config['lr'] is not None or config['critic_lr'] is not None or config['weight_decay'] is not None or config['critic_weight_decay'] is not None or (config['separate_memory_lr'] and config['memory_lr'] is not None):
751
- self.optimizer, self.critic_optimizer = self._init_optimizers(
752
- lr=(config.get('lr', self.base_optim_config['lr'][0]), config.get('memory_lr', self.base_optim_config['lr'][1])) if config.get('separate_memory_lr', False) else config.get('lr', self.base_optim_config['lr']),
753
- critic_lr=config.get('critic_lr', self.base_optim_config['critic_lr']),
754
- weight_decay=config.get('weight_decay', self.base_optim_config['weight_decay']),
755
- critic_weight_decay=config.get('critic_weight_decay', self.base_optim_config['critic_weight_decay']),
756
- separate_memory_lr=config.get('separate_memory_lr', False),
757
- )
761
+ if config.get('separate_memory_lr', False):
762
+ self.optim_config = {
763
+ 'lr': config.get('lr', self.base_optim_config['lr']),
764
+ 'critic_lr': config.get('critic_lr', self.base_optim_config['critic_lr']),
765
+ 'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
766
+ 'critic_weight_decay': config.get('critic_weight_decay', self.base_optim_config['critic_weight_decay']),
767
+ 'memory_lr': config.get('memory_lr', self.base_optim_config['memory_lr']),
768
+ }
769
+ else:
770
+ self.optim_config = {
771
+ 'lr': config.get('lr', self.base_optim_config['lr']),
772
+ 'critic_lr': config.get('critic_lr', self.base_optim_config['critic_lr']),
773
+ 'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
774
+ 'critic_weight_decay': config.get('critic_weight_decay', self.base_optim_config['critic_weight_decay']),
775
+ }
776
+ self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
777
+ elif self.optim_config != self.base_optim_config:
778
+ self.optim_config = self.base_optim_config
779
+ self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
780
+
781
+
782
+
758
783
 
759
784
  # 2. Get epochs and random resets configs
760
785
  epochs = config.get('epochs', 5) # number of epochs for current stage
@@ -771,6 +796,82 @@ class MRLTrainer:
771
796
 
772
797
  return (epochs, unfreeze_epoch), (random_resets, random_resets_from, random_resets_ratio)
773
798
 
799
+ def _apply_unfreeze_strategy(self, epoch: int, unfreeze_epoch: UnfreezeEpochsStrategy):
800
+ is_staged_unfreeze = isinstance(unfreeze_epoch, tuple)
801
+ if is_staged_unfreeze:
802
+ update_epoch, fetch_epoch, joint_epoch, all_epoch = unfreeze_epoch
803
+
804
+ if isinstance(update_epoch, tuple):
805
+ switch_epoch, cross_att_lr = update_epoch
806
+ if epoch == switch_epoch:
807
+ self.actor.freeze_components('joint')
808
+ self.optimizer = self._init_unfreeze_optimizer('update', cross_att_lr)
809
+ print(f"Activating 'update' unfreeze strategy with custom cross_att_lr: {cross_att_lr}")
810
+ elif epoch == update_epoch:
811
+ self.actor.freeze_components('update')
812
+ print(f"Activating 'update' unfreeze strategy - mem-att trainable / cross-att frozen / rest model frozen")
813
+
814
+ if isinstance(fetch_epoch, tuple):
815
+ switch_epoch, mem_att_lr = fetch_epoch
816
+ if epoch == fetch_epoch:
817
+ self.actor.freeze_components('joint')
818
+ self.optimizer = self._init_unfreeze_optimizer('fetch', mem_att_lr)
819
+ print(f"Activating 'fetch' unfreeze strategy with custom mem_att_lr: {mem_att_lr}")
820
+ elif epoch == fetch_epoch:
821
+ self.actor.freeze_components('fetch')
822
+ print(f"Activating 'fetch' unfreeze strategy - mem-att frozen / cross-att trainable / rest model frozen")
823
+
824
+ if isinstance(joint_epoch, tuple):
825
+ switch_epoch, model_lr = joint_epoch
826
+ if epoch == joint_epoch:
827
+ self.actor.unfreeze_components()
828
+ self.optimizer = self._init_unfreeze_optimizer('joint', model_lr)
829
+ print(f"Activating 'joint' unfreeze strategy with custom model_lr: {model_lr}")
830
+ elif epoch == joint_epoch:
831
+ self.actor.freeze_components('joint')
832
+ print(f"Activating 'joint' unfreeze strategy - mem-att/cross-att trainable / rest model frozen")
833
+ if epoch == all_epoch:
834
+ self.actor.unfreeze_components()
835
+ self.optimizer = self._init_unfreeze_optimizer('all', 0.)
836
+ print(f"Switching to train 'all' strategy - unfreeze all components")
837
+ elif epoch == unfreeze_epoch:
838
+ self.actor.unfreeze_components()
839
+ print(f"Switching to train 'all' strategy - unfreeze all components")
840
+
841
+ def _init_unfreeze_optimizer(
842
+ self,
843
+ mode: Literal['update', 'fetch', 'joint', 'all'],
844
+ unfreeze_lr: float,
845
+ ) -> torch.optim.Optimizer:
846
+ memory_lr = self.optim_config['memory_lr'] if 'memory_lr' in self.optim_config else self.optim_config['lr']
847
+ model_lr = self.optim_config['lr']
848
+
849
+ if mode == 'update':
850
+ params = [
851
+ {'params': self.actor.not_memory_parameters(), 'lr': model_lr},
852
+ {'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
853
+ {'params': self.actor.memory_cross_attention_parameters(), 'lr': unfreeze_lr},
854
+ ]
855
+ elif mode == 'fetch':
856
+ params = [
857
+ {'params': self.actor.not_memory_parameters(), 'lr': model_lr},
858
+ {'params': self.actor.memory_cross_attention_parameters(), 'lr': memory_lr},
859
+ {'params': self.actor.memory_attention_parameters(), 'lr': unfreeze_lr},
860
+ ]
861
+ elif mode == 'joint':
862
+ params = [
863
+ {'params': self.actor.not_memory_parameters(), 'lr': unfreeze_lr},
864
+ {'params': self.actor.memory_parameters(), 'lr': memory_lr},
865
+ ]
866
+ else:
867
+ params = [
868
+ {'params': self.actor.not_memory_parameters(), 'lr': model_lr},
869
+ {'params': self.actor.memory_parameters(), 'lr': memory_lr},
870
+ ]
871
+
872
+ return torch.optim.AdamW(params, weight_decay=self.optim_config['weight_decay'])
873
+
874
+
774
875
  def __call__(self, curriculum_config: list[CurriculumConfig], batch_size: int):
775
876
  """Start Memory Reinforcement Learning Curriculum."""
776
877
 
@@ -796,7 +897,11 @@ class MRLTrainer:
796
897
 
797
898
  # 4. Freeze all components except memory attention and memory cross-attention layers in decoder/encoder
798
899
  if unfreeze_epoch != 0:
799
- self.actor.freeze_components('both')
900
+ self.actor.freeze_components('joint')
901
+ if isinstance(unfreeze_epoch, tuple):
902
+ print(f"Starting training with unfreeze strategies - 'warmup' - mem-att/cross-att trainable / rest model frozen")
903
+ else:
904
+ print(f"Starting training with simple unfreeze - 'joint' - mem-att/cross-att trainable / rest model frozen")
800
905
 
801
906
  # 5. Setup train DataLoader
802
907
  if self.use_ddp:
@@ -836,21 +941,8 @@ class MRLTrainer:
836
941
  else:
837
942
  self.random_resets_ratio = 1.0
838
943
 
839
- # 11. Unfreeze all components before selected epoch
840
- is_staged_unfreeze = isinstance(unfreeze_epoch, tuple)
841
- if is_staged_unfreeze:
842
- update_epoch, fetch_epoch, both_epoch, all_epoch = unfreeze_epoch
843
- if epoch == update_epoch:
844
- self.actor.freeze_components('update')
845
- elif epoch == fetch_epoch:
846
- self.actor.freeze_components('fetch')
847
- elif epoch == both_epoch:
848
- self.actor.freeze_components('both')
849
- elif epoch == all_epoch:
850
- self.actor.unfreeze_components()
851
- else:
852
- if epoch == unfreeze_epoch:
853
- self.actor.unfreeze_components()
944
+ # 11. Apply the unfreeze strategy
945
+ self._apply_unfreeze_strategy(epoch, unfreeze_epoch)
854
946
 
855
947
  # 12. Set epoch for distributed sampler
856
948
  if train_sampler is not None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.26
3
+ Version: 0.2.27
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -15,8 +15,8 @@ rxnn/training/base.py,sha256=_xik1GXE4RJ_nxwqLQ1ccXA5pRtBCi-jL-jeRFBdHBU,11851
15
15
  rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
16
16
  rxnn/training/callbacks.py,sha256=-N0MQPpZQaUWCINdTOsjul4bDGbGr2JgQBqOIXBLS6o,35053
17
17
  rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
18
- rxnn/training/models.py,sha256=_TrFwrQ_m6NDPalrafd8faPRyCnDFFFtN_gfzavaCFs,6474
19
- rxnn/training/mrl.py,sha256=hDsKQTaQcEVmnJruD3TxHZJJzDWu5I6Rq2HVDLj8ADU,44747
18
+ rxnn/training/models.py,sha256=bY6yZoXYJEsrcymtb5Ep41vmFVHplCGWlrw1dI0oFRc,6807
19
+ rxnn/training/mrl.py,sha256=MnLaYWxblc5cF261R5PNjIvddVQVNxyjAkEYtchBn9E,49299
20
20
  rxnn/training/reward.py,sha256=7MTVdNm5HnWmt6zFDi3TAYmnVSL_-24riOoY2F7z4x8,11290
21
21
  rxnn/training/rl.py,sha256=j-KNLoZjhaEKasYNOc8DxHtwvknAgAJFwvXKot6otFA,3272
22
22
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
@@ -32,7 +32,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
32
32
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
33
33
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
34
34
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
35
- rxnn-0.2.26.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
- rxnn-0.2.26.dist-info/METADATA,sha256=XDqI42X3zLRAAKZlVLmstm24KFPP_MfvDtObG9GBc0Y,25960
37
- rxnn-0.2.26.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
- rxnn-0.2.26.dist-info/RECORD,,
35
+ rxnn-0.2.27.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
+ rxnn-0.2.27.dist-info/METADATA,sha256=woZT3PVGgtEJP7DIAJv1-Mdfd4XvKoCRHANQgoTXoXk,25960
37
+ rxnn-0.2.27.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
+ rxnn-0.2.27.dist-info/RECORD,,
File without changes
File without changes