adv-optm 1.2.dev13__py3-none-any.whl → 2.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

@@ -50,12 +50,6 @@ class Prodigy_adv(torch.optim.Optimizer):
50
50
  before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
51
51
  A higher value increases the stabilizing influence of the slow
52
52
  momentum. (default: 5.0)
53
- t_alpha (Optional[int]): The number of steps for a linear warmup of the
54
- `alpha` parameter (only used when `use_AdEMAMix` is `True`). This is
55
- highly recommended to prevent instability at the beginning of training,
56
- as it gradually introduces the stabilizing slow momentum term. During
57
- the warmup, `alpha` ramps from 0 to its target value. If `None`,
58
- the scheduler is disabled.
59
53
  Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
60
54
  This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
61
55
  more responsive, especially for small batch sizes. Enabling this will
@@ -72,7 +66,7 @@ class Prodigy_adv(torch.optim.Optimizer):
72
66
  Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
73
67
  d_coef (float):
74
68
  Coefficient in the expression for the estimate of d (default 1.0).
75
- Values such as 0.5 and 2.0 typically work as well.
69
+ Values such as 0.5 and 2.0 typically work as well.
76
70
  Changing this parameter is the preferred way to tune the method.
77
71
  growth_rate (float):
78
72
  prevent the D estimate from growing faster than this multiplicative rate.
@@ -82,8 +76,8 @@ class Prodigy_adv(torch.optim.Optimizer):
82
76
  If you're using sharded parameters, this should be set to True. The optimizer
83
77
  will attempt to auto-detect this, but if you're using an implementation other
84
78
  than PyTorch's builtin version, the auto-detection won't work.
85
- slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
86
- pth entry of each tensor. For values greater than 1 this an an approximation to standard
79
+ slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
80
+ pth entry of each tensor. For values greater than 1 this an an approximation to standard
87
81
  Prodigy. Values ~11 are reasonable (default 11).
88
82
  prodigy_steps (int): If greater than zero, disable Prodigy's stepsize adjustments
89
83
  after the specified optimiser step and release all state memory required by Prodigy
@@ -108,7 +102,7 @@ class Prodigy_adv(torch.optim.Optimizer):
108
102
  k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
109
103
  logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
110
104
  every logging steps. Useful for debugging and tuning. Set to 0 to disable
111
- logging (default: 0).
105
+ logging (default: 0).
112
106
  layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
113
107
  and returns a unique, hashable key representing its "layer" or "bucket".
114
108
  If `None`, parameters are bucketed by their memory ID (tensor-wise).
@@ -122,7 +116,7 @@ class Prodigy_adv(torch.optim.Optimizer):
122
116
  betas: tuple[float, float] = (0.9, 0.999),
123
117
  eps: float = 1e-8,
124
118
  weight_decay: float = 0.0,
125
- vector_reshape: bool = True,
119
+ vector_reshape: bool = False,
126
120
  stochastic_rounding: bool = True,
127
121
  use_atan2: bool = False,
128
122
  cautious_mask: bool = False,
@@ -131,7 +125,6 @@ class Prodigy_adv(torch.optim.Optimizer):
131
125
  use_AdEMAMix: bool = False,
132
126
  beta3_ema: float = 0.9999,
133
127
  alpha: float = 5.0,
134
- t_alpha: int | None = None,
135
128
  Simplified_AdEMAMix: bool = False,
136
129
  alpha_grad: float = 100.0,
137
130
  nnmf_factor: bool = False,
@@ -153,6 +146,8 @@ class Prodigy_adv(torch.optim.Optimizer):
153
146
  k_warmup_steps: int = 0,
154
147
  k_logging: int = 0,
155
148
  layer_key_fn: Optional[Callable] = None,
149
+ # Compiled
150
+ compiled_optimizer: bool = False,
156
151
  ):
157
152
  if not (lr >= 0.0):
158
153
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -188,13 +183,14 @@ class Prodigy_adv(torch.optim.Optimizer):
188
183
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
189
184
  "vector_reshape": vector_reshape, "use_atan2": use_atan2,
190
185
  "orthogonal_gradient": orthogonal_gradient,
191
- "beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
186
+ "beta3_ema": beta3_ema, "alpha": alpha,
192
187
  "beta3": beta3, "d": d0, "d0": d0, "d_max": d0, "d_numerator": 0.0, "d_coef": d_coef,
193
- "growth_rate": growth_rate, "safeguard_warmup": safeguard_warmup, "k": 0, "slice_p": slice_p,
188
+ "growth_rate": growth_rate, "safeguard_warmup": safeguard_warmup, "slice_p": slice_p,
194
189
  "fsdp_in_use": fsdp_in_use, "prodigy_steps": prodigy_steps, "d_limiter": d_limiter,
195
190
  "alpha_grad": alpha_grad,
196
191
  "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
197
192
  "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
193
+ "compiled_optimizer": compiled_optimizer,
198
194
  }
199
195
  self.stochastic_rounding = stochastic_rounding
200
196
  self.cautious_mask = cautious_mask and not Simplified_AdEMAMix
@@ -203,14 +199,23 @@ class Prodigy_adv(torch.optim.Optimizer):
203
199
  self.Simplified_AdEMAMix = Simplified_AdEMAMix
204
200
  self.factored = nnmf_factor
205
201
  self.fsdp_in_use = fsdp_in_use
206
-
202
+
207
203
  self.kourkoutas_beta = kourkoutas_beta
208
204
  self.layer_key_fn = layer_key_fn
209
205
 
210
206
  super().__init__(params, defaults)
207
+ # Use the device of the first parameter to avoid hardcoding '.cuda()'
208
+ self.device = self.param_groups[0]['params'][0].device
209
+
211
210
  if self.kourkoutas_beta:
212
211
  self.kourkoutas_helper = KourkoutasHelper(self)
212
+
213
213
  self.init_step()
214
+ self.global_step = 0
215
+
216
+ if compiled_optimizer:
217
+ torch._dynamo.config.cache_size_limit = 8192
218
+ self.compile(fullgraph=True)
214
219
 
215
220
  @property
216
221
  def supports_fused_back_pass(self):
@@ -240,32 +245,21 @@ class Prodigy_adv(torch.optim.Optimizer):
240
245
  self.dlr = self.d * lr
241
246
  self.d_numerator = g_group.get('d_numerator', 0.0) * self.beta3
242
247
 
243
- @torch.no_grad()
244
- def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
245
- if p.grad is None:
246
- return
247
-
248
- if hasattr(p, "_fsdp_flattened"):
249
- self.fsdp_in_use = True
248
+ for group in self.param_groups:
249
+ for i, p in enumerate(group['params']):
250
+ self.__init_state(p, group)
250
251
 
251
- grad = p.grad
252
- if grad.dtype != torch.float32 and self.factored:
253
- grad = grad.float()
254
- if group["orthogonal_gradient"]:
255
- grad = _orthogonalize_gradient(p, grad)
252
+ @torch.no_grad()
253
+ def __init_state(self, p, group):
256
254
  state = self.state[p]
257
255
 
258
- # State Initialization
259
- if 'step' not in state:
260
- state['step'] = 0
256
+ if len(state) == 0:
261
257
 
262
- should_factor = (
258
+ state['factored'] = (
263
259
  self.factored and
264
260
  not (len(p.shape) == 1 and not group['vector_reshape'])
265
261
  )
266
262
 
267
- state['factored'] = should_factor
268
-
269
263
  slice_p = group['slice_p']
270
264
 
271
265
  dtype = torch.float32 if self.factored else p.dtype
@@ -277,18 +271,18 @@ class Prodigy_adv(torch.optim.Optimizer):
277
271
 
278
272
  # First moment (m)
279
273
  if self.beta1 > 0:
280
- state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
274
+ state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
281
275
  state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
282
276
  if not self.grams_moment:
283
277
  packed_d2 = (d2 + 7) // 8
284
278
  state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
285
279
  if self.use_AdEMAMix:
286
- state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
280
+ state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
287
281
  state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
288
282
  packed_d2 = (d2 + 7) // 8
289
283
  state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
290
284
  # Second moment (v)
291
- state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
285
+ state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
292
286
  state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
293
287
  else: # Fallback to standard AdamW for non-factored tensors
294
288
  if self.beta1 > 0:
@@ -303,25 +297,30 @@ class Prodigy_adv(torch.optim.Optimizer):
303
297
  else:
304
298
  state['p0'] = torch.tensor(0, device=device, dtype=p.dtype)
305
299
 
306
- current_step = state['step']
300
+
301
+ def __step_parameter(self, p: torch.Tensor, group: dict, d: torch.Tensor | float, dlr: torch.Tensor | float):
302
+ if p.grad is None:
303
+ return
304
+
305
+ grad = p.grad
306
+ if grad.dtype != torch.float32 and self.factored:
307
+ grad = grad.float()
308
+ if group["orthogonal_gradient"]:
309
+ grad = _orthogonalize_gradient(p, grad)
310
+
311
+ state = self.state[p]
312
+
307
313
  if group.get('kourkoutas_beta', False):
308
- # Call prepare_step() once at the beginning of the step for all params
309
- self.kourkoutas_helper.maybe_prepare_step(current_step)
310
314
  # Accumulate current grad's norm for the *next* step
311
315
  self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
312
316
  # Get the dynamic beta2 calculated in prepare_step()
313
- beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
317
+ beta2 = self.kourkoutas_helper.get_beta2(p, group)
314
318
  else:
315
319
  beta2 = self.beta2_default
316
320
 
317
321
  if self.use_AdEMAMix:
318
322
  beta3_ema = group['beta3_ema']
319
323
  alpha = group['alpha']
320
- t_alpha = group['t_alpha']
321
- alpha_step = state['step'] + 1
322
- alpha_t = alpha
323
- if t_alpha is not None and t_alpha > 0 and alpha_step < t_alpha:
324
- alpha_t = min(alpha_step * alpha / t_alpha, alpha)
325
324
  if self.Simplified_AdEMAMix:
326
325
  alpha_grad = group["alpha_grad"]
327
326
 
@@ -339,11 +338,11 @@ class Prodigy_adv(torch.optim.Optimizer):
339
338
  del unpacked_sign
340
339
  # Update momentum in full-size
341
340
  if self.Simplified_AdEMAMix:
342
- mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d)
341
+ mt.mul_(self.beta1).add_(grad_reshaped, alpha=d)
343
342
  else:
344
- mt.mul_(self.beta1).add_(grad_reshaped, alpha=self.d * (1.0 - self.beta1))
343
+ mt.mul_(self.beta1).add_(grad_reshaped, alpha=d * (1.0 - self.beta1))
345
344
  if self.grams_moment:
346
- mt.copy_(grad_reshaped.sign() * mt.abs())
345
+ mt = grad_reshaped.sign().mul_(mt.abs())
347
346
  elif self.cautious_mask:
348
347
  mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
349
348
  mask.div_(mask.mean().clamp_(min=1e-3))
@@ -351,7 +350,7 @@ class Prodigy_adv(torch.optim.Optimizer):
351
350
  del mask
352
351
 
353
352
  vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
354
- vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=self.d * self.d * (1.0 - beta2))
353
+ vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=d * d * (1.0 - beta2))
355
354
 
356
355
  if self.use_AdEMAMix:
357
356
  mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
@@ -360,15 +359,15 @@ class Prodigy_adv(torch.optim.Optimizer):
360
359
  unpacked_sign_slow = _unpack_bools(state['sign_slow'], original_m=d2)
361
360
  torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
362
361
  del unpacked_sign_slow
363
- mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=self.d * (1.0 - beta3_ema))
362
+ mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=d * (1.0 - beta3_ema))
364
363
  if self.beta1 > 0:
365
- update = torch.add(mt, mt_slow, alpha=alpha_t)
364
+ update = torch.add(mt, mt_slow, alpha=alpha)
366
365
  else:
367
- update = torch.add(grad_reshaped.mul(self.d), mt_slow, alpha=alpha_t)
366
+ update = torch.add(grad_reshaped.mul(d), mt_slow, alpha=alpha)
368
367
  elif self.Simplified_AdEMAMix:
369
- update = torch.add(mt, grad_reshaped, alpha=alpha_grad * self.d)
368
+ update = torch.add(mt, grad_reshaped, alpha=alpha_grad * d)
370
369
  else:
371
- update = mt.clone() if self.beta1 > 0 else grad_reshaped.mul(self.d)
370
+ update = mt.clone() if self.beta1 > 0 else grad_reshaped.mul(d)
372
371
  del grad_reshaped
373
372
 
374
373
  if group['use_atan2']:
@@ -377,10 +376,10 @@ class Prodigy_adv(torch.optim.Optimizer):
377
376
  update.atan2_(denom).mul_(a)
378
377
  else:
379
378
  denom = vt.sqrt()
380
- update.div_(denom.add_(self.d * group['eps']))
379
+ update.div_(denom.add_(d * group['eps']))
381
380
  del denom
382
381
 
383
- update = update.view(p.shape).mul_(self.dlr)
382
+ update = update.view(p.shape).mul_(dlr)
384
383
 
385
384
  # Compress updated moments and store new factors
386
385
  if self.beta1 > 0:
@@ -401,11 +400,11 @@ class Prodigy_adv(torch.optim.Optimizer):
401
400
  if self.beta1 > 0:
402
401
  exp_avg = state['exp_avg']
403
402
  if self.Simplified_AdEMAMix:
404
- exp_avg.mul_(self.beta1).add_(grad, alpha=self.d)
403
+ exp_avg.mul_(self.beta1).add_(grad, alpha=d)
405
404
  else:
406
- exp_avg.mul_(self.beta1).add_(grad, alpha=self.d * (1.0 - self.beta1))
405
+ exp_avg.mul_(self.beta1).add_(grad, alpha=d * (1.0 - self.beta1))
407
406
  if self.grams_moment:
408
- exp_avg = grad.sign() * exp_avg.abs()
407
+ exp_avg = grad.sign().mul_(exp_avg.abs())
409
408
  elif self.cautious_mask:
410
409
  mask = (exp_avg * grad > 0).to(grad.dtype)
411
410
  mask.div_(mask.mean().clamp_(min=1e-3))
@@ -414,17 +413,17 @@ class Prodigy_adv(torch.optim.Optimizer):
414
413
 
415
414
  if self.use_AdEMAMix:
416
415
  exp_avg_slow = state['exp_avg_slow']
417
- exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=self.d * (1.0 - beta3_ema))
416
+ exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=d * (1.0 - beta3_ema))
418
417
  if self.beta1 > 0:
419
- update = torch.add(exp_avg, exp_avg_slow, alpha=alpha_t)
418
+ update = torch.add(exp_avg, exp_avg_slow, alpha=alpha)
420
419
  else:
421
- update = torch.add(grad.mul(self.d), exp_avg_slow, alpha=alpha_t)
420
+ update = torch.add(grad.mul(d), exp_avg_slow, alpha=alpha)
422
421
  elif self.Simplified_AdEMAMix:
423
- update = torch.add(exp_avg, grad, alpha=alpha_grad * self.d)
422
+ update = torch.add(exp_avg, grad, alpha=alpha_grad * d)
424
423
  else:
425
- update = exp_avg.clone() if self.beta1 > 0 else grad.mul(self.d)
424
+ update = exp_avg.clone() if self.beta1 > 0 else grad.mul(d)
426
425
 
427
- exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=self.d * self.d * (1.0 - beta2))
426
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=d * d * (1.0 - beta2))
428
427
 
429
428
  if group['use_atan2']:
430
429
  a = 1.2732395
@@ -432,25 +431,25 @@ class Prodigy_adv(torch.optim.Optimizer):
432
431
  update.atan2_(denom).mul_(a)
433
432
  else:
434
433
  denom = exp_avg_sq.sqrt()
435
- update.div_(denom.add_(self.d * group['eps']))
434
+ update.div_(denom.add_(d * group['eps']))
436
435
  del denom
437
436
 
438
- update.mul_(self.dlr)
437
+ update.mul_(dlr)
439
438
 
440
439
  # --- Accumulate Prodigy stats ---
441
440
  prodigy_steps = group['prodigy_steps']
442
- if prodigy_steps <= 0 or group['k'] < prodigy_steps:
441
+ if prodigy_steps <= 0 or self.global_step < prodigy_steps:
443
442
  d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
444
443
  s, p0 = state['s'], state['p0']
445
444
  grad_flat = grad.flatten().float()
446
445
  p_flat = p.data.flatten().float()
447
446
  p0 = p0.float()
448
447
 
449
- self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
448
+ self.d_numerator.add_((d / d0) * dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]))
450
449
 
451
- alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
450
+ alpha = ((d / d0) * d) if safeguard_warmup else ((d / d0) * dlr)
452
451
  s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
453
- self.d_denom += s.abs().sum().item()
452
+ self.d_denom.add_(s.abs().sum())
454
453
 
455
454
  del s, p0, grad_flat, p_flat, alpha
456
455
  else:
@@ -463,9 +462,9 @@ class Prodigy_adv(torch.optim.Optimizer):
463
462
  # Decoupled weight decay
464
463
  if group["weight_decay"] != 0:
465
464
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
466
- add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * self.dlr)
465
+ add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * dlr)
467
466
  else:
468
- p.data.add_(p.data, alpha=-group["weight_decay"] * self.dlr)
467
+ p.data.add_(p.data, alpha=-group["weight_decay"] * dlr)
469
468
 
470
469
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
471
470
  add_stochastic_(p.data, -update)
@@ -473,7 +472,31 @@ class Prodigy_adv(torch.optim.Optimizer):
473
472
  p.data.add_(-update)
474
473
  del update
475
474
 
476
- state['step'] += 1
475
+ @torch.no_grad()
476
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
477
+ if hasattr(p, "_fsdp_flattened"):
478
+ self.fsdp_in_use = True
479
+
480
+ if self.global_step is None and 'step' in self.state[p]:
481
+ # For backward compatibility
482
+ self.global_step = self.state[p]['step']
483
+
484
+ if self.kourkoutas_beta:
485
+ self.kourkoutas_helper.maybe_prepare_step(self.global_step)
486
+
487
+ if isinstance(self.d_numerator, float):
488
+ self.d_numerator = torch.tensor(self.d_numerator, device=p.device)
489
+ self.d_denom = torch.tensor(self.d_denom, device=p.device)
490
+
491
+ if not group.get('compiled_optimizer', False):
492
+ self.__step_parameter(p, group, self.d, self.dlr)
493
+ else:
494
+ d_tensor = torch.tensor(self.d, device=p.device)
495
+ dlr_tensor = torch.tensor(self.dlr, device=p.device)
496
+ self._compiled_step_parameter(p, group, d_tensor, dlr_tensor)
497
+
498
+ def compile(self, *args, **kwargs):
499
+ self._compiled_step_parameter = torch.compile(self.__step_parameter, *args, **kwargs)
477
500
 
478
501
  @torch.no_grad()
479
502
  def step(self, closure=None):
@@ -494,23 +517,21 @@ class Prodigy_adv(torch.optim.Optimizer):
494
517
  def calculate_d(self):
495
518
  """Calculates the new `d` based on the accumulated stats."""
496
519
  g_group = self.param_groups[0]
497
-
520
+
498
521
  # Only perform d-adaptation if prodigy_steps has not been reached
499
- prodigy_active = not (g_group.get('prodigy_steps', 0) > 0 and g_group['k'] >= g_group['prodigy_steps'])
522
+ prodigy_active = not (g_group.get('prodigy_steps', 0) > 0 and self.global_step >= g_group['prodigy_steps'])
500
523
 
501
524
  if prodigy_active:
502
525
  d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
503
-
526
+
504
527
  if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
505
- # Use the device of the first parameter to avoid hardcoding '.cuda()'
506
- device = self.param_groups[0]['params'][0].device
507
- dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
528
+ dist_tensor = torch.stack([self.d_numerator, self.d_denom])
508
529
  dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
509
530
  global_d_numerator = dist_tensor[0].item()
510
531
  global_d_denom = dist_tensor[1].item()
511
532
  else:
512
- global_d_numerator = self.d_numerator
513
- global_d_denom = self.d_denom
533
+ global_d_numerator = self.d_numerator.item()
534
+ global_d_denom = self.d_denom.item()
514
535
 
515
536
  d_hat = self.d
516
537
  if global_d_denom > 0:
@@ -529,7 +550,6 @@ class Prodigy_adv(torch.optim.Optimizer):
529
550
  group['d_numerator'] = global_d_numerator
530
551
  group['d'] = self.d
531
552
  group['d_max'] = d_max
532
-
553
+
533
554
  # Increment step counter for all groups, regardless of whether d was updated
534
- for group in self.param_groups:
535
- group['k'] += 1
555
+ self.global_step += 1