ista-daslab-optimizers 1.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ista_daslab_optimizers/__init__.py +6 -0
- ista_daslab_optimizers/acdc/__init__.py +5 -0
- ista_daslab_optimizers/acdc/acdc.py +387 -0
- ista_daslab_optimizers/acdc/wd_scheduler.py +31 -0
- ista_daslab_optimizers/dense_mfac/__init__.py +5 -0
- ista_daslab_optimizers/dense_mfac/dense_core_mfac.py +164 -0
- ista_daslab_optimizers/dense_mfac/dense_mfac.py +93 -0
- ista_daslab_optimizers/fft_low_rank/dct_adamw.py +351 -0
- ista_daslab_optimizers/fft_low_rank/fft_projector.py +192 -0
- ista_daslab_optimizers/fft_low_rank/trion.py +242 -0
- ista_daslab_optimizers/ista_optimizer/__init__.py +5 -0
- ista_daslab_optimizers/ista_optimizer/ista_optimizer.py +36 -0
- ista_daslab_optimizers/micro_adam/__init__.py +5 -0
- ista_daslab_optimizers/micro_adam/micro_adam.py +402 -0
- ista_daslab_optimizers/sparse_mfac/__init__.py +7 -0
- ista_daslab_optimizers/sparse_mfac/sparse_core_mfac_w_ef.py +226 -0
- ista_daslab_optimizers/sparse_mfac/sparse_mfac.py +87 -0
- ista_daslab_optimizers/tools.py +218 -0
- ista_daslab_optimizers/utils/dct.py +45 -0
- ista_daslab_optimizers/utils/global_cache.py +45 -0
- ista_daslab_optimizers/utils/matrix_storage.py +58 -0
- ista_daslab_optimizers/utils/newton_schulz_triton.py +374 -0
- ista_daslab_optimizers/utils/quantizers.py +71 -0
- ista_daslab_optimizers/utils/schedulers.py +41 -0
- ista_daslab_optimizers-1.1.8.dist-info/METADATA +333 -0
- ista_daslab_optimizers-1.1.8.dist-info/RECORD +29 -0
- ista_daslab_optimizers-1.1.8.dist-info/WHEEL +5 -0
- ista_daslab_optimizers-1.1.8.dist-info/licenses/LICENSE +201 -0
- ista_daslab_optimizers-1.1.8.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
import wandb
|
|
2
|
+
import torch
|
|
3
|
+
from ..tools import get_weights_and_gradients, update_model, get_first_device, get_gpus
|
|
4
|
+
from .dense_core_mfac import DenseCoreMFAC
|
|
5
|
+
|
|
6
|
+
# Disable tensor cores as they can mess with precision
|
|
7
|
+
torch.backends.cuda.matmul.allow_tf32 = False
|
|
8
|
+
torch.backends.cudnn.allow_tf32 = False
|
|
9
|
+
|
|
10
|
+
class DenseMFAC(torch.optim.Optimizer):
|
|
11
|
+
def __init__(self,
|
|
12
|
+
params,
|
|
13
|
+
lr: float,
|
|
14
|
+
weight_decay: float,
|
|
15
|
+
ngrads: int,
|
|
16
|
+
damp: float,
|
|
17
|
+
create_G=False):
|
|
18
|
+
|
|
19
|
+
super(DenseMFAC, self).__init__(params, dict(lr=lr, weight_decay=weight_decay))
|
|
20
|
+
|
|
21
|
+
self.m = ngrads
|
|
22
|
+
self.lr = lr
|
|
23
|
+
self.damp = damp
|
|
24
|
+
self.weight_decay = weight_decay
|
|
25
|
+
self.device = get_first_device()
|
|
26
|
+
self.create_G = create_G
|
|
27
|
+
|
|
28
|
+
self.model_size = None
|
|
29
|
+
self.steps = 0
|
|
30
|
+
self.wandb_data = dict()
|
|
31
|
+
|
|
32
|
+
self.model_size = sum([p.numel() for group in self.param_groups for p in group['params']])
|
|
33
|
+
|
|
34
|
+
self.hinv = None
|
|
35
|
+
|
|
36
|
+
def _create_hinv(self):
|
|
37
|
+
self.hinv = DenseCoreMFAC(
|
|
38
|
+
grads=torch.zeros((self.m, self.model_size), dtype=torch.float),
|
|
39
|
+
dev=self.device,
|
|
40
|
+
gpus=get_gpus(),
|
|
41
|
+
damp=self.damp,
|
|
42
|
+
create_G=self.create_G)
|
|
43
|
+
|
|
44
|
+
@torch.no_grad()
|
|
45
|
+
def empty_buffer(self):
|
|
46
|
+
self.hinv.empty_buffer()
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@torch.no_grad()
|
|
50
|
+
def integrate_gradient(self, g):
|
|
51
|
+
_ = self.hinv.integrate_gradient(g)
|
|
52
|
+
|
|
53
|
+
@torch.no_grad()
|
|
54
|
+
def compute_update(self, g, x):
|
|
55
|
+
update_method = self.hinv.integrate_gradient_and_precondition
|
|
56
|
+
# if self.use_sq_newton:
|
|
57
|
+
# update_method = self.hinv.integrate_gradient_and_precondition_twice
|
|
58
|
+
|
|
59
|
+
update = update_method(g, x).to(self.device)
|
|
60
|
+
return update
|
|
61
|
+
|
|
62
|
+
@torch.no_grad()
|
|
63
|
+
def log_data(self, update, g):
|
|
64
|
+
lr = self.param_groups[0]['lr']
|
|
65
|
+
self.wandb_data.update(dict(norm_upd_w_lr=lr * update.norm(p=2), norm_g=g.norm(p=2)))
|
|
66
|
+
self.wandb_data.update(self.hinv.wandb_data)
|
|
67
|
+
# self.wandb_data.update(quantify_preconditioning(g=g, u=update.to(g.device), return_distribution=False, use_abs=True, optim_name=self.optim_name))
|
|
68
|
+
wandb.log(self.wandb_data)
|
|
69
|
+
|
|
70
|
+
@torch.no_grad()
|
|
71
|
+
def step(self, closure=None):
|
|
72
|
+
self.steps += 1
|
|
73
|
+
|
|
74
|
+
if self.hinv is None:
|
|
75
|
+
self._create_hinv()
|
|
76
|
+
|
|
77
|
+
g = get_weights_and_gradients(self.param_groups, get_weights=False)
|
|
78
|
+
update = self.compute_update(g, g)
|
|
79
|
+
|
|
80
|
+
update_model(params=self.param_groups, update=update, alpha=None)
|
|
81
|
+
|
|
82
|
+
if self.steps % self.m == 0:
|
|
83
|
+
self.log_data(update, g)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
# def kfac_update_rescaling(self, g, u):
|
|
87
|
+
# # rescaling_kfac64, # use the alpha rescaling from K-FAC paper, section 6.4
|
|
88
|
+
# T1 = torch.dot(g, u)
|
|
89
|
+
# T2 = (self.hinv.grads_matmul(u).norm(p=2) ** 2).mean()
|
|
90
|
+
# T3 = self.damp * (u.norm(p=2) ** 2)
|
|
91
|
+
# alpha = -T1 / (T2 + T3)
|
|
92
|
+
# self.wandb_data.update(dict(kfac_T1=T1, kfac_T2=T2, kfac_T3=T3, kfac_alpha=alpha))
|
|
93
|
+
# return alpha
|
|
@@ -0,0 +1,351 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.distributed as dist
|
|
3
|
+
|
|
4
|
+
import math
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from fast_hadamard_transform import hadamard_transform
|
|
8
|
+
from ista_daslab_optimizers.utils.dct import dct3_matrix
|
|
9
|
+
from ista_daslab_optimizers.utils.quantizers import Quantizer4bit, Quantizer8bit
|
|
10
|
+
from ista_daslab_optimizers.fft_low_rank.fft_projector import FFTLowRankProjector
|
|
11
|
+
|
|
12
|
+
PROJ_DCT = 'dct'
|
|
13
|
+
PROJ_HDM = 'hdm'
|
|
14
|
+
PROJ_RAND_QR = 'rqr'
|
|
15
|
+
|
|
16
|
+
ALL_PROJ = [
|
|
17
|
+
PROJ_DCT, # DCT projection
|
|
18
|
+
PROJ_HDM, # Hadamard projection
|
|
19
|
+
PROJ_RAND_QR, # Random-QR projection
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
STATE_M = 'm'
|
|
23
|
+
STATE_V = 'v'
|
|
24
|
+
STATE_Q = 'Q'
|
|
25
|
+
STATE_ID = 'param-id'
|
|
26
|
+
STATE_EF = 'ef'
|
|
27
|
+
# STATE_EF_MIN = 'ef-min-vals'
|
|
28
|
+
# STATE_EF_MAX = 'ef-max-vals'
|
|
29
|
+
STATE_FFT_LRP = 'fft-low-rank-projector'
|
|
30
|
+
STATE_BROADCAST_SOURCE = 'broadcast-src' # the process rank that computes the update for a parameter p will broadcast the parameter p to other workers
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class DCTAdamW(torch.optim.Optimizer):
|
|
34
|
+
def __init__(self,
|
|
35
|
+
params,
|
|
36
|
+
lr,
|
|
37
|
+
weight_decay,
|
|
38
|
+
rank,
|
|
39
|
+
proj,
|
|
40
|
+
use_ef=False,
|
|
41
|
+
q_ef=False,
|
|
42
|
+
distributed=False,
|
|
43
|
+
update_proj_gap=1,
|
|
44
|
+
rotate_subspace=False,
|
|
45
|
+
sim_type='matmul',
|
|
46
|
+
ell_norm=1,
|
|
47
|
+
max_shape=32_000,
|
|
48
|
+
betas=(0.9, 0.999),
|
|
49
|
+
eps=1e-8):
|
|
50
|
+
assert proj in ALL_PROJ
|
|
51
|
+
|
|
52
|
+
super().__init__(params, dict(lr=lr, weight_decay=weight_decay))
|
|
53
|
+
|
|
54
|
+
self.rank = rank
|
|
55
|
+
self.proj = proj
|
|
56
|
+
self.use_ef = use_ef
|
|
57
|
+
self.q_ef = q_ef
|
|
58
|
+
self.distributed = distributed
|
|
59
|
+
self.update_proj_gap = update_proj_gap
|
|
60
|
+
self.rotate_subspace = rotate_subspace
|
|
61
|
+
self.sim_type = sim_type
|
|
62
|
+
self.ell_norm = ell_norm
|
|
63
|
+
self.max_shape = max_shape # apply low-rank to 2D parameters that have both dimensions smaller than max_shape
|
|
64
|
+
self.betas = betas
|
|
65
|
+
self.eps = eps
|
|
66
|
+
|
|
67
|
+
self.steps = 0
|
|
68
|
+
self.is_state_initialized = False
|
|
69
|
+
self.Q = None # the full transformation matrix (non-truncated, all rows and columns)
|
|
70
|
+
self.Q_cols_norm = None
|
|
71
|
+
self.use_theoretical_similarity = (self.ell_norm < 0)
|
|
72
|
+
self.ell_norm = abs(self.ell_norm)
|
|
73
|
+
|
|
74
|
+
if proj == PROJ_DCT:
|
|
75
|
+
assert sim_type in ['matmul', 'makhoul']
|
|
76
|
+
else:
|
|
77
|
+
assert sim_type == 'matmul'
|
|
78
|
+
|
|
79
|
+
def setup_Q(self, p):
|
|
80
|
+
if self.Q is None:
|
|
81
|
+
size = min(p.shape)
|
|
82
|
+
if self.proj == PROJ_DCT:
|
|
83
|
+
Qdct3 = dct3_matrix(size, p.dtype, p.device) # first row is zero
|
|
84
|
+
if self.sim_type == 'makhoul':
|
|
85
|
+
self.Q = Qdct3.t()
|
|
86
|
+
print(f'\n\t!!!!! Initialized DCT-2 matrix of size {size} !!!!!\n')
|
|
87
|
+
elif self.sim_type == 'matmul':
|
|
88
|
+
self.Q = Qdct3
|
|
89
|
+
print(f'\n\t!!!!! Initialized DCT-3 matrix of size {size} !!!!!\n')
|
|
90
|
+
else:
|
|
91
|
+
raise RuntimeError(f'Unknown sim_type: {self.sim_type}')
|
|
92
|
+
elif self.proj == PROJ_HDM:
|
|
93
|
+
self.Q = hadamard_transform(torch.eye(size).to(device=p.device, dtype=p.dtype), scale=1. / math.sqrt(size))
|
|
94
|
+
print(f'\n\t!!!!! Initialized Hadamard matrix of size {size} !!!!!\n')
|
|
95
|
+
elif self.proj == PROJ_RAND_QR:
|
|
96
|
+
random = torch.randn(size, size, dtype=p.dtype, device=p.device)
|
|
97
|
+
self.Q, _ = torch.linalg.qr(random)
|
|
98
|
+
del random
|
|
99
|
+
else:
|
|
100
|
+
raise RuntimeError(f'Projection {self.proj} is currently not supported!')
|
|
101
|
+
|
|
102
|
+
if self.use_theoretical_similarity:
|
|
103
|
+
self.Q_cols_norm = self.Q.norm(p=self.ell_norm, dim=0)
|
|
104
|
+
|
|
105
|
+
def should_compute_update(self, p):
|
|
106
|
+
"""
|
|
107
|
+
This function returns a boolean indicating whether the update for the parameter p should be computed on the current GPU
|
|
108
|
+
"""
|
|
109
|
+
state = self.state[p]
|
|
110
|
+
param_id = state[STATE_ID]
|
|
111
|
+
return param_id % dist.get_world_size() == dist.get_rank()
|
|
112
|
+
|
|
113
|
+
def should_update_projection(self):
|
|
114
|
+
return self.steps == 1 or self.steps % self.update_proj_gap == 0
|
|
115
|
+
|
|
116
|
+
def init_state(self, p):
|
|
117
|
+
state = self.state[p]
|
|
118
|
+
if p.ndim == 1: # adam update
|
|
119
|
+
print(f'Parameter of size {tuple(p.shape)} will receive original AdamW update with state shape {tuple(p.shape)}')
|
|
120
|
+
state[STATE_M] = torch.zeros_like(p)
|
|
121
|
+
state[STATE_V] = torch.zeros_like(p)
|
|
122
|
+
elif p.ndim == 2: # low-rank adam update
|
|
123
|
+
n, m = p.shape
|
|
124
|
+
if n >= self.max_shape or m >= self.max_shape: # apply full-rank
|
|
125
|
+
print(f'Parameter of size {tuple(p.shape)} will receive original AdamW update with state shape {tuple(p.shape)}')
|
|
126
|
+
state[STATE_M] = torch.zeros_like(p)
|
|
127
|
+
state[STATE_V] = torch.zeros_like(p)
|
|
128
|
+
else: # apply low-rank using the DCT transform as orthogonal matrix
|
|
129
|
+
if n >= m:
|
|
130
|
+
low_rank_shape = (n, self.rank)
|
|
131
|
+
else:
|
|
132
|
+
# fix for Llama-3-8B that has a layer of size (1024, 4096)
|
|
133
|
+
# fix for Qwen2.5-7B that has a layer of size (512, 3584)
|
|
134
|
+
if n in [512, 1024] and m in [3584, 4096]:
|
|
135
|
+
low_rank_shape = (n, self.rank)
|
|
136
|
+
else:
|
|
137
|
+
low_rank_shape = (self.rank, m)
|
|
138
|
+
# low_rank_shape = (n, self.rank) if n >= m else (self.rank, m)
|
|
139
|
+
print(f'Parameter of size {tuple(p.shape)} will receive low-rank update with state shape {low_rank_shape}')
|
|
140
|
+
state[STATE_M] = torch.zeros(*low_rank_shape, dtype=p.dtype, device=p.device)
|
|
141
|
+
state[STATE_V] = torch.zeros(*low_rank_shape, dtype=p.dtype, device=p.device)
|
|
142
|
+
state[STATE_FFT_LRP] = FFTLowRankProjector(p,
|
|
143
|
+
rank=self.rank,
|
|
144
|
+
proj=self.proj,
|
|
145
|
+
rotate_subspace=self.rotate_subspace,
|
|
146
|
+
sim_type=self.sim_type,
|
|
147
|
+
ell_norm=self.ell_norm,
|
|
148
|
+
use_th_sim=self.use_theoretical_similarity)
|
|
149
|
+
if self.use_ef:
|
|
150
|
+
if self.q_ef > 0:
|
|
151
|
+
# state[STATE_EF] = torch.zeros(p.numel() // 2, dtype=torch.uint8, device=p.device)
|
|
152
|
+
# state[STATE_EF_MIN] = torch.zeros(p.shape[0], dtype=torch.bfloat16, device=p.device)
|
|
153
|
+
# state[STATE_EF_MAX] = torch.zeros(p.shape[0], dtype=torch.bfloat16, device=p.device)
|
|
154
|
+
quantClass = {4: Quantizer4bit, 8: Quantizer8bit}[self.q_ef]
|
|
155
|
+
if self.q_ef == 4:
|
|
156
|
+
quantClass = Quantizer4bit
|
|
157
|
+
print(f'\n\t!!!!! Quantizing EF to 4 bits !!!!!\n')
|
|
158
|
+
elif self.q_ef == 8:
|
|
159
|
+
quantClass = Quantizer8bit
|
|
160
|
+
print(f'\n\t!!!!! Quantizing EF to 8 bits !!!!!\n')
|
|
161
|
+
else:
|
|
162
|
+
raise RuntimeError(f'Quantization on {self.q_ef} bits is currently not supported!')
|
|
163
|
+
state[STATE_EF] = quantClass(shape=p.shape, device=p.device, dtype=p.dtype, bucket_size=p.shape[1])
|
|
164
|
+
else:
|
|
165
|
+
state[STATE_EF] = torch.zeros_like(p)
|
|
166
|
+
|
|
167
|
+
### initialize Q
|
|
168
|
+
print('calling setup_Q')
|
|
169
|
+
self.setup_Q(p)
|
|
170
|
+
# end if
|
|
171
|
+
|
|
172
|
+
def init(self):
|
|
173
|
+
# init broadcast info
|
|
174
|
+
self.is_state_initialized = True
|
|
175
|
+
bcast_src_list = []
|
|
176
|
+
param_id = 0 # parameter id
|
|
177
|
+
for group in self.param_groups:
|
|
178
|
+
for p in group['params']:
|
|
179
|
+
if p is None: continue
|
|
180
|
+
if p.grad is None: continue
|
|
181
|
+
|
|
182
|
+
state = self.state[p]
|
|
183
|
+
if len(state) == 0:
|
|
184
|
+
if self.distributed:
|
|
185
|
+
state[STATE_ID] = param_id
|
|
186
|
+
param_id += 1
|
|
187
|
+
if self.should_compute_update(p):
|
|
188
|
+
# if the current process computes the update, then it will also broadcast the parameters to all other workers
|
|
189
|
+
state[STATE_BROADCAST_SOURCE] = torch.tensor(dist.get_rank(), dtype=torch.int32, device=f'cuda:{dist.get_rank()}')
|
|
190
|
+
self.init_state(p)
|
|
191
|
+
else:
|
|
192
|
+
# p.register_hook(lambda grad: None) # set gradient to None
|
|
193
|
+
# p.requires_grad = False # disable gradient computation for this layer
|
|
194
|
+
state[STATE_BROADCAST_SOURCE] = torch.tensor(0, dtype=torch.int32, device=f'cuda:{dist.get_rank()}') # zero means empty here because we will do an all reduce
|
|
195
|
+
bcast_src_list.append(state[STATE_BROADCAST_SOURCE].item())
|
|
196
|
+
else:
|
|
197
|
+
self.init_state(p)
|
|
198
|
+
# end for group
|
|
199
|
+
|
|
200
|
+
if self.distributed:
|
|
201
|
+
dist.barrier()
|
|
202
|
+
|
|
203
|
+
# with open(f'broadcast-{dist.get_rank()}.txt', 'w') as w:
|
|
204
|
+
# sync broadcast source
|
|
205
|
+
# w.write(f'Broadcast SRC on worker {dist.get_rank()} before all_reduce: {",".join(map(str, bcast_src_list))}\n')
|
|
206
|
+
bcast_src_list = []
|
|
207
|
+
for group in self.param_groups:
|
|
208
|
+
for p in group['params']:
|
|
209
|
+
if p is None: continue
|
|
210
|
+
if p.grad is None: continue
|
|
211
|
+
|
|
212
|
+
state = self.state[p]
|
|
213
|
+
dist.all_reduce(state[STATE_BROADCAST_SOURCE], op=dist.ReduceOp.SUM)
|
|
214
|
+
state[STATE_BROADCAST_SOURCE] = state[STATE_BROADCAST_SOURCE].item()
|
|
215
|
+
bcast_src_list.append(state[STATE_BROADCAST_SOURCE])
|
|
216
|
+
# end for group
|
|
217
|
+
# w.write(f'Broadcast SRC on worker {dist.get_rank()} after all_reduce: {",".join(map(str, bcast_src_list))}\n')
|
|
218
|
+
dist.barrier()
|
|
219
|
+
# end if
|
|
220
|
+
torch.cuda.empty_cache()
|
|
221
|
+
|
|
222
|
+
@torch.no_grad()
|
|
223
|
+
def step(self, closure=None):
|
|
224
|
+
self.steps += 1
|
|
225
|
+
|
|
226
|
+
loss = None
|
|
227
|
+
if closure is not None:
|
|
228
|
+
with torch.enable_grad():
|
|
229
|
+
loss = closure()
|
|
230
|
+
|
|
231
|
+
if not self.is_state_initialized:
|
|
232
|
+
self.init() # init broadcast info
|
|
233
|
+
|
|
234
|
+
for group in self.param_groups:
|
|
235
|
+
lr = group['lr']
|
|
236
|
+
wd = group['weight_decay']
|
|
237
|
+
|
|
238
|
+
for p in group['params']:
|
|
239
|
+
if p is None: continue
|
|
240
|
+
if p.grad is None: continue
|
|
241
|
+
|
|
242
|
+
if wd > 0:
|
|
243
|
+
p.mul_(1 - lr * wd)
|
|
244
|
+
|
|
245
|
+
if self.distributed:
|
|
246
|
+
if self.should_compute_update(p):
|
|
247
|
+
self.update_step(p, lr)
|
|
248
|
+
else:
|
|
249
|
+
self.update_step(p, lr)
|
|
250
|
+
# end for group
|
|
251
|
+
|
|
252
|
+
if self.distributed:
|
|
253
|
+
for group in self.param_groups:
|
|
254
|
+
for p in group['params']:
|
|
255
|
+
if p is None: continue
|
|
256
|
+
if p.grad is None: continue
|
|
257
|
+
|
|
258
|
+
dist.broadcast(p, src=self.state[p][STATE_BROADCAST_SOURCE])
|
|
259
|
+
|
|
260
|
+
# end for group
|
|
261
|
+
dist.barrier() # wait for all GPUs to compute the update for all layers
|
|
262
|
+
# end if distributed
|
|
263
|
+
return loss
|
|
264
|
+
|
|
265
|
+
@torch.no_grad()
|
|
266
|
+
def update_step(self, p, lr):
|
|
267
|
+
if p.ndim == 1: # adam update
|
|
268
|
+
self.adamw_step(p, lr)
|
|
269
|
+
elif p.ndim == 2: # low-rank adam update
|
|
270
|
+
n, m = p.shape
|
|
271
|
+
if n >= self.max_shape or m >= self.max_shape: # apply full-rank for parameters that have at least one dimension >= max_size (e.g. embeddings and lm_head)
|
|
272
|
+
self.adamw_step(p, lr)
|
|
273
|
+
else:
|
|
274
|
+
self.dct_low_rank_step(p, lr)
|
|
275
|
+
|
|
276
|
+
def dct_low_rank_step(self, p, lr):
|
|
277
|
+
beta1, beta2 = self.betas
|
|
278
|
+
bc1 = 1 - beta1 ** self.steps
|
|
279
|
+
sqrt_bc2 = math.sqrt(1 - beta2 ** self.steps)
|
|
280
|
+
adjusted_lr = -lr * sqrt_bc2 / bc1
|
|
281
|
+
|
|
282
|
+
A = p.grad # initially, the accumulator stores gradient and a bit later we will add the error feedback
|
|
283
|
+
state = self.state[p]
|
|
284
|
+
|
|
285
|
+
mt = state[STATE_M]
|
|
286
|
+
vt = state[STATE_V]
|
|
287
|
+
|
|
288
|
+
if self.use_ef:
|
|
289
|
+
E = state[STATE_EF]
|
|
290
|
+
if self.q_ef:
|
|
291
|
+
# see step 4 from Algorithm 1 in the MicroAdam paper https://arxiv.black/pdf/2405.15593
|
|
292
|
+
A.add_(E.quantize_inv()) # p.grad += Qinv(EF)
|
|
293
|
+
else:
|
|
294
|
+
A.add_(E)
|
|
295
|
+
|
|
296
|
+
clrp: FFTLowRankProjector = state[STATE_FFT_LRP]
|
|
297
|
+
clrp.inc_step()
|
|
298
|
+
|
|
299
|
+
if self.should_update_projection():
|
|
300
|
+
a = clrp.change_subspace(self.Q, A, col_norms=self.Q_cols_norm)
|
|
301
|
+
else:
|
|
302
|
+
### compute low-rank accumulator a
|
|
303
|
+
a = clrp.from_higher_to_lower_dimensions(self.Q, A)
|
|
304
|
+
|
|
305
|
+
if self.use_ef:
|
|
306
|
+
A_reconstructed = clrp.from_lower_to_higher_dimensions(self.Q, a)
|
|
307
|
+
if self.q_ef:
|
|
308
|
+
A.sub_(A_reconstructed) # the full precision EF is stored now in A
|
|
309
|
+
# see step 8 from Algorithm 1 in the MicroAdam paper https://arxiv.black/pdf/2405.15593
|
|
310
|
+
E.quantize(A)
|
|
311
|
+
else:
|
|
312
|
+
E.copy_(A).sub_(A_reconstructed)
|
|
313
|
+
del A_reconstructed
|
|
314
|
+
|
|
315
|
+
### update momentum m and v (rotate first, if needed)
|
|
316
|
+
if self.steps > 1 and self.rotate_subspace and self.should_update_projection():
|
|
317
|
+
R = clrp.get_subspace_rotation_matrix(self.Q)
|
|
318
|
+
clrp.rotate_subspace(R, mt)
|
|
319
|
+
clrp.rotate_subspace(R, vt)
|
|
320
|
+
vt.abs_() # make sure vt is positive
|
|
321
|
+
del R
|
|
322
|
+
|
|
323
|
+
mt.mul_(beta1).add_(a, alpha=1 - beta1)
|
|
324
|
+
vt.mul_(beta2).addcmul_(a, a, value=1 - beta2)
|
|
325
|
+
|
|
326
|
+
u = mt / (self.eps * sqrt_bc2 + vt.sqrt())
|
|
327
|
+
clrp.from_lower_to_higher_dimensions(self.Q, u, out=p.grad)
|
|
328
|
+
del u, a
|
|
329
|
+
|
|
330
|
+
p.add_(p.grad, alpha=adjusted_lr)
|
|
331
|
+
|
|
332
|
+
@torch.no_grad()
|
|
333
|
+
def adamw_step(self, p, lr):
|
|
334
|
+
state = self.state[p]
|
|
335
|
+
g = p.grad
|
|
336
|
+
|
|
337
|
+
mt = state[STATE_M]
|
|
338
|
+
vt = state[STATE_V]
|
|
339
|
+
|
|
340
|
+
beta1, beta2 = self.betas
|
|
341
|
+
bc1 = 1 - beta1 ** self.steps
|
|
342
|
+
sqrt_bc2 = math.sqrt(1 - beta2 ** self.steps)
|
|
343
|
+
adjusted_lr = -lr * sqrt_bc2 / bc1
|
|
344
|
+
|
|
345
|
+
# update momentum m and v
|
|
346
|
+
mt.mul_(beta1).add_(g, alpha=1-beta1)
|
|
347
|
+
vt.mul_(beta2).addcmul_(g, g, value=1-beta2)
|
|
348
|
+
|
|
349
|
+
# U = mt / (self.eps * sqrt_bc2 + vt.sqrt())
|
|
350
|
+
g.copy_(vt).sqrt_().add_(self.eps * sqrt_bc2).div_(mt).reciprocal_()
|
|
351
|
+
p.add_(g, alpha=adjusted_lr)
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.distributed as dist
|
|
3
|
+
from ista_daslab_optimizers.utils.dct import dct_type2_makhoul
|
|
4
|
+
from ista_daslab_optimizers.utils.global_cache import GlobalCache
|
|
5
|
+
|
|
6
|
+
class FFTLowRankProjector:
|
|
7
|
+
def __init__(self, p, rank, proj, rotate_subspace, sim_type='matmul', ell_norm=1, use_th_sim=False):
|
|
8
|
+
assert sim_type in ['matmul', 'makhoul']
|
|
9
|
+
self.rank = rank
|
|
10
|
+
self.proj = proj
|
|
11
|
+
self.rotate_states = rotate_subspace # allocate indices_pref only if we choose to rotate the subspace
|
|
12
|
+
self.sim_type = sim_type
|
|
13
|
+
self.ell_norm = ell_norm
|
|
14
|
+
self.use_th_sim = use_th_sim
|
|
15
|
+
|
|
16
|
+
self.size = None
|
|
17
|
+
self.indices_crt = None # the indices for the columns/rows
|
|
18
|
+
self.indices_prev = None # the indices for the columns/rows
|
|
19
|
+
self.is_right_proj = None
|
|
20
|
+
|
|
21
|
+
self.steps = 0
|
|
22
|
+
self.device = f'cuda:{dist.get_rank()}' if dist.is_initialized() else 'cuda:0'
|
|
23
|
+
|
|
24
|
+
GlobalCache.init()
|
|
25
|
+
self._setup(p)
|
|
26
|
+
|
|
27
|
+
def _setup(self, p):
|
|
28
|
+
n, m = p.shape
|
|
29
|
+
if n >= m:
|
|
30
|
+
self.is_right_proj = True
|
|
31
|
+
self.size = min(n, m)
|
|
32
|
+
else:
|
|
33
|
+
# fix for Llama-3-8B that has a layer of size (1024, 4096)
|
|
34
|
+
# fix for Qwen2.5-7B that has a layer of size (512, 3584)
|
|
35
|
+
if n in [512, 1024] and m in [3584, 4096]:
|
|
36
|
+
self.is_right_proj = True
|
|
37
|
+
self.size = m
|
|
38
|
+
else:
|
|
39
|
+
self.is_right_proj = False
|
|
40
|
+
self.size = min(n, m)
|
|
41
|
+
# self.is_right_proj = (n >= m) or (n < m and self.size == m)
|
|
42
|
+
|
|
43
|
+
self.indices_crt = torch.zeros(self.rank, dtype=torch.int32, device=p.device)
|
|
44
|
+
if self.rotate_states:
|
|
45
|
+
self.indices_prev = torch.zeros(self.rank, dtype=torch.int32, device=p.device)
|
|
46
|
+
|
|
47
|
+
def inc_step(self):
|
|
48
|
+
self.steps += 1
|
|
49
|
+
|
|
50
|
+
def compute_similarity_matmul(self, Q, A):
|
|
51
|
+
if self.is_right_proj:
|
|
52
|
+
S = A @ Q
|
|
53
|
+
norms = S.norm(p=self.ell_norm, dim=0) # dim = 0 computes norm of columns (over all rows)
|
|
54
|
+
else:
|
|
55
|
+
S = Q.T @ A
|
|
56
|
+
norms = S.norm(p=self.ell_norm, dim=1) # dim = 1 computes norm of rows (over all columns)
|
|
57
|
+
return S, norms
|
|
58
|
+
|
|
59
|
+
def compute_similarity_makhoul(self, A):
|
|
60
|
+
if self.is_right_proj: # R >= C
|
|
61
|
+
S = dct_type2_makhoul(A)
|
|
62
|
+
norms = S.norm(p=1, dim=0) # dim = 0 computes norm of columns (over all rows) to rank columns
|
|
63
|
+
else: # R < C
|
|
64
|
+
S = dct_type2_makhoul(A.T)
|
|
65
|
+
S = S.T # account for the transposition in inputM because Makhoul computes DCT per rows by default
|
|
66
|
+
norms = S.norm(p=1, dim=1) # dim = 1 computes norm of rows (over all columns) to rank rows
|
|
67
|
+
return S, norms
|
|
68
|
+
|
|
69
|
+
def change_subspace(self, Q, A, col_norms, out=None):
|
|
70
|
+
"""
|
|
71
|
+
This method computes P = A @ Q or P = Q.T @ A and then ranks the columns/rows of matrix P.
|
|
72
|
+
Once we determine the most important r indices, we can simply select them directly from P
|
|
73
|
+
without having to multiply again A @ Q[:, indices] or Q[indices, :] @ A.
|
|
74
|
+
This way, we save some computations.
|
|
75
|
+
"""
|
|
76
|
+
# if self.steps == 1 or self.steps % self.update_proj_gap == 0:
|
|
77
|
+
if self.steps > 1:
|
|
78
|
+
if self.rotate_states:
|
|
79
|
+
self.indices_prev.copy_(self.indices_crt)
|
|
80
|
+
|
|
81
|
+
if self.sim_type == 'matmul':
|
|
82
|
+
S, norms = self.compute_similarity_matmul(Q, A)
|
|
83
|
+
else:
|
|
84
|
+
S, norms = self.compute_similarity_makhoul(A)
|
|
85
|
+
|
|
86
|
+
if self.use_th_sim:
|
|
87
|
+
norms.mul_(col_norms)
|
|
88
|
+
|
|
89
|
+
indices = torch.topk(
|
|
90
|
+
input=norms,
|
|
91
|
+
k=self.rank,
|
|
92
|
+
sorted=False,
|
|
93
|
+
).indices
|
|
94
|
+
|
|
95
|
+
self.indices_crt.copy_(indices)
|
|
96
|
+
del indices, norms
|
|
97
|
+
|
|
98
|
+
# if self.sim_type == 'matmul':
|
|
99
|
+
if out is None:
|
|
100
|
+
if self.is_right_proj:
|
|
101
|
+
return S[:, self.indices_crt]
|
|
102
|
+
else:
|
|
103
|
+
return S[self.indices_crt, :]
|
|
104
|
+
else:
|
|
105
|
+
if self.is_right_proj:
|
|
106
|
+
out.copy_(S[:, self.indices_crt])
|
|
107
|
+
else:
|
|
108
|
+
out.copy_(S[self.indices_crt, :])
|
|
109
|
+
# elif self.sim_type == 'makhoul':
|
|
110
|
+
# if out is None:
|
|
111
|
+
# if self.is_right_proj:
|
|
112
|
+
# return S[:, self.indices_crt]
|
|
113
|
+
# else:
|
|
114
|
+
# return S[:, self.indices_crt].T
|
|
115
|
+
# else:
|
|
116
|
+
# if self.is_right_proj:
|
|
117
|
+
# out.copy_(S[:, self.indices_crt])
|
|
118
|
+
# else:
|
|
119
|
+
# out.copy_(S[:, self.indices_crt].T)
|
|
120
|
+
# else:
|
|
121
|
+
# raise RuntimeError(f'Unknown similarity sim_type: {self.sim_type}')
|
|
122
|
+
|
|
123
|
+
def get_subspace_rotation_matrix(self, Q):
|
|
124
|
+
assert self.rotate_states, f'The optimizer was not initialized with rotate_subspace=True'
|
|
125
|
+
|
|
126
|
+
icrt = self.indices_crt
|
|
127
|
+
iprev = self.indices_prev
|
|
128
|
+
|
|
129
|
+
if self.is_right_proj:
|
|
130
|
+
return Q[:, iprev].T @ Q[:, icrt] # (m, r).T @ (m, r) = (r, r) # with Q from MatrixStorage @ PhD #11, page 44 (same as with Qfrom optimizer state @ PhD #11, page 47)
|
|
131
|
+
# return Q[iprev, :] @ Q[icrt, :].T # (r, m) @ (r, m).T = (r, r)
|
|
132
|
+
else:
|
|
133
|
+
# return Q[icrt, :] @ Q[iprev, :].T # (r, n) @ (r, n).T = (r, r) # with Q from MatrixStorage @ PhD #11, page 44
|
|
134
|
+
return Q[:, icrt].T @ Q[:, iprev] # (r, n) @ (r, n).T = (r, r) # with Q from optimizer state @ PhD #11, page 47
|
|
135
|
+
# return Q[:, icrt].T @ Q[:, iprev] # (n, r).T @ (n, r) = (r, r)
|
|
136
|
+
|
|
137
|
+
def rotate_subspace(self, R, w):
|
|
138
|
+
assert self.rotate_states, f'The optimizer was not initialized with rotate_subspace=True'
|
|
139
|
+
if self.is_right_proj:
|
|
140
|
+
torch.matmul(w, R, out=w)
|
|
141
|
+
else:
|
|
142
|
+
torch.matmul(R, w, out=w)
|
|
143
|
+
|
|
144
|
+
def from_higher_to_lower_dimensions(self, Q, X):
|
|
145
|
+
# Q = MatrixStorage.get_matrix(self.size, self.proj, X.dtype, transpose=not self.is_right_proj)
|
|
146
|
+
|
|
147
|
+
icrt = self.indices_crt
|
|
148
|
+
|
|
149
|
+
if self.is_right_proj:
|
|
150
|
+
return X @ Q[:, icrt] # (n, m) @ (m, r) = (n, r)
|
|
151
|
+
else:
|
|
152
|
+
# return Q[icrt, :] @ X # (r, n) @ (n, m) = (r, m) # with Q from MatrixStorage @ PhD #11, page 44
|
|
153
|
+
return Q[:, icrt].T @ X # (n, r).T @ (n, m) = (r, m) # with Q from optimizer state @ PhD #11, page 47
|
|
154
|
+
|
|
155
|
+
def from_lower_to_higher_dimensions(self, Q, x, out=None):
|
|
156
|
+
# Q = MatrixStorage.get_matrix(self.size, self.proj, x.dtype, transpose=not self.is_right_proj)
|
|
157
|
+
icrt = self.indices_crt
|
|
158
|
+
|
|
159
|
+
if self.is_right_proj:
|
|
160
|
+
# (n, r) @ (m, r).T = (n, m)
|
|
161
|
+
if out is None:
|
|
162
|
+
return x @ Q[:, icrt].T
|
|
163
|
+
else:
|
|
164
|
+
torch.matmul(x, Q[:, icrt].T, out=out)
|
|
165
|
+
else:
|
|
166
|
+
# (r, n).T @ (r, m) = (n, m)
|
|
167
|
+
if out is None:
|
|
168
|
+
# return Q[icrt, :].T @ x # with Q from MatrixStorage @ PhD #11, page 44
|
|
169
|
+
return Q[:, icrt] @ x # with Q from optimizer state @ PhD #11, page 47
|
|
170
|
+
else:
|
|
171
|
+
# torch.matmul(Q[icrt, :].T, x, out=out) # with Q from MatrixStorage @ PhD #11, page 44
|
|
172
|
+
torch.matmul(Q[:, icrt], x, out=out) # with Q from optimizer state @ PhD #11, page 47
|
|
173
|
+
|
|
174
|
+
# if self.strategy == STRATEGY_FIRST:
|
|
175
|
+
# self.indices_crt.copy_(torch.arange(start=0, end=self.rank, dtype=torch.int32, device=self.device))
|
|
176
|
+
# elif self.strategy == STRATEGY_RANDOM:
|
|
177
|
+
# self.indices_crt.copy_(torch.randperm(n=self.size, dtype=torch.int32, device=self.device)[:self.rank])
|
|
178
|
+
# elif self.strategy == STRATEGY_WINDOW:
|
|
179
|
+
# """
|
|
180
|
+
# For size=5, range2x will contain [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
|
|
181
|
+
# For rank=3, the following indices will be generated:
|
|
182
|
+
# step = 1: [0, 1, 2]
|
|
183
|
+
# step = 2: [1, 2, 3]
|
|
184
|
+
# step = 3: [2, 3, 4]
|
|
185
|
+
# step = 4: [3, 4, 0]
|
|
186
|
+
# step = 5: [4, 0, 1]
|
|
187
|
+
# step = 6: [0, 1, 2] # here we have the same indices as for step 1 (the indices are repeated once at size steps)
|
|
188
|
+
# """
|
|
189
|
+
# range2x = torch.arange(self.size, dtype=torch.int32, device=self.device).repeat(1, 2).view(-1)
|
|
190
|
+
# start = self.steps % self.size
|
|
191
|
+
# self.indices_crt.copy_(range2x[start:start+self.rank]) # rank indices, starting at "start"
|
|
192
|
+
# del range2x
|