fusion-bench 0.2.19__py3-none-any.whl → 0.2.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (193) hide show
  1. fusion_bench/__init__.py +1 -0
  2. fusion_bench/_get_started/__init__.py +3 -0
  3. fusion_bench/_get_started/greeting_program.py +49 -0
  4. fusion_bench/compat/method/base_algorithm.py +14 -0
  5. fusion_bench/constants/__init__.py +5 -0
  6. fusion_bench/constants/clip_vision.py +26 -2
  7. fusion_bench/constants/paths.py +4 -0
  8. fusion_bench/dataset/clip_dataset.py +2 -1
  9. fusion_bench/dataset/gpt2_glue.py +9 -9
  10. fusion_bench/dataset/image_corruption/__init__.py +0 -0
  11. fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
  12. fusion_bench/dataset/image_dataset.py +1 -1
  13. fusion_bench/dataset/nyuv2.py +2 -2
  14. fusion_bench/method/__init__.py +16 -1
  15. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  16. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
  17. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
  18. fusion_bench/method/base_algorithm.py +195 -12
  19. fusion_bench/method/bitdelta/__init__.py +4 -0
  20. fusion_bench/method/bitdelta/bitdelta.py +156 -0
  21. fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
  22. fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
  23. fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
  24. fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
  25. fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
  26. fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
  27. fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
  28. fusion_bench/method/doge_ta/doge_ta.py +1 -1
  29. fusion_bench/method/ensemble.py +12 -12
  30. fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
  31. fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -2
  32. fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
  33. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
  34. fusion_bench/method/fw_merging/fw_hard.py +1 -1
  35. fusion_bench/method/fw_merging/fw_soft.py +1 -1
  36. fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
  37. fusion_bench/method/linear/expo.py +2 -1
  38. fusion_bench/method/linear/linear_interpolation.py +6 -4
  39. fusion_bench/method/linear/simple_average_for_llama.py +16 -6
  40. fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
  41. fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
  42. fusion_bench/method/model_recombination.py +2 -5
  43. fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
  44. fusion_bench/method/moe_pruner/utils/data.py +2 -1
  45. fusion_bench/method/moe_pruner/utils/prune.py +6 -1
  46. fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
  47. fusion_bench/method/pruning/wanda_utils/data.py +1 -2
  48. fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
  49. fusion_bench/method/randes/modelsoup.py +1 -3
  50. fusion_bench/method/regmean/clip_regmean.py +2 -2
  51. fusion_bench/method/regmean/gpt2_regmean.py +3 -10
  52. fusion_bench/method/regmean/regmean.py +2 -11
  53. fusion_bench/method/regmean_plusplus/__init__.py +3 -0
  54. fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +199 -0
  55. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +383 -0
  56. fusion_bench/method/simple_average.py +16 -4
  57. fusion_bench/method/slerp/slerp.py +5 -2
  58. fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
  59. fusion_bench/method/smile_upscaling/projected_energy.py +145 -0
  60. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +39 -28
  61. fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
  62. fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
  63. fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
  64. fusion_bench/method/ties_merging/ties_merging.py +13 -26
  65. fusion_bench/method/we_moe/clip_we_moe.py +5 -4
  66. fusion_bench/method/we_moe/we_moe.py +6 -6
  67. fusion_bench/method/weighted_average/llama.py +4 -16
  68. fusion_bench/metrics/continual_learning/__init__.py +1 -0
  69. fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
  70. fusion_bench/metrics/nyuv2/__init__.py +2 -2
  71. fusion_bench/metrics/nyuv2/segmentation.py +1 -1
  72. fusion_bench/mixins/__init__.py +10 -2
  73. fusion_bench/mixins/clip_classification.py +4 -3
  74. fusion_bench/mixins/hydra_config.py +105 -7
  75. fusion_bench/mixins/lightning_fabric.py +2 -0
  76. fusion_bench/mixins/serialization.py +265 -48
  77. fusion_bench/modelpool/__init__.py +2 -2
  78. fusion_bench/modelpool/base_pool.py +29 -9
  79. fusion_bench/modelpool/causal_lm/causal_lm.py +9 -0
  80. fusion_bench/modelpool/clip_vision/modelpool.py +43 -12
  81. fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
  82. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
  83. fusion_bench/models/__init__.py +2 -1
  84. fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
  85. fusion_bench/models/hf_utils.py +182 -0
  86. fusion_bench/models/linearized/linearized_model_utils.py +4 -4
  87. fusion_bench/models/linearized/vision_model.py +1 -1
  88. fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
  89. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
  90. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
  91. fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
  92. fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
  93. fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
  94. fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
  95. fusion_bench/models/modeling_smile_llama/__init__.py +0 -0
  96. fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
  97. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +705 -0
  98. fusion_bench/models/modeling_smile_llama/register.py +8 -0
  99. fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
  100. fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
  101. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +6 -7
  102. fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
  103. fusion_bench/models/parameter_dict.py +1 -1
  104. fusion_bench/models/sparse_we_moe.py +1 -53
  105. fusion_bench/models/utils.py +26 -0
  106. fusion_bench/models/we_moe.py +1 -53
  107. fusion_bench/models/wrappers/ensemble.py +6 -4
  108. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
  109. fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
  110. fusion_bench/programs/base_program.py +81 -2
  111. fusion_bench/programs/fabric_fusion_program.py +24 -8
  112. fusion_bench/scripts/cli.py +6 -6
  113. fusion_bench/taskpool/base_pool.py +4 -3
  114. fusion_bench/taskpool/clip_vision/taskpool.py +34 -18
  115. fusion_bench/taskpool/dummy.py +1 -1
  116. fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
  117. fusion_bench/tasks/clip_classification/__init__.py +6 -4
  118. fusion_bench/utils/__init__.py +6 -1
  119. fusion_bench/utils/devices.py +14 -4
  120. fusion_bench/utils/instantiate_utils.py +3 -1
  121. fusion_bench/utils/misc.py +48 -2
  122. fusion_bench/utils/modelscope.py +265 -0
  123. fusion_bench/utils/parameters.py +2 -2
  124. fusion_bench/utils/rich_utils.py +3 -0
  125. fusion_bench/utils/state_dict_arithmetic.py +34 -27
  126. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/METADATA +31 -24
  127. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/RECORD +189 -153
  128. fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
  129. fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
  130. fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
  131. fusion_bench_config/_get_started/greeting_program.yaml +4 -0
  132. fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
  133. fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
  134. fusion_bench_config/fabric_model_fusion.yaml +45 -17
  135. fusion_bench_config/hydra/default.yaml +6 -2
  136. fusion_bench_config/llama_full_finetune.yaml +1 -0
  137. fusion_bench_config/method/adamerging/clip.yaml +1 -1
  138. fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
  139. fusion_bench_config/method/depth_upscaling.yaml +4 -1
  140. fusion_bench_config/method/regmean/clip_regmean.yaml +1 -1
  141. fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +11 -0
  142. fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
  143. fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
  144. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -0
  145. fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
  146. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +73 -8
  147. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +27 -7
  148. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +34 -4
  149. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +14 -17
  150. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +14 -3
  151. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +39 -5
  152. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +49 -5
  153. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +55 -5
  154. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +21 -4
  155. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +61 -5
  156. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +67 -5
  157. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +73 -5
  158. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +26 -3
  159. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
  160. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +7 -5
  161. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +6 -10
  162. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +6 -7
  163. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +6 -7
  164. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +7 -8
  165. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +8 -6
  166. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +4 -6
  167. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +32 -7
  168. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +14 -6
  169. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +73 -8
  170. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +27 -7
  171. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +6 -10
  172. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +2 -2
  173. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
  174. fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
  175. fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
  176. fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
  177. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
  178. fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
  179. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
  180. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
  181. fusion_bench_config/nyuv2_config.yaml +3 -1
  182. fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
  183. fusion_bench_config/path/default.yaml +28 -0
  184. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
  185. fusion_bench_config/method/adamerging.yaml +0 -23
  186. fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
  187. fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
  188. fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
  189. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/WHEEL +0 -0
  190. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/entry_points.txt +0 -0
  191. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/licenses/LICENSE +0 -0
  192. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/top_level.txt +0 -0
  193. /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
@@ -0,0 +1,182 @@
1
+ """
2
+ This module contains utilities for working with Hugging Face models.
3
+ """
4
+
5
+ import inspect
6
+ import os
7
+ import shutil
8
+ from typing import Optional, cast
9
+
10
+ from omegaconf import OmegaConf
11
+ from transformers.modeling_utils import PreTrainedModel
12
+
13
+ from fusion_bench import BaseAlgorithm, BaseModelPool
14
+ from fusion_bench.utils.pylogger import getRankZeroLogger
15
+
16
+ log = getRankZeroLogger(__name__)
17
+
18
+ __all__ = [
19
+ "save_pretrained_with_remote_code",
20
+ "generate_readme_head",
21
+ "generate_readme_body",
22
+ "generate_complete_readme",
23
+ ]
24
+
25
+
26
+ def save_pretrained_with_remote_code(
27
+ model: PreTrainedModel,
28
+ auto_map: dict[str, object],
29
+ save_directory,
30
+ **kwargs,
31
+ ):
32
+ """
33
+ Saves a model with custom code to a directory.
34
+
35
+ This function facilitates saving a Hugging Face `PreTrainedModel` along with its
36
+ associated custom code. It inspects the objects provided in the `auto_map`,
37
+ copies their source files to the `save_directory`, and generates an `__init__.py`
38
+ to make them importable. It also updates the model's configuration with an
39
+ `auto_map` attribute, which allows `AutoModel.from_pretrained` to correctly
40
+ instantiate the custom model classes when `trust_remote_code=True`.
41
+
42
+ Args:
43
+ model (PreTrainedModel): The model instance to be saved.
44
+ auto_map (dict[str, object]): A dictionary mapping auto class names
45
+ (e.g., "AutoModelForCausalLM") to the corresponding custom class objects.
46
+ save_directory (str or os.PathLike): The directory where the model and
47
+ custom code files will be saved.
48
+ **kwargs: Additional keyword arguments to be passed to the
49
+ `model.save_pretrained` method.
50
+
51
+ Example:
52
+ ```python
53
+ # Assuming `model` is an instance of `SmileQwen2ForCausalLM`
54
+ # and `SmileQwen2Config`, `SmileQwen2Model`, `SmileQwen2ForCausalLM`
55
+ # are custom classes defined in your project.
56
+
57
+ save_pretrained_with_remote_code(
58
+ model,
59
+ auto_map={
60
+ "AutoConfig": SmileQwen2Config,
61
+ "AutoModel": SmileQwen2Model,
62
+ "AutoModelForCausalLM": SmileQwen2ForCausalLM,
63
+ },
64
+ save_directory="./my-custom-model",
65
+ )
66
+
67
+ # The model can then be loaded with `trust_remote_code=True`:
68
+ # from transformers import AutoModelForCausalLM
69
+ # loaded_model = AutoModelForCausalLM.from_pretrained(
70
+ # "./my-custom-model", trust_remote_code=True
71
+ # )
72
+ ```
73
+ """
74
+ auto_map_files = {}
75
+ auto_map_strs = {}
76
+ for key, obj in auto_map.items():
77
+ auto_map_files[key] = inspect.getfile(obj)
78
+
79
+ for key, obj in auto_map.items():
80
+ auto_map_strs[key] = (
81
+ f"{(inspect.getmodule(obj).__name__).split('.')[-1]}.{obj.__name__}"
82
+ )
83
+
84
+ model.config.auto_map = auto_map_strs
85
+
86
+ # save model to `save_directory`
87
+ model.save_pretrained(save_directory=save_directory, **kwargs)
88
+
89
+ # copy source files to `save_directory`
90
+ for key, file_path in auto_map_files.items():
91
+ shutil.copy(
92
+ src=file_path, dst=os.path.join(save_directory, os.path.basename(file_path))
93
+ )
94
+ # construct `__init__.py`
95
+ init_file = os.path.join(save_directory, "__init__.py")
96
+ with open(init_file, "w") as f:
97
+ for key, file_name in auto_map_files.items():
98
+ base_name = os.path.basename(file_name).split(".")[0]
99
+ f.write(f"from .{base_name} import {auto_map[key].__name__}\n")
100
+
101
+
102
+ def generate_readme_head(
103
+ models: list[str] | BaseModelPool,
104
+ library_name: str = "transformers",
105
+ tags: list[str] = ["fusion-bench", "merge"],
106
+ ):
107
+ text = "---\nbase_model:\n"
108
+ for model_name in models:
109
+ text += f"- {model_name}\n"
110
+ if library_name:
111
+ text += f"library_name: {library_name}\n"
112
+ text += "tags:\n"
113
+ for tag in tags:
114
+ text += f"- {tag}\n"
115
+ text += "---\n"
116
+ return text
117
+
118
+
119
+ def generate_readme_body(
120
+ algorithm: BaseAlgorithm,
121
+ models_or_modelpool: Optional[list[str] | BaseModelPool] = None,
122
+ models: list[str] = None,
123
+ ):
124
+ text = """\
125
+ # Merge
126
+
127
+ This is a merge of pre-trained language models created using [fusion-bench](https://github.com/tanganke/fusion_bench).
128
+
129
+ """
130
+
131
+ if models is not None:
132
+ text += """
133
+ ## Models Merged
134
+
135
+ The following models were included in the merge:
136
+
137
+ """
138
+ for model_name in models:
139
+ text += f"- {model_name}\n"
140
+ text += "\n"
141
+
142
+ try:
143
+ text += f"""\
144
+ ## Configuration
145
+
146
+ The following YAML configuration was used to produce this model:
147
+
148
+ ```yaml
149
+ {OmegaConf.to_yaml(algorithm.config, resolve=True, sort_keys=True)}
150
+ ```
151
+ """
152
+ except Exception as e:
153
+ return (
154
+ text # If the algorithm config cannot be converted to YAML, we skip it.
155
+ )
156
+
157
+ if isinstance(models_or_modelpool, BaseModelPool):
158
+ try:
159
+ text += f"""
160
+ ```yaml
161
+ {OmegaConf.to_yaml(models_or_modelpool.config, resolve=True, sort_keys=True)}
162
+ ```
163
+ """
164
+ except Exception as e:
165
+ pass # If the model pool config cannot be converted to YAML, we skip it.
166
+ return text
167
+
168
+
169
+ def generate_complete_readme(
170
+ algorithm: BaseAlgorithm, modelpool: BaseModelPool, models: list[str]
171
+ ):
172
+ # Generate the complete README content
173
+ text = generate_readme_head(
174
+ [modelpool.get_model_path(m) for m in modelpool.model_names]
175
+ )
176
+ readme_body = generate_readme_body(
177
+ algorithm,
178
+ models_or_modelpool=modelpool,
179
+ models=[modelpool.get_model_path(m) for m in modelpool.model_names],
180
+ )
181
+ complete_readme = text + "\n" + readme_body
182
+ return complete_readme
@@ -1,7 +1,7 @@
1
1
  import logging
2
2
  from collections import OrderedDict
3
3
  from copy import deepcopy
4
- from typing import Optional
4
+ from typing import Any, Dict, Optional, Tuple
5
5
 
6
6
  import torch.nn as nn
7
7
  from torch.func import functional_call, jvp
@@ -9,7 +9,7 @@ from torch.func import functional_call, jvp
9
9
  log = logging.getLogger(__name__)
10
10
 
11
11
 
12
- def dict_params_to_tuple(dict_params: dict):
12
+ def dict_params_to_tuple(dict_params: dict) -> Tuple:
13
13
  return tuple(v for k, v in dict_params.items())
14
14
 
15
15
 
@@ -33,7 +33,7 @@ class LinearizedModelWraper(nn.Module):
33
33
  for p in self.params0_values:
34
34
  p.requires_grad_(False)
35
35
 
36
- def tuple_params_to_dict(self, tuple_params):
36
+ def tuple_params_to_dict(self, tuple_params) -> Dict[str, Any]:
37
37
  """
38
38
  Converts a tuple of parameters to a dictionary with keys corresponding to the parameter names.
39
39
 
@@ -50,7 +50,7 @@ class LinearizedModelWraper(nn.Module):
50
50
  state_dict[k] = p
51
51
  return state_dict
52
52
 
53
- def forward(self, *args, **kwargs):
53
+ def forward(self, *args: Any, **kwargs: Any) -> Any:
54
54
  """
55
55
  Computes the linearized model output using a first-order Taylor decomposition.
56
56
 
@@ -70,7 +70,7 @@ def load_lora_vision_model_hf(
70
70
  peft_name: str,
71
71
  merge_and_unload: bool = False,
72
72
  return_vison_model=True,
73
- ):
73
+ ) -> PeftModel:
74
74
  """
75
75
  Load a LoRA (Low-Rank Adaptation) vision model from Hugging Face.
76
76
 
@@ -4,12 +4,12 @@ This is a direct copy of the DeepSeek-V2-Lite model from HuggingFace https://hug
4
4
 
5
5
  from .configuration_deepseek import DeepseekV2Config
6
6
  from .modeling_deepseek import (
7
+ DeepseekV2DecoderLayer,
7
8
  DeepseekV2ForCausalLM,
8
9
  DeepseekV2ForSequenceClassification,
9
10
  DeepseekV2MLP,
10
11
  DeepseekV2Model,
11
12
  DeepseekV2MoE,
12
- DeepseekV2DecoderLayer,
13
13
  )
14
14
  from .modeling_deepseek import MoEGate as DeepseekV2MoEGate
15
15
  from .tokenization_deepseek_fast import DeepseekTokenizerFast
@@ -17,17 +17,18 @@
17
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
18
  # See the License for the specific language governing permissions and
19
19
  # limitations under the License.
20
- """ PyTorch DeepSeek model."""
20
+ """PyTorch DeepSeek model."""
21
21
  import math
22
22
  import warnings
23
23
  from typing import List, Optional, Tuple, Union
24
24
 
25
+ import numpy as np
25
26
  import torch
27
+ import torch.distributed as dist
26
28
  import torch.nn.functional as F
27
29
  import torch.utils.checkpoint
28
30
  from torch import nn
29
31
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30
-
31
32
  from transformers.activations import ACT2FN
32
33
  from transformers.cache_utils import Cache, DynamicCache
33
34
  from transformers.modeling_attn_mask_utils import (
@@ -54,9 +55,8 @@ from transformers.utils import (
54
55
  replace_return_docstrings,
55
56
  )
56
57
  from transformers.utils.import_utils import is_torch_fx_available
58
+
57
59
  from .configuration_deepseek import DeepseekV2Config
58
- import torch.distributed as dist
59
- import numpy as np
60
60
 
61
61
  if is_flash_attn_2_available():
62
62
  from flash_attn import flash_attn_func, flash_attn_varlen_func
@@ -1,6 +1,5 @@
1
1
  from typing import List, Optional, Union
2
2
 
3
-
4
3
  from transformers.models.llama import LlamaTokenizerFast
5
4
 
6
5
 
@@ -0,0 +1,9 @@
1
+ from . import register
2
+ from .configuration_smile_gemma2 import SmileGemma2Config
3
+ from .modeling_smile_gemma2 import (
4
+ SmileGemma2ForCausalLM,
5
+ SmileGemma2ForSequenceClassification,
6
+ SmileGemma2ForTokenClassification,
7
+ SmileGemma2Model,
8
+ SmileGemma2PreTrainedModel,
9
+ )
@@ -0,0 +1,20 @@
1
+ from transformers.models.gemma2.configuration_gemma2 import Gemma2Config
2
+
3
+
4
+ class SmileGemma2Config(Gemma2Config):
5
+ model_type = "smile_gemma2"
6
+
7
+ def __init__(
8
+ self,
9
+ num_experts_per_tok: int = 1,
10
+ rank_of_router: int = None,
11
+ rank_of_expert: int = None,
12
+ num_local_experts: int = None,
13
+ **kwargs,
14
+ ):
15
+ self.num_experts_per_tok = num_experts_per_tok
16
+ self.rank_of_router = rank_of_router
17
+ self.rank_of_expert = rank_of_expert
18
+ self.num_local_experts = num_local_experts
19
+
20
+ super().__init__(**kwargs)