fusion-bench 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +1 -0
- fusion_bench/compat/method/base_algorithm.py +0 -1
- fusion_bench/compat/modelpool/__init__.py +2 -1
- fusion_bench/dataset/arc_agi/__init__.py +6 -1
- fusion_bench/dataset/arc_agi/arc.py +21 -7
- fusion_bench/dataset/arc_agi/arc_agi.py +156 -25
- fusion_bench/dataset/arc_agi/np_cache.py +0 -1
- fusion_bench/dataset/arc_agi/preprocess.py +50 -8
- fusion_bench/dataset/llama/collate.py +10 -3
- fusion_bench/method/__init__.py +3 -0
- fusion_bench/method/adamerging/__init__.py +1 -1
- fusion_bench/method/lm_finetune/fullfinetune_sft.py +47 -5
- fusion_bench/method/lm_finetune/peftfinetune_sft.py +58 -23
- fusion_bench/method/pruning/magnitude_diff_pruning.py +2 -1
- fusion_bench/method/rankone_moe/__init__.py +3 -0
- fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
- fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
- fusion_bench/method/simple_average.py +1 -1
- fusion_bench/mixins/clip_classification.py +2 -7
- fusion_bench/mixins/lightning_fabric.py +2 -2
- fusion_bench/models/rankone_moe.py +410 -0
- fusion_bench/taskpool/__init__.py +10 -2
- fusion_bench/taskpool/clip_vision/__init__.py +1 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +2 -1
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/RECORD +36 -29
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +4 -4
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +13 -7
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
- fusion_bench_config/method/regmean/clip_regmean.yaml +1 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,410 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Dict, List, Tuple # noqa: F401
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.func
|
|
7
|
+
from torch import Tensor, nn
|
|
8
|
+
from torch.func import functional_call
|
|
9
|
+
from torch.nn import functional as F
|
|
10
|
+
|
|
11
|
+
from fusion_bench.utils.type import StateDictType
|
|
12
|
+
|
|
13
|
+
log = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def join_list(list_of_list: List[List]):
|
|
17
|
+
ans = []
|
|
18
|
+
for l in list_of_list:
|
|
19
|
+
ans.extend(l)
|
|
20
|
+
return ans
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def del_attr(obj, names: List[str]):
|
|
24
|
+
"""
|
|
25
|
+
Deletes an attribute from an object recursively.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
obj (object): Object to delete attribute from.
|
|
29
|
+
names (list): List of attribute names to delete recursively.
|
|
30
|
+
"""
|
|
31
|
+
if len(names) == 1:
|
|
32
|
+
delattr(obj, names[0])
|
|
33
|
+
else:
|
|
34
|
+
del_attr(getattr(obj, names[0]), names[1:])
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def set_attr(obj, names: List[str], val):
|
|
38
|
+
"""
|
|
39
|
+
Sets an attribute of an object recursively.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
obj (object): Object to set attribute of.
|
|
43
|
+
names (list): List of attribute names to set recursively.
|
|
44
|
+
val (object): Value to set the attribute to.
|
|
45
|
+
"""
|
|
46
|
+
if len(names) == 1:
|
|
47
|
+
setattr(obj, names[0], val)
|
|
48
|
+
else:
|
|
49
|
+
set_attr(getattr(obj, names[0]), names[1:], val)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def get_attr(obj, names: List[str]):
|
|
53
|
+
"""
|
|
54
|
+
Gets an attribute of an object recursively.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
obj (object): Object to get attribute of.
|
|
58
|
+
names (list): List of attribute names to get recursively.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
object: The attribute of the object.
|
|
62
|
+
"""
|
|
63
|
+
if len(names) == 1:
|
|
64
|
+
return getattr(obj, names[0])
|
|
65
|
+
else:
|
|
66
|
+
return get_attr(getattr(obj, names[0]), names[1:])
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class Depth_0_Gate(nn.Module):
|
|
70
|
+
def __init__(self, num_experts: int):
|
|
71
|
+
super().__init__()
|
|
72
|
+
self.weight = nn.Parameter(torch.empty(num_experts), requires_grad=True)
|
|
73
|
+
|
|
74
|
+
def init_weight(self, init_lambda: float):
|
|
75
|
+
nn.init.constant_(self.weight, init_lambda)
|
|
76
|
+
|
|
77
|
+
def forward(self, *args, **kwargs) -> Tensor:
|
|
78
|
+
return self.weight
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class Depth_1_Gate(nn.Module):
|
|
82
|
+
def __init__(self, hidden_size: int, num_experts: int):
|
|
83
|
+
super().__init__()
|
|
84
|
+
self.fc = nn.Linear(hidden_size, num_experts, bias=True)
|
|
85
|
+
|
|
86
|
+
def init_weight(self, init_lambda: float):
|
|
87
|
+
nn.init.normal_(self.fc.weight, std=0.01)
|
|
88
|
+
nn.init.constant_(self.fc.bias, init_lambda)
|
|
89
|
+
|
|
90
|
+
def forward(self, hidden_states: Tensor) -> Tensor:
|
|
91
|
+
return self.fc(hidden_states)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class Depth_2_Gate(nn.Module):
|
|
95
|
+
def __init__(self, hidden_size: int, num_experts: int):
|
|
96
|
+
super().__init__()
|
|
97
|
+
self.fc1 = nn.Linear(hidden_size, num_experts * 2, bias=True)
|
|
98
|
+
self.fc2 = nn.Linear(num_experts * 2, num_experts, bias=True)
|
|
99
|
+
|
|
100
|
+
def init_weight(self, init_lambda: float):
|
|
101
|
+
nn.init.normal_(self.fc1.weight, std=0.01)
|
|
102
|
+
nn.init.zeros_(self.fc1.bias)
|
|
103
|
+
nn.init.normal_(self.fc2.weight, std=0.01)
|
|
104
|
+
nn.init.constant_(self.fc2.bias, init_lambda)
|
|
105
|
+
|
|
106
|
+
def forward(self, hidden_states: Tensor) -> Tensor:
|
|
107
|
+
hidden_states = F.relu(self.fc1(hidden_states))
|
|
108
|
+
return self.fc2(hidden_states)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def construct_rankone_moe_gate(
|
|
112
|
+
hidden_size: int,
|
|
113
|
+
num_experts: int,
|
|
114
|
+
init_lambda: float,
|
|
115
|
+
num_hidden_layers: int = 2,
|
|
116
|
+
):
|
|
117
|
+
if num_hidden_layers == 0:
|
|
118
|
+
gate = Depth_0_Gate(num_experts)
|
|
119
|
+
elif num_hidden_layers == 1:
|
|
120
|
+
gate = Depth_1_Gate(hidden_size, num_experts)
|
|
121
|
+
elif num_hidden_layers == 2:
|
|
122
|
+
gate = Depth_2_Gate(hidden_size, num_experts)
|
|
123
|
+
else:
|
|
124
|
+
raise ValueError(f"Unsupported number of hidden layers: {num_hidden_layers}")
|
|
125
|
+
|
|
126
|
+
gate.num_hidden_layers = num_hidden_layers
|
|
127
|
+
gate.init_weight(init_lambda)
|
|
128
|
+
return gate
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class ExpertNotTrainedError(Exception):
|
|
132
|
+
pass
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _is_all_zeros(tensor: Tensor | List[Tensor]) -> bool:
|
|
136
|
+
"""
|
|
137
|
+
Check if a tensor or a list of tensors are all zeros.
|
|
138
|
+
"""
|
|
139
|
+
if isinstance(tensor, Tensor):
|
|
140
|
+
return torch.allclose(tensor, torch.zeros_like(tensor))
|
|
141
|
+
else:
|
|
142
|
+
return all(_is_all_zeros(t) for t in tensor)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _svd(w: Tensor, full_matrices=True) -> Tuple[Tensor, Tensor, Tensor]:
|
|
146
|
+
"""
|
|
147
|
+
Perform Singular Value Decomposition (SVD) on a tensor.
|
|
148
|
+
"""
|
|
149
|
+
u, s, vh = torch.linalg.svd(
|
|
150
|
+
w, full_matrices=full_matrices, driver="gesvd" if w.is_cuda else None
|
|
151
|
+
)
|
|
152
|
+
v = vh.T
|
|
153
|
+
return u, s, v
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def svd(
|
|
157
|
+
w: Tensor, full_matrices=True, accelerator=None
|
|
158
|
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
|
159
|
+
"""
|
|
160
|
+
Perform SVD on a tensor, optionally using a specified accelerator.
|
|
161
|
+
"""
|
|
162
|
+
if accelerator is None:
|
|
163
|
+
return _svd(w, full_matrices=full_matrices)
|
|
164
|
+
original_device = w.device
|
|
165
|
+
w = w.to(accelerator)
|
|
166
|
+
u, s, v = _svd(w)
|
|
167
|
+
return u.to(original_device), s.to(original_device), v.to(original_device)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def fun_joint_svd(
|
|
171
|
+
w_list: List[Tensor], accelerator=None
|
|
172
|
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
|
173
|
+
|
|
174
|
+
w = torch.cat(w_list, dim=1) # stacked_matrix
|
|
175
|
+
original_device = w.device
|
|
176
|
+
if accelerator is not None:
|
|
177
|
+
w = w.to(accelerator)
|
|
178
|
+
u_c, s_c, vh_c = torch.linalg.svd(
|
|
179
|
+
w, full_matrices=False, driver="gesvd" if w.is_cuda else None
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
svd_list = []
|
|
183
|
+
offset = 0
|
|
184
|
+
for matrix in w_list:
|
|
185
|
+
n_cols = matrix.size(1)
|
|
186
|
+
u = u_c
|
|
187
|
+
s = s_c
|
|
188
|
+
vh_ = vh_c[:, offset : offset + n_cols]
|
|
189
|
+
v = vh_.T
|
|
190
|
+
svd_list.append(
|
|
191
|
+
[u.to(original_device), s.to(original_device), v.to(original_device)]
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
offset += n_cols
|
|
195
|
+
return svd_list
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class RankOneMoE(nn.Module):
|
|
199
|
+
# variable to store the merged state dict temporarily
|
|
200
|
+
_merged_state_dict: StateDictType = None
|
|
201
|
+
|
|
202
|
+
def __init__(
|
|
203
|
+
self,
|
|
204
|
+
hidden_size: int,
|
|
205
|
+
base_model: nn.Module,
|
|
206
|
+
expert_models: List[nn.Module],
|
|
207
|
+
init_lambda: float = 0.2,
|
|
208
|
+
batch_first: bool = False,
|
|
209
|
+
router_hidden_layers: int = 2,
|
|
210
|
+
batch_reduce: bool = False,
|
|
211
|
+
svd_accelerator=False,
|
|
212
|
+
rank_k: int = -1,
|
|
213
|
+
select_k: int = -1,
|
|
214
|
+
):
|
|
215
|
+
"""
|
|
216
|
+
Initializes the RankOneMoE class.
|
|
217
|
+
https://github.com/EnnengYang/RankOne-MoE
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
hidden_size (int): The size of the hidden layer in the models.
|
|
221
|
+
base_model (nn.Module): The base model that will be used as a reference for the expert models.
|
|
222
|
+
expert_models (List[nn.Module]): A list of expert models that will be combined.
|
|
223
|
+
init_lambda (float, optional): The initial lambda value for the weight ensembling gate. Defaults to 0.2.
|
|
224
|
+
batch_first (bool, optional): If True, the input tensors are expected to have the batch size as the first dimension. Defaults to False.
|
|
225
|
+
router_hidden_layers (int, optional): The number of hidden layers in the router. Defaults to 2.
|
|
226
|
+
batch_reduce (bool): If True, the batch dimension of routing weights is reduced. Defaults to False.
|
|
227
|
+
"""
|
|
228
|
+
super().__init__()
|
|
229
|
+
self.num_experts = len(expert_models)
|
|
230
|
+
self.hidden_size = hidden_size
|
|
231
|
+
self.batch_first = batch_first
|
|
232
|
+
self.batch_reduce = batch_reduce
|
|
233
|
+
self.svd_accelerator = svd_accelerator
|
|
234
|
+
self.rank_k = rank_k
|
|
235
|
+
self.select_k = select_k
|
|
236
|
+
self.init_lambda = init_lambda
|
|
237
|
+
|
|
238
|
+
self.gate = construct_rankone_moe_gate(
|
|
239
|
+
hidden_size=hidden_size,
|
|
240
|
+
num_experts=int(self.num_experts * self.rank_k),
|
|
241
|
+
init_lambda=init_lambda,
|
|
242
|
+
num_hidden_layers=router_hidden_layers,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
# compute the task vectors
|
|
246
|
+
for name, param in base_model.named_parameters():
|
|
247
|
+
if not param.requires_grad:
|
|
248
|
+
for m in expert_models:
|
|
249
|
+
del_attr(m, name.split("."))
|
|
250
|
+
else:
|
|
251
|
+
for m in expert_models:
|
|
252
|
+
get_attr(m, name.split(".")).data = (
|
|
253
|
+
get_attr(m, name.split(".")) - param
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
# fix base model and expert models
|
|
257
|
+
self.base_model = base_model.requires_grad_(False)
|
|
258
|
+
for m in expert_models:
|
|
259
|
+
m.requires_grad_(False)
|
|
260
|
+
|
|
261
|
+
# task vecotr (only bias term)
|
|
262
|
+
self.task_vectors_fc1_bias = nn.Parameter(
|
|
263
|
+
torch.stack([e.fc1.bias for e in expert_models], dim=0), requires_grad=False
|
|
264
|
+
)
|
|
265
|
+
self.task_vectors_fc2_bias = nn.Parameter(
|
|
266
|
+
torch.stack([e.fc2.bias for e in expert_models], dim=0), requires_grad=False
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
# SVD representation of task vector (only weight term)
|
|
270
|
+
self.task_vectors_fc1_u = nn.ParameterList()
|
|
271
|
+
self.task_vectors_fc1_svh = nn.ParameterList()
|
|
272
|
+
self.task_vectors_fc2_u = nn.ParameterList()
|
|
273
|
+
self.task_vectors_fc2_svh = nn.ParameterList()
|
|
274
|
+
|
|
275
|
+
for m in expert_models:
|
|
276
|
+
for name, param in m.named_parameters():
|
|
277
|
+
if ".weight" in name:
|
|
278
|
+
|
|
279
|
+
if _is_all_zeros(param):
|
|
280
|
+
# All fine-tuned models are identical to the pretrained model
|
|
281
|
+
raise ExpertNotTrainedError()
|
|
282
|
+
|
|
283
|
+
u, s, v = svd(param, accelerator=self.svd_accelerator)
|
|
284
|
+
u = u[:, : self.rank_k]
|
|
285
|
+
s = s[: self.rank_k]
|
|
286
|
+
v = v[:, : self.rank_k]
|
|
287
|
+
|
|
288
|
+
if "fc1.weight" == name:
|
|
289
|
+
self.task_vectors_fc1_u.append(
|
|
290
|
+
nn.Parameter(u.T, requires_grad=False)
|
|
291
|
+
)
|
|
292
|
+
self.task_vectors_fc1_svh.append(
|
|
293
|
+
nn.Parameter((s * v).T, requires_grad=False)
|
|
294
|
+
)
|
|
295
|
+
elif "fc2.weight" == name:
|
|
296
|
+
self.task_vectors_fc2_u.append(
|
|
297
|
+
nn.Parameter(u.T, requires_grad=False)
|
|
298
|
+
)
|
|
299
|
+
self.task_vectors_fc2_svh.append(
|
|
300
|
+
nn.Parameter((s * v).T, requires_grad=False)
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
# remove the original module from fine-tuned models to save memory
|
|
304
|
+
for name, param in base_model.named_parameters():
|
|
305
|
+
name_list = name.split(".")
|
|
306
|
+
for m in expert_models:
|
|
307
|
+
set_attr(m, name_list, None)
|
|
308
|
+
|
|
309
|
+
@property
|
|
310
|
+
def forward_model(self):
|
|
311
|
+
return functools.partial(
|
|
312
|
+
functional_call,
|
|
313
|
+
self.base_model,
|
|
314
|
+
self._merged_state_dict,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
def top_k_soft(self, s, k):
|
|
318
|
+
threshold, _ = torch.topk(s, k, largest=True, sorted=False)
|
|
319
|
+
min_threshold = threshold.min()
|
|
320
|
+
# sigmoid -> mask
|
|
321
|
+
mask = torch.sigmoid(100 * (s - min_threshold))
|
|
322
|
+
result = s * mask
|
|
323
|
+
return result
|
|
324
|
+
|
|
325
|
+
def merge_weights(self, expert_weights):
|
|
326
|
+
state_dict = self.base_model.state_dict(keep_vars=True)
|
|
327
|
+
|
|
328
|
+
# Select top-k experts from the expert pool for fusion
|
|
329
|
+
if self.select_k > 0:
|
|
330
|
+
expert_weights = self.top_k_soft(expert_weights, self.select_k)
|
|
331
|
+
|
|
332
|
+
for name in state_dict:
|
|
333
|
+
if name == "fc1.bias":
|
|
334
|
+
for param in self.task_vectors_fc1_bias:
|
|
335
|
+
state_dict[name] = state_dict[name] + self.init_lambda * param
|
|
336
|
+
elif name == "fc2.bias":
|
|
337
|
+
for param in self.task_vectors_fc2_bias:
|
|
338
|
+
state_dict[name] = state_dict[name] + self.init_lambda * param
|
|
339
|
+
|
|
340
|
+
elif name == "fc1.weight":
|
|
341
|
+
w_list = torch.split(
|
|
342
|
+
expert_weights,
|
|
343
|
+
int(expert_weights.size(-1) / self.num_experts),
|
|
344
|
+
dim=-1,
|
|
345
|
+
)
|
|
346
|
+
for weight, u, svh in zip(
|
|
347
|
+
w_list, self.task_vectors_fc1_u, self.task_vectors_fc1_svh
|
|
348
|
+
):
|
|
349
|
+
weight_diag = torch.diag(weight)
|
|
350
|
+
weight_u = torch.mm(weight_diag, u)
|
|
351
|
+
result = torch.matmul(weight_u.T, svh)
|
|
352
|
+
state_dict[name] = state_dict[name] + result
|
|
353
|
+
|
|
354
|
+
elif name == "fc2.weight":
|
|
355
|
+
w_list = torch.split(
|
|
356
|
+
expert_weights,
|
|
357
|
+
int(expert_weights.size(-1) / self.num_experts),
|
|
358
|
+
dim=-1,
|
|
359
|
+
)
|
|
360
|
+
for weight, u, svh in zip(
|
|
361
|
+
w_list, self.task_vectors_fc2_u, self.task_vectors_fc2_svh
|
|
362
|
+
):
|
|
363
|
+
weight_diag = torch.diag(weight)
|
|
364
|
+
weight_u = torch.mm(weight_diag, u)
|
|
365
|
+
result = torch.matmul(weight_u.T, svh)
|
|
366
|
+
state_dict[name] = state_dict[name] + result
|
|
367
|
+
|
|
368
|
+
self._merged_state_dict = state_dict
|
|
369
|
+
return state_dict
|
|
370
|
+
|
|
371
|
+
def forward(self, hidden_states: Tensor):
|
|
372
|
+
if self.gate.num_hidden_layers == 0:
|
|
373
|
+
gate_weights = self.gate()
|
|
374
|
+
else:
|
|
375
|
+
gate_weights = self.gate(hidden_states)
|
|
376
|
+
if self.batch_first:
|
|
377
|
+
# the input is in the shape of (batch_size, seq_len, hidden_size)
|
|
378
|
+
gate_weights = gate_weights.mean(dim=1)
|
|
379
|
+
else:
|
|
380
|
+
# the input is in the shape of (seq_len, batch_size, hidden_size)
|
|
381
|
+
gate_weights = gate_weights.mean(dim=0)
|
|
382
|
+
|
|
383
|
+
if self.gate.num_hidden_layers == 0:
|
|
384
|
+
self.merge_weights(gate_weights)
|
|
385
|
+
output_hidden_states = self.forward_model(hidden_states)
|
|
386
|
+
elif self.batch_reduce:
|
|
387
|
+
gate_weights = gate_weights.mean(dim=0)
|
|
388
|
+
self.merge_weights(gate_weights)
|
|
389
|
+
output_hidden_states = self.forward_model(hidden_states)
|
|
390
|
+
else:
|
|
391
|
+
output_hidden_states = []
|
|
392
|
+
for sample_idx, weights in enumerate(gate_weights):
|
|
393
|
+
self.merge_weights(weights)
|
|
394
|
+
if self.batch_first:
|
|
395
|
+
output_hidden_states.append(
|
|
396
|
+
self.forward_model(hidden_states[sample_idx : sample_idx + 1])
|
|
397
|
+
)
|
|
398
|
+
else:
|
|
399
|
+
output_hidden_states.append(
|
|
400
|
+
self.forward_model(
|
|
401
|
+
hidden_states[:, sample_idx : sample_idx + 1]
|
|
402
|
+
)
|
|
403
|
+
)
|
|
404
|
+
if self.batch_first:
|
|
405
|
+
output_hidden_states = torch.cat(output_hidden_states, dim=0)
|
|
406
|
+
else:
|
|
407
|
+
output_hidden_states = torch.cat(output_hidden_states, dim=1)
|
|
408
|
+
|
|
409
|
+
self._merged_state_dict = None
|
|
410
|
+
return output_hidden_states
|
|
@@ -7,7 +7,11 @@ from fusion_bench.utils.lazy_imports import LazyImporter
|
|
|
7
7
|
|
|
8
8
|
_import_structure = {
|
|
9
9
|
"base_pool": ["BaseTaskPool"],
|
|
10
|
-
"clip_vision": [
|
|
10
|
+
"clip_vision": [
|
|
11
|
+
"CLIPVisionModelTaskPool",
|
|
12
|
+
"SparseWEMoECLIPVisionModelTaskPool",
|
|
13
|
+
"RankoneWEMoECLIPVisionModelTaskPool",
|
|
14
|
+
],
|
|
11
15
|
"dummy": ["DummyTaskPool"],
|
|
12
16
|
"gpt2_text_classification": ["GPT2TextClassificationTaskPool"],
|
|
13
17
|
"nyuv2_taskpool": ["NYUv2TaskPool"],
|
|
@@ -17,7 +21,11 @@ _import_structure = {
|
|
|
17
21
|
|
|
18
22
|
if TYPE_CHECKING:
|
|
19
23
|
from .base_pool import BaseTaskPool
|
|
20
|
-
from .clip_vision import
|
|
24
|
+
from .clip_vision import (
|
|
25
|
+
CLIPVisionModelTaskPool,
|
|
26
|
+
RankoneWEMoECLIPVisionModelTaskPool,
|
|
27
|
+
SparseWEMoECLIPVisionModelTaskPool,
|
|
28
|
+
)
|
|
21
29
|
from .dummy import DummyTaskPool
|
|
22
30
|
from .gpt2_text_classification import GPT2TextClassificationTaskPool
|
|
23
31
|
from .llama import LlamaTestGenerationTaskPool
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
from copy import deepcopy
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any, Dict, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import Tensor
|
|
7
|
+
from torch.utils.hooks import RemovableHandle
|
|
8
|
+
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
|
|
9
|
+
from transformers.models.clip.modeling_clip import CLIPVisionTransformer
|
|
10
|
+
|
|
11
|
+
from fusion_bench.models.hf_clip import HFCLIPClassifier
|
|
12
|
+
from fusion_bench.models.rankone_moe import RankOneMoE
|
|
13
|
+
|
|
14
|
+
from .taskpool import CLIPVisionModelTaskPool
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class LayerWiseRoutingWeightSaver:
|
|
18
|
+
def __init__(self, save_path: Path, max_num: Optional[int] = None):
|
|
19
|
+
self.save_path = save_path
|
|
20
|
+
self.max_num = max_num
|
|
21
|
+
self.routing_weights = []
|
|
22
|
+
|
|
23
|
+
def __call__(self, module, input: Tuple[Tensor], output: Tensor):
|
|
24
|
+
assert isinstance(output, Tensor), "Output is expected to be a Tensor"
|
|
25
|
+
# (batch_size, num_tokens, num_experts)
|
|
26
|
+
routing_weights = output.detach().cpu()
|
|
27
|
+
if self.max_num is not None and self.max_num > 0:
|
|
28
|
+
if len(self.routing_weights) > self.max_num:
|
|
29
|
+
return
|
|
30
|
+
elif routing_weights.size(0) + len(self.routing_weights) > self.max_num:
|
|
31
|
+
self.routing_weights.append(
|
|
32
|
+
routing_weights[: self.max_num - len(self.routing_weights)]
|
|
33
|
+
)
|
|
34
|
+
else:
|
|
35
|
+
self.routing_weights.append(routing_weights)
|
|
36
|
+
else:
|
|
37
|
+
self.routing_weights.append(routing_weights)
|
|
38
|
+
|
|
39
|
+
def save_routing_weights(self):
|
|
40
|
+
routing_weights = torch.cat(self.routing_weights, dim=0)
|
|
41
|
+
if self.save_path is not None:
|
|
42
|
+
self.save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
43
|
+
print(f"Saving routing weights to {self.save_path}")
|
|
44
|
+
torch.save(routing_weights, self.save_path)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class RankoneMoECLIPVisionModelTaskPool(CLIPVisionModelTaskPool):
|
|
48
|
+
|
|
49
|
+
# hooks and handles for saving layer-wise routing weights
|
|
50
|
+
_layer_wise_routing_weights_save_hooks: Dict[Any, LayerWiseRoutingWeightSaver] = {}
|
|
51
|
+
_layer_wise_routing_weights_save_hook_handles: Dict[Any, RemovableHandle] = {}
|
|
52
|
+
|
|
53
|
+
_config_mapping = CLIPVisionModelTaskPool._config_mapping | {
|
|
54
|
+
"_layer_wise_routing_weights_save_path": "layer_wise_routing_weights_save_path",
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
layer_wise_routing_weights_save_path: Optional[str],
|
|
60
|
+
layer_wise_routing_weights_max_num: Optional[int] = None,
|
|
61
|
+
**kwargs,
|
|
62
|
+
):
|
|
63
|
+
# save path for layer-wise routing weights
|
|
64
|
+
self._layer_wise_routing_weights_save_path = (
|
|
65
|
+
layer_wise_routing_weights_save_path
|
|
66
|
+
)
|
|
67
|
+
self.layer_wise_routing_weights_save_path = (
|
|
68
|
+
Path(layer_wise_routing_weights_save_path)
|
|
69
|
+
if layer_wise_routing_weights_save_path is not None
|
|
70
|
+
else None
|
|
71
|
+
)
|
|
72
|
+
self.layer_wise_routing_weights_max_num = layer_wise_routing_weights_max_num
|
|
73
|
+
super().__init__(**kwargs)
|
|
74
|
+
|
|
75
|
+
def on_task_evaluation_begin(self, classifier: HFCLIPClassifier, task_name: str):
|
|
76
|
+
super().on_task_evaluation_begin(classifier, task_name)
|
|
77
|
+
if self.layer_wise_routing_weights_save_path is not None:
|
|
78
|
+
# setup hooks for saving layer-wise routing weights
|
|
79
|
+
assert isinstance(
|
|
80
|
+
classifier.clip_model.vision_model,
|
|
81
|
+
(CLIPVisionTransformer, CLIPVisionModel),
|
|
82
|
+
), "Vision model is expected to be a CLIPVisionTransformer"
|
|
83
|
+
vision_model = classifier.clip_model.vision_model
|
|
84
|
+
if isinstance(vision_model, CLIPVisionModel):
|
|
85
|
+
vision_model = vision_model.vision_model
|
|
86
|
+
# assign forward hooks for each layer
|
|
87
|
+
|
|
88
|
+
for i, layer in enumerate(vision_model.encoder.layers):
|
|
89
|
+
mlp = layer.mlp
|
|
90
|
+
assert isinstance(
|
|
91
|
+
mlp,
|
|
92
|
+
(RankOneMoE),
|
|
93
|
+
), f"MLP is expected to be a RankOneWeightEnsemblingMoE, but got {type(mlp)}"
|
|
94
|
+
# layer-wise routing weights
|
|
95
|
+
hook = LayerWiseRoutingWeightSaver(
|
|
96
|
+
self.layer_wise_routing_weights_save_path
|
|
97
|
+
/ task_name
|
|
98
|
+
/ f"layer_{i}.pt",
|
|
99
|
+
max_num=self.layer_wise_routing_weights_max_num,
|
|
100
|
+
)
|
|
101
|
+
self._layer_wise_routing_weights_save_hooks[i] = hook
|
|
102
|
+
self._layer_wise_routing_weights_save_hook_handles[i] = (
|
|
103
|
+
mlp.gate.register_forward_hook(hook)
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
def on_task_evaluation_end(self):
|
|
107
|
+
super().on_task_evaluation_end()
|
|
108
|
+
if self.layer_wise_routing_weights_save_path is not None:
|
|
109
|
+
# remove hooks for saving layer-wise routing weights
|
|
110
|
+
for i, handle in self._layer_wise_routing_weights_save_hook_handles.items():
|
|
111
|
+
self._layer_wise_routing_weights_save_hooks[i].save_routing_weights()
|
|
112
|
+
handle.remove()
|
|
@@ -3,9 +3,10 @@ import os
|
|
|
3
3
|
from typing import Optional
|
|
4
4
|
|
|
5
5
|
from datasets import load_dataset, load_from_disk
|
|
6
|
+
from omegaconf import DictConfig
|
|
6
7
|
|
|
7
8
|
from fusion_bench.utils import instantiate, timeit_context
|
|
8
|
-
|
|
9
|
+
|
|
9
10
|
from .glue_preprocessors import glue_processors
|
|
10
11
|
from .glue_prompt_templates import glue_prompt_templates
|
|
11
12
|
|