fusion-bench 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. fusion_bench/__init__.py +22 -2
  2. fusion_bench/_get_started/__init__.py +3 -0
  3. fusion_bench/_get_started/greeting_program.py +49 -0
  4. fusion_bench/compat/method/base_algorithm.py +14 -0
  5. fusion_bench/constants/__init__.py +6 -0
  6. fusion_bench/constants/clip_vision.py +26 -2
  7. fusion_bench/constants/paths.py +4 -0
  8. fusion_bench/constants/runtime.py +57 -0
  9. fusion_bench/dataset/clip_dataset.py +2 -1
  10. fusion_bench/dataset/gpt2_glue.py +9 -9
  11. fusion_bench/dataset/image_corruption/__init__.py +0 -0
  12. fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
  13. fusion_bench/dataset/image_dataset.py +1 -1
  14. fusion_bench/dataset/nyuv2.py +2 -2
  15. fusion_bench/method/__init__.py +24 -5
  16. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  17. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
  18. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
  19. fusion_bench/method/base_algorithm.py +195 -12
  20. fusion_bench/method/bitdelta/__init__.py +5 -0
  21. fusion_bench/method/bitdelta/bitdelta.py +156 -0
  22. fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
  23. fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
  24. fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
  25. fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
  26. fusion_bench/method/classification/clip_finetune.py +1 -1
  27. fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
  28. fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
  29. fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
  30. fusion_bench/method/doge_ta/doge_ta.py +1 -1
  31. fusion_bench/method/ensemble.py +12 -12
  32. fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
  33. fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -6
  34. fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
  35. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
  36. fusion_bench/method/fw_merging/fw_hard.py +1 -1
  37. fusion_bench/method/fw_merging/fw_soft.py +1 -1
  38. fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
  39. fusion_bench/method/linear/expo.py +2 -1
  40. fusion_bench/method/linear/linear_interpolation.py +6 -4
  41. fusion_bench/method/linear/simple_average_for_llama.py +17 -13
  42. fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
  43. fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
  44. fusion_bench/method/model_recombination.py +2 -5
  45. fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
  46. fusion_bench/method/moe_pruner/utils/data.py +2 -1
  47. fusion_bench/method/moe_pruner/utils/prune.py +6 -1
  48. fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
  49. fusion_bench/method/pruning/wanda_utils/data.py +1 -2
  50. fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
  51. fusion_bench/method/randes/modelsoup.py +1 -3
  52. fusion_bench/method/regmean/clip_regmean.py +2 -2
  53. fusion_bench/method/regmean/gpt2_regmean.py +3 -10
  54. fusion_bench/method/regmean/regmean.py +2 -11
  55. fusion_bench/method/regmean_plusplus/__init__.py +1 -1
  56. fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +24 -17
  57. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +56 -38
  58. fusion_bench/method/simple_average.py +12 -16
  59. fusion_bench/method/slerp/slerp.py +5 -2
  60. fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
  61. fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
  62. fusion_bench/method/smile_upscaling/projected_energy.py +144 -0
  63. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
  64. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +71 -51
  65. fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
  66. fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
  67. fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
  68. fusion_bench/method/ties_merging/ties_merging.py +13 -26
  69. fusion_bench/method/we_moe/__init__.py +1 -0
  70. fusion_bench/method/we_moe/clip_we_moe.py +5 -4
  71. fusion_bench/method/we_moe/entropy_loss.py +25 -0
  72. fusion_bench/method/we_moe/flan_t5_we_moe.py +331 -0
  73. fusion_bench/method/we_moe/utils.py +15 -0
  74. fusion_bench/method/we_moe/we_moe.py +6 -6
  75. fusion_bench/method/weighted_average/llama.py +4 -16
  76. fusion_bench/metrics/continual_learning/__init__.py +1 -0
  77. fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
  78. fusion_bench/metrics/nyuv2/__init__.py +2 -2
  79. fusion_bench/metrics/nyuv2/segmentation.py +1 -1
  80. fusion_bench/mixins/__init__.py +10 -2
  81. fusion_bench/mixins/clip_classification.py +15 -45
  82. fusion_bench/mixins/hydra_config.py +105 -7
  83. fusion_bench/mixins/lightning_fabric.py +2 -0
  84. fusion_bench/mixins/serialization.py +275 -48
  85. fusion_bench/modelpool/__init__.py +2 -2
  86. fusion_bench/modelpool/base_pool.py +29 -9
  87. fusion_bench/modelpool/causal_lm/causal_lm.py +41 -33
  88. fusion_bench/modelpool/clip_vision/modelpool.py +1 -3
  89. fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
  90. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
  91. fusion_bench/models/__init__.py +7 -1
  92. fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
  93. fusion_bench/models/hf_utils.py +160 -0
  94. fusion_bench/models/linearized/linearized_model_utils.py +4 -4
  95. fusion_bench/models/linearized/vision_model.py +1 -1
  96. fusion_bench/models/model_card_templates/default.md +46 -0
  97. fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
  98. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
  99. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
  100. fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
  101. fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
  102. fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
  103. fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
  104. fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
  105. fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
  106. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +698 -0
  107. fusion_bench/models/modeling_smile_llama/register.py +8 -0
  108. fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
  109. fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
  110. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +7 -12
  111. fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
  112. fusion_bench/models/parameter_dict.py +1 -1
  113. fusion_bench/models/sparse_we_moe.py +1 -53
  114. fusion_bench/models/utils.py +26 -0
  115. fusion_bench/models/we_moe.py +1 -53
  116. fusion_bench/models/wrappers/ensemble.py +6 -4
  117. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
  118. fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
  119. fusion_bench/programs/base_program.py +81 -2
  120. fusion_bench/programs/fabric_fusion_program.py +46 -61
  121. fusion_bench/scripts/cli.py +38 -5
  122. fusion_bench/taskpool/base_pool.py +4 -3
  123. fusion_bench/taskpool/clip_vision/taskpool.py +43 -22
  124. fusion_bench/taskpool/dummy.py +1 -1
  125. fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
  126. fusion_bench/tasks/clip_classification/__init__.py +6 -4
  127. fusion_bench/utils/__init__.py +7 -1
  128. fusion_bench/utils/cache_utils.py +101 -1
  129. fusion_bench/utils/devices.py +14 -4
  130. fusion_bench/utils/fabric.py +2 -2
  131. fusion_bench/utils/instantiate_utils.py +3 -1
  132. fusion_bench/utils/lazy_imports.py +23 -0
  133. fusion_bench/utils/lazy_state_dict.py +38 -3
  134. fusion_bench/utils/modelscope.py +127 -8
  135. fusion_bench/utils/parameters.py +2 -2
  136. fusion_bench/utils/path.py +56 -0
  137. fusion_bench/utils/pylogger.py +1 -1
  138. fusion_bench/utils/rich_utils.py +3 -0
  139. fusion_bench/utils/state_dict_arithmetic.py +25 -23
  140. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/METADATA +24 -47
  141. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/RECORD +184 -145
  142. fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
  143. fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
  144. fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
  145. fusion_bench_config/_get_started/greeting_program.yaml +4 -0
  146. fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
  147. fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
  148. fusion_bench_config/fabric_model_fusion.yaml +45 -17
  149. fusion_bench_config/hydra/default.yaml +6 -2
  150. fusion_bench_config/llama_full_finetune.yaml +1 -0
  151. fusion_bench_config/method/adamerging/clip.yaml +1 -1
  152. fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
  153. fusion_bench_config/method/depth_upscaling.yaml +4 -1
  154. fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
  155. fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
  156. fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
  157. fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
  158. fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
  159. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +2 -1
  160. fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
  161. fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
  162. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
  163. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +1 -1
  164. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -6
  165. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +1 -1
  166. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +1 -1
  167. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +3 -3
  168. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
  169. fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
  170. fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
  171. fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
  172. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
  173. fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
  174. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
  175. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
  176. fusion_bench_config/nyuv2_config.yaml +3 -1
  177. fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
  178. fusion_bench_config/path/default.yaml +28 -0
  179. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
  180. fusion_bench_config/method/adamerging.yaml +0 -23
  181. fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
  182. fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
  183. fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
  184. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/WHEEL +0 -0
  185. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/entry_points.txt +0 -0
  186. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/licenses/LICENSE +0 -0
  187. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/top_level.txt +0 -0
  188. /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
@@ -0,0 +1,371 @@
1
+ import logging
2
+ import os
3
+ from copy import deepcopy
4
+ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, Union
5
+
6
+ import torch
7
+ from accelerate import init_empty_weights
8
+ from tqdm.auto import tqdm
9
+ from transformers import (
10
+ AutoConfig,
11
+ AutoModelForCausalLM,
12
+ AutoTokenizer,
13
+ LlamaForCausalLM,
14
+ MistralForCausalLM,
15
+ PretrainedConfig,
16
+ PreTrainedModel,
17
+ Qwen2ForCausalLM,
18
+ )
19
+ from transformers.models.llama.modeling_llama import LlamaDecoderLayer
20
+ from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
21
+ from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
22
+
23
+ from fusion_bench import BaseAlgorithm, BaseModelPool
24
+ from fusion_bench.compat.modelpool import to_modelpool
25
+ from fusion_bench.constants import RuntimeConstants
26
+ from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
27
+ from fusion_bench.modelpool import CausalLMPool
28
+ from fusion_bench.models.hf_utils import (
29
+ create_default_model_card,
30
+ save_pretrained_with_remote_code,
31
+ )
32
+ from fusion_bench.models.modeling_smile_llama import (
33
+ SmileLlamaConfig,
34
+ SmileLlamaForCausalLM,
35
+ SmileLlamaModel,
36
+ )
37
+ from fusion_bench.models.modeling_smile_llama.modeling_smile_llama import (
38
+ SmileLlamaDecoderLayer,
39
+ )
40
+ from fusion_bench.models.modeling_smile_mistral import (
41
+ SmileMistralConfig,
42
+ SmileMistralForCausalLM,
43
+ SmileMistralModel,
44
+ )
45
+ from fusion_bench.models.modeling_smile_mistral.modeling_smile_mistral import (
46
+ SmileMistralDecoderLayer,
47
+ )
48
+
49
+ # Import all SMILE configurations and models
50
+ from fusion_bench.models.modeling_smile_qwen2 import (
51
+ SmileQwen2Config,
52
+ SmileQwen2ForCausalLM,
53
+ SmileQwen2Model,
54
+ )
55
+ from fusion_bench.models.modeling_smile_qwen2.modeling_smile_qwen2 import (
56
+ SmileQwen2DecoderLayer,
57
+ )
58
+ from fusion_bench.models.smile_moe.linear_from_hf_config import (
59
+ ExpertNotTrainedError,
60
+ upscale_to_smile_linear,
61
+ )
62
+ from fusion_bench.utils.dtype import parse_dtype
63
+ from fusion_bench.utils.parameters import print_parameters
64
+
65
+ log = logging.getLogger(__name__)
66
+
67
+ # Model type mappings
68
+ MODEL_TYPE_MAPPINGS = {
69
+ "qwen2": {
70
+ "base_model_cls": Qwen2ForCausalLM,
71
+ "base_decoder_layer_cls": Qwen2DecoderLayer,
72
+ "smile_config_cls": SmileQwen2Config,
73
+ "smile_model_cls": SmileQwen2ForCausalLM,
74
+ "smile_base_model_cls": SmileQwen2Model,
75
+ "smile_decoder_layer_cls": SmileQwen2DecoderLayer,
76
+ "description": "Qwen2",
77
+ },
78
+ "llama": {
79
+ "base_model_cls": LlamaForCausalLM,
80
+ "base_decoder_layer_cls": LlamaDecoderLayer,
81
+ "smile_config_cls": SmileLlamaConfig,
82
+ "smile_model_cls": SmileLlamaForCausalLM,
83
+ "smile_base_model_cls": SmileLlamaModel,
84
+ "smile_decoder_layer_cls": SmileLlamaDecoderLayer,
85
+ "description": "Llama",
86
+ },
87
+ "mistral": {
88
+ "base_model_cls": MistralForCausalLM,
89
+ "base_decoder_layer_cls": MistralDecoderLayer,
90
+ "smile_config_cls": SmileMistralConfig,
91
+ "smile_model_cls": SmileMistralForCausalLM,
92
+ "smile_base_model_cls": SmileMistralModel,
93
+ "smile_decoder_layer_cls": SmileMistralDecoderLayer,
94
+ "description": "Mistral",
95
+ },
96
+ }
97
+
98
+
99
+ def detect_model_type(
100
+ model_or_config: Union[PreTrainedModel, PretrainedConfig, str],
101
+ ) -> str:
102
+ """
103
+ Detect the model type from a model, config, or model name/path.
104
+
105
+ Args:
106
+ model_or_config: Model, config, or model name/path to detect type from
107
+
108
+ Returns:
109
+ str: The detected model type ("qwen2", "llama", "mistral")
110
+
111
+ Raises:
112
+ ValueError: If model type cannot be detected or is not supported
113
+ """
114
+ if isinstance(model_or_config, str):
115
+ # Load config from path/name
116
+ config = AutoConfig.from_pretrained(model_or_config)
117
+ elif isinstance(model_or_config, PreTrainedModel):
118
+ config = model_or_config.config
119
+ elif isinstance(model_or_config, PretrainedConfig):
120
+ config = model_or_config
121
+ else:
122
+ raise ValueError(
123
+ f"Unsupported type for model type detection: {type(model_or_config)}"
124
+ )
125
+
126
+ model_type = getattr(config, "model_type", "").lower()
127
+
128
+ # Handle various model type variations
129
+ if model_type in MODEL_TYPE_MAPPINGS:
130
+ return model_type
131
+ else:
132
+ raise ValueError(
133
+ f"Unsupported model type: {model_type}. Supported types: {list(MODEL_TYPE_MAPPINGS.keys())}"
134
+ )
135
+
136
+
137
+ @auto_register_config
138
+ class SmileCausalLMUpscalingAlgorithm(
139
+ SimpleProfilerMixin,
140
+ BaseAlgorithm,
141
+ ):
142
+ R"""
143
+ SmileCausalLMUpscalingAlgorithm is a generic model fusion algorithm designed to upscale
144
+ a pretrained CausalLM model using a set of fine-tuned expert models. The algorithm
145
+ supports Qwen2, Llama, and Mistral model architectures and leverages Singular Value
146
+ Decomposition (SVD) to merge the weights of the pretrained model and the expert models
147
+ into a new upscaled model.
148
+
149
+ The algorithm automatically detects the model type and uses the appropriate SMILE
150
+ configuration and model classes.
151
+
152
+ Methods:
153
+ run(modelpool: BaseModelPool) -> Union[SmileQwen2ForCausalLM, SmileLlamaForCausalLM, SmileMistralForCausalLM]:
154
+ Executes the upscaling process and returns the upscaled model.
155
+
156
+ merge(pretrained_model: PreTrainedModel, finetuned_models: List[PreTrainedModel]) -> PreTrainedModel:
157
+ Merges the pretrained model with the fine-tuned models to create an upscaled model.
158
+ """
159
+
160
+ modelpool: CausalLMPool
161
+
162
+ def __init__(
163
+ self,
164
+ device,
165
+ accelerator,
166
+ model_save_path,
167
+ model_dtype,
168
+ num_experts_per_tok,
169
+ rank_of_router,
170
+ rank_of_expert,
171
+ save_with_remote_code: bool = True,
172
+ model_type: str = None, # Optional: explicitly specify model type
173
+ **kwargs,
174
+ ):
175
+ super().__init__(**kwargs)
176
+ self.model_mappings = None # Will be set during run()
177
+
178
+ if not torch.cuda.is_available():
179
+ if "cuda" in self.device:
180
+ self.device = "cpu"
181
+ if "cuda" in self.accelerator:
182
+ self.accelerator = "cpu"
183
+
184
+ @torch.no_grad()
185
+ def run(self, modelpool) -> PreTrainedModel:
186
+ """
187
+ Executes the upscaling process.
188
+
189
+ Args:
190
+ modelpool (ModelPool): The pool of models to be used for upscaling.
191
+
192
+ Returns:
193
+ PreTrainedModel: The upscaled model (specific type depends on detected model architecture).
194
+ """
195
+ self.modelpool = modelpool = to_modelpool(modelpool)
196
+ config = self.config
197
+
198
+ # Auto-detect model type if not specified
199
+ if self.model_type is None:
200
+ self.model_type = detect_model_type(
201
+ modelpool.get_model_path("_pretrained_")
202
+ )
203
+ log.info(f"Auto-detected model type: {self.model_type}")
204
+
205
+ # Get the appropriate model mappings
206
+ if self.model_type not in MODEL_TYPE_MAPPINGS:
207
+ raise ValueError(
208
+ f"Unsupported model type: {self.model_type}. Supported: {list(MODEL_TYPE_MAPPINGS.keys())}"
209
+ )
210
+
211
+ self.model_mappings = MODEL_TYPE_MAPPINGS[self.model_type]
212
+ log.info(f"Using {self.model_mappings['description']} model architecture")
213
+
214
+ with self.profile("load pretrained model"):
215
+ pretrained_model = modelpool.load_pretrained_model()
216
+
217
+ with self.profile("load fine-tuned model"):
218
+ finetuned_models = [
219
+ m for m in tqdm(modelpool.models(), total=len(modelpool.model_names))
220
+ ]
221
+
222
+ if self.device == "cuda" and torch.cuda.is_available():
223
+ pretrained_model = pretrained_model.cuda()
224
+ print("parameter count of pretrained model:")
225
+ print_parameters(pretrained_model)
226
+ finetuned_models = [m.cuda() for m in finetuned_models]
227
+
228
+ with self.profile("merge model"):
229
+ model = self.merge(pretrained_model, finetuned_models)
230
+
231
+ self.print_profile_summary()
232
+ print("parameter count of upscaled MoE model:")
233
+ print_parameters(model)
234
+ print(model)
235
+
236
+ if self.model_dtype is not None:
237
+ model.to(dtype=parse_dtype(self.model_dtype))
238
+
239
+ if self.model_save_path is not None:
240
+ if os.path.dirname(self.model_save_path):
241
+ os.makedirs(os.path.dirname(self.model_save_path), exist_ok=True)
242
+ log.info(f"Saving model to {self.model_save_path}")
243
+ tokenizer = self.modelpool.load_tokenizer()
244
+ tokenizer.save_pretrained(self.model_save_path)
245
+ if not self.save_with_remote_code:
246
+ model.save_pretrained(self.model_save_path)
247
+ else:
248
+ # Use the appropriate auto_map for the detected model type
249
+ auto_map = {
250
+ "AutoConfig": self.model_mappings["smile_config_cls"],
251
+ "AutoModel": self.model_mappings["smile_base_model_cls"],
252
+ "AutoModelForCausalLM": self.model_mappings["smile_model_cls"],
253
+ }
254
+ save_pretrained_with_remote_code(
255
+ model,
256
+ auto_map=auto_map,
257
+ save_directory=self.model_save_path,
258
+ )
259
+
260
+ # save readme
261
+ model_card_str = create_default_model_card(
262
+ models=[modelpool.get_model_path(m) for m in modelpool.all_model_names],
263
+ description=f"Merged {self.model_mappings['description']} model using SMILE Upscaling",
264
+ algorithm_config=self.config,
265
+ modelpool_config=modelpool.config,
266
+ )
267
+ with open(os.path.join(self.model_save_path, "README.md"), "w") as f:
268
+ f.write(model_card_str)
269
+
270
+ return model
271
+
272
+ def merge(
273
+ self,
274
+ pretrained_model: PreTrainedModel,
275
+ finetuned_models: List[PreTrainedModel],
276
+ ) -> PreTrainedModel:
277
+ """
278
+ Merges the pretrained model with the fine-tuned models to create an upscaled model.
279
+
280
+ Args:
281
+ pretrained_model (PreTrainedModel): The pretrained model.
282
+ finetuned_models (List[PreTrainedModel]): A list of fine-tuned models.
283
+
284
+ Returns:
285
+ PreTrainedModel: The upscaled model (specific type depends on model architecture).
286
+ """
287
+ with init_empty_weights():
288
+ pretrained_model_config = self.modelpool.get_model_config("_pretrained_")
289
+ if isinstance(pretrained_model_config, str):
290
+ pretrained_path = pretrained_model_config
291
+ else:
292
+ pretrained_path = pretrained_model_config.get(
293
+ "path", pretrained_model_config["pretrained_model_name_or_path"]
294
+ )
295
+ base_config = AutoConfig.from_pretrained(pretrained_path)
296
+
297
+ # Create the appropriate SMILE config for the detected model type
298
+ SmileConfigClass = self.model_mappings["smile_config_cls"]
299
+ model_config = SmileConfigClass(
300
+ num_experts_per_tok=self.num_experts_per_tok,
301
+ rank_of_router=self.rank_of_router,
302
+ rank_of_expert=self.rank_of_expert,
303
+ num_local_experts=len(finetuned_models),
304
+ **base_config.to_dict(),
305
+ )
306
+
307
+ # Create the appropriate SMILE model for the detected model type
308
+ SmileModelClass = self.model_mappings["smile_model_cls"]
309
+ model = SmileModelClass(model_config)
310
+
311
+ model.to(dtype=pretrained_model.dtype).to_empty(device="cpu")
312
+
313
+ # copy pretrained model weights
314
+ state_dict = model.state_dict()
315
+ pretrained_state_dict = pretrained_model.state_dict()
316
+ for key in list(pretrained_state_dict.keys()):
317
+ if key not in state_dict:
318
+ pretrained_state_dict.pop(key)
319
+ model.load_state_dict(pretrained_state_dict, strict=False)
320
+
321
+ # upscale model
322
+ BaseDecoderLayerClass = self.model_mappings["base_decoder_layer_cls"]
323
+ SmileDecoderLayerClass = self.model_mappings["smile_decoder_layer_cls"]
324
+
325
+ for layer_idx in tqdm(
326
+ range(len(pretrained_model.model.layers)),
327
+ "Upscaling Modules (layer)",
328
+ dynamic_ncols=True,
329
+ ):
330
+ if RuntimeConstants.debug and layer_idx > 0:
331
+ log.info(
332
+ "Debug mode enabled: processing only the first layer, skipping remaining layers"
333
+ )
334
+ break
335
+
336
+ pretrained_layer = pretrained_model.model.layers[layer_idx]
337
+ finetuned_layers = [m.model.layers[layer_idx] for m in finetuned_models]
338
+
339
+ target_layer = model.model.layers[layer_idx]
340
+
341
+ for n in ["q_proj", "k_proj", "v_proj", "o_proj"]:
342
+ try:
343
+ upscale_to_smile_linear(
344
+ base=getattr(pretrained_layer.self_attn, n),
345
+ experts=[getattr(m.self_attn, n) for m in finetuned_layers],
346
+ target=getattr(target_layer.self_attn, n),
347
+ accelerator=self.accelerator,
348
+ )
349
+ except ExpertNotTrainedError:
350
+ setattr(
351
+ target_layer.self_attn,
352
+ n,
353
+ getattr(pretrained_layer.self_attn, n),
354
+ )
355
+
356
+ for n in ["gate_proj", "up_proj", "down_proj"]:
357
+ try:
358
+ upscale_to_smile_linear(
359
+ base=getattr(pretrained_layer.mlp, n),
360
+ experts=[getattr(m.mlp, n) for m in finetuned_layers],
361
+ target=getattr(target_layer.mlp, n),
362
+ accelerator=self.accelerator,
363
+ )
364
+ except ExpertNotTrainedError:
365
+ setattr(
366
+ target_layer.mlp,
367
+ n,
368
+ getattr(pretrained_layer.mlp, n),
369
+ )
370
+
371
+ return model
@@ -0,0 +1,177 @@
1
+ import os
2
+ from typing import Literal, cast
3
+
4
+ import pandas as pd
5
+ import torch
6
+ from omegaconf import DictConfig
7
+ from torch import nn
8
+ from torch.utils.data import DataLoader
9
+ from tqdm import tqdm
10
+ from transformers import CLIPVisionModel
11
+
12
+ from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
13
+ from fusion_bench.dataset import CLIPDataset
14
+ from fusion_bench.method import SmileUpscalingAlgorithm
15
+ from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
16
+ from fusion_bench.modelpool import CLIPVisionModelPool
17
+ from fusion_bench.taskpool.clip_vision.taskpool import LayerWiseFeatureSaver
18
+ from fusion_bench.utils.devices import clear_cuda_cache
19
+
20
+
21
+ @auto_register_config
22
+ class LowRankApproximation(BaseAlgorithm):
23
+ def __init__(self, rank: int, device: str = "cuda", **kwargs):
24
+ """Low-rank approximation of fine-tuned updates."""
25
+ super().__init__(**kwargs)
26
+
27
+ def run(self, modelpool: BaseModelPool):
28
+ # Implement low-rank approximation logic here
29
+ base_model = modelpool.load_pretrained_model()
30
+
31
+ models = {}
32
+ for model_name in tqdm(modelpool.model_names, "processing models"):
33
+ task_model = modelpool.load_model(model_name)
34
+ for module_name, module in task_model.named_modules():
35
+ if isinstance(module, nn.Linear):
36
+ w = cast(
37
+ nn.Linear, base_model.get_submodule(module_name)
38
+ ).weight.to(dtype=torch.float32, device=self.device, copy=True)
39
+ w_ft = module.weight.to(
40
+ dtype=torch.float32, device=self.device, copy=True
41
+ )
42
+
43
+ # Compute low-rank approximation
44
+ w_diff = w_ft - w
45
+ u, s, vh = torch.linalg.svd(w_diff)
46
+ v = vh.T
47
+
48
+ u = u[:, : self.rank]
49
+ s = s[: self.rank]
50
+ v = v[:, : self.rank]
51
+
52
+ low_rank_w_diff = torch.linalg.multi_dot((u, torch.diag(s), v.T))
53
+ low_rank_w = w + low_rank_w_diff
54
+
55
+ module.weight.data = low_rank_w.to(
56
+ dtype=module.weight.dtype,
57
+ device=module.weight.device,
58
+ )
59
+
60
+ models[model_name] = task_model
61
+ return models
62
+
63
+
64
+ @auto_register_config
65
+ class ErrorAccumulationAnalysisForCLIP(
66
+ LightningFabricMixin,
67
+ BaseAlgorithm,
68
+ ):
69
+ def __init__(
70
+ self,
71
+ gate_k: int,
72
+ k: int,
73
+ seed: int = 42,
74
+ top_k: int = 1,
75
+ dataset_kwargs: DictConfig = None,
76
+ max_samples: int = 1024,
77
+ **kwargs,
78
+ ):
79
+ super().__init__(**kwargs)
80
+ if dataset_kwargs is None:
81
+ self.dataset_kwargs = DictConfig(
82
+ {
83
+ "batch_size": 32,
84
+ "num_workers": 4,
85
+ }
86
+ )
87
+
88
+ def run(self, modelpool: CLIPVisionModelPool):
89
+ assert self.fabric.world_size == 1, "Distributed inference is not supported."
90
+ # get the smile model
91
+ smile_algorithm = SmileUpscalingAlgorithm(
92
+ gate_k=self.gate_k, k=self.k, top_k=self.top_k, device=self.fabric.device
93
+ )
94
+ smile_model = smile_algorithm.run(modelpool)
95
+ # get low-rank models
96
+ low_rank_models = LowRankApproximation(rank=self.k).run(modelpool)
97
+
98
+ results = {
99
+ "model_name": [],
100
+ "method": [],
101
+ "layer_index": [],
102
+ "approximation_error": [],
103
+ }
104
+
105
+ for model_name in modelpool.model_names:
106
+ dataset = modelpool.load_test_dataset(model_name)
107
+ processor = modelpool.load_processor()
108
+ dataset = CLIPDataset(dataset, processor)
109
+ dataloader = DataLoader(dataset, shuffle=True, **self.dataset_kwargs)
110
+ dataloader = self.fabric.setup_dataloaders(dataloader)
111
+
112
+ # finetuned_model
113
+ finetuned_model = modelpool.load_model(model_name)
114
+ finetuned_model = self.to_device(finetuned_model)
115
+ self.collect_hidden_states(
116
+ finetuned_model,
117
+ dataloader=dataloader,
118
+ model_name=f"{model_name}/finetuned",
119
+ )
120
+ del finetuned_model
121
+ clear_cuda_cache()
122
+
123
+ # smile model
124
+ smile_model = self.to_device(smile_model)
125
+ self.collect_hidden_states(
126
+ smile_model, dataloader=dataloader, model_name=f"{model_name}/smile"
127
+ )
128
+ smile_model.cpu()
129
+ clear_cuda_cache()
130
+
131
+ # low-rank models
132
+ model = low_rank_models.pop(model_name)
133
+ model = self.to_device(model)
134
+ self.collect_hidden_states(
135
+ model, dataloader=dataloader, model_name=f"{model_name}/low-rank"
136
+ )
137
+ del model
138
+ clear_cuda_cache()
139
+
140
+ del dataloader
141
+ clear_cuda_cache()
142
+
143
+ @torch.no_grad()
144
+ def collect_hidden_states(
145
+ self, model: CLIPVisionModel, dataloader, model_name: str
146
+ ):
147
+ self.fabric.seed_everything(
148
+ self.seed, workers=True
149
+ ) # make sure to get same data samples
150
+ # register hooks
151
+ hooks = {}
152
+ hook_handles = {}
153
+ for i, layer in enumerate(model.vision_model.encoder.layers):
154
+ hooks[i] = LayerWiseFeatureSaver(
155
+ save_path=os.path.join(self.log_dir, model_name, f"layer_{i}.pth"),
156
+ first_token_only=True,
157
+ )
158
+ hook_handles[i] = layer.register_forward_hook(hooks[i])
159
+
160
+ # forward pass
161
+ num_total_samples = 0
162
+ for images, _ in tqdm(dataloader, desc=f"Collecting features for {model_name}"):
163
+ batch_size = images.size(0)
164
+ model(images)
165
+ num_total_samples += batch_size
166
+ if num_total_samples >= self.max_samples:
167
+ break
168
+
169
+ # save features
170
+ for i, hook in hooks.items():
171
+ hook.save_features()
172
+
173
+ # remove hooks
174
+ for i, hook_handle in hook_handles.items():
175
+ hook_handle.remove()
176
+
177
+ return hooks
@@ -0,0 +1,144 @@
1
+ import os
2
+ from typing import Literal
3
+
4
+ import pandas as pd
5
+ import torch
6
+ from tqdm import tqdm
7
+
8
+ from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
9
+ from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
10
+
11
+
12
+ class ProjectedEnergyAnalysis(
13
+ SimpleProfilerMixin,
14
+ LightningFabricMixin,
15
+ BaseAlgorithm,
16
+ ):
17
+ def on_run_start(self):
18
+ self.device = self.fabric.device
19
+
20
+ def run(self, modelpool: BaseModelPool):
21
+ with self.profile("model loading"):
22
+ base_model = modelpool.load_pretrained_model()
23
+
24
+ results = {
25
+ "model_name": [],
26
+ "module_index": [],
27
+ "module_name": [],
28
+ "projected_energy_I": [],
29
+ "projected_energy_II": [],
30
+ "projected_energy_II_III": [],
31
+ }
32
+ for model_name in tqdm(
33
+ modelpool.model_names,
34
+ "analyzing",
35
+ dynamic_ncols=True,
36
+ ):
37
+ with self.profile("model loading"):
38
+ finetuned_model = modelpool.load_model(model_name)
39
+
40
+ module_index = 0
41
+ for module_name, base_module in tqdm(
42
+ list(base_model.named_modules()),
43
+ "analyzing modules",
44
+ dynamic_ncols=True,
45
+ ):
46
+ if isinstance(base_module, torch.nn.Linear):
47
+ with self.profile("weight analysis"):
48
+ _result = self.analyze_weight(
49
+ base_module.weight,
50
+ finetuned_model.get_submodule(module_name).weight,
51
+ )
52
+ results["model_name"].append(model_name)
53
+ results["module_index"].append(module_index)
54
+ results["module_name"].append(module_name)
55
+ for key, value in _result.items():
56
+ results[key].append(value)
57
+
58
+ module_index += 1
59
+
60
+ # save results as csv
61
+ results = pd.DataFrame(results)
62
+ results.to_csv(
63
+ os.path.join(self.log_dir, "projected_energy_analysis.csv"), index=True
64
+ )
65
+
66
+ self.print_profile_summary()
67
+ return None
68
+
69
+ @torch.no_grad()
70
+ def analyze_weight(self, w: torch.Tensor, w_ft: torch.Tensor, k: int = -1):
71
+ w = w.to(dtype=torch.float32, device=self.device)
72
+ w_ft = w_ft.to(dtype=torch.float32, device=self.device)
73
+ w_diff = w_ft - w
74
+
75
+ # Perform analysis on the weight tensor
76
+ u, s, vh = torch.linalg.svd(w, full_matrices=False)
77
+ v = vh.T
78
+ if k < 0:
79
+ # find the position where the sum of singular values is larger than 50% of the total sum
80
+ cumsum = s.cumsum(0)
81
+ k = (cumsum < cumsum[-1] * 0.5).sum().item() + 1
82
+
83
+ # subspace I
84
+ w_diff_proj = self._project_subspace_low(u=u, s=s, v=v, k=k, w=w, w_ft=w_ft)
85
+ projected_energy_I = (
86
+ torch.linalg.norm(w_diff_proj, ord="fro") ** 2
87
+ / torch.linalg.norm(w_diff, ord="fro") ** 2
88
+ )
89
+
90
+ # subspace II
91
+ w_diff_proj = self._project_subspace_high(u=u, s=s, v=v, k=k, w=w, w_ft=w_ft)
92
+ projected_energy_II = (
93
+ torch.linalg.norm(w_diff_proj, ord="fro") ** 2
94
+ / torch.linalg.norm(w_diff, ord="fro") ** 2
95
+ )
96
+
97
+ ## subspace II+III
98
+ u, s, vh = torch.linalg.svd(w, full_matrices=True)
99
+ v = vh.T
100
+ w_diff_proj = self._project_subspace_high(u=u, s=s, v=v, k=k, w=w, w_ft=w_ft)
101
+ projected_energy_II_III = (
102
+ torch.linalg.norm(w_diff_proj, ord="fro") ** 2
103
+ / torch.linalg.norm(w_diff, ord="fro") ** 2
104
+ )
105
+
106
+ return {
107
+ "projected_energy_I": projected_energy_I.item(),
108
+ "projected_energy_II": projected_energy_II.item(),
109
+ "projected_energy_II_III": projected_energy_II_III.item(),
110
+ }
111
+
112
+ def _project_subspace_low(
113
+ self,
114
+ u: torch.Tensor,
115
+ s: torch.Tensor,
116
+ v: torch.Tensor,
117
+ k: int,
118
+ w: torch.Tensor,
119
+ w_ft: torch.Tensor,
120
+ ):
121
+ u = u[:, :k]
122
+ s = s[:k]
123
+ v = v[:, :k]
124
+
125
+ w_diff = w_ft - w
126
+ w_diff_proj = torch.linalg.multi_dot((u, u.T, w_diff, v, v.T))
127
+ return w_diff_proj
128
+
129
+ def _project_subspace_high(
130
+ self,
131
+ u: torch.Tensor,
132
+ s: torch.Tensor,
133
+ v: torch.Tensor,
134
+ k: int,
135
+ w: torch.Tensor,
136
+ w_ft: torch.Tensor,
137
+ ):
138
+ u = u[:, k:]
139
+ s = s[k:]
140
+ v = v[:, k:]
141
+
142
+ w_diff = w_ft - w
143
+ w_diff_proj = torch.linalg.multi_dot((u, u.T, w_diff, v, v.T))
144
+ return w_diff_proj