adv-optm 1.0.5__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

adv_optm/__init__.py CHANGED
@@ -16,4 +16,4 @@ __all__ = [
16
16
  "Lion_Prodigy_adv",
17
17
  ]
18
18
 
19
- __version__ = "1.0.5"
19
+ __version__ = "1.1.0"
@@ -1,15 +1,16 @@
1
1
  import torch
2
- from typing import Optional
2
+ from typing import Optional, Callable
3
3
 
4
4
  from ..util.BF16_Stochastic_Rounding import add_stochastic_
5
5
  from ..util.Effective_Shape import _get_effective_shape
6
6
  from ..util.NNMF import _nnmf,_unnmf
7
7
  from ..util.OrthoGrad import _orthogonalize_gradient
8
8
  from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
9
+ from ..util.Kourkoutas import KourkoutasHelper
9
10
 
10
11
  class AdamW_adv(torch.optim.Optimizer):
11
12
  """
12
- Implements a factored AdamW algorithm.
13
+ Implements an advanced AdamW algorithm.
13
14
  This is an advanced version of AdamW with optional features like
14
15
  low-rank factorization of optimizer states (SMMF), OrthoGrad, etc.
15
16
 
@@ -54,6 +55,28 @@ class AdamW_adv(torch.optim.Optimizer):
54
55
  as it gradually introduces the stabilizing slow momentum term. During
55
56
  the warmup, `alpha` ramps from 0 to its target value. If `None`,
56
57
  the scheduler is disabled. (default: None)
58
+ kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
59
+ If `False`, the optimizer behaves as standard AdamW. (default: False)
60
+ beta2_min (float): The minimum value for dynamic β₂, used during periods of
61
+ high gradient variance ("sunspikes"). Must be less than `betas[1]`.
62
+ (default: 0.88)
63
+ ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
64
+ the pooled gradient norms. Corresponds to `α` in the paper.
65
+ (default: 0.93)
66
+ tiny_spike (float): A small constant added to the denominator of the
67
+ "sunspike" ratio calculation to prevent division by zero. Corresponds
68
+ to `ε_spike` in the paper. (default: 1e-9)
69
+ k_warmup_steps (int): The number of initial steps during which β₂ is held
70
+ at a fixed beta2 value before the
71
+ dynamic logic activates. (default: 0)
72
+ k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
73
+ logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
74
+ every logging steps. Useful for debugging and tuning. Set to 0 to disable
75
+ logging (default: 0).
76
+ layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
77
+ and returns a unique, hashable key representing its "layer" or "bucket".
78
+ If `None`, parameters are bucketed by their memory ID (tensor-wise).
79
+ (default: None)
57
80
  nnmf_factor (bool): whether to use the factorization or disable it to use
58
81
  the uncompressed optimizer. (default: False)
59
82
  """
@@ -76,6 +99,13 @@ class AdamW_adv(torch.optim.Optimizer):
76
99
  beta3_ema: float = 0.9999,
77
100
  alpha: float = 5.0,
78
101
  t_alpha: int | None = None,
102
+ kourkoutas_beta: bool = False,
103
+ beta2_min: float = 0.9,
104
+ ema_alpha: float = 0.95,
105
+ tiny_spike: float = 1e-9,
106
+ k_warmup_steps: int = 0,
107
+ k_logging: int = 0,
108
+ layer_key_fn: Optional[Callable] = None,
79
109
  nnmf_factor: bool = False,
80
110
  ):
81
111
  if not (lr >= 0.0):
@@ -86,6 +116,8 @@ class AdamW_adv(torch.optim.Optimizer):
86
116
  raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
87
117
  if not (weight_decay >= 0.0):
88
118
  raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
119
+ if kourkoutas_beta and not (betas[1] > beta2_min): raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
120
+
89
121
  if cautious_mask and grams_moment:
90
122
  print("Warning: cautious is incompatible with grams, Disabling cautious.")
91
123
  cautious_mask = False
@@ -95,14 +127,21 @@ class AdamW_adv(torch.optim.Optimizer):
95
127
  "vector_reshape": vector_reshape, "use_atan2": use_atan2,
96
128
  "orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
97
129
  "beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
130
+ "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
131
+ "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
98
132
  }
99
133
  self.stochastic_rounding = stochastic_rounding
100
134
  self.cautious_mask = cautious_mask
101
135
  self.grams_moment = grams_moment
102
136
  self.use_AdEMAMix = use_AdEMAMix
103
137
  self.factored = nnmf_factor
138
+ self.kourkoutas_beta = kourkoutas_beta
139
+ self.layer_key_fn = layer_key_fn
104
140
  super().__init__(params, defaults)
105
141
 
142
+ if self.kourkoutas_beta:
143
+ self.kourkoutas_helper = KourkoutasHelper(self)
144
+
106
145
  @property
107
146
  def supports_fused_back_pass(self):
108
147
  return True
@@ -127,10 +166,8 @@ class AdamW_adv(torch.optim.Optimizer):
127
166
  grad = _orthogonalize_gradient(p, grad)
128
167
  state = self.state[p]
129
168
 
130
- beta1, beta2 = group['betas']
131
-
132
169
  # State Initialization
133
- if len(state) == 0:
170
+ if 'step' not in state:
134
171
  state['step'] = 0
135
172
 
136
173
  should_factor = (
@@ -148,7 +185,7 @@ class AdamW_adv(torch.optim.Optimizer):
148
185
  d1, d2 = state['effective_shape']
149
186
 
150
187
  # First moment (m)
151
- if beta1 > 0:
188
+ if group['betas'][0] > 0:
152
189
  state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
153
190
  state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
154
191
  if not self.grams_moment:
@@ -163,16 +200,31 @@ class AdamW_adv(torch.optim.Optimizer):
163
200
  state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
164
201
  state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
165
202
  else: # Fallback to standard AdamW for non-factored tensors
166
- if beta1 > 0:
203
+ if group['betas'][0] > 0:
167
204
  state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
168
205
  if self.use_AdEMAMix:
169
206
  state['exp_avg_slow'] = torch.zeros_like(p, device=device, dtype=dtype)
170
207
  state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
171
208
 
209
+ beta1, beta2 = group['betas']
210
+
211
+ current_step = state['step']
212
+ if group['kourkoutas_beta']:
213
+ # Call prepare_step() once at the beginning of the step for all params
214
+ self.kourkoutas_helper.maybe_prepare_step(current_step)
215
+ # Accumulate current grad's norm for the *next* step
216
+ self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
217
+ # Get the dynamic beta2 calculated in prepare_step()
218
+ beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
219
+
172
220
  step = state['step'] + 1
173
221
  if group['use_bias_correction']:
174
222
  bias_correction1 = 1.0 - beta1 ** step
175
- bias_correction2 = 1.0 - beta2 ** step
223
+ if group['kourkoutas_beta']:
224
+ bias_correction2 = 1.0 - group['betas'][1] ** step
225
+ # Use beta2_max for bias correction
226
+ else:
227
+ bias_correction2 = 1.0 - beta2 ** step
176
228
  else:
177
229
  bias_correction1 = 1
178
230
  bias_correction2 = 1
@@ -315,4 +367,4 @@ class AdamW_adv(torch.optim.Optimizer):
315
367
  for i, p in enumerate(group['params']):
316
368
  self.step_parameter(p, group, i)
317
369
 
318
- return loss
370
+ return loss