adv-optm 1.0.0__tar.gz → 1.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

Files changed (24) hide show
  1. {adv_optm-1.0.0 → adv_optm-1.0.1}/PKG-INFO +1 -1
  2. {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm/__init__.py +1 -1
  3. {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm/optim/AdamW_adv.py +19 -19
  4. {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm/optim/Adopt_adv.py +24 -24
  5. {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm/optim/Lion_Prodigy_adv.py +8 -8
  6. {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm/optim/Lion_adv.py +8 -8
  7. {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm/optim/Prodigy_adv.py +475 -475
  8. {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm/optim/Simplified_AdEMAMix.py +3 -3
  9. {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm.egg-info/PKG-INFO +1 -1
  10. {adv_optm-1.0.0 → adv_optm-1.0.1}/setup.py +1 -1
  11. {adv_optm-1.0.0 → adv_optm-1.0.1}/LICENSE +0 -0
  12. {adv_optm-1.0.0 → adv_optm-1.0.1}/README.md +0 -0
  13. {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm/optim/__init__.py +0 -0
  14. {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
  15. {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm/util/Effective_Shape.py +0 -0
  16. {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm/util/NNMF.py +0 -0
  17. {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm/util/One_Bit_Boolean.py +0 -0
  18. {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm/util/OrthoGrad.py +0 -0
  19. {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm/util/__init__.py +0 -0
  20. {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm.egg-info/SOURCES.txt +0 -0
  21. {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm.egg-info/dependency_links.txt +0 -0
  22. {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm.egg-info/requires.txt +0 -0
  23. {adv_optm-1.0.0 → adv_optm-1.0.1}/adv_optm.egg-info/top_level.txt +0 -0
  24. {adv_optm-1.0.0 → adv_optm-1.0.1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.0.0
3
+ Version: 1.0.1
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -16,4 +16,4 @@ __all__ = [
16
16
  "Lion_Prodigy_adv",
17
17
  ]
18
18
 
19
- __version__ = "1.0.0"
19
+ __version__ = "1.0.1"
@@ -30,8 +30,8 @@ class AdamW_adv(torch.optim.Optimizer):
30
30
  stochastic_rounding (bool): whether to use stochastic
31
31
  rounding for BF16 parameter updates (default: True).
32
32
  use_atan2 (bool): whether to use the atan2 update rule. (default: False)
33
- use_grams (bool): whether to use Grams-style updates. (default: False)
34
- use_cautious (bool): whether to use cautious masking to align the gradient's
33
+ grams_moment (bool): whether to use Grams-style updates. (default: False)
34
+ cautious_mask (bool): whether to use cautious masking to align the gradient's
35
35
  direction with the first moment's. (default: False)
36
36
  use_orthograd (bool): whether to use OrthoGrad. (default: False)
37
37
  use_AdEMAMix (bool): whether to enable the AdEMAMix feature. This adds
@@ -54,7 +54,7 @@ class AdamW_adv(torch.optim.Optimizer):
54
54
  as it gradually introduces the stabilizing slow momentum term. During
55
55
  the warmup, `alpha` ramps from 0 to its target value. If `None`,
56
56
  the scheduler is disabled. (default: None)
57
- factored (bool): whether to use the factorization or disable it to use
57
+ nnmf_factor (bool): whether to use the factorization or disable it to use
58
58
  the uncompressed optimizer. (default: False)
59
59
  """
60
60
 
@@ -69,14 +69,14 @@ class AdamW_adv(torch.optim.Optimizer):
69
69
  vector_reshape: bool = True,
70
70
  stochastic_rounding: bool = True,
71
71
  use_atan2: bool = False,
72
- use_cautious: bool = False,
73
- use_grams: bool = False,
72
+ cautious_mask: bool = False,
73
+ grams_moment: bool = False,
74
74
  use_orthograd: bool = False,
75
75
  use_AdEMAMix: bool = False,
76
76
  beta3_ema: float = 0.9999,
77
77
  alpha: float = 5.0,
78
78
  t_alpha: int | None = None,
79
- factored: bool = False,
79
+ nnmf_factor: bool = False,
80
80
  ):
81
81
  if not (lr >= 0.0):
82
82
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -86,9 +86,9 @@ class AdamW_adv(torch.optim.Optimizer):
86
86
  raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
87
87
  if not (weight_decay >= 0.0):
88
88
  raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
89
- if use_cautious and use_grams:
90
- print("Warning: use_cautious is incompatible with use_grams, Disabling use_cautious.")
91
- use_cautious = False
89
+ if cautious_mask and grams_moment:
90
+ print("Warning: cautious is incompatible with grams, Disabling cautious.")
91
+ cautious_mask = False
92
92
 
93
93
  defaults = {
94
94
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
@@ -97,10 +97,10 @@ class AdamW_adv(torch.optim.Optimizer):
97
97
  "beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
98
98
  }
99
99
  self.stochastic_rounding = stochastic_rounding
100
- self.use_cautious = use_cautious
101
- self.use_grams = use_grams
100
+ self.cautious_mask = cautious_mask
101
+ self.grams_moment = grams_moment
102
102
  self.use_AdEMAMix = use_AdEMAMix
103
- self.factored = factored
103
+ self.factored = nnmf_factor
104
104
  super().__init__(params, defaults)
105
105
 
106
106
  @property
@@ -151,7 +151,7 @@ class AdamW_adv(torch.optim.Optimizer):
151
151
  if beta1 > 0:
152
152
  state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
153
153
  state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
154
- if not self.use_grams:
154
+ if not self.grams_moment:
155
155
  packed_d2 = (d2 + 7) // 8
156
156
  state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
157
157
  if self.use_AdEMAMix:
@@ -192,16 +192,16 @@ class AdamW_adv(torch.optim.Optimizer):
192
192
  # Reconstruct momentum from previous step's factors
193
193
  if beta1 > 0:
194
194
  mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
195
- if not self.use_grams:
195
+ if not self.grams_moment:
196
196
  unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
197
197
  torch.where(unpacked_sign, mt, -mt, out=mt)
198
198
  del unpacked_sign
199
199
  # Update momentum in full-size
200
200
  grad_reshaped = grad.view(d1, d2)
201
201
  mt.mul_(beta1).add_(grad_reshaped, alpha=1.0 - beta1)
202
- if self.use_grams:
202
+ if self.grams_moment:
203
203
  mt.copy_(grad_reshaped.sign() * mt.abs())
204
- elif self.use_cautious:
204
+ elif self.cautious_mask:
205
205
  mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
206
206
  mask.div_(mask.mean().clamp_(min=1e-3))
207
207
  mt.mul_(mask)
@@ -240,7 +240,7 @@ class AdamW_adv(torch.optim.Optimizer):
240
240
 
241
241
  # Compress updated moments and store new factors
242
242
  if beta1 > 0:
243
- if not self.use_grams:
243
+ if not self.grams_moment:
244
244
  state['sign'] = _pack_bools(mt > 0)
245
245
  _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
246
246
  del mt
@@ -257,9 +257,9 @@ class AdamW_adv(torch.optim.Optimizer):
257
257
  if beta1 > 0:
258
258
  exp_avg = state['exp_avg']
259
259
  exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
260
- if self.use_grams:
260
+ if self.grams_moment:
261
261
  exp_avg = grad.sign() * exp_avg.abs()
262
- elif self.use_cautious:
262
+ elif self.cautious_mask:
263
263
  mask = (exp_avg * grad > 0).to(grad.dtype)
264
264
  mask.div_(mask.mean().clamp_(min=1e-3))
265
265
  exp_avg.mul_(mask)
@@ -36,9 +36,9 @@ class Adopt_adv(torch.optim.Optimizer):
36
36
  rounding for BF16 parameter updates (default: True).
37
37
  use_atan2 (bool): whether to use an atan2-based normalization, which can
38
38
  improve stability by removing the need for `eps`. (default: False)
39
- use_cautious (bool): whether to use cautious masking to align the gradient's
39
+ cautious_mask (bool): whether to use cautious masking to align the gradient's
40
40
  direction with the first moment's. (default: False)
41
- use_grams (bool): whether to combine the gradient's direction with the
41
+ grams_moment (bool): whether to combine the gradient's direction with the
42
42
  first moment's magnitude (default: False).
43
43
  use_orthograd (bool): whether to use OrthoGrad. (default: False)
44
44
  use_AdEMAMix (bool): whether to enable the AdEMAMix feature. This adds
@@ -65,14 +65,14 @@ class Adopt_adv(torch.optim.Optimizer):
65
65
  Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
66
66
  This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
67
67
  more responsive, especially for small batch sizes. Enabling this will
68
- automatically disable `use_AdEMAMix`, `use_cautious`, `use_grams`,
68
+ automatically disable `use_AdEMAMix`, `cautious_mask`, `grams_moment`,
69
69
  and `use_atan2`. (default: False)
70
70
  alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
71
71
  (only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
72
72
  current gradient. For small batch sizes, use high values (e.g., 10-100) to be
73
73
  more responsive. For large batch sizes, use low values (e.g., 0-1) for
74
74
  stability. (default: 100.0)
75
- factored (bool): whether to use the factorization or disable it to use
75
+ nnmf_factor (bool): whether to use the factorization or disable it to use
76
76
  the uncompressed optimizer. (default: False)
77
77
  """
78
78
 
@@ -87,8 +87,8 @@ class Adopt_adv(torch.optim.Optimizer):
87
87
  vector_reshape: bool = True,
88
88
  stochastic_rounding: bool = True,
89
89
  use_atan2: bool = False,
90
- use_cautious: bool = False,
91
- use_grams: bool = False,
90
+ cautious_mask: bool = False,
91
+ grams_moment: bool = False,
92
92
  use_orthograd: bool = False,
93
93
  use_AdEMAMix: bool = False,
94
94
  beta3_ema: float = 0.9999,
@@ -96,7 +96,7 @@ class Adopt_adv(torch.optim.Optimizer):
96
96
  t_alpha: int | None = None,
97
97
  Simplified_AdEMAMix: bool = False,
98
98
  alpha_grad: float = 100.0,
99
- factored: bool = False,
99
+ nnmf_factor: bool = False,
100
100
  ):
101
101
  if not (lr >= 0.0):
102
102
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -106,17 +106,17 @@ class Adopt_adv(torch.optim.Optimizer):
106
106
  raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
107
107
  if not (weight_decay >= 0.0):
108
108
  raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
109
- if use_cautious and use_grams:
110
- print("Warning: use_cautious is incompatible with use_grams, Disabling use_cautious.")
111
- use_cautious = False
109
+ if cautious_mask and grams_moment:
110
+ print("Warning: cautious is incompatible with grams, Disabling cautious.")
111
+ cautious_mask = False
112
112
  if betas[0] == 0.0 and Simplified_AdEMAMix:
113
113
  raise ValueError(f"Beta1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
114
114
  if use_AdEMAMix and Simplified_AdEMAMix:
115
115
  print("Warning: use_AdEMAMix is incompatible with Simplified_AdEMAMix, Disabling use_AdEMAMix.")
116
- if use_grams and Simplified_AdEMAMix:
117
- print("Warning: use_grams is incompatible with Simplified_AdEMAMix, Disabling use_grams.")
118
- if use_cautious and Simplified_AdEMAMix:
119
- print("Warning: use_cautious is incompatible with Simplified_AdEMAMix, Disabling use_cautious.")
116
+ if grams_moment and Simplified_AdEMAMix:
117
+ print("Warning: grams is incompatible with Simplified_AdEMAMix, Disabling grams.")
118
+ if cautious_mask and Simplified_AdEMAMix:
119
+ print("Warning: cautious is incompatible with Simplified_AdEMAMix, Disabling cautious.")
120
120
  if use_atan2 and Simplified_AdEMAMix:
121
121
  print("Warning: use_atan2 is incompatible with Simplified_AdEMAMix. Disabling use_atan2.")
122
122
  use_atan2 = False
@@ -129,12 +129,12 @@ class Adopt_adv(torch.optim.Optimizer):
129
129
  self.clip_lambda = clip_lambda
130
130
  self.stochastic_rounding = stochastic_rounding
131
131
  self.use_atan2 = use_atan2 and not Simplified_AdEMAMix
132
- self.use_cautious = use_cautious and not Simplified_AdEMAMix
133
- self.use_grams = use_grams and not Simplified_AdEMAMix
132
+ self.cautious_mask = cautious_mask and not Simplified_AdEMAMix
133
+ self.grams_moment = grams_moment and not Simplified_AdEMAMix
134
134
  self.use_orthograd = use_orthograd
135
135
  self.use_AdEMAMix = use_AdEMAMix and not Simplified_AdEMAMix
136
136
  self.Simplified_AdEMAMix = Simplified_AdEMAMix
137
- self.factored = factored
137
+ self.factored = nnmf_factor
138
138
  super().__init__(params, defaults)
139
139
 
140
140
  @property
@@ -176,7 +176,7 @@ class Adopt_adv(torch.optim.Optimizer):
176
176
  # m_0 = 0
177
177
  state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
178
178
  state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
179
- if not self.use_grams:
179
+ if not self.grams_moment:
180
180
  packed_d2 = (d2 + 7) // 8
181
181
  state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
182
182
  if self.use_AdEMAMix:
@@ -220,7 +220,7 @@ class Adopt_adv(torch.optim.Optimizer):
220
220
 
221
221
  # Reconstruct m_{t-1}
222
222
  mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
223
- if not self.use_grams:
223
+ if not self.grams_moment:
224
224
  if state['sign'].dtype != torch.uint8:
225
225
  state['sign'] = state['sign'].to(torch.uint8)
226
226
  unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
@@ -257,9 +257,9 @@ class Adopt_adv(torch.optim.Optimizer):
257
257
  mt.mul_(beta1).add_(normalized_grad, alpha=1.0)
258
258
  else:
259
259
  mt.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
260
- if self.use_grams:
260
+ if self.grams_moment:
261
261
  mt = grad_reshaped.sign() * mt.abs()
262
- elif self.use_cautious:
262
+ elif self.cautious_mask:
263
263
  mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
264
264
  mask.div_(mask.mean().clamp_(min=1e-3))
265
265
  mt.mul_(mask)
@@ -284,7 +284,7 @@ class Adopt_adv(torch.optim.Optimizer):
284
284
  del grad_reshaped
285
285
 
286
286
  # Compress and store new factors
287
- if not self.use_grams:
287
+ if not self.grams_moment:
288
288
  state['sign'] = _pack_bools(mt > 0)
289
289
  _nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
290
290
  del mt
@@ -322,9 +322,9 @@ class Adopt_adv(torch.optim.Optimizer):
322
322
  else:
323
323
  m.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
324
324
 
325
- if self.use_grams:
325
+ if self.grams_moment:
326
326
  m = grad.sign() * m.abs()
327
- elif self.use_cautious:
327
+ elif self.cautious_mask:
328
328
  mask = (m * grad > 0).to(grad.dtype)
329
329
  mask.div_(mask.mean().clamp_(min=1e-3))
330
330
  m.mul_(mask)
@@ -26,12 +26,12 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
26
26
  matrices to apply low-rank compression (default: True).
27
27
  stochastic_rounding (bool, optional): whether to use stochastic
28
28
  rounding for BF16 parameter updates (default: True).
29
- use_cautious (bool): whether to use the cautious masking technique. (default: False).
29
+ cautious_mask (bool): whether to use the cautious masking technique. (default: False).
30
30
  clip_threshold (float, optional): whether to clip the gradients norm
31
31
  per-parameter as proposed in the paper `Lions and Muons: Optimization via
32
32
  Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
33
33
  (default: 0.0).
34
- factored (bool): whether to use the factorization or use the
34
+ nnmf_factor (bool): whether to use the factorization or use the
35
35
  uncompressed optimizer. (default: True)
36
36
  d0 (float):
37
37
  Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
@@ -61,9 +61,9 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
61
61
  vector_reshape: bool = True,
62
62
  stochastic_rounding: bool = True,
63
63
  use_orthograd: bool = False,
64
- use_cautious: bool = False,
64
+ cautious_mask: bool = False,
65
65
  clip_threshold: float = 0.0,
66
- factored: bool = True,
66
+ nnmf_factor: bool = True,
67
67
  # prodigy parameters
68
68
  beta3: float = None,
69
69
  d0: float = 1e-6,
@@ -92,8 +92,8 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
92
92
  fsdp_in_use=fsdp_in_use,
93
93
  )
94
94
  self.stochastic_rounding = stochastic_rounding
95
- self.use_cautious = use_cautious
96
- self.factored = factored
95
+ self.cautious_mask = cautious_mask
96
+ self.factored = nnmf_factor
97
97
  self.fsdp_in_use = fsdp_in_use
98
98
  super().__init__(params, defaults)
99
99
  # Global state for accumulating metrics across parameter updates within a single step.
@@ -197,7 +197,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
197
197
  # Compute update term c_t = β1*m_{t-1} + (1-β1)*g_t
198
198
  signed_update = exp_avg.clone().mul_(self.beta1).add_(grad_reshaped, alpha=(1-self.beta1)).sign_()
199
199
 
200
- if self.use_cautious:
200
+ if self.cautious_mask:
201
201
  mask = (signed_update * grad_reshaped > 0).to(grad_reshaped.dtype)
202
202
  mask.div_(mask.mean().clamp_(min=1e-3))
203
203
  signed_update.mul_(mask)
@@ -224,7 +224,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
224
224
  exp_avg = exp_avg.float()
225
225
  signed_update = exp_avg.clone().mul_(self.beta1).add_(grad, alpha=(1-self.beta1)).sign_()
226
226
 
227
- if self.use_cautious:
227
+ if self.cautious_mask:
228
228
  mask = (signed_update * grad > 0).to(grad.dtype)
229
229
  mask.div_(mask.mean().clamp_(min=1e-3))
230
230
  signed_update.mul_(mask)
@@ -26,12 +26,12 @@ class Lion_adv(torch.optim.Optimizer):
26
26
  matrices to apply low-rank compression (default: True).
27
27
  stochastic_rounding (bool, optional): whether to use stochastic
28
28
  rounding for BF16 parameter updates (default: True).
29
- use_cautious (bool): whether to use the cautious masking technique. (default: False).
29
+ cautious_mask (bool): whether to use the cautious masking technique. (default: False).
30
30
  clip_threshold (float, optional): whether to clip the gradients norm
31
31
  per-parameter as proposed in the paper `Lions and Muons: Optimization via
32
32
  Stochastic Frank-Wolfe` (https://arxiv.org/abs/2506.04192) to make Lion more stable
33
33
  (default: 0.0).
34
- factored (bool): whether to use the factorization or use the
34
+ nnmf_factor (bool): whether to use the factorization or use the
35
35
  uncompressed optimizer. (default: True)
36
36
  """
37
37
 
@@ -44,9 +44,9 @@ class Lion_adv(torch.optim.Optimizer):
44
44
  vector_reshape: bool = True,
45
45
  stochastic_rounding: bool = True,
46
46
  use_orthograd: bool = False,
47
- use_cautious: bool = False,
47
+ cautious_mask: bool = False,
48
48
  clip_threshold: float = 0.0,
49
- factored: bool = True,
49
+ nnmf_factor: bool = True,
50
50
  ):
51
51
  if not lr > 0.0:
52
52
  raise ValueError(f"Learning rate must be > 0.0, but got {lr}")
@@ -64,8 +64,8 @@ class Lion_adv(torch.optim.Optimizer):
64
64
  clip_threshold=clip_threshold,
65
65
  )
66
66
  self.stochastic_rounding = stochastic_rounding
67
- self.use_cautious = use_cautious
68
- self.factored = factored
67
+ self.cautious_mask = cautious_mask
68
+ self.factored = nnmf_factor
69
69
  super().__init__(params, defaults)
70
70
 
71
71
  @property
@@ -140,7 +140,7 @@ class Lion_adv(torch.optim.Optimizer):
140
140
  # Compute update term c_t
141
141
  signed_update = exp_avg.clone().mul_(beta1).add_(grad_reshaped, alpha=(1-beta1)).sign_()
142
142
 
143
- if self.use_cautious:
143
+ if self.cautious_mask:
144
144
  mask = (signed_update * grad_reshaped > 0).to(grad_reshaped.dtype)
145
145
  mask.div_(mask.mean().clamp_(min=1e-3))
146
146
  signed_update.mul_(mask)
@@ -167,7 +167,7 @@ class Lion_adv(torch.optim.Optimizer):
167
167
  exp_avg = exp_avg.float()
168
168
  signed_update = exp_avg.clone().mul_(beta1).add_(grad, alpha=(1-beta1)).sign_()
169
169
 
170
- if self.use_cautious:
170
+ if self.cautious_mask:
171
171
  mask = (signed_update * grad > 0).to(grad.dtype)
172
172
  mask.div_(mask.mean().clamp_(min=1e-3))
173
173
  signed_update.mul_(mask)