fusion-bench 0.2.16__py3-none-any.whl → 0.2.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/method/__init__.py +11 -0
- fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py +1 -1
- fusion_bench/method/base_algorithm.py +1 -0
- fusion_bench/method/dawe/dawe_for_clip.py +1 -1
- fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py +3 -2
- fusion_bench/method/expert_sparsity/__init__.py +10 -0
- fusion_bench/method/expert_sparsity/mixtral/__init__.py +23 -0
- fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +175 -0
- fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +159 -0
- fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +173 -0
- fusion_bench/method/expert_sparsity/utils/calibration_data.py +153 -0
- fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +1 -1
- fusion_bench/method/knots/__init__.py +0 -0
- fusion_bench/method/knots/knots_utils.py +23 -0
- fusion_bench/method/pwe_moe/module.py +2 -7
- fusion_bench/method/simple_average.py +3 -2
- fusion_bench/method/task_singular_vector/TSVM.py +238 -25
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +52 -20
- fusion_bench/method/task_singular_vector/utils/__init__.py +1 -0
- fusion_bench/method/task_singular_vector/utils/task_singular_interference.py +41 -0
- fusion_bench/mixins/hydra_config.py +1 -1
- fusion_bench/mixins/lightning_fabric.py +25 -1
- fusion_bench/mixins/serialization.py +18 -2
- fusion_bench/modelpool/base_pool.py +1 -0
- fusion_bench/modelpool/causal_lm/causal_lm.py +8 -5
- fusion_bench/modelpool/clip_vision/modelpool.py +21 -13
- fusion_bench/models/__init__.py +1 -0
- fusion_bench/models/expert_sparsity/__init__.py +0 -0
- fusion_bench/models/expert_sparsity/mixtral/__init__.py +15 -0
- fusion_bench/models/expert_sparsity/mixtral/dataset.py +40 -0
- fusion_bench/models/expert_sparsity/mixtral/modeling_mixtral.py +207 -0
- fusion_bench/models/expert_sparsity/mixtral/wrapper.py +268 -0
- fusion_bench/models/parameter_dict.py +6 -1
- fusion_bench/programs/fabric_fusion_program.py +21 -13
- fusion_bench/taskpool/base_pool.py +1 -0
- fusion_bench/taskpool/dummy.py +6 -4
- fusion_bench/utils/__init__.py +4 -3
- fusion_bench/utils/dtype.py +2 -1
- fusion_bench/utils/fabric.py +11 -4
- fusion_bench/utils/{instantiate.py → instantiate_utils.py} +3 -0
- fusion_bench/utils/lazy_state_dict.py +80 -10
- fusion_bench/utils/pylogger.py +30 -0
- {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/METADATA +3 -1
- {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/RECORD +59 -38
- {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/WHEEL +1 -1
- fusion_bench_config/fabric/loggers/mlflow_logger.yaml +2 -0
- fusion_bench_config/fabric_model_fusion.yaml +2 -2
- fusion_bench_config/method/expert_sparsity/README.md +6 -0
- fusion_bench_config/method/expert_sparsity/mixtral.yaml +17 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_cars_and_dtd.yaml +16 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +16 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +16 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +19 -0
- fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml +0 -1
- {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/top_level.txt +0 -0
fusion_bench/method/__init__.py
CHANGED
|
@@ -111,6 +111,12 @@ _import_structure = {
|
|
|
111
111
|
"SparseLoForLlama",
|
|
112
112
|
"PCPSparseLoForLlama",
|
|
113
113
|
],
|
|
114
|
+
# MoE expert pruning
|
|
115
|
+
"expert_sparsity": [
|
|
116
|
+
"DynamicSkippingPruningForMixtral",
|
|
117
|
+
"LayerWisePruningForMixtral",
|
|
118
|
+
"ProgressivePruningForMixtral",
|
|
119
|
+
],
|
|
114
120
|
}
|
|
115
121
|
|
|
116
122
|
|
|
@@ -142,6 +148,11 @@ if TYPE_CHECKING:
|
|
|
142
148
|
SimpleEnsembleAlgorithm,
|
|
143
149
|
WeightedEnsembleAlgorithm,
|
|
144
150
|
)
|
|
151
|
+
from .expert_sparsity import (
|
|
152
|
+
DynamicSkippingPruningForMixtral,
|
|
153
|
+
LayerWisePruningForMixtral,
|
|
154
|
+
ProgressivePruningForMixtral,
|
|
155
|
+
)
|
|
145
156
|
from .fisher_merging import FisherMergingForCLIPVisionModel
|
|
146
157
|
from .fw_merging import FrankWolfeHardAlgorithm, FrankWolfeSoftAlgorithm
|
|
147
158
|
from .gossip import (
|
|
@@ -29,7 +29,7 @@ from fusion_bench.models.wrappers.layer_wise_fusion import (
|
|
|
29
29
|
get_layer_wise_weights,
|
|
30
30
|
)
|
|
31
31
|
from fusion_bench.utils.data import InfiniteDataLoader, load_tensor_from_file
|
|
32
|
-
from fusion_bench.utils.
|
|
32
|
+
from fusion_bench.utils.instantiate_utils import instantiate
|
|
33
33
|
|
|
34
34
|
from .entropy_loss import entropy_loss
|
|
35
35
|
from .min_norm_solvers import MinNormSolver
|
|
@@ -29,7 +29,7 @@ from fusion_bench.models.wrappers.layer_wise_fusion import (
|
|
|
29
29
|
get_layer_wise_weights,
|
|
30
30
|
)
|
|
31
31
|
from fusion_bench.utils.data import InfiniteDataLoader, load_tensor_from_file
|
|
32
|
-
from fusion_bench.utils.
|
|
32
|
+
from fusion_bench.utils.instantiate_utils import instantiate
|
|
33
33
|
|
|
34
34
|
from .entropy_loss import entropy_loss
|
|
35
35
|
from .min_norm_solvers import MinNormSolver
|
|
@@ -23,7 +23,7 @@ from fusion_bench.mixins import CLIPClassificationMixin
|
|
|
23
23
|
from fusion_bench.modelpool import CLIPVisionModelPool
|
|
24
24
|
from fusion_bench.utils import timeit_context
|
|
25
25
|
from fusion_bench.utils.data import InfiniteDataLoader
|
|
26
|
-
from fusion_bench.utils.
|
|
26
|
+
from fusion_bench.utils.instantiate_utils import instantiate
|
|
27
27
|
|
|
28
28
|
from .warppers.dawe_model import DataAdaptiveWeightEnsemblingCLIPVisionModel
|
|
29
29
|
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
import os
|
|
2
2
|
from typing import Optional
|
|
3
3
|
|
|
4
|
+
from transformers import PreTrainedModel
|
|
4
5
|
from typing_extensions import override
|
|
5
6
|
|
|
6
|
-
from fusion_bench.modelpool.causal_lm.causal_lm import
|
|
7
|
+
from fusion_bench.modelpool.causal_lm.causal_lm import CausalLMPool
|
|
7
8
|
from fusion_bench.utils import timeit_context
|
|
8
9
|
|
|
9
10
|
from .depth_upscaling import DepthUpscalingAlgorithm
|
|
@@ -46,7 +47,7 @@ class DepthUpscalingForLlama(DepthUpscalingAlgorithm):
|
|
|
46
47
|
if self.model_save_path is not None:
|
|
47
48
|
tokenizer = modelpool.load_tokenizer()
|
|
48
49
|
|
|
49
|
-
model:
|
|
50
|
+
model: PreTrainedModel = modelpool.load_pretrained_or_first_model()
|
|
50
51
|
model.model.layers = super().run(model.model.layers)
|
|
51
52
|
model.config.num_hidden_layers = len(model.model.layers)
|
|
52
53
|
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Original repo: https://github.com/Lucky-Lance/Expert_Sparsity
|
|
3
|
+
|
|
4
|
+
Reference:
|
|
5
|
+
Not All Experts are Equal: Efficient Expert Pruning and Skipping for Mixture-of-Experts Large Language Models.
|
|
6
|
+
ACL 2024.
|
|
7
|
+
http://arxiv.org/abs/2402.14800
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from .mixtral import *
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
R"""
|
|
2
|
+
```bash
|
|
3
|
+
fusion_bench \
|
|
4
|
+
modelpool=CausalLMPool/mixtral-8x7b \
|
|
5
|
+
...
|
|
6
|
+
```
|
|
7
|
+
|
|
8
|
+
if use flash attention 2, pass the following to the command line:
|
|
9
|
+
|
|
10
|
+
```bash
|
|
11
|
+
+modelpool.models._pretrained_.attn_implementation=flash_attention_2
|
|
12
|
+
```
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from .dynamic_skipping import DynamicSkippingPruningForMixtral
|
|
16
|
+
from .layer_wise_pruning import LayerWisePruningForMixtral
|
|
17
|
+
from .progressive_pruning import ProgressivePruningForMixtral
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"DynamicSkippingPruningForMixtral",
|
|
21
|
+
"LayerWisePruningForMixtral",
|
|
22
|
+
"ProgressivePruningForMixtral",
|
|
23
|
+
]
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
R"""
|
|
2
|
+
Example:
|
|
3
|
+
|
|
4
|
+
```bash
|
|
5
|
+
fusion_bench \
|
|
6
|
+
fabric.loggers.name="mixtral_8x7b_expert_pruning/dynamic_skipping" \
|
|
7
|
+
method=expert_sparsity/mixtral \
|
|
8
|
+
method._target_=fusion_bench.method.DynamicSkippingPruningForMixtral \
|
|
9
|
+
modelpool=CausalLMPool/mixtral-8x7b
|
|
10
|
+
```
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import logging
|
|
14
|
+
import os
|
|
15
|
+
|
|
16
|
+
import lightning as L
|
|
17
|
+
import numpy as np
|
|
18
|
+
import torch
|
|
19
|
+
import torch.nn.functional as F
|
|
20
|
+
from torch.utils.data import DataLoader
|
|
21
|
+
from tqdm import tqdm
|
|
22
|
+
from transformers import MixtralForCausalLM
|
|
23
|
+
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
|
|
24
|
+
|
|
25
|
+
import fusion_bench as fb
|
|
26
|
+
from fusion_bench.method.expert_sparsity.utils.calibration_data import (
|
|
27
|
+
build_calib_loader,
|
|
28
|
+
)
|
|
29
|
+
from fusion_bench.models.expert_sparsity.mixtral.wrapper import (
|
|
30
|
+
PrunableMixtralSparseMoeBlockWrapper,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
34
|
+
|
|
35
|
+
logger = logging.getLogger(__name__)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def dynamic_skipping(
|
|
39
|
+
model: MixtralForCausalLM,
|
|
40
|
+
calib_loader: DataLoader,
|
|
41
|
+
batch_size: int,
|
|
42
|
+
):
|
|
43
|
+
assert isinstance(
|
|
44
|
+
model, MixtralForCausalLM
|
|
45
|
+
), "Currently only `Mixtral` is supported"
|
|
46
|
+
|
|
47
|
+
for l, layer in enumerate(model.model.layers):
|
|
48
|
+
layer.block_sparse_moe = PrunableMixtralSparseMoeBlockWrapper(
|
|
49
|
+
layer.block_sparse_moe
|
|
50
|
+
)
|
|
51
|
+
layer.block_sparse_moe.cache_logits = True
|
|
52
|
+
layer.block_sparse_moe.cache_X = True
|
|
53
|
+
layer.block_sparse_moe.cache_Z = True
|
|
54
|
+
|
|
55
|
+
with torch.inference_mode():
|
|
56
|
+
for i, batch in enumerate(
|
|
57
|
+
tqdm(calib_loader, desc="Model forwarding on sample set...")
|
|
58
|
+
):
|
|
59
|
+
model_inputs = model.prepare_inputs_for_generation(**batch)
|
|
60
|
+
outputs = model(**model_inputs)
|
|
61
|
+
assert outputs is not None
|
|
62
|
+
|
|
63
|
+
res_median = {}
|
|
64
|
+
res_mean = {}
|
|
65
|
+
|
|
66
|
+
for layer_idx in range(len(model.model.layers)):
|
|
67
|
+
b = model.model.layers[layer_idx].block_sparse_moe
|
|
68
|
+
b.cache_space.prepare_for_loader()
|
|
69
|
+
dataloader = torch.utils.data.DataLoader(
|
|
70
|
+
b.cache_space,
|
|
71
|
+
batch_size=batch_size,
|
|
72
|
+
shuffle=True,
|
|
73
|
+
)
|
|
74
|
+
logger.info(len(dataloader))
|
|
75
|
+
|
|
76
|
+
ana_list = []
|
|
77
|
+
for i, (router_logits, X, Z) in enumerate(dataloader):
|
|
78
|
+
routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float).view(
|
|
79
|
+
-1, b.model.num_experts
|
|
80
|
+
)
|
|
81
|
+
for j in range(len(routing_weights)):
|
|
82
|
+
sorted_weights, sort_indices = torch.sort(
|
|
83
|
+
routing_weights[j], descending=True
|
|
84
|
+
)
|
|
85
|
+
ana_list.append(float(sorted_weights[1] / sorted_weights[0]))
|
|
86
|
+
|
|
87
|
+
median = np.median(ana_list)
|
|
88
|
+
mean = np.mean(ana_list)
|
|
89
|
+
logger.info(f"layer {layer_idx} | mean: {mean}, median: {median}")
|
|
90
|
+
res_median[str(layer_idx)] = median
|
|
91
|
+
res_mean[str(layer_idx)] = mean
|
|
92
|
+
|
|
93
|
+
for l, layer in enumerate(model.model.layers):
|
|
94
|
+
layer.block_sparse_moe = layer.block_sparse_moe.model
|
|
95
|
+
|
|
96
|
+
model.config.betas = res_median
|
|
97
|
+
return model, (res_median, res_mean)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class DynamicSkippingPruningForMixtral(
|
|
101
|
+
fb.BaseAlgorithm,
|
|
102
|
+
fb.mixins.LightningFabricMixin,
|
|
103
|
+
fb.mixins.SimpleProfilerMixin,
|
|
104
|
+
):
|
|
105
|
+
modelpool: fb.modelpool.CausalLMPool
|
|
106
|
+
|
|
107
|
+
def __init__(
|
|
108
|
+
self,
|
|
109
|
+
calib_set: str,
|
|
110
|
+
max_block_size: int,
|
|
111
|
+
n_blocks_for_stat: int,
|
|
112
|
+
batch_size: int,
|
|
113
|
+
num_workers: int,
|
|
114
|
+
num_preserved_experts: int,
|
|
115
|
+
seed: int = 42,
|
|
116
|
+
model_save_path: str = R"{log_dir}/pruned_model",
|
|
117
|
+
**kwargs,
|
|
118
|
+
):
|
|
119
|
+
super().__init__(**kwargs)
|
|
120
|
+
self.model_save_path = model_save_path
|
|
121
|
+
self.calib_set = calib_set
|
|
122
|
+
self.max_block_size = max_block_size
|
|
123
|
+
self.n_blocks_for_stat = n_blocks_for_stat
|
|
124
|
+
self.batch_size = batch_size
|
|
125
|
+
self.num_workers = num_workers
|
|
126
|
+
self.seed = seed
|
|
127
|
+
self.num_preserved_experts = num_preserved_experts
|
|
128
|
+
|
|
129
|
+
def run(self, modelpool: fb.modelpool.CausalLMPool):
|
|
130
|
+
"""
|
|
131
|
+
Args:
|
|
132
|
+
modelpool (fb.modelpool.CausalLMPool): The model pool to run the algorithm on.
|
|
133
|
+
Example Config: config/modelpool/CausalLMPool/mixtral-8x7b.yaml
|
|
134
|
+
"""
|
|
135
|
+
self.modelpool = modelpool
|
|
136
|
+
# set random seed
|
|
137
|
+
if self.seed is not None:
|
|
138
|
+
L.seed_everything(self.seed)
|
|
139
|
+
# parse model_save_path
|
|
140
|
+
self.model_save_path = self.model_save_path.format(log_dir=self.log_dir)
|
|
141
|
+
|
|
142
|
+
with self.profile("load model"):
|
|
143
|
+
model = modelpool.load_pretrained_or_first_model()
|
|
144
|
+
tokenizer = modelpool.load_tokenizer()
|
|
145
|
+
|
|
146
|
+
# Load the calibration data
|
|
147
|
+
with self.profile("load calibration data"):
|
|
148
|
+
calib_loader = build_calib_loader(
|
|
149
|
+
self.calib_set,
|
|
150
|
+
tokenizer=tokenizer,
|
|
151
|
+
max_block_size=self.max_block_size,
|
|
152
|
+
n_blocks_for_stat=self.n_blocks_for_stat,
|
|
153
|
+
batch_size=self.batch_size,
|
|
154
|
+
num_workers=self.num_workers,
|
|
155
|
+
seed=self.seed,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
with self.profile("prune model"):
|
|
159
|
+
model, info = dynamic_skipping(
|
|
160
|
+
model,
|
|
161
|
+
calib_loader,
|
|
162
|
+
batch_size=self.batch_size,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
if self.model_save_path is not None:
|
|
166
|
+
with self.profile("save model"):
|
|
167
|
+
modelpool.save_model(
|
|
168
|
+
model,
|
|
169
|
+
path=self.model_save_path,
|
|
170
|
+
tokenizer=tokenizer,
|
|
171
|
+
)
|
|
172
|
+
torch.save(info, os.path.join(self.log_dir, "pruning_info.pt"))
|
|
173
|
+
|
|
174
|
+
self.print_profile_summary()
|
|
175
|
+
return model
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
R"""
|
|
2
|
+
Example:
|
|
3
|
+
|
|
4
|
+
```bash
|
|
5
|
+
fusion_bench \
|
|
6
|
+
fabric.loggers.name="mixtral_8x7b_expert_pruning/layer_wise_pruning" \
|
|
7
|
+
method=expert_sparsity/mixtral \
|
|
8
|
+
method._target_=fusion_bench.method.LayerWisePruningForMixtral \
|
|
9
|
+
modelpool=CausalLMPool/mixtral-8x7b
|
|
10
|
+
```
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import logging
|
|
14
|
+
import os
|
|
15
|
+
from typing import cast
|
|
16
|
+
|
|
17
|
+
import lightning as L
|
|
18
|
+
import torch
|
|
19
|
+
from torch.utils.data import DataLoader
|
|
20
|
+
from tqdm import tqdm
|
|
21
|
+
from transformers import MixtralForCausalLM
|
|
22
|
+
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
|
|
23
|
+
|
|
24
|
+
import fusion_bench as fb
|
|
25
|
+
from fusion_bench.method.expert_sparsity.utils.calibration_data import (
|
|
26
|
+
build_calib_loader,
|
|
27
|
+
)
|
|
28
|
+
from fusion_bench.models.expert_sparsity.mixtral import (
|
|
29
|
+
PrunableMixtralSparseMoeBlockWrapper,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def layerwise_pruning(
|
|
38
|
+
model: MixtralForCausalLM,
|
|
39
|
+
calib_loader: DataLoader,
|
|
40
|
+
r: int,
|
|
41
|
+
):
|
|
42
|
+
assert isinstance(
|
|
43
|
+
model, MixtralForCausalLM
|
|
44
|
+
), "Currently only `Mixtral` is supported"
|
|
45
|
+
|
|
46
|
+
for l, layer in enumerate(model.model.layers):
|
|
47
|
+
layer = cast(MixtralDecoderLayer, layer)
|
|
48
|
+
layer.block_sparse_moe = PrunableMixtralSparseMoeBlockWrapper(
|
|
49
|
+
layer.block_sparse_moe, r=r
|
|
50
|
+
)
|
|
51
|
+
layer.block_sparse_moe.cache_X = True
|
|
52
|
+
layer.block_sparse_moe.cache_Z = True
|
|
53
|
+
|
|
54
|
+
with torch.inference_mode():
|
|
55
|
+
for i, batch in enumerate(
|
|
56
|
+
tqdm(calib_loader, desc="Model forwarding on sample set...")
|
|
57
|
+
):
|
|
58
|
+
model_inputs = model.prepare_inputs_for_generation(**batch)
|
|
59
|
+
outputs = model(**model_inputs)
|
|
60
|
+
assert outputs is not None
|
|
61
|
+
|
|
62
|
+
global_loss_history = dict()
|
|
63
|
+
for l, layer in tqdm(
|
|
64
|
+
list(enumerate(model.model.layers)), desc="Enumerating loss on sample set..."
|
|
65
|
+
):
|
|
66
|
+
layer = cast(MixtralDecoderLayer, layer)
|
|
67
|
+
b: PrunableMixtralSparseMoeBlockWrapper = layer.block_sparse_moe
|
|
68
|
+
if not hasattr(b, "cache_space"):
|
|
69
|
+
continue
|
|
70
|
+
loss_history = b.enumerate()
|
|
71
|
+
global_loss_history[l] = loss_history
|
|
72
|
+
b.prune()
|
|
73
|
+
|
|
74
|
+
logger.info("Merging & saving...")
|
|
75
|
+
for l, layer in enumerate(model.model.layers):
|
|
76
|
+
layer.block_sparse_moe = layer.block_sparse_moe.model
|
|
77
|
+
|
|
78
|
+
model.num_experts = r
|
|
79
|
+
model.config.num_local_experts = r
|
|
80
|
+
|
|
81
|
+
return model, (global_loss_history,)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class LayerWisePruningForMixtral(
|
|
85
|
+
fb.BaseAlgorithm,
|
|
86
|
+
fb.mixins.LightningFabricMixin,
|
|
87
|
+
fb.mixins.SimpleProfilerMixin,
|
|
88
|
+
):
|
|
89
|
+
modelpool: fb.modelpool.CausalLMPool
|
|
90
|
+
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
calib_set: str,
|
|
94
|
+
max_block_size: int,
|
|
95
|
+
n_blocks_for_stat: int,
|
|
96
|
+
batch_size: int,
|
|
97
|
+
num_workers: int,
|
|
98
|
+
num_preserved_experts: int,
|
|
99
|
+
seed: int = 42,
|
|
100
|
+
model_save_path: str = R"{log_dir}/pruned_model",
|
|
101
|
+
**kwargs,
|
|
102
|
+
):
|
|
103
|
+
super().__init__(**kwargs)
|
|
104
|
+
self.model_save_path = model_save_path
|
|
105
|
+
self.calib_set = calib_set
|
|
106
|
+
self.max_block_size = max_block_size
|
|
107
|
+
self.n_blocks_for_stat = n_blocks_for_stat
|
|
108
|
+
self.batch_size = batch_size
|
|
109
|
+
self.num_workers = num_workers
|
|
110
|
+
self.seed = seed
|
|
111
|
+
self.num_preserved_experts = num_preserved_experts
|
|
112
|
+
|
|
113
|
+
def run(self, modelpool: fb.modelpool.CausalLMPool):
|
|
114
|
+
"""
|
|
115
|
+
Args:
|
|
116
|
+
modelpool (fb.modelpool.CausalLMPool): The model pool to run the algorithm on.
|
|
117
|
+
Example Config: config/modelpool/CausalLMPool/mixtral-8x7b.yaml
|
|
118
|
+
"""
|
|
119
|
+
self.modelpool = modelpool
|
|
120
|
+
# set random seed
|
|
121
|
+
if self.seed is not None:
|
|
122
|
+
L.seed_everything(self.seed)
|
|
123
|
+
# parse model_save_path
|
|
124
|
+
self.model_save_path = self.model_save_path.format(log_dir=self.log_dir)
|
|
125
|
+
|
|
126
|
+
with self.profile("load model"):
|
|
127
|
+
model = modelpool.load_pretrained_or_first_model()
|
|
128
|
+
tokenizer = modelpool.load_tokenizer()
|
|
129
|
+
|
|
130
|
+
# Load the calibration data
|
|
131
|
+
with self.profile("load calibration data"):
|
|
132
|
+
calib_loader = build_calib_loader(
|
|
133
|
+
self.calib_set,
|
|
134
|
+
tokenizer=tokenizer,
|
|
135
|
+
max_block_size=self.max_block_size,
|
|
136
|
+
n_blocks_for_stat=self.n_blocks_for_stat,
|
|
137
|
+
batch_size=self.batch_size,
|
|
138
|
+
num_workers=self.num_workers,
|
|
139
|
+
seed=self.seed,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
with self.profile("prune model"):
|
|
143
|
+
model, info = layerwise_pruning(
|
|
144
|
+
model,
|
|
145
|
+
calib_loader,
|
|
146
|
+
r=self.num_preserved_experts,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
if self.model_save_path is not None:
|
|
150
|
+
with self.profile("save model"):
|
|
151
|
+
modelpool.save_model(
|
|
152
|
+
model,
|
|
153
|
+
path=self.model_save_path,
|
|
154
|
+
tokenizer=tokenizer,
|
|
155
|
+
)
|
|
156
|
+
torch.save(info, os.path.join(self.log_dir, "pruning_info.pt"))
|
|
157
|
+
|
|
158
|
+
self.print_profile_summary()
|
|
159
|
+
return model
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
R"""
|
|
2
|
+
Example:
|
|
3
|
+
|
|
4
|
+
```bash
|
|
5
|
+
fusion_bench \
|
|
6
|
+
fabric.loggers.name="mixtral_8x7b_expert_pruning/progressive_pruning" \
|
|
7
|
+
method=expert_sparsity/mixtral \
|
|
8
|
+
method._target_=fusion_bench.method.ProgressivePruningForMixtral \
|
|
9
|
+
modelpool=CausalLMPool/mixtral-8x7b
|
|
10
|
+
```
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import logging
|
|
14
|
+
import os
|
|
15
|
+
|
|
16
|
+
import lightning as L
|
|
17
|
+
import torch
|
|
18
|
+
from torch.utils.data import DataLoader
|
|
19
|
+
from tqdm import tqdm
|
|
20
|
+
from transformers import MixtralForCausalLM
|
|
21
|
+
|
|
22
|
+
import fusion_bench as fb
|
|
23
|
+
from fusion_bench.method.expert_sparsity.utils.calibration_data import (
|
|
24
|
+
build_calib_loader,
|
|
25
|
+
)
|
|
26
|
+
from fusion_bench.models.expert_sparsity.mixtral import (
|
|
27
|
+
PrunableMixtralSparseMoeBlockWrapper,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def progressive_pruning(
|
|
36
|
+
model: MixtralForCausalLM,
|
|
37
|
+
calib_loader: DataLoader,
|
|
38
|
+
r: int,
|
|
39
|
+
):
|
|
40
|
+
assert isinstance(
|
|
41
|
+
model, MixtralForCausalLM
|
|
42
|
+
), "Currently only `Mixtral` is supported"
|
|
43
|
+
|
|
44
|
+
for l, layer in enumerate(model.model.layers):
|
|
45
|
+
layer.block_sparse_moe = PrunableMixtralSparseMoeBlockWrapper(
|
|
46
|
+
layer.block_sparse_moe, r=r
|
|
47
|
+
)
|
|
48
|
+
layer.block_sparse_moe.cache_Z = True
|
|
49
|
+
|
|
50
|
+
with torch.inference_mode():
|
|
51
|
+
for i, batch in enumerate(
|
|
52
|
+
tqdm(calib_loader, desc="Computing Z activations on sample set...")
|
|
53
|
+
):
|
|
54
|
+
model_inputs = model.prepare_inputs_for_generation(**batch)
|
|
55
|
+
outputs = model(**model_inputs)
|
|
56
|
+
assert outputs is not None
|
|
57
|
+
|
|
58
|
+
del model_inputs
|
|
59
|
+
del outputs
|
|
60
|
+
torch.cuda.empty_cache()
|
|
61
|
+
|
|
62
|
+
for l, layer in enumerate(model.model.layers):
|
|
63
|
+
layer.block_sparse_moe.cache_Z = False
|
|
64
|
+
|
|
65
|
+
# Drop
|
|
66
|
+
global_loss_history = dict()
|
|
67
|
+
|
|
68
|
+
for l, layer in tqdm(
|
|
69
|
+
list(enumerate(model.model.layers)), desc="Dropping layers..."
|
|
70
|
+
):
|
|
71
|
+
b = layer.block_sparse_moe
|
|
72
|
+
|
|
73
|
+
b.cache_X = True
|
|
74
|
+
with torch.inference_mode():
|
|
75
|
+
for i, batch in enumerate(calib_loader):
|
|
76
|
+
model_inputs = model.prepare_inputs_for_generation(**batch)
|
|
77
|
+
outputs = model(**model_inputs)
|
|
78
|
+
assert outputs is not None
|
|
79
|
+
|
|
80
|
+
del model_inputs
|
|
81
|
+
del outputs
|
|
82
|
+
torch.cuda.empty_cache()
|
|
83
|
+
b.cache_X = False
|
|
84
|
+
|
|
85
|
+
loss_history = b.enumerate()
|
|
86
|
+
global_loss_history[l] = loss_history
|
|
87
|
+
|
|
88
|
+
b.prune()
|
|
89
|
+
layer.block_sparse_moe = b.model
|
|
90
|
+
|
|
91
|
+
# Prune & save
|
|
92
|
+
model.num_experts = r
|
|
93
|
+
model.config.num_local_experts = r
|
|
94
|
+
|
|
95
|
+
return model, (global_loss_history,)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class ProgressivePruningForMixtral(
|
|
99
|
+
fb.BaseAlgorithm,
|
|
100
|
+
fb.mixins.LightningFabricMixin,
|
|
101
|
+
fb.mixins.SimpleProfilerMixin,
|
|
102
|
+
):
|
|
103
|
+
modelpool: fb.modelpool.CausalLMPool
|
|
104
|
+
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
calib_set: str,
|
|
108
|
+
max_block_size: int,
|
|
109
|
+
n_blocks_for_stat: int,
|
|
110
|
+
batch_size: int,
|
|
111
|
+
num_workers: int,
|
|
112
|
+
num_preserved_experts: int,
|
|
113
|
+
seed: int = 42,
|
|
114
|
+
model_save_path: str = R"{log_dir}/pruned_model",
|
|
115
|
+
**kwargs,
|
|
116
|
+
):
|
|
117
|
+
super().__init__(**kwargs)
|
|
118
|
+
self.model_save_path = model_save_path
|
|
119
|
+
self.calib_set = calib_set
|
|
120
|
+
self.max_block_size = max_block_size
|
|
121
|
+
self.n_blocks_for_stat = n_blocks_for_stat
|
|
122
|
+
self.batch_size = batch_size
|
|
123
|
+
self.num_workers = num_workers
|
|
124
|
+
self.seed = seed
|
|
125
|
+
self.num_preserved_experts = num_preserved_experts
|
|
126
|
+
|
|
127
|
+
def run(self, modelpool: fb.modelpool.CausalLMPool):
|
|
128
|
+
"""
|
|
129
|
+
Args:
|
|
130
|
+
modelpool (fb.modelpool.CausalLMPool): The model pool to run the algorithm on.
|
|
131
|
+
Example Config: config/modelpool/CausalLMPool/mixtral-8x7b.yaml
|
|
132
|
+
"""
|
|
133
|
+
self.modelpool = modelpool
|
|
134
|
+
# set random seed
|
|
135
|
+
if self.seed is not None:
|
|
136
|
+
L.seed_everything(self.seed)
|
|
137
|
+
# parse model_save_path
|
|
138
|
+
self.model_save_path = self.model_save_path.format(log_dir=self.log_dir)
|
|
139
|
+
|
|
140
|
+
with self.profile("load model"):
|
|
141
|
+
model = modelpool.load_pretrained_or_first_model()
|
|
142
|
+
tokenizer = modelpool.load_tokenizer()
|
|
143
|
+
|
|
144
|
+
# Load the calibration data
|
|
145
|
+
with self.profile("load calibration data"):
|
|
146
|
+
calib_loader = build_calib_loader(
|
|
147
|
+
self.calib_set,
|
|
148
|
+
tokenizer=tokenizer,
|
|
149
|
+
max_block_size=self.max_block_size,
|
|
150
|
+
n_blocks_for_stat=self.n_blocks_for_stat,
|
|
151
|
+
batch_size=self.batch_size,
|
|
152
|
+
num_workers=self.num_workers,
|
|
153
|
+
seed=self.seed,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
with self.profile("prune model"):
|
|
157
|
+
model, info = progressive_pruning(
|
|
158
|
+
model,
|
|
159
|
+
calib_loader,
|
|
160
|
+
r=self.num_preserved_experts,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
if self.model_save_path is not None:
|
|
164
|
+
with self.profile("save model"):
|
|
165
|
+
modelpool.save_model(
|
|
166
|
+
model,
|
|
167
|
+
path=self.model_save_path,
|
|
168
|
+
tokenizer=tokenizer,
|
|
169
|
+
)
|
|
170
|
+
torch.save(info, os.path.join(self.log_dir, "pruning_info.pt"))
|
|
171
|
+
|
|
172
|
+
self.print_profile_summary()
|
|
173
|
+
return model
|