fusion-bench 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +2 -0
- fusion_bench/compat/method/base_algorithm.py +7 -2
- fusion_bench/compat/modelpool/__init__.py +3 -2
- fusion_bench/compat/taskpool/__init__.py +1 -1
- fusion_bench/dataset/arc_agi/__init__.py +6 -1
- fusion_bench/dataset/arc_agi/arc.py +26 -7
- fusion_bench/dataset/arc_agi/arc_agi.py +156 -25
- fusion_bench/dataset/arc_agi/np_cache.py +0 -1
- fusion_bench/dataset/arc_agi/preprocess.py +51 -9
- fusion_bench/dataset/llama/__init__.py +1 -0
- fusion_bench/dataset/llama/alpaca.py +93 -3
- fusion_bench/dataset/llama/collate.py +72 -5
- fusion_bench/dataset/llama/metamathqa.py +50 -0
- fusion_bench/dataset/llama/preference_700k.py +70 -0
- fusion_bench/dataset/llama/stanford_shp.py +90 -0
- fusion_bench/dataset/llama/ultrachat.py +58 -0
- fusion_bench/dataset/llama/utils/__init__.py +0 -0
- fusion_bench/method/__init__.py +4 -1
- fusion_bench/method/adamerging/__init__.py +1 -1
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
- fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
- fusion_bench/method/linear/expo.py +39 -0
- fusion_bench/method/lm_finetune/__init__.py +1 -0
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
- fusion_bench/method/lm_finetune/fullfinetune_sft.py +122 -150
- fusion_bench/method/lm_finetune/peftfinetune_sft.py +102 -157
- fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
- fusion_bench/method/pruning/llama_random_prune.py +2 -2
- fusion_bench/method/pruning/magnitude_diff_pruning.py +2 -1
- fusion_bench/method/rankone_moe/__init__.py +3 -0
- fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
- fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
- fusion_bench/method/simple_average.py +1 -1
- fusion_bench/method/surgery/__init__.py +3 -0
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
- fusion_bench/mixins/__init__.py +2 -0
- fusion_bench/mixins/clip_classification.py +60 -12
- fusion_bench/mixins/fabric_training.py +320 -0
- fusion_bench/mixins/lightning_fabric.py +11 -2
- fusion_bench/modelpool/__init__.py +2 -0
- fusion_bench/modelpool/causal_lm/__init__.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
- fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
- fusion_bench/models/chat_templates/__init__.py +1 -0
- fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
- fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
- fusion_bench/models/hf_clip.py +50 -9
- fusion_bench/models/rankone_moe.py +410 -0
- fusion_bench/models/surgery/surgerymodelwrapper.py +157 -0
- fusion_bench/models/utils.py +8 -0
- fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
- fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
- fusion_bench/optim/__init__.py +2 -0
- fusion_bench/optim/exception.py +47 -0
- fusion_bench/optim/lr_scheduler/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
- fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
- fusion_bench/optim/mezo.py +0 -2
- fusion_bench/programs/fabric_fusion_program.py +5 -1
- fusion_bench/taskpool/__init__.py +10 -2
- fusion_bench/taskpool/clip_vision/__init__.py +1 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
- fusion_bench/taskpool/llama/reward_model.py +157 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +2 -1
- fusion_bench/utils/hydra_utils.py +22 -0
- fusion_bench/utils/plot/__init__.py +0 -0
- fusion_bench/utils/plot/token.py +52 -0
- fusion_bench/utils/plot/token_notebook.py +127 -0
- fusion_bench/utils/type.py +5 -3
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/RECORD +104 -57
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
- fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
- fusion_bench_config/fabric_model_fusion.yaml +1 -1
- fusion_bench_config/llama_full_finetune.yaml +19 -0
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +13 -6
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +17 -9
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
- fusion_bench_config/method/regmean/clip_regmean.yaml +1 -0
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
- fusion_bench_config/nyuv2_config.yaml +5 -1
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
- fusion_bench_config/llama_weighted_average.yaml +0 -26
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/top_level.txt +0 -0
fusion_bench/models/hf_clip.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
|
|
1
|
+
import logging
|
|
2
|
+
from typing import TYPE_CHECKING, Callable, Iterable, List # noqa: F401
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
5
|
from torch import Tensor, nn
|
|
@@ -7,6 +8,11 @@ from transformers.models.clip.modeling_clip import BaseModelOutputWithPooling
|
|
|
7
8
|
|
|
8
9
|
from fusion_bench.utils.devices import get_device
|
|
9
10
|
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from fusion_bench.models.surgery.surgerymodelwrapper import SurgeryModelWrapper
|
|
13
|
+
|
|
14
|
+
log = logging.getLogger(__name__)
|
|
15
|
+
|
|
10
16
|
default_templates = [
|
|
11
17
|
lambda c: f"a photo of a {c}",
|
|
12
18
|
]
|
|
@@ -33,6 +39,7 @@ class HFCLIPClassifier(nn.Module):
|
|
|
33
39
|
self,
|
|
34
40
|
clip_model: CLIPModel,
|
|
35
41
|
processor: CLIPProcessor,
|
|
42
|
+
extra_module=None,
|
|
36
43
|
):
|
|
37
44
|
"""
|
|
38
45
|
Initialize the HFCLIPClassifier.
|
|
@@ -56,6 +63,8 @@ class HFCLIPClassifier(nn.Module):
|
|
|
56
63
|
persistent=False,
|
|
57
64
|
)
|
|
58
65
|
|
|
66
|
+
self.extra_module = extra_module
|
|
67
|
+
|
|
59
68
|
@property
|
|
60
69
|
def text_model(self):
|
|
61
70
|
"""Get the text model component of CLIP."""
|
|
@@ -111,7 +120,13 @@ class HFCLIPClassifier(nn.Module):
|
|
|
111
120
|
|
|
112
121
|
self.zeroshot_weights = zeroshot_weights
|
|
113
122
|
|
|
114
|
-
def forward(
|
|
123
|
+
def forward(
|
|
124
|
+
self,
|
|
125
|
+
images: Tensor,
|
|
126
|
+
return_image_embeds=False,
|
|
127
|
+
return_dict=False,
|
|
128
|
+
task_name=None,
|
|
129
|
+
):
|
|
115
130
|
"""
|
|
116
131
|
Perform forward pass for zero-shot image classification.
|
|
117
132
|
|
|
@@ -120,6 +135,9 @@ class HFCLIPClassifier(nn.Module):
|
|
|
120
135
|
|
|
121
136
|
Args:
|
|
122
137
|
images (Tensor): Input images to classify.
|
|
138
|
+
return_image_embeds (bool): Whether to return the image embeddings.
|
|
139
|
+
return_dict (bool): Whether to return a dictionary with logits and image embeddings.
|
|
140
|
+
task_name (Optional[str]): The name of the task.
|
|
123
141
|
|
|
124
142
|
Returns:
|
|
125
143
|
Tensor: Classification logits for each input image.
|
|
@@ -131,16 +149,22 @@ class HFCLIPClassifier(nn.Module):
|
|
|
131
149
|
raise ValueError("Must set classification task before forward pass")
|
|
132
150
|
text_embeds = self.zeroshot_weights
|
|
133
151
|
|
|
134
|
-
image_embeds = self.
|
|
135
|
-
if isinstance(image_embeds, Tensor):
|
|
136
|
-
pass
|
|
137
|
-
elif isinstance(image_embeds, BaseModelOutputWithPooling):
|
|
138
|
-
image_embeds = image_embeds[1]
|
|
139
|
-
image_embeds = self.clip_model.visual_projection(image_embeds)
|
|
140
|
-
|
|
152
|
+
image_embeds = self.get_image_features(images)
|
|
141
153
|
# normalize embeddings
|
|
142
154
|
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
|
143
155
|
|
|
156
|
+
if (
|
|
157
|
+
hasattr(self.vision_model, "is_surgery_model")
|
|
158
|
+
and self.vision_model.is_surgery_model
|
|
159
|
+
):
|
|
160
|
+
# Dealing with the surgery model, for more details, please refer to:
|
|
161
|
+
# (ICML 2024) Yang, et.al. Representation Surgery for Multi-Task Model Merging
|
|
162
|
+
# https://arxiv.org/abs/2402.02705
|
|
163
|
+
self.vision_model: "SurgeryModelWrapper" = self.vision_model
|
|
164
|
+
image_embeds, _, _ = self.vision_model.compute_surgery_features(
|
|
165
|
+
image_embeds, dataset_name=task_name
|
|
166
|
+
)
|
|
167
|
+
|
|
144
168
|
# cosine similarity
|
|
145
169
|
logit_scale = self.clip_model.logit_scale.exp()
|
|
146
170
|
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
|
|
@@ -156,3 +180,20 @@ class HFCLIPClassifier(nn.Module):
|
|
|
156
180
|
return logits_per_image, image_embeds
|
|
157
181
|
else:
|
|
158
182
|
return logits_per_image
|
|
183
|
+
|
|
184
|
+
def get_image_features(self, images: Tensor) -> Tensor:
|
|
185
|
+
"""
|
|
186
|
+
Compute the image embeddings.
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
|
|
190
|
+
applying the projection layer to the pooled output of [`CLIPVisionModel`].
|
|
191
|
+
"""
|
|
192
|
+
|
|
193
|
+
image_embeds = self.vision_model(images)
|
|
194
|
+
if isinstance(image_embeds, Tensor):
|
|
195
|
+
pass
|
|
196
|
+
elif isinstance(image_embeds, BaseModelOutputWithPooling):
|
|
197
|
+
image_embeds = image_embeds[1]
|
|
198
|
+
image_embeds = self.clip_model.visual_projection(image_embeds)
|
|
199
|
+
return image_embeds
|
|
@@ -0,0 +1,410 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Dict, List, Tuple # noqa: F401
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.func
|
|
7
|
+
from torch import Tensor, nn
|
|
8
|
+
from torch.func import functional_call
|
|
9
|
+
from torch.nn import functional as F
|
|
10
|
+
|
|
11
|
+
from fusion_bench.utils.type import StateDictType
|
|
12
|
+
|
|
13
|
+
log = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def join_list(list_of_list: List[List]):
|
|
17
|
+
ans = []
|
|
18
|
+
for l in list_of_list:
|
|
19
|
+
ans.extend(l)
|
|
20
|
+
return ans
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def del_attr(obj, names: List[str]):
|
|
24
|
+
"""
|
|
25
|
+
Deletes an attribute from an object recursively.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
obj (object): Object to delete attribute from.
|
|
29
|
+
names (list): List of attribute names to delete recursively.
|
|
30
|
+
"""
|
|
31
|
+
if len(names) == 1:
|
|
32
|
+
delattr(obj, names[0])
|
|
33
|
+
else:
|
|
34
|
+
del_attr(getattr(obj, names[0]), names[1:])
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def set_attr(obj, names: List[str], val):
|
|
38
|
+
"""
|
|
39
|
+
Sets an attribute of an object recursively.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
obj (object): Object to set attribute of.
|
|
43
|
+
names (list): List of attribute names to set recursively.
|
|
44
|
+
val (object): Value to set the attribute to.
|
|
45
|
+
"""
|
|
46
|
+
if len(names) == 1:
|
|
47
|
+
setattr(obj, names[0], val)
|
|
48
|
+
else:
|
|
49
|
+
set_attr(getattr(obj, names[0]), names[1:], val)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def get_attr(obj, names: List[str]):
|
|
53
|
+
"""
|
|
54
|
+
Gets an attribute of an object recursively.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
obj (object): Object to get attribute of.
|
|
58
|
+
names (list): List of attribute names to get recursively.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
object: The attribute of the object.
|
|
62
|
+
"""
|
|
63
|
+
if len(names) == 1:
|
|
64
|
+
return getattr(obj, names[0])
|
|
65
|
+
else:
|
|
66
|
+
return get_attr(getattr(obj, names[0]), names[1:])
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class Depth_0_Gate(nn.Module):
|
|
70
|
+
def __init__(self, num_experts: int):
|
|
71
|
+
super().__init__()
|
|
72
|
+
self.weight = nn.Parameter(torch.empty(num_experts), requires_grad=True)
|
|
73
|
+
|
|
74
|
+
def init_weight(self, init_lambda: float):
|
|
75
|
+
nn.init.constant_(self.weight, init_lambda)
|
|
76
|
+
|
|
77
|
+
def forward(self, *args, **kwargs) -> Tensor:
|
|
78
|
+
return self.weight
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class Depth_1_Gate(nn.Module):
|
|
82
|
+
def __init__(self, hidden_size: int, num_experts: int):
|
|
83
|
+
super().__init__()
|
|
84
|
+
self.fc = nn.Linear(hidden_size, num_experts, bias=True)
|
|
85
|
+
|
|
86
|
+
def init_weight(self, init_lambda: float):
|
|
87
|
+
nn.init.normal_(self.fc.weight, std=0.01)
|
|
88
|
+
nn.init.constant_(self.fc.bias, init_lambda)
|
|
89
|
+
|
|
90
|
+
def forward(self, hidden_states: Tensor) -> Tensor:
|
|
91
|
+
return self.fc(hidden_states)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class Depth_2_Gate(nn.Module):
|
|
95
|
+
def __init__(self, hidden_size: int, num_experts: int):
|
|
96
|
+
super().__init__()
|
|
97
|
+
self.fc1 = nn.Linear(hidden_size, num_experts * 2, bias=True)
|
|
98
|
+
self.fc2 = nn.Linear(num_experts * 2, num_experts, bias=True)
|
|
99
|
+
|
|
100
|
+
def init_weight(self, init_lambda: float):
|
|
101
|
+
nn.init.normal_(self.fc1.weight, std=0.01)
|
|
102
|
+
nn.init.zeros_(self.fc1.bias)
|
|
103
|
+
nn.init.normal_(self.fc2.weight, std=0.01)
|
|
104
|
+
nn.init.constant_(self.fc2.bias, init_lambda)
|
|
105
|
+
|
|
106
|
+
def forward(self, hidden_states: Tensor) -> Tensor:
|
|
107
|
+
hidden_states = F.relu(self.fc1(hidden_states))
|
|
108
|
+
return self.fc2(hidden_states)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def construct_rankone_moe_gate(
|
|
112
|
+
hidden_size: int,
|
|
113
|
+
num_experts: int,
|
|
114
|
+
init_lambda: float,
|
|
115
|
+
num_hidden_layers: int = 2,
|
|
116
|
+
):
|
|
117
|
+
if num_hidden_layers == 0:
|
|
118
|
+
gate = Depth_0_Gate(num_experts)
|
|
119
|
+
elif num_hidden_layers == 1:
|
|
120
|
+
gate = Depth_1_Gate(hidden_size, num_experts)
|
|
121
|
+
elif num_hidden_layers == 2:
|
|
122
|
+
gate = Depth_2_Gate(hidden_size, num_experts)
|
|
123
|
+
else:
|
|
124
|
+
raise ValueError(f"Unsupported number of hidden layers: {num_hidden_layers}")
|
|
125
|
+
|
|
126
|
+
gate.num_hidden_layers = num_hidden_layers
|
|
127
|
+
gate.init_weight(init_lambda)
|
|
128
|
+
return gate
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class ExpertNotTrainedError(Exception):
|
|
132
|
+
pass
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _is_all_zeros(tensor: Tensor | List[Tensor]) -> bool:
|
|
136
|
+
"""
|
|
137
|
+
Check if a tensor or a list of tensors are all zeros.
|
|
138
|
+
"""
|
|
139
|
+
if isinstance(tensor, Tensor):
|
|
140
|
+
return torch.allclose(tensor, torch.zeros_like(tensor))
|
|
141
|
+
else:
|
|
142
|
+
return all(_is_all_zeros(t) for t in tensor)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _svd(w: Tensor, full_matrices=True) -> Tuple[Tensor, Tensor, Tensor]:
|
|
146
|
+
"""
|
|
147
|
+
Perform Singular Value Decomposition (SVD) on a tensor.
|
|
148
|
+
"""
|
|
149
|
+
u, s, vh = torch.linalg.svd(
|
|
150
|
+
w, full_matrices=full_matrices, driver="gesvd" if w.is_cuda else None
|
|
151
|
+
)
|
|
152
|
+
v = vh.T
|
|
153
|
+
return u, s, v
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def svd(
|
|
157
|
+
w: Tensor, full_matrices=True, accelerator=None
|
|
158
|
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
|
159
|
+
"""
|
|
160
|
+
Perform SVD on a tensor, optionally using a specified accelerator.
|
|
161
|
+
"""
|
|
162
|
+
if accelerator is None:
|
|
163
|
+
return _svd(w, full_matrices=full_matrices)
|
|
164
|
+
original_device = w.device
|
|
165
|
+
w = w.to(accelerator)
|
|
166
|
+
u, s, v = _svd(w)
|
|
167
|
+
return u.to(original_device), s.to(original_device), v.to(original_device)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def fun_joint_svd(
|
|
171
|
+
w_list: List[Tensor], accelerator=None
|
|
172
|
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
|
173
|
+
|
|
174
|
+
w = torch.cat(w_list, dim=1) # stacked_matrix
|
|
175
|
+
original_device = w.device
|
|
176
|
+
if accelerator is not None:
|
|
177
|
+
w = w.to(accelerator)
|
|
178
|
+
u_c, s_c, vh_c = torch.linalg.svd(
|
|
179
|
+
w, full_matrices=False, driver="gesvd" if w.is_cuda else None
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
svd_list = []
|
|
183
|
+
offset = 0
|
|
184
|
+
for matrix in w_list:
|
|
185
|
+
n_cols = matrix.size(1)
|
|
186
|
+
u = u_c
|
|
187
|
+
s = s_c
|
|
188
|
+
vh_ = vh_c[:, offset : offset + n_cols]
|
|
189
|
+
v = vh_.T
|
|
190
|
+
svd_list.append(
|
|
191
|
+
[u.to(original_device), s.to(original_device), v.to(original_device)]
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
offset += n_cols
|
|
195
|
+
return svd_list
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class RankOneMoE(nn.Module):
|
|
199
|
+
# variable to store the merged state dict temporarily
|
|
200
|
+
_merged_state_dict: StateDictType = None
|
|
201
|
+
|
|
202
|
+
def __init__(
|
|
203
|
+
self,
|
|
204
|
+
hidden_size: int,
|
|
205
|
+
base_model: nn.Module,
|
|
206
|
+
expert_models: List[nn.Module],
|
|
207
|
+
init_lambda: float = 0.2,
|
|
208
|
+
batch_first: bool = False,
|
|
209
|
+
router_hidden_layers: int = 2,
|
|
210
|
+
batch_reduce: bool = False,
|
|
211
|
+
svd_accelerator=False,
|
|
212
|
+
rank_k: int = -1,
|
|
213
|
+
select_k: int = -1,
|
|
214
|
+
):
|
|
215
|
+
"""
|
|
216
|
+
Initializes the RankOneMoE class.
|
|
217
|
+
https://github.com/EnnengYang/RankOne-MoE
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
hidden_size (int): The size of the hidden layer in the models.
|
|
221
|
+
base_model (nn.Module): The base model that will be used as a reference for the expert models.
|
|
222
|
+
expert_models (List[nn.Module]): A list of expert models that will be combined.
|
|
223
|
+
init_lambda (float, optional): The initial lambda value for the weight ensembling gate. Defaults to 0.2.
|
|
224
|
+
batch_first (bool, optional): If True, the input tensors are expected to have the batch size as the first dimension. Defaults to False.
|
|
225
|
+
router_hidden_layers (int, optional): The number of hidden layers in the router. Defaults to 2.
|
|
226
|
+
batch_reduce (bool): If True, the batch dimension of routing weights is reduced. Defaults to False.
|
|
227
|
+
"""
|
|
228
|
+
super().__init__()
|
|
229
|
+
self.num_experts = len(expert_models)
|
|
230
|
+
self.hidden_size = hidden_size
|
|
231
|
+
self.batch_first = batch_first
|
|
232
|
+
self.batch_reduce = batch_reduce
|
|
233
|
+
self.svd_accelerator = svd_accelerator
|
|
234
|
+
self.rank_k = rank_k
|
|
235
|
+
self.select_k = select_k
|
|
236
|
+
self.init_lambda = init_lambda
|
|
237
|
+
|
|
238
|
+
self.gate = construct_rankone_moe_gate(
|
|
239
|
+
hidden_size=hidden_size,
|
|
240
|
+
num_experts=int(self.num_experts * self.rank_k),
|
|
241
|
+
init_lambda=init_lambda,
|
|
242
|
+
num_hidden_layers=router_hidden_layers,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
# compute the task vectors
|
|
246
|
+
for name, param in base_model.named_parameters():
|
|
247
|
+
if not param.requires_grad:
|
|
248
|
+
for m in expert_models:
|
|
249
|
+
del_attr(m, name.split("."))
|
|
250
|
+
else:
|
|
251
|
+
for m in expert_models:
|
|
252
|
+
get_attr(m, name.split(".")).data = (
|
|
253
|
+
get_attr(m, name.split(".")) - param
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
# fix base model and expert models
|
|
257
|
+
self.base_model = base_model.requires_grad_(False)
|
|
258
|
+
for m in expert_models:
|
|
259
|
+
m.requires_grad_(False)
|
|
260
|
+
|
|
261
|
+
# task vecotr (only bias term)
|
|
262
|
+
self.task_vectors_fc1_bias = nn.Parameter(
|
|
263
|
+
torch.stack([e.fc1.bias for e in expert_models], dim=0), requires_grad=False
|
|
264
|
+
)
|
|
265
|
+
self.task_vectors_fc2_bias = nn.Parameter(
|
|
266
|
+
torch.stack([e.fc2.bias for e in expert_models], dim=0), requires_grad=False
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
# SVD representation of task vector (only weight term)
|
|
270
|
+
self.task_vectors_fc1_u = nn.ParameterList()
|
|
271
|
+
self.task_vectors_fc1_svh = nn.ParameterList()
|
|
272
|
+
self.task_vectors_fc2_u = nn.ParameterList()
|
|
273
|
+
self.task_vectors_fc2_svh = nn.ParameterList()
|
|
274
|
+
|
|
275
|
+
for m in expert_models:
|
|
276
|
+
for name, param in m.named_parameters():
|
|
277
|
+
if ".weight" in name:
|
|
278
|
+
|
|
279
|
+
if _is_all_zeros(param):
|
|
280
|
+
# All fine-tuned models are identical to the pretrained model
|
|
281
|
+
raise ExpertNotTrainedError()
|
|
282
|
+
|
|
283
|
+
u, s, v = svd(param, accelerator=self.svd_accelerator)
|
|
284
|
+
u = u[:, : self.rank_k]
|
|
285
|
+
s = s[: self.rank_k]
|
|
286
|
+
v = v[:, : self.rank_k]
|
|
287
|
+
|
|
288
|
+
if "fc1.weight" == name:
|
|
289
|
+
self.task_vectors_fc1_u.append(
|
|
290
|
+
nn.Parameter(u.T, requires_grad=False)
|
|
291
|
+
)
|
|
292
|
+
self.task_vectors_fc1_svh.append(
|
|
293
|
+
nn.Parameter((s * v).T, requires_grad=False)
|
|
294
|
+
)
|
|
295
|
+
elif "fc2.weight" == name:
|
|
296
|
+
self.task_vectors_fc2_u.append(
|
|
297
|
+
nn.Parameter(u.T, requires_grad=False)
|
|
298
|
+
)
|
|
299
|
+
self.task_vectors_fc2_svh.append(
|
|
300
|
+
nn.Parameter((s * v).T, requires_grad=False)
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
# remove the original module from fine-tuned models to save memory
|
|
304
|
+
for name, param in base_model.named_parameters():
|
|
305
|
+
name_list = name.split(".")
|
|
306
|
+
for m in expert_models:
|
|
307
|
+
set_attr(m, name_list, None)
|
|
308
|
+
|
|
309
|
+
@property
|
|
310
|
+
def forward_model(self):
|
|
311
|
+
return functools.partial(
|
|
312
|
+
functional_call,
|
|
313
|
+
self.base_model,
|
|
314
|
+
self._merged_state_dict,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
def top_k_soft(self, s, k):
|
|
318
|
+
threshold, _ = torch.topk(s, k, largest=True, sorted=False)
|
|
319
|
+
min_threshold = threshold.min()
|
|
320
|
+
# sigmoid -> mask
|
|
321
|
+
mask = torch.sigmoid(100 * (s - min_threshold))
|
|
322
|
+
result = s * mask
|
|
323
|
+
return result
|
|
324
|
+
|
|
325
|
+
def merge_weights(self, expert_weights):
|
|
326
|
+
state_dict = self.base_model.state_dict(keep_vars=True)
|
|
327
|
+
|
|
328
|
+
# Select top-k experts from the expert pool for fusion
|
|
329
|
+
if self.select_k > 0:
|
|
330
|
+
expert_weights = self.top_k_soft(expert_weights, self.select_k)
|
|
331
|
+
|
|
332
|
+
for name in state_dict:
|
|
333
|
+
if name == "fc1.bias":
|
|
334
|
+
for param in self.task_vectors_fc1_bias:
|
|
335
|
+
state_dict[name] = state_dict[name] + self.init_lambda * param
|
|
336
|
+
elif name == "fc2.bias":
|
|
337
|
+
for param in self.task_vectors_fc2_bias:
|
|
338
|
+
state_dict[name] = state_dict[name] + self.init_lambda * param
|
|
339
|
+
|
|
340
|
+
elif name == "fc1.weight":
|
|
341
|
+
w_list = torch.split(
|
|
342
|
+
expert_weights,
|
|
343
|
+
int(expert_weights.size(-1) / self.num_experts),
|
|
344
|
+
dim=-1,
|
|
345
|
+
)
|
|
346
|
+
for weight, u, svh in zip(
|
|
347
|
+
w_list, self.task_vectors_fc1_u, self.task_vectors_fc1_svh
|
|
348
|
+
):
|
|
349
|
+
weight_diag = torch.diag(weight)
|
|
350
|
+
weight_u = torch.mm(weight_diag, u)
|
|
351
|
+
result = torch.matmul(weight_u.T, svh)
|
|
352
|
+
state_dict[name] = state_dict[name] + result
|
|
353
|
+
|
|
354
|
+
elif name == "fc2.weight":
|
|
355
|
+
w_list = torch.split(
|
|
356
|
+
expert_weights,
|
|
357
|
+
int(expert_weights.size(-1) / self.num_experts),
|
|
358
|
+
dim=-1,
|
|
359
|
+
)
|
|
360
|
+
for weight, u, svh in zip(
|
|
361
|
+
w_list, self.task_vectors_fc2_u, self.task_vectors_fc2_svh
|
|
362
|
+
):
|
|
363
|
+
weight_diag = torch.diag(weight)
|
|
364
|
+
weight_u = torch.mm(weight_diag, u)
|
|
365
|
+
result = torch.matmul(weight_u.T, svh)
|
|
366
|
+
state_dict[name] = state_dict[name] + result
|
|
367
|
+
|
|
368
|
+
self._merged_state_dict = state_dict
|
|
369
|
+
return state_dict
|
|
370
|
+
|
|
371
|
+
def forward(self, hidden_states: Tensor):
|
|
372
|
+
if self.gate.num_hidden_layers == 0:
|
|
373
|
+
gate_weights = self.gate()
|
|
374
|
+
else:
|
|
375
|
+
gate_weights = self.gate(hidden_states)
|
|
376
|
+
if self.batch_first:
|
|
377
|
+
# the input is in the shape of (batch_size, seq_len, hidden_size)
|
|
378
|
+
gate_weights = gate_weights.mean(dim=1)
|
|
379
|
+
else:
|
|
380
|
+
# the input is in the shape of (seq_len, batch_size, hidden_size)
|
|
381
|
+
gate_weights = gate_weights.mean(dim=0)
|
|
382
|
+
|
|
383
|
+
if self.gate.num_hidden_layers == 0:
|
|
384
|
+
self.merge_weights(gate_weights)
|
|
385
|
+
output_hidden_states = self.forward_model(hidden_states)
|
|
386
|
+
elif self.batch_reduce:
|
|
387
|
+
gate_weights = gate_weights.mean(dim=0)
|
|
388
|
+
self.merge_weights(gate_weights)
|
|
389
|
+
output_hidden_states = self.forward_model(hidden_states)
|
|
390
|
+
else:
|
|
391
|
+
output_hidden_states = []
|
|
392
|
+
for sample_idx, weights in enumerate(gate_weights):
|
|
393
|
+
self.merge_weights(weights)
|
|
394
|
+
if self.batch_first:
|
|
395
|
+
output_hidden_states.append(
|
|
396
|
+
self.forward_model(hidden_states[sample_idx : sample_idx + 1])
|
|
397
|
+
)
|
|
398
|
+
else:
|
|
399
|
+
output_hidden_states.append(
|
|
400
|
+
self.forward_model(
|
|
401
|
+
hidden_states[:, sample_idx : sample_idx + 1]
|
|
402
|
+
)
|
|
403
|
+
)
|
|
404
|
+
if self.batch_first:
|
|
405
|
+
output_hidden_states = torch.cat(output_hidden_states, dim=0)
|
|
406
|
+
else:
|
|
407
|
+
output_hidden_states = torch.cat(output_hidden_states, dim=1)
|
|
408
|
+
|
|
409
|
+
self._merged_state_dict = None
|
|
410
|
+
return output_hidden_states
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import TYPE_CHECKING, List, Union, Callable, Generic
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn
|
|
6
|
+
from transformers.models.clip.modeling_clip import (
|
|
7
|
+
CLIPVisionModel,
|
|
8
|
+
CLIPVisionTransformer,
|
|
9
|
+
)
|
|
10
|
+
from fusion_bench.utils.type import TorchModelType
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def regularize_name(name: str):
|
|
14
|
+
name = name.replace("-", "_")
|
|
15
|
+
name = name.replace(".", "_")
|
|
16
|
+
return name
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SurgeryModelWrapper(torch.nn.Module, Generic[TorchModelType]):
|
|
20
|
+
|
|
21
|
+
is_surgery_model = True
|
|
22
|
+
"""A flag to indicate that this is a surgery model."""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
model: TorchModelType,
|
|
27
|
+
test_datasets: List[str],
|
|
28
|
+
projection_dim: int = 512,
|
|
29
|
+
hidden_dim: int = 16,
|
|
30
|
+
):
|
|
31
|
+
super(SurgeryModelWrapper, self).__init__()
|
|
32
|
+
self.model = model
|
|
33
|
+
self.model.requires_grad_(False)
|
|
34
|
+
|
|
35
|
+
self.test_datasets = test_datasets
|
|
36
|
+
self.non_linear_func = torch.nn.ReLU()
|
|
37
|
+
|
|
38
|
+
self.projection_dim = projection_dim
|
|
39
|
+
self.hidden_dim = hidden_dim
|
|
40
|
+
|
|
41
|
+
for dataset_name in test_datasets:
|
|
42
|
+
self.add_surgery_module(dataset_name)
|
|
43
|
+
|
|
44
|
+
def add_surgery_module(self, dataset_name: str):
|
|
45
|
+
"""
|
|
46
|
+
Add a surgery module for a given dataset.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
dataset_name (str): The name of the dataset.
|
|
50
|
+
"""
|
|
51
|
+
dataset_name = regularize_name(dataset_name)
|
|
52
|
+
|
|
53
|
+
down_proj = torch.nn.Linear(self.projection_dim, self.hidden_dim, bias=False)
|
|
54
|
+
up_proj = torch.nn.Linear(self.hidden_dim, self.projection_dim, bias=False)
|
|
55
|
+
|
|
56
|
+
torch.nn.init.kaiming_uniform_(down_proj.weight, a=math.sqrt(5))
|
|
57
|
+
torch.nn.init.zeros_(up_proj.weight)
|
|
58
|
+
|
|
59
|
+
self.add_module(
|
|
60
|
+
"feature_mapping_to_head_down_proj_{}".format(dataset_name), down_proj
|
|
61
|
+
)
|
|
62
|
+
self.add_module(
|
|
63
|
+
"feature_mapping_to_head_up_proj_{}".format(dataset_name), up_proj
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def collect_trainable_params(self):
|
|
67
|
+
trainable_params = []
|
|
68
|
+
|
|
69
|
+
# surgery parameter
|
|
70
|
+
for dataset_name in self.test_datasets:
|
|
71
|
+
dataset_name = regularize_name(dataset_name)
|
|
72
|
+
down_proj = getattr(
|
|
73
|
+
self, "feature_mapping_to_head_down_proj_{}".format(dataset_name)
|
|
74
|
+
)
|
|
75
|
+
up_proj = getattr(
|
|
76
|
+
self, "feature_mapping_to_head_up_proj_{}".format(dataset_name)
|
|
77
|
+
)
|
|
78
|
+
trainable_params.append(down_proj.weight)
|
|
79
|
+
trainable_params.append(up_proj.weight)
|
|
80
|
+
return trainable_params
|
|
81
|
+
|
|
82
|
+
def collect_surgery_module(self):
|
|
83
|
+
surgery_module = {}
|
|
84
|
+
|
|
85
|
+
# surgery parameter
|
|
86
|
+
for dataset_name in self.test_datasets:
|
|
87
|
+
dataset_name = regularize_name(dataset_name)
|
|
88
|
+
down_proj = getattr(
|
|
89
|
+
self, "feature_mapping_to_head_down_proj_{}".format(dataset_name)
|
|
90
|
+
)
|
|
91
|
+
up_proj = getattr(
|
|
92
|
+
self, "feature_mapping_to_head_up_proj_{}".format(dataset_name)
|
|
93
|
+
)
|
|
94
|
+
surgery_module[
|
|
95
|
+
"feature_mapping_to_head_down_proj_{}".format(dataset_name)
|
|
96
|
+
] = down_proj
|
|
97
|
+
surgery_module[
|
|
98
|
+
"feature_mapping_to_head_up_proj_{}".format(dataset_name)
|
|
99
|
+
] = up_proj
|
|
100
|
+
|
|
101
|
+
surgery_module["non_linear_func"] = self.non_linear_func
|
|
102
|
+
|
|
103
|
+
return surgery_module
|
|
104
|
+
|
|
105
|
+
def compute_surgery_features(
|
|
106
|
+
self,
|
|
107
|
+
compute_features_fn: Union[
|
|
108
|
+
torch.Tensor, Callable[[TorchModelType], torch.Tensor]
|
|
109
|
+
],
|
|
110
|
+
dataset_name: str,
|
|
111
|
+
):
|
|
112
|
+
"""
|
|
113
|
+
Compute the surgery features.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
compute_features_fn (Union[torch.Tensor, Callable[[nn.Module], torch.Tensor]]): A function that computes the features or a tensor that represents the features.
|
|
117
|
+
dataset_name (str): The name of the dataset.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
feature (torch.Tensor): The surgery features.
|
|
121
|
+
feature0 (torch.Tensor): The original features.
|
|
122
|
+
feature_sub (torch.Tensor): feature0 - feature.
|
|
123
|
+
"""
|
|
124
|
+
dataset_name = regularize_name(dataset_name)
|
|
125
|
+
|
|
126
|
+
if isinstance(compute_features_fn, torch.Tensor):
|
|
127
|
+
feature = compute_features_fn
|
|
128
|
+
elif callable(compute_features_fn):
|
|
129
|
+
feature = compute_features_fn(self.model)
|
|
130
|
+
else:
|
|
131
|
+
raise ValueError(
|
|
132
|
+
"compute_features_fn must be a tensor or a callable, but got {}".format(
|
|
133
|
+
type(compute_features_fn)
|
|
134
|
+
)
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
feature0 = feature
|
|
138
|
+
|
|
139
|
+
# feature bias
|
|
140
|
+
down_proj = getattr(
|
|
141
|
+
self, "feature_mapping_to_head_down_proj_{}".format(dataset_name)
|
|
142
|
+
)
|
|
143
|
+
up_proj = getattr(
|
|
144
|
+
self, "feature_mapping_to_head_up_proj_{}".format(dataset_name)
|
|
145
|
+
)
|
|
146
|
+
feature_sub = down_proj(feature)
|
|
147
|
+
feature_sub = self.non_linear_func(feature_sub)
|
|
148
|
+
feature_sub = up_proj(feature_sub)
|
|
149
|
+
|
|
150
|
+
# surgery feature
|
|
151
|
+
feature = feature0 - feature_sub
|
|
152
|
+
|
|
153
|
+
return feature, feature0, feature_sub
|
|
154
|
+
|
|
155
|
+
def forward(self, *args, **kwargs):
|
|
156
|
+
"""The wrappered model should just forward like normal."""
|
|
157
|
+
return self.model(*args, **kwargs)
|
fusion_bench/models/utils.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from typing import List
|
|
2
2
|
|
|
3
|
+
import torch
|
|
3
4
|
from torch import nn
|
|
4
5
|
|
|
5
6
|
|
|
@@ -70,3 +71,10 @@ def find_layers_with_type(
|
|
|
70
71
|
if isinstance(submodule, tuple(layer_types)):
|
|
71
72
|
res[name] = submodule
|
|
72
73
|
return res
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def disable_dropout(model: torch.nn.Module):
|
|
77
|
+
"""Disable dropout in a model."""
|
|
78
|
+
for module in model.modules():
|
|
79
|
+
if isinstance(module, torch.nn.Dropout):
|
|
80
|
+
module.p = 0
|