fusion-bench 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (209) hide show
  1. fusion_bench/compat/method/__init__.py +2 -0
  2. fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
  3. fusion_bench/constants/clip_vision.py +22 -0
  4. fusion_bench/dataset/clip_dataset.py +10 -2
  5. fusion_bench/dataset/fer2013.py +1 -0
  6. fusion_bench/dataset/gsm8k.py +2 -2
  7. fusion_bench/method/__init__.py +10 -0
  8. fusion_bench/method/ada_svd/clip_vision.py +4 -1
  9. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
  10. fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
  11. fusion_bench/method/gossip/__init__.py +3 -0
  12. fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
  13. fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
  14. fusion_bench/method/gossip/entropy_loss.py +25 -0
  15. fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
  16. fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
  17. fusion_bench/method/gossip/min_norm_solvers.py +227 -0
  18. fusion_bench/method/gossip/task_wise_gossip.py +265 -0
  19. fusion_bench/method/gossip/utils.py +74 -0
  20. fusion_bench/method/isotropic_merging/__init__.py +1 -1
  21. fusion_bench/method/opcm/opcm.py +16 -7
  22. fusion_bench/method/pwe_moe/module.py +1 -1
  23. fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
  24. fusion_bench/method/regmean/regmean.py +25 -17
  25. fusion_bench/method/smile_upscaling/__init__.py +1 -1
  26. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +46 -145
  27. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +229 -0
  28. fusion_bench/method/smile_upscaling/smile_upscaling.py +19 -346
  29. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
  30. fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
  31. fusion_bench/method/ties_merging/ties_merging.py +36 -31
  32. fusion_bench/method/we_moe/we_moe.py +14 -15
  33. fusion_bench/mixins/__init__.py +6 -3
  34. fusion_bench/mixins/hydra_config.py +49 -0
  35. fusion_bench/mixins/openclip_classification.py +11 -0
  36. fusion_bench/mixins/simple_profiler.py +4 -2
  37. fusion_bench/modelpool/__init__.py +3 -1
  38. fusion_bench/modelpool/base_pool.py +2 -2
  39. fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
  40. fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
  41. fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +2 -203
  42. fusion_bench/models/modeling_smile_qwen2/__init__.py +8 -0
  43. fusion_bench/models/modeling_smile_qwen2/configuration_smile_qwen2.py +21 -0
  44. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +922 -0
  45. fusion_bench/models/modeling_smile_qwen2/register.py +11 -0
  46. fusion_bench/models/open_clip/__init__.py +6 -0
  47. fusion_bench/models/open_clip/modeling.py +176 -0
  48. fusion_bench/models/open_clip/utils.py +311 -0
  49. fusion_bench/models/open_clip/variables_and_paths.py +56 -0
  50. fusion_bench/models/parameter_dict.py +54 -13
  51. fusion_bench/models/rankone_moe.py +2 -88
  52. fusion_bench/models/smile_moe/linear_from_hf_config.py +373 -0
  53. fusion_bench/models/smile_moe/{linear.py → linear_from_module.py} +103 -33
  54. fusion_bench/models/smile_moe/utils/__init__.py +24 -0
  55. fusion_bench/models/smile_moe/utils/svd_utils.py +46 -0
  56. fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
  57. fusion_bench/taskpool/__init__.py +7 -3
  58. fusion_bench/taskpool/clip_vision/__init__.py +1 -0
  59. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
  60. fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
  61. fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
  62. fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
  63. fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
  64. fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
  65. fusion_bench/taskpool/gpt2_text_classification.py +30 -1
  66. fusion_bench/taskpool/lm_eval_harness/__init__.py +3 -0
  67. fusion_bench/taskpool/lm_eval_harness/taskpool.py +87 -0
  68. fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
  69. fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
  70. fusion_bench/utils/data.py +12 -0
  71. fusion_bench/utils/devices.py +14 -0
  72. fusion_bench/utils/instantiate.py +12 -0
  73. fusion_bench/utils/misc.py +9 -2
  74. fusion_bench/utils/packages.py +14 -0
  75. fusion_bench/utils/parameters.py +1 -1
  76. fusion_bench/utils/tensorboard.py +1 -1
  77. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/METADATA +22 -2
  78. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/RECORD +209 -157
  79. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/WHEEL +1 -1
  80. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
  81. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
  82. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
  83. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
  84. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
  85. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
  86. fusion_bench_config/fabric/auto.yaml +0 -1
  87. fusion_bench_config/fabric/llama_ddp.yaml +0 -1
  88. fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
  89. fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
  90. fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
  91. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
  92. fusion_bench_config/fabric_model_fusion.yaml +0 -1
  93. fusion_bench_config/llama_full_finetune.yaml +0 -2
  94. fusion_bench_config/llama_model_fusion.yaml +0 -2
  95. fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
  96. fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
  97. fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
  98. fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
  99. fusion_bench_config/method/adamerging.yaml +2 -2
  100. fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
  101. fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
  102. fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
  103. fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
  104. fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
  105. fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
  106. fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
  107. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
  108. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
  109. fusion_bench_config/method/dare/simple_average.yaml +0 -1
  110. fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
  111. fusion_bench_config/method/dare/ties_merging.yaml +0 -2
  112. fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
  113. fusion_bench_config/method/doge_ta/doge_ta.yaml +1 -1
  114. fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
  115. fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
  116. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
  117. fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
  118. fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
  119. fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
  120. fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
  121. fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
  122. fusion_bench_config/method/linear/llama_expo.yaml +0 -3
  123. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
  124. fusion_bench_config/method/linear/weighted_average.yaml +0 -1
  125. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
  126. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
  127. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
  128. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
  129. fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
  130. fusion_bench_config/method/model_recombination.yaml +0 -1
  131. fusion_bench_config/method/opcm/opcm.yaml +0 -1
  132. fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
  133. fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
  134. fusion_bench_config/method/opcm/weight_average.yaml +0 -1
  135. fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
  136. fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
  137. fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
  138. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
  139. fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
  140. fusion_bench_config/method/slerp/slerp.yaml +0 -2
  141. fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +5 -2
  142. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +13 -0
  143. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
  144. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
  145. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
  146. fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
  147. fusion_bench_config/method/task_arithmetic.yaml +1 -1
  148. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
  149. fusion_bench_config/method/ties_merging.yaml +1 -1
  150. fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
  151. fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
  152. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
  153. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
  154. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
  155. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
  156. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
  157. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
  158. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
  159. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
  160. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
  161. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
  162. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
  163. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
  164. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
  165. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
  166. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
  167. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
  168. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
  169. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
  170. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
  171. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
  172. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
  173. fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
  174. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
  175. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
  176. fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +17 -0
  177. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
  178. fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
  179. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
  180. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
  181. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
  182. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
  183. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
  184. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
  185. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
  186. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
  187. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
  188. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
  189. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +0 -2
  190. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +1 -3
  191. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
  192. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
  193. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
  194. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
  195. fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
  196. fusion_bench_config/nyuv2_config.yaml +0 -2
  197. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
  198. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
  199. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
  200. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
  201. fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml +12 -0
  202. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
  203. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
  204. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
  205. fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
  206. fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
  207. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/entry_points.txt +0 -0
  208. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/licenses/LICENSE +0 -0
  209. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/top_level.txt +0 -0
@@ -8,64 +8,13 @@ from torch import Tensor, nn
8
8
  from torch.func import functional_call
9
9
  from torch.nn import functional as F
10
10
 
11
+ from fusion_bench.models.smile_moe.utils import _is_all_zeros, svd
12
+ from fusion_bench.models.utils import del_attr, get_attr, set_attr
11
13
  from fusion_bench.utils.type import StateDictType
12
14
 
13
15
  log = logging.getLogger(__name__)
14
16
 
15
17
 
16
- def join_list(list_of_list: List[List]):
17
- ans = []
18
- for l in list_of_list:
19
- ans.extend(l)
20
- return ans
21
-
22
-
23
- def del_attr(obj, names: List[str]):
24
- """
25
- Deletes an attribute from an object recursively.
26
-
27
- Args:
28
- obj (object): Object to delete attribute from.
29
- names (list): List of attribute names to delete recursively.
30
- """
31
- if len(names) == 1:
32
- delattr(obj, names[0])
33
- else:
34
- del_attr(getattr(obj, names[0]), names[1:])
35
-
36
-
37
- def set_attr(obj, names: List[str], val):
38
- """
39
- Sets an attribute of an object recursively.
40
-
41
- Args:
42
- obj (object): Object to set attribute of.
43
- names (list): List of attribute names to set recursively.
44
- val (object): Value to set the attribute to.
45
- """
46
- if len(names) == 1:
47
- setattr(obj, names[0], val)
48
- else:
49
- set_attr(getattr(obj, names[0]), names[1:], val)
50
-
51
-
52
- def get_attr(obj, names: List[str]):
53
- """
54
- Gets an attribute of an object recursively.
55
-
56
- Args:
57
- obj (object): Object to get attribute of.
58
- names (list): List of attribute names to get recursively.
59
-
60
- Returns:
61
- object: The attribute of the object.
62
- """
63
- if len(names) == 1:
64
- return getattr(obj, names[0])
65
- else:
66
- return get_attr(getattr(obj, names[0]), names[1:])
67
-
68
-
69
18
  class Depth_0_Gate(nn.Module):
70
19
  def __init__(self, num_experts: int):
71
20
  super().__init__()
@@ -132,41 +81,6 @@ class ExpertNotTrainedError(Exception):
132
81
  pass
133
82
 
134
83
 
135
- def _is_all_zeros(tensor: Tensor | List[Tensor]) -> bool:
136
- """
137
- Check if a tensor or a list of tensors are all zeros.
138
- """
139
- if isinstance(tensor, Tensor):
140
- return torch.allclose(tensor, torch.zeros_like(tensor))
141
- else:
142
- return all(_is_all_zeros(t) for t in tensor)
143
-
144
-
145
- def _svd(w: Tensor, full_matrices=True) -> Tuple[Tensor, Tensor, Tensor]:
146
- """
147
- Perform Singular Value Decomposition (SVD) on a tensor.
148
- """
149
- u, s, vh = torch.linalg.svd(
150
- w, full_matrices=full_matrices, driver="gesvd" if w.is_cuda else None
151
- )
152
- v = vh.T
153
- return u, s, v
154
-
155
-
156
- def svd(
157
- w: Tensor, full_matrices=True, accelerator=None
158
- ) -> Tuple[Tensor, Tensor, Tensor]:
159
- """
160
- Perform SVD on a tensor, optionally using a specified accelerator.
161
- """
162
- if accelerator is None:
163
- return _svd(w, full_matrices=full_matrices)
164
- original_device = w.device
165
- w = w.to(accelerator)
166
- u, s, v = _svd(w)
167
- return u.to(original_device), s.to(original_device), v.to(original_device)
168
-
169
-
170
84
  def fun_joint_svd(
171
85
  w_list: List[Tensor], accelerator=None
172
86
  ) -> Tuple[Tensor, Tensor, Tensor]:
@@ -0,0 +1,373 @@
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import Tensor, nn
6
+
7
+ from fusion_bench.utils.state_dict_arithmetic import state_dict_sub
8
+
9
+ from .utils import _is_all_zeros
10
+
11
+
12
+ class ExpertNotTrainedError(Exception):
13
+ pass
14
+
15
+
16
+ def _svd(w: Tensor, full_matrices=False) -> Tuple[Tensor, Tensor, Tensor]:
17
+ """
18
+ Perform Singular Value Decomposition (SVD) on a tensor.
19
+
20
+ Args:
21
+ w (Tensor): The input tensor.
22
+ full_matrices (bool, optional): Whether to compute the full-sized U and V matrices. Defaults to False.
23
+
24
+ Returns:
25
+ Tuple[Tensor, Tensor, Tensor]: The U, S, and V matrices from SVD.
26
+ """
27
+ dtype = w.dtype
28
+ if w.dtype != torch.float32 or w.dtype != torch.float64:
29
+ w = w.float()
30
+
31
+ u, s, vh = torch.linalg.svd(
32
+ w,
33
+ full_matrices=full_matrices,
34
+ # driver="gesvd" if w.is_cuda else None
35
+ )
36
+ v = vh.T
37
+
38
+ u = u.to(dtype=dtype)
39
+ s = s.to(dtype=dtype)
40
+ v = v.to(dtype=dtype)
41
+ return u, s, v
42
+
43
+
44
+ def svd(
45
+ w: Tensor, full_matrices=True, accelerator=None
46
+ ) -> Tuple[Tensor, Tensor, Tensor]:
47
+ """
48
+ Perform SVD on a tensor with optional acceleration.
49
+ This is different from `.utils.svd` in that it handles tensors with precision other than float32 or float64.
50
+
51
+ Args:
52
+ w (Tensor): The input tensor.
53
+ full_matrices (bool, optional): Whether to compute the full-sized U and V matrices. Defaults to True.
54
+ accelerator (optional): The device to perform the computation on. Defaults to None.
55
+
56
+ Returns:
57
+ Tuple[Tensor, Tensor, Tensor]: The U, S, and V matrices from SVD.
58
+ """
59
+ if accelerator is None:
60
+ return _svd(w, full_matrices=full_matrices)
61
+ original_device = w.device
62
+ w = w.to(accelerator)
63
+ u, s, v = _svd(w)
64
+ return u, s, v
65
+
66
+
67
+ class SmileMoEConfig:
68
+ """
69
+ Example PretrainedConfig for SmileMoE.
70
+
71
+ Args:
72
+ num_experts_per_tok: Number of experts per token.
73
+ rank_of_router: Rank of the router.
74
+ rank_of_expert: Rank of the expert.
75
+ num_local_experts: Number of local experts.
76
+ """
77
+
78
+ num_experts_per_tok: int
79
+ rank_of_router: int
80
+ rank_of_expert: int
81
+ num_local_experts: int
82
+
83
+
84
+ class SmileGate(nn.Module):
85
+ __constants__ = ["in_features", "num_experts", "k"]
86
+ in_features: int
87
+ num_experts: int
88
+ k: int
89
+ weight: nn.Parameter
90
+
91
+ def __init__(
92
+ self,
93
+ in_features: int,
94
+ num_experts: int,
95
+ k: int,
96
+ device=None,
97
+ dtype=None,
98
+ ):
99
+ factory_kwargs = {"device": device, "dtype": dtype}
100
+ super().__init__()
101
+ self.input_features = in_features
102
+ self.num_experts = num_experts
103
+ self.k = k
104
+
105
+ self.weight = nn.Parameter(
106
+ torch.empty(num_experts * k, in_features, **factory_kwargs)
107
+ )
108
+
109
+ def forward(self, x: Tensor):
110
+ batch_size = x.size(0)
111
+ if self.num_experts == 1:
112
+ return torch.ones(batch_size, 1, device=x.device, dtype=x.dtype)
113
+
114
+ routing_weights = F.linear(x, self.weight).view(
115
+ batch_size, self.num_experts, self.k
116
+ )
117
+ routing_weights = routing_weights.norm(p=2, dim=2)
118
+ return routing_weights
119
+
120
+
121
+ class SmileLinearExpert(nn.Module):
122
+ __constants__ = ["in_features", "out_features", "k"]
123
+ in_features: int
124
+ out_features: int
125
+ k: int
126
+
127
+ def __init__(
128
+ self,
129
+ in_features,
130
+ out_features,
131
+ k: int,
132
+ bias: bool,
133
+ device=None,
134
+ dtype=None,
135
+ ):
136
+ factory_kwargs = {"device": device, "dtype": dtype}
137
+ super().__init__()
138
+ self.in_features = in_features
139
+ self.out_features = out_features
140
+ self.k = k
141
+ if k > 0:
142
+ # check k < in_features and out_features
143
+ if k > in_features:
144
+ raise ValueError(
145
+ f"k ({k}) must not be greater than in_features ({in_features})"
146
+ )
147
+ if k > out_features:
148
+ raise ValueError(
149
+ f"k ({k}) must not be greater than out_features ({out_features})"
150
+ )
151
+
152
+ self.u = nn.Parameter(torch.empty(out_features, k, **factory_kwargs))
153
+ self.svh = nn.Parameter(torch.empty(k, in_features, **factory_kwargs))
154
+
155
+ if bias:
156
+ self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
157
+ else:
158
+ self.register_parameter("bias", None)
159
+
160
+ def forward(self, x):
161
+ x = F.linear(x, self.svh)
162
+ x = F.linear(x, self.u, self.bias)
163
+ return x
164
+
165
+
166
+ class SmileLinear(nn.Module):
167
+ __constants__ = [
168
+ "in_features",
169
+ "out_features",
170
+ "num_local_experts",
171
+ "num_experts_per_tok",
172
+ "rank_of_expert",
173
+ "rank_of_router",
174
+ ]
175
+
176
+ in_features: int
177
+ out_features: int
178
+ num_local_experts: int
179
+ num_experts_per_tok: int
180
+ rank_of_expert: int
181
+ rank_of_router: int
182
+
183
+ @torch.no_grad()
184
+ def __init__(
185
+ self,
186
+ config: SmileMoEConfig,
187
+ in_features,
188
+ out_features,
189
+ bias: bool,
190
+ device=None,
191
+ dtype=None,
192
+ ):
193
+ factory_kwargs = {"device": device, "dtype": dtype}
194
+ super().__init__()
195
+ self.num_local_experts = config.num_local_experts
196
+ self.num_experts_per_tok = config.num_experts_per_tok
197
+ self.rank_of_expert = config.rank_of_expert
198
+ self.rank_of_router = config.rank_of_router
199
+ self.in_features = in_features
200
+ self.out_features = out_features
201
+
202
+ # construct the gate network
203
+ self.gate = SmileGate(
204
+ in_features=in_features,
205
+ num_experts=self.num_local_experts,
206
+ k=self.rank_of_router,
207
+ **factory_kwargs,
208
+ )
209
+
210
+ # the shared linear
211
+ self.shared_linear = nn.Linear(
212
+ in_features, out_features, bias=bias, **factory_kwargs
213
+ )
214
+
215
+ # construct experts
216
+ if self.rank_of_expert > 0:
217
+ self.experts = nn.ModuleList(
218
+ [
219
+ SmileLinearExpert(
220
+ in_features=in_features,
221
+ out_features=out_features,
222
+ bias=bias,
223
+ k=self.rank_of_expert,
224
+ **factory_kwargs,
225
+ )
226
+ for _ in range(self.num_local_experts)
227
+ ]
228
+ )
229
+ else:
230
+ self.experts = nn.ModuleList(
231
+ [
232
+ nn.Linear(in_features, out_features, bias=bias, **factory_kwargs)
233
+ for _ in range(self.num_local_experts)
234
+ ]
235
+ )
236
+
237
+ def forward(self, hidden_states: Tensor):
238
+ pretrained_out = self.shared_linear(hidden_states)
239
+
240
+ input_shape = hidden_states.size()
241
+ hidden_states = hidden_states.view(-1, self.in_features)
242
+
243
+ router_logits = self.gate(hidden_states)
244
+ routing_weights = F.softmax(router_logits, dim=1)
245
+ # sample the expert according to the routing weights
246
+ routing_weights, selected_experts = torch.topk(
247
+ routing_weights, self.num_experts_per_tok, dim=-1
248
+ )
249
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
250
+
251
+ final_hidden_states = torch.zeros(
252
+ (hidden_states.size(0), self.out_features),
253
+ dtype=hidden_states.dtype,
254
+ device=hidden_states.device,
255
+ )
256
+
257
+ # One hot encode the selected experts to create an expert mask
258
+ # this will be used to easily index which expert is going to be sollicitated
259
+ expert_mask = torch.nn.functional.one_hot(
260
+ selected_experts, num_classes=self.num_local_experts
261
+ ).permute(2, 1, 0)
262
+
263
+ # Loop over all available experts in the model and perform the computation on each expert
264
+ for expert_idx in range(self.num_local_experts):
265
+ expert_layer = self.experts[expert_idx]
266
+ idx, top_x = torch.where(expert_mask[expert_idx])
267
+
268
+ # Index the correct hidden states and compute the expert hidden state for
269
+ # the current expert. We need to make sure to multiply the output hidden
270
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
271
+ current_state = hidden_states[None, top_x].reshape(-1, self.in_features)
272
+ if current_state.numel() == 0:
273
+ continue
274
+ current_hidden_states = (
275
+ expert_layer(current_state) * routing_weights[top_x, idx, None]
276
+ )
277
+
278
+ # However `index_add_` only support torch tensors for indexing so we'll use
279
+ # the `top_x` tensor here.
280
+ final_hidden_states.index_add_(
281
+ 0, top_x, current_hidden_states.to(hidden_states.dtype)
282
+ )
283
+ final_hidden_states = final_hidden_states.reshape(
284
+ *input_shape[:-1], self.out_features
285
+ )
286
+ final_hidden_states = pretrained_out + final_hidden_states
287
+ return final_hidden_states
288
+
289
+ @property
290
+ def weight(self):
291
+ """
292
+ Mimic linear layer. Bacause in some cases, user might indicate the device (or dtype of parameters) of the linear layer using `linear_layer.weight.device`
293
+ """
294
+ return self.shared_linear.weight
295
+
296
+ @property
297
+ def bias(self):
298
+ return self.shared_linear.bias
299
+
300
+ def __repr__(self):
301
+ return (
302
+ f"SingularMoELinear("
303
+ f"in_features={self.shared_linear.in_features}, "
304
+ f"out_features={self.shared_linear.out_features}, "
305
+ f"num_local_experts={self.num_local_experts}, "
306
+ f"num_experts_per_tok={self.num_experts_per_tok}, "
307
+ f"rank_of_router={self.rank_of_router}, "
308
+ f"rank_of_expert={self.rank_of_expert}"
309
+ f")"
310
+ )
311
+
312
+
313
+ @torch.no_grad()
314
+ def upscale_to_smile_linear(
315
+ base: nn.Linear, experts: List[nn.Linear], target: SmileLinear, accelerator=None
316
+ ):
317
+ """
318
+ Upscale a base linear layer to a SmileLinear layer using expert models.
319
+
320
+ Args:
321
+ base (nn.Linear): The base linear layer.
322
+ experts (List[nn.Linear]): A list of expert linear layers.
323
+ target (SmileLinear): The target SmileLinear layer.
324
+ accelerator (optional): The device to perform the computation on. Defaults to None.
325
+
326
+ Returns:
327
+ SmileLinear: The upscaled SmileLinear layer.
328
+ """
329
+ w = base.weight
330
+ w_ft_list = [e.weight for e in experts]
331
+ dw_list = [w_ft - w for w_ft in w_ft_list]
332
+
333
+ if _is_all_zeros(dw_list):
334
+ raise ExpertNotTrainedError("Expert models are not trained")
335
+
336
+ rank_of_router = target.rank_of_router
337
+ rank_of_expert = target.rank_of_expert
338
+ num_local_experts = target.num_local_experts
339
+ svd_list = [svd(dw, accelerator=accelerator) for dw in dw_list]
340
+
341
+ # gate
342
+ gate_weight = []
343
+ for u, s, v in svd_list:
344
+ gate_weight.append(v[:, :rank_of_router].T)
345
+ gate_weight = (
346
+ torch.stack(gate_weight, dim=0)
347
+ .reshape(num_local_experts * rank_of_router, -1)
348
+ .contiguous()
349
+ )
350
+
351
+ target.gate.load_state_dict({"weight": gate_weight})
352
+
353
+ # shared linear
354
+ target.shared_linear.load_state_dict(base.state_dict())
355
+
356
+ # experts
357
+ if rank_of_expert > 0:
358
+ for expert_idx, target_expert in enumerate(target.experts):
359
+ u, s, v = svd_list[expert_idx]
360
+ u = u[:, :rank_of_expert]
361
+ s = s[:rank_of_expert]
362
+ v = v[:, :rank_of_expert]
363
+ state_dict = {"u": u, "svh": (s * v).T}
364
+ if experts[expert_idx].bias is not None:
365
+ state_dict["bias"] = experts[expert_idx].bias.data
366
+ target_expert.load_state_dict(state_dict)
367
+ else:
368
+ for expert_idx, target_expert in enumerate(target.experts):
369
+ target_expert.load_state_dict(
370
+ state_dict_sub(experts[expert_idx].state_dict(), base.state_dict())
371
+ )
372
+
373
+ return target
@@ -1,10 +1,12 @@
1
1
  import logging
2
- from typing import Dict, List, Tuple # noqa: F401
2
+ from typing import Dict, List, Optional, Tuple, Union # noqa: F401
3
3
 
4
4
  import torch
5
5
  import torch.nn.functional as F
6
6
  from torch import Tensor, nn
7
7
 
8
+ from .utils import _is_all_zeros, svd
9
+
8
10
  log = logging.getLogger(__name__)
9
11
 
10
12
 
@@ -12,50 +14,42 @@ class ExpertNotTrainedError(Exception):
12
14
  pass
13
15
 
14
16
 
15
- def _is_all_zeros(tensor: Tensor | List[Tensor]) -> bool:
16
- if isinstance(tensor, Tensor):
17
- return torch.allclose(tensor, torch.zeros_like(tensor))
18
- else:
19
- return all(_is_all_zeros(t) for t in tensor)
20
-
21
-
22
- def _svd(w: Tensor, full_matrices=True) -> Tuple[Tensor, Tensor, Tensor]:
23
- u, s, vh = torch.linalg.svd(
24
- w, full_matrices=full_matrices, driver="gesvd" if w.is_cuda else None
25
- )
26
- v = vh.T
27
- return u, s, v
28
-
29
-
30
- def svd(
31
- w: Tensor, full_matrices=True, accelerator=None
32
- ) -> Tuple[Tensor, Tensor, Tensor]:
33
- if accelerator is None:
34
- return _svd(w, full_matrices=full_matrices)
35
- original_device = w.device
36
- w = w.to(accelerator)
37
- u, s, v = _svd(w)
38
- return u.to(original_device), s.to(original_device), v.to(original_device)
39
-
40
-
41
17
  class SmileGate(nn.Module):
18
+ __constants__ = ["in_features", "num_experts", "k"]
19
+ in_features: int
20
+ num_experts: int
21
+ k: int
22
+ weight: nn.Parameter
23
+
42
24
  def __init__(
43
25
  self,
44
26
  input_features: int,
45
27
  w_diff_list: List[Tensor],
46
28
  k: int,
47
- svd_list=None, # cached `svd_list`, pass it to avoid recomputing
29
+ svd_cache: List[
30
+ Tuple[Tensor, Tensor, Tensor]
31
+ ] = None, # cached `svd_cache`, pass it to avoid recomputing
48
32
  upscaling_accelerator=None,
49
33
  ):
34
+ R"""
35
+ This constructs weights through SVD decomposition.
36
+
37
+ Args:
38
+ input_features: The dimension of input features.
39
+ w_diff_list: The list of weight matrices to be decomposed.
40
+ k: The number of singular values to keep.
41
+ svd_cache: The cached SVD decomposition results. If not provided, the SVD decomposition will be computed on the fly.
42
+ upscaling_accelerator: The accelerator to use for SVD decomposition.
43
+ """
50
44
  super().__init__()
51
45
  self.input_features = input_features
52
46
  self.num_experts = len(w_diff_list)
53
47
  weights = []
54
48
  for i, w_diff in enumerate(w_diff_list):
55
- if svd_list is None:
49
+ if svd_cache is None:
56
50
  u, s, v = svd(w_diff, accelerator=upscaling_accelerator)
57
51
  else:
58
- u, s, v = svd_list[i]
52
+ u, s, v = svd_cache[i]
59
53
  u = u[:, :k]
60
54
  s = s[:k]
61
55
  v = v[:, :k]
@@ -86,8 +80,38 @@ class SmileGate(nn.Module):
86
80
 
87
81
 
88
82
  class SmileCompressedLinear(nn.Module):
89
- def __init__(self, model: nn.Linear, k: int, svd_cache=None):
83
+ """
84
+ This module is used to compress a linear layer using SVD decomposition.
85
+ """
86
+
87
+ __constants__ = ["in_features", "out_features", "k"]
88
+ in_features: int
89
+ out_features: int
90
+ k: int
91
+
92
+ u: nn.Parameter
93
+ svh: nn.Parameter
94
+ bias: Optional[nn.Parameter]
95
+
96
+ def __init__(
97
+ self,
98
+ model: nn.Linear,
99
+ k: int,
100
+ svd_cache: Optional[Tuple[Tensor, Tensor, Tensor]] = None,
101
+ ):
102
+ """
103
+ Initialize the SmileCompressedLinear module.
104
+
105
+ Args:
106
+ model (nn.Linear): The linear model to compress.
107
+ k (int): The number of singular values to keep.
108
+ svd_cache (Tuple[Tensor, Tensor, Tensor]): Cached SVD results.
109
+ """
90
110
  super().__init__()
111
+ self.in_features = model.in_features
112
+ self.out_features = model.out_features
113
+ self.k = k
114
+
91
115
  if svd_cache is None:
92
116
  u, s, v = svd(model.weight)
93
117
  else:
@@ -106,12 +130,36 @@ class SmileCompressedLinear(nn.Module):
106
130
  self.register_parameter("bias", None)
107
131
 
108
132
  def forward(self, x):
133
+ """
134
+ Forward pass of the SmileCompressedLinear module.
135
+
136
+ Args:
137
+ x (Tensor): The input tensor.
138
+
139
+ Returns:
140
+ Tensor: The output tensor.
141
+ """
109
142
  x = F.linear(x, self.svh)
110
143
  x = F.linear(x, self.u, self.bias)
111
144
  return x
112
145
 
113
146
 
114
147
  class SmileMoELinear(nn.Module):
148
+ __constants__ = [
149
+ "in_features",
150
+ "out_features",
151
+ "num_experts",
152
+ "top_k",
153
+ "gate_k",
154
+ "k",
155
+ ]
156
+ in_features: int
157
+ out_features: int
158
+ num_experts: int
159
+ top_k: int
160
+ gate_k: int
161
+ k: int
162
+
115
163
  @torch.no_grad()
116
164
  def __init__(
117
165
  self,
@@ -124,6 +172,19 @@ class SmileMoELinear(nn.Module):
124
172
  upscaling_accelerator=None,
125
173
  routing_use_diff=True,
126
174
  ):
175
+ """
176
+ Initialize the SmileMoELinear module.
177
+
178
+ Args:
179
+ pretrained_model (nn.Linear): The pretrained linear model.
180
+ finetuned_models (List[nn.Linear]): A list of fine-tuned linear models.
181
+ gate_k (int): The number of singular values to keep for the gate.
182
+ k (int): The number of singular values to keep for the experts.
183
+ top_k (int): The number of top experts to select.
184
+ full_matrices (bool): Whether to compute the full-sized U and V matrices.
185
+ upscaling_accelerator (str): The device to perform the computation on.
186
+ routing_use_diff (bool): Whether to use weight differences for routing.
187
+ """
127
188
  super().__init__()
128
189
  self.num_experts = len(finetuned_models)
129
190
  self.top_k = top_k
@@ -149,7 +210,7 @@ class SmileMoELinear(nn.Module):
149
210
  input_features=self.in_features,
150
211
  w_diff_list=w_diff_list,
151
212
  k=gate_k,
152
- svd_list=svd_cache_list,
213
+ svd_cache=svd_cache_list,
153
214
  upscaling_accelerator=upscaling_accelerator,
154
215
  )
155
216
  else:
@@ -157,7 +218,7 @@ class SmileMoELinear(nn.Module):
157
218
  input_features=self.in_features,
158
219
  w_diff_list=[m.weight for m in finetuned_models],
159
220
  k=gate_k,
160
- svd_list=None,
221
+ svd_cache=None,
161
222
  upscaling_accelerator=upscaling_accelerator,
162
223
  )
163
224
 
@@ -181,6 +242,15 @@ class SmileMoELinear(nn.Module):
181
242
  self.pretrained_model = pretrained_model
182
243
 
183
244
  def forward(self, hidden_states: Tensor):
245
+ """
246
+ Forward pass of the SmileMoELinear module.
247
+
248
+ Args:
249
+ hidden_states (Tensor): The input tensor.
250
+
251
+ Returns:
252
+ Tensor: The output tensor.
253
+ """
184
254
  pretrained_out = self.pretrained_model(hidden_states)
185
255
 
186
256
  input_shape = hidden_states.size()