ista-daslab-optimizers 1.0.1__tar.gz → 1.1.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {ista_daslab_optimizers-1.0.1/ista_daslab_optimizers.egg-info → ista_daslab_optimizers-1.1.3}/PKG-INFO +10 -4
  2. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/README.md +9 -3
  3. ista_daslab_optimizers-1.1.3/ista_daslab_optimizers/micro_adam/micro_adam.py +402 -0
  4. ista_daslab_optimizers-1.1.3/ista_daslab_optimizers/sparse_mfac/__init__.py +7 -0
  5. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3/ista_daslab_optimizers.egg-info}/PKG-INFO +10 -4
  6. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/ista_daslab_optimizers.egg-info/SOURCES.txt +13 -2
  7. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/kernels/micro_adam/micro_adam.cpp +3 -3
  8. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/kernels/micro_adam/micro_adam_asymm_block_quant_inv.cu +9 -7
  9. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/kernels/sparse_mfac/sparse_mfac_LCG_kernel.cu +1 -1
  10. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/pyproject.toml +1 -1
  11. ista_daslab_optimizers-1.0.1/ista_daslab_optimizers/micro_adam/micro_adam.py +0 -247
  12. ista_daslab_optimizers-1.0.1/ista_daslab_optimizers/sparse_mfac/__init__.py +0 -5
  13. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/LICENSE +0 -0
  14. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/MANIFEST.in +0 -0
  15. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/ista_daslab_optimizers/__init__.py +0 -0
  16. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/ista_daslab_optimizers/acdc/__init__.py +0 -0
  17. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/ista_daslab_optimizers/acdc/acdc.py +0 -0
  18. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/ista_daslab_optimizers/acdc/wd_scheduler.py +0 -0
  19. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/ista_daslab_optimizers/dense_mfac/__init__.py +0 -0
  20. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/ista_daslab_optimizers/dense_mfac/dense_core_mfac.py +0 -0
  21. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/ista_daslab_optimizers/dense_mfac/dense_mfac.py +0 -0
  22. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/ista_daslab_optimizers/micro_adam/__init__.py +0 -0
  23. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/ista_daslab_optimizers/sparse_mfac/sparse_core_mfac_w_ef.py +0 -0
  24. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/ista_daslab_optimizers/sparse_mfac/sparse_mfac.py +0 -0
  25. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/ista_daslab_optimizers/tools.py +0 -0
  26. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/ista_daslab_optimizers.egg-info/dependency_links.txt +0 -0
  27. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/ista_daslab_optimizers.egg-info/requires.txt +0 -0
  28. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/ista_daslab_optimizers.egg-info/top_level.txt +0 -0
  29. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/kernels/dense_mfac/dense_mfac.cpp +0 -0
  30. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/kernels/dense_mfac/dense_mfac_kernel.cu +0 -0
  31. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/kernels/micro_adam/micro_adam_asymm_block_quant.cu +0 -0
  32. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/kernels/micro_adam/micro_adam_update.cu +0 -0
  33. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/kernels/sparse_mfac/sparse_mfac.cpp +0 -0
  34. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/kernels/sparse_mfac/sparse_mfac_SP_kernel.cu +0 -0
  35. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/kernels/tools/tools.cpp +0 -0
  36. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/kernels/tools/tools_kernel.cu +0 -0
  37. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/kernels/utils.h +0 -0
  38. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/setup.cfg +0 -0
  39. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.3}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ista_daslab_optimizers
3
- Version: 1.0.1
3
+ Version: 1.1.3
4
4
  Summary: Deep Learning optimizers developed in the Distributed Algorithms and Systems group (DASLab) @ Institute of Science and Technology Austria (ISTA)
5
5
  Author-email: Ionut-Vlad Modoranu <ionut-vlad.modoranu@ist.ac.at>
6
6
  Maintainer-email: Ionut-Vlad Modoranu <ionut-vlad.modoranu@ist.ac.at>
@@ -280,6 +280,7 @@ optimizer = MicroAdam(
280
280
  lr=1e-5, # change accordingly
281
281
  quant_block_size=100_000, # 32 or 64 also works
282
282
  k_init=0.01, # float between 0 and 1 meaning percentage: 0.01 means 1%
283
+ alpha=0, # 0 means sparse update and 0 < alpha < 1 means we integrate fraction alpha from EF to update and then delete it
283
284
  )
284
285
 
285
286
  # from now on, you can use the variable `optimizer` as any other PyTorch optimizer
@@ -288,15 +289,20 @@ optimizer = MicroAdam(
288
289
  # Versions summary:
289
290
 
290
291
  ---
292
+ - **1.1.3** @ September 5th, 2024:
293
+ - allow using `SparseCoreMFACwithEF` separately by importing it in `sparse_mfac.__init__.py`
294
+ - **1.1.2** @ August 1st, 2024:
295
+ - ***[1.1.0]:*** added support to densify the final update: introduced parameter alpha that controls the fraction of error feedback
296
+ (EF) to be integrated into the update to make it dense. Finally, the fraction alpha will be discarded from the EF at
297
+ the expense of another call to `Qinv` and `Q` (and implicitly quantization statistics computation).
298
+ - ***[1.0.2]:*** added FSDP-compatible implementation by initializing the parameter states in the `update_step` method
299
+ instead of MicroAdam constructor
291
300
 
292
301
  - **1.0.1** @ June 27th, 2024:
293
-
294
302
  - removed version in dependencies to avoid conflicts with llm-foundry
295
303
 
296
304
  - **1.0.0** @ June 20th, 2024:
297
-
298
305
  - changed minimum required Python version to 3.8+ and torch to 2.3.0+
299
306
 
300
307
  - **0.0.1** @ June 13th, 2024:
301
-
302
308
  - added initial version of the package for Python 3.9+ and torch 2.3.1+
@@ -55,6 +55,7 @@ optimizer = MicroAdam(
55
55
  lr=1e-5, # change accordingly
56
56
  quant_block_size=100_000, # 32 or 64 also works
57
57
  k_init=0.01, # float between 0 and 1 meaning percentage: 0.01 means 1%
58
+ alpha=0, # 0 means sparse update and 0 < alpha < 1 means we integrate fraction alpha from EF to update and then delete it
58
59
  )
59
60
 
60
61
  # from now on, you can use the variable `optimizer` as any other PyTorch optimizer
@@ -63,15 +64,20 @@ optimizer = MicroAdam(
63
64
  # Versions summary:
64
65
 
65
66
  ---
67
+ - **1.1.3** @ September 5th, 2024:
68
+ - allow using `SparseCoreMFACwithEF` separately by importing it in `sparse_mfac.__init__.py`
69
+ - **1.1.2** @ August 1st, 2024:
70
+ - ***[1.1.0]:*** added support to densify the final update: introduced parameter alpha that controls the fraction of error feedback
71
+ (EF) to be integrated into the update to make it dense. Finally, the fraction alpha will be discarded from the EF at
72
+ the expense of another call to `Qinv` and `Q` (and implicitly quantization statistics computation).
73
+ - ***[1.0.2]:*** added FSDP-compatible implementation by initializing the parameter states in the `update_step` method
74
+ instead of MicroAdam constructor
66
75
 
67
76
  - **1.0.1** @ June 27th, 2024:
68
-
69
77
  - removed version in dependencies to avoid conflicts with llm-foundry
70
78
 
71
79
  - **1.0.0** @ June 20th, 2024:
72
-
73
80
  - changed minimum required Python version to 3.8+ and torch to 2.3.0+
74
81
 
75
82
  - **0.0.1** @ June 13th, 2024:
76
-
77
83
  - added initial version of the package for Python 3.9+ and torch 2.3.1+
@@ -0,0 +1,402 @@
1
+ import os
2
+ import torch
3
+ import math
4
+ import time
5
+ import wandb
6
+ from torch.distributed import is_initialized, get_rank, all_reduce, ReduceOp
7
+ from ..tools import get_first_device, get_gpu_mem_usage, block_split, CopyDirection
8
+
9
+ import ista_daslab_tools
10
+ import ista_daslab_micro_adam
11
+
12
+
13
+ class MicroAdam(torch.optim.Optimizer):
14
+ def __init__(self, params, m, lr, quant_block_size, k_init=0.01, alpha=0, betas=(0.9, 0.999), weight_decay=0, eps=1e-8):
15
+ defaults = dict(lr=lr, weight_decay=weight_decay, eps=eps, alpha=alpha)
16
+ super(MicroAdam, self).__init__(params, defaults)
17
+
18
+ assert (0 <= alpha < 1) or alpha == -2, 'Alpha must be in the [0, 1) interval or -2'
19
+
20
+ self.m = m
21
+ self.lr = lr
22
+ self.quant_block_size = int(quant_block_size)
23
+ self.k_init = k_init
24
+ self.alpha = alpha
25
+ self.weight_decay = weight_decay
26
+ self.beta1 = betas[0]
27
+ self.beta2 = betas[1]
28
+ self.eps = eps
29
+
30
+ self.densify_update_using_ef = (self.alpha > 0)
31
+ self.densify_update_using_quant_error = (self.alpha == -2)
32
+
33
+ self.model_size = sum([p.numel() for group in self.param_groups for p in group['params']])
34
+
35
+ self.steps = 0 # how many optimization steps were performed so far
36
+ self.log_interval = 100
37
+ self.device = get_first_device()
38
+ self._is_state_initialized = False
39
+ self.shared_memory_carveout = 100
40
+ self.blocks = ista_daslab_tools.get_sm_count() * int(100 / self.shared_memory_carveout)
41
+ self.threads = 512
42
+
43
+ self.max_floats = ista_daslab_tools.get_max_floats_for_shared_memory_per_thread_block()
44
+ self.d_block_size = self.max_floats // 2 // int(100 / self.shared_memory_carveout)
45
+
46
+ if torch.distributed.is_initialized():
47
+ self.fsdp_dict_size_count = [{} for _ in range(
48
+ torch.distributed.get_world_size())] # key = layer size, value = how many layers of that size the model has (per worker)
49
+ else:
50
+ self.fsdp_dict_size_count = [{}]
51
+
52
+ self.dict_size_count = {} # key = layer size, value = how many layers of that size the model has
53
+ for param in self.param_groups:
54
+ for p in param['params']:
55
+ size = p.numel()
56
+ # print(p.shape, p.numel())
57
+ self.dict_size_count[size] = 1 + self.dict_size_count.get(size, 0)
58
+
59
+ # self._init_state()
60
+
61
+ def _initialize_parameter_state(self, p, lr, wd):
62
+ layer_size = p.numel()
63
+ st = self.state[p]
64
+
65
+ rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
66
+
67
+ if self.densify_update_using_quant_error:
68
+ st['quant_err'] = torch.zeros_like(p)
69
+
70
+ st['blocks'] = max(1, int(math.floor(self.blocks * layer_size * self.fsdp_dict_size_count[rank][layer_size] / self.model_size)))
71
+
72
+ st['lr'] = lr
73
+ st['weight_decay'] = wd
74
+ st['d'] = layer_size
75
+
76
+ ##### variables for Top-K: d_index_topk is the index where the last, smaller topk block starts
77
+ st['d_block_size'] = layer_size if layer_size < self.d_block_size else self.d_block_size
78
+ st['topk_full_blocks_count'], st['d_index_topk'] = block_split(st['d'], st['d_block_size'])
79
+ st['k_block_size_many'] = int(math.ceil(st['d_block_size'] * self.k_init))
80
+ st['k_block_size_few'] = int(math.ceil((st['d'] - st['d_index_topk']) * self.k_init)) # 0 for d % self.d_block_size = 0
81
+ st['k_index'] = st['topk_full_blocks_count'] * st['k_block_size_many']
82
+ st['k'] = st['k_block_size_many'] * st['topk_full_blocks_count'] + st['k_block_size_few']
83
+
84
+ ##### variables for the ring buffer
85
+ st['index'] = 0 # the position to place a new gradient at
86
+ st['I'] = torch.zeros(self.m, st['k'], dtype=torch.int16, device=self.device) # 2mk bytes
87
+ st['V'] = torch.zeros(self.m, st['k'], dtype=torch.bfloat16, device=self.device) # 2mk bytes
88
+
89
+ ### variables for error feedback: d_index_quant is the index where the last, smaller quantization block starts
90
+ # st['quant_block_size'] = layer_size if layer_size < self.quant_block_size else self.quant_block_size
91
+ st['quant_full_blocks_count'], st['d_index_quant'] = block_split(st['d'], self.quant_block_size)
92
+ st['error'] = torch.zeros(int(math.ceil(st['d'] / 2)), dtype=torch.uint8, device=self.device) # ceil(d/2) bytes
93
+ st['min_vals'] = torch.zeros(st['quant_full_blocks_count'] + 1, dtype=torch.bfloat16, device=self.device) # ceil(d/q_bsz)*2 bytes
94
+ st['max_vals'] = torch.zeros(st['quant_full_blocks_count'] + 1, dtype=torch.bfloat16, device=self.device) # ceil(d/q_bsz)*2 bytes
95
+
96
+ @torch.no_grad()
97
+ def step(self, closure=None):
98
+ self.steps += 1
99
+
100
+ # self._update_lr_wd()
101
+
102
+ loss = None
103
+ if closure is not None:
104
+ with torch.enable_grad():
105
+ loss = closure()
106
+
107
+ if self.steps == 1:
108
+ rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
109
+ for param in self.param_groups:
110
+ for p in param['params']:
111
+ if p is not None:
112
+ size = p.numel()
113
+ if size > 0:
114
+ self.fsdp_dict_size_count[rank][size] = 1 + self.fsdp_dict_size_count[rank].get(size, 0)
115
+
116
+ time_start = time.time()
117
+
118
+ norm_qe, norm_g, norm_u, norm_e, sparsity_u, sparsity_qe = 0, 0, 0, 0, 0, 0
119
+
120
+ for group in self.param_groups:
121
+ lr = group['lr']
122
+ wd = group.get('weight_decay', self.weight_decay)
123
+
124
+ for p in group['params']:
125
+ if p.grad is None:
126
+ continue
127
+
128
+ if p is None:
129
+ continue
130
+
131
+ nqe, ng, nu, ne, sp_u, sp_qe = self.update_step(p, lr, wd)
132
+ norm_qe += nqe
133
+ norm_g += ng
134
+ norm_u += nu
135
+ norm_e += ne
136
+ sparsity_u += sp_u
137
+ sparsity_qe += sp_qe
138
+
139
+ # torch.cuda.synchronize()
140
+ time_end = time.time()
141
+ elapsed_step = time_end - time_start
142
+ self._log(norm_qe, norm_g, norm_u, norm_e, sparsity_u, sparsity_qe, elapsed_step)
143
+
144
+ return loss
145
+
146
+ @torch.no_grad()
147
+ def update_step(self, p, lr, wd):
148
+ norm_qe, norm_g, norm_u, norm_e, sp_u, sp_qe = 0, 0, 0, 0, 0, 0
149
+
150
+ # if p.grad.dtype != torch.bfloat16:
151
+ # grad = p.grad.to(dtype=torch.bfloat16).reshape(-1)
152
+ # else:
153
+ grad = p.grad.view(-1)
154
+
155
+ if self.steps % self.log_interval == 0:
156
+ norm_g = grad.norm(p=2) ** 2
157
+
158
+ st = self.state[p]
159
+ if len(st) == 0:
160
+ self._initialize_parameter_state(p, lr, wd)
161
+
162
+ # print('rank=',torch.distributed.get_rank(), 'keys=',st.keys())
163
+
164
+ blocks = st['blocks']
165
+ # lr = st['lr']
166
+ # wd = st['weight_decay']
167
+ d = st['d']
168
+ d_block_size = st['d_block_size']
169
+ topk_full_blocks_count, d_index_topk = st['topk_full_blocks_count'], st['d_index_topk']
170
+ k_block_size_many = st['k_block_size_many']
171
+ k_block_size_few = st['k_block_size_few']
172
+ k_index = st['k_index']
173
+ k = st['k']
174
+
175
+ # HuggingFace has a setting that converts st['I'] to bfloat16, even though it is declared as int16
176
+ # This happens somewhere between constructor call and step call. Converting it to int16 was the simplest solution
177
+ if st['I'].dtype != torch.int16:
178
+ st['I'] = st['I'].to(torch.int16)
179
+
180
+ index = st['index']
181
+ I = st['I']
182
+ V = st['V']
183
+
184
+ quant_full_blocks_count, d_index_quant = st['quant_full_blocks_count'], st['d_index_quant']
185
+ error = st['error']
186
+ min_vals = st['min_vals']
187
+ max_vals = st['max_vals']
188
+
189
+ ##### STEP 4
190
+ ista_daslab_micro_adam.asymm_block_quant_inv(d, self.quant_block_size, error, min_vals, max_vals, grad, 1.0) # alpha=1 here
191
+
192
+ ##### STEP 5 + 9 (only for I)
193
+ I[index, :k_index] = torch.topk(input=grad[0:d_index_topk].abs().view(topk_full_blocks_count, d_block_size),
194
+ k=k_block_size_many,
195
+ sorted=False).indices.to(dtype=torch.int16).view(-1)
196
+
197
+ if k_block_size_few > 0: # there is a small block left
198
+ I[index, k_index:] = torch.topk(input=grad[d_index_topk:].abs(),
199
+ k=k_block_size_few, # example: slice has size 1, but ks[-1] is 4
200
+ sorted=False).indices.to(dtype=torch.int16).view(-1)
201
+
202
+ ista_daslab_tools.copy_values(d, # V = error[I[buffer_index, :]]
203
+ k,
204
+ d_block_size,
205
+ k_block_size_many,
206
+ I[index, :], # indices
207
+ grad, # inp
208
+ V[index, :], # output
209
+ CopyDirection.d2k.value)
210
+
211
+ st['index'] = (index + 1) % self.m
212
+
213
+ ##### STEP 6
214
+ ista_daslab_tools.zerorize_block_components(grad, I[index, :], d, k, d_block_size, k_block_size_many) # this does a[I[index]] = 0
215
+
216
+ ##### STEP 7
217
+ def _update_quantization_statistics():
218
+ if quant_full_blocks_count == 1:
219
+ min_vals[:quant_full_blocks_count] = grad[:d_index_quant].min()
220
+ max_vals[:quant_full_blocks_count] = grad[:d_index_quant].max()
221
+ else:
222
+ min_vals[:quant_full_blocks_count] = grad[:d_index_quant].view(quant_full_blocks_count, self.quant_block_size).min(dim=1).values
223
+ max_vals[:quant_full_blocks_count] = grad[:d_index_quant].view(quant_full_blocks_count, self.quant_block_size).max(dim=1).values
224
+ if d_index_quant < d:
225
+ min_vals[quant_full_blocks_count] = grad[d_index_quant:].min()
226
+ max_vals[quant_full_blocks_count] = grad[d_index_quant:].max()
227
+
228
+ _update_quantization_statistics()
229
+
230
+ ##### STEP 8
231
+ ista_daslab_micro_adam.asymm_block_quant(d, self.quant_block_size, error, min_vals, max_vals, grad) # error = Q(a, min, max)
232
+
233
+ # weight decay step
234
+ if wd > 0:
235
+ p.mul_(1 - lr * wd)
236
+
237
+ ##### NEW: densify using quant error
238
+ if self.densify_update_using_quant_error:
239
+ # When entering this if-statement, we have:
240
+ # - p is theta_t
241
+ # - p.grad is a_t (from step 6 in algorithm 1)
242
+ # - error is e_t+1 (from step 8 in algorithm 1)
243
+ #
244
+ # Below we have the formula to update the model parameters:
245
+ # [a = -1] with lr
246
+ # theta_t+1 = theta_t - lr * (a_t - Qinv(e_t+1)) - lr * u_t
247
+ # = theta_t - lr * a_t + lr * Qinv(e_t+1) - lr * u_t
248
+ # = theta_t - lr * a_t # STEP A below, in this if statmenet
249
+ # + lr * Qinv(e_t+1) # STEP B below, in this if statmenet
250
+ # - lr * u_t # this is steps 10-11
251
+ #
252
+ # [a = -2] without lr
253
+ # theta_t+1 = theta_t - (a_t - Qinv(e_t+1)) - lr * u_t
254
+ # = theta_t - a_t + Qinv(e_t+1) - lr * u_t
255
+ # = theta_t - a_t # STEP A below, in this if statmenet
256
+ # + Qinv(e_t+1) # STEP B below, in this if statmenet
257
+ # - lr * u_t # this is steps 10-11
258
+ quant_err = st['quant_err']
259
+ quant_err.zero_()
260
+ quant_err.add_(p.grad)
261
+
262
+ ##### STEP A
263
+ p.add_(p.grad, alpha=-1)
264
+
265
+ ##### STEP B
266
+ p.grad.zero_() # zerorize to prepare the accumulator for Qinv
267
+ ista_daslab_micro_adam.asymm_block_quant_inv(d, self.quant_block_size, error, min_vals, max_vals, grad, 1)
268
+ p.add_(p.grad)
269
+
270
+ quant_err.sub_(p.grad)
271
+
272
+ norm_qe = quant_err.norm(p=2) ** 2
273
+ sp_qe = (quant_err == 0).sum()
274
+
275
+ ##### STEPS 10-11
276
+ grad.zero_()
277
+ ista_daslab_micro_adam.compute_microadam_update(blocks, # blocks
278
+ self.threads, # threads
279
+ self.shared_memory_carveout, # carveout
280
+ self.steps, # optimization step
281
+ self.beta1, # beta1
282
+ self.beta2, # beta2
283
+ self.eps, # eps
284
+ d_block_size, # d_block_size
285
+ k_block_size_many, # k_block_size
286
+ d, # d
287
+ self.m, # m
288
+ k, # k
289
+ I, # indices
290
+ V, # values
291
+ grad) # update will be stored here
292
+
293
+ ##### STEP 12: # side idea: only decay the weights that are update
294
+
295
+ ##### if PRETRAINING #1
296
+ if self.densify_update_using_ef: # we add alpha * EF to update that is stored in grad buffer
297
+ # p.grad += alpha * Qinv(error), alpha=0.1
298
+ ista_daslab_micro_adam.asymm_block_quant_inv(d, self.quant_block_size, error, min_vals, max_vals, grad, self.alpha)
299
+ ##### END IF PRETRAINING #1
300
+
301
+ # if alpha > 0, then the update u=p.grad is dense now
302
+
303
+ # update model using MicroAdam update stored in p.grad
304
+ p.add_(p.grad, alpha=-lr)
305
+
306
+ if self.steps % self.log_interval == 0:
307
+ norm_u = grad.norm(p=2) ** 2
308
+ sp_u = (grad == 0).sum() # check sparsity before zerorizing
309
+
310
+ ##### if PRETRAINING #2
311
+ if self.densify_update_using_ef:
312
+ grad.zero_()
313
+ ista_daslab_micro_adam.asymm_block_quant_inv(d, self.quant_block_size, error, min_vals, max_vals, grad, 1-self.alpha)
314
+
315
+ _update_quantization_statistics() # step 7 again
316
+ ista_daslab_micro_adam.asymm_block_quant(d, self.quant_block_size, error, min_vals, max_vals, grad) # step 8 again
317
+ ##### END IF PRETRAINING #2
318
+
319
+ # compute error norm
320
+ if self.steps % self.log_interval == 0:
321
+ grad.zero_()
322
+ ista_daslab_micro_adam.asymm_block_quant_inv(d, self.quant_block_size, error, min_vals, max_vals, grad, 1.0)
323
+
324
+ norm_e = grad.norm(p=2) ** 2
325
+
326
+ # p.grad = p.grad.to(dtype=original_grad_type)
327
+
328
+ return norm_qe, norm_g, norm_u, norm_e, sp_u, sp_qe
329
+
330
+ def _log(self, norm_qe, norm_g, norm_u, norm_e, sparsity_u, sparsity_qe, elapsed_step):
331
+ if self.steps % self.log_interval == 0:
332
+ if is_initialized():
333
+ sync_data = torch.tensor([norm_qe, norm_g, norm_u, norm_e, sparsity_u, sparsity_qe, elapsed_step], dtype=torch.float,
334
+ requires_grad=False).cuda() # correct, loss, size
335
+ all_reduce(sync_data, op=ReduceOp.SUM)
336
+ norm_qe, norm_g, norm_u, norm_e, sparsity_u, sparsity_qe, elapsed_step = sync_data
337
+
338
+ if not is_initialized() or get_rank() == 0:
339
+ wandb_data = {
340
+ 'step/optimizer_steps': self.steps,
341
+ 'step/gpu_mem_usage': get_gpu_mem_usage(),
342
+ 'step/norm_quant_err': math.sqrt(norm_qe),
343
+ 'step/sparsity_quant_err': sparsity_qe / self.model_size * 100.,
344
+ 'step/norm_g': math.sqrt(norm_g),
345
+ 'step/norm_u': math.sqrt(norm_u),
346
+ 'step/norm_error': math.sqrt(norm_e),
347
+ 'step/sparsity_u': sparsity_u / self.model_size * 100.,
348
+ 'step/elapsed_step': elapsed_step,
349
+ }
350
+ wandb.log(wandb_data, commit=False)
351
+
352
+ # def _update_lr_wd(self):
353
+ # # copy the learning rate group to parameter state because the lr scheduler updates the one in the group
354
+ # for group in self.param_groups:
355
+ # lr = group['lr']
356
+ # wd = group.get('weight_decay', self.weight_decay) # if the param groups do not have weight decay, then use the external one
357
+ # for p in group['params']:
358
+ # self.state[p]['lr'] = lr
359
+ # self.state[p]['wd'] = wd
360
+
361
+
362
+ # def _init_state(self):
363
+ # count = 0
364
+ # for group in self.param_groups:
365
+ # lr = group['lr']
366
+ # wd = group.get('weight_decay', self.weight_decay) # if the param groups do not have weight decay, then use the external one
367
+ # for p in group['params']:
368
+ # if not p.requires_grad:
369
+ # continue
370
+
371
+ # print(f'[init_state] rank={torch.distributed.get_rank()}, p.shape={p.shape}')
372
+
373
+ # count += 1
374
+ # layer_size = p.numel()
375
+ # st = self.state[p]
376
+
377
+ # # B * t / d * nt
378
+ # st['blocks'] = max(1, int(math.floor(self.blocks * layer_size * self.dict_size_count[layer_size] / self.model_size)))
379
+
380
+ # st['lr'] = lr
381
+ # st['weight_decay'] = wd
382
+ # st['d'] = layer_size
383
+
384
+ # ##### variables for Top-K: d_index_topk is the index where the last, smaller topk block starts
385
+ # st['d_block_size'] = layer_size if layer_size < self.d_block_size else self.d_block_size
386
+ # st['topk_full_blocks_count'], st['d_index_topk'] = block_split(st['d'], st['d_block_size'])
387
+ # st['k_block_size_many'] = int(math.ceil(st['d_block_size'] * self.k_init))
388
+ # st['k_block_size_few'] = int(math.ceil((st['d'] - st['d_index_topk']) * self.k_init)) # 0 for d % self.d_block_size = 0
389
+ # st['k_index'] = st['topk_full_blocks_count'] * st['k_block_size_many']
390
+ # st['k'] = st['k_block_size_many'] * st['topk_full_blocks_count'] + st['k_block_size_few']
391
+
392
+ # ##### variables for the ring buffer
393
+ # st['index'] = 0 # the position to place a new gradient at
394
+ # st['I'] = torch.zeros(self.m, st['k'], dtype=torch.int16, device=self.device) # 2mk bytes
395
+ # st['V'] = torch.zeros(self.m, st['k'], dtype=torch.bfloat16, device=self.device) # 2mk bytes
396
+
397
+ # ### variables for error feedback: d_index_quant is the index where the last, smaller quantization block starts
398
+ # # st['quant_block_size'] = layer_size if layer_size < self.quant_block_size else self.quant_block_size
399
+ # st['quant_full_blocks_count'], st['d_index_quant'] = block_split(st['d'], self.quant_block_size)
400
+ # st['error'] = torch.zeros(int(math.ceil(st['d'] / 2)), dtype=torch.uint8, device=self.device) # ceil(d/2) bytes
401
+ # st['min_vals'] = torch.zeros(st['quant_full_blocks_count'] + 1, dtype=torch.bfloat16, device=self.device) # ceil(d/q_bsz)*2 bytes
402
+ # st['max_vals'] = torch.zeros(st['quant_full_blocks_count'] + 1, dtype=torch.bfloat16, device=self.device) # ceil(d/q_bsz)*2 bytes
@@ -0,0 +1,7 @@
1
+ from .sparse_mfac import SparseMFAC
2
+ from .sparse_core_mfac_w_ef import SparseCoreMFACwithEF
3
+
4
+ __all__ = [
5
+ 'SparseMFAC',
6
+ 'SparseCoreMFACwithEF'
7
+ ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ista_daslab_optimizers
3
- Version: 1.0.1
3
+ Version: 1.1.3
4
4
  Summary: Deep Learning optimizers developed in the Distributed Algorithms and Systems group (DASLab) @ Institute of Science and Technology Austria (ISTA)
5
5
  Author-email: Ionut-Vlad Modoranu <ionut-vlad.modoranu@ist.ac.at>
6
6
  Maintainer-email: Ionut-Vlad Modoranu <ionut-vlad.modoranu@ist.ac.at>
@@ -280,6 +280,7 @@ optimizer = MicroAdam(
280
280
  lr=1e-5, # change accordingly
281
281
  quant_block_size=100_000, # 32 or 64 also works
282
282
  k_init=0.01, # float between 0 and 1 meaning percentage: 0.01 means 1%
283
+ alpha=0, # 0 means sparse update and 0 < alpha < 1 means we integrate fraction alpha from EF to update and then delete it
283
284
  )
284
285
 
285
286
  # from now on, you can use the variable `optimizer` as any other PyTorch optimizer
@@ -288,15 +289,20 @@ optimizer = MicroAdam(
288
289
  # Versions summary:
289
290
 
290
291
  ---
292
+ - **1.1.3** @ September 5th, 2024:
293
+ - allow using `SparseCoreMFACwithEF` separately by importing it in `sparse_mfac.__init__.py`
294
+ - **1.1.2** @ August 1st, 2024:
295
+ - ***[1.1.0]:*** added support to densify the final update: introduced parameter alpha that controls the fraction of error feedback
296
+ (EF) to be integrated into the update to make it dense. Finally, the fraction alpha will be discarded from the EF at
297
+ the expense of another call to `Qinv` and `Q` (and implicitly quantization statistics computation).
298
+ - ***[1.0.2]:*** added FSDP-compatible implementation by initializing the parameter states in the `update_step` method
299
+ instead of MicroAdam constructor
291
300
 
292
301
  - **1.0.1** @ June 27th, 2024:
293
-
294
302
  - removed version in dependencies to avoid conflicts with llm-foundry
295
303
 
296
304
  - **1.0.0** @ June 20th, 2024:
297
-
298
305
  - changed minimum required Python version to 3.8+ and torch to 2.3.0+
299
306
 
300
307
  - **0.0.1** @ June 13th, 2024:
301
-
302
308
  - added initial version of the package for Python 3.9+ and torch 2.3.1+
@@ -3,7 +3,6 @@ MANIFEST.in
3
3
  README.md
4
4
  pyproject.toml
5
5
  setup.py
6
- ./kernels/utils.h
7
6
  ./kernels/dense_mfac/dense_mfac.cpp
8
7
  ./kernels/dense_mfac/dense_mfac_kernel.cu
9
8
  ./kernels/micro_adam/micro_adam.cpp
@@ -32,4 +31,16 @@ ista_daslab_optimizers/micro_adam/__init__.py
32
31
  ista_daslab_optimizers/micro_adam/micro_adam.py
33
32
  ista_daslab_optimizers/sparse_mfac/__init__.py
34
33
  ista_daslab_optimizers/sparse_mfac/sparse_core_mfac_w_ef.py
35
- ista_daslab_optimizers/sparse_mfac/sparse_mfac.py
34
+ ista_daslab_optimizers/sparse_mfac/sparse_mfac.py
35
+ kernels/utils.h
36
+ kernels/dense_mfac/dense_mfac.cpp
37
+ kernels/dense_mfac/dense_mfac_kernel.cu
38
+ kernels/micro_adam/micro_adam.cpp
39
+ kernels/micro_adam/micro_adam_asymm_block_quant.cu
40
+ kernels/micro_adam/micro_adam_asymm_block_quant_inv.cu
41
+ kernels/micro_adam/micro_adam_update.cu
42
+ kernels/sparse_mfac/sparse_mfac.cpp
43
+ kernels/sparse_mfac/sparse_mfac_LCG_kernel.cu
44
+ kernels/sparse_mfac/sparse_mfac_SP_kernel.cu
45
+ kernels/tools/tools.cpp
46
+ kernels/tools/tools_kernel.cu
@@ -13,7 +13,7 @@ void compute_microadam_update_cuda(int blocks, int threads, int carveout,
13
13
  torch::Tensor indices, torch::Tensor values, torch::Tensor out);
14
14
 
15
15
  void asymm_block_quant_cuda(LL d, LL q_block_size, torch::Tensor xq, torch::Tensor xmin, torch::Tensor xmax, torch::Tensor x);
16
- void asymm_block_quant_inv_cuda(LL d, LL q_block_size, torch::Tensor xq, torch::Tensor xmin, torch::Tensor xmax, torch::Tensor x);
16
+ void asymm_block_quant_inv_cuda(LL d, LL q_block_size, torch::Tensor xq, torch::Tensor xmin, torch::Tensor xmax, torch::Tensor x, float alpha);
17
17
 
18
18
  // C++ methods
19
19
  void compute_microadam_update(int blocks, int threads, int carveout,
@@ -44,14 +44,14 @@ void asymm_block_quant(LL d, LL q_block_size, torch::Tensor xq, torch::Tensor xm
44
44
  asymm_block_quant_cuda(d, q_block_size, xq, xmin, xmax, x);
45
45
  }
46
46
 
47
- void asymm_block_quant_inv(LL d, LL q_block_size, torch::Tensor xq, torch::Tensor xmin, torch::Tensor xmax, torch::Tensor x) {
47
+ void asymm_block_quant_inv(LL d, LL q_block_size, torch::Tensor xq, torch::Tensor xmin, torch::Tensor xmax, torch::Tensor x, float alpha) {
48
48
  CHECK_INPUT(xq);
49
49
  CHECK_INPUT(xmin);
50
50
  CHECK_INPUT(xmax);
51
51
  CHECK_INPUT(x);
52
52
 
53
53
  const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
54
- asymm_block_quant_inv_cuda(d, q_block_size, xq, xmin, xmax, x);
54
+ asymm_block_quant_inv_cuda(d, q_block_size, xq, xmin, xmax, x, alpha);
55
55
  }
56
56
 
57
57
  PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
@@ -1,8 +1,8 @@
1
1
  #include "../utils.h"
2
2
 
3
- __global__ void asymm_block_quant_inv_kernel_bf16_bf16(LL d, LL q_block_size, uint8_t *xq, bfloat16 *xmin, bfloat16 *xmax, bfloat16 *x) {
3
+ __global__ void asymm_block_quant_inv_kernel_bf16_bf16(LL d, LL q_block_size, uint8_t *xq, bfloat16 *xmin, bfloat16 *xmax, bfloat16 *x, float alpha) {
4
4
  /*
5
- This kernel computes x += Q_inv(xq, xmin, xmax) for 4 bits (implements point 1 from PhD notebook page 55)
5
+ This kernel computes x += alpha * Q_inv(xq, xmin, xmax) for 4 bits (implements point 1 from PhD #9 notebook page 55)
6
6
  Here, x is the output buffer and will already contain the dense gradient
7
7
  The output buffer x has d components and xq has d/2 components because one uint8_t stores two 4-bit values
8
8
  In contrast to "globally" kernel, this kernel works with a single block
@@ -43,20 +43,21 @@ __global__ void asymm_block_quant_inv_kernel_bf16_bf16(LL d, LL q_block_size, ui
43
43
  lsb = (vq & 0x0F);
44
44
 
45
45
  // += operation happens here
46
- vx2.x += __float2bfloat16(msb * u + m);
47
- vx2.y += __float2bfloat16(lsb * u + m);
46
+ vx2.x += __float2bfloat16((msb * u + m) * alpha);
47
+ vx2.y += __float2bfloat16((lsb * u + m) * alpha);
48
48
  x2[half_index] = vx2;
49
49
  }
50
50
  if((d & 1) && (Bid == B-1) && (Tid == T-1)) {
51
51
  bfloat16 vx = x[d - 1];
52
52
  vq = xq[half_d];
53
53
  msb = (vq & 0xF0) >> 4;
54
- vx += __float2bfloat16(msb * u + m);
54
+ // += operation happens here
55
+ vx += __float2bfloat16((msb * u + m) * alpha);
55
56
  x[d - 1] = vx;
56
57
  }
57
58
  }
58
59
 
59
- void asymm_block_quant_inv_cuda(LL d, LL q_block_size, torch::Tensor xq, torch::Tensor xmin, torch::Tensor xmax, torch::Tensor x) {
60
+ void asymm_block_quant_inv_cuda(LL d, LL q_block_size, torch::Tensor xq, torch::Tensor xmin, torch::Tensor xmax, torch::Tensor x, float alpha) {
60
61
  ASSERT_BF16(xmin);
61
62
  ASSERT_BF16(xmax);
62
63
  ASSERT_BF16(x);
@@ -72,7 +73,8 @@ void asymm_block_quant_inv_cuda(LL d, LL q_block_size, torch::Tensor xq, torch::
72
73
  (uint8_t*) xq.data_ptr(),
73
74
  (bfloat16*) xmin.data_ptr(),
74
75
  (bfloat16*) xmax.data_ptr(),
75
- (bfloat16*) x.data_ptr());
76
+ (bfloat16*) x.data_ptr(),
77
+ alpha);
76
78
 
77
79
  // error checks
78
80
  GPU_ERROR_CHECK(cudaGetLastError());
@@ -144,7 +144,7 @@ __global__ void LCG_v51_bf16(long d,
144
144
 
145
145
  long i; // vector index (for indices and values)
146
146
  long k_col; // column index to extract data from indices and values at the current row
147
- long index; // the 1-D index to extract data from indices and values at the current row (row * k + k_col)
147
+ // long index; // the 1-D index to extract data from indices and values at the current row (row * k + k_col)
148
148
  int16 ind; // the data from indices at the index "index"
149
149
  bfloat16 val; // the data from values at the index "index"
150
150
 
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name='ista_daslab_optimizers'
7
- version='1.0.1'
7
+ version='1.1.3'
8
8
  dependencies = [
9
9
  "torch", # >=2.3.1",
10
10
  "torchaudio", # >=2.3.1",
@@ -1,247 +0,0 @@
1
-
2
- import torch
3
- import math
4
- import time
5
- import wandb
6
- from torch.distributed import is_initialized, get_rank
7
- from ..tools import get_first_device, get_gpu_mem_usage, block_split, CopyDirection
8
-
9
- import ista_daslab_tools
10
- import ista_daslab_micro_adam
11
-
12
- class MicroAdam(torch.optim.Optimizer):
13
- def __init__(self, params, m, lr, quant_block_size, k_init=0.01, betas=(0.9, 0.999), weight_decay=0, eps=1e-8):
14
- defaults = dict(lr=lr, weight_decay=weight_decay, eps=eps)
15
- super(MicroAdam, self).__init__(params, defaults)
16
-
17
- self.m = m
18
- self.lr = lr
19
- self.quant_block_size = int(quant_block_size)
20
- self.k_init = k_init
21
- self.weight_decay = weight_decay
22
- self.beta1 = betas[0]
23
- self.beta2 = betas[1]
24
- self.eps = eps
25
-
26
- self.model_size = sum([p.numel() for group in self.param_groups for p in group['params']])
27
-
28
- self.steps = 0 # how many optimization steps were performed so far
29
- self.log_interval = 100
30
- self.device = get_first_device()
31
- self._is_state_initialized = False
32
- self.shared_memory_carveout = 100
33
- self.blocks = ista_daslab_tools.get_sm_count() * int(100 / self.shared_memory_carveout)
34
- self.threads = 512
35
-
36
- self.dict_size_count = {} # key = layer size, value = how many layers of that size the model has
37
- for param in self.param_groups:
38
- for p in param['params']:
39
- size = p.numel()
40
- self.dict_size_count[size] = 1 + self.dict_size_count.get(size, 0)
41
-
42
- self._init_state()
43
-
44
- def _init_state(self):
45
- max_floats = ista_daslab_tools.get_max_floats_for_shared_memory_per_thread_block()
46
- d_block_size = max_floats // 2 // int(100 / self.shared_memory_carveout)
47
- count = 0
48
- for group in self.param_groups:
49
- lr = group['lr']
50
- wd = group.get('weight_decay', self.weight_decay) # if the param groups do not have weight decay, then use the external one
51
- for p in group['params']:
52
- if not p.requires_grad:
53
- continue
54
- count += 1
55
- layer_size = p.numel()
56
- st = self.state[p]
57
-
58
- # B * t / d * nt
59
- st['blocks'] = max(1, int(math.floor(self.blocks * layer_size * self.dict_size_count[layer_size] / self.model_size)))
60
-
61
- st['lr'] = lr
62
- st['weight_decay'] = wd
63
- st['d'] = layer_size
64
-
65
- ##### variables for Top-K: d_index_topk is the index where the last, smaller topk block starts
66
- st['d_block_size'] = layer_size if layer_size < d_block_size else d_block_size
67
- st['topk_full_blocks_count'], st['d_index_topk'] = block_split(st['d'], st['d_block_size'])
68
- st['k_block_size_many'] = int(math.ceil(st['d_block_size'] * self.k_init))
69
- st['k_block_size_few'] = int(math.ceil((st['d'] - st['d_index_topk']) * self.k_init)) # 0 for d % d_block_size = 0
70
- st['k_index'] = st['topk_full_blocks_count'] * st['k_block_size_many']
71
- st['k'] = st['k_block_size_many'] * st['topk_full_blocks_count'] + st['k_block_size_few']
72
-
73
- ##### variables for the ring buffer
74
- st['index'] = 0 # the position to place a new gradient at
75
- st['I'] = torch.zeros(self.m, st['k'], dtype=torch.int16, device=self.device) # 2mk bytes
76
- st['V'] = torch.zeros(self.m, st['k'], dtype=torch.bfloat16, device=self.device) # 2mk bytes
77
-
78
- ### variables for error feedback: d_index_quant is the index where the last, smaller quantization block starts
79
- # st['quant_block_size'] = layer_size if layer_size < self.quant_block_size else self.quant_block_size
80
- st['quant_full_blocks_count'], st['d_index_quant'] = block_split(st['d'], self.quant_block_size)
81
- st['error'] = torch.zeros(int(math.ceil(st['d'] / 2)), dtype=torch.uint8, device=self.device) # ceil(d/2) bytes
82
- st['min_vals'] = torch.zeros(st['quant_full_blocks_count'] + 1, dtype=torch.bfloat16, device=self.device) # ceil(d/q_bsz)*2 bytes
83
- st['max_vals'] = torch.zeros(st['quant_full_blocks_count'] + 1, dtype=torch.bfloat16, device=self.device) # ceil(d/q_bsz)*2 bytes
84
-
85
- @torch.no_grad()
86
- def step(self, closure=None):
87
- self.steps += 1
88
-
89
- self._update_lr_wd()
90
-
91
- loss = None
92
- if closure is not None:
93
- with torch.enable_grad():
94
- loss = closure()
95
-
96
- time_start = time.time()
97
-
98
- norm_g, norm_u, norm_e, sparsity_u = 0, 0, 0, 0
99
- for group in self.param_groups:
100
- for p in group['params']:
101
- if p.grad is None:
102
- continue
103
- ng, nu, ne, sp_u = self.update_step(p)
104
- norm_g += ng
105
- norm_u += nu
106
- norm_e += ne
107
- sparsity_u += sp_u
108
-
109
- # torch.cuda.synchronize()
110
- time_end = time.time()
111
- elapsed_step = time_end - time_start
112
- self._log(norm_g, norm_u, norm_e, sparsity_u, elapsed_step)
113
-
114
- return loss
115
-
116
- @torch.no_grad()
117
- def update_step(self, p):
118
- norm_g, norm_u, norm_e, sp_u = 0, 0, 0, 0
119
-
120
- st = self.state[p]
121
- grad = p.grad.view(-1)
122
-
123
- if self.steps % self.log_interval == 0:
124
- norm_g = grad.norm(p=2) ** 2
125
-
126
- blocks = st['blocks']
127
- lr = st['lr']
128
- wd = st['weight_decay']
129
- d = st['d']
130
- d_block_size = st['d_block_size']
131
- topk_full_blocks_count, d_index_topk = st['topk_full_blocks_count'], st['d_index_topk']
132
- k_block_size_many = st['k_block_size_many']
133
- k_block_size_few = st['k_block_size_few']
134
- k_index = st['k_index']
135
- k = st['k']
136
-
137
- # HuggingFace has a setting that converts st['I'] to bfloat16, even though it is declared as int16
138
- # This happens somewhere between constructor call and step call. Converting it to int16 was the simplest solution
139
- if st['I'].dtype != torch.int16:
140
- st['I'] = st['I'].to(torch.int16)
141
-
142
- index = st['index']
143
- I = st['I']
144
- V = st['V']
145
-
146
- quant_full_blocks_count, d_index_quant = st['quant_full_blocks_count'], st['d_index_quant']
147
- error = st['error']
148
- min_vals = st['min_vals']
149
- max_vals = st['max_vals']
150
-
151
- ##### STEP 4
152
- ista_daslab_micro_adam.asymm_block_quant_inv(d, self.quant_block_size, error, min_vals, max_vals, grad)
153
-
154
- ##### STEP 5 + 9 (only for I)
155
- I[index, :k_index] = torch.topk(input=grad[0:d_index_topk].abs().view(topk_full_blocks_count, d_block_size),
156
- k=k_block_size_many,
157
- sorted=False).indices.to(dtype=torch.int16).view(-1)
158
-
159
- if k_block_size_few > 0: # there is a small block left
160
- I[index, k_index:] = torch.topk(input=grad[d_index_topk:].abs(),
161
- k=k_block_size_few, # example: slice has size 1, but ks[-1] is 4
162
- sorted=False).indices.to(dtype=torch.int16).view(-1)
163
-
164
- ista_daslab_tools.copy_values(d, # V = error[I[buffer_index, :]]
165
- k,
166
- d_block_size,
167
- k_block_size_many,
168
- I[index, :], # indices
169
- grad, # inp
170
- V[index, :], # output
171
- CopyDirection.d2k.value)
172
-
173
- st['index'] = (index + 1) % self.m
174
-
175
- ##### STEP 6
176
- ista_daslab_tools.zerorize_block_components(grad, I[index, :], d, k, d_block_size, k_block_size_many) # this does a[I[index]] = 0
177
-
178
- ##### STEP 7
179
- if quant_full_blocks_count == 1:
180
- min_vals[:quant_full_blocks_count] = grad[:d_index_quant].min()
181
- max_vals[:quant_full_blocks_count] = grad[:d_index_quant].max()
182
- else:
183
- min_vals[:quant_full_blocks_count] = grad[:d_index_quant].view(quant_full_blocks_count, self.quant_block_size).min(dim=1).values
184
- max_vals[:quant_full_blocks_count] = grad[:d_index_quant].view(quant_full_blocks_count, self.quant_block_size).max(dim=1).values
185
- if d_index_quant < d:
186
- min_vals[quant_full_blocks_count] = grad[d_index_quant:].min()
187
- max_vals[quant_full_blocks_count] = grad[d_index_quant:].max()
188
-
189
- ##### STEP 8
190
- ista_daslab_micro_adam.asymm_block_quant(d, self.quant_block_size, error, min_vals, max_vals, grad) # error = Q(a, min, max)
191
-
192
- ##### STEPS 10-11
193
- grad.zero_()
194
- ista_daslab_micro_adam.compute_microadam_update(blocks, # blocks
195
- self.threads, # threads
196
- self.shared_memory_carveout, # carveout
197
- self.steps, # optimization step
198
- self.beta1, # beta1
199
- self.beta2, # beta2
200
- self.eps, # eps
201
- d_block_size, # d_block_size
202
- k_block_size_many, # k_block_size
203
- d, # d
204
- self.m, # m
205
- k, # k
206
- I, # indices
207
- V, # values
208
- grad) # update will be stored here
209
-
210
- ##### STEP 12
211
- p.mul_(1 - lr * wd).add_(p.grad, alpha=-lr)
212
-
213
- # compute error norm
214
- if self.steps % self.log_interval == 0:
215
- norm_u = grad.norm(p=2) ** 2
216
- sp_u = (grad == 0).sum() # check sparsity before zerorizing
217
-
218
- grad.zero_()
219
- ista_daslab_micro_adam.asymm_block_quant_inv(d, self.quant_block_size, error, min_vals, max_vals, grad)
220
-
221
- norm_e = grad.norm(p=2) ** 2
222
-
223
- return norm_g, norm_u, norm_e, sp_u
224
-
225
- def _log(self, norm_g, norm_u, norm_e, sparsity_u, elapsed_step):
226
- if self.steps % self.log_interval == 0:
227
- wandb_data = {
228
- 'step/optimizer_steps': self.steps,
229
- 'step/gpu_mem_usage': get_gpu_mem_usage(),
230
- 'step/norm_g': math.sqrt(norm_g),
231
- 'step/norm_u': math.sqrt(norm_u),
232
- 'step/norm_error': math.sqrt(norm_e),
233
- 'step/sparsity_u': sparsity_u / self.model_size * 100.,
234
- 'step/elapsed_step': elapsed_step,
235
- }
236
-
237
- if not is_initialized() or get_rank() == 0:
238
- wandb.log(wandb_data, commit=False)
239
-
240
- def _update_lr_wd(self):
241
- # copy the learning rate group to parameter state because the lr scheduler updates the one in the group
242
- for group in self.param_groups:
243
- lr = group['lr']
244
- wd = group.get('weight_decay', self.weight_decay) # if the param groups do not have weight decay, then use the external one
245
- for p in group['params']:
246
- self.state[p]['lr'] = lr
247
- self.state[p]['wd'] = wd
@@ -1,5 +0,0 @@
1
- from .sparse_mfac import SparseMFAC
2
-
3
- __all__ = [
4
- 'SparseMFAC',
5
- ]