fusion-bench 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +1 -0
- fusion_bench/compat/method/base_algorithm.py +0 -1
- fusion_bench/compat/modelpool/__init__.py +2 -1
- fusion_bench/dataset/arc_agi/__init__.py +6 -1
- fusion_bench/dataset/arc_agi/arc.py +21 -7
- fusion_bench/dataset/arc_agi/arc_agi.py +156 -25
- fusion_bench/dataset/arc_agi/np_cache.py +0 -1
- fusion_bench/dataset/arc_agi/preprocess.py +50 -8
- fusion_bench/dataset/llama/collate.py +10 -3
- fusion_bench/method/__init__.py +3 -0
- fusion_bench/method/adamerging/__init__.py +1 -1
- fusion_bench/method/lm_finetune/fullfinetune_sft.py +47 -5
- fusion_bench/method/lm_finetune/peftfinetune_sft.py +58 -23
- fusion_bench/method/pruning/magnitude_diff_pruning.py +2 -1
- fusion_bench/method/rankone_moe/__init__.py +3 -0
- fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
- fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
- fusion_bench/method/simple_average.py +1 -1
- fusion_bench/mixins/clip_classification.py +2 -7
- fusion_bench/mixins/lightning_fabric.py +2 -2
- fusion_bench/models/rankone_moe.py +410 -0
- fusion_bench/taskpool/__init__.py +10 -2
- fusion_bench/taskpool/clip_vision/__init__.py +1 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +2 -1
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/RECORD +36 -29
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +4 -4
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +13 -7
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
- fusion_bench_config/method/regmean/clip_regmean.yaml +1 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
from typing import cast # noqa: F401
|
|
4
|
+
|
|
5
|
+
import lightning as L
|
|
6
|
+
import lightning.fabric.wrappers
|
|
7
|
+
import torch
|
|
8
|
+
from lightning.pytorch.profilers import SimpleProfiler
|
|
9
|
+
from omegaconf import DictConfig
|
|
10
|
+
from torch import Tensor
|
|
11
|
+
from torch.utils.data import DataLoader
|
|
12
|
+
from tqdm.autonotebook import tqdm
|
|
13
|
+
|
|
14
|
+
from fusion_bench.compat.method.base_algorithm import ModelFusionAlgorithm
|
|
15
|
+
from fusion_bench.compat.modelpool import ModelPool
|
|
16
|
+
from fusion_bench.models.rankone_moe import RankOneMoE
|
|
17
|
+
from fusion_bench.utils import timeit_context
|
|
18
|
+
from fusion_bench.utils.parameters import print_parameters
|
|
19
|
+
|
|
20
|
+
log = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def entropy_loss(logits: Tensor) -> Tensor:
|
|
24
|
+
"""
|
|
25
|
+
Compute the entropy loss of a set of logits.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
logits (Tensor): The logits to compute the entropy loss of.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
Tensor: The entropy loss of the logits.
|
|
32
|
+
"""
|
|
33
|
+
probs = torch.softmax(logits, dim=-1)
|
|
34
|
+
return -torch.sum(probs * torch.log(probs + 1e-8), dim=-1).mean()
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class RankOneMoEAlgorithm(ModelFusionAlgorithm):
|
|
38
|
+
"""
|
|
39
|
+
Algorithm for fusing models using RankOne-MoE (https://github.com/EnnengYang/RankOne-MoE).
|
|
40
|
+
|
|
41
|
+
This class provides methods for constructing the MoE model, performing test-time adaptation,
|
|
42
|
+
and running the fusion process.
|
|
43
|
+
|
|
44
|
+
Attributes:
|
|
45
|
+
_fabric (L.Fabric): The fabric for distributed training.
|
|
46
|
+
modelpool (ModelPool): The pool of models to be fused.
|
|
47
|
+
profiler (SimpleProfiler): The profiler for measuring performance.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
_fabric: L.Fabric = None
|
|
51
|
+
modelpool: ModelPool = None
|
|
52
|
+
|
|
53
|
+
def __init__(self, algorithm_config: DictConfig):
|
|
54
|
+
"""
|
|
55
|
+
Initialize the RankOneMoEAlgorithm with the given configuration.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
algorithm_config (DictConfig): The configuration for the algorithm.
|
|
59
|
+
"""
|
|
60
|
+
super().__init__(algorithm_config)
|
|
61
|
+
|
|
62
|
+
if self._fabric is None and torch.cuda.is_available():
|
|
63
|
+
self._fabric = L.Fabric(
|
|
64
|
+
devices=self.config.get("devices", 1),
|
|
65
|
+
)
|
|
66
|
+
self._fabric.launch()
|
|
67
|
+
else:
|
|
68
|
+
assert "No CUDA device available."
|
|
69
|
+
self.profiler = SimpleProfiler(
|
|
70
|
+
self.config.get("cache_dir", "outputs"), "we_moe_profiler.txt"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
@abstractmethod
|
|
74
|
+
def load_checkpoint(self, model, checkpoint):
|
|
75
|
+
"""
|
|
76
|
+
Load the checkpoint file.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
model: The model to load the checkpoint into.
|
|
80
|
+
checkpoint: The checkpoint file to load.
|
|
81
|
+
"""
|
|
82
|
+
pass
|
|
83
|
+
|
|
84
|
+
@abstractmethod
|
|
85
|
+
def save_checkpoint(self, model, checkpoint):
|
|
86
|
+
"""
|
|
87
|
+
Save the checkpoint file.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
model: The model to save the checkpoint from.
|
|
91
|
+
checkpoint: The checkpoint file to save.
|
|
92
|
+
"""
|
|
93
|
+
pass
|
|
94
|
+
|
|
95
|
+
@abstractmethod
|
|
96
|
+
def construct_moe_model(self) -> RankOneMoE:
|
|
97
|
+
"""
|
|
98
|
+
Construct the Mixture of Experts model using the models in the model pool.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
RankOne-MoE: The constructed MoE model.
|
|
102
|
+
"""
|
|
103
|
+
pass
|
|
104
|
+
|
|
105
|
+
def on_test_time_adaptation_start(self):
|
|
106
|
+
"""
|
|
107
|
+
Hook method called at the start of test-time adaptation.
|
|
108
|
+
"""
|
|
109
|
+
pass
|
|
110
|
+
|
|
111
|
+
@abstractmethod
|
|
112
|
+
def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
|
|
113
|
+
"""
|
|
114
|
+
Get an iterator for the shuffled test data loader for a specific task.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
task (str): The task for which to get the test data loader.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
DataLoader: The shuffled test data loader iterator.
|
|
121
|
+
"""
|
|
122
|
+
pass
|
|
123
|
+
|
|
124
|
+
@abstractmethod
|
|
125
|
+
def compute_logits(self, module, batch, task) -> Tensor:
|
|
126
|
+
"""
|
|
127
|
+
Compute the logits for a given batch and task.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
module: The model module to use for computing logits.
|
|
131
|
+
batch: The batch of data.
|
|
132
|
+
task: The task for which to compute logits.
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
Tensor: The computed logits.
|
|
136
|
+
"""
|
|
137
|
+
pass
|
|
138
|
+
|
|
139
|
+
def test_time_adaptation(self, module: RankOneMoE):
|
|
140
|
+
"""
|
|
141
|
+
Perform test-time adaptation for the given module.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
module (RankOne-MoE): The MoE module to adapt.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
RankOne-MoE: The adapted MoE module.
|
|
148
|
+
"""
|
|
149
|
+
self.on_test_time_adaptation_start()
|
|
150
|
+
|
|
151
|
+
# configure optimizer
|
|
152
|
+
if self.config.optimizer == "adam":
|
|
153
|
+
optimizer = torch.optim.Adam(
|
|
154
|
+
[p for p in module.parameters() if p.requires_grad], lr=self.config.lr
|
|
155
|
+
)
|
|
156
|
+
else:
|
|
157
|
+
raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")
|
|
158
|
+
|
|
159
|
+
if self._fabric is not None:
|
|
160
|
+
module, optimizer = self._fabric.setup(module, optimizer)
|
|
161
|
+
|
|
162
|
+
module.train()
|
|
163
|
+
|
|
164
|
+
if self.config.get("fast_dev_run", False):
|
|
165
|
+
log.info("Running fast_dev_run, only one step")
|
|
166
|
+
pbar = tqdm(
|
|
167
|
+
range(1),
|
|
168
|
+
"Test-time adaptation",
|
|
169
|
+
dynamic_ncols=True,
|
|
170
|
+
)
|
|
171
|
+
else:
|
|
172
|
+
pbar = tqdm(
|
|
173
|
+
range(self.config.max_steps),
|
|
174
|
+
"Test-time adaptation",
|
|
175
|
+
dynamic_ncols=True,
|
|
176
|
+
)
|
|
177
|
+
for step_idx in pbar:
|
|
178
|
+
if self.config.use_grad_accumulate:
|
|
179
|
+
for task in self.modelpool.model_names:
|
|
180
|
+
with self.profiler.profile("data time"):
|
|
181
|
+
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
182
|
+
with self.profiler.profile("forward pass"):
|
|
183
|
+
logits = self.compute_logits(module, batch, task)
|
|
184
|
+
assert (
|
|
185
|
+
logits.dim() == 2
|
|
186
|
+
), f"Expected logits to be 2D, got {logits.dim()}"
|
|
187
|
+
loss = entropy_loss(logits)
|
|
188
|
+
# .backward() accumulates when .zero_grad() wasn't called
|
|
189
|
+
# this can save memory
|
|
190
|
+
with self.profiler.profile("backward pass"):
|
|
191
|
+
self._fabric.backward(loss, retain_graph=True)
|
|
192
|
+
else:
|
|
193
|
+
loss = 0
|
|
194
|
+
for task in self.modelpool.model_names:
|
|
195
|
+
with self.profiler.profile("data time"):
|
|
196
|
+
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
197
|
+
with self.profiler.profile("forward pass"):
|
|
198
|
+
logits = self.compute_logits(module, batch, task)
|
|
199
|
+
assert (
|
|
200
|
+
logits.dim() == 2
|
|
201
|
+
), f"Expected logits to be 2D, got {logits.dim()}"
|
|
202
|
+
loss = loss + entropy_loss(logits)
|
|
203
|
+
with self.profiler.profile("backward pass"):
|
|
204
|
+
self._fabric.backward(loss, retain_graph=True)
|
|
205
|
+
|
|
206
|
+
with self.profiler.profile("optimizer step"):
|
|
207
|
+
optimizer.step()
|
|
208
|
+
optimizer.zero_grad()
|
|
209
|
+
|
|
210
|
+
# print([m for m in module.parameters() if m.requires_grad][0])
|
|
211
|
+
|
|
212
|
+
return module
|
|
213
|
+
|
|
214
|
+
def run(self, modelpool: ModelPool):
|
|
215
|
+
"""
|
|
216
|
+
Run the RankOneMoEAlgorithm to fuse models using RankOne-MoE.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
modelpool (ModelPool): The pool of models to be fused.
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
RankOne-MoE: The fused RankOne MoE model.
|
|
223
|
+
"""
|
|
224
|
+
log.info("Fusing models using RankOne-MoE modules.")
|
|
225
|
+
self.modelpool = modelpool
|
|
226
|
+
|
|
227
|
+
with timeit_context("upscaling models to a RankOne-MoE model"):
|
|
228
|
+
moe_model = self.construct_moe_model()
|
|
229
|
+
print_parameters(moe_model)
|
|
230
|
+
|
|
231
|
+
if self.config.get("checkpoint", False):
|
|
232
|
+
log.info(
|
|
233
|
+
f"load checkpoint from {self.config.checkpoint}, test-time adaptation will be skipped."
|
|
234
|
+
)
|
|
235
|
+
self.load_checkpoint(moe_model, self.config.checkpoint)
|
|
236
|
+
else:
|
|
237
|
+
with self.profiler.profile("test-time adaptation"):
|
|
238
|
+
moe_model = self.test_time_adaptation(moe_model)
|
|
239
|
+
if self.config.get("save_checkpoint", False):
|
|
240
|
+
log.info(f"save checkpoint to {self.config.save_checkpoint}")
|
|
241
|
+
self.save_checkpoint(moe_model, self.config.save_checkpoint)
|
|
242
|
+
|
|
243
|
+
if lightning.fabric.wrappers.is_wrapped(moe_model):
|
|
244
|
+
moe_model = lightning.fabric.wrappers._unwrap_objects(moe_model)
|
|
245
|
+
|
|
246
|
+
# enable sample-wise adaptation
|
|
247
|
+
moe_model.batch_reduce = False
|
|
248
|
+
print(self.profiler.summary())
|
|
249
|
+
return moe_model
|
|
@@ -11,8 +11,8 @@ from fusion_bench.modelpool import BaseModelPool
|
|
|
11
11
|
from fusion_bench.utils.state_dict_arithmetic import (
|
|
12
12
|
state_dict_add,
|
|
13
13
|
state_dict_avg,
|
|
14
|
-
state_dict_mul,
|
|
15
14
|
state_dict_div,
|
|
15
|
+
state_dict_mul,
|
|
16
16
|
)
|
|
17
17
|
from fusion_bench.utils.type import StateDictType
|
|
18
18
|
|
|
@@ -5,11 +5,11 @@ from copy import deepcopy
|
|
|
5
5
|
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, cast # noqa: F401
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
|
+
from omegaconf import DictConfig
|
|
8
9
|
from torch import nn
|
|
9
10
|
from torch.utils.data import DataLoader
|
|
10
11
|
from tqdm.auto import tqdm
|
|
11
12
|
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
|
|
12
|
-
from omegaconf import DictConfig
|
|
13
13
|
|
|
14
14
|
from fusion_bench.dataset.clip_dataset import CLIPDataset
|
|
15
15
|
from fusion_bench.mixins import LightningFabricMixin
|
|
@@ -43,12 +43,7 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
43
43
|
# a dict of zeroshot weights for each task, each key is the task name
|
|
44
44
|
zeroshot_weights_cache_dir: str = "outputs/cache/clip_zeroshot_weights"
|
|
45
45
|
zeroshot_weights: Dict[str, torch.Tensor] = {}
|
|
46
|
-
|
|
47
|
-
def __init__(self, algorithm_config: DictConfig) -> None:
|
|
48
|
-
super().__init__(algorithm_config)
|
|
49
|
-
self.whether_setup_zero_shot_classification_head = (
|
|
50
|
-
False # We want to only do this once
|
|
51
|
-
)
|
|
46
|
+
whether_setup_zero_shot_classification_head = False
|
|
52
47
|
|
|
53
48
|
@property
|
|
54
49
|
def clip_processor(self):
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
|
-
from typing import TYPE_CHECKING, Any, Optional, TypeVar
|
|
3
|
+
from typing import TYPE_CHECKING, Any, List, Optional, TypeVar
|
|
4
4
|
|
|
5
5
|
import lightning as L
|
|
6
6
|
import torch
|
|
@@ -8,8 +8,8 @@ from lightning.fabric.loggers import TensorBoardLogger
|
|
|
8
8
|
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
|
9
9
|
from omegaconf import DictConfig, OmegaConf
|
|
10
10
|
|
|
11
|
-
from fusion_bench.utils.instantiate import instantiate
|
|
12
11
|
from fusion_bench.utils import import_object
|
|
12
|
+
from fusion_bench.utils.instantiate import instantiate
|
|
13
13
|
|
|
14
14
|
if TYPE_CHECKING:
|
|
15
15
|
import lightning.fabric.loggers.tensorboard
|