fusion-bench 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +2 -0
- fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
- fusion_bench/constants/clip_vision.py +22 -0
- fusion_bench/dataset/clip_dataset.py +10 -2
- fusion_bench/dataset/fer2013.py +1 -0
- fusion_bench/dataset/gsm8k.py +2 -2
- fusion_bench/method/__init__.py +10 -0
- fusion_bench/method/ada_svd/clip_vision.py +4 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
- fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
- fusion_bench/method/gossip/__init__.py +3 -0
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
- fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
- fusion_bench/method/gossip/entropy_loss.py +25 -0
- fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
- fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
- fusion_bench/method/gossip/min_norm_solvers.py +227 -0
- fusion_bench/method/gossip/task_wise_gossip.py +265 -0
- fusion_bench/method/gossip/utils.py +74 -0
- fusion_bench/method/isotropic_merging/__init__.py +1 -1
- fusion_bench/method/opcm/opcm.py +16 -7
- fusion_bench/method/pwe_moe/module.py +1 -1
- fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
- fusion_bench/method/regmean/regmean.py +25 -17
- fusion_bench/method/smile_upscaling/__init__.py +1 -1
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +46 -145
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +229 -0
- fusion_bench/method/smile_upscaling/smile_upscaling.py +19 -346
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
- fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
- fusion_bench/method/ties_merging/ties_merging.py +36 -31
- fusion_bench/method/we_moe/we_moe.py +14 -15
- fusion_bench/mixins/__init__.py +6 -3
- fusion_bench/mixins/hydra_config.py +49 -0
- fusion_bench/mixins/openclip_classification.py +11 -0
- fusion_bench/mixins/simple_profiler.py +4 -2
- fusion_bench/modelpool/__init__.py +3 -1
- fusion_bench/modelpool/base_pool.py +2 -2
- fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
- fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
- fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +2 -203
- fusion_bench/models/modeling_smile_qwen2/__init__.py +8 -0
- fusion_bench/models/modeling_smile_qwen2/configuration_smile_qwen2.py +21 -0
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +922 -0
- fusion_bench/models/modeling_smile_qwen2/register.py +11 -0
- fusion_bench/models/open_clip/__init__.py +6 -0
- fusion_bench/models/open_clip/modeling.py +176 -0
- fusion_bench/models/open_clip/utils.py +311 -0
- fusion_bench/models/open_clip/variables_and_paths.py +56 -0
- fusion_bench/models/parameter_dict.py +54 -13
- fusion_bench/models/rankone_moe.py +2 -88
- fusion_bench/models/smile_moe/linear_from_hf_config.py +373 -0
- fusion_bench/models/smile_moe/{linear.py → linear_from_module.py} +103 -33
- fusion_bench/models/smile_moe/utils/__init__.py +24 -0
- fusion_bench/models/smile_moe/utils/svd_utils.py +46 -0
- fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
- fusion_bench/taskpool/__init__.py +7 -3
- fusion_bench/taskpool/clip_vision/__init__.py +1 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
- fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
- fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
- fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
- fusion_bench/taskpool/gpt2_text_classification.py +30 -1
- fusion_bench/taskpool/lm_eval_harness/__init__.py +3 -0
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +87 -0
- fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
- fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
- fusion_bench/utils/data.py +12 -0
- fusion_bench/utils/devices.py +14 -0
- fusion_bench/utils/instantiate.py +12 -0
- fusion_bench/utils/misc.py +9 -2
- fusion_bench/utils/packages.py +14 -0
- fusion_bench/utils/parameters.py +1 -1
- fusion_bench/utils/tensorboard.py +1 -1
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/METADATA +22 -2
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/RECORD +209 -157
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/WHEEL +1 -1
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
- fusion_bench_config/fabric/auto.yaml +0 -1
- fusion_bench_config/fabric/llama_ddp.yaml +0 -1
- fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric_model_fusion.yaml +0 -1
- fusion_bench_config/llama_full_finetune.yaml +0 -2
- fusion_bench_config/llama_model_fusion.yaml +0 -2
- fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
- fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
- fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
- fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
- fusion_bench_config/method/adamerging.yaml +2 -2
- fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
- fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
- fusion_bench_config/method/dare/simple_average.yaml +0 -1
- fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
- fusion_bench_config/method/dare/ties_merging.yaml +0 -2
- fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
- fusion_bench_config/method/doge_ta/doge_ta.yaml +1 -1
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
- fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
- fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
- fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
- fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
- fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
- fusion_bench_config/method/linear/llama_expo.yaml +0 -3
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
- fusion_bench_config/method/linear/weighted_average.yaml +0 -1
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
- fusion_bench_config/method/model_recombination.yaml +0 -1
- fusion_bench_config/method/opcm/opcm.yaml +0 -1
- fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
- fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
- fusion_bench_config/method/opcm/weight_average.yaml +0 -1
- fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
- fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
- fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
- fusion_bench_config/method/slerp/slerp.yaml +0 -2
- fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +5 -2
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +13 -0
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
- fusion_bench_config/method/task_arithmetic.yaml +1 -1
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
- fusion_bench_config/method/ties_merging.yaml +1 -1
- fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
- fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +17 -0
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +0 -2
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +1 -3
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
- fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
- fusion_bench_config/nyuv2_config.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml +12 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
- fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from .svd_utils import svd
|
|
7
|
+
|
|
8
|
+
__all__ = ["svd_utils", "_is_all_zeros"]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _is_all_zeros(tensor: Tensor | List[Tensor]) -> bool:
|
|
12
|
+
"""
|
|
13
|
+
Check if a tensor or a list of tensors are all zeros.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
tensor (Tensor | List[Tensor]): A tensor or a list of tensors.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
bool: True if all elements are zeros, False otherwise.
|
|
20
|
+
"""
|
|
21
|
+
if isinstance(tensor, Tensor):
|
|
22
|
+
return torch.allclose(tensor, torch.zeros_like(tensor))
|
|
23
|
+
else:
|
|
24
|
+
return all(_is_all_zeros(t) for t in tensor)
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from typing import Optional, Tuple, Union
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def _svd(w: Tensor, full_matrices: bool = True) -> Tuple[Tensor, Tensor, Tensor]:
|
|
8
|
+
"""
|
|
9
|
+
Perform Singular Value Decomposition (SVD) on a tensor.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
w (Tensor): The input tensor.
|
|
13
|
+
full_matrices (bool): Whether to compute the full-sized U and V matrices.
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
Tuple[Tensor, Tensor, Tensor]: The U, S, and V matrices from SVD.
|
|
17
|
+
"""
|
|
18
|
+
u, s, vh = torch.linalg.svd(
|
|
19
|
+
w, full_matrices=full_matrices, driver="gesvd" if w.is_cuda else None
|
|
20
|
+
)
|
|
21
|
+
v = vh.T
|
|
22
|
+
return u, s, v
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def svd(
|
|
26
|
+
w: Tensor,
|
|
27
|
+
full_matrices: bool = True,
|
|
28
|
+
accelerator: Optional[Union[torch.device, str]] = None,
|
|
29
|
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
|
30
|
+
"""
|
|
31
|
+
Perform SVD on a tensor, optionally using a specified accelerator.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
w (Tensor): The input tensor.
|
|
35
|
+
full_matrices (bool): Whether to compute the full-sized U and V matrices.
|
|
36
|
+
accelerator (Optional[Union[torch.device, str]]): The device to perform the computation on.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
Tuple[Tensor, Tensor, Tensor]: The U, S, and V matrices from SVD.
|
|
40
|
+
"""
|
|
41
|
+
if accelerator is None:
|
|
42
|
+
return _svd(w, full_matrices=full_matrices)
|
|
43
|
+
original_device = w.device
|
|
44
|
+
w = w.to(accelerator)
|
|
45
|
+
u, s, v = _svd(w)
|
|
46
|
+
return u.to(original_device), s.to(original_device), v.to(original_device)
|
|
@@ -10,12 +10,14 @@ _import_structure = {
|
|
|
10
10
|
"clip_vision": [
|
|
11
11
|
"CLIPVisionModelTaskPool",
|
|
12
12
|
"SparseWEMoECLIPVisionModelTaskPool",
|
|
13
|
-
"
|
|
13
|
+
"RankoneMoECLIPVisionModelTaskPool",
|
|
14
14
|
],
|
|
15
15
|
"dummy": ["DummyTaskPool"],
|
|
16
16
|
"gpt2_text_classification": ["GPT2TextClassificationTaskPool"],
|
|
17
|
-
"nyuv2_taskpool": ["NYUv2TaskPool"],
|
|
18
17
|
"llama": ["LlamaTestGenerationTaskPool"],
|
|
18
|
+
"lm_eval_harness": ["LMEvalHarnessTaskPool"],
|
|
19
|
+
"nyuv2_taskpool": ["NYUv2TaskPool"],
|
|
20
|
+
"openclip_vision": ["OpenCLIPVisionModelTaskPool"],
|
|
19
21
|
}
|
|
20
22
|
|
|
21
23
|
|
|
@@ -23,13 +25,15 @@ if TYPE_CHECKING:
|
|
|
23
25
|
from .base_pool import BaseTaskPool
|
|
24
26
|
from .clip_vision import (
|
|
25
27
|
CLIPVisionModelTaskPool,
|
|
26
|
-
|
|
28
|
+
RankoneMoECLIPVisionModelTaskPool,
|
|
27
29
|
SparseWEMoECLIPVisionModelTaskPool,
|
|
28
30
|
)
|
|
29
31
|
from .dummy import DummyTaskPool
|
|
30
32
|
from .gpt2_text_classification import GPT2TextClassificationTaskPool
|
|
31
33
|
from .llama import LlamaTestGenerationTaskPool
|
|
34
|
+
from .lm_eval_harness import LMEvalHarnessTaskPool
|
|
32
35
|
from .nyuv2_taskpool import NYUv2TaskPool
|
|
36
|
+
from .openclip_vision import OpenCLIPVisionModelTaskPool
|
|
33
37
|
|
|
34
38
|
else:
|
|
35
39
|
sys.modules[__name__] = LazyImporter(
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
# flake8: noqa F401
|
|
2
2
|
from .clip_rankone_moe_taskpool import RankoneMoECLIPVisionModelTaskPool
|
|
3
|
+
from .clip_smile_taskpool import SmileCLIPVisionModelTaskPool
|
|
3
4
|
from .clip_sparse_wemoe_taskpool import SparseWEMoECLIPVisionModelTaskPool
|
|
4
5
|
from .taskpool import CLIPVisionModelTaskPool
|
|
@@ -12,36 +12,7 @@ from fusion_bench.models.hf_clip import HFCLIPClassifier
|
|
|
12
12
|
from fusion_bench.models.rankone_moe import RankOneMoE
|
|
13
13
|
|
|
14
14
|
from .taskpool import CLIPVisionModelTaskPool
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class LayerWiseRoutingWeightSaver:
|
|
18
|
-
def __init__(self, save_path: Path, max_num: Optional[int] = None):
|
|
19
|
-
self.save_path = save_path
|
|
20
|
-
self.max_num = max_num
|
|
21
|
-
self.routing_weights = []
|
|
22
|
-
|
|
23
|
-
def __call__(self, module, input: Tuple[Tensor], output: Tensor):
|
|
24
|
-
assert isinstance(output, Tensor), "Output is expected to be a Tensor"
|
|
25
|
-
# (batch_size, num_tokens, num_experts)
|
|
26
|
-
routing_weights = output.detach().cpu()
|
|
27
|
-
if self.max_num is not None and self.max_num > 0:
|
|
28
|
-
if len(self.routing_weights) > self.max_num:
|
|
29
|
-
return
|
|
30
|
-
elif routing_weights.size(0) + len(self.routing_weights) > self.max_num:
|
|
31
|
-
self.routing_weights.append(
|
|
32
|
-
routing_weights[: self.max_num - len(self.routing_weights)]
|
|
33
|
-
)
|
|
34
|
-
else:
|
|
35
|
-
self.routing_weights.append(routing_weights)
|
|
36
|
-
else:
|
|
37
|
-
self.routing_weights.append(routing_weights)
|
|
38
|
-
|
|
39
|
-
def save_routing_weights(self):
|
|
40
|
-
routing_weights = torch.cat(self.routing_weights, dim=0)
|
|
41
|
-
if self.save_path is not None:
|
|
42
|
-
self.save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
43
|
-
print(f"Saving routing weights to {self.save_path}")
|
|
44
|
-
torch.save(routing_weights, self.save_path)
|
|
15
|
+
from .utils.routing_analysis_utils import LayerWiseRoutingWeightSaver
|
|
45
16
|
|
|
46
17
|
|
|
47
18
|
class RankoneMoECLIPVisionModelTaskPool(CLIPVisionModelTaskPool):
|
|
@@ -109,4 +80,5 @@ class RankoneMoECLIPVisionModelTaskPool(CLIPVisionModelTaskPool):
|
|
|
109
80
|
# remove hooks for saving layer-wise routing weights
|
|
110
81
|
for i, handle in self._layer_wise_routing_weights_save_hook_handles.items():
|
|
111
82
|
self._layer_wise_routing_weights_save_hooks[i].save_routing_weights()
|
|
83
|
+
self._layer_wise_routing_weights_save_hook_handles.pop(i)
|
|
112
84
|
handle.remove()
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
from copy import deepcopy
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import Tensor
|
|
7
|
+
from torch.utils.hooks import RemovableHandle
|
|
8
|
+
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
|
|
9
|
+
from transformers.models.clip.modeling_clip import CLIPVisionTransformer
|
|
10
|
+
|
|
11
|
+
from fusion_bench.method.smile_upscaling import SmileMoELinear
|
|
12
|
+
from fusion_bench.models.hf_clip import HFCLIPClassifier
|
|
13
|
+
|
|
14
|
+
from .taskpool import CLIPVisionModelTaskPool
|
|
15
|
+
from .utils.routing_analysis_utils import LayerWiseRoutingWeightSaver
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SmileCLIPVisionModelTaskPool(CLIPVisionModelTaskPool):
|
|
19
|
+
|
|
20
|
+
# hooks and handles for saving layer-wise routing weights
|
|
21
|
+
_layer_wise_routing_weights_save_hooks: Dict[Any, LayerWiseRoutingWeightSaver] = {}
|
|
22
|
+
_layer_wise_routing_weights_save_hook_handles: Dict[Any, RemovableHandle] = {}
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
linear_module_names: Union[List[str], str],
|
|
27
|
+
layer_wise_routing_weights_save_path: Optional[str],
|
|
28
|
+
layer_wise_routing_weights_max_num: Optional[int] = None,
|
|
29
|
+
**kwargs,
|
|
30
|
+
):
|
|
31
|
+
"""
|
|
32
|
+
Initialize the SMILECLIPVisionModelTaskPool.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
linear_module_names (Union[List[str], str]): The names of the linear modules to save the layer-wise routing weights for.
|
|
36
|
+
layer_wise_routing_weights_save_path (Optional[str]): The path to save the layer-wise routing weights.
|
|
37
|
+
layer_wise_routing_weights_max_num (Optional[int]): The maximum number of layer-wise routing weights to save.
|
|
38
|
+
"""
|
|
39
|
+
# linear module names
|
|
40
|
+
assert linear_module_names is not None, "linear_module_names must be provided"
|
|
41
|
+
self.linear_module_names = (
|
|
42
|
+
[linear_module_names]
|
|
43
|
+
if isinstance(linear_module_names, str)
|
|
44
|
+
else list(linear_module_names)
|
|
45
|
+
)
|
|
46
|
+
# save path for layer-wise routing weights
|
|
47
|
+
self._layer_wise_routing_weights_save_path = (
|
|
48
|
+
layer_wise_routing_weights_save_path
|
|
49
|
+
)
|
|
50
|
+
self.layer_wise_routing_weights_save_path = (
|
|
51
|
+
Path(layer_wise_routing_weights_save_path)
|
|
52
|
+
if layer_wise_routing_weights_save_path is not None
|
|
53
|
+
else None
|
|
54
|
+
)
|
|
55
|
+
self.layer_wise_routing_weights_max_num = layer_wise_routing_weights_max_num
|
|
56
|
+
super().__init__(**kwargs)
|
|
57
|
+
|
|
58
|
+
def on_task_evaluation_begin(self, classifier: HFCLIPClassifier, task_name: str):
|
|
59
|
+
super().on_task_evaluation_begin(classifier, task_name)
|
|
60
|
+
if self.layer_wise_routing_weights_save_path is not None:
|
|
61
|
+
# setup hooks for saving layer-wise routing weights
|
|
62
|
+
assert isinstance(
|
|
63
|
+
classifier.clip_model.vision_model,
|
|
64
|
+
(CLIPVisionTransformer, CLIPVisionModel),
|
|
65
|
+
), "Vision model is expected to be a CLIPVisionTransformer"
|
|
66
|
+
vision_model = classifier.clip_model.vision_model
|
|
67
|
+
if isinstance(vision_model, CLIPVisionModel):
|
|
68
|
+
vision_model = vision_model.vision_model
|
|
69
|
+
# assign forward hooks for each layer
|
|
70
|
+
|
|
71
|
+
for i, layer in enumerate(vision_model.encoder.layers):
|
|
72
|
+
for linear_module_name in self.linear_module_names:
|
|
73
|
+
linear_module = layer.get_submodule(linear_module_name)
|
|
74
|
+
assert isinstance(
|
|
75
|
+
linear_module,
|
|
76
|
+
(SmileMoELinear),
|
|
77
|
+
), f"Linear module is expected to be a SmileMoELinear, but got {type(linear_module)}"
|
|
78
|
+
# layer-wise routing weights
|
|
79
|
+
hook = LayerWiseRoutingWeightSaver(
|
|
80
|
+
self.layer_wise_routing_weights_save_path
|
|
81
|
+
/ task_name
|
|
82
|
+
/ f"layer_{i}_{linear_module_name}.pt",
|
|
83
|
+
max_num=self.layer_wise_routing_weights_max_num,
|
|
84
|
+
)
|
|
85
|
+
self._layer_wise_routing_weights_save_hooks[
|
|
86
|
+
(i, linear_module_name)
|
|
87
|
+
] = hook
|
|
88
|
+
self._layer_wise_routing_weights_save_hook_handles[
|
|
89
|
+
(i, linear_module_name)
|
|
90
|
+
] = linear_module.gate.register_forward_hook(hook)
|
|
91
|
+
|
|
92
|
+
def on_task_evaluation_end(self):
|
|
93
|
+
super().on_task_evaluation_end()
|
|
94
|
+
if self.layer_wise_routing_weights_save_path is not None:
|
|
95
|
+
# remove hooks for saving layer-wise routing weights
|
|
96
|
+
for (
|
|
97
|
+
key,
|
|
98
|
+
handle,
|
|
99
|
+
) in self._layer_wise_routing_weights_save_hook_handles.items():
|
|
100
|
+
self._layer_wise_routing_weights_save_hooks[key].save_routing_weights()
|
|
101
|
+
self._layer_wise_routing_weights_save_hook_handles.pop(key)
|
|
102
|
+
handle.remove()
|
|
@@ -15,36 +15,7 @@ from fusion_bench.models.sparse_we_moe import (
|
|
|
15
15
|
)
|
|
16
16
|
|
|
17
17
|
from .taskpool import CLIPVisionModelTaskPool
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class LayerWiseRoutingWeightSaver:
|
|
21
|
-
def __init__(self, save_path: Path, max_num: Optional[int] = None):
|
|
22
|
-
self.save_path = save_path
|
|
23
|
-
self.max_num = max_num
|
|
24
|
-
self.routing_weights = []
|
|
25
|
-
|
|
26
|
-
def __call__(self, module, input: Tuple[Tensor], output: Tensor):
|
|
27
|
-
assert isinstance(output, Tensor), "Output is expected to be a Tensor"
|
|
28
|
-
# (batch_size, num_tokens, num_experts)
|
|
29
|
-
routing_weights = output.detach().cpu()
|
|
30
|
-
if self.max_num is not None and self.max_num > 0:
|
|
31
|
-
if len(self.routing_weights) > self.max_num:
|
|
32
|
-
return
|
|
33
|
-
elif routing_weights.size(0) + len(self.routing_weights) > self.max_num:
|
|
34
|
-
self.routing_weights.append(
|
|
35
|
-
routing_weights[: self.max_num - len(self.routing_weights)]
|
|
36
|
-
)
|
|
37
|
-
else:
|
|
38
|
-
self.routing_weights.append(routing_weights)
|
|
39
|
-
else:
|
|
40
|
-
self.routing_weights.append(routing_weights)
|
|
41
|
-
|
|
42
|
-
def save_routing_weights(self):
|
|
43
|
-
routing_weights = torch.cat(self.routing_weights, dim=0)
|
|
44
|
-
if self.save_path is not None:
|
|
45
|
-
self.save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
46
|
-
print(f"Saving routing weights to {self.save_path}")
|
|
47
|
-
torch.save(routing_weights, self.save_path)
|
|
18
|
+
from .utils.routing_analysis_utils import LayerWiseRoutingWeightSaver
|
|
48
19
|
|
|
49
20
|
|
|
50
21
|
class SparseWEMoECLIPVisionModelTaskPool(CLIPVisionModelTaskPool):
|
|
@@ -117,4 +88,5 @@ class SparseWEMoECLIPVisionModelTaskPool(CLIPVisionModelTaskPool):
|
|
|
117
88
|
# remove hooks for saving layer-wise routing weights
|
|
118
89
|
for i, handle in self._layer_wise_routing_weights_save_hook_handles.items():
|
|
119
90
|
self._layer_wise_routing_weights_save_hooks[i].save_routing_weights()
|
|
91
|
+
self._layer_wise_routing_weights_save_hook_handles.pop(i)
|
|
120
92
|
handle.remove()
|
|
@@ -32,8 +32,7 @@ from fusion_bench.mixins import LightningFabricMixin
|
|
|
32
32
|
from fusion_bench.models.hf_clip import HFCLIPClassifier
|
|
33
33
|
from fusion_bench.taskpool import BaseTaskPool
|
|
34
34
|
from fusion_bench.tasks.clip_classification import get_classnames_and_templates
|
|
35
|
-
from fusion_bench.utils import instantiate
|
|
36
|
-
from fusion_bench.utils.parameters import count_parameters
|
|
35
|
+
from fusion_bench.utils import count_parameters, instantiate
|
|
37
36
|
|
|
38
37
|
if TYPE_CHECKING:
|
|
39
38
|
from fusion_bench.models.surgery.surgerymodelwrapper import SurgeryModelWrapper
|
|
File without changes
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import List, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _number_of_samples(routing_weights: List[Tensor]):
|
|
9
|
+
count = 0
|
|
10
|
+
for routing_weight in routing_weights:
|
|
11
|
+
count += routing_weight.size(0)
|
|
12
|
+
return count
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class LayerWiseRoutingWeightSaver:
|
|
16
|
+
"""
|
|
17
|
+
A hook for saving layer-wise routing weights.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
save_path: Path
|
|
21
|
+
"The path to save the layer-wise routing weights."
|
|
22
|
+
max_num: Optional[int]
|
|
23
|
+
"The maximum number of layer-wise routing weights to save. If None, all routing weights will be saved."
|
|
24
|
+
routing_weights: List[Tensor]
|
|
25
|
+
"The list of layer-wise routing weights."
|
|
26
|
+
|
|
27
|
+
def __init__(self, save_path: Path, max_num: Optional[int] = None):
|
|
28
|
+
"""
|
|
29
|
+
Args:
|
|
30
|
+
save_path (Path): The path to save the layer-wise routing weights.
|
|
31
|
+
max_num (Optional[int]): The maximum number of layer-wise routing weights to save. If None, all routing weights will be saved.
|
|
32
|
+
"""
|
|
33
|
+
self.save_path = save_path
|
|
34
|
+
self.max_num = max_num
|
|
35
|
+
self.routing_weights = []
|
|
36
|
+
|
|
37
|
+
def __call__(self, module, input: Tuple[Tensor], output: Tensor):
|
|
38
|
+
assert isinstance(output, Tensor), "Output is expected to be a Tensor"
|
|
39
|
+
# (batch_size, num_tokens, num_experts)
|
|
40
|
+
routing_weights = output.detach().cpu()
|
|
41
|
+
if self.max_num is not None and self.max_num > 0:
|
|
42
|
+
if _number_of_samples(self.routing_weights) > self.max_num:
|
|
43
|
+
return
|
|
44
|
+
elif (
|
|
45
|
+
routing_weights.size(0) + _number_of_samples(self.routing_weights)
|
|
46
|
+
> self.max_num
|
|
47
|
+
):
|
|
48
|
+
self.routing_weights.append(
|
|
49
|
+
routing_weights[
|
|
50
|
+
: self.max_num - _number_of_samples(self.routing_weights)
|
|
51
|
+
]
|
|
52
|
+
)
|
|
53
|
+
else:
|
|
54
|
+
self.routing_weights.append(routing_weights)
|
|
55
|
+
else:
|
|
56
|
+
self.routing_weights.append(routing_weights)
|
|
57
|
+
|
|
58
|
+
def save_routing_weights(self):
|
|
59
|
+
routing_weights = torch.cat(self.routing_weights, dim=0)
|
|
60
|
+
if self.save_path is not None:
|
|
61
|
+
self.save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
62
|
+
print(
|
|
63
|
+
f"Saving routing weights to {self.save_path}. Size: {routing_weights.size()}"
|
|
64
|
+
)
|
|
65
|
+
torch.save(routing_weights, self.save_path)
|
|
@@ -139,11 +139,40 @@ class GPT2TextClassificationTaskPool(BaseTaskPool, LightningFabricMixin):
|
|
|
139
139
|
return dataloader
|
|
140
140
|
|
|
141
141
|
@override
|
|
142
|
-
def evaluate(self, model: GPT2Model):
|
|
142
|
+
def evaluate(self, model: GPT2Model, name: str = None):
|
|
143
|
+
"""Evaluate the model on the test datasets.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
model (GPT2Model): The model to evaluate.
|
|
147
|
+
name (str, optional): The name of the model. Defaults to None. This is used to identify the model in the report.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
dict: A dictionary containing the evaluation results for each task.
|
|
151
|
+
"""
|
|
143
152
|
report = {}
|
|
153
|
+
if name is not None:
|
|
154
|
+
report["name"] = name
|
|
144
155
|
for task_name in (pbar := tqdm(self._test_datasets, desc="Evaluating tasks")):
|
|
145
156
|
pbar.set_description(f"Evaluating task {task_name}")
|
|
146
157
|
dataloader = self.get_test_dataloader(task_name)
|
|
147
158
|
result = self.evaluate_single_task(task_name, model, dataloader)
|
|
148
159
|
report[task_name] = result
|
|
160
|
+
|
|
161
|
+
# calculate the average accuracy and loss
|
|
162
|
+
if "average" not in report:
|
|
163
|
+
report["average"] = {}
|
|
164
|
+
accuracies = [
|
|
165
|
+
value["accuracy"]
|
|
166
|
+
for key, value in report.items()
|
|
167
|
+
if isinstance(value, dict) and "accuracy" in value
|
|
168
|
+
]
|
|
169
|
+
if len(accuracies) > 0:
|
|
170
|
+
average_accuracy = sum(accuracies) / len(accuracies)
|
|
171
|
+
report["average"]["accuracy"] = average_accuracy
|
|
172
|
+
losses = [value["loss"] for key, value in report.items() if "loss" in value]
|
|
173
|
+
if len(losses) > 0:
|
|
174
|
+
average_loss = sum(losses) / len(losses)
|
|
175
|
+
report["average"]["loss"] = average_loss
|
|
176
|
+
|
|
177
|
+
log.info(f"Evaluation Result: {report}")
|
|
149
178
|
return report
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from typing import List, Literal, Optional, Union, TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
import lightning.fabric
|
|
6
|
+
import lm_eval
|
|
7
|
+
import lm_eval.models
|
|
8
|
+
from lm_eval.__main__ import check_argument_types, cli_evaluate, setup_parser
|
|
9
|
+
from omegaconf import DictConfig, ListConfig
|
|
10
|
+
|
|
11
|
+
from fusion_bench import BaseTaskPool
|
|
12
|
+
from fusion_bench.mixins import LightningFabricMixin
|
|
13
|
+
from fusion_bench.utils.strenum import _version
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
log = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class LMEvalHarnessTaskPool(BaseTaskPool, LightningFabricMixin):
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
tasks: Union[str, List[str]],
|
|
23
|
+
apply_chat_template: bool = False,
|
|
24
|
+
include_path: Optional[str] = None,
|
|
25
|
+
batch_size: int = 1,
|
|
26
|
+
metadata: Optional[DictConfig] = None,
|
|
27
|
+
verbosity: Optional[
|
|
28
|
+
Literal["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]
|
|
29
|
+
] = None,
|
|
30
|
+
output_path: Optional[str] = None,
|
|
31
|
+
log_samples: bool = False,
|
|
32
|
+
_usage_: Optional[str] = None,
|
|
33
|
+
_version_: Optional[str] = None,
|
|
34
|
+
**kwargs,
|
|
35
|
+
):
|
|
36
|
+
super().__init__(_usage_=_usage_, _version_=_version_)
|
|
37
|
+
self.tasks = tasks
|
|
38
|
+
self.include_path = include_path
|
|
39
|
+
self.batch_size = batch_size
|
|
40
|
+
self.metadata = metadata
|
|
41
|
+
self.apply_chat_template = apply_chat_template
|
|
42
|
+
self.verbosity = verbosity
|
|
43
|
+
self.kwargs = kwargs
|
|
44
|
+
self.output_path = output_path
|
|
45
|
+
self.log_samples = log_samples
|
|
46
|
+
|
|
47
|
+
def evaluate(self, model, *command_line_args, **kwargs):
|
|
48
|
+
command_line_args = []
|
|
49
|
+
if self.include_path is not None:
|
|
50
|
+
command_line_args.extend(["--include_path", self.include_path])
|
|
51
|
+
if isinstance(self.tasks, (list, ListConfig)):
|
|
52
|
+
command_line_args.extend(["--tasks", ",".join(self.tasks)])
|
|
53
|
+
elif isinstance(self.tasks, str):
|
|
54
|
+
command_line_args.extend(["--tasks", self.tasks])
|
|
55
|
+
if self.apply_chat_template:
|
|
56
|
+
command_line_args.extend(
|
|
57
|
+
["--apply_chat_template", str(self.apply_chat_template)]
|
|
58
|
+
)
|
|
59
|
+
if self.batch_size is not None:
|
|
60
|
+
command_line_args.extend(["--batch_size", str(self.batch_size)])
|
|
61
|
+
if self.verbosity is not None:
|
|
62
|
+
command_line_args.extend(["--verbosity", str(self.verbosity)])
|
|
63
|
+
if self.metadata is not None:
|
|
64
|
+
command_line_args.extend(["--metadata", str(self.metadata)])
|
|
65
|
+
if self.output_path is None:
|
|
66
|
+
command_line_args.extend(
|
|
67
|
+
[
|
|
68
|
+
"--output_path",
|
|
69
|
+
os.path.join(self.log_dir, "lm_eval_results"),
|
|
70
|
+
]
|
|
71
|
+
)
|
|
72
|
+
else:
|
|
73
|
+
command_line_args.extend(["--output_path", self.output_path])
|
|
74
|
+
if self.log_samples:
|
|
75
|
+
command_line_args.extend(["--log_samples"])
|
|
76
|
+
for key, value in kwargs.items():
|
|
77
|
+
command_line_args.extend([f"--{key}", str(value)])
|
|
78
|
+
|
|
79
|
+
parser = setup_parser()
|
|
80
|
+
check_argument_types(parser)
|
|
81
|
+
args = parser.parse_args(args=command_line_args)
|
|
82
|
+
log.info("LM-Eval Harness arguments:\n%s", args)
|
|
83
|
+
|
|
84
|
+
if not lightning.fabric.is_wrapped(model):
|
|
85
|
+
model = self.fabric.setup(model)
|
|
86
|
+
args.model = lm_eval.models.huggingface.HFLM(pretrained=model)
|
|
87
|
+
cli_evaluate(args)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .openclip_taskpool import OpenCLIPVisionModelTaskPool
|