Stackformer 0.1.2__tar.gz → 0.1.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {stackformer-0.1.2 → stackformer-0.1.3}/PKG-INFO +3 -3
- {stackformer-0.1.2 → stackformer-0.1.3}/Stackformer.egg-info/PKG-INFO +3 -3
- {stackformer-0.1.2 → stackformer-0.1.3}/Stackformer.egg-info/SOURCES.txt +1 -0
- stackformer-0.1.3/Stackformer.egg-info/requires.txt +2 -0
- {stackformer-0.1.2 → stackformer-0.1.3}/pyproject.toml +3 -3
- {stackformer-0.1.2 → stackformer-0.1.3}/setup.py +3 -3
- {stackformer-0.1.2 → stackformer-0.1.3}/stackformer/__init__.py +9 -2
- stackformer-0.1.3/stackformer/generate.py +53 -0
- stackformer-0.1.3/stackformer/models/Meta.py +159 -0
- stackformer-0.1.3/stackformer/models/OpenAI.py +177 -0
- stackformer-0.1.3/stackformer/models/Transformer.py +104 -0
- {stackformer-0.1.2 → stackformer-0.1.3}/stackformer/modules/Attention.py +73 -62
- stackformer-0.1.3/stackformer/modules/Feed_forward.py +90 -0
- stackformer-0.1.3/stackformer/modules/Normalization.py +31 -0
- {stackformer-0.1.2 → stackformer-0.1.3}/stackformer/modules/position_embedding.py +1 -1
- stackformer-0.1.2/Stackformer.egg-info/requires.txt +0 -2
- stackformer-0.1.2/stackformer/models/Meta.py +0 -213
- stackformer-0.1.2/stackformer/models/OpenAI.py +0 -242
- stackformer-0.1.2/stackformer/models/Transformer.py +0 -238
- stackformer-0.1.2/stackformer/modules/Feed_forward.py +0 -59
- stackformer-0.1.2/stackformer/modules/Normalization.py +0 -41
- {stackformer-0.1.2 → stackformer-0.1.3}/LICENSE +0 -0
- {stackformer-0.1.2 → stackformer-0.1.3}/README.md +0 -0
- {stackformer-0.1.2 → stackformer-0.1.3}/Stackformer.egg-info/dependency_links.txt +0 -0
- {stackformer-0.1.2 → stackformer-0.1.3}/Stackformer.egg-info/top_level.txt +0 -0
- {stackformer-0.1.2 → stackformer-0.1.3}/setup.cfg +0 -0
- {stackformer-0.1.2 → stackformer-0.1.3}/stackformer/models/__init__.py +0 -0
- {stackformer-0.1.2 → stackformer-0.1.3}/stackformer/modules/__init__.py +0 -0
- {stackformer-0.1.2 → stackformer-0.1.3}/stackformer/modules/mask.py +0 -0
- {stackformer-0.1.2 → stackformer-0.1.3}/stackformer/modules/tokenizer.py +0 -0
- {stackformer-0.1.2 → stackformer-0.1.3}/stackformer/trainer.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: Stackformer
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.3
|
|
4
4
|
Summary: Modular transformer blocks built in PyTorch
|
|
5
5
|
Home-page: https://github.com/Gurumurthy30/Stackformer
|
|
6
6
|
Author: Gurumurthy
|
|
@@ -12,8 +12,8 @@ Project-URL: Discussions, https://github.com/Gurumurthy30/Stackformer/discussion
|
|
|
12
12
|
Requires-Python: >=3.9
|
|
13
13
|
Description-Content-Type: text/markdown
|
|
14
14
|
License-File: LICENSE
|
|
15
|
-
Requires-Dist: torch<2.
|
|
16
|
-
Requires-Dist: tqdm
|
|
15
|
+
Requires-Dist: torch<2.7,>=2.0
|
|
16
|
+
Requires-Dist: tqdm<5.0,>=4.5
|
|
17
17
|
Dynamic: author
|
|
18
18
|
Dynamic: home-page
|
|
19
19
|
Dynamic: license-file
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: Stackformer
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.3
|
|
4
4
|
Summary: Modular transformer blocks built in PyTorch
|
|
5
5
|
Home-page: https://github.com/Gurumurthy30/Stackformer
|
|
6
6
|
Author: Gurumurthy
|
|
@@ -12,8 +12,8 @@ Project-URL: Discussions, https://github.com/Gurumurthy30/Stackformer/discussion
|
|
|
12
12
|
Requires-Python: >=3.9
|
|
13
13
|
Description-Content-Type: text/markdown
|
|
14
14
|
License-File: LICENSE
|
|
15
|
-
Requires-Dist: torch<2.
|
|
16
|
-
Requires-Dist: tqdm
|
|
15
|
+
Requires-Dist: torch<2.7,>=2.0
|
|
16
|
+
Requires-Dist: tqdm<5.0,>=4.5
|
|
17
17
|
Dynamic: author
|
|
18
18
|
Dynamic: home-page
|
|
19
19
|
Dynamic: license-file
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "Stackformer"
|
|
3
|
-
version = "0.1.
|
|
3
|
+
version = "0.1.3"
|
|
4
4
|
description = "Modular transformer blocks built in PyTorch"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
requires-python = ">=3.9"
|
|
@@ -11,8 +11,8 @@ authors = [
|
|
|
11
11
|
]
|
|
12
12
|
|
|
13
13
|
dependencies = [
|
|
14
|
-
"torch>=2.0,<2.
|
|
15
|
-
"tqdm>=4.
|
|
14
|
+
"torch>=2.0,<2.7",
|
|
15
|
+
"tqdm>=4.5,<5.0"
|
|
16
16
|
]
|
|
17
17
|
|
|
18
18
|
[project.urls]
|
|
@@ -2,7 +2,7 @@ from setuptools import setup, find_packages
|
|
|
2
2
|
|
|
3
3
|
setup(
|
|
4
4
|
name="Stackformer",
|
|
5
|
-
version="0.1.
|
|
5
|
+
version="0.1.3",
|
|
6
6
|
description="Modular transformer blocks built in PyTorch",
|
|
7
7
|
# long_description=open("README.md", "r", encoding="utf-8").read(),
|
|
8
8
|
# long_description_content_type="text/markdown",
|
|
@@ -18,8 +18,8 @@ setup(
|
|
|
18
18
|
python_requires=">=3.9",
|
|
19
19
|
packages=find_packages(exclude=["tests", "examples"]),
|
|
20
20
|
install_requires=[
|
|
21
|
-
|
|
22
|
-
|
|
21
|
+
"torch>=2.0,<2.7",
|
|
22
|
+
"tqdm>=4.5,<5.0",
|
|
23
23
|
],
|
|
24
24
|
classifiers=[
|
|
25
25
|
"Programming Language :: Python :: 3",
|
|
@@ -9,6 +9,7 @@ from .modules.position_embedding import RoPE
|
|
|
9
9
|
# --- Attention mechanisms ---
|
|
10
10
|
from .modules.Attention import Self_Attention
|
|
11
11
|
from .modules.Attention import Multi_Head_Attention
|
|
12
|
+
from .modules.Attention import Multi_Head_Attention_with_RoPE
|
|
12
13
|
from .modules.Attention import Cross_MultiHead_Attention
|
|
13
14
|
from .modules.Attention import Multi_query_Attention
|
|
14
15
|
from .modules.Attention import Group_query_Attention
|
|
@@ -28,11 +29,17 @@ from .modules.Feed_forward import FF_GELU
|
|
|
28
29
|
from .modules.Feed_forward import FF_LeakyReLU
|
|
29
30
|
from .modules.Feed_forward import FF_Sigmoid
|
|
30
31
|
from .modules.Feed_forward import FF_SiLU
|
|
32
|
+
from .modules.Feed_forward import FF_SwiGLU
|
|
31
33
|
|
|
32
34
|
# --- Model ---
|
|
35
|
+
from .models.OpenAI import GPT_1
|
|
33
36
|
from .models.OpenAI import GPT_2
|
|
34
|
-
from .models.Meta import
|
|
37
|
+
from .models.Meta import llama_1
|
|
38
|
+
from .models.Meta import llama_2
|
|
35
39
|
from .models.Transformer import transformer
|
|
36
40
|
|
|
37
41
|
# --- Trainer ---
|
|
38
|
-
from .trainer import Trainer
|
|
42
|
+
from .trainer import Trainer
|
|
43
|
+
|
|
44
|
+
# --- Generate ---
|
|
45
|
+
from .generate import text_generate
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
def text_generate(self, prompt_ids, max_context_len=128, max_new_tokens=50, temperature=1.0, top_k=None, top_p=1.0, eos_token_id=None):
|
|
5
|
+
if prompt_ids.dim() == 1:
|
|
6
|
+
prompt_ids = prompt_ids.unsqueeze(0)
|
|
7
|
+
|
|
8
|
+
generated = prompt_ids.clone()
|
|
9
|
+
|
|
10
|
+
for _ in range(max_new_tokens):
|
|
11
|
+
# Use sliding window if sequence gets too long
|
|
12
|
+
if generated.size(1) > max_context_len:
|
|
13
|
+
input_ids = generated[:, -max_context_len:]
|
|
14
|
+
else:
|
|
15
|
+
input_ids = generated
|
|
16
|
+
|
|
17
|
+
logits = self.forward(input_ids) # (batch_size, seq_len, vocab_size)
|
|
18
|
+
logits = logits[:, -1, :] # (batch_size, vocab_size)
|
|
19
|
+
|
|
20
|
+
# --- Temperature scaling ---
|
|
21
|
+
if temperature != 1.0:
|
|
22
|
+
logits = logits / temperature
|
|
23
|
+
|
|
24
|
+
# --- Top-k filtering ---
|
|
25
|
+
if top_k is not None and top_k > 0:
|
|
26
|
+
topk_vals, topk_indices = torch.topk(logits, top_k)
|
|
27
|
+
mask = torch.full_like(logits, float('-inf'))
|
|
28
|
+
mask.scatter_(dim=-1, index=topk_indices, src=topk_vals)
|
|
29
|
+
logits = mask
|
|
30
|
+
|
|
31
|
+
# --- Top-p ---
|
|
32
|
+
if top_p < 1.0:
|
|
33
|
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
|
34
|
+
probs = F.softmax(sorted_logits, dim=-1)
|
|
35
|
+
cum_probs = torch.cumsum(probs, dim=-1)
|
|
36
|
+
|
|
37
|
+
sorted_mask = cum_probs > top_p
|
|
38
|
+
sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
|
|
39
|
+
sorted_mask[..., 0] = 0
|
|
40
|
+
|
|
41
|
+
indices_to_remove = sorted_mask.scatter(dim=-1, index=sorted_indices, src=sorted_mask)
|
|
42
|
+
logits = logits.masked_fill(indices_to_remove, float('-inf'))
|
|
43
|
+
|
|
44
|
+
# Sample next token
|
|
45
|
+
probs = F.softmax(logits, dim=-1)
|
|
46
|
+
next_token = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
|
|
47
|
+
generated = torch.cat([generated, next_token], dim=-1)
|
|
48
|
+
|
|
49
|
+
# check if we've reached the end of the sequence
|
|
50
|
+
if eos_token_id is not None and next_token.item() == eos_token_id:
|
|
51
|
+
break
|
|
52
|
+
|
|
53
|
+
return generated
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
|
|
5
|
+
from stackformer.modules.Attention import kv_cache_group_query, Multi_Head_Attention_with_RoPE
|
|
6
|
+
from stackformer.modules.Feed_forward import FF_SwiGLU
|
|
7
|
+
from stackformer.modules.Normalization import RMSNormilization
|
|
8
|
+
from stackformer.generate import text_generate
|
|
9
|
+
|
|
10
|
+
'''
|
|
11
|
+
llama 1
|
|
12
|
+
Attention: MHA
|
|
13
|
+
Mask: Casual
|
|
14
|
+
position: RoPE
|
|
15
|
+
FF: SwiGLU
|
|
16
|
+
Norm: pre norm (RMS norm)
|
|
17
|
+
'''
|
|
18
|
+
class llama_1_Block(nn.Module):
|
|
19
|
+
def __init__(self, emb_dim, num_heads, dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.attention = Multi_Head_Attention_with_RoPE(emb_dim, num_heads, dropout, device=device, dtype=dtype)
|
|
22
|
+
self.norm1 = RMSNormilization(emb_dim, eps=eps)
|
|
23
|
+
self.FF_SwiGLU = FF_SwiGLU(emb_dim, hidden_dim, dropout, device=device, dtype=dtype)
|
|
24
|
+
self.norm2 = RMSNormilization(emb_dim, eps=eps)
|
|
25
|
+
|
|
26
|
+
def forward(self, x):
|
|
27
|
+
residual = x
|
|
28
|
+
x = self.norm1(x)
|
|
29
|
+
x = self.attention(x)
|
|
30
|
+
x = x + residual
|
|
31
|
+
|
|
32
|
+
residual = x
|
|
33
|
+
x = self.norm2(x)
|
|
34
|
+
x = self.FF_SwiGLU(x)
|
|
35
|
+
x = x + residual
|
|
36
|
+
|
|
37
|
+
return x
|
|
38
|
+
|
|
39
|
+
# --- Encoder ---
|
|
40
|
+
class llama_1_Encoder(nn.Module):
|
|
41
|
+
def __init__(self, num_layers, emb_dim, num_heads, dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
|
|
42
|
+
super().__init__()
|
|
43
|
+
self.layers = nn.ModuleList([
|
|
44
|
+
llama_1_Block(emb_dim, num_heads, dropout, hidden_dim, eps, device=device, dtype=dtype)
|
|
45
|
+
for _ in range(num_layers)
|
|
46
|
+
])
|
|
47
|
+
|
|
48
|
+
def forward(self, x):
|
|
49
|
+
for layer in self.layers:
|
|
50
|
+
x = layer(x)
|
|
51
|
+
return x
|
|
52
|
+
|
|
53
|
+
class llama_1(nn.Module):
|
|
54
|
+
def __init__(self, vocab_size, num_layers, emb_dim, num_heads, seq_len,
|
|
55
|
+
dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
|
|
56
|
+
super().__init__()
|
|
57
|
+
self.device = device
|
|
58
|
+
self.dtype = dtype
|
|
59
|
+
self.seq_len = seq_len
|
|
60
|
+
|
|
61
|
+
# --- Token embedding ---
|
|
62
|
+
self.embedding = nn.Embedding(vocab_size, emb_dim, dtype=self.dtype, device=self.device)
|
|
63
|
+
|
|
64
|
+
# --- Encoder ---
|
|
65
|
+
self.encoder = llama_1_Encoder(num_layers=num_layers,emb_dim=emb_dim,num_heads=num_heads,dropout=dropout,
|
|
66
|
+
hidden_dim=hidden_dim,eps=eps,device=self.device,dtype=self.dtype)
|
|
67
|
+
|
|
68
|
+
# --- Final norm
|
|
69
|
+
self.final_norm = RMSNormilization(emb_dim, eps=eps)
|
|
70
|
+
|
|
71
|
+
# --- Output Projection ---
|
|
72
|
+
self.lm_head = nn.Linear(emb_dim, vocab_size, bias=False, dtype=self.dtype, device=self.device)
|
|
73
|
+
|
|
74
|
+
def forward(self, x):
|
|
75
|
+
# x shape: (batch_size, seq_len)
|
|
76
|
+
emb = self.embedding(x) # (batch_size, seq_len, emb_dim)
|
|
77
|
+
x = self.encoder(emb)
|
|
78
|
+
x = self.final_norm(x)
|
|
79
|
+
x = self.lm_head(x)
|
|
80
|
+
return x
|
|
81
|
+
|
|
82
|
+
@torch.no_grad()
|
|
83
|
+
def generate(self, prompt_ids, max_context_len=128, max_new_tokens=50, temperature=1.0, top_k=None, top_p=1.0, eos_token_id=None):
|
|
84
|
+
return text_generate(self, prompt_ids, max_context_len, max_new_tokens, temperature, top_k, top_p, eos_token_id)
|
|
85
|
+
|
|
86
|
+
'''
|
|
87
|
+
llama 2
|
|
88
|
+
Attention: GQA with KV catch
|
|
89
|
+
Mask: Casual
|
|
90
|
+
position: RoPE
|
|
91
|
+
FF: SwiGLU
|
|
92
|
+
Norm: pre norm (RMS norm)
|
|
93
|
+
'''
|
|
94
|
+
class llama_2_Block(nn.Module):
|
|
95
|
+
def __init__(self, emb_dim, query_num_heads, kv_num_heads, batch_size, kv_seq_len, hidden_dim,
|
|
96
|
+
eps=1e-5, dropout=0.1, dtype=torch.float32, device='cpu'):
|
|
97
|
+
super().__init__()
|
|
98
|
+
self.attn_norm = RMSNormilization(emb_dim, eps=eps)
|
|
99
|
+
self.ff_norm = RMSNormilization(emb_dim, eps=eps)
|
|
100
|
+
self.attn = kv_cache_group_query(emb_dim=emb_dim, query_num_heads=query_num_heads, kv_num_heads=kv_num_heads,
|
|
101
|
+
batch_size=batch_size, kv_seq_len=kv_seq_len, dtype=dtype,dropout=dropout, device=device)
|
|
102
|
+
self.ff = FF_SwiGLU(emb_dim=emb_dim, hidden_dim=hidden_dim, device=device, dtype=dtype)
|
|
103
|
+
|
|
104
|
+
def forward(self, x, start_pos):
|
|
105
|
+
residual = x
|
|
106
|
+
x = self.attn_norm(x)
|
|
107
|
+
x = self.attn(x, start_pos, rope=True)
|
|
108
|
+
x = x + residual
|
|
109
|
+
|
|
110
|
+
residual = x
|
|
111
|
+
x = self.ff_norm(x)
|
|
112
|
+
x = self.ff(x)
|
|
113
|
+
x = x + residual
|
|
114
|
+
return x
|
|
115
|
+
|
|
116
|
+
class llama_2_Encoder(nn.Module):
|
|
117
|
+
def __init__(self, num_layers, emb_dim, query_num_heads, kv_num_heads, batch_size, kv_seq_len,
|
|
118
|
+
hidden_dim, eps=1e-5, dropout=0.1, dtype=torch.float32, device='cpu'):
|
|
119
|
+
super().__init__()
|
|
120
|
+
self.layers = nn.ModuleList([
|
|
121
|
+
llama_2_Block(emb_dim=emb_dim, query_num_heads=query_num_heads, kv_num_heads=kv_num_heads,
|
|
122
|
+
batch_size=batch_size, kv_seq_len=kv_seq_len, hidden_dim=hidden_dim,
|
|
123
|
+
eps=eps, dropout=dropout, dtype=dtype, device=device)
|
|
124
|
+
for _ in range(num_layers)
|
|
125
|
+
])
|
|
126
|
+
|
|
127
|
+
def forward(self, x, start_pos):
|
|
128
|
+
for layer in self.layers:
|
|
129
|
+
x = layer(x, start_pos)
|
|
130
|
+
return x
|
|
131
|
+
|
|
132
|
+
class llama_2(nn.Module):
|
|
133
|
+
def __init__(self, num_layers, emb_dim, query_num_heads, kv_num_heads, batch_size, kv_seq_len, vocab_size,
|
|
134
|
+
hidden_dim, eps=1e-5, dropout=0.1, dtype=torch.float32, device='cpu'):
|
|
135
|
+
super().__init__()
|
|
136
|
+
self.device = device
|
|
137
|
+
self.vocab_size = vocab_size
|
|
138
|
+
self.dtype = dtype
|
|
139
|
+
self.seq_len = kv_seq_len # For generation slicing
|
|
140
|
+
|
|
141
|
+
self.embedding = nn.Embedding(vocab_size, emb_dim, dtype=dtype, device=device)
|
|
142
|
+
|
|
143
|
+
self.llama_2_Encoder = llama_2_Encoder(num_layers=num_layers, emb_dim=emb_dim, query_num_heads=query_num_heads,
|
|
144
|
+
kv_num_heads=kv_num_heads, batch_size=batch_size, kv_seq_len=kv_seq_len,
|
|
145
|
+
hidden_dim=hidden_dim, eps=eps, dropout=dropout, dtype=dtype, device=device)
|
|
146
|
+
|
|
147
|
+
self.final_norm = RMSNormilization(emb_dim, eps=eps)
|
|
148
|
+
self.lm_head = nn.Linear(emb_dim, vocab_size, bias=False, dtype=dtype, device=device)
|
|
149
|
+
|
|
150
|
+
def forward(self, input_ids, start_pos=0):
|
|
151
|
+
x = self.embedding(input_ids)
|
|
152
|
+
x = self.llama_2_Encoder(x, start_pos)
|
|
153
|
+
x = self.final_norm(x)
|
|
154
|
+
logits = self.lm_head(x)
|
|
155
|
+
return logits
|
|
156
|
+
|
|
157
|
+
@torch.no_grad()
|
|
158
|
+
def generate(self, prompt_ids, max_context_len=128, max_new_tokens=50, temperature=1.0, top_k=None, top_p=1.0, eos_token_id=None):
|
|
159
|
+
return text_generate(self, prompt_ids, max_context_len, max_new_tokens, temperature, top_k, top_p, eos_token_id)
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
|
|
5
|
+
from stackformer.modules.Attention import Multi_Head_Attention
|
|
6
|
+
from stackformer.modules.position_embedding import AbsolutePositionEmbedding
|
|
7
|
+
from stackformer.modules.Normalization import LayerNorm
|
|
8
|
+
from stackformer.modules.Feed_forward import FF_GELU
|
|
9
|
+
from stackformer.generate import text_generate
|
|
10
|
+
|
|
11
|
+
'''
|
|
12
|
+
GPT-1
|
|
13
|
+
Attention: MHA
|
|
14
|
+
Mask: Casual
|
|
15
|
+
position: absolute
|
|
16
|
+
FF: GeLU
|
|
17
|
+
Norm: post normalization (layer norm)
|
|
18
|
+
'''
|
|
19
|
+
# --- GPT_1 Encoder Block ---
|
|
20
|
+
class GPT_1_Block(nn.Module):
|
|
21
|
+
def __init__(self, emb_dim, num_heads, dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
|
|
22
|
+
super().__init__()
|
|
23
|
+
self.attention = Multi_Head_Attention(emb_dim, num_heads, dropout, device=device, dtype=dtype)
|
|
24
|
+
self.norm1 = LayerNorm(emb_dim, eps=eps)
|
|
25
|
+
self.FF_GELU = FF_GELU(emb_dim, hidden_dim, dropout, device=device, dtype=dtype)
|
|
26
|
+
self.norm2 = LayerNorm(emb_dim, eps=eps)
|
|
27
|
+
|
|
28
|
+
def forward(self, x):
|
|
29
|
+
residual = x
|
|
30
|
+
x = self.attention(x)
|
|
31
|
+
x = self.norm1(x)
|
|
32
|
+
x = x + residual
|
|
33
|
+
|
|
34
|
+
residual = x
|
|
35
|
+
x = self.FF_GELU(x)
|
|
36
|
+
x = self.norm2(x)
|
|
37
|
+
x = x + residual
|
|
38
|
+
|
|
39
|
+
return x
|
|
40
|
+
|
|
41
|
+
# --- GPT_1 Encoder ---
|
|
42
|
+
class GPT_1_Encoder(nn.Module):
|
|
43
|
+
def __init__(self, num_layers, emb_dim, num_heads, dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
|
|
44
|
+
super().__init__()
|
|
45
|
+
self.layers = nn.ModuleList([
|
|
46
|
+
GPT_1_Block(emb_dim, num_heads, dropout, hidden_dim, eps, device=device, dtype=dtype)
|
|
47
|
+
for _ in range(num_layers)
|
|
48
|
+
])
|
|
49
|
+
|
|
50
|
+
def forward(self, x):
|
|
51
|
+
for layer in self.layers:
|
|
52
|
+
x = layer(x)
|
|
53
|
+
return x
|
|
54
|
+
|
|
55
|
+
class GPT_1(nn.Module):
|
|
56
|
+
def __init__(self, vocab_size, num_layers, emb_dim, num_heads, seq_len,
|
|
57
|
+
dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
|
|
58
|
+
super().__init__()
|
|
59
|
+
self.device = device
|
|
60
|
+
self.dtype = dtype
|
|
61
|
+
self.seq_len = seq_len
|
|
62
|
+
|
|
63
|
+
# --- Token embedding ---
|
|
64
|
+
self.embedding = nn.Embedding(vocab_size, emb_dim, dtype=self.dtype, device=self.device)
|
|
65
|
+
|
|
66
|
+
# --- absolute position embedding ---
|
|
67
|
+
self.position_embedding = AbsolutePositionEmbedding(emb_dim=emb_dim, seq_len=seq_len)
|
|
68
|
+
|
|
69
|
+
# --- Encoder ---
|
|
70
|
+
self.encoder = GPT_1_Encoder(num_layers=num_layers,emb_dim=emb_dim,num_heads=num_heads,dropout=dropout,
|
|
71
|
+
hidden_dim=hidden_dim,eps=eps,device=self.device,dtype=self.dtype)
|
|
72
|
+
|
|
73
|
+
# --- Final norm
|
|
74
|
+
self.final_norm = LayerNorm(emb_dim, eps=eps)
|
|
75
|
+
|
|
76
|
+
# --- Output Projection ---
|
|
77
|
+
self.lm_head = nn.Linear(emb_dim, vocab_size, bias=False, dtype=self.dtype, device=self.device)
|
|
78
|
+
|
|
79
|
+
def forward(self, x):
|
|
80
|
+
# x shape: (batch_size, seq_len)
|
|
81
|
+
emb = self.embedding(x) # (batch_size, seq_len, emb_dim)
|
|
82
|
+
pos = self.position_embedding(x) # (batch_size, seq_len, emb_dim)
|
|
83
|
+
x = emb + pos
|
|
84
|
+
x = self.encoder(x)
|
|
85
|
+
x = self.final_norm(x)
|
|
86
|
+
x = self.lm_head(x)
|
|
87
|
+
return x
|
|
88
|
+
|
|
89
|
+
@torch.no_grad()
|
|
90
|
+
def generate(self, prompt_ids, max_context_len=128, max_new_tokens=50, temperature=1.0, top_k=None, top_p=1.0, eos_token_id=None):
|
|
91
|
+
return text_generate(self, prompt_ids, max_context_len, max_new_tokens, temperature, top_k, top_p, eos_token_id)
|
|
92
|
+
|
|
93
|
+
'''
|
|
94
|
+
GPT-2
|
|
95
|
+
Attention: MHA
|
|
96
|
+
Mask: Casual
|
|
97
|
+
position: absolute
|
|
98
|
+
FF: GeLU
|
|
99
|
+
Norm: pre normalization (layer norm)
|
|
100
|
+
'''
|
|
101
|
+
# --- Encoder Block ---
|
|
102
|
+
class GPT_2_Block(nn.Module):
|
|
103
|
+
def __init__(self, emb_dim, num_heads, dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
|
|
104
|
+
super().__init__()
|
|
105
|
+
self.attention = Multi_Head_Attention(emb_dim, num_heads, dropout, device=device, dtype=dtype)
|
|
106
|
+
self.norm1 = LayerNorm(emb_dim, eps=eps)
|
|
107
|
+
self.FF_GELU = FF_GELU(emb_dim, hidden_dim, dropout, device=device, dtype=dtype)
|
|
108
|
+
self.norm2 = LayerNorm(emb_dim, eps=eps)
|
|
109
|
+
|
|
110
|
+
def forward(self, x):
|
|
111
|
+
residual = x
|
|
112
|
+
x = self.norm1(x)
|
|
113
|
+
x = self.attention(x)
|
|
114
|
+
x = x + residual
|
|
115
|
+
|
|
116
|
+
residual = x
|
|
117
|
+
x = self.norm2(x)
|
|
118
|
+
x = self.FF_GELU(x)
|
|
119
|
+
x = x + residual
|
|
120
|
+
|
|
121
|
+
return x
|
|
122
|
+
|
|
123
|
+
# --- Encoder ---
|
|
124
|
+
class GPT_2_Encoder(nn.Module):
|
|
125
|
+
def __init__(self, num_layers, emb_dim, num_heads, dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
|
|
126
|
+
super().__init__()
|
|
127
|
+
self.layers = nn.ModuleList([
|
|
128
|
+
GPT_2_Block(emb_dim, num_heads, dropout, hidden_dim, eps, device=device, dtype=dtype)
|
|
129
|
+
for _ in range(num_layers)
|
|
130
|
+
])
|
|
131
|
+
|
|
132
|
+
def forward(self, x):
|
|
133
|
+
for layer in self.layers:
|
|
134
|
+
x = layer(x)
|
|
135
|
+
return x
|
|
136
|
+
|
|
137
|
+
class GPT_2(nn.Module):
|
|
138
|
+
def __init__(self, vocab_size, num_layers, emb_dim, num_heads, seq_len,
|
|
139
|
+
dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
|
|
140
|
+
super().__init__()
|
|
141
|
+
self.device = device
|
|
142
|
+
self.dtype = dtype
|
|
143
|
+
self.seq_len = seq_len
|
|
144
|
+
|
|
145
|
+
# --- Token embedding ---
|
|
146
|
+
self.embedding = nn.Embedding(vocab_size, emb_dim, dtype=self.dtype, device=self.device)
|
|
147
|
+
|
|
148
|
+
# --- Adaptive position embedding ---
|
|
149
|
+
self.position_embedding = AbsolutePositionEmbedding(
|
|
150
|
+
emb_dim=emb_dim,
|
|
151
|
+
seq_len=seq_len
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# --- Encoder ---
|
|
155
|
+
self.encoder = GPT_2_Encoder(num_layers=num_layers,emb_dim=emb_dim,num_heads=num_heads,dropout=dropout,
|
|
156
|
+
hidden_dim=hidden_dim,eps=eps,device=self.device,dtype=self.dtype)
|
|
157
|
+
|
|
158
|
+
# --- Final norm
|
|
159
|
+
self.final_norm = LayerNorm(emb_dim, eps=eps)
|
|
160
|
+
|
|
161
|
+
# --- Output Projection ---
|
|
162
|
+
self.lm_head = nn.Linear(emb_dim, vocab_size, bias=False,
|
|
163
|
+
dtype=self.dtype, device=self.device)
|
|
164
|
+
|
|
165
|
+
def forward(self, x):
|
|
166
|
+
# x shape: (batch_size, seq_len)
|
|
167
|
+
emb = self.embedding(x) # (batch_size, seq_len, emb_dim)
|
|
168
|
+
pos = self.position_embedding(x) # (batch_size, seq_len, emb_dim)
|
|
169
|
+
x = emb + pos
|
|
170
|
+
x = self.encoder(x)
|
|
171
|
+
x = self.final_norm(x)
|
|
172
|
+
x = self.lm_head(x)
|
|
173
|
+
return x
|
|
174
|
+
|
|
175
|
+
@torch.no_grad()
|
|
176
|
+
def generate(self, prompt_ids, max_context_len=128, max_new_tokens=50, temperature=1.0, top_k=None, top_p=1.0, eos_token_id=None):
|
|
177
|
+
return text_generate(self, prompt_ids, max_context_len, max_new_tokens, temperature, top_k, top_p, eos_token_id)
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
|
|
6
|
+
from stackformer.modules.Attention import Multi_Head_Attention, Cross_MultiHead_Attention
|
|
7
|
+
from stackformer.modules.position_embedding import SinusoidalPositionalEmbedding
|
|
8
|
+
from stackformer.modules.Feed_forward import FF_ReLU
|
|
9
|
+
from stackformer.modules.Normalization import LayerNorm
|
|
10
|
+
|
|
11
|
+
class Encoder(nn.Module):
|
|
12
|
+
def __init__(self, emb_dim, num_heads, dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
|
|
13
|
+
super().__init__()
|
|
14
|
+
self.attention = Multi_Head_Attention(emb_dim, num_heads, dropout, device=device, dtype=dtype)
|
|
15
|
+
self.norm1 = LayerNorm(emb_dim, eps=eps, device=device, dtype=dtype)
|
|
16
|
+
self.ff_relu = FF_ReLU(emb_dim, hidden_dim, dropout, device=device, dtype=dtype)
|
|
17
|
+
self.norm2 = LayerNorm(emb_dim, eps=eps, device=device, dtype=dtype)
|
|
18
|
+
|
|
19
|
+
def forward(self, x):
|
|
20
|
+
residual = x
|
|
21
|
+
x = self.attention(x)
|
|
22
|
+
x = self.norm1(x)
|
|
23
|
+
x = x + residual
|
|
24
|
+
|
|
25
|
+
residual = x
|
|
26
|
+
x = self.ff_relu(x)
|
|
27
|
+
x = self.norm2(x)
|
|
28
|
+
x = x + residual
|
|
29
|
+
|
|
30
|
+
return x
|
|
31
|
+
|
|
32
|
+
class Decoder(nn.Module):
|
|
33
|
+
def __init__(self, emb_dim, num_heads, dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
|
|
34
|
+
super().__init__()
|
|
35
|
+
self.attention = Multi_Head_Attention(emb_dim, num_heads, dropout, device=device, dtype=dtype)
|
|
36
|
+
self.norm1 = LayerNorm(emb_dim, eps=eps, device=device, dtype=dtype)
|
|
37
|
+
self.cross_attention = Cross_MultiHead_Attention(emb_dim, num_heads, dropout, device=device, dtype=dtype)
|
|
38
|
+
self.norm2 = LayerNorm(emb_dim, eps=eps, device=device, dtype=dtype)
|
|
39
|
+
self.ff_relu = FF_ReLU(emb_dim, hidden_dim, dropout, device=device, dtype=dtype)
|
|
40
|
+
self.norm3 = LayerNorm(emb_dim, eps=eps, device=device, dtype=dtype)
|
|
41
|
+
|
|
42
|
+
def forward(self, x, enc_output):
|
|
43
|
+
residual = x
|
|
44
|
+
x = self.attention(x)
|
|
45
|
+
x = self.norm1(x)
|
|
46
|
+
x = x + residual
|
|
47
|
+
|
|
48
|
+
residual = x
|
|
49
|
+
x = self.cross_attention(x, context = enc_output)
|
|
50
|
+
x = self.norm2(x)
|
|
51
|
+
x = x + residual
|
|
52
|
+
|
|
53
|
+
residual = x
|
|
54
|
+
x = self.ff_relu(x)
|
|
55
|
+
x = self.norm3(x)
|
|
56
|
+
x = x + residual
|
|
57
|
+
|
|
58
|
+
return x
|
|
59
|
+
|
|
60
|
+
class transformer(nn.Module):
|
|
61
|
+
def __init__(self, vocab_size, emb_dim, num_heads, dropout, hidden_dim,
|
|
62
|
+
encoder_layers, decoder_layers, seq_len, eps=1e-5, device='cpu', dtype=torch.float32,
|
|
63
|
+
):
|
|
64
|
+
super().__init__()
|
|
65
|
+
self.encoder_layers = encoder_layers
|
|
66
|
+
self.decoder_layers = decoder_layers
|
|
67
|
+
|
|
68
|
+
self.token_emb = nn.Embedding(vocab_size, emb_dim, device=device, dtype=dtype)
|
|
69
|
+
self.pos = SinusoidalPositionalEmbedding(seq_len=seq_len, emb_dim=emb_dim)
|
|
70
|
+
|
|
71
|
+
self.encoder_stack = nn.ModuleList([
|
|
72
|
+
Encoder(emb_dim, num_heads, dropout, hidden_dim, eps=eps, device=device, dtype=dtype)
|
|
73
|
+
for _ in range(encoder_layers)
|
|
74
|
+
])
|
|
75
|
+
|
|
76
|
+
self.decoder_stack = nn.ModuleList([
|
|
77
|
+
Decoder(emb_dim, num_heads, dropout, hidden_dim, eps=eps, device=device, dtype=dtype)
|
|
78
|
+
for _ in range(decoder_layers)
|
|
79
|
+
])
|
|
80
|
+
|
|
81
|
+
# --- final norm ---
|
|
82
|
+
self.final_norm = LayerNorm(emb_dim, eps=eps, device=device, dtype=dtype)
|
|
83
|
+
|
|
84
|
+
# --- output projection ---
|
|
85
|
+
self.out_proj = nn.Linear(emb_dim, vocab_size, device=device, dtype=dtype)
|
|
86
|
+
|
|
87
|
+
def encoder(self, x):
|
|
88
|
+
x = self.token_emb(x) + self.pos(x)
|
|
89
|
+
for block in self.encoder_stack:
|
|
90
|
+
x = block(x)
|
|
91
|
+
return x
|
|
92
|
+
|
|
93
|
+
def decoder(self, x, enc_output):
|
|
94
|
+
x = self.token_emb(x) + self.pos(x)
|
|
95
|
+
for block in self.decoder_stack:
|
|
96
|
+
x = block(x, enc_output)
|
|
97
|
+
return x
|
|
98
|
+
|
|
99
|
+
def forward(self, source, target):
|
|
100
|
+
enc_output = self.encoder(source)
|
|
101
|
+
out = self.decoder(target, enc_output)
|
|
102
|
+
out = self.final_norm(out)
|
|
103
|
+
out = self.out_proj(out)
|
|
104
|
+
return out
|