fusion-bench 0.2.12__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +2 -0
- fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
- fusion_bench/constants/clip_vision.py +22 -0
- fusion_bench/dataset/clip_dataset.py +10 -2
- fusion_bench/dataset/fer2013.py +1 -0
- fusion_bench/dataset/gsm8k.py +2 -2
- fusion_bench/method/__init__.py +10 -0
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
- fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
- fusion_bench/method/gossip/__init__.py +3 -0
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
- fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
- fusion_bench/method/gossip/entropy_loss.py +25 -0
- fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
- fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
- fusion_bench/method/gossip/min_norm_solvers.py +227 -0
- fusion_bench/method/gossip/task_wise_gossip.py +265 -0
- fusion_bench/method/gossip/utils.py +74 -0
- fusion_bench/method/isotropic_merging/__init__.py +1 -1
- fusion_bench/method/opcm/opcm.py +16 -7
- fusion_bench/method/pwe_moe/module.py +1 -1
- fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
- fusion_bench/method/regmean/regmean.py +25 -17
- fusion_bench/method/smile_upscaling/__init__.py +1 -1
- fusion_bench/method/smile_upscaling/smile_upscaling.py +13 -10
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
- fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
- fusion_bench/method/ties_merging/ties_merging.py +36 -31
- fusion_bench/method/we_moe/we_moe.py +14 -15
- fusion_bench/mixins/__init__.py +6 -3
- fusion_bench/mixins/hydra_config.py +49 -0
- fusion_bench/mixins/openclip_classification.py +11 -0
- fusion_bench/mixins/simple_profiler.py +4 -2
- fusion_bench/modelpool/__init__.py +3 -1
- fusion_bench/modelpool/base_pool.py +2 -2
- fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
- fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
- fusion_bench/models/open_clip/__init__.py +6 -0
- fusion_bench/models/open_clip/modeling.py +176 -0
- fusion_bench/models/open_clip/utils.py +311 -0
- fusion_bench/models/open_clip/variables_and_paths.py +56 -0
- fusion_bench/models/parameter_dict.py +54 -13
- fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
- fusion_bench/taskpool/__init__.py +5 -3
- fusion_bench/taskpool/clip_vision/__init__.py +1 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
- fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
- fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
- fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
- fusion_bench/taskpool/gpt2_text_classification.py +30 -1
- fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
- fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
- fusion_bench/utils/data.py +12 -0
- fusion_bench/utils/devices.py +14 -0
- fusion_bench/utils/instantiate.py +12 -0
- fusion_bench/utils/misc.py +9 -2
- fusion_bench/utils/packages.py +14 -0
- fusion_bench/utils/parameters.py +1 -1
- fusion_bench/utils/tensorboard.py +1 -1
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/RECORD +190 -151
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/WHEEL +1 -1
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
- fusion_bench_config/fabric/auto.yaml +0 -1
- fusion_bench_config/fabric/llama_ddp.yaml +0 -1
- fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric_model_fusion.yaml +0 -1
- fusion_bench_config/llama_full_finetune.yaml +0 -2
- fusion_bench_config/llama_model_fusion.yaml +0 -2
- fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
- fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
- fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
- fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
- fusion_bench_config/method/adamerging.yaml +2 -2
- fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
- fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
- fusion_bench_config/method/dare/simple_average.yaml +0 -1
- fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
- fusion_bench_config/method/dare/ties_merging.yaml +0 -2
- fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
- fusion_bench_config/method/doge_ta/doge_ta.yaml +1 -1
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
- fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
- fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
- fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
- fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
- fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
- fusion_bench_config/method/linear/llama_expo.yaml +0 -3
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
- fusion_bench_config/method/linear/weighted_average.yaml +0 -1
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
- fusion_bench_config/method/model_recombination.yaml +0 -1
- fusion_bench_config/method/opcm/opcm.yaml +0 -1
- fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
- fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
- fusion_bench_config/method/opcm/weight_average.yaml +0 -1
- fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
- fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
- fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
- fusion_bench_config/method/slerp/slerp.yaml +0 -2
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
- fusion_bench_config/method/task_arithmetic.yaml +1 -1
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
- fusion_bench_config/method/ties_merging.yaml +1 -1
- fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
- fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +0 -2
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +1 -3
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
- fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
- fusion_bench_config/nyuv2_config.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
- fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,434 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import gc
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
from abc import abstractmethod
|
|
6
|
+
from typing import Any, Callable, List, Mapping, Union, cast # noqa: F401
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
|
10
|
+
from omegaconf import DictConfig
|
|
11
|
+
from torch import Tensor
|
|
12
|
+
from torch.utils.data import DataLoader
|
|
13
|
+
from tqdm.autonotebook import tqdm
|
|
14
|
+
|
|
15
|
+
from fusion_bench.compat.method import ModelFusionAlgorithm
|
|
16
|
+
from fusion_bench.compat.modelpool import ModelPool
|
|
17
|
+
from fusion_bench.mixins.lightning_fabric import LightningFabricMixin
|
|
18
|
+
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
19
|
+
from fusion_bench.modelpool import (
|
|
20
|
+
CLIPVisionModelPool,
|
|
21
|
+
GPT2ForSequenceClassificationPool,
|
|
22
|
+
)
|
|
23
|
+
from fusion_bench.models.wrappers.layer_wise_fusion import (
|
|
24
|
+
LayerWiseMergedModel,
|
|
25
|
+
get_layer_wise_weights,
|
|
26
|
+
)
|
|
27
|
+
from fusion_bench.utils.data import load_tensor_from_file
|
|
28
|
+
|
|
29
|
+
from .entropy_loss import entropy_loss
|
|
30
|
+
|
|
31
|
+
log = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# obtain the current GPU memory usage
|
|
35
|
+
def get_memory_usage(desc):
|
|
36
|
+
allocated = torch.cuda.memory_allocated() / 1024**2 # 转换为 MB
|
|
37
|
+
cached = torch.cuda.memory_reserved() / 1024**2 # 转换为 MB
|
|
38
|
+
return (
|
|
39
|
+
f"{desc}\nAllocated Memory: {allocated:.2f} MB\nCached Memory: {cached:.2f} MB"
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class ModelScheduler:
|
|
44
|
+
"""
|
|
45
|
+
Manage the storage of models, schedule the order in which models are loaded to GPU
|
|
46
|
+
transfer data between the CPU and GPu
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
config: DictConfig,
|
|
52
|
+
modelpool: ModelPool,
|
|
53
|
+
):
|
|
54
|
+
self.pretrained_model = modelpool.load_model("_pretrained_")
|
|
55
|
+
self.finetuned_models = [
|
|
56
|
+
modelpool.load_model(name) for name in modelpool.model_names
|
|
57
|
+
]
|
|
58
|
+
self.num_finetuned_models = len(self.finetuned_models)
|
|
59
|
+
self.new_finetuned_models = copy.deepcopy(self.finetuned_models)
|
|
60
|
+
self.finetuned_models_name = [name for name in modelpool.model_names]
|
|
61
|
+
|
|
62
|
+
self.config = config
|
|
63
|
+
|
|
64
|
+
@torch.no_grad() # not sure whether to use this
|
|
65
|
+
def __call__(self, model_id):
|
|
66
|
+
"""
|
|
67
|
+
return models and relevant data in each step
|
|
68
|
+
"""
|
|
69
|
+
pretrained_model = copy.deepcopy(self.pretrained_model)
|
|
70
|
+
if self.config.topo == "ring":
|
|
71
|
+
finetuned_models = [
|
|
72
|
+
copy.deepcopy(
|
|
73
|
+
self.finetuned_models[(model_id + 1) % self.num_finetuned_models]
|
|
74
|
+
),
|
|
75
|
+
copy.deepcopy(self.finetuned_models[model_id]),
|
|
76
|
+
copy.deepcopy(
|
|
77
|
+
self.finetuned_models[(model_id - 1) % self.num_finetuned_models]
|
|
78
|
+
),
|
|
79
|
+
]
|
|
80
|
+
elif "rotate" in self.config.topo:
|
|
81
|
+
number = self.config.topo.split("_")[1]
|
|
82
|
+
finetuned_models = [copy.deepcopy(self.finetuned_models[model_id])]
|
|
83
|
+
for i in range(0, int(number)):
|
|
84
|
+
finetuned_models.append(
|
|
85
|
+
copy.deepcopy(
|
|
86
|
+
self.finetuned_models[
|
|
87
|
+
(model_id + i + 1) % self.num_finetuned_models
|
|
88
|
+
]
|
|
89
|
+
)
|
|
90
|
+
)
|
|
91
|
+
# initialize layer-wise weights using the provided configuration `init_values` or load from file if `weights` is provided
|
|
92
|
+
if self.config.weights is None:
|
|
93
|
+
layer_wise_weight = get_layer_wise_weights(
|
|
94
|
+
num_models=len(finetuned_models),
|
|
95
|
+
num_layers=len(
|
|
96
|
+
tuple(
|
|
97
|
+
filter(lambda p: p.requires_grad, pretrained_model.parameters())
|
|
98
|
+
)
|
|
99
|
+
),
|
|
100
|
+
init_values=self.config.init_values,
|
|
101
|
+
)
|
|
102
|
+
else:
|
|
103
|
+
if isinstance(self.config.weights, str):
|
|
104
|
+
# self.config.weights is a path to a saved tensor
|
|
105
|
+
layer_wise_weight = load_tensor_from_file(self.config.weights)
|
|
106
|
+
else:
|
|
107
|
+
raise ValueError(f"Unsupported weights format: {self.config.weights}")
|
|
108
|
+
|
|
109
|
+
module = LayerWiseMergedModel(
|
|
110
|
+
layer_wise_weight=layer_wise_weight,
|
|
111
|
+
pretrained_model=pretrained_model,
|
|
112
|
+
finetuned_models=finetuned_models,
|
|
113
|
+
clamp_weights=self.config.clamp_weights,
|
|
114
|
+
tie_weights=self.config.tie_weights,
|
|
115
|
+
strict=self.config.strict,
|
|
116
|
+
)
|
|
117
|
+
print(f"{layer_wise_weight.size()=}, {layer_wise_weight.numel()=}")
|
|
118
|
+
return module
|
|
119
|
+
|
|
120
|
+
def store_model(self, new_finetuned_model_dict, model_id):
|
|
121
|
+
"""
|
|
122
|
+
store new finetuned model after every turn of adamerging
|
|
123
|
+
"""
|
|
124
|
+
self.new_finetuned_models[model_id].load_state_dict(new_finetuned_model_dict)
|
|
125
|
+
|
|
126
|
+
def update_models(self):
|
|
127
|
+
self.finetuned_models = copy.deepcopy(self.new_finetuned_models)
|
|
128
|
+
|
|
129
|
+
def get_final_models(self, idx=None):
|
|
130
|
+
# need a check
|
|
131
|
+
if idx is not None:
|
|
132
|
+
return copy.deepcopy(self.finetuned_models[idx])
|
|
133
|
+
|
|
134
|
+
final_models = [
|
|
135
|
+
{"name": name, "model": model}
|
|
136
|
+
for name, model in zip(self.finetuned_models_name, self.finetuned_models)
|
|
137
|
+
]
|
|
138
|
+
num_finetuned_models = len(self.finetuned_models)
|
|
139
|
+
|
|
140
|
+
average_model = copy.deepcopy(self.pretrained_model)
|
|
141
|
+
state_dict = average_model.state_dict(keep_vars=True)
|
|
142
|
+
for name, _ in self.finetuned_models[0].named_parameters():
|
|
143
|
+
state_dict[name].data.zero_()
|
|
144
|
+
for model in self.finetuned_models:
|
|
145
|
+
for name, param in model.named_parameters():
|
|
146
|
+
state_dict[name] = state_dict[name] + 1 / num_finetuned_models * param
|
|
147
|
+
|
|
148
|
+
average_model.load_state_dict(state_dict)
|
|
149
|
+
final_models += [{"name": "average model", "model": average_model}]
|
|
150
|
+
|
|
151
|
+
return final_models
|
|
152
|
+
|
|
153
|
+
def move_to(self, device):
|
|
154
|
+
self.pretrained_model.to(device=device)
|
|
155
|
+
for model in self.finetuned_models:
|
|
156
|
+
model.to(device=device)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class LayerWiseGossipAlgorithm(
|
|
160
|
+
ModelFusionAlgorithm,
|
|
161
|
+
LightningFabricMixin,
|
|
162
|
+
SimpleProfilerMixin,
|
|
163
|
+
):
|
|
164
|
+
"""
|
|
165
|
+
Implements the Layer-Wise AdaMerging Algorithm.
|
|
166
|
+
|
|
167
|
+
This class merges the layers of a pretrained model with those of several fine-tuned models.
|
|
168
|
+
The merging is controlled by layer-wise weights, which can be initialized based on a provided configuration or loaded from a file.
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
def __init__(self, algorithm_config: DictConfig):
|
|
172
|
+
"""
|
|
173
|
+
Initialize the LayerWiseAdaMergingAlgorithm with the given configuration.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
algorithm_config (DictConfig): The configuration for the algorithm.
|
|
177
|
+
"""
|
|
178
|
+
super().__init__(algorithm_config)
|
|
179
|
+
self._program = None
|
|
180
|
+
|
|
181
|
+
@rank_zero_only
|
|
182
|
+
def save_merging_weights(self, file_path: str, merging_weights: torch.Tensor):
|
|
183
|
+
"""
|
|
184
|
+
Save the merging weights to a file.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
file_path (str): The path to save the merging weights.
|
|
188
|
+
merging_weights (torch.Tensor): The merging weights to save.
|
|
189
|
+
"""
|
|
190
|
+
if self.fabric.is_global_zero and self.config.get(
|
|
191
|
+
"save_merging_weights", False
|
|
192
|
+
):
|
|
193
|
+
if isinstance(file_path, str) and not file_path.startswith(("/", ".")):
|
|
194
|
+
# if the file path is not absolute or relative to current working directory, save it in the log directory
|
|
195
|
+
save_path = os.path.join(self.log_dir, file_path)
|
|
196
|
+
else:
|
|
197
|
+
save_path = file_path
|
|
198
|
+
log.info(f"saving merging weights to {save_path}.")
|
|
199
|
+
if os.path.dirname(save_path):
|
|
200
|
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
201
|
+
torch.save(merging_weights.detach().cpu(), save_path)
|
|
202
|
+
|
|
203
|
+
def free_gpu_memory(self, module: LayerWiseMergedModel):
|
|
204
|
+
module.pretrained_model.to("cpu")
|
|
205
|
+
for model in module.task_vectors:
|
|
206
|
+
model.to("cpu")
|
|
207
|
+
del module
|
|
208
|
+
gc.collect()
|
|
209
|
+
torch.cuda.empty_cache()
|
|
210
|
+
log.info(get_memory_usage("after freeing memory, the memory usage of GPU is:"))
|
|
211
|
+
|
|
212
|
+
def update_datasets(self, datasets):
|
|
213
|
+
"""
|
|
214
|
+
for evary epoch of local adamerging, we only use the data set corresponding to the model involved in the fusion
|
|
215
|
+
"""
|
|
216
|
+
num_datasets = len(datasets)
|
|
217
|
+
datasets_copy = datasets.copy()
|
|
218
|
+
if self.config.topo == "ring":
|
|
219
|
+
for i in range(num_datasets):
|
|
220
|
+
datasets[i] = (
|
|
221
|
+
datasets_copy[i]
|
|
222
|
+
.union(datasets_copy[(i + 1) % num_datasets])
|
|
223
|
+
.union(datasets_copy[(i - 1) % num_datasets])
|
|
224
|
+
)
|
|
225
|
+
elif "rotate" in self.config.topo:
|
|
226
|
+
number = self.config.topo.split("_")[1]
|
|
227
|
+
for i in range(num_datasets):
|
|
228
|
+
datasets[i] = datasets_copy[i]
|
|
229
|
+
for j in range(0, int(number)):
|
|
230
|
+
datasets[i] = datasets[i].union(
|
|
231
|
+
datasets_copy[(i + j + 1) % num_datasets]
|
|
232
|
+
)
|
|
233
|
+
return datasets
|
|
234
|
+
|
|
235
|
+
def run(self, modelpool: ModelPool):
|
|
236
|
+
"""
|
|
237
|
+
Run the Layer-Wise AdaMerging Algorithm.
|
|
238
|
+
|
|
239
|
+
This method constructs the wrapped model and performs test-time adaptation if necessary.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
modelpool (ModelPool): The model pool containing the pretrained and fine-tuned models.
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
LayerWiseMergedModel: The merged model after test-time adaptation.
|
|
246
|
+
"""
|
|
247
|
+
log.info("Fusing models using layer-wise adaptive merging.")
|
|
248
|
+
self.modelpool = modelpool
|
|
249
|
+
self.log_hyperparams(self.config)
|
|
250
|
+
self.num_finetuned_models = len(modelpool.model_names)
|
|
251
|
+
datasets = [{dataset} for dataset in modelpool.model_names]
|
|
252
|
+
|
|
253
|
+
with self.profile("construct the wrapped model"):
|
|
254
|
+
model_scheduler = ModelScheduler(
|
|
255
|
+
modelpool=self.modelpool, config=self.config
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
if self.config.weights is not None:
|
|
259
|
+
# skip the test-time adaptation
|
|
260
|
+
return module.merge_and_unload()
|
|
261
|
+
else:
|
|
262
|
+
for step_idx in tqdm(
|
|
263
|
+
range(self.config.gossip_max_steps),
|
|
264
|
+
"Gossip merging",
|
|
265
|
+
dynamic_ncols=True,
|
|
266
|
+
):
|
|
267
|
+
datasets = self.update_datasets(datasets)
|
|
268
|
+
log.info(f"Gossip merging step:, {step_idx}")
|
|
269
|
+
for model_id in tqdm(
|
|
270
|
+
range(self.num_finetuned_models),
|
|
271
|
+
"local admerging",
|
|
272
|
+
dynamic_ncols=True,
|
|
273
|
+
):
|
|
274
|
+
if self.config.gossip_skip_adamerging == True:
|
|
275
|
+
# skip adamerging, only merge
|
|
276
|
+
with self.profile("construct the local wrapped model"):
|
|
277
|
+
module = model_scheduler(model_id)
|
|
278
|
+
log.info(
|
|
279
|
+
f"skip adamerging, only merge ({modelpool.model_names[model_id]})"
|
|
280
|
+
)
|
|
281
|
+
model_scheduler.store_model(module.merge_weights(), model_id)
|
|
282
|
+
self.free_gpu_memory(module)
|
|
283
|
+
else:
|
|
284
|
+
with self.profile("construct the local wrapped model"):
|
|
285
|
+
module = model_scheduler(model_id)
|
|
286
|
+
|
|
287
|
+
if self.config.improve_dataset == True:
|
|
288
|
+
log.info(
|
|
289
|
+
f"improved datasets, the datasets used in this local merging is {datasets[model_id]}"
|
|
290
|
+
)
|
|
291
|
+
else:
|
|
292
|
+
log.info(
|
|
293
|
+
f"unimproved datasets, the datasets used in this local merging is {modelpool.model_names}"
|
|
294
|
+
)
|
|
295
|
+
with self.profile("test-time adaptation"):
|
|
296
|
+
module = self.test_time_adaptation(
|
|
297
|
+
module, datasets[model_id]
|
|
298
|
+
)
|
|
299
|
+
model_scheduler.store_model(module.merge_weights(), model_id)
|
|
300
|
+
log.info(
|
|
301
|
+
get_memory_usage(
|
|
302
|
+
f"after local merging ({modelpool.model_names[model_id]}), the memory usage of GPU is:"
|
|
303
|
+
)
|
|
304
|
+
)
|
|
305
|
+
self.free_gpu_memory(
|
|
306
|
+
module
|
|
307
|
+
) # simulate distributed GPU memory usage as much as possible
|
|
308
|
+
|
|
309
|
+
model_scheduler.update_models()
|
|
310
|
+
|
|
311
|
+
if "rotate" in self.config.topo:
|
|
312
|
+
number = self.config.topo.split("_")[1]
|
|
313
|
+
if int(number) == 1 and step_idx >= 20:
|
|
314
|
+
self._program.evaluate_merged_model(
|
|
315
|
+
self._program.taskpool, model_scheduler.get_final_models()
|
|
316
|
+
)
|
|
317
|
+
model_scheduler.move_to("cpu")
|
|
318
|
+
else:
|
|
319
|
+
if (
|
|
320
|
+
self.config.accuracy_test_interval != 0
|
|
321
|
+
and (step_idx + 1) % self.config.accuracy_test_interval == 0
|
|
322
|
+
):
|
|
323
|
+
self._program.evaluate_merged_model(
|
|
324
|
+
self._program.taskpool, model_scheduler.get_final_models()
|
|
325
|
+
)
|
|
326
|
+
model_scheduler.move_to("cpu")
|
|
327
|
+
return model_scheduler.get_final_models()
|
|
328
|
+
|
|
329
|
+
def on_test_time_adaptation_start(self):
|
|
330
|
+
"""
|
|
331
|
+
Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.
|
|
332
|
+
"""
|
|
333
|
+
pass
|
|
334
|
+
|
|
335
|
+
@abstractmethod
|
|
336
|
+
def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
|
|
337
|
+
"""
|
|
338
|
+
Loader of test dataset for test-time adaptation. labels are not needed.
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
task (str): The name of the task.
|
|
342
|
+
|
|
343
|
+
Returns:
|
|
344
|
+
DataLoader: The data loader for the test dataset.
|
|
345
|
+
"""
|
|
346
|
+
pass
|
|
347
|
+
|
|
348
|
+
@abstractmethod
|
|
349
|
+
def compute_logits(self, module, images: Tensor, task: str) -> Tensor:
|
|
350
|
+
"""
|
|
351
|
+
Compute the logits for the given images and task.
|
|
352
|
+
|
|
353
|
+
Args:
|
|
354
|
+
module: The model module.
|
|
355
|
+
images (Tensor): The input images.
|
|
356
|
+
task (str): The name of the task.
|
|
357
|
+
|
|
358
|
+
Returns:
|
|
359
|
+
Tensor: The computed logits.
|
|
360
|
+
"""
|
|
361
|
+
pass
|
|
362
|
+
|
|
363
|
+
def test_time_adaptation(self, module: LayerWiseMergedModel, datasets):
|
|
364
|
+
"""
|
|
365
|
+
Perform test-time adaptation on the merged model.
|
|
366
|
+
|
|
367
|
+
This method adapts the merging weights during test-time to improve performance.
|
|
368
|
+
|
|
369
|
+
Args:
|
|
370
|
+
module (LayerWiseMergedModel): The merged model.
|
|
371
|
+
|
|
372
|
+
Returns:
|
|
373
|
+
LayerWiseMergedModel: The adapted merged model.
|
|
374
|
+
"""
|
|
375
|
+
self.on_test_time_adaptation_start()
|
|
376
|
+
|
|
377
|
+
# configure optimizer
|
|
378
|
+
if self.config.optimizer == "adam":
|
|
379
|
+
optimizer = torch.optim.Adam([module.merge_weight], lr=self.config.lr)
|
|
380
|
+
print(f"{optimizer=}")
|
|
381
|
+
module, optimizer = self.fabric.setup(module, optimizer)
|
|
382
|
+
log.info(
|
|
383
|
+
get_memory_usage(
|
|
384
|
+
"after loading models and optimizer, the memory usage of GPU is:"
|
|
385
|
+
)
|
|
386
|
+
)
|
|
387
|
+
else:
|
|
388
|
+
raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")
|
|
389
|
+
|
|
390
|
+
module.train()
|
|
391
|
+
module.merge_weights()
|
|
392
|
+
for step_idx in (
|
|
393
|
+
pbar := tqdm(
|
|
394
|
+
range(self.config.max_steps if not self.is_debug_mode else 1),
|
|
395
|
+
("[DEBUG MODE] " if self.is_debug_mode else "")
|
|
396
|
+
+ "AdaMerging Test-time adaptation",
|
|
397
|
+
dynamic_ncols=True,
|
|
398
|
+
)
|
|
399
|
+
):
|
|
400
|
+
# default behavior for first-order optimizers
|
|
401
|
+
for task in self.modelpool.model_names:
|
|
402
|
+
if self.config.improve_dataset == True and task not in datasets:
|
|
403
|
+
continue
|
|
404
|
+
with self.profile("data loading"):
|
|
405
|
+
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
406
|
+
with self.profile("forward pass"):
|
|
407
|
+
if isinstance(self.modelpool, GPT2ForSequenceClassificationPool):
|
|
408
|
+
logits = self.compute_logits(module, batch, task)
|
|
409
|
+
elif isinstance(self.modelpool, CLIPVisionModelPool):
|
|
410
|
+
logits = self.compute_logits(module, batch[0], task)
|
|
411
|
+
loss = entropy_loss(logits)
|
|
412
|
+
with self.profile("backward pass"):
|
|
413
|
+
self.fabric.backward(loss, retain_graph=True)
|
|
414
|
+
|
|
415
|
+
with self.profile("optimizer step"):
|
|
416
|
+
optimizer.step()
|
|
417
|
+
optimizer.zero_grad()
|
|
418
|
+
with self.profile("merging weights"):
|
|
419
|
+
module.merge_weights()
|
|
420
|
+
|
|
421
|
+
metrics = {
|
|
422
|
+
"train/loss": loss.item(),
|
|
423
|
+
"train/weight_max": module.merge_weight.max().item(),
|
|
424
|
+
"train/weight_min": module.merge_weight.min().item(),
|
|
425
|
+
"train/weight_mean": module.merge_weight.mean().item(),
|
|
426
|
+
}
|
|
427
|
+
self.fabric.log_dict(metrics, step=step_idx)
|
|
428
|
+
pbar.set_postfix(metrics)
|
|
429
|
+
|
|
430
|
+
self.print_profile_summary()
|
|
431
|
+
del optimizer
|
|
432
|
+
gc.collect()
|
|
433
|
+
torch.cuda.empty_cache()
|
|
434
|
+
return module
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
# This code is from
|
|
2
|
+
# Multi-Task Learning as Multi-Objective Optimization
|
|
3
|
+
# Ozan Sener, Vladlen Koltun
|
|
4
|
+
# Neural Information Processing Systems (NeurIPS) 2018
|
|
5
|
+
# https://github.com/intel-isl/MultiObjectiveOptimization
|
|
6
|
+
from typing import Union
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def np_sum(x: Union[torch.Tensor, np.ndarray]) -> float:
|
|
13
|
+
if isinstance(x, torch.Tensor):
|
|
14
|
+
return x.sum().item()
|
|
15
|
+
return np.sum(x)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def to_numpy(x: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
|
|
19
|
+
if isinstance(x, torch.Tensor):
|
|
20
|
+
return x.detach().cpu().numpy()
|
|
21
|
+
return x
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class MinNormSolver:
|
|
25
|
+
MAX_ITER = 250
|
|
26
|
+
STOP_CRIT = 1e-5
|
|
27
|
+
|
|
28
|
+
def _min_norm_element_from2(v1v1, v1v2, v2v2):
|
|
29
|
+
"""
|
|
30
|
+
Analytical solution for min_{c} |cx_1 + (1-c)x_2|_2^2
|
|
31
|
+
d is the distance (objective) optimzed
|
|
32
|
+
v1v1 = <x1,x1>
|
|
33
|
+
v1v2 = <x1,x2>
|
|
34
|
+
v2v2 = <x2,x2>
|
|
35
|
+
"""
|
|
36
|
+
if v1v2 >= v1v1:
|
|
37
|
+
# Case: Fig 1, third column
|
|
38
|
+
gamma = 0.999
|
|
39
|
+
cost = v1v1
|
|
40
|
+
return gamma, cost
|
|
41
|
+
if v1v2 >= v2v2:
|
|
42
|
+
# Case: Fig 1, first column
|
|
43
|
+
gamma = 0.001
|
|
44
|
+
cost = v2v2
|
|
45
|
+
return gamma, cost
|
|
46
|
+
# Case: Fig 1, second column
|
|
47
|
+
gamma = -1.0 * ((v1v2 - v2v2) / (v1v1 + v2v2 - 2 * v1v2))
|
|
48
|
+
cost = v2v2 + gamma * (v1v2 - v2v2)
|
|
49
|
+
return gamma, cost
|
|
50
|
+
|
|
51
|
+
def _min_norm_2d(vecs, dps):
|
|
52
|
+
R"""
|
|
53
|
+
Find the minimum norm solution as combination of two points
|
|
54
|
+
This is correct only in 2D
|
|
55
|
+
ie. min_c |\sum c_i x_i|_2^2 st. \sum c_i = 1 , 1 >= c_1 >= 0 for all i, c_i + c_j = 1.0 for some i, j
|
|
56
|
+
"""
|
|
57
|
+
dmin = 1e8
|
|
58
|
+
for i in range(len(vecs)):
|
|
59
|
+
for j in range(i + 1, len(vecs)):
|
|
60
|
+
if (i, j) not in dps:
|
|
61
|
+
dps[(i, j)] = 0.0
|
|
62
|
+
for k in range(len(vecs[i])):
|
|
63
|
+
dps[(i, j)] += (
|
|
64
|
+
torch.mul(vecs[i][k], vecs[j][k]).sum().data.cpu()
|
|
65
|
+
)
|
|
66
|
+
dps[(j, i)] = dps[(i, j)]
|
|
67
|
+
if (i, i) not in dps:
|
|
68
|
+
dps[(i, i)] = 0.0
|
|
69
|
+
for k in range(len(vecs[i])):
|
|
70
|
+
dps[(i, i)] += (
|
|
71
|
+
torch.mul(vecs[i][k], vecs[i][k]).sum().data.cpu()
|
|
72
|
+
)
|
|
73
|
+
if (j, j) not in dps:
|
|
74
|
+
dps[(j, j)] = 0.0
|
|
75
|
+
for k in range(len(vecs[i])):
|
|
76
|
+
dps[(j, j)] += (
|
|
77
|
+
torch.mul(vecs[j][k], vecs[j][k]).sum().data.cpu()
|
|
78
|
+
)
|
|
79
|
+
c, d = MinNormSolver._min_norm_element_from2(
|
|
80
|
+
dps[(i, i)], dps[(i, j)], dps[(j, j)]
|
|
81
|
+
)
|
|
82
|
+
if d < dmin:
|
|
83
|
+
dmin = d
|
|
84
|
+
sol = [(i, j), c, d]
|
|
85
|
+
return sol, dps
|
|
86
|
+
|
|
87
|
+
def _projection2simplex(y):
|
|
88
|
+
R"""
|
|
89
|
+
Given y, it solves argmin_z |y-z|_2 st \sum z = 1 , 1 >= z_i >= 0 for all i
|
|
90
|
+
"""
|
|
91
|
+
m = len(y)
|
|
92
|
+
sorted_y = np.flip(np.sort(y), axis=0)
|
|
93
|
+
tmpsum = 0.0
|
|
94
|
+
tmax_f = (np.sum(y) - 1.0) / m
|
|
95
|
+
for i in range(m - 1):
|
|
96
|
+
tmpsum += sorted_y[i]
|
|
97
|
+
tmax = (tmpsum - 1) / (i + 1.0)
|
|
98
|
+
if tmax > sorted_y[i + 1]:
|
|
99
|
+
tmax_f = tmax
|
|
100
|
+
break
|
|
101
|
+
return np.maximum(y - tmax_f, np.zeros(y.shape))
|
|
102
|
+
|
|
103
|
+
def _next_point(cur_val, grad, n):
|
|
104
|
+
proj_grad = grad - (np.sum(grad) / n)
|
|
105
|
+
tm1 = -1.0 * cur_val[proj_grad < 0] / proj_grad[proj_grad < 0]
|
|
106
|
+
tm2 = (1.0 - cur_val[proj_grad > 0]) / (proj_grad[proj_grad > 0])
|
|
107
|
+
|
|
108
|
+
skippers = np_sum(tm1 < 1e-7) + np_sum(tm2 < 1e-7)
|
|
109
|
+
t = 1
|
|
110
|
+
if len(tm1[tm1 > 1e-7]) > 0:
|
|
111
|
+
t = np.min(to_numpy(tm1[tm1 > 1e-7]))
|
|
112
|
+
if len(tm2[tm2 > 1e-7]) > 0:
|
|
113
|
+
t = min(t, np.min(to_numpy(tm2[tm2 > 1e-7])))
|
|
114
|
+
|
|
115
|
+
next_point = proj_grad * t + to_numpy(cur_val)
|
|
116
|
+
next_point = MinNormSolver._projection2simplex(next_point)
|
|
117
|
+
return next_point
|
|
118
|
+
|
|
119
|
+
def find_min_norm_element(vecs):
|
|
120
|
+
R"""
|
|
121
|
+
Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull
|
|
122
|
+
as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1.
|
|
123
|
+
It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j})
|
|
124
|
+
Hence, we find the best 2-task solution, and then run the projected gradient descent until convergence
|
|
125
|
+
"""
|
|
126
|
+
# Solution lying at the combination of two points
|
|
127
|
+
dps = {}
|
|
128
|
+
init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps)
|
|
129
|
+
|
|
130
|
+
n = len(vecs)
|
|
131
|
+
sol_vec = np.zeros(n)
|
|
132
|
+
sol_vec[init_sol[0][0]] = init_sol[1]
|
|
133
|
+
sol_vec[init_sol[0][1]] = 1 - init_sol[1]
|
|
134
|
+
|
|
135
|
+
if n < 3:
|
|
136
|
+
# This is optimal for n=2, so return the solution
|
|
137
|
+
return sol_vec, init_sol[2]
|
|
138
|
+
|
|
139
|
+
iter_count = 0
|
|
140
|
+
|
|
141
|
+
grad_mat = np.zeros((n, n))
|
|
142
|
+
for i in range(n):
|
|
143
|
+
for j in range(n):
|
|
144
|
+
grad_mat[i, j] = dps[(i, j)]
|
|
145
|
+
|
|
146
|
+
while iter_count < MinNormSolver.MAX_ITER:
|
|
147
|
+
grad_dir = -1.0 * np.dot(grad_mat, sol_vec)
|
|
148
|
+
new_point = MinNormSolver._next_point(sol_vec, grad_dir, n)
|
|
149
|
+
# Re-compute the inner products for line search
|
|
150
|
+
v1v1 = 0.0
|
|
151
|
+
v1v2 = 0.0
|
|
152
|
+
v2v2 = 0.0
|
|
153
|
+
for i in range(n):
|
|
154
|
+
for j in range(n):
|
|
155
|
+
v1v1 += sol_vec[i] * sol_vec[j] * dps[(i, j)]
|
|
156
|
+
v1v2 += sol_vec[i] * new_point[j] * dps[(i, j)]
|
|
157
|
+
v2v2 += new_point[i] * new_point[j] * dps[(i, j)]
|
|
158
|
+
nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2)
|
|
159
|
+
new_sol_vec = nc * sol_vec + (1 - nc) * new_point
|
|
160
|
+
change = new_sol_vec - sol_vec
|
|
161
|
+
if np_sum(np.abs(change)) < MinNormSolver.STOP_CRIT:
|
|
162
|
+
return sol_vec, nd
|
|
163
|
+
sol_vec = new_sol_vec
|
|
164
|
+
|
|
165
|
+
def find_min_norm_element_FW(vecs):
|
|
166
|
+
R"""
|
|
167
|
+
Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull
|
|
168
|
+
as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1.
|
|
169
|
+
It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j})
|
|
170
|
+
Hence, we find the best 2-task solution, and then run the Frank Wolfe until convergence
|
|
171
|
+
"""
|
|
172
|
+
# Solution lying at the combination of two points
|
|
173
|
+
dps = {}
|
|
174
|
+
init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps)
|
|
175
|
+
|
|
176
|
+
n = len(vecs)
|
|
177
|
+
sol_vec = np.zeros(n)
|
|
178
|
+
sol_vec[init_sol[0][0]] = init_sol[1]
|
|
179
|
+
sol_vec[init_sol[0][1]] = 1 - init_sol[1]
|
|
180
|
+
|
|
181
|
+
if n < 3:
|
|
182
|
+
# This is optimal for n=2, so return the solution
|
|
183
|
+
return sol_vec, init_sol[2]
|
|
184
|
+
|
|
185
|
+
iter_count = 0
|
|
186
|
+
|
|
187
|
+
grad_mat = np.zeros((n, n))
|
|
188
|
+
for i in range(n):
|
|
189
|
+
for j in range(n):
|
|
190
|
+
grad_mat[i, j] = dps[(i, j)]
|
|
191
|
+
|
|
192
|
+
while iter_count < MinNormSolver.MAX_ITER:
|
|
193
|
+
t_iter = np.argmin(np.dot(grad_mat, sol_vec))
|
|
194
|
+
|
|
195
|
+
v1v1 = np.dot(sol_vec, np.dot(grad_mat, sol_vec))
|
|
196
|
+
v1v2 = np.dot(sol_vec, grad_mat[:, t_iter])
|
|
197
|
+
v2v2 = grad_mat[t_iter, t_iter]
|
|
198
|
+
|
|
199
|
+
nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2)
|
|
200
|
+
new_sol_vec = nc * sol_vec
|
|
201
|
+
new_sol_vec[t_iter] += 1 - nc
|
|
202
|
+
|
|
203
|
+
change = new_sol_vec - sol_vec
|
|
204
|
+
if np_sum(np.abs(change)) < MinNormSolver.STOP_CRIT:
|
|
205
|
+
return sol_vec, nd
|
|
206
|
+
sol_vec = new_sol_vec
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def gradient_normalizers(grads, losses, normalization_type):
|
|
210
|
+
gn = {}
|
|
211
|
+
if normalization_type == "l2":
|
|
212
|
+
for t in grads:
|
|
213
|
+
gn[t] = np.sqrt(np.sum([gr.pow(2).sum().data.cpu() for gr in grads[t]]))
|
|
214
|
+
elif normalization_type == "loss":
|
|
215
|
+
for t in grads:
|
|
216
|
+
gn[t] = losses[t]
|
|
217
|
+
elif normalization_type == "loss+":
|
|
218
|
+
for t in grads:
|
|
219
|
+
gn[t] = losses[t] * np.sqrt(
|
|
220
|
+
np.sum([gr.pow(2).sum().data.cpu() for gr in grads[t]])
|
|
221
|
+
)
|
|
222
|
+
elif normalization_type == "none":
|
|
223
|
+
for t in grads:
|
|
224
|
+
gn[t] = 1.0
|
|
225
|
+
else:
|
|
226
|
+
print("ERROR: Invalid Normalization Type")
|
|
227
|
+
return gn
|