fusion-bench 0.2.21__py3-none-any.whl → 0.2.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +25 -2
- fusion_bench/compat/method/__init__.py +5 -2
- fusion_bench/compat/method/base_algorithm.py +3 -2
- fusion_bench/compat/modelpool/base_pool.py +3 -3
- fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
- fusion_bench/constants/__init__.py +1 -0
- fusion_bench/constants/runtime.py +57 -0
- fusion_bench/dataset/gpt2_glue.py +1 -1
- fusion_bench/method/__init__.py +12 -4
- fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
- fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
- fusion_bench/method/bitdelta/__init__.py +1 -0
- fusion_bench/method/bitdelta/bitdelta.py +7 -23
- fusion_bench/method/classification/clip_finetune.py +1 -1
- fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +0 -4
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +2 -2
- fusion_bench/method/linear/simple_average_for_llama.py +16 -11
- fusion_bench/method/model_stock/__init__.py +1 -0
- fusion_bench/method/model_stock/model_stock.py +309 -0
- fusion_bench/method/regmean/clip_regmean.py +3 -6
- fusion_bench/method/regmean/regmean.py +27 -56
- fusion_bench/method/regmean/utils.py +56 -0
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
- fusion_bench/method/simple_average.py +7 -7
- fusion_bench/method/slerp/__init__.py +1 -1
- fusion_bench/method/slerp/slerp.py +110 -14
- fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +1 -2
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +40 -31
- fusion_bench/method/smile_upscaling/smile_upscaling.py +1 -1
- fusion_bench/method/we_moe/__init__.py +1 -0
- fusion_bench/method/we_moe/entropy_loss.py +25 -0
- fusion_bench/method/we_moe/flan_t5_we_moe.py +320 -0
- fusion_bench/method/we_moe/utils.py +15 -0
- fusion_bench/method/weighted_average/llama.py +1 -1
- fusion_bench/mixins/clip_classification.py +37 -48
- fusion_bench/mixins/serialization.py +30 -10
- fusion_bench/modelpool/base_pool.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +293 -75
- fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
- fusion_bench/models/__init__.py +5 -0
- fusion_bench/models/hf_utils.py +69 -86
- fusion_bench/models/linearized/vision_model.py +6 -6
- fusion_bench/models/model_card_templates/default.md +46 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +1 -8
- fusion_bench/models/modeling_smile_mistral/__init__.py +2 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +1 -5
- fusion_bench/models/we_moe.py +8 -8
- fusion_bench/programs/fabric_fusion_program.py +29 -60
- fusion_bench/scripts/cli.py +34 -1
- fusion_bench/taskpool/base_pool.py +99 -17
- fusion_bench/taskpool/clip_vision/taskpool.py +10 -5
- fusion_bench/taskpool/dummy.py +101 -13
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
- fusion_bench/utils/__init__.py +2 -0
- fusion_bench/utils/cache_utils.py +101 -1
- fusion_bench/utils/data.py +6 -4
- fusion_bench/utils/devices.py +7 -4
- fusion_bench/utils/dtype.py +3 -2
- fusion_bench/utils/fabric.py +2 -2
- fusion_bench/utils/lazy_imports.py +23 -0
- fusion_bench/utils/lazy_state_dict.py +117 -19
- fusion_bench/utils/modelscope.py +3 -3
- fusion_bench/utils/packages.py +3 -3
- fusion_bench/utils/parameters.py +0 -2
- fusion_bench/utils/path.py +56 -0
- fusion_bench/utils/pylogger.py +1 -1
- fusion_bench/utils/timer.py +92 -10
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/METADATA +1 -23
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/RECORD +89 -75
- fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
- fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
- fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
- fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
- fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
- fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -1
- fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +1 -1
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,309 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import logging
|
|
3
|
+
import math
|
|
4
|
+
import os
|
|
5
|
+
from collections import OrderedDict
|
|
6
|
+
from copy import deepcopy
|
|
7
|
+
from typing import Dict, List, Optional, Union
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import torch
|
|
11
|
+
from omegaconf import DictConfig
|
|
12
|
+
from torch import nn
|
|
13
|
+
from transformers import PreTrainedModel
|
|
14
|
+
|
|
15
|
+
import fusion_bench
|
|
16
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
17
|
+
from fusion_bench.mixins import SimpleProfilerMixin
|
|
18
|
+
from fusion_bench.models import create_default_model_card
|
|
19
|
+
from fusion_bench.utils.type import StateDictType
|
|
20
|
+
|
|
21
|
+
log = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
EPS = 1e-8
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def compute_angle(
|
|
27
|
+
state_dict_1: StateDictType,
|
|
28
|
+
state_dict_2: StateDictType,
|
|
29
|
+
ref_state_dict: StateDictType,
|
|
30
|
+
ignore_keys: List[str] = [],
|
|
31
|
+
return_cos: bool = False,
|
|
32
|
+
) -> Dict[str, float]:
|
|
33
|
+
"""
|
|
34
|
+
Compute the angle between two state dictionaries relative to a reference state dictionary.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
state_dict_1: First state dictionary
|
|
38
|
+
state_dict_2: Second state dictionary
|
|
39
|
+
ref_state_dict: Reference state dictionary (typically pre-trained model)
|
|
40
|
+
ignore_keys: Keys to ignore during computation
|
|
41
|
+
return_cos: If True, return cosine values instead of angles in degrees
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
Dictionary mapping parameter names to angles (in degrees) or cosine values
|
|
45
|
+
"""
|
|
46
|
+
# Remove the keys not used for CLIP fine-tuning (from the notebook example)
|
|
47
|
+
|
|
48
|
+
return_dict = OrderedDict()
|
|
49
|
+
|
|
50
|
+
with torch.no_grad():
|
|
51
|
+
for key in ref_state_dict:
|
|
52
|
+
if key in ignore_keys:
|
|
53
|
+
log.info(f"Ignoring key '{key}'")
|
|
54
|
+
continue
|
|
55
|
+
|
|
56
|
+
state_dict_1_val = state_dict_1[key]
|
|
57
|
+
state_dict_2_val = state_dict_2[key]
|
|
58
|
+
ref_val = ref_state_dict[key]
|
|
59
|
+
|
|
60
|
+
if not (state_dict_1_val.shape == state_dict_2_val.shape == ref_val.shape):
|
|
61
|
+
log.warning(
|
|
62
|
+
f"Shape mismatch for key '{key}', ignored during merging: "
|
|
63
|
+
f"({state_dict_1_val.shape}, {state_dict_2_val.shape}, {ref_val.shape})"
|
|
64
|
+
)
|
|
65
|
+
continue
|
|
66
|
+
|
|
67
|
+
vector1 = (state_dict_1_val - ref_val).clone().detach()
|
|
68
|
+
vector2 = (state_dict_2_val - ref_val).clone().detach()
|
|
69
|
+
|
|
70
|
+
vector1 = vector1.float()
|
|
71
|
+
vector2 = vector2.float()
|
|
72
|
+
|
|
73
|
+
cosine_val = torch.sum(vector1 * vector2) / (
|
|
74
|
+
math.sqrt(torch.sum(vector1**2) * torch.sum(vector2**2)) + EPS
|
|
75
|
+
)
|
|
76
|
+
cosine_val = torch.clamp(
|
|
77
|
+
cosine_val, min=-1.0, max=1.0
|
|
78
|
+
) # Prevent nan from acos
|
|
79
|
+
|
|
80
|
+
if return_cos:
|
|
81
|
+
return_dict[key] = cosine_val.item()
|
|
82
|
+
else:
|
|
83
|
+
return_dict[key] = np.rad2deg(
|
|
84
|
+
torch.acos(cosine_val).detach().cpu().item()
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
return return_dict
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def compute_ratio(angle_dict: Dict[str, float], k: int = 2) -> Dict[str, float]:
|
|
91
|
+
"""
|
|
92
|
+
Compute interpolation ratios based on angles between fine-tuned models.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
angle_dict: Dictionary mapping parameter names to angles in degrees
|
|
96
|
+
k: Number of fine-tuned models (default: 2)
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
Dictionary mapping parameter names to interpolation ratios
|
|
100
|
+
"""
|
|
101
|
+
ratio_dict = {}
|
|
102
|
+
for key in angle_dict.keys():
|
|
103
|
+
angle = np.deg2rad(angle_dict[key])
|
|
104
|
+
ratio_dict[key] = k * np.cos(angle) / ((k - 1) * np.cos(angle) + 1 + EPS)
|
|
105
|
+
return ratio_dict
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def merge_weights(
|
|
109
|
+
w1: StateDictType, w2: StateDictType, w0: StateDictType, ratio: Dict[str, float]
|
|
110
|
+
) -> StateDictType:
|
|
111
|
+
"""
|
|
112
|
+
Merge model weights using ModelStock formula.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
w1: First fine-tuned model weights
|
|
116
|
+
w2: Second fine-tuned model weights
|
|
117
|
+
w0: Pre-trained model weights
|
|
118
|
+
ratio: Interpolation ratios for each parameter
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
Merged model weights
|
|
122
|
+
"""
|
|
123
|
+
# Compute w12 = (w1 + w2) / 2
|
|
124
|
+
w12 = {}
|
|
125
|
+
for key in w1.keys():
|
|
126
|
+
w12[key] = (w1[key].clone() + w2[key].clone()) / 2.0
|
|
127
|
+
|
|
128
|
+
# Apply ModelStock formula: w_merge = t * w12 + (1-t) * w0
|
|
129
|
+
w_merge = copy.deepcopy(w12)
|
|
130
|
+
for key, r in ratio.items():
|
|
131
|
+
w_merge[key] = w12[key].clone() * r + w0[key].clone() * (1.0 - r)
|
|
132
|
+
|
|
133
|
+
return w_merge
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
@fusion_bench.auto_register_config
|
|
137
|
+
class ModelStock(SimpleProfilerMixin, BaseAlgorithm):
|
|
138
|
+
"""
|
|
139
|
+
Model Stock: All we need is just a few fine-tuned models
|
|
140
|
+
|
|
141
|
+
This method merges fine-tuned models by interpolating between their average
|
|
142
|
+
and a pre-trained anchor model, with interpolation ratios determined by
|
|
143
|
+
the angle between fine-tuned models in parameter space.
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
def __init__(
|
|
147
|
+
self,
|
|
148
|
+
ignore_keys: Optional[List[str]] = None,
|
|
149
|
+
model_save_path: Optional[str] = None,
|
|
150
|
+
model_save_kwargs: Optional[DictConfig] = None,
|
|
151
|
+
**kwargs,
|
|
152
|
+
):
|
|
153
|
+
"""
|
|
154
|
+
Initialize ModelStock algorithm.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
ignore_keys: Additional parameter keys to ignore during merging
|
|
158
|
+
"""
|
|
159
|
+
super().__init__(**kwargs)
|
|
160
|
+
if self.ignore_keys is None:
|
|
161
|
+
self.ignore_keys = []
|
|
162
|
+
if self.model_save_kwargs is None:
|
|
163
|
+
self.model_save_kwargs = DictConfig({})
|
|
164
|
+
|
|
165
|
+
def run(self, modelpool: BaseModelPool) -> nn.Module:
|
|
166
|
+
"""
|
|
167
|
+
Run the ModelStock merging algorithm.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
modelpool: Pool of models containing pre-trained and fine-tuned models
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
Merged model
|
|
174
|
+
"""
|
|
175
|
+
with self.profile("model loading"):
|
|
176
|
+
# Load the pre-trained model (anchor)
|
|
177
|
+
pretrained_model = modelpool.load_pretrained_model()
|
|
178
|
+
if isinstance(pretrained_model, fusion_bench.LazyStateDict):
|
|
179
|
+
assert (
|
|
180
|
+
pretrained_model.meta_module is not None
|
|
181
|
+
), "Meta module is not initialized"
|
|
182
|
+
pretrained_state_dict = pretrained_model.state_dict()
|
|
183
|
+
|
|
184
|
+
# Load all fine-tuned models
|
|
185
|
+
finetuned_models = []
|
|
186
|
+
finetuned_state_dicts = []
|
|
187
|
+
|
|
188
|
+
for model_name in modelpool.model_names:
|
|
189
|
+
model = modelpool.load_model(model_name)
|
|
190
|
+
finetuned_models.append(model)
|
|
191
|
+
finetuned_state_dicts.append(model.state_dict())
|
|
192
|
+
log.info(f"Loaded fine-tuned model: {model_name}")
|
|
193
|
+
|
|
194
|
+
if len(finetuned_models) < 2:
|
|
195
|
+
raise ValueError("ModelStock requires at least 2 fine-tuned models")
|
|
196
|
+
|
|
197
|
+
log.info(f"Running ModelStock with {len(finetuned_models)} fine-tuned models")
|
|
198
|
+
|
|
199
|
+
with self.profile("compute angles and ratios"):
|
|
200
|
+
if len(finetuned_models) == 2:
|
|
201
|
+
# Two fine-tuned models case
|
|
202
|
+
angle_dict = compute_angle(
|
|
203
|
+
finetuned_state_dicts[0],
|
|
204
|
+
finetuned_state_dicts[1],
|
|
205
|
+
pretrained_state_dict,
|
|
206
|
+
ignore_keys=self.ignore_keys,
|
|
207
|
+
)
|
|
208
|
+
ratio_dict = compute_ratio(angle_dict, k=2)
|
|
209
|
+
|
|
210
|
+
log.info(f"Computed angles for {len(angle_dict)} parameter groups")
|
|
211
|
+
|
|
212
|
+
else:
|
|
213
|
+
# N fine-tuned models case - compute average angle
|
|
214
|
+
angles_sum = {}
|
|
215
|
+
angles_count = {}
|
|
216
|
+
|
|
217
|
+
# Compute pairwise angles and average them
|
|
218
|
+
for i in range(len(finetuned_models)):
|
|
219
|
+
for j in range(i + 1, len(finetuned_models)):
|
|
220
|
+
angle_dict = compute_angle(
|
|
221
|
+
finetuned_state_dicts[i],
|
|
222
|
+
finetuned_state_dicts[j],
|
|
223
|
+
pretrained_state_dict,
|
|
224
|
+
ignore_keys=self.ignore_keys,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
for key, angle in angle_dict.items():
|
|
228
|
+
if key not in angles_sum:
|
|
229
|
+
angles_sum[key] = 0
|
|
230
|
+
angles_count[key] = 0
|
|
231
|
+
angles_sum[key] += angle
|
|
232
|
+
angles_count[key] += 1
|
|
233
|
+
|
|
234
|
+
# Average the angles
|
|
235
|
+
avg_angle_dict = {}
|
|
236
|
+
for key in angles_sum:
|
|
237
|
+
avg_angle_dict[key] = angles_sum[key] / angles_count[key]
|
|
238
|
+
|
|
239
|
+
ratio_dict = compute_ratio(avg_angle_dict, k=len(finetuned_models))
|
|
240
|
+
|
|
241
|
+
log.info(
|
|
242
|
+
f"Computed average angles for {len(avg_angle_dict)} parameter groups"
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
with self.profile("merge weights"):
|
|
246
|
+
if len(finetuned_models) == 2:
|
|
247
|
+
# Direct merging for two models
|
|
248
|
+
merged_state_dict = merge_weights(
|
|
249
|
+
finetuned_state_dicts[0],
|
|
250
|
+
finetuned_state_dicts[1],
|
|
251
|
+
pretrained_state_dict,
|
|
252
|
+
ratio_dict,
|
|
253
|
+
)
|
|
254
|
+
else:
|
|
255
|
+
# For N models, first compute the average of fine-tuned models
|
|
256
|
+
avg_finetuned_state_dict = {}
|
|
257
|
+
for key in finetuned_state_dicts[0].keys():
|
|
258
|
+
avg_finetuned_state_dict[key] = torch.zeros_like(
|
|
259
|
+
finetuned_state_dicts[0][key]
|
|
260
|
+
)
|
|
261
|
+
for state_dict in finetuned_state_dicts:
|
|
262
|
+
avg_finetuned_state_dict[key] += state_dict[key]
|
|
263
|
+
avg_finetuned_state_dict[key] /= len(finetuned_state_dicts)
|
|
264
|
+
|
|
265
|
+
# Apply ModelStock formula: w_H = t * w_avg + (1-t) * w_0
|
|
266
|
+
merged_state_dict = copy.deepcopy(avg_finetuned_state_dict)
|
|
267
|
+
for key, r in ratio_dict.items():
|
|
268
|
+
merged_state_dict[key] = avg_finetuned_state_dict[
|
|
269
|
+
key
|
|
270
|
+
].clone() * r + pretrained_state_dict[key].clone() * (1.0 - r)
|
|
271
|
+
|
|
272
|
+
# Load merged weights into the model
|
|
273
|
+
if isinstance(pretrained_model, nn.Module):
|
|
274
|
+
result_model = pretrained_model
|
|
275
|
+
elif isinstance(pretrained_model, fusion_bench.LazyStateDict):
|
|
276
|
+
result_model = deepcopy(pretrained_model.meta_module)
|
|
277
|
+
result_model.to(device=pretrained_model._device)
|
|
278
|
+
result = result_model.load_state_dict(merged_state_dict, strict=False)
|
|
279
|
+
|
|
280
|
+
if result.unexpected_keys:
|
|
281
|
+
raise RuntimeError(
|
|
282
|
+
f"Unexpected keys in state dict: {result.unexpected_keys}"
|
|
283
|
+
)
|
|
284
|
+
if result.missing_keys:
|
|
285
|
+
log.warning(f"Missing keys in state dict: {result.missing_keys}")
|
|
286
|
+
|
|
287
|
+
if self.model_save_path is not None:
|
|
288
|
+
with self.profile("model saving"):
|
|
289
|
+
modelpool.save_model(
|
|
290
|
+
model, path=self.model_save_path, **self.model_save_kwargs
|
|
291
|
+
)
|
|
292
|
+
if isinstance(model, PreTrainedModel):
|
|
293
|
+
modelcard = create_default_model_card(
|
|
294
|
+
models=[
|
|
295
|
+
modelpool.get_model_path(m)
|
|
296
|
+
for m in modelpool.all_model_names
|
|
297
|
+
],
|
|
298
|
+
description="Merged model using [Model Stock](https://arxiv.org/abs/2403.19522).",
|
|
299
|
+
algorithm_config=self.config,
|
|
300
|
+
modelpool_config=modelpool.config,
|
|
301
|
+
)
|
|
302
|
+
with open(
|
|
303
|
+
os.path.join(self.model_save_path, "README.md"), "w"
|
|
304
|
+
) as f:
|
|
305
|
+
f.write(modelcard)
|
|
306
|
+
|
|
307
|
+
self.print_profile_summary()
|
|
308
|
+
log.info("ModelStock merging completed successfully")
|
|
309
|
+
return result_model
|
|
@@ -9,6 +9,7 @@ from torch.nn.modules import Module
|
|
|
9
9
|
from torch.utils.data import DataLoader
|
|
10
10
|
from tqdm.autonotebook import tqdm
|
|
11
11
|
|
|
12
|
+
from fusion_bench import auto_register_config
|
|
12
13
|
from fusion_bench.dataset.clip_dataset import CLIPDataset
|
|
13
14
|
from fusion_bench.mixins import CLIPClassificationMixin
|
|
14
15
|
|
|
@@ -17,17 +18,13 @@ from .regmean import RegMeanAlgorithm
|
|
|
17
18
|
log = logging.getLogger(__name__)
|
|
18
19
|
|
|
19
20
|
|
|
21
|
+
@auto_register_config
|
|
20
22
|
class RegMeanAlgorithmForCLIP(
|
|
21
|
-
RegMeanAlgorithm,
|
|
22
23
|
CLIPClassificationMixin,
|
|
24
|
+
RegMeanAlgorithm,
|
|
23
25
|
):
|
|
24
|
-
_config_mapping = {
|
|
25
|
-
"_dataloader_kwargs": "dataloader_kwargs",
|
|
26
|
-
}
|
|
27
|
-
|
|
28
26
|
def __init__(self, *, dataloader_kwargs: DictConfig, **kwargs):
|
|
29
27
|
super().__init__(**kwargs)
|
|
30
|
-
self.dataloader_kwargs = dataloader_kwargs
|
|
31
28
|
|
|
32
29
|
def on_regmean_start(self):
|
|
33
30
|
self.setup_zero_shot_classification_head()
|
|
@@ -16,49 +16,9 @@ from fusion_bench.method import BaseAlgorithm
|
|
|
16
16
|
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
17
17
|
from fusion_bench.modelpool import BaseModelPool
|
|
18
18
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
def get_param_names_to_merge(
|
|
23
|
-
input_param_names: List[str], exclude_param_names_regex: list
|
|
24
|
-
):
|
|
25
|
-
"""
|
|
26
|
-
get the names of parameters that need to be merged
|
|
27
|
-
:param input_param_names: list, names of input parameters
|
|
28
|
-
:param exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded
|
|
29
|
-
:return:
|
|
30
|
-
"""
|
|
31
|
-
param_names_to_merge = []
|
|
32
|
-
for param_name in input_param_names:
|
|
33
|
-
exclude = any(
|
|
34
|
-
[
|
|
35
|
-
re.match(exclude_pattern, param_name)
|
|
36
|
-
for exclude_pattern in exclude_param_names_regex
|
|
37
|
-
]
|
|
38
|
-
)
|
|
39
|
-
if not exclude:
|
|
40
|
-
param_names_to_merge.append(param_name)
|
|
41
|
-
return param_names_to_merge
|
|
42
|
-
|
|
19
|
+
from .utils import get_modules_to_merge, get_param_names_to_merge
|
|
43
20
|
|
|
44
|
-
|
|
45
|
-
"""
|
|
46
|
-
get the model modules that need to be merged, whose type is in include_module_types
|
|
47
|
-
:param model: nn.Module, input model
|
|
48
|
-
:param include_module_types: list, module types that want to include
|
|
49
|
-
:return:
|
|
50
|
-
"""
|
|
51
|
-
modules_to_merge: Dict[str, nn.Module] = {}
|
|
52
|
-
for module_name, module in model.named_modules():
|
|
53
|
-
is_valid_type = not include_module_types or any(
|
|
54
|
-
[
|
|
55
|
-
isinstance(module, include_module_type)
|
|
56
|
-
for include_module_type in include_module_types
|
|
57
|
-
]
|
|
58
|
-
)
|
|
59
|
-
if is_valid_type:
|
|
60
|
-
modules_to_merge[module_name] = module
|
|
61
|
-
return modules_to_merge
|
|
21
|
+
log = logging.getLogger(__name__)
|
|
62
22
|
|
|
63
23
|
|
|
64
24
|
def reduce_non_diagonal_elements(
|
|
@@ -88,12 +48,16 @@ def merging_with_regmean_weights(
|
|
|
88
48
|
):
|
|
89
49
|
"""
|
|
90
50
|
merge parameters of different models with computed regmean weights
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
models_to_merge_param_dict: dict, dictionary of list, where key is the parameter name,
|
|
54
|
+
value is a list of the corresponding parameters of all the models that need to be merged
|
|
55
|
+
models_to_merge_regmean_weights_list: list, list of dictionaries with length len(models_to_merge),
|
|
56
|
+
each dictionary records the regmean weights (matrix) of parameters for each model that needs to be merged, key is module name
|
|
57
|
+
reduce_non_diagonal_ratio: float, reduce non-diagonal elements in regmean weights by multiplying this scalar
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
dict: merged model parameters
|
|
97
61
|
"""
|
|
98
62
|
# dict, dictionary of model parameters
|
|
99
63
|
merged_params = {}
|
|
@@ -164,13 +128,17 @@ def regmean_merging(
|
|
|
164
128
|
reduce_non_diagonal_ratio: float = 1.0,
|
|
165
129
|
):
|
|
166
130
|
"""
|
|
167
|
-
regmean merging method
|
|
168
|
-
|
|
169
|
-
:
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
131
|
+
regmean merging method.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
models_to_merge: list, individual models that need to be merged
|
|
135
|
+
trainers: list, trainers of individual models
|
|
136
|
+
exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded
|
|
137
|
+
nums_regmean_examples: list, numbers of examples to compute regmean weights
|
|
138
|
+
reduce_non_diagonal_ratio: float, reduce non-diagonal elements in regmean weights by multiplying this scalar
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
dict: merged model parameters
|
|
174
142
|
"""
|
|
175
143
|
|
|
176
144
|
def compute_regmean_weights(module_name: str):
|
|
@@ -281,7 +249,10 @@ def regmean_merging(
|
|
|
281
249
|
|
|
282
250
|
|
|
283
251
|
@auto_register_config
|
|
284
|
-
class RegMeanAlgorithm(
|
|
252
|
+
class RegMeanAlgorithm(
|
|
253
|
+
SimpleProfilerMixin,
|
|
254
|
+
BaseAlgorithm,
|
|
255
|
+
):
|
|
285
256
|
_include_module_type = [nn.Linear]
|
|
286
257
|
|
|
287
258
|
def __init__(
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from typing import Dict, List
|
|
3
|
+
|
|
4
|
+
from torch import nn
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def get_param_names_to_merge(
|
|
8
|
+
input_param_names: List[str], exclude_param_names_regex: list
|
|
9
|
+
) -> List[str]:
|
|
10
|
+
"""
|
|
11
|
+
get the names of parameters that need to be merged
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
input_param_names: list, names of input parameters
|
|
15
|
+
exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
list: names of parameters that need to be merged
|
|
19
|
+
"""
|
|
20
|
+
param_names_to_merge = []
|
|
21
|
+
for param_name in input_param_names:
|
|
22
|
+
exclude = any(
|
|
23
|
+
[
|
|
24
|
+
re.match(exclude_pattern, param_name)
|
|
25
|
+
for exclude_pattern in exclude_param_names_regex
|
|
26
|
+
]
|
|
27
|
+
)
|
|
28
|
+
if not exclude:
|
|
29
|
+
param_names_to_merge.append(param_name)
|
|
30
|
+
return param_names_to_merge
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_modules_to_merge(
|
|
34
|
+
model: nn.Module, include_module_types: list
|
|
35
|
+
) -> Dict[str, nn.Module]:
|
|
36
|
+
"""
|
|
37
|
+
get the model modules that need to be merged, whose type is in include_module_types
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
model: nn.Module, input model
|
|
41
|
+
include_module_types: list, module types that want to include
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
Dict[str, nn.Module]: a dictionary of modules to merge
|
|
45
|
+
"""
|
|
46
|
+
modules_to_merge: Dict[str, nn.Module] = {}
|
|
47
|
+
for module_name, module in model.named_modules():
|
|
48
|
+
is_valid_type = not include_module_types or any(
|
|
49
|
+
[
|
|
50
|
+
isinstance(module, include_module_type)
|
|
51
|
+
for include_module_type in include_module_types
|
|
52
|
+
]
|
|
53
|
+
)
|
|
54
|
+
if is_valid_type:
|
|
55
|
+
modules_to_merge[module_name] = module
|
|
56
|
+
return modules_to_merge
|
|
@@ -7,55 +7,14 @@ import torch
|
|
|
7
7
|
from torch import Tensor, nn
|
|
8
8
|
from tqdm.autonotebook import tqdm
|
|
9
9
|
|
|
10
|
-
|
|
10
|
+
import fusion_bench.method.regmean.utils as regmean_utils
|
|
11
|
+
from fusion_bench import BaseAlgorithm, auto_register_config
|
|
11
12
|
from fusion_bench.mixins import SimpleProfilerMixin
|
|
12
13
|
from fusion_bench.modelpool import BaseModelPool
|
|
13
14
|
|
|
14
15
|
log = logging.getLogger(__name__)
|
|
15
16
|
|
|
16
17
|
|
|
17
|
-
def get_param_names_to_merge(
|
|
18
|
-
input_param_names: List[str], exclude_param_names_regex: list
|
|
19
|
-
):
|
|
20
|
-
"""
|
|
21
|
-
get the names of parameters that need to be merged
|
|
22
|
-
:param input_param_names: list, names of input parameters
|
|
23
|
-
:param exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded
|
|
24
|
-
:return:
|
|
25
|
-
"""
|
|
26
|
-
param_names_to_merge = []
|
|
27
|
-
for param_name in input_param_names:
|
|
28
|
-
exclude = any(
|
|
29
|
-
[
|
|
30
|
-
re.match(exclude_pattern, param_name)
|
|
31
|
-
for exclude_pattern in exclude_param_names_regex
|
|
32
|
-
]
|
|
33
|
-
)
|
|
34
|
-
if not exclude:
|
|
35
|
-
param_names_to_merge.append(param_name)
|
|
36
|
-
return param_names_to_merge
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
def get_modules_to_merge(model: nn.Module, include_module_types: list):
|
|
40
|
-
"""
|
|
41
|
-
get the model modules that need to be merged, whose type is in include_module_types
|
|
42
|
-
:param model: nn.Module, input model
|
|
43
|
-
:param include_module_types: list, module types that want to include
|
|
44
|
-
:return:
|
|
45
|
-
"""
|
|
46
|
-
modules_to_merge: Dict[str, nn.Module] = {}
|
|
47
|
-
for module_name, module in model.named_modules():
|
|
48
|
-
is_valid_type = not include_module_types or any(
|
|
49
|
-
[
|
|
50
|
-
isinstance(module, include_module_type)
|
|
51
|
-
for include_module_type in include_module_types
|
|
52
|
-
]
|
|
53
|
-
)
|
|
54
|
-
if is_valid_type:
|
|
55
|
-
modules_to_merge[module_name] = module
|
|
56
|
-
return modules_to_merge
|
|
57
|
-
|
|
58
|
-
|
|
59
18
|
def reduce_non_diagonal_elements(
|
|
60
19
|
regmean_weights: torch.Tensor, reduce_non_diagonal_ratio: float
|
|
61
20
|
):
|
|
@@ -130,12 +89,16 @@ def merging_with_regmean_weights(
|
|
|
130
89
|
):
|
|
131
90
|
"""
|
|
132
91
|
merge parameters of different models with computed regmean weights
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
92
|
+
|
|
93
|
+
Asrgs:
|
|
94
|
+
models_to_merge_param_dict: dict, dictionary of list, where key is the parameter name,
|
|
95
|
+
value is a list of the corresponding parameters of all the models that need to be merged
|
|
96
|
+
models_to_merge_regmean_weights_list: list, list of dictionaries with length len(models_to_merge),
|
|
97
|
+
each dictionary records the regmean weights (matrix) of parameters for each model that needs to be merged, key is module name
|
|
98
|
+
reduce_non_diagonal_ratio: float, reduce non-diagonal elements in regmean weights by multiplying this scalar
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
dict: merged model parameters
|
|
139
102
|
"""
|
|
140
103
|
# dict, dictionary of model parameters
|
|
141
104
|
merged_params = {}
|
|
@@ -176,14 +139,12 @@ def merging_with_regmean_weights(
|
|
|
176
139
|
return merged_params
|
|
177
140
|
|
|
178
141
|
|
|
179
|
-
|
|
142
|
+
@auto_register_config
|
|
143
|
+
class RegMeanAlgorithmPlusPlus(
|
|
144
|
+
SimpleProfilerMixin,
|
|
145
|
+
BaseAlgorithm,
|
|
146
|
+
):
|
|
180
147
|
_include_module_type = [nn.Linear]
|
|
181
|
-
_config_mapping = {
|
|
182
|
-
"num_regmean_examples": "num_regmean_examples",
|
|
183
|
-
"exclude_param_names_regex": "exclude_param_names_regex",
|
|
184
|
-
"reduce_non_diagonal_ratio": "reduce_non_diagonal_ratio",
|
|
185
|
-
"weight_transpose": "weight_transpose",
|
|
186
|
-
}
|
|
187
148
|
|
|
188
149
|
def __init__(
|
|
189
150
|
self,
|
|
@@ -194,11 +155,11 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
194
155
|
weight_transpose: bool,
|
|
195
156
|
**kwargs,
|
|
196
157
|
):
|
|
158
|
+
super().__init__(**kwargs)
|
|
197
159
|
self.num_regmean_examples = num_regmean_examples
|
|
198
160
|
self.exclude_param_names_regex = exclude_param_names_regex
|
|
199
161
|
self.reduce_non_diagonal_ratio = reduce_non_diagonal_ratio
|
|
200
162
|
self.weight_transpose = weight_transpose
|
|
201
|
-
super().__init__(**kwargs)
|
|
202
163
|
|
|
203
164
|
def run(self, modelpool: BaseModelPool, **kwargs):
|
|
204
165
|
if not isinstance(modelpool, BaseModelPool):
|
|
@@ -262,7 +223,7 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
262
223
|
|
|
263
224
|
# exclude parameter whose name matches element in exclude_param_names_regex
|
|
264
225
|
if param_names_to_merge is None:
|
|
265
|
-
param_names_to_merge = get_param_names_to_merge(
|
|
226
|
+
param_names_to_merge = regmean_utils.get_param_names_to_merge(
|
|
266
227
|
input_param_names=list(param_dict.keys()),
|
|
267
228
|
exclude_param_names_regex=self.config.get(
|
|
268
229
|
"exclude_param_names_regex", []
|
|
@@ -274,7 +235,7 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
274
235
|
param_dict[param_name]
|
|
275
236
|
)
|
|
276
237
|
|
|
277
|
-
linear_modules_to_merge = get_modules_to_merge(
|
|
238
|
+
linear_modules_to_merge = regmean_utils.get_modules_to_merge(
|
|
278
239
|
model=layer_to_merge,
|
|
279
240
|
include_module_types=self._include_module_type,
|
|
280
241
|
)
|
|
@@ -294,7 +255,7 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
294
255
|
linear_modules_to_merge=linear_modules_to_merge,
|
|
295
256
|
)
|
|
296
257
|
|
|
297
|
-
module_subset = get_param_names_to_merge(
|
|
258
|
+
module_subset = regmean_utils.get_param_names_to_merge(
|
|
298
259
|
input_param_names=list(param_dict.keys()),
|
|
299
260
|
exclude_param_names_regex=self.exclude_param_names_regex,
|
|
300
261
|
)
|
|
@@ -61,8 +61,8 @@ def simple_average(
|
|
|
61
61
|
|
|
62
62
|
@auto_register_config
|
|
63
63
|
class SimpleAverageAlgorithm(
|
|
64
|
-
BaseAlgorithm,
|
|
65
64
|
SimpleProfilerMixin,
|
|
65
|
+
BaseAlgorithm,
|
|
66
66
|
):
|
|
67
67
|
def __init__(self, show_pbar: bool = False, **kwargs):
|
|
68
68
|
"""
|
|
@@ -120,13 +120,13 @@ class SimpleAverageAlgorithm(
|
|
|
120
120
|
if isinstance(forward_model, LazyStateDict):
|
|
121
121
|
# if the model is a LazyStateDict, convert it to an empty module
|
|
122
122
|
forward_model = forward_model.meta_module.to_empty(
|
|
123
|
-
device=
|
|
124
|
-
"cpu"
|
|
125
|
-
if forward_model._torch_dtype is None
|
|
126
|
-
else forward_model._torch_dtype
|
|
127
|
-
)
|
|
123
|
+
device=forward_model._device
|
|
128
124
|
)
|
|
129
|
-
forward_model.load_state_dict(sd)
|
|
125
|
+
result = forward_model.load_state_dict(sd, strict=False)
|
|
126
|
+
if result.unexpected_keys:
|
|
127
|
+
raise ValueError(f"Unexpected keys in state dict: {result.unexpected_keys}")
|
|
128
|
+
if result.missing_keys:
|
|
129
|
+
log.warning(f"Missing keys in state dict: {result.missing_keys}")
|
|
130
130
|
# print profile report and log the merged models
|
|
131
131
|
self.print_profile_summary()
|
|
132
132
|
log.info(f"merged {len(merged_model_names)} models:")
|
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
# flake8: noqa F401
|
|
2
|
-
from .slerp import SlerpMergeAlgorithm
|
|
2
|
+
from .slerp import SlerpForCausalLM, SlerpMergeAlgorithm
|