rxnn 0.2.19__py3-none-any.whl → 0.2.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/training/models.py +10 -5
- rxnn/training/mrl.py +53 -15
- {rxnn-0.2.19.dist-info → rxnn-0.2.21.dist-info}/METADATA +1 -1
- {rxnn-0.2.19.dist-info → rxnn-0.2.21.dist-info}/RECORD +6 -6
- {rxnn-0.2.19.dist-info → rxnn-0.2.21.dist-info}/LICENSE +0 -0
- {rxnn-0.2.19.dist-info → rxnn-0.2.21.dist-info}/WHEEL +0 -0
rxnn/training/models.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.nn as nn
|
3
3
|
from enum import Enum
|
4
|
+
from typing import Literal
|
4
5
|
from huggingface_hub import PyTorchModelHubMixin
|
5
6
|
from ..transformers.models import ReactiveTransformerEncoder, ReactiveTransformerDecoder
|
6
7
|
|
@@ -74,23 +75,27 @@ class MrlActorModel(nn.Module):
|
|
74
75
|
self.decoder = decoder
|
75
76
|
self.memory_attention = memory_attention
|
76
77
|
|
77
|
-
def freeze_components(self):
|
78
|
+
def freeze_components(self, stage: Literal['update', 'fetch', 'both'] = 'both'):
|
78
79
|
"""Freeze encoder/decoder except memory-related layers."""
|
79
80
|
if self.encoder.freeze_without_memory is not None:
|
80
81
|
self.encoder.freeze_without_memory()
|
82
|
+
if stage == 'update':
|
83
|
+
self.encoder.freeze_memory()
|
81
84
|
else:
|
82
85
|
for param in self.encoder.parameters():
|
83
86
|
param.requires_grad = False
|
84
|
-
self.encoder.model.trainable_cross_attention_(True)
|
87
|
+
self.encoder.model.trainable_cross_attention_(True if stage != 'update' else False)
|
85
88
|
if self.decoder.freeze_without_memory is not None:
|
86
89
|
self.decoder.freeze_without_memory()
|
90
|
+
if stage == 'update':
|
91
|
+
self.decoder.freeze_memory()
|
87
92
|
else:
|
88
93
|
for param in self.decoder.parameters():
|
89
94
|
param.requires_grad = False
|
90
|
-
self.decoder.model.trainable_cross_attention_(True)
|
95
|
+
self.decoder.model.trainable_cross_attention_(True if stage != 'update' else False)
|
91
96
|
# Unfreeze memory attention
|
92
97
|
for param in self.memory_attention.parameters():
|
93
|
-
param.requires_grad = True
|
98
|
+
param.requires_grad = True if stage != 'fetch' else False
|
94
99
|
|
95
100
|
def unfreeze_components(self):
|
96
101
|
"""Unfreeze all components after initial training."""
|
@@ -124,7 +129,7 @@ class MrlActorModel(nn.Module):
|
|
124
129
|
_, ed = self.encoder(x, attention_mask=attention_mask)
|
125
130
|
return self.memory_attention(ed, attention_mask=attention_mask)
|
126
131
|
|
127
|
-
class MrlCriticModel(nn.Module):
|
132
|
+
class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipeline_tag="text-classification"):
|
128
133
|
def __init__(self, encoder: nn.Module, embed_dim: int, **kwargs):
|
129
134
|
super(MrlCriticModel, self).__init__(**kwargs)
|
130
135
|
self.encoder = encoder
|
rxnn/training/mrl.py
CHANGED
@@ -3,7 +3,7 @@ from torch.utils.data import DataLoader, DistributedSampler
|
|
3
3
|
from torch.utils.tensorboard import SummaryWriter
|
4
4
|
import torch.distributed as dist
|
5
5
|
from torch.nn.parallel import DistributedDataParallel
|
6
|
-
from typing import Optional, TypedDict
|
6
|
+
from typing import Optional, TypedDict, Union
|
7
7
|
from enum import Enum
|
8
8
|
import random, os
|
9
9
|
from ..transformers.sampler import BatchSampler
|
@@ -37,11 +37,15 @@ class CurriculumConfig(TypedDict):
|
|
37
37
|
eval_dataset: Optional[MrlCurriculumDataset]
|
38
38
|
callbacks: Optional[list[MrlTrainerCallback]]
|
39
39
|
strategy: MrlStrategy
|
40
|
-
unfreeze_epoch: Optional[int]
|
40
|
+
unfreeze_epoch: Optional[Union[int, tuple[int, int, int]]]
|
41
41
|
random_resets: Optional[bool]
|
42
42
|
random_resets_from: Optional[int]
|
43
43
|
random_resets_ratio: Optional[float]
|
44
44
|
reward_model: Optional[MrlRewardModel]
|
45
|
+
lr: Optional[float]
|
46
|
+
critic_lr: Optional[float]
|
47
|
+
weight_decay: Optional[float]
|
48
|
+
critic_weight_decay: Optional[float]
|
45
49
|
|
46
50
|
|
47
51
|
class SamplerConfig(TypedDict):
|
@@ -119,17 +123,15 @@ class MRLTrainer:
|
|
119
123
|
self.use_amp = use_amp
|
120
124
|
self.dtype = dtype
|
121
125
|
|
126
|
+
self.base_optim_config = {
|
127
|
+
'lr': config.get('lr', 3e-4),
|
128
|
+
'critic_lr': config.get('critic_lr', 1e-4),
|
129
|
+
'weight_decay': config.get('weight_decay', 0.01),
|
130
|
+
'critic_weight_decay': config.get('critic_weight_decay', 0.01),
|
131
|
+
}
|
132
|
+
|
122
133
|
# Optimizers
|
123
|
-
self.optimizer =
|
124
|
-
self.actor.unique_parameters(),
|
125
|
-
lr=config.get("lr", 3e-4),
|
126
|
-
weight_decay=config.get("weight_decay", 0.01),
|
127
|
-
)
|
128
|
-
self.critic_optimizer = torch.optim.AdamW(
|
129
|
-
self.critic.parameters(),
|
130
|
-
lr=config.get("critic_lr", 1e-4),
|
131
|
-
weight_decay=config.get("critic_weight_decay", 0.01),
|
132
|
-
)
|
134
|
+
self.optimizer, self.critic_optimizer = self._init_optimizers(**self.base_optim_config)
|
133
135
|
|
134
136
|
self.scaler = torch.amp.GradScaler() if self.use_amp else None
|
135
137
|
self.critic_scaler = torch.amp.GradScaler() if self.use_amp else None
|
@@ -156,6 +158,21 @@ class MRLTrainer:
|
|
156
158
|
self.global_epoch = 0
|
157
159
|
self.global_epochs_count = 0
|
158
160
|
|
161
|
+
def _init_optimizers(self, lr: float, critic_lr: float, weight_decay: float, critic_weight_decay: float):
|
162
|
+
optimizer = torch.optim.AdamW(
|
163
|
+
self.actor.unique_parameters(),
|
164
|
+
lr=lr,
|
165
|
+
weight_decay=weight_decay,
|
166
|
+
)
|
167
|
+
|
168
|
+
critic_optimizer = torch.optim.AdamW(
|
169
|
+
self.critic.parameters(),
|
170
|
+
lr=critic_lr,
|
171
|
+
weight_decay=critic_weight_decay,
|
172
|
+
)
|
173
|
+
return optimizer, critic_optimizer
|
174
|
+
|
175
|
+
|
159
176
|
def _init_steps(self):
|
160
177
|
return {
|
161
178
|
'collect': 0,
|
@@ -705,6 +722,13 @@ class MRLTrainer:
|
|
705
722
|
self.strategy = config.get('strategy',
|
706
723
|
MrlStrategy.MULTI_STEP_STRATEGY) # MRL strategy for given curriculum stage
|
707
724
|
self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
|
725
|
+
if config['lr'] is not None or config['critic_lr'] is not None or config['weight_decay'] is not None or config['critic_weight_decay'] is not None:
|
726
|
+
self.optimizer, self.critic_optimizer = self._init_optimizers(
|
727
|
+
lr=config['lr'] or self.base_optim_config['lr'],
|
728
|
+
critic_lr=config['critic_lr'] or self.base_optim_config['critic_lr'],
|
729
|
+
weight_decay=config['weight_decay'] or self.base_optim_config['weight_decay'],
|
730
|
+
critic_weight_decay=config['critic_weight_decay'] or self.base_optim_config['critic_weight_decay']
|
731
|
+
)
|
708
732
|
|
709
733
|
# 2. Get epochs and random resets configs
|
710
734
|
epochs = config.get('epochs', 5) # number of epochs for current stage
|
@@ -746,7 +770,11 @@ class MRLTrainer:
|
|
746
770
|
|
747
771
|
# 4. Freeze all components except memory attention and memory cross-attention layers in decoder/encoder
|
748
772
|
if unfreeze_epoch != 0:
|
749
|
-
|
773
|
+
is_staged_unfreeze = isinstance(unfreeze_epoch, tuple)
|
774
|
+
if is_staged_unfreeze:
|
775
|
+
self.actor.freeze_components('update')
|
776
|
+
else:
|
777
|
+
self.actor.freeze_components()
|
750
778
|
|
751
779
|
# 5. Setup train DataLoader
|
752
780
|
if self.use_ddp:
|
@@ -787,8 +815,18 @@ class MRLTrainer:
|
|
787
815
|
self.random_resets_ratio = 1.0
|
788
816
|
|
789
817
|
# 11. Unfreeze all components before selected epoch
|
790
|
-
|
791
|
-
|
818
|
+
is_staged_unfreeze = isinstance(unfreeze_epoch, tuple)
|
819
|
+
if is_staged_unfreeze:
|
820
|
+
fetch_epoch, both_epoch, all_epoch = unfreeze_epoch
|
821
|
+
if epoch == fetch_epoch:
|
822
|
+
self.actor.freeze_components('fetch')
|
823
|
+
elif epoch == both_epoch:
|
824
|
+
self.actor.freeze_components('both')
|
825
|
+
elif epoch == all_epoch:
|
826
|
+
self.actor.unfreeze_components()
|
827
|
+
else:
|
828
|
+
if epoch == unfreeze_epoch:
|
829
|
+
self.actor.unfreeze_components()
|
792
830
|
|
793
831
|
# 12. Set epoch for distributed sampler
|
794
832
|
if train_sampler is not None:
|
@@ -15,8 +15,8 @@ rxnn/training/base.py,sha256=_xik1GXE4RJ_nxwqLQ1ccXA5pRtBCi-jL-jeRFBdHBU,11851
|
|
15
15
|
rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
|
16
16
|
rxnn/training/callbacks.py,sha256=-N0MQPpZQaUWCINdTOsjul4bDGbGr2JgQBqOIXBLS6o,35053
|
17
17
|
rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
|
18
|
-
rxnn/training/models.py,sha256=
|
19
|
-
rxnn/training/mrl.py,sha256=
|
18
|
+
rxnn/training/models.py,sha256=wf98gYKKm9-ZY3zwdX9NIeJ-lvh7Ro1SoAijmQxYM28,5599
|
19
|
+
rxnn/training/mrl.py,sha256=zk4m1JFuX0y82J0tG2XkY0Pz6Uy2did9cngOXqR9lMk,43326
|
20
20
|
rxnn/training/reward.py,sha256=7MTVdNm5HnWmt6zFDi3TAYmnVSL_-24riOoY2F7z4x8,11290
|
21
21
|
rxnn/training/rl.py,sha256=DHFwnPUlnq2JVj6CS6DwifnC_eMeBAUVp36UCAWNMis,3934
|
22
22
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
@@ -32,7 +32,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
32
32
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
33
33
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
34
34
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
35
|
-
rxnn-0.2.
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
35
|
+
rxnn-0.2.21.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
36
|
+
rxnn-0.2.21.dist-info/METADATA,sha256=XXf_qBMs2dOwWyAN5oNEg1W1-oPVIAQPy0FkNcO7QZQ,25960
|
37
|
+
rxnn-0.2.21.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
38
|
+
rxnn-0.2.21.dist-info/RECORD,,
|
File without changes
|
File without changes
|