fusion-bench 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. fusion_bench/method/__init__.py +4 -0
  2. fusion_bench/method/fw_merging/__init__.py +2 -0
  3. fusion_bench/method/fw_merging/fw_hard.py +448 -0
  4. fusion_bench/method/fw_merging/fw_soft.py +519 -0
  5. fusion_bench/method/fw_merging/utils.py +331 -0
  6. fusion_bench/method/moe_pruner/__init__.py +7 -0
  7. fusion_bench/method/moe_pruner/hooks/__init__.py +6 -0
  8. fusion_bench/method/moe_pruner/hooks/deepseek_v2.py +85 -0
  9. fusion_bench/method/moe_pruner/hooks/hook.py +23 -0
  10. fusion_bench/method/moe_pruner/hooks/mixtral.py +93 -0
  11. fusion_bench/method/moe_pruner/moe_pruner.py +304 -0
  12. fusion_bench/method/moe_pruner/utils/__init__.py +1 -0
  13. fusion_bench/method/moe_pruner/utils/data.py +154 -0
  14. fusion_bench/method/moe_pruner/utils/layerwrapper.py +61 -0
  15. fusion_bench/method/moe_pruner/utils/prune.py +313 -0
  16. fusion_bench/method/moe_pruner/utils/score.py +41 -0
  17. fusion_bench/method/pruning/__init__.py +1 -0
  18. fusion_bench/method/pruning/llama_sparsegpt_prune.py +223 -0
  19. fusion_bench/method/pruning/sparsegpt_utils/__init__.py +1 -0
  20. fusion_bench/method/pruning/sparsegpt_utils/sparsegpt.py +128 -0
  21. fusion_bench/method/pruning/wanda_utils/data.py +33 -14
  22. fusion_bench/method/randes/__init__.py +15 -0
  23. fusion_bench/method/randes/base_algorithm.py +1013 -0
  24. fusion_bench/method/randes/modelsoup.py +126 -0
  25. fusion_bench/method/randes/task_arithmetic.py +318 -0
  26. fusion_bench/method/sparselo/sparselo.py +20 -2
  27. fusion_bench/method/tall_mask/__init__.py +1 -0
  28. fusion_bench/method/tall_mask/task_arithmetic.py +133 -0
  29. fusion_bench/modelpool/causal_lm/causal_lm.py +73 -10
  30. fusion_bench/modelpool/lazy_state_dict_pool.py +15 -0
  31. fusion_bench/models/modeling_deepseek_v2/__init__.py +15 -0
  32. fusion_bench/models/modeling_deepseek_v2/configuration_deepseek.py +208 -0
  33. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +1922 -0
  34. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +38 -0
  35. fusion_bench/programs/fabric_fusion_program.py +5 -0
  36. fusion_bench/taskpool/clip_vision/taskpool.py +8 -1
  37. fusion_bench/utils/__init__.py +1 -0
  38. fusion_bench/utils/data.py +1 -1
  39. fusion_bench/utils/lazy_state_dict.py +268 -0
  40. fusion_bench/utils/parameters.py +33 -0
  41. fusion_bench/utils/state_dict_arithmetic.py +74 -2
  42. fusion_bench/utils/type.py +1 -0
  43. {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/METADATA +10 -3
  44. {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/RECORD +86 -22
  45. {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/WHEEL +1 -1
  46. fusion_bench_config/dataset/image_classification/test/TALL10.yaml +28 -0
  47. fusion_bench_config/dataset/image_classification/test/TALL12.yaml +28 -0
  48. fusion_bench_config/dataset/image_classification/test/TALL16.yaml +28 -0
  49. fusion_bench_config/dataset/image_classification/test/TALL18.yaml +28 -0
  50. fusion_bench_config/dataset/image_classification/train/TALL10.yaml +28 -0
  51. fusion_bench_config/dataset/image_classification/train/TALL12.yaml +28 -0
  52. fusion_bench_config/dataset/image_classification/train/TALL16.yaml +28 -0
  53. fusion_bench_config/dataset/image_classification/train/TALL18.yaml +28 -0
  54. fusion_bench_config/method/fw_merging/fw_hard.yaml +11 -0
  55. fusion_bench_config/method/fw_merging/fw_soft.yaml +12 -0
  56. fusion_bench_config/method/moe_pruner/moe_pruner.yaml +15 -0
  57. fusion_bench_config/method/pruning/llama_sparsegpt_pruning.yaml +16 -0
  58. fusion_bench_config/method/randes/superposed_model_soup.yaml +18 -0
  59. fusion_bench_config/method/randes/superposed_task_arithmetic.yaml +20 -0
  60. fusion_bench_config/method/randes/superposed_task_arithmetic_lora.yaml +20 -0
  61. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +2 -1
  62. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
  63. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
  64. fusion_bench_config/method/tall_mask/task_arithmetic.yaml +4 -0
  65. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL10.yaml +29 -0
  66. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL12.yaml +29 -0
  67. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL16.yaml +29 -0
  68. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL18.yaml +29 -0
  69. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +8 -0
  70. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +8 -0
  71. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +8 -0
  72. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +8 -0
  73. fusion_bench_config/modelpool/CausalLMPool/deepseek-v2-lite.yaml +15 -0
  74. fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.1-8B-Instruct.yaml +11 -0
  75. fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.1-8B.yaml +11 -0
  76. fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.2-3B-Instruct.yaml +11 -0
  77. fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.2-3B.yaml +11 -0
  78. fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-2b-it.yaml +11 -0
  79. fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-2b.yaml +11 -0
  80. fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-9b-it.yaml +11 -0
  81. fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-9b.yaml +11 -0
  82. fusion_bench_config/modelpool/CausalLMPool/mixtral-8x7b.yaml +14 -0
  83. fusion_bench_config/modelpool/SeqenceClassificationModelPool/roberta-base_glue.yaml +69 -0
  84. {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/entry_points.txt +0 -0
  85. {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/licenses/LICENSE +0 -0
  86. {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,304 @@
1
+ import logging
2
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, TypeVar, Union
3
+
4
+ import torch
5
+ from torch import Tensor, nn
6
+ from torch.nn import functional as F
7
+ from tqdm.auto import tqdm
8
+ from transformers import MixtralForCausalLM
9
+ from transformers.models.mixtral.modeling_mixtral import (
10
+ MixtralDecoderLayer,
11
+ MixtralSparseMoeBlock,
12
+ )
13
+
14
+ from fusion_bench import BaseAlgorithm, BaseModelPool
15
+ from fusion_bench.method.pruning.prune_utils import (
16
+ PruningType,
17
+ compute_sparsity,
18
+ find_linear_layers,
19
+ semistructured_magnitude_prune_,
20
+ unstructured_magnitude_prune_,
21
+ )
22
+ from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
23
+ from fusion_bench.modelpool import CausalLMPool
24
+ from fusion_bench.models.modeling_deepseek_v2 import (
25
+ DeepseekV2DecoderLayer,
26
+ DeepseekV2ForCausalLM,
27
+ DeepseekV2MLP,
28
+ DeepseekV2MoE,
29
+ DeepseekV2MoEGate,
30
+ )
31
+ from fusion_bench.utils import timeit_context
32
+ from fusion_bench.utils.cache_utils import cache_to_disk
33
+ from fusion_bench.utils.devices import to_device
34
+
35
+ from .hooks.deepseek_v2 import (
36
+ MoEPrunerHookFnForDeepseekV2Gate,
37
+ MoEPrunerHookFnForDeepseekV2Linear,
38
+ )
39
+ from .hooks.hook import BaseHookFn
40
+ from .hooks.mixtral import (
41
+ MoEPrunerHookFnForMixtralGate,
42
+ MoEPrunerHookFnForMixtralLinear,
43
+ )
44
+ from .utils.data import get_loaders
45
+ from .utils.prune import prepare_calibration_input
46
+
47
+ MoEModel = TypeVar("MoEModel", bound=Union[MixtralForCausalLM, DeepseekV2ForCausalLM])
48
+
49
+ log = logging.getLogger(__name__)
50
+
51
+
52
+ class MoEPruner(BaseAlgorithm, SimpleProfilerMixin, LightningFabricMixin):
53
+
54
+ def __init__(
55
+ self,
56
+ nsamples: int,
57
+ seed: int,
58
+ device: str,
59
+ prune_type: PruningType,
60
+ sparsity_ratio: float,
61
+ n: int,
62
+ m: int,
63
+ max_seqlen: Optional[int] = None,
64
+ ):
65
+ self.nsamples = nsamples
66
+ self.seed = seed
67
+ self.device = device
68
+ self.max_seqlen = max_seqlen
69
+ self.prune_type = prune_type
70
+ self.sparsity_ratio = sparsity_ratio
71
+ self.n = n
72
+ self.m = m
73
+ super().__init__()
74
+
75
+ def run(self, modelpool: CausalLMPool):
76
+ # load pre-trained model or the first model in the pool
77
+ with self.profile("load_model"):
78
+ model: MoEModel = modelpool.load_pretrained_or_first_model()
79
+ if self.max_seqlen is not None:
80
+ model.seqlen = min(
81
+ model.config.max_position_embeddings,
82
+ self.max_seqlen,
83
+ )
84
+ tokenizer = modelpool.load_tokenizer()
85
+
86
+ inps, outs, attention_mask, position_ids, position_embeddings = (
87
+ self.prepare_calibration_data(model, tokenizer)
88
+ )
89
+
90
+ self.prune_using_calibration_data_(
91
+ model,
92
+ inps=inps,
93
+ outs=outs,
94
+ attention_mask=attention_mask,
95
+ position_ids=position_ids,
96
+ position_embeddings=position_embeddings,
97
+ )
98
+
99
+ return model
100
+
101
+ def prepare_calibration_data(self, model: MoEModel, tokenizer):
102
+ """
103
+ Prepare calibration data for pruning with caching.
104
+
105
+ Args:
106
+ model (LlamaForCausalLM): Model to be pruned.
107
+ tokenizer: Tokenizer for the model.
108
+
109
+ Returns:
110
+ Tuple: Calibration data (inputs, outputs, attention mask, position IDs).
111
+ """
112
+
113
+ @cache_to_disk(
114
+ f"outputs/cache/{model.config.name_or_path.split('/')[-1]}/calibration_data.pkl"
115
+ )
116
+ def _prepare_calibration_data(model, tokenizer):
117
+ return self._prepare_calibration_data(model, tokenizer)
118
+
119
+ return _prepare_calibration_data(model, tokenizer)
120
+
121
+ def _prepare_calibration_data(self, model, tokenizer):
122
+ """
123
+ Prepare calibration data for pruning.
124
+
125
+ Args:
126
+ model (LlamaForCausalLM): Model to be pruned.
127
+ tokenizer: Tokenizer for the model.
128
+
129
+ Returns:
130
+ Tuple: Calibration data (inputs, outputs, attention mask, position IDs).
131
+ """
132
+ with timeit_context("loading calibration data"):
133
+ dataloader, _ = get_loaders(
134
+ "c4",
135
+ nsamples=self.nsamples,
136
+ seed=self.seed,
137
+ seqlen=model.seqlen,
138
+ tokenizer=tokenizer,
139
+ )
140
+
141
+ with torch.no_grad():
142
+ # collect input to the first layer
143
+ inps, outs, attention_mask, position_ids, position_embeddings = (
144
+ prepare_calibration_input(model, dataloader, self.device)
145
+ )
146
+ return inps, outs, attention_mask, position_ids, position_embeddings
147
+
148
+ def prune_using_calibration_data_(
149
+ self,
150
+ model: MoEModel,
151
+ *,
152
+ inps,
153
+ outs,
154
+ attention_mask,
155
+ position_ids,
156
+ position_embeddings,
157
+ ):
158
+ model.eval()
159
+ layers = model.model.layers
160
+ for layer_idx, layer in tqdm(
161
+ enumerate(layers),
162
+ "Pruning Layers",
163
+ total=len(layers),
164
+ dynamic_ncols=True,
165
+ ):
166
+ if (
167
+ hasattr(model, "hf_device_map")
168
+ and f"model.layers.{layer_idx}" in model.hf_device_map
169
+ ):
170
+ # handle the case for large models, when the device map has multiple GPUs;
171
+ dev = model.hf_device_map[f"model.layers.{layer_idx}"]
172
+ inps, outs, attention_mask, position_ids, position_embeddings = (
173
+ inps.to(dev),
174
+ outs.to(dev),
175
+ attention_mask.to(dev) if attention_mask is not None else None,
176
+ position_ids.to(dev) if position_ids is not None else None,
177
+ (
178
+ to_device(position_embeddings, dev)
179
+ if position_embeddings is not None
180
+ else None
181
+ ),
182
+ )
183
+
184
+ if isinstance(layer, MixtralDecoderLayer):
185
+ linear_layers = find_linear_layers(layer.block_sparse_moe.experts)
186
+ elif isinstance(layer, DeepseekV2DecoderLayer):
187
+ if isinstance(layer.mlp, DeepseekV2MoE):
188
+ linear_layers = find_linear_layers(layer.mlp.experts)
189
+ elif isinstance(layer.mlp, DeepseekV2MLP):
190
+ # compute the input to the next layer
191
+ with torch.no_grad():
192
+ for j in range(self.nsamples):
193
+ outs[j] = layer(
194
+ inps[j].unsqueeze(0),
195
+ attention_mask=attention_mask,
196
+ position_ids=position_ids,
197
+ position_embeddings=position_embeddings,
198
+ )[0]
199
+ inps, outs = outs, inps
200
+ continue
201
+ else:
202
+ raise ValueError(f"Unsupported layer type: {type(layer)}")
203
+
204
+ linear_hooks: Dict[str, BaseHookFn] = {}
205
+ handles: List[torch.utils.hooks.RemovableHandle] = []
206
+ for name, linear in linear_layers.items():
207
+ if isinstance(model, MixtralForCausalLM):
208
+ hook_fn = MoEPrunerHookFnForMixtralLinear(linear, name)
209
+ elif isinstance(model, DeepseekV2ForCausalLM):
210
+ hook_fn = MoEPrunerHookFnForDeepseekV2Linear(linear, name)
211
+ else:
212
+ raise ValueError(f"Unsupported model type: {type(model)}")
213
+ linear_hooks[name] = hook_fn
214
+ handles.append(linear.register_forward_hook(hook_fn))
215
+
216
+ if isinstance(model, MixtralForCausalLM):
217
+ gate_hook = MoEPrunerHookFnForMixtralGate(
218
+ layer.block_sparse_moe.gate,
219
+ linear_hooks,
220
+ top_k=layer.block_sparse_moe.top_k,
221
+ num_experts=layer.block_sparse_moe.num_experts,
222
+ )
223
+ handles.append(
224
+ layer.block_sparse_moe.gate.register_forward_hook(gate_hook)
225
+ )
226
+ elif isinstance(model, DeepseekV2ForCausalLM):
227
+ gate_hook = MoEPrunerHookFnForDeepseekV2Gate(
228
+ layer.mlp.gate,
229
+ linear_hooks,
230
+ top_k=layer.mlp.gate.top_k,
231
+ num_experts=layer.mlp.config.n_routed_experts,
232
+ )
233
+ handles.append(layer.mlp.gate.register_forward_hook(gate_hook))
234
+ else:
235
+ raise ValueError(f"Unsupported model type: {type(model)}")
236
+
237
+ with torch.no_grad():
238
+ for j in range(self.nsamples):
239
+ outs[j] = layer(
240
+ inps[j].unsqueeze(0),
241
+ attention_mask=attention_mask,
242
+ position_ids=position_ids,
243
+ position_embeddings=position_embeddings,
244
+ )[0]
245
+
246
+ # compute the importance scores and remove the hooks
247
+ metrics = {}
248
+ for name, hook in linear_hooks.items():
249
+ metrics[name] = hook.compute().detach().cpu()
250
+ for h in handles:
251
+ h.remove()
252
+
253
+ # prune the weights based on the importance scores
254
+ if self.prune_type == PruningType.UNSTRUCTURED:
255
+ for name, linear in linear_layers.items():
256
+ log.info(f"Pruning {name}")
257
+ unstructured_magnitude_prune_(
258
+ linear.weight.data,
259
+ metrics[name].to(linear.weight.device),
260
+ sparsity_ratio=self.sparsity_ratio,
261
+ )
262
+ self.check_sparsity(linear.weight)
263
+ elif self.prune_type == PruningType.SEMISTRUCTURED:
264
+ for name, linear in linear_layers.items():
265
+ log.info(f"Pruning {name}")
266
+ semistructured_magnitude_prune_(
267
+ linear.weight.data,
268
+ metrics[name].to(linear.weight.device),
269
+ n=self.n,
270
+ m=self.m,
271
+ )
272
+ self.check_sparsity(linear.weight)
273
+ else:
274
+ raise ValueError(f"Invalid pruning type: {self.prune_type}")
275
+
276
+ # compute the input to the next layer
277
+ with torch.no_grad():
278
+ for j in range(self.nsamples):
279
+ outs[j] = layer(
280
+ inps[j].unsqueeze(0),
281
+ attention_mask=attention_mask,
282
+ position_ids=position_ids,
283
+ position_embeddings=position_embeddings,
284
+ )[0]
285
+ inps, outs = outs, inps
286
+
287
+ @torch.no_grad()
288
+ def check_sparsity(self, weight: Tensor, tol: float = 0.01):
289
+ """
290
+ Check the sparsity of the weight tensor.
291
+
292
+ Args:
293
+ weight (Tensor): Weight tensor.
294
+ tol (float): Tolerance for sparsity check.
295
+
296
+ Raises:
297
+ ValueError: If the pruning type is invalid.
298
+ """
299
+ if self.prune_type == PruningType.UNSTRUCTURED:
300
+ assert (compute_sparsity(weight) - self.sparsity_ratio).abs() < tol
301
+ elif self.prune_type == PruningType.SEMISTRUCTURED:
302
+ assert (compute_sparsity(weight) - self.n / self.m).abs() < tol
303
+ else:
304
+ raise ValueError(f"Invalid pruning type: {self.prune_type}")
@@ -0,0 +1 @@
1
+ from .score import layer_load_balance_score
@@ -0,0 +1,154 @@
1
+ # Code adapted from https://github.com/IST-DASLab/sparsegpt/blob/master/datautils.py
2
+
3
+ import random
4
+ from typing import List, Optional, Tuple, cast # noqa: F401
5
+ import os
6
+ from datasets import load_dataset
7
+ from torch import Tensor
8
+ from tqdm.auto import tqdm
9
+ from transformers import PreTrainedTokenizer
10
+
11
+
12
+ # Wrapper for tokenized input IDs
13
+ class TokenizerWrapper:
14
+ def __init__(self, input_ids):
15
+ self.input_ids = input_ids
16
+
17
+
18
+ # Load and process wikitext2 dataset
19
+ def get_wikitext2(
20
+ nsamples: int,
21
+ seed: int,
22
+ seqlen: int,
23
+ tokenizer: PreTrainedTokenizer,
24
+ data_path: str = "wikitext",
25
+ ):
26
+ """
27
+ Load and preprocess the Wikitext-2 dataset for training and testing.
28
+
29
+ Args:
30
+ nsamples (int): Number of samples to generate from the training set.
31
+ seed (int): Random seed for reproducibility.
32
+ seqlen (int): Length of the sequence to be used for training.
33
+ tokenizer (PreTrainedTokenizer): Tokenizer to encode the text data.
34
+ data_path (str, optional): Path to the dataset. Defaults to "wikitext".
35
+ """
36
+ # Load train and test datasets
37
+ traindata = load_dataset(data_path, "wikitext-2-raw-v1", split="train")
38
+ testdata = load_dataset(data_path, "wikitext-2-raw-v1", split="test")
39
+
40
+ # Encode datasets
41
+ trainenc = tokenizer(" ".join(traindata["text"]), return_tensors="pt")
42
+ testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")
43
+
44
+ # Generate samples from training set
45
+ random.seed(seed)
46
+ trainloader: List[Tuple[Tensor, Tensor]] = []
47
+ for _ in tqdm(range(nsamples), desc="Generating samples"):
48
+ i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
49
+ j = i + seqlen
50
+ inp: Tensor = trainenc.input_ids[:, i:j]
51
+ tar = inp.clone()
52
+ tar[:, :-1] = -100
53
+ trainloader.append((inp, tar))
54
+ return trainloader, testenc
55
+
56
+
57
+ # Load and process c4 dataset
58
+ def get_c4(
59
+ nsamples: int,
60
+ seed: int,
61
+ seqlen: int,
62
+ tokenizer,
63
+ data_path: str = "allenai/c4",
64
+ ) -> Tuple[List[Tuple[Tensor, Tensor]], TokenizerWrapper]:
65
+ """
66
+ Load and process the c4 dataset.
67
+
68
+ Args:
69
+ nsamples (int): Number of samples to generate from the training set.
70
+ seed (int): Seed for random number generation.
71
+ seqlen (int): Length of each sequence.
72
+ tokenizer: Tokenizer object for encoding the text.
73
+ data_path (str, optional): Path to the c4 dataset. Defaults to "allenai/c4".
74
+
75
+ Returns:
76
+ tuple (Tuple[List[Tuple[Tensor, Tensor]], TokenizerWrapper]): Tuple containing the training samples and the validation dataset.
77
+ """
78
+ # Load train and validation datasets
79
+ if os.path.exists(".cache/allenai--c4/en/c4-train.00000-of-01024.json.gz"):
80
+ traindata = load_dataset(
81
+ "json",
82
+ data_files={
83
+ "train": ".cache/allenai--c4/en/c4-train.00000-of-01024.json.gz"
84
+ },
85
+ split="train",
86
+ )
87
+ else:
88
+ traindata = load_dataset(
89
+ data_path,
90
+ # "allenai--c4", # https://github.com/huggingface/datasets/issues/6559
91
+ data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
92
+ split="train",
93
+ )
94
+
95
+ if os.path.exists(".cache/allenai--c4/en/c4-validation.00000-of-00008.json.gz"):
96
+ valdata = load_dataset(
97
+ "json",
98
+ data_files={
99
+ "validation": ".cache/allenai--c4/en/c4-validation.00000-of-00008.json.gz",
100
+ },
101
+ split="validation",
102
+ )
103
+ else:
104
+ valdata = load_dataset(
105
+ data_path,
106
+ # "allenai--c4",
107
+ data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
108
+ split="validation",
109
+ )
110
+
111
+ # Generate samples from training set
112
+ if seed is not None:
113
+ random.seed(seed)
114
+
115
+ trainloader = []
116
+ for _ in tqdm(range(nsamples), desc="Generating samples"):
117
+ while True:
118
+ i = random.randint(0, len(traindata) - 1)
119
+ trainenc = tokenizer(traindata[i]["text"], return_tensors="pt")
120
+ if trainenc.input_ids.shape[1] > seqlen:
121
+ break
122
+ i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
123
+ j = i + seqlen
124
+ inp = trainenc.input_ids[:, i:j]
125
+ tar = inp.clone()
126
+ tar[:, :-1] = -100
127
+ trainloader.append((inp, tar))
128
+
129
+ # Prepare validation dataset
130
+ valenc = tokenizer(" ".join(valdata[:1100]["text"]), return_tensors="pt")
131
+ valenc = valenc.input_ids[:, : (256 * seqlen)]
132
+ valenc = TokenizerWrapper(valenc)
133
+ return trainloader, valenc
134
+
135
+
136
+ # Function to select the appropriate loader based on dataset name
137
+ def get_loaders(
138
+ name: str, nsamples: int = 128, seed: int = 0, seqlen: int = 2048, tokenizer=None
139
+ ):
140
+ """
141
+ Get the data loaders for the specified dataset.
142
+
143
+ Args:
144
+ name (str): The name of the dataset. Supported values are "wikitext2" and "c4".
145
+ nsamples (int, optional): Number of samples to generate from the dataset. Defaults to 128.
146
+ seed (int, optional): Random seed for reproducibility. Defaults to 0.
147
+ seqlen (int, optional): Length of the sequence to be used for training. Defaults to 2048.
148
+ tokenizer (optional): Tokenizer to encode the text data. Defaults to None.
149
+ """
150
+ if "wikitext2" in name:
151
+ return get_wikitext2(nsamples, seed, seqlen, tokenizer)
152
+ if "c4" in name:
153
+ return get_c4(nsamples, seed, seqlen, tokenizer)
154
+ raise ValueError(f"Unknown dataset: {name}")
@@ -0,0 +1,61 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ # Define WrappedGPT class
6
+ class WrappedGPT:
7
+ """
8
+ This class wraps a GPT layer for specific operations.
9
+
10
+ Attributes:
11
+ layer (nn.Linear | nn.Module): The GPT layer to be wrapped.
12
+ dev (torch.device): The device on which the layer's weights are stored.
13
+ rows (int): The number of rows in the layer's weight matrix.
14
+ columns (int): The number of columns in the layer's weight matrix.
15
+ scaler_row (torch.Tensor): A tensor to store the scaler values for each column.
16
+ nsamples (int): The number of samples processed.
17
+ layer_id (int): The ID of the layer.
18
+ layer_name (str): The name of the layer.
19
+ """
20
+
21
+ def __init__(self, layer: nn.Linear | nn.Module, layer_id=0, layer_name="none"):
22
+ """
23
+ Initialize the WrappedGPT class.
24
+
25
+ Args:
26
+ layer (nn.Linear | nn.Module): The GPT layer to be wrapped.
27
+ layer_id (int, optional): The ID of the layer. Defaults to 0.
28
+ layer_name (str, optional): The name of the layer. Defaults to "none".
29
+ """
30
+ self.layer = layer
31
+ self.dev = self.layer.weight.device
32
+ self.rows = layer.weight.data.shape[0]
33
+ self.columns = layer.weight.data.shape[1]
34
+
35
+ self.scaler_row = torch.zeros((self.columns), device=self.dev)
36
+ self.nsamples = 0
37
+
38
+ self.layer_id = layer_id
39
+ self.layer_name = layer_name
40
+
41
+ def add_batch(self, inp: torch.Tensor, out: torch.Tensor):
42
+ """
43
+ Add a batch of input and output tensors to the scaler_row.
44
+
45
+ Args:
46
+ inp (torch.Tensor): The input tensor.
47
+ out (torch.Tensor): The output tensor.
48
+ """
49
+ if len(inp.shape) == 2:
50
+ inp = inp.unsqueeze(0)
51
+ tmp = inp.shape[0]
52
+ if isinstance(self.layer, nn.Linear):
53
+ if len(inp.shape) == 3:
54
+ inp = inp.reshape((-1, inp.shape[-1]))
55
+ inp = inp.t()
56
+
57
+ self.scaler_row *= self.nsamples / (self.nsamples + tmp)
58
+ self.nsamples += tmp
59
+
60
+ inp = inp.type(torch.float32)
61
+ self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2 / self.nsamples