fusion-bench 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. fusion_bench/compat/method/__init__.py +1 -0
  2. fusion_bench/compat/method/base_algorithm.py +0 -1
  3. fusion_bench/compat/modelpool/__init__.py +2 -1
  4. fusion_bench/dataset/arc_agi/__init__.py +6 -1
  5. fusion_bench/dataset/arc_agi/arc.py +21 -7
  6. fusion_bench/dataset/arc_agi/arc_agi.py +156 -25
  7. fusion_bench/dataset/arc_agi/np_cache.py +0 -1
  8. fusion_bench/dataset/arc_agi/preprocess.py +50 -8
  9. fusion_bench/dataset/llama/collate.py +10 -3
  10. fusion_bench/method/__init__.py +3 -0
  11. fusion_bench/method/adamerging/__init__.py +1 -1
  12. fusion_bench/method/lm_finetune/fullfinetune_sft.py +47 -5
  13. fusion_bench/method/lm_finetune/peftfinetune_sft.py +58 -23
  14. fusion_bench/method/pruning/magnitude_diff_pruning.py +2 -1
  15. fusion_bench/method/rankone_moe/__init__.py +3 -0
  16. fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
  17. fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
  18. fusion_bench/method/simple_average.py +1 -1
  19. fusion_bench/mixins/clip_classification.py +2 -7
  20. fusion_bench/mixins/lightning_fabric.py +2 -2
  21. fusion_bench/models/rankone_moe.py +410 -0
  22. fusion_bench/taskpool/__init__.py +10 -2
  23. fusion_bench/taskpool/clip_vision/__init__.py +1 -0
  24. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
  25. fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +2 -1
  26. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/METADATA +1 -1
  27. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/RECORD +36 -29
  28. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +4 -4
  29. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +13 -7
  30. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
  31. fusion_bench_config/method/regmean/clip_regmean.yaml +1 -0
  32. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
  33. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/LICENSE +0 -0
  34. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/WHEEL +0 -0
  35. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/entry_points.txt +0 -0
  36. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,410 @@
1
+ import functools
2
+ import logging
3
+ from typing import Dict, List, Tuple # noqa: F401
4
+
5
+ import torch
6
+ import torch.func
7
+ from torch import Tensor, nn
8
+ from torch.func import functional_call
9
+ from torch.nn import functional as F
10
+
11
+ from fusion_bench.utils.type import StateDictType
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+
16
+ def join_list(list_of_list: List[List]):
17
+ ans = []
18
+ for l in list_of_list:
19
+ ans.extend(l)
20
+ return ans
21
+
22
+
23
+ def del_attr(obj, names: List[str]):
24
+ """
25
+ Deletes an attribute from an object recursively.
26
+
27
+ Args:
28
+ obj (object): Object to delete attribute from.
29
+ names (list): List of attribute names to delete recursively.
30
+ """
31
+ if len(names) == 1:
32
+ delattr(obj, names[0])
33
+ else:
34
+ del_attr(getattr(obj, names[0]), names[1:])
35
+
36
+
37
+ def set_attr(obj, names: List[str], val):
38
+ """
39
+ Sets an attribute of an object recursively.
40
+
41
+ Args:
42
+ obj (object): Object to set attribute of.
43
+ names (list): List of attribute names to set recursively.
44
+ val (object): Value to set the attribute to.
45
+ """
46
+ if len(names) == 1:
47
+ setattr(obj, names[0], val)
48
+ else:
49
+ set_attr(getattr(obj, names[0]), names[1:], val)
50
+
51
+
52
+ def get_attr(obj, names: List[str]):
53
+ """
54
+ Gets an attribute of an object recursively.
55
+
56
+ Args:
57
+ obj (object): Object to get attribute of.
58
+ names (list): List of attribute names to get recursively.
59
+
60
+ Returns:
61
+ object: The attribute of the object.
62
+ """
63
+ if len(names) == 1:
64
+ return getattr(obj, names[0])
65
+ else:
66
+ return get_attr(getattr(obj, names[0]), names[1:])
67
+
68
+
69
+ class Depth_0_Gate(nn.Module):
70
+ def __init__(self, num_experts: int):
71
+ super().__init__()
72
+ self.weight = nn.Parameter(torch.empty(num_experts), requires_grad=True)
73
+
74
+ def init_weight(self, init_lambda: float):
75
+ nn.init.constant_(self.weight, init_lambda)
76
+
77
+ def forward(self, *args, **kwargs) -> Tensor:
78
+ return self.weight
79
+
80
+
81
+ class Depth_1_Gate(nn.Module):
82
+ def __init__(self, hidden_size: int, num_experts: int):
83
+ super().__init__()
84
+ self.fc = nn.Linear(hidden_size, num_experts, bias=True)
85
+
86
+ def init_weight(self, init_lambda: float):
87
+ nn.init.normal_(self.fc.weight, std=0.01)
88
+ nn.init.constant_(self.fc.bias, init_lambda)
89
+
90
+ def forward(self, hidden_states: Tensor) -> Tensor:
91
+ return self.fc(hidden_states)
92
+
93
+
94
+ class Depth_2_Gate(nn.Module):
95
+ def __init__(self, hidden_size: int, num_experts: int):
96
+ super().__init__()
97
+ self.fc1 = nn.Linear(hidden_size, num_experts * 2, bias=True)
98
+ self.fc2 = nn.Linear(num_experts * 2, num_experts, bias=True)
99
+
100
+ def init_weight(self, init_lambda: float):
101
+ nn.init.normal_(self.fc1.weight, std=0.01)
102
+ nn.init.zeros_(self.fc1.bias)
103
+ nn.init.normal_(self.fc2.weight, std=0.01)
104
+ nn.init.constant_(self.fc2.bias, init_lambda)
105
+
106
+ def forward(self, hidden_states: Tensor) -> Tensor:
107
+ hidden_states = F.relu(self.fc1(hidden_states))
108
+ return self.fc2(hidden_states)
109
+
110
+
111
+ def construct_rankone_moe_gate(
112
+ hidden_size: int,
113
+ num_experts: int,
114
+ init_lambda: float,
115
+ num_hidden_layers: int = 2,
116
+ ):
117
+ if num_hidden_layers == 0:
118
+ gate = Depth_0_Gate(num_experts)
119
+ elif num_hidden_layers == 1:
120
+ gate = Depth_1_Gate(hidden_size, num_experts)
121
+ elif num_hidden_layers == 2:
122
+ gate = Depth_2_Gate(hidden_size, num_experts)
123
+ else:
124
+ raise ValueError(f"Unsupported number of hidden layers: {num_hidden_layers}")
125
+
126
+ gate.num_hidden_layers = num_hidden_layers
127
+ gate.init_weight(init_lambda)
128
+ return gate
129
+
130
+
131
+ class ExpertNotTrainedError(Exception):
132
+ pass
133
+
134
+
135
+ def _is_all_zeros(tensor: Tensor | List[Tensor]) -> bool:
136
+ """
137
+ Check if a tensor or a list of tensors are all zeros.
138
+ """
139
+ if isinstance(tensor, Tensor):
140
+ return torch.allclose(tensor, torch.zeros_like(tensor))
141
+ else:
142
+ return all(_is_all_zeros(t) for t in tensor)
143
+
144
+
145
+ def _svd(w: Tensor, full_matrices=True) -> Tuple[Tensor, Tensor, Tensor]:
146
+ """
147
+ Perform Singular Value Decomposition (SVD) on a tensor.
148
+ """
149
+ u, s, vh = torch.linalg.svd(
150
+ w, full_matrices=full_matrices, driver="gesvd" if w.is_cuda else None
151
+ )
152
+ v = vh.T
153
+ return u, s, v
154
+
155
+
156
+ def svd(
157
+ w: Tensor, full_matrices=True, accelerator=None
158
+ ) -> Tuple[Tensor, Tensor, Tensor]:
159
+ """
160
+ Perform SVD on a tensor, optionally using a specified accelerator.
161
+ """
162
+ if accelerator is None:
163
+ return _svd(w, full_matrices=full_matrices)
164
+ original_device = w.device
165
+ w = w.to(accelerator)
166
+ u, s, v = _svd(w)
167
+ return u.to(original_device), s.to(original_device), v.to(original_device)
168
+
169
+
170
+ def fun_joint_svd(
171
+ w_list: List[Tensor], accelerator=None
172
+ ) -> Tuple[Tensor, Tensor, Tensor]:
173
+
174
+ w = torch.cat(w_list, dim=1) # stacked_matrix
175
+ original_device = w.device
176
+ if accelerator is not None:
177
+ w = w.to(accelerator)
178
+ u_c, s_c, vh_c = torch.linalg.svd(
179
+ w, full_matrices=False, driver="gesvd" if w.is_cuda else None
180
+ )
181
+
182
+ svd_list = []
183
+ offset = 0
184
+ for matrix in w_list:
185
+ n_cols = matrix.size(1)
186
+ u = u_c
187
+ s = s_c
188
+ vh_ = vh_c[:, offset : offset + n_cols]
189
+ v = vh_.T
190
+ svd_list.append(
191
+ [u.to(original_device), s.to(original_device), v.to(original_device)]
192
+ )
193
+
194
+ offset += n_cols
195
+ return svd_list
196
+
197
+
198
+ class RankOneMoE(nn.Module):
199
+ # variable to store the merged state dict temporarily
200
+ _merged_state_dict: StateDictType = None
201
+
202
+ def __init__(
203
+ self,
204
+ hidden_size: int,
205
+ base_model: nn.Module,
206
+ expert_models: List[nn.Module],
207
+ init_lambda: float = 0.2,
208
+ batch_first: bool = False,
209
+ router_hidden_layers: int = 2,
210
+ batch_reduce: bool = False,
211
+ svd_accelerator=False,
212
+ rank_k: int = -1,
213
+ select_k: int = -1,
214
+ ):
215
+ """
216
+ Initializes the RankOneMoE class.
217
+ https://github.com/EnnengYang/RankOne-MoE
218
+
219
+ Args:
220
+ hidden_size (int): The size of the hidden layer in the models.
221
+ base_model (nn.Module): The base model that will be used as a reference for the expert models.
222
+ expert_models (List[nn.Module]): A list of expert models that will be combined.
223
+ init_lambda (float, optional): The initial lambda value for the weight ensembling gate. Defaults to 0.2.
224
+ batch_first (bool, optional): If True, the input tensors are expected to have the batch size as the first dimension. Defaults to False.
225
+ router_hidden_layers (int, optional): The number of hidden layers in the router. Defaults to 2.
226
+ batch_reduce (bool): If True, the batch dimension of routing weights is reduced. Defaults to False.
227
+ """
228
+ super().__init__()
229
+ self.num_experts = len(expert_models)
230
+ self.hidden_size = hidden_size
231
+ self.batch_first = batch_first
232
+ self.batch_reduce = batch_reduce
233
+ self.svd_accelerator = svd_accelerator
234
+ self.rank_k = rank_k
235
+ self.select_k = select_k
236
+ self.init_lambda = init_lambda
237
+
238
+ self.gate = construct_rankone_moe_gate(
239
+ hidden_size=hidden_size,
240
+ num_experts=int(self.num_experts * self.rank_k),
241
+ init_lambda=init_lambda,
242
+ num_hidden_layers=router_hidden_layers,
243
+ )
244
+
245
+ # compute the task vectors
246
+ for name, param in base_model.named_parameters():
247
+ if not param.requires_grad:
248
+ for m in expert_models:
249
+ del_attr(m, name.split("."))
250
+ else:
251
+ for m in expert_models:
252
+ get_attr(m, name.split(".")).data = (
253
+ get_attr(m, name.split(".")) - param
254
+ )
255
+
256
+ # fix base model and expert models
257
+ self.base_model = base_model.requires_grad_(False)
258
+ for m in expert_models:
259
+ m.requires_grad_(False)
260
+
261
+ # task vecotr (only bias term)
262
+ self.task_vectors_fc1_bias = nn.Parameter(
263
+ torch.stack([e.fc1.bias for e in expert_models], dim=0), requires_grad=False
264
+ )
265
+ self.task_vectors_fc2_bias = nn.Parameter(
266
+ torch.stack([e.fc2.bias for e in expert_models], dim=0), requires_grad=False
267
+ )
268
+
269
+ # SVD representation of task vector (only weight term)
270
+ self.task_vectors_fc1_u = nn.ParameterList()
271
+ self.task_vectors_fc1_svh = nn.ParameterList()
272
+ self.task_vectors_fc2_u = nn.ParameterList()
273
+ self.task_vectors_fc2_svh = nn.ParameterList()
274
+
275
+ for m in expert_models:
276
+ for name, param in m.named_parameters():
277
+ if ".weight" in name:
278
+
279
+ if _is_all_zeros(param):
280
+ # All fine-tuned models are identical to the pretrained model
281
+ raise ExpertNotTrainedError()
282
+
283
+ u, s, v = svd(param, accelerator=self.svd_accelerator)
284
+ u = u[:, : self.rank_k]
285
+ s = s[: self.rank_k]
286
+ v = v[:, : self.rank_k]
287
+
288
+ if "fc1.weight" == name:
289
+ self.task_vectors_fc1_u.append(
290
+ nn.Parameter(u.T, requires_grad=False)
291
+ )
292
+ self.task_vectors_fc1_svh.append(
293
+ nn.Parameter((s * v).T, requires_grad=False)
294
+ )
295
+ elif "fc2.weight" == name:
296
+ self.task_vectors_fc2_u.append(
297
+ nn.Parameter(u.T, requires_grad=False)
298
+ )
299
+ self.task_vectors_fc2_svh.append(
300
+ nn.Parameter((s * v).T, requires_grad=False)
301
+ )
302
+
303
+ # remove the original module from fine-tuned models to save memory
304
+ for name, param in base_model.named_parameters():
305
+ name_list = name.split(".")
306
+ for m in expert_models:
307
+ set_attr(m, name_list, None)
308
+
309
+ @property
310
+ def forward_model(self):
311
+ return functools.partial(
312
+ functional_call,
313
+ self.base_model,
314
+ self._merged_state_dict,
315
+ )
316
+
317
+ def top_k_soft(self, s, k):
318
+ threshold, _ = torch.topk(s, k, largest=True, sorted=False)
319
+ min_threshold = threshold.min()
320
+ # sigmoid -> mask
321
+ mask = torch.sigmoid(100 * (s - min_threshold))
322
+ result = s * mask
323
+ return result
324
+
325
+ def merge_weights(self, expert_weights):
326
+ state_dict = self.base_model.state_dict(keep_vars=True)
327
+
328
+ # Select top-k experts from the expert pool for fusion
329
+ if self.select_k > 0:
330
+ expert_weights = self.top_k_soft(expert_weights, self.select_k)
331
+
332
+ for name in state_dict:
333
+ if name == "fc1.bias":
334
+ for param in self.task_vectors_fc1_bias:
335
+ state_dict[name] = state_dict[name] + self.init_lambda * param
336
+ elif name == "fc2.bias":
337
+ for param in self.task_vectors_fc2_bias:
338
+ state_dict[name] = state_dict[name] + self.init_lambda * param
339
+
340
+ elif name == "fc1.weight":
341
+ w_list = torch.split(
342
+ expert_weights,
343
+ int(expert_weights.size(-1) / self.num_experts),
344
+ dim=-1,
345
+ )
346
+ for weight, u, svh in zip(
347
+ w_list, self.task_vectors_fc1_u, self.task_vectors_fc1_svh
348
+ ):
349
+ weight_diag = torch.diag(weight)
350
+ weight_u = torch.mm(weight_diag, u)
351
+ result = torch.matmul(weight_u.T, svh)
352
+ state_dict[name] = state_dict[name] + result
353
+
354
+ elif name == "fc2.weight":
355
+ w_list = torch.split(
356
+ expert_weights,
357
+ int(expert_weights.size(-1) / self.num_experts),
358
+ dim=-1,
359
+ )
360
+ for weight, u, svh in zip(
361
+ w_list, self.task_vectors_fc2_u, self.task_vectors_fc2_svh
362
+ ):
363
+ weight_diag = torch.diag(weight)
364
+ weight_u = torch.mm(weight_diag, u)
365
+ result = torch.matmul(weight_u.T, svh)
366
+ state_dict[name] = state_dict[name] + result
367
+
368
+ self._merged_state_dict = state_dict
369
+ return state_dict
370
+
371
+ def forward(self, hidden_states: Tensor):
372
+ if self.gate.num_hidden_layers == 0:
373
+ gate_weights = self.gate()
374
+ else:
375
+ gate_weights = self.gate(hidden_states)
376
+ if self.batch_first:
377
+ # the input is in the shape of (batch_size, seq_len, hidden_size)
378
+ gate_weights = gate_weights.mean(dim=1)
379
+ else:
380
+ # the input is in the shape of (seq_len, batch_size, hidden_size)
381
+ gate_weights = gate_weights.mean(dim=0)
382
+
383
+ if self.gate.num_hidden_layers == 0:
384
+ self.merge_weights(gate_weights)
385
+ output_hidden_states = self.forward_model(hidden_states)
386
+ elif self.batch_reduce:
387
+ gate_weights = gate_weights.mean(dim=0)
388
+ self.merge_weights(gate_weights)
389
+ output_hidden_states = self.forward_model(hidden_states)
390
+ else:
391
+ output_hidden_states = []
392
+ for sample_idx, weights in enumerate(gate_weights):
393
+ self.merge_weights(weights)
394
+ if self.batch_first:
395
+ output_hidden_states.append(
396
+ self.forward_model(hidden_states[sample_idx : sample_idx + 1])
397
+ )
398
+ else:
399
+ output_hidden_states.append(
400
+ self.forward_model(
401
+ hidden_states[:, sample_idx : sample_idx + 1]
402
+ )
403
+ )
404
+ if self.batch_first:
405
+ output_hidden_states = torch.cat(output_hidden_states, dim=0)
406
+ else:
407
+ output_hidden_states = torch.cat(output_hidden_states, dim=1)
408
+
409
+ self._merged_state_dict = None
410
+ return output_hidden_states
@@ -7,7 +7,11 @@ from fusion_bench.utils.lazy_imports import LazyImporter
7
7
 
8
8
  _import_structure = {
9
9
  "base_pool": ["BaseTaskPool"],
10
- "clip_vision": ["CLIPVisionModelTaskPool", "SparseWEMoECLIPVisionModelTaskPool"],
10
+ "clip_vision": [
11
+ "CLIPVisionModelTaskPool",
12
+ "SparseWEMoECLIPVisionModelTaskPool",
13
+ "RankoneWEMoECLIPVisionModelTaskPool",
14
+ ],
11
15
  "dummy": ["DummyTaskPool"],
12
16
  "gpt2_text_classification": ["GPT2TextClassificationTaskPool"],
13
17
  "nyuv2_taskpool": ["NYUv2TaskPool"],
@@ -17,7 +21,11 @@ _import_structure = {
17
21
 
18
22
  if TYPE_CHECKING:
19
23
  from .base_pool import BaseTaskPool
20
- from .clip_vision import CLIPVisionModelTaskPool, SparseWEMoECLIPVisionModelTaskPool
24
+ from .clip_vision import (
25
+ CLIPVisionModelTaskPool,
26
+ RankoneWEMoECLIPVisionModelTaskPool,
27
+ SparseWEMoECLIPVisionModelTaskPool,
28
+ )
21
29
  from .dummy import DummyTaskPool
22
30
  from .gpt2_text_classification import GPT2TextClassificationTaskPool
23
31
  from .llama import LlamaTestGenerationTaskPool
@@ -1,3 +1,4 @@
1
1
  # flake8: noqa F401
2
+ from .clip_rankone_moe_taskpool import RankoneMoECLIPVisionModelTaskPool
2
3
  from .clip_sparse_wemoe_taskpool import SparseWEMoECLIPVisionModelTaskPool
3
4
  from .taskpool import CLIPVisionModelTaskPool
@@ -0,0 +1,112 @@
1
+ from copy import deepcopy
2
+ from pathlib import Path
3
+ from typing import Any, Dict, Optional, Tuple
4
+
5
+ import torch
6
+ from torch import Tensor
7
+ from torch.utils.hooks import RemovableHandle
8
+ from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
9
+ from transformers.models.clip.modeling_clip import CLIPVisionTransformer
10
+
11
+ from fusion_bench.models.hf_clip import HFCLIPClassifier
12
+ from fusion_bench.models.rankone_moe import RankOneMoE
13
+
14
+ from .taskpool import CLIPVisionModelTaskPool
15
+
16
+
17
+ class LayerWiseRoutingWeightSaver:
18
+ def __init__(self, save_path: Path, max_num: Optional[int] = None):
19
+ self.save_path = save_path
20
+ self.max_num = max_num
21
+ self.routing_weights = []
22
+
23
+ def __call__(self, module, input: Tuple[Tensor], output: Tensor):
24
+ assert isinstance(output, Tensor), "Output is expected to be a Tensor"
25
+ # (batch_size, num_tokens, num_experts)
26
+ routing_weights = output.detach().cpu()
27
+ if self.max_num is not None and self.max_num > 0:
28
+ if len(self.routing_weights) > self.max_num:
29
+ return
30
+ elif routing_weights.size(0) + len(self.routing_weights) > self.max_num:
31
+ self.routing_weights.append(
32
+ routing_weights[: self.max_num - len(self.routing_weights)]
33
+ )
34
+ else:
35
+ self.routing_weights.append(routing_weights)
36
+ else:
37
+ self.routing_weights.append(routing_weights)
38
+
39
+ def save_routing_weights(self):
40
+ routing_weights = torch.cat(self.routing_weights, dim=0)
41
+ if self.save_path is not None:
42
+ self.save_path.parent.mkdir(parents=True, exist_ok=True)
43
+ print(f"Saving routing weights to {self.save_path}")
44
+ torch.save(routing_weights, self.save_path)
45
+
46
+
47
+ class RankoneMoECLIPVisionModelTaskPool(CLIPVisionModelTaskPool):
48
+
49
+ # hooks and handles for saving layer-wise routing weights
50
+ _layer_wise_routing_weights_save_hooks: Dict[Any, LayerWiseRoutingWeightSaver] = {}
51
+ _layer_wise_routing_weights_save_hook_handles: Dict[Any, RemovableHandle] = {}
52
+
53
+ _config_mapping = CLIPVisionModelTaskPool._config_mapping | {
54
+ "_layer_wise_routing_weights_save_path": "layer_wise_routing_weights_save_path",
55
+ }
56
+
57
+ def __init__(
58
+ self,
59
+ layer_wise_routing_weights_save_path: Optional[str],
60
+ layer_wise_routing_weights_max_num: Optional[int] = None,
61
+ **kwargs,
62
+ ):
63
+ # save path for layer-wise routing weights
64
+ self._layer_wise_routing_weights_save_path = (
65
+ layer_wise_routing_weights_save_path
66
+ )
67
+ self.layer_wise_routing_weights_save_path = (
68
+ Path(layer_wise_routing_weights_save_path)
69
+ if layer_wise_routing_weights_save_path is not None
70
+ else None
71
+ )
72
+ self.layer_wise_routing_weights_max_num = layer_wise_routing_weights_max_num
73
+ super().__init__(**kwargs)
74
+
75
+ def on_task_evaluation_begin(self, classifier: HFCLIPClassifier, task_name: str):
76
+ super().on_task_evaluation_begin(classifier, task_name)
77
+ if self.layer_wise_routing_weights_save_path is not None:
78
+ # setup hooks for saving layer-wise routing weights
79
+ assert isinstance(
80
+ classifier.clip_model.vision_model,
81
+ (CLIPVisionTransformer, CLIPVisionModel),
82
+ ), "Vision model is expected to be a CLIPVisionTransformer"
83
+ vision_model = classifier.clip_model.vision_model
84
+ if isinstance(vision_model, CLIPVisionModel):
85
+ vision_model = vision_model.vision_model
86
+ # assign forward hooks for each layer
87
+
88
+ for i, layer in enumerate(vision_model.encoder.layers):
89
+ mlp = layer.mlp
90
+ assert isinstance(
91
+ mlp,
92
+ (RankOneMoE),
93
+ ), f"MLP is expected to be a RankOneWeightEnsemblingMoE, but got {type(mlp)}"
94
+ # layer-wise routing weights
95
+ hook = LayerWiseRoutingWeightSaver(
96
+ self.layer_wise_routing_weights_save_path
97
+ / task_name
98
+ / f"layer_{i}.pt",
99
+ max_num=self.layer_wise_routing_weights_max_num,
100
+ )
101
+ self._layer_wise_routing_weights_save_hooks[i] = hook
102
+ self._layer_wise_routing_weights_save_hook_handles[i] = (
103
+ mlp.gate.register_forward_hook(hook)
104
+ )
105
+
106
+ def on_task_evaluation_end(self):
107
+ super().on_task_evaluation_end()
108
+ if self.layer_wise_routing_weights_save_path is not None:
109
+ # remove hooks for saving layer-wise routing weights
110
+ for i, handle in self._layer_wise_routing_weights_save_hook_handles.items():
111
+ self._layer_wise_routing_weights_save_hooks[i].save_routing_weights()
112
+ handle.remove()
@@ -3,9 +3,10 @@ import os
3
3
  from typing import Optional
4
4
 
5
5
  from datasets import load_dataset, load_from_disk
6
+ from omegaconf import DictConfig
6
7
 
7
8
  from fusion_bench.utils import instantiate, timeit_context
8
- from omegaconf import DictConfig
9
+
9
10
  from .glue_preprocessors import glue_processors
10
11
  from .glue_prompt_templates import glue_prompt_templates
11
12
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: fusion_bench
3
- Version: 0.2.5
3
+ Version: 0.2.6
4
4
  Summary: A Comprehensive Benchmark of Deep Model Fusion
5
5
  Author-email: Anke Tang <tang.anke@foxmail.com>
6
6
  License: MIT License