sarasa 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,192 @@
1
+ # NanoChat's GPT model, adapted from https://github.com/karpathy/nanochat
2
+
3
+
4
+ import torch
5
+ from loguru import logger
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ from sarasa.models import BaseModel, ModelConfig
10
+ from sarasa.models.attention import CausalSelfAttention
11
+ from sarasa.models.utils import RMSNorm, RoPE
12
+
13
+
14
+ class MLP(nn.Module):
15
+ def __init__(
16
+ self,
17
+ config: ModelConfig,
18
+ ):
19
+ super().__init__()
20
+ self.c_fc = nn.Linear(config.hidden_dim, 4 * config.hidden_dim, bias=False)
21
+ self.c_proj = nn.Linear(4 * config.hidden_dim, config.hidden_dim, bias=False)
22
+
23
+ def forward(self, x):
24
+ x = self.c_fc(x)
25
+ x = F.relu(x).square()
26
+ x = self.c_proj(x)
27
+ return x
28
+
29
+
30
+ class Block(nn.Module):
31
+ def __init__(
32
+ self,
33
+ config: ModelConfig,
34
+ layer_idx: int,
35
+ ):
36
+ super().__init__()
37
+ self.attn = CausalSelfAttention(config, layer_idx)
38
+ self.mlp = MLP(config)
39
+ self.norm = RMSNorm(config.hidden_dim)
40
+
41
+ def forward(
42
+ self,
43
+ x: torch.Tensor,
44
+ cos_sin: tuple[torch.Tensor, torch.Tensor],
45
+ ) -> torch.Tensor:
46
+ x = x + self.attn(self.norm(x), cos_sin)
47
+ x = x + self.mlp(self.norm(x))
48
+ return x
49
+
50
+
51
+ class GPT(BaseModel):
52
+ def __init__(
53
+ self,
54
+ config: ModelConfig,
55
+ pad_vocab_size_to=64,
56
+ ):
57
+ """
58
+ NOTE a major footgun: this __init__ function runs in meta device context (!!)
59
+ Therefore, any calculations inside here are shapes and dtypes only, no actual data.
60
+ => We actually initialize all data (parameters, buffers, etc.) in init_weights() instead.
61
+ """
62
+ super().__init__()
63
+ self.config = config
64
+ self.num_heads = config.num_heads
65
+ self.hidden_dim = config.hidden_dim
66
+ self.seq_len = config.seq_len
67
+ self.vocab_size = config.vocab_size
68
+ self.num_layers = config.num_layers
69
+ # For DDP, we want vocab_size divisible by world_size. Also, there are potential performance benefits, see:
70
+ # https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings
71
+ padded_vocab_size = ((self.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
72
+ if padded_vocab_size != self.vocab_size:
73
+ logger.warning(
74
+ f"Padding vocab_size from {self.vocab_size} to {padded_vocab_size} to be divisible by {pad_vocab_size_to}"
75
+ )
76
+ self.token_emb = nn.Embedding(padded_vocab_size, self.hidden_dim)
77
+ self.blocks = nn.ModuleList([Block(config, layer_idx) for layer_idx in range(self.num_layers)])
78
+ self.lm_head = nn.Linear(self.hidden_dim, padded_vocab_size, bias=False)
79
+ self.norm = RMSNorm(self.hidden_dim)
80
+ # Per-layer learnable scalars (inspired by modded-nanogpt)
81
+ # resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
82
+ # x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
83
+ # Separate parameters so they can have different optimizer treatment
84
+ self.resid_lambdas = nn.Parameter(torch.ones(self.num_layers)) # fake init, real init in init_weights()
85
+ self.x0_lambdas = nn.Parameter(torch.zeros(self.num_layers)) # fake init, real init in init_weights()
86
+ # To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
87
+ # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
88
+ # so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
89
+ # In the future we can dynamically grow the cache, for now it's fine.
90
+ self.rotary_seq_len = self.seq_len * 16 # 10X over-compute should be enough, TODO make nicer?
91
+ cos, sin = RoPE.precompute(self.rotary_seq_len, config.head_dim)
92
+ self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
93
+ self.register_buffer("sin", sin, persistent=False)
94
+
95
+ @torch.no_grad()
96
+ def init_weights(self):
97
+ """
98
+ Initialize the full model in this one function for maximum clarity.
99
+
100
+ wte (embedding): normal, std=1.0
101
+ lm_head: normal, std=0.001
102
+ for each block:
103
+ attn.c_q: uniform, std=1/sqrt(n_embd)
104
+ attn.c_k: uniform, std=1/sqrt(n_embd)
105
+ attn.c_v: uniform, std=1/sqrt(n_embd)
106
+ attn.c_proj: zeros
107
+ mlp.c_fc: uniform, std=1/sqrt(n_embd)
108
+ mlp.c_proj: zeros
109
+ """
110
+
111
+ # Embedding and unembedding
112
+ torch.nn.init.normal_(self.token_emb.weight, mean=0.0, std=1.0)
113
+ torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
114
+
115
+ # Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal)
116
+ n_embd = self.hidden_dim
117
+ s = 3**0.5 * n_embd**-0.5 # sqrt(3) multiplier makes sure Uniform achieves the same std as Normal
118
+ for block in self.blocks:
119
+ torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers
120
+ torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
121
+ torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
122
+ torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
123
+ torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
124
+ torch.nn.init.zeros_(block.mlp.c_proj.weight)
125
+
126
+ # Per-layer scalars
127
+ self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
128
+ self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init
129
+
130
+ # Rotary embeddings
131
+ head_dim = self.hidden_dim // self.num_heads
132
+ self.cos, self.sin = RoPE.precompute(self.rotary_seq_len, head_dim, device=self.cos.device)
133
+
134
+ # Cast token embeddings to bf16: optimizer can tolerate it and it saves memory
135
+ if self.token_emb.weight.device.type == "cuda":
136
+ self.token_emb.to(dtype=torch.bfloat16)
137
+
138
+ def param_groups(
139
+ self,
140
+ ) -> dict[str, list[torch.nn.Parameter]]:
141
+ # Separate out all parameters into 5 groups (matrix, embedding, lm_head, resid_lambdas, x0_lambdas)
142
+ matrix_params = list(self.blocks.parameters())
143
+ embedding_params = list(self.token_emb.parameters())
144
+ lm_head_params = list(self.lm_head.parameters())
145
+ resid_params = [self.resid_lambdas]
146
+ x0_params = [self.x0_lambdas]
147
+ assert len(list(self.parameters())) == (
148
+ len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(resid_params) + len(x0_params)
149
+ )
150
+
151
+ return {
152
+ "matrix": matrix_params,
153
+ "embedding": embedding_params,
154
+ "lm_head": lm_head_params,
155
+ "resid_lambdas": resid_params,
156
+ "x0_lambdas": x0_params,
157
+ }
158
+
159
+ def forward(
160
+ self,
161
+ input: torch.Tensor,
162
+ ) -> torch.Tensor:
163
+ B, T = input.size()
164
+
165
+ # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
166
+ assert T <= self.cos.size(1), (
167
+ f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
168
+ )
169
+ assert input.device == self.cos.device, (
170
+ f"Rotary embeddings and idx are on different devices: {input.device} != {self.cos.device}"
171
+ )
172
+ assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
173
+ # if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
174
+ cos_sin = self.cos[:, :T], self.sin[:, :T] # truncate cache to current sequence length
175
+
176
+ # Forward the trunk of the Transformer
177
+ x = self.token_emb(input)
178
+ x = self.norm(x)
179
+ x0 = x # save initial normalized embedding for x0 residual
180
+ for block, resid_lambda, x0_lambda in zip(self.blocks, self.resid_lambdas, self.x0_lambdas):
181
+ x = resid_lambda * x + x0_lambda * x0
182
+ x = block(x, cos_sin)
183
+ x = self.norm(x)
184
+
185
+ # Forward the lm_head (compute logits)
186
+ softcap = 15 # smoothly cap the logits to the range [-softcap, softcap]
187
+ logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory
188
+ logits = logits[..., : self.vocab_size] # slice to remove padding
189
+ logits = logits.float() # switch to fp32 for logit softcap and loss computation
190
+ logits = softcap * torch.tanh(logits / softcap) # squash the logits
191
+
192
+ return logits
sarasa/models/utils.py ADDED
@@ -0,0 +1,39 @@
1
+ import torch
2
+
3
+
4
+ class RMSNorm(torch.nn.RMSNorm):
5
+ # RMSNorm without affine parameters
6
+ def __init__(
7
+ self,
8
+ normalized_shape: int,
9
+ ):
10
+ super().__init__(normalized_shape, eps=None, elementwise_affine=False)
11
+
12
+
13
+ class RoPE:
14
+ @staticmethod
15
+ def precompute(
16
+ seq_len: int,
17
+ head_dim: int,
18
+ device: torch.device = None,
19
+ base: float = 10000,
20
+ ):
21
+ channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
22
+ inv_freq = 1.0 / (base ** (channel_range / head_dim))
23
+ t = torch.arange(seq_len, dtype=torch.float32, device=device)
24
+ freqs = torch.outer(t, inv_freq)[None, :, None, :]
25
+ cos, sin = freqs.cos(), freqs.sin()
26
+ cos, sin = cos.bfloat16(), sin.bfloat16()
27
+ return cos, sin
28
+
29
+ @staticmethod
30
+ def apply(
31
+ x: torch.Tensor,
32
+ cos: torch.Tensor,
33
+ sin: torch.Tensor,
34
+ ) -> torch.Tensor:
35
+ assert x.ndim == 4
36
+ x1, x2 = x.chunk(2, dim=-1)
37
+ y1 = x1 * cos + x2 * sin
38
+ y2 = x1 * (-sin) + x2 * cos
39
+ return torch.cat([y1, y2], 3)
@@ -0,0 +1,77 @@
1
+ import dataclasses
2
+ from typing import Literal
3
+
4
+ import torch
5
+
6
+ from sarasa.models import BaseModel
7
+ from sarasa.optimizers.utils import GroupedOptimizer
8
+
9
+
10
+ @dataclasses.dataclass
11
+ class AdamW:
12
+ """
13
+ Default optimizer
14
+ """
15
+
16
+ lr: float = 1e-4
17
+ weight_decay: float = 0.1
18
+ betas: tuple[float, float] = (0.9, 0.95)
19
+
20
+ def create(
21
+ self,
22
+ model: BaseModel,
23
+ ) -> torch.optim.Optimizer:
24
+ param_groups = model.param_groups()
25
+ params = sum(param_groups.values(), [])
26
+ optimizer = torch.optim.AdamW(
27
+ params,
28
+ lr=torch.tensor(self.lr, dtype=torch.float32),
29
+ weight_decay=self.weight_decay,
30
+ betas=self.betas,
31
+ fused=True,
32
+ )
33
+ return optimizer
34
+
35
+
36
+ @dataclasses.dataclass
37
+ class Muon:
38
+ """
39
+ Muon optimizer
40
+ """
41
+
42
+ lr: float = 1e-4
43
+ weight_decay: float = 0.1
44
+ momentum: float = 0.9
45
+
46
+ adam_lr: float | None = None
47
+ adam_betas: tuple[float, float] = (0.9, 0.95)
48
+ adam_weight_decay: float = 0
49
+
50
+ adjust_lr_fn: Literal["original", "match_rms_adamw"] = "match_rms_adamw"
51
+
52
+ def __post_init__(self):
53
+ self.adam_lr = self.adam_lr or self.lr
54
+
55
+ def create(
56
+ self,
57
+ model: BaseModel,
58
+ ) -> torch.optim.Optimizer:
59
+ param_groups = model.param_groups()
60
+
61
+ muon = torch.optim.Muon(
62
+ param_groups["matrix"],
63
+ lr=self.lr,
64
+ weight_decay=self.weight_decay,
65
+ momentum=self.momentum,
66
+ adjust_lr_fn=self.adjust_lr_fn,
67
+ )
68
+
69
+ adam = torch.optim.AdamW(
70
+ sum([param_groups[k] for k in param_groups if k != "matrix"], []),
71
+ lr=self.adam_lr,
72
+ betas=self.adam_betas,
73
+ weight_decay=self.adam_weight_decay,
74
+ fused=True,
75
+ )
76
+
77
+ return GroupedOptimizer(muon, adam)
@@ -0,0 +1,27 @@
1
+ import torch
2
+
3
+
4
+ class GroupedOptimizer(torch.optim.Optimizer):
5
+ def __init__(
6
+ self,
7
+ *optimizers: torch.optim.Optimizer,
8
+ ):
9
+ super().__init__(sum([optim.param_groups for optim in optimizers], []), {})
10
+ self.optimizers = optimizers
11
+
12
+ def step(self) -> None:
13
+ for optim in self.optimizers:
14
+ optim.step()
15
+
16
+ def zero_grad(
17
+ self,
18
+ set_to_none: bool = True,
19
+ ) -> None:
20
+ for optim in self.optimizers:
21
+ optim.zero_grad(set_to_none=set_to_none)
22
+
23
+ def state_dict(self) -> dict:
24
+ return super().state_dict()
25
+
26
+ def load_state_dict(self, state_dict: dict) -> None:
27
+ super().load_state_dict(state_dict)
sarasa/trainer.py ADDED
@@ -0,0 +1,244 @@
1
+ import contextlib
2
+ import os
3
+ import time
4
+ from collections.abc import Iterable
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+ from loguru import logger
9
+ from torch.distributed.elastic.multiprocessing.errors import record
10
+
11
+ from sarasa.activation_checkpoint import apply_op_sac
12
+ from sarasa.checkpoint import Checkpointer
13
+ from sarasa.config import Config
14
+ from sarasa.metrics import MetricsProcessor
15
+ from sarasa.utils import GarbageCollector, apply_distributed, init_distributed, set_dtype, update_timeout, world_size
16
+
17
+ IGNORE_INDEX = -100
18
+
19
+
20
+ class Trainer:
21
+ @record
22
+ def __init__(
23
+ self,
24
+ config: Config,
25
+ ) -> None:
26
+ self.config = config
27
+ logger.info(f"Initializing Trainer with config: {self.config}")
28
+
29
+ # set seed
30
+ torch.manual_seed(config.seed)
31
+ os.environ["PYTHONHASHSEED"] = str(config.seed % 2**32)
32
+
33
+ # setup device
34
+ torch.accelerator.set_device_index(int(os.environ.get("LOCAL_RANK", 0)))
35
+ self.device = torch.accelerator.current_accelerator(check_available=True)
36
+
37
+ self.gc = GarbageCollector(config.train.gc_freq)
38
+
39
+ # setup distributed
40
+ init_distributed(config.distributed.backend, config.distributed.init_timeout_seconds)
41
+
42
+ # setup data and tokenizer -> use vocab size for model setup
43
+ data = config.data.create(batch_size=config.train.local_batch_size)
44
+ self.data_loader = data["train_loader"] # setup data loader
45
+ self.val_loader = data.get("val_loader", None) # setup eval data loader
46
+ self.tokenizer = data["tokenizer"] # setup tokenizer
47
+
48
+ vocab_size = len(self.tokenizer)
49
+ self.config.model.vocab_size = vocab_size
50
+
51
+ # todo: support other loss functions
52
+ self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX, reduction="sum")
53
+
54
+ # setup model, optimizer, lr scheduler
55
+ with torch.device("meta"), set_dtype(getattr(torch, config.train.dtype)):
56
+ self.model = self.config.model.create()
57
+ num_params, flops_per_token = self.model.num_params_flops
58
+ model_size = num_params / 1e9
59
+ model_size, unit = (num_params / 1e6, "M") if model_size < 1 else (model_size, "B")
60
+ logger.info(f"Model created with {model_size:.2f}{unit} parameters")
61
+
62
+ # following torchtitan, (S)AC -> compilation -> distributed wrapping
63
+ if config.train.use_sac:
64
+ logger.info("Applying Selective Activation Checkpointing (SAC)")
65
+ for i, block in enumerate(self.model.blocks):
66
+ self.model.blocks[i] = apply_op_sac(block)
67
+
68
+ if config.train.compile:
69
+ logger.info("Compiling the model")
70
+ for block in self.model.blocks:
71
+ block.compile(fullgraph=True)
72
+ self.model.compile(dynamic=False)
73
+ self.loss_fn.compile()
74
+
75
+ if world_size() > 1:
76
+ apply_distributed(
77
+ config.distributed,
78
+ self.model,
79
+ device=self.device,
80
+ compile=config.train.compile,
81
+ )
82
+
83
+ self.model.to_empty(device=self.device)
84
+ self.model.init_weights()
85
+
86
+ self.optimizer = self.config.optim.create(self.model)
87
+ self.lr_scheduler = self.config.lr_scheduler.create(self.optimizer, config.train.steps)
88
+
89
+ # setup metrics and checkpointer
90
+ # todo: configure num_flops_per_token
91
+ self.metrics_processor = MetricsProcessor(config, self.device, flops_per_token)
92
+ self.checkpointer = Checkpointer(config, self.model) if config.checkpoint.save_freq > 0 else None
93
+
94
+ dev_mem_stats = self.metrics_processor.device_mem_monitor.get_peak_stats()
95
+ logger.info(
96
+ f"{self.device.type.upper()} memory: {dev_mem_stats.max_reserved_gib:.2f} GiB for model initialization"
97
+ )
98
+
99
+ self.step = 0
100
+ self.grad_accum_steps = config.train.global_batch_size // (config.train.local_batch_size * world_size())
101
+ logger.info(f"Gradient accumulation step is set to: {self.grad_accum_steps}")
102
+
103
+ self.amp_context = contextlib.nullcontext()
104
+ if config.distributed.name != "fsdp":
105
+ self.amp_context = torch.autocast(device_type=self.device.type, dtype=getattr(torch, config.train.dtype))
106
+
107
+ # todo: setup profiler context
108
+ self.profile_context = contextlib.nullcontext()
109
+
110
+ if config.train.use_fa4:
111
+ logger.info("Using FA4 flash attention")
112
+ try:
113
+ torch.nn.attention.activate_flash_attention_impl("FA4")
114
+ except Exception as e:
115
+ logger.warning(
116
+ f"Failed to activate FA4 flash attention: {e}. Install sarasa with `flash_attn` extra for better performance."
117
+ )
118
+
119
+ def __del__(self) -> None:
120
+ # cleanup distributed
121
+ if world_size() > 1:
122
+ try:
123
+ dist.destroy_process_group()
124
+ except Exception as e:
125
+ logger.warning(f"Failed to destroy process group: {e}")
126
+
127
+ @record
128
+ def train(self):
129
+ logger.info("Starting training...")
130
+
131
+ self.model.train()
132
+ with self.profile_context:
133
+ data_iter = self.batch_generator(self.data_loader)
134
+ for _ in range(self.config.train.steps):
135
+ self.step += 1
136
+ self.gc.collect(self.step)
137
+ try:
138
+ self.train_step(data_iter)
139
+ except StopIteration:
140
+ logger.warning("Data loader exhausted during training.")
141
+ break
142
+
143
+ if self.checkpointer is not None:
144
+ self.checkpointer.save(self.step)
145
+
146
+ if self.config.train.val_freq > 0 and self.step % self.config.train.val_freq == 0:
147
+ self.evaluate()
148
+
149
+ if world_size() > 1 and self.step == 1:
150
+ update_timeout(self.config.distributed.train_timeout_seconds, self.device)
151
+
152
+ logger.info("Training completed.")
153
+
154
+ def batch_generator(
155
+ self,
156
+ data_iter: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]],
157
+ ) -> Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]:
158
+ data_iter = iter(data_iter)
159
+ while True:
160
+ begin = time.perf_counter()
161
+ batch = next(data_iter)
162
+ input_dict, target = batch
163
+ self.metrics_processor.ntokens_since_last_log += target.numel()
164
+ self.metrics_processor.data_load_times.append(time.perf_counter() - begin)
165
+ yield input_dict, target
166
+
167
+ def train_step(
168
+ self,
169
+ batch_iter: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]],
170
+ ) -> None:
171
+ self.optimizer.zero_grad()
172
+
173
+ micro_batches = []
174
+ valid_tokens = torch.tensor(0, dtype=torch.long)
175
+ for _ in range(self.grad_accum_steps):
176
+ input_dict, target = next(batch_iter)
177
+ valid_tokens += (target != IGNORE_INDEX).sum()
178
+ micro_batches.append((input_dict, target))
179
+
180
+ valid_tokens = valid_tokens.to(self.device)
181
+ if world_size() > 1:
182
+ dist.all_reduce(valid_tokens, op=dist.ReduceOp.SUM)
183
+
184
+ losses = []
185
+ for input_dict, target in micro_batches:
186
+ input_dict = {
187
+ k: v.to(self.device, non_blocking=(self.device.type == "cuda")) for k, v in input_dict.items()
188
+ }
189
+ target = target.to(self.device, non_blocking=(self.device.type == "cuda"))
190
+
191
+ with self.amp_context:
192
+ pred = self.model(**input_dict)
193
+ loss = self.loss_fn(pred.flatten(0, 1), target.flatten(0, 1)) / valid_tokens
194
+
195
+ del pred
196
+ loss.backward()
197
+ losses.append(loss.detach())
198
+
199
+ if self.config.train.grad_clip is not None:
200
+ torch.nn.utils.clip_grad_norm_(
201
+ self.model.parameters(), self.config.train.grad_clip, foreach=self.device.type == "cuda"
202
+ )
203
+
204
+ if self.checkpointer is not None:
205
+ self.checkpointer.wait_for_staging()
206
+
207
+ self.optimizer.step()
208
+ self.lr_scheduler.step()
209
+
210
+ loss = torch.stack(losses).sum()
211
+
212
+ if not self.metrics_processor.should_log(self.step):
213
+ return
214
+
215
+ if world_size() > 1:
216
+ avg_loss = loss.clone()
217
+ dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
218
+ max_loss = loss.clone()
219
+ dist.all_reduce(max_loss, op=dist.ReduceOp.MAX)
220
+ else:
221
+ avg_loss = max_loss = loss
222
+
223
+ with torch.no_grad():
224
+ grad_norm = torch.nn.utils.get_total_norm(self.model.parameters(), foreach=self.device.type == "cuda")
225
+
226
+ lr = self.lr_scheduler.get_last_lr()[0]
227
+ self.metrics_processor.log(
228
+ self.step,
229
+ global_avg_loss=avg_loss.item(),
230
+ global_max_loss=max_loss.item(),
231
+ extra_metrics={
232
+ "grad_norm": grad_norm.item() if grad_norm >= 0 else float("nan"),
233
+ "lr": lr,
234
+ },
235
+ )
236
+
237
+ def evaluate(self):
238
+ raise NotImplementedError
239
+
240
+ def evaluation_step(
241
+ self,
242
+ batch_iter: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]],
243
+ ) -> None:
244
+ raise NotImplementedError