fusion-bench 0.2.21__py3-none-any.whl → 0.2.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. fusion_bench/__init__.py +21 -2
  2. fusion_bench/constants/__init__.py +1 -0
  3. fusion_bench/constants/runtime.py +57 -0
  4. fusion_bench/method/__init__.py +8 -2
  5. fusion_bench/method/bitdelta/__init__.py +1 -0
  6. fusion_bench/method/classification/clip_finetune.py +1 -1
  7. fusion_bench/method/fisher_merging/clip_fisher_merging.py +0 -4
  8. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +2 -2
  9. fusion_bench/method/linear/simple_average_for_llama.py +16 -11
  10. fusion_bench/method/simple_average.py +7 -7
  11. fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
  12. fusion_bench/method/smile_upscaling/projected_energy.py +1 -2
  13. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
  14. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +40 -31
  15. fusion_bench/method/smile_upscaling/smile_upscaling.py +1 -1
  16. fusion_bench/method/we_moe/__init__.py +1 -0
  17. fusion_bench/method/we_moe/entropy_loss.py +25 -0
  18. fusion_bench/method/we_moe/flan_t5_we_moe.py +331 -0
  19. fusion_bench/method/we_moe/utils.py +15 -0
  20. fusion_bench/method/weighted_average/llama.py +1 -1
  21. fusion_bench/mixins/clip_classification.py +11 -42
  22. fusion_bench/mixins/serialization.py +18 -8
  23. fusion_bench/modelpool/causal_lm/causal_lm.py +32 -33
  24. fusion_bench/models/__init__.py +5 -0
  25. fusion_bench/models/hf_utils.py +65 -87
  26. fusion_bench/models/model_card_templates/default.md +46 -0
  27. fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
  28. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +1 -8
  29. fusion_bench/models/modeling_smile_mistral/__init__.py +1 -1
  30. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +1 -5
  31. fusion_bench/programs/fabric_fusion_program.py +29 -60
  32. fusion_bench/scripts/cli.py +34 -1
  33. fusion_bench/taskpool/clip_vision/taskpool.py +9 -4
  34. fusion_bench/utils/__init__.py +1 -0
  35. fusion_bench/utils/cache_utils.py +101 -1
  36. fusion_bench/utils/fabric.py +2 -2
  37. fusion_bench/utils/lazy_imports.py +23 -0
  38. fusion_bench/utils/lazy_state_dict.py +38 -3
  39. fusion_bench/utils/modelscope.py +3 -3
  40. fusion_bench/utils/path.py +56 -0
  41. fusion_bench/utils/pylogger.py +1 -1
  42. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/METADATA +1 -23
  43. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/RECORD +53 -45
  44. fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
  45. fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
  46. fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
  47. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -1
  48. fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
  49. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +1 -1
  50. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/WHEEL +0 -0
  51. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/entry_points.txt +0 -0
  52. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/licenses/LICENSE +0 -0
  53. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.22.dist-info}/top_level.txt +0 -0
@@ -16,10 +16,11 @@ from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
16
16
 
17
17
  from fusion_bench import BaseAlgorithm, BaseModelPool
18
18
  from fusion_bench.compat.modelpool import to_modelpool
19
+ from fusion_bench.constants import RuntimeConstants
19
20
  from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
20
21
  from fusion_bench.modelpool import CausalLMPool
21
22
  from fusion_bench.models.hf_utils import (
22
- generate_complete_readme,
23
+ create_default_model_card,
23
24
  save_pretrained_with_remote_code,
24
25
  )
25
26
  from fusion_bench.models.modeling_smile_qwen2 import (
@@ -41,7 +42,10 @@ log = logging.getLogger(__name__)
41
42
 
42
43
 
43
44
  @auto_register_config
44
- class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
45
+ class SmileQwen2UpscalingAlgorithm(
46
+ SimpleProfilerMixin,
47
+ BaseAlgorithm,
48
+ ):
45
49
  R"""
46
50
  SmileQwen2UpscalingAlgorithm is a model fusion algorithm designed to upscale
47
51
  a pretrained Qwen2 model using a set of fine-tuned expert models. The algorithm
@@ -62,7 +66,7 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
62
66
  self,
63
67
  device,
64
68
  accelerator,
65
- model_path,
69
+ model_save_path,
66
70
  model_dtype,
67
71
  num_experts_per_tok,
68
72
  rank_of_router,
@@ -71,6 +75,11 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
71
75
  **kwargs,
72
76
  ):
73
77
  super().__init__(**kwargs)
78
+ if not torch.cuda.is_available():
79
+ if "cuda" in self.device:
80
+ self.device = "cpu"
81
+ if "cuda" in self.accelerator:
82
+ self.accelerator = "cpu"
74
83
 
75
84
  @torch.no_grad()
76
85
  def run(self, modelpool) -> SmileQwen2ForCausalLM:
@@ -86,13 +95,6 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
86
95
  self.modelpool = modelpool = to_modelpool(modelpool)
87
96
  config = self.config
88
97
 
89
- # load model from path if provided and return directly
90
- if config.model_path is not None and os.path.exists(config.model_path):
91
- log.info(f"Loading model from {config.model_path}")
92
- model = AutoModelForCausalLM.from_pretrained(config.model_path)
93
- print_parameters(model)
94
- return model
95
-
96
98
  with self.profile("load pretrained model"):
97
99
  pretrained_model = modelpool.load_pretrained_model()
98
100
  with self.profile("load fine-tuned model"):
@@ -100,7 +102,7 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
100
102
  m for m in tqdm(modelpool.models(), total=len(modelpool.model_names))
101
103
  ]
102
104
 
103
- if config.device == "cuda" and torch.cuda.is_available():
105
+ if self.device == "cuda" and torch.cuda.is_available():
104
106
  pretrained_model = pretrained_model.cuda()
105
107
  print("parameter count of pretrained model:")
106
108
  print_parameters(pretrained_model)
@@ -114,17 +116,17 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
114
116
  print_parameters(model)
115
117
  print(model)
116
118
 
117
- if config.model_dtype is not None:
118
- model.to(dtype=parse_dtype(config.model_dtype))
119
+ if self.model_dtype is not None:
120
+ model.to(dtype=parse_dtype(self.model_dtype))
119
121
 
120
- if config.model_path is not None:
121
- if os.path.dirname(config.model_path):
122
- os.makedirs(os.path.dirname(config.model_path), exist_ok=True)
123
- log.info(f"Saving model to {config.model_path}")
122
+ if self.model_save_path is not None:
123
+ if os.path.dirname(self.model_save_path):
124
+ os.makedirs(os.path.dirname(self.model_save_path), exist_ok=True)
125
+ log.info(f"Saving model to {self.model_save_path}")
124
126
  tokenizer = self.modelpool.load_tokenizer()
125
- tokenizer.save_pretrained(config.model_path)
127
+ tokenizer.save_pretrained(self.model_save_path)
126
128
  if not self.save_with_remote_code:
127
- model.save_pretrained(config.model_path)
129
+ model.save_pretrained(self.model_save_path)
128
130
  else:
129
131
  save_pretrained_with_remote_code(
130
132
  model,
@@ -133,17 +135,18 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
133
135
  "AutoModel": SmileQwen2Model,
134
136
  "AutoModelForCausalLM": SmileQwen2ForCausalLM,
135
137
  },
136
- save_directory=config.model_path,
138
+ save_directory=self.model_save_path,
137
139
  )
138
140
 
139
141
  # save readme
140
- complete_readme = generate_complete_readme(
141
- algorithm=self,
142
- modelpool=modelpool,
142
+ model_card_str = create_default_model_card(
143
143
  models=[modelpool.get_model_path(m) for m in modelpool.all_model_names],
144
+ description="Merged Qwen model using SMILE Upscaling",
145
+ algorithm_config=self.config,
146
+ modelpool_config=modelpool.config,
144
147
  )
145
- with open(os.path.join(config.model_path, "README.md"), "w") as f:
146
- f.write(complete_readme)
148
+ with open(os.path.join(self.model_save_path, "README.md"), "w") as f:
149
+ f.write(model_card_str)
147
150
 
148
151
  return model
149
152
 
@@ -174,9 +177,9 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
174
177
  )
175
178
  base_config = AutoConfig.from_pretrained(pretrained_path)
176
179
  model_config = SmileQwen2Config(
177
- num_experts_per_tok=config.num_experts_per_tok,
178
- rank_of_router=config.rank_of_router,
179
- rank_of_expert=config.rank_of_expert,
180
+ num_experts_per_tok=self.num_experts_per_tok,
181
+ rank_of_router=self.rank_of_router,
182
+ rank_of_expert=self.rank_of_expert,
180
183
  num_local_experts=len(finetuned_models),
181
184
  **base_config.to_dict(),
182
185
  )
@@ -186,7 +189,7 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
186
189
 
187
190
  # copy pretrained model weights
188
191
  state_dict = model.state_dict()
189
- pretrained_state_dict = dict(pretrained_model.state_dict())
192
+ pretrained_state_dict = pretrained_model.state_dict()
190
193
  for key in list(pretrained_state_dict.keys()):
191
194
  if key not in state_dict:
192
195
  pretrained_state_dict.pop(key)
@@ -198,6 +201,12 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
198
201
  "Upscaling Modules (layer)",
199
202
  dynamic_ncols=True,
200
203
  ):
204
+ if RuntimeConstants.debug and layer_idx > 0:
205
+ log.info(
206
+ "Debug mode enabled: processing only the first layer, skipping remaining layers"
207
+ )
208
+ break
209
+
201
210
  pretrained_layer: Qwen2DecoderLayer = pretrained_model.model.layers[
202
211
  layer_idx
203
212
  ]
@@ -213,7 +222,7 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
213
222
  base=getattr(pretrained_layer.self_attn, n),
214
223
  experts=[getattr(m.self_attn, n) for m in finetuned_layers],
215
224
  target=getattr(target_layer.self_attn, n),
216
- accelerator=config.accelerator,
225
+ accelerator=self.accelerator,
217
226
  )
218
227
  except ExpertNotTrainedError:
219
228
  setattr(
@@ -228,7 +237,7 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
228
237
  base=getattr(pretrained_layer.mlp, n),
229
238
  experts=[getattr(m.mlp, n) for m in finetuned_layers],
230
239
  target=getattr(target_layer.mlp, n),
231
- accelerator=config.accelerator,
240
+ accelerator=self.accelerator,
232
241
  )
233
242
  except ExpertNotTrainedError:
234
243
  setattr(
@@ -20,8 +20,8 @@ from fusion_bench.models.smile_moe.linear_from_module import (
20
20
  SmileMoELinear,
21
21
  )
22
22
  from fusion_bench.models.utils import get_attr, set_attr
23
- from fusion_bench.utils.parameters import print_parameters
24
23
  from fusion_bench.utils.devices import get_device
24
+ from fusion_bench.utils.parameters import print_parameters
25
25
 
26
26
  log = logging.getLogger(__name__)
27
27
 
@@ -1,2 +1,3 @@
1
1
  # flake8: noqa F401
2
2
  from .clip_we_moe import CLIPWeightEnsemblingMoEAlgorithm
3
+ from .flan_t5_we_moe import FlanT5WeightEnsemblingMoEAlgorithm
@@ -0,0 +1,25 @@
1
+ import torch
2
+ from torch import Tensor
3
+
4
+
5
+ def entropy_loss(logits: Tensor, eps: float = 1e-8) -> Tensor:
6
+ """
7
+ Compute the entropy loss of a set of logits.
8
+
9
+ Args:
10
+ logits (Tensor): The logits to compute the entropy loss of.
11
+ eps (float): A small value to avoid log(0). Default is 1e-8.
12
+
13
+ Returns:
14
+ Tensor: The entropy loss of the logits.
15
+ """
16
+ # Ensure the logits tensor has 2 dimensions
17
+ assert (
18
+ logits.dim() == 2
19
+ ), f"Expected logits to have 2 dimensions, found {logits.dim()}, {logits.size()=}"
20
+
21
+ # Compute the softmax probabilities
22
+ probs = torch.softmax(logits, dim=-1)
23
+
24
+ # Compute the entropy loss
25
+ return -torch.sum(probs * torch.log(probs + eps), dim=-1).mean()
@@ -0,0 +1,331 @@
1
+ import functools
2
+ import logging
3
+ import os
4
+ from copy import deepcopy
5
+ from typing import Any, Dict, List, Mapping, Optional, Union, cast # noqa: F401
6
+
7
+ import lightning
8
+ import lightning as L
9
+ import lightning.fabric.wrappers
10
+ import torch
11
+ from torch import Tensor
12
+ from torch.utils.data import DataLoader
13
+ from tqdm.autonotebook import tqdm
14
+ from transformers import T5ForConditionalGeneration
15
+ from transformers.data import default_data_collator
16
+
17
+ from fusion_bench.method import BaseAlgorithm
18
+ from fusion_bench.method.task_arithmetic.task_arithmetic import task_arithmetic_merge
19
+ from fusion_bench.mixins.lightning_fabric import LightningFabricMixin
20
+ from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
21
+ from fusion_bench.modelpool import Seq2SeqLMPool
22
+ from fusion_bench.models.we_moe import WeightEnsemblingMoE
23
+ from fusion_bench.utils import timeit_context
24
+ from fusion_bench.utils.data import InfiniteDataLoader, load_tensor_from_file
25
+ from fusion_bench.utils.instantiate_utils import instantiate
26
+ from fusion_bench.utils.parameters import print_parameters
27
+
28
+ from .entropy_loss import entropy_loss
29
+ from .utils import get_memory_usage
30
+
31
+ log = logging.getLogger(__name__)
32
+
33
+
34
+ class FlanT5WeightEnsemblingMoEAlgorithm(
35
+ BaseAlgorithm,
36
+ LightningFabricMixin,
37
+ SimpleProfilerMixin,
38
+ ):
39
+ """
40
+ FlanT5WeightEnsemblingMoEAlgorithm is a class that implements the WeightEnsemblingMoEAlgorithm
41
+ for FlanT5 models. It extends the WeightEnsemblingMoEAlgorithm and CLIPClassificationMixin classes.
42
+
43
+ Attributes:
44
+ modelpool (Seq2SeqLMPool): The model pool containing the FlanT5 models.
45
+ """
46
+
47
+ modelpool: Seq2SeqLMPool = None
48
+
49
+ def __init__(
50
+ self,
51
+ checkpoint: bool = False,
52
+ save_checkpoint: bool = False,
53
+ router_hidden_layers: int = 2,
54
+ init_lambda: float = 0.3,
55
+ batch_reduce: bool = True,
56
+ lr: float = 1e-4,
57
+ optimizer: str = "adam",
58
+ devices: int = 1,
59
+ batch_size: int = 16,
60
+ num_workers: int = 0,
61
+ max_steps: int = 1000,
62
+ use_grad_accumulate: bool = True,
63
+ cache_dir: bool = "outputs",
64
+ fast_dev_run: bool = False,
65
+ **kwargs,
66
+ ):
67
+ """
68
+ Initialize the WeightEnsemblingMoEAlgorithm with the given configuration.
69
+
70
+ Args:
71
+ algorithm_config (DictConfig): The configuration for the algorithm.
72
+ """
73
+ self.checkpoint = checkpoint
74
+ self.save_checkpoint = save_checkpoint
75
+ self.router_hidden_layers = router_hidden_layers
76
+ self.init_lambda = init_lambda
77
+ self.batch_reduce = batch_reduce
78
+ self.lr = lr
79
+ self.optimizer = optimizer
80
+ self.devices = devices
81
+ self.batch_size = batch_size
82
+ self.num_workers = num_workers
83
+ self.max_steps = max_steps
84
+ self.use_grad_accumulate = use_grad_accumulate
85
+ self.cache_dir = cache_dir
86
+ self.fast_dev_run = fast_dev_run
87
+ super().__init__(**kwargs)
88
+
89
+ def construct_moe_model(self) -> WeightEnsemblingMoE:
90
+ """
91
+ Construct the Mixture of Experts (MoE) model using the models in the model pool.
92
+
93
+ Returns:
94
+ WeightEnsemblingMoE: The constructed MoE model.
95
+ """
96
+ base_model = self.modelpool.load_model("_pretrained_")
97
+ expert_models = [
98
+ self.modelpool.load_model(name) for name in self.modelpool.model_names
99
+ ]
100
+
101
+ # Merge the models using task arithmetic
102
+ moe_model = task_arithmetic_merge(
103
+ # This function modifies the model in place, so we need to pass a deepcopy
104
+ deepcopy(base_model),
105
+ expert_models,
106
+ scaling_factor=self.init_lambda,
107
+ ).requires_grad_(False)
108
+
109
+ print(base_model)
110
+
111
+ # Up-scale MLP modules
112
+ num_layer = 12
113
+ encoder_mlp_index = 1
114
+ base_encoder = base_model.encoder
115
+ moe_encoder = moe_model.encoder
116
+ expert_encoders = [m.encoder for m in expert_models]
117
+
118
+ for layer_idx in range(num_layer):
119
+ base_mlp = (
120
+ base_encoder.block[layer_idx].layer[encoder_mlp_index].DenseReluDense
121
+ )
122
+ expert_mlps = [
123
+ e.block[layer_idx].layer[encoder_mlp_index].DenseReluDense
124
+ for e in expert_encoders
125
+ ]
126
+
127
+ moe_encoder.block[layer_idx].layer[encoder_mlp_index].DenseReluDense = (
128
+ WeightEnsemblingMoE(
129
+ hidden_size=base_encoder.config.hidden_size,
130
+ base_model=base_mlp,
131
+ expert_models=expert_mlps,
132
+ init_lambda=self.init_lambda,
133
+ batch_first=True,
134
+ router_hidden_layers=self.router_hidden_layers,
135
+ batch_reduce=self.batch_reduce,
136
+ )
137
+ )
138
+
139
+ decoder_mlp_index = 2
140
+ base_decoder = base_model.decoder
141
+ moe_decoder = moe_model.decoder
142
+ expert_decoders = [m.decoder for m in expert_models]
143
+
144
+ for layer_idx in range(num_layer):
145
+ base_mlp = (
146
+ base_decoder.block[layer_idx].layer[decoder_mlp_index].DenseReluDense
147
+ )
148
+ expert_mlps = [
149
+ e.block[layer_idx].layer[decoder_mlp_index].DenseReluDense
150
+ for e in expert_decoders
151
+ ]
152
+
153
+ moe_decoder.block[layer_idx].layer[decoder_mlp_index].DenseReluDense = (
154
+ WeightEnsemblingMoE(
155
+ hidden_size=base_decoder.config.hidden_size,
156
+ base_model=base_mlp,
157
+ expert_models=expert_mlps,
158
+ init_lambda=self.init_lambda,
159
+ batch_first=True,
160
+ router_hidden_layers=self.router_hidden_layers,
161
+ batch_reduce=self.batch_reduce,
162
+ )
163
+ )
164
+
165
+ print(moe_model)
166
+ return moe_model
167
+
168
+ @functools.cache
169
+ def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
170
+ """
171
+ Loader of test dataset for test-time adaptation. labels are not needed.
172
+
173
+ Args:
174
+ task (str): The name of the task.
175
+
176
+ Returns:
177
+ DataLoader: The data loader for the test dataset.
178
+ """
179
+ # dataloader_kwargs = dict(self.dataloader_kwargs)
180
+ # dataloader_kwargs.update(dict(shuffle=True, collate_fn=default_data_collator))
181
+
182
+ dataset = self.modelpool.load_test_dataset(task)
183
+ log.info("get_shuffled_test_loader_iter")
184
+ loader = DataLoader(
185
+ dataset,
186
+ batch_size=self.batch_size,
187
+ shuffle=True,
188
+ num_workers=self.num_workers,
189
+ collate_fn=default_data_collator,
190
+ )
191
+ # loader = DataLoader(dataset, **dataloader_kwargs)
192
+ if self.fabric is not None:
193
+ loader = self.fabric.setup_dataloaders(loader)
194
+ return iter(InfiniteDataLoader(loader))
195
+
196
+ def compute_logits(
197
+ self,
198
+ module: Union[T5ForConditionalGeneration],
199
+ batch,
200
+ task: str,
201
+ ) -> Tensor:
202
+ """
203
+ Compute the logits for the given images and task.
204
+
205
+ Args:
206
+ module: The model module.
207
+ images (Tensor): The input images.
208
+ task (str): The name of the task.
209
+
210
+ Returns:
211
+ Tensor: The computed logits.
212
+ """
213
+ input_ids: Tensor = batch["input_ids"]
214
+ attention_mask: Tensor = batch["attention_mask"]
215
+
216
+ # remove padding tokens from the input
217
+ while attention_mask[:, -1].eq(0).all():
218
+ input_ids = input_ids[:, :-1]
219
+ attention_mask = attention_mask[:, :-1]
220
+
221
+ outputs = module(
222
+ input_ids=input_ids,
223
+ attention_mask=attention_mask,
224
+ decoder_input_ids=torch.ones(
225
+ input_ids.size(0), 1, dtype=torch.long, device=input_ids.device
226
+ ),
227
+ )
228
+ logits = outputs.logits[:, 0, :]
229
+ return logits
230
+
231
+ def test_time_adaptation(self, module):
232
+ """
233
+ Perform test-time adaptation for the given module.
234
+
235
+ Args:
236
+ module (WeightEnsemblingMoE): The MoE module to adapt.
237
+
238
+ Returns:
239
+ WeightEnsemblingMoE: The adapted MoE module.
240
+ """
241
+ self.on_test_time_adaptation_start()
242
+
243
+ # configure optimizer
244
+ if self.optimizer == "adam":
245
+ print([name for name, p in module.named_parameters() if p.requires_grad])
246
+ optimizer = torch.optim.Adam(
247
+ [p for p in module.parameters() if p.requires_grad], lr=self.lr
248
+ )
249
+ else:
250
+ raise ValueError(f"Unsupported optimizer: {self.optimizer}")
251
+
252
+ module, optimizer = self.fabric.setup(module, optimizer)
253
+
254
+ module.train()
255
+ # module.merge_weights()
256
+ for step_idx in (
257
+ pbar := tqdm(
258
+ range(self.max_steps if not self.is_debug_mode else 1),
259
+ ("[DEBUG MODE] " if self.is_debug_mode else "")
260
+ + "WEMoE Test-time adaptation",
261
+ dynamic_ncols=True,
262
+ )
263
+ ):
264
+ total_loss = 0
265
+ for task in self.modelpool.model_names:
266
+ with self.profile("data loading"):
267
+ batch = next(self.get_shuffled_test_loader_iter(task))
268
+ with self.profile("forward pass"):
269
+ logits = self.compute_logits(module, batch, task)
270
+ logits = logits.mean(dim=0, keepdim=True)
271
+ loss = entropy_loss(logits)
272
+ total_loss += loss
273
+ with self.profile("backward pass"):
274
+ self.fabric.backward(loss, retain_graph=True)
275
+
276
+ with self.profile("optimizer step"):
277
+ optimizer.step()
278
+ optimizer.zero_grad()
279
+
280
+ metrics = {
281
+ "train/loss": total_loss.item(),
282
+ }
283
+ self.fabric.log_dict(metrics, step=step_idx)
284
+ pbar.set_postfix(metrics)
285
+
286
+ log.info(get_memory_usage(f"after adamerging, the memory usage of GPU is:"))
287
+ self.print_profile_summary()
288
+ return module
289
+
290
+ def on_test_time_adaptation_start(self):
291
+ """
292
+ Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.
293
+ """
294
+ pass
295
+
296
+ def run(self, modelpool: Seq2SeqLMPool, **kwargs):
297
+ """
298
+ Run the WeightEnsemblingMoEAlgorithm to fuse models using Weight Ensembling Mixture of Experts.
299
+
300
+ Args:
301
+ modelpool (ModelPool): The pool of models to be fused.
302
+
303
+ Returns:
304
+ WeightEnsemblingMoE: The fused MoE model.
305
+ """
306
+ log.info("Fusing models using layer-wise adaptive merging.")
307
+ self.modelpool = modelpool
308
+
309
+ with timeit_context("upscaling models to a weight-ensembling MoE model"):
310
+ moe_model = self.construct_moe_model()
311
+ print_parameters(moe_model)
312
+
313
+ if self.checkpoint != False:
314
+ log.info(
315
+ f"load checkpoint from {self.checkpoint}, test-time adaptation will be skipped."
316
+ )
317
+ self.load_checkpoint(moe_model, self.checkpoint)
318
+ else:
319
+ with self.profile("test-time adaptation"):
320
+ moe_model = self.test_time_adaptation(moe_model)
321
+ if self.save_checkpoint != False:
322
+ log.info(f"save checkpoint to {self.save_checkpoint}")
323
+ self.save_checkpoint(moe_model, self.save_checkpoint)
324
+
325
+ if lightning.fabric.wrappers.is_wrapped(moe_model):
326
+ moe_model = lightning.fabric.wrappers._unwrap_objects(moe_model)
327
+
328
+ # enable sample-wise adaptation
329
+ moe_model.batch_reduce = False
330
+ self.print_profile_summary()
331
+ return moe_model
@@ -0,0 +1,15 @@
1
+ import torch
2
+
3
+
4
+ def get_memory_usage(desc):
5
+ """
6
+ obtain the current GPU memory usage
7
+
8
+ Returns:
9
+ str: A string containing the allocated and cached memory in MB.
10
+ """
11
+ allocated = torch.cuda.memory_allocated() / 1024**2 # 转换为 MB
12
+ cached = torch.cuda.memory_reserved() / 1024**2 # 转换为 MB
13
+ return (
14
+ f"{desc}\nAllocated Memory: {allocated:.2f} MB\nCached Memory: {cached:.2f} MB"
15
+ )
@@ -7,11 +7,11 @@ from transformers import PreTrainedModel
7
7
  from typing_extensions import override
8
8
 
9
9
  from fusion_bench.method import BaseAlgorithm
10
+ from fusion_bench.mixins import auto_register_config
10
11
  from fusion_bench.modelpool import CausalLMPool
11
12
  from fusion_bench.utils import timeit_context
12
13
  from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_mul
13
14
  from fusion_bench.utils.type import StateDictType
14
- from fusion_bench.mixins import auto_register_config
15
15
 
16
16
  log = logging.getLogger(__name__)
17
17
 
@@ -22,6 +22,7 @@ from torch.utils.data import DataLoader
22
22
  from tqdm.auto import tqdm
23
23
  from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
24
24
 
25
+ from fusion_bench import cache_with_joblib
25
26
  from fusion_bench.dataset.clip_dataset import CLIPDataset
26
27
  from fusion_bench.mixins import LightningFabricMixin
27
28
  from fusion_bench.modelpool import CLIPVisionModelPool
@@ -46,7 +47,6 @@ class CLIPClassificationMixin(LightningFabricMixin):
46
47
 
47
48
  - `_dataloader_kwargs` (Dict[str, Any]): Keyword arguments for the dataloader.
48
49
  - `modelpool` (CLIPVisionModelPool): The model pool containing the CLIP models.
49
- - `zeroshot_weights_cache_dir` (Optional[str]): The directory to cache the zero-shot weights.
50
50
  """
51
51
 
52
52
  dataloader_kwargs: Dict[str, Any] = {}
@@ -54,7 +54,6 @@ class CLIPClassificationMixin(LightningFabricMixin):
54
54
  modelpool: CLIPVisionModelPool = None
55
55
  _clip_processor: CLIPProcessor = None
56
56
  # a dict of zeroshot weights for each task, each key is the task name
57
- zeroshot_weights_cache_dir: str = "outputs/cache/clip_zeroshot_weights"
58
57
  zeroshot_weights: Dict[str, torch.Tensor] = {}
59
58
  whether_setup_zero_shot_classification_head = False
60
59
 
@@ -131,26 +130,16 @@ class CLIPClassificationMixin(LightningFabricMixin):
131
130
  self.visual_projection = self.fabric.to_device(self.visual_projection)
132
131
  self.logit_scale_exp = self.fabric.to_device(self.logit_scale_exp)
133
132
 
134
- # get cache directory
135
- if self.modelpool.has_pretrained:
136
- model_name = self.modelpool.get_model_config("_pretrained_")
137
- if not isinstance(model_name, str):
138
- model_name = model_name.pretrained_model_name_or_path
139
- else:
140
- model_name = self.modelpool.get_model_config(self.modelpool.model_names[0])
141
- if not isinstance(model_name, str):
142
- model_name = model_name.pretrained_model_name_or_path
143
- cache_dir = os.path.join(
144
- self.zeroshot_weights_cache_dir,
145
- os.path.normpath(model_name.split("/")[-1]),
146
- )
147
- if not os.path.exists(cache_dir):
148
- log.info(
149
- f"Creating cache directory for zero-shot classification head at {cache_dir}"
150
- )
151
- os.makedirs(cache_dir)
133
+ @cache_with_joblib()
134
+ def construct_classification_head(task: str):
135
+ nonlocal clip_classifier
136
+
137
+ classnames, templates = get_classnames_and_templates(task)
138
+ clip_classifier.set_classification_task(classnames, templates)
139
+ zeroshot_weights = clip_classifier.zeroshot_weights.detach().clone()
140
+
141
+ return zeroshot_weights
152
142
 
153
- log.info(f"cache directory for zero-shot classification head: {cache_dir}")
154
143
  for task in tqdm(
155
144
  self.modelpool.model_names if task_names is None else task_names,
156
145
  "Setting up zero-shot classification head",
@@ -158,27 +147,7 @@ class CLIPClassificationMixin(LightningFabricMixin):
158
147
  ):
159
148
  zeroshot_weights = None
160
149
  if self.fabric.is_global_zero:
161
- cache_file = os.path.join(
162
- cache_dir, os.path.normpath(f"{task}_zeroshot_weights.pt")
163
- )
164
- if os.path.exists(cache_file):
165
- zeroshot_weights = torch.load(
166
- cache_file,
167
- map_location="cpu",
168
- weights_only=True,
169
- ).detach()
170
- log.info(
171
- f"Loadded cached zeroshot weights for task: {task}, shape: {zeroshot_weights.shape}"
172
- )
173
- else:
174
- log.info(
175
- f"Construct zero shot classification head for task: {task}"
176
- )
177
- classnames, templates = get_classnames_and_templates(task)
178
- clip_classifier.set_classification_task(classnames, templates)
179
- zeroshot_weights = clip_classifier.zeroshot_weights.detach().clone()
180
- log.info(f"save zeroshot weights to {cache_file}")
181
- torch.save(zeroshot_weights, cache_file)
150
+ zeroshot_weights = construct_classification_head(task)
182
151
 
183
152
  self.fabric.barrier()
184
153
  self.zeroshot_weights[task] = self.fabric.broadcast(zeroshot_weights, src=0)