adv-optm 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

adv_optm/__init__.py CHANGED
@@ -16,4 +16,4 @@ __all__ = [
16
16
  "Lion_Prodigy_adv",
17
17
  ]
18
18
 
19
- __version__ = "0.1.7"
19
+ __version__ = "0.1.8"
@@ -55,7 +55,7 @@ class AdamW_adv(torch.optim.Optimizer):
55
55
  the warmup, `alpha` ramps from 0 to its target value. If `None`,
56
56
  the scheduler is disabled. (default: None)
57
57
  factored (bool): whether to use the factorization or disable it to use
58
- the uncompressed optimizer. (default: True)
58
+ the uncompressed optimizer. (default: False)
59
59
  """
60
60
 
61
61
  def __init__(
@@ -76,7 +76,7 @@ class AdamW_adv(torch.optim.Optimizer):
76
76
  beta3_ema: float = 0.9999,
77
77
  alpha: float = 5.0,
78
78
  t_alpha: int | None = None,
79
- factored: bool = True,
79
+ factored: bool = False,
80
80
  ):
81
81
  if not (lr >= 0.0):
82
82
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -216,7 +216,10 @@ class AdamW_adv(torch.optim.Optimizer):
216
216
  del unpacked_sign_slow
217
217
 
218
218
  mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
219
- update = mt + (alpha_t * mt_slow) if beta1 > 0 else grad_reshaped + (alpha_t * mt_slow)
219
+ if beta1 > 0:
220
+ update = torch.add(mt, mt_slow, alpha=alpha_t)
221
+ else:
222
+ update = torch.add(grad_reshaped, mt_slow, alpha=alpha_t)
220
223
  else:
221
224
  update = mt.clone() if beta1 > 0 else grad_reshaped.clone()
222
225
  del grad_reshaped
@@ -262,7 +265,10 @@ class AdamW_adv(torch.optim.Optimizer):
262
265
  if self.use_AdEMAMix:
263
266
  exp_avg_slow = state['exp_avg_slow']
264
267
  exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
265
- update = exp_avg + (alpha_t * exp_avg_slow) if beta1 > 0 else grad + (alpha_t * exp_avg_slow)
268
+ if beta1 > 0:
269
+ update = torch.add(exp_avg, exp_avg_slow, alpha=alpha_t)
270
+ else:
271
+ update = torch.add(grad, exp_avg_slow, alpha=alpha_t)
266
272
  else:
267
273
  update = exp_avg.clone() if beta1 > 0 else grad.clone()
268
274
 
@@ -63,7 +63,7 @@ class Adopt_adv(torch.optim.Optimizer):
63
63
  the scheduler is disabled and the full `alpha` value is used from
64
64
  the start. (default: None)
65
65
  factored (bool): whether to use the factorization or disable it to use
66
- the uncompressed optimizer. (default: True)
66
+ the uncompressed optimizer. (default: False)
67
67
  """
68
68
 
69
69
  def __init__(
@@ -84,7 +84,7 @@ class Adopt_adv(torch.optim.Optimizer):
84
84
  beta3_ema: float = 0.9999,
85
85
  alpha: float = 5.0,
86
86
  t_alpha: int | None = None,
87
- factored: bool = True,
87
+ factored: bool = False,
88
88
  ):
89
89
  if not (lr >= 0.0):
90
90
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -235,7 +235,7 @@ class Adopt_adv(torch.optim.Optimizer):
235
235
 
236
236
  if self.use_AdEMAMix:
237
237
  mt_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
238
- update = mt + (alpha_t * mt_slow)
238
+ update = torch.add(mt, m_slow, alpha=alpha_t)
239
239
  update = update.view(p.shape)
240
240
  else:
241
241
  update = mt.view(p.shape)
@@ -295,9 +295,9 @@ class Adopt_adv(torch.optim.Optimizer):
295
295
 
296
296
  if self.use_AdEMAMix:
297
297
  m_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
298
- update = m + (alpha_t * m_slow)
298
+ update = torch.add(m, m_slow, alpha=alpha_t)
299
299
  else:
300
- update = m
300
+ update = m.clone()
301
301
 
302
302
  if self.use_atan2:
303
303
  update.mul_(group['lr'] * 1.2732395447351628)
@@ -33,8 +33,6 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
33
33
  (default: 0.0).
34
34
  factored (bool): whether to use the factorization or use the
35
35
  uncompressed optimizer. (default: True)
36
- variance_reduction (bool): whether to use the variance reduction technique
37
- from "Convergence Analysis of the Lion Optimizer" (arXiv:2508.12327v1). (default: False).
38
36
  d0 (float):
39
37
  Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
40
38
  d_coef (float):
@@ -66,7 +64,6 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
66
64
  use_cautious: bool = False,
67
65
  clip_threshold: float = 0.0,
68
66
  factored: bool = True,
69
- variance_reduction: bool = False,
70
67
  # prodigy parameters
71
68
  beta3: float = None,
72
69
  d0: float = 1e-6,
@@ -97,7 +94,6 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
97
94
  self.stochastic_rounding = stochastic_rounding
98
95
  self.use_cautious = use_cautious
99
96
  self.factored = factored
100
- self.variance_reduction = variance_reduction
101
97
  self.fsdp_in_use = fsdp_in_use
102
98
  super().__init__(params, defaults)
103
99
  # Global state for accumulating metrics across parameter updates within a single step.
@@ -183,12 +179,8 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
183
179
  state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
184
180
  packed_d2 = (d2 + 7) // 8
185
181
  state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
186
- if self.variance_reduction:
187
- state['prev_grad'] = torch.zeros((d1, d2), device=p.device, dtype=dtype)
188
182
  else: # Fallback to standard Lion
189
183
  state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
190
- if self.variance_reduction:
191
- state['prev_grad'] = torch.zeros_like(p, device=p.device, dtype=dtype)
192
184
 
193
185
  if state['factored']:
194
186
  # Factored Path
@@ -215,20 +207,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
215
207
  update_for_param = signed_update.view(p.shape).mul(self.dlr)
216
208
 
217
209
  # Update momentum m_t = β2*m_{t-1} + (1-β2)*lr*g_t
218
- if self.variance_reduction:
219
- if state['step'] == 1:
220
- exp_avg.copy_(grad_reshaped)
221
- else:
222
- # Heuristic Prodigy-STORM update
223
- correction = exp_avg.sub(state['prev_grad'])
224
- grad_alpha = self.d * (1 - self.beta2) + self.beta2
225
- exp_avg.copy_(grad_reshaped).mul_(grad_alpha).add_(correction, alpha=self.beta2)
226
- del correction, grad_alpha
227
- state['prev_grad'].copy_(grad_reshaped)
228
- else:
229
- # Standard Prodigy-Lion
230
- alpha = self.d * (1 - self.beta2)
231
- exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=alpha)
210
+ exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=self.d * (1 - self.beta2))
232
211
  del grad_reshaped
233
212
 
234
213
  # Compress new momentum m_t and store factors
@@ -254,20 +233,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
254
233
  update_for_param = signed_update.mul(self.dlr)
255
234
 
256
235
  # Update momentum
257
- if self.variance_reduction:
258
- if state['step'] == 1:
259
- exp_avg.copy_(grad)
260
- else:
261
- # Heuristic Prodigy-STORM update
262
- correction = exp_avg.sub(state['prev_grad'])
263
- grad_alpha = self.d * (1 - self.beta2) + self.beta2
264
- exp_avg.copy_(grad).mul_(grad_alpha).add_(correction, alpha=self.beta2)
265
- del grad_alpha, correction
266
- state['prev_grad'].copy_(grad)
267
- else:
268
- # Standard Prodigy-Lion
269
- alpha = self.d * (1 - self.beta2)
270
- exp_avg.mul_(self.beta2).add_(grad, alpha=alpha)
236
+ exp_avg.mul_(self.beta2).add_(grad, alpha=self.d * (1 - self.beta2))
271
237
 
272
238
  # --- Accumulate Prodigy stats ---
273
239
  d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
@@ -298,7 +264,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
298
264
  else:
299
265
  p.data.add_(-update_for_param)
300
266
 
301
- del update_for_param
267
+ del update_for_param
302
268
 
303
269
  @torch.no_grad()
304
270
  def step(self, closure: Optional[callable] = None):
@@ -33,8 +33,6 @@ class Lion_adv(torch.optim.Optimizer):
33
33
  (default: 0.0).
34
34
  factored (bool): whether to use the factorization or use the
35
35
  uncompressed optimizer. (default: True)
36
- variance_reduction (bool): whether to use the variance reduction technique
37
- from "Convergence Analysis of the Lion Optimizer" (arXiv:2508.12327v1). (default: False).
38
36
  """
39
37
 
40
38
  def __init__(
@@ -49,7 +47,6 @@ class Lion_adv(torch.optim.Optimizer):
49
47
  use_cautious: bool = False,
50
48
  clip_threshold: float = 0.0,
51
49
  factored: bool = True,
52
- variance_reduction: bool = False,
53
50
  ):
54
51
  if not lr > 0.0:
55
52
  raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
@@ -69,7 +66,6 @@ class Lion_adv(torch.optim.Optimizer):
69
66
  self.stochastic_rounding = stochastic_rounding
70
67
  self.use_cautious = use_cautious
71
68
  self.factored = factored
72
- self.variance_reduction = variance_reduction
73
69
  super().__init__(params, defaults)
74
70
 
75
71
  @property
@@ -122,12 +118,8 @@ class Lion_adv(torch.optim.Optimizer):
122
118
  state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
123
119
  packed_d2 = (d2 + 7) // 8
124
120
  state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
125
- if self.variance_reduction:
126
- state['prev_grad'] = torch.zeros((d1, d2), device=p.device, dtype=dtype)
127
121
  else: # Fallback to standard Lion
128
122
  state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
129
- if self.variance_reduction:
130
- state['prev_grad'] = torch.zeros_like(p, device=p.device, dtype=dtype)
131
123
 
132
124
  state['step'] += 1
133
125
  beta1, beta2 = group["betas"]
@@ -157,21 +149,9 @@ class Lion_adv(torch.optim.Optimizer):
157
149
  # Parameter update
158
150
  update_for_param = signed_update.view(p.shape).mul_(lr)
159
151
 
160
- # Update momentum
161
- if self.variance_reduction:
162
- if state['step'] == 1:
163
- exp_avg.copy_(grad_reshaped)
164
- else:
165
- # Use the simplified STORM update: m_t = g_t + β₂ * (m_{t-1} - g_{t-1})
166
- correction = exp_avg.sub(state['prev_grad'])
167
- # Calculate the new momentum and store it back into exp_avg
168
- exp_avg.copy_(grad_reshaped).add_(correction, alpha=beta2)
169
- del correction
170
- # Update prev_grad for the next iteration
171
- state['prev_grad'].copy_(grad_reshaped)
172
- else:
173
- # Standard Lion momentum update
174
- exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2)
152
+ # Standard Lion momentum update
153
+ exp_avg.mul_(beta2).add_(grad_reshaped, alpha=1-beta2)
154
+ del grad_reshaped
175
155
 
176
156
  # Compress new momentum m_t and store factors
177
157
  state['sign'] = _pack_bools(exp_avg > 0)
@@ -195,21 +175,8 @@ class Lion_adv(torch.optim.Optimizer):
195
175
 
196
176
  update_for_param = signed_update.mul_(lr)
197
177
 
198
- # Update momentum
199
- if self.variance_reduction:
200
- if state['step'] == 1:
201
- exp_avg.copy_(grad)
202
- else:
203
- # Use the simplified STORM update: m_t = g_t + β₂ * (m_{t-1} - g_{t-1})
204
- correction = exp_avg.sub(state['prev_grad'])
205
- # Calculate the new momentum and store it back into exp_avg
206
- exp_avg.copy_(grad).add_(correction, alpha=beta2)
207
- del correction
208
- # Update prev_grad for the next iteration
209
- state['prev_grad'].copy_(grad)
210
- else:
211
- # Standard Lion momentum update
212
- exp_avg.mul_(beta2).add_(grad, alpha=1-beta2)
178
+ # Standard Lion momentum update
179
+ exp_avg.mul_(beta2).add_(grad, alpha=1-beta2)
213
180
 
214
181
  if group["weight_decay"] != 0:
215
182
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
@@ -225,7 +192,7 @@ class Lion_adv(torch.optim.Optimizer):
225
192
  else:
226
193
  p.data.add_(-update_for_param)
227
194
 
228
- del update_for_param
195
+ del update_for_param
229
196
 
230
197
  @torch.no_grad()
231
198
  def step(self, closure: Optional[callable] = None):
@@ -64,7 +64,7 @@ class Prodigy_adv(torch.optim.Optimizer):
64
64
  more responsive. For large batch sizes, use low values (e.g., 0-1) for
65
65
  stability. (default: 100.0)
66
66
  factored (bool): whether to use the factorization or disable it to use
67
- the uncompressed optimizer. (default: True)
67
+ the uncompressed optimizer. (default: False)
68
68
  d0 (float):
69
69
  Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
70
70
  d_coef (float):
@@ -82,6 +82,9 @@ class Prodigy_adv(torch.optim.Optimizer):
82
82
  slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
83
83
  pth entry of each tensor. For values greater than 1 this an an approximation to standard
84
84
  Prodigy. Values ~11 are reasonable (default 11).
85
+ prodigy_steps (int): If greater than zero, disable Prodigy's stepsize adjustments
86
+ after the specified optimiser step and release all state memory required by Prodigy
87
+ (default: 0).
85
88
  """
86
89
 
87
90
  def __init__(
@@ -103,7 +106,7 @@ class Prodigy_adv(torch.optim.Optimizer):
103
106
  t_alpha: int | None = None,
104
107
  Simplified_AdEMAMix: bool = False,
105
108
  alpha_grad: float = 100.0,
106
- factored: bool = True,
109
+ factored: bool = False,
107
110
  # prodigy parameters
108
111
  beta3: float = None,
109
112
  d0: float = 1e-6,
@@ -112,6 +115,7 @@ class Prodigy_adv(torch.optim.Optimizer):
112
115
  safeguard_warmup: bool = False,
113
116
  fsdp_in_use: bool = False,
114
117
  slice_p: int = 11,
118
+ prodigy_steps: int = 0,
115
119
  ):
116
120
  if not (lr >= 0.0):
117
121
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -121,6 +125,8 @@ class Prodigy_adv(torch.optim.Optimizer):
121
125
  raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
122
126
  if not (weight_decay >= 0.0):
123
127
  raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
128
+ if not (prodigy_steps >= 0):
129
+ raise ValueError(f"prodigy_steps should be >= 0. Got {prodigy_steps}")
124
130
  if betas[0] == 0.0 and Simplified_AdEMAMix:
125
131
  raise ValueError(f"Beta 1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
126
132
  if use_AdEMAMix and Simplified_AdEMAMix:
@@ -132,6 +138,9 @@ class Prodigy_adv(torch.optim.Optimizer):
132
138
  if use_atan2 and Simplified_AdEMAMix:
133
139
  print("Warning: use_atan2 is incompatible with Simplified_AdEMAMix. Disabling use_atan2.")
134
140
  use_atan2 = False
141
+ if Simplified_AdEMAMix and alpha_grad > 0:
142
+ # scales d_coef by alpha_grad, this force prodigy to behave well with Simplified_AdEMAMix
143
+ d_coef = d_coef/alpha_grad
135
144
 
136
145
  defaults = {
137
146
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
@@ -140,7 +149,7 @@ class Prodigy_adv(torch.optim.Optimizer):
140
149
  "beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
141
150
  "beta3": beta3, "d": d0, "d0": d0, "d_max": d0, "d_numerator": 0.0, "d_coef": d_coef,
142
151
  "growth_rate": growth_rate, "safeguard_warmup": safeguard_warmup, "k": 0, "slice_p": slice_p,
143
- "fsdp_in_use": fsdp_in_use,
152
+ "fsdp_in_use": fsdp_in_use, "prodigy_steps": prodigy_steps,
144
153
  "alpha_grad": alpha_grad,
145
154
  }
146
155
  self.stochastic_rounding = stochastic_rounding
@@ -293,7 +302,10 @@ class Prodigy_adv(torch.optim.Optimizer):
293
302
  torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
294
303
  del unpacked_sign_slow
295
304
  mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=self.d * (1.0 - beta3_ema))
296
- update = mt + (alpha_t * mt_slow) if self.beta1 > 0 else grad_reshaped + (alpha_t * mt_slow)
305
+ if self.beta1 > 0:
306
+ update = torch.add(mt, mt_slow, alpha=alpha_t)
307
+ else:
308
+ update = torch.add(grad_reshaped, mt_slow, alpha=alpha_t)
297
309
  elif self.Simplified_AdEMAMix:
298
310
  update = torch.add(mt, grad_reshaped, alpha=alpha_grad * self.d)
299
311
  else:
@@ -344,7 +356,10 @@ class Prodigy_adv(torch.optim.Optimizer):
344
356
  if self.use_AdEMAMix:
345
357
  exp_avg_slow = state['exp_avg_slow']
346
358
  exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=self.d * (1.0 - beta3_ema))
347
- update = exp_avg + (alpha_t * exp_avg_slow) if self.beta1 > 0 else grad + (alpha_t * exp_avg_slow)
359
+ if self.beta1 > 0:
360
+ update = torch.add(exp_avg, exp_avg_slow, alpha=alpha_t)
361
+ else:
362
+ update = torch.add(grad, exp_avg_slow, alpha=alpha_t)
348
363
  elif self.Simplified_AdEMAMix:
349
364
  update = torch.add(exp_avg, grad, alpha=alpha_grad * self.d)
350
365
  else:
@@ -364,19 +379,27 @@ class Prodigy_adv(torch.optim.Optimizer):
364
379
  update.mul_(self.dlr)
365
380
 
366
381
  # --- Accumulate Prodigy stats ---
367
- d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
368
- s, p0 = state['s'], state['p0']
369
- grad_flat = grad.flatten().float()
370
- p_flat = p.data.flatten().float()
371
- p0 = p0.float()
382
+ prodigy_steps = group['prodigy_steps']
383
+ if prodigy_steps <= 0 or group['k'] < prodigy_steps:
384
+ d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
385
+ s, p0 = state['s'], state['p0']
386
+ grad_flat = grad.flatten().float()
387
+ p_flat = p.data.flatten().float()
388
+ p0 = p0.float()
372
389
 
373
- self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
390
+ self.d_numerator += (self.d / d0) * self.dlr * torch.dot(grad_flat[::slice_p], p0.data - p_flat[::slice_p]).item()
374
391
 
375
- alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
376
- s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
377
- self.d_denom += s.abs().sum().item()
392
+ alpha = ((self.d / d0) * self.d) if safeguard_warmup else ((self.d / d0) * self.dlr)
393
+ s.mul_(self.beta3).add_(grad_flat[::slice_p], alpha=alpha)
394
+ self.d_denom += s.abs().sum().item()
378
395
 
379
- del s, p0, grad_flat, p_flat, alpha
396
+ del s, p0, grad_flat, p_flat, alpha
397
+ else:
398
+ # Free memory if prodigy_steps is reached
399
+ if 's' in state:
400
+ del state['s']
401
+ if 'p0' in state:
402
+ del state['p0']
380
403
 
381
404
  # Decoupled weight decay
382
405
  if group["weight_decay"] != 0:
@@ -413,29 +436,37 @@ class Prodigy_adv(torch.optim.Optimizer):
413
436
  def calculate_d(self):
414
437
  """Calculates the new `d` based on the accumulated stats."""
415
438
  g_group = self.param_groups[0]
416
- d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
417
439
 
418
- if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
419
- # Use the device of the first parameter to avoid hardcoding '.cuda()'
420
- device = self.param_groups[0]['params'][0].device
421
- dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
422
- dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
423
- global_d_numerator = dist_tensor[0].item()
424
- global_d_denom = dist_tensor[1].item()
425
- else:
426
- global_d_numerator = self.d_numerator
427
- global_d_denom = self.d_denom
428
-
429
- d_hat = self.d
430
- if global_d_denom > 0:
431
- d_hat = d_coef * global_d_numerator / global_d_denom
432
- if self.d == g_group['d0']:
433
- self.d = max(self.d, d_hat)
434
- d_max = max(d_max, d_hat)
435
- self.d = min(d_max, self.d * growth_rate)
436
-
440
+ # Only perform d-adaptation if prodigy_steps has not been reached
441
+ prodigy_active = not (g_group.get('prodigy_steps', 0) > 0 and g_group['k'] >= g_group['prodigy_steps'])
442
+
443
+ if prodigy_active:
444
+ d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
445
+
446
+ if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
447
+ # Use the device of the first parameter to avoid hardcoding '.cuda()'
448
+ device = self.param_groups[0]['params'][0].device
449
+ dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
450
+ dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
451
+ global_d_numerator = dist_tensor[0].item()
452
+ global_d_denom = dist_tensor[1].item()
453
+ else:
454
+ global_d_numerator = self.d_numerator
455
+ global_d_denom = self.d_denom
456
+
457
+ d_hat = self.d
458
+ if global_d_denom > 0:
459
+ d_hat = d_coef * global_d_numerator / global_d_denom
460
+ if self.d == g_group['d0']:
461
+ self.d = max(self.d, d_hat)
462
+ d_max = max(d_max, d_hat)
463
+ self.d = min(d_max, self.d * growth_rate)
464
+
465
+ for group in self.param_groups:
466
+ group['d_numerator'] = global_d_numerator
467
+ group['d'] = self.d
468
+ group['d_max'] = d_max
469
+
470
+ # Increment step counter for all groups, regardless of whether d was updated
437
471
  for group in self.param_groups:
438
- group['d_numerator'] = global_d_numerator
439
- group['d'] = self.d
440
- group['d_max'] = d_max
441
472
  group['k'] += 1
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 0.1.7
3
+ Version: 0.1.8
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -0,0 +1,19 @@
1
+ adv_optm/__init__.py,sha256=csc19AmU_h7daI3bo4hDVBouMqGiHejfipPIOGFAUQ8,306
2
+ adv_optm/optim/AdamW_adv.py,sha256=Had6kzSBI0eEMiL2yI1wa1nEBoPfgwHQGtnRcDJ8tXI,14078
3
+ adv_optm/optim/Adopt_adv.py,sha256=-iAKhPbEnzdL0Mx96h2BBlJB85TyHdkjULRjWvNbTyY,14833
4
+ adv_optm/optim/Lion_Prodigy_adv.py,sha256=kIAGXoMbDNRg5reKXtUC_vQQ2gyM-NXPB-Pv9zSpiE8,12787
5
+ adv_optm/optim/Lion_adv.py,sha256=05j_j6LIzHW5b79DVwMIf1FZHVNB8xnStNVjlOdVkCE,8256
6
+ adv_optm/optim/Prodigy_adv.py,sha256=U4grKRumzDJRYSI-QHmmZZ7ed_67tyiC3OPSXqJVBx8,21759
7
+ adv_optm/optim/Simplified_AdEMAMix.py,sha256=opIZjnGJ03-DDAIHTZyJBMReVfgusGDb8FZSWMU3-UM,9774
8
+ adv_optm/optim/__init__.py,sha256=pcP865H2j1tut2VfTUhzQh7V8TF_tzPjqFnjMfFed2k,382
9
+ adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
10
+ adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
11
+ adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
12
+ adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
13
+ adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
14
+ adv_optm/util/__init__.py,sha256=qoyIF0jcLjs_vSEcsv36clw5LFNBEbifyXrrVxMH-G4,349
15
+ adv_optm-0.1.8.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
16
+ adv_optm-0.1.8.dist-info/METADATA,sha256=Ydu5_f_d19hoYMf9zvP3eu9ci8XsLWyDuY99JYJVR9o,5846
17
+ adv_optm-0.1.8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
18
+ adv_optm-0.1.8.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
19
+ adv_optm-0.1.8.dist-info/RECORD,,
@@ -1,19 +0,0 @@
1
- adv_optm/__init__.py,sha256=CZ_tjWWk5d5D8q_R0rcr8vvwlZyY_44zyAcIAmN_SDY,306
2
- adv_optm/optim/AdamW_adv.py,sha256=ZeNzk2tWbyd2QDI5hp4InwG3iuHHfqLrlhr_VmcQfRM,13884
3
- adv_optm/optim/Adopt_adv.py,sha256=rzBWfFOPrMuC6vwETsw7QPKmVXcv4IJRDCTj-6eU1Qk,14798
4
- adv_optm/optim/Lion_Prodigy_adv.py,sha256=JMss9X8lRpIU4E34PfFpWMMal_XNvZ8Yuqc6i7R5wIQ,14588
5
- adv_optm/optim/Lion_adv.py,sha256=BA4bSEhJiQ7BhGLDRn9nuMlBrLVh-OMscbmSTeGgRmI,10137
6
- adv_optm/optim/Prodigy_adv.py,sha256=gJL2r32R3xGD62jMR55ZyKxRv0yL70XHxj4FzEJbFc4,20196
7
- adv_optm/optim/Simplified_AdEMAMix.py,sha256=opIZjnGJ03-DDAIHTZyJBMReVfgusGDb8FZSWMU3-UM,9774
8
- adv_optm/optim/__init__.py,sha256=pcP865H2j1tut2VfTUhzQh7V8TF_tzPjqFnjMfFed2k,382
9
- adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
10
- adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
11
- adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
12
- adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
13
- adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
14
- adv_optm/util/__init__.py,sha256=qoyIF0jcLjs_vSEcsv36clw5LFNBEbifyXrrVxMH-G4,349
15
- adv_optm-0.1.7.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
16
- adv_optm-0.1.7.dist-info/METADATA,sha256=BEKyVG9zVdb9WThOw9YtgWZ_zqDmErumpY5Fr-AkbX0,5846
17
- adv_optm-0.1.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
18
- adv_optm-0.1.7.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
19
- adv_optm-0.1.7.dist-info/RECORD,,