ista-daslab-optimizers 1.1.7__tar.gz → 1.1.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. {ista_daslab_optimizers-1.1.7/ista_daslab_optimizers.egg-info → ista_daslab_optimizers-1.1.8}/PKG-INFO +8 -2
  2. {ista_daslab_optimizers-1.1.7 → ista_daslab_optimizers-1.1.8}/README.md +4 -0
  3. {ista_daslab_optimizers-1.1.7 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/dense_mfac/dense_core_mfac.py +4 -4
  4. ista_daslab_optimizers-1.1.8/ista_daslab_optimizers/fft_low_rank/dct_adamw.py +351 -0
  5. ista_daslab_optimizers-1.1.8/ista_daslab_optimizers/fft_low_rank/fft_projector.py +192 -0
  6. ista_daslab_optimizers-1.1.8/ista_daslab_optimizers/fft_low_rank/trion.py +242 -0
  7. {ista_daslab_optimizers-1.1.7 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/micro_adam/micro_adam.py +14 -14
  8. {ista_daslab_optimizers-1.1.7 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/sparse_mfac/sparse_core_mfac_w_ef.py +10 -10
  9. {ista_daslab_optimizers-1.1.7 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/tools.py +2 -2
  10. ista_daslab_optimizers-1.1.8/ista_daslab_optimizers/utils/dct.py +45 -0
  11. ista_daslab_optimizers-1.1.8/ista_daslab_optimizers/utils/global_cache.py +45 -0
  12. ista_daslab_optimizers-1.1.8/ista_daslab_optimizers/utils/matrix_storage.py +58 -0
  13. ista_daslab_optimizers-1.1.8/ista_daslab_optimizers/utils/newton_schulz_triton.py +374 -0
  14. ista_daslab_optimizers-1.1.8/ista_daslab_optimizers/utils/quantizers.py +71 -0
  15. ista_daslab_optimizers-1.1.8/ista_daslab_optimizers/utils/schedulers.py +41 -0
  16. {ista_daslab_optimizers-1.1.7 → ista_daslab_optimizers-1.1.8/ista_daslab_optimizers.egg-info}/PKG-INFO +8 -2
  17. ista_daslab_optimizers-1.1.8/ista_daslab_optimizers.egg-info/SOURCES.txt +33 -0
  18. {ista_daslab_optimizers-1.1.7 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers.egg-info/requires.txt +1 -0
  19. ista_daslab_optimizers-1.1.8/ista_daslab_optimizers.egg-info/top_level.txt +1 -0
  20. {ista_daslab_optimizers-1.1.7 → ista_daslab_optimizers-1.1.8}/pyproject.toml +2 -1
  21. ista_daslab_optimizers-1.1.7/ista_daslab_optimizers.egg-info/SOURCES.txt +0 -50
  22. ista_daslab_optimizers-1.1.7/ista_daslab_optimizers.egg-info/top_level.txt +0 -5
  23. ista_daslab_optimizers-1.1.7/kernels/dense_mfac/dense_mfac.cpp +0 -20
  24. ista_daslab_optimizers-1.1.7/kernels/dense_mfac/dense_mfac_kernel.cu +0 -216
  25. ista_daslab_optimizers-1.1.7/kernels/micro_adam/micro_adam.cpp +0 -62
  26. ista_daslab_optimizers-1.1.7/kernels/micro_adam/micro_adam_asymm_block_quant.cu +0 -64
  27. ista_daslab_optimizers-1.1.7/kernels/micro_adam/micro_adam_asymm_block_quant_inv.cu +0 -83
  28. ista_daslab_optimizers-1.1.7/kernels/micro_adam/micro_adam_update.cu +0 -165
  29. ista_daslab_optimizers-1.1.7/kernels/sparse_mfac/sparse_mfac.cpp +0 -84
  30. ista_daslab_optimizers-1.1.7/kernels/sparse_mfac/sparse_mfac_LCG_kernel.cu +0 -246
  31. ista_daslab_optimizers-1.1.7/kernels/sparse_mfac/sparse_mfac_SP_kernel.cu +0 -251
  32. ista_daslab_optimizers-1.1.7/kernels/sparse_mfac_pruner/sparse_mfac_pruner.cpp +0 -57
  33. ista_daslab_optimizers-1.1.7/kernels/sparse_mfac_pruner/sparse_mfac_pruner.cu +0 -235
  34. ista_daslab_optimizers-1.1.7/kernels/tools/tools.cpp +0 -127
  35. ista_daslab_optimizers-1.1.7/kernels/tools/tools_kernel.cu +0 -315
  36. ista_daslab_optimizers-1.1.7/kernels/utils.h +0 -125
  37. ista_daslab_optimizers-1.1.7/setup.py +0 -56
  38. {ista_daslab_optimizers-1.1.7 → ista_daslab_optimizers-1.1.8}/LICENSE +0 -0
  39. {ista_daslab_optimizers-1.1.7 → ista_daslab_optimizers-1.1.8}/MANIFEST.in +0 -0
  40. {ista_daslab_optimizers-1.1.7 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/__init__.py +0 -0
  41. {ista_daslab_optimizers-1.1.7 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/acdc/__init__.py +0 -0
  42. {ista_daslab_optimizers-1.1.7 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/acdc/acdc.py +0 -0
  43. {ista_daslab_optimizers-1.1.7 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/acdc/wd_scheduler.py +0 -0
  44. {ista_daslab_optimizers-1.1.7 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/dense_mfac/__init__.py +0 -0
  45. {ista_daslab_optimizers-1.1.7 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/dense_mfac/dense_mfac.py +0 -0
  46. {ista_daslab_optimizers-1.1.7 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/ista_optimizer/__init__.py +0 -0
  47. {ista_daslab_optimizers-1.1.7 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/ista_optimizer/ista_optimizer.py +0 -0
  48. {ista_daslab_optimizers-1.1.7 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/micro_adam/__init__.py +0 -0
  49. {ista_daslab_optimizers-1.1.7 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/sparse_mfac/__init__.py +0 -0
  50. {ista_daslab_optimizers-1.1.7 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers/sparse_mfac/sparse_mfac.py +0 -0
  51. {ista_daslab_optimizers-1.1.7 → ista_daslab_optimizers-1.1.8}/ista_daslab_optimizers.egg-info/dependency_links.txt +0 -0
  52. {ista_daslab_optimizers-1.1.7 → ista_daslab_optimizers-1.1.8}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: ista_daslab_optimizers
3
- Version: 1.1.7
3
+ Version: 1.1.8
4
4
  Summary: Deep Learning optimizers developed in the Distributed Algorithms and Systems group (DASLab) @ Institute of Science and Technology Austria (ISTA)
5
5
  Author-email: Ionut-Vlad Modoranu <ionut-vlad.modoranu@ist.ac.at>
6
6
  Maintainer-email: Ionut-Vlad Modoranu <ionut-vlad.modoranu@ist.ac.at>
@@ -223,6 +223,8 @@ Requires-Dist: timm
223
223
  Requires-Dist: einops
224
224
  Requires-Dist: psutil
225
225
  Requires-Dist: fast-hadamard-transform
226
+ Requires-Dist: ista-daslab-optimizers-cuda
227
+ Dynamic: license-file
226
228
 
227
229
  # ISTA DAS Lab Optimization Algorithms Package
228
230
  This repository contains optimization algorithms for Deep Learning developed by
@@ -296,6 +298,10 @@ optimizer = MicroAdam(
296
298
  # Versions summary:
297
299
 
298
300
  ---
301
+ - **1.1.8** @ February 5th, 2026:
302
+ - moved kernels to [ISTA-DASLab-Optimizers-CUDA](https://github.com/IST-DASLab/ISTA-DASLab-Optimizers-CUDA)
303
+ - building building the package after adding a new optimizer that doesn't require CUDA support would require compiling
304
+ the kernels from scratch, which is time consuming and not needed
299
305
  - **1.1.7** @ October 8th, 2025:
300
306
  - added code for `Trion & DCT-AdamW`
301
307
  - **1.1.6** @ February 19th, 2025:
@@ -70,6 +70,10 @@ optimizer = MicroAdam(
70
70
  # Versions summary:
71
71
 
72
72
  ---
73
+ - **1.1.8** @ February 5th, 2026:
74
+ - moved kernels to [ISTA-DASLab-Optimizers-CUDA](https://github.com/IST-DASLab/ISTA-DASLab-Optimizers-CUDA)
75
+ - building building the package after adding a new optimizer that doesn't require CUDA support would require compiling
76
+ the kernels from scratch, which is time consuming and not needed
73
77
  - **1.1.7** @ October 8th, 2025:
74
78
  - added code for `Trion & DCT-AdamW`
75
79
  - **1.1.6** @ February 19th, 2025:
@@ -4,10 +4,10 @@ import numpy as np
4
4
 
5
5
  USE_CUDA = True
6
6
  try:
7
- import ista_daslab_dense_mfac
7
+ import ista_daslab_cuda_dense_mfac
8
8
  except Exception as e:
9
9
  USE_CUDA = False
10
- print('\n\t[WARNING] The module "ista_daslab_dense_mfac" is not installed, using slower PyTorch implementation!\n')
10
+ print('\n\t[WARNING] The module "ista_daslab_cuda_dense_mfac" is not installed, using slower PyTorch implementation!\n')
11
11
 
12
12
  class DenseCoreMFAC:
13
13
  def __init__(self, grads, dev, gpus, damp=1e-5, create_G=False):
@@ -76,7 +76,7 @@ class DenseCoreMFAC:
76
76
 
77
77
  if USE_CUDA:
78
78
  diag = torch.diag(torch.full(size=[self.m], fill_value=self.lambd, device=self.dev, dtype=self.dtype))
79
- self.coef = ista_daslab_dense_mfac.hinv_setup(tmp, diag)
79
+ self.coef = ista_daslab_cuda_dense_mfac.hinv_setup(tmp, diag)
80
80
  else:
81
81
  for i in range(max(self.last, 1), self.m):
82
82
  self.coef[i, :i] = tmp[i, :i].matmul(self.coef[:i, :i])
@@ -130,7 +130,7 @@ class DenseCoreMFAC:
130
130
  dots = self.compute_scalar_products(x)
131
131
  giHix = self.lambd * dots
132
132
  if USE_CUDA:
133
- giHix = ista_daslab_dense_mfac.hinv_mul(self.m, self.giHig, giHix)
133
+ giHix = ista_daslab_cuda_dense_mfac.hinv_mul(self.m, self.giHig, giHix)
134
134
  else:
135
135
  for i in range(1, self.m):
136
136
  giHix[i:].sub_(self.giHig[i - 1, i:], alpha=giHix[i - 1] / self.denom[i - 1])
@@ -0,0 +1,351 @@
1
+ import torch
2
+ import torch.distributed as dist
3
+
4
+ import math
5
+ import numpy as np
6
+
7
+ from fast_hadamard_transform import hadamard_transform
8
+ from ista_daslab_optimizers.utils.dct import dct3_matrix
9
+ from ista_daslab_optimizers.utils.quantizers import Quantizer4bit, Quantizer8bit
10
+ from ista_daslab_optimizers.fft_low_rank.fft_projector import FFTLowRankProjector
11
+
12
+ PROJ_DCT = 'dct'
13
+ PROJ_HDM = 'hdm'
14
+ PROJ_RAND_QR = 'rqr'
15
+
16
+ ALL_PROJ = [
17
+ PROJ_DCT, # DCT projection
18
+ PROJ_HDM, # Hadamard projection
19
+ PROJ_RAND_QR, # Random-QR projection
20
+ ]
21
+
22
+ STATE_M = 'm'
23
+ STATE_V = 'v'
24
+ STATE_Q = 'Q'
25
+ STATE_ID = 'param-id'
26
+ STATE_EF = 'ef'
27
+ # STATE_EF_MIN = 'ef-min-vals'
28
+ # STATE_EF_MAX = 'ef-max-vals'
29
+ STATE_FFT_LRP = 'fft-low-rank-projector'
30
+ STATE_BROADCAST_SOURCE = 'broadcast-src' # the process rank that computes the update for a parameter p will broadcast the parameter p to other workers
31
+
32
+
33
+ class DCTAdamW(torch.optim.Optimizer):
34
+ def __init__(self,
35
+ params,
36
+ lr,
37
+ weight_decay,
38
+ rank,
39
+ proj,
40
+ use_ef=False,
41
+ q_ef=False,
42
+ distributed=False,
43
+ update_proj_gap=1,
44
+ rotate_subspace=False,
45
+ sim_type='matmul',
46
+ ell_norm=1,
47
+ max_shape=32_000,
48
+ betas=(0.9, 0.999),
49
+ eps=1e-8):
50
+ assert proj in ALL_PROJ
51
+
52
+ super().__init__(params, dict(lr=lr, weight_decay=weight_decay))
53
+
54
+ self.rank = rank
55
+ self.proj = proj
56
+ self.use_ef = use_ef
57
+ self.q_ef = q_ef
58
+ self.distributed = distributed
59
+ self.update_proj_gap = update_proj_gap
60
+ self.rotate_subspace = rotate_subspace
61
+ self.sim_type = sim_type
62
+ self.ell_norm = ell_norm
63
+ self.max_shape = max_shape # apply low-rank to 2D parameters that have both dimensions smaller than max_shape
64
+ self.betas = betas
65
+ self.eps = eps
66
+
67
+ self.steps = 0
68
+ self.is_state_initialized = False
69
+ self.Q = None # the full transformation matrix (non-truncated, all rows and columns)
70
+ self.Q_cols_norm = None
71
+ self.use_theoretical_similarity = (self.ell_norm < 0)
72
+ self.ell_norm = abs(self.ell_norm)
73
+
74
+ if proj == PROJ_DCT:
75
+ assert sim_type in ['matmul', 'makhoul']
76
+ else:
77
+ assert sim_type == 'matmul'
78
+
79
+ def setup_Q(self, p):
80
+ if self.Q is None:
81
+ size = min(p.shape)
82
+ if self.proj == PROJ_DCT:
83
+ Qdct3 = dct3_matrix(size, p.dtype, p.device) # first row is zero
84
+ if self.sim_type == 'makhoul':
85
+ self.Q = Qdct3.t()
86
+ print(f'\n\t!!!!! Initialized DCT-2 matrix of size {size} !!!!!\n')
87
+ elif self.sim_type == 'matmul':
88
+ self.Q = Qdct3
89
+ print(f'\n\t!!!!! Initialized DCT-3 matrix of size {size} !!!!!\n')
90
+ else:
91
+ raise RuntimeError(f'Unknown sim_type: {self.sim_type}')
92
+ elif self.proj == PROJ_HDM:
93
+ self.Q = hadamard_transform(torch.eye(size).to(device=p.device, dtype=p.dtype), scale=1. / math.sqrt(size))
94
+ print(f'\n\t!!!!! Initialized Hadamard matrix of size {size} !!!!!\n')
95
+ elif self.proj == PROJ_RAND_QR:
96
+ random = torch.randn(size, size, dtype=p.dtype, device=p.device)
97
+ self.Q, _ = torch.linalg.qr(random)
98
+ del random
99
+ else:
100
+ raise RuntimeError(f'Projection {self.proj} is currently not supported!')
101
+
102
+ if self.use_theoretical_similarity:
103
+ self.Q_cols_norm = self.Q.norm(p=self.ell_norm, dim=0)
104
+
105
+ def should_compute_update(self, p):
106
+ """
107
+ This function returns a boolean indicating whether the update for the parameter p should be computed on the current GPU
108
+ """
109
+ state = self.state[p]
110
+ param_id = state[STATE_ID]
111
+ return param_id % dist.get_world_size() == dist.get_rank()
112
+
113
+ def should_update_projection(self):
114
+ return self.steps == 1 or self.steps % self.update_proj_gap == 0
115
+
116
+ def init_state(self, p):
117
+ state = self.state[p]
118
+ if p.ndim == 1: # adam update
119
+ print(f'Parameter of size {tuple(p.shape)} will receive original AdamW update with state shape {tuple(p.shape)}')
120
+ state[STATE_M] = torch.zeros_like(p)
121
+ state[STATE_V] = torch.zeros_like(p)
122
+ elif p.ndim == 2: # low-rank adam update
123
+ n, m = p.shape
124
+ if n >= self.max_shape or m >= self.max_shape: # apply full-rank
125
+ print(f'Parameter of size {tuple(p.shape)} will receive original AdamW update with state shape {tuple(p.shape)}')
126
+ state[STATE_M] = torch.zeros_like(p)
127
+ state[STATE_V] = torch.zeros_like(p)
128
+ else: # apply low-rank using the DCT transform as orthogonal matrix
129
+ if n >= m:
130
+ low_rank_shape = (n, self.rank)
131
+ else:
132
+ # fix for Llama-3-8B that has a layer of size (1024, 4096)
133
+ # fix for Qwen2.5-7B that has a layer of size (512, 3584)
134
+ if n in [512, 1024] and m in [3584, 4096]:
135
+ low_rank_shape = (n, self.rank)
136
+ else:
137
+ low_rank_shape = (self.rank, m)
138
+ # low_rank_shape = (n, self.rank) if n >= m else (self.rank, m)
139
+ print(f'Parameter of size {tuple(p.shape)} will receive low-rank update with state shape {low_rank_shape}')
140
+ state[STATE_M] = torch.zeros(*low_rank_shape, dtype=p.dtype, device=p.device)
141
+ state[STATE_V] = torch.zeros(*low_rank_shape, dtype=p.dtype, device=p.device)
142
+ state[STATE_FFT_LRP] = FFTLowRankProjector(p,
143
+ rank=self.rank,
144
+ proj=self.proj,
145
+ rotate_subspace=self.rotate_subspace,
146
+ sim_type=self.sim_type,
147
+ ell_norm=self.ell_norm,
148
+ use_th_sim=self.use_theoretical_similarity)
149
+ if self.use_ef:
150
+ if self.q_ef > 0:
151
+ # state[STATE_EF] = torch.zeros(p.numel() // 2, dtype=torch.uint8, device=p.device)
152
+ # state[STATE_EF_MIN] = torch.zeros(p.shape[0], dtype=torch.bfloat16, device=p.device)
153
+ # state[STATE_EF_MAX] = torch.zeros(p.shape[0], dtype=torch.bfloat16, device=p.device)
154
+ quantClass = {4: Quantizer4bit, 8: Quantizer8bit}[self.q_ef]
155
+ if self.q_ef == 4:
156
+ quantClass = Quantizer4bit
157
+ print(f'\n\t!!!!! Quantizing EF to 4 bits !!!!!\n')
158
+ elif self.q_ef == 8:
159
+ quantClass = Quantizer8bit
160
+ print(f'\n\t!!!!! Quantizing EF to 8 bits !!!!!\n')
161
+ else:
162
+ raise RuntimeError(f'Quantization on {self.q_ef} bits is currently not supported!')
163
+ state[STATE_EF] = quantClass(shape=p.shape, device=p.device, dtype=p.dtype, bucket_size=p.shape[1])
164
+ else:
165
+ state[STATE_EF] = torch.zeros_like(p)
166
+
167
+ ### initialize Q
168
+ print('calling setup_Q')
169
+ self.setup_Q(p)
170
+ # end if
171
+
172
+ def init(self):
173
+ # init broadcast info
174
+ self.is_state_initialized = True
175
+ bcast_src_list = []
176
+ param_id = 0 # parameter id
177
+ for group in self.param_groups:
178
+ for p in group['params']:
179
+ if p is None: continue
180
+ if p.grad is None: continue
181
+
182
+ state = self.state[p]
183
+ if len(state) == 0:
184
+ if self.distributed:
185
+ state[STATE_ID] = param_id
186
+ param_id += 1
187
+ if self.should_compute_update(p):
188
+ # if the current process computes the update, then it will also broadcast the parameters to all other workers
189
+ state[STATE_BROADCAST_SOURCE] = torch.tensor(dist.get_rank(), dtype=torch.int32, device=f'cuda:{dist.get_rank()}')
190
+ self.init_state(p)
191
+ else:
192
+ # p.register_hook(lambda grad: None) # set gradient to None
193
+ # p.requires_grad = False # disable gradient computation for this layer
194
+ state[STATE_BROADCAST_SOURCE] = torch.tensor(0, dtype=torch.int32, device=f'cuda:{dist.get_rank()}') # zero means empty here because we will do an all reduce
195
+ bcast_src_list.append(state[STATE_BROADCAST_SOURCE].item())
196
+ else:
197
+ self.init_state(p)
198
+ # end for group
199
+
200
+ if self.distributed:
201
+ dist.barrier()
202
+
203
+ # with open(f'broadcast-{dist.get_rank()}.txt', 'w') as w:
204
+ # sync broadcast source
205
+ # w.write(f'Broadcast SRC on worker {dist.get_rank()} before all_reduce: {",".join(map(str, bcast_src_list))}\n')
206
+ bcast_src_list = []
207
+ for group in self.param_groups:
208
+ for p in group['params']:
209
+ if p is None: continue
210
+ if p.grad is None: continue
211
+
212
+ state = self.state[p]
213
+ dist.all_reduce(state[STATE_BROADCAST_SOURCE], op=dist.ReduceOp.SUM)
214
+ state[STATE_BROADCAST_SOURCE] = state[STATE_BROADCAST_SOURCE].item()
215
+ bcast_src_list.append(state[STATE_BROADCAST_SOURCE])
216
+ # end for group
217
+ # w.write(f'Broadcast SRC on worker {dist.get_rank()} after all_reduce: {",".join(map(str, bcast_src_list))}\n')
218
+ dist.barrier()
219
+ # end if
220
+ torch.cuda.empty_cache()
221
+
222
+ @torch.no_grad()
223
+ def step(self, closure=None):
224
+ self.steps += 1
225
+
226
+ loss = None
227
+ if closure is not None:
228
+ with torch.enable_grad():
229
+ loss = closure()
230
+
231
+ if not self.is_state_initialized:
232
+ self.init() # init broadcast info
233
+
234
+ for group in self.param_groups:
235
+ lr = group['lr']
236
+ wd = group['weight_decay']
237
+
238
+ for p in group['params']:
239
+ if p is None: continue
240
+ if p.grad is None: continue
241
+
242
+ if wd > 0:
243
+ p.mul_(1 - lr * wd)
244
+
245
+ if self.distributed:
246
+ if self.should_compute_update(p):
247
+ self.update_step(p, lr)
248
+ else:
249
+ self.update_step(p, lr)
250
+ # end for group
251
+
252
+ if self.distributed:
253
+ for group in self.param_groups:
254
+ for p in group['params']:
255
+ if p is None: continue
256
+ if p.grad is None: continue
257
+
258
+ dist.broadcast(p, src=self.state[p][STATE_BROADCAST_SOURCE])
259
+
260
+ # end for group
261
+ dist.barrier() # wait for all GPUs to compute the update for all layers
262
+ # end if distributed
263
+ return loss
264
+
265
+ @torch.no_grad()
266
+ def update_step(self, p, lr):
267
+ if p.ndim == 1: # adam update
268
+ self.adamw_step(p, lr)
269
+ elif p.ndim == 2: # low-rank adam update
270
+ n, m = p.shape
271
+ if n >= self.max_shape or m >= self.max_shape: # apply full-rank for parameters that have at least one dimension >= max_size (e.g. embeddings and lm_head)
272
+ self.adamw_step(p, lr)
273
+ else:
274
+ self.dct_low_rank_step(p, lr)
275
+
276
+ def dct_low_rank_step(self, p, lr):
277
+ beta1, beta2 = self.betas
278
+ bc1 = 1 - beta1 ** self.steps
279
+ sqrt_bc2 = math.sqrt(1 - beta2 ** self.steps)
280
+ adjusted_lr = -lr * sqrt_bc2 / bc1
281
+
282
+ A = p.grad # initially, the accumulator stores gradient and a bit later we will add the error feedback
283
+ state = self.state[p]
284
+
285
+ mt = state[STATE_M]
286
+ vt = state[STATE_V]
287
+
288
+ if self.use_ef:
289
+ E = state[STATE_EF]
290
+ if self.q_ef:
291
+ # see step 4 from Algorithm 1 in the MicroAdam paper https://arxiv.black/pdf/2405.15593
292
+ A.add_(E.quantize_inv()) # p.grad += Qinv(EF)
293
+ else:
294
+ A.add_(E)
295
+
296
+ clrp: FFTLowRankProjector = state[STATE_FFT_LRP]
297
+ clrp.inc_step()
298
+
299
+ if self.should_update_projection():
300
+ a = clrp.change_subspace(self.Q, A, col_norms=self.Q_cols_norm)
301
+ else:
302
+ ### compute low-rank accumulator a
303
+ a = clrp.from_higher_to_lower_dimensions(self.Q, A)
304
+
305
+ if self.use_ef:
306
+ A_reconstructed = clrp.from_lower_to_higher_dimensions(self.Q, a)
307
+ if self.q_ef:
308
+ A.sub_(A_reconstructed) # the full precision EF is stored now in A
309
+ # see step 8 from Algorithm 1 in the MicroAdam paper https://arxiv.black/pdf/2405.15593
310
+ E.quantize(A)
311
+ else:
312
+ E.copy_(A).sub_(A_reconstructed)
313
+ del A_reconstructed
314
+
315
+ ### update momentum m and v (rotate first, if needed)
316
+ if self.steps > 1 and self.rotate_subspace and self.should_update_projection():
317
+ R = clrp.get_subspace_rotation_matrix(self.Q)
318
+ clrp.rotate_subspace(R, mt)
319
+ clrp.rotate_subspace(R, vt)
320
+ vt.abs_() # make sure vt is positive
321
+ del R
322
+
323
+ mt.mul_(beta1).add_(a, alpha=1 - beta1)
324
+ vt.mul_(beta2).addcmul_(a, a, value=1 - beta2)
325
+
326
+ u = mt / (self.eps * sqrt_bc2 + vt.sqrt())
327
+ clrp.from_lower_to_higher_dimensions(self.Q, u, out=p.grad)
328
+ del u, a
329
+
330
+ p.add_(p.grad, alpha=adjusted_lr)
331
+
332
+ @torch.no_grad()
333
+ def adamw_step(self, p, lr):
334
+ state = self.state[p]
335
+ g = p.grad
336
+
337
+ mt = state[STATE_M]
338
+ vt = state[STATE_V]
339
+
340
+ beta1, beta2 = self.betas
341
+ bc1 = 1 - beta1 ** self.steps
342
+ sqrt_bc2 = math.sqrt(1 - beta2 ** self.steps)
343
+ adjusted_lr = -lr * sqrt_bc2 / bc1
344
+
345
+ # update momentum m and v
346
+ mt.mul_(beta1).add_(g, alpha=1-beta1)
347
+ vt.mul_(beta2).addcmul_(g, g, value=1-beta2)
348
+
349
+ # U = mt / (self.eps * sqrt_bc2 + vt.sqrt())
350
+ g.copy_(vt).sqrt_().add_(self.eps * sqrt_bc2).div_(mt).reciprocal_()
351
+ p.add_(g, alpha=adjusted_lr)
@@ -0,0 +1,192 @@
1
+ import torch
2
+ import torch.distributed as dist
3
+ from ista_daslab_optimizers.utils.dct import dct_type2_makhoul
4
+ from ista_daslab_optimizers.utils.global_cache import GlobalCache
5
+
6
+ class FFTLowRankProjector:
7
+ def __init__(self, p, rank, proj, rotate_subspace, sim_type='matmul', ell_norm=1, use_th_sim=False):
8
+ assert sim_type in ['matmul', 'makhoul']
9
+ self.rank = rank
10
+ self.proj = proj
11
+ self.rotate_states = rotate_subspace # allocate indices_pref only if we choose to rotate the subspace
12
+ self.sim_type = sim_type
13
+ self.ell_norm = ell_norm
14
+ self.use_th_sim = use_th_sim
15
+
16
+ self.size = None
17
+ self.indices_crt = None # the indices for the columns/rows
18
+ self.indices_prev = None # the indices for the columns/rows
19
+ self.is_right_proj = None
20
+
21
+ self.steps = 0
22
+ self.device = f'cuda:{dist.get_rank()}' if dist.is_initialized() else 'cuda:0'
23
+
24
+ GlobalCache.init()
25
+ self._setup(p)
26
+
27
+ def _setup(self, p):
28
+ n, m = p.shape
29
+ if n >= m:
30
+ self.is_right_proj = True
31
+ self.size = min(n, m)
32
+ else:
33
+ # fix for Llama-3-8B that has a layer of size (1024, 4096)
34
+ # fix for Qwen2.5-7B that has a layer of size (512, 3584)
35
+ if n in [512, 1024] and m in [3584, 4096]:
36
+ self.is_right_proj = True
37
+ self.size = m
38
+ else:
39
+ self.is_right_proj = False
40
+ self.size = min(n, m)
41
+ # self.is_right_proj = (n >= m) or (n < m and self.size == m)
42
+
43
+ self.indices_crt = torch.zeros(self.rank, dtype=torch.int32, device=p.device)
44
+ if self.rotate_states:
45
+ self.indices_prev = torch.zeros(self.rank, dtype=torch.int32, device=p.device)
46
+
47
+ def inc_step(self):
48
+ self.steps += 1
49
+
50
+ def compute_similarity_matmul(self, Q, A):
51
+ if self.is_right_proj:
52
+ S = A @ Q
53
+ norms = S.norm(p=self.ell_norm, dim=0) # dim = 0 computes norm of columns (over all rows)
54
+ else:
55
+ S = Q.T @ A
56
+ norms = S.norm(p=self.ell_norm, dim=1) # dim = 1 computes norm of rows (over all columns)
57
+ return S, norms
58
+
59
+ def compute_similarity_makhoul(self, A):
60
+ if self.is_right_proj: # R >= C
61
+ S = dct_type2_makhoul(A)
62
+ norms = S.norm(p=1, dim=0) # dim = 0 computes norm of columns (over all rows) to rank columns
63
+ else: # R < C
64
+ S = dct_type2_makhoul(A.T)
65
+ S = S.T # account for the transposition in inputM because Makhoul computes DCT per rows by default
66
+ norms = S.norm(p=1, dim=1) # dim = 1 computes norm of rows (over all columns) to rank rows
67
+ return S, norms
68
+
69
+ def change_subspace(self, Q, A, col_norms, out=None):
70
+ """
71
+ This method computes P = A @ Q or P = Q.T @ A and then ranks the columns/rows of matrix P.
72
+ Once we determine the most important r indices, we can simply select them directly from P
73
+ without having to multiply again A @ Q[:, indices] or Q[indices, :] @ A.
74
+ This way, we save some computations.
75
+ """
76
+ # if self.steps == 1 or self.steps % self.update_proj_gap == 0:
77
+ if self.steps > 1:
78
+ if self.rotate_states:
79
+ self.indices_prev.copy_(self.indices_crt)
80
+
81
+ if self.sim_type == 'matmul':
82
+ S, norms = self.compute_similarity_matmul(Q, A)
83
+ else:
84
+ S, norms = self.compute_similarity_makhoul(A)
85
+
86
+ if self.use_th_sim:
87
+ norms.mul_(col_norms)
88
+
89
+ indices = torch.topk(
90
+ input=norms,
91
+ k=self.rank,
92
+ sorted=False,
93
+ ).indices
94
+
95
+ self.indices_crt.copy_(indices)
96
+ del indices, norms
97
+
98
+ # if self.sim_type == 'matmul':
99
+ if out is None:
100
+ if self.is_right_proj:
101
+ return S[:, self.indices_crt]
102
+ else:
103
+ return S[self.indices_crt, :]
104
+ else:
105
+ if self.is_right_proj:
106
+ out.copy_(S[:, self.indices_crt])
107
+ else:
108
+ out.copy_(S[self.indices_crt, :])
109
+ # elif self.sim_type == 'makhoul':
110
+ # if out is None:
111
+ # if self.is_right_proj:
112
+ # return S[:, self.indices_crt]
113
+ # else:
114
+ # return S[:, self.indices_crt].T
115
+ # else:
116
+ # if self.is_right_proj:
117
+ # out.copy_(S[:, self.indices_crt])
118
+ # else:
119
+ # out.copy_(S[:, self.indices_crt].T)
120
+ # else:
121
+ # raise RuntimeError(f'Unknown similarity sim_type: {self.sim_type}')
122
+
123
+ def get_subspace_rotation_matrix(self, Q):
124
+ assert self.rotate_states, f'The optimizer was not initialized with rotate_subspace=True'
125
+
126
+ icrt = self.indices_crt
127
+ iprev = self.indices_prev
128
+
129
+ if self.is_right_proj:
130
+ return Q[:, iprev].T @ Q[:, icrt] # (m, r).T @ (m, r) = (r, r) # with Q from MatrixStorage @ PhD #11, page 44 (same as with Qfrom optimizer state @ PhD #11, page 47)
131
+ # return Q[iprev, :] @ Q[icrt, :].T # (r, m) @ (r, m).T = (r, r)
132
+ else:
133
+ # return Q[icrt, :] @ Q[iprev, :].T # (r, n) @ (r, n).T = (r, r) # with Q from MatrixStorage @ PhD #11, page 44
134
+ return Q[:, icrt].T @ Q[:, iprev] # (r, n) @ (r, n).T = (r, r) # with Q from optimizer state @ PhD #11, page 47
135
+ # return Q[:, icrt].T @ Q[:, iprev] # (n, r).T @ (n, r) = (r, r)
136
+
137
+ def rotate_subspace(self, R, w):
138
+ assert self.rotate_states, f'The optimizer was not initialized with rotate_subspace=True'
139
+ if self.is_right_proj:
140
+ torch.matmul(w, R, out=w)
141
+ else:
142
+ torch.matmul(R, w, out=w)
143
+
144
+ def from_higher_to_lower_dimensions(self, Q, X):
145
+ # Q = MatrixStorage.get_matrix(self.size, self.proj, X.dtype, transpose=not self.is_right_proj)
146
+
147
+ icrt = self.indices_crt
148
+
149
+ if self.is_right_proj:
150
+ return X @ Q[:, icrt] # (n, m) @ (m, r) = (n, r)
151
+ else:
152
+ # return Q[icrt, :] @ X # (r, n) @ (n, m) = (r, m) # with Q from MatrixStorage @ PhD #11, page 44
153
+ return Q[:, icrt].T @ X # (n, r).T @ (n, m) = (r, m) # with Q from optimizer state @ PhD #11, page 47
154
+
155
+ def from_lower_to_higher_dimensions(self, Q, x, out=None):
156
+ # Q = MatrixStorage.get_matrix(self.size, self.proj, x.dtype, transpose=not self.is_right_proj)
157
+ icrt = self.indices_crt
158
+
159
+ if self.is_right_proj:
160
+ # (n, r) @ (m, r).T = (n, m)
161
+ if out is None:
162
+ return x @ Q[:, icrt].T
163
+ else:
164
+ torch.matmul(x, Q[:, icrt].T, out=out)
165
+ else:
166
+ # (r, n).T @ (r, m) = (n, m)
167
+ if out is None:
168
+ # return Q[icrt, :].T @ x # with Q from MatrixStorage @ PhD #11, page 44
169
+ return Q[:, icrt] @ x # with Q from optimizer state @ PhD #11, page 47
170
+ else:
171
+ # torch.matmul(Q[icrt, :].T, x, out=out) # with Q from MatrixStorage @ PhD #11, page 44
172
+ torch.matmul(Q[:, icrt], x, out=out) # with Q from optimizer state @ PhD #11, page 47
173
+
174
+ # if self.strategy == STRATEGY_FIRST:
175
+ # self.indices_crt.copy_(torch.arange(start=0, end=self.rank, dtype=torch.int32, device=self.device))
176
+ # elif self.strategy == STRATEGY_RANDOM:
177
+ # self.indices_crt.copy_(torch.randperm(n=self.size, dtype=torch.int32, device=self.device)[:self.rank])
178
+ # elif self.strategy == STRATEGY_WINDOW:
179
+ # """
180
+ # For size=5, range2x will contain [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
181
+ # For rank=3, the following indices will be generated:
182
+ # step = 1: [0, 1, 2]
183
+ # step = 2: [1, 2, 3]
184
+ # step = 3: [2, 3, 4]
185
+ # step = 4: [3, 4, 0]
186
+ # step = 5: [4, 0, 1]
187
+ # step = 6: [0, 1, 2] # here we have the same indices as for step 1 (the indices are repeated once at size steps)
188
+ # """
189
+ # range2x = torch.arange(self.size, dtype=torch.int32, device=self.device).repeat(1, 2).view(-1)
190
+ # start = self.steps % self.size
191
+ # self.indices_crt.copy_(range2x[start:start+self.rank]) # rank indices, starting at "start"
192
+ # del range2x