fusion-bench 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/method/__init__.py +4 -0
- fusion_bench/method/fw_merging/__init__.py +2 -0
- fusion_bench/method/fw_merging/fw_hard.py +448 -0
- fusion_bench/method/fw_merging/fw_soft.py +519 -0
- fusion_bench/method/fw_merging/utils.py +331 -0
- fusion_bench/method/moe_pruner/__init__.py +7 -0
- fusion_bench/method/moe_pruner/hooks/__init__.py +6 -0
- fusion_bench/method/moe_pruner/hooks/deepseek_v2.py +85 -0
- fusion_bench/method/moe_pruner/hooks/hook.py +23 -0
- fusion_bench/method/moe_pruner/hooks/mixtral.py +93 -0
- fusion_bench/method/moe_pruner/moe_pruner.py +304 -0
- fusion_bench/method/moe_pruner/utils/__init__.py +1 -0
- fusion_bench/method/moe_pruner/utils/data.py +154 -0
- fusion_bench/method/moe_pruner/utils/layerwrapper.py +61 -0
- fusion_bench/method/moe_pruner/utils/prune.py +313 -0
- fusion_bench/method/moe_pruner/utils/score.py +41 -0
- fusion_bench/method/pruning/__init__.py +1 -0
- fusion_bench/method/pruning/llama_sparsegpt_prune.py +223 -0
- fusion_bench/method/pruning/sparsegpt_utils/__init__.py +1 -0
- fusion_bench/method/pruning/sparsegpt_utils/sparsegpt.py +128 -0
- fusion_bench/method/pruning/wanda_utils/data.py +33 -14
- fusion_bench/method/randes/__init__.py +15 -0
- fusion_bench/method/randes/base_algorithm.py +1013 -0
- fusion_bench/method/randes/modelsoup.py +126 -0
- fusion_bench/method/randes/task_arithmetic.py +318 -0
- fusion_bench/method/sparselo/sparselo.py +20 -2
- fusion_bench/method/tall_mask/__init__.py +1 -0
- fusion_bench/method/tall_mask/task_arithmetic.py +133 -0
- fusion_bench/modelpool/causal_lm/causal_lm.py +73 -10
- fusion_bench/modelpool/lazy_state_dict_pool.py +15 -0
- fusion_bench/models/modeling_deepseek_v2/__init__.py +15 -0
- fusion_bench/models/modeling_deepseek_v2/configuration_deepseek.py +208 -0
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +1922 -0
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +38 -0
- fusion_bench/programs/fabric_fusion_program.py +5 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +8 -1
- fusion_bench/utils/__init__.py +1 -0
- fusion_bench/utils/data.py +1 -1
- fusion_bench/utils/lazy_state_dict.py +268 -0
- fusion_bench/utils/parameters.py +33 -0
- fusion_bench/utils/state_dict_arithmetic.py +74 -2
- fusion_bench/utils/type.py +1 -0
- {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/METADATA +10 -3
- {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/RECORD +86 -22
- {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/WHEEL +1 -1
- fusion_bench_config/dataset/image_classification/test/TALL10.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL12.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL16.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL18.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL10.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL12.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL16.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL18.yaml +28 -0
- fusion_bench_config/method/fw_merging/fw_hard.yaml +11 -0
- fusion_bench_config/method/fw_merging/fw_soft.yaml +12 -0
- fusion_bench_config/method/moe_pruner/moe_pruner.yaml +15 -0
- fusion_bench_config/method/pruning/llama_sparsegpt_pruning.yaml +16 -0
- fusion_bench_config/method/randes/superposed_model_soup.yaml +18 -0
- fusion_bench_config/method/randes/superposed_task_arithmetic.yaml +20 -0
- fusion_bench_config/method/randes/superposed_task_arithmetic_lora.yaml +20 -0
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +2 -1
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
- fusion_bench_config/method/tall_mask/task_arithmetic.yaml +4 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL10.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL12.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL16.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL18.yaml +29 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +8 -0
- fusion_bench_config/modelpool/CausalLMPool/deepseek-v2-lite.yaml +15 -0
- fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.1-8B-Instruct.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.1-8B.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.2-3B-Instruct.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.2-3B.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-2b-it.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-2b.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-9b-it.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-9b.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral-8x7b.yaml +14 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/roberta-base_glue.yaml +69 -0
- {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import time
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import transformers
|
|
7
|
+
|
|
8
|
+
torch.backends.cuda.matmul.allow_tf32 = False
|
|
9
|
+
torch.backends.cudnn.allow_tf32 = False
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
## SparseGPT: https://github.com/IST-DASLab/sparsegpt/tree/f5c25005a61f96a0933ca2f95705a963585aafaa
|
|
13
|
+
class SparseGPT:
|
|
14
|
+
|
|
15
|
+
def __init__(self, layer):
|
|
16
|
+
self.layer = layer
|
|
17
|
+
self.dev = self.layer.weight.device
|
|
18
|
+
W = layer.weight.data.clone()
|
|
19
|
+
if isinstance(self.layer, nn.Conv2d):
|
|
20
|
+
W = W.flatten(1)
|
|
21
|
+
if isinstance(self.layer, transformers.Conv1D):
|
|
22
|
+
W = W.t()
|
|
23
|
+
self.rows = W.shape[0]
|
|
24
|
+
self.columns = W.shape[1]
|
|
25
|
+
self.H = torch.zeros((self.columns, self.columns), device=self.dev)
|
|
26
|
+
self.nsamples = 0
|
|
27
|
+
|
|
28
|
+
def add_batch(self, inp, out):
|
|
29
|
+
if len(inp.shape) == 2:
|
|
30
|
+
inp = inp.unsqueeze(0)
|
|
31
|
+
tmp = inp.shape[0]
|
|
32
|
+
if isinstance(self.layer, nn.Linear) or isinstance(
|
|
33
|
+
self.layer, transformers.Conv1D
|
|
34
|
+
):
|
|
35
|
+
if len(inp.shape) == 3:
|
|
36
|
+
inp = inp.reshape((-1, inp.shape[-1]))
|
|
37
|
+
inp = inp.t()
|
|
38
|
+
self.H *= self.nsamples / (self.nsamples + tmp)
|
|
39
|
+
self.nsamples += tmp
|
|
40
|
+
inp = math.sqrt(2 / self.nsamples) * inp.float()
|
|
41
|
+
self.H += inp.matmul(inp.t())
|
|
42
|
+
|
|
43
|
+
def fasterprune(self, sparsity, prune_n=0, prune_m=0, blocksize=128, percdamp=0.01):
|
|
44
|
+
W = self.layer.weight.data.clone()
|
|
45
|
+
if isinstance(self.layer, nn.Conv2d):
|
|
46
|
+
W = W.flatten(1)
|
|
47
|
+
if isinstance(self.layer, transformers.Conv1D):
|
|
48
|
+
W = W.t()
|
|
49
|
+
W = W.float()
|
|
50
|
+
|
|
51
|
+
tick = time.time()
|
|
52
|
+
|
|
53
|
+
H = self.H
|
|
54
|
+
del self.H
|
|
55
|
+
dead = torch.diag(H) == 0
|
|
56
|
+
H[dead, dead] = 1
|
|
57
|
+
W[:, dead] = 0
|
|
58
|
+
|
|
59
|
+
Losses = torch.zeros(self.rows, device=self.dev)
|
|
60
|
+
|
|
61
|
+
damp = percdamp * torch.mean(torch.diag(H))
|
|
62
|
+
diag = torch.arange(self.columns, device=self.dev)
|
|
63
|
+
H[diag, diag] += damp
|
|
64
|
+
H = torch.linalg.cholesky(H)
|
|
65
|
+
H = torch.cholesky_inverse(H)
|
|
66
|
+
H = torch.linalg.cholesky(H, upper=True)
|
|
67
|
+
Hinv = H
|
|
68
|
+
|
|
69
|
+
mask = None
|
|
70
|
+
|
|
71
|
+
for i1 in range(0, self.columns, blocksize):
|
|
72
|
+
i2 = min(i1 + blocksize, self.columns)
|
|
73
|
+
count = i2 - i1
|
|
74
|
+
|
|
75
|
+
W1 = W[:, i1:i2].clone()
|
|
76
|
+
Q1 = torch.zeros_like(W1)
|
|
77
|
+
Err1 = torch.zeros_like(W1)
|
|
78
|
+
Losses1 = torch.zeros_like(W1)
|
|
79
|
+
Hinv1 = Hinv[i1:i2, i1:i2]
|
|
80
|
+
|
|
81
|
+
if prune_n == 0:
|
|
82
|
+
if mask is not None:
|
|
83
|
+
mask1 = mask[:, i1:i2]
|
|
84
|
+
else:
|
|
85
|
+
tmp = W1**2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2
|
|
86
|
+
thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)]
|
|
87
|
+
mask1 = tmp <= thresh
|
|
88
|
+
else:
|
|
89
|
+
mask1 = torch.zeros_like(W1) == 1
|
|
90
|
+
|
|
91
|
+
for i in range(count):
|
|
92
|
+
w = W1[:, i]
|
|
93
|
+
d = Hinv1[i, i]
|
|
94
|
+
|
|
95
|
+
if prune_n != 0 and i % prune_m == 0:
|
|
96
|
+
tmp = (
|
|
97
|
+
W1[:, i : (i + prune_m)] ** 2
|
|
98
|
+
/ (torch.diag(Hinv1)[i : (i + prune_m)].reshape((1, -1))) ** 2
|
|
99
|
+
)
|
|
100
|
+
mask1.scatter_(
|
|
101
|
+
1, i + torch.topk(tmp, prune_n, dim=1, largest=False)[1], True
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
q = w.clone()
|
|
105
|
+
q[mask1[:, i]] = 0
|
|
106
|
+
|
|
107
|
+
Q1[:, i] = q
|
|
108
|
+
Losses1[:, i] = (w - q) ** 2 / d**2
|
|
109
|
+
|
|
110
|
+
err1 = (w - q) / d
|
|
111
|
+
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
|
|
112
|
+
Err1[:, i] = err1
|
|
113
|
+
|
|
114
|
+
W[:, i1:i2] = Q1
|
|
115
|
+
Losses += torch.sum(Losses1, 1) / 2
|
|
116
|
+
|
|
117
|
+
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
|
|
118
|
+
|
|
119
|
+
torch.cuda.synchronize()
|
|
120
|
+
if isinstance(self.layer, transformers.Conv1D):
|
|
121
|
+
W = W.t()
|
|
122
|
+
self.layer.weight.data = W.reshape(self.layer.weight.shape).to(
|
|
123
|
+
self.layer.weight.data.dtype
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
def free(self):
|
|
127
|
+
self.H = None
|
|
128
|
+
torch.cuda.empty_cache()
|
|
@@ -1,13 +1,15 @@
|
|
|
1
1
|
# Code adapted from https://github.com/IST-DASLab/sparsegpt/blob/master/datautils.py
|
|
2
2
|
|
|
3
|
+
import os
|
|
3
4
|
import random
|
|
4
5
|
from typing import List, Optional, Tuple, cast # noqa: F401
|
|
5
6
|
|
|
6
|
-
from datasets import load_dataset
|
|
7
7
|
from torch import Tensor
|
|
8
8
|
from tqdm.auto import tqdm
|
|
9
9
|
from transformers import PreTrainedTokenizer
|
|
10
10
|
|
|
11
|
+
from datasets import load_dataset
|
|
12
|
+
|
|
11
13
|
|
|
12
14
|
# Wrapper for tokenized input IDs
|
|
13
15
|
class TokenizerWrapper:
|
|
@@ -61,6 +63,7 @@ def get_c4(
|
|
|
61
63
|
seqlen: int,
|
|
62
64
|
tokenizer,
|
|
63
65
|
data_path: str = "allenai/c4",
|
|
66
|
+
cache_dir: str = ".cache/allenai--c4",
|
|
64
67
|
) -> Tuple[List[Tuple[Tensor, Tensor]], TokenizerWrapper]:
|
|
65
68
|
"""
|
|
66
69
|
Load and process the c4 dataset.
|
|
@@ -76,19 +79,35 @@ def get_c4(
|
|
|
76
79
|
tuple (Tuple[List[Tuple[Tensor, Tensor]], TokenizerWrapper]): Tuple containing the training samples and the validation dataset.
|
|
77
80
|
"""
|
|
78
81
|
# Load train and validation datasets
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
82
|
+
if os.path.exists(f"{cache_dir}/en/c4-train.00000-of-01024.json.gz"):
|
|
83
|
+
traindata = load_dataset(
|
|
84
|
+
"json",
|
|
85
|
+
data_files={"train": f"{cache_dir}/en/c4-train.00000-of-01024.json.gz"},
|
|
86
|
+
split="train",
|
|
87
|
+
)
|
|
88
|
+
else:
|
|
89
|
+
traindata = load_dataset(
|
|
90
|
+
data_path,
|
|
91
|
+
# "allenai--c4", # https://github.com/huggingface/datasets/issues/6559
|
|
92
|
+
data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
|
|
93
|
+
split="train",
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
if os.path.exists(f"{cache_dir}/en/c4-validation.00000-of-00008.json.gz"):
|
|
97
|
+
valdata = load_dataset(
|
|
98
|
+
"json",
|
|
99
|
+
data_files={
|
|
100
|
+
"validation": f"{cache_dir}/en/c4-validation.00000-of-00008.json.gz",
|
|
101
|
+
},
|
|
102
|
+
split="validation",
|
|
103
|
+
)
|
|
104
|
+
else:
|
|
105
|
+
valdata = load_dataset(
|
|
106
|
+
data_path,
|
|
107
|
+
# "allenai--c4",
|
|
108
|
+
data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
|
|
109
|
+
split="validation",
|
|
110
|
+
)
|
|
92
111
|
# Generate samples from training set
|
|
93
112
|
if seed is not None:
|
|
94
113
|
random.seed(seed)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
R"""
|
|
2
|
+
RanDeS: Randomized Delta Superposition
|
|
3
|
+
|
|
4
|
+
Implementation of "RanDeS: Randomized Delta Superposition for Multi-Model Compression"
|
|
5
|
+
paper link: http://arxiv.org/abs/2505.11204
|
|
6
|
+
|
|
7
|
+
Modified from https://github.com/Zhou-Hangyu/randes
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from .base_algorithm import SuperposedAlgorithmBase
|
|
11
|
+
from .modelsoup import SuperposedModelSoupAlgorithm
|
|
12
|
+
from .task_arithmetic import (
|
|
13
|
+
SuperposedTaskArithmeticAlgorithm,
|
|
14
|
+
SuperposedTaskArithmeticLoRAAlgorithm,
|
|
15
|
+
)
|