fusion-bench 0.2.20__py3-none-any.whl → 0.2.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. fusion_bench/__init__.py +1 -0
  2. fusion_bench/_get_started/__init__.py +3 -0
  3. fusion_bench/_get_started/greeting_program.py +49 -0
  4. fusion_bench/compat/method/base_algorithm.py +14 -0
  5. fusion_bench/constants/__init__.py +5 -0
  6. fusion_bench/constants/clip_vision.py +26 -2
  7. fusion_bench/constants/paths.py +4 -0
  8. fusion_bench/dataset/clip_dataset.py +2 -1
  9. fusion_bench/dataset/gpt2_glue.py +9 -9
  10. fusion_bench/dataset/image_corruption/__init__.py +0 -0
  11. fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
  12. fusion_bench/dataset/image_dataset.py +1 -1
  13. fusion_bench/dataset/nyuv2.py +2 -2
  14. fusion_bench/method/__init__.py +16 -3
  15. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  16. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
  17. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
  18. fusion_bench/method/base_algorithm.py +195 -12
  19. fusion_bench/method/bitdelta/__init__.py +4 -0
  20. fusion_bench/method/bitdelta/bitdelta.py +156 -0
  21. fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
  22. fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
  23. fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
  24. fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
  25. fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
  26. fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
  27. fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
  28. fusion_bench/method/doge_ta/doge_ta.py +1 -1
  29. fusion_bench/method/ensemble.py +12 -12
  30. fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
  31. fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -2
  32. fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
  33. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
  34. fusion_bench/method/fw_merging/fw_hard.py +1 -1
  35. fusion_bench/method/fw_merging/fw_soft.py +1 -1
  36. fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
  37. fusion_bench/method/linear/expo.py +2 -1
  38. fusion_bench/method/linear/linear_interpolation.py +6 -4
  39. fusion_bench/method/linear/simple_average_for_llama.py +2 -3
  40. fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
  41. fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
  42. fusion_bench/method/model_recombination.py +2 -5
  43. fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
  44. fusion_bench/method/moe_pruner/utils/data.py +2 -1
  45. fusion_bench/method/moe_pruner/utils/prune.py +6 -1
  46. fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
  47. fusion_bench/method/pruning/wanda_utils/data.py +1 -2
  48. fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
  49. fusion_bench/method/randes/modelsoup.py +1 -3
  50. fusion_bench/method/regmean/clip_regmean.py +2 -2
  51. fusion_bench/method/regmean/gpt2_regmean.py +3 -10
  52. fusion_bench/method/regmean/regmean.py +2 -11
  53. fusion_bench/method/regmean_plusplus/__init__.py +1 -1
  54. fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +24 -17
  55. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +56 -38
  56. fusion_bench/method/simple_average.py +5 -9
  57. fusion_bench/method/slerp/slerp.py +5 -2
  58. fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
  59. fusion_bench/method/smile_upscaling/projected_energy.py +145 -0
  60. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +39 -28
  61. fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
  62. fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
  63. fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
  64. fusion_bench/method/ties_merging/ties_merging.py +13 -26
  65. fusion_bench/method/we_moe/clip_we_moe.py +5 -4
  66. fusion_bench/method/we_moe/we_moe.py +6 -6
  67. fusion_bench/method/weighted_average/llama.py +4 -16
  68. fusion_bench/metrics/continual_learning/__init__.py +1 -0
  69. fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
  70. fusion_bench/metrics/nyuv2/__init__.py +2 -2
  71. fusion_bench/metrics/nyuv2/segmentation.py +1 -1
  72. fusion_bench/mixins/__init__.py +10 -2
  73. fusion_bench/mixins/clip_classification.py +4 -3
  74. fusion_bench/mixins/hydra_config.py +105 -7
  75. fusion_bench/mixins/lightning_fabric.py +2 -0
  76. fusion_bench/mixins/serialization.py +265 -48
  77. fusion_bench/modelpool/__init__.py +2 -2
  78. fusion_bench/modelpool/base_pool.py +29 -9
  79. fusion_bench/modelpool/causal_lm/causal_lm.py +9 -0
  80. fusion_bench/modelpool/clip_vision/modelpool.py +1 -3
  81. fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
  82. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
  83. fusion_bench/models/__init__.py +2 -1
  84. fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
  85. fusion_bench/models/hf_utils.py +182 -0
  86. fusion_bench/models/linearized/linearized_model_utils.py +4 -4
  87. fusion_bench/models/linearized/vision_model.py +1 -1
  88. fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
  89. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
  90. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
  91. fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
  92. fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
  93. fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
  94. fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
  95. fusion_bench/models/modeling_smile_llama/__init__.py +0 -0
  96. fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
  97. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +705 -0
  98. fusion_bench/models/modeling_smile_llama/register.py +8 -0
  99. fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
  100. fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
  101. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +6 -7
  102. fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
  103. fusion_bench/models/parameter_dict.py +1 -1
  104. fusion_bench/models/sparse_we_moe.py +1 -53
  105. fusion_bench/models/utils.py +26 -0
  106. fusion_bench/models/we_moe.py +1 -53
  107. fusion_bench/models/wrappers/ensemble.py +6 -4
  108. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
  109. fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
  110. fusion_bench/programs/base_program.py +81 -2
  111. fusion_bench/programs/fabric_fusion_program.py +24 -8
  112. fusion_bench/scripts/cli.py +5 -5
  113. fusion_bench/taskpool/base_pool.py +4 -3
  114. fusion_bench/taskpool/clip_vision/taskpool.py +34 -18
  115. fusion_bench/taskpool/dummy.py +1 -1
  116. fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
  117. fusion_bench/tasks/clip_classification/__init__.py +6 -4
  118. fusion_bench/utils/__init__.py +6 -1
  119. fusion_bench/utils/devices.py +14 -4
  120. fusion_bench/utils/instantiate_utils.py +3 -1
  121. fusion_bench/utils/modelscope.py +127 -8
  122. fusion_bench/utils/parameters.py +2 -2
  123. fusion_bench/utils/rich_utils.py +3 -0
  124. fusion_bench/utils/state_dict_arithmetic.py +25 -23
  125. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/METADATA +24 -25
  126. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/RECORD +165 -134
  127. fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
  128. fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
  129. fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
  130. fusion_bench_config/_get_started/greeting_program.yaml +4 -0
  131. fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
  132. fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
  133. fusion_bench_config/fabric_model_fusion.yaml +45 -17
  134. fusion_bench_config/hydra/default.yaml +6 -2
  135. fusion_bench_config/llama_full_finetune.yaml +1 -0
  136. fusion_bench_config/method/adamerging/clip.yaml +1 -1
  137. fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
  138. fusion_bench_config/method/depth_upscaling.yaml +4 -1
  139. fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
  140. fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
  141. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -0
  142. fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
  143. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
  144. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +1 -1
  145. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -6
  146. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +1 -1
  147. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +1 -1
  148. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +2 -2
  149. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
  150. fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
  151. fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
  152. fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
  153. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
  154. fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
  155. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
  156. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
  157. fusion_bench_config/nyuv2_config.yaml +3 -1
  158. fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
  159. fusion_bench_config/path/default.yaml +28 -0
  160. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
  161. fusion_bench_config/method/adamerging.yaml +0 -23
  162. fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
  163. fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
  164. fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
  165. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/WHEEL +0 -0
  166. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/entry_points.txt +0 -0
  167. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/licenses/LICENSE +0 -0
  168. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/top_level.txt +0 -0
  169. /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
@@ -1,11 +1,12 @@
1
1
  import logging
2
- from typing import List, Mapping, Union # noqa: F401
2
+ from typing import List, Mapping, Optional, Union # noqa: F401
3
3
 
4
4
  import numpy as np
5
5
  import torch
6
6
  from torch import nn
7
7
 
8
8
  from fusion_bench.method import BaseAlgorithm
9
+ from fusion_bench.mixins import auto_register_config
9
10
  from fusion_bench.modelpool import BaseModelPool
10
11
  from fusion_bench.models.wrappers.ensemble import (
11
12
  EnsembleModule,
@@ -18,7 +19,7 @@ log = logging.getLogger(__name__)
18
19
 
19
20
  class SimpleEnsembleAlgorithm(BaseAlgorithm):
20
21
  @torch.no_grad()
21
- def run(self, modelpool: BaseModelPool | List[nn.Module]):
22
+ def run(self, modelpool: BaseModelPool | List[nn.Module]) -> EnsembleModule:
22
23
  """
23
24
  Run the simple ensemble algorithm on the given model pool.
24
25
 
@@ -35,20 +36,19 @@ class SimpleEnsembleAlgorithm(BaseAlgorithm):
35
36
  return ensemble
36
37
 
37
38
 
39
+ @auto_register_config
38
40
  class WeightedEnsembleAlgorithm(BaseAlgorithm):
39
41
 
40
- _config_mapping = BaseAlgorithm._config_mapping | {
41
- "normalize": "normalize",
42
- "weights": "weights",
43
- }
44
-
45
- def __init__(self, normalize: bool, weights: List[float], **kwargs):
46
- self.normalize = normalize
47
- self.weights = weights
42
+ def __init__(
43
+ self,
44
+ normalize: bool = True,
45
+ weights: Optional[List[float]] = None,
46
+ **kwargs,
47
+ ):
48
48
  super().__init__(**kwargs)
49
49
 
50
50
  @torch.no_grad()
51
- def run(self, modelpool: BaseModelPool | List[nn.Module]):
51
+ def run(self, modelpool: BaseModelPool | List[nn.Module]) -> WeightedEnsembleModule:
52
52
  """
53
53
  Run the weighted ensemble algorithm on the given model pool.
54
54
 
@@ -78,7 +78,7 @@ class WeightedEnsembleAlgorithm(BaseAlgorithm):
78
78
 
79
79
  class MaxModelPredictorAlgorithm(BaseAlgorithm):
80
80
  @torch.no_grad()
81
- def run(self, modelpool: BaseModelPool | List[nn.Module]):
81
+ def run(self, modelpool: BaseModelPool | List[nn.Module]) -> MaxModelPredictor:
82
82
  """
83
83
  Run the max model predictor algorithm on the given model pool.
84
84
 
@@ -12,9 +12,9 @@ import os
12
12
  import torch
13
13
  import transformers
14
14
  from datasets import load_dataset
15
+ from huggingface_hub import hf_hub_download
15
16
  from transformers import PreTrainedTokenizer, default_data_collator
16
17
  from transformers.testing_utils import CaptureLogger
17
- from huggingface_hub import hf_hub_download
18
18
 
19
19
  logger = logging.getLogger(__name__)
20
20
 
@@ -65,7 +65,7 @@ class FisherMergingForCLIPVisionModel(
65
65
  minimal_fisher_weight=minimal_fisher_weight,
66
66
  num_fisher_examples=num_fisher_examples,
67
67
  )
68
- self._dataloader_kwargs = dataloader_kwargs
68
+ self.dataloader_kwargs = dataloader_kwargs
69
69
  self.zeroshot_weights_cache_dir = zeroshot_weights_cache_dir
70
70
  for key, value in kwargs.items():
71
71
  log.warning(f"Unused argument: {key}={value}")
@@ -127,7 +127,7 @@ class FisherMergingForCLIPVisionModel(
127
127
  """
128
128
  # setup dataloader
129
129
  train_dataset = CLIPDataset(train_dataset, self.clip_processor)
130
- train_dataloader = DataLoader(train_dataset, **self._dataloader_kwargs)
130
+ train_dataloader = DataLoader(train_dataset, **self.dataloader_kwargs)
131
131
  if self.fabric is not None:
132
132
  train_dataloader = self.fabric.setup_dataloaders(train_dataloader)
133
133
  model = self.fabric.setup(model)
@@ -5,14 +5,14 @@ This implementation is largely based on the implementation from https://github.
5
5
  import logging
6
6
  import re
7
7
  from collections import defaultdict
8
- from typing import Dict, List
8
+ from typing import Any, Dict, List
9
9
 
10
10
  import torch
11
11
  from torch import Tensor, nn
12
12
  from tqdm.autonotebook import tqdm
13
13
 
14
14
  from fusion_bench.method import BaseAlgorithm
15
- from fusion_bench.mixins import SimpleProfilerMixin
15
+ from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
16
16
  from fusion_bench.modelpool import BaseModelPool
17
17
 
18
18
  log = logging.getLogger(__name__)
@@ -353,6 +353,7 @@ def filter_state_dict(
353
353
  return filtered_state_dict
354
354
 
355
355
 
356
+ @auto_register_config
356
357
  class FisherMergingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
357
358
  """
358
359
  Implements the Fisher Merging Algorithm.
@@ -365,13 +366,6 @@ class FisherMergingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
365
366
  Executes the Fisher merging process on the model pool and returns the merged model.
366
367
  """
367
368
 
368
- _config_mapping = BaseAlgorithm._config_mapping | {
369
- "exclude_param_names_regex": "exclude_param_names_regex",
370
- "normalize_fisher_weight": "normalize_fisher_weight",
371
- "minimal_fisher_weight": "minimal_fisher_weight",
372
- "num_fisher_examples": "num_fisher_examples",
373
- }
374
-
375
369
  def __init__(
376
370
  self,
377
371
  *,
@@ -379,12 +373,9 @@ class FisherMergingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
379
373
  normalize_fisher_weight: bool,
380
374
  minimal_fisher_weight: float,
381
375
  num_fisher_examples: int,
376
+ **kwargs,
382
377
  ):
383
- super().__init__()
384
- self.exclude_param_names_regex = exclude_param_names_regex
385
- self.normalize_fisher_weight = normalize_fisher_weight
386
- self.minimal_fisher_weight = minimal_fisher_weight
387
- self.num_fisher_examples = num_fisher_examples
378
+ super().__init__(**kwargs)
388
379
 
389
380
  def run(self, modelpool: BaseModelPool) -> nn.Module:
390
381
  """
@@ -469,7 +460,7 @@ class FisherMergingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
469
460
  self,
470
461
  model_name: str,
471
462
  model: nn.Module,
472
- train_dataset,
463
+ train_dataset: Any,
473
464
  param_names_to_merge: List[str],
474
465
  ) -> Dict[str, Tensor]:
475
466
  """
@@ -18,13 +18,14 @@ from transformers.models.gpt2.modeling_gpt2 import Conv1D
18
18
  from fusion_bench.mixins import LightningFabricMixin
19
19
  from fusion_bench.modelpool import GPT2ForSequenceClassificationPool
20
20
  from fusion_bench.utils import timeit_context
21
-
21
+ from fusion_bench.mixins import auto_register_config
22
22
  from .fisher_merging import FisherMergingAlgorithm, get_param_squared_gradients
23
23
 
24
24
 
25
+ @auto_register_config
25
26
  class FisherMergingAlgorithmForGPT2(
26
- FisherMergingAlgorithm,
27
27
  LightningFabricMixin,
28
+ FisherMergingAlgorithm,
28
29
  ):
29
30
  """
30
31
  Implements the Fisher Merging Algorithm for GPT-2 models on text classification tasks.
@@ -42,11 +43,6 @@ class FisherMergingAlgorithmForGPT2(
42
43
 
43
44
  classifiers = {}
44
45
  modelpool: GPT2ForSequenceClassificationPool = None
45
- _config_mapping = FisherMergingAlgorithm._config_mapping | {
46
- "cache_dir": "cache_dir",
47
- "batch_size": "batch_size",
48
- "num_workers": "num_workers",
49
- }
50
46
 
51
47
  def __init__(
52
48
  self,
@@ -64,9 +60,6 @@ class FisherMergingAlgorithmForGPT2(
64
60
  num_workers (int): Number of workers for data loading.
65
61
  **kwargs: Additional keyword arguments.
66
62
  """
67
- self.cache_dir = cache_dir
68
- self.batch_size = batch_size
69
- self.num_workers = num_workers
70
63
  super().__init__(**kwargs)
71
64
 
72
65
  def on_fisher_merging_start(self):
@@ -223,7 +223,7 @@ class FrankWolfeHardAlgorithm(
223
223
  def get_shuffled_loader_iter(self, task: str):
224
224
  if self.loss_fn == "cross_entropy":
225
225
  # get dataloader kwargs
226
- dataloader_kwargs = self._dataloader_kwargs.copy()
226
+ dataloader_kwargs = self.dataloader_kwargs.copy()
227
227
  dataloader_kwargs["shuffle"] = True
228
228
  dataloader_kwargs["batch_size"] = 1
229
229
 
@@ -193,7 +193,7 @@ class FrankWolfeSoftAlgorithm(
193
193
  @functools.cache
194
194
  def get_shuffled_train_loader_iter(self, task: str, batch_size: int = 1):
195
195
  # get dataloader kwargs
196
- dataloader_kwargs = self._dataloader_kwargs.copy()
196
+ dataloader_kwargs = self.dataloader_kwargs.copy()
197
197
  dataloader_kwargs["shuffle"] = True
198
198
  dataloader_kwargs["batch_size"] = batch_size
199
199
 
@@ -3,13 +3,12 @@ Example Usage:
3
3
 
4
4
  ```bash
5
5
  fusion_bench \
6
- method=adamerging \
6
+ path.log_dir=outputs/ViT-B-32/gossip_layer_wise_adamerging_adam \
7
+ method=adamerging/clip \
7
8
  method.name=clip_layer_wise_adamerging \
8
9
  method.save_merging_weights=merging_weights.pt \
9
- modelpool=clip-vit-base-patch32_TA8 \
10
- taskpool=clip-vit-classification_TA8 \
11
- fabric_logger.root_dir=outputs/logs/ViT-B-32 \
12
- fabric_logger.name=clip_layer_wise_adamerging_adam
10
+ modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8 \
11
+ taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8
13
12
  ```
14
13
  """
15
14
 
@@ -7,6 +7,7 @@ Reference:
7
7
 
8
8
  import logging
9
9
  from copy import deepcopy
10
+ from typing import Union
10
11
 
11
12
  import torch
12
13
  from torch import nn
@@ -79,7 +80,7 @@ class ExPOAlgorithm(BaseAlgorithm):
79
80
  self.extrapolation_factor = extrapolation_factor
80
81
  super().__init__(**kwargs)
81
82
 
82
- def run(self, modelpool: BaseModelPool):
83
+ def run(self, modelpool: Union[BaseModelPool, list]) -> nn.Module:
83
84
  """
84
85
  Run the ExPO merge algorithm.
85
86
 
@@ -1,6 +1,8 @@
1
1
  import logging
2
+ from typing import Any
2
3
 
3
4
  import torch
5
+ from torch import nn
4
6
 
5
7
  from fusion_bench import BaseAlgorithm, BaseModelPool
6
8
  from fusion_bench.utils.state_dict_arithmetic import state_dict_weighted_sum
@@ -10,7 +12,7 @@ log = logging.getLogger(__name__)
10
12
 
11
13
  class LinearInterpolationAlgorithm(BaseAlgorithm):
12
14
  R"""
13
- LinearInterpolationAlgorithm performs linear interpolation between two models.
15
+ `LinearInterpolationAlgorithm` performs linear interpolation between two models.
14
16
  Returns a model with the state dict that is a linear interpolation of the state dicts of the two models.
15
17
  $\theta = (1-t) \theta_1 + t \theta_2$
16
18
  """
@@ -19,9 +21,9 @@ class LinearInterpolationAlgorithm(BaseAlgorithm):
19
21
  "t": "t",
20
22
  }
21
23
 
22
- def __init__(self, t: float, **kwargs):
24
+ def __init__(self, t: float, **kwargs: Any):
23
25
  """
24
- Initialize the LinearInterpolationAlgorithm with the given interpolation parameter.
26
+ Initialize the `LinearInterpolationAlgorithm` with the given interpolation parameter.
25
27
 
26
28
  Args:
27
29
  t (float): The interpolation parameter, should be in the range [0, 1].
@@ -31,7 +33,7 @@ class LinearInterpolationAlgorithm(BaseAlgorithm):
31
33
  self.t = t
32
34
  super().__init__(**kwargs)
33
35
 
34
- def run(self, modelpool: BaseModelPool):
36
+ def run(self, modelpool: BaseModelPool) -> nn.Module:
35
37
  """
36
38
  Run the linear interpolation algorithm on the given model pool.
37
39
 
@@ -1,15 +1,15 @@
1
1
  from copy import deepcopy
2
2
  from typing import TYPE_CHECKING, Optional
3
3
 
4
+ from omegaconf import flag_override
4
5
  from typing_extensions import override
5
6
 
6
7
  from fusion_bench import timeit_context
7
8
  from fusion_bench.method.base_algorithm import BaseAlgorithm
8
9
  from fusion_bench.method.simple_average import SimpleAverageAlgorithm
9
10
  from fusion_bench.modelpool import CausalLMBackbonePool, CausalLMPool
10
- from fusion_bench.utils.pylogger import getRankZeroLogger
11
- from omegaconf import flag_override
12
11
  from fusion_bench.utils import instantiate
12
+ from fusion_bench.utils.pylogger import getRankZeroLogger
13
13
 
14
14
  log = getRankZeroLogger(__name__)
15
15
 
@@ -19,7 +19,6 @@ class SimpleAverageForLlama(BaseAlgorithm):
19
19
  A simple averaging algorithm for LLama models. If `merge_backbone` is set to `True`, the backbone of the model will be averaged and the rest of the model will be loaded from the pre-trained model.
20
20
 
21
21
  Examples:
22
-
23
22
  The following example demonstrates how to use the `SimpleAverageForLlama` algorithm to merge Mistral models.
24
23
 
25
24
  ```bash
@@ -31,7 +31,7 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer
31
31
  from fusion_bench.dataset.llama.collate import bradley_terry_rm_collate
32
32
  from fusion_bench.method import BaseAlgorithm
33
33
  from fusion_bench.mixins import FabricTrainingMixin
34
- from fusion_bench.modelpool import SeqenceClassificationModelPool
34
+ from fusion_bench.modelpool import SequenceClassificationModelPool
35
35
  from fusion_bench.utils import instantiate
36
36
  from fusion_bench.utils.dtype import get_dtype
37
37
 
@@ -121,7 +121,7 @@ class BradleyTerryRewardModeling(BaseAlgorithm, FabricTrainingMixin):
121
121
  self.fix_token_embedding = fix_token_embedding
122
122
  super().__init__(**kwargs)
123
123
 
124
- def run(self, modelpool: SeqenceClassificationModelPool):
124
+ def run(self, modelpool: SequenceClassificationModelPool):
125
125
  self.modelpool = modelpool
126
126
  self.setup()
127
127
  self.train(self.model, self.optimizer, self.lr_scheduler)
@@ -1,5 +1,5 @@
1
1
  import logging
2
- from typing import Optional
2
+ from typing import Any, Optional
3
3
 
4
4
  import torch
5
5
  from tqdm.autonotebook import tqdm
@@ -23,8 +23,7 @@ from transformers.models.mixtral.modeling_mixtral import (
23
23
  )
24
24
  from transformers.utils import ContextManagers
25
25
 
26
- from fusion_bench.method import BaseAlgorithm
27
- from fusion_bench.modelpool import BaseModelPool
26
+ from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
28
27
 
29
28
  log = logging.getLogger(__name__)
30
29
 
@@ -114,7 +113,7 @@ def _upscale_decoder_layer(
114
113
 
115
114
  def upscale_to_mixtral_model(
116
115
  input_model: LlamaModel | MistralModel, output_model: MixtralModel
117
- ):
116
+ ) -> None:
118
117
  """
119
118
  A helper function.
120
119
 
@@ -140,7 +139,7 @@ def upscale_to_mixtral_model(
140
139
 
141
140
  def upscale_to_mixtral_for_causal_lm(
142
141
  input_model: LlamaForCausalLM | MistralForCausalLM, output_model: MixtralForCausalLM
143
- ):
142
+ ) -> None:
144
143
  """
145
144
  A helper function.
146
145
 
@@ -157,24 +156,19 @@ def upscale_to_mixtral_for_causal_lm(
157
156
  upscale_to_mixtral_model(input_model.model, output_model.model)
158
157
 
159
158
 
159
+ @auto_register_config
160
160
  class MixtralUpscalingAlgorithm(BaseAlgorithm):
161
161
  """
162
162
  This class is responsible for upscaling a model to a MixtralModel.
163
163
  It inherits from the ModelFusionAlgorithm class.
164
164
  """
165
165
 
166
- _config_mapping = BaseAlgorithm._config_mapping | {
167
- "num_experts": "num_experts",
168
- "experts_per_token": "experts_per_token",
169
- "save_checkpoint": "save_checkpoint",
170
- }
171
-
172
166
  def __init__(
173
167
  self,
174
168
  num_experts: int,
175
169
  experts_per_token: int,
176
170
  save_checkpoint: str,
177
- **kwargs,
171
+ **kwargs: Any,
178
172
  ):
179
173
  """
180
174
  Initialize the MixtralUpscalingAlgorithm.
@@ -185,9 +179,6 @@ class MixtralUpscalingAlgorithm(BaseAlgorithm):
185
179
  save_checkpoint (str): The path to save the checkpoint.
186
180
  **kwargs: Additional keyword arguments.
187
181
  """
188
- self.num_experts = num_experts
189
- self.experts_per_token = experts_per_token
190
- self.save_checkpoint = save_checkpoint
191
182
  super().__init__(**kwargs)
192
183
 
193
184
  @torch.no_grad()
@@ -242,24 +233,19 @@ class MixtralUpscalingAlgorithm(BaseAlgorithm):
242
233
  return mixtral_model
243
234
 
244
235
 
236
+ @auto_register_config
245
237
  class MixtralForCausalLMUpscalingAlgorithm(BaseAlgorithm):
246
238
  """
247
239
  This class is responsible for upscaling a model to a MixtralForCausalLM.
248
240
  It inherits from the ModelFusionAlgorithm class.
249
241
  """
250
242
 
251
- _config_mapping = BaseAlgorithm._config_mapping | {
252
- "num_experts": "num_experts",
253
- "experts_per_token": "experts_per_token",
254
- "save_checkpoint": "save_checkpoint",
255
- }
256
-
257
243
  def __init__(
258
244
  self,
259
245
  num_experts: int,
260
246
  experts_per_token: int,
261
247
  save_checkpoint: str,
262
- **kwargs,
248
+ **kwargs: Any,
263
249
  ):
264
250
  """
265
251
  Initialize the MixtralForCausalLMUpscalingAlgorithm.
@@ -270,9 +256,6 @@ class MixtralForCausalLMUpscalingAlgorithm(BaseAlgorithm):
270
256
  save_checkpoint (str): The path to save the checkpoint.
271
257
  **kwargs: Additional keyword arguments.
272
258
  """
273
- self.num_experts = num_experts
274
- self.experts_per_token = experts_per_token
275
- self.save_checkpoint = save_checkpoint
276
259
  super().__init__(**kwargs)
277
260
 
278
261
  @torch.no_grad()
@@ -302,7 +285,7 @@ class MixtralForCausalLMUpscalingAlgorithm(BaseAlgorithm):
302
285
  self.config.experts_per_token,
303
286
  )
304
287
 
305
- with ContextManagers([no_init_weights(True)]):
288
+ with ContextManagers([no_init_weights()]):
306
289
  for _ in tqdm(range(1), desc="Initializing Mixtral model"):
307
290
  mixtral_model = MixtralForCausalLM(mixtral_config)
308
291
  upscale_to_mixtral_for_causal_lm(pretrained_model, mixtral_model)
@@ -5,6 +5,7 @@ from typing import List, Mapping, Union # noqa: F401
5
5
  import torch
6
6
  from torch import nn
7
7
 
8
+ from fusion_bench import auto_register_config
8
9
  from fusion_bench.method import BaseAlgorithm
9
10
  from fusion_bench.modelpool import BaseModelPool
10
11
 
@@ -52,17 +53,13 @@ def recombine_state_dict(models: List[nn.Module]):
52
53
  return models
53
54
 
54
55
 
56
+ @auto_register_config
55
57
  class ModelRecombinationAlgorithm(BaseAlgorithm):
56
58
  """
57
59
  Model recombination recombinates the layers of the given models, to create a new set of models.
58
60
  """
59
61
 
60
- _config_mapping = BaseAlgorithm._config_mapping | {
61
- "return_modelpool": "return_modelpool",
62
- }
63
-
64
62
  def __init__(self, return_modelpool: bool, **kwargs):
65
- self.return_modelpool = return_modelpool
66
63
  super().__init__(**kwargs)
67
64
 
68
65
  @torch.no_grad()
@@ -1,6 +1,5 @@
1
- from .hook import BaseHookFn
2
1
  from .deepseek_v2 import (
3
2
  MoEPrunerHookFnForDeepseekV2Gate,
4
3
  MoEPrunerHookFnForDeepseekV2Linear,
5
4
  )
6
-
5
+ from .hook import BaseHookFn
@@ -1,8 +1,9 @@
1
1
  # Code adapted from https://github.com/IST-DASLab/sparsegpt/blob/master/datautils.py
2
2
 
3
+ import os
3
4
  import random
4
5
  from typing import List, Optional, Tuple, cast # noqa: F401
5
- import os
6
+
6
7
  from datasets import load_dataset
7
8
  from torch import Tensor
8
9
  from tqdm.auto import tqdm
@@ -107,7 +107,12 @@ def prepare_calibration_input(
107
107
  device=device,
108
108
  requires_grad=False,
109
109
  )
110
- cache = {"i": 0, "attention_mask": None, "position_ids": None, 'position_embeddings': None}
110
+ cache = {
111
+ "i": 0,
112
+ "attention_mask": None,
113
+ "position_ids": None,
114
+ "position_embeddings": None,
115
+ }
111
116
 
112
117
  class Catcher(nn.Module):
113
118
  def __init__(self, module):
@@ -167,7 +167,7 @@ class MagnitudePruningForLlama(BaseAlgorithm, SimpleProfilerMixin):
167
167
  super().__init__(**kwargs)
168
168
 
169
169
  @torch.no_grad()
170
- def run(self, modelpool: CausalLMPool):
170
+ def run(self, modelpool: CausalLMPool) -> LlamaForCausalLM:
171
171
  """
172
172
  Execute the pruning process on the first model from the given model pool.
173
173
 
@@ -4,12 +4,11 @@ import os
4
4
  import random
5
5
  from typing import List, Optional, Tuple, cast # noqa: F401
6
6
 
7
+ from datasets import load_dataset
7
8
  from torch import Tensor
8
9
  from tqdm.auto import tqdm
9
10
  from transformers import PreTrainedTokenizer
10
11
 
11
- from datasets import load_dataset
12
-
13
12
 
14
13
  # Wrapper for tokenized input IDs
15
14
  class TokenizerWrapper:
@@ -16,14 +16,18 @@ from transformers import CLIPVisionModel
16
16
  from transformers.models.clip.modeling_clip import CLIPEncoderLayer
17
17
  from typing_extensions import override
18
18
 
19
- from fusion_bench.method.base_algorithm import BaseAlgorithm
19
+ from fusion_bench import (
20
+ BaseAlgorithm,
21
+ auto_register_config,
22
+ print_parameters,
23
+ timeit_context,
24
+ )
25
+ from fusion_bench.dataset import CLIPDataset
20
26
  from fusion_bench.method.task_arithmetic import task_arithmetic_merge
21
27
  from fusion_bench.mixins.clip_classification import CLIPClassificationMixin
22
28
  from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
23
29
  from fusion_bench.modelpool import CLIPVisionModelPool
24
- from fusion_bench.utils import timeit_context
25
30
  from fusion_bench.utils.data import InfiniteDataLoader
26
- from fusion_bench.utils.parameters import print_parameters
27
31
 
28
32
  from .module import ParetoWeightEnsemblingModule
29
33
  from .utils import generate_simplex_grid
@@ -31,27 +35,13 @@ from .utils import generate_simplex_grid
31
35
  log = logging.getLogger(__name__)
32
36
 
33
37
 
38
+ @auto_register_config
34
39
  class PWEMoEAlgorithmForCLIP(
35
40
  BaseAlgorithm,
36
41
  SimpleProfilerMixin,
37
42
  CLIPClassificationMixin,
38
43
  ):
39
44
  modelpool: CLIPVisionModelPool = None
40
- _config_mapping = BaseAlgorithm._config_mapping | {
41
- "upscale_mlp": "upscale_mlp",
42
- "upscale_attn": "upscale_attn",
43
- "init_lambda": "init_lambda",
44
- "router_hidden_layers": "router_hidden_layers",
45
- "lr": "lr",
46
- "num_steps": "num_steps",
47
- "save_interval": "save_interval",
48
- "alpha": "alpha",
49
- "checkpoint_path": "checkpoint_path",
50
- "eval_grid": "eval_grid",
51
- "eval_grid_n": "eval_grid_n",
52
- "eval_grid_m": "eval_grid_m",
53
- "_dataloader_kwargs": "dataloader_kwargs",
54
- }
55
45
 
56
46
  def __init__(
57
47
  self,
@@ -72,19 +62,6 @@ class PWEMoEAlgorithmForCLIP(
72
62
  **kwargs,
73
63
  ):
74
64
  super().__init__(**kwargs)
75
- self.upscale_mlp = upscale_mlp
76
- self.upscale_attn = upscale_attn
77
- self.init_lambda = init_lambda
78
- self.router_hidden_layers = router_hidden_layers
79
- self.lr = lr
80
- self.num_steps = num_steps
81
- self.save_interval = save_interval
82
- self.alpha = alpha
83
- self.checkpoint_path = checkpoint_path
84
- self.eval_grid = eval_grid
85
- self.eval_grid_n = eval_grid_n
86
- self.eval_grid_m = eval_grid_m
87
- self._dataloader_kwargs = dataloader_kwargs
88
65
 
89
66
  @override
90
67
  def run(self, modelpool: CLIPVisionModelPool):
@@ -193,13 +170,14 @@ class PWEMoEAlgorithmForCLIP(
193
170
  Loads the datasets specified in the configuration.
194
171
  """
195
172
  train_datasets = {
196
- dataset_name: self.modelpool.load_train_dataset(
197
- dataset_name, self.clip_processor
173
+ dataset_name: CLIPDataset(
174
+ self.modelpool.load_train_dataset(dataset_name),
175
+ processor=self.clip_processor,
198
176
  )
199
177
  for dataset_name in self.modelpool.model_names
200
178
  }
201
179
  train_loaders = {
202
- dataset_name: DataLoader(dataset, shuffle=True, **self._dataloader_kwargs)
180
+ dataset_name: DataLoader(dataset, shuffle=True, **self.dataloader_kwargs)
203
181
  for dataset_name, dataset in train_datasets.items()
204
182
  }
205
183
  train_loaders = {
@@ -5,9 +5,7 @@ import torch
5
5
 
6
6
  from fusion_bench.modelpool import BaseModelPool
7
7
  from fusion_bench.utils.parameters import count_parameters
8
- from fusion_bench.utils.state_dict_arithmetic import (
9
- state_dict_mul,
10
- )
8
+ from fusion_bench.utils.state_dict_arithmetic import state_dict_mul
11
9
 
12
10
  from .base_algorithm import SuperposedAlgorithmBase, compare_models
13
11
 
@@ -27,7 +27,7 @@ class RegMeanAlgorithmForCLIP(
27
27
 
28
28
  def __init__(self, *, dataloader_kwargs: DictConfig, **kwargs):
29
29
  super().__init__(**kwargs)
30
- self._dataloader_kwargs = dataloader_kwargs
30
+ self.dataloader_kwargs = dataloader_kwargs
31
31
 
32
32
  def on_regmean_start(self):
33
33
  self.setup_zero_shot_classification_head()
@@ -60,7 +60,7 @@ class RegMeanAlgorithmForCLIP(
60
60
  # setup dataloader
61
61
  train_dataset = CLIPDataset(train_dataset, self.clip_processor)
62
62
  train_dataloader = DataLoader(
63
- train_dataset, shuffle=True, **self._dataloader_kwargs
63
+ train_dataset, shuffle=True, **self.dataloader_kwargs
64
64
  )
65
65
  train_dataloader = self.fabric.setup_dataloaders(train_dataloader)
66
66
  model = self.fabric.setup(model)