rxnn 0.1.83__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. {rxnn-0.1.83 → rxnn-0.2.1}/PKG-INFO +11 -9
  2. {rxnn-0.1.83 → rxnn-0.2.1}/README.md +10 -8
  3. {rxnn-0.1.83 → rxnn-0.2.1}/pyproject.toml +1 -1
  4. rxnn-0.2.1/src/rxnn/.DS_Store +0 -0
  5. {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/experimental/attention.py +5 -0
  6. rxnn-0.2.1/src/rxnn/memory/attention.py +42 -0
  7. rxnn-0.2.1/src/rxnn/memory/stm.py +96 -0
  8. {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/rxt/models.py +71 -0
  9. {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/training/bml.py +2 -59
  10. {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/training/callbacks.py +302 -39
  11. {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/training/dataset.py +344 -1
  12. rxnn-0.2.1/src/rxnn/training/models.py +142 -0
  13. rxnn-0.2.1/src/rxnn/training/mrl.py +808 -0
  14. rxnn-0.2.1/src/rxnn/training/reward.py +111 -0
  15. rxnn-0.2.1/src/rxnn/training/rl.py +69 -0
  16. rxnn-0.2.1/src/rxnn/training/utils.py +148 -0
  17. {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/transformers/attention.py +10 -0
  18. {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/transformers/layers.py +6 -0
  19. {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/transformers/models.py +16 -4
  20. {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/transformers/positional.py +7 -0
  21. rxnn-0.2.1/src/rxnn/transformers/sampler.py +443 -0
  22. rxnn-0.1.83/src/rxnn/memory/stm.py +0 -53
  23. rxnn-0.1.83/src/rxnn/transformers/sampler.py +0 -169
  24. {rxnn-0.1.83 → rxnn-0.2.1}/LICENSE +0 -0
  25. {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/__init__.py +0 -0
  26. {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/experimental/__init__.py +0 -0
  27. {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/experimental/models.py +0 -0
  28. {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/experimental/moe.py +0 -0
  29. {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/memory/__init__.py +0 -0
  30. {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/memory/norm.py +0 -0
  31. {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/rxt/__init__.py +0 -0
  32. {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/training/__init__.py +0 -0
  33. {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/training/base.py +0 -0
  34. {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/training/scheduler.py +0 -0
  35. {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/training/tokenizer.py +0 -0
  36. {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/transformers/__init__.py +0 -0
  37. {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/transformers/ff.py +0 -0
  38. {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/transformers/mask.py +0 -0
  39. {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/transformers/moe.py +0 -0
  40. {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.83
3
+ Version: 0.2.1
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -23,8 +23,10 @@ Project-URL: Homepage, https://rxai.dev/rxnn
23
23
  Project-URL: Repository, https://github.com/RxAI-dev/rxnn/python
24
24
  Description-Content-Type: text/markdown
25
25
 
26
- <img src="https://raw.githubusercontent.com/RxAI-dev/RxNN/refs/heads/main/assets/logo/logo_rxai.webp" width="300" />
27
- <img src="https://raw.githubusercontent.com/RxAI-dev/RxNN/refs/heads/main/assets/logo/logo_rxnn.webp" width="300" />
26
+ <span>
27
+ <img src="https://raw.githubusercontent.com/RxAI-dev/RxNN/refs/heads/main/assets/logo/logo_rxai_v2.png" width="400" />
28
+ <img src="https://raw.githubusercontent.com/RxAI-dev/RxNN/refs/heads/main/assets/logo/logo_rxnn_v2.png" width="400" />
29
+ </span>
28
30
 
29
31
  # Reactive AI - RxNN
30
32
  ## Reactive Neural Networks Platform
@@ -61,8 +63,8 @@ We are working on three new reactive architectures, that progressively advance f
61
63
 
62
64
  Each new architecture is based on the previous one and adding new features/abilities. They will be progressively
63
65
  released with next versions of **RxNN** framework:
64
- - 0.1.x: Reactive Transformer base models, Base Model Learning (pre-training/fine-tuning) & Transformers extensions (MoE Attention, Short-Term Memory, etc.)
65
- - 0.2.x: Memory Reinforcement Learning (MRL) for Short-Term Memory & Reactive Transformer, Attention-based Memory System details
66
+ - 0.1.x (Released): Reactive Transformer base models, Base Model Learning (pre-training/fine-tuning) & Transformers extensions (MoE Attention, Short-Term Memory, etc.)
67
+ - 0.2.x (Released): Memory Reinforcement Learning (MRL) for Short-Term Memory & Reactive Transformer, Attention-based Memory System details
66
68
  - 0.3.x: Reinforcement Learning from Human Feedback for Reactive models (RxRLHF), basic Tensor Reactive
67
69
  Extensions (TRX/Rust) for full Reactive Transformer, RxT-Alpha release (+following models - RxT-Beta, etc.)
68
70
  - 0.4.x: Preactor base models, Tensor Database (TDB/Rust) for Long-Term Memory, mxRAG/revRAG subsystems
@@ -126,7 +128,7 @@ Submodules:
126
128
  - `rxnn.transformers.moe` - Mixture-of-Experts feed forward layers - `MoeFeedForward` & `GatedMoeFeedForward` (recommended)
127
129
  - `rxnn.transformer.layers` - complete reactive/classic transformer layers - `ReactiveTransformerLayer` & `ClassicTransformerLayer`
128
130
  - `rxnn.transformer.models` - reactive/classic transformer models - `ReactiveTransformerEncoder`, `ReactiveTransformerDecoder` & `ClassicTransformerEncoder`, `ClassicTransformerDecoder`
129
- - `rxnn.transformer.sampler` - samplers for reactive models (Sampler is the integral part of reactive architectures) - `Sampler` & `SampleDecoder`
131
+ - `rxnn.transformer.sampler` - samplers for reactive models (Sampler is the integral part of reactive architectures) - `Sampler`, `SampleDecoder`, `BatchSampler` & `BatchSampleDecoder`
130
132
 
131
133
  In **RxNN** models are initialized in declarative style by class composition, but then they are wrapped in imperative classes,
132
134
  to be compatible with HuggingFace **JSON** config. In example:
@@ -211,7 +213,7 @@ include **Long-Term Memory**.
211
213
 
212
214
  The main `ShortTermMemory` class is located in `rxnn.memory.stm` module - the usage example is in Transformers module description.
213
215
 
214
- Other submodules are connected to **Memory Attention** and will be described in 0.2.x version, after MRL
216
+ > 0.2.x Memory modules docs in progress - will be released soon
215
217
 
216
218
  #### Training
217
219
  Training module includes **Trainers** for different training stages of reactive models and shared training utils.
@@ -233,9 +235,9 @@ Submodules:
233
235
  - `rxnn.training.callbacks` contain Trainer callbacks, for different kind of utils (more info below)
234
236
  - `rxnn.training.scheduler` includes learning rate scheduler for training
235
237
  - `rxnn.training.bml` - Base Model Learning module with Trainers for pre-training and fine-tuning
236
- - `rxnn.training.mrl` - Memory Reinforcement Learning module with Trainers for MRL (from 0.2.x)
238
+ - `rxnn.training.mrl` - Memory Reinforcement Learning module with Trainers for MRL
237
239
  - `rxnn.training.rxrlhf` - Reinforcement Learning from Human Feedback for Reactive Models module (from 0.3.x)
238
- - `rxnn.training.brl` - Behavioral Reinforcement Learning module (Reactor / from 0.7.x
240
+ - `rxnn.training.brl` - Behavioral Reinforcement Learning module (Reactor / from 0.7.x)
239
241
 
240
242
  ##### Base Model Learning
241
243
  Docs in progress
@@ -1,5 +1,7 @@
1
- <img src="https://raw.githubusercontent.com/RxAI-dev/RxNN/refs/heads/main/assets/logo/logo_rxai.webp" width="300" />
2
- <img src="https://raw.githubusercontent.com/RxAI-dev/RxNN/refs/heads/main/assets/logo/logo_rxnn.webp" width="300" />
1
+ <span>
2
+ <img src="https://raw.githubusercontent.com/RxAI-dev/RxNN/refs/heads/main/assets/logo/logo_rxai_v2.png" width="400" />
3
+ <img src="https://raw.githubusercontent.com/RxAI-dev/RxNN/refs/heads/main/assets/logo/logo_rxnn_v2.png" width="400" />
4
+ </span>
3
5
 
4
6
  # Reactive AI - RxNN
5
7
  ## Reactive Neural Networks Platform
@@ -36,8 +38,8 @@ We are working on three new reactive architectures, that progressively advance f
36
38
 
37
39
  Each new architecture is based on the previous one and adding new features/abilities. They will be progressively
38
40
  released with next versions of **RxNN** framework:
39
- - 0.1.x: Reactive Transformer base models, Base Model Learning (pre-training/fine-tuning) & Transformers extensions (MoE Attention, Short-Term Memory, etc.)
40
- - 0.2.x: Memory Reinforcement Learning (MRL) for Short-Term Memory & Reactive Transformer, Attention-based Memory System details
41
+ - 0.1.x (Released): Reactive Transformer base models, Base Model Learning (pre-training/fine-tuning) & Transformers extensions (MoE Attention, Short-Term Memory, etc.)
42
+ - 0.2.x (Released): Memory Reinforcement Learning (MRL) for Short-Term Memory & Reactive Transformer, Attention-based Memory System details
41
43
  - 0.3.x: Reinforcement Learning from Human Feedback for Reactive models (RxRLHF), basic Tensor Reactive
42
44
  Extensions (TRX/Rust) for full Reactive Transformer, RxT-Alpha release (+following models - RxT-Beta, etc.)
43
45
  - 0.4.x: Preactor base models, Tensor Database (TDB/Rust) for Long-Term Memory, mxRAG/revRAG subsystems
@@ -101,7 +103,7 @@ Submodules:
101
103
  - `rxnn.transformers.moe` - Mixture-of-Experts feed forward layers - `MoeFeedForward` & `GatedMoeFeedForward` (recommended)
102
104
  - `rxnn.transformer.layers` - complete reactive/classic transformer layers - `ReactiveTransformerLayer` & `ClassicTransformerLayer`
103
105
  - `rxnn.transformer.models` - reactive/classic transformer models - `ReactiveTransformerEncoder`, `ReactiveTransformerDecoder` & `ClassicTransformerEncoder`, `ClassicTransformerDecoder`
104
- - `rxnn.transformer.sampler` - samplers for reactive models (Sampler is the integral part of reactive architectures) - `Sampler` & `SampleDecoder`
106
+ - `rxnn.transformer.sampler` - samplers for reactive models (Sampler is the integral part of reactive architectures) - `Sampler`, `SampleDecoder`, `BatchSampler` & `BatchSampleDecoder`
105
107
 
106
108
  In **RxNN** models are initialized in declarative style by class composition, but then they are wrapped in imperative classes,
107
109
  to be compatible with HuggingFace **JSON** config. In example:
@@ -186,7 +188,7 @@ include **Long-Term Memory**.
186
188
 
187
189
  The main `ShortTermMemory` class is located in `rxnn.memory.stm` module - the usage example is in Transformers module description.
188
190
 
189
- Other submodules are connected to **Memory Attention** and will be described in 0.2.x version, after MRL
191
+ > 0.2.x Memory modules docs in progress - will be released soon
190
192
 
191
193
  #### Training
192
194
  Training module includes **Trainers** for different training stages of reactive models and shared training utils.
@@ -208,9 +210,9 @@ Submodules:
208
210
  - `rxnn.training.callbacks` contain Trainer callbacks, for different kind of utils (more info below)
209
211
  - `rxnn.training.scheduler` includes learning rate scheduler for training
210
212
  - `rxnn.training.bml` - Base Model Learning module with Trainers for pre-training and fine-tuning
211
- - `rxnn.training.mrl` - Memory Reinforcement Learning module with Trainers for MRL (from 0.2.x)
213
+ - `rxnn.training.mrl` - Memory Reinforcement Learning module with Trainers for MRL
212
214
  - `rxnn.training.rxrlhf` - Reinforcement Learning from Human Feedback for Reactive Models module (from 0.3.x)
213
- - `rxnn.training.brl` - Behavioral Reinforcement Learning module (Reactor / from 0.7.x
215
+ - `rxnn.training.brl` - Behavioral Reinforcement Learning module (Reactor / from 0.7.x)
214
216
 
215
217
  ##### Base Model Learning
216
218
  Docs in progress
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.1.83"
7
+ version = "0.2.1"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
Binary file
@@ -287,6 +287,7 @@ class SparseQueryAttention(MultiHeadAttention):
287
287
  k = self.k_proj(key).view(b, -1, self.num_groups, head_dim).transpose(1, 2)
288
288
  v = self.v_proj(value).view(b, -1, self.num_groups, head_dim).transpose(1, 2)
289
289
  else:
290
+ # Relative embedding version is not working without this strange mapping - it will be removed in next versions
290
291
  group_heads = self.num_heads // self.num_groups
291
292
  query_heads = self.num_heads // self.num_query_groups
292
293
  # Process Q
@@ -457,6 +458,7 @@ def init_experimental_attention(
457
458
  dropout: float = 0.0,
458
459
  rope: RotaryPositionalEmbedding = None,
459
460
  rope_only_for_query: bool = False,
461
+ rope_only_for_keys: bool = False,
460
462
  use_relative_embeddings: bool = False,
461
463
  max_seq_len: int = 1024,
462
464
  use_flash_attention: bool = False,
@@ -478,6 +480,7 @@ def init_experimental_attention(
478
480
  use_relative_embeddings=use_relative_embeddings,
479
481
  max_seq_len=max_seq_len,
480
482
  rope_only_for_query=rope_only_for_query,
483
+ rope_only_for_keys=rope_only_for_keys,
481
484
  use_flash_attention=use_flash_attention,
482
485
  is_causal=is_causal,
483
486
  use_bias=use_bias,
@@ -493,6 +496,7 @@ def init_experimental_attention(
493
496
  use_relative_embeddings=use_relative_embeddings,
494
497
  max_seq_len=max_seq_len,
495
498
  rope_only_for_query=rope_only_for_query,
499
+ rope_only_for_keys=rope_only_for_keys,
496
500
  use_flash_attention=use_flash_attention,
497
501
  is_causal=is_causal,
498
502
  use_bias=use_bias,
@@ -511,6 +515,7 @@ def init_experimental_attention(
511
515
  use_relative_embeddings=use_relative_embeddings,
512
516
  max_seq_len=max_seq_len,
513
517
  rope_only_for_query=rope_only_for_query,
518
+ rope_only_for_keys=rope_only_for_keys,
514
519
  use_flash_attention=use_flash_attention,
515
520
  is_causal=is_causal,
516
521
  use_bias=use_bias,
@@ -0,0 +1,42 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from .stm import ShortTermMemory
4
+
5
+ class StmMemoryAttention(nn.Module):
6
+ def __init__(
7
+ self,
8
+ stm: ShortTermMemory,
9
+ attention_layers: nn.ModuleList,
10
+ memory_norm_layers: nn.ModuleList,
11
+ *args,
12
+ **kwargs
13
+ ):
14
+ super(StmMemoryAttention, self).__init__(*args, **kwargs)
15
+ self.stm = stm
16
+ self.attention_layers = attention_layers
17
+ self.memory_norm_layers = memory_norm_layers
18
+ assert len(self.attention_layers) == len(self.memory_norm_layers) == self.stm.memory.size(0)
19
+ self.num_layers = len(attention_layers)
20
+
21
+ def update_max_len(self, max_seq_len: int):
22
+ for i in range(self.num_layers):
23
+ if self.attention_layers[i].rope is not None:
24
+ self.attention_layers[i].rope.update_max_len(max_seq_len)
25
+
26
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
27
+ mask = attention_mask.unsqueeze(1).unsqueeze(1).bool() if attention_mask is not None else None
28
+
29
+ new_stm = torch.zeros_like(self.stm.memory)
30
+ for i in range(self.num_layers):
31
+ layer_stm = self.stm(i)
32
+ # expand layer STM to batch size, if it's not in batch mode
33
+ if layer_stm.size(0) == 1:
34
+ layer_stm = layer_stm.expand(x.size(0), -1, -1)
35
+ encoded_layer_data = x[i]
36
+ normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
37
+ new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=mask)
38
+ # self.stm.update_layer(i, new_layer_stm + layer_stm)
39
+ new_stm[i] = new_layer_stm + layer_stm # residual
40
+ self.stm.update_all(new_stm)
41
+ return self.stm.memory
42
+
@@ -0,0 +1,96 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class ShortTermMemory(nn.Module):
5
+ """Short-term memory module for the Attention-based Memory System"""
6
+
7
+ def __init__(self, num_layers: int, embed_dim: int, stm_size: int, init_type: str = 'normal',
8
+ is_trainable: bool = False, legacy_init: bool = True, *args, **kwargs):
9
+ super(ShortTermMemory, self).__init__(*args, **kwargs)
10
+ self.num_layers = num_layers
11
+ self.embed_dim = embed_dim
12
+ self.stm_size = stm_size
13
+ self.batch_size = 1 # setting 1 as initial batch size (it will be normally used in inference/pre-training. Bigger batches are for RL stages)
14
+ self.is_trainable = is_trainable
15
+ assert init_type in ['normal', 'standard', 'uniform', 'ones', 'zeros'], \
16
+ 'STM init type must be one of "normal", "standard", "uniform", "ones", "zeros"'
17
+
18
+ # Legacy init - temporary option to load old models with not-batched STM (they will be loaded, updated and then the option will be removed)
19
+ self.legacy_init = legacy_init
20
+
21
+ self.init_type = init_type
22
+ stm = self._init_tensor()
23
+ if self.is_trainable:
24
+ self.memory = nn.Parameter(stm)
25
+ else:
26
+ self.register_buffer('memory', stm)
27
+
28
+ def _init_tensor(self, init_type: str = None):
29
+ init_type = init_type or self.init_type
30
+ stm_shape = (self.num_layers, self.stm_size, self.embed_dim) \
31
+ if self.legacy_init else (self.num_layers, self.batch_size, self.stm_size, self.embed_dim)
32
+ if init_type == 'normal':
33
+ return torch.normal(0, 0.02, stm_shape)
34
+ elif init_type == 'standard':
35
+ return torch.normal(0, 1, stm_shape)
36
+ elif init_type == 'uniform':
37
+ return torch.rand(*stm_shape) * 0.02
38
+ elif init_type == 'ones':
39
+ return torch.ones(*stm_shape)
40
+ else:
41
+ return torch.zeros(*stm_shape)
42
+
43
+ def reset_legacy_(self):
44
+ self.legacy_init = False
45
+ self.memory = self._init_tensor()
46
+
47
+ def forward(self, layer: int) -> torch.Tensor:
48
+ return self.memory[layer].unsqueeze(0) if self.legacy_init else self.memory[layer]
49
+
50
+ def update_layer(self, layer: int, new_stm: torch.Tensor):
51
+ self.memory[layer] = new_stm
52
+
53
+ def update_all(self, new_stm: torch.Tensor):
54
+ self.memory.copy_(new_stm)
55
+
56
+ def make_trainable(self):
57
+ if not self.is_trainable:
58
+ self.is_trainable = True
59
+ initial_stm = self.memory.clone()
60
+ del self.memory
61
+ self.memory = nn.Parameter(initial_stm)
62
+
63
+ def freeze(self):
64
+ if self.is_trainable:
65
+ self.requires_grad_(False)
66
+ trained_stm = self.memory.clone()
67
+ del self.memory
68
+ self.register_buffer('memory', trained_stm)
69
+
70
+ def reset(self, init_type: str = None):
71
+ self.memory = self._init_tensor(init_type)
72
+
73
+ def resize(self, new_stm_size: int, init_type: str = None):
74
+ self.stm_size = new_stm_size
75
+ self.memory = self._init_tensor(init_type)
76
+
77
+ def batched_memory(self, batch_size: int, init_type: str = None):
78
+ if init_type is not None:
79
+ assert init_type in ['normal', 'standard', 'uniform', 'ones', 'zeros'], \
80
+ 'STM init type must be one of "normal", "standard", "uniform", "ones", "zeros"'
81
+ self.init_type = init_type
82
+ self.batch_size = batch_size
83
+ self.memory = self._init_tensor()
84
+
85
+ def single_memory(self, init_type: str = None, use_mean_from_batch: bool = False):
86
+ if init_type is not None:
87
+ assert init_type in ['normal', 'standard', 'uniform', 'ones', 'zeros'], \
88
+ 'STM init type must be one of "normal", "standard", "uniform", "ones", "zeros"'
89
+ self.init_type = init_type
90
+ self.batch_size = 1
91
+ if use_mean_from_batch:
92
+ batch_mean = self.memory.mean(dim=(1, 2, 3), keepdim=True)
93
+ self.memory = self._init_tensor()
94
+ self.memory.copy_(batch_mean)
95
+ else:
96
+ self.memory = self._init_tensor()
@@ -8,6 +8,8 @@ from ..transformers.layers import ReactiveTransformerLayer
8
8
  from ..transformers.models import ReactiveTransformerBase, ReactiveTransformerEncoder, ReactiveTransformerDecoder
9
9
  from ..transformers.ff import get_activation_layer
10
10
  from ..memory.stm import ShortTermMemory
11
+ from ..memory.norm import init_memory_norm
12
+ from ..memory.attention import StmMemoryAttention
11
13
  from ..utils import get_model_size
12
14
  from ..experimental.attention import init_experimental_attention
13
15
 
@@ -135,6 +137,22 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
135
137
  def load_shared_memory(self, stm: ShortTermMemory):
136
138
  self.model.stm = stm
137
139
 
140
+ def freeze_without_memory(self):
141
+ for param in self.model.parameters():
142
+ param.requires_grad_(False)
143
+ self.model.trainable_cross_attention_(True)
144
+
145
+ def freeze_memory(self):
146
+ self.model.trainable_cross_attention_(False)
147
+
148
+ def unfreeze_all(self):
149
+ for param in self.model.parameters():
150
+ param.requires_grad_(True)
151
+
152
+ def update_max_len(self, max_seq_len: int):
153
+ for layer in self.model.layers:
154
+ layer.update_max_len(max_seq_len)
155
+
138
156
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> Union[
139
157
  torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
140
158
  return self.model(x, attention_mask=attention_mask)
@@ -205,3 +223,56 @@ def build_rxt_alpha_for_pretraining(
205
223
 
206
224
  return encoder, decoder
207
225
 
226
+ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
227
+ """RxT-Alpha (Reactive Transformer) memory attention model"""
228
+ def __init__(
229
+ self,
230
+ num_layers: int = 12,
231
+ embed_dim: int = 512,
232
+ att_heads: int = 16,
233
+ seq_len: int = 1024,
234
+ stm_size: int = 1024,
235
+ use_flash_attention: bool = True,
236
+ att_dropout: float = 0.0,
237
+ norm_type: str = 'rms',
238
+ att_groups: int = 1,
239
+ att_type: str = 'sqa',
240
+ att_experts: int = None,
241
+ att_query_experts: int = None,
242
+ att_query_groups: int = None,
243
+ **kwargs,
244
+ ):
245
+ super(RxTAlphaMemoryAttention, self).__init__(**kwargs)
246
+
247
+ assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa'], 'Memory attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
248
+
249
+ rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
250
+ stm = ShortTermMemory(num_layers, embed_dim, stm_size)
251
+
252
+ if att_type in ['mha', 'gqa', 'mqa']:
253
+ att_init = lambda: init_attention(embed_dim, att_heads, att_type, att_groups, rope=rope,
254
+ use_flash_attention=use_flash_attention, dropout=att_dropout,
255
+ max_seq_len=seq_len, is_causal=False, rope_only_for_keys=True)
256
+ else:
257
+ att_init = lambda: init_experimental_attention(embed_dim, att_heads, att_type, att_groups, rope=rope,
258
+ use_flash_attention=use_flash_attention, dropout=att_dropout,
259
+ max_seq_len=seq_len, is_causal=False, num_experts=att_experts,
260
+ num_query_experts=att_query_experts,
261
+ num_query_groups=att_query_groups, rope_only_for_keys=True)
262
+
263
+ memory_norm_layers = nn.ModuleList([init_memory_norm(norm_type, embed_dim, stm_size) for _ in range(num_layers)])
264
+ attention_layers = nn.ModuleList([att_init() for _ in range(num_layers)])
265
+ self.model = StmMemoryAttention(stm, attention_layers, memory_norm_layers)
266
+
267
+ def load_shared_memory(self, stm: ShortTermMemory):
268
+ self.model.stm = stm
269
+
270
+ def update_max_len(self, max_seq_len: int):
271
+ self.model.update_max_len(max_seq_len)
272
+
273
+ def reset_memory(self, init_type: str = None):
274
+ self.model.stm.reset_memory(init_type)
275
+
276
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
277
+ return self.model(x, attention_mask=attention_mask)
278
+
@@ -1,46 +1,12 @@
1
1
  import torch
2
- import torch.nn as nn
3
2
  import torch.nn.functional as F
4
3
  from torch.nn.parallel import DistributedDataParallel
5
4
  import math
6
- from huggingface_hub import PyTorchModelHubMixin
7
5
  from typing import Union
8
6
  import torch.distributed as dist
9
- from ..transformers.models import ReactiveTransformerEncoder, ReactiveTransformerDecoder
7
+ from ..transformers.models import ReactiveTransformerDecoder
10
8
  from ..training.base import BaseTrainer
11
-
12
- class MLMHead(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
13
- def __init__(self, embed_dim: int, vocab_size: int, *args, **kwargs):
14
- super(MLMHead, self).__init__(*args, **kwargs)
15
- self.dense = nn.Linear(embed_dim, embed_dim)
16
- self.act = nn.GELU()
17
- self.layer_norm = nn.LayerNorm(embed_dim)
18
- self.decoder = nn.Linear(embed_dim, vocab_size)
19
-
20
- def forward(self, hidden_states):
21
- x = self.dense(hidden_states)
22
- x = self.act(x)
23
- x = self.layer_norm(x)
24
- return self.decoder(x)
25
-
26
-
27
- class MLMTrainingModel(nn.Module):
28
- def __init__(
29
- self,
30
- encoder: ReactiveTransformerEncoder,
31
- mlm_head: MLMHead,
32
- *args,
33
- **kwargs
34
- ):
35
- super(MLMTrainingModel, self).__init__(*args, **kwargs)
36
- self.encoder = encoder
37
- self.mlm_head = mlm_head
38
-
39
- def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
40
- h, _ = self.encoder(x, attention_mask=attention_mask)
41
- y = self.mlm_head(h)
42
- return y
43
-
9
+ from .models import MLMTrainingModel, JointTrainingModel
44
10
 
45
11
  class MLMTrainer(BaseTrainer):
46
12
  def __init__(
@@ -242,29 +208,6 @@ class AutoregressiveTrainer(BaseTrainer):
242
208
  self.model.train()
243
209
  return avg_loss, metrics
244
210
 
245
-
246
- class JointTrainingModel(nn.Module):
247
- def __init__(
248
- self,
249
- encoder: ReactiveTransformerEncoder,
250
- decoder: ReactiveTransformerDecoder,
251
- mlm_head: MLMHead,
252
- *args,
253
- **kwargs
254
- ):
255
- super(JointTrainingModel, self).__init__(*args, **kwargs)
256
- self.encoder = encoder
257
- self.mlm_head = mlm_head
258
- self.decoder = decoder
259
-
260
- def forward(self, x_e: torch.Tensor, x_d: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[
261
- torch.Tensor, torch.Tensor]:
262
- encoder_result, _ = self.encoder(x_e, attention_mask=attention_mask)
263
- y_e = self.mlm_head(encoder_result)
264
- y_d = self.decoder(x_d, attention_mask=attention_mask)
265
- return y_e, y_d
266
-
267
-
268
211
  class JointLMTrainer(BaseTrainer):
269
212
  """"
270
213
  It's not recommended to use Joint LM Training in current implementation. More info soon