adv-optm 1.2.dev13__py3-none-any.whl → 2.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/AdamW_adv.py +85 -64
- adv_optm/optim/Adopt_adv.py +114 -69
- adv_optm/optim/Lion_Prodigy_adv.py +79 -81
- adv_optm/optim/Lion_adv.py +37 -42
- adv_optm/optim/Prodigy_adv.py +105 -85
- adv_optm/optim/Simplified_AdEMAMix.py +92 -51
- adv_optm/optim/__init__.py +1 -1
- adv_optm/util/BF16_Stochastic_Rounding.py +1 -1
- adv_optm/util/Effective_Shape.py +1 -1
- adv_optm/util/Kourkoutas.py +11 -12
- adv_optm/util/NNMF.py +7 -2
- adv_optm/util/Newton_Schulz.py +1 -2
- adv_optm/util/One_Bit_Boolean.py +1 -1
- adv_optm/util/OrthoGrad.py +4 -3
- adv_optm/util/__init__.py +1 -1
- {adv_optm-1.2.dev13.dist-info → adv_optm-2.dev1.dist-info}/METADATA +20 -20
- adv_optm-2.dev1.dist-info/RECORD +23 -0
- adv_optm-1.2.dev13.dist-info/RECORD +0 -23
- {adv_optm-1.2.dev13.dist-info → adv_optm-2.dev1.dist-info}/WHEEL +0 -0
- {adv_optm-1.2.dev13.dist-info → adv_optm-2.dev1.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-1.2.dev13.dist-info → adv_optm-2.dev1.dist-info}/top_level.txt +0 -0
adv_optm/__init__.py
CHANGED
adv_optm/optim/AdamW_adv.py
CHANGED
|
@@ -49,12 +49,6 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
49
49
|
before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
|
|
50
50
|
A higher value increases the stabilizing influence of the slow
|
|
51
51
|
momentum. (default: 5.0)
|
|
52
|
-
t_alpha (Optional[int]): The number of steps for a linear warmup of the
|
|
53
|
-
`alpha` parameter (only used when `use_AdEMAMix` is `True`). This is
|
|
54
|
-
highly recommended to prevent instability at the beginning of training,
|
|
55
|
-
as it gradually introduces the stabilizing slow momentum term. During
|
|
56
|
-
the warmup, `alpha` ramps from 0 to its target value. If `None`,
|
|
57
|
-
the scheduler is disabled. (default: None)
|
|
58
52
|
kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
|
|
59
53
|
If `False`, the optimizer behaves as standard AdamW. (default: False)
|
|
60
54
|
beta2_min (float): The minimum value for dynamic β₂, used during periods of
|
|
@@ -72,11 +66,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
72
66
|
k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
|
|
73
67
|
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
74
68
|
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
75
|
-
logging (default: 0).
|
|
76
|
-
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
77
|
-
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
78
|
-
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
79
|
-
(default: None)
|
|
69
|
+
logging (default: 0).
|
|
80
70
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
81
71
|
the uncompressed optimizer. (default: False)
|
|
82
72
|
"""
|
|
@@ -89,7 +79,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
89
79
|
eps: float = 1e-8,
|
|
90
80
|
weight_decay: float = 0.0,
|
|
91
81
|
use_bias_correction: bool = True,
|
|
92
|
-
vector_reshape: bool =
|
|
82
|
+
vector_reshape: bool = False,
|
|
93
83
|
stochastic_rounding: bool = True,
|
|
94
84
|
use_atan2: bool = False,
|
|
95
85
|
cautious_mask: bool = False,
|
|
@@ -98,15 +88,15 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
98
88
|
use_AdEMAMix: bool = False,
|
|
99
89
|
beta3_ema: float = 0.9999,
|
|
100
90
|
alpha: float = 5.0,
|
|
101
|
-
t_alpha: int | None = None,
|
|
102
91
|
kourkoutas_beta: bool = False,
|
|
103
92
|
beta2_min: float = 0.9,
|
|
104
93
|
ema_alpha: float = 0.95,
|
|
105
94
|
tiny_spike: float = 1e-9,
|
|
106
95
|
k_warmup_steps: int = 0,
|
|
107
96
|
k_logging: int = 0,
|
|
108
|
-
layer_key_fn: Optional[Callable] = None,
|
|
109
97
|
nnmf_factor: bool = False,
|
|
98
|
+
# Compiled
|
|
99
|
+
compiled_optimizer: bool = False,
|
|
110
100
|
):
|
|
111
101
|
if not (lr >= 0.0):
|
|
112
102
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -116,7 +106,8 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
116
106
|
raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
|
|
117
107
|
if not (weight_decay >= 0.0):
|
|
118
108
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
119
|
-
if kourkoutas_beta and not (betas[1] > beta2_min):
|
|
109
|
+
if kourkoutas_beta and not (betas[1] > beta2_min):
|
|
110
|
+
raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
|
|
120
111
|
|
|
121
112
|
if cautious_mask and grams_moment:
|
|
122
113
|
print("Warning: cautious is incompatible with grams, Disabling cautious.")
|
|
@@ -126,9 +117,10 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
126
117
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
127
118
|
"vector_reshape": vector_reshape, "use_atan2": use_atan2,
|
|
128
119
|
"orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
|
|
129
|
-
"beta3_ema": beta3_ema, "alpha": alpha,
|
|
120
|
+
"beta3_ema": beta3_ema, "alpha": alpha,
|
|
130
121
|
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
131
122
|
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
123
|
+
"compiled_optimizer": compiled_optimizer,
|
|
132
124
|
}
|
|
133
125
|
self.stochastic_rounding = stochastic_rounding
|
|
134
126
|
self.cautious_mask = cautious_mask
|
|
@@ -136,12 +128,20 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
136
128
|
self.use_AdEMAMix = use_AdEMAMix
|
|
137
129
|
self.factored = nnmf_factor
|
|
138
130
|
self.kourkoutas_beta = kourkoutas_beta
|
|
139
|
-
|
|
131
|
+
|
|
140
132
|
super().__init__(params, defaults)
|
|
141
133
|
|
|
134
|
+
self.init_step()
|
|
135
|
+
|
|
142
136
|
if self.kourkoutas_beta:
|
|
143
137
|
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
144
138
|
|
|
139
|
+
self.global_step = 0
|
|
140
|
+
|
|
141
|
+
if compiled_optimizer:
|
|
142
|
+
torch._dynamo.config.cache_size_limit = 8192
|
|
143
|
+
self.compile(fullgraph=True)
|
|
144
|
+
|
|
145
145
|
@property
|
|
146
146
|
def supports_fused_back_pass(self):
|
|
147
147
|
return True
|
|
@@ -154,29 +154,22 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
154
154
|
def supports_flat_params(self):
|
|
155
155
|
return False
|
|
156
156
|
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
157
|
+
def init_step(self):
|
|
158
|
+
for group in self.param_groups:
|
|
159
|
+
for p in group['params']:
|
|
160
|
+
self.__init_state(p, group)
|
|
161
161
|
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
grad = grad.float()
|
|
165
|
-
if group["orthogonal_gradient"]:
|
|
166
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
162
|
+
@torch.no_grad()
|
|
163
|
+
def __init_state(self, p, group):
|
|
167
164
|
state = self.state[p]
|
|
168
165
|
|
|
169
|
-
|
|
170
|
-
if 'step' not in state:
|
|
171
|
-
state['step'] = 0
|
|
166
|
+
if len(state) == 0:
|
|
172
167
|
|
|
173
|
-
|
|
168
|
+
state['factored'] = (
|
|
174
169
|
self.factored and
|
|
175
170
|
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
176
171
|
)
|
|
177
172
|
|
|
178
|
-
state['factored'] = should_factor
|
|
179
|
-
|
|
180
173
|
dtype = torch.float32 if self.factored else p.dtype
|
|
181
174
|
device = p.device
|
|
182
175
|
|
|
@@ -186,18 +179,18 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
186
179
|
|
|
187
180
|
# First moment (m)
|
|
188
181
|
if group['betas'][0] > 0:
|
|
189
|
-
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
182
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
190
183
|
state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
191
184
|
if not self.grams_moment:
|
|
192
185
|
packed_d2 = (d2 + 7) // 8
|
|
193
186
|
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
194
187
|
if self.use_AdEMAMix:
|
|
195
|
-
state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
188
|
+
state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
196
189
|
state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
197
190
|
packed_d2 = (d2 + 7) // 8
|
|
198
191
|
state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
199
192
|
# Second moment (v)
|
|
200
|
-
state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
193
|
+
state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
201
194
|
state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
202
195
|
else: # Fallback to standard AdamW for non-factored tensors
|
|
203
196
|
if group['betas'][0] > 0:
|
|
@@ -206,37 +199,32 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
206
199
|
state['exp_avg_slow'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
207
200
|
state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
208
201
|
|
|
202
|
+
@torch.no_grad()
|
|
203
|
+
def __step_parameter(self, p: torch.Tensor, group: dict, lr: torch.Tensor | float, bias_correction1: torch.Tensor | float, bias_correction2: torch.Tensor | float):
|
|
204
|
+
if p.grad is None:
|
|
205
|
+
return
|
|
206
|
+
|
|
207
|
+
grad = p.grad
|
|
208
|
+
if grad.dtype != torch.float32 and self.factored:
|
|
209
|
+
grad = grad.float()
|
|
210
|
+
if group["orthogonal_gradient"]:
|
|
211
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
212
|
+
state = self.state[p]
|
|
213
|
+
|
|
214
|
+
|
|
209
215
|
beta1, beta2 = group['betas']
|
|
210
216
|
|
|
211
|
-
current_step = state['step']
|
|
212
217
|
if group.get('kourkoutas_beta', False):
|
|
213
|
-
# Call prepare_step() once at the beginning of the step for all params
|
|
214
|
-
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
215
218
|
# Accumulate current grad's norm for the *next* step
|
|
216
219
|
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
217
220
|
# Get the dynamic beta2 calculated in prepare_step()
|
|
218
|
-
beta2 = self.kourkoutas_helper.get_beta2(p, group
|
|
221
|
+
beta2 = self.kourkoutas_helper.get_beta2(p, group)
|
|
219
222
|
|
|
220
|
-
|
|
221
|
-
if group['use_bias_correction']:
|
|
222
|
-
bias_correction1 = 1.0 - beta1 ** step
|
|
223
|
-
if group.get('kourkoutas_beta', False):
|
|
224
|
-
bias_correction2 = 1.0 - group['betas'][1] ** step
|
|
225
|
-
# Use beta2_max for bias correction
|
|
226
|
-
else:
|
|
227
|
-
bias_correction2 = 1.0 - beta2 ** step
|
|
228
|
-
else:
|
|
229
|
-
bias_correction1 = 1
|
|
230
|
-
bias_correction2 = 1
|
|
231
|
-
step_size = group['lr'] / bias_correction1
|
|
223
|
+
step_size = lr / bias_correction1
|
|
232
224
|
|
|
233
225
|
if self.use_AdEMAMix:
|
|
234
226
|
beta3_ema = group['beta3_ema']
|
|
235
227
|
alpha = group['alpha']
|
|
236
|
-
t_alpha = group['t_alpha']
|
|
237
|
-
alpha_t = alpha
|
|
238
|
-
if t_alpha is not None and t_alpha > 0 and step < t_alpha:
|
|
239
|
-
alpha_t = min(step * alpha / t_alpha, alpha)
|
|
240
228
|
|
|
241
229
|
if state['factored']:
|
|
242
230
|
d1, d2 = state['effective_shape']
|
|
@@ -252,7 +240,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
252
240
|
# Update momentum in full-size
|
|
253
241
|
mt.mul_(beta1).add_(grad_reshaped, alpha=1.0 - beta1)
|
|
254
242
|
if self.grams_moment:
|
|
255
|
-
mt
|
|
243
|
+
mt = (grad_reshaped.sign().mul_(mt.abs()))
|
|
256
244
|
elif self.cautious_mask:
|
|
257
245
|
mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
258
246
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
@@ -272,9 +260,9 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
272
260
|
|
|
273
261
|
mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
|
|
274
262
|
if beta1 > 0:
|
|
275
|
-
update = torch.add(mt, mt_slow, alpha=
|
|
263
|
+
update = torch.add(mt, mt_slow, alpha=alpha)
|
|
276
264
|
else:
|
|
277
|
-
update = torch.add(grad_reshaped, mt_slow, alpha=
|
|
265
|
+
update = torch.add(grad_reshaped, mt_slow, alpha=alpha)
|
|
278
266
|
else:
|
|
279
267
|
update = mt.clone() if beta1 > 0 else grad_reshaped.clone()
|
|
280
268
|
del grad_reshaped
|
|
@@ -310,7 +298,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
310
298
|
exp_avg = state['exp_avg']
|
|
311
299
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
312
300
|
if self.grams_moment:
|
|
313
|
-
exp_avg = grad.sign()
|
|
301
|
+
exp_avg = grad.sign().mul_(exp_avg.abs())
|
|
314
302
|
elif self.cautious_mask:
|
|
315
303
|
mask = (exp_avg * grad > 0).to(grad.dtype)
|
|
316
304
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
@@ -321,9 +309,9 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
321
309
|
exp_avg_slow = state['exp_avg_slow']
|
|
322
310
|
exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
|
|
323
311
|
if beta1 > 0:
|
|
324
|
-
update = torch.add(exp_avg, exp_avg_slow, alpha=
|
|
312
|
+
update = torch.add(exp_avg, exp_avg_slow, alpha=alpha)
|
|
325
313
|
else:
|
|
326
|
-
update = torch.add(grad, exp_avg_slow, alpha=
|
|
314
|
+
update = torch.add(grad, exp_avg_slow, alpha=alpha)
|
|
327
315
|
else:
|
|
328
316
|
update = exp_avg.clone() if beta1 > 0 else grad.clone()
|
|
329
317
|
|
|
@@ -343,9 +331,9 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
343
331
|
# Decoupled weight decay
|
|
344
332
|
if group["weight_decay"] != 0:
|
|
345
333
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
346
|
-
add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] *
|
|
334
|
+
add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * lr)
|
|
347
335
|
else:
|
|
348
|
-
p.data.add_(p.data, alpha=-group["weight_decay"] *
|
|
336
|
+
p.data.add_(p.data, alpha=-group["weight_decay"] * lr)
|
|
349
337
|
|
|
350
338
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
351
339
|
add_stochastic_(p.data, -update)
|
|
@@ -353,7 +341,38 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
353
341
|
p.data.add_(-update)
|
|
354
342
|
del update
|
|
355
343
|
|
|
356
|
-
|
|
344
|
+
@torch.no_grad()
|
|
345
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
346
|
+
# if 'exp_avg_sq' not in self.state[p] and 'mu_v_nmf' not in self.state[p]:
|
|
347
|
+
# return
|
|
348
|
+
|
|
349
|
+
if self.global_step is None and 'step' in self.state[p]:
|
|
350
|
+
# For backward compatibility
|
|
351
|
+
self.global_step = self.state[p]['step']
|
|
352
|
+
|
|
353
|
+
if group['use_bias_correction']:
|
|
354
|
+
current_step = self.global_step + 1
|
|
355
|
+
beta1, beta2 = group['betas']
|
|
356
|
+
bias_correction1 = 1.0 - beta1 ** current_step
|
|
357
|
+
bias_correction2 = 1.0 - beta2 ** current_step
|
|
358
|
+
else:
|
|
359
|
+
bias_correction1 = 1.0
|
|
360
|
+
bias_correction2 = 1.0
|
|
361
|
+
|
|
362
|
+
if group.get('kourkoutas_beta', False):
|
|
363
|
+
# Prepare Kourkoutas-β once per step using the global step counter.
|
|
364
|
+
self.kourkoutas_helper.maybe_prepare_step(self.global_step)
|
|
365
|
+
|
|
366
|
+
if not group.get('compiled_optimizer', False):
|
|
367
|
+
self.__step_parameter(p, group, group['lr'], bias_correction1, bias_correction2)
|
|
368
|
+
else:
|
|
369
|
+
lr_tensor = torch.tensor(group['lr'], device=p.device)
|
|
370
|
+
bias_correction1_tensor = torch.tensor(bias_correction1, device=p.device)
|
|
371
|
+
bias_correction2_tensor = torch.tensor(bias_correction2, device=p.device)
|
|
372
|
+
self._compiled_step_parameter(p, group, lr_tensor, bias_correction1_tensor, bias_correction2_tensor)
|
|
373
|
+
|
|
374
|
+
def compile(self, *args, **kwargs):
|
|
375
|
+
self._compiled_step_parameter = torch.compile(self.__step_parameter, *args, **kwargs)
|
|
357
376
|
|
|
358
377
|
@torch.no_grad()
|
|
359
378
|
def step(self, closure=None):
|
|
@@ -367,4 +386,6 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
367
386
|
for i, p in enumerate(group['params']):
|
|
368
387
|
self.step_parameter(p, group, i)
|
|
369
388
|
|
|
389
|
+
self.global_step += 1
|
|
390
|
+
|
|
370
391
|
return loss
|