fusion-bench 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. fusion_bench/compat/method/__init__.py +2 -0
  2. fusion_bench/compat/method/base_algorithm.py +7 -2
  3. fusion_bench/compat/modelpool/__init__.py +3 -2
  4. fusion_bench/compat/taskpool/__init__.py +1 -1
  5. fusion_bench/dataset/arc_agi/__init__.py +6 -1
  6. fusion_bench/dataset/arc_agi/arc.py +26 -7
  7. fusion_bench/dataset/arc_agi/arc_agi.py +156 -25
  8. fusion_bench/dataset/arc_agi/np_cache.py +0 -1
  9. fusion_bench/dataset/arc_agi/preprocess.py +51 -9
  10. fusion_bench/dataset/llama/__init__.py +1 -0
  11. fusion_bench/dataset/llama/alpaca.py +93 -3
  12. fusion_bench/dataset/llama/collate.py +72 -5
  13. fusion_bench/dataset/llama/metamathqa.py +50 -0
  14. fusion_bench/dataset/llama/preference_700k.py +70 -0
  15. fusion_bench/dataset/llama/stanford_shp.py +90 -0
  16. fusion_bench/dataset/llama/ultrachat.py +58 -0
  17. fusion_bench/dataset/llama/utils/__init__.py +0 -0
  18. fusion_bench/method/__init__.py +4 -1
  19. fusion_bench/method/adamerging/__init__.py +1 -1
  20. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
  21. fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
  22. fusion_bench/method/linear/expo.py +39 -0
  23. fusion_bench/method/lm_finetune/__init__.py +1 -0
  24. fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
  25. fusion_bench/method/lm_finetune/fullfinetune_sft.py +122 -150
  26. fusion_bench/method/lm_finetune/peftfinetune_sft.py +102 -157
  27. fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
  28. fusion_bench/method/pruning/llama_random_prune.py +2 -2
  29. fusion_bench/method/pruning/magnitude_diff_pruning.py +2 -1
  30. fusion_bench/method/rankone_moe/__init__.py +3 -0
  31. fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
  32. fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
  33. fusion_bench/method/simple_average.py +1 -1
  34. fusion_bench/method/surgery/__init__.py +3 -0
  35. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
  36. fusion_bench/mixins/__init__.py +2 -0
  37. fusion_bench/mixins/clip_classification.py +60 -12
  38. fusion_bench/mixins/fabric_training.py +320 -0
  39. fusion_bench/mixins/lightning_fabric.py +11 -2
  40. fusion_bench/modelpool/__init__.py +2 -0
  41. fusion_bench/modelpool/causal_lm/__init__.py +1 -1
  42. fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
  43. fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
  44. fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
  45. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
  46. fusion_bench/models/chat_templates/__init__.py +1 -0
  47. fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
  48. fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
  49. fusion_bench/models/hf_clip.py +50 -9
  50. fusion_bench/models/rankone_moe.py +410 -0
  51. fusion_bench/models/surgery/surgerymodelwrapper.py +157 -0
  52. fusion_bench/models/utils.py +8 -0
  53. fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
  54. fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
  55. fusion_bench/optim/__init__.py +2 -0
  56. fusion_bench/optim/exception.py +47 -0
  57. fusion_bench/optim/lr_scheduler/__init__.py +1 -0
  58. fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
  59. fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
  60. fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
  61. fusion_bench/optim/mezo.py +0 -2
  62. fusion_bench/programs/fabric_fusion_program.py +5 -1
  63. fusion_bench/taskpool/__init__.py +10 -2
  64. fusion_bench/taskpool/clip_vision/__init__.py +1 -0
  65. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
  66. fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
  67. fusion_bench/taskpool/llama/reward_model.py +157 -0
  68. fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
  69. fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +2 -1
  70. fusion_bench/utils/hydra_utils.py +22 -0
  71. fusion_bench/utils/plot/__init__.py +0 -0
  72. fusion_bench/utils/plot/token.py +52 -0
  73. fusion_bench/utils/plot/token_notebook.py +127 -0
  74. fusion_bench/utils/type.py +5 -3
  75. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/METADATA +1 -1
  76. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/RECORD +104 -57
  77. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  78. fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
  79. fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
  80. fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
  81. fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
  82. fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
  83. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
  84. fusion_bench_config/fabric_model_fusion.yaml +1 -1
  85. fusion_bench_config/llama_full_finetune.yaml +19 -0
  86. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
  87. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +13 -6
  88. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +17 -9
  89. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
  90. fusion_bench_config/method/regmean/clip_regmean.yaml +1 -0
  91. fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
  92. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
  93. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
  94. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
  95. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
  96. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
  97. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
  98. fusion_bench_config/nyuv2_config.yaml +5 -1
  99. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
  100. fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
  101. fusion_bench_config/llama_weighted_average.yaml +0 -26
  102. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/LICENSE +0 -0
  103. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/WHEEL +0 -0
  104. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/entry_points.txt +0 -0
  105. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
1
- from typing import Callable, Iterable, List # noqa: F401
1
+ import logging
2
+ from typing import TYPE_CHECKING, Callable, Iterable, List # noqa: F401
2
3
 
3
4
  import torch
4
5
  from torch import Tensor, nn
@@ -7,6 +8,11 @@ from transformers.models.clip.modeling_clip import BaseModelOutputWithPooling
7
8
 
8
9
  from fusion_bench.utils.devices import get_device
9
10
 
11
+ if TYPE_CHECKING:
12
+ from fusion_bench.models.surgery.surgerymodelwrapper import SurgeryModelWrapper
13
+
14
+ log = logging.getLogger(__name__)
15
+
10
16
  default_templates = [
11
17
  lambda c: f"a photo of a {c}",
12
18
  ]
@@ -33,6 +39,7 @@ class HFCLIPClassifier(nn.Module):
33
39
  self,
34
40
  clip_model: CLIPModel,
35
41
  processor: CLIPProcessor,
42
+ extra_module=None,
36
43
  ):
37
44
  """
38
45
  Initialize the HFCLIPClassifier.
@@ -56,6 +63,8 @@ class HFCLIPClassifier(nn.Module):
56
63
  persistent=False,
57
64
  )
58
65
 
66
+ self.extra_module = extra_module
67
+
59
68
  @property
60
69
  def text_model(self):
61
70
  """Get the text model component of CLIP."""
@@ -111,7 +120,13 @@ class HFCLIPClassifier(nn.Module):
111
120
 
112
121
  self.zeroshot_weights = zeroshot_weights
113
122
 
114
- def forward(self, images, return_image_embeds=False, return_dict=False):
123
+ def forward(
124
+ self,
125
+ images: Tensor,
126
+ return_image_embeds=False,
127
+ return_dict=False,
128
+ task_name=None,
129
+ ):
115
130
  """
116
131
  Perform forward pass for zero-shot image classification.
117
132
 
@@ -120,6 +135,9 @@ class HFCLIPClassifier(nn.Module):
120
135
 
121
136
  Args:
122
137
  images (Tensor): Input images to classify.
138
+ return_image_embeds (bool): Whether to return the image embeddings.
139
+ return_dict (bool): Whether to return a dictionary with logits and image embeddings.
140
+ task_name (Optional[str]): The name of the task.
123
141
 
124
142
  Returns:
125
143
  Tensor: Classification logits for each input image.
@@ -131,16 +149,22 @@ class HFCLIPClassifier(nn.Module):
131
149
  raise ValueError("Must set classification task before forward pass")
132
150
  text_embeds = self.zeroshot_weights
133
151
 
134
- image_embeds = self.vision_model(images)
135
- if isinstance(image_embeds, Tensor):
136
- pass
137
- elif isinstance(image_embeds, BaseModelOutputWithPooling):
138
- image_embeds = image_embeds[1]
139
- image_embeds = self.clip_model.visual_projection(image_embeds)
140
-
152
+ image_embeds = self.get_image_features(images)
141
153
  # normalize embeddings
142
154
  image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
143
155
 
156
+ if (
157
+ hasattr(self.vision_model, "is_surgery_model")
158
+ and self.vision_model.is_surgery_model
159
+ ):
160
+ # Dealing with the surgery model, for more details, please refer to:
161
+ # (ICML 2024) Yang, et.al. Representation Surgery for Multi-Task Model Merging
162
+ # https://arxiv.org/abs/2402.02705
163
+ self.vision_model: "SurgeryModelWrapper" = self.vision_model
164
+ image_embeds, _, _ = self.vision_model.compute_surgery_features(
165
+ image_embeds, dataset_name=task_name
166
+ )
167
+
144
168
  # cosine similarity
145
169
  logit_scale = self.clip_model.logit_scale.exp()
146
170
  logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
@@ -156,3 +180,20 @@ class HFCLIPClassifier(nn.Module):
156
180
  return logits_per_image, image_embeds
157
181
  else:
158
182
  return logits_per_image
183
+
184
+ def get_image_features(self, images: Tensor) -> Tensor:
185
+ """
186
+ Compute the image embeddings.
187
+
188
+ Returns:
189
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
190
+ applying the projection layer to the pooled output of [`CLIPVisionModel`].
191
+ """
192
+
193
+ image_embeds = self.vision_model(images)
194
+ if isinstance(image_embeds, Tensor):
195
+ pass
196
+ elif isinstance(image_embeds, BaseModelOutputWithPooling):
197
+ image_embeds = image_embeds[1]
198
+ image_embeds = self.clip_model.visual_projection(image_embeds)
199
+ return image_embeds
@@ -0,0 +1,410 @@
1
+ import functools
2
+ import logging
3
+ from typing import Dict, List, Tuple # noqa: F401
4
+
5
+ import torch
6
+ import torch.func
7
+ from torch import Tensor, nn
8
+ from torch.func import functional_call
9
+ from torch.nn import functional as F
10
+
11
+ from fusion_bench.utils.type import StateDictType
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+
16
+ def join_list(list_of_list: List[List]):
17
+ ans = []
18
+ for l in list_of_list:
19
+ ans.extend(l)
20
+ return ans
21
+
22
+
23
+ def del_attr(obj, names: List[str]):
24
+ """
25
+ Deletes an attribute from an object recursively.
26
+
27
+ Args:
28
+ obj (object): Object to delete attribute from.
29
+ names (list): List of attribute names to delete recursively.
30
+ """
31
+ if len(names) == 1:
32
+ delattr(obj, names[0])
33
+ else:
34
+ del_attr(getattr(obj, names[0]), names[1:])
35
+
36
+
37
+ def set_attr(obj, names: List[str], val):
38
+ """
39
+ Sets an attribute of an object recursively.
40
+
41
+ Args:
42
+ obj (object): Object to set attribute of.
43
+ names (list): List of attribute names to set recursively.
44
+ val (object): Value to set the attribute to.
45
+ """
46
+ if len(names) == 1:
47
+ setattr(obj, names[0], val)
48
+ else:
49
+ set_attr(getattr(obj, names[0]), names[1:], val)
50
+
51
+
52
+ def get_attr(obj, names: List[str]):
53
+ """
54
+ Gets an attribute of an object recursively.
55
+
56
+ Args:
57
+ obj (object): Object to get attribute of.
58
+ names (list): List of attribute names to get recursively.
59
+
60
+ Returns:
61
+ object: The attribute of the object.
62
+ """
63
+ if len(names) == 1:
64
+ return getattr(obj, names[0])
65
+ else:
66
+ return get_attr(getattr(obj, names[0]), names[1:])
67
+
68
+
69
+ class Depth_0_Gate(nn.Module):
70
+ def __init__(self, num_experts: int):
71
+ super().__init__()
72
+ self.weight = nn.Parameter(torch.empty(num_experts), requires_grad=True)
73
+
74
+ def init_weight(self, init_lambda: float):
75
+ nn.init.constant_(self.weight, init_lambda)
76
+
77
+ def forward(self, *args, **kwargs) -> Tensor:
78
+ return self.weight
79
+
80
+
81
+ class Depth_1_Gate(nn.Module):
82
+ def __init__(self, hidden_size: int, num_experts: int):
83
+ super().__init__()
84
+ self.fc = nn.Linear(hidden_size, num_experts, bias=True)
85
+
86
+ def init_weight(self, init_lambda: float):
87
+ nn.init.normal_(self.fc.weight, std=0.01)
88
+ nn.init.constant_(self.fc.bias, init_lambda)
89
+
90
+ def forward(self, hidden_states: Tensor) -> Tensor:
91
+ return self.fc(hidden_states)
92
+
93
+
94
+ class Depth_2_Gate(nn.Module):
95
+ def __init__(self, hidden_size: int, num_experts: int):
96
+ super().__init__()
97
+ self.fc1 = nn.Linear(hidden_size, num_experts * 2, bias=True)
98
+ self.fc2 = nn.Linear(num_experts * 2, num_experts, bias=True)
99
+
100
+ def init_weight(self, init_lambda: float):
101
+ nn.init.normal_(self.fc1.weight, std=0.01)
102
+ nn.init.zeros_(self.fc1.bias)
103
+ nn.init.normal_(self.fc2.weight, std=0.01)
104
+ nn.init.constant_(self.fc2.bias, init_lambda)
105
+
106
+ def forward(self, hidden_states: Tensor) -> Tensor:
107
+ hidden_states = F.relu(self.fc1(hidden_states))
108
+ return self.fc2(hidden_states)
109
+
110
+
111
+ def construct_rankone_moe_gate(
112
+ hidden_size: int,
113
+ num_experts: int,
114
+ init_lambda: float,
115
+ num_hidden_layers: int = 2,
116
+ ):
117
+ if num_hidden_layers == 0:
118
+ gate = Depth_0_Gate(num_experts)
119
+ elif num_hidden_layers == 1:
120
+ gate = Depth_1_Gate(hidden_size, num_experts)
121
+ elif num_hidden_layers == 2:
122
+ gate = Depth_2_Gate(hidden_size, num_experts)
123
+ else:
124
+ raise ValueError(f"Unsupported number of hidden layers: {num_hidden_layers}")
125
+
126
+ gate.num_hidden_layers = num_hidden_layers
127
+ gate.init_weight(init_lambda)
128
+ return gate
129
+
130
+
131
+ class ExpertNotTrainedError(Exception):
132
+ pass
133
+
134
+
135
+ def _is_all_zeros(tensor: Tensor | List[Tensor]) -> bool:
136
+ """
137
+ Check if a tensor or a list of tensors are all zeros.
138
+ """
139
+ if isinstance(tensor, Tensor):
140
+ return torch.allclose(tensor, torch.zeros_like(tensor))
141
+ else:
142
+ return all(_is_all_zeros(t) for t in tensor)
143
+
144
+
145
+ def _svd(w: Tensor, full_matrices=True) -> Tuple[Tensor, Tensor, Tensor]:
146
+ """
147
+ Perform Singular Value Decomposition (SVD) on a tensor.
148
+ """
149
+ u, s, vh = torch.linalg.svd(
150
+ w, full_matrices=full_matrices, driver="gesvd" if w.is_cuda else None
151
+ )
152
+ v = vh.T
153
+ return u, s, v
154
+
155
+
156
+ def svd(
157
+ w: Tensor, full_matrices=True, accelerator=None
158
+ ) -> Tuple[Tensor, Tensor, Tensor]:
159
+ """
160
+ Perform SVD on a tensor, optionally using a specified accelerator.
161
+ """
162
+ if accelerator is None:
163
+ return _svd(w, full_matrices=full_matrices)
164
+ original_device = w.device
165
+ w = w.to(accelerator)
166
+ u, s, v = _svd(w)
167
+ return u.to(original_device), s.to(original_device), v.to(original_device)
168
+
169
+
170
+ def fun_joint_svd(
171
+ w_list: List[Tensor], accelerator=None
172
+ ) -> Tuple[Tensor, Tensor, Tensor]:
173
+
174
+ w = torch.cat(w_list, dim=1) # stacked_matrix
175
+ original_device = w.device
176
+ if accelerator is not None:
177
+ w = w.to(accelerator)
178
+ u_c, s_c, vh_c = torch.linalg.svd(
179
+ w, full_matrices=False, driver="gesvd" if w.is_cuda else None
180
+ )
181
+
182
+ svd_list = []
183
+ offset = 0
184
+ for matrix in w_list:
185
+ n_cols = matrix.size(1)
186
+ u = u_c
187
+ s = s_c
188
+ vh_ = vh_c[:, offset : offset + n_cols]
189
+ v = vh_.T
190
+ svd_list.append(
191
+ [u.to(original_device), s.to(original_device), v.to(original_device)]
192
+ )
193
+
194
+ offset += n_cols
195
+ return svd_list
196
+
197
+
198
+ class RankOneMoE(nn.Module):
199
+ # variable to store the merged state dict temporarily
200
+ _merged_state_dict: StateDictType = None
201
+
202
+ def __init__(
203
+ self,
204
+ hidden_size: int,
205
+ base_model: nn.Module,
206
+ expert_models: List[nn.Module],
207
+ init_lambda: float = 0.2,
208
+ batch_first: bool = False,
209
+ router_hidden_layers: int = 2,
210
+ batch_reduce: bool = False,
211
+ svd_accelerator=False,
212
+ rank_k: int = -1,
213
+ select_k: int = -1,
214
+ ):
215
+ """
216
+ Initializes the RankOneMoE class.
217
+ https://github.com/EnnengYang/RankOne-MoE
218
+
219
+ Args:
220
+ hidden_size (int): The size of the hidden layer in the models.
221
+ base_model (nn.Module): The base model that will be used as a reference for the expert models.
222
+ expert_models (List[nn.Module]): A list of expert models that will be combined.
223
+ init_lambda (float, optional): The initial lambda value for the weight ensembling gate. Defaults to 0.2.
224
+ batch_first (bool, optional): If True, the input tensors are expected to have the batch size as the first dimension. Defaults to False.
225
+ router_hidden_layers (int, optional): The number of hidden layers in the router. Defaults to 2.
226
+ batch_reduce (bool): If True, the batch dimension of routing weights is reduced. Defaults to False.
227
+ """
228
+ super().__init__()
229
+ self.num_experts = len(expert_models)
230
+ self.hidden_size = hidden_size
231
+ self.batch_first = batch_first
232
+ self.batch_reduce = batch_reduce
233
+ self.svd_accelerator = svd_accelerator
234
+ self.rank_k = rank_k
235
+ self.select_k = select_k
236
+ self.init_lambda = init_lambda
237
+
238
+ self.gate = construct_rankone_moe_gate(
239
+ hidden_size=hidden_size,
240
+ num_experts=int(self.num_experts * self.rank_k),
241
+ init_lambda=init_lambda,
242
+ num_hidden_layers=router_hidden_layers,
243
+ )
244
+
245
+ # compute the task vectors
246
+ for name, param in base_model.named_parameters():
247
+ if not param.requires_grad:
248
+ for m in expert_models:
249
+ del_attr(m, name.split("."))
250
+ else:
251
+ for m in expert_models:
252
+ get_attr(m, name.split(".")).data = (
253
+ get_attr(m, name.split(".")) - param
254
+ )
255
+
256
+ # fix base model and expert models
257
+ self.base_model = base_model.requires_grad_(False)
258
+ for m in expert_models:
259
+ m.requires_grad_(False)
260
+
261
+ # task vecotr (only bias term)
262
+ self.task_vectors_fc1_bias = nn.Parameter(
263
+ torch.stack([e.fc1.bias for e in expert_models], dim=0), requires_grad=False
264
+ )
265
+ self.task_vectors_fc2_bias = nn.Parameter(
266
+ torch.stack([e.fc2.bias for e in expert_models], dim=0), requires_grad=False
267
+ )
268
+
269
+ # SVD representation of task vector (only weight term)
270
+ self.task_vectors_fc1_u = nn.ParameterList()
271
+ self.task_vectors_fc1_svh = nn.ParameterList()
272
+ self.task_vectors_fc2_u = nn.ParameterList()
273
+ self.task_vectors_fc2_svh = nn.ParameterList()
274
+
275
+ for m in expert_models:
276
+ for name, param in m.named_parameters():
277
+ if ".weight" in name:
278
+
279
+ if _is_all_zeros(param):
280
+ # All fine-tuned models are identical to the pretrained model
281
+ raise ExpertNotTrainedError()
282
+
283
+ u, s, v = svd(param, accelerator=self.svd_accelerator)
284
+ u = u[:, : self.rank_k]
285
+ s = s[: self.rank_k]
286
+ v = v[:, : self.rank_k]
287
+
288
+ if "fc1.weight" == name:
289
+ self.task_vectors_fc1_u.append(
290
+ nn.Parameter(u.T, requires_grad=False)
291
+ )
292
+ self.task_vectors_fc1_svh.append(
293
+ nn.Parameter((s * v).T, requires_grad=False)
294
+ )
295
+ elif "fc2.weight" == name:
296
+ self.task_vectors_fc2_u.append(
297
+ nn.Parameter(u.T, requires_grad=False)
298
+ )
299
+ self.task_vectors_fc2_svh.append(
300
+ nn.Parameter((s * v).T, requires_grad=False)
301
+ )
302
+
303
+ # remove the original module from fine-tuned models to save memory
304
+ for name, param in base_model.named_parameters():
305
+ name_list = name.split(".")
306
+ for m in expert_models:
307
+ set_attr(m, name_list, None)
308
+
309
+ @property
310
+ def forward_model(self):
311
+ return functools.partial(
312
+ functional_call,
313
+ self.base_model,
314
+ self._merged_state_dict,
315
+ )
316
+
317
+ def top_k_soft(self, s, k):
318
+ threshold, _ = torch.topk(s, k, largest=True, sorted=False)
319
+ min_threshold = threshold.min()
320
+ # sigmoid -> mask
321
+ mask = torch.sigmoid(100 * (s - min_threshold))
322
+ result = s * mask
323
+ return result
324
+
325
+ def merge_weights(self, expert_weights):
326
+ state_dict = self.base_model.state_dict(keep_vars=True)
327
+
328
+ # Select top-k experts from the expert pool for fusion
329
+ if self.select_k > 0:
330
+ expert_weights = self.top_k_soft(expert_weights, self.select_k)
331
+
332
+ for name in state_dict:
333
+ if name == "fc1.bias":
334
+ for param in self.task_vectors_fc1_bias:
335
+ state_dict[name] = state_dict[name] + self.init_lambda * param
336
+ elif name == "fc2.bias":
337
+ for param in self.task_vectors_fc2_bias:
338
+ state_dict[name] = state_dict[name] + self.init_lambda * param
339
+
340
+ elif name == "fc1.weight":
341
+ w_list = torch.split(
342
+ expert_weights,
343
+ int(expert_weights.size(-1) / self.num_experts),
344
+ dim=-1,
345
+ )
346
+ for weight, u, svh in zip(
347
+ w_list, self.task_vectors_fc1_u, self.task_vectors_fc1_svh
348
+ ):
349
+ weight_diag = torch.diag(weight)
350
+ weight_u = torch.mm(weight_diag, u)
351
+ result = torch.matmul(weight_u.T, svh)
352
+ state_dict[name] = state_dict[name] + result
353
+
354
+ elif name == "fc2.weight":
355
+ w_list = torch.split(
356
+ expert_weights,
357
+ int(expert_weights.size(-1) / self.num_experts),
358
+ dim=-1,
359
+ )
360
+ for weight, u, svh in zip(
361
+ w_list, self.task_vectors_fc2_u, self.task_vectors_fc2_svh
362
+ ):
363
+ weight_diag = torch.diag(weight)
364
+ weight_u = torch.mm(weight_diag, u)
365
+ result = torch.matmul(weight_u.T, svh)
366
+ state_dict[name] = state_dict[name] + result
367
+
368
+ self._merged_state_dict = state_dict
369
+ return state_dict
370
+
371
+ def forward(self, hidden_states: Tensor):
372
+ if self.gate.num_hidden_layers == 0:
373
+ gate_weights = self.gate()
374
+ else:
375
+ gate_weights = self.gate(hidden_states)
376
+ if self.batch_first:
377
+ # the input is in the shape of (batch_size, seq_len, hidden_size)
378
+ gate_weights = gate_weights.mean(dim=1)
379
+ else:
380
+ # the input is in the shape of (seq_len, batch_size, hidden_size)
381
+ gate_weights = gate_weights.mean(dim=0)
382
+
383
+ if self.gate.num_hidden_layers == 0:
384
+ self.merge_weights(gate_weights)
385
+ output_hidden_states = self.forward_model(hidden_states)
386
+ elif self.batch_reduce:
387
+ gate_weights = gate_weights.mean(dim=0)
388
+ self.merge_weights(gate_weights)
389
+ output_hidden_states = self.forward_model(hidden_states)
390
+ else:
391
+ output_hidden_states = []
392
+ for sample_idx, weights in enumerate(gate_weights):
393
+ self.merge_weights(weights)
394
+ if self.batch_first:
395
+ output_hidden_states.append(
396
+ self.forward_model(hidden_states[sample_idx : sample_idx + 1])
397
+ )
398
+ else:
399
+ output_hidden_states.append(
400
+ self.forward_model(
401
+ hidden_states[:, sample_idx : sample_idx + 1]
402
+ )
403
+ )
404
+ if self.batch_first:
405
+ output_hidden_states = torch.cat(output_hidden_states, dim=0)
406
+ else:
407
+ output_hidden_states = torch.cat(output_hidden_states, dim=1)
408
+
409
+ self._merged_state_dict = None
410
+ return output_hidden_states
@@ -0,0 +1,157 @@
1
+ import math
2
+ from typing import TYPE_CHECKING, List, Union, Callable, Generic
3
+
4
+ import torch
5
+ from torch import nn
6
+ from transformers.models.clip.modeling_clip import (
7
+ CLIPVisionModel,
8
+ CLIPVisionTransformer,
9
+ )
10
+ from fusion_bench.utils.type import TorchModelType
11
+
12
+
13
+ def regularize_name(name: str):
14
+ name = name.replace("-", "_")
15
+ name = name.replace(".", "_")
16
+ return name
17
+
18
+
19
+ class SurgeryModelWrapper(torch.nn.Module, Generic[TorchModelType]):
20
+
21
+ is_surgery_model = True
22
+ """A flag to indicate that this is a surgery model."""
23
+
24
+ def __init__(
25
+ self,
26
+ model: TorchModelType,
27
+ test_datasets: List[str],
28
+ projection_dim: int = 512,
29
+ hidden_dim: int = 16,
30
+ ):
31
+ super(SurgeryModelWrapper, self).__init__()
32
+ self.model = model
33
+ self.model.requires_grad_(False)
34
+
35
+ self.test_datasets = test_datasets
36
+ self.non_linear_func = torch.nn.ReLU()
37
+
38
+ self.projection_dim = projection_dim
39
+ self.hidden_dim = hidden_dim
40
+
41
+ for dataset_name in test_datasets:
42
+ self.add_surgery_module(dataset_name)
43
+
44
+ def add_surgery_module(self, dataset_name: str):
45
+ """
46
+ Add a surgery module for a given dataset.
47
+
48
+ Args:
49
+ dataset_name (str): The name of the dataset.
50
+ """
51
+ dataset_name = regularize_name(dataset_name)
52
+
53
+ down_proj = torch.nn.Linear(self.projection_dim, self.hidden_dim, bias=False)
54
+ up_proj = torch.nn.Linear(self.hidden_dim, self.projection_dim, bias=False)
55
+
56
+ torch.nn.init.kaiming_uniform_(down_proj.weight, a=math.sqrt(5))
57
+ torch.nn.init.zeros_(up_proj.weight)
58
+
59
+ self.add_module(
60
+ "feature_mapping_to_head_down_proj_{}".format(dataset_name), down_proj
61
+ )
62
+ self.add_module(
63
+ "feature_mapping_to_head_up_proj_{}".format(dataset_name), up_proj
64
+ )
65
+
66
+ def collect_trainable_params(self):
67
+ trainable_params = []
68
+
69
+ # surgery parameter
70
+ for dataset_name in self.test_datasets:
71
+ dataset_name = regularize_name(dataset_name)
72
+ down_proj = getattr(
73
+ self, "feature_mapping_to_head_down_proj_{}".format(dataset_name)
74
+ )
75
+ up_proj = getattr(
76
+ self, "feature_mapping_to_head_up_proj_{}".format(dataset_name)
77
+ )
78
+ trainable_params.append(down_proj.weight)
79
+ trainable_params.append(up_proj.weight)
80
+ return trainable_params
81
+
82
+ def collect_surgery_module(self):
83
+ surgery_module = {}
84
+
85
+ # surgery parameter
86
+ for dataset_name in self.test_datasets:
87
+ dataset_name = regularize_name(dataset_name)
88
+ down_proj = getattr(
89
+ self, "feature_mapping_to_head_down_proj_{}".format(dataset_name)
90
+ )
91
+ up_proj = getattr(
92
+ self, "feature_mapping_to_head_up_proj_{}".format(dataset_name)
93
+ )
94
+ surgery_module[
95
+ "feature_mapping_to_head_down_proj_{}".format(dataset_name)
96
+ ] = down_proj
97
+ surgery_module[
98
+ "feature_mapping_to_head_up_proj_{}".format(dataset_name)
99
+ ] = up_proj
100
+
101
+ surgery_module["non_linear_func"] = self.non_linear_func
102
+
103
+ return surgery_module
104
+
105
+ def compute_surgery_features(
106
+ self,
107
+ compute_features_fn: Union[
108
+ torch.Tensor, Callable[[TorchModelType], torch.Tensor]
109
+ ],
110
+ dataset_name: str,
111
+ ):
112
+ """
113
+ Compute the surgery features.
114
+
115
+ Args:
116
+ compute_features_fn (Union[torch.Tensor, Callable[[nn.Module], torch.Tensor]]): A function that computes the features or a tensor that represents the features.
117
+ dataset_name (str): The name of the dataset.
118
+
119
+ Returns:
120
+ feature (torch.Tensor): The surgery features.
121
+ feature0 (torch.Tensor): The original features.
122
+ feature_sub (torch.Tensor): feature0 - feature.
123
+ """
124
+ dataset_name = regularize_name(dataset_name)
125
+
126
+ if isinstance(compute_features_fn, torch.Tensor):
127
+ feature = compute_features_fn
128
+ elif callable(compute_features_fn):
129
+ feature = compute_features_fn(self.model)
130
+ else:
131
+ raise ValueError(
132
+ "compute_features_fn must be a tensor or a callable, but got {}".format(
133
+ type(compute_features_fn)
134
+ )
135
+ )
136
+
137
+ feature0 = feature
138
+
139
+ # feature bias
140
+ down_proj = getattr(
141
+ self, "feature_mapping_to_head_down_proj_{}".format(dataset_name)
142
+ )
143
+ up_proj = getattr(
144
+ self, "feature_mapping_to_head_up_proj_{}".format(dataset_name)
145
+ )
146
+ feature_sub = down_proj(feature)
147
+ feature_sub = self.non_linear_func(feature_sub)
148
+ feature_sub = up_proj(feature_sub)
149
+
150
+ # surgery feature
151
+ feature = feature0 - feature_sub
152
+
153
+ return feature, feature0, feature_sub
154
+
155
+ def forward(self, *args, **kwargs):
156
+ """The wrappered model should just forward like normal."""
157
+ return self.model(*args, **kwargs)
@@ -1,5 +1,6 @@
1
1
  from typing import List
2
2
 
3
+ import torch
3
4
  from torch import nn
4
5
 
5
6
 
@@ -70,3 +71,10 @@ def find_layers_with_type(
70
71
  if isinstance(submodule, tuple(layer_types)):
71
72
  res[name] = submodule
72
73
  return res
74
+
75
+
76
+ def disable_dropout(model: torch.nn.Module):
77
+ """Disable dropout in a model."""
78
+ for module in model.modules():
79
+ if isinstance(module, torch.nn.Dropout):
80
+ module.p = 0