fusion-bench 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. fusion_bench/__init__.py +22 -2
  2. fusion_bench/_get_started/__init__.py +3 -0
  3. fusion_bench/_get_started/greeting_program.py +49 -0
  4. fusion_bench/compat/method/base_algorithm.py +14 -0
  5. fusion_bench/constants/__init__.py +6 -0
  6. fusion_bench/constants/clip_vision.py +26 -2
  7. fusion_bench/constants/paths.py +4 -0
  8. fusion_bench/constants/runtime.py +57 -0
  9. fusion_bench/dataset/clip_dataset.py +2 -1
  10. fusion_bench/dataset/gpt2_glue.py +9 -9
  11. fusion_bench/dataset/image_corruption/__init__.py +0 -0
  12. fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
  13. fusion_bench/dataset/image_dataset.py +1 -1
  14. fusion_bench/dataset/nyuv2.py +2 -2
  15. fusion_bench/method/__init__.py +24 -5
  16. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  17. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
  18. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
  19. fusion_bench/method/base_algorithm.py +195 -12
  20. fusion_bench/method/bitdelta/__init__.py +5 -0
  21. fusion_bench/method/bitdelta/bitdelta.py +156 -0
  22. fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
  23. fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
  24. fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
  25. fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
  26. fusion_bench/method/classification/clip_finetune.py +1 -1
  27. fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
  28. fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
  29. fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
  30. fusion_bench/method/doge_ta/doge_ta.py +1 -1
  31. fusion_bench/method/ensemble.py +12 -12
  32. fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
  33. fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -6
  34. fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
  35. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
  36. fusion_bench/method/fw_merging/fw_hard.py +1 -1
  37. fusion_bench/method/fw_merging/fw_soft.py +1 -1
  38. fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
  39. fusion_bench/method/linear/expo.py +2 -1
  40. fusion_bench/method/linear/linear_interpolation.py +6 -4
  41. fusion_bench/method/linear/simple_average_for_llama.py +17 -13
  42. fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
  43. fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
  44. fusion_bench/method/model_recombination.py +2 -5
  45. fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
  46. fusion_bench/method/moe_pruner/utils/data.py +2 -1
  47. fusion_bench/method/moe_pruner/utils/prune.py +6 -1
  48. fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
  49. fusion_bench/method/pruning/wanda_utils/data.py +1 -2
  50. fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
  51. fusion_bench/method/randes/modelsoup.py +1 -3
  52. fusion_bench/method/regmean/clip_regmean.py +2 -2
  53. fusion_bench/method/regmean/gpt2_regmean.py +3 -10
  54. fusion_bench/method/regmean/regmean.py +2 -11
  55. fusion_bench/method/regmean_plusplus/__init__.py +1 -1
  56. fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +24 -17
  57. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +56 -38
  58. fusion_bench/method/simple_average.py +12 -16
  59. fusion_bench/method/slerp/slerp.py +5 -2
  60. fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
  61. fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
  62. fusion_bench/method/smile_upscaling/projected_energy.py +144 -0
  63. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
  64. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +71 -51
  65. fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
  66. fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
  67. fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
  68. fusion_bench/method/ties_merging/ties_merging.py +13 -26
  69. fusion_bench/method/we_moe/__init__.py +1 -0
  70. fusion_bench/method/we_moe/clip_we_moe.py +5 -4
  71. fusion_bench/method/we_moe/entropy_loss.py +25 -0
  72. fusion_bench/method/we_moe/flan_t5_we_moe.py +331 -0
  73. fusion_bench/method/we_moe/utils.py +15 -0
  74. fusion_bench/method/we_moe/we_moe.py +6 -6
  75. fusion_bench/method/weighted_average/llama.py +4 -16
  76. fusion_bench/metrics/continual_learning/__init__.py +1 -0
  77. fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
  78. fusion_bench/metrics/nyuv2/__init__.py +2 -2
  79. fusion_bench/metrics/nyuv2/segmentation.py +1 -1
  80. fusion_bench/mixins/__init__.py +10 -2
  81. fusion_bench/mixins/clip_classification.py +15 -45
  82. fusion_bench/mixins/hydra_config.py +105 -7
  83. fusion_bench/mixins/lightning_fabric.py +2 -0
  84. fusion_bench/mixins/serialization.py +275 -48
  85. fusion_bench/modelpool/__init__.py +2 -2
  86. fusion_bench/modelpool/base_pool.py +29 -9
  87. fusion_bench/modelpool/causal_lm/causal_lm.py +41 -33
  88. fusion_bench/modelpool/clip_vision/modelpool.py +1 -3
  89. fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
  90. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
  91. fusion_bench/models/__init__.py +7 -1
  92. fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
  93. fusion_bench/models/hf_utils.py +160 -0
  94. fusion_bench/models/linearized/linearized_model_utils.py +4 -4
  95. fusion_bench/models/linearized/vision_model.py +1 -1
  96. fusion_bench/models/model_card_templates/default.md +46 -0
  97. fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
  98. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
  99. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
  100. fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
  101. fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
  102. fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
  103. fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
  104. fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
  105. fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
  106. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +698 -0
  107. fusion_bench/models/modeling_smile_llama/register.py +8 -0
  108. fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
  109. fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
  110. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +7 -12
  111. fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
  112. fusion_bench/models/parameter_dict.py +1 -1
  113. fusion_bench/models/sparse_we_moe.py +1 -53
  114. fusion_bench/models/utils.py +26 -0
  115. fusion_bench/models/we_moe.py +1 -53
  116. fusion_bench/models/wrappers/ensemble.py +6 -4
  117. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
  118. fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
  119. fusion_bench/programs/base_program.py +81 -2
  120. fusion_bench/programs/fabric_fusion_program.py +46 -61
  121. fusion_bench/scripts/cli.py +38 -5
  122. fusion_bench/taskpool/base_pool.py +4 -3
  123. fusion_bench/taskpool/clip_vision/taskpool.py +43 -22
  124. fusion_bench/taskpool/dummy.py +1 -1
  125. fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
  126. fusion_bench/tasks/clip_classification/__init__.py +6 -4
  127. fusion_bench/utils/__init__.py +7 -1
  128. fusion_bench/utils/cache_utils.py +101 -1
  129. fusion_bench/utils/devices.py +14 -4
  130. fusion_bench/utils/fabric.py +2 -2
  131. fusion_bench/utils/instantiate_utils.py +3 -1
  132. fusion_bench/utils/lazy_imports.py +23 -0
  133. fusion_bench/utils/lazy_state_dict.py +38 -3
  134. fusion_bench/utils/modelscope.py +127 -8
  135. fusion_bench/utils/parameters.py +2 -2
  136. fusion_bench/utils/path.py +56 -0
  137. fusion_bench/utils/pylogger.py +1 -1
  138. fusion_bench/utils/rich_utils.py +3 -0
  139. fusion_bench/utils/state_dict_arithmetic.py +25 -23
  140. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/METADATA +24 -47
  141. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/RECORD +184 -145
  142. fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
  143. fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
  144. fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
  145. fusion_bench_config/_get_started/greeting_program.yaml +4 -0
  146. fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
  147. fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
  148. fusion_bench_config/fabric_model_fusion.yaml +45 -17
  149. fusion_bench_config/hydra/default.yaml +6 -2
  150. fusion_bench_config/llama_full_finetune.yaml +1 -0
  151. fusion_bench_config/method/adamerging/clip.yaml +1 -1
  152. fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
  153. fusion_bench_config/method/depth_upscaling.yaml +4 -1
  154. fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
  155. fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
  156. fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
  157. fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
  158. fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
  159. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +2 -1
  160. fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
  161. fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
  162. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
  163. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +1 -1
  164. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -6
  165. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +1 -1
  166. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +1 -1
  167. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +3 -3
  168. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
  169. fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
  170. fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
  171. fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
  172. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
  173. fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
  174. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
  175. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
  176. fusion_bench_config/nyuv2_config.yaml +3 -1
  177. fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
  178. fusion_bench_config/path/default.yaml +28 -0
  179. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
  180. fusion_bench_config/method/adamerging.yaml +0 -23
  181. fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
  182. fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
  183. fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
  184. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/WHEEL +0 -0
  185. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/entry_points.txt +0 -0
  186. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/licenses/LICENSE +0 -0
  187. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/top_level.txt +0 -0
  188. /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
@@ -20,6 +20,7 @@ from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
20
20
  from fusion_bench.compat.modelpool import to_modelpool
21
21
  from fusion_bench.method import BaseAlgorithm
22
22
  from fusion_bench.method.simple_average import simple_average
23
+ from fusion_bench.mixins import auto_register_config
23
24
  from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
24
25
  from fusion_bench.modelpool import BaseModelPool
25
26
  from fusion_bench.models.modeling_smile_mistral import (
@@ -40,7 +41,10 @@ from fusion_bench.utils.parameters import print_parameters
40
41
  log = logging.getLogger(__name__)
41
42
 
42
43
 
43
- class SmileMistralUpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
44
+ class SmileMistralUpscalingAlgorithm(
45
+ SimpleProfilerMixin,
46
+ BaseAlgorithm,
47
+ ):
44
48
  R"""
45
49
  SmileMistralUpscalingAlgorithm is a model fusion algorithm designed to upscale
46
50
  a pretrained Mistral model using a set of fine-tuned expert models. The algorithm
@@ -16,10 +16,17 @@ from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
16
16
 
17
17
  from fusion_bench import BaseAlgorithm, BaseModelPool
18
18
  from fusion_bench.compat.modelpool import to_modelpool
19
- from fusion_bench.mixins import SimpleProfilerMixin
19
+ from fusion_bench.constants import RuntimeConstants
20
+ from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
21
+ from fusion_bench.modelpool import CausalLMPool
22
+ from fusion_bench.models.hf_utils import (
23
+ create_default_model_card,
24
+ save_pretrained_with_remote_code,
25
+ )
20
26
  from fusion_bench.models.modeling_smile_qwen2 import (
21
27
  SmileQwen2Config,
22
28
  SmileQwen2ForCausalLM,
29
+ SmileQwen2Model,
23
30
  )
24
31
  from fusion_bench.models.modeling_smile_qwen2.modeling_smile_qwen2 import (
25
32
  SmileQwen2DecoderLayer,
@@ -34,7 +41,11 @@ from fusion_bench.utils.parameters import print_parameters
34
41
  log = logging.getLogger(__name__)
35
42
 
36
43
 
37
- class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
44
+ @auto_register_config
45
+ class SmileQwen2UpscalingAlgorithm(
46
+ SimpleProfilerMixin,
47
+ BaseAlgorithm,
48
+ ):
38
49
  R"""
39
50
  SmileQwen2UpscalingAlgorithm is a model fusion algorithm designed to upscale
40
51
  a pretrained Qwen2 model using a set of fine-tuned expert models. The algorithm
@@ -49,39 +60,29 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
49
60
  Merges the pretrained model with the fine-tuned models to create an upscaled model.
50
61
  """
51
62
 
52
- _config_mapping = BaseAlgorithm._config_mapping | {
53
- "device": "device",
54
- "accelerator": "accelerator",
55
- "model_path": "model_path",
56
- "model_dtype": "model_dtype",
57
- "num_experts_per_tok": "num_experts_per_tok",
58
- "rank_of_router": "rank_of_router",
59
- "rank_of_expert": "rank_of_expert",
60
- }
63
+ modelpool: CausalLMPool
61
64
 
62
65
  def __init__(
63
66
  self,
64
67
  device,
65
68
  accelerator,
66
- model_path,
69
+ model_save_path,
67
70
  model_dtype,
68
71
  num_experts_per_tok,
69
72
  rank_of_router,
70
73
  rank_of_expert,
74
+ save_with_remote_code: bool = True,
71
75
  **kwargs,
72
76
  ):
73
- self.device = device
74
- self.accelerator = accelerator
75
- self.model_path = model_path
76
- self.model_dtype = model_dtype
77
- # SmileMoE parameters, except `num_local_experts` which is set later according to the number of finetuned models
78
- self.num_experts_per_tok = num_experts_per_tok
79
- self.rank_of_router = rank_of_router
80
- self.rank_of_expert = rank_of_expert
81
77
  super().__init__(**kwargs)
78
+ if not torch.cuda.is_available():
79
+ if "cuda" in self.device:
80
+ self.device = "cpu"
81
+ if "cuda" in self.accelerator:
82
+ self.accelerator = "cpu"
82
83
 
83
84
  @torch.no_grad()
84
- def run(self, modelpool: BaseModelPool) -> SmileQwen2ForCausalLM:
85
+ def run(self, modelpool) -> SmileQwen2ForCausalLM:
85
86
  """
86
87
  Executes the upscaling process.
87
88
 
@@ -94,13 +95,6 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
94
95
  self.modelpool = modelpool = to_modelpool(modelpool)
95
96
  config = self.config
96
97
 
97
- # load model from path if provided and return directly
98
- if config.model_path is not None and os.path.exists(config.model_path):
99
- log.info(f"Loading model from {config.model_path}")
100
- model = AutoModelForCausalLM.from_pretrained(config.model_path)
101
- print_parameters(model)
102
- return model
103
-
104
98
  with self.profile("load pretrained model"):
105
99
  pretrained_model = modelpool.load_pretrained_model()
106
100
  with self.profile("load fine-tuned model"):
@@ -108,7 +102,7 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
108
102
  m for m in tqdm(modelpool.models(), total=len(modelpool.model_names))
109
103
  ]
110
104
 
111
- if config.device == "cuda" and torch.cuda.is_available():
105
+ if self.device == "cuda" and torch.cuda.is_available():
112
106
  pretrained_model = pretrained_model.cuda()
113
107
  print("parameter count of pretrained model:")
114
108
  print_parameters(pretrained_model)
@@ -122,20 +116,37 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
122
116
  print_parameters(model)
123
117
  print(model)
124
118
 
125
- if config.model_dtype is not None:
126
- model.to(dtype=parse_dtype(config.model_dtype))
127
-
128
- if config.model_path is not None:
129
- if os.path.dirname(config.model_path):
130
- os.makedirs(os.path.dirname(config.model_path), exist_ok=True)
131
- log.info(f"Saving model to {config.model_path}")
132
- pretrained_model_config = self.modelpool.get_model_config("_pretrained_")
133
- pretrained_path = pretrained_model_config.get(
134
- "path", pretrained_model_config["pretrained_model_name_or_path"]
119
+ if self.model_dtype is not None:
120
+ model.to(dtype=parse_dtype(self.model_dtype))
121
+
122
+ if self.model_save_path is not None:
123
+ if os.path.dirname(self.model_save_path):
124
+ os.makedirs(os.path.dirname(self.model_save_path), exist_ok=True)
125
+ log.info(f"Saving model to {self.model_save_path}")
126
+ tokenizer = self.modelpool.load_tokenizer()
127
+ tokenizer.save_pretrained(self.model_save_path)
128
+ if not self.save_with_remote_code:
129
+ model.save_pretrained(self.model_save_path)
130
+ else:
131
+ save_pretrained_with_remote_code(
132
+ model,
133
+ auto_map={
134
+ "AutoConfig": SmileQwen2Config,
135
+ "AutoModel": SmileQwen2Model,
136
+ "AutoModelForCausalLM": SmileQwen2ForCausalLM,
137
+ },
138
+ save_directory=self.model_save_path,
139
+ )
140
+
141
+ # save readme
142
+ model_card_str = create_default_model_card(
143
+ models=[modelpool.get_model_path(m) for m in modelpool.all_model_names],
144
+ description="Merged Qwen model using SMILE Upscaling",
145
+ algorithm_config=self.config,
146
+ modelpool_config=modelpool.config,
135
147
  )
136
- tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
137
- tokenizer.save_pretrained(config.model_path)
138
- model.save_pretrained(config.model_path)
148
+ with open(os.path.join(self.model_save_path, "README.md"), "w") as f:
149
+ f.write(model_card_str)
139
150
 
140
151
  return model
141
152
 
@@ -158,14 +169,17 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
158
169
 
159
170
  with init_empty_weights():
160
171
  pretrained_model_config = self.modelpool.get_model_config("_pretrained_")
161
- pretrained_path = pretrained_model_config.get(
162
- "path", pretrained_model_config["pretrained_model_name_or_path"]
163
- )
172
+ if isinstance(pretrained_model_config, str):
173
+ pretrained_path = pretrained_model_config
174
+ else:
175
+ pretrained_path = pretrained_model_config.get(
176
+ "path", pretrained_model_config["pretrained_model_name_or_path"]
177
+ )
164
178
  base_config = AutoConfig.from_pretrained(pretrained_path)
165
179
  model_config = SmileQwen2Config(
166
- num_experts_per_tok=config.num_experts_per_tok,
167
- rank_of_router=config.rank_of_router,
168
- rank_of_expert=config.rank_of_expert,
180
+ num_experts_per_tok=self.num_experts_per_tok,
181
+ rank_of_router=self.rank_of_router,
182
+ rank_of_expert=self.rank_of_expert,
169
183
  num_local_experts=len(finetuned_models),
170
184
  **base_config.to_dict(),
171
185
  )
@@ -175,7 +189,7 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
175
189
 
176
190
  # copy pretrained model weights
177
191
  state_dict = model.state_dict()
178
- pretrained_state_dict = dict(pretrained_model.state_dict())
192
+ pretrained_state_dict = pretrained_model.state_dict()
179
193
  for key in list(pretrained_state_dict.keys()):
180
194
  if key not in state_dict:
181
195
  pretrained_state_dict.pop(key)
@@ -187,6 +201,12 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
187
201
  "Upscaling Modules (layer)",
188
202
  dynamic_ncols=True,
189
203
  ):
204
+ if RuntimeConstants.debug and layer_idx > 0:
205
+ log.info(
206
+ "Debug mode enabled: processing only the first layer, skipping remaining layers"
207
+ )
208
+ break
209
+
190
210
  pretrained_layer: Qwen2DecoderLayer = pretrained_model.model.layers[
191
211
  layer_idx
192
212
  ]
@@ -202,7 +222,7 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
202
222
  base=getattr(pretrained_layer.self_attn, n),
203
223
  experts=[getattr(m.self_attn, n) for m in finetuned_layers],
204
224
  target=getattr(target_layer.self_attn, n),
205
- accelerator=config.accelerator,
225
+ accelerator=self.accelerator,
206
226
  )
207
227
  except ExpertNotTrainedError:
208
228
  setattr(
@@ -217,7 +237,7 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
217
237
  base=getattr(pretrained_layer.mlp, n),
218
238
  experts=[getattr(m.mlp, n) for m in finetuned_layers],
219
239
  target=getattr(target_layer.mlp, n),
220
- accelerator=config.accelerator,
240
+ accelerator=self.accelerator,
221
241
  )
222
242
  except ExpertNotTrainedError:
223
243
  setattr(
@@ -1,7 +1,7 @@
1
1
  import logging
2
2
  import os
3
3
  from copy import deepcopy
4
- from typing import Dict, List, Tuple # noqa: F401
4
+ from typing import Any, Dict, List, Tuple # noqa: F401
5
5
 
6
6
  import torch
7
7
  import torch.nn.functional as F
@@ -20,6 +20,7 @@ from fusion_bench.models.smile_moe.linear_from_module import (
20
20
  SmileMoELinear,
21
21
  )
22
22
  from fusion_bench.models.utils import get_attr, set_attr
23
+ from fusion_bench.utils.devices import get_device
23
24
  from fusion_bench.utils.parameters import print_parameters
24
25
 
25
26
  log = logging.getLogger(__name__)
@@ -54,7 +55,7 @@ class SmileUpscalingAlgorithm(
54
55
  routing_use_diff: bool = True,
55
56
  average_experts: bool = False,
56
57
  model_path: str = None,
57
- **kwargs,
58
+ **kwargs: Any,
58
59
  ):
59
60
  """
60
61
  Initialize the SmileUpscalingAlgorithm.
@@ -91,7 +92,7 @@ class SmileUpscalingAlgorithm(
91
92
  print(f"=== Config for `{type(self).__name__}` ===")
92
93
 
93
94
  @torch.no_grad()
94
- def run(self, modelpool: BaseModelPool):
95
+ def run(self, modelpool: BaseModelPool) -> nn.Module:
95
96
  """
96
97
  Executes the upscaling process.
97
98
 
@@ -142,7 +143,7 @@ class SmileUpscalingAlgorithm(
142
143
  pretrained_model: nn.Module,
143
144
  finetuned_models: List[nn.Module],
144
145
  in_place: bool = True,
145
- ):
146
+ ) -> nn.Module:
146
147
  """
147
148
  Merges the pretrained model with the fine-tuned models to create an upscaled model.
148
149
 
@@ -180,7 +181,12 @@ class SmileUpscalingAlgorithm(
180
181
 
181
182
  name_list = name.split(".")
182
183
  module = get_attr(pretrained_model, name_list)
183
- experts = [get_attr(m, name_list) for m in finetuned_models]
184
+ original_device = get_device(module)
185
+ module = module.to(self.device, non_blocking=True)
186
+ experts = [
187
+ get_attr(m, name_list).to(self.device, non_blocking=True)
188
+ for m in finetuned_models
189
+ ]
184
190
  try:
185
191
  moe_linear = SmileMoELinear(
186
192
  module,
@@ -192,6 +198,7 @@ class SmileUpscalingAlgorithm(
192
198
  full_matrices=self.full_matrices,
193
199
  upscaling_accelerator=self.upscaling_accelerator,
194
200
  )
201
+ moe_linear = moe_linear.to(original_device, non_blocking=True)
195
202
  except ExpertNotTrainedError:
196
203
  print(f"skip {name} because the experts are not trained.")
197
204
  return
@@ -9,7 +9,7 @@ from copy import deepcopy
9
9
  import torch
10
10
 
11
11
  from fusion_bench import BaseAlgorithm
12
- from fusion_bench.mixins import SimpleProfilerMixin
12
+ from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
13
13
  from fusion_bench.modelpool import BaseModelPool
14
14
  from fusion_bench.utils.state_dict_arithmetic import (
15
15
  state_dict_add,
@@ -58,16 +58,11 @@ def generate_task_masks(
58
58
  return final_mask
59
59
 
60
60
 
61
+ @auto_register_config
61
62
  class TallMaskTaskArithmeticAlgorithm(
62
- BaseAlgorithm,
63
63
  SimpleProfilerMixin,
64
+ BaseAlgorithm,
64
65
  ):
65
- _config_mapping = BaseAlgorithm._config_mapping | {
66
- "tall_mask_lambda": "tall_mask_lambda",
67
- "debug": "debug",
68
- "verbose": "verbose",
69
- }
70
-
71
66
  def __init__(
72
67
  self,
73
68
  tall_mask_lambda: float,
@@ -76,9 +71,6 @@ class TallMaskTaskArithmeticAlgorithm(
76
71
  **kwargs,
77
72
  ):
78
73
  super().__init__(**kwargs)
79
- self.tall_mask_lambda = tall_mask_lambda
80
- self.debug = debug
81
- self.verbose = verbose
82
74
 
83
75
  @torch.no_grad()
84
76
  def run(self, modelpool: BaseModelPool):
@@ -12,7 +12,7 @@ import torch
12
12
  from torch import nn
13
13
 
14
14
  from fusion_bench.method.base_algorithm import BaseAlgorithm
15
- from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
15
+ from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
16
16
  from fusion_bench.modelpool import BaseModelPool
17
17
  from fusion_bench.utils.state_dict_arithmetic import (
18
18
  state_dict_add,
@@ -74,9 +74,10 @@ def task_arithmetic_merge(
74
74
  return pretrained_model
75
75
 
76
76
 
77
+ @auto_register_config
77
78
  class TaskArithmeticAlgorithm(
78
- BaseAlgorithm,
79
79
  SimpleProfilerMixin,
80
+ BaseAlgorithm,
80
81
  ):
81
82
  """
82
83
  Task Arithmetic Algorithm for model fusion.
@@ -89,22 +90,17 @@ class TaskArithmeticAlgorithm(
89
90
  scaling_factor (int): The factor by which the task vectors will be scaled before merging.
90
91
  """
91
92
 
92
- _config_mapping = BaseAlgorithm._config_mapping | {
93
- "scaling_factor": "scaling_factor"
94
- }
95
-
96
- def __init__(self, scaling_factor: int):
93
+ def __init__(self, scaling_factor: int, **kwargs):
97
94
  """
98
95
  Initializes the TaskArithmeticAlgorithm with the given scaling factor.
99
96
 
100
97
  Args:
101
98
  scaling_factor (int): The factor by which the task vectors will be scaled before merging.
102
99
  """
103
- self.scaling_factor = scaling_factor
104
- super().__init__()
100
+ super().__init__(**kwargs)
105
101
 
106
102
  @torch.no_grad()
107
- def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]):
103
+ def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]) -> nn.Module:
108
104
  """
109
105
  Runs the Task Arithmetic Algorithm to fuse models in the given model pool.
110
106
 
@@ -9,14 +9,14 @@ Overview of Ties-Merging:
9
9
  """
10
10
 
11
11
  import logging
12
- from typing import Dict, List, Literal, Mapping, Union # noqa: F401
12
+ from typing import Any, Dict, List, Literal, Mapping, Union # noqa: F401
13
13
 
14
14
  import torch
15
15
  from torch import Tensor, nn
16
16
 
17
17
  from fusion_bench.compat.modelpool import to_modelpool
18
18
  from fusion_bench.method import BaseAlgorithm
19
- from fusion_bench.mixins import SimpleProfilerMixin
19
+ from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
20
20
  from fusion_bench.modelpool import BaseModelPool
21
21
  from fusion_bench.utils.type import StateDictType
22
22
 
@@ -25,33 +25,22 @@ from .ties_merging_utils import state_dict_to_vector, ties_merging, vector_to_st
25
25
  log = logging.getLogger(__name__)
26
26
 
27
27
 
28
- class TiesMergingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
29
- """
30
- TiesMergingAlgorithm is a class for fusing multiple models using the TIES merging technique.
31
-
32
- Attributes:
33
- scaling_factor (float): The scaling factor to apply to the merged task vector.
34
- threshold (float): The threshold for resetting values in the task vector.
35
- remove_keys (List[str]): List of keys to remove from the state dictionary.
36
- merge_func (Literal["sum", "mean", "max"]): The merge function to use for disjoint merging.
37
- """
38
-
39
- _config_mapping = BaseAlgorithm._config_mapping | {
40
- "scaling_factor": "scaling_factor",
41
- "threshold": "threshold",
42
- "remove_keys": "remove_keys",
43
- "merge_func": "merge_func",
44
- }
45
-
28
+ @auto_register_config
29
+ class TiesMergingAlgorithm(
30
+ SimpleProfilerMixin,
31
+ BaseAlgorithm,
32
+ ):
46
33
  def __init__(
47
34
  self,
48
35
  scaling_factor: float,
49
36
  threshold: float,
50
37
  remove_keys: List[str],
51
38
  merge_func: Literal["sum", "mean", "max"],
52
- **kwargs,
39
+ **kwargs: Any,
53
40
  ):
54
41
  """
42
+ TiesMergingAlgorithm is a class for fusing multiple models using the TIES merging technique.
43
+
55
44
  Initialize the TiesMergingAlgorithm with the given parameters.
56
45
 
57
46
  Args:
@@ -61,14 +50,12 @@ class TiesMergingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
61
50
  merge_func (Literal["sum", "mean", "max"]): The merge function to use for disjoint merging.
62
51
  **kwargs: Additional keyword arguments for the base class.
63
52
  """
64
- self.scaling_factor = scaling_factor
65
- self.threshold = threshold
66
- self.remove_keys = remove_keys
67
- self.merge_func = merge_func
68
53
  super().__init__(**kwargs)
69
54
 
70
55
  @torch.no_grad()
71
- def run(self, modelpool: BaseModelPool | Dict[str, nn.Module], **kwargs):
56
+ def run(
57
+ self, modelpool: BaseModelPool | Dict[str, nn.Module], **kwargs: Any
58
+ ) -> nn.Module:
72
59
  """
73
60
  Run the TIES merging algorithm to fuse models in the model pool.
74
61
 
@@ -1,2 +1,3 @@
1
1
  # flake8: noqa F401
2
2
  from .clip_we_moe import CLIPWeightEnsemblingMoEAlgorithm
3
+ from .flan_t5_we_moe import FlanT5WeightEnsemblingMoEAlgorithm
@@ -2,6 +2,7 @@ import functools
2
2
  import logging
3
3
  import os
4
4
  from copy import deepcopy
5
+ from typing import Any, Iterator
5
6
 
6
7
  import torch
7
8
  from torch import Tensor
@@ -38,7 +39,7 @@ class CLIPWeightEnsemblingMoEAlgorithm(
38
39
 
39
40
  modelpool: CLIPVisionModelPool = None
40
41
 
41
- def load_checkpoint(self, model, checkpoint):
42
+ def load_checkpoint(self, model: Any, checkpoint: Any):
42
43
  """
43
44
  Load the checkpoint file.
44
45
 
@@ -49,7 +50,7 @@ class CLIPWeightEnsemblingMoEAlgorithm(
49
50
  state = {"model": model}
50
51
  self._fabric.load(checkpoint, state)
51
52
 
52
- def save_checkpoint(self, model, checkpoint):
53
+ def save_checkpoint(self, model: Any, checkpoint: Any):
53
54
  """
54
55
  Save the checkpoint file.
55
56
 
@@ -102,7 +103,7 @@ class CLIPWeightEnsemblingMoEAlgorithm(
102
103
  return moe_model
103
104
 
104
105
  @functools.cache
105
- def get_shuffled_test_loader_iter(self, tta_dataset: str):
106
+ def get_shuffled_test_loader_iter(self, tta_dataset: str) -> Iterator:
106
107
  """
107
108
  Get an iterator for the shuffled test data loader.
108
109
 
@@ -131,7 +132,7 @@ class CLIPWeightEnsemblingMoEAlgorithm(
131
132
  """
132
133
  self.setup_zero_shot_classification_head()
133
134
 
134
- def compute_logits(self, module, batch, task) -> Tensor:
135
+ def compute_logits(self, module: Any, batch: Any, task: Any) -> Tensor:
135
136
  """
136
137
  Compute the logits for the given batch and task.
137
138
 
@@ -0,0 +1,25 @@
1
+ import torch
2
+ from torch import Tensor
3
+
4
+
5
+ def entropy_loss(logits: Tensor, eps: float = 1e-8) -> Tensor:
6
+ """
7
+ Compute the entropy loss of a set of logits.
8
+
9
+ Args:
10
+ logits (Tensor): The logits to compute the entropy loss of.
11
+ eps (float): A small value to avoid log(0). Default is 1e-8.
12
+
13
+ Returns:
14
+ Tensor: The entropy loss of the logits.
15
+ """
16
+ # Ensure the logits tensor has 2 dimensions
17
+ assert (
18
+ logits.dim() == 2
19
+ ), f"Expected logits to have 2 dimensions, found {logits.dim()}, {logits.size()=}"
20
+
21
+ # Compute the softmax probabilities
22
+ probs = torch.softmax(logits, dim=-1)
23
+
24
+ # Compute the entropy loss
25
+ return -torch.sum(probs * torch.log(probs + eps), dim=-1).mean()