fusion-bench 0.2.17__py3-none-any.whl → 0.2.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +6 -0
- fusion_bench/constants/banner.py +12 -0
- fusion_bench/method/__init__.py +11 -0
- fusion_bench/method/expert_sparsity/__init__.py +10 -0
- fusion_bench/method/expert_sparsity/mixtral/__init__.py +23 -0
- fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +175 -0
- fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +159 -0
- fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +173 -0
- fusion_bench/method/expert_sparsity/utils/calibration_data.py +153 -0
- fusion_bench/method/knots/__init__.py +0 -0
- fusion_bench/method/knots/knots_utils.py +23 -0
- fusion_bench/method/linear/simple_average_for_llama.py +17 -3
- fusion_bench/method/simple_average.py +10 -0
- fusion_bench/method/task_singular_vector/utils/__init__.py +1 -0
- fusion_bench/method/task_singular_vector/utils/task_singular_interference.py +41 -0
- fusion_bench/modelpool/causal_lm/causal_lm.py +45 -11
- fusion_bench/models/__init__.py +1 -0
- fusion_bench/models/expert_sparsity/__init__.py +0 -0
- fusion_bench/models/expert_sparsity/mixtral/__init__.py +15 -0
- fusion_bench/models/expert_sparsity/mixtral/dataset.py +40 -0
- fusion_bench/models/expert_sparsity/mixtral/modeling_mixtral.py +207 -0
- fusion_bench/models/expert_sparsity/mixtral/wrapper.py +268 -0
- fusion_bench/programs/fabric_fusion_program.py +12 -8
- fusion_bench/tasks/clip_classification/imagenet.py +1008 -2004
- fusion_bench/utils/__init__.py +3 -2
- fusion_bench/utils/dtype.py +2 -1
- fusion_bench/utils/fabric.py +11 -4
- fusion_bench/utils/lazy_state_dict.py +155 -13
- fusion_bench/utils/misc.py +19 -1
- fusion_bench/utils/pylogger.py +2 -0
- {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.19.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.19.dist-info}/RECORD +40 -21
- fusion_bench_config/fabric/loggers/mlflow_logger.yaml +2 -0
- fusion_bench_config/method/expert_sparsity/README.md +6 -0
- fusion_bench_config/method/expert_sparsity/mixtral.yaml +17 -0
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +11 -0
- {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.19.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.19.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.19.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.19.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains the code for loading the calibration data.
|
|
3
|
+
|
|
4
|
+
- C4
|
|
5
|
+
- Math
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import itertools
|
|
9
|
+
import logging
|
|
10
|
+
import os
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
import transformers
|
|
14
|
+
from datasets import load_dataset
|
|
15
|
+
from transformers import PreTrainedTokenizer, default_data_collator
|
|
16
|
+
from transformers.testing_utils import CaptureLogger
|
|
17
|
+
from huggingface_hub import hf_hub_download
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
DATASETS = {
|
|
23
|
+
# C4: Please download first part of the C4 training data `c4-train.00000-of-01024.json` from [allenai/c4](https://huggingface.co/datasets/allenai/c4/blob/main/en/c4-train.00000-of-01024.json.gz).
|
|
24
|
+
"c4": lambda: load_dataset(
|
|
25
|
+
"json",
|
|
26
|
+
data_files={
|
|
27
|
+
"train": hf_hub_download(
|
|
28
|
+
"allenai/c4",
|
|
29
|
+
filename="en/c4-train.00000-of-01024.json.gz",
|
|
30
|
+
repo_type="dataset",
|
|
31
|
+
)
|
|
32
|
+
},
|
|
33
|
+
),
|
|
34
|
+
# MATH: You can use our pre-built calibration set in `./data/math_pretrain_style.json`. To reproduce our construction, please download the training set of [MATH](https://github.com/hendrycks/math) and use our [script](data/math_calib_construction.py).
|
|
35
|
+
# NOTE: I have uploaded the math_pretrain_style.json to my huggingface repo:
|
|
36
|
+
# https://huggingface.co/datasets/tanganke/math_pretrain_style/tree/main.
|
|
37
|
+
"math": lambda: load_dataset(
|
|
38
|
+
"json",
|
|
39
|
+
data_files={
|
|
40
|
+
"train": hf_hub_download(
|
|
41
|
+
"tanganke/math_pretrain_style",
|
|
42
|
+
filename="math_pretrain_style.json",
|
|
43
|
+
repo_type="dataset",
|
|
44
|
+
)
|
|
45
|
+
},
|
|
46
|
+
),
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def build_calib_loader(
|
|
51
|
+
dataset: str,
|
|
52
|
+
tokenizer: PreTrainedTokenizer,
|
|
53
|
+
max_block_size: int,
|
|
54
|
+
n_blocks_for_stat: int,
|
|
55
|
+
batch_size: int,
|
|
56
|
+
num_workers: int,
|
|
57
|
+
seed: int = 42,
|
|
58
|
+
):
|
|
59
|
+
# dataset can be a string or a dataset object.
|
|
60
|
+
# If it is a string, it can be the name of the dataset in DATASETS or the path to the dataset (a json file).
|
|
61
|
+
if isinstance(dataset, str):
|
|
62
|
+
if dataset in DATASETS:
|
|
63
|
+
all_set = DATASETS[dataset]()
|
|
64
|
+
else:
|
|
65
|
+
assert os.path.exists(dataset), f"Dataset {dataset} not found."
|
|
66
|
+
all_set = load_dataset("json", data_files={"train": dataset})
|
|
67
|
+
else:
|
|
68
|
+
assert dataset is not None, "Dataset is not provided."
|
|
69
|
+
all_set = dataset
|
|
70
|
+
|
|
71
|
+
block_size = tokenizer.model_max_length
|
|
72
|
+
if block_size > max_block_size:
|
|
73
|
+
logger.info(
|
|
74
|
+
"The chosen tokenizer supports a `model_max_length` that is longer than the default `max_block_size` value"
|
|
75
|
+
f" of {max_block_size}. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can"
|
|
76
|
+
" override this default with `--max_block_size xxx`."
|
|
77
|
+
)
|
|
78
|
+
block_size = max_block_size
|
|
79
|
+
|
|
80
|
+
if n_blocks_for_stat > 0: # Random choose `n_blocks_for_stat` blocks
|
|
81
|
+
calib_set = (
|
|
82
|
+
all_set["train"]
|
|
83
|
+
.shuffle(seed=seed)
|
|
84
|
+
.select(range(min(n_blocks_for_stat * 16, len(all_set["train"]))))
|
|
85
|
+
)
|
|
86
|
+
else: # Use the whole set
|
|
87
|
+
logger.warning("n_blocks_for_stat <= 0, using the whole dataset.")
|
|
88
|
+
calib_set = all_set["train"].shuffle(seed=seed)
|
|
89
|
+
|
|
90
|
+
logger.info(f"Calibration dataset: {calib_set}")
|
|
91
|
+
text_column_name = (
|
|
92
|
+
"text" if "text" in calib_set.features else list(calib_set.features)[0]
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
tok_logger = transformers.utils.logging.get_logger(
|
|
96
|
+
"transformers.tokenization_utils_base"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def tokenize_function(examples):
|
|
100
|
+
with CaptureLogger(tok_logger) as cl:
|
|
101
|
+
output = tokenizer(examples[text_column_name])
|
|
102
|
+
# clm input could be much much longer than block_size
|
|
103
|
+
if "Token indices sequence length is longer than the" in cl.out:
|
|
104
|
+
tok_logger.warning(
|
|
105
|
+
"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
|
|
106
|
+
" before being passed to the model."
|
|
107
|
+
)
|
|
108
|
+
return output
|
|
109
|
+
|
|
110
|
+
tokenized_calib_set = calib_set.map(
|
|
111
|
+
tokenize_function,
|
|
112
|
+
batched=True,
|
|
113
|
+
remove_columns=list(calib_set.features),
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
def group_texts(examples):
|
|
117
|
+
# Concatenate all texts.
|
|
118
|
+
concatenated_examples = {
|
|
119
|
+
k: list(itertools.chain(*examples[k])) for k in examples.keys()
|
|
120
|
+
}
|
|
121
|
+
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
|
122
|
+
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
|
123
|
+
# customize this part to your needs.
|
|
124
|
+
if total_length >= block_size:
|
|
125
|
+
total_length = (total_length // block_size) * block_size
|
|
126
|
+
# Split by chunks of max_len.
|
|
127
|
+
result = {
|
|
128
|
+
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
|
129
|
+
for k, t in concatenated_examples.items()
|
|
130
|
+
}
|
|
131
|
+
result["labels"] = result["input_ids"].copy()
|
|
132
|
+
return result
|
|
133
|
+
|
|
134
|
+
lm_calib_set = tokenized_calib_set.map(
|
|
135
|
+
group_texts,
|
|
136
|
+
batched=True,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
if n_blocks_for_stat > 0:
|
|
140
|
+
assert len(lm_calib_set) > n_blocks_for_stat
|
|
141
|
+
lm_calib_set = lm_calib_set.select(range(n_blocks_for_stat))
|
|
142
|
+
|
|
143
|
+
calib_loader = torch.utils.data.DataLoader(
|
|
144
|
+
lm_calib_set,
|
|
145
|
+
batch_size=batch_size,
|
|
146
|
+
num_workers=num_workers,
|
|
147
|
+
pin_memory=True,
|
|
148
|
+
drop_last=False,
|
|
149
|
+
shuffle=False,
|
|
150
|
+
collate_fn=default_data_collator,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
return calib_loader
|
|
File without changes
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def subspace_alignment(
|
|
5
|
+
delta_weights: list[torch.Tensor],
|
|
6
|
+
svd_dtype: torch.dtype | None = torch.float64,
|
|
7
|
+
eps: float = 1e-4,
|
|
8
|
+
):
|
|
9
|
+
"""
|
|
10
|
+
Reference: Model merging with SVD to tie the Knots. http://arxiv.org/abs/2410.19735
|
|
11
|
+
"""
|
|
12
|
+
if svd_dtype is None:
|
|
13
|
+
svd_dtype = delta_weights[0].dtype
|
|
14
|
+
original_dtype = delta_weights[0].dtype
|
|
15
|
+
output_dim, input_dim = delta_weights[0].size()
|
|
16
|
+
concat_task_vector = torch.cat(delta_weights, dim=1)
|
|
17
|
+
U, S, Vh = torch.linalg.svd(concat_task_vector.to(svd_dtype), full_matrices=False)
|
|
18
|
+
# Keep only supported basis components
|
|
19
|
+
U = U[:, S > eps].to(original_dtype)
|
|
20
|
+
Vh = Vh[S > eps].to(original_dtype)
|
|
21
|
+
S = S[S > eps].to(original_dtype)
|
|
22
|
+
Vhs = torch.split(Vh, input_dim, dim=1)
|
|
23
|
+
return U, S, Vhs
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
from
|
|
1
|
+
from copy import deepcopy
|
|
2
|
+
from typing import TYPE_CHECKING, Optional
|
|
2
3
|
|
|
3
4
|
from typing_extensions import override
|
|
4
5
|
|
|
@@ -6,6 +7,11 @@ from fusion_bench import timeit_context
|
|
|
6
7
|
from fusion_bench.method.base_algorithm import BaseAlgorithm
|
|
7
8
|
from fusion_bench.method.simple_average import SimpleAverageAlgorithm
|
|
8
9
|
from fusion_bench.modelpool import CausalLMBackbonePool, CausalLMPool
|
|
10
|
+
from fusion_bench.utils.pylogger import getRankZeroLogger
|
|
11
|
+
from omegaconf import flag_override
|
|
12
|
+
from fusion_bench.utils import instantiate
|
|
13
|
+
|
|
14
|
+
log = getRankZeroLogger(__name__)
|
|
9
15
|
|
|
10
16
|
|
|
11
17
|
class SimpleAverageForLlama(BaseAlgorithm):
|
|
@@ -40,12 +46,20 @@ class SimpleAverageForLlama(BaseAlgorithm):
|
|
|
40
46
|
|
|
41
47
|
if self.merge_backbone:
|
|
42
48
|
assert modelpool.has_pretrained
|
|
43
|
-
|
|
49
|
+
log.info(
|
|
50
|
+
"Merging backbone of the model pool, use CausalLMBackbonePool instead of CausalLMPool."
|
|
51
|
+
)
|
|
52
|
+
modelpool_config = deepcopy(modelpool.config)
|
|
53
|
+
with flag_override(modelpool_config, "allow_objects", True):
|
|
54
|
+
modelpool_config._target_ = (
|
|
55
|
+
"fusion_bench.modelpool.causal_lm.CausalLMBackbonePool"
|
|
56
|
+
)
|
|
57
|
+
backbone_modelpool = instantiate(modelpool_config)
|
|
44
58
|
model = modelpool.load_model("_pretrained_")
|
|
45
59
|
backbone_model = SimpleAverageAlgorithm().run(backbone_modelpool)
|
|
46
60
|
model.model.layers = backbone_model
|
|
47
61
|
else:
|
|
48
|
-
model = SimpleAverageAlgorithm().run()
|
|
62
|
+
model = SimpleAverageAlgorithm().run(modelpool=modelpool)
|
|
49
63
|
|
|
50
64
|
if self.model_save_path is not None:
|
|
51
65
|
with timeit_context(f"Saving the model to {self.model_save_path}"):
|
|
@@ -8,6 +8,7 @@ from torch import nn
|
|
|
8
8
|
from fusion_bench.method.base_algorithm import BaseAlgorithm
|
|
9
9
|
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
10
10
|
from fusion_bench.modelpool import BaseModelPool
|
|
11
|
+
from fusion_bench.utils import LazyStateDict
|
|
11
12
|
from fusion_bench.utils.state_dict_arithmetic import (
|
|
12
13
|
state_dict_add,
|
|
13
14
|
state_dict_avg,
|
|
@@ -104,6 +105,15 @@ class SimpleAverageAlgorithm(
|
|
|
104
105
|
# Divide the accumulated state dictionary by the number of models to get the average
|
|
105
106
|
sd = state_dict_div(sd, len(modelpool.model_names))
|
|
106
107
|
|
|
108
|
+
if isinstance(forward_model, LazyStateDict):
|
|
109
|
+
# if the model is a LazyStateDict, convert it to an empty module
|
|
110
|
+
forward_model = forward_model.meta_module.to_empty(
|
|
111
|
+
device=(
|
|
112
|
+
"cpu"
|
|
113
|
+
if forward_model._torch_dtype is None
|
|
114
|
+
else forward_model._torch_dtype
|
|
115
|
+
)
|
|
116
|
+
)
|
|
107
117
|
forward_model.load_state_dict(sd)
|
|
108
118
|
# print profile report and log the merged models
|
|
109
119
|
self.print_profile_summary()
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def compute_task_singular_interference(weight_differences: List[torch.Tensor]) -> float:
|
|
7
|
+
R"""
|
|
8
|
+
Compute the singular interference of a list of weight differences $\{W_i - W_0\}_{i=1}^T$,
|
|
9
|
+
where $W_0$ is the pre-trained model weight, $W_i$ is the weight of the i-th fine-tuned model
|
|
10
|
+
and $T$ is the number of fine-tuned models.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
weight_differences (List[torch.Tensor]): A list of weight differences $\{W_i - W_0\}_{i=1}^T$.
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
float: The singular interference of the list of weight differences.
|
|
17
|
+
"""
|
|
18
|
+
device = weight_differences[0].device
|
|
19
|
+
dtype = weight_differences[0].dtype
|
|
20
|
+
|
|
21
|
+
U = []
|
|
22
|
+
S = []
|
|
23
|
+
V = []
|
|
24
|
+
for delta_w in weight_differences:
|
|
25
|
+
u, s, vh = torch.linalg.svd(delta_w, full_matrices=False)
|
|
26
|
+
U.append(u)
|
|
27
|
+
S.append(s)
|
|
28
|
+
V.append(vh.t())
|
|
29
|
+
U = torch.cat(U, dim=0)
|
|
30
|
+
S = torch.cat(S, dim=0)
|
|
31
|
+
V = torch.cat(V, dim=0)
|
|
32
|
+
|
|
33
|
+
singular_task_interference = torch.linalg.multi_dot(
|
|
34
|
+
(
|
|
35
|
+
U.t() @ U - torch.eye(U.shape[1], device=device, dtype=dtype),
|
|
36
|
+
torch.diag(S),
|
|
37
|
+
V.t() @ V - torch.eye(V.shape[1], device=device, dtype=dtype),
|
|
38
|
+
)
|
|
39
|
+
)
|
|
40
|
+
singular_task_interference = torch.linalg.norm(singular_task_interference, ord="1")
|
|
41
|
+
return singular_task_interference
|
|
@@ -22,6 +22,8 @@ from typing_extensions import override
|
|
|
22
22
|
from fusion_bench.modelpool import BaseModelPool
|
|
23
23
|
from fusion_bench.utils import instantiate
|
|
24
24
|
from fusion_bench.utils.dtype import parse_dtype
|
|
25
|
+
from fusion_bench.utils.lazy_state_dict import LazyStateDict
|
|
26
|
+
from fusion_bench.utils.packages import import_object
|
|
25
27
|
|
|
26
28
|
log = logging.getLogger(__name__)
|
|
27
29
|
|
|
@@ -30,6 +32,7 @@ class CausalLMPool(BaseModelPool):
|
|
|
30
32
|
_config_mapping = BaseModelPool._config_mapping | {
|
|
31
33
|
"_tokenizer": "tokenizer",
|
|
32
34
|
"_model_kwargs": "model_kwargs",
|
|
35
|
+
"load_lazy": "load_lazy",
|
|
33
36
|
}
|
|
34
37
|
|
|
35
38
|
def __init__(
|
|
@@ -38,6 +41,7 @@ class CausalLMPool(BaseModelPool):
|
|
|
38
41
|
*,
|
|
39
42
|
tokenizer: Optional[DictConfig],
|
|
40
43
|
model_kwargs: Optional[DictConfig] = None,
|
|
44
|
+
load_lazy: bool = False,
|
|
41
45
|
**kwargs,
|
|
42
46
|
):
|
|
43
47
|
super().__init__(models, **kwargs)
|
|
@@ -51,6 +55,7 @@ class CausalLMPool(BaseModelPool):
|
|
|
51
55
|
self._model_kwargs.torch_dtype = parse_dtype(
|
|
52
56
|
self._model_kwargs.torch_dtype
|
|
53
57
|
)
|
|
58
|
+
self.load_lazy = load_lazy
|
|
54
59
|
|
|
55
60
|
@override
|
|
56
61
|
def load_model(
|
|
@@ -88,21 +93,41 @@ class CausalLMPool(BaseModelPool):
|
|
|
88
93
|
model_kwargs.update(kwargs)
|
|
89
94
|
|
|
90
95
|
if isinstance(model_name_or_config, str):
|
|
96
|
+
# If model_name_or_config is a string, it is the name or the path of the model
|
|
91
97
|
log.info(f"Loading model: {model_name_or_config}", stacklevel=2)
|
|
92
98
|
if model_name_or_config in self._models.keys():
|
|
93
99
|
model_config = self._models[model_name_or_config]
|
|
94
100
|
if isinstance(model_config, str):
|
|
95
101
|
# model_config is a string
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
102
|
+
if not self.load_lazy:
|
|
103
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
104
|
+
model_config,
|
|
105
|
+
*args,
|
|
106
|
+
**model_kwargs,
|
|
107
|
+
)
|
|
108
|
+
else:
|
|
109
|
+
# model_config is a string, but we want to use LazyStateDict
|
|
110
|
+
model = LazyStateDict(
|
|
111
|
+
checkpoint=model_config,
|
|
112
|
+
meta_module_class=AutoModelForCausalLM,
|
|
113
|
+
*args,
|
|
114
|
+
**model_kwargs,
|
|
115
|
+
)
|
|
101
116
|
return model
|
|
102
117
|
elif isinstance(model_name_or_config, (DictConfig, Dict)):
|
|
103
118
|
model_config = model_name_or_config
|
|
104
119
|
|
|
105
|
-
|
|
120
|
+
if not self.load_lazy:
|
|
121
|
+
model = instantiate(model_config, *args, **model_kwargs)
|
|
122
|
+
else:
|
|
123
|
+
meta_module_class = model_config.pop("_target_")
|
|
124
|
+
checkpoint = model_config.pop("pretrained_model_name_or_path")
|
|
125
|
+
model = LazyStateDict(
|
|
126
|
+
checkpoint=checkpoint,
|
|
127
|
+
meta_module_class=meta_module_class,
|
|
128
|
+
*args,
|
|
129
|
+
**model_kwargs,
|
|
130
|
+
)
|
|
106
131
|
return model
|
|
107
132
|
|
|
108
133
|
def load_tokenizer(self, *args, **kwargs) -> PreTrainedTokenizer:
|
|
@@ -141,6 +166,7 @@ class CausalLMPool(BaseModelPool):
|
|
|
141
166
|
model_dtype: Optional[str] = None,
|
|
142
167
|
save_tokenizer: bool = False,
|
|
143
168
|
tokenizer_kwargs=None,
|
|
169
|
+
tokenizer: Optional[PreTrainedTokenizer] = None,
|
|
144
170
|
**kwargs,
|
|
145
171
|
):
|
|
146
172
|
"""
|
|
@@ -154,11 +180,13 @@ class CausalLMPool(BaseModelPool):
|
|
|
154
180
|
**kwargs: Additional keyword arguments passed to the `save_pretrained` method.
|
|
155
181
|
"""
|
|
156
182
|
path = os.path.expanduser(path)
|
|
157
|
-
if save_tokenizer
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
183
|
+
# NOTE: if tokenizer is provided, it will be saved regardless of `save_tokenizer`
|
|
184
|
+
if save_tokenizer or tokenizer is not None:
|
|
185
|
+
if tokenizer is None:
|
|
186
|
+
if tokenizer_kwargs is None:
|
|
187
|
+
tokenizer_kwargs = {}
|
|
188
|
+
# load the tokenizer
|
|
189
|
+
tokenizer = self.load_tokenizer(**tokenizer_kwargs)
|
|
162
190
|
tokenizer.save_pretrained(
|
|
163
191
|
path,
|
|
164
192
|
push_to_hub=push_to_hub,
|
|
@@ -176,6 +204,12 @@ class CausalLMBackbonePool(CausalLMPool):
|
|
|
176
204
|
def load_model(
|
|
177
205
|
self, model_name_or_config: str | DictConfig, *args, **kwargs
|
|
178
206
|
) -> Module:
|
|
207
|
+
if self.load_lazy:
|
|
208
|
+
log.warning(
|
|
209
|
+
"CausalLMBackbonePool does not support lazy loading. "
|
|
210
|
+
"Falling back to normal loading."
|
|
211
|
+
)
|
|
212
|
+
self.load_lazy = False
|
|
179
213
|
model: AutoModelForCausalLM = super().load_model(
|
|
180
214
|
model_name_or_config, *args, **kwargs
|
|
181
215
|
)
|
fusion_bench/models/__init__.py
CHANGED
|
File without changes
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
R"""
|
|
2
|
+
Copy from https://github.com/Lucky-Lance/Expert_Sparsity/tree/main/model
|
|
3
|
+
|
|
4
|
+
Original repo: https://github.com/Lucky-Lance/Expert_Sparsity
|
|
5
|
+
|
|
6
|
+
Reference:
|
|
7
|
+
Not All Experts are Equal: Efficient Expert Pruning and Skipping for Mixture-of-Experts Large Language Models.
|
|
8
|
+
ACL 2024.
|
|
9
|
+
http://arxiv.org/abs/2402.14800
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from .wrapper import (
|
|
13
|
+
PrunableMixtralSparseMoeBlockWrapper,
|
|
14
|
+
DynamicSkippingMixtralSparseMoeBlockWrapper,
|
|
15
|
+
)
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class CacheDataset(torch.utils.data.Dataset):
|
|
5
|
+
def __init__(self):
|
|
6
|
+
self.alphas = [] # logits
|
|
7
|
+
self.Xs = [] # input hidden states
|
|
8
|
+
self.Zs = [] # output hidden states
|
|
9
|
+
self.prepared = False
|
|
10
|
+
|
|
11
|
+
def __len__(self):
|
|
12
|
+
if not self.prepared:
|
|
13
|
+
self.prepare_for_loader()
|
|
14
|
+
return len(self.alphas)
|
|
15
|
+
|
|
16
|
+
def __getitem__(self, index):
|
|
17
|
+
if not self.prepared:
|
|
18
|
+
self.prepare_for_loader()
|
|
19
|
+
if isinstance(index, list):
|
|
20
|
+
return [(self.alphas[idx], self.Xs[idx], self.Zs[idx]) for idx in index]
|
|
21
|
+
elif isinstance(index, int):
|
|
22
|
+
return self.alphas[index], self.Xs[index], self.Zs[index]
|
|
23
|
+
|
|
24
|
+
def append(self, alpha=None, X=None, Z=None):
|
|
25
|
+
if alpha is not None:
|
|
26
|
+
self.alphas.append(alpha.detach().to("cpu", non_blocking=True))
|
|
27
|
+
if X is not None:
|
|
28
|
+
self.Xs.append(X.detach().to("cpu", non_blocking=True))
|
|
29
|
+
if Z is not None:
|
|
30
|
+
self.Zs.append(Z.detach().to("cpu", non_blocking=True))
|
|
31
|
+
self.prepared = False
|
|
32
|
+
|
|
33
|
+
def prepare_for_loader(self):
|
|
34
|
+
if self.prepared:
|
|
35
|
+
return
|
|
36
|
+
self.prepared = True
|
|
37
|
+
self.alphas = torch.concat(self.alphas)
|
|
38
|
+
self.Xs = torch.concat(self.Xs)
|
|
39
|
+
self.Zs = torch.concat(self.Zs)
|
|
40
|
+
assert len(self.Xs) == len(self.Zs)
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from typing import Optional, Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
|
8
|
+
from transformers.models.mixtral.modeling_mixtral import (
|
|
9
|
+
MixtralBlockSparseTop2MLP,
|
|
10
|
+
MixtralConfig,
|
|
11
|
+
MixtralRMSNorm,
|
|
12
|
+
MixtralSparseMoeBlock,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DynamicSkippingMixtralSparseMoeBlock(nn.Module):
|
|
17
|
+
"""
|
|
18
|
+
This implementation is
|
|
19
|
+
strictly equivalent to standard MoE with full capacity (no
|
|
20
|
+
dropped tokens). It's faster since it formulates MoE operations
|
|
21
|
+
in terms of block-sparse operations to accomodate imbalanced
|
|
22
|
+
assignments of tokens to experts, whereas standard MoE either
|
|
23
|
+
(1) drop tokens at the cost of reduced performance or (2) set
|
|
24
|
+
capacity factor to number of experts and thus waste computation
|
|
25
|
+
and memory on padding.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, config, beta):
|
|
29
|
+
super().__init__()
|
|
30
|
+
self.hidden_dim = config.hidden_size
|
|
31
|
+
self.ffn_dim = config.intermediate_size
|
|
32
|
+
self.num_experts = config.num_local_experts
|
|
33
|
+
self.top_k = config.num_experts_per_tok
|
|
34
|
+
|
|
35
|
+
# gating
|
|
36
|
+
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
|
37
|
+
|
|
38
|
+
self.experts = nn.ModuleList(
|
|
39
|
+
[MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
self.beta = beta
|
|
43
|
+
|
|
44
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
45
|
+
""" """
|
|
46
|
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
|
47
|
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
48
|
+
# router_logits: (batch * sequence_length, n_experts)
|
|
49
|
+
router_logits = self.gate(hidden_states)
|
|
50
|
+
|
|
51
|
+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
|
52
|
+
routing_weights, selected_experts = torch.topk(
|
|
53
|
+
routing_weights, self.top_k, dim=-1
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
onlytop1_mask = (
|
|
57
|
+
routing_weights[:, 1] < self.beta * routing_weights[:, 0]
|
|
58
|
+
) # bz x seqlen
|
|
59
|
+
|
|
60
|
+
# routing_weights[tokens_top1, 1].fill_(0)
|
|
61
|
+
routing_weights[onlytop1_mask, 1] = 0
|
|
62
|
+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
|
63
|
+
# we cast back to the input dtype
|
|
64
|
+
routing_weights = routing_weights.to(hidden_states.dtype)
|
|
65
|
+
|
|
66
|
+
final_hidden_states = torch.zeros(
|
|
67
|
+
(batch_size * sequence_length, hidden_dim),
|
|
68
|
+
dtype=hidden_states.dtype,
|
|
69
|
+
device=hidden_states.device,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# One hot encode the selected experts to create an expert mask
|
|
73
|
+
# this will be used to easily index which expert is going to be sollicitated
|
|
74
|
+
expert_mask = torch.nn.functional.one_hot(
|
|
75
|
+
selected_experts, num_classes=self.num_experts
|
|
76
|
+
)
|
|
77
|
+
# ipdb.set_trace()
|
|
78
|
+
# expert_mask[tokens_top1, 1, :].fill_(0)
|
|
79
|
+
expert_mask[onlytop1_mask, 1, :] = 0
|
|
80
|
+
expert_mask = expert_mask.permute(2, 1, 0)
|
|
81
|
+
|
|
82
|
+
# Loop over all available experts in the model and perform the computation on each expert
|
|
83
|
+
for expert_idx in range(self.num_experts):
|
|
84
|
+
expert_layer = self.experts[expert_idx]
|
|
85
|
+
idx, top_x = torch.where(expert_mask[expert_idx])
|
|
86
|
+
|
|
87
|
+
if top_x.shape[0] == 0:
|
|
88
|
+
continue
|
|
89
|
+
|
|
90
|
+
# in torch it is faster to index using lists than torch tensors
|
|
91
|
+
top_x_list = top_x.tolist()
|
|
92
|
+
idx_list = idx.tolist()
|
|
93
|
+
|
|
94
|
+
# Index the correct hidden states and compute the expert hidden state for
|
|
95
|
+
# the current expert. We need to make sure to multiply the output hidden
|
|
96
|
+
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
|
97
|
+
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
|
|
98
|
+
current_hidden_states = expert_layer(
|
|
99
|
+
current_state, routing_weights[top_x_list, idx_list, None]
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# However `index_add_` only support torch tensors for indexing so we'll use
|
|
103
|
+
# the `top_x` tensor here.
|
|
104
|
+
final_hidden_states.index_add_(
|
|
105
|
+
0, top_x, current_hidden_states.to(hidden_states.dtype)
|
|
106
|
+
)
|
|
107
|
+
final_hidden_states = final_hidden_states.reshape(
|
|
108
|
+
batch_size, sequence_length, hidden_dim
|
|
109
|
+
)
|
|
110
|
+
return final_hidden_states, router_logits
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class MixtralDecoderLayer(nn.Module):
|
|
114
|
+
def __init__(self, config: MixtralConfig, layer_idx: int):
|
|
115
|
+
super().__init__()
|
|
116
|
+
self.hidden_size = config.hidden_size
|
|
117
|
+
|
|
118
|
+
self.self_attn = ALL_ATTENTION_FUNCTIONS[config._attn_implementation](
|
|
119
|
+
config, layer_idx
|
|
120
|
+
)
|
|
121
|
+
if hasattr(config, "betas"):
|
|
122
|
+
assert (
|
|
123
|
+
isinstance(config.betas, dict)
|
|
124
|
+
and len(config.betas) == config.num_hidden_layers
|
|
125
|
+
)
|
|
126
|
+
self.block_sparse_moe = DynamicSkippingMixtralSparseMoeBlock(
|
|
127
|
+
config, config.betas[str(layer_idx)]
|
|
128
|
+
)
|
|
129
|
+
warnings.warn(
|
|
130
|
+
f"Using online drop: {layer_idx}, {config.betas[str(layer_idx)]}, {type(self.block_sparse_moe)}"
|
|
131
|
+
)
|
|
132
|
+
else:
|
|
133
|
+
self.block_sparse_moe = MixtralSparseMoeBlock(config)
|
|
134
|
+
self.input_layernorm = MixtralRMSNorm(
|
|
135
|
+
config.hidden_size, eps=config.rms_norm_eps
|
|
136
|
+
)
|
|
137
|
+
self.post_attention_layernorm = MixtralRMSNorm(
|
|
138
|
+
config.hidden_size, eps=config.rms_norm_eps
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
def forward(
|
|
142
|
+
self,
|
|
143
|
+
hidden_states: torch.Tensor,
|
|
144
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
145
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
146
|
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
147
|
+
output_attentions: Optional[bool] = False,
|
|
148
|
+
output_router_logits: Optional[bool] = False,
|
|
149
|
+
use_cache: Optional[bool] = False,
|
|
150
|
+
**kwargs,
|
|
151
|
+
) -> Tuple[
|
|
152
|
+
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
|
153
|
+
]:
|
|
154
|
+
if "padding_mask" in kwargs:
|
|
155
|
+
warnings.warn(
|
|
156
|
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
|
157
|
+
)
|
|
158
|
+
"""
|
|
159
|
+
Args:
|
|
160
|
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
161
|
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
|
162
|
+
`(batch, sequence_length)` where padding elements are indicated by 0.
|
|
163
|
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
|
164
|
+
output_attentions (`bool`, *optional*):
|
|
165
|
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
166
|
+
returned tensors for more detail.
|
|
167
|
+
output_router_logits (`bool`, *optional*):
|
|
168
|
+
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
|
|
169
|
+
should not be returned during inference.
|
|
170
|
+
use_cache (`bool`, *optional*):
|
|
171
|
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
|
172
|
+
(see `past_key_values`).
|
|
173
|
+
"""
|
|
174
|
+
|
|
175
|
+
residual = hidden_states
|
|
176
|
+
|
|
177
|
+
hidden_states = self.input_layernorm(hidden_states)
|
|
178
|
+
|
|
179
|
+
# Self Attention
|
|
180
|
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
181
|
+
hidden_states=hidden_states,
|
|
182
|
+
attention_mask=attention_mask,
|
|
183
|
+
position_ids=position_ids,
|
|
184
|
+
past_key_value=past_key_value,
|
|
185
|
+
output_attentions=output_attentions,
|
|
186
|
+
use_cache=use_cache,
|
|
187
|
+
)
|
|
188
|
+
hidden_states = residual + hidden_states
|
|
189
|
+
|
|
190
|
+
# Fully Connected
|
|
191
|
+
residual = hidden_states
|
|
192
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
193
|
+
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
|
|
194
|
+
hidden_states = residual + hidden_states
|
|
195
|
+
|
|
196
|
+
outputs = (hidden_states,)
|
|
197
|
+
|
|
198
|
+
if output_attentions:
|
|
199
|
+
outputs += (self_attn_weights,)
|
|
200
|
+
|
|
201
|
+
if use_cache:
|
|
202
|
+
outputs += (present_key_value,)
|
|
203
|
+
|
|
204
|
+
if output_router_logits:
|
|
205
|
+
outputs += (router_logits,)
|
|
206
|
+
|
|
207
|
+
return outputs
|