adv-optm 1.2.dev14__py3-none-any.whl → 2.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

adv_optm/__init__.py CHANGED
@@ -20,4 +20,4 @@ __all__ = [
20
20
  "AdaMuon_adv",
21
21
  ]
22
22
 
23
- __version__ = "1.2.dev14"
23
+ __version__ = "2.dev1"
@@ -49,12 +49,6 @@ class AdamW_adv(torch.optim.Optimizer):
49
49
  before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
50
50
  A higher value increases the stabilizing influence of the slow
51
51
  momentum. (default: 5.0)
52
- t_alpha (Optional[int]): The number of steps for a linear warmup of the
53
- `alpha` parameter (only used when `use_AdEMAMix` is `True`). This is
54
- highly recommended to prevent instability at the beginning of training,
55
- as it gradually introduces the stabilizing slow momentum term. During
56
- the warmup, `alpha` ramps from 0 to its target value. If `None`,
57
- the scheduler is disabled. (default: None)
58
52
  kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
59
53
  If `False`, the optimizer behaves as standard AdamW. (default: False)
60
54
  beta2_min (float): The minimum value for dynamic β₂, used during periods of
@@ -72,11 +66,7 @@ class AdamW_adv(torch.optim.Optimizer):
72
66
  k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
73
67
  logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
74
68
  every logging steps. Useful for debugging and tuning. Set to 0 to disable
75
- logging (default: 0).
76
- layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
77
- and returns a unique, hashable key representing its "layer" or "bucket".
78
- If `None`, parameters are bucketed by their memory ID (tensor-wise).
79
- (default: None)
69
+ logging (default: 0).
80
70
  nnmf_factor (bool): whether to use the factorization or disable it to use
81
71
  the uncompressed optimizer. (default: False)
82
72
  """
@@ -89,7 +79,7 @@ class AdamW_adv(torch.optim.Optimizer):
89
79
  eps: float = 1e-8,
90
80
  weight_decay: float = 0.0,
91
81
  use_bias_correction: bool = True,
92
- vector_reshape: bool = True,
82
+ vector_reshape: bool = False,
93
83
  stochastic_rounding: bool = True,
94
84
  use_atan2: bool = False,
95
85
  cautious_mask: bool = False,
@@ -98,15 +88,15 @@ class AdamW_adv(torch.optim.Optimizer):
98
88
  use_AdEMAMix: bool = False,
99
89
  beta3_ema: float = 0.9999,
100
90
  alpha: float = 5.0,
101
- t_alpha: int | None = None,
102
91
  kourkoutas_beta: bool = False,
103
92
  beta2_min: float = 0.9,
104
93
  ema_alpha: float = 0.95,
105
94
  tiny_spike: float = 1e-9,
106
95
  k_warmup_steps: int = 0,
107
96
  k_logging: int = 0,
108
- layer_key_fn: Optional[Callable] = None,
109
97
  nnmf_factor: bool = False,
98
+ # Compiled
99
+ compiled_optimizer: bool = False,
110
100
  ):
111
101
  if not (lr >= 0.0):
112
102
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -116,7 +106,8 @@ class AdamW_adv(torch.optim.Optimizer):
116
106
  raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
117
107
  if not (weight_decay >= 0.0):
118
108
  raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
119
- if kourkoutas_beta and not (betas[1] > beta2_min): raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
109
+ if kourkoutas_beta and not (betas[1] > beta2_min):
110
+ raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
120
111
 
121
112
  if cautious_mask and grams_moment:
122
113
  print("Warning: cautious is incompatible with grams, Disabling cautious.")
@@ -126,9 +117,10 @@ class AdamW_adv(torch.optim.Optimizer):
126
117
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
127
118
  "vector_reshape": vector_reshape, "use_atan2": use_atan2,
128
119
  "orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
129
- "beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
120
+ "beta3_ema": beta3_ema, "alpha": alpha,
130
121
  "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
131
122
  "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
123
+ "compiled_optimizer": compiled_optimizer,
132
124
  }
133
125
  self.stochastic_rounding = stochastic_rounding
134
126
  self.cautious_mask = cautious_mask
@@ -136,12 +128,20 @@ class AdamW_adv(torch.optim.Optimizer):
136
128
  self.use_AdEMAMix = use_AdEMAMix
137
129
  self.factored = nnmf_factor
138
130
  self.kourkoutas_beta = kourkoutas_beta
139
- self.layer_key_fn = layer_key_fn
131
+
140
132
  super().__init__(params, defaults)
141
133
 
134
+ self.init_step()
135
+
142
136
  if self.kourkoutas_beta:
143
137
  self.kourkoutas_helper = KourkoutasHelper(self)
144
138
 
139
+ self.global_step = 0
140
+
141
+ if compiled_optimizer:
142
+ torch._dynamo.config.cache_size_limit = 8192
143
+ self.compile(fullgraph=True)
144
+
145
145
  @property
146
146
  def supports_fused_back_pass(self):
147
147
  return True
@@ -154,29 +154,22 @@ class AdamW_adv(torch.optim.Optimizer):
154
154
  def supports_flat_params(self):
155
155
  return False
156
156
 
157
- @torch.no_grad()
158
- def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
159
- if p.grad is None:
160
- return
157
+ def init_step(self):
158
+ for group in self.param_groups:
159
+ for p in group['params']:
160
+ self.__init_state(p, group)
161
161
 
162
- grad = p.grad
163
- if grad.dtype != torch.float32 and self.factored:
164
- grad = grad.float()
165
- if group["orthogonal_gradient"]:
166
- grad = _orthogonalize_gradient(p, grad)
162
+ @torch.no_grad()
163
+ def __init_state(self, p, group):
167
164
  state = self.state[p]
168
165
 
169
- # State Initialization
170
- if 'step' not in state:
171
- state['step'] = 0
166
+ if len(state) == 0:
172
167
 
173
- should_factor = (
168
+ state['factored'] = (
174
169
  self.factored and
175
170
  not (len(p.shape) == 1 and not group['vector_reshape'])
176
171
  )
177
172
 
178
- state['factored'] = should_factor
179
-
180
173
  dtype = torch.float32 if self.factored else p.dtype
181
174
  device = p.device
182
175
 
@@ -186,18 +179,18 @@ class AdamW_adv(torch.optim.Optimizer):
186
179
 
187
180
  # First moment (m)
188
181
  if group['betas'][0] > 0:
189
- state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
182
+ state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
190
183
  state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
191
184
  if not self.grams_moment:
192
185
  packed_d2 = (d2 + 7) // 8
193
186
  state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
194
187
  if self.use_AdEMAMix:
195
- state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
188
+ state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
196
189
  state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
197
190
  packed_d2 = (d2 + 7) // 8
198
191
  state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
199
192
  # Second moment (v)
200
- state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
193
+ state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
201
194
  state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
202
195
  else: # Fallback to standard AdamW for non-factored tensors
203
196
  if group['betas'][0] > 0:
@@ -206,37 +199,32 @@ class AdamW_adv(torch.optim.Optimizer):
206
199
  state['exp_avg_slow'] = torch.zeros_like(p, device=device, dtype=dtype)
207
200
  state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
208
201
 
202
+ @torch.no_grad()
203
+ def __step_parameter(self, p: torch.Tensor, group: dict, lr: torch.Tensor | float, bias_correction1: torch.Tensor | float, bias_correction2: torch.Tensor | float):
204
+ if p.grad is None:
205
+ return
206
+
207
+ grad = p.grad
208
+ if grad.dtype != torch.float32 and self.factored:
209
+ grad = grad.float()
210
+ if group["orthogonal_gradient"]:
211
+ grad = _orthogonalize_gradient(p, grad)
212
+ state = self.state[p]
213
+
214
+
209
215
  beta1, beta2 = group['betas']
210
216
 
211
- current_step = state['step']
212
217
  if group.get('kourkoutas_beta', False):
213
- # Call prepare_step() once at the beginning of the step for all params
214
- self.kourkoutas_helper.maybe_prepare_step(current_step)
215
218
  # Accumulate current grad's norm for the *next* step
216
219
  self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
217
220
  # Get the dynamic beta2 calculated in prepare_step()
218
- beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
221
+ beta2 = self.kourkoutas_helper.get_beta2(p, group)
219
222
 
220
- step = state['step'] + 1
221
- if group['use_bias_correction']:
222
- bias_correction1 = 1.0 - beta1 ** step
223
- if group.get('kourkoutas_beta', False):
224
- bias_correction2 = 1.0 - group['betas'][1] ** step
225
- # Use beta2_max for bias correction
226
- else:
227
- bias_correction2 = 1.0 - beta2 ** step
228
- else:
229
- bias_correction1 = 1
230
- bias_correction2 = 1
231
- step_size = group['lr'] / bias_correction1
223
+ step_size = lr / bias_correction1
232
224
 
233
225
  if self.use_AdEMAMix:
234
226
  beta3_ema = group['beta3_ema']
235
227
  alpha = group['alpha']
236
- t_alpha = group['t_alpha']
237
- alpha_t = alpha
238
- if t_alpha is not None and t_alpha > 0 and step < t_alpha:
239
- alpha_t = min(step * alpha / t_alpha, alpha)
240
228
 
241
229
  if state['factored']:
242
230
  d1, d2 = state['effective_shape']
@@ -252,7 +240,7 @@ class AdamW_adv(torch.optim.Optimizer):
252
240
  # Update momentum in full-size
253
241
  mt.mul_(beta1).add_(grad_reshaped, alpha=1.0 - beta1)
254
242
  if self.grams_moment:
255
- mt.copy_(grad_reshaped.sign() * mt.abs())
243
+ mt = (grad_reshaped.sign().mul_(mt.abs()))
256
244
  elif self.cautious_mask:
257
245
  mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
258
246
  mask.div_(mask.mean().clamp_(min=1e-3))
@@ -272,9 +260,9 @@ class AdamW_adv(torch.optim.Optimizer):
272
260
 
273
261
  mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
274
262
  if beta1 > 0:
275
- update = torch.add(mt, mt_slow, alpha=alpha_t)
263
+ update = torch.add(mt, mt_slow, alpha=alpha)
276
264
  else:
277
- update = torch.add(grad_reshaped, mt_slow, alpha=alpha_t)
265
+ update = torch.add(grad_reshaped, mt_slow, alpha=alpha)
278
266
  else:
279
267
  update = mt.clone() if beta1 > 0 else grad_reshaped.clone()
280
268
  del grad_reshaped
@@ -310,7 +298,7 @@ class AdamW_adv(torch.optim.Optimizer):
310
298
  exp_avg = state['exp_avg']
311
299
  exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
312
300
  if self.grams_moment:
313
- exp_avg = grad.sign() * exp_avg.abs()
301
+ exp_avg = grad.sign().mul_(exp_avg.abs())
314
302
  elif self.cautious_mask:
315
303
  mask = (exp_avg * grad > 0).to(grad.dtype)
316
304
  mask.div_(mask.mean().clamp_(min=1e-3))
@@ -321,9 +309,9 @@ class AdamW_adv(torch.optim.Optimizer):
321
309
  exp_avg_slow = state['exp_avg_slow']
322
310
  exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
323
311
  if beta1 > 0:
324
- update = torch.add(exp_avg, exp_avg_slow, alpha=alpha_t)
312
+ update = torch.add(exp_avg, exp_avg_slow, alpha=alpha)
325
313
  else:
326
- update = torch.add(grad, exp_avg_slow, alpha=alpha_t)
314
+ update = torch.add(grad, exp_avg_slow, alpha=alpha)
327
315
  else:
328
316
  update = exp_avg.clone() if beta1 > 0 else grad.clone()
329
317
 
@@ -343,9 +331,9 @@ class AdamW_adv(torch.optim.Optimizer):
343
331
  # Decoupled weight decay
344
332
  if group["weight_decay"] != 0:
345
333
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
346
- add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
334
+ add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * lr)
347
335
  else:
348
- p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
336
+ p.data.add_(p.data, alpha=-group["weight_decay"] * lr)
349
337
 
350
338
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
351
339
  add_stochastic_(p.data, -update)
@@ -353,7 +341,38 @@ class AdamW_adv(torch.optim.Optimizer):
353
341
  p.data.add_(-update)
354
342
  del update
355
343
 
356
- state['step'] += 1
344
+ @torch.no_grad()
345
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
346
+ # if 'exp_avg_sq' not in self.state[p] and 'mu_v_nmf' not in self.state[p]:
347
+ # return
348
+
349
+ if self.global_step is None and 'step' in self.state[p]:
350
+ # For backward compatibility
351
+ self.global_step = self.state[p]['step']
352
+
353
+ if group['use_bias_correction']:
354
+ current_step = self.global_step + 1
355
+ beta1, beta2 = group['betas']
356
+ bias_correction1 = 1.0 - beta1 ** current_step
357
+ bias_correction2 = 1.0 - beta2 ** current_step
358
+ else:
359
+ bias_correction1 = 1.0
360
+ bias_correction2 = 1.0
361
+
362
+ if group.get('kourkoutas_beta', False):
363
+ # Prepare Kourkoutas-β once per step using the global step counter.
364
+ self.kourkoutas_helper.maybe_prepare_step(self.global_step)
365
+
366
+ if not group.get('compiled_optimizer', False):
367
+ self.__step_parameter(p, group, group['lr'], bias_correction1, bias_correction2)
368
+ else:
369
+ lr_tensor = torch.tensor(group['lr'], device=p.device)
370
+ bias_correction1_tensor = torch.tensor(bias_correction1, device=p.device)
371
+ bias_correction2_tensor = torch.tensor(bias_correction2, device=p.device)
372
+ self._compiled_step_parameter(p, group, lr_tensor, bias_correction1_tensor, bias_correction2_tensor)
373
+
374
+ def compile(self, *args, **kwargs):
375
+ self._compiled_step_parameter = torch.compile(self.__step_parameter, *args, **kwargs)
357
376
 
358
377
  @torch.no_grad()
359
378
  def step(self, closure=None):
@@ -367,4 +386,6 @@ class AdamW_adv(torch.optim.Optimizer):
367
386
  for i, p in enumerate(group['params']):
368
387
  self.step_parameter(p, group, i)
369
388
 
389
+ self.global_step += 1
390
+
370
391
  return loss