nextrec 0.4.5__py3-none-any.whl → 0.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.5"
1
+ __version__ = "0.4.7"
nextrec/basic/callback.py CHANGED
@@ -1,35 +1,413 @@
1
1
  """
2
- EarlyStopper definitions
2
+ Callback System for Training Process
3
3
 
4
4
  Date: create on 27/10/2025
5
+ Checkpoint: edit on 17/12/2025
5
6
  Author: Yang Zhou, zyaztec@gmail.com
6
7
  """
7
8
 
8
9
  import copy
10
+ import logging
11
+ from typing import Optional
12
+ from pathlib import Path
13
+ import torch
14
+ import pickle
15
+ from nextrec import __version__
9
16
 
10
17
 
11
- class EarlyStopper(object):
12
- def __init__(self, patience: int = 20, mode: str = "max"):
18
+ class Callback:
19
+ """
20
+ Base callback.
21
+
22
+ Notes (DDP):
23
+ - In distributed training, the training loop runs on every rank.
24
+ - For callbacks with side effects (saving, logging, etc.), set
25
+ ``run_on_main_process_only=True`` to avoid multi-rank duplication.
26
+ """
27
+
28
+ run_on_main_process_only: bool = False
29
+
30
+ def on_train_begin(self, logs: Optional[dict] = None):
31
+ pass
32
+
33
+ def on_train_end(self, logs: Optional[dict] = None):
34
+ pass
35
+
36
+ def on_epoch_begin(self, epoch: int, logs: Optional[dict] = None):
37
+ pass
38
+
39
+ def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
40
+ pass
41
+
42
+ def on_batch_begin(self, batch: int, logs: Optional[dict] = None):
43
+ pass
44
+
45
+ def on_batch_end(self, batch: int, logs: Optional[dict] = None):
46
+ pass
47
+
48
+ def on_validation_begin(self, logs: Optional[dict] = None):
49
+ pass
50
+
51
+ def on_validation_end(self, logs: Optional[dict] = None):
52
+ pass
53
+
54
+ def set_model(self, model):
55
+ self.model = model
56
+
57
+ def set_params(self, params: dict):
58
+ self.params = params
59
+
60
+ def should_run(self) -> bool:
61
+ if not getattr(self, "run_on_main_process_only", False):
62
+ return True
63
+ model = getattr(self, "model", None)
64
+ if model is None:
65
+ return True
66
+ return bool(getattr(model, "is_main_process", True))
67
+
68
+
69
+ class CallbackList:
70
+ """Container for managing multiple callbacks."""
71
+
72
+ def __init__(self, callbacks: Optional[list[Callback]] = None):
73
+ self.callbacks = callbacks or []
74
+
75
+ def append(self, callback: Callback):
76
+ self.callbacks.append(callback)
77
+
78
+ def set_model(self, model):
79
+ for callback in self.callbacks:
80
+ callback.set_model(model)
81
+
82
+ def set_params(self, params: dict):
83
+ for callback in self.callbacks:
84
+ callback.set_params(params)
85
+
86
+ def on_train_begin(self, logs: Optional[dict] = None):
87
+ for callback in self.callbacks:
88
+ if not callback.should_run():
89
+ continue
90
+ callback.on_train_begin(logs)
91
+
92
+ def on_train_end(self, logs: Optional[dict] = None):
93
+ for callback in self.callbacks:
94
+ if not callback.should_run():
95
+ continue
96
+ callback.on_train_end(logs)
97
+
98
+ def on_epoch_begin(self, epoch: int, logs: Optional[dict] = None):
99
+ for callback in self.callbacks:
100
+ if not callback.should_run():
101
+ continue
102
+ callback.on_epoch_begin(epoch, logs)
103
+
104
+ def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
105
+ for callback in self.callbacks:
106
+ if not callback.should_run():
107
+ continue
108
+ callback.on_epoch_end(epoch, logs)
109
+
110
+ def on_batch_begin(self, batch: int, logs: Optional[dict] = None):
111
+ for callback in self.callbacks:
112
+ if not callback.should_run():
113
+ continue
114
+ callback.on_batch_begin(batch, logs)
115
+
116
+ def on_batch_end(self, batch: int, logs: Optional[dict] = None):
117
+ for callback in self.callbacks:
118
+ if not callback.should_run():
119
+ continue
120
+ callback.on_batch_end(batch, logs)
121
+
122
+ def on_validation_begin(self, logs: Optional[dict] = None):
123
+ for callback in self.callbacks:
124
+ if not callback.should_run():
125
+ continue
126
+ callback.on_validation_begin(logs)
127
+
128
+ def on_validation_end(self, logs: Optional[dict] = None):
129
+ for callback in self.callbacks:
130
+ if not callback.should_run():
131
+ continue
132
+ callback.on_validation_end(logs)
133
+
134
+
135
+ class EarlyStopper(Callback):
136
+
137
+ def __init__(
138
+ self,
139
+ monitor: str = "val_auc",
140
+ patience: int = 20,
141
+ mode: str = "max",
142
+ min_delta: float = 0.0,
143
+ restore_best_weights: bool = True,
144
+ verbose: int = 1,
145
+ ):
146
+ super().__init__()
147
+ self.monitor = monitor
13
148
  self.patience = patience
14
- self.trial_counter = 0
15
- self.best_metrics = 0
149
+ self.mode = mode
150
+ self.min_delta = abs(min_delta)
151
+ self.restore_best_weights = restore_best_weights
152
+ self.verbose = verbose
153
+
154
+ self.wait = 0
155
+ self.stopped_epoch = 0
16
156
  self.best_weights = None
157
+ self.best_epoch = 0
158
+
159
+ if mode == "min":
160
+ self.best_value = float("inf")
161
+ self.monitor_op = lambda current, best: current < (best - self.min_delta)
162
+ elif mode == "max":
163
+ self.best_value = float("-inf")
164
+ self.monitor_op = lambda current, best: current > (best + self.min_delta)
165
+ else:
166
+ raise ValueError(f"mode must be 'min' or 'max', got {mode}")
167
+
168
+ def on_train_begin(self, logs: Optional[dict] = None):
169
+ self.wait = 0
170
+ self.stopped_epoch = 0
171
+ self.best_weights = None
172
+ self.best_epoch = 0
173
+ if self.mode == "min":
174
+ self.best_value = float("inf")
175
+ else:
176
+ self.best_value = float("-inf")
177
+
178
+ def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
179
+ logs = logs or {}
180
+ current = logs.get(self.monitor)
181
+
182
+ if current is None:
183
+ if self.verbose > 0:
184
+ logging.warning(
185
+ f"Early stopping conditioned on metric `{self.monitor}` "
186
+ f"which is not available. Available metrics are: {','.join(list(logs.keys()))}"
187
+ )
188
+ return
189
+
190
+ if self.monitor_op(current, self.best_value):
191
+ self.best_value = current
192
+ self.best_epoch = epoch
193
+ self.wait = 0
194
+ if self.restore_best_weights:
195
+ self.best_weights = copy.deepcopy(self.model.state_dict())
196
+ else:
197
+ self.wait += 1
198
+ if self.wait >= self.patience:
199
+ self.stopped_epoch = epoch
200
+ if hasattr(self.model, "stop_training"):
201
+ self.model.stop_training = True
202
+ if self.verbose > 0:
203
+ logging.info(
204
+ f"Early stopping triggered at epoch {epoch + 1}. "
205
+ f"Best {self.monitor}: {self.best_value:.6f} at epoch {self.best_epoch + 1}"
206
+ )
207
+
208
+ def on_train_end(self, logs: Optional[dict] = None):
209
+ if self.restore_best_weights and self.best_weights is not None:
210
+ if self.verbose > 0:
211
+ logging.info(
212
+ f"Restoring model weights from epoch {self.best_epoch + 1} "
213
+ f"with best {self.monitor}: {self.best_value:.6f}"
214
+ )
215
+ self.model.load_state_dict(self.best_weights)
216
+
217
+
218
+ class CheckpointSaver(Callback):
219
+ """Callback to save model checkpoints during training.
220
+
221
+ Args:
222
+ save_path: Path to save checkpoints.
223
+ monitor: Metric name to monitor for saving best model.
224
+ mode: One of {'min', 'max'}.
225
+ save_best_only: If True, only save when the model is considered the "best".
226
+ save_freq: Frequency of checkpoint saving ('epoch' or integer for every N epochs).
227
+ verbose: Verbosity mode.
228
+ """
229
+
230
+ def __init__(
231
+ self,
232
+ save_path: str | Path,
233
+ monitor: str = "val_auc",
234
+ mode: str = "max",
235
+ save_best_only: bool = False,
236
+ save_freq: str | int = "epoch",
237
+ verbose: int = 1,
238
+ run_on_main_process_only: bool = True,
239
+ ):
240
+ super().__init__()
241
+ self.run_on_main_process_only = run_on_main_process_only
242
+ self.save_path = Path(save_path)
243
+ self.monitor = monitor
17
244
  self.mode = mode
245
+ self.save_best_only = save_best_only
246
+ self.save_freq = save_freq
247
+ self.verbose = verbose
18
248
 
19
- def stop_training(self, val_metrics, weights):
20
- if self.mode == "max":
21
- if val_metrics > self.best_metrics:
22
- self.best_metrics = val_metrics
23
- self.trial_counter = 0
24
- self.best_weights = copy.deepcopy(weights)
25
- elif self.mode == "min":
26
- if val_metrics < self.best_metrics:
27
- self.best_metrics = val_metrics
28
- self.trial_counter = 0
29
- self.best_weights = copy.deepcopy(weights)
30
- return False
31
- elif self.trial_counter + 1 < self.patience:
32
- self.trial_counter += 1
33
- return False
249
+ if mode == "min":
250
+ self.best_value = float("inf")
251
+ self.monitor_op = lambda current, best: current < best
252
+ elif mode == "max":
253
+ self.best_value = float("-inf")
254
+ self.monitor_op = lambda current, best: current > best
34
255
  else:
35
- return True
256
+ raise ValueError(f"mode must be 'min' or 'max', got {mode}")
257
+
258
+ def on_train_begin(self, logs: Optional[dict] = None):
259
+ if self.mode == "min":
260
+ self.best_value = float("inf")
261
+ else:
262
+ self.best_value = float("-inf")
263
+
264
+ # Create directory if it doesn't exist
265
+ self.save_path.parent.mkdir(parents=True, exist_ok=True)
266
+
267
+ def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
268
+ logs = logs or {}
269
+
270
+ # Check if we should save this epoch
271
+ should_save = False
272
+ if self.save_freq == "epoch":
273
+ should_save = True
274
+ elif isinstance(self.save_freq, int) and (epoch + 1) % self.save_freq == 0:
275
+ should_save = True
276
+
277
+ if not should_save and self.save_best_only:
278
+ should_save = False
279
+
280
+ # Check if this is the best model
281
+ current = logs.get(self.monitor)
282
+ is_best = False
283
+
284
+ if current is not None and self.monitor_op(current, self.best_value):
285
+ self.best_value = current
286
+ is_best = True
287
+ should_save = True
288
+
289
+ if should_save:
290
+ if not self.save_best_only or is_best:
291
+ checkpoint_path = (
292
+ self.save_path.parent
293
+ / f"{self.save_path.stem}_epoch_{epoch + 1}{self.save_path.suffix}"
294
+ )
295
+ self.save_checkpoint(checkpoint_path, epoch, logs)
296
+
297
+ if is_best:
298
+ # Use save_path directly without adding _best suffix since it may already contain it
299
+ self.save_checkpoint(self.save_path, epoch, logs)
300
+ if self.verbose > 0:
301
+ logging.info(
302
+ f"Saved best model to {self.save_path} with {self.monitor}: {current:.6f}"
303
+ )
304
+
305
+ def save_checkpoint(self, path: Path, epoch: int, logs: dict):
306
+
307
+ # Get the actual model (unwrap DDP if needed)
308
+ model_to_save = (
309
+ self.model.ddp_model.module
310
+ if getattr(self.model, "ddp_model", None) is not None
311
+ else self.model
312
+ )
313
+
314
+ # Save only state_dict to match BaseModel.save_model() format
315
+ torch.save(model_to_save.state_dict(), path)
316
+
317
+ # Also save features_config.pkl if it doesn't exist
318
+ config_path = path.parent / "features_config.pkl"
319
+ if not config_path.exists():
320
+ features_config = {
321
+ "all_features": self.model.all_features,
322
+ "target": self.model.target_columns,
323
+ "id_columns": self.model.id_columns,
324
+ "version": __version__,
325
+ }
326
+ with open(config_path, "wb") as f:
327
+ pickle.dump(features_config, f)
328
+
329
+ if self.verbose > 1:
330
+ logging.info(f"Saved checkpoint to {path}")
331
+
332
+
333
+ class LearningRateScheduler(Callback):
334
+ """Callback for learning rate scheduling.
335
+
336
+ Args:
337
+ scheduler: Learning rate scheduler instance or name.
338
+ verbose: Verbosity mode.
339
+ """
340
+
341
+ def __init__(self, scheduler=None, verbose: int = 0):
342
+ super().__init__()
343
+ self.scheduler = scheduler
344
+ self.verbose = verbose
345
+
346
+ def on_train_begin(self, logs: Optional[dict] = None):
347
+ if self.scheduler is None and hasattr(self.model, "scheduler_fn"):
348
+ self.scheduler = self.model.scheduler_fn
349
+
350
+ def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
351
+ if self.scheduler is not None:
352
+ # Get current lr before step
353
+ if hasattr(self.model, "optimizer_fn"):
354
+ old_lr = self.model.optimizer_fn.param_groups[0]["lr"]
355
+
356
+ # Step the scheduler
357
+ if hasattr(self.scheduler, "step"):
358
+ # Some schedulers need metrics
359
+ if "val_loss" in (logs or {}) and hasattr(self.scheduler, "mode"):
360
+ self.scheduler.step(logs["val_loss"])
361
+ else:
362
+ self.scheduler.step()
363
+
364
+ # Log new lr
365
+ if self.verbose > 0 and hasattr(self.model, "optimizer_fn"):
366
+ if getattr(self.model, "is_main_process", True):
367
+ new_lr = self.model.optimizer_fn.param_groups[0]["lr"]
368
+ if new_lr != old_lr:
369
+ logging.info(
370
+ f"Learning rate changed from {old_lr:.6e} to {new_lr:.6e}"
371
+ )
372
+
373
+
374
+ class MetricsLogger(Callback):
375
+ """Callback for logging training metrics.
376
+
377
+ Args:
378
+ log_freq: Frequency of logging ('epoch', 'batch', or integer for every N epochs/batches).
379
+ verbose: Verbosity mode.
380
+ """
381
+
382
+ def __init__(self, log_freq: str | int = "epoch", verbose: int = 1):
383
+ super().__init__()
384
+ self.run_on_main_process_only = True
385
+ self.log_freq = log_freq
386
+ self.verbose = verbose
387
+ self.batch_count = 0
388
+
389
+ def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
390
+ if self.verbose > 0 and (
391
+ self.log_freq == "epoch"
392
+ or (isinstance(self.log_freq, int) and (epoch + 1) % self.log_freq == 0)
393
+ ):
394
+ logs = logs or {}
395
+ metrics_str = " - ".join(
396
+ [
397
+ f"{k}: {v:.6f}" if isinstance(v, float) else f"{k}: {v}"
398
+ for k, v in logs.items()
399
+ ]
400
+ )
401
+ logging.info(f"Epoch {epoch + 1}: {metrics_str}")
402
+
403
+ def on_batch_end(self, batch: int, logs: Optional[dict] = None):
404
+ self.batch_count += 1
405
+ if self.verbose > 1 and self.log_freq == "batch":
406
+ logs = logs or {}
407
+ metrics_str = " - ".join(
408
+ [
409
+ f"{k}: {v:.6f}" if isinstance(v, float) else f"{k}: {v}"
410
+ for k, v in logs.items()
411
+ ]
412
+ )
413
+ logging.info(f"Batch {batch}: {metrics_str}")
nextrec/basic/features.py CHANGED
@@ -33,6 +33,8 @@ class SequenceFeature(BaseFeature):
33
33
  l1_reg: float = 0.0,
34
34
  l2_reg: float = 1e-5,
35
35
  trainable: bool = True,
36
+ pretrained_weight: torch.Tensor | None = None,
37
+ freeze_pretrained: bool = False,
36
38
  ):
37
39
  self.name = name
38
40
  self.vocab_size = vocab_size
@@ -47,6 +49,8 @@ class SequenceFeature(BaseFeature):
47
49
  self.l1_reg = l1_reg
48
50
  self.l2_reg = l2_reg
49
51
  self.trainable = trainable
52
+ self.pretrained_weight = pretrained_weight
53
+ self.freeze_pretrained = freeze_pretrained
50
54
 
51
55
 
52
56
  class SparseFeature(BaseFeature):
nextrec/basic/layers.py CHANGED
@@ -496,12 +496,18 @@ class HadamardInteractionLayer(nn.Module):
496
496
 
497
497
 
498
498
  class MultiHeadSelfAttention(nn.Module):
499
+ """
500
+ Multi-Head Self-Attention layer with Flash Attention support.
501
+ Uses PyTorch 2.0+ scaled_dot_product_attention when available for better performance.
502
+ """
503
+
499
504
  def __init__(
500
505
  self,
501
506
  embedding_dim: int,
502
507
  num_heads: int = 2,
503
508
  dropout: float = 0.0,
504
509
  use_residual: bool = True,
510
+ use_layer_norm: bool = False,
505
511
  ):
506
512
  super().__init__()
507
513
  if embedding_dim % num_heads != 0:
@@ -512,45 +518,100 @@ class MultiHeadSelfAttention(nn.Module):
512
518
  self.num_heads = num_heads
513
519
  self.head_dim = embedding_dim // num_heads
514
520
  self.use_residual = use_residual
521
+ self.dropout_rate = dropout
522
+
515
523
  self.W_Q = nn.Linear(embedding_dim, embedding_dim, bias=False)
516
524
  self.W_K = nn.Linear(embedding_dim, embedding_dim, bias=False)
517
525
  self.W_V = nn.Linear(embedding_dim, embedding_dim, bias=False)
526
+ self.W_O = nn.Linear(embedding_dim, embedding_dim, bias=False)
527
+
518
528
  if self.use_residual:
519
529
  self.W_Res = nn.Linear(embedding_dim, embedding_dim, bias=False)
530
+ if use_layer_norm:
531
+ self.layer_norm = nn.LayerNorm(embedding_dim)
532
+ else:
533
+ self.layer_norm = None
534
+
520
535
  self.dropout = nn.Dropout(dropout)
536
+ # Check if Flash Attention is available
537
+ self.use_flash_attention = hasattr(F, "scaled_dot_product_attention")
521
538
 
522
- def forward(self, x: torch.Tensor) -> torch.Tensor:
523
- batch_size, num_fields, _ = x.shape
524
- Q = self.W_Q(x) # [batch_size, num_fields, embedding_dim]
539
+ def forward(
540
+ self, x: torch.Tensor, attention_mask: torch.Tensor | None = None
541
+ ) -> torch.Tensor:
542
+ """
543
+ Args:
544
+ x: [batch_size, seq_len, embedding_dim]
545
+ attention_mask: [batch_size, seq_len] or [batch_size, seq_len, seq_len], boolean mask where True indicates valid positions
546
+ Returns:
547
+ output: [batch_size, seq_len, embedding_dim]
548
+ """
549
+ batch_size, seq_len, _ = x.shape
550
+ Q = self.W_Q(x) # [batch_size, seq_len, embedding_dim]
525
551
  K = self.W_K(x)
526
552
  V = self.W_V(x)
527
- # Split into multiple heads: [batch_size, num_heads, num_fields, head_dim]
528
- Q = Q.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
529
- 1, 2
530
- )
531
- K = K.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
532
- 1, 2
533
- )
534
- V = V.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
535
- 1, 2
536
- )
537
- # Attention scores
538
- scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim**0.5)
539
- attention_weights = F.softmax(scores, dim=-1)
540
- attention_weights = self.dropout(attention_weights)
541
- attention_output = torch.matmul(
542
- attention_weights, V
543
- ) # [batch_size, num_heads, num_fields, head_dim]
553
+
554
+ # Split into multiple heads: [batch_size, num_heads, seq_len, head_dim]
555
+ Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
556
+ K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
557
+ V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
558
+
559
+ if self.use_flash_attention:
560
+ # Use PyTorch 2.0+ Flash Attention
561
+ if attention_mask is not None:
562
+ # Convert mask to [batch_size, 1, seq_len, seq_len] format
563
+ if attention_mask.dim() == 2:
564
+ # [B, L] -> [B, 1, 1, L]
565
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
566
+ elif attention_mask.dim() == 3:
567
+ # [B, L, L] -> [B, 1, L, L]
568
+ attention_mask = attention_mask.unsqueeze(1)
569
+ attention_output = F.scaled_dot_product_attention(
570
+ Q,
571
+ K,
572
+ V,
573
+ attn_mask=attention_mask,
574
+ dropout_p=self.dropout_rate if self.training else 0.0,
575
+ )
576
+ # Handle potential NaN values
577
+ attention_output = torch.nan_to_num(attention_output, nan=0.0)
578
+ else:
579
+ # Fallback to standard attention
580
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim**0.5)
581
+
582
+ if attention_mask is not None:
583
+ # Process mask for standard attention
584
+ if attention_mask.dim() == 2:
585
+ # [B, L] -> [B, 1, 1, L]
586
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
587
+ elif attention_mask.dim() == 3:
588
+ # [B, L, L] -> [B, 1, L, L]
589
+ attention_mask = attention_mask.unsqueeze(1)
590
+ scores = scores.masked_fill(~attention_mask, float("-1e9"))
591
+
592
+ attention_weights = F.softmax(scores, dim=-1)
593
+ attention_weights = self.dropout(attention_weights)
594
+ attention_output = torch.matmul(
595
+ attention_weights, V
596
+ ) # [batch_size, num_heads, seq_len, head_dim]
597
+
544
598
  # Concatenate heads
545
599
  attention_output = attention_output.transpose(1, 2).contiguous()
546
600
  attention_output = attention_output.view(
547
- batch_size, num_fields, self.embedding_dim
601
+ batch_size, seq_len, self.embedding_dim
548
602
  )
603
+
604
+ # Output projection
605
+ output = self.W_O(attention_output)
606
+
549
607
  # Residual connection
550
608
  if self.use_residual:
551
- output = attention_output + self.W_Res(x)
552
- else:
553
- output = attention_output
609
+ output = output + self.W_Res(x)
610
+
611
+ # Layer normalization
612
+ if self.layer_norm is not None:
613
+ output = self.layer_norm(output)
614
+
554
615
  output = F.relu(output)
555
616
  return output
556
617
 
@@ -653,3 +714,21 @@ class AttentionPoolingLayer(nn.Module):
653
714
  # Weighted sum over keys: (B, L, 1) * (B, L, D) -> (B, D)
654
715
  output = torch.sum(attention_weights * keys, dim=1)
655
716
  return output
717
+
718
+
719
+ class RMSNorm(torch.nn.Module):
720
+ """
721
+ Root Mean Square Layer Normalization.
722
+ Reference: https://arxiv.org/abs/1910.07467
723
+ """
724
+
725
+ def __init__(self, hidden_size: int, eps: float = 1e-6):
726
+ super().__init__()
727
+ self.eps = eps
728
+ self.weight = torch.nn.Parameter(torch.ones(hidden_size))
729
+
730
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
731
+ # RMS(x) = sqrt(mean(x^2) + eps)
732
+ variance = torch.mean(x**2, dim=-1, keepdim=True)
733
+ x_normalized = x * torch.rsqrt(variance + self.eps)
734
+ return self.weight * x_normalized