fusion-bench 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. fusion_bench/compat/method/__init__.py +1 -0
  2. fusion_bench/compat/method/base_algorithm.py +0 -1
  3. fusion_bench/compat/modelpool/__init__.py +2 -1
  4. fusion_bench/dataset/arc_agi/__init__.py +6 -1
  5. fusion_bench/dataset/arc_agi/arc.py +21 -7
  6. fusion_bench/dataset/arc_agi/arc_agi.py +156 -25
  7. fusion_bench/dataset/arc_agi/np_cache.py +0 -1
  8. fusion_bench/dataset/arc_agi/preprocess.py +50 -8
  9. fusion_bench/dataset/llama/collate.py +10 -3
  10. fusion_bench/method/__init__.py +3 -0
  11. fusion_bench/method/adamerging/__init__.py +1 -1
  12. fusion_bench/method/lm_finetune/fullfinetune_sft.py +47 -5
  13. fusion_bench/method/lm_finetune/peftfinetune_sft.py +58 -23
  14. fusion_bench/method/pruning/magnitude_diff_pruning.py +2 -1
  15. fusion_bench/method/rankone_moe/__init__.py +3 -0
  16. fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
  17. fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
  18. fusion_bench/method/simple_average.py +1 -1
  19. fusion_bench/mixins/clip_classification.py +2 -7
  20. fusion_bench/mixins/lightning_fabric.py +2 -2
  21. fusion_bench/models/rankone_moe.py +410 -0
  22. fusion_bench/taskpool/__init__.py +10 -2
  23. fusion_bench/taskpool/clip_vision/__init__.py +1 -0
  24. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
  25. fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +2 -1
  26. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/METADATA +1 -1
  27. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/RECORD +36 -29
  28. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +4 -4
  29. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +13 -7
  30. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
  31. fusion_bench_config/method/regmean/clip_regmean.yaml +1 -0
  32. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
  33. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/LICENSE +0 -0
  34. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/WHEEL +0 -0
  35. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/entry_points.txt +0 -0
  36. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,249 @@
1
+ import logging
2
+ from abc import abstractmethod
3
+ from typing import cast # noqa: F401
4
+
5
+ import lightning as L
6
+ import lightning.fabric.wrappers
7
+ import torch
8
+ from lightning.pytorch.profilers import SimpleProfiler
9
+ from omegaconf import DictConfig
10
+ from torch import Tensor
11
+ from torch.utils.data import DataLoader
12
+ from tqdm.autonotebook import tqdm
13
+
14
+ from fusion_bench.compat.method.base_algorithm import ModelFusionAlgorithm
15
+ from fusion_bench.compat.modelpool import ModelPool
16
+ from fusion_bench.models.rankone_moe import RankOneMoE
17
+ from fusion_bench.utils import timeit_context
18
+ from fusion_bench.utils.parameters import print_parameters
19
+
20
+ log = logging.getLogger(__name__)
21
+
22
+
23
+ def entropy_loss(logits: Tensor) -> Tensor:
24
+ """
25
+ Compute the entropy loss of a set of logits.
26
+
27
+ Args:
28
+ logits (Tensor): The logits to compute the entropy loss of.
29
+
30
+ Returns:
31
+ Tensor: The entropy loss of the logits.
32
+ """
33
+ probs = torch.softmax(logits, dim=-1)
34
+ return -torch.sum(probs * torch.log(probs + 1e-8), dim=-1).mean()
35
+
36
+
37
+ class RankOneMoEAlgorithm(ModelFusionAlgorithm):
38
+ """
39
+ Algorithm for fusing models using RankOne-MoE (https://github.com/EnnengYang/RankOne-MoE).
40
+
41
+ This class provides methods for constructing the MoE model, performing test-time adaptation,
42
+ and running the fusion process.
43
+
44
+ Attributes:
45
+ _fabric (L.Fabric): The fabric for distributed training.
46
+ modelpool (ModelPool): The pool of models to be fused.
47
+ profiler (SimpleProfiler): The profiler for measuring performance.
48
+ """
49
+
50
+ _fabric: L.Fabric = None
51
+ modelpool: ModelPool = None
52
+
53
+ def __init__(self, algorithm_config: DictConfig):
54
+ """
55
+ Initialize the RankOneMoEAlgorithm with the given configuration.
56
+
57
+ Args:
58
+ algorithm_config (DictConfig): The configuration for the algorithm.
59
+ """
60
+ super().__init__(algorithm_config)
61
+
62
+ if self._fabric is None and torch.cuda.is_available():
63
+ self._fabric = L.Fabric(
64
+ devices=self.config.get("devices", 1),
65
+ )
66
+ self._fabric.launch()
67
+ else:
68
+ assert "No CUDA device available."
69
+ self.profiler = SimpleProfiler(
70
+ self.config.get("cache_dir", "outputs"), "we_moe_profiler.txt"
71
+ )
72
+
73
+ @abstractmethod
74
+ def load_checkpoint(self, model, checkpoint):
75
+ """
76
+ Load the checkpoint file.
77
+
78
+ Args:
79
+ model: The model to load the checkpoint into.
80
+ checkpoint: The checkpoint file to load.
81
+ """
82
+ pass
83
+
84
+ @abstractmethod
85
+ def save_checkpoint(self, model, checkpoint):
86
+ """
87
+ Save the checkpoint file.
88
+
89
+ Args:
90
+ model: The model to save the checkpoint from.
91
+ checkpoint: The checkpoint file to save.
92
+ """
93
+ pass
94
+
95
+ @abstractmethod
96
+ def construct_moe_model(self) -> RankOneMoE:
97
+ """
98
+ Construct the Mixture of Experts model using the models in the model pool.
99
+
100
+ Returns:
101
+ RankOne-MoE: The constructed MoE model.
102
+ """
103
+ pass
104
+
105
+ def on_test_time_adaptation_start(self):
106
+ """
107
+ Hook method called at the start of test-time adaptation.
108
+ """
109
+ pass
110
+
111
+ @abstractmethod
112
+ def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
113
+ """
114
+ Get an iterator for the shuffled test data loader for a specific task.
115
+
116
+ Args:
117
+ task (str): The task for which to get the test data loader.
118
+
119
+ Returns:
120
+ DataLoader: The shuffled test data loader iterator.
121
+ """
122
+ pass
123
+
124
+ @abstractmethod
125
+ def compute_logits(self, module, batch, task) -> Tensor:
126
+ """
127
+ Compute the logits for a given batch and task.
128
+
129
+ Args:
130
+ module: The model module to use for computing logits.
131
+ batch: The batch of data.
132
+ task: The task for which to compute logits.
133
+
134
+ Returns:
135
+ Tensor: The computed logits.
136
+ """
137
+ pass
138
+
139
+ def test_time_adaptation(self, module: RankOneMoE):
140
+ """
141
+ Perform test-time adaptation for the given module.
142
+
143
+ Args:
144
+ module (RankOne-MoE): The MoE module to adapt.
145
+
146
+ Returns:
147
+ RankOne-MoE: The adapted MoE module.
148
+ """
149
+ self.on_test_time_adaptation_start()
150
+
151
+ # configure optimizer
152
+ if self.config.optimizer == "adam":
153
+ optimizer = torch.optim.Adam(
154
+ [p for p in module.parameters() if p.requires_grad], lr=self.config.lr
155
+ )
156
+ else:
157
+ raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")
158
+
159
+ if self._fabric is not None:
160
+ module, optimizer = self._fabric.setup(module, optimizer)
161
+
162
+ module.train()
163
+
164
+ if self.config.get("fast_dev_run", False):
165
+ log.info("Running fast_dev_run, only one step")
166
+ pbar = tqdm(
167
+ range(1),
168
+ "Test-time adaptation",
169
+ dynamic_ncols=True,
170
+ )
171
+ else:
172
+ pbar = tqdm(
173
+ range(self.config.max_steps),
174
+ "Test-time adaptation",
175
+ dynamic_ncols=True,
176
+ )
177
+ for step_idx in pbar:
178
+ if self.config.use_grad_accumulate:
179
+ for task in self.modelpool.model_names:
180
+ with self.profiler.profile("data time"):
181
+ batch = next(self.get_shuffled_test_loader_iter(task))
182
+ with self.profiler.profile("forward pass"):
183
+ logits = self.compute_logits(module, batch, task)
184
+ assert (
185
+ logits.dim() == 2
186
+ ), f"Expected logits to be 2D, got {logits.dim()}"
187
+ loss = entropy_loss(logits)
188
+ # .backward() accumulates when .zero_grad() wasn't called
189
+ # this can save memory
190
+ with self.profiler.profile("backward pass"):
191
+ self._fabric.backward(loss, retain_graph=True)
192
+ else:
193
+ loss = 0
194
+ for task in self.modelpool.model_names:
195
+ with self.profiler.profile("data time"):
196
+ batch = next(self.get_shuffled_test_loader_iter(task))
197
+ with self.profiler.profile("forward pass"):
198
+ logits = self.compute_logits(module, batch, task)
199
+ assert (
200
+ logits.dim() == 2
201
+ ), f"Expected logits to be 2D, got {logits.dim()}"
202
+ loss = loss + entropy_loss(logits)
203
+ with self.profiler.profile("backward pass"):
204
+ self._fabric.backward(loss, retain_graph=True)
205
+
206
+ with self.profiler.profile("optimizer step"):
207
+ optimizer.step()
208
+ optimizer.zero_grad()
209
+
210
+ # print([m for m in module.parameters() if m.requires_grad][0])
211
+
212
+ return module
213
+
214
+ def run(self, modelpool: ModelPool):
215
+ """
216
+ Run the RankOneMoEAlgorithm to fuse models using RankOne-MoE.
217
+
218
+ Args:
219
+ modelpool (ModelPool): The pool of models to be fused.
220
+
221
+ Returns:
222
+ RankOne-MoE: The fused RankOne MoE model.
223
+ """
224
+ log.info("Fusing models using RankOne-MoE modules.")
225
+ self.modelpool = modelpool
226
+
227
+ with timeit_context("upscaling models to a RankOne-MoE model"):
228
+ moe_model = self.construct_moe_model()
229
+ print_parameters(moe_model)
230
+
231
+ if self.config.get("checkpoint", False):
232
+ log.info(
233
+ f"load checkpoint from {self.config.checkpoint}, test-time adaptation will be skipped."
234
+ )
235
+ self.load_checkpoint(moe_model, self.config.checkpoint)
236
+ else:
237
+ with self.profiler.profile("test-time adaptation"):
238
+ moe_model = self.test_time_adaptation(moe_model)
239
+ if self.config.get("save_checkpoint", False):
240
+ log.info(f"save checkpoint to {self.config.save_checkpoint}")
241
+ self.save_checkpoint(moe_model, self.config.save_checkpoint)
242
+
243
+ if lightning.fabric.wrappers.is_wrapped(moe_model):
244
+ moe_model = lightning.fabric.wrappers._unwrap_objects(moe_model)
245
+
246
+ # enable sample-wise adaptation
247
+ moe_model.batch_reduce = False
248
+ print(self.profiler.summary())
249
+ return moe_model
@@ -11,8 +11,8 @@ from fusion_bench.modelpool import BaseModelPool
11
11
  from fusion_bench.utils.state_dict_arithmetic import (
12
12
  state_dict_add,
13
13
  state_dict_avg,
14
- state_dict_mul,
15
14
  state_dict_div,
15
+ state_dict_mul,
16
16
  )
17
17
  from fusion_bench.utils.type import StateDictType
18
18
 
@@ -5,11 +5,11 @@ from copy import deepcopy
5
5
  from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, cast # noqa: F401
6
6
 
7
7
  import torch
8
+ from omegaconf import DictConfig
8
9
  from torch import nn
9
10
  from torch.utils.data import DataLoader
10
11
  from tqdm.auto import tqdm
11
12
  from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
12
- from omegaconf import DictConfig
13
13
 
14
14
  from fusion_bench.dataset.clip_dataset import CLIPDataset
15
15
  from fusion_bench.mixins import LightningFabricMixin
@@ -43,12 +43,7 @@ class CLIPClassificationMixin(LightningFabricMixin):
43
43
  # a dict of zeroshot weights for each task, each key is the task name
44
44
  zeroshot_weights_cache_dir: str = "outputs/cache/clip_zeroshot_weights"
45
45
  zeroshot_weights: Dict[str, torch.Tensor] = {}
46
-
47
- def __init__(self, algorithm_config: DictConfig) -> None:
48
- super().__init__(algorithm_config)
49
- self.whether_setup_zero_shot_classification_head = (
50
- False # We want to only do this once
51
- )
46
+ whether_setup_zero_shot_classification_head = False
52
47
 
53
48
  @property
54
49
  def clip_processor(self):
@@ -1,6 +1,6 @@
1
1
  import logging
2
2
  import os
3
- from typing import TYPE_CHECKING, Any, Optional, TypeVar, List
3
+ from typing import TYPE_CHECKING, Any, List, Optional, TypeVar
4
4
 
5
5
  import lightning as L
6
6
  import torch
@@ -8,8 +8,8 @@ from lightning.fabric.loggers import TensorBoardLogger
8
8
  from lightning.fabric.utilities.rank_zero import rank_zero_only
9
9
  from omegaconf import DictConfig, OmegaConf
10
10
 
11
- from fusion_bench.utils.instantiate import instantiate
12
11
  from fusion_bench.utils import import_object
12
+ from fusion_bench.utils.instantiate import instantiate
13
13
 
14
14
  if TYPE_CHECKING:
15
15
  import lightning.fabric.loggers.tensorboard