fusion-bench 0.2.17__py3-none-any.whl → 0.2.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +6 -0
- fusion_bench/constants/banner.py +12 -0
- fusion_bench/method/__init__.py +11 -0
- fusion_bench/method/expert_sparsity/__init__.py +10 -0
- fusion_bench/method/expert_sparsity/mixtral/__init__.py +23 -0
- fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +175 -0
- fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +159 -0
- fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +173 -0
- fusion_bench/method/expert_sparsity/utils/calibration_data.py +153 -0
- fusion_bench/method/knots/__init__.py +0 -0
- fusion_bench/method/knots/knots_utils.py +23 -0
- fusion_bench/method/linear/simple_average_for_llama.py +17 -3
- fusion_bench/method/simple_average.py +10 -0
- fusion_bench/method/task_singular_vector/utils/__init__.py +1 -0
- fusion_bench/method/task_singular_vector/utils/task_singular_interference.py +41 -0
- fusion_bench/modelpool/causal_lm/causal_lm.py +45 -11
- fusion_bench/models/__init__.py +1 -0
- fusion_bench/models/expert_sparsity/__init__.py +0 -0
- fusion_bench/models/expert_sparsity/mixtral/__init__.py +15 -0
- fusion_bench/models/expert_sparsity/mixtral/dataset.py +40 -0
- fusion_bench/models/expert_sparsity/mixtral/modeling_mixtral.py +207 -0
- fusion_bench/models/expert_sparsity/mixtral/wrapper.py +268 -0
- fusion_bench/programs/fabric_fusion_program.py +12 -8
- fusion_bench/tasks/clip_classification/imagenet.py +1008 -2004
- fusion_bench/utils/__init__.py +3 -2
- fusion_bench/utils/dtype.py +2 -1
- fusion_bench/utils/fabric.py +11 -4
- fusion_bench/utils/lazy_state_dict.py +155 -13
- fusion_bench/utils/misc.py +19 -1
- fusion_bench/utils/pylogger.py +2 -0
- {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.19.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.19.dist-info}/RECORD +40 -21
- fusion_bench_config/fabric/loggers/mlflow_logger.yaml +2 -0
- fusion_bench_config/method/expert_sparsity/README.md +6 -0
- fusion_bench_config/method/expert_sparsity/mixtral.yaml +17 -0
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +11 -0
- {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.19.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.19.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.19.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.19.dist-info}/top_level.txt +0 -0
fusion_bench/__init__.py
CHANGED
|
@@ -1,3 +1,9 @@
|
|
|
1
|
+
# ███████╗██╗ ██╗███████╗██╗ ██████╗ ███╗ ██╗ ██████╗ ███████╗███╗ ██╗ ██████╗██╗ ██╗
|
|
2
|
+
# ██╔════╝██║ ██║██╔════╝██║██╔═══██╗████╗ ██║ ██╔══██╗██╔════╝████╗ ██║██╔════╝██║ ██║
|
|
3
|
+
# █████╗ ██║ ██║███████╗██║██║ ██║██╔██╗ ██║█████╗██████╔╝█████╗ ██╔██╗ ██║██║ ███████║
|
|
4
|
+
# ██╔══╝ ██║ ██║╚════██║██║██║ ██║██║╚██╗██║╚════╝██╔══██╗██╔══╝ ██║╚██╗██║██║ ██╔══██║
|
|
5
|
+
# ██║ ╚██████╔╝███████║██║╚██████╔╝██║ ╚████║ ██████╔╝███████╗██║ ╚████║╚██████╗██║ ██║
|
|
6
|
+
# ╚═╝ ╚═════╝ ╚══════╝╚═╝ ╚═════╝ ╚═╝ ╚═══╝ ╚═════╝ ╚══════╝╚═╝ ╚═══╝ ╚═════╝╚═╝ ╚═╝
|
|
1
7
|
# flake8: noqa: F401
|
|
2
8
|
from . import (
|
|
3
9
|
constants,
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
FUSION_BENCH_BANNER = (
|
|
2
|
+
""
|
|
3
|
+
+ "███████╗██╗ ██╗███████╗██╗ ██████╗ ███╗ ██╗ ██████╗ ███████╗███╗ ██╗ ██████╗██╗ ██╗\n"
|
|
4
|
+
+ "██╔════╝██║ ██║██╔════╝██║██╔═══██╗████╗ ██║ ██╔══██╗██╔════╝████╗ ██║██╔════╝██║ ██║\n"
|
|
5
|
+
+ "█████╗ ██║ ██║███████╗██║██║ ██║██╔██╗ ██║█████╗██████╔╝█████╗ ██╔██╗ ██║██║ ███████║\n"
|
|
6
|
+
+ "██╔══╝ ██║ ██║╚════██║██║██║ ██║██║╚██╗██║╚════╝██╔══██╗██╔══╝ ██║╚██╗██║██║ ██╔══██║\n"
|
|
7
|
+
+ "██║ ╚██████╔╝███████║██║╚██████╔╝██║ ╚████║ ██████╔╝███████╗██║ ╚████║╚██████╗██║ ██║\n"
|
|
8
|
+
+ "╚═╝ ╚═════╝ ╚══════╝╚═╝ ╚═════╝ ╚═╝ ╚═══╝ ╚═════╝ ╚══════╝╚═╝ ╚═══╝ ╚═════╝╚═╝ ╚═╝\n"
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
if __name__ == "__main__":
|
|
12
|
+
print(FUSION_BENCH_BANNER)
|
fusion_bench/method/__init__.py
CHANGED
|
@@ -111,6 +111,12 @@ _import_structure = {
|
|
|
111
111
|
"SparseLoForLlama",
|
|
112
112
|
"PCPSparseLoForLlama",
|
|
113
113
|
],
|
|
114
|
+
# MoE expert pruning
|
|
115
|
+
"expert_sparsity": [
|
|
116
|
+
"DynamicSkippingPruningForMixtral",
|
|
117
|
+
"LayerWisePruningForMixtral",
|
|
118
|
+
"ProgressivePruningForMixtral",
|
|
119
|
+
],
|
|
114
120
|
}
|
|
115
121
|
|
|
116
122
|
|
|
@@ -142,6 +148,11 @@ if TYPE_CHECKING:
|
|
|
142
148
|
SimpleEnsembleAlgorithm,
|
|
143
149
|
WeightedEnsembleAlgorithm,
|
|
144
150
|
)
|
|
151
|
+
from .expert_sparsity import (
|
|
152
|
+
DynamicSkippingPruningForMixtral,
|
|
153
|
+
LayerWisePruningForMixtral,
|
|
154
|
+
ProgressivePruningForMixtral,
|
|
155
|
+
)
|
|
145
156
|
from .fisher_merging import FisherMergingForCLIPVisionModel
|
|
146
157
|
from .fw_merging import FrankWolfeHardAlgorithm, FrankWolfeSoftAlgorithm
|
|
147
158
|
from .gossip import (
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Original repo: https://github.com/Lucky-Lance/Expert_Sparsity
|
|
3
|
+
|
|
4
|
+
Reference:
|
|
5
|
+
Not All Experts are Equal: Efficient Expert Pruning and Skipping for Mixture-of-Experts Large Language Models.
|
|
6
|
+
ACL 2024.
|
|
7
|
+
http://arxiv.org/abs/2402.14800
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from .mixtral import *
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
R"""
|
|
2
|
+
```bash
|
|
3
|
+
fusion_bench \
|
|
4
|
+
modelpool=CausalLMPool/mixtral-8x7b \
|
|
5
|
+
...
|
|
6
|
+
```
|
|
7
|
+
|
|
8
|
+
if use flash attention 2, pass the following to the command line:
|
|
9
|
+
|
|
10
|
+
```bash
|
|
11
|
+
+modelpool.models._pretrained_.attn_implementation=flash_attention_2
|
|
12
|
+
```
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from .dynamic_skipping import DynamicSkippingPruningForMixtral
|
|
16
|
+
from .layer_wise_pruning import LayerWisePruningForMixtral
|
|
17
|
+
from .progressive_pruning import ProgressivePruningForMixtral
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"DynamicSkippingPruningForMixtral",
|
|
21
|
+
"LayerWisePruningForMixtral",
|
|
22
|
+
"ProgressivePruningForMixtral",
|
|
23
|
+
]
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
R"""
|
|
2
|
+
Example:
|
|
3
|
+
|
|
4
|
+
```bash
|
|
5
|
+
fusion_bench \
|
|
6
|
+
fabric.loggers.name="mixtral_8x7b_expert_pruning/dynamic_skipping" \
|
|
7
|
+
method=expert_sparsity/mixtral \
|
|
8
|
+
method._target_=fusion_bench.method.DynamicSkippingPruningForMixtral \
|
|
9
|
+
modelpool=CausalLMPool/mixtral-8x7b
|
|
10
|
+
```
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import logging
|
|
14
|
+
import os
|
|
15
|
+
|
|
16
|
+
import lightning as L
|
|
17
|
+
import numpy as np
|
|
18
|
+
import torch
|
|
19
|
+
import torch.nn.functional as F
|
|
20
|
+
from torch.utils.data import DataLoader
|
|
21
|
+
from tqdm import tqdm
|
|
22
|
+
from transformers import MixtralForCausalLM
|
|
23
|
+
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
|
|
24
|
+
|
|
25
|
+
import fusion_bench as fb
|
|
26
|
+
from fusion_bench.method.expert_sparsity.utils.calibration_data import (
|
|
27
|
+
build_calib_loader,
|
|
28
|
+
)
|
|
29
|
+
from fusion_bench.models.expert_sparsity.mixtral.wrapper import (
|
|
30
|
+
PrunableMixtralSparseMoeBlockWrapper,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
34
|
+
|
|
35
|
+
logger = logging.getLogger(__name__)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def dynamic_skipping(
|
|
39
|
+
model: MixtralForCausalLM,
|
|
40
|
+
calib_loader: DataLoader,
|
|
41
|
+
batch_size: int,
|
|
42
|
+
):
|
|
43
|
+
assert isinstance(
|
|
44
|
+
model, MixtralForCausalLM
|
|
45
|
+
), "Currently only `Mixtral` is supported"
|
|
46
|
+
|
|
47
|
+
for l, layer in enumerate(model.model.layers):
|
|
48
|
+
layer.block_sparse_moe = PrunableMixtralSparseMoeBlockWrapper(
|
|
49
|
+
layer.block_sparse_moe
|
|
50
|
+
)
|
|
51
|
+
layer.block_sparse_moe.cache_logits = True
|
|
52
|
+
layer.block_sparse_moe.cache_X = True
|
|
53
|
+
layer.block_sparse_moe.cache_Z = True
|
|
54
|
+
|
|
55
|
+
with torch.inference_mode():
|
|
56
|
+
for i, batch in enumerate(
|
|
57
|
+
tqdm(calib_loader, desc="Model forwarding on sample set...")
|
|
58
|
+
):
|
|
59
|
+
model_inputs = model.prepare_inputs_for_generation(**batch)
|
|
60
|
+
outputs = model(**model_inputs)
|
|
61
|
+
assert outputs is not None
|
|
62
|
+
|
|
63
|
+
res_median = {}
|
|
64
|
+
res_mean = {}
|
|
65
|
+
|
|
66
|
+
for layer_idx in range(len(model.model.layers)):
|
|
67
|
+
b = model.model.layers[layer_idx].block_sparse_moe
|
|
68
|
+
b.cache_space.prepare_for_loader()
|
|
69
|
+
dataloader = torch.utils.data.DataLoader(
|
|
70
|
+
b.cache_space,
|
|
71
|
+
batch_size=batch_size,
|
|
72
|
+
shuffle=True,
|
|
73
|
+
)
|
|
74
|
+
logger.info(len(dataloader))
|
|
75
|
+
|
|
76
|
+
ana_list = []
|
|
77
|
+
for i, (router_logits, X, Z) in enumerate(dataloader):
|
|
78
|
+
routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float).view(
|
|
79
|
+
-1, b.model.num_experts
|
|
80
|
+
)
|
|
81
|
+
for j in range(len(routing_weights)):
|
|
82
|
+
sorted_weights, sort_indices = torch.sort(
|
|
83
|
+
routing_weights[j], descending=True
|
|
84
|
+
)
|
|
85
|
+
ana_list.append(float(sorted_weights[1] / sorted_weights[0]))
|
|
86
|
+
|
|
87
|
+
median = np.median(ana_list)
|
|
88
|
+
mean = np.mean(ana_list)
|
|
89
|
+
logger.info(f"layer {layer_idx} | mean: {mean}, median: {median}")
|
|
90
|
+
res_median[str(layer_idx)] = median
|
|
91
|
+
res_mean[str(layer_idx)] = mean
|
|
92
|
+
|
|
93
|
+
for l, layer in enumerate(model.model.layers):
|
|
94
|
+
layer.block_sparse_moe = layer.block_sparse_moe.model
|
|
95
|
+
|
|
96
|
+
model.config.betas = res_median
|
|
97
|
+
return model, (res_median, res_mean)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class DynamicSkippingPruningForMixtral(
|
|
101
|
+
fb.BaseAlgorithm,
|
|
102
|
+
fb.mixins.LightningFabricMixin,
|
|
103
|
+
fb.mixins.SimpleProfilerMixin,
|
|
104
|
+
):
|
|
105
|
+
modelpool: fb.modelpool.CausalLMPool
|
|
106
|
+
|
|
107
|
+
def __init__(
|
|
108
|
+
self,
|
|
109
|
+
calib_set: str,
|
|
110
|
+
max_block_size: int,
|
|
111
|
+
n_blocks_for_stat: int,
|
|
112
|
+
batch_size: int,
|
|
113
|
+
num_workers: int,
|
|
114
|
+
num_preserved_experts: int,
|
|
115
|
+
seed: int = 42,
|
|
116
|
+
model_save_path: str = R"{log_dir}/pruned_model",
|
|
117
|
+
**kwargs,
|
|
118
|
+
):
|
|
119
|
+
super().__init__(**kwargs)
|
|
120
|
+
self.model_save_path = model_save_path
|
|
121
|
+
self.calib_set = calib_set
|
|
122
|
+
self.max_block_size = max_block_size
|
|
123
|
+
self.n_blocks_for_stat = n_blocks_for_stat
|
|
124
|
+
self.batch_size = batch_size
|
|
125
|
+
self.num_workers = num_workers
|
|
126
|
+
self.seed = seed
|
|
127
|
+
self.num_preserved_experts = num_preserved_experts
|
|
128
|
+
|
|
129
|
+
def run(self, modelpool: fb.modelpool.CausalLMPool):
|
|
130
|
+
"""
|
|
131
|
+
Args:
|
|
132
|
+
modelpool (fb.modelpool.CausalLMPool): The model pool to run the algorithm on.
|
|
133
|
+
Example Config: config/modelpool/CausalLMPool/mixtral-8x7b.yaml
|
|
134
|
+
"""
|
|
135
|
+
self.modelpool = modelpool
|
|
136
|
+
# set random seed
|
|
137
|
+
if self.seed is not None:
|
|
138
|
+
L.seed_everything(self.seed)
|
|
139
|
+
# parse model_save_path
|
|
140
|
+
self.model_save_path = self.model_save_path.format(log_dir=self.log_dir)
|
|
141
|
+
|
|
142
|
+
with self.profile("load model"):
|
|
143
|
+
model = modelpool.load_pretrained_or_first_model()
|
|
144
|
+
tokenizer = modelpool.load_tokenizer()
|
|
145
|
+
|
|
146
|
+
# Load the calibration data
|
|
147
|
+
with self.profile("load calibration data"):
|
|
148
|
+
calib_loader = build_calib_loader(
|
|
149
|
+
self.calib_set,
|
|
150
|
+
tokenizer=tokenizer,
|
|
151
|
+
max_block_size=self.max_block_size,
|
|
152
|
+
n_blocks_for_stat=self.n_blocks_for_stat,
|
|
153
|
+
batch_size=self.batch_size,
|
|
154
|
+
num_workers=self.num_workers,
|
|
155
|
+
seed=self.seed,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
with self.profile("prune model"):
|
|
159
|
+
model, info = dynamic_skipping(
|
|
160
|
+
model,
|
|
161
|
+
calib_loader,
|
|
162
|
+
batch_size=self.batch_size,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
if self.model_save_path is not None:
|
|
166
|
+
with self.profile("save model"):
|
|
167
|
+
modelpool.save_model(
|
|
168
|
+
model,
|
|
169
|
+
path=self.model_save_path,
|
|
170
|
+
tokenizer=tokenizer,
|
|
171
|
+
)
|
|
172
|
+
torch.save(info, os.path.join(self.log_dir, "pruning_info.pt"))
|
|
173
|
+
|
|
174
|
+
self.print_profile_summary()
|
|
175
|
+
return model
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
R"""
|
|
2
|
+
Example:
|
|
3
|
+
|
|
4
|
+
```bash
|
|
5
|
+
fusion_bench \
|
|
6
|
+
fabric.loggers.name="mixtral_8x7b_expert_pruning/layer_wise_pruning" \
|
|
7
|
+
method=expert_sparsity/mixtral \
|
|
8
|
+
method._target_=fusion_bench.method.LayerWisePruningForMixtral \
|
|
9
|
+
modelpool=CausalLMPool/mixtral-8x7b
|
|
10
|
+
```
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import logging
|
|
14
|
+
import os
|
|
15
|
+
from typing import cast
|
|
16
|
+
|
|
17
|
+
import lightning as L
|
|
18
|
+
import torch
|
|
19
|
+
from torch.utils.data import DataLoader
|
|
20
|
+
from tqdm import tqdm
|
|
21
|
+
from transformers import MixtralForCausalLM
|
|
22
|
+
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
|
|
23
|
+
|
|
24
|
+
import fusion_bench as fb
|
|
25
|
+
from fusion_bench.method.expert_sparsity.utils.calibration_data import (
|
|
26
|
+
build_calib_loader,
|
|
27
|
+
)
|
|
28
|
+
from fusion_bench.models.expert_sparsity.mixtral import (
|
|
29
|
+
PrunableMixtralSparseMoeBlockWrapper,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def layerwise_pruning(
|
|
38
|
+
model: MixtralForCausalLM,
|
|
39
|
+
calib_loader: DataLoader,
|
|
40
|
+
r: int,
|
|
41
|
+
):
|
|
42
|
+
assert isinstance(
|
|
43
|
+
model, MixtralForCausalLM
|
|
44
|
+
), "Currently only `Mixtral` is supported"
|
|
45
|
+
|
|
46
|
+
for l, layer in enumerate(model.model.layers):
|
|
47
|
+
layer = cast(MixtralDecoderLayer, layer)
|
|
48
|
+
layer.block_sparse_moe = PrunableMixtralSparseMoeBlockWrapper(
|
|
49
|
+
layer.block_sparse_moe, r=r
|
|
50
|
+
)
|
|
51
|
+
layer.block_sparse_moe.cache_X = True
|
|
52
|
+
layer.block_sparse_moe.cache_Z = True
|
|
53
|
+
|
|
54
|
+
with torch.inference_mode():
|
|
55
|
+
for i, batch in enumerate(
|
|
56
|
+
tqdm(calib_loader, desc="Model forwarding on sample set...")
|
|
57
|
+
):
|
|
58
|
+
model_inputs = model.prepare_inputs_for_generation(**batch)
|
|
59
|
+
outputs = model(**model_inputs)
|
|
60
|
+
assert outputs is not None
|
|
61
|
+
|
|
62
|
+
global_loss_history = dict()
|
|
63
|
+
for l, layer in tqdm(
|
|
64
|
+
list(enumerate(model.model.layers)), desc="Enumerating loss on sample set..."
|
|
65
|
+
):
|
|
66
|
+
layer = cast(MixtralDecoderLayer, layer)
|
|
67
|
+
b: PrunableMixtralSparseMoeBlockWrapper = layer.block_sparse_moe
|
|
68
|
+
if not hasattr(b, "cache_space"):
|
|
69
|
+
continue
|
|
70
|
+
loss_history = b.enumerate()
|
|
71
|
+
global_loss_history[l] = loss_history
|
|
72
|
+
b.prune()
|
|
73
|
+
|
|
74
|
+
logger.info("Merging & saving...")
|
|
75
|
+
for l, layer in enumerate(model.model.layers):
|
|
76
|
+
layer.block_sparse_moe = layer.block_sparse_moe.model
|
|
77
|
+
|
|
78
|
+
model.num_experts = r
|
|
79
|
+
model.config.num_local_experts = r
|
|
80
|
+
|
|
81
|
+
return model, (global_loss_history,)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class LayerWisePruningForMixtral(
|
|
85
|
+
fb.BaseAlgorithm,
|
|
86
|
+
fb.mixins.LightningFabricMixin,
|
|
87
|
+
fb.mixins.SimpleProfilerMixin,
|
|
88
|
+
):
|
|
89
|
+
modelpool: fb.modelpool.CausalLMPool
|
|
90
|
+
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
calib_set: str,
|
|
94
|
+
max_block_size: int,
|
|
95
|
+
n_blocks_for_stat: int,
|
|
96
|
+
batch_size: int,
|
|
97
|
+
num_workers: int,
|
|
98
|
+
num_preserved_experts: int,
|
|
99
|
+
seed: int = 42,
|
|
100
|
+
model_save_path: str = R"{log_dir}/pruned_model",
|
|
101
|
+
**kwargs,
|
|
102
|
+
):
|
|
103
|
+
super().__init__(**kwargs)
|
|
104
|
+
self.model_save_path = model_save_path
|
|
105
|
+
self.calib_set = calib_set
|
|
106
|
+
self.max_block_size = max_block_size
|
|
107
|
+
self.n_blocks_for_stat = n_blocks_for_stat
|
|
108
|
+
self.batch_size = batch_size
|
|
109
|
+
self.num_workers = num_workers
|
|
110
|
+
self.seed = seed
|
|
111
|
+
self.num_preserved_experts = num_preserved_experts
|
|
112
|
+
|
|
113
|
+
def run(self, modelpool: fb.modelpool.CausalLMPool):
|
|
114
|
+
"""
|
|
115
|
+
Args:
|
|
116
|
+
modelpool (fb.modelpool.CausalLMPool): The model pool to run the algorithm on.
|
|
117
|
+
Example Config: config/modelpool/CausalLMPool/mixtral-8x7b.yaml
|
|
118
|
+
"""
|
|
119
|
+
self.modelpool = modelpool
|
|
120
|
+
# set random seed
|
|
121
|
+
if self.seed is not None:
|
|
122
|
+
L.seed_everything(self.seed)
|
|
123
|
+
# parse model_save_path
|
|
124
|
+
self.model_save_path = self.model_save_path.format(log_dir=self.log_dir)
|
|
125
|
+
|
|
126
|
+
with self.profile("load model"):
|
|
127
|
+
model = modelpool.load_pretrained_or_first_model()
|
|
128
|
+
tokenizer = modelpool.load_tokenizer()
|
|
129
|
+
|
|
130
|
+
# Load the calibration data
|
|
131
|
+
with self.profile("load calibration data"):
|
|
132
|
+
calib_loader = build_calib_loader(
|
|
133
|
+
self.calib_set,
|
|
134
|
+
tokenizer=tokenizer,
|
|
135
|
+
max_block_size=self.max_block_size,
|
|
136
|
+
n_blocks_for_stat=self.n_blocks_for_stat,
|
|
137
|
+
batch_size=self.batch_size,
|
|
138
|
+
num_workers=self.num_workers,
|
|
139
|
+
seed=self.seed,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
with self.profile("prune model"):
|
|
143
|
+
model, info = layerwise_pruning(
|
|
144
|
+
model,
|
|
145
|
+
calib_loader,
|
|
146
|
+
r=self.num_preserved_experts,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
if self.model_save_path is not None:
|
|
150
|
+
with self.profile("save model"):
|
|
151
|
+
modelpool.save_model(
|
|
152
|
+
model,
|
|
153
|
+
path=self.model_save_path,
|
|
154
|
+
tokenizer=tokenizer,
|
|
155
|
+
)
|
|
156
|
+
torch.save(info, os.path.join(self.log_dir, "pruning_info.pt"))
|
|
157
|
+
|
|
158
|
+
self.print_profile_summary()
|
|
159
|
+
return model
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
R"""
|
|
2
|
+
Example:
|
|
3
|
+
|
|
4
|
+
```bash
|
|
5
|
+
fusion_bench \
|
|
6
|
+
fabric.loggers.name="mixtral_8x7b_expert_pruning/progressive_pruning" \
|
|
7
|
+
method=expert_sparsity/mixtral \
|
|
8
|
+
method._target_=fusion_bench.method.ProgressivePruningForMixtral \
|
|
9
|
+
modelpool=CausalLMPool/mixtral-8x7b
|
|
10
|
+
```
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import logging
|
|
14
|
+
import os
|
|
15
|
+
|
|
16
|
+
import lightning as L
|
|
17
|
+
import torch
|
|
18
|
+
from torch.utils.data import DataLoader
|
|
19
|
+
from tqdm import tqdm
|
|
20
|
+
from transformers import MixtralForCausalLM
|
|
21
|
+
|
|
22
|
+
import fusion_bench as fb
|
|
23
|
+
from fusion_bench.method.expert_sparsity.utils.calibration_data import (
|
|
24
|
+
build_calib_loader,
|
|
25
|
+
)
|
|
26
|
+
from fusion_bench.models.expert_sparsity.mixtral import (
|
|
27
|
+
PrunableMixtralSparseMoeBlockWrapper,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def progressive_pruning(
|
|
36
|
+
model: MixtralForCausalLM,
|
|
37
|
+
calib_loader: DataLoader,
|
|
38
|
+
r: int,
|
|
39
|
+
):
|
|
40
|
+
assert isinstance(
|
|
41
|
+
model, MixtralForCausalLM
|
|
42
|
+
), "Currently only `Mixtral` is supported"
|
|
43
|
+
|
|
44
|
+
for l, layer in enumerate(model.model.layers):
|
|
45
|
+
layer.block_sparse_moe = PrunableMixtralSparseMoeBlockWrapper(
|
|
46
|
+
layer.block_sparse_moe, r=r
|
|
47
|
+
)
|
|
48
|
+
layer.block_sparse_moe.cache_Z = True
|
|
49
|
+
|
|
50
|
+
with torch.inference_mode():
|
|
51
|
+
for i, batch in enumerate(
|
|
52
|
+
tqdm(calib_loader, desc="Computing Z activations on sample set...")
|
|
53
|
+
):
|
|
54
|
+
model_inputs = model.prepare_inputs_for_generation(**batch)
|
|
55
|
+
outputs = model(**model_inputs)
|
|
56
|
+
assert outputs is not None
|
|
57
|
+
|
|
58
|
+
del model_inputs
|
|
59
|
+
del outputs
|
|
60
|
+
torch.cuda.empty_cache()
|
|
61
|
+
|
|
62
|
+
for l, layer in enumerate(model.model.layers):
|
|
63
|
+
layer.block_sparse_moe.cache_Z = False
|
|
64
|
+
|
|
65
|
+
# Drop
|
|
66
|
+
global_loss_history = dict()
|
|
67
|
+
|
|
68
|
+
for l, layer in tqdm(
|
|
69
|
+
list(enumerate(model.model.layers)), desc="Dropping layers..."
|
|
70
|
+
):
|
|
71
|
+
b = layer.block_sparse_moe
|
|
72
|
+
|
|
73
|
+
b.cache_X = True
|
|
74
|
+
with torch.inference_mode():
|
|
75
|
+
for i, batch in enumerate(calib_loader):
|
|
76
|
+
model_inputs = model.prepare_inputs_for_generation(**batch)
|
|
77
|
+
outputs = model(**model_inputs)
|
|
78
|
+
assert outputs is not None
|
|
79
|
+
|
|
80
|
+
del model_inputs
|
|
81
|
+
del outputs
|
|
82
|
+
torch.cuda.empty_cache()
|
|
83
|
+
b.cache_X = False
|
|
84
|
+
|
|
85
|
+
loss_history = b.enumerate()
|
|
86
|
+
global_loss_history[l] = loss_history
|
|
87
|
+
|
|
88
|
+
b.prune()
|
|
89
|
+
layer.block_sparse_moe = b.model
|
|
90
|
+
|
|
91
|
+
# Prune & save
|
|
92
|
+
model.num_experts = r
|
|
93
|
+
model.config.num_local_experts = r
|
|
94
|
+
|
|
95
|
+
return model, (global_loss_history,)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class ProgressivePruningForMixtral(
|
|
99
|
+
fb.BaseAlgorithm,
|
|
100
|
+
fb.mixins.LightningFabricMixin,
|
|
101
|
+
fb.mixins.SimpleProfilerMixin,
|
|
102
|
+
):
|
|
103
|
+
modelpool: fb.modelpool.CausalLMPool
|
|
104
|
+
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
calib_set: str,
|
|
108
|
+
max_block_size: int,
|
|
109
|
+
n_blocks_for_stat: int,
|
|
110
|
+
batch_size: int,
|
|
111
|
+
num_workers: int,
|
|
112
|
+
num_preserved_experts: int,
|
|
113
|
+
seed: int = 42,
|
|
114
|
+
model_save_path: str = R"{log_dir}/pruned_model",
|
|
115
|
+
**kwargs,
|
|
116
|
+
):
|
|
117
|
+
super().__init__(**kwargs)
|
|
118
|
+
self.model_save_path = model_save_path
|
|
119
|
+
self.calib_set = calib_set
|
|
120
|
+
self.max_block_size = max_block_size
|
|
121
|
+
self.n_blocks_for_stat = n_blocks_for_stat
|
|
122
|
+
self.batch_size = batch_size
|
|
123
|
+
self.num_workers = num_workers
|
|
124
|
+
self.seed = seed
|
|
125
|
+
self.num_preserved_experts = num_preserved_experts
|
|
126
|
+
|
|
127
|
+
def run(self, modelpool: fb.modelpool.CausalLMPool):
|
|
128
|
+
"""
|
|
129
|
+
Args:
|
|
130
|
+
modelpool (fb.modelpool.CausalLMPool): The model pool to run the algorithm on.
|
|
131
|
+
Example Config: config/modelpool/CausalLMPool/mixtral-8x7b.yaml
|
|
132
|
+
"""
|
|
133
|
+
self.modelpool = modelpool
|
|
134
|
+
# set random seed
|
|
135
|
+
if self.seed is not None:
|
|
136
|
+
L.seed_everything(self.seed)
|
|
137
|
+
# parse model_save_path
|
|
138
|
+
self.model_save_path = self.model_save_path.format(log_dir=self.log_dir)
|
|
139
|
+
|
|
140
|
+
with self.profile("load model"):
|
|
141
|
+
model = modelpool.load_pretrained_or_first_model()
|
|
142
|
+
tokenizer = modelpool.load_tokenizer()
|
|
143
|
+
|
|
144
|
+
# Load the calibration data
|
|
145
|
+
with self.profile("load calibration data"):
|
|
146
|
+
calib_loader = build_calib_loader(
|
|
147
|
+
self.calib_set,
|
|
148
|
+
tokenizer=tokenizer,
|
|
149
|
+
max_block_size=self.max_block_size,
|
|
150
|
+
n_blocks_for_stat=self.n_blocks_for_stat,
|
|
151
|
+
batch_size=self.batch_size,
|
|
152
|
+
num_workers=self.num_workers,
|
|
153
|
+
seed=self.seed,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
with self.profile("prune model"):
|
|
157
|
+
model, info = progressive_pruning(
|
|
158
|
+
model,
|
|
159
|
+
calib_loader,
|
|
160
|
+
r=self.num_preserved_experts,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
if self.model_save_path is not None:
|
|
164
|
+
with self.profile("save model"):
|
|
165
|
+
modelpool.save_model(
|
|
166
|
+
model,
|
|
167
|
+
path=self.model_save_path,
|
|
168
|
+
tokenizer=tokenizer,
|
|
169
|
+
)
|
|
170
|
+
torch.save(info, os.path.join(self.log_dir, "pruning_info.pt"))
|
|
171
|
+
|
|
172
|
+
self.print_profile_summary()
|
|
173
|
+
return model
|