fusion-bench 0.2.22__py3-none-any.whl → 0.2.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. fusion_bench/__init__.py +4 -0
  2. fusion_bench/compat/method/__init__.py +5 -2
  3. fusion_bench/compat/method/base_algorithm.py +3 -2
  4. fusion_bench/compat/modelpool/base_pool.py +3 -3
  5. fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
  6. fusion_bench/dataset/gpt2_glue.py +1 -1
  7. fusion_bench/method/__init__.py +12 -2
  8. fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
  9. fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
  10. fusion_bench/method/bitdelta/bitdelta.py +7 -23
  11. fusion_bench/method/ensemble.py +17 -2
  12. fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
  13. fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
  14. fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
  15. fusion_bench/method/linear/__init__.py +6 -2
  16. fusion_bench/method/linear/{simple_average_for_llama.py → simple_average_for_causallm.py} +8 -4
  17. fusion_bench/method/linear/{task_arithmetic_for_llama.py → task_arithmetic_for_causallm.py} +22 -12
  18. fusion_bench/method/linear/ties_merging_for_causallm.py +70 -0
  19. fusion_bench/method/model_stock/__init__.py +1 -0
  20. fusion_bench/method/model_stock/model_stock.py +309 -0
  21. fusion_bench/method/regmean/clip_regmean.py +3 -6
  22. fusion_bench/method/regmean/regmean.py +27 -56
  23. fusion_bench/method/regmean/utils.py +56 -0
  24. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
  25. fusion_bench/method/simple_average.py +2 -2
  26. fusion_bench/method/slerp/__init__.py +1 -1
  27. fusion_bench/method/slerp/slerp.py +110 -14
  28. fusion_bench/method/task_arithmetic/task_arithmetic.py +35 -10
  29. fusion_bench/method/ties_merging/ties_merging.py +22 -6
  30. fusion_bench/method/we_moe/flan_t5_we_moe.py +9 -20
  31. fusion_bench/method/wudi/__init__.py +1 -0
  32. fusion_bench/method/wudi/wudi.py +105 -0
  33. fusion_bench/mixins/clip_classification.py +26 -6
  34. fusion_bench/mixins/lightning_fabric.py +4 -0
  35. fusion_bench/mixins/serialization.py +40 -83
  36. fusion_bench/modelpool/base_pool.py +1 -1
  37. fusion_bench/modelpool/causal_lm/causal_lm.py +285 -44
  38. fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
  39. fusion_bench/models/hf_clip.py +4 -0
  40. fusion_bench/models/hf_utils.py +10 -4
  41. fusion_bench/models/linearized/vision_model.py +6 -6
  42. fusion_bench/models/model_card_templates/default.md +8 -1
  43. fusion_bench/models/modeling_smile_mistral/__init__.py +1 -0
  44. fusion_bench/models/we_moe.py +8 -8
  45. fusion_bench/models/wrappers/ensemble.py +136 -7
  46. fusion_bench/scripts/cli.py +2 -2
  47. fusion_bench/taskpool/base_pool.py +99 -17
  48. fusion_bench/taskpool/clip_vision/taskpool.py +12 -5
  49. fusion_bench/taskpool/dummy.py +101 -13
  50. fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
  51. fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
  52. fusion_bench/utils/__init__.py +1 -0
  53. fusion_bench/utils/data.py +6 -4
  54. fusion_bench/utils/devices.py +36 -11
  55. fusion_bench/utils/dtype.py +3 -2
  56. fusion_bench/utils/lazy_state_dict.py +85 -19
  57. fusion_bench/utils/packages.py +3 -3
  58. fusion_bench/utils/parameters.py +0 -2
  59. fusion_bench/utils/rich_utils.py +7 -3
  60. fusion_bench/utils/timer.py +92 -10
  61. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/METADATA +10 -3
  62. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/RECORD +77 -64
  63. fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
  64. fusion_bench_config/method/ensemble/simple_ensemble.yaml +1 -0
  65. fusion_bench_config/method/linear/{simple_average_for_llama.yaml → simple_average_for_causallm.yaml} +1 -1
  66. fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +4 -0
  67. fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +13 -0
  68. fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
  69. fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
  70. fusion_bench_config/method/wudi/wudi.yaml +4 -0
  71. fusion_bench_config/modelpool/CausalLMPool/{Qwen2.5-1.5B_math_and_coder.yaml → Qwen2.5-1.5B_math_and_code.yaml} +1 -2
  72. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_three_models.yaml +11 -0
  73. fusion_bench_config/modelpool/CausalLMPool/llama-7b_3-models_v1.yaml +11 -0
  74. fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +0 -4
  75. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/WHEEL +0 -0
  76. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/entry_points.txt +0 -0
  77. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/licenses/LICENSE +0 -0
  78. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/top_level.txt +0 -0
@@ -17,7 +17,21 @@ from fusion_bench.models.wrappers.ensemble import (
17
17
  log = logging.getLogger(__name__)
18
18
 
19
19
 
20
+ @auto_register_config
20
21
  class SimpleEnsembleAlgorithm(BaseAlgorithm):
22
+ def __init__(
23
+ self,
24
+ device_map: Optional[Mapping[int, Union[str, torch.device]]] = None,
25
+ **kwargs,
26
+ ):
27
+ """
28
+ Initializes the SimpleEnsembleAlgorithm with an optional device map.
29
+
30
+ Args:
31
+ device_map (Optional[Mapping[int, Union[str, torch.device]]], optional): A mapping from model index to device. Defaults to None.
32
+ """
33
+ super().__init__(**kwargs)
34
+
21
35
  @torch.no_grad()
22
36
  def run(self, modelpool: BaseModelPool | List[nn.Module]) -> EnsembleModule:
23
37
  """
@@ -30,9 +44,10 @@ class SimpleEnsembleAlgorithm(BaseAlgorithm):
30
44
  EnsembleModule: The ensembled model.
31
45
  """
32
46
  log.info(f"Running ensemble algorithm with {len(modelpool)} models")
33
-
34
47
  models = [modelpool.load_model(m) for m in modelpool.model_names]
35
- ensemble = EnsembleModule(models=models)
48
+
49
+ log.info("creating ensemble module")
50
+ ensemble = EnsembleModule(models=models, device_map=self.device_map)
36
51
  return ensemble
37
52
 
38
53
 
@@ -23,6 +23,7 @@ from transformers import MixtralForCausalLM
23
23
  from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
24
24
 
25
25
  import fusion_bench as fb
26
+ from fusion_bench import auto_register_config
26
27
  from fusion_bench.method.expert_sparsity.utils.calibration_data import (
27
28
  build_calib_loader,
28
29
  )
@@ -97,6 +98,7 @@ def dynamic_skipping(
97
98
  return model, (res_median, res_mean)
98
99
 
99
100
 
101
+ @auto_register_config
100
102
  class DynamicSkippingPruningForMixtral(
101
103
  fb.BaseAlgorithm,
102
104
  fb.mixins.LightningFabricMixin,
@@ -22,6 +22,7 @@ from transformers import MixtralForCausalLM
22
22
  from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
23
23
 
24
24
  import fusion_bench as fb
25
+ from fusion_bench import auto_register_config
25
26
  from fusion_bench.method.expert_sparsity.utils.calibration_data import (
26
27
  build_calib_loader,
27
28
  )
@@ -81,6 +82,7 @@ def layerwise_pruning(
81
82
  return model, (global_loss_history,)
82
83
 
83
84
 
85
+ @auto_register_config
84
86
  class LayerWisePruningForMixtral(
85
87
  fb.BaseAlgorithm,
86
88
  fb.mixins.LightningFabricMixin,
@@ -20,6 +20,7 @@ from tqdm import tqdm
20
20
  from transformers import MixtralForCausalLM
21
21
 
22
22
  import fusion_bench as fb
23
+ from fusion_bench import auto_register_config
23
24
  from fusion_bench.method.expert_sparsity.utils.calibration_data import (
24
25
  build_calib_loader,
25
26
  )
@@ -95,6 +96,7 @@ def progressive_pruning(
95
96
  return model, (global_loss_history,)
96
97
 
97
98
 
99
+ @auto_register_config
98
100
  class ProgressivePruningForMixtral(
99
101
  fb.BaseAlgorithm,
100
102
  fb.mixins.LightningFabricMixin,
@@ -2,5 +2,9 @@
2
2
  from .expo import ExPOAlgorithm
3
3
  from .linear_interpolation import LinearInterpolationAlgorithm
4
4
  from .llama_expo import ExPOAlgorithmForLlama
5
- from .simple_average_for_llama import SimpleAverageForLlama
6
- from .task_arithmetic_for_llama import TaskArithmeticForLlama
5
+ from .simple_average_for_causallm import SimpleAverageForCausalLM, SimpleAverageForLlama
6
+ from .task_arithmetic_for_causallm import (
7
+ TaskArithmeticForCausalLM,
8
+ TaskArithmeticForLlama,
9
+ )
10
+ from .ties_merging_for_causallm import TiesMergingForCausalLM
@@ -18,16 +18,16 @@ log = get_rankzero_logger(__name__)
18
18
 
19
19
 
20
20
  @auto_register_config
21
- class SimpleAverageForLlama(BaseAlgorithm):
21
+ class SimpleAverageForCausalLM(BaseAlgorithm):
22
22
  R"""
23
23
  A simple averaging algorithm for LLama models. If `merge_backbone` is set to `True`, the backbone of the model will be averaged and the rest of the model will be loaded from the pre-trained model.
24
24
 
25
25
  Examples:
26
- The following example demonstrates how to use the `SimpleAverageForLlama` algorithm to merge Mistral models.
26
+ The following example demonstrates how to use the `SimpleAverageForCausalLM` algorithm to merge Mistral models.
27
27
 
28
28
  ```bash
29
29
  fusion_bench \
30
- method=linear/simple_average_for_llama \
30
+ method=linear/simple_average_for_causallm \
31
31
  method.model_save_path=outputs/simle_mixtral_exp_v4/simple_average \
32
32
  modelpool=CausalLMPool/simle_mixtral_exp_v4.yaml
33
33
  ```
@@ -35,7 +35,7 @@ class SimpleAverageForLlama(BaseAlgorithm):
35
35
 
36
36
  def __init__(
37
37
  self,
38
- merge_backbone: bool,
38
+ merge_backbone: bool = False,
39
39
  model_save_path: Optional[str] = None,
40
40
  show_pbar: bool = False,
41
41
  **kwargs,
@@ -81,3 +81,7 @@ class SimpleAverageForLlama(BaseAlgorithm):
81
81
  with open(os.path.join(self.model_save_path, "README.md"), "w") as f:
82
82
  f.write(model_card_str)
83
83
  return model
84
+
85
+
86
+ SimpleAverageForLlama = SimpleAverageForCausalLM
87
+ """Alias for SimpleAverageForCausalLM"""
@@ -1,22 +1,27 @@
1
1
  import logging
2
+ import os
2
3
  from typing import Dict, List, Mapping, Optional, TypeVar, Union # noqa: F401
3
4
 
4
5
  from typing_extensions import override
5
6
 
6
- from fusion_bench import timeit_context
7
+ from fusion_bench import auto_register_config, timeit_context
7
8
  from fusion_bench.method import TaskArithmeticAlgorithm
8
9
  from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
9
10
  from fusion_bench.modelpool import CausalLMBackbonePool, CausalLMPool
11
+ from fusion_bench.models.hf_utils import create_default_model_card
10
12
 
11
13
  log = logging.getLogger(__name__)
12
14
 
13
15
 
14
- class TaskArithmeticForLlama(TaskArithmeticAlgorithm, SimpleProfilerMixin):
16
+ @auto_register_config
17
+ class TaskArithmeticForCausalLM(
18
+ TaskArithmeticAlgorithm,
19
+ ):
15
20
  R"""
16
21
  Examples:
17
22
 
18
23
  fusion_bench \
19
- method=linear/task_arithmetic_for_llama \
24
+ method=linear/task_arithmetic_for_causallm \
20
25
  method.scaling_factor=0.3 \
21
26
  method.model_save_path=outputs/simle_mixtral_exp_v4/task_arithmetic_0.3 \
22
27
  modelpool=CausalLMPool/simle_mixtral_exp_v4.yaml
@@ -29,18 +34,14 @@ class TaskArithmeticForLlama(TaskArithmeticAlgorithm, SimpleProfilerMixin):
29
34
  def __init__(
30
35
  self,
31
36
  scaling_factor: float,
32
- merge_backbone: bool,
37
+ merge_backbone: bool = False,
33
38
  model_save_path: Optional[str] = None,
39
+ **kwargs,
34
40
  ):
35
- self.merge_backbone = merge_backbone
36
- self.model_save_path = model_save_path
37
- super().__init__(scaling_factor=scaling_factor)
41
+ super().__init__(scaling_factor=scaling_factor, **kwargs)
38
42
 
39
43
  @override
40
44
  def run(self, modelpool: CausalLMPool):
41
- if self.model_save_path:
42
- tokenizer = modelpool.load_tokenizer()
43
-
44
45
  if self.merge_backbone:
45
46
  assert modelpool.has_pretrained
46
47
  backbone_modelpool = CausalLMBackbonePool(**modelpool.config)
@@ -52,6 +53,15 @@ class TaskArithmeticForLlama(TaskArithmeticAlgorithm, SimpleProfilerMixin):
52
53
 
53
54
  if self.model_save_path is not None:
54
55
  with timeit_context(f"Saving the model to {self.model_save_path}"):
55
- tokenizer.save_pretrained(self.model_save_path)
56
- model.save_pretrained(self.model_save_path)
56
+ description = f"Merged model using task arithmetic with scaling factor {self.scaling_factor}."
57
+ modelpool.save_model(
58
+ model=model,
59
+ path=self.model_save_path,
60
+ save_tokenizer=True,
61
+ algorithm_config=self.config,
62
+ description=description,
63
+ )
57
64
  return model
65
+
66
+
67
+ TaskArithmeticForLlama = TaskArithmeticForCausalLM
@@ -0,0 +1,70 @@
1
+ import logging
2
+ import os
3
+ from typing import Dict, List, Mapping, Optional, TypeVar, Union # noqa: F401
4
+
5
+ from typing_extensions import override
6
+
7
+ from fusion_bench import auto_register_config, timeit_context
8
+ from fusion_bench.method import TiesMergingAlgorithm
9
+ from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
10
+ from fusion_bench.modelpool import CausalLMBackbonePool, CausalLMPool
11
+ from fusion_bench.models.hf_utils import create_default_model_card
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+
16
+ @auto_register_config
17
+ class TiesMergingForCausalLM(
18
+ TiesMergingAlgorithm,
19
+ ):
20
+ R"""
21
+ TIES merging algorithm for CausalLM models.
22
+
23
+ This class extends the TiesMergingAlgorithm to work specifically with CausalLM models,
24
+ providing model saving capabilities and backbone merging support.
25
+ """
26
+
27
+ _config_mapping = TiesMergingAlgorithm._config_mapping | {
28
+ "merge_backbone": "merge_backbone",
29
+ }
30
+
31
+ def __init__(
32
+ self,
33
+ scaling_factor: float,
34
+ threshold: float,
35
+ remove_keys: List[str] = None,
36
+ merge_func: str = "sum",
37
+ merge_backbone: bool = False,
38
+ model_save_path: Optional[str] = None,
39
+ **kwargs,
40
+ ):
41
+ super().__init__(
42
+ scaling_factor=scaling_factor,
43
+ threshold=threshold,
44
+ remove_keys=remove_keys,
45
+ merge_func=merge_func,
46
+ **kwargs,
47
+ )
48
+
49
+ @override
50
+ def run(self, modelpool: CausalLMPool):
51
+ if self.merge_backbone:
52
+ assert modelpool.has_pretrained
53
+ backbone_modelpool = CausalLMBackbonePool(**modelpool.config)
54
+ model = modelpool.load_model("_pretrained_")
55
+ backbone_model = super().run(backbone_modelpool)
56
+ model.model.layers = backbone_model
57
+ else:
58
+ model = super().run(modelpool)
59
+
60
+ if self.model_save_path is not None:
61
+ with timeit_context(f"Saving the model to {self.model_save_path}"):
62
+ description = f"Merged model using TIES merging with scaling factor {self.scaling_factor} and threshold {self.threshold}."
63
+ modelpool.save_model(
64
+ model=model,
65
+ path=self.model_save_path,
66
+ save_tokenizer=True,
67
+ algorithm_config=self.config,
68
+ description=description,
69
+ )
70
+ return model
@@ -0,0 +1 @@
1
+ from .model_stock import ModelStock
@@ -0,0 +1,309 @@
1
+ import copy
2
+ import logging
3
+ import math
4
+ import os
5
+ from collections import OrderedDict
6
+ from copy import deepcopy
7
+ from typing import Dict, List, Optional, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+ from omegaconf import DictConfig
12
+ from torch import nn
13
+ from transformers import PreTrainedModel
14
+
15
+ import fusion_bench
16
+ from fusion_bench import BaseAlgorithm, BaseModelPool
17
+ from fusion_bench.mixins import SimpleProfilerMixin
18
+ from fusion_bench.models import create_default_model_card
19
+ from fusion_bench.utils.type import StateDictType
20
+
21
+ log = logging.getLogger(__name__)
22
+
23
+ EPS = 1e-8
24
+
25
+
26
+ def compute_angle(
27
+ state_dict_1: StateDictType,
28
+ state_dict_2: StateDictType,
29
+ ref_state_dict: StateDictType,
30
+ ignore_keys: List[str] = [],
31
+ return_cos: bool = False,
32
+ ) -> Dict[str, float]:
33
+ """
34
+ Compute the angle between two state dictionaries relative to a reference state dictionary.
35
+
36
+ Args:
37
+ state_dict_1: First state dictionary
38
+ state_dict_2: Second state dictionary
39
+ ref_state_dict: Reference state dictionary (typically pre-trained model)
40
+ ignore_keys: Keys to ignore during computation
41
+ return_cos: If True, return cosine values instead of angles in degrees
42
+
43
+ Returns:
44
+ Dictionary mapping parameter names to angles (in degrees) or cosine values
45
+ """
46
+ # Remove the keys not used for CLIP fine-tuning (from the notebook example)
47
+
48
+ return_dict = OrderedDict()
49
+
50
+ with torch.no_grad():
51
+ for key in ref_state_dict:
52
+ if key in ignore_keys:
53
+ log.info(f"Ignoring key '{key}'")
54
+ continue
55
+
56
+ state_dict_1_val = state_dict_1[key]
57
+ state_dict_2_val = state_dict_2[key]
58
+ ref_val = ref_state_dict[key]
59
+
60
+ if not (state_dict_1_val.shape == state_dict_2_val.shape == ref_val.shape):
61
+ log.warning(
62
+ f"Shape mismatch for key '{key}', ignored during merging: "
63
+ f"({state_dict_1_val.shape}, {state_dict_2_val.shape}, {ref_val.shape})"
64
+ )
65
+ continue
66
+
67
+ vector1 = (state_dict_1_val - ref_val).clone().detach()
68
+ vector2 = (state_dict_2_val - ref_val).clone().detach()
69
+
70
+ vector1 = vector1.float()
71
+ vector2 = vector2.float()
72
+
73
+ cosine_val = torch.sum(vector1 * vector2) / (
74
+ math.sqrt(torch.sum(vector1**2) * torch.sum(vector2**2)) + EPS
75
+ )
76
+ cosine_val = torch.clamp(
77
+ cosine_val, min=-1.0, max=1.0
78
+ ) # Prevent nan from acos
79
+
80
+ if return_cos:
81
+ return_dict[key] = cosine_val.item()
82
+ else:
83
+ return_dict[key] = np.rad2deg(
84
+ torch.acos(cosine_val).detach().cpu().item()
85
+ )
86
+
87
+ return return_dict
88
+
89
+
90
+ def compute_ratio(angle_dict: Dict[str, float], k: int = 2) -> Dict[str, float]:
91
+ """
92
+ Compute interpolation ratios based on angles between fine-tuned models.
93
+
94
+ Args:
95
+ angle_dict: Dictionary mapping parameter names to angles in degrees
96
+ k: Number of fine-tuned models (default: 2)
97
+
98
+ Returns:
99
+ Dictionary mapping parameter names to interpolation ratios
100
+ """
101
+ ratio_dict = {}
102
+ for key in angle_dict.keys():
103
+ angle = np.deg2rad(angle_dict[key])
104
+ ratio_dict[key] = k * np.cos(angle) / ((k - 1) * np.cos(angle) + 1 + EPS)
105
+ return ratio_dict
106
+
107
+
108
+ def merge_weights(
109
+ w1: StateDictType, w2: StateDictType, w0: StateDictType, ratio: Dict[str, float]
110
+ ) -> StateDictType:
111
+ """
112
+ Merge model weights using ModelStock formula.
113
+
114
+ Args:
115
+ w1: First fine-tuned model weights
116
+ w2: Second fine-tuned model weights
117
+ w0: Pre-trained model weights
118
+ ratio: Interpolation ratios for each parameter
119
+
120
+ Returns:
121
+ Merged model weights
122
+ """
123
+ # Compute w12 = (w1 + w2) / 2
124
+ w12 = {}
125
+ for key in w1.keys():
126
+ w12[key] = (w1[key].clone() + w2[key].clone()) / 2.0
127
+
128
+ # Apply ModelStock formula: w_merge = t * w12 + (1-t) * w0
129
+ w_merge = copy.deepcopy(w12)
130
+ for key, r in ratio.items():
131
+ w_merge[key] = w12[key].clone() * r + w0[key].clone() * (1.0 - r)
132
+
133
+ return w_merge
134
+
135
+
136
+ @fusion_bench.auto_register_config
137
+ class ModelStock(SimpleProfilerMixin, BaseAlgorithm):
138
+ """
139
+ Model Stock: All we need is just a few fine-tuned models
140
+
141
+ This method merges fine-tuned models by interpolating between their average
142
+ and a pre-trained anchor model, with interpolation ratios determined by
143
+ the angle between fine-tuned models in parameter space.
144
+ """
145
+
146
+ def __init__(
147
+ self,
148
+ ignore_keys: Optional[List[str]] = None,
149
+ model_save_path: Optional[str] = None,
150
+ model_save_kwargs: Optional[DictConfig] = None,
151
+ **kwargs,
152
+ ):
153
+ """
154
+ Initialize ModelStock algorithm.
155
+
156
+ Args:
157
+ ignore_keys: Additional parameter keys to ignore during merging
158
+ """
159
+ super().__init__(**kwargs)
160
+ if self.ignore_keys is None:
161
+ self.ignore_keys = []
162
+ if self.model_save_kwargs is None:
163
+ self.model_save_kwargs = DictConfig({})
164
+
165
+ def run(self, modelpool: BaseModelPool) -> nn.Module:
166
+ """
167
+ Run the ModelStock merging algorithm.
168
+
169
+ Args:
170
+ modelpool: Pool of models containing pre-trained and fine-tuned models
171
+
172
+ Returns:
173
+ Merged model
174
+ """
175
+ with self.profile("model loading"):
176
+ # Load the pre-trained model (anchor)
177
+ pretrained_model = modelpool.load_pretrained_model()
178
+ if isinstance(pretrained_model, fusion_bench.LazyStateDict):
179
+ assert (
180
+ pretrained_model.meta_module is not None
181
+ ), "Meta module is not initialized"
182
+ pretrained_state_dict = pretrained_model.state_dict()
183
+
184
+ # Load all fine-tuned models
185
+ finetuned_models = []
186
+ finetuned_state_dicts = []
187
+
188
+ for model_name in modelpool.model_names:
189
+ model = modelpool.load_model(model_name)
190
+ finetuned_models.append(model)
191
+ finetuned_state_dicts.append(model.state_dict())
192
+ log.info(f"Loaded fine-tuned model: {model_name}")
193
+
194
+ if len(finetuned_models) < 2:
195
+ raise ValueError("ModelStock requires at least 2 fine-tuned models")
196
+
197
+ log.info(f"Running ModelStock with {len(finetuned_models)} fine-tuned models")
198
+
199
+ with self.profile("compute angles and ratios"):
200
+ if len(finetuned_models) == 2:
201
+ # Two fine-tuned models case
202
+ angle_dict = compute_angle(
203
+ finetuned_state_dicts[0],
204
+ finetuned_state_dicts[1],
205
+ pretrained_state_dict,
206
+ ignore_keys=self.ignore_keys,
207
+ )
208
+ ratio_dict = compute_ratio(angle_dict, k=2)
209
+
210
+ log.info(f"Computed angles for {len(angle_dict)} parameter groups")
211
+
212
+ else:
213
+ # N fine-tuned models case - compute average angle
214
+ angles_sum = {}
215
+ angles_count = {}
216
+
217
+ # Compute pairwise angles and average them
218
+ for i in range(len(finetuned_models)):
219
+ for j in range(i + 1, len(finetuned_models)):
220
+ angle_dict = compute_angle(
221
+ finetuned_state_dicts[i],
222
+ finetuned_state_dicts[j],
223
+ pretrained_state_dict,
224
+ ignore_keys=self.ignore_keys,
225
+ )
226
+
227
+ for key, angle in angle_dict.items():
228
+ if key not in angles_sum:
229
+ angles_sum[key] = 0
230
+ angles_count[key] = 0
231
+ angles_sum[key] += angle
232
+ angles_count[key] += 1
233
+
234
+ # Average the angles
235
+ avg_angle_dict = {}
236
+ for key in angles_sum:
237
+ avg_angle_dict[key] = angles_sum[key] / angles_count[key]
238
+
239
+ ratio_dict = compute_ratio(avg_angle_dict, k=len(finetuned_models))
240
+
241
+ log.info(
242
+ f"Computed average angles for {len(avg_angle_dict)} parameter groups"
243
+ )
244
+
245
+ with self.profile("merge weights"):
246
+ if len(finetuned_models) == 2:
247
+ # Direct merging for two models
248
+ merged_state_dict = merge_weights(
249
+ finetuned_state_dicts[0],
250
+ finetuned_state_dicts[1],
251
+ pretrained_state_dict,
252
+ ratio_dict,
253
+ )
254
+ else:
255
+ # For N models, first compute the average of fine-tuned models
256
+ avg_finetuned_state_dict = {}
257
+ for key in finetuned_state_dicts[0].keys():
258
+ avg_finetuned_state_dict[key] = torch.zeros_like(
259
+ finetuned_state_dicts[0][key]
260
+ )
261
+ for state_dict in finetuned_state_dicts:
262
+ avg_finetuned_state_dict[key] += state_dict[key]
263
+ avg_finetuned_state_dict[key] /= len(finetuned_state_dicts)
264
+
265
+ # Apply ModelStock formula: w_H = t * w_avg + (1-t) * w_0
266
+ merged_state_dict = copy.deepcopy(avg_finetuned_state_dict)
267
+ for key, r in ratio_dict.items():
268
+ merged_state_dict[key] = avg_finetuned_state_dict[
269
+ key
270
+ ].clone() * r + pretrained_state_dict[key].clone() * (1.0 - r)
271
+
272
+ # Load merged weights into the model
273
+ if isinstance(pretrained_model, nn.Module):
274
+ result_model = pretrained_model
275
+ elif isinstance(pretrained_model, fusion_bench.LazyStateDict):
276
+ result_model = deepcopy(pretrained_model.meta_module)
277
+ result_model.to(device=pretrained_model._device)
278
+ result = result_model.load_state_dict(merged_state_dict, strict=False)
279
+
280
+ if result.unexpected_keys:
281
+ raise RuntimeError(
282
+ f"Unexpected keys in state dict: {result.unexpected_keys}"
283
+ )
284
+ if result.missing_keys:
285
+ log.warning(f"Missing keys in state dict: {result.missing_keys}")
286
+
287
+ if self.model_save_path is not None:
288
+ with self.profile("model saving"):
289
+ modelpool.save_model(
290
+ model, path=self.model_save_path, **self.model_save_kwargs
291
+ )
292
+ if isinstance(model, PreTrainedModel):
293
+ modelcard = create_default_model_card(
294
+ models=[
295
+ modelpool.get_model_path(m)
296
+ for m in modelpool.all_model_names
297
+ ],
298
+ description="Merged model using [Model Stock](https://arxiv.org/abs/2403.19522).",
299
+ algorithm_config=self.config,
300
+ modelpool_config=modelpool.config,
301
+ )
302
+ with open(
303
+ os.path.join(self.model_save_path, "README.md"), "w"
304
+ ) as f:
305
+ f.write(modelcard)
306
+
307
+ self.print_profile_summary()
308
+ log.info("ModelStock merging completed successfully")
309
+ return result_model
@@ -9,6 +9,7 @@ from torch.nn.modules import Module
9
9
  from torch.utils.data import DataLoader
10
10
  from tqdm.autonotebook import tqdm
11
11
 
12
+ from fusion_bench import auto_register_config
12
13
  from fusion_bench.dataset.clip_dataset import CLIPDataset
13
14
  from fusion_bench.mixins import CLIPClassificationMixin
14
15
 
@@ -17,17 +18,13 @@ from .regmean import RegMeanAlgorithm
17
18
  log = logging.getLogger(__name__)
18
19
 
19
20
 
21
+ @auto_register_config
20
22
  class RegMeanAlgorithmForCLIP(
21
- RegMeanAlgorithm,
22
23
  CLIPClassificationMixin,
24
+ RegMeanAlgorithm,
23
25
  ):
24
- _config_mapping = {
25
- "_dataloader_kwargs": "dataloader_kwargs",
26
- }
27
-
28
26
  def __init__(self, *, dataloader_kwargs: DictConfig, **kwargs):
29
27
  super().__init__(**kwargs)
30
- self.dataloader_kwargs = dataloader_kwargs
31
28
 
32
29
  def on_regmean_start(self):
33
30
  self.setup_zero_shot_classification_head()