fusion-bench 0.2.22__py3-none-any.whl → 0.2.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. fusion_bench/__init__.py +4 -0
  2. fusion_bench/compat/method/__init__.py +5 -2
  3. fusion_bench/compat/method/base_algorithm.py +3 -2
  4. fusion_bench/compat/modelpool/base_pool.py +3 -3
  5. fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
  6. fusion_bench/dataset/gpt2_glue.py +1 -1
  7. fusion_bench/method/__init__.py +12 -2
  8. fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
  9. fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
  10. fusion_bench/method/bitdelta/bitdelta.py +7 -23
  11. fusion_bench/method/ensemble.py +17 -2
  12. fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
  13. fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
  14. fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
  15. fusion_bench/method/linear/__init__.py +6 -2
  16. fusion_bench/method/linear/{simple_average_for_llama.py → simple_average_for_causallm.py} +8 -4
  17. fusion_bench/method/linear/{task_arithmetic_for_llama.py → task_arithmetic_for_causallm.py} +22 -12
  18. fusion_bench/method/linear/ties_merging_for_causallm.py +70 -0
  19. fusion_bench/method/model_stock/__init__.py +1 -0
  20. fusion_bench/method/model_stock/model_stock.py +309 -0
  21. fusion_bench/method/regmean/clip_regmean.py +3 -6
  22. fusion_bench/method/regmean/regmean.py +27 -56
  23. fusion_bench/method/regmean/utils.py +56 -0
  24. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
  25. fusion_bench/method/simple_average.py +2 -2
  26. fusion_bench/method/slerp/__init__.py +1 -1
  27. fusion_bench/method/slerp/slerp.py +110 -14
  28. fusion_bench/method/task_arithmetic/task_arithmetic.py +35 -10
  29. fusion_bench/method/ties_merging/ties_merging.py +22 -6
  30. fusion_bench/method/we_moe/flan_t5_we_moe.py +9 -20
  31. fusion_bench/method/wudi/__init__.py +1 -0
  32. fusion_bench/method/wudi/wudi.py +105 -0
  33. fusion_bench/mixins/clip_classification.py +26 -6
  34. fusion_bench/mixins/lightning_fabric.py +4 -0
  35. fusion_bench/mixins/serialization.py +40 -83
  36. fusion_bench/modelpool/base_pool.py +1 -1
  37. fusion_bench/modelpool/causal_lm/causal_lm.py +285 -44
  38. fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
  39. fusion_bench/models/hf_clip.py +4 -0
  40. fusion_bench/models/hf_utils.py +10 -4
  41. fusion_bench/models/linearized/vision_model.py +6 -6
  42. fusion_bench/models/model_card_templates/default.md +8 -1
  43. fusion_bench/models/modeling_smile_mistral/__init__.py +1 -0
  44. fusion_bench/models/we_moe.py +8 -8
  45. fusion_bench/models/wrappers/ensemble.py +136 -7
  46. fusion_bench/scripts/cli.py +2 -2
  47. fusion_bench/taskpool/base_pool.py +99 -17
  48. fusion_bench/taskpool/clip_vision/taskpool.py +12 -5
  49. fusion_bench/taskpool/dummy.py +101 -13
  50. fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
  51. fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
  52. fusion_bench/utils/__init__.py +1 -0
  53. fusion_bench/utils/data.py +6 -4
  54. fusion_bench/utils/devices.py +36 -11
  55. fusion_bench/utils/dtype.py +3 -2
  56. fusion_bench/utils/lazy_state_dict.py +85 -19
  57. fusion_bench/utils/packages.py +3 -3
  58. fusion_bench/utils/parameters.py +0 -2
  59. fusion_bench/utils/rich_utils.py +7 -3
  60. fusion_bench/utils/timer.py +92 -10
  61. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/METADATA +10 -3
  62. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/RECORD +77 -64
  63. fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
  64. fusion_bench_config/method/ensemble/simple_ensemble.yaml +1 -0
  65. fusion_bench_config/method/linear/{simple_average_for_llama.yaml → simple_average_for_causallm.yaml} +1 -1
  66. fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +4 -0
  67. fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +13 -0
  68. fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
  69. fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
  70. fusion_bench_config/method/wudi/wudi.yaml +4 -0
  71. fusion_bench_config/modelpool/CausalLMPool/{Qwen2.5-1.5B_math_and_coder.yaml → Qwen2.5-1.5B_math_and_code.yaml} +1 -2
  72. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_three_models.yaml +11 -0
  73. fusion_bench_config/modelpool/CausalLMPool/llama-7b_3-models_v1.yaml +11 -0
  74. fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +0 -4
  75. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/WHEEL +0 -0
  76. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/entry_points.txt +0 -0
  77. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/licenses/LICENSE +0 -0
  78. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/top_level.txt +0 -0
@@ -16,49 +16,9 @@ from fusion_bench.method import BaseAlgorithm
16
16
  from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
17
17
  from fusion_bench.modelpool import BaseModelPool
18
18
 
19
- log = logging.getLogger(__name__)
20
-
21
-
22
- def get_param_names_to_merge(
23
- input_param_names: List[str], exclude_param_names_regex: list
24
- ):
25
- """
26
- get the names of parameters that need to be merged
27
- :param input_param_names: list, names of input parameters
28
- :param exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded
29
- :return:
30
- """
31
- param_names_to_merge = []
32
- for param_name in input_param_names:
33
- exclude = any(
34
- [
35
- re.match(exclude_pattern, param_name)
36
- for exclude_pattern in exclude_param_names_regex
37
- ]
38
- )
39
- if not exclude:
40
- param_names_to_merge.append(param_name)
41
- return param_names_to_merge
42
-
19
+ from .utils import get_modules_to_merge, get_param_names_to_merge
43
20
 
44
- def get_modules_to_merge(model: nn.Module, include_module_types: list):
45
- """
46
- get the model modules that need to be merged, whose type is in include_module_types
47
- :param model: nn.Module, input model
48
- :param include_module_types: list, module types that want to include
49
- :return:
50
- """
51
- modules_to_merge: Dict[str, nn.Module] = {}
52
- for module_name, module in model.named_modules():
53
- is_valid_type = not include_module_types or any(
54
- [
55
- isinstance(module, include_module_type)
56
- for include_module_type in include_module_types
57
- ]
58
- )
59
- if is_valid_type:
60
- modules_to_merge[module_name] = module
61
- return modules_to_merge
21
+ log = logging.getLogger(__name__)
62
22
 
63
23
 
64
24
  def reduce_non_diagonal_elements(
@@ -88,12 +48,16 @@ def merging_with_regmean_weights(
88
48
  ):
89
49
  """
90
50
  merge parameters of different models with computed regmean weights
91
- :param models_to_merge_param_dict: dict, dictionary of list, where key is the parameter name,
92
- value is a list of the corresponding parameters of all the models that need to be merged
93
- :param models_to_merge_regmean_weights_list: list, list of dictionaries with length len(models_to_merge),
94
- each dictionary records the regmean weights (matrix) of parameters for each model that needs to be merged, key is module name
95
- :param reduce_non_diagonal_ratio: float, reduce non-diagonal elements in regmean weights by multiplying this scalar
96
- :return:
51
+
52
+ Args:
53
+ models_to_merge_param_dict: dict, dictionary of list, where key is the parameter name,
54
+ value is a list of the corresponding parameters of all the models that need to be merged
55
+ models_to_merge_regmean_weights_list: list, list of dictionaries with length len(models_to_merge),
56
+ each dictionary records the regmean weights (matrix) of parameters for each model that needs to be merged, key is module name
57
+ reduce_non_diagonal_ratio: float, reduce non-diagonal elements in regmean weights by multiplying this scalar
58
+
59
+ Returns:
60
+ dict: merged model parameters
97
61
  """
98
62
  # dict, dictionary of model parameters
99
63
  merged_params = {}
@@ -164,13 +128,17 @@ def regmean_merging(
164
128
  reduce_non_diagonal_ratio: float = 1.0,
165
129
  ):
166
130
  """
167
- regmean merging method
168
- :param models_to_merge: list, individual models that need to be merged
169
- :param trainers: list, trainers of individual models
170
- :param exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded
171
- :param nums_regmean_examples: list, numbers of examples to compute regmean weights
172
- :param reduce_non_diagonal_ratio: float, reduce non-diagonal elements in regmean weights by multiplying this scalar
173
- :return:
131
+ regmean merging method.
132
+
133
+ Args:
134
+ models_to_merge: list, individual models that need to be merged
135
+ trainers: list, trainers of individual models
136
+ exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded
137
+ nums_regmean_examples: list, numbers of examples to compute regmean weights
138
+ reduce_non_diagonal_ratio: float, reduce non-diagonal elements in regmean weights by multiplying this scalar
139
+
140
+ Returns:
141
+ dict: merged model parameters
174
142
  """
175
143
 
176
144
  def compute_regmean_weights(module_name: str):
@@ -281,7 +249,10 @@ def regmean_merging(
281
249
 
282
250
 
283
251
  @auto_register_config
284
- class RegMeanAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
252
+ class RegMeanAlgorithm(
253
+ SimpleProfilerMixin,
254
+ BaseAlgorithm,
255
+ ):
285
256
  _include_module_type = [nn.Linear]
286
257
 
287
258
  def __init__(
@@ -0,0 +1,56 @@
1
+ import re
2
+ from typing import Dict, List
3
+
4
+ from torch import nn
5
+
6
+
7
+ def get_param_names_to_merge(
8
+ input_param_names: List[str], exclude_param_names_regex: list
9
+ ) -> List[str]:
10
+ """
11
+ get the names of parameters that need to be merged
12
+
13
+ Args:
14
+ input_param_names: list, names of input parameters
15
+ exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded
16
+
17
+ Returns:
18
+ list: names of parameters that need to be merged
19
+ """
20
+ param_names_to_merge = []
21
+ for param_name in input_param_names:
22
+ exclude = any(
23
+ [
24
+ re.match(exclude_pattern, param_name)
25
+ for exclude_pattern in exclude_param_names_regex
26
+ ]
27
+ )
28
+ if not exclude:
29
+ param_names_to_merge.append(param_name)
30
+ return param_names_to_merge
31
+
32
+
33
+ def get_modules_to_merge(
34
+ model: nn.Module, include_module_types: list
35
+ ) -> Dict[str, nn.Module]:
36
+ """
37
+ get the model modules that need to be merged, whose type is in include_module_types
38
+
39
+ Args:
40
+ model: nn.Module, input model
41
+ include_module_types: list, module types that want to include
42
+
43
+ Returns:
44
+ Dict[str, nn.Module]: a dictionary of modules to merge
45
+ """
46
+ modules_to_merge: Dict[str, nn.Module] = {}
47
+ for module_name, module in model.named_modules():
48
+ is_valid_type = not include_module_types or any(
49
+ [
50
+ isinstance(module, include_module_type)
51
+ for include_module_type in include_module_types
52
+ ]
53
+ )
54
+ if is_valid_type:
55
+ modules_to_merge[module_name] = module
56
+ return modules_to_merge
@@ -7,55 +7,14 @@ import torch
7
7
  from torch import Tensor, nn
8
8
  from tqdm.autonotebook import tqdm
9
9
 
10
- from fusion_bench.method import BaseAlgorithm
10
+ import fusion_bench.method.regmean.utils as regmean_utils
11
+ from fusion_bench import BaseAlgorithm, auto_register_config
11
12
  from fusion_bench.mixins import SimpleProfilerMixin
12
13
  from fusion_bench.modelpool import BaseModelPool
13
14
 
14
15
  log = logging.getLogger(__name__)
15
16
 
16
17
 
17
- def get_param_names_to_merge(
18
- input_param_names: List[str], exclude_param_names_regex: list
19
- ):
20
- """
21
- get the names of parameters that need to be merged
22
- :param input_param_names: list, names of input parameters
23
- :param exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded
24
- :return:
25
- """
26
- param_names_to_merge = []
27
- for param_name in input_param_names:
28
- exclude = any(
29
- [
30
- re.match(exclude_pattern, param_name)
31
- for exclude_pattern in exclude_param_names_regex
32
- ]
33
- )
34
- if not exclude:
35
- param_names_to_merge.append(param_name)
36
- return param_names_to_merge
37
-
38
-
39
- def get_modules_to_merge(model: nn.Module, include_module_types: list):
40
- """
41
- get the model modules that need to be merged, whose type is in include_module_types
42
- :param model: nn.Module, input model
43
- :param include_module_types: list, module types that want to include
44
- :return:
45
- """
46
- modules_to_merge: Dict[str, nn.Module] = {}
47
- for module_name, module in model.named_modules():
48
- is_valid_type = not include_module_types or any(
49
- [
50
- isinstance(module, include_module_type)
51
- for include_module_type in include_module_types
52
- ]
53
- )
54
- if is_valid_type:
55
- modules_to_merge[module_name] = module
56
- return modules_to_merge
57
-
58
-
59
18
  def reduce_non_diagonal_elements(
60
19
  regmean_weights: torch.Tensor, reduce_non_diagonal_ratio: float
61
20
  ):
@@ -130,12 +89,16 @@ def merging_with_regmean_weights(
130
89
  ):
131
90
  """
132
91
  merge parameters of different models with computed regmean weights
133
- :param models_to_merge_param_dict: dict, dictionary of list, where key is the parameter name,
134
- value is a list of the corresponding parameters of all the models that need to be merged
135
- :param models_to_merge_regmean_weights_list: list, list of dictionaries with length len(models_to_merge),
136
- each dictionary records the regmean weights (matrix) of parameters for each model that needs to be merged, key is module name
137
- :param reduce_non_diagonal_ratio: float, reduce non-diagonal elements in regmean weights by multiplying this scalar
138
- :return:
92
+
93
+ Asrgs:
94
+ models_to_merge_param_dict: dict, dictionary of list, where key is the parameter name,
95
+ value is a list of the corresponding parameters of all the models that need to be merged
96
+ models_to_merge_regmean_weights_list: list, list of dictionaries with length len(models_to_merge),
97
+ each dictionary records the regmean weights (matrix) of parameters for each model that needs to be merged, key is module name
98
+ reduce_non_diagonal_ratio: float, reduce non-diagonal elements in regmean weights by multiplying this scalar
99
+
100
+ Returns:
101
+ dict: merged model parameters
139
102
  """
140
103
  # dict, dictionary of model parameters
141
104
  merged_params = {}
@@ -176,14 +139,12 @@ def merging_with_regmean_weights(
176
139
  return merged_params
177
140
 
178
141
 
179
- class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
142
+ @auto_register_config
143
+ class RegMeanAlgorithmPlusPlus(
144
+ SimpleProfilerMixin,
145
+ BaseAlgorithm,
146
+ ):
180
147
  _include_module_type = [nn.Linear]
181
- _config_mapping = {
182
- "num_regmean_examples": "num_regmean_examples",
183
- "exclude_param_names_regex": "exclude_param_names_regex",
184
- "reduce_non_diagonal_ratio": "reduce_non_diagonal_ratio",
185
- "weight_transpose": "weight_transpose",
186
- }
187
148
 
188
149
  def __init__(
189
150
  self,
@@ -194,11 +155,11 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
194
155
  weight_transpose: bool,
195
156
  **kwargs,
196
157
  ):
158
+ super().__init__(**kwargs)
197
159
  self.num_regmean_examples = num_regmean_examples
198
160
  self.exclude_param_names_regex = exclude_param_names_regex
199
161
  self.reduce_non_diagonal_ratio = reduce_non_diagonal_ratio
200
162
  self.weight_transpose = weight_transpose
201
- super().__init__(**kwargs)
202
163
 
203
164
  def run(self, modelpool: BaseModelPool, **kwargs):
204
165
  if not isinstance(modelpool, BaseModelPool):
@@ -262,7 +223,7 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
262
223
 
263
224
  # exclude parameter whose name matches element in exclude_param_names_regex
264
225
  if param_names_to_merge is None:
265
- param_names_to_merge = get_param_names_to_merge(
226
+ param_names_to_merge = regmean_utils.get_param_names_to_merge(
266
227
  input_param_names=list(param_dict.keys()),
267
228
  exclude_param_names_regex=self.config.get(
268
229
  "exclude_param_names_regex", []
@@ -274,7 +235,7 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
274
235
  param_dict[param_name]
275
236
  )
276
237
 
277
- linear_modules_to_merge = get_modules_to_merge(
238
+ linear_modules_to_merge = regmean_utils.get_modules_to_merge(
278
239
  model=layer_to_merge,
279
240
  include_module_types=self._include_module_type,
280
241
  )
@@ -294,7 +255,7 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
294
255
  linear_modules_to_merge=linear_modules_to_merge,
295
256
  )
296
257
 
297
- module_subset = get_param_names_to_merge(
258
+ module_subset = regmean_utils.get_param_names_to_merge(
298
259
  input_param_names=list(param_dict.keys()),
299
260
  exclude_param_names_regex=self.exclude_param_names_regex,
300
261
  )
@@ -89,7 +89,7 @@ class SimpleAverageAlgorithm(
89
89
  modelpool = BaseModelPool(modelpool)
90
90
 
91
91
  log.info(
92
- f"Fusing models using simple average on {len(modelpool.model_names)} models."
92
+ f"Fusing models using simple average on {len(modelpool.model_names)} models. "
93
93
  f"models: {modelpool.model_names}"
94
94
  )
95
95
  sd: Optional[StateDictType] = None
@@ -119,7 +119,7 @@ class SimpleAverageAlgorithm(
119
119
 
120
120
  if isinstance(forward_model, LazyStateDict):
121
121
  # if the model is a LazyStateDict, convert it to an empty module
122
- forward_model = forward_model.meta_module.to_empty(
122
+ forward_model = deepcopy(forward_model.meta_module).to_empty(
123
123
  device=forward_model._device
124
124
  )
125
125
  result = forward_model.load_state_dict(sd, strict=False)
@@ -1,2 +1,2 @@
1
1
  # flake8: noqa F401
2
- from .slerp import SlerpMergeAlgorithm
2
+ from .slerp import SlerpForCausalLM, SlerpMergeAlgorithm
@@ -1,16 +1,24 @@
1
1
  import logging
2
- from typing import Any, Dict
2
+ import os
3
+ from copy import deepcopy
4
+ from typing import TYPE_CHECKING, Any, Dict, Optional
3
5
 
4
6
  import torch
5
7
  from torch import nn
8
+ from tqdm import tqdm
6
9
  from typing_extensions import override
7
10
 
11
+ from fusion_bench import LazyStateDict, create_default_model_card, timeit_context
8
12
  from fusion_bench.method import BaseAlgorithm
9
- from fusion_bench.modelpool import BaseModelPool
13
+ from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
14
+ from fusion_bench.modelpool import BaseModelPool, CausalLMPool
10
15
  from fusion_bench.utils.type import StateDictType
11
16
 
12
17
  from .slerp_utils import slerp
13
18
 
19
+ if TYPE_CHECKING:
20
+ from transformers import PreTrainedModel
21
+
14
22
  log = logging.getLogger(__name__)
15
23
 
16
24
 
@@ -21,6 +29,7 @@ def slerp_on_state_dicts(
21
29
  *,
22
30
  DOT_THRESHOLD: float = 0.9995,
23
31
  epsilon: float = 1e-8,
32
+ show_pbar: bool = False,
24
33
  ) -> StateDictType:
25
34
  """
26
35
  Perform spherical linear interpolation (slerp) on the state dictionaries of two models.
@@ -36,7 +45,8 @@ def slerp_on_state_dicts(
36
45
  dict: The interpolated state dictionary.
37
46
  """
38
47
  state_dict = {}
39
- for key in secondary_state_dict:
48
+ pbar = secondary_state_dict if not show_pbar else tqdm(secondary_state_dict)
49
+ for key in pbar:
40
50
  v0 = primary_state_dict[key]
41
51
  v1 = secondary_state_dict[key]
42
52
  if v0.shape != v1.shape:
@@ -49,18 +59,19 @@ def slerp_on_state_dicts(
49
59
  return state_dict
50
60
 
51
61
 
62
+ @auto_register_config
52
63
  class SlerpMergeAlgorithm(BaseAlgorithm):
53
64
  """
54
65
  General purpose implementation of Slerp (Spherical Linear Interpolation) for PyTorch models.
55
66
  """
56
67
 
57
- _config_mapping = BaseAlgorithm._config_mapping | {
58
- "t": "t",
59
- "DOT_THRESHOLD": "DOT_THRESHOLD",
60
- "epsilon": "epsilon",
61
- }
62
-
63
- def __init__(self, t: float, DOT_THRESHOLD: float = 0.9995, epsilon: float = 1e-8):
68
+ def __init__(
69
+ self,
70
+ t: float,
71
+ DOT_THRESHOLD: float = 0.9995,
72
+ epsilon: float = 1e-8,
73
+ **kwargs,
74
+ ):
64
75
  """
65
76
  Initialize the SlerpMergeAlgorithm.
66
77
 
@@ -69,10 +80,7 @@ class SlerpMergeAlgorithm(BaseAlgorithm):
69
80
  DOT_THRESHOLD (float, optional): The threshold for the dot product of the two vectors. Defaults to 0.9995.
70
81
  epsilon (float, optional): The epsilon value for numerical stability. Defaults to 1e-8.
71
82
  """
72
- self.t = t
73
- self.DOT_THRESHOLD = DOT_THRESHOLD
74
- self.epsilon = epsilon
75
- super().__init__()
83
+ super().__init__(**kwargs)
76
84
 
77
85
  @override
78
86
  def run(self, modelpool: BaseModelPool) -> nn.Module:
@@ -102,3 +110,91 @@ class SlerpMergeAlgorithm(BaseAlgorithm):
102
110
 
103
111
  primary_model.load_state_dict(state_dict)
104
112
  return primary_model
113
+
114
+
115
+ @auto_register_config
116
+ class SlerpForCausalLM(
117
+ SimpleProfilerMixin,
118
+ BaseAlgorithm,
119
+ ):
120
+ """
121
+ Slerp (Spherical Linear Interpolation) for Causal Language Models.
122
+ """
123
+
124
+ def __init__(
125
+ self,
126
+ t: float,
127
+ DOT_THRESHOLD: float = 0.9995,
128
+ epsilon: float = 1e-8,
129
+ model_save_path: Optional[str] = None,
130
+ show_pbar: bool = False,
131
+ **kwargs,
132
+ ):
133
+ """
134
+ Initialize the SlerpForCausalLM algorithm.
135
+
136
+ Args:
137
+ t (float): The interpolation parameter. Must be in the range [0, 1].
138
+ t=0 returns the first model, t=1 returns the second model,
139
+ t=0.5 provides balanced interpolation.
140
+ DOT_THRESHOLD (float, optional): The threshold for the dot product of normalized vectors.
141
+ When the absolute dot product exceeds this threshold,
142
+ vectors are considered nearly collinear and linear
143
+ interpolation (LERP) is used instead of SLERP for
144
+ numerical stability. Defaults to 0.9995.
145
+ epsilon (float, optional): Small value used for numerical stability to avoid
146
+ division by zero during vector normalization.
147
+ Defaults to 1e-8.
148
+ model_save_path (Optional[str], optional): Path where the merged model should be saved.
149
+ If None, the model is not saved to disk.
150
+ Defaults to None.
151
+ show_pbar (bool, optional): Whether to display a progress bar during the interpolation
152
+ process. Useful for debugging or monitoring progress with
153
+ large models. Defaults to False.
154
+ **kwargs: Additional keyword arguments passed to the parent BaseAlgorithm class.
155
+ """
156
+ super().__init__(**kwargs)
157
+
158
+ @override
159
+ def run(self, modelpool: CausalLMPool):
160
+ assert len(modelpool.all_model_names) == 2, "Slerp expect exactly 2 models"
161
+ primary_model = modelpool.load_model(modelpool.all_model_names[0])
162
+ secondary_model = modelpool.load_model(modelpool.all_model_names[1])
163
+
164
+ with torch.no_grad():
165
+ primary_state_dict = primary_model.state_dict()
166
+ secondary_state_dict = secondary_model.state_dict()
167
+ state_dict = slerp_on_state_dicts(
168
+ self.t,
169
+ primary_state_dict,
170
+ secondary_state_dict,
171
+ DOT_THRESHOLD=self.DOT_THRESHOLD,
172
+ epsilon=self.epsilon,
173
+ )
174
+
175
+ if isinstance(primary_model, nn.Module):
176
+ model = primary_model
177
+ model.load_state_dict(state_dict)
178
+ elif isinstance(primary_model, LazyStateDict):
179
+ model: "PreTrainedModel" = deepcopy(primary_model.meta_module)
180
+ model.to(device=primary_model._device)
181
+ model.load_state_dict(state_dict)
182
+ else:
183
+ raise TypeError(
184
+ f"Unsupported model type: {type(primary_model)}. "
185
+ "Expected nn.Module or LazyStateDict."
186
+ )
187
+ if self.model_save_path is not None:
188
+ with timeit_context(f"Saving the model to {self.model_save_path}"):
189
+ tokenizer = modelpool.load_tokenizer()
190
+ tokenizer.save_pretrained(self.model_save_path)
191
+ model.save_pretrained(self.model_save_path)
192
+ model_card_str = create_default_model_card(
193
+ models=[modelpool.get_model_path(m) for m in modelpool.model_names],
194
+ description="Merged model using Slerp.",
195
+ algorithm_config=self.config,
196
+ modelpool_config=modelpool.config,
197
+ )
198
+ with open(os.path.join(self.model_save_path, "README.md"), "w") as f:
199
+ f.write(model_card_str)
200
+ return model
@@ -6,11 +6,20 @@ http://arxiv.org/abs/2212.04089
6
6
 
7
7
  import logging
8
8
  from copy import deepcopy
9
- from typing import Dict, List, Mapping, Optional, TypeVar, Union # noqa: F401
9
+ from typing import ( # noqa: F401
10
+ TYPE_CHECKING,
11
+ Dict,
12
+ List,
13
+ Mapping,
14
+ Optional,
15
+ TypeVar,
16
+ Union,
17
+ )
10
18
 
11
19
  import torch
12
20
  from torch import nn
13
21
 
22
+ from fusion_bench import LazyStateDict
14
23
  from fusion_bench.method.base_algorithm import BaseAlgorithm
15
24
  from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
16
25
  from fusion_bench.modelpool import BaseModelPool
@@ -21,6 +30,8 @@ from fusion_bench.utils.state_dict_arithmetic import (
21
30
  )
22
31
  from fusion_bench.utils.type import StateDictType, TorchModelType
23
32
 
33
+ if TYPE_CHECKING:
34
+ from transformers import PreTrainedModel
24
35
  log = logging.getLogger(__name__)
25
36
 
26
37
 
@@ -125,25 +136,39 @@ class TaskArithmeticAlgorithm(
125
136
  with self.profile("merge weights"):
126
137
  if task_vector is None:
127
138
  task_vector = state_dict_sub(
128
- model.state_dict(keep_vars=True),
129
- pretrained_model.state_dict(keep_vars=True),
139
+ model.state_dict(),
140
+ pretrained_model.state_dict(),
130
141
  )
131
142
  else:
132
143
  task_vector = state_dict_add(
133
144
  task_vector,
134
145
  state_dict_sub(
135
- model.state_dict(keep_vars=True),
136
- pretrained_model.state_dict(keep_vars=True),
146
+ model.state_dict(),
147
+ pretrained_model.state_dict(),
137
148
  ),
138
149
  )
139
150
  with self.profile("merge weights"):
140
151
  # scale the task vector
141
152
  task_vector = state_dict_mul(task_vector, self.config.scaling_factor)
142
153
  # add the task vector to the pretrained model
143
- state_dict = state_dict_add(
144
- pretrained_model.state_dict(keep_vars=True), task_vector
145
- )
154
+ state_dict = state_dict_add(pretrained_model.state_dict(), task_vector)
146
155
 
147
156
  self.print_profile_summary()
148
- pretrained_model.load_state_dict(state_dict)
149
- return pretrained_model
157
+
158
+ # apply state dict to model
159
+ if isinstance(pretrained_model, nn.Module):
160
+ model = pretrained_model
161
+ model.load_state_dict(state_dict)
162
+ elif isinstance(pretrained_model, LazyStateDict):
163
+ model = deepcopy(pretrained_model.meta_module)
164
+ model = model.to_empty(device=pretrained_model._device)
165
+ result = model.load_state_dict(state_dict, strict=False)
166
+ if result.unexpected_keys:
167
+ raise ValueError(
168
+ f"Unexpected keys in state dict: {result.unexpected_keys}"
169
+ )
170
+ if result.missing_keys:
171
+ log.warning(f"Missing keys in state dict: {result.missing_keys}")
172
+ else:
173
+ raise TypeError(f"Unsupported model type: {type(pretrained_model)}")
174
+ return model
@@ -9,11 +9,14 @@ Overview of Ties-Merging:
9
9
  """
10
10
 
11
11
  import logging
12
+ from copy import deepcopy
12
13
  from typing import Any, Dict, List, Literal, Mapping, Union # noqa: F401
13
14
 
14
15
  import torch
15
16
  from torch import Tensor, nn
17
+ from transformers import PreTrainedModel
16
18
 
19
+ from fusion_bench import LazyStateDict
17
20
  from fusion_bench.compat.modelpool import to_modelpool
18
21
  from fusion_bench.method import BaseAlgorithm
19
22
  from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
@@ -98,12 +101,25 @@ class TiesMergingAlgorithm(
98
101
  merge_func=merge_func,
99
102
  )
100
103
  merged_check = flat_ptm + scaling_factor * merged_tv
101
- merged_state_dict = vector_to_state_dict(
104
+ state_dict = vector_to_state_dict(
102
105
  merged_check, ptm_check, remove_keys=remove_keys
103
106
  )
104
-
105
- # Load the merged state dict into the pretrained model
106
- pretrained_model.load_state_dict(merged_state_dict)
107
-
108
107
  self.print_profile_summary()
109
- return pretrained_model
108
+
109
+ # apply state dict to model
110
+ if isinstance(pretrained_model, nn.Module):
111
+ model = pretrained_model
112
+ model.load_state_dict(state_dict)
113
+ elif isinstance(pretrained_model, LazyStateDict):
114
+ model = deepcopy(pretrained_model.meta_module)
115
+ model = model.to_empty(device=pretrained_model._device)
116
+ result = model.load_state_dict(state_dict, strict=False)
117
+ if result.unexpected_keys:
118
+ raise ValueError(
119
+ f"Unexpected keys in state dict: {result.unexpected_keys}"
120
+ )
121
+ if result.missing_keys:
122
+ log.warning(f"Missing keys in state dict: {result.missing_keys}")
123
+ else:
124
+ raise TypeError(f"Unsupported model type: {type(pretrained_model)}")
125
+ return model