adv-optm 1.1.0.dev2__py3-none-any.whl → 1.1.0.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

@@ -11,25 +11,32 @@ class KourkoutasHelper:
11
11
  if not hasattr(optimizer, 'param_groups'):
12
12
  raise TypeError("optimizer must be a valid torch.optim.Optimizer instance.")
13
13
  self.optimizer = optimizer
14
-
14
+
15
15
  # State managed by the helper
16
16
  self.layer_state = {}
17
17
  self.layer_info = {}
18
18
  self._layer_info_built = False
19
19
  self._current_step_prepared = -1
20
20
 
21
+ # Store stats for external logging (e.g., TensorBoard)
22
+ self.last_beta2_stats = {}
23
+
24
+ # This ensures the map is complete before the first backward pass,
25
+ # making it compatible with fused back pass mechanisms.
26
+ self._build_layer_info_if_needed()
27
+
21
28
  def _build_layer_info_if_needed(self):
22
29
  """Builds a map of layers and the parameters they contain."""
23
30
  if self._layer_info_built:
24
31
  return
25
-
32
+
26
33
  if not hasattr(self.optimizer, 'layer_key_fn') or self.optimizer.layer_key_fn is None:
27
34
  print("Warning: KourkoutasHelper requires 'layer_key_fn' on the optimizer. Defaulting to tensor-wise (id).")
28
35
  self.optimizer.layer_key_fn = lambda p: id(p)
29
36
 
30
37
  for group in self.optimizer.param_groups:
31
38
  for p in group['params']:
32
- if p.grad is None: continue
39
+ # The mapping is static and should not depend on the presence of a gradient.
33
40
  layer_key = self.optimizer.layer_key_fn(p)
34
41
  if layer_key not in self.layer_info:
35
42
  self.layer_info[layer_key] = {'params': [], 'group_ref': group}
@@ -46,14 +53,11 @@ class KourkoutasHelper:
46
53
  Calculates dynamic beta2 for all layers using the completed scalar accumulators
47
54
  from the PREVIOUS step. Should be called once at the start of an optimizer step.
48
55
  """
49
- self._build_layer_info_if_needed()
50
-
51
- # Check if logging is enabled for this step based on the interval
52
- k_logging_interval = self.optimizer.param_groups[0].get('k_logging', 0)
53
- is_logging_step = k_logging_interval > 0 and (current_step + 1) % k_logging_interval == 0
56
+
57
+ beta2_log = []
58
+ # These are just for the sample log, initialize them
59
+ sun, pooled_grad_norm, r_ema = (torch.tensor(0.0),)*3
54
60
 
55
- beta2_log = [] if is_logging_step else None
56
- first_layer_key = next(iter(self.layer_info), None)
57
61
 
58
62
  for layer_key, info in self.layer_info.items():
59
63
  params, group = info['params'], info['group_ref']
@@ -65,16 +69,15 @@ class KourkoutasHelper:
65
69
  }
66
70
 
67
71
  layer_state = self.layer_state[layer_key]
68
-
72
+
69
73
  # Use the completed accumulator from the previous step
70
74
  pooled_grad_norm = torch.sqrt(layer_state['sum_sq_accumulator'])
71
-
75
+
72
76
  r_ema = layer_state['r_ema_grad_norm']
73
- prev_r_ema_val = r_ema.item() # for logging
74
-
77
+
75
78
  # EMA is always updated, even during warmup
76
79
  r_ema.mul_(group['ema_alpha']).add_(pooled_grad_norm, alpha=1.0 - group['ema_alpha'])
77
-
80
+
78
81
  sun = torch.tensor(0.0, device=r_ema.device) # Default sun to 0 for warmup
79
82
  beta2_max = group['betas'][1]
80
83
 
@@ -89,16 +92,22 @@ class KourkoutasHelper:
89
92
  layer_state['dynamic_beta2'] = beta2.item() if isinstance(beta2, torch.Tensor) else beta2
90
93
  layer_state['sum_sq_accumulator'].zero_()
91
94
 
92
- if is_logging_step:
93
- beta2_log.append(layer_state['dynamic_beta2'])
94
- if layer_key == first_layer_key:
95
- print(f"\n[Kourkoutas-β Debug] Step {current_step + 1} - Sample Layer '{layer_key}':")
96
- print(f" - Grad Norm: {pooled_grad_norm.item():.4e}, Prev EMA: {prev_r_ema_val:.4e}, New EMA: {r_ema.item():.4e}")
97
- print(f" - Sunspike: {sun.item():.4f}, Dynamic Beta2: {layer_state['dynamic_beta2']:.4f}")
98
-
99
- if is_logging_step and beta2_log:
95
+ beta2_log.append(layer_state['dynamic_beta2'])
96
+
97
+ # Always compute stats for TensorBoard
98
+ if beta2_log:
100
99
  beta2_tensor = torch.tensor(beta2_log, device='cpu')
101
- print(f"[Kourkoutas-β Debug] Step {current_step + 1} Overall Beta2 Stats: Min={beta2_tensor.min():.4f}, Max={beta2_tensor.max():.4f}, Mean={beta2_tensor.mean():.4f}")
100
+ self.last_beta2_stats = {
101
+ 'min': beta2_tensor.min().item(),
102
+ 'max': beta2_tensor.max().item(),
103
+ 'mean': beta2_tensor.mean().item(),
104
+ }
105
+
106
+ # Handle periodic console logging
107
+ k_logging_interval = self.optimizer.param_groups[0].get('k_logging', 0)
108
+ is_logging_step = k_logging_interval > 0 and (current_step + 1) % k_logging_interval == 0
109
+ if is_logging_step and self.last_beta2_stats:
110
+ print(f"[Kourkoutas-β Debug] Step {current_step + 1} Overall Beta2 Stats: Min={self.last_beta2_stats['min']:.4f}, Max={self.last_beta2_stats['max']:.4f}, Mean={self.last_beta2_stats['mean']:.4f}")
102
111
 
103
112
 
104
113
  def maybe_prepare_step(self, current_step: int):
@@ -113,7 +122,6 @@ class KourkoutasHelper:
113
122
  """
114
123
  Accumulates the squared L2 norm of a single gradient for the next step's calculation.
115
124
  """
116
- self._build_layer_info_if_needed()
117
125
  layer_key = self.optimizer.layer_key_fn(p)
118
126
 
119
127
  if layer_key in self.layer_info:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.1.0.dev2
3
+ Version: 1.1.0.dev4
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -1,20 +1,20 @@
1
- adv_optm/__init__.py,sha256=hkmbLr1AVDoC6VbnyTkNy-G4g5bmcLFH2Kv4dYWp9uY,311
1
+ adv_optm/__init__.py,sha256=H4E_1__pXxRu4PSgQCzGi7WuFqVjTfex2Yduz3B3peI,311
2
2
  adv_optm/optim/AdamW_adv.py,sha256=H4XlYZELwiFvXt0A9wMlRNiw9c8rmPMspHDCvR_SZIQ,17487
3
- adv_optm/optim/Adopt_adv.py,sha256=PJ3ZaLgzYbvxXDS56FGjzMrVMyHDXSWdUPHnX5NpNAA,21241
4
- adv_optm/optim/Lion_Prodigy_adv.py,sha256=sGzhts9a6gHfCkuHTB5L9IrClo4c6UThzYYErBwqOaA,12844
3
+ adv_optm/optim/Adopt_adv.py,sha256=0uMROjCw3wGOyp0ZX_xjwMVaXHJ395ifntcgY0MZt3M,21460
4
+ adv_optm/optim/Lion_Prodigy_adv.py,sha256=xIrwibQ2i919EHEACLCrKe5JBnS-s2Ai35yeJ1Bn1MA,13159
5
5
  adv_optm/optim/Lion_adv.py,sha256=6G1CukJB_pC7l9HwFEuY1ydsNHZFabVmOvcHDsHHVuQ,8295
6
- adv_optm/optim/Prodigy_adv.py,sha256=-eMTutexbGrUQtSehKaOo6BO_p3QySpSIMgJKWvbxog,25517
6
+ adv_optm/optim/Prodigy_adv.py,sha256=EeSfYu8IIeZX1Dk8MlD71vGOpMadtnW2iMhHxPDL2XQ,25574
7
7
  adv_optm/optim/Simplified_AdEMAMix.py,sha256=b4GaSI-TX6wFBqGxZeoJPbf2nVRCEtB3WVb1olDgY14,12980
8
8
  adv_optm/optim/__init__.py,sha256=pcP865H2j1tut2VfTUhzQh7V8TF_tzPjqFnjMfFed2k,382
9
9
  adv_optm/util/BF16_Stochastic_Rounding.py,sha256=Q5H0BcogmE4atP65dLoI21HKSf50lRdsBDfeF6v9Tbg,1548
10
10
  adv_optm/util/Effective_Shape.py,sha256=TBvIk1V8IuTbbBsxuekJA4e_v8JlR5Nujtut8RTWAm4,318
11
- adv_optm/util/Kourkoutas.py,sha256=6OzK96KJ7Dd9Py8hiGWszF9C_n4uVoDjFCA_EYbhL4c,6600
11
+ adv_optm/util/Kourkoutas.py,sha256=st9hO2I0Xcby0LLq1MhxiEsPyNzEkNpJO_WfYvkioKg,6606
12
12
  adv_optm/util/NNMF.py,sha256=yRf5IP5Sjq0Uf0DxN0Q8NxEGSdD-f1ULziLVDOjY8K4,639
13
13
  adv_optm/util/One_Bit_Boolean.py,sha256=Wat49esdwohuN-OHOFMW8D0aOQgV9cP5Rl8z6yfmpos,1068
14
14
  adv_optm/util/OrthoGrad.py,sha256=NzInuBQGy_Ja__M1R9XbvqVaQ0fhGbtGgFE9YON7B3I,707
15
15
  adv_optm/util/__init__.py,sha256=qoyIF0jcLjs_vSEcsv36clw5LFNBEbifyXrrVxMH-G4,349
16
- adv_optm-1.1.0.dev2.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
17
- adv_optm-1.1.0.dev2.dist-info/METADATA,sha256=Y2F2wkpPmdbRtHft1KdCm1D6feTmiP5kFJ6iYpSLwCo,8427
18
- adv_optm-1.1.0.dev2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
19
- adv_optm-1.1.0.dev2.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
20
- adv_optm-1.1.0.dev2.dist-info/RECORD,,
16
+ adv_optm-1.1.0.dev4.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
17
+ adv_optm-1.1.0.dev4.dist-info/METADATA,sha256=Ue6x-vthnxradX5tH1ver4LVbWMEMmqPjMVO8KjTdhI,8427
18
+ adv_optm-1.1.0.dev4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
19
+ adv_optm-1.1.0.dev4.dist-info/top_level.txt,sha256=iNfBIIzu-lPrQ7jyC56WBCcbkRwitM2nJ15-MRQ_6fg,9
20
+ adv_optm-1.1.0.dev4.dist-info/RECORD,,