adv-optm 1.1.0.dev3__py3-none-any.whl → 1.1.0.dev5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

@@ -88,6 +88,9 @@ class Prodigy_adv(torch.optim.Optimizer):
88
88
  prodigy_steps (int): If greater than zero, disable Prodigy's stepsize adjustments
89
89
  after the specified optimiser step and release all state memory required by Prodigy
90
90
  (default: 0).
91
+ d_limiter (bool): whether to clamp the new step size estimate (`d_hat`)
92
+ to prevent sudden, volatile increases in the adaptive step size (`d`).
93
+ (default: False)
91
94
  kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
92
95
  If `False`, the optimizer behaves as standard AdamW/Prodigy. (default: False)
93
96
  beta2_min (float): The minimum value for dynamic β₂, used during periods of
@@ -141,9 +144,11 @@ class Prodigy_adv(torch.optim.Optimizer):
141
144
  fsdp_in_use: bool = False,
142
145
  slice_p: int = 11,
143
146
  prodigy_steps: int = 0,
147
+ d_limiter: bool = False,
148
+ # K-b parameters
144
149
  kourkoutas_beta: bool = False,
145
- beta2_min: float = 0.88,
146
- ema_alpha: float = 0.93,
150
+ beta2_min: float = 0.9,
151
+ ema_alpha: float = 0.95,
147
152
  tiny_spike: float = 1e-9,
148
153
  k_warmup_steps: int = 0,
149
154
  k_logging: int = 0,
@@ -175,8 +180,8 @@ class Prodigy_adv(torch.optim.Optimizer):
175
180
  use_atan2 = False
176
181
  if kourkoutas_beta and not (betas[1] > beta2_min):
177
182
  raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
178
- if Simplified_AdEMAMix and alpha_grad > 0:
179
- # scales d_coef by alpha_grad, this force prodigy to behave well with Simplified_AdEMAMix
183
+ if Simplified_AdEMAMix and alpha_grad > 0 and not d_limiter:
184
+ # scales d_coef by alpha_grad, this force prodigy to behave well with Simplified_AdEMAMix.
180
185
  d_coef = d_coef/alpha_grad
181
186
 
182
187
  defaults = {
@@ -186,7 +191,7 @@ class Prodigy_adv(torch.optim.Optimizer):
186
191
  "beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
187
192
  "beta3": beta3, "d": d0, "d0": d0, "d_max": d0, "d_numerator": 0.0, "d_coef": d_coef,
188
193
  "growth_rate": growth_rate, "safeguard_warmup": safeguard_warmup, "k": 0, "slice_p": slice_p,
189
- "fsdp_in_use": fsdp_in_use, "prodigy_steps": prodigy_steps,
194
+ "fsdp_in_use": fsdp_in_use, "prodigy_steps": prodigy_steps, "d_limiter": d_limiter,
190
195
  "alpha_grad": alpha_grad,
191
196
  "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
192
197
  "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
@@ -251,7 +256,7 @@ class Prodigy_adv(torch.optim.Optimizer):
251
256
  state = self.state[p]
252
257
 
253
258
  # State Initialization
254
- if len(state) == 0:
259
+ if 'step' not in state:
255
260
  state['step'] = 0
256
261
 
257
262
  should_factor = (
@@ -512,6 +517,8 @@ class Prodigy_adv(torch.optim.Optimizer):
512
517
  d_hat = self.d
513
518
  if global_d_denom > 0:
514
519
  d_hat = d_coef * global_d_numerator / global_d_denom
520
+ if g_group['d_limiter']:
521
+ d_hat = min(self.d * (2 ** 0.25), d_hat)
515
522
  if self.d == g_group['d0']:
516
523
  self.d = max(self.d, d_hat)
517
524
  d_max = max(d_max, d_hat)
@@ -90,8 +90,8 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
90
90
  stochastic_rounding: bool = True,
91
91
  orthogonal_gradient: bool = False,
92
92
  kourkoutas_beta: bool = False,
93
- beta2_min: float = 0.88,
94
- ema_alpha: float = 0.93,
93
+ beta2_min: float = 0.9,
94
+ ema_alpha: float = 0.95,
95
95
  tiny_spike: float = 1e-9,
96
96
  k_warmup_steps: int = 0,
97
97
  k_logging: int = 0,
@@ -152,7 +152,7 @@ class Simplified_AdEMAMix(torch.optim.Optimizer):
152
152
  state = self.state[p]
153
153
 
154
154
  # State Initialization
155
- if len(state) == 0:
155
+ if 'step' not in state:
156
156
  state['step'] = 0
157
157
 
158
158
  should_factor = (
@@ -11,22 +11,27 @@ class KourkoutasHelper:
11
11
  if not hasattr(optimizer, 'param_groups'):
12
12
  raise TypeError("optimizer must be a valid torch.optim.Optimizer instance.")
13
13
  self.optimizer = optimizer
14
-
15
- # State managed by the helper
16
14
  self.layer_state = {}
15
+
17
16
  self.layer_info = {}
18
17
  self._layer_info_built = False
19
18
  self._current_step_prepared = -1
20
19
 
20
+ # Store stats for external logging (e.g., TensorBoard)
21
+ self.last_beta2_stats = {}
22
+
21
23
  # This ensures the map is complete before the first backward pass,
22
24
  # making it compatible with fused back pass mechanisms.
23
25
  self._build_layer_info_if_needed()
24
26
 
27
+ if self.optimizer.param_groups[0].get('k_logging', 0) > 0:
28
+ self.print_layer_info()
29
+
25
30
  def _build_layer_info_if_needed(self):
26
31
  """Builds a map of layers and the parameters they contain."""
27
32
  if self._layer_info_built:
28
33
  return
29
-
34
+
30
35
  if not hasattr(self.optimizer, 'layer_key_fn') or self.optimizer.layer_key_fn is None:
31
36
  print("Warning: KourkoutasHelper requires 'layer_key_fn' on the optimizer. Defaulting to tensor-wise (id).")
32
37
  self.optimizer.layer_key_fn = lambda p: id(p)
@@ -45,64 +50,94 @@ class KourkoutasHelper:
45
50
 
46
51
  self._layer_info_built = True
47
52
 
53
+ def print_layer_info(self):
54
+ """Prints the contents of self.layer_info for debugging."""
55
+ print("\n--- BEGIN self.layer_info DUMP ---")
56
+ if not self.layer_info:
57
+ print("Layer info is empty. Make sure the optimizer has parameters.")
58
+ return
59
+
60
+ for layer_key, info in self.layer_info.items():
61
+ param_count = len(info['params'])
62
+ first_param_details = ""
63
+ if param_count > 0:
64
+ p = info['params'][0]
65
+ first_param_details = f" (Example param shape: {list(p.shape)}, dtype: {p.dtype})"
66
+
67
+ print(f"Key: {layer_key}, Params: {param_count}{first_param_details}")
68
+
69
+ print("--- END self.layer_info DUMP ---\n")
70
+
48
71
  def prepare_step(self, current_step: int):
49
72
  """
50
73
  Calculates dynamic beta2 for all layers using the completed scalar accumulators
51
74
  from the PREVIOUS step. Should be called once at the start of an optimizer step.
52
75
  """
53
-
54
- # Check if logging is enabled for this step based on the interval
55
- k_logging_interval = self.optimizer.param_groups[0].get('k_logging', 0)
56
- is_logging_step = k_logging_interval > 0 and (current_step + 1) % k_logging_interval == 0
57
-
58
- beta2_log = [] if is_logging_step else None
76
+
77
+ beta2_log = []
59
78
  first_layer_key = next(iter(self.layer_info), None)
79
+ # These are just for the sample log, initialize them
80
+ sun, pooled_grad_norm, prev_r_ema_val, r_ema_tensor = (torch.tensor(0.0),)*4
60
81
 
61
82
  for layer_key, info in self.layer_info.items():
62
83
  params, group = info['params'], info['group_ref']
63
-
84
+
85
+ first_param_in_layer = info['params'][0]
86
+ param_state = self.optimizer.state[first_param_in_layer]
87
+
64
88
  if layer_key not in self.layer_state:
65
89
  self.layer_state[layer_key] = {
66
- 'r_ema_grad_norm': torch.tensor(0.0, device=params[0].device, dtype=torch.float32),
67
- 'sum_sq_accumulator': torch.tensor(0.0, device=params[0].device, dtype=torch.float32)
90
+ 'sum_sq_accumulator': torch.tensor(0.0, device=first_param_in_layer.device, dtype=torch.float32)
68
91
  }
69
92
 
70
- layer_state = self.layer_state[layer_key]
71
-
72
- # Use the completed accumulator from the previous step
73
- pooled_grad_norm = torch.sqrt(layer_state['sum_sq_accumulator'])
93
+ if 'kourkoutas_r_ema' not in param_state:
94
+ param_state['kourkoutas_r_ema'] = torch.tensor(0.0, device=first_param_in_layer.device, dtype=torch.float32)
95
+
96
+ r_ema_tensor = param_state['kourkoutas_r_ema']
97
+ accumulator = self.layer_state[layer_key]['sum_sq_accumulator']
74
98
 
75
- r_ema = layer_state['r_ema_grad_norm']
76
- prev_r_ema_val = r_ema.item() # for logging
99
+ pooled_grad_norm = torch.sqrt(accumulator)
100
+ prev_r_ema_val = r_ema_tensor.item() # for logging
77
101
 
78
- # EMA is always updated, even during warmup
79
- r_ema.mul_(group['ema_alpha']).add_(pooled_grad_norm, alpha=1.0 - group['ema_alpha'])
102
+ # Update the persistent EMA tensor in-place.
103
+ r_ema_tensor.mul_(group['ema_alpha']).add_(pooled_grad_norm, alpha=1.0 - group['ema_alpha'])
80
104
 
81
- sun = torch.tensor(0.0, device=r_ema.device) # Default sun to 0 for warmup
82
105
  beta2_max = group['betas'][1]
83
-
84
- # --- CONSOLIDATED WARMUP LOGIC ---
106
+ sun = torch.tensor(0.0, device=r_ema_tensor.device) # Default sun to 0 for warmup
107
+
85
108
  if current_step < group['k_warmup_steps']:
86
109
  beta2 = beta2_max
87
110
  else:
88
- raw = pooled_grad_norm / (r_ema + group['tiny_spike'])
111
+ raw = pooled_grad_norm / (r_ema_tensor + group['tiny_spike'])
89
112
  sun = raw / (1.0 + raw)
90
113
  beta2 = beta2_max - (beta2_max - group['beta2_min']) * sun
91
114
 
92
- layer_state['dynamic_beta2'] = beta2.item() if isinstance(beta2, torch.Tensor) else beta2
93
- layer_state['sum_sq_accumulator'].zero_()
115
+ # Store the final calculated beta2 in the helper's transient state for this step.
116
+ self.layer_state[layer_key]['dynamic_beta2'] = beta2.item() if isinstance(beta2, torch.Tensor) else beta2
117
+
118
+ # Reset the accumulator for the next optimizer step.
119
+ accumulator.zero_()
94
120
 
95
- if is_logging_step:
96
- beta2_log.append(layer_state['dynamic_beta2'])
97
- if layer_key == first_layer_key:
98
- print(f"\n[Kourkoutas-β Debug] Step {current_step + 1} - Sample Layer '{layer_key}':")
99
- print(f" - Grad Norm: {pooled_grad_norm.item():.4e}, Prev EMA: {prev_r_ema_val:.4e}, New EMA: {r_ema.item():.4e}")
100
- print(f" - Sunspike: {sun.item():.4f}, Dynamic Beta2: {layer_state['dynamic_beta2']:.4f}")
101
-
102
- if is_logging_step and beta2_log:
121
+ beta2_log.append(self.layer_state[layer_key]['dynamic_beta2'])
122
+
123
+ # Always compute stats for TensorBoard
124
+ if beta2_log:
103
125
  beta2_tensor = torch.tensor(beta2_log, device='cpu')
104
- print(f"[Kourkoutas-β Debug] Step {current_step + 1} Overall Beta2 Stats: Min={beta2_tensor.min():.4f}, Max={beta2_tensor.max():.4f}, Mean={beta2_tensor.mean():.4f}")
126
+ self.last_beta2_stats = {
127
+ 'min': beta2_tensor.min().item(),
128
+ 'max': beta2_tensor.max().item(),
129
+ 'mean': beta2_tensor.mean().item(),
130
+ }
105
131
 
132
+ # Handle periodic console logging
133
+ k_logging_interval = self.optimizer.param_groups[0].get('k_logging', 0)
134
+ is_logging_step = k_logging_interval > 0 and (current_step + 1) % k_logging_interval == 0
135
+ if is_logging_step and self.last_beta2_stats:
136
+ if first_layer_key:
137
+ print(f"\n[Kourkoutas-β Debug] Step {current_step + 1} - Sample Layer '{first_layer_key}':")
138
+ print(f" - Grad Norm: {pooled_grad_norm.item():.4e}, Prev EMA: {prev_r_ema_val:.4e}, New EMA: {r_ema_tensor.item():.4e}")
139
+ print(f" - Sunspike: {sun.item():.4f}, Dynamic Beta2: {self.layer_state[first_layer_key]['dynamic_beta2']:.4f}")
140
+ print(f"[Kourkoutas-β Debug] Step {current_step + 1} Overall Beta2 Stats: Min={self.last_beta2_stats['min']:.4f}, Max={self.last_beta2_stats['max']:.4f}, Mean={self.last_beta2_stats['mean']:.4f}")
106
141
 
107
142
  def maybe_prepare_step(self, current_step: int):
108
143
  """
@@ -119,9 +154,9 @@ class KourkoutasHelper:
119
154
  layer_key = self.optimizer.layer_key_fn(p)
120
155
 
121
156
  if layer_key in self.layer_info:
157
+ # Initialize the transient state for this layer if it's the first time in the step.
122
158
  if layer_key not in self.layer_state:
123
159
  self.layer_state[layer_key] = {
124
- 'r_ema_grad_norm': torch.tensor(0.0, device=p.device, dtype=torch.float32),
125
160
  'sum_sq_accumulator': torch.tensor(0.0, device=p.device, dtype=torch.float32)
126
161
  }
127
162
  # Accumulate for the *next* step's prepare_step call
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.1.0.dev3
3
+ Version: 1.1.0.dev5
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -0,0 +1,20 @@
1
+ adv_optm/__init__.py,sha256=lOHXiF0KmYmUnaQGIoUYeIxdEfYE8T1hFSVq5FVujDs,311
2
+ adv_optm/optim/AdamW_adv.py,sha256=gVVpaKIbpv8pkfvfgVGCQN6No8A4atO7eRSPDBUVqq8,17490
3
+ adv_optm/optim/Adopt_adv.py,sha256=K7z1iiln_HxuEPLl9yGtCngBfdZHxJISQ5dKgNBV-s4,21463
4
+ adv_optm/optim/Lion_Prodigy_adv.py,sha256=aJ9orEEw0QYbrDzn1be0SHvOBlIkLwWG9RpWFuNMskM,13163
5
+ adv_optm/optim/Lion_adv.py,sha256=aGNAplZlyXYgVllYcV_s4bK8iC4fv6EizFoWIMNLdBc,8299
6
+ adv_optm/optim/Prodigy_adv.py,sha256=ecdnnbRgclcG49sGzxAmPHPE_0KkaQWtaiynsBYudoM,25979
7
+ adv_optm/optim/Simplified_AdEMAMix.py,sha256=Cm-8tdCaTahdz45EExgn2W3a5Xl44T9MW-IMrUDbJFk,12983
8
+ adv_optm/optim/__init__.py,sha256=pcP865H2j1tut2VfTUhzQh7V8TF_tzPjqFnjMfFed2k,382
9
+ adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
10
+ adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
11
+ adv_optm/util/Kourkoutas.py,sha256=DCsIcZ1sEeSwthN8KZH7OTKoIZJ3ah4t5DNiqxsSuCk,8344
12
+ adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
13
+ adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
14
+ adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
15
+ adv_optm/util/__init__.py,sha256=qoyIF0jcLjs_vSEcsv36clw5LFNBEbifyXrrVxMH-G4,349
16
+ adv_optm-1.1.0.dev5.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
17
+ adv_optm-1.1.0.dev5.dist-info/METADATA,sha256=2xyGCRbIN54aIuAWnRIpR49okoVgVJb2AGHl2-jgVx8,8427
18
+ adv_optm-1.1.0.dev5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
19
+ adv_optm-1.1.0.dev5.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
20
+ adv_optm-1.1.0.dev5.dist-info/RECORD,,
@@ -1,20 +0,0 @@
1
- adv_optm/__init__.py,sha256=aSPtwpl2S7i_-KYXTDDeKoQlcLjZc6whVUNOINl6TEA,311
2
- adv_optm/optim/AdamW_adv.py,sha256=H4XlYZELwiFvXt0A9wMlRNiw9c8rmPMspHDCvR_SZIQ,17487
3
- adv_optm/optim/Adopt_adv.py,sha256=PJ3ZaLgzYbvxXDS56FGjzMrVMyHDXSWdUPHnX5NpNAA,21241
4
- adv_optm/optim/Lion_Prodigy_adv.py,sha256=sGzhts9a6gHfCkuHTB5L9IrClo4c6UThzYYErBwqOaA,12844
5
- adv_optm/optim/Lion_adv.py,sha256=6G1CukJB_pC7l9HwFEuY1ydsNHZFabVmOvcHDsHHVuQ,8295
6
- adv_optm/optim/Prodigy_adv.py,sha256=EeSfYu8IIeZX1Dk8MlD71vGOpMadtnW2iMhHxPDL2XQ,25574
7
- adv_optm/optim/Simplified_AdEMAMix.py,sha256=b4GaSI-TX6wFBqGxZeoJPbf2nVRCEtB3WVb1olDgY14,12980
8
- adv_optm/optim/__init__.py,sha256=pcP865H2j1tut2VfTUhzQh7V8TF_tzPjqFnjMfFed2k,382
9
- adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
10
- adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
11
- adv_optm/util/Kourkoutas.py,sha256=UN_EAbG-9p98Qp2c_vSUy1Gw1K55SQ_e0TmnNBb-OFQ,6748
12
- adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
13
- adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
14
- adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
15
- adv_optm/util/__init__.py,sha256=qoyIF0jcLjs_vSEcsv36clw5LFNBEbifyXrrVxMH-G4,349
16
- adv_optm-1.1.0.dev3.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
17
- adv_optm-1.1.0.dev3.dist-info/METADATA,sha256=03sDh1nQ1CQXxu4TbRnRblX1IZ9S-Eka7hP1LNs54WA,8427
18
- adv_optm-1.1.0.dev3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
19
- adv_optm-1.1.0.dev3.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
20
- adv_optm-1.1.0.dev3.dist-info/RECORD,,