ista-daslab-optimizers 1.0.1__tar.gz → 1.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {ista_daslab_optimizers-1.0.1/ista_daslab_optimizers.egg-info → ista_daslab_optimizers-1.1.2}/PKG-INFO +8 -4
  2. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/README.md +7 -3
  3. ista_daslab_optimizers-1.1.2/ista_daslab_optimizers/micro_adam/micro_adam.py +338 -0
  4. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2/ista_daslab_optimizers.egg-info}/PKG-INFO +8 -4
  5. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/kernels/micro_adam/micro_adam.cpp +3 -3
  6. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/kernels/micro_adam/micro_adam_asymm_block_quant_inv.cu +9 -7
  7. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/kernels/sparse_mfac/sparse_mfac_LCG_kernel.cu +1 -1
  8. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/pyproject.toml +1 -1
  9. ista_daslab_optimizers-1.0.1/ista_daslab_optimizers/micro_adam/micro_adam.py +0 -247
  10. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/LICENSE +0 -0
  11. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/MANIFEST.in +0 -0
  12. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers/__init__.py +0 -0
  13. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers/acdc/__init__.py +0 -0
  14. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers/acdc/acdc.py +0 -0
  15. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers/acdc/wd_scheduler.py +0 -0
  16. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers/dense_mfac/__init__.py +0 -0
  17. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers/dense_mfac/dense_core_mfac.py +0 -0
  18. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers/dense_mfac/dense_mfac.py +0 -0
  19. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers/micro_adam/__init__.py +0 -0
  20. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers/sparse_mfac/__init__.py +0 -0
  21. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers/sparse_mfac/sparse_core_mfac_w_ef.py +0 -0
  22. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers/sparse_mfac/sparse_mfac.py +0 -0
  23. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers/tools.py +0 -0
  24. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers.egg-info/SOURCES.txt +0 -0
  25. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers.egg-info/dependency_links.txt +0 -0
  26. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers.egg-info/requires.txt +0 -0
  27. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers.egg-info/top_level.txt +0 -0
  28. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/kernels/dense_mfac/dense_mfac.cpp +0 -0
  29. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/kernels/dense_mfac/dense_mfac_kernel.cu +0 -0
  30. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/kernels/micro_adam/micro_adam_asymm_block_quant.cu +0 -0
  31. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/kernels/micro_adam/micro_adam_update.cu +0 -0
  32. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/kernels/sparse_mfac/sparse_mfac.cpp +0 -0
  33. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/kernels/sparse_mfac/sparse_mfac_SP_kernel.cu +0 -0
  34. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/kernels/tools/tools.cpp +0 -0
  35. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/kernels/tools/tools_kernel.cu +0 -0
  36. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/kernels/utils.h +0 -0
  37. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/setup.cfg +0 -0
  38. {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ista_daslab_optimizers
3
- Version: 1.0.1
3
+ Version: 1.1.2
4
4
  Summary: Deep Learning optimizers developed in the Distributed Algorithms and Systems group (DASLab) @ Institute of Science and Technology Austria (ISTA)
5
5
  Author-email: Ionut-Vlad Modoranu <ionut-vlad.modoranu@ist.ac.at>
6
6
  Maintainer-email: Ionut-Vlad Modoranu <ionut-vlad.modoranu@ist.ac.at>
@@ -280,6 +280,7 @@ optimizer = MicroAdam(
280
280
  lr=1e-5, # change accordingly
281
281
  quant_block_size=100_000, # 32 or 64 also works
282
282
  k_init=0.01, # float between 0 and 1 meaning percentage: 0.01 means 1%
283
+ alpha=0, # 0 means sparse update and 0 < alpha < 1 means we integrate fraction alpha from EF to update and then delete it
283
284
  )
284
285
 
285
286
  # from now on, you can use the variable `optimizer` as any other PyTorch optimizer
@@ -288,15 +289,18 @@ optimizer = MicroAdam(
288
289
  # Versions summary:
289
290
 
290
291
  ---
292
+ - **1.1.2** @ August 1st, 2024:
293
+ - ***[1.1.0]:*** added support to densify the final update: introduced parameter alpha that controls the fraction of error feedback
294
+ (EF) to be integrated into the update to make it dense. Finally, the fraction alpha will be discarded from the EF at
295
+ the expense of another call to `Qinv` and `Q` (and implicitly quantization statistics computation).
296
+ - ***[1.0.2]:*** added FSDP-compatible implementation by initializing the parameter states in the `update_step` method
297
+ instead of MicroAdam constructor
291
298
 
292
299
  - **1.0.1** @ June 27th, 2024:
293
-
294
300
  - removed version in dependencies to avoid conflicts with llm-foundry
295
301
 
296
302
  - **1.0.0** @ June 20th, 2024:
297
-
298
303
  - changed minimum required Python version to 3.8+ and torch to 2.3.0+
299
304
 
300
305
  - **0.0.1** @ June 13th, 2024:
301
-
302
306
  - added initial version of the package for Python 3.9+ and torch 2.3.1+
@@ -55,6 +55,7 @@ optimizer = MicroAdam(
55
55
  lr=1e-5, # change accordingly
56
56
  quant_block_size=100_000, # 32 or 64 also works
57
57
  k_init=0.01, # float between 0 and 1 meaning percentage: 0.01 means 1%
58
+ alpha=0, # 0 means sparse update and 0 < alpha < 1 means we integrate fraction alpha from EF to update and then delete it
58
59
  )
59
60
 
60
61
  # from now on, you can use the variable `optimizer` as any other PyTorch optimizer
@@ -63,15 +64,18 @@ optimizer = MicroAdam(
63
64
  # Versions summary:
64
65
 
65
66
  ---
67
+ - **1.1.2** @ August 1st, 2024:
68
+ - ***[1.1.0]:*** added support to densify the final update: introduced parameter alpha that controls the fraction of error feedback
69
+ (EF) to be integrated into the update to make it dense. Finally, the fraction alpha will be discarded from the EF at
70
+ the expense of another call to `Qinv` and `Q` (and implicitly quantization statistics computation).
71
+ - ***[1.0.2]:*** added FSDP-compatible implementation by initializing the parameter states in the `update_step` method
72
+ instead of MicroAdam constructor
66
73
 
67
74
  - **1.0.1** @ June 27th, 2024:
68
-
69
75
  - removed version in dependencies to avoid conflicts with llm-foundry
70
76
 
71
77
  - **1.0.0** @ June 20th, 2024:
72
-
73
78
  - changed minimum required Python version to 3.8+ and torch to 2.3.0+
74
79
 
75
80
  - **0.0.1** @ June 13th, 2024:
76
-
77
81
  - added initial version of the package for Python 3.9+ and torch 2.3.1+
@@ -0,0 +1,338 @@
1
+ import os
2
+ import torch
3
+ import math
4
+ import time
5
+ import wandb
6
+ from torch.distributed import is_initialized, get_rank, all_reduce, ReduceOp
7
+ from ..tools import get_first_device, get_gpu_mem_usage, block_split, CopyDirection
8
+
9
+ import ista_daslab_tools
10
+ import ista_daslab_micro_adam
11
+
12
+
13
+ class MicroAdam(torch.optim.Optimizer):
14
+ def __init__(self, params, m, lr, quant_block_size, k_init=0.01, alpha=0, betas=(0.9, 0.999), weight_decay=0, eps=1e-8):
15
+ defaults = dict(lr=lr, weight_decay=weight_decay, eps=eps, alpha=alpha)
16
+ super(MicroAdam, self).__init__(params, defaults)
17
+
18
+ assert 0 <= alpha < 1, 'Alpha must be in the [0, 1) interval'
19
+
20
+ self.m = m
21
+ self.lr = lr
22
+ self.quant_block_size = int(quant_block_size)
23
+ self.k_init = k_init
24
+ self.alpha = alpha
25
+ self.weight_decay = weight_decay
26
+ self.beta1 = betas[0]
27
+ self.beta2 = betas[1]
28
+ self.eps = eps
29
+
30
+ self.densify_update = (self.alpha > 0)
31
+ self.model_size = sum([p.numel() for group in self.param_groups for p in group['params']])
32
+
33
+ self.steps = 0 # how many optimization steps were performed so far
34
+ self.log_interval = 100
35
+ self.device = get_first_device()
36
+ self._is_state_initialized = False
37
+ self.shared_memory_carveout = 100
38
+ self.blocks = ista_daslab_tools.get_sm_count() * int(100 / self.shared_memory_carveout)
39
+ self.threads = 512
40
+
41
+ self.max_floats = ista_daslab_tools.get_max_floats_for_shared_memory_per_thread_block()
42
+ self.d_block_size = self.max_floats // 2 // int(100 / self.shared_memory_carveout)
43
+
44
+ self.fsdp_dict_size_count = [{} for _ in range(
45
+ torch.distributed.get_world_size())] # key = layer size, value = how many layers of that size the model has (per worker)
46
+ self.dict_size_count = {} # key = layer size, value = how many layers of that size the model has
47
+ for param in self.param_groups:
48
+ for p in param['params']:
49
+ size = p.numel()
50
+ # print(p.shape, p.numel())
51
+ self.dict_size_count[size] = 1 + self.dict_size_count.get(size, 0)
52
+
53
+ # self._init_state()
54
+
55
+ def _initialize_parameter_state(self, p, lr, wd):
56
+ layer_size = p.numel()
57
+ st = self.state[p]
58
+
59
+ rank = torch.distributed.get_rank()
60
+
61
+ st['blocks'] = max(1, int(math.floor(self.blocks * layer_size * self.fsdp_dict_size_count[rank][layer_size] / self.model_size)))
62
+
63
+ st['lr'] = lr
64
+ st['weight_decay'] = wd
65
+ st['d'] = layer_size
66
+
67
+ ##### variables for Top-K: d_index_topk is the index where the last, smaller topk block starts
68
+ st['d_block_size'] = layer_size if layer_size < self.d_block_size else self.d_block_size
69
+ st['topk_full_blocks_count'], st['d_index_topk'] = block_split(st['d'], st['d_block_size'])
70
+ st['k_block_size_many'] = int(math.ceil(st['d_block_size'] * self.k_init))
71
+ st['k_block_size_few'] = int(math.ceil((st['d'] - st['d_index_topk']) * self.k_init)) # 0 for d % self.d_block_size = 0
72
+ st['k_index'] = st['topk_full_blocks_count'] * st['k_block_size_many']
73
+ st['k'] = st['k_block_size_many'] * st['topk_full_blocks_count'] + st['k_block_size_few']
74
+
75
+ ##### variables for the ring buffer
76
+ st['index'] = 0 # the position to place a new gradient at
77
+ st['I'] = torch.zeros(self.m, st['k'], dtype=torch.int16, device=self.device) # 2mk bytes
78
+ st['V'] = torch.zeros(self.m, st['k'], dtype=torch.bfloat16, device=self.device) # 2mk bytes
79
+
80
+ ### variables for error feedback: d_index_quant is the index where the last, smaller quantization block starts
81
+ # st['quant_block_size'] = layer_size if layer_size < self.quant_block_size else self.quant_block_size
82
+ st['quant_full_blocks_count'], st['d_index_quant'] = block_split(st['d'], self.quant_block_size)
83
+ st['error'] = torch.zeros(int(math.ceil(st['d'] / 2)), dtype=torch.uint8, device=self.device) # ceil(d/2) bytes
84
+ st['min_vals'] = torch.zeros(st['quant_full_blocks_count'] + 1, dtype=torch.bfloat16, device=self.device) # ceil(d/q_bsz)*2 bytes
85
+ st['max_vals'] = torch.zeros(st['quant_full_blocks_count'] + 1, dtype=torch.bfloat16, device=self.device) # ceil(d/q_bsz)*2 bytes
86
+
87
+ @torch.no_grad()
88
+ def step(self, closure=None):
89
+ self.steps += 1
90
+
91
+ # self._update_lr_wd()
92
+
93
+ loss = None
94
+ if closure is not None:
95
+ with torch.enable_grad():
96
+ loss = closure()
97
+
98
+ if self.steps == 1:
99
+ rank = torch.distributed.get_rank()
100
+ for param in self.param_groups:
101
+ for p in param['params']:
102
+ if p is not None:
103
+ size = p.numel()
104
+ if size > 0:
105
+ self.fsdp_dict_size_count[rank][size] = 1 + self.fsdp_dict_size_count[rank].get(size, 0)
106
+
107
+ time_start = time.time()
108
+
109
+ norm_g, norm_u, norm_e, sparsity_u = 0, 0, 0, 0
110
+
111
+ for group in self.param_groups:
112
+ lr = group['lr']
113
+ wd = group.get('weight_decay', self.weight_decay)
114
+
115
+ for p in group['params']:
116
+ if p.grad is None:
117
+ continue
118
+
119
+ if p is None:
120
+ continue
121
+
122
+ ng, nu, ne, sp_u = self.update_step(p, lr, wd)
123
+ norm_g += ng
124
+ norm_u += nu
125
+ norm_e += ne
126
+ sparsity_u += sp_u
127
+
128
+ # torch.cuda.synchronize()
129
+ time_end = time.time()
130
+ elapsed_step = time_end - time_start
131
+ self._log(norm_g, norm_u, norm_e, sparsity_u, elapsed_step)
132
+
133
+ return loss
134
+
135
+ @torch.no_grad()
136
+ def update_step(self, p, lr, wd):
137
+ norm_g, norm_u, norm_e, sp_u = 0, 0, 0, 0
138
+
139
+ grad = p.grad.view(-1)
140
+
141
+ if self.steps % self.log_interval == 0:
142
+ norm_g = grad.norm(p=2) ** 2
143
+
144
+ st = self.state[p]
145
+ if len(st) == 0:
146
+ self._initialize_parameter_state(p, lr, wd)
147
+
148
+ # print('rank=',torch.distributed.get_rank(), 'keys=',st.keys())
149
+
150
+ blocks = st['blocks']
151
+ # lr = st['lr']
152
+ # wd = st['weight_decay']
153
+ d = st['d']
154
+ d_block_size = st['d_block_size']
155
+ topk_full_blocks_count, d_index_topk = st['topk_full_blocks_count'], st['d_index_topk']
156
+ k_block_size_many = st['k_block_size_many']
157
+ k_block_size_few = st['k_block_size_few']
158
+ k_index = st['k_index']
159
+ k = st['k']
160
+
161
+ # HuggingFace has a setting that converts st['I'] to bfloat16, even though it is declared as int16
162
+ # This happens somewhere between constructor call and step call. Converting it to int16 was the simplest solution
163
+ if st['I'].dtype != torch.int16:
164
+ st['I'] = st['I'].to(torch.int16)
165
+
166
+ index = st['index']
167
+ I = st['I']
168
+ V = st['V']
169
+
170
+ quant_full_blocks_count, d_index_quant = st['quant_full_blocks_count'], st['d_index_quant']
171
+ error = st['error']
172
+ min_vals = st['min_vals']
173
+ max_vals = st['max_vals']
174
+
175
+ ##### STEP 4
176
+ ista_daslab_micro_adam.asymm_block_quant_inv(d, self.quant_block_size, error, min_vals, max_vals, grad, 1.0) # alpha=1 here
177
+
178
+ ##### STEP 5 + 9 (only for I)
179
+ I[index, :k_index] = torch.topk(input=grad[0:d_index_topk].abs().view(topk_full_blocks_count, d_block_size),
180
+ k=k_block_size_many,
181
+ sorted=False).indices.to(dtype=torch.int16).view(-1)
182
+
183
+ if k_block_size_few > 0: # there is a small block left
184
+ I[index, k_index:] = torch.topk(input=grad[d_index_topk:].abs(),
185
+ k=k_block_size_few, # example: slice has size 1, but ks[-1] is 4
186
+ sorted=False).indices.to(dtype=torch.int16).view(-1)
187
+
188
+ ista_daslab_tools.copy_values(d, # V = error[I[buffer_index, :]]
189
+ k,
190
+ d_block_size,
191
+ k_block_size_many,
192
+ I[index, :], # indices
193
+ grad, # inp
194
+ V[index, :], # output
195
+ CopyDirection.d2k.value)
196
+
197
+ st['index'] = (index + 1) % self.m
198
+
199
+ ##### STEP 6
200
+ ista_daslab_tools.zerorize_block_components(grad, I[index, :], d, k, d_block_size, k_block_size_many) # this does a[I[index]] = 0
201
+
202
+ ##### STEP 7
203
+ def _update_quantization_statistics():
204
+ if quant_full_blocks_count == 1:
205
+ min_vals[:quant_full_blocks_count] = grad[:d_index_quant].min()
206
+ max_vals[:quant_full_blocks_count] = grad[:d_index_quant].max()
207
+ else:
208
+ min_vals[:quant_full_blocks_count] = grad[:d_index_quant].view(quant_full_blocks_count, self.quant_block_size).min(dim=1).values
209
+ max_vals[:quant_full_blocks_count] = grad[:d_index_quant].view(quant_full_blocks_count, self.quant_block_size).max(dim=1).values
210
+ if d_index_quant < d:
211
+ min_vals[quant_full_blocks_count] = grad[d_index_quant:].min()
212
+ max_vals[quant_full_blocks_count] = grad[d_index_quant:].max()
213
+
214
+ _update_quantization_statistics()
215
+
216
+ ##### STEP 8
217
+ ista_daslab_micro_adam.asymm_block_quant(d, self.quant_block_size, error, min_vals, max_vals, grad) # error = Q(a, min, max)
218
+
219
+ ##### STEPS 10-11
220
+ grad.zero_()
221
+ ista_daslab_micro_adam.compute_microadam_update(blocks, # blocks
222
+ self.threads, # threads
223
+ self.shared_memory_carveout, # carveout
224
+ self.steps, # optimization step
225
+ self.beta1, # beta1
226
+ self.beta2, # beta2
227
+ self.eps, # eps
228
+ d_block_size, # d_block_size
229
+ k_block_size_many, # k_block_size
230
+ d, # d
231
+ self.m, # m
232
+ k, # k
233
+ I, # indices
234
+ V, # values
235
+ grad) # update will be stored here
236
+
237
+ ##### STEP 12: # side idea: only decay the weights that are update
238
+
239
+ ##### if PRETRAINING #1
240
+ if self.densify_update: # we add alpha * EF to update that is stored in grad buffer
241
+ # p.grad += alpha * Qinv(error), alpha=0.1
242
+ ista_daslab_micro_adam.asymm_block_quant_inv(d, self.quant_block_size, error, min_vals, max_vals, grad, self.alpha)
243
+ ##### END IF PRETRAINING #1
244
+
245
+ # if alpha > 0, then the update u=p.grad is dense now
246
+ p.mul_(1 - lr * wd).add_(p.grad, alpha=-lr)
247
+
248
+ ##### if PRETRAINING #2
249
+ if self.densify_update:
250
+ grad.zero_()
251
+ ista_daslab_micro_adam.asymm_block_quant_inv(d, self.quant_block_size, error, min_vals, max_vals, grad, 1-self.alpha)
252
+
253
+ _update_quantization_statistics() # step 7 again
254
+ ista_daslab_micro_adam.asymm_block_quant(d, self.quant_block_size, error, min_vals, max_vals, grad) # step 8 again
255
+ ##### END IF PRETRAINING #2
256
+
257
+ # compute error norm
258
+ if self.steps % self.log_interval == 0:
259
+ norm_u = grad.norm(p=2) ** 2
260
+ sp_u = (grad == 0).sum() # check sparsity before zerorizing
261
+
262
+ grad.zero_()
263
+ ista_daslab_micro_adam.asymm_block_quant_inv(d, self.quant_block_size, error, min_vals, max_vals, grad, 1.0)
264
+
265
+ norm_e = grad.norm(p=2) ** 2
266
+
267
+ return norm_g, norm_u, norm_e, sp_u
268
+
269
+ def _log(self, norm_g, norm_u, norm_e, sparsity_u, elapsed_step):
270
+ if self.steps % self.log_interval == 0:
271
+ sync_data = torch.tensor([norm_g, norm_u, norm_e, sparsity_u, elapsed_step], dtype=torch.float,
272
+ requires_grad=False).cuda() # correct, loss, size
273
+ all_reduce(sync_data, op=ReduceOp.SUM)
274
+ norm_g, norm_u, norm_e, sparsity_u, elapsed_step = sync_data
275
+
276
+ if not is_initialized() or get_rank() == 0:
277
+ wandb_data = {
278
+ 'step/optimizer_steps': self.steps,
279
+ 'step/gpu_mem_usage': get_gpu_mem_usage(),
280
+ 'step/norm_g': math.sqrt(norm_g),
281
+ 'step/norm_u': math.sqrt(norm_u),
282
+ 'step/norm_error': math.sqrt(norm_e),
283
+ 'step/sparsity_u': sparsity_u / self.model_size * 100.,
284
+ 'step/elapsed_step': elapsed_step,
285
+ }
286
+ wandb.log(wandb_data, commit=False)
287
+
288
+ # def _update_lr_wd(self):
289
+ # # copy the learning rate group to parameter state because the lr scheduler updates the one in the group
290
+ # for group in self.param_groups:
291
+ # lr = group['lr']
292
+ # wd = group.get('weight_decay', self.weight_decay) # if the param groups do not have weight decay, then use the external one
293
+ # for p in group['params']:
294
+ # self.state[p]['lr'] = lr
295
+ # self.state[p]['wd'] = wd
296
+
297
+
298
+ # def _init_state(self):
299
+ # count = 0
300
+ # for group in self.param_groups:
301
+ # lr = group['lr']
302
+ # wd = group.get('weight_decay', self.weight_decay) # if the param groups do not have weight decay, then use the external one
303
+ # for p in group['params']:
304
+ # if not p.requires_grad:
305
+ # continue
306
+
307
+ # print(f'[init_state] rank={torch.distributed.get_rank()}, p.shape={p.shape}')
308
+
309
+ # count += 1
310
+ # layer_size = p.numel()
311
+ # st = self.state[p]
312
+
313
+ # # B * t / d * nt
314
+ # st['blocks'] = max(1, int(math.floor(self.blocks * layer_size * self.dict_size_count[layer_size] / self.model_size)))
315
+
316
+ # st['lr'] = lr
317
+ # st['weight_decay'] = wd
318
+ # st['d'] = layer_size
319
+
320
+ # ##### variables for Top-K: d_index_topk is the index where the last, smaller topk block starts
321
+ # st['d_block_size'] = layer_size if layer_size < self.d_block_size else self.d_block_size
322
+ # st['topk_full_blocks_count'], st['d_index_topk'] = block_split(st['d'], st['d_block_size'])
323
+ # st['k_block_size_many'] = int(math.ceil(st['d_block_size'] * self.k_init))
324
+ # st['k_block_size_few'] = int(math.ceil((st['d'] - st['d_index_topk']) * self.k_init)) # 0 for d % self.d_block_size = 0
325
+ # st['k_index'] = st['topk_full_blocks_count'] * st['k_block_size_many']
326
+ # st['k'] = st['k_block_size_many'] * st['topk_full_blocks_count'] + st['k_block_size_few']
327
+
328
+ # ##### variables for the ring buffer
329
+ # st['index'] = 0 # the position to place a new gradient at
330
+ # st['I'] = torch.zeros(self.m, st['k'], dtype=torch.int16, device=self.device) # 2mk bytes
331
+ # st['V'] = torch.zeros(self.m, st['k'], dtype=torch.bfloat16, device=self.device) # 2mk bytes
332
+
333
+ # ### variables for error feedback: d_index_quant is the index where the last, smaller quantization block starts
334
+ # # st['quant_block_size'] = layer_size if layer_size < self.quant_block_size else self.quant_block_size
335
+ # st['quant_full_blocks_count'], st['d_index_quant'] = block_split(st['d'], self.quant_block_size)
336
+ # st['error'] = torch.zeros(int(math.ceil(st['d'] / 2)), dtype=torch.uint8, device=self.device) # ceil(d/2) bytes
337
+ # st['min_vals'] = torch.zeros(st['quant_full_blocks_count'] + 1, dtype=torch.bfloat16, device=self.device) # ceil(d/q_bsz)*2 bytes
338
+ # st['max_vals'] = torch.zeros(st['quant_full_blocks_count'] + 1, dtype=torch.bfloat16, device=self.device) # ceil(d/q_bsz)*2 bytes
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ista_daslab_optimizers
3
- Version: 1.0.1
3
+ Version: 1.1.2
4
4
  Summary: Deep Learning optimizers developed in the Distributed Algorithms and Systems group (DASLab) @ Institute of Science and Technology Austria (ISTA)
5
5
  Author-email: Ionut-Vlad Modoranu <ionut-vlad.modoranu@ist.ac.at>
6
6
  Maintainer-email: Ionut-Vlad Modoranu <ionut-vlad.modoranu@ist.ac.at>
@@ -280,6 +280,7 @@ optimizer = MicroAdam(
280
280
  lr=1e-5, # change accordingly
281
281
  quant_block_size=100_000, # 32 or 64 also works
282
282
  k_init=0.01, # float between 0 and 1 meaning percentage: 0.01 means 1%
283
+ alpha=0, # 0 means sparse update and 0 < alpha < 1 means we integrate fraction alpha from EF to update and then delete it
283
284
  )
284
285
 
285
286
  # from now on, you can use the variable `optimizer` as any other PyTorch optimizer
@@ -288,15 +289,18 @@ optimizer = MicroAdam(
288
289
  # Versions summary:
289
290
 
290
291
  ---
292
+ - **1.1.2** @ August 1st, 2024:
293
+ - ***[1.1.0]:*** added support to densify the final update: introduced parameter alpha that controls the fraction of error feedback
294
+ (EF) to be integrated into the update to make it dense. Finally, the fraction alpha will be discarded from the EF at
295
+ the expense of another call to `Qinv` and `Q` (and implicitly quantization statistics computation).
296
+ - ***[1.0.2]:*** added FSDP-compatible implementation by initializing the parameter states in the `update_step` method
297
+ instead of MicroAdam constructor
291
298
 
292
299
  - **1.0.1** @ June 27th, 2024:
293
-
294
300
  - removed version in dependencies to avoid conflicts with llm-foundry
295
301
 
296
302
  - **1.0.0** @ June 20th, 2024:
297
-
298
303
  - changed minimum required Python version to 3.8+ and torch to 2.3.0+
299
304
 
300
305
  - **0.0.1** @ June 13th, 2024:
301
-
302
306
  - added initial version of the package for Python 3.9+ and torch 2.3.1+
@@ -13,7 +13,7 @@ void compute_microadam_update_cuda(int blocks, int threads, int carveout,
13
13
  torch::Tensor indices, torch::Tensor values, torch::Tensor out);
14
14
 
15
15
  void asymm_block_quant_cuda(LL d, LL q_block_size, torch::Tensor xq, torch::Tensor xmin, torch::Tensor xmax, torch::Tensor x);
16
- void asymm_block_quant_inv_cuda(LL d, LL q_block_size, torch::Tensor xq, torch::Tensor xmin, torch::Tensor xmax, torch::Tensor x);
16
+ void asymm_block_quant_inv_cuda(LL d, LL q_block_size, torch::Tensor xq, torch::Tensor xmin, torch::Tensor xmax, torch::Tensor x, float alpha);
17
17
 
18
18
  // C++ methods
19
19
  void compute_microadam_update(int blocks, int threads, int carveout,
@@ -44,14 +44,14 @@ void asymm_block_quant(LL d, LL q_block_size, torch::Tensor xq, torch::Tensor xm
44
44
  asymm_block_quant_cuda(d, q_block_size, xq, xmin, xmax, x);
45
45
  }
46
46
 
47
- void asymm_block_quant_inv(LL d, LL q_block_size, torch::Tensor xq, torch::Tensor xmin, torch::Tensor xmax, torch::Tensor x) {
47
+ void asymm_block_quant_inv(LL d, LL q_block_size, torch::Tensor xq, torch::Tensor xmin, torch::Tensor xmax, torch::Tensor x, float alpha) {
48
48
  CHECK_INPUT(xq);
49
49
  CHECK_INPUT(xmin);
50
50
  CHECK_INPUT(xmax);
51
51
  CHECK_INPUT(x);
52
52
 
53
53
  const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
54
- asymm_block_quant_inv_cuda(d, q_block_size, xq, xmin, xmax, x);
54
+ asymm_block_quant_inv_cuda(d, q_block_size, xq, xmin, xmax, x, alpha);
55
55
  }
56
56
 
57
57
  PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
@@ -1,8 +1,8 @@
1
1
  #include "../utils.h"
2
2
 
3
- __global__ void asymm_block_quant_inv_kernel_bf16_bf16(LL d, LL q_block_size, uint8_t *xq, bfloat16 *xmin, bfloat16 *xmax, bfloat16 *x) {
3
+ __global__ void asymm_block_quant_inv_kernel_bf16_bf16(LL d, LL q_block_size, uint8_t *xq, bfloat16 *xmin, bfloat16 *xmax, bfloat16 *x, float alpha) {
4
4
  /*
5
- This kernel computes x += Q_inv(xq, xmin, xmax) for 4 bits (implements point 1 from PhD notebook page 55)
5
+ This kernel computes x += alpha * Q_inv(xq, xmin, xmax) for 4 bits (implements point 1 from PhD #9 notebook page 55)
6
6
  Here, x is the output buffer and will already contain the dense gradient
7
7
  The output buffer x has d components and xq has d/2 components because one uint8_t stores two 4-bit values
8
8
  In contrast to "globally" kernel, this kernel works with a single block
@@ -43,20 +43,21 @@ __global__ void asymm_block_quant_inv_kernel_bf16_bf16(LL d, LL q_block_size, ui
43
43
  lsb = (vq & 0x0F);
44
44
 
45
45
  // += operation happens here
46
- vx2.x += __float2bfloat16(msb * u + m);
47
- vx2.y += __float2bfloat16(lsb * u + m);
46
+ vx2.x += __float2bfloat16((msb * u + m) * alpha);
47
+ vx2.y += __float2bfloat16((lsb * u + m) * alpha);
48
48
  x2[half_index] = vx2;
49
49
  }
50
50
  if((d & 1) && (Bid == B-1) && (Tid == T-1)) {
51
51
  bfloat16 vx = x[d - 1];
52
52
  vq = xq[half_d];
53
53
  msb = (vq & 0xF0) >> 4;
54
- vx += __float2bfloat16(msb * u + m);
54
+ // += operation happens here
55
+ vx += __float2bfloat16((msb * u + m) * alpha);
55
56
  x[d - 1] = vx;
56
57
  }
57
58
  }
58
59
 
59
- void asymm_block_quant_inv_cuda(LL d, LL q_block_size, torch::Tensor xq, torch::Tensor xmin, torch::Tensor xmax, torch::Tensor x) {
60
+ void asymm_block_quant_inv_cuda(LL d, LL q_block_size, torch::Tensor xq, torch::Tensor xmin, torch::Tensor xmax, torch::Tensor x, float alpha) {
60
61
  ASSERT_BF16(xmin);
61
62
  ASSERT_BF16(xmax);
62
63
  ASSERT_BF16(x);
@@ -72,7 +73,8 @@ void asymm_block_quant_inv_cuda(LL d, LL q_block_size, torch::Tensor xq, torch::
72
73
  (uint8_t*) xq.data_ptr(),
73
74
  (bfloat16*) xmin.data_ptr(),
74
75
  (bfloat16*) xmax.data_ptr(),
75
- (bfloat16*) x.data_ptr());
76
+ (bfloat16*) x.data_ptr(),
77
+ alpha);
76
78
 
77
79
  // error checks
78
80
  GPU_ERROR_CHECK(cudaGetLastError());
@@ -144,7 +144,7 @@ __global__ void LCG_v51_bf16(long d,
144
144
 
145
145
  long i; // vector index (for indices and values)
146
146
  long k_col; // column index to extract data from indices and values at the current row
147
- long index; // the 1-D index to extract data from indices and values at the current row (row * k + k_col)
147
+ // long index; // the 1-D index to extract data from indices and values at the current row (row * k + k_col)
148
148
  int16 ind; // the data from indices at the index "index"
149
149
  bfloat16 val; // the data from values at the index "index"
150
150
 
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name='ista_daslab_optimizers'
7
- version='1.0.1'
7
+ version='1.1.2'
8
8
  dependencies = [
9
9
  "torch", # >=2.3.1",
10
10
  "torchaudio", # >=2.3.1",
@@ -1,247 +0,0 @@
1
-
2
- import torch
3
- import math
4
- import time
5
- import wandb
6
- from torch.distributed import is_initialized, get_rank
7
- from ..tools import get_first_device, get_gpu_mem_usage, block_split, CopyDirection
8
-
9
- import ista_daslab_tools
10
- import ista_daslab_micro_adam
11
-
12
- class MicroAdam(torch.optim.Optimizer):
13
- def __init__(self, params, m, lr, quant_block_size, k_init=0.01, betas=(0.9, 0.999), weight_decay=0, eps=1e-8):
14
- defaults = dict(lr=lr, weight_decay=weight_decay, eps=eps)
15
- super(MicroAdam, self).__init__(params, defaults)
16
-
17
- self.m = m
18
- self.lr = lr
19
- self.quant_block_size = int(quant_block_size)
20
- self.k_init = k_init
21
- self.weight_decay = weight_decay
22
- self.beta1 = betas[0]
23
- self.beta2 = betas[1]
24
- self.eps = eps
25
-
26
- self.model_size = sum([p.numel() for group in self.param_groups for p in group['params']])
27
-
28
- self.steps = 0 # how many optimization steps were performed so far
29
- self.log_interval = 100
30
- self.device = get_first_device()
31
- self._is_state_initialized = False
32
- self.shared_memory_carveout = 100
33
- self.blocks = ista_daslab_tools.get_sm_count() * int(100 / self.shared_memory_carveout)
34
- self.threads = 512
35
-
36
- self.dict_size_count = {} # key = layer size, value = how many layers of that size the model has
37
- for param in self.param_groups:
38
- for p in param['params']:
39
- size = p.numel()
40
- self.dict_size_count[size] = 1 + self.dict_size_count.get(size, 0)
41
-
42
- self._init_state()
43
-
44
- def _init_state(self):
45
- max_floats = ista_daslab_tools.get_max_floats_for_shared_memory_per_thread_block()
46
- d_block_size = max_floats // 2 // int(100 / self.shared_memory_carveout)
47
- count = 0
48
- for group in self.param_groups:
49
- lr = group['lr']
50
- wd = group.get('weight_decay', self.weight_decay) # if the param groups do not have weight decay, then use the external one
51
- for p in group['params']:
52
- if not p.requires_grad:
53
- continue
54
- count += 1
55
- layer_size = p.numel()
56
- st = self.state[p]
57
-
58
- # B * t / d * nt
59
- st['blocks'] = max(1, int(math.floor(self.blocks * layer_size * self.dict_size_count[layer_size] / self.model_size)))
60
-
61
- st['lr'] = lr
62
- st['weight_decay'] = wd
63
- st['d'] = layer_size
64
-
65
- ##### variables for Top-K: d_index_topk is the index where the last, smaller topk block starts
66
- st['d_block_size'] = layer_size if layer_size < d_block_size else d_block_size
67
- st['topk_full_blocks_count'], st['d_index_topk'] = block_split(st['d'], st['d_block_size'])
68
- st['k_block_size_many'] = int(math.ceil(st['d_block_size'] * self.k_init))
69
- st['k_block_size_few'] = int(math.ceil((st['d'] - st['d_index_topk']) * self.k_init)) # 0 for d % d_block_size = 0
70
- st['k_index'] = st['topk_full_blocks_count'] * st['k_block_size_many']
71
- st['k'] = st['k_block_size_many'] * st['topk_full_blocks_count'] + st['k_block_size_few']
72
-
73
- ##### variables for the ring buffer
74
- st['index'] = 0 # the position to place a new gradient at
75
- st['I'] = torch.zeros(self.m, st['k'], dtype=torch.int16, device=self.device) # 2mk bytes
76
- st['V'] = torch.zeros(self.m, st['k'], dtype=torch.bfloat16, device=self.device) # 2mk bytes
77
-
78
- ### variables for error feedback: d_index_quant is the index where the last, smaller quantization block starts
79
- # st['quant_block_size'] = layer_size if layer_size < self.quant_block_size else self.quant_block_size
80
- st['quant_full_blocks_count'], st['d_index_quant'] = block_split(st['d'], self.quant_block_size)
81
- st['error'] = torch.zeros(int(math.ceil(st['d'] / 2)), dtype=torch.uint8, device=self.device) # ceil(d/2) bytes
82
- st['min_vals'] = torch.zeros(st['quant_full_blocks_count'] + 1, dtype=torch.bfloat16, device=self.device) # ceil(d/q_bsz)*2 bytes
83
- st['max_vals'] = torch.zeros(st['quant_full_blocks_count'] + 1, dtype=torch.bfloat16, device=self.device) # ceil(d/q_bsz)*2 bytes
84
-
85
- @torch.no_grad()
86
- def step(self, closure=None):
87
- self.steps += 1
88
-
89
- self._update_lr_wd()
90
-
91
- loss = None
92
- if closure is not None:
93
- with torch.enable_grad():
94
- loss = closure()
95
-
96
- time_start = time.time()
97
-
98
- norm_g, norm_u, norm_e, sparsity_u = 0, 0, 0, 0
99
- for group in self.param_groups:
100
- for p in group['params']:
101
- if p.grad is None:
102
- continue
103
- ng, nu, ne, sp_u = self.update_step(p)
104
- norm_g += ng
105
- norm_u += nu
106
- norm_e += ne
107
- sparsity_u += sp_u
108
-
109
- # torch.cuda.synchronize()
110
- time_end = time.time()
111
- elapsed_step = time_end - time_start
112
- self._log(norm_g, norm_u, norm_e, sparsity_u, elapsed_step)
113
-
114
- return loss
115
-
116
- @torch.no_grad()
117
- def update_step(self, p):
118
- norm_g, norm_u, norm_e, sp_u = 0, 0, 0, 0
119
-
120
- st = self.state[p]
121
- grad = p.grad.view(-1)
122
-
123
- if self.steps % self.log_interval == 0:
124
- norm_g = grad.norm(p=2) ** 2
125
-
126
- blocks = st['blocks']
127
- lr = st['lr']
128
- wd = st['weight_decay']
129
- d = st['d']
130
- d_block_size = st['d_block_size']
131
- topk_full_blocks_count, d_index_topk = st['topk_full_blocks_count'], st['d_index_topk']
132
- k_block_size_many = st['k_block_size_many']
133
- k_block_size_few = st['k_block_size_few']
134
- k_index = st['k_index']
135
- k = st['k']
136
-
137
- # HuggingFace has a setting that converts st['I'] to bfloat16, even though it is declared as int16
138
- # This happens somewhere between constructor call and step call. Converting it to int16 was the simplest solution
139
- if st['I'].dtype != torch.int16:
140
- st['I'] = st['I'].to(torch.int16)
141
-
142
- index = st['index']
143
- I = st['I']
144
- V = st['V']
145
-
146
- quant_full_blocks_count, d_index_quant = st['quant_full_blocks_count'], st['d_index_quant']
147
- error = st['error']
148
- min_vals = st['min_vals']
149
- max_vals = st['max_vals']
150
-
151
- ##### STEP 4
152
- ista_daslab_micro_adam.asymm_block_quant_inv(d, self.quant_block_size, error, min_vals, max_vals, grad)
153
-
154
- ##### STEP 5 + 9 (only for I)
155
- I[index, :k_index] = torch.topk(input=grad[0:d_index_topk].abs().view(topk_full_blocks_count, d_block_size),
156
- k=k_block_size_many,
157
- sorted=False).indices.to(dtype=torch.int16).view(-1)
158
-
159
- if k_block_size_few > 0: # there is a small block left
160
- I[index, k_index:] = torch.topk(input=grad[d_index_topk:].abs(),
161
- k=k_block_size_few, # example: slice has size 1, but ks[-1] is 4
162
- sorted=False).indices.to(dtype=torch.int16).view(-1)
163
-
164
- ista_daslab_tools.copy_values(d, # V = error[I[buffer_index, :]]
165
- k,
166
- d_block_size,
167
- k_block_size_many,
168
- I[index, :], # indices
169
- grad, # inp
170
- V[index, :], # output
171
- CopyDirection.d2k.value)
172
-
173
- st['index'] = (index + 1) % self.m
174
-
175
- ##### STEP 6
176
- ista_daslab_tools.zerorize_block_components(grad, I[index, :], d, k, d_block_size, k_block_size_many) # this does a[I[index]] = 0
177
-
178
- ##### STEP 7
179
- if quant_full_blocks_count == 1:
180
- min_vals[:quant_full_blocks_count] = grad[:d_index_quant].min()
181
- max_vals[:quant_full_blocks_count] = grad[:d_index_quant].max()
182
- else:
183
- min_vals[:quant_full_blocks_count] = grad[:d_index_quant].view(quant_full_blocks_count, self.quant_block_size).min(dim=1).values
184
- max_vals[:quant_full_blocks_count] = grad[:d_index_quant].view(quant_full_blocks_count, self.quant_block_size).max(dim=1).values
185
- if d_index_quant < d:
186
- min_vals[quant_full_blocks_count] = grad[d_index_quant:].min()
187
- max_vals[quant_full_blocks_count] = grad[d_index_quant:].max()
188
-
189
- ##### STEP 8
190
- ista_daslab_micro_adam.asymm_block_quant(d, self.quant_block_size, error, min_vals, max_vals, grad) # error = Q(a, min, max)
191
-
192
- ##### STEPS 10-11
193
- grad.zero_()
194
- ista_daslab_micro_adam.compute_microadam_update(blocks, # blocks
195
- self.threads, # threads
196
- self.shared_memory_carveout, # carveout
197
- self.steps, # optimization step
198
- self.beta1, # beta1
199
- self.beta2, # beta2
200
- self.eps, # eps
201
- d_block_size, # d_block_size
202
- k_block_size_many, # k_block_size
203
- d, # d
204
- self.m, # m
205
- k, # k
206
- I, # indices
207
- V, # values
208
- grad) # update will be stored here
209
-
210
- ##### STEP 12
211
- p.mul_(1 - lr * wd).add_(p.grad, alpha=-lr)
212
-
213
- # compute error norm
214
- if self.steps % self.log_interval == 0:
215
- norm_u = grad.norm(p=2) ** 2
216
- sp_u = (grad == 0).sum() # check sparsity before zerorizing
217
-
218
- grad.zero_()
219
- ista_daslab_micro_adam.asymm_block_quant_inv(d, self.quant_block_size, error, min_vals, max_vals, grad)
220
-
221
- norm_e = grad.norm(p=2) ** 2
222
-
223
- return norm_g, norm_u, norm_e, sp_u
224
-
225
- def _log(self, norm_g, norm_u, norm_e, sparsity_u, elapsed_step):
226
- if self.steps % self.log_interval == 0:
227
- wandb_data = {
228
- 'step/optimizer_steps': self.steps,
229
- 'step/gpu_mem_usage': get_gpu_mem_usage(),
230
- 'step/norm_g': math.sqrt(norm_g),
231
- 'step/norm_u': math.sqrt(norm_u),
232
- 'step/norm_error': math.sqrt(norm_e),
233
- 'step/sparsity_u': sparsity_u / self.model_size * 100.,
234
- 'step/elapsed_step': elapsed_step,
235
- }
236
-
237
- if not is_initialized() or get_rank() == 0:
238
- wandb.log(wandb_data, commit=False)
239
-
240
- def _update_lr_wd(self):
241
- # copy the learning rate group to parameter state because the lr scheduler updates the one in the group
242
- for group in self.param_groups:
243
- lr = group['lr']
244
- wd = group.get('weight_decay', self.weight_decay) # if the param groups do not have weight decay, then use the external one
245
- for p in group['params']:
246
- self.state[p]['lr'] = lr
247
- self.state[p]['wd'] = wd