rxnn 0.2.39__py3-none-any.whl → 0.2.41__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/memory/attention.py +13 -1
- rxnn/rxt/models.py +18 -6
- rxnn/training/models.py +51 -17
- rxnn/training/mrl.py +72 -25
- rxnn/transformers/models.py +7 -2
- {rxnn-0.2.39.dist-info → rxnn-0.2.41.dist-info}/METADATA +1 -1
- {rxnn-0.2.39.dist-info → rxnn-0.2.41.dist-info}/RECORD +9 -9
- {rxnn-0.2.39.dist-info → rxnn-0.2.41.dist-info}/LICENSE +0 -0
- {rxnn-0.2.39.dist-info → rxnn-0.2.41.dist-info}/WHEEL +0 -0
rxnn/memory/attention.py
CHANGED
@@ -8,6 +8,9 @@ class StmMemoryAttention(nn.Module):
|
|
8
8
|
stm: ShortTermMemory,
|
9
9
|
attention_layers: nn.ModuleList,
|
10
10
|
memory_norm_layers: nn.ModuleList,
|
11
|
+
use_gated_residual: bool = False,
|
12
|
+
per_slot_gate: bool = False,
|
13
|
+
init_gate: float = 0.0,
|
11
14
|
*args,
|
12
15
|
**kwargs
|
13
16
|
):
|
@@ -17,6 +20,10 @@ class StmMemoryAttention(nn.Module):
|
|
17
20
|
self.memory_norm_layers = memory_norm_layers
|
18
21
|
assert len(self.attention_layers) == len(self.memory_norm_layers) == self.stm.memory.size(0)
|
19
22
|
self.num_layers = len(attention_layers)
|
23
|
+
self.use_gated_residual = use_gated_residual
|
24
|
+
self.per_slot_gate = per_slot_gate
|
25
|
+
if self.use_gated_residual:
|
26
|
+
self.gate = nn.Parameter(torch.full((self.num_layers, self.stm.stm_size, 1), init_gate) if self.per_slot_gate else torch.full((self.num_layers,), init_gate))
|
20
27
|
|
21
28
|
def update_max_len(self, max_seq_len: int):
|
22
29
|
for i in range(self.num_layers):
|
@@ -35,7 +42,12 @@ class StmMemoryAttention(nn.Module):
|
|
35
42
|
encoded_layer_data = x[i]
|
36
43
|
normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
|
37
44
|
new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=mask)
|
38
|
-
|
45
|
+
if self.use_gated_residual:
|
46
|
+
# gated residual
|
47
|
+
layer_gate = torch.sigmoid(self.gate[i])
|
48
|
+
new_stm[i] = layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
|
49
|
+
else:
|
50
|
+
new_stm[i] = new_layer_stm + layer_stm # residual
|
39
51
|
self.stm.update_all(new_stm)
|
40
52
|
return self.stm.memory
|
41
53
|
|
rxnn/rxt/models.py
CHANGED
@@ -130,10 +130,10 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
|
|
130
130
|
memory_cross_attention=cross_att_init(),
|
131
131
|
) for _ in range(num_layers)
|
132
132
|
])
|
133
|
-
self.model = self._init_model(stm, layers, embedding, use_flash_attention, embed_dim, vocab_size)
|
133
|
+
self.model = self._init_model(stm, layers, embedding, use_flash_attention, embed_dim, vocab_size, use_moe)
|
134
134
|
|
135
135
|
def _init_model(self, stm: ShortTermMemory, layers: nn.ModuleList, embedding: nn.Embedding,
|
136
|
-
use_flash_attention: bool, embed_dim: int, vocab_size: int) -> ReactiveTransformerBase:
|
136
|
+
use_flash_attention: bool, embed_dim: int, vocab_size: int, use_moe: bool) -> ReactiveTransformerBase:
|
137
137
|
pass
|
138
138
|
|
139
139
|
def params_count(self):
|
@@ -185,13 +185,15 @@ class RxTAlphaEncoder(RxTAlphaComponentBase, pipeline_tag="fill-mask", license="
|
|
185
185
|
embedding: nn.Embedding,
|
186
186
|
use_flash_attention: bool,
|
187
187
|
embed_dim: int,
|
188
|
-
vocab_size: int
|
188
|
+
vocab_size: int,
|
189
|
+
use_moe: bool,
|
189
190
|
) -> ReactiveTransformerEncoder:
|
190
191
|
return ReactiveTransformerEncoder(
|
191
192
|
stm=stm,
|
192
193
|
embedding=embedding,
|
193
194
|
own_layers=layers,
|
194
195
|
use_flash_attention=use_flash_attention,
|
196
|
+
use_moe=use_moe,
|
195
197
|
)
|
196
198
|
|
197
199
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
|
@@ -210,7 +212,8 @@ class RxTAlphaDecoder(RxTAlphaComponentBase, pipeline_tag="text-generation", lic
|
|
210
212
|
embedding: nn.Embedding,
|
211
213
|
use_flash_attention: bool,
|
212
214
|
embed_dim: int,
|
213
|
-
vocab_size: int
|
215
|
+
vocab_size: int,
|
216
|
+
use_moe: bool,
|
214
217
|
) -> ReactiveTransformerDecoder:
|
215
218
|
return ReactiveTransformerDecoder(
|
216
219
|
embed_dim,
|
@@ -219,6 +222,7 @@ class RxTAlphaDecoder(RxTAlphaComponentBase, pipeline_tag="text-generation", lic
|
|
219
222
|
embedding=embedding,
|
220
223
|
own_layers=layers,
|
221
224
|
use_flash_attention=use_flash_attention,
|
225
|
+
use_moe=use_moe,
|
222
226
|
)
|
223
227
|
|
224
228
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
|
@@ -246,6 +250,9 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
246
250
|
norm_init_gate: float = -2.0,
|
247
251
|
norm_per_dim_scale: bool = False,
|
248
252
|
norm_decay: float = 0.9,
|
253
|
+
use_gated_residual: bool = False,
|
254
|
+
residual_per_slot_gate: bool = False,
|
255
|
+
residual_init_gate: float = 0.0,
|
249
256
|
**kwargs,
|
250
257
|
):
|
251
258
|
super(RxTAlphaMemoryAttention, self).__init__(**kwargs)
|
@@ -272,7 +279,10 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
272
279
|
init_gate=norm_init_gate, per_dim_scale=norm_per_dim_scale)
|
273
280
|
for _ in range(num_layers)])
|
274
281
|
attention_layers = nn.ModuleList([att_init() for _ in range(num_layers)])
|
275
|
-
self.model = StmMemoryAttention(
|
282
|
+
self.model = StmMemoryAttention(
|
283
|
+
stm, attention_layers, memory_norm_layers,
|
284
|
+
use_gated_residual=use_gated_residual, per_slot_gate=residual_per_slot_gate, init_gate=residual_init_gate
|
285
|
+
)
|
276
286
|
|
277
287
|
def freeze(self):
|
278
288
|
for param in self.parameters():
|
@@ -307,13 +317,15 @@ class RxTAlphaCriticEncoder(RxTAlphaComponentBase, pipeline_tag="text-classifica
|
|
307
317
|
embedding: nn.Embedding,
|
308
318
|
use_flash_attention: bool,
|
309
319
|
embed_dim: int,
|
310
|
-
vocab_size: int
|
320
|
+
vocab_size: int,
|
321
|
+
use_moe: bool = False,
|
311
322
|
) -> ReactiveTransformerEncoderDetachStm:
|
312
323
|
return ReactiveTransformerEncoderDetachStm(
|
313
324
|
stm=stm,
|
314
325
|
embedding=embedding,
|
315
326
|
own_layers=layers,
|
316
327
|
use_flash_attention=use_flash_attention,
|
328
|
+
use_moe=use_moe,
|
317
329
|
)
|
318
330
|
|
319
331
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
|
rxnn/training/models.py
CHANGED
@@ -80,25 +80,33 @@ class MrlActorModel(nn.Module):
|
|
80
80
|
self.decoder = decoder
|
81
81
|
self.memory_attention = memory_attention
|
82
82
|
|
83
|
-
def freeze_components(self, stage: Literal['update', 'fetch', 'joint'] = 'joint'):
|
83
|
+
def freeze_components(self, stage: Literal['update', 'fetch', 'joint'] = 'joint', freeze_embeddings: bool = False):
|
84
84
|
"""Freeze encoder/decoder except memory-related layers."""
|
85
|
+
# Freeze/unfreeze encoder
|
85
86
|
if self.encoder.freeze_without_memory is not None:
|
86
|
-
|
87
|
-
|
87
|
+
if stage == 'update' or stage == 'joint':
|
88
|
+
self.encoder.unfreeze_all()
|
89
|
+
else:
|
90
|
+
self.encoder.freeze_without_memory(unfreeze_norms=True)
|
88
91
|
self.encoder.freeze_memory(with_norms=True)
|
89
92
|
else:
|
90
93
|
for param in self.encoder.parameters():
|
91
|
-
param.requires_grad = False
|
92
|
-
self.encoder.model.trainable_cross_attention_(True if stage != '
|
94
|
+
param.requires_grad = True if stage != 'fetch' else False
|
95
|
+
self.encoder.model.trainable_cross_attention_(True if stage != 'fetch' else False, with_norms=True)
|
96
|
+
# Freeze/unfreeze decoder
|
93
97
|
if self.decoder.freeze_without_memory is not None:
|
94
|
-
|
95
|
-
|
96
|
-
|
98
|
+
if stage == 'fetch':
|
99
|
+
self.decoder.unfreeze_all()
|
100
|
+
else:
|
101
|
+
self.decoder.freeze_without_memory(unfreeze_norms=True)
|
102
|
+
if stage == 'update':
|
103
|
+
self.decoder.freeze_memory(with_norms=True)
|
97
104
|
else:
|
98
105
|
for param in self.decoder.parameters():
|
99
|
-
param.requires_grad = False
|
106
|
+
param.requires_grad = True if stage == 'fetch' else False
|
100
107
|
self.decoder.model.trainable_cross_attention_(True if stage != 'update' else False, with_norms=True)
|
101
|
-
|
108
|
+
|
109
|
+
# Freeze/unfreeze memory attention
|
102
110
|
if self.memory_attention.freeze is not None:
|
103
111
|
if stage == 'fetch':
|
104
112
|
self.memory_attention.freeze()
|
@@ -108,7 +116,11 @@ class MrlActorModel(nn.Module):
|
|
108
116
|
for param in self.memory_attention.parameters():
|
109
117
|
param.requires_grad = True if stage != 'fetch' else False
|
110
118
|
|
111
|
-
|
119
|
+
if freeze_embeddings:
|
120
|
+
for param in self.encoder.model.embedding.parameters():
|
121
|
+
param.requires_grad = False
|
122
|
+
|
123
|
+
def unfreeze_components(self, freeze_embeddings: bool = False):
|
112
124
|
"""Unfreeze all components after initial training."""
|
113
125
|
if self.encoder.unfreeze_all is not None:
|
114
126
|
self.encoder.unfreeze_all()
|
@@ -126,6 +138,11 @@ class MrlActorModel(nn.Module):
|
|
126
138
|
for param in self.memory_attention.parameters():
|
127
139
|
param.requires_grad = True
|
128
140
|
|
141
|
+
if freeze_embeddings:
|
142
|
+
for param in self.encoder.model.embedding.parameters():
|
143
|
+
param.requires_grad = False
|
144
|
+
|
145
|
+
|
129
146
|
def reset_memory(self):
|
130
147
|
self.memory_attention.reset_memory()
|
131
148
|
|
@@ -151,12 +168,29 @@ class MrlActorModel(nn.Module):
|
|
151
168
|
self.decoder.not_memory_parameters()
|
152
169
|
))
|
153
170
|
|
154
|
-
def unique_parameters(self):
|
155
|
-
|
156
|
-
list(
|
157
|
-
|
158
|
-
|
159
|
-
|
171
|
+
def unique_parameters(self, with_embedding: bool = True):
|
172
|
+
if with_embedding:
|
173
|
+
return list(set(
|
174
|
+
list(self.encoder.parameters()) +
|
175
|
+
list(self.decoder.parameters()) +
|
176
|
+
list(self.memory_attention.parameters())
|
177
|
+
))
|
178
|
+
else:
|
179
|
+
return list(set(
|
180
|
+
self.not_memory_parameters() +
|
181
|
+
self.memory_cross_attention_parameters() +
|
182
|
+
list(self.memory_attention_parameters())
|
183
|
+
))
|
184
|
+
|
185
|
+
def moe_router_loss(self):
|
186
|
+
if self.encoder.model.use_moe and self.decoder.model.use_moe:
|
187
|
+
return (self.encoder.model.moe_router_loss() + self.decoder.model.moe_router_loss()) / 2
|
188
|
+
elif self.encoder.model.use_moe:
|
189
|
+
return self.encoder.model.moe_router_loss()
|
190
|
+
elif self.decoder.model.use_moe:
|
191
|
+
return self.decoder.model.moe_router_loss()
|
192
|
+
else:
|
193
|
+
return None
|
160
194
|
|
161
195
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None,
|
162
196
|
action: MrlActorAction = MrlActorAction.DECODE) -> torch.Tensor:
|
rxnn/training/mrl.py
CHANGED
@@ -21,16 +21,20 @@ class MrlConfig(TypedDict):
|
|
21
21
|
separate_memory_lr: Optional[bool]
|
22
22
|
memory_lr: Optional[float]
|
23
23
|
critic_lr: float
|
24
|
-
critic_encoder_lr: float
|
24
|
+
critic_encoder_lr: Optional[float]
|
25
25
|
max_seq_len: int
|
26
26
|
critic_max_len: int
|
27
|
-
weight_decay: float
|
28
|
-
critic_weight_decay: float
|
27
|
+
weight_decay: Optional[float]
|
28
|
+
critic_weight_decay: Optional[float]
|
29
29
|
update_epochs: int
|
30
30
|
pad_token_id: int
|
31
31
|
end_token_id: int
|
32
32
|
callbacks: Optional[list[MrlTrainerCallback]]
|
33
|
-
memory_aware_critic: bool
|
33
|
+
memory_aware_critic: Optional[bool]
|
34
|
+
use_moe_aux_loss: Optional[bool]
|
35
|
+
moe_aux_loss_scale: Optional[float]
|
36
|
+
freeze_embeddings: Optional[bool]
|
37
|
+
embedding_lr: Optional[float]
|
34
38
|
|
35
39
|
|
36
40
|
class MrlStrategy(Enum):
|
@@ -64,6 +68,8 @@ class CurriculumConfig(TypedDict):
|
|
64
68
|
weight_decay: Optional[float]
|
65
69
|
critic_weight_decay: Optional[float]
|
66
70
|
update_epochs: Optional[int]
|
71
|
+
freeze_embeddings: Optional[bool]
|
72
|
+
embedding_lr: Optional[float]
|
67
73
|
|
68
74
|
|
69
75
|
class SamplerConfig(TypedDict):
|
@@ -125,6 +131,10 @@ class MRLTrainer:
|
|
125
131
|
self.max_seq_len = config.get('max_seq_len', 256)
|
126
132
|
self.critic_max_len = config.get('critic_max_len', 512)
|
127
133
|
self.memory_aware_critic = config.get('memory_aware_critic', False)
|
134
|
+
self.use_moe_aux_loss = config.get('use_moe_aux_loss', False)
|
135
|
+
self.moe_aux_loss_scale = config.get('moe_aux_loss_scale', 0.01)
|
136
|
+
self.shared_freeze_embeddings = config.get('freeze_embeddings', False)
|
137
|
+
self.freeze_embeddings = self.shared_freeze_embeddings
|
128
138
|
# Internal update epochs config
|
129
139
|
self.shared_update_epochs = config.get('update_epochs', 10)
|
130
140
|
self.update_epochs = self.shared_update_epochs
|
@@ -162,6 +172,7 @@ class MRLTrainer:
|
|
162
172
|
'weight_decay': config.get('weight_decay', 0.01),
|
163
173
|
'critic_weight_decay': config.get('critic_weight_decay', 0.01),
|
164
174
|
'critic_encoder_lr': config.get('critic_encoder_lr', config.get('critic_lr', 1e-4)),
|
175
|
+
'embedding_lr': config.get('embedding_lr', config.get('lr', 3e-4)),
|
165
176
|
}
|
166
177
|
else:
|
167
178
|
self.base_optim_config = {
|
@@ -170,6 +181,7 @@ class MRLTrainer:
|
|
170
181
|
'weight_decay': config.get('weight_decay', 0.01),
|
171
182
|
'critic_weight_decay': config.get('critic_weight_decay', 0.01),
|
172
183
|
'critic_encoder_lr': config.get('critic_encoder_lr', config.get('critic_lr', 1e-4)),
|
184
|
+
'embedding_lr': config.get('embedding_lr', config.get('lr', 3e-4)),
|
173
185
|
}
|
174
186
|
|
175
187
|
self.optim_config = self.base_optim_config
|
@@ -208,19 +220,22 @@ class MRLTrainer:
|
|
208
220
|
weight_decay: float,
|
209
221
|
critic_weight_decay: float,
|
210
222
|
critic_encoder_lr: float,
|
223
|
+
embedding_lr: float,
|
211
224
|
memory_lr: Optional[float] = None,
|
212
225
|
) -> tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
|
213
226
|
if memory_lr is not None:
|
214
227
|
optimizer = torch.optim.AdamW([
|
228
|
+
{'params': self.actor.encoder.embedding.parameters(), 'lr': embedding_lr},
|
215
229
|
{'params': self.actor.not_memory_parameters(), 'lr': lr},
|
216
230
|
{'params': self.actor.memory_parameters(), 'lr': memory_lr},
|
217
231
|
],
|
218
232
|
weight_decay=weight_decay,
|
219
233
|
)
|
220
234
|
else:
|
221
|
-
optimizer = torch.optim.AdamW(
|
222
|
-
self.actor.
|
223
|
-
|
235
|
+
optimizer = torch.optim.AdamW([
|
236
|
+
{'params': self.actor.encoder.embedding.parameters(), 'lr': embedding_lr},
|
237
|
+
{'params': self.actor.unique_parameters(with_embedding=False), 'lr': lr},
|
238
|
+
],
|
224
239
|
weight_decay=weight_decay,
|
225
240
|
)
|
226
241
|
|
@@ -522,6 +537,18 @@ class MRLTrainer:
|
|
522
537
|
# 6. Return loss item
|
523
538
|
return critic_loss_item
|
524
539
|
|
540
|
+
def _moe_aux_loss(self, main_loss: torch.Tensor) -> torch.Tensor:
|
541
|
+
if not self.use_moe_aux_loss:
|
542
|
+
return main_loss
|
543
|
+
|
544
|
+
actor = next(self.actor.children()) if isinstance(self.actor, DistributedDataParallel) else self.actor
|
545
|
+
|
546
|
+
router_loss = actor.moe_router_loss()
|
547
|
+
if router_loss is not None:
|
548
|
+
return main_loss + self.moe_aux_loss_scale * router_loss
|
549
|
+
else:
|
550
|
+
return main_loss
|
551
|
+
|
525
552
|
def update_actor(self, state: tuple[TokenizedDict, TokenizedDict, TokenizedDict], action: TokenizedDict,
|
526
553
|
advantages: torch.Tensor, old_log_probs: torch.Tensor, epoch: int) -> float:
|
527
554
|
# 1. Reset actor gradients
|
@@ -544,6 +571,8 @@ class MRLTrainer:
|
|
544
571
|
# 4.2 Calculate policy loss with selected algorithm
|
545
572
|
policy_loss = self.rl_algorithm.policy_loss(next_query, action, logits, old_log_probs,
|
546
573
|
advantages)
|
574
|
+
policy_loss = self._moe_aux_loss(policy_loss)
|
575
|
+
|
547
576
|
# 4.3 Run backpropagation with scaler
|
548
577
|
self.scaler.scale(policy_loss).backward(retain_graph=True)
|
549
578
|
# 4.4 Unscale and clip gradient norms
|
@@ -561,6 +590,7 @@ class MRLTrainer:
|
|
561
590
|
action=MrlActorAction.DECODE)
|
562
591
|
# 4.2 Calculate policy loss with selected algorithm
|
563
592
|
policy_loss = self.rl_algorithm.policy_loss(next_query, action, logits, old_log_probs, advantages)
|
593
|
+
policy_loss = self._moe_aux_loss(policy_loss)
|
564
594
|
# 4.3 Run backpropagation
|
565
595
|
policy_loss.backward(retain_graph=True)
|
566
596
|
# 4.4 Clip gradient norms
|
@@ -852,41 +882,41 @@ class MRLTrainer:
|
|
852
882
|
if isinstance(update_epoch, tuple):
|
853
883
|
switch_epoch, cross_att_lr = update_epoch
|
854
884
|
if epoch == switch_epoch:
|
855
|
-
self.actor.
|
885
|
+
self.actor.unfreeze_components(freeze_embeddings=self.freeze_embeddings)
|
856
886
|
self.optimizer = self._init_unfreeze_optimizer('update', cross_att_lr)
|
857
887
|
print(f"Activating 'update' unfreeze strategy with custom cross_att_lr: {cross_att_lr}")
|
858
888
|
elif epoch == update_epoch:
|
859
|
-
self.actor.freeze_components('update')
|
889
|
+
self.actor.freeze_components('update', freeze_embeddings=self.freeze_embeddings)
|
860
890
|
print(
|
861
891
|
f"Activating 'update' unfreeze strategy - mem-att trainable / cross-att frozen / rest model frozen")
|
862
892
|
|
863
893
|
if isinstance(fetch_epoch, tuple):
|
864
894
|
switch_epoch, mem_att_lr = fetch_epoch
|
865
895
|
if epoch == switch_epoch:
|
866
|
-
self.actor.
|
896
|
+
self.actor.unfreeze_components(freeze_embeddings=self.freeze_embeddings)
|
867
897
|
self.optimizer = self._init_unfreeze_optimizer('fetch', mem_att_lr)
|
868
898
|
print(f"Activating 'fetch' unfreeze strategy with custom mem_att_lr: {mem_att_lr}")
|
869
899
|
elif epoch == fetch_epoch:
|
870
|
-
self.actor.freeze_components('fetch')
|
900
|
+
self.actor.freeze_components('fetch', freeze_embeddings=self.freeze_embeddings)
|
871
901
|
print(
|
872
902
|
f"Activating 'fetch' unfreeze strategy - mem-att frozen / cross-att trainable / rest model frozen")
|
873
903
|
|
874
904
|
if isinstance(joint_epoch, tuple):
|
875
905
|
switch_epoch, model_lr = joint_epoch
|
876
906
|
if epoch == switch_epoch:
|
877
|
-
self.actor.unfreeze_components()
|
907
|
+
self.actor.unfreeze_components(freeze_embeddings=self.freeze_embeddings)
|
878
908
|
self.optimizer = self._init_unfreeze_optimizer('joint', model_lr)
|
879
909
|
print(f"Activating 'joint' unfreeze strategy with custom model_lr: {model_lr}")
|
880
910
|
elif epoch == joint_epoch:
|
881
|
-
self.actor.freeze_components('joint')
|
911
|
+
self.actor.freeze_components('joint', freeze_embeddings=self.freeze_embeddings)
|
882
912
|
print(f"Activating 'joint' unfreeze strategy - mem-att/cross-att trainable / rest model frozen")
|
883
913
|
|
884
914
|
if epoch == all_epoch:
|
885
|
-
self.actor.unfreeze_components()
|
915
|
+
self.actor.unfreeze_components(freeze_embeddings=self.freeze_embeddings)
|
886
916
|
self.optimizer = self._init_unfreeze_optimizer('all', 0.)
|
887
917
|
print(f"Switching to train 'all' strategy - unfreeze all components")
|
888
918
|
elif epoch == unfreeze_epoch:
|
889
|
-
self.actor.unfreeze_components()
|
919
|
+
self.actor.unfreeze_components(freeze_embeddings=self.freeze_embeddings)
|
890
920
|
print(f"Switching to train 'all' strategy - unfreeze all components")
|
891
921
|
|
892
922
|
def _init_unfreeze_optimizer(
|
@@ -895,29 +925,43 @@ class MRLTrainer:
|
|
895
925
|
unfreeze_lr: float,
|
896
926
|
) -> torch.optim.Optimizer:
|
897
927
|
memory_lr = self.optim_config['memory_lr'] if 'memory_lr' in self.optim_config else self.optim_config['lr']
|
898
|
-
model_lr = self.optim_config['lr']
|
928
|
+
model_lr, embedding_lr = self.optim_config['lr'], self.optim_config['embedding_lr']
|
899
929
|
|
900
930
|
if mode == 'update':
|
901
931
|
params = [
|
902
|
-
{'params': self.actor.
|
932
|
+
{'params': self.actor.encoder.embedding.parameters(), 'lr': embedding_lr},
|
933
|
+
{'params': self.actor.encoder.not_memory_parameters(), 'lr': model_lr},
|
934
|
+
{'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
|
903
935
|
{'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
|
904
|
-
{'params': self.actor.
|
936
|
+
{'params': self.actor.decoder.memory_parameters(), 'lr': unfreeze_lr},
|
937
|
+
{'params': self.actor.decoder.not_memory_parameters(), 'lr': unfreeze_lr},
|
905
938
|
]
|
906
939
|
elif mode == 'fetch':
|
907
940
|
params = [
|
908
|
-
{'params': self.actor.
|
909
|
-
{'params': self.actor.
|
941
|
+
{'params': self.actor.encoder.embedding.parameters(), 'lr': unfreeze_lr},
|
942
|
+
{'params': self.actor.encoder.not_memory_parameters(), 'lr': unfreeze_lr},
|
943
|
+
{'params': self.actor.encoder.memory_parameters(), 'lr': unfreeze_lr},
|
910
944
|
{'params': self.actor.memory_attention_parameters(), 'lr': unfreeze_lr},
|
945
|
+
{'params': self.actor.decoder.memory_parameters(), 'lr': memory_lr},
|
946
|
+
{'params': self.actor.decoder.not_memory_parameters(), 'lr': model_lr},
|
911
947
|
]
|
912
948
|
elif mode == 'joint':
|
913
949
|
params = [
|
914
|
-
{'params': self.actor.
|
915
|
-
{'params': self.actor.
|
950
|
+
{'params': self.actor.encoder.embedding.parameters(), 'lr': unfreeze_lr},
|
951
|
+
{'params': self.actor.encoder.not_memory_parameters(), 'lr': unfreeze_lr},
|
952
|
+
{'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
|
953
|
+
{'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
|
954
|
+
{'params': self.actor.decoder.memory_parameters(), 'lr': memory_lr},
|
955
|
+
{'params': self.actor.decoder.not_memory_parameters(), 'lr': unfreeze_lr},
|
916
956
|
]
|
917
957
|
else:
|
918
958
|
params = [
|
919
|
-
{'params': self.actor.
|
920
|
-
{'params': self.actor.
|
959
|
+
{'params': self.actor.encoder.embedding.parameters(), 'lr': embedding_lr},
|
960
|
+
{'params': self.actor.encoder.not_memory_parameters(), 'lr': model_lr},
|
961
|
+
{'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
|
962
|
+
{'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
|
963
|
+
{'params': self.actor.decoder.memory_parameters(), 'lr': memory_lr},
|
964
|
+
{'params': self.actor.decoder.not_memory_parameters(), 'lr': model_lr},
|
921
965
|
]
|
922
966
|
|
923
967
|
return torch.optim.AdamW(params, weight_decay=self.optim_config['weight_decay'])
|
@@ -934,6 +978,7 @@ class MRLTrainer:
|
|
934
978
|
MrlStrategy.MULTI_STEP_STRATEGY) # MRL strategy for given curriculum stage
|
935
979
|
self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
|
936
980
|
self.update_epochs = config.get('update_epochs', self.shared_update_epochs) # Internal update epochs
|
981
|
+
self.freeze_embeddings = config.get('freeze_embeddings', self.shared_freeze_embeddings)
|
937
982
|
if config['lr'] is not None or config['critic_lr'] is not None or config['weight_decay'] is not None or config[
|
938
983
|
'critic_weight_decay'] is not None or (config['separate_memory_lr'] and config['memory_lr'] is not None):
|
939
984
|
if config.get('separate_memory_lr', False):
|
@@ -945,6 +990,7 @@ class MRLTrainer:
|
|
945
990
|
self.base_optim_config['critic_weight_decay']),
|
946
991
|
'critic_encoder_lr': config.get('critic_encoder_lr', self.base_optim_config['critic_encoder_lr']),
|
947
992
|
'memory_lr': config.get('memory_lr', self.base_optim_config['memory_lr']),
|
993
|
+
'embedding_lr': config.get('embedding_lr', self.base_optim_config['embedding_lr'])
|
948
994
|
}
|
949
995
|
else:
|
950
996
|
self.optim_config = {
|
@@ -954,6 +1000,7 @@ class MRLTrainer:
|
|
954
1000
|
'critic_weight_decay': config.get('critic_weight_decay',
|
955
1001
|
self.base_optim_config['critic_weight_decay']),
|
956
1002
|
'critic_encoder_lr': config.get('critic_encoder_lr', self.base_optim_config['critic_encoder_lr']),
|
1003
|
+
'embedding_lr': config.get('embedding_lr', self.base_optim_config['embedding_lr'])
|
957
1004
|
}
|
958
1005
|
self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
|
959
1006
|
elif self.optim_config != self.base_optim_config:
|
@@ -1005,7 +1052,7 @@ class MRLTrainer:
|
|
1005
1052
|
if callable(unfreeze_epoch):
|
1006
1053
|
unfreeze_epoch(-1)
|
1007
1054
|
else:
|
1008
|
-
self.actor.freeze_components('joint')
|
1055
|
+
self.actor.freeze_components('joint', freeze_embeddings=self.freeze_embeddings)
|
1009
1056
|
if isinstance(unfreeze_epoch, tuple):
|
1010
1057
|
print(
|
1011
1058
|
f"Starting training with unfreeze strategies - 'warmup' - mem-att/cross-att trainable / rest model frozen")
|
rxnn/transformers/models.py
CHANGED
@@ -17,6 +17,7 @@ class ReactiveTransformerBase(nn.Module):
|
|
17
17
|
absolute_embedding: AbsolutePositionalEmbedding = None,
|
18
18
|
use_flash_attention: bool = False,
|
19
19
|
use_relative_embedding: bool = False,
|
20
|
+
use_moe: bool = False,
|
20
21
|
*args,
|
21
22
|
**kwargs,
|
22
23
|
):
|
@@ -32,6 +33,7 @@ class ReactiveTransformerBase(nn.Module):
|
|
32
33
|
self.layers = own_layers
|
33
34
|
self.num_shared_layers = len(shared_layers) if shared_layers else 0
|
34
35
|
self.num_own_layers = len(own_layers) if own_layers else 0
|
36
|
+
self.use_moe = use_moe
|
35
37
|
|
36
38
|
def trainable_cross_attention_(self, is_trainable: bool, with_norms: bool = True):
|
37
39
|
for i in range(self.num_shared_layers):
|
@@ -50,8 +52,11 @@ class ReactiveTransformerBase(nn.Module):
|
|
50
52
|
return own + shared
|
51
53
|
|
52
54
|
def moe_router_loss(self):
|
53
|
-
|
54
|
-
self.
|
55
|
+
if self.use_moe:
|
56
|
+
return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_own_layers) if self.layers[i].use_moe or self.layers[i].use_moe_att] + [
|
57
|
+
self.shared_layers[i].moe_router_loss() for i in range(self.num_shared_layers) if self.shared_layers[i].use_moe or self.shared_layers[i].use_moe_att]).mean()
|
58
|
+
else:
|
59
|
+
return None
|
55
60
|
|
56
61
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
57
62
|
# Shared logic for encoders and decoders - apply embeddings and positional encoding
|
@@ -5,19 +5,19 @@ rxnn/experimental/attention.py,sha256=46qwZLJuZMpIBrZ-r9DaQEPPmmZkO464C3Tkm_Mq-c
|
|
5
5
|
rxnn/experimental/models.py,sha256=foBo0n0ufvBnfIdJomiEg3CuSOiWSt-q5ako7vzYxx4,4888
|
6
6
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
7
7
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
|
-
rxnn/memory/attention.py,sha256=
|
8
|
+
rxnn/memory/attention.py,sha256=POszZeW0QBKOh4VTDVekmZGKKwUr1Zj0FOAilTv8Vyg,2411
|
9
9
|
rxnn/memory/norm.py,sha256=E98jOQEuIOFFhlkvS8s4fFN-D4tLO6vaOqnObv1oVmA,6592
|
10
10
|
rxnn/memory/stm.py,sha256=SSfc-RL9FE-RLkmOEkLB-9Rb00ZXbMLbsAEPdpIW89o,3851
|
11
11
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
rxnn/rxt/models.py,sha256=
|
12
|
+
rxnn/rxt/models.py,sha256=lRn7NRIAAeCxr8hoIXanhaD-cGwVwA23hBdIQpBK6kc,14484
|
13
13
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
14
|
rxnn/training/base.py,sha256=CqaArEZYOdH64nmKfx28U3GI46TzO4oNkjf_hrF23Cw,11835
|
15
15
|
rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
|
16
16
|
rxnn/training/callbacks.py,sha256=p72lbzFAmFjpcUvyy4aUB3qd53I8C6Sk5w9nQvsKgTk,35852
|
17
17
|
rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
|
18
18
|
rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
|
19
|
-
rxnn/training/models.py,sha256=
|
20
|
-
rxnn/training/mrl.py,sha256=
|
19
|
+
rxnn/training/models.py,sha256=4hDH-R9l1lNvBMW_CGG_QgmCVrkyG7Lyo40PPzvkovQ,8876
|
20
|
+
rxnn/training/mrl.py,sha256=tv7LjW1HBXF9H7rrITQD4EmN1-qgJT44UblREzsjeew,59378
|
21
21
|
rxnn/training/reward.py,sha256=B7nerPk9eNAv2i7umtNF88tVQVwijNNrchIrEITGHKk,11623
|
22
22
|
rxnn/training/rl.py,sha256=q4NzIZAmXRHVToT13IHrPTtEikWQUvT0NO0IjApjAO8,6171
|
23
23
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
@@ -28,12 +28,12 @@ rxnn/transformers/attention.py,sha256=d0Igo1Nrn76BphbHrzekiKJfT3RCy4iSXSB6FLAOTg
|
|
28
28
|
rxnn/transformers/ff.py,sha256=WDjO-H9XWInoWnUnxiseIH6Kx5GlHP0zGJygwhcb1gc,2589
|
29
29
|
rxnn/transformers/layers.py,sha256=l0bXmhN7KOkCw0KTVLixWSo9Op4SesGabWJ4R4EQBMY,7988
|
30
30
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
31
|
-
rxnn/transformers/models.py,sha256=
|
31
|
+
rxnn/transformers/models.py,sha256=7ypPNFFnacdZjvaLVue1KR2PmMSdVYsbCMQSunXDL70,10720
|
32
32
|
rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
33
33
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
34
34
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
35
35
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
39
|
-
rxnn-0.2.
|
36
|
+
rxnn-0.2.41.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
37
|
+
rxnn-0.2.41.dist-info/METADATA,sha256=5oKrThfhnOQK8KjDYJfcP-LTb03hNyUrSTjbOSpUUdg,25960
|
38
|
+
rxnn-0.2.41.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
39
|
+
rxnn-0.2.41.dist-info/RECORD,,
|
File without changes
|
File without changes
|