rxnn 0.2.25__py3-none-any.whl → 0.2.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/memory/norm.py +17 -13
- rxnn/rxt/models.py +36 -27
- rxnn/training/models.py +25 -2
- rxnn/training/mrl.py +158 -40
- rxnn/transformers/layers.py +7 -0
- rxnn/transformers/models.py +10 -0
- {rxnn-0.2.25.dist-info → rxnn-0.2.27.dist-info}/METADATA +1 -1
- {rxnn-0.2.25.dist-info → rxnn-0.2.27.dist-info}/RECORD +10 -10
- {rxnn-0.2.25.dist-info → rxnn-0.2.27.dist-info}/LICENSE +0 -0
- {rxnn-0.2.25.dist-info → rxnn-0.2.27.dist-info}/WHEEL +0 -0
rxnn/memory/norm.py
CHANGED
@@ -7,10 +7,11 @@ class AdaptivePositionalMemoryNorm(nn.Module):
|
|
7
7
|
self,
|
8
8
|
num_slots: int,
|
9
9
|
dim: int,
|
10
|
-
decay: float = 0.
|
10
|
+
decay: float = 0.9,
|
11
11
|
use_scale: bool = True,
|
12
12
|
use_gate: bool = True,
|
13
|
-
init_gate: float = -
|
13
|
+
init_gate: float = -2.0,
|
14
|
+
per_dim_scale: bool = False,
|
14
15
|
):
|
15
16
|
super(AdaptivePositionalMemoryNorm, self).__init__()
|
16
17
|
self.use_gate = use_gate
|
@@ -20,7 +21,8 @@ class AdaptivePositionalMemoryNorm(nn.Module):
|
|
20
21
|
self.eps = 1e-6
|
21
22
|
|
22
23
|
# Learnable parameters
|
23
|
-
|
24
|
+
scale_shape = (num_slots, 1) if not per_dim_scale else (dim,)
|
25
|
+
self.scale = nn.Parameter(torch.ones(*scale_shape)) if use_scale else None
|
24
26
|
self.gate = nn.Parameter(torch.full((num_slots, 1), init_gate)) if use_gate else None
|
25
27
|
|
26
28
|
# EMA buffers
|
@@ -28,7 +30,7 @@ class AdaptivePositionalMemoryNorm(nn.Module):
|
|
28
30
|
|
29
31
|
# Initialize parameters
|
30
32
|
if self.scale is not None:
|
31
|
-
nn.init.normal_(self.scale, mean=1.0, std=0.
|
33
|
+
nn.init.normal_(self.scale, mean=1.0, std=0.1)
|
32
34
|
|
33
35
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
34
36
|
# Calculate current RMS per slot
|
@@ -45,7 +47,7 @@ class AdaptivePositionalMemoryNorm(nn.Module):
|
|
45
47
|
|
46
48
|
# Apply learned scale per slot
|
47
49
|
if self.scale is not None:
|
48
|
-
x_norm = x_norm * self.scale # [batch_size, num_slots, dim] * [num_slots, dim]
|
50
|
+
x_norm = x_norm * self.scale # [batch_size, num_slots, dim] * [num_slots, 1] or [dim]
|
49
51
|
|
50
52
|
# Apply gating mechanism
|
51
53
|
if self.use_gate:
|
@@ -148,24 +150,26 @@ class MemoryNormConfig(TypedDict):
|
|
148
150
|
use_gate: bool
|
149
151
|
init_gate: float
|
150
152
|
init_scale: float
|
153
|
+
per_dim_scale: bool
|
151
154
|
|
152
155
|
def init_memory_norm(
|
153
156
|
norm_type: str,
|
154
157
|
dim: int,
|
155
158
|
num_slots: int = None,
|
156
|
-
decay: float = 0.
|
159
|
+
decay: float = 0.9,
|
157
160
|
use_scale: bool = True,
|
158
161
|
use_gate: bool = True,
|
159
|
-
init_gate: float = -
|
162
|
+
init_gate: float = -2.0,
|
160
163
|
init_scale: float = 1.0,
|
164
|
+
per_dim_scale: bool = False,
|
161
165
|
) -> nn.Module:
|
162
|
-
assert norm_type in [
|
163
|
-
if norm_type ==
|
166
|
+
assert norm_type in ['layer', 'rms', 'adaptive', 'positional']
|
167
|
+
if norm_type == 'layer':
|
164
168
|
return MemoryLayerNorm(dim, use_gate, init_scale, init_gate)
|
165
|
-
elif norm_type ==
|
169
|
+
elif norm_type == 'rms':
|
166
170
|
return SimpleRMSMemoryNorm(dim, use_gate, init_scale, init_gate)
|
167
|
-
elif norm_type ==
|
171
|
+
elif norm_type == 'adaptive':
|
168
172
|
return AdaptiveRMSMemoryNorm(dim, use_gate, decay, init_scale, init_gate)
|
169
|
-
elif norm_type ==
|
170
|
-
return AdaptivePositionalMemoryNorm(num_slots, dim, decay, use_scale, use_gate, init_gate)
|
173
|
+
elif norm_type == 'positional':
|
174
|
+
return AdaptivePositionalMemoryNorm(num_slots, dim, decay, use_scale, use_gate, init_gate, per_dim_scale)
|
171
175
|
return MemoryLayerNorm(dim, use_gate, init_scale, init_gate)
|
rxnn/rxt/models.py
CHANGED
@@ -13,6 +13,7 @@ from ..memory.attention import StmMemoryAttention
|
|
13
13
|
from ..utils import get_model_size
|
14
14
|
from ..experimental.attention import init_experimental_attention
|
15
15
|
|
16
|
+
|
16
17
|
class RxTAlphaComponentConfig(TypedDict):
|
17
18
|
num_layers: int
|
18
19
|
vocab_size: int
|
@@ -76,8 +77,10 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
|
|
76
77
|
assert ff_activation in ['relu', 'gelu',
|
77
78
|
'swish', 'silu', 'linear',
|
78
79
|
'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
|
79
|
-
assert self_att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma',
|
80
|
-
|
80
|
+
assert self_att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma',
|
81
|
+
'sqa'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
|
82
|
+
assert cross_att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma',
|
83
|
+
'sqa'], 'Memory cross-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
|
81
84
|
|
82
85
|
embedding = nn.Embedding(vocab_size, embed_dim)
|
83
86
|
rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
|
@@ -92,20 +95,25 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
|
|
92
95
|
else:
|
93
96
|
att_init = lambda: init_experimental_attention(embed_dim, att_heads, self_att_type, att_groups, rope=rope,
|
94
97
|
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
95
|
-
max_seq_len=seq_len, is_causal=is_causal,
|
98
|
+
max_seq_len=seq_len, is_causal=is_causal,
|
99
|
+
num_experts=att_experts,
|
96
100
|
num_query_experts=att_query_experts,
|
97
101
|
num_query_groups=att_query_groups)
|
98
102
|
|
99
103
|
if cross_att_type in ['mha', 'gqa', 'mqa']:
|
100
104
|
cross_att_init = lambda: init_attention(embed_dim, att_heads, cross_att_type, att_groups, rope=rope,
|
101
|
-
|
102
|
-
|
105
|
+
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
106
|
+
max_seq_len=seq_len, is_causal=is_causal, rope_only_for_query=True)
|
103
107
|
else:
|
104
|
-
cross_att_init = lambda: init_experimental_attention(embed_dim, att_heads, cross_att_type,
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
108
|
+
cross_att_init = lambda: init_experimental_attention(embed_dim, att_heads, cross_att_type,
|
109
|
+
cross_att_groups or att_groups, rope=rope,
|
110
|
+
use_flash_attention=use_flash_attention,
|
111
|
+
dropout=att_dropout,
|
112
|
+
max_seq_len=seq_len, is_causal=is_causal,
|
113
|
+
num_experts=att_experts,
|
114
|
+
num_query_experts=att_query_experts,
|
115
|
+
num_query_groups=cross_att_query_groups or att_query_groups,
|
116
|
+
rope_only_for_query=True)
|
109
117
|
|
110
118
|
layers = nn.ModuleList([
|
111
119
|
ReactiveTransformerLayer(
|
@@ -137,6 +145,12 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
|
|
137
145
|
def load_shared_memory(self, stm: ShortTermMemory):
|
138
146
|
self.model.stm = stm
|
139
147
|
|
148
|
+
def memory_parameters(self) -> list[nn.Parameter]:
|
149
|
+
return self.model.memory_parameters()
|
150
|
+
|
151
|
+
def not_memory_parameters(self) -> list[nn.Parameter]:
|
152
|
+
return self.model.not_memory_parameters()
|
153
|
+
|
140
154
|
def freeze_without_memory(self, unfreeze_norms: bool = True):
|
141
155
|
for param in self.model.parameters():
|
142
156
|
param.requires_grad_(False)
|
@@ -211,20 +225,9 @@ class RxTAlphaDecoder(RxTAlphaComponentBase, pipeline_tag="text-generation", lic
|
|
211
225
|
return self.model(x, attention_mask=attention_mask)
|
212
226
|
|
213
227
|
|
214
|
-
def build_rxt_alpha_for_pretraining(
|
215
|
-
encoder_config: RxTAlphaComponentConfig,
|
216
|
-
decoder_config: RxTAlphaComponentConfig,
|
217
|
-
) -> tuple[RxTAlphaEncoder, RxTAlphaDecoder]:
|
218
|
-
encoder = RxTAlphaEncoder(**encoder_config)
|
219
|
-
decoder = RxTAlphaDecoder(**decoder_config)
|
220
|
-
|
221
|
-
encoder.load_shared_memory(decoder.model.stm)
|
222
|
-
encoder.load_shared_embedding(decoder.model.embedding)
|
223
|
-
|
224
|
-
return encoder, decoder
|
225
|
-
|
226
228
|
class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
|
227
229
|
"""RxT-Alpha (Reactive Transformer) memory attention model"""
|
230
|
+
|
228
231
|
def __init__(
|
229
232
|
self,
|
230
233
|
num_layers: int = 12,
|
@@ -234,17 +237,21 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
234
237
|
stm_size: int = 1024,
|
235
238
|
use_flash_attention: bool = False,
|
236
239
|
att_dropout: float = 0.0,
|
237
|
-
norm_type: str = 'rms',
|
238
240
|
att_groups: int = 1,
|
239
241
|
att_type: str = 'sqa',
|
240
242
|
att_experts: int = None,
|
241
243
|
att_query_experts: int = None,
|
242
244
|
att_query_groups: int = None,
|
245
|
+
norm_type: str = 'rms',
|
246
|
+
norm_init_gate: float = -2.0,
|
247
|
+
norm_per_dim_scale: bool = False,
|
248
|
+
norm_decay: float = 0.9,
|
243
249
|
**kwargs,
|
244
250
|
):
|
245
251
|
super(RxTAlphaMemoryAttention, self).__init__(**kwargs)
|
246
252
|
|
247
|
-
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma',
|
253
|
+
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma',
|
254
|
+
'sqa'], 'Memory attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
|
248
255
|
|
249
256
|
rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
|
250
257
|
stm = ShortTermMemory(num_layers, embed_dim, stm_size)
|
@@ -256,11 +263,14 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
256
263
|
else:
|
257
264
|
att_init = lambda: init_experimental_attention(embed_dim, att_heads, att_type, att_groups, rope=rope,
|
258
265
|
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
259
|
-
max_seq_len=seq_len, is_causal=False,
|
266
|
+
max_seq_len=seq_len, is_causal=False,
|
267
|
+
num_experts=att_experts,
|
260
268
|
num_query_experts=att_query_experts,
|
261
269
|
num_query_groups=att_query_groups, rope_only_for_keys=True)
|
262
270
|
|
263
|
-
memory_norm_layers = nn.ModuleList([init_memory_norm(norm_type, embed_dim, stm_size
|
271
|
+
memory_norm_layers = nn.ModuleList([init_memory_norm(norm_type, embed_dim, stm_size, decay=norm_decay,
|
272
|
+
init_gate=norm_init_gate, per_dim_scale=norm_per_dim_scale)
|
273
|
+
for _ in range(num_layers)])
|
264
274
|
attention_layers = nn.ModuleList([att_init() for _ in range(num_layers)])
|
265
275
|
self.model = StmMemoryAttention(stm, attention_layers, memory_norm_layers)
|
266
276
|
|
@@ -283,4 +293,3 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
283
293
|
|
284
294
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
285
295
|
return self.model(x, attention_mask=attention_mask)
|
286
|
-
|
rxnn/training/models.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.nn as nn
|
3
3
|
from enum import Enum
|
4
|
-
from typing import Literal
|
4
|
+
from typing import Literal, Iterator
|
5
5
|
from huggingface_hub import PyTorchModelHubMixin
|
6
6
|
from ..transformers.models import ReactiveTransformerEncoder, ReactiveTransformerDecoder
|
7
7
|
|
@@ -75,7 +75,7 @@ class MrlActorModel(nn.Module):
|
|
75
75
|
self.decoder = decoder
|
76
76
|
self.memory_attention = memory_attention
|
77
77
|
|
78
|
-
def freeze_components(self, stage: Literal['update', 'fetch', '
|
78
|
+
def freeze_components(self, stage: Literal['update', 'fetch', 'joint'] = 'joint'):
|
79
79
|
"""Freeze encoder/decoder except memory-related layers."""
|
80
80
|
if self.encoder.freeze_without_memory is not None:
|
81
81
|
self.encoder.freeze_without_memory(unfreeze_norms=True)
|
@@ -124,6 +124,29 @@ class MrlActorModel(nn.Module):
|
|
124
124
|
def reset_memory(self):
|
125
125
|
self.memory_attention.reset_memory()
|
126
126
|
|
127
|
+
def memory_parameters(self) -> list[nn.Parameter]:
|
128
|
+
return list(set(
|
129
|
+
self.encoder.memory_parameters() +
|
130
|
+
self.decoder.memory_parameters() +
|
131
|
+
self.memory_attention.parameters()
|
132
|
+
))
|
133
|
+
|
134
|
+
def memory_cross_attention_parameters(self) -> list[nn.Parameter]:
|
135
|
+
return list(set(
|
136
|
+
self.encoder.memory_parameters() +
|
137
|
+
self.decoder.memory_parameters()
|
138
|
+
))
|
139
|
+
|
140
|
+
def memory_attention_parameters(self) -> Iterator[nn.Parameter]:
|
141
|
+
return self.memory_attention.parameters()
|
142
|
+
|
143
|
+
|
144
|
+
def not_memory_parameters(self) -> list[nn.Parameter]:
|
145
|
+
return list(set(
|
146
|
+
self.encoder.not_memory_parameters() +
|
147
|
+
self.decoder.not_memory_parameters()
|
148
|
+
))
|
149
|
+
|
127
150
|
def unique_parameters(self):
|
128
151
|
return list(set(
|
129
152
|
list(self.encoder.parameters()) +
|
rxnn/training/mrl.py
CHANGED
@@ -3,7 +3,7 @@ from torch.utils.data import DataLoader, DistributedSampler
|
|
3
3
|
from torch.utils.tensorboard import SummaryWriter
|
4
4
|
import torch.distributed as dist
|
5
5
|
from torch.nn.parallel import DistributedDataParallel
|
6
|
-
from typing import Optional, TypedDict, Union
|
6
|
+
from typing import Optional, TypedDict, Union, TypeAlias, Literal
|
7
7
|
from enum import Enum
|
8
8
|
import random, os
|
9
9
|
from ..transformers.sampler import BatchSampler
|
@@ -17,6 +17,8 @@ from .models import MrlActorAction, MrlActorModel, MrlCriticModel
|
|
17
17
|
|
18
18
|
class MrlConfig(TypedDict):
|
19
19
|
lr: float
|
20
|
+
separate_memory_lr: Optional[bool]
|
21
|
+
memory_lr: Optional[float]
|
20
22
|
critic_lr: float
|
21
23
|
max_seq_len: int
|
22
24
|
critic_max_len: int
|
@@ -29,6 +31,8 @@ class MrlStrategy(Enum):
|
|
29
31
|
MULTI_STEP_STRATEGY = 2
|
30
32
|
LONG_RANGE_STRATEGY = 3
|
31
33
|
|
34
|
+
UnfreezeItem = Union[int, tuple[int, float]]
|
35
|
+
UnfreezeEpochsStrategy: TypeAlias = Union[int, tuple[UnfreezeItem, UnfreezeItem, UnfreezeItem, int]]
|
32
36
|
|
33
37
|
class CurriculumConfig(TypedDict):
|
34
38
|
steps: int
|
@@ -37,12 +41,14 @@ class CurriculumConfig(TypedDict):
|
|
37
41
|
eval_dataset: Optional[MrlCurriculumDataset]
|
38
42
|
callbacks: Optional[list[MrlTrainerCallback]]
|
39
43
|
strategy: MrlStrategy
|
40
|
-
unfreeze_epoch: Optional[
|
44
|
+
unfreeze_epoch: Optional[UnfreezeEpochsStrategy]
|
41
45
|
random_resets: Optional[bool]
|
42
46
|
random_resets_from: Optional[int]
|
43
47
|
random_resets_ratio: Optional[float]
|
44
48
|
reward_model: Optional[MrlRewardModel]
|
49
|
+
separate_memory_lr: Optional[bool]
|
45
50
|
lr: Optional[float]
|
51
|
+
memory_lr: Optional[float]
|
46
52
|
critic_lr: Optional[float]
|
47
53
|
weight_decay: Optional[float]
|
48
54
|
critic_weight_decay: Optional[float]
|
@@ -84,6 +90,7 @@ class MRLTrainer:
|
|
84
90
|
use_amp: bool = False,
|
85
91
|
dtype: torch.dtype = torch.float32,
|
86
92
|
callbacks: list[MrlTrainerCallback] = None,
|
93
|
+
|
87
94
|
):
|
88
95
|
"""
|
89
96
|
Trainer for Memory Reinforcement Learning (MRL) in Reactive Transformer.
|
@@ -123,15 +130,27 @@ class MRLTrainer:
|
|
123
130
|
self.use_amp = use_amp
|
124
131
|
self.dtype = dtype
|
125
132
|
|
126
|
-
self.
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
133
|
+
self.separate_memory_lr = config.get('separate_memory_lr', False)
|
134
|
+
|
135
|
+
if self.separate_memory_lr:
|
136
|
+
self.base_optim_config = {
|
137
|
+
'lr': config.get('lr', 3e-4),
|
138
|
+
'memory_lr': config.get('memory_lr', 5e-4),
|
139
|
+
'critic_lr': config.get('critic_lr', 1e-4),
|
140
|
+
'weight_decay': config.get('weight_decay', 0.01),
|
141
|
+
'critic_weight_decay': config.get('critic_weight_decay', 0.01),
|
142
|
+
}
|
143
|
+
else:
|
144
|
+
self.base_optim_config = {
|
145
|
+
'lr': config.get('lr', 3e-4),
|
146
|
+
'critic_lr': config.get('critic_lr', 1e-4),
|
147
|
+
'weight_decay': config.get('weight_decay', 0.01),
|
148
|
+
'critic_weight_decay': config.get('critic_weight_decay', 0.01),
|
149
|
+
}
|
150
|
+
|
151
|
+
self.optim_config = self.base_optim_config
|
132
152
|
|
133
|
-
|
134
|
-
self.optimizer, self.critic_optimizer = self._init_optimizers(**self.base_optim_config)
|
153
|
+
self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
|
135
154
|
|
136
155
|
self.scaler = torch.amp.GradScaler() if self.use_amp else None
|
137
156
|
self.critic_scaler = torch.amp.GradScaler() if self.use_amp else None
|
@@ -158,18 +177,34 @@ class MRLTrainer:
|
|
158
177
|
self.global_epoch = 0
|
159
178
|
self.global_epochs_count = 0
|
160
179
|
|
161
|
-
def _init_optimizers(
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
weight_decay
|
166
|
-
|
180
|
+
def _init_optimizers(
|
181
|
+
self,
|
182
|
+
lr: float,
|
183
|
+
critic_lr: float,
|
184
|
+
weight_decay: float,
|
185
|
+
critic_weight_decay: float,
|
186
|
+
memory_lr: Optional[float] = None,
|
187
|
+
) -> tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
|
188
|
+
if memory_lr is not None:
|
189
|
+
optimizer = torch.optim.AdamW([
|
190
|
+
{ 'params': self.actor.not_memory_parameters(), 'lr': lr },
|
191
|
+
{ 'params': self.actor.memory_parameters(), 'lr': memory_lr },
|
192
|
+
],
|
193
|
+
weight_decay=weight_decay,
|
194
|
+
)
|
195
|
+
else:
|
196
|
+
optimizer = torch.optim.AdamW(
|
197
|
+
self.actor.unique_parameters(),
|
198
|
+
lr=lr,
|
199
|
+
weight_decay=weight_decay,
|
200
|
+
)
|
167
201
|
|
168
202
|
critic_optimizer = torch.optim.AdamW(
|
169
203
|
self.critic.parameters(),
|
170
204
|
lr=critic_lr,
|
171
205
|
weight_decay=critic_weight_decay,
|
172
206
|
)
|
207
|
+
|
173
208
|
return optimizer, critic_optimizer
|
174
209
|
|
175
210
|
|
@@ -712,7 +747,7 @@ class MRLTrainer:
|
|
712
747
|
|
713
748
|
return should_stop_stage
|
714
749
|
|
715
|
-
def _setup_curriculum_step(self, config: CurriculumConfig) -> tuple[tuple[int,
|
750
|
+
def _setup_curriculum_step(self, config: CurriculumConfig) -> tuple[tuple[int, UnfreezeEpochsStrategy], tuple[bool, int, float]]:
|
716
751
|
# 1. Set common fields based on config
|
717
752
|
self.curriculum_steps = config.get('steps', 1) # number of steps to run in episode
|
718
753
|
self.train_dataset = config.get('dataset', None) # training dataset for current curriculum stage
|
@@ -722,13 +757,29 @@ class MRLTrainer:
|
|
722
757
|
self.strategy = config.get('strategy',
|
723
758
|
MrlStrategy.MULTI_STEP_STRATEGY) # MRL strategy for given curriculum stage
|
724
759
|
self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
|
725
|
-
if config['lr'] is not None or config['critic_lr'] is not None or config['weight_decay'] is not None or config['critic_weight_decay'] is not None:
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
760
|
+
if config['lr'] is not None or config['critic_lr'] is not None or config['weight_decay'] is not None or config['critic_weight_decay'] is not None or (config['separate_memory_lr'] and config['memory_lr'] is not None):
|
761
|
+
if config.get('separate_memory_lr', False):
|
762
|
+
self.optim_config = {
|
763
|
+
'lr': config.get('lr', self.base_optim_config['lr']),
|
764
|
+
'critic_lr': config.get('critic_lr', self.base_optim_config['critic_lr']),
|
765
|
+
'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
|
766
|
+
'critic_weight_decay': config.get('critic_weight_decay', self.base_optim_config['critic_weight_decay']),
|
767
|
+
'memory_lr': config.get('memory_lr', self.base_optim_config['memory_lr']),
|
768
|
+
}
|
769
|
+
else:
|
770
|
+
self.optim_config = {
|
771
|
+
'lr': config.get('lr', self.base_optim_config['lr']),
|
772
|
+
'critic_lr': config.get('critic_lr', self.base_optim_config['critic_lr']),
|
773
|
+
'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
|
774
|
+
'critic_weight_decay': config.get('critic_weight_decay', self.base_optim_config['critic_weight_decay']),
|
775
|
+
}
|
776
|
+
self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
|
777
|
+
elif self.optim_config != self.base_optim_config:
|
778
|
+
self.optim_config = self.base_optim_config
|
779
|
+
self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
|
780
|
+
|
781
|
+
|
782
|
+
|
732
783
|
|
733
784
|
# 2. Get epochs and random resets configs
|
734
785
|
epochs = config.get('epochs', 5) # number of epochs for current stage
|
@@ -745,6 +796,82 @@ class MRLTrainer:
|
|
745
796
|
|
746
797
|
return (epochs, unfreeze_epoch), (random_resets, random_resets_from, random_resets_ratio)
|
747
798
|
|
799
|
+
def _apply_unfreeze_strategy(self, epoch: int, unfreeze_epoch: UnfreezeEpochsStrategy):
|
800
|
+
is_staged_unfreeze = isinstance(unfreeze_epoch, tuple)
|
801
|
+
if is_staged_unfreeze:
|
802
|
+
update_epoch, fetch_epoch, joint_epoch, all_epoch = unfreeze_epoch
|
803
|
+
|
804
|
+
if isinstance(update_epoch, tuple):
|
805
|
+
switch_epoch, cross_att_lr = update_epoch
|
806
|
+
if epoch == switch_epoch:
|
807
|
+
self.actor.freeze_components('joint')
|
808
|
+
self.optimizer = self._init_unfreeze_optimizer('update', cross_att_lr)
|
809
|
+
print(f"Activating 'update' unfreeze strategy with custom cross_att_lr: {cross_att_lr}")
|
810
|
+
elif epoch == update_epoch:
|
811
|
+
self.actor.freeze_components('update')
|
812
|
+
print(f"Activating 'update' unfreeze strategy - mem-att trainable / cross-att frozen / rest model frozen")
|
813
|
+
|
814
|
+
if isinstance(fetch_epoch, tuple):
|
815
|
+
switch_epoch, mem_att_lr = fetch_epoch
|
816
|
+
if epoch == fetch_epoch:
|
817
|
+
self.actor.freeze_components('joint')
|
818
|
+
self.optimizer = self._init_unfreeze_optimizer('fetch', mem_att_lr)
|
819
|
+
print(f"Activating 'fetch' unfreeze strategy with custom mem_att_lr: {mem_att_lr}")
|
820
|
+
elif epoch == fetch_epoch:
|
821
|
+
self.actor.freeze_components('fetch')
|
822
|
+
print(f"Activating 'fetch' unfreeze strategy - mem-att frozen / cross-att trainable / rest model frozen")
|
823
|
+
|
824
|
+
if isinstance(joint_epoch, tuple):
|
825
|
+
switch_epoch, model_lr = joint_epoch
|
826
|
+
if epoch == joint_epoch:
|
827
|
+
self.actor.unfreeze_components()
|
828
|
+
self.optimizer = self._init_unfreeze_optimizer('joint', model_lr)
|
829
|
+
print(f"Activating 'joint' unfreeze strategy with custom model_lr: {model_lr}")
|
830
|
+
elif epoch == joint_epoch:
|
831
|
+
self.actor.freeze_components('joint')
|
832
|
+
print(f"Activating 'joint' unfreeze strategy - mem-att/cross-att trainable / rest model frozen")
|
833
|
+
if epoch == all_epoch:
|
834
|
+
self.actor.unfreeze_components()
|
835
|
+
self.optimizer = self._init_unfreeze_optimizer('all', 0.)
|
836
|
+
print(f"Switching to train 'all' strategy - unfreeze all components")
|
837
|
+
elif epoch == unfreeze_epoch:
|
838
|
+
self.actor.unfreeze_components()
|
839
|
+
print(f"Switching to train 'all' strategy - unfreeze all components")
|
840
|
+
|
841
|
+
def _init_unfreeze_optimizer(
|
842
|
+
self,
|
843
|
+
mode: Literal['update', 'fetch', 'joint', 'all'],
|
844
|
+
unfreeze_lr: float,
|
845
|
+
) -> torch.optim.Optimizer:
|
846
|
+
memory_lr = self.optim_config['memory_lr'] if 'memory_lr' in self.optim_config else self.optim_config['lr']
|
847
|
+
model_lr = self.optim_config['lr']
|
848
|
+
|
849
|
+
if mode == 'update':
|
850
|
+
params = [
|
851
|
+
{'params': self.actor.not_memory_parameters(), 'lr': model_lr},
|
852
|
+
{'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
|
853
|
+
{'params': self.actor.memory_cross_attention_parameters(), 'lr': unfreeze_lr},
|
854
|
+
]
|
855
|
+
elif mode == 'fetch':
|
856
|
+
params = [
|
857
|
+
{'params': self.actor.not_memory_parameters(), 'lr': model_lr},
|
858
|
+
{'params': self.actor.memory_cross_attention_parameters(), 'lr': memory_lr},
|
859
|
+
{'params': self.actor.memory_attention_parameters(), 'lr': unfreeze_lr},
|
860
|
+
]
|
861
|
+
elif mode == 'joint':
|
862
|
+
params = [
|
863
|
+
{'params': self.actor.not_memory_parameters(), 'lr': unfreeze_lr},
|
864
|
+
{'params': self.actor.memory_parameters(), 'lr': memory_lr},
|
865
|
+
]
|
866
|
+
else:
|
867
|
+
params = [
|
868
|
+
{'params': self.actor.not_memory_parameters(), 'lr': model_lr},
|
869
|
+
{'params': self.actor.memory_parameters(), 'lr': memory_lr},
|
870
|
+
]
|
871
|
+
|
872
|
+
return torch.optim.AdamW(params, weight_decay=self.optim_config['weight_decay'])
|
873
|
+
|
874
|
+
|
748
875
|
def __call__(self, curriculum_config: list[CurriculumConfig], batch_size: int):
|
749
876
|
"""Start Memory Reinforcement Learning Curriculum."""
|
750
877
|
|
@@ -770,7 +897,11 @@ class MRLTrainer:
|
|
770
897
|
|
771
898
|
# 4. Freeze all components except memory attention and memory cross-attention layers in decoder/encoder
|
772
899
|
if unfreeze_epoch != 0:
|
773
|
-
self.actor.freeze_components('
|
900
|
+
self.actor.freeze_components('joint')
|
901
|
+
if isinstance(unfreeze_epoch, tuple):
|
902
|
+
print(f"Starting training with unfreeze strategies - 'warmup' - mem-att/cross-att trainable / rest model frozen")
|
903
|
+
else:
|
904
|
+
print(f"Starting training with simple unfreeze - 'joint' - mem-att/cross-att trainable / rest model frozen")
|
774
905
|
|
775
906
|
# 5. Setup train DataLoader
|
776
907
|
if self.use_ddp:
|
@@ -810,21 +941,8 @@ class MRLTrainer:
|
|
810
941
|
else:
|
811
942
|
self.random_resets_ratio = 1.0
|
812
943
|
|
813
|
-
# 11.
|
814
|
-
|
815
|
-
if is_staged_unfreeze:
|
816
|
-
update_epoch, fetch_epoch, both_epoch, all_epoch = unfreeze_epoch
|
817
|
-
if epoch == update_epoch:
|
818
|
-
self.actor.freeze_components('update')
|
819
|
-
elif epoch == fetch_epoch:
|
820
|
-
self.actor.freeze_components('fetch')
|
821
|
-
elif epoch == both_epoch:
|
822
|
-
self.actor.freeze_components('both')
|
823
|
-
elif epoch == all_epoch:
|
824
|
-
self.actor.unfreeze_components()
|
825
|
-
else:
|
826
|
-
if epoch == unfreeze_epoch:
|
827
|
-
self.actor.unfreeze_components()
|
944
|
+
# 11. Apply the unfreeze strategy
|
945
|
+
self._apply_unfreeze_strategy(epoch, unfreeze_epoch)
|
828
946
|
|
829
947
|
# 12. Set epoch for distributed sampler
|
830
948
|
if train_sampler is not None:
|
rxnn/transformers/layers.py
CHANGED
@@ -64,6 +64,13 @@ class ReactiveTransformerLayer(nn.Module):
|
|
64
64
|
for param in self.norm2.parameters():
|
65
65
|
param.requires_grad_(is_trainable)
|
66
66
|
|
67
|
+
def memory_parameters(self) -> list[nn.Parameter]:
|
68
|
+
return list(self.memory_cross_attention.parameters()) + list(self.norm2.parameters())
|
69
|
+
|
70
|
+
def not_memory_parameters(self) -> list[nn.Parameter]:
|
71
|
+
memory_params = self.memory_parameters()
|
72
|
+
return [param for param in self.parameters() if param not in memory_params]
|
73
|
+
|
67
74
|
def update_max_len(self, max_seq_len: int):
|
68
75
|
if self.attention.rope is not None:
|
69
76
|
self.attention.rope.update_max_len(max_seq_len)
|
rxnn/transformers/models.py
CHANGED
@@ -39,6 +39,16 @@ class ReactiveTransformerBase(nn.Module):
|
|
39
39
|
for i in range(self.num_own_layers):
|
40
40
|
self.layers[i].trainable_cross_attention_(is_trainable, with_norms)
|
41
41
|
|
42
|
+
def memory_parameters(self) -> list[nn.Parameter]:
|
43
|
+
own = [param for layer in self.layers for param in layer.memory_parameters()]
|
44
|
+
shared = [param for layer in self.shared_layers for param in layer.memory_parameters()] if self.shared_layers else []
|
45
|
+
return own + shared
|
46
|
+
|
47
|
+
def not_memory_parameters(self) -> list[nn.Parameter]:
|
48
|
+
own = [param for layer in self.layers for param in layer.not_memory_parameters()]
|
49
|
+
shared = [param for layer in self.shared_layers for param in layer.not_memory_parameters()] if self.shared_layers else []
|
50
|
+
return own + shared
|
51
|
+
|
42
52
|
def moe_router_loss(self):
|
43
53
|
return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_own_layers) if self.layers[i].use_moe or self.layers[i].use_moe_att] + [
|
44
54
|
self.shared_layers[i].moe_router_loss() for i in range(self.num_shared_layers) if self.shared_layers[i].use_moe or self.shared_layers[i].use_moe_att]).mean()
|
@@ -6,17 +6,17 @@ rxnn/experimental/models.py,sha256=foBo0n0ufvBnfIdJomiEg3CuSOiWSt-q5ako7vzYxx4,4
|
|
6
6
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
7
7
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
8
|
rxnn/memory/attention.py,sha256=PXVBZQYNsRraZh7QDBgUOdPy3lTI8B0d8CzduojBjG0,1747
|
9
|
-
rxnn/memory/norm.py,sha256=
|
9
|
+
rxnn/memory/norm.py,sha256=E98jOQEuIOFFhlkvS8s4fFN-D4tLO6vaOqnObv1oVmA,6592
|
10
10
|
rxnn/memory/stm.py,sha256=IH_3INw7FdI013t56ui3Zq9GPUq-k3HeZGjx6BerS4g,3888
|
11
11
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
rxnn/rxt/models.py,sha256=
|
12
|
+
rxnn/rxt/models.py,sha256=r8wZeeNTC2VAhiiNe4y7LrbnB4wjFu_cupKiGkpdgjI,13002
|
13
13
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
14
|
rxnn/training/base.py,sha256=_xik1GXE4RJ_nxwqLQ1ccXA5pRtBCi-jL-jeRFBdHBU,11851
|
15
15
|
rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
|
16
16
|
rxnn/training/callbacks.py,sha256=-N0MQPpZQaUWCINdTOsjul4bDGbGr2JgQBqOIXBLS6o,35053
|
17
17
|
rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
|
18
|
-
rxnn/training/models.py,sha256=
|
19
|
-
rxnn/training/mrl.py,sha256=
|
18
|
+
rxnn/training/models.py,sha256=bY6yZoXYJEsrcymtb5Ep41vmFVHplCGWlrw1dI0oFRc,6807
|
19
|
+
rxnn/training/mrl.py,sha256=MnLaYWxblc5cF261R5PNjIvddVQVNxyjAkEYtchBn9E,49299
|
20
20
|
rxnn/training/reward.py,sha256=7MTVdNm5HnWmt6zFDi3TAYmnVSL_-24riOoY2F7z4x8,11290
|
21
21
|
rxnn/training/rl.py,sha256=j-KNLoZjhaEKasYNOc8DxHtwvknAgAJFwvXKot6otFA,3272
|
22
22
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
@@ -25,14 +25,14 @@ rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
|
|
25
25
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
26
26
|
rxnn/transformers/attention.py,sha256=d0Igo1Nrn76BphbHrzekiKJfT3RCy4iSXSB6FLAOTgc,16247
|
27
27
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
28
|
-
rxnn/transformers/layers.py,sha256=
|
28
|
+
rxnn/transformers/layers.py,sha256=LXSY829fIHSCmFmClhQ6B7I5aKbiOqy9mZmwlJG_r7U,7961
|
29
29
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
30
|
-
rxnn/transformers/models.py,sha256=
|
30
|
+
rxnn/transformers/models.py,sha256=QwVxYN9DrKllEpOiFoAx4CiThOWafeTa-OAY7L6gN0Y,8929
|
31
31
|
rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
32
32
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
33
33
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
34
34
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
35
|
-
rxnn-0.2.
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
35
|
+
rxnn-0.2.27.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
36
|
+
rxnn-0.2.27.dist-info/METADATA,sha256=woZT3PVGgtEJP7DIAJv1-Mdfd4XvKoCRHANQgoTXoXk,25960
|
37
|
+
rxnn-0.2.27.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
38
|
+
rxnn-0.2.27.dist-info/RECORD,,
|
File without changes
|
File without changes
|