ista-daslab-optimizers 0.0.1__cp39-cp39-manylinux_2_34_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ista_daslab_dense_mfac.cpython-39-x86_64-linux-gnu.so +0 -0
- ista_daslab_micro_adam.cpython-39-x86_64-linux-gnu.so +0 -0
- ista_daslab_optimizers/__init__.py +4 -0
- ista_daslab_optimizers/acdc/__init__.py +5 -0
- ista_daslab_optimizers/acdc/acdc.py +387 -0
- ista_daslab_optimizers/acdc/wd_scheduler.py +31 -0
- ista_daslab_optimizers/dense_mfac/__init__.py +5 -0
- ista_daslab_optimizers/dense_mfac/dense_core_mfac.py +164 -0
- ista_daslab_optimizers/dense_mfac/dense_mfac.py +89 -0
- ista_daslab_optimizers/micro_adam/__init__.py +5 -0
- ista_daslab_optimizers/micro_adam/micro_adam.py +247 -0
- ista_daslab_optimizers/sparse_mfac/__init__.py +5 -0
- ista_daslab_optimizers/sparse_mfac/sparse_core_mfac_w_ef.py +226 -0
- ista_daslab_optimizers/sparse_mfac/sparse_mfac.py +87 -0
- ista_daslab_optimizers/tools.py +215 -0
- ista_daslab_optimizers-0.0.1.dist-info/LICENSE +201 -0
- ista_daslab_optimizers-0.0.1.dist-info/METADATA +279 -0
- ista_daslab_optimizers-0.0.1.dist-info/RECORD +22 -0
- ista_daslab_optimizers-0.0.1.dist-info/WHEEL +5 -0
- ista_daslab_optimizers-0.0.1.dist-info/top_level.txt +5 -0
- ista_daslab_sparse_mfac.cpython-39-x86_64-linux-gnu.so +0 -0
- ista_daslab_tools.cpython-39-x86_64-linux-gnu.so +0 -0
|
Binary file
|
|
Binary file
|
|
@@ -0,0 +1,387 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from enum import Enum
|
|
3
|
+
import torch
|
|
4
|
+
import wandb
|
|
5
|
+
from .wd_scheduler import WeightDecayScheduler
|
|
6
|
+
|
|
7
|
+
class ACDC_Action(Enum):
|
|
8
|
+
FILL_MASK_WITH_ONES = 1
|
|
9
|
+
SPARSIFY_MODEL_AND_UPDATE_MASK = 2
|
|
10
|
+
KEEP_MASK_FIXED = 3
|
|
11
|
+
|
|
12
|
+
class ACDC_Phase(Enum):
|
|
13
|
+
DENSE = 1
|
|
14
|
+
SPARSE = 2
|
|
15
|
+
|
|
16
|
+
class ACDC_Scheduler:
|
|
17
|
+
"""
|
|
18
|
+
This class will hold a list of epochs where the sparse training is performed
|
|
19
|
+
"""
|
|
20
|
+
def __init__(self, warmup_epochs, epochs, phase_length_epochs, finetuning_epochs, zero_based=False):
|
|
21
|
+
"""
|
|
22
|
+
Builds an AC/DC Scheduler
|
|
23
|
+
:param warmup_epochs: the warm-up length (dense training)
|
|
24
|
+
:param epochs: total number of epochs for training
|
|
25
|
+
:param phase_length_epochs: the length of dense and sparse phases (both are equal)
|
|
26
|
+
:param finetuning_epochs: the epoch when the finetuning_epochs starts and is considered sparse training
|
|
27
|
+
:param zero_based: True if epochs start from zero, False if epochs start from one
|
|
28
|
+
"""
|
|
29
|
+
# print(f'AC/DC Scheduler: {warmup_epochs=}, {phase_length_epochs=}, {finetuning_epochs=}, {epochs=}')
|
|
30
|
+
self.warmup_epochs = warmup_epochs
|
|
31
|
+
self.epochs = epochs
|
|
32
|
+
self.phase_length_epochs = phase_length_epochs
|
|
33
|
+
self.finetuning_epochs = finetuning_epochs
|
|
34
|
+
self.zero_based = zero_based
|
|
35
|
+
|
|
36
|
+
self.sparse_epochs = []
|
|
37
|
+
self._build_sparse_epochs()
|
|
38
|
+
|
|
39
|
+
def _build_sparse_epochs(self):
|
|
40
|
+
is_sparse = True
|
|
41
|
+
for i, e in enumerate(range(self.warmup_epochs, self.epochs)):
|
|
42
|
+
if is_sparse or e >= self.finetuning_epochs:
|
|
43
|
+
self.sparse_epochs.append(e)
|
|
44
|
+
|
|
45
|
+
if (e-self.warmup_epochs) % self.phase_length_epochs == self.phase_length_epochs - 1:
|
|
46
|
+
# if e % self.phase_length_epochs == self.phase_length_epochs - 1:
|
|
47
|
+
is_sparse = not is_sparse
|
|
48
|
+
|
|
49
|
+
if not self.zero_based:
|
|
50
|
+
for i in range(len(self.sparse_epochs)):
|
|
51
|
+
self.sparse_epochs[i] += 1
|
|
52
|
+
|
|
53
|
+
def is_sparse_epoch(self, epoch):
|
|
54
|
+
"""
|
|
55
|
+
Returns True if sparse training should performed, otherwise returns False
|
|
56
|
+
:param epoch: a zero-based epoch number
|
|
57
|
+
"""
|
|
58
|
+
return epoch in self.sparse_epochs
|
|
59
|
+
|
|
60
|
+
def is_finetuning_epoch(self, epoch):
|
|
61
|
+
return epoch >= self.finetuning_epochs
|
|
62
|
+
|
|
63
|
+
def get_action(self, epoch):
|
|
64
|
+
is_crt_epoch_sparse = self.is_sparse_epoch(epoch)
|
|
65
|
+
is_prev_epoch_sparse = self.is_sparse_epoch(epoch-1)
|
|
66
|
+
if is_crt_epoch_sparse:
|
|
67
|
+
if is_prev_epoch_sparse:
|
|
68
|
+
return ACDC_Action.KEEP_MASK_FIXED # mask was updated with top-k at the first dense epoch and now do nothing
|
|
69
|
+
return ACDC_Action.SPARSIFY_MODEL_AND_UPDATE_MASK # first dense epoch, update the mask using topk
|
|
70
|
+
else: # dense epoch
|
|
71
|
+
if epoch == int(not self.zero_based) or is_prev_epoch_sparse:
|
|
72
|
+
return ACDC_Action.FILL_MASK_WITH_ONES # fill mask with ones
|
|
73
|
+
else:
|
|
74
|
+
return ACDC_Action.KEEP_MASK_FIXED # do not change mask
|
|
75
|
+
|
|
76
|
+
def get_phase(self, epoch):
|
|
77
|
+
if self.is_sparse_epoch(epoch):
|
|
78
|
+
return ACDC_Phase.SPARSE
|
|
79
|
+
return ACDC_Phase.DENSE
|
|
80
|
+
|
|
81
|
+
class ACDC(torch.optim.Optimizer):
|
|
82
|
+
"""
|
|
83
|
+
This class implements the Default AC/DC schedule from the original paper https://arxiv.org/abs/2106.12379.pdf:
|
|
84
|
+
We use the model only to add names for the parameters in wandb logs and print the sparsity, norm and weight decay
|
|
85
|
+
For example, fc.bias (1D, that requires weight decay) and ln/bn.weight/bias (1D, that do not require weight decay).
|
|
86
|
+
* LN = Layer Normalization
|
|
87
|
+
* BN = Batch Normalization
|
|
88
|
+
Example for 100 epochs (from original AC/DC paper https://arxiv.org/pdf/2106.12379.pdf)
|
|
89
|
+
- first 10 epochs warmup (dense training)
|
|
90
|
+
- alternate sparse/dense training phases once at 5 epochs
|
|
91
|
+
- last 10 epochs finetuning (sparse training)
|
|
92
|
+
!!!!! SEE THE HUGE COMMENT AT THE END OF THIS FILE FOR A MORE DETAILED EXAMPLE !!!!!
|
|
93
|
+
|
|
94
|
+
To use this class, make sure you call method `update_acdc_state` at the beginning of each epoch.
|
|
95
|
+
|
|
96
|
+
The following information is required:
|
|
97
|
+
- params
|
|
98
|
+
- model
|
|
99
|
+
- momentum
|
|
100
|
+
- weight_decay
|
|
101
|
+
- wd_type
|
|
102
|
+
- k
|
|
103
|
+
- total_epochs
|
|
104
|
+
- warmup_epochs
|
|
105
|
+
- phase_length_epochs
|
|
106
|
+
- finetuning_epochs
|
|
107
|
+
"""
|
|
108
|
+
def __init__(self,
|
|
109
|
+
params, model, # optimization set/model
|
|
110
|
+
lr, momentum, weight_decay, wd_type, k, # hyper-parameters
|
|
111
|
+
total_epochs, warmup_epochs, phase_length_epochs, finetuning_epochs): # acdc schedulers
|
|
112
|
+
super(ACDC, self).__init__(params, defaults=dict(lr=lr, weight_decay=weight_decay, momentum=momentum, k=k))
|
|
113
|
+
|
|
114
|
+
self.model = model
|
|
115
|
+
self.lr = lr
|
|
116
|
+
self.momentum = momentum
|
|
117
|
+
self.weight_decay = weight_decay
|
|
118
|
+
self.wd_type = wd_type
|
|
119
|
+
self.k = k
|
|
120
|
+
|
|
121
|
+
self.acdc_scheduler = ACDC_Scheduler(
|
|
122
|
+
warmup_epochs=warmup_epochs,
|
|
123
|
+
epochs=total_epochs,
|
|
124
|
+
phase_length_epochs=phase_length_epochs,
|
|
125
|
+
finetuning_epochs=finetuning_epochs,
|
|
126
|
+
zero_based=False)
|
|
127
|
+
|
|
128
|
+
self.phase = None
|
|
129
|
+
self.is_finetuning_epoch = False
|
|
130
|
+
self.update_mask_flag = None
|
|
131
|
+
|
|
132
|
+
self.current_epoch = 0
|
|
133
|
+
self.steps = 0
|
|
134
|
+
self.log_interval = 250
|
|
135
|
+
|
|
136
|
+
self._initialize_param_states()
|
|
137
|
+
|
|
138
|
+
def _initialize_param_states(self):
|
|
139
|
+
for group in self.param_groups:
|
|
140
|
+
for p in group['params']:
|
|
141
|
+
|
|
142
|
+
# this method is called before the first forward pass and all gradients will be None
|
|
143
|
+
# if p.grad is None:
|
|
144
|
+
# continue
|
|
145
|
+
|
|
146
|
+
state = self.state[p]
|
|
147
|
+
|
|
148
|
+
# initialize the state for each individual parameter p
|
|
149
|
+
if len(state) == 0:
|
|
150
|
+
# v is the momentum buffer
|
|
151
|
+
state['v'] = torch.zeros_like(p)
|
|
152
|
+
|
|
153
|
+
# set density to be used in top-k call (only for multi-dim tensors)
|
|
154
|
+
state['density'] = int(self.k * p.numel())
|
|
155
|
+
|
|
156
|
+
# 1D tensors, like:
|
|
157
|
+
# - batch/layer normalization layers
|
|
158
|
+
# - biases for other layers
|
|
159
|
+
# will always have mask=1 because they will never be pruned
|
|
160
|
+
state['mask'] = torch.ones_like(p)
|
|
161
|
+
|
|
162
|
+
# set the weight decay scheduler for each parameter individually
|
|
163
|
+
# all biases and batch/layer norm layers are not decayed
|
|
164
|
+
if len(p.shape) == 1:
|
|
165
|
+
state['wd_scheduler'] = WeightDecayScheduler(weight_decay=0, wd_type='const')
|
|
166
|
+
else:
|
|
167
|
+
state['wd_scheduler'] = WeightDecayScheduler(weight_decay=self.weight_decay, wd_type=self.wd_type)
|
|
168
|
+
|
|
169
|
+
@torch.no_grad()
|
|
170
|
+
def update_acdc_state(self, epoch):
|
|
171
|
+
self.current_epoch = epoch
|
|
172
|
+
phase = self.acdc_scheduler.get_phase(self.current_epoch)
|
|
173
|
+
action = self.acdc_scheduler.get_action(self.current_epoch)
|
|
174
|
+
|
|
175
|
+
print(f'{epoch=}, {phase=}')
|
|
176
|
+
|
|
177
|
+
self._set_phase(phase)
|
|
178
|
+
|
|
179
|
+
if action == ACDC_Action.FILL_MASK_WITH_ONES:
|
|
180
|
+
for group in self.param_groups:
|
|
181
|
+
for p in group['params']:
|
|
182
|
+
self.state[p]['mask'].fill_(1)
|
|
183
|
+
elif action == ACDC_Action.SPARSIFY_MODEL_AND_UPDATE_MASK:
|
|
184
|
+
# update mask and sparsify model
|
|
185
|
+
for group in self.param_groups:
|
|
186
|
+
for p in group['params']:
|
|
187
|
+
is_multi_dim = (len(p.shape) > 1)
|
|
188
|
+
if is_multi_dim:
|
|
189
|
+
state = self.state[p]
|
|
190
|
+
# original_shape = p.shape.clone()
|
|
191
|
+
indices = torch.topk(p.reshape(-1).abs(), k=state['density']).indices
|
|
192
|
+
|
|
193
|
+
# zerorize, view mask as 1D and then set 1 to specific indices. The result will have p.shape
|
|
194
|
+
state['mask'].zero_().reshape(-1)[indices] = 1.
|
|
195
|
+
|
|
196
|
+
# apply the mask to the parameters
|
|
197
|
+
p.mul_(state['mask'])
|
|
198
|
+
elif action == ACDC_Action.KEEP_MASK_FIXED:
|
|
199
|
+
pass # do nothing
|
|
200
|
+
|
|
201
|
+
def _zerorize_momentum_buffer(self):
|
|
202
|
+
# zerorize momentum buffer only for the multi-dimensional parameters
|
|
203
|
+
for group in self.param_groups:
|
|
204
|
+
for p in group['params']:
|
|
205
|
+
self.state[p]['v'].zero_()
|
|
206
|
+
|
|
207
|
+
def _set_phase(self, phase):
|
|
208
|
+
if phase != self.phase:
|
|
209
|
+
self.phase = phase
|
|
210
|
+
if self.phase == ACDC_Phase.DENSE:
|
|
211
|
+
# AC/DC: zerorize momentum buffer at the transition SPARSE => DENSE
|
|
212
|
+
# The following quote is copy-pasted from original ACDC paper, page 7:
|
|
213
|
+
# "We reset SGD momentum at the beginning of every decompression phase."
|
|
214
|
+
self._zerorize_momentum_buffer()
|
|
215
|
+
|
|
216
|
+
@torch.no_grad()
|
|
217
|
+
def _wandb_log(self):
|
|
218
|
+
if self.steps % self.log_interval == 0:
|
|
219
|
+
wandb_dict = dict()
|
|
220
|
+
|
|
221
|
+
total_params = 0
|
|
222
|
+
global_sparsity = 0
|
|
223
|
+
global_params_norm = 0
|
|
224
|
+
global_grad_norm = 0
|
|
225
|
+
for name, p in self.model.named_parameters():
|
|
226
|
+
total_params += p.numel()
|
|
227
|
+
crt_sparsity = (p == 0).sum().item()
|
|
228
|
+
norm_param = p.norm(p=2)
|
|
229
|
+
norm_grad = p.grad.norm(p=2)
|
|
230
|
+
|
|
231
|
+
wandb_dict[f'weight_sparsity_{name}'] = crt_sparsity / p.numel() * 100.
|
|
232
|
+
wandb_dict[f'mask_sparsity_{name}'] = (self.state[p]['mask'] == 0).sum().item() / p.numel() * 100.
|
|
233
|
+
wandb_dict[f'norm_param_{name}'] = norm_param
|
|
234
|
+
wandb_dict[f'norm_grad_{name}'] = norm_grad
|
|
235
|
+
|
|
236
|
+
if self.wd_type == 'awd':
|
|
237
|
+
wandb_dict[f'awd_{name}'] = self.state[p]['wd_scheduler'].get_wd()
|
|
238
|
+
|
|
239
|
+
global_params_norm += norm_param ** 2
|
|
240
|
+
global_grad_norm += norm_grad ** 2
|
|
241
|
+
global_sparsity += crt_sparsity
|
|
242
|
+
|
|
243
|
+
wandb_dict[f'global_params_norm'] = math.sqrt(global_params_norm)
|
|
244
|
+
wandb_dict[f'global_grad_norm'] = math.sqrt(global_grad_norm)
|
|
245
|
+
wandb_dict[f'global_sparsity'] = global_sparsity / total_params * 100.
|
|
246
|
+
|
|
247
|
+
wandb_dict['optimizer_epoch'] = self.current_epoch
|
|
248
|
+
wandb_dict['optimizer_step'] = self.steps
|
|
249
|
+
|
|
250
|
+
wandb_dict['is_dense_phase'] = int(self.phase == ACDC_Phase.DENSE)
|
|
251
|
+
wandb_dict['is_sparse_phase'] = int(self.phase == ACDC_Phase.SPARSE)
|
|
252
|
+
wandb.log(wandb_dict)
|
|
253
|
+
|
|
254
|
+
@torch.no_grad()
|
|
255
|
+
def step(self, closure=None):
|
|
256
|
+
self.steps += 1
|
|
257
|
+
for group in self.param_groups:
|
|
258
|
+
lr = group['lr']
|
|
259
|
+
momentum = group['momentum']
|
|
260
|
+
for p in group['params']:
|
|
261
|
+
if p.grad is None:
|
|
262
|
+
continue
|
|
263
|
+
|
|
264
|
+
# holds all buffers for the current parameter
|
|
265
|
+
state = self.state[p]
|
|
266
|
+
|
|
267
|
+
### apply mask to the gradient on sparse phase
|
|
268
|
+
### do not modify gradient in place via p.grad.mul_(state['mask'])
|
|
269
|
+
### because this will affect the norm statistics in self._wandb_log
|
|
270
|
+
### this will create intermediary tensors
|
|
271
|
+
# grad = p.grad
|
|
272
|
+
if self.phase == ACDC_Phase.SPARSE:
|
|
273
|
+
# grad = grad * state['mask']
|
|
274
|
+
p.grad.mul_(state['mask']) # sparsify gradient
|
|
275
|
+
|
|
276
|
+
state['v'].mul_(momentum).add_(p.grad)
|
|
277
|
+
|
|
278
|
+
wd = state['wd_scheduler'](w=p, g=p.grad) # use sparsified gradient
|
|
279
|
+
p.mul_(1 - lr * wd).sub_(other=state['v'], alpha=lr).mul_(state['mask'])
|
|
280
|
+
|
|
281
|
+
self._wandb_log()
|
|
282
|
+
|
|
283
|
+
"""
|
|
284
|
+
This is what ACDC_Scheduler outputs for the default ACDC schedule, presented at page 5, Figure 1 in the paper https://arxiv.org/pdf/2106.12379.pdf
|
|
285
|
+
|
|
286
|
+
epoch is_sparse_epoch phase action
|
|
287
|
+
1 False ACDC_Phase.DENSE ACDC_Action.FILL_MASK_WITH_ONES
|
|
288
|
+
2 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
289
|
+
3 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
290
|
+
4 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
291
|
+
5 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
292
|
+
6 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
293
|
+
7 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
294
|
+
8 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
295
|
+
9 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
296
|
+
10 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
297
|
+
11 True ACDC_Phase.SPARSE ACDC_Action.SPARSIFY_MODEL_AND_UPDATE_MASK
|
|
298
|
+
12 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
299
|
+
13 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
300
|
+
14 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
301
|
+
15 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
302
|
+
16 False ACDC_Phase.DENSE ACDC_Action.FILL_MASK_WITH_ONES
|
|
303
|
+
17 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
304
|
+
18 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
305
|
+
19 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
306
|
+
20 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
307
|
+
21 True ACDC_Phase.SPARSE ACDC_Action.SPARSIFY_MODEL_AND_UPDATE_MASK
|
|
308
|
+
22 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
309
|
+
23 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
310
|
+
24 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
311
|
+
25 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
312
|
+
26 False ACDC_Phase.DENSE ACDC_Action.FILL_MASK_WITH_ONES
|
|
313
|
+
27 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
314
|
+
28 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
315
|
+
29 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
316
|
+
30 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
317
|
+
31 True ACDC_Phase.SPARSE ACDC_Action.SPARSIFY_MODEL_AND_UPDATE_MASK
|
|
318
|
+
32 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
319
|
+
33 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
320
|
+
34 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
321
|
+
35 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
322
|
+
36 False ACDC_Phase.DENSE ACDC_Action.FILL_MASK_WITH_ONES
|
|
323
|
+
37 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
324
|
+
38 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
325
|
+
39 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
326
|
+
40 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
327
|
+
41 True ACDC_Phase.SPARSE ACDC_Action.SPARSIFY_MODEL_AND_UPDATE_MASK
|
|
328
|
+
42 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
329
|
+
43 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
330
|
+
44 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
331
|
+
45 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
332
|
+
46 False ACDC_Phase.DENSE ACDC_Action.FILL_MASK_WITH_ONES
|
|
333
|
+
47 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
334
|
+
48 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
335
|
+
49 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
336
|
+
50 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
337
|
+
51 True ACDC_Phase.SPARSE ACDC_Action.SPARSIFY_MODEL_AND_UPDATE_MASK
|
|
338
|
+
52 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
339
|
+
53 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
340
|
+
54 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
341
|
+
55 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
342
|
+
56 False ACDC_Phase.DENSE ACDC_Action.FILL_MASK_WITH_ONES
|
|
343
|
+
57 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
344
|
+
58 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
345
|
+
59 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
346
|
+
60 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
347
|
+
61 True ACDC_Phase.SPARSE ACDC_Action.SPARSIFY_MODEL_AND_UPDATE_MASK
|
|
348
|
+
62 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
349
|
+
63 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
350
|
+
64 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
351
|
+
65 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
352
|
+
66 False ACDC_Phase.DENSE ACDC_Action.FILL_MASK_WITH_ONES
|
|
353
|
+
67 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
354
|
+
68 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
355
|
+
69 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
356
|
+
70 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
357
|
+
71 True ACDC_Phase.SPARSE ACDC_Action.SPARSIFY_MODEL_AND_UPDATE_MASK
|
|
358
|
+
72 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
359
|
+
73 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
360
|
+
74 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
361
|
+
75 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
362
|
+
76 False ACDC_Phase.DENSE ACDC_Action.FILL_MASK_WITH_ONES
|
|
363
|
+
77 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
364
|
+
78 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
365
|
+
79 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
366
|
+
80 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
367
|
+
81 True ACDC_Phase.SPARSE ACDC_Action.SPARSIFY_MODEL_AND_UPDATE_MASK
|
|
368
|
+
82 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
369
|
+
83 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
370
|
+
84 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
371
|
+
85 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
372
|
+
86 False ACDC_Phase.DENSE ACDC_Action.FILL_MASK_WITH_ONES
|
|
373
|
+
87 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
374
|
+
88 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
375
|
+
89 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
376
|
+
90 False ACDC_Phase.DENSE ACDC_Action.KEEP_MASK_FIXED
|
|
377
|
+
91 True ACDC_Phase.SPARSE ACDC_Action.SPARSIFY_MODEL_AND_UPDATE_MASK
|
|
378
|
+
92 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
379
|
+
93 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
380
|
+
94 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
381
|
+
95 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
382
|
+
96 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
383
|
+
97 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
384
|
+
98 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
385
|
+
99 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
386
|
+
100 True ACDC_Phase.SPARSE ACDC_Action.KEEP_MASK_FIXED
|
|
387
|
+
"""
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
class WeightDecayScheduler:
|
|
2
|
+
def __init__(self, weight_decay: float, wd_type: str):
|
|
3
|
+
# assert wd_type in ['const', 'balanced', 'opt-a', 'opt-b']
|
|
4
|
+
self.weight_decay = weight_decay
|
|
5
|
+
self.wd_type = wd_type
|
|
6
|
+
self.awd = 0 # for AWD from apple
|
|
7
|
+
|
|
8
|
+
def get_wd(self):
|
|
9
|
+
if self.wd_type == 'const':
|
|
10
|
+
return self.weight_decay
|
|
11
|
+
if self.wd_type == 'awd':
|
|
12
|
+
return self.awd
|
|
13
|
+
|
|
14
|
+
def __call__(self, w=None, g=None):
|
|
15
|
+
"""
|
|
16
|
+
:param w: tensor that contains weights
|
|
17
|
+
:param g: tensor that contains gradients
|
|
18
|
+
:return: the value for the weight decay
|
|
19
|
+
"""
|
|
20
|
+
if self.wd_type == 'const':
|
|
21
|
+
return self.weight_decay
|
|
22
|
+
|
|
23
|
+
if self.wd_type == 'awd': # AWD from the Apple paper: https://openreview.net/pdf?id=ajnThDhuq6
|
|
24
|
+
assert (w is not None) and (g is not None),\
|
|
25
|
+
'The balanced weight decay scheduler requires valid for w and g, but at least one is None!'
|
|
26
|
+
|
|
27
|
+
# in the paper, lambda_awd is set by the user and they return a lambda_bar
|
|
28
|
+
# here, self.awd will be the moving average from the line 8 in their algorithm and
|
|
29
|
+
# the input to our algorithm is self.weight_decay for all wd_types!
|
|
30
|
+
self.awd = 0.1 * self.awd + 0.9 * self.weight_decay * g.norm(p=2) / w.norm(p=2)
|
|
31
|
+
return self.awd
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
USE_CUDA = True
|
|
6
|
+
try:
|
|
7
|
+
import ista_daslab_dense_mfac
|
|
8
|
+
except Exception as e:
|
|
9
|
+
USE_CUDA = False
|
|
10
|
+
print('\n\t[WARNING] The module "ista_daslab_dense_mfac" is not installed, using slower PyTorch implementation!\n')
|
|
11
|
+
|
|
12
|
+
class DenseCoreMFAC:
|
|
13
|
+
def __init__(self, grads, dev, gpus, damp=1e-5, create_G=False):
|
|
14
|
+
self.m, self.d = grads.shape
|
|
15
|
+
self.dev = dev
|
|
16
|
+
self.gpus = gpus
|
|
17
|
+
self.dtype = grads.dtype
|
|
18
|
+
self.gpus = gpus
|
|
19
|
+
self.grads_count = 0
|
|
20
|
+
self.wandb_data = dict()
|
|
21
|
+
self.damp = None
|
|
22
|
+
self.lambd = None
|
|
23
|
+
self.set_damp(damp)
|
|
24
|
+
self.create_G = create_G
|
|
25
|
+
if self.create_G:
|
|
26
|
+
self.G = grads
|
|
27
|
+
|
|
28
|
+
if USE_CUDA and self.m % 32 != 0 or self.m > 1024:
|
|
29
|
+
raise ValueError('CUDA implementation currently on supports $m$ < 1024 and divisible by 32.')
|
|
30
|
+
|
|
31
|
+
self.dper = self.d // len(gpus) + 1
|
|
32
|
+
self.grads = [] # matrix $G$ in the paper
|
|
33
|
+
for idx in range(len(gpus)):
|
|
34
|
+
start, end = idx * self.dper, (idx + 1) * self.dper
|
|
35
|
+
self.grads.append(grads[:, start:end].to(gpus[idx]))
|
|
36
|
+
self.dots = torch.zeros((self.m, self.m), device=self.dev, dtype=self.dtype) # matrix $GG^T$
|
|
37
|
+
for idx in range(len(gpus)):
|
|
38
|
+
self.dots += self.grads[idx].matmul(self.grads[idx].t()).to(self.dev)
|
|
39
|
+
|
|
40
|
+
self.last = 0 # ringbuffer index
|
|
41
|
+
self.giHig = self.lambd * self.dots # matrix $D$
|
|
42
|
+
self.denom = torch.zeros(self.m, device=self.dev, dtype=self.dtype) # $D_ii + m$
|
|
43
|
+
self.coef = self.lambd * torch.eye(self.m, device=self.dev, dtype=self.dtype) # matrix $B$
|
|
44
|
+
|
|
45
|
+
self.setup()
|
|
46
|
+
|
|
47
|
+
def empty_buffer(self):
|
|
48
|
+
for g in self.grads:
|
|
49
|
+
g.zero_()
|
|
50
|
+
|
|
51
|
+
def set_damp(self, new_damp):
|
|
52
|
+
self.damp = new_damp
|
|
53
|
+
self.lambd = 1. / new_damp
|
|
54
|
+
|
|
55
|
+
def reset_optimizer(self):
|
|
56
|
+
self.grads_count = 0
|
|
57
|
+
for idx in range(len(self.gpus)):
|
|
58
|
+
self.grads[idx].zero_()
|
|
59
|
+
self.dots.zero_()
|
|
60
|
+
for idx in range(len(self.gpus)):
|
|
61
|
+
self.dots += self.grads[idx].matmul(self.grads[idx].t()).to(self.dev)
|
|
62
|
+
self.last = 0
|
|
63
|
+
self.giHig = self.lambd * self.dots # matrix $D$
|
|
64
|
+
self.denom = torch.zeros(self.m, device=self.dev, dtype=self.dtype) # $D_ii + m$
|
|
65
|
+
self.coef = self.lambd * torch.eye(self.m, device=self.dev, dtype=self.dtype) # matrix $B$
|
|
66
|
+
self.setup()
|
|
67
|
+
|
|
68
|
+
# Calculate $D$ / `giHig` and $B$ / `coef`
|
|
69
|
+
def setup(self):
|
|
70
|
+
self.giHig = self.lambd * self.dots
|
|
71
|
+
diag = torch.diag(torch.full(size=[self.m], fill_value=self.m, device=self.dev, dtype=self.dtype))
|
|
72
|
+
self.giHig = torch.lu(self.giHig + diag, pivot=False)[0]
|
|
73
|
+
self.giHig = torch.triu(self.giHig - diag)
|
|
74
|
+
self.denom = self.m + torch.diagonal(self.giHig) # here we should use min(grads_count, m)
|
|
75
|
+
tmp = -self.giHig.t().contiguous() / self.denom.reshape((1, -1))
|
|
76
|
+
|
|
77
|
+
if USE_CUDA:
|
|
78
|
+
diag = torch.diag(torch.full(size=[self.m], fill_value=self.lambd, device=self.dev, dtype=self.dtype))
|
|
79
|
+
self.coef = ista_daslab_dense_mfac.hinv_setup(tmp, diag)
|
|
80
|
+
else:
|
|
81
|
+
for i in range(max(self.last, 1), self.m):
|
|
82
|
+
self.coef[i, :i] = tmp[i, :i].matmul(self.coef[:i, :i])
|
|
83
|
+
|
|
84
|
+
# Replace oldest gradient with `g` and then calculate the IHVP with `g`
|
|
85
|
+
def integrate_gradient_and_precondition(self, g, x):
|
|
86
|
+
tmp = self.integrate_gradient(g)
|
|
87
|
+
p = self.precondition(x, tmp)
|
|
88
|
+
return p
|
|
89
|
+
|
|
90
|
+
# Replace oldest gradient with `g`
|
|
91
|
+
def integrate_gradient(self, g):
|
|
92
|
+
self.set_grad(self.last, g)
|
|
93
|
+
tmp = self.compute_scalar_products(g)
|
|
94
|
+
self.dots[self.last, :] = tmp
|
|
95
|
+
self.dots[:, self.last] = tmp
|
|
96
|
+
self.setup()
|
|
97
|
+
self.last = (self.last + 1) % self.m
|
|
98
|
+
return tmp
|
|
99
|
+
|
|
100
|
+
# Distributed `grads[j, :] = g`
|
|
101
|
+
def set_grad(self, j, g):
|
|
102
|
+
# for the eigenvalue experiment:
|
|
103
|
+
if self.create_G:
|
|
104
|
+
self.G[j, :] = g
|
|
105
|
+
|
|
106
|
+
self.grads_count += 1
|
|
107
|
+
def f(i):
|
|
108
|
+
start, end = i * self.dper, (i + 1) * self.dper
|
|
109
|
+
self.grads[i][j, :] = g[start:end]
|
|
110
|
+
|
|
111
|
+
nn.parallel.parallel_apply(
|
|
112
|
+
[f] * len(self.grads), list(range(len(self.gpus)))
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# Distributed `grads.matmul(x)`
|
|
116
|
+
def compute_scalar_products(self, x):
|
|
117
|
+
def f(i):
|
|
118
|
+
start, end = i * self.dper, (i + 1) * self.dper
|
|
119
|
+
G = self.grads[i]
|
|
120
|
+
return G.matmul(x[start:end].to(G.device)).to(self.dev)
|
|
121
|
+
|
|
122
|
+
outputs = nn.parallel.parallel_apply(
|
|
123
|
+
[f] * len(self.gpus), list(range(len(self.gpus)))
|
|
124
|
+
)
|
|
125
|
+
return sum(outputs)
|
|
126
|
+
|
|
127
|
+
# Product with inverse of dampened empirical Fisher
|
|
128
|
+
def precondition(self, x, dots=None):
|
|
129
|
+
if dots is None:
|
|
130
|
+
dots = self.compute_scalar_products(x)
|
|
131
|
+
giHix = self.lambd * dots
|
|
132
|
+
if USE_CUDA:
|
|
133
|
+
giHix = ista_daslab_dense_mfac.hinv_mul(self.m, self.giHig, giHix)
|
|
134
|
+
else:
|
|
135
|
+
for i in range(1, self.m):
|
|
136
|
+
giHix[i:].sub_(self.giHig[i - 1, i:], alpha=giHix[i - 1] / self.denom[i - 1])
|
|
137
|
+
"""
|
|
138
|
+
giHix size: 1024
|
|
139
|
+
denom size: 1024
|
|
140
|
+
coef size: 1024x1024
|
|
141
|
+
M size: 1024
|
|
142
|
+
x size: d
|
|
143
|
+
"""
|
|
144
|
+
M = (giHix / self.denom).matmul(self.coef)
|
|
145
|
+
partA = self.lambd * x
|
|
146
|
+
partB = self.compute_linear_combination(M)
|
|
147
|
+
self.wandb_data.update({f'norm_partA': partA.norm(p=2), f'norm_partB': partB.norm(p=2)})
|
|
148
|
+
return partA.to(self.dev) - partB.to(self.dev)
|
|
149
|
+
|
|
150
|
+
# Distributed `x.matmul(grads)`
|
|
151
|
+
def compute_linear_combination(self, x):
|
|
152
|
+
def f(G):
|
|
153
|
+
return (x.to(G.device).matmul(G)).to(self.dev)
|
|
154
|
+
outputs = nn.parallel.parallel_apply(
|
|
155
|
+
[f] * len(self.grads), self.grads
|
|
156
|
+
)
|
|
157
|
+
"""
|
|
158
|
+
x size: 1024
|
|
159
|
+
grads: 1024 x d
|
|
160
|
+
"""
|
|
161
|
+
x = x.detach().cpu().numpy()
|
|
162
|
+
norm = np.linalg.norm(x)
|
|
163
|
+
self.wandb_data.update({f'lin_comb_coef_norm': norm})
|
|
164
|
+
return torch.cat(outputs)
|