adafactor8bit 0.2.2__tar.gz → 0.2.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adafactor8bit
3
- Version: 0.2.2
3
+ Version: 0.2.5
4
4
  Summary: 8-bit Adafactor Optimizer with Fused CUDA Kernels
5
5
  Home-page: https://github.com/yanfeiwong/adafactor-8bit
6
6
  Author: WANG YAN
@@ -53,7 +53,7 @@ An enhanced 8-bit Adafactor optimizer featuring fused CUDA kernels, log-space bl
53
53
 
54
54
  - **Log-Space Quantization**: Maps the second moment (variance) to the log2 space before 8-bit quantization. This approach accommodates the long-tail distribution of variances, reducing the risk of small second-moment estimates being truncated to zero and improving overall training stability.
55
55
  - **Fused CUDA Kernels**: Combines dequantization, EMA updates, Warp-Shuffle reductions, and requantization into single kernels. It utilizes `float4` vectorization to optimize memory bandwidth usage.
56
- - **Optional 4-bit Packed First Moment**: Stores the first moment (`beta1`) in a physically packed 4-bit format when enabled, providing momentum with minimal additional memory overhead.
56
+ - **Optional NF4 First Moment**: Stores the optional first moment (`beta1`) using Normal Float 4-bit (NF4) non-uniform quantization, preserving small momentum updates while keeping memory overhead minimal.
57
57
  - **CAME Confidence Guidance**: Optional Confidence-guided Adaptive Memory Efficient Optimization (CAME) that estimates update confidence from historical momentum and adaptively suppresses unstable update directions, improving training stability and reducing loss spikes.
58
58
  - **APOLLO Subspace Projection**: Opt-in random subspace projection that estimates adaptive gradient scaling in a low-rank space, preventing stale second-moment statistics and potentially improving convergence and generalization.
59
59
  - **Fira Norm-Growth Limiter**: Suppresses destructive gradient spikes by regulating the relative increase of update norms. Originally used for the APOLLO path, it is now available for the standard Adafactor path as well. It improves training stability and often allows the safe removal of external gradient clipping.
@@ -278,7 +278,7 @@ Enable the APOLLO path to compute gradient scaling factors in a memory-efficient
278
278
  - **`apollo_factorize` (Experimental)**: Applies Adafactor's row/column factorization within the low-rank subspace. Mathematically, this leverages the norm-preserving property of random projections to approximate the variance of the primary dimension, while the secondary dimension's variance is estimated across random bases, introducing inherent noise. This dual-compression mechanism drastically reduces optimizer state overhead. Note that for smaller models, the actual VRAM savings might be marginal, and the introduced noise could impact convergence stability. Use with caution.
279
279
  - **Fira Limiter Integration**: The APOLLO path automatically applies the Fira Norm-Growth Limiter to the scaled gradients to prevent sudden gradient rises from causing loss spikes. You can adjust its sensitivity using the global `fira_margin` parameter.
280
280
 
281
- ## 🛡️ CAME Confidence-Guided Updates
281
+ ## 🧊 CAME Confidence-Guided Updates
282
282
 
283
283
  Enable the CAME (Confidence-guided Adaptive Memory Efficient Optimization) path to add a confidence estimation stage after momentum accumulation:
284
284
 
@@ -326,24 +326,25 @@ If you are migrating from optimizers like AdamW, Adafactor's learning rate behav
326
326
  *These are safe starting points. Always validate on your own task and batch size.*
327
327
 
328
328
 
329
-
330
-
331
-
332
329
  ## 🎓 Acknowledgements
333
330
 
334
- Thanks to **Noam Shazeer** and **Mitchell Stern** for proposing the original Adafactor algorithm in the paper [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235).
335
-
336
- Thanks to **Tim Dettmers** for the inspiration from the paper [8-BIT OPTIMIZERS VIA BLOCK-WISE QUANTIZATION](https://arxiv.org/abs/2110.02861) and the [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) library.
337
-
338
- Thanks to **Hanqing Zhu**, **Zhenyu Zhang**, and the team for proposing the approximated gradient scaling method in the paper [APOLLO: SGD-Like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270).
331
+ This project builds upon the foundational work of several researchers and open-source communities. Sincere thanks to the following for their invaluable contributions:
339
332
 
340
- Thanks to **Xi Chen**, **Kaituo Feng**, and the team for the Norm-Growth Limiter mechanism introduced in [Fira: Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623).
333
+ ### Core Algorithm & Optimizer Design
334
+ - **Noam Shazeer & Mitchell Stern** for proposing the original **Adafactor** algorithm ([Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235)).
335
+ - **Tim Dettmers** for the inspiration from **8-bit block-wise quantization** ([8-BIT OPTIMIZERS VIA BLOCK-WISE QUANTIZATION](https://arxiv.org/abs/2110.02861)) and the [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) library.
336
+ - **Hanqing Zhu, Zhenyu Zhang, et al.** for the **APOLLO** algorithm ([APOLLO: SGD-Like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270)).
337
+ - **Xi Chen, Kaituo Feng, et al.** for the **Norm-Growth Limiter** mechanism in **Fira** ([Fira: Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623)).
338
+ - **Yang Luo, et al.** for the **confidence-guided strategy** in **CAME** ([CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047)).
341
339
 
342
- Thanks to **Yang Luo** and the team for proposing the confidence-guided strategy in the paper [CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047).
340
+ ### Quantization & Implementation
341
+ - **The QLoRA Team** for pioneering the **4-bit NormalFloat (NF4)** quantization format ([QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)) that inspired our first moment quantization.
342
+ - **The PyTorch AO Team** for their work on [4-bit optimizer states](https://github.com/pytorch/ao/tree/main/torchao/optim), validating distribution-aware quantization for optimizer moments.
343
+ - **The PyTorch Team** for providing the foundational optimizer implementation and the C++ Extension toolchain.
343
344
 
344
- Thanks to the **PyTorch team** for providing the foundational Optimizer implementation and the C++ Extension toolchain.
345
+ ### Technical Review & Discussion
346
+ - **Qwen, ChatGLM, and DeepSeek** (large language models) for valuable technical discussions and code reviews on CUDA low-level optimization, memory safety mechanisms, and cross-platform compilation pipeline design.
345
347
 
346
- Thanks to the large language models **Qwen**, **ChatGLM** and **DeepSeek** for valuable technical discussions and code reviews on CUDA low-level optimization and memory safety mechanisms.
347
348
 
348
349
  ## 🏛️ License
349
350
 
@@ -26,7 +26,7 @@ An enhanced 8-bit Adafactor optimizer featuring fused CUDA kernels, log-space bl
26
26
 
27
27
  - **Log-Space Quantization**: Maps the second moment (variance) to the log2 space before 8-bit quantization. This approach accommodates the long-tail distribution of variances, reducing the risk of small second-moment estimates being truncated to zero and improving overall training stability.
28
28
  - **Fused CUDA Kernels**: Combines dequantization, EMA updates, Warp-Shuffle reductions, and requantization into single kernels. It utilizes `float4` vectorization to optimize memory bandwidth usage.
29
- - **Optional 4-bit Packed First Moment**: Stores the first moment (`beta1`) in a physically packed 4-bit format when enabled, providing momentum with minimal additional memory overhead.
29
+ - **Optional NF4 First Moment**: Stores the optional first moment (`beta1`) using Normal Float 4-bit (NF4) non-uniform quantization, preserving small momentum updates while keeping memory overhead minimal.
30
30
  - **CAME Confidence Guidance**: Optional Confidence-guided Adaptive Memory Efficient Optimization (CAME) that estimates update confidence from historical momentum and adaptively suppresses unstable update directions, improving training stability and reducing loss spikes.
31
31
  - **APOLLO Subspace Projection**: Opt-in random subspace projection that estimates adaptive gradient scaling in a low-rank space, preventing stale second-moment statistics and potentially improving convergence and generalization.
32
32
  - **Fira Norm-Growth Limiter**: Suppresses destructive gradient spikes by regulating the relative increase of update norms. Originally used for the APOLLO path, it is now available for the standard Adafactor path as well. It improves training stability and often allows the safe removal of external gradient clipping.
@@ -251,7 +251,7 @@ Enable the APOLLO path to compute gradient scaling factors in a memory-efficient
251
251
  - **`apollo_factorize` (Experimental)**: Applies Adafactor's row/column factorization within the low-rank subspace. Mathematically, this leverages the norm-preserving property of random projections to approximate the variance of the primary dimension, while the secondary dimension's variance is estimated across random bases, introducing inherent noise. This dual-compression mechanism drastically reduces optimizer state overhead. Note that for smaller models, the actual VRAM savings might be marginal, and the introduced noise could impact convergence stability. Use with caution.
252
252
  - **Fira Limiter Integration**: The APOLLO path automatically applies the Fira Norm-Growth Limiter to the scaled gradients to prevent sudden gradient rises from causing loss spikes. You can adjust its sensitivity using the global `fira_margin` parameter.
253
253
 
254
- ## 🛡️ CAME Confidence-Guided Updates
254
+ ## 🧊 CAME Confidence-Guided Updates
255
255
 
256
256
  Enable the CAME (Confidence-guided Adaptive Memory Efficient Optimization) path to add a confidence estimation stage after momentum accumulation:
257
257
 
@@ -299,24 +299,25 @@ If you are migrating from optimizers like AdamW, Adafactor's learning rate behav
299
299
  *These are safe starting points. Always validate on your own task and batch size.*
300
300
 
301
301
 
302
-
303
-
304
-
305
302
  ## 🎓 Acknowledgements
306
303
 
307
- Thanks to **Noam Shazeer** and **Mitchell Stern** for proposing the original Adafactor algorithm in the paper [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235).
308
-
309
- Thanks to **Tim Dettmers** for the inspiration from the paper [8-BIT OPTIMIZERS VIA BLOCK-WISE QUANTIZATION](https://arxiv.org/abs/2110.02861) and the [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) library.
310
-
311
- Thanks to **Hanqing Zhu**, **Zhenyu Zhang**, and the team for proposing the approximated gradient scaling method in the paper [APOLLO: SGD-Like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270).
304
+ This project builds upon the foundational work of several researchers and open-source communities. Sincere thanks to the following for their invaluable contributions:
312
305
 
313
- Thanks to **Xi Chen**, **Kaituo Feng**, and the team for the Norm-Growth Limiter mechanism introduced in [Fira: Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623).
306
+ ### Core Algorithm & Optimizer Design
307
+ - **Noam Shazeer & Mitchell Stern** for proposing the original **Adafactor** algorithm ([Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235)).
308
+ - **Tim Dettmers** for the inspiration from **8-bit block-wise quantization** ([8-BIT OPTIMIZERS VIA BLOCK-WISE QUANTIZATION](https://arxiv.org/abs/2110.02861)) and the [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) library.
309
+ - **Hanqing Zhu, Zhenyu Zhang, et al.** for the **APOLLO** algorithm ([APOLLO: SGD-Like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270)).
310
+ - **Xi Chen, Kaituo Feng, et al.** for the **Norm-Growth Limiter** mechanism in **Fira** ([Fira: Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623)).
311
+ - **Yang Luo, et al.** for the **confidence-guided strategy** in **CAME** ([CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047)).
314
312
 
315
- Thanks to **Yang Luo** and the team for proposing the confidence-guided strategy in the paper [CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047).
313
+ ### Quantization & Implementation
314
+ - **The QLoRA Team** for pioneering the **4-bit NormalFloat (NF4)** quantization format ([QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)) that inspired our first moment quantization.
315
+ - **The PyTorch AO Team** for their work on [4-bit optimizer states](https://github.com/pytorch/ao/tree/main/torchao/optim), validating distribution-aware quantization for optimizer moments.
316
+ - **The PyTorch Team** for providing the foundational optimizer implementation and the C++ Extension toolchain.
316
317
 
317
- Thanks to the **PyTorch team** for providing the foundational Optimizer implementation and the C++ Extension toolchain.
318
+ ### Technical Review & Discussion
319
+ - **Qwen, ChatGLM, and DeepSeek** (large language models) for valuable technical discussions and code reviews on CUDA low-level optimization, memory safety mechanisms, and cross-platform compilation pipeline design.
318
320
 
319
- Thanks to the large language models **Qwen**, **ChatGLM** and **DeepSeek** for valuable technical discussions and code reviews on CUDA low-level optimization and memory safety mechanisms.
320
321
 
321
322
  ## 🏛️ License
322
323
 
@@ -7,7 +7,26 @@
7
7
  __device__ constexpr float INV_255 = 1.0f / 255.0f;
8
8
  __device__ constexpr float MIN_LOG = -126.0f;
9
9
  __device__ constexpr float MIN_VAL = 1.17549435e-38f;
10
- __device__ constexpr float INV_7 = 1.0f / 7.0f;
10
+
11
+ __constant__ float NF4_QMAP[16] = {
12
+ -1.0f, -0.6961928f, -0.52507306f, -0.3895074f,
13
+ -0.27408478f, -0.17286907f, -0.07958022f, 0.0f,
14
+ 0.07958022f, 0.17286907f, 0.27408478f, 0.3895074f,
15
+ 0.52507306f, 0.6961928f, 0.8641379f, 1.0f
16
+ };
17
+
18
+ __device__ __forceinline__ int find_nearest_nf4(float x) {
19
+ float x_abs = fabsf(x);
20
+ if (x_abs < 0.0397901f) return 7;
21
+ if (x_abs < 0.1262246f) return (x >= 0.0f) ? 8 : 6;
22
+ if (x_abs < 0.2234769f) return (x >= 0.0f) ? 9 : 5;
23
+ if (x_abs < 0.3317961f) return (x >= 0.0f) ? 10 : 4;
24
+ if (x_abs < 0.4572902f) return (x >= 0.0f) ? 11 : 3;
25
+ if (x_abs < 0.6106329f) return (x >= 0.0f) ? 12 : 2;
26
+ if (x_abs < 0.7801653f) return (x >= 0.0f) ? 13 : 1;
27
+ if (x_abs < 0.9320689f) return (x >= 0.0f) ? 14 : 0;
28
+ return (x >= 0.0f) ? 15 : 0;
29
+ }
11
30
 
12
31
  // ==========================================
13
32
  // 1. Fused Log-Quantize Lerp (EMA Update for V_t)
@@ -120,7 +139,7 @@ __global__ void fused_log_quantize_lerp_kernel(
120
139
  }
121
140
  __syncthreads();
122
141
 
123
- float max_log = fminf(fmaxf(s_max[0], MIN_LOG + 1e-12f), 50.0f);
142
+ float max_log = fminf(fmaxf(s_max[0], MIN_LOG + 1e-12f), 126.0f);
124
143
  float new_scale = max_log - MIN_LOG;
125
144
  float inv_scale = 255.0f / (max_log - MIN_LOG);
126
145
 
@@ -226,11 +245,10 @@ __global__ void fused_4bit_quantize_lerp_kernel(
226
245
 
227
246
  uchar2 old_q = q_vec[idx];
228
247
 
229
- // Unpack old m_t (biased by +8 for unsigned storage)
230
- float m_old0 = (float)((old_q.x >> 4) - 8) * old_scale;
231
- float m_old1 = (float)((old_q.x & 0x0F) - 8) * old_scale;
232
- float m_old2 = (float)((old_q.y >> 4) - 8) * old_scale;
233
- float m_old3 = (float)((old_q.y & 0x0F) - 8) * old_scale;
248
+ float m_old0 = NF4_QMAP[(old_q.x >> 4)] * old_scale;
249
+ float m_old1 = NF4_QMAP[(old_q.x & 0x0F)] * old_scale;
250
+ float m_old2 = NF4_QMAP[(old_q.y >> 4)] * old_scale;
251
+ float m_old3 = NF4_QMAP[(old_q.y & 0x0F)] * old_scale;
234
252
 
235
253
  float m_new0 = beta * m_old0 + one_minus_b * val_x;
236
254
  float m_new1 = beta * m_old1 + one_minus_b * val_y;
@@ -267,7 +285,7 @@ __global__ void fused_4bit_quantize_lerp_kernel(
267
285
  __syncthreads();
268
286
 
269
287
  float abs_max = fmaxf(s_max[0], 1e-12f);
270
- float new_scale = abs_max * INV_7;
288
+ float new_scale = abs_max;
271
289
  float inv_scale = 1.0f / new_scale;
272
290
 
273
291
  for (int i = 0; i < vec_iters; i++) {
@@ -279,12 +297,10 @@ __global__ void fused_4bit_quantize_lerp_kernel(
279
297
  float m2 = local_m[idx * 4 + 2];
280
298
  float m3 = local_m[idx * 4 + 3];
281
299
 
282
- // Pure integer clamping, then bias by +8 for unsigned 4-bit packing
283
- // Using standard 4-bit signed range [-8, 7] mapping to [0, 15]
284
- int q0 = max(-8, min(7, __float2int_rn(m0 * inv_scale))) + 8;
285
- int q1 = max(-8, min(7, __float2int_rn(m1 * inv_scale))) + 8;
286
- int q2 = max(-8, min(7, __float2int_rn(m2 * inv_scale))) + 8;
287
- int q3 = max(-8, min(7, __float2int_rn(m3 * inv_scale))) + 8;
300
+ int q0 = find_nearest_nf4(m0 * inv_scale);
301
+ int q1 = find_nearest_nf4(m1 * inv_scale);
302
+ int q2 = find_nearest_nf4(m2 * inv_scale);
303
+ int q3 = find_nearest_nf4(m3 * inv_scale);
288
304
 
289
305
  uchar2 out_q;
290
306
  out_q.x = (unsigned char)((q0 << 4) | q1);
@@ -349,7 +365,8 @@ __global__ void compute_update_norm_2d_kernel(
349
365
  max_log = fmaxf(max_log, -53.0f);
350
366
  float inv_std = exp2f(-0.5f * max_log);
351
367
 
352
- float u_ij = grad[idx] * inv_std;
368
+ float g_val = (isnan(grad[idx]) || isinf(grad[idx])) ? 0.0f : grad[idx];
369
+ float u_ij = g_val * inv_std;
353
370
 
354
371
  sq += u_ij * u_ij;
355
372
  }
@@ -418,7 +435,8 @@ __global__ void apply_update_2d_kernel(
418
435
  max_log = fmaxf(max_log, -53.0f);
419
436
  float inv_std = exp2f(-0.5f * max_log);
420
437
 
421
- float u_ij = grad[idx] * inv_std;
438
+ float g_val = (isnan(grad[idx]) || isinf(grad[idx])) ? 0.0f : grad[idx];
439
+ float u_ij = g_val * inv_std;
422
440
 
423
441
  float p_val = static_cast<float>(param[idx]);
424
442
  p_val -= step_size * u_ij;
@@ -464,7 +482,8 @@ __global__ void compute_update_norm_1d_kernel(
464
482
  max_log = fmaxf(max_log, -53.0f);
465
483
  float inv_std = exp2f(-0.5f * max_log);
466
484
 
467
- float u_val = grad[idx] * inv_std;
485
+ float g_val = (isnan(grad[idx]) || isinf(grad[idx])) ? 0.0f : grad[idx];
486
+ float u_val = g_val * inv_std;
468
487
 
469
488
  sq += u_val * u_val;
470
489
  }
@@ -516,7 +535,8 @@ __global__ void apply_update_1d_kernel(
516
535
  max_log = fmaxf(max_log, -53.0f);
517
536
  float inv_std = exp2f(-0.5f * max_log);
518
537
 
519
- float u_val = grad[idx] * inv_std;
538
+ float g_val = (isnan(grad[idx]) || isinf(grad[idx])) ? 0.0f : grad[idx];
539
+ float u_val = g_val * inv_std;
520
540
 
521
541
  float p_val = static_cast<float>(param[idx]);
522
542
  p_val -= step_size * u_val;
@@ -561,7 +581,7 @@ __global__ void compute_update_norm_m_2d_kernel(
561
581
  // Unpack 4-bit m_t
562
582
  unsigned char packed = m_q[idx / 2];
563
583
  int q_int = (idx & 1) ? (packed & 0x0F) : (packed >> 4);
564
- float m_val = (float)(q_int - 8) * m_scale[idx / m_block_size];
584
+ float m_val = NF4_QMAP[q_int] * m_scale[idx / m_block_size];
565
585
 
566
586
  float log_r = (float)row_var_q[b * R + r] * INV_255 * row_var_scale[(b * R + r) / v_block_size] + MIN_LOG;
567
587
  float log_c = (float)col_var_q[b * C + c] * INV_255 * col_var_scale[(b * C + c) / v_block_size] + MIN_LOG;
@@ -636,7 +656,7 @@ __global__ void apply_update_m_2d_kernel(
636
656
  // Unpack 4-bit m_t
637
657
  unsigned char packed = m_q[idx / 2];
638
658
  int q_int = (idx & 1) ? (packed & 0x0F) : (packed >> 4);
639
- float m_val = (float)(q_int - 8) * m_scale[idx / m_block_size];
659
+ float m_val = NF4_QMAP[q_int] * m_scale[idx / m_block_size];
640
660
 
641
661
  float log_r = (float)row_var_q[b * R + r] * INV_255 * row_var_scale[(b * R + r) / v_block_size] + MIN_LOG;
642
662
  float log_c = (float)col_var_q[b * C + c] * INV_255 * col_var_scale[(b * C + c) / v_block_size] + MIN_LOG;
@@ -692,7 +712,7 @@ __global__ void compute_update_norm_m_1d_kernel(
692
712
  // Unpack 4-bit m_t
693
713
  unsigned char packed = m_q[idx / 2];
694
714
  int q_int = (idx & 1) ? (packed & 0x0F) : (packed >> 4);
695
- float m_val = (float)(q_int - 8) * m_scale[idx / m_block_size];
715
+ float m_val = NF4_QMAP[q_int] * m_scale[idx / m_block_size];
696
716
 
697
717
  float log_v = (float)variance_q[idx] * INV_255 * variance_scale[idx / v_block_size] + MIN_LOG;
698
718
 
@@ -751,7 +771,7 @@ __global__ void apply_update_m_1d_kernel(
751
771
  // Unpack 4-bit m_t
752
772
  unsigned char packed = m_q[idx / 2];
753
773
  int q_int = (idx & 1) ? (packed & 0x0F) : (packed >> 4);
754
- float m_val = (float)(q_int - 8) * m_scale[idx / m_block_size];
774
+ float m_val = NF4_QMAP[q_int] * m_scale[idx / m_block_size];
755
775
 
756
776
  float log_v = (float)variance_q[idx] * INV_255 * variance_scale[idx / v_block_size] + MIN_LOG;
757
777
 
@@ -808,7 +828,7 @@ __global__ void compute_apollo_norms_kernel(
808
828
  // 4-bit m 解包
809
829
  unsigned char m_byte = m_q[global_idx / 2];
810
830
  int m_int = (global_idx & 1) ? (m_byte & 0x0F) : (m_byte >> 4);
811
- float m_val = ((float)m_int - 8.0f) * m_scale[global_idx / m_block_size];
831
+ float m_val = NF4_QMAP[m_int] * m_scale[global_idx / m_block_size];
812
832
  // 8-bit log v 解包
813
833
  unsigned char v_byte = v_q[global_idx];
814
834
  float log_v = (float)v_byte * INV_255 * v_scale[global_idx / v_block_size] + MIN_LOG;
@@ -892,7 +912,7 @@ __global__ void dequantize_4bit_kernel(
892
912
  if (idx >= numel) return;
893
913
  unsigned char packed = q[idx / 2];
894
914
  int q_int = (idx & 1) ? (packed & 0x0F) : (packed >> 4);
895
- output[idx] = (float)(q_int - 8) * scale[idx / block_size];
915
+ output[idx] = NF4_QMAP[q_int] * scale[idx / block_size];
896
916
  }
897
917
 
898
918
  void dequantize_4bit_cuda(
@@ -922,7 +942,7 @@ __global__ void compute_update_norm_1d_full_kernel(
922
942
  float one_minus_b = 1.0f - beta;
923
943
 
924
944
  for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < numel; idx += stride) {
925
- float g = grad[idx];
945
+ float g = (isnan(grad[idx]) || isinf(grad[idx])) ? 0.0f : grad[idx];
926
946
  float g2 = g * g;
927
947
  float v = one_minus_b * variance[idx] + beta * g2;
928
948
  variance[idx] = v;
@@ -979,7 +999,7 @@ __global__ void apply_update_1d_full_kernel(
979
999
 
980
1000
  int stride = gridDim.x * blockDim.x;
981
1001
  for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < numel; idx += stride) {
982
- float g = grad[idx];
1002
+ float g = (isnan(grad[idx]) || isinf(grad[idx])) ? 0.0f : grad[idx];
983
1003
  float v = variance[idx];
984
1004
  float inv_std = rsqrtf(fmaxf(v, eps_sq));
985
1005
  float u = g * inv_std;
@@ -1023,7 +1043,7 @@ __global__ void compute_update_norm_1d_full_m_kernel(
1023
1043
  float one_minus_bv = 1.0f - beta_val;
1024
1044
 
1025
1045
  for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < numel; idx += stride) {
1026
- float g = grad[idx];
1046
+ float g = (isnan(grad[idx]) || isinf(grad[idx])) ? 0.0f : grad[idx];
1027
1047
  float g2 = g * g;
1028
1048
 
1029
1049
  float v = one_minus_bv * variance[idx] + beta_val * g2;
@@ -1132,7 +1152,7 @@ __global__ void came_compute_residual_2d_kernel(
1132
1152
 
1133
1153
  unsigned char packed = m_q[idx / 2];
1134
1154
  int q_int = (idx & 1) ? (packed & 0x0F) : (packed >> 4);
1135
- float m_val = (float)(q_int - 8) * m_scale[idx / m_block_size];
1155
+ float m_val = NF4_QMAP[q_int] * m_scale[idx / m_block_size];
1136
1156
 
1137
1157
  float log_r = (float)row_var_q[b * R + r] * INV_255 * row_var_scale[(b * R + r) / v_block_size] + MIN_LOG;
1138
1158
  float log_c = (float)col_var_q[b * C + c] * INV_255 * col_var_scale[(b * C + c) / v_block_size] + MIN_LOG;
@@ -1143,7 +1163,8 @@ __global__ void came_compute_residual_2d_kernel(
1143
1163
  max_log = fmaxf(max_log, -53.0f);
1144
1164
  float inv_std = exp2f(-0.5f * max_log);
1145
1165
 
1146
- float diff = (grad[idx] - m_val) * inv_std;
1166
+ float g_val = (isnan(grad[idx]) || isinf(grad[idx])) ? 0.0f : grad[idx];
1167
+ float diff = (g_val - m_val) * inv_std;
1147
1168
  float res = diff * diff;
1148
1169
 
1149
1170
  atomicAdd(&res_col_sum[b * C + c], res);
@@ -21,6 +21,14 @@ _FP32_MIN_LOG = -126.0
21
21
  _INV_255 = 1.0 / 255.0
22
22
  _INV_7 = 1.0 / 7.0
23
23
 
24
+ _NF4_TABLE = torch.tensor([
25
+ -1.0, -0.6961928, -0.52507306, -0.3895074,
26
+ -0.27408478, -0.17286907, -0.07958022, 0.0,
27
+ 0.07958022, 0.17286907, 0.27408478, 0.3895074,
28
+ 0.52507306, 0.6961928, 0.8641379, 1.0
29
+ ], dtype=torch.float32)
30
+
31
+
24
32
  # ==========================================
25
33
  # 1. CUDA Kernel JIT Loading
26
34
  # ==========================================
@@ -127,7 +135,7 @@ def _log_dequantize_nonneg(q: Tensor, scale: Tensor, shape: torch.Size, pad: int
127
135
  return flat.view(shape)
128
136
 
129
137
  def _quantize_4bit_pytorch(m: Tensor, block_size: int) -> Tuple[Tensor, Tensor]:
130
- """4-bit symmetric min-max quantization with physical packing into uint8."""
138
+ """4-bit NF4 non-uniform quantization with physical packing into uint8."""
131
139
  flat = m.flatten()
132
140
  pad = (block_size - flat.numel() % block_size) % block_size
133
141
  if pad:
@@ -135,12 +143,15 @@ def _quantize_4bit_pytorch(m: Tensor, block_size: int) -> Tuple[Tensor, Tensor]:
135
143
 
136
144
  blocks = flat.view(-1, block_size)
137
145
  abs_max = blocks.abs().amax(dim=1, keepdim=True).clamp(min=1e-12)
138
- scale = abs_max * _INV_7 # _INV_7 = 1.0 / 7.0
146
+ scale = abs_max
139
147
 
140
- q = (torch.round(blocks / scale).clamp(-8, 7) + 8).to(torch.uint8)
148
+ normalized = blocks / scale
149
+ table = _NF4_TABLE.to(normalized.device)
150
+ diff = (normalized.unsqueeze(-1) - table).abs()
151
+ codes = diff.argmin(dim=-1).to(torch.uint8)
141
152
 
142
- q_even = q[:, 0::2]
143
- q_odd = q[:, 1::2]
153
+ q_even = codes[:, 0::2]
154
+ q_odd = codes[:, 1::2]
144
155
  packed = (q_even << 4) | q_odd
145
156
 
146
157
  return packed.view(-1), scale.squeeze(-1)
@@ -152,11 +163,12 @@ def _dequantize_4bit(m_q: Tensor, m_scale: Tensor, numel: int, shape: torch.Size
152
163
  _CUDA_MODULE.dequantize_4bit(output, m_q, m_scale, numel, block_size)
153
164
  return output.view(shape)
154
165
  else:
155
- high = ((m_q >> 4) & 0x0F).to(torch.float32) - 8.0
156
- low = (m_q & 0x0F).to(torch.float32) - 8.0
157
- m_flat = torch.stack((high, low), dim=-1).view(-1)
158
- m_blocks = m_flat.view(-1, block_size)
159
- result = (m_blocks * m_scale.unsqueeze(-1)).view(-1)[:numel]
166
+ high = (m_q >> 4)
167
+ low = (m_q & 0x0F)
168
+ codes = torch.stack((high, low), dim=-1).view(-1)
169
+ m_blocks = codes.view(-1, block_size)
170
+ table = _NF4_TABLE.to(m_q.device)
171
+ result = (table[m_blocks.long()] * m_scale.unsqueeze(-1)).view(-1)[:numel]
160
172
  return result.view(shape)
161
173
 
162
174
  # ==========================================
@@ -461,20 +473,21 @@ class Adafactor8Bit(Optimizer):
461
473
 
462
474
  if beta1 is not None:
463
475
  m_padded_numel = ((p.numel() + m_block_size - 1) // m_block_size) * m_block_size
464
- state["m_q"] = torch.full((m_padded_numel // 2,), 0x88, dtype=torch.uint8, device=p.device)
476
+ state["m_q"] = torch.full((m_padded_numel // 2,), 0x77, dtype=torch.uint8, device=p.device)
465
477
  state["m_scale"] = torch.ones(m_padded_numel // m_block_size, dtype=torch.float32, device=p.device)
466
478
  state["m_block_size"] = m_block_size
467
479
 
468
480
  if beta3 is not None:
469
- state["conf_row_q"] = torch.zeros_like(state["row_var_q"])
470
- state["conf_row_scale"] = torch.ones_like(state["row_var_scale"])
481
+ state["conf_row_q"] = torch.full_like(state["row_var_q"], 255)
482
+ state["conf_row_scale"] = torch.full_like(state["row_var_scale"], 126.0)
471
483
  state["conf_row_shape"] = state["row_var_shape"]
472
484
  state["conf_row_pad"] = state["row_var_pad"]
473
485
 
474
- state["conf_col_q"] = torch.zeros_like(state["col_var_q"])
475
- state["conf_col_scale"] = torch.ones_like(state["col_var_scale"])
486
+ state["conf_col_q"] = torch.full_like(state["col_var_q"], 255)
487
+ state["conf_col_scale"] = torch.full_like(state["col_var_scale"], 126.0)
476
488
  state["conf_col_shape"] = state["col_var_shape"]
477
489
  state["conf_col_pad"] = state["col_var_pad"]
490
+
478
491
  else:
479
492
  state["row_var"] = torch.zeros(r_shape, dtype=torch.float32, device=p.device)
480
493
  state["col_var"] = torch.zeros(c_shape, dtype=torch.float32, device=p.device)
@@ -494,7 +507,7 @@ class Adafactor8Bit(Optimizer):
494
507
 
495
508
  if beta1 is not None:
496
509
  m_padded_numel = ((p.numel() + m_block_size - 1) // m_block_size) * m_block_size
497
- state["m_q"] = torch.full((m_padded_numel // 2,), 0x88, dtype=torch.uint8, device=p.device)
510
+ state["m_q"] = torch.full((m_padded_numel // 2,), 0x77, dtype=torch.uint8, device=p.device)
498
511
  state["m_scale"] = torch.ones(m_padded_numel // m_block_size, dtype=torch.float32, device=p.device)
499
512
  state["m_block_size"] = m_block_size
500
513
  else:
@@ -566,7 +579,7 @@ class Adafactor8Bit(Optimizer):
566
579
  state["m_q"], state["m_scale"] = _quantize_4bit_pytorch(state["m"], m_curr_block_size)
567
580
  state.pop("m")
568
581
  else:
569
- state["m_q"] = torch.full((m_padded_numel // 2,), 0x88, dtype=torch.uint8, device=p.device)
582
+ state["m_q"] = torch.full((m_padded_numel // 2,), 0x77, dtype=torch.uint8, device=p.device)
570
583
  state["m_scale"] = torch.ones(m_padded_numel // m_curr_block_size, dtype=torch.float32, device=p.device)
571
584
  state["m_block_size"] = m_curr_block_size
572
585
 
@@ -620,7 +633,7 @@ class Adafactor8Bit(Optimizer):
620
633
  state["m_q"], state["m_scale"] = _quantize_4bit_pytorch(state["m"], m_curr_block_size)
621
634
  state.pop("m")
622
635
  else:
623
- state["m_q"] = torch.full((m_padded_numel // 2,), 0x88, dtype=torch.uint8, device=p.device)
636
+ state["m_q"] = torch.full((m_padded_numel // 2,), 0x77, dtype=torch.uint8, device=p.device)
624
637
  state["m_scale"] = torch.ones(m_padded_numel // m_curr_block_size, dtype=torch.float32, device=p.device)
625
638
  state["m_block_size"] = m_curr_block_size
626
639
 
@@ -697,6 +710,11 @@ class Adafactor8Bit(Optimizer):
697
710
  # ==========================================
698
711
  def _apply_fira_cuda(state: Dict[str, Any], total_sum_sq: Tensor, alpha: Tensor, fira_margin: float) -> Tuple[Tensor, Tensor]:
699
712
  current_norm = total_sum_sq.sqrt().view([])
713
+
714
+ is_finite = torch.isfinite(current_norm)
715
+ current_norm = torch.where(is_finite, current_norm, torch.zeros_like(current_norm))
716
+ total_sum_sq = torch.where(is_finite, total_sum_sq, torch.zeros_like(total_sum_sq))
717
+
700
718
  fira_threshold = 1.0 + fira_margin
701
719
 
702
720
  prev_norm = state.get("fira_prev_norm", None)
@@ -704,13 +722,14 @@ def _apply_fira_cuda(state: Dict[str, Any], total_sum_sq: Tensor, alpha: Tensor,
704
722
  if not isinstance(prev_norm, Tensor):
705
723
  prev_norm = torch.tensor(prev_norm, device=total_sum_sq.device, dtype=torch.float32)
706
724
 
725
+ is_reset = prev_norm < 1e-6
707
726
  ratio = current_norm / (prev_norm + 1e-8)
708
727
  limiter = torch.clamp_min(ratio, fira_threshold) / fira_threshold
709
- final_scale = 1.0 / limiter
728
+ final_scale = torch.where(is_reset, torch.ones_like(current_norm), 1.0 / limiter)
729
+ state["fira_prev_norm"] = torch.where(is_reset, current_norm, current_norm * final_scale)
710
730
  else:
711
731
  final_scale = torch.tensor(1.0, device=total_sum_sq.device, dtype=torch.float32)
712
-
713
- state["fira_prev_norm"] = current_norm * final_scale
732
+ state["fira_prev_norm"] = current_norm
714
733
 
715
734
  alpha_scaled = alpha * final_scale
716
735
  total_sum_sq.mul_(final_scale.square())
@@ -719,19 +738,26 @@ def _apply_fira_cuda(state: Dict[str, Any], total_sum_sq: Tensor, alpha: Tensor,
719
738
 
720
739
  def _apply_fira_pytorch(state: Dict[str, Any], update: Tensor, fira_margin: float, numel: int, d: float) -> Tuple[Tensor, Tensor]:
721
740
  current_norm = torch.linalg.vector_norm(update)
741
+
742
+ is_finite = torch.isfinite(current_norm)
743
+ current_norm = torch.where(is_finite, current_norm, torch.zeros_like(current_norm))
744
+ update = torch.where(is_finite, update, torch.zeros_like(update))
745
+
722
746
  fira_threshold = 1.0 + fira_margin
723
747
 
724
748
  prev_norm = state.get("fira_prev_norm", None)
725
749
  if prev_norm is not None:
726
750
  if not isinstance(prev_norm, Tensor):
727
751
  prev_norm = torch.tensor(prev_norm, device=update.device, dtype=torch.float32)
752
+
753
+ is_reset = prev_norm < 1e-6
728
754
  ratio = current_norm / (prev_norm + 1e-8)
729
755
  limiter = torch.clamp_min(ratio, fira_threshold) / fira_threshold
730
- final_scale = 1.0 / limiter
756
+ final_scale = torch.where(is_reset, torch.ones_like(current_norm), 1.0 / limiter)
757
+ state["fira_prev_norm"] = torch.where(is_reset, current_norm, current_norm * final_scale)
731
758
  else:
732
759
  final_scale = torch.tensor(1.0, device=update.device, dtype=torch.float32)
733
-
734
- state["fira_prev_norm"] = current_norm * final_scale
760
+ state["fira_prev_norm"] = current_norm
735
761
 
736
762
  update_scaled = update * final_scale
737
763
  norm_final = current_norm * final_scale
@@ -1295,6 +1321,7 @@ def _update_param_apollo(
1295
1321
  fira_margin: float = 0.01,
1296
1322
  ):
1297
1323
  grad_work = grad.neg().float() if maximize else grad.float()
1324
+ grad_work = torch.where(torch.isfinite(grad_work), grad_work, torch.zeros_like(grad_work))
1298
1325
  update_low = None
1299
1326
 
1300
1327
  if apollo_factorize:
@@ -1401,7 +1428,7 @@ def _update_param_apollo(
1401
1428
 
1402
1429
  if state.get("m_low_q") is None:
1403
1430
  m_padded_numel = ((grad_low_numel + m_curr_block_size - 1) // m_curr_block_size) * m_curr_block_size
1404
- state["m_low_q"] = torch.full((m_padded_numel // 2,), 0x88, dtype=torch.uint8, device=grad_low.device)
1431
+ state["m_low_q"] = torch.full((m_padded_numel // 2,), 0x77, dtype=torch.uint8, device=grad_low.device)
1405
1432
  state["m_low_scale"] = torch.ones(m_padded_numel // m_curr_block_size, dtype=torch.float32, device=grad_low.device)
1406
1433
  state["m_block_size"] = m_curr_block_size
1407
1434
 
@@ -1481,7 +1508,7 @@ def _update_param_apollo(
1481
1508
 
1482
1509
  if state.get("m_low_q") is None:
1483
1510
  m_padded_numel = ((grad_low_numel + m_curr_block_size - 1) // m_curr_block_size) * m_curr_block_size
1484
- state["m_low_q"] = torch.full((m_padded_numel // 2,), 0x88, dtype=torch.uint8, device=grad_low.device)
1511
+ state["m_low_q"] = torch.full((m_padded_numel // 2,), 0x77, dtype=torch.uint8, device=grad_low.device)
1485
1512
  state["m_low_scale"] = torch.ones(m_padded_numel // m_curr_block_size, dtype=torch.float32, device=grad_low.device)
1486
1513
  state["m_block_size"] = m_curr_block_size
1487
1514
 
@@ -1575,16 +1602,20 @@ def _update_param_apollo(
1575
1602
  else:
1576
1603
  current_norm_t = torch.linalg.vector_norm(grad_work, ord=2, dtype=torch.float32) * scaling_factor
1577
1604
 
1605
+ is_finite = torch.isfinite(current_norm_t)
1606
+ current_norm_t = torch.where(is_finite, current_norm_t, torch.zeros_like(current_norm_t))
1607
+
1578
1608
  fira_threshold = 1.0 + fira_margin
1579
1609
  if "scaled_grad_norm_prev" in state:
1580
1610
  prev_norm_t = state["scaled_grad_norm_prev"]
1581
1611
  if not isinstance(prev_norm_t, Tensor):
1582
1612
  prev_norm_t = torch.tensor(prev_norm_t, device=param_work.device, dtype=torch.float32)
1583
1613
 
1614
+ is_reset = prev_norm_t < 1e-6
1584
1615
  ratio = current_norm_t / (prev_norm_t + 1e-8)
1585
1616
  limiter = torch.clamp_min(ratio, fira_threshold) / fira_threshold
1586
- final_scale = scaling_factor / limiter
1587
- state["scaled_grad_norm_prev"] = current_norm_t / limiter
1617
+ final_scale = torch.where(is_reset, scaling_factor, scaling_factor / limiter)
1618
+ state["scaled_grad_norm_prev"] = torch.where(is_reset, current_norm_t, current_norm_t / limiter)
1588
1619
  else:
1589
1620
  final_scale = scaling_factor
1590
1621
  state["scaled_grad_norm_prev"] = current_norm_t
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adafactor8bit
3
- Version: 0.2.2
3
+ Version: 0.2.5
4
4
  Summary: 8-bit Adafactor Optimizer with Fused CUDA Kernels
5
5
  Home-page: https://github.com/yanfeiwong/adafactor-8bit
6
6
  Author: WANG YAN
@@ -53,7 +53,7 @@ An enhanced 8-bit Adafactor optimizer featuring fused CUDA kernels, log-space bl
53
53
 
54
54
  - **Log-Space Quantization**: Maps the second moment (variance) to the log2 space before 8-bit quantization. This approach accommodates the long-tail distribution of variances, reducing the risk of small second-moment estimates being truncated to zero and improving overall training stability.
55
55
  - **Fused CUDA Kernels**: Combines dequantization, EMA updates, Warp-Shuffle reductions, and requantization into single kernels. It utilizes `float4` vectorization to optimize memory bandwidth usage.
56
- - **Optional 4-bit Packed First Moment**: Stores the first moment (`beta1`) in a physically packed 4-bit format when enabled, providing momentum with minimal additional memory overhead.
56
+ - **Optional NF4 First Moment**: Stores the optional first moment (`beta1`) using Normal Float 4-bit (NF4) non-uniform quantization, preserving small momentum updates while keeping memory overhead minimal.
57
57
  - **CAME Confidence Guidance**: Optional Confidence-guided Adaptive Memory Efficient Optimization (CAME) that estimates update confidence from historical momentum and adaptively suppresses unstable update directions, improving training stability and reducing loss spikes.
58
58
  - **APOLLO Subspace Projection**: Opt-in random subspace projection that estimates adaptive gradient scaling in a low-rank space, preventing stale second-moment statistics and potentially improving convergence and generalization.
59
59
  - **Fira Norm-Growth Limiter**: Suppresses destructive gradient spikes by regulating the relative increase of update norms. Originally used for the APOLLO path, it is now available for the standard Adafactor path as well. It improves training stability and often allows the safe removal of external gradient clipping.
@@ -278,7 +278,7 @@ Enable the APOLLO path to compute gradient scaling factors in a memory-efficient
278
278
  - **`apollo_factorize` (Experimental)**: Applies Adafactor's row/column factorization within the low-rank subspace. Mathematically, this leverages the norm-preserving property of random projections to approximate the variance of the primary dimension, while the secondary dimension's variance is estimated across random bases, introducing inherent noise. This dual-compression mechanism drastically reduces optimizer state overhead. Note that for smaller models, the actual VRAM savings might be marginal, and the introduced noise could impact convergence stability. Use with caution.
279
279
  - **Fira Limiter Integration**: The APOLLO path automatically applies the Fira Norm-Growth Limiter to the scaled gradients to prevent sudden gradient rises from causing loss spikes. You can adjust its sensitivity using the global `fira_margin` parameter.
280
280
 
281
- ## 🛡️ CAME Confidence-Guided Updates
281
+ ## 🧊 CAME Confidence-Guided Updates
282
282
 
283
283
  Enable the CAME (Confidence-guided Adaptive Memory Efficient Optimization) path to add a confidence estimation stage after momentum accumulation:
284
284
 
@@ -326,24 +326,25 @@ If you are migrating from optimizers like AdamW, Adafactor's learning rate behav
326
326
  *These are safe starting points. Always validate on your own task and batch size.*
327
327
 
328
328
 
329
-
330
-
331
-
332
329
  ## 🎓 Acknowledgements
333
330
 
334
- Thanks to **Noam Shazeer** and **Mitchell Stern** for proposing the original Adafactor algorithm in the paper [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235).
335
-
336
- Thanks to **Tim Dettmers** for the inspiration from the paper [8-BIT OPTIMIZERS VIA BLOCK-WISE QUANTIZATION](https://arxiv.org/abs/2110.02861) and the [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) library.
337
-
338
- Thanks to **Hanqing Zhu**, **Zhenyu Zhang**, and the team for proposing the approximated gradient scaling method in the paper [APOLLO: SGD-Like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270).
331
+ This project builds upon the foundational work of several researchers and open-source communities. Sincere thanks to the following for their invaluable contributions:
339
332
 
340
- Thanks to **Xi Chen**, **Kaituo Feng**, and the team for the Norm-Growth Limiter mechanism introduced in [Fira: Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623).
333
+ ### Core Algorithm & Optimizer Design
334
+ - **Noam Shazeer & Mitchell Stern** for proposing the original **Adafactor** algorithm ([Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235)).
335
+ - **Tim Dettmers** for the inspiration from **8-bit block-wise quantization** ([8-BIT OPTIMIZERS VIA BLOCK-WISE QUANTIZATION](https://arxiv.org/abs/2110.02861)) and the [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) library.
336
+ - **Hanqing Zhu, Zhenyu Zhang, et al.** for the **APOLLO** algorithm ([APOLLO: SGD-Like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270)).
337
+ - **Xi Chen, Kaituo Feng, et al.** for the **Norm-Growth Limiter** mechanism in **Fira** ([Fira: Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623)).
338
+ - **Yang Luo, et al.** for the **confidence-guided strategy** in **CAME** ([CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047)).
341
339
 
342
- Thanks to **Yang Luo** and the team for proposing the confidence-guided strategy in the paper [CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047).
340
+ ### Quantization & Implementation
341
+ - **The QLoRA Team** for pioneering the **4-bit NormalFloat (NF4)** quantization format ([QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)) that inspired our first moment quantization.
342
+ - **The PyTorch AO Team** for their work on [4-bit optimizer states](https://github.com/pytorch/ao/tree/main/torchao/optim), validating distribution-aware quantization for optimizer moments.
343
+ - **The PyTorch Team** for providing the foundational optimizer implementation and the C++ Extension toolchain.
343
344
 
344
- Thanks to the **PyTorch team** for providing the foundational Optimizer implementation and the C++ Extension toolchain.
345
+ ### Technical Review & Discussion
346
+ - **Qwen, ChatGLM, and DeepSeek** (large language models) for valuable technical discussions and code reviews on CUDA low-level optimization, memory safety mechanisms, and cross-platform compilation pipeline design.
345
347
 
346
- Thanks to the large language models **Qwen**, **ChatGLM** and **DeepSeek** for valuable technical discussions and code reviews on CUDA low-level optimization and memory safety mechanisms.
347
348
 
348
349
  ## 🏛️ License
349
350
 
@@ -9,7 +9,7 @@ long_description = (this_directory / "README.md").read_text(encoding="utf-8")
9
9
 
10
10
  setup(
11
11
  name="adafactor8bit",
12
- version="0.2.2",
12
+ version="0.2.5",
13
13
  description="8-bit Adafactor Optimizer with Fused CUDA Kernels",
14
14
  author="WANG YAN",
15
15
  author_email="yanfeiwong1997@outlook.com",
File without changes
File without changes
File without changes