adafactor8bit 0.2.1__tar.gz → 0.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adafactor8bit
3
- Version: 0.2.1
3
+ Version: 0.2.2
4
4
  Summary: 8-bit Adafactor Optimizer with Fused CUDA Kernels
5
5
  Home-page: https://github.com/yanfeiwong/adafactor-8bit
6
6
  Author: WANG YAN
@@ -25,6 +25,13 @@ Dynamic: requires-dist
25
25
  Dynamic: requires-python
26
26
  Dynamic: summary
27
27
 
28
+ <p align="center">
29
+ <a href="https://github.com/yanfeiwong/adafactor-8bit">
30
+ <img src="https://github.com/yanfeiwong/adafactor-8bit/raw/main/assets/banner.png"
31
+ alt="Adafactor8Bit"
32
+ width="80%">
33
+ </a>
34
+ </p>
28
35
  <div align="center">
29
36
 
30
37
  # 8-bit Adafactor with Fused CUDA Kernels
@@ -39,14 +46,15 @@ Dynamic: summary
39
46
 
40
47
  </div>
41
48
 
42
- An enhanced 8-bit Adafactor optimizer featuring fused CUDA kernels, log-space block-wise quantization, optional APOLLO low-rank updates, and 4-bit packed first moments, delivering substantially lower optimizer memory while preserving the low-overhead and numerical stability that make Adafactor attractive for training LLMs and diffusion models.
49
+ An enhanced 8-bit Adafactor optimizer featuring fused CUDA kernels, log-space block-wise quantization, and optional add-ons including 4-bit packed first moments, APOLLO low-rank updates, and CAME confidence-guided optimization. It delivers substantially lower optimizer memory while preserving the low-overhead and numerical stability that make Adafactor attractive for training LLMs and diffusion models.
43
50
 
44
51
 
45
- ## 🔥 Key Features
52
+ ## Key Features
46
53
 
47
54
  - **Log-Space Quantization**: Maps the second moment (variance) to the log2 space before 8-bit quantization. This approach accommodates the long-tail distribution of variances, reducing the risk of small second-moment estimates being truncated to zero and improving overall training stability.
48
55
  - **Fused CUDA Kernels**: Combines dequantization, EMA updates, Warp-Shuffle reductions, and requantization into single kernels. It utilizes `float4` vectorization to optimize memory bandwidth usage.
49
56
  - **Optional 4-bit Packed First Moment**: Stores the first moment (`beta1`) in a physically packed 4-bit format when enabled, providing momentum with minimal additional memory overhead.
57
+ - **CAME Confidence Guidance**: Optional Confidence-guided Adaptive Memory Efficient Optimization (CAME) that estimates update confidence from historical momentum and adaptively suppresses unstable update directions, improving training stability and reducing loss spikes.
50
58
  - **APOLLO Subspace Projection**: Opt-in random subspace projection that estimates adaptive gradient scaling in a low-rank space, preventing stale second-moment statistics and potentially improving convergence and generalization.
51
59
  - **Fira Norm-Growth Limiter**: Suppresses destructive gradient spikes by regulating the relative increase of update norms. Originally used for the APOLLO path, it is now available for the standard Adafactor path as well. It improves training stability and often allows the safe removal of external gradient clipping.
52
60
  - **Zero CPU-GPU Sync**: Eliminates implicit synchronizations (e.g., D2H copies) in the control flow, ensuring the GPU computation pipeline runs without blocking.
@@ -194,16 +202,16 @@ def get_param_groups(model, lr_emb, weight_decay, apollo_rank=256):
194
202
  "weight_decay": weight_decay,
195
203
  "quantize": True,
196
204
  "apollo_rank": apollo_rank,
197
- "beta1":0.9, # Remove if minimizing optimizer memory is the priority.
205
+ "beta1": 0.9, # Remove if minimizing optimizer memory is the priority.
198
206
  },
199
-
207
+
200
208
  # 4. >2D Weights: 8-bit quantization, Weight Decay, Full-Rank
201
209
  {
202
210
  "params": group_nd,
203
211
  "weight_decay": weight_decay,
204
212
  "quantize": True,
205
213
  "apollo_rank": 0,
206
- "beta1":0.9, # Remove if minimizing optimizer memory is the priority.
214
+ "beta1": 0.9, # Remove if minimizing optimizer memory is the priority.
207
215
  "factored": False # Disables factorization to preserve spatial structures, enabling finer gradient scaling.
208
216
  # Note: This increases state memory for >2D weights, depending on your model architecture.
209
217
  # If VRAM is constrained, reverting to factored=True is a safe alternative.
@@ -270,7 +278,40 @@ Enable the APOLLO path to compute gradient scaling factors in a memory-efficient
270
278
  - **`apollo_factorize` (Experimental)**: Applies Adafactor's row/column factorization within the low-rank subspace. Mathematically, this leverages the norm-preserving property of random projections to approximate the variance of the primary dimension, while the secondary dimension's variance is estimated across random bases, introducing inherent noise. This dual-compression mechanism drastically reduces optimizer state overhead. Note that for smaller models, the actual VRAM savings might be marginal, and the introduced noise could impact convergence stability. Use with caution.
271
279
  - **Fira Limiter Integration**: The APOLLO path automatically applies the Fira Norm-Growth Limiter to the scaled gradients to prevent sudden gradient rises from causing loss spikes. You can adjust its sensitivity using the global `fira_margin` parameter.
272
280
 
281
+ ## 🛡️ CAME Confidence-Guided Updates
282
+
283
+ Enable the CAME (Confidence-guided Adaptive Memory Efficient Optimization) path to add a confidence estimation stage after momentum accumulation:
284
+
285
+ **Adaptive Scaling ($V$) → Momentum Accumulation ($M$) → Confidence Weighting ($C$)**
286
+
287
+ ### Key Parameters & Tuning
273
288
 
289
+ The confidence stage measures the consistency between the current update direction and historical momentum, adaptively suppressing highly oscillatory updates.
290
+
291
+ - **`beta3`**: EMA decay coefficient for the confidence matrix. Requires `beta1` (momentum) and `factored=True`. Mutually exclusive with `apollo_rank`. Defaults to `None` (disabled).
292
+ - **Learning Rate**: The official CAME implementation recommends **0.5–0.9×** the AdamW learning rate (see [official tuning guide](https://github.com/yangluo7/CAME/tree/master#hyper-parameter-tuning)). To use this learning rate in this library, you need to disable Adafactor's scaling and clipping (`scale_parameter=False`, `d=1e9`) to align with the original CAME behavior.
293
+ - **Warmup**: Since the confidence matrix is zero-initialized without bias correction, a learning rate warmup is recommended to safely establish the confidence baseline.
294
+ - **Choosing `beta3`**: `beta3` should generally be larger than `beta2` so the confidence estimate evolves more slowly than the variance estimate. A practical starting range is **0.9995–0.99995** when `beta2=0.999`.
295
+
296
+
297
+ ### Configuration Example
298
+
299
+ To replicate "vanilla" CAME (stripping Adafactor's native modifications), replace the standard 2D APOLLO group in your `param_groups` with the following configuration:
300
+
301
+ ```python
302
+ {
303
+ "params": param_group,
304
+ "lr": lr, # Original CAME recommends 0.5-0.9x AdamW LR
305
+ "weight_decay": weight_decay,
306
+ "quantize": True,
307
+ "beta1": 0.9,
308
+ "beta3": 0.9999, # Enable CAME confidence guidance
309
+ "apollo_rank": 0, # Mutually exclusive with CAME
310
+ "scale_parameter": False, # Disable Adafactor RMS scaling to align with vanilla CAME
311
+ "d": 1e9, # Disable Adafactor global RMS clipping
312
+ "enable_fira_for_adafactor": False, # Disable Fira Limiter to prevent interference with CAME's scaling
313
+ },
314
+ ```
274
315
 
275
316
  ## 📈 Learning Rate Guide for Beginners
276
317
 
@@ -298,16 +339,18 @@ Thanks to **Hanqing Zhu**, **Zhenyu Zhang**, and the team for proposing the appr
298
339
 
299
340
  Thanks to **Xi Chen**, **Kaituo Feng**, and the team for the Norm-Growth Limiter mechanism introduced in [Fira: Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623).
300
341
 
342
+ Thanks to **Yang Luo** and the team for proposing the confidence-guided strategy in the paper [CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047).
343
+
301
344
  Thanks to the **PyTorch team** for providing the foundational Optimizer implementation and the C++ Extension toolchain.
302
345
 
303
346
  Thanks to the large language models **Qwen**, **ChatGLM** and **DeepSeek** for valuable technical discussions and code reviews on CUDA low-level optimization and memory safety mechanisms.
304
347
 
348
+ ## 🏛️ License
349
+
350
+ [The project is released under the MIT License.](https://github.com/yanfeiwong/adafactor-8bit/blob/main/LICENSE)
351
+
305
352
  ## ⭐ Star the Project
306
353
 
307
354
  If this optimizer has been useful in your work, consider giving the repository a star. It helps others discover the project and supports future development.
308
355
 
309
356
  [![Star History Chart](https://api.star-history.com/svg?repos=yanfeiwong/adafactor-8bit&type=Date&theme=dark)](https://star-history.com/#yanfeiwong/adafactor-8bit&Date)
310
-
311
- ## 📄 License
312
-
313
- [The project is released under the MIT License.](https://github.com/yanfeiwong/adafactor-8bit/blob/main/LICENSE)
@@ -1,3 +1,10 @@
1
+ <p align="center">
2
+ <a href="https://github.com/yanfeiwong/adafactor-8bit">
3
+ <img src="https://github.com/yanfeiwong/adafactor-8bit/raw/main/assets/banner.png"
4
+ alt="Adafactor8Bit"
5
+ width="80%">
6
+ </a>
7
+ </p>
1
8
  <div align="center">
2
9
 
3
10
  # 8-bit Adafactor with Fused CUDA Kernels
@@ -12,14 +19,15 @@
12
19
 
13
20
  </div>
14
21
 
15
- An enhanced 8-bit Adafactor optimizer featuring fused CUDA kernels, log-space block-wise quantization, optional APOLLO low-rank updates, and 4-bit packed first moments, delivering substantially lower optimizer memory while preserving the low-overhead and numerical stability that make Adafactor attractive for training LLMs and diffusion models.
22
+ An enhanced 8-bit Adafactor optimizer featuring fused CUDA kernels, log-space block-wise quantization, and optional add-ons including 4-bit packed first moments, APOLLO low-rank updates, and CAME confidence-guided optimization. It delivers substantially lower optimizer memory while preserving the low-overhead and numerical stability that make Adafactor attractive for training LLMs and diffusion models.
16
23
 
17
24
 
18
- ## 🔥 Key Features
25
+ ## Key Features
19
26
 
20
27
  - **Log-Space Quantization**: Maps the second moment (variance) to the log2 space before 8-bit quantization. This approach accommodates the long-tail distribution of variances, reducing the risk of small second-moment estimates being truncated to zero and improving overall training stability.
21
28
  - **Fused CUDA Kernels**: Combines dequantization, EMA updates, Warp-Shuffle reductions, and requantization into single kernels. It utilizes `float4` vectorization to optimize memory bandwidth usage.
22
29
  - **Optional 4-bit Packed First Moment**: Stores the first moment (`beta1`) in a physically packed 4-bit format when enabled, providing momentum with minimal additional memory overhead.
30
+ - **CAME Confidence Guidance**: Optional Confidence-guided Adaptive Memory Efficient Optimization (CAME) that estimates update confidence from historical momentum and adaptively suppresses unstable update directions, improving training stability and reducing loss spikes.
23
31
  - **APOLLO Subspace Projection**: Opt-in random subspace projection that estimates adaptive gradient scaling in a low-rank space, preventing stale second-moment statistics and potentially improving convergence and generalization.
24
32
  - **Fira Norm-Growth Limiter**: Suppresses destructive gradient spikes by regulating the relative increase of update norms. Originally used for the APOLLO path, it is now available for the standard Adafactor path as well. It improves training stability and often allows the safe removal of external gradient clipping.
25
33
  - **Zero CPU-GPU Sync**: Eliminates implicit synchronizations (e.g., D2H copies) in the control flow, ensuring the GPU computation pipeline runs without blocking.
@@ -167,16 +175,16 @@ def get_param_groups(model, lr_emb, weight_decay, apollo_rank=256):
167
175
  "weight_decay": weight_decay,
168
176
  "quantize": True,
169
177
  "apollo_rank": apollo_rank,
170
- "beta1":0.9, # Remove if minimizing optimizer memory is the priority.
178
+ "beta1": 0.9, # Remove if minimizing optimizer memory is the priority.
171
179
  },
172
-
180
+
173
181
  # 4. >2D Weights: 8-bit quantization, Weight Decay, Full-Rank
174
182
  {
175
183
  "params": group_nd,
176
184
  "weight_decay": weight_decay,
177
185
  "quantize": True,
178
186
  "apollo_rank": 0,
179
- "beta1":0.9, # Remove if minimizing optimizer memory is the priority.
187
+ "beta1": 0.9, # Remove if minimizing optimizer memory is the priority.
180
188
  "factored": False # Disables factorization to preserve spatial structures, enabling finer gradient scaling.
181
189
  # Note: This increases state memory for >2D weights, depending on your model architecture.
182
190
  # If VRAM is constrained, reverting to factored=True is a safe alternative.
@@ -243,7 +251,40 @@ Enable the APOLLO path to compute gradient scaling factors in a memory-efficient
243
251
  - **`apollo_factorize` (Experimental)**: Applies Adafactor's row/column factorization within the low-rank subspace. Mathematically, this leverages the norm-preserving property of random projections to approximate the variance of the primary dimension, while the secondary dimension's variance is estimated across random bases, introducing inherent noise. This dual-compression mechanism drastically reduces optimizer state overhead. Note that for smaller models, the actual VRAM savings might be marginal, and the introduced noise could impact convergence stability. Use with caution.
244
252
  - **Fira Limiter Integration**: The APOLLO path automatically applies the Fira Norm-Growth Limiter to the scaled gradients to prevent sudden gradient rises from causing loss spikes. You can adjust its sensitivity using the global `fira_margin` parameter.
245
253
 
254
+ ## 🛡️ CAME Confidence-Guided Updates
255
+
256
+ Enable the CAME (Confidence-guided Adaptive Memory Efficient Optimization) path to add a confidence estimation stage after momentum accumulation:
257
+
258
+ **Adaptive Scaling ($V$) → Momentum Accumulation ($M$) → Confidence Weighting ($C$)**
259
+
260
+ ### Key Parameters & Tuning
246
261
 
262
+ The confidence stage measures the consistency between the current update direction and historical momentum, adaptively suppressing highly oscillatory updates.
263
+
264
+ - **`beta3`**: EMA decay coefficient for the confidence matrix. Requires `beta1` (momentum) and `factored=True`. Mutually exclusive with `apollo_rank`. Defaults to `None` (disabled).
265
+ - **Learning Rate**: The official CAME implementation recommends **0.5–0.9×** the AdamW learning rate (see [official tuning guide](https://github.com/yangluo7/CAME/tree/master#hyper-parameter-tuning)). To use this learning rate in this library, you need to disable Adafactor's scaling and clipping (`scale_parameter=False`, `d=1e9`) to align with the original CAME behavior.
266
+ - **Warmup**: Since the confidence matrix is zero-initialized without bias correction, a learning rate warmup is recommended to safely establish the confidence baseline.
267
+ - **Choosing `beta3`**: `beta3` should generally be larger than `beta2` so the confidence estimate evolves more slowly than the variance estimate. A practical starting range is **0.9995–0.99995** when `beta2=0.999`.
268
+
269
+
270
+ ### Configuration Example
271
+
272
+ To replicate "vanilla" CAME (stripping Adafactor's native modifications), replace the standard 2D APOLLO group in your `param_groups` with the following configuration:
273
+
274
+ ```python
275
+ {
276
+ "params": param_group,
277
+ "lr": lr, # Original CAME recommends 0.5-0.9x AdamW LR
278
+ "weight_decay": weight_decay,
279
+ "quantize": True,
280
+ "beta1": 0.9,
281
+ "beta3": 0.9999, # Enable CAME confidence guidance
282
+ "apollo_rank": 0, # Mutually exclusive with CAME
283
+ "scale_parameter": False, # Disable Adafactor RMS scaling to align with vanilla CAME
284
+ "d": 1e9, # Disable Adafactor global RMS clipping
285
+ "enable_fira_for_adafactor": False, # Disable Fira Limiter to prevent interference with CAME's scaling
286
+ },
287
+ ```
247
288
 
248
289
  ## 📈 Learning Rate Guide for Beginners
249
290
 
@@ -271,16 +312,18 @@ Thanks to **Hanqing Zhu**, **Zhenyu Zhang**, and the team for proposing the appr
271
312
 
272
313
  Thanks to **Xi Chen**, **Kaituo Feng**, and the team for the Norm-Growth Limiter mechanism introduced in [Fira: Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623).
273
314
 
315
+ Thanks to **Yang Luo** and the team for proposing the confidence-guided strategy in the paper [CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047).
316
+
274
317
  Thanks to the **PyTorch team** for providing the foundational Optimizer implementation and the C++ Extension toolchain.
275
318
 
276
319
  Thanks to the large language models **Qwen**, **ChatGLM** and **DeepSeek** for valuable technical discussions and code reviews on CUDA low-level optimization and memory safety mechanisms.
277
320
 
321
+ ## 🏛️ License
322
+
323
+ [The project is released under the MIT License.](https://github.com/yanfeiwong/adafactor-8bit/blob/main/LICENSE)
324
+
278
325
  ## ⭐ Star the Project
279
326
 
280
327
  If this optimizer has been useful in your work, consider giving the repository a star. It helps others discover the project and supports future development.
281
328
 
282
329
  [![Star History Chart](https://api.star-history.com/svg?repos=yanfeiwong/adafactor-8bit&type=Date&theme=dark)](https://star-history.com/#yanfeiwong/adafactor-8bit&Date)
283
-
284
- ## 📄 License
285
-
286
- [The project is released under the MIT License.](https://github.com/yanfeiwong/adafactor-8bit/blob/main/LICENSE)
@@ -1112,6 +1112,84 @@ void apply_update_1d_full_m_cuda(
1112
1112
  }
1113
1113
 
1114
1114
 
1115
+ // ==========================================
1116
+ // 15. CAME: Compute Residual Variance (Row & Col)
1117
+ // ==========================================
1118
+ __global__ void came_compute_residual_2d_kernel(
1119
+ const unsigned char* __restrict__ m_q, const float* __restrict__ m_scale,
1120
+ const unsigned char* __restrict__ row_var_q, const float* __restrict__ row_var_scale,
1121
+ const unsigned char* __restrict__ col_var_q, const float* __restrict__ col_var_scale,
1122
+ const float* __restrict__ grad,
1123
+ const float* __restrict__ row_mean_val_ptr,
1124
+ float* __restrict__ res_row_sum, float* __restrict__ res_col_sum,
1125
+ float log_eps_sq, int R, int C, int numel, int m_block_size, int v_block_size)
1126
+ {
1127
+ int stride = gridDim.x * blockDim.x;
1128
+ for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < numel; idx += stride) {
1129
+ int b = idx / (R * C);
1130
+ int r = (idx / C) % R;
1131
+ int c = idx % C;
1132
+
1133
+ unsigned char packed = m_q[idx / 2];
1134
+ int q_int = (idx & 1) ? (packed & 0x0F) : (packed >> 4);
1135
+ float m_val = (float)(q_int - 8) * m_scale[idx / m_block_size];
1136
+
1137
+ float log_r = (float)row_var_q[b * R + r] * INV_255 * row_var_scale[(b * R + r) / v_block_size] + MIN_LOG;
1138
+ float log_c = (float)col_var_q[b * C + c] * INV_255 * col_var_scale[(b * C + c) / v_block_size] + MIN_LOG;
1139
+ float log_row_mean = log2f(fmaxf(row_mean_val_ptr[b], MIN_VAL));
1140
+
1141
+ float log_v_ij = log_r + log_c - log_row_mean;
1142
+ float max_log = fmaxf(log_v_ij, log_eps_sq);
1143
+ max_log = fmaxf(max_log, -53.0f);
1144
+ float inv_std = exp2f(-0.5f * max_log);
1145
+
1146
+ float diff = (grad[idx] - m_val) * inv_std;
1147
+ float res = diff * diff;
1148
+
1149
+ atomicAdd(&res_col_sum[b * C + c], res);
1150
+
1151
+ int row_idx = b * R + r;
1152
+ int lane = threadIdx.x % 32;
1153
+
1154
+ for (int offset = 16; offset > 0; offset /= 2) {
1155
+ int other_row_idx = __shfl_down_sync(0xffffffff, row_idx, offset);
1156
+ float other_res = __shfl_down_sync(0xffffffff, res, offset);
1157
+ if (lane + offset < 32 && row_idx == other_row_idx) {
1158
+ res += other_res;
1159
+ }
1160
+ }
1161
+
1162
+ int prev_row_idx = __shfl_up_sync(0xffffffff, row_idx, 1);
1163
+ bool is_first_in_row = (lane == 0) || (row_idx != prev_row_idx);
1164
+
1165
+ if (is_first_in_row) {
1166
+ atomicAdd(&res_row_sum[row_idx], res);
1167
+ }
1168
+ }
1169
+ }
1170
+
1171
+ void came_compute_residual_2d_cuda(
1172
+ torch::Tensor m_q, torch::Tensor m_scale,
1173
+ torch::Tensor row_var_q, torch::Tensor row_var_scale,
1174
+ torch::Tensor col_var_q, torch::Tensor col_var_scale,
1175
+ torch::Tensor grad, torch::Tensor row_mean_val,
1176
+ torch::Tensor res_row_sum, torch::Tensor res_col_sum,
1177
+ float log_eps_sq, int R, int C, int numel, int m_block_size, int v_block_size)
1178
+ {
1179
+ int threads = 256;
1180
+ int blocks = min(1024, (numel + threads - 1) / threads);
1181
+ came_compute_residual_2d_kernel<<<blocks, threads>>>(
1182
+ m_q.data_ptr<unsigned char>(), m_scale.data_ptr<float>(),
1183
+ row_var_q.data_ptr<unsigned char>(), row_var_scale.data_ptr<float>(),
1184
+ col_var_q.data_ptr<unsigned char>(), col_var_scale.data_ptr<float>(),
1185
+ grad.data_ptr<float>(), row_mean_val.data_ptr<float>(),
1186
+ res_row_sum.data_ptr<float>(), res_col_sum.data_ptr<float>(),
1187
+ log_eps_sq, R, C, numel, m_block_size, v_block_size
1188
+ );
1189
+ }
1190
+
1191
+
1192
+
1115
1193
  PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1116
1194
  m.def("fused_log_quantize_lerp", &fused_log_quantize_lerp_cuda, "Fused log quantize lerp (CUDA)");
1117
1195
  m.def("fused_4bit_quantize_lerp", &fused_4bit_quantize_lerp_cuda, "Fused 4-bit packed quantize lerp for m_t (CUDA)");
@@ -1134,4 +1212,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1134
1212
 
1135
1213
  m.def("compute_update_norm_1d_full_m", &compute_update_norm_1d_full_m_cuda, "Compute update norm 1D full precision with momentum (CUDA)");
1136
1214
  m.def("apply_update_1d_full_m", &apply_update_1d_full_m_cuda, "Apply update 1D full precision with momentum (CUDA)");
1215
+
1216
+ m.def("came_compute_residual_2d", &came_compute_residual_2d_cuda, "Compute CAME residual row/col sums (CUDA)");
1137
1217
  }
@@ -116,6 +116,9 @@ def _log_quantize_nonneg(tensor: Tensor, block_size: int = 2048) -> Tuple[Tensor
116
116
 
117
117
  def _log_dequantize_nonneg(q: Tensor, scale: Tensor, shape: torch.Size, pad: int) -> Tensor:
118
118
  """Dequantize from log-space back to linear-space FP32."""
119
+ if q.dim() == 1:
120
+ block_size = q.numel() // scale.numel()
121
+ q = q.view(-1, block_size)
119
122
  log_blocks = q.float() * scale.unsqueeze(-1) * _INV_255 + _FP32_MIN_LOG
120
123
  blocks = torch.pow(2.0, log_blocks)
121
124
  flat = blocks.flatten()
@@ -165,6 +168,8 @@ class Adafactor8Bit(Optimizer):
165
168
 
166
169
  Args:
167
170
  params (Iterable): Iterable of parameters to optimize or dictionaries defining parameter groups.
171
+
172
+ --- Core Optimization ---
168
173
  lr (float, optional): External learning rate. Defaults to 1e-2.
169
174
  beta1 (float, optional): Momentum coefficient for first moment (4-bit packed).
170
175
  If None, disables first moment (pure Adafactor/RMSProp). Defaults to None.
@@ -174,89 +179,110 @@ class Adafactor8Bit(Optimizer):
174
179
  beta2_decay (float): Dynamic decay rate coefficient.
175
180
  The EMA weight is computed as `step ** beta2_decay`. Ignored if `beta2` is specified.
176
181
  Defaults to -0.8.
182
+ beta3 (float, optional): Confidence-guided decay coefficient for CAME
183
+ (Confidence-guided Adaptive Memory Efficient Optimization).
184
+ Computes the instability of the update direction and scales the update accordingly.
185
+ Strictly requires `beta1` and `factored=True`. Mutually exclusive with `apollo_rank`.
186
+ Defaults to None (disabled).
177
187
  eps (Tuple[Optional[float], float]): Regularization constants (eps1, eps2).
178
188
  - `eps1`: Added to the squared gradient. If `None`, defaults to the machine epsilon
179
- of the parameter's dtype (e.g., ~1.19e-7 for FP32), aligning with PyTorch official
180
- behavior and preventing underflow.
189
+ of the parameter's dtype (e.g., ~1.19e-7 for FP32), preventing underflow.
181
190
  - `eps2`: Lower threshold for parameter RMS scaling. Defaults to (None, 1e-3).
191
+ weight_decay (float): Weight decay (L2 penalty). Defaults to 0.0.
182
192
  d (float): Clipping threshold for the final gradient update RMS.
183
193
  Setting to an extremely large value (e.g., ``1e9``) effectively disables the global
184
- clipping constraint, which can be useful for decoupling updates in highly sparse layers
185
- like Embeddings. Defaults to 1.0.
186
- weight_decay (float): Weight decay (L2 penalty). Defaults to 0.0.
187
- scale_weight_decay (bool): If `True` (default), weight decay is coupled with the
188
- parameter's RMS scale. If `False`, weight decay is decoupled and only scaled by the
189
- base learning rate (AdamW-style).
194
+ clipping constraint, useful for decoupling updates in sparse layers like Embeddings.
195
+ Defaults to 1.0.
190
196
  maximize (bool): Maximize the params based on the objective. Defaults to False.
197
+
198
+ --- Factorization & Scaling ---
191
199
  relative_step (bool): If `True`, uses time-dependent learning rate. Defaults to True.
192
200
  scale_parameter (bool): If `True`, scales learning rate by parameter RMS.
193
201
  Setting to False decouples the step size from parameter magnitude, which can be useful
194
202
  for sparse layers like Embeddings to ensure sufficient update strength. Defaults to True.
203
+ factored (bool): Whether to use row/col factorization for >=2D tensors.
204
+ Setting to False uses element-wise variance (like RMSProp, but still applies Adafactor's
205
+ global RMS clipping). This can be useful for preserving spatial structure in >2D tensors
206
+ such as CNN convolutions, or enabling per-element updates in Embeddings. Defaults to True.
207
+
208
+ --- Quantization Control ---
195
209
  quantize (bool): Enable 8-bit log-space quantization for optimizer states. Defaults to True.
196
210
  block_size (int): Block size for variance quantization. Must be a multiple of 1024. Defaults to 2048.
197
211
  m_block_size (int): Block size for 4-bit momentum quantization.
198
- Balance outlier robustness and memory overhead. Must be a multiple of 4 and >= 32. Defaults to 128.
212
+ Must be a multiple of 4 and >= 32. Defaults to 128.
199
213
  min_8bit_size (int): Minimum number of elements to apply 8-bit quantization. Defaults to 4096.
200
214
  use_cuda_kernel (bool): Whether to use custom CUDA kernels. Defaults to True.
201
- apollo_rank (int): If > 0, enables APOLLO-style random projection to low-rank space
202
- before applying Adafactor. Defaults to 0 (disabled).
203
- apollo_update_proj_gap (int): Steps between projection matrix updates. Defaults to 200.
204
- apollo_scale_type (str): How to compute the gradient scaling factor: 'channel' or 'tensor'.
215
+
216
+ --- APOLLO Low-Rank Projection ---
217
+ apollo_rank (int): Rank for APOLLO (An Optimizer for Memory-Efficient Large-Scale Training)
218
+ style random projection to low-rank space. If > 0, enables APOLLO.
219
+ Mutually exclusive with `beta3` (CAME). Defaults to 0 (disabled).
220
+ apollo_update_proj_gap (int): Steps between random projection matrix refreshes.
221
+ Defaults to 200.
222
+ apollo_scale_type (str): Strategy to map low-rank updates back to full-rank:
223
+ 'channel' (row-wise norm matching) or 'tensor' (global norm matching).
205
224
  Defaults to 'channel'.
206
- apollo_eps (float): Epsilon for low-rank variance normalization. Defaults to 1e-8.
207
- apollo_factorize (bool): If True, uses Adafactor-style row/col factorization in the
208
- low-rank space (FP32, ~16KB state) instead of full matrix variance (8-bit, ~100KB+ state).
209
- For large models to drastically reduce optimizer state memory. Defaults to False.
210
- enable_fira_for_adafactor (bool): If `True`, enables Fira Limiter for the standard Adafactor path
211
- to prevent gradient explosion by smoothing update norms. Defaults to False.
212
- fira_margin (float): The tolerance margin for Fira Limiter. The limiter activates when the
213
- update norm grows by more than `fira_margin` (e.g., 0.01 for 1%). Shared with Apollo path.
214
- Defaults to 0.01.
215
- factored (bool): Whether to use row/col factorization for >=2D tensors.
216
- Setting to False uses element-wise variance (like RMSProp, but still applies Adafactor's
217
- global RMS clipping), which can be useful for preserving spatial structure in >2D tensors
218
- such as CNN convolutions, or enabling per-element updates in highly sparse layers like
219
- Embeddings. Defaults to True.
225
+ apollo_eps (float): Epsilon for low-rank variance normalization to prevent division by zero.
226
+ Defaults to 1e-8.
227
+ apollo_factorize (bool): If True, applies Adafactor-style row/col factorization
228
+ within the low-rank space (FP32, ~16KB state) instead of full matrix variance
229
+ (8-bit, ~100KB+ state) to drastically reduce optimizer state memory. Defaults to False.
230
+
231
+ --- Stabilizers & Regularization ---
232
+ scale_weight_decay (bool): If `True` (default), weight decay is coupled with the
233
+ parameter's RMS scale. If `False`, decoupled (AdamW-style).
234
+ enable_fira_for_adafactor (bool): If `True`, enables Fira Limiter to prevent gradient
235
+ explosion by smoothing update norms. Defaults to False.
236
+ fira_margin (float): The tolerance margin for Fira Limiter (e.g., 0.01 for 1%).
237
+ Shared with Apollo path. Defaults to 0.01.
220
238
  """
239
+
221
240
  def __init__(
222
241
  self,
223
242
  params: Iterable[Union[Tensor, Dict[str, Any]]],
243
+ # --- Core Optimization ---
224
244
  lr: float = 1e-2,
225
245
  beta1: Optional[float] = None,
226
246
  beta2: Optional[float] = None,
227
247
  beta2_decay: float = -0.8,
248
+ beta3: Optional[float] = None,
228
249
  eps: Tuple[Optional[float], float] = (None, 1e-3),
229
- d: float = 1.0,
230
250
  weight_decay: float = 0.0,
251
+ d: float = 1.0,
231
252
  maximize: bool = False,
253
+ # --- Factorization & Scaling ---
232
254
  relative_step: bool = True,
233
255
  scale_parameter: bool = True,
234
- scale_weight_decay: bool = True,
256
+ factored: bool = True,
257
+ # --- Quantization Control ---
235
258
  quantize: bool = True,
236
259
  block_size: int = 2048,
237
260
  m_block_size: int = 128,
238
261
  min_8bit_size: int = 4096,
239
262
  use_cuda_kernel: bool = True,
263
+ # --- APOLLO Low-Rank Projection ---
240
264
  apollo_rank: int = 0,
241
265
  apollo_update_proj_gap: int = 200,
242
266
  apollo_scale_type: str = 'channel',
243
267
  apollo_eps: float = 1e-8,
244
268
  apollo_factorize: bool = False,
269
+ # --- Stabilizers & Regularization ---
270
+ scale_weight_decay: bool = True,
245
271
  enable_fira_for_adafactor: bool = False,
246
272
  fira_margin: float = 0.01,
247
- factored: bool = True,
248
273
  ):
249
- if not 0.0 <= lr: raise ValueError(f"Invalid lr: {lr}")
250
- if beta1 is not None and not (0.0 <= beta1 < 1.0):
274
+
275
+ if lr < 0.0: raise ValueError(f"Invalid lr: {lr}, must be >= 0.0")
276
+ if beta1 is not None and (beta1 < 0.0 or beta1 >= 1.0):
251
277
  raise ValueError(f"Invalid beta1: {beta1}, must be in [0.0, 1.0)")
252
- if not 0.0 >= beta2_decay: raise ValueError(f"Invalid beta2_decay: {beta2_decay}")
278
+ if beta2_decay > 0.0: raise ValueError(f"Invalid beta2_decay: {beta2_decay}, must be <= 0.0")
253
279
  eps1, eps2 = eps
254
- if eps1 is not None and not 0.0 <= eps1: raise ValueError(f"Invalid eps1: {eps1}")
255
- if not 0.0 <= eps2: raise ValueError(f"Invalid eps2: {eps2}")
256
- if not 1.0 <= d: raise ValueError(f"Invalid d: {d}")
257
- if not 0.0 <= weight_decay: raise ValueError(f"Invalid weight_decay: {weight_decay}")
280
+ if eps1 is not None and eps1 < 0.0: raise ValueError(f"Invalid eps1: {eps1}, must be >= 0.0")
281
+ if eps2 < 0.0: raise ValueError(f"Invalid eps2: {eps2}, must be >= 0.0")
282
+ if d < 1.0: raise ValueError(f"Invalid d: {d}, must be >= 1.0")
283
+ if weight_decay < 0.0: raise ValueError(f"Invalid weight_decay: {weight_decay}, must be >= 0.0")
258
284
 
259
- if beta2 is not None and not (0.0 <= beta2 < 1.0):
285
+ if beta2 is not None and (beta2 < 0.0 or beta2 >= 1.0):
260
286
  raise ValueError(f"Invalid beta2: {beta2}, must be in [0.0, 1.0)")
261
287
 
262
288
  if quantize and block_size % 1024 != 0:
@@ -268,21 +294,33 @@ class Adafactor8Bit(Optimizer):
268
294
  if apollo_rank > 0 and apollo_scale_type not in ('channel', 'tensor'):
269
295
  raise ValueError(f"apollo_scale_type must be 'channel' or 'tensor', got {apollo_scale_type}.")
270
296
 
271
- if not 0.0 <= fira_margin: raise ValueError(f"Invalid fira_margin: {fira_margin}")
297
+ if fira_margin < 0.0: raise ValueError(f"Invalid fira_margin: {fira_margin}, must be >= 0.0")
298
+
299
+ if beta3 is not None:
300
+ if beta3 < 0.0 or beta3 >= 1.0:
301
+ raise ValueError(f"Invalid beta3: {beta3}, must be in [0.0, 1.0)")
302
+ if beta1 is None:
303
+ raise ValueError("CAME (beta3) strictly requires momentum (beta1) to compute update instability.")
304
+ if apollo_rank > 0:
305
+ raise ValueError("CAME (beta3) and APOLLO (apollo_rank > 0) are mutually exclusive optimization strategies.")
306
+ if not factored:
307
+ raise ValueError("CAME (beta3) requires factored=True (2D row/col factorization). It is not supported for 1D full-rank paths.")
272
308
 
273
309
  defaults = dict(
274
- lr=lr, beta1=beta1, beta2_decay=beta2_decay, beta2=beta2, eps=eps, d=d, weight_decay=weight_decay,
275
- maximize=maximize, relative_step=relative_step, scale_parameter=scale_parameter,
276
- scale_weight_decay=scale_weight_decay,
277
- quantize=quantize, block_size=block_size, m_block_size=m_block_size, min_8bit_size=min_8bit_size,
278
- use_cuda_kernel=use_cuda_kernel,
279
- apollo_rank=apollo_rank,
280
- apollo_update_proj_gap=apollo_update_proj_gap,
281
- apollo_scale_type=apollo_scale_type, apollo_eps=apollo_eps,
282
- apollo_factorize=apollo_factorize,
283
- enable_fira_for_adafactor=enable_fira_for_adafactor,
284
- fira_margin=fira_margin,
285
- factored=factored,
310
+ # Core Optimization
311
+ lr=lr, beta1=beta1, beta2=beta2, beta2_decay=beta2_decay, beta3=beta3,
312
+ eps=eps, weight_decay=weight_decay, d=d, maximize=maximize,
313
+ # Factorization & Scaling
314
+ relative_step=relative_step, scale_parameter=scale_parameter, factored=factored,
315
+ # Quantization Control
316
+ quantize=quantize, block_size=block_size, m_block_size=m_block_size,
317
+ min_8bit_size=min_8bit_size, use_cuda_kernel=use_cuda_kernel,
318
+ # APOLLO Low-Rank Projection
319
+ apollo_rank=apollo_rank, apollo_update_proj_gap=apollo_update_proj_gap,
320
+ apollo_scale_type=apollo_scale_type, apollo_eps=apollo_eps, apollo_factorize=apollo_factorize,
321
+ # Stabilizers & Regularization
322
+ scale_weight_decay=scale_weight_decay,
323
+ enable_fira_for_adafactor=enable_fira_for_adafactor, fira_margin=fira_margin,
286
324
  )
287
325
  super().__init__(params, defaults)
288
326
 
@@ -330,8 +368,10 @@ class Adafactor8Bit(Optimizer):
330
368
  m_block_size = group.get("m_block_size", 128)
331
369
  min_8bit_size = group.get("min_8bit_size", 4096)
332
370
  apollo_rank = group.get("apollo_rank", 0)
371
+ apollo_factorize = group.get("apollo_factorize", False)
333
372
  factored = group.get("factored", True)
334
373
  beta1 = group.get("beta1")
374
+ beta3 = group.get("beta3")
335
375
 
336
376
  for p in group["params"]:
337
377
  if p.grad is None: continue
@@ -358,8 +398,8 @@ class Adafactor8Bit(Optimizer):
358
398
  state["step"] = step_backup
359
399
  needs_init = True
360
400
  elif use_apollo and is_apollo_state:
361
- if state.get("apollo_rank") != apollo_rank:
362
- logger.warning(f"Adafactor8Bit: Apollo rank changed for param shape {p.shape}. Re-initializing state.")
401
+ if state.get("apollo_rank") != apollo_rank or state.get("apollo_factorize", False) != apollo_factorize:
402
+ logger.warning(f"Adafactor8Bit: Apollo config changed for param shape {p.shape}. Re-initializing state.")
363
403
  step_backup = state.get("step", 0)
364
404
  state.clear()
365
405
  state["step"] = step_backup
@@ -424,11 +464,25 @@ class Adafactor8Bit(Optimizer):
424
464
  state["m_q"] = torch.full((m_padded_numel // 2,), 0x88, dtype=torch.uint8, device=p.device)
425
465
  state["m_scale"] = torch.ones(m_padded_numel // m_block_size, dtype=torch.float32, device=p.device)
426
466
  state["m_block_size"] = m_block_size
467
+
468
+ if beta3 is not None:
469
+ state["conf_row_q"] = torch.zeros_like(state["row_var_q"])
470
+ state["conf_row_scale"] = torch.ones_like(state["row_var_scale"])
471
+ state["conf_row_shape"] = state["row_var_shape"]
472
+ state["conf_row_pad"] = state["row_var_pad"]
473
+
474
+ state["conf_col_q"] = torch.zeros_like(state["col_var_q"])
475
+ state["conf_col_scale"] = torch.ones_like(state["col_var_scale"])
476
+ state["conf_col_shape"] = state["col_var_shape"]
477
+ state["conf_col_pad"] = state["col_var_pad"]
427
478
  else:
428
- state["row_var"] = torch.zeros(r_shape, device=p.device)
429
- state["col_var"] = torch.zeros(c_shape, device=p.device)
479
+ state["row_var"] = torch.zeros(r_shape, dtype=torch.float32, device=p.device)
480
+ state["col_var"] = torch.zeros(c_shape, dtype=torch.float32, device=p.device)
430
481
  if beta1 is not None:
431
- state["m"] = torch.zeros_like(p.grad, device=p.device, memory_format=torch.preserve_format)
482
+ state["m"] = torch.zeros_like(p.grad, dtype=torch.float32, device=p.device, memory_format=torch.preserve_format)
483
+ if beta3 is not None:
484
+ state["conf_row"] = torch.zeros(r_shape, device=p.device)
485
+ state["conf_col"] = torch.zeros(c_shape, device=p.device)
432
486
  else:
433
487
  if use_quant:
434
488
  v_numel = p.grad.numel()
@@ -444,9 +498,9 @@ class Adafactor8Bit(Optimizer):
444
498
  state["m_scale"] = torch.ones(m_padded_numel // m_block_size, dtype=torch.float32, device=p.device)
445
499
  state["m_block_size"] = m_block_size
446
500
  else:
447
- state["variance"] = torch.zeros_like(p.grad, memory_format=torch.preserve_format)
501
+ state["variance"] = torch.zeros_like(p.grad, dtype=torch.float32, memory_format=torch.preserve_format)
448
502
  if beta1 is not None:
449
- state["m"] = torch.zeros_like(p.grad, device=p.device, memory_format=torch.preserve_format)
503
+ state["m"] = torch.zeros_like(p.grad, dtype=torch.float32, device=p.device, memory_format=torch.preserve_format)
450
504
  else:
451
505
  if torch.is_tensor(state["step"]):
452
506
  state["step"] = int(state["step"].cpu().item())
@@ -466,6 +520,16 @@ class Adafactor8Bit(Optimizer):
466
520
  state_is_factored = ("row_var" in state or "row_var_q" in state)
467
521
 
468
522
  if use_quant and not state.get("is_quantized", False):
523
+ if isinstance(state.get("v_low"), Tensor) and state.get("v_low_q") is None:
524
+ state["v_low"].clamp_(min=_FP32_TINY)
525
+ q, s, sh, pad = _log_quantize_nonneg(state["v_low"], curr_block_size)
526
+ state["v_low_q"], state["v_low_scale"], state["v_low_shape"], state["v_low_pad"] = q, s, sh, pad
527
+ state["v_low"] = None
528
+
529
+ if "m_low" in state:
530
+ logger.warning("Adafactor8Bit: Apollo m_low discarded due to quantize flag change.")
531
+ state.pop("m_low", None)
532
+
469
533
  if state_is_factored:
470
534
  if "row_var" in state and "row_var_q" not in state:
471
535
  state["row_var"].clamp_(min=_FP32_TINY)
@@ -477,6 +541,17 @@ class Adafactor8Bit(Optimizer):
477
541
  q, s, sh, pad = _log_quantize_nonneg(state["col_var"], curr_block_size)
478
542
  state["col_var_q"], state["col_var_scale"], state["col_var_shape"], state["col_var_pad"] = q, s, sh, pad
479
543
  del state["col_var"]
544
+ if beta3 is not None:
545
+ if "conf_row" in state and "conf_row_q" not in state:
546
+ state["conf_row"].clamp_(min=_FP32_TINY)
547
+ q, s, sh, pad = _log_quantize_nonneg(state["conf_row"], curr_block_size)
548
+ state["conf_row_q"], state["conf_row_scale"], state["conf_row_shape"], state["conf_row_pad"] = q, s, sh, pad
549
+ del state["conf_row"]
550
+ if "conf_col" in state and "conf_col_q" not in state:
551
+ state["conf_col"].clamp_(min=_FP32_TINY)
552
+ q, s, sh, pad = _log_quantize_nonneg(state["conf_col"], curr_block_size)
553
+ state["conf_col_q"], state["conf_col_scale"], state["conf_col_shape"], state["conf_col_pad"] = q, s, sh, pad
554
+ del state["conf_col"]
480
555
  else:
481
556
  if "variance" in state and "variance_q" not in state:
482
557
  state["variance"].clamp_(min=_FP32_TINY)
@@ -498,11 +573,27 @@ class Adafactor8Bit(Optimizer):
498
573
  state["is_quantized"] = True
499
574
 
500
575
  elif not use_quant and state.get("is_quantized", False):
576
+ if isinstance(state.get("v_low_q"), Tensor):
577
+ state["v_low"] = _log_dequantize_nonneg(
578
+ state.pop("v_low_q"), state.pop("v_low_scale"),
579
+ state.pop("v_low_shape"), state.pop("v_low_pad")
580
+ )
581
+
582
+ if "m_low_q" in state:
583
+ logger.warning("Adafactor8Bit: Apollo m_low_q discarded due to quantize flag change.")
584
+ state.pop("m_low_q", None)
585
+ state.pop("m_low_scale", None)
586
+
501
587
  if state_is_factored:
502
588
  if "row_var_q" in state:
503
589
  state["row_var"] = _log_dequantize_nonneg(state.pop("row_var_q"), state.pop("row_var_scale"), state.pop("row_var_shape"), state.pop("row_var_pad"))
504
590
  if "col_var_q" in state:
505
591
  state["col_var"] = _log_dequantize_nonneg(state.pop("col_var_q"), state.pop("col_var_scale"), state.pop("col_var_shape"), state.pop("col_var_pad"))
592
+ if beta3 is not None:
593
+ if "conf_row_q" in state:
594
+ state["conf_row"] = _log_dequantize_nonneg(state.pop("conf_row_q"), state.pop("conf_row_scale"), state.pop("conf_row_shape"), state.pop("conf_row_pad"))
595
+ if "conf_col_q" in state:
596
+ state["conf_col"] = _log_dequantize_nonneg(state.pop("conf_col_q"), state.pop("conf_col_scale"), state.pop("conf_col_shape"), state.pop("conf_col_pad"))
506
597
  else:
507
598
  if "variance_q" in state:
508
599
  state["variance"] = _log_dequantize_nonneg(state.pop("variance_q"), state.pop("variance_scale"), state.pop("variance_shape"), state.pop("variance_pad"))
@@ -597,6 +688,7 @@ class Adafactor8Bit(Optimizer):
597
688
  enable_fira_for_adafactor=group.get("enable_fira_for_adafactor", False),
598
689
  fira_margin=group.get("fira_margin", 0.01),
599
690
  factored=group.get("factored", True),
691
+ beta3=group.get("beta3"),
600
692
  )
601
693
  return loss
602
694
 
@@ -667,6 +759,7 @@ def _update_param_8bit(
667
759
  enable_fira_for_adafactor: bool = False,
668
760
  fira_margin: float = 0.01,
669
761
  factored: bool = True,
762
+ beta3: Optional[float] = None,
670
763
  ):
671
764
  if eps1 is None:
672
765
  eps1 = torch.finfo(param.dtype).eps
@@ -893,11 +986,13 @@ def _update_param_8bit(
893
986
  C = shape[-1]
894
987
  numel = grad_fp32.numel()
895
988
 
896
- row_mean = grad_fp32.square().mean(dim=-1, keepdim=True)
897
- col_mean = grad_fp32.square().mean(dim=-2, keepdim=True)
989
+ g_sq = grad_fp32.square()
990
+ row_mean = g_sq.mean(dim=-1, keepdim=True)
991
+ col_mean = g_sq.mean(dim=-2, keepdim=True)
898
992
 
899
993
  if quantize:
900
994
  if _load_cuda_module(use_cuda_kernel):
995
+ del g_sq
901
996
  _CUDA_MODULE.fused_log_quantize_lerp(state["row_var_q"], state["row_var_scale"], row_mean.reshape(-1), beta_val, curr_block_size, False, row_mean.numel())
902
997
  _CUDA_MODULE.fused_log_quantize_lerp(state["col_var_q"], state["col_var_scale"], col_mean.reshape(-1), beta_val, curr_block_size, False, col_mean.numel())
903
998
 
@@ -905,26 +1000,76 @@ def _update_param_8bit(
905
1000
  row_mean_val_flat = row_var.mean(dim=-2, keepdim=True).clamp_(min=eps1).flatten().contiguous()
906
1001
  del row_var
907
1002
 
908
- grad_flat = grad_fp32.reshape(-1)
909
- row_var_q_flat = state["row_var_q"].reshape(-1)
910
- col_var_q_flat = state["col_var_q"].reshape(-1)
911
-
912
- total_sum_sq = torch.zeros(1, device=param_work.device, dtype=torch.float32)
913
-
914
1003
  if beta1 is not None:
915
1004
  _CUDA_MODULE.fused_4bit_quantize_lerp(
916
- state["m_q"], state["m_scale"], grad_fp32.view(-1), beta1, m_curr_block_size, N
1005
+ state["m_q"], state["m_scale"], grad_fp32.view(-1), beta1, m_curr_block_size, numel
917
1006
  )
918
- del grad_fp32
919
1007
 
1008
+ if beta3 is not None and beta1 is not None:
1009
+ batch_size = math.prod(shape[:-2]) if len(shape) > 2 else 1
1010
+ res_row_sum = torch.zeros(batch_size * R, device=param_work.device, dtype=torch.float32)
1011
+ res_col_sum = torch.zeros(batch_size * C, device=param_work.device, dtype=torch.float32)
1012
+
1013
+ _CUDA_MODULE.came_compute_residual_2d(
1014
+ state["m_q"].view(-1), state["m_scale"].view(-1),
1015
+ state["row_var_q"].view(-1), state["row_var_scale"],
1016
+ state["col_var_q"].view(-1), state["col_var_scale"],
1017
+ grad_fp32.reshape(-1), row_mean_val_flat,
1018
+ res_row_sum, res_col_sum,
1019
+ log_eps_sq, R, C, numel, m_curr_block_size, curr_block_size
1020
+ )
1021
+
1022
+ beta3_val = 1.0 - beta3
1023
+ u_row_mean = (res_row_sum / C).contiguous().view(-1)
1024
+ u_col_mean = (res_col_sum / R).contiguous().view(-1)
1025
+ del res_row_sum, res_col_sum
1026
+
1027
+ _CUDA_MODULE.fused_log_quantize_lerp(state["conf_row_q"], state["conf_row_scale"], u_row_mean, beta3_val, curr_block_size, False, u_row_mean.numel())
1028
+ _CUDA_MODULE.fused_log_quantize_lerp(state["conf_col_q"], state["conf_col_scale"], u_col_mean, beta3_val, curr_block_size, False, u_col_mean.numel())
1029
+
1030
+ v_row = _log_dequantize_nonneg(state["row_var_q"], state["row_var_scale"], state["row_var_shape"], state["row_var_pad"])
1031
+ v_col = _log_dequantize_nonneg(state["col_var_q"], state["col_var_scale"], state["col_var_shape"], state["col_var_pad"])
1032
+ c_row = _log_dequantize_nonneg(state["conf_row_q"], state["conf_row_scale"], state["conf_row_shape"], state["conf_row_pad"])
1033
+ c_col = _log_dequantize_nonneg(state["conf_col_q"], state["conf_col_scale"], state["conf_col_shape"], state["conf_col_pad"])
1034
+
1035
+ combined_row = (v_row * c_row).clamp_(min=_FP32_TINY)
1036
+ combined_col = (v_col * c_col).clamp_(min=_FP32_TINY)
1037
+
1038
+ kernel_row_mean = combined_row.mean(dim=-2, keepdim=True).clamp_(min=eps1).flatten().contiguous()
1039
+
1040
+ q_r, s_r, _, _ = _log_quantize_nonneg(combined_row, curr_block_size)
1041
+ q_c, s_c, _, _ = _log_quantize_nonneg(combined_col, curr_block_size)
1042
+ del v_row, v_col, c_row, c_col, combined_row, combined_col
1043
+
1044
+ kernel_row_q_flat = q_r.reshape(-1)
1045
+ kernel_row_scale = s_r
1046
+ kernel_col_q_flat = q_c.reshape(-1)
1047
+ kernel_col_scale = s_c
1048
+ else:
1049
+ kernel_row_mean = row_mean_val_flat
1050
+ kernel_row_q_flat = state["row_var_q"].reshape(-1)
1051
+ kernel_row_scale = state["row_var_scale"]
1052
+ kernel_col_q_flat = state["col_var_q"].reshape(-1)
1053
+ kernel_col_scale = state["col_var_scale"]
1054
+
1055
+ if beta1 is not None:
1056
+ grad_flat = None
1057
+ del grad_fp32
1058
+ else:
1059
+ grad_flat = grad_fp32.reshape(-1)
1060
+ del grad_fp32
1061
+
1062
+ total_sum_sq = torch.zeros(1, device=param_work.device, dtype=torch.float32)
1063
+
1064
+ if beta1 is not None:
920
1065
  m_q_flat = state["m_q"].view(-1)
921
1066
  m_scale_flat = state["m_scale"].view(-1)
922
1067
 
923
1068
  _CUDA_MODULE.compute_update_norm_m_2d(
924
1069
  m_q_flat, m_scale_flat,
925
- row_var_q_flat, state["row_var_scale"],
926
- col_var_q_flat, state["col_var_scale"],
927
- total_sum_sq, row_mean_val_flat, log_eps_sq, R, C, numel, m_curr_block_size, curr_block_size
1070
+ kernel_row_q_flat, kernel_row_scale,
1071
+ kernel_col_q_flat, kernel_col_scale,
1072
+ total_sum_sq, kernel_row_mean, log_eps_sq, R, C, numel, m_curr_block_size, curr_block_size
928
1073
  )
929
1074
 
930
1075
  if enable_fira_for_adafactor:
@@ -934,15 +1079,15 @@ def _update_param_8bit(
934
1079
  _CUDA_MODULE.apply_update_m_2d(
935
1080
  param_flat,
936
1081
  m_q_flat, m_scale_flat,
937
- row_var_q_flat, state["row_var_scale"],
938
- col_var_q_flat, state["col_var_scale"],
939
- total_sum_sq, alpha, row_mean_val_flat, d, log_eps_sq, R, C, numel, m_curr_block_size, curr_block_size
1082
+ kernel_row_q_flat, kernel_row_scale,
1083
+ kernel_col_q_flat, kernel_col_scale,
1084
+ total_sum_sq, alpha, kernel_row_mean, d, log_eps_sq, R, C, numel, m_curr_block_size, curr_block_size
940
1085
  )
941
1086
  else:
942
1087
  _CUDA_MODULE.compute_update_norm_2d(
943
- row_var_q_flat, state["row_var_scale"],
944
- col_var_q_flat, state["col_var_scale"],
945
- grad_flat, total_sum_sq, row_mean_val_flat, log_eps_sq, R, C, numel, curr_block_size
1088
+ kernel_row_q_flat, kernel_row_scale,
1089
+ kernel_col_q_flat, kernel_col_scale,
1090
+ grad_flat, total_sum_sq, kernel_row_mean, log_eps_sq, R, C, numel, curr_block_size
946
1091
  )
947
1092
 
948
1093
  if enable_fira_for_adafactor:
@@ -951,9 +1096,9 @@ def _update_param_8bit(
951
1096
  param_flat = param_work.reshape(-1)
952
1097
  _CUDA_MODULE.apply_update_2d(
953
1098
  param_flat, grad_flat,
954
- row_var_q_flat, state["row_var_scale"],
955
- col_var_q_flat, state["col_var_scale"],
956
- total_sum_sq, alpha, row_mean_val_flat, d, log_eps_sq, R, C, numel, curr_block_size
1099
+ kernel_row_q_flat, kernel_row_scale,
1100
+ kernel_col_q_flat, kernel_col_scale,
1101
+ total_sum_sq, alpha, kernel_row_mean, d, log_eps_sq, R, C, numel, curr_block_size
957
1102
  )
958
1103
  else:
959
1104
  row_var = _log_dequantize_nonneg(state["row_var_q"], state["row_var_scale"], state["row_var_shape"], state["row_var_pad"])
@@ -975,8 +1120,53 @@ def _update_param_8bit(
975
1120
  m_temp = _dequantize_4bit(state["m_q"], state["m_scale"], grad_fp32.numel(), grad_fp32.shape, m_curr_block_size, grad_fp32.device)
976
1121
  else:
977
1122
  m_temp = torch.zeros_like(grad_fp32)
1123
+
978
1124
  m_temp.lerp_(grad_fp32, 1.0 - beta1)
979
- del grad_fp32
1125
+
1126
+ if beta3 is not None:
1127
+ inv_col_sq = inv_col.square()
1128
+ inv_row_sq = inv_row.square()
1129
+ inv_col_sq_T = inv_col_sq.transpose(-1, -2)
1130
+ inv_row_sq_T = inv_row_sq.transpose(-1, -2)
1131
+
1132
+ gm = grad_fp32 * m_temp
1133
+ m_sq = m_temp.square()
1134
+
1135
+ t1 = torch.matmul(g_sq, inv_col_sq_T) / C
1136
+ t2 = torch.matmul(gm, inv_col_sq_T) / C
1137
+ t3 = torch.matmul(m_sq, inv_col_sq_T) / C
1138
+ res_row_mean = (inv_row_sq * (t1 - 2.0 * t2 + t3)).clamp(min=0)
1139
+
1140
+ t1c = torch.matmul(inv_row_sq_T, g_sq) / R
1141
+ t2c = torch.matmul(inv_row_sq_T, gm) / R
1142
+ t3c = torch.matmul(inv_row_sq_T, m_sq) / R
1143
+ res_col_mean = (inv_col_sq * (t1c - 2.0 * t2c + t3c)).clamp(min=0)
1144
+
1145
+ del gm, m_sq, g_sq
1146
+
1147
+ conf_row_temp = _log_dequantize_nonneg(state["conf_row_q"], state["conf_row_scale"], state["conf_row_shape"], state["conf_row_pad"])
1148
+ conf_col_temp = _log_dequantize_nonneg(state["conf_col_q"], state["conf_col_scale"], state["conf_col_shape"], state["conf_col_pad"])
1149
+
1150
+ conf_row_temp.lerp_(res_row_mean, 1.0 - beta3)
1151
+ conf_col_temp.lerp_(res_col_mean, 1.0 - beta3)
1152
+
1153
+ q_cr, s_cr, sh_cr, pad_cr = _log_quantize_nonneg(conf_row_temp, curr_block_size)
1154
+ state["conf_row_q"], state["conf_row_scale"], state["conf_row_shape"], state["conf_row_pad"] = q_cr, s_cr, sh_cr, pad_cr
1155
+
1156
+ q_cc, s_cc, sh_cc, pad_cc = _log_quantize_nonneg(conf_col_temp, curr_block_size)
1157
+ state["conf_col_q"], state["conf_col_scale"], state["conf_col_shape"], state["conf_col_pad"] = q_cc, s_cc, sh_cc, pad_cc
1158
+
1159
+ combined_row = (row_var * conf_row_temp).clamp(min=eps_sq)
1160
+ combined_col = (col_var * conf_col_temp).clamp(min=eps_sq)
1161
+ del conf_row_temp, conf_col_temp
1162
+
1163
+ combined_row_mean_val = combined_row.mean(dim=-2, keepdim=True).clamp(min=eps1)
1164
+ inv_row = combined_row.rsqrt() * combined_row_mean_val.sqrt()
1165
+ inv_col = combined_col.rsqrt()
1166
+ else:
1167
+ del g_sq
1168
+
1169
+ del grad_fp32
980
1170
 
981
1171
  update = m_temp * inv_row
982
1172
  update.mul_(inv_col)
@@ -989,6 +1179,7 @@ def _update_param_8bit(
989
1179
  denom = torch.clamp(torch.linalg.vector_norm(update) / (math.sqrt(update.numel()) * d), min=1.0)
990
1180
  param_work.add_(update, alpha=-alpha / denom)
991
1181
  else:
1182
+ del g_sq
992
1183
  update = grad_fp32 * inv_row
993
1184
  update.mul_(inv_col)
994
1185
 
@@ -1011,6 +1202,39 @@ def _update_param_8bit(
1011
1202
  if "m" not in state:
1012
1203
  state["m"] = torch.zeros_like(grad_fp32)
1013
1204
  state["m"].lerp_(grad_fp32, 1.0 - beta1)
1205
+
1206
+ if beta3 is not None:
1207
+ inv_col_sq = inv_col.square()
1208
+ inv_row_sq = inv_row.square()
1209
+ inv_col_sq_T = inv_col_sq.transpose(-1, -2)
1210
+ inv_row_sq_T = inv_row_sq.transpose(-1, -2)
1211
+
1212
+ gm = grad_fp32 * state["m"]
1213
+ m_sq = state["m"].square()
1214
+
1215
+ t1 = torch.matmul(g_sq, inv_col_sq_T) / C
1216
+ t2 = torch.matmul(gm, inv_col_sq_T) / C
1217
+ t3 = torch.matmul(m_sq, inv_col_sq_T) / C
1218
+ res_row_mean = (inv_row_sq * (t1 - 2.0 * t2 + t3)).clamp(min=0)
1219
+
1220
+ t1c = torch.matmul(inv_row_sq_T, g_sq) / R
1221
+ t2c = torch.matmul(inv_row_sq_T, gm) / R
1222
+ t3c = torch.matmul(inv_row_sq_T, m_sq) / R
1223
+ res_col_mean = (inv_col_sq * (t1c - 2.0 * t2c + t3c)).clamp(min=0)
1224
+
1225
+ del gm, m_sq, g_sq
1226
+
1227
+ state["conf_row"].lerp_(res_row_mean, 1.0 - beta3)
1228
+ state["conf_col"].lerp_(res_col_mean, 1.0 - beta3)
1229
+
1230
+ combined_row = (row_var * state["conf_row"]).clamp(min=eps_sq)
1231
+ combined_col = (col_var * state["conf_col"]).clamp(min=eps_sq)
1232
+ combined_row_mean_val = combined_row.mean(dim=-2, keepdim=True).clamp(min=eps1)
1233
+ inv_row = combined_row.rsqrt() * combined_row_mean_val.sqrt()
1234
+ inv_col = combined_col.rsqrt()
1235
+ else:
1236
+ del g_sq
1237
+
1014
1238
  del grad_fp32
1015
1239
 
1016
1240
  update = state["m"] * inv_row
@@ -1022,6 +1246,7 @@ def _update_param_8bit(
1022
1246
  denom = torch.clamp(torch.linalg.vector_norm(update) / (math.sqrt(update.numel()) * d), min=1.0)
1023
1247
  param_work.add_(update, alpha=-alpha / denom)
1024
1248
  else:
1249
+ del g_sq
1025
1250
  update = grad_fp32 * inv_row
1026
1251
  update.mul_(inv_col)
1027
1252
 
@@ -1156,8 +1381,8 @@ def _update_param_apollo(
1156
1381
  col_mean_low = grad_low.square().mean(dim=-2, keepdim=True)
1157
1382
 
1158
1383
  if "row_var_low" not in state:
1159
- state["row_var_low"] = row_mean_low.clone().clamp(min=_FP32_TINY)
1160
- state["col_var_low"] = col_mean_low.clone().clamp(min=_FP32_TINY)
1384
+ state["row_var_low"] = (row_mean_low * beta_val).clamp(min=_FP32_TINY)
1385
+ state["col_var_low"] = (col_mean_low * beta_val).clamp(min=_FP32_TINY)
1161
1386
  else:
1162
1387
  state["row_var_low"].mul_(1.0 - beta_val).add_(row_mean_low, alpha=beta_val)
1163
1388
  state["col_var_low"].mul_(1.0 - beta_val).add_(col_mean_low, alpha=beta_val)
@@ -1213,7 +1438,7 @@ def _update_param_apollo(
1213
1438
  quantize = state.get("is_quantized", True)
1214
1439
 
1215
1440
  if is_first_step:
1216
- v_init = grad_low.flatten().square().clamp(min=_FP32_TINY)
1441
+ v_init = (grad_low.flatten().square() * beta_val).clamp(min=_FP32_TINY)
1217
1442
  if quantize:
1218
1443
  q, s, sh, pad = _log_quantize_nonneg(v_init, block_size)
1219
1444
  state["v_low_q"], state["v_low_scale"], state["v_low_shape"], state["v_low_pad"] = q, s, sh, pad
@@ -1392,4 +1617,4 @@ def _update_param_apollo(
1392
1617
  del update_low
1393
1618
 
1394
1619
  if needs_copy_back:
1395
- param.copy_(param_work.view(original_shape))
1620
+ param.copy_(param_work.view(original_shape))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adafactor8bit
3
- Version: 0.2.1
3
+ Version: 0.2.2
4
4
  Summary: 8-bit Adafactor Optimizer with Fused CUDA Kernels
5
5
  Home-page: https://github.com/yanfeiwong/adafactor-8bit
6
6
  Author: WANG YAN
@@ -25,6 +25,13 @@ Dynamic: requires-dist
25
25
  Dynamic: requires-python
26
26
  Dynamic: summary
27
27
 
28
+ <p align="center">
29
+ <a href="https://github.com/yanfeiwong/adafactor-8bit">
30
+ <img src="https://github.com/yanfeiwong/adafactor-8bit/raw/main/assets/banner.png"
31
+ alt="Adafactor8Bit"
32
+ width="80%">
33
+ </a>
34
+ </p>
28
35
  <div align="center">
29
36
 
30
37
  # 8-bit Adafactor with Fused CUDA Kernels
@@ -39,14 +46,15 @@ Dynamic: summary
39
46
 
40
47
  </div>
41
48
 
42
- An enhanced 8-bit Adafactor optimizer featuring fused CUDA kernels, log-space block-wise quantization, optional APOLLO low-rank updates, and 4-bit packed first moments, delivering substantially lower optimizer memory while preserving the low-overhead and numerical stability that make Adafactor attractive for training LLMs and diffusion models.
49
+ An enhanced 8-bit Adafactor optimizer featuring fused CUDA kernels, log-space block-wise quantization, and optional add-ons including 4-bit packed first moments, APOLLO low-rank updates, and CAME confidence-guided optimization. It delivers substantially lower optimizer memory while preserving the low-overhead and numerical stability that make Adafactor attractive for training LLMs and diffusion models.
43
50
 
44
51
 
45
- ## 🔥 Key Features
52
+ ## Key Features
46
53
 
47
54
  - **Log-Space Quantization**: Maps the second moment (variance) to the log2 space before 8-bit quantization. This approach accommodates the long-tail distribution of variances, reducing the risk of small second-moment estimates being truncated to zero and improving overall training stability.
48
55
  - **Fused CUDA Kernels**: Combines dequantization, EMA updates, Warp-Shuffle reductions, and requantization into single kernels. It utilizes `float4` vectorization to optimize memory bandwidth usage.
49
56
  - **Optional 4-bit Packed First Moment**: Stores the first moment (`beta1`) in a physically packed 4-bit format when enabled, providing momentum with minimal additional memory overhead.
57
+ - **CAME Confidence Guidance**: Optional Confidence-guided Adaptive Memory Efficient Optimization (CAME) that estimates update confidence from historical momentum and adaptively suppresses unstable update directions, improving training stability and reducing loss spikes.
50
58
  - **APOLLO Subspace Projection**: Opt-in random subspace projection that estimates adaptive gradient scaling in a low-rank space, preventing stale second-moment statistics and potentially improving convergence and generalization.
51
59
  - **Fira Norm-Growth Limiter**: Suppresses destructive gradient spikes by regulating the relative increase of update norms. Originally used for the APOLLO path, it is now available for the standard Adafactor path as well. It improves training stability and often allows the safe removal of external gradient clipping.
52
60
  - **Zero CPU-GPU Sync**: Eliminates implicit synchronizations (e.g., D2H copies) in the control flow, ensuring the GPU computation pipeline runs without blocking.
@@ -194,16 +202,16 @@ def get_param_groups(model, lr_emb, weight_decay, apollo_rank=256):
194
202
  "weight_decay": weight_decay,
195
203
  "quantize": True,
196
204
  "apollo_rank": apollo_rank,
197
- "beta1":0.9, # Remove if minimizing optimizer memory is the priority.
205
+ "beta1": 0.9, # Remove if minimizing optimizer memory is the priority.
198
206
  },
199
-
207
+
200
208
  # 4. >2D Weights: 8-bit quantization, Weight Decay, Full-Rank
201
209
  {
202
210
  "params": group_nd,
203
211
  "weight_decay": weight_decay,
204
212
  "quantize": True,
205
213
  "apollo_rank": 0,
206
- "beta1":0.9, # Remove if minimizing optimizer memory is the priority.
214
+ "beta1": 0.9, # Remove if minimizing optimizer memory is the priority.
207
215
  "factored": False # Disables factorization to preserve spatial structures, enabling finer gradient scaling.
208
216
  # Note: This increases state memory for >2D weights, depending on your model architecture.
209
217
  # If VRAM is constrained, reverting to factored=True is a safe alternative.
@@ -270,7 +278,40 @@ Enable the APOLLO path to compute gradient scaling factors in a memory-efficient
270
278
  - **`apollo_factorize` (Experimental)**: Applies Adafactor's row/column factorization within the low-rank subspace. Mathematically, this leverages the norm-preserving property of random projections to approximate the variance of the primary dimension, while the secondary dimension's variance is estimated across random bases, introducing inherent noise. This dual-compression mechanism drastically reduces optimizer state overhead. Note that for smaller models, the actual VRAM savings might be marginal, and the introduced noise could impact convergence stability. Use with caution.
271
279
  - **Fira Limiter Integration**: The APOLLO path automatically applies the Fira Norm-Growth Limiter to the scaled gradients to prevent sudden gradient rises from causing loss spikes. You can adjust its sensitivity using the global `fira_margin` parameter.
272
280
 
281
+ ## 🛡️ CAME Confidence-Guided Updates
282
+
283
+ Enable the CAME (Confidence-guided Adaptive Memory Efficient Optimization) path to add a confidence estimation stage after momentum accumulation:
284
+
285
+ **Adaptive Scaling ($V$) → Momentum Accumulation ($M$) → Confidence Weighting ($C$)**
286
+
287
+ ### Key Parameters & Tuning
273
288
 
289
+ The confidence stage measures the consistency between the current update direction and historical momentum, adaptively suppressing highly oscillatory updates.
290
+
291
+ - **`beta3`**: EMA decay coefficient for the confidence matrix. Requires `beta1` (momentum) and `factored=True`. Mutually exclusive with `apollo_rank`. Defaults to `None` (disabled).
292
+ - **Learning Rate**: The official CAME implementation recommends **0.5–0.9×** the AdamW learning rate (see [official tuning guide](https://github.com/yangluo7/CAME/tree/master#hyper-parameter-tuning)). To use this learning rate in this library, you need to disable Adafactor's scaling and clipping (`scale_parameter=False`, `d=1e9`) to align with the original CAME behavior.
293
+ - **Warmup**: Since the confidence matrix is zero-initialized without bias correction, a learning rate warmup is recommended to safely establish the confidence baseline.
294
+ - **Choosing `beta3`**: `beta3` should generally be larger than `beta2` so the confidence estimate evolves more slowly than the variance estimate. A practical starting range is **0.9995–0.99995** when `beta2=0.999`.
295
+
296
+
297
+ ### Configuration Example
298
+
299
+ To replicate "vanilla" CAME (stripping Adafactor's native modifications), replace the standard 2D APOLLO group in your `param_groups` with the following configuration:
300
+
301
+ ```python
302
+ {
303
+ "params": param_group,
304
+ "lr": lr, # Original CAME recommends 0.5-0.9x AdamW LR
305
+ "weight_decay": weight_decay,
306
+ "quantize": True,
307
+ "beta1": 0.9,
308
+ "beta3": 0.9999, # Enable CAME confidence guidance
309
+ "apollo_rank": 0, # Mutually exclusive with CAME
310
+ "scale_parameter": False, # Disable Adafactor RMS scaling to align with vanilla CAME
311
+ "d": 1e9, # Disable Adafactor global RMS clipping
312
+ "enable_fira_for_adafactor": False, # Disable Fira Limiter to prevent interference with CAME's scaling
313
+ },
314
+ ```
274
315
 
275
316
  ## 📈 Learning Rate Guide for Beginners
276
317
 
@@ -298,16 +339,18 @@ Thanks to **Hanqing Zhu**, **Zhenyu Zhang**, and the team for proposing the appr
298
339
 
299
340
  Thanks to **Xi Chen**, **Kaituo Feng**, and the team for the Norm-Growth Limiter mechanism introduced in [Fira: Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623).
300
341
 
342
+ Thanks to **Yang Luo** and the team for proposing the confidence-guided strategy in the paper [CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047).
343
+
301
344
  Thanks to the **PyTorch team** for providing the foundational Optimizer implementation and the C++ Extension toolchain.
302
345
 
303
346
  Thanks to the large language models **Qwen**, **ChatGLM** and **DeepSeek** for valuable technical discussions and code reviews on CUDA low-level optimization and memory safety mechanisms.
304
347
 
348
+ ## 🏛️ License
349
+
350
+ [The project is released under the MIT License.](https://github.com/yanfeiwong/adafactor-8bit/blob/main/LICENSE)
351
+
305
352
  ## ⭐ Star the Project
306
353
 
307
354
  If this optimizer has been useful in your work, consider giving the repository a star. It helps others discover the project and supports future development.
308
355
 
309
356
  [![Star History Chart](https://api.star-history.com/svg?repos=yanfeiwong/adafactor-8bit&type=Date&theme=dark)](https://star-history.com/#yanfeiwong/adafactor-8bit&Date)
310
-
311
- ## 📄 License
312
-
313
- [The project is released under the MIT License.](https://github.com/yanfeiwong/adafactor-8bit/blob/main/LICENSE)
@@ -9,7 +9,7 @@ long_description = (this_directory / "README.md").read_text(encoding="utf-8")
9
9
 
10
10
  setup(
11
11
  name="adafactor8bit",
12
- version="0.2.1",
12
+ version="0.2.2",
13
13
  description="8-bit Adafactor Optimizer with Fused CUDA Kernels",
14
14
  author="WANG YAN",
15
15
  author_email="yanfeiwong1997@outlook.com",
File without changes
File without changes
File without changes