adv-optm 1.0.5__tar.gz → 1.1.0.dev1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- {adv_optm-1.0.5 → adv_optm-1.1.0.dev1}/PKG-INFO +1 -1
- {adv_optm-1.0.5 → adv_optm-1.1.0.dev1}/adv_optm/__init__.py +1 -1
- {adv_optm-1.0.5 → adv_optm-1.1.0.dev1}/adv_optm/optim/AdamW_adv.py +68 -7
- {adv_optm-1.0.5 → adv_optm-1.1.0.dev1}/adv_optm/optim/Adopt_adv.py +111 -43
- {adv_optm-1.0.5 → adv_optm-1.1.0.dev1}/adv_optm/optim/Prodigy_adv.py +71 -12
- {adv_optm-1.0.5 → adv_optm-1.1.0.dev1}/adv_optm/optim/Simplified_AdEMAMix.py +62 -1
- adv_optm-1.1.0.dev1/adv_optm/util/Kourkoutas.py +108 -0
- {adv_optm-1.0.5 → adv_optm-1.1.0.dev1}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-1.0.5 → adv_optm-1.1.0.dev1}/adv_optm.egg-info/SOURCES.txt +1 -0
- {adv_optm-1.0.5 → adv_optm-1.1.0.dev1}/setup.py +1 -1
- {adv_optm-1.0.5 → adv_optm-1.1.0.dev1}/LICENSE +0 -0
- {adv_optm-1.0.5 → adv_optm-1.1.0.dev1}/README.md +0 -0
- {adv_optm-1.0.5 → adv_optm-1.1.0.dev1}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
- {adv_optm-1.0.5 → adv_optm-1.1.0.dev1}/adv_optm/optim/Lion_adv.py +0 -0
- {adv_optm-1.0.5 → adv_optm-1.1.0.dev1}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-1.0.5 → adv_optm-1.1.0.dev1}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
- {adv_optm-1.0.5 → adv_optm-1.1.0.dev1}/adv_optm/util/Effective_Shape.py +0 -0
- {adv_optm-1.0.5 → adv_optm-1.1.0.dev1}/adv_optm/util/NNMF.py +0 -0
- {adv_optm-1.0.5 → adv_optm-1.1.0.dev1}/adv_optm/util/One_Bit_Boolean.py +0 -0
- {adv_optm-1.0.5 → adv_optm-1.1.0.dev1}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-1.0.5 → adv_optm-1.1.0.dev1}/adv_optm/util/__init__.py +0 -0
- {adv_optm-1.0.5 → adv_optm-1.1.0.dev1}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-1.0.5 → adv_optm-1.1.0.dev1}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-1.0.5 → adv_optm-1.1.0.dev1}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-1.0.5 → adv_optm-1.1.0.dev1}/setup.cfg +0 -0
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
from typing import Optional
|
|
2
|
+
from typing import Optional, Callable
|
|
3
3
|
|
|
4
4
|
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
5
5
|
from ..util.Effective_Shape import _get_effective_shape
|
|
6
6
|
from ..util.NNMF import _nnmf,_unnmf
|
|
7
7
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
8
8
|
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
9
|
+
from ..util.Kourkoutas import KourkoutasHelper
|
|
9
10
|
|
|
10
11
|
class AdamW_adv(torch.optim.Optimizer):
|
|
11
12
|
"""
|
|
@@ -54,6 +55,28 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
54
55
|
as it gradually introduces the stabilizing slow momentum term. During
|
|
55
56
|
the warmup, `alpha` ramps from 0 to its target value. If `None`,
|
|
56
57
|
the scheduler is disabled. (default: None)
|
|
58
|
+
kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
|
|
59
|
+
If `False`, the optimizer behaves as standard AdamW. (default: False)
|
|
60
|
+
beta2_min (float): The minimum value for dynamic β₂, used during periods of
|
|
61
|
+
high gradient variance ("sunspikes"). Must be less than `betas[1]`.
|
|
62
|
+
(default: 0.88)
|
|
63
|
+
ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
|
|
64
|
+
the pooled gradient norms. Corresponds to `α` in the paper.
|
|
65
|
+
(default: 0.93)
|
|
66
|
+
tiny_spike (float): A small constant added to the denominator of the
|
|
67
|
+
"sunspike" ratio calculation to prevent division by zero. Corresponds
|
|
68
|
+
to `ε_spike` in the paper. (default: 1e-9)
|
|
69
|
+
k_warmup_steps (int): The number of initial steps during which β₂ is held
|
|
70
|
+
at a fixed average value (`(beta2_min + beta2_max) / 2`) before the
|
|
71
|
+
dynamic logic activates. (default: 0)
|
|
72
|
+
k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
|
|
73
|
+
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
74
|
+
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
75
|
+
logging (default: 0).
|
|
76
|
+
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
77
|
+
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
78
|
+
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
79
|
+
(default: None)
|
|
57
80
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
58
81
|
the uncompressed optimizer. (default: False)
|
|
59
82
|
"""
|
|
@@ -76,6 +99,13 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
76
99
|
beta3_ema: float = 0.9999,
|
|
77
100
|
alpha: float = 5.0,
|
|
78
101
|
t_alpha: int | None = None,
|
|
102
|
+
kourkoutas_beta: bool = False,
|
|
103
|
+
beta2_min: float = 0.88,
|
|
104
|
+
ema_alpha: float = 0.93,
|
|
105
|
+
tiny_spike: float = 1e-9,
|
|
106
|
+
k_warmup_steps: int = 0,
|
|
107
|
+
k_logging: int = 0,
|
|
108
|
+
layer_key_fn: Optional[Callable] = None,
|
|
79
109
|
nnmf_factor: bool = False,
|
|
80
110
|
):
|
|
81
111
|
if not (lr >= 0.0):
|
|
@@ -86,6 +116,8 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
86
116
|
raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
|
|
87
117
|
if not (weight_decay >= 0.0):
|
|
88
118
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
119
|
+
if kourkoutas_beta and not (betas[1] > beta2_min): raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
|
|
120
|
+
|
|
89
121
|
if cautious_mask and grams_moment:
|
|
90
122
|
print("Warning: cautious is incompatible with grams, Disabling cautious.")
|
|
91
123
|
cautious_mask = False
|
|
@@ -95,6 +127,8 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
95
127
|
"vector_reshape": vector_reshape, "use_atan2": use_atan2,
|
|
96
128
|
"orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
|
|
97
129
|
"beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
|
|
130
|
+
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
131
|
+
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps,
|
|
98
132
|
}
|
|
99
133
|
self.stochastic_rounding = stochastic_rounding
|
|
100
134
|
self.cautious_mask = cautious_mask
|
|
@@ -103,6 +137,12 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
103
137
|
self.factored = nnmf_factor
|
|
104
138
|
super().__init__(params, defaults)
|
|
105
139
|
|
|
140
|
+
self.kourkoutas_beta = kourkoutas_beta
|
|
141
|
+
self.k_logging= k_logging and kourkoutas_beta
|
|
142
|
+
self.layer_key_fn = layer_key_fn and kourkoutas_beta
|
|
143
|
+
if self.kourkoutas_beta:
|
|
144
|
+
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
145
|
+
|
|
106
146
|
@property
|
|
107
147
|
def supports_fused_back_pass(self):
|
|
108
148
|
return True
|
|
@@ -127,8 +167,6 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
127
167
|
grad = _orthogonalize_gradient(p, grad)
|
|
128
168
|
state = self.state[p]
|
|
129
169
|
|
|
130
|
-
beta1, beta2 = group['betas']
|
|
131
|
-
|
|
132
170
|
# State Initialization
|
|
133
171
|
if len(state) == 0:
|
|
134
172
|
state['step'] = 0
|
|
@@ -148,7 +186,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
148
186
|
d1, d2 = state['effective_shape']
|
|
149
187
|
|
|
150
188
|
# First moment (m)
|
|
151
|
-
if
|
|
189
|
+
if group['betas'][0] > 0:
|
|
152
190
|
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
153
191
|
state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
154
192
|
if not self.grams_moment:
|
|
@@ -163,16 +201,29 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
163
201
|
state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
164
202
|
state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
165
203
|
else: # Fallback to standard AdamW for non-factored tensors
|
|
166
|
-
if
|
|
204
|
+
if group['betas'][0] > 0:
|
|
167
205
|
state['exp_avg'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
168
206
|
if self.use_AdEMAMix:
|
|
169
207
|
state['exp_avg_slow'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
170
208
|
state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
171
209
|
|
|
210
|
+
current_step = state['step']
|
|
211
|
+
if group['kourkoutas_beta']:
|
|
212
|
+
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
213
|
+
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
214
|
+
|
|
215
|
+
beta1, beta2 = group['betas']
|
|
216
|
+
if group['kourkoutas_beta']:
|
|
217
|
+
beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
|
|
218
|
+
|
|
172
219
|
step = state['step'] + 1
|
|
173
220
|
if group['use_bias_correction']:
|
|
174
221
|
bias_correction1 = 1.0 - beta1 ** step
|
|
175
|
-
|
|
222
|
+
if group['kourkoutas_beta']:
|
|
223
|
+
bias_correction2 = 1.0 - group['betas'][1] ** step
|
|
224
|
+
# Use beta2_max for bias correction
|
|
225
|
+
else:
|
|
226
|
+
bias_correction2 = 1.0 - beta2 ** step
|
|
176
227
|
else:
|
|
177
228
|
bias_correction1 = 1
|
|
178
229
|
bias_correction2 = 1
|
|
@@ -315,4 +366,14 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
315
366
|
for i, p in enumerate(group['params']):
|
|
316
367
|
self.step_parameter(p, group, i)
|
|
317
368
|
|
|
318
|
-
|
|
369
|
+
if self.kourkoutas_beta and self.k_logging > 0 and hasattr(self, '_beta2_log'):
|
|
370
|
+
first_param_state = self.state[self.param_groups[0]['params'][0]]
|
|
371
|
+
step_num = first_param_state['step']
|
|
372
|
+
|
|
373
|
+
if step_num > 0 and step_num % self.k_logging == 0:
|
|
374
|
+
if self._beta2_log:
|
|
375
|
+
beta2_tensor = torch.tensor(self._beta2_log, device='cpu')
|
|
376
|
+
print(f"Step {step_num}: Kourkoutas beta2 stats: Min={beta2_tensor.min():.4f}, Max={beta2_tensor.max():.4f}, Mean={beta2_tensor.mean():.4f}")
|
|
377
|
+
delattr(self, '_beta2_log')
|
|
378
|
+
|
|
379
|
+
return loss
|
|
@@ -6,6 +6,7 @@ from ..util.Effective_Shape import _get_effective_shape
|
|
|
6
6
|
from ..util.NNMF import _nnmf, _unnmf
|
|
7
7
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
8
8
|
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
9
|
+
from ..util.Kourkoutas import KourkoutasHelper
|
|
9
10
|
|
|
10
11
|
class Adopt_adv(torch.optim.Optimizer):
|
|
11
12
|
"""
|
|
@@ -72,6 +73,28 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
72
73
|
current gradient. For small batch sizes, use high values (e.g., 10-100) to be
|
|
73
74
|
more responsive. For large batch sizes, use low values (e.g., 0-1) for
|
|
74
75
|
stability. (default: 100.0)
|
|
76
|
+
kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
|
|
77
|
+
If `False`, the optimizer behaves as standard Adopt. (default: False)
|
|
78
|
+
beta2_min (float): The minimum value for dynamic β₂, used during periods of
|
|
79
|
+
high gradient variance ("sunspikes"). Must be less than `betas[1]`.
|
|
80
|
+
(default: 0.88)
|
|
81
|
+
ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
|
|
82
|
+
the pooled gradient norms. Corresponds to `α` in the paper.
|
|
83
|
+
(default: 0.93)
|
|
84
|
+
tiny_spike (float): A small constant added to the denominator of the
|
|
85
|
+
"sunspike" ratio calculation to prevent division by zero. Corresponds
|
|
86
|
+
to `ε_spike` in the paper. (default: 1e-9)
|
|
87
|
+
k_warmup_steps (int): The number of initial steps during which β₂ is held
|
|
88
|
+
at a fixed average value (`(beta2_min + beta2_max) / 2`) before the
|
|
89
|
+
dynamic logic activates. (default: 0)
|
|
90
|
+
k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
|
|
91
|
+
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
92
|
+
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
93
|
+
logging (default: 0).
|
|
94
|
+
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
95
|
+
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
96
|
+
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
97
|
+
(default: None)
|
|
75
98
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
76
99
|
the uncompressed optimizer. (default: False)
|
|
77
100
|
"""
|
|
@@ -96,6 +119,13 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
96
119
|
t_alpha: int | None = None,
|
|
97
120
|
Simplified_AdEMAMix: bool = False,
|
|
98
121
|
alpha_grad: float = 100.0,
|
|
122
|
+
kourkoutas_beta: bool = False,
|
|
123
|
+
beta2_min: float = 0.88,
|
|
124
|
+
ema_alpha: float = 0.93,
|
|
125
|
+
tiny_spike: float = 1e-9,
|
|
126
|
+
k_warmup_steps: int = 0,
|
|
127
|
+
k_logging: int = 0,
|
|
128
|
+
layer_key_fn: Optional[Callable] = None,
|
|
99
129
|
nnmf_factor: bool = False,
|
|
100
130
|
):
|
|
101
131
|
if not (lr >= 0.0):
|
|
@@ -111,6 +141,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
111
141
|
cautious_mask = False
|
|
112
142
|
if betas[0] == 0.0 and Simplified_AdEMAMix:
|
|
113
143
|
raise ValueError(f"Beta1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
|
|
144
|
+
if kourkoutas_beta and not (betas[1] > beta2_min): raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
|
|
114
145
|
if use_AdEMAMix and Simplified_AdEMAMix:
|
|
115
146
|
print("Warning: use_AdEMAMix is incompatible with Simplified_AdEMAMix, Disabling use_AdEMAMix.")
|
|
116
147
|
if grams_moment and Simplified_AdEMAMix:
|
|
@@ -125,6 +156,8 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
125
156
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
126
157
|
"vector_reshape": vector_reshape, "beta3_ema": beta3_ema, "alpha": alpha,
|
|
127
158
|
"t_alpha": t_alpha, "alpha_grad": alpha_grad,
|
|
159
|
+
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
160
|
+
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps,
|
|
128
161
|
}
|
|
129
162
|
self.clip_lambda = clip_lambda
|
|
130
163
|
self.stochastic_rounding = stochastic_rounding
|
|
@@ -137,6 +170,12 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
137
170
|
self.factored = nnmf_factor
|
|
138
171
|
super().__init__(params, defaults)
|
|
139
172
|
|
|
173
|
+
self.kourkoutas_beta = kourkoutas_beta
|
|
174
|
+
self.k_logging= k_logging and kourkoutas_beta
|
|
175
|
+
self.layer_key_fn = layer_key_fn and kourkoutas_beta
|
|
176
|
+
if self.kourkoutas_beta:
|
|
177
|
+
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
178
|
+
|
|
140
179
|
@property
|
|
141
180
|
def supports_fused_back_pass(self): return True
|
|
142
181
|
@property
|
|
@@ -174,11 +213,12 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
174
213
|
d1, d2 = state['effective_shape']
|
|
175
214
|
|
|
176
215
|
# m_0 = 0
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
216
|
+
if group['betas'][0] > 0:
|
|
217
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
218
|
+
state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
219
|
+
if not self.grams_moment:
|
|
220
|
+
packed_d2 = (d2 + 7) // 8
|
|
221
|
+
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
182
222
|
if self.use_AdEMAMix:
|
|
183
223
|
state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
184
224
|
state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
@@ -192,17 +232,26 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
192
232
|
# Initialize v_0 using NMF
|
|
193
233
|
_nnmf(vt_init, out=(state['mu_v_nmf'], state['mv_v_nmf']))
|
|
194
234
|
else: # Fallback for non-factored tensors
|
|
195
|
-
|
|
235
|
+
if group['betas'][0] > 0:
|
|
236
|
+
state['exp_avg'] = torch.zeros_like(p, dtype=dtype) # m_0
|
|
196
237
|
if self.use_AdEMAMix:
|
|
197
238
|
state['exp_avg_slow'] = torch.zeros_like(p, dtype=dtype)
|
|
198
239
|
state['exp_avg_sq'] = grad.square() # v_0
|
|
199
240
|
|
|
241
|
+
current_step = state['step']
|
|
242
|
+
if group['kourkoutas_beta']:
|
|
243
|
+
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
244
|
+
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
245
|
+
|
|
246
|
+
beta1, beta2 = group['betas']
|
|
247
|
+
if group['kourkoutas_beta']:
|
|
248
|
+
beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
|
|
249
|
+
|
|
200
250
|
# The first step is for initialization only (skip when use_atan2 as it's scale invariant).
|
|
201
251
|
if state['step'] == 0 and not self.use_atan2:
|
|
202
252
|
state['step'] += 1
|
|
203
253
|
return
|
|
204
254
|
|
|
205
|
-
beta1, beta2 = group['betas']
|
|
206
255
|
if self.use_AdEMAMix:
|
|
207
256
|
beta3_ema = group['beta3_ema']
|
|
208
257
|
alpha = group['alpha']
|
|
@@ -219,13 +268,14 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
219
268
|
d1, d2 = state['effective_shape']
|
|
220
269
|
|
|
221
270
|
# Reconstruct m_{t-1}
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
if
|
|
225
|
-
state['sign']
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
271
|
+
if beta1 > 0:
|
|
272
|
+
mt = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
|
|
273
|
+
if not self.grams_moment:
|
|
274
|
+
if state['sign'].dtype != torch.uint8:
|
|
275
|
+
state['sign'] = state['sign'].to(torch.uint8)
|
|
276
|
+
unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
|
|
277
|
+
torch.where(unpacked_sign, mt, -mt, out=mt)
|
|
278
|
+
del unpacked_sign
|
|
229
279
|
|
|
230
280
|
# Reconstruct AdEMAMix EMA
|
|
231
281
|
if self.use_AdEMAMix:
|
|
@@ -253,25 +303,29 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
253
303
|
del denom
|
|
254
304
|
|
|
255
305
|
# ADOPT Step B: Update momentum m_t using normalized gradient
|
|
256
|
-
if
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
306
|
+
if beta1 > 0:
|
|
307
|
+
if self.Simplified_AdEMAMix:
|
|
308
|
+
mt.mul_(beta1).add_(normalized_grad, alpha=1.0)
|
|
309
|
+
else:
|
|
310
|
+
mt.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
|
|
311
|
+
if self.grams_moment:
|
|
312
|
+
mt = grad_reshaped.sign() * mt.abs()
|
|
313
|
+
elif self.cautious_mask:
|
|
314
|
+
mask = (mt * grad_reshaped > 0).to(grad_reshaped.dtype)
|
|
315
|
+
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
316
|
+
mt.mul_(mask)
|
|
317
|
+
del mask
|
|
267
318
|
|
|
268
319
|
if self.use_AdEMAMix:
|
|
269
320
|
mt_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
|
|
270
|
-
|
|
321
|
+
if beta1 > 0:
|
|
322
|
+
update = torch.add(mt, mt_slow, alpha=alpha_t)
|
|
323
|
+
else:
|
|
324
|
+
update = torch.add(normalized_grad, mt_slow, alpha=alpha_t)
|
|
271
325
|
elif self.Simplified_AdEMAMix:
|
|
272
|
-
update = torch.add(mt,
|
|
326
|
+
update = torch.add(mt, normalized_grad, alpha=alpha_grad)
|
|
273
327
|
else:
|
|
274
|
-
update = mt.clone()
|
|
328
|
+
update = mt.clone() if beta1 > 0 else normalized_grad
|
|
275
329
|
|
|
276
330
|
update = update.view(p.shape)
|
|
277
331
|
|
|
@@ -285,10 +339,11 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
285
339
|
del grad_reshaped
|
|
286
340
|
|
|
287
341
|
# Compress and store new factors
|
|
288
|
-
if
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
342
|
+
if beta1 > 0:
|
|
343
|
+
if not self.grams_moment:
|
|
344
|
+
state['sign'] = _pack_bools(mt > 0)
|
|
345
|
+
_nnmf(mt.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
|
|
346
|
+
del mt
|
|
292
347
|
|
|
293
348
|
if self.use_AdEMAMix:
|
|
294
349
|
state['sign_slow'] = _pack_bools(mt_slow > 0)
|
|
@@ -300,10 +355,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
300
355
|
del vt
|
|
301
356
|
|
|
302
357
|
else: # Standard ADOPT logic for non-factored tensors
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
if self.use_AdEMAMix:
|
|
306
|
-
m_slow = state['exp_avg_slow']
|
|
358
|
+
v = state['exp_avg_sq'] # v_{t-1}
|
|
307
359
|
|
|
308
360
|
# ADOPT Step A: Decorrelate g_t using v_{t-1}
|
|
309
361
|
denom = v.sqrt()
|
|
@@ -318,10 +370,12 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
318
370
|
del denom
|
|
319
371
|
|
|
320
372
|
# ADOPT Step B: Update momentum m_t
|
|
321
|
-
if
|
|
322
|
-
m
|
|
323
|
-
|
|
324
|
-
|
|
373
|
+
if beta1 > 0:
|
|
374
|
+
m = state['exp_avg'] # m_{t-1},
|
|
375
|
+
if self.Simplified_AdEMAMix:
|
|
376
|
+
m.mul_(beta1).add_(normalized_grad, alpha=1.0)
|
|
377
|
+
else:
|
|
378
|
+
m.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
|
|
325
379
|
|
|
326
380
|
if self.grams_moment:
|
|
327
381
|
m = grad.sign() * m.abs()
|
|
@@ -332,12 +386,16 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
332
386
|
del mask
|
|
333
387
|
|
|
334
388
|
if self.use_AdEMAMix:
|
|
389
|
+
m_slow = state['exp_avg_slow']
|
|
335
390
|
m_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
|
|
336
|
-
|
|
391
|
+
if beta1 > 0:
|
|
392
|
+
update = torch.add(m, m_slow, alpha=alpha_t)
|
|
393
|
+
else:
|
|
394
|
+
update = torch.add(normalized_grad, m_slow, alpha=alpha_t)
|
|
337
395
|
elif self.Simplified_AdEMAMix:
|
|
338
|
-
update = torch.add(m,
|
|
396
|
+
update = torch.add(m, normalized_grad, alpha=alpha_grad)
|
|
339
397
|
else:
|
|
340
|
-
update = m.clone()
|
|
398
|
+
update = m.clone() if beta1 > 0 else normalized_grad
|
|
341
399
|
|
|
342
400
|
if self.use_atan2:
|
|
343
401
|
update.mul_(group['lr'] * 1.2732395447351628)
|
|
@@ -374,4 +432,14 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
374
432
|
for i, p in enumerate(group['params']):
|
|
375
433
|
self.step_parameter(p, group, i)
|
|
376
434
|
|
|
435
|
+
if self.kourkoutas_beta and self.k_logging > 0 and hasattr(self, '_beta2_log'):
|
|
436
|
+
first_param_state = self.state[self.param_groups[0]['params'][0]]
|
|
437
|
+
step_num = first_param_state['step']
|
|
438
|
+
|
|
439
|
+
if step_num > 0 and step_num % self.k_logging == 0:
|
|
440
|
+
if self._beta2_log:
|
|
441
|
+
beta2_tensor = torch.tensor(self._beta2_log, device='cpu')
|
|
442
|
+
print(f"Step {step_num}: Kourkoutas beta2 stats: Min={beta2_tensor.min():.4f}, Max={beta2_tensor.max():.4f}, Mean={beta2_tensor.mean():.4f}")
|
|
443
|
+
delattr(self, '_beta2_log')
|
|
444
|
+
|
|
377
445
|
return loss
|
|
@@ -3,11 +3,14 @@ import torch.distributed as dist
|
|
|
3
3
|
|
|
4
4
|
import math
|
|
5
5
|
|
|
6
|
+
from typing import Optional, Callable
|
|
7
|
+
|
|
6
8
|
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
7
9
|
from ..util.Effective_Shape import _get_effective_shape
|
|
8
10
|
from ..util.NNMF import _nnmf,_unnmf
|
|
9
11
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
10
12
|
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
13
|
+
from ..util.Kourkoutas import KourkoutasHelper
|
|
11
14
|
|
|
12
15
|
class Prodigy_adv(torch.optim.Optimizer):
|
|
13
16
|
"""
|
|
@@ -85,6 +88,28 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
85
88
|
prodigy_steps (int): If greater than zero, disable Prodigy's stepsize adjustments
|
|
86
89
|
after the specified optimiser step and release all state memory required by Prodigy
|
|
87
90
|
(default: 0).
|
|
91
|
+
kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
|
|
92
|
+
If `False`, the optimizer behaves as standard AdamW/Prodigy. (default: False)
|
|
93
|
+
beta2_min (float): The minimum value for dynamic β₂, used during periods of
|
|
94
|
+
high gradient variance ("sunspikes"). Must be less than `betas[1]`.
|
|
95
|
+
(default: 0.88)
|
|
96
|
+
ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
|
|
97
|
+
the pooled gradient norms. Corresponds to `α` in the paper.
|
|
98
|
+
(default: 0.93)
|
|
99
|
+
tiny_spike (float): A small constant added to the denominator of the
|
|
100
|
+
"sunspike" ratio calculation to prevent division by zero. Corresponds
|
|
101
|
+
to `ε_spike` in the paper. (default: 1e-9)
|
|
102
|
+
k_warmup_steps (int): The number of initial steps during which β₂ is held
|
|
103
|
+
at a fixed average value (`(beta2_min + beta2_max) / 2`) before the
|
|
104
|
+
dynamic logic activates. (default: 0)
|
|
105
|
+
k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
|
|
106
|
+
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
107
|
+
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
108
|
+
logging (default: 0).
|
|
109
|
+
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
110
|
+
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
111
|
+
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
112
|
+
(default: None)
|
|
88
113
|
"""
|
|
89
114
|
|
|
90
115
|
def __init__(
|
|
@@ -116,6 +141,13 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
116
141
|
fsdp_in_use: bool = False,
|
|
117
142
|
slice_p: int = 11,
|
|
118
143
|
prodigy_steps: int = 0,
|
|
144
|
+
kourkoutas_beta: bool = False,
|
|
145
|
+
beta2_min: float = 0.88,
|
|
146
|
+
ema_alpha: float = 0.93,
|
|
147
|
+
tiny_spike: float = 1e-9,
|
|
148
|
+
k_warmup_steps: int = 0,
|
|
149
|
+
k_logging: int = 0,
|
|
150
|
+
layer_key_fn: Optional[Callable] = None,
|
|
119
151
|
):
|
|
120
152
|
if not (lr >= 0.0):
|
|
121
153
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -141,6 +173,8 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
141
173
|
if use_atan2 and Simplified_AdEMAMix:
|
|
142
174
|
print("Warning: use_atan2 is incompatible with Simplified_AdEMAMix. Disabling use_atan2.")
|
|
143
175
|
use_atan2 = False
|
|
176
|
+
if kourkoutas_beta and not (betas[1] > beta2_min):
|
|
177
|
+
raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
|
|
144
178
|
if Simplified_AdEMAMix and alpha_grad > 0:
|
|
145
179
|
# scales d_coef by alpha_grad, this force prodigy to behave well with Simplified_AdEMAMix
|
|
146
180
|
d_coef = d_coef/alpha_grad
|
|
@@ -153,7 +187,9 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
153
187
|
"beta3": beta3, "d": d0, "d0": d0, "d_max": d0, "d_numerator": 0.0, "d_coef": d_coef,
|
|
154
188
|
"growth_rate": growth_rate, "safeguard_warmup": safeguard_warmup, "k": 0, "slice_p": slice_p,
|
|
155
189
|
"fsdp_in_use": fsdp_in_use, "prodigy_steps": prodigy_steps,
|
|
156
|
-
"alpha_grad": alpha_grad,
|
|
190
|
+
"alpha_grad": alpha_grad,
|
|
191
|
+
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
192
|
+
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps,
|
|
157
193
|
}
|
|
158
194
|
self.stochastic_rounding = stochastic_rounding
|
|
159
195
|
self.cautious_mask = cautious_mask and not Simplified_AdEMAMix
|
|
@@ -163,6 +199,13 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
163
199
|
self.factored = nnmf_factor
|
|
164
200
|
self.fsdp_in_use = fsdp_in_use
|
|
165
201
|
super().__init__(params, defaults)
|
|
202
|
+
|
|
203
|
+
self.kourkoutas_beta = kourkoutas_beta
|
|
204
|
+
self.k_logging= k_logging and kourkoutas_beta
|
|
205
|
+
self.layer_key_fn = layer_key_fn and kourkoutas_beta
|
|
206
|
+
if self.kourkoutas_beta:
|
|
207
|
+
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
208
|
+
|
|
166
209
|
self.init_step()
|
|
167
210
|
|
|
168
211
|
@property
|
|
@@ -180,19 +223,17 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
180
223
|
def init_step(self):
|
|
181
224
|
"""Resets accumulators and calculates dlr for the upcoming step."""
|
|
182
225
|
self.d_denom = 0.0
|
|
183
|
-
|
|
226
|
+
|
|
184
227
|
g_group = self.param_groups[0]
|
|
185
|
-
self.beta1, self.
|
|
228
|
+
self.beta1, self.beta2_default = g_group['betas']
|
|
186
229
|
self.beta3 = g_group['beta3']
|
|
187
230
|
if self.beta3 is None:
|
|
188
|
-
self.beta3 = math.sqrt(self.
|
|
231
|
+
self.beta3 = math.sqrt(self.beta2_default)
|
|
189
232
|
|
|
190
|
-
k = g_group['k']
|
|
191
233
|
self.d = g_group['d']
|
|
192
234
|
lr = g_group['lr']
|
|
193
235
|
|
|
194
236
|
self.dlr = self.d * lr
|
|
195
|
-
|
|
196
237
|
self.d_numerator = g_group.get('d_numerator', 0.0) * self.beta3
|
|
197
238
|
|
|
198
239
|
@torch.no_grad()
|
|
@@ -258,6 +299,15 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
258
299
|
else:
|
|
259
300
|
state['p0'] = torch.tensor(0, device=device, dtype=p.dtype)
|
|
260
301
|
|
|
302
|
+
current_step = state['step']
|
|
303
|
+
if group['kourkoutas_beta']:
|
|
304
|
+
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
305
|
+
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
306
|
+
|
|
307
|
+
beta2 = self.beta2_default
|
|
308
|
+
if group['kourkoutas_beta']:
|
|
309
|
+
beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
|
|
310
|
+
|
|
261
311
|
if self.use_AdEMAMix:
|
|
262
312
|
beta3_ema = group['beta3_ema']
|
|
263
313
|
alpha = group['alpha']
|
|
@@ -295,7 +345,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
295
345
|
del mask
|
|
296
346
|
|
|
297
347
|
vt = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
|
|
298
|
-
vt.mul_(
|
|
348
|
+
vt.mul_(beta2).addcmul_(grad_reshaped, grad_reshaped, value=self.d * self.d * (1.0 - beta2))
|
|
299
349
|
|
|
300
350
|
if self.use_AdEMAMix:
|
|
301
351
|
mt_slow = _unnmf((state['mu_m_slow_nmf'], state['mv_m_slow_nmf']))
|
|
@@ -308,11 +358,11 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
308
358
|
if self.beta1 > 0:
|
|
309
359
|
update = torch.add(mt, mt_slow, alpha=alpha_t)
|
|
310
360
|
else:
|
|
311
|
-
update = torch.add(grad_reshaped, mt_slow, alpha=alpha_t)
|
|
361
|
+
update = torch.add(grad_reshaped.mul(self.d), mt_slow, alpha=alpha_t)
|
|
312
362
|
elif self.Simplified_AdEMAMix:
|
|
313
363
|
update = torch.add(mt, grad_reshaped, alpha=alpha_grad * self.d)
|
|
314
364
|
else:
|
|
315
|
-
update = mt.clone() if self.beta1 > 0 else grad_reshaped.
|
|
365
|
+
update = mt.clone() if self.beta1 > 0 else grad_reshaped.mul(self.d)
|
|
316
366
|
del grad_reshaped
|
|
317
367
|
|
|
318
368
|
if group['use_atan2']:
|
|
@@ -362,13 +412,13 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
362
412
|
if self.beta1 > 0:
|
|
363
413
|
update = torch.add(exp_avg, exp_avg_slow, alpha=alpha_t)
|
|
364
414
|
else:
|
|
365
|
-
update = torch.add(grad, exp_avg_slow, alpha=alpha_t)
|
|
415
|
+
update = torch.add(grad.mul(self.d), exp_avg_slow, alpha=alpha_t)
|
|
366
416
|
elif self.Simplified_AdEMAMix:
|
|
367
417
|
update = torch.add(exp_avg, grad, alpha=alpha_grad * self.d)
|
|
368
418
|
else:
|
|
369
|
-
update = exp_avg.clone() if self.beta1 > 0 else grad.
|
|
419
|
+
update = exp_avg.clone() if self.beta1 > 0 else grad.mul(self.d)
|
|
370
420
|
|
|
371
|
-
exp_avg_sq.mul_(
|
|
421
|
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=self.d * self.d * (1.0 - beta2))
|
|
372
422
|
|
|
373
423
|
if group['use_atan2']:
|
|
374
424
|
a = 1.2732395
|
|
@@ -431,6 +481,15 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
431
481
|
for i, p in enumerate(group['params']):
|
|
432
482
|
self.step_parameter(p, group, i)
|
|
433
483
|
|
|
484
|
+
if self.kourkoutas_beta and self.k_logging > 0 and hasattr(self, '_beta2_log'):
|
|
485
|
+
first_param_state = self.state[self.param_groups[0]['params'][0]]
|
|
486
|
+
step_num = first_param_state['step']
|
|
487
|
+
|
|
488
|
+
if step_num > 0 and step_num % self.k_logging == 0:
|
|
489
|
+
if self._beta2_log:
|
|
490
|
+
beta2_tensor = torch.tensor(self._beta2_log, device='cpu')
|
|
491
|
+
print(f"Step {step_num}: Kourkoutas beta2 stats: Min={beta2_tensor.min():.4f}, Max={beta2_tensor.max():.4f}, Mean={beta2_tensor.mean():.4f}")
|
|
492
|
+
delattr(self, '_beta2_log')
|
|
434
493
|
|
|
435
494
|
self.calculate_d()
|
|
436
495
|
self.init_step()
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import torch
|
|
2
|
+
from typing import Optional, Callable
|
|
2
3
|
|
|
3
4
|
import math
|
|
4
5
|
|
|
@@ -7,6 +8,7 @@ from ..util.Effective_Shape import _get_effective_shape
|
|
|
7
8
|
from ..util.NNMF import _nnmf,_unnmf
|
|
8
9
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
9
10
|
from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
|
|
11
|
+
from ..util.Kourkoutas import KourkoutasHelper
|
|
10
12
|
|
|
11
13
|
# A little helper from the original simplified_AdEMAMix
|
|
12
14
|
def linear_hl_warmup_scheduler(step, beta_end, beta_start=0, warmup=1):
|
|
@@ -47,6 +49,28 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
47
49
|
stochastic_rounding (bool): whether to use stochastic
|
|
48
50
|
rounding for BF16 parameter updates (default: True).
|
|
49
51
|
orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
|
|
52
|
+
kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
|
|
53
|
+
If `False`, the optimizer behaves as standard Simplified_AdEMAMix. (default: False)
|
|
54
|
+
beta2_min (float): The minimum value for dynamic β₂, used during periods of
|
|
55
|
+
high gradient variance ("sunspikes"). Must be less than `betas[1]`.
|
|
56
|
+
(default: 0.88)
|
|
57
|
+
ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
|
|
58
|
+
the pooled gradient norms. Corresponds to `α` in the paper.
|
|
59
|
+
(default: 0.93)
|
|
60
|
+
tiny_spike (float): A small constant added to the denominator of the
|
|
61
|
+
"sunspike" ratio calculation to prevent division by zero. Corresponds
|
|
62
|
+
to `ε_spike` in the paper. (default: 1e-9)
|
|
63
|
+
k_warmup_steps (int): The number of initial steps during which β₂ is held
|
|
64
|
+
at a fixed average value (`(beta2_min + beta2_max) / 2`) before the
|
|
65
|
+
dynamic logic activates. (default: 0)
|
|
66
|
+
k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
|
|
67
|
+
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
68
|
+
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
69
|
+
logging (default: 0).
|
|
70
|
+
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
71
|
+
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
72
|
+
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
73
|
+
(default: None)
|
|
50
74
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
51
75
|
the uncompressed optimizer. (default: False)
|
|
52
76
|
"""
|
|
@@ -65,6 +89,13 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
65
89
|
vector_reshape: bool = True,
|
|
66
90
|
stochastic_rounding: bool = True,
|
|
67
91
|
orthogonal_gradient: bool = False,
|
|
92
|
+
kourkoutas_beta: bool = False,
|
|
93
|
+
beta2_min: float = 0.88,
|
|
94
|
+
ema_alpha: float = 0.93,
|
|
95
|
+
tiny_spike: float = 1e-9,
|
|
96
|
+
k_warmup_steps: int = 0,
|
|
97
|
+
k_logging: int = 0,
|
|
98
|
+
layer_key_fn: Optional[Callable] = None,
|
|
68
99
|
nnmf_factor: bool = False,
|
|
69
100
|
):
|
|
70
101
|
if not (lr >= 0.0):
|
|
@@ -77,17 +108,26 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
77
108
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
78
109
|
if not 0.0 <= alpha_grad:
|
|
79
110
|
raise ValueError("Invalid alpha value: {}".format(alpha_grad))
|
|
111
|
+
if kourkoutas_beta and not (betas[1] > beta2_min): raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
|
|
80
112
|
|
|
81
113
|
defaults = {
|
|
82
114
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
83
115
|
"alpha_grad": alpha_grad, "beta1_warmup": beta1_warmup, "min_beta1": min_beta1,
|
|
84
116
|
"vector_reshape": vector_reshape,
|
|
85
117
|
"orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
|
|
118
|
+
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
119
|
+
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps,
|
|
86
120
|
}
|
|
87
121
|
self.stochastic_rounding = stochastic_rounding
|
|
88
122
|
self.factored = nnmf_factor
|
|
89
123
|
super().__init__(params, defaults)
|
|
90
124
|
|
|
125
|
+
self.kourkoutas_beta = kourkoutas_beta
|
|
126
|
+
self.k_logging= k_logging and kourkoutas_beta
|
|
127
|
+
self.layer_key_fn = layer_key_fn and kourkoutas_beta
|
|
128
|
+
if self.kourkoutas_beta:
|
|
129
|
+
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
130
|
+
|
|
91
131
|
@property
|
|
92
132
|
def supports_fused_back_pass(self):
|
|
93
133
|
return True
|
|
@@ -149,9 +189,17 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
149
189
|
state['num_sum'] = 1.0
|
|
150
190
|
state['den_sum'] = 1.0
|
|
151
191
|
|
|
192
|
+
current_step = state['step']
|
|
193
|
+
if group['kourkoutas_beta']:
|
|
194
|
+
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
195
|
+
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
196
|
+
|
|
152
197
|
beta1_final, beta2 = group["betas"]
|
|
153
198
|
beta1_warmup = group["beta1_warmup"]
|
|
154
199
|
alpha_grad = group["alpha_grad"]
|
|
200
|
+
|
|
201
|
+
if group['kourkoutas_beta']:
|
|
202
|
+
beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
|
|
155
203
|
|
|
156
204
|
if beta1_warmup is not None:
|
|
157
205
|
step = state['step'] + 1
|
|
@@ -161,7 +209,10 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
161
209
|
|
|
162
210
|
if group['use_bias_correction']:
|
|
163
211
|
state['num_sum'] = beta1 * state['num_sum'] + 1.0
|
|
164
|
-
|
|
212
|
+
if group['kourkoutas_beta']:
|
|
213
|
+
state['den_sum'] = group['betas'][1] * state['den_sum'] + (1.0 - group['betas'][1])
|
|
214
|
+
else:
|
|
215
|
+
state['den_sum'] = beta2 * state['den_sum'] + (1.0 - beta2)
|
|
165
216
|
|
|
166
217
|
if state['factored']:
|
|
167
218
|
d1, d2 = state['effective_shape']
|
|
@@ -243,4 +294,14 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
243
294
|
for i, p in enumerate(group['params']):
|
|
244
295
|
self.step_parameter(p, group, i)
|
|
245
296
|
|
|
297
|
+
if self.kourkoutas_beta and self.k_logging > 0 and hasattr(self, '_beta2_log'):
|
|
298
|
+
first_param_state = self.state[self.param_groups[0]['params'][0]]
|
|
299
|
+
step_num = first_param_state['step']
|
|
300
|
+
|
|
301
|
+
if step_num > 0 and step_num % self.k_logging == 0:
|
|
302
|
+
if self._beta2_log:
|
|
303
|
+
beta2_tensor = torch.tensor(self._beta2_log, device='cpu')
|
|
304
|
+
print(f"Step {step_num}: Kourkoutas beta2 stats: Min={beta2_tensor.min():.4f}, Max={beta2_tensor.max():.4f}, Mean={beta2_tensor.mean():.4f}")
|
|
305
|
+
delattr(self, '_beta2_log')
|
|
306
|
+
|
|
246
307
|
return loss
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.optim import Optimizer
|
|
3
|
+
from typing import Callable
|
|
4
|
+
|
|
5
|
+
class KourkoutasHelper:
|
|
6
|
+
"""
|
|
7
|
+
A helper class to add layer-wise Kourkoutas-β functionality to a PyTorch optimizer.
|
|
8
|
+
"""
|
|
9
|
+
def __init__(self, optimizer: Optimizer):
|
|
10
|
+
# We need a reference to the optimizer to access its param_groups and state
|
|
11
|
+
if not hasattr(optimizer, 'param_groups'):
|
|
12
|
+
raise TypeError("optimizer must be a valid torch.optim.Optimizer instance.")
|
|
13
|
+
self.optimizer = optimizer
|
|
14
|
+
|
|
15
|
+
# State managed by the helper
|
|
16
|
+
self.layer_state = {}
|
|
17
|
+
self.layer_info = {}
|
|
18
|
+
self._layer_info_built = False
|
|
19
|
+
self._current_step_prepared = -1
|
|
20
|
+
|
|
21
|
+
def _build_layer_info_if_needed(self):
|
|
22
|
+
"""Builds a map of layers and the parameters they contain."""
|
|
23
|
+
if self._layer_info_built:
|
|
24
|
+
return
|
|
25
|
+
|
|
26
|
+
if not hasattr(self.optimizer, 'layer_key_fn') or self.optimizer.layer_key_fn is None:
|
|
27
|
+
print("Warning: KourkoutasHelper requires 'layer_key_fn' on the optimizer. Defaulting to tensor-wise (id).")
|
|
28
|
+
self.optimizer.layer_key_fn = lambda p: id(p)
|
|
29
|
+
|
|
30
|
+
for group in self.optimizer.param_groups:
|
|
31
|
+
if not group.get('kourkoutas_beta', False):
|
|
32
|
+
continue
|
|
33
|
+
for p in group['params']:
|
|
34
|
+
if p.grad is None: continue
|
|
35
|
+
layer_key = self.optimizer.layer_key_fn(p)
|
|
36
|
+
if layer_key not in self.layer_info:
|
|
37
|
+
self.layer_info[layer_key] = {'params': [], 'group_ref': group}
|
|
38
|
+
self.layer_info[layer_key]['params'].append(p)
|
|
39
|
+
self._layer_info_built = True
|
|
40
|
+
|
|
41
|
+
def prepare_step(self):
|
|
42
|
+
"""
|
|
43
|
+
Calculates dynamic beta2 for all layers using the completed scalar accumulators
|
|
44
|
+
from the PREVIOUS step. Should be called once at the start of an optimizer step.
|
|
45
|
+
"""
|
|
46
|
+
self._build_layer_info_if_needed()
|
|
47
|
+
|
|
48
|
+
if hasattr(self.optimizer, 'logging') and self.optimizer.logging:
|
|
49
|
+
if not hasattr(self.optimizer, '_beta2_log'):
|
|
50
|
+
self.optimizer._beta2_log = []
|
|
51
|
+
|
|
52
|
+
for layer_key, info in self.layer_info.items():
|
|
53
|
+
params, group = info['params'], info['group_ref']
|
|
54
|
+
|
|
55
|
+
if layer_key not in self.layer_state:
|
|
56
|
+
self.layer_state[layer_key] = {
|
|
57
|
+
'r_ema_grad_norm': torch.tensor(0.0, device=params[0].device, dtype=torch.float32),
|
|
58
|
+
'sum_sq_accumulator': torch.tensor(0.0, device=params[0].device, dtype=torch.float32)
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
layer_state = self.layer_state[layer_key]
|
|
62
|
+
|
|
63
|
+
pooled_grad_norm = torch.sqrt(layer_state['sum_sq_accumulator'])
|
|
64
|
+
|
|
65
|
+
r_ema = layer_state['r_ema_grad_norm']
|
|
66
|
+
r_ema.mul_(group['ema_alpha']).add_(pooled_grad_norm, alpha=1.0 - group['ema_alpha'])
|
|
67
|
+
|
|
68
|
+
raw = pooled_grad_norm / (r_ema + group['tiny_spike'])
|
|
69
|
+
sun = raw / (1.0 + raw)
|
|
70
|
+
beta2_max = group['betas'][1]
|
|
71
|
+
beta2 = beta2_max - (beta2_max - group['beta2_min']) * sun
|
|
72
|
+
|
|
73
|
+
layer_state['dynamic_beta2'] = beta2.item()
|
|
74
|
+
layer_state['sum_sq_accumulator'].zero_()
|
|
75
|
+
|
|
76
|
+
if hasattr(self.optimizer, 'logging') and self.optimizer.logging and hasattr(self.optimizer, '_beta2_log'):
|
|
77
|
+
self.optimizer._beta2_log.append(beta2.item())
|
|
78
|
+
|
|
79
|
+
def maybe_prepare_step(self, current_step: int):
|
|
80
|
+
"""
|
|
81
|
+
A universal guard that calls prepare_step() exactly once per training step.
|
|
82
|
+
"""
|
|
83
|
+
if self._current_step_prepared < current_step:
|
|
84
|
+
self.prepare_step()
|
|
85
|
+
self._current_step_prepared = current_step
|
|
86
|
+
|
|
87
|
+
def accumulate_gradient_sq_norm(self, p: torch.Tensor, grad: torch.Tensor):
|
|
88
|
+
"""
|
|
89
|
+
Accumulates the squared L2 norm of a single gradient for the next step's calculation.
|
|
90
|
+
"""
|
|
91
|
+
layer_key = self.optimizer.layer_key_fn(p)
|
|
92
|
+
if layer_key not in self.layer_state:
|
|
93
|
+
self.layer_state[layer_key] = {
|
|
94
|
+
'r_ema_grad_norm': torch.tensor(0.0, device=p.device, dtype=torch.float32),
|
|
95
|
+
'sum_sq_accumulator': torch.tensor(0.0, device=p.device, dtype=torch.float32)
|
|
96
|
+
}
|
|
97
|
+
self.layer_state[layer_key]['sum_sq_accumulator'] += torch.sum(grad.detach().pow(2)).float()
|
|
98
|
+
|
|
99
|
+
def get_beta2(self, p: torch.Tensor, group: dict, current_step: int) -> float:
|
|
100
|
+
"""
|
|
101
|
+
Gets the appropriate beta2 for the current parameter, handling warmup and dynamic value fetching.
|
|
102
|
+
"""
|
|
103
|
+
beta2_default = group['betas'][1]
|
|
104
|
+
if current_step < group['k_warmup_steps']:
|
|
105
|
+
return 0.5 * (group['beta2_min'] + beta2_default)
|
|
106
|
+
|
|
107
|
+
layer_key = self.optimizer.layer_key_fn(p)
|
|
108
|
+
return self.layer_state.get(layer_key, {}).get('dynamic_beta2', beta2_default)
|
|
@@ -16,6 +16,7 @@ adv_optm/optim/Simplified_AdEMAMix.py
|
|
|
16
16
|
adv_optm/optim/__init__.py
|
|
17
17
|
adv_optm/util/BF16_Stochastic_Rounding.py
|
|
18
18
|
adv_optm/util/Effective_Shape.py
|
|
19
|
+
adv_optm/util/Kourkoutas.py
|
|
19
20
|
adv_optm/util/NNMF.py
|
|
20
21
|
adv_optm/util/One_Bit_Boolean.py
|
|
21
22
|
adv_optm/util/OrthoGrad.py
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|