Stackformer 0.1.0__tar.gz → 0.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {stackformer-0.1.0 → stackformer-0.1.2}/PKG-INFO +9 -4
- {stackformer-0.1.0 → stackformer-0.1.2}/README.md +7 -2
- {stackformer-0.1.0 → stackformer-0.1.2}/Stackformer.egg-info/PKG-INFO +9 -4
- stackformer-0.1.2/Stackformer.egg-info/SOURCES.txt +22 -0
- stackformer-0.1.2/Stackformer.egg-info/requires.txt +2 -0
- stackformer-0.1.2/Stackformer.egg-info/top_level.txt +1 -0
- {stackformer-0.1.0 → stackformer-0.1.2}/pyproject.toml +2 -2
- {stackformer-0.1.0 → stackformer-0.1.2}/setup.py +2 -2
- stackformer-0.1.2/stackformer/__init__.py +38 -0
- stackformer-0.1.2/stackformer/models/Meta.py +213 -0
- stackformer-0.1.0/models/GPT_2.py → stackformer-0.1.2/stackformer/models/OpenAI.py +7 -3
- stackformer-0.1.2/stackformer/models/Transformer.py +238 -0
- {stackformer-0.1.0 → stackformer-0.1.2/stackformer}/modules/Attention.py +33 -34
- {stackformer-0.1.0 → stackformer-0.1.2/stackformer}/modules/mask.py +1 -1
- {stackformer-0.1.0 → stackformer-0.1.2/stackformer}/modules/position_embedding.py +6 -4
- stackformer-0.1.2/stackformer/trainer.py +356 -0
- stackformer-0.1.0/Stackformer.egg-info/SOURCES.txt +0 -18
- stackformer-0.1.0/Stackformer.egg-info/requires.txt +0 -2
- stackformer-0.1.0/Stackformer.egg-info/top_level.txt +0 -2
- {stackformer-0.1.0 → stackformer-0.1.2}/LICENSE +0 -0
- {stackformer-0.1.0 → stackformer-0.1.2}/Stackformer.egg-info/dependency_links.txt +0 -0
- {stackformer-0.1.0 → stackformer-0.1.2}/setup.cfg +0 -0
- {stackformer-0.1.0 → stackformer-0.1.2/stackformer}/models/__init__.py +0 -0
- {stackformer-0.1.0 → stackformer-0.1.2/stackformer}/modules/Feed_forward.py +0 -0
- {stackformer-0.1.0 → stackformer-0.1.2/stackformer}/modules/Normalization.py +0 -0
- {stackformer-0.1.0 → stackformer-0.1.2/stackformer}/modules/__init__.py +0 -0
- {stackformer-0.1.0 → stackformer-0.1.2/stackformer}/modules/tokenizer.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: Stackformer
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.2
|
|
4
4
|
Summary: Modular transformer blocks built in PyTorch
|
|
5
5
|
Home-page: https://github.com/Gurumurthy30/Stackformer
|
|
6
6
|
Author: Gurumurthy
|
|
@@ -12,7 +12,7 @@ Project-URL: Discussions, https://github.com/Gurumurthy30/Stackformer/discussion
|
|
|
12
12
|
Requires-Python: >=3.9
|
|
13
13
|
Description-Content-Type: text/markdown
|
|
14
14
|
License-File: LICENSE
|
|
15
|
-
Requires-Dist: torch
|
|
15
|
+
Requires-Dist: torch<2.6,>=2.0
|
|
16
16
|
Requires-Dist: tqdm>=4.67
|
|
17
17
|
Dynamic: author
|
|
18
18
|
Dynamic: home-page
|
|
@@ -61,11 +61,16 @@ stackformer/ \
|
|
|
61
61
|
|
|
62
62
|
## 💻 Installation
|
|
63
63
|
|
|
64
|
-
|
|
64
|
+
✅ Method 1: Install from PyPI:
|
|
65
|
+
```bash
|
|
66
|
+
pip install Stackformer
|
|
67
|
+
import stackformer
|
|
68
|
+
```
|
|
65
69
|
|
|
70
|
+
🔧 Method 2: Clone the repository:
|
|
66
71
|
```bash
|
|
67
72
|
git clone https://github.com/Gurumurthy30/Stackformer
|
|
68
|
-
cd
|
|
73
|
+
cd Stackformer
|
|
69
74
|
pip install -e .
|
|
70
75
|
```
|
|
71
76
|
|
|
@@ -40,11 +40,16 @@ stackformer/ \
|
|
|
40
40
|
|
|
41
41
|
## 💻 Installation
|
|
42
42
|
|
|
43
|
-
|
|
43
|
+
✅ Method 1: Install from PyPI:
|
|
44
|
+
```bash
|
|
45
|
+
pip install Stackformer
|
|
46
|
+
import stackformer
|
|
47
|
+
```
|
|
44
48
|
|
|
49
|
+
🔧 Method 2: Clone the repository:
|
|
45
50
|
```bash
|
|
46
51
|
git clone https://github.com/Gurumurthy30/Stackformer
|
|
47
|
-
cd
|
|
52
|
+
cd Stackformer
|
|
48
53
|
pip install -e .
|
|
49
54
|
```
|
|
50
55
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: Stackformer
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.2
|
|
4
4
|
Summary: Modular transformer blocks built in PyTorch
|
|
5
5
|
Home-page: https://github.com/Gurumurthy30/Stackformer
|
|
6
6
|
Author: Gurumurthy
|
|
@@ -12,7 +12,7 @@ Project-URL: Discussions, https://github.com/Gurumurthy30/Stackformer/discussion
|
|
|
12
12
|
Requires-Python: >=3.9
|
|
13
13
|
Description-Content-Type: text/markdown
|
|
14
14
|
License-File: LICENSE
|
|
15
|
-
Requires-Dist: torch
|
|
15
|
+
Requires-Dist: torch<2.6,>=2.0
|
|
16
16
|
Requires-Dist: tqdm>=4.67
|
|
17
17
|
Dynamic: author
|
|
18
18
|
Dynamic: home-page
|
|
@@ -61,11 +61,16 @@ stackformer/ \
|
|
|
61
61
|
|
|
62
62
|
## 💻 Installation
|
|
63
63
|
|
|
64
|
-
|
|
64
|
+
✅ Method 1: Install from PyPI:
|
|
65
|
+
```bash
|
|
66
|
+
pip install Stackformer
|
|
67
|
+
import stackformer
|
|
68
|
+
```
|
|
65
69
|
|
|
70
|
+
🔧 Method 2: Clone the repository:
|
|
66
71
|
```bash
|
|
67
72
|
git clone https://github.com/Gurumurthy30/Stackformer
|
|
68
|
-
cd
|
|
73
|
+
cd Stackformer
|
|
69
74
|
pip install -e .
|
|
70
75
|
```
|
|
71
76
|
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
setup.py
|
|
5
|
+
Stackformer.egg-info/PKG-INFO
|
|
6
|
+
Stackformer.egg-info/SOURCES.txt
|
|
7
|
+
Stackformer.egg-info/dependency_links.txt
|
|
8
|
+
Stackformer.egg-info/requires.txt
|
|
9
|
+
Stackformer.egg-info/top_level.txt
|
|
10
|
+
stackformer/__init__.py
|
|
11
|
+
stackformer/trainer.py
|
|
12
|
+
stackformer/models/Meta.py
|
|
13
|
+
stackformer/models/OpenAI.py
|
|
14
|
+
stackformer/models/Transformer.py
|
|
15
|
+
stackformer/models/__init__.py
|
|
16
|
+
stackformer/modules/Attention.py
|
|
17
|
+
stackformer/modules/Feed_forward.py
|
|
18
|
+
stackformer/modules/Normalization.py
|
|
19
|
+
stackformer/modules/__init__.py
|
|
20
|
+
stackformer/modules/mask.py
|
|
21
|
+
stackformer/modules/position_embedding.py
|
|
22
|
+
stackformer/modules/tokenizer.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
stackformer
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "Stackformer"
|
|
3
|
-
version = "0.1.
|
|
3
|
+
version = "0.1.2"
|
|
4
4
|
description = "Modular transformer blocks built in PyTorch"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
requires-python = ">=3.9"
|
|
@@ -11,7 +11,7 @@ authors = [
|
|
|
11
11
|
]
|
|
12
12
|
|
|
13
13
|
dependencies = [
|
|
14
|
-
"torch>=2.6",
|
|
14
|
+
"torch>=2.0,<2.6",
|
|
15
15
|
"tqdm>=4.67"
|
|
16
16
|
]
|
|
17
17
|
|
|
@@ -2,7 +2,7 @@ from setuptools import setup, find_packages
|
|
|
2
2
|
|
|
3
3
|
setup(
|
|
4
4
|
name="Stackformer",
|
|
5
|
-
version="0.1.
|
|
5
|
+
version="0.1.2",
|
|
6
6
|
description="Modular transformer blocks built in PyTorch",
|
|
7
7
|
# long_description=open("README.md", "r", encoding="utf-8").read(),
|
|
8
8
|
# long_description_content_type="text/markdown",
|
|
@@ -18,7 +18,7 @@ setup(
|
|
|
18
18
|
python_requires=">=3.9",
|
|
19
19
|
packages=find_packages(exclude=["tests", "examples"]),
|
|
20
20
|
install_requires=[
|
|
21
|
-
"torch>=2.6",
|
|
21
|
+
"torch>=2.0,<2.6",
|
|
22
22
|
"tqdm>=4.67",
|
|
23
23
|
],
|
|
24
24
|
classifiers=[
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# --- Tokenizer ---
|
|
2
|
+
from .modules.tokenizer import Embedding_using_tiktoken
|
|
3
|
+
|
|
4
|
+
# --- Position Embeddings ---
|
|
5
|
+
from .modules.position_embedding import AbsolutePositionEmbedding
|
|
6
|
+
from .modules.position_embedding import SinusoidalPositionalEmbedding
|
|
7
|
+
from .modules.position_embedding import RoPE
|
|
8
|
+
|
|
9
|
+
# --- Attention mechanisms ---
|
|
10
|
+
from .modules.Attention import Self_Attention
|
|
11
|
+
from .modules.Attention import Multi_Head_Attention
|
|
12
|
+
from .modules.Attention import Cross_MultiHead_Attention
|
|
13
|
+
from .modules.Attention import Multi_query_Attention
|
|
14
|
+
from .modules.Attention import Group_query_Attention
|
|
15
|
+
from .modules.Attention import Linear_Attention
|
|
16
|
+
from .modules.Attention import Multi_latent_Attention
|
|
17
|
+
from .modules.Attention import Local_Attention
|
|
18
|
+
from .modules.Attention import kv_cache_multihead
|
|
19
|
+
from .modules.Attention import kv_cache_group_query
|
|
20
|
+
|
|
21
|
+
# --- Normalization layers ---
|
|
22
|
+
from .modules.Normalization import LayerNorm
|
|
23
|
+
from .modules.Normalization import RMSNormilization
|
|
24
|
+
|
|
25
|
+
# --- Feed Forward layers ---
|
|
26
|
+
from .modules.Feed_forward import FF_ReLU
|
|
27
|
+
from .modules.Feed_forward import FF_GELU
|
|
28
|
+
from .modules.Feed_forward import FF_LeakyReLU
|
|
29
|
+
from .modules.Feed_forward import FF_Sigmoid
|
|
30
|
+
from .modules.Feed_forward import FF_SiLU
|
|
31
|
+
|
|
32
|
+
# --- Model ---
|
|
33
|
+
from .models.OpenAI import GPT_2
|
|
34
|
+
from .models.Meta import Llama_2
|
|
35
|
+
from .models.Transformer import transformer
|
|
36
|
+
|
|
37
|
+
# --- Trainer ---
|
|
38
|
+
from .trainer import Trainer
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def precompute_theta_position_frequency(head_dim, seq_len, device='cpu', theta=10000.0):
|
|
7
|
+
assert head_dim % 2 == 0, "head_dim must be even"
|
|
8
|
+
theta_numerator = torch.arange(0, head_dim, 2, device=device)
|
|
9
|
+
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))
|
|
10
|
+
m = torch.arange(seq_len, device=device)
|
|
11
|
+
freqs = torch.outer(m, inv_freq)
|
|
12
|
+
freq_complex = torch.polar(torch.ones_like(freqs), freqs)
|
|
13
|
+
return freq_complex
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def apply_rotry_position_embedding(x, freq_complex, device='cpu', dtype=torch.float32):
|
|
17
|
+
batch_size, seq_len, num_head, emb_dim = x.shape
|
|
18
|
+
assert emb_dim % 2 == 0, "emb_dim must be even"
|
|
19
|
+
x_reshaped = x.view(batch_size, seq_len, num_head, emb_dim // 2, 2).to(device=device, dtype=dtype)
|
|
20
|
+
x_complex = torch.view_as_complex(x_reshaped)
|
|
21
|
+
freq_complex = freq_complex[:seq_len].unsqueeze(0).unsqueeze(2).to(device=device)
|
|
22
|
+
x_rotated = x_complex * freq_complex
|
|
23
|
+
x_out = torch.view_as_real(x_rotated).contiguous().view(batch_size, seq_len, num_head, emb_dim)
|
|
24
|
+
return x_out.to(device=device, dtype=dtype)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class kv_cache_group_query(nn.Module):
|
|
28
|
+
def __init__(self, emb_dim, query_num_heads, kv_num_heads, batch_size, kv_seq_len,
|
|
29
|
+
device='cpu', dtype=torch.float32, dropout=0.1):
|
|
30
|
+
super().__init__()
|
|
31
|
+
assert emb_dim % query_num_heads == 0, "Embedding dim must be divisible by query heads"
|
|
32
|
+
assert query_num_heads % kv_num_heads == 0, "query heads must be divisible by kv heads"
|
|
33
|
+
|
|
34
|
+
self.device = device
|
|
35
|
+
self.dtype = dtype
|
|
36
|
+
self.emb_dim = emb_dim
|
|
37
|
+
self.query_num_heads = query_num_heads
|
|
38
|
+
self.kv_num_heads = kv_num_heads
|
|
39
|
+
self.head_dim = emb_dim // query_num_heads
|
|
40
|
+
self.num_queries_per_kv = query_num_heads // kv_num_heads
|
|
41
|
+
self.kv_seq_len = kv_seq_len
|
|
42
|
+
|
|
43
|
+
self.query = nn.Linear(emb_dim, emb_dim, bias=False, dtype=dtype, device=device)
|
|
44
|
+
self.key = nn.Linear(emb_dim, kv_num_heads * self.head_dim, bias=False, dtype=dtype, device=device)
|
|
45
|
+
self.value = nn.Linear(emb_dim, kv_num_heads * self.head_dim, bias=False, dtype=dtype, device=device)
|
|
46
|
+
|
|
47
|
+
self.out_proj = nn.Linear(query_num_heads * self.head_dim, emb_dim, dtype=dtype, device=device)
|
|
48
|
+
self.dropout = nn.Dropout(dropout)
|
|
49
|
+
|
|
50
|
+
self.register_buffer("cache_keys", torch.zeros(batch_size, kv_seq_len, kv_num_heads, self.head_dim, device=device, dtype=dtype))
|
|
51
|
+
self.register_buffer("cache_value", torch.zeros(batch_size, kv_seq_len, kv_num_heads, self.head_dim, device=device, dtype=dtype))
|
|
52
|
+
|
|
53
|
+
def forward(self, x, start_pos):
|
|
54
|
+
batch_size, seq_len, _ = x.shape
|
|
55
|
+
|
|
56
|
+
xq = self.query(x).view(batch_size, seq_len, self.query_num_heads, self.head_dim)
|
|
57
|
+
xk = self.key(x).view(batch_size, seq_len, self.kv_num_heads, self.head_dim)
|
|
58
|
+
xv = self.value(x).view(batch_size, seq_len, self.kv_num_heads, self.head_dim)
|
|
59
|
+
|
|
60
|
+
freq_q = precompute_theta_position_frequency(head_dim=self.head_dim, seq_len=seq_len, device=self.device)
|
|
61
|
+
xq = apply_rotry_position_embedding(xq, freq_q, device=self.device, dtype=self.dtype)
|
|
62
|
+
freq_k = precompute_theta_position_frequency(head_dim=self.head_dim, seq_len=self.kv_seq_len, device=self.device)
|
|
63
|
+
xk = apply_rotry_position_embedding(xk, freq_k, device=self.device, dtype=self.dtype)
|
|
64
|
+
|
|
65
|
+
self.cache_keys[:, start_pos:start_pos + seq_len] = xk
|
|
66
|
+
self.cache_value[:, start_pos:start_pos + seq_len] = xv
|
|
67
|
+
|
|
68
|
+
xk_full = self.cache_keys[:, :start_pos + seq_len]
|
|
69
|
+
xv_full = self.cache_value[:, :start_pos + seq_len]
|
|
70
|
+
|
|
71
|
+
query = xq.transpose(1, 2)
|
|
72
|
+
key = xk_full.transpose(1, 2).repeat_interleave(self.num_queries_per_kv, dim=1)
|
|
73
|
+
value = xv_full.transpose(1, 2).repeat_interleave(self.num_queries_per_kv, dim=1)
|
|
74
|
+
|
|
75
|
+
attn_scores = torch.matmul(query, key.transpose(2, 3)) / (self.head_dim ** 0.5)
|
|
76
|
+
|
|
77
|
+
causal_mask = torch.triu(torch.ones(seq_len, attn_scores.shape[-1], dtype=torch.bool, device=self.device), diagonal=1)
|
|
78
|
+
attn_scores.masked_fill_(causal_mask[None, None, :, :], float('-inf'))
|
|
79
|
+
|
|
80
|
+
attn_weights = F.softmax(attn_scores, dim=-1)
|
|
81
|
+
out = torch.matmul(attn_weights, value)
|
|
82
|
+
|
|
83
|
+
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.query_num_heads * self.head_dim)
|
|
84
|
+
return self.dropout(self.out_proj(out))
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class RMSNormilization(nn.Module):
|
|
88
|
+
def __init__(self, emb_dim, eps=1e-5, device='cpu', dtype=torch.float32):
|
|
89
|
+
super().__init__()
|
|
90
|
+
self.eps = eps
|
|
91
|
+
self.scale = nn.Parameter(torch.ones(emb_dim, dtype=dtype, device=device))
|
|
92
|
+
|
|
93
|
+
def forward(self, x):
|
|
94
|
+
norm = x.norm(2, dim=-1, keepdim=True)
|
|
95
|
+
rms = norm / (x.shape[-1] ** 0.5)
|
|
96
|
+
return (x / (rms + self.eps)) * self.scale
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class FF_SiLU(nn.Module):
|
|
100
|
+
def __init__(self, emb_dim, hidden_dim, device='cpu', dtype=torch.float32):
|
|
101
|
+
super().__init__()
|
|
102
|
+
self.silu = nn.Sequential(
|
|
103
|
+
nn.Linear(emb_dim, hidden_dim, device=device, dtype=dtype),
|
|
104
|
+
nn.SiLU(),
|
|
105
|
+
nn.Linear(hidden_dim, emb_dim, device=device, dtype=dtype),
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
def forward(self, x):
|
|
109
|
+
return self.silu(x)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class block(nn.Module):
|
|
113
|
+
def __init__(self, emb_dim, query_num_heads, kv_num_heads, batch_size, kv_seq_len, hidden_dim,
|
|
114
|
+
eps=1e-5, dropout=0.1, dtype=torch.float32, device='cpu'):
|
|
115
|
+
super().__init__()
|
|
116
|
+
self.attn_norm = RMSNormilization(emb_dim=emb_dim, eps=eps, device=device, dtype=dtype)
|
|
117
|
+
self.ff_norm = RMSNormilization(emb_dim=emb_dim, eps=eps, device=device, dtype=dtype)
|
|
118
|
+
self.attn = kv_cache_group_query(emb_dim=emb_dim, query_num_heads=query_num_heads, kv_num_heads=kv_num_heads,
|
|
119
|
+
batch_size=batch_size, kv_seq_len=kv_seq_len, dtype=dtype,
|
|
120
|
+
dropout=dropout, device=device)
|
|
121
|
+
self.ff = FF_SiLU(emb_dim=emb_dim, hidden_dim=hidden_dim, device=device, dtype=dtype)
|
|
122
|
+
|
|
123
|
+
def forward(self, x, start_pos):
|
|
124
|
+
residual = x
|
|
125
|
+
x = self.attn_norm(x)
|
|
126
|
+
x = self.attn(x, start_pos)
|
|
127
|
+
x = x + residual
|
|
128
|
+
|
|
129
|
+
residual = x
|
|
130
|
+
x = self.ff_norm(x)
|
|
131
|
+
x = self.ff(x)
|
|
132
|
+
x = x + residual
|
|
133
|
+
return x
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class Encoder(nn.Module):
|
|
137
|
+
def __init__(self, num_layers, emb_dim, query_num_heads, kv_num_heads, batch_size, kv_seq_len,
|
|
138
|
+
hidden_dim, eps=1e-5, dropout=0.1, dtype=torch.float32, device='cpu'):
|
|
139
|
+
super().__init__()
|
|
140
|
+
self.layers = nn.ModuleList([
|
|
141
|
+
block(emb_dim=emb_dim, query_num_heads=query_num_heads, kv_num_heads=kv_num_heads,
|
|
142
|
+
batch_size=batch_size, kv_seq_len=kv_seq_len, hidden_dim=hidden_dim,
|
|
143
|
+
eps=eps, dropout=dropout, dtype=dtype, device=device)
|
|
144
|
+
for _ in range(num_layers)
|
|
145
|
+
])
|
|
146
|
+
|
|
147
|
+
def forward(self, x, start_pos):
|
|
148
|
+
for layer in self.layers:
|
|
149
|
+
x = layer(x, start_pos)
|
|
150
|
+
return x
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class Llama_2(nn.Module):
|
|
154
|
+
def __init__(self, num_layers, emb_dim, query_num_heads, kv_num_heads, batch_size, kv_seq_len, vocab_size,
|
|
155
|
+
hidden_dim, eps=1e-5, dropout=0.1, dtype=torch.float32, device='cpu'):
|
|
156
|
+
super().__init__()
|
|
157
|
+
self.device = device
|
|
158
|
+
self.vocab_size = vocab_size
|
|
159
|
+
self.dtype = dtype
|
|
160
|
+
self.seq_len = kv_seq_len # For generation slicing
|
|
161
|
+
|
|
162
|
+
self.embedding = nn.Embedding(vocab_size, emb_dim, dtype=dtype, device=device)
|
|
163
|
+
|
|
164
|
+
self.encoder = Encoder(num_layers=num_layers, emb_dim=emb_dim, query_num_heads=query_num_heads,
|
|
165
|
+
kv_num_heads=kv_num_heads, batch_size=batch_size, kv_seq_len=kv_seq_len,
|
|
166
|
+
hidden_dim=hidden_dim, eps=eps, dropout=dropout, dtype=dtype, device=device)
|
|
167
|
+
|
|
168
|
+
self.final_norm = RMSNormilization(emb_dim, eps=eps, device=device, dtype=dtype)
|
|
169
|
+
self.lm_head = nn.Linear(emb_dim, vocab_size, bias=False, dtype=dtype, device=device)
|
|
170
|
+
|
|
171
|
+
def forward(self, input_ids, start_pos=0):
|
|
172
|
+
x = self.embedding(input_ids)
|
|
173
|
+
x = self.encoder(x, start_pos)
|
|
174
|
+
x = self.final_norm(x)
|
|
175
|
+
logits = self.lm_head(x)
|
|
176
|
+
return logits
|
|
177
|
+
|
|
178
|
+
@torch.no_grad()
|
|
179
|
+
def generate(self, prompt_ids, max_new_tokens=50, temperature=1.0, top_k=None, top_p=1.0):
|
|
180
|
+
self.eval()
|
|
181
|
+
if prompt_ids.dim() == 1:
|
|
182
|
+
prompt_ids = prompt_ids.unsqueeze(0)
|
|
183
|
+
|
|
184
|
+
generated = prompt_ids.clone()
|
|
185
|
+
for step in range(max_new_tokens):
|
|
186
|
+
input_ids = generated[:, -self.seq_len:]
|
|
187
|
+
logits = self.forward(input_ids, start_pos=step) # Correct start_pos
|
|
188
|
+
logits = logits[:, -1, :]
|
|
189
|
+
|
|
190
|
+
if temperature != 1.0:
|
|
191
|
+
logits = logits / temperature
|
|
192
|
+
|
|
193
|
+
if top_k is not None and top_k > 0:
|
|
194
|
+
topk_vals, topk_indices = torch.topk(logits, top_k)
|
|
195
|
+
mask = torch.full_like(logits, float('-inf'))
|
|
196
|
+
mask.scatter_(dim=-1, index=topk_indices, src=topk_vals)
|
|
197
|
+
logits = mask
|
|
198
|
+
|
|
199
|
+
if top_p < 1.0:
|
|
200
|
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
|
201
|
+
probs = F.softmax(sorted_logits, dim=-1)
|
|
202
|
+
cum_probs = torch.cumsum(probs, dim=-1)
|
|
203
|
+
sorted_mask = cum_probs > top_p
|
|
204
|
+
sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
|
|
205
|
+
sorted_mask[..., 0] = 0
|
|
206
|
+
indices_to_remove = sorted_mask.scatter(dim=-1, index=sorted_indices, src=sorted_mask)
|
|
207
|
+
logits = logits.masked_fill(indices_to_remove, float('-inf'))
|
|
208
|
+
|
|
209
|
+
probs = F.softmax(logits, dim=-1)
|
|
210
|
+
next_token = torch.multinomial(probs, num_samples=1)
|
|
211
|
+
generated = torch.cat([generated, next_token], dim=-1)
|
|
212
|
+
|
|
213
|
+
return generated
|
|
@@ -22,7 +22,8 @@ class SinusoidalPositionalEmbedding(nn.Module):
|
|
|
22
22
|
def forward(self, x):
|
|
23
23
|
# x shape: (batch_size, seq_len, emb_dim) or (batch_size, seq_len)
|
|
24
24
|
batch_size, seq_len = x.shape[0], x.shape[1]
|
|
25
|
-
|
|
25
|
+
out = self.pe[:seq_len].unsqueeze(0).expand(batch_size, seq_len, -1)
|
|
26
|
+
return out.to(device=x.device,dtype=x.dtype)
|
|
26
27
|
|
|
27
28
|
# --- Multi Head Attention ---
|
|
28
29
|
class MultiHeadAttention(nn.Module):
|
|
@@ -135,11 +136,14 @@ class Encoder(nn.Module):
|
|
|
135
136
|
x = layer(x)
|
|
136
137
|
return x
|
|
137
138
|
|
|
138
|
-
class
|
|
139
|
+
class GPT_2(nn.Module):
|
|
139
140
|
def __init__(self, vocab_size, num_layers, Emb_dim, num_heads, seq_len,
|
|
140
141
|
dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
|
|
141
|
-
|
|
142
142
|
super().__init__()
|
|
143
|
+
self.device = device
|
|
144
|
+
self.dtype = dtype
|
|
145
|
+
self.seq_len = seq_len
|
|
146
|
+
|
|
143
147
|
# --- Token embedding ---
|
|
144
148
|
self.embedding = nn.Embedding(vocab_size, Emb_dim, dtype=self.dtype, device=self.device)
|
|
145
149
|
|
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
|
|
6
|
+
# --- position embedding ---
|
|
7
|
+
class SinusoidalPositionalEmbedding(nn.Module):
|
|
8
|
+
def __init__(self, seq_len, emb_dim):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.seq_len = seq_len
|
|
11
|
+
self.emb_dim = emb_dim
|
|
12
|
+
|
|
13
|
+
position = torch.arange(0, seq_len).unsqueeze(1)
|
|
14
|
+
div_term = torch.exp(torch.arange(0, emb_dim, 2) * -(math.log(10000.0) / emb_dim))
|
|
15
|
+
|
|
16
|
+
pe = torch.zeros(seq_len, emb_dim)
|
|
17
|
+
pe[:, 0::2] = torch.sin(position * div_term)
|
|
18
|
+
pe[:, 1::2] = torch.cos(position * div_term)
|
|
19
|
+
|
|
20
|
+
self.register_buffer("pe", pe)
|
|
21
|
+
|
|
22
|
+
def forward(self, x):
|
|
23
|
+
# x shape: (batch_size, seq_len, emb_dim) or (batch_size, seq_len)
|
|
24
|
+
batch_size, seq_len = x.shape[0], x.shape[1]
|
|
25
|
+
out = self.pe[:seq_len].unsqueeze(0).expand(batch_size, seq_len, -1)
|
|
26
|
+
return out.to(device=x.device,dtype=x.dtype)
|
|
27
|
+
|
|
28
|
+
# --- multi-head attention ---
|
|
29
|
+
class Multi_Head_Attention(nn.Module):
|
|
30
|
+
def __init__(self, emb_dim, num_heads, dropout, device='cpu',dtype=torch.float32):
|
|
31
|
+
super().__init__()
|
|
32
|
+
assert emb_dim % num_heads == 0, "emb_dim must be divisible by num_heads"
|
|
33
|
+
self.emb_dim = emb_dim
|
|
34
|
+
self.num_heads = num_heads
|
|
35
|
+
self.device = device
|
|
36
|
+
self.head_dim = emb_dim // num_heads
|
|
37
|
+
|
|
38
|
+
self.key = nn.Linear(emb_dim, emb_dim, bias=False,dtype=dtype,device=device)
|
|
39
|
+
self.query = nn.Linear(emb_dim, emb_dim, bias=False,dtype=dtype,device=device)
|
|
40
|
+
self.value = nn.Linear(emb_dim, emb_dim, bias=False,dtype=dtype,device=device)
|
|
41
|
+
|
|
42
|
+
self.scale = torch.tensor(self.head_dim ** 0.5,device=device,dtype=dtype)
|
|
43
|
+
self.out_proj = nn.Linear(emb_dim, emb_dim,dtype=dtype,device=device)
|
|
44
|
+
|
|
45
|
+
self.dropout = nn.Dropout(dropout)
|
|
46
|
+
|
|
47
|
+
def forward(self, x):
|
|
48
|
+
Batch_size, Seq_len, _ = x.shape
|
|
49
|
+
|
|
50
|
+
# Generate Q, K, V and reshape for multi-head attention
|
|
51
|
+
Keys = self.key(x).view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (Batch_size, nh, Seq_len, hd)
|
|
52
|
+
Querys = self.query(x).view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
53
|
+
Values = self.value(x).view(Batch_size, Seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
54
|
+
|
|
55
|
+
# Compute attention scores
|
|
56
|
+
scores = (Querys @ Keys.transpose(-2, -1)) / self.scale # (Batch_size, nh, Seq_len, Seq_len)
|
|
57
|
+
|
|
58
|
+
# Apply causal mask if requested
|
|
59
|
+
causal_mask = torch.triu(torch.ones(Seq_len, Seq_len, dtype=torch.bool, device=self.device), diagonal=1)
|
|
60
|
+
scores = scores.masked_fill_(causal_mask[None, None, :, :], float('-inf'))
|
|
61
|
+
|
|
62
|
+
# Apply softmax and dropout
|
|
63
|
+
attn = F.softmax(scores, dim=-1)
|
|
64
|
+
attn = self.dropout(attn)
|
|
65
|
+
|
|
66
|
+
# Apply attention to values
|
|
67
|
+
out = attn @ Values # (Batch_size, nh, Seq_len, hd)
|
|
68
|
+
|
|
69
|
+
# Concatenate heads and project
|
|
70
|
+
out = out.transpose(1, 2).contiguous().view(Batch_size, Seq_len, self.emb_dim) # (Batch_size, Seq_len, emb_dim)
|
|
71
|
+
|
|
72
|
+
return self.out_proj(out)
|
|
73
|
+
|
|
74
|
+
# --- cross-attention ---
|
|
75
|
+
class Cross_MultiHead_Attention(nn.Module):
|
|
76
|
+
def __init__(self, emb_dim, num_heads, dropout,device='cpu', dtype=torch.float32):
|
|
77
|
+
super().__init__()
|
|
78
|
+
assert emb_dim % num_heads == 0, "emb_dim must be divisible by num_heads"
|
|
79
|
+
self.emb_dim = emb_dim
|
|
80
|
+
self.device = device
|
|
81
|
+
self.num_heads = num_heads
|
|
82
|
+
self.head_dim = emb_dim // num_heads
|
|
83
|
+
|
|
84
|
+
# Querys, Key, Value projections
|
|
85
|
+
self.query = nn.Linear(emb_dim, emb_dim, bias=False,dtype=dtype,device=device)
|
|
86
|
+
self.key = nn.Linear(emb_dim, emb_dim, bias=False,dtype=dtype,device=device)
|
|
87
|
+
self.value = nn.Linear(emb_dim, emb_dim, bias=False,dtype=dtype,device=device)
|
|
88
|
+
|
|
89
|
+
self.scale = torch.tensor(self.head_dim ** 0.5,device=device,dtype=dtype)
|
|
90
|
+
|
|
91
|
+
self.out_proj = nn.Linear(emb_dim, emb_dim,dtype=dtype,device=device)
|
|
92
|
+
self.dropout = nn.Dropout(dropout)
|
|
93
|
+
|
|
94
|
+
def forward(self, x, context=None):
|
|
95
|
+
Batch_size, query_seq_len, _ = x.shape
|
|
96
|
+
context = x if context is None else context # self-attention fallback
|
|
97
|
+
KV_seq_len = context.shape[1]
|
|
98
|
+
|
|
99
|
+
# Project Q, K, V
|
|
100
|
+
Querys = self.query(x).view(Batch_size, query_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
101
|
+
Keys = self.key(context).view(Batch_size, KV_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
102
|
+
Values = self.value(context).view(Batch_size, KV_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
103
|
+
|
|
104
|
+
# Attention scores
|
|
105
|
+
scores = (Querys @ Keys.transpose(-2, -1)) / self.scale # (Batch_size, nh, query_seq_len, KV_seq_len)
|
|
106
|
+
|
|
107
|
+
causal_mask = torch.triu(torch.ones(query_seq_len, query_seq_len, dtype=torch.bool, device=self.device), diagonal=1)
|
|
108
|
+
scores = scores.masked_fill_(causal_mask[None, None, :, :], float('-inf'))
|
|
109
|
+
|
|
110
|
+
attn = F.softmax(scores, dim=-1)
|
|
111
|
+
attn = self.dropout(attn)
|
|
112
|
+
|
|
113
|
+
out = attn @ Values
|
|
114
|
+
out = out.transpose(1, 2).contiguous().view(Batch_size, query_seq_len, self.emb_dim) # (Batch_size, query_seq_len, emb_dim)
|
|
115
|
+
|
|
116
|
+
return self.out_proj(out)
|
|
117
|
+
|
|
118
|
+
# --- Feed Forward ---
|
|
119
|
+
class FF_ReLU(nn.Module):
|
|
120
|
+
def __init__(self, emb_dim, hidden_dim, dropout=0.1, device='cpu', dtype=torch.float32):
|
|
121
|
+
super().__init__()
|
|
122
|
+
self.relu = nn.Sequential(
|
|
123
|
+
nn.Linear(emb_dim, hidden_dim, device=device, dtype=dtype),
|
|
124
|
+
nn.ReLU(),
|
|
125
|
+
nn.Dropout(dropout),
|
|
126
|
+
nn.Linear(hidden_dim, emb_dim, device=device, dtype=dtype),
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
def forward(self, x):
|
|
130
|
+
return self.relu(x)
|
|
131
|
+
|
|
132
|
+
class LayerNorm(nn.Module):
|
|
133
|
+
def __init__(self, emb_dim, eps=1e-5, device='cpu', dtype=torch.float32):
|
|
134
|
+
super().__init__()
|
|
135
|
+
self.eps = eps
|
|
136
|
+
self.weight = nn.Parameter(torch.ones(emb_dim, device=device, dtype=dtype))
|
|
137
|
+
self.bias = nn.Parameter(torch.zeros(emb_dim, device=device, dtype=dtype))
|
|
138
|
+
|
|
139
|
+
def forward(self, x):
|
|
140
|
+
mean = x.mean(dim=-1, keepdim=True)
|
|
141
|
+
var = x.var(dim=-1, keepdim=True, unbiased=False)
|
|
142
|
+
norm_x = (x - mean) / torch.sqrt(var + self.eps)
|
|
143
|
+
return norm_x * self.weight + self.bias
|
|
144
|
+
|
|
145
|
+
class Encoder(nn.Module):
|
|
146
|
+
def __init__(self, emb_dim, num_heads, dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
|
|
147
|
+
super().__init__()
|
|
148
|
+
self.attention = Multi_Head_Attention(emb_dim, num_heads, dropout, device=device, dtype=dtype)
|
|
149
|
+
self.norm1 = LayerNorm(emb_dim, eps=eps, device=device, dtype=dtype)
|
|
150
|
+
self.ff_relu = FF_ReLU(emb_dim, hidden_dim, dropout, device=device, dtype=dtype)
|
|
151
|
+
self.norm2 = LayerNorm(emb_dim, eps=eps, device=device, dtype=dtype)
|
|
152
|
+
|
|
153
|
+
def forward(self, x):
|
|
154
|
+
residual = x
|
|
155
|
+
x = self.attention(x)
|
|
156
|
+
x = self.norm1(x)
|
|
157
|
+
x = x + residual
|
|
158
|
+
|
|
159
|
+
residual = x
|
|
160
|
+
x = self.ff_relu(x)
|
|
161
|
+
x = self.norm2(x)
|
|
162
|
+
x = x + residual
|
|
163
|
+
|
|
164
|
+
return x
|
|
165
|
+
|
|
166
|
+
class Decoder(nn.Module):
|
|
167
|
+
def __init__(self, emb_dim, num_heads, dropout, hidden_dim, eps=1e-5, device='cpu', dtype=torch.float32):
|
|
168
|
+
super().__init__()
|
|
169
|
+
self.attention = Multi_Head_Attention(emb_dim, num_heads, dropout, device=device, dtype=dtype)
|
|
170
|
+
self.norm1 = LayerNorm(emb_dim, eps=eps, device=device, dtype=dtype)
|
|
171
|
+
self.cross_attention = Cross_MultiHead_Attention(emb_dim, num_heads, dropout, device=device, dtype=dtype)
|
|
172
|
+
self.norm2 = LayerNorm(emb_dim, eps=eps, device=device, dtype=dtype)
|
|
173
|
+
self.ff_relu = FF_ReLU(emb_dim, hidden_dim, dropout, device=device, dtype=dtype)
|
|
174
|
+
self.norm3 = LayerNorm(emb_dim, eps=eps, device=device, dtype=dtype)
|
|
175
|
+
|
|
176
|
+
def forward(self, x, enc_output):
|
|
177
|
+
residual = x
|
|
178
|
+
x = self.attention(x)
|
|
179
|
+
x = self.norm1(x)
|
|
180
|
+
x = x + residual
|
|
181
|
+
|
|
182
|
+
residual = x
|
|
183
|
+
x = self.cross_attention(x, context = enc_output)
|
|
184
|
+
x = self.norm2(x)
|
|
185
|
+
x = x + residual
|
|
186
|
+
|
|
187
|
+
residual = x
|
|
188
|
+
x = self.ff_relu(x)
|
|
189
|
+
x = self.norm3(x)
|
|
190
|
+
x = x + residual
|
|
191
|
+
|
|
192
|
+
return x
|
|
193
|
+
|
|
194
|
+
class transformer(nn.Module):
|
|
195
|
+
def __init__(self, vocab_size, emb_dim, num_heads, dropout, hidden_dim,
|
|
196
|
+
encoder_layers, decoder_layers, seq_len, eps=1e-5, device='cpu', dtype=torch.float32,
|
|
197
|
+
):
|
|
198
|
+
super().__init__()
|
|
199
|
+
self.encoder_layers = encoder_layers
|
|
200
|
+
self.decoder_layers = decoder_layers
|
|
201
|
+
|
|
202
|
+
self.token_emb = nn.Embedding(vocab_size, emb_dim, device=device, dtype=dtype)
|
|
203
|
+
self.pos = SinusoidalPositionalEmbedding(seq_len=seq_len, emb_dim=emb_dim)
|
|
204
|
+
|
|
205
|
+
self.encoder_stack = nn.ModuleList([
|
|
206
|
+
Encoder(emb_dim, num_heads, dropout, hidden_dim, eps=eps, device=device, dtype=dtype)
|
|
207
|
+
for _ in range(encoder_layers)
|
|
208
|
+
])
|
|
209
|
+
|
|
210
|
+
self.decoder_stack = nn.ModuleList([
|
|
211
|
+
Decoder(emb_dim, num_heads, dropout, hidden_dim, eps=eps, device=device, dtype=dtype)
|
|
212
|
+
for _ in range(decoder_layers)
|
|
213
|
+
])
|
|
214
|
+
|
|
215
|
+
# --- final norm ---
|
|
216
|
+
self.final_norm = LayerNorm(emb_dim, eps=eps, device=device, dtype=dtype)
|
|
217
|
+
|
|
218
|
+
# --- output projection ---
|
|
219
|
+
self.out_proj = nn.Linear(emb_dim, vocab_size, device=device, dtype=dtype)
|
|
220
|
+
|
|
221
|
+
def encoder(self, x):
|
|
222
|
+
x = self.token_emb(x) + self.pos(x)
|
|
223
|
+
for block in self.encoder_stack:
|
|
224
|
+
x = block(x)
|
|
225
|
+
return x
|
|
226
|
+
|
|
227
|
+
def decoder(self, x, enc_output):
|
|
228
|
+
x = self.token_emb(x) + self.pos(x)
|
|
229
|
+
for block in self.decoder_stack:
|
|
230
|
+
x = block(x, enc_output)
|
|
231
|
+
return x
|
|
232
|
+
|
|
233
|
+
def forward(self, source, target):
|
|
234
|
+
enc_output = self.encoder(source)
|
|
235
|
+
out = self.decoder(target, enc_output)
|
|
236
|
+
out = self.final_norm(out)
|
|
237
|
+
out = self.out_proj(out)
|
|
238
|
+
return out
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import torch.nn as nn
|
|
3
3
|
import torch.nn.functional as F
|
|
4
|
+
|
|
4
5
|
class Self_Attention(nn.Module):
|
|
5
6
|
def __init__(self, Emb_dim, dropout,dtype=torch.float32,device='cpu'):
|
|
6
7
|
super().__init__()
|
|
@@ -412,23 +413,23 @@ class kv_cache_multihead(nn.Module):
|
|
|
412
413
|
self.dtype = dtype
|
|
413
414
|
self.device = device
|
|
414
415
|
|
|
415
|
-
assert emb_dim % num_heads == 0
|
|
416
|
+
assert emb_dim % num_heads == 0, "emb_dim must be divisible by num_heads"
|
|
416
417
|
self.emb_dim = emb_dim
|
|
417
418
|
self.num_heads = num_heads
|
|
418
419
|
self.head_dim = emb_dim // num_heads
|
|
419
420
|
self.kv_seq_len = kv_seq_len
|
|
420
421
|
|
|
421
|
-
self.query = nn.Linear(emb_dim, emb_dim, bias=False,dtype=dtype,device=device)
|
|
422
|
-
self.key = nn.Linear(emb_dim, emb_dim, bias=False,dtype=dtype,device=device)
|
|
423
|
-
self.value = nn.Linear(emb_dim, emb_dim, bias=False,dtype=dtype,device=device)
|
|
422
|
+
self.query = nn.Linear(emb_dim, emb_dim, bias=False, dtype=dtype, device=device)
|
|
423
|
+
self.key = nn.Linear(emb_dim, emb_dim, bias=False, dtype=dtype, device=device)
|
|
424
|
+
self.value = nn.Linear(emb_dim, emb_dim, bias=False, dtype=dtype, device=device)
|
|
424
425
|
|
|
425
|
-
self.out_proj = nn.Linear(emb_dim, emb_dim,dtype=dtype,device=device)
|
|
426
|
+
self.out_proj = nn.Linear(emb_dim, emb_dim, dtype=dtype, device=device)
|
|
426
427
|
self.dropout = nn.Dropout(dropout)
|
|
428
|
+
# KV caches
|
|
429
|
+
self.register_buffer("cache_keys", torch.zeros(batch_size, kv_seq_len*2, num_heads, self.head_dim,device=device,dtype=dtype))
|
|
430
|
+
self.register_buffer("cache_value", torch.zeros(batch_size, kv_seq_len*2, num_heads, self.head_dim,device=device,dtype=dtype))
|
|
427
431
|
|
|
428
|
-
|
|
429
|
-
self.cache_value = torch.zeros(batch_size, kv_seq_len, num_heads, self.head_dim,dtype=dtype,device=device)
|
|
430
|
-
|
|
431
|
-
def forward(self, x, start_pos, RoPE: False):
|
|
432
|
+
def forward(self, x, start_pos, RoPE=False):
|
|
432
433
|
batch_size, seq_len, C = x.shape
|
|
433
434
|
|
|
434
435
|
xq = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
|
|
@@ -441,36 +442,35 @@ class kv_cache_multihead(nn.Module):
|
|
|
441
442
|
freq_complex = precompute_theta_position_frequency(head_dim=self.head_dim, seq_len=self.kv_seq_len, device=self.device)
|
|
442
443
|
xk = apply_rotry_position_embedding(xk, freq_complex, device=self.device, dtype=self.dtype)
|
|
443
444
|
|
|
444
|
-
# Cache keys and values
|
|
445
|
-
self.cache_keys[
|
|
446
|
-
self.cache_value[
|
|
447
|
-
|
|
448
|
-
xk_full = self.cache_keys[:, :start_pos+seq_len]
|
|
449
|
-
xv_full = self.cache_value[:, :start_pos+seq_len]
|
|
445
|
+
# Cache keys and values - only update the batch_size portion we're using
|
|
446
|
+
self.cache_keys[:batch_size, start_pos:start_pos+seq_len] = xk
|
|
447
|
+
self.cache_value[:batch_size, start_pos:start_pos+seq_len] = xv
|
|
450
448
|
|
|
449
|
+
# Only use the relevant batch portion from cache
|
|
450
|
+
xk_full = self.cache_keys[:batch_size, :start_pos+seq_len]
|
|
451
|
+
xv_full = self.cache_value[:batch_size, :start_pos+seq_len]
|
|
452
|
+
|
|
451
453
|
query = xq.transpose(1, 2) # (batch_size, num_head, seq_len, emb_dim)
|
|
452
454
|
key = xk_full.transpose(1, 2) # (batch_size, num_head, T_total, emb_dim)
|
|
453
455
|
value = xv_full.transpose(1, 2) # (batch_size, num_head, T_total, emb_dim)
|
|
454
456
|
|
|
455
457
|
attn_scores = torch.matmul(query, key.transpose(2, 3)) / (self.head_dim ** 0.5)
|
|
456
|
-
|
|
457
458
|
# Causal mask
|
|
458
|
-
causal_mask = torch.triu(torch.ones(
|
|
459
|
+
causal_mask = torch.triu(torch.ones(attn_scores.shape[-2], attn_scores.shape[-1], dtype=torch.bool, device=self.device), diagonal=1)
|
|
459
460
|
attn_scores.masked_fill_(causal_mask[None, None, :, :], float('-inf'))
|
|
460
461
|
|
|
461
462
|
attn_weights = F.softmax(attn_scores, dim=-1)
|
|
462
463
|
out = torch.matmul(attn_weights, value)
|
|
463
|
-
|
|
464
464
|
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
|
|
465
465
|
return self.dropout(self.out_proj(out))
|
|
466
466
|
|
|
467
467
|
class kv_cache_group_query(nn.Module):
|
|
468
|
-
def __init__(self, emb_dim, query_num_heads, kv_num_heads, batch_size, kv_seq_len,device='cpu'
|
|
468
|
+
def __init__(self, emb_dim, query_num_heads, kv_num_heads, batch_size, kv_seq_len, device='cpu', dtype=torch.float32, dropout=0.1):
|
|
469
469
|
super().__init__()
|
|
470
470
|
self.dtype = dtype
|
|
471
471
|
self.device = device
|
|
472
472
|
|
|
473
|
-
assert query_num_heads % kv_num_heads
|
|
473
|
+
assert query_num_heads % kv_num_heads == 0, "query heads must be divisible by kv heads"
|
|
474
474
|
assert emb_dim % query_num_heads == 0, "embedding must be divisible by query heads"
|
|
475
475
|
|
|
476
476
|
self.emb_dim = emb_dim
|
|
@@ -488,8 +488,8 @@ class kv_cache_group_query(nn.Module):
|
|
|
488
488
|
self.dropout = nn.Dropout(dropout)
|
|
489
489
|
|
|
490
490
|
# KV caches
|
|
491
|
-
self.register_buffer("cache_keys", torch.zeros(batch_size, kv_seq_len, kv_num_heads, self.head_dim,device=device,dtype=dtype))
|
|
492
|
-
self.register_buffer("cache_value", torch.zeros(batch_size, kv_seq_len, kv_num_heads, self.head_dim,device=device,dtype=dtype))
|
|
491
|
+
self.register_buffer("cache_keys", torch.zeros(batch_size, kv_seq_len*2, kv_num_heads, self.head_dim,device=device,dtype=dtype))
|
|
492
|
+
self.register_buffer("cache_value", torch.zeros(batch_size, kv_seq_len*2, kv_num_heads, self.head_dim,device=device,dtype=dtype))
|
|
493
493
|
|
|
494
494
|
def forward(self, x, start_pos, RoPE=False):
|
|
495
495
|
batch_size, seq_len, _ = x.shape
|
|
@@ -501,15 +501,16 @@ class kv_cache_group_query(nn.Module):
|
|
|
501
501
|
if RoPE:
|
|
502
502
|
freq_q = precompute_theta_position_frequency(head_dim=self.head_dim, seq_len=seq_len, device=self.device)
|
|
503
503
|
xq = apply_rotry_position_embedding(xq, freq_q, device=self.device, dtype=self.dtype)
|
|
504
|
-
freq_k = precompute_theta_position_frequency(head_dim=self.head_dim, seq_len=
|
|
504
|
+
freq_k = precompute_theta_position_frequency(head_dim=self.head_dim, seq_len=seq_len, device=self.device)
|
|
505
505
|
xk = apply_rotry_position_embedding(xk, freq_k, device=self.device, dtype=self.dtype)
|
|
506
|
-
# Cache
|
|
507
|
-
self.cache_keys[:, start_pos:start_pos+seq_len] = xk
|
|
508
|
-
self.cache_value[:, start_pos:start_pos+seq_len] = xv
|
|
509
|
-
|
|
510
|
-
xk_full = self.cache_keys[:, :start_pos+seq_len] # [B, T, kv_heads, D]
|
|
511
|
-
xv_full = self.cache_value[:, :start_pos+seq_len]
|
|
512
506
|
|
|
507
|
+
# Cache keys and values - only update the batch_size portion we're using
|
|
508
|
+
self.cache_keys[:batch_size, start_pos:start_pos+seq_len] = xk
|
|
509
|
+
self.cache_value[:batch_size, start_pos:start_pos+seq_len] = xv
|
|
510
|
+
# Only use the relevant batch portion from cache
|
|
511
|
+
xk_full = self.cache_keys[:batch_size, :start_pos+seq_len]
|
|
512
|
+
xv_full = self.cache_value[:batch_size, :start_pos+seq_len]
|
|
513
|
+
|
|
513
514
|
# Transpose for attention: [B, H, T, D]
|
|
514
515
|
query = xq.transpose(1, 2) # [B, q_heads, seq_len, D]
|
|
515
516
|
key = xk_full.transpose(1, 2) # [B, kv_heads, total_kv_len, D]
|
|
@@ -518,16 +519,14 @@ class kv_cache_group_query(nn.Module):
|
|
|
518
519
|
# Repeat keys and values to match query heads
|
|
519
520
|
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
|
520
521
|
value = value.repeat_interleave(self.num_queries_per_kv, dim=1)
|
|
521
|
-
|
|
522
522
|
# Attention
|
|
523
523
|
attn_scores = torch.matmul(query, key.transpose(2, 3)) / (self.head_dim ** 0.5)
|
|
524
|
-
|
|
525
524
|
# Causal mask
|
|
526
|
-
causal_mask = torch.triu(torch.ones(
|
|
525
|
+
causal_mask = torch.triu(torch.ones(attn_scores.shape[-2], attn_scores.shape[-1], dtype=torch.bool, device=self.device), diagonal=1)
|
|
527
526
|
attn_scores.masked_fill_(causal_mask[None, None, :, :], float('-inf'))
|
|
528
|
-
|
|
527
|
+
#softmax
|
|
529
528
|
attn_weights = F.softmax(attn_scores, dim=-1)
|
|
529
|
+
# atten weight @ value
|
|
530
530
|
out = torch.matmul(attn_weights, value)
|
|
531
|
-
|
|
532
531
|
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.emb_dim)
|
|
533
532
|
return self.dropout(self.out_proj(out))
|
|
@@ -14,7 +14,8 @@ class AbsolutePositionEmbedding(nn.Module):
|
|
|
14
14
|
batch_size, seq_len = x.shape[0], x.shape[1]
|
|
15
15
|
positions = torch.arange(0, seq_len)
|
|
16
16
|
abs_pos = self.embedding(positions) # (seq_len, emb_dim)
|
|
17
|
-
|
|
17
|
+
out = abs_pos.unsqueeze(0).expand(batch_size, seq_len, -1)
|
|
18
|
+
return out.to(device=x.device,dtype=x.dtype)
|
|
18
19
|
|
|
19
20
|
# --- Sinusoidal Positional Embedding ---
|
|
20
21
|
class SinusoidalPositionalEmbedding(nn.Module):
|
|
@@ -35,8 +36,9 @@ class SinusoidalPositionalEmbedding(nn.Module):
|
|
|
35
36
|
def forward(self, x):
|
|
36
37
|
# x shape: (batch_size, seq_len, emb_dim) or (batch_size, seq_len)
|
|
37
38
|
batch_size, seq_len = x.shape[0], x.shape[1]
|
|
38
|
-
|
|
39
|
-
|
|
39
|
+
out = self.pe[:seq_len].unsqueeze(0).expand(batch_size, seq_len, -1)
|
|
40
|
+
return out.to(device=x.device,dtype=x.dtype)
|
|
41
|
+
|
|
40
42
|
# --- RoPE ---
|
|
41
43
|
class RoPE(nn.Module):
|
|
42
44
|
def __init__(self, head_dim, seq_len, theta=10000.0, device='cpu', dtype=torch.float32):
|
|
@@ -58,4 +60,4 @@ class RoPE(nn.Module):
|
|
|
58
60
|
freqs = self.freq_complex[:seq_len].unsqueeze(0).unsqueeze(2) # (1, seq_len, 1, head_dim//2)
|
|
59
61
|
x_rotated = x_complex * freqs
|
|
60
62
|
x_out = torch.view_as_real(x_rotated).contiguous().view(batch_size, seq_len, num_head, emb_dim)
|
|
61
|
-
return x_out.to(device=
|
|
63
|
+
return x_out.to(device=x.device, dtype=x.dtype)
|
|
@@ -0,0 +1,356 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import os
|
|
3
|
+
from torch.utils.data import DataLoader
|
|
4
|
+
from torch.optim import AdamW, SGD
|
|
5
|
+
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, CosineAnnealingWarmRestarts
|
|
6
|
+
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
|
|
9
|
+
class Trainer:
|
|
10
|
+
def __init__(self,
|
|
11
|
+
model,
|
|
12
|
+
train_dataset,
|
|
13
|
+
eval_dataset,
|
|
14
|
+
train_batch_size,
|
|
15
|
+
eval_batch_size,
|
|
16
|
+
vocab_size,
|
|
17
|
+
output_dir,
|
|
18
|
+
num_epoch,
|
|
19
|
+
lr: float,
|
|
20
|
+
scheduler_type=None,
|
|
21
|
+
optimizer_type="adamw",
|
|
22
|
+
eval_per_epoch = 1,
|
|
23
|
+
eval_per_step = None,
|
|
24
|
+
weight_decay=0.01,
|
|
25
|
+
warmup_steps=0,
|
|
26
|
+
grad_accumulation_step=1,
|
|
27
|
+
max_eval_step=None,
|
|
28
|
+
max_steps=None,
|
|
29
|
+
Save_step=None,
|
|
30
|
+
Save_epoch=None,
|
|
31
|
+
max_epoch=None,
|
|
32
|
+
model_to_resume=None,
|
|
33
|
+
resume_training=False,
|
|
34
|
+
seed=42,
|
|
35
|
+
device='cpu'):
|
|
36
|
+
self.model = model
|
|
37
|
+
self.train_dataset = train_dataset
|
|
38
|
+
self.train_batch_size = train_batch_size
|
|
39
|
+
self.eval_dataset = eval_dataset
|
|
40
|
+
self.eval_batch_size = eval_batch_size
|
|
41
|
+
self.vocab_size = vocab_size
|
|
42
|
+
self.num_epoch = num_epoch
|
|
43
|
+
self.max_steps = max_steps
|
|
44
|
+
self.max_epoch = max_epoch
|
|
45
|
+
self.eval_per_epoch = eval_per_epoch
|
|
46
|
+
self.eval_per_step = eval_per_step
|
|
47
|
+
self.max_eval_step = max_eval_step
|
|
48
|
+
self.lr = lr
|
|
49
|
+
self.scheduler_type = scheduler_type
|
|
50
|
+
self.output_dir = output_dir
|
|
51
|
+
self.model_to_resume = model_to_resume
|
|
52
|
+
self.resume_training = resume_training
|
|
53
|
+
self.Save_step = Save_step
|
|
54
|
+
self.Save_epoch = Save_epoch
|
|
55
|
+
self.grad_accumulation_step = grad_accumulation_step
|
|
56
|
+
self.optimizer_type = optimizer_type
|
|
57
|
+
self.weight_decay = weight_decay
|
|
58
|
+
self.warmup_steps = warmup_steps
|
|
59
|
+
self.seed = seed
|
|
60
|
+
self.device = device
|
|
61
|
+
|
|
62
|
+
# --- random seed ---
|
|
63
|
+
def set_seed(self, seed):
|
|
64
|
+
torch.manual_seed(seed)
|
|
65
|
+
if self.device == 'cuda' and torch.cuda.is_available():
|
|
66
|
+
torch.cuda.manual_seed_all(seed)
|
|
67
|
+
|
|
68
|
+
# --- scheduler ---
|
|
69
|
+
def get_scheduler(self, scheduler_type, total_training_steps, optimizer):
|
|
70
|
+
if scheduler_type is None:
|
|
71
|
+
return None
|
|
72
|
+
elif scheduler_type == "linear":
|
|
73
|
+
return get_linear_schedule_with_warmup(
|
|
74
|
+
optimizer,
|
|
75
|
+
num_warmup_steps=self.warmup_steps,
|
|
76
|
+
num_training_steps=total_training_steps
|
|
77
|
+
)
|
|
78
|
+
elif scheduler_type == "cosine":
|
|
79
|
+
return get_cosine_schedule_with_warmup(
|
|
80
|
+
optimizer,
|
|
81
|
+
num_warmup_steps=self.warmup_steps,
|
|
82
|
+
num_training_steps=total_training_steps
|
|
83
|
+
)
|
|
84
|
+
elif scheduler_type == "cosine_restarts":
|
|
85
|
+
return get_cosine_with_hard_restarts_schedule_with_warmup(
|
|
86
|
+
optimizer,
|
|
87
|
+
num_warmup_steps=self.warmup_steps,
|
|
88
|
+
num_training_steps=total_training_steps,
|
|
89
|
+
num_cycles=4 # Number of restarts
|
|
90
|
+
)
|
|
91
|
+
elif scheduler_type == "cosineannealing":
|
|
92
|
+
return CosineAnnealingLR(optimizer, T_max=total_training_steps)
|
|
93
|
+
elif scheduler_type == "cosine_warm_restarts":
|
|
94
|
+
return CosineAnnealingWarmRestarts(optimizer, T_0=total_training_steps//4, T_mult=2)
|
|
95
|
+
else:
|
|
96
|
+
raise ValueError(f"Unsupported scheduler type: {scheduler_type}")
|
|
97
|
+
|
|
98
|
+
# --- optimizer ---
|
|
99
|
+
def get_optimizer(self, optimizer_type, model, lr, weight_decay):
|
|
100
|
+
if optimizer_type.lower() == "adamw":
|
|
101
|
+
return AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
|
102
|
+
elif optimizer_type.lower() == "sgd":
|
|
103
|
+
return SGD(model.parameters(), lr=lr, weight_decay=weight_decay, momentum=0.9)
|
|
104
|
+
else:
|
|
105
|
+
raise ValueError(f"Unsupported optimizer {optimizer_type}")
|
|
106
|
+
|
|
107
|
+
# --- validate model ---
|
|
108
|
+
def eval_model(self, model, eval_loader, max_val_steps):
|
|
109
|
+
eval_loss = 0
|
|
110
|
+
model.eval()
|
|
111
|
+
max_val_steps = min(max_val_steps or len(eval_loader), len(eval_loader))
|
|
112
|
+
with torch.no_grad():
|
|
113
|
+
pbar = tqdm(eval_loader, total=max_val_steps, desc="Evaluating", leave=False)
|
|
114
|
+
for step, (inputs, targets) in enumerate(pbar):
|
|
115
|
+
inputs = inputs.to(self.device)
|
|
116
|
+
targets = targets.to(self.device)
|
|
117
|
+
output = model(inputs) # shape: [B, T, V]
|
|
118
|
+
loss = torch.nn.functional.cross_entropy(
|
|
119
|
+
output.view(-1, output.size(-1)),
|
|
120
|
+
targets.view(-1), ignore_index=-100)
|
|
121
|
+
pbar.set_postfix(loss=loss.item())
|
|
122
|
+
eval_loss += loss.item()
|
|
123
|
+
if step + 1 >= max_val_steps:
|
|
124
|
+
break
|
|
125
|
+
model.train()
|
|
126
|
+
avg_eval_loss = eval_loss / max_val_steps
|
|
127
|
+
return avg_eval_loss
|
|
128
|
+
|
|
129
|
+
# --- train dataloader ---
|
|
130
|
+
def get_train_loader(self, train_dataset, batch_size, seed):
|
|
131
|
+
generator = torch.Generator()
|
|
132
|
+
generator.manual_seed(seed)
|
|
133
|
+
train_loader = DataLoader(
|
|
134
|
+
train_dataset,
|
|
135
|
+
batch_size=batch_size,
|
|
136
|
+
shuffle=True,
|
|
137
|
+
generator=generator,
|
|
138
|
+
pin_memory=True if self.device == 'cuda' else False
|
|
139
|
+
)
|
|
140
|
+
return train_loader
|
|
141
|
+
|
|
142
|
+
# --- validation dataloader ---
|
|
143
|
+
def get_eval_loader(self, eval_dataset, batch_size, seed):
|
|
144
|
+
generator = torch.Generator()
|
|
145
|
+
generator.manual_seed(seed)
|
|
146
|
+
eval_loader = DataLoader(
|
|
147
|
+
eval_dataset,
|
|
148
|
+
batch_size=batch_size,
|
|
149
|
+
shuffle=False,
|
|
150
|
+
generator=generator,
|
|
151
|
+
pin_memory=True if self.device == 'cuda' else False
|
|
152
|
+
)
|
|
153
|
+
return eval_loader
|
|
154
|
+
|
|
155
|
+
# --- save model ---
|
|
156
|
+
def save_model(self, model, optimizer, scheduler, epoch, num_epoch, loss, global_step,
|
|
157
|
+
accumulated_steps, batch_idx_to_resume, output_dir, name):
|
|
158
|
+
checkpoint = {
|
|
159
|
+
'model_state_dict': model.state_dict(),
|
|
160
|
+
'optimizer_state_dict': optimizer.state_dict(),
|
|
161
|
+
'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
|
|
162
|
+
'current_epoch': epoch,
|
|
163
|
+
'num_epoch': num_epoch,
|
|
164
|
+
'loss': loss,
|
|
165
|
+
'accumulated_steps': accumulated_steps,
|
|
166
|
+
'global_step': global_step,
|
|
167
|
+
'batch_idx_to_resume': batch_idx_to_resume,
|
|
168
|
+
'rng_state': {
|
|
169
|
+
'torch': torch.get_rng_state(),
|
|
170
|
+
'cuda': torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None
|
|
171
|
+
}
|
|
172
|
+
}
|
|
173
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
174
|
+
path = f'{output_dir}/checkpoint_{name}.pt'
|
|
175
|
+
torch.save(checkpoint, path)
|
|
176
|
+
print(f'Saved training state to {path}')
|
|
177
|
+
|
|
178
|
+
def load_checkpoint(self, path, model, optimizer, scheduler):
|
|
179
|
+
checkpoint = torch.load(path, map_location=self.device)
|
|
180
|
+
model.load_state_dict(checkpoint['model_state_dict'])
|
|
181
|
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
182
|
+
if checkpoint.get('scheduler_state_dict') is not None and scheduler is not None:
|
|
183
|
+
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
|
184
|
+
current_epoch = checkpoint['current_epoch']
|
|
185
|
+
global_step = checkpoint['global_step']
|
|
186
|
+
loss = checkpoint['loss']
|
|
187
|
+
num_epoch = checkpoint['num_epoch']
|
|
188
|
+
accumulated_steps = checkpoint['accumulated_steps']
|
|
189
|
+
batch_idx_to_resume = checkpoint['batch_idx_to_resume']
|
|
190
|
+
# RNG
|
|
191
|
+
torch.set_rng_state(checkpoint['rng_state']['torch'])
|
|
192
|
+
if torch.cuda.is_available() and checkpoint['rng_state']['cuda']:
|
|
193
|
+
torch.cuda.set_rng_state_all(checkpoint['rng_state']['cuda'])
|
|
194
|
+
return {
|
|
195
|
+
'current_epoch': current_epoch,
|
|
196
|
+
'num_epoch': num_epoch,
|
|
197
|
+
'accumulated_steps' : accumulated_steps,
|
|
198
|
+
'batch_idx_to_resume': batch_idx_to_resume,
|
|
199
|
+
'global_step': global_step,
|
|
200
|
+
'loss': loss
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
# --- train ---
|
|
204
|
+
def train(self):
|
|
205
|
+
# --- seed ---
|
|
206
|
+
self.set_seed(self.seed)
|
|
207
|
+
# --- dataloader ---
|
|
208
|
+
train_loader = self.get_train_loader(self.train_dataset, self.train_batch_size, self.seed)
|
|
209
|
+
eval_loader = self.get_eval_loader(self.eval_dataset, self.eval_batch_size, self.seed)
|
|
210
|
+
|
|
211
|
+
# --- Calculate the total step ---
|
|
212
|
+
steps_per_epoch = len(train_loader) // self.grad_accumulation_step
|
|
213
|
+
total_training_steps = self.max_steps if self.max_steps is not None else steps_per_epoch * self.num_epoch
|
|
214
|
+
|
|
215
|
+
os.makedirs(self.output_dir, exist_ok=True)
|
|
216
|
+
|
|
217
|
+
model = self.model.to(self.device)
|
|
218
|
+
optimizer = self.get_optimizer(self.optimizer_type, model, self.lr, self.weight_decay)
|
|
219
|
+
criterion = torch.nn.functional.cross_entropy
|
|
220
|
+
scheduler = self.get_scheduler(self.scheduler_type, total_training_steps, optimizer)
|
|
221
|
+
|
|
222
|
+
global_step = 0
|
|
223
|
+
start_epoch = 0
|
|
224
|
+
num_epoch = self.num_epoch
|
|
225
|
+
batch_idx_to_resume = 0
|
|
226
|
+
accumulated_steps = 0
|
|
227
|
+
|
|
228
|
+
if self.resume_training and self.model_to_resume:
|
|
229
|
+
ckpt_data = self.load_checkpoint(self.model_to_resume, model, optimizer, scheduler)
|
|
230
|
+
start_epoch = ckpt_data['current_epoch']
|
|
231
|
+
global_step = ckpt_data['global_step']
|
|
232
|
+
num_epoch = ckpt_data['num_epoch']
|
|
233
|
+
batch_idx_to_resume = ckpt_data['batch_idx_to_resume']
|
|
234
|
+
accumulated_steps = ckpt_data['accumulated_steps']
|
|
235
|
+
print(f"♻️ Resuming training from epoch {start_epoch}, step {global_step}")
|
|
236
|
+
|
|
237
|
+
# --- print info ---
|
|
238
|
+
print(f"🧠 Number of parameters: {sum(p.numel() for p in self.model.parameters()):,}")
|
|
239
|
+
print(f"🍱 Number of train samples: {len(self.train_dataset):,}")
|
|
240
|
+
print(f"📊 Number of eval samples: {len(self.eval_dataset):,}")
|
|
241
|
+
print(f"📦 Train steps per epoch (batches): {len(train_loader):,}")
|
|
242
|
+
print(f"📦 Eval steps per epoch (batches): {len(eval_loader):,}")
|
|
243
|
+
|
|
244
|
+
for epoch in range(start_epoch, num_epoch):
|
|
245
|
+
model.train()
|
|
246
|
+
epoch_loss = 0
|
|
247
|
+
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epoch}", leave=False)
|
|
248
|
+
for batch_idx, batch in enumerate(pbar):
|
|
249
|
+
if epoch == start_epoch and self.resume_training:
|
|
250
|
+
if batch_idx < batch_idx_to_resume:
|
|
251
|
+
continue
|
|
252
|
+
elif batch_idx == batch_idx_to_resume:
|
|
253
|
+
batch_idx_to_resume = 0
|
|
254
|
+
|
|
255
|
+
# --- load the inputs and targets ---
|
|
256
|
+
inputs, targets = batch
|
|
257
|
+
inputs = inputs.to(self.device, non_blocking=True)
|
|
258
|
+
targets = targets.to(self.device, non_blocking=True)
|
|
259
|
+
|
|
260
|
+
# --- get prediction from model ---
|
|
261
|
+
output = model(inputs)
|
|
262
|
+
|
|
263
|
+
# --- calculate loss ---
|
|
264
|
+
loss = criterion(
|
|
265
|
+
output.view(-1, self.vocab_size),
|
|
266
|
+
targets.view(-1),
|
|
267
|
+
ignore_index=-100
|
|
268
|
+
)
|
|
269
|
+
loss = loss / self.grad_accumulation_step
|
|
270
|
+
loss.backward()
|
|
271
|
+
|
|
272
|
+
pbar.set_postfix(loss=loss.item() * self.grad_accumulation_step)
|
|
273
|
+
epoch_loss += loss.item() * self.grad_accumulation_step
|
|
274
|
+
accumulated_steps += 1
|
|
275
|
+
|
|
276
|
+
# --- gradient accumulation ---
|
|
277
|
+
if accumulated_steps % self.grad_accumulation_step == 0:
|
|
278
|
+
# Gradient clipping for stable training
|
|
279
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
280
|
+
optimizer.step()
|
|
281
|
+
if scheduler is not None:
|
|
282
|
+
scheduler.step()
|
|
283
|
+
optimizer.zero_grad()
|
|
284
|
+
global_step += 1
|
|
285
|
+
accumulated_steps = 0
|
|
286
|
+
|
|
287
|
+
is_last_step = (self.max_steps is not None and global_step >= self.max_steps)
|
|
288
|
+
# check eval_per_step
|
|
289
|
+
if (self.eval_per_step is not None and global_step+1 % self.eval_per_step == 0) or is_last_step:
|
|
290
|
+
avg_eval_loss = self.eval_model(model, eval_loader, self.max_eval_step)
|
|
291
|
+
print(f"🎯 Eval loss: {avg_eval_loss:.4f}")
|
|
292
|
+
|
|
293
|
+
# Check max steps
|
|
294
|
+
if is_last_step:
|
|
295
|
+
self.save_model(
|
|
296
|
+
model=model, optimizer=optimizer, scheduler=scheduler,
|
|
297
|
+
epoch=epoch+1, num_epoch=num_epoch, loss=epoch_loss,
|
|
298
|
+
global_step=global_step, output_dir=self.output_dir,
|
|
299
|
+
batch_idx_to_resume=batch_idx+1,accumulated_steps=accumulated_steps,
|
|
300
|
+
name=f'final_step_epoch_{epoch+1}_step_{global_step}'
|
|
301
|
+
)
|
|
302
|
+
return
|
|
303
|
+
|
|
304
|
+
# Save at specific steps
|
|
305
|
+
if (self.Save_step is not None and
|
|
306
|
+
global_step > 0 and
|
|
307
|
+
global_step % self.Save_step == 0):
|
|
308
|
+
self.save_model(
|
|
309
|
+
model=model, optimizer=optimizer, scheduler=scheduler,
|
|
310
|
+
epoch=epoch+1, num_epoch=num_epoch, loss=epoch_loss,
|
|
311
|
+
global_step=global_step, output_dir=self.output_dir,
|
|
312
|
+
batch_idx_to_resume=batch_idx+1,accumulated_steps=accumulated_steps,
|
|
313
|
+
name=f'epoch_{epoch+1}_step_{global_step}'
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
# Handle remaining accumulated gradients at epoch end
|
|
317
|
+
if accumulated_steps > 0:
|
|
318
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
319
|
+
optimizer.step()
|
|
320
|
+
if scheduler is not None:
|
|
321
|
+
scheduler.step()
|
|
322
|
+
optimizer.zero_grad()
|
|
323
|
+
global_step += 1
|
|
324
|
+
|
|
325
|
+
is_last_epoch = (self.max_epoch is not None and (epoch+1) == self.max_epoch)
|
|
326
|
+
|
|
327
|
+
# Evaluation
|
|
328
|
+
if (self.eval_per_epoch is not None and (epoch+1) % self.eval_per_epoch == 0) or is_last_epoch:
|
|
329
|
+
avg_eval_loss = self.eval_model(model, eval_loader, self.max_eval_step)
|
|
330
|
+
print(f"🎯 Eval loss: {avg_eval_loss:.4f}")
|
|
331
|
+
|
|
332
|
+
# Check max epoch
|
|
333
|
+
if is_last_epoch:
|
|
334
|
+
self.save_model(
|
|
335
|
+
model=model, optimizer=optimizer, scheduler=scheduler,
|
|
336
|
+
epoch=epoch+1, num_epoch=num_epoch, loss=epoch_loss,
|
|
337
|
+
global_step=global_step, output_dir=self.output_dir,
|
|
338
|
+
batch_idx_to_resume=batch_idx+1,accumulated_steps=accumulated_steps,
|
|
339
|
+
name=f'final_model_epoch_{epoch+1}_step_{global_step}'
|
|
340
|
+
)
|
|
341
|
+
return
|
|
342
|
+
|
|
343
|
+
# print epoch loss
|
|
344
|
+
avg_epoch_loss = epoch_loss / len(train_loader)
|
|
345
|
+
print(f"🔥 Epoch {epoch+1} finished - Training Loss: {avg_epoch_loss:.4f}")
|
|
346
|
+
|
|
347
|
+
# Save at specific epochs
|
|
348
|
+
if (self.Save_epoch is not None and
|
|
349
|
+
(epoch + 1) % self.Save_epoch == 0):
|
|
350
|
+
self.save_model(
|
|
351
|
+
model=model, optimizer=optimizer, scheduler=scheduler,
|
|
352
|
+
epoch=epoch+1, num_epoch=num_epoch, loss=epoch_loss,
|
|
353
|
+
global_step=global_step, output_dir=self.output_dir,
|
|
354
|
+
batch_idx_to_resume=batch_idx+1,accumulated_steps=accumulated_steps,
|
|
355
|
+
name=f'epoch_{epoch+1}_step_{global_step}'
|
|
356
|
+
)
|
|
@@ -1,18 +0,0 @@
|
|
|
1
|
-
LICENSE
|
|
2
|
-
README.md
|
|
3
|
-
pyproject.toml
|
|
4
|
-
setup.py
|
|
5
|
-
Stackformer.egg-info/PKG-INFO
|
|
6
|
-
Stackformer.egg-info/SOURCES.txt
|
|
7
|
-
Stackformer.egg-info/dependency_links.txt
|
|
8
|
-
Stackformer.egg-info/requires.txt
|
|
9
|
-
Stackformer.egg-info/top_level.txt
|
|
10
|
-
models/GPT_2.py
|
|
11
|
-
models/__init__.py
|
|
12
|
-
modules/Attention.py
|
|
13
|
-
modules/Feed_forward.py
|
|
14
|
-
modules/Normalization.py
|
|
15
|
-
modules/__init__.py
|
|
16
|
-
modules/mask.py
|
|
17
|
-
modules/position_embedding.py
|
|
18
|
-
modules/tokenizer.py
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|