adv-optm 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- adv_optm/__init__.py +23 -0
- adv_optm/optim/AdaMuon_adv.py +720 -0
- adv_optm/optim/AdamW_adv.py +374 -0
- adv_optm/optim/Adopt_adv.py +437 -0
- adv_optm/optim/Lion_Prodigy_adv.py +341 -0
- adv_optm/optim/Lion_adv.py +210 -0
- adv_optm/optim/Muon_adv.py +723 -0
- adv_optm/optim/Prodigy_adv.py +539 -0
- adv_optm/optim/Simplified_AdEMAMix.py +298 -0
- adv_optm/optim/__init__.py +19 -0
- adv_optm/util/BF16_Stochastic_Rounding.py +47 -0
- adv_optm/util/Effective_Shape.py +8 -0
- adv_optm/util/Kourkoutas.py +159 -0
- adv_optm/util/NNMF.py +18 -0
- adv_optm/util/Newton_Schulz.py +87 -0
- adv_optm/util/One_Bit_Boolean.py +22 -0
- adv_optm/util/OrthoGrad.py +16 -0
- adv_optm/util/__init__.py +13 -0
- adv_optm-1.2.0.dist-info/METADATA +222 -0
- adv_optm-1.2.0.dist-info/RECORD +23 -0
- adv_optm-1.2.0.dist-info/WHEEL +5 -0
- adv_optm-1.2.0.dist-info/licenses/LICENSE +201 -0
- adv_optm-1.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,298 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Optional, Callable
|
|
3
|
+
|
|
4
|
+
import math
|
|
5
|
+
|
|
6
|
+
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
7
|
+
from ..util.Effective_Shape import _get_effective_shape
|
|
8
|
+
from ..util.NNMF import _nnmf,_unnmf
|
|
9
|
+
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
10
|
+
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
11
|
+
from ..util.Kourkoutas import KourkoutasHelper
|
|
12
|
+
|
|
13
|
+
# A little helper from the original simplified_AdEMAMix
|
|
14
|
+
def linear_hl_warmup_scheduler(step, beta_end, beta_start=0, warmup=1):
|
|
15
|
+
|
|
16
|
+
def f(beta, eps=1e-8):
|
|
17
|
+
return math.log(0.5)/math.log(beta+eps)-1
|
|
18
|
+
|
|
19
|
+
def f_inv(t):
|
|
20
|
+
return math.pow(0.5, 1/(t+1))
|
|
21
|
+
|
|
22
|
+
if step < warmup:
|
|
23
|
+
a = step / float(warmup)
|
|
24
|
+
return f_inv((1.0-a) * f(beta_start) + a * f(beta_end))
|
|
25
|
+
return beta_end
|
|
26
|
+
|
|
27
|
+
class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
28
|
+
"""
|
|
29
|
+
Implements the Simplified AdEMAMix algorithm.
|
|
30
|
+
Refactored from:
|
|
31
|
+
https://github.com/DepenM/Simplified-AdEMAMix/blob/main/simplified_AdEMAMix.py
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
params (iterable): iterable of parameters to optimize or dicts defining
|
|
35
|
+
parameter groups
|
|
36
|
+
lr (float): learning rate (default: 1e-5)
|
|
37
|
+
betas (tuple[float, float]): coefficients used for computing running
|
|
38
|
+
averages of gradient and its square (default: (0.99, 0.999))
|
|
39
|
+
eps (float): term added to the denominator to improve
|
|
40
|
+
numerical stability (default: 1e-8)
|
|
41
|
+
weight_decay (float): weight decay (L2 penalty) (default: 0).
|
|
42
|
+
alpha_grad (float): Coeficient for mixing the current gradient and EMA. for small batch
|
|
43
|
+
sizes set it to high values, up to 100. And for large batch sized set it to small
|
|
44
|
+
value, down to 0. (default: 100)
|
|
45
|
+
beta1_warmup (int, optional): number of warmup steps used to increase beta1 (default: None)
|
|
46
|
+
min_beta1 (float, optional): minimum value of beta1 to start from (default 0.9)
|
|
47
|
+
vector_reshape (bool): whether to reshape 1D vectors into 2D
|
|
48
|
+
matrices to apply low-rank compression (default: True).
|
|
49
|
+
stochastic_rounding (bool): whether to use stochastic
|
|
50
|
+
rounding for BF16 parameter updates (default: True).
|
|
51
|
+
orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
|
|
52
|
+
kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
|
|
53
|
+
If `False`, the optimizer behaves as standard Simplified_AdEMAMix. (default: False)
|
|
54
|
+
beta2_min (float): The minimum value for dynamic β₂, used during periods of
|
|
55
|
+
high gradient variance ("sunspikes"). Must be less than `betas[1]`.
|
|
56
|
+
(default: 0.88)
|
|
57
|
+
ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
|
|
58
|
+
the pooled gradient norms. Corresponds to `α` in the paper.
|
|
59
|
+
(default: 0.93)
|
|
60
|
+
tiny_spike (float): A small constant added to the denominator of the
|
|
61
|
+
"sunspike" ratio calculation to prevent division by zero. Corresponds
|
|
62
|
+
to `ε_spike` in the paper. (default: 1e-9)
|
|
63
|
+
k_warmup_steps (int): The number of initial steps during which β₂ is held
|
|
64
|
+
at a fixed beta2 value before the
|
|
65
|
+
dynamic logic activates. (default: 0)
|
|
66
|
+
k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
|
|
67
|
+
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
68
|
+
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
69
|
+
logging (default: 0).
|
|
70
|
+
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
71
|
+
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
72
|
+
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
73
|
+
(default: None)
|
|
74
|
+
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
75
|
+
the uncompressed optimizer. (default: False)
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
def __init__(
|
|
79
|
+
self,
|
|
80
|
+
params,
|
|
81
|
+
lr: float = 1e-5,
|
|
82
|
+
betas: tuple[float, float] = (0.99, 0.999),
|
|
83
|
+
eps: float = 1e-8,
|
|
84
|
+
weight_decay: float = 0.0,
|
|
85
|
+
alpha_grad: float = 100.0,
|
|
86
|
+
beta1_warmup: int | None = None,
|
|
87
|
+
min_beta1: float | None = 0.9,
|
|
88
|
+
use_bias_correction: bool = True,
|
|
89
|
+
vector_reshape: bool = True,
|
|
90
|
+
stochastic_rounding: bool = True,
|
|
91
|
+
orthogonal_gradient: bool = False,
|
|
92
|
+
kourkoutas_beta: bool = False,
|
|
93
|
+
beta2_min: float = 0.9,
|
|
94
|
+
ema_alpha: float = 0.95,
|
|
95
|
+
tiny_spike: float = 1e-9,
|
|
96
|
+
k_warmup_steps: int = 0,
|
|
97
|
+
k_logging: int = 0,
|
|
98
|
+
layer_key_fn: Optional[Callable] = None,
|
|
99
|
+
nnmf_factor: bool = False,
|
|
100
|
+
):
|
|
101
|
+
if not (lr >= 0.0):
|
|
102
|
+
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
103
|
+
if not (0.0 <= betas[0] < 1.0 and 0.0 <= betas[1] < 1.0):
|
|
104
|
+
raise ValueError(f"Betas should be in [0.0, 1.0). Got {betas}")
|
|
105
|
+
if not (eps >= 0.0):
|
|
106
|
+
raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
|
|
107
|
+
if not (weight_decay >= 0.0):
|
|
108
|
+
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
109
|
+
if not 0.0 <= alpha_grad:
|
|
110
|
+
raise ValueError("Invalid alpha value: {}".format(alpha_grad))
|
|
111
|
+
if kourkoutas_beta and not (betas[1] > beta2_min): raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
|
|
112
|
+
|
|
113
|
+
defaults = {
|
|
114
|
+
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
115
|
+
"alpha_grad": alpha_grad, "beta1_warmup": beta1_warmup, "min_beta1": min_beta1,
|
|
116
|
+
"vector_reshape": vector_reshape,
|
|
117
|
+
"orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
|
|
118
|
+
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
119
|
+
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
120
|
+
}
|
|
121
|
+
self.stochastic_rounding = stochastic_rounding
|
|
122
|
+
self.factored = nnmf_factor
|
|
123
|
+
self.kourkoutas_beta = kourkoutas_beta
|
|
124
|
+
self.layer_key_fn = layer_key_fn
|
|
125
|
+
super().__init__(params, defaults)
|
|
126
|
+
|
|
127
|
+
if self.kourkoutas_beta:
|
|
128
|
+
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
129
|
+
|
|
130
|
+
@property
|
|
131
|
+
def supports_fused_back_pass(self):
|
|
132
|
+
return True
|
|
133
|
+
|
|
134
|
+
@property
|
|
135
|
+
def supports_memory_efficient_fp16(self):
|
|
136
|
+
return True
|
|
137
|
+
|
|
138
|
+
@property
|
|
139
|
+
def supports_flat_params(self):
|
|
140
|
+
return False
|
|
141
|
+
|
|
142
|
+
@torch.no_grad()
|
|
143
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
144
|
+
if p.grad is None:
|
|
145
|
+
return
|
|
146
|
+
|
|
147
|
+
grad = p.grad
|
|
148
|
+
if grad.dtype != torch.float32 and self.factored:
|
|
149
|
+
grad = grad.float()
|
|
150
|
+
if group["orthogonal_gradient"]:
|
|
151
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
152
|
+
state = self.state[p]
|
|
153
|
+
|
|
154
|
+
# State Initialization
|
|
155
|
+
if 'step' not in state:
|
|
156
|
+
state['step'] = 0
|
|
157
|
+
|
|
158
|
+
should_factor = (
|
|
159
|
+
self.factored and
|
|
160
|
+
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
state['factored'] = should_factor
|
|
164
|
+
|
|
165
|
+
dtype = torch.float32 if self.factored else p.dtype
|
|
166
|
+
device = p.device
|
|
167
|
+
|
|
168
|
+
if state['factored']:
|
|
169
|
+
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
170
|
+
d1, d2 = state['effective_shape']
|
|
171
|
+
|
|
172
|
+
# First moment (m)
|
|
173
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
174
|
+
state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
175
|
+
packed_d2 = (d2 + 7) // 8
|
|
176
|
+
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
177
|
+
# Second moment (v)
|
|
178
|
+
state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
179
|
+
state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
180
|
+
else: # Fallback to standard optimizer for non-factored tensors
|
|
181
|
+
state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
182
|
+
state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
183
|
+
|
|
184
|
+
if group['use_bias_correction']:
|
|
185
|
+
state['num_sum'] = 0.0
|
|
186
|
+
state['den_sum'] = 0.0
|
|
187
|
+
else:
|
|
188
|
+
state['num_sum'] = 1.0
|
|
189
|
+
state['den_sum'] = 1.0
|
|
190
|
+
|
|
191
|
+
beta1_final, beta2 = group["betas"]
|
|
192
|
+
|
|
193
|
+
current_step = state['step']
|
|
194
|
+
if group.get('kourkoutas_beta', False):
|
|
195
|
+
# Call prepare_step() once at the beginning of the step for all params
|
|
196
|
+
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
197
|
+
# Accumulate current grad's norm for the *next* step
|
|
198
|
+
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
199
|
+
# Get the dynamic beta2 calculated in prepare_step()
|
|
200
|
+
beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
|
|
201
|
+
|
|
202
|
+
beta1_warmup = group["beta1_warmup"]
|
|
203
|
+
alpha_grad = group["alpha_grad"]
|
|
204
|
+
|
|
205
|
+
if beta1_warmup is not None:
|
|
206
|
+
step = state['step'] + 1
|
|
207
|
+
beta1 = linear_hl_warmup_scheduler(step, beta_end=beta1_final, beta_start=group['min_beta1'], warmup=beta1_warmup)
|
|
208
|
+
else:
|
|
209
|
+
beta1 = beta1_final
|
|
210
|
+
|
|
211
|
+
if group['use_bias_correction']:
|
|
212
|
+
state['num_sum'] = beta1 * state['num_sum'] + 1.0
|
|
213
|
+
if group.get('kourkoutas_beta', False):
|
|
214
|
+
state['den_sum'] = group['betas'][1] * state['den_sum'] + (1.0 - group['betas'][1])
|
|
215
|
+
else:
|
|
216
|
+
state['den_sum'] = beta2 * state['den_sum'] + (1.0 - beta2)
|
|
217
|
+
|
|
218
|
+
if state['factored']:
|
|
219
|
+
d1, d2 = state['effective_shape']
|
|
220
|
+
|
|
221
|
+
# Reconstruct momentum from previous step's factors
|
|
222
|
+
mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
223
|
+
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
224
|
+
torch.where(unpacked_sign, mt, -mt, out=mt)
|
|
225
|
+
del unpacked_sign
|
|
226
|
+
# Update momentum in full-size
|
|
227
|
+
grad_reshaped = grad.view(d1, d2)
|
|
228
|
+
mt.mul_(beta1).add_(grad_reshaped, alpha=1.0)
|
|
229
|
+
|
|
230
|
+
vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
|
|
231
|
+
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
|
|
232
|
+
|
|
233
|
+
update = torch.add(mt, grad_reshaped, alpha=alpha_grad)
|
|
234
|
+
del grad_reshaped
|
|
235
|
+
|
|
236
|
+
denom = vt.sqrt().add_(group['eps'] * math.sqrt(state['den_sum']))
|
|
237
|
+
update.div_(denom)
|
|
238
|
+
del denom
|
|
239
|
+
|
|
240
|
+
if group['use_bias_correction']:
|
|
241
|
+
update = (update / state['num_sum']) * math.sqrt(state['den_sum'])
|
|
242
|
+
|
|
243
|
+
update = update.view(p.shape).mul_(group['lr'])
|
|
244
|
+
|
|
245
|
+
# Compress updated moments and store new factors
|
|
246
|
+
state['sign'] = _pack_bools(mt > 0)
|
|
247
|
+
_nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
248
|
+
del mt
|
|
249
|
+
_nnmf(vt, out=(state['mu_v_nmf'], state['mv_v_nmf']))
|
|
250
|
+
del vt
|
|
251
|
+
|
|
252
|
+
else: # Standard optimizer logic for non-factored tensors
|
|
253
|
+
exp_avg_sq = state['exp_avg_sq']
|
|
254
|
+
|
|
255
|
+
exp_avg = state['exp_avg']
|
|
256
|
+
exp_avg.mul_(beta1).add_(grad, alpha=1.0)
|
|
257
|
+
|
|
258
|
+
update = torch.add(exp_avg, grad, alpha=alpha_grad)
|
|
259
|
+
|
|
260
|
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
261
|
+
|
|
262
|
+
denom = exp_avg_sq.sqrt().add_(group['eps'] * math.sqrt(state['den_sum']))
|
|
263
|
+
update.div_(denom)
|
|
264
|
+
del denom
|
|
265
|
+
|
|
266
|
+
if group['use_bias_correction']:
|
|
267
|
+
update = (update / state['num_sum']) * math.sqrt(state['den_sum'])
|
|
268
|
+
|
|
269
|
+
update.mul_(group['lr'])
|
|
270
|
+
|
|
271
|
+
# Decoupled weight decay
|
|
272
|
+
if group["weight_decay"] != 0:
|
|
273
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
274
|
+
add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
|
|
275
|
+
else:
|
|
276
|
+
p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
|
|
277
|
+
|
|
278
|
+
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
279
|
+
add_stochastic_(p.data, -update)
|
|
280
|
+
else:
|
|
281
|
+
p.data.add_(-update)
|
|
282
|
+
del update
|
|
283
|
+
|
|
284
|
+
state['step'] += 1
|
|
285
|
+
|
|
286
|
+
@torch.no_grad()
|
|
287
|
+
def step(self, closure=None):
|
|
288
|
+
"""Performs a single optimization step."""
|
|
289
|
+
loss = None
|
|
290
|
+
if closure is not None:
|
|
291
|
+
with torch.enable_grad():
|
|
292
|
+
loss = closure()
|
|
293
|
+
|
|
294
|
+
for group in self.param_groups:
|
|
295
|
+
for i, p in enumerate(group['params']):
|
|
296
|
+
self.step_parameter(p, group, i)
|
|
297
|
+
|
|
298
|
+
return loss
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from .AdamW_adv import AdamW_adv
|
|
2
|
+
from .Prodigy_adv import Prodigy_adv
|
|
3
|
+
from .Adopt_adv import Adopt_adv
|
|
4
|
+
from .Simplified_AdEMAMix import Simplified_AdEMAMix
|
|
5
|
+
from .Lion_adv import Lion_adv
|
|
6
|
+
from .Lion_Prodigy_adv import Lion_Prodigy_adv
|
|
7
|
+
from .Muon_adv import Muon_adv
|
|
8
|
+
from .AdaMuon_adv import AdaMuon_adv
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"AdamW_adv",
|
|
12
|
+
"Prodigy_adv",
|
|
13
|
+
"Adopt_adv",
|
|
14
|
+
"Simplified_AdEMAMix",
|
|
15
|
+
"Lion_adv",
|
|
16
|
+
"Lion_Prodigy_adv",
|
|
17
|
+
"Muon_adv",
|
|
18
|
+
"AdaMuon_adv",
|
|
19
|
+
]
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import Tensor
|
|
3
|
+
|
|
4
|
+
def copy_stochastic_(target: Tensor, source: Tensor):
|
|
5
|
+
"""
|
|
6
|
+
Nerogar's implementation of stochastic rounding in the paper "Revisiting BFloat16 Training"
|
|
7
|
+
(https://arxiv.org/abs/2010.06192).
|
|
8
|
+
see:
|
|
9
|
+
https://github.com/pytorch/pytorch/issues/120376
|
|
10
|
+
https://github.com/Nerogar/OneTrainer/blob/daae18eaed8c0fa39289b2ff79cc2c1e08577fcb/modules/util/bf16_stochastic_rounding.py
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
target: the target tensor with dtype=bfloat16
|
|
14
|
+
source: the target tensor with dtype=float32
|
|
15
|
+
"""
|
|
16
|
+
# create a random 16 bit integer
|
|
17
|
+
result = torch.randint_like(
|
|
18
|
+
source,
|
|
19
|
+
dtype=torch.int32,
|
|
20
|
+
low=0,
|
|
21
|
+
high=(1 << 16),
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
# add the random number to the lower 16 bit of the mantissa
|
|
25
|
+
result.add_(source.view(dtype=torch.int32))
|
|
26
|
+
|
|
27
|
+
# mask off the lower 16 bit of the mantissa
|
|
28
|
+
result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
|
|
29
|
+
|
|
30
|
+
# copy the higher 16 bit into the target tensor
|
|
31
|
+
target.copy_(result.view(dtype=torch.float32))
|
|
32
|
+
|
|
33
|
+
del result
|
|
34
|
+
|
|
35
|
+
def add_stochastic_(input: Tensor, other: Tensor, alpha: float = 1.0):
|
|
36
|
+
"""
|
|
37
|
+
adds other to input using stochastic rounding
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
input: the input tensor with dtype=bfloat16
|
|
41
|
+
other: the other tensor
|
|
42
|
+
alpha: a multiplier for other
|
|
43
|
+
"""
|
|
44
|
+
result = other.clone() if other.dtype == torch.float32 else other.to(dtype=torch.float32)
|
|
45
|
+
|
|
46
|
+
result.add_(input, alpha=alpha)
|
|
47
|
+
copy_stochastic_(input, result)
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
def _get_effective_shape(numel: int) -> tuple[int, int]:
|
|
2
|
+
"""Finds two factors of numel that are closest to its square root."""
|
|
3
|
+
if numel <= 0:
|
|
4
|
+
return (0, 0)
|
|
5
|
+
for i in reversed(range(1, int(numel ** 0.5) + 1)):
|
|
6
|
+
if numel % i == 0:
|
|
7
|
+
return (numel // i, i)
|
|
8
|
+
return (numel, 1)
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.optim import Optimizer
|
|
3
|
+
from typing import Callable
|
|
4
|
+
|
|
5
|
+
class KourkoutasHelper:
|
|
6
|
+
"""
|
|
7
|
+
A helper class to add layer-wise Kourkoutas-β functionality to a PyTorch optimizer.
|
|
8
|
+
"""
|
|
9
|
+
def __init__(self, optimizer: Optimizer):
|
|
10
|
+
# We need a reference to the optimizer to access its param_groups and state
|
|
11
|
+
if not hasattr(optimizer, 'param_groups'):
|
|
12
|
+
raise TypeError("optimizer must be a valid torch.optim.Optimizer instance.")
|
|
13
|
+
self.optimizer = optimizer
|
|
14
|
+
self.layer_state = {}
|
|
15
|
+
|
|
16
|
+
self.layer_info = {}
|
|
17
|
+
self._layer_info_built = False
|
|
18
|
+
self._current_step_prepared = -1
|
|
19
|
+
|
|
20
|
+
# Store stats for external logging (e.g., TensorBoard)
|
|
21
|
+
self.last_beta2_stats = {}
|
|
22
|
+
|
|
23
|
+
# This ensures the map is complete before the first backward pass,
|
|
24
|
+
# making it compatible with fused back pass mechanisms.
|
|
25
|
+
self._build_layer_info_if_needed()
|
|
26
|
+
|
|
27
|
+
def _build_layer_info_if_needed(self):
|
|
28
|
+
"""Builds a map of layers and the parameters they contain."""
|
|
29
|
+
if self._layer_info_built:
|
|
30
|
+
return
|
|
31
|
+
|
|
32
|
+
if hasattr(self.optimizer, 'layer_key_fn') and self.optimizer.layer_key_fn is not None:
|
|
33
|
+
# A custom key function was provided by the user. We will use it.
|
|
34
|
+
pass
|
|
35
|
+
else:
|
|
36
|
+
# No key function was provided. Default to coarse, shape-based bucketing.
|
|
37
|
+
self.optimizer.layer_key_fn = lambda p: \
|
|
38
|
+
(id(p),) if p.dim() == 2 and 1 <= p.shape[0] <= 10 and p.shape[1] in {768, 1280, 4096} \
|
|
39
|
+
else tuple(p.shape)
|
|
40
|
+
# This ensures that we won't mix embeddings with tokens (1 to 10)
|
|
41
|
+
# TODO find a better way to safeguard the embeddings
|
|
42
|
+
|
|
43
|
+
for group in self.optimizer.param_groups:
|
|
44
|
+
if not group.get('kourkoutas_beta', False) and not group.get('adam_kourkoutas_beta', False):
|
|
45
|
+
continue
|
|
46
|
+
|
|
47
|
+
for p in group['params']:
|
|
48
|
+
# The mapping is static and should not depend on the presence of a gradient.
|
|
49
|
+
layer_key = self.optimizer.layer_key_fn(p)
|
|
50
|
+
if layer_key not in self.layer_info:
|
|
51
|
+
self.layer_info[layer_key] = {'params': [], 'group_ref': group}
|
|
52
|
+
self.layer_info[layer_key]['params'].append(p)
|
|
53
|
+
|
|
54
|
+
self._layer_info_built = True
|
|
55
|
+
|
|
56
|
+
def prepare_step(self, current_step: int):
|
|
57
|
+
"""
|
|
58
|
+
Calculates dynamic beta2 for all layers using the completed scalar accumulators
|
|
59
|
+
from the PREVIOUS step. Should be called once at the start of an optimizer step.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
beta2_log = []
|
|
63
|
+
# These are just for the sample log, initialize them
|
|
64
|
+
sun, pooled_grad_norm, r_ema_tensor = (torch.tensor(0.0),)*3
|
|
65
|
+
|
|
66
|
+
# The optimizer that owns this helper holds the master defaults for K-b.
|
|
67
|
+
# This is crucial in hybrid optimizers where some param_groups might not
|
|
68
|
+
# have all K-b keys populated, preventing KeyErrors.
|
|
69
|
+
master_defaults = self.optimizer.defaults
|
|
70
|
+
|
|
71
|
+
for layer_key, info in self.layer_info.items():
|
|
72
|
+
params, group = info['params'], info['group_ref']
|
|
73
|
+
|
|
74
|
+
if not group.get('kourkoutas_beta', False) and not group.get('adam_kourkoutas_beta', False):
|
|
75
|
+
continue
|
|
76
|
+
|
|
77
|
+
first_param_in_layer = info['params'][0]
|
|
78
|
+
param_state = self.optimizer.state[first_param_in_layer]
|
|
79
|
+
|
|
80
|
+
if layer_key not in self.layer_state:
|
|
81
|
+
self.layer_state[layer_key] = {
|
|
82
|
+
'sum_sq_accumulator': torch.tensor(0.0, device=first_param_in_layer.device, dtype=torch.float32)
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
if 'kourkoutas_r_ema' not in param_state:
|
|
86
|
+
param_state['kourkoutas_r_ema'] = torch.tensor(0.0, device=first_param_in_layer.device, dtype=torch.float32)
|
|
87
|
+
|
|
88
|
+
# Use group-specific K-b settings, falling back to the optimizer's master defaults.
|
|
89
|
+
# This makes the helper robust against param groups that enable kourkoutas_beta
|
|
90
|
+
# but are missing the other required hyperparameters.
|
|
91
|
+
ema_alpha = group.get('ema_alpha', master_defaults['ema_alpha'])
|
|
92
|
+
beta2_max = group.get('betas', master_defaults['betas'])[1]
|
|
93
|
+
beta2_min = group.get('beta2_min', master_defaults['beta2_min'])
|
|
94
|
+
tiny_spike = group.get('tiny_spike', master_defaults['tiny_spike'])
|
|
95
|
+
k_warmup_steps = group.get('k_warmup_steps', master_defaults['k_warmup_steps'])
|
|
96
|
+
|
|
97
|
+
r_ema_tensor = param_state['kourkoutas_r_ema']
|
|
98
|
+
accumulator = self.layer_state[layer_key]['sum_sq_accumulator']
|
|
99
|
+
|
|
100
|
+
pooled_grad_norm = torch.sqrt(accumulator)
|
|
101
|
+
|
|
102
|
+
# Update the persistent EMA tensor in-place.
|
|
103
|
+
r_ema_tensor.mul_(ema_alpha).add_(pooled_grad_norm, alpha=1.0 - ema_alpha)
|
|
104
|
+
|
|
105
|
+
sun = torch.tensor(0.0, device=r_ema_tensor.device) # Default sun to 0 for warmup
|
|
106
|
+
|
|
107
|
+
if current_step < k_warmup_steps:
|
|
108
|
+
beta2 = beta2_max
|
|
109
|
+
else:
|
|
110
|
+
raw = pooled_grad_norm / (r_ema_tensor + tiny_spike)
|
|
111
|
+
sun = raw / (1.0 + raw)
|
|
112
|
+
beta2 = beta2_max - (beta2_max - beta2_min) * sun
|
|
113
|
+
|
|
114
|
+
# Store the final calculated beta2 in the helper's transient state for this step.
|
|
115
|
+
self.layer_state[layer_key]['dynamic_beta2'] = beta2.item() if isinstance(beta2, torch.Tensor) else beta2
|
|
116
|
+
|
|
117
|
+
# Reset the accumulator for the next optimizer step.
|
|
118
|
+
accumulator.zero_()
|
|
119
|
+
|
|
120
|
+
beta2_log.append(self.layer_state[layer_key]['dynamic_beta2'])
|
|
121
|
+
|
|
122
|
+
# Always compute stats for TensorBoard
|
|
123
|
+
if beta2_log:
|
|
124
|
+
beta2_tensor = torch.tensor(beta2_log, device='cpu')
|
|
125
|
+
self.last_beta2_stats = {
|
|
126
|
+
'mean': beta2_tensor.mean().item(),
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
def maybe_prepare_step(self, current_step: int):
|
|
130
|
+
"""
|
|
131
|
+
A universal guard that calls prepare_step() exactly once per training step.
|
|
132
|
+
"""
|
|
133
|
+
if self._current_step_prepared < current_step:
|
|
134
|
+
self.prepare_step(current_step)
|
|
135
|
+
self._current_step_prepared = current_step
|
|
136
|
+
|
|
137
|
+
def accumulate_gradient_sq_norm(self, p: torch.Tensor, grad: torch.Tensor):
|
|
138
|
+
"""
|
|
139
|
+
Accumulates the squared L2 norm of a single gradient for the next step's calculation.
|
|
140
|
+
"""
|
|
141
|
+
layer_key = self.optimizer.layer_key_fn(p)
|
|
142
|
+
|
|
143
|
+
if layer_key in self.layer_info:
|
|
144
|
+
# Initialize the transient state for this layer if it's the first time in the step.
|
|
145
|
+
if layer_key not in self.layer_state:
|
|
146
|
+
self.layer_state[layer_key] = {
|
|
147
|
+
'sum_sq_accumulator': torch.tensor(0.0, device=p.device, dtype=torch.float32)
|
|
148
|
+
}
|
|
149
|
+
# Accumulate for the *next* step's prepare_step call
|
|
150
|
+
self.layer_state[layer_key]['sum_sq_accumulator'] += torch.sum(grad.detach().pow(2)).float()
|
|
151
|
+
|
|
152
|
+
def get_beta2(self, p: torch.Tensor, group: dict, current_step: int) -> float:
|
|
153
|
+
"""
|
|
154
|
+
Gets the appropriate beta2 for the current parameter, handling warmup and dynamic value fetching.
|
|
155
|
+
"""
|
|
156
|
+
layer_key = self.optimizer.layer_key_fn(p)
|
|
157
|
+
# The default is the max value, which is correct for unmapped params or edge cases
|
|
158
|
+
beta2_default = group.get('betas', group.get('adam_betas'))[1] if group.get('betas', group.get('adam_betas')) else 0.999
|
|
159
|
+
return self.layer_state.get(layer_key, {}).get('dynamic_beta2', beta2_default)
|
adv_optm/util/NNMF.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
def _unnmf(row_col: tuple) -> torch.Tensor:
|
|
4
|
+
"""Reconstructs a matrix from its rank-1 factors (outer product)."""
|
|
5
|
+
return torch.outer(row_col[0], row_col[1])
|
|
6
|
+
|
|
7
|
+
def _nnmf(matrix: torch.Tensor, out: tuple):
|
|
8
|
+
"""Performs a rank-1 non-negative matrix factorization."""
|
|
9
|
+
shape = matrix.shape
|
|
10
|
+
torch.sum(matrix, dim=1, out=out[0])
|
|
11
|
+
torch.sum(matrix, dim=0, out=out[1])
|
|
12
|
+
# Normalize one of the factors for stability
|
|
13
|
+
if shape[0] < shape[1]:
|
|
14
|
+
scale = out[0].sum()
|
|
15
|
+
if scale != 0: out[0].div_(scale)
|
|
16
|
+
else:
|
|
17
|
+
scale = out[1].sum()
|
|
18
|
+
if scale != 0: out[1].div_(scale)
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
@torch.no_grad()
|
|
4
|
+
def _newton_schulz_iteration(
|
|
5
|
+
G: torch.Tensor,
|
|
6
|
+
steps: int = 5,
|
|
7
|
+
eps: float = 1e-7,
|
|
8
|
+
coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
|
|
9
|
+
cns: bool = False,
|
|
10
|
+
cns_a_bound: float = 1e-4,
|
|
11
|
+
) -> torch.Tensor:
|
|
12
|
+
"""
|
|
13
|
+
Performs the Newton-Schulz iteration to find the nearest orthogonal matrix.
|
|
14
|
+
This is the core computation of the Muon optimizer.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
G (torch.Tensor): The 2D input matrix (momentum-accumulated gradient).
|
|
18
|
+
steps (int): The number of iterations to run.
|
|
19
|
+
eps (float): Small constant for numerical stability during normalization.
|
|
20
|
+
coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
|
|
21
|
+
quintic polynomial update.
|
|
22
|
+
cns (bool): If True, enables Chebyshev-accelerated Newton-Schulz (CANS)
|
|
23
|
+
using an iterative 3rd-order polynomial with optimal coefficients
|
|
24
|
+
derived at each step.
|
|
25
|
+
cns_a_bound (float): The initial lower bound for singular values when
|
|
26
|
+
using CANS. The upper bound is assumed to be 1.0 after normalization.
|
|
27
|
+
Returns:
|
|
28
|
+
torch.Tensor: The orthogonalized matrix.
|
|
29
|
+
"""
|
|
30
|
+
assert G.ndim >= 2
|
|
31
|
+
|
|
32
|
+
a, b, c = coeffs
|
|
33
|
+
|
|
34
|
+
X = G.to(torch.bfloat16)
|
|
35
|
+
|
|
36
|
+
transposed = G.size(-2) > G.size(-1)
|
|
37
|
+
if transposed:
|
|
38
|
+
X = X.mT
|
|
39
|
+
|
|
40
|
+
X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps)
|
|
41
|
+
|
|
42
|
+
if cns:
|
|
43
|
+
# Chebyshev-accelerated Newton-Schulz (CANS) from
|
|
44
|
+
# "Accelerating Newton-Schulz Iteration for Orthogonalization via Chebyshev-type Polynomials"
|
|
45
|
+
# This implements the iterative scheme from Algorithm 1, using the
|
|
46
|
+
# closed-form 3rd-order polynomial from Proposition 2.
|
|
47
|
+
lower_bound = cns_a_bound
|
|
48
|
+
upper_bound = 1.0 # Matrix is normalized, so largest singular value is approx 1.
|
|
49
|
+
|
|
50
|
+
for _ in range(steps):
|
|
51
|
+
# Calculate optimal 3rd-order coefficients c1, c3 for p(x) = c1*x + c3*x^3
|
|
52
|
+
# based on the current singular value bounds [lower_bound, upper_bound].
|
|
53
|
+
# Formulas are derived from Proposition 2 and its proof in Appendix B of the paper.
|
|
54
|
+
a_bound, b_bound = lower_bound, upper_bound
|
|
55
|
+
term = a_bound*a_bound + a_bound*b_bound + b_bound*b_bound
|
|
56
|
+
e_sq = term / 3.0
|
|
57
|
+
|
|
58
|
+
# Calculate alpha, which scales the polynomial
|
|
59
|
+
common_den_part = 2.0 * (e_sq**1.5)
|
|
60
|
+
ab_part = a_bound*a_bound*b_bound + b_bound*b_bound*a_bound
|
|
61
|
+
alpha_den = common_den_part + ab_part
|
|
62
|
+
alpha = 6.0 / alpha_den
|
|
63
|
+
|
|
64
|
+
c1 = alpha * e_sq
|
|
65
|
+
c3 = -alpha / 3.0
|
|
66
|
+
|
|
67
|
+
# Apply the 3rd-order Newton-Schulz update
|
|
68
|
+
A = X @ X.mT
|
|
69
|
+
X = c1 * X + c3 * (A @ X)
|
|
70
|
+
|
|
71
|
+
# Update the singular value bounds for the next iteration based on the error
|
|
72
|
+
eps_num = common_den_part - ab_part
|
|
73
|
+
eps_val = eps_num / alpha_den
|
|
74
|
+
lower_bound = 1.0 - eps_val
|
|
75
|
+
upper_bound = 1.0 + eps_val
|
|
76
|
+
else:
|
|
77
|
+
# Perform the iterative updates
|
|
78
|
+
for _ in range(steps):
|
|
79
|
+
A = X @ X.mT
|
|
80
|
+
B = b * A + c * (A @ A)
|
|
81
|
+
X = a * X + B @ X
|
|
82
|
+
|
|
83
|
+
# Transpose back if necessary
|
|
84
|
+
if transposed:
|
|
85
|
+
X = X.mT
|
|
86
|
+
|
|
87
|
+
return X.to(G.dtype)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
@torch.no_grad()
|
|
4
|
+
def _pack_bools(tensor: torch.Tensor) -> torch.Tensor:
|
|
5
|
+
"""Packs a boolean tensor into a uint8 tensor to achieve 1-bit storage."""
|
|
6
|
+
n, m = tensor.shape
|
|
7
|
+
packed_m = (m + 7) // 8
|
|
8
|
+
padded_tensor = torch.nn.functional.pad(tensor, (0, packed_m * 8 - m), 'constant', 0)
|
|
9
|
+
reshaped = padded_tensor.view(n, packed_m, 8)
|
|
10
|
+
shifter = torch.arange(8, device=tensor.device, dtype=torch.uint8)
|
|
11
|
+
packed = (reshaped.to(torch.uint8) * (2**shifter)).sum(dim=2).to(torch.uint8)
|
|
12
|
+
return packed
|
|
13
|
+
|
|
14
|
+
@torch.no_grad()
|
|
15
|
+
def _unpack_bools(packed_tensor: torch.Tensor, original_m: int) -> torch.Tensor:
|
|
16
|
+
"""Unpacks a uint8 tensor back into a boolean tensor."""
|
|
17
|
+
if packed_tensor.dtype != torch.uint8:
|
|
18
|
+
packed_tensor = packed_tensor.to(torch.uint8)
|
|
19
|
+
shifter = (2**torch.arange(8, device=packed_tensor.device, dtype=torch.uint8)).view(1, 1, 8)
|
|
20
|
+
unpacked_padded = (packed_tensor.unsqueeze(2) & shifter) != 0
|
|
21
|
+
unpacked = unpacked_padded.view(packed_tensor.shape[0], -1)[:, :original_m]
|
|
22
|
+
return unpacked
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
def _orthogonalize_gradient(p: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
|
|
4
|
+
"""Projects the gradient `grad` to be orthogonal to the parameter `p`."""
|
|
5
|
+
if grad.is_sparse: raise RuntimeError("OrthoGrad logic does not support sparse gradients.")
|
|
6
|
+
original_shape = grad.shape
|
|
7
|
+
original_dtype = grad.dtype
|
|
8
|
+
w = p.view(-1).float()
|
|
9
|
+
g = grad.view(-1).float()
|
|
10
|
+
w_norm_sq = torch.dot(w, w).add_(1e-30)
|
|
11
|
+
proj = torch.dot(w, g) / w_norm_sq
|
|
12
|
+
g_orth = g.sub(w, alpha=proj)
|
|
13
|
+
g_norm = g.norm(2)
|
|
14
|
+
g_orth_norm = g_orth.norm(2).add_(1e-30)
|
|
15
|
+
g_orth_scaled = g_orth * (g_norm / g_orth_norm)
|
|
16
|
+
return g_orth_scaled.view(original_shape).to(original_dtype)
|