rxnn 0.2.71__py3-none-any.whl → 0.2.73__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/memory/attention.py +100 -39
- rxnn/memory/gate.py +60 -0
- rxnn/rxt/models.py +153 -9
- rxnn/training/mrl.py +3 -3
- rxnn/transformers/layers.py +10 -16
- rxnn/transformers/models.py +4 -4
- {rxnn-0.2.71.dist-info → rxnn-0.2.73.dist-info}/METADATA +1 -1
- {rxnn-0.2.71.dist-info → rxnn-0.2.73.dist-info}/RECORD +10 -9
- {rxnn-0.2.71.dist-info → rxnn-0.2.73.dist-info}/LICENSE +0 -0
- {rxnn-0.2.71.dist-info → rxnn-0.2.73.dist-info}/WHEEL +0 -0
rxnn/memory/attention.py
CHANGED
@@ -2,17 +2,15 @@ import torch
|
|
2
2
|
import torch.nn as nn
|
3
3
|
from .stm import ShortTermMemory
|
4
4
|
|
5
|
+
|
5
6
|
class StmMemoryAttention(nn.Module):
|
6
7
|
def __init__(
|
7
8
|
self,
|
8
9
|
stm: ShortTermMemory,
|
9
10
|
attention_layers: nn.ModuleList,
|
10
11
|
memory_norm_layers: nn.ModuleList,
|
11
|
-
|
12
|
-
|
13
|
-
init_gate: float = 0.0,
|
14
|
-
use_dynamic_gate: bool = False,
|
15
|
-
use_tanh_gate: bool = False,
|
12
|
+
memory_input_norm_layers: nn.ModuleList,
|
13
|
+
residual_gate_layers: nn.ModuleList,
|
16
14
|
debug_mode: bool = False,
|
17
15
|
debug_interval: int = 10,
|
18
16
|
*args,
|
@@ -22,16 +20,12 @@ class StmMemoryAttention(nn.Module):
|
|
22
20
|
self.stm = stm
|
23
21
|
self.attention_layers = attention_layers
|
24
22
|
self.memory_norm_layers = memory_norm_layers
|
25
|
-
|
23
|
+
self.memory_input_norm_layers = memory_input_norm_layers
|
24
|
+
self.residual_gate_layers = residual_gate_layers
|
25
|
+
assert (len(self.attention_layers) == len(self.memory_norm_layers) ==
|
26
|
+
len(self.residual_gate_layers) == len(self.memory_input_norm_layers) ==
|
27
|
+
self.stm.memory.size(0))
|
26
28
|
self.num_layers = len(attention_layers)
|
27
|
-
self.use_gated_residual = use_gated_residual
|
28
|
-
self.per_slot_gate = per_slot_gate
|
29
|
-
self.use_dynamic_gate = use_dynamic_gate
|
30
|
-
self.use_tanh_gate = use_tanh_gate
|
31
|
-
if self.use_gated_residual:
|
32
|
-
gate_shape = (self.num_layers, self.stm.stm_size, 1) if self.per_slot_gate else (self.num_layers,)
|
33
|
-
self.gate = nn.Parameter(torch.full(gate_shape, init_gate))
|
34
|
-
|
35
29
|
self.debug_mode = debug_mode
|
36
30
|
self.debug_interval = debug_interval
|
37
31
|
self.debug_step = 0
|
@@ -41,32 +35,27 @@ class StmMemoryAttention(nn.Module):
|
|
41
35
|
if self.attention_layers[i].rope is not None:
|
42
36
|
self.attention_layers[i].rope.update_max_len(max_seq_len)
|
43
37
|
|
44
|
-
def _residual_gate(self, gate: torch.Tensor, layer_stm: torch.Tensor, new_layer_stm: torch.Tensor) -> torch.Tensor:
|
45
|
-
if self.use_dynamic_gate:
|
46
|
-
mean_dim = -1 if self.per_slot_gate else [1, 2]
|
47
|
-
gate_input = gate * (new_layer_stm + layer_stm).mean(dim=mean_dim, keepdim=True)
|
48
|
-
layer_gate = torch.tanh(gate_input) if self.use_tanh_gate else torch.sigmoid(gate_input)
|
49
|
-
else:
|
50
|
-
layer_gate = torch.tanh(gate) if self.use_tanh_gate else torch.sigmoid(gate)
|
51
|
-
if self.use_tanh_gate:
|
52
|
-
return (1 + layer_gate) * new_layer_stm + (1 - layer_gate) * layer_stm
|
53
|
-
else:
|
54
|
-
return layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
|
55
|
-
|
56
38
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
39
|
+
# 1. Process correct attention mask
|
57
40
|
if attention_mask is not None:
|
58
41
|
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
|
42
|
+
# 2. Init new empty STM
|
59
43
|
new_stm = torch.zeros_like(self.stm.memory)
|
44
|
+
|
45
|
+
# 3. Run Short-Term Memory update for all layers
|
60
46
|
for i in range(self.num_layers):
|
47
|
+
# 4. Get current layer STM value
|
61
48
|
layer_stm = self.stm(i)
|
62
|
-
#
|
49
|
+
# 5. Expand layer STM to batch size, if it's not in batch mode
|
63
50
|
if layer_stm.size(0) == 1:
|
64
51
|
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
65
|
-
|
52
|
+
|
53
|
+
# 6. Get encoded layer data and normalize it
|
54
|
+
encoded_layer_data = self.memory_input_norm_layers[i](x[i])
|
55
|
+
# 7. Normalize STM layer
|
66
56
|
normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
|
67
|
-
if torch.isnan(normalized_layer_stm).any():
|
68
|
-
print(f"NaN detected in {i} layer memory norm output")
|
69
57
|
|
58
|
+
# 8. Print normalization stats in debug mode
|
70
59
|
if self.debug_mode and self.training:
|
71
60
|
if self.debug_step != 0 and self.debug_step % self.debug_interval == 0:
|
72
61
|
self.debug_step = 0
|
@@ -74,16 +63,88 @@ class StmMemoryAttention(nn.Module):
|
|
74
63
|
else:
|
75
64
|
self.debug_step += 1
|
76
65
|
|
77
|
-
|
78
|
-
print(f"NaN detected in {i} layer encoded data input")
|
79
|
-
|
66
|
+
# 9. Calculate memory attention
|
80
67
|
new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=attention_mask)
|
81
|
-
|
82
|
-
|
68
|
+
# 10. Combine new updated layer state with current STM state in residual gate
|
69
|
+
new_stm[i] = self.residual_gate_layers[i](layer_stm, new_layer_stm) # residual
|
70
|
+
# 11. Update all layers/models
|
71
|
+
self.stm.update_all(new_stm)
|
72
|
+
return self.stm.memory
|
73
|
+
|
74
|
+
|
75
|
+
class InterlayerStmMemoryAttention(StmMemoryAttention):
|
76
|
+
def __init__(
|
77
|
+
self,
|
78
|
+
stm: ShortTermMemory,
|
79
|
+
attention_layers: nn.ModuleList,
|
80
|
+
memory_norm_layers: nn.ModuleList,
|
81
|
+
memory_input_norm_layers: nn.ModuleList,
|
82
|
+
residual_gate_layers: nn.ModuleList,
|
83
|
+
mean_attention_layers: nn.ModuleList,
|
84
|
+
mean_memory_norm_layers: nn.ModuleList,
|
85
|
+
mean_residual_gate_layers: nn.ModuleList,
|
86
|
+
mean_stm_norm: nn.Module,
|
87
|
+
debug_mode: bool = False,
|
88
|
+
debug_interval: int = 10,
|
89
|
+
**kwargs
|
90
|
+
):
|
91
|
+
super(InterlayerStmMemoryAttention, self).__init__(
|
92
|
+
stm, attention_layers, memory_norm_layers, memory_input_norm_layers, residual_gate_layers,
|
93
|
+
debug_mode=debug_mode, debug_interval=debug_interval, **kwargs
|
94
|
+
)
|
95
|
+
self.mean_attention_layers = mean_attention_layers
|
96
|
+
self.mean_memory_norm_layers = mean_memory_norm_layers
|
97
|
+
self.mean_stm_norm = mean_stm_norm
|
98
|
+
self.mean_residual_gate_layers = mean_residual_gate_layers
|
99
|
+
assert (len(self.mean_attention_layers) == len(self.mean_memory_norm_layers) ==
|
100
|
+
len(self.mean_residual_gate_layers) == self.num_layers)
|
83
101
|
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
102
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
103
|
+
# 1. Process correct attention mask
|
104
|
+
if attention_mask is not None:
|
105
|
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
|
106
|
+
# 2. Init new empty STM
|
107
|
+
new_stm = torch.zeros_like(self.stm.memory)
|
108
|
+
|
109
|
+
# 3. Get mean STM value from layers for mean interlayer memory attention
|
110
|
+
mean_stm = self.stm.memory.mean(dim=0) # [batch_size, stm_size, embed_dim]
|
111
|
+
# 4. Normalize mean STM layer
|
112
|
+
normalized_mean_stm = self.mean_stm_norm(mean_stm)
|
113
|
+
|
114
|
+
# 5. Run Short-Term Memory update for all layers
|
115
|
+
for i in range(self.num_layers):
|
116
|
+
# 6. Get current layer STM value
|
117
|
+
layer_stm = self.stm(i)
|
118
|
+
# 7. Expand layer STM to batch size, if it's not in batch mode
|
119
|
+
if layer_stm.size(0) == 1:
|
120
|
+
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
121
|
+
|
122
|
+
# 8. Mean interlayer memory attention
|
123
|
+
# a) normalize STM layer value
|
124
|
+
pre_normalized_layer_stm = self.mean_memory_norm_layers[i](layer_stm)
|
125
|
+
# b) calculate attention between STM layer and mean value of all STM layers (from previous interaction)
|
126
|
+
interlayer_stm = self.mean_attention_layers[i](pre_normalized_layer_stm, normalized_mean_stm, normalized_mean_stm, mask=None)
|
127
|
+
# c) combine updated interlayer state with current STM state in residual gate
|
128
|
+
updated_layer_stm = self.mean_residual_gate_layers[i](layer_stm, interlayer_stm)
|
129
|
+
|
130
|
+
# 9. Main memory attention
|
131
|
+
# a) get encoded data for current layer and normalize it
|
132
|
+
encoded_layer_data = self.memory_input_norm_layers[i](x[i])
|
133
|
+
# b) normalize STM layer value
|
134
|
+
normalized_layer_stm = self.memory_norm_layers[i](updated_layer_stm)
|
135
|
+
# c) print normalized STM stats in debug mode
|
136
|
+
if self.debug_mode and self.training:
|
137
|
+
if self.debug_step != 0 and self.debug_step % self.debug_interval == 0:
|
138
|
+
self.debug_step = 0
|
139
|
+
print(
|
140
|
+
f"Pre-Normalized STM stats - mean: {pre_normalized_layer_stm.mean().item():.4f}, std: {pre_normalized_layer_stm.std().item():.4f}")
|
141
|
+
print(f"Normalized STM stats - mean: {normalized_layer_stm.mean().item():.4f}, std: {normalized_layer_stm.std().item():.4f}")
|
142
|
+
else:
|
143
|
+
self.debug_step += 1
|
144
|
+
# d) calculate memory attention between STM layer and encoded data
|
145
|
+
new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=attention_mask)
|
146
|
+
# e) combine new updated layer STM with previous state in residual gate
|
147
|
+
new_stm[i] = self.residual_gate_layers[i](updated_layer_stm, new_layer_stm) # residual
|
148
|
+
# 10. Update all layers/models
|
88
149
|
self.stm.update_all(new_stm)
|
89
150
|
return self.stm.memory
|
rxnn/memory/gate.py
ADDED
@@ -0,0 +1,60 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
from typing import TypeAlias, Literal
|
4
|
+
|
5
|
+
ResidualGateType: TypeAlias = Literal['static', 'elementwise', 'linear']
|
6
|
+
|
7
|
+
|
8
|
+
class ResidualGate(nn.Module):
|
9
|
+
def __init__(
|
10
|
+
self,
|
11
|
+
stm_size: int,
|
12
|
+
use_gate: bool = False,
|
13
|
+
gate_type: ResidualGateType = 'static',
|
14
|
+
per_slot_gate: bool = True,
|
15
|
+
init_gate: float = 0.0,
|
16
|
+
use_tanh_gate: bool = True,
|
17
|
+
**kwargs,
|
18
|
+
):
|
19
|
+
super(ResidualGate, self).__init__(**kwargs)
|
20
|
+
self.use_gate = use_gate
|
21
|
+
self.per_slot_gate = per_slot_gate
|
22
|
+
self.gate_type = gate_type
|
23
|
+
self.use_tanh_gate = use_tanh_gate
|
24
|
+
|
25
|
+
if self.use_gate:
|
26
|
+
if self.gate_type == 'linear':
|
27
|
+
self.gate = nn.Linear(stm_size, stm_size if self.per_slot_gate else 1)
|
28
|
+
else:
|
29
|
+
gate_shape = (stm_size, 1) if self.per_slot_gate else (1,)
|
30
|
+
self.gate = nn.Parameter(torch.full(gate_shape, init_gate))
|
31
|
+
else:
|
32
|
+
self.gate = None
|
33
|
+
|
34
|
+
self.gate_activation = nn.Tanh() if self.use_tanh_gate else nn.Sigmoid()
|
35
|
+
|
36
|
+
def _dynamic_gate(self, old_value: torch.Tensor, new_value: torch.Tensor):
|
37
|
+
if self.gate_type == 'linear':
|
38
|
+
mean_residual = (new_value + old_value).mean(dim=-1)
|
39
|
+
gate_input = self.gate(mean_residual).unsqueeze(-1)
|
40
|
+
else:
|
41
|
+
mean_dim = -1 if self.per_slot_gate else [1, 2]
|
42
|
+
gate_input = self.gate * (new_value + old_value).mean(dim=mean_dim, keepdim=True)
|
43
|
+
return self.gate_activation(gate_input)
|
44
|
+
|
45
|
+
def _calculate_output(self, layer_gate: torch.Tensor, old_value: torch.Tensor, new_value: torch.Tensor) -> torch.Tensor:
|
46
|
+
if self.use_tanh_gate:
|
47
|
+
return (1 + layer_gate) * new_value + (1 - layer_gate) * old_value
|
48
|
+
else:
|
49
|
+
return layer_gate * new_value + (1 - layer_gate) * old_value
|
50
|
+
|
51
|
+
def forward(self, old_value: torch.Tensor, new_value: torch.Tensor) -> torch.Tensor:
|
52
|
+
if not self.use_gate:
|
53
|
+
return new_value + old_value
|
54
|
+
|
55
|
+
if self.gate_type == 'static':
|
56
|
+
layer_gate = self.gate_activation(self.gate)
|
57
|
+
else:
|
58
|
+
layer_gate = self._dynamic_gate(old_value, new_value)
|
59
|
+
|
60
|
+
return self._calculate_output(layer_gate, old_value, new_value)
|
rxnn/rxt/models.py
CHANGED
@@ -9,11 +9,13 @@ from ..transformers.models import ReactiveTransformerBase, ReactiveTransformerEn
|
|
9
9
|
from ..transformers.ff import get_activation_layer
|
10
10
|
from ..memory.stm import ShortTermMemory
|
11
11
|
from ..memory.norm import init_memory_norm
|
12
|
-
from ..memory.attention import StmMemoryAttention
|
12
|
+
from ..memory.attention import StmMemoryAttention, InterlayerStmMemoryAttention
|
13
|
+
from ..memory.gate import ResidualGate, ResidualGateType
|
13
14
|
from ..utils import get_model_size
|
14
15
|
from ..experimental.attention import init_experimental_attention
|
15
16
|
|
16
17
|
|
18
|
+
|
17
19
|
class RxTAlphaComponentConfig(TypedDict):
|
18
20
|
num_layers: int
|
19
21
|
vocab_size: int
|
@@ -260,15 +262,15 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
260
262
|
att_experts: int = None,
|
261
263
|
att_query_experts: int = None,
|
262
264
|
att_query_groups: int = None,
|
263
|
-
norm_type: str = 'rms',
|
265
|
+
norm_type: str = 'classic-rms',
|
264
266
|
norm_init_gate: float = -2.0,
|
265
267
|
norm_per_dim_scale: bool = False,
|
266
268
|
norm_decay: float = 0.9,
|
267
269
|
use_gated_residual: bool = False,
|
268
|
-
residual_per_slot_gate: bool =
|
269
|
-
|
270
|
-
|
271
|
-
use_tanh_residual_gate: bool =
|
270
|
+
residual_per_slot_gate: bool = True,
|
271
|
+
residual_gate_init: float = 3.0,
|
272
|
+
residual_gate_type: ResidualGateType = 'static',
|
273
|
+
use_tanh_residual_gate: bool = True,
|
272
274
|
debug_mode: bool = False,
|
273
275
|
debug_interval: int = 10,
|
274
276
|
**kwargs,
|
@@ -296,12 +298,153 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
296
298
|
memory_norm_layers = nn.ModuleList([init_memory_norm(norm_type, embed_dim, stm_size, decay=norm_decay,
|
297
299
|
init_gate=norm_init_gate, per_dim_scale=norm_per_dim_scale)
|
298
300
|
for _ in range(num_layers)])
|
301
|
+
memory_input_norm_layers = nn.ModuleList(nn.RMSNorm(embed_dim) for _ in range(num_layers))
|
299
302
|
attention_layers = nn.ModuleList([att_init() for _ in range(num_layers)])
|
303
|
+
residual_gates = nn.ModuleList([
|
304
|
+
ResidualGate(
|
305
|
+
stm_size, use_gate=use_gated_residual, gate_type=residual_gate_type, per_slot_gate=residual_per_slot_gate,
|
306
|
+
init_gate=residual_gate_init, use_tanh_gate=use_tanh_residual_gate
|
307
|
+
) for _ in range(num_layers)
|
308
|
+
])
|
309
|
+
|
300
310
|
self.model = StmMemoryAttention(
|
301
311
|
stm, attention_layers, memory_norm_layers,
|
302
|
-
|
303
|
-
|
304
|
-
|
312
|
+
memory_input_norm_layers, residual_gates,
|
313
|
+
debug_mode=debug_mode, debug_interval=debug_interval,
|
314
|
+
)
|
315
|
+
|
316
|
+
def freeze(self):
|
317
|
+
for param in self.parameters():
|
318
|
+
param.requires_grad = False
|
319
|
+
|
320
|
+
def unfreeze(self):
|
321
|
+
for param in self.parameters():
|
322
|
+
param.requires_grad = True
|
323
|
+
|
324
|
+
def load_shared_memory(self, stm: ShortTermMemory):
|
325
|
+
self.model.stm = stm
|
326
|
+
|
327
|
+
def update_max_len(self, max_seq_len: int):
|
328
|
+
self.model.update_max_len(max_seq_len)
|
329
|
+
|
330
|
+
def reset_memory(self, init_type: str = None):
|
331
|
+
self.model.stm.reset(init_type)
|
332
|
+
|
333
|
+
def clone_reset_memory(self):
|
334
|
+
self.model.stm.clone_detach_reset()
|
335
|
+
|
336
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
337
|
+
return self.model(x, attention_mask=attention_mask)
|
338
|
+
|
339
|
+
|
340
|
+
class RxTAlphaInterlayerMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
|
341
|
+
"""RxT-Alpha (Reactive Transformer) memory attention model with interlayer STM attention"""
|
342
|
+
|
343
|
+
def __init__(
|
344
|
+
self,
|
345
|
+
num_layers: int = 12,
|
346
|
+
embed_dim: int = 512,
|
347
|
+
att_heads: int = 16,
|
348
|
+
seq_len: int = 1024,
|
349
|
+
stm_size: int = 1024,
|
350
|
+
use_flash_attention: bool = False,
|
351
|
+
att_dropout: float = 0.0,
|
352
|
+
att_groups: int = 1,
|
353
|
+
att_type: str = 'sqa',
|
354
|
+
att_experts: int = None,
|
355
|
+
att_query_experts: int = None,
|
356
|
+
att_query_groups: int = None,
|
357
|
+
interlayer_att_dropout: float = 0.0,
|
358
|
+
interlayer_att_groups: int = 1,
|
359
|
+
interlayer_att_type: str = 'sqa',
|
360
|
+
interlayer_att_experts: int = None,
|
361
|
+
interlayer_att_query_experts: int = None,
|
362
|
+
interlayer_att_query_groups: int = None,
|
363
|
+
norm_type: str = 'classic-rms',
|
364
|
+
norm_init_gate: float = -2.0,
|
365
|
+
norm_per_dim_scale: bool = False,
|
366
|
+
norm_decay: float = 0.9,
|
367
|
+
use_gated_residual: bool = False,
|
368
|
+
residual_per_slot_gate: bool = True,
|
369
|
+
residual_gate_init: float = 3.0,
|
370
|
+
residual_gate_type: ResidualGateType = 'static',
|
371
|
+
use_tanh_residual_gate: bool = True,
|
372
|
+
debug_mode: bool = False,
|
373
|
+
debug_interval: int = 10,
|
374
|
+
**kwargs,
|
375
|
+
):
|
376
|
+
super(RxTAlphaInterlayerMemoryAttention, self).__init__(**kwargs)
|
377
|
+
|
378
|
+
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma',
|
379
|
+
'sqa'], 'Memory attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
|
380
|
+
|
381
|
+
rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
|
382
|
+
stm = ShortTermMemory(num_layers, embed_dim, stm_size)
|
383
|
+
|
384
|
+
if att_type in ['mha', 'gqa', 'mqa']:
|
385
|
+
att_init = lambda: init_attention(
|
386
|
+
embed_dim, att_heads, att_type, att_groups, rope=rope,
|
387
|
+
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
388
|
+
max_seq_len=seq_len, is_causal=False, rope_only_for_keys=True
|
389
|
+
)
|
390
|
+
else:
|
391
|
+
att_init = lambda: init_experimental_attention(
|
392
|
+
embed_dim, att_heads, att_type, att_groups, rope=rope,
|
393
|
+
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
394
|
+
max_seq_len=seq_len, is_causal=False, num_experts=att_experts,
|
395
|
+
num_query_experts=att_query_experts, num_query_groups=att_query_groups,
|
396
|
+
rope_only_for_keys=True
|
397
|
+
)
|
398
|
+
|
399
|
+
memory_norm_layers = nn.ModuleList([init_memory_norm(norm_type, embed_dim, stm_size, decay=norm_decay,
|
400
|
+
init_gate=norm_init_gate, per_dim_scale=norm_per_dim_scale)
|
401
|
+
for _ in range(num_layers)])
|
402
|
+
memory_input_norm_layers = nn.ModuleList(nn.RMSNorm(embed_dim) for _ in range(num_layers))
|
403
|
+
attention_layers = nn.ModuleList([att_init() for _ in range(num_layers)])
|
404
|
+
residual_gates = nn.ModuleList([
|
405
|
+
ResidualGate(
|
406
|
+
stm_size, use_gate=use_gated_residual, gate_type=residual_gate_type,
|
407
|
+
per_slot_gate=residual_per_slot_gate,
|
408
|
+
init_gate=residual_gate_init, use_tanh_gate=use_tanh_residual_gate
|
409
|
+
) for _ in range(num_layers)
|
410
|
+
])
|
411
|
+
|
412
|
+
# Interlayer attention
|
413
|
+
if interlayer_att_type in ['mha', 'gqa', 'mqa']:
|
414
|
+
interlayer_att_init = lambda: init_attention(
|
415
|
+
embed_dim, att_heads, interlayer_att_type, interlayer_att_groups, rope=None,
|
416
|
+
use_flash_attention=use_flash_attention, dropout=interlayer_att_dropout, is_causal=False
|
417
|
+
)
|
418
|
+
else:
|
419
|
+
interlayer_att_init = lambda: init_experimental_attention(
|
420
|
+
embed_dim, att_heads, interlayer_att_type, interlayer_att_groups, rope=None,
|
421
|
+
use_flash_attention=use_flash_attention, dropout=interlayer_att_dropout, is_causal=False,
|
422
|
+
num_experts=interlayer_att_experts, num_query_experts=interlayer_att_query_experts, num_query_groups=interlayer_att_query_groups
|
423
|
+
)
|
424
|
+
|
425
|
+
mean_attention_layers = nn.ModuleList([interlayer_att_init() for _ in range(num_layers)])
|
426
|
+
|
427
|
+
mean_stm_norm = init_memory_norm(
|
428
|
+
norm_type, embed_dim, stm_size, decay=norm_decay,
|
429
|
+
init_gate=norm_init_gate, per_dim_scale=norm_per_dim_scale
|
430
|
+
)
|
431
|
+
|
432
|
+
mean_memory_norm_layers = nn.ModuleList([init_memory_norm(norm_type, embed_dim, stm_size, decay=norm_decay,
|
433
|
+
init_gate=norm_init_gate, per_dim_scale=norm_per_dim_scale)
|
434
|
+
for _ in range(num_layers)])
|
435
|
+
|
436
|
+
mean_residual_gates = nn.ModuleList([
|
437
|
+
ResidualGate(
|
438
|
+
stm_size, use_gate=use_gated_residual, gate_type=residual_gate_type,
|
439
|
+
per_slot_gate=residual_per_slot_gate,
|
440
|
+
init_gate=residual_gate_init, use_tanh_gate=use_tanh_residual_gate
|
441
|
+
) for _ in range(num_layers)
|
442
|
+
])
|
443
|
+
|
444
|
+
self.model = InterlayerStmMemoryAttention(
|
445
|
+
stm, attention_layers, memory_norm_layers, memory_input_norm_layers, residual_gates,
|
446
|
+
mean_attention_layers, mean_memory_norm_layers, mean_residual_gates, mean_stm_norm,
|
447
|
+
debug_mode=debug_mode, debug_interval=debug_interval,
|
305
448
|
)
|
306
449
|
|
307
450
|
def freeze(self):
|
@@ -327,6 +470,7 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
327
470
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
328
471
|
return self.model(x, attention_mask=attention_mask)
|
329
472
|
|
473
|
+
|
330
474
|
class RxTAlphaCriticEncoder(RxTAlphaComponentBase, pipeline_tag="text-classification", license="apache-2.0"):
|
331
475
|
"""RxT-Alpha (Reactive Transformer) encoder model"""
|
332
476
|
|
rxnn/training/mrl.py
CHANGED
@@ -592,7 +592,7 @@ class MRLTrainer:
|
|
592
592
|
|
593
593
|
router_loss = actor.moe_router_loss()
|
594
594
|
if torch.isnan(router_loss).any():
|
595
|
-
print("NaN detected in router loss")
|
595
|
+
print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in router loss")
|
596
596
|
if router_loss is not None:
|
597
597
|
return main_loss + self.moe_aux_loss_scale * router_loss
|
598
598
|
else:
|
@@ -671,7 +671,7 @@ class MRLTrainer:
|
|
671
671
|
# 4.4 Unscale and clip gradient norms
|
672
672
|
self.scaler.unscale_(self.optimizer)
|
673
673
|
torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
|
674
|
-
error_if_nonfinite=
|
674
|
+
error_if_nonfinite=False)
|
675
675
|
if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
|
676
676
|
self._log_gradients(logits)
|
677
677
|
# 4.5 Run scaled optimization step
|
@@ -692,7 +692,7 @@ class MRLTrainer:
|
|
692
692
|
policy_loss.backward(retain_graph=True)
|
693
693
|
# 4.4 Clip gradient norms
|
694
694
|
torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
|
695
|
-
error_if_nonfinite=
|
695
|
+
error_if_nonfinite=False)
|
696
696
|
if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
|
697
697
|
self._log_gradients(logits)
|
698
698
|
# 4.5 Run scaled optimization step
|
rxnn/transformers/layers.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.nn as nn
|
3
|
+
from poetry.console.commands import self
|
4
|
+
|
3
5
|
from .attention import MultiHeadAttention
|
4
6
|
from .ff import FeedForward, GatedFeedForward
|
5
7
|
from .moe import MoeFeedForward, GatedMoeFeedForward
|
@@ -49,10 +51,12 @@ class ReactiveTransformerLayer(nn.Module):
|
|
49
51
|
self.norm1 = nn.RMSNorm(embed_dim)
|
50
52
|
self.norm2 = nn.RMSNorm(embed_dim)
|
51
53
|
self.norm3 = nn.RMSNorm(embed_dim)
|
54
|
+
self.stm_norm = nn.RMSNorm(embed_dim)
|
52
55
|
else:
|
53
56
|
self.norm1 = nn.LayerNorm(embed_dim)
|
54
57
|
self.norm2 = nn.LayerNorm(embed_dim)
|
55
58
|
self.norm3 = nn.LayerNorm(embed_dim)
|
59
|
+
self.stm_norm = nn.LayerNorm(embed_dim)
|
56
60
|
self.use_post_norm = use_post_norm
|
57
61
|
self.use_moe = use_moe
|
58
62
|
self.use_moe_att = use_moe_att
|
@@ -63,9 +67,11 @@ class ReactiveTransformerLayer(nn.Module):
|
|
63
67
|
if with_norms:
|
64
68
|
for param in self.norm2.parameters():
|
65
69
|
param.requires_grad_(is_trainable)
|
70
|
+
for param in self.stm_norm.parameters():
|
71
|
+
param.requires_grad_(is_trainable)
|
66
72
|
|
67
73
|
def memory_parameters(self) -> list[nn.Parameter]:
|
68
|
-
return list(self.memory_cross_attention.parameters()) + list(self.norm2.parameters())
|
74
|
+
return list(self.memory_cross_attention.parameters()) + list(self.norm2.parameters()) + list(self.stm_norm.parameters())
|
69
75
|
|
70
76
|
def not_memory_parameters(self) -> list[nn.Parameter]:
|
71
77
|
return (list(self.attention.parameters()) + list(self.norm1.parameters()) +
|
@@ -102,11 +108,8 @@ class ReactiveTransformerLayer(nn.Module):
|
|
102
108
|
residual = x
|
103
109
|
if not self.use_post_norm:
|
104
110
|
x = self.norm1(x)
|
105
|
-
if torch.isnan(x).any():
|
106
|
-
print("NaN detected in pre-norm (self-attention) output")
|
107
111
|
x = self.attention(x, x, x, mask=mask)
|
108
|
-
|
109
|
-
print("NaN detected in self-attention output")
|
112
|
+
|
110
113
|
x = residual + x
|
111
114
|
if self.use_post_norm:
|
112
115
|
x = self.norm1(x)
|
@@ -114,18 +117,13 @@ class ReactiveTransformerLayer(nn.Module):
|
|
114
117
|
residual = x
|
115
118
|
if not self.use_post_norm:
|
116
119
|
x = self.norm2(x)
|
117
|
-
if torch.isnan(x).any():
|
118
|
-
print("NaN detected in pre-norm (cross-attention) output")
|
119
120
|
|
121
|
+
# normalize STM and prepare STM mask
|
122
|
+
stm = self.stm_norm(stm)
|
120
123
|
mem_mask = mask.squeeze(1).unsqueeze(-1).expand(-1, -1, -1, stm.size(1)) \
|
121
124
|
if mask is not None else None
|
122
125
|
|
123
|
-
if torch.isnan(stm).any():
|
124
|
-
print("NaN detected in STM cross-attention input")
|
125
|
-
|
126
126
|
x = self.memory_cross_attention(x, stm, stm, mask=mem_mask)
|
127
|
-
if torch.isnan(x).any():
|
128
|
-
print("NaN detected in cross-attention output")
|
129
127
|
x = residual + x
|
130
128
|
if self.use_post_norm:
|
131
129
|
x = self.norm2(x)
|
@@ -134,11 +132,7 @@ class ReactiveTransformerLayer(nn.Module):
|
|
134
132
|
residual = x
|
135
133
|
if not self.use_post_norm:
|
136
134
|
x = self.norm3(x)
|
137
|
-
if torch.isnan(x).any():
|
138
|
-
print("NaN detected in pre-norm (ff) output")
|
139
135
|
x = self.ff(x)
|
140
|
-
if torch.isnan(x).any():
|
141
|
-
print("NaN detected in ff output")
|
142
136
|
x = residual + x
|
143
137
|
if self.use_post_norm:
|
144
138
|
x = self.norm3(x)
|
rxnn/transformers/models.py
CHANGED
@@ -94,7 +94,7 @@ class ReactiveTransformerDecoder(ReactiveTransformerBase):
|
|
94
94
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
95
95
|
x = super().forward(x) # apply embeddings
|
96
96
|
if torch.isnan(x).any():
|
97
|
-
print("NaN detected in decoder embedding output")
|
97
|
+
print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in decoder embedding output")
|
98
98
|
seq_len = x.size(1)
|
99
99
|
if not self.use_flash_attention and self.use_relative_embedding:
|
100
100
|
mask = create_causal_mask(seq_len, device=x.device)
|
@@ -112,7 +112,7 @@ class ReactiveTransformerDecoder(ReactiveTransformerBase):
|
|
112
112
|
for i in range(self.num_own_layers):
|
113
113
|
x = self._handle_layer(i, x, mask=mask)
|
114
114
|
if torch.isnan(x).any():
|
115
|
-
print(f"NaN detected in {i}. decoder layer output")
|
115
|
+
print(f"!!!!!!!!!!!!!!!!!!!!!! NaN detected in {i}. decoder layer output")
|
116
116
|
return self.head(self.head_norm(x) if self.use_head_norm else x)
|
117
117
|
|
118
118
|
|
@@ -122,7 +122,7 @@ class ReactiveTransformerEncoder(ReactiveTransformerBase):
|
|
122
122
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
|
123
123
|
x = super().forward(x) # apply embeddings
|
124
124
|
if torch.isnan(x).any():
|
125
|
-
print("NaN detected in encoder embedding output")
|
125
|
+
print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in encoder embedding output")
|
126
126
|
if attention_mask is not None:
|
127
127
|
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
|
128
128
|
|
@@ -136,7 +136,7 @@ class ReactiveTransformerEncoder(ReactiveTransformerBase):
|
|
136
136
|
for i in range(self.num_own_layers):
|
137
137
|
x = self._handle_layer(i, x, mask=attention_mask)
|
138
138
|
if torch.isnan(x).any():
|
139
|
-
print(f"NaN detected in {i}. encoder layer output")
|
139
|
+
print(f"!!!!!!!!!!!!!!!!!!!!!! NaN detected in {i}. encoder layer output")
|
140
140
|
hidden_states.append(x)
|
141
141
|
return x, torch.stack(hidden_states)
|
142
142
|
|
@@ -5,11 +5,12 @@ rxnn/experimental/attention.py,sha256=jlNS82INjycNEfmk3HtkIacUvT_ELhaCO2g-kZTvhX
|
|
5
5
|
rxnn/experimental/models.py,sha256=KheR1zSNJIaeVvpVAkEJwcuM5nOqQP0ZF08XhrtGJ8E,5387
|
6
6
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
7
7
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
|
-
rxnn/memory/attention.py,sha256=
|
8
|
+
rxnn/memory/attention.py,sha256=CReYJZNA5JRED_QWqX-yKqEKZTRX6DNCAB8uFLZtKxI,7513
|
9
|
+
rxnn/memory/gate.py,sha256=pR_H2y9C7S02QskoFAEC9Tmluut0k4GGlHgvZGiw6m4,2332
|
9
10
|
rxnn/memory/norm.py,sha256=cVjjhCLqR5K6-321SP_ObG17y-ddlcTJeCTXvW4vpk0,6675
|
10
11
|
rxnn/memory/stm.py,sha256=jv57gsH9XW19sLbxpRDqsp1yfsii_4Ef4Ncr_ztk-i4,3937
|
11
12
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
rxnn/rxt/models.py,sha256=
|
13
|
+
rxnn/rxt/models.py,sha256=Pb48Frl6HV4Wb9CZgYtmzch3k_4Jess3rhs7dY1I96k,22209
|
13
14
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
15
|
rxnn/training/base.py,sha256=CqaArEZYOdH64nmKfx28U3GI46TzO4oNkjf_hrF23Cw,11835
|
15
16
|
rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
|
@@ -17,7 +18,7 @@ rxnn/training/callbacks.py,sha256=rS8leuVFPVVfE5Zc8DMkUZhRIPN-vpPbUjowXE5TSBw,36
|
|
17
18
|
rxnn/training/dataset.py,sha256=ruU6k33pQmpTqhxpjLFNdDJnCjcrBcGeFOzJqFahJDM,51880
|
18
19
|
rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
|
19
20
|
rxnn/training/models.py,sha256=ILkcqBV1MImnULnq-YDSSEf8cUdEbUgQaH0FRTsa4LA,9069
|
20
|
-
rxnn/training/mrl.py,sha256=
|
21
|
+
rxnn/training/mrl.py,sha256=KUJAdUznquhf5UlcpV-QF5oKHDBEsDecMEVmMLQZw7w,67380
|
21
22
|
rxnn/training/reward.py,sha256=uiSsBXmjMw2yv-1Bssy3RTlpU6zP8ape3490Sl-aT0M,16144
|
22
23
|
rxnn/training/rl.py,sha256=hWtExxY-_pAmTOGYxyCNounUbaGWvLDVltC4sRC7MN4,7175
|
23
24
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
@@ -26,14 +27,14 @@ rxnn/training/utils.py,sha256=ngDCm654NL3UsPy190Er4XPc9HI-OyEV6tDLMgEEvQc,6219
|
|
26
27
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
27
28
|
rxnn/transformers/attention.py,sha256=KRnKT6XUqAXElxV9y72mSpdTeiMgCKCCLqqxCFNTHmA,16372
|
28
29
|
rxnn/transformers/ff.py,sha256=WDjO-H9XWInoWnUnxiseIH6Kx5GlHP0zGJygwhcb1gc,2589
|
29
|
-
rxnn/transformers/layers.py,sha256=
|
30
|
+
rxnn/transformers/layers.py,sha256=fxjlbQG6cwxq-b2ei4DnohSQGH5gwy4GkfP9duTUvjw,8492
|
30
31
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
31
|
-
rxnn/transformers/models.py,sha256=
|
32
|
+
rxnn/transformers/models.py,sha256=TP0H9do53Z0vd8kpHMISBzMpHE5X9QIHcy0B-iJHuNQ,11711
|
32
33
|
rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
33
34
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
34
35
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
35
36
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
39
|
-
rxnn-0.2.
|
37
|
+
rxnn-0.2.73.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
38
|
+
rxnn-0.2.73.dist-info/METADATA,sha256=gtoRMeFgBuOZs4lRKl9JGUxZ2X4C9K78Ee-NHLMqW4E,60420
|
39
|
+
rxnn-0.2.73.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
40
|
+
rxnn-0.2.73.dist-info/RECORD,,
|
File without changes
|
File without changes
|