fusion-bench 0.2.21__py3-none-any.whl → 0.2.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. fusion_bench/__init__.py +21 -2
  2. fusion_bench/constants/__init__.py +1 -0
  3. fusion_bench/constants/runtime.py +57 -0
  4. fusion_bench/method/__init__.py +8 -2
  5. fusion_bench/method/bitdelta/__init__.py +1 -0
  6. fusion_bench/method/classification/clip_finetune.py +1 -1
  7. fusion_bench/method/fisher_merging/clip_fisher_merging.py +0 -4
  8. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +2 -2
  9. fusion_bench/method/linear/simple_average_for_llama.py +16 -11
  10. fusion_bench/method/simple_average.py +7 -7
  11. fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
  12. fusion_bench/method/smile_upscaling/projected_energy.py +1 -2
  13. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
  14. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +40 -31
  15. fusion_bench/method/smile_upscaling/smile_upscaling.py +1 -1
  16. fusion_bench/method/we_moe/__init__.py +1 -0
  17. fusion_bench/method/we_moe/entropy_loss.py +25 -0
  18. fusion_bench/method/we_moe/flan_t5_we_moe.py +331 -0
  19. fusion_bench/method/we_moe/utils.py +15 -0
  20. fusion_bench/method/weighted_average/llama.py +1 -1
  21. fusion_bench/mixins/clip_classification.py +11 -42
  22. fusion_bench/mixins/serialization.py +18 -8
  23. fusion_bench/modelpool/causal_lm/causal_lm.py +32 -33
  24. fusion_bench/models/__init__.py +5 -0
  25. fusion_bench/models/hf_utils.py +65 -87
  26. fusion_bench/models/model_card_templates/default.md +46 -0
  27. fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
  28. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +1 -8
  29. fusion_bench/models/modeling_smile_mistral/__init__.py +1 -1
  30. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +1 -5
  31. fusion_bench/programs/fabric_fusion_program.py +29 -60
  32. fusion_bench/scripts/cli.py +34 -1
  33. fusion_bench/taskpool/clip_vision/taskpool.py +9 -4
  34. fusion_bench/utils/__init__.py +1 -0
  35. fusion_bench/utils/cache_utils.py +101 -1
  36. fusion_bench/utils/fabric.py +2 -2
  37. fusion_bench/utils/lazy_imports.py +23 -0
  38. fusion_bench/utils/lazy_state_dict.py +38 -3
  39. fusion_bench/utils/modelscope.py +3 -3
  40. fusion_bench/utils/path.py +56 -0
  41. fusion_bench/utils/pylogger.py +1 -1
  42. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/METADATA +1 -23
  43. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/RECORD +53 -45
  44. fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
  45. fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
  46. fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
  47. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -1
  48. fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
  49. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +1 -1
  50. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/WHEEL +0 -0
  51. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/entry_points.txt +0 -0
  52. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/licenses/LICENSE +0 -0
  53. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/top_level.txt +0 -0
fusion_bench/__init__.py CHANGED
@@ -19,9 +19,28 @@ from . import (
19
19
  tasks,
20
20
  utils,
21
21
  )
22
+ from .constants import RuntimeConstants
22
23
  from .method import BaseAlgorithm, BaseModelFusionAlgorithm
23
24
  from .mixins import auto_register_config
24
25
  from .modelpool import BaseModelPool
25
- from .models import separate_io
26
+ from .models import (
27
+ create_default_model_card,
28
+ load_model_card_template,
29
+ save_pretrained_with_remote_code,
30
+ separate_io,
31
+ )
32
+ from .programs import BaseHydraProgram
26
33
  from .taskpool import BaseTaskPool
27
- from .utils import parse_dtype, print_parameters, timeit_context
34
+ from .utils import (
35
+ cache_with_joblib,
36
+ get_rankzero_logger,
37
+ import_object,
38
+ instantiate,
39
+ parse_dtype,
40
+ print_parameters,
41
+ seed_everything_by_time,
42
+ set_default_cache_dir,
43
+ set_print_function_call,
44
+ set_print_function_call_permeanent,
45
+ timeit_context,
46
+ )
@@ -2,6 +2,7 @@
2
2
  import importlib.metadata
3
3
 
4
4
  from .paths import *
5
+ from .runtime import RuntimeConstants
5
6
 
6
7
  # fusionbench version
7
8
  FUSION_BENCH_VERSION = importlib.metadata.version("fusion-bench")
@@ -0,0 +1,57 @@
1
+ import threading
2
+ from pathlib import Path
3
+ from typing import Optional, Union
4
+
5
+
6
+ class RuntimeConstants:
7
+ """
8
+ This class holds constants related to the runtime environment of the Fusion Bench framework.
9
+ It includes default values for cache directories and other runtime configurations.
10
+
11
+ Implemented as a thread-safe singleton to ensure consistent runtime configuration
12
+ across the entire application.
13
+ """
14
+
15
+ _instance: Optional["RuntimeConstants"] = None
16
+ _lock = threading.Lock()
17
+
18
+ def __new__(cls) -> "RuntimeConstants":
19
+ """Create a new instance using singleton pattern with thread safety."""
20
+ with cls._lock:
21
+ # Double-check locking pattern
22
+ if cls._instance is None:
23
+ cls._instance = super(RuntimeConstants, cls).__new__(cls)
24
+ cls._instance._initialized = False
25
+ return cls._instance
26
+
27
+ def __init__(self):
28
+ """Initialize the singleton instance only once."""
29
+ if not self._initialized:
30
+ # Add your runtime constants here
31
+ self._initialized = True
32
+
33
+ debug = False
34
+
35
+ @property
36
+ def cache_dir(self) -> Path:
37
+ from fusion_bench.utils.cache_utils import DEFAULT_CACHE_DIR
38
+
39
+ return DEFAULT_CACHE_DIR
40
+
41
+ @cache_dir.setter
42
+ def cache_dir(self, path: Union[str, Path]) -> None:
43
+ from fusion_bench.utils.cache_utils import set_default_cache_dir
44
+
45
+ set_default_cache_dir(path)
46
+
47
+ @property
48
+ def print_function_call(self) -> bool:
49
+ from fusion_bench.utils.instantiate_utils import PRINT_FUNCTION_CALL
50
+
51
+ return PRINT_FUNCTION_CALL
52
+
53
+ @print_function_call.setter
54
+ def print_function_call(self, enable: bool) -> None:
55
+ from fusion_bench.utils.instantiate_utils import set_print_function_call
56
+
57
+ set_print_function_call(enable)
@@ -90,7 +90,10 @@ _import_structure = {
90
90
  "MixtralForCausalLMMergingAlgorithm",
91
91
  ],
92
92
  "dawe": ["DataAdaptiveWeightEnsemblingForCLIP"],
93
- "we_moe": ["CLIPWeightEnsemblingMoEAlgorithm"],
93
+ "we_moe": [
94
+ "CLIPWeightEnsemblingMoEAlgorithm",
95
+ "FlanT5WeightEnsemblingMoEAlgorithm",
96
+ ],
94
97
  "rankone_moe": ["CLIPRankOneMoEAlgorithm", "RankOneMoEAlgorithm"],
95
98
  "sparse_we_moe": [
96
99
  "SparseWeightEnsemblingMoEAlgorithm",
@@ -228,7 +231,10 @@ if TYPE_CHECKING:
228
231
  from .task_arithmetic import TaskArithmeticAlgorithm
229
232
  from .task_singular_vector import TaskSingularVectorMerging
230
233
  from .ties_merging import TiesMergingAlgorithm
231
- from .we_moe import CLIPWeightEnsemblingMoEAlgorithm
234
+ from .we_moe import (
235
+ CLIPWeightEnsemblingMoEAlgorithm,
236
+ FlanT5WeightEnsemblingMoEAlgorithm,
237
+ )
232
238
  from .weighted_average import WeightedAverageAlgorithm, WeightedAverageForLLama
233
239
 
234
240
  else:
@@ -1,4 +1,5 @@
1
1
  """
2
2
  Adapted from https://github.com/FasterDecoding/BitDelta
3
3
  """
4
+
4
5
  from .bitdelta import BitDeltaAlgorithm
@@ -393,7 +393,7 @@ def convert_l_lora_state_dict_to_hf(
393
393
  base_model_name: Optional[str] = None,
394
394
  ):
395
395
  """
396
- Convert a linearized Lora model's checkpoint to Hugggingface's format.
396
+ Convert a linearized Lora model's checkpoint to huggingface's format.
397
397
 
398
398
  Args:
399
399
  pretrained_path (str): The path to the pretrained model.
@@ -32,7 +32,6 @@ class FisherMergingForCLIPVisionModel(
32
32
  zeroshot_weights = {}
33
33
 
34
34
  _config_mapping = FisherMergingAlgorithm._config_mapping | {
35
- "zeroshot_weights_cache_dir": "zeroshot_weights_cache_dir",
36
35
  "_dataloader_kwargs": "dataloader_kwargs",
37
36
  }
38
37
 
@@ -44,7 +43,6 @@ class FisherMergingForCLIPVisionModel(
44
43
  minimal_fisher_weight,
45
44
  num_fisher_examples,
46
45
  dataloader_kwargs: DictConfig,
47
- zeroshot_weights_cache_dir=None,
48
46
  **kwargs,
49
47
  ):
50
48
  """
@@ -56,7 +54,6 @@ class FisherMergingForCLIPVisionModel(
56
54
  minimal_fisher_weight (float): Minimal value for Fisher weights to avoid numerical issues.
57
55
  num_fisher_examples (int): Number of examples to compute Fisher weights.
58
56
  dataloader_kwargs (DictConfig): Configuration for the dataloader.
59
- zeroshot_weights_cache_dir (str, optional): Directory to cache zero-shot weights. Defaults to None.
60
57
  **kwargs: Additional keyword arguments.
61
58
  """
62
59
  super().__init__(
@@ -66,7 +63,6 @@ class FisherMergingForCLIPVisionModel(
66
63
  num_fisher_examples=num_fisher_examples,
67
64
  )
68
65
  self.dataloader_kwargs = dataloader_kwargs
69
- self.zeroshot_weights_cache_dir = zeroshot_weights_cache_dir
70
66
  for key, value in kwargs.items():
71
67
  log.warning(f"Unused argument: {key}={value}")
72
68
  setattr(self, key, value)
@@ -15,10 +15,10 @@ from transformers import GPT2ForSequenceClassification, GPT2Model
15
15
  from transformers.data import default_data_collator
16
16
  from transformers.models.gpt2.modeling_gpt2 import Conv1D
17
17
 
18
- from fusion_bench.mixins import LightningFabricMixin
18
+ from fusion_bench.mixins import LightningFabricMixin, auto_register_config
19
19
  from fusion_bench.modelpool import GPT2ForSequenceClassificationPool
20
20
  from fusion_bench.utils import timeit_context
21
- from fusion_bench.mixins import auto_register_config
21
+
22
22
  from .fisher_merging import FisherMergingAlgorithm, get_param_squared_gradients
23
23
 
24
24
 
@@ -1,3 +1,4 @@
1
+ import os
1
2
  from copy import deepcopy
2
3
  from typing import TYPE_CHECKING, Optional
3
4
 
@@ -7,13 +8,16 @@ from typing_extensions import override
7
8
  from fusion_bench import timeit_context
8
9
  from fusion_bench.method.base_algorithm import BaseAlgorithm
9
10
  from fusion_bench.method.simple_average import SimpleAverageAlgorithm
11
+ from fusion_bench.mixins import auto_register_config
10
12
  from fusion_bench.modelpool import CausalLMBackbonePool, CausalLMPool
13
+ from fusion_bench.models.hf_utils import create_default_model_card
11
14
  from fusion_bench.utils import instantiate
12
- from fusion_bench.utils.pylogger import getRankZeroLogger
15
+ from fusion_bench.utils.pylogger import get_rankzero_logger
13
16
 
14
- log = getRankZeroLogger(__name__)
17
+ log = get_rankzero_logger(__name__)
15
18
 
16
19
 
20
+ @auto_register_config
17
21
  class SimpleAverageForLlama(BaseAlgorithm):
18
22
  R"""
19
23
  A simple averaging algorithm for LLama models. If `merge_backbone` is set to `True`, the backbone of the model will be averaged and the rest of the model will be loaded from the pre-trained model.
@@ -29,21 +33,14 @@ class SimpleAverageForLlama(BaseAlgorithm):
29
33
  ```
30
34
  """
31
35
 
32
- _config_mapping = BaseAlgorithm._config_mapping | {
33
- "merge_backbone": "merge_backbone",
34
- "show_pbar": "show_pbar",
35
- }
36
-
37
36
  def __init__(
38
37
  self,
39
38
  merge_backbone: bool,
40
39
  model_save_path: Optional[str] = None,
41
40
  show_pbar: bool = False,
41
+ **kwargs,
42
42
  ):
43
- super().__init__()
44
- self.merge_backbone = merge_backbone
45
- self.model_save_path = model_save_path
46
- self.show_pbar = show_pbar
43
+ super().__init__(**kwargs)
47
44
 
48
45
  @override
49
46
  def run(self, modelpool: CausalLMPool):
@@ -75,4 +72,12 @@ class SimpleAverageForLlama(BaseAlgorithm):
75
72
  with timeit_context(f"Saving the model to {self.model_save_path}"):
76
73
  tokenizer.save_pretrained(self.model_save_path)
77
74
  model.save_pretrained(self.model_save_path)
75
+ model_card_str = create_default_model_card(
76
+ models=[modelpool.get_model_path(m) for m in modelpool.model_names],
77
+ description="Merged model using simple averaging.",
78
+ algorithm_config=self.config,
79
+ modelpool_config=modelpool.config,
80
+ )
81
+ with open(os.path.join(self.model_save_path, "README.md"), "w") as f:
82
+ f.write(model_card_str)
78
83
  return model
@@ -61,8 +61,8 @@ def simple_average(
61
61
 
62
62
  @auto_register_config
63
63
  class SimpleAverageAlgorithm(
64
- BaseAlgorithm,
65
64
  SimpleProfilerMixin,
65
+ BaseAlgorithm,
66
66
  ):
67
67
  def __init__(self, show_pbar: bool = False, **kwargs):
68
68
  """
@@ -120,13 +120,13 @@ class SimpleAverageAlgorithm(
120
120
  if isinstance(forward_model, LazyStateDict):
121
121
  # if the model is a LazyStateDict, convert it to an empty module
122
122
  forward_model = forward_model.meta_module.to_empty(
123
- device=(
124
- "cpu"
125
- if forward_model._torch_dtype is None
126
- else forward_model._torch_dtype
127
- )
123
+ device=forward_model._device
128
124
  )
129
- forward_model.load_state_dict(sd)
125
+ result = forward_model.load_state_dict(sd, strict=False)
126
+ if result.unexpected_keys:
127
+ raise ValueError(f"Unexpected keys in state dict: {result.unexpected_keys}")
128
+ if result.missing_keys:
129
+ log.warning(f"Missing keys in state dict: {result.missing_keys}")
130
130
  # print profile report and log the merged models
131
131
  self.print_profile_summary()
132
132
  log.info(f"merged {len(merged_model_names)} models:")
@@ -0,0 +1,371 @@
1
+ import logging
2
+ import os
3
+ from copy import deepcopy
4
+ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, Union
5
+
6
+ import torch
7
+ from accelerate import init_empty_weights
8
+ from tqdm.auto import tqdm
9
+ from transformers import (
10
+ AutoConfig,
11
+ AutoModelForCausalLM,
12
+ AutoTokenizer,
13
+ LlamaForCausalLM,
14
+ MistralForCausalLM,
15
+ PretrainedConfig,
16
+ PreTrainedModel,
17
+ Qwen2ForCausalLM,
18
+ )
19
+ from transformers.models.llama.modeling_llama import LlamaDecoderLayer
20
+ from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
21
+ from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
22
+
23
+ from fusion_bench import BaseAlgorithm, BaseModelPool
24
+ from fusion_bench.compat.modelpool import to_modelpool
25
+ from fusion_bench.constants import RuntimeConstants
26
+ from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
27
+ from fusion_bench.modelpool import CausalLMPool
28
+ from fusion_bench.models.hf_utils import (
29
+ create_default_model_card,
30
+ save_pretrained_with_remote_code,
31
+ )
32
+ from fusion_bench.models.modeling_smile_llama import (
33
+ SmileLlamaConfig,
34
+ SmileLlamaForCausalLM,
35
+ SmileLlamaModel,
36
+ )
37
+ from fusion_bench.models.modeling_smile_llama.modeling_smile_llama import (
38
+ SmileLlamaDecoderLayer,
39
+ )
40
+ from fusion_bench.models.modeling_smile_mistral import (
41
+ SmileMistralConfig,
42
+ SmileMistralForCausalLM,
43
+ SmileMistralModel,
44
+ )
45
+ from fusion_bench.models.modeling_smile_mistral.modeling_smile_mistral import (
46
+ SmileMistralDecoderLayer,
47
+ )
48
+
49
+ # Import all SMILE configurations and models
50
+ from fusion_bench.models.modeling_smile_qwen2 import (
51
+ SmileQwen2Config,
52
+ SmileQwen2ForCausalLM,
53
+ SmileQwen2Model,
54
+ )
55
+ from fusion_bench.models.modeling_smile_qwen2.modeling_smile_qwen2 import (
56
+ SmileQwen2DecoderLayer,
57
+ )
58
+ from fusion_bench.models.smile_moe.linear_from_hf_config import (
59
+ ExpertNotTrainedError,
60
+ upscale_to_smile_linear,
61
+ )
62
+ from fusion_bench.utils.dtype import parse_dtype
63
+ from fusion_bench.utils.parameters import print_parameters
64
+
65
+ log = logging.getLogger(__name__)
66
+
67
+ # Model type mappings
68
+ MODEL_TYPE_MAPPINGS = {
69
+ "qwen2": {
70
+ "base_model_cls": Qwen2ForCausalLM,
71
+ "base_decoder_layer_cls": Qwen2DecoderLayer,
72
+ "smile_config_cls": SmileQwen2Config,
73
+ "smile_model_cls": SmileQwen2ForCausalLM,
74
+ "smile_base_model_cls": SmileQwen2Model,
75
+ "smile_decoder_layer_cls": SmileQwen2DecoderLayer,
76
+ "description": "Qwen2",
77
+ },
78
+ "llama": {
79
+ "base_model_cls": LlamaForCausalLM,
80
+ "base_decoder_layer_cls": LlamaDecoderLayer,
81
+ "smile_config_cls": SmileLlamaConfig,
82
+ "smile_model_cls": SmileLlamaForCausalLM,
83
+ "smile_base_model_cls": SmileLlamaModel,
84
+ "smile_decoder_layer_cls": SmileLlamaDecoderLayer,
85
+ "description": "Llama",
86
+ },
87
+ "mistral": {
88
+ "base_model_cls": MistralForCausalLM,
89
+ "base_decoder_layer_cls": MistralDecoderLayer,
90
+ "smile_config_cls": SmileMistralConfig,
91
+ "smile_model_cls": SmileMistralForCausalLM,
92
+ "smile_base_model_cls": SmileMistralModel,
93
+ "smile_decoder_layer_cls": SmileMistralDecoderLayer,
94
+ "description": "Mistral",
95
+ },
96
+ }
97
+
98
+
99
+ def detect_model_type(
100
+ model_or_config: Union[PreTrainedModel, PretrainedConfig, str],
101
+ ) -> str:
102
+ """
103
+ Detect the model type from a model, config, or model name/path.
104
+
105
+ Args:
106
+ model_or_config: Model, config, or model name/path to detect type from
107
+
108
+ Returns:
109
+ str: The detected model type ("qwen2", "llama", "mistral")
110
+
111
+ Raises:
112
+ ValueError: If model type cannot be detected or is not supported
113
+ """
114
+ if isinstance(model_or_config, str):
115
+ # Load config from path/name
116
+ config = AutoConfig.from_pretrained(model_or_config)
117
+ elif isinstance(model_or_config, PreTrainedModel):
118
+ config = model_or_config.config
119
+ elif isinstance(model_or_config, PretrainedConfig):
120
+ config = model_or_config
121
+ else:
122
+ raise ValueError(
123
+ f"Unsupported type for model type detection: {type(model_or_config)}"
124
+ )
125
+
126
+ model_type = getattr(config, "model_type", "").lower()
127
+
128
+ # Handle various model type variations
129
+ if model_type in MODEL_TYPE_MAPPINGS:
130
+ return model_type
131
+ else:
132
+ raise ValueError(
133
+ f"Unsupported model type: {model_type}. Supported types: {list(MODEL_TYPE_MAPPINGS.keys())}"
134
+ )
135
+
136
+
137
+ @auto_register_config
138
+ class SmileCausalLMUpscalingAlgorithm(
139
+ SimpleProfilerMixin,
140
+ BaseAlgorithm,
141
+ ):
142
+ R"""
143
+ SmileCausalLMUpscalingAlgorithm is a generic model fusion algorithm designed to upscale
144
+ a pretrained CausalLM model using a set of fine-tuned expert models. The algorithm
145
+ supports Qwen2, Llama, and Mistral model architectures and leverages Singular Value
146
+ Decomposition (SVD) to merge the weights of the pretrained model and the expert models
147
+ into a new upscaled model.
148
+
149
+ The algorithm automatically detects the model type and uses the appropriate SMILE
150
+ configuration and model classes.
151
+
152
+ Methods:
153
+ run(modelpool: BaseModelPool) -> Union[SmileQwen2ForCausalLM, SmileLlamaForCausalLM, SmileMistralForCausalLM]:
154
+ Executes the upscaling process and returns the upscaled model.
155
+
156
+ merge(pretrained_model: PreTrainedModel, finetuned_models: List[PreTrainedModel]) -> PreTrainedModel:
157
+ Merges the pretrained model with the fine-tuned models to create an upscaled model.
158
+ """
159
+
160
+ modelpool: CausalLMPool
161
+
162
+ def __init__(
163
+ self,
164
+ device,
165
+ accelerator,
166
+ model_save_path,
167
+ model_dtype,
168
+ num_experts_per_tok,
169
+ rank_of_router,
170
+ rank_of_expert,
171
+ save_with_remote_code: bool = True,
172
+ model_type: str = None, # Optional: explicitly specify model type
173
+ **kwargs,
174
+ ):
175
+ super().__init__(**kwargs)
176
+ self.model_mappings = None # Will be set during run()
177
+
178
+ if not torch.cuda.is_available():
179
+ if "cuda" in self.device:
180
+ self.device = "cpu"
181
+ if "cuda" in self.accelerator:
182
+ self.accelerator = "cpu"
183
+
184
+ @torch.no_grad()
185
+ def run(self, modelpool) -> PreTrainedModel:
186
+ """
187
+ Executes the upscaling process.
188
+
189
+ Args:
190
+ modelpool (ModelPool): The pool of models to be used for upscaling.
191
+
192
+ Returns:
193
+ PreTrainedModel: The upscaled model (specific type depends on detected model architecture).
194
+ """
195
+ self.modelpool = modelpool = to_modelpool(modelpool)
196
+ config = self.config
197
+
198
+ # Auto-detect model type if not specified
199
+ if self.model_type is None:
200
+ self.model_type = detect_model_type(
201
+ modelpool.get_model_path("_pretrained_")
202
+ )
203
+ log.info(f"Auto-detected model type: {self.model_type}")
204
+
205
+ # Get the appropriate model mappings
206
+ if self.model_type not in MODEL_TYPE_MAPPINGS:
207
+ raise ValueError(
208
+ f"Unsupported model type: {self.model_type}. Supported: {list(MODEL_TYPE_MAPPINGS.keys())}"
209
+ )
210
+
211
+ self.model_mappings = MODEL_TYPE_MAPPINGS[self.model_type]
212
+ log.info(f"Using {self.model_mappings['description']} model architecture")
213
+
214
+ with self.profile("load pretrained model"):
215
+ pretrained_model = modelpool.load_pretrained_model()
216
+
217
+ with self.profile("load fine-tuned model"):
218
+ finetuned_models = [
219
+ m for m in tqdm(modelpool.models(), total=len(modelpool.model_names))
220
+ ]
221
+
222
+ if self.device == "cuda" and torch.cuda.is_available():
223
+ pretrained_model = pretrained_model.cuda()
224
+ print("parameter count of pretrained model:")
225
+ print_parameters(pretrained_model)
226
+ finetuned_models = [m.cuda() for m in finetuned_models]
227
+
228
+ with self.profile("merge model"):
229
+ model = self.merge(pretrained_model, finetuned_models)
230
+
231
+ self.print_profile_summary()
232
+ print("parameter count of upscaled MoE model:")
233
+ print_parameters(model)
234
+ print(model)
235
+
236
+ if self.model_dtype is not None:
237
+ model.to(dtype=parse_dtype(self.model_dtype))
238
+
239
+ if self.model_save_path is not None:
240
+ if os.path.dirname(self.model_save_path):
241
+ os.makedirs(os.path.dirname(self.model_save_path), exist_ok=True)
242
+ log.info(f"Saving model to {self.model_save_path}")
243
+ tokenizer = self.modelpool.load_tokenizer()
244
+ tokenizer.save_pretrained(self.model_save_path)
245
+ if not self.save_with_remote_code:
246
+ model.save_pretrained(self.model_save_path)
247
+ else:
248
+ # Use the appropriate auto_map for the detected model type
249
+ auto_map = {
250
+ "AutoConfig": self.model_mappings["smile_config_cls"],
251
+ "AutoModel": self.model_mappings["smile_base_model_cls"],
252
+ "AutoModelForCausalLM": self.model_mappings["smile_model_cls"],
253
+ }
254
+ save_pretrained_with_remote_code(
255
+ model,
256
+ auto_map=auto_map,
257
+ save_directory=self.model_save_path,
258
+ )
259
+
260
+ # save readme
261
+ model_card_str = create_default_model_card(
262
+ models=[modelpool.get_model_path(m) for m in modelpool.all_model_names],
263
+ description=f"Merged {self.model_mappings['description']} model using SMILE Upscaling",
264
+ algorithm_config=self.config,
265
+ modelpool_config=modelpool.config,
266
+ )
267
+ with open(os.path.join(self.model_save_path, "README.md"), "w") as f:
268
+ f.write(model_card_str)
269
+
270
+ return model
271
+
272
+ def merge(
273
+ self,
274
+ pretrained_model: PreTrainedModel,
275
+ finetuned_models: List[PreTrainedModel],
276
+ ) -> PreTrainedModel:
277
+ """
278
+ Merges the pretrained model with the fine-tuned models to create an upscaled model.
279
+
280
+ Args:
281
+ pretrained_model (PreTrainedModel): The pretrained model.
282
+ finetuned_models (List[PreTrainedModel]): A list of fine-tuned models.
283
+
284
+ Returns:
285
+ PreTrainedModel: The upscaled model (specific type depends on model architecture).
286
+ """
287
+ with init_empty_weights():
288
+ pretrained_model_config = self.modelpool.get_model_config("_pretrained_")
289
+ if isinstance(pretrained_model_config, str):
290
+ pretrained_path = pretrained_model_config
291
+ else:
292
+ pretrained_path = pretrained_model_config.get(
293
+ "path", pretrained_model_config["pretrained_model_name_or_path"]
294
+ )
295
+ base_config = AutoConfig.from_pretrained(pretrained_path)
296
+
297
+ # Create the appropriate SMILE config for the detected model type
298
+ SmileConfigClass = self.model_mappings["smile_config_cls"]
299
+ model_config = SmileConfigClass(
300
+ num_experts_per_tok=self.num_experts_per_tok,
301
+ rank_of_router=self.rank_of_router,
302
+ rank_of_expert=self.rank_of_expert,
303
+ num_local_experts=len(finetuned_models),
304
+ **base_config.to_dict(),
305
+ )
306
+
307
+ # Create the appropriate SMILE model for the detected model type
308
+ SmileModelClass = self.model_mappings["smile_model_cls"]
309
+ model = SmileModelClass(model_config)
310
+
311
+ model.to(dtype=pretrained_model.dtype).to_empty(device="cpu")
312
+
313
+ # copy pretrained model weights
314
+ state_dict = model.state_dict()
315
+ pretrained_state_dict = pretrained_model.state_dict()
316
+ for key in list(pretrained_state_dict.keys()):
317
+ if key not in state_dict:
318
+ pretrained_state_dict.pop(key)
319
+ model.load_state_dict(pretrained_state_dict, strict=False)
320
+
321
+ # upscale model
322
+ BaseDecoderLayerClass = self.model_mappings["base_decoder_layer_cls"]
323
+ SmileDecoderLayerClass = self.model_mappings["smile_decoder_layer_cls"]
324
+
325
+ for layer_idx in tqdm(
326
+ range(len(pretrained_model.model.layers)),
327
+ "Upscaling Modules (layer)",
328
+ dynamic_ncols=True,
329
+ ):
330
+ if RuntimeConstants.debug and layer_idx > 0:
331
+ log.info(
332
+ "Debug mode enabled: processing only the first layer, skipping remaining layers"
333
+ )
334
+ break
335
+
336
+ pretrained_layer = pretrained_model.model.layers[layer_idx]
337
+ finetuned_layers = [m.model.layers[layer_idx] for m in finetuned_models]
338
+
339
+ target_layer = model.model.layers[layer_idx]
340
+
341
+ for n in ["q_proj", "k_proj", "v_proj", "o_proj"]:
342
+ try:
343
+ upscale_to_smile_linear(
344
+ base=getattr(pretrained_layer.self_attn, n),
345
+ experts=[getattr(m.self_attn, n) for m in finetuned_layers],
346
+ target=getattr(target_layer.self_attn, n),
347
+ accelerator=self.accelerator,
348
+ )
349
+ except ExpertNotTrainedError:
350
+ setattr(
351
+ target_layer.self_attn,
352
+ n,
353
+ getattr(pretrained_layer.self_attn, n),
354
+ )
355
+
356
+ for n in ["gate_proj", "up_proj", "down_proj"]:
357
+ try:
358
+ upscale_to_smile_linear(
359
+ base=getattr(pretrained_layer.mlp, n),
360
+ experts=[getattr(m.mlp, n) for m in finetuned_layers],
361
+ target=getattr(target_layer.mlp, n),
362
+ accelerator=self.accelerator,
363
+ )
364
+ except ExpertNotTrainedError:
365
+ setattr(
366
+ target_layer.mlp,
367
+ n,
368
+ getattr(pretrained_layer.mlp, n),
369
+ )
370
+
371
+ return model
@@ -3,12 +3,11 @@ from typing import Literal
3
3
 
4
4
  import pandas as pd
5
5
  import torch
6
+ from tqdm import tqdm
6
7
 
7
8
  from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
8
9
  from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
9
10
 
10
- from tqdm import tqdm
11
-
12
11
 
13
12
  class ProjectedEnergyAnalysis(
14
13
  SimpleProfilerMixin,
@@ -20,6 +20,7 @@ from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
20
20
  from fusion_bench.compat.modelpool import to_modelpool
21
21
  from fusion_bench.method import BaseAlgorithm
22
22
  from fusion_bench.method.simple_average import simple_average
23
+ from fusion_bench.mixins import auto_register_config
23
24
  from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
24
25
  from fusion_bench.modelpool import BaseModelPool
25
26
  from fusion_bench.models.modeling_smile_mistral import (
@@ -40,7 +41,10 @@ from fusion_bench.utils.parameters import print_parameters
40
41
  log = logging.getLogger(__name__)
41
42
 
42
43
 
43
- class SmileMistralUpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
44
+ class SmileMistralUpscalingAlgorithm(
45
+ SimpleProfilerMixin,
46
+ BaseAlgorithm,
47
+ ):
44
48
  R"""
45
49
  SmileMistralUpscalingAlgorithm is a model fusion algorithm designed to upscale
46
50
  a pretrained Mistral model using a set of fine-tuned expert models. The algorithm