fusion-bench 0.2.21__py3-none-any.whl → 0.2.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. fusion_bench/__init__.py +25 -2
  2. fusion_bench/compat/method/__init__.py +5 -2
  3. fusion_bench/compat/method/base_algorithm.py +3 -2
  4. fusion_bench/compat/modelpool/base_pool.py +3 -3
  5. fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
  6. fusion_bench/constants/__init__.py +1 -0
  7. fusion_bench/constants/runtime.py +57 -0
  8. fusion_bench/dataset/gpt2_glue.py +1 -1
  9. fusion_bench/method/__init__.py +12 -4
  10. fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
  11. fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
  12. fusion_bench/method/bitdelta/__init__.py +1 -0
  13. fusion_bench/method/bitdelta/bitdelta.py +7 -23
  14. fusion_bench/method/classification/clip_finetune.py +1 -1
  15. fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
  16. fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
  17. fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
  18. fusion_bench/method/fisher_merging/clip_fisher_merging.py +0 -4
  19. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +2 -2
  20. fusion_bench/method/linear/simple_average_for_llama.py +16 -11
  21. fusion_bench/method/model_stock/__init__.py +1 -0
  22. fusion_bench/method/model_stock/model_stock.py +309 -0
  23. fusion_bench/method/regmean/clip_regmean.py +3 -6
  24. fusion_bench/method/regmean/regmean.py +27 -56
  25. fusion_bench/method/regmean/utils.py +56 -0
  26. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
  27. fusion_bench/method/simple_average.py +7 -7
  28. fusion_bench/method/slerp/__init__.py +1 -1
  29. fusion_bench/method/slerp/slerp.py +110 -14
  30. fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
  31. fusion_bench/method/smile_upscaling/projected_energy.py +1 -2
  32. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
  33. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +40 -31
  34. fusion_bench/method/smile_upscaling/smile_upscaling.py +1 -1
  35. fusion_bench/method/we_moe/__init__.py +1 -0
  36. fusion_bench/method/we_moe/entropy_loss.py +25 -0
  37. fusion_bench/method/we_moe/flan_t5_we_moe.py +320 -0
  38. fusion_bench/method/we_moe/utils.py +15 -0
  39. fusion_bench/method/weighted_average/llama.py +1 -1
  40. fusion_bench/mixins/clip_classification.py +37 -48
  41. fusion_bench/mixins/serialization.py +30 -10
  42. fusion_bench/modelpool/base_pool.py +1 -1
  43. fusion_bench/modelpool/causal_lm/causal_lm.py +293 -75
  44. fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
  45. fusion_bench/models/__init__.py +5 -0
  46. fusion_bench/models/hf_utils.py +69 -86
  47. fusion_bench/models/linearized/vision_model.py +6 -6
  48. fusion_bench/models/model_card_templates/default.md +46 -0
  49. fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
  50. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +1 -8
  51. fusion_bench/models/modeling_smile_mistral/__init__.py +2 -1
  52. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +1 -5
  53. fusion_bench/models/we_moe.py +8 -8
  54. fusion_bench/programs/fabric_fusion_program.py +29 -60
  55. fusion_bench/scripts/cli.py +34 -1
  56. fusion_bench/taskpool/base_pool.py +99 -17
  57. fusion_bench/taskpool/clip_vision/taskpool.py +10 -5
  58. fusion_bench/taskpool/dummy.py +101 -13
  59. fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
  60. fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
  61. fusion_bench/utils/__init__.py +2 -0
  62. fusion_bench/utils/cache_utils.py +101 -1
  63. fusion_bench/utils/data.py +6 -4
  64. fusion_bench/utils/devices.py +7 -4
  65. fusion_bench/utils/dtype.py +3 -2
  66. fusion_bench/utils/fabric.py +2 -2
  67. fusion_bench/utils/lazy_imports.py +23 -0
  68. fusion_bench/utils/lazy_state_dict.py +117 -19
  69. fusion_bench/utils/modelscope.py +3 -3
  70. fusion_bench/utils/packages.py +3 -3
  71. fusion_bench/utils/parameters.py +0 -2
  72. fusion_bench/utils/path.py +56 -0
  73. fusion_bench/utils/pylogger.py +1 -1
  74. fusion_bench/utils/timer.py +92 -10
  75. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/METADATA +1 -23
  76. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/RECORD +89 -75
  77. fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
  78. fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
  79. fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
  80. fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
  81. fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
  82. fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
  83. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -1
  84. fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
  85. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +1 -1
  86. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/WHEEL +0 -0
  87. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/entry_points.txt +0 -0
  88. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/licenses/LICENSE +0 -0
  89. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/top_level.txt +0 -0
@@ -16,10 +16,11 @@ from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
16
16
 
17
17
  from fusion_bench import BaseAlgorithm, BaseModelPool
18
18
  from fusion_bench.compat.modelpool import to_modelpool
19
+ from fusion_bench.constants import RuntimeConstants
19
20
  from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
20
21
  from fusion_bench.modelpool import CausalLMPool
21
22
  from fusion_bench.models.hf_utils import (
22
- generate_complete_readme,
23
+ create_default_model_card,
23
24
  save_pretrained_with_remote_code,
24
25
  )
25
26
  from fusion_bench.models.modeling_smile_qwen2 import (
@@ -41,7 +42,10 @@ log = logging.getLogger(__name__)
41
42
 
42
43
 
43
44
  @auto_register_config
44
- class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
45
+ class SmileQwen2UpscalingAlgorithm(
46
+ SimpleProfilerMixin,
47
+ BaseAlgorithm,
48
+ ):
45
49
  R"""
46
50
  SmileQwen2UpscalingAlgorithm is a model fusion algorithm designed to upscale
47
51
  a pretrained Qwen2 model using a set of fine-tuned expert models. The algorithm
@@ -62,7 +66,7 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
62
66
  self,
63
67
  device,
64
68
  accelerator,
65
- model_path,
69
+ model_save_path,
66
70
  model_dtype,
67
71
  num_experts_per_tok,
68
72
  rank_of_router,
@@ -71,6 +75,11 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
71
75
  **kwargs,
72
76
  ):
73
77
  super().__init__(**kwargs)
78
+ if not torch.cuda.is_available():
79
+ if "cuda" in self.device:
80
+ self.device = "cpu"
81
+ if "cuda" in self.accelerator:
82
+ self.accelerator = "cpu"
74
83
 
75
84
  @torch.no_grad()
76
85
  def run(self, modelpool) -> SmileQwen2ForCausalLM:
@@ -86,13 +95,6 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
86
95
  self.modelpool = modelpool = to_modelpool(modelpool)
87
96
  config = self.config
88
97
 
89
- # load model from path if provided and return directly
90
- if config.model_path is not None and os.path.exists(config.model_path):
91
- log.info(f"Loading model from {config.model_path}")
92
- model = AutoModelForCausalLM.from_pretrained(config.model_path)
93
- print_parameters(model)
94
- return model
95
-
96
98
  with self.profile("load pretrained model"):
97
99
  pretrained_model = modelpool.load_pretrained_model()
98
100
  with self.profile("load fine-tuned model"):
@@ -100,7 +102,7 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
100
102
  m for m in tqdm(modelpool.models(), total=len(modelpool.model_names))
101
103
  ]
102
104
 
103
- if config.device == "cuda" and torch.cuda.is_available():
105
+ if self.device == "cuda" and torch.cuda.is_available():
104
106
  pretrained_model = pretrained_model.cuda()
105
107
  print("parameter count of pretrained model:")
106
108
  print_parameters(pretrained_model)
@@ -114,17 +116,17 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
114
116
  print_parameters(model)
115
117
  print(model)
116
118
 
117
- if config.model_dtype is not None:
118
- model.to(dtype=parse_dtype(config.model_dtype))
119
+ if self.model_dtype is not None:
120
+ model.to(dtype=parse_dtype(self.model_dtype))
119
121
 
120
- if config.model_path is not None:
121
- if os.path.dirname(config.model_path):
122
- os.makedirs(os.path.dirname(config.model_path), exist_ok=True)
123
- log.info(f"Saving model to {config.model_path}")
122
+ if self.model_save_path is not None:
123
+ if os.path.dirname(self.model_save_path):
124
+ os.makedirs(os.path.dirname(self.model_save_path), exist_ok=True)
125
+ log.info(f"Saving model to {self.model_save_path}")
124
126
  tokenizer = self.modelpool.load_tokenizer()
125
- tokenizer.save_pretrained(config.model_path)
127
+ tokenizer.save_pretrained(self.model_save_path)
126
128
  if not self.save_with_remote_code:
127
- model.save_pretrained(config.model_path)
129
+ model.save_pretrained(self.model_save_path)
128
130
  else:
129
131
  save_pretrained_with_remote_code(
130
132
  model,
@@ -133,17 +135,18 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
133
135
  "AutoModel": SmileQwen2Model,
134
136
  "AutoModelForCausalLM": SmileQwen2ForCausalLM,
135
137
  },
136
- save_directory=config.model_path,
138
+ save_directory=self.model_save_path,
137
139
  )
138
140
 
139
141
  # save readme
140
- complete_readme = generate_complete_readme(
141
- algorithm=self,
142
- modelpool=modelpool,
142
+ model_card_str = create_default_model_card(
143
143
  models=[modelpool.get_model_path(m) for m in modelpool.all_model_names],
144
+ description="Merged Qwen model using SMILE Upscaling",
145
+ algorithm_config=self.config,
146
+ modelpool_config=modelpool.config,
144
147
  )
145
- with open(os.path.join(config.model_path, "README.md"), "w") as f:
146
- f.write(complete_readme)
148
+ with open(os.path.join(self.model_save_path, "README.md"), "w") as f:
149
+ f.write(model_card_str)
147
150
 
148
151
  return model
149
152
 
@@ -174,9 +177,9 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
174
177
  )
175
178
  base_config = AutoConfig.from_pretrained(pretrained_path)
176
179
  model_config = SmileQwen2Config(
177
- num_experts_per_tok=config.num_experts_per_tok,
178
- rank_of_router=config.rank_of_router,
179
- rank_of_expert=config.rank_of_expert,
180
+ num_experts_per_tok=self.num_experts_per_tok,
181
+ rank_of_router=self.rank_of_router,
182
+ rank_of_expert=self.rank_of_expert,
180
183
  num_local_experts=len(finetuned_models),
181
184
  **base_config.to_dict(),
182
185
  )
@@ -186,7 +189,7 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
186
189
 
187
190
  # copy pretrained model weights
188
191
  state_dict = model.state_dict()
189
- pretrained_state_dict = dict(pretrained_model.state_dict())
192
+ pretrained_state_dict = pretrained_model.state_dict()
190
193
  for key in list(pretrained_state_dict.keys()):
191
194
  if key not in state_dict:
192
195
  pretrained_state_dict.pop(key)
@@ -198,6 +201,12 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
198
201
  "Upscaling Modules (layer)",
199
202
  dynamic_ncols=True,
200
203
  ):
204
+ if RuntimeConstants.debug and layer_idx > 0:
205
+ log.info(
206
+ "Debug mode enabled: processing only the first layer, skipping remaining layers"
207
+ )
208
+ break
209
+
201
210
  pretrained_layer: Qwen2DecoderLayer = pretrained_model.model.layers[
202
211
  layer_idx
203
212
  ]
@@ -213,7 +222,7 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
213
222
  base=getattr(pretrained_layer.self_attn, n),
214
223
  experts=[getattr(m.self_attn, n) for m in finetuned_layers],
215
224
  target=getattr(target_layer.self_attn, n),
216
- accelerator=config.accelerator,
225
+ accelerator=self.accelerator,
217
226
  )
218
227
  except ExpertNotTrainedError:
219
228
  setattr(
@@ -228,7 +237,7 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
228
237
  base=getattr(pretrained_layer.mlp, n),
229
238
  experts=[getattr(m.mlp, n) for m in finetuned_layers],
230
239
  target=getattr(target_layer.mlp, n),
231
- accelerator=config.accelerator,
240
+ accelerator=self.accelerator,
232
241
  )
233
242
  except ExpertNotTrainedError:
234
243
  setattr(
@@ -20,8 +20,8 @@ from fusion_bench.models.smile_moe.linear_from_module import (
20
20
  SmileMoELinear,
21
21
  )
22
22
  from fusion_bench.models.utils import get_attr, set_attr
23
- from fusion_bench.utils.parameters import print_parameters
24
23
  from fusion_bench.utils.devices import get_device
24
+ from fusion_bench.utils.parameters import print_parameters
25
25
 
26
26
  log = logging.getLogger(__name__)
27
27
 
@@ -1,2 +1,3 @@
1
1
  # flake8: noqa F401
2
2
  from .clip_we_moe import CLIPWeightEnsemblingMoEAlgorithm
3
+ from .flan_t5_we_moe import FlanT5WeightEnsemblingMoEAlgorithm
@@ -0,0 +1,25 @@
1
+ import torch
2
+ from torch import Tensor
3
+
4
+
5
+ def entropy_loss(logits: Tensor, eps: float = 1e-8) -> Tensor:
6
+ """
7
+ Compute the entropy loss of a set of logits.
8
+
9
+ Args:
10
+ logits (Tensor): The logits to compute the entropy loss of.
11
+ eps (float): A small value to avoid log(0). Default is 1e-8.
12
+
13
+ Returns:
14
+ Tensor: The entropy loss of the logits.
15
+ """
16
+ # Ensure the logits tensor has 2 dimensions
17
+ assert (
18
+ logits.dim() == 2
19
+ ), f"Expected logits to have 2 dimensions, found {logits.dim()}, {logits.size()=}"
20
+
21
+ # Compute the softmax probabilities
22
+ probs = torch.softmax(logits, dim=-1)
23
+
24
+ # Compute the entropy loss
25
+ return -torch.sum(probs * torch.log(probs + eps), dim=-1).mean()
@@ -0,0 +1,320 @@
1
+ import functools
2
+ import logging
3
+ import os
4
+ from copy import deepcopy
5
+ from typing import Any, Dict, List, Mapping, Optional, Union, cast # noqa: F401
6
+
7
+ import lightning
8
+ import lightning as L
9
+ import lightning.fabric.wrappers
10
+ import torch
11
+ from torch import Tensor
12
+ from torch.utils.data import DataLoader
13
+ from tqdm.autonotebook import tqdm
14
+ from transformers import T5ForConditionalGeneration
15
+ from transformers.data import default_data_collator
16
+
17
+ from fusion_bench.method import BaseAlgorithm
18
+ from fusion_bench.method.task_arithmetic.task_arithmetic import task_arithmetic_merge
19
+ from fusion_bench.mixins import (
20
+ LightningFabricMixin,
21
+ SimpleProfilerMixin,
22
+ auto_register_config,
23
+ )
24
+ from fusion_bench.modelpool import Seq2SeqLMPool
25
+ from fusion_bench.models.we_moe import WeightEnsemblingMoE
26
+ from fusion_bench.utils import print_parameters, timeit_context
27
+ from fusion_bench.utils.data import InfiniteDataLoader, load_tensor_from_file
28
+ from fusion_bench.utils.instantiate_utils import instantiate
29
+ from fusion_bench.utils.parameters import print_parameters
30
+
31
+ from .entropy_loss import entropy_loss
32
+ from .utils import get_memory_usage
33
+
34
+ log = logging.getLogger(__name__)
35
+
36
+
37
+ @auto_register_config
38
+ class FlanT5WeightEnsemblingMoEAlgorithm(
39
+ LightningFabricMixin,
40
+ SimpleProfilerMixin,
41
+ BaseAlgorithm,
42
+ ):
43
+ """
44
+ FlanT5WeightEnsemblingMoEAlgorithm is a class that implements the WeightEnsemblingMoEAlgorithm
45
+ for FlanT5 models. It extends the WeightEnsemblingMoEAlgorithm and CLIPClassificationMixin classes.
46
+
47
+ Attributes:
48
+ modelpool (Seq2SeqLMPool): The model pool containing the FlanT5 models.
49
+ """
50
+
51
+ modelpool: Seq2SeqLMPool = None
52
+
53
+ def __init__(
54
+ self,
55
+ checkpoint: bool = False,
56
+ save_checkpoint: bool = False,
57
+ router_hidden_layers: int = 2,
58
+ init_lambda: float = 0.3,
59
+ batch_reduce: bool = True,
60
+ lr: float = 1e-4,
61
+ optimizer: str = "adam",
62
+ devices: int = 1,
63
+ batch_size: int = 16,
64
+ num_workers: int = 0,
65
+ max_steps: int = 1000,
66
+ use_grad_accumulate: bool = True,
67
+ fast_dev_run: bool = False,
68
+ **kwargs,
69
+ ):
70
+ """
71
+ Initialize the WeightEnsemblingMoEAlgorithm with the given configuration.
72
+
73
+ Args:
74
+ algorithm_config (DictConfig): The configuration for the algorithm.
75
+ """
76
+ super().__init__(**kwargs)
77
+
78
+ def construct_moe_model(self):
79
+ """
80
+ Construct the Mixture of Experts (MoE) model using the models in the model pool.
81
+
82
+ Returns:
83
+ WeightEnsemblingMoE: The constructed MoE model.
84
+ """
85
+ base_model = self.modelpool.load_model("_pretrained_")
86
+ expert_models = [
87
+ self.modelpool.load_model(name) for name in self.modelpool.model_names
88
+ ]
89
+
90
+ # Merge the models using task arithmetic
91
+ moe_model = task_arithmetic_merge(
92
+ # This function modifies the model in place, so we need to pass a deepcopy
93
+ deepcopy(base_model),
94
+ expert_models,
95
+ scaling_factor=self.init_lambda,
96
+ ).requires_grad_(False)
97
+
98
+ print(base_model)
99
+
100
+ # Up-scale MLP modules
101
+ num_layer = 12
102
+ encoder_mlp_index = 1
103
+ base_encoder = base_model.encoder
104
+ moe_encoder = moe_model.encoder
105
+ expert_encoders = [m.encoder for m in expert_models]
106
+
107
+ for layer_idx in range(num_layer):
108
+ base_mlp = (
109
+ base_encoder.block[layer_idx].layer[encoder_mlp_index].DenseReluDense
110
+ )
111
+ expert_mlps = [
112
+ e.block[layer_idx].layer[encoder_mlp_index].DenseReluDense
113
+ for e in expert_encoders
114
+ ]
115
+
116
+ moe_encoder.block[layer_idx].layer[encoder_mlp_index].DenseReluDense = (
117
+ WeightEnsemblingMoE(
118
+ hidden_size=base_encoder.config.hidden_size,
119
+ base_model=base_mlp,
120
+ expert_models=expert_mlps,
121
+ init_lambda=self.init_lambda,
122
+ batch_first=True,
123
+ router_hidden_layers=self.router_hidden_layers,
124
+ batch_reduce=self.batch_reduce,
125
+ )
126
+ )
127
+
128
+ decoder_mlp_index = 2
129
+ base_decoder = base_model.decoder
130
+ moe_decoder = moe_model.decoder
131
+ expert_decoders = [m.decoder for m in expert_models]
132
+
133
+ for layer_idx in range(num_layer):
134
+ base_mlp = (
135
+ base_decoder.block[layer_idx].layer[decoder_mlp_index].DenseReluDense
136
+ )
137
+ expert_mlps = [
138
+ e.block[layer_idx].layer[decoder_mlp_index].DenseReluDense
139
+ for e in expert_decoders
140
+ ]
141
+
142
+ moe_decoder.block[layer_idx].layer[decoder_mlp_index].DenseReluDense = (
143
+ WeightEnsemblingMoE(
144
+ hidden_size=base_decoder.config.hidden_size,
145
+ base_model=base_mlp,
146
+ expert_models=expert_mlps,
147
+ init_lambda=self.init_lambda,
148
+ batch_first=True,
149
+ router_hidden_layers=self.router_hidden_layers,
150
+ batch_reduce=self.batch_reduce,
151
+ )
152
+ )
153
+
154
+ print(moe_model)
155
+ return moe_model
156
+
157
+ @functools.cache
158
+ def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
159
+ """
160
+ Loader of test dataset for test-time adaptation. labels are not needed.
161
+
162
+ Args:
163
+ task (str): The name of the task.
164
+
165
+ Returns:
166
+ DataLoader: The data loader for the test dataset.
167
+ """
168
+ # dataloader_kwargs = dict(self.dataloader_kwargs)
169
+ # dataloader_kwargs.update(dict(shuffle=True, collate_fn=default_data_collator))
170
+
171
+ dataset = self.modelpool.load_test_dataset(task)
172
+ log.info("get_shuffled_test_loader_iter")
173
+ loader = DataLoader(
174
+ dataset,
175
+ batch_size=self.batch_size,
176
+ shuffle=True,
177
+ num_workers=self.num_workers,
178
+ collate_fn=default_data_collator,
179
+ )
180
+ # loader = DataLoader(dataset, **dataloader_kwargs)
181
+ if self.fabric is not None:
182
+ loader = self.fabric.setup_dataloaders(loader)
183
+ return iter(InfiniteDataLoader(loader))
184
+
185
+ def compute_logits(
186
+ self,
187
+ module: Union[T5ForConditionalGeneration],
188
+ batch,
189
+ task: str,
190
+ ) -> Tensor:
191
+ """
192
+ Compute the logits for the given images and task.
193
+
194
+ Args:
195
+ module: The model module.
196
+ images (Tensor): The input images.
197
+ task (str): The name of the task.
198
+
199
+ Returns:
200
+ Tensor: The computed logits.
201
+ """
202
+ input_ids: Tensor = batch["input_ids"]
203
+ attention_mask: Tensor = batch["attention_mask"]
204
+
205
+ # remove padding tokens from the input
206
+ while attention_mask[:, -1].eq(0).all():
207
+ input_ids = input_ids[:, :-1]
208
+ attention_mask = attention_mask[:, :-1]
209
+
210
+ outputs = module(
211
+ input_ids=input_ids,
212
+ attention_mask=attention_mask,
213
+ decoder_input_ids=torch.ones(
214
+ input_ids.size(0), 1, dtype=torch.long, device=input_ids.device
215
+ ),
216
+ )
217
+ logits = outputs.logits[:, 0, :]
218
+ return logits
219
+
220
+ def test_time_adaptation(self, module):
221
+ """
222
+ Perform test-time adaptation for the given module.
223
+
224
+ Args:
225
+ module (WeightEnsemblingMoE): The MoE module to adapt.
226
+
227
+ Returns:
228
+ WeightEnsemblingMoE: The adapted MoE module.
229
+ """
230
+ self.on_test_time_adaptation_start()
231
+
232
+ # configure optimizer
233
+ if self.optimizer == "adam":
234
+ print([name for name, p in module.named_parameters() if p.requires_grad])
235
+ optimizer = torch.optim.Adam(
236
+ [p for p in module.parameters() if p.requires_grad], lr=self.lr
237
+ )
238
+ else:
239
+ raise ValueError(f"Unsupported optimizer: {self.optimizer}")
240
+
241
+ module, optimizer = self.fabric.setup(module, optimizer)
242
+
243
+ module.train()
244
+ # module.merge_weights()
245
+ for step_idx in (
246
+ pbar := tqdm(
247
+ range(self.max_steps if not self.is_debug_mode else 1),
248
+ ("[DEBUG MODE] " if self.is_debug_mode else "")
249
+ + "WEMoE Test-time adaptation",
250
+ dynamic_ncols=True,
251
+ )
252
+ ):
253
+ total_loss = 0
254
+ for task in self.modelpool.model_names:
255
+ with self.profile("data loading"):
256
+ batch = next(self.get_shuffled_test_loader_iter(task))
257
+ with self.profile("forward pass"):
258
+ logits = self.compute_logits(module, batch, task)
259
+ logits = logits.mean(dim=0, keepdim=True)
260
+ loss = entropy_loss(logits)
261
+ total_loss += loss
262
+ with self.profile("backward pass"):
263
+ self.fabric.backward(loss, retain_graph=True)
264
+
265
+ with self.profile("optimizer step"):
266
+ optimizer.step()
267
+ optimizer.zero_grad()
268
+
269
+ metrics = {
270
+ "train/loss": total_loss.item(),
271
+ }
272
+ self.fabric.log_dict(metrics, step=step_idx)
273
+ pbar.set_postfix(metrics)
274
+
275
+ log.info(get_memory_usage(f"after adamerging, the memory usage of GPU is:"))
276
+ self.print_profile_summary()
277
+ return module
278
+
279
+ def on_test_time_adaptation_start(self):
280
+ """
281
+ Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.
282
+ """
283
+ pass
284
+
285
+ def run(self, modelpool: Seq2SeqLMPool, **kwargs):
286
+ """
287
+ Run the WeightEnsemblingMoEAlgorithm to fuse models using Weight Ensembling Mixture of Experts.
288
+
289
+ Args:
290
+ modelpool (ModelPool): The pool of models to be fused.
291
+
292
+ Returns:
293
+ WeightEnsemblingMoE: The fused MoE model.
294
+ """
295
+ log.info("Fusing models using layer-wise adaptive merging.")
296
+ self.modelpool = modelpool
297
+
298
+ with timeit_context("upscaling models to a weight-ensembling MoE model"):
299
+ moe_model = self.construct_moe_model()
300
+ print_parameters(moe_model)
301
+
302
+ if self.checkpoint != False:
303
+ log.info(
304
+ f"load checkpoint from {self.checkpoint}, test-time adaptation will be skipped."
305
+ )
306
+ self.load_checkpoint(moe_model, self.checkpoint)
307
+ else:
308
+ with self.profile("test-time adaptation"):
309
+ moe_model = self.test_time_adaptation(moe_model)
310
+ if self.save_checkpoint != False:
311
+ log.info(f"save checkpoint to {self.save_checkpoint}")
312
+ self.save_checkpoint(moe_model, self.save_checkpoint)
313
+
314
+ if lightning.fabric.wrappers.is_wrapped(moe_model):
315
+ moe_model = lightning.fabric.wrappers._unwrap_objects(moe_model)
316
+
317
+ # enable sample-wise adaptation
318
+ moe_model.batch_reduce = False
319
+ self.print_profile_summary()
320
+ return moe_model
@@ -0,0 +1,15 @@
1
+ import torch
2
+
3
+
4
+ def get_memory_usage(desc):
5
+ """
6
+ obtain the current GPU memory usage
7
+
8
+ Returns:
9
+ str: A string containing the allocated and cached memory in MB.
10
+ """
11
+ allocated = torch.cuda.memory_allocated() / 1024**2 # 转换为 MB
12
+ cached = torch.cuda.memory_reserved() / 1024**2 # 转换为 MB
13
+ return (
14
+ f"{desc}\nAllocated Memory: {allocated:.2f} MB\nCached Memory: {cached:.2f} MB"
15
+ )
@@ -7,11 +7,11 @@ from transformers import PreTrainedModel
7
7
  from typing_extensions import override
8
8
 
9
9
  from fusion_bench.method import BaseAlgorithm
10
+ from fusion_bench.mixins import auto_register_config
10
11
  from fusion_bench.modelpool import CausalLMPool
11
12
  from fusion_bench.utils import timeit_context
12
13
  from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_mul
13
14
  from fusion_bench.utils.type import StateDictType
14
- from fusion_bench.mixins import auto_register_config
15
15
 
16
16
  log = logging.getLogger(__name__)
17
17