ista-daslab-optimizers 1.1.6__tar.gz → 1.1.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ista_daslab_optimizers-1.1.6/ista_daslab_optimizers.egg-info → ista_daslab_optimizers-1.1.8}/PKG-INFO +32 -10
- {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/README.md +27 -8
- {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/__init__.py +2 -0
- {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/dense_mfac/dense_core_mfac.py +4 -4
- ista_daslab_optimizers-1.1.8/ista_daslab_optimizers/fft_low_rank/dct_adamw.py +351 -0
- ista_daslab_optimizers-1.1.8/ista_daslab_optimizers/fft_low_rank/fft_projector.py +192 -0
- ista_daslab_optimizers-1.1.8/ista_daslab_optimizers/fft_low_rank/trion.py +242 -0
- ista_daslab_optimizers-1.1.8/ista_daslab_optimizers/ista_optimizer/__init__.py +5 -0
- ista_daslab_optimizers-1.1.8/ista_daslab_optimizers/ista_optimizer/ista_optimizer.py +36 -0
- {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/micro_adam/micro_adam.py +14 -14
- {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/sparse_mfac/sparse_core_mfac_w_ef.py +10 -10
- {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/tools.py +4 -3
- ista_daslab_optimizers-1.1.8/ista_daslab_optimizers/utils/dct.py +45 -0
- ista_daslab_optimizers-1.1.8/ista_daslab_optimizers/utils/global_cache.py +45 -0
- ista_daslab_optimizers-1.1.8/ista_daslab_optimizers/utils/matrix_storage.py +58 -0
- ista_daslab_optimizers-1.1.8/ista_daslab_optimizers/utils/newton_schulz_triton.py +374 -0
- ista_daslab_optimizers-1.1.8/ista_daslab_optimizers/utils/quantizers.py +71 -0
- ista_daslab_optimizers-1.1.8/ista_daslab_optimizers/utils/schedulers.py +41 -0
- {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8/ista_daslab_optimizers.egg-info}/PKG-INFO +32 -10
- ista_daslab_optimizers-1.1.8/ista_daslab_optimizers.egg-info/SOURCES.txt +33 -0
- {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers.egg-info/requires.txt +2 -0
- ista_daslab_optimizers-1.1.8/ista_daslab_optimizers.egg-info/top_level.txt +1 -0
- {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/pyproject.toml +4 -1
- ista_daslab_optimizers-1.1.6/ista_daslab_optimizers.egg-info/SOURCES.txt +0 -48
- ista_daslab_optimizers-1.1.6/ista_daslab_optimizers.egg-info/top_level.txt +0 -5
- ista_daslab_optimizers-1.1.6/kernels/dense_mfac/dense_mfac.cpp +0 -20
- ista_daslab_optimizers-1.1.6/kernels/dense_mfac/dense_mfac_kernel.cu +0 -216
- ista_daslab_optimizers-1.1.6/kernels/micro_adam/micro_adam.cpp +0 -62
- ista_daslab_optimizers-1.1.6/kernels/micro_adam/micro_adam_asymm_block_quant.cu +0 -64
- ista_daslab_optimizers-1.1.6/kernels/micro_adam/micro_adam_asymm_block_quant_inv.cu +0 -83
- ista_daslab_optimizers-1.1.6/kernels/micro_adam/micro_adam_update.cu +0 -165
- ista_daslab_optimizers-1.1.6/kernels/sparse_mfac/sparse_mfac.cpp +0 -84
- ista_daslab_optimizers-1.1.6/kernels/sparse_mfac/sparse_mfac_LCG_kernel.cu +0 -246
- ista_daslab_optimizers-1.1.6/kernels/sparse_mfac/sparse_mfac_SP_kernel.cu +0 -251
- ista_daslab_optimizers-1.1.6/kernels/sparse_mfac_pruner/sparse_mfac_pruner.cpp +0 -57
- ista_daslab_optimizers-1.1.6/kernels/sparse_mfac_pruner/sparse_mfac_pruner.cu +0 -235
- ista_daslab_optimizers-1.1.6/kernels/tools/tools.cpp +0 -127
- ista_daslab_optimizers-1.1.6/kernels/tools/tools_kernel.cu +0 -315
- ista_daslab_optimizers-1.1.6/kernels/utils.h +0 -125
- ista_daslab_optimizers-1.1.6/setup.py +0 -56
- {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/LICENSE +0 -0
- {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/MANIFEST.in +0 -0
- {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/acdc/__init__.py +0 -0
- {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/acdc/acdc.py +0 -0
- {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/acdc/wd_scheduler.py +0 -0
- {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/dense_mfac/__init__.py +0 -0
- {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/dense_mfac/dense_mfac.py +0 -0
- {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/micro_adam/__init__.py +0 -0
- {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/sparse_mfac/__init__.py +0 -0
- {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/sparse_mfac/sparse_mfac.py +0 -0
- {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers.egg-info/dependency_links.txt +0 -0
- {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: ista_daslab_optimizers
|
|
3
|
-
Version: 1.1.
|
|
3
|
+
Version: 1.1.8
|
|
4
4
|
Summary: Deep Learning optimizers developed in the Distributed Algorithms and Systems group (DASLab) @ Institute of Science and Technology Austria (ISTA)
|
|
5
5
|
Author-email: Ionut-Vlad Modoranu <ionut-vlad.modoranu@ist.ac.at>
|
|
6
6
|
Maintainer-email: Ionut-Vlad Modoranu <ionut-vlad.modoranu@ist.ac.at>
|
|
@@ -222,6 +222,9 @@ Requires-Dist: gpustat
|
|
|
222
222
|
Requires-Dist: timm
|
|
223
223
|
Requires-Dist: einops
|
|
224
224
|
Requires-Dist: psutil
|
|
225
|
+
Requires-Dist: fast-hadamard-transform
|
|
226
|
+
Requires-Dist: ista-daslab-optimizers-cuda
|
|
227
|
+
Dynamic: license-file
|
|
225
228
|
|
|
226
229
|
# ISTA DAS Lab Optimization Algorithms Package
|
|
227
230
|
This repository contains optimization algorithms for Deep Learning developed by
|
|
@@ -240,6 +243,9 @@ The repository contains code for the following optimizers published by DASLab @
|
|
|
240
243
|
- **MicroAdam**:
|
|
241
244
|
- paper: [MicroAdam: Accurate Adaptive Optimization with Low Space Overhead and Provable Convergence](https://arxiv.org/abs/2405.15593)
|
|
242
245
|
- official repository: [GitHub](https://github.com/IST-DASLab/MicroAdam)
|
|
246
|
+
- **Trion / DCT-AdamW**:
|
|
247
|
+
- paper: [FFT-based Dynamic Subspace Selection for Low-Rank Adaptive Optimization of Large Language Models](https://arxiv.org/abs/2505.17967v3)
|
|
248
|
+
- code: [GitHub](https://github.com/IST-DASLab/ISTA-DASLab-Optimizers/tree/main/ista_daslab_optimizers/fft_low_rank)
|
|
243
249
|
|
|
244
250
|
### Installation
|
|
245
251
|
To use the latest stable version of this repository, you can install via pip:
|
|
@@ -261,7 +267,8 @@ source install.sh
|
|
|
261
267
|
|
|
262
268
|
## How to use optimizers?
|
|
263
269
|
|
|
264
|
-
In this repository we provide a minimal working example for CIFAR-10 for optimizers `acdc`,
|
|
270
|
+
In this repository we provide a minimal working example for CIFAR-10 for optimizers `acdc`,
|
|
271
|
+
`dense_mfac`, `sparse_mfac` and `micro_adam`:
|
|
265
272
|
```shell
|
|
266
273
|
cd examples/cifar10
|
|
267
274
|
OPTIMIZER=micro_adam # or any other optimizer listed above
|
|
@@ -291,18 +298,33 @@ optimizer = MicroAdam(
|
|
|
291
298
|
# Versions summary:
|
|
292
299
|
|
|
293
300
|
---
|
|
301
|
+
- **1.1.8** @ February 5th, 2026:
|
|
302
|
+
- moved kernels to [ISTA-DASLab-Optimizers-CUDA](https://github.com/IST-DASLab/ISTA-DASLab-Optimizers-CUDA)
|
|
303
|
+
- building building the package after adding a new optimizer that doesn't require CUDA support would require compiling
|
|
304
|
+
the kernels from scratch, which is time consuming and not needed
|
|
305
|
+
- **1.1.7** @ October 8th, 2025:
|
|
306
|
+
- added code for `Trion & DCT-AdamW`
|
|
294
307
|
- **1.1.6** @ February 19th, 2025:
|
|
295
|
-
- do not update the parameters that have `None` gradient in method `update_model` from `tools.py`.
|
|
308
|
+
- do not update the parameters that have `None` gradient in method `update_model` from `tools.py`.
|
|
309
|
+
This is useful when using M-FAC for models with more than one classification head in the Continual Learning framework.
|
|
296
310
|
- **1.1.5** @ February 19th, 2025:
|
|
297
|
-
- adapted `DenseMFAC` for a model with multiple classification heads for Continual Learning where
|
|
311
|
+
- adapted `DenseMFAC` for a model with multiple classification heads for Continual Learning where
|
|
312
|
+
we have one feature extractor block and a list of classification heads. The issue was related to
|
|
313
|
+
the model size, which included the feature extractor backbone and all classification heads, but
|
|
314
|
+
in practice only one classification head will be used for training and inference. This caused some
|
|
315
|
+
size mismatch errors at runtime in the `DenseCoreMFAC` module because the gradient at runtime had
|
|
316
|
+
fewer entries than the entire model. When using `DenseMFAC` for such settings, set `optimizer.model_size`
|
|
317
|
+
to the correct size after calling the constructor and the `DenseCoreMFAC` object will be created
|
|
318
|
+
automatically in the `step` function.
|
|
298
319
|
- **1.1.3** @ September 5th, 2024:
|
|
299
320
|
- allow using `SparseCoreMFACwithEF` separately by importing it in `sparse_mfac.__init__.py`
|
|
300
321
|
- **1.1.2** @ August 1st, 2024:
|
|
301
|
-
- ***[1.1.0]:*** added support to densify the final update: introduced parameter alpha that controls
|
|
302
|
-
(EF) to be integrated into the update to make it dense. Finally, the
|
|
303
|
-
the expense of another call to `Qinv` and `Q` (and
|
|
304
|
-
|
|
305
|
-
|
|
322
|
+
- ***[1.1.0]:*** added support to densify the final update: introduced parameter alpha that controls
|
|
323
|
+
the fraction of error feedback (EF) to be integrated into the update to make it dense. Finally, the
|
|
324
|
+
fraction alpha will be discarded from the EF at the expense of another call to `Qinv` and `Q` (and
|
|
325
|
+
implicitly quantization statistics computation).
|
|
326
|
+
- ***[1.0.2]:*** added FSDP-compatible implementation by initializing the parameter states in the
|
|
327
|
+
`update_step` method instead of MicroAdam constructor
|
|
306
328
|
- **1.0.1** @ June 27th, 2024:
|
|
307
329
|
- removed version in dependencies to avoid conflicts with llm-foundry
|
|
308
330
|
- **1.0.0** @ June 20th, 2024:
|
|
@@ -15,6 +15,9 @@ The repository contains code for the following optimizers published by DASLab @
|
|
|
15
15
|
- **MicroAdam**:
|
|
16
16
|
- paper: [MicroAdam: Accurate Adaptive Optimization with Low Space Overhead and Provable Convergence](https://arxiv.org/abs/2405.15593)
|
|
17
17
|
- official repository: [GitHub](https://github.com/IST-DASLab/MicroAdam)
|
|
18
|
+
- **Trion / DCT-AdamW**:
|
|
19
|
+
- paper: [FFT-based Dynamic Subspace Selection for Low-Rank Adaptive Optimization of Large Language Models](https://arxiv.org/abs/2505.17967v3)
|
|
20
|
+
- code: [GitHub](https://github.com/IST-DASLab/ISTA-DASLab-Optimizers/tree/main/ista_daslab_optimizers/fft_low_rank)
|
|
18
21
|
|
|
19
22
|
### Installation
|
|
20
23
|
To use the latest stable version of this repository, you can install via pip:
|
|
@@ -36,7 +39,8 @@ source install.sh
|
|
|
36
39
|
|
|
37
40
|
## How to use optimizers?
|
|
38
41
|
|
|
39
|
-
In this repository we provide a minimal working example for CIFAR-10 for optimizers `acdc`,
|
|
42
|
+
In this repository we provide a minimal working example for CIFAR-10 for optimizers `acdc`,
|
|
43
|
+
`dense_mfac`, `sparse_mfac` and `micro_adam`:
|
|
40
44
|
```shell
|
|
41
45
|
cd examples/cifar10
|
|
42
46
|
OPTIMIZER=micro_adam # or any other optimizer listed above
|
|
@@ -66,18 +70,33 @@ optimizer = MicroAdam(
|
|
|
66
70
|
# Versions summary:
|
|
67
71
|
|
|
68
72
|
---
|
|
73
|
+
- **1.1.8** @ February 5th, 2026:
|
|
74
|
+
- moved kernels to [ISTA-DASLab-Optimizers-CUDA](https://github.com/IST-DASLab/ISTA-DASLab-Optimizers-CUDA)
|
|
75
|
+
- building building the package after adding a new optimizer that doesn't require CUDA support would require compiling
|
|
76
|
+
the kernels from scratch, which is time consuming and not needed
|
|
77
|
+
- **1.1.7** @ October 8th, 2025:
|
|
78
|
+
- added code for `Trion & DCT-AdamW`
|
|
69
79
|
- **1.1.6** @ February 19th, 2025:
|
|
70
|
-
- do not update the parameters that have `None` gradient in method `update_model` from `tools.py`.
|
|
80
|
+
- do not update the parameters that have `None` gradient in method `update_model` from `tools.py`.
|
|
81
|
+
This is useful when using M-FAC for models with more than one classification head in the Continual Learning framework.
|
|
71
82
|
- **1.1.5** @ February 19th, 2025:
|
|
72
|
-
- adapted `DenseMFAC` for a model with multiple classification heads for Continual Learning where
|
|
83
|
+
- adapted `DenseMFAC` for a model with multiple classification heads for Continual Learning where
|
|
84
|
+
we have one feature extractor block and a list of classification heads. The issue was related to
|
|
85
|
+
the model size, which included the feature extractor backbone and all classification heads, but
|
|
86
|
+
in practice only one classification head will be used for training and inference. This caused some
|
|
87
|
+
size mismatch errors at runtime in the `DenseCoreMFAC` module because the gradient at runtime had
|
|
88
|
+
fewer entries than the entire model. When using `DenseMFAC` for such settings, set `optimizer.model_size`
|
|
89
|
+
to the correct size after calling the constructor and the `DenseCoreMFAC` object will be created
|
|
90
|
+
automatically in the `step` function.
|
|
73
91
|
- **1.1.3** @ September 5th, 2024:
|
|
74
92
|
- allow using `SparseCoreMFACwithEF` separately by importing it in `sparse_mfac.__init__.py`
|
|
75
93
|
- **1.1.2** @ August 1st, 2024:
|
|
76
|
-
- ***[1.1.0]:*** added support to densify the final update: introduced parameter alpha that controls
|
|
77
|
-
(EF) to be integrated into the update to make it dense. Finally, the
|
|
78
|
-
the expense of another call to `Qinv` and `Q` (and
|
|
79
|
-
|
|
80
|
-
|
|
94
|
+
- ***[1.1.0]:*** added support to densify the final update: introduced parameter alpha that controls
|
|
95
|
+
the fraction of error feedback (EF) to be integrated into the update to make it dense. Finally, the
|
|
96
|
+
fraction alpha will be discarded from the EF at the expense of another call to `Qinv` and `Q` (and
|
|
97
|
+
implicitly quantization statistics computation).
|
|
98
|
+
- ***[1.0.2]:*** added FSDP-compatible implementation by initializing the parameter states in the
|
|
99
|
+
`update_step` method instead of MicroAdam constructor
|
|
81
100
|
- **1.0.1** @ June 27th, 2024:
|
|
82
101
|
- removed version in dependencies to avoid conflicts with llm-foundry
|
|
83
102
|
- **1.0.0** @ June 20th, 2024:
|
|
@@ -4,10 +4,10 @@ import numpy as np
|
|
|
4
4
|
|
|
5
5
|
USE_CUDA = True
|
|
6
6
|
try:
|
|
7
|
-
import
|
|
7
|
+
import ista_daslab_cuda_dense_mfac
|
|
8
8
|
except Exception as e:
|
|
9
9
|
USE_CUDA = False
|
|
10
|
-
print('\n\t[WARNING] The module "
|
|
10
|
+
print('\n\t[WARNING] The module "ista_daslab_cuda_dense_mfac" is not installed, using slower PyTorch implementation!\n')
|
|
11
11
|
|
|
12
12
|
class DenseCoreMFAC:
|
|
13
13
|
def __init__(self, grads, dev, gpus, damp=1e-5, create_G=False):
|
|
@@ -76,7 +76,7 @@ class DenseCoreMFAC:
|
|
|
76
76
|
|
|
77
77
|
if USE_CUDA:
|
|
78
78
|
diag = torch.diag(torch.full(size=[self.m], fill_value=self.lambd, device=self.dev, dtype=self.dtype))
|
|
79
|
-
self.coef =
|
|
79
|
+
self.coef = ista_daslab_cuda_dense_mfac.hinv_setup(tmp, diag)
|
|
80
80
|
else:
|
|
81
81
|
for i in range(max(self.last, 1), self.m):
|
|
82
82
|
self.coef[i, :i] = tmp[i, :i].matmul(self.coef[:i, :i])
|
|
@@ -130,7 +130,7 @@ class DenseCoreMFAC:
|
|
|
130
130
|
dots = self.compute_scalar_products(x)
|
|
131
131
|
giHix = self.lambd * dots
|
|
132
132
|
if USE_CUDA:
|
|
133
|
-
giHix =
|
|
133
|
+
giHix = ista_daslab_cuda_dense_mfac.hinv_mul(self.m, self.giHig, giHix)
|
|
134
134
|
else:
|
|
135
135
|
for i in range(1, self.m):
|
|
136
136
|
giHix[i:].sub_(self.giHig[i - 1, i:], alpha=giHix[i - 1] / self.denom[i - 1])
|
|
@@ -0,0 +1,351 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.distributed as dist
|
|
3
|
+
|
|
4
|
+
import math
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from fast_hadamard_transform import hadamard_transform
|
|
8
|
+
from ista_daslab_optimizers.utils.dct import dct3_matrix
|
|
9
|
+
from ista_daslab_optimizers.utils.quantizers import Quantizer4bit, Quantizer8bit
|
|
10
|
+
from ista_daslab_optimizers.fft_low_rank.fft_projector import FFTLowRankProjector
|
|
11
|
+
|
|
12
|
+
PROJ_DCT = 'dct'
|
|
13
|
+
PROJ_HDM = 'hdm'
|
|
14
|
+
PROJ_RAND_QR = 'rqr'
|
|
15
|
+
|
|
16
|
+
ALL_PROJ = [
|
|
17
|
+
PROJ_DCT, # DCT projection
|
|
18
|
+
PROJ_HDM, # Hadamard projection
|
|
19
|
+
PROJ_RAND_QR, # Random-QR projection
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
STATE_M = 'm'
|
|
23
|
+
STATE_V = 'v'
|
|
24
|
+
STATE_Q = 'Q'
|
|
25
|
+
STATE_ID = 'param-id'
|
|
26
|
+
STATE_EF = 'ef'
|
|
27
|
+
# STATE_EF_MIN = 'ef-min-vals'
|
|
28
|
+
# STATE_EF_MAX = 'ef-max-vals'
|
|
29
|
+
STATE_FFT_LRP = 'fft-low-rank-projector'
|
|
30
|
+
STATE_BROADCAST_SOURCE = 'broadcast-src' # the process rank that computes the update for a parameter p will broadcast the parameter p to other workers
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class DCTAdamW(torch.optim.Optimizer):
|
|
34
|
+
def __init__(self,
|
|
35
|
+
params,
|
|
36
|
+
lr,
|
|
37
|
+
weight_decay,
|
|
38
|
+
rank,
|
|
39
|
+
proj,
|
|
40
|
+
use_ef=False,
|
|
41
|
+
q_ef=False,
|
|
42
|
+
distributed=False,
|
|
43
|
+
update_proj_gap=1,
|
|
44
|
+
rotate_subspace=False,
|
|
45
|
+
sim_type='matmul',
|
|
46
|
+
ell_norm=1,
|
|
47
|
+
max_shape=32_000,
|
|
48
|
+
betas=(0.9, 0.999),
|
|
49
|
+
eps=1e-8):
|
|
50
|
+
assert proj in ALL_PROJ
|
|
51
|
+
|
|
52
|
+
super().__init__(params, dict(lr=lr, weight_decay=weight_decay))
|
|
53
|
+
|
|
54
|
+
self.rank = rank
|
|
55
|
+
self.proj = proj
|
|
56
|
+
self.use_ef = use_ef
|
|
57
|
+
self.q_ef = q_ef
|
|
58
|
+
self.distributed = distributed
|
|
59
|
+
self.update_proj_gap = update_proj_gap
|
|
60
|
+
self.rotate_subspace = rotate_subspace
|
|
61
|
+
self.sim_type = sim_type
|
|
62
|
+
self.ell_norm = ell_norm
|
|
63
|
+
self.max_shape = max_shape # apply low-rank to 2D parameters that have both dimensions smaller than max_shape
|
|
64
|
+
self.betas = betas
|
|
65
|
+
self.eps = eps
|
|
66
|
+
|
|
67
|
+
self.steps = 0
|
|
68
|
+
self.is_state_initialized = False
|
|
69
|
+
self.Q = None # the full transformation matrix (non-truncated, all rows and columns)
|
|
70
|
+
self.Q_cols_norm = None
|
|
71
|
+
self.use_theoretical_similarity = (self.ell_norm < 0)
|
|
72
|
+
self.ell_norm = abs(self.ell_norm)
|
|
73
|
+
|
|
74
|
+
if proj == PROJ_DCT:
|
|
75
|
+
assert sim_type in ['matmul', 'makhoul']
|
|
76
|
+
else:
|
|
77
|
+
assert sim_type == 'matmul'
|
|
78
|
+
|
|
79
|
+
def setup_Q(self, p):
|
|
80
|
+
if self.Q is None:
|
|
81
|
+
size = min(p.shape)
|
|
82
|
+
if self.proj == PROJ_DCT:
|
|
83
|
+
Qdct3 = dct3_matrix(size, p.dtype, p.device) # first row is zero
|
|
84
|
+
if self.sim_type == 'makhoul':
|
|
85
|
+
self.Q = Qdct3.t()
|
|
86
|
+
print(f'\n\t!!!!! Initialized DCT-2 matrix of size {size} !!!!!\n')
|
|
87
|
+
elif self.sim_type == 'matmul':
|
|
88
|
+
self.Q = Qdct3
|
|
89
|
+
print(f'\n\t!!!!! Initialized DCT-3 matrix of size {size} !!!!!\n')
|
|
90
|
+
else:
|
|
91
|
+
raise RuntimeError(f'Unknown sim_type: {self.sim_type}')
|
|
92
|
+
elif self.proj == PROJ_HDM:
|
|
93
|
+
self.Q = hadamard_transform(torch.eye(size).to(device=p.device, dtype=p.dtype), scale=1. / math.sqrt(size))
|
|
94
|
+
print(f'\n\t!!!!! Initialized Hadamard matrix of size {size} !!!!!\n')
|
|
95
|
+
elif self.proj == PROJ_RAND_QR:
|
|
96
|
+
random = torch.randn(size, size, dtype=p.dtype, device=p.device)
|
|
97
|
+
self.Q, _ = torch.linalg.qr(random)
|
|
98
|
+
del random
|
|
99
|
+
else:
|
|
100
|
+
raise RuntimeError(f'Projection {self.proj} is currently not supported!')
|
|
101
|
+
|
|
102
|
+
if self.use_theoretical_similarity:
|
|
103
|
+
self.Q_cols_norm = self.Q.norm(p=self.ell_norm, dim=0)
|
|
104
|
+
|
|
105
|
+
def should_compute_update(self, p):
|
|
106
|
+
"""
|
|
107
|
+
This function returns a boolean indicating whether the update for the parameter p should be computed on the current GPU
|
|
108
|
+
"""
|
|
109
|
+
state = self.state[p]
|
|
110
|
+
param_id = state[STATE_ID]
|
|
111
|
+
return param_id % dist.get_world_size() == dist.get_rank()
|
|
112
|
+
|
|
113
|
+
def should_update_projection(self):
|
|
114
|
+
return self.steps == 1 or self.steps % self.update_proj_gap == 0
|
|
115
|
+
|
|
116
|
+
def init_state(self, p):
|
|
117
|
+
state = self.state[p]
|
|
118
|
+
if p.ndim == 1: # adam update
|
|
119
|
+
print(f'Parameter of size {tuple(p.shape)} will receive original AdamW update with state shape {tuple(p.shape)}')
|
|
120
|
+
state[STATE_M] = torch.zeros_like(p)
|
|
121
|
+
state[STATE_V] = torch.zeros_like(p)
|
|
122
|
+
elif p.ndim == 2: # low-rank adam update
|
|
123
|
+
n, m = p.shape
|
|
124
|
+
if n >= self.max_shape or m >= self.max_shape: # apply full-rank
|
|
125
|
+
print(f'Parameter of size {tuple(p.shape)} will receive original AdamW update with state shape {tuple(p.shape)}')
|
|
126
|
+
state[STATE_M] = torch.zeros_like(p)
|
|
127
|
+
state[STATE_V] = torch.zeros_like(p)
|
|
128
|
+
else: # apply low-rank using the DCT transform as orthogonal matrix
|
|
129
|
+
if n >= m:
|
|
130
|
+
low_rank_shape = (n, self.rank)
|
|
131
|
+
else:
|
|
132
|
+
# fix for Llama-3-8B that has a layer of size (1024, 4096)
|
|
133
|
+
# fix for Qwen2.5-7B that has a layer of size (512, 3584)
|
|
134
|
+
if n in [512, 1024] and m in [3584, 4096]:
|
|
135
|
+
low_rank_shape = (n, self.rank)
|
|
136
|
+
else:
|
|
137
|
+
low_rank_shape = (self.rank, m)
|
|
138
|
+
# low_rank_shape = (n, self.rank) if n >= m else (self.rank, m)
|
|
139
|
+
print(f'Parameter of size {tuple(p.shape)} will receive low-rank update with state shape {low_rank_shape}')
|
|
140
|
+
state[STATE_M] = torch.zeros(*low_rank_shape, dtype=p.dtype, device=p.device)
|
|
141
|
+
state[STATE_V] = torch.zeros(*low_rank_shape, dtype=p.dtype, device=p.device)
|
|
142
|
+
state[STATE_FFT_LRP] = FFTLowRankProjector(p,
|
|
143
|
+
rank=self.rank,
|
|
144
|
+
proj=self.proj,
|
|
145
|
+
rotate_subspace=self.rotate_subspace,
|
|
146
|
+
sim_type=self.sim_type,
|
|
147
|
+
ell_norm=self.ell_norm,
|
|
148
|
+
use_th_sim=self.use_theoretical_similarity)
|
|
149
|
+
if self.use_ef:
|
|
150
|
+
if self.q_ef > 0:
|
|
151
|
+
# state[STATE_EF] = torch.zeros(p.numel() // 2, dtype=torch.uint8, device=p.device)
|
|
152
|
+
# state[STATE_EF_MIN] = torch.zeros(p.shape[0], dtype=torch.bfloat16, device=p.device)
|
|
153
|
+
# state[STATE_EF_MAX] = torch.zeros(p.shape[0], dtype=torch.bfloat16, device=p.device)
|
|
154
|
+
quantClass = {4: Quantizer4bit, 8: Quantizer8bit}[self.q_ef]
|
|
155
|
+
if self.q_ef == 4:
|
|
156
|
+
quantClass = Quantizer4bit
|
|
157
|
+
print(f'\n\t!!!!! Quantizing EF to 4 bits !!!!!\n')
|
|
158
|
+
elif self.q_ef == 8:
|
|
159
|
+
quantClass = Quantizer8bit
|
|
160
|
+
print(f'\n\t!!!!! Quantizing EF to 8 bits !!!!!\n')
|
|
161
|
+
else:
|
|
162
|
+
raise RuntimeError(f'Quantization on {self.q_ef} bits is currently not supported!')
|
|
163
|
+
state[STATE_EF] = quantClass(shape=p.shape, device=p.device, dtype=p.dtype, bucket_size=p.shape[1])
|
|
164
|
+
else:
|
|
165
|
+
state[STATE_EF] = torch.zeros_like(p)
|
|
166
|
+
|
|
167
|
+
### initialize Q
|
|
168
|
+
print('calling setup_Q')
|
|
169
|
+
self.setup_Q(p)
|
|
170
|
+
# end if
|
|
171
|
+
|
|
172
|
+
def init(self):
|
|
173
|
+
# init broadcast info
|
|
174
|
+
self.is_state_initialized = True
|
|
175
|
+
bcast_src_list = []
|
|
176
|
+
param_id = 0 # parameter id
|
|
177
|
+
for group in self.param_groups:
|
|
178
|
+
for p in group['params']:
|
|
179
|
+
if p is None: continue
|
|
180
|
+
if p.grad is None: continue
|
|
181
|
+
|
|
182
|
+
state = self.state[p]
|
|
183
|
+
if len(state) == 0:
|
|
184
|
+
if self.distributed:
|
|
185
|
+
state[STATE_ID] = param_id
|
|
186
|
+
param_id += 1
|
|
187
|
+
if self.should_compute_update(p):
|
|
188
|
+
# if the current process computes the update, then it will also broadcast the parameters to all other workers
|
|
189
|
+
state[STATE_BROADCAST_SOURCE] = torch.tensor(dist.get_rank(), dtype=torch.int32, device=f'cuda:{dist.get_rank()}')
|
|
190
|
+
self.init_state(p)
|
|
191
|
+
else:
|
|
192
|
+
# p.register_hook(lambda grad: None) # set gradient to None
|
|
193
|
+
# p.requires_grad = False # disable gradient computation for this layer
|
|
194
|
+
state[STATE_BROADCAST_SOURCE] = torch.tensor(0, dtype=torch.int32, device=f'cuda:{dist.get_rank()}') # zero means empty here because we will do an all reduce
|
|
195
|
+
bcast_src_list.append(state[STATE_BROADCAST_SOURCE].item())
|
|
196
|
+
else:
|
|
197
|
+
self.init_state(p)
|
|
198
|
+
# end for group
|
|
199
|
+
|
|
200
|
+
if self.distributed:
|
|
201
|
+
dist.barrier()
|
|
202
|
+
|
|
203
|
+
# with open(f'broadcast-{dist.get_rank()}.txt', 'w') as w:
|
|
204
|
+
# sync broadcast source
|
|
205
|
+
# w.write(f'Broadcast SRC on worker {dist.get_rank()} before all_reduce: {",".join(map(str, bcast_src_list))}\n')
|
|
206
|
+
bcast_src_list = []
|
|
207
|
+
for group in self.param_groups:
|
|
208
|
+
for p in group['params']:
|
|
209
|
+
if p is None: continue
|
|
210
|
+
if p.grad is None: continue
|
|
211
|
+
|
|
212
|
+
state = self.state[p]
|
|
213
|
+
dist.all_reduce(state[STATE_BROADCAST_SOURCE], op=dist.ReduceOp.SUM)
|
|
214
|
+
state[STATE_BROADCAST_SOURCE] = state[STATE_BROADCAST_SOURCE].item()
|
|
215
|
+
bcast_src_list.append(state[STATE_BROADCAST_SOURCE])
|
|
216
|
+
# end for group
|
|
217
|
+
# w.write(f'Broadcast SRC on worker {dist.get_rank()} after all_reduce: {",".join(map(str, bcast_src_list))}\n')
|
|
218
|
+
dist.barrier()
|
|
219
|
+
# end if
|
|
220
|
+
torch.cuda.empty_cache()
|
|
221
|
+
|
|
222
|
+
@torch.no_grad()
|
|
223
|
+
def step(self, closure=None):
|
|
224
|
+
self.steps += 1
|
|
225
|
+
|
|
226
|
+
loss = None
|
|
227
|
+
if closure is not None:
|
|
228
|
+
with torch.enable_grad():
|
|
229
|
+
loss = closure()
|
|
230
|
+
|
|
231
|
+
if not self.is_state_initialized:
|
|
232
|
+
self.init() # init broadcast info
|
|
233
|
+
|
|
234
|
+
for group in self.param_groups:
|
|
235
|
+
lr = group['lr']
|
|
236
|
+
wd = group['weight_decay']
|
|
237
|
+
|
|
238
|
+
for p in group['params']:
|
|
239
|
+
if p is None: continue
|
|
240
|
+
if p.grad is None: continue
|
|
241
|
+
|
|
242
|
+
if wd > 0:
|
|
243
|
+
p.mul_(1 - lr * wd)
|
|
244
|
+
|
|
245
|
+
if self.distributed:
|
|
246
|
+
if self.should_compute_update(p):
|
|
247
|
+
self.update_step(p, lr)
|
|
248
|
+
else:
|
|
249
|
+
self.update_step(p, lr)
|
|
250
|
+
# end for group
|
|
251
|
+
|
|
252
|
+
if self.distributed:
|
|
253
|
+
for group in self.param_groups:
|
|
254
|
+
for p in group['params']:
|
|
255
|
+
if p is None: continue
|
|
256
|
+
if p.grad is None: continue
|
|
257
|
+
|
|
258
|
+
dist.broadcast(p, src=self.state[p][STATE_BROADCAST_SOURCE])
|
|
259
|
+
|
|
260
|
+
# end for group
|
|
261
|
+
dist.barrier() # wait for all GPUs to compute the update for all layers
|
|
262
|
+
# end if distributed
|
|
263
|
+
return loss
|
|
264
|
+
|
|
265
|
+
@torch.no_grad()
|
|
266
|
+
def update_step(self, p, lr):
|
|
267
|
+
if p.ndim == 1: # adam update
|
|
268
|
+
self.adamw_step(p, lr)
|
|
269
|
+
elif p.ndim == 2: # low-rank adam update
|
|
270
|
+
n, m = p.shape
|
|
271
|
+
if n >= self.max_shape or m >= self.max_shape: # apply full-rank for parameters that have at least one dimension >= max_size (e.g. embeddings and lm_head)
|
|
272
|
+
self.adamw_step(p, lr)
|
|
273
|
+
else:
|
|
274
|
+
self.dct_low_rank_step(p, lr)
|
|
275
|
+
|
|
276
|
+
def dct_low_rank_step(self, p, lr):
|
|
277
|
+
beta1, beta2 = self.betas
|
|
278
|
+
bc1 = 1 - beta1 ** self.steps
|
|
279
|
+
sqrt_bc2 = math.sqrt(1 - beta2 ** self.steps)
|
|
280
|
+
adjusted_lr = -lr * sqrt_bc2 / bc1
|
|
281
|
+
|
|
282
|
+
A = p.grad # initially, the accumulator stores gradient and a bit later we will add the error feedback
|
|
283
|
+
state = self.state[p]
|
|
284
|
+
|
|
285
|
+
mt = state[STATE_M]
|
|
286
|
+
vt = state[STATE_V]
|
|
287
|
+
|
|
288
|
+
if self.use_ef:
|
|
289
|
+
E = state[STATE_EF]
|
|
290
|
+
if self.q_ef:
|
|
291
|
+
# see step 4 from Algorithm 1 in the MicroAdam paper https://arxiv.black/pdf/2405.15593
|
|
292
|
+
A.add_(E.quantize_inv()) # p.grad += Qinv(EF)
|
|
293
|
+
else:
|
|
294
|
+
A.add_(E)
|
|
295
|
+
|
|
296
|
+
clrp: FFTLowRankProjector = state[STATE_FFT_LRP]
|
|
297
|
+
clrp.inc_step()
|
|
298
|
+
|
|
299
|
+
if self.should_update_projection():
|
|
300
|
+
a = clrp.change_subspace(self.Q, A, col_norms=self.Q_cols_norm)
|
|
301
|
+
else:
|
|
302
|
+
### compute low-rank accumulator a
|
|
303
|
+
a = clrp.from_higher_to_lower_dimensions(self.Q, A)
|
|
304
|
+
|
|
305
|
+
if self.use_ef:
|
|
306
|
+
A_reconstructed = clrp.from_lower_to_higher_dimensions(self.Q, a)
|
|
307
|
+
if self.q_ef:
|
|
308
|
+
A.sub_(A_reconstructed) # the full precision EF is stored now in A
|
|
309
|
+
# see step 8 from Algorithm 1 in the MicroAdam paper https://arxiv.black/pdf/2405.15593
|
|
310
|
+
E.quantize(A)
|
|
311
|
+
else:
|
|
312
|
+
E.copy_(A).sub_(A_reconstructed)
|
|
313
|
+
del A_reconstructed
|
|
314
|
+
|
|
315
|
+
### update momentum m and v (rotate first, if needed)
|
|
316
|
+
if self.steps > 1 and self.rotate_subspace and self.should_update_projection():
|
|
317
|
+
R = clrp.get_subspace_rotation_matrix(self.Q)
|
|
318
|
+
clrp.rotate_subspace(R, mt)
|
|
319
|
+
clrp.rotate_subspace(R, vt)
|
|
320
|
+
vt.abs_() # make sure vt is positive
|
|
321
|
+
del R
|
|
322
|
+
|
|
323
|
+
mt.mul_(beta1).add_(a, alpha=1 - beta1)
|
|
324
|
+
vt.mul_(beta2).addcmul_(a, a, value=1 - beta2)
|
|
325
|
+
|
|
326
|
+
u = mt / (self.eps * sqrt_bc2 + vt.sqrt())
|
|
327
|
+
clrp.from_lower_to_higher_dimensions(self.Q, u, out=p.grad)
|
|
328
|
+
del u, a
|
|
329
|
+
|
|
330
|
+
p.add_(p.grad, alpha=adjusted_lr)
|
|
331
|
+
|
|
332
|
+
@torch.no_grad()
|
|
333
|
+
def adamw_step(self, p, lr):
|
|
334
|
+
state = self.state[p]
|
|
335
|
+
g = p.grad
|
|
336
|
+
|
|
337
|
+
mt = state[STATE_M]
|
|
338
|
+
vt = state[STATE_V]
|
|
339
|
+
|
|
340
|
+
beta1, beta2 = self.betas
|
|
341
|
+
bc1 = 1 - beta1 ** self.steps
|
|
342
|
+
sqrt_bc2 = math.sqrt(1 - beta2 ** self.steps)
|
|
343
|
+
adjusted_lr = -lr * sqrt_bc2 / bc1
|
|
344
|
+
|
|
345
|
+
# update momentum m and v
|
|
346
|
+
mt.mul_(beta1).add_(g, alpha=1-beta1)
|
|
347
|
+
vt.mul_(beta2).addcmul_(g, g, value=1-beta2)
|
|
348
|
+
|
|
349
|
+
# U = mt / (self.eps * sqrt_bc2 + vt.sqrt())
|
|
350
|
+
g.copy_(vt).sqrt_().add_(self.eps * sqrt_bc2).div_(mt).reciprocal_()
|
|
351
|
+
p.add_(g, alpha=adjusted_lr)
|