adafactor8bit 0.2.0__tar.gz → 0.2.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adafactor8bit-0.2.0/adafactor8bit.egg-info → adafactor8bit-0.2.2}/PKG-INFO +97 -29
- {adafactor8bit-0.2.0 → adafactor8bit-0.2.2}/README.md +96 -28
- adafactor8bit-0.2.2/adafactor8bit/kernels.cu +1217 -0
- adafactor8bit-0.2.2/adafactor8bit/optimizer.py +1620 -0
- {adafactor8bit-0.2.0 → adafactor8bit-0.2.2/adafactor8bit.egg-info}/PKG-INFO +97 -29
- {adafactor8bit-0.2.0 → adafactor8bit-0.2.2}/setup.py +1 -1
- adafactor8bit-0.2.0/adafactor8bit/kernels.cu +0 -377
- adafactor8bit-0.2.0/adafactor8bit/optimizer.py +0 -919
- {adafactor8bit-0.2.0 → adafactor8bit-0.2.2}/LICENSE +0 -0
- {adafactor8bit-0.2.0 → adafactor8bit-0.2.2}/MANIFEST.in +0 -0
- {adafactor8bit-0.2.0 → adafactor8bit-0.2.2}/adafactor8bit/__init__.py +0 -0
- {adafactor8bit-0.2.0 → adafactor8bit-0.2.2}/adafactor8bit.egg-info/SOURCES.txt +0 -0
- {adafactor8bit-0.2.0 → adafactor8bit-0.2.2}/adafactor8bit.egg-info/dependency_links.txt +0 -0
- {adafactor8bit-0.2.0 → adafactor8bit-0.2.2}/adafactor8bit.egg-info/requires.txt +0 -0
- {adafactor8bit-0.2.0 → adafactor8bit-0.2.2}/adafactor8bit.egg-info/top_level.txt +0 -0
- {adafactor8bit-0.2.0 → adafactor8bit-0.2.2}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: adafactor8bit
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.2
|
|
4
4
|
Summary: 8-bit Adafactor Optimizer with Fused CUDA Kernels
|
|
5
5
|
Home-page: https://github.com/yanfeiwong/adafactor-8bit
|
|
6
6
|
Author: WANG YAN
|
|
@@ -25,6 +25,13 @@ Dynamic: requires-dist
|
|
|
25
25
|
Dynamic: requires-python
|
|
26
26
|
Dynamic: summary
|
|
27
27
|
|
|
28
|
+
<p align="center">
|
|
29
|
+
<a href="https://github.com/yanfeiwong/adafactor-8bit">
|
|
30
|
+
<img src="https://github.com/yanfeiwong/adafactor-8bit/raw/main/assets/banner.png"
|
|
31
|
+
alt="Adafactor8Bit"
|
|
32
|
+
width="80%">
|
|
33
|
+
</a>
|
|
34
|
+
</p>
|
|
28
35
|
<div align="center">
|
|
29
36
|
|
|
30
37
|
# 8-bit Adafactor with Fused CUDA Kernels
|
|
@@ -39,25 +46,27 @@ Dynamic: summary
|
|
|
39
46
|
|
|
40
47
|
</div>
|
|
41
48
|
|
|
42
|
-
An 8-bit Adafactor optimizer featuring fused CUDA kernels
|
|
49
|
+
An enhanced 8-bit Adafactor optimizer featuring fused CUDA kernels, log-space block-wise quantization, and optional add-ons including 4-bit packed first moments, APOLLO low-rank updates, and CAME confidence-guided optimization. It delivers substantially lower optimizer memory while preserving the low-overhead and numerical stability that make Adafactor attractive for training LLMs and diffusion models.
|
|
43
50
|
|
|
44
51
|
|
|
45
|
-
## Key Features
|
|
52
|
+
## ⚡ Key Features
|
|
46
53
|
|
|
47
54
|
- **Log-Space Quantization**: Maps the second moment (variance) to the log2 space before 8-bit quantization. This approach accommodates the long-tail distribution of variances, reducing the risk of small second-moment estimates being truncated to zero and improving overall training stability.
|
|
48
55
|
- **Fused CUDA Kernels**: Combines dequantization, EMA updates, Warp-Shuffle reductions, and requantization into single kernels. It utilizes `float4` vectorization to optimize memory bandwidth usage.
|
|
56
|
+
- **Optional 4-bit Packed First Moment**: Stores the first moment (`beta1`) in a physically packed 4-bit format when enabled, providing momentum with minimal additional memory overhead.
|
|
57
|
+
- **CAME Confidence Guidance**: Optional Confidence-guided Adaptive Memory Efficient Optimization (CAME) that estimates update confidence from historical momentum and adaptively suppresses unstable update directions, improving training stability and reducing loss spikes.
|
|
49
58
|
- **APOLLO Subspace Projection**: Opt-in random subspace projection that estimates adaptive gradient scaling in a low-rank space, preventing stale second-moment statistics and potentially improving convergence and generalization.
|
|
50
59
|
- **Fira Norm-Growth Limiter**: Suppresses destructive gradient spikes by regulating the relative increase of update norms. Originally used for the APOLLO path, it is now available for the standard Adafactor path as well. It improves training stability and often allows the safe removal of external gradient clipping.
|
|
51
60
|
- **Zero CPU-GPU Sync**: Eliminates implicit synchronizations (e.g., D2H copies) in the control flow, ensuring the GPU computation pipeline runs without blocking.
|
|
52
61
|
- **Cross-Platform JIT**: Uses Just-In-Time (JIT) compilation for straightforward setup across both Windows and Linux environments.
|
|
53
62
|
|
|
54
|
-
## Performance
|
|
63
|
+
## 📊 Performance
|
|
55
64
|
|
|
56
|
-
- **Memory Footprint**: Due to Adafactor's factorized second-moment estimation
|
|
65
|
+
- **Memory Footprint**: Due to Adafactor's factorized second-moment estimation, 8-bit quantization, and optional 4-bit packed first moments, the optimizer typically consumes substantially less memory than `AdamW8Bit`.
|
|
57
66
|
- **Training Speed**: The fused kernel design and reduced synchronization overhead allow it to achieve step times comparable to other mainstream 8-bit optimizers.
|
|
58
67
|
- **Quantization Precision**: The second moment (variance) in Adafactor is strictly non-negative and spans multiple orders of magnitude. By mapping it to `UINT8` in log2 space rather than linear space, the optimizer preserves relative precision for small variances, mitigating the instability often caused by outlier gradients in standard 8-bit quantization.
|
|
59
68
|
|
|
60
|
-
## Installation
|
|
69
|
+
## 📦 Installation
|
|
61
70
|
|
|
62
71
|
This project uses JIT (Just-In-Time) compilation.
|
|
63
72
|
|
|
@@ -77,9 +86,12 @@ pip install -U adafactor8bit
|
|
|
77
86
|
pip install git+https://github.com/yanfeiwong/adafactor-8bit.git
|
|
78
87
|
```
|
|
79
88
|
|
|
80
|
-
|
|
89
|
+
> [!IMPORTANT]
|
|
90
|
+
> **First-Time Compilation**: The first time you instantiate the optimizer (or run the example script), it will automatically trigger the JIT compilation of the CUDA source code in the background. This may take anywhere from a few seconds to a couple of minutes depending on your system, and the terminal might appear unresponsive. Once compiled, the binary will be cached, and all subsequent runs will be instantaneous.
|
|
81
91
|
|
|
82
|
-
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
## 🚀 Quick Start
|
|
83
95
|
|
|
84
96
|
Using it is as simple as using a standard PyTorch optimizer.
|
|
85
97
|
|
|
@@ -89,7 +101,8 @@ from adafactor8bit import Adafactor8Bit
|
|
|
89
101
|
optimizer = Adafactor8Bit(model.parameters(), lr=1e-3)
|
|
90
102
|
```
|
|
91
103
|
|
|
92
|
-
|
|
104
|
+
> [!TIP]
|
|
105
|
+
> Passing `model.parameters()` directly works for a quick test. In production, `param_groups` are recommended to protect sensitive layers (Norms, Biases) from quantization and weight decay. For **sparse token embeddings** (large vocabularies + small batch sizes), please refer to the [Advanced Example](#-advanced-example) to avoid cold-start variance explosion.
|
|
93
106
|
|
|
94
107
|
|
|
95
108
|
```python
|
|
@@ -123,15 +136,20 @@ optimizer = Adafactor8Bit(
|
|
|
123
136
|
# Training loop...
|
|
124
137
|
```
|
|
125
138
|
|
|
126
|
-
## Advanced Example
|
|
139
|
+
## 🛠️ Advanced Example
|
|
127
140
|
|
|
128
|
-
Here we demonstrate a **hybrid grouping** strategy for complex hybrid architectures (e.g., Vision-Language Models, Diffusion UNets) to achieve stable and efficient
|
|
141
|
+
Here we demonstrate a **hybrid grouping** strategy for complex hybrid architectures (e.g., Vision-Language Models, Diffusion UNets) to achieve stable and efficient training.
|
|
129
142
|
|
|
130
143
|
📌 **The following strategies are applied:**
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
144
|
+
| Layer Type | Strategy |
|
|
145
|
+
|------------|----------|
|
|
146
|
+
| **1D / Sensitive Parameters** (Norms, Biases) | No quantization, no weight decay |
|
|
147
|
+
| **Embedding Layers** | `factored=False`, `scale_parameter=False`, `d=1e9` → Momentum-free Adam. Paired with an Adam-style learning rate, this allows for fine-grained, per-token updates while avoiding cold-token interference. |
|
|
148
|
+
| **2D Weights** (Linear Layers) | 8-bit quantization, weight decay, **APOLLO** path. Continuously switching random subspace projection captures comprehensive gradient information and acts as a regularizer. |
|
|
149
|
+
| **>2D Weights** (Conv2d, etc.) | 8-bit quantization, weight decay, **Full-Rank** (`factored=False`). Trades some VRAM to preserve complete spatial structures. |
|
|
150
|
+
| **Momentum (`beta1`)** | Enabled only for dense weight matrices, where the optimization benefit typically outweighs the small memory overhead of the packed 4-bit first moment. Sensitive parameters (Norms/Biases) and sparse Embeddings remain momentum-free. |
|
|
151
|
+
|
|
152
|
+
**Implementation:**
|
|
135
153
|
|
|
136
154
|
```python
|
|
137
155
|
from adafactor8bit import Adafactor8Bit
|
|
@@ -179,14 +197,21 @@ def get_param_groups(model, lr_emb, weight_decay, apollo_rank=256):
|
|
|
179
197
|
},
|
|
180
198
|
|
|
181
199
|
# 3. 2D Weights: 8-bit quantization, Weight Decay, APOLLO low-rank projection
|
|
182
|
-
{
|
|
183
|
-
|
|
200
|
+
{
|
|
201
|
+
"params": group_2d,
|
|
202
|
+
"weight_decay": weight_decay,
|
|
203
|
+
"quantize": True,
|
|
204
|
+
"apollo_rank": apollo_rank,
|
|
205
|
+
"beta1": 0.9, # Remove if minimizing optimizer memory is the priority.
|
|
206
|
+
},
|
|
207
|
+
|
|
184
208
|
# 4. >2D Weights: 8-bit quantization, Weight Decay, Full-Rank
|
|
185
209
|
{
|
|
186
210
|
"params": group_nd,
|
|
187
211
|
"weight_decay": weight_decay,
|
|
188
212
|
"quantize": True,
|
|
189
213
|
"apollo_rank": 0,
|
|
214
|
+
"beta1": 0.9, # Remove if minimizing optimizer memory is the priority.
|
|
190
215
|
"factored": False # Disables factorization to preserve spatial structures, enabling finer gradient scaling.
|
|
191
216
|
# Note: This increases state memory for >2D weights, depending on your model architecture.
|
|
192
217
|
# If VRAM is constrained, reverting to factored=True is a safe alternative.
|
|
@@ -206,10 +231,11 @@ optimizer = Adafactor8Bit(
|
|
|
206
231
|
# Training loop...
|
|
207
232
|
```
|
|
208
233
|
|
|
209
|
-
|
|
234
|
+
> [!NOTE]
|
|
235
|
+
> For more complete examples, please refer to the [examples folder](https://github.com/yanfeiwong/adafactor-8bit/tree/main/examples).
|
|
210
236
|
|
|
211
237
|
|
|
212
|
-
## Advanced Configuration
|
|
238
|
+
## ⚙️ Advanced Configuration
|
|
213
239
|
|
|
214
240
|
### Continual Learning (`beta2` & `relative_step`)
|
|
215
241
|
By default, Adafactor's second-moment decay rate dynamically decays with the training step, and the internal learning rate schedule (`relative_step`) scales the learning rate accordingly.
|
|
@@ -237,19 +263,57 @@ By default, Adafactor factorizes the second moment of $\ge$ 2D tensors into row
|
|
|
237
263
|
If you are in an environment without a CUDA compiler and want to bypass JIT compilation entirely:
|
|
238
264
|
- Set `use_cuda_kernel=False` to fall back to the pure PyTorch implementation.
|
|
239
265
|
|
|
240
|
-
## APOLLO Low-Rank Subspace Projection
|
|
266
|
+
## 🌌 APOLLO Low-Rank Subspace Projection
|
|
241
267
|
Enable the APOLLO path to compute gradient scaling factors in a memory-efficient low-rank subspace. Compared to Adafactor's standard row/column factorization (which assumes spatial independence), APOLLO uses random subspace projection to capture cross-dimensional covariance information, potentially leading to better generalization while keeping memory overhead extremely low.
|
|
242
268
|
|
|
243
|
-
- **`apollo_rank`**: The target rank for the projection subspace. The default is `0` (disabled).
|
|
244
|
-
|
|
269
|
+
- **`apollo_rank`**: The target rank for the projection subspace. The default is `0` (disabled).
|
|
270
|
+
|
|
271
|
+
- The official APOLLO GitHub repository recommends a rank of `256` for 1B and 7B models.
|
|
272
|
+
- The [LLaMA-Factory](https://llamafactory.readthedocs.io/en/latest/advanced/arguments.html#apollo) default is `16`.
|
|
273
|
+
- Setting this to `1` (APOLLO-Mini style) pushes VRAM savings to the limit (saves even more VRAM than the Adafactor path). The original APOLLO-Mini relies on the first-moment (beta1) to smooth out projection noise. To replicate this, set `beta1=0.9` alongside `apollo_rank=1`. Without beta1, rank=1 may still work but can exhibit noisier scaling factors, especially at small batch sizes.
|
|
274
|
+
|
|
275
|
+
|
|
245
276
|
- **`apollo_scale_type`**: Determines how the scaling factor is applied. `'channel'` applies it per channel (Standard APOLLO), while `'tensor'` applies it globally (APOLLO-Mini).
|
|
246
277
|
- **`apollo_update_proj_gap`**: Steps between projection matrix refreshes. Defaults to `200`. Setting this too small may cause frequent oscillations due to abrupt basis mutations, while setting it too large might cause the projection space to become stale and fail to track the drift of the gradient manifold.
|
|
247
278
|
- **`apollo_factorize` (Experimental)**: Applies Adafactor's row/column factorization within the low-rank subspace. Mathematically, this leverages the norm-preserving property of random projections to approximate the variance of the primary dimension, while the secondary dimension's variance is estimated across random bases, introducing inherent noise. This dual-compression mechanism drastically reduces optimizer state overhead. Note that for smaller models, the actual VRAM savings might be marginal, and the introduced noise could impact convergence stability. Use with caution.
|
|
248
279
|
- **Fira Limiter Integration**: The APOLLO path automatically applies the Fira Norm-Growth Limiter to the scaled gradients to prevent sudden gradient rises from causing loss spikes. You can adjust its sensitivity using the global `fira_margin` parameter.
|
|
249
280
|
|
|
281
|
+
## 🛡️ CAME Confidence-Guided Updates
|
|
282
|
+
|
|
283
|
+
Enable the CAME (Confidence-guided Adaptive Memory Efficient Optimization) path to add a confidence estimation stage after momentum accumulation:
|
|
284
|
+
|
|
285
|
+
**Adaptive Scaling ($V$) → Momentum Accumulation ($M$) → Confidence Weighting ($C$)**
|
|
286
|
+
|
|
287
|
+
### Key Parameters & Tuning
|
|
250
288
|
|
|
289
|
+
The confidence stage measures the consistency between the current update direction and historical momentum, adaptively suppressing highly oscillatory updates.
|
|
251
290
|
|
|
252
|
-
|
|
291
|
+
- **`beta3`**: EMA decay coefficient for the confidence matrix. Requires `beta1` (momentum) and `factored=True`. Mutually exclusive with `apollo_rank`. Defaults to `None` (disabled).
|
|
292
|
+
- **Learning Rate**: The official CAME implementation recommends **0.5–0.9×** the AdamW learning rate (see [official tuning guide](https://github.com/yangluo7/CAME/tree/master#hyper-parameter-tuning)). To use this learning rate in this library, you need to disable Adafactor's scaling and clipping (`scale_parameter=False`, `d=1e9`) to align with the original CAME behavior.
|
|
293
|
+
- **Warmup**: Since the confidence matrix is zero-initialized without bias correction, a learning rate warmup is recommended to safely establish the confidence baseline.
|
|
294
|
+
- **Choosing `beta3`**: `beta3` should generally be larger than `beta2` so the confidence estimate evolves more slowly than the variance estimate. A practical starting range is **0.9995–0.99995** when `beta2=0.999`.
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
### Configuration Example
|
|
298
|
+
|
|
299
|
+
To replicate "vanilla" CAME (stripping Adafactor's native modifications), replace the standard 2D APOLLO group in your `param_groups` with the following configuration:
|
|
300
|
+
|
|
301
|
+
```python
|
|
302
|
+
{
|
|
303
|
+
"params": param_group,
|
|
304
|
+
"lr": lr, # Original CAME recommends 0.5-0.9x AdamW LR
|
|
305
|
+
"weight_decay": weight_decay,
|
|
306
|
+
"quantize": True,
|
|
307
|
+
"beta1": 0.9,
|
|
308
|
+
"beta3": 0.9999, # Enable CAME confidence guidance
|
|
309
|
+
"apollo_rank": 0, # Mutually exclusive with CAME
|
|
310
|
+
"scale_parameter": False, # Disable Adafactor RMS scaling to align with vanilla CAME
|
|
311
|
+
"d": 1e9, # Disable Adafactor global RMS clipping
|
|
312
|
+
"enable_fira_for_adafactor": False, # Disable Fira Limiter to prevent interference with CAME's scaling
|
|
313
|
+
},
|
|
314
|
+
```
|
|
315
|
+
|
|
316
|
+
## 📈 Learning Rate Guide for Beginners
|
|
253
317
|
|
|
254
318
|
If you are migrating from optimizers like AdamW, Adafactor's learning rate behavior might feel a bit different. This is mainly due to the `scale_parameter` option.
|
|
255
319
|
|
|
@@ -265,7 +329,7 @@ If you are migrating from optimizers like AdamW, Adafactor's learning rate behav
|
|
|
265
329
|
|
|
266
330
|
|
|
267
331
|
|
|
268
|
-
## Acknowledgements
|
|
332
|
+
## 🎓 Acknowledgements
|
|
269
333
|
|
|
270
334
|
Thanks to **Noam Shazeer** and **Mitchell Stern** for proposing the original Adafactor algorithm in the paper [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235).
|
|
271
335
|
|
|
@@ -275,14 +339,18 @@ Thanks to **Hanqing Zhu**, **Zhenyu Zhang**, and the team for proposing the appr
|
|
|
275
339
|
|
|
276
340
|
Thanks to **Xi Chen**, **Kaituo Feng**, and the team for the Norm-Growth Limiter mechanism introduced in [Fira: Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623).
|
|
277
341
|
|
|
342
|
+
Thanks to **Yang Luo** and the team for proposing the confidence-guided strategy in the paper [CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047).
|
|
343
|
+
|
|
278
344
|
Thanks to the **PyTorch team** for providing the foundational Optimizer implementation and the C++ Extension toolchain.
|
|
279
345
|
|
|
280
|
-
Thanks to the large language models **Qwen** and **DeepSeek** for valuable technical discussions and code reviews on CUDA low-level optimization
|
|
346
|
+
Thanks to the large language models **Qwen**, **ChatGLM** and **DeepSeek** for valuable technical discussions and code reviews on CUDA low-level optimization and memory safety mechanisms.
|
|
281
347
|
|
|
282
|
-
##
|
|
348
|
+
## 🏛️ License
|
|
283
349
|
|
|
284
|
-
[
|
|
350
|
+
[The project is released under the MIT License.](https://github.com/yanfeiwong/adafactor-8bit/blob/main/LICENSE)
|
|
285
351
|
|
|
286
|
-
##
|
|
352
|
+
## ⭐ Star the Project
|
|
287
353
|
|
|
288
|
-
|
|
354
|
+
If this optimizer has been useful in your work, consider giving the repository a star. It helps others discover the project and supports future development.
|
|
355
|
+
|
|
356
|
+
[](https://star-history.com/#yanfeiwong/adafactor-8bit&Date)
|
|
@@ -1,3 +1,10 @@
|
|
|
1
|
+
<p align="center">
|
|
2
|
+
<a href="https://github.com/yanfeiwong/adafactor-8bit">
|
|
3
|
+
<img src="https://github.com/yanfeiwong/adafactor-8bit/raw/main/assets/banner.png"
|
|
4
|
+
alt="Adafactor8Bit"
|
|
5
|
+
width="80%">
|
|
6
|
+
</a>
|
|
7
|
+
</p>
|
|
1
8
|
<div align="center">
|
|
2
9
|
|
|
3
10
|
# 8-bit Adafactor with Fused CUDA Kernels
|
|
@@ -12,25 +19,27 @@
|
|
|
12
19
|
|
|
13
20
|
</div>
|
|
14
21
|
|
|
15
|
-
An 8-bit Adafactor optimizer featuring fused CUDA kernels
|
|
22
|
+
An enhanced 8-bit Adafactor optimizer featuring fused CUDA kernels, log-space block-wise quantization, and optional add-ons including 4-bit packed first moments, APOLLO low-rank updates, and CAME confidence-guided optimization. It delivers substantially lower optimizer memory while preserving the low-overhead and numerical stability that make Adafactor attractive for training LLMs and diffusion models.
|
|
16
23
|
|
|
17
24
|
|
|
18
|
-
## Key Features
|
|
25
|
+
## ⚡ Key Features
|
|
19
26
|
|
|
20
27
|
- **Log-Space Quantization**: Maps the second moment (variance) to the log2 space before 8-bit quantization. This approach accommodates the long-tail distribution of variances, reducing the risk of small second-moment estimates being truncated to zero and improving overall training stability.
|
|
21
28
|
- **Fused CUDA Kernels**: Combines dequantization, EMA updates, Warp-Shuffle reductions, and requantization into single kernels. It utilizes `float4` vectorization to optimize memory bandwidth usage.
|
|
29
|
+
- **Optional 4-bit Packed First Moment**: Stores the first moment (`beta1`) in a physically packed 4-bit format when enabled, providing momentum with minimal additional memory overhead.
|
|
30
|
+
- **CAME Confidence Guidance**: Optional Confidence-guided Adaptive Memory Efficient Optimization (CAME) that estimates update confidence from historical momentum and adaptively suppresses unstable update directions, improving training stability and reducing loss spikes.
|
|
22
31
|
- **APOLLO Subspace Projection**: Opt-in random subspace projection that estimates adaptive gradient scaling in a low-rank space, preventing stale second-moment statistics and potentially improving convergence and generalization.
|
|
23
32
|
- **Fira Norm-Growth Limiter**: Suppresses destructive gradient spikes by regulating the relative increase of update norms. Originally used for the APOLLO path, it is now available for the standard Adafactor path as well. It improves training stability and often allows the safe removal of external gradient clipping.
|
|
24
33
|
- **Zero CPU-GPU Sync**: Eliminates implicit synchronizations (e.g., D2H copies) in the control flow, ensuring the GPU computation pipeline runs without blocking.
|
|
25
34
|
- **Cross-Platform JIT**: Uses Just-In-Time (JIT) compilation for straightforward setup across both Windows and Linux environments.
|
|
26
35
|
|
|
27
|
-
## Performance
|
|
36
|
+
## 📊 Performance
|
|
28
37
|
|
|
29
|
-
- **Memory Footprint**: Due to Adafactor's factorized second-moment estimation
|
|
38
|
+
- **Memory Footprint**: Due to Adafactor's factorized second-moment estimation, 8-bit quantization, and optional 4-bit packed first moments, the optimizer typically consumes substantially less memory than `AdamW8Bit`.
|
|
30
39
|
- **Training Speed**: The fused kernel design and reduced synchronization overhead allow it to achieve step times comparable to other mainstream 8-bit optimizers.
|
|
31
40
|
- **Quantization Precision**: The second moment (variance) in Adafactor is strictly non-negative and spans multiple orders of magnitude. By mapping it to `UINT8` in log2 space rather than linear space, the optimizer preserves relative precision for small variances, mitigating the instability often caused by outlier gradients in standard 8-bit quantization.
|
|
32
41
|
|
|
33
|
-
## Installation
|
|
42
|
+
## 📦 Installation
|
|
34
43
|
|
|
35
44
|
This project uses JIT (Just-In-Time) compilation.
|
|
36
45
|
|
|
@@ -50,9 +59,12 @@ pip install -U adafactor8bit
|
|
|
50
59
|
pip install git+https://github.com/yanfeiwong/adafactor-8bit.git
|
|
51
60
|
```
|
|
52
61
|
|
|
53
|
-
|
|
62
|
+
> [!IMPORTANT]
|
|
63
|
+
> **First-Time Compilation**: The first time you instantiate the optimizer (or run the example script), it will automatically trigger the JIT compilation of the CUDA source code in the background. This may take anywhere from a few seconds to a couple of minutes depending on your system, and the terminal might appear unresponsive. Once compiled, the binary will be cached, and all subsequent runs will be instantaneous.
|
|
54
64
|
|
|
55
|
-
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
## 🚀 Quick Start
|
|
56
68
|
|
|
57
69
|
Using it is as simple as using a standard PyTorch optimizer.
|
|
58
70
|
|
|
@@ -62,7 +74,8 @@ from adafactor8bit import Adafactor8Bit
|
|
|
62
74
|
optimizer = Adafactor8Bit(model.parameters(), lr=1e-3)
|
|
63
75
|
```
|
|
64
76
|
|
|
65
|
-
|
|
77
|
+
> [!TIP]
|
|
78
|
+
> Passing `model.parameters()` directly works for a quick test. In production, `param_groups` are recommended to protect sensitive layers (Norms, Biases) from quantization and weight decay. For **sparse token embeddings** (large vocabularies + small batch sizes), please refer to the [Advanced Example](#-advanced-example) to avoid cold-start variance explosion.
|
|
66
79
|
|
|
67
80
|
|
|
68
81
|
```python
|
|
@@ -96,15 +109,20 @@ optimizer = Adafactor8Bit(
|
|
|
96
109
|
# Training loop...
|
|
97
110
|
```
|
|
98
111
|
|
|
99
|
-
## Advanced Example
|
|
112
|
+
## 🛠️ Advanced Example
|
|
100
113
|
|
|
101
|
-
Here we demonstrate a **hybrid grouping** strategy for complex hybrid architectures (e.g., Vision-Language Models, Diffusion UNets) to achieve stable and efficient
|
|
114
|
+
Here we demonstrate a **hybrid grouping** strategy for complex hybrid architectures (e.g., Vision-Language Models, Diffusion UNets) to achieve stable and efficient training.
|
|
102
115
|
|
|
103
116
|
📌 **The following strategies are applied:**
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
117
|
+
| Layer Type | Strategy |
|
|
118
|
+
|------------|----------|
|
|
119
|
+
| **1D / Sensitive Parameters** (Norms, Biases) | No quantization, no weight decay |
|
|
120
|
+
| **Embedding Layers** | `factored=False`, `scale_parameter=False`, `d=1e9` → Momentum-free Adam. Paired with an Adam-style learning rate, this allows for fine-grained, per-token updates while avoiding cold-token interference. |
|
|
121
|
+
| **2D Weights** (Linear Layers) | 8-bit quantization, weight decay, **APOLLO** path. Continuously switching random subspace projection captures comprehensive gradient information and acts as a regularizer. |
|
|
122
|
+
| **>2D Weights** (Conv2d, etc.) | 8-bit quantization, weight decay, **Full-Rank** (`factored=False`). Trades some VRAM to preserve complete spatial structures. |
|
|
123
|
+
| **Momentum (`beta1`)** | Enabled only for dense weight matrices, where the optimization benefit typically outweighs the small memory overhead of the packed 4-bit first moment. Sensitive parameters (Norms/Biases) and sparse Embeddings remain momentum-free. |
|
|
124
|
+
|
|
125
|
+
**Implementation:**
|
|
108
126
|
|
|
109
127
|
```python
|
|
110
128
|
from adafactor8bit import Adafactor8Bit
|
|
@@ -152,14 +170,21 @@ def get_param_groups(model, lr_emb, weight_decay, apollo_rank=256):
|
|
|
152
170
|
},
|
|
153
171
|
|
|
154
172
|
# 3. 2D Weights: 8-bit quantization, Weight Decay, APOLLO low-rank projection
|
|
155
|
-
{
|
|
156
|
-
|
|
173
|
+
{
|
|
174
|
+
"params": group_2d,
|
|
175
|
+
"weight_decay": weight_decay,
|
|
176
|
+
"quantize": True,
|
|
177
|
+
"apollo_rank": apollo_rank,
|
|
178
|
+
"beta1": 0.9, # Remove if minimizing optimizer memory is the priority.
|
|
179
|
+
},
|
|
180
|
+
|
|
157
181
|
# 4. >2D Weights: 8-bit quantization, Weight Decay, Full-Rank
|
|
158
182
|
{
|
|
159
183
|
"params": group_nd,
|
|
160
184
|
"weight_decay": weight_decay,
|
|
161
185
|
"quantize": True,
|
|
162
186
|
"apollo_rank": 0,
|
|
187
|
+
"beta1": 0.9, # Remove if minimizing optimizer memory is the priority.
|
|
163
188
|
"factored": False # Disables factorization to preserve spatial structures, enabling finer gradient scaling.
|
|
164
189
|
# Note: This increases state memory for >2D weights, depending on your model architecture.
|
|
165
190
|
# If VRAM is constrained, reverting to factored=True is a safe alternative.
|
|
@@ -179,10 +204,11 @@ optimizer = Adafactor8Bit(
|
|
|
179
204
|
# Training loop...
|
|
180
205
|
```
|
|
181
206
|
|
|
182
|
-
|
|
207
|
+
> [!NOTE]
|
|
208
|
+
> For more complete examples, please refer to the [examples folder](https://github.com/yanfeiwong/adafactor-8bit/tree/main/examples).
|
|
183
209
|
|
|
184
210
|
|
|
185
|
-
## Advanced Configuration
|
|
211
|
+
## ⚙️ Advanced Configuration
|
|
186
212
|
|
|
187
213
|
### Continual Learning (`beta2` & `relative_step`)
|
|
188
214
|
By default, Adafactor's second-moment decay rate dynamically decays with the training step, and the internal learning rate schedule (`relative_step`) scales the learning rate accordingly.
|
|
@@ -210,19 +236,57 @@ By default, Adafactor factorizes the second moment of $\ge$ 2D tensors into row
|
|
|
210
236
|
If you are in an environment without a CUDA compiler and want to bypass JIT compilation entirely:
|
|
211
237
|
- Set `use_cuda_kernel=False` to fall back to the pure PyTorch implementation.
|
|
212
238
|
|
|
213
|
-
## APOLLO Low-Rank Subspace Projection
|
|
239
|
+
## 🌌 APOLLO Low-Rank Subspace Projection
|
|
214
240
|
Enable the APOLLO path to compute gradient scaling factors in a memory-efficient low-rank subspace. Compared to Adafactor's standard row/column factorization (which assumes spatial independence), APOLLO uses random subspace projection to capture cross-dimensional covariance information, potentially leading to better generalization while keeping memory overhead extremely low.
|
|
215
241
|
|
|
216
|
-
- **`apollo_rank`**: The target rank for the projection subspace. The default is `0` (disabled).
|
|
217
|
-
|
|
242
|
+
- **`apollo_rank`**: The target rank for the projection subspace. The default is `0` (disabled).
|
|
243
|
+
|
|
244
|
+
- The official APOLLO GitHub repository recommends a rank of `256` for 1B and 7B models.
|
|
245
|
+
- The [LLaMA-Factory](https://llamafactory.readthedocs.io/en/latest/advanced/arguments.html#apollo) default is `16`.
|
|
246
|
+
- Setting this to `1` (APOLLO-Mini style) pushes VRAM savings to the limit (saves even more VRAM than the Adafactor path). The original APOLLO-Mini relies on the first-moment (beta1) to smooth out projection noise. To replicate this, set `beta1=0.9` alongside `apollo_rank=1`. Without beta1, rank=1 may still work but can exhibit noisier scaling factors, especially at small batch sizes.
|
|
247
|
+
|
|
248
|
+
|
|
218
249
|
- **`apollo_scale_type`**: Determines how the scaling factor is applied. `'channel'` applies it per channel (Standard APOLLO), while `'tensor'` applies it globally (APOLLO-Mini).
|
|
219
250
|
- **`apollo_update_proj_gap`**: Steps between projection matrix refreshes. Defaults to `200`. Setting this too small may cause frequent oscillations due to abrupt basis mutations, while setting it too large might cause the projection space to become stale and fail to track the drift of the gradient manifold.
|
|
220
251
|
- **`apollo_factorize` (Experimental)**: Applies Adafactor's row/column factorization within the low-rank subspace. Mathematically, this leverages the norm-preserving property of random projections to approximate the variance of the primary dimension, while the secondary dimension's variance is estimated across random bases, introducing inherent noise. This dual-compression mechanism drastically reduces optimizer state overhead. Note that for smaller models, the actual VRAM savings might be marginal, and the introduced noise could impact convergence stability. Use with caution.
|
|
221
252
|
- **Fira Limiter Integration**: The APOLLO path automatically applies the Fira Norm-Growth Limiter to the scaled gradients to prevent sudden gradient rises from causing loss spikes. You can adjust its sensitivity using the global `fira_margin` parameter.
|
|
222
253
|
|
|
254
|
+
## 🛡️ CAME Confidence-Guided Updates
|
|
255
|
+
|
|
256
|
+
Enable the CAME (Confidence-guided Adaptive Memory Efficient Optimization) path to add a confidence estimation stage after momentum accumulation:
|
|
257
|
+
|
|
258
|
+
**Adaptive Scaling ($V$) → Momentum Accumulation ($M$) → Confidence Weighting ($C$)**
|
|
259
|
+
|
|
260
|
+
### Key Parameters & Tuning
|
|
223
261
|
|
|
262
|
+
The confidence stage measures the consistency between the current update direction and historical momentum, adaptively suppressing highly oscillatory updates.
|
|
224
263
|
|
|
225
|
-
|
|
264
|
+
- **`beta3`**: EMA decay coefficient for the confidence matrix. Requires `beta1` (momentum) and `factored=True`. Mutually exclusive with `apollo_rank`. Defaults to `None` (disabled).
|
|
265
|
+
- **Learning Rate**: The official CAME implementation recommends **0.5–0.9×** the AdamW learning rate (see [official tuning guide](https://github.com/yangluo7/CAME/tree/master#hyper-parameter-tuning)). To use this learning rate in this library, you need to disable Adafactor's scaling and clipping (`scale_parameter=False`, `d=1e9`) to align with the original CAME behavior.
|
|
266
|
+
- **Warmup**: Since the confidence matrix is zero-initialized without bias correction, a learning rate warmup is recommended to safely establish the confidence baseline.
|
|
267
|
+
- **Choosing `beta3`**: `beta3` should generally be larger than `beta2` so the confidence estimate evolves more slowly than the variance estimate. A practical starting range is **0.9995–0.99995** when `beta2=0.999`.
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
### Configuration Example
|
|
271
|
+
|
|
272
|
+
To replicate "vanilla" CAME (stripping Adafactor's native modifications), replace the standard 2D APOLLO group in your `param_groups` with the following configuration:
|
|
273
|
+
|
|
274
|
+
```python
|
|
275
|
+
{
|
|
276
|
+
"params": param_group,
|
|
277
|
+
"lr": lr, # Original CAME recommends 0.5-0.9x AdamW LR
|
|
278
|
+
"weight_decay": weight_decay,
|
|
279
|
+
"quantize": True,
|
|
280
|
+
"beta1": 0.9,
|
|
281
|
+
"beta3": 0.9999, # Enable CAME confidence guidance
|
|
282
|
+
"apollo_rank": 0, # Mutually exclusive with CAME
|
|
283
|
+
"scale_parameter": False, # Disable Adafactor RMS scaling to align with vanilla CAME
|
|
284
|
+
"d": 1e9, # Disable Adafactor global RMS clipping
|
|
285
|
+
"enable_fira_for_adafactor": False, # Disable Fira Limiter to prevent interference with CAME's scaling
|
|
286
|
+
},
|
|
287
|
+
```
|
|
288
|
+
|
|
289
|
+
## 📈 Learning Rate Guide for Beginners
|
|
226
290
|
|
|
227
291
|
If you are migrating from optimizers like AdamW, Adafactor's learning rate behavior might feel a bit different. This is mainly due to the `scale_parameter` option.
|
|
228
292
|
|
|
@@ -238,7 +302,7 @@ If you are migrating from optimizers like AdamW, Adafactor's learning rate behav
|
|
|
238
302
|
|
|
239
303
|
|
|
240
304
|
|
|
241
|
-
## Acknowledgements
|
|
305
|
+
## 🎓 Acknowledgements
|
|
242
306
|
|
|
243
307
|
Thanks to **Noam Shazeer** and **Mitchell Stern** for proposing the original Adafactor algorithm in the paper [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235).
|
|
244
308
|
|
|
@@ -248,14 +312,18 @@ Thanks to **Hanqing Zhu**, **Zhenyu Zhang**, and the team for proposing the appr
|
|
|
248
312
|
|
|
249
313
|
Thanks to **Xi Chen**, **Kaituo Feng**, and the team for the Norm-Growth Limiter mechanism introduced in [Fira: Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623).
|
|
250
314
|
|
|
315
|
+
Thanks to **Yang Luo** and the team for proposing the confidence-guided strategy in the paper [CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047).
|
|
316
|
+
|
|
251
317
|
Thanks to the **PyTorch team** for providing the foundational Optimizer implementation and the C++ Extension toolchain.
|
|
252
318
|
|
|
253
|
-
Thanks to the large language models **Qwen** and **DeepSeek** for valuable technical discussions and code reviews on CUDA low-level optimization
|
|
319
|
+
Thanks to the large language models **Qwen**, **ChatGLM** and **DeepSeek** for valuable technical discussions and code reviews on CUDA low-level optimization and memory safety mechanisms.
|
|
254
320
|
|
|
255
|
-
##
|
|
321
|
+
## 🏛️ License
|
|
256
322
|
|
|
257
|
-
[
|
|
323
|
+
[The project is released under the MIT License.](https://github.com/yanfeiwong/adafactor-8bit/blob/main/LICENSE)
|
|
258
324
|
|
|
259
|
-
##
|
|
325
|
+
## ⭐ Star the Project
|
|
260
326
|
|
|
261
|
-
|
|
327
|
+
If this optimizer has been useful in your work, consider giving the repository a star. It helps others discover the project and supports future development.
|
|
328
|
+
|
|
329
|
+
[](https://star-history.com/#yanfeiwong/adafactor-8bit&Date)
|