adv-optm 1.1.0.dev3__py3-none-any.whl → 1.1.0.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/Adopt_adv.py +435 -439
- adv_optm/optim/Lion_Prodigy_adv.py +315 -315
- adv_optm/util/Kourkoutas.py +28 -22
- {adv_optm-1.1.0.dev3.dist-info → adv_optm-1.1.0.dev4.dist-info}/METADATA +1 -1
- {adv_optm-1.1.0.dev3.dist-info → adv_optm-1.1.0.dev4.dist-info}/RECORD +9 -9
- {adv_optm-1.1.0.dev3.dist-info → adv_optm-1.1.0.dev4.dist-info}/WHEEL +0 -0
- {adv_optm-1.1.0.dev3.dist-info → adv_optm-1.1.0.dev4.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-1.1.0.dev3.dist-info → adv_optm-1.1.0.dev4.dist-info}/top_level.txt +0 -0
adv_optm/util/Kourkoutas.py
CHANGED
|
@@ -11,13 +11,16 @@ class KourkoutasHelper:
|
|
|
11
11
|
if not hasattr(optimizer, 'param_groups'):
|
|
12
12
|
raise TypeError("optimizer must be a valid torch.optim.Optimizer instance.")
|
|
13
13
|
self.optimizer = optimizer
|
|
14
|
-
|
|
14
|
+
|
|
15
15
|
# State managed by the helper
|
|
16
16
|
self.layer_state = {}
|
|
17
17
|
self.layer_info = {}
|
|
18
18
|
self._layer_info_built = False
|
|
19
19
|
self._current_step_prepared = -1
|
|
20
20
|
|
|
21
|
+
# Store stats for external logging (e.g., TensorBoard)
|
|
22
|
+
self.last_beta2_stats = {}
|
|
23
|
+
|
|
21
24
|
# This ensures the map is complete before the first backward pass,
|
|
22
25
|
# making it compatible with fused back pass mechanisms.
|
|
23
26
|
self._build_layer_info_if_needed()
|
|
@@ -26,7 +29,7 @@ class KourkoutasHelper:
|
|
|
26
29
|
"""Builds a map of layers and the parameters they contain."""
|
|
27
30
|
if self._layer_info_built:
|
|
28
31
|
return
|
|
29
|
-
|
|
32
|
+
|
|
30
33
|
if not hasattr(self.optimizer, 'layer_key_fn') or self.optimizer.layer_key_fn is None:
|
|
31
34
|
print("Warning: KourkoutasHelper requires 'layer_key_fn' on the optimizer. Defaulting to tensor-wise (id).")
|
|
32
35
|
self.optimizer.layer_key_fn = lambda p: id(p)
|
|
@@ -50,13 +53,11 @@ class KourkoutasHelper:
|
|
|
50
53
|
Calculates dynamic beta2 for all layers using the completed scalar accumulators
|
|
51
54
|
from the PREVIOUS step. Should be called once at the start of an optimizer step.
|
|
52
55
|
"""
|
|
56
|
+
|
|
57
|
+
beta2_log = []
|
|
58
|
+
# These are just for the sample log, initialize them
|
|
59
|
+
sun, pooled_grad_norm, r_ema = (torch.tensor(0.0),)*3
|
|
53
60
|
|
|
54
|
-
# Check if logging is enabled for this step based on the interval
|
|
55
|
-
k_logging_interval = self.optimizer.param_groups[0].get('k_logging', 0)
|
|
56
|
-
is_logging_step = k_logging_interval > 0 and (current_step + 1) % k_logging_interval == 0
|
|
57
|
-
|
|
58
|
-
beta2_log = [] if is_logging_step else None
|
|
59
|
-
first_layer_key = next(iter(self.layer_info), None)
|
|
60
61
|
|
|
61
62
|
for layer_key, info in self.layer_info.items():
|
|
62
63
|
params, group = info['params'], info['group_ref']
|
|
@@ -68,16 +69,15 @@ class KourkoutasHelper:
|
|
|
68
69
|
}
|
|
69
70
|
|
|
70
71
|
layer_state = self.layer_state[layer_key]
|
|
71
|
-
|
|
72
|
+
|
|
72
73
|
# Use the completed accumulator from the previous step
|
|
73
74
|
pooled_grad_norm = torch.sqrt(layer_state['sum_sq_accumulator'])
|
|
74
|
-
|
|
75
|
+
|
|
75
76
|
r_ema = layer_state['r_ema_grad_norm']
|
|
76
|
-
|
|
77
|
-
|
|
77
|
+
|
|
78
78
|
# EMA is always updated, even during warmup
|
|
79
79
|
r_ema.mul_(group['ema_alpha']).add_(pooled_grad_norm, alpha=1.0 - group['ema_alpha'])
|
|
80
|
-
|
|
80
|
+
|
|
81
81
|
sun = torch.tensor(0.0, device=r_ema.device) # Default sun to 0 for warmup
|
|
82
82
|
beta2_max = group['betas'][1]
|
|
83
83
|
|
|
@@ -92,16 +92,22 @@ class KourkoutasHelper:
|
|
|
92
92
|
layer_state['dynamic_beta2'] = beta2.item() if isinstance(beta2, torch.Tensor) else beta2
|
|
93
93
|
layer_state['sum_sq_accumulator'].zero_()
|
|
94
94
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
print(f" - Grad Norm: {pooled_grad_norm.item():.4e}, Prev EMA: {prev_r_ema_val:.4e}, New EMA: {r_ema.item():.4e}")
|
|
100
|
-
print(f" - Sunspike: {sun.item():.4f}, Dynamic Beta2: {layer_state['dynamic_beta2']:.4f}")
|
|
101
|
-
|
|
102
|
-
if is_logging_step and beta2_log:
|
|
95
|
+
beta2_log.append(layer_state['dynamic_beta2'])
|
|
96
|
+
|
|
97
|
+
# Always compute stats for TensorBoard
|
|
98
|
+
if beta2_log:
|
|
103
99
|
beta2_tensor = torch.tensor(beta2_log, device='cpu')
|
|
104
|
-
|
|
100
|
+
self.last_beta2_stats = {
|
|
101
|
+
'min': beta2_tensor.min().item(),
|
|
102
|
+
'max': beta2_tensor.max().item(),
|
|
103
|
+
'mean': beta2_tensor.mean().item(),
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
# Handle periodic console logging
|
|
107
|
+
k_logging_interval = self.optimizer.param_groups[0].get('k_logging', 0)
|
|
108
|
+
is_logging_step = k_logging_interval > 0 and (current_step + 1) % k_logging_interval == 0
|
|
109
|
+
if is_logging_step and self.last_beta2_stats:
|
|
110
|
+
print(f"[Kourkoutas-β Debug] Step {current_step + 1} Overall Beta2 Stats: Min={self.last_beta2_stats['min']:.4f}, Max={self.last_beta2_stats['max']:.4f}, Mean={self.last_beta2_stats['mean']:.4f}")
|
|
105
111
|
|
|
106
112
|
|
|
107
113
|
def maybe_prepare_step(self, current_step: int):
|
|
@@ -1,20 +1,20 @@
|
|
|
1
|
-
adv_optm/__init__.py,sha256=
|
|
1
|
+
adv_optm/__init__.py,sha256=H4E_1__pXxRu4PSgQCzGi7WuFqVjTfex2Yduz3B3peI,311
|
|
2
2
|
adv_optm/optim/AdamW_adv.py,sha256=H4XlYZELwiFvXt0A9wMlRNiw9c8rmPMspHDCvR_SZIQ,17487
|
|
3
|
-
adv_optm/optim/Adopt_adv.py,sha256=
|
|
4
|
-
adv_optm/optim/Lion_Prodigy_adv.py,sha256=
|
|
3
|
+
adv_optm/optim/Adopt_adv.py,sha256=0uMROjCw3wGOyp0ZX_xjwMVaXHJ395ifntcgY0MZt3M,21460
|
|
4
|
+
adv_optm/optim/Lion_Prodigy_adv.py,sha256=xIrwibQ2i919EHEACLCrKe5JBnS-s2Ai35yeJ1Bn1MA,13159
|
|
5
5
|
adv_optm/optim/Lion_adv.py,sha256=6G1CukJB_pC7l9HwFEuY1ydsNHZFabVmOvcHDsHHVuQ,8295
|
|
6
6
|
adv_optm/optim/Prodigy_adv.py,sha256=EeSfYu8IIeZX1Dk8MlD71vGOpMadtnW2iMhHxPDL2XQ,25574
|
|
7
7
|
adv_optm/optim/Simplified_AdEMAMix.py,sha256=b4GaSI-TX6wFBqGxZeoJPbf2nVRCEtB3WVb1olDgY14,12980
|
|
8
8
|
adv_optm/optim/__init__.py,sha256=pcP865H2j1tut2VfTUhzQh7V8TF_tzPjqFnjMfFed2k,382
|
|
9
9
|
adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
|
|
10
10
|
adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
|
|
11
|
-
adv_optm/util/Kourkoutas.py,sha256=
|
|
11
|
+
adv_optm/util/Kourkoutas.py,sha256=st9hO2I0Xcby0LLq1MhxiEsPyNzEkNpJO_WfYvkioKg,6606
|
|
12
12
|
adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
|
|
13
13
|
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
14
14
|
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
15
15
|
adv_optm/util/__init__.py,sha256=qoyIF0jcLjs_vSEcsv36clw5LFNBEbifyXrrVxMH-G4,349
|
|
16
|
-
adv_optm-1.1.0.
|
|
17
|
-
adv_optm-1.1.0.
|
|
18
|
-
adv_optm-1.1.0.
|
|
19
|
-
adv_optm-1.1.0.
|
|
20
|
-
adv_optm-1.1.0.
|
|
16
|
+
adv_optm-1.1.0.dev4.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
17
|
+
adv_optm-1.1.0.dev4.dist-info/METADATA,sha256=Ue6x-vthnxradX5tH1ver4LVbWMEMmqPjMVO8KjTdhI,8427
|
|
18
|
+
adv_optm-1.1.0.dev4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
19
|
+
adv_optm-1.1.0.dev4.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
20
|
+
adv_optm-1.1.0.dev4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|