fusion-bench 0.2.17__py3-none-any.whl → 0.2.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +6 -0
- fusion_bench/constants/banner.py +12 -0
- fusion_bench/method/__init__.py +11 -0
- fusion_bench/method/expert_sparsity/__init__.py +10 -0
- fusion_bench/method/expert_sparsity/mixtral/__init__.py +23 -0
- fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +175 -0
- fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +159 -0
- fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +173 -0
- fusion_bench/method/expert_sparsity/utils/calibration_data.py +153 -0
- fusion_bench/method/knots/__init__.py +0 -0
- fusion_bench/method/knots/knots_utils.py +23 -0
- fusion_bench/method/linear/simple_average_for_llama.py +17 -3
- fusion_bench/method/simple_average.py +10 -0
- fusion_bench/method/task_singular_vector/utils/__init__.py +1 -0
- fusion_bench/method/task_singular_vector/utils/task_singular_interference.py +41 -0
- fusion_bench/modelpool/causal_lm/causal_lm.py +45 -11
- fusion_bench/models/__init__.py +1 -0
- fusion_bench/models/expert_sparsity/__init__.py +0 -0
- fusion_bench/models/expert_sparsity/mixtral/__init__.py +15 -0
- fusion_bench/models/expert_sparsity/mixtral/dataset.py +40 -0
- fusion_bench/models/expert_sparsity/mixtral/modeling_mixtral.py +207 -0
- fusion_bench/models/expert_sparsity/mixtral/wrapper.py +268 -0
- fusion_bench/programs/fabric_fusion_program.py +12 -8
- fusion_bench/tasks/clip_classification/imagenet.py +1008 -2004
- fusion_bench/utils/__init__.py +3 -2
- fusion_bench/utils/dtype.py +2 -1
- fusion_bench/utils/fabric.py +11 -4
- fusion_bench/utils/lazy_state_dict.py +155 -13
- fusion_bench/utils/misc.py +19 -1
- fusion_bench/utils/pylogger.py +2 -0
- {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.19.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.19.dist-info}/RECORD +40 -21
- fusion_bench_config/fabric/loggers/mlflow_logger.yaml +2 -0
- fusion_bench_config/method/expert_sparsity/README.md +6 -0
- fusion_bench_config/method/expert_sparsity/mixtral.yaml +17 -0
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +11 -0
- {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.19.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.19.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.19.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.19.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
import itertools as I
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
from transformers.models.mixtral.modeling_mixtral import (
|
|
9
|
+
MixtralBlockSparseTop2MLP,
|
|
10
|
+
MixtralForCausalLM,
|
|
11
|
+
MixtralSparseMoeBlock,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from .dataset import CacheDataset
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class PrunableMixtralSparseMoeBlockWrapper(torch.nn.Module):
|
|
20
|
+
"""
|
|
21
|
+
Wrapper of `MixtralSparseMoeBlock` that supports expert pruning.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
model: MixtralSparseMoeBlock,
|
|
27
|
+
r: Optional[int] = None,
|
|
28
|
+
):
|
|
29
|
+
"""
|
|
30
|
+
Args:
|
|
31
|
+
model: The model to be wrapped.
|
|
32
|
+
r: The number of experts to keep.
|
|
33
|
+
"""
|
|
34
|
+
super().__init__()
|
|
35
|
+
if isinstance(model, MixtralSparseMoeBlock):
|
|
36
|
+
self.model = model
|
|
37
|
+
else:
|
|
38
|
+
self.model = model.model
|
|
39
|
+
self.r = r
|
|
40
|
+
|
|
41
|
+
self.experts_to_drop = None
|
|
42
|
+
self.cache_space = CacheDataset()
|
|
43
|
+
self.cache_logits = False
|
|
44
|
+
self.cache_X = False
|
|
45
|
+
self.cache_Z = False
|
|
46
|
+
|
|
47
|
+
# Forward uses topk
|
|
48
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
49
|
+
""" """
|
|
50
|
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
|
51
|
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
52
|
+
# router_logits: (batch * sequence_length, n_experts)
|
|
53
|
+
router_logits = self.model.gate(hidden_states)
|
|
54
|
+
|
|
55
|
+
if self.experts_to_drop is not None:
|
|
56
|
+
for e in self.experts_to_drop:
|
|
57
|
+
router_logits[:, e] = -float("inf")
|
|
58
|
+
|
|
59
|
+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
|
60
|
+
routing_weights, selected_experts = torch.topk(
|
|
61
|
+
routing_weights, self.model.top_k, dim=-1
|
|
62
|
+
)
|
|
63
|
+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
|
64
|
+
# we cast back to the input dtype
|
|
65
|
+
routing_weights = routing_weights.to(hidden_states.dtype)
|
|
66
|
+
|
|
67
|
+
final_hidden_states = torch.zeros(
|
|
68
|
+
(batch_size * sequence_length, hidden_dim),
|
|
69
|
+
dtype=hidden_states.dtype,
|
|
70
|
+
device=hidden_states.device,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# One hot encode the selected experts to create an expert mask
|
|
74
|
+
# this will be used to easily index which expert is going to be sollicitated
|
|
75
|
+
expert_mask = torch.nn.functional.one_hot(
|
|
76
|
+
selected_experts, num_classes=self.model.num_experts
|
|
77
|
+
).permute(2, 1, 0)
|
|
78
|
+
|
|
79
|
+
# Loop over all available experts in the model and perform the computation on each expert
|
|
80
|
+
for expert_idx in range(self.model.num_experts):
|
|
81
|
+
expert_layer = self.model.experts[expert_idx]
|
|
82
|
+
idx, top_x = torch.where(expert_mask[expert_idx])
|
|
83
|
+
|
|
84
|
+
if top_x.shape[0] == 0:
|
|
85
|
+
continue
|
|
86
|
+
|
|
87
|
+
# in torch it is faster to index using lists than torch tensors
|
|
88
|
+
top_x_list = top_x.tolist()
|
|
89
|
+
idx_list = idx.tolist()
|
|
90
|
+
|
|
91
|
+
# Index the correct hidden states and compute the expert hidden state for
|
|
92
|
+
# the current expert. We need to make sure to multiply the output hidden
|
|
93
|
+
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
|
94
|
+
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
|
|
95
|
+
current_hidden_states = (
|
|
96
|
+
expert_layer(current_state)
|
|
97
|
+
* routing_weights[top_x_list, idx_list, None]
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# However `index_add_` only support torch tensors for indexing so we'll use
|
|
101
|
+
# the `top_x` tensor here.
|
|
102
|
+
final_hidden_states.index_add_(
|
|
103
|
+
0, top_x, current_hidden_states.to(hidden_states.dtype)
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
if self.experts_to_drop is not None and (
|
|
107
|
+
self.cache_logits or self.cache_X or self.cache_Z
|
|
108
|
+
):
|
|
109
|
+
logger.warn(
|
|
110
|
+
f"Already dropped {self.experts_to_drop} but still storing activations."
|
|
111
|
+
)
|
|
112
|
+
self.cache_space.append(
|
|
113
|
+
alpha=(router_logits if self.cache_logits else None),
|
|
114
|
+
X=(hidden_states if self.cache_X else None),
|
|
115
|
+
Z=(final_hidden_states if self.cache_Z else None),
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
final_hidden_states = final_hidden_states.reshape(
|
|
119
|
+
batch_size, sequence_length, hidden_dim
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
return final_hidden_states, router_logits
|
|
123
|
+
|
|
124
|
+
@torch.no_grad()
|
|
125
|
+
def enumerate(self):
|
|
126
|
+
# disable caching
|
|
127
|
+
self.cache_logits = False
|
|
128
|
+
self.cache_X = False
|
|
129
|
+
self.cache_Z = False
|
|
130
|
+
loss_history = dict()
|
|
131
|
+
|
|
132
|
+
with torch.inference_mode():
|
|
133
|
+
for dropped in I.combinations(
|
|
134
|
+
range(self.model.num_experts), self.model.num_experts - self.r
|
|
135
|
+
):
|
|
136
|
+
self.experts_to_drop = dropped
|
|
137
|
+
loss = 0
|
|
138
|
+
|
|
139
|
+
for hidden_states, final_hidden_states in zip(
|
|
140
|
+
self.cache_space.Xs, self.cache_space.Zs
|
|
141
|
+
):
|
|
142
|
+
hidden_states = hidden_states.to(
|
|
143
|
+
device=self.model.gate.weight.data.device, non_blocking=True
|
|
144
|
+
)
|
|
145
|
+
final_hidden_states = final_hidden_states.to(
|
|
146
|
+
dtype=torch.float64,
|
|
147
|
+
device=self.model.gate.weight.data.device,
|
|
148
|
+
non_blocking=True,
|
|
149
|
+
)
|
|
150
|
+
final_hidden_states_e, _ = self.forward(hidden_states.unsqueeze(0))
|
|
151
|
+
# compute the |Z - Z_e|_2 L2 loss
|
|
152
|
+
loss += torch.norm(
|
|
153
|
+
final_hidden_states
|
|
154
|
+
- final_hidden_states_e.squeeze(0).to(torch.float64)
|
|
155
|
+
).item()
|
|
156
|
+
loss_history[dropped] = loss
|
|
157
|
+
|
|
158
|
+
self.experts_to_drop = min(loss_history, key=loss_history.get)
|
|
159
|
+
return loss_history
|
|
160
|
+
|
|
161
|
+
@torch.no_grad()
|
|
162
|
+
def prune(self):
|
|
163
|
+
assert self.experts_to_drop is not None
|
|
164
|
+
assert len(self.experts_to_drop) == self.model.num_experts - self.r
|
|
165
|
+
del self.cache_space
|
|
166
|
+
self.cache_X = False
|
|
167
|
+
self.cache_Z = False
|
|
168
|
+
|
|
169
|
+
experts_to_reserve = sorted(
|
|
170
|
+
set(range(self.model.num_experts)) - set(self.experts_to_drop)
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# create a new gate with the experts to reserve
|
|
174
|
+
gate_new = torch.nn.Linear(
|
|
175
|
+
in_features=self.model.gate.in_features,
|
|
176
|
+
out_features=self.r,
|
|
177
|
+
bias=False,
|
|
178
|
+
device=self.model.gate.weight.data.device,
|
|
179
|
+
dtype=torch.bfloat16,
|
|
180
|
+
)
|
|
181
|
+
gate_new.weight.data = self.model.gate.weight.data[list(experts_to_reserve)]
|
|
182
|
+
self.model.gate = gate_new
|
|
183
|
+
|
|
184
|
+
self.model.experts = torch.nn.ModuleList(
|
|
185
|
+
[self.model.experts[i] for i in experts_to_reserve]
|
|
186
|
+
)
|
|
187
|
+
self.model.num_experts = self.r
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class DynamicSkippingMixtralSparseMoeBlockWrapper(nn.Module):
|
|
191
|
+
def __init__(self, model: MixtralSparseMoeBlock, beta: float):
|
|
192
|
+
super().__init__()
|
|
193
|
+
assert isinstance(model, MixtralSparseMoeBlock)
|
|
194
|
+
assert model.top_k == 2
|
|
195
|
+
self.hidden_dim = model.hidden_dim
|
|
196
|
+
self.ffn_dim = model.ffn_dim
|
|
197
|
+
self.num_experts = model.num_experts
|
|
198
|
+
self.top_k = model.top_k
|
|
199
|
+
self.gate = model.gate
|
|
200
|
+
self.experts = model.experts
|
|
201
|
+
|
|
202
|
+
self.beta = beta
|
|
203
|
+
|
|
204
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
205
|
+
""" """
|
|
206
|
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
|
207
|
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
208
|
+
# router_logits: (batch * sequence_length, n_experts)
|
|
209
|
+
router_logits = self.gate(hidden_states)
|
|
210
|
+
|
|
211
|
+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
|
212
|
+
routing_weights, selected_experts = torch.topk(
|
|
213
|
+
routing_weights, self.top_k, dim=-1
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
# (batch * sequence_length)
|
|
217
|
+
mask_top1 = routing_weights[:, 1] < self.beta * routing_weights[:, 0]
|
|
218
|
+
routing_weights[mask_top1, 1] = 0
|
|
219
|
+
|
|
220
|
+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
|
221
|
+
# we cast back to the input dtype
|
|
222
|
+
routing_weights = routing_weights.to(hidden_states.dtype)
|
|
223
|
+
|
|
224
|
+
final_hidden_states = torch.zeros(
|
|
225
|
+
(batch_size * sequence_length, hidden_dim),
|
|
226
|
+
dtype=hidden_states.dtype,
|
|
227
|
+
device=hidden_states.device,
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
# One hot encode the selected experts to create an expert mask
|
|
231
|
+
# this will be used to easily index which expert is going to be sollicitated
|
|
232
|
+
# (batch * sequence_length, self.top_k, n_experts)
|
|
233
|
+
expert_mask = torch.nn.functional.one_hot(
|
|
234
|
+
selected_experts, num_classes=self.num_experts
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
expert_mask[mask_top1, 1, :] = 0
|
|
238
|
+
expert_mask = expert_mask.permute(2, 1, 0)
|
|
239
|
+
|
|
240
|
+
# Loop over all available experts in the model and perform the computation on each expert
|
|
241
|
+
for expert_idx in range(self.num_experts):
|
|
242
|
+
expert_layer = self.experts[expert_idx]
|
|
243
|
+
top_x, indices = torch.where(expert_mask[expert_idx])
|
|
244
|
+
|
|
245
|
+
if indices.shape[0] == 0:
|
|
246
|
+
continue
|
|
247
|
+
|
|
248
|
+
# in torch it is faster to index using lists than torch tensors
|
|
249
|
+
indices_list = indices.tolist()
|
|
250
|
+
top_x_list = top_x.tolist()
|
|
251
|
+
|
|
252
|
+
# Index the correct hidden states and compute the expert hidden state for
|
|
253
|
+
# the current expert. We need to make sure to multiply the output hidden
|
|
254
|
+
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
|
255
|
+
current_state = hidden_states[None, indices_list].reshape(-1, hidden_dim)
|
|
256
|
+
current_hidden_states = expert_layer(
|
|
257
|
+
current_state, routing_weights[indices_list, top_x_list, None]
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
# However `index_add_` only support torch tensors for indexing so we'll use
|
|
261
|
+
# the `top_x` tensor here.
|
|
262
|
+
final_hidden_states.index_add_(
|
|
263
|
+
0, indices, current_hidden_states.to(hidden_states.dtype)
|
|
264
|
+
)
|
|
265
|
+
final_hidden_states = final_hidden_states.reshape(
|
|
266
|
+
batch_size, sequence_length, hidden_dim
|
|
267
|
+
)
|
|
268
|
+
return final_hidden_states, router_logits
|
|
@@ -296,13 +296,17 @@ class FabricModelFusionProgram(
|
|
|
296
296
|
if hydra_output_dir is not None:
|
|
297
297
|
os.makedirs(self.log_dir, exist_ok=True)
|
|
298
298
|
try:
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
os.path.join(
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
299
|
+
# if the system is windows, use the `mklink` command in "CMD" to create the symlink
|
|
300
|
+
if os.name == "nt":
|
|
301
|
+
os.system(f"mklink /J {os.path.abspath(os.path.join(self.log_dir, 'hydra_output_' + os.path.basename(hydra_output_dir)))} {os.path.abspath(hydra_output_dir)}")
|
|
302
|
+
else:
|
|
303
|
+
os.symlink(
|
|
304
|
+
hydra_output_dir,
|
|
305
|
+
os.path.join(
|
|
306
|
+
self.log_dir,
|
|
307
|
+
"hydra_output_" + os.path.basename(hydra_output_dir),
|
|
308
|
+
),
|
|
309
|
+
target_is_directory=True,
|
|
310
|
+
)
|
|
307
311
|
except OSError as e:
|
|
308
312
|
log.warning(f"Failed to create symbolic link: {e}")
|