adafactor8bit 0.2.0__tar.gz → 0.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adafactor8bit
3
- Version: 0.2.0
3
+ Version: 0.2.2
4
4
  Summary: 8-bit Adafactor Optimizer with Fused CUDA Kernels
5
5
  Home-page: https://github.com/yanfeiwong/adafactor-8bit
6
6
  Author: WANG YAN
@@ -25,6 +25,13 @@ Dynamic: requires-dist
25
25
  Dynamic: requires-python
26
26
  Dynamic: summary
27
27
 
28
+ <p align="center">
29
+ <a href="https://github.com/yanfeiwong/adafactor-8bit">
30
+ <img src="https://github.com/yanfeiwong/adafactor-8bit/raw/main/assets/banner.png"
31
+ alt="Adafactor8Bit"
32
+ width="80%">
33
+ </a>
34
+ </p>
28
35
  <div align="center">
29
36
 
30
37
  # 8-bit Adafactor with Fused CUDA Kernels
@@ -39,25 +46,27 @@ Dynamic: summary
39
46
 
40
47
  </div>
41
48
 
42
- An 8-bit Adafactor optimizer featuring fused CUDA kernels and log-space block-wise quantization, designed to further reduce optimizer state memory while maintaining low step overhead and stability suitable for large models such as LLMs and diffusion models.
49
+ An enhanced 8-bit Adafactor optimizer featuring fused CUDA kernels, log-space block-wise quantization, and optional add-ons including 4-bit packed first moments, APOLLO low-rank updates, and CAME confidence-guided optimization. It delivers substantially lower optimizer memory while preserving the low-overhead and numerical stability that make Adafactor attractive for training LLMs and diffusion models.
43
50
 
44
51
 
45
- ## Key Features
52
+ ## Key Features
46
53
 
47
54
  - **Log-Space Quantization**: Maps the second moment (variance) to the log2 space before 8-bit quantization. This approach accommodates the long-tail distribution of variances, reducing the risk of small second-moment estimates being truncated to zero and improving overall training stability.
48
55
  - **Fused CUDA Kernels**: Combines dequantization, EMA updates, Warp-Shuffle reductions, and requantization into single kernels. It utilizes `float4` vectorization to optimize memory bandwidth usage.
56
+ - **Optional 4-bit Packed First Moment**: Stores the first moment (`beta1`) in a physically packed 4-bit format when enabled, providing momentum with minimal additional memory overhead.
57
+ - **CAME Confidence Guidance**: Optional Confidence-guided Adaptive Memory Efficient Optimization (CAME) that estimates update confidence from historical momentum and adaptively suppresses unstable update directions, improving training stability and reducing loss spikes.
49
58
  - **APOLLO Subspace Projection**: Opt-in random subspace projection that estimates adaptive gradient scaling in a low-rank space, preventing stale second-moment statistics and potentially improving convergence and generalization.
50
59
  - **Fira Norm-Growth Limiter**: Suppresses destructive gradient spikes by regulating the relative increase of update norms. Originally used for the APOLLO path, it is now available for the standard Adafactor path as well. It improves training stability and often allows the safe removal of external gradient clipping.
51
60
  - **Zero CPU-GPU Sync**: Eliminates implicit synchronizations (e.g., D2H copies) in the control flow, ensuring the GPU computation pipeline runs without blocking.
52
61
  - **Cross-Platform JIT**: Uses Just-In-Time (JIT) compilation for straightforward setup across both Windows and Linux environments.
53
62
 
54
- ## Performance
63
+ ## 📊 Performance
55
64
 
56
- - **Memory Footprint**: Due to Adafactor's factorized second-moment estimation and 8-bit quantization, the optimizer state memory usage is generally lower than that of `AdamW8Bit`.
65
+ - **Memory Footprint**: Due to Adafactor's factorized second-moment estimation, 8-bit quantization, and optional 4-bit packed first moments, the optimizer typically consumes substantially less memory than `AdamW8Bit`.
57
66
  - **Training Speed**: The fused kernel design and reduced synchronization overhead allow it to achieve step times comparable to other mainstream 8-bit optimizers.
58
67
  - **Quantization Precision**: The second moment (variance) in Adafactor is strictly non-negative and spans multiple orders of magnitude. By mapping it to `UINT8` in log2 space rather than linear space, the optimizer preserves relative precision for small variances, mitigating the instability often caused by outlier gradients in standard 8-bit quantization.
59
68
 
60
- ## Installation
69
+ ## 📦 Installation
61
70
 
62
71
  This project uses JIT (Just-In-Time) compilation.
63
72
 
@@ -77,9 +86,12 @@ pip install -U adafactor8bit
77
86
  pip install git+https://github.com/yanfeiwong/adafactor-8bit.git
78
87
  ```
79
88
 
80
- **Note**: The first time you instantiate the optimizer (or run the example script), it will automatically trigger the JIT compilation of the CUDA source code in the background. This may take anywhere from a few seconds to a couple of minutes depending on your system, and the terminal might appear unresponsive. Once compiled, the binary will be cached, and all subsequent runs will be instantaneous.
89
+ > [!IMPORTANT]
90
+ > **First-Time Compilation**: The first time you instantiate the optimizer (or run the example script), it will automatically trigger the JIT compilation of the CUDA source code in the background. This may take anywhere from a few seconds to a couple of minutes depending on your system, and the terminal might appear unresponsive. Once compiled, the binary will be cached, and all subsequent runs will be instantaneous.
81
91
 
82
- ## Quick Start
92
+
93
+
94
+ ## 🚀 Quick Start
83
95
 
84
96
  Using it is as simple as using a standard PyTorch optimizer.
85
97
 
@@ -89,7 +101,8 @@ from adafactor8bit import Adafactor8Bit
89
101
  optimizer = Adafactor8Bit(model.parameters(), lr=1e-3)
90
102
  ```
91
103
 
92
- **💡 Note**: Passing `model.parameters()` directly works for a quick test. In production, `param_groups` are recommended to protect sensitive layers (Norms, Biases) from quantization and weight decay. For **sparse token embeddings** (large vocabularies + small batch sizes), please refer to the [Advanced Example](#advanced-example) to avoid cold-start variance explosion.
104
+ > [!TIP]
105
+ > Passing `model.parameters()` directly works for a quick test. In production, `param_groups` are recommended to protect sensitive layers (Norms, Biases) from quantization and weight decay. For **sparse token embeddings** (large vocabularies + small batch sizes), please refer to the [Advanced Example](#-advanced-example) to avoid cold-start variance explosion.
93
106
 
94
107
 
95
108
  ```python
@@ -123,15 +136,20 @@ optimizer = Adafactor8Bit(
123
136
  # Training loop...
124
137
  ```
125
138
 
126
- ## Advanced Example
139
+ ## 🛠️ Advanced Example
127
140
 
128
- Here we demonstrate a **hybrid grouping** strategy for complex hybrid architectures (e.g., Vision-Language Models, Diffusion UNets) to achieve stable and efficient momentum-free training as much as possible.
141
+ Here we demonstrate a **hybrid grouping** strategy for complex hybrid architectures (e.g., Vision-Language Models, Diffusion UNets) to achieve stable and efficient training.
129
142
 
130
143
  📌 **The following strategies are applied:**
131
- 1. **1D / Sensitive Parameters (Norms, Biases)**: No quantization, no weight decay.
132
- 2. **Embedding Layers**: Combines `factored=False`, `scale_parameter=False`, and `d=1e9` to make the optimization behavior equivalent to a **momentum-free Adam**. Paired with an Adam-style learning rate, this allows for fine-grained, per-token updates while avoiding cold-token interference (global clipping penalties).
133
- 3. **2D Weights (Linear Layers)**: 8-bit quantization, weight decay, using the **APOLLO** path. The continuously switching random subspace projection helps capture comprehensive gradient information and acts as a regularizer.
134
- 4. **>2D Weights (Conv2d, etc.)**: 8-bit quantization, weight decay, **Full-Rank** (`factored=False`). Trades a certain amount of VRAM to preserve complete spatial structures for better optimization outcomes.
144
+ | Layer Type | Strategy |
145
+ |------------|----------|
146
+ | **1D / Sensitive Parameters** (Norms, Biases) | No quantization, no weight decay |
147
+ | **Embedding Layers** | `factored=False`, `scale_parameter=False`, `d=1e9` Momentum-free Adam. Paired with an Adam-style learning rate, this allows for fine-grained, per-token updates while avoiding cold-token interference. |
148
+ | **2D Weights** (Linear Layers) | 8-bit quantization, weight decay, **APOLLO** path. Continuously switching random subspace projection captures comprehensive gradient information and acts as a regularizer. |
149
+ | **>2D Weights** (Conv2d, etc.) | 8-bit quantization, weight decay, **Full-Rank** (`factored=False`). Trades some VRAM to preserve complete spatial structures. |
150
+ | **Momentum (`beta1`)** | Enabled only for dense weight matrices, where the optimization benefit typically outweighs the small memory overhead of the packed 4-bit first moment. Sensitive parameters (Norms/Biases) and sparse Embeddings remain momentum-free. |
151
+
152
+ **Implementation:**
135
153
 
136
154
  ```python
137
155
  from adafactor8bit import Adafactor8Bit
@@ -179,14 +197,21 @@ def get_param_groups(model, lr_emb, weight_decay, apollo_rank=256):
179
197
  },
180
198
 
181
199
  # 3. 2D Weights: 8-bit quantization, Weight Decay, APOLLO low-rank projection
182
- {"params": group_2d, "weight_decay": weight_decay, "quantize": True, "apollo_rank": apollo_rank},
183
-
200
+ {
201
+ "params": group_2d,
202
+ "weight_decay": weight_decay,
203
+ "quantize": True,
204
+ "apollo_rank": apollo_rank,
205
+ "beta1": 0.9, # Remove if minimizing optimizer memory is the priority.
206
+ },
207
+
184
208
  # 4. >2D Weights: 8-bit quantization, Weight Decay, Full-Rank
185
209
  {
186
210
  "params": group_nd,
187
211
  "weight_decay": weight_decay,
188
212
  "quantize": True,
189
213
  "apollo_rank": 0,
214
+ "beta1": 0.9, # Remove if minimizing optimizer memory is the priority.
190
215
  "factored": False # Disables factorization to preserve spatial structures, enabling finer gradient scaling.
191
216
  # Note: This increases state memory for >2D weights, depending on your model architecture.
192
217
  # If VRAM is constrained, reverting to factored=True is a safe alternative.
@@ -206,10 +231,11 @@ optimizer = Adafactor8Bit(
206
231
  # Training loop...
207
232
  ```
208
233
 
209
- For more complete examples, please refer to the [examples folder](https://github.com/yanfeiwong/adafactor-8bit/tree/main/examples).
234
+ > [!NOTE]
235
+ > For more complete examples, please refer to the [examples folder](https://github.com/yanfeiwong/adafactor-8bit/tree/main/examples).
210
236
 
211
237
 
212
- ## Advanced Configuration
238
+ ## ⚙️ Advanced Configuration
213
239
 
214
240
  ### Continual Learning (`beta2` & `relative_step`)
215
241
  By default, Adafactor's second-moment decay rate dynamically decays with the training step, and the internal learning rate schedule (`relative_step`) scales the learning rate accordingly.
@@ -237,19 +263,57 @@ By default, Adafactor factorizes the second moment of $\ge$ 2D tensors into row
237
263
  If you are in an environment without a CUDA compiler and want to bypass JIT compilation entirely:
238
264
  - Set `use_cuda_kernel=False` to fall back to the pure PyTorch implementation.
239
265
 
240
- ## APOLLO Low-Rank Subspace Projection
266
+ ## 🌌 APOLLO Low-Rank Subspace Projection
241
267
  Enable the APOLLO path to compute gradient scaling factors in a memory-efficient low-rank subspace. Compared to Adafactor's standard row/column factorization (which assumes spatial independence), APOLLO uses random subspace projection to capture cross-dimensional covariance information, potentially leading to better generalization while keeping memory overhead extremely low.
242
268
 
243
- - **`apollo_rank`**: The target rank for the projection subspace. The default is `0` (disabled). Setting it to `256` might work well for most 1B to 7B models.
244
- *Note: Setting this to `1` (APOLLO-Mini style) pushes VRAM savings to the limit (saves even more VRAM than the Adafactor path). However, the original APOLLO-Mini relies on Adam's first-moment (beta1) to smooth out noise. Since our implementation uses a pure second-moment architecture, rank=1 may lead to distorted scaling factors and training instability.*
269
+ - **`apollo_rank`**: The target rank for the projection subspace. The default is `0` (disabled).
270
+
271
+ - The official APOLLO GitHub repository recommends a rank of `256` for 1B and 7B models.
272
+ - The [LLaMA-Factory](https://llamafactory.readthedocs.io/en/latest/advanced/arguments.html#apollo) default is `16`.
273
+ - Setting this to `1` (APOLLO-Mini style) pushes VRAM savings to the limit (saves even more VRAM than the Adafactor path). The original APOLLO-Mini relies on the first-moment (beta1) to smooth out projection noise. To replicate this, set `beta1=0.9` alongside `apollo_rank=1`. Without beta1, rank=1 may still work but can exhibit noisier scaling factors, especially at small batch sizes.
274
+
275
+
245
276
  - **`apollo_scale_type`**: Determines how the scaling factor is applied. `'channel'` applies it per channel (Standard APOLLO), while `'tensor'` applies it globally (APOLLO-Mini).
246
277
  - **`apollo_update_proj_gap`**: Steps between projection matrix refreshes. Defaults to `200`. Setting this too small may cause frequent oscillations due to abrupt basis mutations, while setting it too large might cause the projection space to become stale and fail to track the drift of the gradient manifold.
247
278
  - **`apollo_factorize` (Experimental)**: Applies Adafactor's row/column factorization within the low-rank subspace. Mathematically, this leverages the norm-preserving property of random projections to approximate the variance of the primary dimension, while the secondary dimension's variance is estimated across random bases, introducing inherent noise. This dual-compression mechanism drastically reduces optimizer state overhead. Note that for smaller models, the actual VRAM savings might be marginal, and the introduced noise could impact convergence stability. Use with caution.
248
279
  - **Fira Limiter Integration**: The APOLLO path automatically applies the Fira Norm-Growth Limiter to the scaled gradients to prevent sudden gradient rises from causing loss spikes. You can adjust its sensitivity using the global `fira_margin` parameter.
249
280
 
281
+ ## 🛡️ CAME Confidence-Guided Updates
282
+
283
+ Enable the CAME (Confidence-guided Adaptive Memory Efficient Optimization) path to add a confidence estimation stage after momentum accumulation:
284
+
285
+ **Adaptive Scaling ($V$) → Momentum Accumulation ($M$) → Confidence Weighting ($C$)**
286
+
287
+ ### Key Parameters & Tuning
250
288
 
289
+ The confidence stage measures the consistency between the current update direction and historical momentum, adaptively suppressing highly oscillatory updates.
251
290
 
252
- ## Learning Rate Guide for Beginners
291
+ - **`beta3`**: EMA decay coefficient for the confidence matrix. Requires `beta1` (momentum) and `factored=True`. Mutually exclusive with `apollo_rank`. Defaults to `None` (disabled).
292
+ - **Learning Rate**: The official CAME implementation recommends **0.5–0.9×** the AdamW learning rate (see [official tuning guide](https://github.com/yangluo7/CAME/tree/master#hyper-parameter-tuning)). To use this learning rate in this library, you need to disable Adafactor's scaling and clipping (`scale_parameter=False`, `d=1e9`) to align with the original CAME behavior.
293
+ - **Warmup**: Since the confidence matrix is zero-initialized without bias correction, a learning rate warmup is recommended to safely establish the confidence baseline.
294
+ - **Choosing `beta3`**: `beta3` should generally be larger than `beta2` so the confidence estimate evolves more slowly than the variance estimate. A practical starting range is **0.9995–0.99995** when `beta2=0.999`.
295
+
296
+
297
+ ### Configuration Example
298
+
299
+ To replicate "vanilla" CAME (stripping Adafactor's native modifications), replace the standard 2D APOLLO group in your `param_groups` with the following configuration:
300
+
301
+ ```python
302
+ {
303
+ "params": param_group,
304
+ "lr": lr, # Original CAME recommends 0.5-0.9x AdamW LR
305
+ "weight_decay": weight_decay,
306
+ "quantize": True,
307
+ "beta1": 0.9,
308
+ "beta3": 0.9999, # Enable CAME confidence guidance
309
+ "apollo_rank": 0, # Mutually exclusive with CAME
310
+ "scale_parameter": False, # Disable Adafactor RMS scaling to align with vanilla CAME
311
+ "d": 1e9, # Disable Adafactor global RMS clipping
312
+ "enable_fira_for_adafactor": False, # Disable Fira Limiter to prevent interference with CAME's scaling
313
+ },
314
+ ```
315
+
316
+ ## 📈 Learning Rate Guide for Beginners
253
317
 
254
318
  If you are migrating from optimizers like AdamW, Adafactor's learning rate behavior might feel a bit different. This is mainly due to the `scale_parameter` option.
255
319
 
@@ -265,7 +329,7 @@ If you are migrating from optimizers like AdamW, Adafactor's learning rate behav
265
329
 
266
330
 
267
331
 
268
- ## Acknowledgements
332
+ ## 🎓 Acknowledgements
269
333
 
270
334
  Thanks to **Noam Shazeer** and **Mitchell Stern** for proposing the original Adafactor algorithm in the paper [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235).
271
335
 
@@ -275,14 +339,18 @@ Thanks to **Hanqing Zhu**, **Zhenyu Zhang**, and the team for proposing the appr
275
339
 
276
340
  Thanks to **Xi Chen**, **Kaituo Feng**, and the team for the Norm-Growth Limiter mechanism introduced in [Fira: Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623).
277
341
 
342
+ Thanks to **Yang Luo** and the team for proposing the confidence-guided strategy in the paper [CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047).
343
+
278
344
  Thanks to the **PyTorch team** for providing the foundational Optimizer implementation and the C++ Extension toolchain.
279
345
 
280
- Thanks to the large language models **Qwen** and **DeepSeek** for valuable technical discussions and code reviews on CUDA low-level optimization, memory safety mechanisms, and cross-platform compilation pipeline design.
346
+ Thanks to the large language models **Qwen**, **ChatGLM** and **DeepSeek** for valuable technical discussions and code reviews on CUDA low-level optimization and memory safety mechanisms.
281
347
 
282
- ## Star History
348
+ ## 🏛️ License
283
349
 
284
- [![Star History Chart](https://api.star-history.com/svg?repos=yanfeiwong/adafactor-8bit&type=Date&theme=dark)](https://star-history.com/#yanfeiwong/adafactor-8bit&Date)
350
+ [The project is released under the MIT License.](https://github.com/yanfeiwong/adafactor-8bit/blob/main/LICENSE)
285
351
 
286
- ## License
352
+ ## ⭐ Star the Project
287
353
 
288
- [The project is released under the MIT License.](https://github.com/yanfeiwong/adafactor-8bit/blob/main/LICENSE)
354
+ If this optimizer has been useful in your work, consider giving the repository a star. It helps others discover the project and supports future development.
355
+
356
+ [![Star History Chart](https://api.star-history.com/svg?repos=yanfeiwong/adafactor-8bit&type=Date&theme=dark)](https://star-history.com/#yanfeiwong/adafactor-8bit&Date)
@@ -1,3 +1,10 @@
1
+ <p align="center">
2
+ <a href="https://github.com/yanfeiwong/adafactor-8bit">
3
+ <img src="https://github.com/yanfeiwong/adafactor-8bit/raw/main/assets/banner.png"
4
+ alt="Adafactor8Bit"
5
+ width="80%">
6
+ </a>
7
+ </p>
1
8
  <div align="center">
2
9
 
3
10
  # 8-bit Adafactor with Fused CUDA Kernels
@@ -12,25 +19,27 @@
12
19
 
13
20
  </div>
14
21
 
15
- An 8-bit Adafactor optimizer featuring fused CUDA kernels and log-space block-wise quantization, designed to further reduce optimizer state memory while maintaining low step overhead and stability suitable for large models such as LLMs and diffusion models.
22
+ An enhanced 8-bit Adafactor optimizer featuring fused CUDA kernels, log-space block-wise quantization, and optional add-ons including 4-bit packed first moments, APOLLO low-rank updates, and CAME confidence-guided optimization. It delivers substantially lower optimizer memory while preserving the low-overhead and numerical stability that make Adafactor attractive for training LLMs and diffusion models.
16
23
 
17
24
 
18
- ## Key Features
25
+ ## Key Features
19
26
 
20
27
  - **Log-Space Quantization**: Maps the second moment (variance) to the log2 space before 8-bit quantization. This approach accommodates the long-tail distribution of variances, reducing the risk of small second-moment estimates being truncated to zero and improving overall training stability.
21
28
  - **Fused CUDA Kernels**: Combines dequantization, EMA updates, Warp-Shuffle reductions, and requantization into single kernels. It utilizes `float4` vectorization to optimize memory bandwidth usage.
29
+ - **Optional 4-bit Packed First Moment**: Stores the first moment (`beta1`) in a physically packed 4-bit format when enabled, providing momentum with minimal additional memory overhead.
30
+ - **CAME Confidence Guidance**: Optional Confidence-guided Adaptive Memory Efficient Optimization (CAME) that estimates update confidence from historical momentum and adaptively suppresses unstable update directions, improving training stability and reducing loss spikes.
22
31
  - **APOLLO Subspace Projection**: Opt-in random subspace projection that estimates adaptive gradient scaling in a low-rank space, preventing stale second-moment statistics and potentially improving convergence and generalization.
23
32
  - **Fira Norm-Growth Limiter**: Suppresses destructive gradient spikes by regulating the relative increase of update norms. Originally used for the APOLLO path, it is now available for the standard Adafactor path as well. It improves training stability and often allows the safe removal of external gradient clipping.
24
33
  - **Zero CPU-GPU Sync**: Eliminates implicit synchronizations (e.g., D2H copies) in the control flow, ensuring the GPU computation pipeline runs without blocking.
25
34
  - **Cross-Platform JIT**: Uses Just-In-Time (JIT) compilation for straightforward setup across both Windows and Linux environments.
26
35
 
27
- ## Performance
36
+ ## 📊 Performance
28
37
 
29
- - **Memory Footprint**: Due to Adafactor's factorized second-moment estimation and 8-bit quantization, the optimizer state memory usage is generally lower than that of `AdamW8Bit`.
38
+ - **Memory Footprint**: Due to Adafactor's factorized second-moment estimation, 8-bit quantization, and optional 4-bit packed first moments, the optimizer typically consumes substantially less memory than `AdamW8Bit`.
30
39
  - **Training Speed**: The fused kernel design and reduced synchronization overhead allow it to achieve step times comparable to other mainstream 8-bit optimizers.
31
40
  - **Quantization Precision**: The second moment (variance) in Adafactor is strictly non-negative and spans multiple orders of magnitude. By mapping it to `UINT8` in log2 space rather than linear space, the optimizer preserves relative precision for small variances, mitigating the instability often caused by outlier gradients in standard 8-bit quantization.
32
41
 
33
- ## Installation
42
+ ## 📦 Installation
34
43
 
35
44
  This project uses JIT (Just-In-Time) compilation.
36
45
 
@@ -50,9 +59,12 @@ pip install -U adafactor8bit
50
59
  pip install git+https://github.com/yanfeiwong/adafactor-8bit.git
51
60
  ```
52
61
 
53
- **Note**: The first time you instantiate the optimizer (or run the example script), it will automatically trigger the JIT compilation of the CUDA source code in the background. This may take anywhere from a few seconds to a couple of minutes depending on your system, and the terminal might appear unresponsive. Once compiled, the binary will be cached, and all subsequent runs will be instantaneous.
62
+ > [!IMPORTANT]
63
+ > **First-Time Compilation**: The first time you instantiate the optimizer (or run the example script), it will automatically trigger the JIT compilation of the CUDA source code in the background. This may take anywhere from a few seconds to a couple of minutes depending on your system, and the terminal might appear unresponsive. Once compiled, the binary will be cached, and all subsequent runs will be instantaneous.
54
64
 
55
- ## Quick Start
65
+
66
+
67
+ ## 🚀 Quick Start
56
68
 
57
69
  Using it is as simple as using a standard PyTorch optimizer.
58
70
 
@@ -62,7 +74,8 @@ from adafactor8bit import Adafactor8Bit
62
74
  optimizer = Adafactor8Bit(model.parameters(), lr=1e-3)
63
75
  ```
64
76
 
65
- **💡 Note**: Passing `model.parameters()` directly works for a quick test. In production, `param_groups` are recommended to protect sensitive layers (Norms, Biases) from quantization and weight decay. For **sparse token embeddings** (large vocabularies + small batch sizes), please refer to the [Advanced Example](#advanced-example) to avoid cold-start variance explosion.
77
+ > [!TIP]
78
+ > Passing `model.parameters()` directly works for a quick test. In production, `param_groups` are recommended to protect sensitive layers (Norms, Biases) from quantization and weight decay. For **sparse token embeddings** (large vocabularies + small batch sizes), please refer to the [Advanced Example](#-advanced-example) to avoid cold-start variance explosion.
66
79
 
67
80
 
68
81
  ```python
@@ -96,15 +109,20 @@ optimizer = Adafactor8Bit(
96
109
  # Training loop...
97
110
  ```
98
111
 
99
- ## Advanced Example
112
+ ## 🛠️ Advanced Example
100
113
 
101
- Here we demonstrate a **hybrid grouping** strategy for complex hybrid architectures (e.g., Vision-Language Models, Diffusion UNets) to achieve stable and efficient momentum-free training as much as possible.
114
+ Here we demonstrate a **hybrid grouping** strategy for complex hybrid architectures (e.g., Vision-Language Models, Diffusion UNets) to achieve stable and efficient training.
102
115
 
103
116
  📌 **The following strategies are applied:**
104
- 1. **1D / Sensitive Parameters (Norms, Biases)**: No quantization, no weight decay.
105
- 2. **Embedding Layers**: Combines `factored=False`, `scale_parameter=False`, and `d=1e9` to make the optimization behavior equivalent to a **momentum-free Adam**. Paired with an Adam-style learning rate, this allows for fine-grained, per-token updates while avoiding cold-token interference (global clipping penalties).
106
- 3. **2D Weights (Linear Layers)**: 8-bit quantization, weight decay, using the **APOLLO** path. The continuously switching random subspace projection helps capture comprehensive gradient information and acts as a regularizer.
107
- 4. **>2D Weights (Conv2d, etc.)**: 8-bit quantization, weight decay, **Full-Rank** (`factored=False`). Trades a certain amount of VRAM to preserve complete spatial structures for better optimization outcomes.
117
+ | Layer Type | Strategy |
118
+ |------------|----------|
119
+ | **1D / Sensitive Parameters** (Norms, Biases) | No quantization, no weight decay |
120
+ | **Embedding Layers** | `factored=False`, `scale_parameter=False`, `d=1e9` Momentum-free Adam. Paired with an Adam-style learning rate, this allows for fine-grained, per-token updates while avoiding cold-token interference. |
121
+ | **2D Weights** (Linear Layers) | 8-bit quantization, weight decay, **APOLLO** path. Continuously switching random subspace projection captures comprehensive gradient information and acts as a regularizer. |
122
+ | **>2D Weights** (Conv2d, etc.) | 8-bit quantization, weight decay, **Full-Rank** (`factored=False`). Trades some VRAM to preserve complete spatial structures. |
123
+ | **Momentum (`beta1`)** | Enabled only for dense weight matrices, where the optimization benefit typically outweighs the small memory overhead of the packed 4-bit first moment. Sensitive parameters (Norms/Biases) and sparse Embeddings remain momentum-free. |
124
+
125
+ **Implementation:**
108
126
 
109
127
  ```python
110
128
  from adafactor8bit import Adafactor8Bit
@@ -152,14 +170,21 @@ def get_param_groups(model, lr_emb, weight_decay, apollo_rank=256):
152
170
  },
153
171
 
154
172
  # 3. 2D Weights: 8-bit quantization, Weight Decay, APOLLO low-rank projection
155
- {"params": group_2d, "weight_decay": weight_decay, "quantize": True, "apollo_rank": apollo_rank},
156
-
173
+ {
174
+ "params": group_2d,
175
+ "weight_decay": weight_decay,
176
+ "quantize": True,
177
+ "apollo_rank": apollo_rank,
178
+ "beta1": 0.9, # Remove if minimizing optimizer memory is the priority.
179
+ },
180
+
157
181
  # 4. >2D Weights: 8-bit quantization, Weight Decay, Full-Rank
158
182
  {
159
183
  "params": group_nd,
160
184
  "weight_decay": weight_decay,
161
185
  "quantize": True,
162
186
  "apollo_rank": 0,
187
+ "beta1": 0.9, # Remove if minimizing optimizer memory is the priority.
163
188
  "factored": False # Disables factorization to preserve spatial structures, enabling finer gradient scaling.
164
189
  # Note: This increases state memory for >2D weights, depending on your model architecture.
165
190
  # If VRAM is constrained, reverting to factored=True is a safe alternative.
@@ -179,10 +204,11 @@ optimizer = Adafactor8Bit(
179
204
  # Training loop...
180
205
  ```
181
206
 
182
- For more complete examples, please refer to the [examples folder](https://github.com/yanfeiwong/adafactor-8bit/tree/main/examples).
207
+ > [!NOTE]
208
+ > For more complete examples, please refer to the [examples folder](https://github.com/yanfeiwong/adafactor-8bit/tree/main/examples).
183
209
 
184
210
 
185
- ## Advanced Configuration
211
+ ## ⚙️ Advanced Configuration
186
212
 
187
213
  ### Continual Learning (`beta2` & `relative_step`)
188
214
  By default, Adafactor's second-moment decay rate dynamically decays with the training step, and the internal learning rate schedule (`relative_step`) scales the learning rate accordingly.
@@ -210,19 +236,57 @@ By default, Adafactor factorizes the second moment of $\ge$ 2D tensors into row
210
236
  If you are in an environment without a CUDA compiler and want to bypass JIT compilation entirely:
211
237
  - Set `use_cuda_kernel=False` to fall back to the pure PyTorch implementation.
212
238
 
213
- ## APOLLO Low-Rank Subspace Projection
239
+ ## 🌌 APOLLO Low-Rank Subspace Projection
214
240
  Enable the APOLLO path to compute gradient scaling factors in a memory-efficient low-rank subspace. Compared to Adafactor's standard row/column factorization (which assumes spatial independence), APOLLO uses random subspace projection to capture cross-dimensional covariance information, potentially leading to better generalization while keeping memory overhead extremely low.
215
241
 
216
- - **`apollo_rank`**: The target rank for the projection subspace. The default is `0` (disabled). Setting it to `256` might work well for most 1B to 7B models.
217
- *Note: Setting this to `1` (APOLLO-Mini style) pushes VRAM savings to the limit (saves even more VRAM than the Adafactor path). However, the original APOLLO-Mini relies on Adam's first-moment (beta1) to smooth out noise. Since our implementation uses a pure second-moment architecture, rank=1 may lead to distorted scaling factors and training instability.*
242
+ - **`apollo_rank`**: The target rank for the projection subspace. The default is `0` (disabled).
243
+
244
+ - The official APOLLO GitHub repository recommends a rank of `256` for 1B and 7B models.
245
+ - The [LLaMA-Factory](https://llamafactory.readthedocs.io/en/latest/advanced/arguments.html#apollo) default is `16`.
246
+ - Setting this to `1` (APOLLO-Mini style) pushes VRAM savings to the limit (saves even more VRAM than the Adafactor path). The original APOLLO-Mini relies on the first-moment (beta1) to smooth out projection noise. To replicate this, set `beta1=0.9` alongside `apollo_rank=1`. Without beta1, rank=1 may still work but can exhibit noisier scaling factors, especially at small batch sizes.
247
+
248
+
218
249
  - **`apollo_scale_type`**: Determines how the scaling factor is applied. `'channel'` applies it per channel (Standard APOLLO), while `'tensor'` applies it globally (APOLLO-Mini).
219
250
  - **`apollo_update_proj_gap`**: Steps between projection matrix refreshes. Defaults to `200`. Setting this too small may cause frequent oscillations due to abrupt basis mutations, while setting it too large might cause the projection space to become stale and fail to track the drift of the gradient manifold.
220
251
  - **`apollo_factorize` (Experimental)**: Applies Adafactor's row/column factorization within the low-rank subspace. Mathematically, this leverages the norm-preserving property of random projections to approximate the variance of the primary dimension, while the secondary dimension's variance is estimated across random bases, introducing inherent noise. This dual-compression mechanism drastically reduces optimizer state overhead. Note that for smaller models, the actual VRAM savings might be marginal, and the introduced noise could impact convergence stability. Use with caution.
221
252
  - **Fira Limiter Integration**: The APOLLO path automatically applies the Fira Norm-Growth Limiter to the scaled gradients to prevent sudden gradient rises from causing loss spikes. You can adjust its sensitivity using the global `fira_margin` parameter.
222
253
 
254
+ ## 🛡️ CAME Confidence-Guided Updates
255
+
256
+ Enable the CAME (Confidence-guided Adaptive Memory Efficient Optimization) path to add a confidence estimation stage after momentum accumulation:
257
+
258
+ **Adaptive Scaling ($V$) → Momentum Accumulation ($M$) → Confidence Weighting ($C$)**
259
+
260
+ ### Key Parameters & Tuning
223
261
 
262
+ The confidence stage measures the consistency between the current update direction and historical momentum, adaptively suppressing highly oscillatory updates.
224
263
 
225
- ## Learning Rate Guide for Beginners
264
+ - **`beta3`**: EMA decay coefficient for the confidence matrix. Requires `beta1` (momentum) and `factored=True`. Mutually exclusive with `apollo_rank`. Defaults to `None` (disabled).
265
+ - **Learning Rate**: The official CAME implementation recommends **0.5–0.9×** the AdamW learning rate (see [official tuning guide](https://github.com/yangluo7/CAME/tree/master#hyper-parameter-tuning)). To use this learning rate in this library, you need to disable Adafactor's scaling and clipping (`scale_parameter=False`, `d=1e9`) to align with the original CAME behavior.
266
+ - **Warmup**: Since the confidence matrix is zero-initialized without bias correction, a learning rate warmup is recommended to safely establish the confidence baseline.
267
+ - **Choosing `beta3`**: `beta3` should generally be larger than `beta2` so the confidence estimate evolves more slowly than the variance estimate. A practical starting range is **0.9995–0.99995** when `beta2=0.999`.
268
+
269
+
270
+ ### Configuration Example
271
+
272
+ To replicate "vanilla" CAME (stripping Adafactor's native modifications), replace the standard 2D APOLLO group in your `param_groups` with the following configuration:
273
+
274
+ ```python
275
+ {
276
+ "params": param_group,
277
+ "lr": lr, # Original CAME recommends 0.5-0.9x AdamW LR
278
+ "weight_decay": weight_decay,
279
+ "quantize": True,
280
+ "beta1": 0.9,
281
+ "beta3": 0.9999, # Enable CAME confidence guidance
282
+ "apollo_rank": 0, # Mutually exclusive with CAME
283
+ "scale_parameter": False, # Disable Adafactor RMS scaling to align with vanilla CAME
284
+ "d": 1e9, # Disable Adafactor global RMS clipping
285
+ "enable_fira_for_adafactor": False, # Disable Fira Limiter to prevent interference with CAME's scaling
286
+ },
287
+ ```
288
+
289
+ ## 📈 Learning Rate Guide for Beginners
226
290
 
227
291
  If you are migrating from optimizers like AdamW, Adafactor's learning rate behavior might feel a bit different. This is mainly due to the `scale_parameter` option.
228
292
 
@@ -238,7 +302,7 @@ If you are migrating from optimizers like AdamW, Adafactor's learning rate behav
238
302
 
239
303
 
240
304
 
241
- ## Acknowledgements
305
+ ## 🎓 Acknowledgements
242
306
 
243
307
  Thanks to **Noam Shazeer** and **Mitchell Stern** for proposing the original Adafactor algorithm in the paper [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235).
244
308
 
@@ -248,14 +312,18 @@ Thanks to **Hanqing Zhu**, **Zhenyu Zhang**, and the team for proposing the appr
248
312
 
249
313
  Thanks to **Xi Chen**, **Kaituo Feng**, and the team for the Norm-Growth Limiter mechanism introduced in [Fira: Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623).
250
314
 
315
+ Thanks to **Yang Luo** and the team for proposing the confidence-guided strategy in the paper [CAME: Confidence-guided Adaptive Memory Efficient Optimization](https://arxiv.org/abs/2307.02047).
316
+
251
317
  Thanks to the **PyTorch team** for providing the foundational Optimizer implementation and the C++ Extension toolchain.
252
318
 
253
- Thanks to the large language models **Qwen** and **DeepSeek** for valuable technical discussions and code reviews on CUDA low-level optimization, memory safety mechanisms, and cross-platform compilation pipeline design.
319
+ Thanks to the large language models **Qwen**, **ChatGLM** and **DeepSeek** for valuable technical discussions and code reviews on CUDA low-level optimization and memory safety mechanisms.
254
320
 
255
- ## Star History
321
+ ## 🏛️ License
256
322
 
257
- [![Star History Chart](https://api.star-history.com/svg?repos=yanfeiwong/adafactor-8bit&type=Date&theme=dark)](https://star-history.com/#yanfeiwong/adafactor-8bit&Date)
323
+ [The project is released under the MIT License.](https://github.com/yanfeiwong/adafactor-8bit/blob/main/LICENSE)
258
324
 
259
- ## License
325
+ ## ⭐ Star the Project
260
326
 
261
- [The project is released under the MIT License.](https://github.com/yanfeiwong/adafactor-8bit/blob/main/LICENSE)
327
+ If this optimizer has been useful in your work, consider giving the repository a star. It helps others discover the project and supports future development.
328
+
329
+ [![Star History Chart](https://api.star-history.com/svg?repos=yanfeiwong/adafactor-8bit&type=Date&theme=dark)](https://star-history.com/#yanfeiwong/adafactor-8bit&Date)