fusion-bench 0.2.17__py3-none-any.whl → 0.2.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/method/__init__.py +11 -0
- fusion_bench/method/expert_sparsity/__init__.py +10 -0
- fusion_bench/method/expert_sparsity/mixtral/__init__.py +23 -0
- fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +175 -0
- fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +159 -0
- fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +173 -0
- fusion_bench/method/expert_sparsity/utils/calibration_data.py +153 -0
- fusion_bench/method/knots/__init__.py +0 -0
- fusion_bench/method/knots/knots_utils.py +23 -0
- fusion_bench/method/task_singular_vector/utils/__init__.py +1 -0
- fusion_bench/method/task_singular_vector/utils/task_singular_interference.py +41 -0
- fusion_bench/modelpool/causal_lm/causal_lm.py +8 -5
- fusion_bench/models/__init__.py +1 -0
- fusion_bench/models/expert_sparsity/__init__.py +0 -0
- fusion_bench/models/expert_sparsity/mixtral/__init__.py +15 -0
- fusion_bench/models/expert_sparsity/mixtral/dataset.py +40 -0
- fusion_bench/models/expert_sparsity/mixtral/modeling_mixtral.py +207 -0
- fusion_bench/models/expert_sparsity/mixtral/wrapper.py +268 -0
- fusion_bench/programs/fabric_fusion_program.py +12 -8
- fusion_bench/utils/__init__.py +3 -2
- fusion_bench/utils/dtype.py +2 -1
- fusion_bench/utils/fabric.py +11 -4
- fusion_bench/utils/lazy_state_dict.py +80 -10
- fusion_bench/utils/pylogger.py +2 -0
- {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.18.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.18.dist-info}/RECORD +33 -16
- fusion_bench_config/fabric/loggers/mlflow_logger.yaml +2 -0
- fusion_bench_config/method/expert_sparsity/README.md +6 -0
- fusion_bench_config/method/expert_sparsity/mixtral.yaml +17 -0
- {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.18.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.18.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.18.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.18.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains the code for loading the calibration data.
|
|
3
|
+
|
|
4
|
+
- C4
|
|
5
|
+
- Math
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import itertools
|
|
9
|
+
import logging
|
|
10
|
+
import os
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
import transformers
|
|
14
|
+
from datasets import load_dataset
|
|
15
|
+
from transformers import PreTrainedTokenizer, default_data_collator
|
|
16
|
+
from transformers.testing_utils import CaptureLogger
|
|
17
|
+
from huggingface_hub import hf_hub_download
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
DATASETS = {
|
|
23
|
+
# C4: Please download first part of the C4 training data `c4-train.00000-of-01024.json` from [allenai/c4](https://huggingface.co/datasets/allenai/c4/blob/main/en/c4-train.00000-of-01024.json.gz).
|
|
24
|
+
"c4": lambda: load_dataset(
|
|
25
|
+
"json",
|
|
26
|
+
data_files={
|
|
27
|
+
"train": hf_hub_download(
|
|
28
|
+
"allenai/c4",
|
|
29
|
+
filename="en/c4-train.00000-of-01024.json.gz",
|
|
30
|
+
repo_type="dataset",
|
|
31
|
+
)
|
|
32
|
+
},
|
|
33
|
+
),
|
|
34
|
+
# MATH: You can use our pre-built calibration set in `./data/math_pretrain_style.json`. To reproduce our construction, please download the training set of [MATH](https://github.com/hendrycks/math) and use our [script](data/math_calib_construction.py).
|
|
35
|
+
# NOTE: I have uploaded the math_pretrain_style.json to my huggingface repo:
|
|
36
|
+
# https://huggingface.co/datasets/tanganke/math_pretrain_style/tree/main.
|
|
37
|
+
"math": lambda: load_dataset(
|
|
38
|
+
"json",
|
|
39
|
+
data_files={
|
|
40
|
+
"train": hf_hub_download(
|
|
41
|
+
"tanganke/math_pretrain_style",
|
|
42
|
+
filename="math_pretrain_style.json",
|
|
43
|
+
repo_type="dataset",
|
|
44
|
+
)
|
|
45
|
+
},
|
|
46
|
+
),
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def build_calib_loader(
|
|
51
|
+
dataset: str,
|
|
52
|
+
tokenizer: PreTrainedTokenizer,
|
|
53
|
+
max_block_size: int,
|
|
54
|
+
n_blocks_for_stat: int,
|
|
55
|
+
batch_size: int,
|
|
56
|
+
num_workers: int,
|
|
57
|
+
seed: int = 42,
|
|
58
|
+
):
|
|
59
|
+
# dataset can be a string or a dataset object.
|
|
60
|
+
# If it is a string, it can be the name of the dataset in DATASETS or the path to the dataset (a json file).
|
|
61
|
+
if isinstance(dataset, str):
|
|
62
|
+
if dataset in DATASETS:
|
|
63
|
+
all_set = DATASETS[dataset]()
|
|
64
|
+
else:
|
|
65
|
+
assert os.path.exists(dataset), f"Dataset {dataset} not found."
|
|
66
|
+
all_set = load_dataset("json", data_files={"train": dataset})
|
|
67
|
+
else:
|
|
68
|
+
assert dataset is not None, "Dataset is not provided."
|
|
69
|
+
all_set = dataset
|
|
70
|
+
|
|
71
|
+
block_size = tokenizer.model_max_length
|
|
72
|
+
if block_size > max_block_size:
|
|
73
|
+
logger.info(
|
|
74
|
+
"The chosen tokenizer supports a `model_max_length` that is longer than the default `max_block_size` value"
|
|
75
|
+
f" of {max_block_size}. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can"
|
|
76
|
+
" override this default with `--max_block_size xxx`."
|
|
77
|
+
)
|
|
78
|
+
block_size = max_block_size
|
|
79
|
+
|
|
80
|
+
if n_blocks_for_stat > 0: # Random choose `n_blocks_for_stat` blocks
|
|
81
|
+
calib_set = (
|
|
82
|
+
all_set["train"]
|
|
83
|
+
.shuffle(seed=seed)
|
|
84
|
+
.select(range(min(n_blocks_for_stat * 16, len(all_set["train"]))))
|
|
85
|
+
)
|
|
86
|
+
else: # Use the whole set
|
|
87
|
+
logger.warning("n_blocks_for_stat <= 0, using the whole dataset.")
|
|
88
|
+
calib_set = all_set["train"].shuffle(seed=seed)
|
|
89
|
+
|
|
90
|
+
logger.info(f"Calibration dataset: {calib_set}")
|
|
91
|
+
text_column_name = (
|
|
92
|
+
"text" if "text" in calib_set.features else list(calib_set.features)[0]
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
tok_logger = transformers.utils.logging.get_logger(
|
|
96
|
+
"transformers.tokenization_utils_base"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def tokenize_function(examples):
|
|
100
|
+
with CaptureLogger(tok_logger) as cl:
|
|
101
|
+
output = tokenizer(examples[text_column_name])
|
|
102
|
+
# clm input could be much much longer than block_size
|
|
103
|
+
if "Token indices sequence length is longer than the" in cl.out:
|
|
104
|
+
tok_logger.warning(
|
|
105
|
+
"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
|
|
106
|
+
" before being passed to the model."
|
|
107
|
+
)
|
|
108
|
+
return output
|
|
109
|
+
|
|
110
|
+
tokenized_calib_set = calib_set.map(
|
|
111
|
+
tokenize_function,
|
|
112
|
+
batched=True,
|
|
113
|
+
remove_columns=list(calib_set.features),
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
def group_texts(examples):
|
|
117
|
+
# Concatenate all texts.
|
|
118
|
+
concatenated_examples = {
|
|
119
|
+
k: list(itertools.chain(*examples[k])) for k in examples.keys()
|
|
120
|
+
}
|
|
121
|
+
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
|
122
|
+
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
|
123
|
+
# customize this part to your needs.
|
|
124
|
+
if total_length >= block_size:
|
|
125
|
+
total_length = (total_length // block_size) * block_size
|
|
126
|
+
# Split by chunks of max_len.
|
|
127
|
+
result = {
|
|
128
|
+
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
|
129
|
+
for k, t in concatenated_examples.items()
|
|
130
|
+
}
|
|
131
|
+
result["labels"] = result["input_ids"].copy()
|
|
132
|
+
return result
|
|
133
|
+
|
|
134
|
+
lm_calib_set = tokenized_calib_set.map(
|
|
135
|
+
group_texts,
|
|
136
|
+
batched=True,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
if n_blocks_for_stat > 0:
|
|
140
|
+
assert len(lm_calib_set) > n_blocks_for_stat
|
|
141
|
+
lm_calib_set = lm_calib_set.select(range(n_blocks_for_stat))
|
|
142
|
+
|
|
143
|
+
calib_loader = torch.utils.data.DataLoader(
|
|
144
|
+
lm_calib_set,
|
|
145
|
+
batch_size=batch_size,
|
|
146
|
+
num_workers=num_workers,
|
|
147
|
+
pin_memory=True,
|
|
148
|
+
drop_last=False,
|
|
149
|
+
shuffle=False,
|
|
150
|
+
collate_fn=default_data_collator,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
return calib_loader
|
|
File without changes
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def subspace_alignment(
|
|
5
|
+
delta_weights: list[torch.Tensor],
|
|
6
|
+
svd_dtype: torch.dtype | None = torch.float64,
|
|
7
|
+
eps: float = 1e-4,
|
|
8
|
+
):
|
|
9
|
+
"""
|
|
10
|
+
Reference: Model merging with SVD to tie the Knots. http://arxiv.org/abs/2410.19735
|
|
11
|
+
"""
|
|
12
|
+
if svd_dtype is None:
|
|
13
|
+
svd_dtype = delta_weights[0].dtype
|
|
14
|
+
original_dtype = delta_weights[0].dtype
|
|
15
|
+
output_dim, input_dim = delta_weights[0].size()
|
|
16
|
+
concat_task_vector = torch.cat(delta_weights, dim=1)
|
|
17
|
+
U, S, Vh = torch.linalg.svd(concat_task_vector.to(svd_dtype), full_matrices=False)
|
|
18
|
+
# Keep only supported basis components
|
|
19
|
+
U = U[:, S > eps].to(original_dtype)
|
|
20
|
+
Vh = Vh[S > eps].to(original_dtype)
|
|
21
|
+
S = S[S > eps].to(original_dtype)
|
|
22
|
+
Vhs = torch.split(Vh, input_dim, dim=1)
|
|
23
|
+
return U, S, Vhs
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def compute_task_singular_interference(weight_differences: List[torch.Tensor]) -> float:
|
|
7
|
+
R"""
|
|
8
|
+
Compute the singular interference of a list of weight differences $\{W_i - W_0\}_{i=1}^T$,
|
|
9
|
+
where $W_0$ is the pre-trained model weight, $W_i$ is the weight of the i-th fine-tuned model
|
|
10
|
+
and $T$ is the number of fine-tuned models.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
weight_differences (List[torch.Tensor]): A list of weight differences $\{W_i - W_0\}_{i=1}^T$.
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
float: The singular interference of the list of weight differences.
|
|
17
|
+
"""
|
|
18
|
+
device = weight_differences[0].device
|
|
19
|
+
dtype = weight_differences[0].dtype
|
|
20
|
+
|
|
21
|
+
U = []
|
|
22
|
+
S = []
|
|
23
|
+
V = []
|
|
24
|
+
for delta_w in weight_differences:
|
|
25
|
+
u, s, vh = torch.linalg.svd(delta_w, full_matrices=False)
|
|
26
|
+
U.append(u)
|
|
27
|
+
S.append(s)
|
|
28
|
+
V.append(vh.t())
|
|
29
|
+
U = torch.cat(U, dim=0)
|
|
30
|
+
S = torch.cat(S, dim=0)
|
|
31
|
+
V = torch.cat(V, dim=0)
|
|
32
|
+
|
|
33
|
+
singular_task_interference = torch.linalg.multi_dot(
|
|
34
|
+
(
|
|
35
|
+
U.t() @ U - torch.eye(U.shape[1], device=device, dtype=dtype),
|
|
36
|
+
torch.diag(S),
|
|
37
|
+
V.t() @ V - torch.eye(V.shape[1], device=device, dtype=dtype),
|
|
38
|
+
)
|
|
39
|
+
)
|
|
40
|
+
singular_task_interference = torch.linalg.norm(singular_task_interference, ord="1")
|
|
41
|
+
return singular_task_interference
|
|
@@ -141,6 +141,7 @@ class CausalLMPool(BaseModelPool):
|
|
|
141
141
|
model_dtype: Optional[str] = None,
|
|
142
142
|
save_tokenizer: bool = False,
|
|
143
143
|
tokenizer_kwargs=None,
|
|
144
|
+
tokenizer: Optional[PreTrainedTokenizer] = None,
|
|
144
145
|
**kwargs,
|
|
145
146
|
):
|
|
146
147
|
"""
|
|
@@ -154,11 +155,13 @@ class CausalLMPool(BaseModelPool):
|
|
|
154
155
|
**kwargs: Additional keyword arguments passed to the `save_pretrained` method.
|
|
155
156
|
"""
|
|
156
157
|
path = os.path.expanduser(path)
|
|
157
|
-
if save_tokenizer
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
158
|
+
# NOTE: if tokenizer is provided, it will be saved regardless of `save_tokenizer`
|
|
159
|
+
if save_tokenizer or tokenizer is not None:
|
|
160
|
+
if tokenizer is None:
|
|
161
|
+
if tokenizer_kwargs is None:
|
|
162
|
+
tokenizer_kwargs = {}
|
|
163
|
+
# load the tokenizer
|
|
164
|
+
tokenizer = self.load_tokenizer(**tokenizer_kwargs)
|
|
162
165
|
tokenizer.save_pretrained(
|
|
163
166
|
path,
|
|
164
167
|
push_to_hub=push_to_hub,
|
fusion_bench/models/__init__.py
CHANGED
|
File without changes
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
R"""
|
|
2
|
+
Copy from https://github.com/Lucky-Lance/Expert_Sparsity/tree/main/model
|
|
3
|
+
|
|
4
|
+
Original repo: https://github.com/Lucky-Lance/Expert_Sparsity
|
|
5
|
+
|
|
6
|
+
Reference:
|
|
7
|
+
Not All Experts are Equal: Efficient Expert Pruning and Skipping for Mixture-of-Experts Large Language Models.
|
|
8
|
+
ACL 2024.
|
|
9
|
+
http://arxiv.org/abs/2402.14800
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from .wrapper import (
|
|
13
|
+
PrunableMixtralSparseMoeBlockWrapper,
|
|
14
|
+
DynamicSkippingMixtralSparseMoeBlockWrapper,
|
|
15
|
+
)
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class CacheDataset(torch.utils.data.Dataset):
|
|
5
|
+
def __init__(self):
|
|
6
|
+
self.alphas = [] # logits
|
|
7
|
+
self.Xs = [] # input hidden states
|
|
8
|
+
self.Zs = [] # output hidden states
|
|
9
|
+
self.prepared = False
|
|
10
|
+
|
|
11
|
+
def __len__(self):
|
|
12
|
+
if not self.prepared:
|
|
13
|
+
self.prepare_for_loader()
|
|
14
|
+
return len(self.alphas)
|
|
15
|
+
|
|
16
|
+
def __getitem__(self, index):
|
|
17
|
+
if not self.prepared:
|
|
18
|
+
self.prepare_for_loader()
|
|
19
|
+
if isinstance(index, list):
|
|
20
|
+
return [(self.alphas[idx], self.Xs[idx], self.Zs[idx]) for idx in index]
|
|
21
|
+
elif isinstance(index, int):
|
|
22
|
+
return self.alphas[index], self.Xs[index], self.Zs[index]
|
|
23
|
+
|
|
24
|
+
def append(self, alpha=None, X=None, Z=None):
|
|
25
|
+
if alpha is not None:
|
|
26
|
+
self.alphas.append(alpha.detach().to("cpu", non_blocking=True))
|
|
27
|
+
if X is not None:
|
|
28
|
+
self.Xs.append(X.detach().to("cpu", non_blocking=True))
|
|
29
|
+
if Z is not None:
|
|
30
|
+
self.Zs.append(Z.detach().to("cpu", non_blocking=True))
|
|
31
|
+
self.prepared = False
|
|
32
|
+
|
|
33
|
+
def prepare_for_loader(self):
|
|
34
|
+
if self.prepared:
|
|
35
|
+
return
|
|
36
|
+
self.prepared = True
|
|
37
|
+
self.alphas = torch.concat(self.alphas)
|
|
38
|
+
self.Xs = torch.concat(self.Xs)
|
|
39
|
+
self.Zs = torch.concat(self.Zs)
|
|
40
|
+
assert len(self.Xs) == len(self.Zs)
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from typing import Optional, Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
|
8
|
+
from transformers.models.mixtral.modeling_mixtral import (
|
|
9
|
+
MixtralBlockSparseTop2MLP,
|
|
10
|
+
MixtralConfig,
|
|
11
|
+
MixtralRMSNorm,
|
|
12
|
+
MixtralSparseMoeBlock,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DynamicSkippingMixtralSparseMoeBlock(nn.Module):
|
|
17
|
+
"""
|
|
18
|
+
This implementation is
|
|
19
|
+
strictly equivalent to standard MoE with full capacity (no
|
|
20
|
+
dropped tokens). It's faster since it formulates MoE operations
|
|
21
|
+
in terms of block-sparse operations to accomodate imbalanced
|
|
22
|
+
assignments of tokens to experts, whereas standard MoE either
|
|
23
|
+
(1) drop tokens at the cost of reduced performance or (2) set
|
|
24
|
+
capacity factor to number of experts and thus waste computation
|
|
25
|
+
and memory on padding.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, config, beta):
|
|
29
|
+
super().__init__()
|
|
30
|
+
self.hidden_dim = config.hidden_size
|
|
31
|
+
self.ffn_dim = config.intermediate_size
|
|
32
|
+
self.num_experts = config.num_local_experts
|
|
33
|
+
self.top_k = config.num_experts_per_tok
|
|
34
|
+
|
|
35
|
+
# gating
|
|
36
|
+
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
|
37
|
+
|
|
38
|
+
self.experts = nn.ModuleList(
|
|
39
|
+
[MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
self.beta = beta
|
|
43
|
+
|
|
44
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
45
|
+
""" """
|
|
46
|
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
|
47
|
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
48
|
+
# router_logits: (batch * sequence_length, n_experts)
|
|
49
|
+
router_logits = self.gate(hidden_states)
|
|
50
|
+
|
|
51
|
+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
|
52
|
+
routing_weights, selected_experts = torch.topk(
|
|
53
|
+
routing_weights, self.top_k, dim=-1
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
onlytop1_mask = (
|
|
57
|
+
routing_weights[:, 1] < self.beta * routing_weights[:, 0]
|
|
58
|
+
) # bz x seqlen
|
|
59
|
+
|
|
60
|
+
# routing_weights[tokens_top1, 1].fill_(0)
|
|
61
|
+
routing_weights[onlytop1_mask, 1] = 0
|
|
62
|
+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
|
63
|
+
# we cast back to the input dtype
|
|
64
|
+
routing_weights = routing_weights.to(hidden_states.dtype)
|
|
65
|
+
|
|
66
|
+
final_hidden_states = torch.zeros(
|
|
67
|
+
(batch_size * sequence_length, hidden_dim),
|
|
68
|
+
dtype=hidden_states.dtype,
|
|
69
|
+
device=hidden_states.device,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# One hot encode the selected experts to create an expert mask
|
|
73
|
+
# this will be used to easily index which expert is going to be sollicitated
|
|
74
|
+
expert_mask = torch.nn.functional.one_hot(
|
|
75
|
+
selected_experts, num_classes=self.num_experts
|
|
76
|
+
)
|
|
77
|
+
# ipdb.set_trace()
|
|
78
|
+
# expert_mask[tokens_top1, 1, :].fill_(0)
|
|
79
|
+
expert_mask[onlytop1_mask, 1, :] = 0
|
|
80
|
+
expert_mask = expert_mask.permute(2, 1, 0)
|
|
81
|
+
|
|
82
|
+
# Loop over all available experts in the model and perform the computation on each expert
|
|
83
|
+
for expert_idx in range(self.num_experts):
|
|
84
|
+
expert_layer = self.experts[expert_idx]
|
|
85
|
+
idx, top_x = torch.where(expert_mask[expert_idx])
|
|
86
|
+
|
|
87
|
+
if top_x.shape[0] == 0:
|
|
88
|
+
continue
|
|
89
|
+
|
|
90
|
+
# in torch it is faster to index using lists than torch tensors
|
|
91
|
+
top_x_list = top_x.tolist()
|
|
92
|
+
idx_list = idx.tolist()
|
|
93
|
+
|
|
94
|
+
# Index the correct hidden states and compute the expert hidden state for
|
|
95
|
+
# the current expert. We need to make sure to multiply the output hidden
|
|
96
|
+
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
|
97
|
+
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
|
|
98
|
+
current_hidden_states = expert_layer(
|
|
99
|
+
current_state, routing_weights[top_x_list, idx_list, None]
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# However `index_add_` only support torch tensors for indexing so we'll use
|
|
103
|
+
# the `top_x` tensor here.
|
|
104
|
+
final_hidden_states.index_add_(
|
|
105
|
+
0, top_x, current_hidden_states.to(hidden_states.dtype)
|
|
106
|
+
)
|
|
107
|
+
final_hidden_states = final_hidden_states.reshape(
|
|
108
|
+
batch_size, sequence_length, hidden_dim
|
|
109
|
+
)
|
|
110
|
+
return final_hidden_states, router_logits
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class MixtralDecoderLayer(nn.Module):
|
|
114
|
+
def __init__(self, config: MixtralConfig, layer_idx: int):
|
|
115
|
+
super().__init__()
|
|
116
|
+
self.hidden_size = config.hidden_size
|
|
117
|
+
|
|
118
|
+
self.self_attn = ALL_ATTENTION_FUNCTIONS[config._attn_implementation](
|
|
119
|
+
config, layer_idx
|
|
120
|
+
)
|
|
121
|
+
if hasattr(config, "betas"):
|
|
122
|
+
assert (
|
|
123
|
+
isinstance(config.betas, dict)
|
|
124
|
+
and len(config.betas) == config.num_hidden_layers
|
|
125
|
+
)
|
|
126
|
+
self.block_sparse_moe = DynamicSkippingMixtralSparseMoeBlock(
|
|
127
|
+
config, config.betas[str(layer_idx)]
|
|
128
|
+
)
|
|
129
|
+
warnings.warn(
|
|
130
|
+
f"Using online drop: {layer_idx}, {config.betas[str(layer_idx)]}, {type(self.block_sparse_moe)}"
|
|
131
|
+
)
|
|
132
|
+
else:
|
|
133
|
+
self.block_sparse_moe = MixtralSparseMoeBlock(config)
|
|
134
|
+
self.input_layernorm = MixtralRMSNorm(
|
|
135
|
+
config.hidden_size, eps=config.rms_norm_eps
|
|
136
|
+
)
|
|
137
|
+
self.post_attention_layernorm = MixtralRMSNorm(
|
|
138
|
+
config.hidden_size, eps=config.rms_norm_eps
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
def forward(
|
|
142
|
+
self,
|
|
143
|
+
hidden_states: torch.Tensor,
|
|
144
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
145
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
146
|
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
147
|
+
output_attentions: Optional[bool] = False,
|
|
148
|
+
output_router_logits: Optional[bool] = False,
|
|
149
|
+
use_cache: Optional[bool] = False,
|
|
150
|
+
**kwargs,
|
|
151
|
+
) -> Tuple[
|
|
152
|
+
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
|
153
|
+
]:
|
|
154
|
+
if "padding_mask" in kwargs:
|
|
155
|
+
warnings.warn(
|
|
156
|
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
|
157
|
+
)
|
|
158
|
+
"""
|
|
159
|
+
Args:
|
|
160
|
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
161
|
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
|
162
|
+
`(batch, sequence_length)` where padding elements are indicated by 0.
|
|
163
|
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
|
164
|
+
output_attentions (`bool`, *optional*):
|
|
165
|
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
166
|
+
returned tensors for more detail.
|
|
167
|
+
output_router_logits (`bool`, *optional*):
|
|
168
|
+
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
|
|
169
|
+
should not be returned during inference.
|
|
170
|
+
use_cache (`bool`, *optional*):
|
|
171
|
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
|
172
|
+
(see `past_key_values`).
|
|
173
|
+
"""
|
|
174
|
+
|
|
175
|
+
residual = hidden_states
|
|
176
|
+
|
|
177
|
+
hidden_states = self.input_layernorm(hidden_states)
|
|
178
|
+
|
|
179
|
+
# Self Attention
|
|
180
|
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
181
|
+
hidden_states=hidden_states,
|
|
182
|
+
attention_mask=attention_mask,
|
|
183
|
+
position_ids=position_ids,
|
|
184
|
+
past_key_value=past_key_value,
|
|
185
|
+
output_attentions=output_attentions,
|
|
186
|
+
use_cache=use_cache,
|
|
187
|
+
)
|
|
188
|
+
hidden_states = residual + hidden_states
|
|
189
|
+
|
|
190
|
+
# Fully Connected
|
|
191
|
+
residual = hidden_states
|
|
192
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
193
|
+
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
|
|
194
|
+
hidden_states = residual + hidden_states
|
|
195
|
+
|
|
196
|
+
outputs = (hidden_states,)
|
|
197
|
+
|
|
198
|
+
if output_attentions:
|
|
199
|
+
outputs += (self_attn_weights,)
|
|
200
|
+
|
|
201
|
+
if use_cache:
|
|
202
|
+
outputs += (present_key_value,)
|
|
203
|
+
|
|
204
|
+
if output_router_logits:
|
|
205
|
+
outputs += (router_logits,)
|
|
206
|
+
|
|
207
|
+
return outputs
|