adv-optm 1.2.dev3__py3-none-any.whl → 1.2.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/AdaMuon_adv.py +43 -21
- adv_optm/optim/Muon_adv.py +35 -17
- {adv_optm-1.2.dev3.dist-info → adv_optm-1.2.dev4.dist-info}/METADATA +1 -1
- {adv_optm-1.2.dev3.dist-info → adv_optm-1.2.dev4.dist-info}/RECORD +8 -8
- {adv_optm-1.2.dev3.dist-info → adv_optm-1.2.dev4.dist-info}/WHEEL +0 -0
- {adv_optm-1.2.dev3.dist-info → adv_optm-1.2.dev4.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-1.2.dev3.dist-info → adv_optm-1.2.dev4.dist-info}/top_level.txt +0 -0
adv_optm/__init__.py
CHANGED
adv_optm/optim/AdaMuon_adv.py
CHANGED
|
@@ -135,7 +135,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
135
135
|
print("Warning: nesterov is incompatible with Simplified_AdEMAMix, Disabling cautious.")
|
|
136
136
|
nesterov = False
|
|
137
137
|
|
|
138
|
-
|
|
138
|
+
muon_defaults = {
|
|
139
139
|
"lr": lr, "betas": betas, "weight_decay": weight_decay,
|
|
140
140
|
"eps": eps, "rms_target": rms_target, "ns_steps": ns_steps,
|
|
141
141
|
"ns_eps": ns_eps, "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
|
|
@@ -154,22 +154,41 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
154
154
|
self.helper = None
|
|
155
155
|
self.aux_adam = None
|
|
156
156
|
|
|
157
|
-
if self.MuonWithAuxAdam:
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
self.aux_adam = AdamW_adv(
|
|
161
|
-
[],
|
|
162
|
-
lr=muon_adam_lr,
|
|
163
|
-
**adam_kwargs,
|
|
164
|
-
_is_delegate=True
|
|
165
|
-
)
|
|
166
|
-
# Update the defaults dictionary
|
|
167
|
-
defaults.update(self.aux_adam.defaults)
|
|
168
|
-
|
|
169
|
-
super().__init__(params, defaults)
|
|
157
|
+
if not self.MuonWithAuxAdam:
|
|
158
|
+
super().__init__(params, muon_defaults)
|
|
159
|
+
return
|
|
170
160
|
|
|
171
|
-
|
|
172
|
-
|
|
161
|
+
# HYBRID OPTIMIZER LOGIC
|
|
162
|
+
adam_kwargs = adam_kwargs or {}
|
|
163
|
+
self.aux_adam = AdamW_adv(
|
|
164
|
+
[],
|
|
165
|
+
lr=muon_adam_lr,
|
|
166
|
+
**adam_kwargs,
|
|
167
|
+
_is_delegate=True
|
|
168
|
+
)
|
|
169
|
+
adam_defaults = self.aux_adam.defaults
|
|
170
|
+
|
|
171
|
+
final_param_groups = []
|
|
172
|
+
_layer_key_fn = layer_key_fn if layer_key_fn is not None else lambda p: 'muon'
|
|
173
|
+
|
|
174
|
+
for group in params:
|
|
175
|
+
# All params in a group are of the same type
|
|
176
|
+
first_param = group['params'][0]
|
|
177
|
+
key = _layer_key_fn(first_param)
|
|
178
|
+
optim_type = 'adam' if key == 'adam' else 'muon'
|
|
179
|
+
|
|
180
|
+
new_group = group.copy()
|
|
181
|
+
defaults_to_use = adam_defaults if optim_type == 'adam' else muon_defaults
|
|
182
|
+
|
|
183
|
+
for key, value in defaults_to_use.items():
|
|
184
|
+
new_group.setdefault(key, value)
|
|
185
|
+
|
|
186
|
+
final_param_groups.append(new_group)
|
|
187
|
+
|
|
188
|
+
super().__init__(final_param_groups, {})
|
|
189
|
+
|
|
190
|
+
# Now that self is initialized, create the helper
|
|
191
|
+
self.helper = MuonAdamHelper(self, layer_key_fn)
|
|
173
192
|
|
|
174
193
|
|
|
175
194
|
@property
|
|
@@ -196,21 +215,24 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
196
215
|
|
|
197
216
|
@torch.no_grad()
|
|
198
217
|
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
218
|
+
if group['_kourkoutas_beta'] and self._kourkoutas_helper is None:
|
|
219
|
+
self._kourkoutas_helper = KourkoutasHelper(self)
|
|
220
|
+
|
|
199
221
|
if self.MuonWithAuxAdam:
|
|
200
222
|
optim_type = self.helper.get_optimizer_type(p)
|
|
201
223
|
if optim_type == 'adam':
|
|
202
224
|
# Delegate to the AdamW_adv optimizer's logic.
|
|
203
225
|
# We need to temporarily "lend" our state and param_groups
|
|
204
|
-
# to the delegate so it has the full context to work with,
|
|
205
|
-
# especially for features like Kourkoutas-beta.
|
|
206
226
|
self.aux_adam.state = self.state
|
|
207
227
|
self.aux_adam.param_groups = self.param_groups
|
|
228
|
+
|
|
229
|
+
# Ensure the aux optimizer uses the same Kourkoutas helper instance.
|
|
230
|
+
if self._kourkoutas_helper is not None:
|
|
231
|
+
self.aux_adam.kourkoutas_helper = self._kourkoutas_helper
|
|
232
|
+
|
|
208
233
|
self.aux_adam.step_parameter(p, group, i)
|
|
209
234
|
return
|
|
210
235
|
|
|
211
|
-
if group['_kourkoutas_beta'] and self._kourkoutas_helper is None:
|
|
212
|
-
self._kourkoutas_helper = KourkoutasHelper(self)
|
|
213
|
-
|
|
214
236
|
if p.grad is None:
|
|
215
237
|
return
|
|
216
238
|
|
adv_optm/optim/Muon_adv.py
CHANGED
|
@@ -100,7 +100,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
100
100
|
print("Warning: nesterov is incompatible with Simplified_AdEMAMix, Disabling cautious.")
|
|
101
101
|
nesterov = False
|
|
102
102
|
|
|
103
|
-
|
|
103
|
+
muon_defaults = {
|
|
104
104
|
"lr": lr, "beta1": beta1, "weight_decay": weight_decay,
|
|
105
105
|
"nesterov": nesterov, "ns_steps": ns_steps, "ns_eps": ns_eps,
|
|
106
106
|
"ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
|
|
@@ -114,23 +114,41 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
114
114
|
self.helper = None
|
|
115
115
|
self.aux_adam = None
|
|
116
116
|
|
|
117
|
-
if self.MuonWithAuxAdam:
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
self.aux_adam = AdamW_adv(
|
|
121
|
-
[],
|
|
122
|
-
lr=muon_adam_lr,
|
|
123
|
-
**adam_kwargs,
|
|
124
|
-
_is_delegate=True
|
|
125
|
-
)
|
|
126
|
-
# Update the defaults dictionary
|
|
127
|
-
defaults.update(self.aux_adam.defaults)
|
|
128
|
-
|
|
129
|
-
super().__init__(params, defaults)
|
|
117
|
+
if not self.MuonWithAuxAdam:
|
|
118
|
+
super().__init__(params, muon_defaults)
|
|
119
|
+
return
|
|
130
120
|
|
|
131
|
-
|
|
132
|
-
|
|
121
|
+
# HYBRID OPTIMIZER LOGIC
|
|
122
|
+
adam_kwargs = adam_kwargs or {}
|
|
123
|
+
self.aux_adam = AdamW_adv(
|
|
124
|
+
[],
|
|
125
|
+
lr=muon_adam_lr,
|
|
126
|
+
**adam_kwargs,
|
|
127
|
+
_is_delegate=True
|
|
128
|
+
)
|
|
129
|
+
adam_defaults = self.aux_adam.defaults
|
|
130
|
+
|
|
131
|
+
final_param_groups = []
|
|
132
|
+
_layer_key_fn = layer_key_fn if layer_key_fn is not None else lambda p: 'muon'
|
|
133
|
+
|
|
134
|
+
for group in params:
|
|
135
|
+
first_param = group['params'][0]
|
|
136
|
+
key = _layer_key_fn(first_param)
|
|
137
|
+
optim_type = 'adam' if key == 'adam' else 'muon'
|
|
138
|
+
|
|
139
|
+
new_group = group.copy()
|
|
140
|
+
defaults_to_use = adam_defaults if optim_type == 'adam' else muon_defaults
|
|
141
|
+
|
|
142
|
+
for key, value in defaults_to_use.items():
|
|
143
|
+
new_group.setdefault(key, value)
|
|
144
|
+
|
|
145
|
+
final_param_groups.append(new_group)
|
|
146
|
+
|
|
147
|
+
super().__init__(final_param_groups, {})
|
|
133
148
|
|
|
149
|
+
# Now that self is initialized, create the helper
|
|
150
|
+
self.helper = MuonAdamHelper(self, layer_key_fn)
|
|
151
|
+
|
|
134
152
|
|
|
135
153
|
@property
|
|
136
154
|
def supports_fused_back_pass(self):
|
|
@@ -335,4 +353,4 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
335
353
|
for i, p in enumerate(group['params']):
|
|
336
354
|
self.step_parameter(p, group, i)
|
|
337
355
|
|
|
338
|
-
return loss
|
|
356
|
+
return loss
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
adv_optm/__init__.py,sha256=
|
|
2
|
-
adv_optm/optim/AdaMuon_adv.py,sha256=
|
|
1
|
+
adv_optm/__init__.py,sha256=bB7_VywKpvZbcGCjtZoF8giQgcUgoziISBgIaEUpcAw,379
|
|
2
|
+
adv_optm/optim/AdaMuon_adv.py,sha256=s5UkR2YJ_Z10SiBokT97eq4tCHc2D8BEOFDx5AOMryQ,20983
|
|
3
3
|
adv_optm/optim/AdamW_adv.py,sha256=7IvdD1rqYeHZwQCZU9X0H7x87MCKcHQ5M68GLuMCkvE,17702
|
|
4
4
|
adv_optm/optim/Adopt_adv.py,sha256=C2FsEZGvCk9q4YNKAj0qIxdZ5AfPlda-1lIpSX0a1nE,21256
|
|
5
5
|
adv_optm/optim/Lion_Prodigy_adv.py,sha256=LEA3UYJpPeFnmxeniLNv1u2LKKj4ufx3Bq_MLw-nWXk,14617
|
|
6
6
|
adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
|
|
7
|
-
adv_optm/optim/Muon_adv.py,sha256=
|
|
7
|
+
adv_optm/optim/Muon_adv.py,sha256=vB-Eeh0IqYMd3lkQvIPEbH256bTyYO73OgIzn0N2VCk,14985
|
|
8
8
|
adv_optm/optim/Prodigy_adv.py,sha256=bmwuO8GrJHH4NaEaqE-ffcR9wHhQ57457xoN-P6hyks,25909
|
|
9
9
|
adv_optm/optim/Simplified_AdEMAMix.py,sha256=sY-vThMVgADRh0ar9WHkrM2n8UcgQLQC1YV1Wx8uFz4,12983
|
|
10
10
|
adv_optm/optim/__init__.py,sha256=hpUWE6CKtt_rvMdgQVb3PtjhfZAvAxTq6hp8H8rIpBo,489
|
|
@@ -17,8 +17,8 @@ adv_optm/util/Newton_Schulz.py,sha256=wJ_sKRaGVIsOofQ737my4ng494qX_pfgOqlDDmYtnC
|
|
|
17
17
|
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
18
18
|
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
19
19
|
adv_optm/util/__init__.py,sha256=jAaUfaAjFrTJ6-Q915ezAbq0efRbpYjriW2OdeCbSzo,433
|
|
20
|
-
adv_optm-1.2.
|
|
21
|
-
adv_optm-1.2.
|
|
22
|
-
adv_optm-1.2.
|
|
23
|
-
adv_optm-1.2.
|
|
24
|
-
adv_optm-1.2.
|
|
20
|
+
adv_optm-1.2.dev4.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
21
|
+
adv_optm-1.2.dev4.dist-info/METADATA,sha256=jNczVxIPq0LuusXuGrZ23CQ4CrMNOfJdBDpDQgulMUw,14022
|
|
22
|
+
adv_optm-1.2.dev4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
23
|
+
adv_optm-1.2.dev4.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
24
|
+
adv_optm-1.2.dev4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|