Stackformer 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 GURUMURTHY
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,75 @@
1
+ Metadata-Version: 2.4
2
+ Name: Stackformer
3
+ Version: 0.1.0
4
+ Summary: Modular transformer blocks built in PyTorch
5
+ Home-page: https://github.com/Gurumurthy30/Stackformer
6
+ Author: Gurumurthy
7
+ Author-email: Gurumurthy <gurumurthy.00300@gmail.com>
8
+ License: MIT
9
+ Project-URL: Repository, https://github.com/Gurumurthy30/Stackformer
10
+ Project-URL: Issue Tracker, https://github.com/Gurumurthy30/Stackformer/issues
11
+ Project-URL: Discussions, https://github.com/Gurumurthy30/Stackformer/discussions
12
+ Requires-Python: >=3.9
13
+ Description-Content-Type: text/markdown
14
+ License-File: LICENSE
15
+ Requires-Dist: torch>=2.6
16
+ Requires-Dist: tqdm>=4.67
17
+ Dynamic: author
18
+ Dynamic: home-page
19
+ Dynamic: license-file
20
+ Dynamic: requires-python
21
+
22
+ ## 🧱 Stackformer
23
+
24
+ **Stackformer** is a modular transformer-building framework written entirely in PyTorch. It is designed primarily for experimentation, providing various transformer blocks such as attention mechanisms, normalization layers, feed-forward networks, and a simple model architecture. The project is a work-in-progress with plans for further enhancements and expansions.
25
+
26
+ ---
27
+
28
+ ## 📖 About Me
29
+
30
+ My name is **Gurumurthy**, and I am a final-year Bachelor of Engineering student from India. I created this library as my own size project to showcase my skills and knowledge in deep learning and transformer architectures.
31
+
32
+ I am also interested and free to work with others on different projects for knowledge sharing and building connections.
33
+
34
+ ---
35
+
36
+ ## 🌟 Features
37
+
38
+ - Multiple attention mechanisms including multi-head, group query, linear, local, and KV cache variants
39
+ - Token embedding via `tiktoken`
40
+ - Absolute and sinusoidal positional embeddings
41
+ - Normalization layers like LayerNorm and RMSNorm
42
+ - Several feed-forward network variants with activations such as ReLU, GELU, SiLU, LeakyReLU, and Sigmoid
43
+ - A simple GPT-style transformer model implementation
44
+
45
+ ---
46
+
47
+ ## 📁 Project Structure
48
+
49
+ stackformer/ \
50
+ |-- modules/ \
51
+ | |-- tokenizer.py # Token embedding using tiktoken \
52
+ | |-- position_embedding.py # Absolute and sinusoidal embeddings \
53
+ | |-- Attention.py # Attention mechanisms \
54
+ | |-- Normalization.py # LayerNorm and RMSNorm \
55
+ | |-- Feed_forward.py # Feed-forward layers with various activations \
56
+ |-- models/ \
57
+ | -- GPT_2.py # GPT-style transformer stack model \
58
+ -- trainer.py # Training loop and utilities \
59
+
60
+ ---
61
+
62
+ ## 💻 Installation
63
+
64
+ Clone the repository and install in development mode:
65
+
66
+ ```bash
67
+ git clone https://github.com/Gurumurthy30/Stackformer
68
+ cd transformers
69
+ pip install -e .
70
+ ```
71
+
72
+ ---
73
+
74
+ ## 🚀 Future Plans
75
+ Currently, I am working on improving and optimizing the existing components while fixing known bugs and issues. After stabilizing the current modules, I plan to add more advanced blocks like Mixture of Experts (MoE), mask handling, and other essential transformer components. Eventually, I will expand the library by developing more comprehensive model architectures.
@@ -0,0 +1,54 @@
1
+ ## 🧱 Stackformer
2
+
3
+ **Stackformer** is a modular transformer-building framework written entirely in PyTorch. It is designed primarily for experimentation, providing various transformer blocks such as attention mechanisms, normalization layers, feed-forward networks, and a simple model architecture. The project is a work-in-progress with plans for further enhancements and expansions.
4
+
5
+ ---
6
+
7
+ ## 📖 About Me
8
+
9
+ My name is **Gurumurthy**, and I am a final-year Bachelor of Engineering student from India. I created this library as my own size project to showcase my skills and knowledge in deep learning and transformer architectures.
10
+
11
+ I am also interested and free to work with others on different projects for knowledge sharing and building connections.
12
+
13
+ ---
14
+
15
+ ## 🌟 Features
16
+
17
+ - Multiple attention mechanisms including multi-head, group query, linear, local, and KV cache variants
18
+ - Token embedding via `tiktoken`
19
+ - Absolute and sinusoidal positional embeddings
20
+ - Normalization layers like LayerNorm and RMSNorm
21
+ - Several feed-forward network variants with activations such as ReLU, GELU, SiLU, LeakyReLU, and Sigmoid
22
+ - A simple GPT-style transformer model implementation
23
+
24
+ ---
25
+
26
+ ## 📁 Project Structure
27
+
28
+ stackformer/ \
29
+ |-- modules/ \
30
+ | |-- tokenizer.py # Token embedding using tiktoken \
31
+ | |-- position_embedding.py # Absolute and sinusoidal embeddings \
32
+ | |-- Attention.py # Attention mechanisms \
33
+ | |-- Normalization.py # LayerNorm and RMSNorm \
34
+ | |-- Feed_forward.py # Feed-forward layers with various activations \
35
+ |-- models/ \
36
+ | -- GPT_2.py # GPT-style transformer stack model \
37
+ -- trainer.py # Training loop and utilities \
38
+
39
+ ---
40
+
41
+ ## 💻 Installation
42
+
43
+ Clone the repository and install in development mode:
44
+
45
+ ```bash
46
+ git clone https://github.com/Gurumurthy30/Stackformer
47
+ cd transformers
48
+ pip install -e .
49
+ ```
50
+
51
+ ---
52
+
53
+ ## 🚀 Future Plans
54
+ Currently, I am working on improving and optimizing the existing components while fixing known bugs and issues. After stabilizing the current modules, I plan to add more advanced blocks like Mixture of Experts (MoE), mask handling, and other essential transformer components. Eventually, I will expand the library by developing more comprehensive model architectures.
@@ -0,0 +1,75 @@
1
+ Metadata-Version: 2.4
2
+ Name: Stackformer
3
+ Version: 0.1.0
4
+ Summary: Modular transformer blocks built in PyTorch
5
+ Home-page: https://github.com/Gurumurthy30/Stackformer
6
+ Author: Gurumurthy
7
+ Author-email: Gurumurthy <gurumurthy.00300@gmail.com>
8
+ License: MIT
9
+ Project-URL: Repository, https://github.com/Gurumurthy30/Stackformer
10
+ Project-URL: Issue Tracker, https://github.com/Gurumurthy30/Stackformer/issues
11
+ Project-URL: Discussions, https://github.com/Gurumurthy30/Stackformer/discussions
12
+ Requires-Python: >=3.9
13
+ Description-Content-Type: text/markdown
14
+ License-File: LICENSE
15
+ Requires-Dist: torch>=2.6
16
+ Requires-Dist: tqdm>=4.67
17
+ Dynamic: author
18
+ Dynamic: home-page
19
+ Dynamic: license-file
20
+ Dynamic: requires-python
21
+
22
+ ## 🧱 Stackformer
23
+
24
+ **Stackformer** is a modular transformer-building framework written entirely in PyTorch. It is designed primarily for experimentation, providing various transformer blocks such as attention mechanisms, normalization layers, feed-forward networks, and a simple model architecture. The project is a work-in-progress with plans for further enhancements and expansions.
25
+
26
+ ---
27
+
28
+ ## 📖 About Me
29
+
30
+ My name is **Gurumurthy**, and I am a final-year Bachelor of Engineering student from India. I created this library as my own size project to showcase my skills and knowledge in deep learning and transformer architectures.
31
+
32
+ I am also interested and free to work with others on different projects for knowledge sharing and building connections.
33
+
34
+ ---
35
+
36
+ ## 🌟 Features
37
+
38
+ - Multiple attention mechanisms including multi-head, group query, linear, local, and KV cache variants
39
+ - Token embedding via `tiktoken`
40
+ - Absolute and sinusoidal positional embeddings
41
+ - Normalization layers like LayerNorm and RMSNorm
42
+ - Several feed-forward network variants with activations such as ReLU, GELU, SiLU, LeakyReLU, and Sigmoid
43
+ - A simple GPT-style transformer model implementation
44
+
45
+ ---
46
+
47
+ ## 📁 Project Structure
48
+
49
+ stackformer/ \
50
+ |-- modules/ \
51
+ | |-- tokenizer.py # Token embedding using tiktoken \
52
+ | |-- position_embedding.py # Absolute and sinusoidal embeddings \
53
+ | |-- Attention.py # Attention mechanisms \
54
+ | |-- Normalization.py # LayerNorm and RMSNorm \
55
+ | |-- Feed_forward.py # Feed-forward layers with various activations \
56
+ |-- models/ \
57
+ | -- GPT_2.py # GPT-style transformer stack model \
58
+ -- trainer.py # Training loop and utilities \
59
+
60
+ ---
61
+
62
+ ## 💻 Installation
63
+
64
+ Clone the repository and install in development mode:
65
+
66
+ ```bash
67
+ git clone https://github.com/Gurumurthy30/Stackformer
68
+ cd transformers
69
+ pip install -e .
70
+ ```
71
+
72
+ ---
73
+
74
+ ## 🚀 Future Plans
75
+ Currently, I am working on improving and optimizing the existing components while fixing known bugs and issues. After stabilizing the current modules, I plan to add more advanced blocks like Mixture of Experts (MoE), mask handling, and other essential transformer components. Eventually, I will expand the library by developing more comprehensive model architectures.
@@ -0,0 +1,18 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ setup.py
5
+ Stackformer.egg-info/PKG-INFO
6
+ Stackformer.egg-info/SOURCES.txt
7
+ Stackformer.egg-info/dependency_links.txt
8
+ Stackformer.egg-info/requires.txt
9
+ Stackformer.egg-info/top_level.txt
10
+ models/GPT_2.py
11
+ models/__init__.py
12
+ modules/Attention.py
13
+ modules/Feed_forward.py
14
+ modules/Normalization.py
15
+ modules/__init__.py
16
+ modules/mask.py
17
+ modules/position_embedding.py
18
+ modules/tokenizer.py
@@ -0,0 +1,2 @@
1
+ torch>=2.6
2
+ tqdm>=4.67
@@ -0,0 +1,2 @@
1
+ models
2
+ modules
@@ -0,0 +1,238 @@
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ # --- position embedding ---
7
+ class SinusoidalPositionalEmbedding(nn.Module):
8
+ def __init__(self, seq_len, emb_dim):
9
+ super().__init__()
10
+ self.seq_len = seq_len
11
+ self.emb_dim = emb_dim
12
+
13
+ position = torch.arange(0, seq_len).unsqueeze(1)
14
+ div_term = torch.exp(torch.arange(0, emb_dim, 2) * -(math.log(10000.0) / emb_dim))
15
+
16
+ pe = torch.zeros(seq_len, emb_dim)
17
+ pe[:, 0::2] = torch.sin(position * div_term)
18
+ pe[:, 1::2] = torch.cos(position * div_term)
19
+
20
+ self.register_buffer("pe", pe)
21
+
22
+ def forward(self, x):
23
+ # x shape: (batch_size, seq_len, emb_dim) or (batch_size, seq_len)
24
+ batch_size, seq_len = x.shape[0], x.shape[1]
25
+ return self.pe[:seq_len].unsqueeze(0).expand(batch_size, seq_len, -1).to(x.device)
26
+
27
+ # --- Multi Head Attention ---
28
+ class MultiHeadAttention(nn.Module):
29
+ def __init__(self, Emb_dim, num_heads, dropout=0.1, device='cpu', dtype=torch.float32):
30
+ super().__init__()
31
+ assert Emb_dim % num_heads == 0, "Emb_dim must be divisible by num_heads"
32
+ self.Emb_dim = Emb_dim
33
+ self.device = device
34
+ self.num_heads = num_heads
35
+ self.head_dim = Emb_dim // num_heads
36
+
37
+ self.key = nn.Linear(Emb_dim, Emb_dim, bias=False, dtype=dtype, device=device)
38
+ self.query = nn.Linear(Emb_dim, Emb_dim, bias=False, dtype=dtype, device=device)
39
+ self.value = nn.Linear(Emb_dim, Emb_dim, bias=False, dtype=dtype, device=device)
40
+
41
+ self.scale = math.sqrt(self.head_dim)
42
+ self.out_proj = nn.Linear(Emb_dim, Emb_dim, dtype=dtype, device=device)
43
+
44
+ self.dropout = nn.Dropout(dropout)
45
+
46
+ def forward(self, x):
47
+ batch_size, seq_len, _ = x.shape
48
+
49
+ # Generate Q, K, V and reshape for multi-head attention
50
+ keys = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
51
+ queries = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
52
+ values = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
53
+
54
+ # Compute attention scores
55
+ scores = (queries @ keys.transpose(-2, -1)) / self.scale
56
+
57
+ # Create causal mask dynamically based on current sequence length
58
+ causal_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool, device=x.device), diagonal=1)
59
+ scores = scores.masked_fill(causal_mask[None, None, :, :], float('-inf'))
60
+
61
+ # Apply softmax and dropout
62
+ attn = F.softmax(scores, dim=-1)
63
+ attn = self.dropout(attn)
64
+
65
+ # Apply attention to values
66
+ out = attn @ values
67
+
68
+ # Concatenate heads and project
69
+ out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.Emb_dim)
70
+
71
+ return self.out_proj(out)
72
+
73
+ # --- Feed Forward ---
74
+ class FF_ReLU(nn.Module):
75
+ def __init__(self, Emb_dim, hidden_dim, dropout=0.1, device='cpu', dtype=torch.float32):
76
+ super().__init__()
77
+ self.relu = nn.Sequential(
78
+ nn.Linear(Emb_dim, hidden_dim, device=device, dtype=dtype),
79
+ nn.ReLU(),
80
+ nn.Dropout(dropout),
81
+ nn.Linear(hidden_dim, Emb_dim, device=device, dtype=dtype),
82
+ )
83
+
84
+ def forward(self, x):
85
+ return self.relu(x)
86
+
87
+ class LayerNorm(nn.Module):
88
+ def __init__(self, Emb_dim, eps=1e-5, device='cpu', dtype=torch.float32):
89
+ super().__init__()
90
+ self.eps = eps
91
+ self.weight = nn.Parameter(torch.ones(Emb_dim, device=device, dtype=dtype))
92
+ self.bias = nn.Parameter(torch.zeros(Emb_dim, device=device, dtype=dtype))
93
+
94
+ def forward(self, x):
95
+ mean = x.mean(dim=-1, keepdim=True)
96
+ var = x.var(dim=-1, keepdim=True, unbiased=False)
97
+ norm_x = (x - mean) / torch.sqrt(var + self.eps)
98
+ return norm_x * self.weight + self.bias
99
+
100
+ # --- Transformer Block ---
101
+ class Block(nn.Module):
102
+ def __init__(self, Emb_dim, num_heads, dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
103
+ super().__init__()
104
+ self.attention = MultiHeadAttention(Emb_dim, num_heads, dropout, device=device, dtype=dtype)
105
+ self.norm1 = LayerNorm(Emb_dim, eps=eps, device=device, dtype=dtype)
106
+ self.ff_relu = FF_ReLU(Emb_dim, hidden_dim, dropout, device=device, dtype=dtype)
107
+ self.norm2 = LayerNorm(Emb_dim, eps=eps, device=device, dtype=dtype)
108
+
109
+ def forward(self, x):
110
+ # Pre-norm: normalize before attention
111
+ residual = x
112
+ x = self.norm1(x)
113
+ x = self.attention(x)
114
+ x = x + residual # Residual connection
115
+
116
+ # Pre-norm: normalize before FF
117
+ residual = x
118
+ x = self.norm2(x)
119
+ x = self.ff_relu(x)
120
+ x = x + residual # Residual connection
121
+
122
+ return x
123
+
124
+ # --- Encoder ---
125
+ class Encoder(nn.Module):
126
+ def __init__(self, num_layers, Emb_dim, num_heads, dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
127
+ super().__init__()
128
+ self.layers = nn.ModuleList([
129
+ Block(Emb_dim, num_heads, dropout, hidden_dim, eps, device=device, dtype=dtype)
130
+ for _ in range(num_layers)
131
+ ])
132
+
133
+ def forward(self, x):
134
+ for layer in self.layers:
135
+ x = layer(x)
136
+ return x
137
+
138
+ class GPTModel(nn.Module):
139
+ def __init__(self, vocab_size, num_layers, Emb_dim, num_heads, seq_len,
140
+ dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
141
+
142
+ super().__init__()
143
+ # --- Token embedding ---
144
+ self.embedding = nn.Embedding(vocab_size, Emb_dim, dtype=self.dtype, device=self.device)
145
+
146
+ # --- Embedding dropout ---
147
+ self.emb_dropout = nn.Dropout(dropout)
148
+
149
+ # --- Adaptive position embedding ---
150
+ self.position_embedding = SinusoidalPositionalEmbedding(
151
+ emb_dim=Emb_dim,
152
+ seq_len=seq_len
153
+ )
154
+
155
+ # --- Encoder ---
156
+ self.encoder = Encoder(
157
+ num_layers=num_layers,
158
+ Emb_dim=Emb_dim,
159
+ num_heads=num_heads,
160
+ dropout=dropout,
161
+ hidden_dim=hidden_dim,
162
+ eps=eps,
163
+ device=self.device,
164
+ dtype=self.dtype
165
+ )
166
+
167
+ # --- Final norm
168
+ self.final_norm = LayerNorm(Emb_dim, eps=eps,
169
+ device=self.device, dtype=self.dtype)
170
+
171
+ # --- Output Projection ---
172
+ self.lm_head = nn.Linear(Emb_dim, vocab_size, bias=False,
173
+ dtype=self.dtype, device=self.device)
174
+
175
+ def forward(self, x):
176
+ # x shape: (batch_size, seq_len)
177
+ emb = self.embedding(x) # (batch_size, seq_len, emb_dim)
178
+ pos = self.position_embedding(x) # (batch_size, seq_len, emb_dim)
179
+ x = emb + pos
180
+ x = self.emb_dropout(x)
181
+ x = self.encoder(x)
182
+ x = self.final_norm(x)
183
+ x = self.lm_head(x)
184
+ return x
185
+
186
+ @torch.no_grad()
187
+ def generate(self, prompt_ids, max_new_tokens=50, temperature=1.0, top_k=None, top_p=1.0, eos_token_id=None):
188
+ self.eval()
189
+ if prompt_ids.dim() == 1:
190
+ prompt_ids = prompt_ids.unsqueeze(0) # (1, seq_len)
191
+
192
+ generated = prompt_ids.clone()
193
+ max_context_len = self.seq_len
194
+
195
+ for _ in range(max_new_tokens):
196
+ # Use sliding window if sequence gets too long
197
+ if generated.size(1) > max_context_len:
198
+ input_ids = generated[:, -max_context_len:]
199
+ else:
200
+ input_ids = generated
201
+
202
+ logits = self.forward(input_ids) # (batch_size, seq_len, vocab_size)
203
+ logits = logits[:, -1, :] # (batch_size, vocab_size)
204
+
205
+ # --- Temperature scaling ---
206
+ if temperature != 1.0:
207
+ logits = logits / temperature
208
+
209
+ # --- Top-k filtering ---
210
+ if top_k is not None and top_k > 0:
211
+ topk_vals, topk_indices = torch.topk(logits, top_k)
212
+ mask = torch.full_like(logits, float('-inf'))
213
+ mask.scatter_(dim=-1, index=topk_indices, src=topk_vals)
214
+ logits = mask
215
+
216
+ # --- Top-p ---
217
+ if top_p < 1.0:
218
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
219
+ probs = F.softmax(sorted_logits, dim=-1)
220
+ cum_probs = torch.cumsum(probs, dim=-1)
221
+
222
+ sorted_mask = cum_probs > top_p
223
+ sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
224
+ sorted_mask[..., 0] = 0
225
+
226
+ indices_to_remove = sorted_mask.scatter(dim=-1, index=sorted_indices, src=sorted_mask)
227
+ logits = logits.masked_fill(indices_to_remove, float('-inf'))
228
+
229
+ # Sample next token
230
+ probs = F.softmax(logits, dim=-1)
231
+ next_token = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
232
+ generated = torch.cat([generated, next_token], dim=-1)
233
+
234
+ # check if we've reached the end of the sequence
235
+ if eos_token_id is not None and next_token.item() == eos_token_id:
236
+ break
237
+
238
+ return generated
File without changes
@@ -0,0 +1,533 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ class Self_Attention(nn.Module):
5
+ def __init__(self, Emb_dim, dropout,dtype=torch.float32,device='cpu'):
6
+ super().__init__()
7
+ self.device = device
8
+ self.scale = torch.tensor(Emb_dim ** 0.5,dtype=dtype,device=device)
9
+
10
+ self.key = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
11
+ self.query = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
12
+ self.value = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
13
+
14
+ self.out_proj = nn.Linear(Emb_dim, Emb_dim,dtype=dtype,device=device)
15
+ self.dropout = nn.Dropout(dropout)
16
+
17
+ def forward(self, x):
18
+ Batch_size, Seq_len, Emb_dim = x.size()
19
+
20
+ Querys = self.query(x) # (Batch_size, Seq_len, D)
21
+ Keys = self.key(x) # (Batch_size, Seq_len, D)
22
+ Values = self.value(x) # (Batch_size, Seq_len, D)
23
+
24
+ # Attention scores
25
+ scores = Querys @ Keys.transpose(-2, -1) / self.scale # (Batch_size, Seq_len, Seq_len)
26
+
27
+ causal_mask = torch.triu(torch.ones(Seq_len, Seq_len, dtype=torch.bool, device=self.device), diagonal=1)
28
+ scores = scores.masked_fill_(causal_mask, float('-inf')) # Mask *future* tokens
29
+
30
+ attn = F.softmax(scores, dim=-1) # (Batch_size, Seq_len, Seq_len)
31
+ attn = self.dropout(attn)
32
+
33
+ out = (attn @ Values) # (Batch_size, Seq_len, D)
34
+ return self.out_proj(out) # (Batch_size, Seq_len, D)
35
+
36
+ class Multi_Head_Attention(nn.Module):
37
+ def __init__(self, Emb_dim, num_heads, dropout, device='cpu',dtype=torch.float32):
38
+ super().__init__()
39
+ assert Emb_dim % num_heads == 0, "Emb_dim must be divisible by num_heads"
40
+ self.Emb_dim = Emb_dim
41
+ self.num_heads = num_heads
42
+ self.device = device
43
+ self.head_dim = Emb_dim // num_heads
44
+
45
+ self.key = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
46
+ self.query = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
47
+ self.value = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
48
+
49
+ self.scale = torch.tensor(self.head_dim ** 0.5,device=device,dtype=dtype)
50
+ self.out_proj = nn.Linear(Emb_dim, Emb_dim,dtype=dtype,device=device)
51
+
52
+ self.dropout = nn.Dropout(dropout)
53
+
54
+ def forward(self, x):
55
+ Batch_size, Seq_len, _ = x.shape
56
+
57
+ # Generate Q, K, V and reshape for multi-head attention
58
+ Keys = self.key(x).view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (Batch_size, nh, Seq_len, hd)
59
+ Querys = self.query(x).view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2)
60
+ Values = self.value(x).view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2)
61
+
62
+ # Compute attention scores
63
+ scores = (Querys @ Keys.transpose(-2, -1)) / self.scale # (Batch_size, nh, Seq_len, Seq_len)
64
+
65
+ # Apply causal mask if requested
66
+ causal_mask = torch.triu(torch.ones(Seq_len, Seq_len, dtype=torch.bool, device=self.device), diagonal=1)
67
+ scores = scores.masked_fill_(causal_mask[None, None, :, :], float('-inf'))
68
+
69
+ # Apply softmax and dropout
70
+ attn = F.softmax(scores, dim=-1)
71
+ attn = self.dropout(attn)
72
+
73
+ # Apply attention to values
74
+ out = attn @ Values # (Batch_size, nh, Seq_len, hd)
75
+
76
+ # Concatenate heads and project
77
+ out = out.transpose(1, 2).contiguous().view(Batch_size, Seq_len, self.Emb_dim) # (Batch_size, Seq_len, Emb_dim)
78
+
79
+ return self.out_proj(out)
80
+
81
+ class Cross_MultiHead_Attention(nn.Module):
82
+ def __init__(self, Emb_dim, num_heads, dropout,device='cpu', dtype=torch.float32):
83
+ super().__init__()
84
+ assert Emb_dim % num_heads == 0, "Emb_dim must be divisible by num_heads"
85
+ self.Emb_dim = Emb_dim
86
+ self.device = device
87
+ self.num_heads = num_heads
88
+ self.head_dim = Emb_dim // num_heads
89
+
90
+ # Querys, Key, Value projections
91
+ self.query = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
92
+ self.key = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
93
+ self.value = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
94
+
95
+ self.scale = torch.tensor(self.head_dim ** 0.5,device=device,dtype=dtype)
96
+
97
+ self.out_proj = nn.Linear(Emb_dim, Emb_dim,dtype=dtype,device=device)
98
+ self.dropout = nn.Dropout(dropout)
99
+
100
+ def forward(self, x, context=None):
101
+ """
102
+ x: (Batch_size, query_seq_len, Emb_dim) — query input (e.g., decoder hidden states)
103
+ context: (Batch_size, KV_seq_len, Emb_dim) — source for keys/values (e.g., encoder output). If None, self-attention.
104
+ mask: (Batch_size, 1, query_seq_len, KV_seq_len) — optional attention mask
105
+ """
106
+ Batch_size, query_seq_len, _ = x.shape
107
+ context = x if context is None else context # self-attention fallback
108
+ KV_seq_len = context.shape[1]
109
+
110
+ # Project Q, K, V
111
+ Querys = self.query(x).view(Batch_size, query_seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (Batch_size, nh, query_seq_len, hd)
112
+ Keys = self.key(context).view(Batch_size, KV_seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (Batch_size, nh, KV_seq_len, hd)
113
+ Values = self.value(context).view(Batch_size, KV_seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (Batch_size, nh, KV_seq_len, hd)
114
+
115
+ # Attention scores
116
+ scores = (Querys @ Keys.transpose(-2, -1)) / self.scale # (Batch_size, nh, query_seq_len, KV_seq_len)
117
+
118
+ causal_mask = torch.triu(torch.ones(query_seq_len, query_seq_len, dtype=torch.bool, device=self.device), diagonal=1)
119
+ scores = scores.masked_fill_(causal_mask[None, None, :, :], float('-inf'))
120
+
121
+ attn = F.softmax(scores, dim=-1)
122
+ attn = self.dropout(attn)
123
+
124
+ out = attn @ Values # (Batch_size, nh, query_seq_len, hd)
125
+ out = out.transpose(1, 2).contiguous().view(Batch_size, query_seq_len, self.Emb_dim) # (Batch_size, query_seq_len, Emb_dim)
126
+
127
+ return self.out_proj(out)
128
+
129
+ class Multi_query_Attention(nn.Module):
130
+ def __init__(self, Emb_dim, num_heads, dropout, device='cpu', dtype=torch.float32):
131
+ super().__init__()
132
+ assert Emb_dim % num_heads == 0, "Emb_dim must be divisible by num_heads"
133
+ self.Emb_dim = Emb_dim
134
+ self.device = device
135
+ self.num_heads = num_heads
136
+ self.head_dim = Emb_dim // num_heads
137
+
138
+ self.query = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
139
+ self.key = nn.Linear(Emb_dim, self.head_dim, bias=False,dtype=dtype,device=device)
140
+ self.value = nn.Linear(Emb_dim, self.head_dim, bias=False,dtype=dtype,device=device)
141
+
142
+ self.scale = torch.tensor(self.head_dim ** 0.5,device=device,dtype=dtype)
143
+
144
+ self.out_proj = nn.Linear(Emb_dim, Emb_dim,dtype=dtype,device=device)
145
+
146
+ self.dropout = nn.Dropout(dropout)
147
+
148
+ def forward(self, x):
149
+ Batch_size, Seq_len, C = x.shape
150
+ # Generate Q, K, V and reshape for Multiquery_Attention
151
+ Querys = self.query(x)
152
+ Keys = self.key(x)
153
+ Values = self.value(x)
154
+
155
+ Querys = Querys.view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (Batch_size, nh, Seq_len, hd)
156
+ Keys = Keys.unsqueeze(1).expand(Batch_size, 1, Seq_len, self.head_dim) # (Batch_size, 1, Seq_len, hd)
157
+ Values = Values.unsqueeze(1).expand(Batch_size, 1, Seq_len, self.head_dim)
158
+
159
+ # Compute attention scores
160
+ scores = (Querys @ Keys.transpose(-2, -1)) / self.scale # (Batch_size, nh, Seq_len, Seq_len)
161
+
162
+ # Apply causal mask if requested
163
+ causal_mask = torch.triu(torch.ones(Seq_len, Seq_len, dtype=torch.bool, device=self.device), diagonal=1)
164
+ scores = scores.masked_fill_(causal_mask[None, None, :, :], float('-inf'))
165
+
166
+ # Apply softmax and dropout
167
+ attn = F.softmax(scores, dim=-1)
168
+ attn = self.dropout(attn)
169
+
170
+ # Apply attention to values
171
+ out = attn @ Values # (Batch_size, nh, Seq_len, hd)
172
+
173
+ # Concatenate heads and project
174
+ out = out.transpose(1, 2).contiguous().view(Batch_size, Seq_len, self.Emb_dim) # (Batch_size, Seq_len, Emb_dim)
175
+
176
+ return self.out_proj(out)
177
+
178
+ class Group_query_Attention(nn.Module):
179
+ def __init__(self, Emb_dim, num_query_heads, num_kv_heads, dropout,device='cpu', dtype=torch.float32):
180
+ super().__init__()
181
+ assert Emb_dim % num_query_heads == 0, "Emb_dim must be divisible by num_heads"
182
+ self.Emb_dim = Emb_dim
183
+ self.device = device
184
+ self.num_query_heads = num_query_heads
185
+ self.num_kv_heads = num_kv_heads
186
+
187
+ self.head_dim = Emb_dim // num_query_heads
188
+ self.num_queries_pre_kv = num_query_heads // num_kv_heads
189
+
190
+ self.query = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
191
+ self.key = nn.Linear(Emb_dim, self. num_kv_heads * self.head_dim, bias=False,dtype=dtype,device=device)
192
+ self.value = nn.Linear(Emb_dim, self.num_kv_heads * self.head_dim, bias=False,dtype=dtype,device=device)
193
+
194
+ self.scale = torch.tensor(self.head_dim ** 0.5,device=device,dtype=dtype)
195
+
196
+ self.out_proj = nn.Linear(Emb_dim, Emb_dim,dtype=dtype,device=device)
197
+
198
+ self.dropout = nn.Dropout(dropout)
199
+
200
+ def forward(self, x):
201
+ Batch_size, Seq_len, C = x.shape
202
+
203
+ # Generate Q, K, V and reshape for Multiquery_Attention
204
+ Querys = self.query(x).view(Batch_size, Seq_len, self.num_query_heads, self.head_dim).transpose(1, 2) # (Batch_size, nqh, Seq_len, hd)
205
+ Keys = self.key(x).view(Batch_size, Seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) # (Batch_size, nkvh, Seq_len, hd)
206
+ Values = self.value(x).view(Batch_size, Seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) # (Batch_size, nkvh, Seq_len, hd)
207
+
208
+ Keys = Keys.repeat_interleave(self.num_queries_pre_kv,dim=1)
209
+ Values = Values.repeat_interleave(self.num_queries_pre_kv,dim=1)
210
+
211
+ # Compute attention scores
212
+ scores = (Querys @ Keys.transpose(-2, -1)) / self.scale # (Batch_size, nh, Seq_len, Seq_len)
213
+
214
+ # Apply causal mask if requested
215
+ causal_mask = torch.triu(torch.ones(Seq_len, Seq_len, dtype=torch.bool, device=self.device), diagonal=1)
216
+ scores = scores.masked_fill_(causal_mask[None, None, :, :], float('-inf'))
217
+
218
+ # Apply softmax and dropout
219
+ attn = F.softmax(scores, dim=-1)
220
+ attn = self.dropout(attn)
221
+
222
+ # Apply attention to values
223
+ out = attn @ Values # (Batch_size, nh, Seq_len, hd)
224
+
225
+ # Concatenate heads and project
226
+ out = out.transpose(1, 2).contiguous().view(Batch_size, Seq_len, self.Emb_dim) # (Batch_size, Seq_len, Emb_dim)
227
+
228
+ return self.out_proj(out)
229
+
230
+ class Linear_Attention(nn.Module):
231
+ def __init__(self, Emb_dim, num_heads, dropout, eps = 1e-5, device='cpu', dtype=torch.float32):
232
+ super().__init__()
233
+ assert Emb_dim % num_heads == 0, "Emb_dim must be divisible by num_heads"
234
+ self.Emb_dim = Emb_dim
235
+ self.eps = eps
236
+ self.device = device
237
+ self.dtype=dtype
238
+ self.num_heads = num_heads
239
+ self.head_dim = Emb_dim // self.num_heads
240
+
241
+ self.query = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
242
+ self.key = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
243
+ self.value = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
244
+
245
+ self.out_proj = nn.Linear(Emb_dim, Emb_dim,dtype=dtype,device=device)
246
+
247
+ self.dropout = nn.Dropout(dropout)
248
+
249
+ def forward(self, x):
250
+ Batch_size, Seq_len, _ = x.shape
251
+
252
+ # Generate Q, K, V and reshape for multi-head attention
253
+ Querys = self.query(x).view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (Batch_size, nh, Seq_len, hd)
254
+ Keys = self.key(x).view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2)
255
+ Values = self.value(x).view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2)
256
+
257
+ phi_q = F.elu(Querys) + 1.0
258
+ phi_k = F.elu(Keys) + 1.0
259
+
260
+ kv_outer_product = torch.matmul(phi_k.unsqueeze(-1),Values.unsqueeze(-2))
261
+
262
+ s_cumulative = torch.cumsum(kv_outer_product, dim=2)
263
+ z_cumulative = torch.cumsum(phi_k,dim=2)
264
+
265
+ numerator = torch.matmul(phi_q.unsqueeze(-2),s_cumulative).squeeze(-2)
266
+ denominator = torch.sum(phi_q * z_cumulative,dim=-1,keepdim=True) + self.eps
267
+
268
+ out = numerator / denominator
269
+ out = out.transpose(1, 2).contiguous().view(Batch_size, Seq_len, self.Emb_dim) # (Batch_size, Seq_len, Emb_dim)
270
+ out = self.out_proj(out)
271
+ return self.dropout(out)
272
+
273
+ class Multi_latent_Attention(nn.Module):
274
+ def __init__(self, Emb_dim, q_compressed_dim, kv_compressed_dim , num_heads,device='cpu' ,dtype=torch.float32, dropout=0):
275
+ super().__init__()
276
+ assert Emb_dim % num_heads == 0, "Emb_dim must be divisible by num_heads"
277
+ self.Emb_dim = Emb_dim
278
+ self.device = device
279
+ self.q_compressed_dim = q_compressed_dim
280
+ self.kv_compressed_dim = kv_compressed_dim
281
+ self.num_heads = num_heads
282
+ self.head_dim = Emb_dim // self.num_heads
283
+
284
+ self.W_dq = nn.Linear(Emb_dim,q_compressed_dim,bias=False,dtype=dtype,device=device)
285
+ self.W_dq_norm = nn.LayerNorm(q_compressed_dim,dtype=dtype,device=device)
286
+ self.W_uq = nn.Linear(q_compressed_dim,Emb_dim,bias=False,dtype=dtype,device=device)
287
+
288
+ self.W_dkv = nn.Linear(Emb_dim,kv_compressed_dim,bias=False,dtype=dtype,device=device)
289
+ self.W_dkv_norm = nn.LayerNorm(kv_compressed_dim,dtype=dtype,device=device)
290
+ self.W_uk = nn.Linear(kv_compressed_dim,Emb_dim,dtype=dtype,device=device)
291
+ self.W_uv = nn.Linear(kv_compressed_dim,Emb_dim,dtype=dtype,device=device)
292
+
293
+ self.out_proj = nn.Linear(Emb_dim, Emb_dim,dtype=dtype,device=device)
294
+
295
+ self.dropout = nn.Dropout(dropout)
296
+
297
+ def forward(self, x):
298
+ Batch_size, Seq_len, C = x.shape
299
+
300
+ compressed_q_latent = self.W_dq(x)
301
+ compressed_q_latent_norm = self.W_dq_norm(compressed_q_latent)
302
+ q_final = self.W_uq(compressed_q_latent_norm)
303
+
304
+ compressed_kv_latent = self.W_dkv(x)
305
+ compressed_kv_latent_norm = self.W_dkv_norm(compressed_kv_latent)
306
+ k_final = self.W_uk(compressed_kv_latent_norm)
307
+ v_final = self.W_uv(compressed_kv_latent_norm)
308
+
309
+ Querys = q_final.view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2)
310
+ Keys = k_final.view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2)
311
+ Values = v_final.view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2)
312
+
313
+ out = F.scaled_dot_product_attention(
314
+ query=Querys,
315
+ key=Keys,
316
+ value=Values,
317
+ attn_mask=None,
318
+ is_causal=True,
319
+ dropout_p=self.dropout.p # use self.dropout.p to get dropout prob
320
+ )
321
+ out = self.out_proj(out.transpose(1, 2).contiguous().view(Batch_size, Seq_len, C))
322
+ out = self.dropout(out)
323
+ return out
324
+
325
+ class Local_Attention(nn.Module):
326
+ def __init__(self, Emb_dim, num_heads, Window_size ,dropout, device='cpu',dtype=torch.float32):
327
+ super().__init__()
328
+ assert Emb_dim % num_heads == 0, "Emb_dim must be divisible by num_heads"
329
+ self.Emb_dim = Emb_dim
330
+ self.Window_size = Window_size
331
+ self.device = device
332
+ self.num_heads = num_heads
333
+ self.head_dim = Emb_dim // num_heads
334
+
335
+ self.key = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
336
+ self.query = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
337
+ self.value = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
338
+
339
+ self.scale = torch.tensor(self.head_dim ** 0.5,device=device,dtype=dtype)
340
+ self.out_proj = nn.Linear(Emb_dim, Emb_dim,dtype=dtype,device=device)
341
+
342
+ self.dropout = nn.Dropout(dropout)
343
+
344
+ def forward(self, x):
345
+ Batch_size, Seq_len, _ = x.shape
346
+
347
+ # Generate Q, K, V and reshape for multi-head attention
348
+ Keys = self.key(x).view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (Batch_size, nh, Seq_len, hd)
349
+ Querys = self.query(x).view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2)
350
+ Values = self.value(x).view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2)
351
+
352
+ # Compute attention scores
353
+ scores = (Querys @ Keys.transpose(-2, -1)) / self.scale # (Batch_size, nh, Seq_len, Seq_len)
354
+
355
+ # Apply sliding window mask
356
+ casual = torch.tril(torch.ones_like(scores,dtype=bool))
357
+ band = torch.triu(casual, diagonal=-(self.Window_size-1))
358
+ scores = scores.masked_fill_(~band, float('-inf'))
359
+
360
+ # Apply softmax and dropout
361
+ attn = F.softmax(scores, dim=-1)
362
+ attn = self.dropout(attn)
363
+
364
+ # Apply attention to values
365
+ out = attn @ Values # (Batch_size, nh, Seq_len, hd)
366
+
367
+ # Concatenate heads and project
368
+ out = out.transpose(1, 2).contiguous().view(Batch_size, Seq_len, self.Emb_dim) # (Batch_size, Seq_len, Emb_dim)
369
+
370
+ return self.out_proj(out)
371
+
372
+ def precompute_theta_position_frequency(head_dim, seq_len, device='cpu', theta=10000.0):
373
+ assert head_dim % 2 == 0, "head_dim must be even"
374
+
375
+ # Frequencies: 1 / (theta ** (2i / head_dim))
376
+ theta_numerator = torch.arange(0, head_dim, 2, device=device)
377
+ inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))
378
+
379
+ # Position indices
380
+ m = torch.arange(seq_len, device=device)
381
+
382
+ # Outer product: (seq_len, head_dim // 2)
383
+ freqs = torch.outer(m, inv_freq)
384
+
385
+ # Convert to complex exponential form: exp(i * freq)
386
+ freq_complex = torch.polar(torch.ones_like(freqs), freqs)
387
+ return freq_complex
388
+
389
+
390
+ def apply_rotry_position_embedding(x, freq_complex, device='cpu', dtype=torch.float32):
391
+ # x: (batch_size, seq_len, num_head, emb_dim)
392
+ batch_size, seq_len, num_head, emb_dim = x.shape
393
+ assert emb_dim % 2 == 0, "emb_dim must be even"
394
+
395
+ # Reshape to split last dimension into complex pairs
396
+ x_reshaped = x.view(batch_size, seq_len, num_head, emb_dim // 2, 2).to(device=device, dtype=dtype)
397
+ x_complex = torch.view_as_complex(x_reshaped)
398
+
399
+ # Prepare frequencies: (1, seq_len, 1, emb_dim//2)
400
+ freq_complex = freq_complex[:seq_len].unsqueeze(0).unsqueeze(2).to(device=device)
401
+
402
+ # Apply rotation
403
+ x_rotated = x_complex * freq_complex
404
+
405
+ # Convert back to real tensor and reshape
406
+ x_out = torch.view_as_real(x_rotated).contiguous().view(batch_size, seq_len, num_head, emb_dim)
407
+ return x_out.to(device=device, dtype=dtype)
408
+
409
+ class kv_cache_multihead(nn.Module):
410
+ def __init__(self, emb_dim, num_heads, batch_size, kv_seq_len, device='cpu', dtype=torch.float32,dropout=0.1):
411
+ super().__init__()
412
+ self.dtype = dtype
413
+ self.device = device
414
+
415
+ assert emb_dim % num_heads == 0
416
+ self.emb_dim = emb_dim
417
+ self.num_heads = num_heads
418
+ self.head_dim = emb_dim // num_heads
419
+ self.kv_seq_len = kv_seq_len
420
+
421
+ self.query = nn.Linear(emb_dim, emb_dim, bias=False,dtype=dtype,device=device)
422
+ self.key = nn.Linear(emb_dim, emb_dim, bias=False,dtype=dtype,device=device)
423
+ self.value = nn.Linear(emb_dim, emb_dim, bias=False,dtype=dtype,device=device)
424
+
425
+ self.out_proj = nn.Linear(emb_dim, emb_dim,dtype=dtype,device=device)
426
+ self.dropout = nn.Dropout(dropout)
427
+
428
+ self.cache_keys = torch.zeros(batch_size, kv_seq_len, num_heads, self.head_dim,dtype=dtype,device=device)
429
+ self.cache_value = torch.zeros(batch_size, kv_seq_len, num_heads, self.head_dim,dtype=dtype,device=device)
430
+
431
+ def forward(self, x, start_pos, RoPE: False):
432
+ batch_size, seq_len, C = x.shape
433
+
434
+ xq = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
435
+ xk = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
436
+ xv = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
437
+
438
+ if RoPE:
439
+ freq_complex = precompute_theta_position_frequency(head_dim=self.head_dim, seq_len=seq_len, device=self.device)
440
+ xq = apply_rotry_position_embedding(xq, freq_complex, device=self.device, dtype=self.dtype)
441
+ freq_complex = precompute_theta_position_frequency(head_dim=self.head_dim, seq_len=self.kv_seq_len, device=self.device)
442
+ xk = apply_rotry_position_embedding(xk, freq_complex, device=self.device, dtype=self.dtype)
443
+
444
+ # Cache keys and values
445
+ self.cache_keys[:, start_pos:start_pos+seq_len] = xk
446
+ self.cache_value[:, start_pos:start_pos+seq_len] = xv
447
+
448
+ xk_full = self.cache_keys[:, :start_pos+seq_len]
449
+ xv_full = self.cache_value[:, :start_pos+seq_len]
450
+
451
+ query = xq.transpose(1, 2) # (batch_size, num_head, seq_len, emb_dim)
452
+ key = xk_full.transpose(1, 2) # (batch_size, num_head, T_total, emb_dim)
453
+ value = xv_full.transpose(1, 2) # (batch_size, num_head, T_total, emb_dim)
454
+
455
+ attn_scores = torch.matmul(query, key.transpose(2, 3)) / (self.head_dim ** 0.5)
456
+
457
+ # Causal mask
458
+ causal_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool, device=self.device), diagonal=1)
459
+ attn_scores.masked_fill_(causal_mask[None, None, :, :], float('-inf'))
460
+
461
+ attn_weights = F.softmax(attn_scores, dim=-1)
462
+ out = torch.matmul(attn_weights, value)
463
+
464
+ out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
465
+ return self.dropout(self.out_proj(out))
466
+
467
+ class kv_cache_group_query(nn.Module):
468
+ def __init__(self, emb_dim, query_num_heads, kv_num_heads, batch_size, kv_seq_len,device='cpu' , dtype=torch.float32 , dropout=0.1):
469
+ super().__init__()
470
+ self.dtype = dtype
471
+ self.device = device
472
+
473
+ assert query_num_heads % kv_num_heads == 0, "query heads must be divisible by kv heads"
474
+ assert emb_dim % query_num_heads == 0, "embedding must be divisible by query heads"
475
+
476
+ self.emb_dim = emb_dim
477
+ self.query_num_heads = query_num_heads
478
+ self.kv_num_heads = kv_num_heads
479
+ self.head_dim = emb_dim // query_num_heads
480
+ self.num_queries_per_kv = query_num_heads // kv_num_heads
481
+ self.kv_seq_len = kv_seq_len
482
+
483
+ self.query = nn.Linear(emb_dim, emb_dim, bias=False,dtype=dtype,device=device)
484
+ self.key = nn.Linear(emb_dim, kv_num_heads * self.head_dim, bias=False,dtype=dtype,device=device)
485
+ self.value = nn.Linear(emb_dim, kv_num_heads * self.head_dim, bias=False,dtype=dtype,device=device)
486
+
487
+ self.out_proj = nn.Linear(emb_dim, emb_dim,dtype=dtype,device=device)
488
+ self.dropout = nn.Dropout(dropout)
489
+
490
+ # KV caches
491
+ self.register_buffer("cache_keys", torch.zeros(batch_size, kv_seq_len, kv_num_heads, self.head_dim,device=device,dtype=dtype))
492
+ self.register_buffer("cache_value", torch.zeros(batch_size, kv_seq_len, kv_num_heads, self.head_dim,device=device,dtype=dtype))
493
+
494
+ def forward(self, x, start_pos, RoPE=False):
495
+ batch_size, seq_len, _ = x.shape
496
+
497
+ xq = self.query(x).view(batch_size, seq_len, self.query_num_heads, self.head_dim)
498
+ xk = self.key(x).view(batch_size, seq_len, self.kv_num_heads, self.head_dim)
499
+ xv = self.value(x).view(batch_size, seq_len, self.kv_num_heads, self.head_dim)
500
+
501
+ if RoPE:
502
+ freq_q = precompute_theta_position_frequency(head_dim=self.head_dim, seq_len=seq_len, device=self.device)
503
+ xq = apply_rotry_position_embedding(xq, freq_q, device=self.device, dtype=self.dtype)
504
+ freq_k = precompute_theta_position_frequency(head_dim=self.head_dim, seq_len=self.kv_seq_len, device=self.device)
505
+ xk = apply_rotry_position_embedding(xk, freq_k, device=self.device, dtype=self.dtype)
506
+ # Cache
507
+ self.cache_keys[:, start_pos:start_pos+seq_len] = xk
508
+ self.cache_value[:, start_pos:start_pos+seq_len] = xv
509
+
510
+ xk_full = self.cache_keys[:, :start_pos+seq_len] # [B, T, kv_heads, D]
511
+ xv_full = self.cache_value[:, :start_pos+seq_len]
512
+
513
+ # Transpose for attention: [B, H, T, D]
514
+ query = xq.transpose(1, 2) # [B, q_heads, seq_len, D]
515
+ key = xk_full.transpose(1, 2) # [B, kv_heads, total_kv_len, D]
516
+ value = xv_full.transpose(1, 2)
517
+
518
+ # Repeat keys and values to match query heads
519
+ key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
520
+ value = value.repeat_interleave(self.num_queries_per_kv, dim=1)
521
+
522
+ # Attention
523
+ attn_scores = torch.matmul(query, key.transpose(2, 3)) / (self.head_dim ** 0.5)
524
+
525
+ # Causal mask
526
+ causal_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool, device=self.device), diagonal=1)
527
+ attn_scores.masked_fill_(causal_mask[None, None, :, :], float('-inf'))
528
+
529
+ attn_weights = F.softmax(attn_scores, dim=-1)
530
+ out = torch.matmul(attn_weights, value)
531
+
532
+ out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.emb_dim)
533
+ return self.dropout(self.out_proj(out))
@@ -0,0 +1,59 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class FF_ReLU(nn.Module):
6
+ def __init__(self,emb_dim,hidden_dim,device='cpu',dtype=torch.float32):
7
+ super().__init__()
8
+ self.relu=nn.Sequential(
9
+ nn.Linear(emb_dim,hidden_dim,device=device,dtype=dtype),
10
+ nn.ReLU(),
11
+ nn.Linear(hidden_dim,emb_dim,device=device,dtype=dtype),
12
+ )
13
+ def forward(self,x):
14
+ return self.relu(x)
15
+
16
+ class FF_LeakyReLU(nn.Module):
17
+ def __init__(self,emb_dim,hidden_dim,negative_slope=0.1,device='cpu',dtype=torch.float32):
18
+ super().__init__()
19
+ self.l_relu=nn.Sequential(
20
+ nn.Linear(emb_dim,hidden_dim,device=device,dtype=dtype),
21
+ nn.LeakyReLU(negative_slope),
22
+ nn.Linear(hidden_dim,emb_dim,device=device,dtype=dtype),
23
+ )
24
+ def forward(self,x):
25
+ return self.l_relu(x)
26
+
27
+ class FF_GELU(nn.Module):
28
+ def __init__(self,emb_dim,hidden_dim,device='cpu',dtype=torch.float32):
29
+ super().__init__()
30
+ self.gelu=nn.Sequential(
31
+ nn.Linear(emb_dim,hidden_dim,device=device,dtype=dtype),
32
+ nn.GELU(),
33
+ nn.Linear(hidden_dim,emb_dim,device=device,dtype=dtype),
34
+ )
35
+ def forward(self,x):
36
+ return self.gelu(x)
37
+
38
+ class FF_Sigmoid(nn.Module):
39
+ def __init__(self,emb_dim,hidden_dim,device='cpu',dtype=torch.float32):
40
+ super().__init__()
41
+ self.sigmoid=nn.Sequential(
42
+ nn.Linear(emb_dim,hidden_dim,device=device,dtype=dtype),
43
+ nn.Sigmoid(),
44
+ nn.Linear(hidden_dim,emb_dim,device=device,dtype=dtype),
45
+ )
46
+ def forward(self,x):
47
+ return self.sigmoid(x)
48
+
49
+ class FF_SiLU(nn.Module):
50
+ def __init__(self,emb_dim,hidden_dim,device='cpu',dtype=torch.float32):
51
+ super().__init__()
52
+ self.silu=nn.Sequential(
53
+ nn.Linear(emb_dim,hidden_dim,device=device,dtype=dtype),
54
+ nn.SiLU(),
55
+ nn.Linear(hidden_dim,emb_dim,device=device,dtype=dtype),
56
+ )
57
+ def forward(self,x):
58
+ return self.silu(x)
59
+
@@ -0,0 +1,41 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class LayerNorm(nn.Module):
5
+ def __init__(self, emb_dim, eps = 1e-5):
6
+ super().__init__()
7
+ self.eps = eps
8
+
9
+ def forward(self, x):
10
+ mean = x.mean(dim=-1, keepdim=True)
11
+ var = x.var(dim=-1, keepdim=True, unbiased=False)
12
+ norm_x = (x - mean) / torch.sqrt(var + self.eps)
13
+ return norm_x
14
+
15
+ class LayerNorm(nn.Module):
16
+ def __init__(self, emb_dim, eps=1e-5, device='cpu', dtype=torch.float32):
17
+ super().__init__()
18
+ self.eps = eps
19
+ self.weight = nn.Parameter(torch.ones(emb_dim, device=device, dtype=dtype))
20
+ self.bias = nn.Parameter(torch.zeros(emb_dim, device=device, dtype=dtype))
21
+
22
+ def forward(self, x):
23
+ mean = x.mean(dim=-1, keepdim=True)
24
+ var = x.var(dim=-1, keepdim=True, unbiased=False)
25
+ norm_x = (x - mean) / torch.sqrt(var + self.eps)
26
+ return norm_x * self.weight + self.bias
27
+
28
+
29
+ # RMS = sqrt(Xn ** 2)
30
+ # Norm = Xn / RMS
31
+ class RMSNormilization(nn.Module):
32
+
33
+ def __init__(self,dim,eps=1e-5):
34
+ super().__init__()
35
+ self.eps = eps
36
+ self.weight = nn.Parameter(torch.ones(dim))
37
+
38
+ def forward(self,x):
39
+ rms = x.pow(2).mean(-1,keepdim=True).sqrt()
40
+ norm = self.weight * x / (rms + self.eps)
41
+ return norm
File without changes
@@ -0,0 +1,36 @@
1
+ # problem: Random mask and global mask
2
+ import torch
3
+
4
+ def casual_mask(Seq_len):
5
+ causal_mask = torch.triu(torch.ones(Seq_len, Seq_len, dtype=torch.bool), diagonal=1)
6
+ return causal_mask
7
+
8
+ def sliding_window(Seq_len, window_size):
9
+ casual = torch.tril(torch.ones(Seq_len,Seq_len,dtype=bool))
10
+ band = torch.triu(casual, diagonal=-(window_size-1))
11
+ return ~band
12
+
13
+ def dilated_casual_mask(Seq_len, dilation):
14
+ i = torch.arange(Seq_len).unsqueeze(1)
15
+ j = torch.arange(Seq_len).unsqueeze(0)
16
+ # causal and dilation condition
17
+ mask = (i >= j) & ((i - j) % dilation == 0)
18
+ return ~mask
19
+
20
+ def random_mask(Seq_len, num_random):
21
+ mask = torch.zeros(Seq_len, Seq_len)
22
+ for i in range(Seq_len):
23
+ candidates = list(range(i))
24
+ if len(candidates) == 0:
25
+ continue
26
+ random_mask = torch.randperm(len(candidates))[:min(num_random, len(candidates))]
27
+ mask[i, torch.tensor([candidates[j] for j in random_mask])] = 1
28
+ return ~mask
29
+
30
+ def global_mask(Seq_len, global_index):
31
+ global_index_tensor = torch.tensor(global_index)
32
+ mask = torch.zeros(Seq_len, Seq_len)
33
+ for g in global_index:
34
+ mask[g,:] = 1
35
+ mask[:,global_index_tensor] = 1
36
+ return ~mask
@@ -0,0 +1,61 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ # --- Absolute Positional Embedding ---
6
+ class AbsolutePositionEmbedding(nn.Module):
7
+ def __init__(self, seq_len, emb_dim):
8
+ super().__init__()
9
+ self.seq_len = seq_len
10
+ self.emb_dim = emb_dim
11
+ self.embedding = nn.Embedding(seq_len, emb_dim)
12
+
13
+ def forward(self, x):
14
+ batch_size, seq_len = x.shape[0], x.shape[1]
15
+ positions = torch.arange(0, seq_len)
16
+ abs_pos = self.embedding(positions) # (seq_len, emb_dim)
17
+ return abs_pos.unsqueeze(0).expand(batch_size, seq_len, -1).to(x.device)
18
+
19
+ # --- Sinusoidal Positional Embedding ---
20
+ class SinusoidalPositionalEmbedding(nn.Module):
21
+ def __init__(self, seq_len, emb_dim):
22
+ super().__init__()
23
+ self.seq_len = seq_len
24
+ self.emb_dim = emb_dim
25
+
26
+ position = torch.arange(0, seq_len).unsqueeze(1)
27
+ div_term = torch.exp(torch.arange(0, emb_dim, 2) * -(math.log(10000.0) / emb_dim))
28
+
29
+ pe = torch.zeros(seq_len, emb_dim)
30
+ pe[:, 0::2] = torch.sin(position * div_term)
31
+ pe[:, 1::2] = torch.cos(position * div_term)
32
+
33
+ self.register_buffer("pe", pe)
34
+
35
+ def forward(self, x):
36
+ # x shape: (batch_size, seq_len, emb_dim) or (batch_size, seq_len)
37
+ batch_size, seq_len = x.shape[0], x.shape[1]
38
+ return self.pe[:seq_len].unsqueeze(0).expand(batch_size, seq_len, -1).to(x.device)
39
+
40
+ # --- RoPE ---
41
+ class RoPE(nn.Module):
42
+ def __init__(self, head_dim, seq_len, theta=10000.0, device='cpu', dtype=torch.float32):
43
+ super().__init__()
44
+ self.dtype = dtype
45
+ self.device = device
46
+ assert head_dim % 2 == 0, "head_dim must be even"
47
+ theta_numerator = torch.arange(0, head_dim, 2, device=device, dtype=dtype)
48
+ inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))
49
+ m = torch.arange(seq_len, device=device)
50
+ freqs = torch.outer(m, inv_freq)
51
+ self.register_buffer("freq_complex", torch.polar(torch.ones_like(freqs), freqs))
52
+
53
+ def forward(self, x):
54
+ batch_size, seq_len, num_head, emb_dim = x.shape
55
+ assert emb_dim % 2 == 0, "emb_dim must be even"
56
+ x_reshaped = x.view(batch_size, seq_len, num_head, emb_dim // 2, 2)
57
+ x_complex = torch.view_as_complex(x_reshaped)
58
+ freqs = self.freq_complex[:seq_len].unsqueeze(0).unsqueeze(2) # (1, seq_len, 1, head_dim//2)
59
+ x_rotated = x_complex * freqs
60
+ x_out = torch.view_as_real(x_rotated).contiguous().view(batch_size, seq_len, num_head, emb_dim)
61
+ return x_out.to(device=self.device, dtype=self.dtype)
@@ -0,0 +1,25 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import tiktoken
4
+
5
+ # Bite pair (BPE) Embedding using tokienizer
6
+ class Embedding_using_tiktoken:
7
+ def __init__(self,data,embedding_dim,model: str):
8
+ self.tokenizer = tiktoken.get_encoding(model)
9
+
10
+
11
+ def encoding(self,data,embedding_dim):
12
+ max_token_id = self.tokenizer.n_vocab
13
+ embedding_layer = nn.Embedding(num_embeddings = max_token_id, embedding_dim = embedding_dim)
14
+ tensors = torch.tensor(self.tokenizer.encode(data))
15
+ embedded = embedding_layer(tensors)
16
+ return embedded
17
+
18
+ def decoding(self,data):
19
+ return self.tokenizer.decode(data)
20
+
21
+ def vocab_size(self):
22
+ return self.tokenizer.n_vocab
23
+
24
+ def model_list(self):
25
+ return tiktoken.list_encoding_names()
@@ -0,0 +1,25 @@
1
+ [project]
2
+ name = "Stackformer"
3
+ version = "0.1.0"
4
+ description = "Modular transformer blocks built in PyTorch"
5
+ readme = "README.md"
6
+ requires-python = ">=3.9"
7
+ license = {text = "MIT"}
8
+
9
+ authors = [
10
+ {name = "Gurumurthy", email = "gurumurthy.00300@gmail.com"}
11
+ ]
12
+
13
+ dependencies = [
14
+ "torch>=2.6",
15
+ "tqdm>=4.67"
16
+ ]
17
+
18
+ [project.urls]
19
+ "Repository" = "https://github.com/Gurumurthy30/Stackformer"
20
+ "Issue Tracker" = "https://github.com/Gurumurthy30/Stackformer/issues"
21
+ "Discussions" = "https://github.com/Gurumurthy30/Stackformer/discussions"
22
+
23
+ [build-system]
24
+ requires = ["setuptools>=61", "wheel"]
25
+ build-backend = "setuptools.build_meta"
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,31 @@
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name="Stackformer",
5
+ version="0.1.0",
6
+ description="Modular transformer blocks built in PyTorch",
7
+ # long_description=open("README.md", "r", encoding="utf-8").read(),
8
+ # long_description_content_type="text/markdown",
9
+ author="Gurumurthy",
10
+ author_email="gurumurthy.00300@gmail.com",
11
+ url="https://github.com/Gurumurthy30/Stackformer",
12
+ project_urls={
13
+ "Repository": "https://github.com/Gurumurthy30/Stackformer",
14
+ "Issue Tracker": "https://github.com/Gurumurthy30/Stackformer/issues",
15
+ "Discussions": "https://github.com/Gurumurthy30/Stackformer/discussions",
16
+ },
17
+ license="MIT",
18
+ python_requires=">=3.9",
19
+ packages=find_packages(exclude=["tests", "examples"]),
20
+ install_requires=[
21
+ "torch>=2.6",
22
+ "tqdm>=4.67",
23
+ ],
24
+ classifiers=[
25
+ "Programming Language :: Python :: 3",
26
+ "License :: OSI Approved :: MIT License",
27
+ "Operating System :: OS Independent",
28
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
29
+ "Topic :: Software Development :: Libraries :: Python Modules",
30
+ ],
31
+ )