fusion-bench 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +3 -1
- fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
- fusion_bench/constants/clip_vision.py +22 -0
- fusion_bench/dataset/clip_dataset.py +10 -2
- fusion_bench/dataset/gsm8k.py +2 -2
- fusion_bench/method/__init__.py +12 -2
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
- fusion_bench/method/doge_ta/__init__.py +2 -0
- fusion_bench/method/{DOGE_TA → doge_ta}/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/{DOGE_TA/DOGE_TA.py → doge_ta/doge_ta.py} +1 -1
- fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
- fusion_bench/method/gossip/__init__.py +3 -0
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
- fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
- fusion_bench/method/gossip/entropy_loss.py +25 -0
- fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
- fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
- fusion_bench/method/gossip/min_norm_solvers.py +227 -0
- fusion_bench/method/gossip/task_wise_gossip.py +265 -0
- fusion_bench/method/gossip/utils.py +74 -0
- fusion_bench/method/isotropic_merging/__init__.py +1 -1
- fusion_bench/method/opcm/opcm.py +102 -84
- fusion_bench/method/opcm/task_arithmetic.py +35 -21
- fusion_bench/method/opcm/ties_merging.py +71 -52
- fusion_bench/method/pwe_moe/module.py +1 -1
- fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
- fusion_bench/method/regmean/regmean.py +25 -17
- fusion_bench/method/smile_upscaling/__init__.py +1 -1
- fusion_bench/method/smile_upscaling/smile_upscaling.py +13 -10
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
- fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
- fusion_bench/method/ties_merging/ties_merging.py +36 -31
- fusion_bench/method/we_moe/we_moe.py +14 -15
- fusion_bench/mixins/__init__.py +6 -3
- fusion_bench/mixins/hydra_config.py +49 -0
- fusion_bench/mixins/openclip_classification.py +11 -0
- fusion_bench/mixins/simple_profiler.py +4 -2
- fusion_bench/modelpool/__init__.py +3 -1
- fusion_bench/modelpool/base_pool.py +2 -2
- fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
- fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
- fusion_bench/models/open_clip/__init__.py +6 -0
- fusion_bench/models/open_clip/modeling.py +176 -0
- fusion_bench/models/open_clip/utils.py +311 -0
- fusion_bench/models/open_clip/variables_and_paths.py +56 -0
- fusion_bench/models/parameter_dict.py +54 -13
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -46
- fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py +4 -119
- fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
- fusion_bench/taskpool/__init__.py +5 -3
- fusion_bench/taskpool/clip_vision/__init__.py +1 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
- fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
- fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
- fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
- fusion_bench/taskpool/gpt2_text_classification.py +30 -1
- fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
- fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
- fusion_bench/utils/data.py +12 -0
- fusion_bench/utils/devices.py +14 -0
- fusion_bench/utils/instantiate.py +12 -0
- fusion_bench/utils/misc.py +9 -2
- fusion_bench/utils/packages.py +14 -0
- fusion_bench/utils/parameters.py +1 -1
- fusion_bench/utils/tensorboard.py +1 -1
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/METADATA +15 -2
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/RECORD +198 -158
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/WHEEL +1 -1
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
- fusion_bench_config/fabric/auto.yaml +0 -1
- fusion_bench_config/fabric/llama_ddp.yaml +0 -1
- fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric_model_fusion.yaml +0 -1
- fusion_bench_config/llama_full_finetune.yaml +0 -2
- fusion_bench_config/llama_model_fusion.yaml +0 -2
- fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
- fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
- fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
- fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
- fusion_bench_config/method/adamerging.yaml +2 -2
- fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
- fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
- fusion_bench_config/method/dare/simple_average.yaml +0 -1
- fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
- fusion_bench_config/method/dare/ties_merging.yaml +0 -2
- fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
- fusion_bench_config/method/{DOGE_TA/DOGE_TA.yaml → doge_ta/doge_ta.yaml} +1 -1
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
- fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
- fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
- fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
- fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
- fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
- fusion_bench_config/method/linear/llama_expo.yaml +0 -3
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
- fusion_bench_config/method/linear/weighted_average.yaml +0 -1
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
- fusion_bench_config/method/model_recombination.yaml +0 -1
- fusion_bench_config/method/opcm/opcm.yaml +0 -1
- fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
- fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
- fusion_bench_config/method/opcm/weight_average.yaml +0 -1
- fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
- fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
- fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
- fusion_bench_config/method/slerp/slerp.yaml +0 -2
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
- fusion_bench_config/method/task_arithmetic.yaml +1 -1
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
- fusion_bench_config/method/ties_merging.yaml +1 -1
- fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
- fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +8 -10
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +66 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
- fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
- fusion_bench_config/nyuv2_config.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
- fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
- fusion_bench/method/DOGE_TA/__init__.py +0 -2
- /fusion_bench/method/{DOGE_TA → doge_ta}/layer_wise_adamerging.py +0 -0
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info/licenses}/LICENSE +0 -0
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,311 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import os
|
|
3
|
+
import pickle
|
|
4
|
+
from collections import OrderedDict
|
|
5
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def compute_l1_norm(
|
|
13
|
+
model1: nn.Module, model2: nn.Module
|
|
14
|
+
) -> Tuple[torch.Tensor, Dict[str, float]]:
|
|
15
|
+
"""
|
|
16
|
+
Computes the L1 norm between the parameters of two models.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
model1 (nn.Module): The first model.
|
|
20
|
+
model2 (nn.Module): The second model.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
Tuple[torch.Tensor, Dict[str, float]]: A tuple containing the total L1 norm and a dictionary
|
|
24
|
+
with the L1 norm for each layer.
|
|
25
|
+
|
|
26
|
+
"""
|
|
27
|
+
norms = dict()
|
|
28
|
+
l1_norm = 0.0
|
|
29
|
+
for (n, p1), p2 in zip(model1.named_parameters(), model2.parameters()):
|
|
30
|
+
layer_l1_norm = torch.norm(p1 - p2, 1)
|
|
31
|
+
l1_norm += layer_l1_norm
|
|
32
|
+
norms[n] = layer_l1_norm.item()
|
|
33
|
+
|
|
34
|
+
return l1_norm, norms
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def assign_learning_rate(param_group, new_lr):
|
|
38
|
+
param_group["lr"] = new_lr
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _warmup_lr(base_lr, warmup_length, step):
|
|
42
|
+
return base_lr * (step + 1) / warmup_length
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def cosine_lr(optimizer, base_lrs, warmup_length, steps):
|
|
46
|
+
if not isinstance(base_lrs, list):
|
|
47
|
+
base_lrs = [base_lrs for _ in optimizer.param_groups]
|
|
48
|
+
assert len(base_lrs) == len(optimizer.param_groups)
|
|
49
|
+
|
|
50
|
+
def _lr_adjuster(step):
|
|
51
|
+
for param_group, base_lr in zip(optimizer.param_groups, base_lrs):
|
|
52
|
+
if step < warmup_length:
|
|
53
|
+
lr = _warmup_lr(base_lr, warmup_length, step)
|
|
54
|
+
else:
|
|
55
|
+
e = step - warmup_length
|
|
56
|
+
es = steps - warmup_length
|
|
57
|
+
lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
|
|
58
|
+
assign_learning_rate(param_group, lr)
|
|
59
|
+
|
|
60
|
+
return _lr_adjuster
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def accuracy(output: torch.Tensor, target: torch.Tensor, topk: List[int] = (1,)):
|
|
64
|
+
pred = output.topk(max(topk), 1, True, True)[1].t()
|
|
65
|
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
|
66
|
+
return [
|
|
67
|
+
float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
|
|
68
|
+
for k in topk
|
|
69
|
+
]
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def torch_load_old(save_path: str, device=None):
|
|
73
|
+
with open(save_path, "rb") as f:
|
|
74
|
+
classifier = pickle.load(f)
|
|
75
|
+
if device is not None:
|
|
76
|
+
classifier = classifier.to(device)
|
|
77
|
+
return classifier
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def torch_save(model, save_path, save_state_dict=True):
|
|
81
|
+
# TODO: hacky way to save state dict
|
|
82
|
+
if save_state_dict and isinstance(model, torch.nn.Module):
|
|
83
|
+
model = model.state_dict()
|
|
84
|
+
if os.path.dirname(save_path) != "":
|
|
85
|
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
86
|
+
torch.save(model, save_path)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def torch_load(save_path, device=None):
|
|
90
|
+
model = torch.load(save_path, map_location="cpu")
|
|
91
|
+
if device is not None:
|
|
92
|
+
model = model.to(device)
|
|
93
|
+
return model
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def get_logits(inputs, classifier):
|
|
97
|
+
assert callable(classifier)
|
|
98
|
+
if hasattr(classifier, "to"):
|
|
99
|
+
classifier = classifier.to(inputs.device)
|
|
100
|
+
return classifier(inputs)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def get_probs(inputs, classifier):
|
|
104
|
+
if hasattr(classifier, "predict_proba"):
|
|
105
|
+
probs = classifier.predict_proba(inputs.detach().cpu().numpy())
|
|
106
|
+
return torch.from_numpy(probs)
|
|
107
|
+
logits = get_logits(inputs, classifier)
|
|
108
|
+
return logits.softmax(dim=1)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class LabelSmoothing(torch.nn.Module):
|
|
112
|
+
def __init__(self, smoothing=0.0):
|
|
113
|
+
super(LabelSmoothing, self).__init__()
|
|
114
|
+
self.confidence = 1.0 - smoothing
|
|
115
|
+
self.smoothing = smoothing
|
|
116
|
+
|
|
117
|
+
def forward(self, x, target):
|
|
118
|
+
logprobs = torch.nn.functional.log_softmax(x, dim=-1)
|
|
119
|
+
|
|
120
|
+
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
|
|
121
|
+
nll_loss = nll_loss.squeeze(1)
|
|
122
|
+
smooth_loss = -logprobs.mean(dim=-1)
|
|
123
|
+
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
|
|
124
|
+
return loss.mean()
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class DotDict(dict):
|
|
128
|
+
"""dot.notation access to dictionary attributes"""
|
|
129
|
+
|
|
130
|
+
__getattr__ = dict.get
|
|
131
|
+
__setattr__ = dict.__setitem__
|
|
132
|
+
__delattr__ = dict.__delitem__
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def find_optimal_coef(
|
|
136
|
+
results: Dict[str, Any],
|
|
137
|
+
metric: str = "avg_normalized_top1",
|
|
138
|
+
minimize: bool = False,
|
|
139
|
+
control_metric: Optional[str] = None,
|
|
140
|
+
control_metric_threshold: float = 0.0,
|
|
141
|
+
) -> float:
|
|
142
|
+
"""
|
|
143
|
+
Finds the optimal coefficient based on the given results and metric.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
results (Dict[str, Any]): A dictionary containing the results for different scaling coefficients.
|
|
147
|
+
metric (str, optional): The metric to optimize. Defaults to "avg_normalized_top1".
|
|
148
|
+
minimize (bool, optional): Whether to minimize the metric. Defaults to False.
|
|
149
|
+
control_metric (str, optional): The control metric to check against. Defaults to None.
|
|
150
|
+
control_metric_threshold (float, optional): The threshold value for the control metric. Defaults to 0.0.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
The optimal coefficient based on the given results and metric.
|
|
154
|
+
"""
|
|
155
|
+
best_coef = None
|
|
156
|
+
if minimize:
|
|
157
|
+
best_metric = 1
|
|
158
|
+
else:
|
|
159
|
+
best_metric = 0
|
|
160
|
+
for scaling_coef in results.keys():
|
|
161
|
+
if control_metric is not None:
|
|
162
|
+
if results[scaling_coef][control_metric] < control_metric_threshold:
|
|
163
|
+
print(f"Control metric fell below {control_metric_threshold} threshold")
|
|
164
|
+
continue
|
|
165
|
+
if minimize:
|
|
166
|
+
if results[scaling_coef][metric] < best_metric:
|
|
167
|
+
best_metric = results[scaling_coef][metric]
|
|
168
|
+
best_coef = scaling_coef
|
|
169
|
+
else:
|
|
170
|
+
if results[scaling_coef][metric] > best_metric:
|
|
171
|
+
best_metric = results[scaling_coef][metric]
|
|
172
|
+
best_coef = scaling_coef
|
|
173
|
+
return best_coef
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def nonlinear_advantage(nonlinear_acc, linear_acc, num_classes):
|
|
177
|
+
"""Computes the normalized non-linear advantage of a finetuned model.
|
|
178
|
+
|
|
179
|
+
The nonlinear_advantage is defined as:
|
|
180
|
+
error_rate(linear_model) - error_rate(nonlinear_model) / (1 - 1 / num_classes)
|
|
181
|
+
and takes values between [-1, 1]. A value of 0 indicates that the nonlinear
|
|
182
|
+
model is no better than the linear one. Meanwhile, a value of 1 indicates
|
|
183
|
+
that the nonlinear model is perfect and the linear trivial, and a value of
|
|
184
|
+
-1 indicates the opposite.
|
|
185
|
+
"""
|
|
186
|
+
return (nonlinear_acc - linear_acc) / (1.0 - 1.0 / num_classes)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def to_cuda(input_dict):
|
|
190
|
+
cuda_dict = {}
|
|
191
|
+
for key, value in input_dict.items():
|
|
192
|
+
cuda_dict[key] = value.to("cuda")
|
|
193
|
+
return cuda_dict
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def state_dict_to_vector(state_dict, remove_keys=[]):
|
|
197
|
+
shared_state_dict = copy.deepcopy(state_dict)
|
|
198
|
+
for key in remove_keys:
|
|
199
|
+
if key in shared_state_dict:
|
|
200
|
+
del shared_state_dict[key]
|
|
201
|
+
sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items()))
|
|
202
|
+
return torch.nn.utils.parameters_to_vector(
|
|
203
|
+
[value.reshape(-1) for key, value in sorted_shared_state_dict.items()]
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def vector_to_state_dict(vector, state_dict, remove_keys=[]):
|
|
208
|
+
# create a reference dict to define the order of the vector
|
|
209
|
+
reference_dict = copy.deepcopy(state_dict)
|
|
210
|
+
for key in remove_keys:
|
|
211
|
+
if key in reference_dict:
|
|
212
|
+
del reference_dict[key]
|
|
213
|
+
sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))
|
|
214
|
+
|
|
215
|
+
# create a shared state dict using the reference dict
|
|
216
|
+
torch.nn.utils.vector_to_parameters(vector, sorted_reference_dict.values())
|
|
217
|
+
|
|
218
|
+
# add back the encoder and decoder embedding weights.
|
|
219
|
+
if "transformer.shared.weight" in sorted_reference_dict:
|
|
220
|
+
for key in remove_keys:
|
|
221
|
+
sorted_reference_dict[key] = sorted_reference_dict[
|
|
222
|
+
"transformer.shared.weight"
|
|
223
|
+
]
|
|
224
|
+
return sorted_reference_dict
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def add_ptm_to_tv(tv_dict, ptm_dict):
|
|
228
|
+
assert set(tv_dict.keys()) == set(
|
|
229
|
+
ptm_dict.keys()
|
|
230
|
+
), "Differing parameter names in models."
|
|
231
|
+
final_dict = copy.deepcopy(tv_dict)
|
|
232
|
+
for k, v in ptm_dict.items():
|
|
233
|
+
final_dict[k] = tv_dict[k] + v
|
|
234
|
+
return final_dict
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def check_parameterNamesMatch(checkpoints):
|
|
238
|
+
parameter_names = set(checkpoints[0].keys())
|
|
239
|
+
|
|
240
|
+
if len(checkpoints) >= 2:
|
|
241
|
+
# raise ValueError("Number of models is less than 2.")
|
|
242
|
+
for checkpoint in checkpoints[1:]:
|
|
243
|
+
current_parameterNames = set(checkpoint.keys())
|
|
244
|
+
if current_parameterNames != parameter_names:
|
|
245
|
+
raise ValueError(
|
|
246
|
+
"Differing parameter names in models. "
|
|
247
|
+
f"The different parameters are {parameter_names.symmetric_difference(current_parameterNames)}"
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def check_state_dicts_equal(state_dict1, state_dict2):
|
|
252
|
+
if set(state_dict1.keys()) != set(state_dict2.keys()):
|
|
253
|
+
return False
|
|
254
|
+
|
|
255
|
+
for key in state_dict1.keys():
|
|
256
|
+
if not torch.equal(state_dict1[key], state_dict2[key]):
|
|
257
|
+
return False
|
|
258
|
+
|
|
259
|
+
return True
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def topk_values_mask(M, K=0.7, return_mask=False, reshape_mask=False):
|
|
263
|
+
if K == 100:
|
|
264
|
+
# print("Not applying mask")
|
|
265
|
+
if return_mask:
|
|
266
|
+
return M, torch.ones_like(M), None
|
|
267
|
+
else:
|
|
268
|
+
return M, torch.ones_like(M)
|
|
269
|
+
|
|
270
|
+
if K >= 1:
|
|
271
|
+
K /= 100
|
|
272
|
+
|
|
273
|
+
original_shape = M.shape
|
|
274
|
+
if M.dim() == 1:
|
|
275
|
+
M = M.unsqueeze(0)
|
|
276
|
+
|
|
277
|
+
n, d = M.shape
|
|
278
|
+
k = int(d * K)
|
|
279
|
+
k = d - k # Keep top k elements instead of bottom k elements
|
|
280
|
+
|
|
281
|
+
# Find the k-th smallest element by magnitude for each row
|
|
282
|
+
kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True)
|
|
283
|
+
# Create a mask tensor with True for the top k elements in each row
|
|
284
|
+
mask = M.abs() >= kth_values
|
|
285
|
+
final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask
|
|
286
|
+
|
|
287
|
+
if reshape_mask:
|
|
288
|
+
final_mask = final_mask.reshape(M.shape)
|
|
289
|
+
|
|
290
|
+
if return_mask:
|
|
291
|
+
return M * final_mask, final_mask.float().mean(dim=1), final_mask
|
|
292
|
+
else:
|
|
293
|
+
return M * final_mask, final_mask.float().mean(dim=1)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def cleanup_linear(state_dict):
|
|
297
|
+
# The linear model also has keys for the reference point $\theta_0$ in the state dict with the prefix `params0`.
|
|
298
|
+
state_dict = {k: v for k, v in state_dict.items() if "params." in k}
|
|
299
|
+
return state_dict
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def get_ptm_linear(state_dict: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
303
|
+
# rename keys so that they match afterwards
|
|
304
|
+
state_dict_new = {
|
|
305
|
+
k.replace("params0", "params"): v
|
|
306
|
+
for k, v in state_dict.items()
|
|
307
|
+
if "params0." in k
|
|
308
|
+
}
|
|
309
|
+
state_dict_remaining = {k: v for k, v in state_dict.items() if "params." not in k}
|
|
310
|
+
|
|
311
|
+
return state_dict_new, state_dict_remaining
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}{bar:-10b}"
|
|
5
|
+
MODELS = ["ViT-B-32", "ViT-B-16", "ViT-L-14"]
|
|
6
|
+
OPENCLIP_CACHEDIR = Path(Path.home(), "openclip-cachedir", "open_clip").as_posix()
|
|
7
|
+
CACHEDIR = None
|
|
8
|
+
|
|
9
|
+
ALL_DATASETS = [
|
|
10
|
+
"Cars",
|
|
11
|
+
"DTD",
|
|
12
|
+
"EuroSAT",
|
|
13
|
+
"GTSRB",
|
|
14
|
+
"MNIST",
|
|
15
|
+
"RESISC45",
|
|
16
|
+
"SVHN",
|
|
17
|
+
"SUN397",
|
|
18
|
+
"STL10",
|
|
19
|
+
"OxfordIIITPet",
|
|
20
|
+
"Flowers102",
|
|
21
|
+
"CIFAR100",
|
|
22
|
+
"PCAM",
|
|
23
|
+
"FER2013",
|
|
24
|
+
"CIFAR10",
|
|
25
|
+
"Food101",
|
|
26
|
+
"FashionMNIST",
|
|
27
|
+
"RenderedSST2",
|
|
28
|
+
"EMNIST",
|
|
29
|
+
"KMNIST",
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
DATASETS_8 = ALL_DATASETS[:8]
|
|
33
|
+
DATASETS_14 = ALL_DATASETS[:14]
|
|
34
|
+
DATASETS_20 = ALL_DATASETS[:20]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def cleanup_dataset_name(dataset_name: str):
|
|
38
|
+
return dataset_name.replace("Val", "") + "Val"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def get_zeroshot_path(root, dataset, model):
|
|
42
|
+
return Path(
|
|
43
|
+
root, model, cleanup_dataset_name(dataset), f"nonlinear_zeroshot.pt"
|
|
44
|
+
).as_posix()
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def get_finetuned_path(root, dataset, model):
|
|
48
|
+
return Path(
|
|
49
|
+
root, model, cleanup_dataset_name(dataset), f"nonlinear_finetuned.pt"
|
|
50
|
+
).as_posix()
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def get_single_task_accuracies_path(model):
|
|
54
|
+
return Path(
|
|
55
|
+
"results/single_task", model, f"nonlinear_ft_accuracies.json"
|
|
56
|
+
).as_posix()
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import List, Mapping
|
|
1
|
+
from typing import List, Mapping, Optional, Tuple
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
from torch import nn
|
|
@@ -6,7 +6,13 @@ from torch import nn
|
|
|
6
6
|
__all__ = "ParamterDictModel"
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
def
|
|
9
|
+
def _set_attr(
|
|
10
|
+
obj,
|
|
11
|
+
names: List[str],
|
|
12
|
+
val,
|
|
13
|
+
check_parent: bool = False,
|
|
14
|
+
parent_builder=nn.Module,
|
|
15
|
+
):
|
|
10
16
|
"""
|
|
11
17
|
Sets an attribute of an object recursively.
|
|
12
18
|
|
|
@@ -20,8 +26,14 @@ def set_attr(obj, names: List[str], val, check_parent: bool = False):
|
|
|
20
26
|
setattr(obj, names[0], val)
|
|
21
27
|
else:
|
|
22
28
|
if check_parent and not hasattr(obj, names[0]):
|
|
23
|
-
setattr(obj, names[0],
|
|
24
|
-
|
|
29
|
+
setattr(obj, names[0], parent_builder())
|
|
30
|
+
_set_attr(
|
|
31
|
+
getattr(obj, names[0]),
|
|
32
|
+
names[1:],
|
|
33
|
+
val,
|
|
34
|
+
check_parent=check_parent,
|
|
35
|
+
parent_builder=parent_builder,
|
|
36
|
+
)
|
|
25
37
|
|
|
26
38
|
|
|
27
39
|
def has_attr(obj, names: List[str]):
|
|
@@ -49,17 +61,19 @@ class ParameterDictModel(nn.Module):
|
|
|
49
61
|
|
|
50
62
|
def __init__(
|
|
51
63
|
self,
|
|
52
|
-
parameters: Mapping[str, nn.Parameter],
|
|
64
|
+
parameters: Optional[Mapping[str, nn.Parameter]] = None,
|
|
53
65
|
):
|
|
54
66
|
super().__init__()
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
67
|
+
if parameters is not None:
|
|
68
|
+
for name, param in parameters.items():
|
|
69
|
+
assert isinstance(param, nn.Parameter), f"{name} is not a nn.Parameter"
|
|
70
|
+
_set_attr(
|
|
71
|
+
self,
|
|
72
|
+
name.split("."),
|
|
73
|
+
param,
|
|
74
|
+
check_parent=True,
|
|
75
|
+
parent_builder=self.__class__,
|
|
76
|
+
)
|
|
63
77
|
|
|
64
78
|
def __repr__(self):
|
|
65
79
|
"""
|
|
@@ -73,3 +87,30 @@ class ParameterDictModel(nn.Module):
|
|
|
73
87
|
param_repr = f"{name}: {param.size()}"
|
|
74
88
|
param_reprs.append(param_repr)
|
|
75
89
|
return f"{self.__class__.__name__}({', '.join(param_reprs)})"
|
|
90
|
+
|
|
91
|
+
def __getitem__(self, key: str):
|
|
92
|
+
if not has_attr(self, key.split(".")):
|
|
93
|
+
raise KeyError(f"Key {key} not found in {self}")
|
|
94
|
+
key = key.split(".")
|
|
95
|
+
obj = self
|
|
96
|
+
for k in key:
|
|
97
|
+
obj = getattr(obj, k)
|
|
98
|
+
return obj
|
|
99
|
+
|
|
100
|
+
def __setitem__(self, key: str, value: nn.Parameter):
|
|
101
|
+
if not has_attr(self, key.split(".")):
|
|
102
|
+
_set_attr(self, key.split("."), value, check_parent=True)
|
|
103
|
+
else:
|
|
104
|
+
_set_attr(self, key.split("."), value, check_parent=False)
|
|
105
|
+
|
|
106
|
+
def __contains__(self, key: str):
|
|
107
|
+
return has_attr(self, key.split("."))
|
|
108
|
+
|
|
109
|
+
def keys(self):
|
|
110
|
+
return [name for name, _ in self.named_parameters()]
|
|
111
|
+
|
|
112
|
+
def items(self) -> List[Tuple[str, nn.Parameter]]:
|
|
113
|
+
return [(name, self[name]) for name in self.keys()]
|
|
114
|
+
|
|
115
|
+
def values(self) -> List[nn.Parameter]:
|
|
116
|
+
return [self[name] for name in self.keys()]
|
|
@@ -16,6 +16,7 @@ import torch
|
|
|
16
16
|
from torch import Tensor, nn
|
|
17
17
|
from torch.func import functional_call
|
|
18
18
|
|
|
19
|
+
from fusion_bench.models.utils import del_attr, get_attr, set_attr
|
|
19
20
|
from fusion_bench.utils.type import StateDictType, TorchModelType
|
|
20
21
|
|
|
21
22
|
__all__ = ["get_layer_wise_weights", "fuse_weights", "LayerWiseMergedModel"]
|
|
@@ -23,52 +24,6 @@ __all__ = ["get_layer_wise_weights", "fuse_weights", "LayerWiseMergedModel"]
|
|
|
23
24
|
log = logging.getLogger(__name__)
|
|
24
25
|
|
|
25
26
|
|
|
26
|
-
def del_attr(obj, names: List[str]):
|
|
27
|
-
"""
|
|
28
|
-
Deletes an attribute from an object recursively.
|
|
29
|
-
|
|
30
|
-
Args:
|
|
31
|
-
obj (object): Object to delete attribute from.
|
|
32
|
-
names (list): List of attribute names to delete recursively.
|
|
33
|
-
"""
|
|
34
|
-
if len(names) == 1:
|
|
35
|
-
delattr(obj, names[0])
|
|
36
|
-
else:
|
|
37
|
-
del_attr(getattr(obj, names[0]), names[1:])
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
def set_attr(obj, names: List[str], val):
|
|
41
|
-
"""
|
|
42
|
-
Sets an attribute of an object recursively.
|
|
43
|
-
|
|
44
|
-
Args:
|
|
45
|
-
obj (object): Object to set attribute of.
|
|
46
|
-
names (list): List of attribute names to set recursively.
|
|
47
|
-
val (object): Value to set the attribute to.
|
|
48
|
-
"""
|
|
49
|
-
if len(names) == 1:
|
|
50
|
-
setattr(obj, names[0], val)
|
|
51
|
-
else:
|
|
52
|
-
set_attr(getattr(obj, names[0]), names[1:], val)
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def get_attr(obj, names: List[str]):
|
|
56
|
-
"""
|
|
57
|
-
Gets an attribute of an object recursively.
|
|
58
|
-
|
|
59
|
-
Args:
|
|
60
|
-
obj (object): Object to get attribute of.
|
|
61
|
-
names (list): List of attribute names to get recursively.
|
|
62
|
-
|
|
63
|
-
Returns:
|
|
64
|
-
object: The attribute of the object.
|
|
65
|
-
"""
|
|
66
|
-
if len(names) == 1:
|
|
67
|
-
return getattr(obj, names[0])
|
|
68
|
-
else:
|
|
69
|
-
return get_attr(getattr(obj, names[0]), names[1:])
|
|
70
|
-
|
|
71
|
-
|
|
72
27
|
def get_layer_wise_weights(
|
|
73
28
|
num_models: int,
|
|
74
29
|
num_layers: int,
|
|
@@ -10,132 +10,17 @@ import torch
|
|
|
10
10
|
from torch import Tensor, nn
|
|
11
11
|
from torch.func import functional_call
|
|
12
12
|
|
|
13
|
+
from fusion_bench.models.utils import del_attr, get_attr, set_attr
|
|
13
14
|
from fusion_bench.utils.state_dict_arithmetic import state_dict_add
|
|
14
15
|
from fusion_bench.utils.type import StateDictType
|
|
15
16
|
|
|
17
|
+
from .layer_wise_fusion import fuse_weights, get_layer_wise_weights
|
|
18
|
+
|
|
16
19
|
__all__ = ["get_layer_wise_weights", "fuse_weights", "LayerWiseMergedModel"]
|
|
17
20
|
|
|
18
21
|
log = logging.getLogger(__name__)
|
|
19
22
|
|
|
20
23
|
|
|
21
|
-
def del_attr(obj, names: List[str]):
|
|
22
|
-
"""
|
|
23
|
-
Deletes an attribute from an object recursively.
|
|
24
|
-
|
|
25
|
-
Args:
|
|
26
|
-
obj (object): Object to delete attribute from.
|
|
27
|
-
names (list): List of attribute names to delete recursively.
|
|
28
|
-
"""
|
|
29
|
-
if len(names) == 1:
|
|
30
|
-
delattr(obj, names[0])
|
|
31
|
-
else:
|
|
32
|
-
del_attr(getattr(obj, names[0]), names[1:])
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
def set_attr(obj, names: List[str], val):
|
|
36
|
-
"""
|
|
37
|
-
Sets an attribute of an object recursively.
|
|
38
|
-
|
|
39
|
-
Args:
|
|
40
|
-
obj (object): Object to set attribute of.
|
|
41
|
-
names (list): List of attribute names to set recursively.
|
|
42
|
-
val (object): Value to set the attribute to.
|
|
43
|
-
"""
|
|
44
|
-
if len(names) == 1:
|
|
45
|
-
setattr(obj, names[0], val)
|
|
46
|
-
else:
|
|
47
|
-
set_attr(getattr(obj, names[0]), names[1:], val)
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
def get_attr(obj, names: List[str]):
|
|
51
|
-
"""
|
|
52
|
-
Gets an attribute of an object recursively.
|
|
53
|
-
|
|
54
|
-
Args:
|
|
55
|
-
obj (object): Object to get attribute of.
|
|
56
|
-
names (list): List of attribute names to get recursively.
|
|
57
|
-
|
|
58
|
-
Returns:
|
|
59
|
-
object: The attribute of the object.
|
|
60
|
-
"""
|
|
61
|
-
if len(names) == 1:
|
|
62
|
-
return getattr(obj, names[0])
|
|
63
|
-
else:
|
|
64
|
-
return get_attr(getattr(obj, names[0]), names[1:])
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
def get_layer_wise_weights(
|
|
68
|
-
num_models: int,
|
|
69
|
-
num_layers: int,
|
|
70
|
-
init_values: float = None,
|
|
71
|
-
dtype: torch.dtype = torch.float32,
|
|
72
|
-
):
|
|
73
|
-
"""
|
|
74
|
-
Return a tensor of layer-wise weights for the given number of models and layers.
|
|
75
|
-
|
|
76
|
-
Args:
|
|
77
|
-
num_models (int): The number of models to fuse.
|
|
78
|
-
num_layers (int): The number of layers in each model.
|
|
79
|
-
init_values (float, optional): The initial value for each weight. Defaults to 1.0 / num_models.
|
|
80
|
-
dtype (torch.dtype): dtype of weights. This should be the same with model dtype.
|
|
81
|
-
|
|
82
|
-
Returns:
|
|
83
|
-
Tensor: A tensor of shape (num_models, num_layers) containing the layer-wise weights.
|
|
84
|
-
"""
|
|
85
|
-
assert num_models >= 1, f"num_models must be >= 1, got {num_models}"
|
|
86
|
-
assert num_layers >= 1, f"num_layers must be >= 1, got {num_layers}"
|
|
87
|
-
if init_values is None:
|
|
88
|
-
init_values = 1.0 / num_models
|
|
89
|
-
return torch.full((num_models, num_layers), init_values, dtype=dtype)
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
def _fuse_weights(layer_wise_weight: Tensor, tensors: List[Tensor]):
|
|
93
|
-
"""
|
|
94
|
-
Fuse the layer-wise weights with the given state dictionaries.
|
|
95
|
-
|
|
96
|
-
Args:
|
|
97
|
-
layer_wise_weight (Tensor): A tensor of shape (num_models,) containing the layer-wise weights.
|
|
98
|
-
state_dicts (List[Tensor]): A list of state dictionaries, each containing the weights for a single layer.
|
|
99
|
-
|
|
100
|
-
Returns:
|
|
101
|
-
Tensor: A tensor of shape (num_params,) containing the fused weights.
|
|
102
|
-
"""
|
|
103
|
-
assert len(layer_wise_weight) == len(
|
|
104
|
-
tensors
|
|
105
|
-
), f"layer_wise_weight.shape={layer_wise_weight.shape}, len(tensors)={len(tensors)}"
|
|
106
|
-
return sum(
|
|
107
|
-
layer_wise_weight[i] * w.to(layer_wise_weight.device)
|
|
108
|
-
for i, w in enumerate(tensors)
|
|
109
|
-
)
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
def fuse_weights(
|
|
113
|
-
layer_wise_weight: Tensor, state_dicts: List[StateDictType]
|
|
114
|
-
) -> StateDictType:
|
|
115
|
-
"""
|
|
116
|
-
Fuse the weights of multiple models using layer-wise fusion.
|
|
117
|
-
|
|
118
|
-
Args:
|
|
119
|
-
layer_wise_weight (Tensor): A tensor of shape (num_models, num_layers) representing the weight of each layer for each model.
|
|
120
|
-
state_dicts (List[StateDict]): A list of state dictionaries, one for each model.
|
|
121
|
-
|
|
122
|
-
Returns:
|
|
123
|
-
A dictionary mapping each weight tensor key to the fused weight tensor.
|
|
124
|
-
"""
|
|
125
|
-
num_models = len(state_dicts)
|
|
126
|
-
num_layers = len(state_dicts[0])
|
|
127
|
-
assert layer_wise_weight.shape == (
|
|
128
|
-
num_models,
|
|
129
|
-
num_layers,
|
|
130
|
-
), f"layer_wise_weight.shape={layer_wise_weight.shape}, expected (num_models, num_layers): ({num_models}, {num_layers})"
|
|
131
|
-
return {
|
|
132
|
-
k: _fuse_weights(
|
|
133
|
-
layer_wise_weight[:, i], [state_dict[k] for state_dict in state_dicts]
|
|
134
|
-
)
|
|
135
|
-
for i, k in enumerate(state_dicts[0].keys())
|
|
136
|
-
}
|
|
137
|
-
|
|
138
|
-
|
|
139
24
|
class LayerWiseMergedModel(nn.Module):
|
|
140
25
|
_merged_state_dict: StateDictType = None
|
|
141
26
|
|
|
@@ -390,7 +275,7 @@ class LayerWiseMergedModel(nn.Module):
|
|
|
390
275
|
layer_vectors_scale = layer_vectors * layer_lamdas.view(-1, 1, 1)
|
|
391
276
|
sum_over_num_vectors = layer_vectors_scale.sum(dim=0)
|
|
392
277
|
|
|
393
|
-
layer_delta_scale = layer_delta
|
|
278
|
+
layer_delta_scale = layer_delta * layer_lamdas.view(-1, 1, 1)
|
|
394
279
|
sum_over_delta = layer_delta_scale.sum(dim=0)
|
|
395
280
|
|
|
396
281
|
# Iterate through each vector and calculate the loss one by one
|