fusion-bench 0.2.17__py3-none-any.whl → 0.2.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. fusion_bench/__init__.py +6 -0
  2. fusion_bench/constants/banner.py +12 -0
  3. fusion_bench/method/__init__.py +11 -0
  4. fusion_bench/method/expert_sparsity/__init__.py +10 -0
  5. fusion_bench/method/expert_sparsity/mixtral/__init__.py +23 -0
  6. fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +175 -0
  7. fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +159 -0
  8. fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +173 -0
  9. fusion_bench/method/expert_sparsity/utils/calibration_data.py +153 -0
  10. fusion_bench/method/knots/__init__.py +0 -0
  11. fusion_bench/method/knots/knots_utils.py +23 -0
  12. fusion_bench/method/linear/simple_average_for_llama.py +17 -3
  13. fusion_bench/method/simple_average.py +10 -0
  14. fusion_bench/method/task_singular_vector/utils/__init__.py +1 -0
  15. fusion_bench/method/task_singular_vector/utils/task_singular_interference.py +41 -0
  16. fusion_bench/modelpool/causal_lm/causal_lm.py +45 -11
  17. fusion_bench/models/__init__.py +1 -0
  18. fusion_bench/models/expert_sparsity/__init__.py +0 -0
  19. fusion_bench/models/expert_sparsity/mixtral/__init__.py +15 -0
  20. fusion_bench/models/expert_sparsity/mixtral/dataset.py +40 -0
  21. fusion_bench/models/expert_sparsity/mixtral/modeling_mixtral.py +207 -0
  22. fusion_bench/models/expert_sparsity/mixtral/wrapper.py +268 -0
  23. fusion_bench/programs/fabric_fusion_program.py +12 -8
  24. fusion_bench/tasks/clip_classification/imagenet.py +1008 -2004
  25. fusion_bench/utils/__init__.py +3 -2
  26. fusion_bench/utils/dtype.py +2 -1
  27. fusion_bench/utils/fabric.py +11 -4
  28. fusion_bench/utils/lazy_state_dict.py +155 -13
  29. fusion_bench/utils/misc.py +19 -1
  30. fusion_bench/utils/pylogger.py +2 -0
  31. {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.19.dist-info}/METADATA +1 -1
  32. {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.19.dist-info}/RECORD +40 -21
  33. fusion_bench_config/fabric/loggers/mlflow_logger.yaml +2 -0
  34. fusion_bench_config/method/expert_sparsity/README.md +6 -0
  35. fusion_bench_config/method/expert_sparsity/mixtral.yaml +17 -0
  36. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +11 -0
  37. {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.19.dist-info}/WHEEL +0 -0
  38. {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.19.dist-info}/entry_points.txt +0 -0
  39. {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.19.dist-info}/licenses/LICENSE +0 -0
  40. {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.19.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,268 @@
1
+ import itertools as I
2
+ import logging
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from transformers.models.mixtral.modeling_mixtral import (
9
+ MixtralBlockSparseTop2MLP,
10
+ MixtralForCausalLM,
11
+ MixtralSparseMoeBlock,
12
+ )
13
+
14
+ from .dataset import CacheDataset
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class PrunableMixtralSparseMoeBlockWrapper(torch.nn.Module):
20
+ """
21
+ Wrapper of `MixtralSparseMoeBlock` that supports expert pruning.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ model: MixtralSparseMoeBlock,
27
+ r: Optional[int] = None,
28
+ ):
29
+ """
30
+ Args:
31
+ model: The model to be wrapped.
32
+ r: The number of experts to keep.
33
+ """
34
+ super().__init__()
35
+ if isinstance(model, MixtralSparseMoeBlock):
36
+ self.model = model
37
+ else:
38
+ self.model = model.model
39
+ self.r = r
40
+
41
+ self.experts_to_drop = None
42
+ self.cache_space = CacheDataset()
43
+ self.cache_logits = False
44
+ self.cache_X = False
45
+ self.cache_Z = False
46
+
47
+ # Forward uses topk
48
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
49
+ """ """
50
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
51
+ hidden_states = hidden_states.view(-1, hidden_dim)
52
+ # router_logits: (batch * sequence_length, n_experts)
53
+ router_logits = self.model.gate(hidden_states)
54
+
55
+ if self.experts_to_drop is not None:
56
+ for e in self.experts_to_drop:
57
+ router_logits[:, e] = -float("inf")
58
+
59
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
60
+ routing_weights, selected_experts = torch.topk(
61
+ routing_weights, self.model.top_k, dim=-1
62
+ )
63
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
64
+ # we cast back to the input dtype
65
+ routing_weights = routing_weights.to(hidden_states.dtype)
66
+
67
+ final_hidden_states = torch.zeros(
68
+ (batch_size * sequence_length, hidden_dim),
69
+ dtype=hidden_states.dtype,
70
+ device=hidden_states.device,
71
+ )
72
+
73
+ # One hot encode the selected experts to create an expert mask
74
+ # this will be used to easily index which expert is going to be sollicitated
75
+ expert_mask = torch.nn.functional.one_hot(
76
+ selected_experts, num_classes=self.model.num_experts
77
+ ).permute(2, 1, 0)
78
+
79
+ # Loop over all available experts in the model and perform the computation on each expert
80
+ for expert_idx in range(self.model.num_experts):
81
+ expert_layer = self.model.experts[expert_idx]
82
+ idx, top_x = torch.where(expert_mask[expert_idx])
83
+
84
+ if top_x.shape[0] == 0:
85
+ continue
86
+
87
+ # in torch it is faster to index using lists than torch tensors
88
+ top_x_list = top_x.tolist()
89
+ idx_list = idx.tolist()
90
+
91
+ # Index the correct hidden states and compute the expert hidden state for
92
+ # the current expert. We need to make sure to multiply the output hidden
93
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
94
+ current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
95
+ current_hidden_states = (
96
+ expert_layer(current_state)
97
+ * routing_weights[top_x_list, idx_list, None]
98
+ )
99
+
100
+ # However `index_add_` only support torch tensors for indexing so we'll use
101
+ # the `top_x` tensor here.
102
+ final_hidden_states.index_add_(
103
+ 0, top_x, current_hidden_states.to(hidden_states.dtype)
104
+ )
105
+
106
+ if self.experts_to_drop is not None and (
107
+ self.cache_logits or self.cache_X or self.cache_Z
108
+ ):
109
+ logger.warn(
110
+ f"Already dropped {self.experts_to_drop} but still storing activations."
111
+ )
112
+ self.cache_space.append(
113
+ alpha=(router_logits if self.cache_logits else None),
114
+ X=(hidden_states if self.cache_X else None),
115
+ Z=(final_hidden_states if self.cache_Z else None),
116
+ )
117
+
118
+ final_hidden_states = final_hidden_states.reshape(
119
+ batch_size, sequence_length, hidden_dim
120
+ )
121
+
122
+ return final_hidden_states, router_logits
123
+
124
+ @torch.no_grad()
125
+ def enumerate(self):
126
+ # disable caching
127
+ self.cache_logits = False
128
+ self.cache_X = False
129
+ self.cache_Z = False
130
+ loss_history = dict()
131
+
132
+ with torch.inference_mode():
133
+ for dropped in I.combinations(
134
+ range(self.model.num_experts), self.model.num_experts - self.r
135
+ ):
136
+ self.experts_to_drop = dropped
137
+ loss = 0
138
+
139
+ for hidden_states, final_hidden_states in zip(
140
+ self.cache_space.Xs, self.cache_space.Zs
141
+ ):
142
+ hidden_states = hidden_states.to(
143
+ device=self.model.gate.weight.data.device, non_blocking=True
144
+ )
145
+ final_hidden_states = final_hidden_states.to(
146
+ dtype=torch.float64,
147
+ device=self.model.gate.weight.data.device,
148
+ non_blocking=True,
149
+ )
150
+ final_hidden_states_e, _ = self.forward(hidden_states.unsqueeze(0))
151
+ # compute the |Z - Z_e|_2 L2 loss
152
+ loss += torch.norm(
153
+ final_hidden_states
154
+ - final_hidden_states_e.squeeze(0).to(torch.float64)
155
+ ).item()
156
+ loss_history[dropped] = loss
157
+
158
+ self.experts_to_drop = min(loss_history, key=loss_history.get)
159
+ return loss_history
160
+
161
+ @torch.no_grad()
162
+ def prune(self):
163
+ assert self.experts_to_drop is not None
164
+ assert len(self.experts_to_drop) == self.model.num_experts - self.r
165
+ del self.cache_space
166
+ self.cache_X = False
167
+ self.cache_Z = False
168
+
169
+ experts_to_reserve = sorted(
170
+ set(range(self.model.num_experts)) - set(self.experts_to_drop)
171
+ )
172
+
173
+ # create a new gate with the experts to reserve
174
+ gate_new = torch.nn.Linear(
175
+ in_features=self.model.gate.in_features,
176
+ out_features=self.r,
177
+ bias=False,
178
+ device=self.model.gate.weight.data.device,
179
+ dtype=torch.bfloat16,
180
+ )
181
+ gate_new.weight.data = self.model.gate.weight.data[list(experts_to_reserve)]
182
+ self.model.gate = gate_new
183
+
184
+ self.model.experts = torch.nn.ModuleList(
185
+ [self.model.experts[i] for i in experts_to_reserve]
186
+ )
187
+ self.model.num_experts = self.r
188
+
189
+
190
+ class DynamicSkippingMixtralSparseMoeBlockWrapper(nn.Module):
191
+ def __init__(self, model: MixtralSparseMoeBlock, beta: float):
192
+ super().__init__()
193
+ assert isinstance(model, MixtralSparseMoeBlock)
194
+ assert model.top_k == 2
195
+ self.hidden_dim = model.hidden_dim
196
+ self.ffn_dim = model.ffn_dim
197
+ self.num_experts = model.num_experts
198
+ self.top_k = model.top_k
199
+ self.gate = model.gate
200
+ self.experts = model.experts
201
+
202
+ self.beta = beta
203
+
204
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
205
+ """ """
206
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
207
+ hidden_states = hidden_states.view(-1, hidden_dim)
208
+ # router_logits: (batch * sequence_length, n_experts)
209
+ router_logits = self.gate(hidden_states)
210
+
211
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
212
+ routing_weights, selected_experts = torch.topk(
213
+ routing_weights, self.top_k, dim=-1
214
+ )
215
+
216
+ # (batch * sequence_length)
217
+ mask_top1 = routing_weights[:, 1] < self.beta * routing_weights[:, 0]
218
+ routing_weights[mask_top1, 1] = 0
219
+
220
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
221
+ # we cast back to the input dtype
222
+ routing_weights = routing_weights.to(hidden_states.dtype)
223
+
224
+ final_hidden_states = torch.zeros(
225
+ (batch_size * sequence_length, hidden_dim),
226
+ dtype=hidden_states.dtype,
227
+ device=hidden_states.device,
228
+ )
229
+
230
+ # One hot encode the selected experts to create an expert mask
231
+ # this will be used to easily index which expert is going to be sollicitated
232
+ # (batch * sequence_length, self.top_k, n_experts)
233
+ expert_mask = torch.nn.functional.one_hot(
234
+ selected_experts, num_classes=self.num_experts
235
+ )
236
+
237
+ expert_mask[mask_top1, 1, :] = 0
238
+ expert_mask = expert_mask.permute(2, 1, 0)
239
+
240
+ # Loop over all available experts in the model and perform the computation on each expert
241
+ for expert_idx in range(self.num_experts):
242
+ expert_layer = self.experts[expert_idx]
243
+ top_x, indices = torch.where(expert_mask[expert_idx])
244
+
245
+ if indices.shape[0] == 0:
246
+ continue
247
+
248
+ # in torch it is faster to index using lists than torch tensors
249
+ indices_list = indices.tolist()
250
+ top_x_list = top_x.tolist()
251
+
252
+ # Index the correct hidden states and compute the expert hidden state for
253
+ # the current expert. We need to make sure to multiply the output hidden
254
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
255
+ current_state = hidden_states[None, indices_list].reshape(-1, hidden_dim)
256
+ current_hidden_states = expert_layer(
257
+ current_state, routing_weights[indices_list, top_x_list, None]
258
+ )
259
+
260
+ # However `index_add_` only support torch tensors for indexing so we'll use
261
+ # the `top_x` tensor here.
262
+ final_hidden_states.index_add_(
263
+ 0, indices, current_hidden_states.to(hidden_states.dtype)
264
+ )
265
+ final_hidden_states = final_hidden_states.reshape(
266
+ batch_size, sequence_length, hidden_dim
267
+ )
268
+ return final_hidden_states, router_logits
@@ -296,13 +296,17 @@ class FabricModelFusionProgram(
296
296
  if hydra_output_dir is not None:
297
297
  os.makedirs(self.log_dir, exist_ok=True)
298
298
  try:
299
- os.symlink(
300
- hydra_output_dir,
301
- os.path.join(
302
- self.log_dir,
303
- "hydra_output_" + os.path.basename(hydra_output_dir),
304
- ),
305
- target_is_directory=True,
306
- )
299
+ # if the system is windows, use the `mklink` command in "CMD" to create the symlink
300
+ if os.name == "nt":
301
+ os.system(f"mklink /J {os.path.abspath(os.path.join(self.log_dir, 'hydra_output_' + os.path.basename(hydra_output_dir)))} {os.path.abspath(hydra_output_dir)}")
302
+ else:
303
+ os.symlink(
304
+ hydra_output_dir,
305
+ os.path.join(
306
+ self.log_dir,
307
+ "hydra_output_" + os.path.basename(hydra_output_dir),
308
+ ),
309
+ target_is_directory=True,
310
+ )
307
311
  except OSError as e:
308
312
  log.warning(f"Failed to create symbolic link: {e}")