fusion-bench 0.2.15__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/method/__init__.py +4 -0
- fusion_bench/method/fw_merging/__init__.py +2 -0
- fusion_bench/method/fw_merging/fw_hard.py +448 -0
- fusion_bench/method/fw_merging/fw_soft.py +519 -0
- fusion_bench/method/fw_merging/utils.py +331 -0
- fusion_bench/method/moe_pruner/__init__.py +7 -0
- fusion_bench/method/moe_pruner/hooks/__init__.py +6 -0
- fusion_bench/method/moe_pruner/hooks/deepseek_v2.py +85 -0
- fusion_bench/method/moe_pruner/hooks/hook.py +23 -0
- fusion_bench/method/moe_pruner/hooks/mixtral.py +93 -0
- fusion_bench/method/moe_pruner/moe_pruner.py +304 -0
- fusion_bench/method/moe_pruner/utils/__init__.py +1 -0
- fusion_bench/method/moe_pruner/utils/data.py +154 -0
- fusion_bench/method/moe_pruner/utils/layerwrapper.py +61 -0
- fusion_bench/method/moe_pruner/utils/prune.py +313 -0
- fusion_bench/method/moe_pruner/utils/score.py +41 -0
- fusion_bench/method/pruning/__init__.py +1 -0
- fusion_bench/method/pruning/llama_sparsegpt_prune.py +223 -0
- fusion_bench/method/pruning/sparsegpt_utils/__init__.py +1 -0
- fusion_bench/method/pruning/sparsegpt_utils/sparsegpt.py +128 -0
- fusion_bench/method/pruning/wanda_utils/data.py +33 -14
- fusion_bench/method/randes/__init__.py +15 -0
- fusion_bench/method/randes/base_algorithm.py +1013 -0
- fusion_bench/method/randes/modelsoup.py +126 -0
- fusion_bench/method/randes/task_arithmetic.py +318 -0
- fusion_bench/method/sparselo/sparselo.py +20 -2
- fusion_bench/method/tall_mask/__init__.py +1 -0
- fusion_bench/method/tall_mask/task_arithmetic.py +133 -0
- fusion_bench/modelpool/lazy_state_dict_pool.py +15 -0
- fusion_bench/models/modeling_deepseek_v2/__init__.py +15 -0
- fusion_bench/models/modeling_deepseek_v2/configuration_deepseek.py +208 -0
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +1922 -0
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +38 -0
- fusion_bench/programs/fabric_fusion_program.py +5 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +8 -1
- fusion_bench/utils/__init__.py +1 -0
- fusion_bench/utils/data.py +1 -1
- fusion_bench/utils/lazy_state_dict.py +268 -0
- fusion_bench/utils/parameters.py +33 -0
- fusion_bench/utils/state_dict_arithmetic.py +74 -2
- fusion_bench/utils/type.py +1 -0
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/METADATA +6 -2
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/RECORD +77 -21
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/WHEEL +1 -1
- fusion_bench_config/dataset/image_classification/test/TALL10.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL12.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL16.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL18.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL10.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL12.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL16.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL18.yaml +28 -0
- fusion_bench_config/method/fw_merging/fw_hard.yaml +11 -0
- fusion_bench_config/method/fw_merging/fw_soft.yaml +12 -0
- fusion_bench_config/method/moe_pruner/moe_pruner.yaml +15 -0
- fusion_bench_config/method/pruning/llama_sparsegpt_pruning.yaml +16 -0
- fusion_bench_config/method/randes/superposed_model_soup.yaml +18 -0
- fusion_bench_config/method/randes/superposed_task_arithmetic.yaml +20 -0
- fusion_bench_config/method/randes/superposed_task_arithmetic_lora.yaml +20 -0
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +2 -1
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
- fusion_bench_config/method/tall_mask/task_arithmetic.yaml +4 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL10.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL12.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL16.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL18.yaml +29 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +8 -0
- fusion_bench_config/modelpool/CausalLMPool/deepseek-v2-lite.yaml +15 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral-8x7b.yaml +14 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/roberta-base_glue.yaml +69 -0
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,331 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This is modified based on https://github.com/EnnengYang/AdaMerging/blob/main/src/ties_merging_utils.py
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import copy
|
|
6
|
+
from collections import OrderedDict
|
|
7
|
+
from typing import List
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from torch import Tensor, nn
|
|
11
|
+
|
|
12
|
+
from fusion_bench.utils.type import StateDictType
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# Model conversion utils
|
|
16
|
+
def state_dict_to_vector(state_dict, remove_keys=[]):
|
|
17
|
+
"""
|
|
18
|
+
Convert a state dictionary to a vector, removing specified keys.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
state_dict (dict): The state dictionary to convert.
|
|
22
|
+
remove_keys (list): List of keys to remove from the state dictionary.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Tensor: A vector representation of the state dictionary.
|
|
26
|
+
"""
|
|
27
|
+
shared_state_dict = copy.deepcopy(state_dict)
|
|
28
|
+
for key in remove_keys:
|
|
29
|
+
if key in shared_state_dict:
|
|
30
|
+
del shared_state_dict[key]
|
|
31
|
+
sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items()))
|
|
32
|
+
return nn.utils.parameters_to_vector(
|
|
33
|
+
[value.reshape(-1) for key, value in sorted_shared_state_dict.items()]
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def vector_to_state_dict(vector, state_dict, remove_keys=[]):
|
|
38
|
+
"""
|
|
39
|
+
Convert a vector back to a state dictionary, removing specified keys.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
vector (Tensor): The vector to convert.
|
|
43
|
+
state_dict (dict): The reference state dictionary.
|
|
44
|
+
remove_keys (list): List of keys to remove from the state dictionary.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
dict: A state dictionary representation of the vector.
|
|
48
|
+
"""
|
|
49
|
+
# create a reference dict to define the order of the vector
|
|
50
|
+
reference_dict = copy.deepcopy(state_dict)
|
|
51
|
+
for key in remove_keys:
|
|
52
|
+
if key in reference_dict:
|
|
53
|
+
del reference_dict[key]
|
|
54
|
+
sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))
|
|
55
|
+
|
|
56
|
+
# create a shared state dict using the reference dict
|
|
57
|
+
nn.utils.vector_to_parameters(vector, sorted_reference_dict.values())
|
|
58
|
+
|
|
59
|
+
# add back the encoder and decoder embedding weights.
|
|
60
|
+
if "transformer.shared.weight" in sorted_reference_dict:
|
|
61
|
+
for key in remove_keys:
|
|
62
|
+
sorted_reference_dict[key] = sorted_reference_dict[
|
|
63
|
+
"transformer.shared.weight"
|
|
64
|
+
]
|
|
65
|
+
return sorted_reference_dict
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def add_ptm_to_tv(tv_dict, ptm_dict):
|
|
69
|
+
"""
|
|
70
|
+
Add the values of one state dictionary to another.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
tv_dict (dict): The target state dictionary.
|
|
74
|
+
ptm_dict (dict): The state dictionary to add.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
dict: The resulting state dictionary after addition.
|
|
78
|
+
"""
|
|
79
|
+
assert set(tv_dict.keys()) == set(
|
|
80
|
+
ptm_dict.keys()
|
|
81
|
+
), "Differing parameter names in models."
|
|
82
|
+
final_dict = copy.deepcopy(tv_dict)
|
|
83
|
+
for k, v in ptm_dict.items():
|
|
84
|
+
final_dict[k] = tv_dict[k] + v
|
|
85
|
+
return final_dict
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def check_parameterNamesMatch(checkpoints: List[StateDictType]) -> None:
|
|
89
|
+
"""
|
|
90
|
+
Check if the parameter names match across multiple checkpoints.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
checkpoints (list): List of state dictionaries to check.
|
|
94
|
+
|
|
95
|
+
Raises:
|
|
96
|
+
ValueError: If the parameter names do not match.
|
|
97
|
+
"""
|
|
98
|
+
parameter_names = set(checkpoints[0].keys())
|
|
99
|
+
|
|
100
|
+
if len(checkpoints) >= 2:
|
|
101
|
+
# raise ValueError("Number of models is less than 2.")
|
|
102
|
+
for checkpoint in checkpoints[1:]:
|
|
103
|
+
current_parameterNames = set(checkpoint.keys())
|
|
104
|
+
if current_parameterNames != parameter_names:
|
|
105
|
+
raise ValueError(
|
|
106
|
+
"Differing parameter names in models. "
|
|
107
|
+
f"The different parameters are {parameter_names.symmetric_difference(current_parameterNames)}"
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def check_state_dicts_equal(
|
|
112
|
+
state_dict1: StateDictType, state_dict2: StateDictType
|
|
113
|
+
) -> bool:
|
|
114
|
+
"""
|
|
115
|
+
Check if two state dictionaries are equal.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
state_dict1 (dict): The first state dictionary.
|
|
119
|
+
state_dict2 (dict): The second state dictionary.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
bool: True if the state dictionaries are equal, False otherwise.
|
|
123
|
+
"""
|
|
124
|
+
if set(state_dict1.keys()) != set(state_dict2.keys()):
|
|
125
|
+
return False
|
|
126
|
+
|
|
127
|
+
for key in state_dict1.keys():
|
|
128
|
+
if not torch.equal(state_dict1[key], state_dict2[key]):
|
|
129
|
+
return False
|
|
130
|
+
|
|
131
|
+
return True
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
# TIES MERGING UTILS
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def topk_values_mask(M, K=0.7, return_mask=False):
|
|
138
|
+
"""
|
|
139
|
+
Mask the top K values in a tensor.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
M (Tensor): The input tensor.
|
|
143
|
+
K (float): The proportion of top values to keep.
|
|
144
|
+
return_mask (bool): Whether to return the mask tensor.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
tuple: The masked tensor, the mean of the mask, and optionally the mask tensor.
|
|
148
|
+
"""
|
|
149
|
+
if K > 1:
|
|
150
|
+
K /= 100
|
|
151
|
+
|
|
152
|
+
original_shape = M.shape
|
|
153
|
+
if M.dim() == 1:
|
|
154
|
+
M = M.unsqueeze(0)
|
|
155
|
+
|
|
156
|
+
n, d = M.shape
|
|
157
|
+
k = int(d * K)
|
|
158
|
+
k = d - k # Keep top k elements instead of bottom k elements
|
|
159
|
+
|
|
160
|
+
# Find the k-th smallest element by magnitude for each row
|
|
161
|
+
kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True)
|
|
162
|
+
# Create a mask tensor with True for the top k elements in each row
|
|
163
|
+
mask = M.abs() >= kth_values
|
|
164
|
+
final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask
|
|
165
|
+
|
|
166
|
+
if return_mask:
|
|
167
|
+
return M * final_mask, final_mask.float().mean(dim=1), final_mask
|
|
168
|
+
return M * final_mask, final_mask.float().mean(dim=1)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def resolve_zero_signs(sign_to_mult, method="majority"):
|
|
172
|
+
"""
|
|
173
|
+
Resolve zero signs in a tensor by majority or minority rule.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
sign_to_mult (Tensor): The tensor with signs to resolve.
|
|
177
|
+
method (str): The method to use for resolving zero signs ("majority" or "minority").
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Tensor: The tensor with resolved signs.
|
|
181
|
+
"""
|
|
182
|
+
majority_sign = torch.sign(sign_to_mult.sum())
|
|
183
|
+
|
|
184
|
+
if method == "majority":
|
|
185
|
+
sign_to_mult[sign_to_mult == 0] = majority_sign
|
|
186
|
+
elif method == "minority":
|
|
187
|
+
sign_to_mult[sign_to_mult == 0] = -1 * majority_sign
|
|
188
|
+
return sign_to_mult
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def resolve_sign(v: Tensor):
|
|
192
|
+
"""
|
|
193
|
+
Resolve the sign of a tensor by majority rule.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
v (Tensor): The input tensor.
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
Tensor: The tensor with resolved signs.
|
|
200
|
+
"""
|
|
201
|
+
sign_to_mult = torch.sign(v.sum(dim=0))
|
|
202
|
+
sign_to_mult = resolve_zero_signs(sign_to_mult, "majority")
|
|
203
|
+
return sign_to_mult
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def disjoint_merge(v: Tensor, merge_func: str, sign_to_mult):
|
|
207
|
+
"""
|
|
208
|
+
Perform disjoint merging of a tensor using a specified merge function.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
v (Tensor): The input tensor.
|
|
212
|
+
merge_func (str): The merge function to use ("mean", "sum", or "max").
|
|
213
|
+
sign_to_mult (Tensor): The tensor with signs to use for merging.
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
Tensor: The merged tensor.
|
|
217
|
+
"""
|
|
218
|
+
merge_func = merge_func.split("-")[-1]
|
|
219
|
+
|
|
220
|
+
# If sign is provided then we select the corresponding entries and aggregate.
|
|
221
|
+
if sign_to_mult is not None:
|
|
222
|
+
rows_to_keep = torch.where(sign_to_mult.unsqueeze(0) > 0, v > 0, v < 0)
|
|
223
|
+
selected_entries = v * rows_to_keep
|
|
224
|
+
# Else we select all non-zero entries and aggregate.
|
|
225
|
+
else:
|
|
226
|
+
rows_to_keep = v != 0
|
|
227
|
+
selected_entries = v * rows_to_keep
|
|
228
|
+
|
|
229
|
+
if merge_func == "mean":
|
|
230
|
+
non_zero_counts = (selected_entries != 0).sum(dim=0).float()
|
|
231
|
+
disjoint_aggs = torch.sum(selected_entries, dim=0) / torch.clamp(
|
|
232
|
+
non_zero_counts, min=1
|
|
233
|
+
)
|
|
234
|
+
elif merge_func == "sum":
|
|
235
|
+
disjoint_aggs = torch.sum(selected_entries, dim=0)
|
|
236
|
+
elif merge_func == "max":
|
|
237
|
+
disjoint_aggs = selected_entries.abs().max(dim=0)[0]
|
|
238
|
+
disjoint_aggs *= sign_to_mult
|
|
239
|
+
else:
|
|
240
|
+
raise ValueError(f"Merge method {merge_func} is not defined.")
|
|
241
|
+
|
|
242
|
+
return disjoint_aggs
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def ties_merging(
|
|
246
|
+
flat_task_checks,
|
|
247
|
+
reset_thresh=None,
|
|
248
|
+
merge_func="",
|
|
249
|
+
):
|
|
250
|
+
"""
|
|
251
|
+
Perform TIES merging on a tensor.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
flat_task_checks (Tensor): The input tensor.
|
|
255
|
+
reset_thresh (float): The threshold for resetting values.
|
|
256
|
+
merge_func (str): The merge function to use.
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
Tensor: The merged tensor.
|
|
260
|
+
"""
|
|
261
|
+
all_checks = flat_task_checks.clone()
|
|
262
|
+
updated_checks, *_ = topk_values_mask(all_checks, K=reset_thresh, return_mask=False)
|
|
263
|
+
print("RESOLVING SIGN")
|
|
264
|
+
final_signs = resolve_sign(updated_checks)
|
|
265
|
+
assert final_signs is not None
|
|
266
|
+
|
|
267
|
+
print(f"Disjoint AGGREGATION: {merge_func}")
|
|
268
|
+
merged_tv = disjoint_merge(updated_checks, merge_func, final_signs)
|
|
269
|
+
|
|
270
|
+
return merged_tv
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def disjoint_merge_split(v: Tensor, merge_func: str, sign_to_mult):
|
|
274
|
+
"""
|
|
275
|
+
Perform disjoint merging of a tensor using a specified merge function and return selected entries.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
v (Tensor): The input tensor.
|
|
279
|
+
merge_func (str): The merge function to use ("sum").
|
|
280
|
+
sign_to_mult (Tensor): The tensor with signs to use for merging.
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
tuple: The selected entries and the merged tensor.
|
|
284
|
+
"""
|
|
285
|
+
merge_func = merge_func.split("-")[-1]
|
|
286
|
+
|
|
287
|
+
# If sign is provided then we select the corresponding entries and aggregate.
|
|
288
|
+
if sign_to_mult is not None:
|
|
289
|
+
rows_to_keep = torch.where(sign_to_mult.unsqueeze(0) > 0, v > 0, v < 0)
|
|
290
|
+
selected_entries = v * rows_to_keep
|
|
291
|
+
# Else we select all non-zero entries and aggregate.
|
|
292
|
+
else:
|
|
293
|
+
rows_to_keep = v != 0
|
|
294
|
+
selected_entries = v * rows_to_keep
|
|
295
|
+
|
|
296
|
+
if merge_func == "sum":
|
|
297
|
+
disjoint_aggs = torch.sum(selected_entries, dim=0)
|
|
298
|
+
else:
|
|
299
|
+
raise ValueError(f"Merge method {merge_func} is not defined.")
|
|
300
|
+
|
|
301
|
+
return selected_entries, disjoint_aggs
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def ties_merging_split(
|
|
305
|
+
flat_task_checks,
|
|
306
|
+
reset_thresh=None,
|
|
307
|
+
merge_func: str = "",
|
|
308
|
+
):
|
|
309
|
+
"""
|
|
310
|
+
Perform TIES merging on a tensor and return selected entries.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
flat_task_checks (Tensor): The input tensor.
|
|
314
|
+
reset_thresh (float): The threshold for resetting values.
|
|
315
|
+
merge_func (str): The merge function to use.
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
tuple: The selected entries and the merged tensor.
|
|
319
|
+
"""
|
|
320
|
+
all_checks = flat_task_checks.clone()
|
|
321
|
+
updated_checks, *_ = topk_values_mask(all_checks, K=reset_thresh, return_mask=False)
|
|
322
|
+
print("RESOLVING SIGN")
|
|
323
|
+
final_signs = resolve_sign(updated_checks)
|
|
324
|
+
assert final_signs is not None
|
|
325
|
+
|
|
326
|
+
print(f"Disjoint AGGREGATION: {merge_func}")
|
|
327
|
+
selected_entries, merged_tv = disjoint_merge_split(
|
|
328
|
+
updated_checks, merge_func, final_signs
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
return selected_entries, merged_tv
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
from typing import Dict, Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor, nn
|
|
5
|
+
|
|
6
|
+
from fusion_bench.models.modeling_deepseek_v2 import DeepseekV2MoEGate
|
|
7
|
+
|
|
8
|
+
from .hook import BaseHookFn
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class MoEPrunerHookFnForDeepseekV2Linear(BaseHookFn):
|
|
12
|
+
_routing_weights = None # set by gate hook
|
|
13
|
+
|
|
14
|
+
def __init__(self, linear: nn.Linear, name: str):
|
|
15
|
+
super().__init__(linear)
|
|
16
|
+
self.linear = linear
|
|
17
|
+
self.scalar_row = torch.zeros(
|
|
18
|
+
(linear.weight.size(1),), device=linear.weight.device
|
|
19
|
+
)
|
|
20
|
+
self.nsamples = 0
|
|
21
|
+
self.name = name
|
|
22
|
+
|
|
23
|
+
def __call__(self, linear, inps: Tuple[Tensor], out: Tensor):
|
|
24
|
+
assert len(inps) == 1
|
|
25
|
+
inp = inps[0]
|
|
26
|
+
if len(inp.shape) == 2:
|
|
27
|
+
inp = inp.unsqueeze(0)
|
|
28
|
+
|
|
29
|
+
batch_size = inp.shape[0]
|
|
30
|
+
if len(inp.shape) == 3:
|
|
31
|
+
inp = inp.reshape((-1, inp.shape[-1]))
|
|
32
|
+
# (NxL, C) -> (C, NxL)
|
|
33
|
+
inp = inp.t()
|
|
34
|
+
|
|
35
|
+
self.scalar_row *= self.nsamples / (self.nsamples + batch_size)
|
|
36
|
+
self.nsamples += batch_size
|
|
37
|
+
|
|
38
|
+
inp = inp.type(torch.float32)
|
|
39
|
+
routing_weights = self._routing_weights.t()
|
|
40
|
+
self.scalar_row += (
|
|
41
|
+
torch.norm(inp * routing_weights, p=2, dim=1) ** 2 / self.nsamples
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
def compute(self):
|
|
45
|
+
return torch.abs(self.linear.weight) * torch.sqrt(
|
|
46
|
+
self.scalar_row.reshape(1, -1)
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class MoEPrunerHookFnForDeepseekV2Gate(BaseHookFn):
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
router: DeepseekV2MoEGate,
|
|
54
|
+
linear_layer_hooks: Dict[str, MoEPrunerHookFnForDeepseekV2Linear],
|
|
55
|
+
top_k: int,
|
|
56
|
+
num_experts: int,
|
|
57
|
+
):
|
|
58
|
+
super().__init__(router)
|
|
59
|
+
self.router = router
|
|
60
|
+
self.linear_layer_hooks = linear_layer_hooks
|
|
61
|
+
self.top_k = top_k
|
|
62
|
+
self.num_experts = num_experts
|
|
63
|
+
|
|
64
|
+
def __call__(self, router, inps: Tuple[Tensor], out: Tuple[Tensor, Tensor, Tensor]):
|
|
65
|
+
assert len(inps) == 1
|
|
66
|
+
|
|
67
|
+
x = inps[0]
|
|
68
|
+
x = x.view(-1, x.shape[-1])
|
|
69
|
+
topk_ids, topk_weight, aux_loss = out
|
|
70
|
+
|
|
71
|
+
# One hot encode the selected experts to create an expert mask
|
|
72
|
+
# this will be used to easily index which expert is going to be sollicitated
|
|
73
|
+
expert_mask = torch.nn.functional.one_hot(
|
|
74
|
+
topk_ids, num_classes=self.num_experts
|
|
75
|
+
).permute(2, 1, 0)
|
|
76
|
+
|
|
77
|
+
for expert_idx in range(self.num_experts):
|
|
78
|
+
idx, top_x = torch.where(expert_mask[expert_idx])
|
|
79
|
+
for name, hook in self.linear_layer_hooks.items():
|
|
80
|
+
if not name.startswith(f"{expert_idx}."):
|
|
81
|
+
continue
|
|
82
|
+
hook._routing_weights = topk_weight[top_x, idx, None]
|
|
83
|
+
|
|
84
|
+
def compute(self):
|
|
85
|
+
pass
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from typing import Tuple
|
|
3
|
+
|
|
4
|
+
from torch import Tensor, nn
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class BaseHookFn:
|
|
8
|
+
def __init__(self, module: nn.Module):
|
|
9
|
+
self.module = module
|
|
10
|
+
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def compute(self) -> Tensor:
|
|
13
|
+
"""
|
|
14
|
+
Compute the importance scores.
|
|
15
|
+
"""
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def __call__(self, router, inps: Tuple[Tensor], out: Tensor):
|
|
20
|
+
"""
|
|
21
|
+
Hook function to be called during the forward pass.
|
|
22
|
+
"""
|
|
23
|
+
pass
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from typing import Dict, Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from torch import Tensor, nn
|
|
6
|
+
|
|
7
|
+
from .hook import BaseHookFn
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class MoEPrunerHookFnForMixtralLinear(BaseHookFn):
|
|
11
|
+
_routing_weights = None # set by gate hook
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
linear: nn.Linear,
|
|
16
|
+
name: str,
|
|
17
|
+
):
|
|
18
|
+
super().__init__(linear)
|
|
19
|
+
self.linear = linear
|
|
20
|
+
self.scalar_row = torch.zeros(
|
|
21
|
+
(linear.weight.size(1),), device=linear.weight.device
|
|
22
|
+
)
|
|
23
|
+
self.nsamples = 0
|
|
24
|
+
self.name = name
|
|
25
|
+
|
|
26
|
+
def compute(self):
|
|
27
|
+
return torch.abs(self.linear.weight) * torch.sqrt(
|
|
28
|
+
self.scalar_row.reshape(1, -1)
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
def __call__(self, linear: nn.Linear, inps: Tuple[Tensor], out: Tensor):
|
|
32
|
+
assert len(inps) == 1
|
|
33
|
+
inp = inps[0]
|
|
34
|
+
if len(inp.shape) == 2:
|
|
35
|
+
inp = inp.unsqueeze(0)
|
|
36
|
+
|
|
37
|
+
batch_size = inp.shape[0]
|
|
38
|
+
if len(inp.shape) == 3:
|
|
39
|
+
inp = inp.reshape((-1, inp.shape[-1]))
|
|
40
|
+
# (NxL, C) -> (C, NxL)
|
|
41
|
+
inp = inp.t()
|
|
42
|
+
|
|
43
|
+
self.scalar_row *= self.nsamples / (self.nsamples + batch_size)
|
|
44
|
+
self.nsamples += batch_size
|
|
45
|
+
|
|
46
|
+
inp = inp.type(torch.float32)
|
|
47
|
+
routing_weights = self._routing_weights.t()
|
|
48
|
+
self.scalar_row += (
|
|
49
|
+
torch.norm(inp * routing_weights, p=2, dim=1) ** 2 / self.nsamples
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class MoEPrunerHookFnForMixtralGate(BaseHookFn):
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
router: nn.Module,
|
|
58
|
+
linear_layer_hooks: Dict[str, MoEPrunerHookFnForMixtralLinear],
|
|
59
|
+
top_k: int,
|
|
60
|
+
num_experts: int,
|
|
61
|
+
):
|
|
62
|
+
self.nsamples = 0
|
|
63
|
+
self.linear_layer_hooks = linear_layer_hooks
|
|
64
|
+
self.top_k = top_k
|
|
65
|
+
self.num_experts = num_experts
|
|
66
|
+
super().__init__(router)
|
|
67
|
+
|
|
68
|
+
def __call__(self, router, inps: Tuple[Tensor], out: Tensor):
|
|
69
|
+
assert len(inps) == 1
|
|
70
|
+
inp = inps[0]
|
|
71
|
+
|
|
72
|
+
router_logits = out
|
|
73
|
+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
|
74
|
+
routing_weights, selected_experts = torch.topk(
|
|
75
|
+
routing_weights, self.top_k, dim=-1
|
|
76
|
+
)
|
|
77
|
+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
|
78
|
+
|
|
79
|
+
# One hot encode the selected experts to create an expert mask
|
|
80
|
+
# this will be used to easily index which expert is going to be sollicitated
|
|
81
|
+
expert_mask = torch.nn.functional.one_hot(
|
|
82
|
+
selected_experts, num_classes=self.num_experts
|
|
83
|
+
).permute(2, 1, 0)
|
|
84
|
+
|
|
85
|
+
for expert_idx in range(self.num_experts):
|
|
86
|
+
idx, top_x = torch.where(expert_mask[expert_idx])
|
|
87
|
+
for name, hook in self.linear_layer_hooks.items():
|
|
88
|
+
if not name.startswith(f"{expert_idx}."):
|
|
89
|
+
continue
|
|
90
|
+
hook._routing_weights = routing_weights[top_x, idx, None]
|
|
91
|
+
|
|
92
|
+
def compute(self):
|
|
93
|
+
pass
|