adafactor8bit 0.2.2__tar.gz → 0.2.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adafactor8bit-0.2.2/adafactor8bit.egg-info → adafactor8bit-0.2.4}/PKG-INFO +16 -15
- {adafactor8bit-0.2.2 → adafactor8bit-0.2.4}/README.md +15 -14
- {adafactor8bit-0.2.2 → adafactor8bit-0.2.4}/adafactor8bit/kernels.cu +36 -20
- {adafactor8bit-0.2.2 → adafactor8bit-0.2.4}/adafactor8bit/optimizer.py +33 -20
- {adafactor8bit-0.2.2 → adafactor8bit-0.2.4/adafactor8bit.egg-info}/PKG-INFO +16 -15
- {adafactor8bit-0.2.2 → adafactor8bit-0.2.4}/setup.py +1 -1
- {adafactor8bit-0.2.2 → adafactor8bit-0.2.4}/LICENSE +0 -0
- {adafactor8bit-0.2.2 → adafactor8bit-0.2.4}/MANIFEST.in +0 -0
- {adafactor8bit-0.2.2 → adafactor8bit-0.2.4}/adafactor8bit/__init__.py +0 -0
- {adafactor8bit-0.2.2 → adafactor8bit-0.2.4}/adafactor8bit.egg-info/SOURCES.txt +0 -0
- {adafactor8bit-0.2.2 → adafactor8bit-0.2.4}/adafactor8bit.egg-info/dependency_links.txt +0 -0
- {adafactor8bit-0.2.2 → adafactor8bit-0.2.4}/adafactor8bit.egg-info/requires.txt +0 -0
- {adafactor8bit-0.2.2 → adafactor8bit-0.2.4}/adafactor8bit.egg-info/top_level.txt +0 -0
- {adafactor8bit-0.2.2 → adafactor8bit-0.2.4}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: adafactor8bit
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.4
|
|
4
4
|
Summary: 8-bit Adafactor Optimizer with Fused CUDA Kernels
|
|
5
5
|
Home-page: https://github.com/yanfeiwong/adafactor-8bit
|
|
6
6
|
Author: WANG YAN
|
|
@@ -53,7 +53,7 @@ An enhanced 8-bit Adafactor optimizer featuring fused CUDA kernels, log-space bl
|
|
|
53
53
|
|
|
54
54
|
- **Log-Space Quantization**: Maps the second moment (variance) to the log2 space before 8-bit quantization. This approach accommodates the long-tail distribution of variances, reducing the risk of small second-moment estimates being truncated to zero and improving overall training stability.
|
|
55
55
|
- **Fused CUDA Kernels**: Combines dequantization, EMA updates, Warp-Shuffle reductions, and requantization into single kernels. It utilizes `float4` vectorization to optimize memory bandwidth usage.
|
|
56
|
-
- **Optional
|
|
56
|
+
- **Optional NF4 First Moment**: Stores the optional first moment (`beta1`) using Normal Float 4-bit (NF4) non-uniform quantization, preserving small momentum updates while keeping memory overhead minimal.
|
|
57
57
|
- **CAME Confidence Guidance**: Optional Confidence-guided Adaptive Memory Efficient Optimization (CAME) that estimates update confidence from historical momentum and adaptively suppresses unstable update directions, improving training stability and reducing loss spikes.
|
|
58
58
|
- **APOLLO Subspace Projection**: Opt-in random subspace projection that estimates adaptive gradient scaling in a low-rank space, preventing stale second-moment statistics and potentially improving convergence and generalization.
|
|
59
59
|
- **Fira Norm-Growth Limiter**: Suppresses destructive gradient spikes by regulating the relative increase of update norms. Originally used for the APOLLO path, it is now available for the standard Adafactor path as well. It improves training stability and often allows the safe removal of external gradient clipping.
|
|
@@ -278,7 +278,7 @@ Enable the APOLLO path to compute gradient scaling factors in a memory-efficient
|
|
|
278
278
|
- **`apollo_factorize` (Experimental)**: Applies Adafactor's row/column factorization within the low-rank subspace. Mathematically, this leverages the norm-preserving property of random projections to approximate the variance of the primary dimension, while the secondary dimension's variance is estimated across random bases, introducing inherent noise. This dual-compression mechanism drastically reduces optimizer state overhead. Note that for smaller models, the actual VRAM savings might be marginal, and the introduced noise could impact convergence stability. Use with caution.
|
|
279
279
|
- **Fira Limiter Integration**: The APOLLO path automatically applies the Fira Norm-Growth Limiter to the scaled gradients to prevent sudden gradient rises from causing loss spikes. You can adjust its sensitivity using the global `fira_margin` parameter.
|
|
280
280
|
|
|
281
|
-
##
|
|
281
|
+
## 🧊 CAME Confidence-Guided Updates
|
|
282
282
|
|
|
283
283
|
Enable the CAME (Confidence-guided Adaptive Memory Efficient Optimization) path to add a confidence estimation stage after momentum accumulation:
|
|
284
284
|
|
|
@@ -326,24 +326,25 @@ If you are migrating from optimizers like AdamW, Adafactor's learning rate behav
|
|
|
326
326
|
*These are safe starting points. Always validate on your own task and batch size.*
|
|
327
327
|
|
|
328
328
|
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
329
|
## 🎓 Acknowledgements
|
|
333
330
|
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
Thanks to **Tim Dettmers** for the inspiration from the paper [8-BIT OPTIMIZERS VIA BLOCK-WISE QUANTIZATION](https://arxiv.org/abs/2110.02861) and the [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) library.
|
|
337
|
-
|
|
338
|
-
Thanks to **Hanqing Zhu**, **Zhenyu Zhang**, and the team for proposing the approximated gradient scaling method in the paper [APOLLO: SGD-Like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270).
|
|
331
|
+
This project builds upon the foundational work of several researchers and open-source communities. Sincere thanks to the following for their invaluable contributions:
|
|
339
332
|
|
|
340
|
-
|
|
333
|
+
### Core Algorithm & Optimizer Design
|
|
334
|
+
- **Noam Shazeer & Mitchell Stern** for proposing the original **Adafactor** algorithm ([Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235)).
|
|
335
|
+
- **Tim Dettmers** for the inspiration from **8-bit block-wise quantization** ([8-BIT OPTIMIZERS VIA BLOCK-WISE QUANTIZATION](https://arxiv.org/abs/2110.02861)) and the [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) library.
|
|
336
|
+
- **Hanqing Zhu, Zhenyu Zhang, et al.** for the **APOLLO** algorithm ([APOLLO: SGD-Like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270)).
|
|
337
|
+
- **Xi Chen, Kaituo Feng, et al.** for the **Norm-Growth Limiter** mechanism in **Fira** ([Fira: Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623)).
|
|
338
|
+
- **Yang Luo, et al.** for the **confidence-guided strategy** in **CAME** ([CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047)).
|
|
341
339
|
|
|
342
|
-
|
|
340
|
+
### Quantization & Implementation
|
|
341
|
+
- **The QLoRA Team** for pioneering the **4-bit NormalFloat (NF4)** quantization format ([QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)) that inspired our first moment quantization.
|
|
342
|
+
- **The PyTorch AO Team** for their work on [4-bit optimizer states](https://github.com/pytorch/ao/tree/main/torchao/optim), validating distribution-aware quantization for optimizer moments.
|
|
343
|
+
- **The PyTorch Team** for providing the foundational optimizer implementation and the C++ Extension toolchain.
|
|
343
344
|
|
|
344
|
-
|
|
345
|
+
### Technical Review & Discussion
|
|
346
|
+
- **Qwen, ChatGLM, and DeepSeek** (large language models) for valuable technical discussions and code reviews on CUDA low-level optimization, memory safety mechanisms, and cross-platform compilation pipeline design.
|
|
345
347
|
|
|
346
|
-
Thanks to the large language models **Qwen**, **ChatGLM** and **DeepSeek** for valuable technical discussions and code reviews on CUDA low-level optimization and memory safety mechanisms.
|
|
347
348
|
|
|
348
349
|
## 🏛️ License
|
|
349
350
|
|
|
@@ -26,7 +26,7 @@ An enhanced 8-bit Adafactor optimizer featuring fused CUDA kernels, log-space bl
|
|
|
26
26
|
|
|
27
27
|
- **Log-Space Quantization**: Maps the second moment (variance) to the log2 space before 8-bit quantization. This approach accommodates the long-tail distribution of variances, reducing the risk of small second-moment estimates being truncated to zero and improving overall training stability.
|
|
28
28
|
- **Fused CUDA Kernels**: Combines dequantization, EMA updates, Warp-Shuffle reductions, and requantization into single kernels. It utilizes `float4` vectorization to optimize memory bandwidth usage.
|
|
29
|
-
- **Optional
|
|
29
|
+
- **Optional NF4 First Moment**: Stores the optional first moment (`beta1`) using Normal Float 4-bit (NF4) non-uniform quantization, preserving small momentum updates while keeping memory overhead minimal.
|
|
30
30
|
- **CAME Confidence Guidance**: Optional Confidence-guided Adaptive Memory Efficient Optimization (CAME) that estimates update confidence from historical momentum and adaptively suppresses unstable update directions, improving training stability and reducing loss spikes.
|
|
31
31
|
- **APOLLO Subspace Projection**: Opt-in random subspace projection that estimates adaptive gradient scaling in a low-rank space, preventing stale second-moment statistics and potentially improving convergence and generalization.
|
|
32
32
|
- **Fira Norm-Growth Limiter**: Suppresses destructive gradient spikes by regulating the relative increase of update norms. Originally used for the APOLLO path, it is now available for the standard Adafactor path as well. It improves training stability and often allows the safe removal of external gradient clipping.
|
|
@@ -251,7 +251,7 @@ Enable the APOLLO path to compute gradient scaling factors in a memory-efficient
|
|
|
251
251
|
- **`apollo_factorize` (Experimental)**: Applies Adafactor's row/column factorization within the low-rank subspace. Mathematically, this leverages the norm-preserving property of random projections to approximate the variance of the primary dimension, while the secondary dimension's variance is estimated across random bases, introducing inherent noise. This dual-compression mechanism drastically reduces optimizer state overhead. Note that for smaller models, the actual VRAM savings might be marginal, and the introduced noise could impact convergence stability. Use with caution.
|
|
252
252
|
- **Fira Limiter Integration**: The APOLLO path automatically applies the Fira Norm-Growth Limiter to the scaled gradients to prevent sudden gradient rises from causing loss spikes. You can adjust its sensitivity using the global `fira_margin` parameter.
|
|
253
253
|
|
|
254
|
-
##
|
|
254
|
+
## 🧊 CAME Confidence-Guided Updates
|
|
255
255
|
|
|
256
256
|
Enable the CAME (Confidence-guided Adaptive Memory Efficient Optimization) path to add a confidence estimation stage after momentum accumulation:
|
|
257
257
|
|
|
@@ -299,24 +299,25 @@ If you are migrating from optimizers like AdamW, Adafactor's learning rate behav
|
|
|
299
299
|
*These are safe starting points. Always validate on your own task and batch size.*
|
|
300
300
|
|
|
301
301
|
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
302
|
## 🎓 Acknowledgements
|
|
306
303
|
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
Thanks to **Tim Dettmers** for the inspiration from the paper [8-BIT OPTIMIZERS VIA BLOCK-WISE QUANTIZATION](https://arxiv.org/abs/2110.02861) and the [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) library.
|
|
310
|
-
|
|
311
|
-
Thanks to **Hanqing Zhu**, **Zhenyu Zhang**, and the team for proposing the approximated gradient scaling method in the paper [APOLLO: SGD-Like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270).
|
|
304
|
+
This project builds upon the foundational work of several researchers and open-source communities. Sincere thanks to the following for their invaluable contributions:
|
|
312
305
|
|
|
313
|
-
|
|
306
|
+
### Core Algorithm & Optimizer Design
|
|
307
|
+
- **Noam Shazeer & Mitchell Stern** for proposing the original **Adafactor** algorithm ([Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235)).
|
|
308
|
+
- **Tim Dettmers** for the inspiration from **8-bit block-wise quantization** ([8-BIT OPTIMIZERS VIA BLOCK-WISE QUANTIZATION](https://arxiv.org/abs/2110.02861)) and the [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) library.
|
|
309
|
+
- **Hanqing Zhu, Zhenyu Zhang, et al.** for the **APOLLO** algorithm ([APOLLO: SGD-Like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270)).
|
|
310
|
+
- **Xi Chen, Kaituo Feng, et al.** for the **Norm-Growth Limiter** mechanism in **Fira** ([Fira: Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623)).
|
|
311
|
+
- **Yang Luo, et al.** for the **confidence-guided strategy** in **CAME** ([CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047)).
|
|
314
312
|
|
|
315
|
-
|
|
313
|
+
### Quantization & Implementation
|
|
314
|
+
- **The QLoRA Team** for pioneering the **4-bit NormalFloat (NF4)** quantization format ([QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)) that inspired our first moment quantization.
|
|
315
|
+
- **The PyTorch AO Team** for their work on [4-bit optimizer states](https://github.com/pytorch/ao/tree/main/torchao/optim), validating distribution-aware quantization for optimizer moments.
|
|
316
|
+
- **The PyTorch Team** for providing the foundational optimizer implementation and the C++ Extension toolchain.
|
|
316
317
|
|
|
317
|
-
|
|
318
|
+
### Technical Review & Discussion
|
|
319
|
+
- **Qwen, ChatGLM, and DeepSeek** (large language models) for valuable technical discussions and code reviews on CUDA low-level optimization, memory safety mechanisms, and cross-platform compilation pipeline design.
|
|
318
320
|
|
|
319
|
-
Thanks to the large language models **Qwen**, **ChatGLM** and **DeepSeek** for valuable technical discussions and code reviews on CUDA low-level optimization and memory safety mechanisms.
|
|
320
321
|
|
|
321
322
|
## 🏛️ License
|
|
322
323
|
|
|
@@ -7,7 +7,26 @@
|
|
|
7
7
|
__device__ constexpr float INV_255 = 1.0f / 255.0f;
|
|
8
8
|
__device__ constexpr float MIN_LOG = -126.0f;
|
|
9
9
|
__device__ constexpr float MIN_VAL = 1.17549435e-38f;
|
|
10
|
-
|
|
10
|
+
|
|
11
|
+
__constant__ float NF4_QMAP[16] = {
|
|
12
|
+
-1.0f, -0.6961928f, -0.52507306f, -0.3895074f,
|
|
13
|
+
-0.27408478f, -0.17286907f, -0.07958022f, 0.0f,
|
|
14
|
+
0.07958022f, 0.17286907f, 0.27408478f, 0.3895074f,
|
|
15
|
+
0.52507306f, 0.6961928f, 0.8641379f, 1.0f
|
|
16
|
+
};
|
|
17
|
+
|
|
18
|
+
__device__ __forceinline__ int find_nearest_nf4(float x) {
|
|
19
|
+
float x_abs = fabsf(x);
|
|
20
|
+
if (x_abs < 0.0397901f) return 7;
|
|
21
|
+
if (x_abs < 0.1262246f) return (x >= 0.0f) ? 8 : 6;
|
|
22
|
+
if (x_abs < 0.2234769f) return (x >= 0.0f) ? 9 : 5;
|
|
23
|
+
if (x_abs < 0.3317961f) return (x >= 0.0f) ? 10 : 4;
|
|
24
|
+
if (x_abs < 0.4572902f) return (x >= 0.0f) ? 11 : 3;
|
|
25
|
+
if (x_abs < 0.6106329f) return (x >= 0.0f) ? 12 : 2;
|
|
26
|
+
if (x_abs < 0.7801653f) return (x >= 0.0f) ? 13 : 1;
|
|
27
|
+
if (x_abs < 0.9320689f) return (x >= 0.0f) ? 14 : 0;
|
|
28
|
+
return (x >= 0.0f) ? 15 : 0;
|
|
29
|
+
}
|
|
11
30
|
|
|
12
31
|
// ==========================================
|
|
13
32
|
// 1. Fused Log-Quantize Lerp (EMA Update for V_t)
|
|
@@ -226,11 +245,10 @@ __global__ void fused_4bit_quantize_lerp_kernel(
|
|
|
226
245
|
|
|
227
246
|
uchar2 old_q = q_vec[idx];
|
|
228
247
|
|
|
229
|
-
|
|
230
|
-
float
|
|
231
|
-
float
|
|
232
|
-
float
|
|
233
|
-
float m_old3 = (float)((old_q.y & 0x0F) - 8) * old_scale;
|
|
248
|
+
float m_old0 = NF4_QMAP[(old_q.x >> 4)] * old_scale;
|
|
249
|
+
float m_old1 = NF4_QMAP[(old_q.x & 0x0F)] * old_scale;
|
|
250
|
+
float m_old2 = NF4_QMAP[(old_q.y >> 4)] * old_scale;
|
|
251
|
+
float m_old3 = NF4_QMAP[(old_q.y & 0x0F)] * old_scale;
|
|
234
252
|
|
|
235
253
|
float m_new0 = beta * m_old0 + one_minus_b * val_x;
|
|
236
254
|
float m_new1 = beta * m_old1 + one_minus_b * val_y;
|
|
@@ -267,7 +285,7 @@ __global__ void fused_4bit_quantize_lerp_kernel(
|
|
|
267
285
|
__syncthreads();
|
|
268
286
|
|
|
269
287
|
float abs_max = fmaxf(s_max[0], 1e-12f);
|
|
270
|
-
float new_scale = abs_max
|
|
288
|
+
float new_scale = abs_max;
|
|
271
289
|
float inv_scale = 1.0f / new_scale;
|
|
272
290
|
|
|
273
291
|
for (int i = 0; i < vec_iters; i++) {
|
|
@@ -279,12 +297,10 @@ __global__ void fused_4bit_quantize_lerp_kernel(
|
|
|
279
297
|
float m2 = local_m[idx * 4 + 2];
|
|
280
298
|
float m3 = local_m[idx * 4 + 3];
|
|
281
299
|
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
int
|
|
285
|
-
int
|
|
286
|
-
int q2 = max(-8, min(7, __float2int_rn(m2 * inv_scale))) + 8;
|
|
287
|
-
int q3 = max(-8, min(7, __float2int_rn(m3 * inv_scale))) + 8;
|
|
300
|
+
int q0 = find_nearest_nf4(m0 * inv_scale);
|
|
301
|
+
int q1 = find_nearest_nf4(m1 * inv_scale);
|
|
302
|
+
int q2 = find_nearest_nf4(m2 * inv_scale);
|
|
303
|
+
int q3 = find_nearest_nf4(m3 * inv_scale);
|
|
288
304
|
|
|
289
305
|
uchar2 out_q;
|
|
290
306
|
out_q.x = (unsigned char)((q0 << 4) | q1);
|
|
@@ -561,7 +577,7 @@ __global__ void compute_update_norm_m_2d_kernel(
|
|
|
561
577
|
// Unpack 4-bit m_t
|
|
562
578
|
unsigned char packed = m_q[idx / 2];
|
|
563
579
|
int q_int = (idx & 1) ? (packed & 0x0F) : (packed >> 4);
|
|
564
|
-
float m_val =
|
|
580
|
+
float m_val = NF4_QMAP[q_int] * m_scale[idx / m_block_size];
|
|
565
581
|
|
|
566
582
|
float log_r = (float)row_var_q[b * R + r] * INV_255 * row_var_scale[(b * R + r) / v_block_size] + MIN_LOG;
|
|
567
583
|
float log_c = (float)col_var_q[b * C + c] * INV_255 * col_var_scale[(b * C + c) / v_block_size] + MIN_LOG;
|
|
@@ -636,7 +652,7 @@ __global__ void apply_update_m_2d_kernel(
|
|
|
636
652
|
// Unpack 4-bit m_t
|
|
637
653
|
unsigned char packed = m_q[idx / 2];
|
|
638
654
|
int q_int = (idx & 1) ? (packed & 0x0F) : (packed >> 4);
|
|
639
|
-
float m_val =
|
|
655
|
+
float m_val = NF4_QMAP[q_int] * m_scale[idx / m_block_size];
|
|
640
656
|
|
|
641
657
|
float log_r = (float)row_var_q[b * R + r] * INV_255 * row_var_scale[(b * R + r) / v_block_size] + MIN_LOG;
|
|
642
658
|
float log_c = (float)col_var_q[b * C + c] * INV_255 * col_var_scale[(b * C + c) / v_block_size] + MIN_LOG;
|
|
@@ -692,7 +708,7 @@ __global__ void compute_update_norm_m_1d_kernel(
|
|
|
692
708
|
// Unpack 4-bit m_t
|
|
693
709
|
unsigned char packed = m_q[idx / 2];
|
|
694
710
|
int q_int = (idx & 1) ? (packed & 0x0F) : (packed >> 4);
|
|
695
|
-
float m_val =
|
|
711
|
+
float m_val = NF4_QMAP[q_int] * m_scale[idx / m_block_size];
|
|
696
712
|
|
|
697
713
|
float log_v = (float)variance_q[idx] * INV_255 * variance_scale[idx / v_block_size] + MIN_LOG;
|
|
698
714
|
|
|
@@ -751,7 +767,7 @@ __global__ void apply_update_m_1d_kernel(
|
|
|
751
767
|
// Unpack 4-bit m_t
|
|
752
768
|
unsigned char packed = m_q[idx / 2];
|
|
753
769
|
int q_int = (idx & 1) ? (packed & 0x0F) : (packed >> 4);
|
|
754
|
-
float m_val =
|
|
770
|
+
float m_val = NF4_QMAP[q_int] * m_scale[idx / m_block_size];
|
|
755
771
|
|
|
756
772
|
float log_v = (float)variance_q[idx] * INV_255 * variance_scale[idx / v_block_size] + MIN_LOG;
|
|
757
773
|
|
|
@@ -808,7 +824,7 @@ __global__ void compute_apollo_norms_kernel(
|
|
|
808
824
|
// 4-bit m 解包
|
|
809
825
|
unsigned char m_byte = m_q[global_idx / 2];
|
|
810
826
|
int m_int = (global_idx & 1) ? (m_byte & 0x0F) : (m_byte >> 4);
|
|
811
|
-
float m_val =
|
|
827
|
+
float m_val = NF4_QMAP[m_int] * m_scale[global_idx / m_block_size];
|
|
812
828
|
// 8-bit log v 解包
|
|
813
829
|
unsigned char v_byte = v_q[global_idx];
|
|
814
830
|
float log_v = (float)v_byte * INV_255 * v_scale[global_idx / v_block_size] + MIN_LOG;
|
|
@@ -892,7 +908,7 @@ __global__ void dequantize_4bit_kernel(
|
|
|
892
908
|
if (idx >= numel) return;
|
|
893
909
|
unsigned char packed = q[idx / 2];
|
|
894
910
|
int q_int = (idx & 1) ? (packed & 0x0F) : (packed >> 4);
|
|
895
|
-
output[idx] =
|
|
911
|
+
output[idx] = NF4_QMAP[q_int] * scale[idx / block_size];
|
|
896
912
|
}
|
|
897
913
|
|
|
898
914
|
void dequantize_4bit_cuda(
|
|
@@ -1132,7 +1148,7 @@ __global__ void came_compute_residual_2d_kernel(
|
|
|
1132
1148
|
|
|
1133
1149
|
unsigned char packed = m_q[idx / 2];
|
|
1134
1150
|
int q_int = (idx & 1) ? (packed & 0x0F) : (packed >> 4);
|
|
1135
|
-
float m_val =
|
|
1151
|
+
float m_val = NF4_QMAP[q_int] * m_scale[idx / m_block_size];
|
|
1136
1152
|
|
|
1137
1153
|
float log_r = (float)row_var_q[b * R + r] * INV_255 * row_var_scale[(b * R + r) / v_block_size] + MIN_LOG;
|
|
1138
1154
|
float log_c = (float)col_var_q[b * C + c] * INV_255 * col_var_scale[(b * C + c) / v_block_size] + MIN_LOG;
|
|
@@ -21,6 +21,14 @@ _FP32_MIN_LOG = -126.0
|
|
|
21
21
|
_INV_255 = 1.0 / 255.0
|
|
22
22
|
_INV_7 = 1.0 / 7.0
|
|
23
23
|
|
|
24
|
+
_NF4_TABLE = torch.tensor([
|
|
25
|
+
-1.0, -0.6961928, -0.52507306, -0.3895074,
|
|
26
|
+
-0.27408478, -0.17286907, -0.07958022, 0.0,
|
|
27
|
+
0.07958022, 0.17286907, 0.27408478, 0.3895074,
|
|
28
|
+
0.52507306, 0.6961928, 0.8641379, 1.0
|
|
29
|
+
], dtype=torch.float32)
|
|
30
|
+
|
|
31
|
+
|
|
24
32
|
# ==========================================
|
|
25
33
|
# 1. CUDA Kernel JIT Loading
|
|
26
34
|
# ==========================================
|
|
@@ -127,7 +135,7 @@ def _log_dequantize_nonneg(q: Tensor, scale: Tensor, shape: torch.Size, pad: int
|
|
|
127
135
|
return flat.view(shape)
|
|
128
136
|
|
|
129
137
|
def _quantize_4bit_pytorch(m: Tensor, block_size: int) -> Tuple[Tensor, Tensor]:
|
|
130
|
-
"""4-bit
|
|
138
|
+
"""4-bit NF4 non-uniform quantization with physical packing into uint8."""
|
|
131
139
|
flat = m.flatten()
|
|
132
140
|
pad = (block_size - flat.numel() % block_size) % block_size
|
|
133
141
|
if pad:
|
|
@@ -135,12 +143,15 @@ def _quantize_4bit_pytorch(m: Tensor, block_size: int) -> Tuple[Tensor, Tensor]:
|
|
|
135
143
|
|
|
136
144
|
blocks = flat.view(-1, block_size)
|
|
137
145
|
abs_max = blocks.abs().amax(dim=1, keepdim=True).clamp(min=1e-12)
|
|
138
|
-
scale = abs_max
|
|
146
|
+
scale = abs_max
|
|
139
147
|
|
|
140
|
-
|
|
148
|
+
normalized = blocks / scale
|
|
149
|
+
table = _NF4_TABLE.to(normalized.device)
|
|
150
|
+
diff = (normalized.unsqueeze(-1) - table).abs()
|
|
151
|
+
codes = diff.argmin(dim=-1).to(torch.uint8)
|
|
141
152
|
|
|
142
|
-
q_even =
|
|
143
|
-
q_odd =
|
|
153
|
+
q_even = codes[:, 0::2]
|
|
154
|
+
q_odd = codes[:, 1::2]
|
|
144
155
|
packed = (q_even << 4) | q_odd
|
|
145
156
|
|
|
146
157
|
return packed.view(-1), scale.squeeze(-1)
|
|
@@ -152,11 +163,12 @@ def _dequantize_4bit(m_q: Tensor, m_scale: Tensor, numel: int, shape: torch.Size
|
|
|
152
163
|
_CUDA_MODULE.dequantize_4bit(output, m_q, m_scale, numel, block_size)
|
|
153
164
|
return output.view(shape)
|
|
154
165
|
else:
|
|
155
|
-
high = (
|
|
156
|
-
low = (m_q & 0x0F)
|
|
157
|
-
|
|
158
|
-
m_blocks =
|
|
159
|
-
|
|
166
|
+
high = (m_q >> 4)
|
|
167
|
+
low = (m_q & 0x0F)
|
|
168
|
+
codes = torch.stack((high, low), dim=-1).view(-1)
|
|
169
|
+
m_blocks = codes.view(-1, block_size)
|
|
170
|
+
table = _NF4_TABLE.to(m_q.device)
|
|
171
|
+
result = (table[m_blocks.long()] * m_scale.unsqueeze(-1)).view(-1)[:numel]
|
|
160
172
|
return result.view(shape)
|
|
161
173
|
|
|
162
174
|
# ==========================================
|
|
@@ -461,20 +473,21 @@ class Adafactor8Bit(Optimizer):
|
|
|
461
473
|
|
|
462
474
|
if beta1 is not None:
|
|
463
475
|
m_padded_numel = ((p.numel() + m_block_size - 1) // m_block_size) * m_block_size
|
|
464
|
-
state["m_q"] = torch.full((m_padded_numel // 2,),
|
|
476
|
+
state["m_q"] = torch.full((m_padded_numel // 2,), 0x77, dtype=torch.uint8, device=p.device)
|
|
465
477
|
state["m_scale"] = torch.ones(m_padded_numel // m_block_size, dtype=torch.float32, device=p.device)
|
|
466
478
|
state["m_block_size"] = m_block_size
|
|
467
479
|
|
|
468
480
|
if beta3 is not None:
|
|
469
|
-
state["conf_row_q"] = torch.
|
|
470
|
-
state["conf_row_scale"] = torch.
|
|
481
|
+
state["conf_row_q"] = torch.full_like(state["row_var_q"], 255)
|
|
482
|
+
state["conf_row_scale"] = torch.full_like(state["row_var_scale"], 126.0)
|
|
471
483
|
state["conf_row_shape"] = state["row_var_shape"]
|
|
472
484
|
state["conf_row_pad"] = state["row_var_pad"]
|
|
473
485
|
|
|
474
|
-
state["conf_col_q"] = torch.
|
|
475
|
-
state["conf_col_scale"] = torch.
|
|
486
|
+
state["conf_col_q"] = torch.full_like(state["col_var_q"], 255)
|
|
487
|
+
state["conf_col_scale"] = torch.full_like(state["col_var_scale"], 126.0)
|
|
476
488
|
state["conf_col_shape"] = state["col_var_shape"]
|
|
477
489
|
state["conf_col_pad"] = state["col_var_pad"]
|
|
490
|
+
|
|
478
491
|
else:
|
|
479
492
|
state["row_var"] = torch.zeros(r_shape, dtype=torch.float32, device=p.device)
|
|
480
493
|
state["col_var"] = torch.zeros(c_shape, dtype=torch.float32, device=p.device)
|
|
@@ -494,7 +507,7 @@ class Adafactor8Bit(Optimizer):
|
|
|
494
507
|
|
|
495
508
|
if beta1 is not None:
|
|
496
509
|
m_padded_numel = ((p.numel() + m_block_size - 1) // m_block_size) * m_block_size
|
|
497
|
-
state["m_q"] = torch.full((m_padded_numel // 2,),
|
|
510
|
+
state["m_q"] = torch.full((m_padded_numel // 2,), 0x77, dtype=torch.uint8, device=p.device)
|
|
498
511
|
state["m_scale"] = torch.ones(m_padded_numel // m_block_size, dtype=torch.float32, device=p.device)
|
|
499
512
|
state["m_block_size"] = m_block_size
|
|
500
513
|
else:
|
|
@@ -566,7 +579,7 @@ class Adafactor8Bit(Optimizer):
|
|
|
566
579
|
state["m_q"], state["m_scale"] = _quantize_4bit_pytorch(state["m"], m_curr_block_size)
|
|
567
580
|
state.pop("m")
|
|
568
581
|
else:
|
|
569
|
-
state["m_q"] = torch.full((m_padded_numel // 2,),
|
|
582
|
+
state["m_q"] = torch.full((m_padded_numel // 2,), 0x77, dtype=torch.uint8, device=p.device)
|
|
570
583
|
state["m_scale"] = torch.ones(m_padded_numel // m_curr_block_size, dtype=torch.float32, device=p.device)
|
|
571
584
|
state["m_block_size"] = m_curr_block_size
|
|
572
585
|
|
|
@@ -620,7 +633,7 @@ class Adafactor8Bit(Optimizer):
|
|
|
620
633
|
state["m_q"], state["m_scale"] = _quantize_4bit_pytorch(state["m"], m_curr_block_size)
|
|
621
634
|
state.pop("m")
|
|
622
635
|
else:
|
|
623
|
-
state["m_q"] = torch.full((m_padded_numel // 2,),
|
|
636
|
+
state["m_q"] = torch.full((m_padded_numel // 2,), 0x77, dtype=torch.uint8, device=p.device)
|
|
624
637
|
state["m_scale"] = torch.ones(m_padded_numel // m_curr_block_size, dtype=torch.float32, device=p.device)
|
|
625
638
|
state["m_block_size"] = m_curr_block_size
|
|
626
639
|
|
|
@@ -1401,7 +1414,7 @@ def _update_param_apollo(
|
|
|
1401
1414
|
|
|
1402
1415
|
if state.get("m_low_q") is None:
|
|
1403
1416
|
m_padded_numel = ((grad_low_numel + m_curr_block_size - 1) // m_curr_block_size) * m_curr_block_size
|
|
1404
|
-
state["m_low_q"] = torch.full((m_padded_numel // 2,),
|
|
1417
|
+
state["m_low_q"] = torch.full((m_padded_numel // 2,), 0x77, dtype=torch.uint8, device=grad_low.device)
|
|
1405
1418
|
state["m_low_scale"] = torch.ones(m_padded_numel // m_curr_block_size, dtype=torch.float32, device=grad_low.device)
|
|
1406
1419
|
state["m_block_size"] = m_curr_block_size
|
|
1407
1420
|
|
|
@@ -1481,7 +1494,7 @@ def _update_param_apollo(
|
|
|
1481
1494
|
|
|
1482
1495
|
if state.get("m_low_q") is None:
|
|
1483
1496
|
m_padded_numel = ((grad_low_numel + m_curr_block_size - 1) // m_curr_block_size) * m_curr_block_size
|
|
1484
|
-
state["m_low_q"] = torch.full((m_padded_numel // 2,),
|
|
1497
|
+
state["m_low_q"] = torch.full((m_padded_numel // 2,), 0x77, dtype=torch.uint8, device=grad_low.device)
|
|
1485
1498
|
state["m_low_scale"] = torch.ones(m_padded_numel // m_curr_block_size, dtype=torch.float32, device=grad_low.device)
|
|
1486
1499
|
state["m_block_size"] = m_curr_block_size
|
|
1487
1500
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: adafactor8bit
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.4
|
|
4
4
|
Summary: 8-bit Adafactor Optimizer with Fused CUDA Kernels
|
|
5
5
|
Home-page: https://github.com/yanfeiwong/adafactor-8bit
|
|
6
6
|
Author: WANG YAN
|
|
@@ -53,7 +53,7 @@ An enhanced 8-bit Adafactor optimizer featuring fused CUDA kernels, log-space bl
|
|
|
53
53
|
|
|
54
54
|
- **Log-Space Quantization**: Maps the second moment (variance) to the log2 space before 8-bit quantization. This approach accommodates the long-tail distribution of variances, reducing the risk of small second-moment estimates being truncated to zero and improving overall training stability.
|
|
55
55
|
- **Fused CUDA Kernels**: Combines dequantization, EMA updates, Warp-Shuffle reductions, and requantization into single kernels. It utilizes `float4` vectorization to optimize memory bandwidth usage.
|
|
56
|
-
- **Optional
|
|
56
|
+
- **Optional NF4 First Moment**: Stores the optional first moment (`beta1`) using Normal Float 4-bit (NF4) non-uniform quantization, preserving small momentum updates while keeping memory overhead minimal.
|
|
57
57
|
- **CAME Confidence Guidance**: Optional Confidence-guided Adaptive Memory Efficient Optimization (CAME) that estimates update confidence from historical momentum and adaptively suppresses unstable update directions, improving training stability and reducing loss spikes.
|
|
58
58
|
- **APOLLO Subspace Projection**: Opt-in random subspace projection that estimates adaptive gradient scaling in a low-rank space, preventing stale second-moment statistics and potentially improving convergence and generalization.
|
|
59
59
|
- **Fira Norm-Growth Limiter**: Suppresses destructive gradient spikes by regulating the relative increase of update norms. Originally used for the APOLLO path, it is now available for the standard Adafactor path as well. It improves training stability and often allows the safe removal of external gradient clipping.
|
|
@@ -278,7 +278,7 @@ Enable the APOLLO path to compute gradient scaling factors in a memory-efficient
|
|
|
278
278
|
- **`apollo_factorize` (Experimental)**: Applies Adafactor's row/column factorization within the low-rank subspace. Mathematically, this leverages the norm-preserving property of random projections to approximate the variance of the primary dimension, while the secondary dimension's variance is estimated across random bases, introducing inherent noise. This dual-compression mechanism drastically reduces optimizer state overhead. Note that for smaller models, the actual VRAM savings might be marginal, and the introduced noise could impact convergence stability. Use with caution.
|
|
279
279
|
- **Fira Limiter Integration**: The APOLLO path automatically applies the Fira Norm-Growth Limiter to the scaled gradients to prevent sudden gradient rises from causing loss spikes. You can adjust its sensitivity using the global `fira_margin` parameter.
|
|
280
280
|
|
|
281
|
-
##
|
|
281
|
+
## 🧊 CAME Confidence-Guided Updates
|
|
282
282
|
|
|
283
283
|
Enable the CAME (Confidence-guided Adaptive Memory Efficient Optimization) path to add a confidence estimation stage after momentum accumulation:
|
|
284
284
|
|
|
@@ -326,24 +326,25 @@ If you are migrating from optimizers like AdamW, Adafactor's learning rate behav
|
|
|
326
326
|
*These are safe starting points. Always validate on your own task and batch size.*
|
|
327
327
|
|
|
328
328
|
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
329
|
## 🎓 Acknowledgements
|
|
333
330
|
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
Thanks to **Tim Dettmers** for the inspiration from the paper [8-BIT OPTIMIZERS VIA BLOCK-WISE QUANTIZATION](https://arxiv.org/abs/2110.02861) and the [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) library.
|
|
337
|
-
|
|
338
|
-
Thanks to **Hanqing Zhu**, **Zhenyu Zhang**, and the team for proposing the approximated gradient scaling method in the paper [APOLLO: SGD-Like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270).
|
|
331
|
+
This project builds upon the foundational work of several researchers and open-source communities. Sincere thanks to the following for their invaluable contributions:
|
|
339
332
|
|
|
340
|
-
|
|
333
|
+
### Core Algorithm & Optimizer Design
|
|
334
|
+
- **Noam Shazeer & Mitchell Stern** for proposing the original **Adafactor** algorithm ([Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235)).
|
|
335
|
+
- **Tim Dettmers** for the inspiration from **8-bit block-wise quantization** ([8-BIT OPTIMIZERS VIA BLOCK-WISE QUANTIZATION](https://arxiv.org/abs/2110.02861)) and the [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) library.
|
|
336
|
+
- **Hanqing Zhu, Zhenyu Zhang, et al.** for the **APOLLO** algorithm ([APOLLO: SGD-Like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270)).
|
|
337
|
+
- **Xi Chen, Kaituo Feng, et al.** for the **Norm-Growth Limiter** mechanism in **Fira** ([Fira: Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623)).
|
|
338
|
+
- **Yang Luo, et al.** for the **confidence-guided strategy** in **CAME** ([CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047)).
|
|
341
339
|
|
|
342
|
-
|
|
340
|
+
### Quantization & Implementation
|
|
341
|
+
- **The QLoRA Team** for pioneering the **4-bit NormalFloat (NF4)** quantization format ([QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)) that inspired our first moment quantization.
|
|
342
|
+
- **The PyTorch AO Team** for their work on [4-bit optimizer states](https://github.com/pytorch/ao/tree/main/torchao/optim), validating distribution-aware quantization for optimizer moments.
|
|
343
|
+
- **The PyTorch Team** for providing the foundational optimizer implementation and the C++ Extension toolchain.
|
|
343
344
|
|
|
344
|
-
|
|
345
|
+
### Technical Review & Discussion
|
|
346
|
+
- **Qwen, ChatGLM, and DeepSeek** (large language models) for valuable technical discussions and code reviews on CUDA low-level optimization, memory safety mechanisms, and cross-platform compilation pipeline design.
|
|
345
347
|
|
|
346
|
-
Thanks to the large language models **Qwen**, **ChatGLM** and **DeepSeek** for valuable technical discussions and code reviews on CUDA low-level optimization and memory safety mechanisms.
|
|
347
348
|
|
|
348
349
|
## 🏛️ License
|
|
349
350
|
|
|
@@ -9,7 +9,7 @@ long_description = (this_directory / "README.md").read_text(encoding="utf-8")
|
|
|
9
9
|
|
|
10
10
|
setup(
|
|
11
11
|
name="adafactor8bit",
|
|
12
|
-
version="0.2.
|
|
12
|
+
version="0.2.4",
|
|
13
13
|
description="8-bit Adafactor Optimizer with Fused CUDA Kernels",
|
|
14
14
|
author="WANG YAN",
|
|
15
15
|
author_email="yanfeiwong1997@outlook.com",
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|