fusion-bench 0.2.21__py3-none-any.whl → 0.2.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. fusion_bench/__init__.py +25 -2
  2. fusion_bench/compat/method/__init__.py +5 -2
  3. fusion_bench/compat/method/base_algorithm.py +3 -2
  4. fusion_bench/compat/modelpool/base_pool.py +3 -3
  5. fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
  6. fusion_bench/constants/__init__.py +1 -0
  7. fusion_bench/constants/runtime.py +57 -0
  8. fusion_bench/dataset/gpt2_glue.py +1 -1
  9. fusion_bench/method/__init__.py +12 -4
  10. fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
  11. fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
  12. fusion_bench/method/bitdelta/__init__.py +1 -0
  13. fusion_bench/method/bitdelta/bitdelta.py +7 -23
  14. fusion_bench/method/classification/clip_finetune.py +1 -1
  15. fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
  16. fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
  17. fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
  18. fusion_bench/method/fisher_merging/clip_fisher_merging.py +0 -4
  19. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +2 -2
  20. fusion_bench/method/linear/simple_average_for_llama.py +16 -11
  21. fusion_bench/method/model_stock/__init__.py +1 -0
  22. fusion_bench/method/model_stock/model_stock.py +309 -0
  23. fusion_bench/method/regmean/clip_regmean.py +3 -6
  24. fusion_bench/method/regmean/regmean.py +27 -56
  25. fusion_bench/method/regmean/utils.py +56 -0
  26. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
  27. fusion_bench/method/simple_average.py +7 -7
  28. fusion_bench/method/slerp/__init__.py +1 -1
  29. fusion_bench/method/slerp/slerp.py +110 -14
  30. fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
  31. fusion_bench/method/smile_upscaling/projected_energy.py +1 -2
  32. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
  33. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +40 -31
  34. fusion_bench/method/smile_upscaling/smile_upscaling.py +1 -1
  35. fusion_bench/method/we_moe/__init__.py +1 -0
  36. fusion_bench/method/we_moe/entropy_loss.py +25 -0
  37. fusion_bench/method/we_moe/flan_t5_we_moe.py +320 -0
  38. fusion_bench/method/we_moe/utils.py +15 -0
  39. fusion_bench/method/weighted_average/llama.py +1 -1
  40. fusion_bench/mixins/clip_classification.py +37 -48
  41. fusion_bench/mixins/serialization.py +30 -10
  42. fusion_bench/modelpool/base_pool.py +1 -1
  43. fusion_bench/modelpool/causal_lm/causal_lm.py +293 -75
  44. fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
  45. fusion_bench/models/__init__.py +5 -0
  46. fusion_bench/models/hf_utils.py +69 -86
  47. fusion_bench/models/linearized/vision_model.py +6 -6
  48. fusion_bench/models/model_card_templates/default.md +46 -0
  49. fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
  50. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +1 -8
  51. fusion_bench/models/modeling_smile_mistral/__init__.py +2 -1
  52. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +1 -5
  53. fusion_bench/models/we_moe.py +8 -8
  54. fusion_bench/programs/fabric_fusion_program.py +29 -60
  55. fusion_bench/scripts/cli.py +34 -1
  56. fusion_bench/taskpool/base_pool.py +99 -17
  57. fusion_bench/taskpool/clip_vision/taskpool.py +10 -5
  58. fusion_bench/taskpool/dummy.py +101 -13
  59. fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
  60. fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
  61. fusion_bench/utils/__init__.py +2 -0
  62. fusion_bench/utils/cache_utils.py +101 -1
  63. fusion_bench/utils/data.py +6 -4
  64. fusion_bench/utils/devices.py +7 -4
  65. fusion_bench/utils/dtype.py +3 -2
  66. fusion_bench/utils/fabric.py +2 -2
  67. fusion_bench/utils/lazy_imports.py +23 -0
  68. fusion_bench/utils/lazy_state_dict.py +117 -19
  69. fusion_bench/utils/modelscope.py +3 -3
  70. fusion_bench/utils/packages.py +3 -3
  71. fusion_bench/utils/parameters.py +0 -2
  72. fusion_bench/utils/path.py +56 -0
  73. fusion_bench/utils/pylogger.py +1 -1
  74. fusion_bench/utils/timer.py +92 -10
  75. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/METADATA +1 -23
  76. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/RECORD +89 -75
  77. fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
  78. fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
  79. fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
  80. fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
  81. fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
  82. fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
  83. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -1
  84. fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
  85. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +1 -1
  86. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/WHEEL +0 -0
  87. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/entry_points.txt +0 -0
  88. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/licenses/LICENSE +0 -0
  89. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,309 @@
1
+ import copy
2
+ import logging
3
+ import math
4
+ import os
5
+ from collections import OrderedDict
6
+ from copy import deepcopy
7
+ from typing import Dict, List, Optional, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+ from omegaconf import DictConfig
12
+ from torch import nn
13
+ from transformers import PreTrainedModel
14
+
15
+ import fusion_bench
16
+ from fusion_bench import BaseAlgorithm, BaseModelPool
17
+ from fusion_bench.mixins import SimpleProfilerMixin
18
+ from fusion_bench.models import create_default_model_card
19
+ from fusion_bench.utils.type import StateDictType
20
+
21
+ log = logging.getLogger(__name__)
22
+
23
+ EPS = 1e-8
24
+
25
+
26
+ def compute_angle(
27
+ state_dict_1: StateDictType,
28
+ state_dict_2: StateDictType,
29
+ ref_state_dict: StateDictType,
30
+ ignore_keys: List[str] = [],
31
+ return_cos: bool = False,
32
+ ) -> Dict[str, float]:
33
+ """
34
+ Compute the angle between two state dictionaries relative to a reference state dictionary.
35
+
36
+ Args:
37
+ state_dict_1: First state dictionary
38
+ state_dict_2: Second state dictionary
39
+ ref_state_dict: Reference state dictionary (typically pre-trained model)
40
+ ignore_keys: Keys to ignore during computation
41
+ return_cos: If True, return cosine values instead of angles in degrees
42
+
43
+ Returns:
44
+ Dictionary mapping parameter names to angles (in degrees) or cosine values
45
+ """
46
+ # Remove the keys not used for CLIP fine-tuning (from the notebook example)
47
+
48
+ return_dict = OrderedDict()
49
+
50
+ with torch.no_grad():
51
+ for key in ref_state_dict:
52
+ if key in ignore_keys:
53
+ log.info(f"Ignoring key '{key}'")
54
+ continue
55
+
56
+ state_dict_1_val = state_dict_1[key]
57
+ state_dict_2_val = state_dict_2[key]
58
+ ref_val = ref_state_dict[key]
59
+
60
+ if not (state_dict_1_val.shape == state_dict_2_val.shape == ref_val.shape):
61
+ log.warning(
62
+ f"Shape mismatch for key '{key}', ignored during merging: "
63
+ f"({state_dict_1_val.shape}, {state_dict_2_val.shape}, {ref_val.shape})"
64
+ )
65
+ continue
66
+
67
+ vector1 = (state_dict_1_val - ref_val).clone().detach()
68
+ vector2 = (state_dict_2_val - ref_val).clone().detach()
69
+
70
+ vector1 = vector1.float()
71
+ vector2 = vector2.float()
72
+
73
+ cosine_val = torch.sum(vector1 * vector2) / (
74
+ math.sqrt(torch.sum(vector1**2) * torch.sum(vector2**2)) + EPS
75
+ )
76
+ cosine_val = torch.clamp(
77
+ cosine_val, min=-1.0, max=1.0
78
+ ) # Prevent nan from acos
79
+
80
+ if return_cos:
81
+ return_dict[key] = cosine_val.item()
82
+ else:
83
+ return_dict[key] = np.rad2deg(
84
+ torch.acos(cosine_val).detach().cpu().item()
85
+ )
86
+
87
+ return return_dict
88
+
89
+
90
+ def compute_ratio(angle_dict: Dict[str, float], k: int = 2) -> Dict[str, float]:
91
+ """
92
+ Compute interpolation ratios based on angles between fine-tuned models.
93
+
94
+ Args:
95
+ angle_dict: Dictionary mapping parameter names to angles in degrees
96
+ k: Number of fine-tuned models (default: 2)
97
+
98
+ Returns:
99
+ Dictionary mapping parameter names to interpolation ratios
100
+ """
101
+ ratio_dict = {}
102
+ for key in angle_dict.keys():
103
+ angle = np.deg2rad(angle_dict[key])
104
+ ratio_dict[key] = k * np.cos(angle) / ((k - 1) * np.cos(angle) + 1 + EPS)
105
+ return ratio_dict
106
+
107
+
108
+ def merge_weights(
109
+ w1: StateDictType, w2: StateDictType, w0: StateDictType, ratio: Dict[str, float]
110
+ ) -> StateDictType:
111
+ """
112
+ Merge model weights using ModelStock formula.
113
+
114
+ Args:
115
+ w1: First fine-tuned model weights
116
+ w2: Second fine-tuned model weights
117
+ w0: Pre-trained model weights
118
+ ratio: Interpolation ratios for each parameter
119
+
120
+ Returns:
121
+ Merged model weights
122
+ """
123
+ # Compute w12 = (w1 + w2) / 2
124
+ w12 = {}
125
+ for key in w1.keys():
126
+ w12[key] = (w1[key].clone() + w2[key].clone()) / 2.0
127
+
128
+ # Apply ModelStock formula: w_merge = t * w12 + (1-t) * w0
129
+ w_merge = copy.deepcopy(w12)
130
+ for key, r in ratio.items():
131
+ w_merge[key] = w12[key].clone() * r + w0[key].clone() * (1.0 - r)
132
+
133
+ return w_merge
134
+
135
+
136
+ @fusion_bench.auto_register_config
137
+ class ModelStock(SimpleProfilerMixin, BaseAlgorithm):
138
+ """
139
+ Model Stock: All we need is just a few fine-tuned models
140
+
141
+ This method merges fine-tuned models by interpolating between their average
142
+ and a pre-trained anchor model, with interpolation ratios determined by
143
+ the angle between fine-tuned models in parameter space.
144
+ """
145
+
146
+ def __init__(
147
+ self,
148
+ ignore_keys: Optional[List[str]] = None,
149
+ model_save_path: Optional[str] = None,
150
+ model_save_kwargs: Optional[DictConfig] = None,
151
+ **kwargs,
152
+ ):
153
+ """
154
+ Initialize ModelStock algorithm.
155
+
156
+ Args:
157
+ ignore_keys: Additional parameter keys to ignore during merging
158
+ """
159
+ super().__init__(**kwargs)
160
+ if self.ignore_keys is None:
161
+ self.ignore_keys = []
162
+ if self.model_save_kwargs is None:
163
+ self.model_save_kwargs = DictConfig({})
164
+
165
+ def run(self, modelpool: BaseModelPool) -> nn.Module:
166
+ """
167
+ Run the ModelStock merging algorithm.
168
+
169
+ Args:
170
+ modelpool: Pool of models containing pre-trained and fine-tuned models
171
+
172
+ Returns:
173
+ Merged model
174
+ """
175
+ with self.profile("model loading"):
176
+ # Load the pre-trained model (anchor)
177
+ pretrained_model = modelpool.load_pretrained_model()
178
+ if isinstance(pretrained_model, fusion_bench.LazyStateDict):
179
+ assert (
180
+ pretrained_model.meta_module is not None
181
+ ), "Meta module is not initialized"
182
+ pretrained_state_dict = pretrained_model.state_dict()
183
+
184
+ # Load all fine-tuned models
185
+ finetuned_models = []
186
+ finetuned_state_dicts = []
187
+
188
+ for model_name in modelpool.model_names:
189
+ model = modelpool.load_model(model_name)
190
+ finetuned_models.append(model)
191
+ finetuned_state_dicts.append(model.state_dict())
192
+ log.info(f"Loaded fine-tuned model: {model_name}")
193
+
194
+ if len(finetuned_models) < 2:
195
+ raise ValueError("ModelStock requires at least 2 fine-tuned models")
196
+
197
+ log.info(f"Running ModelStock with {len(finetuned_models)} fine-tuned models")
198
+
199
+ with self.profile("compute angles and ratios"):
200
+ if len(finetuned_models) == 2:
201
+ # Two fine-tuned models case
202
+ angle_dict = compute_angle(
203
+ finetuned_state_dicts[0],
204
+ finetuned_state_dicts[1],
205
+ pretrained_state_dict,
206
+ ignore_keys=self.ignore_keys,
207
+ )
208
+ ratio_dict = compute_ratio(angle_dict, k=2)
209
+
210
+ log.info(f"Computed angles for {len(angle_dict)} parameter groups")
211
+
212
+ else:
213
+ # N fine-tuned models case - compute average angle
214
+ angles_sum = {}
215
+ angles_count = {}
216
+
217
+ # Compute pairwise angles and average them
218
+ for i in range(len(finetuned_models)):
219
+ for j in range(i + 1, len(finetuned_models)):
220
+ angle_dict = compute_angle(
221
+ finetuned_state_dicts[i],
222
+ finetuned_state_dicts[j],
223
+ pretrained_state_dict,
224
+ ignore_keys=self.ignore_keys,
225
+ )
226
+
227
+ for key, angle in angle_dict.items():
228
+ if key not in angles_sum:
229
+ angles_sum[key] = 0
230
+ angles_count[key] = 0
231
+ angles_sum[key] += angle
232
+ angles_count[key] += 1
233
+
234
+ # Average the angles
235
+ avg_angle_dict = {}
236
+ for key in angles_sum:
237
+ avg_angle_dict[key] = angles_sum[key] / angles_count[key]
238
+
239
+ ratio_dict = compute_ratio(avg_angle_dict, k=len(finetuned_models))
240
+
241
+ log.info(
242
+ f"Computed average angles for {len(avg_angle_dict)} parameter groups"
243
+ )
244
+
245
+ with self.profile("merge weights"):
246
+ if len(finetuned_models) == 2:
247
+ # Direct merging for two models
248
+ merged_state_dict = merge_weights(
249
+ finetuned_state_dicts[0],
250
+ finetuned_state_dicts[1],
251
+ pretrained_state_dict,
252
+ ratio_dict,
253
+ )
254
+ else:
255
+ # For N models, first compute the average of fine-tuned models
256
+ avg_finetuned_state_dict = {}
257
+ for key in finetuned_state_dicts[0].keys():
258
+ avg_finetuned_state_dict[key] = torch.zeros_like(
259
+ finetuned_state_dicts[0][key]
260
+ )
261
+ for state_dict in finetuned_state_dicts:
262
+ avg_finetuned_state_dict[key] += state_dict[key]
263
+ avg_finetuned_state_dict[key] /= len(finetuned_state_dicts)
264
+
265
+ # Apply ModelStock formula: w_H = t * w_avg + (1-t) * w_0
266
+ merged_state_dict = copy.deepcopy(avg_finetuned_state_dict)
267
+ for key, r in ratio_dict.items():
268
+ merged_state_dict[key] = avg_finetuned_state_dict[
269
+ key
270
+ ].clone() * r + pretrained_state_dict[key].clone() * (1.0 - r)
271
+
272
+ # Load merged weights into the model
273
+ if isinstance(pretrained_model, nn.Module):
274
+ result_model = pretrained_model
275
+ elif isinstance(pretrained_model, fusion_bench.LazyStateDict):
276
+ result_model = deepcopy(pretrained_model.meta_module)
277
+ result_model.to(device=pretrained_model._device)
278
+ result = result_model.load_state_dict(merged_state_dict, strict=False)
279
+
280
+ if result.unexpected_keys:
281
+ raise RuntimeError(
282
+ f"Unexpected keys in state dict: {result.unexpected_keys}"
283
+ )
284
+ if result.missing_keys:
285
+ log.warning(f"Missing keys in state dict: {result.missing_keys}")
286
+
287
+ if self.model_save_path is not None:
288
+ with self.profile("model saving"):
289
+ modelpool.save_model(
290
+ model, path=self.model_save_path, **self.model_save_kwargs
291
+ )
292
+ if isinstance(model, PreTrainedModel):
293
+ modelcard = create_default_model_card(
294
+ models=[
295
+ modelpool.get_model_path(m)
296
+ for m in modelpool.all_model_names
297
+ ],
298
+ description="Merged model using [Model Stock](https://arxiv.org/abs/2403.19522).",
299
+ algorithm_config=self.config,
300
+ modelpool_config=modelpool.config,
301
+ )
302
+ with open(
303
+ os.path.join(self.model_save_path, "README.md"), "w"
304
+ ) as f:
305
+ f.write(modelcard)
306
+
307
+ self.print_profile_summary()
308
+ log.info("ModelStock merging completed successfully")
309
+ return result_model
@@ -9,6 +9,7 @@ from torch.nn.modules import Module
9
9
  from torch.utils.data import DataLoader
10
10
  from tqdm.autonotebook import tqdm
11
11
 
12
+ from fusion_bench import auto_register_config
12
13
  from fusion_bench.dataset.clip_dataset import CLIPDataset
13
14
  from fusion_bench.mixins import CLIPClassificationMixin
14
15
 
@@ -17,17 +18,13 @@ from .regmean import RegMeanAlgorithm
17
18
  log = logging.getLogger(__name__)
18
19
 
19
20
 
21
+ @auto_register_config
20
22
  class RegMeanAlgorithmForCLIP(
21
- RegMeanAlgorithm,
22
23
  CLIPClassificationMixin,
24
+ RegMeanAlgorithm,
23
25
  ):
24
- _config_mapping = {
25
- "_dataloader_kwargs": "dataloader_kwargs",
26
- }
27
-
28
26
  def __init__(self, *, dataloader_kwargs: DictConfig, **kwargs):
29
27
  super().__init__(**kwargs)
30
- self.dataloader_kwargs = dataloader_kwargs
31
28
 
32
29
  def on_regmean_start(self):
33
30
  self.setup_zero_shot_classification_head()
@@ -16,49 +16,9 @@ from fusion_bench.method import BaseAlgorithm
16
16
  from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
17
17
  from fusion_bench.modelpool import BaseModelPool
18
18
 
19
- log = logging.getLogger(__name__)
20
-
21
-
22
- def get_param_names_to_merge(
23
- input_param_names: List[str], exclude_param_names_regex: list
24
- ):
25
- """
26
- get the names of parameters that need to be merged
27
- :param input_param_names: list, names of input parameters
28
- :param exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded
29
- :return:
30
- """
31
- param_names_to_merge = []
32
- for param_name in input_param_names:
33
- exclude = any(
34
- [
35
- re.match(exclude_pattern, param_name)
36
- for exclude_pattern in exclude_param_names_regex
37
- ]
38
- )
39
- if not exclude:
40
- param_names_to_merge.append(param_name)
41
- return param_names_to_merge
42
-
19
+ from .utils import get_modules_to_merge, get_param_names_to_merge
43
20
 
44
- def get_modules_to_merge(model: nn.Module, include_module_types: list):
45
- """
46
- get the model modules that need to be merged, whose type is in include_module_types
47
- :param model: nn.Module, input model
48
- :param include_module_types: list, module types that want to include
49
- :return:
50
- """
51
- modules_to_merge: Dict[str, nn.Module] = {}
52
- for module_name, module in model.named_modules():
53
- is_valid_type = not include_module_types or any(
54
- [
55
- isinstance(module, include_module_type)
56
- for include_module_type in include_module_types
57
- ]
58
- )
59
- if is_valid_type:
60
- modules_to_merge[module_name] = module
61
- return modules_to_merge
21
+ log = logging.getLogger(__name__)
62
22
 
63
23
 
64
24
  def reduce_non_diagonal_elements(
@@ -88,12 +48,16 @@ def merging_with_regmean_weights(
88
48
  ):
89
49
  """
90
50
  merge parameters of different models with computed regmean weights
91
- :param models_to_merge_param_dict: dict, dictionary of list, where key is the parameter name,
92
- value is a list of the corresponding parameters of all the models that need to be merged
93
- :param models_to_merge_regmean_weights_list: list, list of dictionaries with length len(models_to_merge),
94
- each dictionary records the regmean weights (matrix) of parameters for each model that needs to be merged, key is module name
95
- :param reduce_non_diagonal_ratio: float, reduce non-diagonal elements in regmean weights by multiplying this scalar
96
- :return:
51
+
52
+ Args:
53
+ models_to_merge_param_dict: dict, dictionary of list, where key is the parameter name,
54
+ value is a list of the corresponding parameters of all the models that need to be merged
55
+ models_to_merge_regmean_weights_list: list, list of dictionaries with length len(models_to_merge),
56
+ each dictionary records the regmean weights (matrix) of parameters for each model that needs to be merged, key is module name
57
+ reduce_non_diagonal_ratio: float, reduce non-diagonal elements in regmean weights by multiplying this scalar
58
+
59
+ Returns:
60
+ dict: merged model parameters
97
61
  """
98
62
  # dict, dictionary of model parameters
99
63
  merged_params = {}
@@ -164,13 +128,17 @@ def regmean_merging(
164
128
  reduce_non_diagonal_ratio: float = 1.0,
165
129
  ):
166
130
  """
167
- regmean merging method
168
- :param models_to_merge: list, individual models that need to be merged
169
- :param trainers: list, trainers of individual models
170
- :param exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded
171
- :param nums_regmean_examples: list, numbers of examples to compute regmean weights
172
- :param reduce_non_diagonal_ratio: float, reduce non-diagonal elements in regmean weights by multiplying this scalar
173
- :return:
131
+ regmean merging method.
132
+
133
+ Args:
134
+ models_to_merge: list, individual models that need to be merged
135
+ trainers: list, trainers of individual models
136
+ exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded
137
+ nums_regmean_examples: list, numbers of examples to compute regmean weights
138
+ reduce_non_diagonal_ratio: float, reduce non-diagonal elements in regmean weights by multiplying this scalar
139
+
140
+ Returns:
141
+ dict: merged model parameters
174
142
  """
175
143
 
176
144
  def compute_regmean_weights(module_name: str):
@@ -281,7 +249,10 @@ def regmean_merging(
281
249
 
282
250
 
283
251
  @auto_register_config
284
- class RegMeanAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
252
+ class RegMeanAlgorithm(
253
+ SimpleProfilerMixin,
254
+ BaseAlgorithm,
255
+ ):
285
256
  _include_module_type = [nn.Linear]
286
257
 
287
258
  def __init__(
@@ -0,0 +1,56 @@
1
+ import re
2
+ from typing import Dict, List
3
+
4
+ from torch import nn
5
+
6
+
7
+ def get_param_names_to_merge(
8
+ input_param_names: List[str], exclude_param_names_regex: list
9
+ ) -> List[str]:
10
+ """
11
+ get the names of parameters that need to be merged
12
+
13
+ Args:
14
+ input_param_names: list, names of input parameters
15
+ exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded
16
+
17
+ Returns:
18
+ list: names of parameters that need to be merged
19
+ """
20
+ param_names_to_merge = []
21
+ for param_name in input_param_names:
22
+ exclude = any(
23
+ [
24
+ re.match(exclude_pattern, param_name)
25
+ for exclude_pattern in exclude_param_names_regex
26
+ ]
27
+ )
28
+ if not exclude:
29
+ param_names_to_merge.append(param_name)
30
+ return param_names_to_merge
31
+
32
+
33
+ def get_modules_to_merge(
34
+ model: nn.Module, include_module_types: list
35
+ ) -> Dict[str, nn.Module]:
36
+ """
37
+ get the model modules that need to be merged, whose type is in include_module_types
38
+
39
+ Args:
40
+ model: nn.Module, input model
41
+ include_module_types: list, module types that want to include
42
+
43
+ Returns:
44
+ Dict[str, nn.Module]: a dictionary of modules to merge
45
+ """
46
+ modules_to_merge: Dict[str, nn.Module] = {}
47
+ for module_name, module in model.named_modules():
48
+ is_valid_type = not include_module_types or any(
49
+ [
50
+ isinstance(module, include_module_type)
51
+ for include_module_type in include_module_types
52
+ ]
53
+ )
54
+ if is_valid_type:
55
+ modules_to_merge[module_name] = module
56
+ return modules_to_merge
@@ -7,55 +7,14 @@ import torch
7
7
  from torch import Tensor, nn
8
8
  from tqdm.autonotebook import tqdm
9
9
 
10
- from fusion_bench.method import BaseAlgorithm
10
+ import fusion_bench.method.regmean.utils as regmean_utils
11
+ from fusion_bench import BaseAlgorithm, auto_register_config
11
12
  from fusion_bench.mixins import SimpleProfilerMixin
12
13
  from fusion_bench.modelpool import BaseModelPool
13
14
 
14
15
  log = logging.getLogger(__name__)
15
16
 
16
17
 
17
- def get_param_names_to_merge(
18
- input_param_names: List[str], exclude_param_names_regex: list
19
- ):
20
- """
21
- get the names of parameters that need to be merged
22
- :param input_param_names: list, names of input parameters
23
- :param exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded
24
- :return:
25
- """
26
- param_names_to_merge = []
27
- for param_name in input_param_names:
28
- exclude = any(
29
- [
30
- re.match(exclude_pattern, param_name)
31
- for exclude_pattern in exclude_param_names_regex
32
- ]
33
- )
34
- if not exclude:
35
- param_names_to_merge.append(param_name)
36
- return param_names_to_merge
37
-
38
-
39
- def get_modules_to_merge(model: nn.Module, include_module_types: list):
40
- """
41
- get the model modules that need to be merged, whose type is in include_module_types
42
- :param model: nn.Module, input model
43
- :param include_module_types: list, module types that want to include
44
- :return:
45
- """
46
- modules_to_merge: Dict[str, nn.Module] = {}
47
- for module_name, module in model.named_modules():
48
- is_valid_type = not include_module_types or any(
49
- [
50
- isinstance(module, include_module_type)
51
- for include_module_type in include_module_types
52
- ]
53
- )
54
- if is_valid_type:
55
- modules_to_merge[module_name] = module
56
- return modules_to_merge
57
-
58
-
59
18
  def reduce_non_diagonal_elements(
60
19
  regmean_weights: torch.Tensor, reduce_non_diagonal_ratio: float
61
20
  ):
@@ -130,12 +89,16 @@ def merging_with_regmean_weights(
130
89
  ):
131
90
  """
132
91
  merge parameters of different models with computed regmean weights
133
- :param models_to_merge_param_dict: dict, dictionary of list, where key is the parameter name,
134
- value is a list of the corresponding parameters of all the models that need to be merged
135
- :param models_to_merge_regmean_weights_list: list, list of dictionaries with length len(models_to_merge),
136
- each dictionary records the regmean weights (matrix) of parameters for each model that needs to be merged, key is module name
137
- :param reduce_non_diagonal_ratio: float, reduce non-diagonal elements in regmean weights by multiplying this scalar
138
- :return:
92
+
93
+ Asrgs:
94
+ models_to_merge_param_dict: dict, dictionary of list, where key is the parameter name,
95
+ value is a list of the corresponding parameters of all the models that need to be merged
96
+ models_to_merge_regmean_weights_list: list, list of dictionaries with length len(models_to_merge),
97
+ each dictionary records the regmean weights (matrix) of parameters for each model that needs to be merged, key is module name
98
+ reduce_non_diagonal_ratio: float, reduce non-diagonal elements in regmean weights by multiplying this scalar
99
+
100
+ Returns:
101
+ dict: merged model parameters
139
102
  """
140
103
  # dict, dictionary of model parameters
141
104
  merged_params = {}
@@ -176,14 +139,12 @@ def merging_with_regmean_weights(
176
139
  return merged_params
177
140
 
178
141
 
179
- class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
142
+ @auto_register_config
143
+ class RegMeanAlgorithmPlusPlus(
144
+ SimpleProfilerMixin,
145
+ BaseAlgorithm,
146
+ ):
180
147
  _include_module_type = [nn.Linear]
181
- _config_mapping = {
182
- "num_regmean_examples": "num_regmean_examples",
183
- "exclude_param_names_regex": "exclude_param_names_regex",
184
- "reduce_non_diagonal_ratio": "reduce_non_diagonal_ratio",
185
- "weight_transpose": "weight_transpose",
186
- }
187
148
 
188
149
  def __init__(
189
150
  self,
@@ -194,11 +155,11 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
194
155
  weight_transpose: bool,
195
156
  **kwargs,
196
157
  ):
158
+ super().__init__(**kwargs)
197
159
  self.num_regmean_examples = num_regmean_examples
198
160
  self.exclude_param_names_regex = exclude_param_names_regex
199
161
  self.reduce_non_diagonal_ratio = reduce_non_diagonal_ratio
200
162
  self.weight_transpose = weight_transpose
201
- super().__init__(**kwargs)
202
163
 
203
164
  def run(self, modelpool: BaseModelPool, **kwargs):
204
165
  if not isinstance(modelpool, BaseModelPool):
@@ -262,7 +223,7 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
262
223
 
263
224
  # exclude parameter whose name matches element in exclude_param_names_regex
264
225
  if param_names_to_merge is None:
265
- param_names_to_merge = get_param_names_to_merge(
226
+ param_names_to_merge = regmean_utils.get_param_names_to_merge(
266
227
  input_param_names=list(param_dict.keys()),
267
228
  exclude_param_names_regex=self.config.get(
268
229
  "exclude_param_names_regex", []
@@ -274,7 +235,7 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
274
235
  param_dict[param_name]
275
236
  )
276
237
 
277
- linear_modules_to_merge = get_modules_to_merge(
238
+ linear_modules_to_merge = regmean_utils.get_modules_to_merge(
278
239
  model=layer_to_merge,
279
240
  include_module_types=self._include_module_type,
280
241
  )
@@ -294,7 +255,7 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
294
255
  linear_modules_to_merge=linear_modules_to_merge,
295
256
  )
296
257
 
297
- module_subset = get_param_names_to_merge(
258
+ module_subset = regmean_utils.get_param_names_to_merge(
298
259
  input_param_names=list(param_dict.keys()),
299
260
  exclude_param_names_regex=self.exclude_param_names_regex,
300
261
  )
@@ -61,8 +61,8 @@ def simple_average(
61
61
 
62
62
  @auto_register_config
63
63
  class SimpleAverageAlgorithm(
64
- BaseAlgorithm,
65
64
  SimpleProfilerMixin,
65
+ BaseAlgorithm,
66
66
  ):
67
67
  def __init__(self, show_pbar: bool = False, **kwargs):
68
68
  """
@@ -120,13 +120,13 @@ class SimpleAverageAlgorithm(
120
120
  if isinstance(forward_model, LazyStateDict):
121
121
  # if the model is a LazyStateDict, convert it to an empty module
122
122
  forward_model = forward_model.meta_module.to_empty(
123
- device=(
124
- "cpu"
125
- if forward_model._torch_dtype is None
126
- else forward_model._torch_dtype
127
- )
123
+ device=forward_model._device
128
124
  )
129
- forward_model.load_state_dict(sd)
125
+ result = forward_model.load_state_dict(sd, strict=False)
126
+ if result.unexpected_keys:
127
+ raise ValueError(f"Unexpected keys in state dict: {result.unexpected_keys}")
128
+ if result.missing_keys:
129
+ log.warning(f"Missing keys in state dict: {result.missing_keys}")
130
130
  # print profile report and log the merged models
131
131
  self.print_profile_summary()
132
132
  log.info(f"merged {len(merged_model_names)} models:")
@@ -1,2 +1,2 @@
1
1
  # flake8: noqa F401
2
- from .slerp import SlerpMergeAlgorithm
2
+ from .slerp import SlerpForCausalLM, SlerpMergeAlgorithm