adafactor8bit 0.2.2__tar.gz → 0.2.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adafactor8bit
3
- Version: 0.2.2
3
+ Version: 0.2.4
4
4
  Summary: 8-bit Adafactor Optimizer with Fused CUDA Kernels
5
5
  Home-page: https://github.com/yanfeiwong/adafactor-8bit
6
6
  Author: WANG YAN
@@ -53,7 +53,7 @@ An enhanced 8-bit Adafactor optimizer featuring fused CUDA kernels, log-space bl
53
53
 
54
54
  - **Log-Space Quantization**: Maps the second moment (variance) to the log2 space before 8-bit quantization. This approach accommodates the long-tail distribution of variances, reducing the risk of small second-moment estimates being truncated to zero and improving overall training stability.
55
55
  - **Fused CUDA Kernels**: Combines dequantization, EMA updates, Warp-Shuffle reductions, and requantization into single kernels. It utilizes `float4` vectorization to optimize memory bandwidth usage.
56
- - **Optional 4-bit Packed First Moment**: Stores the first moment (`beta1`) in a physically packed 4-bit format when enabled, providing momentum with minimal additional memory overhead.
56
+ - **Optional NF4 First Moment**: Stores the optional first moment (`beta1`) using Normal Float 4-bit (NF4) non-uniform quantization, preserving small momentum updates while keeping memory overhead minimal.
57
57
  - **CAME Confidence Guidance**: Optional Confidence-guided Adaptive Memory Efficient Optimization (CAME) that estimates update confidence from historical momentum and adaptively suppresses unstable update directions, improving training stability and reducing loss spikes.
58
58
  - **APOLLO Subspace Projection**: Opt-in random subspace projection that estimates adaptive gradient scaling in a low-rank space, preventing stale second-moment statistics and potentially improving convergence and generalization.
59
59
  - **Fira Norm-Growth Limiter**: Suppresses destructive gradient spikes by regulating the relative increase of update norms. Originally used for the APOLLO path, it is now available for the standard Adafactor path as well. It improves training stability and often allows the safe removal of external gradient clipping.
@@ -278,7 +278,7 @@ Enable the APOLLO path to compute gradient scaling factors in a memory-efficient
278
278
  - **`apollo_factorize` (Experimental)**: Applies Adafactor's row/column factorization within the low-rank subspace. Mathematically, this leverages the norm-preserving property of random projections to approximate the variance of the primary dimension, while the secondary dimension's variance is estimated across random bases, introducing inherent noise. This dual-compression mechanism drastically reduces optimizer state overhead. Note that for smaller models, the actual VRAM savings might be marginal, and the introduced noise could impact convergence stability. Use with caution.
279
279
  - **Fira Limiter Integration**: The APOLLO path automatically applies the Fira Norm-Growth Limiter to the scaled gradients to prevent sudden gradient rises from causing loss spikes. You can adjust its sensitivity using the global `fira_margin` parameter.
280
280
 
281
- ## 🛡️ CAME Confidence-Guided Updates
281
+ ## 🧊 CAME Confidence-Guided Updates
282
282
 
283
283
  Enable the CAME (Confidence-guided Adaptive Memory Efficient Optimization) path to add a confidence estimation stage after momentum accumulation:
284
284
 
@@ -326,24 +326,25 @@ If you are migrating from optimizers like AdamW, Adafactor's learning rate behav
326
326
  *These are safe starting points. Always validate on your own task and batch size.*
327
327
 
328
328
 
329
-
330
-
331
-
332
329
  ## 🎓 Acknowledgements
333
330
 
334
- Thanks to **Noam Shazeer** and **Mitchell Stern** for proposing the original Adafactor algorithm in the paper [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235).
335
-
336
- Thanks to **Tim Dettmers** for the inspiration from the paper [8-BIT OPTIMIZERS VIA BLOCK-WISE QUANTIZATION](https://arxiv.org/abs/2110.02861) and the [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) library.
337
-
338
- Thanks to **Hanqing Zhu**, **Zhenyu Zhang**, and the team for proposing the approximated gradient scaling method in the paper [APOLLO: SGD-Like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270).
331
+ This project builds upon the foundational work of several researchers and open-source communities. Sincere thanks to the following for their invaluable contributions:
339
332
 
340
- Thanks to **Xi Chen**, **Kaituo Feng**, and the team for the Norm-Growth Limiter mechanism introduced in [Fira: Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623).
333
+ ### Core Algorithm & Optimizer Design
334
+ - **Noam Shazeer & Mitchell Stern** for proposing the original **Adafactor** algorithm ([Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235)).
335
+ - **Tim Dettmers** for the inspiration from **8-bit block-wise quantization** ([8-BIT OPTIMIZERS VIA BLOCK-WISE QUANTIZATION](https://arxiv.org/abs/2110.02861)) and the [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) library.
336
+ - **Hanqing Zhu, Zhenyu Zhang, et al.** for the **APOLLO** algorithm ([APOLLO: SGD-Like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270)).
337
+ - **Xi Chen, Kaituo Feng, et al.** for the **Norm-Growth Limiter** mechanism in **Fira** ([Fira: Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623)).
338
+ - **Yang Luo, et al.** for the **confidence-guided strategy** in **CAME** ([CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047)).
341
339
 
342
- Thanks to **Yang Luo** and the team for proposing the confidence-guided strategy in the paper [CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047).
340
+ ### Quantization & Implementation
341
+ - **The QLoRA Team** for pioneering the **4-bit NormalFloat (NF4)** quantization format ([QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)) that inspired our first moment quantization.
342
+ - **The PyTorch AO Team** for their work on [4-bit optimizer states](https://github.com/pytorch/ao/tree/main/torchao/optim), validating distribution-aware quantization for optimizer moments.
343
+ - **The PyTorch Team** for providing the foundational optimizer implementation and the C++ Extension toolchain.
343
344
 
344
- Thanks to the **PyTorch team** for providing the foundational Optimizer implementation and the C++ Extension toolchain.
345
+ ### Technical Review & Discussion
346
+ - **Qwen, ChatGLM, and DeepSeek** (large language models) for valuable technical discussions and code reviews on CUDA low-level optimization, memory safety mechanisms, and cross-platform compilation pipeline design.
345
347
 
346
- Thanks to the large language models **Qwen**, **ChatGLM** and **DeepSeek** for valuable technical discussions and code reviews on CUDA low-level optimization and memory safety mechanisms.
347
348
 
348
349
  ## 🏛️ License
349
350
 
@@ -26,7 +26,7 @@ An enhanced 8-bit Adafactor optimizer featuring fused CUDA kernels, log-space bl
26
26
 
27
27
  - **Log-Space Quantization**: Maps the second moment (variance) to the log2 space before 8-bit quantization. This approach accommodates the long-tail distribution of variances, reducing the risk of small second-moment estimates being truncated to zero and improving overall training stability.
28
28
  - **Fused CUDA Kernels**: Combines dequantization, EMA updates, Warp-Shuffle reductions, and requantization into single kernels. It utilizes `float4` vectorization to optimize memory bandwidth usage.
29
- - **Optional 4-bit Packed First Moment**: Stores the first moment (`beta1`) in a physically packed 4-bit format when enabled, providing momentum with minimal additional memory overhead.
29
+ - **Optional NF4 First Moment**: Stores the optional first moment (`beta1`) using Normal Float 4-bit (NF4) non-uniform quantization, preserving small momentum updates while keeping memory overhead minimal.
30
30
  - **CAME Confidence Guidance**: Optional Confidence-guided Adaptive Memory Efficient Optimization (CAME) that estimates update confidence from historical momentum and adaptively suppresses unstable update directions, improving training stability and reducing loss spikes.
31
31
  - **APOLLO Subspace Projection**: Opt-in random subspace projection that estimates adaptive gradient scaling in a low-rank space, preventing stale second-moment statistics and potentially improving convergence and generalization.
32
32
  - **Fira Norm-Growth Limiter**: Suppresses destructive gradient spikes by regulating the relative increase of update norms. Originally used for the APOLLO path, it is now available for the standard Adafactor path as well. It improves training stability and often allows the safe removal of external gradient clipping.
@@ -251,7 +251,7 @@ Enable the APOLLO path to compute gradient scaling factors in a memory-efficient
251
251
  - **`apollo_factorize` (Experimental)**: Applies Adafactor's row/column factorization within the low-rank subspace. Mathematically, this leverages the norm-preserving property of random projections to approximate the variance of the primary dimension, while the secondary dimension's variance is estimated across random bases, introducing inherent noise. This dual-compression mechanism drastically reduces optimizer state overhead. Note that for smaller models, the actual VRAM savings might be marginal, and the introduced noise could impact convergence stability. Use with caution.
252
252
  - **Fira Limiter Integration**: The APOLLO path automatically applies the Fira Norm-Growth Limiter to the scaled gradients to prevent sudden gradient rises from causing loss spikes. You can adjust its sensitivity using the global `fira_margin` parameter.
253
253
 
254
- ## 🛡️ CAME Confidence-Guided Updates
254
+ ## 🧊 CAME Confidence-Guided Updates
255
255
 
256
256
  Enable the CAME (Confidence-guided Adaptive Memory Efficient Optimization) path to add a confidence estimation stage after momentum accumulation:
257
257
 
@@ -299,24 +299,25 @@ If you are migrating from optimizers like AdamW, Adafactor's learning rate behav
299
299
  *These are safe starting points. Always validate on your own task and batch size.*
300
300
 
301
301
 
302
-
303
-
304
-
305
302
  ## 🎓 Acknowledgements
306
303
 
307
- Thanks to **Noam Shazeer** and **Mitchell Stern** for proposing the original Adafactor algorithm in the paper [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235).
308
-
309
- Thanks to **Tim Dettmers** for the inspiration from the paper [8-BIT OPTIMIZERS VIA BLOCK-WISE QUANTIZATION](https://arxiv.org/abs/2110.02861) and the [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) library.
310
-
311
- Thanks to **Hanqing Zhu**, **Zhenyu Zhang**, and the team for proposing the approximated gradient scaling method in the paper [APOLLO: SGD-Like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270).
304
+ This project builds upon the foundational work of several researchers and open-source communities. Sincere thanks to the following for their invaluable contributions:
312
305
 
313
- Thanks to **Xi Chen**, **Kaituo Feng**, and the team for the Norm-Growth Limiter mechanism introduced in [Fira: Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623).
306
+ ### Core Algorithm & Optimizer Design
307
+ - **Noam Shazeer & Mitchell Stern** for proposing the original **Adafactor** algorithm ([Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235)).
308
+ - **Tim Dettmers** for the inspiration from **8-bit block-wise quantization** ([8-BIT OPTIMIZERS VIA BLOCK-WISE QUANTIZATION](https://arxiv.org/abs/2110.02861)) and the [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) library.
309
+ - **Hanqing Zhu, Zhenyu Zhang, et al.** for the **APOLLO** algorithm ([APOLLO: SGD-Like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270)).
310
+ - **Xi Chen, Kaituo Feng, et al.** for the **Norm-Growth Limiter** mechanism in **Fira** ([Fira: Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623)).
311
+ - **Yang Luo, et al.** for the **confidence-guided strategy** in **CAME** ([CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047)).
314
312
 
315
- Thanks to **Yang Luo** and the team for proposing the confidence-guided strategy in the paper [CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047).
313
+ ### Quantization & Implementation
314
+ - **The QLoRA Team** for pioneering the **4-bit NormalFloat (NF4)** quantization format ([QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)) that inspired our first moment quantization.
315
+ - **The PyTorch AO Team** for their work on [4-bit optimizer states](https://github.com/pytorch/ao/tree/main/torchao/optim), validating distribution-aware quantization for optimizer moments.
316
+ - **The PyTorch Team** for providing the foundational optimizer implementation and the C++ Extension toolchain.
316
317
 
317
- Thanks to the **PyTorch team** for providing the foundational Optimizer implementation and the C++ Extension toolchain.
318
+ ### Technical Review & Discussion
319
+ - **Qwen, ChatGLM, and DeepSeek** (large language models) for valuable technical discussions and code reviews on CUDA low-level optimization, memory safety mechanisms, and cross-platform compilation pipeline design.
318
320
 
319
- Thanks to the large language models **Qwen**, **ChatGLM** and **DeepSeek** for valuable technical discussions and code reviews on CUDA low-level optimization and memory safety mechanisms.
320
321
 
321
322
  ## 🏛️ License
322
323
 
@@ -7,7 +7,26 @@
7
7
  __device__ constexpr float INV_255 = 1.0f / 255.0f;
8
8
  __device__ constexpr float MIN_LOG = -126.0f;
9
9
  __device__ constexpr float MIN_VAL = 1.17549435e-38f;
10
- __device__ constexpr float INV_7 = 1.0f / 7.0f;
10
+
11
+ __constant__ float NF4_QMAP[16] = {
12
+ -1.0f, -0.6961928f, -0.52507306f, -0.3895074f,
13
+ -0.27408478f, -0.17286907f, -0.07958022f, 0.0f,
14
+ 0.07958022f, 0.17286907f, 0.27408478f, 0.3895074f,
15
+ 0.52507306f, 0.6961928f, 0.8641379f, 1.0f
16
+ };
17
+
18
+ __device__ __forceinline__ int find_nearest_nf4(float x) {
19
+ float x_abs = fabsf(x);
20
+ if (x_abs < 0.0397901f) return 7;
21
+ if (x_abs < 0.1262246f) return (x >= 0.0f) ? 8 : 6;
22
+ if (x_abs < 0.2234769f) return (x >= 0.0f) ? 9 : 5;
23
+ if (x_abs < 0.3317961f) return (x >= 0.0f) ? 10 : 4;
24
+ if (x_abs < 0.4572902f) return (x >= 0.0f) ? 11 : 3;
25
+ if (x_abs < 0.6106329f) return (x >= 0.0f) ? 12 : 2;
26
+ if (x_abs < 0.7801653f) return (x >= 0.0f) ? 13 : 1;
27
+ if (x_abs < 0.9320689f) return (x >= 0.0f) ? 14 : 0;
28
+ return (x >= 0.0f) ? 15 : 0;
29
+ }
11
30
 
12
31
  // ==========================================
13
32
  // 1. Fused Log-Quantize Lerp (EMA Update for V_t)
@@ -226,11 +245,10 @@ __global__ void fused_4bit_quantize_lerp_kernel(
226
245
 
227
246
  uchar2 old_q = q_vec[idx];
228
247
 
229
- // Unpack old m_t (biased by +8 for unsigned storage)
230
- float m_old0 = (float)((old_q.x >> 4) - 8) * old_scale;
231
- float m_old1 = (float)((old_q.x & 0x0F) - 8) * old_scale;
232
- float m_old2 = (float)((old_q.y >> 4) - 8) * old_scale;
233
- float m_old3 = (float)((old_q.y & 0x0F) - 8) * old_scale;
248
+ float m_old0 = NF4_QMAP[(old_q.x >> 4)] * old_scale;
249
+ float m_old1 = NF4_QMAP[(old_q.x & 0x0F)] * old_scale;
250
+ float m_old2 = NF4_QMAP[(old_q.y >> 4)] * old_scale;
251
+ float m_old3 = NF4_QMAP[(old_q.y & 0x0F)] * old_scale;
234
252
 
235
253
  float m_new0 = beta * m_old0 + one_minus_b * val_x;
236
254
  float m_new1 = beta * m_old1 + one_minus_b * val_y;
@@ -267,7 +285,7 @@ __global__ void fused_4bit_quantize_lerp_kernel(
267
285
  __syncthreads();
268
286
 
269
287
  float abs_max = fmaxf(s_max[0], 1e-12f);
270
- float new_scale = abs_max * INV_7;
288
+ float new_scale = abs_max;
271
289
  float inv_scale = 1.0f / new_scale;
272
290
 
273
291
  for (int i = 0; i < vec_iters; i++) {
@@ -279,12 +297,10 @@ __global__ void fused_4bit_quantize_lerp_kernel(
279
297
  float m2 = local_m[idx * 4 + 2];
280
298
  float m3 = local_m[idx * 4 + 3];
281
299
 
282
- // Pure integer clamping, then bias by +8 for unsigned 4-bit packing
283
- // Using standard 4-bit signed range [-8, 7] mapping to [0, 15]
284
- int q0 = max(-8, min(7, __float2int_rn(m0 * inv_scale))) + 8;
285
- int q1 = max(-8, min(7, __float2int_rn(m1 * inv_scale))) + 8;
286
- int q2 = max(-8, min(7, __float2int_rn(m2 * inv_scale))) + 8;
287
- int q3 = max(-8, min(7, __float2int_rn(m3 * inv_scale))) + 8;
300
+ int q0 = find_nearest_nf4(m0 * inv_scale);
301
+ int q1 = find_nearest_nf4(m1 * inv_scale);
302
+ int q2 = find_nearest_nf4(m2 * inv_scale);
303
+ int q3 = find_nearest_nf4(m3 * inv_scale);
288
304
 
289
305
  uchar2 out_q;
290
306
  out_q.x = (unsigned char)((q0 << 4) | q1);
@@ -561,7 +577,7 @@ __global__ void compute_update_norm_m_2d_kernel(
561
577
  // Unpack 4-bit m_t
562
578
  unsigned char packed = m_q[idx / 2];
563
579
  int q_int = (idx & 1) ? (packed & 0x0F) : (packed >> 4);
564
- float m_val = (float)(q_int - 8) * m_scale[idx / m_block_size];
580
+ float m_val = NF4_QMAP[q_int] * m_scale[idx / m_block_size];
565
581
 
566
582
  float log_r = (float)row_var_q[b * R + r] * INV_255 * row_var_scale[(b * R + r) / v_block_size] + MIN_LOG;
567
583
  float log_c = (float)col_var_q[b * C + c] * INV_255 * col_var_scale[(b * C + c) / v_block_size] + MIN_LOG;
@@ -636,7 +652,7 @@ __global__ void apply_update_m_2d_kernel(
636
652
  // Unpack 4-bit m_t
637
653
  unsigned char packed = m_q[idx / 2];
638
654
  int q_int = (idx & 1) ? (packed & 0x0F) : (packed >> 4);
639
- float m_val = (float)(q_int - 8) * m_scale[idx / m_block_size];
655
+ float m_val = NF4_QMAP[q_int] * m_scale[idx / m_block_size];
640
656
 
641
657
  float log_r = (float)row_var_q[b * R + r] * INV_255 * row_var_scale[(b * R + r) / v_block_size] + MIN_LOG;
642
658
  float log_c = (float)col_var_q[b * C + c] * INV_255 * col_var_scale[(b * C + c) / v_block_size] + MIN_LOG;
@@ -692,7 +708,7 @@ __global__ void compute_update_norm_m_1d_kernel(
692
708
  // Unpack 4-bit m_t
693
709
  unsigned char packed = m_q[idx / 2];
694
710
  int q_int = (idx & 1) ? (packed & 0x0F) : (packed >> 4);
695
- float m_val = (float)(q_int - 8) * m_scale[idx / m_block_size];
711
+ float m_val = NF4_QMAP[q_int] * m_scale[idx / m_block_size];
696
712
 
697
713
  float log_v = (float)variance_q[idx] * INV_255 * variance_scale[idx / v_block_size] + MIN_LOG;
698
714
 
@@ -751,7 +767,7 @@ __global__ void apply_update_m_1d_kernel(
751
767
  // Unpack 4-bit m_t
752
768
  unsigned char packed = m_q[idx / 2];
753
769
  int q_int = (idx & 1) ? (packed & 0x0F) : (packed >> 4);
754
- float m_val = (float)(q_int - 8) * m_scale[idx / m_block_size];
770
+ float m_val = NF4_QMAP[q_int] * m_scale[idx / m_block_size];
755
771
 
756
772
  float log_v = (float)variance_q[idx] * INV_255 * variance_scale[idx / v_block_size] + MIN_LOG;
757
773
 
@@ -808,7 +824,7 @@ __global__ void compute_apollo_norms_kernel(
808
824
  // 4-bit m 解包
809
825
  unsigned char m_byte = m_q[global_idx / 2];
810
826
  int m_int = (global_idx & 1) ? (m_byte & 0x0F) : (m_byte >> 4);
811
- float m_val = ((float)m_int - 8.0f) * m_scale[global_idx / m_block_size];
827
+ float m_val = NF4_QMAP[m_int] * m_scale[global_idx / m_block_size];
812
828
  // 8-bit log v 解包
813
829
  unsigned char v_byte = v_q[global_idx];
814
830
  float log_v = (float)v_byte * INV_255 * v_scale[global_idx / v_block_size] + MIN_LOG;
@@ -892,7 +908,7 @@ __global__ void dequantize_4bit_kernel(
892
908
  if (idx >= numel) return;
893
909
  unsigned char packed = q[idx / 2];
894
910
  int q_int = (idx & 1) ? (packed & 0x0F) : (packed >> 4);
895
- output[idx] = (float)(q_int - 8) * scale[idx / block_size];
911
+ output[idx] = NF4_QMAP[q_int] * scale[idx / block_size];
896
912
  }
897
913
 
898
914
  void dequantize_4bit_cuda(
@@ -1132,7 +1148,7 @@ __global__ void came_compute_residual_2d_kernel(
1132
1148
 
1133
1149
  unsigned char packed = m_q[idx / 2];
1134
1150
  int q_int = (idx & 1) ? (packed & 0x0F) : (packed >> 4);
1135
- float m_val = (float)(q_int - 8) * m_scale[idx / m_block_size];
1151
+ float m_val = NF4_QMAP[q_int] * m_scale[idx / m_block_size];
1136
1152
 
1137
1153
  float log_r = (float)row_var_q[b * R + r] * INV_255 * row_var_scale[(b * R + r) / v_block_size] + MIN_LOG;
1138
1154
  float log_c = (float)col_var_q[b * C + c] * INV_255 * col_var_scale[(b * C + c) / v_block_size] + MIN_LOG;
@@ -21,6 +21,14 @@ _FP32_MIN_LOG = -126.0
21
21
  _INV_255 = 1.0 / 255.0
22
22
  _INV_7 = 1.0 / 7.0
23
23
 
24
+ _NF4_TABLE = torch.tensor([
25
+ -1.0, -0.6961928, -0.52507306, -0.3895074,
26
+ -0.27408478, -0.17286907, -0.07958022, 0.0,
27
+ 0.07958022, 0.17286907, 0.27408478, 0.3895074,
28
+ 0.52507306, 0.6961928, 0.8641379, 1.0
29
+ ], dtype=torch.float32)
30
+
31
+
24
32
  # ==========================================
25
33
  # 1. CUDA Kernel JIT Loading
26
34
  # ==========================================
@@ -127,7 +135,7 @@ def _log_dequantize_nonneg(q: Tensor, scale: Tensor, shape: torch.Size, pad: int
127
135
  return flat.view(shape)
128
136
 
129
137
  def _quantize_4bit_pytorch(m: Tensor, block_size: int) -> Tuple[Tensor, Tensor]:
130
- """4-bit symmetric min-max quantization with physical packing into uint8."""
138
+ """4-bit NF4 non-uniform quantization with physical packing into uint8."""
131
139
  flat = m.flatten()
132
140
  pad = (block_size - flat.numel() % block_size) % block_size
133
141
  if pad:
@@ -135,12 +143,15 @@ def _quantize_4bit_pytorch(m: Tensor, block_size: int) -> Tuple[Tensor, Tensor]:
135
143
 
136
144
  blocks = flat.view(-1, block_size)
137
145
  abs_max = blocks.abs().amax(dim=1, keepdim=True).clamp(min=1e-12)
138
- scale = abs_max * _INV_7 # _INV_7 = 1.0 / 7.0
146
+ scale = abs_max
139
147
 
140
- q = (torch.round(blocks / scale).clamp(-8, 7) + 8).to(torch.uint8)
148
+ normalized = blocks / scale
149
+ table = _NF4_TABLE.to(normalized.device)
150
+ diff = (normalized.unsqueeze(-1) - table).abs()
151
+ codes = diff.argmin(dim=-1).to(torch.uint8)
141
152
 
142
- q_even = q[:, 0::2]
143
- q_odd = q[:, 1::2]
153
+ q_even = codes[:, 0::2]
154
+ q_odd = codes[:, 1::2]
144
155
  packed = (q_even << 4) | q_odd
145
156
 
146
157
  return packed.view(-1), scale.squeeze(-1)
@@ -152,11 +163,12 @@ def _dequantize_4bit(m_q: Tensor, m_scale: Tensor, numel: int, shape: torch.Size
152
163
  _CUDA_MODULE.dequantize_4bit(output, m_q, m_scale, numel, block_size)
153
164
  return output.view(shape)
154
165
  else:
155
- high = ((m_q >> 4) & 0x0F).to(torch.float32) - 8.0
156
- low = (m_q & 0x0F).to(torch.float32) - 8.0
157
- m_flat = torch.stack((high, low), dim=-1).view(-1)
158
- m_blocks = m_flat.view(-1, block_size)
159
- result = (m_blocks * m_scale.unsqueeze(-1)).view(-1)[:numel]
166
+ high = (m_q >> 4)
167
+ low = (m_q & 0x0F)
168
+ codes = torch.stack((high, low), dim=-1).view(-1)
169
+ m_blocks = codes.view(-1, block_size)
170
+ table = _NF4_TABLE.to(m_q.device)
171
+ result = (table[m_blocks.long()] * m_scale.unsqueeze(-1)).view(-1)[:numel]
160
172
  return result.view(shape)
161
173
 
162
174
  # ==========================================
@@ -461,20 +473,21 @@ class Adafactor8Bit(Optimizer):
461
473
 
462
474
  if beta1 is not None:
463
475
  m_padded_numel = ((p.numel() + m_block_size - 1) // m_block_size) * m_block_size
464
- state["m_q"] = torch.full((m_padded_numel // 2,), 0x88, dtype=torch.uint8, device=p.device)
476
+ state["m_q"] = torch.full((m_padded_numel // 2,), 0x77, dtype=torch.uint8, device=p.device)
465
477
  state["m_scale"] = torch.ones(m_padded_numel // m_block_size, dtype=torch.float32, device=p.device)
466
478
  state["m_block_size"] = m_block_size
467
479
 
468
480
  if beta3 is not None:
469
- state["conf_row_q"] = torch.zeros_like(state["row_var_q"])
470
- state["conf_row_scale"] = torch.ones_like(state["row_var_scale"])
481
+ state["conf_row_q"] = torch.full_like(state["row_var_q"], 255)
482
+ state["conf_row_scale"] = torch.full_like(state["row_var_scale"], 126.0)
471
483
  state["conf_row_shape"] = state["row_var_shape"]
472
484
  state["conf_row_pad"] = state["row_var_pad"]
473
485
 
474
- state["conf_col_q"] = torch.zeros_like(state["col_var_q"])
475
- state["conf_col_scale"] = torch.ones_like(state["col_var_scale"])
486
+ state["conf_col_q"] = torch.full_like(state["col_var_q"], 255)
487
+ state["conf_col_scale"] = torch.full_like(state["col_var_scale"], 126.0)
476
488
  state["conf_col_shape"] = state["col_var_shape"]
477
489
  state["conf_col_pad"] = state["col_var_pad"]
490
+
478
491
  else:
479
492
  state["row_var"] = torch.zeros(r_shape, dtype=torch.float32, device=p.device)
480
493
  state["col_var"] = torch.zeros(c_shape, dtype=torch.float32, device=p.device)
@@ -494,7 +507,7 @@ class Adafactor8Bit(Optimizer):
494
507
 
495
508
  if beta1 is not None:
496
509
  m_padded_numel = ((p.numel() + m_block_size - 1) // m_block_size) * m_block_size
497
- state["m_q"] = torch.full((m_padded_numel // 2,), 0x88, dtype=torch.uint8, device=p.device)
510
+ state["m_q"] = torch.full((m_padded_numel // 2,), 0x77, dtype=torch.uint8, device=p.device)
498
511
  state["m_scale"] = torch.ones(m_padded_numel // m_block_size, dtype=torch.float32, device=p.device)
499
512
  state["m_block_size"] = m_block_size
500
513
  else:
@@ -566,7 +579,7 @@ class Adafactor8Bit(Optimizer):
566
579
  state["m_q"], state["m_scale"] = _quantize_4bit_pytorch(state["m"], m_curr_block_size)
567
580
  state.pop("m")
568
581
  else:
569
- state["m_q"] = torch.full((m_padded_numel // 2,), 0x88, dtype=torch.uint8, device=p.device)
582
+ state["m_q"] = torch.full((m_padded_numel // 2,), 0x77, dtype=torch.uint8, device=p.device)
570
583
  state["m_scale"] = torch.ones(m_padded_numel // m_curr_block_size, dtype=torch.float32, device=p.device)
571
584
  state["m_block_size"] = m_curr_block_size
572
585
 
@@ -620,7 +633,7 @@ class Adafactor8Bit(Optimizer):
620
633
  state["m_q"], state["m_scale"] = _quantize_4bit_pytorch(state["m"], m_curr_block_size)
621
634
  state.pop("m")
622
635
  else:
623
- state["m_q"] = torch.full((m_padded_numel // 2,), 0x88, dtype=torch.uint8, device=p.device)
636
+ state["m_q"] = torch.full((m_padded_numel // 2,), 0x77, dtype=torch.uint8, device=p.device)
624
637
  state["m_scale"] = torch.ones(m_padded_numel // m_curr_block_size, dtype=torch.float32, device=p.device)
625
638
  state["m_block_size"] = m_curr_block_size
626
639
 
@@ -1401,7 +1414,7 @@ def _update_param_apollo(
1401
1414
 
1402
1415
  if state.get("m_low_q") is None:
1403
1416
  m_padded_numel = ((grad_low_numel + m_curr_block_size - 1) // m_curr_block_size) * m_curr_block_size
1404
- state["m_low_q"] = torch.full((m_padded_numel // 2,), 0x88, dtype=torch.uint8, device=grad_low.device)
1417
+ state["m_low_q"] = torch.full((m_padded_numel // 2,), 0x77, dtype=torch.uint8, device=grad_low.device)
1405
1418
  state["m_low_scale"] = torch.ones(m_padded_numel // m_curr_block_size, dtype=torch.float32, device=grad_low.device)
1406
1419
  state["m_block_size"] = m_curr_block_size
1407
1420
 
@@ -1481,7 +1494,7 @@ def _update_param_apollo(
1481
1494
 
1482
1495
  if state.get("m_low_q") is None:
1483
1496
  m_padded_numel = ((grad_low_numel + m_curr_block_size - 1) // m_curr_block_size) * m_curr_block_size
1484
- state["m_low_q"] = torch.full((m_padded_numel // 2,), 0x88, dtype=torch.uint8, device=grad_low.device)
1497
+ state["m_low_q"] = torch.full((m_padded_numel // 2,), 0x77, dtype=torch.uint8, device=grad_low.device)
1485
1498
  state["m_low_scale"] = torch.ones(m_padded_numel // m_curr_block_size, dtype=torch.float32, device=grad_low.device)
1486
1499
  state["m_block_size"] = m_curr_block_size
1487
1500
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adafactor8bit
3
- Version: 0.2.2
3
+ Version: 0.2.4
4
4
  Summary: 8-bit Adafactor Optimizer with Fused CUDA Kernels
5
5
  Home-page: https://github.com/yanfeiwong/adafactor-8bit
6
6
  Author: WANG YAN
@@ -53,7 +53,7 @@ An enhanced 8-bit Adafactor optimizer featuring fused CUDA kernels, log-space bl
53
53
 
54
54
  - **Log-Space Quantization**: Maps the second moment (variance) to the log2 space before 8-bit quantization. This approach accommodates the long-tail distribution of variances, reducing the risk of small second-moment estimates being truncated to zero and improving overall training stability.
55
55
  - **Fused CUDA Kernels**: Combines dequantization, EMA updates, Warp-Shuffle reductions, and requantization into single kernels. It utilizes `float4` vectorization to optimize memory bandwidth usage.
56
- - **Optional 4-bit Packed First Moment**: Stores the first moment (`beta1`) in a physically packed 4-bit format when enabled, providing momentum with minimal additional memory overhead.
56
+ - **Optional NF4 First Moment**: Stores the optional first moment (`beta1`) using Normal Float 4-bit (NF4) non-uniform quantization, preserving small momentum updates while keeping memory overhead minimal.
57
57
  - **CAME Confidence Guidance**: Optional Confidence-guided Adaptive Memory Efficient Optimization (CAME) that estimates update confidence from historical momentum and adaptively suppresses unstable update directions, improving training stability and reducing loss spikes.
58
58
  - **APOLLO Subspace Projection**: Opt-in random subspace projection that estimates adaptive gradient scaling in a low-rank space, preventing stale second-moment statistics and potentially improving convergence and generalization.
59
59
  - **Fira Norm-Growth Limiter**: Suppresses destructive gradient spikes by regulating the relative increase of update norms. Originally used for the APOLLO path, it is now available for the standard Adafactor path as well. It improves training stability and often allows the safe removal of external gradient clipping.
@@ -278,7 +278,7 @@ Enable the APOLLO path to compute gradient scaling factors in a memory-efficient
278
278
  - **`apollo_factorize` (Experimental)**: Applies Adafactor's row/column factorization within the low-rank subspace. Mathematically, this leverages the norm-preserving property of random projections to approximate the variance of the primary dimension, while the secondary dimension's variance is estimated across random bases, introducing inherent noise. This dual-compression mechanism drastically reduces optimizer state overhead. Note that for smaller models, the actual VRAM savings might be marginal, and the introduced noise could impact convergence stability. Use with caution.
279
279
  - **Fira Limiter Integration**: The APOLLO path automatically applies the Fira Norm-Growth Limiter to the scaled gradients to prevent sudden gradient rises from causing loss spikes. You can adjust its sensitivity using the global `fira_margin` parameter.
280
280
 
281
- ## 🛡️ CAME Confidence-Guided Updates
281
+ ## 🧊 CAME Confidence-Guided Updates
282
282
 
283
283
  Enable the CAME (Confidence-guided Adaptive Memory Efficient Optimization) path to add a confidence estimation stage after momentum accumulation:
284
284
 
@@ -326,24 +326,25 @@ If you are migrating from optimizers like AdamW, Adafactor's learning rate behav
326
326
  *These are safe starting points. Always validate on your own task and batch size.*
327
327
 
328
328
 
329
-
330
-
331
-
332
329
  ## 🎓 Acknowledgements
333
330
 
334
- Thanks to **Noam Shazeer** and **Mitchell Stern** for proposing the original Adafactor algorithm in the paper [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235).
335
-
336
- Thanks to **Tim Dettmers** for the inspiration from the paper [8-BIT OPTIMIZERS VIA BLOCK-WISE QUANTIZATION](https://arxiv.org/abs/2110.02861) and the [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) library.
337
-
338
- Thanks to **Hanqing Zhu**, **Zhenyu Zhang**, and the team for proposing the approximated gradient scaling method in the paper [APOLLO: SGD-Like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270).
331
+ This project builds upon the foundational work of several researchers and open-source communities. Sincere thanks to the following for their invaluable contributions:
339
332
 
340
- Thanks to **Xi Chen**, **Kaituo Feng**, and the team for the Norm-Growth Limiter mechanism introduced in [Fira: Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623).
333
+ ### Core Algorithm & Optimizer Design
334
+ - **Noam Shazeer & Mitchell Stern** for proposing the original **Adafactor** algorithm ([Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235)).
335
+ - **Tim Dettmers** for the inspiration from **8-bit block-wise quantization** ([8-BIT OPTIMIZERS VIA BLOCK-WISE QUANTIZATION](https://arxiv.org/abs/2110.02861)) and the [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) library.
336
+ - **Hanqing Zhu, Zhenyu Zhang, et al.** for the **APOLLO** algorithm ([APOLLO: SGD-Like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270)).
337
+ - **Xi Chen, Kaituo Feng, et al.** for the **Norm-Growth Limiter** mechanism in **Fira** ([Fira: Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623)).
338
+ - **Yang Luo, et al.** for the **confidence-guided strategy** in **CAME** ([CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047)).
341
339
 
342
- Thanks to **Yang Luo** and the team for proposing the confidence-guided strategy in the paper [CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047).
340
+ ### Quantization & Implementation
341
+ - **The QLoRA Team** for pioneering the **4-bit NormalFloat (NF4)** quantization format ([QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)) that inspired our first moment quantization.
342
+ - **The PyTorch AO Team** for their work on [4-bit optimizer states](https://github.com/pytorch/ao/tree/main/torchao/optim), validating distribution-aware quantization for optimizer moments.
343
+ - **The PyTorch Team** for providing the foundational optimizer implementation and the C++ Extension toolchain.
343
344
 
344
- Thanks to the **PyTorch team** for providing the foundational Optimizer implementation and the C++ Extension toolchain.
345
+ ### Technical Review & Discussion
346
+ - **Qwen, ChatGLM, and DeepSeek** (large language models) for valuable technical discussions and code reviews on CUDA low-level optimization, memory safety mechanisms, and cross-platform compilation pipeline design.
345
347
 
346
- Thanks to the large language models **Qwen**, **ChatGLM** and **DeepSeek** for valuable technical discussions and code reviews on CUDA low-level optimization and memory safety mechanisms.
347
348
 
348
349
  ## 🏛️ License
349
350
 
@@ -9,7 +9,7 @@ long_description = (this_directory / "README.md").read_text(encoding="utf-8")
9
9
 
10
10
  setup(
11
11
  name="adafactor8bit",
12
- version="0.2.2",
12
+ version="0.2.4",
13
13
  description="8-bit Adafactor Optimizer with Fused CUDA Kernels",
14
14
  author="WANG YAN",
15
15
  author_email="yanfeiwong1997@outlook.com",
File without changes
File without changes
File without changes