fusion-bench 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. fusion_bench/compat/method/__init__.py +1 -0
  2. fusion_bench/compat/method/base_algorithm.py +0 -1
  3. fusion_bench/compat/modelpool/__init__.py +2 -1
  4. fusion_bench/dataset/arc_agi/__init__.py +6 -1
  5. fusion_bench/dataset/arc_agi/arc.py +21 -7
  6. fusion_bench/dataset/arc_agi/arc_agi.py +156 -25
  7. fusion_bench/dataset/arc_agi/np_cache.py +0 -1
  8. fusion_bench/dataset/arc_agi/preprocess.py +50 -8
  9. fusion_bench/dataset/llama/collate.py +10 -3
  10. fusion_bench/method/__init__.py +3 -0
  11. fusion_bench/method/adamerging/__init__.py +1 -1
  12. fusion_bench/method/lm_finetune/fullfinetune_sft.py +47 -5
  13. fusion_bench/method/lm_finetune/peftfinetune_sft.py +58 -23
  14. fusion_bench/method/pruning/magnitude_diff_pruning.py +2 -1
  15. fusion_bench/method/rankone_moe/__init__.py +3 -0
  16. fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
  17. fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
  18. fusion_bench/method/simple_average.py +1 -1
  19. fusion_bench/mixins/clip_classification.py +2 -7
  20. fusion_bench/mixins/lightning_fabric.py +2 -2
  21. fusion_bench/models/rankone_moe.py +410 -0
  22. fusion_bench/taskpool/__init__.py +10 -2
  23. fusion_bench/taskpool/clip_vision/__init__.py +1 -0
  24. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
  25. fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +2 -1
  26. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/METADATA +1 -1
  27. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/RECORD +36 -29
  28. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +4 -4
  29. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +13 -7
  30. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
  31. fusion_bench_config/method/regmean/clip_regmean.yaml +1 -0
  32. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
  33. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/LICENSE +0 -0
  34. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/WHEEL +0 -0
  35. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/entry_points.txt +0 -0
  36. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.6.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@ from omegaconf import DictConfig
13
13
  from torch import nn
14
14
  from torch.utils.data import ConcatDataset, DataLoader
15
15
  from tqdm.auto import tqdm
16
+ from transformers import AutoModelForCausalLM, AutoTokenizer
16
17
  from typing_extensions import TYPE_CHECKING, override
17
18
 
18
19
  from fusion_bench import BaseAlgorithm, BaseModelPool
@@ -117,6 +118,9 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
117
118
  self.model.gradient_checkpointing_enable(
118
119
  gradient_checkpointing_kwargs={"use_reentrant": True}
119
120
  )
121
+ self.use_cache = False
122
+ else:
123
+ self.use_cache = True
120
124
  self.model_dtype = get_dtype(self.model)
121
125
 
122
126
  def configure_optimizer(self):
@@ -215,7 +219,12 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
215
219
  # disable gradient synchronization if accumulating gradients across steps for improved performance
216
220
  with fabric.no_backward_sync(self.model, enabled=is_accumulating):
217
221
  # use_cache=True is not compatible with gradient checkpointing, so we disable it here
218
- output = self.model(**batch, use_cache=False)
222
+ output = self.model(
223
+ input_ids=batch["input_ids"],
224
+ attention_mask=batch["attention_mask"],
225
+ labels=batch["labels"],
226
+ use_cache=self.use_cache,
227
+ )
219
228
  loss = output["loss"]
220
229
 
221
230
  fabric.backward(loss)
@@ -252,7 +261,7 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
252
261
  ):
253
262
  break
254
263
  # break if max_steps is set, and exit training
255
- if self.max_steps > 0 and self.global_step_idx >= self.max_steps:
264
+ if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1:
256
265
  self.is_training = False
257
266
  break
258
267
 
@@ -328,14 +337,15 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
328
337
  "checkpoints",
329
338
  "latest_model.ckpt",
330
339
  ),
331
- os.path.join(
340
+ dst := os.path.join(
332
341
  self.log_dir,
333
342
  "checkpoints",
334
343
  f"epoch={self.epoch_idx}_step={self.global_step_idx}.ckpt",
335
344
  ),
345
+ target_is_directory=os.path.isdir(dst),
336
346
  )
337
347
  except Exception as e:
338
- pass
348
+ log.error(f"Failed to create symlink: {e}")
339
349
  else:
340
350
  raise ValueError(
341
351
  f"Unknown stage: {stage}. Available options: 'end_of_step', 'end_of_epoch', 'end_of_training'"
@@ -364,8 +374,15 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
364
374
  }
365
375
  )
366
376
 
377
+ trainable_param_names = set(
378
+ name
379
+ for name, param in self.model.state_dict(keep_vars=True).items()
380
+ if param.requires_grad
381
+ )
367
382
  filter = (
368
- None if self.save_full_model else {"model": lambda k, p: p.requires_grad}
383
+ None
384
+ if self.save_full_model
385
+ else {"model": lambda k, p: k in trainable_param_names}
369
386
  )
370
387
 
371
388
  fabric.save(path, state=state, filter=filter)
@@ -401,3 +418,28 @@ def load_checkpoint(
401
418
  state = {"model": model}
402
419
  state.update(state_components)
403
420
  fabric.load(ckpt_path, state=state, strict=strict)
421
+
422
+
423
+ if __name__ == "__main__":
424
+ # convert a checkpoint to hf format
425
+ import argparse
426
+
427
+ parser = argparse.ArgumentParser()
428
+ parser.add_argument("--base_model_path", type=str)
429
+ parser.add_argument("--ckpt_path", type=str)
430
+ parser.add_argument("--output_path", type=str)
431
+
432
+ args = parser.parse_args()
433
+
434
+ fabric = L.Fabric(devices=1, strategy="fsdp")
435
+ fabric.launch()
436
+
437
+ tokenizer = AutoTokenizer.from_pretrained(args.base_model_path)
438
+ tokenizer.save_pretrained(args.output_path)
439
+
440
+ model = AutoModelForCausalLM.from_pretrained(
441
+ args.base_model_path, torch_dtype=torch.bfloat16
442
+ )
443
+ model = fabric.setup_module(model)
444
+ load_checkpoint(fabric, args.ckpt_path, model=model, strict=True)
445
+ model.save_pretrained(args.output_path)
@@ -10,10 +10,10 @@ import peft
10
10
  import torch
11
11
  from lightning.fabric.strategies.fsdp import FSDPStrategy
12
12
  from lightning.fabric.utilities import rank_zero_only
13
- from omegaconf import DictConfig
13
+ from omegaconf import DictConfig, OmegaConf
14
14
  from peft import PeftModel, get_peft_config, get_peft_model
15
15
  from torch import nn
16
- from torch.utils.data import DataLoader, ConcatDataset
16
+ from torch.utils.data import ConcatDataset, DataLoader
17
17
  from tqdm.auto import tqdm
18
18
  from typing_extensions import TYPE_CHECKING, override
19
19
 
@@ -65,7 +65,9 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
65
65
  gradient_clip_algorithm: Literal["value", "norm"] = "norm",
66
66
  save_optimizer_state: bool = False,
67
67
  save_full_model: bool = False,
68
+ save_ckpt_type: Literal["lightning", "peft"] = "peft",
68
69
  ckpt_path: Optional[str] = None,
70
+ max_length: int = 6150,
69
71
  **kwargs,
70
72
  ):
71
73
  """
@@ -90,6 +92,7 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
90
92
  gradient_clip_algorithm(str): Algorithm to use for gradient clipping. Available options: 'value', 'norm'. If set to 'value', the gradients will be clipped to the specified value. If set to 'norm', the gradients will be clipped to the specified norm.
91
93
  save_optimizer_state(bool): Whether to save the optimizer and lr_scheduler state along with the model checkpoint.
92
94
  save_full_model(bool): Whether to save the full model or only the trainable parameters in the model checkpoint.
95
+ save_ckpt_type(str): Type of checkpoint to save. Available options: 'lightning', 'peft'. If set to 'lightning', the model will be saved using the Lightning checkpointing mechanism. If set to 'peft', the model will be saved using the PEFT checkpointing mechanism.
93
96
  ckpt_path(str): Path to the checkpoint to load before training. If set to None, no checkpoint will be loaded.
94
97
  """
95
98
  self._optimizer = optimizer
@@ -110,7 +113,9 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
110
113
  self.gradient_clip_algorithm = gradient_clip_algorithm
111
114
  self.save_optimizer_state = save_optimizer_state
112
115
  self.save_full_model = save_full_model
116
+ self.save_ckpt_type = save_ckpt_type
113
117
  self.ckpt_path = ckpt_path
118
+ self.max_length = max_length
114
119
  super().__init__(**kwargs)
115
120
 
116
121
  def run(self, modelpool: CausalLMPool):
@@ -126,7 +131,7 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
126
131
  model = self.modelpool.load_pretrained_model()
127
132
 
128
133
  # get the PEFT model
129
- peft_config = get_peft_config(self._peft_config)
134
+ peft_config = instantiate(self._peft_config, _convert_="all")
130
135
  peft_model = get_peft_model(model, peft_config, self.adapter_name)
131
136
  peft_model.print_trainable_parameters()
132
137
 
@@ -139,6 +144,10 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
139
144
  self.model.gradient_checkpointing_enable(
140
145
  gradient_checkpointing_kwargs={"use_reentrant": True}
141
146
  )
147
+ self.use_cache = False
148
+ else:
149
+ self.use_cache = True
150
+
142
151
  self.model_dtype = get_dtype(self.model)
143
152
 
144
153
  def configure_optimizer(self):
@@ -234,10 +243,22 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
234
243
  ):
235
244
  is_accumulating = (step_idx + 1) % self.accumulate_grad_batches != 0
236
245
 
246
+ if self.max_length > 0 and batch["input_ids"].shape[1] > self.max_length:
247
+ log.warning(
248
+ f"Input length exceeds max_length: {batch['input_ids'].shape[1]} > {self.max_length}. Truncating input."
249
+ )
250
+ batch["input_ids"] = batch["input_ids"][:, : self.max_length]
251
+ batch["attention_mask"] = batch["attention_mask"][:, : self.max_length]
252
+ batch["labels"] = batch["labels"][:, : self.max_length]
237
253
  # disable gradient synchronization if accumulating gradients across steps for improved performance
238
254
  with fabric.no_backward_sync(self.model, enabled=is_accumulating):
239
255
  # use_cache=True is not compatible with gradient checkpointing, so we disable it here
240
- output = self.model(**batch, use_cache=False)
256
+ output = self.model(
257
+ input_ids=batch["input_ids"],
258
+ attention_mask=batch["attention_mask"],
259
+ labels=batch["labels"],
260
+ use_cache=self.use_cache,
261
+ )
241
262
  loss = output["loss"]
242
263
 
243
264
  fabric.backward(loss)
@@ -274,7 +295,7 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
274
295
  ):
275
296
  break
276
297
  # break if max_steps is set, and exit training
277
- if self.max_steps > 0 and self.global_step_idx >= self.max_steps:
298
+ if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1:
278
299
  self.is_training = False
279
300
  break
280
301
 
@@ -350,14 +371,15 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
350
371
  "checkpoints",
351
372
  "latest_model.ckpt",
352
373
  ),
353
- os.path.join(
374
+ dst := os.path.join(
354
375
  self.log_dir,
355
376
  "checkpoints",
356
377
  f"epoch={self.epoch_idx}_step={self.global_step_idx}.ckpt",
357
378
  ),
379
+ target_is_directory=os.path.isdir(dst),
358
380
  )
359
381
  except Exception as e:
360
- pass
382
+ log.error(f"Failed to create symlink: {e}")
361
383
  else:
362
384
  raise ValueError(
363
385
  f"Unknown stage: {stage}. Available options: 'end_of_step', 'end_of_epoch', 'end_of_training'"
@@ -373,24 +395,37 @@ class PeftFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
373
395
  return log.warning(f"Checkpoint already exists at {path}. Skipping save.")
374
396
 
375
397
  fabric = self.fabric
376
- state = {"model": self.model}
377
-
378
- # save the optimizer and lr_scheduler state if needed
379
- if self.save_optimizer_state and save_optimizer_state is not False:
380
- state.update(
381
- {
382
- "optimizer": self.optimizer,
383
- "lr_scheduler": self.lr_scheduler,
384
- "global_step_idx": self.global_step_idx,
385
- "epoch_idx": self.epoch_idx,
386
- }
398
+ if self.save_ckpt_type == "lightning":
399
+ state = {"model": self.model}
400
+
401
+ # save the optimizer and lr_scheduler state if needed
402
+ if self.save_optimizer_state and save_optimizer_state is not False:
403
+ state.update(
404
+ {
405
+ "optimizer": self.optimizer,
406
+ "lr_scheduler": self.lr_scheduler,
407
+ "global_step_idx": self.global_step_idx,
408
+ "epoch_idx": self.epoch_idx,
409
+ }
410
+ )
411
+ trainable_param_names = set(
412
+ name
413
+ for name, param in self.model.state_dict(keep_vars=True).items()
414
+ if param.requires_grad
415
+ )
416
+ filter = (
417
+ None
418
+ if self.save_full_model
419
+ else {"model": lambda k, p: k in trainable_param_names}
387
420
  )
388
421
 
389
- filter = (
390
- None if self.save_full_model else {"model": lambda k, p: p.requires_grad}
391
- )
392
-
393
- fabric.save(path, state=state, filter=filter)
422
+ fabric.save(path, state=state, filter=filter)
423
+ elif self.save_ckpt_type == "peft":
424
+ self.model.save_pretrained(path, is_main_process=fabric.is_global_zero)
425
+ else:
426
+ raise ValueError(
427
+ f"Unknown save_ckpt_type: {self.save_ckpt_type}. Available options: 'lightning', 'peft'"
428
+ )
394
429
  self._latest_saved_checkpoint_global_step = self.global_step_idx
395
430
 
396
431
  def load_checkpoint(self, path: Union[str, Path]):
@@ -1,3 +1,4 @@
1
+ import functools
1
2
  import logging
2
3
  import re
3
4
  from copy import deepcopy
@@ -10,7 +11,7 @@ from tqdm.auto import tqdm
10
11
  from fusion_bench.method import BaseAlgorithm
11
12
  from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
12
13
  from fusion_bench.modelpool import BaseModelPool
13
- import functools
14
+
14
15
  from .prune_utils import unstructured_magnitude_prune_
15
16
 
16
17
  log = logging.getLogger(__name__)
@@ -0,0 +1,3 @@
1
+ # flake8: noqa F401
2
+ from .clip_rankone_moe import CLIPRankOneMoEAlgorithm
3
+ from .rankone_moe import RankOneMoEAlgorithm
@@ -0,0 +1,160 @@
1
+ import functools
2
+ import logging
3
+ import os
4
+ from copy import deepcopy
5
+
6
+ import torch
7
+ from torch import Tensor
8
+ from torch.utils.data import DataLoader
9
+ from transformers.models.clip.modeling_clip import CLIPEncoder
10
+
11
+ from fusion_bench.dataset import CLIPDataset
12
+ from fusion_bench.method.task_arithmetic.task_arithmetic import task_arithmetic_merge
13
+ from fusion_bench.mixins import CLIPClassificationMixin
14
+ from fusion_bench.modelpool import CLIPVisionModelPool
15
+ from fusion_bench.models.rankone_moe import RankOneMoE
16
+ from fusion_bench.utils.data import InfiniteDataLoader
17
+
18
+ from .rankone_moe import RankOneMoEAlgorithm
19
+
20
+ log = logging.getLogger(__name__)
21
+
22
+
23
+ class CLIPRankOneMoEAlgorithm(
24
+ RankOneMoEAlgorithm,
25
+ CLIPClassificationMixin,
26
+ ):
27
+ """
28
+ CLIPRankOneMoEAlgorithm is a class that implements the RankOneMoEAlgorithm (https://github.com/EnnengYang/RankOne-MoE)
29
+ for CLIP models. It extends the RankOneMoEAlgorithm and CLIPClassificationMixin classes.
30
+
31
+ Attributes:
32
+ modelpool (CLIPVisionModelPool): The model pool containing the CLIP models.
33
+ """
34
+
35
+ modelpool: CLIPVisionModelPool = None
36
+
37
+ def load_checkpoint(self, model, checkpoint):
38
+ """
39
+ Load the checkpoint file.
40
+
41
+ Args:
42
+ model: The model to load the checkpoint into.
43
+ checkpoint: The path to the checkpoint file.
44
+ """
45
+ state = {"model": model}
46
+ self._fabric.load(checkpoint, state)
47
+
48
+ def save_checkpoint(self, model, checkpoint):
49
+ """
50
+ Save the checkpoint file.
51
+
52
+ Args:
53
+ model: The model to save the checkpoint from.
54
+ checkpoint: The path to the checkpoint file.
55
+ """
56
+ self._fabric.save(checkpoint, {"model": model})
57
+
58
+ def construct_moe_model(self) -> RankOneMoE:
59
+ """
60
+ Construct the RankOne-MoE model using the models in the model pool.
61
+
62
+ Returns:
63
+ RankOne-MoE: The constructed MoE model.
64
+ """
65
+ base_model = self.modelpool.load_model("_pretrained_")
66
+ expert_models = [
67
+ self.modelpool.load_model(m) for m in self.modelpool.model_names
68
+ ]
69
+
70
+ # Merge the models using task arithmetic
71
+ moe_model = task_arithmetic_merge(
72
+ # This function modifies the model in place, so we need to pass a deepcopy
73
+ deepcopy(base_model),
74
+ expert_models,
75
+ scaling_factor=self.config.init_lambda,
76
+ ).requires_grad_(False)
77
+
78
+ # Up-scale MLP modules
79
+ base_encoder: CLIPEncoder = base_model.vision_model.encoder
80
+ moe_encoder: CLIPEncoder = moe_model.vision_model.encoder
81
+ expert_encoders = [m.vision_model.encoder for m in expert_models]
82
+
83
+ num_layers = len(base_encoder.layers)
84
+ for layer_idx in range(num_layers):
85
+ base_mlp = base_encoder.layers[layer_idx].mlp
86
+ expert_mlps = [e.layers[layer_idx].mlp for e in expert_encoders]
87
+
88
+ moe_encoder.layers[layer_idx].mlp = RankOneMoE(
89
+ hidden_size=base_encoder.config.hidden_size,
90
+ base_model=base_mlp,
91
+ expert_models=expert_mlps,
92
+ init_lambda=self.config.init_lambda,
93
+ batch_first=True, # For open_clip models this is False
94
+ router_hidden_layers=self.config.router_hidden_layers,
95
+ batch_reduce=self.config.batch_reduce,
96
+ svd_accelerator=self.config.svd_accelerator,
97
+ rank_k=self.config.rank_k,
98
+ select_k=self.config.select_k,
99
+ )
100
+
101
+ return moe_model
102
+
103
+ @functools.cache
104
+ def get_shuffled_test_loader_iter(self, tta_dataset: str):
105
+ """
106
+ Get an iterator for the shuffled test data loader.
107
+
108
+ Args:
109
+ tta_dataset (str): The name of the test-time adaptation dataset.
110
+
111
+ Returns:
112
+ Iterator: An iterator for the shuffled test data loader.
113
+ """
114
+ dataset = self.modelpool.load_test_dataset(tta_dataset)
115
+ dataset = CLIPDataset(dataset, processor=self.clip_processor)
116
+ log.info("get_shuffled_test_loader_iter")
117
+ loader = DataLoader(
118
+ dataset,
119
+ batch_size=self.config.batch_size,
120
+ shuffle=True,
121
+ num_workers=self.config.num_workers,
122
+ pin_memory=True,
123
+ )
124
+ loader = self.fabric.setup_dataloaders(loader)
125
+ return iter(InfiniteDataLoader(loader))
126
+
127
+ def on_test_time_adaptation_start(self):
128
+ """
129
+ Load the CLIP processor and construct the zero-shot classification head for each task.
130
+ """
131
+ self.setup_zero_shot_classification_head()
132
+
133
+ def compute_logits(self, module, batch, task) -> Tensor:
134
+ """
135
+ Compute the logits for the given batch and task.
136
+
137
+ Args:
138
+ module: The model module.
139
+ batch: The input batch.
140
+ task: The task name.
141
+
142
+ Returns:
143
+ Tensor: The computed logits.
144
+ """
145
+ images, _ = batch
146
+ text_embeds = self.zeroshot_weights[task]
147
+
148
+ image_embeds = module(images)[1]
149
+ image_embeds = self.visual_projection(image_embeds)
150
+
151
+ # Normalize embeddings
152
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
153
+
154
+ # Cosine similarity
155
+ logits_per_text = (
156
+ torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
157
+ )
158
+ logits_per_image = logits_per_text.t()
159
+
160
+ return logits_per_image