fusion-bench 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. fusion_bench/__init__.py +22 -2
  2. fusion_bench/_get_started/__init__.py +3 -0
  3. fusion_bench/_get_started/greeting_program.py +49 -0
  4. fusion_bench/compat/method/base_algorithm.py +14 -0
  5. fusion_bench/constants/__init__.py +6 -0
  6. fusion_bench/constants/clip_vision.py +26 -2
  7. fusion_bench/constants/paths.py +4 -0
  8. fusion_bench/constants/runtime.py +57 -0
  9. fusion_bench/dataset/clip_dataset.py +2 -1
  10. fusion_bench/dataset/gpt2_glue.py +9 -9
  11. fusion_bench/dataset/image_corruption/__init__.py +0 -0
  12. fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
  13. fusion_bench/dataset/image_dataset.py +1 -1
  14. fusion_bench/dataset/nyuv2.py +2 -2
  15. fusion_bench/method/__init__.py +24 -5
  16. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  17. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
  18. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
  19. fusion_bench/method/base_algorithm.py +195 -12
  20. fusion_bench/method/bitdelta/__init__.py +5 -0
  21. fusion_bench/method/bitdelta/bitdelta.py +156 -0
  22. fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
  23. fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
  24. fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
  25. fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
  26. fusion_bench/method/classification/clip_finetune.py +1 -1
  27. fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
  28. fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
  29. fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
  30. fusion_bench/method/doge_ta/doge_ta.py +1 -1
  31. fusion_bench/method/ensemble.py +12 -12
  32. fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
  33. fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -6
  34. fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
  35. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
  36. fusion_bench/method/fw_merging/fw_hard.py +1 -1
  37. fusion_bench/method/fw_merging/fw_soft.py +1 -1
  38. fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
  39. fusion_bench/method/linear/expo.py +2 -1
  40. fusion_bench/method/linear/linear_interpolation.py +6 -4
  41. fusion_bench/method/linear/simple_average_for_llama.py +17 -13
  42. fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
  43. fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
  44. fusion_bench/method/model_recombination.py +2 -5
  45. fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
  46. fusion_bench/method/moe_pruner/utils/data.py +2 -1
  47. fusion_bench/method/moe_pruner/utils/prune.py +6 -1
  48. fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
  49. fusion_bench/method/pruning/wanda_utils/data.py +1 -2
  50. fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
  51. fusion_bench/method/randes/modelsoup.py +1 -3
  52. fusion_bench/method/regmean/clip_regmean.py +2 -2
  53. fusion_bench/method/regmean/gpt2_regmean.py +3 -10
  54. fusion_bench/method/regmean/regmean.py +2 -11
  55. fusion_bench/method/regmean_plusplus/__init__.py +1 -1
  56. fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +24 -17
  57. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +56 -38
  58. fusion_bench/method/simple_average.py +12 -16
  59. fusion_bench/method/slerp/slerp.py +5 -2
  60. fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
  61. fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
  62. fusion_bench/method/smile_upscaling/projected_energy.py +144 -0
  63. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
  64. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +71 -51
  65. fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
  66. fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
  67. fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
  68. fusion_bench/method/ties_merging/ties_merging.py +13 -26
  69. fusion_bench/method/we_moe/__init__.py +1 -0
  70. fusion_bench/method/we_moe/clip_we_moe.py +5 -4
  71. fusion_bench/method/we_moe/entropy_loss.py +25 -0
  72. fusion_bench/method/we_moe/flan_t5_we_moe.py +331 -0
  73. fusion_bench/method/we_moe/utils.py +15 -0
  74. fusion_bench/method/we_moe/we_moe.py +6 -6
  75. fusion_bench/method/weighted_average/llama.py +4 -16
  76. fusion_bench/metrics/continual_learning/__init__.py +1 -0
  77. fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
  78. fusion_bench/metrics/nyuv2/__init__.py +2 -2
  79. fusion_bench/metrics/nyuv2/segmentation.py +1 -1
  80. fusion_bench/mixins/__init__.py +10 -2
  81. fusion_bench/mixins/clip_classification.py +15 -45
  82. fusion_bench/mixins/hydra_config.py +105 -7
  83. fusion_bench/mixins/lightning_fabric.py +2 -0
  84. fusion_bench/mixins/serialization.py +275 -48
  85. fusion_bench/modelpool/__init__.py +2 -2
  86. fusion_bench/modelpool/base_pool.py +29 -9
  87. fusion_bench/modelpool/causal_lm/causal_lm.py +41 -33
  88. fusion_bench/modelpool/clip_vision/modelpool.py +1 -3
  89. fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
  90. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
  91. fusion_bench/models/__init__.py +7 -1
  92. fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
  93. fusion_bench/models/hf_utils.py +160 -0
  94. fusion_bench/models/linearized/linearized_model_utils.py +4 -4
  95. fusion_bench/models/linearized/vision_model.py +1 -1
  96. fusion_bench/models/model_card_templates/default.md +46 -0
  97. fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
  98. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
  99. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
  100. fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
  101. fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
  102. fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
  103. fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
  104. fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
  105. fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
  106. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +698 -0
  107. fusion_bench/models/modeling_smile_llama/register.py +8 -0
  108. fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
  109. fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
  110. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +7 -12
  111. fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
  112. fusion_bench/models/parameter_dict.py +1 -1
  113. fusion_bench/models/sparse_we_moe.py +1 -53
  114. fusion_bench/models/utils.py +26 -0
  115. fusion_bench/models/we_moe.py +1 -53
  116. fusion_bench/models/wrappers/ensemble.py +6 -4
  117. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
  118. fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
  119. fusion_bench/programs/base_program.py +81 -2
  120. fusion_bench/programs/fabric_fusion_program.py +46 -61
  121. fusion_bench/scripts/cli.py +38 -5
  122. fusion_bench/taskpool/base_pool.py +4 -3
  123. fusion_bench/taskpool/clip_vision/taskpool.py +43 -22
  124. fusion_bench/taskpool/dummy.py +1 -1
  125. fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
  126. fusion_bench/tasks/clip_classification/__init__.py +6 -4
  127. fusion_bench/utils/__init__.py +7 -1
  128. fusion_bench/utils/cache_utils.py +101 -1
  129. fusion_bench/utils/devices.py +14 -4
  130. fusion_bench/utils/fabric.py +2 -2
  131. fusion_bench/utils/instantiate_utils.py +3 -1
  132. fusion_bench/utils/lazy_imports.py +23 -0
  133. fusion_bench/utils/lazy_state_dict.py +38 -3
  134. fusion_bench/utils/modelscope.py +127 -8
  135. fusion_bench/utils/parameters.py +2 -2
  136. fusion_bench/utils/path.py +56 -0
  137. fusion_bench/utils/pylogger.py +1 -1
  138. fusion_bench/utils/rich_utils.py +3 -0
  139. fusion_bench/utils/state_dict_arithmetic.py +25 -23
  140. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/METADATA +24 -47
  141. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/RECORD +184 -145
  142. fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
  143. fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
  144. fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
  145. fusion_bench_config/_get_started/greeting_program.yaml +4 -0
  146. fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
  147. fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
  148. fusion_bench_config/fabric_model_fusion.yaml +45 -17
  149. fusion_bench_config/hydra/default.yaml +6 -2
  150. fusion_bench_config/llama_full_finetune.yaml +1 -0
  151. fusion_bench_config/method/adamerging/clip.yaml +1 -1
  152. fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
  153. fusion_bench_config/method/depth_upscaling.yaml +4 -1
  154. fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
  155. fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
  156. fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
  157. fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
  158. fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
  159. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +2 -1
  160. fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
  161. fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
  162. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
  163. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +1 -1
  164. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -6
  165. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +1 -1
  166. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +1 -1
  167. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +3 -3
  168. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
  169. fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
  170. fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
  171. fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
  172. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
  173. fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
  174. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
  175. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
  176. fusion_bench_config/nyuv2_config.yaml +3 -1
  177. fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
  178. fusion_bench_config/path/default.yaml +28 -0
  179. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
  180. fusion_bench_config/method/adamerging.yaml +0 -23
  181. fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
  182. fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
  183. fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
  184. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/WHEEL +0 -0
  185. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/entry_points.txt +0 -0
  186. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/licenses/LICENSE +0 -0
  187. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/top_level.txt +0 -0
  188. /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
@@ -8,7 +8,7 @@ from copy import deepcopy
8
8
  from typing import Any, Dict, Optional, TypeAlias, Union, cast # noqa: F401
9
9
 
10
10
  import peft
11
- from omegaconf import DictConfig, flag_override
11
+ from omegaconf import DictConfig, OmegaConf, flag_override
12
12
  from torch import nn
13
13
  from torch.nn.modules import Module
14
14
  from transformers import (
@@ -19,43 +19,51 @@ from transformers import (
19
19
  )
20
20
  from typing_extensions import override
21
21
 
22
- from fusion_bench.modelpool import BaseModelPool
23
- from fusion_bench.utils import instantiate
24
- from fusion_bench.utils.dtype import parse_dtype
22
+ from fusion_bench import (
23
+ BaseModelPool,
24
+ auto_register_config,
25
+ import_object,
26
+ instantiate,
27
+ parse_dtype,
28
+ )
25
29
  from fusion_bench.utils.lazy_state_dict import LazyStateDict
26
- from fusion_bench.utils.packages import import_object
27
30
 
28
31
  log = logging.getLogger(__name__)
29
32
 
30
33
 
34
+ @auto_register_config
31
35
  class CausalLMPool(BaseModelPool):
32
- _config_mapping = BaseModelPool._config_mapping | {
33
- "_tokenizer": "tokenizer",
34
- "_model_kwargs": "model_kwargs",
35
- "load_lazy": "load_lazy",
36
- }
37
-
38
36
  def __init__(
39
37
  self,
40
38
  models,
41
39
  *,
42
- tokenizer: Optional[DictConfig],
40
+ tokenizer: Optional[DictConfig | str],
43
41
  model_kwargs: Optional[DictConfig] = None,
44
- load_lazy: bool = False,
42
+ enable_lazy_loading: bool = False,
45
43
  **kwargs,
46
44
  ):
47
45
  super().__init__(models, **kwargs)
48
- # process `model_kwargs`
49
- self._tokenizer = tokenizer
50
- self._model_kwargs = model_kwargs
51
- if self._model_kwargs is None:
52
- self._model_kwargs = DictConfig({})
53
- with flag_override(self._model_kwargs, "allow_objects", True):
54
- if hasattr(self._model_kwargs, "torch_dtype"):
55
- self._model_kwargs.torch_dtype = parse_dtype(
56
- self._model_kwargs.torch_dtype
57
- )
58
- self.load_lazy = load_lazy
46
+ if model_kwargs is None:
47
+ self.model_kwargs = DictConfig({})
48
+
49
+ def get_model_path(self, model_name: str):
50
+ model_name_or_config = self._models[model_name]
51
+ if isinstance(model_name_or_config, str):
52
+ return model_name_or_config
53
+ elif isinstance(model_name_or_config, (DictConfig, dict)):
54
+ return model_name_or_config.get("pretrained_model_name_or_path")
55
+ else:
56
+ raise RuntimeError("Invalid model configuration")
57
+
58
+ def get_model_kwargs(self):
59
+ model_kwargs = (
60
+ OmegaConf.to_container(self.model_kwargs, resolve=True)
61
+ if isinstance(self.model_kwargs, DictConfig)
62
+ else self.model_kwargs
63
+ )
64
+ if "torch_dtype" in model_kwargs:
65
+ model_kwargs["torch_dtype"] = parse_dtype(model_kwargs["torch_dtype"])
66
+ return model_kwargs
59
67
 
60
68
  @override
61
69
  def load_model(
@@ -89,7 +97,7 @@ class CausalLMPool(BaseModelPool):
89
97
  pretrained_model_name_or_path: path_to_model_b
90
98
  ```
91
99
  """
92
- model_kwargs = deepcopy(self._model_kwargs)
100
+ model_kwargs = self.get_model_kwargs()
93
101
  model_kwargs.update(kwargs)
94
102
 
95
103
  if isinstance(model_name_or_config, str):
@@ -99,7 +107,7 @@ class CausalLMPool(BaseModelPool):
99
107
  model_config = self._models[model_name_or_config]
100
108
  if isinstance(model_config, str):
101
109
  # model_config is a string
102
- if not self.load_lazy:
110
+ if not self.enable_lazy_loading:
103
111
  model = AutoModelForCausalLM.from_pretrained(
104
112
  model_config,
105
113
  *args,
@@ -117,7 +125,7 @@ class CausalLMPool(BaseModelPool):
117
125
  elif isinstance(model_name_or_config, (DictConfig, Dict)):
118
126
  model_config = model_name_or_config
119
127
 
120
- if not self.load_lazy:
128
+ if not self.enable_lazy_loading:
121
129
  model = instantiate(model_config, *args, **model_kwargs)
122
130
  else:
123
131
  meta_module_class = model_config.pop("_target_")
@@ -149,12 +157,12 @@ class CausalLMPool(BaseModelPool):
149
157
  Returns:
150
158
  PreTrainedTokenizer: The tokenizer.
151
159
  """
152
- assert self._tokenizer is not None, "Tokenizer is not defined in the config"
160
+ assert self.tokenizer is not None, "Tokenizer is not defined in the config"
153
161
  log.info("Loading tokenizer.", stacklevel=2)
154
- if isinstance(self._tokenizer, str):
155
- tokenizer = AutoTokenizer.from_pretrained(self._tokenizer, *args, **kwargs)
162
+ if isinstance(self.tokenizer, str):
163
+ tokenizer = AutoTokenizer.from_pretrained(self.tokenizer, *args, **kwargs)
156
164
  else:
157
- tokenizer = instantiate(self._tokenizer, *args, **kwargs)
165
+ tokenizer = instantiate(self.tokenizer, *args, **kwargs)
158
166
  return tokenizer
159
167
 
160
168
  @override
@@ -204,12 +212,12 @@ class CausalLMBackbonePool(CausalLMPool):
204
212
  def load_model(
205
213
  self, model_name_or_config: str | DictConfig, *args, **kwargs
206
214
  ) -> Module:
207
- if self.load_lazy:
215
+ if self.enable_lazy_loading:
208
216
  log.warning(
209
217
  "CausalLMBackbonePool does not support lazy loading. "
210
218
  "Falling back to normal loading."
211
219
  )
212
- self.load_lazy = False
220
+ self.enable_lazy_loading = False
213
221
  model: AutoModelForCausalLM = super().load_model(
214
222
  model_name_or_config, *args, **kwargs
215
223
  )
@@ -11,9 +11,7 @@ from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
11
11
  from typing_extensions import override
12
12
 
13
13
  from fusion_bench.utils import instantiate, timeit_context
14
- from fusion_bench.utils.modelscope import (
15
- resolve_repo_path,
16
- )
14
+ from fusion_bench.utils.modelscope import resolve_repo_path
17
15
 
18
16
  from ..base_pool import BaseModelPool
19
17
 
@@ -1,2 +1,2 @@
1
1
  from .reward_model import create_reward_model_from_pretrained
2
- from .seq_classification_lm import SeqenceClassificationModelPool
2
+ from .seq_classification_lm import SequenceClassificationModelPool
@@ -17,7 +17,7 @@ if TYPE_CHECKING:
17
17
  log = logging.getLogger(__name__)
18
18
 
19
19
 
20
- class SeqenceClassificationModelPool(BaseModelPool):
20
+ class SequenceClassificationModelPool(BaseModelPool):
21
21
 
22
22
  def __init__(
23
23
  self,
@@ -1,4 +1,10 @@
1
1
  # flake8: noqa F401
2
+ from fusion_bench.utils import LazyStateDict
3
+
2
4
  from . import separate_io, utils
5
+ from .hf_utils import (
6
+ create_default_model_card,
7
+ load_model_card_template,
8
+ save_pretrained_with_remote_code,
9
+ )
3
10
  from .parameter_dict import ParameterDictModel
4
- from fusion_bench.utils import LazyStateDict
@@ -10,6 +10,6 @@ Reference:
10
10
  """
11
11
 
12
12
  from .wrapper import (
13
- PrunableMixtralSparseMoeBlockWrapper,
14
13
  DynamicSkippingMixtralSparseMoeBlockWrapper,
14
+ PrunableMixtralSparseMoeBlockWrapper,
15
15
  )
@@ -0,0 +1,160 @@
1
+ """
2
+ This module contains utilities for working with Hugging Face models.
3
+ """
4
+
5
+ import inspect
6
+ import os
7
+ import shutil
8
+ from typing import List, Optional, cast
9
+
10
+ from omegaconf import DictConfig, OmegaConf
11
+ from transformers.modeling_utils import PreTrainedModel
12
+
13
+ from fusion_bench.utils.pylogger import get_rankzero_logger
14
+
15
+ log = get_rankzero_logger(__name__)
16
+
17
+ __all__ = [
18
+ "load_model_card_template",
19
+ "save_pretrained_with_remote_code",
20
+ "create_default_model_card",
21
+ ]
22
+
23
+ MODEL_CARD_TEMPLATE_DIRS = [
24
+ os.path.join(os.path.dirname(__file__), "model_card_templates")
25
+ ]
26
+
27
+
28
+ def load_model_card_template(basename: str) -> str:
29
+ """
30
+ Load a model card template from file.
31
+
32
+ Searches for a template file by name, first checking if the name is a direct file path,
33
+ then searching through predefined template directories.
34
+
35
+ Args:
36
+ name (str): The name of the template file or a direct file path to the template.
37
+
38
+ Returns:
39
+ str: The contents of the template file as a string.
40
+
41
+ Raises:
42
+ FileNotFoundError: If the template file is not found in any of the search locations.
43
+ """
44
+ if os.path.exists(basename):
45
+ return open(basename).read()
46
+
47
+ for template_dir in MODEL_CARD_TEMPLATE_DIRS:
48
+ template_path = os.path.join(template_dir, basename)
49
+ if os.path.exists(template_path):
50
+ return open(template_path).read()
51
+
52
+ raise FileNotFoundError(f"Model card template '{basename}' not found.")
53
+
54
+
55
+ def try_to_yaml(config):
56
+ if config is None:
57
+ return None
58
+
59
+ try:
60
+ return OmegaConf.to_yaml(config, resolve=True, sort_keys=True)
61
+ except Exception as e:
62
+ log.error(f"Failed to convert config to YAML: {e}. Return `None`.")
63
+ return None
64
+
65
+
66
+ def save_pretrained_with_remote_code(
67
+ model: PreTrainedModel,
68
+ auto_map: dict[str, object],
69
+ save_directory,
70
+ **kwargs,
71
+ ):
72
+ """
73
+ Saves a model with custom code to a directory.
74
+
75
+ This function facilitates saving a Hugging Face `PreTrainedModel` along with its
76
+ associated custom code. It inspects the objects provided in the `auto_map`,
77
+ copies their source files to the `save_directory`, and generates an `__init__.py`
78
+ to make them importable. It also updates the model's configuration with an
79
+ `auto_map` attribute, which allows `AutoModel.from_pretrained` to correctly
80
+ instantiate the custom model classes when `trust_remote_code=True`.
81
+
82
+ Args:
83
+ model (PreTrainedModel): The model instance to be saved.
84
+ auto_map (dict[str, object]): A dictionary mapping auto class names
85
+ (e.g., "AutoModelForCausalLM") to the corresponding custom class objects.
86
+ save_directory (str or os.PathLike): The directory where the model and
87
+ custom code files will be saved.
88
+ **kwargs: Additional keyword arguments to be passed to the
89
+ `model.save_pretrained` method.
90
+
91
+ Example:
92
+ ```python
93
+ # Assuming `model` is an instance of `SmileQwen2ForCausalLM`
94
+ # and `SmileQwen2Config`, `SmileQwen2Model`, `SmileQwen2ForCausalLM`
95
+ # are custom classes defined in your project.
96
+
97
+ save_pretrained_with_remote_code(
98
+ model,
99
+ auto_map={
100
+ "AutoConfig": SmileQwen2Config,
101
+ "AutoModel": SmileQwen2Model,
102
+ "AutoModelForCausalLM": SmileQwen2ForCausalLM,
103
+ },
104
+ save_directory="./my-custom-model",
105
+ )
106
+
107
+ # The model can then be loaded with `trust_remote_code=True`:
108
+ # from transformers import AutoModelForCausalLM
109
+ # loaded_model = AutoModelForCausalLM.from_pretrained(
110
+ # "./my-custom-model", trust_remote_code=True
111
+ # )
112
+ ```
113
+ """
114
+ auto_map_files = {}
115
+ auto_map_strs = {}
116
+ for key, obj in auto_map.items():
117
+ auto_map_files[key] = inspect.getfile(obj)
118
+
119
+ for key, obj in auto_map.items():
120
+ auto_map_strs[key] = (
121
+ f"{(inspect.getmodule(obj).__name__).split('.')[-1]}.{obj.__name__}"
122
+ )
123
+
124
+ model.config.auto_map = auto_map_strs
125
+
126
+ # save model to `save_directory`
127
+ model.save_pretrained(save_directory=save_directory, **kwargs)
128
+
129
+ # copy source files to `save_directory`
130
+ for key, file_path in auto_map_files.items():
131
+ shutil.copy(
132
+ src=file_path, dst=os.path.join(save_directory, os.path.basename(file_path))
133
+ )
134
+ # construct `__init__.py`
135
+ init_file = os.path.join(save_directory, "__init__.py")
136
+ with open(init_file, "w") as f:
137
+ for key, file_name in auto_map_files.items():
138
+ base_name = os.path.basename(file_name).split(".")[0]
139
+ f.write(f"from .{base_name} import {auto_map[key].__name__}\n")
140
+
141
+
142
+ def create_default_model_card(
143
+ models: list[str],
144
+ description=None,
145
+ algorithm_config: DictConfig = None,
146
+ modelpool_config: DictConfig = None,
147
+ ):
148
+ from jinja2 import Template
149
+
150
+ template: Template = Template(load_model_card_template("default.md"))
151
+ card = template.render(
152
+ models=models,
153
+ library_name="transformers",
154
+ tags=["fusion-bench", "merge"],
155
+ title="Deep Model Fusion",
156
+ description=description,
157
+ algorithm_config_str=try_to_yaml(algorithm_config),
158
+ modelpool_config_str=try_to_yaml(modelpool_config),
159
+ )
160
+ return card
@@ -1,7 +1,7 @@
1
1
  import logging
2
2
  from collections import OrderedDict
3
3
  from copy import deepcopy
4
- from typing import Optional
4
+ from typing import Any, Dict, Optional, Tuple
5
5
 
6
6
  import torch.nn as nn
7
7
  from torch.func import functional_call, jvp
@@ -9,7 +9,7 @@ from torch.func import functional_call, jvp
9
9
  log = logging.getLogger(__name__)
10
10
 
11
11
 
12
- def dict_params_to_tuple(dict_params: dict):
12
+ def dict_params_to_tuple(dict_params: dict) -> Tuple:
13
13
  return tuple(v for k, v in dict_params.items())
14
14
 
15
15
 
@@ -33,7 +33,7 @@ class LinearizedModelWraper(nn.Module):
33
33
  for p in self.params0_values:
34
34
  p.requires_grad_(False)
35
35
 
36
- def tuple_params_to_dict(self, tuple_params):
36
+ def tuple_params_to_dict(self, tuple_params) -> Dict[str, Any]:
37
37
  """
38
38
  Converts a tuple of parameters to a dictionary with keys corresponding to the parameter names.
39
39
 
@@ -50,7 +50,7 @@ class LinearizedModelWraper(nn.Module):
50
50
  state_dict[k] = p
51
51
  return state_dict
52
52
 
53
- def forward(self, *args, **kwargs):
53
+ def forward(self, *args: Any, **kwargs: Any) -> Any:
54
54
  """
55
55
  Computes the linearized model output using a first-order Taylor decomposition.
56
56
 
@@ -70,7 +70,7 @@ def load_lora_vision_model_hf(
70
70
  peft_name: str,
71
71
  merge_and_unload: bool = False,
72
72
  return_vison_model=True,
73
- ):
73
+ ) -> PeftModel:
74
74
  """
75
75
  Load a LoRA (Low-Rank Adaptation) vision model from Hugging Face.
76
76
 
@@ -0,0 +1,46 @@
1
+ ---
2
+ base_model:
3
+ {%- for model in models %}
4
+ - {{ model }}
5
+ {%- endfor %}
6
+ library_name: {{ library_name }}
7
+ tags:
8
+ {%- for tag in tags %}
9
+ - {{ tag }}
10
+ {%- endfor %}
11
+ ---
12
+ # {{ title }}
13
+
14
+ {% if description is not none %}{{ description }}{% endif %}
15
+
16
+ ## Models Merged
17
+
18
+ This is a merged model created using [fusion-bench](https://github.com/tanganke/fusion_bench).
19
+
20
+ The following models were included in the merge:
21
+ {% for model in models %}
22
+ - {{ model }}
23
+ {%- endfor %}
24
+
25
+ {% if algorithm_config_str is not none or modelpool_config_str is not none %}
26
+ ## Configuration
27
+
28
+ The following YAML configuration was used to produce this model:
29
+
30
+ {% if algorithm_config_str is not none -%}
31
+ ### Algorithm Configuration
32
+
33
+ ```yaml
34
+ {{ algorithm_config_str -}}
35
+ ```
36
+ {%- endif %}
37
+
38
+ {% if modelpool_config_str is not none -%}
39
+ ### Model Pool Configuration
40
+
41
+ ```yaml
42
+ {{ modelpool_config_str -}}
43
+ ```
44
+ {%- endif %}
45
+
46
+ {% endif %}
@@ -4,12 +4,12 @@ This is a direct copy of the DeepSeek-V2-Lite model from HuggingFace https://hug
4
4
 
5
5
  from .configuration_deepseek import DeepseekV2Config
6
6
  from .modeling_deepseek import (
7
+ DeepseekV2DecoderLayer,
7
8
  DeepseekV2ForCausalLM,
8
9
  DeepseekV2ForSequenceClassification,
9
10
  DeepseekV2MLP,
10
11
  DeepseekV2Model,
11
12
  DeepseekV2MoE,
12
- DeepseekV2DecoderLayer,
13
13
  )
14
14
  from .modeling_deepseek import MoEGate as DeepseekV2MoEGate
15
15
  from .tokenization_deepseek_fast import DeepseekTokenizerFast
@@ -17,17 +17,18 @@
17
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
18
  # See the License for the specific language governing permissions and
19
19
  # limitations under the License.
20
- """ PyTorch DeepSeek model."""
20
+ """PyTorch DeepSeek model."""
21
21
  import math
22
22
  import warnings
23
23
  from typing import List, Optional, Tuple, Union
24
24
 
25
+ import numpy as np
25
26
  import torch
27
+ import torch.distributed as dist
26
28
  import torch.nn.functional as F
27
29
  import torch.utils.checkpoint
28
30
  from torch import nn
29
31
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30
-
31
32
  from transformers.activations import ACT2FN
32
33
  from transformers.cache_utils import Cache, DynamicCache
33
34
  from transformers.modeling_attn_mask_utils import (
@@ -54,9 +55,8 @@ from transformers.utils import (
54
55
  replace_return_docstrings,
55
56
  )
56
57
  from transformers.utils.import_utils import is_torch_fx_available
58
+
57
59
  from .configuration_deepseek import DeepseekV2Config
58
- import torch.distributed as dist
59
- import numpy as np
60
60
 
61
61
  if is_flash_attn_2_available():
62
62
  from flash_attn import flash_attn_func, flash_attn_varlen_func
@@ -1,6 +1,5 @@
1
1
  from typing import List, Optional, Union
2
2
 
3
-
4
3
  from transformers.models.llama import LlamaTokenizerFast
5
4
 
6
5
 
@@ -0,0 +1,9 @@
1
+ from . import register
2
+ from .configuration_smile_gemma2 import SmileGemma2Config
3
+ from .modeling_smile_gemma2 import (
4
+ SmileGemma2ForCausalLM,
5
+ SmileGemma2ForSequenceClassification,
6
+ SmileGemma2ForTokenClassification,
7
+ SmileGemma2Model,
8
+ SmileGemma2PreTrainedModel,
9
+ )
@@ -0,0 +1,20 @@
1
+ from transformers.models.gemma2.configuration_gemma2 import Gemma2Config
2
+
3
+
4
+ class SmileGemma2Config(Gemma2Config):
5
+ model_type = "smile_gemma2"
6
+
7
+ def __init__(
8
+ self,
9
+ num_experts_per_tok: int = 1,
10
+ rank_of_router: int = None,
11
+ rank_of_expert: int = None,
12
+ num_local_experts: int = None,
13
+ **kwargs,
14
+ ):
15
+ self.num_experts_per_tok = num_experts_per_tok
16
+ self.rank_of_router = rank_of_router
17
+ self.rank_of_expert = rank_of_expert
18
+ self.num_local_experts = num_local_experts
19
+
20
+ super().__init__(**kwargs)