sarasa 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sarasa/metrics.py ADDED
@@ -0,0 +1,294 @@
1
+ import dataclasses
2
+ import subprocess
3
+ import time
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import torch
8
+ from loguru import logger
9
+ from torch._utils import _get_device_module
10
+
11
+ from sarasa.config import Config
12
+ from sarasa.utils import rank
13
+
14
+
15
+ # ported from torchtitan
16
+ # hardcoded BF16 type peak flops for NVIDIA A100, H100, H200, B200 GPU and AMD MI250, MI300X, MI325X, MI355X and Intel PVC
17
+ def get_peak_flops(device_name: str) -> float:
18
+ try:
19
+ # Run the lspci command and capture the output
20
+ result = subprocess.run(["lspci"], stdout=subprocess.PIPE, text=True)
21
+ # Filter the output for lines containing both "NVIDIA" and "H100"
22
+ filtered_lines = [line for line in result.stdout.splitlines() if "NVIDIA" in line and "H100" in line]
23
+ # Join all filtered lines into a single string
24
+ device_name = " ".join(filtered_lines) or device_name
25
+ except FileNotFoundError as e:
26
+ logger.warning(f"Error running lspci: {e}, fallback to use device_name")
27
+ if "A100" in device_name:
28
+ # data from https://www.nvidia.com/en-us/data-center/a100/
29
+ return 312e12
30
+ elif "H100" in device_name:
31
+ # data from https://www.nvidia.com/en-us/data-center/h100/
32
+ # NOTE: Specifications are one-half lower without sparsity.
33
+ if "NVL" in device_name:
34
+ return 835e12
35
+ elif "PCIe" in device_name:
36
+ return 756e12
37
+ else: # for H100 SXM and other variants
38
+ return 989e12
39
+ elif "H200" in device_name:
40
+ # data from https://www.nvidia.com/en-us/data-center/h200/
41
+ return 989e12
42
+ elif "B200" in device_name:
43
+ # data from https://nvdam.widen.net/s/wwnsxrhm2w/blackwell-datasheet-3384703
44
+ return 2.25e15
45
+ elif "MI355X" in device_name:
46
+ # MI355X data from https://www.amd.com/en/products/accelerators/instinct/mi350/mi355x.html
47
+ return 2500e12
48
+ elif "MI300X" in device_name or "MI325X" in device_name:
49
+ # MI300X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi300x.html
50
+ # MI325X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi325x.html
51
+ return 1300e12
52
+ elif "MI250X" in device_name:
53
+ # data from https://www.amd.com/en/products/accelerators/instinct/mi200/mi250x.html (per GCD)
54
+ return 191.5e12
55
+ elif "Data Center GPU Max 1550" in device_name:
56
+ # Also known as Ponte Vecchio (PVC).
57
+ # data from https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html
58
+ # Dot Product Accumulate Systolic (DPAS):
59
+ # - Freq: 1300MHz
60
+ # - #ops: 512
61
+ # Full EU mode (i.e. 512 max compute units): 340.8 TFLOPS (BF16)
62
+ # Standard EU mode (i.e. 448 max compute units): 298.2 TFLOPS (BF16)
63
+ max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units
64
+ return 512 * max_comp_units * 1300 * 10**6
65
+ elif "l40s" in device_name:
66
+ # data from: "https://resources.nvidia.com/en-us-l40s/l40s-datasheet-28413"
67
+ return 362e12
68
+
69
+ else: # for other GPU types, assume A100
70
+ logger.warning(f"Peak flops undefined for: {device_name}, fallback to A100")
71
+ return 312e12
72
+
73
+
74
+ @dataclasses.dataclass(slots=True)
75
+ class DevMemStats:
76
+ max_active_gib: float
77
+ max_active_perc: float
78
+ max_reserved_gib: float
79
+ max_reserved_perc: float
80
+ num_alloc_retries: int
81
+ num_ooms: int
82
+
83
+
84
+ class DeviceMemoryMonitor:
85
+ def __init__(
86
+ self,
87
+ device: torch.device,
88
+ ) -> None:
89
+ self.device = device
90
+ try:
91
+ _, self.total_mem = torch.accelerator.get_memory_info(self.device)
92
+ self.device_name = _get_device_module(self.device.type).get_device_name(self.device)
93
+ except RuntimeError:
94
+ if self.device.type == "mps":
95
+ self.total_mem = torch.mps.recommended_max_memory()
96
+ self.device_name = "Apple Silicon GPU"
97
+ else:
98
+ raise NotImplementedError(f"Device memory monitor not implemented for device type: {self.device.type}")
99
+
100
+ self.reset_peak_stats()
101
+ try:
102
+ torch.accelerator.empty_cache()
103
+ except RuntimeError:
104
+ logger.error(f"Failed to empty cache for device type: {self.device.type}")
105
+
106
+ @staticmethod
107
+ def to_gib(bytes: int) -> float:
108
+ return bytes / (1024**3)
109
+
110
+ def reset_peak_stats(self) -> None:
111
+ try:
112
+ torch.accelerator.reset_peak_memory_stats(self.device)
113
+ except RuntimeError:
114
+ logger.error(f"Failed to reset peak memory stats for device type: {self.device.type}")
115
+
116
+ def get_peak_stats(self) -> DevMemStats:
117
+ try:
118
+ info = torch.accelerator.memory_stats(self.device)
119
+ except RuntimeError:
120
+ logger.error(f"Failed to get peak memory stats for device type: {self.device.type}")
121
+ info = {}
122
+
123
+ max_active = info.get("active_bytes.all.peak", -1)
124
+ max_reserved = info.get("reserved_bytes.all.peak", -1)
125
+ num_retries = info.get("num_alloc_retries", -1)
126
+ num_ooms = info.get("num_ooms", -1)
127
+
128
+ if num_retries > 0:
129
+ logger.warning(f"{num_retries} {self.device.type.upper()} memory allocation retries.")
130
+ if num_ooms > 0:
131
+ logger.warning(f"{num_ooms} {self.device.type.upper()} OOM errors thrown.")
132
+
133
+ return DevMemStats(
134
+ max_active_gib=self.to_gib(max_active),
135
+ max_active_perc=max_active / self.total_mem * 100,
136
+ max_reserved_gib=self.to_gib(max_reserved),
137
+ max_reserved_perc=max_reserved / self.total_mem * 100,
138
+ num_alloc_retries=num_retries,
139
+ num_ooms=num_ooms,
140
+ )
141
+
142
+
143
+ class BaseReporter:
144
+ def config(
145
+ self,
146
+ config: dict[str, Any],
147
+ ) -> None:
148
+ raise NotImplementedError()
149
+
150
+ def log(
151
+ self,
152
+ metrics: dict[str, Any],
153
+ step: int,
154
+ ) -> None:
155
+ raise NotImplementedError()
156
+
157
+ def close(self) -> None:
158
+ raise NotImplementedError()
159
+
160
+
161
+ class TensorboardReporter(BaseReporter):
162
+ def __init__(
163
+ self,
164
+ log_dir: Path,
165
+ ) -> None:
166
+ from torch.utils.tensorboard import SummaryWriter
167
+
168
+ self.writer = SummaryWriter(log_dir=log_dir, max_queue=1000)
169
+
170
+ logger.info(f"TensorBoard log is available at {log_dir}")
171
+
172
+ def config(
173
+ self,
174
+ config: dict[str, Any],
175
+ ) -> None:
176
+ for k, v in config.items():
177
+ self.writer.add_text(f"config/{k}", str(v))
178
+
179
+ def log(
180
+ self,
181
+ metrics: dict[str, float],
182
+ step: int,
183
+ ) -> None:
184
+ for k, v in metrics.items():
185
+ self.writer.add_scalar(k, v, step)
186
+
187
+ def close(self) -> None:
188
+ self.writer.close()
189
+
190
+
191
+ class MetricsProcessor:
192
+ def __init__(
193
+ self,
194
+ config: Config,
195
+ device: torch.device,
196
+ flops_per_token: int,
197
+ ) -> None:
198
+ self.reporters = []
199
+ if config.metrics.all_node or rank() == 0:
200
+ if config.metrics.use_tensorboard:
201
+ log_dir = config.output_dir / "tensorboard" if config.output_dir else Path("./tensorboard")
202
+ self.reporters.append(TensorboardReporter(log_dir=log_dir))
203
+
204
+ for reporter in self.reporters:
205
+ reporter.config(config=dataclasses.asdict(config))
206
+
207
+ self.device_mem_monitor = DeviceMemoryMonitor(device)
208
+ self.log_freq = config.metrics.log_freq
209
+ self.time_last_log = time.perf_counter()
210
+ gpu_peak_flops = get_peak_flops(self.device_mem_monitor.device_name)
211
+ logger.info(f"Detected device: {self.device_mem_monitor.device_name}, Peak FLOPS: {gpu_peak_flops}")
212
+ self.gpu_peak_flops = gpu_peak_flops
213
+ self.ntokens_since_last_log = 0
214
+ self.flops_per_token = flops_per_token
215
+ self.data_load_times: list[float] = []
216
+ self.reset()
217
+
218
+ def should_log(
219
+ self,
220
+ step: int,
221
+ ) -> bool:
222
+ return step == 1 or step % self.log_freq == 0
223
+
224
+ def log(
225
+ self,
226
+ step: int,
227
+ global_avg_loss: float,
228
+ global_max_loss: float,
229
+ extra_metrics: dict[str, float] | None = None,
230
+ ) -> None:
231
+ time_delta = time.perf_counter() - self.time_last_log
232
+ device_mem_stats = self.device_mem_monitor.get_peak_stats()
233
+ time_ete = time_delta / self.log_freq
234
+ time_data_load = sum(self.data_load_times) / len(self.data_load_times) if self.data_load_times else 0.0
235
+ time_data_load_perc = 100 * time_data_load / time_ete if time_ete > 0 else 0.0
236
+
237
+ metrics = {
238
+ "loss/avg": global_avg_loss,
239
+ "loss/max": global_max_loss,
240
+ "memory/max_active(GiB)": device_mem_stats.max_active_gib,
241
+ "memory/max_active(%)": device_mem_stats.max_active_perc,
242
+ "memory/max_reserved(GiB)": device_mem_stats.max_reserved_gib,
243
+ "memory/max_reserved(%)": device_mem_stats.max_reserved_perc,
244
+ "memory/num_alloc_retries": device_mem_stats.num_alloc_retries,
245
+ "memory/num_ooms": device_mem_stats.num_ooms,
246
+ "time/end-to-end(s)": time_ete,
247
+ "time/data_load(s)": time_data_load,
248
+ "time/data_load(%)": time_data_load_perc,
249
+ }
250
+
251
+ log = (
252
+ f"[Step {step:>10}] loss: {global_avg_loss:.4f}, memory: {device_mem_stats.max_reserved_gib:.2f} GiB, "
253
+ f"time(s): {time_ete:.2f}sec (data load ratio: {time_data_load_perc:.1f}%)"
254
+ )
255
+
256
+ if extra_metrics is not None:
257
+ metrics.update(extra_metrics)
258
+
259
+ if self.flops_per_token > 0:
260
+ tps = self.ntokens_since_last_log / time_delta
261
+ mfu = 100 * self.flops_per_token * tps / self.gpu_peak_flops
262
+ tflops = self.flops_per_token * tps / 1e12
263
+
264
+ metrics.update({
265
+ "throughput(tps)": tps,
266
+ "tflops": tflops,
267
+ "mfu(%)": mfu,
268
+ })
269
+ log += f", tflops: {tflops:.2f}, mfu: {mfu:.2f}%"
270
+
271
+ for reporter in self.reporters:
272
+ reporter.log(metrics, step)
273
+
274
+ logger.info(log)
275
+
276
+ self.reset()
277
+
278
+ def reset(self) -> None:
279
+ self.ntokens_since_last_log = 0
280
+ self.data_load_times.clear()
281
+ self.time_last_log = time.perf_counter()
282
+ self.device_mem_monitor.reset_peak_stats()
283
+
284
+ def val_log(
285
+ self,
286
+ step: int,
287
+ val_loss: float,
288
+ extra_metrics: dict[str, float] | None = None,
289
+ ) -> None:
290
+ raise NotImplementedError()
291
+
292
+ def close(self) -> None:
293
+ for reporter in self.reporters:
294
+ reporter.close()
@@ -0,0 +1,95 @@
1
+ from __future__ import annotations
2
+
3
+ import abc
4
+ import dataclasses
5
+ from typing import Literal
6
+
7
+ import torch
8
+ from loguru import logger
9
+ from torch import nn
10
+
11
+
12
+ @dataclasses.dataclass
13
+ class ModelConfig:
14
+ name: Literal["nanochat_gpt", "llama3"] = "nanochat_gpt"
15
+ num_layers: int = 12
16
+ head_dim: int = 128
17
+
18
+ num_heads: int | None = None # inferred later if None
19
+ num_kv_heads: int | None = None # inferred later if None
20
+ hidden_dim: int | None = None # inferred later if None
21
+ vocab_size: int | None = None # set later based on tokenizer
22
+ seq_len: int | None = None # set later based on data config
23
+ qk_norm: bool = False # whether to use RMSNorm on q/k
24
+
25
+ def __post_init__(self):
26
+ # infer hidden_dim, num_heads, num_kv_heads if not provided using the rules presented in nanochat
27
+ self.hidden_dim = self.hidden_dim or (self.num_layers * 64 + self.head_dim - 1) // self.head_dim * self.head_dim
28
+ self.num_heads = self.num_heads or self.hidden_dim // self.head_dim
29
+ self.num_kv_heads = self.num_kv_heads or self.num_heads
30
+
31
+ # sanity checks
32
+ assert self.hidden_dim % self.head_dim == 0
33
+ assert self.head_dim * self.num_heads == self.hidden_dim
34
+ assert self.num_kv_heads <= self.num_heads and self.num_heads % self.num_kv_heads == 0
35
+
36
+ def create(self) -> BaseModel:
37
+ if self.vocab_size is None or self.seq_len is None:
38
+ raise ValueError("vocab_size and seq_len must be set before creating the model")
39
+
40
+ match self.name:
41
+ case "nanochat_gpt":
42
+ from .nanochat_gpt import GPT
43
+
44
+ if not self.qk_norm:
45
+ logger.warning("nanochat_gpt model without qk_norm is not recommended")
46
+
47
+ return GPT(self)
48
+
49
+ case "llama3":
50
+ from .llama3 import Llama3
51
+
52
+ if self.qk_norm:
53
+ logger.warning("llama3 model with qk_norm is not standard")
54
+
55
+ return Llama3(self)
56
+
57
+ case _:
58
+ raise ValueError(f"Unknown model name: {self.name}")
59
+
60
+
61
+ class BaseModel(nn.Module, abc.ABC):
62
+ # Common interface for all models in Sarasa
63
+
64
+ blocks: list[nn.Module] # TF blocks
65
+ config: ModelConfig
66
+
67
+ @abc.abstractmethod
68
+ @torch.no_grad()
69
+ def init_weights(self) -> None:
70
+ # Actual initialization of model weights
71
+ pass
72
+
73
+ @abc.abstractmethod
74
+ def param_groups(self) -> dict[str, list[nn.Parameter]]:
75
+ # Return parameter groups for optimizer
76
+ pass
77
+
78
+ @property
79
+ def num_params_flops(
80
+ self,
81
+ ) -> tuple[int, int]:
82
+ # Return number of parameters and FLOPs per token (for dense model)
83
+ config = self.config
84
+
85
+ # for tied embeddings, num_params -= num_params_emb
86
+ num_params = sum(p.numel() for p in self.parameters())
87
+ num_params_emb = self.token_emb.weight.numel()
88
+
89
+ # If forward pass has 1 matmul, then backward pass has 2 matmuls
90
+ # Each self-attention has 2 matmuls
91
+ num_flops_per_token = 6 * (
92
+ (num_params - num_params_emb) + (config.num_layers * config.num_heads * config.head_dim * config.seq_len)
93
+ )
94
+
95
+ return num_params, num_flops_per_token
@@ -0,0 +1,84 @@
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ from sarasa.models import ModelConfig
6
+ from sarasa.models.utils import RMSNorm, RoPE
7
+
8
+
9
+ class SDPAttention(nn.Module):
10
+ def __init__(
11
+ self,
12
+ is_causal: bool,
13
+ enable_gqa: bool,
14
+ ):
15
+ super().__init__()
16
+ self.is_causal = is_causal
17
+ self.enable_gqa = enable_gqa
18
+
19
+ if nn.attention.current_flash_attention_impl() == "FA4":
20
+ self.sdpa_backends = nn.attention.SDPBackend.FLASH_ATTENTION
21
+ else:
22
+ self.sdpa_backends = [
23
+ nn.attention.SDPBackend.CUDNN_ATTENTION,
24
+ nn.attention.SDPBackend.FLASH_ATTENTION,
25
+ nn.attention.SDPBackend.MATH,
26
+ ]
27
+
28
+ def forward(
29
+ self,
30
+ query: torch.Tensor,
31
+ key: torch.Tensor,
32
+ value: torch.Tensor,
33
+ ) -> torch.Tensor:
34
+ with nn.attention.sdpa_kernel(self.sdpa_backends):
35
+ return F.scaled_dot_product_attention(
36
+ query,
37
+ key,
38
+ value,
39
+ is_causal=self.is_causal,
40
+ enable_gqa=self.enable_gqa,
41
+ )
42
+
43
+
44
+ class CausalSelfAttention(nn.Module):
45
+ def __init__(
46
+ self,
47
+ config: ModelConfig,
48
+ layer_idx: int | None = None,
49
+ ):
50
+ super().__init__()
51
+ self.layer_idx = layer_idx
52
+ self.num_heads = config.num_heads
53
+ self.num_kv_heads = config.num_kv_heads
54
+ self.hidden_dim = config.hidden_dim
55
+ self.head_dim = self.hidden_dim // self.num_heads
56
+ self.c_q = nn.Linear(self.hidden_dim, self.num_heads * self.head_dim, bias=False)
57
+ self.c_k = nn.Linear(self.hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
58
+ self.c_v = nn.Linear(self.hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
59
+ self.c_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=False)
60
+ self.qk_norm = RMSNorm(self.head_dim) if config.qk_norm else nn.Identity()
61
+
62
+ # todo: support varlen etc and kv caching
63
+ self.attn = SDPAttention(is_causal=True, enable_gqa=self.num_heads != self.num_kv_heads)
64
+
65
+ def forward(
66
+ self,
67
+ x: torch.Tensor,
68
+ cos_sin: tuple[torch.Tensor, torch.Tensor],
69
+ ) -> torch.Tensor:
70
+ B, T, C = x.size()
71
+
72
+ # Project the input to get queries, keys, and values
73
+ q = self.c_q(x).view(B, T, self.num_heads, self.head_dim)
74
+ k = self.c_k(x).view(B, T, self.num_kv_heads, self.head_dim)
75
+ v = self.c_v(x).view(B, T, self.num_kv_heads, self.head_dim)
76
+
77
+ # Apply Rotary Embeddings to queries and keys to get relative positional encoding
78
+ cos, sin = cos_sin
79
+ q, k = RoPE.apply(q, cos, sin), RoPE.apply(k, cos, sin)
80
+ q, k = self.qk_norm(q), self.qk_norm(k)
81
+ y = self.attn(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)) # (B, n_head, T, head_dim)
82
+ y = y.transpose(1, 2).contiguous().view(B, T, -1)
83
+ y = self.c_proj(y)
84
+ return y
@@ -0,0 +1,129 @@
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ from sarasa.models import BaseModel, ModelConfig
6
+ from sarasa.models.attention import CausalSelfAttention
7
+ from sarasa.models.utils import RMSNorm, RoPE
8
+
9
+
10
+ class MLP(nn.Module):
11
+ def __init__(
12
+ self,
13
+ config: ModelConfig,
14
+ multiple_of: int,
15
+ ffn_dim_multiplier: float | None,
16
+ ):
17
+ super().__init__()
18
+ hidden_dim = int(8 * config.hidden_dim / 3)
19
+ # custom dim factor multiplier
20
+ if ffn_dim_multiplier is not None:
21
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
22
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
23
+
24
+ self.w1 = nn.Linear(config.hidden_dim, hidden_dim, bias=False)
25
+ self.w2 = nn.Linear(hidden_dim, config.hidden_dim, bias=False)
26
+ self.w3 = nn.Linear(config.hidden_dim, hidden_dim, bias=False)
27
+
28
+ def forward(self, x):
29
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
30
+
31
+
32
+ class Block(nn.Module):
33
+ def __init__(
34
+ self,
35
+ config: ModelConfig,
36
+ layer_idx: int,
37
+ multiple_of: int,
38
+ ffn_dim_multiplier: float | None,
39
+ ):
40
+ super().__init__()
41
+ self.layer_idx = layer_idx
42
+ self.attention = CausalSelfAttention(config)
43
+ self.mlp = MLP(config, multiple_of, ffn_dim_multiplier)
44
+ self.norm = RMSNorm(config.hidden_dim)
45
+
46
+ def forward(
47
+ self,
48
+ x: torch.Tensor,
49
+ cos_sin: tuple[torch.Tensor, torch.Tensor],
50
+ ) -> torch.Tensor:
51
+ x = x + self.attention(self.norm(x), cos_sin)
52
+ x = x + self.mlp(self.norm(x))
53
+ return x
54
+
55
+
56
+ class Llama3(BaseModel):
57
+ def __init__(
58
+ self,
59
+ config: ModelConfig,
60
+ multiple_of: int = 1024,
61
+ ffn_dim_multiplier: float | None = 1.4,
62
+ ):
63
+ super().__init__()
64
+ self.config = config
65
+ self.token_emb = nn.Embedding(config.vocab_size, config.hidden_dim)
66
+ self.max_seq_len = config.seq_len * 16
67
+ self.head_dim = config.head_dim
68
+ cos, sin = RoPE.precompute(self.max_seq_len, config.head_dim)
69
+ self.register_buffer("cos", cos, persistent=False)
70
+ self.register_buffer("sin", sin, persistent=False)
71
+ self.blocks = nn.ModuleList([
72
+ Block(config, layer_idx, multiple_of, ffn_dim_multiplier) for layer_idx in range(config.num_layers)
73
+ ])
74
+ self.norm = RMSNorm(config.hidden_dim)
75
+ self.output = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
76
+
77
+ @torch.no_grad()
78
+ def init_weights(self) -> None:
79
+ self.cos, self.sin = RoPE.precompute(self.max_seq_len, self.head_dim, device=self.cos.device)
80
+ torch.nn.init.normal_(self.token_emb.weight)
81
+ for block in self.blocks:
82
+ block: Block
83
+ init_std = 0.02 / (2 * (block.layer_idx + 1)) ** 0.5
84
+
85
+ nn.init.trunc_normal_(block.attention.c_q.weight, std=0.02)
86
+ nn.init.trunc_normal_(block.attention.c_k.weight, std=0.02)
87
+ nn.init.trunc_normal_(block.attention.c_v.weight, std=0.02)
88
+ nn.init.trunc_normal_(block.attention.c_proj.weight, std=init_std)
89
+
90
+ nn.init.trunc_normal_(block.mlp.w1.weight, std=0.02)
91
+ nn.init.trunc_normal_(block.mlp.w2.weight, std=init_std)
92
+ nn.init.trunc_normal_(block.mlp.w3.weight, std=init_std)
93
+
94
+ final_out_std = self.output.weight.shape[-1] ** -0.5
95
+ cutoff_factor = 3
96
+ nn.init.trunc_normal_(
97
+ self.output.weight,
98
+ mean=0.0,
99
+ std=final_out_std,
100
+ a=-cutoff_factor * final_out_std,
101
+ b=cutoff_factor * final_out_std,
102
+ )
103
+
104
+ def param_groups(self) -> dict[str, list[nn.Parameter]]:
105
+ matrix_params = list(self.blocks.parameters())
106
+ embedding_params = list(self.token_emb.parameters())
107
+ lm_head_params = list(self.output.parameters())
108
+ assert len(list(self.parameters())) == (len(matrix_params) + len(embedding_params) + len(lm_head_params))
109
+
110
+ return {
111
+ "matrix": matrix_params,
112
+ "embedding": embedding_params,
113
+ "lm_head": lm_head_params,
114
+ }
115
+
116
+ def forward(
117
+ self,
118
+ input: torch.Tensor,
119
+ ) -> torch.Tensor:
120
+ B, T = input.size()
121
+ x = self.token_emb(input) # (B, T, C)
122
+ cos_sin = self.cos[:, :T], self.sin[:, :T]
123
+
124
+ for block in self.blocks:
125
+ x = block(x, cos_sin)
126
+
127
+ x = self.norm(x)
128
+ logits = self.output(x)
129
+ return logits