fusion-bench 0.2.20__py3-none-any.whl → 0.2.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. fusion_bench/__init__.py +1 -0
  2. fusion_bench/_get_started/__init__.py +3 -0
  3. fusion_bench/_get_started/greeting_program.py +49 -0
  4. fusion_bench/compat/method/base_algorithm.py +14 -0
  5. fusion_bench/constants/__init__.py +5 -0
  6. fusion_bench/constants/clip_vision.py +26 -2
  7. fusion_bench/constants/paths.py +4 -0
  8. fusion_bench/dataset/clip_dataset.py +2 -1
  9. fusion_bench/dataset/gpt2_glue.py +9 -9
  10. fusion_bench/dataset/image_corruption/__init__.py +0 -0
  11. fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
  12. fusion_bench/dataset/image_dataset.py +1 -1
  13. fusion_bench/dataset/nyuv2.py +2 -2
  14. fusion_bench/method/__init__.py +16 -3
  15. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  16. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
  17. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
  18. fusion_bench/method/base_algorithm.py +195 -12
  19. fusion_bench/method/bitdelta/__init__.py +4 -0
  20. fusion_bench/method/bitdelta/bitdelta.py +156 -0
  21. fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
  22. fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
  23. fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
  24. fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
  25. fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
  26. fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
  27. fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
  28. fusion_bench/method/doge_ta/doge_ta.py +1 -1
  29. fusion_bench/method/ensemble.py +12 -12
  30. fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
  31. fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -2
  32. fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
  33. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
  34. fusion_bench/method/fw_merging/fw_hard.py +1 -1
  35. fusion_bench/method/fw_merging/fw_soft.py +1 -1
  36. fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
  37. fusion_bench/method/linear/expo.py +2 -1
  38. fusion_bench/method/linear/linear_interpolation.py +6 -4
  39. fusion_bench/method/linear/simple_average_for_llama.py +2 -3
  40. fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
  41. fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
  42. fusion_bench/method/model_recombination.py +2 -5
  43. fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
  44. fusion_bench/method/moe_pruner/utils/data.py +2 -1
  45. fusion_bench/method/moe_pruner/utils/prune.py +6 -1
  46. fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
  47. fusion_bench/method/pruning/wanda_utils/data.py +1 -2
  48. fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
  49. fusion_bench/method/randes/modelsoup.py +1 -3
  50. fusion_bench/method/regmean/clip_regmean.py +2 -2
  51. fusion_bench/method/regmean/gpt2_regmean.py +3 -10
  52. fusion_bench/method/regmean/regmean.py +2 -11
  53. fusion_bench/method/regmean_plusplus/__init__.py +1 -1
  54. fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +24 -17
  55. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +56 -38
  56. fusion_bench/method/simple_average.py +5 -9
  57. fusion_bench/method/slerp/slerp.py +5 -2
  58. fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
  59. fusion_bench/method/smile_upscaling/projected_energy.py +145 -0
  60. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +39 -28
  61. fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
  62. fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
  63. fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
  64. fusion_bench/method/ties_merging/ties_merging.py +13 -26
  65. fusion_bench/method/we_moe/clip_we_moe.py +5 -4
  66. fusion_bench/method/we_moe/we_moe.py +6 -6
  67. fusion_bench/method/weighted_average/llama.py +4 -16
  68. fusion_bench/metrics/continual_learning/__init__.py +1 -0
  69. fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
  70. fusion_bench/metrics/nyuv2/__init__.py +2 -2
  71. fusion_bench/metrics/nyuv2/segmentation.py +1 -1
  72. fusion_bench/mixins/__init__.py +10 -2
  73. fusion_bench/mixins/clip_classification.py +4 -3
  74. fusion_bench/mixins/hydra_config.py +105 -7
  75. fusion_bench/mixins/lightning_fabric.py +2 -0
  76. fusion_bench/mixins/serialization.py +265 -48
  77. fusion_bench/modelpool/__init__.py +2 -2
  78. fusion_bench/modelpool/base_pool.py +29 -9
  79. fusion_bench/modelpool/causal_lm/causal_lm.py +9 -0
  80. fusion_bench/modelpool/clip_vision/modelpool.py +1 -3
  81. fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
  82. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
  83. fusion_bench/models/__init__.py +2 -1
  84. fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
  85. fusion_bench/models/hf_utils.py +182 -0
  86. fusion_bench/models/linearized/linearized_model_utils.py +4 -4
  87. fusion_bench/models/linearized/vision_model.py +1 -1
  88. fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
  89. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
  90. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
  91. fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
  92. fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
  93. fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
  94. fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
  95. fusion_bench/models/modeling_smile_llama/__init__.py +0 -0
  96. fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
  97. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +705 -0
  98. fusion_bench/models/modeling_smile_llama/register.py +8 -0
  99. fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
  100. fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
  101. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +6 -7
  102. fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
  103. fusion_bench/models/parameter_dict.py +1 -1
  104. fusion_bench/models/sparse_we_moe.py +1 -53
  105. fusion_bench/models/utils.py +26 -0
  106. fusion_bench/models/we_moe.py +1 -53
  107. fusion_bench/models/wrappers/ensemble.py +6 -4
  108. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
  109. fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
  110. fusion_bench/programs/base_program.py +81 -2
  111. fusion_bench/programs/fabric_fusion_program.py +24 -8
  112. fusion_bench/scripts/cli.py +5 -5
  113. fusion_bench/taskpool/base_pool.py +4 -3
  114. fusion_bench/taskpool/clip_vision/taskpool.py +34 -18
  115. fusion_bench/taskpool/dummy.py +1 -1
  116. fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
  117. fusion_bench/tasks/clip_classification/__init__.py +6 -4
  118. fusion_bench/utils/__init__.py +6 -1
  119. fusion_bench/utils/devices.py +14 -4
  120. fusion_bench/utils/instantiate_utils.py +3 -1
  121. fusion_bench/utils/modelscope.py +127 -8
  122. fusion_bench/utils/parameters.py +2 -2
  123. fusion_bench/utils/rich_utils.py +3 -0
  124. fusion_bench/utils/state_dict_arithmetic.py +25 -23
  125. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/METADATA +24 -25
  126. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/RECORD +165 -134
  127. fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
  128. fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
  129. fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
  130. fusion_bench_config/_get_started/greeting_program.yaml +4 -0
  131. fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
  132. fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
  133. fusion_bench_config/fabric_model_fusion.yaml +45 -17
  134. fusion_bench_config/hydra/default.yaml +6 -2
  135. fusion_bench_config/llama_full_finetune.yaml +1 -0
  136. fusion_bench_config/method/adamerging/clip.yaml +1 -1
  137. fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
  138. fusion_bench_config/method/depth_upscaling.yaml +4 -1
  139. fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
  140. fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
  141. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -0
  142. fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
  143. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
  144. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +1 -1
  145. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -6
  146. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +1 -1
  147. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +1 -1
  148. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +2 -2
  149. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
  150. fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
  151. fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
  152. fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
  153. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
  154. fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
  155. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
  156. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
  157. fusion_bench_config/nyuv2_config.yaml +3 -1
  158. fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
  159. fusion_bench_config/path/default.yaml +28 -0
  160. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
  161. fusion_bench_config/method/adamerging.yaml +0 -23
  162. fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
  163. fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
  164. fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
  165. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/WHEEL +0 -0
  166. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/entry_points.txt +0 -0
  167. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/licenses/LICENSE +0 -0
  168. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/top_level.txt +0 -0
  169. /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
@@ -1,13 +1,13 @@
1
1
  import logging
2
2
  from copy import deepcopy
3
- from typing import Dict, List, Optional, Union
3
+ from typing import Dict, Generator, List, Optional, Tuple, Union
4
4
 
5
5
  import torch
6
6
  from omegaconf import DictConfig
7
7
  from torch import nn
8
8
  from torch.utils.data import Dataset
9
9
 
10
- from fusion_bench.mixins import BaseYAMLSerializableModel, HydraConfigMixin
10
+ from fusion_bench.mixins import BaseYAMLSerializable, HydraConfigMixin
11
11
  from fusion_bench.utils import instantiate, timeit_context
12
12
 
13
13
  __all__ = ["BaseModelPool"]
@@ -15,7 +15,10 @@ __all__ = ["BaseModelPool"]
15
15
  log = logging.getLogger(__name__)
16
16
 
17
17
 
18
- class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
18
+ class BaseModelPool(
19
+ HydraConfigMixin,
20
+ BaseYAMLSerializable,
21
+ ):
19
22
  """
20
23
  A class for managing and interacting with a pool of models along with their associated datasets or other specifications. For example, a model pool may contain multiple models, each with its own training, validation, and testing datasets. As for the specifications, a vision model pool may contain image preprocessor, and a language model pool may contain a tokenizer.
21
24
 
@@ -31,7 +34,7 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
31
34
  _program = None
32
35
  _config_key = "modelpool"
33
36
  _models: Union[DictConfig, Dict[str, nn.Module]]
34
- _config_mapping = BaseYAMLSerializableModel._config_mapping | {
37
+ _config_mapping = BaseYAMLSerializable._config_mapping | {
35
38
  "_models": "models",
36
39
  "_train_datasets": "train_datasets",
37
40
  "_val_datasets": "val_datasets",
@@ -56,7 +59,7 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
56
59
  super().__init__(**kwargs)
57
60
 
58
61
  @property
59
- def has_pretrained(self):
62
+ def has_pretrained(self) -> bool:
60
63
  """
61
64
  Check if the model pool contains a pretrained model.
62
65
 
@@ -125,7 +128,7 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
125
128
  return len(self.model_names)
126
129
 
127
130
  @staticmethod
128
- def is_special_model(model_name: str):
131
+ def is_special_model(model_name: str) -> bool:
129
132
  """
130
133
  Determine if a model is special based on its name.
131
134
 
@@ -152,6 +155,23 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
152
155
  model_config = deepcopy(model_config)
153
156
  return model_config
154
157
 
158
+ def get_model_path(self, model_name: str) -> str:
159
+ """
160
+ Get the path for the specified model.
161
+
162
+ Args:
163
+ model_name (str): The name of the model.
164
+
165
+ Returns:
166
+ str: The path for the specified model.
167
+ """
168
+ if isinstance(self._models[model_name], str):
169
+ return self._models[model_name]
170
+ else:
171
+ raise ValueError(
172
+ "Model path is not a string. Try to override this method in derived modelpool class."
173
+ )
174
+
155
175
  def load_model(
156
176
  self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
157
177
  ) -> nn.Module:
@@ -159,7 +179,7 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
159
179
  Load a model from the pool based on the provided configuration.
160
180
 
161
181
  Args:
162
- model (Union[str, DictConfig]): The model name or configuration.
182
+ model_name_or_config (Union[str, DictConfig]): The model name or configuration.
163
183
 
164
184
  Returns:
165
185
  nn.Module: The instantiated model.
@@ -201,11 +221,11 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
201
221
  model = self.load_model(self.model_names[0], *args, **kwargs)
202
222
  return model
203
223
 
204
- def models(self):
224
+ def models(self) -> Generator[nn.Module, None, None]:
205
225
  for model_name in self.model_names:
206
226
  yield self.load_model(model_name)
207
227
 
208
- def named_models(self):
228
+ def named_models(self) -> Generator[Tuple[str, nn.Module], None, None]:
209
229
  for model_name in self.model_names:
210
230
  yield model_name, self.load_model(model_name)
211
231
 
@@ -57,6 +57,15 @@ class CausalLMPool(BaseModelPool):
57
57
  )
58
58
  self.load_lazy = load_lazy
59
59
 
60
+ def get_model_path(self, model_name: str):
61
+ model_name_or_config = self._models[model_name]
62
+ if isinstance(model_name_or_config, str):
63
+ return model_name_or_config
64
+ elif isinstance(model_name_or_config, (DictConfig, dict)):
65
+ return model_name_or_config.get("pretrained_model_name_or_path")
66
+ else:
67
+ raise RuntimeError("Invalid model configuration")
68
+
60
69
  @override
61
70
  def load_model(
62
71
  self,
@@ -11,9 +11,7 @@ from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
11
11
  from typing_extensions import override
12
12
 
13
13
  from fusion_bench.utils import instantiate, timeit_context
14
- from fusion_bench.utils.modelscope import (
15
- resolve_repo_path,
16
- )
14
+ from fusion_bench.utils.modelscope import resolve_repo_path
17
15
 
18
16
  from ..base_pool import BaseModelPool
19
17
 
@@ -1,2 +1,2 @@
1
1
  from .reward_model import create_reward_model_from_pretrained
2
- from .seq_classification_lm import SeqenceClassificationModelPool
2
+ from .seq_classification_lm import SequenceClassificationModelPool
@@ -17,7 +17,7 @@ if TYPE_CHECKING:
17
17
  log = logging.getLogger(__name__)
18
18
 
19
19
 
20
- class SeqenceClassificationModelPool(BaseModelPool):
20
+ class SequenceClassificationModelPool(BaseModelPool):
21
21
 
22
22
  def __init__(
23
23
  self,
@@ -1,4 +1,5 @@
1
1
  # flake8: noqa F401
2
+ from fusion_bench.utils import LazyStateDict
3
+
2
4
  from . import separate_io, utils
3
5
  from .parameter_dict import ParameterDictModel
4
- from fusion_bench.utils import LazyStateDict
@@ -10,6 +10,6 @@ Reference:
10
10
  """
11
11
 
12
12
  from .wrapper import (
13
- PrunableMixtralSparseMoeBlockWrapper,
14
13
  DynamicSkippingMixtralSparseMoeBlockWrapper,
14
+ PrunableMixtralSparseMoeBlockWrapper,
15
15
  )
@@ -0,0 +1,182 @@
1
+ """
2
+ This module contains utilities for working with Hugging Face models.
3
+ """
4
+
5
+ import inspect
6
+ import os
7
+ import shutil
8
+ from typing import Optional, cast
9
+
10
+ from omegaconf import OmegaConf
11
+ from transformers.modeling_utils import PreTrainedModel
12
+
13
+ from fusion_bench import BaseAlgorithm, BaseModelPool
14
+ from fusion_bench.utils.pylogger import getRankZeroLogger
15
+
16
+ log = getRankZeroLogger(__name__)
17
+
18
+ __all__ = [
19
+ "save_pretrained_with_remote_code",
20
+ "generate_readme_head",
21
+ "generate_readme_body",
22
+ "generate_complete_readme",
23
+ ]
24
+
25
+
26
+ def save_pretrained_with_remote_code(
27
+ model: PreTrainedModel,
28
+ auto_map: dict[str, object],
29
+ save_directory,
30
+ **kwargs,
31
+ ):
32
+ """
33
+ Saves a model with custom code to a directory.
34
+
35
+ This function facilitates saving a Hugging Face `PreTrainedModel` along with its
36
+ associated custom code. It inspects the objects provided in the `auto_map`,
37
+ copies their source files to the `save_directory`, and generates an `__init__.py`
38
+ to make them importable. It also updates the model's configuration with an
39
+ `auto_map` attribute, which allows `AutoModel.from_pretrained` to correctly
40
+ instantiate the custom model classes when `trust_remote_code=True`.
41
+
42
+ Args:
43
+ model (PreTrainedModel): The model instance to be saved.
44
+ auto_map (dict[str, object]): A dictionary mapping auto class names
45
+ (e.g., "AutoModelForCausalLM") to the corresponding custom class objects.
46
+ save_directory (str or os.PathLike): The directory where the model and
47
+ custom code files will be saved.
48
+ **kwargs: Additional keyword arguments to be passed to the
49
+ `model.save_pretrained` method.
50
+
51
+ Example:
52
+ ```python
53
+ # Assuming `model` is an instance of `SmileQwen2ForCausalLM`
54
+ # and `SmileQwen2Config`, `SmileQwen2Model`, `SmileQwen2ForCausalLM`
55
+ # are custom classes defined in your project.
56
+
57
+ save_pretrained_with_remote_code(
58
+ model,
59
+ auto_map={
60
+ "AutoConfig": SmileQwen2Config,
61
+ "AutoModel": SmileQwen2Model,
62
+ "AutoModelForCausalLM": SmileQwen2ForCausalLM,
63
+ },
64
+ save_directory="./my-custom-model",
65
+ )
66
+
67
+ # The model can then be loaded with `trust_remote_code=True`:
68
+ # from transformers import AutoModelForCausalLM
69
+ # loaded_model = AutoModelForCausalLM.from_pretrained(
70
+ # "./my-custom-model", trust_remote_code=True
71
+ # )
72
+ ```
73
+ """
74
+ auto_map_files = {}
75
+ auto_map_strs = {}
76
+ for key, obj in auto_map.items():
77
+ auto_map_files[key] = inspect.getfile(obj)
78
+
79
+ for key, obj in auto_map.items():
80
+ auto_map_strs[key] = (
81
+ f"{(inspect.getmodule(obj).__name__).split('.')[-1]}.{obj.__name__}"
82
+ )
83
+
84
+ model.config.auto_map = auto_map_strs
85
+
86
+ # save model to `save_directory`
87
+ model.save_pretrained(save_directory=save_directory, **kwargs)
88
+
89
+ # copy source files to `save_directory`
90
+ for key, file_path in auto_map_files.items():
91
+ shutil.copy(
92
+ src=file_path, dst=os.path.join(save_directory, os.path.basename(file_path))
93
+ )
94
+ # construct `__init__.py`
95
+ init_file = os.path.join(save_directory, "__init__.py")
96
+ with open(init_file, "w") as f:
97
+ for key, file_name in auto_map_files.items():
98
+ base_name = os.path.basename(file_name).split(".")[0]
99
+ f.write(f"from .{base_name} import {auto_map[key].__name__}\n")
100
+
101
+
102
+ def generate_readme_head(
103
+ models: list[str] | BaseModelPool,
104
+ library_name: str = "transformers",
105
+ tags: list[str] = ["fusion-bench", "merge"],
106
+ ):
107
+ text = "---\nbase_model:\n"
108
+ for model_name in models:
109
+ text += f"- {model_name}\n"
110
+ if library_name:
111
+ text += f"library_name: {library_name}\n"
112
+ text += "tags:\n"
113
+ for tag in tags:
114
+ text += f"- {tag}\n"
115
+ text += "---\n"
116
+ return text
117
+
118
+
119
+ def generate_readme_body(
120
+ algorithm: BaseAlgorithm,
121
+ models_or_modelpool: Optional[list[str] | BaseModelPool] = None,
122
+ models: list[str] = None,
123
+ ):
124
+ text = """\
125
+ # Merge
126
+
127
+ This is a merge of pre-trained language models created using [fusion-bench](https://github.com/tanganke/fusion_bench).
128
+
129
+ """
130
+
131
+ if models is not None:
132
+ text += """
133
+ ## Models Merged
134
+
135
+ The following models were included in the merge:
136
+
137
+ """
138
+ for model_name in models:
139
+ text += f"- {model_name}\n"
140
+ text += "\n"
141
+
142
+ try:
143
+ text += f"""\
144
+ ## Configuration
145
+
146
+ The following YAML configuration was used to produce this model:
147
+
148
+ ```yaml
149
+ {OmegaConf.to_yaml(algorithm.config, resolve=True, sort_keys=True)}
150
+ ```
151
+ """
152
+ except Exception as e:
153
+ return (
154
+ text # If the algorithm config cannot be converted to YAML, we skip it.
155
+ )
156
+
157
+ if isinstance(models_or_modelpool, BaseModelPool):
158
+ try:
159
+ text += f"""
160
+ ```yaml
161
+ {OmegaConf.to_yaml(models_or_modelpool.config, resolve=True, sort_keys=True)}
162
+ ```
163
+ """
164
+ except Exception as e:
165
+ pass # If the model pool config cannot be converted to YAML, we skip it.
166
+ return text
167
+
168
+
169
+ def generate_complete_readme(
170
+ algorithm: BaseAlgorithm, modelpool: BaseModelPool, models: list[str]
171
+ ):
172
+ # Generate the complete README content
173
+ text = generate_readme_head(
174
+ [modelpool.get_model_path(m) for m in modelpool.model_names]
175
+ )
176
+ readme_body = generate_readme_body(
177
+ algorithm,
178
+ models_or_modelpool=modelpool,
179
+ models=[modelpool.get_model_path(m) for m in modelpool.model_names],
180
+ )
181
+ complete_readme = text + "\n" + readme_body
182
+ return complete_readme
@@ -1,7 +1,7 @@
1
1
  import logging
2
2
  from collections import OrderedDict
3
3
  from copy import deepcopy
4
- from typing import Optional
4
+ from typing import Any, Dict, Optional, Tuple
5
5
 
6
6
  import torch.nn as nn
7
7
  from torch.func import functional_call, jvp
@@ -9,7 +9,7 @@ from torch.func import functional_call, jvp
9
9
  log = logging.getLogger(__name__)
10
10
 
11
11
 
12
- def dict_params_to_tuple(dict_params: dict):
12
+ def dict_params_to_tuple(dict_params: dict) -> Tuple:
13
13
  return tuple(v for k, v in dict_params.items())
14
14
 
15
15
 
@@ -33,7 +33,7 @@ class LinearizedModelWraper(nn.Module):
33
33
  for p in self.params0_values:
34
34
  p.requires_grad_(False)
35
35
 
36
- def tuple_params_to_dict(self, tuple_params):
36
+ def tuple_params_to_dict(self, tuple_params) -> Dict[str, Any]:
37
37
  """
38
38
  Converts a tuple of parameters to a dictionary with keys corresponding to the parameter names.
39
39
 
@@ -50,7 +50,7 @@ class LinearizedModelWraper(nn.Module):
50
50
  state_dict[k] = p
51
51
  return state_dict
52
52
 
53
- def forward(self, *args, **kwargs):
53
+ def forward(self, *args: Any, **kwargs: Any) -> Any:
54
54
  """
55
55
  Computes the linearized model output using a first-order Taylor decomposition.
56
56
 
@@ -70,7 +70,7 @@ def load_lora_vision_model_hf(
70
70
  peft_name: str,
71
71
  merge_and_unload: bool = False,
72
72
  return_vison_model=True,
73
- ):
73
+ ) -> PeftModel:
74
74
  """
75
75
  Load a LoRA (Low-Rank Adaptation) vision model from Hugging Face.
76
76
 
@@ -4,12 +4,12 @@ This is a direct copy of the DeepSeek-V2-Lite model from HuggingFace https://hug
4
4
 
5
5
  from .configuration_deepseek import DeepseekV2Config
6
6
  from .modeling_deepseek import (
7
+ DeepseekV2DecoderLayer,
7
8
  DeepseekV2ForCausalLM,
8
9
  DeepseekV2ForSequenceClassification,
9
10
  DeepseekV2MLP,
10
11
  DeepseekV2Model,
11
12
  DeepseekV2MoE,
12
- DeepseekV2DecoderLayer,
13
13
  )
14
14
  from .modeling_deepseek import MoEGate as DeepseekV2MoEGate
15
15
  from .tokenization_deepseek_fast import DeepseekTokenizerFast
@@ -17,17 +17,18 @@
17
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
18
  # See the License for the specific language governing permissions and
19
19
  # limitations under the License.
20
- """ PyTorch DeepSeek model."""
20
+ """PyTorch DeepSeek model."""
21
21
  import math
22
22
  import warnings
23
23
  from typing import List, Optional, Tuple, Union
24
24
 
25
+ import numpy as np
25
26
  import torch
27
+ import torch.distributed as dist
26
28
  import torch.nn.functional as F
27
29
  import torch.utils.checkpoint
28
30
  from torch import nn
29
31
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30
-
31
32
  from transformers.activations import ACT2FN
32
33
  from transformers.cache_utils import Cache, DynamicCache
33
34
  from transformers.modeling_attn_mask_utils import (
@@ -54,9 +55,8 @@ from transformers.utils import (
54
55
  replace_return_docstrings,
55
56
  )
56
57
  from transformers.utils.import_utils import is_torch_fx_available
58
+
57
59
  from .configuration_deepseek import DeepseekV2Config
58
- import torch.distributed as dist
59
- import numpy as np
60
60
 
61
61
  if is_flash_attn_2_available():
62
62
  from flash_attn import flash_attn_func, flash_attn_varlen_func
@@ -1,6 +1,5 @@
1
1
  from typing import List, Optional, Union
2
2
 
3
-
4
3
  from transformers.models.llama import LlamaTokenizerFast
5
4
 
6
5
 
@@ -0,0 +1,9 @@
1
+ from . import register
2
+ from .configuration_smile_gemma2 import SmileGemma2Config
3
+ from .modeling_smile_gemma2 import (
4
+ SmileGemma2ForCausalLM,
5
+ SmileGemma2ForSequenceClassification,
6
+ SmileGemma2ForTokenClassification,
7
+ SmileGemma2Model,
8
+ SmileGemma2PreTrainedModel,
9
+ )
@@ -0,0 +1,20 @@
1
+ from transformers.models.gemma2.configuration_gemma2 import Gemma2Config
2
+
3
+
4
+ class SmileGemma2Config(Gemma2Config):
5
+ model_type = "smile_gemma2"
6
+
7
+ def __init__(
8
+ self,
9
+ num_experts_per_tok: int = 1,
10
+ rank_of_router: int = None,
11
+ rank_of_expert: int = None,
12
+ num_local_experts: int = None,
13
+ **kwargs,
14
+ ):
15
+ self.num_experts_per_tok = num_experts_per_tok
16
+ self.rank_of_router = rank_of_router
17
+ self.rank_of_expert = rank_of_expert
18
+ self.num_local_experts = num_local_experts
19
+
20
+ super().__init__(**kwargs)