ista-daslab-optimizers 0.0.1__cp39-cp39-manylinux_2_34_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,4 @@
1
+ from .acdc import *
2
+ from .micro_adam import *
3
+ from .sparse_mfac import *
4
+ from .dense_mfac import *
@@ -0,0 +1,5 @@
1
+ from .acdc import ACDC
2
+
3
+ __all__ = [
4
+ 'ACDC'
5
+ ]
@@ -0,0 +1,387 @@
1
+ import math
2
+ from enum import Enum
3
+ import torch
4
+ import wandb
5
+ from .wd_scheduler import WeightDecayScheduler
6
+
7
+ class ACDC_Action(Enum):
8
+ FILL_MASK_WITH_ONES = 1
9
+ SPARSIFY_MODEL_AND_UPDATE_MASK = 2
10
+ KEEP_MASK_FIXED = 3
11
+
12
+ class ACDC_Phase(Enum):
13
+ DENSE = 1
14
+ SPARSE = 2
15
+
16
+ class ACDC_Scheduler:
17
+ """
18
+ This class will hold a list of epochs where the sparse training is performed
19
+ """
20
+ def __init__(self, warmup_epochs, epochs, phase_length_epochs, finetuning_epochs, zero_based=False):
21
+ """
22
+ Builds an AC/DC Scheduler
23
+ :param warmup_epochs: the warm-up length (dense training)
24
+ :param epochs: total number of epochs for training
25
+ :param phase_length_epochs: the length of dense and sparse phases (both are equal)
26
+ :param finetuning_epochs: the epoch when the finetuning_epochs starts and is considered sparse training
27
+ :param zero_based: True if epochs start from zero, False if epochs start from one
28
+ """
29
+ # print(f'AC/DC Scheduler: {warmup_epochs=}, {phase_length_epochs=}, {finetuning_epochs=}, {epochs=}')
30
+ self.warmup_epochs = warmup_epochs
31
+ self.epochs = epochs
32
+ self.phase_length_epochs = phase_length_epochs
33
+ self.finetuning_epochs = finetuning_epochs
34
+ self.zero_based = zero_based
35
+
36
+ self.sparse_epochs = []
37
+ self._build_sparse_epochs()
38
+
39
+ def _build_sparse_epochs(self):
40
+ is_sparse = True
41
+ for i, e in enumerate(range(self.warmup_epochs, self.epochs)):
42
+ if is_sparse or e >= self.finetuning_epochs:
43
+ self.sparse_epochs.append(e)
44
+
45
+ if (e-self.warmup_epochs) % self.phase_length_epochs == self.phase_length_epochs - 1:
46
+ # if e % self.phase_length_epochs == self.phase_length_epochs - 1:
47
+ is_sparse = not is_sparse
48
+
49
+ if not self.zero_based:
50
+ for i in range(len(self.sparse_epochs)):
51
+ self.sparse_epochs[i] += 1
52
+
53
+ def is_sparse_epoch(self, epoch):
54
+ """
55
+ Returns True if sparse training should performed, otherwise returns False
56
+ :param epoch: a zero-based epoch number
57
+ """
58
+ return epoch in self.sparse_epochs
59
+
60
+ def is_finetuning_epoch(self, epoch):
61
+ return epoch >= self.finetuning_epochs
62
+
63
+ def get_action(self, epoch):
64
+ is_crt_epoch_sparse = self.is_sparse_epoch(epoch)
65
+ is_prev_epoch_sparse = self.is_sparse_epoch(epoch-1)
66
+ if is_crt_epoch_sparse:
67
+ if is_prev_epoch_sparse:
68
+ return ACDC_Action.KEEP_MASK_FIXED # mask was updated with top-k at the first dense epoch and now do nothing
69
+ return ACDC_Action.SPARSIFY_MODEL_AND_UPDATE_MASK # first dense epoch, update the mask using topk
70
+ else: # dense epoch
71
+ if epoch == int(not self.zero_based) or is_prev_epoch_sparse:
72
+ return ACDC_Action.FILL_MASK_WITH_ONES # fill mask with ones
73
+ else:
74
+ return ACDC_Action.KEEP_MASK_FIXED # do not change mask
75
+
76
+ def get_phase(self, epoch):
77
+ if self.is_sparse_epoch(epoch):
78
+ return ACDC_Phase.SPARSE
79
+ return ACDC_Phase.DENSE
80
+
81
+ class ACDC(torch.optim.Optimizer):
82
+ """
83
+ This class implements the Default AC/DC schedule from the original paper https://arxiv.org/abs/2106.12379.pdf:
84
+ We use the model only to add names for the parameters in wandb logs and print the sparsity, norm and weight decay
85
+ For example, fc.bias (1D, that requires weight decay) and ln/bn.weight/bias (1D, that do not require weight decay).
86
+ * LN = Layer Normalization
87
+ * BN = Batch Normalization
88
+ Example for 100 epochs (from original AC/DC paper https://arxiv.org/pdf/2106.12379.pdf)
89
+ - first 10 epochs warmup (dense training)
90
+ - alternate sparse/dense training phases once at 5 epochs
91
+ - last 10 epochs finetuning (sparse training)
92
+ !!!!! SEE THE HUGE COMMENT AT THE END OF THIS FILE FOR A MORE DETAILED EXAMPLE !!!!!
93
+
94
+ To use this class, make sure you call method `update_acdc_state` at the beginning of each epoch.
95
+
96
+ The following information is required:
97
+ - params
98
+ - model
99
+ - momentum
100
+ - weight_decay
101
+ - wd_type
102
+ - k
103
+ - total_epochs
104
+ - warmup_epochs
105
+ - phase_length_epochs
106
+ - finetuning_epochs
107
+ """
108
+ def __init__(self,
109
+ params, model, # optimization set/model
110
+ lr, momentum, weight_decay, wd_type, k, # hyper-parameters
111
+ total_epochs, warmup_epochs, phase_length_epochs, finetuning_epochs): # acdc schedulers
112
+ super(ACDC, self).__init__(params, defaults=dict(lr=lr, weight_decay=weight_decay, momentum=momentum, k=k))
113
+
114
+ self.model = model
115
+ self.lr = lr
116
+ self.momentum = momentum
117
+ self.weight_decay = weight_decay
118
+ self.wd_type = wd_type
119
+ self.k = k
120
+
121
+ self.acdc_scheduler = ACDC_Scheduler(
122
+ warmup_epochs=warmup_epochs,
123
+ epochs=total_epochs,
124
+ phase_length_epochs=phase_length_epochs,
125
+ finetuning_epochs=finetuning_epochs,
126
+ zero_based=False)
127
+
128
+ self.phase = None
129
+ self.is_finetuning_epoch = False
130
+ self.update_mask_flag = None
131
+
132
+ self.current_epoch = 0
133
+ self.steps = 0
134
+ self.log_interval = 250
135
+
136
+ self._initialize_param_states()
137
+
138
+ def _initialize_param_states(self):
139
+ for group in self.param_groups:
140
+ for p in group['params']:
141
+
142
+ # this method is called before the first forward pass and all gradients will be None
143
+ # if p.grad is None:
144
+ # continue
145
+
146
+ state = self.state[p]
147
+
148
+ # initialize the state for each individual parameter p
149
+ if len(state) == 0:
150
+ # v is the momentum buffer
151
+ state['v'] = torch.zeros_like(p)
152
+
153
+ # set density to be used in top-k call (only for multi-dim tensors)
154
+ state['density'] = int(self.k * p.numel())
155
+
156
+ # 1D tensors, like:
157
+ # - batch/layer normalization layers
158
+ # - biases for other layers
159
+ # will always have mask=1 because they will never be pruned
160
+ state['mask'] = torch.ones_like(p)
161
+
162
+ # set the weight decay scheduler for each parameter individually
163
+ # all biases and batch/layer norm layers are not decayed
164
+ if len(p.shape) == 1:
165
+ state['wd_scheduler'] = WeightDecayScheduler(weight_decay=0, wd_type='const')
166
+ else:
167
+ state['wd_scheduler'] = WeightDecayScheduler(weight_decay=self.weight_decay, wd_type=self.wd_type)
168
+
169
+ @torch.no_grad()
170
+ def update_acdc_state(self, epoch):
171
+ self.current_epoch = epoch
172
+ phase = self.acdc_scheduler.get_phase(self.current_epoch)
173
+ action = self.acdc_scheduler.get_action(self.current_epoch)
174
+
175
+ print(f'{epoch=}, {phase=}')
176
+
177
+ self._set_phase(phase)
178
+
179
+ if action == ACDC_Action.FILL_MASK_WITH_ONES:
180
+ for group in self.param_groups:
181
+ for p in group['params']:
182
+ self.state[p]['mask'].fill_(1)
183
+ elif action == ACDC_Action.SPARSIFY_MODEL_AND_UPDATE_MASK:
184
+ # update mask and sparsify model
185
+ for group in self.param_groups:
186
+ for p in group['params']:
187
+ is_multi_dim = (len(p.shape) > 1)
188
+ if is_multi_dim:
189
+ state = self.state[p]
190
+ # original_shape = p.shape.clone()
191
+ indices = torch.topk(p.reshape(-1).abs(), k=state['density']).indices
192
+
193
+ # zerorize, view mask as 1D and then set 1 to specific indices. The result will have p.shape
194
+ state['mask'].zero_().reshape(-1)[indices] = 1.
195
+
196
+ # apply the mask to the parameters
197
+ p.mul_(state['mask'])
198
+ elif action == ACDC_Action.KEEP_MASK_FIXED:
199
+ pass # do nothing
200
+
201
+ def _zerorize_momentum_buffer(self):
202
+ # zerorize momentum buffer only for the multi-dimensional parameters
203
+ for group in self.param_groups:
204
+ for p in group['params']:
205
+ self.state[p]['v'].zero_()
206
+
207
+ def _set_phase(self, phase):
208
+ if phase != self.phase:
209
+ self.phase = phase
210
+ if self.phase == ACDC_Phase.DENSE:
211
+ # AC/DC: zerorize momentum buffer at the transition SPARSE => DENSE
212
+ # The following quote is copy-pasted from original ACDC paper, page 7:
213
+ # "We reset SGD momentum at the beginning of every decompression phase."
214
+ self._zerorize_momentum_buffer()
215
+
216
+ @torch.no_grad()
217
+ def _wandb_log(self):
218
+ if self.steps % self.log_interval == 0:
219
+ wandb_dict = dict()
220
+
221
+ total_params = 0
222
+ global_sparsity = 0
223
+ global_params_norm = 0
224
+ global_grad_norm = 0
225
+ for name, p in self.model.named_parameters():
226
+ total_params += p.numel()
227
+ crt_sparsity = (p == 0).sum().item()
228
+ norm_param = p.norm(p=2)
229
+ norm_grad = p.grad.norm(p=2)
230
+
231
+ wandb_dict[f'weight_sparsity_{name}'] = crt_sparsity / p.numel() * 100.
232
+ wandb_dict[f'mask_sparsity_{name}'] = (self.state[p]['mask'] == 0).sum().item() / p.numel() * 100.
233
+ wandb_dict[f'norm_param_{name}'] = norm_param
234
+ wandb_dict[f'norm_grad_{name}'] = norm_grad
235
+
236
+ if self.wd_type == 'awd':
237
+ wandb_dict[f'awd_{name}'] = self.state[p]['wd_scheduler'].get_wd()
238
+
239
+ global_params_norm += norm_param ** 2
240
+ global_grad_norm += norm_grad ** 2
241
+ global_sparsity += crt_sparsity
242
+
243
+ wandb_dict[f'global_params_norm'] = math.sqrt(global_params_norm)
244
+ wandb_dict[f'global_grad_norm'] = math.sqrt(global_grad_norm)
245
+ wandb_dict[f'global_sparsity'] = global_sparsity / total_params * 100.
246
+
247
+ wandb_dict['optimizer_epoch'] = self.current_epoch
248
+ wandb_dict['optimizer_step'] = self.steps
249
+
250
+ wandb_dict['is_dense_phase'] = int(self.phase == ACDC_Phase.DENSE)
251
+ wandb_dict['is_sparse_phase'] = int(self.phase == ACDC_Phase.SPARSE)
252
+ wandb.log(wandb_dict)
253
+
254
+ @torch.no_grad()
255
+ def step(self, closure=None):
256
+ self.steps += 1
257
+ for group in self.param_groups:
258
+ lr = group['lr']
259
+ momentum = group['momentum']
260
+ for p in group['params']:
261
+ if p.grad is None:
262
+ continue
263
+
264
+ # holds all buffers for the current parameter
265
+ state = self.state[p]
266
+
267
+ ### apply mask to the gradient on sparse phase
268
+ ### do not modify gradient in place via p.grad.mul_(state['mask'])
269
+ ### because this will affect the norm statistics in self._wandb_log
270
+ ### this will create intermediary tensors
271
+ # grad = p.grad
272
+ if self.phase == ACDC_Phase.SPARSE:
273
+ # grad = grad * state['mask']
274
+ p.grad.mul_(state['mask']) # sparsify gradient
275
+
276
+ state['v'].mul_(momentum).add_(p.grad)
277
+
278
+ wd = state['wd_scheduler'](w=p, g=p.grad) # use sparsified gradient
279
+ p.mul_(1 - lr * wd).sub_(other=state['v'], alpha=lr).mul_(state['mask'])
280
+
281
+ self._wandb_log()
282
+
283
+ """
284
+ This is what ACDC_Scheduler outputs for the default ACDC schedule, presented at page 5, Figure 1 in the paper https://arxiv.org/pdf/2106.12379.pdf
285
+
286
+ epoch is_sparse_epoch phase action
287
+ 1 False ACDC_Phase.DENSE ACDC_Action.FILL_MASK_WITH_ONES
288
+ 2 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
289
+ 3 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
290
+ 4 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
291
+ 5 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
292
+ 6 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
293
+ 7 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
294
+ 8 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
295
+ 9 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
296
+ 10 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
297
+ 11 True ACDC_Phase.SPARSE ACDC_Action.SPARSIFY_MODEL_AND_UPDATE_MASK
298
+ 12 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
299
+ 13 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
300
+ 14 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
301
+ 15 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
302
+ 16 False ACDC_Phase.DENSE ACDC_Action.FILL_MASK_WITH_ONES
303
+ 17 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
304
+ 18 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
305
+ 19 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
306
+ 20 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
307
+ 21 True ACDC_Phase.SPARSE ACDC_Action.SPARSIFY_MODEL_AND_UPDATE_MASK
308
+ 22 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
309
+ 23 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
310
+ 24 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
311
+ 25 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
312
+ 26 False ACDC_Phase.DENSE ACDC_Action.FILL_MASK_WITH_ONES
313
+ 27 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
314
+ 28 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
315
+ 29 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
316
+ 30 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
317
+ 31 True ACDC_Phase.SPARSE ACDC_Action.SPARSIFY_MODEL_AND_UPDATE_MASK
318
+ 32 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
319
+ 33 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
320
+ 34 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
321
+ 35 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
322
+ 36 False ACDC_Phase.DENSE ACDC_Action.FILL_MASK_WITH_ONES
323
+ 37 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
324
+ 38 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
325
+ 39 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
326
+ 40 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
327
+ 41 True ACDC_Phase.SPARSE ACDC_Action.SPARSIFY_MODEL_AND_UPDATE_MASK
328
+ 42 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
329
+ 43 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
330
+ 44 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
331
+ 45 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
332
+ 46 False ACDC_Phase.DENSE ACDC_Action.FILL_MASK_WITH_ONES
333
+ 47 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
334
+ 48 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
335
+ 49 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
336
+ 50 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
337
+ 51 True ACDC_Phase.SPARSE ACDC_Action.SPARSIFY_MODEL_AND_UPDATE_MASK
338
+ 52 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
339
+ 53 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
340
+ 54 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
341
+ 55 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
342
+ 56 False ACDC_Phase.DENSE ACDC_Action.FILL_MASK_WITH_ONES
343
+ 57 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
344
+ 58 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
345
+ 59 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
346
+ 60 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
347
+ 61 True ACDC_Phase.SPARSE ACDC_Action.SPARSIFY_MODEL_AND_UPDATE_MASK
348
+ 62 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
349
+ 63 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
350
+ 64 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
351
+ 65 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
352
+ 66 False ACDC_Phase.DENSE ACDC_Action.FILL_MASK_WITH_ONES
353
+ 67 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
354
+ 68 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
355
+ 69 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
356
+ 70 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
357
+ 71 True ACDC_Phase.SPARSE ACDC_Action.SPARSIFY_MODEL_AND_UPDATE_MASK
358
+ 72 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
359
+ 73 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
360
+ 74 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
361
+ 75 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
362
+ 76 False ACDC_Phase.DENSE ACDC_Action.FILL_MASK_WITH_ONES
363
+ 77 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
364
+ 78 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
365
+ 79 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
366
+ 80 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
367
+ 81 True ACDC_Phase.SPARSE ACDC_Action.SPARSIFY_MODEL_AND_UPDATE_MASK
368
+ 82 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
369
+ 83 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
370
+ 84 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
371
+ 85 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
372
+ 86 False ACDC_Phase.DENSE ACDC_Action.FILL_MASK_WITH_ONES
373
+ 87 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
374
+ 88 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
375
+ 89 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
376
+ 90 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
377
+ 91 True ACDC_Phase.SPARSE ACDC_Action.SPARSIFY_MODEL_AND_UPDATE_MASK
378
+ 92 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
379
+ 93 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
380
+ 94 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
381
+ 95 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
382
+ 96 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
383
+ 97 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
384
+ 98 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
385
+ 99 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
386
+ 100 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
387
+ """
@@ -0,0 +1,31 @@
1
+ class WeightDecayScheduler:
2
+ def __init__(self, weight_decay: float, wd_type: str):
3
+ # assert wd_type in ['const', 'balanced', 'opt-a', 'opt-b']
4
+ self.weight_decay = weight_decay
5
+ self.wd_type = wd_type
6
+ self.awd = 0 # for AWD from apple
7
+
8
+ def get_wd(self):
9
+ if self.wd_type == 'const':
10
+ return self.weight_decay
11
+ if self.wd_type == 'awd':
12
+ return self.awd
13
+
14
+ def __call__(self, w=None, g=None):
15
+ """
16
+ :param w: tensor that contains weights
17
+ :param g: tensor that contains gradients
18
+ :return: the value for the weight decay
19
+ """
20
+ if self.wd_type == 'const':
21
+ return self.weight_decay
22
+
23
+ if self.wd_type == 'awd': # AWD from the Apple paper: https://openreview.net/pdf?id=ajnThDhuq6
24
+ assert (w is not None) and (g is not None),\
25
+ 'The balanced weight decay scheduler requires valid for w and g, but at least one is None!'
26
+
27
+ # in the paper, lambda_awd is set by the user and they return a lambda_bar
28
+ # here, self.awd will be the moving average from the line 8 in their algorithm and
29
+ # the input to our algorithm is self.weight_decay for all wd_types!
30
+ self.awd = 0.1 * self.awd + 0.9 * self.weight_decay * g.norm(p=2) / w.norm(p=2)
31
+ return self.awd
@@ -0,0 +1,5 @@
1
+ from .dense_mfac import DenseMFAC
2
+
3
+ __all__ = [
4
+ 'DenseMFAC'
5
+ ]
@@ -0,0 +1,164 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+ USE_CUDA = True
6
+ try:
7
+ import ista_daslab_dense_mfac
8
+ except Exception as e:
9
+ USE_CUDA = False
10
+ print('\n\t[WARNING] The module "ista_daslab_dense_mfac" is not installed, using slower PyTorch implementation!\n')
11
+
12
+ class DenseCoreMFAC:
13
+ def __init__(self, grads, dev, gpus, damp=1e-5, create_G=False):
14
+ self.m, self.d = grads.shape
15
+ self.dev = dev
16
+ self.gpus = gpus
17
+ self.dtype = grads.dtype
18
+ self.gpus = gpus
19
+ self.grads_count = 0
20
+ self.wandb_data = dict()
21
+ self.damp = None
22
+ self.lambd = None
23
+ self.set_damp(damp)
24
+ self.create_G = create_G
25
+ if self.create_G:
26
+ self.G = grads
27
+
28
+ if USE_CUDA and self.m % 32 != 0 or self.m > 1024:
29
+ raise ValueError('CUDA implementation currently on supports $m$ < 1024 and divisible by 32.')
30
+
31
+ self.dper = self.d // len(gpus) + 1
32
+ self.grads = [] # matrix $G$ in the paper
33
+ for idx in range(len(gpus)):
34
+ start, end = idx * self.dper, (idx + 1) * self.dper
35
+ self.grads.append(grads[:, start:end].to(gpus[idx]))
36
+ self.dots = torch.zeros((self.m, self.m), device=self.dev, dtype=self.dtype) # matrix $GG^T$
37
+ for idx in range(len(gpus)):
38
+ self.dots += self.grads[idx].matmul(self.grads[idx].t()).to(self.dev)
39
+
40
+ self.last = 0 # ringbuffer index
41
+ self.giHig = self.lambd * self.dots # matrix $D$
42
+ self.denom = torch.zeros(self.m, device=self.dev, dtype=self.dtype) # $D_ii + m$
43
+ self.coef = self.lambd * torch.eye(self.m, device=self.dev, dtype=self.dtype) # matrix $B$
44
+
45
+ self.setup()
46
+
47
+ def empty_buffer(self):
48
+ for g in self.grads:
49
+ g.zero_()
50
+
51
+ def set_damp(self, new_damp):
52
+ self.damp = new_damp
53
+ self.lambd = 1. / new_damp
54
+
55
+ def reset_optimizer(self):
56
+ self.grads_count = 0
57
+ for idx in range(len(self.gpus)):
58
+ self.grads[idx].zero_()
59
+ self.dots.zero_()
60
+ for idx in range(len(self.gpus)):
61
+ self.dots += self.grads[idx].matmul(self.grads[idx].t()).to(self.dev)
62
+ self.last = 0
63
+ self.giHig = self.lambd * self.dots # matrix $D$
64
+ self.denom = torch.zeros(self.m, device=self.dev, dtype=self.dtype) # $D_ii + m$
65
+ self.coef = self.lambd * torch.eye(self.m, device=self.dev, dtype=self.dtype) # matrix $B$
66
+ self.setup()
67
+
68
+ # Calculate $D$ / `giHig` and $B$ / `coef`
69
+ def setup(self):
70
+ self.giHig = self.lambd * self.dots
71
+ diag = torch.diag(torch.full(size=[self.m], fill_value=self.m, device=self.dev, dtype=self.dtype))
72
+ self.giHig = torch.lu(self.giHig + diag, pivot=False)[0]
73
+ self.giHig = torch.triu(self.giHig - diag)
74
+ self.denom = self.m + torch.diagonal(self.giHig) # here we should use min(grads_count, m)
75
+ tmp = -self.giHig.t().contiguous() / self.denom.reshape((1, -1))
76
+
77
+ if USE_CUDA:
78
+ diag = torch.diag(torch.full(size=[self.m], fill_value=self.lambd, device=self.dev, dtype=self.dtype))
79
+ self.coef = ista_daslab_dense_mfac.hinv_setup(tmp, diag)
80
+ else:
81
+ for i in range(max(self.last, 1), self.m):
82
+ self.coef[i, :i] = tmp[i, :i].matmul(self.coef[:i, :i])
83
+
84
+ # Replace oldest gradient with `g` and then calculate the IHVP with `g`
85
+ def integrate_gradient_and_precondition(self, g, x):
86
+ tmp = self.integrate_gradient(g)
87
+ p = self.precondition(x, tmp)
88
+ return p
89
+
90
+ # Replace oldest gradient with `g`
91
+ def integrate_gradient(self, g):
92
+ self.set_grad(self.last, g)
93
+ tmp = self.compute_scalar_products(g)
94
+ self.dots[self.last, :] = tmp
95
+ self.dots[:, self.last] = tmp
96
+ self.setup()
97
+ self.last = (self.last + 1) % self.m
98
+ return tmp
99
+
100
+ # Distributed `grads[j, :] = g`
101
+ def set_grad(self, j, g):
102
+ # for the eigenvalue experiment:
103
+ if self.create_G:
104
+ self.G[j, :] = g
105
+
106
+ self.grads_count += 1
107
+ def f(i):
108
+ start, end = i * self.dper, (i + 1) * self.dper
109
+ self.grads[i][j, :] = g[start:end]
110
+
111
+ nn.parallel.parallel_apply(
112
+ [f] * len(self.grads), list(range(len(self.gpus)))
113
+ )
114
+
115
+ # Distributed `grads.matmul(x)`
116
+ def compute_scalar_products(self, x):
117
+ def f(i):
118
+ start, end = i * self.dper, (i + 1) * self.dper
119
+ G = self.grads[i]
120
+ return G.matmul(x[start:end].to(G.device)).to(self.dev)
121
+
122
+ outputs = nn.parallel.parallel_apply(
123
+ [f] * len(self.gpus), list(range(len(self.gpus)))
124
+ )
125
+ return sum(outputs)
126
+
127
+ # Product with inverse of dampened empirical Fisher
128
+ def precondition(self, x, dots=None):
129
+ if dots is None:
130
+ dots = self.compute_scalar_products(x)
131
+ giHix = self.lambd * dots
132
+ if USE_CUDA:
133
+ giHix = ista_daslab_dense_mfac.hinv_mul(self.m, self.giHig, giHix)
134
+ else:
135
+ for i in range(1, self.m):
136
+ giHix[i:].sub_(self.giHig[i - 1, i:], alpha=giHix[i - 1] / self.denom[i - 1])
137
+ """
138
+ giHix size: 1024
139
+ denom size: 1024
140
+ coef size: 1024x1024
141
+ M size: 1024
142
+ x size: d
143
+ """
144
+ M = (giHix / self.denom).matmul(self.coef)
145
+ partA = self.lambd * x
146
+ partB = self.compute_linear_combination(M)
147
+ self.wandb_data.update({f'norm_partA': partA.norm(p=2), f'norm_partB': partB.norm(p=2)})
148
+ return partA.to(self.dev) - partB.to(self.dev)
149
+
150
+ # Distributed `x.matmul(grads)`
151
+ def compute_linear_combination(self, x):
152
+ def f(G):
153
+ return (x.to(G.device).matmul(G)).to(self.dev)
154
+ outputs = nn.parallel.parallel_apply(
155
+ [f] * len(self.grads), self.grads
156
+ )
157
+ """
158
+ x size: 1024
159
+ grads: 1024 x d
160
+ """
161
+ x = x.detach().cpu().numpy()
162
+ norm = np.linalg.norm(x)
163
+ self.wandb_data.update({f'lin_comb_coef_norm': norm})
164
+ return torch.cat(outputs)