fusion-bench 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (209) hide show
  1. fusion_bench/compat/method/__init__.py +2 -0
  2. fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
  3. fusion_bench/constants/clip_vision.py +22 -0
  4. fusion_bench/dataset/clip_dataset.py +10 -2
  5. fusion_bench/dataset/fer2013.py +1 -0
  6. fusion_bench/dataset/gsm8k.py +2 -2
  7. fusion_bench/method/__init__.py +10 -0
  8. fusion_bench/method/ada_svd/clip_vision.py +4 -1
  9. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
  10. fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
  11. fusion_bench/method/gossip/__init__.py +3 -0
  12. fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
  13. fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
  14. fusion_bench/method/gossip/entropy_loss.py +25 -0
  15. fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
  16. fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
  17. fusion_bench/method/gossip/min_norm_solvers.py +227 -0
  18. fusion_bench/method/gossip/task_wise_gossip.py +265 -0
  19. fusion_bench/method/gossip/utils.py +74 -0
  20. fusion_bench/method/isotropic_merging/__init__.py +1 -1
  21. fusion_bench/method/opcm/opcm.py +16 -7
  22. fusion_bench/method/pwe_moe/module.py +1 -1
  23. fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
  24. fusion_bench/method/regmean/regmean.py +25 -17
  25. fusion_bench/method/smile_upscaling/__init__.py +1 -1
  26. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +46 -145
  27. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +229 -0
  28. fusion_bench/method/smile_upscaling/smile_upscaling.py +19 -346
  29. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
  30. fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
  31. fusion_bench/method/ties_merging/ties_merging.py +36 -31
  32. fusion_bench/method/we_moe/we_moe.py +14 -15
  33. fusion_bench/mixins/__init__.py +6 -3
  34. fusion_bench/mixins/hydra_config.py +49 -0
  35. fusion_bench/mixins/openclip_classification.py +11 -0
  36. fusion_bench/mixins/simple_profiler.py +4 -2
  37. fusion_bench/modelpool/__init__.py +3 -1
  38. fusion_bench/modelpool/base_pool.py +2 -2
  39. fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
  40. fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
  41. fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +2 -203
  42. fusion_bench/models/modeling_smile_qwen2/__init__.py +8 -0
  43. fusion_bench/models/modeling_smile_qwen2/configuration_smile_qwen2.py +21 -0
  44. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +922 -0
  45. fusion_bench/models/modeling_smile_qwen2/register.py +11 -0
  46. fusion_bench/models/open_clip/__init__.py +6 -0
  47. fusion_bench/models/open_clip/modeling.py +176 -0
  48. fusion_bench/models/open_clip/utils.py +311 -0
  49. fusion_bench/models/open_clip/variables_and_paths.py +56 -0
  50. fusion_bench/models/parameter_dict.py +54 -13
  51. fusion_bench/models/rankone_moe.py +2 -88
  52. fusion_bench/models/smile_moe/linear_from_hf_config.py +373 -0
  53. fusion_bench/models/smile_moe/{linear.py → linear_from_module.py} +103 -33
  54. fusion_bench/models/smile_moe/utils/__init__.py +24 -0
  55. fusion_bench/models/smile_moe/utils/svd_utils.py +46 -0
  56. fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
  57. fusion_bench/taskpool/__init__.py +7 -3
  58. fusion_bench/taskpool/clip_vision/__init__.py +1 -0
  59. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
  60. fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
  61. fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
  62. fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
  63. fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
  64. fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
  65. fusion_bench/taskpool/gpt2_text_classification.py +30 -1
  66. fusion_bench/taskpool/lm_eval_harness/__init__.py +3 -0
  67. fusion_bench/taskpool/lm_eval_harness/taskpool.py +87 -0
  68. fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
  69. fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
  70. fusion_bench/utils/data.py +12 -0
  71. fusion_bench/utils/devices.py +14 -0
  72. fusion_bench/utils/instantiate.py +12 -0
  73. fusion_bench/utils/misc.py +9 -2
  74. fusion_bench/utils/packages.py +14 -0
  75. fusion_bench/utils/parameters.py +1 -1
  76. fusion_bench/utils/tensorboard.py +1 -1
  77. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/METADATA +22 -2
  78. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/RECORD +209 -157
  79. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/WHEEL +1 -1
  80. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
  81. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
  82. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
  83. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
  84. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
  85. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
  86. fusion_bench_config/fabric/auto.yaml +0 -1
  87. fusion_bench_config/fabric/llama_ddp.yaml +0 -1
  88. fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
  89. fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
  90. fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
  91. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
  92. fusion_bench_config/fabric_model_fusion.yaml +0 -1
  93. fusion_bench_config/llama_full_finetune.yaml +0 -2
  94. fusion_bench_config/llama_model_fusion.yaml +0 -2
  95. fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
  96. fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
  97. fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
  98. fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
  99. fusion_bench_config/method/adamerging.yaml +2 -2
  100. fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
  101. fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
  102. fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
  103. fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
  104. fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
  105. fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
  106. fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
  107. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
  108. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
  109. fusion_bench_config/method/dare/simple_average.yaml +0 -1
  110. fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
  111. fusion_bench_config/method/dare/ties_merging.yaml +0 -2
  112. fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
  113. fusion_bench_config/method/doge_ta/doge_ta.yaml +1 -1
  114. fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
  115. fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
  116. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
  117. fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
  118. fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
  119. fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
  120. fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
  121. fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
  122. fusion_bench_config/method/linear/llama_expo.yaml +0 -3
  123. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
  124. fusion_bench_config/method/linear/weighted_average.yaml +0 -1
  125. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
  126. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
  127. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
  128. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
  129. fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
  130. fusion_bench_config/method/model_recombination.yaml +0 -1
  131. fusion_bench_config/method/opcm/opcm.yaml +0 -1
  132. fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
  133. fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
  134. fusion_bench_config/method/opcm/weight_average.yaml +0 -1
  135. fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
  136. fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
  137. fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
  138. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
  139. fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
  140. fusion_bench_config/method/slerp/slerp.yaml +0 -2
  141. fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +5 -2
  142. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +13 -0
  143. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
  144. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
  145. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
  146. fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
  147. fusion_bench_config/method/task_arithmetic.yaml +1 -1
  148. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
  149. fusion_bench_config/method/ties_merging.yaml +1 -1
  150. fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
  151. fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
  152. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
  153. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
  154. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
  155. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
  156. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
  157. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
  158. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
  159. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
  160. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
  161. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
  162. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
  163. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
  164. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
  165. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
  166. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
  167. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
  168. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
  169. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
  170. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
  171. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
  172. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
  173. fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
  174. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
  175. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
  176. fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +17 -0
  177. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
  178. fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
  179. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
  180. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
  181. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
  182. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
  183. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
  184. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
  185. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
  186. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
  187. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
  188. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
  189. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +0 -2
  190. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +1 -3
  191. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
  192. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
  193. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
  194. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
  195. fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
  196. fusion_bench_config/nyuv2_config.yaml +0 -2
  197. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
  198. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
  199. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
  200. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
  201. fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml +12 -0
  202. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
  203. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
  204. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
  205. fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
  206. fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
  207. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/entry_points.txt +0 -0
  208. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/licenses/LICENSE +0 -0
  209. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/top_level.txt +0 -0
@@ -13,348 +13,18 @@ from fusion_bench.method import BaseAlgorithm
13
13
  from fusion_bench.method.simple_average import simple_average
14
14
  from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
15
15
  from fusion_bench.modelpool import BaseModelPool
16
+ from fusion_bench.models.smile_moe.linear_from_module import (
17
+ ExpertNotTrainedError,
18
+ SmileCompressedLinear,
19
+ SmileGate,
20
+ SmileMoELinear,
21
+ )
16
22
  from fusion_bench.models.utils import get_attr, set_attr
17
23
  from fusion_bench.utils.parameters import print_parameters
18
24
 
19
25
  log = logging.getLogger(__name__)
20
26
 
21
27
 
22
- class ExpertNotTrainedError(Exception):
23
- pass
24
-
25
-
26
- def _is_all_zeros(tensor: Tensor | List[Tensor]) -> bool:
27
- """
28
- Check if a tensor or a list of tensors are all zeros.
29
-
30
- Args:
31
- tensor (Tensor | List[Tensor]): A tensor or a list of tensors.
32
-
33
- Returns:
34
- bool: True if all elements are zeros, False otherwise.
35
- """
36
- if isinstance(tensor, Tensor):
37
- return torch.allclose(tensor, torch.zeros_like(tensor))
38
- else:
39
- return all(_is_all_zeros(t) for t in tensor)
40
-
41
-
42
- def _svd(w: Tensor, full_matrices=True) -> Tuple[Tensor, Tensor, Tensor]:
43
- """
44
- Perform Singular Value Decomposition (SVD) on a tensor.
45
-
46
- Args:
47
- w (Tensor): The input tensor.
48
- full_matrices (bool): Whether to compute the full-sized U and V matrices.
49
-
50
- Returns:
51
- Tuple[Tensor, Tensor, Tensor]: The U, S, and V matrices from SVD.
52
- """
53
- u, s, vh = torch.linalg.svd(
54
- w, full_matrices=full_matrices, driver="gesvd" if w.is_cuda else None
55
- )
56
- v = vh.T
57
- return u, s, v
58
-
59
-
60
- def svd(
61
- w: Tensor, full_matrices=True, accelerator=None
62
- ) -> Tuple[Tensor, Tensor, Tensor]:
63
- """
64
- Perform SVD on a tensor, optionally using a specified accelerator.
65
-
66
- Args:
67
- w (Tensor): The input tensor.
68
- full_matrices (bool): Whether to compute the full-sized U and V matrices.
69
- accelerator (str): The device to perform the computation on.
70
-
71
- Returns:
72
- Tuple[Tensor, Tensor, Tensor]: The U, S, and V matrices from SVD.
73
- """
74
- if accelerator is None:
75
- return _svd(w, full_matrices=full_matrices)
76
- original_device = w.device
77
- w = w.to(accelerator)
78
- u, s, v = _svd(w)
79
- return u.to(original_device), s.to(original_device), v.to(original_device)
80
-
81
-
82
- class SmileGate(nn.Module):
83
- def __init__(
84
- self,
85
- input_features: int,
86
- w_diff_list: List[Tensor],
87
- k: int,
88
- svd_list=None, # cached `svd_list`, pass it to avoid recomputing
89
- upscaling_accelerator=None,
90
- ):
91
- """
92
- Initialize the SmileGate module.
93
-
94
- Args:
95
- input_features (int): The number of input features.
96
- w_diff_list (List[Tensor]): A list of weight difference tensors.
97
- k (int): The number of singular values to keep.
98
- svd_list (List[Tuple[Tensor, Tensor, Tensor]]): Cached SVD results.
99
- upscaling_accelerator (str): The device to perform the computation on.
100
- """
101
- super().__init__()
102
- self.input_features = input_features
103
- self.num_experts = len(w_diff_list)
104
- weights = []
105
- for i, w_diff in enumerate(w_diff_list):
106
- if svd_list is None:
107
- u, s, v = svd(w_diff, accelerator=upscaling_accelerator)
108
- else:
109
- u, s, v = svd_list[i]
110
- u = u[:, :k]
111
- s = s[:k]
112
- v = v[:, :k]
113
-
114
- # weights.append((s * v).T)
115
- weights.append(v.T)
116
- self.k = s.size(0) # k is the actual k after truncation
117
-
118
- weights = (
119
- torch.stack(weights, dim=0)
120
- .reshape(self.num_experts * self.k, -1)
121
- .contiguous()
122
- )
123
- self.weights = nn.Parameter(
124
- weights
125
- ) # weights should be a tensor of shape (num_experts * k, n)
126
-
127
- def forward(self, x: Tensor):
128
- """
129
- Forward pass of the SmileGate module.
130
-
131
- Args:
132
- x (Tensor): The input tensor.
133
-
134
- Returns:
135
- Tensor: The routing weights.
136
- """
137
- batch_size = x.size(0)
138
- if self.num_experts == 1:
139
- return torch.ones(batch_size, 1, device=x.device, dtype=x.dtype)
140
-
141
- routing_weights = F.linear(x, self.weights).view(
142
- batch_size, self.num_experts, self.k
143
- )
144
- routing_weights = routing_weights.norm(p=2, dim=2)
145
- return routing_weights
146
-
147
-
148
- class SmileCompressedLinear(nn.Module):
149
- def __init__(self, model: nn.Linear, k: int, svd_cache=None):
150
- """
151
- Initialize the SmileCompressedLinear module.
152
-
153
- Args:
154
- model (nn.Linear): The linear model to compress.
155
- k (int): The number of singular values to keep.
156
- svd_cache (Tuple[Tensor, Tensor, Tensor]): Cached SVD results.
157
- """
158
- super().__init__()
159
- if svd_cache is None:
160
- u, s, v = svd(model.weight)
161
- else:
162
- u, s, v = svd_cache
163
- if k > 0:
164
- u = u[:, :k]
165
- s = s[:k]
166
- v = v[:, :k]
167
-
168
- self.u = nn.Parameter(u)
169
- self.svh = nn.Parameter((s * v).T)
170
-
171
- if model.bias is not None:
172
- self.bias = nn.Parameter(model.bias.data, requires_grad=True)
173
- else:
174
- self.register_parameter("bias", None)
175
-
176
- def forward(self, x):
177
- """
178
- Forward pass of the SmileCompressedLinear module.
179
-
180
- Args:
181
- x (Tensor): The input tensor.
182
-
183
- Returns:
184
- Tensor: The output tensor.
185
- """
186
- x = F.linear(x, self.svh)
187
- x = F.linear(x, self.u, self.bias)
188
- return x
189
-
190
-
191
- class SmileMoELinear(nn.Module):
192
- @torch.no_grad()
193
- def __init__(
194
- self,
195
- pretrained_model: nn.Linear,
196
- finetuned_models: List[nn.Linear],
197
- gate_k: int,
198
- k: int,
199
- top_k: int = 1,
200
- full_matrices=True,
201
- upscaling_accelerator=None,
202
- routing_use_diff=True,
203
- ):
204
- """
205
- Initialize the SmileMoELinear module.
206
-
207
- Args:
208
- pretrained_model (nn.Linear): The pretrained linear model.
209
- finetuned_models (List[nn.Linear]): A list of fine-tuned linear models.
210
- gate_k (int): The number of singular values to keep for the gate.
211
- k (int): The number of singular values to keep for the experts.
212
- top_k (int): The number of top experts to select.
213
- full_matrices (bool): Whether to compute the full-sized U and V matrices.
214
- upscaling_accelerator (str): The device to perform the computation on.
215
- routing_use_diff (bool): Whether to use weight differences for routing.
216
- """
217
- super().__init__()
218
- self.num_experts = len(finetuned_models)
219
- self.top_k = top_k
220
- self.k = k
221
- self.gate_k = gate_k
222
- self.in_features = pretrained_model.in_features
223
- self.out_features = pretrained_model.out_features
224
-
225
- w_diff_list = [m.weight - pretrained_model.weight for m in finetuned_models]
226
- if _is_all_zeros(w_diff_list):
227
- # All fine-tuned models are identical to the pretrained model
228
- raise ExpertNotTrainedError()
229
-
230
- if routing_use_diff or k > 0:
231
- svd_cache_list = [
232
- svd(w, full_matrices=full_matrices, accelerator=upscaling_accelerator)
233
- for w in w_diff_list
234
- ] # the svd cache list to avoid recomputing
235
-
236
- # construct the gate network
237
- if routing_use_diff:
238
- self.gate = SmileGate(
239
- input_features=self.in_features,
240
- w_diff_list=w_diff_list,
241
- k=gate_k,
242
- svd_list=svd_cache_list,
243
- upscaling_accelerator=upscaling_accelerator,
244
- )
245
- else:
246
- self.gate = SmileGate(
247
- input_features=self.in_features,
248
- w_diff_list=[m.weight for m in finetuned_models],
249
- k=gate_k,
250
- svd_list=None,
251
- upscaling_accelerator=upscaling_accelerator,
252
- )
253
-
254
- # construct experts
255
- for m, w_diff in zip(finetuned_models, w_diff_list):
256
- m.weight.data = w_diff
257
- if k > 0:
258
- experts = [
259
- SmileCompressedLinear(m, k, svd_cache=svd_cache)
260
- for m, svd_cache in zip(finetuned_models, svd_cache_list)
261
- ]
262
- else:
263
- # if k is not set (<0), we use the full fine-tuned model
264
- experts = finetuned_models
265
- self.experts = nn.ModuleList(experts)
266
-
267
- if pretrained_model.bias is not None:
268
- for m in experts:
269
- m.bias.data = m.bias.data - pretrained_model.bias
270
- # assign the pretrained model (the shared part)
271
- self.pretrained_model = pretrained_model
272
-
273
- def forward(self, hidden_states: Tensor):
274
- """
275
- Forward pass of the SmileMoELinear module.
276
-
277
- Args:
278
- hidden_states (Tensor): The input tensor.
279
-
280
- Returns:
281
- Tensor: The output tensor.
282
- """
283
- pretrained_out = self.pretrained_model(hidden_states)
284
-
285
- input_shape = hidden_states.size()
286
- hidden_states = hidden_states.view(-1, self.in_features)
287
-
288
- router_logits = self.gate(hidden_states)
289
- routing_weights = F.softmax(router_logits, dim=1)
290
- # sample the expert according to the routing weights
291
- routing_weights, selected_experts = torch.topk(
292
- routing_weights, self.top_k, dim=-1
293
- )
294
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
295
-
296
- final_hidden_states = torch.zeros(
297
- (hidden_states.size(0), self.out_features),
298
- dtype=hidden_states.dtype,
299
- device=hidden_states.device,
300
- )
301
-
302
- # One hot encode the selected experts to create an expert mask
303
- # this will be used to easily index which expert is going to be sollicitated
304
- expert_mask = torch.nn.functional.one_hot(
305
- selected_experts, num_classes=self.num_experts
306
- ).permute(2, 1, 0)
307
-
308
- # Loop over all available experts in the model and perform the computation on each expert
309
- for expert_idx in range(self.num_experts):
310
- expert_layer = self.experts[expert_idx]
311
- idx, top_x = torch.where(expert_mask[expert_idx])
312
-
313
- # Index the correct hidden states and compute the expert hidden state for
314
- # the current expert. We need to make sure to multiply the output hidden
315
- # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
316
- current_state = hidden_states[None, top_x].reshape(-1, self.in_features)
317
- if current_state.numel() == 0:
318
- continue
319
- current_hidden_states = (
320
- expert_layer(current_state) * routing_weights[top_x, idx, None]
321
- )
322
-
323
- # However `index_add_` only support torch tensors for indexing so we'll use
324
- # the `top_x` tensor here.
325
- final_hidden_states.index_add_(
326
- 0, top_x, current_hidden_states.to(hidden_states.dtype)
327
- )
328
- final_hidden_states = final_hidden_states.reshape(
329
- *input_shape[:-1], self.out_features
330
- )
331
- final_hidden_states = pretrained_out + final_hidden_states
332
- return final_hidden_states
333
-
334
- @property
335
- def weight(self):
336
- """
337
- Mimic linear layer. Bacause in some cases, user might indicate the device (or dtype of parameters) of the linear layer using `linear_layer.weight.device`
338
- """
339
- return self.pretrained_model.weight
340
-
341
- @property
342
- def bias(self):
343
- return self.pretrained_model.bias
344
-
345
- def __repr__(self):
346
- return (
347
- f"SingularMoELinear("
348
- f"in_features={self.pretrained_model.in_features}, "
349
- f"out_features={self.pretrained_model.out_features}, "
350
- f"num_experts={self.num_experts}, "
351
- f"top_k={self.top_k}, "
352
- f"gate_k={self.gate_k}, "
353
- f"k={self.k}"
354
- f")"
355
- )
356
-
357
-
358
28
  class SmileUpscalingAlgorithm(
359
29
  SimpleProfilerMixin,
360
30
  BaseAlgorithm,
@@ -442,16 +112,19 @@ class SmileUpscalingAlgorithm(
442
112
  print_parameters(model)
443
113
  return model
444
114
 
445
- with self.profile("load pretrained model"):
446
- pretrained_model = modelpool.load_model("_pretrained_")
447
- with self.profile("load fine-tuned model"):
448
- finetuned_models = [
449
- m for m in tqdm(modelpool.models(), total=len(modelpool.model_names))
450
- ]
451
-
452
- if self.config.device == "cuda" and torch.cuda.is_available():
453
- pretrained_model = pretrained_model.cuda()
454
- finetuned_models = [m.cuda() for m in finetuned_models]
115
+ with self.profile("loading model"):
116
+ # load models and move to GPU if available
117
+ with self.profile("load pretrained model"):
118
+ pretrained_model = modelpool.load_model("_pretrained_")
119
+ with self.profile("load fine-tuned model"):
120
+ finetuned_models = [
121
+ m
122
+ for m in tqdm(modelpool.models(), total=len(modelpool.model_names))
123
+ ]
124
+
125
+ if self.config.device == "cuda" and torch.cuda.is_available():
126
+ pretrained_model = pretrained_model.cuda()
127
+ finetuned_models = [m.cuda() for m in finetuned_models]
455
128
 
456
129
  with self.profile("merge model"):
457
130
  model = self.merge(pretrained_model, finetuned_models)
@@ -85,7 +85,14 @@ class CLIPLayerWiseAdaMergingSurgeryAlgorithm(
85
85
 
86
86
  if self.config.weights is not None:
87
87
  # skip the test-time adaptation
88
+ merge_weight: torch.Tensor = torch.load(self.config.weights)
89
+ module.merge_weight.data = merge_weight.to(
90
+ device=module.merge_weight.device
91
+ )
88
92
  merged_model = copy.deepcopy(module.merge_and_unload())
93
+ # setup the zero-shot classification head
94
+ self.on_test_time_adaptation_start()
95
+
89
96
  else:
90
97
  with self.profile("test-time adaptation"):
91
98
  module = self.test_time_adaptation(module)
@@ -6,7 +6,7 @@ http://arxiv.org/abs/2212.04089
6
6
 
7
7
  import logging
8
8
  from copy import deepcopy
9
- from typing import Dict, List, Mapping, TypeVar, Union # noqa: F401
9
+ from typing import Dict, List, Mapping, Optional, TypeVar, Union # noqa: F401
10
10
 
11
11
  import torch
12
12
  from torch import nn
@@ -19,18 +19,18 @@ from fusion_bench.utils.state_dict_arithmetic import (
19
19
  state_dict_mul,
20
20
  state_dict_sub,
21
21
  )
22
- from fusion_bench.utils.type import StateDictType
22
+ from fusion_bench.utils.type import StateDictType, TorchModelType
23
23
 
24
24
  log = logging.getLogger(__name__)
25
25
 
26
26
 
27
27
  @torch.no_grad()
28
28
  def task_arithmetic_merge(
29
- pretrained_model: nn.Module,
30
- finetuned_models: List[nn.Module],
29
+ pretrained_model: TorchModelType,
30
+ finetuned_models: List[TorchModelType],
31
31
  scaling_factor: float,
32
32
  inplace: bool = True,
33
- ) -> nn.Module:
33
+ ) -> TorchModelType:
34
34
  """
35
35
  Merges the task vectors from multiple fine-tuned models into a single pre-trained model.
36
36
 
@@ -46,15 +46,17 @@ def task_arithmetic_merge(
46
46
  """
47
47
  if not inplace:
48
48
  pretrained_model = deepcopy(pretrained_model)
49
- task_vector: StateDictType = None
49
+ task_vector: Optional[StateDictType] = None
50
50
  # Calculate the total task vector
51
51
  for model in finetuned_models:
52
52
  if task_vector is None:
53
+ # calculate the task vector for the first model
53
54
  task_vector = state_dict_sub(
54
55
  model.state_dict(keep_vars=True),
55
56
  pretrained_model.state_dict(keep_vars=True),
56
57
  )
57
58
  else:
59
+ # calculate the task vector for the remaining models
58
60
  task_vector = state_dict_add(
59
61
  task_vector,
60
62
  state_dict_sub(
@@ -16,6 +16,7 @@ from torch import Tensor, nn
16
16
 
17
17
  from fusion_bench.compat.modelpool import to_modelpool
18
18
  from fusion_bench.method import BaseAlgorithm
19
+ from fusion_bench.mixins import SimpleProfilerMixin
19
20
  from fusion_bench.modelpool import BaseModelPool
20
21
  from fusion_bench.utils.type import StateDictType
21
22
 
@@ -24,7 +25,7 @@ from .ties_merging_utils import state_dict_to_vector, ties_merging, vector_to_st
24
25
  log = logging.getLogger(__name__)
25
26
 
26
27
 
27
- class TiesMergingAlgorithm(BaseAlgorithm):
28
+ class TiesMergingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
28
29
  """
29
30
  TiesMergingAlgorithm is a class for fusing multiple models using the TIES merging technique.
30
31
 
@@ -84,34 +85,38 @@ class TiesMergingAlgorithm(BaseAlgorithm):
84
85
  scaling_factor = self.scaling_factor
85
86
  threshold = self.threshold
86
87
 
87
- # Load the pretrained model
88
- pretrained_model = modelpool.load_model("_pretrained_")
89
-
90
- # Load the state dicts of the models
91
- ft_checks: List[StateDictType] = [
92
- modelpool.load_model(model_name).state_dict(keep_vars=True)
93
- for model_name in modelpool.model_names
94
- ]
95
- ptm_check: StateDictType = pretrained_model.state_dict(keep_vars=True)
96
-
97
- # Compute the task vectors
98
- flat_ft: Tensor = torch.vstack(
99
- [state_dict_to_vector(check, remove_keys) for check in ft_checks]
100
- )
101
- flat_ptm: Tensor = state_dict_to_vector(ptm_check, remove_keys)
102
- tv_flat_checks = flat_ft - flat_ptm
103
-
104
- # Perform TIES Merging
105
- merged_tv = ties_merging(
106
- tv_flat_checks,
107
- reset_thresh=threshold,
108
- merge_func=merge_func,
109
- )
110
- merged_check = flat_ptm + scaling_factor * merged_tv
111
- merged_state_dict = vector_to_state_dict(
112
- merged_check, ptm_check, remove_keys=remove_keys
113
- )
114
-
115
- # Load the merged state dict into the pretrained model
116
- pretrained_model.load_state_dict(merged_state_dict)
88
+ with self.profile("loading models"):
89
+ # Load the pretrained model
90
+ pretrained_model = modelpool.load_model("_pretrained_")
91
+
92
+ # Load the state dicts of the models
93
+ ft_checks: List[StateDictType] = [
94
+ modelpool.load_model(model_name).state_dict(keep_vars=True)
95
+ for model_name in modelpool.model_names
96
+ ]
97
+ ptm_check: StateDictType = pretrained_model.state_dict(keep_vars=True)
98
+
99
+ with self.profile("merging models"):
100
+ # Compute the task vectors
101
+ flat_ft: Tensor = torch.vstack(
102
+ [state_dict_to_vector(check, remove_keys) for check in ft_checks]
103
+ )
104
+ flat_ptm: Tensor = state_dict_to_vector(ptm_check, remove_keys)
105
+ tv_flat_checks = flat_ft - flat_ptm
106
+
107
+ # Perform TIES Merging
108
+ merged_tv = ties_merging(
109
+ tv_flat_checks,
110
+ reset_thresh=threshold,
111
+ merge_func=merge_func,
112
+ )
113
+ merged_check = flat_ptm + scaling_factor * merged_tv
114
+ merged_state_dict = vector_to_state_dict(
115
+ merged_check, ptm_check, remove_keys=remove_keys
116
+ )
117
+
118
+ # Load the merged state dict into the pretrained model
119
+ pretrained_model.load_state_dict(merged_state_dict)
120
+
121
+ self.print_profile_summary()
117
122
  return pretrained_model
@@ -5,7 +5,6 @@ from typing import cast # noqa: F401
5
5
  import lightning as L
6
6
  import lightning.fabric.wrappers
7
7
  import torch
8
- from lightning.pytorch.profilers import SimpleProfiler
9
8
  from omegaconf import DictConfig
10
9
  from torch import Tensor
11
10
  from torch.utils.data import DataLoader
@@ -13,6 +12,7 @@ from tqdm.autonotebook import tqdm
13
12
 
14
13
  from fusion_bench.compat.method.base_algorithm import ModelFusionAlgorithm
15
14
  from fusion_bench.compat.modelpool import ModelPool
15
+ from fusion_bench.mixins import SimpleProfilerMixin
16
16
  from fusion_bench.models.we_moe import WeightEnsemblingMoE
17
17
  from fusion_bench.utils import timeit_context
18
18
  from fusion_bench.utils.parameters import print_parameters
@@ -34,7 +34,10 @@ def entropy_loss(logits: Tensor) -> Tensor:
34
34
  return -torch.sum(probs * torch.log(probs + 1e-8), dim=-1).mean()
35
35
 
36
36
 
37
- class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
37
+ class WeightEnsemblingMoEAlgorithm(
38
+ ModelFusionAlgorithm,
39
+ SimpleProfilerMixin,
40
+ ):
38
41
  """
39
42
  Algorithm for fusing models using Weight Ensembling Mixture of Experts (MoE).
40
43
 
@@ -44,7 +47,6 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
44
47
  Attributes:
45
48
  _fabric (L.Fabric): The fabric for distributed training.
46
49
  modelpool (ModelPool): The pool of models to be fused.
47
- profiler (SimpleProfiler): The profiler for measuring performance.
48
50
  """
49
51
 
50
52
  _fabric: L.Fabric = None
@@ -66,9 +68,6 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
66
68
  self._fabric.launch()
67
69
  else:
68
70
  assert "No CUDA device available."
69
- self.profiler = SimpleProfiler(
70
- self.config.get("cache_dir", "outputs"), "we_moe_profiler.txt"
71
- )
72
71
 
73
72
  @abstractmethod
74
73
  def load_checkpoint(self, model, checkpoint):
@@ -177,9 +176,9 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
177
176
  for step_idx in pbar:
178
177
  if self.config.use_grad_accumulate:
179
178
  for task in self.modelpool.model_names:
180
- with self.profiler.profile("data time"):
179
+ with self.profile("data time"):
181
180
  batch = next(self.get_shuffled_test_loader_iter(task))
182
- with self.profiler.profile("forward pass"):
181
+ with self.profile("forward pass"):
183
182
  logits = self.compute_logits(module, batch, task)
184
183
  assert (
185
184
  logits.dim() == 2
@@ -187,23 +186,23 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
187
186
  loss = entropy_loss(logits)
188
187
  # .backward() accumulates when .zero_grad() wasn't called
189
188
  # this can save memory
190
- with self.profiler.profile("backward pass"):
189
+ with self.profile("backward pass"):
191
190
  self._fabric.backward(loss, retain_graph=True)
192
191
  else:
193
192
  loss = 0
194
193
  for task in self.modelpool.model_names:
195
- with self.profiler.profile("data time"):
194
+ with self.profile("data time"):
196
195
  batch = next(self.get_shuffled_test_loader_iter(task))
197
- with self.profiler.profile("forward pass"):
196
+ with self.profile("forward pass"):
198
197
  logits = self.compute_logits(module, batch, task)
199
198
  assert (
200
199
  logits.dim() == 2
201
200
  ), f"Expected logits to be 2D, got {logits.dim()}"
202
201
  loss = loss + entropy_loss(logits)
203
- with self.profiler.profile("backward pass"):
202
+ with self.profile("backward pass"):
204
203
  self._fabric.backward(loss, retain_graph=True)
205
204
 
206
- with self.profiler.profile("optimizer step"):
205
+ with self.profile("optimizer step"):
207
206
  optimizer.step()
208
207
  optimizer.zero_grad()
209
208
 
@@ -232,7 +231,7 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
232
231
  )
233
232
  self.load_checkpoint(moe_model, self.config.checkpoint)
234
233
  else:
235
- with self.profiler.profile("test-time adaptation"):
234
+ with self.profile("test-time adaptation"):
236
235
  moe_model = self.test_time_adaptation(moe_model)
237
236
  if self.config.get("save_checkpoint", False):
238
237
  log.info(f"save checkpoint to {self.config.save_checkpoint}")
@@ -243,5 +242,5 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
243
242
 
244
243
  # enable sample-wise adaptation
245
244
  moe_model.batch_reduce = False
246
- print(self.profiler.summary())
245
+ self.print_profile_summary()
247
246
  return moe_model
@@ -6,20 +6,23 @@ from typing_extensions import TYPE_CHECKING
6
6
  from fusion_bench.utils.lazy_imports import LazyImporter
7
7
 
8
8
  _import_structure = {
9
+ "clip_classification": ["CLIPClassificationMixin"],
10
+ "fabric_training": ["FabricTrainingMixin"],
11
+ "hydra_config": ["HydraConfigMixin"],
9
12
  "lightning_fabric": ["LightningFabricMixin"],
13
+ "openclip_classification": ["OpenCLIPClassificationMixin"],
10
14
  "serialization": ["YAMLSerializationMixin", "BaseYAMLSerializableModel"],
11
15
  "simple_profiler": ["SimpleProfilerMixin"],
12
- "clip_classification": ["CLIPClassificationMixin"],
13
- "fabric_training": ["FabricTrainingMixin"],
14
16
  }
15
17
 
16
18
  if TYPE_CHECKING:
17
19
  from .clip_classification import CLIPClassificationMixin
18
20
  from .fabric_training import FabricTrainingMixin
21
+ from .hydra_config import HydraConfigMixin
19
22
  from .lightning_fabric import LightningFabricMixin
23
+ from .openclip_classification import OpenCLIPClassificationMixin
20
24
  from .serialization import BaseYAMLSerializableModel, YAMLSerializationMixin
21
25
  from .simple_profiler import SimpleProfilerMixin
22
-
23
26
  else:
24
27
  sys.modules[__name__] = LazyImporter(
25
28
  __name__,