fusion-bench 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +2 -0
- fusion_bench/compat/method/base_algorithm.py +7 -2
- fusion_bench/compat/modelpool/__init__.py +3 -2
- fusion_bench/compat/taskpool/__init__.py +1 -1
- fusion_bench/dataset/arc_agi/__init__.py +6 -1
- fusion_bench/dataset/arc_agi/arc.py +26 -7
- fusion_bench/dataset/arc_agi/arc_agi.py +156 -25
- fusion_bench/dataset/arc_agi/np_cache.py +0 -1
- fusion_bench/dataset/arc_agi/preprocess.py +51 -9
- fusion_bench/dataset/llama/__init__.py +1 -0
- fusion_bench/dataset/llama/alpaca.py +93 -3
- fusion_bench/dataset/llama/collate.py +72 -5
- fusion_bench/dataset/llama/metamathqa.py +50 -0
- fusion_bench/dataset/llama/preference_700k.py +70 -0
- fusion_bench/dataset/llama/stanford_shp.py +90 -0
- fusion_bench/dataset/llama/ultrachat.py +58 -0
- fusion_bench/dataset/llama/utils/__init__.py +0 -0
- fusion_bench/method/__init__.py +4 -1
- fusion_bench/method/adamerging/__init__.py +1 -1
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
- fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
- fusion_bench/method/linear/expo.py +39 -0
- fusion_bench/method/lm_finetune/__init__.py +1 -0
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
- fusion_bench/method/lm_finetune/fullfinetune_sft.py +122 -150
- fusion_bench/method/lm_finetune/peftfinetune_sft.py +102 -157
- fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
- fusion_bench/method/pruning/llama_random_prune.py +2 -2
- fusion_bench/method/pruning/magnitude_diff_pruning.py +2 -1
- fusion_bench/method/rankone_moe/__init__.py +3 -0
- fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
- fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
- fusion_bench/method/simple_average.py +1 -1
- fusion_bench/method/surgery/__init__.py +3 -0
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
- fusion_bench/mixins/__init__.py +2 -0
- fusion_bench/mixins/clip_classification.py +60 -12
- fusion_bench/mixins/fabric_training.py +320 -0
- fusion_bench/mixins/lightning_fabric.py +11 -2
- fusion_bench/modelpool/__init__.py +2 -0
- fusion_bench/modelpool/causal_lm/__init__.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
- fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
- fusion_bench/models/chat_templates/__init__.py +1 -0
- fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
- fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
- fusion_bench/models/hf_clip.py +50 -9
- fusion_bench/models/rankone_moe.py +410 -0
- fusion_bench/models/surgery/surgerymodelwrapper.py +157 -0
- fusion_bench/models/utils.py +8 -0
- fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
- fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
- fusion_bench/optim/__init__.py +2 -0
- fusion_bench/optim/exception.py +47 -0
- fusion_bench/optim/lr_scheduler/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
- fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
- fusion_bench/optim/mezo.py +0 -2
- fusion_bench/programs/fabric_fusion_program.py +5 -1
- fusion_bench/taskpool/__init__.py +10 -2
- fusion_bench/taskpool/clip_vision/__init__.py +1 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
- fusion_bench/taskpool/llama/reward_model.py +157 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +2 -1
- fusion_bench/utils/hydra_utils.py +22 -0
- fusion_bench/utils/plot/__init__.py +0 -0
- fusion_bench/utils/plot/token.py +52 -0
- fusion_bench/utils/plot/token_notebook.py +127 -0
- fusion_bench/utils/type.py +5 -3
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/RECORD +104 -57
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
- fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
- fusion_bench_config/fabric_model_fusion.yaml +1 -1
- fusion_bench_config/llama_full_finetune.yaml +19 -0
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +13 -6
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +17 -9
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
- fusion_bench_config/method/regmean/clip_regmean.yaml +1 -0
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
- fusion_bench_config/nyuv2_config.yaml +5 -1
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
- fusion_bench_config/llama_weighted_average.yaml +0 -26
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
from typing import cast # noqa: F401
|
|
4
|
+
|
|
5
|
+
import lightning as L
|
|
6
|
+
import lightning.fabric.wrappers
|
|
7
|
+
import torch
|
|
8
|
+
from lightning.pytorch.profilers import SimpleProfiler
|
|
9
|
+
from omegaconf import DictConfig
|
|
10
|
+
from torch import Tensor
|
|
11
|
+
from torch.utils.data import DataLoader
|
|
12
|
+
from tqdm.autonotebook import tqdm
|
|
13
|
+
|
|
14
|
+
from fusion_bench.compat.method.base_algorithm import ModelFusionAlgorithm
|
|
15
|
+
from fusion_bench.compat.modelpool import ModelPool
|
|
16
|
+
from fusion_bench.models.rankone_moe import RankOneMoE
|
|
17
|
+
from fusion_bench.utils import timeit_context
|
|
18
|
+
from fusion_bench.utils.parameters import print_parameters
|
|
19
|
+
|
|
20
|
+
log = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def entropy_loss(logits: Tensor) -> Tensor:
|
|
24
|
+
"""
|
|
25
|
+
Compute the entropy loss of a set of logits.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
logits (Tensor): The logits to compute the entropy loss of.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
Tensor: The entropy loss of the logits.
|
|
32
|
+
"""
|
|
33
|
+
probs = torch.softmax(logits, dim=-1)
|
|
34
|
+
return -torch.sum(probs * torch.log(probs + 1e-8), dim=-1).mean()
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class RankOneMoEAlgorithm(ModelFusionAlgorithm):
|
|
38
|
+
"""
|
|
39
|
+
Algorithm for fusing models using RankOne-MoE (https://github.com/EnnengYang/RankOne-MoE).
|
|
40
|
+
|
|
41
|
+
This class provides methods for constructing the MoE model, performing test-time adaptation,
|
|
42
|
+
and running the fusion process.
|
|
43
|
+
|
|
44
|
+
Attributes:
|
|
45
|
+
_fabric (L.Fabric): The fabric for distributed training.
|
|
46
|
+
modelpool (ModelPool): The pool of models to be fused.
|
|
47
|
+
profiler (SimpleProfiler): The profiler for measuring performance.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
_fabric: L.Fabric = None
|
|
51
|
+
modelpool: ModelPool = None
|
|
52
|
+
|
|
53
|
+
def __init__(self, algorithm_config: DictConfig):
|
|
54
|
+
"""
|
|
55
|
+
Initialize the RankOneMoEAlgorithm with the given configuration.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
algorithm_config (DictConfig): The configuration for the algorithm.
|
|
59
|
+
"""
|
|
60
|
+
super().__init__(algorithm_config)
|
|
61
|
+
|
|
62
|
+
if self._fabric is None and torch.cuda.is_available():
|
|
63
|
+
self._fabric = L.Fabric(
|
|
64
|
+
devices=self.config.get("devices", 1),
|
|
65
|
+
)
|
|
66
|
+
self._fabric.launch()
|
|
67
|
+
else:
|
|
68
|
+
assert "No CUDA device available."
|
|
69
|
+
self.profiler = SimpleProfiler(
|
|
70
|
+
self.config.get("cache_dir", "outputs"), "we_moe_profiler.txt"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
@abstractmethod
|
|
74
|
+
def load_checkpoint(self, model, checkpoint):
|
|
75
|
+
"""
|
|
76
|
+
Load the checkpoint file.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
model: The model to load the checkpoint into.
|
|
80
|
+
checkpoint: The checkpoint file to load.
|
|
81
|
+
"""
|
|
82
|
+
pass
|
|
83
|
+
|
|
84
|
+
@abstractmethod
|
|
85
|
+
def save_checkpoint(self, model, checkpoint):
|
|
86
|
+
"""
|
|
87
|
+
Save the checkpoint file.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
model: The model to save the checkpoint from.
|
|
91
|
+
checkpoint: The checkpoint file to save.
|
|
92
|
+
"""
|
|
93
|
+
pass
|
|
94
|
+
|
|
95
|
+
@abstractmethod
|
|
96
|
+
def construct_moe_model(self) -> RankOneMoE:
|
|
97
|
+
"""
|
|
98
|
+
Construct the Mixture of Experts model using the models in the model pool.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
RankOne-MoE: The constructed MoE model.
|
|
102
|
+
"""
|
|
103
|
+
pass
|
|
104
|
+
|
|
105
|
+
def on_test_time_adaptation_start(self):
|
|
106
|
+
"""
|
|
107
|
+
Hook method called at the start of test-time adaptation.
|
|
108
|
+
"""
|
|
109
|
+
pass
|
|
110
|
+
|
|
111
|
+
@abstractmethod
|
|
112
|
+
def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
|
|
113
|
+
"""
|
|
114
|
+
Get an iterator for the shuffled test data loader for a specific task.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
task (str): The task for which to get the test data loader.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
DataLoader: The shuffled test data loader iterator.
|
|
121
|
+
"""
|
|
122
|
+
pass
|
|
123
|
+
|
|
124
|
+
@abstractmethod
|
|
125
|
+
def compute_logits(self, module, batch, task) -> Tensor:
|
|
126
|
+
"""
|
|
127
|
+
Compute the logits for a given batch and task.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
module: The model module to use for computing logits.
|
|
131
|
+
batch: The batch of data.
|
|
132
|
+
task: The task for which to compute logits.
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
Tensor: The computed logits.
|
|
136
|
+
"""
|
|
137
|
+
pass
|
|
138
|
+
|
|
139
|
+
def test_time_adaptation(self, module: RankOneMoE):
|
|
140
|
+
"""
|
|
141
|
+
Perform test-time adaptation for the given module.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
module (RankOne-MoE): The MoE module to adapt.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
RankOne-MoE: The adapted MoE module.
|
|
148
|
+
"""
|
|
149
|
+
self.on_test_time_adaptation_start()
|
|
150
|
+
|
|
151
|
+
# configure optimizer
|
|
152
|
+
if self.config.optimizer == "adam":
|
|
153
|
+
optimizer = torch.optim.Adam(
|
|
154
|
+
[p for p in module.parameters() if p.requires_grad], lr=self.config.lr
|
|
155
|
+
)
|
|
156
|
+
else:
|
|
157
|
+
raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")
|
|
158
|
+
|
|
159
|
+
if self._fabric is not None:
|
|
160
|
+
module, optimizer = self._fabric.setup(module, optimizer)
|
|
161
|
+
|
|
162
|
+
module.train()
|
|
163
|
+
|
|
164
|
+
if self.config.get("fast_dev_run", False):
|
|
165
|
+
log.info("Running fast_dev_run, only one step")
|
|
166
|
+
pbar = tqdm(
|
|
167
|
+
range(1),
|
|
168
|
+
"Test-time adaptation",
|
|
169
|
+
dynamic_ncols=True,
|
|
170
|
+
)
|
|
171
|
+
else:
|
|
172
|
+
pbar = tqdm(
|
|
173
|
+
range(self.config.max_steps),
|
|
174
|
+
"Test-time adaptation",
|
|
175
|
+
dynamic_ncols=True,
|
|
176
|
+
)
|
|
177
|
+
for step_idx in pbar:
|
|
178
|
+
if self.config.use_grad_accumulate:
|
|
179
|
+
for task in self.modelpool.model_names:
|
|
180
|
+
with self.profiler.profile("data time"):
|
|
181
|
+
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
182
|
+
with self.profiler.profile("forward pass"):
|
|
183
|
+
logits = self.compute_logits(module, batch, task)
|
|
184
|
+
assert (
|
|
185
|
+
logits.dim() == 2
|
|
186
|
+
), f"Expected logits to be 2D, got {logits.dim()}"
|
|
187
|
+
loss = entropy_loss(logits)
|
|
188
|
+
# .backward() accumulates when .zero_grad() wasn't called
|
|
189
|
+
# this can save memory
|
|
190
|
+
with self.profiler.profile("backward pass"):
|
|
191
|
+
self._fabric.backward(loss, retain_graph=True)
|
|
192
|
+
else:
|
|
193
|
+
loss = 0
|
|
194
|
+
for task in self.modelpool.model_names:
|
|
195
|
+
with self.profiler.profile("data time"):
|
|
196
|
+
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
197
|
+
with self.profiler.profile("forward pass"):
|
|
198
|
+
logits = self.compute_logits(module, batch, task)
|
|
199
|
+
assert (
|
|
200
|
+
logits.dim() == 2
|
|
201
|
+
), f"Expected logits to be 2D, got {logits.dim()}"
|
|
202
|
+
loss = loss + entropy_loss(logits)
|
|
203
|
+
with self.profiler.profile("backward pass"):
|
|
204
|
+
self._fabric.backward(loss, retain_graph=True)
|
|
205
|
+
|
|
206
|
+
with self.profiler.profile("optimizer step"):
|
|
207
|
+
optimizer.step()
|
|
208
|
+
optimizer.zero_grad()
|
|
209
|
+
|
|
210
|
+
# print([m for m in module.parameters() if m.requires_grad][0])
|
|
211
|
+
|
|
212
|
+
return module
|
|
213
|
+
|
|
214
|
+
def run(self, modelpool: ModelPool):
|
|
215
|
+
"""
|
|
216
|
+
Run the RankOneMoEAlgorithm to fuse models using RankOne-MoE.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
modelpool (ModelPool): The pool of models to be fused.
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
RankOne-MoE: The fused RankOne MoE model.
|
|
223
|
+
"""
|
|
224
|
+
log.info("Fusing models using RankOne-MoE modules.")
|
|
225
|
+
self.modelpool = modelpool
|
|
226
|
+
|
|
227
|
+
with timeit_context("upscaling models to a RankOne-MoE model"):
|
|
228
|
+
moe_model = self.construct_moe_model()
|
|
229
|
+
print_parameters(moe_model)
|
|
230
|
+
|
|
231
|
+
if self.config.get("checkpoint", False):
|
|
232
|
+
log.info(
|
|
233
|
+
f"load checkpoint from {self.config.checkpoint}, test-time adaptation will be skipped."
|
|
234
|
+
)
|
|
235
|
+
self.load_checkpoint(moe_model, self.config.checkpoint)
|
|
236
|
+
else:
|
|
237
|
+
with self.profiler.profile("test-time adaptation"):
|
|
238
|
+
moe_model = self.test_time_adaptation(moe_model)
|
|
239
|
+
if self.config.get("save_checkpoint", False):
|
|
240
|
+
log.info(f"save checkpoint to {self.config.save_checkpoint}")
|
|
241
|
+
self.save_checkpoint(moe_model, self.config.save_checkpoint)
|
|
242
|
+
|
|
243
|
+
if lightning.fabric.wrappers.is_wrapped(moe_model):
|
|
244
|
+
moe_model = lightning.fabric.wrappers._unwrap_objects(moe_model)
|
|
245
|
+
|
|
246
|
+
# enable sample-wise adaptation
|
|
247
|
+
moe_model.batch_reduce = False
|
|
248
|
+
print(self.profiler.summary())
|
|
249
|
+
return moe_model
|
|
@@ -11,8 +11,8 @@ from fusion_bench.modelpool import BaseModelPool
|
|
|
11
11
|
from fusion_bench.utils.state_dict_arithmetic import (
|
|
12
12
|
state_dict_add,
|
|
13
13
|
state_dict_avg,
|
|
14
|
-
state_dict_mul,
|
|
15
14
|
state_dict_div,
|
|
15
|
+
state_dict_mul,
|
|
16
16
|
)
|
|
17
17
|
from fusion_bench.utils.type import StateDictType
|
|
18
18
|
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Implementation of the Layer-Wise AdaMerging+Surgery Algorithm.
|
|
3
|
+
|
|
4
|
+
For more details, please refer to:
|
|
5
|
+
|
|
6
|
+
- (ICLR 2024) Yang, et.al. AdaMerging: Adaptive Model Merging for Multi-Task Learning. http://arxiv.org/abs/2310.02575
|
|
7
|
+
- (ICML 2024) Yang, et.al. Representation Surgery for Multi-Task Model Merging. https://arxiv.org/abs/2402.02705
|
|
8
|
+
|
|
9
|
+
Basic Example:
|
|
10
|
+
|
|
11
|
+
```shell
|
|
12
|
+
fusion_bench \
|
|
13
|
+
method=surgery/adamerging_surgery \
|
|
14
|
+
modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8 \
|
|
15
|
+
taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8
|
|
16
|
+
```
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import copy
|
|
20
|
+
import functools
|
|
21
|
+
import gc
|
|
22
|
+
import logging
|
|
23
|
+
from typing import TYPE_CHECKING, cast
|
|
24
|
+
|
|
25
|
+
import torch
|
|
26
|
+
import torch.nn.functional as F
|
|
27
|
+
from torch.utils.data import DataLoader
|
|
28
|
+
from tqdm import tqdm
|
|
29
|
+
from transformers import CLIPVisionModel
|
|
30
|
+
|
|
31
|
+
from fusion_bench.dataset.clip_dataset import CLIPDataset
|
|
32
|
+
from fusion_bench.method.adamerging.layer_wise_adamerging import (
|
|
33
|
+
LayerWiseAdaMergingAlgorithm,
|
|
34
|
+
)
|
|
35
|
+
from fusion_bench.method.adamerging.utils import get_memory_usage
|
|
36
|
+
from fusion_bench.mixins import CLIPClassificationMixin
|
|
37
|
+
from fusion_bench.modelpool import CLIPVisionModelPool
|
|
38
|
+
from fusion_bench.models.surgery.surgerymodelwrapper import SurgeryModelWrapper
|
|
39
|
+
from fusion_bench.models.wrappers.layer_wise_fusion import LayerWiseMergedModel
|
|
40
|
+
|
|
41
|
+
log = logging.getLogger(__name__)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class CLIPLayerWiseAdaMergingSurgeryAlgorithm(
|
|
45
|
+
CLIPClassificationMixin,
|
|
46
|
+
LayerWiseAdaMergingAlgorithm,
|
|
47
|
+
):
|
|
48
|
+
|
|
49
|
+
def on_test_time_adaptation_start(self):
|
|
50
|
+
"""
|
|
51
|
+
Here we load the CLIP processor and construct the zero-shot classification head for each task.
|
|
52
|
+
"""
|
|
53
|
+
self.setup_zero_shot_classification_head()
|
|
54
|
+
|
|
55
|
+
@functools.cache
|
|
56
|
+
def get_shuffled_test_loader_iter(self, task: str):
|
|
57
|
+
return super().get_shuffled_test_loader_iter(
|
|
58
|
+
task,
|
|
59
|
+
batch_size=self.config.batch_size,
|
|
60
|
+
num_workers=self.config.num_workers,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
def run(self, modelpool: CLIPVisionModelPool, **kwargs):
|
|
64
|
+
"""
|
|
65
|
+
Run the Layer-Wise AdaMerging+Surgery Algorithm.
|
|
66
|
+
|
|
67
|
+
This method constructs the wrapped model and performs test-time adaptation if necessary. Then, it will perform surgery.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
modelpool (ModelPool): The model pool containing the pretrained and fine-tuned models.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
LayerWiseMergedModel: The merged model after test-time adaptation.
|
|
74
|
+
"""
|
|
75
|
+
log.info("Fusing models using layer-wise adaptive merging.")
|
|
76
|
+
self.modelpool = modelpool
|
|
77
|
+
self.log_hyperparams(self.config)
|
|
78
|
+
|
|
79
|
+
# === Start of the AdaMerging Algorithm ===
|
|
80
|
+
with self.profile("construct the wrapped model"):
|
|
81
|
+
module = cast(
|
|
82
|
+
LayerWiseMergedModel[CLIPVisionModel],
|
|
83
|
+
self.construct_layer_wise_merged_model(modelpool),
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
if self.config.weights is not None:
|
|
87
|
+
# skip the test-time adaptation
|
|
88
|
+
merged_model = copy.deepcopy(module.merge_and_unload())
|
|
89
|
+
else:
|
|
90
|
+
with self.profile("test-time adaptation"):
|
|
91
|
+
module = self.test_time_adaptation(module)
|
|
92
|
+
if self.config.get("save_merging_weights", False):
|
|
93
|
+
self.save_merging_weights(
|
|
94
|
+
self.config.save_merging_weights, module.merge_weight
|
|
95
|
+
)
|
|
96
|
+
merged_model = copy.deepcopy(module.merge_and_unload())
|
|
97
|
+
|
|
98
|
+
# free memory
|
|
99
|
+
del module
|
|
100
|
+
gc.collect()
|
|
101
|
+
torch.cuda.empty_cache()
|
|
102
|
+
|
|
103
|
+
# === Start of the Surgery Algorithm ===
|
|
104
|
+
log.info("start performing Surgery")
|
|
105
|
+
alpha_model = SurgeryModelWrapper(
|
|
106
|
+
merged_model,
|
|
107
|
+
modelpool.model_names,
|
|
108
|
+
projection_dim=merged_model.config.projection_dim,
|
|
109
|
+
)
|
|
110
|
+
alpha_model = self.fabric.setup(alpha_model)
|
|
111
|
+
log.info(get_memory_usage("after freeing memory, the memory usage of GPU is:"))
|
|
112
|
+
|
|
113
|
+
optimizer = torch.optim.Adam(
|
|
114
|
+
alpha_model.collect_trainable_params(),
|
|
115
|
+
lr=1e-3,
|
|
116
|
+
betas=(0.9, 0.999),
|
|
117
|
+
weight_decay=0.0,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
finetuned_models = {
|
|
121
|
+
model_name: modelpool.load_model(model_name)
|
|
122
|
+
for model_name in modelpool.model_names
|
|
123
|
+
}
|
|
124
|
+
for name, model in finetuned_models.items():
|
|
125
|
+
model.requires_grad_(False)
|
|
126
|
+
model = self.fabric.to_device(model)
|
|
127
|
+
model.eval()
|
|
128
|
+
|
|
129
|
+
for iteration in tqdm(
|
|
130
|
+
range(self.config.surgery_steps),
|
|
131
|
+
"surgery",
|
|
132
|
+
dynamic_ncols=True,
|
|
133
|
+
):
|
|
134
|
+
for dataset_name in modelpool.model_names:
|
|
135
|
+
batch = next(self.get_shuffled_test_loader_iter(dataset_name))
|
|
136
|
+
finetuned_feature = self.compute_features(
|
|
137
|
+
finetuned_models[dataset_name], batch[0]
|
|
138
|
+
)
|
|
139
|
+
features, _, _ = alpha_model.compute_surgery_features(
|
|
140
|
+
lambda model: self.compute_features(model, batch[0]),
|
|
141
|
+
dataset_name,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
loss = F.l1_loss(features, finetuned_feature)
|
|
145
|
+
|
|
146
|
+
optimizer.zero_grad()
|
|
147
|
+
loss.backward()
|
|
148
|
+
optimizer.step()
|
|
149
|
+
|
|
150
|
+
if ((iteration + 1) % self.config.eval_iterations) == 0:
|
|
151
|
+
# print(list(alpha_model.collect_trainable_params()))
|
|
152
|
+
# Evaluate try to use the test module in fusion bench
|
|
153
|
+
log.info(f"iteration: {iteration+1}")
|
|
154
|
+
self._program.evaluate_merged_model(self._program.taskpool, alpha_model)
|
|
155
|
+
|
|
156
|
+
log.info("test the result of Adamerging")
|
|
157
|
+
return merged_model
|
fusion_bench/mixins/__init__.py
CHANGED
|
@@ -10,10 +10,12 @@ _import_structure = {
|
|
|
10
10
|
"serialization": ["YAMLSerializationMixin", "BaseYAMLSerializableModel"],
|
|
11
11
|
"simple_profiler": ["SimpleProfilerMixin"],
|
|
12
12
|
"clip_classification": ["CLIPClassificationMixin"],
|
|
13
|
+
"fabric_training": ["FabricTrainingMixin"],
|
|
13
14
|
}
|
|
14
15
|
|
|
15
16
|
if TYPE_CHECKING:
|
|
16
17
|
from .clip_classification import CLIPClassificationMixin
|
|
18
|
+
from .fabric_training import FabricTrainingMixin
|
|
17
19
|
from .lightning_fabric import LightningFabricMixin
|
|
18
20
|
from .serialization import BaseYAMLSerializableModel, YAMLSerializationMixin
|
|
19
21
|
from .simple_profiler import SimpleProfilerMixin
|
|
@@ -2,14 +2,24 @@ import functools
|
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
4
|
from copy import deepcopy
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import ( # noqa: F401
|
|
6
|
+
TYPE_CHECKING,
|
|
7
|
+
Any,
|
|
8
|
+
Dict,
|
|
9
|
+
List,
|
|
10
|
+
Optional,
|
|
11
|
+
Tuple,
|
|
12
|
+
TypeVar,
|
|
13
|
+
Union,
|
|
14
|
+
cast,
|
|
15
|
+
)
|
|
6
16
|
|
|
7
17
|
import torch
|
|
18
|
+
from omegaconf import DictConfig
|
|
8
19
|
from torch import nn
|
|
9
20
|
from torch.utils.data import DataLoader
|
|
10
21
|
from tqdm.auto import tqdm
|
|
11
22
|
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
|
|
12
|
-
from omegaconf import DictConfig
|
|
13
23
|
|
|
14
24
|
from fusion_bench.dataset.clip_dataset import CLIPDataset
|
|
15
25
|
from fusion_bench.mixins import LightningFabricMixin
|
|
@@ -18,10 +28,12 @@ from fusion_bench.models.hf_clip import HFCLIPClassifier
|
|
|
18
28
|
from fusion_bench.tasks.clip_classification import get_classnames_and_templates
|
|
19
29
|
from fusion_bench.utils.data import InfiniteDataLoader
|
|
20
30
|
|
|
21
|
-
|
|
31
|
+
if TYPE_CHECKING:
|
|
32
|
+
from transformers.models.clip.modeling_clip import CLIPVisionTransformer
|
|
22
33
|
|
|
23
|
-
|
|
34
|
+
log = logging.getLogger(__name__)
|
|
24
35
|
|
|
36
|
+
# disable tokenizers parallelism by default to avoid deadlocks
|
|
25
37
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
26
38
|
|
|
27
39
|
|
|
@@ -43,12 +55,7 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
43
55
|
# a dict of zeroshot weights for each task, each key is the task name
|
|
44
56
|
zeroshot_weights_cache_dir: str = "outputs/cache/clip_zeroshot_weights"
|
|
45
57
|
zeroshot_weights: Dict[str, torch.Tensor] = {}
|
|
46
|
-
|
|
47
|
-
def __init__(self, algorithm_config: DictConfig) -> None:
|
|
48
|
-
super().__init__(algorithm_config)
|
|
49
|
-
self.whether_setup_zero_shot_classification_head = (
|
|
50
|
-
False # We want to only do this once
|
|
51
|
-
)
|
|
58
|
+
whether_setup_zero_shot_classification_head = False
|
|
52
59
|
|
|
53
60
|
@property
|
|
54
61
|
def clip_processor(self):
|
|
@@ -180,13 +187,30 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
180
187
|
|
|
181
188
|
def compute_logits(
|
|
182
189
|
self,
|
|
183
|
-
module: Union[nn.Module, CLIPVisionModel],
|
|
190
|
+
module: Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"],
|
|
184
191
|
images: torch.Tensor,
|
|
185
192
|
task: str,
|
|
193
|
+
image_embeds: Optional[torch.Tensor] = None,
|
|
186
194
|
) -> torch.Tensor:
|
|
195
|
+
"""
|
|
196
|
+
Compute the logits of the images for a given task.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
module (Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"]): The module to compute the logits.
|
|
200
|
+
images (torch.Tensor): The images to compute the logits.
|
|
201
|
+
task (str): The task to compute the logits.
|
|
202
|
+
image_embeds (Optional[torch.Tensor]): The precomputed image embeddings. If None, the image embeddings will be computed.
|
|
203
|
+
|
|
204
|
+
Returns:
|
|
205
|
+
torch.Tensor: The logits of the images.
|
|
206
|
+
"""
|
|
187
207
|
text_embeds = self.zeroshot_weights[task]
|
|
188
208
|
|
|
189
|
-
image_embeds
|
|
209
|
+
if image_embeds is None:
|
|
210
|
+
image_embeds = module(images)[1]
|
|
211
|
+
assert isinstance(
|
|
212
|
+
image_embeds, torch.Tensor
|
|
213
|
+
), f"`image_embeds` must be a tensor, but got {type(image_embeds)}"
|
|
190
214
|
image_embeds = self.visual_projection(image_embeds)
|
|
191
215
|
|
|
192
216
|
# normalize embeddings
|
|
@@ -199,3 +223,27 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
199
223
|
logits_per_image = logits_per_text.t()
|
|
200
224
|
|
|
201
225
|
return logits_per_image
|
|
226
|
+
|
|
227
|
+
def compute_features(
|
|
228
|
+
self,
|
|
229
|
+
module: Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"],
|
|
230
|
+
images: torch.Tensor,
|
|
231
|
+
normalize: bool = True,
|
|
232
|
+
) -> torch.Tensor:
|
|
233
|
+
"""
|
|
234
|
+
Extracts image features using CLIP's vision encoder and visual projection.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
module (Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"]): The CLIP vision encoder module.
|
|
238
|
+
images (torch.Tensor): Input image batch to process.
|
|
239
|
+
normalize (bool): Whether to normalize the image embeddings.
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
torch.Tensor: Normalized image embeddings with dimension matching CLIP's projection space (`projection_dim` in model config).
|
|
243
|
+
"""
|
|
244
|
+
image_embeds = module(images)[1]
|
|
245
|
+
image_embeds = self.visual_projection(image_embeds)
|
|
246
|
+
|
|
247
|
+
if normalize:
|
|
248
|
+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
|
249
|
+
return image_embeds
|