fusion-bench 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/method/__init__.py +4 -0
- fusion_bench/method/fw_merging/__init__.py +2 -0
- fusion_bench/method/fw_merging/fw_hard.py +448 -0
- fusion_bench/method/fw_merging/fw_soft.py +519 -0
- fusion_bench/method/fw_merging/utils.py +331 -0
- fusion_bench/method/moe_pruner/__init__.py +7 -0
- fusion_bench/method/moe_pruner/hooks/__init__.py +6 -0
- fusion_bench/method/moe_pruner/hooks/deepseek_v2.py +85 -0
- fusion_bench/method/moe_pruner/hooks/hook.py +23 -0
- fusion_bench/method/moe_pruner/hooks/mixtral.py +93 -0
- fusion_bench/method/moe_pruner/moe_pruner.py +304 -0
- fusion_bench/method/moe_pruner/utils/__init__.py +1 -0
- fusion_bench/method/moe_pruner/utils/data.py +154 -0
- fusion_bench/method/moe_pruner/utils/layerwrapper.py +61 -0
- fusion_bench/method/moe_pruner/utils/prune.py +313 -0
- fusion_bench/method/moe_pruner/utils/score.py +41 -0
- fusion_bench/method/pruning/__init__.py +1 -0
- fusion_bench/method/pruning/llama_sparsegpt_prune.py +223 -0
- fusion_bench/method/pruning/sparsegpt_utils/__init__.py +1 -0
- fusion_bench/method/pruning/sparsegpt_utils/sparsegpt.py +128 -0
- fusion_bench/method/pruning/wanda_utils/data.py +33 -14
- fusion_bench/method/randes/__init__.py +15 -0
- fusion_bench/method/randes/base_algorithm.py +1013 -0
- fusion_bench/method/randes/modelsoup.py +126 -0
- fusion_bench/method/randes/task_arithmetic.py +318 -0
- fusion_bench/method/sparselo/sparselo.py +20 -2
- fusion_bench/method/tall_mask/__init__.py +1 -0
- fusion_bench/method/tall_mask/task_arithmetic.py +133 -0
- fusion_bench/modelpool/causal_lm/causal_lm.py +73 -10
- fusion_bench/modelpool/lazy_state_dict_pool.py +15 -0
- fusion_bench/models/modeling_deepseek_v2/__init__.py +15 -0
- fusion_bench/models/modeling_deepseek_v2/configuration_deepseek.py +208 -0
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +1922 -0
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +38 -0
- fusion_bench/programs/fabric_fusion_program.py +5 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +8 -1
- fusion_bench/utils/__init__.py +1 -0
- fusion_bench/utils/data.py +1 -1
- fusion_bench/utils/lazy_state_dict.py +268 -0
- fusion_bench/utils/parameters.py +33 -0
- fusion_bench/utils/state_dict_arithmetic.py +74 -2
- fusion_bench/utils/type.py +1 -0
- {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/METADATA +10 -3
- {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/RECORD +86 -22
- {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/WHEEL +1 -1
- fusion_bench_config/dataset/image_classification/test/TALL10.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL12.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL16.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL18.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL10.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL12.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL16.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL18.yaml +28 -0
- fusion_bench_config/method/fw_merging/fw_hard.yaml +11 -0
- fusion_bench_config/method/fw_merging/fw_soft.yaml +12 -0
- fusion_bench_config/method/moe_pruner/moe_pruner.yaml +15 -0
- fusion_bench_config/method/pruning/llama_sparsegpt_pruning.yaml +16 -0
- fusion_bench_config/method/randes/superposed_model_soup.yaml +18 -0
- fusion_bench_config/method/randes/superposed_task_arithmetic.yaml +20 -0
- fusion_bench_config/method/randes/superposed_task_arithmetic_lora.yaml +20 -0
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +2 -1
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
- fusion_bench_config/method/tall_mask/task_arithmetic.yaml +4 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL10.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL12.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL16.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL18.yaml +29 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +8 -0
- fusion_bench_config/modelpool/CausalLMPool/deepseek-v2-lite.yaml +15 -0
- fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.1-8B-Instruct.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.1-8B.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.2-3B-Instruct.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.2-3B.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-2b-it.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-2b.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-9b-it.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-9b.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral-8x7b.yaml +14 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/roberta-base_glue.yaml +69 -0
- {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import Tensor, nn
|
|
6
|
+
from torch.nn import functional as F
|
|
7
|
+
from tqdm.auto import tqdm
|
|
8
|
+
from transformers import MixtralForCausalLM
|
|
9
|
+
from transformers.models.mixtral.modeling_mixtral import (
|
|
10
|
+
MixtralDecoderLayer,
|
|
11
|
+
MixtralSparseMoeBlock,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
15
|
+
from fusion_bench.method.pruning.prune_utils import (
|
|
16
|
+
PruningType,
|
|
17
|
+
compute_sparsity,
|
|
18
|
+
find_linear_layers,
|
|
19
|
+
semistructured_magnitude_prune_,
|
|
20
|
+
unstructured_magnitude_prune_,
|
|
21
|
+
)
|
|
22
|
+
from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
|
|
23
|
+
from fusion_bench.modelpool import CausalLMPool
|
|
24
|
+
from fusion_bench.models.modeling_deepseek_v2 import (
|
|
25
|
+
DeepseekV2DecoderLayer,
|
|
26
|
+
DeepseekV2ForCausalLM,
|
|
27
|
+
DeepseekV2MLP,
|
|
28
|
+
DeepseekV2MoE,
|
|
29
|
+
DeepseekV2MoEGate,
|
|
30
|
+
)
|
|
31
|
+
from fusion_bench.utils import timeit_context
|
|
32
|
+
from fusion_bench.utils.cache_utils import cache_to_disk
|
|
33
|
+
from fusion_bench.utils.devices import to_device
|
|
34
|
+
|
|
35
|
+
from .hooks.deepseek_v2 import (
|
|
36
|
+
MoEPrunerHookFnForDeepseekV2Gate,
|
|
37
|
+
MoEPrunerHookFnForDeepseekV2Linear,
|
|
38
|
+
)
|
|
39
|
+
from .hooks.hook import BaseHookFn
|
|
40
|
+
from .hooks.mixtral import (
|
|
41
|
+
MoEPrunerHookFnForMixtralGate,
|
|
42
|
+
MoEPrunerHookFnForMixtralLinear,
|
|
43
|
+
)
|
|
44
|
+
from .utils.data import get_loaders
|
|
45
|
+
from .utils.prune import prepare_calibration_input
|
|
46
|
+
|
|
47
|
+
MoEModel = TypeVar("MoEModel", bound=Union[MixtralForCausalLM, DeepseekV2ForCausalLM])
|
|
48
|
+
|
|
49
|
+
log = logging.getLogger(__name__)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class MoEPruner(BaseAlgorithm, SimpleProfilerMixin, LightningFabricMixin):
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
nsamples: int,
|
|
57
|
+
seed: int,
|
|
58
|
+
device: str,
|
|
59
|
+
prune_type: PruningType,
|
|
60
|
+
sparsity_ratio: float,
|
|
61
|
+
n: int,
|
|
62
|
+
m: int,
|
|
63
|
+
max_seqlen: Optional[int] = None,
|
|
64
|
+
):
|
|
65
|
+
self.nsamples = nsamples
|
|
66
|
+
self.seed = seed
|
|
67
|
+
self.device = device
|
|
68
|
+
self.max_seqlen = max_seqlen
|
|
69
|
+
self.prune_type = prune_type
|
|
70
|
+
self.sparsity_ratio = sparsity_ratio
|
|
71
|
+
self.n = n
|
|
72
|
+
self.m = m
|
|
73
|
+
super().__init__()
|
|
74
|
+
|
|
75
|
+
def run(self, modelpool: CausalLMPool):
|
|
76
|
+
# load pre-trained model or the first model in the pool
|
|
77
|
+
with self.profile("load_model"):
|
|
78
|
+
model: MoEModel = modelpool.load_pretrained_or_first_model()
|
|
79
|
+
if self.max_seqlen is not None:
|
|
80
|
+
model.seqlen = min(
|
|
81
|
+
model.config.max_position_embeddings,
|
|
82
|
+
self.max_seqlen,
|
|
83
|
+
)
|
|
84
|
+
tokenizer = modelpool.load_tokenizer()
|
|
85
|
+
|
|
86
|
+
inps, outs, attention_mask, position_ids, position_embeddings = (
|
|
87
|
+
self.prepare_calibration_data(model, tokenizer)
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
self.prune_using_calibration_data_(
|
|
91
|
+
model,
|
|
92
|
+
inps=inps,
|
|
93
|
+
outs=outs,
|
|
94
|
+
attention_mask=attention_mask,
|
|
95
|
+
position_ids=position_ids,
|
|
96
|
+
position_embeddings=position_embeddings,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
return model
|
|
100
|
+
|
|
101
|
+
def prepare_calibration_data(self, model: MoEModel, tokenizer):
|
|
102
|
+
"""
|
|
103
|
+
Prepare calibration data for pruning with caching.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
model (LlamaForCausalLM): Model to be pruned.
|
|
107
|
+
tokenizer: Tokenizer for the model.
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
Tuple: Calibration data (inputs, outputs, attention mask, position IDs).
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
@cache_to_disk(
|
|
114
|
+
f"outputs/cache/{model.config.name_or_path.split('/')[-1]}/calibration_data.pkl"
|
|
115
|
+
)
|
|
116
|
+
def _prepare_calibration_data(model, tokenizer):
|
|
117
|
+
return self._prepare_calibration_data(model, tokenizer)
|
|
118
|
+
|
|
119
|
+
return _prepare_calibration_data(model, tokenizer)
|
|
120
|
+
|
|
121
|
+
def _prepare_calibration_data(self, model, tokenizer):
|
|
122
|
+
"""
|
|
123
|
+
Prepare calibration data for pruning.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
model (LlamaForCausalLM): Model to be pruned.
|
|
127
|
+
tokenizer: Tokenizer for the model.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
Tuple: Calibration data (inputs, outputs, attention mask, position IDs).
|
|
131
|
+
"""
|
|
132
|
+
with timeit_context("loading calibration data"):
|
|
133
|
+
dataloader, _ = get_loaders(
|
|
134
|
+
"c4",
|
|
135
|
+
nsamples=self.nsamples,
|
|
136
|
+
seed=self.seed,
|
|
137
|
+
seqlen=model.seqlen,
|
|
138
|
+
tokenizer=tokenizer,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
with torch.no_grad():
|
|
142
|
+
# collect input to the first layer
|
|
143
|
+
inps, outs, attention_mask, position_ids, position_embeddings = (
|
|
144
|
+
prepare_calibration_input(model, dataloader, self.device)
|
|
145
|
+
)
|
|
146
|
+
return inps, outs, attention_mask, position_ids, position_embeddings
|
|
147
|
+
|
|
148
|
+
def prune_using_calibration_data_(
|
|
149
|
+
self,
|
|
150
|
+
model: MoEModel,
|
|
151
|
+
*,
|
|
152
|
+
inps,
|
|
153
|
+
outs,
|
|
154
|
+
attention_mask,
|
|
155
|
+
position_ids,
|
|
156
|
+
position_embeddings,
|
|
157
|
+
):
|
|
158
|
+
model.eval()
|
|
159
|
+
layers = model.model.layers
|
|
160
|
+
for layer_idx, layer in tqdm(
|
|
161
|
+
enumerate(layers),
|
|
162
|
+
"Pruning Layers",
|
|
163
|
+
total=len(layers),
|
|
164
|
+
dynamic_ncols=True,
|
|
165
|
+
):
|
|
166
|
+
if (
|
|
167
|
+
hasattr(model, "hf_device_map")
|
|
168
|
+
and f"model.layers.{layer_idx}" in model.hf_device_map
|
|
169
|
+
):
|
|
170
|
+
# handle the case for large models, when the device map has multiple GPUs;
|
|
171
|
+
dev = model.hf_device_map[f"model.layers.{layer_idx}"]
|
|
172
|
+
inps, outs, attention_mask, position_ids, position_embeddings = (
|
|
173
|
+
inps.to(dev),
|
|
174
|
+
outs.to(dev),
|
|
175
|
+
attention_mask.to(dev) if attention_mask is not None else None,
|
|
176
|
+
position_ids.to(dev) if position_ids is not None else None,
|
|
177
|
+
(
|
|
178
|
+
to_device(position_embeddings, dev)
|
|
179
|
+
if position_embeddings is not None
|
|
180
|
+
else None
|
|
181
|
+
),
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
if isinstance(layer, MixtralDecoderLayer):
|
|
185
|
+
linear_layers = find_linear_layers(layer.block_sparse_moe.experts)
|
|
186
|
+
elif isinstance(layer, DeepseekV2DecoderLayer):
|
|
187
|
+
if isinstance(layer.mlp, DeepseekV2MoE):
|
|
188
|
+
linear_layers = find_linear_layers(layer.mlp.experts)
|
|
189
|
+
elif isinstance(layer.mlp, DeepseekV2MLP):
|
|
190
|
+
# compute the input to the next layer
|
|
191
|
+
with torch.no_grad():
|
|
192
|
+
for j in range(self.nsamples):
|
|
193
|
+
outs[j] = layer(
|
|
194
|
+
inps[j].unsqueeze(0),
|
|
195
|
+
attention_mask=attention_mask,
|
|
196
|
+
position_ids=position_ids,
|
|
197
|
+
position_embeddings=position_embeddings,
|
|
198
|
+
)[0]
|
|
199
|
+
inps, outs = outs, inps
|
|
200
|
+
continue
|
|
201
|
+
else:
|
|
202
|
+
raise ValueError(f"Unsupported layer type: {type(layer)}")
|
|
203
|
+
|
|
204
|
+
linear_hooks: Dict[str, BaseHookFn] = {}
|
|
205
|
+
handles: List[torch.utils.hooks.RemovableHandle] = []
|
|
206
|
+
for name, linear in linear_layers.items():
|
|
207
|
+
if isinstance(model, MixtralForCausalLM):
|
|
208
|
+
hook_fn = MoEPrunerHookFnForMixtralLinear(linear, name)
|
|
209
|
+
elif isinstance(model, DeepseekV2ForCausalLM):
|
|
210
|
+
hook_fn = MoEPrunerHookFnForDeepseekV2Linear(linear, name)
|
|
211
|
+
else:
|
|
212
|
+
raise ValueError(f"Unsupported model type: {type(model)}")
|
|
213
|
+
linear_hooks[name] = hook_fn
|
|
214
|
+
handles.append(linear.register_forward_hook(hook_fn))
|
|
215
|
+
|
|
216
|
+
if isinstance(model, MixtralForCausalLM):
|
|
217
|
+
gate_hook = MoEPrunerHookFnForMixtralGate(
|
|
218
|
+
layer.block_sparse_moe.gate,
|
|
219
|
+
linear_hooks,
|
|
220
|
+
top_k=layer.block_sparse_moe.top_k,
|
|
221
|
+
num_experts=layer.block_sparse_moe.num_experts,
|
|
222
|
+
)
|
|
223
|
+
handles.append(
|
|
224
|
+
layer.block_sparse_moe.gate.register_forward_hook(gate_hook)
|
|
225
|
+
)
|
|
226
|
+
elif isinstance(model, DeepseekV2ForCausalLM):
|
|
227
|
+
gate_hook = MoEPrunerHookFnForDeepseekV2Gate(
|
|
228
|
+
layer.mlp.gate,
|
|
229
|
+
linear_hooks,
|
|
230
|
+
top_k=layer.mlp.gate.top_k,
|
|
231
|
+
num_experts=layer.mlp.config.n_routed_experts,
|
|
232
|
+
)
|
|
233
|
+
handles.append(layer.mlp.gate.register_forward_hook(gate_hook))
|
|
234
|
+
else:
|
|
235
|
+
raise ValueError(f"Unsupported model type: {type(model)}")
|
|
236
|
+
|
|
237
|
+
with torch.no_grad():
|
|
238
|
+
for j in range(self.nsamples):
|
|
239
|
+
outs[j] = layer(
|
|
240
|
+
inps[j].unsqueeze(0),
|
|
241
|
+
attention_mask=attention_mask,
|
|
242
|
+
position_ids=position_ids,
|
|
243
|
+
position_embeddings=position_embeddings,
|
|
244
|
+
)[0]
|
|
245
|
+
|
|
246
|
+
# compute the importance scores and remove the hooks
|
|
247
|
+
metrics = {}
|
|
248
|
+
for name, hook in linear_hooks.items():
|
|
249
|
+
metrics[name] = hook.compute().detach().cpu()
|
|
250
|
+
for h in handles:
|
|
251
|
+
h.remove()
|
|
252
|
+
|
|
253
|
+
# prune the weights based on the importance scores
|
|
254
|
+
if self.prune_type == PruningType.UNSTRUCTURED:
|
|
255
|
+
for name, linear in linear_layers.items():
|
|
256
|
+
log.info(f"Pruning {name}")
|
|
257
|
+
unstructured_magnitude_prune_(
|
|
258
|
+
linear.weight.data,
|
|
259
|
+
metrics[name].to(linear.weight.device),
|
|
260
|
+
sparsity_ratio=self.sparsity_ratio,
|
|
261
|
+
)
|
|
262
|
+
self.check_sparsity(linear.weight)
|
|
263
|
+
elif self.prune_type == PruningType.SEMISTRUCTURED:
|
|
264
|
+
for name, linear in linear_layers.items():
|
|
265
|
+
log.info(f"Pruning {name}")
|
|
266
|
+
semistructured_magnitude_prune_(
|
|
267
|
+
linear.weight.data,
|
|
268
|
+
metrics[name].to(linear.weight.device),
|
|
269
|
+
n=self.n,
|
|
270
|
+
m=self.m,
|
|
271
|
+
)
|
|
272
|
+
self.check_sparsity(linear.weight)
|
|
273
|
+
else:
|
|
274
|
+
raise ValueError(f"Invalid pruning type: {self.prune_type}")
|
|
275
|
+
|
|
276
|
+
# compute the input to the next layer
|
|
277
|
+
with torch.no_grad():
|
|
278
|
+
for j in range(self.nsamples):
|
|
279
|
+
outs[j] = layer(
|
|
280
|
+
inps[j].unsqueeze(0),
|
|
281
|
+
attention_mask=attention_mask,
|
|
282
|
+
position_ids=position_ids,
|
|
283
|
+
position_embeddings=position_embeddings,
|
|
284
|
+
)[0]
|
|
285
|
+
inps, outs = outs, inps
|
|
286
|
+
|
|
287
|
+
@torch.no_grad()
|
|
288
|
+
def check_sparsity(self, weight: Tensor, tol: float = 0.01):
|
|
289
|
+
"""
|
|
290
|
+
Check the sparsity of the weight tensor.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
weight (Tensor): Weight tensor.
|
|
294
|
+
tol (float): Tolerance for sparsity check.
|
|
295
|
+
|
|
296
|
+
Raises:
|
|
297
|
+
ValueError: If the pruning type is invalid.
|
|
298
|
+
"""
|
|
299
|
+
if self.prune_type == PruningType.UNSTRUCTURED:
|
|
300
|
+
assert (compute_sparsity(weight) - self.sparsity_ratio).abs() < tol
|
|
301
|
+
elif self.prune_type == PruningType.SEMISTRUCTURED:
|
|
302
|
+
assert (compute_sparsity(weight) - self.n / self.m).abs() < tol
|
|
303
|
+
else:
|
|
304
|
+
raise ValueError(f"Invalid pruning type: {self.prune_type}")
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .score import layer_load_balance_score
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
# Code adapted from https://github.com/IST-DASLab/sparsegpt/blob/master/datautils.py
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
from typing import List, Optional, Tuple, cast # noqa: F401
|
|
5
|
+
import os
|
|
6
|
+
from datasets import load_dataset
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
from tqdm.auto import tqdm
|
|
9
|
+
from transformers import PreTrainedTokenizer
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# Wrapper for tokenized input IDs
|
|
13
|
+
class TokenizerWrapper:
|
|
14
|
+
def __init__(self, input_ids):
|
|
15
|
+
self.input_ids = input_ids
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# Load and process wikitext2 dataset
|
|
19
|
+
def get_wikitext2(
|
|
20
|
+
nsamples: int,
|
|
21
|
+
seed: int,
|
|
22
|
+
seqlen: int,
|
|
23
|
+
tokenizer: PreTrainedTokenizer,
|
|
24
|
+
data_path: str = "wikitext",
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
Load and preprocess the Wikitext-2 dataset for training and testing.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
nsamples (int): Number of samples to generate from the training set.
|
|
31
|
+
seed (int): Random seed for reproducibility.
|
|
32
|
+
seqlen (int): Length of the sequence to be used for training.
|
|
33
|
+
tokenizer (PreTrainedTokenizer): Tokenizer to encode the text data.
|
|
34
|
+
data_path (str, optional): Path to the dataset. Defaults to "wikitext".
|
|
35
|
+
"""
|
|
36
|
+
# Load train and test datasets
|
|
37
|
+
traindata = load_dataset(data_path, "wikitext-2-raw-v1", split="train")
|
|
38
|
+
testdata = load_dataset(data_path, "wikitext-2-raw-v1", split="test")
|
|
39
|
+
|
|
40
|
+
# Encode datasets
|
|
41
|
+
trainenc = tokenizer(" ".join(traindata["text"]), return_tensors="pt")
|
|
42
|
+
testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")
|
|
43
|
+
|
|
44
|
+
# Generate samples from training set
|
|
45
|
+
random.seed(seed)
|
|
46
|
+
trainloader: List[Tuple[Tensor, Tensor]] = []
|
|
47
|
+
for _ in tqdm(range(nsamples), desc="Generating samples"):
|
|
48
|
+
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
|
49
|
+
j = i + seqlen
|
|
50
|
+
inp: Tensor = trainenc.input_ids[:, i:j]
|
|
51
|
+
tar = inp.clone()
|
|
52
|
+
tar[:, :-1] = -100
|
|
53
|
+
trainloader.append((inp, tar))
|
|
54
|
+
return trainloader, testenc
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# Load and process c4 dataset
|
|
58
|
+
def get_c4(
|
|
59
|
+
nsamples: int,
|
|
60
|
+
seed: int,
|
|
61
|
+
seqlen: int,
|
|
62
|
+
tokenizer,
|
|
63
|
+
data_path: str = "allenai/c4",
|
|
64
|
+
) -> Tuple[List[Tuple[Tensor, Tensor]], TokenizerWrapper]:
|
|
65
|
+
"""
|
|
66
|
+
Load and process the c4 dataset.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
nsamples (int): Number of samples to generate from the training set.
|
|
70
|
+
seed (int): Seed for random number generation.
|
|
71
|
+
seqlen (int): Length of each sequence.
|
|
72
|
+
tokenizer: Tokenizer object for encoding the text.
|
|
73
|
+
data_path (str, optional): Path to the c4 dataset. Defaults to "allenai/c4".
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
tuple (Tuple[List[Tuple[Tensor, Tensor]], TokenizerWrapper]): Tuple containing the training samples and the validation dataset.
|
|
77
|
+
"""
|
|
78
|
+
# Load train and validation datasets
|
|
79
|
+
if os.path.exists(".cache/allenai--c4/en/c4-train.00000-of-01024.json.gz"):
|
|
80
|
+
traindata = load_dataset(
|
|
81
|
+
"json",
|
|
82
|
+
data_files={
|
|
83
|
+
"train": ".cache/allenai--c4/en/c4-train.00000-of-01024.json.gz"
|
|
84
|
+
},
|
|
85
|
+
split="train",
|
|
86
|
+
)
|
|
87
|
+
else:
|
|
88
|
+
traindata = load_dataset(
|
|
89
|
+
data_path,
|
|
90
|
+
# "allenai--c4", # https://github.com/huggingface/datasets/issues/6559
|
|
91
|
+
data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
|
|
92
|
+
split="train",
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
if os.path.exists(".cache/allenai--c4/en/c4-validation.00000-of-00008.json.gz"):
|
|
96
|
+
valdata = load_dataset(
|
|
97
|
+
"json",
|
|
98
|
+
data_files={
|
|
99
|
+
"validation": ".cache/allenai--c4/en/c4-validation.00000-of-00008.json.gz",
|
|
100
|
+
},
|
|
101
|
+
split="validation",
|
|
102
|
+
)
|
|
103
|
+
else:
|
|
104
|
+
valdata = load_dataset(
|
|
105
|
+
data_path,
|
|
106
|
+
# "allenai--c4",
|
|
107
|
+
data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
|
|
108
|
+
split="validation",
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# Generate samples from training set
|
|
112
|
+
if seed is not None:
|
|
113
|
+
random.seed(seed)
|
|
114
|
+
|
|
115
|
+
trainloader = []
|
|
116
|
+
for _ in tqdm(range(nsamples), desc="Generating samples"):
|
|
117
|
+
while True:
|
|
118
|
+
i = random.randint(0, len(traindata) - 1)
|
|
119
|
+
trainenc = tokenizer(traindata[i]["text"], return_tensors="pt")
|
|
120
|
+
if trainenc.input_ids.shape[1] > seqlen:
|
|
121
|
+
break
|
|
122
|
+
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
|
123
|
+
j = i + seqlen
|
|
124
|
+
inp = trainenc.input_ids[:, i:j]
|
|
125
|
+
tar = inp.clone()
|
|
126
|
+
tar[:, :-1] = -100
|
|
127
|
+
trainloader.append((inp, tar))
|
|
128
|
+
|
|
129
|
+
# Prepare validation dataset
|
|
130
|
+
valenc = tokenizer(" ".join(valdata[:1100]["text"]), return_tensors="pt")
|
|
131
|
+
valenc = valenc.input_ids[:, : (256 * seqlen)]
|
|
132
|
+
valenc = TokenizerWrapper(valenc)
|
|
133
|
+
return trainloader, valenc
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
# Function to select the appropriate loader based on dataset name
|
|
137
|
+
def get_loaders(
|
|
138
|
+
name: str, nsamples: int = 128, seed: int = 0, seqlen: int = 2048, tokenizer=None
|
|
139
|
+
):
|
|
140
|
+
"""
|
|
141
|
+
Get the data loaders for the specified dataset.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
name (str): The name of the dataset. Supported values are "wikitext2" and "c4".
|
|
145
|
+
nsamples (int, optional): Number of samples to generate from the dataset. Defaults to 128.
|
|
146
|
+
seed (int, optional): Random seed for reproducibility. Defaults to 0.
|
|
147
|
+
seqlen (int, optional): Length of the sequence to be used for training. Defaults to 2048.
|
|
148
|
+
tokenizer (optional): Tokenizer to encode the text data. Defaults to None.
|
|
149
|
+
"""
|
|
150
|
+
if "wikitext2" in name:
|
|
151
|
+
return get_wikitext2(nsamples, seed, seqlen, tokenizer)
|
|
152
|
+
if "c4" in name:
|
|
153
|
+
return get_c4(nsamples, seed, seqlen, tokenizer)
|
|
154
|
+
raise ValueError(f"Unknown dataset: {name}")
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
# Define WrappedGPT class
|
|
6
|
+
class WrappedGPT:
|
|
7
|
+
"""
|
|
8
|
+
This class wraps a GPT layer for specific operations.
|
|
9
|
+
|
|
10
|
+
Attributes:
|
|
11
|
+
layer (nn.Linear | nn.Module): The GPT layer to be wrapped.
|
|
12
|
+
dev (torch.device): The device on which the layer's weights are stored.
|
|
13
|
+
rows (int): The number of rows in the layer's weight matrix.
|
|
14
|
+
columns (int): The number of columns in the layer's weight matrix.
|
|
15
|
+
scaler_row (torch.Tensor): A tensor to store the scaler values for each column.
|
|
16
|
+
nsamples (int): The number of samples processed.
|
|
17
|
+
layer_id (int): The ID of the layer.
|
|
18
|
+
layer_name (str): The name of the layer.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, layer: nn.Linear | nn.Module, layer_id=0, layer_name="none"):
|
|
22
|
+
"""
|
|
23
|
+
Initialize the WrappedGPT class.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
layer (nn.Linear | nn.Module): The GPT layer to be wrapped.
|
|
27
|
+
layer_id (int, optional): The ID of the layer. Defaults to 0.
|
|
28
|
+
layer_name (str, optional): The name of the layer. Defaults to "none".
|
|
29
|
+
"""
|
|
30
|
+
self.layer = layer
|
|
31
|
+
self.dev = self.layer.weight.device
|
|
32
|
+
self.rows = layer.weight.data.shape[0]
|
|
33
|
+
self.columns = layer.weight.data.shape[1]
|
|
34
|
+
|
|
35
|
+
self.scaler_row = torch.zeros((self.columns), device=self.dev)
|
|
36
|
+
self.nsamples = 0
|
|
37
|
+
|
|
38
|
+
self.layer_id = layer_id
|
|
39
|
+
self.layer_name = layer_name
|
|
40
|
+
|
|
41
|
+
def add_batch(self, inp: torch.Tensor, out: torch.Tensor):
|
|
42
|
+
"""
|
|
43
|
+
Add a batch of input and output tensors to the scaler_row.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
inp (torch.Tensor): The input tensor.
|
|
47
|
+
out (torch.Tensor): The output tensor.
|
|
48
|
+
"""
|
|
49
|
+
if len(inp.shape) == 2:
|
|
50
|
+
inp = inp.unsqueeze(0)
|
|
51
|
+
tmp = inp.shape[0]
|
|
52
|
+
if isinstance(self.layer, nn.Linear):
|
|
53
|
+
if len(inp.shape) == 3:
|
|
54
|
+
inp = inp.reshape((-1, inp.shape[-1]))
|
|
55
|
+
inp = inp.t()
|
|
56
|
+
|
|
57
|
+
self.scaler_row *= self.nsamples / (self.nsamples + tmp)
|
|
58
|
+
self.nsamples += tmp
|
|
59
|
+
|
|
60
|
+
inp = inp.type(torch.float32)
|
|
61
|
+
self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2 / self.nsamples
|