ista-daslab-optimizers 1.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ista_daslab_optimizers/__init__.py +6 -0
- ista_daslab_optimizers/acdc/__init__.py +5 -0
- ista_daslab_optimizers/acdc/acdc.py +387 -0
- ista_daslab_optimizers/acdc/wd_scheduler.py +31 -0
- ista_daslab_optimizers/dense_mfac/__init__.py +5 -0
- ista_daslab_optimizers/dense_mfac/dense_core_mfac.py +164 -0
- ista_daslab_optimizers/dense_mfac/dense_mfac.py +93 -0
- ista_daslab_optimizers/fft_low_rank/dct_adamw.py +351 -0
- ista_daslab_optimizers/fft_low_rank/fft_projector.py +192 -0
- ista_daslab_optimizers/fft_low_rank/trion.py +242 -0
- ista_daslab_optimizers/ista_optimizer/__init__.py +5 -0
- ista_daslab_optimizers/ista_optimizer/ista_optimizer.py +36 -0
- ista_daslab_optimizers/micro_adam/__init__.py +5 -0
- ista_daslab_optimizers/micro_adam/micro_adam.py +402 -0
- ista_daslab_optimizers/sparse_mfac/__init__.py +7 -0
- ista_daslab_optimizers/sparse_mfac/sparse_core_mfac_w_ef.py +226 -0
- ista_daslab_optimizers/sparse_mfac/sparse_mfac.py +87 -0
- ista_daslab_optimizers/tools.py +218 -0
- ista_daslab_optimizers/utils/dct.py +45 -0
- ista_daslab_optimizers/utils/global_cache.py +45 -0
- ista_daslab_optimizers/utils/matrix_storage.py +58 -0
- ista_daslab_optimizers/utils/newton_schulz_triton.py +374 -0
- ista_daslab_optimizers/utils/quantizers.py +71 -0
- ista_daslab_optimizers/utils/schedulers.py +41 -0
- ista_daslab_optimizers-1.1.8.dist-info/METADATA +333 -0
- ista_daslab_optimizers-1.1.8.dist-info/RECORD +29 -0
- ista_daslab_optimizers-1.1.8.dist-info/WHEEL +5 -0
- ista_daslab_optimizers-1.1.8.dist-info/licenses/LICENSE +201 -0
- ista_daslab_optimizers-1.1.8.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,402 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
import math
|
|
4
|
+
import time
|
|
5
|
+
import wandb
|
|
6
|
+
from torch.distributed import is_initialized, get_rank, all_reduce, ReduceOp
|
|
7
|
+
from ..tools import get_first_device, get_gpu_mem_usage, block_split, CopyDirection
|
|
8
|
+
|
|
9
|
+
import ista_daslab_cuda_tools
|
|
10
|
+
import ista_daslab_cuda_micro_adam
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MicroAdam(torch.optim.Optimizer):
|
|
14
|
+
def __init__(self, params, m, lr, quant_block_size, k_init=0.01, alpha=0, betas=(0.9, 0.999), weight_decay=0, eps=1e-8):
|
|
15
|
+
defaults = dict(lr=lr, weight_decay=weight_decay, eps=eps, alpha=alpha)
|
|
16
|
+
super(MicroAdam, self).__init__(params, defaults)
|
|
17
|
+
|
|
18
|
+
assert (0 <= alpha < 1) or alpha == -2, 'Alpha must be in the [0, 1) interval or -2'
|
|
19
|
+
|
|
20
|
+
self.m = m
|
|
21
|
+
self.lr = lr
|
|
22
|
+
self.quant_block_size = int(quant_block_size)
|
|
23
|
+
self.k_init = k_init
|
|
24
|
+
self.alpha = alpha
|
|
25
|
+
self.weight_decay = weight_decay
|
|
26
|
+
self.beta1 = betas[0]
|
|
27
|
+
self.beta2 = betas[1]
|
|
28
|
+
self.eps = eps
|
|
29
|
+
|
|
30
|
+
self.densify_update_using_ef = (self.alpha > 0)
|
|
31
|
+
self.densify_update_using_quant_error = (self.alpha == -2)
|
|
32
|
+
|
|
33
|
+
self.model_size = sum([p.numel() for group in self.param_groups for p in group['params']])
|
|
34
|
+
|
|
35
|
+
self.steps = 0 # how many optimization steps were performed so far
|
|
36
|
+
self.log_interval = 100
|
|
37
|
+
self.device = get_first_device()
|
|
38
|
+
self._is_state_initialized = False
|
|
39
|
+
self.shared_memory_carveout = 100
|
|
40
|
+
self.blocks = ista_daslab_cuda_tools.get_sm_count() * int(100 / self.shared_memory_carveout)
|
|
41
|
+
self.threads = 512
|
|
42
|
+
|
|
43
|
+
self.max_floats = ista_daslab_cuda_tools.get_max_floats_for_shared_memory_per_thread_block()
|
|
44
|
+
self.d_block_size = self.max_floats // 2 // int(100 / self.shared_memory_carveout)
|
|
45
|
+
|
|
46
|
+
if torch.distributed.is_initialized():
|
|
47
|
+
self.fsdp_dict_size_count = [{} for _ in range(
|
|
48
|
+
torch.distributed.get_world_size())] # key = layer size, value = how many layers of that size the model has (per worker)
|
|
49
|
+
else:
|
|
50
|
+
self.fsdp_dict_size_count = [{}]
|
|
51
|
+
|
|
52
|
+
self.dict_size_count = {} # key = layer size, value = how many layers of that size the model has
|
|
53
|
+
for param in self.param_groups:
|
|
54
|
+
for p in param['params']:
|
|
55
|
+
size = p.numel()
|
|
56
|
+
# print(p.shape, p.numel())
|
|
57
|
+
self.dict_size_count[size] = 1 + self.dict_size_count.get(size, 0)
|
|
58
|
+
|
|
59
|
+
# self._init_state()
|
|
60
|
+
|
|
61
|
+
def _initialize_parameter_state(self, p, lr, wd):
|
|
62
|
+
layer_size = p.numel()
|
|
63
|
+
st = self.state[p]
|
|
64
|
+
|
|
65
|
+
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
|
|
66
|
+
|
|
67
|
+
if self.densify_update_using_quant_error:
|
|
68
|
+
st['quant_err'] = torch.zeros_like(p)
|
|
69
|
+
|
|
70
|
+
st['blocks'] = max(1, int(math.floor(self.blocks * layer_size * self.fsdp_dict_size_count[rank][layer_size] / self.model_size)))
|
|
71
|
+
|
|
72
|
+
st['lr'] = lr
|
|
73
|
+
st['weight_decay'] = wd
|
|
74
|
+
st['d'] = layer_size
|
|
75
|
+
|
|
76
|
+
##### variables for Top-K: d_index_topk is the index where the last, smaller topk block starts
|
|
77
|
+
st['d_block_size'] = layer_size if layer_size < self.d_block_size else self.d_block_size
|
|
78
|
+
st['topk_full_blocks_count'], st['d_index_topk'] = block_split(st['d'], st['d_block_size'])
|
|
79
|
+
st['k_block_size_many'] = int(math.ceil(st['d_block_size'] * self.k_init))
|
|
80
|
+
st['k_block_size_few'] = int(math.ceil((st['d'] - st['d_index_topk']) * self.k_init)) # 0 for d % self.d_block_size = 0
|
|
81
|
+
st['k_index'] = st['topk_full_blocks_count'] * st['k_block_size_many']
|
|
82
|
+
st['k'] = st['k_block_size_many'] * st['topk_full_blocks_count'] + st['k_block_size_few']
|
|
83
|
+
|
|
84
|
+
##### variables for the ring buffer
|
|
85
|
+
st['index'] = 0 # the position to place a new gradient at
|
|
86
|
+
st['I'] = torch.zeros(self.m, st['k'], dtype=torch.int16, device=self.device) # 2mk bytes
|
|
87
|
+
st['V'] = torch.zeros(self.m, st['k'], dtype=torch.bfloat16, device=self.device) # 2mk bytes
|
|
88
|
+
|
|
89
|
+
### variables for error feedback: d_index_quant is the index where the last, smaller quantization block starts
|
|
90
|
+
# st['quant_block_size'] = layer_size if layer_size < self.quant_block_size else self.quant_block_size
|
|
91
|
+
st['quant_full_blocks_count'], st['d_index_quant'] = block_split(st['d'], self.quant_block_size)
|
|
92
|
+
st['error'] = torch.zeros(int(math.ceil(st['d'] / 2)), dtype=torch.uint8, device=self.device) # ceil(d/2) bytes
|
|
93
|
+
st['min_vals'] = torch.zeros(st['quant_full_blocks_count'] + 1, dtype=torch.bfloat16, device=self.device) # ceil(d/q_bsz)*2 bytes
|
|
94
|
+
st['max_vals'] = torch.zeros(st['quant_full_blocks_count'] + 1, dtype=torch.bfloat16, device=self.device) # ceil(d/q_bsz)*2 bytes
|
|
95
|
+
|
|
96
|
+
@torch.no_grad()
|
|
97
|
+
def step(self, closure=None):
|
|
98
|
+
self.steps += 1
|
|
99
|
+
|
|
100
|
+
# self._update_lr_wd()
|
|
101
|
+
|
|
102
|
+
loss = None
|
|
103
|
+
if closure is not None:
|
|
104
|
+
with torch.enable_grad():
|
|
105
|
+
loss = closure()
|
|
106
|
+
|
|
107
|
+
if self.steps == 1:
|
|
108
|
+
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
|
|
109
|
+
for param in self.param_groups:
|
|
110
|
+
for p in param['params']:
|
|
111
|
+
if p is not None:
|
|
112
|
+
size = p.numel()
|
|
113
|
+
if size > 0:
|
|
114
|
+
self.fsdp_dict_size_count[rank][size] = 1 + self.fsdp_dict_size_count[rank].get(size, 0)
|
|
115
|
+
|
|
116
|
+
time_start = time.time()
|
|
117
|
+
|
|
118
|
+
norm_qe, norm_g, norm_u, norm_e, sparsity_u, sparsity_qe = 0, 0, 0, 0, 0, 0
|
|
119
|
+
|
|
120
|
+
for group in self.param_groups:
|
|
121
|
+
lr = group['lr']
|
|
122
|
+
wd = group.get('weight_decay', self.weight_decay)
|
|
123
|
+
|
|
124
|
+
for p in group['params']:
|
|
125
|
+
if p.grad is None:
|
|
126
|
+
continue
|
|
127
|
+
|
|
128
|
+
if p is None:
|
|
129
|
+
continue
|
|
130
|
+
|
|
131
|
+
nqe, ng, nu, ne, sp_u, sp_qe = self.update_step(p, lr, wd)
|
|
132
|
+
norm_qe += nqe
|
|
133
|
+
norm_g += ng
|
|
134
|
+
norm_u += nu
|
|
135
|
+
norm_e += ne
|
|
136
|
+
sparsity_u += sp_u
|
|
137
|
+
sparsity_qe += sp_qe
|
|
138
|
+
|
|
139
|
+
# torch.cuda.synchronize()
|
|
140
|
+
time_end = time.time()
|
|
141
|
+
elapsed_step = time_end - time_start
|
|
142
|
+
self._log(norm_qe, norm_g, norm_u, norm_e, sparsity_u, sparsity_qe, elapsed_step)
|
|
143
|
+
|
|
144
|
+
return loss
|
|
145
|
+
|
|
146
|
+
@torch.no_grad()
|
|
147
|
+
def update_step(self, p, lr, wd):
|
|
148
|
+
norm_qe, norm_g, norm_u, norm_e, sp_u, sp_qe = 0, 0, 0, 0, 0, 0
|
|
149
|
+
|
|
150
|
+
# if p.grad.dtype != torch.bfloat16:
|
|
151
|
+
# grad = p.grad.to(dtype=torch.bfloat16).reshape(-1)
|
|
152
|
+
# else:
|
|
153
|
+
grad = p.grad.view(-1)
|
|
154
|
+
|
|
155
|
+
if self.steps % self.log_interval == 0:
|
|
156
|
+
norm_g = grad.norm(p=2) ** 2
|
|
157
|
+
|
|
158
|
+
st = self.state[p]
|
|
159
|
+
if len(st) == 0:
|
|
160
|
+
self._initialize_parameter_state(p, lr, wd)
|
|
161
|
+
|
|
162
|
+
# print('rank=',torch.distributed.get_rank(), 'keys=',st.keys())
|
|
163
|
+
|
|
164
|
+
blocks = st['blocks']
|
|
165
|
+
# lr = st['lr']
|
|
166
|
+
# wd = st['weight_decay']
|
|
167
|
+
d = st['d']
|
|
168
|
+
d_block_size = st['d_block_size']
|
|
169
|
+
topk_full_blocks_count, d_index_topk = st['topk_full_blocks_count'], st['d_index_topk']
|
|
170
|
+
k_block_size_many = st['k_block_size_many']
|
|
171
|
+
k_block_size_few = st['k_block_size_few']
|
|
172
|
+
k_index = st['k_index']
|
|
173
|
+
k = st['k']
|
|
174
|
+
|
|
175
|
+
# HuggingFace has a setting that converts st['I'] to bfloat16, even though it is declared as int16
|
|
176
|
+
# This happens somewhere between constructor call and step call. Converting it to int16 was the simplest solution
|
|
177
|
+
if st['I'].dtype != torch.int16:
|
|
178
|
+
st['I'] = st['I'].to(torch.int16)
|
|
179
|
+
|
|
180
|
+
index = st['index']
|
|
181
|
+
I = st['I']
|
|
182
|
+
V = st['V']
|
|
183
|
+
|
|
184
|
+
quant_full_blocks_count, d_index_quant = st['quant_full_blocks_count'], st['d_index_quant']
|
|
185
|
+
error = st['error']
|
|
186
|
+
min_vals = st['min_vals']
|
|
187
|
+
max_vals = st['max_vals']
|
|
188
|
+
|
|
189
|
+
##### STEP 4
|
|
190
|
+
ista_daslab_cuda_micro_adam.asymm_block_quant_inv(d, self.quant_block_size, error, min_vals, max_vals, grad, 1.0) # alpha=1 here
|
|
191
|
+
|
|
192
|
+
##### STEP 5 + 9 (only for I)
|
|
193
|
+
I[index, :k_index] = torch.topk(input=grad[0:d_index_topk].abs().view(topk_full_blocks_count, d_block_size),
|
|
194
|
+
k=k_block_size_many,
|
|
195
|
+
sorted=False).indices.to(dtype=torch.int16).view(-1)
|
|
196
|
+
|
|
197
|
+
if k_block_size_few > 0: # there is a small block left
|
|
198
|
+
I[index, k_index:] = torch.topk(input=grad[d_index_topk:].abs(),
|
|
199
|
+
k=k_block_size_few, # example: slice has size 1, but ks[-1] is 4
|
|
200
|
+
sorted=False).indices.to(dtype=torch.int16).view(-1)
|
|
201
|
+
|
|
202
|
+
ista_daslab_cuda_tools.copy_values(d, # V = error[I[buffer_index, :]]
|
|
203
|
+
k,
|
|
204
|
+
d_block_size,
|
|
205
|
+
k_block_size_many,
|
|
206
|
+
I[index, :], # indices
|
|
207
|
+
grad, # inp
|
|
208
|
+
V[index, :], # output
|
|
209
|
+
CopyDirection.d2k.value)
|
|
210
|
+
|
|
211
|
+
st['index'] = (index + 1) % self.m
|
|
212
|
+
|
|
213
|
+
##### STEP 6
|
|
214
|
+
ista_daslab_cuda_tools.zerorize_block_components(grad, I[index, :], d, k, d_block_size, k_block_size_many) # this does a[I[index]] = 0
|
|
215
|
+
|
|
216
|
+
##### STEP 7
|
|
217
|
+
def _update_quantization_statistics():
|
|
218
|
+
if quant_full_blocks_count == 1:
|
|
219
|
+
min_vals[:quant_full_blocks_count] = grad[:d_index_quant].min()
|
|
220
|
+
max_vals[:quant_full_blocks_count] = grad[:d_index_quant].max()
|
|
221
|
+
else:
|
|
222
|
+
min_vals[:quant_full_blocks_count] = grad[:d_index_quant].view(quant_full_blocks_count, self.quant_block_size).min(dim=1).values
|
|
223
|
+
max_vals[:quant_full_blocks_count] = grad[:d_index_quant].view(quant_full_blocks_count, self.quant_block_size).max(dim=1).values
|
|
224
|
+
if d_index_quant < d:
|
|
225
|
+
min_vals[quant_full_blocks_count] = grad[d_index_quant:].min()
|
|
226
|
+
max_vals[quant_full_blocks_count] = grad[d_index_quant:].max()
|
|
227
|
+
|
|
228
|
+
_update_quantization_statistics()
|
|
229
|
+
|
|
230
|
+
##### STEP 8
|
|
231
|
+
ista_daslab_cuda_micro_adam.asymm_block_quant(d, self.quant_block_size, error, min_vals, max_vals, grad) # error = Q(a, min, max)
|
|
232
|
+
|
|
233
|
+
# weight decay step
|
|
234
|
+
if wd > 0:
|
|
235
|
+
p.mul_(1 - lr * wd)
|
|
236
|
+
|
|
237
|
+
##### NEW: densify using quant error
|
|
238
|
+
if self.densify_update_using_quant_error:
|
|
239
|
+
# When entering this if-statement, we have:
|
|
240
|
+
# - p is theta_t
|
|
241
|
+
# - p.grad is a_t (from step 6 in algorithm 1)
|
|
242
|
+
# - error is e_t+1 (from step 8 in algorithm 1)
|
|
243
|
+
#
|
|
244
|
+
# Below we have the formula to update the model parameters:
|
|
245
|
+
# [a = -1] with lr
|
|
246
|
+
# theta_t+1 = theta_t - lr * (a_t - Qinv(e_t+1)) - lr * u_t
|
|
247
|
+
# = theta_t - lr * a_t + lr * Qinv(e_t+1) - lr * u_t
|
|
248
|
+
# = theta_t - lr * a_t # STEP A below, in this if statmenet
|
|
249
|
+
# + lr * Qinv(e_t+1) # STEP B below, in this if statmenet
|
|
250
|
+
# - lr * u_t # this is steps 10-11
|
|
251
|
+
#
|
|
252
|
+
# [a = -2] without lr
|
|
253
|
+
# theta_t+1 = theta_t - (a_t - Qinv(e_t+1)) - lr * u_t
|
|
254
|
+
# = theta_t - a_t + Qinv(e_t+1) - lr * u_t
|
|
255
|
+
# = theta_t - a_t # STEP A below, in this if statmenet
|
|
256
|
+
# + Qinv(e_t+1) # STEP B below, in this if statmenet
|
|
257
|
+
# - lr * u_t # this is steps 10-11
|
|
258
|
+
quant_err = st['quant_err']
|
|
259
|
+
quant_err.zero_()
|
|
260
|
+
quant_err.add_(p.grad)
|
|
261
|
+
|
|
262
|
+
##### STEP A
|
|
263
|
+
p.add_(p.grad, alpha=-1)
|
|
264
|
+
|
|
265
|
+
##### STEP B
|
|
266
|
+
p.grad.zero_() # zerorize to prepare the accumulator for Qinv
|
|
267
|
+
ista_daslab_cuda_micro_adam.asymm_block_quant_inv(d, self.quant_block_size, error, min_vals, max_vals, grad, 1)
|
|
268
|
+
p.add_(p.grad)
|
|
269
|
+
|
|
270
|
+
quant_err.sub_(p.grad)
|
|
271
|
+
|
|
272
|
+
norm_qe = quant_err.norm(p=2) ** 2
|
|
273
|
+
sp_qe = (quant_err == 0).sum()
|
|
274
|
+
|
|
275
|
+
##### STEPS 10-11
|
|
276
|
+
grad.zero_()
|
|
277
|
+
ista_daslab_cuda_micro_adam.compute_microadam_update(blocks, # blocks
|
|
278
|
+
self.threads, # threads
|
|
279
|
+
self.shared_memory_carveout, # carveout
|
|
280
|
+
self.steps, # optimization step
|
|
281
|
+
self.beta1, # beta1
|
|
282
|
+
self.beta2, # beta2
|
|
283
|
+
self.eps, # eps
|
|
284
|
+
d_block_size, # d_block_size
|
|
285
|
+
k_block_size_many, # k_block_size
|
|
286
|
+
d, # d
|
|
287
|
+
self.m, # m
|
|
288
|
+
k, # k
|
|
289
|
+
I, # indices
|
|
290
|
+
V, # values
|
|
291
|
+
grad) # update will be stored here
|
|
292
|
+
|
|
293
|
+
##### STEP 12: # side idea: only decay the weights that are update
|
|
294
|
+
|
|
295
|
+
##### if PRETRAINING #1
|
|
296
|
+
if self.densify_update_using_ef: # we add alpha * EF to update that is stored in grad buffer
|
|
297
|
+
# p.grad += alpha * Qinv(error), alpha=0.1
|
|
298
|
+
ista_daslab_cuda_micro_adam.asymm_block_quant_inv(d, self.quant_block_size, error, min_vals, max_vals, grad, self.alpha)
|
|
299
|
+
##### END IF PRETRAINING #1
|
|
300
|
+
|
|
301
|
+
# if alpha > 0, then the update u=p.grad is dense now
|
|
302
|
+
|
|
303
|
+
# update model using MicroAdam update stored in p.grad
|
|
304
|
+
p.add_(p.grad, alpha=-lr)
|
|
305
|
+
|
|
306
|
+
if self.steps % self.log_interval == 0:
|
|
307
|
+
norm_u = grad.norm(p=2) ** 2
|
|
308
|
+
sp_u = (grad == 0).sum() # check sparsity before zerorizing
|
|
309
|
+
|
|
310
|
+
##### if PRETRAINING #2
|
|
311
|
+
if self.densify_update_using_ef:
|
|
312
|
+
grad.zero_()
|
|
313
|
+
ista_daslab_cuda_micro_adam.asymm_block_quant_inv(d, self.quant_block_size, error, min_vals, max_vals, grad, 1-self.alpha)
|
|
314
|
+
|
|
315
|
+
_update_quantization_statistics() # step 7 again
|
|
316
|
+
ista_daslab_cuda_micro_adam.asymm_block_quant(d, self.quant_block_size, error, min_vals, max_vals, grad) # step 8 again
|
|
317
|
+
##### END IF PRETRAINING #2
|
|
318
|
+
|
|
319
|
+
# compute error norm
|
|
320
|
+
if self.steps % self.log_interval == 0:
|
|
321
|
+
grad.zero_()
|
|
322
|
+
ista_daslab_cuda_micro_adam.asymm_block_quant_inv(d, self.quant_block_size, error, min_vals, max_vals, grad, 1.0)
|
|
323
|
+
|
|
324
|
+
norm_e = grad.norm(p=2) ** 2
|
|
325
|
+
|
|
326
|
+
# p.grad = p.grad.to(dtype=original_grad_type)
|
|
327
|
+
|
|
328
|
+
return norm_qe, norm_g, norm_u, norm_e, sp_u, sp_qe
|
|
329
|
+
|
|
330
|
+
def _log(self, norm_qe, norm_g, norm_u, norm_e, sparsity_u, sparsity_qe, elapsed_step):
|
|
331
|
+
if self.steps % self.log_interval == 0:
|
|
332
|
+
if is_initialized():
|
|
333
|
+
sync_data = torch.tensor([norm_qe, norm_g, norm_u, norm_e, sparsity_u, sparsity_qe, elapsed_step], dtype=torch.float,
|
|
334
|
+
requires_grad=False).cuda() # correct, loss, size
|
|
335
|
+
all_reduce(sync_data, op=ReduceOp.SUM)
|
|
336
|
+
norm_qe, norm_g, norm_u, norm_e, sparsity_u, sparsity_qe, elapsed_step = sync_data
|
|
337
|
+
|
|
338
|
+
if not is_initialized() or get_rank() == 0:
|
|
339
|
+
wandb_data = {
|
|
340
|
+
'step/optimizer_steps': self.steps,
|
|
341
|
+
'step/gpu_mem_usage': get_gpu_mem_usage(),
|
|
342
|
+
'step/norm_quant_err': math.sqrt(norm_qe),
|
|
343
|
+
'step/sparsity_quant_err': sparsity_qe / self.model_size * 100.,
|
|
344
|
+
'step/norm_g': math.sqrt(norm_g),
|
|
345
|
+
'step/norm_u': math.sqrt(norm_u),
|
|
346
|
+
'step/norm_error': math.sqrt(norm_e),
|
|
347
|
+
'step/sparsity_u': sparsity_u / self.model_size * 100.,
|
|
348
|
+
'step/elapsed_step': elapsed_step,
|
|
349
|
+
}
|
|
350
|
+
wandb.log(wandb_data, commit=False)
|
|
351
|
+
|
|
352
|
+
# def _update_lr_wd(self):
|
|
353
|
+
# # copy the learning rate group to parameter state because the lr scheduler updates the one in the group
|
|
354
|
+
# for group in self.param_groups:
|
|
355
|
+
# lr = group['lr']
|
|
356
|
+
# wd = group.get('weight_decay', self.weight_decay) # if the param groups do not have weight decay, then use the external one
|
|
357
|
+
# for p in group['params']:
|
|
358
|
+
# self.state[p]['lr'] = lr
|
|
359
|
+
# self.state[p]['wd'] = wd
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
# def _init_state(self):
|
|
363
|
+
# count = 0
|
|
364
|
+
# for group in self.param_groups:
|
|
365
|
+
# lr = group['lr']
|
|
366
|
+
# wd = group.get('weight_decay', self.weight_decay) # if the param groups do not have weight decay, then use the external one
|
|
367
|
+
# for p in group['params']:
|
|
368
|
+
# if not p.requires_grad:
|
|
369
|
+
# continue
|
|
370
|
+
|
|
371
|
+
# print(f'[init_state] rank={torch.distributed.get_rank()}, p.shape={p.shape}')
|
|
372
|
+
|
|
373
|
+
# count += 1
|
|
374
|
+
# layer_size = p.numel()
|
|
375
|
+
# st = self.state[p]
|
|
376
|
+
|
|
377
|
+
# # B * t / d * nt
|
|
378
|
+
# st['blocks'] = max(1, int(math.floor(self.blocks * layer_size * self.dict_size_count[layer_size] / self.model_size)))
|
|
379
|
+
|
|
380
|
+
# st['lr'] = lr
|
|
381
|
+
# st['weight_decay'] = wd
|
|
382
|
+
# st['d'] = layer_size
|
|
383
|
+
|
|
384
|
+
# ##### variables for Top-K: d_index_topk is the index where the last, smaller topk block starts
|
|
385
|
+
# st['d_block_size'] = layer_size if layer_size < self.d_block_size else self.d_block_size
|
|
386
|
+
# st['topk_full_blocks_count'], st['d_index_topk'] = block_split(st['d'], st['d_block_size'])
|
|
387
|
+
# st['k_block_size_many'] = int(math.ceil(st['d_block_size'] * self.k_init))
|
|
388
|
+
# st['k_block_size_few'] = int(math.ceil((st['d'] - st['d_index_topk']) * self.k_init)) # 0 for d % self.d_block_size = 0
|
|
389
|
+
# st['k_index'] = st['topk_full_blocks_count'] * st['k_block_size_many']
|
|
390
|
+
# st['k'] = st['k_block_size_many'] * st['topk_full_blocks_count'] + st['k_block_size_few']
|
|
391
|
+
|
|
392
|
+
# ##### variables for the ring buffer
|
|
393
|
+
# st['index'] = 0 # the position to place a new gradient at
|
|
394
|
+
# st['I'] = torch.zeros(self.m, st['k'], dtype=torch.int16, device=self.device) # 2mk bytes
|
|
395
|
+
# st['V'] = torch.zeros(self.m, st['k'], dtype=torch.bfloat16, device=self.device) # 2mk bytes
|
|
396
|
+
|
|
397
|
+
# ### variables for error feedback: d_index_quant is the index where the last, smaller quantization block starts
|
|
398
|
+
# # st['quant_block_size'] = layer_size if layer_size < self.quant_block_size else self.quant_block_size
|
|
399
|
+
# st['quant_full_blocks_count'], st['d_index_quant'] = block_split(st['d'], self.quant_block_size)
|
|
400
|
+
# st['error'] = torch.zeros(int(math.ceil(st['d'] / 2)), dtype=torch.uint8, device=self.device) # ceil(d/2) bytes
|
|
401
|
+
# st['min_vals'] = torch.zeros(st['quant_full_blocks_count'] + 1, dtype=torch.bfloat16, device=self.device) # ceil(d/q_bsz)*2 bytes
|
|
402
|
+
# st['max_vals'] = torch.zeros(st['quant_full_blocks_count'] + 1, dtype=torch.bfloat16, device=self.device) # ceil(d/q_bsz)*2 bytes
|
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import torch
|
|
3
|
+
from ..tools import block_split, KernelVersionsManager, CopyDirection
|
|
4
|
+
|
|
5
|
+
torch.backends.cuda.matmul.allow_tf32 = False
|
|
6
|
+
torch.backends.cudnn.allow_tf32 = False
|
|
7
|
+
|
|
8
|
+
import ista_daslab_cuda_tools
|
|
9
|
+
import ista_daslab_cuda_dense_mfac
|
|
10
|
+
import ista_daslab_cuda_sparse_mfac
|
|
11
|
+
|
|
12
|
+
USE_CUDA = True
|
|
13
|
+
|
|
14
|
+
class SparseCoreMFACwithEF:
|
|
15
|
+
def __init__(self, m, d, k_init, dev, gpus, damp, use_bf16):
|
|
16
|
+
if USE_CUDA and m % 32 != 0 or m > 1024:
|
|
17
|
+
raise ValueError('CUDA implementation currently on supports $m$ < 1024 and divisible by 32.')
|
|
18
|
+
self.m = m
|
|
19
|
+
self.d = d
|
|
20
|
+
self.k_init = k_init
|
|
21
|
+
self.device = dev
|
|
22
|
+
self.gpus = gpus
|
|
23
|
+
self.dtype = torch.float
|
|
24
|
+
self.lamda = 1. / damp
|
|
25
|
+
self.damp = damp
|
|
26
|
+
self.use_bf16 = use_bf16
|
|
27
|
+
|
|
28
|
+
##### Error Feedback & Top-K related methods
|
|
29
|
+
self.error = torch.zeros(self.d, dtype=torch.bfloat16 if use_bf16 else torch.float32, device=self.device)
|
|
30
|
+
|
|
31
|
+
self.d_block_size = ista_daslab_cuda_tools.get_max_floats_for_shared_memory_per_thread_block()
|
|
32
|
+
self.k_block_size = math.ceil(self.d_block_size * self.k_init)
|
|
33
|
+
self.blocks_count, self.start_index_of_last_block = block_split(self.d, self.d_block_size)
|
|
34
|
+
self.k = self.k_block_size * self.blocks_count
|
|
35
|
+
|
|
36
|
+
self.last_k = 0
|
|
37
|
+
if self.start_index_of_last_block < self.d:
|
|
38
|
+
last_block_size = self.d - self.start_index_of_last_block
|
|
39
|
+
self.last_k = math.ceil(last_block_size * self.k_init)
|
|
40
|
+
self.k += self.last_k
|
|
41
|
+
print(f'Last block has size {last_block_size} and associated k for it is {self.last_k}')
|
|
42
|
+
|
|
43
|
+
print(f'{self.d=}, {self.k=}')
|
|
44
|
+
print(f'{self.d_block_size=}, {self.k_block_size=}')
|
|
45
|
+
print(f'{self.blocks_count=}, {self.start_index_of_last_block=}')
|
|
46
|
+
print(f'{self.last_k=}')
|
|
47
|
+
|
|
48
|
+
self.log_interval = 0
|
|
49
|
+
self.steps = 0
|
|
50
|
+
self.wandb_data = dict()
|
|
51
|
+
|
|
52
|
+
self.gpus_count = len(self.gpus)
|
|
53
|
+
self.dtype_indices = torch.int16
|
|
54
|
+
self.dtype_values = torch.bfloat16 if use_bf16 else torch.float
|
|
55
|
+
|
|
56
|
+
self.scalar_products = torch.zeros(self.m, dtype=torch.float, device=self.device)
|
|
57
|
+
|
|
58
|
+
self.I = torch.zeros(self.m, self.k, dtype=self.dtype_indices, device=self.device)
|
|
59
|
+
self.V = torch.zeros(self.m, self.k, dtype=self.dtype_values, device=self.device)
|
|
60
|
+
|
|
61
|
+
self.dots = torch.zeros((self.m, self.m), device=self.device, dtype=self.dtype) # matrix $GG^T$
|
|
62
|
+
self.buffer_index = 0 # ringbuffer index
|
|
63
|
+
self.giHig = None # matrix $D$
|
|
64
|
+
self.denom = torch.zeros(self.m, device=self.device, dtype=self.dtype) # $D_ii + m$
|
|
65
|
+
self.coef = self.lamda * torch.eye(self.m, device=self.device, dtype=self.dtype) # matrix $B$
|
|
66
|
+
self.setup()
|
|
67
|
+
|
|
68
|
+
self.kvm = KernelVersionsManager(version_SP=23, version_LCG=51, m=self.m, d=self.d, d_block_size=self.d_block_size)
|
|
69
|
+
|
|
70
|
+
def setup(self):
|
|
71
|
+
self.giHig = self.lamda * self.dots
|
|
72
|
+
diag_m = torch.diag(torch.full([self.m], self.m, device=self.device, dtype=self.dtype))
|
|
73
|
+
self.giHig = torch.lu(self.giHig + diag_m, pivot=False)[0]
|
|
74
|
+
self.giHig = torch.triu(self.giHig - diag_m)
|
|
75
|
+
self.denom = self.m + torch.diagonal(self.giHig)
|
|
76
|
+
tmp = -self.giHig.t().contiguous() / self.denom.reshape((1, -1))
|
|
77
|
+
|
|
78
|
+
if USE_CUDA:
|
|
79
|
+
diag_lambd = torch.diag(torch.full([self.m], self.lamda, device=self.device, dtype=self.dtype))
|
|
80
|
+
self.coef = ista_daslab_cuda_dense_mfac.hinv_setup(tmp, diag_lambd)
|
|
81
|
+
else:
|
|
82
|
+
for i in range(max(self.buffer_index, 1), self.m):
|
|
83
|
+
self.coef[i, :i] = tmp[i, :i].matmul(self.coef[:i, :i])
|
|
84
|
+
|
|
85
|
+
def _apply_ef_then_topk(self, g):
|
|
86
|
+
"""
|
|
87
|
+
See PhD #9 page 70 for the pseudocode
|
|
88
|
+
"""
|
|
89
|
+
self.error.add_(g) # the error feedback is the accumulator here
|
|
90
|
+
|
|
91
|
+
self.I[self.buffer_index, :self.k-self.last_k] = torch.topk(
|
|
92
|
+
input=self.error[0:self.start_index_of_last_block].abs().view(self.blocks_count, self.d_block_size),
|
|
93
|
+
k=self.k_block_size, # k is the same for all first n-1 blocks
|
|
94
|
+
sorted=False).indices.to(torch.int16).view(-1) # will have 2D shape: (blocks_count, self.block_size)
|
|
95
|
+
|
|
96
|
+
if self.start_index_of_last_block < self.d:
|
|
97
|
+
self.I[self.buffer_index, self.k-self.last_k:] = torch.topk(
|
|
98
|
+
input=self.error[self.start_index_of_last_block:].abs(),
|
|
99
|
+
k=self.last_k,
|
|
100
|
+
sorted=False).indices.to(torch.int16)
|
|
101
|
+
|
|
102
|
+
### copy the values from the error feedback accumulator to values V (this is the G update),
|
|
103
|
+
### the large tensor (error, size d) is copied to the small tensor (V, size k)
|
|
104
|
+
# norm_last_v_1 = self.V[self.buffer_index, :].norm(p=2).item()
|
|
105
|
+
ista_daslab_cuda_tools.copy_values(self.d, # V = error[I[buffer_index, :]]
|
|
106
|
+
self.k,
|
|
107
|
+
self.d_block_size,
|
|
108
|
+
self.k_block_size,
|
|
109
|
+
self.I[self.buffer_index, :], # indices
|
|
110
|
+
self.error, # inp
|
|
111
|
+
self.V[self.buffer_index, :], # output
|
|
112
|
+
CopyDirection.d2k.value)
|
|
113
|
+
# norm_last_v_2 = self.V[self.buffer_index, :].norm(p=2).item()
|
|
114
|
+
### the small tensor (V, size k) is copied to the large tensor (g, size d)
|
|
115
|
+
g.zero_() # this will contain the values in V, at the right indices, but will also contain zeros
|
|
116
|
+
# norm_g_before = g.norm(p=2).item()
|
|
117
|
+
ista_daslab_cuda_tools.copy_values(self.d, # this does g[I[buffer_index]] = V
|
|
118
|
+
self.k,
|
|
119
|
+
self.d_block_size,
|
|
120
|
+
self.k_block_size,
|
|
121
|
+
self.I[self.buffer_index, :], # indices
|
|
122
|
+
self.V[self.buffer_index, :], # inp
|
|
123
|
+
g, # out
|
|
124
|
+
CopyDirection.k2d.value)
|
|
125
|
+
# norm_g_after = g.norm(p=2).item()
|
|
126
|
+
# norm_last_v_3 = self.V[self.buffer_index, :].norm(p=2).item()
|
|
127
|
+
# print(f'[_apply_ef_then_topk]{self.steps=}\n\t{norm_g_before=}, {norm_g_after=}\n\t{norm_last_v_1=}, {norm_last_v_2=}, {norm_last_v_3=}')
|
|
128
|
+
# zerorize error: subtract the top-k values (saved in V[index, :]), which are also present in g
|
|
129
|
+
self.error.sub_(g)
|
|
130
|
+
|
|
131
|
+
def apply_ef_then_update_buffer_then_precondition(self, g):
|
|
132
|
+
"""
|
|
133
|
+
The function name says it all
|
|
134
|
+
Returns update inv(F) * g = 1/lambda * g - linear_combination_of_gradients (tmp contains linear comb params)
|
|
135
|
+
:param g: the dense gradient
|
|
136
|
+
:return: `the preconditioned sparse-gradient
|
|
137
|
+
"""
|
|
138
|
+
self.steps += 1
|
|
139
|
+
self._apply_ef_then_topk(g) # after this call, g will contain the top-k values and zeros
|
|
140
|
+
|
|
141
|
+
# norm_g = g.norm(p=2).item()
|
|
142
|
+
# norm_last_v = self.V[self.buffer_index, :].norm(p=2).item()
|
|
143
|
+
# print(f'{self.steps=}, {norm_g=}, {norm_last_v=}')
|
|
144
|
+
|
|
145
|
+
dots = self._integrate_gradient(topk_values_w_zeros=g)
|
|
146
|
+
p = self._precondition(g, dots) # here we precondition the sparse gradient, e.g. only the top-k values, stored in the d-dim tensor g
|
|
147
|
+
return p
|
|
148
|
+
|
|
149
|
+
def _integrate_gradient(self, topk_values_w_zeros):
|
|
150
|
+
tmp = self.compute_scalar_products(topk_values_w_zeros)
|
|
151
|
+
tmp = tmp.squeeze() # (d, 1) becomes (d,)
|
|
152
|
+
|
|
153
|
+
self.dots[self.buffer_index, :] = tmp
|
|
154
|
+
self.dots[:, self.buffer_index] = tmp
|
|
155
|
+
|
|
156
|
+
self.setup()
|
|
157
|
+
|
|
158
|
+
self.buffer_index = (self.buffer_index + 1) % self.m
|
|
159
|
+
return tmp
|
|
160
|
+
|
|
161
|
+
def _precondition(self, g, dots=None):
|
|
162
|
+
"""
|
|
163
|
+
Returns the update inv(F) * x
|
|
164
|
+
The matrix M stores the coefficients of the linear combination
|
|
165
|
+
x: usually the sparse gradient
|
|
166
|
+
"""
|
|
167
|
+
# print(f'[precondition]')
|
|
168
|
+
if dots is None:
|
|
169
|
+
dots = self.compute_scalar_products(g)
|
|
170
|
+
giHix = self.lamda * dots
|
|
171
|
+
if USE_CUDA:
|
|
172
|
+
giHix = ista_daslab_cuda_dense_mfac.hinv_mul(self.m, self.giHig, giHix)
|
|
173
|
+
torch.cuda.synchronize()
|
|
174
|
+
else:
|
|
175
|
+
for i in range(1, self.m):
|
|
176
|
+
giHix[i:].sub_(self.giHig[i - 1, i:], alpha=giHix[i - 1] / self.denom[i - 1])
|
|
177
|
+
M = (giHix / self.denom).matmul(self.coef) # .view(-1, 1) # view is linked to matmul_grads_sequential_batch
|
|
178
|
+
|
|
179
|
+
partA = self.lamda * g
|
|
180
|
+
partB = self.compute_linear_combination(M, out=g) # out will be returned such that partB = out
|
|
181
|
+
if self.steps > 0 and self.log_interval > 0 and self.steps % self.log_interval == 0:
|
|
182
|
+
self.wandb_data.update(dict(norm_partA=partA.norm(p=2), norm_partB=partB.norm(p=2)))
|
|
183
|
+
return partA - partB
|
|
184
|
+
|
|
185
|
+
def compute_scalar_products(self, g):
|
|
186
|
+
self.scalar_products.zero_()
|
|
187
|
+
|
|
188
|
+
ista_daslab_cuda_sparse_mfac.SP(
|
|
189
|
+
self.kvm.get_SP_blocks(),
|
|
190
|
+
self.kvm.get_SP_threads(),
|
|
191
|
+
self.kvm.version_SP,
|
|
192
|
+
self.d,
|
|
193
|
+
min(self.m, self.steps),
|
|
194
|
+
self.k,
|
|
195
|
+
self.d_block_size,
|
|
196
|
+
self.k_block_size,
|
|
197
|
+
g,
|
|
198
|
+
self.I,
|
|
199
|
+
self.V,
|
|
200
|
+
self.scalar_products,
|
|
201
|
+
int(self.use_bf16))
|
|
202
|
+
|
|
203
|
+
return self.scalar_products
|
|
204
|
+
|
|
205
|
+
def compute_linear_combination(self, M, out):
|
|
206
|
+
out.zero_()
|
|
207
|
+
|
|
208
|
+
ista_daslab_cuda_sparse_mfac.LCG(
|
|
209
|
+
self.kvm.get_LCG_blocks(),
|
|
210
|
+
self.kvm.get_LCG_threads(),
|
|
211
|
+
self.kvm.version_LCG,
|
|
212
|
+
self.d,
|
|
213
|
+
min(self.m, self.steps),
|
|
214
|
+
self.k,
|
|
215
|
+
self.d_block_size,
|
|
216
|
+
self.k_block_size,
|
|
217
|
+
M,
|
|
218
|
+
self.I,
|
|
219
|
+
self.V,
|
|
220
|
+
out,
|
|
221
|
+
int(self.use_bf16))
|
|
222
|
+
|
|
223
|
+
if self.steps > 0 and self.log_interval > 0 and self.steps % self.log_interval == 0:
|
|
224
|
+
self.wandb_data.update(dict(lin_comb_coef_norm=M.norm(p=2)))
|
|
225
|
+
|
|
226
|
+
return out
|