fusion-bench 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +3 -1
- fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
- fusion_bench/constants/clip_vision.py +22 -0
- fusion_bench/dataset/clip_dataset.py +10 -2
- fusion_bench/dataset/gsm8k.py +2 -2
- fusion_bench/method/__init__.py +12 -2
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
- fusion_bench/method/doge_ta/__init__.py +2 -0
- fusion_bench/method/{DOGE_TA → doge_ta}/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/{DOGE_TA/DOGE_TA.py → doge_ta/doge_ta.py} +1 -1
- fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
- fusion_bench/method/gossip/__init__.py +3 -0
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
- fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
- fusion_bench/method/gossip/entropy_loss.py +25 -0
- fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
- fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
- fusion_bench/method/gossip/min_norm_solvers.py +227 -0
- fusion_bench/method/gossip/task_wise_gossip.py +265 -0
- fusion_bench/method/gossip/utils.py +74 -0
- fusion_bench/method/isotropic_merging/__init__.py +1 -1
- fusion_bench/method/opcm/opcm.py +102 -84
- fusion_bench/method/opcm/task_arithmetic.py +35 -21
- fusion_bench/method/opcm/ties_merging.py +71 -52
- fusion_bench/method/pwe_moe/module.py +1 -1
- fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
- fusion_bench/method/regmean/regmean.py +25 -17
- fusion_bench/method/smile_upscaling/__init__.py +1 -1
- fusion_bench/method/smile_upscaling/smile_upscaling.py +13 -10
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
- fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
- fusion_bench/method/ties_merging/ties_merging.py +36 -31
- fusion_bench/method/we_moe/we_moe.py +14 -15
- fusion_bench/mixins/__init__.py +6 -3
- fusion_bench/mixins/hydra_config.py +49 -0
- fusion_bench/mixins/openclip_classification.py +11 -0
- fusion_bench/mixins/simple_profiler.py +4 -2
- fusion_bench/modelpool/__init__.py +3 -1
- fusion_bench/modelpool/base_pool.py +2 -2
- fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
- fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
- fusion_bench/models/open_clip/__init__.py +6 -0
- fusion_bench/models/open_clip/modeling.py +176 -0
- fusion_bench/models/open_clip/utils.py +311 -0
- fusion_bench/models/open_clip/variables_and_paths.py +56 -0
- fusion_bench/models/parameter_dict.py +54 -13
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -46
- fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py +4 -119
- fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
- fusion_bench/taskpool/__init__.py +5 -3
- fusion_bench/taskpool/clip_vision/__init__.py +1 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
- fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
- fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
- fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
- fusion_bench/taskpool/gpt2_text_classification.py +30 -1
- fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
- fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
- fusion_bench/utils/data.py +12 -0
- fusion_bench/utils/devices.py +14 -0
- fusion_bench/utils/instantiate.py +12 -0
- fusion_bench/utils/misc.py +9 -2
- fusion_bench/utils/packages.py +14 -0
- fusion_bench/utils/parameters.py +1 -1
- fusion_bench/utils/tensorboard.py +1 -1
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/METADATA +15 -2
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/RECORD +198 -158
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/WHEEL +1 -1
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
- fusion_bench_config/fabric/auto.yaml +0 -1
- fusion_bench_config/fabric/llama_ddp.yaml +0 -1
- fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric_model_fusion.yaml +0 -1
- fusion_bench_config/llama_full_finetune.yaml +0 -2
- fusion_bench_config/llama_model_fusion.yaml +0 -2
- fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
- fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
- fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
- fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
- fusion_bench_config/method/adamerging.yaml +2 -2
- fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
- fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
- fusion_bench_config/method/dare/simple_average.yaml +0 -1
- fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
- fusion_bench_config/method/dare/ties_merging.yaml +0 -2
- fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
- fusion_bench_config/method/{DOGE_TA/DOGE_TA.yaml → doge_ta/doge_ta.yaml} +1 -1
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
- fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
- fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
- fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
- fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
- fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
- fusion_bench_config/method/linear/llama_expo.yaml +0 -3
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
- fusion_bench_config/method/linear/weighted_average.yaml +0 -1
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
- fusion_bench_config/method/model_recombination.yaml +0 -1
- fusion_bench_config/method/opcm/opcm.yaml +0 -1
- fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
- fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
- fusion_bench_config/method/opcm/weight_average.yaml +0 -1
- fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
- fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
- fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
- fusion_bench_config/method/slerp/slerp.yaml +0 -2
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
- fusion_bench_config/method/task_arithmetic.yaml +1 -1
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
- fusion_bench_config/method/ties_merging.yaml +1 -1
- fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
- fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +8 -10
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +66 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
- fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
- fusion_bench_config/nyuv2_config.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
- fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
- fusion_bench/method/DOGE_TA/__init__.py +0 -2
- /fusion_bench/method/{DOGE_TA → doge_ta}/layer_wise_adamerging.py +0 -0
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info/licenses}/LICENSE +0 -0
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/top_level.txt +0 -0
|
@@ -20,7 +20,7 @@ from fusion_bench.method.ties_merging.ties_merging_utils import (
|
|
|
20
20
|
ties_merging,
|
|
21
21
|
vector_to_state_dict,
|
|
22
22
|
)
|
|
23
|
-
from fusion_bench.mixins import LightningFabricMixin
|
|
23
|
+
from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
|
|
24
24
|
from fusion_bench.taskpool import CLIPVisionModelTaskPool
|
|
25
25
|
from fusion_bench.utils.json import load_from_json, save_to_json
|
|
26
26
|
from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_sub
|
|
@@ -29,7 +29,11 @@ if TYPE_CHECKING:
|
|
|
29
29
|
from torch.utils.tensorboard import SummaryWriter
|
|
30
30
|
|
|
31
31
|
|
|
32
|
-
class ContinualTiesMergingForCLIP(
|
|
32
|
+
class ContinualTiesMergingForCLIP(
|
|
33
|
+
BaseAlgorithm,
|
|
34
|
+
LightningFabricMixin,
|
|
35
|
+
SimpleProfilerMixin,
|
|
36
|
+
):
|
|
33
37
|
def __init__(
|
|
34
38
|
self,
|
|
35
39
|
scaling_factor: float,
|
|
@@ -84,68 +88,83 @@ class ContinualTiesMergingForCLIP(BaseAlgorithm, LightningFabricMixin):
|
|
|
84
88
|
)
|
|
85
89
|
|
|
86
90
|
# get the average model
|
|
87
|
-
|
|
91
|
+
with self.profile("loading model"):
|
|
92
|
+
pretrained_model = modelpool.load_pretrained_model()
|
|
88
93
|
merged_model = deepcopy(pretrained_model)
|
|
89
94
|
|
|
90
95
|
for model_idx, model_name in tqdm(
|
|
91
96
|
enumerate(model_names), desc="Processing models"
|
|
92
97
|
):
|
|
93
|
-
|
|
98
|
+
with self.profile("loading model"):
|
|
99
|
+
task_model = modelpool.load_model(model_name)
|
|
94
100
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
)
|
|
99
|
-
if model_idx == 0:
|
|
100
|
-
# if is the first model, the merged task vector is equal to the task vector
|
|
101
|
-
ties_merging_state_dict = task_vector
|
|
102
|
-
else:
|
|
103
|
-
# if is not the first model, we need to merge the task vector with the previous merged task vector
|
|
104
|
-
merged_tv = state_dict_sub(
|
|
105
|
-
merged_model.state_dict(),
|
|
101
|
+
with self.profile("merging model"):
|
|
102
|
+
task_vector = state_dict_sub(
|
|
103
|
+
task_model.state_dict(),
|
|
106
104
|
pretrained_model.state_dict(),
|
|
107
105
|
)
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
106
|
+
if model_idx == 0:
|
|
107
|
+
# if is the first model, the merged task vector is equal to the task vector
|
|
108
|
+
ties_merging_state_dict = task_vector
|
|
109
|
+
else:
|
|
110
|
+
# if is not the first model, we need to merge the task vector with the previous merged task vector
|
|
111
|
+
merged_tv = state_dict_sub(
|
|
112
|
+
merged_model.state_dict(),
|
|
113
|
+
pretrained_model.state_dict(),
|
|
114
|
+
)
|
|
115
|
+
tv_flat_checks = torch.vstack(
|
|
116
|
+
[
|
|
117
|
+
state_dict_to_vector(
|
|
118
|
+
merged_tv, remove_keys=self.remove_keys
|
|
119
|
+
),
|
|
120
|
+
state_dict_to_vector(
|
|
121
|
+
task_vector, remove_keys=self.remove_keys
|
|
122
|
+
),
|
|
123
|
+
]
|
|
124
|
+
)
|
|
125
|
+
# perform the TIES merging
|
|
126
|
+
ties_merging_tv = ties_merging(
|
|
127
|
+
tv_flat_checks,
|
|
128
|
+
reset_thresh=self.threshold,
|
|
129
|
+
merge_func=self.merge_func,
|
|
130
|
+
)
|
|
131
|
+
# convert the merged task vector back to a state dict
|
|
132
|
+
ties_merging_state_dict = vector_to_state_dict(
|
|
133
|
+
ties_merging_tv,
|
|
134
|
+
merged_model.state_dict(),
|
|
135
|
+
remove_keys=self.remove_keys,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
for param_name, param in task_model.named_parameters():
|
|
139
|
+
if not param.requires_grad:
|
|
140
|
+
continue
|
|
141
|
+
|
|
142
|
+
merged_param = merged_model.get_parameter(param_name)
|
|
143
|
+
new_param = (
|
|
144
|
+
merged_param
|
|
145
|
+
+ self.scaling_factor * ties_merging_state_dict[param_name]
|
|
146
|
+
)
|
|
147
|
+
merged_model.get_parameter(param_name).data = new_param
|
|
137
148
|
|
|
138
149
|
if self.save_on_every_step:
|
|
139
|
-
self.
|
|
150
|
+
with self.profile("saving model"):
|
|
151
|
+
self.save_merged_model(merged_model, model_idx)
|
|
140
152
|
|
|
141
153
|
if self.evaluate_on_every_step:
|
|
142
|
-
self.
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
154
|
+
with self.profile("evaluating model"):
|
|
155
|
+
self.taskpool._is_setup = False
|
|
156
|
+
self.taskpool._test_datasets = DictConfig(
|
|
157
|
+
{
|
|
158
|
+
n: self._test_datasets[n]
|
|
159
|
+
for n in model_names[: model_idx + 1]
|
|
160
|
+
}
|
|
161
|
+
)
|
|
162
|
+
report = self.taskpool.evaluate(deepcopy(merged_model))
|
|
163
|
+
save_to_json(
|
|
164
|
+
report, Path(self.log_dir) / f"report_{model_idx}.json"
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
self.print_profile_summary()
|
|
149
168
|
return merged_model
|
|
150
169
|
|
|
151
170
|
def save_merged_model(self, merged_model: CLIPVisionModel, step: int):
|
|
@@ -0,0 +1,476 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
|
|
7
|
+
|
|
8
|
+
import lightning.fabric
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pandas as pd
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn.functional as F
|
|
13
|
+
from omegaconf import DictConfig, OmegaConf
|
|
14
|
+
from open_clip.model import ResidualAttentionBlock
|
|
15
|
+
from torch import Tensor, nn
|
|
16
|
+
from tqdm.auto import tqdm
|
|
17
|
+
|
|
18
|
+
from fusion_bench import BaseAlgorithm
|
|
19
|
+
from fusion_bench.dataset.clip_dataset import CLIPDataset
|
|
20
|
+
from fusion_bench.method.task_arithmetic import task_arithmetic_merge
|
|
21
|
+
from fusion_bench.mixins import OpenCLIPClassificationMixin, SimpleProfilerMixin
|
|
22
|
+
from fusion_bench.modelpool import OpenCLIPVisionModelPool
|
|
23
|
+
from fusion_bench.models.open_clip import ClassificationHead, ImageEncoder
|
|
24
|
+
from fusion_bench.utils import print_parameters, timeit_context
|
|
25
|
+
from fusion_bench.utils.data import InfiniteDataLoader
|
|
26
|
+
|
|
27
|
+
from .module import ParetoWeightEnsemblingModule
|
|
28
|
+
from .phn.solvers import EPOSolver
|
|
29
|
+
from .utils import generate_simplex_grid
|
|
30
|
+
|
|
31
|
+
log = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class PWEMoEAlgorithmForOpenCLIP(
|
|
35
|
+
BaseAlgorithm,
|
|
36
|
+
SimpleProfilerMixin,
|
|
37
|
+
OpenCLIPClassificationMixin,
|
|
38
|
+
):
|
|
39
|
+
modelpool: OpenCLIPVisionModelPool
|
|
40
|
+
|
|
41
|
+
#! === Training & Validation Data ===
|
|
42
|
+
# setup the datasets and loaders by calling `load_datasets`
|
|
43
|
+
train_datasets: Dict[str, CLIPDataset]
|
|
44
|
+
train_loaders: Dict[str, torch.utils.data.DataLoader]
|
|
45
|
+
train_loader_iters: Dict[str, Iterator[Tuple[torch.Tensor, torch.Tensor]]]
|
|
46
|
+
|
|
47
|
+
test_datasets: Dict[str, CLIPDataset]
|
|
48
|
+
test_loaders: Dict[str, torch.utils.data.DataLoader]
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
*,
|
|
53
|
+
#! === Model Architecture Arguments ===
|
|
54
|
+
partial: bool,
|
|
55
|
+
init_lambda: float,
|
|
56
|
+
router_hidden_layers: int,
|
|
57
|
+
checkpoint_path: str,
|
|
58
|
+
#! === Training Arguments ===
|
|
59
|
+
run_train: bool,
|
|
60
|
+
num_steps: int,
|
|
61
|
+
save_interval: int,
|
|
62
|
+
lr: float,
|
|
63
|
+
alpha: float,
|
|
64
|
+
dataloader_kwargs: DictConfig,
|
|
65
|
+
#! === Evaluation Arguments ===
|
|
66
|
+
run_eval: bool,
|
|
67
|
+
num_evaluation_samples: Union[str, int],
|
|
68
|
+
quick_evaluation: bool,
|
|
69
|
+
**kwargs,
|
|
70
|
+
):
|
|
71
|
+
super().__init__(**kwargs)
|
|
72
|
+
self.partial = partial
|
|
73
|
+
self.init_lambda = init_lambda
|
|
74
|
+
self.router_hidden_layers = router_hidden_layers
|
|
75
|
+
self.lr = lr
|
|
76
|
+
self.num_steps = num_steps
|
|
77
|
+
self.save_interval = save_interval
|
|
78
|
+
self.alpha = alpha
|
|
79
|
+
self.checkpoint_path = checkpoint_path
|
|
80
|
+
self._dataloader_kwargs = dataloader_kwargs
|
|
81
|
+
self.run_train = run_train
|
|
82
|
+
self.run_eval = run_eval
|
|
83
|
+
self.num_evaluation_samples = num_evaluation_samples
|
|
84
|
+
self.quick_evaluation = quick_evaluation
|
|
85
|
+
|
|
86
|
+
def run(self, modelpool: OpenCLIPVisionModelPool):
|
|
87
|
+
self.modelpool = modelpool
|
|
88
|
+
|
|
89
|
+
# setup the MoE model
|
|
90
|
+
model = self.load_model()
|
|
91
|
+
if self.checkpoint_path is not None:
|
|
92
|
+
self.fabric.load(self.checkpoint_path, {"model": model})
|
|
93
|
+
|
|
94
|
+
# setup dataloaders
|
|
95
|
+
self.load_datasets()
|
|
96
|
+
|
|
97
|
+
if self.run_train:
|
|
98
|
+
model = self.train()
|
|
99
|
+
if self.run_eval:
|
|
100
|
+
self.evaluate(model)
|
|
101
|
+
return model
|
|
102
|
+
|
|
103
|
+
@torch.no_grad()
|
|
104
|
+
def load_model(self):
|
|
105
|
+
modelpool = self.modelpool
|
|
106
|
+
|
|
107
|
+
# load models and classification heads
|
|
108
|
+
pretrained_model: ImageEncoder = self.modelpool.load_pretrained_model()
|
|
109
|
+
log.info("pretrained model statistics:")
|
|
110
|
+
print_parameters(pretrained_model, print_fn=log.info)
|
|
111
|
+
|
|
112
|
+
finetuned_models: Dict[str, ImageEncoder] = {}
|
|
113
|
+
for model_name in self.modelpool.model_names:
|
|
114
|
+
finetuned_models[model_name] = modelpool.load_model(model_name)
|
|
115
|
+
|
|
116
|
+
classification_heads: Dict[str, ClassificationHead] = {}
|
|
117
|
+
for model_name in self.modelpool.model_names:
|
|
118
|
+
classification_heads[model_name] = modelpool.load_classification_head(
|
|
119
|
+
model_name
|
|
120
|
+
)
|
|
121
|
+
self.classification_heads = classification_heads
|
|
122
|
+
|
|
123
|
+
self.train_processor = modelpool.train_processor
|
|
124
|
+
self.test_processor = modelpool.test_processor
|
|
125
|
+
|
|
126
|
+
with timeit_context("Building the MoE model"):
|
|
127
|
+
model = deepcopy(pretrained_model)
|
|
128
|
+
|
|
129
|
+
if self.partial:
|
|
130
|
+
log.info("Weight ensembling only the MLPs")
|
|
131
|
+
# weight ensembling only the MLPs, merge the remaining layers using task arithmetic
|
|
132
|
+
model = task_arithmetic_merge(
|
|
133
|
+
pretrained_model=model,
|
|
134
|
+
finetuned_models=list(finetuned_models.values()),
|
|
135
|
+
scaling_factor=self.init_lambda,
|
|
136
|
+
inplace=True,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# fix all parameters
|
|
140
|
+
model.requires_grad_(False)
|
|
141
|
+
|
|
142
|
+
for layer_idx in tqdm(
|
|
143
|
+
range(model.model.visual.transformer.layers), desc="Upscaling MLPs"
|
|
144
|
+
):
|
|
145
|
+
resblock: ResidualAttentionBlock = (
|
|
146
|
+
model.model.visual.transformer.resblocks[layer_idx]
|
|
147
|
+
)
|
|
148
|
+
resblock.mlp = ParetoWeightEnsemblingModule(
|
|
149
|
+
base_model=cast(
|
|
150
|
+
ResidualAttentionBlock,
|
|
151
|
+
pretrained_model.model.visual.transformer.resblocks[
|
|
152
|
+
layer_idx
|
|
153
|
+
],
|
|
154
|
+
).mlp,
|
|
155
|
+
expert_models=[
|
|
156
|
+
cast(
|
|
157
|
+
ResidualAttentionBlock,
|
|
158
|
+
m.model.visual.transformer.resblocks[layer_idx],
|
|
159
|
+
).mlp
|
|
160
|
+
for m in finetuned_models.values()
|
|
161
|
+
],
|
|
162
|
+
init_lambda=self.init_lambda,
|
|
163
|
+
fix_base_model_and_experts=True,
|
|
164
|
+
router_hidden_layers=self.router_hidden_layers,
|
|
165
|
+
)
|
|
166
|
+
else:
|
|
167
|
+
log.info("Weight ensembling all the layers")
|
|
168
|
+
# weight ensembling all the layers, merge the remaining layers using task arithmetic
|
|
169
|
+
model = task_arithmetic_merge(
|
|
170
|
+
pretrained_model=model,
|
|
171
|
+
finetuned_models=list(finetuned_models.values()),
|
|
172
|
+
scaling_factor=self.init_lambda,
|
|
173
|
+
inplace=True,
|
|
174
|
+
)
|
|
175
|
+
# fix all parameters
|
|
176
|
+
model.requires_grad_(False)
|
|
177
|
+
|
|
178
|
+
for name in [
|
|
179
|
+
"conv1",
|
|
180
|
+
"ln_pre",
|
|
181
|
+
"ln_post",
|
|
182
|
+
# "class_embedding",
|
|
183
|
+
# "positional_embedding",
|
|
184
|
+
]:
|
|
185
|
+
setattr(
|
|
186
|
+
model.model.visual,
|
|
187
|
+
name,
|
|
188
|
+
ParetoWeightEnsemblingModule(
|
|
189
|
+
base_model=getattr(pretrained_model.model.visual, name),
|
|
190
|
+
expert_models=[
|
|
191
|
+
getattr(m.model.visual, name)
|
|
192
|
+
for m in finetuned_models.values()
|
|
193
|
+
],
|
|
194
|
+
init_lambda=self.init_lambda,
|
|
195
|
+
fix_base_model_and_experts=True,
|
|
196
|
+
router_hidden_layers=self.router_hidden_layers,
|
|
197
|
+
),
|
|
198
|
+
)
|
|
199
|
+
for layer_idx in tqdm(
|
|
200
|
+
range(model.model.visual.transformer.layers),
|
|
201
|
+
desc="Upscaling the transformer layers",
|
|
202
|
+
):
|
|
203
|
+
for name in ["ln_1", "attn", "ln_attn", "ln_2", "mlp"]:
|
|
204
|
+
setattr(
|
|
205
|
+
model.model.visual.transformer.resblocks[layer_idx],
|
|
206
|
+
name,
|
|
207
|
+
ParetoWeightEnsemblingModule(
|
|
208
|
+
base_model=getattr(
|
|
209
|
+
cast(
|
|
210
|
+
ResidualAttentionBlock,
|
|
211
|
+
pretrained_model.model.visual.transformer.resblocks[
|
|
212
|
+
layer_idx
|
|
213
|
+
],
|
|
214
|
+
),
|
|
215
|
+
name,
|
|
216
|
+
),
|
|
217
|
+
expert_models=[
|
|
218
|
+
getattr(
|
|
219
|
+
cast(
|
|
220
|
+
ResidualAttentionBlock,
|
|
221
|
+
m.model.visual.transformer.resblocks[
|
|
222
|
+
layer_idx
|
|
223
|
+
],
|
|
224
|
+
),
|
|
225
|
+
name,
|
|
226
|
+
)
|
|
227
|
+
for m in finetuned_models.values()
|
|
228
|
+
],
|
|
229
|
+
init_lambda=self.init_lambda,
|
|
230
|
+
fix_base_model_and_experts=True,
|
|
231
|
+
router_hidden_layers=self.router_hidden_layers,
|
|
232
|
+
),
|
|
233
|
+
)
|
|
234
|
+
for name in ["token_embedding", "ln_final"]:
|
|
235
|
+
setattr(
|
|
236
|
+
model.model,
|
|
237
|
+
name,
|
|
238
|
+
ParetoWeightEnsemblingModule(
|
|
239
|
+
base_model=getattr(pretrained_model.model, name),
|
|
240
|
+
expert_models=[
|
|
241
|
+
getattr(m.model, name)
|
|
242
|
+
for m in finetuned_models.values()
|
|
243
|
+
],
|
|
244
|
+
init_lambda=self.init_lambda,
|
|
245
|
+
fix_base_model_and_experts=True,
|
|
246
|
+
router_hidden_layers=self.router_hidden_layers,
|
|
247
|
+
),
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
self.model = model
|
|
251
|
+
print_parameters(model, print_fn=log.info)
|
|
252
|
+
return model
|
|
253
|
+
|
|
254
|
+
def load_datasets(self):
|
|
255
|
+
modelpool = self.modelpool
|
|
256
|
+
|
|
257
|
+
# setup the train datasets and loaders
|
|
258
|
+
train_datasets = {}
|
|
259
|
+
train_loaders = {}
|
|
260
|
+
train_loader_iters = {}
|
|
261
|
+
for dataset_name in modelpool.train_dataset_names:
|
|
262
|
+
train_datasets[dataset_name] = modelpool.load_train_dataset(dataset_name)
|
|
263
|
+
train_datasets[dataset_name] = CLIPDataset(
|
|
264
|
+
train_datasets[dataset_name], self.train_processor
|
|
265
|
+
)
|
|
266
|
+
# sanity check
|
|
267
|
+
assert isinstance(train_datasets[dataset_name][0], tuple)
|
|
268
|
+
|
|
269
|
+
# setup the train loaders
|
|
270
|
+
train_loaders[dataset_name] = torch.utils.data.DataLoader(
|
|
271
|
+
train_datasets[dataset_name],
|
|
272
|
+
shuffle=True,
|
|
273
|
+
drop_last=True,
|
|
274
|
+
**self._dataloader_kwargs,
|
|
275
|
+
)
|
|
276
|
+
train_loaders[dataset_name] = self.fabric.setup_dataloaders(
|
|
277
|
+
train_loaders[dataset_name]
|
|
278
|
+
)
|
|
279
|
+
train_loaders[dataset_name] = InfiniteDataLoader(
|
|
280
|
+
train_loaders[dataset_name]
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
# setup the train loader iterators
|
|
284
|
+
train_loader_iters[dataset_name] = iter(train_loaders[dataset_name])
|
|
285
|
+
|
|
286
|
+
self.train_datasets = train_datasets
|
|
287
|
+
self.train_loaders = train_loaders
|
|
288
|
+
self.train_loader_iters = train_loader_iters
|
|
289
|
+
|
|
290
|
+
# setup the test datasets and loaders
|
|
291
|
+
test_datasets = {}
|
|
292
|
+
test_loaders = {}
|
|
293
|
+
for dataset_name in modelpool.test_dataset_names:
|
|
294
|
+
test_datasets[dataset_name] = modelpool.load_test_dataset(dataset_name)
|
|
295
|
+
test_datasets[dataset_name] = CLIPDataset(
|
|
296
|
+
test_datasets[dataset_name], self.test_processor
|
|
297
|
+
)
|
|
298
|
+
test_loaders[dataset_name] = torch.utils.data.DataLoader(
|
|
299
|
+
test_datasets[dataset_name],
|
|
300
|
+
shuffle=False,
|
|
301
|
+
**self._dataloader_kwargs,
|
|
302
|
+
)
|
|
303
|
+
test_loaders[dataset_name] = self.fabric.setup_dataloaders(
|
|
304
|
+
test_loaders[dataset_name]
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
self.test_datasets = test_datasets
|
|
308
|
+
self.test_loaders = test_loaders
|
|
309
|
+
|
|
310
|
+
def compute_loss(self, model: ImageEncoder, ray: Tensor):
|
|
311
|
+
losses = []
|
|
312
|
+
for dataset_idx, dataset_name in enumerate(self.train_datasets):
|
|
313
|
+
batch = next(self.train_loader_iters[dataset_name])
|
|
314
|
+
x, y = batch
|
|
315
|
+
|
|
316
|
+
features = model(x)
|
|
317
|
+
logits = self.classification_heads[dataset_name](features)
|
|
318
|
+
|
|
319
|
+
_loss = F.cross_entropy(logits, y)
|
|
320
|
+
losses.append(_loss)
|
|
321
|
+
|
|
322
|
+
loss = self.aggregate_loss(model, ray, losses)
|
|
323
|
+
return loss
|
|
324
|
+
|
|
325
|
+
@abstractmethod
|
|
326
|
+
def aggregate_loss(self, model: nn.Module, ray: Tensor, losses: Tuple[Tensor]):
|
|
327
|
+
pass
|
|
328
|
+
|
|
329
|
+
def train(self):
|
|
330
|
+
# setup the model
|
|
331
|
+
num_objectives = len(self.modelpool.model_names)
|
|
332
|
+
model = deepcopy(self.model)
|
|
333
|
+
self.classification_heads = {
|
|
334
|
+
t: h.to(self.fabric.device) for t, h in self.classification_heads.items()
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
# set up the optimizer and learning rate scheduler
|
|
338
|
+
optimizer = torch.optim.Adam(
|
|
339
|
+
filter(lambda p: p.requires_grad, model.parameters()),
|
|
340
|
+
lr=self.lr,
|
|
341
|
+
)
|
|
342
|
+
model, optimizer = self.fabric.setup(model, optimizer)
|
|
343
|
+
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
344
|
+
optimizer=optimizer, T_max=self.num_steps, eta_min=self.lr * 0.1
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
model.train()
|
|
348
|
+
device = self.fabric.device
|
|
349
|
+
for step_idx in tqdm(
|
|
350
|
+
range(1, 1 + self.num_steps), "training", dynamic_ncols=True
|
|
351
|
+
):
|
|
352
|
+
# sample a preference ray
|
|
353
|
+
ray = torch.from_numpy(
|
|
354
|
+
np.random.dirichlet((self.alpha,) * num_objectives, 1)
|
|
355
|
+
.astype(np.float32)
|
|
356
|
+
.flatten()
|
|
357
|
+
).to(device)
|
|
358
|
+
ParetoWeightEnsemblingModule.set_preferenece_vector(model, ray)
|
|
359
|
+
|
|
360
|
+
loss = self.compute_loss(model, ray)
|
|
361
|
+
|
|
362
|
+
optimizer.zero_grad()
|
|
363
|
+
self.fabric.backward(loss)
|
|
364
|
+
optimizer.step()
|
|
365
|
+
|
|
366
|
+
lr_scheduler.step()
|
|
367
|
+
|
|
368
|
+
self.fabric.log("loss", loss.item(), step=step_idx)
|
|
369
|
+
|
|
370
|
+
if step_idx % self.save_interval == 0 or step_idx == self.num_steps:
|
|
371
|
+
ckpt_dir = Path(self.log_dir) / "checkpoints"
|
|
372
|
+
ckpt_dir.mkdir(exist_ok=True, parents=True)
|
|
373
|
+
self.fabric.save(
|
|
374
|
+
ckpt_dir / f"model_step={step_idx}.ckpt",
|
|
375
|
+
{"model": model},
|
|
376
|
+
)
|
|
377
|
+
return model
|
|
378
|
+
|
|
379
|
+
def evaluate(self, model):
|
|
380
|
+
results = defaultdict(list)
|
|
381
|
+
|
|
382
|
+
num_objectives = len(self.modelpool.model_names)
|
|
383
|
+
device = self.fabric.device
|
|
384
|
+
self.classification_heads = {
|
|
385
|
+
t: h.to(self.fabric.device) for t, h in self.classification_heads.items()
|
|
386
|
+
}
|
|
387
|
+
|
|
388
|
+
if not lightning.fabric.is_wrapped(model):
|
|
389
|
+
model = self.fabric.setup_module(model)
|
|
390
|
+
model.eval()
|
|
391
|
+
|
|
392
|
+
if self.num_evaluation_samples == "equal_weight":
|
|
393
|
+
uniform_grid = np.array(
|
|
394
|
+
[[1 / num_objectives] * num_objectives], dtype=np.float32
|
|
395
|
+
)
|
|
396
|
+
else:
|
|
397
|
+
uniform_grid = generate_simplex_grid(
|
|
398
|
+
num_objectives, self.num_evaluation_samples
|
|
399
|
+
)
|
|
400
|
+
for ray_idx, ray in tqdm(enumerate(uniform_grid), "evaluating samples"):
|
|
401
|
+
results["ray_idx"].append(ray_idx)
|
|
402
|
+
# sample a preference ray
|
|
403
|
+
for i in range(len(ray)):
|
|
404
|
+
results[f"ray_{i}"].append(ray[i])
|
|
405
|
+
ray = torch.from_numpy(ray).to(device)
|
|
406
|
+
ParetoWeightEnsemblingModule.set_preferenece_vector(model, ray)
|
|
407
|
+
|
|
408
|
+
accs = []
|
|
409
|
+
for dataset_idx, dataset_name in enumerate(
|
|
410
|
+
tqdm(
|
|
411
|
+
self.modelpool.test_dataset_names,
|
|
412
|
+
"evaluating datasets",
|
|
413
|
+
leave=False,
|
|
414
|
+
)
|
|
415
|
+
):
|
|
416
|
+
test_loader = self.test_loaders[dataset_name]
|
|
417
|
+
TOTAL_CORRECT = 0
|
|
418
|
+
TOTAL_COUNT = 0
|
|
419
|
+
for batch_idx, batch in enumerate(
|
|
420
|
+
pbar := tqdm(
|
|
421
|
+
test_loader,
|
|
422
|
+
f"evaluate {dataset_name}",
|
|
423
|
+
leave=False,
|
|
424
|
+
)
|
|
425
|
+
):
|
|
426
|
+
x, y = batch
|
|
427
|
+
|
|
428
|
+
features = model(x)
|
|
429
|
+
logits = self.classification_heads[dataset_name](features)
|
|
430
|
+
preds = logits.argmax(-1)
|
|
431
|
+
|
|
432
|
+
correct = (preds == y).sum().item()
|
|
433
|
+
TOTAL_CORRECT += correct
|
|
434
|
+
TOTAL_COUNT += len(y)
|
|
435
|
+
acc = TOTAL_CORRECT / TOTAL_COUNT
|
|
436
|
+
pbar.set_postfix_str(f"acc={acc:.2f}")
|
|
437
|
+
|
|
438
|
+
if self.quick_evaluation and batch_idx > 20:
|
|
439
|
+
break
|
|
440
|
+
results[dataset_name].append(acc)
|
|
441
|
+
accs.append(acc)
|
|
442
|
+
|
|
443
|
+
# compute the average accuracy
|
|
444
|
+
if "average" not in self.modelpool.test_dataset_names:
|
|
445
|
+
results["average"].append(np.mean(accs))
|
|
446
|
+
|
|
447
|
+
(df := pd.DataFrame(results)).to_csv(
|
|
448
|
+
Path(self.log_dir) / "result.csv", index=False
|
|
449
|
+
)
|
|
450
|
+
log.info(df)
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
class PWEMoELinearScalarizationForOpenCLIP(PWEMoEAlgorithmForOpenCLIP):
|
|
454
|
+
def aggregate_loss(self, model: nn.Module, ray: Tensor, losses: Tuple[Tensor]):
|
|
455
|
+
loss = 0
|
|
456
|
+
for r, l in zip(ray, losses):
|
|
457
|
+
loss += r * l
|
|
458
|
+
return loss
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
class PWEMoEExactParetoOptimalForOpenCLIP(PWEMoEAlgorithmForOpenCLIP):
|
|
462
|
+
epo_solver: Optional[EPOSolver] = None
|
|
463
|
+
|
|
464
|
+
def aggregate_loss(self, model: nn.Module, ray: Tensor, losses: Tuple[Tensor]):
|
|
465
|
+
if self.epo_solver is None:
|
|
466
|
+
num_objectives = len(self.modelpool.model_names)
|
|
467
|
+
self.epo_solver = EPOSolver(n_tasks=num_objectives, n_params=None)
|
|
468
|
+
epo_solver = self.epo_solver
|
|
469
|
+
|
|
470
|
+
losses = torch.stack(losses)
|
|
471
|
+
loss = epo_solver.get_weighted_loss(
|
|
472
|
+
losses,
|
|
473
|
+
ray,
|
|
474
|
+
tuple(filter(lambda p: p.requires_grad, model.parameters())),
|
|
475
|
+
)
|
|
476
|
+
return loss
|