fusion-bench 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +2 -0
- fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
- fusion_bench/constants/clip_vision.py +22 -0
- fusion_bench/dataset/clip_dataset.py +10 -2
- fusion_bench/dataset/fer2013.py +1 -0
- fusion_bench/dataset/gsm8k.py +2 -2
- fusion_bench/method/__init__.py +10 -0
- fusion_bench/method/ada_svd/clip_vision.py +4 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
- fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
- fusion_bench/method/gossip/__init__.py +3 -0
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
- fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
- fusion_bench/method/gossip/entropy_loss.py +25 -0
- fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
- fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
- fusion_bench/method/gossip/min_norm_solvers.py +227 -0
- fusion_bench/method/gossip/task_wise_gossip.py +265 -0
- fusion_bench/method/gossip/utils.py +74 -0
- fusion_bench/method/isotropic_merging/__init__.py +1 -1
- fusion_bench/method/opcm/opcm.py +16 -7
- fusion_bench/method/pwe_moe/module.py +1 -1
- fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
- fusion_bench/method/regmean/regmean.py +25 -17
- fusion_bench/method/smile_upscaling/__init__.py +1 -1
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +46 -145
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +229 -0
- fusion_bench/method/smile_upscaling/smile_upscaling.py +19 -346
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
- fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
- fusion_bench/method/ties_merging/ties_merging.py +36 -31
- fusion_bench/method/we_moe/we_moe.py +14 -15
- fusion_bench/mixins/__init__.py +6 -3
- fusion_bench/mixins/hydra_config.py +49 -0
- fusion_bench/mixins/openclip_classification.py +11 -0
- fusion_bench/mixins/simple_profiler.py +4 -2
- fusion_bench/modelpool/__init__.py +3 -1
- fusion_bench/modelpool/base_pool.py +2 -2
- fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
- fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
- fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +2 -203
- fusion_bench/models/modeling_smile_qwen2/__init__.py +8 -0
- fusion_bench/models/modeling_smile_qwen2/configuration_smile_qwen2.py +21 -0
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +922 -0
- fusion_bench/models/modeling_smile_qwen2/register.py +11 -0
- fusion_bench/models/open_clip/__init__.py +6 -0
- fusion_bench/models/open_clip/modeling.py +176 -0
- fusion_bench/models/open_clip/utils.py +311 -0
- fusion_bench/models/open_clip/variables_and_paths.py +56 -0
- fusion_bench/models/parameter_dict.py +54 -13
- fusion_bench/models/rankone_moe.py +2 -88
- fusion_bench/models/smile_moe/linear_from_hf_config.py +373 -0
- fusion_bench/models/smile_moe/{linear.py → linear_from_module.py} +103 -33
- fusion_bench/models/smile_moe/utils/__init__.py +24 -0
- fusion_bench/models/smile_moe/utils/svd_utils.py +46 -0
- fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
- fusion_bench/taskpool/__init__.py +7 -3
- fusion_bench/taskpool/clip_vision/__init__.py +1 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
- fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
- fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
- fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
- fusion_bench/taskpool/gpt2_text_classification.py +30 -1
- fusion_bench/taskpool/lm_eval_harness/__init__.py +3 -0
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +87 -0
- fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
- fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
- fusion_bench/utils/data.py +12 -0
- fusion_bench/utils/devices.py +14 -0
- fusion_bench/utils/instantiate.py +12 -0
- fusion_bench/utils/misc.py +9 -2
- fusion_bench/utils/packages.py +14 -0
- fusion_bench/utils/parameters.py +1 -1
- fusion_bench/utils/tensorboard.py +1 -1
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/METADATA +22 -2
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/RECORD +209 -157
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/WHEEL +1 -1
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
- fusion_bench_config/fabric/auto.yaml +0 -1
- fusion_bench_config/fabric/llama_ddp.yaml +0 -1
- fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric_model_fusion.yaml +0 -1
- fusion_bench_config/llama_full_finetune.yaml +0 -2
- fusion_bench_config/llama_model_fusion.yaml +0 -2
- fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
- fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
- fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
- fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
- fusion_bench_config/method/adamerging.yaml +2 -2
- fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
- fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
- fusion_bench_config/method/dare/simple_average.yaml +0 -1
- fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
- fusion_bench_config/method/dare/ties_merging.yaml +0 -2
- fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
- fusion_bench_config/method/doge_ta/doge_ta.yaml +1 -1
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
- fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
- fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
- fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
- fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
- fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
- fusion_bench_config/method/linear/llama_expo.yaml +0 -3
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
- fusion_bench_config/method/linear/weighted_average.yaml +0 -1
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
- fusion_bench_config/method/model_recombination.yaml +0 -1
- fusion_bench_config/method/opcm/opcm.yaml +0 -1
- fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
- fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
- fusion_bench_config/method/opcm/weight_average.yaml +0 -1
- fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
- fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
- fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
- fusion_bench_config/method/slerp/slerp.yaml +0 -2
- fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +5 -2
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +13 -0
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
- fusion_bench_config/method/task_arithmetic.yaml +1 -1
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
- fusion_bench_config/method/ties_merging.yaml +1 -1
- fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
- fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +17 -0
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +0 -2
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +1 -3
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
- fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
- fusion_bench_config/nyuv2_config.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml +12 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
- fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,476 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
|
|
7
|
+
|
|
8
|
+
import lightning.fabric
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pandas as pd
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn.functional as F
|
|
13
|
+
from omegaconf import DictConfig, OmegaConf
|
|
14
|
+
from open_clip.model import ResidualAttentionBlock
|
|
15
|
+
from torch import Tensor, nn
|
|
16
|
+
from tqdm.auto import tqdm
|
|
17
|
+
|
|
18
|
+
from fusion_bench import BaseAlgorithm
|
|
19
|
+
from fusion_bench.dataset.clip_dataset import CLIPDataset
|
|
20
|
+
from fusion_bench.method.task_arithmetic import task_arithmetic_merge
|
|
21
|
+
from fusion_bench.mixins import OpenCLIPClassificationMixin, SimpleProfilerMixin
|
|
22
|
+
from fusion_bench.modelpool import OpenCLIPVisionModelPool
|
|
23
|
+
from fusion_bench.models.open_clip import ClassificationHead, ImageEncoder
|
|
24
|
+
from fusion_bench.utils import print_parameters, timeit_context
|
|
25
|
+
from fusion_bench.utils.data import InfiniteDataLoader
|
|
26
|
+
|
|
27
|
+
from .module import ParetoWeightEnsemblingModule
|
|
28
|
+
from .phn.solvers import EPOSolver
|
|
29
|
+
from .utils import generate_simplex_grid
|
|
30
|
+
|
|
31
|
+
log = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class PWEMoEAlgorithmForOpenCLIP(
|
|
35
|
+
BaseAlgorithm,
|
|
36
|
+
SimpleProfilerMixin,
|
|
37
|
+
OpenCLIPClassificationMixin,
|
|
38
|
+
):
|
|
39
|
+
modelpool: OpenCLIPVisionModelPool
|
|
40
|
+
|
|
41
|
+
#! === Training & Validation Data ===
|
|
42
|
+
# setup the datasets and loaders by calling `load_datasets`
|
|
43
|
+
train_datasets: Dict[str, CLIPDataset]
|
|
44
|
+
train_loaders: Dict[str, torch.utils.data.DataLoader]
|
|
45
|
+
train_loader_iters: Dict[str, Iterator[Tuple[torch.Tensor, torch.Tensor]]]
|
|
46
|
+
|
|
47
|
+
test_datasets: Dict[str, CLIPDataset]
|
|
48
|
+
test_loaders: Dict[str, torch.utils.data.DataLoader]
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
*,
|
|
53
|
+
#! === Model Architecture Arguments ===
|
|
54
|
+
partial: bool,
|
|
55
|
+
init_lambda: float,
|
|
56
|
+
router_hidden_layers: int,
|
|
57
|
+
checkpoint_path: str,
|
|
58
|
+
#! === Training Arguments ===
|
|
59
|
+
run_train: bool,
|
|
60
|
+
num_steps: int,
|
|
61
|
+
save_interval: int,
|
|
62
|
+
lr: float,
|
|
63
|
+
alpha: float,
|
|
64
|
+
dataloader_kwargs: DictConfig,
|
|
65
|
+
#! === Evaluation Arguments ===
|
|
66
|
+
run_eval: bool,
|
|
67
|
+
num_evaluation_samples: Union[str, int],
|
|
68
|
+
quick_evaluation: bool,
|
|
69
|
+
**kwargs,
|
|
70
|
+
):
|
|
71
|
+
super().__init__(**kwargs)
|
|
72
|
+
self.partial = partial
|
|
73
|
+
self.init_lambda = init_lambda
|
|
74
|
+
self.router_hidden_layers = router_hidden_layers
|
|
75
|
+
self.lr = lr
|
|
76
|
+
self.num_steps = num_steps
|
|
77
|
+
self.save_interval = save_interval
|
|
78
|
+
self.alpha = alpha
|
|
79
|
+
self.checkpoint_path = checkpoint_path
|
|
80
|
+
self._dataloader_kwargs = dataloader_kwargs
|
|
81
|
+
self.run_train = run_train
|
|
82
|
+
self.run_eval = run_eval
|
|
83
|
+
self.num_evaluation_samples = num_evaluation_samples
|
|
84
|
+
self.quick_evaluation = quick_evaluation
|
|
85
|
+
|
|
86
|
+
def run(self, modelpool: OpenCLIPVisionModelPool):
|
|
87
|
+
self.modelpool = modelpool
|
|
88
|
+
|
|
89
|
+
# setup the MoE model
|
|
90
|
+
model = self.load_model()
|
|
91
|
+
if self.checkpoint_path is not None:
|
|
92
|
+
self.fabric.load(self.checkpoint_path, {"model": model})
|
|
93
|
+
|
|
94
|
+
# setup dataloaders
|
|
95
|
+
self.load_datasets()
|
|
96
|
+
|
|
97
|
+
if self.run_train:
|
|
98
|
+
model = self.train()
|
|
99
|
+
if self.run_eval:
|
|
100
|
+
self.evaluate(model)
|
|
101
|
+
return model
|
|
102
|
+
|
|
103
|
+
@torch.no_grad()
|
|
104
|
+
def load_model(self):
|
|
105
|
+
modelpool = self.modelpool
|
|
106
|
+
|
|
107
|
+
# load models and classification heads
|
|
108
|
+
pretrained_model: ImageEncoder = self.modelpool.load_pretrained_model()
|
|
109
|
+
log.info("pretrained model statistics:")
|
|
110
|
+
print_parameters(pretrained_model, print_fn=log.info)
|
|
111
|
+
|
|
112
|
+
finetuned_models: Dict[str, ImageEncoder] = {}
|
|
113
|
+
for model_name in self.modelpool.model_names:
|
|
114
|
+
finetuned_models[model_name] = modelpool.load_model(model_name)
|
|
115
|
+
|
|
116
|
+
classification_heads: Dict[str, ClassificationHead] = {}
|
|
117
|
+
for model_name in self.modelpool.model_names:
|
|
118
|
+
classification_heads[model_name] = modelpool.load_classification_head(
|
|
119
|
+
model_name
|
|
120
|
+
)
|
|
121
|
+
self.classification_heads = classification_heads
|
|
122
|
+
|
|
123
|
+
self.train_processor = modelpool.train_processor
|
|
124
|
+
self.test_processor = modelpool.test_processor
|
|
125
|
+
|
|
126
|
+
with timeit_context("Building the MoE model"):
|
|
127
|
+
model = deepcopy(pretrained_model)
|
|
128
|
+
|
|
129
|
+
if self.partial:
|
|
130
|
+
log.info("Weight ensembling only the MLPs")
|
|
131
|
+
# weight ensembling only the MLPs, merge the remaining layers using task arithmetic
|
|
132
|
+
model = task_arithmetic_merge(
|
|
133
|
+
pretrained_model=model,
|
|
134
|
+
finetuned_models=list(finetuned_models.values()),
|
|
135
|
+
scaling_factor=self.init_lambda,
|
|
136
|
+
inplace=True,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# fix all parameters
|
|
140
|
+
model.requires_grad_(False)
|
|
141
|
+
|
|
142
|
+
for layer_idx in tqdm(
|
|
143
|
+
range(model.model.visual.transformer.layers), desc="Upscaling MLPs"
|
|
144
|
+
):
|
|
145
|
+
resblock: ResidualAttentionBlock = (
|
|
146
|
+
model.model.visual.transformer.resblocks[layer_idx]
|
|
147
|
+
)
|
|
148
|
+
resblock.mlp = ParetoWeightEnsemblingModule(
|
|
149
|
+
base_model=cast(
|
|
150
|
+
ResidualAttentionBlock,
|
|
151
|
+
pretrained_model.model.visual.transformer.resblocks[
|
|
152
|
+
layer_idx
|
|
153
|
+
],
|
|
154
|
+
).mlp,
|
|
155
|
+
expert_models=[
|
|
156
|
+
cast(
|
|
157
|
+
ResidualAttentionBlock,
|
|
158
|
+
m.model.visual.transformer.resblocks[layer_idx],
|
|
159
|
+
).mlp
|
|
160
|
+
for m in finetuned_models.values()
|
|
161
|
+
],
|
|
162
|
+
init_lambda=self.init_lambda,
|
|
163
|
+
fix_base_model_and_experts=True,
|
|
164
|
+
router_hidden_layers=self.router_hidden_layers,
|
|
165
|
+
)
|
|
166
|
+
else:
|
|
167
|
+
log.info("Weight ensembling all the layers")
|
|
168
|
+
# weight ensembling all the layers, merge the remaining layers using task arithmetic
|
|
169
|
+
model = task_arithmetic_merge(
|
|
170
|
+
pretrained_model=model,
|
|
171
|
+
finetuned_models=list(finetuned_models.values()),
|
|
172
|
+
scaling_factor=self.init_lambda,
|
|
173
|
+
inplace=True,
|
|
174
|
+
)
|
|
175
|
+
# fix all parameters
|
|
176
|
+
model.requires_grad_(False)
|
|
177
|
+
|
|
178
|
+
for name in [
|
|
179
|
+
"conv1",
|
|
180
|
+
"ln_pre",
|
|
181
|
+
"ln_post",
|
|
182
|
+
# "class_embedding",
|
|
183
|
+
# "positional_embedding",
|
|
184
|
+
]:
|
|
185
|
+
setattr(
|
|
186
|
+
model.model.visual,
|
|
187
|
+
name,
|
|
188
|
+
ParetoWeightEnsemblingModule(
|
|
189
|
+
base_model=getattr(pretrained_model.model.visual, name),
|
|
190
|
+
expert_models=[
|
|
191
|
+
getattr(m.model.visual, name)
|
|
192
|
+
for m in finetuned_models.values()
|
|
193
|
+
],
|
|
194
|
+
init_lambda=self.init_lambda,
|
|
195
|
+
fix_base_model_and_experts=True,
|
|
196
|
+
router_hidden_layers=self.router_hidden_layers,
|
|
197
|
+
),
|
|
198
|
+
)
|
|
199
|
+
for layer_idx in tqdm(
|
|
200
|
+
range(model.model.visual.transformer.layers),
|
|
201
|
+
desc="Upscaling the transformer layers",
|
|
202
|
+
):
|
|
203
|
+
for name in ["ln_1", "attn", "ln_attn", "ln_2", "mlp"]:
|
|
204
|
+
setattr(
|
|
205
|
+
model.model.visual.transformer.resblocks[layer_idx],
|
|
206
|
+
name,
|
|
207
|
+
ParetoWeightEnsemblingModule(
|
|
208
|
+
base_model=getattr(
|
|
209
|
+
cast(
|
|
210
|
+
ResidualAttentionBlock,
|
|
211
|
+
pretrained_model.model.visual.transformer.resblocks[
|
|
212
|
+
layer_idx
|
|
213
|
+
],
|
|
214
|
+
),
|
|
215
|
+
name,
|
|
216
|
+
),
|
|
217
|
+
expert_models=[
|
|
218
|
+
getattr(
|
|
219
|
+
cast(
|
|
220
|
+
ResidualAttentionBlock,
|
|
221
|
+
m.model.visual.transformer.resblocks[
|
|
222
|
+
layer_idx
|
|
223
|
+
],
|
|
224
|
+
),
|
|
225
|
+
name,
|
|
226
|
+
)
|
|
227
|
+
for m in finetuned_models.values()
|
|
228
|
+
],
|
|
229
|
+
init_lambda=self.init_lambda,
|
|
230
|
+
fix_base_model_and_experts=True,
|
|
231
|
+
router_hidden_layers=self.router_hidden_layers,
|
|
232
|
+
),
|
|
233
|
+
)
|
|
234
|
+
for name in ["token_embedding", "ln_final"]:
|
|
235
|
+
setattr(
|
|
236
|
+
model.model,
|
|
237
|
+
name,
|
|
238
|
+
ParetoWeightEnsemblingModule(
|
|
239
|
+
base_model=getattr(pretrained_model.model, name),
|
|
240
|
+
expert_models=[
|
|
241
|
+
getattr(m.model, name)
|
|
242
|
+
for m in finetuned_models.values()
|
|
243
|
+
],
|
|
244
|
+
init_lambda=self.init_lambda,
|
|
245
|
+
fix_base_model_and_experts=True,
|
|
246
|
+
router_hidden_layers=self.router_hidden_layers,
|
|
247
|
+
),
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
self.model = model
|
|
251
|
+
print_parameters(model, print_fn=log.info)
|
|
252
|
+
return model
|
|
253
|
+
|
|
254
|
+
def load_datasets(self):
|
|
255
|
+
modelpool = self.modelpool
|
|
256
|
+
|
|
257
|
+
# setup the train datasets and loaders
|
|
258
|
+
train_datasets = {}
|
|
259
|
+
train_loaders = {}
|
|
260
|
+
train_loader_iters = {}
|
|
261
|
+
for dataset_name in modelpool.train_dataset_names:
|
|
262
|
+
train_datasets[dataset_name] = modelpool.load_train_dataset(dataset_name)
|
|
263
|
+
train_datasets[dataset_name] = CLIPDataset(
|
|
264
|
+
train_datasets[dataset_name], self.train_processor
|
|
265
|
+
)
|
|
266
|
+
# sanity check
|
|
267
|
+
assert isinstance(train_datasets[dataset_name][0], tuple)
|
|
268
|
+
|
|
269
|
+
# setup the train loaders
|
|
270
|
+
train_loaders[dataset_name] = torch.utils.data.DataLoader(
|
|
271
|
+
train_datasets[dataset_name],
|
|
272
|
+
shuffle=True,
|
|
273
|
+
drop_last=True,
|
|
274
|
+
**self._dataloader_kwargs,
|
|
275
|
+
)
|
|
276
|
+
train_loaders[dataset_name] = self.fabric.setup_dataloaders(
|
|
277
|
+
train_loaders[dataset_name]
|
|
278
|
+
)
|
|
279
|
+
train_loaders[dataset_name] = InfiniteDataLoader(
|
|
280
|
+
train_loaders[dataset_name]
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
# setup the train loader iterators
|
|
284
|
+
train_loader_iters[dataset_name] = iter(train_loaders[dataset_name])
|
|
285
|
+
|
|
286
|
+
self.train_datasets = train_datasets
|
|
287
|
+
self.train_loaders = train_loaders
|
|
288
|
+
self.train_loader_iters = train_loader_iters
|
|
289
|
+
|
|
290
|
+
# setup the test datasets and loaders
|
|
291
|
+
test_datasets = {}
|
|
292
|
+
test_loaders = {}
|
|
293
|
+
for dataset_name in modelpool.test_dataset_names:
|
|
294
|
+
test_datasets[dataset_name] = modelpool.load_test_dataset(dataset_name)
|
|
295
|
+
test_datasets[dataset_name] = CLIPDataset(
|
|
296
|
+
test_datasets[dataset_name], self.test_processor
|
|
297
|
+
)
|
|
298
|
+
test_loaders[dataset_name] = torch.utils.data.DataLoader(
|
|
299
|
+
test_datasets[dataset_name],
|
|
300
|
+
shuffle=False,
|
|
301
|
+
**self._dataloader_kwargs,
|
|
302
|
+
)
|
|
303
|
+
test_loaders[dataset_name] = self.fabric.setup_dataloaders(
|
|
304
|
+
test_loaders[dataset_name]
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
self.test_datasets = test_datasets
|
|
308
|
+
self.test_loaders = test_loaders
|
|
309
|
+
|
|
310
|
+
def compute_loss(self, model: ImageEncoder, ray: Tensor):
|
|
311
|
+
losses = []
|
|
312
|
+
for dataset_idx, dataset_name in enumerate(self.train_datasets):
|
|
313
|
+
batch = next(self.train_loader_iters[dataset_name])
|
|
314
|
+
x, y = batch
|
|
315
|
+
|
|
316
|
+
features = model(x)
|
|
317
|
+
logits = self.classification_heads[dataset_name](features)
|
|
318
|
+
|
|
319
|
+
_loss = F.cross_entropy(logits, y)
|
|
320
|
+
losses.append(_loss)
|
|
321
|
+
|
|
322
|
+
loss = self.aggregate_loss(model, ray, losses)
|
|
323
|
+
return loss
|
|
324
|
+
|
|
325
|
+
@abstractmethod
|
|
326
|
+
def aggregate_loss(self, model: nn.Module, ray: Tensor, losses: Tuple[Tensor]):
|
|
327
|
+
pass
|
|
328
|
+
|
|
329
|
+
def train(self):
|
|
330
|
+
# setup the model
|
|
331
|
+
num_objectives = len(self.modelpool.model_names)
|
|
332
|
+
model = deepcopy(self.model)
|
|
333
|
+
self.classification_heads = {
|
|
334
|
+
t: h.to(self.fabric.device) for t, h in self.classification_heads.items()
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
# set up the optimizer and learning rate scheduler
|
|
338
|
+
optimizer = torch.optim.Adam(
|
|
339
|
+
filter(lambda p: p.requires_grad, model.parameters()),
|
|
340
|
+
lr=self.lr,
|
|
341
|
+
)
|
|
342
|
+
model, optimizer = self.fabric.setup(model, optimizer)
|
|
343
|
+
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
344
|
+
optimizer=optimizer, T_max=self.num_steps, eta_min=self.lr * 0.1
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
model.train()
|
|
348
|
+
device = self.fabric.device
|
|
349
|
+
for step_idx in tqdm(
|
|
350
|
+
range(1, 1 + self.num_steps), "training", dynamic_ncols=True
|
|
351
|
+
):
|
|
352
|
+
# sample a preference ray
|
|
353
|
+
ray = torch.from_numpy(
|
|
354
|
+
np.random.dirichlet((self.alpha,) * num_objectives, 1)
|
|
355
|
+
.astype(np.float32)
|
|
356
|
+
.flatten()
|
|
357
|
+
).to(device)
|
|
358
|
+
ParetoWeightEnsemblingModule.set_preferenece_vector(model, ray)
|
|
359
|
+
|
|
360
|
+
loss = self.compute_loss(model, ray)
|
|
361
|
+
|
|
362
|
+
optimizer.zero_grad()
|
|
363
|
+
self.fabric.backward(loss)
|
|
364
|
+
optimizer.step()
|
|
365
|
+
|
|
366
|
+
lr_scheduler.step()
|
|
367
|
+
|
|
368
|
+
self.fabric.log("loss", loss.item(), step=step_idx)
|
|
369
|
+
|
|
370
|
+
if step_idx % self.save_interval == 0 or step_idx == self.num_steps:
|
|
371
|
+
ckpt_dir = Path(self.log_dir) / "checkpoints"
|
|
372
|
+
ckpt_dir.mkdir(exist_ok=True, parents=True)
|
|
373
|
+
self.fabric.save(
|
|
374
|
+
ckpt_dir / f"model_step={step_idx}.ckpt",
|
|
375
|
+
{"model": model},
|
|
376
|
+
)
|
|
377
|
+
return model
|
|
378
|
+
|
|
379
|
+
def evaluate(self, model):
|
|
380
|
+
results = defaultdict(list)
|
|
381
|
+
|
|
382
|
+
num_objectives = len(self.modelpool.model_names)
|
|
383
|
+
device = self.fabric.device
|
|
384
|
+
self.classification_heads = {
|
|
385
|
+
t: h.to(self.fabric.device) for t, h in self.classification_heads.items()
|
|
386
|
+
}
|
|
387
|
+
|
|
388
|
+
if not lightning.fabric.is_wrapped(model):
|
|
389
|
+
model = self.fabric.setup_module(model)
|
|
390
|
+
model.eval()
|
|
391
|
+
|
|
392
|
+
if self.num_evaluation_samples == "equal_weight":
|
|
393
|
+
uniform_grid = np.array(
|
|
394
|
+
[[1 / num_objectives] * num_objectives], dtype=np.float32
|
|
395
|
+
)
|
|
396
|
+
else:
|
|
397
|
+
uniform_grid = generate_simplex_grid(
|
|
398
|
+
num_objectives, self.num_evaluation_samples
|
|
399
|
+
)
|
|
400
|
+
for ray_idx, ray in tqdm(enumerate(uniform_grid), "evaluating samples"):
|
|
401
|
+
results["ray_idx"].append(ray_idx)
|
|
402
|
+
# sample a preference ray
|
|
403
|
+
for i in range(len(ray)):
|
|
404
|
+
results[f"ray_{i}"].append(ray[i])
|
|
405
|
+
ray = torch.from_numpy(ray).to(device)
|
|
406
|
+
ParetoWeightEnsemblingModule.set_preferenece_vector(model, ray)
|
|
407
|
+
|
|
408
|
+
accs = []
|
|
409
|
+
for dataset_idx, dataset_name in enumerate(
|
|
410
|
+
tqdm(
|
|
411
|
+
self.modelpool.test_dataset_names,
|
|
412
|
+
"evaluating datasets",
|
|
413
|
+
leave=False,
|
|
414
|
+
)
|
|
415
|
+
):
|
|
416
|
+
test_loader = self.test_loaders[dataset_name]
|
|
417
|
+
TOTAL_CORRECT = 0
|
|
418
|
+
TOTAL_COUNT = 0
|
|
419
|
+
for batch_idx, batch in enumerate(
|
|
420
|
+
pbar := tqdm(
|
|
421
|
+
test_loader,
|
|
422
|
+
f"evaluate {dataset_name}",
|
|
423
|
+
leave=False,
|
|
424
|
+
)
|
|
425
|
+
):
|
|
426
|
+
x, y = batch
|
|
427
|
+
|
|
428
|
+
features = model(x)
|
|
429
|
+
logits = self.classification_heads[dataset_name](features)
|
|
430
|
+
preds = logits.argmax(-1)
|
|
431
|
+
|
|
432
|
+
correct = (preds == y).sum().item()
|
|
433
|
+
TOTAL_CORRECT += correct
|
|
434
|
+
TOTAL_COUNT += len(y)
|
|
435
|
+
acc = TOTAL_CORRECT / TOTAL_COUNT
|
|
436
|
+
pbar.set_postfix_str(f"acc={acc:.2f}")
|
|
437
|
+
|
|
438
|
+
if self.quick_evaluation and batch_idx > 20:
|
|
439
|
+
break
|
|
440
|
+
results[dataset_name].append(acc)
|
|
441
|
+
accs.append(acc)
|
|
442
|
+
|
|
443
|
+
# compute the average accuracy
|
|
444
|
+
if "average" not in self.modelpool.test_dataset_names:
|
|
445
|
+
results["average"].append(np.mean(accs))
|
|
446
|
+
|
|
447
|
+
(df := pd.DataFrame(results)).to_csv(
|
|
448
|
+
Path(self.log_dir) / "result.csv", index=False
|
|
449
|
+
)
|
|
450
|
+
log.info(df)
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
class PWEMoELinearScalarizationForOpenCLIP(PWEMoEAlgorithmForOpenCLIP):
|
|
454
|
+
def aggregate_loss(self, model: nn.Module, ray: Tensor, losses: Tuple[Tensor]):
|
|
455
|
+
loss = 0
|
|
456
|
+
for r, l in zip(ray, losses):
|
|
457
|
+
loss += r * l
|
|
458
|
+
return loss
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
class PWEMoEExactParetoOptimalForOpenCLIP(PWEMoEAlgorithmForOpenCLIP):
|
|
462
|
+
epo_solver: Optional[EPOSolver] = None
|
|
463
|
+
|
|
464
|
+
def aggregate_loss(self, model: nn.Module, ray: Tensor, losses: Tuple[Tensor]):
|
|
465
|
+
if self.epo_solver is None:
|
|
466
|
+
num_objectives = len(self.modelpool.model_names)
|
|
467
|
+
self.epo_solver = EPOSolver(n_tasks=num_objectives, n_params=None)
|
|
468
|
+
epo_solver = self.epo_solver
|
|
469
|
+
|
|
470
|
+
losses = torch.stack(losses)
|
|
471
|
+
loss = epo_solver.get_weighted_loss(
|
|
472
|
+
losses,
|
|
473
|
+
ray,
|
|
474
|
+
tuple(filter(lambda p: p.requires_grad, model.parameters())),
|
|
475
|
+
)
|
|
476
|
+
return loss
|
|
@@ -13,6 +13,7 @@ from torch import Tensor, nn
|
|
|
13
13
|
from tqdm.autonotebook import tqdm
|
|
14
14
|
|
|
15
15
|
from fusion_bench.method import BaseAlgorithm
|
|
16
|
+
from fusion_bench.mixins import SimpleProfilerMixin
|
|
16
17
|
from fusion_bench.modelpool import BaseModelPool
|
|
17
18
|
|
|
18
19
|
log = logging.getLogger(__name__)
|
|
@@ -279,7 +280,7 @@ def regmean_merging(
|
|
|
279
280
|
return merged_params
|
|
280
281
|
|
|
281
282
|
|
|
282
|
-
class RegMeanAlgorithm(BaseAlgorithm):
|
|
283
|
+
class RegMeanAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
283
284
|
_include_module_type = [nn.Linear]
|
|
284
285
|
_config_mapping = {
|
|
285
286
|
"num_regmean_examples": "num_regmean_examples",
|
|
@@ -342,24 +343,31 @@ class RegMeanAlgorithm(BaseAlgorithm):
|
|
|
342
343
|
)
|
|
343
344
|
assert len(linear_modules_to_merge) > 0, "No linear modules to merge"
|
|
344
345
|
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
346
|
+
with (
|
|
347
|
+
self.profile("merging models"),
|
|
348
|
+
self.profile("computing regmean weights"),
|
|
349
|
+
):
|
|
350
|
+
regmean_weights = self.get_regmean_weights(
|
|
351
|
+
name,
|
|
352
|
+
model,
|
|
353
|
+
train_dataset=modelpool.load_train_dataset(name),
|
|
354
|
+
linear_modules_to_merge=linear_modules_to_merge,
|
|
355
|
+
)
|
|
356
|
+
models_to_merge_regmean_weights_list.append(regmean_weights)
|
|
357
|
+
|
|
358
|
+
with self.profile("merging models"):
|
|
359
|
+
# merging with regmean weights
|
|
360
|
+
merged_params = merging_with_regmean_weights(
|
|
361
|
+
models_to_merge_param_dict=models_to_merge_param_dict,
|
|
362
|
+
models_to_merge_regmean_weights_list=models_to_merge_regmean_weights_list,
|
|
363
|
+
reduce_non_diagonal_ratio=self.reduce_non_diagonal_ratio,
|
|
364
|
+
weight_transpose=self.config.get("weight_transpose", True),
|
|
365
|
+
)
|
|
352
366
|
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
models_to_merge_param_dict=models_to_merge_param_dict,
|
|
356
|
-
models_to_merge_regmean_weights_list=models_to_merge_regmean_weights_list,
|
|
357
|
-
reduce_non_diagonal_ratio=self.reduce_non_diagonal_ratio,
|
|
358
|
-
weight_transpose=self.config.get("weight_transpose", True),
|
|
359
|
-
)
|
|
367
|
+
merged_model = modelpool.load_model("_pretrained_")
|
|
368
|
+
merged_model.load_state_dict(merged_params, strict=False)
|
|
360
369
|
|
|
361
|
-
|
|
362
|
-
merged_model.load_state_dict(merged_params, strict=False)
|
|
370
|
+
self.print_profile_summary()
|
|
363
371
|
return merged_model
|
|
364
372
|
|
|
365
373
|
def on_regmean_start(self):
|