ista-daslab-optimizers 1.0.1__tar.gz → 1.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ista_daslab_optimizers-1.0.1/ista_daslab_optimizers.egg-info → ista_daslab_optimizers-1.1.2}/PKG-INFO +8 -4
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/README.md +7 -3
- ista_daslab_optimizers-1.1.2/ista_daslab_optimizers/micro_adam/micro_adam.py +338 -0
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2/ista_daslab_optimizers.egg-info}/PKG-INFO +8 -4
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/kernels/micro_adam/micro_adam.cpp +3 -3
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/kernels/micro_adam/micro_adam_asymm_block_quant_inv.cu +9 -7
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/kernels/sparse_mfac/sparse_mfac_LCG_kernel.cu +1 -1
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/pyproject.toml +1 -1
- ista_daslab_optimizers-1.0.1/ista_daslab_optimizers/micro_adam/micro_adam.py +0 -247
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/LICENSE +0 -0
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/MANIFEST.in +0 -0
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers/__init__.py +0 -0
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers/acdc/__init__.py +0 -0
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers/acdc/acdc.py +0 -0
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers/acdc/wd_scheduler.py +0 -0
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers/dense_mfac/__init__.py +0 -0
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers/dense_mfac/dense_core_mfac.py +0 -0
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers/dense_mfac/dense_mfac.py +0 -0
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers/micro_adam/__init__.py +0 -0
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers/sparse_mfac/__init__.py +0 -0
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers/sparse_mfac/sparse_core_mfac_w_ef.py +0 -0
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers/sparse_mfac/sparse_mfac.py +0 -0
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers/tools.py +0 -0
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers.egg-info/SOURCES.txt +0 -0
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers.egg-info/dependency_links.txt +0 -0
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers.egg-info/requires.txt +0 -0
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers.egg-info/top_level.txt +0 -0
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/kernels/dense_mfac/dense_mfac.cpp +0 -0
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/kernels/dense_mfac/dense_mfac_kernel.cu +0 -0
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/kernels/micro_adam/micro_adam_asymm_block_quant.cu +0 -0
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/kernels/micro_adam/micro_adam_update.cu +0 -0
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/kernels/sparse_mfac/sparse_mfac.cpp +0 -0
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/kernels/sparse_mfac/sparse_mfac_SP_kernel.cu +0 -0
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/kernels/tools/tools.cpp +0 -0
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/kernels/tools/tools_kernel.cu +0 -0
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/kernels/utils.h +0 -0
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/setup.cfg +0 -0
- {ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/setup.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ista_daslab_optimizers
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.1.2
|
|
4
4
|
Summary: Deep Learning optimizers developed in the Distributed Algorithms and Systems group (DASLab) @ Institute of Science and Technology Austria (ISTA)
|
|
5
5
|
Author-email: Ionut-Vlad Modoranu <ionut-vlad.modoranu@ist.ac.at>
|
|
6
6
|
Maintainer-email: Ionut-Vlad Modoranu <ionut-vlad.modoranu@ist.ac.at>
|
|
@@ -280,6 +280,7 @@ optimizer = MicroAdam(
|
|
|
280
280
|
lr=1e-5, # change accordingly
|
|
281
281
|
quant_block_size=100_000, # 32 or 64 also works
|
|
282
282
|
k_init=0.01, # float between 0 and 1 meaning percentage: 0.01 means 1%
|
|
283
|
+
alpha=0, # 0 means sparse update and 0 < alpha < 1 means we integrate fraction alpha from EF to update and then delete it
|
|
283
284
|
)
|
|
284
285
|
|
|
285
286
|
# from now on, you can use the variable `optimizer` as any other PyTorch optimizer
|
|
@@ -288,15 +289,18 @@ optimizer = MicroAdam(
|
|
|
288
289
|
# Versions summary:
|
|
289
290
|
|
|
290
291
|
---
|
|
292
|
+
- **1.1.2** @ August 1st, 2024:
|
|
293
|
+
- ***[1.1.0]:*** added support to densify the final update: introduced parameter alpha that controls the fraction of error feedback
|
|
294
|
+
(EF) to be integrated into the update to make it dense. Finally, the fraction alpha will be discarded from the EF at
|
|
295
|
+
the expense of another call to `Qinv` and `Q` (and implicitly quantization statistics computation).
|
|
296
|
+
- ***[1.0.2]:*** added FSDP-compatible implementation by initializing the parameter states in the `update_step` method
|
|
297
|
+
instead of MicroAdam constructor
|
|
291
298
|
|
|
292
299
|
- **1.0.1** @ June 27th, 2024:
|
|
293
|
-
|
|
294
300
|
- removed version in dependencies to avoid conflicts with llm-foundry
|
|
295
301
|
|
|
296
302
|
- **1.0.0** @ June 20th, 2024:
|
|
297
|
-
|
|
298
303
|
- changed minimum required Python version to 3.8+ and torch to 2.3.0+
|
|
299
304
|
|
|
300
305
|
- **0.0.1** @ June 13th, 2024:
|
|
301
|
-
|
|
302
306
|
- added initial version of the package for Python 3.9+ and torch 2.3.1+
|
|
@@ -55,6 +55,7 @@ optimizer = MicroAdam(
|
|
|
55
55
|
lr=1e-5, # change accordingly
|
|
56
56
|
quant_block_size=100_000, # 32 or 64 also works
|
|
57
57
|
k_init=0.01, # float between 0 and 1 meaning percentage: 0.01 means 1%
|
|
58
|
+
alpha=0, # 0 means sparse update and 0 < alpha < 1 means we integrate fraction alpha from EF to update and then delete it
|
|
58
59
|
)
|
|
59
60
|
|
|
60
61
|
# from now on, you can use the variable `optimizer` as any other PyTorch optimizer
|
|
@@ -63,15 +64,18 @@ optimizer = MicroAdam(
|
|
|
63
64
|
# Versions summary:
|
|
64
65
|
|
|
65
66
|
---
|
|
67
|
+
- **1.1.2** @ August 1st, 2024:
|
|
68
|
+
- ***[1.1.0]:*** added support to densify the final update: introduced parameter alpha that controls the fraction of error feedback
|
|
69
|
+
(EF) to be integrated into the update to make it dense. Finally, the fraction alpha will be discarded from the EF at
|
|
70
|
+
the expense of another call to `Qinv` and `Q` (and implicitly quantization statistics computation).
|
|
71
|
+
- ***[1.0.2]:*** added FSDP-compatible implementation by initializing the parameter states in the `update_step` method
|
|
72
|
+
instead of MicroAdam constructor
|
|
66
73
|
|
|
67
74
|
- **1.0.1** @ June 27th, 2024:
|
|
68
|
-
|
|
69
75
|
- removed version in dependencies to avoid conflicts with llm-foundry
|
|
70
76
|
|
|
71
77
|
- **1.0.0** @ June 20th, 2024:
|
|
72
|
-
|
|
73
78
|
- changed minimum required Python version to 3.8+ and torch to 2.3.0+
|
|
74
79
|
|
|
75
80
|
- **0.0.1** @ June 13th, 2024:
|
|
76
|
-
|
|
77
81
|
- added initial version of the package for Python 3.9+ and torch 2.3.1+
|
|
@@ -0,0 +1,338 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
import math
|
|
4
|
+
import time
|
|
5
|
+
import wandb
|
|
6
|
+
from torch.distributed import is_initialized, get_rank, all_reduce, ReduceOp
|
|
7
|
+
from ..tools import get_first_device, get_gpu_mem_usage, block_split, CopyDirection
|
|
8
|
+
|
|
9
|
+
import ista_daslab_tools
|
|
10
|
+
import ista_daslab_micro_adam
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MicroAdam(torch.optim.Optimizer):
|
|
14
|
+
def __init__(self, params, m, lr, quant_block_size, k_init=0.01, alpha=0, betas=(0.9, 0.999), weight_decay=0, eps=1e-8):
|
|
15
|
+
defaults = dict(lr=lr, weight_decay=weight_decay, eps=eps, alpha=alpha)
|
|
16
|
+
super(MicroAdam, self).__init__(params, defaults)
|
|
17
|
+
|
|
18
|
+
assert 0 <= alpha < 1, 'Alpha must be in the [0, 1) interval'
|
|
19
|
+
|
|
20
|
+
self.m = m
|
|
21
|
+
self.lr = lr
|
|
22
|
+
self.quant_block_size = int(quant_block_size)
|
|
23
|
+
self.k_init = k_init
|
|
24
|
+
self.alpha = alpha
|
|
25
|
+
self.weight_decay = weight_decay
|
|
26
|
+
self.beta1 = betas[0]
|
|
27
|
+
self.beta2 = betas[1]
|
|
28
|
+
self.eps = eps
|
|
29
|
+
|
|
30
|
+
self.densify_update = (self.alpha > 0)
|
|
31
|
+
self.model_size = sum([p.numel() for group in self.param_groups for p in group['params']])
|
|
32
|
+
|
|
33
|
+
self.steps = 0 # how many optimization steps were performed so far
|
|
34
|
+
self.log_interval = 100
|
|
35
|
+
self.device = get_first_device()
|
|
36
|
+
self._is_state_initialized = False
|
|
37
|
+
self.shared_memory_carveout = 100
|
|
38
|
+
self.blocks = ista_daslab_tools.get_sm_count() * int(100 / self.shared_memory_carveout)
|
|
39
|
+
self.threads = 512
|
|
40
|
+
|
|
41
|
+
self.max_floats = ista_daslab_tools.get_max_floats_for_shared_memory_per_thread_block()
|
|
42
|
+
self.d_block_size = self.max_floats // 2 // int(100 / self.shared_memory_carveout)
|
|
43
|
+
|
|
44
|
+
self.fsdp_dict_size_count = [{} for _ in range(
|
|
45
|
+
torch.distributed.get_world_size())] # key = layer size, value = how many layers of that size the model has (per worker)
|
|
46
|
+
self.dict_size_count = {} # key = layer size, value = how many layers of that size the model has
|
|
47
|
+
for param in self.param_groups:
|
|
48
|
+
for p in param['params']:
|
|
49
|
+
size = p.numel()
|
|
50
|
+
# print(p.shape, p.numel())
|
|
51
|
+
self.dict_size_count[size] = 1 + self.dict_size_count.get(size, 0)
|
|
52
|
+
|
|
53
|
+
# self._init_state()
|
|
54
|
+
|
|
55
|
+
def _initialize_parameter_state(self, p, lr, wd):
|
|
56
|
+
layer_size = p.numel()
|
|
57
|
+
st = self.state[p]
|
|
58
|
+
|
|
59
|
+
rank = torch.distributed.get_rank()
|
|
60
|
+
|
|
61
|
+
st['blocks'] = max(1, int(math.floor(self.blocks * layer_size * self.fsdp_dict_size_count[rank][layer_size] / self.model_size)))
|
|
62
|
+
|
|
63
|
+
st['lr'] = lr
|
|
64
|
+
st['weight_decay'] = wd
|
|
65
|
+
st['d'] = layer_size
|
|
66
|
+
|
|
67
|
+
##### variables for Top-K: d_index_topk is the index where the last, smaller topk block starts
|
|
68
|
+
st['d_block_size'] = layer_size if layer_size < self.d_block_size else self.d_block_size
|
|
69
|
+
st['topk_full_blocks_count'], st['d_index_topk'] = block_split(st['d'], st['d_block_size'])
|
|
70
|
+
st['k_block_size_many'] = int(math.ceil(st['d_block_size'] * self.k_init))
|
|
71
|
+
st['k_block_size_few'] = int(math.ceil((st['d'] - st['d_index_topk']) * self.k_init)) # 0 for d % self.d_block_size = 0
|
|
72
|
+
st['k_index'] = st['topk_full_blocks_count'] * st['k_block_size_many']
|
|
73
|
+
st['k'] = st['k_block_size_many'] * st['topk_full_blocks_count'] + st['k_block_size_few']
|
|
74
|
+
|
|
75
|
+
##### variables for the ring buffer
|
|
76
|
+
st['index'] = 0 # the position to place a new gradient at
|
|
77
|
+
st['I'] = torch.zeros(self.m, st['k'], dtype=torch.int16, device=self.device) # 2mk bytes
|
|
78
|
+
st['V'] = torch.zeros(self.m, st['k'], dtype=torch.bfloat16, device=self.device) # 2mk bytes
|
|
79
|
+
|
|
80
|
+
### variables for error feedback: d_index_quant is the index where the last, smaller quantization block starts
|
|
81
|
+
# st['quant_block_size'] = layer_size if layer_size < self.quant_block_size else self.quant_block_size
|
|
82
|
+
st['quant_full_blocks_count'], st['d_index_quant'] = block_split(st['d'], self.quant_block_size)
|
|
83
|
+
st['error'] = torch.zeros(int(math.ceil(st['d'] / 2)), dtype=torch.uint8, device=self.device) # ceil(d/2) bytes
|
|
84
|
+
st['min_vals'] = torch.zeros(st['quant_full_blocks_count'] + 1, dtype=torch.bfloat16, device=self.device) # ceil(d/q_bsz)*2 bytes
|
|
85
|
+
st['max_vals'] = torch.zeros(st['quant_full_blocks_count'] + 1, dtype=torch.bfloat16, device=self.device) # ceil(d/q_bsz)*2 bytes
|
|
86
|
+
|
|
87
|
+
@torch.no_grad()
|
|
88
|
+
def step(self, closure=None):
|
|
89
|
+
self.steps += 1
|
|
90
|
+
|
|
91
|
+
# self._update_lr_wd()
|
|
92
|
+
|
|
93
|
+
loss = None
|
|
94
|
+
if closure is not None:
|
|
95
|
+
with torch.enable_grad():
|
|
96
|
+
loss = closure()
|
|
97
|
+
|
|
98
|
+
if self.steps == 1:
|
|
99
|
+
rank = torch.distributed.get_rank()
|
|
100
|
+
for param in self.param_groups:
|
|
101
|
+
for p in param['params']:
|
|
102
|
+
if p is not None:
|
|
103
|
+
size = p.numel()
|
|
104
|
+
if size > 0:
|
|
105
|
+
self.fsdp_dict_size_count[rank][size] = 1 + self.fsdp_dict_size_count[rank].get(size, 0)
|
|
106
|
+
|
|
107
|
+
time_start = time.time()
|
|
108
|
+
|
|
109
|
+
norm_g, norm_u, norm_e, sparsity_u = 0, 0, 0, 0
|
|
110
|
+
|
|
111
|
+
for group in self.param_groups:
|
|
112
|
+
lr = group['lr']
|
|
113
|
+
wd = group.get('weight_decay', self.weight_decay)
|
|
114
|
+
|
|
115
|
+
for p in group['params']:
|
|
116
|
+
if p.grad is None:
|
|
117
|
+
continue
|
|
118
|
+
|
|
119
|
+
if p is None:
|
|
120
|
+
continue
|
|
121
|
+
|
|
122
|
+
ng, nu, ne, sp_u = self.update_step(p, lr, wd)
|
|
123
|
+
norm_g += ng
|
|
124
|
+
norm_u += nu
|
|
125
|
+
norm_e += ne
|
|
126
|
+
sparsity_u += sp_u
|
|
127
|
+
|
|
128
|
+
# torch.cuda.synchronize()
|
|
129
|
+
time_end = time.time()
|
|
130
|
+
elapsed_step = time_end - time_start
|
|
131
|
+
self._log(norm_g, norm_u, norm_e, sparsity_u, elapsed_step)
|
|
132
|
+
|
|
133
|
+
return loss
|
|
134
|
+
|
|
135
|
+
@torch.no_grad()
|
|
136
|
+
def update_step(self, p, lr, wd):
|
|
137
|
+
norm_g, norm_u, norm_e, sp_u = 0, 0, 0, 0
|
|
138
|
+
|
|
139
|
+
grad = p.grad.view(-1)
|
|
140
|
+
|
|
141
|
+
if self.steps % self.log_interval == 0:
|
|
142
|
+
norm_g = grad.norm(p=2) ** 2
|
|
143
|
+
|
|
144
|
+
st = self.state[p]
|
|
145
|
+
if len(st) == 0:
|
|
146
|
+
self._initialize_parameter_state(p, lr, wd)
|
|
147
|
+
|
|
148
|
+
# print('rank=',torch.distributed.get_rank(), 'keys=',st.keys())
|
|
149
|
+
|
|
150
|
+
blocks = st['blocks']
|
|
151
|
+
# lr = st['lr']
|
|
152
|
+
# wd = st['weight_decay']
|
|
153
|
+
d = st['d']
|
|
154
|
+
d_block_size = st['d_block_size']
|
|
155
|
+
topk_full_blocks_count, d_index_topk = st['topk_full_blocks_count'], st['d_index_topk']
|
|
156
|
+
k_block_size_many = st['k_block_size_many']
|
|
157
|
+
k_block_size_few = st['k_block_size_few']
|
|
158
|
+
k_index = st['k_index']
|
|
159
|
+
k = st['k']
|
|
160
|
+
|
|
161
|
+
# HuggingFace has a setting that converts st['I'] to bfloat16, even though it is declared as int16
|
|
162
|
+
# This happens somewhere between constructor call and step call. Converting it to int16 was the simplest solution
|
|
163
|
+
if st['I'].dtype != torch.int16:
|
|
164
|
+
st['I'] = st['I'].to(torch.int16)
|
|
165
|
+
|
|
166
|
+
index = st['index']
|
|
167
|
+
I = st['I']
|
|
168
|
+
V = st['V']
|
|
169
|
+
|
|
170
|
+
quant_full_blocks_count, d_index_quant = st['quant_full_blocks_count'], st['d_index_quant']
|
|
171
|
+
error = st['error']
|
|
172
|
+
min_vals = st['min_vals']
|
|
173
|
+
max_vals = st['max_vals']
|
|
174
|
+
|
|
175
|
+
##### STEP 4
|
|
176
|
+
ista_daslab_micro_adam.asymm_block_quant_inv(d, self.quant_block_size, error, min_vals, max_vals, grad, 1.0) # alpha=1 here
|
|
177
|
+
|
|
178
|
+
##### STEP 5 + 9 (only for I)
|
|
179
|
+
I[index, :k_index] = torch.topk(input=grad[0:d_index_topk].abs().view(topk_full_blocks_count, d_block_size),
|
|
180
|
+
k=k_block_size_many,
|
|
181
|
+
sorted=False).indices.to(dtype=torch.int16).view(-1)
|
|
182
|
+
|
|
183
|
+
if k_block_size_few > 0: # there is a small block left
|
|
184
|
+
I[index, k_index:] = torch.topk(input=grad[d_index_topk:].abs(),
|
|
185
|
+
k=k_block_size_few, # example: slice has size 1, but ks[-1] is 4
|
|
186
|
+
sorted=False).indices.to(dtype=torch.int16).view(-1)
|
|
187
|
+
|
|
188
|
+
ista_daslab_tools.copy_values(d, # V = error[I[buffer_index, :]]
|
|
189
|
+
k,
|
|
190
|
+
d_block_size,
|
|
191
|
+
k_block_size_many,
|
|
192
|
+
I[index, :], # indices
|
|
193
|
+
grad, # inp
|
|
194
|
+
V[index, :], # output
|
|
195
|
+
CopyDirection.d2k.value)
|
|
196
|
+
|
|
197
|
+
st['index'] = (index + 1) % self.m
|
|
198
|
+
|
|
199
|
+
##### STEP 6
|
|
200
|
+
ista_daslab_tools.zerorize_block_components(grad, I[index, :], d, k, d_block_size, k_block_size_many) # this does a[I[index]] = 0
|
|
201
|
+
|
|
202
|
+
##### STEP 7
|
|
203
|
+
def _update_quantization_statistics():
|
|
204
|
+
if quant_full_blocks_count == 1:
|
|
205
|
+
min_vals[:quant_full_blocks_count] = grad[:d_index_quant].min()
|
|
206
|
+
max_vals[:quant_full_blocks_count] = grad[:d_index_quant].max()
|
|
207
|
+
else:
|
|
208
|
+
min_vals[:quant_full_blocks_count] = grad[:d_index_quant].view(quant_full_blocks_count, self.quant_block_size).min(dim=1).values
|
|
209
|
+
max_vals[:quant_full_blocks_count] = grad[:d_index_quant].view(quant_full_blocks_count, self.quant_block_size).max(dim=1).values
|
|
210
|
+
if d_index_quant < d:
|
|
211
|
+
min_vals[quant_full_blocks_count] = grad[d_index_quant:].min()
|
|
212
|
+
max_vals[quant_full_blocks_count] = grad[d_index_quant:].max()
|
|
213
|
+
|
|
214
|
+
_update_quantization_statistics()
|
|
215
|
+
|
|
216
|
+
##### STEP 8
|
|
217
|
+
ista_daslab_micro_adam.asymm_block_quant(d, self.quant_block_size, error, min_vals, max_vals, grad) # error = Q(a, min, max)
|
|
218
|
+
|
|
219
|
+
##### STEPS 10-11
|
|
220
|
+
grad.zero_()
|
|
221
|
+
ista_daslab_micro_adam.compute_microadam_update(blocks, # blocks
|
|
222
|
+
self.threads, # threads
|
|
223
|
+
self.shared_memory_carveout, # carveout
|
|
224
|
+
self.steps, # optimization step
|
|
225
|
+
self.beta1, # beta1
|
|
226
|
+
self.beta2, # beta2
|
|
227
|
+
self.eps, # eps
|
|
228
|
+
d_block_size, # d_block_size
|
|
229
|
+
k_block_size_many, # k_block_size
|
|
230
|
+
d, # d
|
|
231
|
+
self.m, # m
|
|
232
|
+
k, # k
|
|
233
|
+
I, # indices
|
|
234
|
+
V, # values
|
|
235
|
+
grad) # update will be stored here
|
|
236
|
+
|
|
237
|
+
##### STEP 12: # side idea: only decay the weights that are update
|
|
238
|
+
|
|
239
|
+
##### if PRETRAINING #1
|
|
240
|
+
if self.densify_update: # we add alpha * EF to update that is stored in grad buffer
|
|
241
|
+
# p.grad += alpha * Qinv(error), alpha=0.1
|
|
242
|
+
ista_daslab_micro_adam.asymm_block_quant_inv(d, self.quant_block_size, error, min_vals, max_vals, grad, self.alpha)
|
|
243
|
+
##### END IF PRETRAINING #1
|
|
244
|
+
|
|
245
|
+
# if alpha > 0, then the update u=p.grad is dense now
|
|
246
|
+
p.mul_(1 - lr * wd).add_(p.grad, alpha=-lr)
|
|
247
|
+
|
|
248
|
+
##### if PRETRAINING #2
|
|
249
|
+
if self.densify_update:
|
|
250
|
+
grad.zero_()
|
|
251
|
+
ista_daslab_micro_adam.asymm_block_quant_inv(d, self.quant_block_size, error, min_vals, max_vals, grad, 1-self.alpha)
|
|
252
|
+
|
|
253
|
+
_update_quantization_statistics() # step 7 again
|
|
254
|
+
ista_daslab_micro_adam.asymm_block_quant(d, self.quant_block_size, error, min_vals, max_vals, grad) # step 8 again
|
|
255
|
+
##### END IF PRETRAINING #2
|
|
256
|
+
|
|
257
|
+
# compute error norm
|
|
258
|
+
if self.steps % self.log_interval == 0:
|
|
259
|
+
norm_u = grad.norm(p=2) ** 2
|
|
260
|
+
sp_u = (grad == 0).sum() # check sparsity before zerorizing
|
|
261
|
+
|
|
262
|
+
grad.zero_()
|
|
263
|
+
ista_daslab_micro_adam.asymm_block_quant_inv(d, self.quant_block_size, error, min_vals, max_vals, grad, 1.0)
|
|
264
|
+
|
|
265
|
+
norm_e = grad.norm(p=2) ** 2
|
|
266
|
+
|
|
267
|
+
return norm_g, norm_u, norm_e, sp_u
|
|
268
|
+
|
|
269
|
+
def _log(self, norm_g, norm_u, norm_e, sparsity_u, elapsed_step):
|
|
270
|
+
if self.steps % self.log_interval == 0:
|
|
271
|
+
sync_data = torch.tensor([norm_g, norm_u, norm_e, sparsity_u, elapsed_step], dtype=torch.float,
|
|
272
|
+
requires_grad=False).cuda() # correct, loss, size
|
|
273
|
+
all_reduce(sync_data, op=ReduceOp.SUM)
|
|
274
|
+
norm_g, norm_u, norm_e, sparsity_u, elapsed_step = sync_data
|
|
275
|
+
|
|
276
|
+
if not is_initialized() or get_rank() == 0:
|
|
277
|
+
wandb_data = {
|
|
278
|
+
'step/optimizer_steps': self.steps,
|
|
279
|
+
'step/gpu_mem_usage': get_gpu_mem_usage(),
|
|
280
|
+
'step/norm_g': math.sqrt(norm_g),
|
|
281
|
+
'step/norm_u': math.sqrt(norm_u),
|
|
282
|
+
'step/norm_error': math.sqrt(norm_e),
|
|
283
|
+
'step/sparsity_u': sparsity_u / self.model_size * 100.,
|
|
284
|
+
'step/elapsed_step': elapsed_step,
|
|
285
|
+
}
|
|
286
|
+
wandb.log(wandb_data, commit=False)
|
|
287
|
+
|
|
288
|
+
# def _update_lr_wd(self):
|
|
289
|
+
# # copy the learning rate group to parameter state because the lr scheduler updates the one in the group
|
|
290
|
+
# for group in self.param_groups:
|
|
291
|
+
# lr = group['lr']
|
|
292
|
+
# wd = group.get('weight_decay', self.weight_decay) # if the param groups do not have weight decay, then use the external one
|
|
293
|
+
# for p in group['params']:
|
|
294
|
+
# self.state[p]['lr'] = lr
|
|
295
|
+
# self.state[p]['wd'] = wd
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
# def _init_state(self):
|
|
299
|
+
# count = 0
|
|
300
|
+
# for group in self.param_groups:
|
|
301
|
+
# lr = group['lr']
|
|
302
|
+
# wd = group.get('weight_decay', self.weight_decay) # if the param groups do not have weight decay, then use the external one
|
|
303
|
+
# for p in group['params']:
|
|
304
|
+
# if not p.requires_grad:
|
|
305
|
+
# continue
|
|
306
|
+
|
|
307
|
+
# print(f'[init_state] rank={torch.distributed.get_rank()}, p.shape={p.shape}')
|
|
308
|
+
|
|
309
|
+
# count += 1
|
|
310
|
+
# layer_size = p.numel()
|
|
311
|
+
# st = self.state[p]
|
|
312
|
+
|
|
313
|
+
# # B * t / d * nt
|
|
314
|
+
# st['blocks'] = max(1, int(math.floor(self.blocks * layer_size * self.dict_size_count[layer_size] / self.model_size)))
|
|
315
|
+
|
|
316
|
+
# st['lr'] = lr
|
|
317
|
+
# st['weight_decay'] = wd
|
|
318
|
+
# st['d'] = layer_size
|
|
319
|
+
|
|
320
|
+
# ##### variables for Top-K: d_index_topk is the index where the last, smaller topk block starts
|
|
321
|
+
# st['d_block_size'] = layer_size if layer_size < self.d_block_size else self.d_block_size
|
|
322
|
+
# st['topk_full_blocks_count'], st['d_index_topk'] = block_split(st['d'], st['d_block_size'])
|
|
323
|
+
# st['k_block_size_many'] = int(math.ceil(st['d_block_size'] * self.k_init))
|
|
324
|
+
# st['k_block_size_few'] = int(math.ceil((st['d'] - st['d_index_topk']) * self.k_init)) # 0 for d % self.d_block_size = 0
|
|
325
|
+
# st['k_index'] = st['topk_full_blocks_count'] * st['k_block_size_many']
|
|
326
|
+
# st['k'] = st['k_block_size_many'] * st['topk_full_blocks_count'] + st['k_block_size_few']
|
|
327
|
+
|
|
328
|
+
# ##### variables for the ring buffer
|
|
329
|
+
# st['index'] = 0 # the position to place a new gradient at
|
|
330
|
+
# st['I'] = torch.zeros(self.m, st['k'], dtype=torch.int16, device=self.device) # 2mk bytes
|
|
331
|
+
# st['V'] = torch.zeros(self.m, st['k'], dtype=torch.bfloat16, device=self.device) # 2mk bytes
|
|
332
|
+
|
|
333
|
+
# ### variables for error feedback: d_index_quant is the index where the last, smaller quantization block starts
|
|
334
|
+
# # st['quant_block_size'] = layer_size if layer_size < self.quant_block_size else self.quant_block_size
|
|
335
|
+
# st['quant_full_blocks_count'], st['d_index_quant'] = block_split(st['d'], self.quant_block_size)
|
|
336
|
+
# st['error'] = torch.zeros(int(math.ceil(st['d'] / 2)), dtype=torch.uint8, device=self.device) # ceil(d/2) bytes
|
|
337
|
+
# st['min_vals'] = torch.zeros(st['quant_full_blocks_count'] + 1, dtype=torch.bfloat16, device=self.device) # ceil(d/q_bsz)*2 bytes
|
|
338
|
+
# st['max_vals'] = torch.zeros(st['quant_full_blocks_count'] + 1, dtype=torch.bfloat16, device=self.device) # ceil(d/q_bsz)*2 bytes
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ista_daslab_optimizers
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.1.2
|
|
4
4
|
Summary: Deep Learning optimizers developed in the Distributed Algorithms and Systems group (DASLab) @ Institute of Science and Technology Austria (ISTA)
|
|
5
5
|
Author-email: Ionut-Vlad Modoranu <ionut-vlad.modoranu@ist.ac.at>
|
|
6
6
|
Maintainer-email: Ionut-Vlad Modoranu <ionut-vlad.modoranu@ist.ac.at>
|
|
@@ -280,6 +280,7 @@ optimizer = MicroAdam(
|
|
|
280
280
|
lr=1e-5, # change accordingly
|
|
281
281
|
quant_block_size=100_000, # 32 or 64 also works
|
|
282
282
|
k_init=0.01, # float between 0 and 1 meaning percentage: 0.01 means 1%
|
|
283
|
+
alpha=0, # 0 means sparse update and 0 < alpha < 1 means we integrate fraction alpha from EF to update and then delete it
|
|
283
284
|
)
|
|
284
285
|
|
|
285
286
|
# from now on, you can use the variable `optimizer` as any other PyTorch optimizer
|
|
@@ -288,15 +289,18 @@ optimizer = MicroAdam(
|
|
|
288
289
|
# Versions summary:
|
|
289
290
|
|
|
290
291
|
---
|
|
292
|
+
- **1.1.2** @ August 1st, 2024:
|
|
293
|
+
- ***[1.1.0]:*** added support to densify the final update: introduced parameter alpha that controls the fraction of error feedback
|
|
294
|
+
(EF) to be integrated into the update to make it dense. Finally, the fraction alpha will be discarded from the EF at
|
|
295
|
+
the expense of another call to `Qinv` and `Q` (and implicitly quantization statistics computation).
|
|
296
|
+
- ***[1.0.2]:*** added FSDP-compatible implementation by initializing the parameter states in the `update_step` method
|
|
297
|
+
instead of MicroAdam constructor
|
|
291
298
|
|
|
292
299
|
- **1.0.1** @ June 27th, 2024:
|
|
293
|
-
|
|
294
300
|
- removed version in dependencies to avoid conflicts with llm-foundry
|
|
295
301
|
|
|
296
302
|
- **1.0.0** @ June 20th, 2024:
|
|
297
|
-
|
|
298
303
|
- changed minimum required Python version to 3.8+ and torch to 2.3.0+
|
|
299
304
|
|
|
300
305
|
- **0.0.1** @ June 13th, 2024:
|
|
301
|
-
|
|
302
306
|
- added initial version of the package for Python 3.9+ and torch 2.3.1+
|
{ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/kernels/micro_adam/micro_adam.cpp
RENAMED
|
@@ -13,7 +13,7 @@ void compute_microadam_update_cuda(int blocks, int threads, int carveout,
|
|
|
13
13
|
torch::Tensor indices, torch::Tensor values, torch::Tensor out);
|
|
14
14
|
|
|
15
15
|
void asymm_block_quant_cuda(LL d, LL q_block_size, torch::Tensor xq, torch::Tensor xmin, torch::Tensor xmax, torch::Tensor x);
|
|
16
|
-
void asymm_block_quant_inv_cuda(LL d, LL q_block_size, torch::Tensor xq, torch::Tensor xmin, torch::Tensor xmax, torch::Tensor x);
|
|
16
|
+
void asymm_block_quant_inv_cuda(LL d, LL q_block_size, torch::Tensor xq, torch::Tensor xmin, torch::Tensor xmax, torch::Tensor x, float alpha);
|
|
17
17
|
|
|
18
18
|
// C++ methods
|
|
19
19
|
void compute_microadam_update(int blocks, int threads, int carveout,
|
|
@@ -44,14 +44,14 @@ void asymm_block_quant(LL d, LL q_block_size, torch::Tensor xq, torch::Tensor xm
|
|
|
44
44
|
asymm_block_quant_cuda(d, q_block_size, xq, xmin, xmax, x);
|
|
45
45
|
}
|
|
46
46
|
|
|
47
|
-
void asymm_block_quant_inv(LL d, LL q_block_size, torch::Tensor xq, torch::Tensor xmin, torch::Tensor xmax, torch::Tensor x) {
|
|
47
|
+
void asymm_block_quant_inv(LL d, LL q_block_size, torch::Tensor xq, torch::Tensor xmin, torch::Tensor xmax, torch::Tensor x, float alpha) {
|
|
48
48
|
CHECK_INPUT(xq);
|
|
49
49
|
CHECK_INPUT(xmin);
|
|
50
50
|
CHECK_INPUT(xmax);
|
|
51
51
|
CHECK_INPUT(x);
|
|
52
52
|
|
|
53
53
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
|
54
|
-
asymm_block_quant_inv_cuda(d, q_block_size, xq, xmin, xmax, x);
|
|
54
|
+
asymm_block_quant_inv_cuda(d, q_block_size, xq, xmin, xmax, x, alpha);
|
|
55
55
|
}
|
|
56
56
|
|
|
57
57
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
#include "../utils.h"
|
|
2
2
|
|
|
3
|
-
__global__ void asymm_block_quant_inv_kernel_bf16_bf16(LL d, LL q_block_size, uint8_t *xq, bfloat16 *xmin, bfloat16 *xmax, bfloat16 *x) {
|
|
3
|
+
__global__ void asymm_block_quant_inv_kernel_bf16_bf16(LL d, LL q_block_size, uint8_t *xq, bfloat16 *xmin, bfloat16 *xmax, bfloat16 *x, float alpha) {
|
|
4
4
|
/*
|
|
5
|
-
This kernel computes x += Q_inv(xq, xmin, xmax) for 4 bits (implements point 1 from PhD notebook page 55)
|
|
5
|
+
This kernel computes x += alpha * Q_inv(xq, xmin, xmax) for 4 bits (implements point 1 from PhD #9 notebook page 55)
|
|
6
6
|
Here, x is the output buffer and will already contain the dense gradient
|
|
7
7
|
The output buffer x has d components and xq has d/2 components because one uint8_t stores two 4-bit values
|
|
8
8
|
In contrast to "globally" kernel, this kernel works with a single block
|
|
@@ -43,20 +43,21 @@ __global__ void asymm_block_quant_inv_kernel_bf16_bf16(LL d, LL q_block_size, ui
|
|
|
43
43
|
lsb = (vq & 0x0F);
|
|
44
44
|
|
|
45
45
|
// += operation happens here
|
|
46
|
-
vx2.x += __float2bfloat16(msb * u + m);
|
|
47
|
-
vx2.y += __float2bfloat16(lsb * u + m);
|
|
46
|
+
vx2.x += __float2bfloat16((msb * u + m) * alpha);
|
|
47
|
+
vx2.y += __float2bfloat16((lsb * u + m) * alpha);
|
|
48
48
|
x2[half_index] = vx2;
|
|
49
49
|
}
|
|
50
50
|
if((d & 1) && (Bid == B-1) && (Tid == T-1)) {
|
|
51
51
|
bfloat16 vx = x[d - 1];
|
|
52
52
|
vq = xq[half_d];
|
|
53
53
|
msb = (vq & 0xF0) >> 4;
|
|
54
|
-
|
|
54
|
+
// += operation happens here
|
|
55
|
+
vx += __float2bfloat16((msb * u + m) * alpha);
|
|
55
56
|
x[d - 1] = vx;
|
|
56
57
|
}
|
|
57
58
|
}
|
|
58
59
|
|
|
59
|
-
void asymm_block_quant_inv_cuda(LL d, LL q_block_size, torch::Tensor xq, torch::Tensor xmin, torch::Tensor xmax, torch::Tensor x) {
|
|
60
|
+
void asymm_block_quant_inv_cuda(LL d, LL q_block_size, torch::Tensor xq, torch::Tensor xmin, torch::Tensor xmax, torch::Tensor x, float alpha) {
|
|
60
61
|
ASSERT_BF16(xmin);
|
|
61
62
|
ASSERT_BF16(xmax);
|
|
62
63
|
ASSERT_BF16(x);
|
|
@@ -72,7 +73,8 @@ void asymm_block_quant_inv_cuda(LL d, LL q_block_size, torch::Tensor xq, torch::
|
|
|
72
73
|
(uint8_t*) xq.data_ptr(),
|
|
73
74
|
(bfloat16*) xmin.data_ptr(),
|
|
74
75
|
(bfloat16*) xmax.data_ptr(),
|
|
75
|
-
(bfloat16*) x.data_ptr()
|
|
76
|
+
(bfloat16*) x.data_ptr(),
|
|
77
|
+
alpha);
|
|
76
78
|
|
|
77
79
|
// error checks
|
|
78
80
|
GPU_ERROR_CHECK(cudaGetLastError());
|
|
@@ -144,7 +144,7 @@ __global__ void LCG_v51_bf16(long d,
|
|
|
144
144
|
|
|
145
145
|
long i; // vector index (for indices and values)
|
|
146
146
|
long k_col; // column index to extract data from indices and values at the current row
|
|
147
|
-
|
|
147
|
+
// long index; // the 1-D index to extract data from indices and values at the current row (row * k + k_col)
|
|
148
148
|
int16 ind; // the data from indices at the index "index"
|
|
149
149
|
bfloat16 val; // the data from values at the index "index"
|
|
150
150
|
|
|
@@ -1,247 +0,0 @@
|
|
|
1
|
-
|
|
2
|
-
import torch
|
|
3
|
-
import math
|
|
4
|
-
import time
|
|
5
|
-
import wandb
|
|
6
|
-
from torch.distributed import is_initialized, get_rank
|
|
7
|
-
from ..tools import get_first_device, get_gpu_mem_usage, block_split, CopyDirection
|
|
8
|
-
|
|
9
|
-
import ista_daslab_tools
|
|
10
|
-
import ista_daslab_micro_adam
|
|
11
|
-
|
|
12
|
-
class MicroAdam(torch.optim.Optimizer):
|
|
13
|
-
def __init__(self, params, m, lr, quant_block_size, k_init=0.01, betas=(0.9, 0.999), weight_decay=0, eps=1e-8):
|
|
14
|
-
defaults = dict(lr=lr, weight_decay=weight_decay, eps=eps)
|
|
15
|
-
super(MicroAdam, self).__init__(params, defaults)
|
|
16
|
-
|
|
17
|
-
self.m = m
|
|
18
|
-
self.lr = lr
|
|
19
|
-
self.quant_block_size = int(quant_block_size)
|
|
20
|
-
self.k_init = k_init
|
|
21
|
-
self.weight_decay = weight_decay
|
|
22
|
-
self.beta1 = betas[0]
|
|
23
|
-
self.beta2 = betas[1]
|
|
24
|
-
self.eps = eps
|
|
25
|
-
|
|
26
|
-
self.model_size = sum([p.numel() for group in self.param_groups for p in group['params']])
|
|
27
|
-
|
|
28
|
-
self.steps = 0 # how many optimization steps were performed so far
|
|
29
|
-
self.log_interval = 100
|
|
30
|
-
self.device = get_first_device()
|
|
31
|
-
self._is_state_initialized = False
|
|
32
|
-
self.shared_memory_carveout = 100
|
|
33
|
-
self.blocks = ista_daslab_tools.get_sm_count() * int(100 / self.shared_memory_carveout)
|
|
34
|
-
self.threads = 512
|
|
35
|
-
|
|
36
|
-
self.dict_size_count = {} # key = layer size, value = how many layers of that size the model has
|
|
37
|
-
for param in self.param_groups:
|
|
38
|
-
for p in param['params']:
|
|
39
|
-
size = p.numel()
|
|
40
|
-
self.dict_size_count[size] = 1 + self.dict_size_count.get(size, 0)
|
|
41
|
-
|
|
42
|
-
self._init_state()
|
|
43
|
-
|
|
44
|
-
def _init_state(self):
|
|
45
|
-
max_floats = ista_daslab_tools.get_max_floats_for_shared_memory_per_thread_block()
|
|
46
|
-
d_block_size = max_floats // 2 // int(100 / self.shared_memory_carveout)
|
|
47
|
-
count = 0
|
|
48
|
-
for group in self.param_groups:
|
|
49
|
-
lr = group['lr']
|
|
50
|
-
wd = group.get('weight_decay', self.weight_decay) # if the param groups do not have weight decay, then use the external one
|
|
51
|
-
for p in group['params']:
|
|
52
|
-
if not p.requires_grad:
|
|
53
|
-
continue
|
|
54
|
-
count += 1
|
|
55
|
-
layer_size = p.numel()
|
|
56
|
-
st = self.state[p]
|
|
57
|
-
|
|
58
|
-
# B * t / d * nt
|
|
59
|
-
st['blocks'] = max(1, int(math.floor(self.blocks * layer_size * self.dict_size_count[layer_size] / self.model_size)))
|
|
60
|
-
|
|
61
|
-
st['lr'] = lr
|
|
62
|
-
st['weight_decay'] = wd
|
|
63
|
-
st['d'] = layer_size
|
|
64
|
-
|
|
65
|
-
##### variables for Top-K: d_index_topk is the index where the last, smaller topk block starts
|
|
66
|
-
st['d_block_size'] = layer_size if layer_size < d_block_size else d_block_size
|
|
67
|
-
st['topk_full_blocks_count'], st['d_index_topk'] = block_split(st['d'], st['d_block_size'])
|
|
68
|
-
st['k_block_size_many'] = int(math.ceil(st['d_block_size'] * self.k_init))
|
|
69
|
-
st['k_block_size_few'] = int(math.ceil((st['d'] - st['d_index_topk']) * self.k_init)) # 0 for d % d_block_size = 0
|
|
70
|
-
st['k_index'] = st['topk_full_blocks_count'] * st['k_block_size_many']
|
|
71
|
-
st['k'] = st['k_block_size_many'] * st['topk_full_blocks_count'] + st['k_block_size_few']
|
|
72
|
-
|
|
73
|
-
##### variables for the ring buffer
|
|
74
|
-
st['index'] = 0 # the position to place a new gradient at
|
|
75
|
-
st['I'] = torch.zeros(self.m, st['k'], dtype=torch.int16, device=self.device) # 2mk bytes
|
|
76
|
-
st['V'] = torch.zeros(self.m, st['k'], dtype=torch.bfloat16, device=self.device) # 2mk bytes
|
|
77
|
-
|
|
78
|
-
### variables for error feedback: d_index_quant is the index where the last, smaller quantization block starts
|
|
79
|
-
# st['quant_block_size'] = layer_size if layer_size < self.quant_block_size else self.quant_block_size
|
|
80
|
-
st['quant_full_blocks_count'], st['d_index_quant'] = block_split(st['d'], self.quant_block_size)
|
|
81
|
-
st['error'] = torch.zeros(int(math.ceil(st['d'] / 2)), dtype=torch.uint8, device=self.device) # ceil(d/2) bytes
|
|
82
|
-
st['min_vals'] = torch.zeros(st['quant_full_blocks_count'] + 1, dtype=torch.bfloat16, device=self.device) # ceil(d/q_bsz)*2 bytes
|
|
83
|
-
st['max_vals'] = torch.zeros(st['quant_full_blocks_count'] + 1, dtype=torch.bfloat16, device=self.device) # ceil(d/q_bsz)*2 bytes
|
|
84
|
-
|
|
85
|
-
@torch.no_grad()
|
|
86
|
-
def step(self, closure=None):
|
|
87
|
-
self.steps += 1
|
|
88
|
-
|
|
89
|
-
self._update_lr_wd()
|
|
90
|
-
|
|
91
|
-
loss = None
|
|
92
|
-
if closure is not None:
|
|
93
|
-
with torch.enable_grad():
|
|
94
|
-
loss = closure()
|
|
95
|
-
|
|
96
|
-
time_start = time.time()
|
|
97
|
-
|
|
98
|
-
norm_g, norm_u, norm_e, sparsity_u = 0, 0, 0, 0
|
|
99
|
-
for group in self.param_groups:
|
|
100
|
-
for p in group['params']:
|
|
101
|
-
if p.grad is None:
|
|
102
|
-
continue
|
|
103
|
-
ng, nu, ne, sp_u = self.update_step(p)
|
|
104
|
-
norm_g += ng
|
|
105
|
-
norm_u += nu
|
|
106
|
-
norm_e += ne
|
|
107
|
-
sparsity_u += sp_u
|
|
108
|
-
|
|
109
|
-
# torch.cuda.synchronize()
|
|
110
|
-
time_end = time.time()
|
|
111
|
-
elapsed_step = time_end - time_start
|
|
112
|
-
self._log(norm_g, norm_u, norm_e, sparsity_u, elapsed_step)
|
|
113
|
-
|
|
114
|
-
return loss
|
|
115
|
-
|
|
116
|
-
@torch.no_grad()
|
|
117
|
-
def update_step(self, p):
|
|
118
|
-
norm_g, norm_u, norm_e, sp_u = 0, 0, 0, 0
|
|
119
|
-
|
|
120
|
-
st = self.state[p]
|
|
121
|
-
grad = p.grad.view(-1)
|
|
122
|
-
|
|
123
|
-
if self.steps % self.log_interval == 0:
|
|
124
|
-
norm_g = grad.norm(p=2) ** 2
|
|
125
|
-
|
|
126
|
-
blocks = st['blocks']
|
|
127
|
-
lr = st['lr']
|
|
128
|
-
wd = st['weight_decay']
|
|
129
|
-
d = st['d']
|
|
130
|
-
d_block_size = st['d_block_size']
|
|
131
|
-
topk_full_blocks_count, d_index_topk = st['topk_full_blocks_count'], st['d_index_topk']
|
|
132
|
-
k_block_size_many = st['k_block_size_many']
|
|
133
|
-
k_block_size_few = st['k_block_size_few']
|
|
134
|
-
k_index = st['k_index']
|
|
135
|
-
k = st['k']
|
|
136
|
-
|
|
137
|
-
# HuggingFace has a setting that converts st['I'] to bfloat16, even though it is declared as int16
|
|
138
|
-
# This happens somewhere between constructor call and step call. Converting it to int16 was the simplest solution
|
|
139
|
-
if st['I'].dtype != torch.int16:
|
|
140
|
-
st['I'] = st['I'].to(torch.int16)
|
|
141
|
-
|
|
142
|
-
index = st['index']
|
|
143
|
-
I = st['I']
|
|
144
|
-
V = st['V']
|
|
145
|
-
|
|
146
|
-
quant_full_blocks_count, d_index_quant = st['quant_full_blocks_count'], st['d_index_quant']
|
|
147
|
-
error = st['error']
|
|
148
|
-
min_vals = st['min_vals']
|
|
149
|
-
max_vals = st['max_vals']
|
|
150
|
-
|
|
151
|
-
##### STEP 4
|
|
152
|
-
ista_daslab_micro_adam.asymm_block_quant_inv(d, self.quant_block_size, error, min_vals, max_vals, grad)
|
|
153
|
-
|
|
154
|
-
##### STEP 5 + 9 (only for I)
|
|
155
|
-
I[index, :k_index] = torch.topk(input=grad[0:d_index_topk].abs().view(topk_full_blocks_count, d_block_size),
|
|
156
|
-
k=k_block_size_many,
|
|
157
|
-
sorted=False).indices.to(dtype=torch.int16).view(-1)
|
|
158
|
-
|
|
159
|
-
if k_block_size_few > 0: # there is a small block left
|
|
160
|
-
I[index, k_index:] = torch.topk(input=grad[d_index_topk:].abs(),
|
|
161
|
-
k=k_block_size_few, # example: slice has size 1, but ks[-1] is 4
|
|
162
|
-
sorted=False).indices.to(dtype=torch.int16).view(-1)
|
|
163
|
-
|
|
164
|
-
ista_daslab_tools.copy_values(d, # V = error[I[buffer_index, :]]
|
|
165
|
-
k,
|
|
166
|
-
d_block_size,
|
|
167
|
-
k_block_size_many,
|
|
168
|
-
I[index, :], # indices
|
|
169
|
-
grad, # inp
|
|
170
|
-
V[index, :], # output
|
|
171
|
-
CopyDirection.d2k.value)
|
|
172
|
-
|
|
173
|
-
st['index'] = (index + 1) % self.m
|
|
174
|
-
|
|
175
|
-
##### STEP 6
|
|
176
|
-
ista_daslab_tools.zerorize_block_components(grad, I[index, :], d, k, d_block_size, k_block_size_many) # this does a[I[index]] = 0
|
|
177
|
-
|
|
178
|
-
##### STEP 7
|
|
179
|
-
if quant_full_blocks_count == 1:
|
|
180
|
-
min_vals[:quant_full_blocks_count] = grad[:d_index_quant].min()
|
|
181
|
-
max_vals[:quant_full_blocks_count] = grad[:d_index_quant].max()
|
|
182
|
-
else:
|
|
183
|
-
min_vals[:quant_full_blocks_count] = grad[:d_index_quant].view(quant_full_blocks_count, self.quant_block_size).min(dim=1).values
|
|
184
|
-
max_vals[:quant_full_blocks_count] = grad[:d_index_quant].view(quant_full_blocks_count, self.quant_block_size).max(dim=1).values
|
|
185
|
-
if d_index_quant < d:
|
|
186
|
-
min_vals[quant_full_blocks_count] = grad[d_index_quant:].min()
|
|
187
|
-
max_vals[quant_full_blocks_count] = grad[d_index_quant:].max()
|
|
188
|
-
|
|
189
|
-
##### STEP 8
|
|
190
|
-
ista_daslab_micro_adam.asymm_block_quant(d, self.quant_block_size, error, min_vals, max_vals, grad) # error = Q(a, min, max)
|
|
191
|
-
|
|
192
|
-
##### STEPS 10-11
|
|
193
|
-
grad.zero_()
|
|
194
|
-
ista_daslab_micro_adam.compute_microadam_update(blocks, # blocks
|
|
195
|
-
self.threads, # threads
|
|
196
|
-
self.shared_memory_carveout, # carveout
|
|
197
|
-
self.steps, # optimization step
|
|
198
|
-
self.beta1, # beta1
|
|
199
|
-
self.beta2, # beta2
|
|
200
|
-
self.eps, # eps
|
|
201
|
-
d_block_size, # d_block_size
|
|
202
|
-
k_block_size_many, # k_block_size
|
|
203
|
-
d, # d
|
|
204
|
-
self.m, # m
|
|
205
|
-
k, # k
|
|
206
|
-
I, # indices
|
|
207
|
-
V, # values
|
|
208
|
-
grad) # update will be stored here
|
|
209
|
-
|
|
210
|
-
##### STEP 12
|
|
211
|
-
p.mul_(1 - lr * wd).add_(p.grad, alpha=-lr)
|
|
212
|
-
|
|
213
|
-
# compute error norm
|
|
214
|
-
if self.steps % self.log_interval == 0:
|
|
215
|
-
norm_u = grad.norm(p=2) ** 2
|
|
216
|
-
sp_u = (grad == 0).sum() # check sparsity before zerorizing
|
|
217
|
-
|
|
218
|
-
grad.zero_()
|
|
219
|
-
ista_daslab_micro_adam.asymm_block_quant_inv(d, self.quant_block_size, error, min_vals, max_vals, grad)
|
|
220
|
-
|
|
221
|
-
norm_e = grad.norm(p=2) ** 2
|
|
222
|
-
|
|
223
|
-
return norm_g, norm_u, norm_e, sp_u
|
|
224
|
-
|
|
225
|
-
def _log(self, norm_g, norm_u, norm_e, sparsity_u, elapsed_step):
|
|
226
|
-
if self.steps % self.log_interval == 0:
|
|
227
|
-
wandb_data = {
|
|
228
|
-
'step/optimizer_steps': self.steps,
|
|
229
|
-
'step/gpu_mem_usage': get_gpu_mem_usage(),
|
|
230
|
-
'step/norm_g': math.sqrt(norm_g),
|
|
231
|
-
'step/norm_u': math.sqrt(norm_u),
|
|
232
|
-
'step/norm_error': math.sqrt(norm_e),
|
|
233
|
-
'step/sparsity_u': sparsity_u / self.model_size * 100.,
|
|
234
|
-
'step/elapsed_step': elapsed_step,
|
|
235
|
-
}
|
|
236
|
-
|
|
237
|
-
if not is_initialized() or get_rank() == 0:
|
|
238
|
-
wandb.log(wandb_data, commit=False)
|
|
239
|
-
|
|
240
|
-
def _update_lr_wd(self):
|
|
241
|
-
# copy the learning rate group to parameter state because the lr scheduler updates the one in the group
|
|
242
|
-
for group in self.param_groups:
|
|
243
|
-
lr = group['lr']
|
|
244
|
-
wd = group.get('weight_decay', self.weight_decay) # if the param groups do not have weight decay, then use the external one
|
|
245
|
-
for p in group['params']:
|
|
246
|
-
self.state[p]['lr'] = lr
|
|
247
|
-
self.state[p]['wd'] = wd
|
|
File without changes
|
|
File without changes
|
{ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
{ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers/acdc/acdc.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/ista_daslab_optimizers/tools.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/kernels/dense_mfac/dense_mfac.cpp
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ista_daslab_optimizers-1.0.1 → ista_daslab_optimizers-1.1.2}/kernels/sparse_mfac/sparse_mfac.cpp
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|