Stackformer 0.1.0__tar.gz → 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. {stackformer-0.1.0 → stackformer-0.1.2}/PKG-INFO +9 -4
  2. {stackformer-0.1.0 → stackformer-0.1.2}/README.md +7 -2
  3. {stackformer-0.1.0 → stackformer-0.1.2}/Stackformer.egg-info/PKG-INFO +9 -4
  4. stackformer-0.1.2/Stackformer.egg-info/SOURCES.txt +22 -0
  5. stackformer-0.1.2/Stackformer.egg-info/requires.txt +2 -0
  6. stackformer-0.1.2/Stackformer.egg-info/top_level.txt +1 -0
  7. {stackformer-0.1.0 → stackformer-0.1.2}/pyproject.toml +2 -2
  8. {stackformer-0.1.0 → stackformer-0.1.2}/setup.py +2 -2
  9. stackformer-0.1.2/stackformer/__init__.py +38 -0
  10. stackformer-0.1.2/stackformer/models/Meta.py +213 -0
  11. stackformer-0.1.0/models/GPT_2.py → stackformer-0.1.2/stackformer/models/OpenAI.py +7 -3
  12. stackformer-0.1.2/stackformer/models/Transformer.py +238 -0
  13. {stackformer-0.1.0 → stackformer-0.1.2/stackformer}/modules/Attention.py +33 -34
  14. {stackformer-0.1.0 → stackformer-0.1.2/stackformer}/modules/mask.py +1 -1
  15. {stackformer-0.1.0 → stackformer-0.1.2/stackformer}/modules/position_embedding.py +6 -4
  16. stackformer-0.1.2/stackformer/trainer.py +356 -0
  17. stackformer-0.1.0/Stackformer.egg-info/SOURCES.txt +0 -18
  18. stackformer-0.1.0/Stackformer.egg-info/requires.txt +0 -2
  19. stackformer-0.1.0/Stackformer.egg-info/top_level.txt +0 -2
  20. {stackformer-0.1.0 → stackformer-0.1.2}/LICENSE +0 -0
  21. {stackformer-0.1.0 → stackformer-0.1.2}/Stackformer.egg-info/dependency_links.txt +0 -0
  22. {stackformer-0.1.0 → stackformer-0.1.2}/setup.cfg +0 -0
  23. {stackformer-0.1.0 → stackformer-0.1.2/stackformer}/models/__init__.py +0 -0
  24. {stackformer-0.1.0 → stackformer-0.1.2/stackformer}/modules/Feed_forward.py +0 -0
  25. {stackformer-0.1.0 → stackformer-0.1.2/stackformer}/modules/Normalization.py +0 -0
  26. {stackformer-0.1.0 → stackformer-0.1.2/stackformer}/modules/__init__.py +0 -0
  27. {stackformer-0.1.0 → stackformer-0.1.2/stackformer}/modules/tokenizer.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: Stackformer
3
- Version: 0.1.0
3
+ Version: 0.1.2
4
4
  Summary: Modular transformer blocks built in PyTorch
5
5
  Home-page: https://github.com/Gurumurthy30/Stackformer
6
6
  Author: Gurumurthy
@@ -12,7 +12,7 @@ Project-URL: Discussions, https://github.com/Gurumurthy30/Stackformer/discussion
12
12
  Requires-Python: >=3.9
13
13
  Description-Content-Type: text/markdown
14
14
  License-File: LICENSE
15
- Requires-Dist: torch>=2.6
15
+ Requires-Dist: torch<2.6,>=2.0
16
16
  Requires-Dist: tqdm>=4.67
17
17
  Dynamic: author
18
18
  Dynamic: home-page
@@ -61,11 +61,16 @@ stackformer/ \
61
61
 
62
62
  ## 💻 Installation
63
63
 
64
- Clone the repository and install in development mode:
64
+ Method 1: Install from PyPI:
65
+ ```bash
66
+ pip install Stackformer
67
+ import stackformer
68
+ ```
65
69
 
70
+ 🔧 Method 2: Clone the repository:
66
71
  ```bash
67
72
  git clone https://github.com/Gurumurthy30/Stackformer
68
- cd transformers
73
+ cd Stackformer
69
74
  pip install -e .
70
75
  ```
71
76
 
@@ -40,11 +40,16 @@ stackformer/ \
40
40
 
41
41
  ## 💻 Installation
42
42
 
43
- Clone the repository and install in development mode:
43
+ Method 1: Install from PyPI:
44
+ ```bash
45
+ pip install Stackformer
46
+ import stackformer
47
+ ```
44
48
 
49
+ 🔧 Method 2: Clone the repository:
45
50
  ```bash
46
51
  git clone https://github.com/Gurumurthy30/Stackformer
47
- cd transformers
52
+ cd Stackformer
48
53
  pip install -e .
49
54
  ```
50
55
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: Stackformer
3
- Version: 0.1.0
3
+ Version: 0.1.2
4
4
  Summary: Modular transformer blocks built in PyTorch
5
5
  Home-page: https://github.com/Gurumurthy30/Stackformer
6
6
  Author: Gurumurthy
@@ -12,7 +12,7 @@ Project-URL: Discussions, https://github.com/Gurumurthy30/Stackformer/discussion
12
12
  Requires-Python: >=3.9
13
13
  Description-Content-Type: text/markdown
14
14
  License-File: LICENSE
15
- Requires-Dist: torch>=2.6
15
+ Requires-Dist: torch<2.6,>=2.0
16
16
  Requires-Dist: tqdm>=4.67
17
17
  Dynamic: author
18
18
  Dynamic: home-page
@@ -61,11 +61,16 @@ stackformer/ \
61
61
 
62
62
  ## 💻 Installation
63
63
 
64
- Clone the repository and install in development mode:
64
+ Method 1: Install from PyPI:
65
+ ```bash
66
+ pip install Stackformer
67
+ import stackformer
68
+ ```
65
69
 
70
+ 🔧 Method 2: Clone the repository:
66
71
  ```bash
67
72
  git clone https://github.com/Gurumurthy30/Stackformer
68
- cd transformers
73
+ cd Stackformer
69
74
  pip install -e .
70
75
  ```
71
76
 
@@ -0,0 +1,22 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ setup.py
5
+ Stackformer.egg-info/PKG-INFO
6
+ Stackformer.egg-info/SOURCES.txt
7
+ Stackformer.egg-info/dependency_links.txt
8
+ Stackformer.egg-info/requires.txt
9
+ Stackformer.egg-info/top_level.txt
10
+ stackformer/__init__.py
11
+ stackformer/trainer.py
12
+ stackformer/models/Meta.py
13
+ stackformer/models/OpenAI.py
14
+ stackformer/models/Transformer.py
15
+ stackformer/models/__init__.py
16
+ stackformer/modules/Attention.py
17
+ stackformer/modules/Feed_forward.py
18
+ stackformer/modules/Normalization.py
19
+ stackformer/modules/__init__.py
20
+ stackformer/modules/mask.py
21
+ stackformer/modules/position_embedding.py
22
+ stackformer/modules/tokenizer.py
@@ -0,0 +1,2 @@
1
+ torch<2.6,>=2.0
2
+ tqdm>=4.67
@@ -0,0 +1 @@
1
+ stackformer
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "Stackformer"
3
- version = "0.1.0"
3
+ version = "0.1.2"
4
4
  description = "Modular transformer blocks built in PyTorch"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.9"
@@ -11,7 +11,7 @@ authors = [
11
11
  ]
12
12
 
13
13
  dependencies = [
14
- "torch>=2.6",
14
+ "torch>=2.0,<2.6",
15
15
  "tqdm>=4.67"
16
16
  ]
17
17
 
@@ -2,7 +2,7 @@ from setuptools import setup, find_packages
2
2
 
3
3
  setup(
4
4
  name="Stackformer",
5
- version="0.1.0",
5
+ version="0.1.2",
6
6
  description="Modular transformer blocks built in PyTorch",
7
7
  # long_description=open("README.md", "r", encoding="utf-8").read(),
8
8
  # long_description_content_type="text/markdown",
@@ -18,7 +18,7 @@ setup(
18
18
  python_requires=">=3.9",
19
19
  packages=find_packages(exclude=["tests", "examples"]),
20
20
  install_requires=[
21
- "torch>=2.6",
21
+ "torch>=2.0,<2.6",
22
22
  "tqdm>=4.67",
23
23
  ],
24
24
  classifiers=[
@@ -0,0 +1,38 @@
1
+ # --- Tokenizer ---
2
+ from .modules.tokenizer import Embedding_using_tiktoken
3
+
4
+ # --- Position Embeddings ---
5
+ from .modules.position_embedding import AbsolutePositionEmbedding
6
+ from .modules.position_embedding import SinusoidalPositionalEmbedding
7
+ from .modules.position_embedding import RoPE
8
+
9
+ # --- Attention mechanisms ---
10
+ from .modules.Attention import Self_Attention
11
+ from .modules.Attention import Multi_Head_Attention
12
+ from .modules.Attention import Cross_MultiHead_Attention
13
+ from .modules.Attention import Multi_query_Attention
14
+ from .modules.Attention import Group_query_Attention
15
+ from .modules.Attention import Linear_Attention
16
+ from .modules.Attention import Multi_latent_Attention
17
+ from .modules.Attention import Local_Attention
18
+ from .modules.Attention import kv_cache_multihead
19
+ from .modules.Attention import kv_cache_group_query
20
+
21
+ # --- Normalization layers ---
22
+ from .modules.Normalization import LayerNorm
23
+ from .modules.Normalization import RMSNormilization
24
+
25
+ # --- Feed Forward layers ---
26
+ from .modules.Feed_forward import FF_ReLU
27
+ from .modules.Feed_forward import FF_GELU
28
+ from .modules.Feed_forward import FF_LeakyReLU
29
+ from .modules.Feed_forward import FF_Sigmoid
30
+ from .modules.Feed_forward import FF_SiLU
31
+
32
+ # --- Model ---
33
+ from .models.OpenAI import GPT_2
34
+ from .models.Meta import Llama_2
35
+ from .models.Transformer import transformer
36
+
37
+ # --- Trainer ---
38
+ from .trainer import Trainer
@@ -0,0 +1,213 @@
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def precompute_theta_position_frequency(head_dim, seq_len, device='cpu', theta=10000.0):
7
+ assert head_dim % 2 == 0, "head_dim must be even"
8
+ theta_numerator = torch.arange(0, head_dim, 2, device=device)
9
+ inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))
10
+ m = torch.arange(seq_len, device=device)
11
+ freqs = torch.outer(m, inv_freq)
12
+ freq_complex = torch.polar(torch.ones_like(freqs), freqs)
13
+ return freq_complex
14
+
15
+
16
+ def apply_rotry_position_embedding(x, freq_complex, device='cpu', dtype=torch.float32):
17
+ batch_size, seq_len, num_head, emb_dim = x.shape
18
+ assert emb_dim % 2 == 0, "emb_dim must be even"
19
+ x_reshaped = x.view(batch_size, seq_len, num_head, emb_dim // 2, 2).to(device=device, dtype=dtype)
20
+ x_complex = torch.view_as_complex(x_reshaped)
21
+ freq_complex = freq_complex[:seq_len].unsqueeze(0).unsqueeze(2).to(device=device)
22
+ x_rotated = x_complex * freq_complex
23
+ x_out = torch.view_as_real(x_rotated).contiguous().view(batch_size, seq_len, num_head, emb_dim)
24
+ return x_out.to(device=device, dtype=dtype)
25
+
26
+
27
+ class kv_cache_group_query(nn.Module):
28
+ def __init__(self, emb_dim, query_num_heads, kv_num_heads, batch_size, kv_seq_len,
29
+ device='cpu', dtype=torch.float32, dropout=0.1):
30
+ super().__init__()
31
+ assert emb_dim % query_num_heads == 0, "Embedding dim must be divisible by query heads"
32
+ assert query_num_heads % kv_num_heads == 0, "query heads must be divisible by kv heads"
33
+
34
+ self.device = device
35
+ self.dtype = dtype
36
+ self.emb_dim = emb_dim
37
+ self.query_num_heads = query_num_heads
38
+ self.kv_num_heads = kv_num_heads
39
+ self.head_dim = emb_dim // query_num_heads
40
+ self.num_queries_per_kv = query_num_heads // kv_num_heads
41
+ self.kv_seq_len = kv_seq_len
42
+
43
+ self.query = nn.Linear(emb_dim, emb_dim, bias=False, dtype=dtype, device=device)
44
+ self.key = nn.Linear(emb_dim, kv_num_heads * self.head_dim, bias=False, dtype=dtype, device=device)
45
+ self.value = nn.Linear(emb_dim, kv_num_heads * self.head_dim, bias=False, dtype=dtype, device=device)
46
+
47
+ self.out_proj = nn.Linear(query_num_heads * self.head_dim, emb_dim, dtype=dtype, device=device)
48
+ self.dropout = nn.Dropout(dropout)
49
+
50
+ self.register_buffer("cache_keys", torch.zeros(batch_size, kv_seq_len, kv_num_heads, self.head_dim, device=device, dtype=dtype))
51
+ self.register_buffer("cache_value", torch.zeros(batch_size, kv_seq_len, kv_num_heads, self.head_dim, device=device, dtype=dtype))
52
+
53
+ def forward(self, x, start_pos):
54
+ batch_size, seq_len, _ = x.shape
55
+
56
+ xq = self.query(x).view(batch_size, seq_len, self.query_num_heads, self.head_dim)
57
+ xk = self.key(x).view(batch_size, seq_len, self.kv_num_heads, self.head_dim)
58
+ xv = self.value(x).view(batch_size, seq_len, self.kv_num_heads, self.head_dim)
59
+
60
+ freq_q = precompute_theta_position_frequency(head_dim=self.head_dim, seq_len=seq_len, device=self.device)
61
+ xq = apply_rotry_position_embedding(xq, freq_q, device=self.device, dtype=self.dtype)
62
+ freq_k = precompute_theta_position_frequency(head_dim=self.head_dim, seq_len=self.kv_seq_len, device=self.device)
63
+ xk = apply_rotry_position_embedding(xk, freq_k, device=self.device, dtype=self.dtype)
64
+
65
+ self.cache_keys[:, start_pos:start_pos + seq_len] = xk
66
+ self.cache_value[:, start_pos:start_pos + seq_len] = xv
67
+
68
+ xk_full = self.cache_keys[:, :start_pos + seq_len]
69
+ xv_full = self.cache_value[:, :start_pos + seq_len]
70
+
71
+ query = xq.transpose(1, 2)
72
+ key = xk_full.transpose(1, 2).repeat_interleave(self.num_queries_per_kv, dim=1)
73
+ value = xv_full.transpose(1, 2).repeat_interleave(self.num_queries_per_kv, dim=1)
74
+
75
+ attn_scores = torch.matmul(query, key.transpose(2, 3)) / (self.head_dim ** 0.5)
76
+
77
+ causal_mask = torch.triu(torch.ones(seq_len, attn_scores.shape[-1], dtype=torch.bool, device=self.device), diagonal=1)
78
+ attn_scores.masked_fill_(causal_mask[None, None, :, :], float('-inf'))
79
+
80
+ attn_weights = F.softmax(attn_scores, dim=-1)
81
+ out = torch.matmul(attn_weights, value)
82
+
83
+ out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.query_num_heads * self.head_dim)
84
+ return self.dropout(self.out_proj(out))
85
+
86
+
87
+ class RMSNormilization(nn.Module):
88
+ def __init__(self, emb_dim, eps=1e-5, device='cpu', dtype=torch.float32):
89
+ super().__init__()
90
+ self.eps = eps
91
+ self.scale = nn.Parameter(torch.ones(emb_dim, dtype=dtype, device=device))
92
+
93
+ def forward(self, x):
94
+ norm = x.norm(2, dim=-1, keepdim=True)
95
+ rms = norm / (x.shape[-1] ** 0.5)
96
+ return (x / (rms + self.eps)) * self.scale
97
+
98
+
99
+ class FF_SiLU(nn.Module):
100
+ def __init__(self, emb_dim, hidden_dim, device='cpu', dtype=torch.float32):
101
+ super().__init__()
102
+ self.silu = nn.Sequential(
103
+ nn.Linear(emb_dim, hidden_dim, device=device, dtype=dtype),
104
+ nn.SiLU(),
105
+ nn.Linear(hidden_dim, emb_dim, device=device, dtype=dtype),
106
+ )
107
+
108
+ def forward(self, x):
109
+ return self.silu(x)
110
+
111
+
112
+ class block(nn.Module):
113
+ def __init__(self, emb_dim, query_num_heads, kv_num_heads, batch_size, kv_seq_len, hidden_dim,
114
+ eps=1e-5, dropout=0.1, dtype=torch.float32, device='cpu'):
115
+ super().__init__()
116
+ self.attn_norm = RMSNormilization(emb_dim=emb_dim, eps=eps, device=device, dtype=dtype)
117
+ self.ff_norm = RMSNormilization(emb_dim=emb_dim, eps=eps, device=device, dtype=dtype)
118
+ self.attn = kv_cache_group_query(emb_dim=emb_dim, query_num_heads=query_num_heads, kv_num_heads=kv_num_heads,
119
+ batch_size=batch_size, kv_seq_len=kv_seq_len, dtype=dtype,
120
+ dropout=dropout, device=device)
121
+ self.ff = FF_SiLU(emb_dim=emb_dim, hidden_dim=hidden_dim, device=device, dtype=dtype)
122
+
123
+ def forward(self, x, start_pos):
124
+ residual = x
125
+ x = self.attn_norm(x)
126
+ x = self.attn(x, start_pos)
127
+ x = x + residual
128
+
129
+ residual = x
130
+ x = self.ff_norm(x)
131
+ x = self.ff(x)
132
+ x = x + residual
133
+ return x
134
+
135
+
136
+ class Encoder(nn.Module):
137
+ def __init__(self, num_layers, emb_dim, query_num_heads, kv_num_heads, batch_size, kv_seq_len,
138
+ hidden_dim, eps=1e-5, dropout=0.1, dtype=torch.float32, device='cpu'):
139
+ super().__init__()
140
+ self.layers = nn.ModuleList([
141
+ block(emb_dim=emb_dim, query_num_heads=query_num_heads, kv_num_heads=kv_num_heads,
142
+ batch_size=batch_size, kv_seq_len=kv_seq_len, hidden_dim=hidden_dim,
143
+ eps=eps, dropout=dropout, dtype=dtype, device=device)
144
+ for _ in range(num_layers)
145
+ ])
146
+
147
+ def forward(self, x, start_pos):
148
+ for layer in self.layers:
149
+ x = layer(x, start_pos)
150
+ return x
151
+
152
+
153
+ class Llama_2(nn.Module):
154
+ def __init__(self, num_layers, emb_dim, query_num_heads, kv_num_heads, batch_size, kv_seq_len, vocab_size,
155
+ hidden_dim, eps=1e-5, dropout=0.1, dtype=torch.float32, device='cpu'):
156
+ super().__init__()
157
+ self.device = device
158
+ self.vocab_size = vocab_size
159
+ self.dtype = dtype
160
+ self.seq_len = kv_seq_len # For generation slicing
161
+
162
+ self.embedding = nn.Embedding(vocab_size, emb_dim, dtype=dtype, device=device)
163
+
164
+ self.encoder = Encoder(num_layers=num_layers, emb_dim=emb_dim, query_num_heads=query_num_heads,
165
+ kv_num_heads=kv_num_heads, batch_size=batch_size, kv_seq_len=kv_seq_len,
166
+ hidden_dim=hidden_dim, eps=eps, dropout=dropout, dtype=dtype, device=device)
167
+
168
+ self.final_norm = RMSNormilization(emb_dim, eps=eps, device=device, dtype=dtype)
169
+ self.lm_head = nn.Linear(emb_dim, vocab_size, bias=False, dtype=dtype, device=device)
170
+
171
+ def forward(self, input_ids, start_pos=0):
172
+ x = self.embedding(input_ids)
173
+ x = self.encoder(x, start_pos)
174
+ x = self.final_norm(x)
175
+ logits = self.lm_head(x)
176
+ return logits
177
+
178
+ @torch.no_grad()
179
+ def generate(self, prompt_ids, max_new_tokens=50, temperature=1.0, top_k=None, top_p=1.0):
180
+ self.eval()
181
+ if prompt_ids.dim() == 1:
182
+ prompt_ids = prompt_ids.unsqueeze(0)
183
+
184
+ generated = prompt_ids.clone()
185
+ for step in range(max_new_tokens):
186
+ input_ids = generated[:, -self.seq_len:]
187
+ logits = self.forward(input_ids, start_pos=step) # Correct start_pos
188
+ logits = logits[:, -1, :]
189
+
190
+ if temperature != 1.0:
191
+ logits = logits / temperature
192
+
193
+ if top_k is not None and top_k > 0:
194
+ topk_vals, topk_indices = torch.topk(logits, top_k)
195
+ mask = torch.full_like(logits, float('-inf'))
196
+ mask.scatter_(dim=-1, index=topk_indices, src=topk_vals)
197
+ logits = mask
198
+
199
+ if top_p < 1.0:
200
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
201
+ probs = F.softmax(sorted_logits, dim=-1)
202
+ cum_probs = torch.cumsum(probs, dim=-1)
203
+ sorted_mask = cum_probs > top_p
204
+ sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
205
+ sorted_mask[..., 0] = 0
206
+ indices_to_remove = sorted_mask.scatter(dim=-1, index=sorted_indices, src=sorted_mask)
207
+ logits = logits.masked_fill(indices_to_remove, float('-inf'))
208
+
209
+ probs = F.softmax(logits, dim=-1)
210
+ next_token = torch.multinomial(probs, num_samples=1)
211
+ generated = torch.cat([generated, next_token], dim=-1)
212
+
213
+ return generated
@@ -22,7 +22,8 @@ class SinusoidalPositionalEmbedding(nn.Module):
22
22
  def forward(self, x):
23
23
  # x shape: (batch_size, seq_len, emb_dim) or (batch_size, seq_len)
24
24
  batch_size, seq_len = x.shape[0], x.shape[1]
25
- return self.pe[:seq_len].unsqueeze(0).expand(batch_size, seq_len, -1).to(x.device)
25
+ out = self.pe[:seq_len].unsqueeze(0).expand(batch_size, seq_len, -1)
26
+ return out.to(device=x.device,dtype=x.dtype)
26
27
 
27
28
  # --- Multi Head Attention ---
28
29
  class MultiHeadAttention(nn.Module):
@@ -135,11 +136,14 @@ class Encoder(nn.Module):
135
136
  x = layer(x)
136
137
  return x
137
138
 
138
- class GPTModel(nn.Module):
139
+ class GPT_2(nn.Module):
139
140
  def __init__(self, vocab_size, num_layers, Emb_dim, num_heads, seq_len,
140
141
  dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
141
-
142
142
  super().__init__()
143
+ self.device = device
144
+ self.dtype = dtype
145
+ self.seq_len = seq_len
146
+
143
147
  # --- Token embedding ---
144
148
  self.embedding = nn.Embedding(vocab_size, Emb_dim, dtype=self.dtype, device=self.device)
145
149
 
@@ -0,0 +1,238 @@
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ # --- position embedding ---
7
+ class SinusoidalPositionalEmbedding(nn.Module):
8
+ def __init__(self, seq_len, emb_dim):
9
+ super().__init__()
10
+ self.seq_len = seq_len
11
+ self.emb_dim = emb_dim
12
+
13
+ position = torch.arange(0, seq_len).unsqueeze(1)
14
+ div_term = torch.exp(torch.arange(0, emb_dim, 2) * -(math.log(10000.0) / emb_dim))
15
+
16
+ pe = torch.zeros(seq_len, emb_dim)
17
+ pe[:, 0::2] = torch.sin(position * div_term)
18
+ pe[:, 1::2] = torch.cos(position * div_term)
19
+
20
+ self.register_buffer("pe", pe)
21
+
22
+ def forward(self, x):
23
+ # x shape: (batch_size, seq_len, emb_dim) or (batch_size, seq_len)
24
+ batch_size, seq_len = x.shape[0], x.shape[1]
25
+ out = self.pe[:seq_len].unsqueeze(0).expand(batch_size, seq_len, -1)
26
+ return out.to(device=x.device,dtype=x.dtype)
27
+
28
+ # --- multi-head attention ---
29
+ class Multi_Head_Attention(nn.Module):
30
+ def __init__(self, emb_dim, num_heads, dropout, device='cpu',dtype=torch.float32):
31
+ super().__init__()
32
+ assert emb_dim % num_heads == 0, "emb_dim must be divisible by num_heads"
33
+ self.emb_dim = emb_dim
34
+ self.num_heads = num_heads
35
+ self.device = device
36
+ self.head_dim = emb_dim // num_heads
37
+
38
+ self.key = nn.Linear(emb_dim, emb_dim, bias=False,dtype=dtype,device=device)
39
+ self.query = nn.Linear(emb_dim, emb_dim, bias=False,dtype=dtype,device=device)
40
+ self.value = nn.Linear(emb_dim, emb_dim, bias=False,dtype=dtype,device=device)
41
+
42
+ self.scale = torch.tensor(self.head_dim ** 0.5,device=device,dtype=dtype)
43
+ self.out_proj = nn.Linear(emb_dim, emb_dim,dtype=dtype,device=device)
44
+
45
+ self.dropout = nn.Dropout(dropout)
46
+
47
+ def forward(self, x):
48
+ Batch_size, Seq_len, _ = x.shape
49
+
50
+ # Generate Q, K, V and reshape for multi-head attention
51
+ Keys = self.key(x).view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (Batch_size, nh, Seq_len, hd)
52
+ Querys = self.query(x).view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2)
53
+ Values = self.value(x).view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2)
54
+
55
+ # Compute attention scores
56
+ scores = (Querys @ Keys.transpose(-2, -1)) / self.scale # (Batch_size, nh, Seq_len, Seq_len)
57
+
58
+ # Apply causal mask if requested
59
+ causal_mask = torch.triu(torch.ones(Seq_len, Seq_len, dtype=torch.bool, device=self.device), diagonal=1)
60
+ scores = scores.masked_fill_(causal_mask[None, None, :, :], float('-inf'))
61
+
62
+ # Apply softmax and dropout
63
+ attn = F.softmax(scores, dim=-1)
64
+ attn = self.dropout(attn)
65
+
66
+ # Apply attention to values
67
+ out = attn @ Values # (Batch_size, nh, Seq_len, hd)
68
+
69
+ # Concatenate heads and project
70
+ out = out.transpose(1, 2).contiguous().view(Batch_size, Seq_len, self.emb_dim) # (Batch_size, Seq_len, emb_dim)
71
+
72
+ return self.out_proj(out)
73
+
74
+ # --- cross-attention ---
75
+ class Cross_MultiHead_Attention(nn.Module):
76
+ def __init__(self, emb_dim, num_heads, dropout,device='cpu', dtype=torch.float32):
77
+ super().__init__()
78
+ assert emb_dim % num_heads == 0, "emb_dim must be divisible by num_heads"
79
+ self.emb_dim = emb_dim
80
+ self.device = device
81
+ self.num_heads = num_heads
82
+ self.head_dim = emb_dim // num_heads
83
+
84
+ # Querys, Key, Value projections
85
+ self.query = nn.Linear(emb_dim, emb_dim, bias=False,dtype=dtype,device=device)
86
+ self.key = nn.Linear(emb_dim, emb_dim, bias=False,dtype=dtype,device=device)
87
+ self.value = nn.Linear(emb_dim, emb_dim, bias=False,dtype=dtype,device=device)
88
+
89
+ self.scale = torch.tensor(self.head_dim ** 0.5,device=device,dtype=dtype)
90
+
91
+ self.out_proj = nn.Linear(emb_dim, emb_dim,dtype=dtype,device=device)
92
+ self.dropout = nn.Dropout(dropout)
93
+
94
+ def forward(self, x, context=None):
95
+ Batch_size, query_seq_len, _ = x.shape
96
+ context = x if context is None else context # self-attention fallback
97
+ KV_seq_len = context.shape[1]
98
+
99
+ # Project Q, K, V
100
+ Querys = self.query(x).view(Batch_size, query_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
101
+ Keys = self.key(context).view(Batch_size, KV_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
102
+ Values = self.value(context).view(Batch_size, KV_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
103
+
104
+ # Attention scores
105
+ scores = (Querys @ Keys.transpose(-2, -1)) / self.scale # (Batch_size, nh, query_seq_len, KV_seq_len)
106
+
107
+ causal_mask = torch.triu(torch.ones(query_seq_len, query_seq_len, dtype=torch.bool, device=self.device), diagonal=1)
108
+ scores = scores.masked_fill_(causal_mask[None, None, :, :], float('-inf'))
109
+
110
+ attn = F.softmax(scores, dim=-1)
111
+ attn = self.dropout(attn)
112
+
113
+ out = attn @ Values
114
+ out = out.transpose(1, 2).contiguous().view(Batch_size, query_seq_len, self.emb_dim) # (Batch_size, query_seq_len, emb_dim)
115
+
116
+ return self.out_proj(out)
117
+
118
+ # --- Feed Forward ---
119
+ class FF_ReLU(nn.Module):
120
+ def __init__(self, emb_dim, hidden_dim, dropout=0.1, device='cpu', dtype=torch.float32):
121
+ super().__init__()
122
+ self.relu = nn.Sequential(
123
+ nn.Linear(emb_dim, hidden_dim, device=device, dtype=dtype),
124
+ nn.ReLU(),
125
+ nn.Dropout(dropout),
126
+ nn.Linear(hidden_dim, emb_dim, device=device, dtype=dtype),
127
+ )
128
+
129
+ def forward(self, x):
130
+ return self.relu(x)
131
+
132
+ class LayerNorm(nn.Module):
133
+ def __init__(self, emb_dim, eps=1e-5, device='cpu', dtype=torch.float32):
134
+ super().__init__()
135
+ self.eps = eps
136
+ self.weight = nn.Parameter(torch.ones(emb_dim, device=device, dtype=dtype))
137
+ self.bias = nn.Parameter(torch.zeros(emb_dim, device=device, dtype=dtype))
138
+
139
+ def forward(self, x):
140
+ mean = x.mean(dim=-1, keepdim=True)
141
+ var = x.var(dim=-1, keepdim=True, unbiased=False)
142
+ norm_x = (x - mean) / torch.sqrt(var + self.eps)
143
+ return norm_x * self.weight + self.bias
144
+
145
+ class Encoder(nn.Module):
146
+ def __init__(self, emb_dim, num_heads, dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
147
+ super().__init__()
148
+ self.attention = Multi_Head_Attention(emb_dim, num_heads, dropout, device=device, dtype=dtype)
149
+ self.norm1 = LayerNorm(emb_dim, eps=eps, device=device, dtype=dtype)
150
+ self.ff_relu = FF_ReLU(emb_dim, hidden_dim, dropout, device=device, dtype=dtype)
151
+ self.norm2 = LayerNorm(emb_dim, eps=eps, device=device, dtype=dtype)
152
+
153
+ def forward(self, x):
154
+ residual = x
155
+ x = self.attention(x)
156
+ x = self.norm1(x)
157
+ x = x + residual
158
+
159
+ residual = x
160
+ x = self.ff_relu(x)
161
+ x = self.norm2(x)
162
+ x = x + residual
163
+
164
+ return x
165
+
166
+ class Decoder(nn.Module):
167
+ def __init__(self, emb_dim, num_heads, dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
168
+ super().__init__()
169
+ self.attention = Multi_Head_Attention(emb_dim, num_heads, dropout, device=device, dtype=dtype)
170
+ self.norm1 = LayerNorm(emb_dim, eps=eps, device=device, dtype=dtype)
171
+ self.cross_attention = Cross_MultiHead_Attention(emb_dim, num_heads, dropout, device=device, dtype=dtype)
172
+ self.norm2 = LayerNorm(emb_dim, eps=eps, device=device, dtype=dtype)
173
+ self.ff_relu = FF_ReLU(emb_dim, hidden_dim, dropout, device=device, dtype=dtype)
174
+ self.norm3 = LayerNorm(emb_dim, eps=eps, device=device, dtype=dtype)
175
+
176
+ def forward(self, x, enc_output):
177
+ residual = x
178
+ x = self.attention(x)
179
+ x = self.norm1(x)
180
+ x = x + residual
181
+
182
+ residual = x
183
+ x = self.cross_attention(x, context = enc_output)
184
+ x = self.norm2(x)
185
+ x = x + residual
186
+
187
+ residual = x
188
+ x = self.ff_relu(x)
189
+ x = self.norm3(x)
190
+ x = x + residual
191
+
192
+ return x
193
+
194
+ class transformer(nn.Module):
195
+ def __init__(self, vocab_size, emb_dim, num_heads, dropout, hidden_dim,
196
+ encoder_layers, decoder_layers, seq_len, eps=1e-5, device='cpu', dtype=torch.float32,
197
+ ):
198
+ super().__init__()
199
+ self.encoder_layers = encoder_layers
200
+ self.decoder_layers = decoder_layers
201
+
202
+ self.token_emb = nn.Embedding(vocab_size, emb_dim, device=device, dtype=dtype)
203
+ self.pos = SinusoidalPositionalEmbedding(seq_len=seq_len, emb_dim=emb_dim)
204
+
205
+ self.encoder_stack = nn.ModuleList([
206
+ Encoder(emb_dim, num_heads, dropout, hidden_dim, eps=eps, device=device, dtype=dtype)
207
+ for _ in range(encoder_layers)
208
+ ])
209
+
210
+ self.decoder_stack = nn.ModuleList([
211
+ Decoder(emb_dim, num_heads, dropout, hidden_dim, eps=eps, device=device, dtype=dtype)
212
+ for _ in range(decoder_layers)
213
+ ])
214
+
215
+ # --- final norm ---
216
+ self.final_norm = LayerNorm(emb_dim, eps=eps, device=device, dtype=dtype)
217
+
218
+ # --- output projection ---
219
+ self.out_proj = nn.Linear(emb_dim, vocab_size, device=device, dtype=dtype)
220
+
221
+ def encoder(self, x):
222
+ x = self.token_emb(x) + self.pos(x)
223
+ for block in self.encoder_stack:
224
+ x = block(x)
225
+ return x
226
+
227
+ def decoder(self, x, enc_output):
228
+ x = self.token_emb(x) + self.pos(x)
229
+ for block in self.decoder_stack:
230
+ x = block(x, enc_output)
231
+ return x
232
+
233
+ def forward(self, source, target):
234
+ enc_output = self.encoder(source)
235
+ out = self.decoder(target, enc_output)
236
+ out = self.final_norm(out)
237
+ out = self.out_proj(out)
238
+ return out
@@ -1,6 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
  import torch.nn.functional as F
4
+
4
5
  class Self_Attention(nn.Module):
5
6
  def __init__(self, Emb_dim, dropout,dtype=torch.float32,device='cpu'):
6
7
  super().__init__()
@@ -412,23 +413,23 @@ class kv_cache_multihead(nn.Module):
412
413
  self.dtype = dtype
413
414
  self.device = device
414
415
 
415
- assert emb_dim % num_heads == 0
416
+ assert emb_dim % num_heads == 0, "emb_dim must be divisible by num_heads"
416
417
  self.emb_dim = emb_dim
417
418
  self.num_heads = num_heads
418
419
  self.head_dim = emb_dim // num_heads
419
420
  self.kv_seq_len = kv_seq_len
420
421
 
421
- self.query = nn.Linear(emb_dim, emb_dim, bias=False,dtype=dtype,device=device)
422
- self.key = nn.Linear(emb_dim, emb_dim, bias=False,dtype=dtype,device=device)
423
- self.value = nn.Linear(emb_dim, emb_dim, bias=False,dtype=dtype,device=device)
422
+ self.query = nn.Linear(emb_dim, emb_dim, bias=False, dtype=dtype, device=device)
423
+ self.key = nn.Linear(emb_dim, emb_dim, bias=False, dtype=dtype, device=device)
424
+ self.value = nn.Linear(emb_dim, emb_dim, bias=False, dtype=dtype, device=device)
424
425
 
425
- self.out_proj = nn.Linear(emb_dim, emb_dim,dtype=dtype,device=device)
426
+ self.out_proj = nn.Linear(emb_dim, emb_dim, dtype=dtype, device=device)
426
427
  self.dropout = nn.Dropout(dropout)
428
+ # KV caches
429
+ self.register_buffer("cache_keys", torch.zeros(batch_size, kv_seq_len*2, num_heads, self.head_dim,device=device,dtype=dtype))
430
+ self.register_buffer("cache_value", torch.zeros(batch_size, kv_seq_len*2, num_heads, self.head_dim,device=device,dtype=dtype))
427
431
 
428
- self.cache_keys = torch.zeros(batch_size, kv_seq_len, num_heads, self.head_dim,dtype=dtype,device=device)
429
- self.cache_value = torch.zeros(batch_size, kv_seq_len, num_heads, self.head_dim,dtype=dtype,device=device)
430
-
431
- def forward(self, x, start_pos, RoPE: False):
432
+ def forward(self, x, start_pos, RoPE=False):
432
433
  batch_size, seq_len, C = x.shape
433
434
 
434
435
  xq = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
@@ -441,36 +442,35 @@ class kv_cache_multihead(nn.Module):
441
442
  freq_complex = precompute_theta_position_frequency(head_dim=self.head_dim, seq_len=self.kv_seq_len, device=self.device)
442
443
  xk = apply_rotry_position_embedding(xk, freq_complex, device=self.device, dtype=self.dtype)
443
444
 
444
- # Cache keys and values
445
- self.cache_keys[:, start_pos:start_pos+seq_len] = xk
446
- self.cache_value[:, start_pos:start_pos+seq_len] = xv
447
-
448
- xk_full = self.cache_keys[:, :start_pos+seq_len]
449
- xv_full = self.cache_value[:, :start_pos+seq_len]
445
+ # Cache keys and values - only update the batch_size portion we're using
446
+ self.cache_keys[:batch_size, start_pos:start_pos+seq_len] = xk
447
+ self.cache_value[:batch_size, start_pos:start_pos+seq_len] = xv
450
448
 
449
+ # Only use the relevant batch portion from cache
450
+ xk_full = self.cache_keys[:batch_size, :start_pos+seq_len]
451
+ xv_full = self.cache_value[:batch_size, :start_pos+seq_len]
452
+
451
453
  query = xq.transpose(1, 2) # (batch_size, num_head, seq_len, emb_dim)
452
454
  key = xk_full.transpose(1, 2) # (batch_size, num_head, T_total, emb_dim)
453
455
  value = xv_full.transpose(1, 2) # (batch_size, num_head, T_total, emb_dim)
454
456
 
455
457
  attn_scores = torch.matmul(query, key.transpose(2, 3)) / (self.head_dim ** 0.5)
456
-
457
458
  # Causal mask
458
- causal_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool, device=self.device), diagonal=1)
459
+ causal_mask = torch.triu(torch.ones(attn_scores.shape[-2], attn_scores.shape[-1], dtype=torch.bool, device=self.device), diagonal=1)
459
460
  attn_scores.masked_fill_(causal_mask[None, None, :, :], float('-inf'))
460
461
 
461
462
  attn_weights = F.softmax(attn_scores, dim=-1)
462
463
  out = torch.matmul(attn_weights, value)
463
-
464
464
  out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
465
465
  return self.dropout(self.out_proj(out))
466
466
 
467
467
  class kv_cache_group_query(nn.Module):
468
- def __init__(self, emb_dim, query_num_heads, kv_num_heads, batch_size, kv_seq_len,device='cpu' , dtype=torch.float32 , dropout=0.1):
468
+ def __init__(self, emb_dim, query_num_heads, kv_num_heads, batch_size, kv_seq_len, device='cpu', dtype=torch.float32, dropout=0.1):
469
469
  super().__init__()
470
470
  self.dtype = dtype
471
471
  self.device = device
472
472
 
473
- assert query_num_heads % kv_num_heads == 0, "query heads must be divisible by kv heads"
473
+ assert query_num_heads % kv_num_heads == 0, "query heads must be divisible by kv heads"
474
474
  assert emb_dim % query_num_heads == 0, "embedding must be divisible by query heads"
475
475
 
476
476
  self.emb_dim = emb_dim
@@ -488,8 +488,8 @@ class kv_cache_group_query(nn.Module):
488
488
  self.dropout = nn.Dropout(dropout)
489
489
 
490
490
  # KV caches
491
- self.register_buffer("cache_keys", torch.zeros(batch_size, kv_seq_len, kv_num_heads, self.head_dim,device=device,dtype=dtype))
492
- self.register_buffer("cache_value", torch.zeros(batch_size, kv_seq_len, kv_num_heads, self.head_dim,device=device,dtype=dtype))
491
+ self.register_buffer("cache_keys", torch.zeros(batch_size, kv_seq_len*2, kv_num_heads, self.head_dim,device=device,dtype=dtype))
492
+ self.register_buffer("cache_value", torch.zeros(batch_size, kv_seq_len*2, kv_num_heads, self.head_dim,device=device,dtype=dtype))
493
493
 
494
494
  def forward(self, x, start_pos, RoPE=False):
495
495
  batch_size, seq_len, _ = x.shape
@@ -501,15 +501,16 @@ class kv_cache_group_query(nn.Module):
501
501
  if RoPE:
502
502
  freq_q = precompute_theta_position_frequency(head_dim=self.head_dim, seq_len=seq_len, device=self.device)
503
503
  xq = apply_rotry_position_embedding(xq, freq_q, device=self.device, dtype=self.dtype)
504
- freq_k = precompute_theta_position_frequency(head_dim=self.head_dim, seq_len=self.kv_seq_len, device=self.device)
504
+ freq_k = precompute_theta_position_frequency(head_dim=self.head_dim, seq_len=seq_len, device=self.device)
505
505
  xk = apply_rotry_position_embedding(xk, freq_k, device=self.device, dtype=self.dtype)
506
- # Cache
507
- self.cache_keys[:, start_pos:start_pos+seq_len] = xk
508
- self.cache_value[:, start_pos:start_pos+seq_len] = xv
509
-
510
- xk_full = self.cache_keys[:, :start_pos+seq_len] # [B, T, kv_heads, D]
511
- xv_full = self.cache_value[:, :start_pos+seq_len]
512
506
 
507
+ # Cache keys and values - only update the batch_size portion we're using
508
+ self.cache_keys[:batch_size, start_pos:start_pos+seq_len] = xk
509
+ self.cache_value[:batch_size, start_pos:start_pos+seq_len] = xv
510
+ # Only use the relevant batch portion from cache
511
+ xk_full = self.cache_keys[:batch_size, :start_pos+seq_len]
512
+ xv_full = self.cache_value[:batch_size, :start_pos+seq_len]
513
+
513
514
  # Transpose for attention: [B, H, T, D]
514
515
  query = xq.transpose(1, 2) # [B, q_heads, seq_len, D]
515
516
  key = xk_full.transpose(1, 2) # [B, kv_heads, total_kv_len, D]
@@ -518,16 +519,14 @@ class kv_cache_group_query(nn.Module):
518
519
  # Repeat keys and values to match query heads
519
520
  key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
520
521
  value = value.repeat_interleave(self.num_queries_per_kv, dim=1)
521
-
522
522
  # Attention
523
523
  attn_scores = torch.matmul(query, key.transpose(2, 3)) / (self.head_dim ** 0.5)
524
-
525
524
  # Causal mask
526
- causal_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool, device=self.device), diagonal=1)
525
+ causal_mask = torch.triu(torch.ones(attn_scores.shape[-2], attn_scores.shape[-1], dtype=torch.bool, device=self.device), diagonal=1)
527
526
  attn_scores.masked_fill_(causal_mask[None, None, :, :], float('-inf'))
528
-
527
+ #softmax
529
528
  attn_weights = F.softmax(attn_scores, dim=-1)
529
+ # atten weight @ value
530
530
  out = torch.matmul(attn_weights, value)
531
-
532
531
  out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.emb_dim)
533
532
  return self.dropout(self.out_proj(out))
@@ -33,4 +33,4 @@ def global_mask(Seq_len, global_index):
33
33
  for g in global_index:
34
34
  mask[g,:] = 1
35
35
  mask[:,global_index_tensor] = 1
36
- return ~mask
36
+ return ~mask.bool()
@@ -14,7 +14,8 @@ class AbsolutePositionEmbedding(nn.Module):
14
14
  batch_size, seq_len = x.shape[0], x.shape[1]
15
15
  positions = torch.arange(0, seq_len)
16
16
  abs_pos = self.embedding(positions) # (seq_len, emb_dim)
17
- return abs_pos.unsqueeze(0).expand(batch_size, seq_len, -1).to(x.device)
17
+ out = abs_pos.unsqueeze(0).expand(batch_size, seq_len, -1)
18
+ return out.to(device=x.device,dtype=x.dtype)
18
19
 
19
20
  # --- Sinusoidal Positional Embedding ---
20
21
  class SinusoidalPositionalEmbedding(nn.Module):
@@ -35,8 +36,9 @@ class SinusoidalPositionalEmbedding(nn.Module):
35
36
  def forward(self, x):
36
37
  # x shape: (batch_size, seq_len, emb_dim) or (batch_size, seq_len)
37
38
  batch_size, seq_len = x.shape[0], x.shape[1]
38
- return self.pe[:seq_len].unsqueeze(0).expand(batch_size, seq_len, -1).to(x.device)
39
-
39
+ out = self.pe[:seq_len].unsqueeze(0).expand(batch_size, seq_len, -1)
40
+ return out.to(device=x.device,dtype=x.dtype)
41
+
40
42
  # --- RoPE ---
41
43
  class RoPE(nn.Module):
42
44
  def __init__(self, head_dim, seq_len, theta=10000.0, device='cpu', dtype=torch.float32):
@@ -58,4 +60,4 @@ class RoPE(nn.Module):
58
60
  freqs = self.freq_complex[:seq_len].unsqueeze(0).unsqueeze(2) # (1, seq_len, 1, head_dim//2)
59
61
  x_rotated = x_complex * freqs
60
62
  x_out = torch.view_as_real(x_rotated).contiguous().view(batch_size, seq_len, num_head, emb_dim)
61
- return x_out.to(device=self.device, dtype=self.dtype)
63
+ return x_out.to(device=x.device, dtype=x.dtype)
@@ -0,0 +1,356 @@
1
+ import torch
2
+ import os
3
+ from torch.utils.data import DataLoader
4
+ from torch.optim import AdamW, SGD
5
+ from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, CosineAnnealingWarmRestarts
6
+ from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup
7
+ from tqdm import tqdm
8
+
9
+ class Trainer:
10
+ def __init__(self,
11
+ model,
12
+ train_dataset,
13
+ eval_dataset,
14
+ train_batch_size,
15
+ eval_batch_size,
16
+ vocab_size,
17
+ output_dir,
18
+ num_epoch,
19
+ lr: float,
20
+ scheduler_type=None,
21
+ optimizer_type="adamw",
22
+ eval_per_epoch = 1,
23
+ eval_per_step = None,
24
+ weight_decay=0.01,
25
+ warmup_steps=0,
26
+ grad_accumulation_step=1,
27
+ max_eval_step=None,
28
+ max_steps=None,
29
+ Save_step=None,
30
+ Save_epoch=None,
31
+ max_epoch=None,
32
+ model_to_resume=None,
33
+ resume_training=False,
34
+ seed=42,
35
+ device='cpu'):
36
+ self.model = model
37
+ self.train_dataset = train_dataset
38
+ self.train_batch_size = train_batch_size
39
+ self.eval_dataset = eval_dataset
40
+ self.eval_batch_size = eval_batch_size
41
+ self.vocab_size = vocab_size
42
+ self.num_epoch = num_epoch
43
+ self.max_steps = max_steps
44
+ self.max_epoch = max_epoch
45
+ self.eval_per_epoch = eval_per_epoch
46
+ self.eval_per_step = eval_per_step
47
+ self.max_eval_step = max_eval_step
48
+ self.lr = lr
49
+ self.scheduler_type = scheduler_type
50
+ self.output_dir = output_dir
51
+ self.model_to_resume = model_to_resume
52
+ self.resume_training = resume_training
53
+ self.Save_step = Save_step
54
+ self.Save_epoch = Save_epoch
55
+ self.grad_accumulation_step = grad_accumulation_step
56
+ self.optimizer_type = optimizer_type
57
+ self.weight_decay = weight_decay
58
+ self.warmup_steps = warmup_steps
59
+ self.seed = seed
60
+ self.device = device
61
+
62
+ # --- random seed ---
63
+ def set_seed(self, seed):
64
+ torch.manual_seed(seed)
65
+ if self.device == 'cuda' and torch.cuda.is_available():
66
+ torch.cuda.manual_seed_all(seed)
67
+
68
+ # --- scheduler ---
69
+ def get_scheduler(self, scheduler_type, total_training_steps, optimizer):
70
+ if scheduler_type is None:
71
+ return None
72
+ elif scheduler_type == "linear":
73
+ return get_linear_schedule_with_warmup(
74
+ optimizer,
75
+ num_warmup_steps=self.warmup_steps,
76
+ num_training_steps=total_training_steps
77
+ )
78
+ elif scheduler_type == "cosine":
79
+ return get_cosine_schedule_with_warmup(
80
+ optimizer,
81
+ num_warmup_steps=self.warmup_steps,
82
+ num_training_steps=total_training_steps
83
+ )
84
+ elif scheduler_type == "cosine_restarts":
85
+ return get_cosine_with_hard_restarts_schedule_with_warmup(
86
+ optimizer,
87
+ num_warmup_steps=self.warmup_steps,
88
+ num_training_steps=total_training_steps,
89
+ num_cycles=4 # Number of restarts
90
+ )
91
+ elif scheduler_type == "cosineannealing":
92
+ return CosineAnnealingLR(optimizer, T_max=total_training_steps)
93
+ elif scheduler_type == "cosine_warm_restarts":
94
+ return CosineAnnealingWarmRestarts(optimizer, T_0=total_training_steps//4, T_mult=2)
95
+ else:
96
+ raise ValueError(f"Unsupported scheduler type: {scheduler_type}")
97
+
98
+ # --- optimizer ---
99
+ def get_optimizer(self, optimizer_type, model, lr, weight_decay):
100
+ if optimizer_type.lower() == "adamw":
101
+ return AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
102
+ elif optimizer_type.lower() == "sgd":
103
+ return SGD(model.parameters(), lr=lr, weight_decay=weight_decay, momentum=0.9)
104
+ else:
105
+ raise ValueError(f"Unsupported optimizer {optimizer_type}")
106
+
107
+ # --- validate model ---
108
+ def eval_model(self, model, eval_loader, max_val_steps):
109
+ eval_loss = 0
110
+ model.eval()
111
+ max_val_steps = min(max_val_steps or len(eval_loader), len(eval_loader))
112
+ with torch.no_grad():
113
+ pbar = tqdm(eval_loader, total=max_val_steps, desc="Evaluating", leave=False)
114
+ for step, (inputs, targets) in enumerate(pbar):
115
+ inputs = inputs.to(self.device)
116
+ targets = targets.to(self.device)
117
+ output = model(inputs) # shape: [B, T, V]
118
+ loss = torch.nn.functional.cross_entropy(
119
+ output.view(-1, output.size(-1)),
120
+ targets.view(-1), ignore_index=-100)
121
+ pbar.set_postfix(loss=loss.item())
122
+ eval_loss += loss.item()
123
+ if step + 1 >= max_val_steps:
124
+ break
125
+ model.train()
126
+ avg_eval_loss = eval_loss / max_val_steps
127
+ return avg_eval_loss
128
+
129
+ # --- train dataloader ---
130
+ def get_train_loader(self, train_dataset, batch_size, seed):
131
+ generator = torch.Generator()
132
+ generator.manual_seed(seed)
133
+ train_loader = DataLoader(
134
+ train_dataset,
135
+ batch_size=batch_size,
136
+ shuffle=True,
137
+ generator=generator,
138
+ pin_memory=True if self.device == 'cuda' else False
139
+ )
140
+ return train_loader
141
+
142
+ # --- validation dataloader ---
143
+ def get_eval_loader(self, eval_dataset, batch_size, seed):
144
+ generator = torch.Generator()
145
+ generator.manual_seed(seed)
146
+ eval_loader = DataLoader(
147
+ eval_dataset,
148
+ batch_size=batch_size,
149
+ shuffle=False,
150
+ generator=generator,
151
+ pin_memory=True if self.device == 'cuda' else False
152
+ )
153
+ return eval_loader
154
+
155
+ # --- save model ---
156
+ def save_model(self, model, optimizer, scheduler, epoch, num_epoch, loss, global_step,
157
+ accumulated_steps, batch_idx_to_resume, output_dir, name):
158
+ checkpoint = {
159
+ 'model_state_dict': model.state_dict(),
160
+ 'optimizer_state_dict': optimizer.state_dict(),
161
+ 'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
162
+ 'current_epoch': epoch,
163
+ 'num_epoch': num_epoch,
164
+ 'loss': loss,
165
+ 'accumulated_steps': accumulated_steps,
166
+ 'global_step': global_step,
167
+ 'batch_idx_to_resume': batch_idx_to_resume,
168
+ 'rng_state': {
169
+ 'torch': torch.get_rng_state(),
170
+ 'cuda': torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None
171
+ }
172
+ }
173
+ os.makedirs(output_dir, exist_ok=True)
174
+ path = f'{output_dir}/checkpoint_{name}.pt'
175
+ torch.save(checkpoint, path)
176
+ print(f'Saved training state to {path}')
177
+
178
+ def load_checkpoint(self, path, model, optimizer, scheduler):
179
+ checkpoint = torch.load(path, map_location=self.device)
180
+ model.load_state_dict(checkpoint['model_state_dict'])
181
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
182
+ if checkpoint.get('scheduler_state_dict') is not None and scheduler is not None:
183
+ scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
184
+ current_epoch = checkpoint['current_epoch']
185
+ global_step = checkpoint['global_step']
186
+ loss = checkpoint['loss']
187
+ num_epoch = checkpoint['num_epoch']
188
+ accumulated_steps = checkpoint['accumulated_steps']
189
+ batch_idx_to_resume = checkpoint['batch_idx_to_resume']
190
+ # RNG
191
+ torch.set_rng_state(checkpoint['rng_state']['torch'])
192
+ if torch.cuda.is_available() and checkpoint['rng_state']['cuda']:
193
+ torch.cuda.set_rng_state_all(checkpoint['rng_state']['cuda'])
194
+ return {
195
+ 'current_epoch': current_epoch,
196
+ 'num_epoch': num_epoch,
197
+ 'accumulated_steps' : accumulated_steps,
198
+ 'batch_idx_to_resume': batch_idx_to_resume,
199
+ 'global_step': global_step,
200
+ 'loss': loss
201
+ }
202
+
203
+ # --- train ---
204
+ def train(self):
205
+ # --- seed ---
206
+ self.set_seed(self.seed)
207
+ # --- dataloader ---
208
+ train_loader = self.get_train_loader(self.train_dataset, self.train_batch_size, self.seed)
209
+ eval_loader = self.get_eval_loader(self.eval_dataset, self.eval_batch_size, self.seed)
210
+
211
+ # --- Calculate the total step ---
212
+ steps_per_epoch = len(train_loader) // self.grad_accumulation_step
213
+ total_training_steps = self.max_steps if self.max_steps is not None else steps_per_epoch * self.num_epoch
214
+
215
+ os.makedirs(self.output_dir, exist_ok=True)
216
+
217
+ model = self.model.to(self.device)
218
+ optimizer = self.get_optimizer(self.optimizer_type, model, self.lr, self.weight_decay)
219
+ criterion = torch.nn.functional.cross_entropy
220
+ scheduler = self.get_scheduler(self.scheduler_type, total_training_steps, optimizer)
221
+
222
+ global_step = 0
223
+ start_epoch = 0
224
+ num_epoch = self.num_epoch
225
+ batch_idx_to_resume = 0
226
+ accumulated_steps = 0
227
+
228
+ if self.resume_training and self.model_to_resume:
229
+ ckpt_data = self.load_checkpoint(self.model_to_resume, model, optimizer, scheduler)
230
+ start_epoch = ckpt_data['current_epoch']
231
+ global_step = ckpt_data['global_step']
232
+ num_epoch = ckpt_data['num_epoch']
233
+ batch_idx_to_resume = ckpt_data['batch_idx_to_resume']
234
+ accumulated_steps = ckpt_data['accumulated_steps']
235
+ print(f"♻️ Resuming training from epoch {start_epoch}, step {global_step}")
236
+
237
+ # --- print info ---
238
+ print(f"🧠 Number of parameters: {sum(p.numel() for p in self.model.parameters()):,}")
239
+ print(f"🍱 Number of train samples: {len(self.train_dataset):,}")
240
+ print(f"📊 Number of eval samples: {len(self.eval_dataset):,}")
241
+ print(f"📦 Train steps per epoch (batches): {len(train_loader):,}")
242
+ print(f"📦 Eval steps per epoch (batches): {len(eval_loader):,}")
243
+
244
+ for epoch in range(start_epoch, num_epoch):
245
+ model.train()
246
+ epoch_loss = 0
247
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epoch}", leave=False)
248
+ for batch_idx, batch in enumerate(pbar):
249
+ if epoch == start_epoch and self.resume_training:
250
+ if batch_idx < batch_idx_to_resume:
251
+ continue
252
+ elif batch_idx == batch_idx_to_resume:
253
+ batch_idx_to_resume = 0
254
+
255
+ # --- load the inputs and targets ---
256
+ inputs, targets = batch
257
+ inputs = inputs.to(self.device, non_blocking=True)
258
+ targets = targets.to(self.device, non_blocking=True)
259
+
260
+ # --- get prediction from model ---
261
+ output = model(inputs)
262
+
263
+ # --- calculate loss ---
264
+ loss = criterion(
265
+ output.view(-1, self.vocab_size),
266
+ targets.view(-1),
267
+ ignore_index=-100
268
+ )
269
+ loss = loss / self.grad_accumulation_step
270
+ loss.backward()
271
+
272
+ pbar.set_postfix(loss=loss.item() * self.grad_accumulation_step)
273
+ epoch_loss += loss.item() * self.grad_accumulation_step
274
+ accumulated_steps += 1
275
+
276
+ # --- gradient accumulation ---
277
+ if accumulated_steps % self.grad_accumulation_step == 0:
278
+ # Gradient clipping for stable training
279
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
280
+ optimizer.step()
281
+ if scheduler is not None:
282
+ scheduler.step()
283
+ optimizer.zero_grad()
284
+ global_step += 1
285
+ accumulated_steps = 0
286
+
287
+ is_last_step = (self.max_steps is not None and global_step >= self.max_steps)
288
+ # check eval_per_step
289
+ if (self.eval_per_step is not None and global_step+1 % self.eval_per_step == 0) or is_last_step:
290
+ avg_eval_loss = self.eval_model(model, eval_loader, self.max_eval_step)
291
+ print(f"🎯 Eval loss: {avg_eval_loss:.4f}")
292
+
293
+ # Check max steps
294
+ if is_last_step:
295
+ self.save_model(
296
+ model=model, optimizer=optimizer, scheduler=scheduler,
297
+ epoch=epoch+1, num_epoch=num_epoch, loss=epoch_loss,
298
+ global_step=global_step, output_dir=self.output_dir,
299
+ batch_idx_to_resume=batch_idx+1,accumulated_steps=accumulated_steps,
300
+ name=f'final_step_epoch_{epoch+1}_step_{global_step}'
301
+ )
302
+ return
303
+
304
+ # Save at specific steps
305
+ if (self.Save_step is not None and
306
+ global_step > 0 and
307
+ global_step % self.Save_step == 0):
308
+ self.save_model(
309
+ model=model, optimizer=optimizer, scheduler=scheduler,
310
+ epoch=epoch+1, num_epoch=num_epoch, loss=epoch_loss,
311
+ global_step=global_step, output_dir=self.output_dir,
312
+ batch_idx_to_resume=batch_idx+1,accumulated_steps=accumulated_steps,
313
+ name=f'epoch_{epoch+1}_step_{global_step}'
314
+ )
315
+
316
+ # Handle remaining accumulated gradients at epoch end
317
+ if accumulated_steps > 0:
318
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
319
+ optimizer.step()
320
+ if scheduler is not None:
321
+ scheduler.step()
322
+ optimizer.zero_grad()
323
+ global_step += 1
324
+
325
+ is_last_epoch = (self.max_epoch is not None and (epoch+1) == self.max_epoch)
326
+
327
+ # Evaluation
328
+ if (self.eval_per_epoch is not None and (epoch+1) % self.eval_per_epoch == 0) or is_last_epoch:
329
+ avg_eval_loss = self.eval_model(model, eval_loader, self.max_eval_step)
330
+ print(f"🎯 Eval loss: {avg_eval_loss:.4f}")
331
+
332
+ # Check max epoch
333
+ if is_last_epoch:
334
+ self.save_model(
335
+ model=model, optimizer=optimizer, scheduler=scheduler,
336
+ epoch=epoch+1, num_epoch=num_epoch, loss=epoch_loss,
337
+ global_step=global_step, output_dir=self.output_dir,
338
+ batch_idx_to_resume=batch_idx+1,accumulated_steps=accumulated_steps,
339
+ name=f'final_model_epoch_{epoch+1}_step_{global_step}'
340
+ )
341
+ return
342
+
343
+ # print epoch loss
344
+ avg_epoch_loss = epoch_loss / len(train_loader)
345
+ print(f"🔥 Epoch {epoch+1} finished - Training Loss: {avg_epoch_loss:.4f}")
346
+
347
+ # Save at specific epochs
348
+ if (self.Save_epoch is not None and
349
+ (epoch + 1) % self.Save_epoch == 0):
350
+ self.save_model(
351
+ model=model, optimizer=optimizer, scheduler=scheduler,
352
+ epoch=epoch+1, num_epoch=num_epoch, loss=epoch_loss,
353
+ global_step=global_step, output_dir=self.output_dir,
354
+ batch_idx_to_resume=batch_idx+1,accumulated_steps=accumulated_steps,
355
+ name=f'epoch_{epoch+1}_step_{global_step}'
356
+ )
@@ -1,18 +0,0 @@
1
- LICENSE
2
- README.md
3
- pyproject.toml
4
- setup.py
5
- Stackformer.egg-info/PKG-INFO
6
- Stackformer.egg-info/SOURCES.txt
7
- Stackformer.egg-info/dependency_links.txt
8
- Stackformer.egg-info/requires.txt
9
- Stackformer.egg-info/top_level.txt
10
- models/GPT_2.py
11
- models/__init__.py
12
- modules/Attention.py
13
- modules/Feed_forward.py
14
- modules/Normalization.py
15
- modules/__init__.py
16
- modules/mask.py
17
- modules/position_embedding.py
18
- modules/tokenizer.py
@@ -1,2 +0,0 @@
1
- torch>=2.6
2
- tqdm>=4.67
@@ -1,2 +0,0 @@
1
- models
2
- modules
File without changes
File without changes