adv-optm 1.2.dev14__py3-none-any.whl → 2.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

@@ -56,13 +56,6 @@ class Adopt_adv(torch.optim.Optimizer):
56
56
  before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
57
57
  A higher value increases the stabilizing influence of the slow
58
58
  momentum. (default: 5.0)
59
- t_alpha (Optional[int]): The number of steps for a linear warmup of the
60
- `alpha` parameter (only used when `use_AdEMAMix` is `True`). This is
61
- highly recommended to prevent instability at the beginning of training,
62
- as it gradually introduces the stabilizing slow momentum term. During
63
- the warmup, `alpha` ramps from 0 to its target value. If `None`,
64
- the scheduler is disabled and the full `alpha` value is used from
65
- the start. (default: None)
66
59
  Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
67
60
  This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
68
61
  more responsive, especially for small batch sizes. Enabling this will
@@ -90,7 +83,7 @@ class Adopt_adv(torch.optim.Optimizer):
90
83
  k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
91
84
  logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
92
85
  every logging steps. Useful for debugging and tuning. Set to 0 to disable
93
- logging (default: 0).
86
+ logging (default: 0).
94
87
  layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
95
88
  and returns a unique, hashable key representing its "layer" or "bucket".
96
89
  If `None`, parameters are bucketed by their memory ID (tensor-wise).
@@ -107,7 +100,7 @@ class Adopt_adv(torch.optim.Optimizer):
107
100
  eps: float = 1e-6,
108
101
  weight_decay: float = 0.0,
109
102
  clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,
110
- vector_reshape: bool = True,
103
+ vector_reshape: bool = False,
111
104
  stochastic_rounding: bool = True,
112
105
  use_atan2: bool = False,
113
106
  cautious_mask: bool = False,
@@ -116,7 +109,6 @@ class Adopt_adv(torch.optim.Optimizer):
116
109
  use_AdEMAMix: bool = False,
117
110
  beta3_ema: float = 0.9999,
118
111
  alpha: float = 5.0,
119
- t_alpha: int | None = None,
120
112
  Simplified_AdEMAMix: bool = False,
121
113
  alpha_grad: float = 100.0,
122
114
  kourkoutas_beta: bool = False,
@@ -127,6 +119,8 @@ class Adopt_adv(torch.optim.Optimizer):
127
119
  k_logging: int = 0,
128
120
  layer_key_fn: Optional[Callable] = None,
129
121
  nnmf_factor: bool = False,
122
+ # Compiled
123
+ compiled_optimizer: bool = False,
130
124
  ):
131
125
  if not (lr >= 0.0):
132
126
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -141,7 +135,8 @@ class Adopt_adv(torch.optim.Optimizer):
141
135
  cautious_mask = False
142
136
  if betas[0] == 0.0 and Simplified_AdEMAMix:
143
137
  raise ValueError(f"Beta1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
144
- if kourkoutas_beta and not (betas[1] > beta2_min): raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
138
+ if kourkoutas_beta and not (betas[1] > beta2_min):
139
+ raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
145
140
  if use_AdEMAMix and Simplified_AdEMAMix:
146
141
  print("Warning: use_AdEMAMix is incompatible with Simplified_AdEMAMix, Disabling use_AdEMAMix.")
147
142
  if grams_moment and Simplified_AdEMAMix:
@@ -152,9 +147,10 @@ class Adopt_adv(torch.optim.Optimizer):
152
147
  defaults = {
153
148
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
154
149
  "vector_reshape": vector_reshape, "beta3_ema": beta3_ema, "alpha": alpha,
155
- "t_alpha": t_alpha, "alpha_grad": alpha_grad,
150
+ "alpha_grad": alpha_grad,
156
151
  "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
157
152
  "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
153
+ "compiled_optimizer": compiled_optimizer,
158
154
  }
159
155
  self.clip_lambda = clip_lambda
160
156
  self.stochastic_rounding = stochastic_rounding
@@ -169,9 +165,17 @@ class Adopt_adv(torch.optim.Optimizer):
169
165
  self.layer_key_fn = layer_key_fn
170
166
  super().__init__(params, defaults)
171
167
 
168
+ self.init_step()
169
+
172
170
  if self.kourkoutas_beta:
173
171
  self.kourkoutas_helper = KourkoutasHelper(self)
174
172
 
173
+ self.global_step = 0
174
+
175
+ if compiled_optimizer:
176
+ torch._dynamo.config.cache_size_limit = 8192
177
+ self.compile(fullgraph=True)
178
+
175
179
  @property
176
180
  def supports_fused_back_pass(self): return True
177
181
  @property
@@ -179,29 +183,22 @@ class Adopt_adv(torch.optim.Optimizer):
179
183
  @property
180
184
  def supports_flat_params(self): return False
181
185
 
182
- @torch.no_grad()
183
- def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
184
- if p.grad is None:
185
- return
186
+ def init_step(self):
187
+ for group in self.param_groups:
188
+ for p in group['params']:
189
+ self.__init_state(p, group)
186
190
 
187
- grad = p.grad
188
- if self.factored and grad.dtype != torch.float32:
189
- grad = grad.float()
190
- if self.orthogonal_gradient:
191
- grad = _orthogonalize_gradient(p, grad)
191
+ @torch.no_grad()
192
+ def __init_state(self, p, group):
192
193
  state = self.state[p]
193
194
 
194
- # State Initialization
195
- if 'step' not in state:
196
- state['step'] = 0
195
+ if len(state) == 0:
197
196
 
198
- should_factor = (
197
+ state['factored'] = (
199
198
  self.factored and
200
199
  not (len(p.shape) == 1 and not group['vector_reshape'])
201
200
  )
202
201
 
203
- state['factored'] = should_factor
204
-
205
202
  dtype = torch.float32 if self.factored else p.dtype
206
203
 
207
204
  if state['factored']:
@@ -210,55 +207,75 @@ class Adopt_adv(torch.optim.Optimizer):
210
207
 
211
208
  # m_0 = 0
212
209
  if group['betas'][0] > 0:
213
- state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
210
+ state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
214
211
  state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
215
212
  if not self.grams_moment:
216
213
  packed_d2 = (d2 + 7) // 8
217
214
  state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
218
215
  if self.use_AdEMAMix:
219
- state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
216
+ state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
220
217
  state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
221
218
  packed_d2 = (d2 + 7) // 8
222
219
  state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
223
- # v_0 = g_0^2 (SMMF_ADOPT NMF storage)
224
- vt_init = grad.view(d1, d2).square_()
225
- # Allocate NMF factors for v
226
- state['mu_v_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
227
- state['mv_v_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
228
- # Initialize v_0 using NMF
229
- _nnmf(vt_init, out=(state['mu_v_nmf'], state['mv_v_nmf']))
220
+
230
221
  else: # Fallback for non-factored tensors
231
222
  if group['betas'][0] > 0:
232
223
  state['exp_avg'] = torch.zeros_like(p, dtype=dtype) # m_0
233
224
  if self.use_AdEMAMix:
234
225
  state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
235
- state['exp_avg_sq'] = grad.square() # v_0
226
+
227
+ @torch.no_grad()
228
+ def __init_step(self, p, group):
229
+ if p.grad is None:
230
+ return
231
+
232
+ state = self.state[p]
233
+
234
+ if 'exp_avg_sq' in state or 'mu_v_nmf' in state:
235
+ return
236
+
237
+ grad = p.grad
238
+ dtype = torch.float32 if self.factored else p.dtype
239
+
240
+ if state['factored']:
241
+ d1, d2 = state['effective_shape']
242
+ # v_0 = g_0^2 (SMMF_ADOPT NMF storage)
243
+ vt_init = grad.view(d1, d2).square_()
244
+ # Allocate NMF factors for v
245
+ state['mu_v_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
246
+ state['mv_v_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
247
+ # Initialize v_0 using NMF
248
+ _nnmf(vt_init, out=(state['mu_v_nmf'], state['mv_v_nmf']))
249
+ del vt_init
250
+ else:
251
+ state['exp_avg_sq'] = grad.square() # v_0
252
+
253
+
254
+ @torch.no_grad()
255
+ def __step_parameter(self, p: torch.Tensor, group: dict, lr: torch.Tensor | float):
256
+ if p.grad is None:
257
+ return
258
+
259
+ grad = p.grad
260
+ if self.factored and grad.dtype != torch.float32:
261
+ grad = grad.float()
262
+ if self.orthogonal_gradient:
263
+ grad = _orthogonalize_gradient(p, grad)
264
+ state = self.state[p]
265
+
236
266
 
237
267
  beta1, beta2 = group['betas']
238
268
 
239
- current_step = state['step']
240
269
  if group.get('kourkoutas_beta', False):
241
- # Call prepare_step() once at the beginning of the step for all params
242
- self.kourkoutas_helper.maybe_prepare_step(current_step)
243
270
  # Accumulate current grad's norm for the *next* step
244
271
  self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
245
272
  # Get the dynamic beta2 calculated in prepare_step()
246
- beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
247
-
248
- # The first step is for initialization only (skip when use_atan2 as it's scale invariant).
249
- if state['step'] == 0 and not self.use_atan2:
250
- state['step'] += 1
251
- return
273
+ beta2 = self.kourkoutas_helper.get_beta2(p, group)
252
274
 
253
275
  if self.use_AdEMAMix:
254
276
  beta3_ema = group['beta3_ema']
255
277
  alpha = group['alpha']
256
- t_alpha = group['t_alpha']
257
- # Use step+1 for 1-based step count in scheduler
258
- alpha_step = state['step'] + 1
259
- alpha_t = alpha
260
- if t_alpha is not None and t_alpha > 0 and alpha_step < t_alpha:
261
- alpha_t = min(alpha_step * alpha / t_alpha, alpha)
278
+
262
279
  if self.Simplified_AdEMAMix:
263
280
  alpha_grad = group["alpha_grad"]
264
281
 
@@ -296,7 +313,7 @@ class Adopt_adv(torch.optim.Optimizer):
296
313
  else:
297
314
  normalized_grad = grad_reshaped / denom.add_(group['eps'])
298
315
  if self.clip_lambda is not None:
299
- clip_val = self.clip_lambda(state['step'])
316
+ clip_val = self.clip_lambda(self.global_step)
300
317
  normalized_grad.clamp_(-clip_val, clip_val)
301
318
  del denom
302
319
 
@@ -307,7 +324,7 @@ class Adopt_adv(torch.optim.Optimizer):
307
324
  else:
308
325
  mt.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
309
326
  if self.grams_moment:
310
- mt = grad_reshaped.sign() * mt.abs()
327
+ mt = grad_reshaped.sign().mul_(mt.abs())
311
328
  elif self.cautious_mask:
312
329
  mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
313
330
  mask.div_(mask.mean().clamp_(min=1e-3))
@@ -317,9 +334,9 @@ class Adopt_adv(torch.optim.Optimizer):
317
334
  if self.use_AdEMAMix:
318
335
  mt_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
319
336
  if beta1 > 0:
320
- update = torch.add(mt, mt_slow, alpha=alpha_t)
337
+ update = torch.add(mt, mt_slow, alpha=alpha)
321
338
  else:
322
- update = torch.add(normalized_grad, mt_slow, alpha=alpha_t)
339
+ update = torch.add(normalized_grad, mt_slow, alpha=alpha)
323
340
  elif self.Simplified_AdEMAMix:
324
341
  update = torch.add(mt, normalized_grad, alpha=alpha_grad)
325
342
  else:
@@ -328,9 +345,9 @@ class Adopt_adv(torch.optim.Optimizer):
328
345
  update = update.view(p.shape)
329
346
 
330
347
  if self.use_atan2:
331
- update.mul_(group['lr'] * 1.2732395447351628)
348
+ update.mul_(lr * 1.2732395447351628)
332
349
  else:
333
- update.mul_(group['lr'])
350
+ update.mul_(lr)
334
351
 
335
352
  # Update second moment v_t for the *next* step using raw g_t
336
353
  vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
@@ -353,7 +370,7 @@ class Adopt_adv(torch.optim.Optimizer):
353
370
  del vt
354
371
 
355
372
  else: # Standard ADOPT logic for non-factored tensors
356
- v = state['exp_avg_sq'] # v_{t-1}
373
+ v = state['exp_avg_sq'] # v_{t-1}
357
374
 
358
375
  # ADOPT Step A: Decorrelate g_t using v_{t-1}
359
376
  denom = v.sqrt()
@@ -363,7 +380,7 @@ class Adopt_adv(torch.optim.Optimizer):
363
380
  else:
364
381
  normalized_grad = grad / denom.add_(group['eps'])
365
382
  if self.clip_lambda is not None:
366
- clip_val = self.clip_lambda(state['step'])
383
+ clip_val = self.clip_lambda(self.global_step)
367
384
  normalized_grad.clamp_(-clip_val, clip_val)
368
385
  del denom
369
386
 
@@ -376,7 +393,7 @@ class Adopt_adv(torch.optim.Optimizer):
376
393
  m.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
377
394
 
378
395
  if self.grams_moment:
379
- m = grad.sign() * m.abs()
396
+ m = grad.sign().mul_(m.abs())
380
397
  elif self.cautious_mask:
381
398
  mask = (m * grad > 0).to(grad.dtype)
382
399
  mask.div_(mask.mean().clamp_(min=1e-3))
@@ -387,18 +404,18 @@ class Adopt_adv(torch.optim.Optimizer):
387
404
  m_slow = state['exp_avg_slow']
388
405
  m_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
389
406
  if beta1 > 0:
390
- update = torch.add(m, m_slow, alpha=alpha_t)
407
+ update = torch.add(m, m_slow, alpha=alpha)
391
408
  else:
392
- update = torch.add(normalized_grad, m_slow, alpha=alpha_t)
409
+ update = torch.add(normalized_grad, m_slow, alpha=alpha)
393
410
  elif self.Simplified_AdEMAMix:
394
411
  update = torch.add(m, normalized_grad, alpha=alpha_grad)
395
412
  else:
396
413
  update = m.clone() if beta1 > 0 else normalized_grad
397
414
 
398
415
  if self.use_atan2:
399
- update.mul_(group['lr'] * 1.2732395447351628)
416
+ update.mul_(lr * 1.2732395447351628)
400
417
  else:
401
- update.mul_(group['lr'])
418
+ update.mul_(lr)
402
419
 
403
420
  # Update second moment v_t for the next step using raw g_t
404
421
  v.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
@@ -406,9 +423,9 @@ class Adopt_adv(torch.optim.Optimizer):
406
423
  # Parameter Update
407
424
  if group["weight_decay"] != 0:
408
425
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
409
- add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
426
+ add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * lr)
410
427
  else:
411
- p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
428
+ p.data.add_(p.data, alpha=-group["weight_decay"] * lr)
412
429
 
413
430
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
414
431
  add_stochastic_(p.data, -update)
@@ -416,7 +433,33 @@ class Adopt_adv(torch.optim.Optimizer):
416
433
  p.data.add_(-update)
417
434
  del update
418
435
 
419
- state['step'] += 1
436
+ @torch.no_grad()
437
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
438
+ if self.global_step is None and 'step' in self.state[p]:
439
+ # For backward compatibility
440
+ self.global_step = self.state[p]['step']
441
+
442
+ if self.global_step == 0:
443
+ self.__init_step(p, group)
444
+
445
+ # The first step is for initialization only (skip when use_atan2 as it's scale invariant).
446
+ if self.global_step == 0 and not self.use_atan2:
447
+ self.global_step += 1
448
+ return
449
+
450
+ if group.get('kourkoutas_beta', False):
451
+ # Prepare Kourkoutas-β once per step using the global step counter.
452
+ self.kourkoutas_helper.maybe_prepare_step(self.global_step)
453
+
454
+ if not group.get('compiled_optimizer', False):
455
+ self.__step_parameter(p, group, group['lr'])
456
+ else:
457
+ lr_tensor = torch.tensor(group['lr'], device=p.device)
458
+ self._compiled_step_parameter(p, group, lr_tensor)
459
+
460
+ def compile(self, *args, **kwargs):
461
+ self._compiled_step_parameter = torch.compile(self.__step_parameter, *args, **kwargs)
462
+
420
463
 
421
464
  @torch.no_grad()
422
465
  def step(self, closure=None):
@@ -430,4 +473,6 @@ class Adopt_adv(torch.optim.Optimizer):
430
473
  for i, p in enumerate(group['params']):
431
474
  self.step_parameter(p, group, i)
432
475
 
433
- return loss
476
+ self.global_step += 1
477
+
478
+ return loss