rxnn 0.2.24__py3-none-any.whl → 0.2.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/memory/attention.py +0 -1
- rxnn/memory/norm.py +26 -24
- rxnn/rxt/models.py +36 -27
- rxnn/training/models.py +13 -0
- rxnn/training/mrl.py +42 -16
- rxnn/transformers/layers.py +7 -0
- rxnn/transformers/models.py +10 -0
- {rxnn-0.2.24.dist-info → rxnn-0.2.26.dist-info}/METADATA +1 -1
- {rxnn-0.2.24.dist-info → rxnn-0.2.26.dist-info}/RECORD +11 -11
- {rxnn-0.2.24.dist-info → rxnn-0.2.26.dist-info}/LICENSE +0 -0
- {rxnn-0.2.24.dist-info → rxnn-0.2.26.dist-info}/WHEEL +0 -0
rxnn/memory/attention.py
CHANGED
@@ -35,7 +35,6 @@ class StmMemoryAttention(nn.Module):
|
|
35
35
|
encoded_layer_data = x[i]
|
36
36
|
normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
|
37
37
|
new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=mask)
|
38
|
-
# self.stm.update_layer(i, new_layer_stm + layer_stm)
|
39
38
|
new_stm[i] = new_layer_stm + layer_stm # residual
|
40
39
|
self.stm.update_all(new_stm)
|
41
40
|
return self.stm.memory
|
rxnn/memory/norm.py
CHANGED
@@ -7,10 +7,11 @@ class AdaptivePositionalMemoryNorm(nn.Module):
|
|
7
7
|
self,
|
8
8
|
num_slots: int,
|
9
9
|
dim: int,
|
10
|
-
decay: float = 0.
|
10
|
+
decay: float = 0.9,
|
11
11
|
use_scale: bool = True,
|
12
12
|
use_gate: bool = True,
|
13
|
-
init_gate: float = -
|
13
|
+
init_gate: float = -2.0,
|
14
|
+
per_dim_scale: bool = False,
|
14
15
|
):
|
15
16
|
super(AdaptivePositionalMemoryNorm, self).__init__()
|
16
17
|
self.use_gate = use_gate
|
@@ -20,39 +21,38 @@ class AdaptivePositionalMemoryNorm(nn.Module):
|
|
20
21
|
self.eps = 1e-6
|
21
22
|
|
22
23
|
# Learnable parameters
|
23
|
-
|
24
|
-
self.
|
24
|
+
scale_shape = (num_slots, 1) if not per_dim_scale else (dim,)
|
25
|
+
self.scale = nn.Parameter(torch.ones(*scale_shape)) if use_scale else None
|
26
|
+
self.gate = nn.Parameter(torch.full((num_slots, 1), init_gate)) if use_gate else None
|
25
27
|
|
26
28
|
# EMA buffers
|
27
29
|
self.register_buffer("ema_rms", torch.ones(num_slots, 1))
|
28
30
|
|
29
31
|
# Initialize parameters
|
30
32
|
if self.scale is not None:
|
31
|
-
nn.init.normal_(self.scale, mean=1.0, std=0.
|
33
|
+
nn.init.normal_(self.scale, mean=1.0, std=0.1)
|
32
34
|
|
33
35
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
34
|
-
# x shape: [batch_size, num_slots, dim]
|
35
|
-
batch_size = x.size(0)
|
36
|
-
|
37
36
|
# Calculate current RMS per slot
|
38
|
-
|
39
|
-
|
37
|
+
# x: [batch_size, num_slots, dim]
|
38
|
+
current_rms = x.pow(2).mean(dim=-1, keepdim=True).sqrt() # [batch, num_slots, 1]
|
39
|
+
slot_rms = current_rms.mean(dim=0) # [num_slots, 1] (average over batch)
|
40
40
|
|
41
41
|
# Update EMA during training
|
42
42
|
if self.training:
|
43
|
-
self.ema_rms = self.decay * self.ema_rms + (1 - self.decay) * slot_rms.detach()
|
43
|
+
self.ema_rms = self.decay * self.ema_rms + (1 - self.decay) * slot_rms.detach() # [num_slots, 1]
|
44
44
|
|
45
45
|
# Normalize using EMA statistics
|
46
|
-
x_norm = x * torch.rsqrt(self.ema_rms + self.eps)
|
46
|
+
x_norm = x * torch.rsqrt(self.ema_rms + self.eps) # [batch_size, num_slots, dim] * [num_slots, 1]
|
47
47
|
|
48
48
|
# Apply learned scale per slot
|
49
49
|
if self.scale is not None:
|
50
|
-
x_norm = x_norm * self.scale
|
50
|
+
x_norm = x_norm * self.scale # [batch_size, num_slots, dim] * [num_slots, 1] or [dim]
|
51
51
|
|
52
52
|
# Apply gating mechanism
|
53
53
|
if self.use_gate:
|
54
|
-
gate = torch.sigmoid(self.gate) # [
|
55
|
-
return gate * x_norm + (1 - gate) * x
|
54
|
+
gate = torch.sigmoid(self.gate) # [num_slots, 1]
|
55
|
+
return gate * x_norm + (1 - gate) * x # [batch_size, num_slots, dim] * [num_slots, 1]
|
56
56
|
|
57
57
|
return x_norm
|
58
58
|
|
@@ -77,7 +77,7 @@ class AdaptiveRMSMemoryNorm(nn.Module):
|
|
77
77
|
# x shape: [batch_size, num_slots, dim]
|
78
78
|
if self.training and hasattr(self, 'ema_rms'):
|
79
79
|
# Compute current RMS across all slots and batch (scalar)
|
80
|
-
current_rms = x.pow(2).mean(
|
80
|
+
current_rms = x.pow(2).mean(dim=-1).mean().sqrt()
|
81
81
|
self.ema_rms = self.ema_rms * self.decay + current_rms * (1 - self.decay)
|
82
82
|
rms = self.ema_rms
|
83
83
|
else:
|
@@ -150,24 +150,26 @@ class MemoryNormConfig(TypedDict):
|
|
150
150
|
use_gate: bool
|
151
151
|
init_gate: float
|
152
152
|
init_scale: float
|
153
|
+
per_dim_scale: bool
|
153
154
|
|
154
155
|
def init_memory_norm(
|
155
156
|
norm_type: str,
|
156
157
|
dim: int,
|
157
158
|
num_slots: int = None,
|
158
|
-
decay: float = 0.
|
159
|
+
decay: float = 0.9,
|
159
160
|
use_scale: bool = True,
|
160
161
|
use_gate: bool = True,
|
161
|
-
init_gate: float = -
|
162
|
+
init_gate: float = -2.0,
|
162
163
|
init_scale: float = 1.0,
|
164
|
+
per_dim_scale: bool = False,
|
163
165
|
) -> nn.Module:
|
164
|
-
assert norm_type in [
|
165
|
-
if norm_type ==
|
166
|
+
assert norm_type in ['layer', 'rms', 'adaptive', 'positional']
|
167
|
+
if norm_type == 'layer':
|
166
168
|
return MemoryLayerNorm(dim, use_gate, init_scale, init_gate)
|
167
|
-
elif norm_type ==
|
169
|
+
elif norm_type == 'rms':
|
168
170
|
return SimpleRMSMemoryNorm(dim, use_gate, init_scale, init_gate)
|
169
|
-
elif norm_type ==
|
171
|
+
elif norm_type == 'adaptive':
|
170
172
|
return AdaptiveRMSMemoryNorm(dim, use_gate, decay, init_scale, init_gate)
|
171
|
-
elif norm_type ==
|
172
|
-
return AdaptivePositionalMemoryNorm(num_slots, dim, decay, use_scale, use_gate, init_gate)
|
173
|
+
elif norm_type == 'positional':
|
174
|
+
return AdaptivePositionalMemoryNorm(num_slots, dim, decay, use_scale, use_gate, init_gate, per_dim_scale)
|
173
175
|
return MemoryLayerNorm(dim, use_gate, init_scale, init_gate)
|
rxnn/rxt/models.py
CHANGED
@@ -13,6 +13,7 @@ from ..memory.attention import StmMemoryAttention
|
|
13
13
|
from ..utils import get_model_size
|
14
14
|
from ..experimental.attention import init_experimental_attention
|
15
15
|
|
16
|
+
|
16
17
|
class RxTAlphaComponentConfig(TypedDict):
|
17
18
|
num_layers: int
|
18
19
|
vocab_size: int
|
@@ -76,8 +77,10 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
|
|
76
77
|
assert ff_activation in ['relu', 'gelu',
|
77
78
|
'swish', 'silu', 'linear',
|
78
79
|
'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
|
79
|
-
assert self_att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma',
|
80
|
-
|
80
|
+
assert self_att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma',
|
81
|
+
'sqa'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
|
82
|
+
assert cross_att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma',
|
83
|
+
'sqa'], 'Memory cross-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
|
81
84
|
|
82
85
|
embedding = nn.Embedding(vocab_size, embed_dim)
|
83
86
|
rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
|
@@ -92,20 +95,25 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
|
|
92
95
|
else:
|
93
96
|
att_init = lambda: init_experimental_attention(embed_dim, att_heads, self_att_type, att_groups, rope=rope,
|
94
97
|
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
95
|
-
max_seq_len=seq_len, is_causal=is_causal,
|
98
|
+
max_seq_len=seq_len, is_causal=is_causal,
|
99
|
+
num_experts=att_experts,
|
96
100
|
num_query_experts=att_query_experts,
|
97
101
|
num_query_groups=att_query_groups)
|
98
102
|
|
99
103
|
if cross_att_type in ['mha', 'gqa', 'mqa']:
|
100
104
|
cross_att_init = lambda: init_attention(embed_dim, att_heads, cross_att_type, att_groups, rope=rope,
|
101
|
-
|
102
|
-
|
105
|
+
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
106
|
+
max_seq_len=seq_len, is_causal=is_causal, rope_only_for_query=True)
|
103
107
|
else:
|
104
|
-
cross_att_init = lambda: init_experimental_attention(embed_dim, att_heads, cross_att_type,
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
108
|
+
cross_att_init = lambda: init_experimental_attention(embed_dim, att_heads, cross_att_type,
|
109
|
+
cross_att_groups or att_groups, rope=rope,
|
110
|
+
use_flash_attention=use_flash_attention,
|
111
|
+
dropout=att_dropout,
|
112
|
+
max_seq_len=seq_len, is_causal=is_causal,
|
113
|
+
num_experts=att_experts,
|
114
|
+
num_query_experts=att_query_experts,
|
115
|
+
num_query_groups=cross_att_query_groups or att_query_groups,
|
116
|
+
rope_only_for_query=True)
|
109
117
|
|
110
118
|
layers = nn.ModuleList([
|
111
119
|
ReactiveTransformerLayer(
|
@@ -137,6 +145,12 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
|
|
137
145
|
def load_shared_memory(self, stm: ShortTermMemory):
|
138
146
|
self.model.stm = stm
|
139
147
|
|
148
|
+
def memory_parameters(self) -> list[nn.Parameter]:
|
149
|
+
return self.model.memory_parameters()
|
150
|
+
|
151
|
+
def not_memory_parameters(self) -> list[nn.Parameter]:
|
152
|
+
return self.model.not_memory_parameters()
|
153
|
+
|
140
154
|
def freeze_without_memory(self, unfreeze_norms: bool = True):
|
141
155
|
for param in self.model.parameters():
|
142
156
|
param.requires_grad_(False)
|
@@ -211,20 +225,9 @@ class RxTAlphaDecoder(RxTAlphaComponentBase, pipeline_tag="text-generation", lic
|
|
211
225
|
return self.model(x, attention_mask=attention_mask)
|
212
226
|
|
213
227
|
|
214
|
-
def build_rxt_alpha_for_pretraining(
|
215
|
-
encoder_config: RxTAlphaComponentConfig,
|
216
|
-
decoder_config: RxTAlphaComponentConfig,
|
217
|
-
) -> tuple[RxTAlphaEncoder, RxTAlphaDecoder]:
|
218
|
-
encoder = RxTAlphaEncoder(**encoder_config)
|
219
|
-
decoder = RxTAlphaDecoder(**decoder_config)
|
220
|
-
|
221
|
-
encoder.load_shared_memory(decoder.model.stm)
|
222
|
-
encoder.load_shared_embedding(decoder.model.embedding)
|
223
|
-
|
224
|
-
return encoder, decoder
|
225
|
-
|
226
228
|
class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
|
227
229
|
"""RxT-Alpha (Reactive Transformer) memory attention model"""
|
230
|
+
|
228
231
|
def __init__(
|
229
232
|
self,
|
230
233
|
num_layers: int = 12,
|
@@ -234,17 +237,21 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
234
237
|
stm_size: int = 1024,
|
235
238
|
use_flash_attention: bool = False,
|
236
239
|
att_dropout: float = 0.0,
|
237
|
-
norm_type: str = 'rms',
|
238
240
|
att_groups: int = 1,
|
239
241
|
att_type: str = 'sqa',
|
240
242
|
att_experts: int = None,
|
241
243
|
att_query_experts: int = None,
|
242
244
|
att_query_groups: int = None,
|
245
|
+
norm_type: str = 'rms',
|
246
|
+
norm_init_gate: float = -2.0,
|
247
|
+
norm_per_dim_scale: bool = False,
|
248
|
+
norm_decay: float = 0.9,
|
243
249
|
**kwargs,
|
244
250
|
):
|
245
251
|
super(RxTAlphaMemoryAttention, self).__init__(**kwargs)
|
246
252
|
|
247
|
-
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma',
|
253
|
+
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma',
|
254
|
+
'sqa'], 'Memory attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
|
248
255
|
|
249
256
|
rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
|
250
257
|
stm = ShortTermMemory(num_layers, embed_dim, stm_size)
|
@@ -256,11 +263,14 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
256
263
|
else:
|
257
264
|
att_init = lambda: init_experimental_attention(embed_dim, att_heads, att_type, att_groups, rope=rope,
|
258
265
|
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
259
|
-
max_seq_len=seq_len, is_causal=False,
|
266
|
+
max_seq_len=seq_len, is_causal=False,
|
267
|
+
num_experts=att_experts,
|
260
268
|
num_query_experts=att_query_experts,
|
261
269
|
num_query_groups=att_query_groups, rope_only_for_keys=True)
|
262
270
|
|
263
|
-
memory_norm_layers = nn.ModuleList([init_memory_norm(norm_type, embed_dim, stm_size
|
271
|
+
memory_norm_layers = nn.ModuleList([init_memory_norm(norm_type, embed_dim, stm_size, decay=norm_decay,
|
272
|
+
init_gate=norm_init_gate, per_dim_scale=norm_per_dim_scale)
|
273
|
+
for _ in range(num_layers)])
|
264
274
|
attention_layers = nn.ModuleList([att_init() for _ in range(num_layers)])
|
265
275
|
self.model = StmMemoryAttention(stm, attention_layers, memory_norm_layers)
|
266
276
|
|
@@ -283,4 +293,3 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
283
293
|
|
284
294
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
285
295
|
return self.model(x, attention_mask=attention_mask)
|
286
|
-
|
rxnn/training/models.py
CHANGED
@@ -124,6 +124,19 @@ class MrlActorModel(nn.Module):
|
|
124
124
|
def reset_memory(self):
|
125
125
|
self.memory_attention.reset_memory()
|
126
126
|
|
127
|
+
def memory_parameters(self) -> list[nn.Parameter]:
|
128
|
+
return list(set(
|
129
|
+
self.encoder.memory_parameters() +
|
130
|
+
self.decoder.memory_parameters() +
|
131
|
+
self.memory_attention.parameters()
|
132
|
+
))
|
133
|
+
|
134
|
+
def not_memory_parameters(self) -> list[nn.Parameter]:
|
135
|
+
return list(set(
|
136
|
+
self.encoder.not_memory_parameters() +
|
137
|
+
self.decoder.not_memory_parameters()
|
138
|
+
))
|
139
|
+
|
127
140
|
def unique_parameters(self):
|
128
141
|
return list(set(
|
129
142
|
list(self.encoder.parameters()) +
|
rxnn/training/mrl.py
CHANGED
@@ -17,6 +17,8 @@ from .models import MrlActorAction, MrlActorModel, MrlCriticModel
|
|
17
17
|
|
18
18
|
class MrlConfig(TypedDict):
|
19
19
|
lr: float
|
20
|
+
separate_memory_lr: Optional[bool]
|
21
|
+
memory_lr: Optional[float]
|
20
22
|
critic_lr: float
|
21
23
|
max_seq_len: int
|
22
24
|
critic_max_len: int
|
@@ -42,7 +44,9 @@ class CurriculumConfig(TypedDict):
|
|
42
44
|
random_resets_from: Optional[int]
|
43
45
|
random_resets_ratio: Optional[float]
|
44
46
|
reward_model: Optional[MrlRewardModel]
|
47
|
+
separate_memory_lr: Optional[bool]
|
45
48
|
lr: Optional[float]
|
49
|
+
memory_lr: Optional[float]
|
46
50
|
critic_lr: Optional[float]
|
47
51
|
weight_decay: Optional[float]
|
48
52
|
critic_weight_decay: Optional[float]
|
@@ -84,6 +88,7 @@ class MRLTrainer:
|
|
84
88
|
use_amp: bool = False,
|
85
89
|
dtype: torch.dtype = torch.float32,
|
86
90
|
callbacks: list[MrlTrainerCallback] = None,
|
91
|
+
|
87
92
|
):
|
88
93
|
"""
|
89
94
|
Trainer for Memory Reinforcement Learning (MRL) in Reactive Transformer.
|
@@ -123,15 +128,25 @@ class MRLTrainer:
|
|
123
128
|
self.use_amp = use_amp
|
124
129
|
self.dtype = dtype
|
125
130
|
|
126
|
-
self.
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
131
|
+
self.separate_memory_lr = config.get('separate_memory_lr', False)
|
132
|
+
|
133
|
+
if self.separate_memory_lr:
|
134
|
+
self.base_optim_config = {
|
135
|
+
'lr': (config.get('lr', 3e-4), config.get('memory_lr', 5e-4)),
|
136
|
+
'critic_lr': config.get('critic_lr', 1e-4),
|
137
|
+
'weight_decay': config.get('weight_decay', 0.01),
|
138
|
+
'critic_weight_decay': config.get('critic_weight_decay', 0.01),
|
139
|
+
}
|
140
|
+
else:
|
141
|
+
self.base_optim_config = {
|
142
|
+
'lr': config.get('lr', 3e-4),
|
143
|
+
'critic_lr': config.get('critic_lr', 1e-4),
|
144
|
+
'weight_decay': config.get('weight_decay', 0.01),
|
145
|
+
'critic_weight_decay': config.get('critic_weight_decay', 0.01),
|
146
|
+
}
|
132
147
|
|
133
148
|
# Optimizers
|
134
|
-
self.optimizer, self.critic_optimizer = self._init_optimizers(**self.base_optim_config)
|
149
|
+
self.optimizer, self.critic_optimizer = self._init_optimizers(**self.base_optim_config, separate_memory_lr=self.separate_memory_lr)
|
135
150
|
|
136
151
|
self.scaler = torch.amp.GradScaler() if self.use_amp else None
|
137
152
|
self.critic_scaler = torch.amp.GradScaler() if self.use_amp else None
|
@@ -158,18 +173,28 @@ class MRLTrainer:
|
|
158
173
|
self.global_epoch = 0
|
159
174
|
self.global_epochs_count = 0
|
160
175
|
|
161
|
-
def _init_optimizers(self, lr: float, critic_lr: float, weight_decay: float, critic_weight_decay: float):
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
176
|
+
def _init_optimizers(self, lr: Union[float, tuple[float, float]], critic_lr: float, weight_decay: float, critic_weight_decay: float, separate_memory_lr: bool = False) -> tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
|
177
|
+
if separate_memory_lr:
|
178
|
+
rest_lr, memory_lr = lr
|
179
|
+
optimizer = torch.optim.AdamW([
|
180
|
+
{ 'params': self.actor.not_memory_parameters(), 'lr': rest_lr },
|
181
|
+
{ 'params': self.actor.memory_parameters(), 'lr': memory_lr },
|
182
|
+
],
|
183
|
+
weight_decay=weight_decay,
|
184
|
+
)
|
185
|
+
else:
|
186
|
+
optimizer = torch.optim.AdamW(
|
187
|
+
self.actor.unique_parameters(),
|
188
|
+
lr=lr,
|
189
|
+
weight_decay=weight_decay,
|
190
|
+
)
|
167
191
|
|
168
192
|
critic_optimizer = torch.optim.AdamW(
|
169
193
|
self.critic.parameters(),
|
170
194
|
lr=critic_lr,
|
171
195
|
weight_decay=critic_weight_decay,
|
172
196
|
)
|
197
|
+
|
173
198
|
return optimizer, critic_optimizer
|
174
199
|
|
175
200
|
|
@@ -722,12 +747,13 @@ class MRLTrainer:
|
|
722
747
|
self.strategy = config.get('strategy',
|
723
748
|
MrlStrategy.MULTI_STEP_STRATEGY) # MRL strategy for given curriculum stage
|
724
749
|
self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
|
725
|
-
if config['lr'] is not None or config['critic_lr'] is not None or config['weight_decay'] is not None or config['critic_weight_decay'] is not None:
|
750
|
+
if config['lr'] is not None or config['critic_lr'] is not None or config['weight_decay'] is not None or config['critic_weight_decay'] is not None or (config['separate_memory_lr'] and config['memory_lr'] is not None):
|
726
751
|
self.optimizer, self.critic_optimizer = self._init_optimizers(
|
727
|
-
lr=config.get('lr', self.base_optim_config['lr']),
|
752
|
+
lr=(config.get('lr', self.base_optim_config['lr'][0]), config.get('memory_lr', self.base_optim_config['lr'][1])) if config.get('separate_memory_lr', False) else config.get('lr', self.base_optim_config['lr']),
|
728
753
|
critic_lr=config.get('critic_lr', self.base_optim_config['critic_lr']),
|
729
754
|
weight_decay=config.get('weight_decay', self.base_optim_config['weight_decay']),
|
730
|
-
critic_weight_decay=config.get('critic_weight_decay', self.base_optim_config['critic_weight_decay'])
|
755
|
+
critic_weight_decay=config.get('critic_weight_decay', self.base_optim_config['critic_weight_decay']),
|
756
|
+
separate_memory_lr=config.get('separate_memory_lr', False),
|
731
757
|
)
|
732
758
|
|
733
759
|
# 2. Get epochs and random resets configs
|
rxnn/transformers/layers.py
CHANGED
@@ -64,6 +64,13 @@ class ReactiveTransformerLayer(nn.Module):
|
|
64
64
|
for param in self.norm2.parameters():
|
65
65
|
param.requires_grad_(is_trainable)
|
66
66
|
|
67
|
+
def memory_parameters(self) -> list[nn.Parameter]:
|
68
|
+
return list(self.memory_cross_attention.parameters()) + list(self.norm2.parameters())
|
69
|
+
|
70
|
+
def not_memory_parameters(self) -> list[nn.Parameter]:
|
71
|
+
memory_params = self.memory_parameters()
|
72
|
+
return [param for param in self.parameters() if param not in memory_params]
|
73
|
+
|
67
74
|
def update_max_len(self, max_seq_len: int):
|
68
75
|
if self.attention.rope is not None:
|
69
76
|
self.attention.rope.update_max_len(max_seq_len)
|
rxnn/transformers/models.py
CHANGED
@@ -39,6 +39,16 @@ class ReactiveTransformerBase(nn.Module):
|
|
39
39
|
for i in range(self.num_own_layers):
|
40
40
|
self.layers[i].trainable_cross_attention_(is_trainable, with_norms)
|
41
41
|
|
42
|
+
def memory_parameters(self) -> list[nn.Parameter]:
|
43
|
+
own = [param for layer in self.layers for param in layer.memory_parameters()]
|
44
|
+
shared = [param for layer in self.shared_layers for param in layer.memory_parameters()] if self.shared_layers else []
|
45
|
+
return own + shared
|
46
|
+
|
47
|
+
def not_memory_parameters(self) -> list[nn.Parameter]:
|
48
|
+
own = [param for layer in self.layers for param in layer.not_memory_parameters()]
|
49
|
+
shared = [param for layer in self.shared_layers for param in layer.not_memory_parameters()] if self.shared_layers else []
|
50
|
+
return own + shared
|
51
|
+
|
42
52
|
def moe_router_loss(self):
|
43
53
|
return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_own_layers) if self.layers[i].use_moe or self.layers[i].use_moe_att] + [
|
44
54
|
self.shared_layers[i].moe_router_loss() for i in range(self.num_shared_layers) if self.shared_layers[i].use_moe or self.shared_layers[i].use_moe_att]).mean()
|
@@ -5,18 +5,18 @@ rxnn/experimental/attention.py,sha256=46qwZLJuZMpIBrZ-r9DaQEPPmmZkO464C3Tkm_Mq-c
|
|
5
5
|
rxnn/experimental/models.py,sha256=foBo0n0ufvBnfIdJomiEg3CuSOiWSt-q5ako7vzYxx4,4888
|
6
6
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
7
7
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
|
-
rxnn/memory/attention.py,sha256=
|
9
|
-
rxnn/memory/norm.py,sha256=
|
8
|
+
rxnn/memory/attention.py,sha256=PXVBZQYNsRraZh7QDBgUOdPy3lTI8B0d8CzduojBjG0,1747
|
9
|
+
rxnn/memory/norm.py,sha256=E98jOQEuIOFFhlkvS8s4fFN-D4tLO6vaOqnObv1oVmA,6592
|
10
10
|
rxnn/memory/stm.py,sha256=IH_3INw7FdI013t56ui3Zq9GPUq-k3HeZGjx6BerS4g,3888
|
11
11
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
rxnn/rxt/models.py,sha256=
|
12
|
+
rxnn/rxt/models.py,sha256=r8wZeeNTC2VAhiiNe4y7LrbnB4wjFu_cupKiGkpdgjI,13002
|
13
13
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
14
|
rxnn/training/base.py,sha256=_xik1GXE4RJ_nxwqLQ1ccXA5pRtBCi-jL-jeRFBdHBU,11851
|
15
15
|
rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
|
16
16
|
rxnn/training/callbacks.py,sha256=-N0MQPpZQaUWCINdTOsjul4bDGbGr2JgQBqOIXBLS6o,35053
|
17
17
|
rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
|
18
|
-
rxnn/training/models.py,sha256=
|
19
|
-
rxnn/training/mrl.py,sha256=
|
18
|
+
rxnn/training/models.py,sha256=_TrFwrQ_m6NDPalrafd8faPRyCnDFFFtN_gfzavaCFs,6474
|
19
|
+
rxnn/training/mrl.py,sha256=hDsKQTaQcEVmnJruD3TxHZJJzDWu5I6Rq2HVDLj8ADU,44747
|
20
20
|
rxnn/training/reward.py,sha256=7MTVdNm5HnWmt6zFDi3TAYmnVSL_-24riOoY2F7z4x8,11290
|
21
21
|
rxnn/training/rl.py,sha256=j-KNLoZjhaEKasYNOc8DxHtwvknAgAJFwvXKot6otFA,3272
|
22
22
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
@@ -25,14 +25,14 @@ rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
|
|
25
25
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
26
26
|
rxnn/transformers/attention.py,sha256=d0Igo1Nrn76BphbHrzekiKJfT3RCy4iSXSB6FLAOTgc,16247
|
27
27
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
28
|
-
rxnn/transformers/layers.py,sha256=
|
28
|
+
rxnn/transformers/layers.py,sha256=LXSY829fIHSCmFmClhQ6B7I5aKbiOqy9mZmwlJG_r7U,7961
|
29
29
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
30
|
-
rxnn/transformers/models.py,sha256=
|
30
|
+
rxnn/transformers/models.py,sha256=QwVxYN9DrKllEpOiFoAx4CiThOWafeTa-OAY7L6gN0Y,8929
|
31
31
|
rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
32
32
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
33
33
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
34
34
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
35
|
-
rxnn-0.2.
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
35
|
+
rxnn-0.2.26.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
36
|
+
rxnn-0.2.26.dist-info/METADATA,sha256=XDqI42X3zLRAAKZlVLmstm24KFPP_MfvDtObG9GBc0Y,25960
|
37
|
+
rxnn-0.2.26.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
38
|
+
rxnn-0.2.26.dist-info/RECORD,,
|
File without changes
|
File without changes
|