rxnn 0.1.83__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/.DS_Store ADDED
Binary file
@@ -287,6 +287,7 @@ class SparseQueryAttention(MultiHeadAttention):
287
287
  k = self.k_proj(key).view(b, -1, self.num_groups, head_dim).transpose(1, 2)
288
288
  v = self.v_proj(value).view(b, -1, self.num_groups, head_dim).transpose(1, 2)
289
289
  else:
290
+ # Relative embedding version is not working without this strange mapping - it will be removed in next versions
290
291
  group_heads = self.num_heads // self.num_groups
291
292
  query_heads = self.num_heads // self.num_query_groups
292
293
  # Process Q
@@ -457,6 +458,7 @@ def init_experimental_attention(
457
458
  dropout: float = 0.0,
458
459
  rope: RotaryPositionalEmbedding = None,
459
460
  rope_only_for_query: bool = False,
461
+ rope_only_for_keys: bool = False,
460
462
  use_relative_embeddings: bool = False,
461
463
  max_seq_len: int = 1024,
462
464
  use_flash_attention: bool = False,
@@ -478,6 +480,7 @@ def init_experimental_attention(
478
480
  use_relative_embeddings=use_relative_embeddings,
479
481
  max_seq_len=max_seq_len,
480
482
  rope_only_for_query=rope_only_for_query,
483
+ rope_only_for_keys=rope_only_for_keys,
481
484
  use_flash_attention=use_flash_attention,
482
485
  is_causal=is_causal,
483
486
  use_bias=use_bias,
@@ -493,6 +496,7 @@ def init_experimental_attention(
493
496
  use_relative_embeddings=use_relative_embeddings,
494
497
  max_seq_len=max_seq_len,
495
498
  rope_only_for_query=rope_only_for_query,
499
+ rope_only_for_keys=rope_only_for_keys,
496
500
  use_flash_attention=use_flash_attention,
497
501
  is_causal=is_causal,
498
502
  use_bias=use_bias,
@@ -511,6 +515,7 @@ def init_experimental_attention(
511
515
  use_relative_embeddings=use_relative_embeddings,
512
516
  max_seq_len=max_seq_len,
513
517
  rope_only_for_query=rope_only_for_query,
518
+ rope_only_for_keys=rope_only_for_keys,
514
519
  use_flash_attention=use_flash_attention,
515
520
  is_causal=is_causal,
516
521
  use_bias=use_bias,
@@ -0,0 +1,42 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from .stm import ShortTermMemory
4
+
5
+ class StmMemoryAttention(nn.Module):
6
+ def __init__(
7
+ self,
8
+ stm: ShortTermMemory,
9
+ attention_layers: nn.ModuleList,
10
+ memory_norm_layers: nn.ModuleList,
11
+ *args,
12
+ **kwargs
13
+ ):
14
+ super(StmMemoryAttention, self).__init__(*args, **kwargs)
15
+ self.stm = stm
16
+ self.attention_layers = attention_layers
17
+ self.memory_norm_layers = memory_norm_layers
18
+ assert len(self.attention_layers) == len(self.memory_norm_layers) == self.stm.memory.size(0)
19
+ self.num_layers = len(attention_layers)
20
+
21
+ def update_max_len(self, max_seq_len: int):
22
+ for i in range(self.num_layers):
23
+ if self.attention_layers[i].rope is not None:
24
+ self.attention_layers[i].rope.update_max_len(max_seq_len)
25
+
26
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
27
+ mask = attention_mask.unsqueeze(1).unsqueeze(1).bool() if attention_mask is not None else None
28
+
29
+ new_stm = torch.zeros_like(self.stm.memory)
30
+ for i in range(self.num_layers):
31
+ layer_stm = self.stm(i)
32
+ # expand layer STM to batch size, if it's not in batch mode
33
+ if layer_stm.size(0) == 1:
34
+ layer_stm = layer_stm.expand(x.size(0), -1, -1)
35
+ encoded_layer_data = x[i]
36
+ normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
37
+ new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=mask)
38
+ # self.stm.update_layer(i, new_layer_stm + layer_stm)
39
+ new_stm[i] = new_layer_stm + layer_stm # residual
40
+ self.stm.update_all(new_stm)
41
+ return self.stm.memory
42
+
rxnn/memory/stm.py CHANGED
@@ -5,32 +5,45 @@ class ShortTermMemory(nn.Module):
5
5
  """Short-term memory module for the Attention-based Memory System"""
6
6
 
7
7
  def __init__(self, num_layers: int, embed_dim: int, stm_size: int, init_type: str = 'normal',
8
- is_trainable: bool = False, *args, **kwargs):
8
+ is_trainable: bool = False, legacy_init: bool = True, *args, **kwargs):
9
9
  super(ShortTermMemory, self).__init__(*args, **kwargs)
10
10
  self.num_layers = num_layers
11
11
  self.embed_dim = embed_dim
12
12
  self.stm_size = stm_size
13
+ self.batch_size = 1 # setting 1 as initial batch size (it will be normally used in inference/pre-training. Bigger batches are for RL stages)
13
14
  self.is_trainable = is_trainable
14
15
  assert init_type in ['normal', 'standard', 'uniform', 'ones', 'zeros'], \
15
16
  'STM init type must be one of "normal", "standard", "uniform", "ones", "zeros"'
17
+ self.init_type = init_type
18
+ stm = self._init_tensor()
19
+ if self.is_trainable:
20
+ self.memory = nn.Parameter(stm)
21
+ else:
22
+ self.register_buffer('memory', stm)
23
+ # Legacy init - temporary option to load old models with not-batched STM (they will be loaded, updated and then the option will be removed)
24
+ self.legacy_init = legacy_init
25
+
26
+ def _init_tensor(self, init_type: str = None):
27
+ init_type = init_type or self.init_type
28
+ stm_shape = (self.num_layers, self.stm_size, self.embed_dim) \
29
+ if self.legacy_init else (self.num_layers, self.batch_size, self.stm_size, self.embed_dim)
16
30
  if init_type == 'normal':
17
- stm = torch.normal(0, 0.02, (num_layers, stm_size, embed_dim))
31
+ return torch.normal(0, 0.02, stm_shape)
18
32
  elif init_type == 'standard':
19
- stm = torch.normal(0, 1, (num_layers, stm_size, embed_dim))
33
+ return torch.normal(0, 1, stm_shape)
20
34
  elif init_type == 'uniform':
21
- stm = torch.rand(num_layers, stm_size, embed_dim) * 0.02
35
+ return torch.rand(*stm_shape) * 0.02
22
36
  elif init_type == 'ones':
23
- stm = torch.ones(num_layers, stm_size, embed_dim)
37
+ return torch.ones(*stm_shape)
24
38
  else:
25
- stm = torch.zeros(num_layers, stm_size, embed_dim)
39
+ return torch.zeros(*stm_shape)
26
40
 
27
- if self.is_trainable:
28
- self.memory = nn.Parameter(stm)
29
- else:
30
- self.register_buffer('memory', stm)
41
+ def reset_legacy_(self):
42
+ self.legacy_init = False
43
+ self.memory = self._init_tensor()
31
44
 
32
45
  def forward(self, layer: int) -> torch.Tensor:
33
- return self.memory[layer].unsqueeze(0)
46
+ return self.memory[layer].unsqueeze(0) if self.legacy_init else self.memory[layer]
34
47
 
35
48
  def update_layer(self, layer: int, new_stm: torch.Tensor):
36
49
  self.memory[layer] = new_stm
@@ -50,4 +63,32 @@ class ShortTermMemory(nn.Module):
50
63
  self.requires_grad_(False)
51
64
  trained_stm = self.memory.clone()
52
65
  del self.memory
53
- self.register_buffer('memory', trained_stm)
66
+ self.register_buffer('memory', trained_stm)
67
+
68
+ def reset(self, init_type: str = None):
69
+ self.memory = self._init_tensor(init_type)
70
+
71
+ def resize(self, new_stm_size: int, init_type: str = None):
72
+ self.stm_size = new_stm_size
73
+ self.memory = self._init_tensor(init_type)
74
+
75
+ def batched_memory(self, batch_size: int, init_type: str = None):
76
+ if init_type is not None:
77
+ assert init_type in ['normal', 'standard', 'uniform', 'ones', 'zeros'], \
78
+ 'STM init type must be one of "normal", "standard", "uniform", "ones", "zeros"'
79
+ self.init_type = init_type
80
+ self.batch_size = batch_size
81
+ self.memory = self._init_tensor()
82
+
83
+ def single_memory(self, init_type: str = None, use_mean_from_batch: bool = False):
84
+ if init_type is not None:
85
+ assert init_type in ['normal', 'standard', 'uniform', 'ones', 'zeros'], \
86
+ 'STM init type must be one of "normal", "standard", "uniform", "ones", "zeros"'
87
+ self.init_type = init_type
88
+ self.batch_size = 1
89
+ if use_mean_from_batch:
90
+ batch_mean = self.memory.mean(dim=(1, 2, 3), keepdim=True)
91
+ self.memory = self._init_tensor()
92
+ self.memory.copy_(batch_mean)
93
+ else:
94
+ self.memory = self._init_tensor()
rxnn/rxt/models.py CHANGED
@@ -8,6 +8,8 @@ from ..transformers.layers import ReactiveTransformerLayer
8
8
  from ..transformers.models import ReactiveTransformerBase, ReactiveTransformerEncoder, ReactiveTransformerDecoder
9
9
  from ..transformers.ff import get_activation_layer
10
10
  from ..memory.stm import ShortTermMemory
11
+ from ..memory.norm import init_memory_norm
12
+ from ..memory.attention import StmMemoryAttention
11
13
  from ..utils import get_model_size
12
14
  from ..experimental.attention import init_experimental_attention
13
15
 
@@ -135,6 +137,22 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
135
137
  def load_shared_memory(self, stm: ShortTermMemory):
136
138
  self.model.stm = stm
137
139
 
140
+ def freeze_without_memory(self):
141
+ for param in self.model.parameters():
142
+ param.requires_grad_(False)
143
+ self.model.trainable_cross_attention_(True)
144
+
145
+ def freeze_memory(self):
146
+ self.model.trainable_cross_attention_(False)
147
+
148
+ def unfreeze_all(self):
149
+ for param in self.model.parameters():
150
+ param.requires_grad_(True)
151
+
152
+ def update_max_len(self, max_seq_len: int):
153
+ for layer in self.model.layers:
154
+ layer.update_max_len(max_seq_len)
155
+
138
156
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> Union[
139
157
  torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
140
158
  return self.model(x, attention_mask=attention_mask)
@@ -205,3 +223,56 @@ def build_rxt_alpha_for_pretraining(
205
223
 
206
224
  return encoder, decoder
207
225
 
226
+ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
227
+ """RxT-Alpha (Reactive Transformer) memory attention model"""
228
+ def __init__(
229
+ self,
230
+ num_layers: int = 12,
231
+ embed_dim: int = 512,
232
+ att_heads: int = 16,
233
+ seq_len: int = 1024,
234
+ stm_size: int = 1024,
235
+ use_flash_attention: bool = True,
236
+ att_dropout: float = 0.0,
237
+ norm_type: str = 'rms',
238
+ att_groups: int = 1,
239
+ att_type: str = 'sqa',
240
+ att_experts: int = None,
241
+ att_query_experts: int = None,
242
+ att_query_groups: int = None,
243
+ **kwargs,
244
+ ):
245
+ super(RxTAlphaMemoryAttention, self).__init__(**kwargs)
246
+
247
+ assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa'], 'Memory attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
248
+
249
+ rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
250
+ stm = ShortTermMemory(num_layers, embed_dim, stm_size)
251
+
252
+ if att_type in ['mha', 'gqa', 'mqa']:
253
+ att_init = lambda: init_attention(embed_dim, att_heads, att_type, att_groups, rope=rope,
254
+ use_flash_attention=use_flash_attention, dropout=att_dropout,
255
+ max_seq_len=seq_len, is_causal=False, rope_only_for_keys=True)
256
+ else:
257
+ att_init = lambda: init_experimental_attention(embed_dim, att_heads, att_type, att_groups, rope=rope,
258
+ use_flash_attention=use_flash_attention, dropout=att_dropout,
259
+ max_seq_len=seq_len, is_causal=False, num_experts=att_experts,
260
+ num_query_experts=att_query_experts,
261
+ num_query_groups=att_query_groups, rope_only_for_keys=True)
262
+
263
+ memory_norm_layers = nn.ModuleList([init_memory_norm(norm_type, embed_dim, stm_size) for _ in range(num_layers)])
264
+ attention_layers = nn.ModuleList([att_init() for _ in range(num_layers)])
265
+ self.model = StmMemoryAttention(stm, attention_layers, memory_norm_layers)
266
+
267
+ def load_shared_memory(self, stm: ShortTermMemory):
268
+ self.model.stm = stm
269
+
270
+ def update_max_len(self, max_seq_len: int):
271
+ self.model.update_max_len(max_seq_len)
272
+
273
+ def reset_memory(self, init_type: str = None):
274
+ self.model.stm.reset_memory(init_type)
275
+
276
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
277
+ return self.model(x, attention_mask=attention_mask)
278
+
rxnn/training/bml.py CHANGED
@@ -1,46 +1,12 @@
1
1
  import torch
2
- import torch.nn as nn
3
2
  import torch.nn.functional as F
4
3
  from torch.nn.parallel import DistributedDataParallel
5
4
  import math
6
- from huggingface_hub import PyTorchModelHubMixin
7
5
  from typing import Union
8
6
  import torch.distributed as dist
9
- from ..transformers.models import ReactiveTransformerEncoder, ReactiveTransformerDecoder
7
+ from ..transformers.models import ReactiveTransformerDecoder
10
8
  from ..training.base import BaseTrainer
11
-
12
- class MLMHead(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
13
- def __init__(self, embed_dim: int, vocab_size: int, *args, **kwargs):
14
- super(MLMHead, self).__init__(*args, **kwargs)
15
- self.dense = nn.Linear(embed_dim, embed_dim)
16
- self.act = nn.GELU()
17
- self.layer_norm = nn.LayerNorm(embed_dim)
18
- self.decoder = nn.Linear(embed_dim, vocab_size)
19
-
20
- def forward(self, hidden_states):
21
- x = self.dense(hidden_states)
22
- x = self.act(x)
23
- x = self.layer_norm(x)
24
- return self.decoder(x)
25
-
26
-
27
- class MLMTrainingModel(nn.Module):
28
- def __init__(
29
- self,
30
- encoder: ReactiveTransformerEncoder,
31
- mlm_head: MLMHead,
32
- *args,
33
- **kwargs
34
- ):
35
- super(MLMTrainingModel, self).__init__(*args, **kwargs)
36
- self.encoder = encoder
37
- self.mlm_head = mlm_head
38
-
39
- def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
40
- h, _ = self.encoder(x, attention_mask=attention_mask)
41
- y = self.mlm_head(h)
42
- return y
43
-
9
+ from .models import MLMTrainingModel, JointTrainingModel
44
10
 
45
11
  class MLMTrainer(BaseTrainer):
46
12
  def __init__(
@@ -242,29 +208,6 @@ class AutoregressiveTrainer(BaseTrainer):
242
208
  self.model.train()
243
209
  return avg_loss, metrics
244
210
 
245
-
246
- class JointTrainingModel(nn.Module):
247
- def __init__(
248
- self,
249
- encoder: ReactiveTransformerEncoder,
250
- decoder: ReactiveTransformerDecoder,
251
- mlm_head: MLMHead,
252
- *args,
253
- **kwargs
254
- ):
255
- super(JointTrainingModel, self).__init__(*args, **kwargs)
256
- self.encoder = encoder
257
- self.mlm_head = mlm_head
258
- self.decoder = decoder
259
-
260
- def forward(self, x_e: torch.Tensor, x_d: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[
261
- torch.Tensor, torch.Tensor]:
262
- encoder_result, _ = self.encoder(x_e, attention_mask=attention_mask)
263
- y_e = self.mlm_head(encoder_result)
264
- y_d = self.decoder(x_d, attention_mask=attention_mask)
265
- return y_e, y_d
266
-
267
-
268
211
  class JointLMTrainer(BaseTrainer):
269
212
  """"
270
213
  It's not recommended to use Joint LM Training in current implementation. More info soon