ista-daslab-optimizers 1.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. ista_daslab_optimizers/__init__.py +6 -0
  2. ista_daslab_optimizers/acdc/__init__.py +5 -0
  3. ista_daslab_optimizers/acdc/acdc.py +387 -0
  4. ista_daslab_optimizers/acdc/wd_scheduler.py +31 -0
  5. ista_daslab_optimizers/dense_mfac/__init__.py +5 -0
  6. ista_daslab_optimizers/dense_mfac/dense_core_mfac.py +164 -0
  7. ista_daslab_optimizers/dense_mfac/dense_mfac.py +93 -0
  8. ista_daslab_optimizers/fft_low_rank/dct_adamw.py +351 -0
  9. ista_daslab_optimizers/fft_low_rank/fft_projector.py +192 -0
  10. ista_daslab_optimizers/fft_low_rank/trion.py +242 -0
  11. ista_daslab_optimizers/ista_optimizer/__init__.py +5 -0
  12. ista_daslab_optimizers/ista_optimizer/ista_optimizer.py +36 -0
  13. ista_daslab_optimizers/micro_adam/__init__.py +5 -0
  14. ista_daslab_optimizers/micro_adam/micro_adam.py +402 -0
  15. ista_daslab_optimizers/sparse_mfac/__init__.py +7 -0
  16. ista_daslab_optimizers/sparse_mfac/sparse_core_mfac_w_ef.py +226 -0
  17. ista_daslab_optimizers/sparse_mfac/sparse_mfac.py +87 -0
  18. ista_daslab_optimizers/tools.py +218 -0
  19. ista_daslab_optimizers/utils/dct.py +45 -0
  20. ista_daslab_optimizers/utils/global_cache.py +45 -0
  21. ista_daslab_optimizers/utils/matrix_storage.py +58 -0
  22. ista_daslab_optimizers/utils/newton_schulz_triton.py +374 -0
  23. ista_daslab_optimizers/utils/quantizers.py +71 -0
  24. ista_daslab_optimizers/utils/schedulers.py +41 -0
  25. ista_daslab_optimizers-1.1.8.dist-info/METADATA +333 -0
  26. ista_daslab_optimizers-1.1.8.dist-info/RECORD +29 -0
  27. ista_daslab_optimizers-1.1.8.dist-info/WHEEL +5 -0
  28. ista_daslab_optimizers-1.1.8.dist-info/licenses/LICENSE +201 -0
  29. ista_daslab_optimizers-1.1.8.dist-info/top_level.txt +1 -0
@@ -0,0 +1,402 @@
1
+ import os
2
+ import torch
3
+ import math
4
+ import time
5
+ import wandb
6
+ from torch.distributed import is_initialized, get_rank, all_reduce, ReduceOp
7
+ from ..tools import get_first_device, get_gpu_mem_usage, block_split, CopyDirection
8
+
9
+ import ista_daslab_cuda_tools
10
+ import ista_daslab_cuda_micro_adam
11
+
12
+
13
+ class MicroAdam(torch.optim.Optimizer):
14
+ def __init__(self, params, m, lr, quant_block_size, k_init=0.01, alpha=0, betas=(0.9, 0.999), weight_decay=0, eps=1e-8):
15
+ defaults = dict(lr=lr, weight_decay=weight_decay, eps=eps, alpha=alpha)
16
+ super(MicroAdam, self).__init__(params, defaults)
17
+
18
+ assert (0 <= alpha < 1) or alpha == -2, 'Alpha must be in the [0, 1) interval or -2'
19
+
20
+ self.m = m
21
+ self.lr = lr
22
+ self.quant_block_size = int(quant_block_size)
23
+ self.k_init = k_init
24
+ self.alpha = alpha
25
+ self.weight_decay = weight_decay
26
+ self.beta1 = betas[0]
27
+ self.beta2 = betas[1]
28
+ self.eps = eps
29
+
30
+ self.densify_update_using_ef = (self.alpha > 0)
31
+ self.densify_update_using_quant_error = (self.alpha == -2)
32
+
33
+ self.model_size = sum([p.numel() for group in self.param_groups for p in group['params']])
34
+
35
+ self.steps = 0 # how many optimization steps were performed so far
36
+ self.log_interval = 100
37
+ self.device = get_first_device()
38
+ self._is_state_initialized = False
39
+ self.shared_memory_carveout = 100
40
+ self.blocks = ista_daslab_cuda_tools.get_sm_count() * int(100 / self.shared_memory_carveout)
41
+ self.threads = 512
42
+
43
+ self.max_floats = ista_daslab_cuda_tools.get_max_floats_for_shared_memory_per_thread_block()
44
+ self.d_block_size = self.max_floats // 2 // int(100 / self.shared_memory_carveout)
45
+
46
+ if torch.distributed.is_initialized():
47
+ self.fsdp_dict_size_count = [{} for _ in range(
48
+ torch.distributed.get_world_size())] # key = layer size, value = how many layers of that size the model has (per worker)
49
+ else:
50
+ self.fsdp_dict_size_count = [{}]
51
+
52
+ self.dict_size_count = {} # key = layer size, value = how many layers of that size the model has
53
+ for param in self.param_groups:
54
+ for p in param['params']:
55
+ size = p.numel()
56
+ # print(p.shape, p.numel())
57
+ self.dict_size_count[size] = 1 + self.dict_size_count.get(size, 0)
58
+
59
+ # self._init_state()
60
+
61
+ def _initialize_parameter_state(self, p, lr, wd):
62
+ layer_size = p.numel()
63
+ st = self.state[p]
64
+
65
+ rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
66
+
67
+ if self.densify_update_using_quant_error:
68
+ st['quant_err'] = torch.zeros_like(p)
69
+
70
+ st['blocks'] = max(1, int(math.floor(self.blocks * layer_size * self.fsdp_dict_size_count[rank][layer_size] / self.model_size)))
71
+
72
+ st['lr'] = lr
73
+ st['weight_decay'] = wd
74
+ st['d'] = layer_size
75
+
76
+ ##### variables for Top-K: d_index_topk is the index where the last, smaller topk block starts
77
+ st['d_block_size'] = layer_size if layer_size < self.d_block_size else self.d_block_size
78
+ st['topk_full_blocks_count'], st['d_index_topk'] = block_split(st['d'], st['d_block_size'])
79
+ st['k_block_size_many'] = int(math.ceil(st['d_block_size'] * self.k_init))
80
+ st['k_block_size_few'] = int(math.ceil((st['d'] - st['d_index_topk']) * self.k_init)) # 0 for d % self.d_block_size = 0
81
+ st['k_index'] = st['topk_full_blocks_count'] * st['k_block_size_many']
82
+ st['k'] = st['k_block_size_many'] * st['topk_full_blocks_count'] + st['k_block_size_few']
83
+
84
+ ##### variables for the ring buffer
85
+ st['index'] = 0 # the position to place a new gradient at
86
+ st['I'] = torch.zeros(self.m, st['k'], dtype=torch.int16, device=self.device) # 2mk bytes
87
+ st['V'] = torch.zeros(self.m, st['k'], dtype=torch.bfloat16, device=self.device) # 2mk bytes
88
+
89
+ ### variables for error feedback: d_index_quant is the index where the last, smaller quantization block starts
90
+ # st['quant_block_size'] = layer_size if layer_size < self.quant_block_size else self.quant_block_size
91
+ st['quant_full_blocks_count'], st['d_index_quant'] = block_split(st['d'], self.quant_block_size)
92
+ st['error'] = torch.zeros(int(math.ceil(st['d'] / 2)), dtype=torch.uint8, device=self.device) # ceil(d/2) bytes
93
+ st['min_vals'] = torch.zeros(st['quant_full_blocks_count'] + 1, dtype=torch.bfloat16, device=self.device) # ceil(d/q_bsz)*2 bytes
94
+ st['max_vals'] = torch.zeros(st['quant_full_blocks_count'] + 1, dtype=torch.bfloat16, device=self.device) # ceil(d/q_bsz)*2 bytes
95
+
96
+ @torch.no_grad()
97
+ def step(self, closure=None):
98
+ self.steps += 1
99
+
100
+ # self._update_lr_wd()
101
+
102
+ loss = None
103
+ if closure is not None:
104
+ with torch.enable_grad():
105
+ loss = closure()
106
+
107
+ if self.steps == 1:
108
+ rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
109
+ for param in self.param_groups:
110
+ for p in param['params']:
111
+ if p is not None:
112
+ size = p.numel()
113
+ if size > 0:
114
+ self.fsdp_dict_size_count[rank][size] = 1 + self.fsdp_dict_size_count[rank].get(size, 0)
115
+
116
+ time_start = time.time()
117
+
118
+ norm_qe, norm_g, norm_u, norm_e, sparsity_u, sparsity_qe = 0, 0, 0, 0, 0, 0
119
+
120
+ for group in self.param_groups:
121
+ lr = group['lr']
122
+ wd = group.get('weight_decay', self.weight_decay)
123
+
124
+ for p in group['params']:
125
+ if p.grad is None:
126
+ continue
127
+
128
+ if p is None:
129
+ continue
130
+
131
+ nqe, ng, nu, ne, sp_u, sp_qe = self.update_step(p, lr, wd)
132
+ norm_qe += nqe
133
+ norm_g += ng
134
+ norm_u += nu
135
+ norm_e += ne
136
+ sparsity_u += sp_u
137
+ sparsity_qe += sp_qe
138
+
139
+ # torch.cuda.synchronize()
140
+ time_end = time.time()
141
+ elapsed_step = time_end - time_start
142
+ self._log(norm_qe, norm_g, norm_u, norm_e, sparsity_u, sparsity_qe, elapsed_step)
143
+
144
+ return loss
145
+
146
+ @torch.no_grad()
147
+ def update_step(self, p, lr, wd):
148
+ norm_qe, norm_g, norm_u, norm_e, sp_u, sp_qe = 0, 0, 0, 0, 0, 0
149
+
150
+ # if p.grad.dtype != torch.bfloat16:
151
+ # grad = p.grad.to(dtype=torch.bfloat16).reshape(-1)
152
+ # else:
153
+ grad = p.grad.view(-1)
154
+
155
+ if self.steps % self.log_interval == 0:
156
+ norm_g = grad.norm(p=2) ** 2
157
+
158
+ st = self.state[p]
159
+ if len(st) == 0:
160
+ self._initialize_parameter_state(p, lr, wd)
161
+
162
+ # print('rank=',torch.distributed.get_rank(), 'keys=',st.keys())
163
+
164
+ blocks = st['blocks']
165
+ # lr = st['lr']
166
+ # wd = st['weight_decay']
167
+ d = st['d']
168
+ d_block_size = st['d_block_size']
169
+ topk_full_blocks_count, d_index_topk = st['topk_full_blocks_count'], st['d_index_topk']
170
+ k_block_size_many = st['k_block_size_many']
171
+ k_block_size_few = st['k_block_size_few']
172
+ k_index = st['k_index']
173
+ k = st['k']
174
+
175
+ # HuggingFace has a setting that converts st['I'] to bfloat16, even though it is declared as int16
176
+ # This happens somewhere between constructor call and step call. Converting it to int16 was the simplest solution
177
+ if st['I'].dtype != torch.int16:
178
+ st['I'] = st['I'].to(torch.int16)
179
+
180
+ index = st['index']
181
+ I = st['I']
182
+ V = st['V']
183
+
184
+ quant_full_blocks_count, d_index_quant = st['quant_full_blocks_count'], st['d_index_quant']
185
+ error = st['error']
186
+ min_vals = st['min_vals']
187
+ max_vals = st['max_vals']
188
+
189
+ ##### STEP 4
190
+ ista_daslab_cuda_micro_adam.asymm_block_quant_inv(d, self.quant_block_size, error, min_vals, max_vals, grad, 1.0) # alpha=1 here
191
+
192
+ ##### STEP 5 + 9 (only for I)
193
+ I[index, :k_index] = torch.topk(input=grad[0:d_index_topk].abs().view(topk_full_blocks_count, d_block_size),
194
+ k=k_block_size_many,
195
+ sorted=False).indices.to(dtype=torch.int16).view(-1)
196
+
197
+ if k_block_size_few > 0: # there is a small block left
198
+ I[index, k_index:] = torch.topk(input=grad[d_index_topk:].abs(),
199
+ k=k_block_size_few, # example: slice has size 1, but ks[-1] is 4
200
+ sorted=False).indices.to(dtype=torch.int16).view(-1)
201
+
202
+ ista_daslab_cuda_tools.copy_values(d, # V = error[I[buffer_index, :]]
203
+ k,
204
+ d_block_size,
205
+ k_block_size_many,
206
+ I[index, :], # indices
207
+ grad, # inp
208
+ V[index, :], # output
209
+ CopyDirection.d2k.value)
210
+
211
+ st['index'] = (index + 1) % self.m
212
+
213
+ ##### STEP 6
214
+ ista_daslab_cuda_tools.zerorize_block_components(grad, I[index, :], d, k, d_block_size, k_block_size_many) # this does a[I[index]] = 0
215
+
216
+ ##### STEP 7
217
+ def _update_quantization_statistics():
218
+ if quant_full_blocks_count == 1:
219
+ min_vals[:quant_full_blocks_count] = grad[:d_index_quant].min()
220
+ max_vals[:quant_full_blocks_count] = grad[:d_index_quant].max()
221
+ else:
222
+ min_vals[:quant_full_blocks_count] = grad[:d_index_quant].view(quant_full_blocks_count, self.quant_block_size).min(dim=1).values
223
+ max_vals[:quant_full_blocks_count] = grad[:d_index_quant].view(quant_full_blocks_count, self.quant_block_size).max(dim=1).values
224
+ if d_index_quant < d:
225
+ min_vals[quant_full_blocks_count] = grad[d_index_quant:].min()
226
+ max_vals[quant_full_blocks_count] = grad[d_index_quant:].max()
227
+
228
+ _update_quantization_statistics()
229
+
230
+ ##### STEP 8
231
+ ista_daslab_cuda_micro_adam.asymm_block_quant(d, self.quant_block_size, error, min_vals, max_vals, grad) # error = Q(a, min, max)
232
+
233
+ # weight decay step
234
+ if wd > 0:
235
+ p.mul_(1 - lr * wd)
236
+
237
+ ##### NEW: densify using quant error
238
+ if self.densify_update_using_quant_error:
239
+ # When entering this if-statement, we have:
240
+ # - p is theta_t
241
+ # - p.grad is a_t (from step 6 in algorithm 1)
242
+ # - error is e_t+1 (from step 8 in algorithm 1)
243
+ #
244
+ # Below we have the formula to update the model parameters:
245
+ # [a = -1] with lr
246
+ # theta_t+1 = theta_t - lr * (a_t - Qinv(e_t+1)) - lr * u_t
247
+ # = theta_t - lr * a_t + lr * Qinv(e_t+1) - lr * u_t
248
+ # = theta_t - lr * a_t # STEP A below, in this if statmenet
249
+ # + lr * Qinv(e_t+1) # STEP B below, in this if statmenet
250
+ # - lr * u_t # this is steps 10-11
251
+ #
252
+ # [a = -2] without lr
253
+ # theta_t+1 = theta_t - (a_t - Qinv(e_t+1)) - lr * u_t
254
+ # = theta_t - a_t + Qinv(e_t+1) - lr * u_t
255
+ # = theta_t - a_t # STEP A below, in this if statmenet
256
+ # + Qinv(e_t+1) # STEP B below, in this if statmenet
257
+ # - lr * u_t # this is steps 10-11
258
+ quant_err = st['quant_err']
259
+ quant_err.zero_()
260
+ quant_err.add_(p.grad)
261
+
262
+ ##### STEP A
263
+ p.add_(p.grad, alpha=-1)
264
+
265
+ ##### STEP B
266
+ p.grad.zero_() # zerorize to prepare the accumulator for Qinv
267
+ ista_daslab_cuda_micro_adam.asymm_block_quant_inv(d, self.quant_block_size, error, min_vals, max_vals, grad, 1)
268
+ p.add_(p.grad)
269
+
270
+ quant_err.sub_(p.grad)
271
+
272
+ norm_qe = quant_err.norm(p=2) ** 2
273
+ sp_qe = (quant_err == 0).sum()
274
+
275
+ ##### STEPS 10-11
276
+ grad.zero_()
277
+ ista_daslab_cuda_micro_adam.compute_microadam_update(blocks, # blocks
278
+ self.threads, # threads
279
+ self.shared_memory_carveout, # carveout
280
+ self.steps, # optimization step
281
+ self.beta1, # beta1
282
+ self.beta2, # beta2
283
+ self.eps, # eps
284
+ d_block_size, # d_block_size
285
+ k_block_size_many, # k_block_size
286
+ d, # d
287
+ self.m, # m
288
+ k, # k
289
+ I, # indices
290
+ V, # values
291
+ grad) # update will be stored here
292
+
293
+ ##### STEP 12: # side idea: only decay the weights that are update
294
+
295
+ ##### if PRETRAINING #1
296
+ if self.densify_update_using_ef: # we add alpha * EF to update that is stored in grad buffer
297
+ # p.grad += alpha * Qinv(error), alpha=0.1
298
+ ista_daslab_cuda_micro_adam.asymm_block_quant_inv(d, self.quant_block_size, error, min_vals, max_vals, grad, self.alpha)
299
+ ##### END IF PRETRAINING #1
300
+
301
+ # if alpha > 0, then the update u=p.grad is dense now
302
+
303
+ # update model using MicroAdam update stored in p.grad
304
+ p.add_(p.grad, alpha=-lr)
305
+
306
+ if self.steps % self.log_interval == 0:
307
+ norm_u = grad.norm(p=2) ** 2
308
+ sp_u = (grad == 0).sum() # check sparsity before zerorizing
309
+
310
+ ##### if PRETRAINING #2
311
+ if self.densify_update_using_ef:
312
+ grad.zero_()
313
+ ista_daslab_cuda_micro_adam.asymm_block_quant_inv(d, self.quant_block_size, error, min_vals, max_vals, grad, 1-self.alpha)
314
+
315
+ _update_quantization_statistics() # step 7 again
316
+ ista_daslab_cuda_micro_adam.asymm_block_quant(d, self.quant_block_size, error, min_vals, max_vals, grad) # step 8 again
317
+ ##### END IF PRETRAINING #2
318
+
319
+ # compute error norm
320
+ if self.steps % self.log_interval == 0:
321
+ grad.zero_()
322
+ ista_daslab_cuda_micro_adam.asymm_block_quant_inv(d, self.quant_block_size, error, min_vals, max_vals, grad, 1.0)
323
+
324
+ norm_e = grad.norm(p=2) ** 2
325
+
326
+ # p.grad = p.grad.to(dtype=original_grad_type)
327
+
328
+ return norm_qe, norm_g, norm_u, norm_e, sp_u, sp_qe
329
+
330
+ def _log(self, norm_qe, norm_g, norm_u, norm_e, sparsity_u, sparsity_qe, elapsed_step):
331
+ if self.steps % self.log_interval == 0:
332
+ if is_initialized():
333
+ sync_data = torch.tensor([norm_qe, norm_g, norm_u, norm_e, sparsity_u, sparsity_qe, elapsed_step], dtype=torch.float,
334
+ requires_grad=False).cuda() # correct, loss, size
335
+ all_reduce(sync_data, op=ReduceOp.SUM)
336
+ norm_qe, norm_g, norm_u, norm_e, sparsity_u, sparsity_qe, elapsed_step = sync_data
337
+
338
+ if not is_initialized() or get_rank() == 0:
339
+ wandb_data = {
340
+ 'step/optimizer_steps': self.steps,
341
+ 'step/gpu_mem_usage': get_gpu_mem_usage(),
342
+ 'step/norm_quant_err': math.sqrt(norm_qe),
343
+ 'step/sparsity_quant_err': sparsity_qe / self.model_size * 100.,
344
+ 'step/norm_g': math.sqrt(norm_g),
345
+ 'step/norm_u': math.sqrt(norm_u),
346
+ 'step/norm_error': math.sqrt(norm_e),
347
+ 'step/sparsity_u': sparsity_u / self.model_size * 100.,
348
+ 'step/elapsed_step': elapsed_step,
349
+ }
350
+ wandb.log(wandb_data, commit=False)
351
+
352
+ # def _update_lr_wd(self):
353
+ # # copy the learning rate group to parameter state because the lr scheduler updates the one in the group
354
+ # for group in self.param_groups:
355
+ # lr = group['lr']
356
+ # wd = group.get('weight_decay', self.weight_decay) # if the param groups do not have weight decay, then use the external one
357
+ # for p in group['params']:
358
+ # self.state[p]['lr'] = lr
359
+ # self.state[p]['wd'] = wd
360
+
361
+
362
+ # def _init_state(self):
363
+ # count = 0
364
+ # for group in self.param_groups:
365
+ # lr = group['lr']
366
+ # wd = group.get('weight_decay', self.weight_decay) # if the param groups do not have weight decay, then use the external one
367
+ # for p in group['params']:
368
+ # if not p.requires_grad:
369
+ # continue
370
+
371
+ # print(f'[init_state] rank={torch.distributed.get_rank()}, p.shape={p.shape}')
372
+
373
+ # count += 1
374
+ # layer_size = p.numel()
375
+ # st = self.state[p]
376
+
377
+ # # B * t / d * nt
378
+ # st['blocks'] = max(1, int(math.floor(self.blocks * layer_size * self.dict_size_count[layer_size] / self.model_size)))
379
+
380
+ # st['lr'] = lr
381
+ # st['weight_decay'] = wd
382
+ # st['d'] = layer_size
383
+
384
+ # ##### variables for Top-K: d_index_topk is the index where the last, smaller topk block starts
385
+ # st['d_block_size'] = layer_size if layer_size < self.d_block_size else self.d_block_size
386
+ # st['topk_full_blocks_count'], st['d_index_topk'] = block_split(st['d'], st['d_block_size'])
387
+ # st['k_block_size_many'] = int(math.ceil(st['d_block_size'] * self.k_init))
388
+ # st['k_block_size_few'] = int(math.ceil((st['d'] - st['d_index_topk']) * self.k_init)) # 0 for d % self.d_block_size = 0
389
+ # st['k_index'] = st['topk_full_blocks_count'] * st['k_block_size_many']
390
+ # st['k'] = st['k_block_size_many'] * st['topk_full_blocks_count'] + st['k_block_size_few']
391
+
392
+ # ##### variables for the ring buffer
393
+ # st['index'] = 0 # the position to place a new gradient at
394
+ # st['I'] = torch.zeros(self.m, st['k'], dtype=torch.int16, device=self.device) # 2mk bytes
395
+ # st['V'] = torch.zeros(self.m, st['k'], dtype=torch.bfloat16, device=self.device) # 2mk bytes
396
+
397
+ # ### variables for error feedback: d_index_quant is the index where the last, smaller quantization block starts
398
+ # # st['quant_block_size'] = layer_size if layer_size < self.quant_block_size else self.quant_block_size
399
+ # st['quant_full_blocks_count'], st['d_index_quant'] = block_split(st['d'], self.quant_block_size)
400
+ # st['error'] = torch.zeros(int(math.ceil(st['d'] / 2)), dtype=torch.uint8, device=self.device) # ceil(d/2) bytes
401
+ # st['min_vals'] = torch.zeros(st['quant_full_blocks_count'] + 1, dtype=torch.bfloat16, device=self.device) # ceil(d/q_bsz)*2 bytes
402
+ # st['max_vals'] = torch.zeros(st['quant_full_blocks_count'] + 1, dtype=torch.bfloat16, device=self.device) # ceil(d/q_bsz)*2 bytes
@@ -0,0 +1,7 @@
1
+ from .sparse_mfac import SparseMFAC
2
+ from .sparse_core_mfac_w_ef import SparseCoreMFACwithEF
3
+
4
+ __all__ = [
5
+ 'SparseMFAC',
6
+ 'SparseCoreMFACwithEF'
7
+ ]
@@ -0,0 +1,226 @@
1
+ import math
2
+ import torch
3
+ from ..tools import block_split, KernelVersionsManager, CopyDirection
4
+
5
+ torch.backends.cuda.matmul.allow_tf32 = False
6
+ torch.backends.cudnn.allow_tf32 = False
7
+
8
+ import ista_daslab_cuda_tools
9
+ import ista_daslab_cuda_dense_mfac
10
+ import ista_daslab_cuda_sparse_mfac
11
+
12
+ USE_CUDA = True
13
+
14
+ class SparseCoreMFACwithEF:
15
+ def __init__(self, m, d, k_init, dev, gpus, damp, use_bf16):
16
+ if USE_CUDA and m % 32 != 0 or m > 1024:
17
+ raise ValueError('CUDA implementation currently on supports $m$ < 1024 and divisible by 32.')
18
+ self.m = m
19
+ self.d = d
20
+ self.k_init = k_init
21
+ self.device = dev
22
+ self.gpus = gpus
23
+ self.dtype = torch.float
24
+ self.lamda = 1. / damp
25
+ self.damp = damp
26
+ self.use_bf16 = use_bf16
27
+
28
+ ##### Error Feedback & Top-K related methods
29
+ self.error = torch.zeros(self.d, dtype=torch.bfloat16 if use_bf16 else torch.float32, device=self.device)
30
+
31
+ self.d_block_size = ista_daslab_cuda_tools.get_max_floats_for_shared_memory_per_thread_block()
32
+ self.k_block_size = math.ceil(self.d_block_size * self.k_init)
33
+ self.blocks_count, self.start_index_of_last_block = block_split(self.d, self.d_block_size)
34
+ self.k = self.k_block_size * self.blocks_count
35
+
36
+ self.last_k = 0
37
+ if self.start_index_of_last_block < self.d:
38
+ last_block_size = self.d - self.start_index_of_last_block
39
+ self.last_k = math.ceil(last_block_size * self.k_init)
40
+ self.k += self.last_k
41
+ print(f'Last block has size {last_block_size} and associated k for it is {self.last_k}')
42
+
43
+ print(f'{self.d=}, {self.k=}')
44
+ print(f'{self.d_block_size=}, {self.k_block_size=}')
45
+ print(f'{self.blocks_count=}, {self.start_index_of_last_block=}')
46
+ print(f'{self.last_k=}')
47
+
48
+ self.log_interval = 0
49
+ self.steps = 0
50
+ self.wandb_data = dict()
51
+
52
+ self.gpus_count = len(self.gpus)
53
+ self.dtype_indices = torch.int16
54
+ self.dtype_values = torch.bfloat16 if use_bf16 else torch.float
55
+
56
+ self.scalar_products = torch.zeros(self.m, dtype=torch.float, device=self.device)
57
+
58
+ self.I = torch.zeros(self.m, self.k, dtype=self.dtype_indices, device=self.device)
59
+ self.V = torch.zeros(self.m, self.k, dtype=self.dtype_values, device=self.device)
60
+
61
+ self.dots = torch.zeros((self.m, self.m), device=self.device, dtype=self.dtype) # matrix $GG^T$
62
+ self.buffer_index = 0 # ringbuffer index
63
+ self.giHig = None # matrix $D$
64
+ self.denom = torch.zeros(self.m, device=self.device, dtype=self.dtype) # $D_ii + m$
65
+ self.coef = self.lamda * torch.eye(self.m, device=self.device, dtype=self.dtype) # matrix $B$
66
+ self.setup()
67
+
68
+ self.kvm = KernelVersionsManager(version_SP=23, version_LCG=51, m=self.m, d=self.d, d_block_size=self.d_block_size)
69
+
70
+ def setup(self):
71
+ self.giHig = self.lamda * self.dots
72
+ diag_m = torch.diag(torch.full([self.m], self.m, device=self.device, dtype=self.dtype))
73
+ self.giHig = torch.lu(self.giHig + diag_m, pivot=False)[0]
74
+ self.giHig = torch.triu(self.giHig - diag_m)
75
+ self.denom = self.m + torch.diagonal(self.giHig)
76
+ tmp = -self.giHig.t().contiguous() / self.denom.reshape((1, -1))
77
+
78
+ if USE_CUDA:
79
+ diag_lambd = torch.diag(torch.full([self.m], self.lamda, device=self.device, dtype=self.dtype))
80
+ self.coef = ista_daslab_cuda_dense_mfac.hinv_setup(tmp, diag_lambd)
81
+ else:
82
+ for i in range(max(self.buffer_index, 1), self.m):
83
+ self.coef[i, :i] = tmp[i, :i].matmul(self.coef[:i, :i])
84
+
85
+ def _apply_ef_then_topk(self, g):
86
+ """
87
+ See PhD #9 page 70 for the pseudocode
88
+ """
89
+ self.error.add_(g) # the error feedback is the accumulator here
90
+
91
+ self.I[self.buffer_index, :self.k-self.last_k] = torch.topk(
92
+ input=self.error[0:self.start_index_of_last_block].abs().view(self.blocks_count, self.d_block_size),
93
+ k=self.k_block_size, # k is the same for all first n-1 blocks
94
+ sorted=False).indices.to(torch.int16).view(-1) # will have 2D shape: (blocks_count, self.block_size)
95
+
96
+ if self.start_index_of_last_block < self.d:
97
+ self.I[self.buffer_index, self.k-self.last_k:] = torch.topk(
98
+ input=self.error[self.start_index_of_last_block:].abs(),
99
+ k=self.last_k,
100
+ sorted=False).indices.to(torch.int16)
101
+
102
+ ### copy the values from the error feedback accumulator to values V (this is the G update),
103
+ ### the large tensor (error, size d) is copied to the small tensor (V, size k)
104
+ # norm_last_v_1 = self.V[self.buffer_index, :].norm(p=2).item()
105
+ ista_daslab_cuda_tools.copy_values(self.d, # V = error[I[buffer_index, :]]
106
+ self.k,
107
+ self.d_block_size,
108
+ self.k_block_size,
109
+ self.I[self.buffer_index, :], # indices
110
+ self.error, # inp
111
+ self.V[self.buffer_index, :], # output
112
+ CopyDirection.d2k.value)
113
+ # norm_last_v_2 = self.V[self.buffer_index, :].norm(p=2).item()
114
+ ### the small tensor (V, size k) is copied to the large tensor (g, size d)
115
+ g.zero_() # this will contain the values in V, at the right indices, but will also contain zeros
116
+ # norm_g_before = g.norm(p=2).item()
117
+ ista_daslab_cuda_tools.copy_values(self.d, # this does g[I[buffer_index]] = V
118
+ self.k,
119
+ self.d_block_size,
120
+ self.k_block_size,
121
+ self.I[self.buffer_index, :], # indices
122
+ self.V[self.buffer_index, :], # inp
123
+ g, # out
124
+ CopyDirection.k2d.value)
125
+ # norm_g_after = g.norm(p=2).item()
126
+ # norm_last_v_3 = self.V[self.buffer_index, :].norm(p=2).item()
127
+ # print(f'[_apply_ef_then_topk]{self.steps=}\n\t{norm_g_before=}, {norm_g_after=}\n\t{norm_last_v_1=}, {norm_last_v_2=}, {norm_last_v_3=}')
128
+ # zerorize error: subtract the top-k values (saved in V[index, :]), which are also present in g
129
+ self.error.sub_(g)
130
+
131
+ def apply_ef_then_update_buffer_then_precondition(self, g):
132
+ """
133
+ The function name says it all
134
+ Returns update inv(F) * g = 1/lambda * g - linear_combination_of_gradients (tmp contains linear comb params)
135
+ :param g: the dense gradient
136
+ :return: `the preconditioned sparse-gradient
137
+ """
138
+ self.steps += 1
139
+ self._apply_ef_then_topk(g) # after this call, g will contain the top-k values and zeros
140
+
141
+ # norm_g = g.norm(p=2).item()
142
+ # norm_last_v = self.V[self.buffer_index, :].norm(p=2).item()
143
+ # print(f'{self.steps=}, {norm_g=}, {norm_last_v=}')
144
+
145
+ dots = self._integrate_gradient(topk_values_w_zeros=g)
146
+ p = self._precondition(g, dots) # here we precondition the sparse gradient, e.g. only the top-k values, stored in the d-dim tensor g
147
+ return p
148
+
149
+ def _integrate_gradient(self, topk_values_w_zeros):
150
+ tmp = self.compute_scalar_products(topk_values_w_zeros)
151
+ tmp = tmp.squeeze() # (d, 1) becomes (d,)
152
+
153
+ self.dots[self.buffer_index, :] = tmp
154
+ self.dots[:, self.buffer_index] = tmp
155
+
156
+ self.setup()
157
+
158
+ self.buffer_index = (self.buffer_index + 1) % self.m
159
+ return tmp
160
+
161
+ def _precondition(self, g, dots=None):
162
+ """
163
+ Returns the update inv(F) * x
164
+ The matrix M stores the coefficients of the linear combination
165
+ x: usually the sparse gradient
166
+ """
167
+ # print(f'[precondition]')
168
+ if dots is None:
169
+ dots = self.compute_scalar_products(g)
170
+ giHix = self.lamda * dots
171
+ if USE_CUDA:
172
+ giHix = ista_daslab_cuda_dense_mfac.hinv_mul(self.m, self.giHig, giHix)
173
+ torch.cuda.synchronize()
174
+ else:
175
+ for i in range(1, self.m):
176
+ giHix[i:].sub_(self.giHig[i - 1, i:], alpha=giHix[i - 1] / self.denom[i - 1])
177
+ M = (giHix / self.denom).matmul(self.coef) # .view(-1, 1) # view is linked to matmul_grads_sequential_batch
178
+
179
+ partA = self.lamda * g
180
+ partB = self.compute_linear_combination(M, out=g) # out will be returned such that partB = out
181
+ if self.steps > 0 and self.log_interval > 0 and self.steps % self.log_interval == 0:
182
+ self.wandb_data.update(dict(norm_partA=partA.norm(p=2), norm_partB=partB.norm(p=2)))
183
+ return partA - partB
184
+
185
+ def compute_scalar_products(self, g):
186
+ self.scalar_products.zero_()
187
+
188
+ ista_daslab_cuda_sparse_mfac.SP(
189
+ self.kvm.get_SP_blocks(),
190
+ self.kvm.get_SP_threads(),
191
+ self.kvm.version_SP,
192
+ self.d,
193
+ min(self.m, self.steps),
194
+ self.k,
195
+ self.d_block_size,
196
+ self.k_block_size,
197
+ g,
198
+ self.I,
199
+ self.V,
200
+ self.scalar_products,
201
+ int(self.use_bf16))
202
+
203
+ return self.scalar_products
204
+
205
+ def compute_linear_combination(self, M, out):
206
+ out.zero_()
207
+
208
+ ista_daslab_cuda_sparse_mfac.LCG(
209
+ self.kvm.get_LCG_blocks(),
210
+ self.kvm.get_LCG_threads(),
211
+ self.kvm.version_LCG,
212
+ self.d,
213
+ min(self.m, self.steps),
214
+ self.k,
215
+ self.d_block_size,
216
+ self.k_block_size,
217
+ M,
218
+ self.I,
219
+ self.V,
220
+ out,
221
+ int(self.use_bf16))
222
+
223
+ if self.steps > 0 and self.log_interval > 0 and self.steps % self.log_interval == 0:
224
+ self.wandb_data.update(dict(lin_comb_coef_norm=M.norm(p=2)))
225
+
226
+ return out