fusion-bench 0.2.21__py3-none-any.whl → 0.2.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. fusion_bench/__init__.py +25 -2
  2. fusion_bench/compat/method/__init__.py +5 -2
  3. fusion_bench/compat/method/base_algorithm.py +3 -2
  4. fusion_bench/compat/modelpool/base_pool.py +3 -3
  5. fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
  6. fusion_bench/constants/__init__.py +1 -0
  7. fusion_bench/constants/runtime.py +57 -0
  8. fusion_bench/dataset/gpt2_glue.py +1 -1
  9. fusion_bench/method/__init__.py +12 -4
  10. fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
  11. fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
  12. fusion_bench/method/bitdelta/__init__.py +1 -0
  13. fusion_bench/method/bitdelta/bitdelta.py +7 -23
  14. fusion_bench/method/classification/clip_finetune.py +1 -1
  15. fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
  16. fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
  17. fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
  18. fusion_bench/method/fisher_merging/clip_fisher_merging.py +0 -4
  19. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +2 -2
  20. fusion_bench/method/linear/simple_average_for_llama.py +16 -11
  21. fusion_bench/method/model_stock/__init__.py +1 -0
  22. fusion_bench/method/model_stock/model_stock.py +309 -0
  23. fusion_bench/method/regmean/clip_regmean.py +3 -6
  24. fusion_bench/method/regmean/regmean.py +27 -56
  25. fusion_bench/method/regmean/utils.py +56 -0
  26. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
  27. fusion_bench/method/simple_average.py +7 -7
  28. fusion_bench/method/slerp/__init__.py +1 -1
  29. fusion_bench/method/slerp/slerp.py +110 -14
  30. fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
  31. fusion_bench/method/smile_upscaling/projected_energy.py +1 -2
  32. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
  33. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +40 -31
  34. fusion_bench/method/smile_upscaling/smile_upscaling.py +1 -1
  35. fusion_bench/method/we_moe/__init__.py +1 -0
  36. fusion_bench/method/we_moe/entropy_loss.py +25 -0
  37. fusion_bench/method/we_moe/flan_t5_we_moe.py +320 -0
  38. fusion_bench/method/we_moe/utils.py +15 -0
  39. fusion_bench/method/weighted_average/llama.py +1 -1
  40. fusion_bench/mixins/clip_classification.py +37 -48
  41. fusion_bench/mixins/serialization.py +30 -10
  42. fusion_bench/modelpool/base_pool.py +1 -1
  43. fusion_bench/modelpool/causal_lm/causal_lm.py +293 -75
  44. fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
  45. fusion_bench/models/__init__.py +5 -0
  46. fusion_bench/models/hf_utils.py +69 -86
  47. fusion_bench/models/linearized/vision_model.py +6 -6
  48. fusion_bench/models/model_card_templates/default.md +46 -0
  49. fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
  50. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +1 -8
  51. fusion_bench/models/modeling_smile_mistral/__init__.py +2 -1
  52. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +1 -5
  53. fusion_bench/models/we_moe.py +8 -8
  54. fusion_bench/programs/fabric_fusion_program.py +29 -60
  55. fusion_bench/scripts/cli.py +34 -1
  56. fusion_bench/taskpool/base_pool.py +99 -17
  57. fusion_bench/taskpool/clip_vision/taskpool.py +10 -5
  58. fusion_bench/taskpool/dummy.py +101 -13
  59. fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
  60. fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
  61. fusion_bench/utils/__init__.py +2 -0
  62. fusion_bench/utils/cache_utils.py +101 -1
  63. fusion_bench/utils/data.py +6 -4
  64. fusion_bench/utils/devices.py +7 -4
  65. fusion_bench/utils/dtype.py +3 -2
  66. fusion_bench/utils/fabric.py +2 -2
  67. fusion_bench/utils/lazy_imports.py +23 -0
  68. fusion_bench/utils/lazy_state_dict.py +117 -19
  69. fusion_bench/utils/modelscope.py +3 -3
  70. fusion_bench/utils/packages.py +3 -3
  71. fusion_bench/utils/parameters.py +0 -2
  72. fusion_bench/utils/path.py +56 -0
  73. fusion_bench/utils/pylogger.py +1 -1
  74. fusion_bench/utils/timer.py +92 -10
  75. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/METADATA +1 -23
  76. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/RECORD +89 -75
  77. fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
  78. fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
  79. fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
  80. fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
  81. fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
  82. fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
  83. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -1
  84. fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
  85. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +1 -1
  86. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/WHEEL +0 -0
  87. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/entry_points.txt +0 -0
  88. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/licenses/LICENSE +0 -0
  89. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,24 @@
1
1
  import logging
2
- from typing import Any, Dict
2
+ import os
3
+ from copy import deepcopy
4
+ from typing import TYPE_CHECKING, Any, Dict, Optional
3
5
 
4
6
  import torch
5
7
  from torch import nn
8
+ from tqdm import tqdm
6
9
  from typing_extensions import override
7
10
 
11
+ from fusion_bench import LazyStateDict, create_default_model_card, timeit_context
8
12
  from fusion_bench.method import BaseAlgorithm
9
- from fusion_bench.modelpool import BaseModelPool
13
+ from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
14
+ from fusion_bench.modelpool import BaseModelPool, CausalLMPool
10
15
  from fusion_bench.utils.type import StateDictType
11
16
 
12
17
  from .slerp_utils import slerp
13
18
 
19
+ if TYPE_CHECKING:
20
+ from transformers import PreTrainedModel
21
+
14
22
  log = logging.getLogger(__name__)
15
23
 
16
24
 
@@ -21,6 +29,7 @@ def slerp_on_state_dicts(
21
29
  *,
22
30
  DOT_THRESHOLD: float = 0.9995,
23
31
  epsilon: float = 1e-8,
32
+ show_pbar: bool = False,
24
33
  ) -> StateDictType:
25
34
  """
26
35
  Perform spherical linear interpolation (slerp) on the state dictionaries of two models.
@@ -36,7 +45,8 @@ def slerp_on_state_dicts(
36
45
  dict: The interpolated state dictionary.
37
46
  """
38
47
  state_dict = {}
39
- for key in secondary_state_dict:
48
+ pbar = secondary_state_dict if not show_pbar else tqdm(secondary_state_dict)
49
+ for key in pbar:
40
50
  v0 = primary_state_dict[key]
41
51
  v1 = secondary_state_dict[key]
42
52
  if v0.shape != v1.shape:
@@ -49,18 +59,19 @@ def slerp_on_state_dicts(
49
59
  return state_dict
50
60
 
51
61
 
62
+ @auto_register_config
52
63
  class SlerpMergeAlgorithm(BaseAlgorithm):
53
64
  """
54
65
  General purpose implementation of Slerp (Spherical Linear Interpolation) for PyTorch models.
55
66
  """
56
67
 
57
- _config_mapping = BaseAlgorithm._config_mapping | {
58
- "t": "t",
59
- "DOT_THRESHOLD": "DOT_THRESHOLD",
60
- "epsilon": "epsilon",
61
- }
62
-
63
- def __init__(self, t: float, DOT_THRESHOLD: float = 0.9995, epsilon: float = 1e-8):
68
+ def __init__(
69
+ self,
70
+ t: float,
71
+ DOT_THRESHOLD: float = 0.9995,
72
+ epsilon: float = 1e-8,
73
+ **kwargs,
74
+ ):
64
75
  """
65
76
  Initialize the SlerpMergeAlgorithm.
66
77
 
@@ -69,10 +80,7 @@ class SlerpMergeAlgorithm(BaseAlgorithm):
69
80
  DOT_THRESHOLD (float, optional): The threshold for the dot product of the two vectors. Defaults to 0.9995.
70
81
  epsilon (float, optional): The epsilon value for numerical stability. Defaults to 1e-8.
71
82
  """
72
- self.t = t
73
- self.DOT_THRESHOLD = DOT_THRESHOLD
74
- self.epsilon = epsilon
75
- super().__init__()
83
+ super().__init__(**kwargs)
76
84
 
77
85
  @override
78
86
  def run(self, modelpool: BaseModelPool) -> nn.Module:
@@ -102,3 +110,91 @@ class SlerpMergeAlgorithm(BaseAlgorithm):
102
110
 
103
111
  primary_model.load_state_dict(state_dict)
104
112
  return primary_model
113
+
114
+
115
+ @auto_register_config
116
+ class SlerpForCausalLM(
117
+ SimpleProfilerMixin,
118
+ BaseAlgorithm,
119
+ ):
120
+ """
121
+ Slerp (Spherical Linear Interpolation) for Causal Language Models.
122
+ """
123
+
124
+ def __init__(
125
+ self,
126
+ t: float,
127
+ DOT_THRESHOLD: float = 0.9995,
128
+ epsilon: float = 1e-8,
129
+ model_save_path: Optional[str] = None,
130
+ show_pbar: bool = False,
131
+ **kwargs,
132
+ ):
133
+ """
134
+ Initialize the SlerpForCausalLM algorithm.
135
+
136
+ Args:
137
+ t (float): The interpolation parameter. Must be in the range [0, 1].
138
+ t=0 returns the first model, t=1 returns the second model,
139
+ t=0.5 provides balanced interpolation.
140
+ DOT_THRESHOLD (float, optional): The threshold for the dot product of normalized vectors.
141
+ When the absolute dot product exceeds this threshold,
142
+ vectors are considered nearly collinear and linear
143
+ interpolation (LERP) is used instead of SLERP for
144
+ numerical stability. Defaults to 0.9995.
145
+ epsilon (float, optional): Small value used for numerical stability to avoid
146
+ division by zero during vector normalization.
147
+ Defaults to 1e-8.
148
+ model_save_path (Optional[str], optional): Path where the merged model should be saved.
149
+ If None, the model is not saved to disk.
150
+ Defaults to None.
151
+ show_pbar (bool, optional): Whether to display a progress bar during the interpolation
152
+ process. Useful for debugging or monitoring progress with
153
+ large models. Defaults to False.
154
+ **kwargs: Additional keyword arguments passed to the parent BaseAlgorithm class.
155
+ """
156
+ super().__init__(**kwargs)
157
+
158
+ @override
159
+ def run(self, modelpool: CausalLMPool):
160
+ assert len(modelpool.all_model_names) == 2, "Slerp expect exactly 2 models"
161
+ primary_model = modelpool.load_model(modelpool.all_model_names[0])
162
+ secondary_model = modelpool.load_model(modelpool.all_model_names[1])
163
+
164
+ with torch.no_grad():
165
+ primary_state_dict = primary_model.state_dict()
166
+ secondary_state_dict = secondary_model.state_dict()
167
+ state_dict = slerp_on_state_dicts(
168
+ self.t,
169
+ primary_state_dict,
170
+ secondary_state_dict,
171
+ DOT_THRESHOLD=self.DOT_THRESHOLD,
172
+ epsilon=self.epsilon,
173
+ )
174
+
175
+ if isinstance(primary_model, nn.Module):
176
+ model = primary_model
177
+ model.load_state_dict(state_dict)
178
+ elif isinstance(primary_model, LazyStateDict):
179
+ model: "PreTrainedModel" = deepcopy(primary_model.meta_module)
180
+ model.to(device=primary_model._device)
181
+ model.load_state_dict(state_dict)
182
+ else:
183
+ raise TypeError(
184
+ f"Unsupported model type: {type(primary_model)}. "
185
+ "Expected nn.Module or LazyStateDict."
186
+ )
187
+ if self.model_save_path is not None:
188
+ with timeit_context(f"Saving the model to {self.model_save_path}"):
189
+ tokenizer = modelpool.load_tokenizer()
190
+ tokenizer.save_pretrained(self.model_save_path)
191
+ model.save_pretrained(self.model_save_path)
192
+ model_card_str = create_default_model_card(
193
+ models=[modelpool.get_model_path(m) for m in modelpool.model_names],
194
+ description="Merged model using Slerp.",
195
+ algorithm_config=self.config,
196
+ modelpool_config=modelpool.config,
197
+ )
198
+ with open(os.path.join(self.model_save_path, "README.md"), "w") as f:
199
+ f.write(model_card_str)
200
+ return model
@@ -0,0 +1,371 @@
1
+ import logging
2
+ import os
3
+ from copy import deepcopy
4
+ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, Union
5
+
6
+ import torch
7
+ from accelerate import init_empty_weights
8
+ from tqdm.auto import tqdm
9
+ from transformers import (
10
+ AutoConfig,
11
+ AutoModelForCausalLM,
12
+ AutoTokenizer,
13
+ LlamaForCausalLM,
14
+ MistralForCausalLM,
15
+ PretrainedConfig,
16
+ PreTrainedModel,
17
+ Qwen2ForCausalLM,
18
+ )
19
+ from transformers.models.llama.modeling_llama import LlamaDecoderLayer
20
+ from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
21
+ from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
22
+
23
+ from fusion_bench import BaseAlgorithm, BaseModelPool
24
+ from fusion_bench.compat.modelpool import to_modelpool
25
+ from fusion_bench.constants import RuntimeConstants
26
+ from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
27
+ from fusion_bench.modelpool import CausalLMPool
28
+ from fusion_bench.models.hf_utils import (
29
+ create_default_model_card,
30
+ save_pretrained_with_remote_code,
31
+ )
32
+ from fusion_bench.models.modeling_smile_llama import (
33
+ SmileLlamaConfig,
34
+ SmileLlamaForCausalLM,
35
+ SmileLlamaModel,
36
+ )
37
+ from fusion_bench.models.modeling_smile_llama.modeling_smile_llama import (
38
+ SmileLlamaDecoderLayer,
39
+ )
40
+ from fusion_bench.models.modeling_smile_mistral import (
41
+ SmileMistralConfig,
42
+ SmileMistralForCausalLM,
43
+ SmileMistralModel,
44
+ )
45
+ from fusion_bench.models.modeling_smile_mistral.modeling_smile_mistral import (
46
+ SmileMistralDecoderLayer,
47
+ )
48
+
49
+ # Import all SMILE configurations and models
50
+ from fusion_bench.models.modeling_smile_qwen2 import (
51
+ SmileQwen2Config,
52
+ SmileQwen2ForCausalLM,
53
+ SmileQwen2Model,
54
+ )
55
+ from fusion_bench.models.modeling_smile_qwen2.modeling_smile_qwen2 import (
56
+ SmileQwen2DecoderLayer,
57
+ )
58
+ from fusion_bench.models.smile_moe.linear_from_hf_config import (
59
+ ExpertNotTrainedError,
60
+ upscale_to_smile_linear,
61
+ )
62
+ from fusion_bench.utils.dtype import parse_dtype
63
+ from fusion_bench.utils.parameters import print_parameters
64
+
65
+ log = logging.getLogger(__name__)
66
+
67
+ # Model type mappings
68
+ MODEL_TYPE_MAPPINGS = {
69
+ "qwen2": {
70
+ "base_model_cls": Qwen2ForCausalLM,
71
+ "base_decoder_layer_cls": Qwen2DecoderLayer,
72
+ "smile_config_cls": SmileQwen2Config,
73
+ "smile_model_cls": SmileQwen2ForCausalLM,
74
+ "smile_base_model_cls": SmileQwen2Model,
75
+ "smile_decoder_layer_cls": SmileQwen2DecoderLayer,
76
+ "description": "Qwen2",
77
+ },
78
+ "llama": {
79
+ "base_model_cls": LlamaForCausalLM,
80
+ "base_decoder_layer_cls": LlamaDecoderLayer,
81
+ "smile_config_cls": SmileLlamaConfig,
82
+ "smile_model_cls": SmileLlamaForCausalLM,
83
+ "smile_base_model_cls": SmileLlamaModel,
84
+ "smile_decoder_layer_cls": SmileLlamaDecoderLayer,
85
+ "description": "Llama",
86
+ },
87
+ "mistral": {
88
+ "base_model_cls": MistralForCausalLM,
89
+ "base_decoder_layer_cls": MistralDecoderLayer,
90
+ "smile_config_cls": SmileMistralConfig,
91
+ "smile_model_cls": SmileMistralForCausalLM,
92
+ "smile_base_model_cls": SmileMistralModel,
93
+ "smile_decoder_layer_cls": SmileMistralDecoderLayer,
94
+ "description": "Mistral",
95
+ },
96
+ }
97
+
98
+
99
+ def detect_model_type(
100
+ model_or_config: Union[PreTrainedModel, PretrainedConfig, str],
101
+ ) -> str:
102
+ """
103
+ Detect the model type from a model, config, or model name/path.
104
+
105
+ Args:
106
+ model_or_config: Model, config, or model name/path to detect type from
107
+
108
+ Returns:
109
+ str: The detected model type ("qwen2", "llama", "mistral")
110
+
111
+ Raises:
112
+ ValueError: If model type cannot be detected or is not supported
113
+ """
114
+ if isinstance(model_or_config, str):
115
+ # Load config from path/name
116
+ config = AutoConfig.from_pretrained(model_or_config)
117
+ elif isinstance(model_or_config, PreTrainedModel):
118
+ config = model_or_config.config
119
+ elif isinstance(model_or_config, PretrainedConfig):
120
+ config = model_or_config
121
+ else:
122
+ raise ValueError(
123
+ f"Unsupported type for model type detection: {type(model_or_config)}"
124
+ )
125
+
126
+ model_type = getattr(config, "model_type", "").lower()
127
+
128
+ # Handle various model type variations
129
+ if model_type in MODEL_TYPE_MAPPINGS:
130
+ return model_type
131
+ else:
132
+ raise ValueError(
133
+ f"Unsupported model type: {model_type}. Supported types: {list(MODEL_TYPE_MAPPINGS.keys())}"
134
+ )
135
+
136
+
137
+ @auto_register_config
138
+ class SmileCausalLMUpscalingAlgorithm(
139
+ SimpleProfilerMixin,
140
+ BaseAlgorithm,
141
+ ):
142
+ R"""
143
+ SmileCausalLMUpscalingAlgorithm is a generic model fusion algorithm designed to upscale
144
+ a pretrained CausalLM model using a set of fine-tuned expert models. The algorithm
145
+ supports Qwen2, Llama, and Mistral model architectures and leverages Singular Value
146
+ Decomposition (SVD) to merge the weights of the pretrained model and the expert models
147
+ into a new upscaled model.
148
+
149
+ The algorithm automatically detects the model type and uses the appropriate SMILE
150
+ configuration and model classes.
151
+
152
+ Methods:
153
+ run(modelpool: BaseModelPool) -> Union[SmileQwen2ForCausalLM, SmileLlamaForCausalLM, SmileMistralForCausalLM]:
154
+ Executes the upscaling process and returns the upscaled model.
155
+
156
+ merge(pretrained_model: PreTrainedModel, finetuned_models: List[PreTrainedModel]) -> PreTrainedModel:
157
+ Merges the pretrained model with the fine-tuned models to create an upscaled model.
158
+ """
159
+
160
+ modelpool: CausalLMPool
161
+
162
+ def __init__(
163
+ self,
164
+ device,
165
+ accelerator,
166
+ model_save_path,
167
+ model_dtype,
168
+ num_experts_per_tok,
169
+ rank_of_router,
170
+ rank_of_expert,
171
+ save_with_remote_code: bool = True,
172
+ model_type: str = None, # Optional: explicitly specify model type
173
+ **kwargs,
174
+ ):
175
+ super().__init__(**kwargs)
176
+ self.model_mappings = None # Will be set during run()
177
+
178
+ if not torch.cuda.is_available():
179
+ if "cuda" in self.device:
180
+ self.device = "cpu"
181
+ if "cuda" in self.accelerator:
182
+ self.accelerator = "cpu"
183
+
184
+ @torch.no_grad()
185
+ def run(self, modelpool) -> PreTrainedModel:
186
+ """
187
+ Executes the upscaling process.
188
+
189
+ Args:
190
+ modelpool (ModelPool): The pool of models to be used for upscaling.
191
+
192
+ Returns:
193
+ PreTrainedModel: The upscaled model (specific type depends on detected model architecture).
194
+ """
195
+ self.modelpool = modelpool = to_modelpool(modelpool)
196
+ config = self.config
197
+
198
+ # Auto-detect model type if not specified
199
+ if self.model_type is None:
200
+ self.model_type = detect_model_type(
201
+ modelpool.get_model_path("_pretrained_")
202
+ )
203
+ log.info(f"Auto-detected model type: {self.model_type}")
204
+
205
+ # Get the appropriate model mappings
206
+ if self.model_type not in MODEL_TYPE_MAPPINGS:
207
+ raise ValueError(
208
+ f"Unsupported model type: {self.model_type}. Supported: {list(MODEL_TYPE_MAPPINGS.keys())}"
209
+ )
210
+
211
+ self.model_mappings = MODEL_TYPE_MAPPINGS[self.model_type]
212
+ log.info(f"Using {self.model_mappings['description']} model architecture")
213
+
214
+ with self.profile("load pretrained model"):
215
+ pretrained_model = modelpool.load_pretrained_model()
216
+
217
+ with self.profile("load fine-tuned model"):
218
+ finetuned_models = [
219
+ m for m in tqdm(modelpool.models(), total=len(modelpool.model_names))
220
+ ]
221
+
222
+ if self.device == "cuda" and torch.cuda.is_available():
223
+ pretrained_model = pretrained_model.cuda()
224
+ print("parameter count of pretrained model:")
225
+ print_parameters(pretrained_model)
226
+ finetuned_models = [m.cuda() for m in finetuned_models]
227
+
228
+ with self.profile("merge model"):
229
+ model = self.merge(pretrained_model, finetuned_models)
230
+
231
+ self.print_profile_summary()
232
+ print("parameter count of upscaled MoE model:")
233
+ print_parameters(model)
234
+ print(model)
235
+
236
+ if self.model_dtype is not None:
237
+ model.to(dtype=parse_dtype(self.model_dtype))
238
+
239
+ if self.model_save_path is not None:
240
+ if os.path.dirname(self.model_save_path):
241
+ os.makedirs(os.path.dirname(self.model_save_path), exist_ok=True)
242
+ log.info(f"Saving model to {self.model_save_path}")
243
+ tokenizer = self.modelpool.load_tokenizer()
244
+ tokenizer.save_pretrained(self.model_save_path)
245
+ if not self.save_with_remote_code:
246
+ model.save_pretrained(self.model_save_path)
247
+ else:
248
+ # Use the appropriate auto_map for the detected model type
249
+ auto_map = {
250
+ "AutoConfig": self.model_mappings["smile_config_cls"],
251
+ "AutoModel": self.model_mappings["smile_base_model_cls"],
252
+ "AutoModelForCausalLM": self.model_mappings["smile_model_cls"],
253
+ }
254
+ save_pretrained_with_remote_code(
255
+ model,
256
+ auto_map=auto_map,
257
+ save_directory=self.model_save_path,
258
+ )
259
+
260
+ # save readme
261
+ model_card_str = create_default_model_card(
262
+ models=[modelpool.get_model_path(m) for m in modelpool.all_model_names],
263
+ description=f"Merged {self.model_mappings['description']} model using SMILE Upscaling",
264
+ algorithm_config=self.config,
265
+ modelpool_config=modelpool.config,
266
+ )
267
+ with open(os.path.join(self.model_save_path, "README.md"), "w") as f:
268
+ f.write(model_card_str)
269
+
270
+ return model
271
+
272
+ def merge(
273
+ self,
274
+ pretrained_model: PreTrainedModel,
275
+ finetuned_models: List[PreTrainedModel],
276
+ ) -> PreTrainedModel:
277
+ """
278
+ Merges the pretrained model with the fine-tuned models to create an upscaled model.
279
+
280
+ Args:
281
+ pretrained_model (PreTrainedModel): The pretrained model.
282
+ finetuned_models (List[PreTrainedModel]): A list of fine-tuned models.
283
+
284
+ Returns:
285
+ PreTrainedModel: The upscaled model (specific type depends on model architecture).
286
+ """
287
+ with init_empty_weights():
288
+ pretrained_model_config = self.modelpool.get_model_config("_pretrained_")
289
+ if isinstance(pretrained_model_config, str):
290
+ pretrained_path = pretrained_model_config
291
+ else:
292
+ pretrained_path = pretrained_model_config.get(
293
+ "path", pretrained_model_config["pretrained_model_name_or_path"]
294
+ )
295
+ base_config = AutoConfig.from_pretrained(pretrained_path)
296
+
297
+ # Create the appropriate SMILE config for the detected model type
298
+ SmileConfigClass = self.model_mappings["smile_config_cls"]
299
+ model_config = SmileConfigClass(
300
+ num_experts_per_tok=self.num_experts_per_tok,
301
+ rank_of_router=self.rank_of_router,
302
+ rank_of_expert=self.rank_of_expert,
303
+ num_local_experts=len(finetuned_models),
304
+ **base_config.to_dict(),
305
+ )
306
+
307
+ # Create the appropriate SMILE model for the detected model type
308
+ SmileModelClass = self.model_mappings["smile_model_cls"]
309
+ model = SmileModelClass(model_config)
310
+
311
+ model.to(dtype=pretrained_model.dtype).to_empty(device="cpu")
312
+
313
+ # copy pretrained model weights
314
+ state_dict = model.state_dict()
315
+ pretrained_state_dict = pretrained_model.state_dict()
316
+ for key in list(pretrained_state_dict.keys()):
317
+ if key not in state_dict:
318
+ pretrained_state_dict.pop(key)
319
+ model.load_state_dict(pretrained_state_dict, strict=False)
320
+
321
+ # upscale model
322
+ BaseDecoderLayerClass = self.model_mappings["base_decoder_layer_cls"]
323
+ SmileDecoderLayerClass = self.model_mappings["smile_decoder_layer_cls"]
324
+
325
+ for layer_idx in tqdm(
326
+ range(len(pretrained_model.model.layers)),
327
+ "Upscaling Modules (layer)",
328
+ dynamic_ncols=True,
329
+ ):
330
+ if RuntimeConstants.debug and layer_idx > 0:
331
+ log.info(
332
+ "Debug mode enabled: processing only the first layer, skipping remaining layers"
333
+ )
334
+ break
335
+
336
+ pretrained_layer = pretrained_model.model.layers[layer_idx]
337
+ finetuned_layers = [m.model.layers[layer_idx] for m in finetuned_models]
338
+
339
+ target_layer = model.model.layers[layer_idx]
340
+
341
+ for n in ["q_proj", "k_proj", "v_proj", "o_proj"]:
342
+ try:
343
+ upscale_to_smile_linear(
344
+ base=getattr(pretrained_layer.self_attn, n),
345
+ experts=[getattr(m.self_attn, n) for m in finetuned_layers],
346
+ target=getattr(target_layer.self_attn, n),
347
+ accelerator=self.accelerator,
348
+ )
349
+ except ExpertNotTrainedError:
350
+ setattr(
351
+ target_layer.self_attn,
352
+ n,
353
+ getattr(pretrained_layer.self_attn, n),
354
+ )
355
+
356
+ for n in ["gate_proj", "up_proj", "down_proj"]:
357
+ try:
358
+ upscale_to_smile_linear(
359
+ base=getattr(pretrained_layer.mlp, n),
360
+ experts=[getattr(m.mlp, n) for m in finetuned_layers],
361
+ target=getattr(target_layer.mlp, n),
362
+ accelerator=self.accelerator,
363
+ )
364
+ except ExpertNotTrainedError:
365
+ setattr(
366
+ target_layer.mlp,
367
+ n,
368
+ getattr(pretrained_layer.mlp, n),
369
+ )
370
+
371
+ return model
@@ -3,12 +3,11 @@ from typing import Literal
3
3
 
4
4
  import pandas as pd
5
5
  import torch
6
+ from tqdm import tqdm
6
7
 
7
8
  from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
8
9
  from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
9
10
 
10
- from tqdm import tqdm
11
-
12
11
 
13
12
  class ProjectedEnergyAnalysis(
14
13
  SimpleProfilerMixin,
@@ -20,6 +20,7 @@ from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
20
20
  from fusion_bench.compat.modelpool import to_modelpool
21
21
  from fusion_bench.method import BaseAlgorithm
22
22
  from fusion_bench.method.simple_average import simple_average
23
+ from fusion_bench.mixins import auto_register_config
23
24
  from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
24
25
  from fusion_bench.modelpool import BaseModelPool
25
26
  from fusion_bench.models.modeling_smile_mistral import (
@@ -40,7 +41,10 @@ from fusion_bench.utils.parameters import print_parameters
40
41
  log = logging.getLogger(__name__)
41
42
 
42
43
 
43
- class SmileMistralUpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
44
+ class SmileMistralUpscalingAlgorithm(
45
+ SimpleProfilerMixin,
46
+ BaseAlgorithm,
47
+ ):
44
48
  R"""
45
49
  SmileMistralUpscalingAlgorithm is a model fusion algorithm designed to upscale
46
50
  a pretrained Mistral model using a set of fine-tuned expert models. The algorithm