adv-optm 1.1.0.dev2__py3-none-any.whl → 1.1.0.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm/__init__.py +1 -1
- adv_optm/optim/Adopt_adv.py +435 -439
- adv_optm/optim/Lion_Prodigy_adv.py +315 -315
- adv_optm/optim/Prodigy_adv.py +4 -2
- adv_optm/util/Kourkoutas.py +33 -25
- {adv_optm-1.1.0.dev2.dist-info → adv_optm-1.1.0.dev4.dist-info}/METADATA +1 -1
- {adv_optm-1.1.0.dev2.dist-info → adv_optm-1.1.0.dev4.dist-info}/RECORD +10 -10
- {adv_optm-1.1.0.dev2.dist-info → adv_optm-1.1.0.dev4.dist-info}/WHEEL +0 -0
- {adv_optm-1.1.0.dev2.dist-info → adv_optm-1.1.0.dev4.dist-info}/licenses/LICENSE +0 -0
- {adv_optm-1.1.0.dev2.dist-info → adv_optm-1.1.0.dev4.dist-info}/top_level.txt +0 -0
adv_optm/util/Kourkoutas.py
CHANGED
|
@@ -11,25 +11,32 @@ class KourkoutasHelper:
|
|
|
11
11
|
if not hasattr(optimizer, 'param_groups'):
|
|
12
12
|
raise TypeError("optimizer must be a valid torch.optim.Optimizer instance.")
|
|
13
13
|
self.optimizer = optimizer
|
|
14
|
-
|
|
14
|
+
|
|
15
15
|
# State managed by the helper
|
|
16
16
|
self.layer_state = {}
|
|
17
17
|
self.layer_info = {}
|
|
18
18
|
self._layer_info_built = False
|
|
19
19
|
self._current_step_prepared = -1
|
|
20
20
|
|
|
21
|
+
# Store stats for external logging (e.g., TensorBoard)
|
|
22
|
+
self.last_beta2_stats = {}
|
|
23
|
+
|
|
24
|
+
# This ensures the map is complete before the first backward pass,
|
|
25
|
+
# making it compatible with fused back pass mechanisms.
|
|
26
|
+
self._build_layer_info_if_needed()
|
|
27
|
+
|
|
21
28
|
def _build_layer_info_if_needed(self):
|
|
22
29
|
"""Builds a map of layers and the parameters they contain."""
|
|
23
30
|
if self._layer_info_built:
|
|
24
31
|
return
|
|
25
|
-
|
|
32
|
+
|
|
26
33
|
if not hasattr(self.optimizer, 'layer_key_fn') or self.optimizer.layer_key_fn is None:
|
|
27
34
|
print("Warning: KourkoutasHelper requires 'layer_key_fn' on the optimizer. Defaulting to tensor-wise (id).")
|
|
28
35
|
self.optimizer.layer_key_fn = lambda p: id(p)
|
|
29
36
|
|
|
30
37
|
for group in self.optimizer.param_groups:
|
|
31
38
|
for p in group['params']:
|
|
32
|
-
|
|
39
|
+
# The mapping is static and should not depend on the presence of a gradient.
|
|
33
40
|
layer_key = self.optimizer.layer_key_fn(p)
|
|
34
41
|
if layer_key not in self.layer_info:
|
|
35
42
|
self.layer_info[layer_key] = {'params': [], 'group_ref': group}
|
|
@@ -46,14 +53,11 @@ class KourkoutasHelper:
|
|
|
46
53
|
Calculates dynamic beta2 for all layers using the completed scalar accumulators
|
|
47
54
|
from the PREVIOUS step. Should be called once at the start of an optimizer step.
|
|
48
55
|
"""
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
#
|
|
52
|
-
|
|
53
|
-
is_logging_step = k_logging_interval > 0 and (current_step + 1) % k_logging_interval == 0
|
|
56
|
+
|
|
57
|
+
beta2_log = []
|
|
58
|
+
# These are just for the sample log, initialize them
|
|
59
|
+
sun, pooled_grad_norm, r_ema = (torch.tensor(0.0),)*3
|
|
54
60
|
|
|
55
|
-
beta2_log = [] if is_logging_step else None
|
|
56
|
-
first_layer_key = next(iter(self.layer_info), None)
|
|
57
61
|
|
|
58
62
|
for layer_key, info in self.layer_info.items():
|
|
59
63
|
params, group = info['params'], info['group_ref']
|
|
@@ -65,16 +69,15 @@ class KourkoutasHelper:
|
|
|
65
69
|
}
|
|
66
70
|
|
|
67
71
|
layer_state = self.layer_state[layer_key]
|
|
68
|
-
|
|
72
|
+
|
|
69
73
|
# Use the completed accumulator from the previous step
|
|
70
74
|
pooled_grad_norm = torch.sqrt(layer_state['sum_sq_accumulator'])
|
|
71
|
-
|
|
75
|
+
|
|
72
76
|
r_ema = layer_state['r_ema_grad_norm']
|
|
73
|
-
|
|
74
|
-
|
|
77
|
+
|
|
75
78
|
# EMA is always updated, even during warmup
|
|
76
79
|
r_ema.mul_(group['ema_alpha']).add_(pooled_grad_norm, alpha=1.0 - group['ema_alpha'])
|
|
77
|
-
|
|
80
|
+
|
|
78
81
|
sun = torch.tensor(0.0, device=r_ema.device) # Default sun to 0 for warmup
|
|
79
82
|
beta2_max = group['betas'][1]
|
|
80
83
|
|
|
@@ -89,16 +92,22 @@ class KourkoutasHelper:
|
|
|
89
92
|
layer_state['dynamic_beta2'] = beta2.item() if isinstance(beta2, torch.Tensor) else beta2
|
|
90
93
|
layer_state['sum_sq_accumulator'].zero_()
|
|
91
94
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
print(f" - Grad Norm: {pooled_grad_norm.item():.4e}, Prev EMA: {prev_r_ema_val:.4e}, New EMA: {r_ema.item():.4e}")
|
|
97
|
-
print(f" - Sunspike: {sun.item():.4f}, Dynamic Beta2: {layer_state['dynamic_beta2']:.4f}")
|
|
98
|
-
|
|
99
|
-
if is_logging_step and beta2_log:
|
|
95
|
+
beta2_log.append(layer_state['dynamic_beta2'])
|
|
96
|
+
|
|
97
|
+
# Always compute stats for TensorBoard
|
|
98
|
+
if beta2_log:
|
|
100
99
|
beta2_tensor = torch.tensor(beta2_log, device='cpu')
|
|
101
|
-
|
|
100
|
+
self.last_beta2_stats = {
|
|
101
|
+
'min': beta2_tensor.min().item(),
|
|
102
|
+
'max': beta2_tensor.max().item(),
|
|
103
|
+
'mean': beta2_tensor.mean().item(),
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
# Handle periodic console logging
|
|
107
|
+
k_logging_interval = self.optimizer.param_groups[0].get('k_logging', 0)
|
|
108
|
+
is_logging_step = k_logging_interval > 0 and (current_step + 1) % k_logging_interval == 0
|
|
109
|
+
if is_logging_step and self.last_beta2_stats:
|
|
110
|
+
print(f"[Kourkoutas-β Debug] Step {current_step + 1} Overall Beta2 Stats: Min={self.last_beta2_stats['min']:.4f}, Max={self.last_beta2_stats['max']:.4f}, Mean={self.last_beta2_stats['mean']:.4f}")
|
|
102
111
|
|
|
103
112
|
|
|
104
113
|
def maybe_prepare_step(self, current_step: int):
|
|
@@ -113,7 +122,6 @@ class KourkoutasHelper:
|
|
|
113
122
|
"""
|
|
114
123
|
Accumulates the squared L2 norm of a single gradient for the next step's calculation.
|
|
115
124
|
"""
|
|
116
|
-
self._build_layer_info_if_needed()
|
|
117
125
|
layer_key = self.optimizer.layer_key_fn(p)
|
|
118
126
|
|
|
119
127
|
if layer_key in self.layer_info:
|
|
@@ -1,20 +1,20 @@
|
|
|
1
|
-
adv_optm/__init__.py,sha256=
|
|
1
|
+
adv_optm/__init__.py,sha256=H4E_1__pXxRu4PSgQCzGi7WuFqVjTfex2Yduz3B3peI,311
|
|
2
2
|
adv_optm/optim/AdamW_adv.py,sha256=H4XlYZELwiFvXt0A9wMlRNiw9c8rmPMspHDCvR_SZIQ,17487
|
|
3
|
-
adv_optm/optim/Adopt_adv.py,sha256=
|
|
4
|
-
adv_optm/optim/Lion_Prodigy_adv.py,sha256=
|
|
3
|
+
adv_optm/optim/Adopt_adv.py,sha256=0uMROjCw3wGOyp0ZX_xjwMVaXHJ395ifntcgY0MZt3M,21460
|
|
4
|
+
adv_optm/optim/Lion_Prodigy_adv.py,sha256=xIrwibQ2i919EHEACLCrKe5JBnS-s2Ai35yeJ1Bn1MA,13159
|
|
5
5
|
adv_optm/optim/Lion_adv.py,sha256=6G1CukJB_pC7l9HwFEuY1ydsNHZFabVmOvcHDsHHVuQ,8295
|
|
6
|
-
adv_optm/optim/Prodigy_adv.py,sha256
|
|
6
|
+
adv_optm/optim/Prodigy_adv.py,sha256=EeSfYu8IIeZX1Dk8MlD71vGOpMadtnW2iMhHxPDL2XQ,25574
|
|
7
7
|
adv_optm/optim/Simplified_AdEMAMix.py,sha256=b4GaSI-TX6wFBqGxZeoJPbf2nVRCEtB3WVb1olDgY14,12980
|
|
8
8
|
adv_optm/optim/__init__.py,sha256=pcP865H2j1tut2VfTUhzQh7V8TF_tzPjqFnjMfFed2k,382
|
|
9
9
|
adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
|
|
10
10
|
adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
|
|
11
|
-
adv_optm/util/Kourkoutas.py,sha256=
|
|
11
|
+
adv_optm/util/Kourkoutas.py,sha256=st9hO2I0Xcby0LLq1MhxiEsPyNzEkNpJO_WfYvkioKg,6606
|
|
12
12
|
adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
|
|
13
13
|
adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
|
|
14
14
|
adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
|
|
15
15
|
adv_optm/util/__init__.py,sha256=qoyIF0jcLjs_vSEcsv36clw5LFNBEbifyXrrVxMH-G4,349
|
|
16
|
-
adv_optm-1.1.0.
|
|
17
|
-
adv_optm-1.1.0.
|
|
18
|
-
adv_optm-1.1.0.
|
|
19
|
-
adv_optm-1.1.0.
|
|
20
|
-
adv_optm-1.1.0.
|
|
16
|
+
adv_optm-1.1.0.dev4.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
17
|
+
adv_optm-1.1.0.dev4.dist-info/METADATA,sha256=Ue6x-vthnxradX5tH1ver4LVbWMEMmqPjMVO8KjTdhI,8427
|
|
18
|
+
adv_optm-1.1.0.dev4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
19
|
+
adv_optm-1.1.0.dev4.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
|
|
20
|
+
adv_optm-1.1.0.dev4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|