adv-optm 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,298 @@
1
+ import torch
2
+ from typing import Optional, Callable
3
+
4
+ import math
5
+
6
+ from ..util.BF16_Stochastic_Rounding import add_stochastic_
7
+ from ..util.Effective_Shape import _get_effective_shape
8
+ from ..util.NNMF import _nnmf,_unnmf
9
+ from ..util.OrthoGrad import _orthogonalize_gradient
10
+ from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
11
+ from ..util.Kourkoutas import KourkoutasHelper
12
+
13
+ # A little helper from the original simplified_AdEMAMix
14
+ def linear_hl_warmup_scheduler(step, beta_end, beta_start=0, warmup=1):
15
+
16
+ def f(beta, eps=1e-8):
17
+ return math.log(0.5)/math.log(beta+eps)-1
18
+
19
+ def f_inv(t):
20
+ return math.pow(0.5, 1/(t+1))
21
+
22
+ if step < warmup:
23
+ a = step / float(warmup)
24
+ return f_inv((1.0-a) * f(beta_start) + a * f(beta_end))
25
+ return beta_end
26
+
27
+ class Simplified_AdEMAMix(torch.optim.Optimizer):
28
+ """
29
+ Implements the Simplified AdEMAMix algorithm.
30
+ Refactored from:
31
+ https://github.com/DepenM/Simplified-AdEMAMix/blob/main/simplified_AdEMAMix.py
32
+
33
+ Args:
34
+ params (iterable): iterable of parameters to optimize or dicts defining
35
+ parameter groups
36
+ lr (float): learning rate (default: 1e-5)
37
+ betas (tuple[float, float]): coefficients used for computing running
38
+ averages of gradient and its square (default: (0.99, 0.999))
39
+ eps (float): term added to the denominator to improve
40
+ numerical stability (default: 1e-8)
41
+ weight_decay (float): weight decay (L2 penalty) (default: 0).
42
+ alpha_grad (float): Coeficient for mixing the current gradient and EMA. for small batch
43
+ sizes set it to high values, up to 100. And for large batch sized set it to small
44
+ value, down to 0. (default: 100)
45
+ beta1_warmup (int, optional): number of warmup steps used to increase beta1 (default: None)
46
+ min_beta1 (float, optional): minimum value of beta1 to start from (default 0.9)
47
+ vector_reshape (bool): whether to reshape 1D vectors into 2D
48
+ matrices to apply low-rank compression (default: True).
49
+ stochastic_rounding (bool): whether to use stochastic
50
+ rounding for BF16 parameter updates (default: True).
51
+ orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
52
+ kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
53
+ If `False`, the optimizer behaves as standard Simplified_AdEMAMix. (default: False)
54
+ beta2_min (float): The minimum value for dynamic β₂, used during periods of
55
+ high gradient variance ("sunspikes"). Must be less than `betas[1]`.
56
+ (default: 0.88)
57
+ ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
58
+ the pooled gradient norms. Corresponds to `α` in the paper.
59
+ (default: 0.93)
60
+ tiny_spike (float): A small constant added to the denominator of the
61
+ "sunspike" ratio calculation to prevent division by zero. Corresponds
62
+ to `ε_spike` in the paper. (default: 1e-9)
63
+ k_warmup_steps (int): The number of initial steps during which β₂ is held
64
+ at a fixed beta2 value before the
65
+ dynamic logic activates. (default: 0)
66
+ k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
67
+ logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
68
+ every logging steps. Useful for debugging and tuning. Set to 0 to disable
69
+ logging (default: 0).
70
+ layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
71
+ and returns a unique, hashable key representing its "layer" or "bucket".
72
+ If `None`, parameters are bucketed by their memory ID (tensor-wise).
73
+ (default: None)
74
+ nnmf_factor (bool): whether to use the factorization or disable it to use
75
+ the uncompressed optimizer. (default: False)
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ params,
81
+ lr: float = 1e-5,
82
+ betas: tuple[float, float] = (0.99, 0.999),
83
+ eps: float = 1e-8,
84
+ weight_decay: float = 0.0,
85
+ alpha_grad: float = 100.0,
86
+ beta1_warmup: int | None = None,
87
+ min_beta1: float | None = 0.9,
88
+ use_bias_correction: bool = True,
89
+ vector_reshape: bool = True,
90
+ stochastic_rounding: bool = True,
91
+ orthogonal_gradient: bool = False,
92
+ kourkoutas_beta: bool = False,
93
+ beta2_min: float = 0.9,
94
+ ema_alpha: float = 0.95,
95
+ tiny_spike: float = 1e-9,
96
+ k_warmup_steps: int = 0,
97
+ k_logging: int = 0,
98
+ layer_key_fn: Optional[Callable] = None,
99
+ nnmf_factor: bool = False,
100
+ ):
101
+ if not (lr >= 0.0):
102
+ raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
103
+ if not (0.0 <= betas[0] < 1.0 and 0.0 <= betas[1] < 1.0):
104
+ raise ValueError(f"Betas should be in [0.0, 1.0). Got {betas}")
105
+ if not (eps >= 0.0):
106
+ raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
107
+ if not (weight_decay >= 0.0):
108
+ raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
109
+ if not 0.0 <= alpha_grad:
110
+ raise ValueError("Invalid alpha value: {}".format(alpha_grad))
111
+ if kourkoutas_beta and not (betas[1] > beta2_min): raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
112
+
113
+ defaults = {
114
+ "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
115
+ "alpha_grad": alpha_grad, "beta1_warmup": beta1_warmup, "min_beta1": min_beta1,
116
+ "vector_reshape": vector_reshape,
117
+ "orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
118
+ "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
119
+ "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
120
+ }
121
+ self.stochastic_rounding = stochastic_rounding
122
+ self.factored = nnmf_factor
123
+ self.kourkoutas_beta = kourkoutas_beta
124
+ self.layer_key_fn = layer_key_fn
125
+ super().__init__(params, defaults)
126
+
127
+ if self.kourkoutas_beta:
128
+ self.kourkoutas_helper = KourkoutasHelper(self)
129
+
130
+ @property
131
+ def supports_fused_back_pass(self):
132
+ return True
133
+
134
+ @property
135
+ def supports_memory_efficient_fp16(self):
136
+ return True
137
+
138
+ @property
139
+ def supports_flat_params(self):
140
+ return False
141
+
142
+ @torch.no_grad()
143
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
144
+ if p.grad is None:
145
+ return
146
+
147
+ grad = p.grad
148
+ if grad.dtype != torch.float32 and self.factored:
149
+ grad = grad.float()
150
+ if group["orthogonal_gradient"]:
151
+ grad = _orthogonalize_gradient(p, grad)
152
+ state = self.state[p]
153
+
154
+ # State Initialization
155
+ if 'step' not in state:
156
+ state['step'] = 0
157
+
158
+ should_factor = (
159
+ self.factored and
160
+ not (len(p.shape) == 1 and not group['vector_reshape'])
161
+ )
162
+
163
+ state['factored'] = should_factor
164
+
165
+ dtype = torch.float32 if self.factored else p.dtype
166
+ device = p.device
167
+
168
+ if state['factored']:
169
+ state['effective_shape'] = _get_effective_shape(p.numel())
170
+ d1, d2 = state['effective_shape']
171
+
172
+ # First moment (m)
173
+ state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
174
+ state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
175
+ packed_d2 = (d2 + 7) // 8
176
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
177
+ # Second moment (v)
178
+ state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
179
+ state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
180
+ else: # Fallback to standard optimizer for non-factored tensors
181
+ state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
182
+ state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
183
+
184
+ if group['use_bias_correction']:
185
+ state['num_sum'] = 0.0
186
+ state['den_sum'] = 0.0
187
+ else:
188
+ state['num_sum'] = 1.0
189
+ state['den_sum'] = 1.0
190
+
191
+ beta1_final, beta2 = group["betas"]
192
+
193
+ current_step = state['step']
194
+ if group.get('kourkoutas_beta', False):
195
+ # Call prepare_step() once at the beginning of the step for all params
196
+ self.kourkoutas_helper.maybe_prepare_step(current_step)
197
+ # Accumulate current grad's norm for the *next* step
198
+ self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
199
+ # Get the dynamic beta2 calculated in prepare_step()
200
+ beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
201
+
202
+ beta1_warmup = group["beta1_warmup"]
203
+ alpha_grad = group["alpha_grad"]
204
+
205
+ if beta1_warmup is not None:
206
+ step = state['step'] + 1
207
+ beta1 = linear_hl_warmup_scheduler(step, beta_end=beta1_final, beta_start=group['min_beta1'], warmup=beta1_warmup)
208
+ else:
209
+ beta1 = beta1_final
210
+
211
+ if group['use_bias_correction']:
212
+ state['num_sum'] = beta1 * state['num_sum'] + 1.0
213
+ if group.get('kourkoutas_beta', False):
214
+ state['den_sum'] = group['betas'][1] * state['den_sum'] + (1.0 - group['betas'][1])
215
+ else:
216
+ state['den_sum'] = beta2 * state['den_sum'] + (1.0 - beta2)
217
+
218
+ if state['factored']:
219
+ d1, d2 = state['effective_shape']
220
+
221
+ # Reconstruct momentum from previous step's factors
222
+ mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
223
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
224
+ torch.where(unpacked_sign, mt, -mt, out=mt)
225
+ del unpacked_sign
226
+ # Update momentum in full-size
227
+ grad_reshaped = grad.view(d1, d2)
228
+ mt.mul_(beta1).add_(grad_reshaped, alpha=1.0)
229
+
230
+ vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
231
+ vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
232
+
233
+ update = torch.add(mt, grad_reshaped, alpha=alpha_grad)
234
+ del grad_reshaped
235
+
236
+ denom = vt.sqrt().add_(group['eps'] * math.sqrt(state['den_sum']))
237
+ update.div_(denom)
238
+ del denom
239
+
240
+ if group['use_bias_correction']:
241
+ update = (update / state['num_sum']) * math.sqrt(state['den_sum'])
242
+
243
+ update = update.view(p.shape).mul_(group['lr'])
244
+
245
+ # Compress updated moments and store new factors
246
+ state['sign'] = _pack_bools(mt > 0)
247
+ _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
248
+ del mt
249
+ _nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
250
+ del vt
251
+
252
+ else: # Standard optimizer logic for non-factored tensors
253
+ exp_avg_sq = state['exp_avg_sq']
254
+
255
+ exp_avg = state['exp_avg']
256
+ exp_avg.mul_(beta1).add_(grad, alpha=1.0)
257
+
258
+ update = torch.add(exp_avg, grad, alpha=alpha_grad)
259
+
260
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
261
+
262
+ denom = exp_avg_sq.sqrt().add_(group['eps'] * math.sqrt(state['den_sum']))
263
+ update.div_(denom)
264
+ del denom
265
+
266
+ if group['use_bias_correction']:
267
+ update = (update / state['num_sum']) * math.sqrt(state['den_sum'])
268
+
269
+ update.mul_(group['lr'])
270
+
271
+ # Decoupled weight decay
272
+ if group["weight_decay"] != 0:
273
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
274
+ add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
275
+ else:
276
+ p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
277
+
278
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
279
+ add_stochastic_(p.data, -update)
280
+ else:
281
+ p.data.add_(-update)
282
+ del update
283
+
284
+ state['step'] += 1
285
+
286
+ @torch.no_grad()
287
+ def step(self, closure=None):
288
+ """Performs a single optimization step."""
289
+ loss = None
290
+ if closure is not None:
291
+ with torch.enable_grad():
292
+ loss = closure()
293
+
294
+ for group in self.param_groups:
295
+ for i, p in enumerate(group['params']):
296
+ self.step_parameter(p, group, i)
297
+
298
+ return loss
@@ -0,0 +1,19 @@
1
+ from .AdamW_adv import AdamW_adv
2
+ from .Prodigy_adv import Prodigy_adv
3
+ from .Adopt_adv import Adopt_adv
4
+ from .Simplified_AdEMAMix import Simplified_AdEMAMix
5
+ from .Lion_adv import Lion_adv
6
+ from .Lion_Prodigy_adv import Lion_Prodigy_adv
7
+ from .Muon_adv import Muon_adv
8
+ from .AdaMuon_adv import AdaMuon_adv
9
+
10
+ __all__ = [
11
+ "AdamW_adv",
12
+ "Prodigy_adv",
13
+ "Adopt_adv",
14
+ "Simplified_AdEMAMix",
15
+ "Lion_adv",
16
+ "Lion_Prodigy_adv",
17
+ "Muon_adv",
18
+ "AdaMuon_adv",
19
+ ]
@@ -0,0 +1,47 @@
1
+ import torch
2
+ from torch import Tensor
3
+
4
+ def copy_stochastic_(target: Tensor, source: Tensor):
5
+ """
6
+ Nerogar's implementation of stochastic rounding in the paper "Revisiting BFloat16 Training"
7
+ (https://arxiv.org/abs/2010.06192).
8
+ see:
9
+ https://github.com/pytorch/pytorch/issues/120376
10
+ https://github.com/Nerogar/OneTrainer/blob/daae18eaed8c0fa39289b2ff79cc2c1e08577fcb/modules/util/bf16_stochastic_rounding.py
11
+
12
+ Args:
13
+ target: the target tensor with dtype=bfloat16
14
+ source: the target tensor with dtype=float32
15
+ """
16
+ # create a random 16 bit integer
17
+ result = torch.randint_like(
18
+ source,
19
+ dtype=torch.int32,
20
+ low=0,
21
+ high=(1 << 16),
22
+ )
23
+
24
+ # add the random number to the lower 16 bit of the mantissa
25
+ result.add_(source.view(dtype=torch.int32))
26
+
27
+ # mask off the lower 16 bit of the mantissa
28
+ result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
29
+
30
+ # copy the higher 16 bit into the target tensor
31
+ target.copy_(result.view(dtype=torch.float32))
32
+
33
+ del result
34
+
35
+ def add_stochastic_(input: Tensor, other: Tensor, alpha: float = 1.0):
36
+ """
37
+ adds other to input using stochastic rounding
38
+
39
+ Args:
40
+ input: the input tensor with dtype=bfloat16
41
+ other: the other tensor
42
+ alpha: a multiplier for other
43
+ """
44
+ result = other.clone() if other.dtype == torch.float32 else other.to(dtype=torch.float32)
45
+
46
+ result.add_(input, alpha=alpha)
47
+ copy_stochastic_(input, result)
@@ -0,0 +1,8 @@
1
+ def _get_effective_shape(numel: int) -> tuple[int, int]:
2
+ """Finds two factors of numel that are closest to its square root."""
3
+ if numel <= 0:
4
+ return (0, 0)
5
+ for i in reversed(range(1, int(numel ** 0.5) + 1)):
6
+ if numel % i == 0:
7
+ return (numel // i, i)
8
+ return (numel, 1)
@@ -0,0 +1,159 @@
1
+ import torch
2
+ from torch.optim import Optimizer
3
+ from typing import Callable
4
+
5
+ class KourkoutasHelper:
6
+ """
7
+ A helper class to add layer-wise Kourkoutas-β functionality to a PyTorch optimizer.
8
+ """
9
+ def __init__(self, optimizer: Optimizer):
10
+ # We need a reference to the optimizer to access its param_groups and state
11
+ if not hasattr(optimizer, 'param_groups'):
12
+ raise TypeError("optimizer must be a valid torch.optim.Optimizer instance.")
13
+ self.optimizer = optimizer
14
+ self.layer_state = {}
15
+
16
+ self.layer_info = {}
17
+ self._layer_info_built = False
18
+ self._current_step_prepared = -1
19
+
20
+ # Store stats for external logging (e.g., TensorBoard)
21
+ self.last_beta2_stats = {}
22
+
23
+ # This ensures the map is complete before the first backward pass,
24
+ # making it compatible with fused back pass mechanisms.
25
+ self._build_layer_info_if_needed()
26
+
27
+ def _build_layer_info_if_needed(self):
28
+ """Builds a map of layers and the parameters they contain."""
29
+ if self._layer_info_built:
30
+ return
31
+
32
+ if hasattr(self.optimizer, 'layer_key_fn') and self.optimizer.layer_key_fn is not None:
33
+ # A custom key function was provided by the user. We will use it.
34
+ pass
35
+ else:
36
+ # No key function was provided. Default to coarse, shape-based bucketing.
37
+ self.optimizer.layer_key_fn = lambda p: \
38
+ (id(p),) if p.dim() == 2 and 1 <= p.shape[0] <= 10 and p.shape[1] in {768, 1280, 4096} \
39
+ else tuple(p.shape)
40
+ # This ensures that we won't mix embeddings with tokens (1 to 10)
41
+ # TODO find a better way to safeguard the embeddings
42
+
43
+ for group in self.optimizer.param_groups:
44
+ if not group.get('kourkoutas_beta', False) and not group.get('adam_kourkoutas_beta', False):
45
+ continue
46
+
47
+ for p in group['params']:
48
+ # The mapping is static and should not depend on the presence of a gradient.
49
+ layer_key = self.optimizer.layer_key_fn(p)
50
+ if layer_key not in self.layer_info:
51
+ self.layer_info[layer_key] = {'params': [], 'group_ref': group}
52
+ self.layer_info[layer_key]['params'].append(p)
53
+
54
+ self._layer_info_built = True
55
+
56
+ def prepare_step(self, current_step: int):
57
+ """
58
+ Calculates dynamic beta2 for all layers using the completed scalar accumulators
59
+ from the PREVIOUS step. Should be called once at the start of an optimizer step.
60
+ """
61
+
62
+ beta2_log = []
63
+ # These are just for the sample log, initialize them
64
+ sun, pooled_grad_norm, r_ema_tensor = (torch.tensor(0.0),)*3
65
+
66
+ # The optimizer that owns this helper holds the master defaults for K-b.
67
+ # This is crucial in hybrid optimizers where some param_groups might not
68
+ # have all K-b keys populated, preventing KeyErrors.
69
+ master_defaults = self.optimizer.defaults
70
+
71
+ for layer_key, info in self.layer_info.items():
72
+ params, group = info['params'], info['group_ref']
73
+
74
+ if not group.get('kourkoutas_beta', False) and not group.get('adam_kourkoutas_beta', False):
75
+ continue
76
+
77
+ first_param_in_layer = info['params'][0]
78
+ param_state = self.optimizer.state[first_param_in_layer]
79
+
80
+ if layer_key not in self.layer_state:
81
+ self.layer_state[layer_key] = {
82
+ 'sum_sq_accumulator': torch.tensor(0.0, device=first_param_in_layer.device, dtype=torch.float32)
83
+ }
84
+
85
+ if 'kourkoutas_r_ema' not in param_state:
86
+ param_state['kourkoutas_r_ema'] = torch.tensor(0.0, device=first_param_in_layer.device, dtype=torch.float32)
87
+
88
+ # Use group-specific K-b settings, falling back to the optimizer's master defaults.
89
+ # This makes the helper robust against param groups that enable kourkoutas_beta
90
+ # but are missing the other required hyperparameters.
91
+ ema_alpha = group.get('ema_alpha', master_defaults['ema_alpha'])
92
+ beta2_max = group.get('betas', master_defaults['betas'])[1]
93
+ beta2_min = group.get('beta2_min', master_defaults['beta2_min'])
94
+ tiny_spike = group.get('tiny_spike', master_defaults['tiny_spike'])
95
+ k_warmup_steps = group.get('k_warmup_steps', master_defaults['k_warmup_steps'])
96
+
97
+ r_ema_tensor = param_state['kourkoutas_r_ema']
98
+ accumulator = self.layer_state[layer_key]['sum_sq_accumulator']
99
+
100
+ pooled_grad_norm = torch.sqrt(accumulator)
101
+
102
+ # Update the persistent EMA tensor in-place.
103
+ r_ema_tensor.mul_(ema_alpha).add_(pooled_grad_norm, alpha=1.0 - ema_alpha)
104
+
105
+ sun = torch.tensor(0.0, device=r_ema_tensor.device) # Default sun to 0 for warmup
106
+
107
+ if current_step < k_warmup_steps:
108
+ beta2 = beta2_max
109
+ else:
110
+ raw = pooled_grad_norm / (r_ema_tensor + tiny_spike)
111
+ sun = raw / (1.0 + raw)
112
+ beta2 = beta2_max - (beta2_max - beta2_min) * sun
113
+
114
+ # Store the final calculated beta2 in the helper's transient state for this step.
115
+ self.layer_state[layer_key]['dynamic_beta2'] = beta2.item() if isinstance(beta2, torch.Tensor) else beta2
116
+
117
+ # Reset the accumulator for the next optimizer step.
118
+ accumulator.zero_()
119
+
120
+ beta2_log.append(self.layer_state[layer_key]['dynamic_beta2'])
121
+
122
+ # Always compute stats for TensorBoard
123
+ if beta2_log:
124
+ beta2_tensor = torch.tensor(beta2_log, device='cpu')
125
+ self.last_beta2_stats = {
126
+ 'mean': beta2_tensor.mean().item(),
127
+ }
128
+
129
+ def maybe_prepare_step(self, current_step: int):
130
+ """
131
+ A universal guard that calls prepare_step() exactly once per training step.
132
+ """
133
+ if self._current_step_prepared < current_step:
134
+ self.prepare_step(current_step)
135
+ self._current_step_prepared = current_step
136
+
137
+ def accumulate_gradient_sq_norm(self, p: torch.Tensor, grad: torch.Tensor):
138
+ """
139
+ Accumulates the squared L2 norm of a single gradient for the next step's calculation.
140
+ """
141
+ layer_key = self.optimizer.layer_key_fn(p)
142
+
143
+ if layer_key in self.layer_info:
144
+ # Initialize the transient state for this layer if it's the first time in the step.
145
+ if layer_key not in self.layer_state:
146
+ self.layer_state[layer_key] = {
147
+ 'sum_sq_accumulator': torch.tensor(0.0, device=p.device, dtype=torch.float32)
148
+ }
149
+ # Accumulate for the *next* step's prepare_step call
150
+ self.layer_state[layer_key]['sum_sq_accumulator'] += torch.sum(grad.detach().pow(2)).float()
151
+
152
+ def get_beta2(self, p: torch.Tensor, group: dict, current_step: int) -> float:
153
+ """
154
+ Gets the appropriate beta2 for the current parameter, handling warmup and dynamic value fetching.
155
+ """
156
+ layer_key = self.optimizer.layer_key_fn(p)
157
+ # The default is the max value, which is correct for unmapped params or edge cases
158
+ beta2_default = group.get('betas', group.get('adam_betas'))[1] if group.get('betas', group.get('adam_betas')) else 0.999
159
+ return self.layer_state.get(layer_key, {}).get('dynamic_beta2', beta2_default)
adv_optm/util/NNMF.py ADDED
@@ -0,0 +1,18 @@
1
+ import torch
2
+
3
+ def _unnmf(row_col: tuple) -> torch.Tensor:
4
+ """Reconstructs a matrix from its rank-1 factors (outer product)."""
5
+ return torch.outer(row_col[0], row_col[1])
6
+
7
+ def _nnmf(matrix: torch.Tensor, out: tuple):
8
+ """Performs a rank-1 non-negative matrix factorization."""
9
+ shape = matrix.shape
10
+ torch.sum(matrix, dim=1, out=out[0])
11
+ torch.sum(matrix, dim=0, out=out[1])
12
+ # Normalize one of the factors for stability
13
+ if shape[0] < shape[1]:
14
+ scale = out[0].sum()
15
+ if scale != 0: out[0].div_(scale)
16
+ else:
17
+ scale = out[1].sum()
18
+ if scale != 0: out[1].div_(scale)
@@ -0,0 +1,87 @@
1
+ import torch
2
+
3
+ @torch.no_grad()
4
+ def _newton_schulz_iteration(
5
+ G: torch.Tensor,
6
+ steps: int = 5,
7
+ eps: float = 1e-7,
8
+ coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
9
+ cns: bool = False,
10
+ cns_a_bound: float = 1e-4,
11
+ ) -> torch.Tensor:
12
+ """
13
+ Performs the Newton-Schulz iteration to find the nearest orthogonal matrix.
14
+ This is the core computation of the Muon optimizer.
15
+
16
+ Args:
17
+ G (torch.Tensor): The 2D input matrix (momentum-accumulated gradient).
18
+ steps (int): The number of iterations to run.
19
+ eps (float): Small constant for numerical stability during normalization.
20
+ coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
21
+ quintic polynomial update.
22
+ cns (bool): If True, enables Chebyshev-accelerated Newton-Schulz (CANS)
23
+ using an iterative 3rd-order polynomial with optimal coefficients
24
+ derived at each step.
25
+ cns_a_bound (float): The initial lower bound for singular values when
26
+ using CANS. The upper bound is assumed to be 1.0 after normalization.
27
+ Returns:
28
+ torch.Tensor: The orthogonalized matrix.
29
+ """
30
+ assert G.ndim >= 2
31
+
32
+ a, b, c = coeffs
33
+
34
+ X = G.to(torch.bfloat16)
35
+
36
+ transposed = G.size(-2) > G.size(-1)
37
+ if transposed:
38
+ X = X.mT
39
+
40
+ X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps)
41
+
42
+ if cns:
43
+ # Chebyshev-accelerated Newton-Schulz (CANS) from
44
+ # "Accelerating Newton-Schulz Iteration for Orthogonalization via Chebyshev-type Polynomials"
45
+ # This implements the iterative scheme from Algorithm 1, using the
46
+ # closed-form 3rd-order polynomial from Proposition 2.
47
+ lower_bound = cns_a_bound
48
+ upper_bound = 1.0 # Matrix is normalized, so largest singular value is approx 1.
49
+
50
+ for _ in range(steps):
51
+ # Calculate optimal 3rd-order coefficients c1, c3 for p(x) = c1*x + c3*x^3
52
+ # based on the current singular value bounds [lower_bound, upper_bound].
53
+ # Formulas are derived from Proposition 2 and its proof in Appendix B of the paper.
54
+ a_bound, b_bound = lower_bound, upper_bound
55
+ term = a_bound*a_bound + a_bound*b_bound + b_bound*b_bound
56
+ e_sq = term / 3.0
57
+
58
+ # Calculate alpha, which scales the polynomial
59
+ common_den_part = 2.0 * (e_sq**1.5)
60
+ ab_part = a_bound*a_bound*b_bound + b_bound*b_bound*a_bound
61
+ alpha_den = common_den_part + ab_part
62
+ alpha = 6.0 / alpha_den
63
+
64
+ c1 = alpha * e_sq
65
+ c3 = -alpha / 3.0
66
+
67
+ # Apply the 3rd-order Newton-Schulz update
68
+ A = X @ X.mT
69
+ X = c1 * X + c3 * (A @ X)
70
+
71
+ # Update the singular value bounds for the next iteration based on the error
72
+ eps_num = common_den_part - ab_part
73
+ eps_val = eps_num / alpha_den
74
+ lower_bound = 1.0 - eps_val
75
+ upper_bound = 1.0 + eps_val
76
+ else:
77
+ # Perform the iterative updates
78
+ for _ in range(steps):
79
+ A = X @ X.mT
80
+ B = b * A + c * (A @ A)
81
+ X = a * X + B @ X
82
+
83
+ # Transpose back if necessary
84
+ if transposed:
85
+ X = X.mT
86
+
87
+ return X.to(G.dtype)
@@ -0,0 +1,22 @@
1
+ import torch
2
+
3
+ @torch.no_grad()
4
+ def _pack_bools(tensor: torch.Tensor) -> torch.Tensor:
5
+ """Packs a boolean tensor into a uint8 tensor to achieve 1-bit storage."""
6
+ n, m = tensor.shape
7
+ packed_m = (m + 7) // 8
8
+ padded_tensor = torch.nn.functional.pad(tensor, (0, packed_m * 8 - m), 'constant', 0)
9
+ reshaped = padded_tensor.view(n, packed_m, 8)
10
+ shifter = torch.arange(8, device=tensor.device, dtype=torch.uint8)
11
+ packed = (reshaped.to(torch.uint8) * (2**shifter)).sum(dim=2).to(torch.uint8)
12
+ return packed
13
+
14
+ @torch.no_grad()
15
+ def _unpack_bools(packed_tensor: torch.Tensor, original_m: int) -> torch.Tensor:
16
+ """Unpacks a uint8 tensor back into a boolean tensor."""
17
+ if packed_tensor.dtype != torch.uint8:
18
+ packed_tensor = packed_tensor.to(torch.uint8)
19
+ shifter = (2**torch.arange(8, device=packed_tensor.device, dtype=torch.uint8)).view(1, 1, 8)
20
+ unpacked_padded = (packed_tensor.unsqueeze(2) & shifter) != 0
21
+ unpacked = unpacked_padded.view(packed_tensor.shape[0], -1)[:, :original_m]
22
+ return unpacked
@@ -0,0 +1,16 @@
1
+ import torch
2
+
3
+ def _orthogonalize_gradient(p: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
4
+ """Projects the gradient `grad` to be orthogonal to the parameter `p`."""
5
+ if grad.is_sparse: raise RuntimeError("OrthoGrad logic does not support sparse gradients.")
6
+ original_shape = grad.shape
7
+ original_dtype = grad.dtype
8
+ w = p.view(-1).float()
9
+ g = grad.view(-1).float()
10
+ w_norm_sq = torch.dot(w, w).add_(1e-30)
11
+ proj = torch.dot(w, g) / w_norm_sq
12
+ g_orth = g.sub(w, alpha=proj)
13
+ g_norm = g.norm(2)
14
+ g_orth_norm = g_orth.norm(2).add_(1e-30)
15
+ g_orth_scaled = g_orth * (g_norm / g_orth_norm)
16
+ return g_orth_scaled.view(original_shape).to(original_dtype)