Stackformer 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stackformer-0.1.0/LICENSE +21 -0
- stackformer-0.1.0/PKG-INFO +75 -0
- stackformer-0.1.0/README.md +54 -0
- stackformer-0.1.0/Stackformer.egg-info/PKG-INFO +75 -0
- stackformer-0.1.0/Stackformer.egg-info/SOURCES.txt +18 -0
- stackformer-0.1.0/Stackformer.egg-info/dependency_links.txt +1 -0
- stackformer-0.1.0/Stackformer.egg-info/requires.txt +2 -0
- stackformer-0.1.0/Stackformer.egg-info/top_level.txt +2 -0
- stackformer-0.1.0/models/GPT_2.py +238 -0
- stackformer-0.1.0/models/__init__.py +0 -0
- stackformer-0.1.0/modules/Attention.py +533 -0
- stackformer-0.1.0/modules/Feed_forward.py +59 -0
- stackformer-0.1.0/modules/Normalization.py +41 -0
- stackformer-0.1.0/modules/__init__.py +0 -0
- stackformer-0.1.0/modules/mask.py +36 -0
- stackformer-0.1.0/modules/position_embedding.py +61 -0
- stackformer-0.1.0/modules/tokenizer.py +25 -0
- stackformer-0.1.0/pyproject.toml +25 -0
- stackformer-0.1.0/setup.cfg +4 -0
- stackformer-0.1.0/setup.py +31 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 GURUMURTHY
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: Stackformer
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Modular transformer blocks built in PyTorch
|
|
5
|
+
Home-page: https://github.com/Gurumurthy30/Stackformer
|
|
6
|
+
Author: Gurumurthy
|
|
7
|
+
Author-email: Gurumurthy <gurumurthy.00300@gmail.com>
|
|
8
|
+
License: MIT
|
|
9
|
+
Project-URL: Repository, https://github.com/Gurumurthy30/Stackformer
|
|
10
|
+
Project-URL: Issue Tracker, https://github.com/Gurumurthy30/Stackformer/issues
|
|
11
|
+
Project-URL: Discussions, https://github.com/Gurumurthy30/Stackformer/discussions
|
|
12
|
+
Requires-Python: >=3.9
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
14
|
+
License-File: LICENSE
|
|
15
|
+
Requires-Dist: torch>=2.6
|
|
16
|
+
Requires-Dist: tqdm>=4.67
|
|
17
|
+
Dynamic: author
|
|
18
|
+
Dynamic: home-page
|
|
19
|
+
Dynamic: license-file
|
|
20
|
+
Dynamic: requires-python
|
|
21
|
+
|
|
22
|
+
## 🧱 Stackformer
|
|
23
|
+
|
|
24
|
+
**Stackformer** is a modular transformer-building framework written entirely in PyTorch. It is designed primarily for experimentation, providing various transformer blocks such as attention mechanisms, normalization layers, feed-forward networks, and a simple model architecture. The project is a work-in-progress with plans for further enhancements and expansions.
|
|
25
|
+
|
|
26
|
+
---
|
|
27
|
+
|
|
28
|
+
## 📖 About Me
|
|
29
|
+
|
|
30
|
+
My name is **Gurumurthy**, and I am a final-year Bachelor of Engineering student from India. I created this library as my own size project to showcase my skills and knowledge in deep learning and transformer architectures.
|
|
31
|
+
|
|
32
|
+
I am also interested and free to work with others on different projects for knowledge sharing and building connections.
|
|
33
|
+
|
|
34
|
+
---
|
|
35
|
+
|
|
36
|
+
## 🌟 Features
|
|
37
|
+
|
|
38
|
+
- Multiple attention mechanisms including multi-head, group query, linear, local, and KV cache variants
|
|
39
|
+
- Token embedding via `tiktoken`
|
|
40
|
+
- Absolute and sinusoidal positional embeddings
|
|
41
|
+
- Normalization layers like LayerNorm and RMSNorm
|
|
42
|
+
- Several feed-forward network variants with activations such as ReLU, GELU, SiLU, LeakyReLU, and Sigmoid
|
|
43
|
+
- A simple GPT-style transformer model implementation
|
|
44
|
+
|
|
45
|
+
---
|
|
46
|
+
|
|
47
|
+
## 📁 Project Structure
|
|
48
|
+
|
|
49
|
+
stackformer/ \
|
|
50
|
+
|-- modules/ \
|
|
51
|
+
| |-- tokenizer.py # Token embedding using tiktoken \
|
|
52
|
+
| |-- position_embedding.py # Absolute and sinusoidal embeddings \
|
|
53
|
+
| |-- Attention.py # Attention mechanisms \
|
|
54
|
+
| |-- Normalization.py # LayerNorm and RMSNorm \
|
|
55
|
+
| |-- Feed_forward.py # Feed-forward layers with various activations \
|
|
56
|
+
|-- models/ \
|
|
57
|
+
| -- GPT_2.py # GPT-style transformer stack model \
|
|
58
|
+
-- trainer.py # Training loop and utilities \
|
|
59
|
+
|
|
60
|
+
---
|
|
61
|
+
|
|
62
|
+
## 💻 Installation
|
|
63
|
+
|
|
64
|
+
Clone the repository and install in development mode:
|
|
65
|
+
|
|
66
|
+
```bash
|
|
67
|
+
git clone https://github.com/Gurumurthy30/Stackformer
|
|
68
|
+
cd transformers
|
|
69
|
+
pip install -e .
|
|
70
|
+
```
|
|
71
|
+
|
|
72
|
+
---
|
|
73
|
+
|
|
74
|
+
## 🚀 Future Plans
|
|
75
|
+
Currently, I am working on improving and optimizing the existing components while fixing known bugs and issues. After stabilizing the current modules, I plan to add more advanced blocks like Mixture of Experts (MoE), mask handling, and other essential transformer components. Eventually, I will expand the library by developing more comprehensive model architectures.
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
## 🧱 Stackformer
|
|
2
|
+
|
|
3
|
+
**Stackformer** is a modular transformer-building framework written entirely in PyTorch. It is designed primarily for experimentation, providing various transformer blocks such as attention mechanisms, normalization layers, feed-forward networks, and a simple model architecture. The project is a work-in-progress with plans for further enhancements and expansions.
|
|
4
|
+
|
|
5
|
+
---
|
|
6
|
+
|
|
7
|
+
## 📖 About Me
|
|
8
|
+
|
|
9
|
+
My name is **Gurumurthy**, and I am a final-year Bachelor of Engineering student from India. I created this library as my own size project to showcase my skills and knowledge in deep learning and transformer architectures.
|
|
10
|
+
|
|
11
|
+
I am also interested and free to work with others on different projects for knowledge sharing and building connections.
|
|
12
|
+
|
|
13
|
+
---
|
|
14
|
+
|
|
15
|
+
## 🌟 Features
|
|
16
|
+
|
|
17
|
+
- Multiple attention mechanisms including multi-head, group query, linear, local, and KV cache variants
|
|
18
|
+
- Token embedding via `tiktoken`
|
|
19
|
+
- Absolute and sinusoidal positional embeddings
|
|
20
|
+
- Normalization layers like LayerNorm and RMSNorm
|
|
21
|
+
- Several feed-forward network variants with activations such as ReLU, GELU, SiLU, LeakyReLU, and Sigmoid
|
|
22
|
+
- A simple GPT-style transformer model implementation
|
|
23
|
+
|
|
24
|
+
---
|
|
25
|
+
|
|
26
|
+
## 📁 Project Structure
|
|
27
|
+
|
|
28
|
+
stackformer/ \
|
|
29
|
+
|-- modules/ \
|
|
30
|
+
| |-- tokenizer.py # Token embedding using tiktoken \
|
|
31
|
+
| |-- position_embedding.py # Absolute and sinusoidal embeddings \
|
|
32
|
+
| |-- Attention.py # Attention mechanisms \
|
|
33
|
+
| |-- Normalization.py # LayerNorm and RMSNorm \
|
|
34
|
+
| |-- Feed_forward.py # Feed-forward layers with various activations \
|
|
35
|
+
|-- models/ \
|
|
36
|
+
| -- GPT_2.py # GPT-style transformer stack model \
|
|
37
|
+
-- trainer.py # Training loop and utilities \
|
|
38
|
+
|
|
39
|
+
---
|
|
40
|
+
|
|
41
|
+
## 💻 Installation
|
|
42
|
+
|
|
43
|
+
Clone the repository and install in development mode:
|
|
44
|
+
|
|
45
|
+
```bash
|
|
46
|
+
git clone https://github.com/Gurumurthy30/Stackformer
|
|
47
|
+
cd transformers
|
|
48
|
+
pip install -e .
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
---
|
|
52
|
+
|
|
53
|
+
## 🚀 Future Plans
|
|
54
|
+
Currently, I am working on improving and optimizing the existing components while fixing known bugs and issues. After stabilizing the current modules, I plan to add more advanced blocks like Mixture of Experts (MoE), mask handling, and other essential transformer components. Eventually, I will expand the library by developing more comprehensive model architectures.
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: Stackformer
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Modular transformer blocks built in PyTorch
|
|
5
|
+
Home-page: https://github.com/Gurumurthy30/Stackformer
|
|
6
|
+
Author: Gurumurthy
|
|
7
|
+
Author-email: Gurumurthy <gurumurthy.00300@gmail.com>
|
|
8
|
+
License: MIT
|
|
9
|
+
Project-URL: Repository, https://github.com/Gurumurthy30/Stackformer
|
|
10
|
+
Project-URL: Issue Tracker, https://github.com/Gurumurthy30/Stackformer/issues
|
|
11
|
+
Project-URL: Discussions, https://github.com/Gurumurthy30/Stackformer/discussions
|
|
12
|
+
Requires-Python: >=3.9
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
14
|
+
License-File: LICENSE
|
|
15
|
+
Requires-Dist: torch>=2.6
|
|
16
|
+
Requires-Dist: tqdm>=4.67
|
|
17
|
+
Dynamic: author
|
|
18
|
+
Dynamic: home-page
|
|
19
|
+
Dynamic: license-file
|
|
20
|
+
Dynamic: requires-python
|
|
21
|
+
|
|
22
|
+
## 🧱 Stackformer
|
|
23
|
+
|
|
24
|
+
**Stackformer** is a modular transformer-building framework written entirely in PyTorch. It is designed primarily for experimentation, providing various transformer blocks such as attention mechanisms, normalization layers, feed-forward networks, and a simple model architecture. The project is a work-in-progress with plans for further enhancements and expansions.
|
|
25
|
+
|
|
26
|
+
---
|
|
27
|
+
|
|
28
|
+
## 📖 About Me
|
|
29
|
+
|
|
30
|
+
My name is **Gurumurthy**, and I am a final-year Bachelor of Engineering student from India. I created this library as my own size project to showcase my skills and knowledge in deep learning and transformer architectures.
|
|
31
|
+
|
|
32
|
+
I am also interested and free to work with others on different projects for knowledge sharing and building connections.
|
|
33
|
+
|
|
34
|
+
---
|
|
35
|
+
|
|
36
|
+
## 🌟 Features
|
|
37
|
+
|
|
38
|
+
- Multiple attention mechanisms including multi-head, group query, linear, local, and KV cache variants
|
|
39
|
+
- Token embedding via `tiktoken`
|
|
40
|
+
- Absolute and sinusoidal positional embeddings
|
|
41
|
+
- Normalization layers like LayerNorm and RMSNorm
|
|
42
|
+
- Several feed-forward network variants with activations such as ReLU, GELU, SiLU, LeakyReLU, and Sigmoid
|
|
43
|
+
- A simple GPT-style transformer model implementation
|
|
44
|
+
|
|
45
|
+
---
|
|
46
|
+
|
|
47
|
+
## 📁 Project Structure
|
|
48
|
+
|
|
49
|
+
stackformer/ \
|
|
50
|
+
|-- modules/ \
|
|
51
|
+
| |-- tokenizer.py # Token embedding using tiktoken \
|
|
52
|
+
| |-- position_embedding.py # Absolute and sinusoidal embeddings \
|
|
53
|
+
| |-- Attention.py # Attention mechanisms \
|
|
54
|
+
| |-- Normalization.py # LayerNorm and RMSNorm \
|
|
55
|
+
| |-- Feed_forward.py # Feed-forward layers with various activations \
|
|
56
|
+
|-- models/ \
|
|
57
|
+
| -- GPT_2.py # GPT-style transformer stack model \
|
|
58
|
+
-- trainer.py # Training loop and utilities \
|
|
59
|
+
|
|
60
|
+
---
|
|
61
|
+
|
|
62
|
+
## 💻 Installation
|
|
63
|
+
|
|
64
|
+
Clone the repository and install in development mode:
|
|
65
|
+
|
|
66
|
+
```bash
|
|
67
|
+
git clone https://github.com/Gurumurthy30/Stackformer
|
|
68
|
+
cd transformers
|
|
69
|
+
pip install -e .
|
|
70
|
+
```
|
|
71
|
+
|
|
72
|
+
---
|
|
73
|
+
|
|
74
|
+
## 🚀 Future Plans
|
|
75
|
+
Currently, I am working on improving and optimizing the existing components while fixing known bugs and issues. After stabilizing the current modules, I plan to add more advanced blocks like Mixture of Experts (MoE), mask handling, and other essential transformer components. Eventually, I will expand the library by developing more comprehensive model architectures.
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
setup.py
|
|
5
|
+
Stackformer.egg-info/PKG-INFO
|
|
6
|
+
Stackformer.egg-info/SOURCES.txt
|
|
7
|
+
Stackformer.egg-info/dependency_links.txt
|
|
8
|
+
Stackformer.egg-info/requires.txt
|
|
9
|
+
Stackformer.egg-info/top_level.txt
|
|
10
|
+
models/GPT_2.py
|
|
11
|
+
models/__init__.py
|
|
12
|
+
modules/Attention.py
|
|
13
|
+
modules/Feed_forward.py
|
|
14
|
+
modules/Normalization.py
|
|
15
|
+
modules/__init__.py
|
|
16
|
+
modules/mask.py
|
|
17
|
+
modules/position_embedding.py
|
|
18
|
+
modules/tokenizer.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
|
|
6
|
+
# --- position embedding ---
|
|
7
|
+
class SinusoidalPositionalEmbedding(nn.Module):
|
|
8
|
+
def __init__(self, seq_len, emb_dim):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.seq_len = seq_len
|
|
11
|
+
self.emb_dim = emb_dim
|
|
12
|
+
|
|
13
|
+
position = torch.arange(0, seq_len).unsqueeze(1)
|
|
14
|
+
div_term = torch.exp(torch.arange(0, emb_dim, 2) * -(math.log(10000.0) / emb_dim))
|
|
15
|
+
|
|
16
|
+
pe = torch.zeros(seq_len, emb_dim)
|
|
17
|
+
pe[:, 0::2] = torch.sin(position * div_term)
|
|
18
|
+
pe[:, 1::2] = torch.cos(position * div_term)
|
|
19
|
+
|
|
20
|
+
self.register_buffer("pe", pe)
|
|
21
|
+
|
|
22
|
+
def forward(self, x):
|
|
23
|
+
# x shape: (batch_size, seq_len, emb_dim) or (batch_size, seq_len)
|
|
24
|
+
batch_size, seq_len = x.shape[0], x.shape[1]
|
|
25
|
+
return self.pe[:seq_len].unsqueeze(0).expand(batch_size, seq_len, -1).to(x.device)
|
|
26
|
+
|
|
27
|
+
# --- Multi Head Attention ---
|
|
28
|
+
class MultiHeadAttention(nn.Module):
|
|
29
|
+
def __init__(self, Emb_dim, num_heads, dropout=0.1, device='cpu', dtype=torch.float32):
|
|
30
|
+
super().__init__()
|
|
31
|
+
assert Emb_dim % num_heads == 0, "Emb_dim must be divisible by num_heads"
|
|
32
|
+
self.Emb_dim = Emb_dim
|
|
33
|
+
self.device = device
|
|
34
|
+
self.num_heads = num_heads
|
|
35
|
+
self.head_dim = Emb_dim // num_heads
|
|
36
|
+
|
|
37
|
+
self.key = nn.Linear(Emb_dim, Emb_dim, bias=False, dtype=dtype, device=device)
|
|
38
|
+
self.query = nn.Linear(Emb_dim, Emb_dim, bias=False, dtype=dtype, device=device)
|
|
39
|
+
self.value = nn.Linear(Emb_dim, Emb_dim, bias=False, dtype=dtype, device=device)
|
|
40
|
+
|
|
41
|
+
self.scale = math.sqrt(self.head_dim)
|
|
42
|
+
self.out_proj = nn.Linear(Emb_dim, Emb_dim, dtype=dtype, device=device)
|
|
43
|
+
|
|
44
|
+
self.dropout = nn.Dropout(dropout)
|
|
45
|
+
|
|
46
|
+
def forward(self, x):
|
|
47
|
+
batch_size, seq_len, _ = x.shape
|
|
48
|
+
|
|
49
|
+
# Generate Q, K, V and reshape for multi-head attention
|
|
50
|
+
keys = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
51
|
+
queries = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
52
|
+
values = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
53
|
+
|
|
54
|
+
# Compute attention scores
|
|
55
|
+
scores = (queries @ keys.transpose(-2, -1)) / self.scale
|
|
56
|
+
|
|
57
|
+
# Create causal mask dynamically based on current sequence length
|
|
58
|
+
causal_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool, device=x.device), diagonal=1)
|
|
59
|
+
scores = scores.masked_fill(causal_mask[None, None, :, :], float('-inf'))
|
|
60
|
+
|
|
61
|
+
# Apply softmax and dropout
|
|
62
|
+
attn = F.softmax(scores, dim=-1)
|
|
63
|
+
attn = self.dropout(attn)
|
|
64
|
+
|
|
65
|
+
# Apply attention to values
|
|
66
|
+
out = attn @ values
|
|
67
|
+
|
|
68
|
+
# Concatenate heads and project
|
|
69
|
+
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.Emb_dim)
|
|
70
|
+
|
|
71
|
+
return self.out_proj(out)
|
|
72
|
+
|
|
73
|
+
# --- Feed Forward ---
|
|
74
|
+
class FF_ReLU(nn.Module):
|
|
75
|
+
def __init__(self, Emb_dim, hidden_dim, dropout=0.1, device='cpu', dtype=torch.float32):
|
|
76
|
+
super().__init__()
|
|
77
|
+
self.relu = nn.Sequential(
|
|
78
|
+
nn.Linear(Emb_dim, hidden_dim, device=device, dtype=dtype),
|
|
79
|
+
nn.ReLU(),
|
|
80
|
+
nn.Dropout(dropout),
|
|
81
|
+
nn.Linear(hidden_dim, Emb_dim, device=device, dtype=dtype),
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
def forward(self, x):
|
|
85
|
+
return self.relu(x)
|
|
86
|
+
|
|
87
|
+
class LayerNorm(nn.Module):
|
|
88
|
+
def __init__(self, Emb_dim, eps=1e-5, device='cpu', dtype=torch.float32):
|
|
89
|
+
super().__init__()
|
|
90
|
+
self.eps = eps
|
|
91
|
+
self.weight = nn.Parameter(torch.ones(Emb_dim, device=device, dtype=dtype))
|
|
92
|
+
self.bias = nn.Parameter(torch.zeros(Emb_dim, device=device, dtype=dtype))
|
|
93
|
+
|
|
94
|
+
def forward(self, x):
|
|
95
|
+
mean = x.mean(dim=-1, keepdim=True)
|
|
96
|
+
var = x.var(dim=-1, keepdim=True, unbiased=False)
|
|
97
|
+
norm_x = (x - mean) / torch.sqrt(var + self.eps)
|
|
98
|
+
return norm_x * self.weight + self.bias
|
|
99
|
+
|
|
100
|
+
# --- Transformer Block ---
|
|
101
|
+
class Block(nn.Module):
|
|
102
|
+
def __init__(self, Emb_dim, num_heads, dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
|
|
103
|
+
super().__init__()
|
|
104
|
+
self.attention = MultiHeadAttention(Emb_dim, num_heads, dropout, device=device, dtype=dtype)
|
|
105
|
+
self.norm1 = LayerNorm(Emb_dim, eps=eps, device=device, dtype=dtype)
|
|
106
|
+
self.ff_relu = FF_ReLU(Emb_dim, hidden_dim, dropout, device=device, dtype=dtype)
|
|
107
|
+
self.norm2 = LayerNorm(Emb_dim, eps=eps, device=device, dtype=dtype)
|
|
108
|
+
|
|
109
|
+
def forward(self, x):
|
|
110
|
+
# Pre-norm: normalize before attention
|
|
111
|
+
residual = x
|
|
112
|
+
x = self.norm1(x)
|
|
113
|
+
x = self.attention(x)
|
|
114
|
+
x = x + residual # Residual connection
|
|
115
|
+
|
|
116
|
+
# Pre-norm: normalize before FF
|
|
117
|
+
residual = x
|
|
118
|
+
x = self.norm2(x)
|
|
119
|
+
x = self.ff_relu(x)
|
|
120
|
+
x = x + residual # Residual connection
|
|
121
|
+
|
|
122
|
+
return x
|
|
123
|
+
|
|
124
|
+
# --- Encoder ---
|
|
125
|
+
class Encoder(nn.Module):
|
|
126
|
+
def __init__(self, num_layers, Emb_dim, num_heads, dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
|
|
127
|
+
super().__init__()
|
|
128
|
+
self.layers = nn.ModuleList([
|
|
129
|
+
Block(Emb_dim, num_heads, dropout, hidden_dim, eps, device=device, dtype=dtype)
|
|
130
|
+
for _ in range(num_layers)
|
|
131
|
+
])
|
|
132
|
+
|
|
133
|
+
def forward(self, x):
|
|
134
|
+
for layer in self.layers:
|
|
135
|
+
x = layer(x)
|
|
136
|
+
return x
|
|
137
|
+
|
|
138
|
+
class GPTModel(nn.Module):
|
|
139
|
+
def __init__(self, vocab_size, num_layers, Emb_dim, num_heads, seq_len,
|
|
140
|
+
dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
|
|
141
|
+
|
|
142
|
+
super().__init__()
|
|
143
|
+
# --- Token embedding ---
|
|
144
|
+
self.embedding = nn.Embedding(vocab_size, Emb_dim, dtype=self.dtype, device=self.device)
|
|
145
|
+
|
|
146
|
+
# --- Embedding dropout ---
|
|
147
|
+
self.emb_dropout = nn.Dropout(dropout)
|
|
148
|
+
|
|
149
|
+
# --- Adaptive position embedding ---
|
|
150
|
+
self.position_embedding = SinusoidalPositionalEmbedding(
|
|
151
|
+
emb_dim=Emb_dim,
|
|
152
|
+
seq_len=seq_len
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
# --- Encoder ---
|
|
156
|
+
self.encoder = Encoder(
|
|
157
|
+
num_layers=num_layers,
|
|
158
|
+
Emb_dim=Emb_dim,
|
|
159
|
+
num_heads=num_heads,
|
|
160
|
+
dropout=dropout,
|
|
161
|
+
hidden_dim=hidden_dim,
|
|
162
|
+
eps=eps,
|
|
163
|
+
device=self.device,
|
|
164
|
+
dtype=self.dtype
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# --- Final norm
|
|
168
|
+
self.final_norm = LayerNorm(Emb_dim, eps=eps,
|
|
169
|
+
device=self.device, dtype=self.dtype)
|
|
170
|
+
|
|
171
|
+
# --- Output Projection ---
|
|
172
|
+
self.lm_head = nn.Linear(Emb_dim, vocab_size, bias=False,
|
|
173
|
+
dtype=self.dtype, device=self.device)
|
|
174
|
+
|
|
175
|
+
def forward(self, x):
|
|
176
|
+
# x shape: (batch_size, seq_len)
|
|
177
|
+
emb = self.embedding(x) # (batch_size, seq_len, emb_dim)
|
|
178
|
+
pos = self.position_embedding(x) # (batch_size, seq_len, emb_dim)
|
|
179
|
+
x = emb + pos
|
|
180
|
+
x = self.emb_dropout(x)
|
|
181
|
+
x = self.encoder(x)
|
|
182
|
+
x = self.final_norm(x)
|
|
183
|
+
x = self.lm_head(x)
|
|
184
|
+
return x
|
|
185
|
+
|
|
186
|
+
@torch.no_grad()
|
|
187
|
+
def generate(self, prompt_ids, max_new_tokens=50, temperature=1.0, top_k=None, top_p=1.0, eos_token_id=None):
|
|
188
|
+
self.eval()
|
|
189
|
+
if prompt_ids.dim() == 1:
|
|
190
|
+
prompt_ids = prompt_ids.unsqueeze(0) # (1, seq_len)
|
|
191
|
+
|
|
192
|
+
generated = prompt_ids.clone()
|
|
193
|
+
max_context_len = self.seq_len
|
|
194
|
+
|
|
195
|
+
for _ in range(max_new_tokens):
|
|
196
|
+
# Use sliding window if sequence gets too long
|
|
197
|
+
if generated.size(1) > max_context_len:
|
|
198
|
+
input_ids = generated[:, -max_context_len:]
|
|
199
|
+
else:
|
|
200
|
+
input_ids = generated
|
|
201
|
+
|
|
202
|
+
logits = self.forward(input_ids) # (batch_size, seq_len, vocab_size)
|
|
203
|
+
logits = logits[:, -1, :] # (batch_size, vocab_size)
|
|
204
|
+
|
|
205
|
+
# --- Temperature scaling ---
|
|
206
|
+
if temperature != 1.0:
|
|
207
|
+
logits = logits / temperature
|
|
208
|
+
|
|
209
|
+
# --- Top-k filtering ---
|
|
210
|
+
if top_k is not None and top_k > 0:
|
|
211
|
+
topk_vals, topk_indices = torch.topk(logits, top_k)
|
|
212
|
+
mask = torch.full_like(logits, float('-inf'))
|
|
213
|
+
mask.scatter_(dim=-1, index=topk_indices, src=topk_vals)
|
|
214
|
+
logits = mask
|
|
215
|
+
|
|
216
|
+
# --- Top-p ---
|
|
217
|
+
if top_p < 1.0:
|
|
218
|
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
|
219
|
+
probs = F.softmax(sorted_logits, dim=-1)
|
|
220
|
+
cum_probs = torch.cumsum(probs, dim=-1)
|
|
221
|
+
|
|
222
|
+
sorted_mask = cum_probs > top_p
|
|
223
|
+
sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
|
|
224
|
+
sorted_mask[..., 0] = 0
|
|
225
|
+
|
|
226
|
+
indices_to_remove = sorted_mask.scatter(dim=-1, index=sorted_indices, src=sorted_mask)
|
|
227
|
+
logits = logits.masked_fill(indices_to_remove, float('-inf'))
|
|
228
|
+
|
|
229
|
+
# Sample next token
|
|
230
|
+
probs = F.softmax(logits, dim=-1)
|
|
231
|
+
next_token = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
|
|
232
|
+
generated = torch.cat([generated, next_token], dim=-1)
|
|
233
|
+
|
|
234
|
+
# check if we've reached the end of the sequence
|
|
235
|
+
if eos_token_id is not None and next_token.item() == eos_token_id:
|
|
236
|
+
break
|
|
237
|
+
|
|
238
|
+
return generated
|
|
File without changes
|
|
@@ -0,0 +1,533 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
class Self_Attention(nn.Module):
|
|
5
|
+
def __init__(self, Emb_dim, dropout,dtype=torch.float32,device='cpu'):
|
|
6
|
+
super().__init__()
|
|
7
|
+
self.device = device
|
|
8
|
+
self.scale = torch.tensor(Emb_dim ** 0.5,dtype=dtype,device=device)
|
|
9
|
+
|
|
10
|
+
self.key = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
|
|
11
|
+
self.query = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
|
|
12
|
+
self.value = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
|
|
13
|
+
|
|
14
|
+
self.out_proj = nn.Linear(Emb_dim, Emb_dim,dtype=dtype,device=device)
|
|
15
|
+
self.dropout = nn.Dropout(dropout)
|
|
16
|
+
|
|
17
|
+
def forward(self, x):
|
|
18
|
+
Batch_size, Seq_len, Emb_dim = x.size()
|
|
19
|
+
|
|
20
|
+
Querys = self.query(x) # (Batch_size, Seq_len, D)
|
|
21
|
+
Keys = self.key(x) # (Batch_size, Seq_len, D)
|
|
22
|
+
Values = self.value(x) # (Batch_size, Seq_len, D)
|
|
23
|
+
|
|
24
|
+
# Attention scores
|
|
25
|
+
scores = Querys @ Keys.transpose(-2, -1) / self.scale # (Batch_size, Seq_len, Seq_len)
|
|
26
|
+
|
|
27
|
+
causal_mask = torch.triu(torch.ones(Seq_len, Seq_len, dtype=torch.bool, device=self.device), diagonal=1)
|
|
28
|
+
scores = scores.masked_fill_(causal_mask, float('-inf')) # Mask *future* tokens
|
|
29
|
+
|
|
30
|
+
attn = F.softmax(scores, dim=-1) # (Batch_size, Seq_len, Seq_len)
|
|
31
|
+
attn = self.dropout(attn)
|
|
32
|
+
|
|
33
|
+
out = (attn @ Values) # (Batch_size, Seq_len, D)
|
|
34
|
+
return self.out_proj(out) # (Batch_size, Seq_len, D)
|
|
35
|
+
|
|
36
|
+
class Multi_Head_Attention(nn.Module):
|
|
37
|
+
def __init__(self, Emb_dim, num_heads, dropout, device='cpu',dtype=torch.float32):
|
|
38
|
+
super().__init__()
|
|
39
|
+
assert Emb_dim % num_heads == 0, "Emb_dim must be divisible by num_heads"
|
|
40
|
+
self.Emb_dim = Emb_dim
|
|
41
|
+
self.num_heads = num_heads
|
|
42
|
+
self.device = device
|
|
43
|
+
self.head_dim = Emb_dim // num_heads
|
|
44
|
+
|
|
45
|
+
self.key = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
|
|
46
|
+
self.query = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
|
|
47
|
+
self.value = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
|
|
48
|
+
|
|
49
|
+
self.scale = torch.tensor(self.head_dim ** 0.5,device=device,dtype=dtype)
|
|
50
|
+
self.out_proj = nn.Linear(Emb_dim, Emb_dim,dtype=dtype,device=device)
|
|
51
|
+
|
|
52
|
+
self.dropout = nn.Dropout(dropout)
|
|
53
|
+
|
|
54
|
+
def forward(self, x):
|
|
55
|
+
Batch_size, Seq_len, _ = x.shape
|
|
56
|
+
|
|
57
|
+
# Generate Q, K, V and reshape for multi-head attention
|
|
58
|
+
Keys = self.key(x).view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (Batch_size, nh, Seq_len, hd)
|
|
59
|
+
Querys = self.query(x).view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
60
|
+
Values = self.value(x).view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
61
|
+
|
|
62
|
+
# Compute attention scores
|
|
63
|
+
scores = (Querys @ Keys.transpose(-2, -1)) / self.scale # (Batch_size, nh, Seq_len, Seq_len)
|
|
64
|
+
|
|
65
|
+
# Apply causal mask if requested
|
|
66
|
+
causal_mask = torch.triu(torch.ones(Seq_len, Seq_len, dtype=torch.bool, device=self.device), diagonal=1)
|
|
67
|
+
scores = scores.masked_fill_(causal_mask[None, None, :, :], float('-inf'))
|
|
68
|
+
|
|
69
|
+
# Apply softmax and dropout
|
|
70
|
+
attn = F.softmax(scores, dim=-1)
|
|
71
|
+
attn = self.dropout(attn)
|
|
72
|
+
|
|
73
|
+
# Apply attention to values
|
|
74
|
+
out = attn @ Values # (Batch_size, nh, Seq_len, hd)
|
|
75
|
+
|
|
76
|
+
# Concatenate heads and project
|
|
77
|
+
out = out.transpose(1, 2).contiguous().view(Batch_size, Seq_len, self.Emb_dim) # (Batch_size, Seq_len, Emb_dim)
|
|
78
|
+
|
|
79
|
+
return self.out_proj(out)
|
|
80
|
+
|
|
81
|
+
class Cross_MultiHead_Attention(nn.Module):
|
|
82
|
+
def __init__(self, Emb_dim, num_heads, dropout,device='cpu', dtype=torch.float32):
|
|
83
|
+
super().__init__()
|
|
84
|
+
assert Emb_dim % num_heads == 0, "Emb_dim must be divisible by num_heads"
|
|
85
|
+
self.Emb_dim = Emb_dim
|
|
86
|
+
self.device = device
|
|
87
|
+
self.num_heads = num_heads
|
|
88
|
+
self.head_dim = Emb_dim // num_heads
|
|
89
|
+
|
|
90
|
+
# Querys, Key, Value projections
|
|
91
|
+
self.query = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
|
|
92
|
+
self.key = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
|
|
93
|
+
self.value = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
|
|
94
|
+
|
|
95
|
+
self.scale = torch.tensor(self.head_dim ** 0.5,device=device,dtype=dtype)
|
|
96
|
+
|
|
97
|
+
self.out_proj = nn.Linear(Emb_dim, Emb_dim,dtype=dtype,device=device)
|
|
98
|
+
self.dropout = nn.Dropout(dropout)
|
|
99
|
+
|
|
100
|
+
def forward(self, x, context=None):
|
|
101
|
+
"""
|
|
102
|
+
x: (Batch_size, query_seq_len, Emb_dim) — query input (e.g., decoder hidden states)
|
|
103
|
+
context: (Batch_size, KV_seq_len, Emb_dim) — source for keys/values (e.g., encoder output). If None, self-attention.
|
|
104
|
+
mask: (Batch_size, 1, query_seq_len, KV_seq_len) — optional attention mask
|
|
105
|
+
"""
|
|
106
|
+
Batch_size, query_seq_len, _ = x.shape
|
|
107
|
+
context = x if context is None else context # self-attention fallback
|
|
108
|
+
KV_seq_len = context.shape[1]
|
|
109
|
+
|
|
110
|
+
# Project Q, K, V
|
|
111
|
+
Querys = self.query(x).view(Batch_size, query_seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (Batch_size, nh, query_seq_len, hd)
|
|
112
|
+
Keys = self.key(context).view(Batch_size, KV_seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (Batch_size, nh, KV_seq_len, hd)
|
|
113
|
+
Values = self.value(context).view(Batch_size, KV_seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (Batch_size, nh, KV_seq_len, hd)
|
|
114
|
+
|
|
115
|
+
# Attention scores
|
|
116
|
+
scores = (Querys @ Keys.transpose(-2, -1)) / self.scale # (Batch_size, nh, query_seq_len, KV_seq_len)
|
|
117
|
+
|
|
118
|
+
causal_mask = torch.triu(torch.ones(query_seq_len, query_seq_len, dtype=torch.bool, device=self.device), diagonal=1)
|
|
119
|
+
scores = scores.masked_fill_(causal_mask[None, None, :, :], float('-inf'))
|
|
120
|
+
|
|
121
|
+
attn = F.softmax(scores, dim=-1)
|
|
122
|
+
attn = self.dropout(attn)
|
|
123
|
+
|
|
124
|
+
out = attn @ Values # (Batch_size, nh, query_seq_len, hd)
|
|
125
|
+
out = out.transpose(1, 2).contiguous().view(Batch_size, query_seq_len, self.Emb_dim) # (Batch_size, query_seq_len, Emb_dim)
|
|
126
|
+
|
|
127
|
+
return self.out_proj(out)
|
|
128
|
+
|
|
129
|
+
class Multi_query_Attention(nn.Module):
|
|
130
|
+
def __init__(self, Emb_dim, num_heads, dropout, device='cpu', dtype=torch.float32):
|
|
131
|
+
super().__init__()
|
|
132
|
+
assert Emb_dim % num_heads == 0, "Emb_dim must be divisible by num_heads"
|
|
133
|
+
self.Emb_dim = Emb_dim
|
|
134
|
+
self.device = device
|
|
135
|
+
self.num_heads = num_heads
|
|
136
|
+
self.head_dim = Emb_dim // num_heads
|
|
137
|
+
|
|
138
|
+
self.query = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
|
|
139
|
+
self.key = nn.Linear(Emb_dim, self.head_dim, bias=False,dtype=dtype,device=device)
|
|
140
|
+
self.value = nn.Linear(Emb_dim, self.head_dim, bias=False,dtype=dtype,device=device)
|
|
141
|
+
|
|
142
|
+
self.scale = torch.tensor(self.head_dim ** 0.5,device=device,dtype=dtype)
|
|
143
|
+
|
|
144
|
+
self.out_proj = nn.Linear(Emb_dim, Emb_dim,dtype=dtype,device=device)
|
|
145
|
+
|
|
146
|
+
self.dropout = nn.Dropout(dropout)
|
|
147
|
+
|
|
148
|
+
def forward(self, x):
|
|
149
|
+
Batch_size, Seq_len, C = x.shape
|
|
150
|
+
# Generate Q, K, V and reshape for Multiquery_Attention
|
|
151
|
+
Querys = self.query(x)
|
|
152
|
+
Keys = self.key(x)
|
|
153
|
+
Values = self.value(x)
|
|
154
|
+
|
|
155
|
+
Querys = Querys.view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (Batch_size, nh, Seq_len, hd)
|
|
156
|
+
Keys = Keys.unsqueeze(1).expand(Batch_size, 1, Seq_len, self.head_dim) # (Batch_size, 1, Seq_len, hd)
|
|
157
|
+
Values = Values.unsqueeze(1).expand(Batch_size, 1, Seq_len, self.head_dim)
|
|
158
|
+
|
|
159
|
+
# Compute attention scores
|
|
160
|
+
scores = (Querys @ Keys.transpose(-2, -1)) / self.scale # (Batch_size, nh, Seq_len, Seq_len)
|
|
161
|
+
|
|
162
|
+
# Apply causal mask if requested
|
|
163
|
+
causal_mask = torch.triu(torch.ones(Seq_len, Seq_len, dtype=torch.bool, device=self.device), diagonal=1)
|
|
164
|
+
scores = scores.masked_fill_(causal_mask[None, None, :, :], float('-inf'))
|
|
165
|
+
|
|
166
|
+
# Apply softmax and dropout
|
|
167
|
+
attn = F.softmax(scores, dim=-1)
|
|
168
|
+
attn = self.dropout(attn)
|
|
169
|
+
|
|
170
|
+
# Apply attention to values
|
|
171
|
+
out = attn @ Values # (Batch_size, nh, Seq_len, hd)
|
|
172
|
+
|
|
173
|
+
# Concatenate heads and project
|
|
174
|
+
out = out.transpose(1, 2).contiguous().view(Batch_size, Seq_len, self.Emb_dim) # (Batch_size, Seq_len, Emb_dim)
|
|
175
|
+
|
|
176
|
+
return self.out_proj(out)
|
|
177
|
+
|
|
178
|
+
class Group_query_Attention(nn.Module):
|
|
179
|
+
def __init__(self, Emb_dim, num_query_heads, num_kv_heads, dropout,device='cpu', dtype=torch.float32):
|
|
180
|
+
super().__init__()
|
|
181
|
+
assert Emb_dim % num_query_heads == 0, "Emb_dim must be divisible by num_heads"
|
|
182
|
+
self.Emb_dim = Emb_dim
|
|
183
|
+
self.device = device
|
|
184
|
+
self.num_query_heads = num_query_heads
|
|
185
|
+
self.num_kv_heads = num_kv_heads
|
|
186
|
+
|
|
187
|
+
self.head_dim = Emb_dim // num_query_heads
|
|
188
|
+
self.num_queries_pre_kv = num_query_heads // num_kv_heads
|
|
189
|
+
|
|
190
|
+
self.query = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
|
|
191
|
+
self.key = nn.Linear(Emb_dim, self. num_kv_heads * self.head_dim, bias=False,dtype=dtype,device=device)
|
|
192
|
+
self.value = nn.Linear(Emb_dim, self.num_kv_heads * self.head_dim, bias=False,dtype=dtype,device=device)
|
|
193
|
+
|
|
194
|
+
self.scale = torch.tensor(self.head_dim ** 0.5,device=device,dtype=dtype)
|
|
195
|
+
|
|
196
|
+
self.out_proj = nn.Linear(Emb_dim, Emb_dim,dtype=dtype,device=device)
|
|
197
|
+
|
|
198
|
+
self.dropout = nn.Dropout(dropout)
|
|
199
|
+
|
|
200
|
+
def forward(self, x):
|
|
201
|
+
Batch_size, Seq_len, C = x.shape
|
|
202
|
+
|
|
203
|
+
# Generate Q, K, V and reshape for Multiquery_Attention
|
|
204
|
+
Querys = self.query(x).view(Batch_size, Seq_len, self.num_query_heads, self.head_dim).transpose(1, 2) # (Batch_size, nqh, Seq_len, hd)
|
|
205
|
+
Keys = self.key(x).view(Batch_size, Seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) # (Batch_size, nkvh, Seq_len, hd)
|
|
206
|
+
Values = self.value(x).view(Batch_size, Seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) # (Batch_size, nkvh, Seq_len, hd)
|
|
207
|
+
|
|
208
|
+
Keys = Keys.repeat_interleave(self.num_queries_pre_kv,dim=1)
|
|
209
|
+
Values = Values.repeat_interleave(self.num_queries_pre_kv,dim=1)
|
|
210
|
+
|
|
211
|
+
# Compute attention scores
|
|
212
|
+
scores = (Querys @ Keys.transpose(-2, -1)) / self.scale # (Batch_size, nh, Seq_len, Seq_len)
|
|
213
|
+
|
|
214
|
+
# Apply causal mask if requested
|
|
215
|
+
causal_mask = torch.triu(torch.ones(Seq_len, Seq_len, dtype=torch.bool, device=self.device), diagonal=1)
|
|
216
|
+
scores = scores.masked_fill_(causal_mask[None, None, :, :], float('-inf'))
|
|
217
|
+
|
|
218
|
+
# Apply softmax and dropout
|
|
219
|
+
attn = F.softmax(scores, dim=-1)
|
|
220
|
+
attn = self.dropout(attn)
|
|
221
|
+
|
|
222
|
+
# Apply attention to values
|
|
223
|
+
out = attn @ Values # (Batch_size, nh, Seq_len, hd)
|
|
224
|
+
|
|
225
|
+
# Concatenate heads and project
|
|
226
|
+
out = out.transpose(1, 2).contiguous().view(Batch_size, Seq_len, self.Emb_dim) # (Batch_size, Seq_len, Emb_dim)
|
|
227
|
+
|
|
228
|
+
return self.out_proj(out)
|
|
229
|
+
|
|
230
|
+
class Linear_Attention(nn.Module):
|
|
231
|
+
def __init__(self, Emb_dim, num_heads, dropout, eps = 1e-5, device='cpu', dtype=torch.float32):
|
|
232
|
+
super().__init__()
|
|
233
|
+
assert Emb_dim % num_heads == 0, "Emb_dim must be divisible by num_heads"
|
|
234
|
+
self.Emb_dim = Emb_dim
|
|
235
|
+
self.eps = eps
|
|
236
|
+
self.device = device
|
|
237
|
+
self.dtype=dtype
|
|
238
|
+
self.num_heads = num_heads
|
|
239
|
+
self.head_dim = Emb_dim // self.num_heads
|
|
240
|
+
|
|
241
|
+
self.query = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
|
|
242
|
+
self.key = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
|
|
243
|
+
self.value = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
|
|
244
|
+
|
|
245
|
+
self.out_proj = nn.Linear(Emb_dim, Emb_dim,dtype=dtype,device=device)
|
|
246
|
+
|
|
247
|
+
self.dropout = nn.Dropout(dropout)
|
|
248
|
+
|
|
249
|
+
def forward(self, x):
|
|
250
|
+
Batch_size, Seq_len, _ = x.shape
|
|
251
|
+
|
|
252
|
+
# Generate Q, K, V and reshape for multi-head attention
|
|
253
|
+
Querys = self.query(x).view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (Batch_size, nh, Seq_len, hd)
|
|
254
|
+
Keys = self.key(x).view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
255
|
+
Values = self.value(x).view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
256
|
+
|
|
257
|
+
phi_q = F.elu(Querys) + 1.0
|
|
258
|
+
phi_k = F.elu(Keys) + 1.0
|
|
259
|
+
|
|
260
|
+
kv_outer_product = torch.matmul(phi_k.unsqueeze(-1),Values.unsqueeze(-2))
|
|
261
|
+
|
|
262
|
+
s_cumulative = torch.cumsum(kv_outer_product, dim=2)
|
|
263
|
+
z_cumulative = torch.cumsum(phi_k,dim=2)
|
|
264
|
+
|
|
265
|
+
numerator = torch.matmul(phi_q.unsqueeze(-2),s_cumulative).squeeze(-2)
|
|
266
|
+
denominator = torch.sum(phi_q * z_cumulative,dim=-1,keepdim=True) + self.eps
|
|
267
|
+
|
|
268
|
+
out = numerator / denominator
|
|
269
|
+
out = out.transpose(1, 2).contiguous().view(Batch_size, Seq_len, self.Emb_dim) # (Batch_size, Seq_len, Emb_dim)
|
|
270
|
+
out = self.out_proj(out)
|
|
271
|
+
return self.dropout(out)
|
|
272
|
+
|
|
273
|
+
class Multi_latent_Attention(nn.Module):
|
|
274
|
+
def __init__(self, Emb_dim, q_compressed_dim, kv_compressed_dim , num_heads,device='cpu' ,dtype=torch.float32, dropout=0):
|
|
275
|
+
super().__init__()
|
|
276
|
+
assert Emb_dim % num_heads == 0, "Emb_dim must be divisible by num_heads"
|
|
277
|
+
self.Emb_dim = Emb_dim
|
|
278
|
+
self.device = device
|
|
279
|
+
self.q_compressed_dim = q_compressed_dim
|
|
280
|
+
self.kv_compressed_dim = kv_compressed_dim
|
|
281
|
+
self.num_heads = num_heads
|
|
282
|
+
self.head_dim = Emb_dim // self.num_heads
|
|
283
|
+
|
|
284
|
+
self.W_dq = nn.Linear(Emb_dim,q_compressed_dim,bias=False,dtype=dtype,device=device)
|
|
285
|
+
self.W_dq_norm = nn.LayerNorm(q_compressed_dim,dtype=dtype,device=device)
|
|
286
|
+
self.W_uq = nn.Linear(q_compressed_dim,Emb_dim,bias=False,dtype=dtype,device=device)
|
|
287
|
+
|
|
288
|
+
self.W_dkv = nn.Linear(Emb_dim,kv_compressed_dim,bias=False,dtype=dtype,device=device)
|
|
289
|
+
self.W_dkv_norm = nn.LayerNorm(kv_compressed_dim,dtype=dtype,device=device)
|
|
290
|
+
self.W_uk = nn.Linear(kv_compressed_dim,Emb_dim,dtype=dtype,device=device)
|
|
291
|
+
self.W_uv = nn.Linear(kv_compressed_dim,Emb_dim,dtype=dtype,device=device)
|
|
292
|
+
|
|
293
|
+
self.out_proj = nn.Linear(Emb_dim, Emb_dim,dtype=dtype,device=device)
|
|
294
|
+
|
|
295
|
+
self.dropout = nn.Dropout(dropout)
|
|
296
|
+
|
|
297
|
+
def forward(self, x):
|
|
298
|
+
Batch_size, Seq_len, C = x.shape
|
|
299
|
+
|
|
300
|
+
compressed_q_latent = self.W_dq(x)
|
|
301
|
+
compressed_q_latent_norm = self.W_dq_norm(compressed_q_latent)
|
|
302
|
+
q_final = self.W_uq(compressed_q_latent_norm)
|
|
303
|
+
|
|
304
|
+
compressed_kv_latent = self.W_dkv(x)
|
|
305
|
+
compressed_kv_latent_norm = self.W_dkv_norm(compressed_kv_latent)
|
|
306
|
+
k_final = self.W_uk(compressed_kv_latent_norm)
|
|
307
|
+
v_final = self.W_uv(compressed_kv_latent_norm)
|
|
308
|
+
|
|
309
|
+
Querys = q_final.view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
310
|
+
Keys = k_final.view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
311
|
+
Values = v_final.view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
312
|
+
|
|
313
|
+
out = F.scaled_dot_product_attention(
|
|
314
|
+
query=Querys,
|
|
315
|
+
key=Keys,
|
|
316
|
+
value=Values,
|
|
317
|
+
attn_mask=None,
|
|
318
|
+
is_causal=True,
|
|
319
|
+
dropout_p=self.dropout.p # use self.dropout.p to get dropout prob
|
|
320
|
+
)
|
|
321
|
+
out = self.out_proj(out.transpose(1, 2).contiguous().view(Batch_size, Seq_len, C))
|
|
322
|
+
out = self.dropout(out)
|
|
323
|
+
return out
|
|
324
|
+
|
|
325
|
+
class Local_Attention(nn.Module):
|
|
326
|
+
def __init__(self, Emb_dim, num_heads, Window_size ,dropout, device='cpu',dtype=torch.float32):
|
|
327
|
+
super().__init__()
|
|
328
|
+
assert Emb_dim % num_heads == 0, "Emb_dim must be divisible by num_heads"
|
|
329
|
+
self.Emb_dim = Emb_dim
|
|
330
|
+
self.Window_size = Window_size
|
|
331
|
+
self.device = device
|
|
332
|
+
self.num_heads = num_heads
|
|
333
|
+
self.head_dim = Emb_dim // num_heads
|
|
334
|
+
|
|
335
|
+
self.key = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
|
|
336
|
+
self.query = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
|
|
337
|
+
self.value = nn.Linear(Emb_dim, Emb_dim, bias=False,dtype=dtype,device=device)
|
|
338
|
+
|
|
339
|
+
self.scale = torch.tensor(self.head_dim ** 0.5,device=device,dtype=dtype)
|
|
340
|
+
self.out_proj = nn.Linear(Emb_dim, Emb_dim,dtype=dtype,device=device)
|
|
341
|
+
|
|
342
|
+
self.dropout = nn.Dropout(dropout)
|
|
343
|
+
|
|
344
|
+
def forward(self, x):
|
|
345
|
+
Batch_size, Seq_len, _ = x.shape
|
|
346
|
+
|
|
347
|
+
# Generate Q, K, V and reshape for multi-head attention
|
|
348
|
+
Keys = self.key(x).view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (Batch_size, nh, Seq_len, hd)
|
|
349
|
+
Querys = self.query(x).view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
350
|
+
Values = self.value(x).view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
351
|
+
|
|
352
|
+
# Compute attention scores
|
|
353
|
+
scores = (Querys @ Keys.transpose(-2, -1)) / self.scale # (Batch_size, nh, Seq_len, Seq_len)
|
|
354
|
+
|
|
355
|
+
# Apply sliding window mask
|
|
356
|
+
casual = torch.tril(torch.ones_like(scores,dtype=bool))
|
|
357
|
+
band = torch.triu(casual, diagonal=-(self.Window_size-1))
|
|
358
|
+
scores = scores.masked_fill_(~band, float('-inf'))
|
|
359
|
+
|
|
360
|
+
# Apply softmax and dropout
|
|
361
|
+
attn = F.softmax(scores, dim=-1)
|
|
362
|
+
attn = self.dropout(attn)
|
|
363
|
+
|
|
364
|
+
# Apply attention to values
|
|
365
|
+
out = attn @ Values # (Batch_size, nh, Seq_len, hd)
|
|
366
|
+
|
|
367
|
+
# Concatenate heads and project
|
|
368
|
+
out = out.transpose(1, 2).contiguous().view(Batch_size, Seq_len, self.Emb_dim) # (Batch_size, Seq_len, Emb_dim)
|
|
369
|
+
|
|
370
|
+
return self.out_proj(out)
|
|
371
|
+
|
|
372
|
+
def precompute_theta_position_frequency(head_dim, seq_len, device='cpu', theta=10000.0):
|
|
373
|
+
assert head_dim % 2 == 0, "head_dim must be even"
|
|
374
|
+
|
|
375
|
+
# Frequencies: 1 / (theta ** (2i / head_dim))
|
|
376
|
+
theta_numerator = torch.arange(0, head_dim, 2, device=device)
|
|
377
|
+
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))
|
|
378
|
+
|
|
379
|
+
# Position indices
|
|
380
|
+
m = torch.arange(seq_len, device=device)
|
|
381
|
+
|
|
382
|
+
# Outer product: (seq_len, head_dim // 2)
|
|
383
|
+
freqs = torch.outer(m, inv_freq)
|
|
384
|
+
|
|
385
|
+
# Convert to complex exponential form: exp(i * freq)
|
|
386
|
+
freq_complex = torch.polar(torch.ones_like(freqs), freqs)
|
|
387
|
+
return freq_complex
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
def apply_rotry_position_embedding(x, freq_complex, device='cpu', dtype=torch.float32):
|
|
391
|
+
# x: (batch_size, seq_len, num_head, emb_dim)
|
|
392
|
+
batch_size, seq_len, num_head, emb_dim = x.shape
|
|
393
|
+
assert emb_dim % 2 == 0, "emb_dim must be even"
|
|
394
|
+
|
|
395
|
+
# Reshape to split last dimension into complex pairs
|
|
396
|
+
x_reshaped = x.view(batch_size, seq_len, num_head, emb_dim // 2, 2).to(device=device, dtype=dtype)
|
|
397
|
+
x_complex = torch.view_as_complex(x_reshaped)
|
|
398
|
+
|
|
399
|
+
# Prepare frequencies: (1, seq_len, 1, emb_dim//2)
|
|
400
|
+
freq_complex = freq_complex[:seq_len].unsqueeze(0).unsqueeze(2).to(device=device)
|
|
401
|
+
|
|
402
|
+
# Apply rotation
|
|
403
|
+
x_rotated = x_complex * freq_complex
|
|
404
|
+
|
|
405
|
+
# Convert back to real tensor and reshape
|
|
406
|
+
x_out = torch.view_as_real(x_rotated).contiguous().view(batch_size, seq_len, num_head, emb_dim)
|
|
407
|
+
return x_out.to(device=device, dtype=dtype)
|
|
408
|
+
|
|
409
|
+
class kv_cache_multihead(nn.Module):
|
|
410
|
+
def __init__(self, emb_dim, num_heads, batch_size, kv_seq_len, device='cpu', dtype=torch.float32,dropout=0.1):
|
|
411
|
+
super().__init__()
|
|
412
|
+
self.dtype = dtype
|
|
413
|
+
self.device = device
|
|
414
|
+
|
|
415
|
+
assert emb_dim % num_heads == 0
|
|
416
|
+
self.emb_dim = emb_dim
|
|
417
|
+
self.num_heads = num_heads
|
|
418
|
+
self.head_dim = emb_dim // num_heads
|
|
419
|
+
self.kv_seq_len = kv_seq_len
|
|
420
|
+
|
|
421
|
+
self.query = nn.Linear(emb_dim, emb_dim, bias=False,dtype=dtype,device=device)
|
|
422
|
+
self.key = nn.Linear(emb_dim, emb_dim, bias=False,dtype=dtype,device=device)
|
|
423
|
+
self.value = nn.Linear(emb_dim, emb_dim, bias=False,dtype=dtype,device=device)
|
|
424
|
+
|
|
425
|
+
self.out_proj = nn.Linear(emb_dim, emb_dim,dtype=dtype,device=device)
|
|
426
|
+
self.dropout = nn.Dropout(dropout)
|
|
427
|
+
|
|
428
|
+
self.cache_keys = torch.zeros(batch_size, kv_seq_len, num_heads, self.head_dim,dtype=dtype,device=device)
|
|
429
|
+
self.cache_value = torch.zeros(batch_size, kv_seq_len, num_heads, self.head_dim,dtype=dtype,device=device)
|
|
430
|
+
|
|
431
|
+
def forward(self, x, start_pos, RoPE: False):
|
|
432
|
+
batch_size, seq_len, C = x.shape
|
|
433
|
+
|
|
434
|
+
xq = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
|
|
435
|
+
xk = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
|
|
436
|
+
xv = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
|
|
437
|
+
|
|
438
|
+
if RoPE:
|
|
439
|
+
freq_complex = precompute_theta_position_frequency(head_dim=self.head_dim, seq_len=seq_len, device=self.device)
|
|
440
|
+
xq = apply_rotry_position_embedding(xq, freq_complex, device=self.device, dtype=self.dtype)
|
|
441
|
+
freq_complex = precompute_theta_position_frequency(head_dim=self.head_dim, seq_len=self.kv_seq_len, device=self.device)
|
|
442
|
+
xk = apply_rotry_position_embedding(xk, freq_complex, device=self.device, dtype=self.dtype)
|
|
443
|
+
|
|
444
|
+
# Cache keys and values
|
|
445
|
+
self.cache_keys[:, start_pos:start_pos+seq_len] = xk
|
|
446
|
+
self.cache_value[:, start_pos:start_pos+seq_len] = xv
|
|
447
|
+
|
|
448
|
+
xk_full = self.cache_keys[:, :start_pos+seq_len]
|
|
449
|
+
xv_full = self.cache_value[:, :start_pos+seq_len]
|
|
450
|
+
|
|
451
|
+
query = xq.transpose(1, 2) # (batch_size, num_head, seq_len, emb_dim)
|
|
452
|
+
key = xk_full.transpose(1, 2) # (batch_size, num_head, T_total, emb_dim)
|
|
453
|
+
value = xv_full.transpose(1, 2) # (batch_size, num_head, T_total, emb_dim)
|
|
454
|
+
|
|
455
|
+
attn_scores = torch.matmul(query, key.transpose(2, 3)) / (self.head_dim ** 0.5)
|
|
456
|
+
|
|
457
|
+
# Causal mask
|
|
458
|
+
causal_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool, device=self.device), diagonal=1)
|
|
459
|
+
attn_scores.masked_fill_(causal_mask[None, None, :, :], float('-inf'))
|
|
460
|
+
|
|
461
|
+
attn_weights = F.softmax(attn_scores, dim=-1)
|
|
462
|
+
out = torch.matmul(attn_weights, value)
|
|
463
|
+
|
|
464
|
+
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
|
|
465
|
+
return self.dropout(self.out_proj(out))
|
|
466
|
+
|
|
467
|
+
class kv_cache_group_query(nn.Module):
|
|
468
|
+
def __init__(self, emb_dim, query_num_heads, kv_num_heads, batch_size, kv_seq_len,device='cpu' , dtype=torch.float32 , dropout=0.1):
|
|
469
|
+
super().__init__()
|
|
470
|
+
self.dtype = dtype
|
|
471
|
+
self.device = device
|
|
472
|
+
|
|
473
|
+
assert query_num_heads % kv_num_heads == 0, "query heads must be divisible by kv heads"
|
|
474
|
+
assert emb_dim % query_num_heads == 0, "embedding must be divisible by query heads"
|
|
475
|
+
|
|
476
|
+
self.emb_dim = emb_dim
|
|
477
|
+
self.query_num_heads = query_num_heads
|
|
478
|
+
self.kv_num_heads = kv_num_heads
|
|
479
|
+
self.head_dim = emb_dim // query_num_heads
|
|
480
|
+
self.num_queries_per_kv = query_num_heads // kv_num_heads
|
|
481
|
+
self.kv_seq_len = kv_seq_len
|
|
482
|
+
|
|
483
|
+
self.query = nn.Linear(emb_dim, emb_dim, bias=False,dtype=dtype,device=device)
|
|
484
|
+
self.key = nn.Linear(emb_dim, kv_num_heads * self.head_dim, bias=False,dtype=dtype,device=device)
|
|
485
|
+
self.value = nn.Linear(emb_dim, kv_num_heads * self.head_dim, bias=False,dtype=dtype,device=device)
|
|
486
|
+
|
|
487
|
+
self.out_proj = nn.Linear(emb_dim, emb_dim,dtype=dtype,device=device)
|
|
488
|
+
self.dropout = nn.Dropout(dropout)
|
|
489
|
+
|
|
490
|
+
# KV caches
|
|
491
|
+
self.register_buffer("cache_keys", torch.zeros(batch_size, kv_seq_len, kv_num_heads, self.head_dim,device=device,dtype=dtype))
|
|
492
|
+
self.register_buffer("cache_value", torch.zeros(batch_size, kv_seq_len, kv_num_heads, self.head_dim,device=device,dtype=dtype))
|
|
493
|
+
|
|
494
|
+
def forward(self, x, start_pos, RoPE=False):
|
|
495
|
+
batch_size, seq_len, _ = x.shape
|
|
496
|
+
|
|
497
|
+
xq = self.query(x).view(batch_size, seq_len, self.query_num_heads, self.head_dim)
|
|
498
|
+
xk = self.key(x).view(batch_size, seq_len, self.kv_num_heads, self.head_dim)
|
|
499
|
+
xv = self.value(x).view(batch_size, seq_len, self.kv_num_heads, self.head_dim)
|
|
500
|
+
|
|
501
|
+
if RoPE:
|
|
502
|
+
freq_q = precompute_theta_position_frequency(head_dim=self.head_dim, seq_len=seq_len, device=self.device)
|
|
503
|
+
xq = apply_rotry_position_embedding(xq, freq_q, device=self.device, dtype=self.dtype)
|
|
504
|
+
freq_k = precompute_theta_position_frequency(head_dim=self.head_dim, seq_len=self.kv_seq_len, device=self.device)
|
|
505
|
+
xk = apply_rotry_position_embedding(xk, freq_k, device=self.device, dtype=self.dtype)
|
|
506
|
+
# Cache
|
|
507
|
+
self.cache_keys[:, start_pos:start_pos+seq_len] = xk
|
|
508
|
+
self.cache_value[:, start_pos:start_pos+seq_len] = xv
|
|
509
|
+
|
|
510
|
+
xk_full = self.cache_keys[:, :start_pos+seq_len] # [B, T, kv_heads, D]
|
|
511
|
+
xv_full = self.cache_value[:, :start_pos+seq_len]
|
|
512
|
+
|
|
513
|
+
# Transpose for attention: [B, H, T, D]
|
|
514
|
+
query = xq.transpose(1, 2) # [B, q_heads, seq_len, D]
|
|
515
|
+
key = xk_full.transpose(1, 2) # [B, kv_heads, total_kv_len, D]
|
|
516
|
+
value = xv_full.transpose(1, 2)
|
|
517
|
+
|
|
518
|
+
# Repeat keys and values to match query heads
|
|
519
|
+
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
|
520
|
+
value = value.repeat_interleave(self.num_queries_per_kv, dim=1)
|
|
521
|
+
|
|
522
|
+
# Attention
|
|
523
|
+
attn_scores = torch.matmul(query, key.transpose(2, 3)) / (self.head_dim ** 0.5)
|
|
524
|
+
|
|
525
|
+
# Causal mask
|
|
526
|
+
causal_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool, device=self.device), diagonal=1)
|
|
527
|
+
attn_scores.masked_fill_(causal_mask[None, None, :, :], float('-inf'))
|
|
528
|
+
|
|
529
|
+
attn_weights = F.softmax(attn_scores, dim=-1)
|
|
530
|
+
out = torch.matmul(attn_weights, value)
|
|
531
|
+
|
|
532
|
+
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.emb_dim)
|
|
533
|
+
return self.dropout(self.out_proj(out))
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
|
|
5
|
+
class FF_ReLU(nn.Module):
|
|
6
|
+
def __init__(self,emb_dim,hidden_dim,device='cpu',dtype=torch.float32):
|
|
7
|
+
super().__init__()
|
|
8
|
+
self.relu=nn.Sequential(
|
|
9
|
+
nn.Linear(emb_dim,hidden_dim,device=device,dtype=dtype),
|
|
10
|
+
nn.ReLU(),
|
|
11
|
+
nn.Linear(hidden_dim,emb_dim,device=device,dtype=dtype),
|
|
12
|
+
)
|
|
13
|
+
def forward(self,x):
|
|
14
|
+
return self.relu(x)
|
|
15
|
+
|
|
16
|
+
class FF_LeakyReLU(nn.Module):
|
|
17
|
+
def __init__(self,emb_dim,hidden_dim,negative_slope=0.1,device='cpu',dtype=torch.float32):
|
|
18
|
+
super().__init__()
|
|
19
|
+
self.l_relu=nn.Sequential(
|
|
20
|
+
nn.Linear(emb_dim,hidden_dim,device=device,dtype=dtype),
|
|
21
|
+
nn.LeakyReLU(negative_slope),
|
|
22
|
+
nn.Linear(hidden_dim,emb_dim,device=device,dtype=dtype),
|
|
23
|
+
)
|
|
24
|
+
def forward(self,x):
|
|
25
|
+
return self.l_relu(x)
|
|
26
|
+
|
|
27
|
+
class FF_GELU(nn.Module):
|
|
28
|
+
def __init__(self,emb_dim,hidden_dim,device='cpu',dtype=torch.float32):
|
|
29
|
+
super().__init__()
|
|
30
|
+
self.gelu=nn.Sequential(
|
|
31
|
+
nn.Linear(emb_dim,hidden_dim,device=device,dtype=dtype),
|
|
32
|
+
nn.GELU(),
|
|
33
|
+
nn.Linear(hidden_dim,emb_dim,device=device,dtype=dtype),
|
|
34
|
+
)
|
|
35
|
+
def forward(self,x):
|
|
36
|
+
return self.gelu(x)
|
|
37
|
+
|
|
38
|
+
class FF_Sigmoid(nn.Module):
|
|
39
|
+
def __init__(self,emb_dim,hidden_dim,device='cpu',dtype=torch.float32):
|
|
40
|
+
super().__init__()
|
|
41
|
+
self.sigmoid=nn.Sequential(
|
|
42
|
+
nn.Linear(emb_dim,hidden_dim,device=device,dtype=dtype),
|
|
43
|
+
nn.Sigmoid(),
|
|
44
|
+
nn.Linear(hidden_dim,emb_dim,device=device,dtype=dtype),
|
|
45
|
+
)
|
|
46
|
+
def forward(self,x):
|
|
47
|
+
return self.sigmoid(x)
|
|
48
|
+
|
|
49
|
+
class FF_SiLU(nn.Module):
|
|
50
|
+
def __init__(self,emb_dim,hidden_dim,device='cpu',dtype=torch.float32):
|
|
51
|
+
super().__init__()
|
|
52
|
+
self.silu=nn.Sequential(
|
|
53
|
+
nn.Linear(emb_dim,hidden_dim,device=device,dtype=dtype),
|
|
54
|
+
nn.SiLU(),
|
|
55
|
+
nn.Linear(hidden_dim,emb_dim,device=device,dtype=dtype),
|
|
56
|
+
)
|
|
57
|
+
def forward(self,x):
|
|
58
|
+
return self.silu(x)
|
|
59
|
+
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
class LayerNorm(nn.Module):
|
|
5
|
+
def __init__(self, emb_dim, eps = 1e-5):
|
|
6
|
+
super().__init__()
|
|
7
|
+
self.eps = eps
|
|
8
|
+
|
|
9
|
+
def forward(self, x):
|
|
10
|
+
mean = x.mean(dim=-1, keepdim=True)
|
|
11
|
+
var = x.var(dim=-1, keepdim=True, unbiased=False)
|
|
12
|
+
norm_x = (x - mean) / torch.sqrt(var + self.eps)
|
|
13
|
+
return norm_x
|
|
14
|
+
|
|
15
|
+
class LayerNorm(nn.Module):
|
|
16
|
+
def __init__(self, emb_dim, eps=1e-5, device='cpu', dtype=torch.float32):
|
|
17
|
+
super().__init__()
|
|
18
|
+
self.eps = eps
|
|
19
|
+
self.weight = nn.Parameter(torch.ones(emb_dim, device=device, dtype=dtype))
|
|
20
|
+
self.bias = nn.Parameter(torch.zeros(emb_dim, device=device, dtype=dtype))
|
|
21
|
+
|
|
22
|
+
def forward(self, x):
|
|
23
|
+
mean = x.mean(dim=-1, keepdim=True)
|
|
24
|
+
var = x.var(dim=-1, keepdim=True, unbiased=False)
|
|
25
|
+
norm_x = (x - mean) / torch.sqrt(var + self.eps)
|
|
26
|
+
return norm_x * self.weight + self.bias
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# RMS = sqrt(Xn ** 2)
|
|
30
|
+
# Norm = Xn / RMS
|
|
31
|
+
class RMSNormilization(nn.Module):
|
|
32
|
+
|
|
33
|
+
def __init__(self,dim,eps=1e-5):
|
|
34
|
+
super().__init__()
|
|
35
|
+
self.eps = eps
|
|
36
|
+
self.weight = nn.Parameter(torch.ones(dim))
|
|
37
|
+
|
|
38
|
+
def forward(self,x):
|
|
39
|
+
rms = x.pow(2).mean(-1,keepdim=True).sqrt()
|
|
40
|
+
norm = self.weight * x / (rms + self.eps)
|
|
41
|
+
return norm
|
|
File without changes
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
# problem: Random mask and global mask
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
def casual_mask(Seq_len):
|
|
5
|
+
causal_mask = torch.triu(torch.ones(Seq_len, Seq_len, dtype=torch.bool), diagonal=1)
|
|
6
|
+
return causal_mask
|
|
7
|
+
|
|
8
|
+
def sliding_window(Seq_len, window_size):
|
|
9
|
+
casual = torch.tril(torch.ones(Seq_len,Seq_len,dtype=bool))
|
|
10
|
+
band = torch.triu(casual, diagonal=-(window_size-1))
|
|
11
|
+
return ~band
|
|
12
|
+
|
|
13
|
+
def dilated_casual_mask(Seq_len, dilation):
|
|
14
|
+
i = torch.arange(Seq_len).unsqueeze(1)
|
|
15
|
+
j = torch.arange(Seq_len).unsqueeze(0)
|
|
16
|
+
# causal and dilation condition
|
|
17
|
+
mask = (i >= j) & ((i - j) % dilation == 0)
|
|
18
|
+
return ~mask
|
|
19
|
+
|
|
20
|
+
def random_mask(Seq_len, num_random):
|
|
21
|
+
mask = torch.zeros(Seq_len, Seq_len)
|
|
22
|
+
for i in range(Seq_len):
|
|
23
|
+
candidates = list(range(i))
|
|
24
|
+
if len(candidates) == 0:
|
|
25
|
+
continue
|
|
26
|
+
random_mask = torch.randperm(len(candidates))[:min(num_random, len(candidates))]
|
|
27
|
+
mask[i, torch.tensor([candidates[j] for j in random_mask])] = 1
|
|
28
|
+
return ~mask
|
|
29
|
+
|
|
30
|
+
def global_mask(Seq_len, global_index):
|
|
31
|
+
global_index_tensor = torch.tensor(global_index)
|
|
32
|
+
mask = torch.zeros(Seq_len, Seq_len)
|
|
33
|
+
for g in global_index:
|
|
34
|
+
mask[g,:] = 1
|
|
35
|
+
mask[:,global_index_tensor] = 1
|
|
36
|
+
return ~mask
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
# --- Absolute Positional Embedding ---
|
|
6
|
+
class AbsolutePositionEmbedding(nn.Module):
|
|
7
|
+
def __init__(self, seq_len, emb_dim):
|
|
8
|
+
super().__init__()
|
|
9
|
+
self.seq_len = seq_len
|
|
10
|
+
self.emb_dim = emb_dim
|
|
11
|
+
self.embedding = nn.Embedding(seq_len, emb_dim)
|
|
12
|
+
|
|
13
|
+
def forward(self, x):
|
|
14
|
+
batch_size, seq_len = x.shape[0], x.shape[1]
|
|
15
|
+
positions = torch.arange(0, seq_len)
|
|
16
|
+
abs_pos = self.embedding(positions) # (seq_len, emb_dim)
|
|
17
|
+
return abs_pos.unsqueeze(0).expand(batch_size, seq_len, -1).to(x.device)
|
|
18
|
+
|
|
19
|
+
# --- Sinusoidal Positional Embedding ---
|
|
20
|
+
class SinusoidalPositionalEmbedding(nn.Module):
|
|
21
|
+
def __init__(self, seq_len, emb_dim):
|
|
22
|
+
super().__init__()
|
|
23
|
+
self.seq_len = seq_len
|
|
24
|
+
self.emb_dim = emb_dim
|
|
25
|
+
|
|
26
|
+
position = torch.arange(0, seq_len).unsqueeze(1)
|
|
27
|
+
div_term = torch.exp(torch.arange(0, emb_dim, 2) * -(math.log(10000.0) / emb_dim))
|
|
28
|
+
|
|
29
|
+
pe = torch.zeros(seq_len, emb_dim)
|
|
30
|
+
pe[:, 0::2] = torch.sin(position * div_term)
|
|
31
|
+
pe[:, 1::2] = torch.cos(position * div_term)
|
|
32
|
+
|
|
33
|
+
self.register_buffer("pe", pe)
|
|
34
|
+
|
|
35
|
+
def forward(self, x):
|
|
36
|
+
# x shape: (batch_size, seq_len, emb_dim) or (batch_size, seq_len)
|
|
37
|
+
batch_size, seq_len = x.shape[0], x.shape[1]
|
|
38
|
+
return self.pe[:seq_len].unsqueeze(0).expand(batch_size, seq_len, -1).to(x.device)
|
|
39
|
+
|
|
40
|
+
# --- RoPE ---
|
|
41
|
+
class RoPE(nn.Module):
|
|
42
|
+
def __init__(self, head_dim, seq_len, theta=10000.0, device='cpu', dtype=torch.float32):
|
|
43
|
+
super().__init__()
|
|
44
|
+
self.dtype = dtype
|
|
45
|
+
self.device = device
|
|
46
|
+
assert head_dim % 2 == 0, "head_dim must be even"
|
|
47
|
+
theta_numerator = torch.arange(0, head_dim, 2, device=device, dtype=dtype)
|
|
48
|
+
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))
|
|
49
|
+
m = torch.arange(seq_len, device=device)
|
|
50
|
+
freqs = torch.outer(m, inv_freq)
|
|
51
|
+
self.register_buffer("freq_complex", torch.polar(torch.ones_like(freqs), freqs))
|
|
52
|
+
|
|
53
|
+
def forward(self, x):
|
|
54
|
+
batch_size, seq_len, num_head, emb_dim = x.shape
|
|
55
|
+
assert emb_dim % 2 == 0, "emb_dim must be even"
|
|
56
|
+
x_reshaped = x.view(batch_size, seq_len, num_head, emb_dim // 2, 2)
|
|
57
|
+
x_complex = torch.view_as_complex(x_reshaped)
|
|
58
|
+
freqs = self.freq_complex[:seq_len].unsqueeze(0).unsqueeze(2) # (1, seq_len, 1, head_dim//2)
|
|
59
|
+
x_rotated = x_complex * freqs
|
|
60
|
+
x_out = torch.view_as_real(x_rotated).contiguous().view(batch_size, seq_len, num_head, emb_dim)
|
|
61
|
+
return x_out.to(device=self.device, dtype=self.dtype)
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import tiktoken
|
|
4
|
+
|
|
5
|
+
# Bite pair (BPE) Embedding using tokienizer
|
|
6
|
+
class Embedding_using_tiktoken:
|
|
7
|
+
def __init__(self,data,embedding_dim,model: str):
|
|
8
|
+
self.tokenizer = tiktoken.get_encoding(model)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def encoding(self,data,embedding_dim):
|
|
12
|
+
max_token_id = self.tokenizer.n_vocab
|
|
13
|
+
embedding_layer = nn.Embedding(num_embeddings = max_token_id, embedding_dim = embedding_dim)
|
|
14
|
+
tensors = torch.tensor(self.tokenizer.encode(data))
|
|
15
|
+
embedded = embedding_layer(tensors)
|
|
16
|
+
return embedded
|
|
17
|
+
|
|
18
|
+
def decoding(self,data):
|
|
19
|
+
return self.tokenizer.decode(data)
|
|
20
|
+
|
|
21
|
+
def vocab_size(self):
|
|
22
|
+
return self.tokenizer.n_vocab
|
|
23
|
+
|
|
24
|
+
def model_list(self):
|
|
25
|
+
return tiktoken.list_encoding_names()
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "Stackformer"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Modular transformer blocks built in PyTorch"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.9"
|
|
7
|
+
license = {text = "MIT"}
|
|
8
|
+
|
|
9
|
+
authors = [
|
|
10
|
+
{name = "Gurumurthy", email = "gurumurthy.00300@gmail.com"}
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
dependencies = [
|
|
14
|
+
"torch>=2.6",
|
|
15
|
+
"tqdm>=4.67"
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
[project.urls]
|
|
19
|
+
"Repository" = "https://github.com/Gurumurthy30/Stackformer"
|
|
20
|
+
"Issue Tracker" = "https://github.com/Gurumurthy30/Stackformer/issues"
|
|
21
|
+
"Discussions" = "https://github.com/Gurumurthy30/Stackformer/discussions"
|
|
22
|
+
|
|
23
|
+
[build-system]
|
|
24
|
+
requires = ["setuptools>=61", "wheel"]
|
|
25
|
+
build-backend = "setuptools.build_meta"
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from setuptools import setup, find_packages
|
|
2
|
+
|
|
3
|
+
setup(
|
|
4
|
+
name="Stackformer",
|
|
5
|
+
version="0.1.0",
|
|
6
|
+
description="Modular transformer blocks built in PyTorch",
|
|
7
|
+
# long_description=open("README.md", "r", encoding="utf-8").read(),
|
|
8
|
+
# long_description_content_type="text/markdown",
|
|
9
|
+
author="Gurumurthy",
|
|
10
|
+
author_email="gurumurthy.00300@gmail.com",
|
|
11
|
+
url="https://github.com/Gurumurthy30/Stackformer",
|
|
12
|
+
project_urls={
|
|
13
|
+
"Repository": "https://github.com/Gurumurthy30/Stackformer",
|
|
14
|
+
"Issue Tracker": "https://github.com/Gurumurthy30/Stackformer/issues",
|
|
15
|
+
"Discussions": "https://github.com/Gurumurthy30/Stackformer/discussions",
|
|
16
|
+
},
|
|
17
|
+
license="MIT",
|
|
18
|
+
python_requires=">=3.9",
|
|
19
|
+
packages=find_packages(exclude=["tests", "examples"]),
|
|
20
|
+
install_requires=[
|
|
21
|
+
"torch>=2.6",
|
|
22
|
+
"tqdm>=4.67",
|
|
23
|
+
],
|
|
24
|
+
classifiers=[
|
|
25
|
+
"Programming Language :: Python :: 3",
|
|
26
|
+
"License :: OSI Approved :: MIT License",
|
|
27
|
+
"Operating System :: OS Independent",
|
|
28
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
29
|
+
"Topic :: Software Development :: Libraries :: Python Modules",
|
|
30
|
+
],
|
|
31
|
+
)
|