fusion-bench 0.2.15__py3-none-any.whl → 0.2.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. fusion_bench/method/__init__.py +4 -0
  2. fusion_bench/method/fw_merging/__init__.py +2 -0
  3. fusion_bench/method/fw_merging/fw_hard.py +448 -0
  4. fusion_bench/method/fw_merging/fw_soft.py +519 -0
  5. fusion_bench/method/fw_merging/utils.py +331 -0
  6. fusion_bench/method/moe_pruner/__init__.py +7 -0
  7. fusion_bench/method/moe_pruner/hooks/__init__.py +6 -0
  8. fusion_bench/method/moe_pruner/hooks/deepseek_v2.py +85 -0
  9. fusion_bench/method/moe_pruner/hooks/hook.py +23 -0
  10. fusion_bench/method/moe_pruner/hooks/mixtral.py +93 -0
  11. fusion_bench/method/moe_pruner/moe_pruner.py +304 -0
  12. fusion_bench/method/moe_pruner/utils/__init__.py +1 -0
  13. fusion_bench/method/moe_pruner/utils/data.py +154 -0
  14. fusion_bench/method/moe_pruner/utils/layerwrapper.py +61 -0
  15. fusion_bench/method/moe_pruner/utils/prune.py +313 -0
  16. fusion_bench/method/moe_pruner/utils/score.py +41 -0
  17. fusion_bench/method/pruning/__init__.py +1 -0
  18. fusion_bench/method/pruning/llama_sparsegpt_prune.py +223 -0
  19. fusion_bench/method/pruning/sparsegpt_utils/__init__.py +1 -0
  20. fusion_bench/method/pruning/sparsegpt_utils/sparsegpt.py +128 -0
  21. fusion_bench/method/pruning/wanda_utils/data.py +33 -14
  22. fusion_bench/method/randes/__init__.py +15 -0
  23. fusion_bench/method/randes/base_algorithm.py +1013 -0
  24. fusion_bench/method/randes/modelsoup.py +126 -0
  25. fusion_bench/method/randes/task_arithmetic.py +318 -0
  26. fusion_bench/method/sparselo/sparselo.py +20 -2
  27. fusion_bench/method/tall_mask/__init__.py +1 -0
  28. fusion_bench/method/tall_mask/task_arithmetic.py +133 -0
  29. fusion_bench/modelpool/lazy_state_dict_pool.py +15 -0
  30. fusion_bench/models/modeling_deepseek_v2/__init__.py +15 -0
  31. fusion_bench/models/modeling_deepseek_v2/configuration_deepseek.py +208 -0
  32. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +1922 -0
  33. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +38 -0
  34. fusion_bench/programs/fabric_fusion_program.py +5 -0
  35. fusion_bench/taskpool/clip_vision/taskpool.py +8 -1
  36. fusion_bench/utils/__init__.py +1 -0
  37. fusion_bench/utils/data.py +1 -1
  38. fusion_bench/utils/lazy_state_dict.py +268 -0
  39. fusion_bench/utils/parameters.py +33 -0
  40. fusion_bench/utils/state_dict_arithmetic.py +74 -2
  41. fusion_bench/utils/type.py +1 -0
  42. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/METADATA +6 -2
  43. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/RECORD +77 -21
  44. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/WHEEL +1 -1
  45. fusion_bench_config/dataset/image_classification/test/TALL10.yaml +28 -0
  46. fusion_bench_config/dataset/image_classification/test/TALL12.yaml +28 -0
  47. fusion_bench_config/dataset/image_classification/test/TALL16.yaml +28 -0
  48. fusion_bench_config/dataset/image_classification/test/TALL18.yaml +28 -0
  49. fusion_bench_config/dataset/image_classification/train/TALL10.yaml +28 -0
  50. fusion_bench_config/dataset/image_classification/train/TALL12.yaml +28 -0
  51. fusion_bench_config/dataset/image_classification/train/TALL16.yaml +28 -0
  52. fusion_bench_config/dataset/image_classification/train/TALL18.yaml +28 -0
  53. fusion_bench_config/method/fw_merging/fw_hard.yaml +11 -0
  54. fusion_bench_config/method/fw_merging/fw_soft.yaml +12 -0
  55. fusion_bench_config/method/moe_pruner/moe_pruner.yaml +15 -0
  56. fusion_bench_config/method/pruning/llama_sparsegpt_pruning.yaml +16 -0
  57. fusion_bench_config/method/randes/superposed_model_soup.yaml +18 -0
  58. fusion_bench_config/method/randes/superposed_task_arithmetic.yaml +20 -0
  59. fusion_bench_config/method/randes/superposed_task_arithmetic_lora.yaml +20 -0
  60. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +2 -1
  61. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
  62. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
  63. fusion_bench_config/method/tall_mask/task_arithmetic.yaml +4 -0
  64. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL10.yaml +29 -0
  65. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL12.yaml +29 -0
  66. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL16.yaml +29 -0
  67. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL18.yaml +29 -0
  68. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +8 -0
  69. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +8 -0
  70. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +8 -0
  71. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +8 -0
  72. fusion_bench_config/modelpool/CausalLMPool/deepseek-v2-lite.yaml +15 -0
  73. fusion_bench_config/modelpool/CausalLMPool/mixtral-8x7b.yaml +14 -0
  74. fusion_bench_config/modelpool/SeqenceClassificationModelPool/roberta-base_glue.yaml +69 -0
  75. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/entry_points.txt +0 -0
  76. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/licenses/LICENSE +0 -0
  77. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,331 @@
1
+ """
2
+ This is modified based on https://github.com/EnnengYang/AdaMerging/blob/main/src/ties_merging_utils.py
3
+ """
4
+
5
+ import copy
6
+ from collections import OrderedDict
7
+ from typing import List
8
+
9
+ import torch
10
+ from torch import Tensor, nn
11
+
12
+ from fusion_bench.utils.type import StateDictType
13
+
14
+
15
+ # Model conversion utils
16
+ def state_dict_to_vector(state_dict, remove_keys=[]):
17
+ """
18
+ Convert a state dictionary to a vector, removing specified keys.
19
+
20
+ Args:
21
+ state_dict (dict): The state dictionary to convert.
22
+ remove_keys (list): List of keys to remove from the state dictionary.
23
+
24
+ Returns:
25
+ Tensor: A vector representation of the state dictionary.
26
+ """
27
+ shared_state_dict = copy.deepcopy(state_dict)
28
+ for key in remove_keys:
29
+ if key in shared_state_dict:
30
+ del shared_state_dict[key]
31
+ sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items()))
32
+ return nn.utils.parameters_to_vector(
33
+ [value.reshape(-1) for key, value in sorted_shared_state_dict.items()]
34
+ )
35
+
36
+
37
+ def vector_to_state_dict(vector, state_dict, remove_keys=[]):
38
+ """
39
+ Convert a vector back to a state dictionary, removing specified keys.
40
+
41
+ Args:
42
+ vector (Tensor): The vector to convert.
43
+ state_dict (dict): The reference state dictionary.
44
+ remove_keys (list): List of keys to remove from the state dictionary.
45
+
46
+ Returns:
47
+ dict: A state dictionary representation of the vector.
48
+ """
49
+ # create a reference dict to define the order of the vector
50
+ reference_dict = copy.deepcopy(state_dict)
51
+ for key in remove_keys:
52
+ if key in reference_dict:
53
+ del reference_dict[key]
54
+ sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))
55
+
56
+ # create a shared state dict using the reference dict
57
+ nn.utils.vector_to_parameters(vector, sorted_reference_dict.values())
58
+
59
+ # add back the encoder and decoder embedding weights.
60
+ if "transformer.shared.weight" in sorted_reference_dict:
61
+ for key in remove_keys:
62
+ sorted_reference_dict[key] = sorted_reference_dict[
63
+ "transformer.shared.weight"
64
+ ]
65
+ return sorted_reference_dict
66
+
67
+
68
+ def add_ptm_to_tv(tv_dict, ptm_dict):
69
+ """
70
+ Add the values of one state dictionary to another.
71
+
72
+ Args:
73
+ tv_dict (dict): The target state dictionary.
74
+ ptm_dict (dict): The state dictionary to add.
75
+
76
+ Returns:
77
+ dict: The resulting state dictionary after addition.
78
+ """
79
+ assert set(tv_dict.keys()) == set(
80
+ ptm_dict.keys()
81
+ ), "Differing parameter names in models."
82
+ final_dict = copy.deepcopy(tv_dict)
83
+ for k, v in ptm_dict.items():
84
+ final_dict[k] = tv_dict[k] + v
85
+ return final_dict
86
+
87
+
88
+ def check_parameterNamesMatch(checkpoints: List[StateDictType]) -> None:
89
+ """
90
+ Check if the parameter names match across multiple checkpoints.
91
+
92
+ Args:
93
+ checkpoints (list): List of state dictionaries to check.
94
+
95
+ Raises:
96
+ ValueError: If the parameter names do not match.
97
+ """
98
+ parameter_names = set(checkpoints[0].keys())
99
+
100
+ if len(checkpoints) >= 2:
101
+ # raise ValueError("Number of models is less than 2.")
102
+ for checkpoint in checkpoints[1:]:
103
+ current_parameterNames = set(checkpoint.keys())
104
+ if current_parameterNames != parameter_names:
105
+ raise ValueError(
106
+ "Differing parameter names in models. "
107
+ f"The different parameters are {parameter_names.symmetric_difference(current_parameterNames)}"
108
+ )
109
+
110
+
111
+ def check_state_dicts_equal(
112
+ state_dict1: StateDictType, state_dict2: StateDictType
113
+ ) -> bool:
114
+ """
115
+ Check if two state dictionaries are equal.
116
+
117
+ Args:
118
+ state_dict1 (dict): The first state dictionary.
119
+ state_dict2 (dict): The second state dictionary.
120
+
121
+ Returns:
122
+ bool: True if the state dictionaries are equal, False otherwise.
123
+ """
124
+ if set(state_dict1.keys()) != set(state_dict2.keys()):
125
+ return False
126
+
127
+ for key in state_dict1.keys():
128
+ if not torch.equal(state_dict1[key], state_dict2[key]):
129
+ return False
130
+
131
+ return True
132
+
133
+
134
+ # TIES MERGING UTILS
135
+
136
+
137
+ def topk_values_mask(M, K=0.7, return_mask=False):
138
+ """
139
+ Mask the top K values in a tensor.
140
+
141
+ Args:
142
+ M (Tensor): The input tensor.
143
+ K (float): The proportion of top values to keep.
144
+ return_mask (bool): Whether to return the mask tensor.
145
+
146
+ Returns:
147
+ tuple: The masked tensor, the mean of the mask, and optionally the mask tensor.
148
+ """
149
+ if K > 1:
150
+ K /= 100
151
+
152
+ original_shape = M.shape
153
+ if M.dim() == 1:
154
+ M = M.unsqueeze(0)
155
+
156
+ n, d = M.shape
157
+ k = int(d * K)
158
+ k = d - k # Keep top k elements instead of bottom k elements
159
+
160
+ # Find the k-th smallest element by magnitude for each row
161
+ kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True)
162
+ # Create a mask tensor with True for the top k elements in each row
163
+ mask = M.abs() >= kth_values
164
+ final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask
165
+
166
+ if return_mask:
167
+ return M * final_mask, final_mask.float().mean(dim=1), final_mask
168
+ return M * final_mask, final_mask.float().mean(dim=1)
169
+
170
+
171
+ def resolve_zero_signs(sign_to_mult, method="majority"):
172
+ """
173
+ Resolve zero signs in a tensor by majority or minority rule.
174
+
175
+ Args:
176
+ sign_to_mult (Tensor): The tensor with signs to resolve.
177
+ method (str): The method to use for resolving zero signs ("majority" or "minority").
178
+
179
+ Returns:
180
+ Tensor: The tensor with resolved signs.
181
+ """
182
+ majority_sign = torch.sign(sign_to_mult.sum())
183
+
184
+ if method == "majority":
185
+ sign_to_mult[sign_to_mult == 0] = majority_sign
186
+ elif method == "minority":
187
+ sign_to_mult[sign_to_mult == 0] = -1 * majority_sign
188
+ return sign_to_mult
189
+
190
+
191
+ def resolve_sign(v: Tensor):
192
+ """
193
+ Resolve the sign of a tensor by majority rule.
194
+
195
+ Args:
196
+ v (Tensor): The input tensor.
197
+
198
+ Returns:
199
+ Tensor: The tensor with resolved signs.
200
+ """
201
+ sign_to_mult = torch.sign(v.sum(dim=0))
202
+ sign_to_mult = resolve_zero_signs(sign_to_mult, "majority")
203
+ return sign_to_mult
204
+
205
+
206
+ def disjoint_merge(v: Tensor, merge_func: str, sign_to_mult):
207
+ """
208
+ Perform disjoint merging of a tensor using a specified merge function.
209
+
210
+ Args:
211
+ v (Tensor): The input tensor.
212
+ merge_func (str): The merge function to use ("mean", "sum", or "max").
213
+ sign_to_mult (Tensor): The tensor with signs to use for merging.
214
+
215
+ Returns:
216
+ Tensor: The merged tensor.
217
+ """
218
+ merge_func = merge_func.split("-")[-1]
219
+
220
+ # If sign is provided then we select the corresponding entries and aggregate.
221
+ if sign_to_mult is not None:
222
+ rows_to_keep = torch.where(sign_to_mult.unsqueeze(0) > 0, v > 0, v < 0)
223
+ selected_entries = v * rows_to_keep
224
+ # Else we select all non-zero entries and aggregate.
225
+ else:
226
+ rows_to_keep = v != 0
227
+ selected_entries = v * rows_to_keep
228
+
229
+ if merge_func == "mean":
230
+ non_zero_counts = (selected_entries != 0).sum(dim=0).float()
231
+ disjoint_aggs = torch.sum(selected_entries, dim=0) / torch.clamp(
232
+ non_zero_counts, min=1
233
+ )
234
+ elif merge_func == "sum":
235
+ disjoint_aggs = torch.sum(selected_entries, dim=0)
236
+ elif merge_func == "max":
237
+ disjoint_aggs = selected_entries.abs().max(dim=0)[0]
238
+ disjoint_aggs *= sign_to_mult
239
+ else:
240
+ raise ValueError(f"Merge method {merge_func} is not defined.")
241
+
242
+ return disjoint_aggs
243
+
244
+
245
+ def ties_merging(
246
+ flat_task_checks,
247
+ reset_thresh=None,
248
+ merge_func="",
249
+ ):
250
+ """
251
+ Perform TIES merging on a tensor.
252
+
253
+ Args:
254
+ flat_task_checks (Tensor): The input tensor.
255
+ reset_thresh (float): The threshold for resetting values.
256
+ merge_func (str): The merge function to use.
257
+
258
+ Returns:
259
+ Tensor: The merged tensor.
260
+ """
261
+ all_checks = flat_task_checks.clone()
262
+ updated_checks, *_ = topk_values_mask(all_checks, K=reset_thresh, return_mask=False)
263
+ print("RESOLVING SIGN")
264
+ final_signs = resolve_sign(updated_checks)
265
+ assert final_signs is not None
266
+
267
+ print(f"Disjoint AGGREGATION: {merge_func}")
268
+ merged_tv = disjoint_merge(updated_checks, merge_func, final_signs)
269
+
270
+ return merged_tv
271
+
272
+
273
+ def disjoint_merge_split(v: Tensor, merge_func: str, sign_to_mult):
274
+ """
275
+ Perform disjoint merging of a tensor using a specified merge function and return selected entries.
276
+
277
+ Args:
278
+ v (Tensor): The input tensor.
279
+ merge_func (str): The merge function to use ("sum").
280
+ sign_to_mult (Tensor): The tensor with signs to use for merging.
281
+
282
+ Returns:
283
+ tuple: The selected entries and the merged tensor.
284
+ """
285
+ merge_func = merge_func.split("-")[-1]
286
+
287
+ # If sign is provided then we select the corresponding entries and aggregate.
288
+ if sign_to_mult is not None:
289
+ rows_to_keep = torch.where(sign_to_mult.unsqueeze(0) > 0, v > 0, v < 0)
290
+ selected_entries = v * rows_to_keep
291
+ # Else we select all non-zero entries and aggregate.
292
+ else:
293
+ rows_to_keep = v != 0
294
+ selected_entries = v * rows_to_keep
295
+
296
+ if merge_func == "sum":
297
+ disjoint_aggs = torch.sum(selected_entries, dim=0)
298
+ else:
299
+ raise ValueError(f"Merge method {merge_func} is not defined.")
300
+
301
+ return selected_entries, disjoint_aggs
302
+
303
+
304
+ def ties_merging_split(
305
+ flat_task_checks,
306
+ reset_thresh=None,
307
+ merge_func: str = "",
308
+ ):
309
+ """
310
+ Perform TIES merging on a tensor and return selected entries.
311
+
312
+ Args:
313
+ flat_task_checks (Tensor): The input tensor.
314
+ reset_thresh (float): The threshold for resetting values.
315
+ merge_func (str): The merge function to use.
316
+
317
+ Returns:
318
+ tuple: The selected entries and the merged tensor.
319
+ """
320
+ all_checks = flat_task_checks.clone()
321
+ updated_checks, *_ = topk_values_mask(all_checks, K=reset_thresh, return_mask=False)
322
+ print("RESOLVING SIGN")
323
+ final_signs = resolve_sign(updated_checks)
324
+ assert final_signs is not None
325
+
326
+ print(f"Disjoint AGGREGATION: {merge_func}")
327
+ selected_entries, merged_tv = disjoint_merge_split(
328
+ updated_checks, merge_func, final_signs
329
+ )
330
+
331
+ return selected_entries, merged_tv
@@ -0,0 +1,7 @@
1
+ """
2
+ Implementation of MoE-Pruner
3
+
4
+ MoE-Pruner: Pruning Mixture-of-Experts Large Language Model using the Hints from Its Router
5
+ """
6
+
7
+ from .moe_pruner import MoEPruner
@@ -0,0 +1,6 @@
1
+ from .hook import BaseHookFn
2
+ from .deepseek_v2 import (
3
+ MoEPrunerHookFnForDeepseekV2Gate,
4
+ MoEPrunerHookFnForDeepseekV2Linear,
5
+ )
6
+
@@ -0,0 +1,85 @@
1
+ from typing import Dict, Tuple
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+
6
+ from fusion_bench.models.modeling_deepseek_v2 import DeepseekV2MoEGate
7
+
8
+ from .hook import BaseHookFn
9
+
10
+
11
+ class MoEPrunerHookFnForDeepseekV2Linear(BaseHookFn):
12
+ _routing_weights = None # set by gate hook
13
+
14
+ def __init__(self, linear: nn.Linear, name: str):
15
+ super().__init__(linear)
16
+ self.linear = linear
17
+ self.scalar_row = torch.zeros(
18
+ (linear.weight.size(1),), device=linear.weight.device
19
+ )
20
+ self.nsamples = 0
21
+ self.name = name
22
+
23
+ def __call__(self, linear, inps: Tuple[Tensor], out: Tensor):
24
+ assert len(inps) == 1
25
+ inp = inps[0]
26
+ if len(inp.shape) == 2:
27
+ inp = inp.unsqueeze(0)
28
+
29
+ batch_size = inp.shape[0]
30
+ if len(inp.shape) == 3:
31
+ inp = inp.reshape((-1, inp.shape[-1]))
32
+ # (NxL, C) -> (C, NxL)
33
+ inp = inp.t()
34
+
35
+ self.scalar_row *= self.nsamples / (self.nsamples + batch_size)
36
+ self.nsamples += batch_size
37
+
38
+ inp = inp.type(torch.float32)
39
+ routing_weights = self._routing_weights.t()
40
+ self.scalar_row += (
41
+ torch.norm(inp * routing_weights, p=2, dim=1) ** 2 / self.nsamples
42
+ )
43
+
44
+ def compute(self):
45
+ return torch.abs(self.linear.weight) * torch.sqrt(
46
+ self.scalar_row.reshape(1, -1)
47
+ )
48
+
49
+
50
+ class MoEPrunerHookFnForDeepseekV2Gate(BaseHookFn):
51
+ def __init__(
52
+ self,
53
+ router: DeepseekV2MoEGate,
54
+ linear_layer_hooks: Dict[str, MoEPrunerHookFnForDeepseekV2Linear],
55
+ top_k: int,
56
+ num_experts: int,
57
+ ):
58
+ super().__init__(router)
59
+ self.router = router
60
+ self.linear_layer_hooks = linear_layer_hooks
61
+ self.top_k = top_k
62
+ self.num_experts = num_experts
63
+
64
+ def __call__(self, router, inps: Tuple[Tensor], out: Tuple[Tensor, Tensor, Tensor]):
65
+ assert len(inps) == 1
66
+
67
+ x = inps[0]
68
+ x = x.view(-1, x.shape[-1])
69
+ topk_ids, topk_weight, aux_loss = out
70
+
71
+ # One hot encode the selected experts to create an expert mask
72
+ # this will be used to easily index which expert is going to be sollicitated
73
+ expert_mask = torch.nn.functional.one_hot(
74
+ topk_ids, num_classes=self.num_experts
75
+ ).permute(2, 1, 0)
76
+
77
+ for expert_idx in range(self.num_experts):
78
+ idx, top_x = torch.where(expert_mask[expert_idx])
79
+ for name, hook in self.linear_layer_hooks.items():
80
+ if not name.startswith(f"{expert_idx}."):
81
+ continue
82
+ hook._routing_weights = topk_weight[top_x, idx, None]
83
+
84
+ def compute(self):
85
+ pass
@@ -0,0 +1,23 @@
1
+ from abc import abstractmethod
2
+ from typing import Tuple
3
+
4
+ from torch import Tensor, nn
5
+
6
+
7
+ class BaseHookFn:
8
+ def __init__(self, module: nn.Module):
9
+ self.module = module
10
+
11
+ @abstractmethod
12
+ def compute(self) -> Tensor:
13
+ """
14
+ Compute the importance scores.
15
+ """
16
+ pass
17
+
18
+ @abstractmethod
19
+ def __call__(self, router, inps: Tuple[Tensor], out: Tensor):
20
+ """
21
+ Hook function to be called during the forward pass.
22
+ """
23
+ pass
@@ -0,0 +1,93 @@
1
+ from typing import Dict, Tuple
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import Tensor, nn
6
+
7
+ from .hook import BaseHookFn
8
+
9
+
10
+ class MoEPrunerHookFnForMixtralLinear(BaseHookFn):
11
+ _routing_weights = None # set by gate hook
12
+
13
+ def __init__(
14
+ self,
15
+ linear: nn.Linear,
16
+ name: str,
17
+ ):
18
+ super().__init__(linear)
19
+ self.linear = linear
20
+ self.scalar_row = torch.zeros(
21
+ (linear.weight.size(1),), device=linear.weight.device
22
+ )
23
+ self.nsamples = 0
24
+ self.name = name
25
+
26
+ def compute(self):
27
+ return torch.abs(self.linear.weight) * torch.sqrt(
28
+ self.scalar_row.reshape(1, -1)
29
+ )
30
+
31
+ def __call__(self, linear: nn.Linear, inps: Tuple[Tensor], out: Tensor):
32
+ assert len(inps) == 1
33
+ inp = inps[0]
34
+ if len(inp.shape) == 2:
35
+ inp = inp.unsqueeze(0)
36
+
37
+ batch_size = inp.shape[0]
38
+ if len(inp.shape) == 3:
39
+ inp = inp.reshape((-1, inp.shape[-1]))
40
+ # (NxL, C) -> (C, NxL)
41
+ inp = inp.t()
42
+
43
+ self.scalar_row *= self.nsamples / (self.nsamples + batch_size)
44
+ self.nsamples += batch_size
45
+
46
+ inp = inp.type(torch.float32)
47
+ routing_weights = self._routing_weights.t()
48
+ self.scalar_row += (
49
+ torch.norm(inp * routing_weights, p=2, dim=1) ** 2 / self.nsamples
50
+ )
51
+
52
+
53
+ class MoEPrunerHookFnForMixtralGate(BaseHookFn):
54
+
55
+ def __init__(
56
+ self,
57
+ router: nn.Module,
58
+ linear_layer_hooks: Dict[str, MoEPrunerHookFnForMixtralLinear],
59
+ top_k: int,
60
+ num_experts: int,
61
+ ):
62
+ self.nsamples = 0
63
+ self.linear_layer_hooks = linear_layer_hooks
64
+ self.top_k = top_k
65
+ self.num_experts = num_experts
66
+ super().__init__(router)
67
+
68
+ def __call__(self, router, inps: Tuple[Tensor], out: Tensor):
69
+ assert len(inps) == 1
70
+ inp = inps[0]
71
+
72
+ router_logits = out
73
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
74
+ routing_weights, selected_experts = torch.topk(
75
+ routing_weights, self.top_k, dim=-1
76
+ )
77
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
78
+
79
+ # One hot encode the selected experts to create an expert mask
80
+ # this will be used to easily index which expert is going to be sollicitated
81
+ expert_mask = torch.nn.functional.one_hot(
82
+ selected_experts, num_classes=self.num_experts
83
+ ).permute(2, 1, 0)
84
+
85
+ for expert_idx in range(self.num_experts):
86
+ idx, top_x = torch.where(expert_mask[expert_idx])
87
+ for name, hook in self.linear_layer_hooks.items():
88
+ if not name.startswith(f"{expert_idx}."):
89
+ continue
90
+ hook._routing_weights = routing_weights[top_x, idx, None]
91
+
92
+ def compute(self):
93
+ pass