fusion-bench 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +2 -0
- fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
- fusion_bench/constants/clip_vision.py +22 -0
- fusion_bench/dataset/clip_dataset.py +10 -2
- fusion_bench/dataset/fer2013.py +1 -0
- fusion_bench/dataset/gsm8k.py +2 -2
- fusion_bench/method/__init__.py +10 -0
- fusion_bench/method/ada_svd/clip_vision.py +4 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
- fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
- fusion_bench/method/gossip/__init__.py +3 -0
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
- fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
- fusion_bench/method/gossip/entropy_loss.py +25 -0
- fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
- fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
- fusion_bench/method/gossip/min_norm_solvers.py +227 -0
- fusion_bench/method/gossip/task_wise_gossip.py +265 -0
- fusion_bench/method/gossip/utils.py +74 -0
- fusion_bench/method/isotropic_merging/__init__.py +1 -1
- fusion_bench/method/opcm/opcm.py +16 -7
- fusion_bench/method/pwe_moe/module.py +1 -1
- fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
- fusion_bench/method/regmean/regmean.py +25 -17
- fusion_bench/method/smile_upscaling/__init__.py +1 -1
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +46 -145
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +229 -0
- fusion_bench/method/smile_upscaling/smile_upscaling.py +19 -346
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
- fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
- fusion_bench/method/ties_merging/ties_merging.py +36 -31
- fusion_bench/method/we_moe/we_moe.py +14 -15
- fusion_bench/mixins/__init__.py +6 -3
- fusion_bench/mixins/hydra_config.py +49 -0
- fusion_bench/mixins/openclip_classification.py +11 -0
- fusion_bench/mixins/simple_profiler.py +4 -2
- fusion_bench/modelpool/__init__.py +3 -1
- fusion_bench/modelpool/base_pool.py +2 -2
- fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
- fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
- fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +2 -203
- fusion_bench/models/modeling_smile_qwen2/__init__.py +8 -0
- fusion_bench/models/modeling_smile_qwen2/configuration_smile_qwen2.py +21 -0
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +922 -0
- fusion_bench/models/modeling_smile_qwen2/register.py +11 -0
- fusion_bench/models/open_clip/__init__.py +6 -0
- fusion_bench/models/open_clip/modeling.py +176 -0
- fusion_bench/models/open_clip/utils.py +311 -0
- fusion_bench/models/open_clip/variables_and_paths.py +56 -0
- fusion_bench/models/parameter_dict.py +54 -13
- fusion_bench/models/rankone_moe.py +2 -88
- fusion_bench/models/smile_moe/linear_from_hf_config.py +373 -0
- fusion_bench/models/smile_moe/{linear.py → linear_from_module.py} +103 -33
- fusion_bench/models/smile_moe/utils/__init__.py +24 -0
- fusion_bench/models/smile_moe/utils/svd_utils.py +46 -0
- fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
- fusion_bench/taskpool/__init__.py +7 -3
- fusion_bench/taskpool/clip_vision/__init__.py +1 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
- fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
- fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
- fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
- fusion_bench/taskpool/gpt2_text_classification.py +30 -1
- fusion_bench/taskpool/lm_eval_harness/__init__.py +3 -0
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +87 -0
- fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
- fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
- fusion_bench/utils/data.py +12 -0
- fusion_bench/utils/devices.py +14 -0
- fusion_bench/utils/instantiate.py +12 -0
- fusion_bench/utils/misc.py +9 -2
- fusion_bench/utils/packages.py +14 -0
- fusion_bench/utils/parameters.py +1 -1
- fusion_bench/utils/tensorboard.py +1 -1
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/METADATA +22 -2
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/RECORD +209 -157
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/WHEEL +1 -1
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
- fusion_bench_config/fabric/auto.yaml +0 -1
- fusion_bench_config/fabric/llama_ddp.yaml +0 -1
- fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric_model_fusion.yaml +0 -1
- fusion_bench_config/llama_full_finetune.yaml +0 -2
- fusion_bench_config/llama_model_fusion.yaml +0 -2
- fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
- fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
- fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
- fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
- fusion_bench_config/method/adamerging.yaml +2 -2
- fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
- fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
- fusion_bench_config/method/dare/simple_average.yaml +0 -1
- fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
- fusion_bench_config/method/dare/ties_merging.yaml +0 -2
- fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
- fusion_bench_config/method/doge_ta/doge_ta.yaml +1 -1
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
- fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
- fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
- fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
- fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
- fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
- fusion_bench_config/method/linear/llama_expo.yaml +0 -3
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
- fusion_bench_config/method/linear/weighted_average.yaml +0 -1
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
- fusion_bench_config/method/model_recombination.yaml +0 -1
- fusion_bench_config/method/opcm/opcm.yaml +0 -1
- fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
- fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
- fusion_bench_config/method/opcm/weight_average.yaml +0 -1
- fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
- fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
- fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
- fusion_bench_config/method/slerp/slerp.yaml +0 -2
- fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +5 -2
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +13 -0
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
- fusion_bench_config/method/task_arithmetic.yaml +1 -1
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
- fusion_bench_config/method/ties_merging.yaml +1 -1
- fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
- fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +17 -0
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +0 -2
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +1 -3
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
- fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
- fusion_bench_config/nyuv2_config.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml +12 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
- fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/top_level.txt +0 -0
|
@@ -13,348 +13,18 @@ from fusion_bench.method import BaseAlgorithm
|
|
|
13
13
|
from fusion_bench.method.simple_average import simple_average
|
|
14
14
|
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
15
15
|
from fusion_bench.modelpool import BaseModelPool
|
|
16
|
+
from fusion_bench.models.smile_moe.linear_from_module import (
|
|
17
|
+
ExpertNotTrainedError,
|
|
18
|
+
SmileCompressedLinear,
|
|
19
|
+
SmileGate,
|
|
20
|
+
SmileMoELinear,
|
|
21
|
+
)
|
|
16
22
|
from fusion_bench.models.utils import get_attr, set_attr
|
|
17
23
|
from fusion_bench.utils.parameters import print_parameters
|
|
18
24
|
|
|
19
25
|
log = logging.getLogger(__name__)
|
|
20
26
|
|
|
21
27
|
|
|
22
|
-
class ExpertNotTrainedError(Exception):
|
|
23
|
-
pass
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
def _is_all_zeros(tensor: Tensor | List[Tensor]) -> bool:
|
|
27
|
-
"""
|
|
28
|
-
Check if a tensor or a list of tensors are all zeros.
|
|
29
|
-
|
|
30
|
-
Args:
|
|
31
|
-
tensor (Tensor | List[Tensor]): A tensor or a list of tensors.
|
|
32
|
-
|
|
33
|
-
Returns:
|
|
34
|
-
bool: True if all elements are zeros, False otherwise.
|
|
35
|
-
"""
|
|
36
|
-
if isinstance(tensor, Tensor):
|
|
37
|
-
return torch.allclose(tensor, torch.zeros_like(tensor))
|
|
38
|
-
else:
|
|
39
|
-
return all(_is_all_zeros(t) for t in tensor)
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
def _svd(w: Tensor, full_matrices=True) -> Tuple[Tensor, Tensor, Tensor]:
|
|
43
|
-
"""
|
|
44
|
-
Perform Singular Value Decomposition (SVD) on a tensor.
|
|
45
|
-
|
|
46
|
-
Args:
|
|
47
|
-
w (Tensor): The input tensor.
|
|
48
|
-
full_matrices (bool): Whether to compute the full-sized U and V matrices.
|
|
49
|
-
|
|
50
|
-
Returns:
|
|
51
|
-
Tuple[Tensor, Tensor, Tensor]: The U, S, and V matrices from SVD.
|
|
52
|
-
"""
|
|
53
|
-
u, s, vh = torch.linalg.svd(
|
|
54
|
-
w, full_matrices=full_matrices, driver="gesvd" if w.is_cuda else None
|
|
55
|
-
)
|
|
56
|
-
v = vh.T
|
|
57
|
-
return u, s, v
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
def svd(
|
|
61
|
-
w: Tensor, full_matrices=True, accelerator=None
|
|
62
|
-
) -> Tuple[Tensor, Tensor, Tensor]:
|
|
63
|
-
"""
|
|
64
|
-
Perform SVD on a tensor, optionally using a specified accelerator.
|
|
65
|
-
|
|
66
|
-
Args:
|
|
67
|
-
w (Tensor): The input tensor.
|
|
68
|
-
full_matrices (bool): Whether to compute the full-sized U and V matrices.
|
|
69
|
-
accelerator (str): The device to perform the computation on.
|
|
70
|
-
|
|
71
|
-
Returns:
|
|
72
|
-
Tuple[Tensor, Tensor, Tensor]: The U, S, and V matrices from SVD.
|
|
73
|
-
"""
|
|
74
|
-
if accelerator is None:
|
|
75
|
-
return _svd(w, full_matrices=full_matrices)
|
|
76
|
-
original_device = w.device
|
|
77
|
-
w = w.to(accelerator)
|
|
78
|
-
u, s, v = _svd(w)
|
|
79
|
-
return u.to(original_device), s.to(original_device), v.to(original_device)
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
class SmileGate(nn.Module):
|
|
83
|
-
def __init__(
|
|
84
|
-
self,
|
|
85
|
-
input_features: int,
|
|
86
|
-
w_diff_list: List[Tensor],
|
|
87
|
-
k: int,
|
|
88
|
-
svd_list=None, # cached `svd_list`, pass it to avoid recomputing
|
|
89
|
-
upscaling_accelerator=None,
|
|
90
|
-
):
|
|
91
|
-
"""
|
|
92
|
-
Initialize the SmileGate module.
|
|
93
|
-
|
|
94
|
-
Args:
|
|
95
|
-
input_features (int): The number of input features.
|
|
96
|
-
w_diff_list (List[Tensor]): A list of weight difference tensors.
|
|
97
|
-
k (int): The number of singular values to keep.
|
|
98
|
-
svd_list (List[Tuple[Tensor, Tensor, Tensor]]): Cached SVD results.
|
|
99
|
-
upscaling_accelerator (str): The device to perform the computation on.
|
|
100
|
-
"""
|
|
101
|
-
super().__init__()
|
|
102
|
-
self.input_features = input_features
|
|
103
|
-
self.num_experts = len(w_diff_list)
|
|
104
|
-
weights = []
|
|
105
|
-
for i, w_diff in enumerate(w_diff_list):
|
|
106
|
-
if svd_list is None:
|
|
107
|
-
u, s, v = svd(w_diff, accelerator=upscaling_accelerator)
|
|
108
|
-
else:
|
|
109
|
-
u, s, v = svd_list[i]
|
|
110
|
-
u = u[:, :k]
|
|
111
|
-
s = s[:k]
|
|
112
|
-
v = v[:, :k]
|
|
113
|
-
|
|
114
|
-
# weights.append((s * v).T)
|
|
115
|
-
weights.append(v.T)
|
|
116
|
-
self.k = s.size(0) # k is the actual k after truncation
|
|
117
|
-
|
|
118
|
-
weights = (
|
|
119
|
-
torch.stack(weights, dim=0)
|
|
120
|
-
.reshape(self.num_experts * self.k, -1)
|
|
121
|
-
.contiguous()
|
|
122
|
-
)
|
|
123
|
-
self.weights = nn.Parameter(
|
|
124
|
-
weights
|
|
125
|
-
) # weights should be a tensor of shape (num_experts * k, n)
|
|
126
|
-
|
|
127
|
-
def forward(self, x: Tensor):
|
|
128
|
-
"""
|
|
129
|
-
Forward pass of the SmileGate module.
|
|
130
|
-
|
|
131
|
-
Args:
|
|
132
|
-
x (Tensor): The input tensor.
|
|
133
|
-
|
|
134
|
-
Returns:
|
|
135
|
-
Tensor: The routing weights.
|
|
136
|
-
"""
|
|
137
|
-
batch_size = x.size(0)
|
|
138
|
-
if self.num_experts == 1:
|
|
139
|
-
return torch.ones(batch_size, 1, device=x.device, dtype=x.dtype)
|
|
140
|
-
|
|
141
|
-
routing_weights = F.linear(x, self.weights).view(
|
|
142
|
-
batch_size, self.num_experts, self.k
|
|
143
|
-
)
|
|
144
|
-
routing_weights = routing_weights.norm(p=2, dim=2)
|
|
145
|
-
return routing_weights
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
class SmileCompressedLinear(nn.Module):
|
|
149
|
-
def __init__(self, model: nn.Linear, k: int, svd_cache=None):
|
|
150
|
-
"""
|
|
151
|
-
Initialize the SmileCompressedLinear module.
|
|
152
|
-
|
|
153
|
-
Args:
|
|
154
|
-
model (nn.Linear): The linear model to compress.
|
|
155
|
-
k (int): The number of singular values to keep.
|
|
156
|
-
svd_cache (Tuple[Tensor, Tensor, Tensor]): Cached SVD results.
|
|
157
|
-
"""
|
|
158
|
-
super().__init__()
|
|
159
|
-
if svd_cache is None:
|
|
160
|
-
u, s, v = svd(model.weight)
|
|
161
|
-
else:
|
|
162
|
-
u, s, v = svd_cache
|
|
163
|
-
if k > 0:
|
|
164
|
-
u = u[:, :k]
|
|
165
|
-
s = s[:k]
|
|
166
|
-
v = v[:, :k]
|
|
167
|
-
|
|
168
|
-
self.u = nn.Parameter(u)
|
|
169
|
-
self.svh = nn.Parameter((s * v).T)
|
|
170
|
-
|
|
171
|
-
if model.bias is not None:
|
|
172
|
-
self.bias = nn.Parameter(model.bias.data, requires_grad=True)
|
|
173
|
-
else:
|
|
174
|
-
self.register_parameter("bias", None)
|
|
175
|
-
|
|
176
|
-
def forward(self, x):
|
|
177
|
-
"""
|
|
178
|
-
Forward pass of the SmileCompressedLinear module.
|
|
179
|
-
|
|
180
|
-
Args:
|
|
181
|
-
x (Tensor): The input tensor.
|
|
182
|
-
|
|
183
|
-
Returns:
|
|
184
|
-
Tensor: The output tensor.
|
|
185
|
-
"""
|
|
186
|
-
x = F.linear(x, self.svh)
|
|
187
|
-
x = F.linear(x, self.u, self.bias)
|
|
188
|
-
return x
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
class SmileMoELinear(nn.Module):
|
|
192
|
-
@torch.no_grad()
|
|
193
|
-
def __init__(
|
|
194
|
-
self,
|
|
195
|
-
pretrained_model: nn.Linear,
|
|
196
|
-
finetuned_models: List[nn.Linear],
|
|
197
|
-
gate_k: int,
|
|
198
|
-
k: int,
|
|
199
|
-
top_k: int = 1,
|
|
200
|
-
full_matrices=True,
|
|
201
|
-
upscaling_accelerator=None,
|
|
202
|
-
routing_use_diff=True,
|
|
203
|
-
):
|
|
204
|
-
"""
|
|
205
|
-
Initialize the SmileMoELinear module.
|
|
206
|
-
|
|
207
|
-
Args:
|
|
208
|
-
pretrained_model (nn.Linear): The pretrained linear model.
|
|
209
|
-
finetuned_models (List[nn.Linear]): A list of fine-tuned linear models.
|
|
210
|
-
gate_k (int): The number of singular values to keep for the gate.
|
|
211
|
-
k (int): The number of singular values to keep for the experts.
|
|
212
|
-
top_k (int): The number of top experts to select.
|
|
213
|
-
full_matrices (bool): Whether to compute the full-sized U and V matrices.
|
|
214
|
-
upscaling_accelerator (str): The device to perform the computation on.
|
|
215
|
-
routing_use_diff (bool): Whether to use weight differences for routing.
|
|
216
|
-
"""
|
|
217
|
-
super().__init__()
|
|
218
|
-
self.num_experts = len(finetuned_models)
|
|
219
|
-
self.top_k = top_k
|
|
220
|
-
self.k = k
|
|
221
|
-
self.gate_k = gate_k
|
|
222
|
-
self.in_features = pretrained_model.in_features
|
|
223
|
-
self.out_features = pretrained_model.out_features
|
|
224
|
-
|
|
225
|
-
w_diff_list = [m.weight - pretrained_model.weight for m in finetuned_models]
|
|
226
|
-
if _is_all_zeros(w_diff_list):
|
|
227
|
-
# All fine-tuned models are identical to the pretrained model
|
|
228
|
-
raise ExpertNotTrainedError()
|
|
229
|
-
|
|
230
|
-
if routing_use_diff or k > 0:
|
|
231
|
-
svd_cache_list = [
|
|
232
|
-
svd(w, full_matrices=full_matrices, accelerator=upscaling_accelerator)
|
|
233
|
-
for w in w_diff_list
|
|
234
|
-
] # the svd cache list to avoid recomputing
|
|
235
|
-
|
|
236
|
-
# construct the gate network
|
|
237
|
-
if routing_use_diff:
|
|
238
|
-
self.gate = SmileGate(
|
|
239
|
-
input_features=self.in_features,
|
|
240
|
-
w_diff_list=w_diff_list,
|
|
241
|
-
k=gate_k,
|
|
242
|
-
svd_list=svd_cache_list,
|
|
243
|
-
upscaling_accelerator=upscaling_accelerator,
|
|
244
|
-
)
|
|
245
|
-
else:
|
|
246
|
-
self.gate = SmileGate(
|
|
247
|
-
input_features=self.in_features,
|
|
248
|
-
w_diff_list=[m.weight for m in finetuned_models],
|
|
249
|
-
k=gate_k,
|
|
250
|
-
svd_list=None,
|
|
251
|
-
upscaling_accelerator=upscaling_accelerator,
|
|
252
|
-
)
|
|
253
|
-
|
|
254
|
-
# construct experts
|
|
255
|
-
for m, w_diff in zip(finetuned_models, w_diff_list):
|
|
256
|
-
m.weight.data = w_diff
|
|
257
|
-
if k > 0:
|
|
258
|
-
experts = [
|
|
259
|
-
SmileCompressedLinear(m, k, svd_cache=svd_cache)
|
|
260
|
-
for m, svd_cache in zip(finetuned_models, svd_cache_list)
|
|
261
|
-
]
|
|
262
|
-
else:
|
|
263
|
-
# if k is not set (<0), we use the full fine-tuned model
|
|
264
|
-
experts = finetuned_models
|
|
265
|
-
self.experts = nn.ModuleList(experts)
|
|
266
|
-
|
|
267
|
-
if pretrained_model.bias is not None:
|
|
268
|
-
for m in experts:
|
|
269
|
-
m.bias.data = m.bias.data - pretrained_model.bias
|
|
270
|
-
# assign the pretrained model (the shared part)
|
|
271
|
-
self.pretrained_model = pretrained_model
|
|
272
|
-
|
|
273
|
-
def forward(self, hidden_states: Tensor):
|
|
274
|
-
"""
|
|
275
|
-
Forward pass of the SmileMoELinear module.
|
|
276
|
-
|
|
277
|
-
Args:
|
|
278
|
-
hidden_states (Tensor): The input tensor.
|
|
279
|
-
|
|
280
|
-
Returns:
|
|
281
|
-
Tensor: The output tensor.
|
|
282
|
-
"""
|
|
283
|
-
pretrained_out = self.pretrained_model(hidden_states)
|
|
284
|
-
|
|
285
|
-
input_shape = hidden_states.size()
|
|
286
|
-
hidden_states = hidden_states.view(-1, self.in_features)
|
|
287
|
-
|
|
288
|
-
router_logits = self.gate(hidden_states)
|
|
289
|
-
routing_weights = F.softmax(router_logits, dim=1)
|
|
290
|
-
# sample the expert according to the routing weights
|
|
291
|
-
routing_weights, selected_experts = torch.topk(
|
|
292
|
-
routing_weights, self.top_k, dim=-1
|
|
293
|
-
)
|
|
294
|
-
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
|
295
|
-
|
|
296
|
-
final_hidden_states = torch.zeros(
|
|
297
|
-
(hidden_states.size(0), self.out_features),
|
|
298
|
-
dtype=hidden_states.dtype,
|
|
299
|
-
device=hidden_states.device,
|
|
300
|
-
)
|
|
301
|
-
|
|
302
|
-
# One hot encode the selected experts to create an expert mask
|
|
303
|
-
# this will be used to easily index which expert is going to be sollicitated
|
|
304
|
-
expert_mask = torch.nn.functional.one_hot(
|
|
305
|
-
selected_experts, num_classes=self.num_experts
|
|
306
|
-
).permute(2, 1, 0)
|
|
307
|
-
|
|
308
|
-
# Loop over all available experts in the model and perform the computation on each expert
|
|
309
|
-
for expert_idx in range(self.num_experts):
|
|
310
|
-
expert_layer = self.experts[expert_idx]
|
|
311
|
-
idx, top_x = torch.where(expert_mask[expert_idx])
|
|
312
|
-
|
|
313
|
-
# Index the correct hidden states and compute the expert hidden state for
|
|
314
|
-
# the current expert. We need to make sure to multiply the output hidden
|
|
315
|
-
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
|
316
|
-
current_state = hidden_states[None, top_x].reshape(-1, self.in_features)
|
|
317
|
-
if current_state.numel() == 0:
|
|
318
|
-
continue
|
|
319
|
-
current_hidden_states = (
|
|
320
|
-
expert_layer(current_state) * routing_weights[top_x, idx, None]
|
|
321
|
-
)
|
|
322
|
-
|
|
323
|
-
# However `index_add_` only support torch tensors for indexing so we'll use
|
|
324
|
-
# the `top_x` tensor here.
|
|
325
|
-
final_hidden_states.index_add_(
|
|
326
|
-
0, top_x, current_hidden_states.to(hidden_states.dtype)
|
|
327
|
-
)
|
|
328
|
-
final_hidden_states = final_hidden_states.reshape(
|
|
329
|
-
*input_shape[:-1], self.out_features
|
|
330
|
-
)
|
|
331
|
-
final_hidden_states = pretrained_out + final_hidden_states
|
|
332
|
-
return final_hidden_states
|
|
333
|
-
|
|
334
|
-
@property
|
|
335
|
-
def weight(self):
|
|
336
|
-
"""
|
|
337
|
-
Mimic linear layer. Bacause in some cases, user might indicate the device (or dtype of parameters) of the linear layer using `linear_layer.weight.device`
|
|
338
|
-
"""
|
|
339
|
-
return self.pretrained_model.weight
|
|
340
|
-
|
|
341
|
-
@property
|
|
342
|
-
def bias(self):
|
|
343
|
-
return self.pretrained_model.bias
|
|
344
|
-
|
|
345
|
-
def __repr__(self):
|
|
346
|
-
return (
|
|
347
|
-
f"SingularMoELinear("
|
|
348
|
-
f"in_features={self.pretrained_model.in_features}, "
|
|
349
|
-
f"out_features={self.pretrained_model.out_features}, "
|
|
350
|
-
f"num_experts={self.num_experts}, "
|
|
351
|
-
f"top_k={self.top_k}, "
|
|
352
|
-
f"gate_k={self.gate_k}, "
|
|
353
|
-
f"k={self.k}"
|
|
354
|
-
f")"
|
|
355
|
-
)
|
|
356
|
-
|
|
357
|
-
|
|
358
28
|
class SmileUpscalingAlgorithm(
|
|
359
29
|
SimpleProfilerMixin,
|
|
360
30
|
BaseAlgorithm,
|
|
@@ -442,16 +112,19 @@ class SmileUpscalingAlgorithm(
|
|
|
442
112
|
print_parameters(model)
|
|
443
113
|
return model
|
|
444
114
|
|
|
445
|
-
with self.profile("
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
115
|
+
with self.profile("loading model"):
|
|
116
|
+
# load models and move to GPU if available
|
|
117
|
+
with self.profile("load pretrained model"):
|
|
118
|
+
pretrained_model = modelpool.load_model("_pretrained_")
|
|
119
|
+
with self.profile("load fine-tuned model"):
|
|
120
|
+
finetuned_models = [
|
|
121
|
+
m
|
|
122
|
+
for m in tqdm(modelpool.models(), total=len(modelpool.model_names))
|
|
123
|
+
]
|
|
124
|
+
|
|
125
|
+
if self.config.device == "cuda" and torch.cuda.is_available():
|
|
126
|
+
pretrained_model = pretrained_model.cuda()
|
|
127
|
+
finetuned_models = [m.cuda() for m in finetuned_models]
|
|
455
128
|
|
|
456
129
|
with self.profile("merge model"):
|
|
457
130
|
model = self.merge(pretrained_model, finetuned_models)
|
|
@@ -85,7 +85,14 @@ class CLIPLayerWiseAdaMergingSurgeryAlgorithm(
|
|
|
85
85
|
|
|
86
86
|
if self.config.weights is not None:
|
|
87
87
|
# skip the test-time adaptation
|
|
88
|
+
merge_weight: torch.Tensor = torch.load(self.config.weights)
|
|
89
|
+
module.merge_weight.data = merge_weight.to(
|
|
90
|
+
device=module.merge_weight.device
|
|
91
|
+
)
|
|
88
92
|
merged_model = copy.deepcopy(module.merge_and_unload())
|
|
93
|
+
# setup the zero-shot classification head
|
|
94
|
+
self.on_test_time_adaptation_start()
|
|
95
|
+
|
|
89
96
|
else:
|
|
90
97
|
with self.profile("test-time adaptation"):
|
|
91
98
|
module = self.test_time_adaptation(module)
|
|
@@ -6,7 +6,7 @@ http://arxiv.org/abs/2212.04089
|
|
|
6
6
|
|
|
7
7
|
import logging
|
|
8
8
|
from copy import deepcopy
|
|
9
|
-
from typing import Dict, List, Mapping, TypeVar, Union # noqa: F401
|
|
9
|
+
from typing import Dict, List, Mapping, Optional, TypeVar, Union # noqa: F401
|
|
10
10
|
|
|
11
11
|
import torch
|
|
12
12
|
from torch import nn
|
|
@@ -19,18 +19,18 @@ from fusion_bench.utils.state_dict_arithmetic import (
|
|
|
19
19
|
state_dict_mul,
|
|
20
20
|
state_dict_sub,
|
|
21
21
|
)
|
|
22
|
-
from fusion_bench.utils.type import StateDictType
|
|
22
|
+
from fusion_bench.utils.type import StateDictType, TorchModelType
|
|
23
23
|
|
|
24
24
|
log = logging.getLogger(__name__)
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
@torch.no_grad()
|
|
28
28
|
def task_arithmetic_merge(
|
|
29
|
-
pretrained_model:
|
|
30
|
-
finetuned_models: List[
|
|
29
|
+
pretrained_model: TorchModelType,
|
|
30
|
+
finetuned_models: List[TorchModelType],
|
|
31
31
|
scaling_factor: float,
|
|
32
32
|
inplace: bool = True,
|
|
33
|
-
) ->
|
|
33
|
+
) -> TorchModelType:
|
|
34
34
|
"""
|
|
35
35
|
Merges the task vectors from multiple fine-tuned models into a single pre-trained model.
|
|
36
36
|
|
|
@@ -46,15 +46,17 @@ def task_arithmetic_merge(
|
|
|
46
46
|
"""
|
|
47
47
|
if not inplace:
|
|
48
48
|
pretrained_model = deepcopy(pretrained_model)
|
|
49
|
-
task_vector: StateDictType = None
|
|
49
|
+
task_vector: Optional[StateDictType] = None
|
|
50
50
|
# Calculate the total task vector
|
|
51
51
|
for model in finetuned_models:
|
|
52
52
|
if task_vector is None:
|
|
53
|
+
# calculate the task vector for the first model
|
|
53
54
|
task_vector = state_dict_sub(
|
|
54
55
|
model.state_dict(keep_vars=True),
|
|
55
56
|
pretrained_model.state_dict(keep_vars=True),
|
|
56
57
|
)
|
|
57
58
|
else:
|
|
59
|
+
# calculate the task vector for the remaining models
|
|
58
60
|
task_vector = state_dict_add(
|
|
59
61
|
task_vector,
|
|
60
62
|
state_dict_sub(
|
|
@@ -16,6 +16,7 @@ from torch import Tensor, nn
|
|
|
16
16
|
|
|
17
17
|
from fusion_bench.compat.modelpool import to_modelpool
|
|
18
18
|
from fusion_bench.method import BaseAlgorithm
|
|
19
|
+
from fusion_bench.mixins import SimpleProfilerMixin
|
|
19
20
|
from fusion_bench.modelpool import BaseModelPool
|
|
20
21
|
from fusion_bench.utils.type import StateDictType
|
|
21
22
|
|
|
@@ -24,7 +25,7 @@ from .ties_merging_utils import state_dict_to_vector, ties_merging, vector_to_st
|
|
|
24
25
|
log = logging.getLogger(__name__)
|
|
25
26
|
|
|
26
27
|
|
|
27
|
-
class TiesMergingAlgorithm(BaseAlgorithm):
|
|
28
|
+
class TiesMergingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
28
29
|
"""
|
|
29
30
|
TiesMergingAlgorithm is a class for fusing multiple models using the TIES merging technique.
|
|
30
31
|
|
|
@@ -84,34 +85,38 @@ class TiesMergingAlgorithm(BaseAlgorithm):
|
|
|
84
85
|
scaling_factor = self.scaling_factor
|
|
85
86
|
threshold = self.threshold
|
|
86
87
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
merged_check
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
88
|
+
with self.profile("loading models"):
|
|
89
|
+
# Load the pretrained model
|
|
90
|
+
pretrained_model = modelpool.load_model("_pretrained_")
|
|
91
|
+
|
|
92
|
+
# Load the state dicts of the models
|
|
93
|
+
ft_checks: List[StateDictType] = [
|
|
94
|
+
modelpool.load_model(model_name).state_dict(keep_vars=True)
|
|
95
|
+
for model_name in modelpool.model_names
|
|
96
|
+
]
|
|
97
|
+
ptm_check: StateDictType = pretrained_model.state_dict(keep_vars=True)
|
|
98
|
+
|
|
99
|
+
with self.profile("merging models"):
|
|
100
|
+
# Compute the task vectors
|
|
101
|
+
flat_ft: Tensor = torch.vstack(
|
|
102
|
+
[state_dict_to_vector(check, remove_keys) for check in ft_checks]
|
|
103
|
+
)
|
|
104
|
+
flat_ptm: Tensor = state_dict_to_vector(ptm_check, remove_keys)
|
|
105
|
+
tv_flat_checks = flat_ft - flat_ptm
|
|
106
|
+
|
|
107
|
+
# Perform TIES Merging
|
|
108
|
+
merged_tv = ties_merging(
|
|
109
|
+
tv_flat_checks,
|
|
110
|
+
reset_thresh=threshold,
|
|
111
|
+
merge_func=merge_func,
|
|
112
|
+
)
|
|
113
|
+
merged_check = flat_ptm + scaling_factor * merged_tv
|
|
114
|
+
merged_state_dict = vector_to_state_dict(
|
|
115
|
+
merged_check, ptm_check, remove_keys=remove_keys
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# Load the merged state dict into the pretrained model
|
|
119
|
+
pretrained_model.load_state_dict(merged_state_dict)
|
|
120
|
+
|
|
121
|
+
self.print_profile_summary()
|
|
117
122
|
return pretrained_model
|
|
@@ -5,7 +5,6 @@ from typing import cast # noqa: F401
|
|
|
5
5
|
import lightning as L
|
|
6
6
|
import lightning.fabric.wrappers
|
|
7
7
|
import torch
|
|
8
|
-
from lightning.pytorch.profilers import SimpleProfiler
|
|
9
8
|
from omegaconf import DictConfig
|
|
10
9
|
from torch import Tensor
|
|
11
10
|
from torch.utils.data import DataLoader
|
|
@@ -13,6 +12,7 @@ from tqdm.autonotebook import tqdm
|
|
|
13
12
|
|
|
14
13
|
from fusion_bench.compat.method.base_algorithm import ModelFusionAlgorithm
|
|
15
14
|
from fusion_bench.compat.modelpool import ModelPool
|
|
15
|
+
from fusion_bench.mixins import SimpleProfilerMixin
|
|
16
16
|
from fusion_bench.models.we_moe import WeightEnsemblingMoE
|
|
17
17
|
from fusion_bench.utils import timeit_context
|
|
18
18
|
from fusion_bench.utils.parameters import print_parameters
|
|
@@ -34,7 +34,10 @@ def entropy_loss(logits: Tensor) -> Tensor:
|
|
|
34
34
|
return -torch.sum(probs * torch.log(probs + 1e-8), dim=-1).mean()
|
|
35
35
|
|
|
36
36
|
|
|
37
|
-
class WeightEnsemblingMoEAlgorithm(
|
|
37
|
+
class WeightEnsemblingMoEAlgorithm(
|
|
38
|
+
ModelFusionAlgorithm,
|
|
39
|
+
SimpleProfilerMixin,
|
|
40
|
+
):
|
|
38
41
|
"""
|
|
39
42
|
Algorithm for fusing models using Weight Ensembling Mixture of Experts (MoE).
|
|
40
43
|
|
|
@@ -44,7 +47,6 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
|
|
|
44
47
|
Attributes:
|
|
45
48
|
_fabric (L.Fabric): The fabric for distributed training.
|
|
46
49
|
modelpool (ModelPool): The pool of models to be fused.
|
|
47
|
-
profiler (SimpleProfiler): The profiler for measuring performance.
|
|
48
50
|
"""
|
|
49
51
|
|
|
50
52
|
_fabric: L.Fabric = None
|
|
@@ -66,9 +68,6 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
|
|
|
66
68
|
self._fabric.launch()
|
|
67
69
|
else:
|
|
68
70
|
assert "No CUDA device available."
|
|
69
|
-
self.profiler = SimpleProfiler(
|
|
70
|
-
self.config.get("cache_dir", "outputs"), "we_moe_profiler.txt"
|
|
71
|
-
)
|
|
72
71
|
|
|
73
72
|
@abstractmethod
|
|
74
73
|
def load_checkpoint(self, model, checkpoint):
|
|
@@ -177,9 +176,9 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
|
|
|
177
176
|
for step_idx in pbar:
|
|
178
177
|
if self.config.use_grad_accumulate:
|
|
179
178
|
for task in self.modelpool.model_names:
|
|
180
|
-
with self.
|
|
179
|
+
with self.profile("data time"):
|
|
181
180
|
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
182
|
-
with self.
|
|
181
|
+
with self.profile("forward pass"):
|
|
183
182
|
logits = self.compute_logits(module, batch, task)
|
|
184
183
|
assert (
|
|
185
184
|
logits.dim() == 2
|
|
@@ -187,23 +186,23 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
|
|
|
187
186
|
loss = entropy_loss(logits)
|
|
188
187
|
# .backward() accumulates when .zero_grad() wasn't called
|
|
189
188
|
# this can save memory
|
|
190
|
-
with self.
|
|
189
|
+
with self.profile("backward pass"):
|
|
191
190
|
self._fabric.backward(loss, retain_graph=True)
|
|
192
191
|
else:
|
|
193
192
|
loss = 0
|
|
194
193
|
for task in self.modelpool.model_names:
|
|
195
|
-
with self.
|
|
194
|
+
with self.profile("data time"):
|
|
196
195
|
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
197
|
-
with self.
|
|
196
|
+
with self.profile("forward pass"):
|
|
198
197
|
logits = self.compute_logits(module, batch, task)
|
|
199
198
|
assert (
|
|
200
199
|
logits.dim() == 2
|
|
201
200
|
), f"Expected logits to be 2D, got {logits.dim()}"
|
|
202
201
|
loss = loss + entropy_loss(logits)
|
|
203
|
-
with self.
|
|
202
|
+
with self.profile("backward pass"):
|
|
204
203
|
self._fabric.backward(loss, retain_graph=True)
|
|
205
204
|
|
|
206
|
-
with self.
|
|
205
|
+
with self.profile("optimizer step"):
|
|
207
206
|
optimizer.step()
|
|
208
207
|
optimizer.zero_grad()
|
|
209
208
|
|
|
@@ -232,7 +231,7 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
|
|
|
232
231
|
)
|
|
233
232
|
self.load_checkpoint(moe_model, self.config.checkpoint)
|
|
234
233
|
else:
|
|
235
|
-
with self.
|
|
234
|
+
with self.profile("test-time adaptation"):
|
|
236
235
|
moe_model = self.test_time_adaptation(moe_model)
|
|
237
236
|
if self.config.get("save_checkpoint", False):
|
|
238
237
|
log.info(f"save checkpoint to {self.config.save_checkpoint}")
|
|
@@ -243,5 +242,5 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
|
|
|
243
242
|
|
|
244
243
|
# enable sample-wise adaptation
|
|
245
244
|
moe_model.batch_reduce = False
|
|
246
|
-
|
|
245
|
+
self.print_profile_summary()
|
|
247
246
|
return moe_model
|
fusion_bench/mixins/__init__.py
CHANGED
|
@@ -6,20 +6,23 @@ from typing_extensions import TYPE_CHECKING
|
|
|
6
6
|
from fusion_bench.utils.lazy_imports import LazyImporter
|
|
7
7
|
|
|
8
8
|
_import_structure = {
|
|
9
|
+
"clip_classification": ["CLIPClassificationMixin"],
|
|
10
|
+
"fabric_training": ["FabricTrainingMixin"],
|
|
11
|
+
"hydra_config": ["HydraConfigMixin"],
|
|
9
12
|
"lightning_fabric": ["LightningFabricMixin"],
|
|
13
|
+
"openclip_classification": ["OpenCLIPClassificationMixin"],
|
|
10
14
|
"serialization": ["YAMLSerializationMixin", "BaseYAMLSerializableModel"],
|
|
11
15
|
"simple_profiler": ["SimpleProfilerMixin"],
|
|
12
|
-
"clip_classification": ["CLIPClassificationMixin"],
|
|
13
|
-
"fabric_training": ["FabricTrainingMixin"],
|
|
14
16
|
}
|
|
15
17
|
|
|
16
18
|
if TYPE_CHECKING:
|
|
17
19
|
from .clip_classification import CLIPClassificationMixin
|
|
18
20
|
from .fabric_training import FabricTrainingMixin
|
|
21
|
+
from .hydra_config import HydraConfigMixin
|
|
19
22
|
from .lightning_fabric import LightningFabricMixin
|
|
23
|
+
from .openclip_classification import OpenCLIPClassificationMixin
|
|
20
24
|
from .serialization import BaseYAMLSerializableModel, YAMLSerializationMixin
|
|
21
25
|
from .simple_profiler import SimpleProfilerMixin
|
|
22
|
-
|
|
23
26
|
else:
|
|
24
27
|
sys.modules[__name__] = LazyImporter(
|
|
25
28
|
__name__,
|