adv-optm 1.2.dev18__py3-none-any.whl → 2.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

@@ -56,13 +56,6 @@ class Adopt_adv(torch.optim.Optimizer):
56
56
  before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
57
57
  A higher value increases the stabilizing influence of the slow
58
58
  momentum. (default: 5.0)
59
- t_alpha (Optional[int]): The number of steps for a linear warmup of the
60
- `alpha` parameter (only used when `use_AdEMAMix` is `True`). This is
61
- highly recommended to prevent instability at the beginning of training,
62
- as it gradually introduces the stabilizing slow momentum term. During
63
- the warmup, `alpha` ramps from 0 to its target value. If `None`,
64
- the scheduler is disabled and the full `alpha` value is used from
65
- the start. (default: None)
66
59
  Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
67
60
  This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
68
61
  more responsive, especially for small batch sizes. Enabling this will
@@ -90,10 +83,10 @@ class Adopt_adv(torch.optim.Optimizer):
90
83
  k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
91
84
  logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
92
85
  every logging steps. Useful for debugging and tuning. Set to 0 to disable
93
- logging (default: 0).
86
+ logging (default: 0).
94
87
  layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
95
88
  and returns a unique, hashable key representing its "layer" or "bucket".
96
- If `None`, parameters are bucketed by their memory ID (tensor-wise).
89
+ If `None`, parameters are bucketed by their shape.
97
90
  (default: None)
98
91
  nnmf_factor (bool): whether to use the factorization or disable it to use
99
92
  the uncompressed optimizer. (default: False)
@@ -107,7 +100,7 @@ class Adopt_adv(torch.optim.Optimizer):
107
100
  eps: float = 1e-6,
108
101
  weight_decay: float = 0.0,
109
102
  clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,
110
- vector_reshape: bool = True,
103
+ vector_reshape: bool = False,
111
104
  stochastic_rounding: bool = True,
112
105
  use_atan2: bool = False,
113
106
  cautious_mask: bool = False,
@@ -116,7 +109,6 @@ class Adopt_adv(torch.optim.Optimizer):
116
109
  use_AdEMAMix: bool = False,
117
110
  beta3_ema: float = 0.9999,
118
111
  alpha: float = 5.0,
119
- t_alpha: int | None = None,
120
112
  Simplified_AdEMAMix: bool = False,
121
113
  alpha_grad: float = 100.0,
122
114
  kourkoutas_beta: bool = False,
@@ -127,6 +119,8 @@ class Adopt_adv(torch.optim.Optimizer):
127
119
  k_logging: int = 0,
128
120
  layer_key_fn: Optional[Callable] = None,
129
121
  nnmf_factor: bool = False,
122
+ # Compiled
123
+ compiled_optimizer: bool = False,
130
124
  ):
131
125
  if not (lr >= 0.0):
132
126
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -141,7 +135,8 @@ class Adopt_adv(torch.optim.Optimizer):
141
135
  cautious_mask = False
142
136
  if betas[0] == 0.0 and Simplified_AdEMAMix:
143
137
  raise ValueError(f"Beta1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
144
- if kourkoutas_beta and not (betas[1] > beta2_min): raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
138
+ if kourkoutas_beta and not (betas[1] > beta2_min):
139
+ raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
145
140
  if use_AdEMAMix and Simplified_AdEMAMix:
146
141
  print("Warning: use_AdEMAMix is incompatible with Simplified_AdEMAMix, Disabling use_AdEMAMix.")
147
142
  if grams_moment and Simplified_AdEMAMix:
@@ -152,9 +147,10 @@ class Adopt_adv(torch.optim.Optimizer):
152
147
  defaults = {
153
148
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
154
149
  "vector_reshape": vector_reshape, "beta3_ema": beta3_ema, "alpha": alpha,
155
- "t_alpha": t_alpha, "alpha_grad": alpha_grad,
150
+ "alpha_grad": alpha_grad,
156
151
  "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
157
152
  "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
153
+ "compiled_optimizer": compiled_optimizer,
158
154
  }
159
155
  self.clip_lambda = clip_lambda
160
156
  self.stochastic_rounding = stochastic_rounding
@@ -169,9 +165,17 @@ class Adopt_adv(torch.optim.Optimizer):
169
165
  self.layer_key_fn = layer_key_fn
170
166
  super().__init__(params, defaults)
171
167
 
168
+ self.init_step()
169
+
172
170
  if self.kourkoutas_beta:
173
171
  self.kourkoutas_helper = KourkoutasHelper(self)
174
172
 
173
+ self.global_step = 0
174
+
175
+ if compiled_optimizer:
176
+ torch._dynamo.config.cache_size_limit = 8192
177
+ self.compile(fullgraph=True)
178
+
175
179
  @property
176
180
  def supports_fused_back_pass(self): return True
177
181
  @property
@@ -179,29 +183,22 @@ class Adopt_adv(torch.optim.Optimizer):
179
183
  @property
180
184
  def supports_flat_params(self): return False
181
185
 
182
- @torch.no_grad()
183
- def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
184
- if p.grad is None:
185
- return
186
+ def init_step(self):
187
+ for group in self.param_groups:
188
+ for p in group['params']:
189
+ self.__init_state(p, group)
186
190
 
187
- grad = p.grad
188
- if self.factored and grad.dtype != torch.float32:
189
- grad = grad.float()
190
- if self.orthogonal_gradient:
191
- grad = _orthogonalize_gradient(p, grad)
191
+ @torch.no_grad()
192
+ def __init_state(self, p, group):
192
193
  state = self.state[p]
193
194
 
194
- # State Initialization
195
- if 'step' not in state:
196
- state['step'] = 0
195
+ if len(state) == 0:
197
196
 
198
- should_factor = (
197
+ state['factored'] = (
199
198
  self.factored and
200
199
  not (len(p.shape) == 1 and not group['vector_reshape'])
201
200
  )
202
201
 
203
- state['factored'] = should_factor
204
-
205
202
  dtype = torch.float32 if self.factored else p.dtype
206
203
 
207
204
  if state['factored']:
@@ -210,55 +207,75 @@ class Adopt_adv(torch.optim.Optimizer):
210
207
 
211
208
  # m_0 = 0
212
209
  if group['betas'][0] > 0:
213
- state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
210
+ state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
214
211
  state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
215
212
  if not self.grams_moment:
216
213
  packed_d2 = (d2 + 7) // 8
217
214
  state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
218
215
  if self.use_AdEMAMix:
219
- state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
216
+ state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
220
217
  state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
221
218
  packed_d2 = (d2 + 7) // 8
222
219
  state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
223
- # v_0 = g_0^2 (SMMF_ADOPT NMF storage)
224
- vt_init = grad.view(d1, d2).square_()
225
- # Allocate NMF factors for v
226
- state['mu_v_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
227
- state['mv_v_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
228
- # Initialize v_0 using NMF
229
- _nnmf(vt_init, out=(state['mu_v_nmf'], state['mv_v_nmf']))
220
+
230
221
  else: # Fallback for non-factored tensors
231
222
  if group['betas'][0] > 0:
232
223
  state['exp_avg'] = torch.zeros_like(p, dtype=dtype) # m_0
233
224
  if self.use_AdEMAMix:
234
225
  state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
235
- state['exp_avg_sq'] = grad.square() # v_0
226
+
227
+ @torch.no_grad()
228
+ def __init_step(self, p, group):
229
+ if p.grad is None:
230
+ return
231
+
232
+ state = self.state[p]
233
+
234
+ if 'exp_avg_sq' in state or 'mu_v_nmf' in state:
235
+ return
236
+
237
+ grad = p.grad
238
+ dtype = torch.float32 if self.factored else p.dtype
239
+
240
+ if state['factored']:
241
+ d1, d2 = state['effective_shape']
242
+ # v_0 = g_0^2 (SMMF_ADOPT NMF storage)
243
+ vt_init = grad.view(d1, d2).square_()
244
+ # Allocate NMF factors for v
245
+ state['mu_v_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
246
+ state['mv_v_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
247
+ # Initialize v_0 using NMF
248
+ _nnmf(vt_init, out=(state['mu_v_nmf'], state['mv_v_nmf']))
249
+ del vt_init
250
+ else:
251
+ state['exp_avg_sq'] = grad.square() # v_0
252
+
253
+
254
+ @torch.no_grad()
255
+ def __step_parameter(self, p: torch.Tensor, group: dict, lr: torch.Tensor | float):
256
+ if p.grad is None:
257
+ return
258
+
259
+ grad = p.grad
260
+ if self.factored and grad.dtype != torch.float32:
261
+ grad = grad.float()
262
+ if self.orthogonal_gradient:
263
+ grad = _orthogonalize_gradient(p, grad)
264
+ state = self.state[p]
265
+
236
266
 
237
267
  beta1, beta2 = group['betas']
238
268
 
239
- current_step = state['step']
240
269
  if group.get('kourkoutas_beta', False):
241
- # Call prepare_step() once at the beginning of the step for all params
242
- self.kourkoutas_helper.maybe_prepare_step(current_step)
243
270
  # Accumulate current grad's norm for the *next* step
244
271
  self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
245
272
  # Get the dynamic beta2 calculated in prepare_step()
246
- beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
247
-
248
- # The first step is for initialization only (skip when use_atan2 as it's scale invariant).
249
- if state['step'] == 0 and not self.use_atan2:
250
- state['step'] += 1
251
- return
273
+ beta2 = self.kourkoutas_helper.get_beta2(p, group)
252
274
 
253
275
  if self.use_AdEMAMix:
254
276
  beta3_ema = group['beta3_ema']
255
277
  alpha = group['alpha']
256
- t_alpha = group['t_alpha']
257
- # Use step+1 for 1-based step count in scheduler
258
- alpha_step = state['step'] + 1
259
- alpha_t = alpha
260
- if t_alpha is not None and t_alpha > 0 and alpha_step < t_alpha:
261
- alpha_t = min(alpha_step * alpha / t_alpha, alpha)
278
+
262
279
  if self.Simplified_AdEMAMix:
263
280
  alpha_grad = group["alpha_grad"]
264
281
 
@@ -296,7 +313,7 @@ class Adopt_adv(torch.optim.Optimizer):
296
313
  else:
297
314
  normalized_grad = grad_reshaped / denom.add_(group['eps'])
298
315
  if self.clip_lambda is not None:
299
- clip_val = self.clip_lambda(state['step'])
316
+ clip_val = self.clip_lambda(self.global_step)
300
317
  normalized_grad.clamp_(-clip_val, clip_val)
301
318
  del denom
302
319
 
@@ -317,9 +334,9 @@ class Adopt_adv(torch.optim.Optimizer):
317
334
  if self.use_AdEMAMix:
318
335
  mt_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
319
336
  if beta1 > 0:
320
- update = torch.add(mt, mt_slow, alpha=alpha_t)
337
+ update = torch.add(mt, mt_slow, alpha=alpha)
321
338
  else:
322
- update = torch.add(normalized_grad, mt_slow, alpha=alpha_t)
339
+ update = torch.add(normalized_grad, mt_slow, alpha=alpha)
323
340
  elif self.Simplified_AdEMAMix:
324
341
  update = torch.add(mt, normalized_grad, alpha=alpha_grad)
325
342
  else:
@@ -328,9 +345,9 @@ class Adopt_adv(torch.optim.Optimizer):
328
345
  update = update.view(p.shape)
329
346
 
330
347
  if self.use_atan2:
331
- update.mul_(group['lr'] * 1.2732395447351628)
348
+ update.mul_(lr * 1.2732395447351628)
332
349
  else:
333
- update.mul_(group['lr'])
350
+ update.mul_(lr)
334
351
 
335
352
  # Update second moment v_t for the *next* step using raw g_t
336
353
  vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=1.0 - beta2)
@@ -353,7 +370,7 @@ class Adopt_adv(torch.optim.Optimizer):
353
370
  del vt
354
371
 
355
372
  else: # Standard ADOPT logic for non-factored tensors
356
- v = state['exp_avg_sq'] # v_{t-1}
373
+ v = state['exp_avg_sq'] # v_{t-1}
357
374
 
358
375
  # ADOPT Step A: Decorrelate g_t using v_{t-1}
359
376
  denom = v.sqrt()
@@ -363,7 +380,7 @@ class Adopt_adv(torch.optim.Optimizer):
363
380
  else:
364
381
  normalized_grad = grad / denom.add_(group['eps'])
365
382
  if self.clip_lambda is not None:
366
- clip_val = self.clip_lambda(state['step'])
383
+ clip_val = self.clip_lambda(self.global_step)
367
384
  normalized_grad.clamp_(-clip_val, clip_val)
368
385
  del denom
369
386
 
@@ -387,18 +404,18 @@ class Adopt_adv(torch.optim.Optimizer):
387
404
  m_slow = state['exp_avg_slow']
388
405
  m_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
389
406
  if beta1 > 0:
390
- update = torch.add(m, m_slow, alpha=alpha_t)
407
+ update = torch.add(m, m_slow, alpha=alpha)
391
408
  else:
392
- update = torch.add(normalized_grad, m_slow, alpha=alpha_t)
409
+ update = torch.add(normalized_grad, m_slow, alpha=alpha)
393
410
  elif self.Simplified_AdEMAMix:
394
411
  update = torch.add(m, normalized_grad, alpha=alpha_grad)
395
412
  else:
396
413
  update = m.clone() if beta1 > 0 else normalized_grad
397
414
 
398
415
  if self.use_atan2:
399
- update.mul_(group['lr'] * 1.2732395447351628)
416
+ update.mul_(lr * 1.2732395447351628)
400
417
  else:
401
- update.mul_(group['lr'])
418
+ update.mul_(lr)
402
419
 
403
420
  # Update second moment v_t for the next step using raw g_t
404
421
  v.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
@@ -406,9 +423,9 @@ class Adopt_adv(torch.optim.Optimizer):
406
423
  # Parameter Update
407
424
  if group["weight_decay"] != 0:
408
425
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
409
- add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
426
+ add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * lr)
410
427
  else:
411
- p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
428
+ p.data.add_(p.data, alpha=-group["weight_decay"] * lr)
412
429
 
413
430
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
414
431
  add_stochastic_(p.data, -update)
@@ -416,7 +433,33 @@ class Adopt_adv(torch.optim.Optimizer):
416
433
  p.data.add_(-update)
417
434
  del update
418
435
 
419
- state['step'] += 1
436
+ @torch.no_grad()
437
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
438
+ if self.global_step is None and 'step' in self.state[p]:
439
+ # For backward compatibility
440
+ self.global_step = self.state[p]['step']
441
+
442
+ if self.global_step == 0:
443
+ self.__init_step(p, group)
444
+
445
+ # The first step is for initialization only (skip when use_atan2 as it's scale invariant).
446
+ if self.global_step == 0 and not self.use_atan2:
447
+ self.global_step += 1
448
+ return
449
+
450
+ if group.get('kourkoutas_beta', False):
451
+ # Prepare Kourkoutas-β once per step using the global step counter.
452
+ self.kourkoutas_helper.maybe_prepare_step(self.global_step)
453
+
454
+ if not group.get('compiled_optimizer', False):
455
+ self.__step_parameter(p, group, group['lr'])
456
+ else:
457
+ lr_tensor = torch.tensor(group['lr'], device=p.device)
458
+ self._compiled_step_parameter(p, group, lr_tensor)
459
+
460
+ def compile(self, *args, **kwargs):
461
+ self._compiled_step_parameter = torch.compile(self.__step_parameter, *args, **kwargs)
462
+
420
463
 
421
464
  @torch.no_grad()
422
465
  def step(self, closure=None):
@@ -430,4 +473,6 @@ class Adopt_adv(torch.optim.Optimizer):
430
473
  for i, p in enumerate(group['params']):
431
474
  self.step_parameter(p, group, i)
432
475
 
433
- return loss
476
+ self.global_step += 1
477
+
478
+ return loss
@@ -27,17 +27,13 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
27
27
  stochastic_rounding (bool, optional): whether to use stochastic
28
28
  rounding for BF16 parameter updates (default: True).
29
29
  cautious_mask (bool): whether to use the cautious masking technique. (default: False).
30
- clip_threshold (float, optional): whether to clip the gradients norm
31
- per-parameter as proposed in the paper `Lions and Muons: Optimization via
32
- Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
33
- (default: 0.0).
34
30
  nnmf_factor (bool): whether to use the factorization or use the
35
31
  uncompressed optimizer. (default: True)
36
32
  d0 (float):
37
33
  Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
38
34
  d_coef (float):
39
35
  Coefficient in the expression for the estimate of d (default 1.0).
40
- Values such as 0.5 and 2.0 typically work as well.
36
+ Values such as 0.5 and 2.0 typically work as well.
41
37
  Changing this parameter is the preferred way to tune the method.
42
38
  growth_rate (float):
43
39
  prevent the D estimate from growing faster than this multiplicative rate.
@@ -47,8 +43,8 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
47
43
  If you're using sharded parameters, this should be set to True. The optimizer
48
44
  will attempt to auto-detect this, but if you're using an implementation other
49
45
  than PyTorch's builtin version, the auto-detection won't work.
50
- slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
51
- pth entry of each tensor. For values greater than 1 this an an approximation to standard
46
+ slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
47
+ pth entry of each tensor. For values greater than 1 this an an approximation to standard
52
48
  Prodigy. Values ~11 are reasonable (default 11).
53
49
  prodigy_steps (int): If greater than zero, disable Prodigy's stepsize adjustments
54
50
  after the specified optimiser step and release all state memory required by Prodigy
@@ -64,11 +60,10 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
64
60
  lr: float = 1,
65
61
  betas: Tuple[float, float] = (0.9, 0.99),
66
62
  weight_decay: float = 0.0,
67
- vector_reshape: bool = True,
63
+ vector_reshape: bool = False,
68
64
  stochastic_rounding: bool = True,
69
65
  orthogonal_gradient: bool = False,
70
66
  cautious_mask: bool = False,
71
- clip_threshold: float = 0.0,
72
67
  nnmf_factor: bool = False,
73
68
  # prodigy parameters
74
69
  beta3: float = None,
@@ -80,6 +75,8 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
80
75
  slice_p: int = 11,
81
76
  prodigy_steps: int = 0,
82
77
  d_limiter: bool = True,
78
+ # Compiled
79
+ compiled_optimizer: bool = False,
83
80
  ):
84
81
  if not lr > 0.0:
85
82
  raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
@@ -94,21 +91,28 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
94
91
  weight_decay=weight_decay,
95
92
  vector_reshape=vector_reshape,
96
93
  orthogonal_gradient=orthogonal_gradient,
97
- clip_threshold=clip_threshold,
98
94
  beta3=beta3, d=d0, d0=d0, d_max=d0, d_numerator=0.0, d_coef=d_coef,
99
- growth_rate=growth_rate, safeguard_warmup=safeguard_warmup, k=0, slice_p=slice_p,
95
+ growth_rate=growth_rate, safeguard_warmup=safeguard_warmup, slice_p=slice_p,
100
96
  fsdp_in_use=fsdp_in_use,
101
97
  prodigy_steps=prodigy_steps,
102
98
  d_limiter=d_limiter,
99
+ compiled_optimizer=compiled_optimizer,
103
100
  )
104
101
  self.stochastic_rounding = stochastic_rounding
105
102
  self.cautious_mask = cautious_mask
106
103
  self.factored = nnmf_factor
107
104
  self.fsdp_in_use = fsdp_in_use
108
105
  super().__init__(params, defaults)
109
- # Global state for accumulating metrics across parameter updates within a single step.
106
+ # Use the device of the first parameter to avoid hardcoding '.cuda()'
107
+ self.device = self.param_groups[0]['params'][0].device
108
+
109
+ self.global_step = 0
110
110
  self.init_step()
111
111
 
112
+ if compiled_optimizer:
113
+ torch._dynamo.config.cache_size_limit = 8192
114
+ self.compile(fullgraph=True)
115
+
112
116
  @property
113
117
  def supports_fused_back_pass(self) -> bool:
114
118
  return True
@@ -124,14 +128,13 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
124
128
  def init_step(self):
125
129
  """Resets accumulators and calculates dlr for the upcoming step."""
126
130
  self.d_denom = 0.0
127
-
131
+
128
132
  g_group = self.param_groups[0]
129
133
  self.beta1, self.beta2 = g_group['betas']
130
134
  self.beta3 = g_group['beta3']
131
135
  if self.beta3 is None:
132
136
  self.beta3 = math.sqrt(self.beta2)
133
-
134
- k = g_group['k']
137
+
135
138
  self.d = g_group['d']
136
139
  lr = g_group['lr']
137
140
 
@@ -139,38 +142,21 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
139
142
 
140
143
  self.d_numerator = g_group.get('d_numerator', 0.0) * self.beta3
141
144
 
142
- @torch.no_grad()
143
- def step_parameter(self, p: torch.Tensor, group: dict, i: Optional[int] = None):
144
- """Performs a single optimization step on a single parameter."""
145
- if p.grad is None:
146
- return
147
-
148
- if hasattr(p, "_fsdp_flattened"):
149
- self.fsdp_in_use = True
145
+ for group in self.param_groups:
146
+ for i, p in enumerate(group['params']):
147
+ self.__init_state(p, group)
150
148
 
151
- grad = p.grad
152
- if grad.dtype != torch.float32 and self.factored:
153
- grad = grad.float()
154
- if group["clip_threshold"] > 0.0:
155
- grad_norm = torch.norm(grad.detach())
156
- if grad_norm > group["clip_threshold"]:
157
- clip_coef = group["clip_threshold"] / grad_norm
158
- grad.mul_(clip_coef)
159
- if group["orthogonal_gradient"]:
160
- grad = _orthogonalize_gradient(p, grad)
149
+ @torch.no_grad()
150
+ def __init_state(self, p, group):
161
151
  state = self.state[p]
162
152
 
163
- # State Initialization
164
- if 'step' not in state:
165
- state['step'] = 0
153
+ if len(state) == 0:
166
154
 
167
- should_factor = (
155
+ state['factored'] = (
168
156
  self.factored and
169
157
  not (len(p.shape) == 1 and not group['vector_reshape'])
170
158
  )
171
159
 
172
- state['factored'] = should_factor
173
-
174
160
  dtype = torch.float32 if self.factored else p.dtype
175
161
 
176
162
  slice_p = group['slice_p']
@@ -185,13 +171,28 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
185
171
  if state['factored']:
186
172
  state['effective_shape'] = _get_effective_shape(p.numel())
187
173
  d1, d2 = state['effective_shape']
188
- state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
174
+ state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
189
175
  state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
190
176
  packed_d2 = (d2 + 7) // 8
191
177
  state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
192
178
  else: # Fallback to standard Lion
193
179
  state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
194
180
 
181
+ @torch.no_grad()
182
+ def __step_parameter(self, p: torch.Tensor, group: dict, d: torch.Tensor | float, dlr: torch.Tensor | float):
183
+ """Performs a single optimization step on a single parameter."""
184
+ if p.grad is None:
185
+ return
186
+
187
+
188
+ grad = p.grad
189
+ if grad.dtype != torch.float32 and self.factored:
190
+ grad = grad.float()
191
+ if group["orthogonal_gradient"]:
192
+ grad = _orthogonalize_gradient(p, grad)
193
+ state = self.state[p]
194
+
195
+
195
196
  if state['factored']:
196
197
  # Factored Path
197
198
  d1, d2 = state['effective_shape']
@@ -205,7 +206,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
205
206
  exp_avg = exp_avg.float()
206
207
 
207
208
  # Compute update term c_t = β1*m_{t-1} + (1-β1)*g_t
208
- signed_update = exp_avg.clone().mul_(self.beta1).add_(grad_reshaped, alpha=self.d * (1-self.beta1)).sign_()
209
+ signed_update = exp_avg.clone().mul_(self.beta1).add_(grad_reshaped, alpha=d * (1-self.beta1)).sign_()
209
210
 
210
211
  if self.cautious_mask:
211
212
  mask = (signed_update * grad_reshaped > 0).to(grad_reshaped.dtype)
@@ -214,10 +215,10 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
214
215
  del mask
215
216
 
216
217
  # Parameter update: p_t = p_{t-1} - lr * sign(c_t)
217
- update_for_param = signed_update.view(p.shape).mul(self.dlr)
218
+ update_for_param = signed_update.view(p.shape).mul(dlr)
218
219
 
219
220
  # Update momentum m_t = β2*m_{t-1} + (1-β2)*lr*g_t
220
- exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=self.d * (1 - self.beta2))
221
+ exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=d * (1 - self.beta2))
221
222
  del grad_reshaped
222
223
 
223
224
  # Compress new momentum m_t and store factors
@@ -232,7 +233,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
232
233
  # Compute update term and sign for the update
233
234
  if exp_avg.dtype != torch.float32 and self.factored:
234
235
  exp_avg = exp_avg.float()
235
- signed_update = exp_avg.clone().mul_(self.beta1).add_(grad, alpha=self.d * (1-self.beta1)).sign_()
236
+ signed_update = exp_avg.clone().mul_(self.beta1).add_(grad, alpha=d * (1-self.beta1)).sign_()
236
237
 
237
238
  if self.cautious_mask:
238
239
  mask = (signed_update * grad > 0).to(grad.dtype)
@@ -240,41 +241,18 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
240
241
  signed_update.mul_(mask)
241
242
  del mask
242
243
 
243
- update_for_param = signed_update.mul(self.dlr)
244
-
245
- # Update momentum
246
- exp_avg.mul_(self.beta2).add_(grad, alpha=self.d * (1 - self.beta2))
247
-
248
- prodigy_steps = group['prodigy_steps']
249
- if prodigy_steps <= 0 or group['k'] < prodigy_steps:
250
- # --- Accumulate Prodigy stats ---
251
- d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
252
- s, p0 = state['s'], state['p0']
253
- grad_flat = grad.flatten().float()
254
- p_flat = p.data.flatten().float()
255
- p0 = p0.float()
256
-
257
- self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
258
-
259
- alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
260
- s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
261
- self.d_denom += s.abs().sum().item()
244
+ update_for_param = signed_update.mul(dlr)
262
245
 
263
- del s, p0, grad_flat, p_flat, alpha
264
- else:
265
- # Free memory if prodigy_steps is reached
266
- if 's' in state:
267
- del state['s']
268
- if 'p0' in state:
269
- del state['p0']
246
+ # Update momentum
247
+ exp_avg.mul_(self.beta2).add_(grad, alpha=d * (1 - self.beta2))
270
248
 
271
249
  if group["weight_decay"] != 0:
272
250
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
273
251
  add_stochastic_(p.data, p.data,
274
- alpha=-group["weight_decay"] * self.dlr)
252
+ alpha=-group["weight_decay"] * dlr)
275
253
  else:
276
254
  p.data.add_(
277
- p.data, alpha=-group["weight_decay"] * self.dlr
255
+ p.data, alpha=-group["weight_decay"] * dlr
278
256
  )
279
257
 
280
258
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
@@ -284,6 +262,29 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
284
262
 
285
263
  del update_for_param
286
264
 
265
+ @torch.no_grad()
266
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
267
+ if hasattr(p, "_fsdp_flattened"):
268
+ self.fsdp_in_use = True
269
+
270
+ if self.global_step is None and 'step' in self.state[p]:
271
+ # For backward compatibility
272
+ self.global_step = self.state[p]['step']
273
+
274
+ if isinstance(self.d_numerator, float):
275
+ self.d_numerator = torch.tensor(self.d_numerator, device=p.device)
276
+ self.d_denom = torch.tensor(self.d_denom, device=p.device)
277
+
278
+ if not group.get('compiled_optimizer', False):
279
+ self.__step_parameter(p, group, self.d, self.dlr)
280
+ else:
281
+ d_tensor = torch.tensor(self.d, device=p.device)
282
+ dlr_tensor = torch.tensor(self.dlr, device=p.device)
283
+ self._compiled_step_parameter(p, group, d_tensor, dlr_tensor)
284
+
285
+ def compile(self, *args, **kwargs):
286
+ self._compiled_step_parameter = torch.compile(self.__step_parameter, *args, **kwargs)
287
+
287
288
  @torch.no_grad()
288
289
  def step(self, closure: Optional[callable] = None):
289
290
  """Performs a single optimization step."""
@@ -306,21 +307,19 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
306
307
  """Calculates the new `d` based on the accumulated stats."""
307
308
  g_group = self.param_groups[0]
308
309
  # Only perform d-adaptation if prodigy_steps has not been reached
309
- prodigy_active = not (g_group.get('prodigy_steps', 0) > 0 and g_group['k'] >= g_group['prodigy_steps'])
310
+ prodigy_active = not (g_group.get('prodigy_steps', 0) > 0 and self.global_step >= g_group['prodigy_steps'])
310
311
 
311
312
  if prodigy_active:
312
313
  d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
313
-
314
+
314
315
  if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
315
- # Use the device of the first parameter to avoid hardcoding '.cuda()'
316
- device = self.param_groups[0]['params'][0].device
317
- dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
316
+ dist_tensor = torch.stack([self.d_numerator, self.d_denom])
318
317
  dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
319
318
  global_d_numerator = dist_tensor[0].item()
320
319
  global_d_denom = dist_tensor[1].item()
321
320
  else:
322
- global_d_numerator = self.d_numerator
323
- global_d_denom = self.d_denom
321
+ global_d_numerator = self.d_numerator.item()
322
+ global_d_denom = self.d_denom.item()
324
323
 
325
324
  d_hat = self.d
326
325
  if global_d_denom > 0:
@@ -337,5 +336,4 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
337
336
  group['d'] = self.d
338
337
  group['d_max'] = d_max
339
338
  # Increment step counter for all groups, regardless of whether d was updated
340
- for group in self.param_groups:
341
- group['k'] += 1
339
+ self.global_step += 1