rxnn 0.2.26__py3-none-any.whl → 0.2.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/training/models.py +12 -2
- rxnn/training/mrl.py +125 -33
- {rxnn-0.2.26.dist-info → rxnn-0.2.27.dist-info}/METADATA +1 -1
- {rxnn-0.2.26.dist-info → rxnn-0.2.27.dist-info}/RECORD +6 -6
- {rxnn-0.2.26.dist-info → rxnn-0.2.27.dist-info}/LICENSE +0 -0
- {rxnn-0.2.26.dist-info → rxnn-0.2.27.dist-info}/WHEEL +0 -0
rxnn/training/models.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.nn as nn
|
3
3
|
from enum import Enum
|
4
|
-
from typing import Literal
|
4
|
+
from typing import Literal, Iterator
|
5
5
|
from huggingface_hub import PyTorchModelHubMixin
|
6
6
|
from ..transformers.models import ReactiveTransformerEncoder, ReactiveTransformerDecoder
|
7
7
|
|
@@ -75,7 +75,7 @@ class MrlActorModel(nn.Module):
|
|
75
75
|
self.decoder = decoder
|
76
76
|
self.memory_attention = memory_attention
|
77
77
|
|
78
|
-
def freeze_components(self, stage: Literal['update', 'fetch', '
|
78
|
+
def freeze_components(self, stage: Literal['update', 'fetch', 'joint'] = 'joint'):
|
79
79
|
"""Freeze encoder/decoder except memory-related layers."""
|
80
80
|
if self.encoder.freeze_without_memory is not None:
|
81
81
|
self.encoder.freeze_without_memory(unfreeze_norms=True)
|
@@ -131,6 +131,16 @@ class MrlActorModel(nn.Module):
|
|
131
131
|
self.memory_attention.parameters()
|
132
132
|
))
|
133
133
|
|
134
|
+
def memory_cross_attention_parameters(self) -> list[nn.Parameter]:
|
135
|
+
return list(set(
|
136
|
+
self.encoder.memory_parameters() +
|
137
|
+
self.decoder.memory_parameters()
|
138
|
+
))
|
139
|
+
|
140
|
+
def memory_attention_parameters(self) -> Iterator[nn.Parameter]:
|
141
|
+
return self.memory_attention.parameters()
|
142
|
+
|
143
|
+
|
134
144
|
def not_memory_parameters(self) -> list[nn.Parameter]:
|
135
145
|
return list(set(
|
136
146
|
self.encoder.not_memory_parameters() +
|
rxnn/training/mrl.py
CHANGED
@@ -3,7 +3,7 @@ from torch.utils.data import DataLoader, DistributedSampler
|
|
3
3
|
from torch.utils.tensorboard import SummaryWriter
|
4
4
|
import torch.distributed as dist
|
5
5
|
from torch.nn.parallel import DistributedDataParallel
|
6
|
-
from typing import Optional, TypedDict, Union
|
6
|
+
from typing import Optional, TypedDict, Union, TypeAlias, Literal
|
7
7
|
from enum import Enum
|
8
8
|
import random, os
|
9
9
|
from ..transformers.sampler import BatchSampler
|
@@ -31,6 +31,8 @@ class MrlStrategy(Enum):
|
|
31
31
|
MULTI_STEP_STRATEGY = 2
|
32
32
|
LONG_RANGE_STRATEGY = 3
|
33
33
|
|
34
|
+
UnfreezeItem = Union[int, tuple[int, float]]
|
35
|
+
UnfreezeEpochsStrategy: TypeAlias = Union[int, tuple[UnfreezeItem, UnfreezeItem, UnfreezeItem, int]]
|
34
36
|
|
35
37
|
class CurriculumConfig(TypedDict):
|
36
38
|
steps: int
|
@@ -39,7 +41,7 @@ class CurriculumConfig(TypedDict):
|
|
39
41
|
eval_dataset: Optional[MrlCurriculumDataset]
|
40
42
|
callbacks: Optional[list[MrlTrainerCallback]]
|
41
43
|
strategy: MrlStrategy
|
42
|
-
unfreeze_epoch: Optional[
|
44
|
+
unfreeze_epoch: Optional[UnfreezeEpochsStrategy]
|
43
45
|
random_resets: Optional[bool]
|
44
46
|
random_resets_from: Optional[int]
|
45
47
|
random_resets_ratio: Optional[float]
|
@@ -132,7 +134,8 @@ class MRLTrainer:
|
|
132
134
|
|
133
135
|
if self.separate_memory_lr:
|
134
136
|
self.base_optim_config = {
|
135
|
-
'lr':
|
137
|
+
'lr': config.get('lr', 3e-4),
|
138
|
+
'memory_lr': config.get('memory_lr', 5e-4),
|
136
139
|
'critic_lr': config.get('critic_lr', 1e-4),
|
137
140
|
'weight_decay': config.get('weight_decay', 0.01),
|
138
141
|
'critic_weight_decay': config.get('critic_weight_decay', 0.01),
|
@@ -145,8 +148,9 @@ class MRLTrainer:
|
|
145
148
|
'critic_weight_decay': config.get('critic_weight_decay', 0.01),
|
146
149
|
}
|
147
150
|
|
148
|
-
|
149
|
-
|
151
|
+
self.optim_config = self.base_optim_config
|
152
|
+
|
153
|
+
self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
|
150
154
|
|
151
155
|
self.scaler = torch.amp.GradScaler() if self.use_amp else None
|
152
156
|
self.critic_scaler = torch.amp.GradScaler() if self.use_amp else None
|
@@ -173,11 +177,17 @@ class MRLTrainer:
|
|
173
177
|
self.global_epoch = 0
|
174
178
|
self.global_epochs_count = 0
|
175
179
|
|
176
|
-
def _init_optimizers(
|
177
|
-
|
178
|
-
|
180
|
+
def _init_optimizers(
|
181
|
+
self,
|
182
|
+
lr: float,
|
183
|
+
critic_lr: float,
|
184
|
+
weight_decay: float,
|
185
|
+
critic_weight_decay: float,
|
186
|
+
memory_lr: Optional[float] = None,
|
187
|
+
) -> tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
|
188
|
+
if memory_lr is not None:
|
179
189
|
optimizer = torch.optim.AdamW([
|
180
|
-
{ 'params': self.actor.not_memory_parameters(), 'lr':
|
190
|
+
{ 'params': self.actor.not_memory_parameters(), 'lr': lr },
|
181
191
|
{ 'params': self.actor.memory_parameters(), 'lr': memory_lr },
|
182
192
|
],
|
183
193
|
weight_decay=weight_decay,
|
@@ -737,7 +747,7 @@ class MRLTrainer:
|
|
737
747
|
|
738
748
|
return should_stop_stage
|
739
749
|
|
740
|
-
def _setup_curriculum_step(self, config: CurriculumConfig) -> tuple[tuple[int,
|
750
|
+
def _setup_curriculum_step(self, config: CurriculumConfig) -> tuple[tuple[int, UnfreezeEpochsStrategy], tuple[bool, int, float]]:
|
741
751
|
# 1. Set common fields based on config
|
742
752
|
self.curriculum_steps = config.get('steps', 1) # number of steps to run in episode
|
743
753
|
self.train_dataset = config.get('dataset', None) # training dataset for current curriculum stage
|
@@ -748,13 +758,28 @@ class MRLTrainer:
|
|
748
758
|
MrlStrategy.MULTI_STEP_STRATEGY) # MRL strategy for given curriculum stage
|
749
759
|
self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
|
750
760
|
if config['lr'] is not None or config['critic_lr'] is not None or config['weight_decay'] is not None or config['critic_weight_decay'] is not None or (config['separate_memory_lr'] and config['memory_lr'] is not None):
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
761
|
+
if config.get('separate_memory_lr', False):
|
762
|
+
self.optim_config = {
|
763
|
+
'lr': config.get('lr', self.base_optim_config['lr']),
|
764
|
+
'critic_lr': config.get('critic_lr', self.base_optim_config['critic_lr']),
|
765
|
+
'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
|
766
|
+
'critic_weight_decay': config.get('critic_weight_decay', self.base_optim_config['critic_weight_decay']),
|
767
|
+
'memory_lr': config.get('memory_lr', self.base_optim_config['memory_lr']),
|
768
|
+
}
|
769
|
+
else:
|
770
|
+
self.optim_config = {
|
771
|
+
'lr': config.get('lr', self.base_optim_config['lr']),
|
772
|
+
'critic_lr': config.get('critic_lr', self.base_optim_config['critic_lr']),
|
773
|
+
'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
|
774
|
+
'critic_weight_decay': config.get('critic_weight_decay', self.base_optim_config['critic_weight_decay']),
|
775
|
+
}
|
776
|
+
self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
|
777
|
+
elif self.optim_config != self.base_optim_config:
|
778
|
+
self.optim_config = self.base_optim_config
|
779
|
+
self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
|
780
|
+
|
781
|
+
|
782
|
+
|
758
783
|
|
759
784
|
# 2. Get epochs and random resets configs
|
760
785
|
epochs = config.get('epochs', 5) # number of epochs for current stage
|
@@ -771,6 +796,82 @@ class MRLTrainer:
|
|
771
796
|
|
772
797
|
return (epochs, unfreeze_epoch), (random_resets, random_resets_from, random_resets_ratio)
|
773
798
|
|
799
|
+
def _apply_unfreeze_strategy(self, epoch: int, unfreeze_epoch: UnfreezeEpochsStrategy):
|
800
|
+
is_staged_unfreeze = isinstance(unfreeze_epoch, tuple)
|
801
|
+
if is_staged_unfreeze:
|
802
|
+
update_epoch, fetch_epoch, joint_epoch, all_epoch = unfreeze_epoch
|
803
|
+
|
804
|
+
if isinstance(update_epoch, tuple):
|
805
|
+
switch_epoch, cross_att_lr = update_epoch
|
806
|
+
if epoch == switch_epoch:
|
807
|
+
self.actor.freeze_components('joint')
|
808
|
+
self.optimizer = self._init_unfreeze_optimizer('update', cross_att_lr)
|
809
|
+
print(f"Activating 'update' unfreeze strategy with custom cross_att_lr: {cross_att_lr}")
|
810
|
+
elif epoch == update_epoch:
|
811
|
+
self.actor.freeze_components('update')
|
812
|
+
print(f"Activating 'update' unfreeze strategy - mem-att trainable / cross-att frozen / rest model frozen")
|
813
|
+
|
814
|
+
if isinstance(fetch_epoch, tuple):
|
815
|
+
switch_epoch, mem_att_lr = fetch_epoch
|
816
|
+
if epoch == fetch_epoch:
|
817
|
+
self.actor.freeze_components('joint')
|
818
|
+
self.optimizer = self._init_unfreeze_optimizer('fetch', mem_att_lr)
|
819
|
+
print(f"Activating 'fetch' unfreeze strategy with custom mem_att_lr: {mem_att_lr}")
|
820
|
+
elif epoch == fetch_epoch:
|
821
|
+
self.actor.freeze_components('fetch')
|
822
|
+
print(f"Activating 'fetch' unfreeze strategy - mem-att frozen / cross-att trainable / rest model frozen")
|
823
|
+
|
824
|
+
if isinstance(joint_epoch, tuple):
|
825
|
+
switch_epoch, model_lr = joint_epoch
|
826
|
+
if epoch == joint_epoch:
|
827
|
+
self.actor.unfreeze_components()
|
828
|
+
self.optimizer = self._init_unfreeze_optimizer('joint', model_lr)
|
829
|
+
print(f"Activating 'joint' unfreeze strategy with custom model_lr: {model_lr}")
|
830
|
+
elif epoch == joint_epoch:
|
831
|
+
self.actor.freeze_components('joint')
|
832
|
+
print(f"Activating 'joint' unfreeze strategy - mem-att/cross-att trainable / rest model frozen")
|
833
|
+
if epoch == all_epoch:
|
834
|
+
self.actor.unfreeze_components()
|
835
|
+
self.optimizer = self._init_unfreeze_optimizer('all', 0.)
|
836
|
+
print(f"Switching to train 'all' strategy - unfreeze all components")
|
837
|
+
elif epoch == unfreeze_epoch:
|
838
|
+
self.actor.unfreeze_components()
|
839
|
+
print(f"Switching to train 'all' strategy - unfreeze all components")
|
840
|
+
|
841
|
+
def _init_unfreeze_optimizer(
|
842
|
+
self,
|
843
|
+
mode: Literal['update', 'fetch', 'joint', 'all'],
|
844
|
+
unfreeze_lr: float,
|
845
|
+
) -> torch.optim.Optimizer:
|
846
|
+
memory_lr = self.optim_config['memory_lr'] if 'memory_lr' in self.optim_config else self.optim_config['lr']
|
847
|
+
model_lr = self.optim_config['lr']
|
848
|
+
|
849
|
+
if mode == 'update':
|
850
|
+
params = [
|
851
|
+
{'params': self.actor.not_memory_parameters(), 'lr': model_lr},
|
852
|
+
{'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
|
853
|
+
{'params': self.actor.memory_cross_attention_parameters(), 'lr': unfreeze_lr},
|
854
|
+
]
|
855
|
+
elif mode == 'fetch':
|
856
|
+
params = [
|
857
|
+
{'params': self.actor.not_memory_parameters(), 'lr': model_lr},
|
858
|
+
{'params': self.actor.memory_cross_attention_parameters(), 'lr': memory_lr},
|
859
|
+
{'params': self.actor.memory_attention_parameters(), 'lr': unfreeze_lr},
|
860
|
+
]
|
861
|
+
elif mode == 'joint':
|
862
|
+
params = [
|
863
|
+
{'params': self.actor.not_memory_parameters(), 'lr': unfreeze_lr},
|
864
|
+
{'params': self.actor.memory_parameters(), 'lr': memory_lr},
|
865
|
+
]
|
866
|
+
else:
|
867
|
+
params = [
|
868
|
+
{'params': self.actor.not_memory_parameters(), 'lr': model_lr},
|
869
|
+
{'params': self.actor.memory_parameters(), 'lr': memory_lr},
|
870
|
+
]
|
871
|
+
|
872
|
+
return torch.optim.AdamW(params, weight_decay=self.optim_config['weight_decay'])
|
873
|
+
|
874
|
+
|
774
875
|
def __call__(self, curriculum_config: list[CurriculumConfig], batch_size: int):
|
775
876
|
"""Start Memory Reinforcement Learning Curriculum."""
|
776
877
|
|
@@ -796,7 +897,11 @@ class MRLTrainer:
|
|
796
897
|
|
797
898
|
# 4. Freeze all components except memory attention and memory cross-attention layers in decoder/encoder
|
798
899
|
if unfreeze_epoch != 0:
|
799
|
-
self.actor.freeze_components('
|
900
|
+
self.actor.freeze_components('joint')
|
901
|
+
if isinstance(unfreeze_epoch, tuple):
|
902
|
+
print(f"Starting training with unfreeze strategies - 'warmup' - mem-att/cross-att trainable / rest model frozen")
|
903
|
+
else:
|
904
|
+
print(f"Starting training with simple unfreeze - 'joint' - mem-att/cross-att trainable / rest model frozen")
|
800
905
|
|
801
906
|
# 5. Setup train DataLoader
|
802
907
|
if self.use_ddp:
|
@@ -836,21 +941,8 @@ class MRLTrainer:
|
|
836
941
|
else:
|
837
942
|
self.random_resets_ratio = 1.0
|
838
943
|
|
839
|
-
# 11.
|
840
|
-
|
841
|
-
if is_staged_unfreeze:
|
842
|
-
update_epoch, fetch_epoch, both_epoch, all_epoch = unfreeze_epoch
|
843
|
-
if epoch == update_epoch:
|
844
|
-
self.actor.freeze_components('update')
|
845
|
-
elif epoch == fetch_epoch:
|
846
|
-
self.actor.freeze_components('fetch')
|
847
|
-
elif epoch == both_epoch:
|
848
|
-
self.actor.freeze_components('both')
|
849
|
-
elif epoch == all_epoch:
|
850
|
-
self.actor.unfreeze_components()
|
851
|
-
else:
|
852
|
-
if epoch == unfreeze_epoch:
|
853
|
-
self.actor.unfreeze_components()
|
944
|
+
# 11. Apply the unfreeze strategy
|
945
|
+
self._apply_unfreeze_strategy(epoch, unfreeze_epoch)
|
854
946
|
|
855
947
|
# 12. Set epoch for distributed sampler
|
856
948
|
if train_sampler is not None:
|
@@ -15,8 +15,8 @@ rxnn/training/base.py,sha256=_xik1GXE4RJ_nxwqLQ1ccXA5pRtBCi-jL-jeRFBdHBU,11851
|
|
15
15
|
rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
|
16
16
|
rxnn/training/callbacks.py,sha256=-N0MQPpZQaUWCINdTOsjul4bDGbGr2JgQBqOIXBLS6o,35053
|
17
17
|
rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
|
18
|
-
rxnn/training/models.py,sha256=
|
19
|
-
rxnn/training/mrl.py,sha256=
|
18
|
+
rxnn/training/models.py,sha256=bY6yZoXYJEsrcymtb5Ep41vmFVHplCGWlrw1dI0oFRc,6807
|
19
|
+
rxnn/training/mrl.py,sha256=MnLaYWxblc5cF261R5PNjIvddVQVNxyjAkEYtchBn9E,49299
|
20
20
|
rxnn/training/reward.py,sha256=7MTVdNm5HnWmt6zFDi3TAYmnVSL_-24riOoY2F7z4x8,11290
|
21
21
|
rxnn/training/rl.py,sha256=j-KNLoZjhaEKasYNOc8DxHtwvknAgAJFwvXKot6otFA,3272
|
22
22
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
@@ -32,7 +32,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
32
32
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
33
33
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
34
34
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
35
|
-
rxnn-0.2.
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
35
|
+
rxnn-0.2.27.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
36
|
+
rxnn-0.2.27.dist-info/METADATA,sha256=woZT3PVGgtEJP7DIAJv1-Mdfd4XvKoCRHANQgoTXoXk,25960
|
37
|
+
rxnn-0.2.27.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
38
|
+
rxnn-0.2.27.dist-info/RECORD,,
|
File without changes
|
File without changes
|