adafactor8bit 0.2.2__tar.gz → 0.2.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adafactor8bit-0.2.2/adafactor8bit.egg-info → adafactor8bit-0.2.5}/PKG-INFO +16 -15
- {adafactor8bit-0.2.2 → adafactor8bit-0.2.5}/README.md +15 -14
- {adafactor8bit-0.2.2 → adafactor8bit-0.2.5}/adafactor8bit/kernels.cu +50 -29
- {adafactor8bit-0.2.2 → adafactor8bit-0.2.5}/adafactor8bit/optimizer.py +59 -28
- {adafactor8bit-0.2.2 → adafactor8bit-0.2.5/adafactor8bit.egg-info}/PKG-INFO +16 -15
- {adafactor8bit-0.2.2 → adafactor8bit-0.2.5}/setup.py +1 -1
- {adafactor8bit-0.2.2 → adafactor8bit-0.2.5}/LICENSE +0 -0
- {adafactor8bit-0.2.2 → adafactor8bit-0.2.5}/MANIFEST.in +0 -0
- {adafactor8bit-0.2.2 → adafactor8bit-0.2.5}/adafactor8bit/__init__.py +0 -0
- {adafactor8bit-0.2.2 → adafactor8bit-0.2.5}/adafactor8bit.egg-info/SOURCES.txt +0 -0
- {adafactor8bit-0.2.2 → adafactor8bit-0.2.5}/adafactor8bit.egg-info/dependency_links.txt +0 -0
- {adafactor8bit-0.2.2 → adafactor8bit-0.2.5}/adafactor8bit.egg-info/requires.txt +0 -0
- {adafactor8bit-0.2.2 → adafactor8bit-0.2.5}/adafactor8bit.egg-info/top_level.txt +0 -0
- {adafactor8bit-0.2.2 → adafactor8bit-0.2.5}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: adafactor8bit
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.5
|
|
4
4
|
Summary: 8-bit Adafactor Optimizer with Fused CUDA Kernels
|
|
5
5
|
Home-page: https://github.com/yanfeiwong/adafactor-8bit
|
|
6
6
|
Author: WANG YAN
|
|
@@ -53,7 +53,7 @@ An enhanced 8-bit Adafactor optimizer featuring fused CUDA kernels, log-space bl
|
|
|
53
53
|
|
|
54
54
|
- **Log-Space Quantization**: Maps the second moment (variance) to the log2 space before 8-bit quantization. This approach accommodates the long-tail distribution of variances, reducing the risk of small second-moment estimates being truncated to zero and improving overall training stability.
|
|
55
55
|
- **Fused CUDA Kernels**: Combines dequantization, EMA updates, Warp-Shuffle reductions, and requantization into single kernels. It utilizes `float4` vectorization to optimize memory bandwidth usage.
|
|
56
|
-
- **Optional
|
|
56
|
+
- **Optional NF4 First Moment**: Stores the optional first moment (`beta1`) using Normal Float 4-bit (NF4) non-uniform quantization, preserving small momentum updates while keeping memory overhead minimal.
|
|
57
57
|
- **CAME Confidence Guidance**: Optional Confidence-guided Adaptive Memory Efficient Optimization (CAME) that estimates update confidence from historical momentum and adaptively suppresses unstable update directions, improving training stability and reducing loss spikes.
|
|
58
58
|
- **APOLLO Subspace Projection**: Opt-in random subspace projection that estimates adaptive gradient scaling in a low-rank space, preventing stale second-moment statistics and potentially improving convergence and generalization.
|
|
59
59
|
- **Fira Norm-Growth Limiter**: Suppresses destructive gradient spikes by regulating the relative increase of update norms. Originally used for the APOLLO path, it is now available for the standard Adafactor path as well. It improves training stability and often allows the safe removal of external gradient clipping.
|
|
@@ -278,7 +278,7 @@ Enable the APOLLO path to compute gradient scaling factors in a memory-efficient
|
|
|
278
278
|
- **`apollo_factorize` (Experimental)**: Applies Adafactor's row/column factorization within the low-rank subspace. Mathematically, this leverages the norm-preserving property of random projections to approximate the variance of the primary dimension, while the secondary dimension's variance is estimated across random bases, introducing inherent noise. This dual-compression mechanism drastically reduces optimizer state overhead. Note that for smaller models, the actual VRAM savings might be marginal, and the introduced noise could impact convergence stability. Use with caution.
|
|
279
279
|
- **Fira Limiter Integration**: The APOLLO path automatically applies the Fira Norm-Growth Limiter to the scaled gradients to prevent sudden gradient rises from causing loss spikes. You can adjust its sensitivity using the global `fira_margin` parameter.
|
|
280
280
|
|
|
281
|
-
##
|
|
281
|
+
## 🧊 CAME Confidence-Guided Updates
|
|
282
282
|
|
|
283
283
|
Enable the CAME (Confidence-guided Adaptive Memory Efficient Optimization) path to add a confidence estimation stage after momentum accumulation:
|
|
284
284
|
|
|
@@ -326,24 +326,25 @@ If you are migrating from optimizers like AdamW, Adafactor's learning rate behav
|
|
|
326
326
|
*These are safe starting points. Always validate on your own task and batch size.*
|
|
327
327
|
|
|
328
328
|
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
329
|
## 🎓 Acknowledgements
|
|
333
330
|
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
Thanks to **Tim Dettmers** for the inspiration from the paper [8-BIT OPTIMIZERS VIA BLOCK-WISE QUANTIZATION](https://arxiv.org/abs/2110.02861) and the [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) library.
|
|
337
|
-
|
|
338
|
-
Thanks to **Hanqing Zhu**, **Zhenyu Zhang**, and the team for proposing the approximated gradient scaling method in the paper [APOLLO: SGD-Like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270).
|
|
331
|
+
This project builds upon the foundational work of several researchers and open-source communities. Sincere thanks to the following for their invaluable contributions:
|
|
339
332
|
|
|
340
|
-
|
|
333
|
+
### Core Algorithm & Optimizer Design
|
|
334
|
+
- **Noam Shazeer & Mitchell Stern** for proposing the original **Adafactor** algorithm ([Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235)).
|
|
335
|
+
- **Tim Dettmers** for the inspiration from **8-bit block-wise quantization** ([8-BIT OPTIMIZERS VIA BLOCK-WISE QUANTIZATION](https://arxiv.org/abs/2110.02861)) and the [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) library.
|
|
336
|
+
- **Hanqing Zhu, Zhenyu Zhang, et al.** for the **APOLLO** algorithm ([APOLLO: SGD-Like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270)).
|
|
337
|
+
- **Xi Chen, Kaituo Feng, et al.** for the **Norm-Growth Limiter** mechanism in **Fira** ([Fira: Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623)).
|
|
338
|
+
- **Yang Luo, et al.** for the **confidence-guided strategy** in **CAME** ([CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047)).
|
|
341
339
|
|
|
342
|
-
|
|
340
|
+
### Quantization & Implementation
|
|
341
|
+
- **The QLoRA Team** for pioneering the **4-bit NormalFloat (NF4)** quantization format ([QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)) that inspired our first moment quantization.
|
|
342
|
+
- **The PyTorch AO Team** for their work on [4-bit optimizer states](https://github.com/pytorch/ao/tree/main/torchao/optim), validating distribution-aware quantization for optimizer moments.
|
|
343
|
+
- **The PyTorch Team** for providing the foundational optimizer implementation and the C++ Extension toolchain.
|
|
343
344
|
|
|
344
|
-
|
|
345
|
+
### Technical Review & Discussion
|
|
346
|
+
- **Qwen, ChatGLM, and DeepSeek** (large language models) for valuable technical discussions and code reviews on CUDA low-level optimization, memory safety mechanisms, and cross-platform compilation pipeline design.
|
|
345
347
|
|
|
346
|
-
Thanks to the large language models **Qwen**, **ChatGLM** and **DeepSeek** for valuable technical discussions and code reviews on CUDA low-level optimization and memory safety mechanisms.
|
|
347
348
|
|
|
348
349
|
## 🏛️ License
|
|
349
350
|
|
|
@@ -26,7 +26,7 @@ An enhanced 8-bit Adafactor optimizer featuring fused CUDA kernels, log-space bl
|
|
|
26
26
|
|
|
27
27
|
- **Log-Space Quantization**: Maps the second moment (variance) to the log2 space before 8-bit quantization. This approach accommodates the long-tail distribution of variances, reducing the risk of small second-moment estimates being truncated to zero and improving overall training stability.
|
|
28
28
|
- **Fused CUDA Kernels**: Combines dequantization, EMA updates, Warp-Shuffle reductions, and requantization into single kernels. It utilizes `float4` vectorization to optimize memory bandwidth usage.
|
|
29
|
-
- **Optional
|
|
29
|
+
- **Optional NF4 First Moment**: Stores the optional first moment (`beta1`) using Normal Float 4-bit (NF4) non-uniform quantization, preserving small momentum updates while keeping memory overhead minimal.
|
|
30
30
|
- **CAME Confidence Guidance**: Optional Confidence-guided Adaptive Memory Efficient Optimization (CAME) that estimates update confidence from historical momentum and adaptively suppresses unstable update directions, improving training stability and reducing loss spikes.
|
|
31
31
|
- **APOLLO Subspace Projection**: Opt-in random subspace projection that estimates adaptive gradient scaling in a low-rank space, preventing stale second-moment statistics and potentially improving convergence and generalization.
|
|
32
32
|
- **Fira Norm-Growth Limiter**: Suppresses destructive gradient spikes by regulating the relative increase of update norms. Originally used for the APOLLO path, it is now available for the standard Adafactor path as well. It improves training stability and often allows the safe removal of external gradient clipping.
|
|
@@ -251,7 +251,7 @@ Enable the APOLLO path to compute gradient scaling factors in a memory-efficient
|
|
|
251
251
|
- **`apollo_factorize` (Experimental)**: Applies Adafactor's row/column factorization within the low-rank subspace. Mathematically, this leverages the norm-preserving property of random projections to approximate the variance of the primary dimension, while the secondary dimension's variance is estimated across random bases, introducing inherent noise. This dual-compression mechanism drastically reduces optimizer state overhead. Note that for smaller models, the actual VRAM savings might be marginal, and the introduced noise could impact convergence stability. Use with caution.
|
|
252
252
|
- **Fira Limiter Integration**: The APOLLO path automatically applies the Fira Norm-Growth Limiter to the scaled gradients to prevent sudden gradient rises from causing loss spikes. You can adjust its sensitivity using the global `fira_margin` parameter.
|
|
253
253
|
|
|
254
|
-
##
|
|
254
|
+
## 🧊 CAME Confidence-Guided Updates
|
|
255
255
|
|
|
256
256
|
Enable the CAME (Confidence-guided Adaptive Memory Efficient Optimization) path to add a confidence estimation stage after momentum accumulation:
|
|
257
257
|
|
|
@@ -299,24 +299,25 @@ If you are migrating from optimizers like AdamW, Adafactor's learning rate behav
|
|
|
299
299
|
*These are safe starting points. Always validate on your own task and batch size.*
|
|
300
300
|
|
|
301
301
|
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
302
|
## 🎓 Acknowledgements
|
|
306
303
|
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
Thanks to **Tim Dettmers** for the inspiration from the paper [8-BIT OPTIMIZERS VIA BLOCK-WISE QUANTIZATION](https://arxiv.org/abs/2110.02861) and the [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) library.
|
|
310
|
-
|
|
311
|
-
Thanks to **Hanqing Zhu**, **Zhenyu Zhang**, and the team for proposing the approximated gradient scaling method in the paper [APOLLO: SGD-Like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270).
|
|
304
|
+
This project builds upon the foundational work of several researchers and open-source communities. Sincere thanks to the following for their invaluable contributions:
|
|
312
305
|
|
|
313
|
-
|
|
306
|
+
### Core Algorithm & Optimizer Design
|
|
307
|
+
- **Noam Shazeer & Mitchell Stern** for proposing the original **Adafactor** algorithm ([Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235)).
|
|
308
|
+
- **Tim Dettmers** for the inspiration from **8-bit block-wise quantization** ([8-BIT OPTIMIZERS VIA BLOCK-WISE QUANTIZATION](https://arxiv.org/abs/2110.02861)) and the [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) library.
|
|
309
|
+
- **Hanqing Zhu, Zhenyu Zhang, et al.** for the **APOLLO** algorithm ([APOLLO: SGD-Like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270)).
|
|
310
|
+
- **Xi Chen, Kaituo Feng, et al.** for the **Norm-Growth Limiter** mechanism in **Fira** ([Fira: Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623)).
|
|
311
|
+
- **Yang Luo, et al.** for the **confidence-guided strategy** in **CAME** ([CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047)).
|
|
314
312
|
|
|
315
|
-
|
|
313
|
+
### Quantization & Implementation
|
|
314
|
+
- **The QLoRA Team** for pioneering the **4-bit NormalFloat (NF4)** quantization format ([QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)) that inspired our first moment quantization.
|
|
315
|
+
- **The PyTorch AO Team** for their work on [4-bit optimizer states](https://github.com/pytorch/ao/tree/main/torchao/optim), validating distribution-aware quantization for optimizer moments.
|
|
316
|
+
- **The PyTorch Team** for providing the foundational optimizer implementation and the C++ Extension toolchain.
|
|
316
317
|
|
|
317
|
-
|
|
318
|
+
### Technical Review & Discussion
|
|
319
|
+
- **Qwen, ChatGLM, and DeepSeek** (large language models) for valuable technical discussions and code reviews on CUDA low-level optimization, memory safety mechanisms, and cross-platform compilation pipeline design.
|
|
318
320
|
|
|
319
|
-
Thanks to the large language models **Qwen**, **ChatGLM** and **DeepSeek** for valuable technical discussions and code reviews on CUDA low-level optimization and memory safety mechanisms.
|
|
320
321
|
|
|
321
322
|
## 🏛️ License
|
|
322
323
|
|
|
@@ -7,7 +7,26 @@
|
|
|
7
7
|
__device__ constexpr float INV_255 = 1.0f / 255.0f;
|
|
8
8
|
__device__ constexpr float MIN_LOG = -126.0f;
|
|
9
9
|
__device__ constexpr float MIN_VAL = 1.17549435e-38f;
|
|
10
|
-
|
|
10
|
+
|
|
11
|
+
__constant__ float NF4_QMAP[16] = {
|
|
12
|
+
-1.0f, -0.6961928f, -0.52507306f, -0.3895074f,
|
|
13
|
+
-0.27408478f, -0.17286907f, -0.07958022f, 0.0f,
|
|
14
|
+
0.07958022f, 0.17286907f, 0.27408478f, 0.3895074f,
|
|
15
|
+
0.52507306f, 0.6961928f, 0.8641379f, 1.0f
|
|
16
|
+
};
|
|
17
|
+
|
|
18
|
+
__device__ __forceinline__ int find_nearest_nf4(float x) {
|
|
19
|
+
float x_abs = fabsf(x);
|
|
20
|
+
if (x_abs < 0.0397901f) return 7;
|
|
21
|
+
if (x_abs < 0.1262246f) return (x >= 0.0f) ? 8 : 6;
|
|
22
|
+
if (x_abs < 0.2234769f) return (x >= 0.0f) ? 9 : 5;
|
|
23
|
+
if (x_abs < 0.3317961f) return (x >= 0.0f) ? 10 : 4;
|
|
24
|
+
if (x_abs < 0.4572902f) return (x >= 0.0f) ? 11 : 3;
|
|
25
|
+
if (x_abs < 0.6106329f) return (x >= 0.0f) ? 12 : 2;
|
|
26
|
+
if (x_abs < 0.7801653f) return (x >= 0.0f) ? 13 : 1;
|
|
27
|
+
if (x_abs < 0.9320689f) return (x >= 0.0f) ? 14 : 0;
|
|
28
|
+
return (x >= 0.0f) ? 15 : 0;
|
|
29
|
+
}
|
|
11
30
|
|
|
12
31
|
// ==========================================
|
|
13
32
|
// 1. Fused Log-Quantize Lerp (EMA Update for V_t)
|
|
@@ -120,7 +139,7 @@ __global__ void fused_log_quantize_lerp_kernel(
|
|
|
120
139
|
}
|
|
121
140
|
__syncthreads();
|
|
122
141
|
|
|
123
|
-
float max_log = fminf(fmaxf(s_max[0], MIN_LOG + 1e-12f),
|
|
142
|
+
float max_log = fminf(fmaxf(s_max[0], MIN_LOG + 1e-12f), 126.0f);
|
|
124
143
|
float new_scale = max_log - MIN_LOG;
|
|
125
144
|
float inv_scale = 255.0f / (max_log - MIN_LOG);
|
|
126
145
|
|
|
@@ -226,11 +245,10 @@ __global__ void fused_4bit_quantize_lerp_kernel(
|
|
|
226
245
|
|
|
227
246
|
uchar2 old_q = q_vec[idx];
|
|
228
247
|
|
|
229
|
-
|
|
230
|
-
float
|
|
231
|
-
float
|
|
232
|
-
float
|
|
233
|
-
float m_old3 = (float)((old_q.y & 0x0F) - 8) * old_scale;
|
|
248
|
+
float m_old0 = NF4_QMAP[(old_q.x >> 4)] * old_scale;
|
|
249
|
+
float m_old1 = NF4_QMAP[(old_q.x & 0x0F)] * old_scale;
|
|
250
|
+
float m_old2 = NF4_QMAP[(old_q.y >> 4)] * old_scale;
|
|
251
|
+
float m_old3 = NF4_QMAP[(old_q.y & 0x0F)] * old_scale;
|
|
234
252
|
|
|
235
253
|
float m_new0 = beta * m_old0 + one_minus_b * val_x;
|
|
236
254
|
float m_new1 = beta * m_old1 + one_minus_b * val_y;
|
|
@@ -267,7 +285,7 @@ __global__ void fused_4bit_quantize_lerp_kernel(
|
|
|
267
285
|
__syncthreads();
|
|
268
286
|
|
|
269
287
|
float abs_max = fmaxf(s_max[0], 1e-12f);
|
|
270
|
-
float new_scale = abs_max
|
|
288
|
+
float new_scale = abs_max;
|
|
271
289
|
float inv_scale = 1.0f / new_scale;
|
|
272
290
|
|
|
273
291
|
for (int i = 0; i < vec_iters; i++) {
|
|
@@ -279,12 +297,10 @@ __global__ void fused_4bit_quantize_lerp_kernel(
|
|
|
279
297
|
float m2 = local_m[idx * 4 + 2];
|
|
280
298
|
float m3 = local_m[idx * 4 + 3];
|
|
281
299
|
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
int
|
|
285
|
-
int
|
|
286
|
-
int q2 = max(-8, min(7, __float2int_rn(m2 * inv_scale))) + 8;
|
|
287
|
-
int q3 = max(-8, min(7, __float2int_rn(m3 * inv_scale))) + 8;
|
|
300
|
+
int q0 = find_nearest_nf4(m0 * inv_scale);
|
|
301
|
+
int q1 = find_nearest_nf4(m1 * inv_scale);
|
|
302
|
+
int q2 = find_nearest_nf4(m2 * inv_scale);
|
|
303
|
+
int q3 = find_nearest_nf4(m3 * inv_scale);
|
|
288
304
|
|
|
289
305
|
uchar2 out_q;
|
|
290
306
|
out_q.x = (unsigned char)((q0 << 4) | q1);
|
|
@@ -349,7 +365,8 @@ __global__ void compute_update_norm_2d_kernel(
|
|
|
349
365
|
max_log = fmaxf(max_log, -53.0f);
|
|
350
366
|
float inv_std = exp2f(-0.5f * max_log);
|
|
351
367
|
|
|
352
|
-
float
|
|
368
|
+
float g_val = (isnan(grad[idx]) || isinf(grad[idx])) ? 0.0f : grad[idx];
|
|
369
|
+
float u_ij = g_val * inv_std;
|
|
353
370
|
|
|
354
371
|
sq += u_ij * u_ij;
|
|
355
372
|
}
|
|
@@ -418,7 +435,8 @@ __global__ void apply_update_2d_kernel(
|
|
|
418
435
|
max_log = fmaxf(max_log, -53.0f);
|
|
419
436
|
float inv_std = exp2f(-0.5f * max_log);
|
|
420
437
|
|
|
421
|
-
float
|
|
438
|
+
float g_val = (isnan(grad[idx]) || isinf(grad[idx])) ? 0.0f : grad[idx];
|
|
439
|
+
float u_ij = g_val * inv_std;
|
|
422
440
|
|
|
423
441
|
float p_val = static_cast<float>(param[idx]);
|
|
424
442
|
p_val -= step_size * u_ij;
|
|
@@ -464,7 +482,8 @@ __global__ void compute_update_norm_1d_kernel(
|
|
|
464
482
|
max_log = fmaxf(max_log, -53.0f);
|
|
465
483
|
float inv_std = exp2f(-0.5f * max_log);
|
|
466
484
|
|
|
467
|
-
float
|
|
485
|
+
float g_val = (isnan(grad[idx]) || isinf(grad[idx])) ? 0.0f : grad[idx];
|
|
486
|
+
float u_val = g_val * inv_std;
|
|
468
487
|
|
|
469
488
|
sq += u_val * u_val;
|
|
470
489
|
}
|
|
@@ -516,7 +535,8 @@ __global__ void apply_update_1d_kernel(
|
|
|
516
535
|
max_log = fmaxf(max_log, -53.0f);
|
|
517
536
|
float inv_std = exp2f(-0.5f * max_log);
|
|
518
537
|
|
|
519
|
-
float
|
|
538
|
+
float g_val = (isnan(grad[idx]) || isinf(grad[idx])) ? 0.0f : grad[idx];
|
|
539
|
+
float u_val = g_val * inv_std;
|
|
520
540
|
|
|
521
541
|
float p_val = static_cast<float>(param[idx]);
|
|
522
542
|
p_val -= step_size * u_val;
|
|
@@ -561,7 +581,7 @@ __global__ void compute_update_norm_m_2d_kernel(
|
|
|
561
581
|
// Unpack 4-bit m_t
|
|
562
582
|
unsigned char packed = m_q[idx / 2];
|
|
563
583
|
int q_int = (idx & 1) ? (packed & 0x0F) : (packed >> 4);
|
|
564
|
-
float m_val =
|
|
584
|
+
float m_val = NF4_QMAP[q_int] * m_scale[idx / m_block_size];
|
|
565
585
|
|
|
566
586
|
float log_r = (float)row_var_q[b * R + r] * INV_255 * row_var_scale[(b * R + r) / v_block_size] + MIN_LOG;
|
|
567
587
|
float log_c = (float)col_var_q[b * C + c] * INV_255 * col_var_scale[(b * C + c) / v_block_size] + MIN_LOG;
|
|
@@ -636,7 +656,7 @@ __global__ void apply_update_m_2d_kernel(
|
|
|
636
656
|
// Unpack 4-bit m_t
|
|
637
657
|
unsigned char packed = m_q[idx / 2];
|
|
638
658
|
int q_int = (idx & 1) ? (packed & 0x0F) : (packed >> 4);
|
|
639
|
-
float m_val =
|
|
659
|
+
float m_val = NF4_QMAP[q_int] * m_scale[idx / m_block_size];
|
|
640
660
|
|
|
641
661
|
float log_r = (float)row_var_q[b * R + r] * INV_255 * row_var_scale[(b * R + r) / v_block_size] + MIN_LOG;
|
|
642
662
|
float log_c = (float)col_var_q[b * C + c] * INV_255 * col_var_scale[(b * C + c) / v_block_size] + MIN_LOG;
|
|
@@ -692,7 +712,7 @@ __global__ void compute_update_norm_m_1d_kernel(
|
|
|
692
712
|
// Unpack 4-bit m_t
|
|
693
713
|
unsigned char packed = m_q[idx / 2];
|
|
694
714
|
int q_int = (idx & 1) ? (packed & 0x0F) : (packed >> 4);
|
|
695
|
-
float m_val =
|
|
715
|
+
float m_val = NF4_QMAP[q_int] * m_scale[idx / m_block_size];
|
|
696
716
|
|
|
697
717
|
float log_v = (float)variance_q[idx] * INV_255 * variance_scale[idx / v_block_size] + MIN_LOG;
|
|
698
718
|
|
|
@@ -751,7 +771,7 @@ __global__ void apply_update_m_1d_kernel(
|
|
|
751
771
|
// Unpack 4-bit m_t
|
|
752
772
|
unsigned char packed = m_q[idx / 2];
|
|
753
773
|
int q_int = (idx & 1) ? (packed & 0x0F) : (packed >> 4);
|
|
754
|
-
float m_val =
|
|
774
|
+
float m_val = NF4_QMAP[q_int] * m_scale[idx / m_block_size];
|
|
755
775
|
|
|
756
776
|
float log_v = (float)variance_q[idx] * INV_255 * variance_scale[idx / v_block_size] + MIN_LOG;
|
|
757
777
|
|
|
@@ -808,7 +828,7 @@ __global__ void compute_apollo_norms_kernel(
|
|
|
808
828
|
// 4-bit m 解包
|
|
809
829
|
unsigned char m_byte = m_q[global_idx / 2];
|
|
810
830
|
int m_int = (global_idx & 1) ? (m_byte & 0x0F) : (m_byte >> 4);
|
|
811
|
-
float m_val =
|
|
831
|
+
float m_val = NF4_QMAP[m_int] * m_scale[global_idx / m_block_size];
|
|
812
832
|
// 8-bit log v 解包
|
|
813
833
|
unsigned char v_byte = v_q[global_idx];
|
|
814
834
|
float log_v = (float)v_byte * INV_255 * v_scale[global_idx / v_block_size] + MIN_LOG;
|
|
@@ -892,7 +912,7 @@ __global__ void dequantize_4bit_kernel(
|
|
|
892
912
|
if (idx >= numel) return;
|
|
893
913
|
unsigned char packed = q[idx / 2];
|
|
894
914
|
int q_int = (idx & 1) ? (packed & 0x0F) : (packed >> 4);
|
|
895
|
-
output[idx] =
|
|
915
|
+
output[idx] = NF4_QMAP[q_int] * scale[idx / block_size];
|
|
896
916
|
}
|
|
897
917
|
|
|
898
918
|
void dequantize_4bit_cuda(
|
|
@@ -922,7 +942,7 @@ __global__ void compute_update_norm_1d_full_kernel(
|
|
|
922
942
|
float one_minus_b = 1.0f - beta;
|
|
923
943
|
|
|
924
944
|
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < numel; idx += stride) {
|
|
925
|
-
float g = grad[idx];
|
|
945
|
+
float g = (isnan(grad[idx]) || isinf(grad[idx])) ? 0.0f : grad[idx];
|
|
926
946
|
float g2 = g * g;
|
|
927
947
|
float v = one_minus_b * variance[idx] + beta * g2;
|
|
928
948
|
variance[idx] = v;
|
|
@@ -979,7 +999,7 @@ __global__ void apply_update_1d_full_kernel(
|
|
|
979
999
|
|
|
980
1000
|
int stride = gridDim.x * blockDim.x;
|
|
981
1001
|
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < numel; idx += stride) {
|
|
982
|
-
float g = grad[idx];
|
|
1002
|
+
float g = (isnan(grad[idx]) || isinf(grad[idx])) ? 0.0f : grad[idx];
|
|
983
1003
|
float v = variance[idx];
|
|
984
1004
|
float inv_std = rsqrtf(fmaxf(v, eps_sq));
|
|
985
1005
|
float u = g * inv_std;
|
|
@@ -1023,7 +1043,7 @@ __global__ void compute_update_norm_1d_full_m_kernel(
|
|
|
1023
1043
|
float one_minus_bv = 1.0f - beta_val;
|
|
1024
1044
|
|
|
1025
1045
|
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < numel; idx += stride) {
|
|
1026
|
-
float g = grad[idx];
|
|
1046
|
+
float g = (isnan(grad[idx]) || isinf(grad[idx])) ? 0.0f : grad[idx];
|
|
1027
1047
|
float g2 = g * g;
|
|
1028
1048
|
|
|
1029
1049
|
float v = one_minus_bv * variance[idx] + beta_val * g2;
|
|
@@ -1132,7 +1152,7 @@ __global__ void came_compute_residual_2d_kernel(
|
|
|
1132
1152
|
|
|
1133
1153
|
unsigned char packed = m_q[idx / 2];
|
|
1134
1154
|
int q_int = (idx & 1) ? (packed & 0x0F) : (packed >> 4);
|
|
1135
|
-
float m_val =
|
|
1155
|
+
float m_val = NF4_QMAP[q_int] * m_scale[idx / m_block_size];
|
|
1136
1156
|
|
|
1137
1157
|
float log_r = (float)row_var_q[b * R + r] * INV_255 * row_var_scale[(b * R + r) / v_block_size] + MIN_LOG;
|
|
1138
1158
|
float log_c = (float)col_var_q[b * C + c] * INV_255 * col_var_scale[(b * C + c) / v_block_size] + MIN_LOG;
|
|
@@ -1143,7 +1163,8 @@ __global__ void came_compute_residual_2d_kernel(
|
|
|
1143
1163
|
max_log = fmaxf(max_log, -53.0f);
|
|
1144
1164
|
float inv_std = exp2f(-0.5f * max_log);
|
|
1145
1165
|
|
|
1146
|
-
float
|
|
1166
|
+
float g_val = (isnan(grad[idx]) || isinf(grad[idx])) ? 0.0f : grad[idx];
|
|
1167
|
+
float diff = (g_val - m_val) * inv_std;
|
|
1147
1168
|
float res = diff * diff;
|
|
1148
1169
|
|
|
1149
1170
|
atomicAdd(&res_col_sum[b * C + c], res);
|
|
@@ -21,6 +21,14 @@ _FP32_MIN_LOG = -126.0
|
|
|
21
21
|
_INV_255 = 1.0 / 255.0
|
|
22
22
|
_INV_7 = 1.0 / 7.0
|
|
23
23
|
|
|
24
|
+
_NF4_TABLE = torch.tensor([
|
|
25
|
+
-1.0, -0.6961928, -0.52507306, -0.3895074,
|
|
26
|
+
-0.27408478, -0.17286907, -0.07958022, 0.0,
|
|
27
|
+
0.07958022, 0.17286907, 0.27408478, 0.3895074,
|
|
28
|
+
0.52507306, 0.6961928, 0.8641379, 1.0
|
|
29
|
+
], dtype=torch.float32)
|
|
30
|
+
|
|
31
|
+
|
|
24
32
|
# ==========================================
|
|
25
33
|
# 1. CUDA Kernel JIT Loading
|
|
26
34
|
# ==========================================
|
|
@@ -127,7 +135,7 @@ def _log_dequantize_nonneg(q: Tensor, scale: Tensor, shape: torch.Size, pad: int
|
|
|
127
135
|
return flat.view(shape)
|
|
128
136
|
|
|
129
137
|
def _quantize_4bit_pytorch(m: Tensor, block_size: int) -> Tuple[Tensor, Tensor]:
|
|
130
|
-
"""4-bit
|
|
138
|
+
"""4-bit NF4 non-uniform quantization with physical packing into uint8."""
|
|
131
139
|
flat = m.flatten()
|
|
132
140
|
pad = (block_size - flat.numel() % block_size) % block_size
|
|
133
141
|
if pad:
|
|
@@ -135,12 +143,15 @@ def _quantize_4bit_pytorch(m: Tensor, block_size: int) -> Tuple[Tensor, Tensor]:
|
|
|
135
143
|
|
|
136
144
|
blocks = flat.view(-1, block_size)
|
|
137
145
|
abs_max = blocks.abs().amax(dim=1, keepdim=True).clamp(min=1e-12)
|
|
138
|
-
scale = abs_max
|
|
146
|
+
scale = abs_max
|
|
139
147
|
|
|
140
|
-
|
|
148
|
+
normalized = blocks / scale
|
|
149
|
+
table = _NF4_TABLE.to(normalized.device)
|
|
150
|
+
diff = (normalized.unsqueeze(-1) - table).abs()
|
|
151
|
+
codes = diff.argmin(dim=-1).to(torch.uint8)
|
|
141
152
|
|
|
142
|
-
q_even =
|
|
143
|
-
q_odd =
|
|
153
|
+
q_even = codes[:, 0::2]
|
|
154
|
+
q_odd = codes[:, 1::2]
|
|
144
155
|
packed = (q_even << 4) | q_odd
|
|
145
156
|
|
|
146
157
|
return packed.view(-1), scale.squeeze(-1)
|
|
@@ -152,11 +163,12 @@ def _dequantize_4bit(m_q: Tensor, m_scale: Tensor, numel: int, shape: torch.Size
|
|
|
152
163
|
_CUDA_MODULE.dequantize_4bit(output, m_q, m_scale, numel, block_size)
|
|
153
164
|
return output.view(shape)
|
|
154
165
|
else:
|
|
155
|
-
high = (
|
|
156
|
-
low = (m_q & 0x0F)
|
|
157
|
-
|
|
158
|
-
m_blocks =
|
|
159
|
-
|
|
166
|
+
high = (m_q >> 4)
|
|
167
|
+
low = (m_q & 0x0F)
|
|
168
|
+
codes = torch.stack((high, low), dim=-1).view(-1)
|
|
169
|
+
m_blocks = codes.view(-1, block_size)
|
|
170
|
+
table = _NF4_TABLE.to(m_q.device)
|
|
171
|
+
result = (table[m_blocks.long()] * m_scale.unsqueeze(-1)).view(-1)[:numel]
|
|
160
172
|
return result.view(shape)
|
|
161
173
|
|
|
162
174
|
# ==========================================
|
|
@@ -461,20 +473,21 @@ class Adafactor8Bit(Optimizer):
|
|
|
461
473
|
|
|
462
474
|
if beta1 is not None:
|
|
463
475
|
m_padded_numel = ((p.numel() + m_block_size - 1) // m_block_size) * m_block_size
|
|
464
|
-
state["m_q"] = torch.full((m_padded_numel // 2,),
|
|
476
|
+
state["m_q"] = torch.full((m_padded_numel // 2,), 0x77, dtype=torch.uint8, device=p.device)
|
|
465
477
|
state["m_scale"] = torch.ones(m_padded_numel // m_block_size, dtype=torch.float32, device=p.device)
|
|
466
478
|
state["m_block_size"] = m_block_size
|
|
467
479
|
|
|
468
480
|
if beta3 is not None:
|
|
469
|
-
state["conf_row_q"] = torch.
|
|
470
|
-
state["conf_row_scale"] = torch.
|
|
481
|
+
state["conf_row_q"] = torch.full_like(state["row_var_q"], 255)
|
|
482
|
+
state["conf_row_scale"] = torch.full_like(state["row_var_scale"], 126.0)
|
|
471
483
|
state["conf_row_shape"] = state["row_var_shape"]
|
|
472
484
|
state["conf_row_pad"] = state["row_var_pad"]
|
|
473
485
|
|
|
474
|
-
state["conf_col_q"] = torch.
|
|
475
|
-
state["conf_col_scale"] = torch.
|
|
486
|
+
state["conf_col_q"] = torch.full_like(state["col_var_q"], 255)
|
|
487
|
+
state["conf_col_scale"] = torch.full_like(state["col_var_scale"], 126.0)
|
|
476
488
|
state["conf_col_shape"] = state["col_var_shape"]
|
|
477
489
|
state["conf_col_pad"] = state["col_var_pad"]
|
|
490
|
+
|
|
478
491
|
else:
|
|
479
492
|
state["row_var"] = torch.zeros(r_shape, dtype=torch.float32, device=p.device)
|
|
480
493
|
state["col_var"] = torch.zeros(c_shape, dtype=torch.float32, device=p.device)
|
|
@@ -494,7 +507,7 @@ class Adafactor8Bit(Optimizer):
|
|
|
494
507
|
|
|
495
508
|
if beta1 is not None:
|
|
496
509
|
m_padded_numel = ((p.numel() + m_block_size - 1) // m_block_size) * m_block_size
|
|
497
|
-
state["m_q"] = torch.full((m_padded_numel // 2,),
|
|
510
|
+
state["m_q"] = torch.full((m_padded_numel // 2,), 0x77, dtype=torch.uint8, device=p.device)
|
|
498
511
|
state["m_scale"] = torch.ones(m_padded_numel // m_block_size, dtype=torch.float32, device=p.device)
|
|
499
512
|
state["m_block_size"] = m_block_size
|
|
500
513
|
else:
|
|
@@ -566,7 +579,7 @@ class Adafactor8Bit(Optimizer):
|
|
|
566
579
|
state["m_q"], state["m_scale"] = _quantize_4bit_pytorch(state["m"], m_curr_block_size)
|
|
567
580
|
state.pop("m")
|
|
568
581
|
else:
|
|
569
|
-
state["m_q"] = torch.full((m_padded_numel // 2,),
|
|
582
|
+
state["m_q"] = torch.full((m_padded_numel // 2,), 0x77, dtype=torch.uint8, device=p.device)
|
|
570
583
|
state["m_scale"] = torch.ones(m_padded_numel // m_curr_block_size, dtype=torch.float32, device=p.device)
|
|
571
584
|
state["m_block_size"] = m_curr_block_size
|
|
572
585
|
|
|
@@ -620,7 +633,7 @@ class Adafactor8Bit(Optimizer):
|
|
|
620
633
|
state["m_q"], state["m_scale"] = _quantize_4bit_pytorch(state["m"], m_curr_block_size)
|
|
621
634
|
state.pop("m")
|
|
622
635
|
else:
|
|
623
|
-
state["m_q"] = torch.full((m_padded_numel // 2,),
|
|
636
|
+
state["m_q"] = torch.full((m_padded_numel // 2,), 0x77, dtype=torch.uint8, device=p.device)
|
|
624
637
|
state["m_scale"] = torch.ones(m_padded_numel // m_curr_block_size, dtype=torch.float32, device=p.device)
|
|
625
638
|
state["m_block_size"] = m_curr_block_size
|
|
626
639
|
|
|
@@ -697,6 +710,11 @@ class Adafactor8Bit(Optimizer):
|
|
|
697
710
|
# ==========================================
|
|
698
711
|
def _apply_fira_cuda(state: Dict[str, Any], total_sum_sq: Tensor, alpha: Tensor, fira_margin: float) -> Tuple[Tensor, Tensor]:
|
|
699
712
|
current_norm = total_sum_sq.sqrt().view([])
|
|
713
|
+
|
|
714
|
+
is_finite = torch.isfinite(current_norm)
|
|
715
|
+
current_norm = torch.where(is_finite, current_norm, torch.zeros_like(current_norm))
|
|
716
|
+
total_sum_sq = torch.where(is_finite, total_sum_sq, torch.zeros_like(total_sum_sq))
|
|
717
|
+
|
|
700
718
|
fira_threshold = 1.0 + fira_margin
|
|
701
719
|
|
|
702
720
|
prev_norm = state.get("fira_prev_norm", None)
|
|
@@ -704,13 +722,14 @@ def _apply_fira_cuda(state: Dict[str, Any], total_sum_sq: Tensor, alpha: Tensor,
|
|
|
704
722
|
if not isinstance(prev_norm, Tensor):
|
|
705
723
|
prev_norm = torch.tensor(prev_norm, device=total_sum_sq.device, dtype=torch.float32)
|
|
706
724
|
|
|
725
|
+
is_reset = prev_norm < 1e-6
|
|
707
726
|
ratio = current_norm / (prev_norm + 1e-8)
|
|
708
727
|
limiter = torch.clamp_min(ratio, fira_threshold) / fira_threshold
|
|
709
|
-
final_scale = 1.0 / limiter
|
|
728
|
+
final_scale = torch.where(is_reset, torch.ones_like(current_norm), 1.0 / limiter)
|
|
729
|
+
state["fira_prev_norm"] = torch.where(is_reset, current_norm, current_norm * final_scale)
|
|
710
730
|
else:
|
|
711
731
|
final_scale = torch.tensor(1.0, device=total_sum_sq.device, dtype=torch.float32)
|
|
712
|
-
|
|
713
|
-
state["fira_prev_norm"] = current_norm * final_scale
|
|
732
|
+
state["fira_prev_norm"] = current_norm
|
|
714
733
|
|
|
715
734
|
alpha_scaled = alpha * final_scale
|
|
716
735
|
total_sum_sq.mul_(final_scale.square())
|
|
@@ -719,19 +738,26 @@ def _apply_fira_cuda(state: Dict[str, Any], total_sum_sq: Tensor, alpha: Tensor,
|
|
|
719
738
|
|
|
720
739
|
def _apply_fira_pytorch(state: Dict[str, Any], update: Tensor, fira_margin: float, numel: int, d: float) -> Tuple[Tensor, Tensor]:
|
|
721
740
|
current_norm = torch.linalg.vector_norm(update)
|
|
741
|
+
|
|
742
|
+
is_finite = torch.isfinite(current_norm)
|
|
743
|
+
current_norm = torch.where(is_finite, current_norm, torch.zeros_like(current_norm))
|
|
744
|
+
update = torch.where(is_finite, update, torch.zeros_like(update))
|
|
745
|
+
|
|
722
746
|
fira_threshold = 1.0 + fira_margin
|
|
723
747
|
|
|
724
748
|
prev_norm = state.get("fira_prev_norm", None)
|
|
725
749
|
if prev_norm is not None:
|
|
726
750
|
if not isinstance(prev_norm, Tensor):
|
|
727
751
|
prev_norm = torch.tensor(prev_norm, device=update.device, dtype=torch.float32)
|
|
752
|
+
|
|
753
|
+
is_reset = prev_norm < 1e-6
|
|
728
754
|
ratio = current_norm / (prev_norm + 1e-8)
|
|
729
755
|
limiter = torch.clamp_min(ratio, fira_threshold) / fira_threshold
|
|
730
|
-
final_scale = 1.0 / limiter
|
|
756
|
+
final_scale = torch.where(is_reset, torch.ones_like(current_norm), 1.0 / limiter)
|
|
757
|
+
state["fira_prev_norm"] = torch.where(is_reset, current_norm, current_norm * final_scale)
|
|
731
758
|
else:
|
|
732
759
|
final_scale = torch.tensor(1.0, device=update.device, dtype=torch.float32)
|
|
733
|
-
|
|
734
|
-
state["fira_prev_norm"] = current_norm * final_scale
|
|
760
|
+
state["fira_prev_norm"] = current_norm
|
|
735
761
|
|
|
736
762
|
update_scaled = update * final_scale
|
|
737
763
|
norm_final = current_norm * final_scale
|
|
@@ -1295,6 +1321,7 @@ def _update_param_apollo(
|
|
|
1295
1321
|
fira_margin: float = 0.01,
|
|
1296
1322
|
):
|
|
1297
1323
|
grad_work = grad.neg().float() if maximize else grad.float()
|
|
1324
|
+
grad_work = torch.where(torch.isfinite(grad_work), grad_work, torch.zeros_like(grad_work))
|
|
1298
1325
|
update_low = None
|
|
1299
1326
|
|
|
1300
1327
|
if apollo_factorize:
|
|
@@ -1401,7 +1428,7 @@ def _update_param_apollo(
|
|
|
1401
1428
|
|
|
1402
1429
|
if state.get("m_low_q") is None:
|
|
1403
1430
|
m_padded_numel = ((grad_low_numel + m_curr_block_size - 1) // m_curr_block_size) * m_curr_block_size
|
|
1404
|
-
state["m_low_q"] = torch.full((m_padded_numel // 2,),
|
|
1431
|
+
state["m_low_q"] = torch.full((m_padded_numel // 2,), 0x77, dtype=torch.uint8, device=grad_low.device)
|
|
1405
1432
|
state["m_low_scale"] = torch.ones(m_padded_numel // m_curr_block_size, dtype=torch.float32, device=grad_low.device)
|
|
1406
1433
|
state["m_block_size"] = m_curr_block_size
|
|
1407
1434
|
|
|
@@ -1481,7 +1508,7 @@ def _update_param_apollo(
|
|
|
1481
1508
|
|
|
1482
1509
|
if state.get("m_low_q") is None:
|
|
1483
1510
|
m_padded_numel = ((grad_low_numel + m_curr_block_size - 1) // m_curr_block_size) * m_curr_block_size
|
|
1484
|
-
state["m_low_q"] = torch.full((m_padded_numel // 2,),
|
|
1511
|
+
state["m_low_q"] = torch.full((m_padded_numel // 2,), 0x77, dtype=torch.uint8, device=grad_low.device)
|
|
1485
1512
|
state["m_low_scale"] = torch.ones(m_padded_numel // m_curr_block_size, dtype=torch.float32, device=grad_low.device)
|
|
1486
1513
|
state["m_block_size"] = m_curr_block_size
|
|
1487
1514
|
|
|
@@ -1575,16 +1602,20 @@ def _update_param_apollo(
|
|
|
1575
1602
|
else:
|
|
1576
1603
|
current_norm_t = torch.linalg.vector_norm(grad_work, ord=2, dtype=torch.float32) * scaling_factor
|
|
1577
1604
|
|
|
1605
|
+
is_finite = torch.isfinite(current_norm_t)
|
|
1606
|
+
current_norm_t = torch.where(is_finite, current_norm_t, torch.zeros_like(current_norm_t))
|
|
1607
|
+
|
|
1578
1608
|
fira_threshold = 1.0 + fira_margin
|
|
1579
1609
|
if "scaled_grad_norm_prev" in state:
|
|
1580
1610
|
prev_norm_t = state["scaled_grad_norm_prev"]
|
|
1581
1611
|
if not isinstance(prev_norm_t, Tensor):
|
|
1582
1612
|
prev_norm_t = torch.tensor(prev_norm_t, device=param_work.device, dtype=torch.float32)
|
|
1583
1613
|
|
|
1614
|
+
is_reset = prev_norm_t < 1e-6
|
|
1584
1615
|
ratio = current_norm_t / (prev_norm_t + 1e-8)
|
|
1585
1616
|
limiter = torch.clamp_min(ratio, fira_threshold) / fira_threshold
|
|
1586
|
-
final_scale = scaling_factor / limiter
|
|
1587
|
-
state["scaled_grad_norm_prev"] = current_norm_t / limiter
|
|
1617
|
+
final_scale = torch.where(is_reset, scaling_factor, scaling_factor / limiter)
|
|
1618
|
+
state["scaled_grad_norm_prev"] = torch.where(is_reset, current_norm_t, current_norm_t / limiter)
|
|
1588
1619
|
else:
|
|
1589
1620
|
final_scale = scaling_factor
|
|
1590
1621
|
state["scaled_grad_norm_prev"] = current_norm_t
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: adafactor8bit
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.5
|
|
4
4
|
Summary: 8-bit Adafactor Optimizer with Fused CUDA Kernels
|
|
5
5
|
Home-page: https://github.com/yanfeiwong/adafactor-8bit
|
|
6
6
|
Author: WANG YAN
|
|
@@ -53,7 +53,7 @@ An enhanced 8-bit Adafactor optimizer featuring fused CUDA kernels, log-space bl
|
|
|
53
53
|
|
|
54
54
|
- **Log-Space Quantization**: Maps the second moment (variance) to the log2 space before 8-bit quantization. This approach accommodates the long-tail distribution of variances, reducing the risk of small second-moment estimates being truncated to zero and improving overall training stability.
|
|
55
55
|
- **Fused CUDA Kernels**: Combines dequantization, EMA updates, Warp-Shuffle reductions, and requantization into single kernels. It utilizes `float4` vectorization to optimize memory bandwidth usage.
|
|
56
|
-
- **Optional
|
|
56
|
+
- **Optional NF4 First Moment**: Stores the optional first moment (`beta1`) using Normal Float 4-bit (NF4) non-uniform quantization, preserving small momentum updates while keeping memory overhead minimal.
|
|
57
57
|
- **CAME Confidence Guidance**: Optional Confidence-guided Adaptive Memory Efficient Optimization (CAME) that estimates update confidence from historical momentum and adaptively suppresses unstable update directions, improving training stability and reducing loss spikes.
|
|
58
58
|
- **APOLLO Subspace Projection**: Opt-in random subspace projection that estimates adaptive gradient scaling in a low-rank space, preventing stale second-moment statistics and potentially improving convergence and generalization.
|
|
59
59
|
- **Fira Norm-Growth Limiter**: Suppresses destructive gradient spikes by regulating the relative increase of update norms. Originally used for the APOLLO path, it is now available for the standard Adafactor path as well. It improves training stability and often allows the safe removal of external gradient clipping.
|
|
@@ -278,7 +278,7 @@ Enable the APOLLO path to compute gradient scaling factors in a memory-efficient
|
|
|
278
278
|
- **`apollo_factorize` (Experimental)**: Applies Adafactor's row/column factorization within the low-rank subspace. Mathematically, this leverages the norm-preserving property of random projections to approximate the variance of the primary dimension, while the secondary dimension's variance is estimated across random bases, introducing inherent noise. This dual-compression mechanism drastically reduces optimizer state overhead. Note that for smaller models, the actual VRAM savings might be marginal, and the introduced noise could impact convergence stability. Use with caution.
|
|
279
279
|
- **Fira Limiter Integration**: The APOLLO path automatically applies the Fira Norm-Growth Limiter to the scaled gradients to prevent sudden gradient rises from causing loss spikes. You can adjust its sensitivity using the global `fira_margin` parameter.
|
|
280
280
|
|
|
281
|
-
##
|
|
281
|
+
## 🧊 CAME Confidence-Guided Updates
|
|
282
282
|
|
|
283
283
|
Enable the CAME (Confidence-guided Adaptive Memory Efficient Optimization) path to add a confidence estimation stage after momentum accumulation:
|
|
284
284
|
|
|
@@ -326,24 +326,25 @@ If you are migrating from optimizers like AdamW, Adafactor's learning rate behav
|
|
|
326
326
|
*These are safe starting points. Always validate on your own task and batch size.*
|
|
327
327
|
|
|
328
328
|
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
329
|
## 🎓 Acknowledgements
|
|
333
330
|
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
Thanks to **Tim Dettmers** for the inspiration from the paper [8-BIT OPTIMIZERS VIA BLOCK-WISE QUANTIZATION](https://arxiv.org/abs/2110.02861) and the [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) library.
|
|
337
|
-
|
|
338
|
-
Thanks to **Hanqing Zhu**, **Zhenyu Zhang**, and the team for proposing the approximated gradient scaling method in the paper [APOLLO: SGD-Like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270).
|
|
331
|
+
This project builds upon the foundational work of several researchers and open-source communities. Sincere thanks to the following for their invaluable contributions:
|
|
339
332
|
|
|
340
|
-
|
|
333
|
+
### Core Algorithm & Optimizer Design
|
|
334
|
+
- **Noam Shazeer & Mitchell Stern** for proposing the original **Adafactor** algorithm ([Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235)).
|
|
335
|
+
- **Tim Dettmers** for the inspiration from **8-bit block-wise quantization** ([8-BIT OPTIMIZERS VIA BLOCK-WISE QUANTIZATION](https://arxiv.org/abs/2110.02861)) and the [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) library.
|
|
336
|
+
- **Hanqing Zhu, Zhenyu Zhang, et al.** for the **APOLLO** algorithm ([APOLLO: SGD-Like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270)).
|
|
337
|
+
- **Xi Chen, Kaituo Feng, et al.** for the **Norm-Growth Limiter** mechanism in **Fira** ([Fira: Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623)).
|
|
338
|
+
- **Yang Luo, et al.** for the **confidence-guided strategy** in **CAME** ([CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047)).
|
|
341
339
|
|
|
342
|
-
|
|
340
|
+
### Quantization & Implementation
|
|
341
|
+
- **The QLoRA Team** for pioneering the **4-bit NormalFloat (NF4)** quantization format ([QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)) that inspired our first moment quantization.
|
|
342
|
+
- **The PyTorch AO Team** for their work on [4-bit optimizer states](https://github.com/pytorch/ao/tree/main/torchao/optim), validating distribution-aware quantization for optimizer moments.
|
|
343
|
+
- **The PyTorch Team** for providing the foundational optimizer implementation and the C++ Extension toolchain.
|
|
343
344
|
|
|
344
|
-
|
|
345
|
+
### Technical Review & Discussion
|
|
346
|
+
- **Qwen, ChatGLM, and DeepSeek** (large language models) for valuable technical discussions and code reviews on CUDA low-level optimization, memory safety mechanisms, and cross-platform compilation pipeline design.
|
|
345
347
|
|
|
346
|
-
Thanks to the large language models **Qwen**, **ChatGLM** and **DeepSeek** for valuable technical discussions and code reviews on CUDA low-level optimization and memory safety mechanisms.
|
|
347
348
|
|
|
348
349
|
## 🏛️ License
|
|
349
350
|
|
|
@@ -9,7 +9,7 @@ long_description = (this_directory / "README.md").read_text(encoding="utf-8")
|
|
|
9
9
|
|
|
10
10
|
setup(
|
|
11
11
|
name="adafactor8bit",
|
|
12
|
-
version="0.2.
|
|
12
|
+
version="0.2.5",
|
|
13
13
|
description="8-bit Adafactor Optimizer with Fused CUDA Kernels",
|
|
14
14
|
author="WANG YAN",
|
|
15
15
|
author_email="yanfeiwong1997@outlook.com",
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|