adv-optm 1.2.dev5__tar.gz → 1.2.dev6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev6}/PKG-INFO +1 -1
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev6}/adv_optm/__init__.py +1 -1
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev6}/adv_optm/util/Kourkoutas.py +21 -5
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev6}/adv_optm.egg-info/PKG-INFO +1 -1
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev6}/setup.py +1 -1
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev6}/LICENSE +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev6}/README.md +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev6}/adv_optm/optim/AdaMuon_adv.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev6}/adv_optm/optim/AdamW_adv.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev6}/adv_optm/optim/Adopt_adv.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev6}/adv_optm/optim/Lion_Prodigy_adv.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev6}/adv_optm/optim/Lion_adv.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev6}/adv_optm/optim/Muon_adv.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev6}/adv_optm/optim/Prodigy_adv.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev6}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev6}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev6}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev6}/adv_optm/util/Effective_Shape.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev6}/adv_optm/util/MuonAdam_helper.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev6}/adv_optm/util/NNMF.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev6}/adv_optm/util/Newton_Schulz.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev6}/adv_optm/util/One_Bit_Boolean.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev6}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev6}/adv_optm/util/__init__.py +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev6}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev6}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev6}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev6}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-1.2.dev5 → adv_optm-1.2.dev6}/setup.cfg +0 -0
|
@@ -86,9 +86,17 @@ class KourkoutasHelper:
|
|
|
86
86
|
# These are just for the sample log, initialize them
|
|
87
87
|
sun, pooled_grad_norm, prev_r_ema_val, r_ema_tensor = (torch.tensor(0.0),)*4
|
|
88
88
|
|
|
89
|
+
# The optimizer that owns this helper holds the master defaults for K-b.
|
|
90
|
+
# This is crucial in hybrid optimizers where some param_groups might not
|
|
91
|
+
# have all K-b keys populated, preventing KeyErrors.
|
|
92
|
+
master_defaults = self.optimizer.defaults
|
|
93
|
+
|
|
89
94
|
for layer_key, info in self.layer_info.items():
|
|
90
95
|
params, group = info['params'], info['group_ref']
|
|
91
96
|
|
|
97
|
+
if not group.get('kourkoutas_beta', False):
|
|
98
|
+
continue
|
|
99
|
+
|
|
92
100
|
first_param_in_layer = info['params'][0]
|
|
93
101
|
param_state = self.optimizer.state[first_param_in_layer]
|
|
94
102
|
|
|
@@ -100,6 +108,15 @@ class KourkoutasHelper:
|
|
|
100
108
|
if 'kourkoutas_r_ema' not in param_state:
|
|
101
109
|
param_state['kourkoutas_r_ema'] = torch.tensor(0.0, device=first_param_in_layer.device, dtype=torch.float32)
|
|
102
110
|
|
|
111
|
+
# Use group-specific K-b settings, falling back to the optimizer's master defaults.
|
|
112
|
+
# This makes the helper robust against param groups that enable kourkoutas_beta
|
|
113
|
+
# but are missing the other required hyperparameters.
|
|
114
|
+
ema_alpha = group.get('ema_alpha', master_defaults['ema_alpha'])
|
|
115
|
+
beta2_max = group.get('betas', master_defaults['betas'])[1]
|
|
116
|
+
beta2_min = group.get('beta2_min', master_defaults['beta2_min'])
|
|
117
|
+
tiny_spike = group.get('tiny_spike', master_defaults['tiny_spike'])
|
|
118
|
+
k_warmup_steps = group.get('k_warmup_steps', master_defaults['k_warmup_steps'])
|
|
119
|
+
|
|
103
120
|
r_ema_tensor = param_state['kourkoutas_r_ema']
|
|
104
121
|
accumulator = self.layer_state[layer_key]['sum_sq_accumulator']
|
|
105
122
|
|
|
@@ -107,17 +124,16 @@ class KourkoutasHelper:
|
|
|
107
124
|
prev_r_ema_val = r_ema_tensor.item() # for logging
|
|
108
125
|
|
|
109
126
|
# Update the persistent EMA tensor in-place.
|
|
110
|
-
r_ema_tensor.mul_(
|
|
127
|
+
r_ema_tensor.mul_(ema_alpha).add_(pooled_grad_norm, alpha=1.0 - ema_alpha)
|
|
111
128
|
|
|
112
|
-
beta2_max = group['betas'][1]
|
|
113
129
|
sun = torch.tensor(0.0, device=r_ema_tensor.device) # Default sun to 0 for warmup
|
|
114
130
|
|
|
115
|
-
if current_step <
|
|
131
|
+
if current_step < k_warmup_steps:
|
|
116
132
|
beta2 = beta2_max
|
|
117
133
|
else:
|
|
118
|
-
raw = pooled_grad_norm / (r_ema_tensor +
|
|
134
|
+
raw = pooled_grad_norm / (r_ema_tensor + tiny_spike)
|
|
119
135
|
sun = raw / (1.0 + raw)
|
|
120
|
-
beta2 = beta2_max - (beta2_max -
|
|
136
|
+
beta2 = beta2_max - (beta2_max - beta2_min) * sun
|
|
121
137
|
|
|
122
138
|
# Store the final calculated beta2 in the helper's transient state for this step.
|
|
123
139
|
self.layer_state[layer_key]['dynamic_beta2'] = beta2.item() if isinstance(beta2, torch.Tensor) else beta2
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|