rxnn 0.2.72__py3-none-any.whl → 0.2.73__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/memory/attention.py +100 -39
- rxnn/memory/gate.py +60 -0
- rxnn/rxt/models.py +153 -9
- rxnn/transformers/layers.py +10 -16
- rxnn/utils.py +0 -5
- {rxnn-0.2.72.dist-info → rxnn-0.2.73.dist-info}/METADATA +1 -1
- {rxnn-0.2.72.dist-info → rxnn-0.2.73.dist-info}/RECORD +9 -8
- {rxnn-0.2.72.dist-info → rxnn-0.2.73.dist-info}/LICENSE +0 -0
- {rxnn-0.2.72.dist-info → rxnn-0.2.73.dist-info}/WHEEL +0 -0
rxnn/memory/attention.py
CHANGED
@@ -2,17 +2,15 @@ import torch
|
|
2
2
|
import torch.nn as nn
|
3
3
|
from .stm import ShortTermMemory
|
4
4
|
|
5
|
+
|
5
6
|
class StmMemoryAttention(nn.Module):
|
6
7
|
def __init__(
|
7
8
|
self,
|
8
9
|
stm: ShortTermMemory,
|
9
10
|
attention_layers: nn.ModuleList,
|
10
11
|
memory_norm_layers: nn.ModuleList,
|
11
|
-
|
12
|
-
|
13
|
-
init_gate: float = 0.0,
|
14
|
-
use_dynamic_gate: bool = False,
|
15
|
-
use_tanh_gate: bool = False,
|
12
|
+
memory_input_norm_layers: nn.ModuleList,
|
13
|
+
residual_gate_layers: nn.ModuleList,
|
16
14
|
debug_mode: bool = False,
|
17
15
|
debug_interval: int = 10,
|
18
16
|
*args,
|
@@ -22,16 +20,12 @@ class StmMemoryAttention(nn.Module):
|
|
22
20
|
self.stm = stm
|
23
21
|
self.attention_layers = attention_layers
|
24
22
|
self.memory_norm_layers = memory_norm_layers
|
25
|
-
|
23
|
+
self.memory_input_norm_layers = memory_input_norm_layers
|
24
|
+
self.residual_gate_layers = residual_gate_layers
|
25
|
+
assert (len(self.attention_layers) == len(self.memory_norm_layers) ==
|
26
|
+
len(self.residual_gate_layers) == len(self.memory_input_norm_layers) ==
|
27
|
+
self.stm.memory.size(0))
|
26
28
|
self.num_layers = len(attention_layers)
|
27
|
-
self.use_gated_residual = use_gated_residual
|
28
|
-
self.per_slot_gate = per_slot_gate
|
29
|
-
self.use_dynamic_gate = use_dynamic_gate
|
30
|
-
self.use_tanh_gate = use_tanh_gate
|
31
|
-
if self.use_gated_residual:
|
32
|
-
gate_shape = (self.num_layers, self.stm.stm_size, 1) if self.per_slot_gate else (self.num_layers,)
|
33
|
-
self.gate = nn.Parameter(torch.full(gate_shape, init_gate))
|
34
|
-
|
35
29
|
self.debug_mode = debug_mode
|
36
30
|
self.debug_interval = debug_interval
|
37
31
|
self.debug_step = 0
|
@@ -41,32 +35,27 @@ class StmMemoryAttention(nn.Module):
|
|
41
35
|
if self.attention_layers[i].rope is not None:
|
42
36
|
self.attention_layers[i].rope.update_max_len(max_seq_len)
|
43
37
|
|
44
|
-
def _residual_gate(self, gate: torch.Tensor, layer_stm: torch.Tensor, new_layer_stm: torch.Tensor) -> torch.Tensor:
|
45
|
-
if self.use_dynamic_gate:
|
46
|
-
mean_dim = -1 if self.per_slot_gate else [1, 2]
|
47
|
-
gate_input = gate * (new_layer_stm + layer_stm).mean(dim=mean_dim, keepdim=True)
|
48
|
-
layer_gate = torch.tanh(gate_input) if self.use_tanh_gate else torch.sigmoid(gate_input)
|
49
|
-
else:
|
50
|
-
layer_gate = torch.tanh(gate) if self.use_tanh_gate else torch.sigmoid(gate)
|
51
|
-
if self.use_tanh_gate:
|
52
|
-
return (1 + layer_gate) * new_layer_stm + (1 - layer_gate) * layer_stm
|
53
|
-
else:
|
54
|
-
return layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
|
55
|
-
|
56
38
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
39
|
+
# 1. Process correct attention mask
|
57
40
|
if attention_mask is not None:
|
58
41
|
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
|
42
|
+
# 2. Init new empty STM
|
59
43
|
new_stm = torch.zeros_like(self.stm.memory)
|
44
|
+
|
45
|
+
# 3. Run Short-Term Memory update for all layers
|
60
46
|
for i in range(self.num_layers):
|
47
|
+
# 4. Get current layer STM value
|
61
48
|
layer_stm = self.stm(i)
|
62
|
-
#
|
49
|
+
# 5. Expand layer STM to batch size, if it's not in batch mode
|
63
50
|
if layer_stm.size(0) == 1:
|
64
51
|
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
65
|
-
|
52
|
+
|
53
|
+
# 6. Get encoded layer data and normalize it
|
54
|
+
encoded_layer_data = self.memory_input_norm_layers[i](x[i])
|
55
|
+
# 7. Normalize STM layer
|
66
56
|
normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
|
67
|
-
if torch.isnan(normalized_layer_stm).any():
|
68
|
-
print(f"!!!!!!!!!!!!!!!!!!!!!! NaN detected in {i} layer memory norm output")
|
69
57
|
|
58
|
+
# 8. Print normalization stats in debug mode
|
70
59
|
if self.debug_mode and self.training:
|
71
60
|
if self.debug_step != 0 and self.debug_step % self.debug_interval == 0:
|
72
61
|
self.debug_step = 0
|
@@ -74,16 +63,88 @@ class StmMemoryAttention(nn.Module):
|
|
74
63
|
else:
|
75
64
|
self.debug_step += 1
|
76
65
|
|
77
|
-
|
78
|
-
print(f"!!!!!!!!!!!!!!!!!!!!!! NaN detected in {i} layer encoded data input")
|
79
|
-
|
66
|
+
# 9. Calculate memory attention
|
80
67
|
new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=attention_mask)
|
81
|
-
|
82
|
-
|
68
|
+
# 10. Combine new updated layer state with current STM state in residual gate
|
69
|
+
new_stm[i] = self.residual_gate_layers[i](layer_stm, new_layer_stm) # residual
|
70
|
+
# 11. Update all layers/models
|
71
|
+
self.stm.update_all(new_stm)
|
72
|
+
return self.stm.memory
|
73
|
+
|
74
|
+
|
75
|
+
class InterlayerStmMemoryAttention(StmMemoryAttention):
|
76
|
+
def __init__(
|
77
|
+
self,
|
78
|
+
stm: ShortTermMemory,
|
79
|
+
attention_layers: nn.ModuleList,
|
80
|
+
memory_norm_layers: nn.ModuleList,
|
81
|
+
memory_input_norm_layers: nn.ModuleList,
|
82
|
+
residual_gate_layers: nn.ModuleList,
|
83
|
+
mean_attention_layers: nn.ModuleList,
|
84
|
+
mean_memory_norm_layers: nn.ModuleList,
|
85
|
+
mean_residual_gate_layers: nn.ModuleList,
|
86
|
+
mean_stm_norm: nn.Module,
|
87
|
+
debug_mode: bool = False,
|
88
|
+
debug_interval: int = 10,
|
89
|
+
**kwargs
|
90
|
+
):
|
91
|
+
super(InterlayerStmMemoryAttention, self).__init__(
|
92
|
+
stm, attention_layers, memory_norm_layers, memory_input_norm_layers, residual_gate_layers,
|
93
|
+
debug_mode=debug_mode, debug_interval=debug_interval, **kwargs
|
94
|
+
)
|
95
|
+
self.mean_attention_layers = mean_attention_layers
|
96
|
+
self.mean_memory_norm_layers = mean_memory_norm_layers
|
97
|
+
self.mean_stm_norm = mean_stm_norm
|
98
|
+
self.mean_residual_gate_layers = mean_residual_gate_layers
|
99
|
+
assert (len(self.mean_attention_layers) == len(self.mean_memory_norm_layers) ==
|
100
|
+
len(self.mean_residual_gate_layers) == self.num_layers)
|
83
101
|
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
102
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
103
|
+
# 1. Process correct attention mask
|
104
|
+
if attention_mask is not None:
|
105
|
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
|
106
|
+
# 2. Init new empty STM
|
107
|
+
new_stm = torch.zeros_like(self.stm.memory)
|
108
|
+
|
109
|
+
# 3. Get mean STM value from layers for mean interlayer memory attention
|
110
|
+
mean_stm = self.stm.memory.mean(dim=0) # [batch_size, stm_size, embed_dim]
|
111
|
+
# 4. Normalize mean STM layer
|
112
|
+
normalized_mean_stm = self.mean_stm_norm(mean_stm)
|
113
|
+
|
114
|
+
# 5. Run Short-Term Memory update for all layers
|
115
|
+
for i in range(self.num_layers):
|
116
|
+
# 6. Get current layer STM value
|
117
|
+
layer_stm = self.stm(i)
|
118
|
+
# 7. Expand layer STM to batch size, if it's not in batch mode
|
119
|
+
if layer_stm.size(0) == 1:
|
120
|
+
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
121
|
+
|
122
|
+
# 8. Mean interlayer memory attention
|
123
|
+
# a) normalize STM layer value
|
124
|
+
pre_normalized_layer_stm = self.mean_memory_norm_layers[i](layer_stm)
|
125
|
+
# b) calculate attention between STM layer and mean value of all STM layers (from previous interaction)
|
126
|
+
interlayer_stm = self.mean_attention_layers[i](pre_normalized_layer_stm, normalized_mean_stm, normalized_mean_stm, mask=None)
|
127
|
+
# c) combine updated interlayer state with current STM state in residual gate
|
128
|
+
updated_layer_stm = self.mean_residual_gate_layers[i](layer_stm, interlayer_stm)
|
129
|
+
|
130
|
+
# 9. Main memory attention
|
131
|
+
# a) get encoded data for current layer and normalize it
|
132
|
+
encoded_layer_data = self.memory_input_norm_layers[i](x[i])
|
133
|
+
# b) normalize STM layer value
|
134
|
+
normalized_layer_stm = self.memory_norm_layers[i](updated_layer_stm)
|
135
|
+
# c) print normalized STM stats in debug mode
|
136
|
+
if self.debug_mode and self.training:
|
137
|
+
if self.debug_step != 0 and self.debug_step % self.debug_interval == 0:
|
138
|
+
self.debug_step = 0
|
139
|
+
print(
|
140
|
+
f"Pre-Normalized STM stats - mean: {pre_normalized_layer_stm.mean().item():.4f}, std: {pre_normalized_layer_stm.std().item():.4f}")
|
141
|
+
print(f"Normalized STM stats - mean: {normalized_layer_stm.mean().item():.4f}, std: {normalized_layer_stm.std().item():.4f}")
|
142
|
+
else:
|
143
|
+
self.debug_step += 1
|
144
|
+
# d) calculate memory attention between STM layer and encoded data
|
145
|
+
new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=attention_mask)
|
146
|
+
# e) combine new updated layer STM with previous state in residual gate
|
147
|
+
new_stm[i] = self.residual_gate_layers[i](updated_layer_stm, new_layer_stm) # residual
|
148
|
+
# 10. Update all layers/models
|
88
149
|
self.stm.update_all(new_stm)
|
89
150
|
return self.stm.memory
|
rxnn/memory/gate.py
ADDED
@@ -0,0 +1,60 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
from typing import TypeAlias, Literal
|
4
|
+
|
5
|
+
ResidualGateType: TypeAlias = Literal['static', 'elementwise', 'linear']
|
6
|
+
|
7
|
+
|
8
|
+
class ResidualGate(nn.Module):
|
9
|
+
def __init__(
|
10
|
+
self,
|
11
|
+
stm_size: int,
|
12
|
+
use_gate: bool = False,
|
13
|
+
gate_type: ResidualGateType = 'static',
|
14
|
+
per_slot_gate: bool = True,
|
15
|
+
init_gate: float = 0.0,
|
16
|
+
use_tanh_gate: bool = True,
|
17
|
+
**kwargs,
|
18
|
+
):
|
19
|
+
super(ResidualGate, self).__init__(**kwargs)
|
20
|
+
self.use_gate = use_gate
|
21
|
+
self.per_slot_gate = per_slot_gate
|
22
|
+
self.gate_type = gate_type
|
23
|
+
self.use_tanh_gate = use_tanh_gate
|
24
|
+
|
25
|
+
if self.use_gate:
|
26
|
+
if self.gate_type == 'linear':
|
27
|
+
self.gate = nn.Linear(stm_size, stm_size if self.per_slot_gate else 1)
|
28
|
+
else:
|
29
|
+
gate_shape = (stm_size, 1) if self.per_slot_gate else (1,)
|
30
|
+
self.gate = nn.Parameter(torch.full(gate_shape, init_gate))
|
31
|
+
else:
|
32
|
+
self.gate = None
|
33
|
+
|
34
|
+
self.gate_activation = nn.Tanh() if self.use_tanh_gate else nn.Sigmoid()
|
35
|
+
|
36
|
+
def _dynamic_gate(self, old_value: torch.Tensor, new_value: torch.Tensor):
|
37
|
+
if self.gate_type == 'linear':
|
38
|
+
mean_residual = (new_value + old_value).mean(dim=-1)
|
39
|
+
gate_input = self.gate(mean_residual).unsqueeze(-1)
|
40
|
+
else:
|
41
|
+
mean_dim = -1 if self.per_slot_gate else [1, 2]
|
42
|
+
gate_input = self.gate * (new_value + old_value).mean(dim=mean_dim, keepdim=True)
|
43
|
+
return self.gate_activation(gate_input)
|
44
|
+
|
45
|
+
def _calculate_output(self, layer_gate: torch.Tensor, old_value: torch.Tensor, new_value: torch.Tensor) -> torch.Tensor:
|
46
|
+
if self.use_tanh_gate:
|
47
|
+
return (1 + layer_gate) * new_value + (1 - layer_gate) * old_value
|
48
|
+
else:
|
49
|
+
return layer_gate * new_value + (1 - layer_gate) * old_value
|
50
|
+
|
51
|
+
def forward(self, old_value: torch.Tensor, new_value: torch.Tensor) -> torch.Tensor:
|
52
|
+
if not self.use_gate:
|
53
|
+
return new_value + old_value
|
54
|
+
|
55
|
+
if self.gate_type == 'static':
|
56
|
+
layer_gate = self.gate_activation(self.gate)
|
57
|
+
else:
|
58
|
+
layer_gate = self._dynamic_gate(old_value, new_value)
|
59
|
+
|
60
|
+
return self._calculate_output(layer_gate, old_value, new_value)
|
rxnn/rxt/models.py
CHANGED
@@ -9,11 +9,13 @@ from ..transformers.models import ReactiveTransformerBase, ReactiveTransformerEn
|
|
9
9
|
from ..transformers.ff import get_activation_layer
|
10
10
|
from ..memory.stm import ShortTermMemory
|
11
11
|
from ..memory.norm import init_memory_norm
|
12
|
-
from ..memory.attention import StmMemoryAttention
|
12
|
+
from ..memory.attention import StmMemoryAttention, InterlayerStmMemoryAttention
|
13
|
+
from ..memory.gate import ResidualGate, ResidualGateType
|
13
14
|
from ..utils import get_model_size
|
14
15
|
from ..experimental.attention import init_experimental_attention
|
15
16
|
|
16
17
|
|
18
|
+
|
17
19
|
class RxTAlphaComponentConfig(TypedDict):
|
18
20
|
num_layers: int
|
19
21
|
vocab_size: int
|
@@ -260,15 +262,15 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
260
262
|
att_experts: int = None,
|
261
263
|
att_query_experts: int = None,
|
262
264
|
att_query_groups: int = None,
|
263
|
-
norm_type: str = 'rms',
|
265
|
+
norm_type: str = 'classic-rms',
|
264
266
|
norm_init_gate: float = -2.0,
|
265
267
|
norm_per_dim_scale: bool = False,
|
266
268
|
norm_decay: float = 0.9,
|
267
269
|
use_gated_residual: bool = False,
|
268
|
-
residual_per_slot_gate: bool =
|
269
|
-
|
270
|
-
|
271
|
-
use_tanh_residual_gate: bool =
|
270
|
+
residual_per_slot_gate: bool = True,
|
271
|
+
residual_gate_init: float = 3.0,
|
272
|
+
residual_gate_type: ResidualGateType = 'static',
|
273
|
+
use_tanh_residual_gate: bool = True,
|
272
274
|
debug_mode: bool = False,
|
273
275
|
debug_interval: int = 10,
|
274
276
|
**kwargs,
|
@@ -296,12 +298,153 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
296
298
|
memory_norm_layers = nn.ModuleList([init_memory_norm(norm_type, embed_dim, stm_size, decay=norm_decay,
|
297
299
|
init_gate=norm_init_gate, per_dim_scale=norm_per_dim_scale)
|
298
300
|
for _ in range(num_layers)])
|
301
|
+
memory_input_norm_layers = nn.ModuleList(nn.RMSNorm(embed_dim) for _ in range(num_layers))
|
299
302
|
attention_layers = nn.ModuleList([att_init() for _ in range(num_layers)])
|
303
|
+
residual_gates = nn.ModuleList([
|
304
|
+
ResidualGate(
|
305
|
+
stm_size, use_gate=use_gated_residual, gate_type=residual_gate_type, per_slot_gate=residual_per_slot_gate,
|
306
|
+
init_gate=residual_gate_init, use_tanh_gate=use_tanh_residual_gate
|
307
|
+
) for _ in range(num_layers)
|
308
|
+
])
|
309
|
+
|
300
310
|
self.model = StmMemoryAttention(
|
301
311
|
stm, attention_layers, memory_norm_layers,
|
302
|
-
|
303
|
-
|
304
|
-
|
312
|
+
memory_input_norm_layers, residual_gates,
|
313
|
+
debug_mode=debug_mode, debug_interval=debug_interval,
|
314
|
+
)
|
315
|
+
|
316
|
+
def freeze(self):
|
317
|
+
for param in self.parameters():
|
318
|
+
param.requires_grad = False
|
319
|
+
|
320
|
+
def unfreeze(self):
|
321
|
+
for param in self.parameters():
|
322
|
+
param.requires_grad = True
|
323
|
+
|
324
|
+
def load_shared_memory(self, stm: ShortTermMemory):
|
325
|
+
self.model.stm = stm
|
326
|
+
|
327
|
+
def update_max_len(self, max_seq_len: int):
|
328
|
+
self.model.update_max_len(max_seq_len)
|
329
|
+
|
330
|
+
def reset_memory(self, init_type: str = None):
|
331
|
+
self.model.stm.reset(init_type)
|
332
|
+
|
333
|
+
def clone_reset_memory(self):
|
334
|
+
self.model.stm.clone_detach_reset()
|
335
|
+
|
336
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
337
|
+
return self.model(x, attention_mask=attention_mask)
|
338
|
+
|
339
|
+
|
340
|
+
class RxTAlphaInterlayerMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
|
341
|
+
"""RxT-Alpha (Reactive Transformer) memory attention model with interlayer STM attention"""
|
342
|
+
|
343
|
+
def __init__(
|
344
|
+
self,
|
345
|
+
num_layers: int = 12,
|
346
|
+
embed_dim: int = 512,
|
347
|
+
att_heads: int = 16,
|
348
|
+
seq_len: int = 1024,
|
349
|
+
stm_size: int = 1024,
|
350
|
+
use_flash_attention: bool = False,
|
351
|
+
att_dropout: float = 0.0,
|
352
|
+
att_groups: int = 1,
|
353
|
+
att_type: str = 'sqa',
|
354
|
+
att_experts: int = None,
|
355
|
+
att_query_experts: int = None,
|
356
|
+
att_query_groups: int = None,
|
357
|
+
interlayer_att_dropout: float = 0.0,
|
358
|
+
interlayer_att_groups: int = 1,
|
359
|
+
interlayer_att_type: str = 'sqa',
|
360
|
+
interlayer_att_experts: int = None,
|
361
|
+
interlayer_att_query_experts: int = None,
|
362
|
+
interlayer_att_query_groups: int = None,
|
363
|
+
norm_type: str = 'classic-rms',
|
364
|
+
norm_init_gate: float = -2.0,
|
365
|
+
norm_per_dim_scale: bool = False,
|
366
|
+
norm_decay: float = 0.9,
|
367
|
+
use_gated_residual: bool = False,
|
368
|
+
residual_per_slot_gate: bool = True,
|
369
|
+
residual_gate_init: float = 3.0,
|
370
|
+
residual_gate_type: ResidualGateType = 'static',
|
371
|
+
use_tanh_residual_gate: bool = True,
|
372
|
+
debug_mode: bool = False,
|
373
|
+
debug_interval: int = 10,
|
374
|
+
**kwargs,
|
375
|
+
):
|
376
|
+
super(RxTAlphaInterlayerMemoryAttention, self).__init__(**kwargs)
|
377
|
+
|
378
|
+
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma',
|
379
|
+
'sqa'], 'Memory attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
|
380
|
+
|
381
|
+
rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
|
382
|
+
stm = ShortTermMemory(num_layers, embed_dim, stm_size)
|
383
|
+
|
384
|
+
if att_type in ['mha', 'gqa', 'mqa']:
|
385
|
+
att_init = lambda: init_attention(
|
386
|
+
embed_dim, att_heads, att_type, att_groups, rope=rope,
|
387
|
+
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
388
|
+
max_seq_len=seq_len, is_causal=False, rope_only_for_keys=True
|
389
|
+
)
|
390
|
+
else:
|
391
|
+
att_init = lambda: init_experimental_attention(
|
392
|
+
embed_dim, att_heads, att_type, att_groups, rope=rope,
|
393
|
+
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
394
|
+
max_seq_len=seq_len, is_causal=False, num_experts=att_experts,
|
395
|
+
num_query_experts=att_query_experts, num_query_groups=att_query_groups,
|
396
|
+
rope_only_for_keys=True
|
397
|
+
)
|
398
|
+
|
399
|
+
memory_norm_layers = nn.ModuleList([init_memory_norm(norm_type, embed_dim, stm_size, decay=norm_decay,
|
400
|
+
init_gate=norm_init_gate, per_dim_scale=norm_per_dim_scale)
|
401
|
+
for _ in range(num_layers)])
|
402
|
+
memory_input_norm_layers = nn.ModuleList(nn.RMSNorm(embed_dim) for _ in range(num_layers))
|
403
|
+
attention_layers = nn.ModuleList([att_init() for _ in range(num_layers)])
|
404
|
+
residual_gates = nn.ModuleList([
|
405
|
+
ResidualGate(
|
406
|
+
stm_size, use_gate=use_gated_residual, gate_type=residual_gate_type,
|
407
|
+
per_slot_gate=residual_per_slot_gate,
|
408
|
+
init_gate=residual_gate_init, use_tanh_gate=use_tanh_residual_gate
|
409
|
+
) for _ in range(num_layers)
|
410
|
+
])
|
411
|
+
|
412
|
+
# Interlayer attention
|
413
|
+
if interlayer_att_type in ['mha', 'gqa', 'mqa']:
|
414
|
+
interlayer_att_init = lambda: init_attention(
|
415
|
+
embed_dim, att_heads, interlayer_att_type, interlayer_att_groups, rope=None,
|
416
|
+
use_flash_attention=use_flash_attention, dropout=interlayer_att_dropout, is_causal=False
|
417
|
+
)
|
418
|
+
else:
|
419
|
+
interlayer_att_init = lambda: init_experimental_attention(
|
420
|
+
embed_dim, att_heads, interlayer_att_type, interlayer_att_groups, rope=None,
|
421
|
+
use_flash_attention=use_flash_attention, dropout=interlayer_att_dropout, is_causal=False,
|
422
|
+
num_experts=interlayer_att_experts, num_query_experts=interlayer_att_query_experts, num_query_groups=interlayer_att_query_groups
|
423
|
+
)
|
424
|
+
|
425
|
+
mean_attention_layers = nn.ModuleList([interlayer_att_init() for _ in range(num_layers)])
|
426
|
+
|
427
|
+
mean_stm_norm = init_memory_norm(
|
428
|
+
norm_type, embed_dim, stm_size, decay=norm_decay,
|
429
|
+
init_gate=norm_init_gate, per_dim_scale=norm_per_dim_scale
|
430
|
+
)
|
431
|
+
|
432
|
+
mean_memory_norm_layers = nn.ModuleList([init_memory_norm(norm_type, embed_dim, stm_size, decay=norm_decay,
|
433
|
+
init_gate=norm_init_gate, per_dim_scale=norm_per_dim_scale)
|
434
|
+
for _ in range(num_layers)])
|
435
|
+
|
436
|
+
mean_residual_gates = nn.ModuleList([
|
437
|
+
ResidualGate(
|
438
|
+
stm_size, use_gate=use_gated_residual, gate_type=residual_gate_type,
|
439
|
+
per_slot_gate=residual_per_slot_gate,
|
440
|
+
init_gate=residual_gate_init, use_tanh_gate=use_tanh_residual_gate
|
441
|
+
) for _ in range(num_layers)
|
442
|
+
])
|
443
|
+
|
444
|
+
self.model = InterlayerStmMemoryAttention(
|
445
|
+
stm, attention_layers, memory_norm_layers, memory_input_norm_layers, residual_gates,
|
446
|
+
mean_attention_layers, mean_memory_norm_layers, mean_residual_gates, mean_stm_norm,
|
447
|
+
debug_mode=debug_mode, debug_interval=debug_interval,
|
305
448
|
)
|
306
449
|
|
307
450
|
def freeze(self):
|
@@ -327,6 +470,7 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
327
470
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
328
471
|
return self.model(x, attention_mask=attention_mask)
|
329
472
|
|
473
|
+
|
330
474
|
class RxTAlphaCriticEncoder(RxTAlphaComponentBase, pipeline_tag="text-classification", license="apache-2.0"):
|
331
475
|
"""RxT-Alpha (Reactive Transformer) encoder model"""
|
332
476
|
|
rxnn/transformers/layers.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.nn as nn
|
3
|
+
from poetry.console.commands import self
|
4
|
+
|
3
5
|
from .attention import MultiHeadAttention
|
4
6
|
from .ff import FeedForward, GatedFeedForward
|
5
7
|
from .moe import MoeFeedForward, GatedMoeFeedForward
|
@@ -49,10 +51,12 @@ class ReactiveTransformerLayer(nn.Module):
|
|
49
51
|
self.norm1 = nn.RMSNorm(embed_dim)
|
50
52
|
self.norm2 = nn.RMSNorm(embed_dim)
|
51
53
|
self.norm3 = nn.RMSNorm(embed_dim)
|
54
|
+
self.stm_norm = nn.RMSNorm(embed_dim)
|
52
55
|
else:
|
53
56
|
self.norm1 = nn.LayerNorm(embed_dim)
|
54
57
|
self.norm2 = nn.LayerNorm(embed_dim)
|
55
58
|
self.norm3 = nn.LayerNorm(embed_dim)
|
59
|
+
self.stm_norm = nn.LayerNorm(embed_dim)
|
56
60
|
self.use_post_norm = use_post_norm
|
57
61
|
self.use_moe = use_moe
|
58
62
|
self.use_moe_att = use_moe_att
|
@@ -63,9 +67,11 @@ class ReactiveTransformerLayer(nn.Module):
|
|
63
67
|
if with_norms:
|
64
68
|
for param in self.norm2.parameters():
|
65
69
|
param.requires_grad_(is_trainable)
|
70
|
+
for param in self.stm_norm.parameters():
|
71
|
+
param.requires_grad_(is_trainable)
|
66
72
|
|
67
73
|
def memory_parameters(self) -> list[nn.Parameter]:
|
68
|
-
return list(self.memory_cross_attention.parameters()) + list(self.norm2.parameters())
|
74
|
+
return list(self.memory_cross_attention.parameters()) + list(self.norm2.parameters()) + list(self.stm_norm.parameters())
|
69
75
|
|
70
76
|
def not_memory_parameters(self) -> list[nn.Parameter]:
|
71
77
|
return (list(self.attention.parameters()) + list(self.norm1.parameters()) +
|
@@ -102,11 +108,8 @@ class ReactiveTransformerLayer(nn.Module):
|
|
102
108
|
residual = x
|
103
109
|
if not self.use_post_norm:
|
104
110
|
x = self.norm1(x)
|
105
|
-
if torch.isnan(x).any():
|
106
|
-
print("!!!!!!!!!!!!!!!!!!!!!! !!!!!!!!!!!!!!!!!!!!!! NaN detected in pre-norm (self-attention) output")
|
107
111
|
x = self.attention(x, x, x, mask=mask)
|
108
|
-
|
109
|
-
print("!!!!!!!!!!!!!!!!!!!!!! !!!!!!!!!!!!!!!!!!!!!! NaN detected in self-attention output")
|
112
|
+
|
110
113
|
x = residual + x
|
111
114
|
if self.use_post_norm:
|
112
115
|
x = self.norm1(x)
|
@@ -114,18 +117,13 @@ class ReactiveTransformerLayer(nn.Module):
|
|
114
117
|
residual = x
|
115
118
|
if not self.use_post_norm:
|
116
119
|
x = self.norm2(x)
|
117
|
-
if torch.isnan(x).any():
|
118
|
-
print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in pre-norm (cross-attention) output")
|
119
120
|
|
121
|
+
# normalize STM and prepare STM mask
|
122
|
+
stm = self.stm_norm(stm)
|
120
123
|
mem_mask = mask.squeeze(1).unsqueeze(-1).expand(-1, -1, -1, stm.size(1)) \
|
121
124
|
if mask is not None else None
|
122
125
|
|
123
|
-
if torch.isnan(stm).any():
|
124
|
-
print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in STM cross-attention input")
|
125
|
-
|
126
126
|
x = self.memory_cross_attention(x, stm, stm, mask=mem_mask)
|
127
|
-
if torch.isnan(x).any():
|
128
|
-
print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in cross-attention output")
|
129
127
|
x = residual + x
|
130
128
|
if self.use_post_norm:
|
131
129
|
x = self.norm2(x)
|
@@ -134,11 +132,7 @@ class ReactiveTransformerLayer(nn.Module):
|
|
134
132
|
residual = x
|
135
133
|
if not self.use_post_norm:
|
136
134
|
x = self.norm3(x)
|
137
|
-
if torch.isnan(x).any():
|
138
|
-
print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in pre-norm (ff) output")
|
139
135
|
x = self.ff(x)
|
140
|
-
if torch.isnan(x).any():
|
141
|
-
print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in ff output")
|
142
136
|
x = residual + x
|
143
137
|
if self.use_post_norm:
|
144
138
|
x = self.norm3(x)
|
rxnn/utils.py
CHANGED
@@ -1,11 +1,6 @@
|
|
1
1
|
import random, gc
|
2
|
-
from typing import Optional, Union, List, Dict, Any
|
3
|
-
|
4
2
|
import torch
|
5
3
|
import numpy as np
|
6
|
-
from huggingface_hub import PyTorchModelHubMixin
|
7
|
-
from huggingface_hub.hub_mixin import DataclassInstance
|
8
|
-
|
9
4
|
|
10
5
|
def human_format(num: int):
|
11
6
|
"""Format numbers to human-readable format."""
|
@@ -5,11 +5,12 @@ rxnn/experimental/attention.py,sha256=jlNS82INjycNEfmk3HtkIacUvT_ELhaCO2g-kZTvhX
|
|
5
5
|
rxnn/experimental/models.py,sha256=KheR1zSNJIaeVvpVAkEJwcuM5nOqQP0ZF08XhrtGJ8E,5387
|
6
6
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
7
7
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
|
-
rxnn/memory/attention.py,sha256=
|
8
|
+
rxnn/memory/attention.py,sha256=CReYJZNA5JRED_QWqX-yKqEKZTRX6DNCAB8uFLZtKxI,7513
|
9
|
+
rxnn/memory/gate.py,sha256=pR_H2y9C7S02QskoFAEC9Tmluut0k4GGlHgvZGiw6m4,2332
|
9
10
|
rxnn/memory/norm.py,sha256=cVjjhCLqR5K6-321SP_ObG17y-ddlcTJeCTXvW4vpk0,6675
|
10
11
|
rxnn/memory/stm.py,sha256=jv57gsH9XW19sLbxpRDqsp1yfsii_4Ef4Ncr_ztk-i4,3937
|
11
12
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
rxnn/rxt/models.py,sha256=
|
13
|
+
rxnn/rxt/models.py,sha256=Pb48Frl6HV4Wb9CZgYtmzch3k_4Jess3rhs7dY1I96k,22209
|
13
14
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
15
|
rxnn/training/base.py,sha256=CqaArEZYOdH64nmKfx28U3GI46TzO4oNkjf_hrF23Cw,11835
|
15
16
|
rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
|
@@ -26,14 +27,14 @@ rxnn/training/utils.py,sha256=ngDCm654NL3UsPy190Er4XPc9HI-OyEV6tDLMgEEvQc,6219
|
|
26
27
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
27
28
|
rxnn/transformers/attention.py,sha256=KRnKT6XUqAXElxV9y72mSpdTeiMgCKCCLqqxCFNTHmA,16372
|
28
29
|
rxnn/transformers/ff.py,sha256=WDjO-H9XWInoWnUnxiseIH6Kx5GlHP0zGJygwhcb1gc,2589
|
29
|
-
rxnn/transformers/layers.py,sha256=
|
30
|
+
rxnn/transformers/layers.py,sha256=fxjlbQG6cwxq-b2ei4DnohSQGH5gwy4GkfP9duTUvjw,8492
|
30
31
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
31
32
|
rxnn/transformers/models.py,sha256=TP0H9do53Z0vd8kpHMISBzMpHE5X9QIHcy0B-iJHuNQ,11711
|
32
33
|
rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
33
34
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
34
35
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
35
|
-
rxnn/utils.py,sha256=
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
39
|
-
rxnn-0.2.
|
36
|
+
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
37
|
+
rxnn-0.2.73.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
38
|
+
rxnn-0.2.73.dist-info/METADATA,sha256=gtoRMeFgBuOZs4lRKl9JGUxZ2X4C9K78Ee-NHLMqW4E,60420
|
39
|
+
rxnn-0.2.73.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
40
|
+
rxnn-0.2.73.dist-info/RECORD,,
|
File without changes
|
File without changes
|