adv-optm 0.1.0__tar.gz → 0.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- {adv_optm-0.1.0 → adv_optm-0.1.1}/PKG-INFO +1 -1
- {adv_optm-0.1.0 → adv_optm-0.1.1}/adv_optm/__init__.py +1 -1
- {adv_optm-0.1.0 → adv_optm-0.1.1}/adv_optm/optim/AdamW_adv.py +6 -6
- {adv_optm-0.1.0 → adv_optm-0.1.1}/adv_optm/optim/Adopt_adv.py +6 -6
- {adv_optm-0.1.0 → adv_optm-0.1.1}/adv_optm/optim/Prodigy_adv.py +35 -3
- {adv_optm-0.1.0 → adv_optm-0.1.1}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-0.1.0 → adv_optm-0.1.1}/setup.py +1 -1
- {adv_optm-0.1.0 → adv_optm-0.1.1}/LICENSE +0 -0
- {adv_optm-0.1.0 → adv_optm-0.1.1}/README.md +0 -0
- {adv_optm-0.1.0 → adv_optm-0.1.1}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-0.1.0 → adv_optm-0.1.1}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
- {adv_optm-0.1.0 → adv_optm-0.1.1}/adv_optm/util/Effective_Shape.py +0 -0
- {adv_optm-0.1.0 → adv_optm-0.1.1}/adv_optm/util/NNMF.py +0 -0
- {adv_optm-0.1.0 → adv_optm-0.1.1}/adv_optm/util/One_Bit_Boolean.py +0 -0
- {adv_optm-0.1.0 → adv_optm-0.1.1}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-0.1.0 → adv_optm-0.1.1}/adv_optm/util/Randomized_SVD.py +0 -0
- {adv_optm-0.1.0 → adv_optm-0.1.1}/adv_optm/util/__init__.py +0 -0
- {adv_optm-0.1.0 → adv_optm-0.1.1}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-0.1.0 → adv_optm-0.1.1}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-0.1.0 → adv_optm-0.1.1}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-0.1.0 → adv_optm-0.1.1}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-0.1.0 → adv_optm-0.1.1}/setup.cfg +0 -0
|
@@ -37,7 +37,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
37
37
|
combined with the primary momentum (`mt`) to stabilize updates,
|
|
38
38
|
especially in noisy, small-batch settings. If `False`, the
|
|
39
39
|
optimizer behaves as standard AdamW. (default: False)
|
|
40
|
-
|
|
40
|
+
beta3_ema (float): The decay rate for the slow exponential moving average of
|
|
41
41
|
the momentum (only used when `use_AdEMAMix` is `True`). A higher
|
|
42
42
|
value (e.g., 0.9999) gives the EMA a longer memory, making it more
|
|
43
43
|
stable but slower to adapt. A lower value (e.g., 0.999) is often
|
|
@@ -71,7 +71,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
71
71
|
use_grams: bool = False,
|
|
72
72
|
use_orthograd: bool = False,
|
|
73
73
|
use_AdEMAMix: bool = False,
|
|
74
|
-
|
|
74
|
+
beta3_ema: float = 0.9999,
|
|
75
75
|
alpha: float = 5.0,
|
|
76
76
|
t_alpha: int | None = None,
|
|
77
77
|
factored: bool = True,
|
|
@@ -89,7 +89,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
89
89
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
90
90
|
"vector_reshape": vector_reshape, "use_atan2": use_atan2,
|
|
91
91
|
"use_orthograd": use_orthograd, "use_bias_correction": use_bias_correction,
|
|
92
|
-
"
|
|
92
|
+
"beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
|
|
93
93
|
}
|
|
94
94
|
self.stochastic_rounding = stochastic_rounding
|
|
95
95
|
self.use_cautious = use_cautious
|
|
@@ -162,7 +162,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
162
162
|
|
|
163
163
|
beta1, beta2 = group['betas']
|
|
164
164
|
if self.use_AdEMAMix:
|
|
165
|
-
|
|
165
|
+
beta3_ema = group['beta3_ema']
|
|
166
166
|
alpha = group['alpha']
|
|
167
167
|
t_alpha = group['t_alpha']
|
|
168
168
|
current_step = state['step'] + 1
|
|
@@ -201,7 +201,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
201
201
|
torch.where(unpacked_sign_slow, mt_slow, -mt_slow, out=mt_slow)
|
|
202
202
|
del unpacked_sign_slow
|
|
203
203
|
|
|
204
|
-
mt_slow.mul_(
|
|
204
|
+
mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
|
|
205
205
|
update_m = mt + (alpha_t * mt_slow)
|
|
206
206
|
else:
|
|
207
207
|
update_m = mt
|
|
@@ -245,7 +245,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
245
245
|
|
|
246
246
|
if self.use_AdEMAMix:
|
|
247
247
|
exp_avg_slow = state['exp_avg_slow']
|
|
248
|
-
exp_avg_slow.mul_(
|
|
248
|
+
exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
|
|
249
249
|
update_m = exp_avg + (alpha_t * exp_avg_slow)
|
|
250
250
|
else:
|
|
251
251
|
update_m = exp_avg
|
|
@@ -48,7 +48,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
48
48
|
combined with the primary momentum (`mt`) to stabilize updates,
|
|
49
49
|
especially in noisy, small-batch settings. If `False`, the
|
|
50
50
|
optimizer behaves as standard ADOPT. (default: False)
|
|
51
|
-
|
|
51
|
+
beta3_ema (float): The decay rate for the slow exponential moving average of
|
|
52
52
|
the momentum (only used when `use_AdEMAMix` is `True`). A higher
|
|
53
53
|
value (e.g., 0.9999) gives the EMA a longer memory, making it more
|
|
54
54
|
stable but slower to adapt. A lower value (e.g., 0.999) is often
|
|
@@ -83,7 +83,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
83
83
|
use_grams: bool = False,
|
|
84
84
|
use_orthograd: bool = False,
|
|
85
85
|
use_AdEMAMix: bool = False,
|
|
86
|
-
|
|
86
|
+
beta3_ema: float = 0.9999,
|
|
87
87
|
alpha: float = 5.0,
|
|
88
88
|
t_alpha: int | None = None,
|
|
89
89
|
factored: bool = True,
|
|
@@ -99,7 +99,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
99
99
|
|
|
100
100
|
defaults = {
|
|
101
101
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
102
|
-
"vector_reshape": vector_reshape, "
|
|
102
|
+
"vector_reshape": vector_reshape, "beta3_ema": beta3_ema, "alpha": alpha,
|
|
103
103
|
"t_alpha": t_alpha,
|
|
104
104
|
}
|
|
105
105
|
self.clip_lambda = clip_lambda
|
|
@@ -179,7 +179,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
179
179
|
|
|
180
180
|
beta1, beta2 = group['betas']
|
|
181
181
|
if self.use_AdEMAMix:
|
|
182
|
-
|
|
182
|
+
beta3_ema = group['beta3_ema']
|
|
183
183
|
alpha = group['alpha']
|
|
184
184
|
t_alpha = group['t_alpha']
|
|
185
185
|
# Use step+1 for 1-based step count in scheduler
|
|
@@ -236,7 +236,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
236
236
|
del mask
|
|
237
237
|
|
|
238
238
|
if self.use_AdEMAMix:
|
|
239
|
-
mt_slow = mt_slow_prev.mul_(
|
|
239
|
+
mt_slow = mt_slow_prev.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
|
|
240
240
|
update = mt + (alpha_t * mt_slow)
|
|
241
241
|
update = update.view(p.shape)
|
|
242
242
|
else:
|
|
@@ -293,7 +293,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
293
293
|
del mask
|
|
294
294
|
|
|
295
295
|
if self.use_AdEMAMix:
|
|
296
|
-
m_slow.mul_(
|
|
296
|
+
m_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
|
|
297
297
|
update = m + (alpha_t * m_slow)
|
|
298
298
|
else:
|
|
299
299
|
update = m
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
|
|
2
|
+
import torch.distributed as dist
|
|
3
|
+
|
|
3
4
|
import math
|
|
4
5
|
|
|
5
6
|
from ..util.BF16_Stochastic_Rounding import add_stochastic_
|
|
@@ -54,6 +55,23 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
54
55
|
the scheduler is disabled and th
|
|
55
56
|
factored (bool): whether to use the factorization or disable it to use
|
|
56
57
|
the uncompressed optimizer. (default: True)
|
|
58
|
+
d0 (float):
|
|
59
|
+
Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
|
|
60
|
+
d_coef (float):
|
|
61
|
+
Coefficient in the expression for the estimate of d (default 1.0).
|
|
62
|
+
Values such as 0.5 and 2.0 typically work as well.
|
|
63
|
+
Changing this parameter is the preferred way to tune the method.
|
|
64
|
+
growth_rate (float):
|
|
65
|
+
prevent the D estimate from growing faster than this multiplicative rate.
|
|
66
|
+
Default is inf, for unrestricted. Values like 1.02 give a kind of learning
|
|
67
|
+
rate warmup effect.
|
|
68
|
+
fsdp_in_use (bool):
|
|
69
|
+
If you're using sharded parameters, this should be set to True. The optimizer
|
|
70
|
+
will attempt to auto-detect this, but if you're using an implementation other
|
|
71
|
+
than PyTorch's builtin version, the auto-detection won't work.
|
|
72
|
+
slice_p (int): Reduce memory usage by calculating LR adaptation statistics on only every
|
|
73
|
+
pth entry of each tensor. For values greater than 1 this an an approximation to standard
|
|
74
|
+
Prodigy. Values ~11 are reasonable (default 1).
|
|
57
75
|
"""
|
|
58
76
|
|
|
59
77
|
def __init__(
|
|
@@ -80,6 +98,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
80
98
|
d_coef: float = 1,
|
|
81
99
|
growth_rate: float = float('inf'),
|
|
82
100
|
safeguard_warmup: bool = False,
|
|
101
|
+
fsdp_in_use: bool = False,
|
|
83
102
|
slice_p: int = 11,
|
|
84
103
|
):
|
|
85
104
|
if not (lr >= 0.0):
|
|
@@ -98,12 +117,14 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
98
117
|
"beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
|
|
99
118
|
"beta3": beta3, "d": d0, "d0": d0, "d_max": d0, "d_numerator": 0.0, "d_coef": d_coef,
|
|
100
119
|
"growth_rate": growth_rate, "safeguard_warmup": safeguard_warmup, "k": 0, "slice_p": slice_p,
|
|
120
|
+
"fsdp_in_use": fsdp_in_use,
|
|
101
121
|
}
|
|
102
122
|
self.stochastic_rounding = stochastic_rounding
|
|
103
123
|
self.use_cautious = use_cautious
|
|
104
124
|
self.use_grams = use_grams
|
|
105
125
|
self.use_AdEMAMix = use_AdEMAMix
|
|
106
126
|
self.factored = factored
|
|
127
|
+
self.fsdp_in_use = fsdp_in_use
|
|
107
128
|
super().__init__(params, defaults)
|
|
108
129
|
self.init_step()
|
|
109
130
|
|
|
@@ -142,6 +163,9 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
142
163
|
if p.grad is None:
|
|
143
164
|
return
|
|
144
165
|
|
|
166
|
+
if hasattr(p, "_fsdp_flattened"):
|
|
167
|
+
self.fsdp_in_use = True
|
|
168
|
+
|
|
145
169
|
grad = p.grad
|
|
146
170
|
if grad.dtype != torch.float32 and self.factored:
|
|
147
171
|
grad = grad.float()
|
|
@@ -349,8 +373,16 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
349
373
|
g_group = self.param_groups[0]
|
|
350
374
|
d_max, d_coef, growth_rate = g_group['d_max'], g_group['d_coef'], g_group['growth_rate']
|
|
351
375
|
|
|
352
|
-
|
|
353
|
-
|
|
376
|
+
if self.fsdp_in_use and dist.is_available() and dist.is_initialized():
|
|
377
|
+
# Use the device of the first parameter to avoid hardcoding '.cuda()'
|
|
378
|
+
device = self.param_groups[0]['params'][0].device
|
|
379
|
+
dist_tensor = torch.tensor([self.d_numerator, self.d_denom], device=device)
|
|
380
|
+
dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
|
|
381
|
+
global_d_numerator = dist_tensor[0].item()
|
|
382
|
+
global_d_denom = dist_tensor[1].item()
|
|
383
|
+
else:
|
|
384
|
+
global_d_numerator = self.d_numerator
|
|
385
|
+
global_d_denom = self.d_denom
|
|
354
386
|
|
|
355
387
|
d_hat = self.d
|
|
356
388
|
if global_d_denom > 0:
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|