rxnn 0.1.83__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/.DS_Store +0 -0
- rxnn/experimental/attention.py +5 -0
- rxnn/memory/attention.py +42 -0
- rxnn/memory/stm.py +53 -12
- rxnn/rxt/models.py +71 -0
- rxnn/training/bml.py +2 -59
- rxnn/training/callbacks.py +302 -39
- rxnn/training/dataset.py +344 -1
- rxnn/training/models.py +142 -0
- rxnn/training/mrl.py +808 -0
- rxnn/training/reward.py +111 -0
- rxnn/training/rl.py +69 -0
- rxnn/training/utils.py +148 -0
- rxnn/transformers/attention.py +10 -0
- rxnn/transformers/layers.py +6 -0
- rxnn/transformers/models.py +16 -4
- rxnn/transformers/positional.py +7 -0
- rxnn/transformers/sampler.py +283 -9
- {rxnn-0.1.83.dist-info → rxnn-0.2.0.dist-info}/METADATA +11 -9
- rxnn-0.2.0.dist-info/RECORD +38 -0
- rxnn-0.1.83.dist-info/RECORD +0 -31
- {rxnn-0.1.83.dist-info → rxnn-0.2.0.dist-info}/LICENSE +0 -0
- {rxnn-0.1.83.dist-info → rxnn-0.2.0.dist-info}/WHEEL +0 -0
rxnn/.DS_Store
ADDED
Binary file
|
rxnn/experimental/attention.py
CHANGED
@@ -287,6 +287,7 @@ class SparseQueryAttention(MultiHeadAttention):
|
|
287
287
|
k = self.k_proj(key).view(b, -1, self.num_groups, head_dim).transpose(1, 2)
|
288
288
|
v = self.v_proj(value).view(b, -1, self.num_groups, head_dim).transpose(1, 2)
|
289
289
|
else:
|
290
|
+
# Relative embedding version is not working without this strange mapping - it will be removed in next versions
|
290
291
|
group_heads = self.num_heads // self.num_groups
|
291
292
|
query_heads = self.num_heads // self.num_query_groups
|
292
293
|
# Process Q
|
@@ -457,6 +458,7 @@ def init_experimental_attention(
|
|
457
458
|
dropout: float = 0.0,
|
458
459
|
rope: RotaryPositionalEmbedding = None,
|
459
460
|
rope_only_for_query: bool = False,
|
461
|
+
rope_only_for_keys: bool = False,
|
460
462
|
use_relative_embeddings: bool = False,
|
461
463
|
max_seq_len: int = 1024,
|
462
464
|
use_flash_attention: bool = False,
|
@@ -478,6 +480,7 @@ def init_experimental_attention(
|
|
478
480
|
use_relative_embeddings=use_relative_embeddings,
|
479
481
|
max_seq_len=max_seq_len,
|
480
482
|
rope_only_for_query=rope_only_for_query,
|
483
|
+
rope_only_for_keys=rope_only_for_keys,
|
481
484
|
use_flash_attention=use_flash_attention,
|
482
485
|
is_causal=is_causal,
|
483
486
|
use_bias=use_bias,
|
@@ -493,6 +496,7 @@ def init_experimental_attention(
|
|
493
496
|
use_relative_embeddings=use_relative_embeddings,
|
494
497
|
max_seq_len=max_seq_len,
|
495
498
|
rope_only_for_query=rope_only_for_query,
|
499
|
+
rope_only_for_keys=rope_only_for_keys,
|
496
500
|
use_flash_attention=use_flash_attention,
|
497
501
|
is_causal=is_causal,
|
498
502
|
use_bias=use_bias,
|
@@ -511,6 +515,7 @@ def init_experimental_attention(
|
|
511
515
|
use_relative_embeddings=use_relative_embeddings,
|
512
516
|
max_seq_len=max_seq_len,
|
513
517
|
rope_only_for_query=rope_only_for_query,
|
518
|
+
rope_only_for_keys=rope_only_for_keys,
|
514
519
|
use_flash_attention=use_flash_attention,
|
515
520
|
is_causal=is_causal,
|
516
521
|
use_bias=use_bias,
|
rxnn/memory/attention.py
ADDED
@@ -0,0 +1,42 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
from .stm import ShortTermMemory
|
4
|
+
|
5
|
+
class StmMemoryAttention(nn.Module):
|
6
|
+
def __init__(
|
7
|
+
self,
|
8
|
+
stm: ShortTermMemory,
|
9
|
+
attention_layers: nn.ModuleList,
|
10
|
+
memory_norm_layers: nn.ModuleList,
|
11
|
+
*args,
|
12
|
+
**kwargs
|
13
|
+
):
|
14
|
+
super(StmMemoryAttention, self).__init__(*args, **kwargs)
|
15
|
+
self.stm = stm
|
16
|
+
self.attention_layers = attention_layers
|
17
|
+
self.memory_norm_layers = memory_norm_layers
|
18
|
+
assert len(self.attention_layers) == len(self.memory_norm_layers) == self.stm.memory.size(0)
|
19
|
+
self.num_layers = len(attention_layers)
|
20
|
+
|
21
|
+
def update_max_len(self, max_seq_len: int):
|
22
|
+
for i in range(self.num_layers):
|
23
|
+
if self.attention_layers[i].rope is not None:
|
24
|
+
self.attention_layers[i].rope.update_max_len(max_seq_len)
|
25
|
+
|
26
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
27
|
+
mask = attention_mask.unsqueeze(1).unsqueeze(1).bool() if attention_mask is not None else None
|
28
|
+
|
29
|
+
new_stm = torch.zeros_like(self.stm.memory)
|
30
|
+
for i in range(self.num_layers):
|
31
|
+
layer_stm = self.stm(i)
|
32
|
+
# expand layer STM to batch size, if it's not in batch mode
|
33
|
+
if layer_stm.size(0) == 1:
|
34
|
+
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
35
|
+
encoded_layer_data = x[i]
|
36
|
+
normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
|
37
|
+
new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=mask)
|
38
|
+
# self.stm.update_layer(i, new_layer_stm + layer_stm)
|
39
|
+
new_stm[i] = new_layer_stm + layer_stm # residual
|
40
|
+
self.stm.update_all(new_stm)
|
41
|
+
return self.stm.memory
|
42
|
+
|
rxnn/memory/stm.py
CHANGED
@@ -5,32 +5,45 @@ class ShortTermMemory(nn.Module):
|
|
5
5
|
"""Short-term memory module for the Attention-based Memory System"""
|
6
6
|
|
7
7
|
def __init__(self, num_layers: int, embed_dim: int, stm_size: int, init_type: str = 'normal',
|
8
|
-
is_trainable: bool = False, *args, **kwargs):
|
8
|
+
is_trainable: bool = False, legacy_init: bool = True, *args, **kwargs):
|
9
9
|
super(ShortTermMemory, self).__init__(*args, **kwargs)
|
10
10
|
self.num_layers = num_layers
|
11
11
|
self.embed_dim = embed_dim
|
12
12
|
self.stm_size = stm_size
|
13
|
+
self.batch_size = 1 # setting 1 as initial batch size (it will be normally used in inference/pre-training. Bigger batches are for RL stages)
|
13
14
|
self.is_trainable = is_trainable
|
14
15
|
assert init_type in ['normal', 'standard', 'uniform', 'ones', 'zeros'], \
|
15
16
|
'STM init type must be one of "normal", "standard", "uniform", "ones", "zeros"'
|
17
|
+
self.init_type = init_type
|
18
|
+
stm = self._init_tensor()
|
19
|
+
if self.is_trainable:
|
20
|
+
self.memory = nn.Parameter(stm)
|
21
|
+
else:
|
22
|
+
self.register_buffer('memory', stm)
|
23
|
+
# Legacy init - temporary option to load old models with not-batched STM (they will be loaded, updated and then the option will be removed)
|
24
|
+
self.legacy_init = legacy_init
|
25
|
+
|
26
|
+
def _init_tensor(self, init_type: str = None):
|
27
|
+
init_type = init_type or self.init_type
|
28
|
+
stm_shape = (self.num_layers, self.stm_size, self.embed_dim) \
|
29
|
+
if self.legacy_init else (self.num_layers, self.batch_size, self.stm_size, self.embed_dim)
|
16
30
|
if init_type == 'normal':
|
17
|
-
|
31
|
+
return torch.normal(0, 0.02, stm_shape)
|
18
32
|
elif init_type == 'standard':
|
19
|
-
|
33
|
+
return torch.normal(0, 1, stm_shape)
|
20
34
|
elif init_type == 'uniform':
|
21
|
-
|
35
|
+
return torch.rand(*stm_shape) * 0.02
|
22
36
|
elif init_type == 'ones':
|
23
|
-
|
37
|
+
return torch.ones(*stm_shape)
|
24
38
|
else:
|
25
|
-
|
39
|
+
return torch.zeros(*stm_shape)
|
26
40
|
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
self.register_buffer('memory', stm)
|
41
|
+
def reset_legacy_(self):
|
42
|
+
self.legacy_init = False
|
43
|
+
self.memory = self._init_tensor()
|
31
44
|
|
32
45
|
def forward(self, layer: int) -> torch.Tensor:
|
33
|
-
return self.memory[layer].unsqueeze(0)
|
46
|
+
return self.memory[layer].unsqueeze(0) if self.legacy_init else self.memory[layer]
|
34
47
|
|
35
48
|
def update_layer(self, layer: int, new_stm: torch.Tensor):
|
36
49
|
self.memory[layer] = new_stm
|
@@ -50,4 +63,32 @@ class ShortTermMemory(nn.Module):
|
|
50
63
|
self.requires_grad_(False)
|
51
64
|
trained_stm = self.memory.clone()
|
52
65
|
del self.memory
|
53
|
-
self.register_buffer('memory', trained_stm)
|
66
|
+
self.register_buffer('memory', trained_stm)
|
67
|
+
|
68
|
+
def reset(self, init_type: str = None):
|
69
|
+
self.memory = self._init_tensor(init_type)
|
70
|
+
|
71
|
+
def resize(self, new_stm_size: int, init_type: str = None):
|
72
|
+
self.stm_size = new_stm_size
|
73
|
+
self.memory = self._init_tensor(init_type)
|
74
|
+
|
75
|
+
def batched_memory(self, batch_size: int, init_type: str = None):
|
76
|
+
if init_type is not None:
|
77
|
+
assert init_type in ['normal', 'standard', 'uniform', 'ones', 'zeros'], \
|
78
|
+
'STM init type must be one of "normal", "standard", "uniform", "ones", "zeros"'
|
79
|
+
self.init_type = init_type
|
80
|
+
self.batch_size = batch_size
|
81
|
+
self.memory = self._init_tensor()
|
82
|
+
|
83
|
+
def single_memory(self, init_type: str = None, use_mean_from_batch: bool = False):
|
84
|
+
if init_type is not None:
|
85
|
+
assert init_type in ['normal', 'standard', 'uniform', 'ones', 'zeros'], \
|
86
|
+
'STM init type must be one of "normal", "standard", "uniform", "ones", "zeros"'
|
87
|
+
self.init_type = init_type
|
88
|
+
self.batch_size = 1
|
89
|
+
if use_mean_from_batch:
|
90
|
+
batch_mean = self.memory.mean(dim=(1, 2, 3), keepdim=True)
|
91
|
+
self.memory = self._init_tensor()
|
92
|
+
self.memory.copy_(batch_mean)
|
93
|
+
else:
|
94
|
+
self.memory = self._init_tensor()
|
rxnn/rxt/models.py
CHANGED
@@ -8,6 +8,8 @@ from ..transformers.layers import ReactiveTransformerLayer
|
|
8
8
|
from ..transformers.models import ReactiveTransformerBase, ReactiveTransformerEncoder, ReactiveTransformerDecoder
|
9
9
|
from ..transformers.ff import get_activation_layer
|
10
10
|
from ..memory.stm import ShortTermMemory
|
11
|
+
from ..memory.norm import init_memory_norm
|
12
|
+
from ..memory.attention import StmMemoryAttention
|
11
13
|
from ..utils import get_model_size
|
12
14
|
from ..experimental.attention import init_experimental_attention
|
13
15
|
|
@@ -135,6 +137,22 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
|
|
135
137
|
def load_shared_memory(self, stm: ShortTermMemory):
|
136
138
|
self.model.stm = stm
|
137
139
|
|
140
|
+
def freeze_without_memory(self):
|
141
|
+
for param in self.model.parameters():
|
142
|
+
param.requires_grad_(False)
|
143
|
+
self.model.trainable_cross_attention_(True)
|
144
|
+
|
145
|
+
def freeze_memory(self):
|
146
|
+
self.model.trainable_cross_attention_(False)
|
147
|
+
|
148
|
+
def unfreeze_all(self):
|
149
|
+
for param in self.model.parameters():
|
150
|
+
param.requires_grad_(True)
|
151
|
+
|
152
|
+
def update_max_len(self, max_seq_len: int):
|
153
|
+
for layer in self.model.layers:
|
154
|
+
layer.update_max_len(max_seq_len)
|
155
|
+
|
138
156
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> Union[
|
139
157
|
torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
140
158
|
return self.model(x, attention_mask=attention_mask)
|
@@ -205,3 +223,56 @@ def build_rxt_alpha_for_pretraining(
|
|
205
223
|
|
206
224
|
return encoder, decoder
|
207
225
|
|
226
|
+
class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
|
227
|
+
"""RxT-Alpha (Reactive Transformer) memory attention model"""
|
228
|
+
def __init__(
|
229
|
+
self,
|
230
|
+
num_layers: int = 12,
|
231
|
+
embed_dim: int = 512,
|
232
|
+
att_heads: int = 16,
|
233
|
+
seq_len: int = 1024,
|
234
|
+
stm_size: int = 1024,
|
235
|
+
use_flash_attention: bool = True,
|
236
|
+
att_dropout: float = 0.0,
|
237
|
+
norm_type: str = 'rms',
|
238
|
+
att_groups: int = 1,
|
239
|
+
att_type: str = 'sqa',
|
240
|
+
att_experts: int = None,
|
241
|
+
att_query_experts: int = None,
|
242
|
+
att_query_groups: int = None,
|
243
|
+
**kwargs,
|
244
|
+
):
|
245
|
+
super(RxTAlphaMemoryAttention, self).__init__(**kwargs)
|
246
|
+
|
247
|
+
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa'], 'Memory attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
|
248
|
+
|
249
|
+
rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
|
250
|
+
stm = ShortTermMemory(num_layers, embed_dim, stm_size)
|
251
|
+
|
252
|
+
if att_type in ['mha', 'gqa', 'mqa']:
|
253
|
+
att_init = lambda: init_attention(embed_dim, att_heads, att_type, att_groups, rope=rope,
|
254
|
+
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
255
|
+
max_seq_len=seq_len, is_causal=False, rope_only_for_keys=True)
|
256
|
+
else:
|
257
|
+
att_init = lambda: init_experimental_attention(embed_dim, att_heads, att_type, att_groups, rope=rope,
|
258
|
+
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
259
|
+
max_seq_len=seq_len, is_causal=False, num_experts=att_experts,
|
260
|
+
num_query_experts=att_query_experts,
|
261
|
+
num_query_groups=att_query_groups, rope_only_for_keys=True)
|
262
|
+
|
263
|
+
memory_norm_layers = nn.ModuleList([init_memory_norm(norm_type, embed_dim, stm_size) for _ in range(num_layers)])
|
264
|
+
attention_layers = nn.ModuleList([att_init() for _ in range(num_layers)])
|
265
|
+
self.model = StmMemoryAttention(stm, attention_layers, memory_norm_layers)
|
266
|
+
|
267
|
+
def load_shared_memory(self, stm: ShortTermMemory):
|
268
|
+
self.model.stm = stm
|
269
|
+
|
270
|
+
def update_max_len(self, max_seq_len: int):
|
271
|
+
self.model.update_max_len(max_seq_len)
|
272
|
+
|
273
|
+
def reset_memory(self, init_type: str = None):
|
274
|
+
self.model.stm.reset_memory(init_type)
|
275
|
+
|
276
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
277
|
+
return self.model(x, attention_mask=attention_mask)
|
278
|
+
|
rxnn/training/bml.py
CHANGED
@@ -1,46 +1,12 @@
|
|
1
1
|
import torch
|
2
|
-
import torch.nn as nn
|
3
2
|
import torch.nn.functional as F
|
4
3
|
from torch.nn.parallel import DistributedDataParallel
|
5
4
|
import math
|
6
|
-
from huggingface_hub import PyTorchModelHubMixin
|
7
5
|
from typing import Union
|
8
6
|
import torch.distributed as dist
|
9
|
-
from ..transformers.models import
|
7
|
+
from ..transformers.models import ReactiveTransformerDecoder
|
10
8
|
from ..training.base import BaseTrainer
|
11
|
-
|
12
|
-
class MLMHead(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
|
13
|
-
def __init__(self, embed_dim: int, vocab_size: int, *args, **kwargs):
|
14
|
-
super(MLMHead, self).__init__(*args, **kwargs)
|
15
|
-
self.dense = nn.Linear(embed_dim, embed_dim)
|
16
|
-
self.act = nn.GELU()
|
17
|
-
self.layer_norm = nn.LayerNorm(embed_dim)
|
18
|
-
self.decoder = nn.Linear(embed_dim, vocab_size)
|
19
|
-
|
20
|
-
def forward(self, hidden_states):
|
21
|
-
x = self.dense(hidden_states)
|
22
|
-
x = self.act(x)
|
23
|
-
x = self.layer_norm(x)
|
24
|
-
return self.decoder(x)
|
25
|
-
|
26
|
-
|
27
|
-
class MLMTrainingModel(nn.Module):
|
28
|
-
def __init__(
|
29
|
-
self,
|
30
|
-
encoder: ReactiveTransformerEncoder,
|
31
|
-
mlm_head: MLMHead,
|
32
|
-
*args,
|
33
|
-
**kwargs
|
34
|
-
):
|
35
|
-
super(MLMTrainingModel, self).__init__(*args, **kwargs)
|
36
|
-
self.encoder = encoder
|
37
|
-
self.mlm_head = mlm_head
|
38
|
-
|
39
|
-
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
40
|
-
h, _ = self.encoder(x, attention_mask=attention_mask)
|
41
|
-
y = self.mlm_head(h)
|
42
|
-
return y
|
43
|
-
|
9
|
+
from .models import MLMTrainingModel, JointTrainingModel
|
44
10
|
|
45
11
|
class MLMTrainer(BaseTrainer):
|
46
12
|
def __init__(
|
@@ -242,29 +208,6 @@ class AutoregressiveTrainer(BaseTrainer):
|
|
242
208
|
self.model.train()
|
243
209
|
return avg_loss, metrics
|
244
210
|
|
245
|
-
|
246
|
-
class JointTrainingModel(nn.Module):
|
247
|
-
def __init__(
|
248
|
-
self,
|
249
|
-
encoder: ReactiveTransformerEncoder,
|
250
|
-
decoder: ReactiveTransformerDecoder,
|
251
|
-
mlm_head: MLMHead,
|
252
|
-
*args,
|
253
|
-
**kwargs
|
254
|
-
):
|
255
|
-
super(JointTrainingModel, self).__init__(*args, **kwargs)
|
256
|
-
self.encoder = encoder
|
257
|
-
self.mlm_head = mlm_head
|
258
|
-
self.decoder = decoder
|
259
|
-
|
260
|
-
def forward(self, x_e: torch.Tensor, x_d: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[
|
261
|
-
torch.Tensor, torch.Tensor]:
|
262
|
-
encoder_result, _ = self.encoder(x_e, attention_mask=attention_mask)
|
263
|
-
y_e = self.mlm_head(encoder_result)
|
264
|
-
y_d = self.decoder(x_d, attention_mask=attention_mask)
|
265
|
-
return y_e, y_d
|
266
|
-
|
267
|
-
|
268
211
|
class JointLMTrainer(BaseTrainer):
|
269
212
|
""""
|
270
213
|
It's not recommended to use Joint LM Training in current implementation. More info soon
|