sarasa 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sarasa/__init__.py +2 -0
- sarasa/activation_checkpoint.py +81 -0
- sarasa/checkpoint.py +112 -0
- sarasa/config.py +279 -0
- sarasa/data/__init__.py +36 -0
- sarasa/data/hf_datasets.py +115 -0
- sarasa/data/tokenizer.py +63 -0
- sarasa/metrics.py +294 -0
- sarasa/models/__init__.py +95 -0
- sarasa/models/attention.py +84 -0
- sarasa/models/llama3.py +129 -0
- sarasa/models/nanochat_gpt.py +192 -0
- sarasa/models/utils.py +39 -0
- sarasa/optimizers/__init__.py +77 -0
- sarasa/optimizers/utils.py +27 -0
- sarasa/trainer.py +244 -0
- sarasa/utils.py +163 -0
- sarasa-0.0.2.dist-info/METADATA +138 -0
- sarasa-0.0.2.dist-info/RECORD +21 -0
- sarasa-0.0.2.dist-info/WHEEL +4 -0
- sarasa-0.0.2.dist-info/licenses/LICENSE +201 -0
sarasa/metrics.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
import subprocess
|
|
3
|
+
import time
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from loguru import logger
|
|
9
|
+
from torch._utils import _get_device_module
|
|
10
|
+
|
|
11
|
+
from sarasa.config import Config
|
|
12
|
+
from sarasa.utils import rank
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# ported from torchtitan
|
|
16
|
+
# hardcoded BF16 type peak flops for NVIDIA A100, H100, H200, B200 GPU and AMD MI250, MI300X, MI325X, MI355X and Intel PVC
|
|
17
|
+
def get_peak_flops(device_name: str) -> float:
|
|
18
|
+
try:
|
|
19
|
+
# Run the lspci command and capture the output
|
|
20
|
+
result = subprocess.run(["lspci"], stdout=subprocess.PIPE, text=True)
|
|
21
|
+
# Filter the output for lines containing both "NVIDIA" and "H100"
|
|
22
|
+
filtered_lines = [line for line in result.stdout.splitlines() if "NVIDIA" in line and "H100" in line]
|
|
23
|
+
# Join all filtered lines into a single string
|
|
24
|
+
device_name = " ".join(filtered_lines) or device_name
|
|
25
|
+
except FileNotFoundError as e:
|
|
26
|
+
logger.warning(f"Error running lspci: {e}, fallback to use device_name")
|
|
27
|
+
if "A100" in device_name:
|
|
28
|
+
# data from https://www.nvidia.com/en-us/data-center/a100/
|
|
29
|
+
return 312e12
|
|
30
|
+
elif "H100" in device_name:
|
|
31
|
+
# data from https://www.nvidia.com/en-us/data-center/h100/
|
|
32
|
+
# NOTE: Specifications are one-half lower without sparsity.
|
|
33
|
+
if "NVL" in device_name:
|
|
34
|
+
return 835e12
|
|
35
|
+
elif "PCIe" in device_name:
|
|
36
|
+
return 756e12
|
|
37
|
+
else: # for H100 SXM and other variants
|
|
38
|
+
return 989e12
|
|
39
|
+
elif "H200" in device_name:
|
|
40
|
+
# data from https://www.nvidia.com/en-us/data-center/h200/
|
|
41
|
+
return 989e12
|
|
42
|
+
elif "B200" in device_name:
|
|
43
|
+
# data from https://nvdam.widen.net/s/wwnsxrhm2w/blackwell-datasheet-3384703
|
|
44
|
+
return 2.25e15
|
|
45
|
+
elif "MI355X" in device_name:
|
|
46
|
+
# MI355X data from https://www.amd.com/en/products/accelerators/instinct/mi350/mi355x.html
|
|
47
|
+
return 2500e12
|
|
48
|
+
elif "MI300X" in device_name or "MI325X" in device_name:
|
|
49
|
+
# MI300X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi300x.html
|
|
50
|
+
# MI325X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi325x.html
|
|
51
|
+
return 1300e12
|
|
52
|
+
elif "MI250X" in device_name:
|
|
53
|
+
# data from https://www.amd.com/en/products/accelerators/instinct/mi200/mi250x.html (per GCD)
|
|
54
|
+
return 191.5e12
|
|
55
|
+
elif "Data Center GPU Max 1550" in device_name:
|
|
56
|
+
# Also known as Ponte Vecchio (PVC).
|
|
57
|
+
# data from https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html
|
|
58
|
+
# Dot Product Accumulate Systolic (DPAS):
|
|
59
|
+
# - Freq: 1300MHz
|
|
60
|
+
# - #ops: 512
|
|
61
|
+
# Full EU mode (i.e. 512 max compute units): 340.8 TFLOPS (BF16)
|
|
62
|
+
# Standard EU mode (i.e. 448 max compute units): 298.2 TFLOPS (BF16)
|
|
63
|
+
max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units
|
|
64
|
+
return 512 * max_comp_units * 1300 * 10**6
|
|
65
|
+
elif "l40s" in device_name:
|
|
66
|
+
# data from: "https://resources.nvidia.com/en-us-l40s/l40s-datasheet-28413"
|
|
67
|
+
return 362e12
|
|
68
|
+
|
|
69
|
+
else: # for other GPU types, assume A100
|
|
70
|
+
logger.warning(f"Peak flops undefined for: {device_name}, fallback to A100")
|
|
71
|
+
return 312e12
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@dataclasses.dataclass(slots=True)
|
|
75
|
+
class DevMemStats:
|
|
76
|
+
max_active_gib: float
|
|
77
|
+
max_active_perc: float
|
|
78
|
+
max_reserved_gib: float
|
|
79
|
+
max_reserved_perc: float
|
|
80
|
+
num_alloc_retries: int
|
|
81
|
+
num_ooms: int
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class DeviceMemoryMonitor:
|
|
85
|
+
def __init__(
|
|
86
|
+
self,
|
|
87
|
+
device: torch.device,
|
|
88
|
+
) -> None:
|
|
89
|
+
self.device = device
|
|
90
|
+
try:
|
|
91
|
+
_, self.total_mem = torch.accelerator.get_memory_info(self.device)
|
|
92
|
+
self.device_name = _get_device_module(self.device.type).get_device_name(self.device)
|
|
93
|
+
except RuntimeError:
|
|
94
|
+
if self.device.type == "mps":
|
|
95
|
+
self.total_mem = torch.mps.recommended_max_memory()
|
|
96
|
+
self.device_name = "Apple Silicon GPU"
|
|
97
|
+
else:
|
|
98
|
+
raise NotImplementedError(f"Device memory monitor not implemented for device type: {self.device.type}")
|
|
99
|
+
|
|
100
|
+
self.reset_peak_stats()
|
|
101
|
+
try:
|
|
102
|
+
torch.accelerator.empty_cache()
|
|
103
|
+
except RuntimeError:
|
|
104
|
+
logger.error(f"Failed to empty cache for device type: {self.device.type}")
|
|
105
|
+
|
|
106
|
+
@staticmethod
|
|
107
|
+
def to_gib(bytes: int) -> float:
|
|
108
|
+
return bytes / (1024**3)
|
|
109
|
+
|
|
110
|
+
def reset_peak_stats(self) -> None:
|
|
111
|
+
try:
|
|
112
|
+
torch.accelerator.reset_peak_memory_stats(self.device)
|
|
113
|
+
except RuntimeError:
|
|
114
|
+
logger.error(f"Failed to reset peak memory stats for device type: {self.device.type}")
|
|
115
|
+
|
|
116
|
+
def get_peak_stats(self) -> DevMemStats:
|
|
117
|
+
try:
|
|
118
|
+
info = torch.accelerator.memory_stats(self.device)
|
|
119
|
+
except RuntimeError:
|
|
120
|
+
logger.error(f"Failed to get peak memory stats for device type: {self.device.type}")
|
|
121
|
+
info = {}
|
|
122
|
+
|
|
123
|
+
max_active = info.get("active_bytes.all.peak", -1)
|
|
124
|
+
max_reserved = info.get("reserved_bytes.all.peak", -1)
|
|
125
|
+
num_retries = info.get("num_alloc_retries", -1)
|
|
126
|
+
num_ooms = info.get("num_ooms", -1)
|
|
127
|
+
|
|
128
|
+
if num_retries > 0:
|
|
129
|
+
logger.warning(f"{num_retries} {self.device.type.upper()} memory allocation retries.")
|
|
130
|
+
if num_ooms > 0:
|
|
131
|
+
logger.warning(f"{num_ooms} {self.device.type.upper()} OOM errors thrown.")
|
|
132
|
+
|
|
133
|
+
return DevMemStats(
|
|
134
|
+
max_active_gib=self.to_gib(max_active),
|
|
135
|
+
max_active_perc=max_active / self.total_mem * 100,
|
|
136
|
+
max_reserved_gib=self.to_gib(max_reserved),
|
|
137
|
+
max_reserved_perc=max_reserved / self.total_mem * 100,
|
|
138
|
+
num_alloc_retries=num_retries,
|
|
139
|
+
num_ooms=num_ooms,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class BaseReporter:
|
|
144
|
+
def config(
|
|
145
|
+
self,
|
|
146
|
+
config: dict[str, Any],
|
|
147
|
+
) -> None:
|
|
148
|
+
raise NotImplementedError()
|
|
149
|
+
|
|
150
|
+
def log(
|
|
151
|
+
self,
|
|
152
|
+
metrics: dict[str, Any],
|
|
153
|
+
step: int,
|
|
154
|
+
) -> None:
|
|
155
|
+
raise NotImplementedError()
|
|
156
|
+
|
|
157
|
+
def close(self) -> None:
|
|
158
|
+
raise NotImplementedError()
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class TensorboardReporter(BaseReporter):
|
|
162
|
+
def __init__(
|
|
163
|
+
self,
|
|
164
|
+
log_dir: Path,
|
|
165
|
+
) -> None:
|
|
166
|
+
from torch.utils.tensorboard import SummaryWriter
|
|
167
|
+
|
|
168
|
+
self.writer = SummaryWriter(log_dir=log_dir, max_queue=1000)
|
|
169
|
+
|
|
170
|
+
logger.info(f"TensorBoard log is available at {log_dir}")
|
|
171
|
+
|
|
172
|
+
def config(
|
|
173
|
+
self,
|
|
174
|
+
config: dict[str, Any],
|
|
175
|
+
) -> None:
|
|
176
|
+
for k, v in config.items():
|
|
177
|
+
self.writer.add_text(f"config/{k}", str(v))
|
|
178
|
+
|
|
179
|
+
def log(
|
|
180
|
+
self,
|
|
181
|
+
metrics: dict[str, float],
|
|
182
|
+
step: int,
|
|
183
|
+
) -> None:
|
|
184
|
+
for k, v in metrics.items():
|
|
185
|
+
self.writer.add_scalar(k, v, step)
|
|
186
|
+
|
|
187
|
+
def close(self) -> None:
|
|
188
|
+
self.writer.close()
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class MetricsProcessor:
|
|
192
|
+
def __init__(
|
|
193
|
+
self,
|
|
194
|
+
config: Config,
|
|
195
|
+
device: torch.device,
|
|
196
|
+
flops_per_token: int,
|
|
197
|
+
) -> None:
|
|
198
|
+
self.reporters = []
|
|
199
|
+
if config.metrics.all_node or rank() == 0:
|
|
200
|
+
if config.metrics.use_tensorboard:
|
|
201
|
+
log_dir = config.output_dir / "tensorboard" if config.output_dir else Path("./tensorboard")
|
|
202
|
+
self.reporters.append(TensorboardReporter(log_dir=log_dir))
|
|
203
|
+
|
|
204
|
+
for reporter in self.reporters:
|
|
205
|
+
reporter.config(config=dataclasses.asdict(config))
|
|
206
|
+
|
|
207
|
+
self.device_mem_monitor = DeviceMemoryMonitor(device)
|
|
208
|
+
self.log_freq = config.metrics.log_freq
|
|
209
|
+
self.time_last_log = time.perf_counter()
|
|
210
|
+
gpu_peak_flops = get_peak_flops(self.device_mem_monitor.device_name)
|
|
211
|
+
logger.info(f"Detected device: {self.device_mem_monitor.device_name}, Peak FLOPS: {gpu_peak_flops}")
|
|
212
|
+
self.gpu_peak_flops = gpu_peak_flops
|
|
213
|
+
self.ntokens_since_last_log = 0
|
|
214
|
+
self.flops_per_token = flops_per_token
|
|
215
|
+
self.data_load_times: list[float] = []
|
|
216
|
+
self.reset()
|
|
217
|
+
|
|
218
|
+
def should_log(
|
|
219
|
+
self,
|
|
220
|
+
step: int,
|
|
221
|
+
) -> bool:
|
|
222
|
+
return step == 1 or step % self.log_freq == 0
|
|
223
|
+
|
|
224
|
+
def log(
|
|
225
|
+
self,
|
|
226
|
+
step: int,
|
|
227
|
+
global_avg_loss: float,
|
|
228
|
+
global_max_loss: float,
|
|
229
|
+
extra_metrics: dict[str, float] | None = None,
|
|
230
|
+
) -> None:
|
|
231
|
+
time_delta = time.perf_counter() - self.time_last_log
|
|
232
|
+
device_mem_stats = self.device_mem_monitor.get_peak_stats()
|
|
233
|
+
time_ete = time_delta / self.log_freq
|
|
234
|
+
time_data_load = sum(self.data_load_times) / len(self.data_load_times) if self.data_load_times else 0.0
|
|
235
|
+
time_data_load_perc = 100 * time_data_load / time_ete if time_ete > 0 else 0.0
|
|
236
|
+
|
|
237
|
+
metrics = {
|
|
238
|
+
"loss/avg": global_avg_loss,
|
|
239
|
+
"loss/max": global_max_loss,
|
|
240
|
+
"memory/max_active(GiB)": device_mem_stats.max_active_gib,
|
|
241
|
+
"memory/max_active(%)": device_mem_stats.max_active_perc,
|
|
242
|
+
"memory/max_reserved(GiB)": device_mem_stats.max_reserved_gib,
|
|
243
|
+
"memory/max_reserved(%)": device_mem_stats.max_reserved_perc,
|
|
244
|
+
"memory/num_alloc_retries": device_mem_stats.num_alloc_retries,
|
|
245
|
+
"memory/num_ooms": device_mem_stats.num_ooms,
|
|
246
|
+
"time/end-to-end(s)": time_ete,
|
|
247
|
+
"time/data_load(s)": time_data_load,
|
|
248
|
+
"time/data_load(%)": time_data_load_perc,
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
log = (
|
|
252
|
+
f"[Step {step:>10}] loss: {global_avg_loss:.4f}, memory: {device_mem_stats.max_reserved_gib:.2f} GiB, "
|
|
253
|
+
f"time(s): {time_ete:.2f}sec (data load ratio: {time_data_load_perc:.1f}%)"
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
if extra_metrics is not None:
|
|
257
|
+
metrics.update(extra_metrics)
|
|
258
|
+
|
|
259
|
+
if self.flops_per_token > 0:
|
|
260
|
+
tps = self.ntokens_since_last_log / time_delta
|
|
261
|
+
mfu = 100 * self.flops_per_token * tps / self.gpu_peak_flops
|
|
262
|
+
tflops = self.flops_per_token * tps / 1e12
|
|
263
|
+
|
|
264
|
+
metrics.update({
|
|
265
|
+
"throughput(tps)": tps,
|
|
266
|
+
"tflops": tflops,
|
|
267
|
+
"mfu(%)": mfu,
|
|
268
|
+
})
|
|
269
|
+
log += f", tflops: {tflops:.2f}, mfu: {mfu:.2f}%"
|
|
270
|
+
|
|
271
|
+
for reporter in self.reporters:
|
|
272
|
+
reporter.log(metrics, step)
|
|
273
|
+
|
|
274
|
+
logger.info(log)
|
|
275
|
+
|
|
276
|
+
self.reset()
|
|
277
|
+
|
|
278
|
+
def reset(self) -> None:
|
|
279
|
+
self.ntokens_since_last_log = 0
|
|
280
|
+
self.data_load_times.clear()
|
|
281
|
+
self.time_last_log = time.perf_counter()
|
|
282
|
+
self.device_mem_monitor.reset_peak_stats()
|
|
283
|
+
|
|
284
|
+
def val_log(
|
|
285
|
+
self,
|
|
286
|
+
step: int,
|
|
287
|
+
val_loss: float,
|
|
288
|
+
extra_metrics: dict[str, float] | None = None,
|
|
289
|
+
) -> None:
|
|
290
|
+
raise NotImplementedError()
|
|
291
|
+
|
|
292
|
+
def close(self) -> None:
|
|
293
|
+
for reporter in self.reporters:
|
|
294
|
+
reporter.close()
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
import dataclasses
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from loguru import logger
|
|
9
|
+
from torch import nn
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclasses.dataclass
|
|
13
|
+
class ModelConfig:
|
|
14
|
+
name: Literal["nanochat_gpt", "llama3"] = "nanochat_gpt"
|
|
15
|
+
num_layers: int = 12
|
|
16
|
+
head_dim: int = 128
|
|
17
|
+
|
|
18
|
+
num_heads: int | None = None # inferred later if None
|
|
19
|
+
num_kv_heads: int | None = None # inferred later if None
|
|
20
|
+
hidden_dim: int | None = None # inferred later if None
|
|
21
|
+
vocab_size: int | None = None # set later based on tokenizer
|
|
22
|
+
seq_len: int | None = None # set later based on data config
|
|
23
|
+
qk_norm: bool = False # whether to use RMSNorm on q/k
|
|
24
|
+
|
|
25
|
+
def __post_init__(self):
|
|
26
|
+
# infer hidden_dim, num_heads, num_kv_heads if not provided using the rules presented in nanochat
|
|
27
|
+
self.hidden_dim = self.hidden_dim or (self.num_layers * 64 + self.head_dim - 1) // self.head_dim * self.head_dim
|
|
28
|
+
self.num_heads = self.num_heads or self.hidden_dim // self.head_dim
|
|
29
|
+
self.num_kv_heads = self.num_kv_heads or self.num_heads
|
|
30
|
+
|
|
31
|
+
# sanity checks
|
|
32
|
+
assert self.hidden_dim % self.head_dim == 0
|
|
33
|
+
assert self.head_dim * self.num_heads == self.hidden_dim
|
|
34
|
+
assert self.num_kv_heads <= self.num_heads and self.num_heads % self.num_kv_heads == 0
|
|
35
|
+
|
|
36
|
+
def create(self) -> BaseModel:
|
|
37
|
+
if self.vocab_size is None or self.seq_len is None:
|
|
38
|
+
raise ValueError("vocab_size and seq_len must be set before creating the model")
|
|
39
|
+
|
|
40
|
+
match self.name:
|
|
41
|
+
case "nanochat_gpt":
|
|
42
|
+
from .nanochat_gpt import GPT
|
|
43
|
+
|
|
44
|
+
if not self.qk_norm:
|
|
45
|
+
logger.warning("nanochat_gpt model without qk_norm is not recommended")
|
|
46
|
+
|
|
47
|
+
return GPT(self)
|
|
48
|
+
|
|
49
|
+
case "llama3":
|
|
50
|
+
from .llama3 import Llama3
|
|
51
|
+
|
|
52
|
+
if self.qk_norm:
|
|
53
|
+
logger.warning("llama3 model with qk_norm is not standard")
|
|
54
|
+
|
|
55
|
+
return Llama3(self)
|
|
56
|
+
|
|
57
|
+
case _:
|
|
58
|
+
raise ValueError(f"Unknown model name: {self.name}")
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class BaseModel(nn.Module, abc.ABC):
|
|
62
|
+
# Common interface for all models in Sarasa
|
|
63
|
+
|
|
64
|
+
blocks: list[nn.Module] # TF blocks
|
|
65
|
+
config: ModelConfig
|
|
66
|
+
|
|
67
|
+
@abc.abstractmethod
|
|
68
|
+
@torch.no_grad()
|
|
69
|
+
def init_weights(self) -> None:
|
|
70
|
+
# Actual initialization of model weights
|
|
71
|
+
pass
|
|
72
|
+
|
|
73
|
+
@abc.abstractmethod
|
|
74
|
+
def param_groups(self) -> dict[str, list[nn.Parameter]]:
|
|
75
|
+
# Return parameter groups for optimizer
|
|
76
|
+
pass
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def num_params_flops(
|
|
80
|
+
self,
|
|
81
|
+
) -> tuple[int, int]:
|
|
82
|
+
# Return number of parameters and FLOPs per token (for dense model)
|
|
83
|
+
config = self.config
|
|
84
|
+
|
|
85
|
+
# for tied embeddings, num_params -= num_params_emb
|
|
86
|
+
num_params = sum(p.numel() for p in self.parameters())
|
|
87
|
+
num_params_emb = self.token_emb.weight.numel()
|
|
88
|
+
|
|
89
|
+
# If forward pass has 1 matmul, then backward pass has 2 matmuls
|
|
90
|
+
# Each self-attention has 2 matmuls
|
|
91
|
+
num_flops_per_token = 6 * (
|
|
92
|
+
(num_params - num_params_emb) + (config.num_layers * config.num_heads * config.head_dim * config.seq_len)
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
return num_params, num_flops_per_token
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
from torch.nn import functional as F
|
|
4
|
+
|
|
5
|
+
from sarasa.models import ModelConfig
|
|
6
|
+
from sarasa.models.utils import RMSNorm, RoPE
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SDPAttention(nn.Module):
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
is_causal: bool,
|
|
13
|
+
enable_gqa: bool,
|
|
14
|
+
):
|
|
15
|
+
super().__init__()
|
|
16
|
+
self.is_causal = is_causal
|
|
17
|
+
self.enable_gqa = enable_gqa
|
|
18
|
+
|
|
19
|
+
if nn.attention.current_flash_attention_impl() == "FA4":
|
|
20
|
+
self.sdpa_backends = nn.attention.SDPBackend.FLASH_ATTENTION
|
|
21
|
+
else:
|
|
22
|
+
self.sdpa_backends = [
|
|
23
|
+
nn.attention.SDPBackend.CUDNN_ATTENTION,
|
|
24
|
+
nn.attention.SDPBackend.FLASH_ATTENTION,
|
|
25
|
+
nn.attention.SDPBackend.MATH,
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
def forward(
|
|
29
|
+
self,
|
|
30
|
+
query: torch.Tensor,
|
|
31
|
+
key: torch.Tensor,
|
|
32
|
+
value: torch.Tensor,
|
|
33
|
+
) -> torch.Tensor:
|
|
34
|
+
with nn.attention.sdpa_kernel(self.sdpa_backends):
|
|
35
|
+
return F.scaled_dot_product_attention(
|
|
36
|
+
query,
|
|
37
|
+
key,
|
|
38
|
+
value,
|
|
39
|
+
is_causal=self.is_causal,
|
|
40
|
+
enable_gqa=self.enable_gqa,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class CausalSelfAttention(nn.Module):
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
config: ModelConfig,
|
|
48
|
+
layer_idx: int | None = None,
|
|
49
|
+
):
|
|
50
|
+
super().__init__()
|
|
51
|
+
self.layer_idx = layer_idx
|
|
52
|
+
self.num_heads = config.num_heads
|
|
53
|
+
self.num_kv_heads = config.num_kv_heads
|
|
54
|
+
self.hidden_dim = config.hidden_dim
|
|
55
|
+
self.head_dim = self.hidden_dim // self.num_heads
|
|
56
|
+
self.c_q = nn.Linear(self.hidden_dim, self.num_heads * self.head_dim, bias=False)
|
|
57
|
+
self.c_k = nn.Linear(self.hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
|
|
58
|
+
self.c_v = nn.Linear(self.hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
|
|
59
|
+
self.c_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=False)
|
|
60
|
+
self.qk_norm = RMSNorm(self.head_dim) if config.qk_norm else nn.Identity()
|
|
61
|
+
|
|
62
|
+
# todo: support varlen etc and kv caching
|
|
63
|
+
self.attn = SDPAttention(is_causal=True, enable_gqa=self.num_heads != self.num_kv_heads)
|
|
64
|
+
|
|
65
|
+
def forward(
|
|
66
|
+
self,
|
|
67
|
+
x: torch.Tensor,
|
|
68
|
+
cos_sin: tuple[torch.Tensor, torch.Tensor],
|
|
69
|
+
) -> torch.Tensor:
|
|
70
|
+
B, T, C = x.size()
|
|
71
|
+
|
|
72
|
+
# Project the input to get queries, keys, and values
|
|
73
|
+
q = self.c_q(x).view(B, T, self.num_heads, self.head_dim)
|
|
74
|
+
k = self.c_k(x).view(B, T, self.num_kv_heads, self.head_dim)
|
|
75
|
+
v = self.c_v(x).view(B, T, self.num_kv_heads, self.head_dim)
|
|
76
|
+
|
|
77
|
+
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
|
|
78
|
+
cos, sin = cos_sin
|
|
79
|
+
q, k = RoPE.apply(q, cos, sin), RoPE.apply(k, cos, sin)
|
|
80
|
+
q, k = self.qk_norm(q), self.qk_norm(k)
|
|
81
|
+
y = self.attn(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)) # (B, n_head, T, head_dim)
|
|
82
|
+
y = y.transpose(1, 2).contiguous().view(B, T, -1)
|
|
83
|
+
y = self.c_proj(y)
|
|
84
|
+
return y
|
sarasa/models/llama3.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
from torch.nn import functional as F
|
|
4
|
+
|
|
5
|
+
from sarasa.models import BaseModel, ModelConfig
|
|
6
|
+
from sarasa.models.attention import CausalSelfAttention
|
|
7
|
+
from sarasa.models.utils import RMSNorm, RoPE
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class MLP(nn.Module):
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
config: ModelConfig,
|
|
14
|
+
multiple_of: int,
|
|
15
|
+
ffn_dim_multiplier: float | None,
|
|
16
|
+
):
|
|
17
|
+
super().__init__()
|
|
18
|
+
hidden_dim = int(8 * config.hidden_dim / 3)
|
|
19
|
+
# custom dim factor multiplier
|
|
20
|
+
if ffn_dim_multiplier is not None:
|
|
21
|
+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
|
22
|
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
|
23
|
+
|
|
24
|
+
self.w1 = nn.Linear(config.hidden_dim, hidden_dim, bias=False)
|
|
25
|
+
self.w2 = nn.Linear(hidden_dim, config.hidden_dim, bias=False)
|
|
26
|
+
self.w3 = nn.Linear(config.hidden_dim, hidden_dim, bias=False)
|
|
27
|
+
|
|
28
|
+
def forward(self, x):
|
|
29
|
+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Block(nn.Module):
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
config: ModelConfig,
|
|
36
|
+
layer_idx: int,
|
|
37
|
+
multiple_of: int,
|
|
38
|
+
ffn_dim_multiplier: float | None,
|
|
39
|
+
):
|
|
40
|
+
super().__init__()
|
|
41
|
+
self.layer_idx = layer_idx
|
|
42
|
+
self.attention = CausalSelfAttention(config)
|
|
43
|
+
self.mlp = MLP(config, multiple_of, ffn_dim_multiplier)
|
|
44
|
+
self.norm = RMSNorm(config.hidden_dim)
|
|
45
|
+
|
|
46
|
+
def forward(
|
|
47
|
+
self,
|
|
48
|
+
x: torch.Tensor,
|
|
49
|
+
cos_sin: tuple[torch.Tensor, torch.Tensor],
|
|
50
|
+
) -> torch.Tensor:
|
|
51
|
+
x = x + self.attention(self.norm(x), cos_sin)
|
|
52
|
+
x = x + self.mlp(self.norm(x))
|
|
53
|
+
return x
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class Llama3(BaseModel):
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
config: ModelConfig,
|
|
60
|
+
multiple_of: int = 1024,
|
|
61
|
+
ffn_dim_multiplier: float | None = 1.4,
|
|
62
|
+
):
|
|
63
|
+
super().__init__()
|
|
64
|
+
self.config = config
|
|
65
|
+
self.token_emb = nn.Embedding(config.vocab_size, config.hidden_dim)
|
|
66
|
+
self.max_seq_len = config.seq_len * 16
|
|
67
|
+
self.head_dim = config.head_dim
|
|
68
|
+
cos, sin = RoPE.precompute(self.max_seq_len, config.head_dim)
|
|
69
|
+
self.register_buffer("cos", cos, persistent=False)
|
|
70
|
+
self.register_buffer("sin", sin, persistent=False)
|
|
71
|
+
self.blocks = nn.ModuleList([
|
|
72
|
+
Block(config, layer_idx, multiple_of, ffn_dim_multiplier) for layer_idx in range(config.num_layers)
|
|
73
|
+
])
|
|
74
|
+
self.norm = RMSNorm(config.hidden_dim)
|
|
75
|
+
self.output = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
|
|
76
|
+
|
|
77
|
+
@torch.no_grad()
|
|
78
|
+
def init_weights(self) -> None:
|
|
79
|
+
self.cos, self.sin = RoPE.precompute(self.max_seq_len, self.head_dim, device=self.cos.device)
|
|
80
|
+
torch.nn.init.normal_(self.token_emb.weight)
|
|
81
|
+
for block in self.blocks:
|
|
82
|
+
block: Block
|
|
83
|
+
init_std = 0.02 / (2 * (block.layer_idx + 1)) ** 0.5
|
|
84
|
+
|
|
85
|
+
nn.init.trunc_normal_(block.attention.c_q.weight, std=0.02)
|
|
86
|
+
nn.init.trunc_normal_(block.attention.c_k.weight, std=0.02)
|
|
87
|
+
nn.init.trunc_normal_(block.attention.c_v.weight, std=0.02)
|
|
88
|
+
nn.init.trunc_normal_(block.attention.c_proj.weight, std=init_std)
|
|
89
|
+
|
|
90
|
+
nn.init.trunc_normal_(block.mlp.w1.weight, std=0.02)
|
|
91
|
+
nn.init.trunc_normal_(block.mlp.w2.weight, std=init_std)
|
|
92
|
+
nn.init.trunc_normal_(block.mlp.w3.weight, std=init_std)
|
|
93
|
+
|
|
94
|
+
final_out_std = self.output.weight.shape[-1] ** -0.5
|
|
95
|
+
cutoff_factor = 3
|
|
96
|
+
nn.init.trunc_normal_(
|
|
97
|
+
self.output.weight,
|
|
98
|
+
mean=0.0,
|
|
99
|
+
std=final_out_std,
|
|
100
|
+
a=-cutoff_factor * final_out_std,
|
|
101
|
+
b=cutoff_factor * final_out_std,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
def param_groups(self) -> dict[str, list[nn.Parameter]]:
|
|
105
|
+
matrix_params = list(self.blocks.parameters())
|
|
106
|
+
embedding_params = list(self.token_emb.parameters())
|
|
107
|
+
lm_head_params = list(self.output.parameters())
|
|
108
|
+
assert len(list(self.parameters())) == (len(matrix_params) + len(embedding_params) + len(lm_head_params))
|
|
109
|
+
|
|
110
|
+
return {
|
|
111
|
+
"matrix": matrix_params,
|
|
112
|
+
"embedding": embedding_params,
|
|
113
|
+
"lm_head": lm_head_params,
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
def forward(
|
|
117
|
+
self,
|
|
118
|
+
input: torch.Tensor,
|
|
119
|
+
) -> torch.Tensor:
|
|
120
|
+
B, T = input.size()
|
|
121
|
+
x = self.token_emb(input) # (B, T, C)
|
|
122
|
+
cos_sin = self.cos[:, :T], self.sin[:, :T]
|
|
123
|
+
|
|
124
|
+
for block in self.blocks:
|
|
125
|
+
x = block(x, cos_sin)
|
|
126
|
+
|
|
127
|
+
x = self.norm(x)
|
|
128
|
+
logits = self.output(x)
|
|
129
|
+
return logits
|