adafactor8bit 0.2.1__tar.gz → 0.2.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adafactor8bit-0.2.1/adafactor8bit.egg-info → adafactor8bit-0.2.2}/PKG-INFO +53 -10
- {adafactor8bit-0.2.1 → adafactor8bit-0.2.2}/README.md +52 -9
- {adafactor8bit-0.2.1 → adafactor8bit-0.2.2}/adafactor8bit/kernels.cu +80 -0
- {adafactor8bit-0.2.1 → adafactor8bit-0.2.2}/adafactor8bit/optimizer.py +310 -85
- {adafactor8bit-0.2.1 → adafactor8bit-0.2.2/adafactor8bit.egg-info}/PKG-INFO +53 -10
- {adafactor8bit-0.2.1 → adafactor8bit-0.2.2}/setup.py +1 -1
- {adafactor8bit-0.2.1 → adafactor8bit-0.2.2}/LICENSE +0 -0
- {adafactor8bit-0.2.1 → adafactor8bit-0.2.2}/MANIFEST.in +0 -0
- {adafactor8bit-0.2.1 → adafactor8bit-0.2.2}/adafactor8bit/__init__.py +0 -0
- {adafactor8bit-0.2.1 → adafactor8bit-0.2.2}/adafactor8bit.egg-info/SOURCES.txt +0 -0
- {adafactor8bit-0.2.1 → adafactor8bit-0.2.2}/adafactor8bit.egg-info/dependency_links.txt +0 -0
- {adafactor8bit-0.2.1 → adafactor8bit-0.2.2}/adafactor8bit.egg-info/requires.txt +0 -0
- {adafactor8bit-0.2.1 → adafactor8bit-0.2.2}/adafactor8bit.egg-info/top_level.txt +0 -0
- {adafactor8bit-0.2.1 → adafactor8bit-0.2.2}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: adafactor8bit
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.2
|
|
4
4
|
Summary: 8-bit Adafactor Optimizer with Fused CUDA Kernels
|
|
5
5
|
Home-page: https://github.com/yanfeiwong/adafactor-8bit
|
|
6
6
|
Author: WANG YAN
|
|
@@ -25,6 +25,13 @@ Dynamic: requires-dist
|
|
|
25
25
|
Dynamic: requires-python
|
|
26
26
|
Dynamic: summary
|
|
27
27
|
|
|
28
|
+
<p align="center">
|
|
29
|
+
<a href="https://github.com/yanfeiwong/adafactor-8bit">
|
|
30
|
+
<img src="https://github.com/yanfeiwong/adafactor-8bit/raw/main/assets/banner.png"
|
|
31
|
+
alt="Adafactor8Bit"
|
|
32
|
+
width="80%">
|
|
33
|
+
</a>
|
|
34
|
+
</p>
|
|
28
35
|
<div align="center">
|
|
29
36
|
|
|
30
37
|
# 8-bit Adafactor with Fused CUDA Kernels
|
|
@@ -39,14 +46,15 @@ Dynamic: summary
|
|
|
39
46
|
|
|
40
47
|
</div>
|
|
41
48
|
|
|
42
|
-
An enhanced 8-bit Adafactor optimizer featuring fused CUDA kernels, log-space block-wise quantization, optional APOLLO low-rank updates, and
|
|
49
|
+
An enhanced 8-bit Adafactor optimizer featuring fused CUDA kernels, log-space block-wise quantization, and optional add-ons including 4-bit packed first moments, APOLLO low-rank updates, and CAME confidence-guided optimization. It delivers substantially lower optimizer memory while preserving the low-overhead and numerical stability that make Adafactor attractive for training LLMs and diffusion models.
|
|
43
50
|
|
|
44
51
|
|
|
45
|
-
##
|
|
52
|
+
## ⚡ Key Features
|
|
46
53
|
|
|
47
54
|
- **Log-Space Quantization**: Maps the second moment (variance) to the log2 space before 8-bit quantization. This approach accommodates the long-tail distribution of variances, reducing the risk of small second-moment estimates being truncated to zero and improving overall training stability.
|
|
48
55
|
- **Fused CUDA Kernels**: Combines dequantization, EMA updates, Warp-Shuffle reductions, and requantization into single kernels. It utilizes `float4` vectorization to optimize memory bandwidth usage.
|
|
49
56
|
- **Optional 4-bit Packed First Moment**: Stores the first moment (`beta1`) in a physically packed 4-bit format when enabled, providing momentum with minimal additional memory overhead.
|
|
57
|
+
- **CAME Confidence Guidance**: Optional Confidence-guided Adaptive Memory Efficient Optimization (CAME) that estimates update confidence from historical momentum and adaptively suppresses unstable update directions, improving training stability and reducing loss spikes.
|
|
50
58
|
- **APOLLO Subspace Projection**: Opt-in random subspace projection that estimates adaptive gradient scaling in a low-rank space, preventing stale second-moment statistics and potentially improving convergence and generalization.
|
|
51
59
|
- **Fira Norm-Growth Limiter**: Suppresses destructive gradient spikes by regulating the relative increase of update norms. Originally used for the APOLLO path, it is now available for the standard Adafactor path as well. It improves training stability and often allows the safe removal of external gradient clipping.
|
|
52
60
|
- **Zero CPU-GPU Sync**: Eliminates implicit synchronizations (e.g., D2H copies) in the control flow, ensuring the GPU computation pipeline runs without blocking.
|
|
@@ -194,16 +202,16 @@ def get_param_groups(model, lr_emb, weight_decay, apollo_rank=256):
|
|
|
194
202
|
"weight_decay": weight_decay,
|
|
195
203
|
"quantize": True,
|
|
196
204
|
"apollo_rank": apollo_rank,
|
|
197
|
-
"beta1":0.9,
|
|
205
|
+
"beta1": 0.9, # Remove if minimizing optimizer memory is the priority.
|
|
198
206
|
},
|
|
199
|
-
|
|
207
|
+
|
|
200
208
|
# 4. >2D Weights: 8-bit quantization, Weight Decay, Full-Rank
|
|
201
209
|
{
|
|
202
210
|
"params": group_nd,
|
|
203
211
|
"weight_decay": weight_decay,
|
|
204
212
|
"quantize": True,
|
|
205
213
|
"apollo_rank": 0,
|
|
206
|
-
"beta1":0.9,
|
|
214
|
+
"beta1": 0.9, # Remove if minimizing optimizer memory is the priority.
|
|
207
215
|
"factored": False # Disables factorization to preserve spatial structures, enabling finer gradient scaling.
|
|
208
216
|
# Note: This increases state memory for >2D weights, depending on your model architecture.
|
|
209
217
|
# If VRAM is constrained, reverting to factored=True is a safe alternative.
|
|
@@ -270,7 +278,40 @@ Enable the APOLLO path to compute gradient scaling factors in a memory-efficient
|
|
|
270
278
|
- **`apollo_factorize` (Experimental)**: Applies Adafactor's row/column factorization within the low-rank subspace. Mathematically, this leverages the norm-preserving property of random projections to approximate the variance of the primary dimension, while the secondary dimension's variance is estimated across random bases, introducing inherent noise. This dual-compression mechanism drastically reduces optimizer state overhead. Note that for smaller models, the actual VRAM savings might be marginal, and the introduced noise could impact convergence stability. Use with caution.
|
|
271
279
|
- **Fira Limiter Integration**: The APOLLO path automatically applies the Fira Norm-Growth Limiter to the scaled gradients to prevent sudden gradient rises from causing loss spikes. You can adjust its sensitivity using the global `fira_margin` parameter.
|
|
272
280
|
|
|
281
|
+
## 🛡️ CAME Confidence-Guided Updates
|
|
282
|
+
|
|
283
|
+
Enable the CAME (Confidence-guided Adaptive Memory Efficient Optimization) path to add a confidence estimation stage after momentum accumulation:
|
|
284
|
+
|
|
285
|
+
**Adaptive Scaling ($V$) → Momentum Accumulation ($M$) → Confidence Weighting ($C$)**
|
|
286
|
+
|
|
287
|
+
### Key Parameters & Tuning
|
|
273
288
|
|
|
289
|
+
The confidence stage measures the consistency between the current update direction and historical momentum, adaptively suppressing highly oscillatory updates.
|
|
290
|
+
|
|
291
|
+
- **`beta3`**: EMA decay coefficient for the confidence matrix. Requires `beta1` (momentum) and `factored=True`. Mutually exclusive with `apollo_rank`. Defaults to `None` (disabled).
|
|
292
|
+
- **Learning Rate**: The official CAME implementation recommends **0.5–0.9×** the AdamW learning rate (see [official tuning guide](https://github.com/yangluo7/CAME/tree/master#hyper-parameter-tuning)). To use this learning rate in this library, you need to disable Adafactor's scaling and clipping (`scale_parameter=False`, `d=1e9`) to align with the original CAME behavior.
|
|
293
|
+
- **Warmup**: Since the confidence matrix is zero-initialized without bias correction, a learning rate warmup is recommended to safely establish the confidence baseline.
|
|
294
|
+
- **Choosing `beta3`**: `beta3` should generally be larger than `beta2` so the confidence estimate evolves more slowly than the variance estimate. A practical starting range is **0.9995–0.99995** when `beta2=0.999`.
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
### Configuration Example
|
|
298
|
+
|
|
299
|
+
To replicate "vanilla" CAME (stripping Adafactor's native modifications), replace the standard 2D APOLLO group in your `param_groups` with the following configuration:
|
|
300
|
+
|
|
301
|
+
```python
|
|
302
|
+
{
|
|
303
|
+
"params": param_group,
|
|
304
|
+
"lr": lr, # Original CAME recommends 0.5-0.9x AdamW LR
|
|
305
|
+
"weight_decay": weight_decay,
|
|
306
|
+
"quantize": True,
|
|
307
|
+
"beta1": 0.9,
|
|
308
|
+
"beta3": 0.9999, # Enable CAME confidence guidance
|
|
309
|
+
"apollo_rank": 0, # Mutually exclusive with CAME
|
|
310
|
+
"scale_parameter": False, # Disable Adafactor RMS scaling to align with vanilla CAME
|
|
311
|
+
"d": 1e9, # Disable Adafactor global RMS clipping
|
|
312
|
+
"enable_fira_for_adafactor": False, # Disable Fira Limiter to prevent interference with CAME's scaling
|
|
313
|
+
},
|
|
314
|
+
```
|
|
274
315
|
|
|
275
316
|
## 📈 Learning Rate Guide for Beginners
|
|
276
317
|
|
|
@@ -298,16 +339,18 @@ Thanks to **Hanqing Zhu**, **Zhenyu Zhang**, and the team for proposing the appr
|
|
|
298
339
|
|
|
299
340
|
Thanks to **Xi Chen**, **Kaituo Feng**, and the team for the Norm-Growth Limiter mechanism introduced in [Fira: Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623).
|
|
300
341
|
|
|
342
|
+
Thanks to **Yang Luo** and the team for proposing the confidence-guided strategy in the paper [CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047).
|
|
343
|
+
|
|
301
344
|
Thanks to the **PyTorch team** for providing the foundational Optimizer implementation and the C++ Extension toolchain.
|
|
302
345
|
|
|
303
346
|
Thanks to the large language models **Qwen**, **ChatGLM** and **DeepSeek** for valuable technical discussions and code reviews on CUDA low-level optimization and memory safety mechanisms.
|
|
304
347
|
|
|
348
|
+
## 🏛️ License
|
|
349
|
+
|
|
350
|
+
[The project is released under the MIT License.](https://github.com/yanfeiwong/adafactor-8bit/blob/main/LICENSE)
|
|
351
|
+
|
|
305
352
|
## ⭐ Star the Project
|
|
306
353
|
|
|
307
354
|
If this optimizer has been useful in your work, consider giving the repository a star. It helps others discover the project and supports future development.
|
|
308
355
|
|
|
309
356
|
[](https://star-history.com/#yanfeiwong/adafactor-8bit&Date)
|
|
310
|
-
|
|
311
|
-
## 📄 License
|
|
312
|
-
|
|
313
|
-
[The project is released under the MIT License.](https://github.com/yanfeiwong/adafactor-8bit/blob/main/LICENSE)
|
|
@@ -1,3 +1,10 @@
|
|
|
1
|
+
<p align="center">
|
|
2
|
+
<a href="https://github.com/yanfeiwong/adafactor-8bit">
|
|
3
|
+
<img src="https://github.com/yanfeiwong/adafactor-8bit/raw/main/assets/banner.png"
|
|
4
|
+
alt="Adafactor8Bit"
|
|
5
|
+
width="80%">
|
|
6
|
+
</a>
|
|
7
|
+
</p>
|
|
1
8
|
<div align="center">
|
|
2
9
|
|
|
3
10
|
# 8-bit Adafactor with Fused CUDA Kernels
|
|
@@ -12,14 +19,15 @@
|
|
|
12
19
|
|
|
13
20
|
</div>
|
|
14
21
|
|
|
15
|
-
An enhanced 8-bit Adafactor optimizer featuring fused CUDA kernels, log-space block-wise quantization, optional APOLLO low-rank updates, and
|
|
22
|
+
An enhanced 8-bit Adafactor optimizer featuring fused CUDA kernels, log-space block-wise quantization, and optional add-ons including 4-bit packed first moments, APOLLO low-rank updates, and CAME confidence-guided optimization. It delivers substantially lower optimizer memory while preserving the low-overhead and numerical stability that make Adafactor attractive for training LLMs and diffusion models.
|
|
16
23
|
|
|
17
24
|
|
|
18
|
-
##
|
|
25
|
+
## ⚡ Key Features
|
|
19
26
|
|
|
20
27
|
- **Log-Space Quantization**: Maps the second moment (variance) to the log2 space before 8-bit quantization. This approach accommodates the long-tail distribution of variances, reducing the risk of small second-moment estimates being truncated to zero and improving overall training stability.
|
|
21
28
|
- **Fused CUDA Kernels**: Combines dequantization, EMA updates, Warp-Shuffle reductions, and requantization into single kernels. It utilizes `float4` vectorization to optimize memory bandwidth usage.
|
|
22
29
|
- **Optional 4-bit Packed First Moment**: Stores the first moment (`beta1`) in a physically packed 4-bit format when enabled, providing momentum with minimal additional memory overhead.
|
|
30
|
+
- **CAME Confidence Guidance**: Optional Confidence-guided Adaptive Memory Efficient Optimization (CAME) that estimates update confidence from historical momentum and adaptively suppresses unstable update directions, improving training stability and reducing loss spikes.
|
|
23
31
|
- **APOLLO Subspace Projection**: Opt-in random subspace projection that estimates adaptive gradient scaling in a low-rank space, preventing stale second-moment statistics and potentially improving convergence and generalization.
|
|
24
32
|
- **Fira Norm-Growth Limiter**: Suppresses destructive gradient spikes by regulating the relative increase of update norms. Originally used for the APOLLO path, it is now available for the standard Adafactor path as well. It improves training stability and often allows the safe removal of external gradient clipping.
|
|
25
33
|
- **Zero CPU-GPU Sync**: Eliminates implicit synchronizations (e.g., D2H copies) in the control flow, ensuring the GPU computation pipeline runs without blocking.
|
|
@@ -167,16 +175,16 @@ def get_param_groups(model, lr_emb, weight_decay, apollo_rank=256):
|
|
|
167
175
|
"weight_decay": weight_decay,
|
|
168
176
|
"quantize": True,
|
|
169
177
|
"apollo_rank": apollo_rank,
|
|
170
|
-
"beta1":0.9,
|
|
178
|
+
"beta1": 0.9, # Remove if minimizing optimizer memory is the priority.
|
|
171
179
|
},
|
|
172
|
-
|
|
180
|
+
|
|
173
181
|
# 4. >2D Weights: 8-bit quantization, Weight Decay, Full-Rank
|
|
174
182
|
{
|
|
175
183
|
"params": group_nd,
|
|
176
184
|
"weight_decay": weight_decay,
|
|
177
185
|
"quantize": True,
|
|
178
186
|
"apollo_rank": 0,
|
|
179
|
-
"beta1":0.9,
|
|
187
|
+
"beta1": 0.9, # Remove if minimizing optimizer memory is the priority.
|
|
180
188
|
"factored": False # Disables factorization to preserve spatial structures, enabling finer gradient scaling.
|
|
181
189
|
# Note: This increases state memory for >2D weights, depending on your model architecture.
|
|
182
190
|
# If VRAM is constrained, reverting to factored=True is a safe alternative.
|
|
@@ -243,7 +251,40 @@ Enable the APOLLO path to compute gradient scaling factors in a memory-efficient
|
|
|
243
251
|
- **`apollo_factorize` (Experimental)**: Applies Adafactor's row/column factorization within the low-rank subspace. Mathematically, this leverages the norm-preserving property of random projections to approximate the variance of the primary dimension, while the secondary dimension's variance is estimated across random bases, introducing inherent noise. This dual-compression mechanism drastically reduces optimizer state overhead. Note that for smaller models, the actual VRAM savings might be marginal, and the introduced noise could impact convergence stability. Use with caution.
|
|
244
252
|
- **Fira Limiter Integration**: The APOLLO path automatically applies the Fira Norm-Growth Limiter to the scaled gradients to prevent sudden gradient rises from causing loss spikes. You can adjust its sensitivity using the global `fira_margin` parameter.
|
|
245
253
|
|
|
254
|
+
## 🛡️ CAME Confidence-Guided Updates
|
|
255
|
+
|
|
256
|
+
Enable the CAME (Confidence-guided Adaptive Memory Efficient Optimization) path to add a confidence estimation stage after momentum accumulation:
|
|
257
|
+
|
|
258
|
+
**Adaptive Scaling ($V$) → Momentum Accumulation ($M$) → Confidence Weighting ($C$)**
|
|
259
|
+
|
|
260
|
+
### Key Parameters & Tuning
|
|
246
261
|
|
|
262
|
+
The confidence stage measures the consistency between the current update direction and historical momentum, adaptively suppressing highly oscillatory updates.
|
|
263
|
+
|
|
264
|
+
- **`beta3`**: EMA decay coefficient for the confidence matrix. Requires `beta1` (momentum) and `factored=True`. Mutually exclusive with `apollo_rank`. Defaults to `None` (disabled).
|
|
265
|
+
- **Learning Rate**: The official CAME implementation recommends **0.5–0.9×** the AdamW learning rate (see [official tuning guide](https://github.com/yangluo7/CAME/tree/master#hyper-parameter-tuning)). To use this learning rate in this library, you need to disable Adafactor's scaling and clipping (`scale_parameter=False`, `d=1e9`) to align with the original CAME behavior.
|
|
266
|
+
- **Warmup**: Since the confidence matrix is zero-initialized without bias correction, a learning rate warmup is recommended to safely establish the confidence baseline.
|
|
267
|
+
- **Choosing `beta3`**: `beta3` should generally be larger than `beta2` so the confidence estimate evolves more slowly than the variance estimate. A practical starting range is **0.9995–0.99995** when `beta2=0.999`.
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
### Configuration Example
|
|
271
|
+
|
|
272
|
+
To replicate "vanilla" CAME (stripping Adafactor's native modifications), replace the standard 2D APOLLO group in your `param_groups` with the following configuration:
|
|
273
|
+
|
|
274
|
+
```python
|
|
275
|
+
{
|
|
276
|
+
"params": param_group,
|
|
277
|
+
"lr": lr, # Original CAME recommends 0.5-0.9x AdamW LR
|
|
278
|
+
"weight_decay": weight_decay,
|
|
279
|
+
"quantize": True,
|
|
280
|
+
"beta1": 0.9,
|
|
281
|
+
"beta3": 0.9999, # Enable CAME confidence guidance
|
|
282
|
+
"apollo_rank": 0, # Mutually exclusive with CAME
|
|
283
|
+
"scale_parameter": False, # Disable Adafactor RMS scaling to align with vanilla CAME
|
|
284
|
+
"d": 1e9, # Disable Adafactor global RMS clipping
|
|
285
|
+
"enable_fira_for_adafactor": False, # Disable Fira Limiter to prevent interference with CAME's scaling
|
|
286
|
+
},
|
|
287
|
+
```
|
|
247
288
|
|
|
248
289
|
## 📈 Learning Rate Guide for Beginners
|
|
249
290
|
|
|
@@ -271,16 +312,18 @@ Thanks to **Hanqing Zhu**, **Zhenyu Zhang**, and the team for proposing the appr
|
|
|
271
312
|
|
|
272
313
|
Thanks to **Xi Chen**, **Kaituo Feng**, and the team for the Norm-Growth Limiter mechanism introduced in [Fira: Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623).
|
|
273
314
|
|
|
315
|
+
Thanks to **Yang Luo** and the team for proposing the confidence-guided strategy in the paper [CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047).
|
|
316
|
+
|
|
274
317
|
Thanks to the **PyTorch team** for providing the foundational Optimizer implementation and the C++ Extension toolchain.
|
|
275
318
|
|
|
276
319
|
Thanks to the large language models **Qwen**, **ChatGLM** and **DeepSeek** for valuable technical discussions and code reviews on CUDA low-level optimization and memory safety mechanisms.
|
|
277
320
|
|
|
321
|
+
## 🏛️ License
|
|
322
|
+
|
|
323
|
+
[The project is released under the MIT License.](https://github.com/yanfeiwong/adafactor-8bit/blob/main/LICENSE)
|
|
324
|
+
|
|
278
325
|
## ⭐ Star the Project
|
|
279
326
|
|
|
280
327
|
If this optimizer has been useful in your work, consider giving the repository a star. It helps others discover the project and supports future development.
|
|
281
328
|
|
|
282
329
|
[](https://star-history.com/#yanfeiwong/adafactor-8bit&Date)
|
|
283
|
-
|
|
284
|
-
## 📄 License
|
|
285
|
-
|
|
286
|
-
[The project is released under the MIT License.](https://github.com/yanfeiwong/adafactor-8bit/blob/main/LICENSE)
|
|
@@ -1112,6 +1112,84 @@ void apply_update_1d_full_m_cuda(
|
|
|
1112
1112
|
}
|
|
1113
1113
|
|
|
1114
1114
|
|
|
1115
|
+
// ==========================================
|
|
1116
|
+
// 15. CAME: Compute Residual Variance (Row & Col)
|
|
1117
|
+
// ==========================================
|
|
1118
|
+
__global__ void came_compute_residual_2d_kernel(
|
|
1119
|
+
const unsigned char* __restrict__ m_q, const float* __restrict__ m_scale,
|
|
1120
|
+
const unsigned char* __restrict__ row_var_q, const float* __restrict__ row_var_scale,
|
|
1121
|
+
const unsigned char* __restrict__ col_var_q, const float* __restrict__ col_var_scale,
|
|
1122
|
+
const float* __restrict__ grad,
|
|
1123
|
+
const float* __restrict__ row_mean_val_ptr,
|
|
1124
|
+
float* __restrict__ res_row_sum, float* __restrict__ res_col_sum,
|
|
1125
|
+
float log_eps_sq, int R, int C, int numel, int m_block_size, int v_block_size)
|
|
1126
|
+
{
|
|
1127
|
+
int stride = gridDim.x * blockDim.x;
|
|
1128
|
+
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < numel; idx += stride) {
|
|
1129
|
+
int b = idx / (R * C);
|
|
1130
|
+
int r = (idx / C) % R;
|
|
1131
|
+
int c = idx % C;
|
|
1132
|
+
|
|
1133
|
+
unsigned char packed = m_q[idx / 2];
|
|
1134
|
+
int q_int = (idx & 1) ? (packed & 0x0F) : (packed >> 4);
|
|
1135
|
+
float m_val = (float)(q_int - 8) * m_scale[idx / m_block_size];
|
|
1136
|
+
|
|
1137
|
+
float log_r = (float)row_var_q[b * R + r] * INV_255 * row_var_scale[(b * R + r) / v_block_size] + MIN_LOG;
|
|
1138
|
+
float log_c = (float)col_var_q[b * C + c] * INV_255 * col_var_scale[(b * C + c) / v_block_size] + MIN_LOG;
|
|
1139
|
+
float log_row_mean = log2f(fmaxf(row_mean_val_ptr[b], MIN_VAL));
|
|
1140
|
+
|
|
1141
|
+
float log_v_ij = log_r + log_c - log_row_mean;
|
|
1142
|
+
float max_log = fmaxf(log_v_ij, log_eps_sq);
|
|
1143
|
+
max_log = fmaxf(max_log, -53.0f);
|
|
1144
|
+
float inv_std = exp2f(-0.5f * max_log);
|
|
1145
|
+
|
|
1146
|
+
float diff = (grad[idx] - m_val) * inv_std;
|
|
1147
|
+
float res = diff * diff;
|
|
1148
|
+
|
|
1149
|
+
atomicAdd(&res_col_sum[b * C + c], res);
|
|
1150
|
+
|
|
1151
|
+
int row_idx = b * R + r;
|
|
1152
|
+
int lane = threadIdx.x % 32;
|
|
1153
|
+
|
|
1154
|
+
for (int offset = 16; offset > 0; offset /= 2) {
|
|
1155
|
+
int other_row_idx = __shfl_down_sync(0xffffffff, row_idx, offset);
|
|
1156
|
+
float other_res = __shfl_down_sync(0xffffffff, res, offset);
|
|
1157
|
+
if (lane + offset < 32 && row_idx == other_row_idx) {
|
|
1158
|
+
res += other_res;
|
|
1159
|
+
}
|
|
1160
|
+
}
|
|
1161
|
+
|
|
1162
|
+
int prev_row_idx = __shfl_up_sync(0xffffffff, row_idx, 1);
|
|
1163
|
+
bool is_first_in_row = (lane == 0) || (row_idx != prev_row_idx);
|
|
1164
|
+
|
|
1165
|
+
if (is_first_in_row) {
|
|
1166
|
+
atomicAdd(&res_row_sum[row_idx], res);
|
|
1167
|
+
}
|
|
1168
|
+
}
|
|
1169
|
+
}
|
|
1170
|
+
|
|
1171
|
+
void came_compute_residual_2d_cuda(
|
|
1172
|
+
torch::Tensor m_q, torch::Tensor m_scale,
|
|
1173
|
+
torch::Tensor row_var_q, torch::Tensor row_var_scale,
|
|
1174
|
+
torch::Tensor col_var_q, torch::Tensor col_var_scale,
|
|
1175
|
+
torch::Tensor grad, torch::Tensor row_mean_val,
|
|
1176
|
+
torch::Tensor res_row_sum, torch::Tensor res_col_sum,
|
|
1177
|
+
float log_eps_sq, int R, int C, int numel, int m_block_size, int v_block_size)
|
|
1178
|
+
{
|
|
1179
|
+
int threads = 256;
|
|
1180
|
+
int blocks = min(1024, (numel + threads - 1) / threads);
|
|
1181
|
+
came_compute_residual_2d_kernel<<<blocks, threads>>>(
|
|
1182
|
+
m_q.data_ptr<unsigned char>(), m_scale.data_ptr<float>(),
|
|
1183
|
+
row_var_q.data_ptr<unsigned char>(), row_var_scale.data_ptr<float>(),
|
|
1184
|
+
col_var_q.data_ptr<unsigned char>(), col_var_scale.data_ptr<float>(),
|
|
1185
|
+
grad.data_ptr<float>(), row_mean_val.data_ptr<float>(),
|
|
1186
|
+
res_row_sum.data_ptr<float>(), res_col_sum.data_ptr<float>(),
|
|
1187
|
+
log_eps_sq, R, C, numel, m_block_size, v_block_size
|
|
1188
|
+
);
|
|
1189
|
+
}
|
|
1190
|
+
|
|
1191
|
+
|
|
1192
|
+
|
|
1115
1193
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
1116
1194
|
m.def("fused_log_quantize_lerp", &fused_log_quantize_lerp_cuda, "Fused log quantize lerp (CUDA)");
|
|
1117
1195
|
m.def("fused_4bit_quantize_lerp", &fused_4bit_quantize_lerp_cuda, "Fused 4-bit packed quantize lerp for m_t (CUDA)");
|
|
@@ -1134,4 +1212,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
|
1134
1212
|
|
|
1135
1213
|
m.def("compute_update_norm_1d_full_m", &compute_update_norm_1d_full_m_cuda, "Compute update norm 1D full precision with momentum (CUDA)");
|
|
1136
1214
|
m.def("apply_update_1d_full_m", &apply_update_1d_full_m_cuda, "Apply update 1D full precision with momentum (CUDA)");
|
|
1215
|
+
|
|
1216
|
+
m.def("came_compute_residual_2d", &came_compute_residual_2d_cuda, "Compute CAME residual row/col sums (CUDA)");
|
|
1137
1217
|
}
|
|
@@ -116,6 +116,9 @@ def _log_quantize_nonneg(tensor: Tensor, block_size: int = 2048) -> Tuple[Tensor
|
|
|
116
116
|
|
|
117
117
|
def _log_dequantize_nonneg(q: Tensor, scale: Tensor, shape: torch.Size, pad: int) -> Tensor:
|
|
118
118
|
"""Dequantize from log-space back to linear-space FP32."""
|
|
119
|
+
if q.dim() == 1:
|
|
120
|
+
block_size = q.numel() // scale.numel()
|
|
121
|
+
q = q.view(-1, block_size)
|
|
119
122
|
log_blocks = q.float() * scale.unsqueeze(-1) * _INV_255 + _FP32_MIN_LOG
|
|
120
123
|
blocks = torch.pow(2.0, log_blocks)
|
|
121
124
|
flat = blocks.flatten()
|
|
@@ -165,6 +168,8 @@ class Adafactor8Bit(Optimizer):
|
|
|
165
168
|
|
|
166
169
|
Args:
|
|
167
170
|
params (Iterable): Iterable of parameters to optimize or dictionaries defining parameter groups.
|
|
171
|
+
|
|
172
|
+
--- Core Optimization ---
|
|
168
173
|
lr (float, optional): External learning rate. Defaults to 1e-2.
|
|
169
174
|
beta1 (float, optional): Momentum coefficient for first moment (4-bit packed).
|
|
170
175
|
If None, disables first moment (pure Adafactor/RMSProp). Defaults to None.
|
|
@@ -174,89 +179,110 @@ class Adafactor8Bit(Optimizer):
|
|
|
174
179
|
beta2_decay (float): Dynamic decay rate coefficient.
|
|
175
180
|
The EMA weight is computed as `step ** beta2_decay`. Ignored if `beta2` is specified.
|
|
176
181
|
Defaults to -0.8.
|
|
182
|
+
beta3 (float, optional): Confidence-guided decay coefficient for CAME
|
|
183
|
+
(Confidence-guided Adaptive Memory Efficient Optimization).
|
|
184
|
+
Computes the instability of the update direction and scales the update accordingly.
|
|
185
|
+
Strictly requires `beta1` and `factored=True`. Mutually exclusive with `apollo_rank`.
|
|
186
|
+
Defaults to None (disabled).
|
|
177
187
|
eps (Tuple[Optional[float], float]): Regularization constants (eps1, eps2).
|
|
178
188
|
- `eps1`: Added to the squared gradient. If `None`, defaults to the machine epsilon
|
|
179
|
-
of the parameter's dtype (e.g., ~1.19e-7 for FP32),
|
|
180
|
-
behavior and preventing underflow.
|
|
189
|
+
of the parameter's dtype (e.g., ~1.19e-7 for FP32), preventing underflow.
|
|
181
190
|
- `eps2`: Lower threshold for parameter RMS scaling. Defaults to (None, 1e-3).
|
|
191
|
+
weight_decay (float): Weight decay (L2 penalty). Defaults to 0.0.
|
|
182
192
|
d (float): Clipping threshold for the final gradient update RMS.
|
|
183
193
|
Setting to an extremely large value (e.g., ``1e9``) effectively disables the global
|
|
184
|
-
clipping constraint,
|
|
185
|
-
|
|
186
|
-
weight_decay (float): Weight decay (L2 penalty). Defaults to 0.0.
|
|
187
|
-
scale_weight_decay (bool): If `True` (default), weight decay is coupled with the
|
|
188
|
-
parameter's RMS scale. If `False`, weight decay is decoupled and only scaled by the
|
|
189
|
-
base learning rate (AdamW-style).
|
|
194
|
+
clipping constraint, useful for decoupling updates in sparse layers like Embeddings.
|
|
195
|
+
Defaults to 1.0.
|
|
190
196
|
maximize (bool): Maximize the params based on the objective. Defaults to False.
|
|
197
|
+
|
|
198
|
+
--- Factorization & Scaling ---
|
|
191
199
|
relative_step (bool): If `True`, uses time-dependent learning rate. Defaults to True.
|
|
192
200
|
scale_parameter (bool): If `True`, scales learning rate by parameter RMS.
|
|
193
201
|
Setting to False decouples the step size from parameter magnitude, which can be useful
|
|
194
202
|
for sparse layers like Embeddings to ensure sufficient update strength. Defaults to True.
|
|
203
|
+
factored (bool): Whether to use row/col factorization for >=2D tensors.
|
|
204
|
+
Setting to False uses element-wise variance (like RMSProp, but still applies Adafactor's
|
|
205
|
+
global RMS clipping). This can be useful for preserving spatial structure in >2D tensors
|
|
206
|
+
such as CNN convolutions, or enabling per-element updates in Embeddings. Defaults to True.
|
|
207
|
+
|
|
208
|
+
--- Quantization Control ---
|
|
195
209
|
quantize (bool): Enable 8-bit log-space quantization for optimizer states. Defaults to True.
|
|
196
210
|
block_size (int): Block size for variance quantization. Must be a multiple of 1024. Defaults to 2048.
|
|
197
211
|
m_block_size (int): Block size for 4-bit momentum quantization.
|
|
198
|
-
|
|
212
|
+
Must be a multiple of 4 and >= 32. Defaults to 128.
|
|
199
213
|
min_8bit_size (int): Minimum number of elements to apply 8-bit quantization. Defaults to 4096.
|
|
200
214
|
use_cuda_kernel (bool): Whether to use custom CUDA kernels. Defaults to True.
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
215
|
+
|
|
216
|
+
--- APOLLO Low-Rank Projection ---
|
|
217
|
+
apollo_rank (int): Rank for APOLLO (An Optimizer for Memory-Efficient Large-Scale Training)
|
|
218
|
+
style random projection to low-rank space. If > 0, enables APOLLO.
|
|
219
|
+
Mutually exclusive with `beta3` (CAME). Defaults to 0 (disabled).
|
|
220
|
+
apollo_update_proj_gap (int): Steps between random projection matrix refreshes.
|
|
221
|
+
Defaults to 200.
|
|
222
|
+
apollo_scale_type (str): Strategy to map low-rank updates back to full-rank:
|
|
223
|
+
'channel' (row-wise norm matching) or 'tensor' (global norm matching).
|
|
205
224
|
Defaults to 'channel'.
|
|
206
|
-
apollo_eps (float): Epsilon for low-rank variance normalization
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
Embeddings. Defaults to True.
|
|
225
|
+
apollo_eps (float): Epsilon for low-rank variance normalization to prevent division by zero.
|
|
226
|
+
Defaults to 1e-8.
|
|
227
|
+
apollo_factorize (bool): If True, applies Adafactor-style row/col factorization
|
|
228
|
+
within the low-rank space (FP32, ~16KB state) instead of full matrix variance
|
|
229
|
+
(8-bit, ~100KB+ state) to drastically reduce optimizer state memory. Defaults to False.
|
|
230
|
+
|
|
231
|
+
--- Stabilizers & Regularization ---
|
|
232
|
+
scale_weight_decay (bool): If `True` (default), weight decay is coupled with the
|
|
233
|
+
parameter's RMS scale. If `False`, decoupled (AdamW-style).
|
|
234
|
+
enable_fira_for_adafactor (bool): If `True`, enables Fira Limiter to prevent gradient
|
|
235
|
+
explosion by smoothing update norms. Defaults to False.
|
|
236
|
+
fira_margin (float): The tolerance margin for Fira Limiter (e.g., 0.01 for 1%).
|
|
237
|
+
Shared with Apollo path. Defaults to 0.01.
|
|
220
238
|
"""
|
|
239
|
+
|
|
221
240
|
def __init__(
|
|
222
241
|
self,
|
|
223
242
|
params: Iterable[Union[Tensor, Dict[str, Any]]],
|
|
243
|
+
# --- Core Optimization ---
|
|
224
244
|
lr: float = 1e-2,
|
|
225
245
|
beta1: Optional[float] = None,
|
|
226
246
|
beta2: Optional[float] = None,
|
|
227
247
|
beta2_decay: float = -0.8,
|
|
248
|
+
beta3: Optional[float] = None,
|
|
228
249
|
eps: Tuple[Optional[float], float] = (None, 1e-3),
|
|
229
|
-
d: float = 1.0,
|
|
230
250
|
weight_decay: float = 0.0,
|
|
251
|
+
d: float = 1.0,
|
|
231
252
|
maximize: bool = False,
|
|
253
|
+
# --- Factorization & Scaling ---
|
|
232
254
|
relative_step: bool = True,
|
|
233
255
|
scale_parameter: bool = True,
|
|
234
|
-
|
|
256
|
+
factored: bool = True,
|
|
257
|
+
# --- Quantization Control ---
|
|
235
258
|
quantize: bool = True,
|
|
236
259
|
block_size: int = 2048,
|
|
237
260
|
m_block_size: int = 128,
|
|
238
261
|
min_8bit_size: int = 4096,
|
|
239
262
|
use_cuda_kernel: bool = True,
|
|
263
|
+
# --- APOLLO Low-Rank Projection ---
|
|
240
264
|
apollo_rank: int = 0,
|
|
241
265
|
apollo_update_proj_gap: int = 200,
|
|
242
266
|
apollo_scale_type: str = 'channel',
|
|
243
267
|
apollo_eps: float = 1e-8,
|
|
244
268
|
apollo_factorize: bool = False,
|
|
269
|
+
# --- Stabilizers & Regularization ---
|
|
270
|
+
scale_weight_decay: bool = True,
|
|
245
271
|
enable_fira_for_adafactor: bool = False,
|
|
246
272
|
fira_margin: float = 0.01,
|
|
247
|
-
factored: bool = True,
|
|
248
273
|
):
|
|
249
|
-
|
|
250
|
-
if
|
|
274
|
+
|
|
275
|
+
if lr < 0.0: raise ValueError(f"Invalid lr: {lr}, must be >= 0.0")
|
|
276
|
+
if beta1 is not None and (beta1 < 0.0 or beta1 >= 1.0):
|
|
251
277
|
raise ValueError(f"Invalid beta1: {beta1}, must be in [0.0, 1.0)")
|
|
252
|
-
if
|
|
278
|
+
if beta2_decay > 0.0: raise ValueError(f"Invalid beta2_decay: {beta2_decay}, must be <= 0.0")
|
|
253
279
|
eps1, eps2 = eps
|
|
254
|
-
if eps1 is not None and
|
|
255
|
-
if
|
|
256
|
-
if
|
|
257
|
-
if
|
|
280
|
+
if eps1 is not None and eps1 < 0.0: raise ValueError(f"Invalid eps1: {eps1}, must be >= 0.0")
|
|
281
|
+
if eps2 < 0.0: raise ValueError(f"Invalid eps2: {eps2}, must be >= 0.0")
|
|
282
|
+
if d < 1.0: raise ValueError(f"Invalid d: {d}, must be >= 1.0")
|
|
283
|
+
if weight_decay < 0.0: raise ValueError(f"Invalid weight_decay: {weight_decay}, must be >= 0.0")
|
|
258
284
|
|
|
259
|
-
if beta2 is not None and
|
|
285
|
+
if beta2 is not None and (beta2 < 0.0 or beta2 >= 1.0):
|
|
260
286
|
raise ValueError(f"Invalid beta2: {beta2}, must be in [0.0, 1.0)")
|
|
261
287
|
|
|
262
288
|
if quantize and block_size % 1024 != 0:
|
|
@@ -268,21 +294,33 @@ class Adafactor8Bit(Optimizer):
|
|
|
268
294
|
if apollo_rank > 0 and apollo_scale_type not in ('channel', 'tensor'):
|
|
269
295
|
raise ValueError(f"apollo_scale_type must be 'channel' or 'tensor', got {apollo_scale_type}.")
|
|
270
296
|
|
|
271
|
-
if
|
|
297
|
+
if fira_margin < 0.0: raise ValueError(f"Invalid fira_margin: {fira_margin}, must be >= 0.0")
|
|
298
|
+
|
|
299
|
+
if beta3 is not None:
|
|
300
|
+
if beta3 < 0.0 or beta3 >= 1.0:
|
|
301
|
+
raise ValueError(f"Invalid beta3: {beta3}, must be in [0.0, 1.0)")
|
|
302
|
+
if beta1 is None:
|
|
303
|
+
raise ValueError("CAME (beta3) strictly requires momentum (beta1) to compute update instability.")
|
|
304
|
+
if apollo_rank > 0:
|
|
305
|
+
raise ValueError("CAME (beta3) and APOLLO (apollo_rank > 0) are mutually exclusive optimization strategies.")
|
|
306
|
+
if not factored:
|
|
307
|
+
raise ValueError("CAME (beta3) requires factored=True (2D row/col factorization). It is not supported for 1D full-rank paths.")
|
|
272
308
|
|
|
273
309
|
defaults = dict(
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
310
|
+
# Core Optimization
|
|
311
|
+
lr=lr, beta1=beta1, beta2=beta2, beta2_decay=beta2_decay, beta3=beta3,
|
|
312
|
+
eps=eps, weight_decay=weight_decay, d=d, maximize=maximize,
|
|
313
|
+
# Factorization & Scaling
|
|
314
|
+
relative_step=relative_step, scale_parameter=scale_parameter, factored=factored,
|
|
315
|
+
# Quantization Control
|
|
316
|
+
quantize=quantize, block_size=block_size, m_block_size=m_block_size,
|
|
317
|
+
min_8bit_size=min_8bit_size, use_cuda_kernel=use_cuda_kernel,
|
|
318
|
+
# APOLLO Low-Rank Projection
|
|
319
|
+
apollo_rank=apollo_rank, apollo_update_proj_gap=apollo_update_proj_gap,
|
|
320
|
+
apollo_scale_type=apollo_scale_type, apollo_eps=apollo_eps, apollo_factorize=apollo_factorize,
|
|
321
|
+
# Stabilizers & Regularization
|
|
322
|
+
scale_weight_decay=scale_weight_decay,
|
|
323
|
+
enable_fira_for_adafactor=enable_fira_for_adafactor, fira_margin=fira_margin,
|
|
286
324
|
)
|
|
287
325
|
super().__init__(params, defaults)
|
|
288
326
|
|
|
@@ -330,8 +368,10 @@ class Adafactor8Bit(Optimizer):
|
|
|
330
368
|
m_block_size = group.get("m_block_size", 128)
|
|
331
369
|
min_8bit_size = group.get("min_8bit_size", 4096)
|
|
332
370
|
apollo_rank = group.get("apollo_rank", 0)
|
|
371
|
+
apollo_factorize = group.get("apollo_factorize", False)
|
|
333
372
|
factored = group.get("factored", True)
|
|
334
373
|
beta1 = group.get("beta1")
|
|
374
|
+
beta3 = group.get("beta3")
|
|
335
375
|
|
|
336
376
|
for p in group["params"]:
|
|
337
377
|
if p.grad is None: continue
|
|
@@ -358,8 +398,8 @@ class Adafactor8Bit(Optimizer):
|
|
|
358
398
|
state["step"] = step_backup
|
|
359
399
|
needs_init = True
|
|
360
400
|
elif use_apollo and is_apollo_state:
|
|
361
|
-
if state.get("apollo_rank") != apollo_rank:
|
|
362
|
-
logger.warning(f"Adafactor8Bit: Apollo
|
|
401
|
+
if state.get("apollo_rank") != apollo_rank or state.get("apollo_factorize", False) != apollo_factorize:
|
|
402
|
+
logger.warning(f"Adafactor8Bit: Apollo config changed for param shape {p.shape}. Re-initializing state.")
|
|
363
403
|
step_backup = state.get("step", 0)
|
|
364
404
|
state.clear()
|
|
365
405
|
state["step"] = step_backup
|
|
@@ -424,11 +464,25 @@ class Adafactor8Bit(Optimizer):
|
|
|
424
464
|
state["m_q"] = torch.full((m_padded_numel // 2,), 0x88, dtype=torch.uint8, device=p.device)
|
|
425
465
|
state["m_scale"] = torch.ones(m_padded_numel // m_block_size, dtype=torch.float32, device=p.device)
|
|
426
466
|
state["m_block_size"] = m_block_size
|
|
467
|
+
|
|
468
|
+
if beta3 is not None:
|
|
469
|
+
state["conf_row_q"] = torch.zeros_like(state["row_var_q"])
|
|
470
|
+
state["conf_row_scale"] = torch.ones_like(state["row_var_scale"])
|
|
471
|
+
state["conf_row_shape"] = state["row_var_shape"]
|
|
472
|
+
state["conf_row_pad"] = state["row_var_pad"]
|
|
473
|
+
|
|
474
|
+
state["conf_col_q"] = torch.zeros_like(state["col_var_q"])
|
|
475
|
+
state["conf_col_scale"] = torch.ones_like(state["col_var_scale"])
|
|
476
|
+
state["conf_col_shape"] = state["col_var_shape"]
|
|
477
|
+
state["conf_col_pad"] = state["col_var_pad"]
|
|
427
478
|
else:
|
|
428
|
-
state["row_var"] = torch.zeros(r_shape, device=p.device)
|
|
429
|
-
state["col_var"] = torch.zeros(c_shape, device=p.device)
|
|
479
|
+
state["row_var"] = torch.zeros(r_shape, dtype=torch.float32, device=p.device)
|
|
480
|
+
state["col_var"] = torch.zeros(c_shape, dtype=torch.float32, device=p.device)
|
|
430
481
|
if beta1 is not None:
|
|
431
|
-
state["m"] = torch.zeros_like(p.grad, device=p.device, memory_format=torch.preserve_format)
|
|
482
|
+
state["m"] = torch.zeros_like(p.grad, dtype=torch.float32, device=p.device, memory_format=torch.preserve_format)
|
|
483
|
+
if beta3 is not None:
|
|
484
|
+
state["conf_row"] = torch.zeros(r_shape, device=p.device)
|
|
485
|
+
state["conf_col"] = torch.zeros(c_shape, device=p.device)
|
|
432
486
|
else:
|
|
433
487
|
if use_quant:
|
|
434
488
|
v_numel = p.grad.numel()
|
|
@@ -444,9 +498,9 @@ class Adafactor8Bit(Optimizer):
|
|
|
444
498
|
state["m_scale"] = torch.ones(m_padded_numel // m_block_size, dtype=torch.float32, device=p.device)
|
|
445
499
|
state["m_block_size"] = m_block_size
|
|
446
500
|
else:
|
|
447
|
-
state["variance"] = torch.zeros_like(p.grad, memory_format=torch.preserve_format)
|
|
501
|
+
state["variance"] = torch.zeros_like(p.grad, dtype=torch.float32, memory_format=torch.preserve_format)
|
|
448
502
|
if beta1 is not None:
|
|
449
|
-
state["m"] = torch.zeros_like(p.grad, device=p.device, memory_format=torch.preserve_format)
|
|
503
|
+
state["m"] = torch.zeros_like(p.grad, dtype=torch.float32, device=p.device, memory_format=torch.preserve_format)
|
|
450
504
|
else:
|
|
451
505
|
if torch.is_tensor(state["step"]):
|
|
452
506
|
state["step"] = int(state["step"].cpu().item())
|
|
@@ -466,6 +520,16 @@ class Adafactor8Bit(Optimizer):
|
|
|
466
520
|
state_is_factored = ("row_var" in state or "row_var_q" in state)
|
|
467
521
|
|
|
468
522
|
if use_quant and not state.get("is_quantized", False):
|
|
523
|
+
if isinstance(state.get("v_low"), Tensor) and state.get("v_low_q") is None:
|
|
524
|
+
state["v_low"].clamp_(min=_FP32_TINY)
|
|
525
|
+
q, s, sh, pad = _log_quantize_nonneg(state["v_low"], curr_block_size)
|
|
526
|
+
state["v_low_q"], state["v_low_scale"], state["v_low_shape"], state["v_low_pad"] = q, s, sh, pad
|
|
527
|
+
state["v_low"] = None
|
|
528
|
+
|
|
529
|
+
if "m_low" in state:
|
|
530
|
+
logger.warning("Adafactor8Bit: Apollo m_low discarded due to quantize flag change.")
|
|
531
|
+
state.pop("m_low", None)
|
|
532
|
+
|
|
469
533
|
if state_is_factored:
|
|
470
534
|
if "row_var" in state and "row_var_q" not in state:
|
|
471
535
|
state["row_var"].clamp_(min=_FP32_TINY)
|
|
@@ -477,6 +541,17 @@ class Adafactor8Bit(Optimizer):
|
|
|
477
541
|
q, s, sh, pad = _log_quantize_nonneg(state["col_var"], curr_block_size)
|
|
478
542
|
state["col_var_q"], state["col_var_scale"], state["col_var_shape"], state["col_var_pad"] = q, s, sh, pad
|
|
479
543
|
del state["col_var"]
|
|
544
|
+
if beta3 is not None:
|
|
545
|
+
if "conf_row" in state and "conf_row_q" not in state:
|
|
546
|
+
state["conf_row"].clamp_(min=_FP32_TINY)
|
|
547
|
+
q, s, sh, pad = _log_quantize_nonneg(state["conf_row"], curr_block_size)
|
|
548
|
+
state["conf_row_q"], state["conf_row_scale"], state["conf_row_shape"], state["conf_row_pad"] = q, s, sh, pad
|
|
549
|
+
del state["conf_row"]
|
|
550
|
+
if "conf_col" in state and "conf_col_q" not in state:
|
|
551
|
+
state["conf_col"].clamp_(min=_FP32_TINY)
|
|
552
|
+
q, s, sh, pad = _log_quantize_nonneg(state["conf_col"], curr_block_size)
|
|
553
|
+
state["conf_col_q"], state["conf_col_scale"], state["conf_col_shape"], state["conf_col_pad"] = q, s, sh, pad
|
|
554
|
+
del state["conf_col"]
|
|
480
555
|
else:
|
|
481
556
|
if "variance" in state and "variance_q" not in state:
|
|
482
557
|
state["variance"].clamp_(min=_FP32_TINY)
|
|
@@ -498,11 +573,27 @@ class Adafactor8Bit(Optimizer):
|
|
|
498
573
|
state["is_quantized"] = True
|
|
499
574
|
|
|
500
575
|
elif not use_quant and state.get("is_quantized", False):
|
|
576
|
+
if isinstance(state.get("v_low_q"), Tensor):
|
|
577
|
+
state["v_low"] = _log_dequantize_nonneg(
|
|
578
|
+
state.pop("v_low_q"), state.pop("v_low_scale"),
|
|
579
|
+
state.pop("v_low_shape"), state.pop("v_low_pad")
|
|
580
|
+
)
|
|
581
|
+
|
|
582
|
+
if "m_low_q" in state:
|
|
583
|
+
logger.warning("Adafactor8Bit: Apollo m_low_q discarded due to quantize flag change.")
|
|
584
|
+
state.pop("m_low_q", None)
|
|
585
|
+
state.pop("m_low_scale", None)
|
|
586
|
+
|
|
501
587
|
if state_is_factored:
|
|
502
588
|
if "row_var_q" in state:
|
|
503
589
|
state["row_var"] = _log_dequantize_nonneg(state.pop("row_var_q"), state.pop("row_var_scale"), state.pop("row_var_shape"), state.pop("row_var_pad"))
|
|
504
590
|
if "col_var_q" in state:
|
|
505
591
|
state["col_var"] = _log_dequantize_nonneg(state.pop("col_var_q"), state.pop("col_var_scale"), state.pop("col_var_shape"), state.pop("col_var_pad"))
|
|
592
|
+
if beta3 is not None:
|
|
593
|
+
if "conf_row_q" in state:
|
|
594
|
+
state["conf_row"] = _log_dequantize_nonneg(state.pop("conf_row_q"), state.pop("conf_row_scale"), state.pop("conf_row_shape"), state.pop("conf_row_pad"))
|
|
595
|
+
if "conf_col_q" in state:
|
|
596
|
+
state["conf_col"] = _log_dequantize_nonneg(state.pop("conf_col_q"), state.pop("conf_col_scale"), state.pop("conf_col_shape"), state.pop("conf_col_pad"))
|
|
506
597
|
else:
|
|
507
598
|
if "variance_q" in state:
|
|
508
599
|
state["variance"] = _log_dequantize_nonneg(state.pop("variance_q"), state.pop("variance_scale"), state.pop("variance_shape"), state.pop("variance_pad"))
|
|
@@ -597,6 +688,7 @@ class Adafactor8Bit(Optimizer):
|
|
|
597
688
|
enable_fira_for_adafactor=group.get("enable_fira_for_adafactor", False),
|
|
598
689
|
fira_margin=group.get("fira_margin", 0.01),
|
|
599
690
|
factored=group.get("factored", True),
|
|
691
|
+
beta3=group.get("beta3"),
|
|
600
692
|
)
|
|
601
693
|
return loss
|
|
602
694
|
|
|
@@ -667,6 +759,7 @@ def _update_param_8bit(
|
|
|
667
759
|
enable_fira_for_adafactor: bool = False,
|
|
668
760
|
fira_margin: float = 0.01,
|
|
669
761
|
factored: bool = True,
|
|
762
|
+
beta3: Optional[float] = None,
|
|
670
763
|
):
|
|
671
764
|
if eps1 is None:
|
|
672
765
|
eps1 = torch.finfo(param.dtype).eps
|
|
@@ -893,11 +986,13 @@ def _update_param_8bit(
|
|
|
893
986
|
C = shape[-1]
|
|
894
987
|
numel = grad_fp32.numel()
|
|
895
988
|
|
|
896
|
-
|
|
897
|
-
|
|
989
|
+
g_sq = grad_fp32.square()
|
|
990
|
+
row_mean = g_sq.mean(dim=-1, keepdim=True)
|
|
991
|
+
col_mean = g_sq.mean(dim=-2, keepdim=True)
|
|
898
992
|
|
|
899
993
|
if quantize:
|
|
900
994
|
if _load_cuda_module(use_cuda_kernel):
|
|
995
|
+
del g_sq
|
|
901
996
|
_CUDA_MODULE.fused_log_quantize_lerp(state["row_var_q"], state["row_var_scale"], row_mean.reshape(-1), beta_val, curr_block_size, False, row_mean.numel())
|
|
902
997
|
_CUDA_MODULE.fused_log_quantize_lerp(state["col_var_q"], state["col_var_scale"], col_mean.reshape(-1), beta_val, curr_block_size, False, col_mean.numel())
|
|
903
998
|
|
|
@@ -905,26 +1000,76 @@ def _update_param_8bit(
|
|
|
905
1000
|
row_mean_val_flat = row_var.mean(dim=-2, keepdim=True).clamp_(min=eps1).flatten().contiguous()
|
|
906
1001
|
del row_var
|
|
907
1002
|
|
|
908
|
-
grad_flat = grad_fp32.reshape(-1)
|
|
909
|
-
row_var_q_flat = state["row_var_q"].reshape(-1)
|
|
910
|
-
col_var_q_flat = state["col_var_q"].reshape(-1)
|
|
911
|
-
|
|
912
|
-
total_sum_sq = torch.zeros(1, device=param_work.device, dtype=torch.float32)
|
|
913
|
-
|
|
914
1003
|
if beta1 is not None:
|
|
915
1004
|
_CUDA_MODULE.fused_4bit_quantize_lerp(
|
|
916
|
-
state["m_q"], state["m_scale"], grad_fp32.view(-1), beta1, m_curr_block_size,
|
|
1005
|
+
state["m_q"], state["m_scale"], grad_fp32.view(-1), beta1, m_curr_block_size, numel
|
|
917
1006
|
)
|
|
918
|
-
del grad_fp32
|
|
919
1007
|
|
|
1008
|
+
if beta3 is not None and beta1 is not None:
|
|
1009
|
+
batch_size = math.prod(shape[:-2]) if len(shape) > 2 else 1
|
|
1010
|
+
res_row_sum = torch.zeros(batch_size * R, device=param_work.device, dtype=torch.float32)
|
|
1011
|
+
res_col_sum = torch.zeros(batch_size * C, device=param_work.device, dtype=torch.float32)
|
|
1012
|
+
|
|
1013
|
+
_CUDA_MODULE.came_compute_residual_2d(
|
|
1014
|
+
state["m_q"].view(-1), state["m_scale"].view(-1),
|
|
1015
|
+
state["row_var_q"].view(-1), state["row_var_scale"],
|
|
1016
|
+
state["col_var_q"].view(-1), state["col_var_scale"],
|
|
1017
|
+
grad_fp32.reshape(-1), row_mean_val_flat,
|
|
1018
|
+
res_row_sum, res_col_sum,
|
|
1019
|
+
log_eps_sq, R, C, numel, m_curr_block_size, curr_block_size
|
|
1020
|
+
)
|
|
1021
|
+
|
|
1022
|
+
beta3_val = 1.0 - beta3
|
|
1023
|
+
u_row_mean = (res_row_sum / C).contiguous().view(-1)
|
|
1024
|
+
u_col_mean = (res_col_sum / R).contiguous().view(-1)
|
|
1025
|
+
del res_row_sum, res_col_sum
|
|
1026
|
+
|
|
1027
|
+
_CUDA_MODULE.fused_log_quantize_lerp(state["conf_row_q"], state["conf_row_scale"], u_row_mean, beta3_val, curr_block_size, False, u_row_mean.numel())
|
|
1028
|
+
_CUDA_MODULE.fused_log_quantize_lerp(state["conf_col_q"], state["conf_col_scale"], u_col_mean, beta3_val, curr_block_size, False, u_col_mean.numel())
|
|
1029
|
+
|
|
1030
|
+
v_row = _log_dequantize_nonneg(state["row_var_q"], state["row_var_scale"], state["row_var_shape"], state["row_var_pad"])
|
|
1031
|
+
v_col = _log_dequantize_nonneg(state["col_var_q"], state["col_var_scale"], state["col_var_shape"], state["col_var_pad"])
|
|
1032
|
+
c_row = _log_dequantize_nonneg(state["conf_row_q"], state["conf_row_scale"], state["conf_row_shape"], state["conf_row_pad"])
|
|
1033
|
+
c_col = _log_dequantize_nonneg(state["conf_col_q"], state["conf_col_scale"], state["conf_col_shape"], state["conf_col_pad"])
|
|
1034
|
+
|
|
1035
|
+
combined_row = (v_row * c_row).clamp_(min=_FP32_TINY)
|
|
1036
|
+
combined_col = (v_col * c_col).clamp_(min=_FP32_TINY)
|
|
1037
|
+
|
|
1038
|
+
kernel_row_mean = combined_row.mean(dim=-2, keepdim=True).clamp_(min=eps1).flatten().contiguous()
|
|
1039
|
+
|
|
1040
|
+
q_r, s_r, _, _ = _log_quantize_nonneg(combined_row, curr_block_size)
|
|
1041
|
+
q_c, s_c, _, _ = _log_quantize_nonneg(combined_col, curr_block_size)
|
|
1042
|
+
del v_row, v_col, c_row, c_col, combined_row, combined_col
|
|
1043
|
+
|
|
1044
|
+
kernel_row_q_flat = q_r.reshape(-1)
|
|
1045
|
+
kernel_row_scale = s_r
|
|
1046
|
+
kernel_col_q_flat = q_c.reshape(-1)
|
|
1047
|
+
kernel_col_scale = s_c
|
|
1048
|
+
else:
|
|
1049
|
+
kernel_row_mean = row_mean_val_flat
|
|
1050
|
+
kernel_row_q_flat = state["row_var_q"].reshape(-1)
|
|
1051
|
+
kernel_row_scale = state["row_var_scale"]
|
|
1052
|
+
kernel_col_q_flat = state["col_var_q"].reshape(-1)
|
|
1053
|
+
kernel_col_scale = state["col_var_scale"]
|
|
1054
|
+
|
|
1055
|
+
if beta1 is not None:
|
|
1056
|
+
grad_flat = None
|
|
1057
|
+
del grad_fp32
|
|
1058
|
+
else:
|
|
1059
|
+
grad_flat = grad_fp32.reshape(-1)
|
|
1060
|
+
del grad_fp32
|
|
1061
|
+
|
|
1062
|
+
total_sum_sq = torch.zeros(1, device=param_work.device, dtype=torch.float32)
|
|
1063
|
+
|
|
1064
|
+
if beta1 is not None:
|
|
920
1065
|
m_q_flat = state["m_q"].view(-1)
|
|
921
1066
|
m_scale_flat = state["m_scale"].view(-1)
|
|
922
1067
|
|
|
923
1068
|
_CUDA_MODULE.compute_update_norm_m_2d(
|
|
924
1069
|
m_q_flat, m_scale_flat,
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
total_sum_sq,
|
|
1070
|
+
kernel_row_q_flat, kernel_row_scale,
|
|
1071
|
+
kernel_col_q_flat, kernel_col_scale,
|
|
1072
|
+
total_sum_sq, kernel_row_mean, log_eps_sq, R, C, numel, m_curr_block_size, curr_block_size
|
|
928
1073
|
)
|
|
929
1074
|
|
|
930
1075
|
if enable_fira_for_adafactor:
|
|
@@ -934,15 +1079,15 @@ def _update_param_8bit(
|
|
|
934
1079
|
_CUDA_MODULE.apply_update_m_2d(
|
|
935
1080
|
param_flat,
|
|
936
1081
|
m_q_flat, m_scale_flat,
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
total_sum_sq, alpha,
|
|
1082
|
+
kernel_row_q_flat, kernel_row_scale,
|
|
1083
|
+
kernel_col_q_flat, kernel_col_scale,
|
|
1084
|
+
total_sum_sq, alpha, kernel_row_mean, d, log_eps_sq, R, C, numel, m_curr_block_size, curr_block_size
|
|
940
1085
|
)
|
|
941
1086
|
else:
|
|
942
1087
|
_CUDA_MODULE.compute_update_norm_2d(
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
grad_flat, total_sum_sq,
|
|
1088
|
+
kernel_row_q_flat, kernel_row_scale,
|
|
1089
|
+
kernel_col_q_flat, kernel_col_scale,
|
|
1090
|
+
grad_flat, total_sum_sq, kernel_row_mean, log_eps_sq, R, C, numel, curr_block_size
|
|
946
1091
|
)
|
|
947
1092
|
|
|
948
1093
|
if enable_fira_for_adafactor:
|
|
@@ -951,9 +1096,9 @@ def _update_param_8bit(
|
|
|
951
1096
|
param_flat = param_work.reshape(-1)
|
|
952
1097
|
_CUDA_MODULE.apply_update_2d(
|
|
953
1098
|
param_flat, grad_flat,
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
total_sum_sq, alpha,
|
|
1099
|
+
kernel_row_q_flat, kernel_row_scale,
|
|
1100
|
+
kernel_col_q_flat, kernel_col_scale,
|
|
1101
|
+
total_sum_sq, alpha, kernel_row_mean, d, log_eps_sq, R, C, numel, curr_block_size
|
|
957
1102
|
)
|
|
958
1103
|
else:
|
|
959
1104
|
row_var = _log_dequantize_nonneg(state["row_var_q"], state["row_var_scale"], state["row_var_shape"], state["row_var_pad"])
|
|
@@ -975,8 +1120,53 @@ def _update_param_8bit(
|
|
|
975
1120
|
m_temp = _dequantize_4bit(state["m_q"], state["m_scale"], grad_fp32.numel(), grad_fp32.shape, m_curr_block_size, grad_fp32.device)
|
|
976
1121
|
else:
|
|
977
1122
|
m_temp = torch.zeros_like(grad_fp32)
|
|
1123
|
+
|
|
978
1124
|
m_temp.lerp_(grad_fp32, 1.0 - beta1)
|
|
979
|
-
|
|
1125
|
+
|
|
1126
|
+
if beta3 is not None:
|
|
1127
|
+
inv_col_sq = inv_col.square()
|
|
1128
|
+
inv_row_sq = inv_row.square()
|
|
1129
|
+
inv_col_sq_T = inv_col_sq.transpose(-1, -2)
|
|
1130
|
+
inv_row_sq_T = inv_row_sq.transpose(-1, -2)
|
|
1131
|
+
|
|
1132
|
+
gm = grad_fp32 * m_temp
|
|
1133
|
+
m_sq = m_temp.square()
|
|
1134
|
+
|
|
1135
|
+
t1 = torch.matmul(g_sq, inv_col_sq_T) / C
|
|
1136
|
+
t2 = torch.matmul(gm, inv_col_sq_T) / C
|
|
1137
|
+
t3 = torch.matmul(m_sq, inv_col_sq_T) / C
|
|
1138
|
+
res_row_mean = (inv_row_sq * (t1 - 2.0 * t2 + t3)).clamp(min=0)
|
|
1139
|
+
|
|
1140
|
+
t1c = torch.matmul(inv_row_sq_T, g_sq) / R
|
|
1141
|
+
t2c = torch.matmul(inv_row_sq_T, gm) / R
|
|
1142
|
+
t3c = torch.matmul(inv_row_sq_T, m_sq) / R
|
|
1143
|
+
res_col_mean = (inv_col_sq * (t1c - 2.0 * t2c + t3c)).clamp(min=0)
|
|
1144
|
+
|
|
1145
|
+
del gm, m_sq, g_sq
|
|
1146
|
+
|
|
1147
|
+
conf_row_temp = _log_dequantize_nonneg(state["conf_row_q"], state["conf_row_scale"], state["conf_row_shape"], state["conf_row_pad"])
|
|
1148
|
+
conf_col_temp = _log_dequantize_nonneg(state["conf_col_q"], state["conf_col_scale"], state["conf_col_shape"], state["conf_col_pad"])
|
|
1149
|
+
|
|
1150
|
+
conf_row_temp.lerp_(res_row_mean, 1.0 - beta3)
|
|
1151
|
+
conf_col_temp.lerp_(res_col_mean, 1.0 - beta3)
|
|
1152
|
+
|
|
1153
|
+
q_cr, s_cr, sh_cr, pad_cr = _log_quantize_nonneg(conf_row_temp, curr_block_size)
|
|
1154
|
+
state["conf_row_q"], state["conf_row_scale"], state["conf_row_shape"], state["conf_row_pad"] = q_cr, s_cr, sh_cr, pad_cr
|
|
1155
|
+
|
|
1156
|
+
q_cc, s_cc, sh_cc, pad_cc = _log_quantize_nonneg(conf_col_temp, curr_block_size)
|
|
1157
|
+
state["conf_col_q"], state["conf_col_scale"], state["conf_col_shape"], state["conf_col_pad"] = q_cc, s_cc, sh_cc, pad_cc
|
|
1158
|
+
|
|
1159
|
+
combined_row = (row_var * conf_row_temp).clamp(min=eps_sq)
|
|
1160
|
+
combined_col = (col_var * conf_col_temp).clamp(min=eps_sq)
|
|
1161
|
+
del conf_row_temp, conf_col_temp
|
|
1162
|
+
|
|
1163
|
+
combined_row_mean_val = combined_row.mean(dim=-2, keepdim=True).clamp(min=eps1)
|
|
1164
|
+
inv_row = combined_row.rsqrt() * combined_row_mean_val.sqrt()
|
|
1165
|
+
inv_col = combined_col.rsqrt()
|
|
1166
|
+
else:
|
|
1167
|
+
del g_sq
|
|
1168
|
+
|
|
1169
|
+
del grad_fp32
|
|
980
1170
|
|
|
981
1171
|
update = m_temp * inv_row
|
|
982
1172
|
update.mul_(inv_col)
|
|
@@ -989,6 +1179,7 @@ def _update_param_8bit(
|
|
|
989
1179
|
denom = torch.clamp(torch.linalg.vector_norm(update) / (math.sqrt(update.numel()) * d), min=1.0)
|
|
990
1180
|
param_work.add_(update, alpha=-alpha / denom)
|
|
991
1181
|
else:
|
|
1182
|
+
del g_sq
|
|
992
1183
|
update = grad_fp32 * inv_row
|
|
993
1184
|
update.mul_(inv_col)
|
|
994
1185
|
|
|
@@ -1011,6 +1202,39 @@ def _update_param_8bit(
|
|
|
1011
1202
|
if "m" not in state:
|
|
1012
1203
|
state["m"] = torch.zeros_like(grad_fp32)
|
|
1013
1204
|
state["m"].lerp_(grad_fp32, 1.0 - beta1)
|
|
1205
|
+
|
|
1206
|
+
if beta3 is not None:
|
|
1207
|
+
inv_col_sq = inv_col.square()
|
|
1208
|
+
inv_row_sq = inv_row.square()
|
|
1209
|
+
inv_col_sq_T = inv_col_sq.transpose(-1, -2)
|
|
1210
|
+
inv_row_sq_T = inv_row_sq.transpose(-1, -2)
|
|
1211
|
+
|
|
1212
|
+
gm = grad_fp32 * state["m"]
|
|
1213
|
+
m_sq = state["m"].square()
|
|
1214
|
+
|
|
1215
|
+
t1 = torch.matmul(g_sq, inv_col_sq_T) / C
|
|
1216
|
+
t2 = torch.matmul(gm, inv_col_sq_T) / C
|
|
1217
|
+
t3 = torch.matmul(m_sq, inv_col_sq_T) / C
|
|
1218
|
+
res_row_mean = (inv_row_sq * (t1 - 2.0 * t2 + t3)).clamp(min=0)
|
|
1219
|
+
|
|
1220
|
+
t1c = torch.matmul(inv_row_sq_T, g_sq) / R
|
|
1221
|
+
t2c = torch.matmul(inv_row_sq_T, gm) / R
|
|
1222
|
+
t3c = torch.matmul(inv_row_sq_T, m_sq) / R
|
|
1223
|
+
res_col_mean = (inv_col_sq * (t1c - 2.0 * t2c + t3c)).clamp(min=0)
|
|
1224
|
+
|
|
1225
|
+
del gm, m_sq, g_sq
|
|
1226
|
+
|
|
1227
|
+
state["conf_row"].lerp_(res_row_mean, 1.0 - beta3)
|
|
1228
|
+
state["conf_col"].lerp_(res_col_mean, 1.0 - beta3)
|
|
1229
|
+
|
|
1230
|
+
combined_row = (row_var * state["conf_row"]).clamp(min=eps_sq)
|
|
1231
|
+
combined_col = (col_var * state["conf_col"]).clamp(min=eps_sq)
|
|
1232
|
+
combined_row_mean_val = combined_row.mean(dim=-2, keepdim=True).clamp(min=eps1)
|
|
1233
|
+
inv_row = combined_row.rsqrt() * combined_row_mean_val.sqrt()
|
|
1234
|
+
inv_col = combined_col.rsqrt()
|
|
1235
|
+
else:
|
|
1236
|
+
del g_sq
|
|
1237
|
+
|
|
1014
1238
|
del grad_fp32
|
|
1015
1239
|
|
|
1016
1240
|
update = state["m"] * inv_row
|
|
@@ -1022,6 +1246,7 @@ def _update_param_8bit(
|
|
|
1022
1246
|
denom = torch.clamp(torch.linalg.vector_norm(update) / (math.sqrt(update.numel()) * d), min=1.0)
|
|
1023
1247
|
param_work.add_(update, alpha=-alpha / denom)
|
|
1024
1248
|
else:
|
|
1249
|
+
del g_sq
|
|
1025
1250
|
update = grad_fp32 * inv_row
|
|
1026
1251
|
update.mul_(inv_col)
|
|
1027
1252
|
|
|
@@ -1156,8 +1381,8 @@ def _update_param_apollo(
|
|
|
1156
1381
|
col_mean_low = grad_low.square().mean(dim=-2, keepdim=True)
|
|
1157
1382
|
|
|
1158
1383
|
if "row_var_low" not in state:
|
|
1159
|
-
state["row_var_low"] = row_mean_low
|
|
1160
|
-
state["col_var_low"] = col_mean_low
|
|
1384
|
+
state["row_var_low"] = (row_mean_low * beta_val).clamp(min=_FP32_TINY)
|
|
1385
|
+
state["col_var_low"] = (col_mean_low * beta_val).clamp(min=_FP32_TINY)
|
|
1161
1386
|
else:
|
|
1162
1387
|
state["row_var_low"].mul_(1.0 - beta_val).add_(row_mean_low, alpha=beta_val)
|
|
1163
1388
|
state["col_var_low"].mul_(1.0 - beta_val).add_(col_mean_low, alpha=beta_val)
|
|
@@ -1213,7 +1438,7 @@ def _update_param_apollo(
|
|
|
1213
1438
|
quantize = state.get("is_quantized", True)
|
|
1214
1439
|
|
|
1215
1440
|
if is_first_step:
|
|
1216
|
-
v_init = grad_low.flatten().square().clamp(min=_FP32_TINY)
|
|
1441
|
+
v_init = (grad_low.flatten().square() * beta_val).clamp(min=_FP32_TINY)
|
|
1217
1442
|
if quantize:
|
|
1218
1443
|
q, s, sh, pad = _log_quantize_nonneg(v_init, block_size)
|
|
1219
1444
|
state["v_low_q"], state["v_low_scale"], state["v_low_shape"], state["v_low_pad"] = q, s, sh, pad
|
|
@@ -1392,4 +1617,4 @@ def _update_param_apollo(
|
|
|
1392
1617
|
del update_low
|
|
1393
1618
|
|
|
1394
1619
|
if needs_copy_back:
|
|
1395
|
-
param.copy_(param_work.view(original_shape))
|
|
1620
|
+
param.copy_(param_work.view(original_shape))
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: adafactor8bit
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.2
|
|
4
4
|
Summary: 8-bit Adafactor Optimizer with Fused CUDA Kernels
|
|
5
5
|
Home-page: https://github.com/yanfeiwong/adafactor-8bit
|
|
6
6
|
Author: WANG YAN
|
|
@@ -25,6 +25,13 @@ Dynamic: requires-dist
|
|
|
25
25
|
Dynamic: requires-python
|
|
26
26
|
Dynamic: summary
|
|
27
27
|
|
|
28
|
+
<p align="center">
|
|
29
|
+
<a href="https://github.com/yanfeiwong/adafactor-8bit">
|
|
30
|
+
<img src="https://github.com/yanfeiwong/adafactor-8bit/raw/main/assets/banner.png"
|
|
31
|
+
alt="Adafactor8Bit"
|
|
32
|
+
width="80%">
|
|
33
|
+
</a>
|
|
34
|
+
</p>
|
|
28
35
|
<div align="center">
|
|
29
36
|
|
|
30
37
|
# 8-bit Adafactor with Fused CUDA Kernels
|
|
@@ -39,14 +46,15 @@ Dynamic: summary
|
|
|
39
46
|
|
|
40
47
|
</div>
|
|
41
48
|
|
|
42
|
-
An enhanced 8-bit Adafactor optimizer featuring fused CUDA kernels, log-space block-wise quantization, optional APOLLO low-rank updates, and
|
|
49
|
+
An enhanced 8-bit Adafactor optimizer featuring fused CUDA kernels, log-space block-wise quantization, and optional add-ons including 4-bit packed first moments, APOLLO low-rank updates, and CAME confidence-guided optimization. It delivers substantially lower optimizer memory while preserving the low-overhead and numerical stability that make Adafactor attractive for training LLMs and diffusion models.
|
|
43
50
|
|
|
44
51
|
|
|
45
|
-
##
|
|
52
|
+
## ⚡ Key Features
|
|
46
53
|
|
|
47
54
|
- **Log-Space Quantization**: Maps the second moment (variance) to the log2 space before 8-bit quantization. This approach accommodates the long-tail distribution of variances, reducing the risk of small second-moment estimates being truncated to zero and improving overall training stability.
|
|
48
55
|
- **Fused CUDA Kernels**: Combines dequantization, EMA updates, Warp-Shuffle reductions, and requantization into single kernels. It utilizes `float4` vectorization to optimize memory bandwidth usage.
|
|
49
56
|
- **Optional 4-bit Packed First Moment**: Stores the first moment (`beta1`) in a physically packed 4-bit format when enabled, providing momentum with minimal additional memory overhead.
|
|
57
|
+
- **CAME Confidence Guidance**: Optional Confidence-guided Adaptive Memory Efficient Optimization (CAME) that estimates update confidence from historical momentum and adaptively suppresses unstable update directions, improving training stability and reducing loss spikes.
|
|
50
58
|
- **APOLLO Subspace Projection**: Opt-in random subspace projection that estimates adaptive gradient scaling in a low-rank space, preventing stale second-moment statistics and potentially improving convergence and generalization.
|
|
51
59
|
- **Fira Norm-Growth Limiter**: Suppresses destructive gradient spikes by regulating the relative increase of update norms. Originally used for the APOLLO path, it is now available for the standard Adafactor path as well. It improves training stability and often allows the safe removal of external gradient clipping.
|
|
52
60
|
- **Zero CPU-GPU Sync**: Eliminates implicit synchronizations (e.g., D2H copies) in the control flow, ensuring the GPU computation pipeline runs without blocking.
|
|
@@ -194,16 +202,16 @@ def get_param_groups(model, lr_emb, weight_decay, apollo_rank=256):
|
|
|
194
202
|
"weight_decay": weight_decay,
|
|
195
203
|
"quantize": True,
|
|
196
204
|
"apollo_rank": apollo_rank,
|
|
197
|
-
"beta1":0.9,
|
|
205
|
+
"beta1": 0.9, # Remove if minimizing optimizer memory is the priority.
|
|
198
206
|
},
|
|
199
|
-
|
|
207
|
+
|
|
200
208
|
# 4. >2D Weights: 8-bit quantization, Weight Decay, Full-Rank
|
|
201
209
|
{
|
|
202
210
|
"params": group_nd,
|
|
203
211
|
"weight_decay": weight_decay,
|
|
204
212
|
"quantize": True,
|
|
205
213
|
"apollo_rank": 0,
|
|
206
|
-
"beta1":0.9,
|
|
214
|
+
"beta1": 0.9, # Remove if minimizing optimizer memory is the priority.
|
|
207
215
|
"factored": False # Disables factorization to preserve spatial structures, enabling finer gradient scaling.
|
|
208
216
|
# Note: This increases state memory for >2D weights, depending on your model architecture.
|
|
209
217
|
# If VRAM is constrained, reverting to factored=True is a safe alternative.
|
|
@@ -270,7 +278,40 @@ Enable the APOLLO path to compute gradient scaling factors in a memory-efficient
|
|
|
270
278
|
- **`apollo_factorize` (Experimental)**: Applies Adafactor's row/column factorization within the low-rank subspace. Mathematically, this leverages the norm-preserving property of random projections to approximate the variance of the primary dimension, while the secondary dimension's variance is estimated across random bases, introducing inherent noise. This dual-compression mechanism drastically reduces optimizer state overhead. Note that for smaller models, the actual VRAM savings might be marginal, and the introduced noise could impact convergence stability. Use with caution.
|
|
271
279
|
- **Fira Limiter Integration**: The APOLLO path automatically applies the Fira Norm-Growth Limiter to the scaled gradients to prevent sudden gradient rises from causing loss spikes. You can adjust its sensitivity using the global `fira_margin` parameter.
|
|
272
280
|
|
|
281
|
+
## 🛡️ CAME Confidence-Guided Updates
|
|
282
|
+
|
|
283
|
+
Enable the CAME (Confidence-guided Adaptive Memory Efficient Optimization) path to add a confidence estimation stage after momentum accumulation:
|
|
284
|
+
|
|
285
|
+
**Adaptive Scaling ($V$) → Momentum Accumulation ($M$) → Confidence Weighting ($C$)**
|
|
286
|
+
|
|
287
|
+
### Key Parameters & Tuning
|
|
273
288
|
|
|
289
|
+
The confidence stage measures the consistency between the current update direction and historical momentum, adaptively suppressing highly oscillatory updates.
|
|
290
|
+
|
|
291
|
+
- **`beta3`**: EMA decay coefficient for the confidence matrix. Requires `beta1` (momentum) and `factored=True`. Mutually exclusive with `apollo_rank`. Defaults to `None` (disabled).
|
|
292
|
+
- **Learning Rate**: The official CAME implementation recommends **0.5–0.9×** the AdamW learning rate (see [official tuning guide](https://github.com/yangluo7/CAME/tree/master#hyper-parameter-tuning)). To use this learning rate in this library, you need to disable Adafactor's scaling and clipping (`scale_parameter=False`, `d=1e9`) to align with the original CAME behavior.
|
|
293
|
+
- **Warmup**: Since the confidence matrix is zero-initialized without bias correction, a learning rate warmup is recommended to safely establish the confidence baseline.
|
|
294
|
+
- **Choosing `beta3`**: `beta3` should generally be larger than `beta2` so the confidence estimate evolves more slowly than the variance estimate. A practical starting range is **0.9995–0.99995** when `beta2=0.999`.
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
### Configuration Example
|
|
298
|
+
|
|
299
|
+
To replicate "vanilla" CAME (stripping Adafactor's native modifications), replace the standard 2D APOLLO group in your `param_groups` with the following configuration:
|
|
300
|
+
|
|
301
|
+
```python
|
|
302
|
+
{
|
|
303
|
+
"params": param_group,
|
|
304
|
+
"lr": lr, # Original CAME recommends 0.5-0.9x AdamW LR
|
|
305
|
+
"weight_decay": weight_decay,
|
|
306
|
+
"quantize": True,
|
|
307
|
+
"beta1": 0.9,
|
|
308
|
+
"beta3": 0.9999, # Enable CAME confidence guidance
|
|
309
|
+
"apollo_rank": 0, # Mutually exclusive with CAME
|
|
310
|
+
"scale_parameter": False, # Disable Adafactor RMS scaling to align with vanilla CAME
|
|
311
|
+
"d": 1e9, # Disable Adafactor global RMS clipping
|
|
312
|
+
"enable_fira_for_adafactor": False, # Disable Fira Limiter to prevent interference with CAME's scaling
|
|
313
|
+
},
|
|
314
|
+
```
|
|
274
315
|
|
|
275
316
|
## 📈 Learning Rate Guide for Beginners
|
|
276
317
|
|
|
@@ -298,16 +339,18 @@ Thanks to **Hanqing Zhu**, **Zhenyu Zhang**, and the team for proposing the appr
|
|
|
298
339
|
|
|
299
340
|
Thanks to **Xi Chen**, **Kaituo Feng**, and the team for the Norm-Growth Limiter mechanism introduced in [Fira: Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623).
|
|
300
341
|
|
|
342
|
+
Thanks to **Yang Luo** and the team for proposing the confidence-guided strategy in the paper [CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047).
|
|
343
|
+
|
|
301
344
|
Thanks to the **PyTorch team** for providing the foundational Optimizer implementation and the C++ Extension toolchain.
|
|
302
345
|
|
|
303
346
|
Thanks to the large language models **Qwen**, **ChatGLM** and **DeepSeek** for valuable technical discussions and code reviews on CUDA low-level optimization and memory safety mechanisms.
|
|
304
347
|
|
|
348
|
+
## 🏛️ License
|
|
349
|
+
|
|
350
|
+
[The project is released under the MIT License.](https://github.com/yanfeiwong/adafactor-8bit/blob/main/LICENSE)
|
|
351
|
+
|
|
305
352
|
## ⭐ Star the Project
|
|
306
353
|
|
|
307
354
|
If this optimizer has been useful in your work, consider giving the repository a star. It helps others discover the project and supports future development.
|
|
308
355
|
|
|
309
356
|
[](https://star-history.com/#yanfeiwong/adafactor-8bit&Date)
|
|
310
|
-
|
|
311
|
-
## 📄 License
|
|
312
|
-
|
|
313
|
-
[The project is released under the MIT License.](https://github.com/yanfeiwong/adafactor-8bit/blob/main/LICENSE)
|
|
@@ -9,7 +9,7 @@ long_description = (this_directory / "README.md").read_text(encoding="utf-8")
|
|
|
9
9
|
|
|
10
10
|
setup(
|
|
11
11
|
name="adafactor8bit",
|
|
12
|
-
version="0.2.
|
|
12
|
+
version="0.2.2",
|
|
13
13
|
description="8-bit Adafactor Optimizer with Fused CUDA Kernels",
|
|
14
14
|
author="WANG YAN",
|
|
15
15
|
author_email="yanfeiwong1997@outlook.com",
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|