fusion-bench 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. fusion_bench/compat/method/__init__.py +2 -0
  2. fusion_bench/compat/method/base_algorithm.py +7 -2
  3. fusion_bench/compat/modelpool/__init__.py +3 -2
  4. fusion_bench/compat/taskpool/__init__.py +1 -1
  5. fusion_bench/dataset/arc_agi/__init__.py +6 -1
  6. fusion_bench/dataset/arc_agi/arc.py +26 -7
  7. fusion_bench/dataset/arc_agi/arc_agi.py +156 -25
  8. fusion_bench/dataset/arc_agi/np_cache.py +0 -1
  9. fusion_bench/dataset/arc_agi/preprocess.py +51 -9
  10. fusion_bench/dataset/llama/__init__.py +1 -0
  11. fusion_bench/dataset/llama/alpaca.py +93 -3
  12. fusion_bench/dataset/llama/collate.py +72 -5
  13. fusion_bench/dataset/llama/metamathqa.py +50 -0
  14. fusion_bench/dataset/llama/preference_700k.py +70 -0
  15. fusion_bench/dataset/llama/stanford_shp.py +90 -0
  16. fusion_bench/dataset/llama/ultrachat.py +58 -0
  17. fusion_bench/dataset/llama/utils/__init__.py +0 -0
  18. fusion_bench/method/__init__.py +4 -1
  19. fusion_bench/method/adamerging/__init__.py +1 -1
  20. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
  21. fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
  22. fusion_bench/method/linear/expo.py +39 -0
  23. fusion_bench/method/lm_finetune/__init__.py +1 -0
  24. fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
  25. fusion_bench/method/lm_finetune/fullfinetune_sft.py +122 -150
  26. fusion_bench/method/lm_finetune/peftfinetune_sft.py +102 -157
  27. fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
  28. fusion_bench/method/pruning/llama_random_prune.py +2 -2
  29. fusion_bench/method/pruning/magnitude_diff_pruning.py +2 -1
  30. fusion_bench/method/rankone_moe/__init__.py +3 -0
  31. fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
  32. fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
  33. fusion_bench/method/simple_average.py +1 -1
  34. fusion_bench/method/surgery/__init__.py +3 -0
  35. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
  36. fusion_bench/mixins/__init__.py +2 -0
  37. fusion_bench/mixins/clip_classification.py +60 -12
  38. fusion_bench/mixins/fabric_training.py +320 -0
  39. fusion_bench/mixins/lightning_fabric.py +11 -2
  40. fusion_bench/modelpool/__init__.py +2 -0
  41. fusion_bench/modelpool/causal_lm/__init__.py +1 -1
  42. fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
  43. fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
  44. fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
  45. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
  46. fusion_bench/models/chat_templates/__init__.py +1 -0
  47. fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
  48. fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
  49. fusion_bench/models/hf_clip.py +50 -9
  50. fusion_bench/models/rankone_moe.py +410 -0
  51. fusion_bench/models/surgery/surgerymodelwrapper.py +157 -0
  52. fusion_bench/models/utils.py +8 -0
  53. fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
  54. fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
  55. fusion_bench/optim/__init__.py +2 -0
  56. fusion_bench/optim/exception.py +47 -0
  57. fusion_bench/optim/lr_scheduler/__init__.py +1 -0
  58. fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
  59. fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
  60. fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
  61. fusion_bench/optim/mezo.py +0 -2
  62. fusion_bench/programs/fabric_fusion_program.py +5 -1
  63. fusion_bench/taskpool/__init__.py +10 -2
  64. fusion_bench/taskpool/clip_vision/__init__.py +1 -0
  65. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
  66. fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
  67. fusion_bench/taskpool/llama/reward_model.py +157 -0
  68. fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
  69. fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +2 -1
  70. fusion_bench/utils/hydra_utils.py +22 -0
  71. fusion_bench/utils/plot/__init__.py +0 -0
  72. fusion_bench/utils/plot/token.py +52 -0
  73. fusion_bench/utils/plot/token_notebook.py +127 -0
  74. fusion_bench/utils/type.py +5 -3
  75. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/METADATA +1 -1
  76. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/RECORD +104 -57
  77. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  78. fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
  79. fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
  80. fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
  81. fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
  82. fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
  83. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
  84. fusion_bench_config/fabric_model_fusion.yaml +1 -1
  85. fusion_bench_config/llama_full_finetune.yaml +19 -0
  86. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
  87. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +13 -6
  88. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +17 -9
  89. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
  90. fusion_bench_config/method/regmean/clip_regmean.yaml +1 -0
  91. fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
  92. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
  93. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
  94. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
  95. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
  96. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
  97. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
  98. fusion_bench_config/nyuv2_config.yaml +5 -1
  99. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
  100. fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
  101. fusion_bench_config/llama_weighted_average.yaml +0 -26
  102. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/LICENSE +0 -0
  103. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/WHEEL +0 -0
  104. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/entry_points.txt +0 -0
  105. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ import functools
1
2
  import itertools
2
3
  import logging
3
4
  import os
@@ -10,16 +11,16 @@ import peft
10
11
  import torch
11
12
  from lightning.fabric.strategies.fsdp import FSDPStrategy
12
13
  from lightning.fabric.utilities import rank_zero_only
13
- from omegaconf import DictConfig
14
- from peft import PeftModel, get_peft_config, get_peft_model
14
+ from omegaconf import DictConfig, OmegaConf
15
+ from peft import LoraConfig, PeftModel, get_peft_config, get_peft_model
15
16
  from torch import nn
16
- from torch.utils.data import DataLoader, ConcatDataset
17
+ from torch.utils.data import ConcatDataset, DataLoader
17
18
  from tqdm.auto import tqdm
18
19
  from typing_extensions import TYPE_CHECKING, override
19
20
 
20
21
  from fusion_bench import BaseAlgorithm, BaseModelPool
21
22
  from fusion_bench.dataset.llama.collate import padded_collate_sft
22
- from fusion_bench.mixins import LightningFabricMixin
23
+ from fusion_bench.mixins import FabricTrainingMixin
23
24
  from fusion_bench.modelpool import CausalLMPool
24
25
  from fusion_bench.utils import instantiate
25
26
  from fusion_bench.utils.dtype import get_dtype
@@ -35,7 +36,7 @@ if TYPE_CHECKING:
35
36
  log = logging.getLogger(__name__)
36
37
 
37
38
 
38
- class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
39
+ class PeftFinetuneSFT(BaseAlgorithm, FabricTrainingMixin):
39
40
 
40
41
  model: Union[
41
42
  nn.Module, "_FabricModule", "LlamaForCausalLM", PeftModel, peft.LoraModel
@@ -65,7 +66,9 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
65
66
  gradient_clip_algorithm: Literal["value", "norm"] = "norm",
66
67
  save_optimizer_state: bool = False,
67
68
  save_full_model: bool = False,
69
+ save_ckpt_type: Literal["lightning", "peft"] = "peft",
68
70
  ckpt_path: Optional[str] = None,
71
+ max_length: int = 6144,
69
72
  **kwargs,
70
73
  ):
71
74
  """
@@ -90,6 +93,7 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
90
93
  gradient_clip_algorithm(str): Algorithm to use for gradient clipping. Available options: 'value', 'norm'. If set to 'value', the gradients will be clipped to the specified value. If set to 'norm', the gradients will be clipped to the specified norm.
91
94
  save_optimizer_state(bool): Whether to save the optimizer and lr_scheduler state along with the model checkpoint.
92
95
  save_full_model(bool): Whether to save the full model or only the trainable parameters in the model checkpoint.
96
+ save_ckpt_type(str): Type of checkpoint to save. Available options: 'lightning', 'peft'. If set to 'lightning', the model will be saved using the Lightning checkpointing mechanism. If set to 'peft', the model will be saved using the PEFT checkpointing mechanism.
93
97
  ckpt_path(str): Path to the checkpoint to load before training. If set to None, no checkpoint will be loaded.
94
98
  """
95
99
  self._optimizer = optimizer
@@ -110,23 +114,31 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
110
114
  self.gradient_clip_algorithm = gradient_clip_algorithm
111
115
  self.save_optimizer_state = save_optimizer_state
112
116
  self.save_full_model = save_full_model
117
+ self.save_ckpt_type = save_ckpt_type
113
118
  self.ckpt_path = ckpt_path
119
+ self.max_length = max_length
114
120
  super().__init__(**kwargs)
115
121
 
116
122
  def run(self, modelpool: CausalLMPool):
117
123
  self.modelpool = modelpool
118
124
  self.setup()
119
- self.train()
125
+ self.train(self.model, self.optimizer, self.lr_scheduler)
120
126
 
121
127
  if self.merge_and_unload:
122
128
  self.model = self.model.merge_and_unload()
123
129
  return self.model
124
130
 
125
131
  def setup_model(self):
132
+ # https://github.com/Lightning-AI/litgpt/blob/main/litgpt/finetune/lora.py
133
+ self.tokenizer = self.modelpool.load_tokenizer()
134
+ if self.tokenizer.pad_token_id is None:
135
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
136
+
126
137
  model = self.modelpool.load_pretrained_model()
127
138
 
128
139
  # get the PEFT model
129
- peft_config = get_peft_config(self._peft_config)
140
+ peft_config = instantiate(self._peft_config, _convert_="all")
141
+ peft_config.save_pretrained(os.path.join(self.log_dir, "peft_config"))
130
142
  peft_model = get_peft_model(model, peft_config, self.adapter_name)
131
143
  peft_model.print_trainable_parameters()
132
144
 
@@ -139,21 +151,16 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
139
151
  self.model.gradient_checkpointing_enable(
140
152
  gradient_checkpointing_kwargs={"use_reentrant": True}
141
153
  )
154
+ self.use_cache = False
155
+ else:
156
+ self.use_cache = True
157
+
142
158
  self.model_dtype = get_dtype(self.model)
159
+ self.model = self.model.to(dtype=self.model_dtype)
143
160
 
144
161
  def configure_optimizer(self):
145
162
  # compute expected total steps
146
- self.expected_total_steps = []
147
- if self.max_steps > 0:
148
- self.expected_total_steps.append(self.max_steps)
149
- if self.max_steps_per_epoch > 0 and self.max_epochs > 0:
150
- self.expected_total_steps.append(self.max_steps_per_epoch * self.max_epochs)
151
- if self.max_epochs > 0:
152
- self.expected_total_steps.append(
153
- len(self.train_dataloader) * self.max_epochs
154
- )
155
- self.expected_total_steps = min(self.expected_total_steps)
156
- log.info(f"Expected total steps: {self.expected_total_steps}")
163
+ self.compute_expected_total_steps(self.train_dataloader)
157
164
 
158
165
  optimizer = instantiate(self._optimizer, self.model.parameters())
159
166
  if self._lr_scheduler is not None:
@@ -192,7 +199,9 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
192
199
  train_dataset,
193
200
  **self.dataloader_kwargs,
194
201
  shuffle=True,
195
- collate_fn=padded_collate_sft,
202
+ collate_fn=functools.partial(
203
+ padded_collate_sft, pad_token_id=self.tokenizer.pad_token_id
204
+ ),
196
205
  )
197
206
  self.train_dataloader = fabric.setup_dataloaders(self.train_dataloader)
198
207
 
@@ -205,28 +214,19 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
205
214
  optimizer = self.configure_optimizer()
206
215
  optimizer, lr_scheduler = optimizer["optimizer"], optimizer["lr_scheduler"]
207
216
 
208
- self.model, self.optimizer = fabric.setup(self.model, optimizer)
217
+ self.model = self.fabric.setup_module(self.model)
218
+ self.optimizer = self.fabric.setup_optimizers(optimizer)
209
219
  self.lr_scheduler = lr_scheduler
210
220
 
211
- def _clip_gradients_if_needed(self):
221
+ @override
222
+ def train_epoch(self, *args, **kwargs):
212
223
  fabric = self.fabric
213
224
 
214
- if self.gradient_clip_val is not None:
215
- if self.gradient_clip_algorithm == "value":
216
- fabric.clip_gradients(self.model, clip_val=self.gradient_clip_val)
217
- elif self.gradient_clip_algorithm == "norm":
218
- fabric.clip_gradients(self.model, max_norm=self.gradient_clip_val)
219
- else:
220
- raise ValueError(
221
- f"Unknown gradient clip algorithm: {self.gradient_clip_algorithm}. Available options: 'value', 'norm'"
222
- )
223
-
224
- def train_epoch(self):
225
- fabric = self.fabric
225
+ accumulated_loss = 0
226
226
  for step_idx, batch in enumerate(
227
227
  pbar := tqdm(
228
228
  self.train_dataloader,
229
- desc="Training Steps",
229
+ desc="Training Batches",
230
230
  dynamic_ncols=True,
231
231
  leave=False,
232
232
  disable=not fabric.is_global_zero,
@@ -234,24 +234,30 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
234
234
  ):
235
235
  is_accumulating = (step_idx + 1) % self.accumulate_grad_batches != 0
236
236
 
237
+ if self.max_length > 0 and batch["input_ids"].shape[1] > self.max_length:
238
+ log.warning(
239
+ f"Input length exceeds max_length: {batch['input_ids'].shape[1]} > {self.max_length}. Truncating input."
240
+ )
241
+ batch["input_ids"] = batch["input_ids"][:, : self.max_length]
242
+ batch["attention_mask"] = batch["attention_mask"][:, : self.max_length]
243
+ batch["labels"] = batch["labels"][:, : self.max_length]
244
+
237
245
  # disable gradient synchronization if accumulating gradients across steps for improved performance
238
246
  with fabric.no_backward_sync(self.model, enabled=is_accumulating):
239
247
  # use_cache=True is not compatible with gradient checkpointing, so we disable it here
240
- output = self.model(**batch, use_cache=False)
241
- loss = output["loss"]
248
+ output = self.model(
249
+ input_ids=batch["input_ids"],
250
+ attention_mask=batch["attention_mask"],
251
+ labels=batch["labels"],
252
+ use_cache=self.use_cache,
253
+ )
254
+ loss = output["loss"] / self.accumulate_grad_batches
242
255
 
243
256
  fabric.backward(loss)
244
-
245
- metrics = {
246
- "train/loss": loss.item(),
247
- "train/epoch_idx": self.epoch_idx,
248
- "train/lr": self.optimizer.param_groups[0]["lr"],
249
- }
250
- fabric.log_dict(metrics, step=self.global_step_idx)
251
- pbar.set_postfix(metrics)
257
+ accumulated_loss += loss.item()
252
258
 
253
259
  if not is_accumulating:
254
- self._clip_gradients_if_needed()
260
+ self.clip_gradients_if_needed(self.model, self.optimizer)
255
261
 
256
262
  # run lr_scheduler at the end of the step if interval is set to "step"
257
263
  if (
@@ -264,104 +270,30 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
264
270
  self.optimizer.step()
265
271
  self.optimizer.zero_grad()
266
272
 
267
- # save the model at the end of the step if interval is set to "step" and frequency is met
268
- self._try_save_checkpoint(stage="end_of_step")
273
+ metrics = {
274
+ "train/loss": accumulated_loss,
275
+ "train/epoch_idx": self.epoch_idx,
276
+ "train/lr": self.optimizer.param_groups[0]["lr"],
277
+ }
278
+ fabric.log_dict(metrics, step=self.global_step_idx)
279
+ pbar.set_postfix(metrics)
269
280
 
270
- # break if max_steps_per_epoch is set, and exit epoch
271
- if (
272
- self.max_steps_per_epoch > 0
273
- and step_idx + 1 >= self.max_steps_per_epoch
274
- ):
275
- break
276
- # break if max_steps is set, and exit training
277
- if self.max_steps > 0 and self.global_step_idx >= self.max_steps:
278
- self.is_training = False
279
- break
281
+ # save the model at the end of the step if interval is set to "step" and frequency is met
282
+ self.conditional_checkpoint_save(stage="end_of_step")
280
283
 
281
- self.global_step_idx += 1
284
+ # break if max_steps_per_epoch is set, and exit epoch
285
+ if (
286
+ self.max_steps_per_epoch > 0
287
+ and step_idx + 1 >= self.max_steps_per_epoch
288
+ ):
289
+ break
290
+ # break if max_steps is set, and exit training
291
+ if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1:
292
+ self.is_training = False
293
+ break
282
294
 
283
- def train(self):
284
- fabric = self.fabric
285
- self.is_training = True
286
- self.global_step_idx = 0
287
- self.model.train()
288
- for epoch_idx in tqdm(
289
- range(self.max_epochs) if self.max_epochs > 0 else itertools.count(0),
290
- "Training Epoch",
291
- dynamic_ncols=True,
292
- leave=False,
293
- disable=not fabric.is_global_zero,
294
- ):
295
- self.epoch_idx = epoch_idx
296
- self.train_epoch()
297
- # run lr_scheduler at the end of the epoch if interval is set to "epoch"
298
- if (
299
- self.lr_scheduler_interval == "epoch"
300
- and (epoch_idx + 1) % self.lr_scheduler_frequency == 0
301
- ):
302
- self.lr_scheduler.step()
303
-
304
- # save the model at the end of the epoch if interval is set to "epoch" and frequency is met
305
- self._try_save_checkpoint(stage="end_of_epoch")
306
-
307
- if not self.is_training:
308
- break
309
-
310
- # save the model at the end of training
311
- self._try_save_checkpoint(stage="end_of_training")
312
-
313
- def _try_save_checkpoint(
314
- self, stage: Literal["end_of_step", "end_of_epoch", "end_of_training"]
315
- ):
316
- if stage == "end_of_step":
317
- if (
318
- self.checkpoint_save_interval == "step"
319
- and (self.global_step_idx + 1) % self.checkpoint_save_frequency == 0
320
- ):
321
- self.save_checkpoint(
322
- os.path.join(
323
- self.log_dir, "checkpoints", f"step={self.global_step_idx}.ckpt"
324
- )
325
- )
326
- elif stage == "end_of_epoch":
327
- if (
328
- self.checkpoint_save_interval == "epoch"
329
- and (self.epoch_idx + 1) % self.checkpoint_save_frequency == 0
330
- ):
331
- self.save_checkpoint(
332
- os.path.join(
333
- self.log_dir, "checkpoints", f"epoch={self.epoch_idx}.ckpt"
334
- )
335
- )
336
- elif stage == "end_of_training":
337
- # if the checkpoint has not been saved yet, save it
338
- if self.global_step_idx > self._latest_saved_checkpoint_global_step:
339
- self.save_checkpoint(
340
- os.path.join(
341
- self.log_dir,
342
- "checkpoints",
343
- f"epoch={self.epoch_idx}_step={self.global_step_idx}.ckpt",
344
- )
345
- )
346
- try:
347
- os.symlink(
348
- os.path.join(
349
- self.log_dir,
350
- "checkpoints",
351
- "latest_model.ckpt",
352
- ),
353
- os.path.join(
354
- self.log_dir,
355
- "checkpoints",
356
- f"epoch={self.epoch_idx}_step={self.global_step_idx}.ckpt",
357
- ),
358
- )
359
- except Exception as e:
360
- pass
361
- else:
362
- raise ValueError(
363
- f"Unknown stage: {stage}. Available options: 'end_of_step', 'end_of_epoch', 'end_of_training'"
364
- )
295
+ self.global_step_idx += 1
296
+ accumulated_loss = 0
365
297
 
366
298
  def save_checkpoint(
367
299
  self,
@@ -373,24 +305,37 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
373
305
  return log.warning(f"Checkpoint already exists at {path}. Skipping save.")
374
306
 
375
307
  fabric = self.fabric
376
- state = {"model": self.model}
377
-
378
- # save the optimizer and lr_scheduler state if needed
379
- if self.save_optimizer_state and save_optimizer_state is not False:
380
- state.update(
381
- {
382
- "optimizer": self.optimizer,
383
- "lr_scheduler": self.lr_scheduler,
384
- "global_step_idx": self.global_step_idx,
385
- "epoch_idx": self.epoch_idx,
386
- }
308
+ if self.save_ckpt_type == "lightning":
309
+ state = {"model": self.model}
310
+
311
+ # save the optimizer and lr_scheduler state if needed
312
+ if self.save_optimizer_state and save_optimizer_state is not False:
313
+ state.update(
314
+ {
315
+ "optimizer": self.optimizer,
316
+ "lr_scheduler": self.lr_scheduler,
317
+ "global_step_idx": self.global_step_idx,
318
+ "epoch_idx": self.epoch_idx,
319
+ }
320
+ )
321
+ trainable_param_names = set(
322
+ name
323
+ for name, param in self.model.state_dict(keep_vars=True).items()
324
+ if param.requires_grad
325
+ )
326
+ filter = (
327
+ None
328
+ if self.save_full_model
329
+ else {"model": lambda k, p: k in trainable_param_names}
330
+ )
331
+ os.makedirs(os.path.dirname(path), exist_ok=True)
332
+ fabric.save(path, state=state, filter=filter)
333
+ elif self.save_ckpt_type == "peft":
334
+ self.model.save_pretrained(path, is_main_process=fabric.is_global_zero)
335
+ else:
336
+ raise ValueError(
337
+ f"Unknown save_ckpt_type: {self.save_ckpt_type}. Available options: 'lightning', 'peft'"
387
338
  )
388
-
389
- filter = (
390
- None if self.save_full_model else {"model": lambda k, p: p.requires_grad}
391
- )
392
-
393
- fabric.save(path, state=state, filter=filter)
394
339
  self._latest_saved_checkpoint_global_step = self.global_step_idx
395
340
 
396
341
  def load_checkpoint(self, path: Union[str, Path]):
@@ -1,7 +1,7 @@
1
- from typing import Literal, Optional, Union
1
+ from typing import Dict, Literal, Optional, Union
2
2
 
3
3
  import torch
4
- from torch import Dict, nn
4
+ from torch import nn
5
5
  from tqdm.auto import tqdm
6
6
  from transformers import LlamaForCausalLM, LlamaModel
7
7
 
@@ -1,7 +1,7 @@
1
- from typing import Literal, Optional, Union # noqa: F401
1
+ from typing import Dict, Literal, Optional, Union # noqa: F401
2
2
 
3
3
  import torch
4
- from torch import Dict, nn
4
+ from torch import nn
5
5
  from tqdm.auto import tqdm
6
6
  from transformers import LlamaForCausalLM, LlamaModel
7
7
 
@@ -1,3 +1,4 @@
1
+ import functools
1
2
  import logging
2
3
  import re
3
4
  from copy import deepcopy
@@ -10,7 +11,7 @@ from tqdm.auto import tqdm
10
11
  from fusion_bench.method import BaseAlgorithm
11
12
  from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
12
13
  from fusion_bench.modelpool import BaseModelPool
13
- import functools
14
+
14
15
  from .prune_utils import unstructured_magnitude_prune_
15
16
 
16
17
  log = logging.getLogger(__name__)
@@ -0,0 +1,3 @@
1
+ # flake8: noqa F401
2
+ from .clip_rankone_moe import CLIPRankOneMoEAlgorithm
3
+ from .rankone_moe import RankOneMoEAlgorithm
@@ -0,0 +1,160 @@
1
+ import functools
2
+ import logging
3
+ import os
4
+ from copy import deepcopy
5
+
6
+ import torch
7
+ from torch import Tensor
8
+ from torch.utils.data import DataLoader
9
+ from transformers.models.clip.modeling_clip import CLIPEncoder
10
+
11
+ from fusion_bench.dataset import CLIPDataset
12
+ from fusion_bench.method.task_arithmetic.task_arithmetic import task_arithmetic_merge
13
+ from fusion_bench.mixins import CLIPClassificationMixin
14
+ from fusion_bench.modelpool import CLIPVisionModelPool
15
+ from fusion_bench.models.rankone_moe import RankOneMoE
16
+ from fusion_bench.utils.data import InfiniteDataLoader
17
+
18
+ from .rankone_moe import RankOneMoEAlgorithm
19
+
20
+ log = logging.getLogger(__name__)
21
+
22
+
23
+ class CLIPRankOneMoEAlgorithm(
24
+ RankOneMoEAlgorithm,
25
+ CLIPClassificationMixin,
26
+ ):
27
+ """
28
+ CLIPRankOneMoEAlgorithm is a class that implements the RankOneMoEAlgorithm (https://github.com/EnnengYang/RankOne-MoE)
29
+ for CLIP models. It extends the RankOneMoEAlgorithm and CLIPClassificationMixin classes.
30
+
31
+ Attributes:
32
+ modelpool (CLIPVisionModelPool): The model pool containing the CLIP models.
33
+ """
34
+
35
+ modelpool: CLIPVisionModelPool = None
36
+
37
+ def load_checkpoint(self, model, checkpoint):
38
+ """
39
+ Load the checkpoint file.
40
+
41
+ Args:
42
+ model: The model to load the checkpoint into.
43
+ checkpoint: The path to the checkpoint file.
44
+ """
45
+ state = {"model": model}
46
+ self._fabric.load(checkpoint, state)
47
+
48
+ def save_checkpoint(self, model, checkpoint):
49
+ """
50
+ Save the checkpoint file.
51
+
52
+ Args:
53
+ model: The model to save the checkpoint from.
54
+ checkpoint: The path to the checkpoint file.
55
+ """
56
+ self._fabric.save(checkpoint, {"model": model})
57
+
58
+ def construct_moe_model(self) -> RankOneMoE:
59
+ """
60
+ Construct the RankOne-MoE model using the models in the model pool.
61
+
62
+ Returns:
63
+ RankOne-MoE: The constructed MoE model.
64
+ """
65
+ base_model = self.modelpool.load_model("_pretrained_")
66
+ expert_models = [
67
+ self.modelpool.load_model(m) for m in self.modelpool.model_names
68
+ ]
69
+
70
+ # Merge the models using task arithmetic
71
+ moe_model = task_arithmetic_merge(
72
+ # This function modifies the model in place, so we need to pass a deepcopy
73
+ deepcopy(base_model),
74
+ expert_models,
75
+ scaling_factor=self.config.init_lambda,
76
+ ).requires_grad_(False)
77
+
78
+ # Up-scale MLP modules
79
+ base_encoder: CLIPEncoder = base_model.vision_model.encoder
80
+ moe_encoder: CLIPEncoder = moe_model.vision_model.encoder
81
+ expert_encoders = [m.vision_model.encoder for m in expert_models]
82
+
83
+ num_layers = len(base_encoder.layers)
84
+ for layer_idx in range(num_layers):
85
+ base_mlp = base_encoder.layers[layer_idx].mlp
86
+ expert_mlps = [e.layers[layer_idx].mlp for e in expert_encoders]
87
+
88
+ moe_encoder.layers[layer_idx].mlp = RankOneMoE(
89
+ hidden_size=base_encoder.config.hidden_size,
90
+ base_model=base_mlp,
91
+ expert_models=expert_mlps,
92
+ init_lambda=self.config.init_lambda,
93
+ batch_first=True, # For open_clip models this is False
94
+ router_hidden_layers=self.config.router_hidden_layers,
95
+ batch_reduce=self.config.batch_reduce,
96
+ svd_accelerator=self.config.svd_accelerator,
97
+ rank_k=self.config.rank_k,
98
+ select_k=self.config.select_k,
99
+ )
100
+
101
+ return moe_model
102
+
103
+ @functools.cache
104
+ def get_shuffled_test_loader_iter(self, tta_dataset: str):
105
+ """
106
+ Get an iterator for the shuffled test data loader.
107
+
108
+ Args:
109
+ tta_dataset (str): The name of the test-time adaptation dataset.
110
+
111
+ Returns:
112
+ Iterator: An iterator for the shuffled test data loader.
113
+ """
114
+ dataset = self.modelpool.load_test_dataset(tta_dataset)
115
+ dataset = CLIPDataset(dataset, processor=self.clip_processor)
116
+ log.info("get_shuffled_test_loader_iter")
117
+ loader = DataLoader(
118
+ dataset,
119
+ batch_size=self.config.batch_size,
120
+ shuffle=True,
121
+ num_workers=self.config.num_workers,
122
+ pin_memory=True,
123
+ )
124
+ loader = self.fabric.setup_dataloaders(loader)
125
+ return iter(InfiniteDataLoader(loader))
126
+
127
+ def on_test_time_adaptation_start(self):
128
+ """
129
+ Load the CLIP processor and construct the zero-shot classification head for each task.
130
+ """
131
+ self.setup_zero_shot_classification_head()
132
+
133
+ def compute_logits(self, module, batch, task) -> Tensor:
134
+ """
135
+ Compute the logits for the given batch and task.
136
+
137
+ Args:
138
+ module: The model module.
139
+ batch: The input batch.
140
+ task: The task name.
141
+
142
+ Returns:
143
+ Tensor: The computed logits.
144
+ """
145
+ images, _ = batch
146
+ text_embeds = self.zeroshot_weights[task]
147
+
148
+ image_embeds = module(images)[1]
149
+ image_embeds = self.visual_projection(image_embeds)
150
+
151
+ # Normalize embeddings
152
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
153
+
154
+ # Cosine similarity
155
+ logits_per_text = (
156
+ torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
157
+ )
158
+ logits_per_image = logits_per_text.t()
159
+
160
+ return logits_per_image