adv-optm 1.2.0__py3-none-any.whl → 1.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/AdaMuon_adv.py +14 -5
- adv_optm/optim/AdamW_adv.py +9 -2
- adv_optm/optim/Adopt_adv.py +9 -2
- adv_optm/optim/Lion_Prodigy_adv.py +8 -1
- adv_optm/optim/Lion_adv.py +8 -1
- adv_optm/optim/Muon_adv.py +16 -9
- adv_optm/optim/Prodigy_adv.py +9 -2
- adv_optm/optim/Simplified_AdEMAMix.py +9 -2
- adv_optm/util/BF16_Stochastic_Rounding.py +29 -4
- adv_optm/util/Kourkoutas.py +12 -6
- {adv_optm-1.2.0.dist-info → adv_optm-1.2.4.dist-info}/METADATA +1 -1
- adv_optm-1.2.4.dist-info/RECORD +23 -0
- adv_optm-1.2.0.dist-info/RECORD +0 -23
- {adv_optm-1.2.0.dist-info → adv_optm-1.2.4.dist-info}/WHEEL +0 -0
- {adv_optm-1.2.0.dist-info → adv_optm-1.2.4.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-1.2.0.dist-info → adv_optm-1.2.4.dist-info}/top_level.txt +0 -0
adv_optm/__init__.py
CHANGED
adv_optm/optim/AdaMuon_adv.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
3
|
+
from ..util.BF16_Stochastic_Rounding import add_stochastic_, set_seed as set_stochastic_rounding_seed
|
|
4
4
|
from ..util.Newton_Schulz import _newton_schulz_iteration
|
|
5
5
|
from ..util.Effective_Shape import _get_effective_shape
|
|
6
6
|
from ..util.NNMF import _nnmf,_unnmf
|
|
@@ -184,6 +184,13 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
184
184
|
torch._dynamo.config.cache_size_limit = 8192
|
|
185
185
|
self.compile(fullgraph=True)
|
|
186
186
|
|
|
187
|
+
if self.stochastic_rounding:
|
|
188
|
+
# For deterministic stochastic rounding, we need to seed the generator
|
|
189
|
+
# for each device used by the parameters.
|
|
190
|
+
devices = {p.device for group in self.param_groups for p in group['params'] if p.dtype == torch.bfloat16}
|
|
191
|
+
for device in devices:
|
|
192
|
+
set_stochastic_rounding_seed(device)
|
|
193
|
+
|
|
187
194
|
@property
|
|
188
195
|
def supports_fused_back_pass(self):
|
|
189
196
|
return True
|
|
@@ -241,6 +248,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
241
248
|
elif len(p.shape) >= 2:
|
|
242
249
|
state['normuon_v'] = torch.zeros(p.shape[0], device=p.device, dtype=torch.float32)
|
|
243
250
|
|
|
251
|
+
group['adam_kourkoutas_beta'] = False
|
|
252
|
+
|
|
244
253
|
elif optim_type == 'adam':
|
|
245
254
|
|
|
246
255
|
state['step'] = 0
|
|
@@ -441,8 +450,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
441
450
|
)
|
|
442
451
|
del signed_m_buf
|
|
443
452
|
|
|
444
|
-
update = update.view(original_shape)
|
|
445
|
-
|
|
446
453
|
if group['normuon_variant']:
|
|
447
454
|
# NorMuon Logic
|
|
448
455
|
v_t = state['normuon_v']
|
|
@@ -452,7 +459,9 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
452
459
|
# Normalize update
|
|
453
460
|
update.div_(v_t.sqrt().unsqueeze(1).add_(group['eps']))
|
|
454
461
|
del mean_squared_update
|
|
462
|
+
update = update.view(original_shape)
|
|
455
463
|
else:
|
|
464
|
+
update = update.view(original_shape)
|
|
456
465
|
# Original AdaMuon Logic
|
|
457
466
|
vt_buf = state['second_momentum_buffer']
|
|
458
467
|
vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
|
|
@@ -470,10 +479,10 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
470
479
|
if group['rms_rescaling']:
|
|
471
480
|
rms_target = 0.2 # default (Adam) value for RMS
|
|
472
481
|
update_norm = torch.linalg.vector_norm(update)
|
|
473
|
-
update
|
|
482
|
+
update.mul_(rms_target * lr * (p.numel()**0.5) / update_norm.add_(1e-8))
|
|
474
483
|
del update_norm
|
|
475
484
|
else:
|
|
476
|
-
update
|
|
485
|
+
update.mul_(lr)
|
|
477
486
|
|
|
478
487
|
else: # Fallback to standard SGD with momentum for 1D params (biases, etc.)
|
|
479
488
|
# Momentum update
|
adv_optm/optim/AdamW_adv.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from typing import Optional, Callable
|
|
3
3
|
|
|
4
|
-
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
4
|
+
from ..util.BF16_Stochastic_Rounding import add_stochastic_, set_seed as set_stochastic_rounding_seed
|
|
5
5
|
from ..util.Effective_Shape import _get_effective_shape
|
|
6
6
|
from ..util.NNMF import _nnmf,_unnmf
|
|
7
7
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
@@ -142,6 +142,13 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
142
142
|
if self.kourkoutas_beta:
|
|
143
143
|
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
144
144
|
|
|
145
|
+
if self.stochastic_rounding:
|
|
146
|
+
# For deterministic stochastic rounding, we need to seed the generator
|
|
147
|
+
# for each device used by the parameters.
|
|
148
|
+
devices = {p.device for group in self.param_groups for p in group['params'] if p.dtype == torch.bfloat16}
|
|
149
|
+
for device in devices:
|
|
150
|
+
set_stochastic_rounding_seed(device)
|
|
151
|
+
|
|
145
152
|
@property
|
|
146
153
|
def supports_fused_back_pass(self):
|
|
147
154
|
return True
|
|
@@ -215,7 +222,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
215
222
|
# Accumulate current grad's norm for the *next* step
|
|
216
223
|
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
217
224
|
# Get the dynamic beta2 calculated in prepare_step()
|
|
218
|
-
beta2 = self.kourkoutas_helper.get_beta2(p, group
|
|
225
|
+
beta2 = self.kourkoutas_helper.get_beta2(p, group)
|
|
219
226
|
|
|
220
227
|
step = state['step'] + 1
|
|
221
228
|
if group['use_bias_correction']:
|
adv_optm/optim/Adopt_adv.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from typing import Callable, Optional
|
|
3
3
|
|
|
4
|
-
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
4
|
+
from ..util.BF16_Stochastic_Rounding import add_stochastic_, set_seed as set_stochastic_rounding_seed
|
|
5
5
|
from ..util.Effective_Shape import _get_effective_shape
|
|
6
6
|
from ..util.NNMF import _nnmf, _unnmf
|
|
7
7
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
@@ -172,6 +172,13 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
172
172
|
if self.kourkoutas_beta:
|
|
173
173
|
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
174
174
|
|
|
175
|
+
if self.stochastic_rounding:
|
|
176
|
+
# For deterministic stochastic rounding, we need to seed the generator
|
|
177
|
+
# for each device used by the parameters.
|
|
178
|
+
devices = {p.device for group in self.param_groups for p in group['params'] if p.dtype == torch.bfloat16}
|
|
179
|
+
for device in devices:
|
|
180
|
+
set_stochastic_rounding_seed(device)
|
|
181
|
+
|
|
175
182
|
@property
|
|
176
183
|
def supports_fused_back_pass(self): return True
|
|
177
184
|
@property
|
|
@@ -243,7 +250,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
243
250
|
# Accumulate current grad's norm for the *next* step
|
|
244
251
|
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
245
252
|
# Get the dynamic beta2 calculated in prepare_step()
|
|
246
|
-
beta2 = self.kourkoutas_helper.get_beta2(p, group
|
|
253
|
+
beta2 = self.kourkoutas_helper.get_beta2(p, group)
|
|
247
254
|
|
|
248
255
|
# The first step is for initialization only (skip when use_atan2 as it's scale invariant).
|
|
249
256
|
if state['step'] == 0 and not self.use_atan2:
|
|
@@ -5,7 +5,7 @@ import math
|
|
|
5
5
|
|
|
6
6
|
from typing import Tuple, Optional
|
|
7
7
|
|
|
8
|
-
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
8
|
+
from ..util.BF16_Stochastic_Rounding import add_stochastic_, set_seed as set_stochastic_rounding_seed
|
|
9
9
|
from ..util.Effective_Shape import _get_effective_shape
|
|
10
10
|
from ..util.NNMF import _nnmf,_unnmf
|
|
11
11
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
@@ -109,6 +109,13 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
109
109
|
# Global state for accumulating metrics across parameter updates within a single step.
|
|
110
110
|
self.init_step()
|
|
111
111
|
|
|
112
|
+
if self.stochastic_rounding:
|
|
113
|
+
# For deterministic stochastic rounding, we need to seed the generator
|
|
114
|
+
# for each device used by the parameters.
|
|
115
|
+
devices = {p.device for group in self.param_groups for p in group['params'] if p.dtype == torch.bfloat16}
|
|
116
|
+
for device in devices:
|
|
117
|
+
set_stochastic_rounding_seed(device)
|
|
118
|
+
|
|
112
119
|
@property
|
|
113
120
|
def supports_fused_back_pass(self) -> bool:
|
|
114
121
|
return True
|
adv_optm/optim/Lion_adv.py
CHANGED
|
@@ -2,7 +2,7 @@ import torch
|
|
|
2
2
|
|
|
3
3
|
from typing import Tuple, Optional
|
|
4
4
|
|
|
5
|
-
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
5
|
+
from ..util.BF16_Stochastic_Rounding import add_stochastic_, set_seed as set_stochastic_rounding_seed
|
|
6
6
|
from ..util.Effective_Shape import _get_effective_shape
|
|
7
7
|
from ..util.NNMF import _nnmf,_unnmf
|
|
8
8
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
@@ -68,6 +68,13 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
68
68
|
self.factored = nnmf_factor
|
|
69
69
|
super().__init__(params, defaults)
|
|
70
70
|
|
|
71
|
+
if self.stochastic_rounding:
|
|
72
|
+
# For deterministic stochastic rounding, we need to seed the generator
|
|
73
|
+
# for each device used by the parameters.
|
|
74
|
+
devices = {p.device for group in self.param_groups for p in group['params'] if p.dtype == torch.bfloat16}
|
|
75
|
+
for device in devices:
|
|
76
|
+
set_stochastic_rounding_seed(device)
|
|
77
|
+
|
|
71
78
|
@property
|
|
72
79
|
def supports_fused_back_pass(self) -> bool:
|
|
73
80
|
return True
|
adv_optm/optim/Muon_adv.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
3
|
+
from ..util.BF16_Stochastic_Rounding import add_stochastic_, set_seed as set_stochastic_rounding_seed
|
|
4
4
|
from ..util.Newton_Schulz import _newton_schulz_iteration
|
|
5
5
|
from ..util.Effective_Shape import _get_effective_shape
|
|
6
6
|
from ..util.NNMF import _nnmf,_unnmf
|
|
@@ -186,6 +186,13 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
186
186
|
torch._dynamo.config.cache_size_limit = 8192
|
|
187
187
|
self.compile(fullgraph=True)
|
|
188
188
|
|
|
189
|
+
if self.stochastic_rounding:
|
|
190
|
+
# For deterministic stochastic rounding, we need to seed the generator
|
|
191
|
+
# for each device used by the parameters.
|
|
192
|
+
devices = {p.device for group in self.param_groups for p in group['params'] if p.dtype == torch.bfloat16}
|
|
193
|
+
for device in devices:
|
|
194
|
+
set_stochastic_rounding_seed(device)
|
|
195
|
+
|
|
189
196
|
@property
|
|
190
197
|
def supports_fused_back_pass(self):
|
|
191
198
|
return True
|
|
@@ -363,14 +370,14 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
363
370
|
update.div_(v_t.sqrt().unsqueeze(1).add_(group['normuon_eps']))
|
|
364
371
|
del mean_squared_update
|
|
365
372
|
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
373
|
+
# RMS-aligned rescaling
|
|
374
|
+
if group['rms_rescaling']:
|
|
375
|
+
rms_target = 0.2 # default (Adam) value for RMS
|
|
376
|
+
update_norm = torch.linalg.vector_norm(update)
|
|
377
|
+
update = update.view(p.shape).mul_(rms_target * lr * (p.numel()**0.5) / update_norm.add_(1e-8))
|
|
378
|
+
del update_norm
|
|
379
|
+
else:
|
|
380
|
+
update = update.view(p.shape).mul_(lr)
|
|
374
381
|
|
|
375
382
|
state['sign_buf'] = _pack_bools(mt_buf > 0)
|
|
376
383
|
_nnmf(mt_buf.abs(), out=(state['mu_mbuf_nmf'], state['mv_mbuf_nmf']))
|
adv_optm/optim/Prodigy_adv.py
CHANGED
|
@@ -5,7 +5,7 @@ import math
|
|
|
5
5
|
|
|
6
6
|
from typing import Optional, Callable
|
|
7
7
|
|
|
8
|
-
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
8
|
+
from ..util.BF16_Stochastic_Rounding import add_stochastic_, set_seed as set_stochastic_rounding_seed
|
|
9
9
|
from ..util.Effective_Shape import _get_effective_shape
|
|
10
10
|
from ..util.NNMF import _nnmf,_unnmf
|
|
11
11
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
@@ -212,6 +212,13 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
212
212
|
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
213
213
|
self.init_step()
|
|
214
214
|
|
|
215
|
+
if self.stochastic_rounding:
|
|
216
|
+
# For deterministic stochastic rounding, we need to seed the generator
|
|
217
|
+
# for each device used by the parameters.
|
|
218
|
+
devices = {p.device for group in self.param_groups for p in group['params'] if p.dtype == torch.bfloat16}
|
|
219
|
+
for device in devices:
|
|
220
|
+
set_stochastic_rounding_seed(device)
|
|
221
|
+
|
|
215
222
|
@property
|
|
216
223
|
def supports_fused_back_pass(self):
|
|
217
224
|
return True
|
|
@@ -310,7 +317,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
310
317
|
# Accumulate current grad's norm for the *next* step
|
|
311
318
|
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
312
319
|
# Get the dynamic beta2 calculated in prepare_step()
|
|
313
|
-
beta2 = self.kourkoutas_helper.get_beta2(p, group
|
|
320
|
+
beta2 = self.kourkoutas_helper.get_beta2(p, group)
|
|
314
321
|
else:
|
|
315
322
|
beta2 = self.beta2_default
|
|
316
323
|
|
|
@@ -3,7 +3,7 @@ from typing import Optional, Callable
|
|
|
3
3
|
|
|
4
4
|
import math
|
|
5
5
|
|
|
6
|
-
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
6
|
+
from ..util.BF16_Stochastic_Rounding import add_stochastic_, set_seed as set_stochastic_rounding_seed
|
|
7
7
|
from ..util.Effective_Shape import _get_effective_shape
|
|
8
8
|
from ..util.NNMF import _nnmf,_unnmf
|
|
9
9
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
@@ -127,6 +127,13 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
127
127
|
if self.kourkoutas_beta:
|
|
128
128
|
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
129
129
|
|
|
130
|
+
if self.stochastic_rounding:
|
|
131
|
+
# For deterministic stochastic rounding, we need to seed the generator
|
|
132
|
+
# for each device used by the parameters.
|
|
133
|
+
devices = {p.device for group in self.param_groups for p in group['params'] if p.dtype == torch.bfloat16}
|
|
134
|
+
for device in devices:
|
|
135
|
+
set_stochastic_rounding_seed(device)
|
|
136
|
+
|
|
130
137
|
@property
|
|
131
138
|
def supports_fused_back_pass(self):
|
|
132
139
|
return True
|
|
@@ -197,7 +204,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
|
|
|
197
204
|
# Accumulate current grad's norm for the *next* step
|
|
198
205
|
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
199
206
|
# Get the dynamic beta2 calculated in prepare_step()
|
|
200
|
-
beta2 = self.kourkoutas_helper.get_beta2(p, group
|
|
207
|
+
beta2 = self.kourkoutas_helper.get_beta2(p, group)
|
|
201
208
|
|
|
202
209
|
beta1_warmup = group["beta1_warmup"]
|
|
203
210
|
alpha_grad = group["alpha_grad"]
|
|
@@ -1,10 +1,25 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from torch import Tensor
|
|
3
3
|
|
|
4
|
+
from typing import Dict, Any
|
|
5
|
+
|
|
6
|
+
_generators: Dict[torch.device, torch.Generator] = {}
|
|
7
|
+
|
|
8
|
+
def set_seed(device: torch.device):
|
|
9
|
+
"""
|
|
10
|
+
Initializes or resets the deterministic generator for a specific device.
|
|
11
|
+
This ensures that the sequence of random numbers used for stochastic
|
|
12
|
+
rounding is reproducible.
|
|
13
|
+
"""
|
|
14
|
+
global _generators
|
|
15
|
+
if device not in _generators:
|
|
16
|
+
_generators[device] = torch.Generator(device=device)
|
|
17
|
+
_generators[device].manual_seed(42)
|
|
18
|
+
|
|
4
19
|
def copy_stochastic_(target: Tensor, source: Tensor):
|
|
5
20
|
"""
|
|
6
21
|
Nerogar's implementation of stochastic rounding in the paper "Revisiting BFloat16 Training"
|
|
7
|
-
(https://arxiv.org/abs/2010.06192).
|
|
22
|
+
(https://arxiv.org/abs/2010.06192). Made deterministic.
|
|
8
23
|
see:
|
|
9
24
|
https://github.com/pytorch/pytorch/issues/120376
|
|
10
25
|
https://github.com/Nerogar/OneTrainer/blob/daae18eaed8c0fa39289b2ff79cc2c1e08577fcb/modules/util/bf16_stochastic_rounding.py
|
|
@@ -13,12 +28,21 @@ def copy_stochastic_(target: Tensor, source: Tensor):
|
|
|
13
28
|
target: the target tensor with dtype=bfloat16
|
|
14
29
|
source: the target tensor with dtype=float32
|
|
15
30
|
"""
|
|
31
|
+
global _generators
|
|
32
|
+
device = source.device
|
|
33
|
+
if device not in _generators:
|
|
34
|
+
set_seed(device)
|
|
35
|
+
|
|
36
|
+
generator = _generators[device]
|
|
37
|
+
|
|
16
38
|
# create a random 16 bit integer
|
|
17
|
-
result = torch.
|
|
18
|
-
source,
|
|
39
|
+
result = torch.randint(
|
|
40
|
+
size=source.shape,
|
|
41
|
+
device=source.device,
|
|
19
42
|
dtype=torch.int32,
|
|
20
43
|
low=0,
|
|
21
44
|
high=(1 << 16),
|
|
45
|
+
generator=generator,
|
|
22
46
|
)
|
|
23
47
|
|
|
24
48
|
# add the random number to the lower 16 bit of the mantissa
|
|
@@ -32,6 +56,7 @@ def copy_stochastic_(target: Tensor, source: Tensor):
|
|
|
32
56
|
|
|
33
57
|
del result
|
|
34
58
|
|
|
59
|
+
|
|
35
60
|
def add_stochastic_(input: Tensor, other: Tensor, alpha: float = 1.0):
|
|
36
61
|
"""
|
|
37
62
|
adds other to input using stochastic rounding
|
|
@@ -44,4 +69,4 @@ def add_stochastic_(input: Tensor, other: Tensor, alpha: float = 1.0):
|
|
|
44
69
|
result = other.clone() if other.dtype == torch.float32 else other.to(dtype=torch.float32)
|
|
45
70
|
|
|
46
71
|
result.add_(input, alpha=alpha)
|
|
47
|
-
copy_stochastic_(input, result)
|
|
72
|
+
copy_stochastic_(input, result)
|
adv_optm/util/Kourkoutas.py
CHANGED
|
@@ -88,11 +88,17 @@ class KourkoutasHelper:
|
|
|
88
88
|
# Use group-specific K-b settings, falling back to the optimizer's master defaults.
|
|
89
89
|
# This makes the helper robust against param groups that enable kourkoutas_beta
|
|
90
90
|
# but are missing the other required hyperparameters.
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
91
|
+
# In hybrid optimizers like Muon_adv, the Kourkoutas-related keys in the
|
|
92
|
+
# defaults and param_groups are prefixed with 'adam_' to avoid conflicts.
|
|
93
|
+
# We must detect this case and use the correct key names.
|
|
94
|
+
prefix = 'adam_' if group.get('adam_kourkoutas_beta', False) else ''
|
|
95
|
+
|
|
96
|
+
ema_alpha = group.get(f'{prefix}ema_alpha', master_defaults[f'{prefix}ema_alpha'])
|
|
97
|
+
betas_tuple = group.get(f'{prefix}betas', master_defaults[f'{prefix}betas'])
|
|
98
|
+
beta2_max = betas_tuple[1]
|
|
99
|
+
beta2_min = group.get(f'{prefix}beta2_min', master_defaults[f'{prefix}beta2_min'])
|
|
100
|
+
tiny_spike = group.get(f'{prefix}tiny_spike', master_defaults[f'{prefix}tiny_spike'])
|
|
101
|
+
k_warmup_steps = group.get(f'{prefix}k_warmup_steps', master_defaults[f'{prefix}k_warmup_steps'])
|
|
96
102
|
|
|
97
103
|
r_ema_tensor = param_state['kourkoutas_r_ema']
|
|
98
104
|
accumulator = self.layer_state[layer_key]['sum_sq_accumulator']
|
|
@@ -149,7 +155,7 @@ class KourkoutasHelper:
|
|
|
149
155
|
# Accumulate for the *next* step's prepare_step call
|
|
150
156
|
self.layer_state[layer_key]['sum_sq_accumulator'] += torch.sum(grad.detach().pow(2)).float()
|
|
151
157
|
|
|
152
|
-
def get_beta2(self, p: torch.Tensor, group: dict
|
|
158
|
+
def get_beta2(self, p: torch.Tensor, group: dict) -> float:
|
|
153
159
|
"""
|
|
154
160
|
Gets the appropriate beta2 for the current parameter, handling warmup and dynamic value fetching.
|
|
155
161
|
"""
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
adv_optm/__init__.py,sha256=fxQlsNgh5Su63jHu4hPZt_1NCsoFuEsZmAa7cvUn3I0,376
|
|
2
|
+
adv_optm/optim/AdaMuon_adv.py,sha256=miib3NlnBZyT0K4wLliR7I9Vw4xsVdC3ewWfgP88mxE,34686
|
|
3
|
+
adv_optm/optim/AdamW_adv.py,sha256=ZvZkqOIqT_013sCqRoL4drEKCCXbsQY-JRrRngoN9f8,18068
|
|
4
|
+
adv_optm/optim/Adopt_adv.py,sha256=8es_ot1EgJa3SZHfKQ_PU4fYM4TJMAGOmOGq_876IOs,21870
|
|
5
|
+
adv_optm/optim/Lion_Prodigy_adv.py,sha256=nns9Oz_0EKmGRN8p6kWMlRWKI-tHx8v8eg8TH-hXjJY,15047
|
|
6
|
+
adv_optm/optim/Lion_adv.py,sha256=ug4uuQk3PmdkggsuzqNpZ6vieLUAbTHGr1Q_pvuLLVs,8729
|
|
7
|
+
adv_optm/optim/Muon_adv.py,sha256=HEbyFYak4aDRfxJwwKD7PlvHXTE2TOpadWO0vRVnNf8,34119
|
|
8
|
+
adv_optm/optim/Prodigy_adv.py,sha256=Wiukv1Hn6KFSslI6Dk4QXFFwNNtRjQsJ4GNEYkC4dFc,26662
|
|
9
|
+
adv_optm/optim/Simplified_AdEMAMix.py,sha256=1j3M9t0Dza2dLVabwC0ft36sANx-QHBeLRp2WJlU_3s,13387
|
|
10
|
+
adv_optm/optim/__init__.py,sha256=hpUWE6CKtt_rvMdgQVb3PtjhfZAvAxTq6hp8H8rIpBo,489
|
|
11
|
+
adv_optm/util/BF16_Stochastic_Rounding.py,sha256=VXfv7U4-Yhyq1o6gZMApvW1DUUwZ15-eob98daQW9uc,2288
|
|
12
|
+
adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
|
|
13
|
+
adv_optm/util/Kourkoutas.py,sha256=eSE2KUnvbxP2Kg4sUCFmqGLvX5eov4OUjULKBKHBLoc,8131
|
|
14
|
+
adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
|
|
15
|
+
adv_optm/util/Newton_Schulz.py,sha256=bBboYw_jm5_FMf0Citl79uqNedkHOTjQnUI7rZgLBmY,3341
|
|
16
|
+
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
17
|
+
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
18
|
+
adv_optm/util/__init__.py,sha256=CXzS703GB4gil85khZi7sgKOnbzXGBOltshIOSPqj18,435
|
|
19
|
+
adv_optm-1.2.4.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
20
|
+
adv_optm-1.2.4.dist-info/METADATA,sha256=FUjhBc16Ab58N06TRXq7087T9EwZnZqtLbw5niYzIS4,11917
|
|
21
|
+
adv_optm-1.2.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
22
|
+
adv_optm-1.2.4.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
23
|
+
adv_optm-1.2.4.dist-info/RECORD,,
|
adv_optm-1.2.0.dist-info/RECORD
DELETED
|
@@ -1,23 +0,0 @@
|
|
|
1
|
-
adv_optm/__init__.py,sha256=lQwVoYMSCofAxJ_CadX2NssB1jldn9JetoyAVMJPDrs,376
|
|
2
|
-
adv_optm/optim/AdaMuon_adv.py,sha256=sqMd1cdBDMpwVmEoU1w3vE_Fj3nfx-_bZjf8mf5st4Y,34189
|
|
3
|
-
adv_optm/optim/AdamW_adv.py,sha256=pDKwdOV90qxTkRuIez0kU_VdI0ztJygY-MxhhQT10Yw,17652
|
|
4
|
-
adv_optm/optim/Adopt_adv.py,sha256=eSLJS0RVJ0MAE5pMFK-Q00vJF6NuxKJbefAg8F58XD4,21454
|
|
5
|
-
adv_optm/optim/Lion_Prodigy_adv.py,sha256=LEA3UYJpPeFnmxeniLNv1u2LKKj4ufx3Bq_MLw-nWXk,14617
|
|
6
|
-
adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
|
|
7
|
-
adv_optm/optim/Muon_adv.py,sha256=_odkBjwIuY895fh6wAs_9ljXyyPCg9V-tAQnjMVf4Po,33721
|
|
8
|
-
adv_optm/optim/Prodigy_adv.py,sha256=H0xuVhaCDJF6ilts_It20teZZCN4MSbOSPQ-fsy6pEg,26246
|
|
9
|
-
adv_optm/optim/Simplified_AdEMAMix.py,sha256=nEIA3yM11nBooKzHudB5l3x4UdFRBYRwiKVUkGmO0K8,12971
|
|
10
|
-
adv_optm/optim/__init__.py,sha256=hpUWE6CKtt_rvMdgQVb3PtjhfZAvAxTq6hp8H8rIpBo,489
|
|
11
|
-
adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
|
|
12
|
-
adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
|
|
13
|
-
adv_optm/util/Kourkoutas.py,sha256=C_Qn6I0Qao_9D_nCv4ZYmC_SgJLoPwhrMb5FkRQ-k1M,7693
|
|
14
|
-
adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
|
|
15
|
-
adv_optm/util/Newton_Schulz.py,sha256=bBboYw_jm5_FMf0Citl79uqNedkHOTjQnUI7rZgLBmY,3341
|
|
16
|
-
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
17
|
-
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
18
|
-
adv_optm/util/__init__.py,sha256=CXzS703GB4gil85khZi7sgKOnbzXGBOltshIOSPqj18,435
|
|
19
|
-
adv_optm-1.2.0.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
20
|
-
adv_optm-1.2.0.dist-info/METADATA,sha256=qsz3TfHskcMBhNRYT_YD58_pr2kfFdsm8LQ5WOXoBlE,11917
|
|
21
|
-
adv_optm-1.2.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
22
|
-
adv_optm-1.2.0.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
23
|
-
adv_optm-1.2.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|