sarasa 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sarasa/__init__.py +2 -0
- sarasa/activation_checkpoint.py +81 -0
- sarasa/checkpoint.py +112 -0
- sarasa/config.py +279 -0
- sarasa/data/__init__.py +36 -0
- sarasa/data/hf_datasets.py +115 -0
- sarasa/data/tokenizer.py +63 -0
- sarasa/metrics.py +294 -0
- sarasa/models/__init__.py +95 -0
- sarasa/models/attention.py +84 -0
- sarasa/models/llama3.py +129 -0
- sarasa/models/nanochat_gpt.py +192 -0
- sarasa/models/utils.py +39 -0
- sarasa/optimizers/__init__.py +77 -0
- sarasa/optimizers/utils.py +27 -0
- sarasa/trainer.py +244 -0
- sarasa/utils.py +163 -0
- sarasa-0.0.2.dist-info/METADATA +138 -0
- sarasa-0.0.2.dist-info/RECORD +21 -0
- sarasa-0.0.2.dist-info/WHEEL +4 -0
- sarasa-0.0.2.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
# NanoChat's GPT model, adapted from https://github.com/karpathy/nanochat
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from loguru import logger
|
|
6
|
+
from torch import nn
|
|
7
|
+
from torch.nn import functional as F
|
|
8
|
+
|
|
9
|
+
from sarasa.models import BaseModel, ModelConfig
|
|
10
|
+
from sarasa.models.attention import CausalSelfAttention
|
|
11
|
+
from sarasa.models.utils import RMSNorm, RoPE
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MLP(nn.Module):
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
config: ModelConfig,
|
|
18
|
+
):
|
|
19
|
+
super().__init__()
|
|
20
|
+
self.c_fc = nn.Linear(config.hidden_dim, 4 * config.hidden_dim, bias=False)
|
|
21
|
+
self.c_proj = nn.Linear(4 * config.hidden_dim, config.hidden_dim, bias=False)
|
|
22
|
+
|
|
23
|
+
def forward(self, x):
|
|
24
|
+
x = self.c_fc(x)
|
|
25
|
+
x = F.relu(x).square()
|
|
26
|
+
x = self.c_proj(x)
|
|
27
|
+
return x
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class Block(nn.Module):
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
config: ModelConfig,
|
|
34
|
+
layer_idx: int,
|
|
35
|
+
):
|
|
36
|
+
super().__init__()
|
|
37
|
+
self.attn = CausalSelfAttention(config, layer_idx)
|
|
38
|
+
self.mlp = MLP(config)
|
|
39
|
+
self.norm = RMSNorm(config.hidden_dim)
|
|
40
|
+
|
|
41
|
+
def forward(
|
|
42
|
+
self,
|
|
43
|
+
x: torch.Tensor,
|
|
44
|
+
cos_sin: tuple[torch.Tensor, torch.Tensor],
|
|
45
|
+
) -> torch.Tensor:
|
|
46
|
+
x = x + self.attn(self.norm(x), cos_sin)
|
|
47
|
+
x = x + self.mlp(self.norm(x))
|
|
48
|
+
return x
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class GPT(BaseModel):
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
config: ModelConfig,
|
|
55
|
+
pad_vocab_size_to=64,
|
|
56
|
+
):
|
|
57
|
+
"""
|
|
58
|
+
NOTE a major footgun: this __init__ function runs in meta device context (!!)
|
|
59
|
+
Therefore, any calculations inside here are shapes and dtypes only, no actual data.
|
|
60
|
+
=> We actually initialize all data (parameters, buffers, etc.) in init_weights() instead.
|
|
61
|
+
"""
|
|
62
|
+
super().__init__()
|
|
63
|
+
self.config = config
|
|
64
|
+
self.num_heads = config.num_heads
|
|
65
|
+
self.hidden_dim = config.hidden_dim
|
|
66
|
+
self.seq_len = config.seq_len
|
|
67
|
+
self.vocab_size = config.vocab_size
|
|
68
|
+
self.num_layers = config.num_layers
|
|
69
|
+
# For DDP, we want vocab_size divisible by world_size. Also, there are potential performance benefits, see:
|
|
70
|
+
# https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings
|
|
71
|
+
padded_vocab_size = ((self.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
|
|
72
|
+
if padded_vocab_size != self.vocab_size:
|
|
73
|
+
logger.warning(
|
|
74
|
+
f"Padding vocab_size from {self.vocab_size} to {padded_vocab_size} to be divisible by {pad_vocab_size_to}"
|
|
75
|
+
)
|
|
76
|
+
self.token_emb = nn.Embedding(padded_vocab_size, self.hidden_dim)
|
|
77
|
+
self.blocks = nn.ModuleList([Block(config, layer_idx) for layer_idx in range(self.num_layers)])
|
|
78
|
+
self.lm_head = nn.Linear(self.hidden_dim, padded_vocab_size, bias=False)
|
|
79
|
+
self.norm = RMSNorm(self.hidden_dim)
|
|
80
|
+
# Per-layer learnable scalars (inspired by modded-nanogpt)
|
|
81
|
+
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
|
|
82
|
+
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
|
|
83
|
+
# Separate parameters so they can have different optimizer treatment
|
|
84
|
+
self.resid_lambdas = nn.Parameter(torch.ones(self.num_layers)) # fake init, real init in init_weights()
|
|
85
|
+
self.x0_lambdas = nn.Parameter(torch.zeros(self.num_layers)) # fake init, real init in init_weights()
|
|
86
|
+
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
|
|
87
|
+
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
|
88
|
+
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
|
|
89
|
+
# In the future we can dynamically grow the cache, for now it's fine.
|
|
90
|
+
self.rotary_seq_len = self.seq_len * 16 # 10X over-compute should be enough, TODO make nicer?
|
|
91
|
+
cos, sin = RoPE.precompute(self.rotary_seq_len, config.head_dim)
|
|
92
|
+
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
|
93
|
+
self.register_buffer("sin", sin, persistent=False)
|
|
94
|
+
|
|
95
|
+
@torch.no_grad()
|
|
96
|
+
def init_weights(self):
|
|
97
|
+
"""
|
|
98
|
+
Initialize the full model in this one function for maximum clarity.
|
|
99
|
+
|
|
100
|
+
wte (embedding): normal, std=1.0
|
|
101
|
+
lm_head: normal, std=0.001
|
|
102
|
+
for each block:
|
|
103
|
+
attn.c_q: uniform, std=1/sqrt(n_embd)
|
|
104
|
+
attn.c_k: uniform, std=1/sqrt(n_embd)
|
|
105
|
+
attn.c_v: uniform, std=1/sqrt(n_embd)
|
|
106
|
+
attn.c_proj: zeros
|
|
107
|
+
mlp.c_fc: uniform, std=1/sqrt(n_embd)
|
|
108
|
+
mlp.c_proj: zeros
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
# Embedding and unembedding
|
|
112
|
+
torch.nn.init.normal_(self.token_emb.weight, mean=0.0, std=1.0)
|
|
113
|
+
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
|
|
114
|
+
|
|
115
|
+
# Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal)
|
|
116
|
+
n_embd = self.hidden_dim
|
|
117
|
+
s = 3**0.5 * n_embd**-0.5 # sqrt(3) multiplier makes sure Uniform achieves the same std as Normal
|
|
118
|
+
for block in self.blocks:
|
|
119
|
+
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers
|
|
120
|
+
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
|
|
121
|
+
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
|
|
122
|
+
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
|
|
123
|
+
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
|
|
124
|
+
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
|
125
|
+
|
|
126
|
+
# Per-layer scalars
|
|
127
|
+
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
|
|
128
|
+
self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init
|
|
129
|
+
|
|
130
|
+
# Rotary embeddings
|
|
131
|
+
head_dim = self.hidden_dim // self.num_heads
|
|
132
|
+
self.cos, self.sin = RoPE.precompute(self.rotary_seq_len, head_dim, device=self.cos.device)
|
|
133
|
+
|
|
134
|
+
# Cast token embeddings to bf16: optimizer can tolerate it and it saves memory
|
|
135
|
+
if self.token_emb.weight.device.type == "cuda":
|
|
136
|
+
self.token_emb.to(dtype=torch.bfloat16)
|
|
137
|
+
|
|
138
|
+
def param_groups(
|
|
139
|
+
self,
|
|
140
|
+
) -> dict[str, list[torch.nn.Parameter]]:
|
|
141
|
+
# Separate out all parameters into 5 groups (matrix, embedding, lm_head, resid_lambdas, x0_lambdas)
|
|
142
|
+
matrix_params = list(self.blocks.parameters())
|
|
143
|
+
embedding_params = list(self.token_emb.parameters())
|
|
144
|
+
lm_head_params = list(self.lm_head.parameters())
|
|
145
|
+
resid_params = [self.resid_lambdas]
|
|
146
|
+
x0_params = [self.x0_lambdas]
|
|
147
|
+
assert len(list(self.parameters())) == (
|
|
148
|
+
len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(resid_params) + len(x0_params)
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
return {
|
|
152
|
+
"matrix": matrix_params,
|
|
153
|
+
"embedding": embedding_params,
|
|
154
|
+
"lm_head": lm_head_params,
|
|
155
|
+
"resid_lambdas": resid_params,
|
|
156
|
+
"x0_lambdas": x0_params,
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
def forward(
|
|
160
|
+
self,
|
|
161
|
+
input: torch.Tensor,
|
|
162
|
+
) -> torch.Tensor:
|
|
163
|
+
B, T = input.size()
|
|
164
|
+
|
|
165
|
+
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
|
|
166
|
+
assert T <= self.cos.size(1), (
|
|
167
|
+
f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
|
168
|
+
)
|
|
169
|
+
assert input.device == self.cos.device, (
|
|
170
|
+
f"Rotary embeddings and idx are on different devices: {input.device} != {self.cos.device}"
|
|
171
|
+
)
|
|
172
|
+
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
|
|
173
|
+
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
|
|
174
|
+
cos_sin = self.cos[:, :T], self.sin[:, :T] # truncate cache to current sequence length
|
|
175
|
+
|
|
176
|
+
# Forward the trunk of the Transformer
|
|
177
|
+
x = self.token_emb(input)
|
|
178
|
+
x = self.norm(x)
|
|
179
|
+
x0 = x # save initial normalized embedding for x0 residual
|
|
180
|
+
for block, resid_lambda, x0_lambda in zip(self.blocks, self.resid_lambdas, self.x0_lambdas):
|
|
181
|
+
x = resid_lambda * x + x0_lambda * x0
|
|
182
|
+
x = block(x, cos_sin)
|
|
183
|
+
x = self.norm(x)
|
|
184
|
+
|
|
185
|
+
# Forward the lm_head (compute logits)
|
|
186
|
+
softcap = 15 # smoothly cap the logits to the range [-softcap, softcap]
|
|
187
|
+
logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory
|
|
188
|
+
logits = logits[..., : self.vocab_size] # slice to remove padding
|
|
189
|
+
logits = logits.float() # switch to fp32 for logit softcap and loss computation
|
|
190
|
+
logits = softcap * torch.tanh(logits / softcap) # squash the logits
|
|
191
|
+
|
|
192
|
+
return logits
|
sarasa/models/utils.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class RMSNorm(torch.nn.RMSNorm):
|
|
5
|
+
# RMSNorm without affine parameters
|
|
6
|
+
def __init__(
|
|
7
|
+
self,
|
|
8
|
+
normalized_shape: int,
|
|
9
|
+
):
|
|
10
|
+
super().__init__(normalized_shape, eps=None, elementwise_affine=False)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class RoPE:
|
|
14
|
+
@staticmethod
|
|
15
|
+
def precompute(
|
|
16
|
+
seq_len: int,
|
|
17
|
+
head_dim: int,
|
|
18
|
+
device: torch.device = None,
|
|
19
|
+
base: float = 10000,
|
|
20
|
+
):
|
|
21
|
+
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
|
|
22
|
+
inv_freq = 1.0 / (base ** (channel_range / head_dim))
|
|
23
|
+
t = torch.arange(seq_len, dtype=torch.float32, device=device)
|
|
24
|
+
freqs = torch.outer(t, inv_freq)[None, :, None, :]
|
|
25
|
+
cos, sin = freqs.cos(), freqs.sin()
|
|
26
|
+
cos, sin = cos.bfloat16(), sin.bfloat16()
|
|
27
|
+
return cos, sin
|
|
28
|
+
|
|
29
|
+
@staticmethod
|
|
30
|
+
def apply(
|
|
31
|
+
x: torch.Tensor,
|
|
32
|
+
cos: torch.Tensor,
|
|
33
|
+
sin: torch.Tensor,
|
|
34
|
+
) -> torch.Tensor:
|
|
35
|
+
assert x.ndim == 4
|
|
36
|
+
x1, x2 = x.chunk(2, dim=-1)
|
|
37
|
+
y1 = x1 * cos + x2 * sin
|
|
38
|
+
y2 = x1 * (-sin) + x2 * cos
|
|
39
|
+
return torch.cat([y1, y2], 3)
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from sarasa.models import BaseModel
|
|
7
|
+
from sarasa.optimizers.utils import GroupedOptimizer
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclasses.dataclass
|
|
11
|
+
class AdamW:
|
|
12
|
+
"""
|
|
13
|
+
Default optimizer
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
lr: float = 1e-4
|
|
17
|
+
weight_decay: float = 0.1
|
|
18
|
+
betas: tuple[float, float] = (0.9, 0.95)
|
|
19
|
+
|
|
20
|
+
def create(
|
|
21
|
+
self,
|
|
22
|
+
model: BaseModel,
|
|
23
|
+
) -> torch.optim.Optimizer:
|
|
24
|
+
param_groups = model.param_groups()
|
|
25
|
+
params = sum(param_groups.values(), [])
|
|
26
|
+
optimizer = torch.optim.AdamW(
|
|
27
|
+
params,
|
|
28
|
+
lr=torch.tensor(self.lr, dtype=torch.float32),
|
|
29
|
+
weight_decay=self.weight_decay,
|
|
30
|
+
betas=self.betas,
|
|
31
|
+
fused=True,
|
|
32
|
+
)
|
|
33
|
+
return optimizer
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclasses.dataclass
|
|
37
|
+
class Muon:
|
|
38
|
+
"""
|
|
39
|
+
Muon optimizer
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
lr: float = 1e-4
|
|
43
|
+
weight_decay: float = 0.1
|
|
44
|
+
momentum: float = 0.9
|
|
45
|
+
|
|
46
|
+
adam_lr: float | None = None
|
|
47
|
+
adam_betas: tuple[float, float] = (0.9, 0.95)
|
|
48
|
+
adam_weight_decay: float = 0
|
|
49
|
+
|
|
50
|
+
adjust_lr_fn: Literal["original", "match_rms_adamw"] = "match_rms_adamw"
|
|
51
|
+
|
|
52
|
+
def __post_init__(self):
|
|
53
|
+
self.adam_lr = self.adam_lr or self.lr
|
|
54
|
+
|
|
55
|
+
def create(
|
|
56
|
+
self,
|
|
57
|
+
model: BaseModel,
|
|
58
|
+
) -> torch.optim.Optimizer:
|
|
59
|
+
param_groups = model.param_groups()
|
|
60
|
+
|
|
61
|
+
muon = torch.optim.Muon(
|
|
62
|
+
param_groups["matrix"],
|
|
63
|
+
lr=self.lr,
|
|
64
|
+
weight_decay=self.weight_decay,
|
|
65
|
+
momentum=self.momentum,
|
|
66
|
+
adjust_lr_fn=self.adjust_lr_fn,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
adam = torch.optim.AdamW(
|
|
70
|
+
sum([param_groups[k] for k in param_groups if k != "matrix"], []),
|
|
71
|
+
lr=self.adam_lr,
|
|
72
|
+
betas=self.adam_betas,
|
|
73
|
+
weight_decay=self.adam_weight_decay,
|
|
74
|
+
fused=True,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
return GroupedOptimizer(muon, adam)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class GroupedOptimizer(torch.optim.Optimizer):
|
|
5
|
+
def __init__(
|
|
6
|
+
self,
|
|
7
|
+
*optimizers: torch.optim.Optimizer,
|
|
8
|
+
):
|
|
9
|
+
super().__init__(sum([optim.param_groups for optim in optimizers], []), {})
|
|
10
|
+
self.optimizers = optimizers
|
|
11
|
+
|
|
12
|
+
def step(self) -> None:
|
|
13
|
+
for optim in self.optimizers:
|
|
14
|
+
optim.step()
|
|
15
|
+
|
|
16
|
+
def zero_grad(
|
|
17
|
+
self,
|
|
18
|
+
set_to_none: bool = True,
|
|
19
|
+
) -> None:
|
|
20
|
+
for optim in self.optimizers:
|
|
21
|
+
optim.zero_grad(set_to_none=set_to_none)
|
|
22
|
+
|
|
23
|
+
def state_dict(self) -> dict:
|
|
24
|
+
return super().state_dict()
|
|
25
|
+
|
|
26
|
+
def load_state_dict(self, state_dict: dict) -> None:
|
|
27
|
+
super().load_state_dict(state_dict)
|
sarasa/trainer.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import os
|
|
3
|
+
import time
|
|
4
|
+
from collections.abc import Iterable
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.distributed as dist
|
|
8
|
+
from loguru import logger
|
|
9
|
+
from torch.distributed.elastic.multiprocessing.errors import record
|
|
10
|
+
|
|
11
|
+
from sarasa.activation_checkpoint import apply_op_sac
|
|
12
|
+
from sarasa.checkpoint import Checkpointer
|
|
13
|
+
from sarasa.config import Config
|
|
14
|
+
from sarasa.metrics import MetricsProcessor
|
|
15
|
+
from sarasa.utils import GarbageCollector, apply_distributed, init_distributed, set_dtype, update_timeout, world_size
|
|
16
|
+
|
|
17
|
+
IGNORE_INDEX = -100
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Trainer:
|
|
21
|
+
@record
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
config: Config,
|
|
25
|
+
) -> None:
|
|
26
|
+
self.config = config
|
|
27
|
+
logger.info(f"Initializing Trainer with config: {self.config}")
|
|
28
|
+
|
|
29
|
+
# set seed
|
|
30
|
+
torch.manual_seed(config.seed)
|
|
31
|
+
os.environ["PYTHONHASHSEED"] = str(config.seed % 2**32)
|
|
32
|
+
|
|
33
|
+
# setup device
|
|
34
|
+
torch.accelerator.set_device_index(int(os.environ.get("LOCAL_RANK", 0)))
|
|
35
|
+
self.device = torch.accelerator.current_accelerator(check_available=True)
|
|
36
|
+
|
|
37
|
+
self.gc = GarbageCollector(config.train.gc_freq)
|
|
38
|
+
|
|
39
|
+
# setup distributed
|
|
40
|
+
init_distributed(config.distributed.backend, config.distributed.init_timeout_seconds)
|
|
41
|
+
|
|
42
|
+
# setup data and tokenizer -> use vocab size for model setup
|
|
43
|
+
data = config.data.create(batch_size=config.train.local_batch_size)
|
|
44
|
+
self.data_loader = data["train_loader"] # setup data loader
|
|
45
|
+
self.val_loader = data.get("val_loader", None) # setup eval data loader
|
|
46
|
+
self.tokenizer = data["tokenizer"] # setup tokenizer
|
|
47
|
+
|
|
48
|
+
vocab_size = len(self.tokenizer)
|
|
49
|
+
self.config.model.vocab_size = vocab_size
|
|
50
|
+
|
|
51
|
+
# todo: support other loss functions
|
|
52
|
+
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX, reduction="sum")
|
|
53
|
+
|
|
54
|
+
# setup model, optimizer, lr scheduler
|
|
55
|
+
with torch.device("meta"), set_dtype(getattr(torch, config.train.dtype)):
|
|
56
|
+
self.model = self.config.model.create()
|
|
57
|
+
num_params, flops_per_token = self.model.num_params_flops
|
|
58
|
+
model_size = num_params / 1e9
|
|
59
|
+
model_size, unit = (num_params / 1e6, "M") if model_size < 1 else (model_size, "B")
|
|
60
|
+
logger.info(f"Model created with {model_size:.2f}{unit} parameters")
|
|
61
|
+
|
|
62
|
+
# following torchtitan, (S)AC -> compilation -> distributed wrapping
|
|
63
|
+
if config.train.use_sac:
|
|
64
|
+
logger.info("Applying Selective Activation Checkpointing (SAC)")
|
|
65
|
+
for i, block in enumerate(self.model.blocks):
|
|
66
|
+
self.model.blocks[i] = apply_op_sac(block)
|
|
67
|
+
|
|
68
|
+
if config.train.compile:
|
|
69
|
+
logger.info("Compiling the model")
|
|
70
|
+
for block in self.model.blocks:
|
|
71
|
+
block.compile(fullgraph=True)
|
|
72
|
+
self.model.compile(dynamic=False)
|
|
73
|
+
self.loss_fn.compile()
|
|
74
|
+
|
|
75
|
+
if world_size() > 1:
|
|
76
|
+
apply_distributed(
|
|
77
|
+
config.distributed,
|
|
78
|
+
self.model,
|
|
79
|
+
device=self.device,
|
|
80
|
+
compile=config.train.compile,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
self.model.to_empty(device=self.device)
|
|
84
|
+
self.model.init_weights()
|
|
85
|
+
|
|
86
|
+
self.optimizer = self.config.optim.create(self.model)
|
|
87
|
+
self.lr_scheduler = self.config.lr_scheduler.create(self.optimizer, config.train.steps)
|
|
88
|
+
|
|
89
|
+
# setup metrics and checkpointer
|
|
90
|
+
# todo: configure num_flops_per_token
|
|
91
|
+
self.metrics_processor = MetricsProcessor(config, self.device, flops_per_token)
|
|
92
|
+
self.checkpointer = Checkpointer(config, self.model) if config.checkpoint.save_freq > 0 else None
|
|
93
|
+
|
|
94
|
+
dev_mem_stats = self.metrics_processor.device_mem_monitor.get_peak_stats()
|
|
95
|
+
logger.info(
|
|
96
|
+
f"{self.device.type.upper()} memory: {dev_mem_stats.max_reserved_gib:.2f} GiB for model initialization"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
self.step = 0
|
|
100
|
+
self.grad_accum_steps = config.train.global_batch_size // (config.train.local_batch_size * world_size())
|
|
101
|
+
logger.info(f"Gradient accumulation step is set to: {self.grad_accum_steps}")
|
|
102
|
+
|
|
103
|
+
self.amp_context = contextlib.nullcontext()
|
|
104
|
+
if config.distributed.name != "fsdp":
|
|
105
|
+
self.amp_context = torch.autocast(device_type=self.device.type, dtype=getattr(torch, config.train.dtype))
|
|
106
|
+
|
|
107
|
+
# todo: setup profiler context
|
|
108
|
+
self.profile_context = contextlib.nullcontext()
|
|
109
|
+
|
|
110
|
+
if config.train.use_fa4:
|
|
111
|
+
logger.info("Using FA4 flash attention")
|
|
112
|
+
try:
|
|
113
|
+
torch.nn.attention.activate_flash_attention_impl("FA4")
|
|
114
|
+
except Exception as e:
|
|
115
|
+
logger.warning(
|
|
116
|
+
f"Failed to activate FA4 flash attention: {e}. Install sarasa with `flash_attn` extra for better performance."
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
def __del__(self) -> None:
|
|
120
|
+
# cleanup distributed
|
|
121
|
+
if world_size() > 1:
|
|
122
|
+
try:
|
|
123
|
+
dist.destroy_process_group()
|
|
124
|
+
except Exception as e:
|
|
125
|
+
logger.warning(f"Failed to destroy process group: {e}")
|
|
126
|
+
|
|
127
|
+
@record
|
|
128
|
+
def train(self):
|
|
129
|
+
logger.info("Starting training...")
|
|
130
|
+
|
|
131
|
+
self.model.train()
|
|
132
|
+
with self.profile_context:
|
|
133
|
+
data_iter = self.batch_generator(self.data_loader)
|
|
134
|
+
for _ in range(self.config.train.steps):
|
|
135
|
+
self.step += 1
|
|
136
|
+
self.gc.collect(self.step)
|
|
137
|
+
try:
|
|
138
|
+
self.train_step(data_iter)
|
|
139
|
+
except StopIteration:
|
|
140
|
+
logger.warning("Data loader exhausted during training.")
|
|
141
|
+
break
|
|
142
|
+
|
|
143
|
+
if self.checkpointer is not None:
|
|
144
|
+
self.checkpointer.save(self.step)
|
|
145
|
+
|
|
146
|
+
if self.config.train.val_freq > 0 and self.step % self.config.train.val_freq == 0:
|
|
147
|
+
self.evaluate()
|
|
148
|
+
|
|
149
|
+
if world_size() > 1 and self.step == 1:
|
|
150
|
+
update_timeout(self.config.distributed.train_timeout_seconds, self.device)
|
|
151
|
+
|
|
152
|
+
logger.info("Training completed.")
|
|
153
|
+
|
|
154
|
+
def batch_generator(
|
|
155
|
+
self,
|
|
156
|
+
data_iter: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]],
|
|
157
|
+
) -> Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]:
|
|
158
|
+
data_iter = iter(data_iter)
|
|
159
|
+
while True:
|
|
160
|
+
begin = time.perf_counter()
|
|
161
|
+
batch = next(data_iter)
|
|
162
|
+
input_dict, target = batch
|
|
163
|
+
self.metrics_processor.ntokens_since_last_log += target.numel()
|
|
164
|
+
self.metrics_processor.data_load_times.append(time.perf_counter() - begin)
|
|
165
|
+
yield input_dict, target
|
|
166
|
+
|
|
167
|
+
def train_step(
|
|
168
|
+
self,
|
|
169
|
+
batch_iter: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]],
|
|
170
|
+
) -> None:
|
|
171
|
+
self.optimizer.zero_grad()
|
|
172
|
+
|
|
173
|
+
micro_batches = []
|
|
174
|
+
valid_tokens = torch.tensor(0, dtype=torch.long)
|
|
175
|
+
for _ in range(self.grad_accum_steps):
|
|
176
|
+
input_dict, target = next(batch_iter)
|
|
177
|
+
valid_tokens += (target != IGNORE_INDEX).sum()
|
|
178
|
+
micro_batches.append((input_dict, target))
|
|
179
|
+
|
|
180
|
+
valid_tokens = valid_tokens.to(self.device)
|
|
181
|
+
if world_size() > 1:
|
|
182
|
+
dist.all_reduce(valid_tokens, op=dist.ReduceOp.SUM)
|
|
183
|
+
|
|
184
|
+
losses = []
|
|
185
|
+
for input_dict, target in micro_batches:
|
|
186
|
+
input_dict = {
|
|
187
|
+
k: v.to(self.device, non_blocking=(self.device.type == "cuda")) for k, v in input_dict.items()
|
|
188
|
+
}
|
|
189
|
+
target = target.to(self.device, non_blocking=(self.device.type == "cuda"))
|
|
190
|
+
|
|
191
|
+
with self.amp_context:
|
|
192
|
+
pred = self.model(**input_dict)
|
|
193
|
+
loss = self.loss_fn(pred.flatten(0, 1), target.flatten(0, 1)) / valid_tokens
|
|
194
|
+
|
|
195
|
+
del pred
|
|
196
|
+
loss.backward()
|
|
197
|
+
losses.append(loss.detach())
|
|
198
|
+
|
|
199
|
+
if self.config.train.grad_clip is not None:
|
|
200
|
+
torch.nn.utils.clip_grad_norm_(
|
|
201
|
+
self.model.parameters(), self.config.train.grad_clip, foreach=self.device.type == "cuda"
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
if self.checkpointer is not None:
|
|
205
|
+
self.checkpointer.wait_for_staging()
|
|
206
|
+
|
|
207
|
+
self.optimizer.step()
|
|
208
|
+
self.lr_scheduler.step()
|
|
209
|
+
|
|
210
|
+
loss = torch.stack(losses).sum()
|
|
211
|
+
|
|
212
|
+
if not self.metrics_processor.should_log(self.step):
|
|
213
|
+
return
|
|
214
|
+
|
|
215
|
+
if world_size() > 1:
|
|
216
|
+
avg_loss = loss.clone()
|
|
217
|
+
dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
|
|
218
|
+
max_loss = loss.clone()
|
|
219
|
+
dist.all_reduce(max_loss, op=dist.ReduceOp.MAX)
|
|
220
|
+
else:
|
|
221
|
+
avg_loss = max_loss = loss
|
|
222
|
+
|
|
223
|
+
with torch.no_grad():
|
|
224
|
+
grad_norm = torch.nn.utils.get_total_norm(self.model.parameters(), foreach=self.device.type == "cuda")
|
|
225
|
+
|
|
226
|
+
lr = self.lr_scheduler.get_last_lr()[0]
|
|
227
|
+
self.metrics_processor.log(
|
|
228
|
+
self.step,
|
|
229
|
+
global_avg_loss=avg_loss.item(),
|
|
230
|
+
global_max_loss=max_loss.item(),
|
|
231
|
+
extra_metrics={
|
|
232
|
+
"grad_norm": grad_norm.item() if grad_norm >= 0 else float("nan"),
|
|
233
|
+
"lr": lr,
|
|
234
|
+
},
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
def evaluate(self):
|
|
238
|
+
raise NotImplementedError
|
|
239
|
+
|
|
240
|
+
def evaluation_step(
|
|
241
|
+
self,
|
|
242
|
+
batch_iter: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]],
|
|
243
|
+
) -> None:
|
|
244
|
+
raise NotImplementedError
|