rxnn 0.1.83__tar.gz → 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rxnn-0.1.83 → rxnn-0.2.1}/PKG-INFO +11 -9
- {rxnn-0.1.83 → rxnn-0.2.1}/README.md +10 -8
- {rxnn-0.1.83 → rxnn-0.2.1}/pyproject.toml +1 -1
- rxnn-0.2.1/src/rxnn/.DS_Store +0 -0
- {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/experimental/attention.py +5 -0
- rxnn-0.2.1/src/rxnn/memory/attention.py +42 -0
- rxnn-0.2.1/src/rxnn/memory/stm.py +96 -0
- {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/rxt/models.py +71 -0
- {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/training/bml.py +2 -59
- {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/training/callbacks.py +302 -39
- {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/training/dataset.py +344 -1
- rxnn-0.2.1/src/rxnn/training/models.py +142 -0
- rxnn-0.2.1/src/rxnn/training/mrl.py +808 -0
- rxnn-0.2.1/src/rxnn/training/reward.py +111 -0
- rxnn-0.2.1/src/rxnn/training/rl.py +69 -0
- rxnn-0.2.1/src/rxnn/training/utils.py +148 -0
- {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/transformers/attention.py +10 -0
- {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/transformers/layers.py +6 -0
- {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/transformers/models.py +16 -4
- {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/transformers/positional.py +7 -0
- rxnn-0.2.1/src/rxnn/transformers/sampler.py +443 -0
- rxnn-0.1.83/src/rxnn/memory/stm.py +0 -53
- rxnn-0.1.83/src/rxnn/transformers/sampler.py +0 -169
- {rxnn-0.1.83 → rxnn-0.2.1}/LICENSE +0 -0
- {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/__init__.py +0 -0
- {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/experimental/__init__.py +0 -0
- {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/experimental/models.py +0 -0
- {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/experimental/moe.py +0 -0
- {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/memory/__init__.py +0 -0
- {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/memory/norm.py +0 -0
- {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/rxt/__init__.py +0 -0
- {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/training/__init__.py +0 -0
- {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/training/base.py +0 -0
- {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/training/scheduler.py +0 -0
- {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/training/tokenizer.py +0 -0
- {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/transformers/__init__.py +0 -0
- {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/transformers/ff.py +0 -0
- {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/transformers/mask.py +0 -0
- {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/transformers/moe.py +0 -0
- {rxnn-0.1.83 → rxnn-0.2.1}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: rxnn
|
3
|
-
Version: 0.1
|
3
|
+
Version: 0.2.1
|
4
4
|
Summary: RxNN: Reactive Neural Networks Platform
|
5
5
|
License: Apache-2.0
|
6
6
|
Keywords: deep-learning,ai,machine-learning
|
@@ -23,8 +23,10 @@ Project-URL: Homepage, https://rxai.dev/rxnn
|
|
23
23
|
Project-URL: Repository, https://github.com/RxAI-dev/rxnn/python
|
24
24
|
Description-Content-Type: text/markdown
|
25
25
|
|
26
|
-
<
|
27
|
-
<img src="https://raw.githubusercontent.com/RxAI-dev/RxNN/refs/heads/main/assets/logo/
|
26
|
+
<span>
|
27
|
+
<img src="https://raw.githubusercontent.com/RxAI-dev/RxNN/refs/heads/main/assets/logo/logo_rxai_v2.png" width="400" />
|
28
|
+
<img src="https://raw.githubusercontent.com/RxAI-dev/RxNN/refs/heads/main/assets/logo/logo_rxnn_v2.png" width="400" />
|
29
|
+
</span>
|
28
30
|
|
29
31
|
# Reactive AI - RxNN
|
30
32
|
## Reactive Neural Networks Platform
|
@@ -61,8 +63,8 @@ We are working on three new reactive architectures, that progressively advance f
|
|
61
63
|
|
62
64
|
Each new architecture is based on the previous one and adding new features/abilities. They will be progressively
|
63
65
|
released with next versions of **RxNN** framework:
|
64
|
-
- 0.1.x: Reactive Transformer base models, Base Model Learning (pre-training/fine-tuning) & Transformers extensions (MoE Attention, Short-Term Memory, etc.)
|
65
|
-
- 0.2.x: Memory Reinforcement Learning (MRL) for Short-Term Memory & Reactive Transformer, Attention-based Memory System details
|
66
|
+
- 0.1.x (Released): Reactive Transformer base models, Base Model Learning (pre-training/fine-tuning) & Transformers extensions (MoE Attention, Short-Term Memory, etc.)
|
67
|
+
- 0.2.x (Released): Memory Reinforcement Learning (MRL) for Short-Term Memory & Reactive Transformer, Attention-based Memory System details
|
66
68
|
- 0.3.x: Reinforcement Learning from Human Feedback for Reactive models (RxRLHF), basic Tensor Reactive
|
67
69
|
Extensions (TRX/Rust) for full Reactive Transformer, RxT-Alpha release (+following models - RxT-Beta, etc.)
|
68
70
|
- 0.4.x: Preactor base models, Tensor Database (TDB/Rust) for Long-Term Memory, mxRAG/revRAG subsystems
|
@@ -126,7 +128,7 @@ Submodules:
|
|
126
128
|
- `rxnn.transformers.moe` - Mixture-of-Experts feed forward layers - `MoeFeedForward` & `GatedMoeFeedForward` (recommended)
|
127
129
|
- `rxnn.transformer.layers` - complete reactive/classic transformer layers - `ReactiveTransformerLayer` & `ClassicTransformerLayer`
|
128
130
|
- `rxnn.transformer.models` - reactive/classic transformer models - `ReactiveTransformerEncoder`, `ReactiveTransformerDecoder` & `ClassicTransformerEncoder`, `ClassicTransformerDecoder`
|
129
|
-
- `rxnn.transformer.sampler` - samplers for reactive models (Sampler is the integral part of reactive architectures) - `Sampler` & `
|
131
|
+
- `rxnn.transformer.sampler` - samplers for reactive models (Sampler is the integral part of reactive architectures) - `Sampler`, `SampleDecoder`, `BatchSampler` & `BatchSampleDecoder`
|
130
132
|
|
131
133
|
In **RxNN** models are initialized in declarative style by class composition, but then they are wrapped in imperative classes,
|
132
134
|
to be compatible with HuggingFace **JSON** config. In example:
|
@@ -211,7 +213,7 @@ include **Long-Term Memory**.
|
|
211
213
|
|
212
214
|
The main `ShortTermMemory` class is located in `rxnn.memory.stm` module - the usage example is in Transformers module description.
|
213
215
|
|
214
|
-
|
216
|
+
> 0.2.x Memory modules docs in progress - will be released soon
|
215
217
|
|
216
218
|
#### Training
|
217
219
|
Training module includes **Trainers** for different training stages of reactive models and shared training utils.
|
@@ -233,9 +235,9 @@ Submodules:
|
|
233
235
|
- `rxnn.training.callbacks` contain Trainer callbacks, for different kind of utils (more info below)
|
234
236
|
- `rxnn.training.scheduler` includes learning rate scheduler for training
|
235
237
|
- `rxnn.training.bml` - Base Model Learning module with Trainers for pre-training and fine-tuning
|
236
|
-
- `rxnn.training.mrl` - Memory Reinforcement Learning module with Trainers for MRL
|
238
|
+
- `rxnn.training.mrl` - Memory Reinforcement Learning module with Trainers for MRL
|
237
239
|
- `rxnn.training.rxrlhf` - Reinforcement Learning from Human Feedback for Reactive Models module (from 0.3.x)
|
238
|
-
- `rxnn.training.brl` - Behavioral Reinforcement Learning module (Reactor / from 0.7.x
|
240
|
+
- `rxnn.training.brl` - Behavioral Reinforcement Learning module (Reactor / from 0.7.x)
|
239
241
|
|
240
242
|
##### Base Model Learning
|
241
243
|
Docs in progress
|
@@ -1,5 +1,7 @@
|
|
1
|
-
<
|
2
|
-
<img src="https://raw.githubusercontent.com/RxAI-dev/RxNN/refs/heads/main/assets/logo/
|
1
|
+
<span>
|
2
|
+
<img src="https://raw.githubusercontent.com/RxAI-dev/RxNN/refs/heads/main/assets/logo/logo_rxai_v2.png" width="400" />
|
3
|
+
<img src="https://raw.githubusercontent.com/RxAI-dev/RxNN/refs/heads/main/assets/logo/logo_rxnn_v2.png" width="400" />
|
4
|
+
</span>
|
3
5
|
|
4
6
|
# Reactive AI - RxNN
|
5
7
|
## Reactive Neural Networks Platform
|
@@ -36,8 +38,8 @@ We are working on three new reactive architectures, that progressively advance f
|
|
36
38
|
|
37
39
|
Each new architecture is based on the previous one and adding new features/abilities. They will be progressively
|
38
40
|
released with next versions of **RxNN** framework:
|
39
|
-
- 0.1.x: Reactive Transformer base models, Base Model Learning (pre-training/fine-tuning) & Transformers extensions (MoE Attention, Short-Term Memory, etc.)
|
40
|
-
- 0.2.x: Memory Reinforcement Learning (MRL) for Short-Term Memory & Reactive Transformer, Attention-based Memory System details
|
41
|
+
- 0.1.x (Released): Reactive Transformer base models, Base Model Learning (pre-training/fine-tuning) & Transformers extensions (MoE Attention, Short-Term Memory, etc.)
|
42
|
+
- 0.2.x (Released): Memory Reinforcement Learning (MRL) for Short-Term Memory & Reactive Transformer, Attention-based Memory System details
|
41
43
|
- 0.3.x: Reinforcement Learning from Human Feedback for Reactive models (RxRLHF), basic Tensor Reactive
|
42
44
|
Extensions (TRX/Rust) for full Reactive Transformer, RxT-Alpha release (+following models - RxT-Beta, etc.)
|
43
45
|
- 0.4.x: Preactor base models, Tensor Database (TDB/Rust) for Long-Term Memory, mxRAG/revRAG subsystems
|
@@ -101,7 +103,7 @@ Submodules:
|
|
101
103
|
- `rxnn.transformers.moe` - Mixture-of-Experts feed forward layers - `MoeFeedForward` & `GatedMoeFeedForward` (recommended)
|
102
104
|
- `rxnn.transformer.layers` - complete reactive/classic transformer layers - `ReactiveTransformerLayer` & `ClassicTransformerLayer`
|
103
105
|
- `rxnn.transformer.models` - reactive/classic transformer models - `ReactiveTransformerEncoder`, `ReactiveTransformerDecoder` & `ClassicTransformerEncoder`, `ClassicTransformerDecoder`
|
104
|
-
- `rxnn.transformer.sampler` - samplers for reactive models (Sampler is the integral part of reactive architectures) - `Sampler` & `
|
106
|
+
- `rxnn.transformer.sampler` - samplers for reactive models (Sampler is the integral part of reactive architectures) - `Sampler`, `SampleDecoder`, `BatchSampler` & `BatchSampleDecoder`
|
105
107
|
|
106
108
|
In **RxNN** models are initialized in declarative style by class composition, but then they are wrapped in imperative classes,
|
107
109
|
to be compatible with HuggingFace **JSON** config. In example:
|
@@ -186,7 +188,7 @@ include **Long-Term Memory**.
|
|
186
188
|
|
187
189
|
The main `ShortTermMemory` class is located in `rxnn.memory.stm` module - the usage example is in Transformers module description.
|
188
190
|
|
189
|
-
|
191
|
+
> 0.2.x Memory modules docs in progress - will be released soon
|
190
192
|
|
191
193
|
#### Training
|
192
194
|
Training module includes **Trainers** for different training stages of reactive models and shared training utils.
|
@@ -208,9 +210,9 @@ Submodules:
|
|
208
210
|
- `rxnn.training.callbacks` contain Trainer callbacks, for different kind of utils (more info below)
|
209
211
|
- `rxnn.training.scheduler` includes learning rate scheduler for training
|
210
212
|
- `rxnn.training.bml` - Base Model Learning module with Trainers for pre-training and fine-tuning
|
211
|
-
- `rxnn.training.mrl` - Memory Reinforcement Learning module with Trainers for MRL
|
213
|
+
- `rxnn.training.mrl` - Memory Reinforcement Learning module with Trainers for MRL
|
212
214
|
- `rxnn.training.rxrlhf` - Reinforcement Learning from Human Feedback for Reactive Models module (from 0.3.x)
|
213
|
-
- `rxnn.training.brl` - Behavioral Reinforcement Learning module (Reactor / from 0.7.x
|
215
|
+
- `rxnn.training.brl` - Behavioral Reinforcement Learning module (Reactor / from 0.7.x)
|
214
216
|
|
215
217
|
##### Base Model Learning
|
216
218
|
Docs in progress
|
Binary file
|
@@ -287,6 +287,7 @@ class SparseQueryAttention(MultiHeadAttention):
|
|
287
287
|
k = self.k_proj(key).view(b, -1, self.num_groups, head_dim).transpose(1, 2)
|
288
288
|
v = self.v_proj(value).view(b, -1, self.num_groups, head_dim).transpose(1, 2)
|
289
289
|
else:
|
290
|
+
# Relative embedding version is not working without this strange mapping - it will be removed in next versions
|
290
291
|
group_heads = self.num_heads // self.num_groups
|
291
292
|
query_heads = self.num_heads // self.num_query_groups
|
292
293
|
# Process Q
|
@@ -457,6 +458,7 @@ def init_experimental_attention(
|
|
457
458
|
dropout: float = 0.0,
|
458
459
|
rope: RotaryPositionalEmbedding = None,
|
459
460
|
rope_only_for_query: bool = False,
|
461
|
+
rope_only_for_keys: bool = False,
|
460
462
|
use_relative_embeddings: bool = False,
|
461
463
|
max_seq_len: int = 1024,
|
462
464
|
use_flash_attention: bool = False,
|
@@ -478,6 +480,7 @@ def init_experimental_attention(
|
|
478
480
|
use_relative_embeddings=use_relative_embeddings,
|
479
481
|
max_seq_len=max_seq_len,
|
480
482
|
rope_only_for_query=rope_only_for_query,
|
483
|
+
rope_only_for_keys=rope_only_for_keys,
|
481
484
|
use_flash_attention=use_flash_attention,
|
482
485
|
is_causal=is_causal,
|
483
486
|
use_bias=use_bias,
|
@@ -493,6 +496,7 @@ def init_experimental_attention(
|
|
493
496
|
use_relative_embeddings=use_relative_embeddings,
|
494
497
|
max_seq_len=max_seq_len,
|
495
498
|
rope_only_for_query=rope_only_for_query,
|
499
|
+
rope_only_for_keys=rope_only_for_keys,
|
496
500
|
use_flash_attention=use_flash_attention,
|
497
501
|
is_causal=is_causal,
|
498
502
|
use_bias=use_bias,
|
@@ -511,6 +515,7 @@ def init_experimental_attention(
|
|
511
515
|
use_relative_embeddings=use_relative_embeddings,
|
512
516
|
max_seq_len=max_seq_len,
|
513
517
|
rope_only_for_query=rope_only_for_query,
|
518
|
+
rope_only_for_keys=rope_only_for_keys,
|
514
519
|
use_flash_attention=use_flash_attention,
|
515
520
|
is_causal=is_causal,
|
516
521
|
use_bias=use_bias,
|
@@ -0,0 +1,42 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
from .stm import ShortTermMemory
|
4
|
+
|
5
|
+
class StmMemoryAttention(nn.Module):
|
6
|
+
def __init__(
|
7
|
+
self,
|
8
|
+
stm: ShortTermMemory,
|
9
|
+
attention_layers: nn.ModuleList,
|
10
|
+
memory_norm_layers: nn.ModuleList,
|
11
|
+
*args,
|
12
|
+
**kwargs
|
13
|
+
):
|
14
|
+
super(StmMemoryAttention, self).__init__(*args, **kwargs)
|
15
|
+
self.stm = stm
|
16
|
+
self.attention_layers = attention_layers
|
17
|
+
self.memory_norm_layers = memory_norm_layers
|
18
|
+
assert len(self.attention_layers) == len(self.memory_norm_layers) == self.stm.memory.size(0)
|
19
|
+
self.num_layers = len(attention_layers)
|
20
|
+
|
21
|
+
def update_max_len(self, max_seq_len: int):
|
22
|
+
for i in range(self.num_layers):
|
23
|
+
if self.attention_layers[i].rope is not None:
|
24
|
+
self.attention_layers[i].rope.update_max_len(max_seq_len)
|
25
|
+
|
26
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
27
|
+
mask = attention_mask.unsqueeze(1).unsqueeze(1).bool() if attention_mask is not None else None
|
28
|
+
|
29
|
+
new_stm = torch.zeros_like(self.stm.memory)
|
30
|
+
for i in range(self.num_layers):
|
31
|
+
layer_stm = self.stm(i)
|
32
|
+
# expand layer STM to batch size, if it's not in batch mode
|
33
|
+
if layer_stm.size(0) == 1:
|
34
|
+
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
35
|
+
encoded_layer_data = x[i]
|
36
|
+
normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
|
37
|
+
new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=mask)
|
38
|
+
# self.stm.update_layer(i, new_layer_stm + layer_stm)
|
39
|
+
new_stm[i] = new_layer_stm + layer_stm # residual
|
40
|
+
self.stm.update_all(new_stm)
|
41
|
+
return self.stm.memory
|
42
|
+
|
@@ -0,0 +1,96 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
|
4
|
+
class ShortTermMemory(nn.Module):
|
5
|
+
"""Short-term memory module for the Attention-based Memory System"""
|
6
|
+
|
7
|
+
def __init__(self, num_layers: int, embed_dim: int, stm_size: int, init_type: str = 'normal',
|
8
|
+
is_trainable: bool = False, legacy_init: bool = True, *args, **kwargs):
|
9
|
+
super(ShortTermMemory, self).__init__(*args, **kwargs)
|
10
|
+
self.num_layers = num_layers
|
11
|
+
self.embed_dim = embed_dim
|
12
|
+
self.stm_size = stm_size
|
13
|
+
self.batch_size = 1 # setting 1 as initial batch size (it will be normally used in inference/pre-training. Bigger batches are for RL stages)
|
14
|
+
self.is_trainable = is_trainable
|
15
|
+
assert init_type in ['normal', 'standard', 'uniform', 'ones', 'zeros'], \
|
16
|
+
'STM init type must be one of "normal", "standard", "uniform", "ones", "zeros"'
|
17
|
+
|
18
|
+
# Legacy init - temporary option to load old models with not-batched STM (they will be loaded, updated and then the option will be removed)
|
19
|
+
self.legacy_init = legacy_init
|
20
|
+
|
21
|
+
self.init_type = init_type
|
22
|
+
stm = self._init_tensor()
|
23
|
+
if self.is_trainable:
|
24
|
+
self.memory = nn.Parameter(stm)
|
25
|
+
else:
|
26
|
+
self.register_buffer('memory', stm)
|
27
|
+
|
28
|
+
def _init_tensor(self, init_type: str = None):
|
29
|
+
init_type = init_type or self.init_type
|
30
|
+
stm_shape = (self.num_layers, self.stm_size, self.embed_dim) \
|
31
|
+
if self.legacy_init else (self.num_layers, self.batch_size, self.stm_size, self.embed_dim)
|
32
|
+
if init_type == 'normal':
|
33
|
+
return torch.normal(0, 0.02, stm_shape)
|
34
|
+
elif init_type == 'standard':
|
35
|
+
return torch.normal(0, 1, stm_shape)
|
36
|
+
elif init_type == 'uniform':
|
37
|
+
return torch.rand(*stm_shape) * 0.02
|
38
|
+
elif init_type == 'ones':
|
39
|
+
return torch.ones(*stm_shape)
|
40
|
+
else:
|
41
|
+
return torch.zeros(*stm_shape)
|
42
|
+
|
43
|
+
def reset_legacy_(self):
|
44
|
+
self.legacy_init = False
|
45
|
+
self.memory = self._init_tensor()
|
46
|
+
|
47
|
+
def forward(self, layer: int) -> torch.Tensor:
|
48
|
+
return self.memory[layer].unsqueeze(0) if self.legacy_init else self.memory[layer]
|
49
|
+
|
50
|
+
def update_layer(self, layer: int, new_stm: torch.Tensor):
|
51
|
+
self.memory[layer] = new_stm
|
52
|
+
|
53
|
+
def update_all(self, new_stm: torch.Tensor):
|
54
|
+
self.memory.copy_(new_stm)
|
55
|
+
|
56
|
+
def make_trainable(self):
|
57
|
+
if not self.is_trainable:
|
58
|
+
self.is_trainable = True
|
59
|
+
initial_stm = self.memory.clone()
|
60
|
+
del self.memory
|
61
|
+
self.memory = nn.Parameter(initial_stm)
|
62
|
+
|
63
|
+
def freeze(self):
|
64
|
+
if self.is_trainable:
|
65
|
+
self.requires_grad_(False)
|
66
|
+
trained_stm = self.memory.clone()
|
67
|
+
del self.memory
|
68
|
+
self.register_buffer('memory', trained_stm)
|
69
|
+
|
70
|
+
def reset(self, init_type: str = None):
|
71
|
+
self.memory = self._init_tensor(init_type)
|
72
|
+
|
73
|
+
def resize(self, new_stm_size: int, init_type: str = None):
|
74
|
+
self.stm_size = new_stm_size
|
75
|
+
self.memory = self._init_tensor(init_type)
|
76
|
+
|
77
|
+
def batched_memory(self, batch_size: int, init_type: str = None):
|
78
|
+
if init_type is not None:
|
79
|
+
assert init_type in ['normal', 'standard', 'uniform', 'ones', 'zeros'], \
|
80
|
+
'STM init type must be one of "normal", "standard", "uniform", "ones", "zeros"'
|
81
|
+
self.init_type = init_type
|
82
|
+
self.batch_size = batch_size
|
83
|
+
self.memory = self._init_tensor()
|
84
|
+
|
85
|
+
def single_memory(self, init_type: str = None, use_mean_from_batch: bool = False):
|
86
|
+
if init_type is not None:
|
87
|
+
assert init_type in ['normal', 'standard', 'uniform', 'ones', 'zeros'], \
|
88
|
+
'STM init type must be one of "normal", "standard", "uniform", "ones", "zeros"'
|
89
|
+
self.init_type = init_type
|
90
|
+
self.batch_size = 1
|
91
|
+
if use_mean_from_batch:
|
92
|
+
batch_mean = self.memory.mean(dim=(1, 2, 3), keepdim=True)
|
93
|
+
self.memory = self._init_tensor()
|
94
|
+
self.memory.copy_(batch_mean)
|
95
|
+
else:
|
96
|
+
self.memory = self._init_tensor()
|
@@ -8,6 +8,8 @@ from ..transformers.layers import ReactiveTransformerLayer
|
|
8
8
|
from ..transformers.models import ReactiveTransformerBase, ReactiveTransformerEncoder, ReactiveTransformerDecoder
|
9
9
|
from ..transformers.ff import get_activation_layer
|
10
10
|
from ..memory.stm import ShortTermMemory
|
11
|
+
from ..memory.norm import init_memory_norm
|
12
|
+
from ..memory.attention import StmMemoryAttention
|
11
13
|
from ..utils import get_model_size
|
12
14
|
from ..experimental.attention import init_experimental_attention
|
13
15
|
|
@@ -135,6 +137,22 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
|
|
135
137
|
def load_shared_memory(self, stm: ShortTermMemory):
|
136
138
|
self.model.stm = stm
|
137
139
|
|
140
|
+
def freeze_without_memory(self):
|
141
|
+
for param in self.model.parameters():
|
142
|
+
param.requires_grad_(False)
|
143
|
+
self.model.trainable_cross_attention_(True)
|
144
|
+
|
145
|
+
def freeze_memory(self):
|
146
|
+
self.model.trainable_cross_attention_(False)
|
147
|
+
|
148
|
+
def unfreeze_all(self):
|
149
|
+
for param in self.model.parameters():
|
150
|
+
param.requires_grad_(True)
|
151
|
+
|
152
|
+
def update_max_len(self, max_seq_len: int):
|
153
|
+
for layer in self.model.layers:
|
154
|
+
layer.update_max_len(max_seq_len)
|
155
|
+
|
138
156
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> Union[
|
139
157
|
torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
140
158
|
return self.model(x, attention_mask=attention_mask)
|
@@ -205,3 +223,56 @@ def build_rxt_alpha_for_pretraining(
|
|
205
223
|
|
206
224
|
return encoder, decoder
|
207
225
|
|
226
|
+
class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
|
227
|
+
"""RxT-Alpha (Reactive Transformer) memory attention model"""
|
228
|
+
def __init__(
|
229
|
+
self,
|
230
|
+
num_layers: int = 12,
|
231
|
+
embed_dim: int = 512,
|
232
|
+
att_heads: int = 16,
|
233
|
+
seq_len: int = 1024,
|
234
|
+
stm_size: int = 1024,
|
235
|
+
use_flash_attention: bool = True,
|
236
|
+
att_dropout: float = 0.0,
|
237
|
+
norm_type: str = 'rms',
|
238
|
+
att_groups: int = 1,
|
239
|
+
att_type: str = 'sqa',
|
240
|
+
att_experts: int = None,
|
241
|
+
att_query_experts: int = None,
|
242
|
+
att_query_groups: int = None,
|
243
|
+
**kwargs,
|
244
|
+
):
|
245
|
+
super(RxTAlphaMemoryAttention, self).__init__(**kwargs)
|
246
|
+
|
247
|
+
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa'], 'Memory attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
|
248
|
+
|
249
|
+
rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
|
250
|
+
stm = ShortTermMemory(num_layers, embed_dim, stm_size)
|
251
|
+
|
252
|
+
if att_type in ['mha', 'gqa', 'mqa']:
|
253
|
+
att_init = lambda: init_attention(embed_dim, att_heads, att_type, att_groups, rope=rope,
|
254
|
+
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
255
|
+
max_seq_len=seq_len, is_causal=False, rope_only_for_keys=True)
|
256
|
+
else:
|
257
|
+
att_init = lambda: init_experimental_attention(embed_dim, att_heads, att_type, att_groups, rope=rope,
|
258
|
+
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
259
|
+
max_seq_len=seq_len, is_causal=False, num_experts=att_experts,
|
260
|
+
num_query_experts=att_query_experts,
|
261
|
+
num_query_groups=att_query_groups, rope_only_for_keys=True)
|
262
|
+
|
263
|
+
memory_norm_layers = nn.ModuleList([init_memory_norm(norm_type, embed_dim, stm_size) for _ in range(num_layers)])
|
264
|
+
attention_layers = nn.ModuleList([att_init() for _ in range(num_layers)])
|
265
|
+
self.model = StmMemoryAttention(stm, attention_layers, memory_norm_layers)
|
266
|
+
|
267
|
+
def load_shared_memory(self, stm: ShortTermMemory):
|
268
|
+
self.model.stm = stm
|
269
|
+
|
270
|
+
def update_max_len(self, max_seq_len: int):
|
271
|
+
self.model.update_max_len(max_seq_len)
|
272
|
+
|
273
|
+
def reset_memory(self, init_type: str = None):
|
274
|
+
self.model.stm.reset_memory(init_type)
|
275
|
+
|
276
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
277
|
+
return self.model(x, attention_mask=attention_mask)
|
278
|
+
|
@@ -1,46 +1,12 @@
|
|
1
1
|
import torch
|
2
|
-
import torch.nn as nn
|
3
2
|
import torch.nn.functional as F
|
4
3
|
from torch.nn.parallel import DistributedDataParallel
|
5
4
|
import math
|
6
|
-
from huggingface_hub import PyTorchModelHubMixin
|
7
5
|
from typing import Union
|
8
6
|
import torch.distributed as dist
|
9
|
-
from ..transformers.models import
|
7
|
+
from ..transformers.models import ReactiveTransformerDecoder
|
10
8
|
from ..training.base import BaseTrainer
|
11
|
-
|
12
|
-
class MLMHead(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
|
13
|
-
def __init__(self, embed_dim: int, vocab_size: int, *args, **kwargs):
|
14
|
-
super(MLMHead, self).__init__(*args, **kwargs)
|
15
|
-
self.dense = nn.Linear(embed_dim, embed_dim)
|
16
|
-
self.act = nn.GELU()
|
17
|
-
self.layer_norm = nn.LayerNorm(embed_dim)
|
18
|
-
self.decoder = nn.Linear(embed_dim, vocab_size)
|
19
|
-
|
20
|
-
def forward(self, hidden_states):
|
21
|
-
x = self.dense(hidden_states)
|
22
|
-
x = self.act(x)
|
23
|
-
x = self.layer_norm(x)
|
24
|
-
return self.decoder(x)
|
25
|
-
|
26
|
-
|
27
|
-
class MLMTrainingModel(nn.Module):
|
28
|
-
def __init__(
|
29
|
-
self,
|
30
|
-
encoder: ReactiveTransformerEncoder,
|
31
|
-
mlm_head: MLMHead,
|
32
|
-
*args,
|
33
|
-
**kwargs
|
34
|
-
):
|
35
|
-
super(MLMTrainingModel, self).__init__(*args, **kwargs)
|
36
|
-
self.encoder = encoder
|
37
|
-
self.mlm_head = mlm_head
|
38
|
-
|
39
|
-
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
40
|
-
h, _ = self.encoder(x, attention_mask=attention_mask)
|
41
|
-
y = self.mlm_head(h)
|
42
|
-
return y
|
43
|
-
|
9
|
+
from .models import MLMTrainingModel, JointTrainingModel
|
44
10
|
|
45
11
|
class MLMTrainer(BaseTrainer):
|
46
12
|
def __init__(
|
@@ -242,29 +208,6 @@ class AutoregressiveTrainer(BaseTrainer):
|
|
242
208
|
self.model.train()
|
243
209
|
return avg_loss, metrics
|
244
210
|
|
245
|
-
|
246
|
-
class JointTrainingModel(nn.Module):
|
247
|
-
def __init__(
|
248
|
-
self,
|
249
|
-
encoder: ReactiveTransformerEncoder,
|
250
|
-
decoder: ReactiveTransformerDecoder,
|
251
|
-
mlm_head: MLMHead,
|
252
|
-
*args,
|
253
|
-
**kwargs
|
254
|
-
):
|
255
|
-
super(JointTrainingModel, self).__init__(*args, **kwargs)
|
256
|
-
self.encoder = encoder
|
257
|
-
self.mlm_head = mlm_head
|
258
|
-
self.decoder = decoder
|
259
|
-
|
260
|
-
def forward(self, x_e: torch.Tensor, x_d: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[
|
261
|
-
torch.Tensor, torch.Tensor]:
|
262
|
-
encoder_result, _ = self.encoder(x_e, attention_mask=attention_mask)
|
263
|
-
y_e = self.mlm_head(encoder_result)
|
264
|
-
y_d = self.decoder(x_d, attention_mask=attention_mask)
|
265
|
-
return y_e, y_d
|
266
|
-
|
267
|
-
|
268
211
|
class JointLMTrainer(BaseTrainer):
|
269
212
|
""""
|
270
213
|
It's not recommended to use Joint LM Training in current implementation. More info soon
|