fusion-bench 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +22 -2
- fusion_bench/_get_started/__init__.py +3 -0
- fusion_bench/_get_started/greeting_program.py +49 -0
- fusion_bench/compat/method/base_algorithm.py +14 -0
- fusion_bench/constants/__init__.py +6 -0
- fusion_bench/constants/clip_vision.py +26 -2
- fusion_bench/constants/paths.py +4 -0
- fusion_bench/constants/runtime.py +57 -0
- fusion_bench/dataset/clip_dataset.py +2 -1
- fusion_bench/dataset/gpt2_glue.py +9 -9
- fusion_bench/dataset/image_corruption/__init__.py +0 -0
- fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
- fusion_bench/dataset/image_dataset.py +1 -1
- fusion_bench/dataset/nyuv2.py +2 -2
- fusion_bench/method/__init__.py +24 -5
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
- fusion_bench/method/base_algorithm.py +195 -12
- fusion_bench/method/bitdelta/__init__.py +5 -0
- fusion_bench/method/bitdelta/bitdelta.py +156 -0
- fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
- fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
- fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
- fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
- fusion_bench/method/classification/clip_finetune.py +1 -1
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
- fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
- fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
- fusion_bench/method/doge_ta/doge_ta.py +1 -1
- fusion_bench/method/ensemble.py +12 -12
- fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -6
- fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
- fusion_bench/method/fw_merging/fw_hard.py +1 -1
- fusion_bench/method/fw_merging/fw_soft.py +1 -1
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
- fusion_bench/method/linear/expo.py +2 -1
- fusion_bench/method/linear/linear_interpolation.py +6 -4
- fusion_bench/method/linear/simple_average_for_llama.py +17 -13
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
- fusion_bench/method/model_recombination.py +2 -5
- fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
- fusion_bench/method/moe_pruner/utils/data.py +2 -1
- fusion_bench/method/moe_pruner/utils/prune.py +6 -1
- fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
- fusion_bench/method/pruning/wanda_utils/data.py +1 -2
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
- fusion_bench/method/randes/modelsoup.py +1 -3
- fusion_bench/method/regmean/clip_regmean.py +2 -2
- fusion_bench/method/regmean/gpt2_regmean.py +3 -10
- fusion_bench/method/regmean/regmean.py +2 -11
- fusion_bench/method/regmean_plusplus/__init__.py +1 -1
- fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +24 -17
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +56 -38
- fusion_bench/method/simple_average.py +12 -16
- fusion_bench/method/slerp/slerp.py +5 -2
- fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
- fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +144 -0
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +71 -51
- fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
- fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
- fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
- fusion_bench/method/ties_merging/ties_merging.py +13 -26
- fusion_bench/method/we_moe/__init__.py +1 -0
- fusion_bench/method/we_moe/clip_we_moe.py +5 -4
- fusion_bench/method/we_moe/entropy_loss.py +25 -0
- fusion_bench/method/we_moe/flan_t5_we_moe.py +331 -0
- fusion_bench/method/we_moe/utils.py +15 -0
- fusion_bench/method/we_moe/we_moe.py +6 -6
- fusion_bench/method/weighted_average/llama.py +4 -16
- fusion_bench/metrics/continual_learning/__init__.py +1 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
- fusion_bench/metrics/nyuv2/__init__.py +2 -2
- fusion_bench/metrics/nyuv2/segmentation.py +1 -1
- fusion_bench/mixins/__init__.py +10 -2
- fusion_bench/mixins/clip_classification.py +15 -45
- fusion_bench/mixins/hydra_config.py +105 -7
- fusion_bench/mixins/lightning_fabric.py +2 -0
- fusion_bench/mixins/serialization.py +275 -48
- fusion_bench/modelpool/__init__.py +2 -2
- fusion_bench/modelpool/base_pool.py +29 -9
- fusion_bench/modelpool/causal_lm/causal_lm.py +41 -33
- fusion_bench/modelpool/clip_vision/modelpool.py +1 -3
- fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
- fusion_bench/models/__init__.py +7 -1
- fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
- fusion_bench/models/hf_utils.py +160 -0
- fusion_bench/models/linearized/linearized_model_utils.py +4 -4
- fusion_bench/models/linearized/vision_model.py +1 -1
- fusion_bench/models/model_card_templates/default.md +46 -0
- fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
- fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
- fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
- fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
- fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
- fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +698 -0
- fusion_bench/models/modeling_smile_llama/register.py +8 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
- fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +7 -12
- fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
- fusion_bench/models/parameter_dict.py +1 -1
- fusion_bench/models/sparse_we_moe.py +1 -53
- fusion_bench/models/utils.py +26 -0
- fusion_bench/models/we_moe.py +1 -53
- fusion_bench/models/wrappers/ensemble.py +6 -4
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
- fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
- fusion_bench/programs/base_program.py +81 -2
- fusion_bench/programs/fabric_fusion_program.py +46 -61
- fusion_bench/scripts/cli.py +38 -5
- fusion_bench/taskpool/base_pool.py +4 -3
- fusion_bench/taskpool/clip_vision/taskpool.py +43 -22
- fusion_bench/taskpool/dummy.py +1 -1
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
- fusion_bench/tasks/clip_classification/__init__.py +6 -4
- fusion_bench/utils/__init__.py +7 -1
- fusion_bench/utils/cache_utils.py +101 -1
- fusion_bench/utils/devices.py +14 -4
- fusion_bench/utils/fabric.py +2 -2
- fusion_bench/utils/instantiate_utils.py +3 -1
- fusion_bench/utils/lazy_imports.py +23 -0
- fusion_bench/utils/lazy_state_dict.py +38 -3
- fusion_bench/utils/modelscope.py +127 -8
- fusion_bench/utils/parameters.py +2 -2
- fusion_bench/utils/path.py +56 -0
- fusion_bench/utils/pylogger.py +1 -1
- fusion_bench/utils/rich_utils.py +3 -0
- fusion_bench/utils/state_dict_arithmetic.py +25 -23
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/METADATA +24 -47
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/RECORD +184 -145
- fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
- fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
- fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
- fusion_bench_config/_get_started/greeting_program.yaml +4 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
- fusion_bench_config/fabric_model_fusion.yaml +45 -17
- fusion_bench_config/hydra/default.yaml +6 -2
- fusion_bench_config/llama_full_finetune.yaml +1 -0
- fusion_bench_config/method/adamerging/clip.yaml +1 -1
- fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
- fusion_bench_config/method/depth_upscaling.yaml +4 -1
- fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
- fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
- fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
- fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +2 -1
- fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +1 -1
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +3 -3
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
- fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
- fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
- fusion_bench_config/nyuv2_config.yaml +3 -1
- fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
- fusion_bench_config/path/default.yaml +28 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
- fusion_bench_config/method/adamerging.yaml +0 -23
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/top_level.txt +0 -0
- /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
|
@@ -0,0 +1,331 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
from typing import Any, Dict, List, Mapping, Optional, Union, cast # noqa: F401
|
|
6
|
+
|
|
7
|
+
import lightning
|
|
8
|
+
import lightning as L
|
|
9
|
+
import lightning.fabric.wrappers
|
|
10
|
+
import torch
|
|
11
|
+
from torch import Tensor
|
|
12
|
+
from torch.utils.data import DataLoader
|
|
13
|
+
from tqdm.autonotebook import tqdm
|
|
14
|
+
from transformers import T5ForConditionalGeneration
|
|
15
|
+
from transformers.data import default_data_collator
|
|
16
|
+
|
|
17
|
+
from fusion_bench.method import BaseAlgorithm
|
|
18
|
+
from fusion_bench.method.task_arithmetic.task_arithmetic import task_arithmetic_merge
|
|
19
|
+
from fusion_bench.mixins.lightning_fabric import LightningFabricMixin
|
|
20
|
+
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
21
|
+
from fusion_bench.modelpool import Seq2SeqLMPool
|
|
22
|
+
from fusion_bench.models.we_moe import WeightEnsemblingMoE
|
|
23
|
+
from fusion_bench.utils import timeit_context
|
|
24
|
+
from fusion_bench.utils.data import InfiniteDataLoader, load_tensor_from_file
|
|
25
|
+
from fusion_bench.utils.instantiate_utils import instantiate
|
|
26
|
+
from fusion_bench.utils.parameters import print_parameters
|
|
27
|
+
|
|
28
|
+
from .entropy_loss import entropy_loss
|
|
29
|
+
from .utils import get_memory_usage
|
|
30
|
+
|
|
31
|
+
log = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class FlanT5WeightEnsemblingMoEAlgorithm(
|
|
35
|
+
BaseAlgorithm,
|
|
36
|
+
LightningFabricMixin,
|
|
37
|
+
SimpleProfilerMixin,
|
|
38
|
+
):
|
|
39
|
+
"""
|
|
40
|
+
FlanT5WeightEnsemblingMoEAlgorithm is a class that implements the WeightEnsemblingMoEAlgorithm
|
|
41
|
+
for FlanT5 models. It extends the WeightEnsemblingMoEAlgorithm and CLIPClassificationMixin classes.
|
|
42
|
+
|
|
43
|
+
Attributes:
|
|
44
|
+
modelpool (Seq2SeqLMPool): The model pool containing the FlanT5 models.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
modelpool: Seq2SeqLMPool = None
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
checkpoint: bool = False,
|
|
52
|
+
save_checkpoint: bool = False,
|
|
53
|
+
router_hidden_layers: int = 2,
|
|
54
|
+
init_lambda: float = 0.3,
|
|
55
|
+
batch_reduce: bool = True,
|
|
56
|
+
lr: float = 1e-4,
|
|
57
|
+
optimizer: str = "adam",
|
|
58
|
+
devices: int = 1,
|
|
59
|
+
batch_size: int = 16,
|
|
60
|
+
num_workers: int = 0,
|
|
61
|
+
max_steps: int = 1000,
|
|
62
|
+
use_grad_accumulate: bool = True,
|
|
63
|
+
cache_dir: bool = "outputs",
|
|
64
|
+
fast_dev_run: bool = False,
|
|
65
|
+
**kwargs,
|
|
66
|
+
):
|
|
67
|
+
"""
|
|
68
|
+
Initialize the WeightEnsemblingMoEAlgorithm with the given configuration.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
algorithm_config (DictConfig): The configuration for the algorithm.
|
|
72
|
+
"""
|
|
73
|
+
self.checkpoint = checkpoint
|
|
74
|
+
self.save_checkpoint = save_checkpoint
|
|
75
|
+
self.router_hidden_layers = router_hidden_layers
|
|
76
|
+
self.init_lambda = init_lambda
|
|
77
|
+
self.batch_reduce = batch_reduce
|
|
78
|
+
self.lr = lr
|
|
79
|
+
self.optimizer = optimizer
|
|
80
|
+
self.devices = devices
|
|
81
|
+
self.batch_size = batch_size
|
|
82
|
+
self.num_workers = num_workers
|
|
83
|
+
self.max_steps = max_steps
|
|
84
|
+
self.use_grad_accumulate = use_grad_accumulate
|
|
85
|
+
self.cache_dir = cache_dir
|
|
86
|
+
self.fast_dev_run = fast_dev_run
|
|
87
|
+
super().__init__(**kwargs)
|
|
88
|
+
|
|
89
|
+
def construct_moe_model(self) -> WeightEnsemblingMoE:
|
|
90
|
+
"""
|
|
91
|
+
Construct the Mixture of Experts (MoE) model using the models in the model pool.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
WeightEnsemblingMoE: The constructed MoE model.
|
|
95
|
+
"""
|
|
96
|
+
base_model = self.modelpool.load_model("_pretrained_")
|
|
97
|
+
expert_models = [
|
|
98
|
+
self.modelpool.load_model(name) for name in self.modelpool.model_names
|
|
99
|
+
]
|
|
100
|
+
|
|
101
|
+
# Merge the models using task arithmetic
|
|
102
|
+
moe_model = task_arithmetic_merge(
|
|
103
|
+
# This function modifies the model in place, so we need to pass a deepcopy
|
|
104
|
+
deepcopy(base_model),
|
|
105
|
+
expert_models,
|
|
106
|
+
scaling_factor=self.init_lambda,
|
|
107
|
+
).requires_grad_(False)
|
|
108
|
+
|
|
109
|
+
print(base_model)
|
|
110
|
+
|
|
111
|
+
# Up-scale MLP modules
|
|
112
|
+
num_layer = 12
|
|
113
|
+
encoder_mlp_index = 1
|
|
114
|
+
base_encoder = base_model.encoder
|
|
115
|
+
moe_encoder = moe_model.encoder
|
|
116
|
+
expert_encoders = [m.encoder for m in expert_models]
|
|
117
|
+
|
|
118
|
+
for layer_idx in range(num_layer):
|
|
119
|
+
base_mlp = (
|
|
120
|
+
base_encoder.block[layer_idx].layer[encoder_mlp_index].DenseReluDense
|
|
121
|
+
)
|
|
122
|
+
expert_mlps = [
|
|
123
|
+
e.block[layer_idx].layer[encoder_mlp_index].DenseReluDense
|
|
124
|
+
for e in expert_encoders
|
|
125
|
+
]
|
|
126
|
+
|
|
127
|
+
moe_encoder.block[layer_idx].layer[encoder_mlp_index].DenseReluDense = (
|
|
128
|
+
WeightEnsemblingMoE(
|
|
129
|
+
hidden_size=base_encoder.config.hidden_size,
|
|
130
|
+
base_model=base_mlp,
|
|
131
|
+
expert_models=expert_mlps,
|
|
132
|
+
init_lambda=self.init_lambda,
|
|
133
|
+
batch_first=True,
|
|
134
|
+
router_hidden_layers=self.router_hidden_layers,
|
|
135
|
+
batch_reduce=self.batch_reduce,
|
|
136
|
+
)
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
decoder_mlp_index = 2
|
|
140
|
+
base_decoder = base_model.decoder
|
|
141
|
+
moe_decoder = moe_model.decoder
|
|
142
|
+
expert_decoders = [m.decoder for m in expert_models]
|
|
143
|
+
|
|
144
|
+
for layer_idx in range(num_layer):
|
|
145
|
+
base_mlp = (
|
|
146
|
+
base_decoder.block[layer_idx].layer[decoder_mlp_index].DenseReluDense
|
|
147
|
+
)
|
|
148
|
+
expert_mlps = [
|
|
149
|
+
e.block[layer_idx].layer[decoder_mlp_index].DenseReluDense
|
|
150
|
+
for e in expert_decoders
|
|
151
|
+
]
|
|
152
|
+
|
|
153
|
+
moe_decoder.block[layer_idx].layer[decoder_mlp_index].DenseReluDense = (
|
|
154
|
+
WeightEnsemblingMoE(
|
|
155
|
+
hidden_size=base_decoder.config.hidden_size,
|
|
156
|
+
base_model=base_mlp,
|
|
157
|
+
expert_models=expert_mlps,
|
|
158
|
+
init_lambda=self.init_lambda,
|
|
159
|
+
batch_first=True,
|
|
160
|
+
router_hidden_layers=self.router_hidden_layers,
|
|
161
|
+
batch_reduce=self.batch_reduce,
|
|
162
|
+
)
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
print(moe_model)
|
|
166
|
+
return moe_model
|
|
167
|
+
|
|
168
|
+
@functools.cache
|
|
169
|
+
def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
|
|
170
|
+
"""
|
|
171
|
+
Loader of test dataset for test-time adaptation. labels are not needed.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
task (str): The name of the task.
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
DataLoader: The data loader for the test dataset.
|
|
178
|
+
"""
|
|
179
|
+
# dataloader_kwargs = dict(self.dataloader_kwargs)
|
|
180
|
+
# dataloader_kwargs.update(dict(shuffle=True, collate_fn=default_data_collator))
|
|
181
|
+
|
|
182
|
+
dataset = self.modelpool.load_test_dataset(task)
|
|
183
|
+
log.info("get_shuffled_test_loader_iter")
|
|
184
|
+
loader = DataLoader(
|
|
185
|
+
dataset,
|
|
186
|
+
batch_size=self.batch_size,
|
|
187
|
+
shuffle=True,
|
|
188
|
+
num_workers=self.num_workers,
|
|
189
|
+
collate_fn=default_data_collator,
|
|
190
|
+
)
|
|
191
|
+
# loader = DataLoader(dataset, **dataloader_kwargs)
|
|
192
|
+
if self.fabric is not None:
|
|
193
|
+
loader = self.fabric.setup_dataloaders(loader)
|
|
194
|
+
return iter(InfiniteDataLoader(loader))
|
|
195
|
+
|
|
196
|
+
def compute_logits(
|
|
197
|
+
self,
|
|
198
|
+
module: Union[T5ForConditionalGeneration],
|
|
199
|
+
batch,
|
|
200
|
+
task: str,
|
|
201
|
+
) -> Tensor:
|
|
202
|
+
"""
|
|
203
|
+
Compute the logits for the given images and task.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
module: The model module.
|
|
207
|
+
images (Tensor): The input images.
|
|
208
|
+
task (str): The name of the task.
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
Tensor: The computed logits.
|
|
212
|
+
"""
|
|
213
|
+
input_ids: Tensor = batch["input_ids"]
|
|
214
|
+
attention_mask: Tensor = batch["attention_mask"]
|
|
215
|
+
|
|
216
|
+
# remove padding tokens from the input
|
|
217
|
+
while attention_mask[:, -1].eq(0).all():
|
|
218
|
+
input_ids = input_ids[:, :-1]
|
|
219
|
+
attention_mask = attention_mask[:, :-1]
|
|
220
|
+
|
|
221
|
+
outputs = module(
|
|
222
|
+
input_ids=input_ids,
|
|
223
|
+
attention_mask=attention_mask,
|
|
224
|
+
decoder_input_ids=torch.ones(
|
|
225
|
+
input_ids.size(0), 1, dtype=torch.long, device=input_ids.device
|
|
226
|
+
),
|
|
227
|
+
)
|
|
228
|
+
logits = outputs.logits[:, 0, :]
|
|
229
|
+
return logits
|
|
230
|
+
|
|
231
|
+
def test_time_adaptation(self, module):
|
|
232
|
+
"""
|
|
233
|
+
Perform test-time adaptation for the given module.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
module (WeightEnsemblingMoE): The MoE module to adapt.
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
WeightEnsemblingMoE: The adapted MoE module.
|
|
240
|
+
"""
|
|
241
|
+
self.on_test_time_adaptation_start()
|
|
242
|
+
|
|
243
|
+
# configure optimizer
|
|
244
|
+
if self.optimizer == "adam":
|
|
245
|
+
print([name for name, p in module.named_parameters() if p.requires_grad])
|
|
246
|
+
optimizer = torch.optim.Adam(
|
|
247
|
+
[p for p in module.parameters() if p.requires_grad], lr=self.lr
|
|
248
|
+
)
|
|
249
|
+
else:
|
|
250
|
+
raise ValueError(f"Unsupported optimizer: {self.optimizer}")
|
|
251
|
+
|
|
252
|
+
module, optimizer = self.fabric.setup(module, optimizer)
|
|
253
|
+
|
|
254
|
+
module.train()
|
|
255
|
+
# module.merge_weights()
|
|
256
|
+
for step_idx in (
|
|
257
|
+
pbar := tqdm(
|
|
258
|
+
range(self.max_steps if not self.is_debug_mode else 1),
|
|
259
|
+
("[DEBUG MODE] " if self.is_debug_mode else "")
|
|
260
|
+
+ "WEMoE Test-time adaptation",
|
|
261
|
+
dynamic_ncols=True,
|
|
262
|
+
)
|
|
263
|
+
):
|
|
264
|
+
total_loss = 0
|
|
265
|
+
for task in self.modelpool.model_names:
|
|
266
|
+
with self.profile("data loading"):
|
|
267
|
+
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
268
|
+
with self.profile("forward pass"):
|
|
269
|
+
logits = self.compute_logits(module, batch, task)
|
|
270
|
+
logits = logits.mean(dim=0, keepdim=True)
|
|
271
|
+
loss = entropy_loss(logits)
|
|
272
|
+
total_loss += loss
|
|
273
|
+
with self.profile("backward pass"):
|
|
274
|
+
self.fabric.backward(loss, retain_graph=True)
|
|
275
|
+
|
|
276
|
+
with self.profile("optimizer step"):
|
|
277
|
+
optimizer.step()
|
|
278
|
+
optimizer.zero_grad()
|
|
279
|
+
|
|
280
|
+
metrics = {
|
|
281
|
+
"train/loss": total_loss.item(),
|
|
282
|
+
}
|
|
283
|
+
self.fabric.log_dict(metrics, step=step_idx)
|
|
284
|
+
pbar.set_postfix(metrics)
|
|
285
|
+
|
|
286
|
+
log.info(get_memory_usage(f"after adamerging, the memory usage of GPU is:"))
|
|
287
|
+
self.print_profile_summary()
|
|
288
|
+
return module
|
|
289
|
+
|
|
290
|
+
def on_test_time_adaptation_start(self):
|
|
291
|
+
"""
|
|
292
|
+
Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.
|
|
293
|
+
"""
|
|
294
|
+
pass
|
|
295
|
+
|
|
296
|
+
def run(self, modelpool: Seq2SeqLMPool, **kwargs):
|
|
297
|
+
"""
|
|
298
|
+
Run the WeightEnsemblingMoEAlgorithm to fuse models using Weight Ensembling Mixture of Experts.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
modelpool (ModelPool): The pool of models to be fused.
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
WeightEnsemblingMoE: The fused MoE model.
|
|
305
|
+
"""
|
|
306
|
+
log.info("Fusing models using layer-wise adaptive merging.")
|
|
307
|
+
self.modelpool = modelpool
|
|
308
|
+
|
|
309
|
+
with timeit_context("upscaling models to a weight-ensembling MoE model"):
|
|
310
|
+
moe_model = self.construct_moe_model()
|
|
311
|
+
print_parameters(moe_model)
|
|
312
|
+
|
|
313
|
+
if self.checkpoint != False:
|
|
314
|
+
log.info(
|
|
315
|
+
f"load checkpoint from {self.checkpoint}, test-time adaptation will be skipped."
|
|
316
|
+
)
|
|
317
|
+
self.load_checkpoint(moe_model, self.checkpoint)
|
|
318
|
+
else:
|
|
319
|
+
with self.profile("test-time adaptation"):
|
|
320
|
+
moe_model = self.test_time_adaptation(moe_model)
|
|
321
|
+
if self.save_checkpoint != False:
|
|
322
|
+
log.info(f"save checkpoint to {self.save_checkpoint}")
|
|
323
|
+
self.save_checkpoint(moe_model, self.save_checkpoint)
|
|
324
|
+
|
|
325
|
+
if lightning.fabric.wrappers.is_wrapped(moe_model):
|
|
326
|
+
moe_model = lightning.fabric.wrappers._unwrap_objects(moe_model)
|
|
327
|
+
|
|
328
|
+
# enable sample-wise adaptation
|
|
329
|
+
moe_model.batch_reduce = False
|
|
330
|
+
self.print_profile_summary()
|
|
331
|
+
return moe_model
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def get_memory_usage(desc):
|
|
5
|
+
"""
|
|
6
|
+
obtain the current GPU memory usage
|
|
7
|
+
|
|
8
|
+
Returns:
|
|
9
|
+
str: A string containing the allocated and cached memory in MB.
|
|
10
|
+
"""
|
|
11
|
+
allocated = torch.cuda.memory_allocated() / 1024**2 # 转换为 MB
|
|
12
|
+
cached = torch.cuda.memory_reserved() / 1024**2 # 转换为 MB
|
|
13
|
+
return (
|
|
14
|
+
f"{desc}\nAllocated Memory: {allocated:.2f} MB\nCached Memory: {cached:.2f} MB"
|
|
15
|
+
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from abc import abstractmethod
|
|
3
|
-
from typing import cast # noqa: F401
|
|
3
|
+
from typing import Any, cast # noqa: F401
|
|
4
4
|
|
|
5
5
|
import lightning as L
|
|
6
6
|
import lightning.fabric.wrappers
|
|
@@ -70,7 +70,7 @@ class WeightEnsemblingMoEAlgorithm(
|
|
|
70
70
|
assert "No CUDA device available."
|
|
71
71
|
|
|
72
72
|
@abstractmethod
|
|
73
|
-
def load_checkpoint(self, model, checkpoint):
|
|
73
|
+
def load_checkpoint(self, model: Any, checkpoint: Any):
|
|
74
74
|
"""
|
|
75
75
|
Load the checkpoint file.
|
|
76
76
|
|
|
@@ -81,7 +81,7 @@ class WeightEnsemblingMoEAlgorithm(
|
|
|
81
81
|
pass
|
|
82
82
|
|
|
83
83
|
@abstractmethod
|
|
84
|
-
def save_checkpoint(self, model, checkpoint):
|
|
84
|
+
def save_checkpoint(self, model: Any, checkpoint: Any):
|
|
85
85
|
"""
|
|
86
86
|
Save the checkpoint file.
|
|
87
87
|
|
|
@@ -121,7 +121,7 @@ class WeightEnsemblingMoEAlgorithm(
|
|
|
121
121
|
pass
|
|
122
122
|
|
|
123
123
|
@abstractmethod
|
|
124
|
-
def compute_logits(self, module, batch, task) -> Tensor:
|
|
124
|
+
def compute_logits(self, module: Any, batch: Any, task: Any) -> Tensor:
|
|
125
125
|
"""
|
|
126
126
|
Compute the logits for a given batch and task.
|
|
127
127
|
|
|
@@ -135,7 +135,7 @@ class WeightEnsemblingMoEAlgorithm(
|
|
|
135
135
|
"""
|
|
136
136
|
pass
|
|
137
137
|
|
|
138
|
-
def test_time_adaptation(self, module: WeightEnsemblingMoE):
|
|
138
|
+
def test_time_adaptation(self, module: WeightEnsemblingMoE) -> WeightEnsemblingMoE:
|
|
139
139
|
"""
|
|
140
140
|
Perform test-time adaptation for the given module.
|
|
141
141
|
|
|
@@ -208,7 +208,7 @@ class WeightEnsemblingMoEAlgorithm(
|
|
|
208
208
|
|
|
209
209
|
return module
|
|
210
210
|
|
|
211
|
-
def run(self, modelpool: ModelPool):
|
|
211
|
+
def run(self, modelpool: ModelPool) -> WeightEnsemblingMoE:
|
|
212
212
|
"""
|
|
213
213
|
Run the WeightEnsemblingMoEAlgorithm to fuse models using Weight Ensembling Mixture of Experts.
|
|
214
214
|
|
|
@@ -3,9 +3,11 @@ from typing import List, Mapping, Union # noqa: F401
|
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import torch
|
|
6
|
+
from transformers import PreTrainedModel
|
|
6
7
|
from typing_extensions import override
|
|
7
8
|
|
|
8
9
|
from fusion_bench.method import BaseAlgorithm
|
|
10
|
+
from fusion_bench.mixins import auto_register_config
|
|
9
11
|
from fusion_bench.modelpool import CausalLMPool
|
|
10
12
|
from fusion_bench.utils import timeit_context
|
|
11
13
|
from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_mul
|
|
@@ -14,20 +16,12 @@ from fusion_bench.utils.type import StateDictType
|
|
|
14
16
|
log = logging.getLogger(__name__)
|
|
15
17
|
|
|
16
18
|
|
|
19
|
+
@auto_register_config
|
|
17
20
|
class WeightedAverageForLLama(BaseAlgorithm):
|
|
18
21
|
"""
|
|
19
22
|
A class to perform weighted averaging of LlaMa/Mistral models.
|
|
20
23
|
"""
|
|
21
24
|
|
|
22
|
-
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
23
|
-
"normalize": "normalize",
|
|
24
|
-
"weights": "weights",
|
|
25
|
-
"backbone_only": "backbone_only",
|
|
26
|
-
"merged_model_save_path": "merged_model_save_path",
|
|
27
|
-
"save_tokenizer": "save_tokenizer",
|
|
28
|
-
"push_to_hub": "push_to_hub",
|
|
29
|
-
}
|
|
30
|
-
|
|
31
25
|
def __init__(
|
|
32
26
|
self,
|
|
33
27
|
normalize: bool,
|
|
@@ -49,17 +43,11 @@ class WeightedAverageForLLama(BaseAlgorithm):
|
|
|
49
43
|
save_tokenizer (bool): Whether to save the tokenizer.
|
|
50
44
|
push_to_hub (bool): Whether to push the model to the hub.
|
|
51
45
|
"""
|
|
52
|
-
self.normalize = normalize
|
|
53
|
-
self.weights = weights
|
|
54
|
-
self.backbone_only = backbone_only
|
|
55
|
-
self.merged_model_save_path = merged_model_save_path
|
|
56
|
-
self.save_tokenizer = save_tokenizer
|
|
57
|
-
self.push_to_hub = push_to_hub
|
|
58
46
|
super().__init__(**kwargs)
|
|
59
47
|
|
|
60
48
|
@override
|
|
61
49
|
@torch.no_grad()
|
|
62
|
-
def run(self, modelpool: CausalLMPool):
|
|
50
|
+
def run(self, modelpool: CausalLMPool) -> PreTrainedModel:
|
|
63
51
|
"""
|
|
64
52
|
Executes the weighted averaging of models in the provided model pool.
|
|
65
53
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .backward_transfer import compute_backward_transfer
|
|
@@ -10,7 +10,7 @@ def compute_backward_transfer(
|
|
|
10
10
|
Compute the backward transfer (BWT) of a model on a set of tasks.
|
|
11
11
|
|
|
12
12
|
Equation:
|
|
13
|
-
BWT = \frac{1}{n} \sum_{k=1}^{n} (acc_{
|
|
13
|
+
$BWT = \frac{1}{n} \sum_{k=1}^{n} (acc_{T,i}[k] - acc_{i,i}[k])$
|
|
14
14
|
|
|
15
15
|
Returns:
|
|
16
16
|
float: The backward transfer of the model.
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
from .depth import DepthMetric
|
|
2
2
|
from .noise import NoiseMetric
|
|
3
3
|
from .normal import NormalMetric
|
|
4
|
-
from .segmentation import
|
|
4
|
+
from .segmentation import SegmentationMetric
|
|
5
5
|
|
|
6
6
|
metric_classes = {
|
|
7
|
-
"segmentation":
|
|
7
|
+
"segmentation": SegmentationMetric,
|
|
8
8
|
"depth": DepthMetric,
|
|
9
9
|
"normal": NormalMetric,
|
|
10
10
|
"noise": NoiseMetric,
|
fusion_bench/mixins/__init__.py
CHANGED
|
@@ -11,7 +11,11 @@ _import_structure = {
|
|
|
11
11
|
"hydra_config": ["HydraConfigMixin"],
|
|
12
12
|
"lightning_fabric": ["LightningFabricMixin"],
|
|
13
13
|
"openclip_classification": ["OpenCLIPClassificationMixin"],
|
|
14
|
-
"serialization": [
|
|
14
|
+
"serialization": [
|
|
15
|
+
"BaseYAMLSerializable",
|
|
16
|
+
"YAMLSerializationMixin",
|
|
17
|
+
"auto_register_config",
|
|
18
|
+
],
|
|
15
19
|
"simple_profiler": ["SimpleProfilerMixin"],
|
|
16
20
|
}
|
|
17
21
|
|
|
@@ -21,7 +25,11 @@ if TYPE_CHECKING:
|
|
|
21
25
|
from .hydra_config import HydraConfigMixin
|
|
22
26
|
from .lightning_fabric import LightningFabricMixin
|
|
23
27
|
from .openclip_classification import OpenCLIPClassificationMixin
|
|
24
|
-
from .serialization import
|
|
28
|
+
from .serialization import (
|
|
29
|
+
BaseYAMLSerializable,
|
|
30
|
+
YAMLSerializationMixin,
|
|
31
|
+
auto_register_config,
|
|
32
|
+
)
|
|
25
33
|
from .simple_profiler import SimpleProfilerMixin
|
|
26
34
|
else:
|
|
27
35
|
sys.modules[__name__] = LazyImporter(
|
|
@@ -6,6 +6,7 @@ from typing import ( # noqa: F401
|
|
|
6
6
|
TYPE_CHECKING,
|
|
7
7
|
Any,
|
|
8
8
|
Dict,
|
|
9
|
+
Iterator,
|
|
9
10
|
List,
|
|
10
11
|
Optional,
|
|
11
12
|
Tuple,
|
|
@@ -21,6 +22,7 @@ from torch.utils.data import DataLoader
|
|
|
21
22
|
from tqdm.auto import tqdm
|
|
22
23
|
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
|
|
23
24
|
|
|
25
|
+
from fusion_bench import cache_with_joblib
|
|
24
26
|
from fusion_bench.dataset.clip_dataset import CLIPDataset
|
|
25
27
|
from fusion_bench.mixins import LightningFabricMixin
|
|
26
28
|
from fusion_bench.modelpool import CLIPVisionModelPool
|
|
@@ -45,15 +47,13 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
45
47
|
|
|
46
48
|
- `_dataloader_kwargs` (Dict[str, Any]): Keyword arguments for the dataloader.
|
|
47
49
|
- `modelpool` (CLIPVisionModelPool): The model pool containing the CLIP models.
|
|
48
|
-
- `zeroshot_weights_cache_dir` (Optional[str]): The directory to cache the zero-shot weights.
|
|
49
50
|
"""
|
|
50
51
|
|
|
51
|
-
|
|
52
|
+
dataloader_kwargs: Dict[str, Any] = {}
|
|
52
53
|
# the modelpool is set by inheriting class
|
|
53
54
|
modelpool: CLIPVisionModelPool = None
|
|
54
55
|
_clip_processor: CLIPProcessor = None
|
|
55
56
|
# a dict of zeroshot weights for each task, each key is the task name
|
|
56
|
-
zeroshot_weights_cache_dir: str = "outputs/cache/clip_zeroshot_weights"
|
|
57
57
|
zeroshot_weights: Dict[str, torch.Tensor] = {}
|
|
58
58
|
whether_setup_zero_shot_classification_head = False
|
|
59
59
|
|
|
@@ -71,7 +71,7 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
71
71
|
batch_size: Optional[int] = None,
|
|
72
72
|
num_workers: Optional[int] = None,
|
|
73
73
|
**loader_kwargs,
|
|
74
|
-
):
|
|
74
|
+
) -> Iterator:
|
|
75
75
|
"""
|
|
76
76
|
Get an iterator for a shuffled test DataLoader.
|
|
77
77
|
|
|
@@ -89,7 +89,7 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
89
89
|
Iterator: An iterator over the shuffled test DataLoader.
|
|
90
90
|
"""
|
|
91
91
|
# get dataloader kwargs
|
|
92
|
-
dataloader_kwargs = self.
|
|
92
|
+
dataloader_kwargs = self.dataloader_kwargs.copy()
|
|
93
93
|
dataloader_kwargs["shuffle"] = True
|
|
94
94
|
if batch_size is not None:
|
|
95
95
|
dataloader_kwargs["batch_size"] = batch_size
|
|
@@ -130,26 +130,16 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
130
130
|
self.visual_projection = self.fabric.to_device(self.visual_projection)
|
|
131
131
|
self.logit_scale_exp = self.fabric.to_device(self.logit_scale_exp)
|
|
132
132
|
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
cache_dir = os.path.join(
|
|
143
|
-
self.zeroshot_weights_cache_dir,
|
|
144
|
-
os.path.normpath(model_name.split("/")[-1]),
|
|
145
|
-
)
|
|
146
|
-
if not os.path.exists(cache_dir):
|
|
147
|
-
log.info(
|
|
148
|
-
f"Creating cache directory for zero-shot classification head at {cache_dir}"
|
|
149
|
-
)
|
|
150
|
-
os.makedirs(cache_dir)
|
|
133
|
+
@cache_with_joblib()
|
|
134
|
+
def construct_classification_head(task: str):
|
|
135
|
+
nonlocal clip_classifier
|
|
136
|
+
|
|
137
|
+
classnames, templates = get_classnames_and_templates(task)
|
|
138
|
+
clip_classifier.set_classification_task(classnames, templates)
|
|
139
|
+
zeroshot_weights = clip_classifier.zeroshot_weights.detach().clone()
|
|
140
|
+
|
|
141
|
+
return zeroshot_weights
|
|
151
142
|
|
|
152
|
-
log.info(f"cache directory for zero-shot classification head: {cache_dir}")
|
|
153
143
|
for task in tqdm(
|
|
154
144
|
self.modelpool.model_names if task_names is None else task_names,
|
|
155
145
|
"Setting up zero-shot classification head",
|
|
@@ -157,27 +147,7 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
157
147
|
):
|
|
158
148
|
zeroshot_weights = None
|
|
159
149
|
if self.fabric.is_global_zero:
|
|
160
|
-
|
|
161
|
-
cache_dir, os.path.normpath(f"{task}_zeroshot_weights.pt")
|
|
162
|
-
)
|
|
163
|
-
if os.path.exists(cache_file):
|
|
164
|
-
zeroshot_weights = torch.load(
|
|
165
|
-
cache_file,
|
|
166
|
-
map_location="cpu",
|
|
167
|
-
weights_only=True,
|
|
168
|
-
).detach()
|
|
169
|
-
log.info(
|
|
170
|
-
f"Loadded cached zeroshot weights for task: {task}, shape: {zeroshot_weights.shape}"
|
|
171
|
-
)
|
|
172
|
-
else:
|
|
173
|
-
log.info(
|
|
174
|
-
f"Construct zero shot classification head for task: {task}"
|
|
175
|
-
)
|
|
176
|
-
classnames, templates = get_classnames_and_templates(task)
|
|
177
|
-
clip_classifier.set_classification_task(classnames, templates)
|
|
178
|
-
zeroshot_weights = clip_classifier.zeroshot_weights.detach().clone()
|
|
179
|
-
log.info(f"save zeroshot weights to {cache_file}")
|
|
180
|
-
torch.save(zeroshot_weights, cache_file)
|
|
150
|
+
zeroshot_weights = construct_classification_head(task)
|
|
181
151
|
|
|
182
152
|
self.fabric.barrier()
|
|
183
153
|
self.zeroshot_weights[task] = self.fabric.broadcast(zeroshot_weights, src=0)
|