fusion-bench 0.2.12__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (190) hide show
  1. fusion_bench/compat/method/__init__.py +2 -0
  2. fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
  3. fusion_bench/constants/clip_vision.py +22 -0
  4. fusion_bench/dataset/clip_dataset.py +10 -2
  5. fusion_bench/dataset/fer2013.py +1 -0
  6. fusion_bench/dataset/gsm8k.py +2 -2
  7. fusion_bench/method/__init__.py +10 -0
  8. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
  9. fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
  10. fusion_bench/method/gossip/__init__.py +3 -0
  11. fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
  12. fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
  13. fusion_bench/method/gossip/entropy_loss.py +25 -0
  14. fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
  15. fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
  16. fusion_bench/method/gossip/min_norm_solvers.py +227 -0
  17. fusion_bench/method/gossip/task_wise_gossip.py +265 -0
  18. fusion_bench/method/gossip/utils.py +74 -0
  19. fusion_bench/method/isotropic_merging/__init__.py +1 -1
  20. fusion_bench/method/opcm/opcm.py +16 -7
  21. fusion_bench/method/pwe_moe/module.py +1 -1
  22. fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
  23. fusion_bench/method/regmean/regmean.py +25 -17
  24. fusion_bench/method/smile_upscaling/__init__.py +1 -1
  25. fusion_bench/method/smile_upscaling/smile_upscaling.py +13 -10
  26. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
  27. fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
  28. fusion_bench/method/ties_merging/ties_merging.py +36 -31
  29. fusion_bench/method/we_moe/we_moe.py +14 -15
  30. fusion_bench/mixins/__init__.py +6 -3
  31. fusion_bench/mixins/hydra_config.py +49 -0
  32. fusion_bench/mixins/openclip_classification.py +11 -0
  33. fusion_bench/mixins/simple_profiler.py +4 -2
  34. fusion_bench/modelpool/__init__.py +3 -1
  35. fusion_bench/modelpool/base_pool.py +2 -2
  36. fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
  37. fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
  38. fusion_bench/models/open_clip/__init__.py +6 -0
  39. fusion_bench/models/open_clip/modeling.py +176 -0
  40. fusion_bench/models/open_clip/utils.py +311 -0
  41. fusion_bench/models/open_clip/variables_and_paths.py +56 -0
  42. fusion_bench/models/parameter_dict.py +54 -13
  43. fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
  44. fusion_bench/taskpool/__init__.py +5 -3
  45. fusion_bench/taskpool/clip_vision/__init__.py +1 -0
  46. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
  47. fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
  48. fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
  49. fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
  50. fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
  51. fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
  52. fusion_bench/taskpool/gpt2_text_classification.py +30 -1
  53. fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
  54. fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
  55. fusion_bench/utils/data.py +12 -0
  56. fusion_bench/utils/devices.py +14 -0
  57. fusion_bench/utils/instantiate.py +12 -0
  58. fusion_bench/utils/misc.py +9 -2
  59. fusion_bench/utils/packages.py +14 -0
  60. fusion_bench/utils/parameters.py +1 -1
  61. fusion_bench/utils/tensorboard.py +1 -1
  62. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/METADATA +1 -1
  63. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/RECORD +190 -151
  64. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/WHEEL +1 -1
  65. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
  66. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
  67. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
  68. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
  69. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
  70. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
  71. fusion_bench_config/fabric/auto.yaml +0 -1
  72. fusion_bench_config/fabric/llama_ddp.yaml +0 -1
  73. fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
  74. fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
  75. fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
  76. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
  77. fusion_bench_config/fabric_model_fusion.yaml +0 -1
  78. fusion_bench_config/llama_full_finetune.yaml +0 -2
  79. fusion_bench_config/llama_model_fusion.yaml +0 -2
  80. fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
  81. fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
  82. fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
  83. fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
  84. fusion_bench_config/method/adamerging.yaml +2 -2
  85. fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
  86. fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
  87. fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
  88. fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
  89. fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
  90. fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
  91. fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
  92. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
  93. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
  94. fusion_bench_config/method/dare/simple_average.yaml +0 -1
  95. fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
  96. fusion_bench_config/method/dare/ties_merging.yaml +0 -2
  97. fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
  98. fusion_bench_config/method/doge_ta/doge_ta.yaml +1 -1
  99. fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
  100. fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
  101. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
  102. fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
  103. fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
  104. fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
  105. fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
  106. fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
  107. fusion_bench_config/method/linear/llama_expo.yaml +0 -3
  108. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
  109. fusion_bench_config/method/linear/weighted_average.yaml +0 -1
  110. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
  111. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
  112. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
  113. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
  114. fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
  115. fusion_bench_config/method/model_recombination.yaml +0 -1
  116. fusion_bench_config/method/opcm/opcm.yaml +0 -1
  117. fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
  118. fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
  119. fusion_bench_config/method/opcm/weight_average.yaml +0 -1
  120. fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
  121. fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
  122. fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
  123. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
  124. fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
  125. fusion_bench_config/method/slerp/slerp.yaml +0 -2
  126. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
  127. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
  128. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
  129. fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
  130. fusion_bench_config/method/task_arithmetic.yaml +1 -1
  131. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
  132. fusion_bench_config/method/ties_merging.yaml +1 -1
  133. fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
  134. fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
  135. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
  136. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
  137. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
  138. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
  139. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
  140. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
  141. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
  142. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
  143. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
  144. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
  145. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
  146. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
  147. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
  148. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
  149. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
  150. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
  151. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
  152. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
  153. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
  154. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
  155. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
  156. fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
  157. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
  158. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
  159. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
  160. fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
  161. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
  162. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
  163. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
  164. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
  165. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
  166. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
  167. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
  168. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
  169. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
  170. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
  171. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +0 -2
  172. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +1 -3
  173. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
  174. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
  175. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
  176. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
  177. fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
  178. fusion_bench_config/nyuv2_config.yaml +0 -2
  179. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
  180. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
  181. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
  182. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
  183. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
  184. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
  185. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
  186. fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
  187. fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
  188. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/entry_points.txt +0 -0
  189. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/licenses/LICENSE +0 -0
  190. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/top_level.txt +0 -0
@@ -12,36 +12,7 @@ from fusion_bench.models.hf_clip import HFCLIPClassifier
12
12
  from fusion_bench.models.rankone_moe import RankOneMoE
13
13
 
14
14
  from .taskpool import CLIPVisionModelTaskPool
15
-
16
-
17
- class LayerWiseRoutingWeightSaver:
18
- def __init__(self, save_path: Path, max_num: Optional[int] = None):
19
- self.save_path = save_path
20
- self.max_num = max_num
21
- self.routing_weights = []
22
-
23
- def __call__(self, module, input: Tuple[Tensor], output: Tensor):
24
- assert isinstance(output, Tensor), "Output is expected to be a Tensor"
25
- # (batch_size, num_tokens, num_experts)
26
- routing_weights = output.detach().cpu()
27
- if self.max_num is not None and self.max_num > 0:
28
- if len(self.routing_weights) > self.max_num:
29
- return
30
- elif routing_weights.size(0) + len(self.routing_weights) > self.max_num:
31
- self.routing_weights.append(
32
- routing_weights[: self.max_num - len(self.routing_weights)]
33
- )
34
- else:
35
- self.routing_weights.append(routing_weights)
36
- else:
37
- self.routing_weights.append(routing_weights)
38
-
39
- def save_routing_weights(self):
40
- routing_weights = torch.cat(self.routing_weights, dim=0)
41
- if self.save_path is not None:
42
- self.save_path.parent.mkdir(parents=True, exist_ok=True)
43
- print(f"Saving routing weights to {self.save_path}")
44
- torch.save(routing_weights, self.save_path)
15
+ from .utils.routing_analysis_utils import LayerWiseRoutingWeightSaver
45
16
 
46
17
 
47
18
  class RankoneMoECLIPVisionModelTaskPool(CLIPVisionModelTaskPool):
@@ -109,4 +80,5 @@ class RankoneMoECLIPVisionModelTaskPool(CLIPVisionModelTaskPool):
109
80
  # remove hooks for saving layer-wise routing weights
110
81
  for i, handle in self._layer_wise_routing_weights_save_hook_handles.items():
111
82
  self._layer_wise_routing_weights_save_hooks[i].save_routing_weights()
83
+ self._layer_wise_routing_weights_save_hook_handles.pop(i)
112
84
  handle.remove()
@@ -0,0 +1,102 @@
1
+ from copy import deepcopy
2
+ from pathlib import Path
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ from torch import Tensor
7
+ from torch.utils.hooks import RemovableHandle
8
+ from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
9
+ from transformers.models.clip.modeling_clip import CLIPVisionTransformer
10
+
11
+ from fusion_bench.method.smile_upscaling import SmileMoELinear
12
+ from fusion_bench.models.hf_clip import HFCLIPClassifier
13
+
14
+ from .taskpool import CLIPVisionModelTaskPool
15
+ from .utils.routing_analysis_utils import LayerWiseRoutingWeightSaver
16
+
17
+
18
+ class SmileCLIPVisionModelTaskPool(CLIPVisionModelTaskPool):
19
+
20
+ # hooks and handles for saving layer-wise routing weights
21
+ _layer_wise_routing_weights_save_hooks: Dict[Any, LayerWiseRoutingWeightSaver] = {}
22
+ _layer_wise_routing_weights_save_hook_handles: Dict[Any, RemovableHandle] = {}
23
+
24
+ def __init__(
25
+ self,
26
+ linear_module_names: Union[List[str], str],
27
+ layer_wise_routing_weights_save_path: Optional[str],
28
+ layer_wise_routing_weights_max_num: Optional[int] = None,
29
+ **kwargs,
30
+ ):
31
+ """
32
+ Initialize the SMILECLIPVisionModelTaskPool.
33
+
34
+ Args:
35
+ linear_module_names (Union[List[str], str]): The names of the linear modules to save the layer-wise routing weights for.
36
+ layer_wise_routing_weights_save_path (Optional[str]): The path to save the layer-wise routing weights.
37
+ layer_wise_routing_weights_max_num (Optional[int]): The maximum number of layer-wise routing weights to save.
38
+ """
39
+ # linear module names
40
+ assert linear_module_names is not None, "linear_module_names must be provided"
41
+ self.linear_module_names = (
42
+ [linear_module_names]
43
+ if isinstance(linear_module_names, str)
44
+ else list(linear_module_names)
45
+ )
46
+ # save path for layer-wise routing weights
47
+ self._layer_wise_routing_weights_save_path = (
48
+ layer_wise_routing_weights_save_path
49
+ )
50
+ self.layer_wise_routing_weights_save_path = (
51
+ Path(layer_wise_routing_weights_save_path)
52
+ if layer_wise_routing_weights_save_path is not None
53
+ else None
54
+ )
55
+ self.layer_wise_routing_weights_max_num = layer_wise_routing_weights_max_num
56
+ super().__init__(**kwargs)
57
+
58
+ def on_task_evaluation_begin(self, classifier: HFCLIPClassifier, task_name: str):
59
+ super().on_task_evaluation_begin(classifier, task_name)
60
+ if self.layer_wise_routing_weights_save_path is not None:
61
+ # setup hooks for saving layer-wise routing weights
62
+ assert isinstance(
63
+ classifier.clip_model.vision_model,
64
+ (CLIPVisionTransformer, CLIPVisionModel),
65
+ ), "Vision model is expected to be a CLIPVisionTransformer"
66
+ vision_model = classifier.clip_model.vision_model
67
+ if isinstance(vision_model, CLIPVisionModel):
68
+ vision_model = vision_model.vision_model
69
+ # assign forward hooks for each layer
70
+
71
+ for i, layer in enumerate(vision_model.encoder.layers):
72
+ for linear_module_name in self.linear_module_names:
73
+ linear_module = layer.get_submodule(linear_module_name)
74
+ assert isinstance(
75
+ linear_module,
76
+ (SmileMoELinear),
77
+ ), f"Linear module is expected to be a SmileMoELinear, but got {type(linear_module)}"
78
+ # layer-wise routing weights
79
+ hook = LayerWiseRoutingWeightSaver(
80
+ self.layer_wise_routing_weights_save_path
81
+ / task_name
82
+ / f"layer_{i}_{linear_module_name}.pt",
83
+ max_num=self.layer_wise_routing_weights_max_num,
84
+ )
85
+ self._layer_wise_routing_weights_save_hooks[
86
+ (i, linear_module_name)
87
+ ] = hook
88
+ self._layer_wise_routing_weights_save_hook_handles[
89
+ (i, linear_module_name)
90
+ ] = linear_module.gate.register_forward_hook(hook)
91
+
92
+ def on_task_evaluation_end(self):
93
+ super().on_task_evaluation_end()
94
+ if self.layer_wise_routing_weights_save_path is not None:
95
+ # remove hooks for saving layer-wise routing weights
96
+ for (
97
+ key,
98
+ handle,
99
+ ) in self._layer_wise_routing_weights_save_hook_handles.items():
100
+ self._layer_wise_routing_weights_save_hooks[key].save_routing_weights()
101
+ self._layer_wise_routing_weights_save_hook_handles.pop(key)
102
+ handle.remove()
@@ -15,36 +15,7 @@ from fusion_bench.models.sparse_we_moe import (
15
15
  )
16
16
 
17
17
  from .taskpool import CLIPVisionModelTaskPool
18
-
19
-
20
- class LayerWiseRoutingWeightSaver:
21
- def __init__(self, save_path: Path, max_num: Optional[int] = None):
22
- self.save_path = save_path
23
- self.max_num = max_num
24
- self.routing_weights = []
25
-
26
- def __call__(self, module, input: Tuple[Tensor], output: Tensor):
27
- assert isinstance(output, Tensor), "Output is expected to be a Tensor"
28
- # (batch_size, num_tokens, num_experts)
29
- routing_weights = output.detach().cpu()
30
- if self.max_num is not None and self.max_num > 0:
31
- if len(self.routing_weights) > self.max_num:
32
- return
33
- elif routing_weights.size(0) + len(self.routing_weights) > self.max_num:
34
- self.routing_weights.append(
35
- routing_weights[: self.max_num - len(self.routing_weights)]
36
- )
37
- else:
38
- self.routing_weights.append(routing_weights)
39
- else:
40
- self.routing_weights.append(routing_weights)
41
-
42
- def save_routing_weights(self):
43
- routing_weights = torch.cat(self.routing_weights, dim=0)
44
- if self.save_path is not None:
45
- self.save_path.parent.mkdir(parents=True, exist_ok=True)
46
- print(f"Saving routing weights to {self.save_path}")
47
- torch.save(routing_weights, self.save_path)
18
+ from .utils.routing_analysis_utils import LayerWiseRoutingWeightSaver
48
19
 
49
20
 
50
21
  class SparseWEMoECLIPVisionModelTaskPool(CLIPVisionModelTaskPool):
@@ -117,4 +88,5 @@ class SparseWEMoECLIPVisionModelTaskPool(CLIPVisionModelTaskPool):
117
88
  # remove hooks for saving layer-wise routing weights
118
89
  for i, handle in self._layer_wise_routing_weights_save_hook_handles.items():
119
90
  self._layer_wise_routing_weights_save_hooks[i].save_routing_weights()
91
+ self._layer_wise_routing_weights_save_hook_handles.pop(i)
120
92
  handle.remove()
@@ -32,8 +32,7 @@ from fusion_bench.mixins import LightningFabricMixin
32
32
  from fusion_bench.models.hf_clip import HFCLIPClassifier
33
33
  from fusion_bench.taskpool import BaseTaskPool
34
34
  from fusion_bench.tasks.clip_classification import get_classnames_and_templates
35
- from fusion_bench.utils import instantiate
36
- from fusion_bench.utils.parameters import count_parameters
35
+ from fusion_bench.utils import count_parameters, instantiate
37
36
 
38
37
  if TYPE_CHECKING:
39
38
  from fusion_bench.models.surgery.surgerymodelwrapper import SurgeryModelWrapper
File without changes
@@ -0,0 +1,65 @@
1
+ from pathlib import Path
2
+ from typing import List, Optional, Tuple
3
+
4
+ import torch
5
+ from torch import Tensor
6
+
7
+
8
+ def _number_of_samples(routing_weights: List[Tensor]):
9
+ count = 0
10
+ for routing_weight in routing_weights:
11
+ count += routing_weight.size(0)
12
+ return count
13
+
14
+
15
+ class LayerWiseRoutingWeightSaver:
16
+ """
17
+ A hook for saving layer-wise routing weights.
18
+ """
19
+
20
+ save_path: Path
21
+ "The path to save the layer-wise routing weights."
22
+ max_num: Optional[int]
23
+ "The maximum number of layer-wise routing weights to save. If None, all routing weights will be saved."
24
+ routing_weights: List[Tensor]
25
+ "The list of layer-wise routing weights."
26
+
27
+ def __init__(self, save_path: Path, max_num: Optional[int] = None):
28
+ """
29
+ Args:
30
+ save_path (Path): The path to save the layer-wise routing weights.
31
+ max_num (Optional[int]): The maximum number of layer-wise routing weights to save. If None, all routing weights will be saved.
32
+ """
33
+ self.save_path = save_path
34
+ self.max_num = max_num
35
+ self.routing_weights = []
36
+
37
+ def __call__(self, module, input: Tuple[Tensor], output: Tensor):
38
+ assert isinstance(output, Tensor), "Output is expected to be a Tensor"
39
+ # (batch_size, num_tokens, num_experts)
40
+ routing_weights = output.detach().cpu()
41
+ if self.max_num is not None and self.max_num > 0:
42
+ if _number_of_samples(self.routing_weights) > self.max_num:
43
+ return
44
+ elif (
45
+ routing_weights.size(0) + _number_of_samples(self.routing_weights)
46
+ > self.max_num
47
+ ):
48
+ self.routing_weights.append(
49
+ routing_weights[
50
+ : self.max_num - _number_of_samples(self.routing_weights)
51
+ ]
52
+ )
53
+ else:
54
+ self.routing_weights.append(routing_weights)
55
+ else:
56
+ self.routing_weights.append(routing_weights)
57
+
58
+ def save_routing_weights(self):
59
+ routing_weights = torch.cat(self.routing_weights, dim=0)
60
+ if self.save_path is not None:
61
+ self.save_path.parent.mkdir(parents=True, exist_ok=True)
62
+ print(
63
+ f"Saving routing weights to {self.save_path}. Size: {routing_weights.size()}"
64
+ )
65
+ torch.save(routing_weights, self.save_path)
@@ -139,11 +139,40 @@ class GPT2TextClassificationTaskPool(BaseTaskPool, LightningFabricMixin):
139
139
  return dataloader
140
140
 
141
141
  @override
142
- def evaluate(self, model: GPT2Model):
142
+ def evaluate(self, model: GPT2Model, name: str = None):
143
+ """Evaluate the model on the test datasets.
144
+
145
+ Args:
146
+ model (GPT2Model): The model to evaluate.
147
+ name (str, optional): The name of the model. Defaults to None. This is used to identify the model in the report.
148
+
149
+ Returns:
150
+ dict: A dictionary containing the evaluation results for each task.
151
+ """
143
152
  report = {}
153
+ if name is not None:
154
+ report["name"] = name
144
155
  for task_name in (pbar := tqdm(self._test_datasets, desc="Evaluating tasks")):
145
156
  pbar.set_description(f"Evaluating task {task_name}")
146
157
  dataloader = self.get_test_dataloader(task_name)
147
158
  result = self.evaluate_single_task(task_name, model, dataloader)
148
159
  report[task_name] = result
160
+
161
+ # calculate the average accuracy and loss
162
+ if "average" not in report:
163
+ report["average"] = {}
164
+ accuracies = [
165
+ value["accuracy"]
166
+ for key, value in report.items()
167
+ if isinstance(value, dict) and "accuracy" in value
168
+ ]
169
+ if len(accuracies) > 0:
170
+ average_accuracy = sum(accuracies) / len(accuracies)
171
+ report["average"]["accuracy"] = average_accuracy
172
+ losses = [value["loss"] for key, value in report.items() if "loss" in value]
173
+ if len(losses) > 0:
174
+ average_loss = sum(losses) / len(losses)
175
+ report["average"]["loss"] = average_loss
176
+
177
+ log.info(f"Evaluation Result: {report}")
149
178
  return report
@@ -0,0 +1 @@
1
+ from .openclip_taskpool import OpenCLIPVisionModelTaskPool
@@ -0,0 +1,196 @@
1
+ import itertools
2
+ import json
3
+ import logging
4
+ import os
5
+ from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
6
+
7
+ import lightning.fabric
8
+ import open_clip
9
+ import torch
10
+ from omegaconf import DictConfig
11
+ from torch.nn import functional as F
12
+ from torch.utils.data import DataLoader, Dataset
13
+ from torchmetrics import Accuracy, MeanMetric
14
+ from torchmetrics.classification.accuracy import MulticlassAccuracy
15
+ from tqdm.auto import tqdm
16
+
17
+ from fusion_bench import BaseTaskPool
18
+ from fusion_bench.dataset.clip_dataset import CLIPDataset
19
+ from fusion_bench.mixins import LightningFabricMixin
20
+ from fusion_bench.modelpool.openclip_vision.modelpool import load_classifier_head
21
+ from fusion_bench.models.open_clip import (
22
+ ClassificationHead,
23
+ ImageClassifier,
24
+ ImageEncoder,
25
+ )
26
+ from fusion_bench.models.open_clip.variables_and_paths import OPENCLIP_CACHEDIR
27
+ from fusion_bench.utils import count_parameters, instantiate
28
+
29
+ if TYPE_CHECKING:
30
+ from fusion_bench.modelpool import OpenCLIPVisionModelPool
31
+ from fusion_bench.programs import FabricModelFusionProgram
32
+
33
+ log = logging.getLogger(__name__)
34
+
35
+
36
+ class OpenCLIPVisionModelTaskPool(
37
+ BaseTaskPool,
38
+ LightningFabricMixin,
39
+ ):
40
+ _is_setup = False
41
+
42
+ _program: "FabricModelFusionProgram"
43
+
44
+ processor: Optional[Callable] = None
45
+ test_datasets: Dict[str, CLIPDataset]
46
+
47
+ def __init__(
48
+ self,
49
+ test_datasets: Union[DictConfig, Dict[str, Dataset]],
50
+ classification_heads: Union[DictConfig, Dict[str, ClassificationHead]],
51
+ dataloader_kwargs: DictConfig,
52
+ model_name: Optional[str] = None,
53
+ fast_dev_run: bool = False,
54
+ **kwargs,
55
+ ):
56
+ self._test_datasets = test_datasets
57
+ self._classifier_heads = classification_heads
58
+ self._dataloader_kwargs = dataloader_kwargs
59
+ self._model_name = model_name
60
+ self.fast_dev_run = fast_dev_run
61
+ super().__init__(**kwargs)
62
+
63
+ def setup(self):
64
+ # setup the processor
65
+ if self._program is not None and self._program.modelpool is not None:
66
+ modelpool: "OpenCLIPVisionModelPool" = self._program.modelpool
67
+ self.processor = modelpool.test_processor
68
+ elif self._model_name is not None:
69
+ _, _, self.processor = open_clip.create_model_and_transforms(
70
+ self._model_name,
71
+ pretrained="openai",
72
+ cache_dir=OPENCLIP_CACHEDIR,
73
+ )
74
+ else:
75
+ raise ValueError("Modelpool or model_name is not set")
76
+
77
+ # setup the test datasets
78
+ self.test_datasets = {
79
+ name: instantiate(dataset) if isinstance(dataset, DictConfig) else dataset
80
+ for name, dataset in self._test_datasets.items()
81
+ }
82
+ self.test_datasets = {
83
+ name: CLIPDataset(dataset, self.processor)
84
+ for name, dataset in self.test_datasets.items()
85
+ }
86
+ self.test_dataloaders = {
87
+ name: self.fabric.setup_dataloaders(
88
+ DataLoader(dataset, **self._dataloader_kwargs)
89
+ )
90
+ for name, dataset in self.test_datasets.items()
91
+ }
92
+
93
+ # setup classifier heads
94
+ self.classifier_heads = {
95
+ name: load_classifier_head(head).to(self.fabric.device)
96
+ for name, head in self._classifier_heads.items()
97
+ }
98
+ self._is_setup = True
99
+
100
+ @torch.no_grad()
101
+ def _evaluate(
102
+ self,
103
+ classifier: ImageClassifier,
104
+ test_loader: DataLoader,
105
+ num_classes: int,
106
+ task_name: str,
107
+ ):
108
+ accuracy: MulticlassAccuracy = Accuracy(
109
+ task="multiclass", num_classes=num_classes
110
+ )
111
+ classifier.eval()
112
+ loss_metric = MeanMetric()
113
+ # if fast_dev_run is set, we only evaluate on a batch of the data
114
+ if self.fast_dev_run:
115
+ log.info("Running under fast_dev_run mode, evaluating on a single batch.")
116
+ test_loader = itertools.islice(test_loader, 1)
117
+ else:
118
+ test_loader = test_loader
119
+
120
+ pbar = tqdm(
121
+ test_loader,
122
+ desc=f"Evaluating {task_name}",
123
+ leave=False,
124
+ dynamic_ncols=True,
125
+ )
126
+ for batch in pbar:
127
+ inputs, targets = batch
128
+ logits = classifier(inputs)
129
+ loss = F.cross_entropy(logits, targets)
130
+ loss_metric.update(loss.detach().cpu())
131
+ acc = accuracy(logits.detach().cpu(), targets.detach().cpu())
132
+ pbar.set_postfix(
133
+ {
134
+ "accuracy": accuracy.compute().item(),
135
+ "loss": loss_metric.compute().item(),
136
+ }
137
+ )
138
+
139
+ acc = accuracy.compute().item()
140
+ loss = loss_metric.compute().item()
141
+ results = {"accuracy": acc, "loss": loss}
142
+ return results
143
+
144
+ def evaluate(self, model: ImageEncoder, **kwargs):
145
+ if not self._is_setup:
146
+ self.setup()
147
+
148
+ report = {}
149
+ # collect basic model information
150
+ training_params, all_params = count_parameters(model)
151
+ report["model_info"] = {
152
+ "trainable_params": training_params,
153
+ "all_params": all_params,
154
+ "trainable_percentage": training_params / all_params,
155
+ }
156
+
157
+ if not lightning.fabric.is_wrapped(model):
158
+ model = self.fabric.setup_module(model)
159
+
160
+ pbar = tqdm(
161
+ self.test_dataloaders.items(),
162
+ desc="Evaluating tasks",
163
+ total=len(self.test_dataloaders),
164
+ )
165
+ for task_name, test_dataloader in pbar:
166
+ classifier = ImageClassifier(model, self.classifier_heads[task_name])
167
+ num_classes = self.classifier_heads[task_name].weight.size(0)
168
+ result = self._evaluate(
169
+ classifier,
170
+ test_dataloader,
171
+ num_classes=num_classes,
172
+ task_name=task_name,
173
+ )
174
+ report[task_name] = result
175
+
176
+ # calculate the average accuracy and loss
177
+ if "average" not in report:
178
+ report["average"] = {}
179
+ accuracies = [
180
+ value["accuracy"]
181
+ for key, value in report.items()
182
+ if "accuracy" in value
183
+ ]
184
+ if len(accuracies) > 0:
185
+ average_accuracy = sum(accuracies) / len(accuracies)
186
+ report["average"]["accuracy"] = average_accuracy
187
+ losses = [value["loss"] for key, value in report.items() if "loss" in value]
188
+ if len(losses) > 0:
189
+ average_loss = sum(losses) / len(losses)
190
+ report["average"]["loss"] = average_loss
191
+
192
+ log.info(f"Evaluation Result: {report}")
193
+ if self.fabric.is_global_zero and len(self.fabric._loggers) > 0:
194
+ with open(os.path.join(self.log_dir, "report.json"), "w") as fp:
195
+ json.dump(report, fp)
196
+ return report
@@ -9,6 +9,18 @@ from torch.utils.data import DataLoader, Dataset
9
9
 
10
10
 
11
11
  class InfiniteDataLoader:
12
+ """
13
+ A wrapper class for DataLoader to create an infinite data loader.
14
+ This is useful in case we are only interested in the number of steps and not the number of epochs.
15
+
16
+ This class wraps a DataLoader and provides an iterator that resets
17
+ when the end of the dataset is reached, creating an infinite loop.
18
+
19
+ Attributes:
20
+ data_loader (DataLoader): The DataLoader to wrap.
21
+ data_iter (iterator): An iterator over the DataLoader.
22
+ """
23
+
12
24
  def __init__(self, data_loader: DataLoader):
13
25
  self.data_loader = data_loader
14
26
  self.data_iter = iter(data_loader)
@@ -229,3 +229,17 @@ def cleanup_cuda():
229
229
  gc.collect()
230
230
  torch.cuda.empty_cache()
231
231
  torch.cuda.reset_peak_memory_stats()
232
+
233
+
234
+ def print_memory_usage(print_fn=print):
235
+ """
236
+ Print the current GPU memory usage.
237
+
238
+ Returns:
239
+ str: A string containing the allocated and cached memory in MB.
240
+ """
241
+ allocated = torch.cuda.memory_allocated() / 1024**2 # 转换为 MB
242
+ cached = torch.cuda.memory_reserved() / 1024**2 # 转换为 MB
243
+ print_str = f"Allocated Memory: {allocated:.2f} MB\nCached Memory: {cached:.2f} MB"
244
+ print_fn(print_str)
245
+ return print_str
@@ -2,6 +2,7 @@
2
2
  # Modified from Hydra
3
3
  import copy
4
4
  import functools
5
+ from contextlib import contextmanager
5
6
  from enum import Enum
6
7
  from textwrap import dedent
7
8
  from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
@@ -30,6 +31,17 @@ Function to be used for printing function calls.
30
31
  CATCH_EXCEPTION = True
31
32
 
32
33
 
34
+ @contextmanager
35
+ def set_print_function_call(value: bool):
36
+ global PRINT_FUNCTION_CALL
37
+ old_value = PRINT_FUNCTION_CALL
38
+ PRINT_FUNCTION_CALL = value
39
+ try:
40
+ yield
41
+ finally:
42
+ PRINT_FUNCTION_CALL = old_value
43
+
44
+
33
45
  def is_instantiable(config: Union[DictConfig, Any]) -> bool:
34
46
  if OmegaConf.is_dict(config):
35
47
  return "_target_" in config
@@ -1,6 +1,6 @@
1
- from typing import Iterable
1
+ from typing import Iterable, List
2
2
 
3
- __all__ = ["first", "has_length"]
3
+ __all__ = ["first", "has_length", "join_list"]
4
4
 
5
5
 
6
6
  def first(iterable: Iterable):
@@ -16,3 +16,10 @@ def has_length(dataset):
16
16
  except TypeError:
17
17
  # TypeError: len() of unsized object
18
18
  return False
19
+
20
+
21
+ def join_list(list_of_list: List[List]):
22
+ ans = []
23
+ for item in list_of_list:
24
+ ans.extend(item)
25
+ return ans
@@ -82,3 +82,17 @@ def import_object(abs_obj_name: str):
82
82
  module_name, obj_name = abs_obj_name.rsplit(".", 1)
83
83
  module = importlib.import_module(module_name)
84
84
  return getattr(module, obj_name)
85
+
86
+
87
+ def compare_versions(v1, v2):
88
+ """Compare two version strings.
89
+ Returns -1 if v1 < v2, 0 if v1 == v2, 1 if v1 > v2"""
90
+
91
+ v1 = version.parse(v1)
92
+ v2 = version.parse(v2)
93
+ if v1 < v2:
94
+ return -1
95
+ elif v1 > v2:
96
+ return 1
97
+ else:
98
+ return 0
@@ -252,7 +252,7 @@ def print_parameters(
252
252
 
253
253
 
254
254
  def check_parameters_all_equal(
255
- list_of_param_names: List[Union[StateDictType, nn.Module, List[str]]]
255
+ list_of_param_names: List[Union[StateDictType, nn.Module, List[str]]],
256
256
  ) -> None:
257
257
  """
258
258
  Checks if all models have the same parameters.
@@ -1,5 +1,5 @@
1
1
  """
2
- functions deal with tensorboard logs.
2
+ functions deal with tensorboard logs.
3
3
  """
4
4
 
5
5
  from typing import Dict, Iterable, List
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fusion_bench
3
- Version: 0.2.12
3
+ Version: 0.2.13
4
4
  Summary: A Comprehensive Benchmark of Deep Model Fusion
5
5
  Author-email: Anke Tang <tang.anke@foxmail.com>
6
6
  License: MIT License