adv-optm 1.2.dev2__py3-none-any.whl → 1.2.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

adv_optm/__init__.py CHANGED
@@ -6,6 +6,7 @@ from .optim import (
6
6
  Lion_adv,
7
7
  Lion_Prodigy_adv,
8
8
  Muon_adv,
9
+ AdaMuon_adv,
9
10
  )
10
11
 
11
12
  __all__ = [
@@ -16,6 +17,7 @@ __all__ = [
16
17
  "Lion_adv",
17
18
  "Lion_Prodigy_adv",
18
19
  "Muon_adv",
20
+ "AdaMuon_adv",
19
21
  ]
20
22
 
21
- __version__ = "1.2.dev2"
23
+ __version__ = "1.2.dev4"
@@ -0,0 +1,465 @@
1
+ import torch
2
+ from typing import Optional, Callable
3
+
4
+ from .AdamW_adv import AdamW_adv
5
+ from ..util.MuonAdam_helper import MuonAdamHelper
6
+ from ..util.Kourkoutas import KourkoutasHelper
7
+
8
+ from ..util.BF16_Stochastic_Rounding import add_stochastic_
9
+ from ..util.Newton_Schulz import _newton_schulz_iteration
10
+ from ..util.Effective_Shape import _get_effective_shape
11
+ from ..util.NNMF import _nnmf,_unnmf
12
+ from ..util.One_Bit_Boolean import _pack_bools, _unpack_bools
13
+
14
+ class AdaMuon_adv(torch.optim.Optimizer):
15
+ """
16
+ Implements the AdaMuon optimizer algorithm.
17
+
18
+ AdaMuon combines the geometry-aware updates of Muon with the element-wise
19
+ adaptivity of Adam. It is designed for 2D parameters (e.g., linear layers)
20
+ and can handle higher-dimensional parameters by flattening.
21
+
22
+ The algorithm incorporates three key mechanisms:
23
+ 1. A sign-stabilized orthogonal update, where the sign of the momentum is
24
+ orthogonalized instead of the momentum itself.
25
+ 2. An element-wise second momentum estimator applied to the orthogonalized
26
+ update directions.
27
+ 3. An RMS-aligned rescaling strategy to match the update magnitude of Adam,
28
+ allowing for reuse of learning rate schedules.
29
+
30
+ Can also operate in a hybrid mode, using an auxiliary AdamW
31
+ optimizer for specific parameters (e.g., biases, norms, embeddings) as
32
+ defined by a `layer_key_fn`.
33
+
34
+ Args:
35
+ params (iterable): iterable of parameters to optimize or dicts defining
36
+ parameter groups.
37
+ lr (float): learning rate (default: 1e-3).
38
+ betas (tuple[float, float]): coefficients used for both first and second moment
39
+ estimation (default: (0.95, 0.95))
40
+ weight_decay (float): weight decay (L2 penalty) (default: 0.1).
41
+ eps (float): term added to the denominator for adaptive scaling to improve
42
+ numerical stability (default: 1e-8).
43
+ rms_target (float): The target Root-Mean-Square value for the final update
44
+ vector, used for RMS-aligned rescaling. Allows for the reuse of existing Adam
45
+ learning rate schedules. (default: 0.2).
46
+ ns_steps (int): number of Newton-Schulz iterations to perform (default: 5).
47
+ ns_eps (float): epsilon for Newton-Schulz normalization stability (default: 1e-7).
48
+ ns_coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
49
+ quintic polynomial in the Newton-Schulz iteration.
50
+ (default: (3.4445, -4.7750, 2.0315)).
51
+ stochastic_rounding (bool): whether to use stochastic rounding for
52
+ BF16 parameter updates (default: True).
53
+ nesterov (bool): enables Nesterov momentum (default: False).
54
+ use_atan2 (bool): whether to use the atan2 update rule. (default: False)
55
+ Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
56
+ This changes the update to `alpha_grad * grad + mt`, which can be
57
+ more responsive, especially for small batch sizes. (default: False)
58
+ alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
59
+ (only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
60
+ current gradient. For small batch sizes, use high values (e.g., 10-100) to be
61
+ more responsive. For large batch sizes, use low values (e.g., 0-1) for
62
+ stability. (default: 100.0)
63
+ vector_reshape_muon (bool): whether to reshape 1D vectors into 2D
64
+ matrices for muon NewtonSchulz (default: False).
65
+ vector_reshape (bool): whether to reshape 1D vectors into 2D
66
+ matrices to apply low-rank compression (default: True).
67
+ nnmf_factor (bool): whether to use the factorization or disable it to use
68
+ the uncompressed optimizer. (default: False)
69
+ kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
70
+ If `False`, the optimizer behaves as standard AdamW. (default: False)
71
+ beta2_min (float): The minimum value for dynamic β₂, used during periods of
72
+ high gradient variance ("sunspikes"). Must be less than `betas[1]`.
73
+ (default: 0.88)
74
+ ema_alpha (float): The decay rate for the Exponential Moving Average (EMA) of
75
+ the pooled gradient norms. Corresponds to `α` in the paper.
76
+ (default: 0.93)
77
+ tiny_spike (float): A small constant added to the denominator of the
78
+ "sunspike" ratio calculation to prevent division by zero. Corresponds
79
+ to `ε_spike` in the paper. (default: 1e-9)
80
+ k_warmup_steps (int): The number of initial steps during which β₂ is held
81
+ at a fixed beta2 value before the
82
+ dynamic logic activates. (default: 0)
83
+ MuonWithAuxAdam (bool): If True, enables the hybrid optimizer mode.
84
+ Parameters designated by `layer_key_fn` will be optimized with
85
+ AdamW_adv instead of Muon. (default: False)
86
+ layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
87
+ and returns a key. If the key is 'adam', the parameter is handled by
88
+ the auxiliary AdamW optimizer. All other keys are handled by Muon.
89
+ Only used when `MuonWithAuxAdam` is True. (default: None)
90
+ adam_kwargs (Optional[dict]): A dictionary of keyword arguments to pass
91
+ to the auxiliary AdamW_adv optimizer. Only used when
92
+ `MuonWithAuxAdam` is True. (default: None)
93
+ """
94
+
95
+ def __init__(
96
+ self,
97
+ params,
98
+ lr: float = 1e-3,
99
+ betas: tuple[float, float] = (0.95, 0.95),
100
+ weight_decay: float = 0.1,
101
+ eps: float = 1e-8,
102
+ rms_target: float = 0.2,
103
+ ns_steps: int = 5,
104
+ ns_eps: float = 1e-7,
105
+ ns_coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
106
+ stochastic_rounding: bool = True,
107
+ use_atan2: bool = False,
108
+ nesterov: bool = False,
109
+ Simplified_AdEMAMix: bool = False,
110
+ alpha_grad: float = 100.0,
111
+ vector_reshape_muon: bool = False,
112
+ vector_reshape: bool = False,
113
+ nnmf_factor: bool = False,
114
+ # K-b parameters
115
+ kourkoutas_beta: bool = False,
116
+ beta2_min: float = 0.9,
117
+ ema_alpha: float = 0.95,
118
+ tiny_spike: float = 1e-9,
119
+ k_warmup_steps: int = 0,
120
+ k_logging: int = 0,
121
+ layer_key_kb_fn: Optional[Callable] = None,
122
+ # hybrid optimizer mode
123
+ MuonWithAuxAdam: bool = False,
124
+ layer_key_fn: Optional[Callable] = None,
125
+ muon_adam_lr: float = 1e-4,
126
+ adam_kwargs: Optional[dict] = None,
127
+ ):
128
+ if not (lr >= 0.0):
129
+ raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
130
+ if not (weight_decay >= 0.0):
131
+ raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
132
+ if not (ns_steps > 0):
133
+ raise ValueError(f"Newton-Schulz steps should be > 0. Got {ns_steps}")
134
+ if Simplified_AdEMAMix and nesterov:
135
+ print("Warning: nesterov is incompatible with Simplified_AdEMAMix, Disabling cautious.")
136
+ nesterov = False
137
+
138
+ muon_defaults = {
139
+ "lr": lr, "betas": betas, "weight_decay": weight_decay,
140
+ "eps": eps, "rms_target": rms_target, "ns_steps": ns_steps,
141
+ "ns_eps": ns_eps, "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
142
+ "vector_reshape": vector_reshape,
143
+ "vector_reshape_muon": vector_reshape_muon,
144
+ "nesterov":nesterov, "use_atan2":use_atan2,
145
+ "Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
146
+ "_kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
147
+ "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
148
+ }
149
+ self.stochastic_rounding = stochastic_rounding
150
+ self._kourkoutas_beta = kourkoutas_beta
151
+ self._kourkoutas_helper = None
152
+ self.layer_key_kb_fn = layer_key_kb_fn
153
+ self.MuonWithAuxAdam = MuonWithAuxAdam
154
+ self.helper = None
155
+ self.aux_adam = None
156
+
157
+ if not self.MuonWithAuxAdam:
158
+ super().__init__(params, muon_defaults)
159
+ return
160
+
161
+ # HYBRID OPTIMIZER LOGIC
162
+ adam_kwargs = adam_kwargs or {}
163
+ self.aux_adam = AdamW_adv(
164
+ [],
165
+ lr=muon_adam_lr,
166
+ **adam_kwargs,
167
+ _is_delegate=True
168
+ )
169
+ adam_defaults = self.aux_adam.defaults
170
+
171
+ final_param_groups = []
172
+ _layer_key_fn = layer_key_fn if layer_key_fn is not None else lambda p: 'muon'
173
+
174
+ for group in params:
175
+ # All params in a group are of the same type
176
+ first_param = group['params'][0]
177
+ key = _layer_key_fn(first_param)
178
+ optim_type = 'adam' if key == 'adam' else 'muon'
179
+
180
+ new_group = group.copy()
181
+ defaults_to_use = adam_defaults if optim_type == 'adam' else muon_defaults
182
+
183
+ for key, value in defaults_to_use.items():
184
+ new_group.setdefault(key, value)
185
+
186
+ final_param_groups.append(new_group)
187
+
188
+ super().__init__(final_param_groups, {})
189
+
190
+ # Now that self is initialized, create the helper
191
+ self.helper = MuonAdamHelper(self, layer_key_fn)
192
+
193
+
194
+ @property
195
+ def supports_fused_back_pass(self):
196
+ return True
197
+
198
+ @property
199
+ def supports_memory_efficient_fp16(self):
200
+ return True
201
+
202
+ @property
203
+ def supports_flat_params(self):
204
+ return False
205
+
206
+ @property
207
+ def kourkoutas_helper(self):
208
+ """
209
+ Exposes the kourkoutas_helper from the auxiliary AdamW optimizer,
210
+ if it exists. This allows external access for logging K-b.
211
+ """
212
+ if self.aux_adam and hasattr(self.aux_adam, 'kourkoutas_helper'):
213
+ return self.aux_adam.kourkoutas_helper
214
+ return None
215
+
216
+ @torch.no_grad()
217
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
218
+ if group['_kourkoutas_beta'] and self._kourkoutas_helper is None:
219
+ self._kourkoutas_helper = KourkoutasHelper(self)
220
+
221
+ if self.MuonWithAuxAdam:
222
+ optim_type = self.helper.get_optimizer_type(p)
223
+ if optim_type == 'adam':
224
+ # Delegate to the AdamW_adv optimizer's logic.
225
+ # We need to temporarily "lend" our state and param_groups
226
+ self.aux_adam.state = self.state
227
+ self.aux_adam.param_groups = self.param_groups
228
+
229
+ # Ensure the aux optimizer uses the same Kourkoutas helper instance.
230
+ if self._kourkoutas_helper is not None:
231
+ self.aux_adam.kourkoutas_helper = self._kourkoutas_helper
232
+
233
+ self.aux_adam.step_parameter(p, group, i)
234
+ return
235
+
236
+ if p.grad is None:
237
+ return
238
+
239
+ grad = p.grad
240
+ state = self.state[p]
241
+
242
+
243
+ # State Initialization
244
+ if 'step' not in state:
245
+ state['step'] = 0
246
+
247
+ should_factor = (
248
+ group['nnmf_factor'] and
249
+ not (len(p.shape) == 1 and not group['vector_reshape'])
250
+ )
251
+
252
+ state['factored'] = should_factor
253
+
254
+ state['reshaped_1d_muon'] = len(p.shape) == 1 and group['vector_reshape_muon']
255
+
256
+ dtype = torch.float32 if group['nnmf_factor'] else p.dtype
257
+ device = p.device
258
+ if state['factored'] or state['reshaped_1d_muon']:
259
+ state['effective_shape'] = _get_effective_shape(p.numel())
260
+ d1, d2 = state['effective_shape']
261
+ if state['factored']:
262
+ state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
263
+ state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
264
+ packed_d2 = (d2 + 7) // 8
265
+ state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
266
+ state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
267
+ state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
268
+ else:
269
+ if len(p.shape) >= 2:
270
+ state['momentum_buffer'] = torch.zeros_like(p)
271
+ state['second_momentum_buffer'] = torch.zeros_like(p)
272
+ if state['reshaped_1d_muon']:
273
+ state['momentum_buffer'] = torch.zeros((d1, d2), device=device, dtype=dtype)
274
+ state['second_momentum_buffer'] = torch.zeros((d1, d2), device=device, dtype=dtype)
275
+ elif len(p.shape) == 1:
276
+ state['momentum_buffer'] = torch.zeros_like(p)
277
+
278
+ # Retrieve hyperparameters
279
+ beta1, beta2 = group['betas']
280
+ current_step = state['step']
281
+ nesterov = group['nesterov']
282
+ Simplified_AdEMAMix = group['Simplified_AdEMAMix']
283
+ alpha_grad = group['alpha_grad']
284
+
285
+ if state['factored']: # Factored AdaMuon
286
+
287
+ # Reconstruct momentum from previous step's factors & sign
288
+ d1, d2 = state['effective_shape']
289
+ mt_buf = _unnmf((state['mu_m_nmf'], state['mv_m_nmf']))
290
+ unpacked_sign = _unpack_bools(state['sign'], original_m=d2)
291
+ torch.where(unpacked_sign, mt_buf, -mt_buf, out=mt_buf)
292
+ del unpacked_sign
293
+
294
+ # Update momentum in full-size
295
+ grad_reshaped = grad.view(d1, d2)
296
+ mt_buf.mul_(beta1).add_(grad_reshaped)
297
+
298
+ if nesterov:
299
+ signed_m_buf = torch.sign(grad_reshaped.add(mt_buf, alpha=beta1))
300
+ elif Simplified_AdEMAMix:
301
+ signed_m_buf = torch.sign(mt_buf.add(grad_reshaped, alpha=alpha_grad))
302
+ else:
303
+ signed_m_buf = torch.sign(mt_buf)
304
+ del grad_reshaped
305
+
306
+ update = _newton_schulz_iteration(
307
+ signed_m_buf,
308
+ steps=group['ns_steps'],
309
+ eps=group['ns_eps'],
310
+ coeffs=group['ns_coeffs'],
311
+ )
312
+
313
+ if group['_kourkoutas_beta']:
314
+ # Call prepare_step() once at the beginning of the step for all params
315
+ self._kourkoutas_helper.maybe_prepare_step(current_step)
316
+ # Accumulate current sign-stabilized orthogonal update's norm for the *next* step
317
+ self._kourkoutas_helper.accumulate_gradient_sq_norm(p, update.view(p.shape))
318
+ # Get the dynamic beta2 calculated in prepare_step()
319
+ beta2 = self._kourkoutas_helper.get_beta2(p, group, current_step)
320
+
321
+ # Reconstruct second momentum from previous step's factors
322
+ vt_buf = _unnmf((state['mu_v_nmf'], state['mv_v_nmf']))
323
+
324
+ # Update second momentum in full-size
325
+ vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
326
+
327
+ # Apply second momentum update (adaptive scaling)
328
+ if group['use_atan2']:
329
+ a = 1.2732395
330
+ denom = vt_buf.sqrt()
331
+ update.atan2_(denom).mul_(a)
332
+ else:
333
+ denom = vt_buf.sqrt().add_(group['eps'])
334
+ update.div_(denom)
335
+ del denom
336
+
337
+ # RMS-aligned rescaling
338
+ rms_target = group['rms_target']
339
+ num_elements = update.numel()
340
+ scaling_factor = rms_target * (num_elements ** 0.5) / (update.norm())
341
+
342
+ update.mul_(scaling_factor)
343
+ update = update.view(p.shape).mul_(group['lr'])
344
+ del num_elements, scaling_factor
345
+
346
+ # Compress updated moments and store new factors
347
+ state['sign'] = _pack_bools(mt_buf > 0)
348
+ _nnmf(mt_buf.abs(), out=(state['mu_m_nmf'], state['mv_m_nmf']))
349
+ del mt_buf
350
+
351
+ _nnmf(vt_buf.abs(), out=(state['mu_v_nmf'], state['mv_v_nmf']))
352
+ del vt_buf
353
+
354
+ else: # Standard AdaMuon logic for non-factored tensors
355
+
356
+ if len(p.shape) >= 2 or state['reshaped_1d_muon']:
357
+
358
+ # Momentum update
359
+ mt_buf = state['momentum_buffer']
360
+ if state['reshaped_1d_muon']:
361
+ d1, d2 = state['effective_shape']
362
+ grad_reshaped = grad.view(d1, d2)
363
+ mt_buf.mul_(beta1).add_(grad_reshaped)
364
+ if nesterov:
365
+ signed_m_buf = torch.sign(grad_reshaped.add(mt_buf, alpha=beta1))
366
+ elif Simplified_AdEMAMix:
367
+ signed_m_buf = torch.sign(mt_buf.add(grad_reshaped, alpha=alpha_grad))
368
+ else:
369
+ signed_m_buf = torch.sign(mt_buf)
370
+ del grad_reshaped
371
+ else:
372
+ mt_buf.mul_(beta1).add_(grad)
373
+ if nesterov:
374
+ signed_m_buf = torch.sign(grad.add(mt_buf, alpha=beta1))
375
+ elif Simplified_AdEMAMix:
376
+ signed_m_buf = torch.sign(mt_buf.add(grad, alpha=alpha_grad))
377
+ else:
378
+ signed_m_buf = torch.sign(mt_buf)
379
+
380
+ # Flatten if necessary (e.g., for Conv layers)
381
+ if len(p.shape) > 2:
382
+ signed_m_buf = signed_m_buf.view(p.shape[0], -1)
383
+
384
+ # NewtonSchulz
385
+ update = _newton_schulz_iteration(
386
+ signed_m_buf,
387
+ steps=group['ns_steps'],
388
+ eps=group['ns_eps'],
389
+ coeffs=group['ns_coeffs'],
390
+ )
391
+
392
+ if len(p.shape) > 2 or state['reshaped_1d_muon']:
393
+ update = update.view(p.shape)
394
+
395
+ if group['_kourkoutas_beta']:
396
+ # Call prepare_step() once at the beginning of the step for all params
397
+ self._kourkoutas_helper.maybe_prepare_step(current_step)
398
+ # Accumulate current sign-stabilized orthogonal update's norm for the *next* step
399
+ self._kourkoutas_helper.accumulate_gradient_sq_norm(p, update)
400
+ # Get the dynamic beta2 calculated in prepare_step()
401
+ beta2 = self._kourkoutas_helper.get_beta2(p, group, current_step)
402
+
403
+ vt_buf = state['second_momentum_buffer']
404
+ vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
405
+
406
+ # Apply second momentum update (adaptive scaling)
407
+ if group['use_atan2']:
408
+ a = 1.2732395
409
+ denom = vt_buf.sqrt()
410
+ update.atan2_(denom).mul_(a)
411
+ else:
412
+ denom = vt_buf.sqrt().add_(group['eps'])
413
+ update.div_(denom)
414
+ del denom
415
+
416
+ # RMS-aligned rescaling
417
+ rms_target = group['rms_target']
418
+ num_elements = update.numel()
419
+ scaling_factor = rms_target * (num_elements ** 0.5) / (update.norm())
420
+
421
+ update.mul_(scaling_factor)
422
+ del num_elements, scaling_factor
423
+
424
+ update.mul_(group['lr'])
425
+
426
+ else: # Fallback to standard SGD with momentum for 1D params (biases, etc.) when not reshaped
427
+ # Momentum update
428
+ mt_buf = state['momentum_buffer']
429
+ mt_buf.mul_(beta1).add_(grad)
430
+ if nesterov:
431
+ update = grad.add(mt_buf, alpha=beta1)
432
+ elif Simplified_AdEMAMix:
433
+ signed_m_buf = torch.sign(mt_buf.add(grad, alpha=alpha_grad))
434
+ else:
435
+ update = mt_buf.clone()
436
+ update.mul_(group['lr'])
437
+
438
+ # Decoupled weight decay
439
+ if group["weight_decay"] != 0:
440
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
441
+ add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
442
+ else:
443
+ p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
444
+
445
+ if p.dtype == torch.bfloat16 and self.stochastic_rounding:
446
+ add_stochastic_(p.data, -update)
447
+ else:
448
+ p.data.add_(-update)
449
+ del update
450
+
451
+ state['step'] += 1
452
+
453
+ @torch.no_grad()
454
+ def step(self, closure=None):
455
+ """Performs a single optimization step."""
456
+ loss = None
457
+ if closure is not None:
458
+ with torch.enable_grad():
459
+ loss = closure()
460
+
461
+ for group in self.param_groups:
462
+ for i, p in enumerate(group['params']):
463
+ self.step_parameter(p, group, i)
464
+
465
+ return loss
@@ -73,7 +73,7 @@ class AdamW_adv(torch.optim.Optimizer):
73
73
  logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
74
74
  every logging steps. Useful for debugging and tuning. Set to 0 to disable
75
75
  logging (default: 0).
76
- layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
76
+ layer_key_kb_fn (Optional[Callable]): A function that takes a parameter `p`
77
77
  and returns a unique, hashable key representing its "layer" or "bucket".
78
78
  If `None`, parameters are bucketed by their memory ID (tensor-wise).
79
79
  (default: None)
@@ -105,7 +105,7 @@ class AdamW_adv(torch.optim.Optimizer):
105
105
  tiny_spike: float = 1e-9,
106
106
  k_warmup_steps: int = 0,
107
107
  k_logging: int = 0,
108
- layer_key_fn: Optional[Callable] = None,
108
+ layer_key_kb_fn: Optional[Callable] = None,
109
109
  nnmf_factor: bool = False,
110
110
  _is_delegate: bool = False,
111
111
  ):
@@ -137,7 +137,7 @@ class AdamW_adv(torch.optim.Optimizer):
137
137
  self.use_AdEMAMix = use_AdEMAMix
138
138
  self.factored = nnmf_factor
139
139
  self.kourkoutas_beta = kourkoutas_beta
140
- self.layer_key_fn = layer_key_fn
140
+ self.layer_key_kb_fn = layer_key_kb_fn
141
141
  if not _is_delegate:
142
142
  super().__init__(params, defaults)
143
143
  else:
@@ -91,7 +91,7 @@ class Adopt_adv(torch.optim.Optimizer):
91
91
  logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
92
92
  every logging steps. Useful for debugging and tuning. Set to 0 to disable
93
93
  logging (default: 0).
94
- layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
94
+ layer_key_kb_fn (Optional[Callable]): A function that takes a parameter `p`
95
95
  and returns a unique, hashable key representing its "layer" or "bucket".
96
96
  If `None`, parameters are bucketed by their memory ID (tensor-wise).
97
97
  (default: None)
@@ -125,7 +125,7 @@ class Adopt_adv(torch.optim.Optimizer):
125
125
  tiny_spike: float = 1e-9,
126
126
  k_warmup_steps: int = 0,
127
127
  k_logging: int = 0,
128
- layer_key_fn: Optional[Callable] = None,
128
+ layer_key_kb_fn: Optional[Callable] = None,
129
129
  nnmf_factor: bool = False,
130
130
  ):
131
131
  if not (lr >= 0.0):
@@ -148,9 +148,6 @@ class Adopt_adv(torch.optim.Optimizer):
148
148
  print("Warning: grams is incompatible with Simplified_AdEMAMix, Disabling grams.")
149
149
  if cautious_mask and Simplified_AdEMAMix:
150
150
  print("Warning: cautious is incompatible with Simplified_AdEMAMix, Disabling cautious.")
151
- if use_atan2 and Simplified_AdEMAMix:
152
- print("Warning: use_atan2 is incompatible with Simplified_AdEMAMix. Disabling use_atan2.")
153
- use_atan2 = False
154
151
 
155
152
  defaults = {
156
153
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
@@ -169,7 +166,7 @@ class Adopt_adv(torch.optim.Optimizer):
169
166
  self.Simplified_AdEMAMix = Simplified_AdEMAMix
170
167
  self.factored = nnmf_factor
171
168
  self.kourkoutas_beta = kourkoutas_beta
172
- self.layer_key_fn = layer_key_fn
169
+ self.layer_key_kb_fn = layer_key_kb_fn
173
170
  super().__init__(params, defaults)
174
171
 
175
172
  if self.kourkoutas_beta:
@@ -22,7 +22,7 @@ class Muon_adv(torch.optim.Optimizer):
22
22
  can handle other-dimensional parameters (e.g., 1D bias, 4D convolutional layers) by
23
23
  flattening/reshaping them.
24
24
 
25
- This version can also operate in a hybrid mode, using an auxiliary AdamW
25
+ Can also operate in a hybrid mode, using an auxiliary AdamW
26
26
  optimizer for specific parameters (e.g., biases, norms, embeddings) as
27
27
  defined by a `layer_key_fn`.
28
28
 
@@ -38,6 +38,14 @@ class Muon_adv(torch.optim.Optimizer):
38
38
  ns_coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
39
39
  quintic polynomial in the Newton-Schulz iteration.
40
40
  (default: (3.4445, -4.7750, 2.0315)).
41
+ Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
42
+ This changes the update to `alpha_grad * grad + mt`, which can be
43
+ more responsive, especially for small batch sizes. (default: False)
44
+ alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
45
+ (only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
46
+ current gradient. For small batch sizes, use high values (e.g., 10-100) to be
47
+ more responsive. For large batch sizes, use low values (e.g., 0-1) for
48
+ stability. (default: 100.0)
41
49
  stochastic_rounding (bool): whether to use stochastic rounding for
42
50
  BF16 parameter updates (default: True).
43
51
  vector_reshape_muon (bool): whether to reshape 1D vectors into 2D
@@ -68,9 +76,11 @@ class Muon_adv(torch.optim.Optimizer):
68
76
  ns_steps: int = 5,
69
77
  ns_eps: float = 1e-7,
70
78
  ns_coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
79
+ Simplified_AdEMAMix: bool = False,
80
+ alpha_grad: float = 100.0,
71
81
  stochastic_rounding: bool = True,
72
82
  vector_reshape_muon: bool = False,
73
- vector_reshape: bool = True,
83
+ vector_reshape: bool = False,
74
84
  nnmf_factor: bool = False,
75
85
  # hybrid optimizer mode
76
86
  MuonWithAuxAdam: bool = False,
@@ -86,13 +96,17 @@ class Muon_adv(torch.optim.Optimizer):
86
96
  raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
87
97
  if not (ns_steps > 0):
88
98
  raise ValueError(f"Newton-Schulz steps should be > 0. Got {ns_steps}")
99
+ if Simplified_AdEMAMix and nesterov:
100
+ print("Warning: nesterov is incompatible with Simplified_AdEMAMix, Disabling cautious.")
101
+ nesterov = False
89
102
 
90
- defaults = {
103
+ muon_defaults = {
91
104
  "lr": lr, "beta1": beta1, "weight_decay": weight_decay,
92
105
  "nesterov": nesterov, "ns_steps": ns_steps, "ns_eps": ns_eps,
93
106
  "ns_coeffs": ns_coeffs, "nnmf_factor": nnmf_factor,
94
107
  "vector_reshape": vector_reshape,
95
108
  "vector_reshape_muon": vector_reshape_muon,
109
+ "Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
96
110
  }
97
111
  self.stochastic_rounding = stochastic_rounding
98
112
 
@@ -100,23 +114,41 @@ class Muon_adv(torch.optim.Optimizer):
100
114
  self.helper = None
101
115
  self.aux_adam = None
102
116
 
103
- if self.MuonWithAuxAdam:
104
- adam_kwargs = adam_kwargs or {}
105
- # Create a delegate AdamW optimizer to get its default hyperparameters.
106
- self.aux_adam = AdamW_adv(
107
- [],
108
- lr=muon_adam_lr,
109
- **adam_kwargs,
110
- _is_delegate=True
111
- )
112
- # Update the defaults dictionary
113
- defaults.update(self.aux_adam.defaults)
114
-
115
- super().__init__(params, defaults)
117
+ if not self.MuonWithAuxAdam:
118
+ super().__init__(params, muon_defaults)
119
+ return
116
120
 
117
- if self.MuonWithAuxAdam:
118
- self.helper = MuonAdamHelper(self, layer_key_fn)
121
+ # HYBRID OPTIMIZER LOGIC
122
+ adam_kwargs = adam_kwargs or {}
123
+ self.aux_adam = AdamW_adv(
124
+ [],
125
+ lr=muon_adam_lr,
126
+ **adam_kwargs,
127
+ _is_delegate=True
128
+ )
129
+ adam_defaults = self.aux_adam.defaults
130
+
131
+ final_param_groups = []
132
+ _layer_key_fn = layer_key_fn if layer_key_fn is not None else lambda p: 'muon'
133
+
134
+ for group in params:
135
+ first_param = group['params'][0]
136
+ key = _layer_key_fn(first_param)
137
+ optim_type = 'adam' if key == 'adam' else 'muon'
138
+
139
+ new_group = group.copy()
140
+ defaults_to_use = adam_defaults if optim_type == 'adam' else muon_defaults
141
+
142
+ for key, value in defaults_to_use.items():
143
+ new_group.setdefault(key, value)
144
+
145
+ final_param_groups.append(new_group)
146
+
147
+ super().__init__(final_param_groups, {})
119
148
 
149
+ # Now that self is initialized, create the helper
150
+ self.helper = MuonAdamHelper(self, layer_key_fn)
151
+
120
152
 
121
153
  @property
122
154
  def supports_fused_back_pass(self):
@@ -130,6 +162,16 @@ class Muon_adv(torch.optim.Optimizer):
130
162
  def supports_flat_params(self):
131
163
  return False
132
164
 
165
+ @property
166
+ def kourkoutas_helper(self):
167
+ """
168
+ Exposes the kourkoutas_helper from the auxiliary AdamW optimizer,
169
+ if it exists. This allows external access for logging K-b.
170
+ """
171
+ if self.aux_adam and hasattr(self.aux_adam, 'kourkoutas_helper'):
172
+ return self.aux_adam.kourkoutas_helper
173
+ return None
174
+
133
175
  @torch.no_grad()
134
176
  def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
135
177
  if self.MuonWithAuxAdam:
@@ -165,7 +207,7 @@ class Muon_adv(torch.optim.Optimizer):
165
207
 
166
208
  dtype = torch.float32 if group['nnmf_factor'] else p.dtype
167
209
  device = p.device
168
- if group['vector_reshape'] or state['reshaped_1d_muon']:
210
+ if state['factored'] or state['reshaped_1d_muon']:
169
211
  state['effective_shape'] = _get_effective_shape(p.numel())
170
212
  d1, d2 = state['effective_shape']
171
213
  if state['factored']:
@@ -183,6 +225,8 @@ class Muon_adv(torch.optim.Optimizer):
183
225
 
184
226
  beta1 = group['beta1']
185
227
  nesterov = group['nesterov']
228
+ Simplified_AdEMAMix = group['Simplified_AdEMAMix']
229
+ alpha_grad = group['alpha_grad']
186
230
 
187
231
  if state['factored']: # Factored Muon
188
232
 
@@ -200,6 +244,8 @@ class Muon_adv(torch.optim.Optimizer):
200
244
  if nesterov:
201
245
  # Nesterov momentum
202
246
  update = grad_reshaped.add(mt_buf, alpha=beta1)
247
+ elif Simplified_AdEMAMix:
248
+ update = torch.add(mt_buf, grad_reshaped, alpha=alpha_grad)
203
249
  else:
204
250
  # Standard momentum
205
251
  update = mt_buf.clone()
@@ -238,6 +284,12 @@ class Muon_adv(torch.optim.Optimizer):
238
284
  del grad_reshaped
239
285
  else:
240
286
  update = grad.add(mt_buf, alpha=beta1)
287
+ elif Simplified_AdEMAMix:
288
+ if state['reshaped_1d_muon']:
289
+ update = torch.add(mt_buf, grad_reshaped, alpha=alpha_grad)
290
+ del grad_reshaped
291
+ else:
292
+ update = torch.add(mt_buf, grad, alpha=alpha_grad)
241
293
  else:
242
294
  # Standard momentum
243
295
  update = mt_buf.clone()
@@ -267,6 +319,8 @@ class Muon_adv(torch.optim.Optimizer):
267
319
  if nesterov:
268
320
  # Nesterov momentum
269
321
  update = grad.add(mt_buf, alpha=beta1)
322
+ elif Simplified_AdEMAMix:
323
+ update = torch.add(mt_buf, grad, alpha=alpha_grad)
270
324
  else:
271
325
  # Standard momentum
272
326
  update = mt_buf.clone()
@@ -299,4 +353,4 @@ class Muon_adv(torch.optim.Optimizer):
299
353
  for i, p in enumerate(group['params']):
300
354
  self.step_parameter(p, group, i)
301
355
 
302
- return loss
356
+ return loss
@@ -109,7 +109,7 @@ class Prodigy_adv(torch.optim.Optimizer):
109
109
  logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
110
110
  every logging steps. Useful for debugging and tuning. Set to 0 to disable
111
111
  logging (default: 0).
112
- layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
112
+ layer_key_kb_fn (Optional[Callable]): A function that takes a parameter `p`
113
113
  and returns a unique, hashable key representing its "layer" or "bucket".
114
114
  If `None`, parameters are bucketed by their memory ID (tensor-wise).
115
115
  (default: None)
@@ -152,7 +152,7 @@ class Prodigy_adv(torch.optim.Optimizer):
152
152
  tiny_spike: float = 1e-9,
153
153
  k_warmup_steps: int = 0,
154
154
  k_logging: int = 0,
155
- layer_key_fn: Optional[Callable] = None,
155
+ layer_key_kb_fn: Optional[Callable] = None,
156
156
  ):
157
157
  if not (lr >= 0.0):
158
158
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -205,7 +205,7 @@ class Prodigy_adv(torch.optim.Optimizer):
205
205
  self.fsdp_in_use = fsdp_in_use
206
206
 
207
207
  self.kourkoutas_beta = kourkoutas_beta
208
- self.layer_key_fn = layer_key_fn
208
+ self.layer_key_kb_fn = layer_key_kb_fn
209
209
 
210
210
  super().__init__(params, defaults)
211
211
  if self.kourkoutas_beta:
@@ -67,7 +67,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
67
67
  logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
68
68
  every logging steps. Useful for debugging and tuning. Set to 0 to disable
69
69
  logging (default: 0).
70
- layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
70
+ layer_key_kb_fn (Optional[Callable]): A function that takes a parameter `p`
71
71
  and returns a unique, hashable key representing its "layer" or "bucket".
72
72
  If `None`, parameters are bucketed by their memory ID (tensor-wise).
73
73
  (default: None)
@@ -95,7 +95,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
95
95
  tiny_spike: float = 1e-9,
96
96
  k_warmup_steps: int = 0,
97
97
  k_logging: int = 0,
98
- layer_key_fn: Optional[Callable] = None,
98
+ layer_key_kb_fn: Optional[Callable] = None,
99
99
  nnmf_factor: bool = False,
100
100
  ):
101
101
  if not (lr >= 0.0):
@@ -121,7 +121,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
121
121
  self.stochastic_rounding = stochastic_rounding
122
122
  self.factored = nnmf_factor
123
123
  self.kourkoutas_beta = kourkoutas_beta
124
- self.layer_key_fn = layer_key_fn
124
+ self.layer_key_kb_fn = layer_key_kb_fn
125
125
  super().__init__(params, defaults)
126
126
 
127
127
  if self.kourkoutas_beta:
@@ -5,6 +5,7 @@ from .Simplified_AdEMAMix import Simplified_AdEMAMix
5
5
  from .Lion_adv import Lion_adv
6
6
  from .Lion_Prodigy_adv import Lion_Prodigy_adv
7
7
  from .Muon_adv import Muon_adv
8
+ from .AdaMuon_adv import AdaMuon_adv
8
9
 
9
10
  __all__ = [
10
11
  "AdamW_adv",
@@ -14,4 +15,5 @@ __all__ = [
14
15
  "Lion_adv",
15
16
  "Lion_Prodigy_adv",
16
17
  "Muon_adv",
18
+ "AdaMuon_adv",
17
19
  ]
@@ -32,12 +32,12 @@ class KourkoutasHelper:
32
32
  if self._layer_info_built:
33
33
  return
34
34
 
35
- if hasattr(self.optimizer, 'layer_key_fn') and self.optimizer.layer_key_fn is not None:
35
+ if hasattr(self.optimizer, 'layer_key_kb_fn') and self.optimizer.layer_key_kb_fn is not None:
36
36
  # A custom key function was provided by the user. We will use it.
37
37
  pass
38
38
  else:
39
39
  # No key function was provided. Default to coarse, shape-based bucketing.
40
- self.optimizer.layer_key_fn = lambda p: \
40
+ self.optimizer.layer_key_kb_fn = lambda p: \
41
41
  (id(p),) if p.dim() == 2 and 1 <= p.shape[0] <= 10 and p.shape[1] in {768, 1280, 4096} \
42
42
  else tuple(p.shape)
43
43
  # This ensures that we won't mix embeddings with tokens (1 to 10)
@@ -46,7 +46,7 @@ class KourkoutasHelper:
46
46
  for group in self.optimizer.param_groups:
47
47
  for p in group['params']:
48
48
  # The mapping is static and should not depend on the presence of a gradient.
49
- layer_key = self.optimizer.layer_key_fn(p)
49
+ layer_key = self.optimizer.layer_key_kb_fn(p)
50
50
  if layer_key not in self.layer_info:
51
51
  self.layer_info[layer_key] = {'params': [], 'group_ref': group}
52
52
  self.layer_info[layer_key]['params'].append(p)
@@ -158,7 +158,7 @@ class KourkoutasHelper:
158
158
  """
159
159
  Accumulates the squared L2 norm of a single gradient for the next step's calculation.
160
160
  """
161
- layer_key = self.optimizer.layer_key_fn(p)
161
+ layer_key = self.optimizer.layer_key_kb_fn(p)
162
162
 
163
163
  if layer_key in self.layer_info:
164
164
  # Initialize the transient state for this layer if it's the first time in the step.
@@ -173,6 +173,6 @@ class KourkoutasHelper:
173
173
  """
174
174
  Gets the appropriate beta2 for the current parameter, handling warmup and dynamic value fetching.
175
175
  """
176
- layer_key = self.optimizer.layer_key_fn(p)
176
+ layer_key = self.optimizer.layer_key_kb_fn(p)
177
177
  # The default is the max value, which is correct for unmapped params or edge cases
178
178
  return self.layer_state.get(layer_key, {}).get('dynamic_beta2', group['betas'][1])
@@ -1,3 +1,4 @@
1
+ import torch
1
2
  from torch.optim import Optimizer
2
3
  from typing import Callable, Optional
3
4
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.2.dev2
3
+ Version: 1.2.dev4
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -0,0 +1,24 @@
1
+ adv_optm/__init__.py,sha256=bB7_VywKpvZbcGCjtZoF8giQgcUgoziISBgIaEUpcAw,379
2
+ adv_optm/optim/AdaMuon_adv.py,sha256=s5UkR2YJ_Z10SiBokT97eq4tCHc2D8BEOFDx5AOMryQ,20983
3
+ adv_optm/optim/AdamW_adv.py,sha256=7IvdD1rqYeHZwQCZU9X0H7x87MCKcHQ5M68GLuMCkvE,17702
4
+ adv_optm/optim/Adopt_adv.py,sha256=C2FsEZGvCk9q4YNKAj0qIxdZ5AfPlda-1lIpSX0a1nE,21256
5
+ adv_optm/optim/Lion_Prodigy_adv.py,sha256=LEA3UYJpPeFnmxeniLNv1u2LKKj4ufx3Bq_MLw-nWXk,14617
6
+ adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
7
+ adv_optm/optim/Muon_adv.py,sha256=vB-Eeh0IqYMd3lkQvIPEbH256bTyYO73OgIzn0N2VCk,14985
8
+ adv_optm/optim/Prodigy_adv.py,sha256=bmwuO8GrJHH4NaEaqE-ffcR9wHhQ57457xoN-P6hyks,25909
9
+ adv_optm/optim/Simplified_AdEMAMix.py,sha256=sY-vThMVgADRh0ar9WHkrM2n8UcgQLQC1YV1Wx8uFz4,12983
10
+ adv_optm/optim/__init__.py,sha256=hpUWE6CKtt_rvMdgQVb3PtjhfZAvAxTq6hp8H8rIpBo,489
11
+ adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
12
+ adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
13
+ adv_optm/util/Kourkoutas.py,sha256=MDQaNVH8jqzaefks2RShveo44dpYDz88WStwUJ3iF0s,8724
14
+ adv_optm/util/MuonAdam_helper.py,sha256=7rnNMujZVDaqo1g22QscMyPlZvIHQQSLHMED9_I8QWU,1250
15
+ adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
16
+ adv_optm/util/Newton_Schulz.py,sha256=wJ_sKRaGVIsOofQ737my4ng494qX_pfgOqlDDmYtnCg,1377
17
+ adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
18
+ adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
19
+ adv_optm/util/__init__.py,sha256=jAaUfaAjFrTJ6-Q915ezAbq0efRbpYjriW2OdeCbSzo,433
20
+ adv_optm-1.2.dev4.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
21
+ adv_optm-1.2.dev4.dist-info/METADATA,sha256=jNczVxIPq0LuusXuGrZ23CQ4CrMNOfJdBDpDQgulMUw,14022
22
+ adv_optm-1.2.dev4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
+ adv_optm-1.2.dev4.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
24
+ adv_optm-1.2.dev4.dist-info/RECORD,,
@@ -1,23 +0,0 @@
1
- adv_optm/__init__.py,sha256=THWhNF8-PI71K9Au4xAkuDs96YcEagJ-yT5r_g2-yKw,341
2
- adv_optm/optim/AdamW_adv.py,sha256=Zym0beeu0ye5_PgpAjpzcYghdPYFWs3gQzDmuPZVR80,17690
3
- adv_optm/optim/Adopt_adv.py,sha256=NXbtPrGm3tZr06cApi5oEHZ2F1zwss3tRi15SGnrYPc,21426
4
- adv_optm/optim/Lion_Prodigy_adv.py,sha256=LEA3UYJpPeFnmxeniLNv1u2LKKj4ufx3Bq_MLw-nWXk,14617
5
- adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
6
- adv_optm/optim/Muon_adv.py,sha256=9K5YR3odaGfDDZzasletHRlqxG8xN9IXj6oiqx1CaEI,12423
7
- adv_optm/optim/Prodigy_adv.py,sha256=0_XG5YnMQTv-zJysJHlJniSo5kGYdX3p3o1e33HLt78,25897
8
- adv_optm/optim/Simplified_AdEMAMix.py,sha256=nEIA3yM11nBooKzHudB5l3x4UdFRBYRwiKVUkGmO0K8,12971
9
- adv_optm/optim/__init__.py,sha256=3o2XJ4J-PUq3rJM2mBnmuHwbKNb4LuW-Ig_9aBC0ycc,431
10
- adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
11
- adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
12
- adv_optm/util/Kourkoutas.py,sha256=woyJfX7l4eieeg0pC5XrILBLvwECwbD3a6ou1K6qjKU,8706
13
- adv_optm/util/MuonAdam_helper.py,sha256=llPCc9MBFen_wodbY4G2E17tBZky8clDiJSZLHkMva8,1236
14
- adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
15
- adv_optm/util/Newton_Schulz.py,sha256=wJ_sKRaGVIsOofQ737my4ng494qX_pfgOqlDDmYtnCg,1377
16
- adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
17
- adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
18
- adv_optm/util/__init__.py,sha256=jAaUfaAjFrTJ6-Q915ezAbq0efRbpYjriW2OdeCbSzo,433
19
- adv_optm-1.2.dev2.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
20
- adv_optm-1.2.dev2.dist-info/METADATA,sha256=JTCPGBJUd4JR7DU26AhX8qSPzWrSVtEwv9Au7I3iEPY,14022
21
- adv_optm-1.2.dev2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
22
- adv_optm-1.2.dev2.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
23
- adv_optm-1.2.dev2.dist-info/RECORD,,