rxnn 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn-0.1.0.dist-info/LICENSE +201 -0
- rxnn-0.1.0.dist-info/METADATA +257 -0
- rxnn-0.1.0.dist-info/RECORD +23 -0
- rxnn-0.1.0.dist-info/WHEEL +4 -0
- src/experimental/attention.py +133 -0
- src/memory/norm.py +173 -0
- src/memory/stm.py +53 -0
- src/rxt/models.py +180 -0
- src/training/base.py +275 -0
- src/training/bml.py +345 -0
- src/training/callbacks.py +491 -0
- src/training/dataset.py +164 -0
- src/training/scheduler.py +19 -0
- src/training/tokenizer.py +208 -0
- src/transformers/attention.py +324 -0
- src/transformers/ff.py +72 -0
- src/transformers/layers.py +150 -0
- src/transformers/mask.py +10 -0
- src/transformers/models.py +168 -0
- src/transformers/moe.py +139 -0
- src/transformers/positional.py +105 -0
- src/transformers/sampler.py +109 -0
- src/utils.py +14 -0
@@ -0,0 +1,150 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
from attention import MultiHeadAttention
|
4
|
+
from ff import FeedForward, GatedFeedForward
|
5
|
+
from moe import MoeFeedForward, GatedMoeFeedForward
|
6
|
+
|
7
|
+
|
8
|
+
class ReactiveTransformerLayer(nn.Module):
|
9
|
+
"""Reactive Transformer layer - extending the classic Transformer layer with Memory Cross-Attention"""
|
10
|
+
|
11
|
+
def __init__(
|
12
|
+
self,
|
13
|
+
embed_dim: int,
|
14
|
+
ff_dim: int,
|
15
|
+
self_attention: MultiHeadAttention,
|
16
|
+
memory_cross_attention: MultiHeadAttention,
|
17
|
+
use_rms_norm: bool = False,
|
18
|
+
use_post_norm: bool = False,
|
19
|
+
ff_activation: nn.Module = nn.GELU(),
|
20
|
+
ff_dropout: float = 0.1,
|
21
|
+
use_gated: bool = False,
|
22
|
+
use_moe: bool = False,
|
23
|
+
num_experts: int = 1,
|
24
|
+
moe_top_k: int = 1,
|
25
|
+
*args,
|
26
|
+
**kwargs,
|
27
|
+
):
|
28
|
+
super(ReactiveTransformerLayer, self).__init__(*args, **kwargs)
|
29
|
+
|
30
|
+
self.attention = self_attention
|
31
|
+
|
32
|
+
self.memory_cross_attention = memory_cross_attention
|
33
|
+
|
34
|
+
if use_gated:
|
35
|
+
if use_moe:
|
36
|
+
self.ff = GatedMoeFeedForward(embed_dim, ff_dim, num_experts, ff_activation, top_k=moe_top_k,
|
37
|
+
dropout=ff_dropout)
|
38
|
+
else:
|
39
|
+
self.ff = GatedFeedForward(embed_dim, ff_dim, ff_activation, dropout=ff_dropout)
|
40
|
+
else:
|
41
|
+
if use_moe:
|
42
|
+
self.ff = MoeFeedForward(embed_dim, ff_dim, num_experts, ff_activation, top_k=moe_top_k,
|
43
|
+
dropout=ff_dropout)
|
44
|
+
else:
|
45
|
+
self.ff = FeedForward(embed_dim, ff_dim, ff_activation, dropout=ff_dropout)
|
46
|
+
|
47
|
+
if use_rms_norm:
|
48
|
+
self.norm1 = nn.RMSNorm(embed_dim)
|
49
|
+
self.norm2 = nn.RMSNorm(embed_dim)
|
50
|
+
self.norm3 = nn.RMSNorm(embed_dim)
|
51
|
+
else:
|
52
|
+
self.norm1 = nn.LayerNorm(embed_dim)
|
53
|
+
self.norm2 = nn.LayerNorm(embed_dim)
|
54
|
+
self.norm3 = nn.LayerNorm(embed_dim)
|
55
|
+
self.use_post_norm = use_post_norm
|
56
|
+
|
57
|
+
def trainable_cross_attention_(self, is_trainable: bool):
|
58
|
+
for param in self.memory_cross_attention.parameters():
|
59
|
+
param.requires_grad_(is_trainable)
|
60
|
+
|
61
|
+
def forward(self, x: torch.Tensor, stm: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
62
|
+
# First step, self-attention
|
63
|
+
residual = x
|
64
|
+
if not self.use_post_norm:
|
65
|
+
x = self.norm1(x)
|
66
|
+
x = self.attention(x, x, x, mask=mask)
|
67
|
+
x = residual + x
|
68
|
+
if self.use_post_norm:
|
69
|
+
x = self.norm1(x)
|
70
|
+
# Second step, Memory cross-attention
|
71
|
+
residual = x
|
72
|
+
if not self.use_post_norm:
|
73
|
+
x = self.norm2(x)
|
74
|
+
x = self.memory_cross_attention(x, stm, stm)
|
75
|
+
x = residual + x
|
76
|
+
if self.use_post_norm:
|
77
|
+
x = self.norm2(x)
|
78
|
+
|
79
|
+
# Third step, Feed Forward network
|
80
|
+
residual = x
|
81
|
+
if not self.use_post_norm:
|
82
|
+
x = self.norm3(x)
|
83
|
+
x = self.ff(x)
|
84
|
+
x = residual + x
|
85
|
+
if self.use_post_norm:
|
86
|
+
x = self.norm3(x)
|
87
|
+
return x
|
88
|
+
|
89
|
+
|
90
|
+
class ClassicTransformerLayer(nn.Module):
|
91
|
+
"""Classic Transformer layer - classic decoder-only/encoder-only Transformer layer with self-attention and Feed-Forward network."""
|
92
|
+
|
93
|
+
def __init__(
|
94
|
+
self,
|
95
|
+
embed_dim: int,
|
96
|
+
ff_dim: int,
|
97
|
+
self_attention: MultiHeadAttention,
|
98
|
+
use_rms_norm: bool = False,
|
99
|
+
use_post_norm: bool = False,
|
100
|
+
ff_activation: nn.Module = nn.GELU(),
|
101
|
+
ff_dropout: float = 0.1,
|
102
|
+
use_gated: bool = False,
|
103
|
+
use_moe: bool = False,
|
104
|
+
num_experts: int = 1,
|
105
|
+
moe_top_k: int = 1,
|
106
|
+
*args,
|
107
|
+
**kwargs,
|
108
|
+
):
|
109
|
+
super(ClassicTransformerLayer, self).__init__(*args, **kwargs)
|
110
|
+
|
111
|
+
self.attention = self_attention
|
112
|
+
|
113
|
+
if use_gated:
|
114
|
+
if use_moe:
|
115
|
+
self.ff = GatedMoeFeedForward(embed_dim, ff_dim, num_experts, top_k=moe_top_k, dropout=ff_dropout)
|
116
|
+
else:
|
117
|
+
self.ff = GatedFeedForward(embed_dim, ff_dim, dropout=ff_dropout)
|
118
|
+
else:
|
119
|
+
if use_moe:
|
120
|
+
self.ff = MoeFeedForward(embed_dim, ff_dim, num_experts, ff_activation, top_k=moe_top_k,
|
121
|
+
dropout=ff_dropout)
|
122
|
+
else:
|
123
|
+
self.ff = FeedForward(embed_dim, ff_dim, ff_activation, dropout=ff_dropout)
|
124
|
+
|
125
|
+
if use_rms_norm:
|
126
|
+
self.norm1 = nn.RMSNorm(embed_dim)
|
127
|
+
self.norm2 = nn.RMSNorm(embed_dim)
|
128
|
+
else:
|
129
|
+
self.norm1 = nn.LayerNorm(embed_dim)
|
130
|
+
self.norm2 = nn.LayerNorm(embed_dim)
|
131
|
+
self.use_post_norm = use_post_norm
|
132
|
+
|
133
|
+
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
134
|
+
# First step, self-attention
|
135
|
+
residual = x
|
136
|
+
if not self.use_post_norm:
|
137
|
+
x = self.norm1(x)
|
138
|
+
x = self.attention(x, x, x, mask=mask)
|
139
|
+
x = residual + x
|
140
|
+
if self.use_post_norm:
|
141
|
+
x = self.norm1(x)
|
142
|
+
# Second step, Feed Forward network
|
143
|
+
residual = x
|
144
|
+
if not self.use_post_norm:
|
145
|
+
x = self.norm2(x)
|
146
|
+
x = self.ff(x)
|
147
|
+
x = residual + x
|
148
|
+
if self.use_post_norm:
|
149
|
+
x = self.norm2(x)
|
150
|
+
return x
|
src/transformers/mask.py
ADDED
@@ -0,0 +1,10 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
|
4
|
+
def create_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
|
5
|
+
"""Create a causal (lower triangular) attention mask for a given sequence length."""
|
6
|
+
# Create a lower triangular matrix of ones
|
7
|
+
mask = torch.tril(torch.ones((seq_len, seq_len), device=device))
|
8
|
+
# Expand the mask to have the shape (1, 1, seq_len, seq_len)
|
9
|
+
mask = mask.unsqueeze(0).unsqueeze(0)
|
10
|
+
return mask
|
@@ -0,0 +1,168 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
from positional import AbsolutePositionalEmbedding
|
4
|
+
from mask import create_causal_mask
|
5
|
+
from src.memory.stm import ShortTermMemory
|
6
|
+
|
7
|
+
|
8
|
+
class ReactiveTransformerBase(nn.Module):
|
9
|
+
"""Base class for Reactive Transformer models - common logic for both decoders and encoders."""
|
10
|
+
|
11
|
+
def __init__(
|
12
|
+
self,
|
13
|
+
stm: ShortTermMemory,
|
14
|
+
embedding: nn.Embedding,
|
15
|
+
own_layers: nn.ModuleList,
|
16
|
+
shared_layers: nn.ModuleList = None,
|
17
|
+
absolute_embedding: AbsolutePositionalEmbedding = None,
|
18
|
+
use_flash_attention: bool = False,
|
19
|
+
*args,
|
20
|
+
**kwargs,
|
21
|
+
):
|
22
|
+
super(ReactiveTransformerBase, self).__init__(*args, **kwargs)
|
23
|
+
|
24
|
+
self.embedding = embedding
|
25
|
+
self.stm = stm
|
26
|
+
self.pos_embedding = absolute_embedding
|
27
|
+
self.use_flash_attention = use_flash_attention
|
28
|
+
|
29
|
+
self.shared_layers = shared_layers
|
30
|
+
self.layers = own_layers
|
31
|
+
self.num_shared_layers = len(shared_layers) if shared_layers else 0
|
32
|
+
self.num_own_layers = len(own_layers) if own_layers else 0
|
33
|
+
|
34
|
+
def trainable_cross_attention_(self, is_trainable: bool):
|
35
|
+
for i in range(self.num_shared_layers):
|
36
|
+
self.shared_layers[i].trainable_cross_attention_(is_trainable)
|
37
|
+
for i in range(self.num_own_layers):
|
38
|
+
self.layers[i].trainable_cross_attention_(is_trainable)
|
39
|
+
|
40
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
41
|
+
# Shared logic for encoders and decoders - apply embeddings and positional encoding
|
42
|
+
x = self.embedding(x)
|
43
|
+
if self.pos_embedding is not None:
|
44
|
+
x = self.pos_embedding(x)
|
45
|
+
return x
|
46
|
+
|
47
|
+
|
48
|
+
class ReactiveTransformerDecoder(ReactiveTransformerBase):
|
49
|
+
"""Reactive Transformer decoder - extending the classic Transformer decoder with Memory Cross-Attention"""
|
50
|
+
|
51
|
+
def __init__(self, embed_dim: int, vocab_size: int, *args, **kwargs):
|
52
|
+
super(ReactiveTransformerDecoder, self).__init__(*args, **kwargs)
|
53
|
+
self.head = nn.Linear(embed_dim, vocab_size)
|
54
|
+
|
55
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
56
|
+
x = super().forward(x) # apply embeddings
|
57
|
+
seq_len = x.size(1)
|
58
|
+
if not self.use_flash_attention:
|
59
|
+
mask = create_causal_mask(seq_len, device=x.device)
|
60
|
+
if attention_mask is not None:
|
61
|
+
mask &= attention_mask.unsqueeze(1).unsqueeze(1).bool()
|
62
|
+
elif attention_mask is not None:
|
63
|
+
mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
|
64
|
+
else:
|
65
|
+
mask = None
|
66
|
+
# Process shared layers
|
67
|
+
if self.shared_layers is not None:
|
68
|
+
for i in range(self.num_shared_layers):
|
69
|
+
layer_stm = self.stm(i)
|
70
|
+
x = self.shared_layers[i](x, layer_stm, mask=mask)
|
71
|
+
# Process own layers
|
72
|
+
for i in range(self.num_own_layers):
|
73
|
+
layer_stm = self.stm(i)
|
74
|
+
x = self.layers[i](x, layer_stm, mask=mask)
|
75
|
+
return self.head(x)
|
76
|
+
|
77
|
+
|
78
|
+
class ReactiveTransformerEncoder(ReactiveTransformerBase):
|
79
|
+
"""Reactive Transformer encoder - extending the classic Transformer encoder with Memory Cross-Attention"""
|
80
|
+
|
81
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
|
82
|
+
x = super().forward(x) # apply embeddings
|
83
|
+
if attention_mask is not None:
|
84
|
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
|
85
|
+
|
86
|
+
hidden_states = []
|
87
|
+
# Process shared layers
|
88
|
+
if self.shared_layers is not None:
|
89
|
+
for i in range(self.num_shared_layers):
|
90
|
+
layer_stm = self.stm(i)
|
91
|
+
x = self.shared_layers[i](x, layer_stm, mask=attention_mask)
|
92
|
+
hidden_states.append(x)
|
93
|
+
# Process own layers
|
94
|
+
for i in range(self.num_own_layers):
|
95
|
+
layer_stm = self.stm(i)
|
96
|
+
x = self.layers[i](x, layer_stm, mask=attention_mask)
|
97
|
+
hidden_states.append(x)
|
98
|
+
return x, torch.stack(hidden_states)
|
99
|
+
|
100
|
+
|
101
|
+
class ClassicTransformerBase(nn.Module):
|
102
|
+
"""Base class for Classic Transformer models - common logic for both decoders and encoders."""
|
103
|
+
|
104
|
+
def __init__(
|
105
|
+
self,
|
106
|
+
embedding: nn.Embedding,
|
107
|
+
layers: nn.ModuleList,
|
108
|
+
absolute_embedding: AbsolutePositionalEmbedding = None,
|
109
|
+
use_flash_attention: bool = False,
|
110
|
+
*args,
|
111
|
+
**kwargs,
|
112
|
+
):
|
113
|
+
super(ClassicTransformerBase, self).__init__(*args, **kwargs)
|
114
|
+
|
115
|
+
self.embedding = embedding
|
116
|
+
self.pos_embedding = absolute_embedding
|
117
|
+
self.use_flash_attention = use_flash_attention
|
118
|
+
|
119
|
+
self.layers = layers
|
120
|
+
self.num_layers = len(layers) if layers else 0
|
121
|
+
|
122
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
123
|
+
# Shared logic for encoders and decoders - apply embeddings and positional encoding
|
124
|
+
x = self.embedding(x)
|
125
|
+
if self.pos_embedding is not None:
|
126
|
+
x = self.pos_embedding(x)
|
127
|
+
return x
|
128
|
+
|
129
|
+
|
130
|
+
class ClassicTransformerDecoder(ClassicTransformerBase):
|
131
|
+
"""Classic Transformer decoder - for decoder-only Transformer models"""
|
132
|
+
|
133
|
+
def __init__(self, embed_dim: int, vocab_size: int, *args, **kwargs):
|
134
|
+
super(ClassicTransformerDecoder, self).__init__(*args, **kwargs)
|
135
|
+
self.head = nn.Linear(embed_dim, vocab_size)
|
136
|
+
|
137
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
138
|
+
x = super().forward(x) # apply embeddings
|
139
|
+
seq_len = x.size(1)
|
140
|
+
if not self.use_flash_attention:
|
141
|
+
mask = create_causal_mask(seq_len, device=x.device)
|
142
|
+
if attention_mask is not None:
|
143
|
+
mask &= attention_mask.unsqueeze(1).unsqueeze(1).bool()
|
144
|
+
elif attention_mask is not None:
|
145
|
+
mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
|
146
|
+
else:
|
147
|
+
mask = None
|
148
|
+
|
149
|
+
# Process layers
|
150
|
+
for i in range(self.num_layers):
|
151
|
+
x = self.layers[i](x, mask=mask)
|
152
|
+
return self.head(x)
|
153
|
+
|
154
|
+
|
155
|
+
class ClassicTransformerEncoder(ClassicTransformerBase):
|
156
|
+
"""Classic Transformer encoder - for encoder-only Transformer models"""
|
157
|
+
|
158
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
|
159
|
+
x = super().forward(x) # apply embeddings
|
160
|
+
if attention_mask is not None:
|
161
|
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
|
162
|
+
|
163
|
+
hidden_states = []
|
164
|
+
# Process own layers
|
165
|
+
for i in range(self.num_own_layers):
|
166
|
+
x = self.layers[i](x, mask=attention_mask)
|
167
|
+
hidden_states.append(x)
|
168
|
+
return x, torch.stack(hidden_states)
|
src/transformers/moe.py
ADDED
@@ -0,0 +1,139 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
|
5
|
+
|
6
|
+
class MoeRouter(nn.Module):
|
7
|
+
"""Mixture-of-Experts Router layer - computes routing weights for each expert."""
|
8
|
+
|
9
|
+
def __init__(self, embed_dim: int, num_experts: int, top_k: int = 1, *args, **kwargs):
|
10
|
+
super(MoeRouter, self).__init__(*args, **kwargs)
|
11
|
+
self.top_k = top_k
|
12
|
+
self.num_experts = num_experts
|
13
|
+
self.gate = nn.Linear(embed_dim, num_experts, bias=False)
|
14
|
+
self.aux_loss = 0.0 # For expert load balancing
|
15
|
+
|
16
|
+
def forward(self, x: torch.Tensor):
|
17
|
+
# x shape: [batch_size*seq_len, embed_dim]
|
18
|
+
logits = self.gate(x)
|
19
|
+
probs = F.softmax(logits, dim=-1)
|
20
|
+
|
21
|
+
# Expert load balancing loss
|
22
|
+
if self.training:
|
23
|
+
probs_for_bal = F.softmax(logits, dim=0)
|
24
|
+
self.aux_loss = (probs_for_bal.mean(dim=0) *
|
25
|
+
torch.log(probs_for_bal.mean(dim=0) + 1e-9)).sum()
|
26
|
+
|
27
|
+
top_k_weights, top_k_indices = probs.topk(self.top_k, dim=-1)
|
28
|
+
top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9)
|
29
|
+
|
30
|
+
return top_k_weights, top_k_indices
|
31
|
+
|
32
|
+
|
33
|
+
class MoeFeedForward(nn.Module):
|
34
|
+
"""Mixture-of-Experts Feed-Forward layer - combines multiple experts into a single model."""
|
35
|
+
|
36
|
+
def __init__(
|
37
|
+
self,
|
38
|
+
embed_dim: int,
|
39
|
+
hidden_dim: int,
|
40
|
+
num_experts: int,
|
41
|
+
activation: nn.Module,
|
42
|
+
top_k: int = 1,
|
43
|
+
dropout: float = 0.0,
|
44
|
+
*args,
|
45
|
+
**kwargs
|
46
|
+
):
|
47
|
+
super(MoeFeedForward, self).__init__(*args, **kwargs)
|
48
|
+
self.embed_dim = embed_dim
|
49
|
+
self.num_experts = num_experts
|
50
|
+
self.top_k = top_k
|
51
|
+
|
52
|
+
self.router = MoeRouter(embed_dim, num_experts, top_k)
|
53
|
+
|
54
|
+
# Batch all expert parameters together
|
55
|
+
self.w1 = nn.Parameter(torch.empty(num_experts, embed_dim, self._w1_dim_factor(hidden_dim)))
|
56
|
+
self.b1 = nn.Parameter(torch.zeros(num_experts, self._w1_dim_factor(hidden_dim)))
|
57
|
+
self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, embed_dim))
|
58
|
+
self.b2 = nn.Parameter(torch.zeros(num_experts, embed_dim))
|
59
|
+
self.activation = activation
|
60
|
+
self.dropout = nn.Dropout(dropout)
|
61
|
+
|
62
|
+
# Initialize parameters
|
63
|
+
self._init_linear_parameters()
|
64
|
+
nn.init.zeros_(self.b1)
|
65
|
+
nn.init.zeros_(self.b2)
|
66
|
+
|
67
|
+
def _init_linear_parameters(self):
|
68
|
+
nn.init.kaiming_normal_(self.w1, nonlinearity='relu')
|
69
|
+
nn.init.kaiming_normal_(self.w2, nonlinearity='relu')
|
70
|
+
|
71
|
+
def _w1_dim_factor(self, hidden_dim: int) -> int:
|
72
|
+
return hidden_dim
|
73
|
+
|
74
|
+
def _activate(self, h: torch.Tensor):
|
75
|
+
return self.activation(h)
|
76
|
+
|
77
|
+
def forward(self, x: torch.Tensor):
|
78
|
+
orig_shape = x.shape
|
79
|
+
x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
|
80
|
+
|
81
|
+
# Get routing weights and indices
|
82
|
+
weights, indices = self.router(x) # [batch*seq_len, top_k]
|
83
|
+
|
84
|
+
# Create expert masks and combine it with masks
|
85
|
+
mask = F.one_hot(indices, self.num_experts).float() # [batch*seq_len, top_k, num_experts]
|
86
|
+
weights = (weights.unsqueeze(-1) * mask).sum(dim=1) # [batch*seq_len, num_experts]
|
87
|
+
|
88
|
+
# Expert computation
|
89
|
+
x = x.unsqueeze(1).expand(-1, self.num_experts, -1) # [batch*seq_len, num_experts, embed_dim]
|
90
|
+
|
91
|
+
# First linear layer
|
92
|
+
h = torch.einsum('bie,ieh->bih', x, self.w1) + self.b1 # [batch*seq_len, num_experts, hidden_dim]
|
93
|
+
h = self._activate(h)
|
94
|
+
h = self.dropout(h)
|
95
|
+
|
96
|
+
# Second linear layer (projection back to embed_dim)
|
97
|
+
out = torch.einsum('bih,ihe->bie', h, self.w2) + self.b2 # [batch*seq_len, num_experts, embed_dim]
|
98
|
+
|
99
|
+
# Weighted sum of expert outputs
|
100
|
+
out = (out * weights.unsqueeze(-1)).sum(dim=1) # [batch*seq_len, embed_dim]
|
101
|
+
|
102
|
+
return out.view(*orig_shape)
|
103
|
+
|
104
|
+
|
105
|
+
class GatedMoeFeedForward(MoeFeedForward):
|
106
|
+
"""Gated Mixture-of-Experts Feed-Forward layer - enable GLU-based activations for MoE"""
|
107
|
+
|
108
|
+
def __init__(
|
109
|
+
self,
|
110
|
+
embed_dim: int,
|
111
|
+
hidden_dim: int,
|
112
|
+
num_experts: int,
|
113
|
+
activation: nn.Module = nn.SiLU(),
|
114
|
+
top_k: int = 1,
|
115
|
+
dropout: float = 0.1,
|
116
|
+
*args,
|
117
|
+
**kwargs
|
118
|
+
):
|
119
|
+
super(GatedMoeFeedForward, self).__init__(
|
120
|
+
embed_dim=embed_dim,
|
121
|
+
hidden_dim=hidden_dim,
|
122
|
+
num_experts=num_experts,
|
123
|
+
activation=activation,
|
124
|
+
top_k=top_k,
|
125
|
+
dropout=dropout,
|
126
|
+
*args,
|
127
|
+
**kwargs
|
128
|
+
)
|
129
|
+
|
130
|
+
def _init_linear_parameters(self):
|
131
|
+
nn.init.kaiming_normal_(self.w1, nonlinearity='relu')
|
132
|
+
nn.init.kaiming_normal_(self.w2, nonlinearity='linear')
|
133
|
+
|
134
|
+
def _w1_dim_factor(self, hidden_dim: int) -> int:
|
135
|
+
return 2 * hidden_dim
|
136
|
+
|
137
|
+
def _activate(self, h: torch.Tensor):
|
138
|
+
a, b = h.chunk(2, dim=-1)
|
139
|
+
return a * self.activation(b)
|
@@ -0,0 +1,105 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import nn
|
3
|
+
import math
|
4
|
+
|
5
|
+
|
6
|
+
class RotaryPositionalEmbedding(nn.Module):
|
7
|
+
"""Rotary Positional Embedding (RoPE) layer - recommended for positional encoding"""
|
8
|
+
|
9
|
+
def __init__(self, dim: int, max_seq_len: int = 4096, base: int = 10000, *args, **kwargs):
|
10
|
+
super(RotaryPositionalEmbedding, self).__init__(*args, **kwargs)
|
11
|
+
self.dim = dim
|
12
|
+
self.max_seq_len = max_seq_len
|
13
|
+
self.base = base
|
14
|
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
15
|
+
self.register_buffer('inv_freq', inv_freq)
|
16
|
+
self.register_buffer('cache', None, persistent=False)
|
17
|
+
|
18
|
+
def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
19
|
+
device = q.device
|
20
|
+
seq_len = q.size(-2)
|
21
|
+
# Prepare RoPE Frequencies
|
22
|
+
freqs = self._prepare_freqs(seq_len, device)
|
23
|
+
|
24
|
+
# Apply the rotation to the queries
|
25
|
+
q_embed = self._rotate(q, freqs)
|
26
|
+
# Apply the rotation to the keys
|
27
|
+
k_embed = self._rotate(k, freqs)
|
28
|
+
|
29
|
+
return q_embed, k_embed
|
30
|
+
|
31
|
+
def forward_one(self, q: torch.Tensor) -> torch.Tensor:
|
32
|
+
device = q.device
|
33
|
+
seq_len = q.size(-2)
|
34
|
+
# Prepare RoPE Frequencies
|
35
|
+
freqs = self._prepare_freqs(seq_len, device)
|
36
|
+
|
37
|
+
# Apply the rotation to the queries
|
38
|
+
q_embed = self._rotate(q, freqs)
|
39
|
+
|
40
|
+
return q_embed
|
41
|
+
|
42
|
+
def _prepare_freqs(self, seq_len: int, device: torch.device) -> torch.Tensor:
|
43
|
+
if self.cache is None or self.cache.size(1) < seq_len:
|
44
|
+
t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
|
45
|
+
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
|
46
|
+
self.cache = freqs
|
47
|
+
return freqs[None, None, :, :]
|
48
|
+
else:
|
49
|
+
return self.cache[None, None, :, :]
|
50
|
+
|
51
|
+
def _rotate(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
52
|
+
x1 = x[..., 0::2]
|
53
|
+
x2 = x[..., 1::2]
|
54
|
+
# Apply the rotation
|
55
|
+
x_rotated1 = x1 * torch.cos(freqs) - x2 * torch.sin(freqs)
|
56
|
+
x_rotated2 = x1 * torch.sin(freqs) + x2 * torch.cos(freqs)
|
57
|
+
# Concatenate the rotated parts back together
|
58
|
+
x_rotated = torch.cat((x_rotated1, x_rotated2), dim=-1)
|
59
|
+
return x_rotated
|
60
|
+
|
61
|
+
|
62
|
+
class AbsolutePositionalEmbedding(nn.Module):
|
63
|
+
"""Absolute Positional Embedding layer (legacy) - not recommended for memory-augmented Reactive Transformers"""
|
64
|
+
|
65
|
+
def __init__(self, max_seq_len: int, embed_dim: int, *args, **kwargs):
|
66
|
+
super(AbsolutePositionalEmbedding, self).__init__(*args, **kwargs)
|
67
|
+
self.max_seq_len = max_seq_len
|
68
|
+
self.embed_dim = embed_dim
|
69
|
+
self.position_embeddings = nn.Embedding(max_seq_len, embed_dim)
|
70
|
+
|
71
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
72
|
+
# Create position indices
|
73
|
+
seq_len = x.size(1)
|
74
|
+
positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(x.size(0), -1)
|
75
|
+
# Get position embeddings
|
76
|
+
pos_embeddings = self.position_embeddings(positions)
|
77
|
+
# Add position embeddings to the input embeddings
|
78
|
+
return x + pos_embeddings
|
79
|
+
|
80
|
+
|
81
|
+
class RelativePositionalEmbedding(nn.Module):
|
82
|
+
"""Relative Positional Embedding layer (legacy) - not compatible with Flash Attention and not recommended for positional encoding"""
|
83
|
+
|
84
|
+
def __init__(self, max_seq_len: int, embed_dim: int, *args, **kwargs):
|
85
|
+
super(RelativePositionalEmbedding, self).__init__(*args, **kwargs)
|
86
|
+
self.max_seq_len = max_seq_len
|
87
|
+
self.embed_dim = embed_dim
|
88
|
+
self.position_embeddings = nn.Embedding(2 * max_seq_len - 1, embed_dim)
|
89
|
+
self.embed_dim_sqrt = math.sqrt(embed_dim)
|
90
|
+
|
91
|
+
def forward(self, q: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
|
92
|
+
q_len = q.size(2)
|
93
|
+
k_len = k.size(2)
|
94
|
+
|
95
|
+
# Create relative position indices
|
96
|
+
indices = torch.arange(q_len, device=q.device)[:, None] - torch.arange(k_len, device=k.device)[None, :]
|
97
|
+
indices += self.max_seq_len - 1 # Shift to non-negative
|
98
|
+
indices = torch.clamp(indices, 0, 2 * self.max_seq_len - 2)
|
99
|
+
|
100
|
+
# Get embeddings
|
101
|
+
rel_emb = self.position_embeddings(indices)
|
102
|
+
|
103
|
+
rel_emb = rel_emb.permute(2, 0, 1)
|
104
|
+
rel_pos_bias = torch.einsum('bhqd, dqk -> bhqk', q, rel_emb)
|
105
|
+
return rel_pos_bias / self.embed_dim_sqrt
|
@@ -0,0 +1,109 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from typing import Iterator
|
5
|
+
|
6
|
+
def sample(
|
7
|
+
logits: torch.Tensor,
|
8
|
+
temperature: float = 1.0,
|
9
|
+
top_k: int = None,
|
10
|
+
top_p: float = None,
|
11
|
+
) -> torch.Tensor:
|
12
|
+
if temperature <= 0:
|
13
|
+
raise ValueError("Temperature must be > 0")
|
14
|
+
|
15
|
+
# Apply temperature
|
16
|
+
logits = logits / temperature
|
17
|
+
|
18
|
+
# Apply top-k filtering
|
19
|
+
if top_k is not None and top_k > 0:
|
20
|
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
21
|
+
logits[indices_to_remove] = float('-inf')
|
22
|
+
|
23
|
+
# Apply top-p (nucleus) sampling
|
24
|
+
if top_p is not None and 0 < top_p <= 1.0:
|
25
|
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
26
|
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
27
|
+
|
28
|
+
# Remove tokens with cumulative probability above threshold
|
29
|
+
sorted_indices_to_remove = cumulative_probs > top_p
|
30
|
+
# Shift right to keep first token above threshold
|
31
|
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
32
|
+
sorted_indices_to_remove[..., 0] = 0
|
33
|
+
|
34
|
+
# Scatter sorted indices back to original positions
|
35
|
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
36
|
+
dim=-1,
|
37
|
+
index=sorted_indices,
|
38
|
+
src=sorted_indices_to_remove
|
39
|
+
)
|
40
|
+
logits[indices_to_remove] = float('-inf')
|
41
|
+
|
42
|
+
# Convert to probabilities
|
43
|
+
probs = F.softmax(logits, dim=-1)
|
44
|
+
|
45
|
+
# Sample from distribution
|
46
|
+
return torch.multinomial(probs, num_samples=1)
|
47
|
+
|
48
|
+
class Sampler:
|
49
|
+
def __init__(self, model: nn.Module, device: torch.device, end_token_id: int):
|
50
|
+
self.model = model.to(device)
|
51
|
+
self.device = device
|
52
|
+
self.end_token_id = end_token_id
|
53
|
+
|
54
|
+
def _generate_token(
|
55
|
+
self,
|
56
|
+
input_ids: torch.Tensor,
|
57
|
+
temperature: float,
|
58
|
+
top_k: int,
|
59
|
+
top_p: float ,
|
60
|
+
attention_mask: torch.Tensor,
|
61
|
+
) -> tuple[int, torch.Tensor, torch.Tensor]:
|
62
|
+
# Forward pass to get next token logits
|
63
|
+
outputs = self.model(input_ids, attention_mask=attention_mask)
|
64
|
+
next_token_logits = outputs[:, -1, :] # Get logits for next token
|
65
|
+
# Apply sampling
|
66
|
+
next_token = sample(
|
67
|
+
next_token_logits,
|
68
|
+
temperature=temperature,
|
69
|
+
top_k=top_k,
|
70
|
+
top_p=top_p,
|
71
|
+
)
|
72
|
+
next_token = next_token.item() # Extract scalar token
|
73
|
+
next_token_ten = torch.tensor([[next_token]], device=self.device)
|
74
|
+
next_input_ids = torch.cat([input_ids, next_token_ten], dim=1)
|
75
|
+
new_one = torch.ones(1, 1, dtype=torch.bool, device=self.device)
|
76
|
+
next_mask = torch.cat([attention_mask, new_one], dim=1) if attention_mask is not None else None
|
77
|
+
# Yield the generated token
|
78
|
+
return (
|
79
|
+
next_token,
|
80
|
+
next_input_ids,
|
81
|
+
next_mask
|
82
|
+
)
|
83
|
+
|
84
|
+
def __call__(
|
85
|
+
self,
|
86
|
+
initial_tokens: torch.Tensor,
|
87
|
+
temperature: float = 1.0,
|
88
|
+
top_k: int = None,
|
89
|
+
top_p: float = None,
|
90
|
+
max_seq_len: int = 50,
|
91
|
+
attention_mask: torch.Tensor = None,
|
92
|
+
no_grad: bool = True,
|
93
|
+
) -> Iterator[int]:
|
94
|
+
# Convert initial tokens to tensor and move to device
|
95
|
+
input_ids = initial_tokens
|
96
|
+
|
97
|
+
if no_grad:
|
98
|
+
with torch.no_grad():
|
99
|
+
for _ in range(max_seq_len):
|
100
|
+
next_token, input_ids, attention_mask = self._generate_token(input_ids, temperature, top_k, top_p, attention_mask)
|
101
|
+
yield next_token
|
102
|
+
if next_token == self.end_token_id:
|
103
|
+
break
|
104
|
+
else:
|
105
|
+
for _ in range(max_seq_len):
|
106
|
+
next_token, input_ids, attention_mask = self._generate_token(input_ids, temperature, top_k, top_p, attention_mask)
|
107
|
+
yield next_token
|
108
|
+
if next_token == self.end_token_id:
|
109
|
+
break
|
src/utils.py
ADDED
@@ -0,0 +1,14 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
def human_format(num: int):
|
4
|
+
num = float('{:.3g}'.format(num))
|
5
|
+
magnitude = 0
|
6
|
+
while abs(num) >= 1000:
|
7
|
+
magnitude += 1
|
8
|
+
num /= 1000.0
|
9
|
+
return '{}{}'.format('{:f}'.format(num).rstrip('0').rstrip('.'), ['', 'K', 'M', 'B', 'T'][magnitude])
|
10
|
+
|
11
|
+
|
12
|
+
def get_model_size(model: torch.nn.Module):
|
13
|
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
14
|
+
return f'Model params {human_format(trainable_params)}'
|