rxnn 0.2.22__py3-none-any.whl → 0.2.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/rxt/models.py +12 -4
- rxnn/training/models.py +19 -10
- rxnn/training/mrl.py +10 -12
- rxnn/transformers/layers.py +4 -1
- rxnn/transformers/models.py +3 -3
- {rxnn-0.2.22.dist-info → rxnn-0.2.24.dist-info}/METADATA +1 -1
- {rxnn-0.2.22.dist-info → rxnn-0.2.24.dist-info}/RECORD +9 -9
- {rxnn-0.2.22.dist-info → rxnn-0.2.24.dist-info}/LICENSE +0 -0
- {rxnn-0.2.22.dist-info → rxnn-0.2.24.dist-info}/WHEEL +0 -0
rxnn/rxt/models.py
CHANGED
@@ -137,13 +137,13 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
|
|
137
137
|
def load_shared_memory(self, stm: ShortTermMemory):
|
138
138
|
self.model.stm = stm
|
139
139
|
|
140
|
-
def freeze_without_memory(self):
|
140
|
+
def freeze_without_memory(self, unfreeze_norms: bool = True):
|
141
141
|
for param in self.model.parameters():
|
142
142
|
param.requires_grad_(False)
|
143
|
-
self.model.trainable_cross_attention_(True)
|
143
|
+
self.model.trainable_cross_attention_(True, with_norms=unfreeze_norms)
|
144
144
|
|
145
|
-
def freeze_memory(self):
|
146
|
-
self.model.trainable_cross_attention_(False)
|
145
|
+
def freeze_memory(self, with_norms: bool = True):
|
146
|
+
self.model.trainable_cross_attention_(False, with_norms=with_norms)
|
147
147
|
|
148
148
|
def unfreeze_all(self):
|
149
149
|
for param in self.model.parameters():
|
@@ -264,6 +264,14 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
264
264
|
attention_layers = nn.ModuleList([att_init() for _ in range(num_layers)])
|
265
265
|
self.model = StmMemoryAttention(stm, attention_layers, memory_norm_layers)
|
266
266
|
|
267
|
+
def freeze(self):
|
268
|
+
for param in self.parameters():
|
269
|
+
param.requires_grad = False
|
270
|
+
|
271
|
+
def unfreeze(self):
|
272
|
+
for param in self.parameters():
|
273
|
+
param.requires_grad = True
|
274
|
+
|
267
275
|
def load_shared_memory(self, stm: ShortTermMemory):
|
268
276
|
self.model.stm = stm
|
269
277
|
|
rxnn/training/models.py
CHANGED
@@ -78,24 +78,30 @@ class MrlActorModel(nn.Module):
|
|
78
78
|
def freeze_components(self, stage: Literal['update', 'fetch', 'both'] = 'both'):
|
79
79
|
"""Freeze encoder/decoder except memory-related layers."""
|
80
80
|
if self.encoder.freeze_without_memory is not None:
|
81
|
-
self.encoder.freeze_without_memory()
|
81
|
+
self.encoder.freeze_without_memory(unfreeze_norms=True)
|
82
82
|
if stage == 'update':
|
83
|
-
self.encoder.freeze_memory()
|
83
|
+
self.encoder.freeze_memory(with_norms=True)
|
84
84
|
else:
|
85
85
|
for param in self.encoder.parameters():
|
86
86
|
param.requires_grad = False
|
87
|
-
self.encoder.model.trainable_cross_attention_(True if stage != 'update' else False)
|
87
|
+
self.encoder.model.trainable_cross_attention_(True if stage != 'update' else False, with_norms=True)
|
88
88
|
if self.decoder.freeze_without_memory is not None:
|
89
|
-
self.decoder.freeze_without_memory()
|
89
|
+
self.decoder.freeze_without_memory(unfreeze_norms=True)
|
90
90
|
if stage == 'update':
|
91
|
-
self.decoder.freeze_memory()
|
91
|
+
self.decoder.freeze_memory(with_norms=True)
|
92
92
|
else:
|
93
93
|
for param in self.decoder.parameters():
|
94
94
|
param.requires_grad = False
|
95
|
-
self.decoder.model.trainable_cross_attention_(True if stage != 'update' else False)
|
95
|
+
self.decoder.model.trainable_cross_attention_(True if stage != 'update' else False, with_norms=True)
|
96
96
|
# Unfreeze memory attention
|
97
|
-
|
98
|
-
|
97
|
+
if self.memory_attention.freeze is not None:
|
98
|
+
if stage == 'fetch':
|
99
|
+
self.memory_attention.freeze()
|
100
|
+
else:
|
101
|
+
self.memory_attention.unfreeze()
|
102
|
+
else:
|
103
|
+
for param in self.memory_attention.parameters():
|
104
|
+
param.requires_grad = True if stage != 'fetch' else False
|
99
105
|
|
100
106
|
def unfreeze_components(self):
|
101
107
|
"""Unfreeze all components after initial training."""
|
@@ -109,8 +115,11 @@ class MrlActorModel(nn.Module):
|
|
109
115
|
else:
|
110
116
|
for param in self.decoder.parameters():
|
111
117
|
param.requires_grad = True
|
112
|
-
|
113
|
-
|
118
|
+
if self.memory_attention.unfreeze is not None:
|
119
|
+
self.memory_attention.unfreeze()
|
120
|
+
else:
|
121
|
+
for param in self.memory_attention.parameters():
|
122
|
+
param.requires_grad = True
|
114
123
|
|
115
124
|
def reset_memory(self):
|
116
125
|
self.memory_attention.reset_memory()
|
rxnn/training/mrl.py
CHANGED
@@ -37,7 +37,7 @@ class CurriculumConfig(TypedDict):
|
|
37
37
|
eval_dataset: Optional[MrlCurriculumDataset]
|
38
38
|
callbacks: Optional[list[MrlTrainerCallback]]
|
39
39
|
strategy: MrlStrategy
|
40
|
-
unfreeze_epoch: Optional[Union[int, tuple[int, int, int]]]
|
40
|
+
unfreeze_epoch: Optional[Union[int, tuple[int, int, int, int]]]
|
41
41
|
random_resets: Optional[bool]
|
42
42
|
random_resets_from: Optional[int]
|
43
43
|
random_resets_ratio: Optional[float]
|
@@ -724,10 +724,10 @@ class MRLTrainer:
|
|
724
724
|
self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
|
725
725
|
if config['lr'] is not None or config['critic_lr'] is not None or config['weight_decay'] is not None or config['critic_weight_decay'] is not None:
|
726
726
|
self.optimizer, self.critic_optimizer = self._init_optimizers(
|
727
|
-
lr=config
|
728
|
-
critic_lr=config
|
729
|
-
weight_decay=config
|
730
|
-
critic_weight_decay=config
|
727
|
+
lr=config.get('lr', self.base_optim_config['lr']),
|
728
|
+
critic_lr=config.get('critic_lr', self.base_optim_config['critic_lr']),
|
729
|
+
weight_decay=config.get('weight_decay', self.base_optim_config['weight_decay']),
|
730
|
+
critic_weight_decay=config.get('critic_weight_decay', self.base_optim_config['critic_weight_decay'])
|
731
731
|
)
|
732
732
|
|
733
733
|
# 2. Get epochs and random resets configs
|
@@ -770,11 +770,7 @@ class MRLTrainer:
|
|
770
770
|
|
771
771
|
# 4. Freeze all components except memory attention and memory cross-attention layers in decoder/encoder
|
772
772
|
if unfreeze_epoch != 0:
|
773
|
-
|
774
|
-
if is_staged_unfreeze:
|
775
|
-
self.actor.freeze_components('update')
|
776
|
-
else:
|
777
|
-
self.actor.freeze_components()
|
773
|
+
self.actor.freeze_components('both')
|
778
774
|
|
779
775
|
# 5. Setup train DataLoader
|
780
776
|
if self.use_ddp:
|
@@ -817,8 +813,10 @@ class MRLTrainer:
|
|
817
813
|
# 11. Unfreeze all components before selected epoch
|
818
814
|
is_staged_unfreeze = isinstance(unfreeze_epoch, tuple)
|
819
815
|
if is_staged_unfreeze:
|
820
|
-
fetch_epoch, both_epoch, all_epoch = unfreeze_epoch
|
821
|
-
if epoch ==
|
816
|
+
update_epoch, fetch_epoch, both_epoch, all_epoch = unfreeze_epoch
|
817
|
+
if epoch == update_epoch:
|
818
|
+
self.actor.freeze_components('update')
|
819
|
+
elif epoch == fetch_epoch:
|
822
820
|
self.actor.freeze_components('fetch')
|
823
821
|
elif epoch == both_epoch:
|
824
822
|
self.actor.freeze_components('both')
|
rxnn/transformers/layers.py
CHANGED
@@ -57,9 +57,12 @@ class ReactiveTransformerLayer(nn.Module):
|
|
57
57
|
self.use_moe = use_moe
|
58
58
|
self.use_moe_att = use_moe_att
|
59
59
|
|
60
|
-
def trainable_cross_attention_(self, is_trainable: bool):
|
60
|
+
def trainable_cross_attention_(self, is_trainable: bool, with_norms: bool = True):
|
61
61
|
for param in self.memory_cross_attention.parameters():
|
62
62
|
param.requires_grad_(is_trainable)
|
63
|
+
if with_norms:
|
64
|
+
for param in self.norm2.parameters():
|
65
|
+
param.requires_grad_(is_trainable)
|
63
66
|
|
64
67
|
def update_max_len(self, max_seq_len: int):
|
65
68
|
if self.attention.rope is not None:
|
rxnn/transformers/models.py
CHANGED
@@ -33,11 +33,11 @@ class ReactiveTransformerBase(nn.Module):
|
|
33
33
|
self.num_shared_layers = len(shared_layers) if shared_layers else 0
|
34
34
|
self.num_own_layers = len(own_layers) if own_layers else 0
|
35
35
|
|
36
|
-
def trainable_cross_attention_(self, is_trainable: bool):
|
36
|
+
def trainable_cross_attention_(self, is_trainable: bool, with_norms: bool = True):
|
37
37
|
for i in range(self.num_shared_layers):
|
38
|
-
self.shared_layers[i].trainable_cross_attention_(is_trainable)
|
38
|
+
self.shared_layers[i].trainable_cross_attention_(is_trainable, with_norms)
|
39
39
|
for i in range(self.num_own_layers):
|
40
|
-
self.layers[i].trainable_cross_attention_(is_trainable)
|
40
|
+
self.layers[i].trainable_cross_attention_(is_trainable, with_norms)
|
41
41
|
|
42
42
|
def moe_router_loss(self):
|
43
43
|
return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_own_layers) if self.layers[i].use_moe or self.layers[i].use_moe_att] + [
|
@@ -9,14 +9,14 @@ rxnn/memory/attention.py,sha256=p-r8DK3iVhNn-JAESVzIXDCG8gk1R_-x5xHclZ5jgb0,1813
|
|
9
9
|
rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
|
10
10
|
rxnn/memory/stm.py,sha256=IH_3INw7FdI013t56ui3Zq9GPUq-k3HeZGjx6BerS4g,3888
|
11
11
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
rxnn/rxt/models.py,sha256=
|
12
|
+
rxnn/rxt/models.py,sha256=3gCYD_OXvQc8GaXQvRCSj1OcYOSHayWlpP5lsg9wMMk,12389
|
13
13
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
14
|
rxnn/training/base.py,sha256=_xik1GXE4RJ_nxwqLQ1ccXA5pRtBCi-jL-jeRFBdHBU,11851
|
15
15
|
rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
|
16
16
|
rxnn/training/callbacks.py,sha256=-N0MQPpZQaUWCINdTOsjul4bDGbGr2JgQBqOIXBLS6o,35053
|
17
17
|
rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
|
18
|
-
rxnn/training/models.py,sha256=
|
19
|
-
rxnn/training/mrl.py,sha256=
|
18
|
+
rxnn/training/models.py,sha256=5fl1hESVj2Hakqz5to8ZJzw5Q4_RKZAUq2bn6nRiPV8,6045
|
19
|
+
rxnn/training/mrl.py,sha256=14wx3pVha15B7eRWPRgoxRtV5dPtBI0yadIHOYZjX6k,43275
|
20
20
|
rxnn/training/reward.py,sha256=7MTVdNm5HnWmt6zFDi3TAYmnVSL_-24riOoY2F7z4x8,11290
|
21
21
|
rxnn/training/rl.py,sha256=j-KNLoZjhaEKasYNOc8DxHtwvknAgAJFwvXKot6otFA,3272
|
22
22
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
@@ -25,14 +25,14 @@ rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
|
|
25
25
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
26
26
|
rxnn/transformers/attention.py,sha256=d0Igo1Nrn76BphbHrzekiKJfT3RCy4iSXSB6FLAOTgc,16247
|
27
27
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
28
|
-
rxnn/transformers/layers.py,sha256=
|
28
|
+
rxnn/transformers/layers.py,sha256=UQZbrAg1UAttPASeqS7BP1a4JalktThmRMzX99Qghss,7618
|
29
29
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
30
|
-
rxnn/transformers/models.py,sha256=
|
30
|
+
rxnn/transformers/models.py,sha256=_2qO1SASHtKvTW3dW-Dy9HEmAvoNVC1_addm2tM9Zbs,8325
|
31
31
|
rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
32
32
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
33
33
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
34
34
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
35
|
-
rxnn-0.2.
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
35
|
+
rxnn-0.2.24.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
36
|
+
rxnn-0.2.24.dist-info/METADATA,sha256=PrVfcCd8NBFtFnD8lAJqU7UW3lLEc-Tr7MQhK6obvuo,25960
|
37
|
+
rxnn-0.2.24.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
38
|
+
rxnn-0.2.24.dist-info/RECORD,,
|
File without changes
|
File without changes
|