ista-daslab-optimizers 1.1.6__tar.gz → 1.1.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. {ista_daslab_optimizers-1.1.6/ista_daslab_optimizers.egg-info → ista_daslab_optimizers-1.1.8}/PKG-INFO +32 -10
  2. {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/README.md +27 -8
  3. {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/__init__.py +2 -0
  4. {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/dense_mfac/dense_core_mfac.py +4 -4
  5. ista_daslab_optimizers-1.1.8/ista_daslab_optimizers/fft_low_rank/dct_adamw.py +351 -0
  6. ista_daslab_optimizers-1.1.8/ista_daslab_optimizers/fft_low_rank/fft_projector.py +192 -0
  7. ista_daslab_optimizers-1.1.8/ista_daslab_optimizers/fft_low_rank/trion.py +242 -0
  8. ista_daslab_optimizers-1.1.8/ista_daslab_optimizers/ista_optimizer/__init__.py +5 -0
  9. ista_daslab_optimizers-1.1.8/ista_daslab_optimizers/ista_optimizer/ista_optimizer.py +36 -0
  10. {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/micro_adam/micro_adam.py +14 -14
  11. {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/sparse_mfac/sparse_core_mfac_w_ef.py +10 -10
  12. {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/tools.py +4 -3
  13. ista_daslab_optimizers-1.1.8/ista_daslab_optimizers/utils/dct.py +45 -0
  14. ista_daslab_optimizers-1.1.8/ista_daslab_optimizers/utils/global_cache.py +45 -0
  15. ista_daslab_optimizers-1.1.8/ista_daslab_optimizers/utils/matrix_storage.py +58 -0
  16. ista_daslab_optimizers-1.1.8/ista_daslab_optimizers/utils/newton_schulz_triton.py +374 -0
  17. ista_daslab_optimizers-1.1.8/ista_daslab_optimizers/utils/quantizers.py +71 -0
  18. ista_daslab_optimizers-1.1.8/ista_daslab_optimizers/utils/schedulers.py +41 -0
  19. {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8/ista_daslab_optimizers.egg-info}/PKG-INFO +32 -10
  20. ista_daslab_optimizers-1.1.8/ista_daslab_optimizers.egg-info/SOURCES.txt +33 -0
  21. {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers.egg-info/requires.txt +2 -0
  22. ista_daslab_optimizers-1.1.8/ista_daslab_optimizers.egg-info/top_level.txt +1 -0
  23. {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/pyproject.toml +4 -1
  24. ista_daslab_optimizers-1.1.6/ista_daslab_optimizers.egg-info/SOURCES.txt +0 -48
  25. ista_daslab_optimizers-1.1.6/ista_daslab_optimizers.egg-info/top_level.txt +0 -5
  26. ista_daslab_optimizers-1.1.6/kernels/dense_mfac/dense_mfac.cpp +0 -20
  27. ista_daslab_optimizers-1.1.6/kernels/dense_mfac/dense_mfac_kernel.cu +0 -216
  28. ista_daslab_optimizers-1.1.6/kernels/micro_adam/micro_adam.cpp +0 -62
  29. ista_daslab_optimizers-1.1.6/kernels/micro_adam/micro_adam_asymm_block_quant.cu +0 -64
  30. ista_daslab_optimizers-1.1.6/kernels/micro_adam/micro_adam_asymm_block_quant_inv.cu +0 -83
  31. ista_daslab_optimizers-1.1.6/kernels/micro_adam/micro_adam_update.cu +0 -165
  32. ista_daslab_optimizers-1.1.6/kernels/sparse_mfac/sparse_mfac.cpp +0 -84
  33. ista_daslab_optimizers-1.1.6/kernels/sparse_mfac/sparse_mfac_LCG_kernel.cu +0 -246
  34. ista_daslab_optimizers-1.1.6/kernels/sparse_mfac/sparse_mfac_SP_kernel.cu +0 -251
  35. ista_daslab_optimizers-1.1.6/kernels/sparse_mfac_pruner/sparse_mfac_pruner.cpp +0 -57
  36. ista_daslab_optimizers-1.1.6/kernels/sparse_mfac_pruner/sparse_mfac_pruner.cu +0 -235
  37. ista_daslab_optimizers-1.1.6/kernels/tools/tools.cpp +0 -127
  38. ista_daslab_optimizers-1.1.6/kernels/tools/tools_kernel.cu +0 -315
  39. ista_daslab_optimizers-1.1.6/kernels/utils.h +0 -125
  40. ista_daslab_optimizers-1.1.6/setup.py +0 -56
  41. {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/LICENSE +0 -0
  42. {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/MANIFEST.in +0 -0
  43. {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/acdc/__init__.py +0 -0
  44. {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/acdc/acdc.py +0 -0
  45. {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/acdc/wd_scheduler.py +0 -0
  46. {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/dense_mfac/__init__.py +0 -0
  47. {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/dense_mfac/dense_mfac.py +0 -0
  48. {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/micro_adam/__init__.py +0 -0
  49. {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/sparse_mfac/__init__.py +0 -0
  50. {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/sparse_mfac/sparse_mfac.py +0 -0
  51. {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers.egg-info/dependency_links.txt +0 -0
  52. {ista_daslab_optimizers-1.1.6 → ista_daslab_optimizers-1.1.8}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: ista_daslab_optimizers
3
- Version: 1.1.6
3
+ Version: 1.1.8
4
4
  Summary: Deep Learning optimizers developed in the Distributed Algorithms and Systems group (DASLab) @ Institute of Science and Technology Austria (ISTA)
5
5
  Author-email: Ionut-Vlad Modoranu <ionut-vlad.modoranu@ist.ac.at>
6
6
  Maintainer-email: Ionut-Vlad Modoranu <ionut-vlad.modoranu@ist.ac.at>
@@ -222,6 +222,9 @@ Requires-Dist: gpustat
222
222
  Requires-Dist: timm
223
223
  Requires-Dist: einops
224
224
  Requires-Dist: psutil
225
+ Requires-Dist: fast-hadamard-transform
226
+ Requires-Dist: ista-daslab-optimizers-cuda
227
+ Dynamic: license-file
225
228
 
226
229
  # ISTA DAS Lab Optimization Algorithms Package
227
230
  This repository contains optimization algorithms for Deep Learning developed by
@@ -240,6 +243,9 @@ The repository contains code for the following optimizers published by DASLab @
240
243
  - **MicroAdam**:
241
244
  - paper: [MicroAdam: Accurate Adaptive Optimization with Low Space Overhead and Provable Convergence](https://arxiv.org/abs/2405.15593)
242
245
  - official repository: [GitHub](https://github.com/IST-DASLab/MicroAdam)
246
+ - **Trion / DCT-AdamW**:
247
+ - paper: [FFT-based Dynamic Subspace Selection for Low-Rank Adaptive Optimization of Large Language Models](https://arxiv.org/abs/2505.17967v3)
248
+ - code: [GitHub](https://github.com/IST-DASLab/ISTA-DASLab-Optimizers/tree/main/ista_daslab_optimizers/fft_low_rank)
243
249
 
244
250
  ### Installation
245
251
  To use the latest stable version of this repository, you can install via pip:
@@ -261,7 +267,8 @@ source install.sh
261
267
 
262
268
  ## How to use optimizers?
263
269
 
264
- In this repository we provide a minimal working example for CIFAR-10 for optimizers `acdc`, `dense_mfac`, `sparse_mfac` and `micro_adam`:
270
+ In this repository we provide a minimal working example for CIFAR-10 for optimizers `acdc`,
271
+ `dense_mfac`, `sparse_mfac` and `micro_adam`:
265
272
  ```shell
266
273
  cd examples/cifar10
267
274
  OPTIMIZER=micro_adam # or any other optimizer listed above
@@ -291,18 +298,33 @@ optimizer = MicroAdam(
291
298
  # Versions summary:
292
299
 
293
300
  ---
301
+ - **1.1.8** @ February 5th, 2026:
302
+ - moved kernels to [ISTA-DASLab-Optimizers-CUDA](https://github.com/IST-DASLab/ISTA-DASLab-Optimizers-CUDA)
303
+ - building building the package after adding a new optimizer that doesn't require CUDA support would require compiling
304
+ the kernels from scratch, which is time consuming and not needed
305
+ - **1.1.7** @ October 8th, 2025:
306
+ - added code for `Trion & DCT-AdamW`
294
307
  - **1.1.6** @ February 19th, 2025:
295
- - do not update the parameters that have `None` gradient in method `update_model` from `tools.py`. This is useful when using M-FAC for models with more than one classification head in the Continual Learning framework
308
+ - do not update the parameters that have `None` gradient in method `update_model` from `tools.py`.
309
+ This is useful when using M-FAC for models with more than one classification head in the Continual Learning framework.
296
310
  - **1.1.5** @ February 19th, 2025:
297
- - adapted `DenseMFAC` for a model with multiple classification heads for Continual Learning where we have one feature extractor block and a list of classification heads. The issue was related to the model size, which included the feature extractor backbone and all classification heads, but in practice only one classification head will be used for training and inference. This caused some size mismatch errors at runtime in the `DenseCoreMFAC` module because the gradient at runtime had fewer entries than the entire model. When using `DenseMFAC` for such settings, set `optimizer.model_size` to the correct size after calling the constructor and the `DenseCoreMFAC` object will be created automatically in the `step` function.
311
+ - adapted `DenseMFAC` for a model with multiple classification heads for Continual Learning where
312
+ we have one feature extractor block and a list of classification heads. The issue was related to
313
+ the model size, which included the feature extractor backbone and all classification heads, but
314
+ in practice only one classification head will be used for training and inference. This caused some
315
+ size mismatch errors at runtime in the `DenseCoreMFAC` module because the gradient at runtime had
316
+ fewer entries than the entire model. When using `DenseMFAC` for such settings, set `optimizer.model_size`
317
+ to the correct size after calling the constructor and the `DenseCoreMFAC` object will be created
318
+ automatically in the `step` function.
298
319
  - **1.1.3** @ September 5th, 2024:
299
320
  - allow using `SparseCoreMFACwithEF` separately by importing it in `sparse_mfac.__init__.py`
300
321
  - **1.1.2** @ August 1st, 2024:
301
- - ***[1.1.0]:*** added support to densify the final update: introduced parameter alpha that controls the fraction of error feedback
302
- (EF) to be integrated into the update to make it dense. Finally, the fraction alpha will be discarded from the EF at
303
- the expense of another call to `Qinv` and `Q` (and implicitly quantization statistics computation).
304
- - ***[1.0.2]:*** added FSDP-compatible implementation by initializing the parameter states in the `update_step` method
305
- instead of MicroAdam constructor
322
+ - ***[1.1.0]:*** added support to densify the final update: introduced parameter alpha that controls
323
+ the fraction of error feedback (EF) to be integrated into the update to make it dense. Finally, the
324
+ fraction alpha will be discarded from the EF at the expense of another call to `Qinv` and `Q` (and
325
+ implicitly quantization statistics computation).
326
+ - ***[1.0.2]:*** added FSDP-compatible implementation by initializing the parameter states in the
327
+ `update_step` method instead of MicroAdam constructor
306
328
  - **1.0.1** @ June 27th, 2024:
307
329
  - removed version in dependencies to avoid conflicts with llm-foundry
308
330
  - **1.0.0** @ June 20th, 2024:
@@ -15,6 +15,9 @@ The repository contains code for the following optimizers published by DASLab @
15
15
  - **MicroAdam**:
16
16
  - paper: [MicroAdam: Accurate Adaptive Optimization with Low Space Overhead and Provable Convergence](https://arxiv.org/abs/2405.15593)
17
17
  - official repository: [GitHub](https://github.com/IST-DASLab/MicroAdam)
18
+ - **Trion / DCT-AdamW**:
19
+ - paper: [FFT-based Dynamic Subspace Selection for Low-Rank Adaptive Optimization of Large Language Models](https://arxiv.org/abs/2505.17967v3)
20
+ - code: [GitHub](https://github.com/IST-DASLab/ISTA-DASLab-Optimizers/tree/main/ista_daslab_optimizers/fft_low_rank)
18
21
 
19
22
  ### Installation
20
23
  To use the latest stable version of this repository, you can install via pip:
@@ -36,7 +39,8 @@ source install.sh
36
39
 
37
40
  ## How to use optimizers?
38
41
 
39
- In this repository we provide a minimal working example for CIFAR-10 for optimizers `acdc`, `dense_mfac`, `sparse_mfac` and `micro_adam`:
42
+ In this repository we provide a minimal working example for CIFAR-10 for optimizers `acdc`,
43
+ `dense_mfac`, `sparse_mfac` and `micro_adam`:
40
44
  ```shell
41
45
  cd examples/cifar10
42
46
  OPTIMIZER=micro_adam # or any other optimizer listed above
@@ -66,18 +70,33 @@ optimizer = MicroAdam(
66
70
  # Versions summary:
67
71
 
68
72
  ---
73
+ - **1.1.8** @ February 5th, 2026:
74
+ - moved kernels to [ISTA-DASLab-Optimizers-CUDA](https://github.com/IST-DASLab/ISTA-DASLab-Optimizers-CUDA)
75
+ - building building the package after adding a new optimizer that doesn't require CUDA support would require compiling
76
+ the kernels from scratch, which is time consuming and not needed
77
+ - **1.1.7** @ October 8th, 2025:
78
+ - added code for `Trion & DCT-AdamW`
69
79
  - **1.1.6** @ February 19th, 2025:
70
- - do not update the parameters that have `None` gradient in method `update_model` from `tools.py`. This is useful when using M-FAC for models with more than one classification head in the Continual Learning framework
80
+ - do not update the parameters that have `None` gradient in method `update_model` from `tools.py`.
81
+ This is useful when using M-FAC for models with more than one classification head in the Continual Learning framework.
71
82
  - **1.1.5** @ February 19th, 2025:
72
- - adapted `DenseMFAC` for a model with multiple classification heads for Continual Learning where we have one feature extractor block and a list of classification heads. The issue was related to the model size, which included the feature extractor backbone and all classification heads, but in practice only one classification head will be used for training and inference. This caused some size mismatch errors at runtime in the `DenseCoreMFAC` module because the gradient at runtime had fewer entries than the entire model. When using `DenseMFAC` for such settings, set `optimizer.model_size` to the correct size after calling the constructor and the `DenseCoreMFAC` object will be created automatically in the `step` function.
83
+ - adapted `DenseMFAC` for a model with multiple classification heads for Continual Learning where
84
+ we have one feature extractor block and a list of classification heads. The issue was related to
85
+ the model size, which included the feature extractor backbone and all classification heads, but
86
+ in practice only one classification head will be used for training and inference. This caused some
87
+ size mismatch errors at runtime in the `DenseCoreMFAC` module because the gradient at runtime had
88
+ fewer entries than the entire model. When using `DenseMFAC` for such settings, set `optimizer.model_size`
89
+ to the correct size after calling the constructor and the `DenseCoreMFAC` object will be created
90
+ automatically in the `step` function.
73
91
  - **1.1.3** @ September 5th, 2024:
74
92
  - allow using `SparseCoreMFACwithEF` separately by importing it in `sparse_mfac.__init__.py`
75
93
  - **1.1.2** @ August 1st, 2024:
76
- - ***[1.1.0]:*** added support to densify the final update: introduced parameter alpha that controls the fraction of error feedback
77
- (EF) to be integrated into the update to make it dense. Finally, the fraction alpha will be discarded from the EF at
78
- the expense of another call to `Qinv` and `Q` (and implicitly quantization statistics computation).
79
- - ***[1.0.2]:*** added FSDP-compatible implementation by initializing the parameter states in the `update_step` method
80
- instead of MicroAdam constructor
94
+ - ***[1.1.0]:*** added support to densify the final update: introduced parameter alpha that controls
95
+ the fraction of error feedback (EF) to be integrated into the update to make it dense. Finally, the
96
+ fraction alpha will be discarded from the EF at the expense of another call to `Qinv` and `Q` (and
97
+ implicitly quantization statistics computation).
98
+ - ***[1.0.2]:*** added FSDP-compatible implementation by initializing the parameter states in the
99
+ `update_step` method instead of MicroAdam constructor
81
100
  - **1.0.1** @ June 27th, 2024:
82
101
  - removed version in dependencies to avoid conflicts with llm-foundry
83
102
  - **1.0.0** @ June 20th, 2024:
@@ -2,3 +2,5 @@ from .acdc import *
2
2
  from .micro_adam import *
3
3
  from .sparse_mfac import *
4
4
  from .dense_mfac import *
5
+ from .fft_low_rank.trion import Trion
6
+ from .fft_low_rank.dct_adamw import DCTAdamW
@@ -4,10 +4,10 @@ import numpy as np
4
4
 
5
5
  USE_CUDA = True
6
6
  try:
7
- import ista_daslab_dense_mfac
7
+ import ista_daslab_cuda_dense_mfac
8
8
  except Exception as e:
9
9
  USE_CUDA = False
10
- print('\n\t[WARNING] The module "ista_daslab_dense_mfac" is not installed, using slower PyTorch implementation!\n')
10
+ print('\n\t[WARNING] The module "ista_daslab_cuda_dense_mfac" is not installed, using slower PyTorch implementation!\n')
11
11
 
12
12
  class DenseCoreMFAC:
13
13
  def __init__(self, grads, dev, gpus, damp=1e-5, create_G=False):
@@ -76,7 +76,7 @@ class DenseCoreMFAC:
76
76
 
77
77
  if USE_CUDA:
78
78
  diag = torch.diag(torch.full(size=[self.m], fill_value=self.lambd, device=self.dev, dtype=self.dtype))
79
- self.coef = ista_daslab_dense_mfac.hinv_setup(tmp, diag)
79
+ self.coef = ista_daslab_cuda_dense_mfac.hinv_setup(tmp, diag)
80
80
  else:
81
81
  for i in range(max(self.last, 1), self.m):
82
82
  self.coef[i, :i] = tmp[i, :i].matmul(self.coef[:i, :i])
@@ -130,7 +130,7 @@ class DenseCoreMFAC:
130
130
  dots = self.compute_scalar_products(x)
131
131
  giHix = self.lambd * dots
132
132
  if USE_CUDA:
133
- giHix = ista_daslab_dense_mfac.hinv_mul(self.m, self.giHig, giHix)
133
+ giHix = ista_daslab_cuda_dense_mfac.hinv_mul(self.m, self.giHig, giHix)
134
134
  else:
135
135
  for i in range(1, self.m):
136
136
  giHix[i:].sub_(self.giHig[i - 1, i:], alpha=giHix[i - 1] / self.denom[i - 1])
@@ -0,0 +1,351 @@
1
+ import torch
2
+ import torch.distributed as dist
3
+
4
+ import math
5
+ import numpy as np
6
+
7
+ from fast_hadamard_transform import hadamard_transform
8
+ from ista_daslab_optimizers.utils.dct import dct3_matrix
9
+ from ista_daslab_optimizers.utils.quantizers import Quantizer4bit, Quantizer8bit
10
+ from ista_daslab_optimizers.fft_low_rank.fft_projector import FFTLowRankProjector
11
+
12
+ PROJ_DCT = 'dct'
13
+ PROJ_HDM = 'hdm'
14
+ PROJ_RAND_QR = 'rqr'
15
+
16
+ ALL_PROJ = [
17
+ PROJ_DCT, # DCT projection
18
+ PROJ_HDM, # Hadamard projection
19
+ PROJ_RAND_QR, # Random-QR projection
20
+ ]
21
+
22
+ STATE_M = 'm'
23
+ STATE_V = 'v'
24
+ STATE_Q = 'Q'
25
+ STATE_ID = 'param-id'
26
+ STATE_EF = 'ef'
27
+ # STATE_EF_MIN = 'ef-min-vals'
28
+ # STATE_EF_MAX = 'ef-max-vals'
29
+ STATE_FFT_LRP = 'fft-low-rank-projector'
30
+ STATE_BROADCAST_SOURCE = 'broadcast-src' # the process rank that computes the update for a parameter p will broadcast the parameter p to other workers
31
+
32
+
33
+ class DCTAdamW(torch.optim.Optimizer):
34
+ def __init__(self,
35
+ params,
36
+ lr,
37
+ weight_decay,
38
+ rank,
39
+ proj,
40
+ use_ef=False,
41
+ q_ef=False,
42
+ distributed=False,
43
+ update_proj_gap=1,
44
+ rotate_subspace=False,
45
+ sim_type='matmul',
46
+ ell_norm=1,
47
+ max_shape=32_000,
48
+ betas=(0.9, 0.999),
49
+ eps=1e-8):
50
+ assert proj in ALL_PROJ
51
+
52
+ super().__init__(params, dict(lr=lr, weight_decay=weight_decay))
53
+
54
+ self.rank = rank
55
+ self.proj = proj
56
+ self.use_ef = use_ef
57
+ self.q_ef = q_ef
58
+ self.distributed = distributed
59
+ self.update_proj_gap = update_proj_gap
60
+ self.rotate_subspace = rotate_subspace
61
+ self.sim_type = sim_type
62
+ self.ell_norm = ell_norm
63
+ self.max_shape = max_shape # apply low-rank to 2D parameters that have both dimensions smaller than max_shape
64
+ self.betas = betas
65
+ self.eps = eps
66
+
67
+ self.steps = 0
68
+ self.is_state_initialized = False
69
+ self.Q = None # the full transformation matrix (non-truncated, all rows and columns)
70
+ self.Q_cols_norm = None
71
+ self.use_theoretical_similarity = (self.ell_norm < 0)
72
+ self.ell_norm = abs(self.ell_norm)
73
+
74
+ if proj == PROJ_DCT:
75
+ assert sim_type in ['matmul', 'makhoul']
76
+ else:
77
+ assert sim_type == 'matmul'
78
+
79
+ def setup_Q(self, p):
80
+ if self.Q is None:
81
+ size = min(p.shape)
82
+ if self.proj == PROJ_DCT:
83
+ Qdct3 = dct3_matrix(size, p.dtype, p.device) # first row is zero
84
+ if self.sim_type == 'makhoul':
85
+ self.Q = Qdct3.t()
86
+ print(f'\n\t!!!!! Initialized DCT-2 matrix of size {size} !!!!!\n')
87
+ elif self.sim_type == 'matmul':
88
+ self.Q = Qdct3
89
+ print(f'\n\t!!!!! Initialized DCT-3 matrix of size {size} !!!!!\n')
90
+ else:
91
+ raise RuntimeError(f'Unknown sim_type: {self.sim_type}')
92
+ elif self.proj == PROJ_HDM:
93
+ self.Q = hadamard_transform(torch.eye(size).to(device=p.device, dtype=p.dtype), scale=1. / math.sqrt(size))
94
+ print(f'\n\t!!!!! Initialized Hadamard matrix of size {size} !!!!!\n')
95
+ elif self.proj == PROJ_RAND_QR:
96
+ random = torch.randn(size, size, dtype=p.dtype, device=p.device)
97
+ self.Q, _ = torch.linalg.qr(random)
98
+ del random
99
+ else:
100
+ raise RuntimeError(f'Projection {self.proj} is currently not supported!')
101
+
102
+ if self.use_theoretical_similarity:
103
+ self.Q_cols_norm = self.Q.norm(p=self.ell_norm, dim=0)
104
+
105
+ def should_compute_update(self, p):
106
+ """
107
+ This function returns a boolean indicating whether the update for the parameter p should be computed on the current GPU
108
+ """
109
+ state = self.state[p]
110
+ param_id = state[STATE_ID]
111
+ return param_id % dist.get_world_size() == dist.get_rank()
112
+
113
+ def should_update_projection(self):
114
+ return self.steps == 1 or self.steps % self.update_proj_gap == 0
115
+
116
+ def init_state(self, p):
117
+ state = self.state[p]
118
+ if p.ndim == 1: # adam update
119
+ print(f'Parameter of size {tuple(p.shape)} will receive original AdamW update with state shape {tuple(p.shape)}')
120
+ state[STATE_M] = torch.zeros_like(p)
121
+ state[STATE_V] = torch.zeros_like(p)
122
+ elif p.ndim == 2: # low-rank adam update
123
+ n, m = p.shape
124
+ if n >= self.max_shape or m >= self.max_shape: # apply full-rank
125
+ print(f'Parameter of size {tuple(p.shape)} will receive original AdamW update with state shape {tuple(p.shape)}')
126
+ state[STATE_M] = torch.zeros_like(p)
127
+ state[STATE_V] = torch.zeros_like(p)
128
+ else: # apply low-rank using the DCT transform as orthogonal matrix
129
+ if n >= m:
130
+ low_rank_shape = (n, self.rank)
131
+ else:
132
+ # fix for Llama-3-8B that has a layer of size (1024, 4096)
133
+ # fix for Qwen2.5-7B that has a layer of size (512, 3584)
134
+ if n in [512, 1024] and m in [3584, 4096]:
135
+ low_rank_shape = (n, self.rank)
136
+ else:
137
+ low_rank_shape = (self.rank, m)
138
+ # low_rank_shape = (n, self.rank) if n >= m else (self.rank, m)
139
+ print(f'Parameter of size {tuple(p.shape)} will receive low-rank update with state shape {low_rank_shape}')
140
+ state[STATE_M] = torch.zeros(*low_rank_shape, dtype=p.dtype, device=p.device)
141
+ state[STATE_V] = torch.zeros(*low_rank_shape, dtype=p.dtype, device=p.device)
142
+ state[STATE_FFT_LRP] = FFTLowRankProjector(p,
143
+ rank=self.rank,
144
+ proj=self.proj,
145
+ rotate_subspace=self.rotate_subspace,
146
+ sim_type=self.sim_type,
147
+ ell_norm=self.ell_norm,
148
+ use_th_sim=self.use_theoretical_similarity)
149
+ if self.use_ef:
150
+ if self.q_ef > 0:
151
+ # state[STATE_EF] = torch.zeros(p.numel() // 2, dtype=torch.uint8, device=p.device)
152
+ # state[STATE_EF_MIN] = torch.zeros(p.shape[0], dtype=torch.bfloat16, device=p.device)
153
+ # state[STATE_EF_MAX] = torch.zeros(p.shape[0], dtype=torch.bfloat16, device=p.device)
154
+ quantClass = {4: Quantizer4bit, 8: Quantizer8bit}[self.q_ef]
155
+ if self.q_ef == 4:
156
+ quantClass = Quantizer4bit
157
+ print(f'\n\t!!!!! Quantizing EF to 4 bits !!!!!\n')
158
+ elif self.q_ef == 8:
159
+ quantClass = Quantizer8bit
160
+ print(f'\n\t!!!!! Quantizing EF to 8 bits !!!!!\n')
161
+ else:
162
+ raise RuntimeError(f'Quantization on {self.q_ef} bits is currently not supported!')
163
+ state[STATE_EF] = quantClass(shape=p.shape, device=p.device, dtype=p.dtype, bucket_size=p.shape[1])
164
+ else:
165
+ state[STATE_EF] = torch.zeros_like(p)
166
+
167
+ ### initialize Q
168
+ print('calling setup_Q')
169
+ self.setup_Q(p)
170
+ # end if
171
+
172
+ def init(self):
173
+ # init broadcast info
174
+ self.is_state_initialized = True
175
+ bcast_src_list = []
176
+ param_id = 0 # parameter id
177
+ for group in self.param_groups:
178
+ for p in group['params']:
179
+ if p is None: continue
180
+ if p.grad is None: continue
181
+
182
+ state = self.state[p]
183
+ if len(state) == 0:
184
+ if self.distributed:
185
+ state[STATE_ID] = param_id
186
+ param_id += 1
187
+ if self.should_compute_update(p):
188
+ # if the current process computes the update, then it will also broadcast the parameters to all other workers
189
+ state[STATE_BROADCAST_SOURCE] = torch.tensor(dist.get_rank(), dtype=torch.int32, device=f'cuda:{dist.get_rank()}')
190
+ self.init_state(p)
191
+ else:
192
+ # p.register_hook(lambda grad: None) # set gradient to None
193
+ # p.requires_grad = False # disable gradient computation for this layer
194
+ state[STATE_BROADCAST_SOURCE] = torch.tensor(0, dtype=torch.int32, device=f'cuda:{dist.get_rank()}') # zero means empty here because we will do an all reduce
195
+ bcast_src_list.append(state[STATE_BROADCAST_SOURCE].item())
196
+ else:
197
+ self.init_state(p)
198
+ # end for group
199
+
200
+ if self.distributed:
201
+ dist.barrier()
202
+
203
+ # with open(f'broadcast-{dist.get_rank()}.txt', 'w') as w:
204
+ # sync broadcast source
205
+ # w.write(f'Broadcast SRC on worker {dist.get_rank()} before all_reduce: {",".join(map(str, bcast_src_list))}\n')
206
+ bcast_src_list = []
207
+ for group in self.param_groups:
208
+ for p in group['params']:
209
+ if p is None: continue
210
+ if p.grad is None: continue
211
+
212
+ state = self.state[p]
213
+ dist.all_reduce(state[STATE_BROADCAST_SOURCE], op=dist.ReduceOp.SUM)
214
+ state[STATE_BROADCAST_SOURCE] = state[STATE_BROADCAST_SOURCE].item()
215
+ bcast_src_list.append(state[STATE_BROADCAST_SOURCE])
216
+ # end for group
217
+ # w.write(f'Broadcast SRC on worker {dist.get_rank()} after all_reduce: {",".join(map(str, bcast_src_list))}\n')
218
+ dist.barrier()
219
+ # end if
220
+ torch.cuda.empty_cache()
221
+
222
+ @torch.no_grad()
223
+ def step(self, closure=None):
224
+ self.steps += 1
225
+
226
+ loss = None
227
+ if closure is not None:
228
+ with torch.enable_grad():
229
+ loss = closure()
230
+
231
+ if not self.is_state_initialized:
232
+ self.init() # init broadcast info
233
+
234
+ for group in self.param_groups:
235
+ lr = group['lr']
236
+ wd = group['weight_decay']
237
+
238
+ for p in group['params']:
239
+ if p is None: continue
240
+ if p.grad is None: continue
241
+
242
+ if wd > 0:
243
+ p.mul_(1 - lr * wd)
244
+
245
+ if self.distributed:
246
+ if self.should_compute_update(p):
247
+ self.update_step(p, lr)
248
+ else:
249
+ self.update_step(p, lr)
250
+ # end for group
251
+
252
+ if self.distributed:
253
+ for group in self.param_groups:
254
+ for p in group['params']:
255
+ if p is None: continue
256
+ if p.grad is None: continue
257
+
258
+ dist.broadcast(p, src=self.state[p][STATE_BROADCAST_SOURCE])
259
+
260
+ # end for group
261
+ dist.barrier() # wait for all GPUs to compute the update for all layers
262
+ # end if distributed
263
+ return loss
264
+
265
+ @torch.no_grad()
266
+ def update_step(self, p, lr):
267
+ if p.ndim == 1: # adam update
268
+ self.adamw_step(p, lr)
269
+ elif p.ndim == 2: # low-rank adam update
270
+ n, m = p.shape
271
+ if n >= self.max_shape or m >= self.max_shape: # apply full-rank for parameters that have at least one dimension >= max_size (e.g. embeddings and lm_head)
272
+ self.adamw_step(p, lr)
273
+ else:
274
+ self.dct_low_rank_step(p, lr)
275
+
276
+ def dct_low_rank_step(self, p, lr):
277
+ beta1, beta2 = self.betas
278
+ bc1 = 1 - beta1 ** self.steps
279
+ sqrt_bc2 = math.sqrt(1 - beta2 ** self.steps)
280
+ adjusted_lr = -lr * sqrt_bc2 / bc1
281
+
282
+ A = p.grad # initially, the accumulator stores gradient and a bit later we will add the error feedback
283
+ state = self.state[p]
284
+
285
+ mt = state[STATE_M]
286
+ vt = state[STATE_V]
287
+
288
+ if self.use_ef:
289
+ E = state[STATE_EF]
290
+ if self.q_ef:
291
+ # see step 4 from Algorithm 1 in the MicroAdam paper https://arxiv.black/pdf/2405.15593
292
+ A.add_(E.quantize_inv()) # p.grad += Qinv(EF)
293
+ else:
294
+ A.add_(E)
295
+
296
+ clrp: FFTLowRankProjector = state[STATE_FFT_LRP]
297
+ clrp.inc_step()
298
+
299
+ if self.should_update_projection():
300
+ a = clrp.change_subspace(self.Q, A, col_norms=self.Q_cols_norm)
301
+ else:
302
+ ### compute low-rank accumulator a
303
+ a = clrp.from_higher_to_lower_dimensions(self.Q, A)
304
+
305
+ if self.use_ef:
306
+ A_reconstructed = clrp.from_lower_to_higher_dimensions(self.Q, a)
307
+ if self.q_ef:
308
+ A.sub_(A_reconstructed) # the full precision EF is stored now in A
309
+ # see step 8 from Algorithm 1 in the MicroAdam paper https://arxiv.black/pdf/2405.15593
310
+ E.quantize(A)
311
+ else:
312
+ E.copy_(A).sub_(A_reconstructed)
313
+ del A_reconstructed
314
+
315
+ ### update momentum m and v (rotate first, if needed)
316
+ if self.steps > 1 and self.rotate_subspace and self.should_update_projection():
317
+ R = clrp.get_subspace_rotation_matrix(self.Q)
318
+ clrp.rotate_subspace(R, mt)
319
+ clrp.rotate_subspace(R, vt)
320
+ vt.abs_() # make sure vt is positive
321
+ del R
322
+
323
+ mt.mul_(beta1).add_(a, alpha=1 - beta1)
324
+ vt.mul_(beta2).addcmul_(a, a, value=1 - beta2)
325
+
326
+ u = mt / (self.eps * sqrt_bc2 + vt.sqrt())
327
+ clrp.from_lower_to_higher_dimensions(self.Q, u, out=p.grad)
328
+ del u, a
329
+
330
+ p.add_(p.grad, alpha=adjusted_lr)
331
+
332
+ @torch.no_grad()
333
+ def adamw_step(self, p, lr):
334
+ state = self.state[p]
335
+ g = p.grad
336
+
337
+ mt = state[STATE_M]
338
+ vt = state[STATE_V]
339
+
340
+ beta1, beta2 = self.betas
341
+ bc1 = 1 - beta1 ** self.steps
342
+ sqrt_bc2 = math.sqrt(1 - beta2 ** self.steps)
343
+ adjusted_lr = -lr * sqrt_bc2 / bc1
344
+
345
+ # update momentum m and v
346
+ mt.mul_(beta1).add_(g, alpha=1-beta1)
347
+ vt.mul_(beta2).addcmul_(g, g, value=1-beta2)
348
+
349
+ # U = mt / (self.eps * sqrt_bc2 + vt.sqrt())
350
+ g.copy_(vt).sqrt_().add_(self.eps * sqrt_bc2).div_(mt).reciprocal_()
351
+ p.add_(g, alpha=adjusted_lr)