Stackformer 0.1.2__tar.gz → 0.1.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {stackformer-0.1.2 → stackformer-0.1.3}/PKG-INFO +3 -3
  2. {stackformer-0.1.2 → stackformer-0.1.3}/Stackformer.egg-info/PKG-INFO +3 -3
  3. {stackformer-0.1.2 → stackformer-0.1.3}/Stackformer.egg-info/SOURCES.txt +1 -0
  4. stackformer-0.1.3/Stackformer.egg-info/requires.txt +2 -0
  5. {stackformer-0.1.2 → stackformer-0.1.3}/pyproject.toml +3 -3
  6. {stackformer-0.1.2 → stackformer-0.1.3}/setup.py +3 -3
  7. {stackformer-0.1.2 → stackformer-0.1.3}/stackformer/__init__.py +9 -2
  8. stackformer-0.1.3/stackformer/generate.py +53 -0
  9. stackformer-0.1.3/stackformer/models/Meta.py +159 -0
  10. stackformer-0.1.3/stackformer/models/OpenAI.py +177 -0
  11. stackformer-0.1.3/stackformer/models/Transformer.py +104 -0
  12. {stackformer-0.1.2 → stackformer-0.1.3}/stackformer/modules/Attention.py +73 -62
  13. stackformer-0.1.3/stackformer/modules/Feed_forward.py +90 -0
  14. stackformer-0.1.3/stackformer/modules/Normalization.py +31 -0
  15. {stackformer-0.1.2 → stackformer-0.1.3}/stackformer/modules/position_embedding.py +1 -1
  16. stackformer-0.1.2/Stackformer.egg-info/requires.txt +0 -2
  17. stackformer-0.1.2/stackformer/models/Meta.py +0 -213
  18. stackformer-0.1.2/stackformer/models/OpenAI.py +0 -242
  19. stackformer-0.1.2/stackformer/models/Transformer.py +0 -238
  20. stackformer-0.1.2/stackformer/modules/Feed_forward.py +0 -59
  21. stackformer-0.1.2/stackformer/modules/Normalization.py +0 -41
  22. {stackformer-0.1.2 → stackformer-0.1.3}/LICENSE +0 -0
  23. {stackformer-0.1.2 → stackformer-0.1.3}/README.md +0 -0
  24. {stackformer-0.1.2 → stackformer-0.1.3}/Stackformer.egg-info/dependency_links.txt +0 -0
  25. {stackformer-0.1.2 → stackformer-0.1.3}/Stackformer.egg-info/top_level.txt +0 -0
  26. {stackformer-0.1.2 → stackformer-0.1.3}/setup.cfg +0 -0
  27. {stackformer-0.1.2 → stackformer-0.1.3}/stackformer/models/__init__.py +0 -0
  28. {stackformer-0.1.2 → stackformer-0.1.3}/stackformer/modules/__init__.py +0 -0
  29. {stackformer-0.1.2 → stackformer-0.1.3}/stackformer/modules/mask.py +0 -0
  30. {stackformer-0.1.2 → stackformer-0.1.3}/stackformer/modules/tokenizer.py +0 -0
  31. {stackformer-0.1.2 → stackformer-0.1.3}/stackformer/trainer.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: Stackformer
3
- Version: 0.1.2
3
+ Version: 0.1.3
4
4
  Summary: Modular transformer blocks built in PyTorch
5
5
  Home-page: https://github.com/Gurumurthy30/Stackformer
6
6
  Author: Gurumurthy
@@ -12,8 +12,8 @@ Project-URL: Discussions, https://github.com/Gurumurthy30/Stackformer/discussion
12
12
  Requires-Python: >=3.9
13
13
  Description-Content-Type: text/markdown
14
14
  License-File: LICENSE
15
- Requires-Dist: torch<2.6,>=2.0
16
- Requires-Dist: tqdm>=4.67
15
+ Requires-Dist: torch<2.7,>=2.0
16
+ Requires-Dist: tqdm<5.0,>=4.5
17
17
  Dynamic: author
18
18
  Dynamic: home-page
19
19
  Dynamic: license-file
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: Stackformer
3
- Version: 0.1.2
3
+ Version: 0.1.3
4
4
  Summary: Modular transformer blocks built in PyTorch
5
5
  Home-page: https://github.com/Gurumurthy30/Stackformer
6
6
  Author: Gurumurthy
@@ -12,8 +12,8 @@ Project-URL: Discussions, https://github.com/Gurumurthy30/Stackformer/discussion
12
12
  Requires-Python: >=3.9
13
13
  Description-Content-Type: text/markdown
14
14
  License-File: LICENSE
15
- Requires-Dist: torch<2.6,>=2.0
16
- Requires-Dist: tqdm>=4.67
15
+ Requires-Dist: torch<2.7,>=2.0
16
+ Requires-Dist: tqdm<5.0,>=4.5
17
17
  Dynamic: author
18
18
  Dynamic: home-page
19
19
  Dynamic: license-file
@@ -8,6 +8,7 @@ Stackformer.egg-info/dependency_links.txt
8
8
  Stackformer.egg-info/requires.txt
9
9
  Stackformer.egg-info/top_level.txt
10
10
  stackformer/__init__.py
11
+ stackformer/generate.py
11
12
  stackformer/trainer.py
12
13
  stackformer/models/Meta.py
13
14
  stackformer/models/OpenAI.py
@@ -0,0 +1,2 @@
1
+ torch<2.7,>=2.0
2
+ tqdm<5.0,>=4.5
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "Stackformer"
3
- version = "0.1.2"
3
+ version = "0.1.3"
4
4
  description = "Modular transformer blocks built in PyTorch"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.9"
@@ -11,8 +11,8 @@ authors = [
11
11
  ]
12
12
 
13
13
  dependencies = [
14
- "torch>=2.0,<2.6",
15
- "tqdm>=4.67"
14
+ "torch>=2.0,<2.7",
15
+ "tqdm>=4.5,<5.0"
16
16
  ]
17
17
 
18
18
  [project.urls]
@@ -2,7 +2,7 @@ from setuptools import setup, find_packages
2
2
 
3
3
  setup(
4
4
  name="Stackformer",
5
- version="0.1.2",
5
+ version="0.1.3",
6
6
  description="Modular transformer blocks built in PyTorch",
7
7
  # long_description=open("README.md", "r", encoding="utf-8").read(),
8
8
  # long_description_content_type="text/markdown",
@@ -18,8 +18,8 @@ setup(
18
18
  python_requires=">=3.9",
19
19
  packages=find_packages(exclude=["tests", "examples"]),
20
20
  install_requires=[
21
- "torch>=2.0,<2.6",
22
- "tqdm>=4.67",
21
+ "torch>=2.0,<2.7",
22
+ "tqdm>=4.5,<5.0",
23
23
  ],
24
24
  classifiers=[
25
25
  "Programming Language :: Python :: 3",
@@ -9,6 +9,7 @@ from .modules.position_embedding import RoPE
9
9
  # --- Attention mechanisms ---
10
10
  from .modules.Attention import Self_Attention
11
11
  from .modules.Attention import Multi_Head_Attention
12
+ from .modules.Attention import Multi_Head_Attention_with_RoPE
12
13
  from .modules.Attention import Cross_MultiHead_Attention
13
14
  from .modules.Attention import Multi_query_Attention
14
15
  from .modules.Attention import Group_query_Attention
@@ -28,11 +29,17 @@ from .modules.Feed_forward import FF_GELU
28
29
  from .modules.Feed_forward import FF_LeakyReLU
29
30
  from .modules.Feed_forward import FF_Sigmoid
30
31
  from .modules.Feed_forward import FF_SiLU
32
+ from .modules.Feed_forward import FF_SwiGLU
31
33
 
32
34
  # --- Model ---
35
+ from .models.OpenAI import GPT_1
33
36
  from .models.OpenAI import GPT_2
34
- from .models.Meta import Llama_2
37
+ from .models.Meta import llama_1
38
+ from .models.Meta import llama_2
35
39
  from .models.Transformer import transformer
36
40
 
37
41
  # --- Trainer ---
38
- from .trainer import Trainer
42
+ from .trainer import Trainer
43
+
44
+ # --- Generate ---
45
+ from .generate import text_generate
@@ -0,0 +1,53 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ def text_generate(self, prompt_ids, max_context_len=128, max_new_tokens=50, temperature=1.0, top_k=None, top_p=1.0, eos_token_id=None):
5
+ if prompt_ids.dim() == 1:
6
+ prompt_ids = prompt_ids.unsqueeze(0)
7
+
8
+ generated = prompt_ids.clone()
9
+
10
+ for _ in range(max_new_tokens):
11
+ # Use sliding window if sequence gets too long
12
+ if generated.size(1) > max_context_len:
13
+ input_ids = generated[:, -max_context_len:]
14
+ else:
15
+ input_ids = generated
16
+
17
+ logits = self.forward(input_ids) # (batch_size, seq_len, vocab_size)
18
+ logits = logits[:, -1, :] # (batch_size, vocab_size)
19
+
20
+ # --- Temperature scaling ---
21
+ if temperature != 1.0:
22
+ logits = logits / temperature
23
+
24
+ # --- Top-k filtering ---
25
+ if top_k is not None and top_k > 0:
26
+ topk_vals, topk_indices = torch.topk(logits, top_k)
27
+ mask = torch.full_like(logits, float('-inf'))
28
+ mask.scatter_(dim=-1, index=topk_indices, src=topk_vals)
29
+ logits = mask
30
+
31
+ # --- Top-p ---
32
+ if top_p < 1.0:
33
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
34
+ probs = F.softmax(sorted_logits, dim=-1)
35
+ cum_probs = torch.cumsum(probs, dim=-1)
36
+
37
+ sorted_mask = cum_probs > top_p
38
+ sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
39
+ sorted_mask[..., 0] = 0
40
+
41
+ indices_to_remove = sorted_mask.scatter(dim=-1, index=sorted_indices, src=sorted_mask)
42
+ logits = logits.masked_fill(indices_to_remove, float('-inf'))
43
+
44
+ # Sample next token
45
+ probs = F.softmax(logits, dim=-1)
46
+ next_token = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
47
+ generated = torch.cat([generated, next_token], dim=-1)
48
+
49
+ # check if we've reached the end of the sequence
50
+ if eos_token_id is not None and next_token.item() == eos_token_id:
51
+ break
52
+
53
+ return generated
@@ -0,0 +1,159 @@
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+ from stackformer.modules.Attention import kv_cache_group_query, Multi_Head_Attention_with_RoPE
6
+ from stackformer.modules.Feed_forward import FF_SwiGLU
7
+ from stackformer.modules.Normalization import RMSNormilization
8
+ from stackformer.generate import text_generate
9
+
10
+ '''
11
+ llama 1
12
+ Attention: MHA
13
+ Mask: Casual
14
+ position: RoPE
15
+ FF: SwiGLU
16
+ Norm: pre norm (RMS norm)
17
+ '''
18
+ class llama_1_Block(nn.Module):
19
+ def __init__(self, emb_dim, num_heads, dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
20
+ super().__init__()
21
+ self.attention = Multi_Head_Attention_with_RoPE(emb_dim, num_heads, dropout, device=device, dtype=dtype)
22
+ self.norm1 = RMSNormilization(emb_dim, eps=eps)
23
+ self.FF_SwiGLU = FF_SwiGLU(emb_dim, hidden_dim, dropout, device=device, dtype=dtype)
24
+ self.norm2 = RMSNormilization(emb_dim, eps=eps)
25
+
26
+ def forward(self, x):
27
+ residual = x
28
+ x = self.norm1(x)
29
+ x = self.attention(x)
30
+ x = x + residual
31
+
32
+ residual = x
33
+ x = self.norm2(x)
34
+ x = self.FF_SwiGLU(x)
35
+ x = x + residual
36
+
37
+ return x
38
+
39
+ # --- Encoder ---
40
+ class llama_1_Encoder(nn.Module):
41
+ def __init__(self, num_layers, emb_dim, num_heads, dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
42
+ super().__init__()
43
+ self.layers = nn.ModuleList([
44
+ llama_1_Block(emb_dim, num_heads, dropout, hidden_dim, eps, device=device, dtype=dtype)
45
+ for _ in range(num_layers)
46
+ ])
47
+
48
+ def forward(self, x):
49
+ for layer in self.layers:
50
+ x = layer(x)
51
+ return x
52
+
53
+ class llama_1(nn.Module):
54
+ def __init__(self, vocab_size, num_layers, emb_dim, num_heads, seq_len,
55
+ dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
56
+ super().__init__()
57
+ self.device = device
58
+ self.dtype = dtype
59
+ self.seq_len = seq_len
60
+
61
+ # --- Token embedding ---
62
+ self.embedding = nn.Embedding(vocab_size, emb_dim, dtype=self.dtype, device=self.device)
63
+
64
+ # --- Encoder ---
65
+ self.encoder = llama_1_Encoder(num_layers=num_layers,emb_dim=emb_dim,num_heads=num_heads,dropout=dropout,
66
+ hidden_dim=hidden_dim,eps=eps,device=self.device,dtype=self.dtype)
67
+
68
+ # --- Final norm
69
+ self.final_norm = RMSNormilization(emb_dim, eps=eps)
70
+
71
+ # --- Output Projection ---
72
+ self.lm_head = nn.Linear(emb_dim, vocab_size, bias=False, dtype=self.dtype, device=self.device)
73
+
74
+ def forward(self, x):
75
+ # x shape: (batch_size, seq_len)
76
+ emb = self.embedding(x) # (batch_size, seq_len, emb_dim)
77
+ x = self.encoder(emb)
78
+ x = self.final_norm(x)
79
+ x = self.lm_head(x)
80
+ return x
81
+
82
+ @torch.no_grad()
83
+ def generate(self, prompt_ids, max_context_len=128, max_new_tokens=50, temperature=1.0, top_k=None, top_p=1.0, eos_token_id=None):
84
+ return text_generate(self, prompt_ids, max_context_len, max_new_tokens, temperature, top_k, top_p, eos_token_id)
85
+
86
+ '''
87
+ llama 2
88
+ Attention: GQA with KV catch
89
+ Mask: Casual
90
+ position: RoPE
91
+ FF: SwiGLU
92
+ Norm: pre norm (RMS norm)
93
+ '''
94
+ class llama_2_Block(nn.Module):
95
+ def __init__(self, emb_dim, query_num_heads, kv_num_heads, batch_size, kv_seq_len, hidden_dim,
96
+ eps=1e-5, dropout=0.1, dtype=torch.float32, device='cpu'):
97
+ super().__init__()
98
+ self.attn_norm = RMSNormilization(emb_dim, eps=eps)
99
+ self.ff_norm = RMSNormilization(emb_dim, eps=eps)
100
+ self.attn = kv_cache_group_query(emb_dim=emb_dim, query_num_heads=query_num_heads, kv_num_heads=kv_num_heads,
101
+ batch_size=batch_size, kv_seq_len=kv_seq_len, dtype=dtype,dropout=dropout, device=device)
102
+ self.ff = FF_SwiGLU(emb_dim=emb_dim, hidden_dim=hidden_dim, device=device, dtype=dtype)
103
+
104
+ def forward(self, x, start_pos):
105
+ residual = x
106
+ x = self.attn_norm(x)
107
+ x = self.attn(x, start_pos, rope=True)
108
+ x = x + residual
109
+
110
+ residual = x
111
+ x = self.ff_norm(x)
112
+ x = self.ff(x)
113
+ x = x + residual
114
+ return x
115
+
116
+ class llama_2_Encoder(nn.Module):
117
+ def __init__(self, num_layers, emb_dim, query_num_heads, kv_num_heads, batch_size, kv_seq_len,
118
+ hidden_dim, eps=1e-5, dropout=0.1, dtype=torch.float32, device='cpu'):
119
+ super().__init__()
120
+ self.layers = nn.ModuleList([
121
+ llama_2_Block(emb_dim=emb_dim, query_num_heads=query_num_heads, kv_num_heads=kv_num_heads,
122
+ batch_size=batch_size, kv_seq_len=kv_seq_len, hidden_dim=hidden_dim,
123
+ eps=eps, dropout=dropout, dtype=dtype, device=device)
124
+ for _ in range(num_layers)
125
+ ])
126
+
127
+ def forward(self, x, start_pos):
128
+ for layer in self.layers:
129
+ x = layer(x, start_pos)
130
+ return x
131
+
132
+ class llama_2(nn.Module):
133
+ def __init__(self, num_layers, emb_dim, query_num_heads, kv_num_heads, batch_size, kv_seq_len, vocab_size,
134
+ hidden_dim, eps=1e-5, dropout=0.1, dtype=torch.float32, device='cpu'):
135
+ super().__init__()
136
+ self.device = device
137
+ self.vocab_size = vocab_size
138
+ self.dtype = dtype
139
+ self.seq_len = kv_seq_len # For generation slicing
140
+
141
+ self.embedding = nn.Embedding(vocab_size, emb_dim, dtype=dtype, device=device)
142
+
143
+ self.llama_2_Encoder = llama_2_Encoder(num_layers=num_layers, emb_dim=emb_dim, query_num_heads=query_num_heads,
144
+ kv_num_heads=kv_num_heads, batch_size=batch_size, kv_seq_len=kv_seq_len,
145
+ hidden_dim=hidden_dim, eps=eps, dropout=dropout, dtype=dtype, device=device)
146
+
147
+ self.final_norm = RMSNormilization(emb_dim, eps=eps)
148
+ self.lm_head = nn.Linear(emb_dim, vocab_size, bias=False, dtype=dtype, device=device)
149
+
150
+ def forward(self, input_ids, start_pos=0):
151
+ x = self.embedding(input_ids)
152
+ x = self.llama_2_Encoder(x, start_pos)
153
+ x = self.final_norm(x)
154
+ logits = self.lm_head(x)
155
+ return logits
156
+
157
+ @torch.no_grad()
158
+ def generate(self, prompt_ids, max_context_len=128, max_new_tokens=50, temperature=1.0, top_k=None, top_p=1.0, eos_token_id=None):
159
+ return text_generate(self, prompt_ids, max_context_len, max_new_tokens, temperature, top_k, top_p, eos_token_id)
@@ -0,0 +1,177 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from stackformer.modules.Attention import Multi_Head_Attention
6
+ from stackformer.modules.position_embedding import AbsolutePositionEmbedding
7
+ from stackformer.modules.Normalization import LayerNorm
8
+ from stackformer.modules.Feed_forward import FF_GELU
9
+ from stackformer.generate import text_generate
10
+
11
+ '''
12
+ GPT-1
13
+ Attention: MHA
14
+ Mask: Casual
15
+ position: absolute
16
+ FF: GeLU
17
+ Norm: post normalization (layer norm)
18
+ '''
19
+ # --- GPT_1 Encoder Block ---
20
+ class GPT_1_Block(nn.Module):
21
+ def __init__(self, emb_dim, num_heads, dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
22
+ super().__init__()
23
+ self.attention = Multi_Head_Attention(emb_dim, num_heads, dropout, device=device, dtype=dtype)
24
+ self.norm1 = LayerNorm(emb_dim, eps=eps)
25
+ self.FF_GELU = FF_GELU(emb_dim, hidden_dim, dropout, device=device, dtype=dtype)
26
+ self.norm2 = LayerNorm(emb_dim, eps=eps)
27
+
28
+ def forward(self, x):
29
+ residual = x
30
+ x = self.attention(x)
31
+ x = self.norm1(x)
32
+ x = x + residual
33
+
34
+ residual = x
35
+ x = self.FF_GELU(x)
36
+ x = self.norm2(x)
37
+ x = x + residual
38
+
39
+ return x
40
+
41
+ # --- GPT_1 Encoder ---
42
+ class GPT_1_Encoder(nn.Module):
43
+ def __init__(self, num_layers, emb_dim, num_heads, dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
44
+ super().__init__()
45
+ self.layers = nn.ModuleList([
46
+ GPT_1_Block(emb_dim, num_heads, dropout, hidden_dim, eps, device=device, dtype=dtype)
47
+ for _ in range(num_layers)
48
+ ])
49
+
50
+ def forward(self, x):
51
+ for layer in self.layers:
52
+ x = layer(x)
53
+ return x
54
+
55
+ class GPT_1(nn.Module):
56
+ def __init__(self, vocab_size, num_layers, emb_dim, num_heads, seq_len,
57
+ dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
58
+ super().__init__()
59
+ self.device = device
60
+ self.dtype = dtype
61
+ self.seq_len = seq_len
62
+
63
+ # --- Token embedding ---
64
+ self.embedding = nn.Embedding(vocab_size, emb_dim, dtype=self.dtype, device=self.device)
65
+
66
+ # --- absolute position embedding ---
67
+ self.position_embedding = AbsolutePositionEmbedding(emb_dim=emb_dim, seq_len=seq_len)
68
+
69
+ # --- Encoder ---
70
+ self.encoder = GPT_1_Encoder(num_layers=num_layers,emb_dim=emb_dim,num_heads=num_heads,dropout=dropout,
71
+ hidden_dim=hidden_dim,eps=eps,device=self.device,dtype=self.dtype)
72
+
73
+ # --- Final norm
74
+ self.final_norm = LayerNorm(emb_dim, eps=eps)
75
+
76
+ # --- Output Projection ---
77
+ self.lm_head = nn.Linear(emb_dim, vocab_size, bias=False, dtype=self.dtype, device=self.device)
78
+
79
+ def forward(self, x):
80
+ # x shape: (batch_size, seq_len)
81
+ emb = self.embedding(x) # (batch_size, seq_len, emb_dim)
82
+ pos = self.position_embedding(x) # (batch_size, seq_len, emb_dim)
83
+ x = emb + pos
84
+ x = self.encoder(x)
85
+ x = self.final_norm(x)
86
+ x = self.lm_head(x)
87
+ return x
88
+
89
+ @torch.no_grad()
90
+ def generate(self, prompt_ids, max_context_len=128, max_new_tokens=50, temperature=1.0, top_k=None, top_p=1.0, eos_token_id=None):
91
+ return text_generate(self, prompt_ids, max_context_len, max_new_tokens, temperature, top_k, top_p, eos_token_id)
92
+
93
+ '''
94
+ GPT-2
95
+ Attention: MHA
96
+ Mask: Casual
97
+ position: absolute
98
+ FF: GeLU
99
+ Norm: pre normalization (layer norm)
100
+ '''
101
+ # --- Encoder Block ---
102
+ class GPT_2_Block(nn.Module):
103
+ def __init__(self, emb_dim, num_heads, dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
104
+ super().__init__()
105
+ self.attention = Multi_Head_Attention(emb_dim, num_heads, dropout, device=device, dtype=dtype)
106
+ self.norm1 = LayerNorm(emb_dim, eps=eps)
107
+ self.FF_GELU = FF_GELU(emb_dim, hidden_dim, dropout, device=device, dtype=dtype)
108
+ self.norm2 = LayerNorm(emb_dim, eps=eps)
109
+
110
+ def forward(self, x):
111
+ residual = x
112
+ x = self.norm1(x)
113
+ x = self.attention(x)
114
+ x = x + residual
115
+
116
+ residual = x
117
+ x = self.norm2(x)
118
+ x = self.FF_GELU(x)
119
+ x = x + residual
120
+
121
+ return x
122
+
123
+ # --- Encoder ---
124
+ class GPT_2_Encoder(nn.Module):
125
+ def __init__(self, num_layers, emb_dim, num_heads, dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
126
+ super().__init__()
127
+ self.layers = nn.ModuleList([
128
+ GPT_2_Block(emb_dim, num_heads, dropout, hidden_dim, eps, device=device, dtype=dtype)
129
+ for _ in range(num_layers)
130
+ ])
131
+
132
+ def forward(self, x):
133
+ for layer in self.layers:
134
+ x = layer(x)
135
+ return x
136
+
137
+ class GPT_2(nn.Module):
138
+ def __init__(self, vocab_size, num_layers, emb_dim, num_heads, seq_len,
139
+ dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
140
+ super().__init__()
141
+ self.device = device
142
+ self.dtype = dtype
143
+ self.seq_len = seq_len
144
+
145
+ # --- Token embedding ---
146
+ self.embedding = nn.Embedding(vocab_size, emb_dim, dtype=self.dtype, device=self.device)
147
+
148
+ # --- Adaptive position embedding ---
149
+ self.position_embedding = AbsolutePositionEmbedding(
150
+ emb_dim=emb_dim,
151
+ seq_len=seq_len
152
+ )
153
+
154
+ # --- Encoder ---
155
+ self.encoder = GPT_2_Encoder(num_layers=num_layers,emb_dim=emb_dim,num_heads=num_heads,dropout=dropout,
156
+ hidden_dim=hidden_dim,eps=eps,device=self.device,dtype=self.dtype)
157
+
158
+ # --- Final norm
159
+ self.final_norm = LayerNorm(emb_dim, eps=eps)
160
+
161
+ # --- Output Projection ---
162
+ self.lm_head = nn.Linear(emb_dim, vocab_size, bias=False,
163
+ dtype=self.dtype, device=self.device)
164
+
165
+ def forward(self, x):
166
+ # x shape: (batch_size, seq_len)
167
+ emb = self.embedding(x) # (batch_size, seq_len, emb_dim)
168
+ pos = self.position_embedding(x) # (batch_size, seq_len, emb_dim)
169
+ x = emb + pos
170
+ x = self.encoder(x)
171
+ x = self.final_norm(x)
172
+ x = self.lm_head(x)
173
+ return x
174
+
175
+ @torch.no_grad()
176
+ def generate(self, prompt_ids, max_context_len=128, max_new_tokens=50, temperature=1.0, top_k=None, top_p=1.0, eos_token_id=None):
177
+ return text_generate(self, prompt_ids, max_context_len, max_new_tokens, temperature, top_k, top_p, eos_token_id)
@@ -0,0 +1,104 @@
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from stackformer.modules.Attention import Multi_Head_Attention, Cross_MultiHead_Attention
7
+ from stackformer.modules.position_embedding import SinusoidalPositionalEmbedding
8
+ from stackformer.modules.Feed_forward import FF_ReLU
9
+ from stackformer.modules.Normalization import LayerNorm
10
+
11
+ class Encoder(nn.Module):
12
+ def __init__(self, emb_dim, num_heads, dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
13
+ super().__init__()
14
+ self.attention = Multi_Head_Attention(emb_dim, num_heads, dropout, device=device, dtype=dtype)
15
+ self.norm1 = LayerNorm(emb_dim, eps=eps, device=device, dtype=dtype)
16
+ self.ff_relu = FF_ReLU(emb_dim, hidden_dim, dropout, device=device, dtype=dtype)
17
+ self.norm2 = LayerNorm(emb_dim, eps=eps, device=device, dtype=dtype)
18
+
19
+ def forward(self, x):
20
+ residual = x
21
+ x = self.attention(x)
22
+ x = self.norm1(x)
23
+ x = x + residual
24
+
25
+ residual = x
26
+ x = self.ff_relu(x)
27
+ x = self.norm2(x)
28
+ x = x + residual
29
+
30
+ return x
31
+
32
+ class Decoder(nn.Module):
33
+ def __init__(self, emb_dim, num_heads, dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
34
+ super().__init__()
35
+ self.attention = Multi_Head_Attention(emb_dim, num_heads, dropout, device=device, dtype=dtype)
36
+ self.norm1 = LayerNorm(emb_dim, eps=eps, device=device, dtype=dtype)
37
+ self.cross_attention = Cross_MultiHead_Attention(emb_dim, num_heads, dropout, device=device, dtype=dtype)
38
+ self.norm2 = LayerNorm(emb_dim, eps=eps, device=device, dtype=dtype)
39
+ self.ff_relu = FF_ReLU(emb_dim, hidden_dim, dropout, device=device, dtype=dtype)
40
+ self.norm3 = LayerNorm(emb_dim, eps=eps, device=device, dtype=dtype)
41
+
42
+ def forward(self, x, enc_output):
43
+ residual = x
44
+ x = self.attention(x)
45
+ x = self.norm1(x)
46
+ x = x + residual
47
+
48
+ residual = x
49
+ x = self.cross_attention(x, context = enc_output)
50
+ x = self.norm2(x)
51
+ x = x + residual
52
+
53
+ residual = x
54
+ x = self.ff_relu(x)
55
+ x = self.norm3(x)
56
+ x = x + residual
57
+
58
+ return x
59
+
60
+ class transformer(nn.Module):
61
+ def __init__(self, vocab_size, emb_dim, num_heads, dropout, hidden_dim,
62
+ encoder_layers, decoder_layers, seq_len, eps=1e-5, device='cpu', dtype=torch.float32,
63
+ ):
64
+ super().__init__()
65
+ self.encoder_layers = encoder_layers
66
+ self.decoder_layers = decoder_layers
67
+
68
+ self.token_emb = nn.Embedding(vocab_size, emb_dim, device=device, dtype=dtype)
69
+ self.pos = SinusoidalPositionalEmbedding(seq_len=seq_len, emb_dim=emb_dim)
70
+
71
+ self.encoder_stack = nn.ModuleList([
72
+ Encoder(emb_dim, num_heads, dropout, hidden_dim, eps=eps, device=device, dtype=dtype)
73
+ for _ in range(encoder_layers)
74
+ ])
75
+
76
+ self.decoder_stack = nn.ModuleList([
77
+ Decoder(emb_dim, num_heads, dropout, hidden_dim, eps=eps, device=device, dtype=dtype)
78
+ for _ in range(decoder_layers)
79
+ ])
80
+
81
+ # --- final norm ---
82
+ self.final_norm = LayerNorm(emb_dim, eps=eps, device=device, dtype=dtype)
83
+
84
+ # --- output projection ---
85
+ self.out_proj = nn.Linear(emb_dim, vocab_size, device=device, dtype=dtype)
86
+
87
+ def encoder(self, x):
88
+ x = self.token_emb(x) + self.pos(x)
89
+ for block in self.encoder_stack:
90
+ x = block(x)
91
+ return x
92
+
93
+ def decoder(self, x, enc_output):
94
+ x = self.token_emb(x) + self.pos(x)
95
+ for block in self.decoder_stack:
96
+ x = block(x, enc_output)
97
+ return x
98
+
99
+ def forward(self, source, target):
100
+ enc_output = self.encoder(source)
101
+ out = self.decoder(target, enc_output)
102
+ out = self.final_norm(out)
103
+ out = self.out_proj(out)
104
+ return out