fusion-bench 0.2.16__py3-none-any.whl → 0.2.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. fusion_bench/method/__init__.py +11 -0
  2. fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py +1 -1
  3. fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py +1 -1
  4. fusion_bench/method/base_algorithm.py +1 -0
  5. fusion_bench/method/dawe/dawe_for_clip.py +1 -1
  6. fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py +3 -2
  7. fusion_bench/method/expert_sparsity/__init__.py +10 -0
  8. fusion_bench/method/expert_sparsity/mixtral/__init__.py +23 -0
  9. fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +175 -0
  10. fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +159 -0
  11. fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +173 -0
  12. fusion_bench/method/expert_sparsity/utils/calibration_data.py +153 -0
  13. fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +1 -1
  14. fusion_bench/method/knots/__init__.py +0 -0
  15. fusion_bench/method/knots/knots_utils.py +23 -0
  16. fusion_bench/method/pwe_moe/module.py +2 -7
  17. fusion_bench/method/simple_average.py +3 -2
  18. fusion_bench/method/task_singular_vector/TSVM.py +238 -25
  19. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +52 -20
  20. fusion_bench/method/task_singular_vector/utils/__init__.py +1 -0
  21. fusion_bench/method/task_singular_vector/utils/task_singular_interference.py +41 -0
  22. fusion_bench/mixins/hydra_config.py +1 -1
  23. fusion_bench/mixins/lightning_fabric.py +25 -1
  24. fusion_bench/mixins/serialization.py +18 -2
  25. fusion_bench/modelpool/base_pool.py +1 -0
  26. fusion_bench/modelpool/causal_lm/causal_lm.py +8 -5
  27. fusion_bench/modelpool/clip_vision/modelpool.py +21 -13
  28. fusion_bench/models/__init__.py +1 -0
  29. fusion_bench/models/expert_sparsity/__init__.py +0 -0
  30. fusion_bench/models/expert_sparsity/mixtral/__init__.py +15 -0
  31. fusion_bench/models/expert_sparsity/mixtral/dataset.py +40 -0
  32. fusion_bench/models/expert_sparsity/mixtral/modeling_mixtral.py +207 -0
  33. fusion_bench/models/expert_sparsity/mixtral/wrapper.py +268 -0
  34. fusion_bench/models/parameter_dict.py +6 -1
  35. fusion_bench/programs/fabric_fusion_program.py +21 -13
  36. fusion_bench/taskpool/base_pool.py +1 -0
  37. fusion_bench/taskpool/dummy.py +6 -4
  38. fusion_bench/utils/__init__.py +4 -3
  39. fusion_bench/utils/dtype.py +2 -1
  40. fusion_bench/utils/fabric.py +11 -4
  41. fusion_bench/utils/{instantiate.py → instantiate_utils.py} +3 -0
  42. fusion_bench/utils/lazy_state_dict.py +80 -10
  43. fusion_bench/utils/pylogger.py +30 -0
  44. {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/METADATA +3 -1
  45. {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/RECORD +59 -38
  46. {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/WHEEL +1 -1
  47. fusion_bench_config/fabric/loggers/mlflow_logger.yaml +2 -0
  48. fusion_bench_config/fabric_model_fusion.yaml +2 -2
  49. fusion_bench_config/method/expert_sparsity/README.md +6 -0
  50. fusion_bench_config/method/expert_sparsity/mixtral.yaml +17 -0
  51. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -1
  52. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_cars_and_dtd.yaml +16 -0
  53. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +16 -0
  54. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +16 -0
  55. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +19 -0
  56. fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml +0 -1
  57. {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/entry_points.txt +0 -0
  58. {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/licenses/LICENSE +0 -0
  59. {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/top_level.txt +0 -0
@@ -111,6 +111,12 @@ _import_structure = {
111
111
  "SparseLoForLlama",
112
112
  "PCPSparseLoForLlama",
113
113
  ],
114
+ # MoE expert pruning
115
+ "expert_sparsity": [
116
+ "DynamicSkippingPruningForMixtral",
117
+ "LayerWisePruningForMixtral",
118
+ "ProgressivePruningForMixtral",
119
+ ],
114
120
  }
115
121
 
116
122
 
@@ -142,6 +148,11 @@ if TYPE_CHECKING:
142
148
  SimpleEnsembleAlgorithm,
143
149
  WeightedEnsembleAlgorithm,
144
150
  )
151
+ from .expert_sparsity import (
152
+ DynamicSkippingPruningForMixtral,
153
+ LayerWisePruningForMixtral,
154
+ ProgressivePruningForMixtral,
155
+ )
145
156
  from .fisher_merging import FisherMergingForCLIPVisionModel
146
157
  from .fw_merging import FrankWolfeHardAlgorithm, FrankWolfeSoftAlgorithm
147
158
  from .gossip import (
@@ -29,7 +29,7 @@ from fusion_bench.models.wrappers.layer_wise_fusion import (
29
29
  get_layer_wise_weights,
30
30
  )
31
31
  from fusion_bench.utils.data import InfiniteDataLoader, load_tensor_from_file
32
- from fusion_bench.utils.instantiate import instantiate
32
+ from fusion_bench.utils.instantiate_utils import instantiate
33
33
 
34
34
  from .entropy_loss import entropy_loss
35
35
  from .min_norm_solvers import MinNormSolver
@@ -29,7 +29,7 @@ from fusion_bench.models.wrappers.layer_wise_fusion import (
29
29
  get_layer_wise_weights,
30
30
  )
31
31
  from fusion_bench.utils.data import InfiniteDataLoader, load_tensor_from_file
32
- from fusion_bench.utils.instantiate import instantiate
32
+ from fusion_bench.utils.instantiate_utils import instantiate
33
33
 
34
34
  from .entropy_loss import entropy_loss
35
35
  from .min_norm_solvers import MinNormSolver
@@ -19,6 +19,7 @@ class BaseAlgorithm(BaseYAMLSerializableModel):
19
19
  """
20
20
 
21
21
  _program = None
22
+ _config_key = "method"
22
23
 
23
24
  @abstractmethod
24
25
  def run(self, modelpool: BaseModelPool):
@@ -23,7 +23,7 @@ from fusion_bench.mixins import CLIPClassificationMixin
23
23
  from fusion_bench.modelpool import CLIPVisionModelPool
24
24
  from fusion_bench.utils import timeit_context
25
25
  from fusion_bench.utils.data import InfiniteDataLoader
26
- from fusion_bench.utils.instantiate import instantiate
26
+ from fusion_bench.utils.instantiate_utils import instantiate
27
27
 
28
28
  from .warppers.dawe_model import DataAdaptiveWeightEnsemblingCLIPVisionModel
29
29
 
@@ -1,9 +1,10 @@
1
1
  import os
2
2
  from typing import Optional
3
3
 
4
+ from transformers import PreTrainedModel
4
5
  from typing_extensions import override
5
6
 
6
- from fusion_bench.modelpool.causal_lm.causal_lm import CausalLM, CausalLMPool
7
+ from fusion_bench.modelpool.causal_lm.causal_lm import CausalLMPool
7
8
  from fusion_bench.utils import timeit_context
8
9
 
9
10
  from .depth_upscaling import DepthUpscalingAlgorithm
@@ -46,7 +47,7 @@ class DepthUpscalingForLlama(DepthUpscalingAlgorithm):
46
47
  if self.model_save_path is not None:
47
48
  tokenizer = modelpool.load_tokenizer()
48
49
 
49
- model: CausalLM = modelpool.load_pretrained_or_first_model()
50
+ model: PreTrainedModel = modelpool.load_pretrained_or_first_model()
50
51
  model.model.layers = super().run(model.model.layers)
51
52
  model.config.num_hidden_layers = len(model.model.layers)
52
53
 
@@ -0,0 +1,10 @@
1
+ """
2
+ Original repo: https://github.com/Lucky-Lance/Expert_Sparsity
3
+
4
+ Reference:
5
+ Not All Experts are Equal: Efficient Expert Pruning and Skipping for Mixture-of-Experts Large Language Models.
6
+ ACL 2024.
7
+ http://arxiv.org/abs/2402.14800
8
+ """
9
+
10
+ from .mixtral import *
@@ -0,0 +1,23 @@
1
+ R"""
2
+ ```bash
3
+ fusion_bench \
4
+ modelpool=CausalLMPool/mixtral-8x7b \
5
+ ...
6
+ ```
7
+
8
+ if use flash attention 2, pass the following to the command line:
9
+
10
+ ```bash
11
+ +modelpool.models._pretrained_.attn_implementation=flash_attention_2
12
+ ```
13
+ """
14
+
15
+ from .dynamic_skipping import DynamicSkippingPruningForMixtral
16
+ from .layer_wise_pruning import LayerWisePruningForMixtral
17
+ from .progressive_pruning import ProgressivePruningForMixtral
18
+
19
+ __all__ = [
20
+ "DynamicSkippingPruningForMixtral",
21
+ "LayerWisePruningForMixtral",
22
+ "ProgressivePruningForMixtral",
23
+ ]
@@ -0,0 +1,175 @@
1
+ R"""
2
+ Example:
3
+
4
+ ```bash
5
+ fusion_bench \
6
+ fabric.loggers.name="mixtral_8x7b_expert_pruning/dynamic_skipping" \
7
+ method=expert_sparsity/mixtral \
8
+ method._target_=fusion_bench.method.DynamicSkippingPruningForMixtral \
9
+ modelpool=CausalLMPool/mixtral-8x7b
10
+ ```
11
+ """
12
+
13
+ import logging
14
+ import os
15
+
16
+ import lightning as L
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch.utils.data import DataLoader
21
+ from tqdm import tqdm
22
+ from transformers import MixtralForCausalLM
23
+ from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
24
+
25
+ import fusion_bench as fb
26
+ from fusion_bench.method.expert_sparsity.utils.calibration_data import (
27
+ build_calib_loader,
28
+ )
29
+ from fusion_bench.models.expert_sparsity.mixtral.wrapper import (
30
+ PrunableMixtralSparseMoeBlockWrapper,
31
+ )
32
+
33
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ def dynamic_skipping(
39
+ model: MixtralForCausalLM,
40
+ calib_loader: DataLoader,
41
+ batch_size: int,
42
+ ):
43
+ assert isinstance(
44
+ model, MixtralForCausalLM
45
+ ), "Currently only `Mixtral` is supported"
46
+
47
+ for l, layer in enumerate(model.model.layers):
48
+ layer.block_sparse_moe = PrunableMixtralSparseMoeBlockWrapper(
49
+ layer.block_sparse_moe
50
+ )
51
+ layer.block_sparse_moe.cache_logits = True
52
+ layer.block_sparse_moe.cache_X = True
53
+ layer.block_sparse_moe.cache_Z = True
54
+
55
+ with torch.inference_mode():
56
+ for i, batch in enumerate(
57
+ tqdm(calib_loader, desc="Model forwarding on sample set...")
58
+ ):
59
+ model_inputs = model.prepare_inputs_for_generation(**batch)
60
+ outputs = model(**model_inputs)
61
+ assert outputs is not None
62
+
63
+ res_median = {}
64
+ res_mean = {}
65
+
66
+ for layer_idx in range(len(model.model.layers)):
67
+ b = model.model.layers[layer_idx].block_sparse_moe
68
+ b.cache_space.prepare_for_loader()
69
+ dataloader = torch.utils.data.DataLoader(
70
+ b.cache_space,
71
+ batch_size=batch_size,
72
+ shuffle=True,
73
+ )
74
+ logger.info(len(dataloader))
75
+
76
+ ana_list = []
77
+ for i, (router_logits, X, Z) in enumerate(dataloader):
78
+ routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float).view(
79
+ -1, b.model.num_experts
80
+ )
81
+ for j in range(len(routing_weights)):
82
+ sorted_weights, sort_indices = torch.sort(
83
+ routing_weights[j], descending=True
84
+ )
85
+ ana_list.append(float(sorted_weights[1] / sorted_weights[0]))
86
+
87
+ median = np.median(ana_list)
88
+ mean = np.mean(ana_list)
89
+ logger.info(f"layer {layer_idx} | mean: {mean}, median: {median}")
90
+ res_median[str(layer_idx)] = median
91
+ res_mean[str(layer_idx)] = mean
92
+
93
+ for l, layer in enumerate(model.model.layers):
94
+ layer.block_sparse_moe = layer.block_sparse_moe.model
95
+
96
+ model.config.betas = res_median
97
+ return model, (res_median, res_mean)
98
+
99
+
100
+ class DynamicSkippingPruningForMixtral(
101
+ fb.BaseAlgorithm,
102
+ fb.mixins.LightningFabricMixin,
103
+ fb.mixins.SimpleProfilerMixin,
104
+ ):
105
+ modelpool: fb.modelpool.CausalLMPool
106
+
107
+ def __init__(
108
+ self,
109
+ calib_set: str,
110
+ max_block_size: int,
111
+ n_blocks_for_stat: int,
112
+ batch_size: int,
113
+ num_workers: int,
114
+ num_preserved_experts: int,
115
+ seed: int = 42,
116
+ model_save_path: str = R"{log_dir}/pruned_model",
117
+ **kwargs,
118
+ ):
119
+ super().__init__(**kwargs)
120
+ self.model_save_path = model_save_path
121
+ self.calib_set = calib_set
122
+ self.max_block_size = max_block_size
123
+ self.n_blocks_for_stat = n_blocks_for_stat
124
+ self.batch_size = batch_size
125
+ self.num_workers = num_workers
126
+ self.seed = seed
127
+ self.num_preserved_experts = num_preserved_experts
128
+
129
+ def run(self, modelpool: fb.modelpool.CausalLMPool):
130
+ """
131
+ Args:
132
+ modelpool (fb.modelpool.CausalLMPool): The model pool to run the algorithm on.
133
+ Example Config: config/modelpool/CausalLMPool/mixtral-8x7b.yaml
134
+ """
135
+ self.modelpool = modelpool
136
+ # set random seed
137
+ if self.seed is not None:
138
+ L.seed_everything(self.seed)
139
+ # parse model_save_path
140
+ self.model_save_path = self.model_save_path.format(log_dir=self.log_dir)
141
+
142
+ with self.profile("load model"):
143
+ model = modelpool.load_pretrained_or_first_model()
144
+ tokenizer = modelpool.load_tokenizer()
145
+
146
+ # Load the calibration data
147
+ with self.profile("load calibration data"):
148
+ calib_loader = build_calib_loader(
149
+ self.calib_set,
150
+ tokenizer=tokenizer,
151
+ max_block_size=self.max_block_size,
152
+ n_blocks_for_stat=self.n_blocks_for_stat,
153
+ batch_size=self.batch_size,
154
+ num_workers=self.num_workers,
155
+ seed=self.seed,
156
+ )
157
+
158
+ with self.profile("prune model"):
159
+ model, info = dynamic_skipping(
160
+ model,
161
+ calib_loader,
162
+ batch_size=self.batch_size,
163
+ )
164
+
165
+ if self.model_save_path is not None:
166
+ with self.profile("save model"):
167
+ modelpool.save_model(
168
+ model,
169
+ path=self.model_save_path,
170
+ tokenizer=tokenizer,
171
+ )
172
+ torch.save(info, os.path.join(self.log_dir, "pruning_info.pt"))
173
+
174
+ self.print_profile_summary()
175
+ return model
@@ -0,0 +1,159 @@
1
+ R"""
2
+ Example:
3
+
4
+ ```bash
5
+ fusion_bench \
6
+ fabric.loggers.name="mixtral_8x7b_expert_pruning/layer_wise_pruning" \
7
+ method=expert_sparsity/mixtral \
8
+ method._target_=fusion_bench.method.LayerWisePruningForMixtral \
9
+ modelpool=CausalLMPool/mixtral-8x7b
10
+ ```
11
+ """
12
+
13
+ import logging
14
+ import os
15
+ from typing import cast
16
+
17
+ import lightning as L
18
+ import torch
19
+ from torch.utils.data import DataLoader
20
+ from tqdm import tqdm
21
+ from transformers import MixtralForCausalLM
22
+ from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
23
+
24
+ import fusion_bench as fb
25
+ from fusion_bench.method.expert_sparsity.utils.calibration_data import (
26
+ build_calib_loader,
27
+ )
28
+ from fusion_bench.models.expert_sparsity.mixtral import (
29
+ PrunableMixtralSparseMoeBlockWrapper,
30
+ )
31
+
32
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ def layerwise_pruning(
38
+ model: MixtralForCausalLM,
39
+ calib_loader: DataLoader,
40
+ r: int,
41
+ ):
42
+ assert isinstance(
43
+ model, MixtralForCausalLM
44
+ ), "Currently only `Mixtral` is supported"
45
+
46
+ for l, layer in enumerate(model.model.layers):
47
+ layer = cast(MixtralDecoderLayer, layer)
48
+ layer.block_sparse_moe = PrunableMixtralSparseMoeBlockWrapper(
49
+ layer.block_sparse_moe, r=r
50
+ )
51
+ layer.block_sparse_moe.cache_X = True
52
+ layer.block_sparse_moe.cache_Z = True
53
+
54
+ with torch.inference_mode():
55
+ for i, batch in enumerate(
56
+ tqdm(calib_loader, desc="Model forwarding on sample set...")
57
+ ):
58
+ model_inputs = model.prepare_inputs_for_generation(**batch)
59
+ outputs = model(**model_inputs)
60
+ assert outputs is not None
61
+
62
+ global_loss_history = dict()
63
+ for l, layer in tqdm(
64
+ list(enumerate(model.model.layers)), desc="Enumerating loss on sample set..."
65
+ ):
66
+ layer = cast(MixtralDecoderLayer, layer)
67
+ b: PrunableMixtralSparseMoeBlockWrapper = layer.block_sparse_moe
68
+ if not hasattr(b, "cache_space"):
69
+ continue
70
+ loss_history = b.enumerate()
71
+ global_loss_history[l] = loss_history
72
+ b.prune()
73
+
74
+ logger.info("Merging & saving...")
75
+ for l, layer in enumerate(model.model.layers):
76
+ layer.block_sparse_moe = layer.block_sparse_moe.model
77
+
78
+ model.num_experts = r
79
+ model.config.num_local_experts = r
80
+
81
+ return model, (global_loss_history,)
82
+
83
+
84
+ class LayerWisePruningForMixtral(
85
+ fb.BaseAlgorithm,
86
+ fb.mixins.LightningFabricMixin,
87
+ fb.mixins.SimpleProfilerMixin,
88
+ ):
89
+ modelpool: fb.modelpool.CausalLMPool
90
+
91
+ def __init__(
92
+ self,
93
+ calib_set: str,
94
+ max_block_size: int,
95
+ n_blocks_for_stat: int,
96
+ batch_size: int,
97
+ num_workers: int,
98
+ num_preserved_experts: int,
99
+ seed: int = 42,
100
+ model_save_path: str = R"{log_dir}/pruned_model",
101
+ **kwargs,
102
+ ):
103
+ super().__init__(**kwargs)
104
+ self.model_save_path = model_save_path
105
+ self.calib_set = calib_set
106
+ self.max_block_size = max_block_size
107
+ self.n_blocks_for_stat = n_blocks_for_stat
108
+ self.batch_size = batch_size
109
+ self.num_workers = num_workers
110
+ self.seed = seed
111
+ self.num_preserved_experts = num_preserved_experts
112
+
113
+ def run(self, modelpool: fb.modelpool.CausalLMPool):
114
+ """
115
+ Args:
116
+ modelpool (fb.modelpool.CausalLMPool): The model pool to run the algorithm on.
117
+ Example Config: config/modelpool/CausalLMPool/mixtral-8x7b.yaml
118
+ """
119
+ self.modelpool = modelpool
120
+ # set random seed
121
+ if self.seed is not None:
122
+ L.seed_everything(self.seed)
123
+ # parse model_save_path
124
+ self.model_save_path = self.model_save_path.format(log_dir=self.log_dir)
125
+
126
+ with self.profile("load model"):
127
+ model = modelpool.load_pretrained_or_first_model()
128
+ tokenizer = modelpool.load_tokenizer()
129
+
130
+ # Load the calibration data
131
+ with self.profile("load calibration data"):
132
+ calib_loader = build_calib_loader(
133
+ self.calib_set,
134
+ tokenizer=tokenizer,
135
+ max_block_size=self.max_block_size,
136
+ n_blocks_for_stat=self.n_blocks_for_stat,
137
+ batch_size=self.batch_size,
138
+ num_workers=self.num_workers,
139
+ seed=self.seed,
140
+ )
141
+
142
+ with self.profile("prune model"):
143
+ model, info = layerwise_pruning(
144
+ model,
145
+ calib_loader,
146
+ r=self.num_preserved_experts,
147
+ )
148
+
149
+ if self.model_save_path is not None:
150
+ with self.profile("save model"):
151
+ modelpool.save_model(
152
+ model,
153
+ path=self.model_save_path,
154
+ tokenizer=tokenizer,
155
+ )
156
+ torch.save(info, os.path.join(self.log_dir, "pruning_info.pt"))
157
+
158
+ self.print_profile_summary()
159
+ return model
@@ -0,0 +1,173 @@
1
+ R"""
2
+ Example:
3
+
4
+ ```bash
5
+ fusion_bench \
6
+ fabric.loggers.name="mixtral_8x7b_expert_pruning/progressive_pruning" \
7
+ method=expert_sparsity/mixtral \
8
+ method._target_=fusion_bench.method.ProgressivePruningForMixtral \
9
+ modelpool=CausalLMPool/mixtral-8x7b
10
+ ```
11
+ """
12
+
13
+ import logging
14
+ import os
15
+
16
+ import lightning as L
17
+ import torch
18
+ from torch.utils.data import DataLoader
19
+ from tqdm import tqdm
20
+ from transformers import MixtralForCausalLM
21
+
22
+ import fusion_bench as fb
23
+ from fusion_bench.method.expert_sparsity.utils.calibration_data import (
24
+ build_calib_loader,
25
+ )
26
+ from fusion_bench.models.expert_sparsity.mixtral import (
27
+ PrunableMixtralSparseMoeBlockWrapper,
28
+ )
29
+
30
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ def progressive_pruning(
36
+ model: MixtralForCausalLM,
37
+ calib_loader: DataLoader,
38
+ r: int,
39
+ ):
40
+ assert isinstance(
41
+ model, MixtralForCausalLM
42
+ ), "Currently only `Mixtral` is supported"
43
+
44
+ for l, layer in enumerate(model.model.layers):
45
+ layer.block_sparse_moe = PrunableMixtralSparseMoeBlockWrapper(
46
+ layer.block_sparse_moe, r=r
47
+ )
48
+ layer.block_sparse_moe.cache_Z = True
49
+
50
+ with torch.inference_mode():
51
+ for i, batch in enumerate(
52
+ tqdm(calib_loader, desc="Computing Z activations on sample set...")
53
+ ):
54
+ model_inputs = model.prepare_inputs_for_generation(**batch)
55
+ outputs = model(**model_inputs)
56
+ assert outputs is not None
57
+
58
+ del model_inputs
59
+ del outputs
60
+ torch.cuda.empty_cache()
61
+
62
+ for l, layer in enumerate(model.model.layers):
63
+ layer.block_sparse_moe.cache_Z = False
64
+
65
+ # Drop
66
+ global_loss_history = dict()
67
+
68
+ for l, layer in tqdm(
69
+ list(enumerate(model.model.layers)), desc="Dropping layers..."
70
+ ):
71
+ b = layer.block_sparse_moe
72
+
73
+ b.cache_X = True
74
+ with torch.inference_mode():
75
+ for i, batch in enumerate(calib_loader):
76
+ model_inputs = model.prepare_inputs_for_generation(**batch)
77
+ outputs = model(**model_inputs)
78
+ assert outputs is not None
79
+
80
+ del model_inputs
81
+ del outputs
82
+ torch.cuda.empty_cache()
83
+ b.cache_X = False
84
+
85
+ loss_history = b.enumerate()
86
+ global_loss_history[l] = loss_history
87
+
88
+ b.prune()
89
+ layer.block_sparse_moe = b.model
90
+
91
+ # Prune & save
92
+ model.num_experts = r
93
+ model.config.num_local_experts = r
94
+
95
+ return model, (global_loss_history,)
96
+
97
+
98
+ class ProgressivePruningForMixtral(
99
+ fb.BaseAlgorithm,
100
+ fb.mixins.LightningFabricMixin,
101
+ fb.mixins.SimpleProfilerMixin,
102
+ ):
103
+ modelpool: fb.modelpool.CausalLMPool
104
+
105
+ def __init__(
106
+ self,
107
+ calib_set: str,
108
+ max_block_size: int,
109
+ n_blocks_for_stat: int,
110
+ batch_size: int,
111
+ num_workers: int,
112
+ num_preserved_experts: int,
113
+ seed: int = 42,
114
+ model_save_path: str = R"{log_dir}/pruned_model",
115
+ **kwargs,
116
+ ):
117
+ super().__init__(**kwargs)
118
+ self.model_save_path = model_save_path
119
+ self.calib_set = calib_set
120
+ self.max_block_size = max_block_size
121
+ self.n_blocks_for_stat = n_blocks_for_stat
122
+ self.batch_size = batch_size
123
+ self.num_workers = num_workers
124
+ self.seed = seed
125
+ self.num_preserved_experts = num_preserved_experts
126
+
127
+ def run(self, modelpool: fb.modelpool.CausalLMPool):
128
+ """
129
+ Args:
130
+ modelpool (fb.modelpool.CausalLMPool): The model pool to run the algorithm on.
131
+ Example Config: config/modelpool/CausalLMPool/mixtral-8x7b.yaml
132
+ """
133
+ self.modelpool = modelpool
134
+ # set random seed
135
+ if self.seed is not None:
136
+ L.seed_everything(self.seed)
137
+ # parse model_save_path
138
+ self.model_save_path = self.model_save_path.format(log_dir=self.log_dir)
139
+
140
+ with self.profile("load model"):
141
+ model = modelpool.load_pretrained_or_first_model()
142
+ tokenizer = modelpool.load_tokenizer()
143
+
144
+ # Load the calibration data
145
+ with self.profile("load calibration data"):
146
+ calib_loader = build_calib_loader(
147
+ self.calib_set,
148
+ tokenizer=tokenizer,
149
+ max_block_size=self.max_block_size,
150
+ n_blocks_for_stat=self.n_blocks_for_stat,
151
+ batch_size=self.batch_size,
152
+ num_workers=self.num_workers,
153
+ seed=self.seed,
154
+ )
155
+
156
+ with self.profile("prune model"):
157
+ model, info = progressive_pruning(
158
+ model,
159
+ calib_loader,
160
+ r=self.num_preserved_experts,
161
+ )
162
+
163
+ if self.model_save_path is not None:
164
+ with self.profile("save model"):
165
+ modelpool.save_model(
166
+ model,
167
+ path=self.model_save_path,
168
+ tokenizer=tokenizer,
169
+ )
170
+ torch.save(info, os.path.join(self.log_dir, "pruning_info.pt"))
171
+
172
+ self.print_profile_summary()
173
+ return model