fusion-bench 0.2.16__py3-none-any.whl → 0.2.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/method/__init__.py +11 -0
- fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py +1 -1
- fusion_bench/method/base_algorithm.py +1 -0
- fusion_bench/method/dawe/dawe_for_clip.py +1 -1
- fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py +3 -2
- fusion_bench/method/expert_sparsity/__init__.py +10 -0
- fusion_bench/method/expert_sparsity/mixtral/__init__.py +23 -0
- fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +175 -0
- fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +159 -0
- fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +173 -0
- fusion_bench/method/expert_sparsity/utils/calibration_data.py +153 -0
- fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +1 -1
- fusion_bench/method/knots/__init__.py +0 -0
- fusion_bench/method/knots/knots_utils.py +23 -0
- fusion_bench/method/pwe_moe/module.py +2 -7
- fusion_bench/method/simple_average.py +3 -2
- fusion_bench/method/task_singular_vector/TSVM.py +238 -25
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +52 -20
- fusion_bench/method/task_singular_vector/utils/__init__.py +1 -0
- fusion_bench/method/task_singular_vector/utils/task_singular_interference.py +41 -0
- fusion_bench/mixins/hydra_config.py +1 -1
- fusion_bench/mixins/lightning_fabric.py +25 -1
- fusion_bench/mixins/serialization.py +18 -2
- fusion_bench/modelpool/base_pool.py +1 -0
- fusion_bench/modelpool/causal_lm/causal_lm.py +8 -5
- fusion_bench/modelpool/clip_vision/modelpool.py +21 -13
- fusion_bench/models/__init__.py +1 -0
- fusion_bench/models/expert_sparsity/__init__.py +0 -0
- fusion_bench/models/expert_sparsity/mixtral/__init__.py +15 -0
- fusion_bench/models/expert_sparsity/mixtral/dataset.py +40 -0
- fusion_bench/models/expert_sparsity/mixtral/modeling_mixtral.py +207 -0
- fusion_bench/models/expert_sparsity/mixtral/wrapper.py +268 -0
- fusion_bench/models/parameter_dict.py +6 -1
- fusion_bench/programs/fabric_fusion_program.py +21 -13
- fusion_bench/taskpool/base_pool.py +1 -0
- fusion_bench/taskpool/dummy.py +6 -4
- fusion_bench/utils/__init__.py +4 -3
- fusion_bench/utils/dtype.py +2 -1
- fusion_bench/utils/fabric.py +11 -4
- fusion_bench/utils/{instantiate.py → instantiate_utils.py} +3 -0
- fusion_bench/utils/lazy_state_dict.py +80 -10
- fusion_bench/utils/pylogger.py +30 -0
- {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/METADATA +3 -1
- {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/RECORD +59 -38
- {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/WHEEL +1 -1
- fusion_bench_config/fabric/loggers/mlflow_logger.yaml +2 -0
- fusion_bench_config/fabric_model_fusion.yaml +2 -2
- fusion_bench_config/method/expert_sparsity/README.md +6 -0
- fusion_bench_config/method/expert_sparsity/mixtral.yaml +17 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_cars_and_dtd.yaml +16 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +16 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +16 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +19 -0
- fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml +0 -1
- {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def compute_task_singular_interference(weight_differences: List[torch.Tensor]) -> float:
|
|
7
|
+
R"""
|
|
8
|
+
Compute the singular interference of a list of weight differences $\{W_i - W_0\}_{i=1}^T$,
|
|
9
|
+
where $W_0$ is the pre-trained model weight, $W_i$ is the weight of the i-th fine-tuned model
|
|
10
|
+
and $T$ is the number of fine-tuned models.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
weight_differences (List[torch.Tensor]): A list of weight differences $\{W_i - W_0\}_{i=1}^T$.
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
float: The singular interference of the list of weight differences.
|
|
17
|
+
"""
|
|
18
|
+
device = weight_differences[0].device
|
|
19
|
+
dtype = weight_differences[0].dtype
|
|
20
|
+
|
|
21
|
+
U = []
|
|
22
|
+
S = []
|
|
23
|
+
V = []
|
|
24
|
+
for delta_w in weight_differences:
|
|
25
|
+
u, s, vh = torch.linalg.svd(delta_w, full_matrices=False)
|
|
26
|
+
U.append(u)
|
|
27
|
+
S.append(s)
|
|
28
|
+
V.append(vh.t())
|
|
29
|
+
U = torch.cat(U, dim=0)
|
|
30
|
+
S = torch.cat(S, dim=0)
|
|
31
|
+
V = torch.cat(V, dim=0)
|
|
32
|
+
|
|
33
|
+
singular_task_interference = torch.linalg.multi_dot(
|
|
34
|
+
(
|
|
35
|
+
U.t() @ U - torch.eye(U.shape[1], device=device, dtype=dtype),
|
|
36
|
+
torch.diag(S),
|
|
37
|
+
V.t() @ V - torch.eye(V.shape[1], device=device, dtype=dtype),
|
|
38
|
+
)
|
|
39
|
+
)
|
|
40
|
+
singular_task_interference = torch.linalg.norm(singular_task_interference, ord="1")
|
|
41
|
+
return singular_task_interference
|
|
@@ -9,7 +9,7 @@ from hydra import compose, initialize
|
|
|
9
9
|
from omegaconf import DictConfig, OmegaConf
|
|
10
10
|
|
|
11
11
|
from fusion_bench.utils import import_object, instantiate
|
|
12
|
-
from fusion_bench.utils.
|
|
12
|
+
from fusion_bench.utils.instantiate_utils import set_print_function_call
|
|
13
13
|
|
|
14
14
|
log = logging.getLogger(__name__)
|
|
15
15
|
|
|
@@ -11,7 +11,7 @@ from lightning.fabric.utilities.rank_zero import rank_zero_only
|
|
|
11
11
|
from omegaconf import DictConfig, OmegaConf
|
|
12
12
|
|
|
13
13
|
from fusion_bench.utils import import_object
|
|
14
|
-
from fusion_bench.utils.
|
|
14
|
+
from fusion_bench.utils.instantiate_utils import instantiate
|
|
15
15
|
|
|
16
16
|
if TYPE_CHECKING:
|
|
17
17
|
import lightning.fabric.loggers.tensorboard
|
|
@@ -172,3 +172,27 @@ class LightningFabricMixin:
|
|
|
172
172
|
return True
|
|
173
173
|
else:
|
|
174
174
|
return False
|
|
175
|
+
|
|
176
|
+
def log(self, name: str, value: Any, step: Optional[int] = None):
|
|
177
|
+
"""
|
|
178
|
+
Logs the metric to the fabric's logger.
|
|
179
|
+
"""
|
|
180
|
+
self.fabric.log(name, value, step=step)
|
|
181
|
+
|
|
182
|
+
def log_dict(self, metrics: dict, step: Optional[int] = None):
|
|
183
|
+
"""
|
|
184
|
+
Logs the metrics to the fabric's logger.
|
|
185
|
+
"""
|
|
186
|
+
self.fabric.log_dict(metrics, step=step)
|
|
187
|
+
|
|
188
|
+
def log_optimizer_lr(
|
|
189
|
+
self,
|
|
190
|
+
optimizer: torch.optim.Optimizer,
|
|
191
|
+
step: Optional[int] = None,
|
|
192
|
+
name_template: str = "train/lr_group_{0}",
|
|
193
|
+
):
|
|
194
|
+
"""
|
|
195
|
+
Logs the learning rate of the optimizer to the fabric's logger.
|
|
196
|
+
"""
|
|
197
|
+
for i, param_group in enumerate(optimizer.param_groups):
|
|
198
|
+
self.fabric.log(name_template.format(i), param_group["lr"], step=step)
|
|
@@ -4,13 +4,14 @@ from typing import Dict, Optional, Union
|
|
|
4
4
|
|
|
5
5
|
from omegaconf import OmegaConf
|
|
6
6
|
|
|
7
|
-
from fusion_bench.utils import instantiate
|
|
7
|
+
from fusion_bench.utils import import_object, instantiate
|
|
8
8
|
|
|
9
9
|
log = logging.getLogger(__name__)
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class YAMLSerializationMixin:
|
|
13
13
|
_recursive_: bool = False
|
|
14
|
+
_config_key: Optional[str] = None
|
|
14
15
|
_config_mapping: Dict[str, str] = {
|
|
15
16
|
"_recursive_": "_recursive_",
|
|
16
17
|
}
|
|
@@ -99,7 +100,22 @@ class YAMLSerializationMixin:
|
|
|
99
100
|
BaseModelPool: The loaded model pool.
|
|
100
101
|
"""
|
|
101
102
|
config = OmegaConf.load(path)
|
|
102
|
-
|
|
103
|
+
if cls._config_key is not None and cls._config_key in config:
|
|
104
|
+
config = config[cls._config_key]
|
|
105
|
+
target_cls = import_object(config["_target_"])
|
|
106
|
+
if target_cls != cls:
|
|
107
|
+
log.warning(
|
|
108
|
+
f"The class {target_cls.__name__} is not the same as the class {cls.__name__}. "
|
|
109
|
+
f"Instantiating the class {target_cls.__name__} instead."
|
|
110
|
+
)
|
|
111
|
+
return instantiate(
|
|
112
|
+
config,
|
|
113
|
+
_recursive_=(
|
|
114
|
+
cls._recursive_
|
|
115
|
+
if config.get("_recursive_") is None
|
|
116
|
+
else config.get("_recursive_")
|
|
117
|
+
),
|
|
118
|
+
)
|
|
103
119
|
|
|
104
120
|
def to_config(self):
|
|
105
121
|
"""
|
|
@@ -29,6 +29,7 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
|
|
|
29
29
|
"""
|
|
30
30
|
|
|
31
31
|
_program = None
|
|
32
|
+
_config_key = "modelpool"
|
|
32
33
|
_models: Union[DictConfig, Dict[str, nn.Module]]
|
|
33
34
|
_config_mapping = BaseYAMLSerializableModel._config_mapping | {
|
|
34
35
|
"_models": "models",
|
|
@@ -141,6 +141,7 @@ class CausalLMPool(BaseModelPool):
|
|
|
141
141
|
model_dtype: Optional[str] = None,
|
|
142
142
|
save_tokenizer: bool = False,
|
|
143
143
|
tokenizer_kwargs=None,
|
|
144
|
+
tokenizer: Optional[PreTrainedTokenizer] = None,
|
|
144
145
|
**kwargs,
|
|
145
146
|
):
|
|
146
147
|
"""
|
|
@@ -154,11 +155,13 @@ class CausalLMPool(BaseModelPool):
|
|
|
154
155
|
**kwargs: Additional keyword arguments passed to the `save_pretrained` method.
|
|
155
156
|
"""
|
|
156
157
|
path = os.path.expanduser(path)
|
|
157
|
-
if save_tokenizer
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
158
|
+
# NOTE: if tokenizer is provided, it will be saved regardless of `save_tokenizer`
|
|
159
|
+
if save_tokenizer or tokenizer is not None:
|
|
160
|
+
if tokenizer is None:
|
|
161
|
+
if tokenizer_kwargs is None:
|
|
162
|
+
tokenizer_kwargs = {}
|
|
163
|
+
# load the tokenizer
|
|
164
|
+
tokenizer = self.load_tokenizer(**tokenizer_kwargs)
|
|
162
165
|
tokenizer.save_pretrained(
|
|
163
166
|
path,
|
|
164
167
|
push_to_hub=push_to_hub,
|
|
@@ -3,6 +3,7 @@ from copy import deepcopy
|
|
|
3
3
|
from typing import Optional, Union
|
|
4
4
|
|
|
5
5
|
from datasets import load_dataset
|
|
6
|
+
from lightning.fabric.utilities import rank_zero_only
|
|
6
7
|
from omegaconf import DictConfig, open_dict
|
|
7
8
|
from torch import nn
|
|
8
9
|
from torch.utils.data import Dataset
|
|
@@ -40,7 +41,8 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
40
41
|
def load_processor(self, *args, **kwargs) -> CLIPProcessor:
|
|
41
42
|
assert self._processor is not None, "Processor is not defined in the config"
|
|
42
43
|
if isinstance(self._processor, str):
|
|
43
|
-
|
|
44
|
+
if rank_zero_only.rank == 0:
|
|
45
|
+
log.info(f"Loading `transformers.CLIPProcessor`: {self._processor}")
|
|
44
46
|
processor = CLIPProcessor.from_pretrained(self._processor)
|
|
45
47
|
else:
|
|
46
48
|
processor = instantiate(self._processor, *args, **kwargs)
|
|
@@ -50,7 +52,8 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
50
52
|
model_config = self._models[model_name]
|
|
51
53
|
|
|
52
54
|
if isinstance(model_config, str):
|
|
53
|
-
|
|
55
|
+
if rank_zero_only.rank == 0:
|
|
56
|
+
log.info(f"Loading `transformers.CLIPModel`: {model_config}")
|
|
54
57
|
clip_model = CLIPModel.from_pretrained(model_config, *args, **kwargs)
|
|
55
58
|
return clip_model
|
|
56
59
|
else:
|
|
@@ -102,10 +105,12 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
102
105
|
):
|
|
103
106
|
model = self._models[model_name_or_config]
|
|
104
107
|
if isinstance(model, str):
|
|
105
|
-
|
|
108
|
+
if rank_zero_only.rank == 0:
|
|
109
|
+
log.info(f"Loading `transformers.CLIPVisionModel`: {model}")
|
|
106
110
|
return CLIPVisionModel.from_pretrained(model, *args, **kwargs)
|
|
107
111
|
if isinstance(model, nn.Module):
|
|
108
|
-
|
|
112
|
+
if rank_zero_only.rank == 0:
|
|
113
|
+
log.info(f"Returning existing model: {model}")
|
|
109
114
|
return model
|
|
110
115
|
|
|
111
116
|
# If the model is not a string, we use the default load_model method
|
|
@@ -114,9 +119,10 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
114
119
|
def load_train_dataset(self, dataset_name: str, *args, **kwargs):
|
|
115
120
|
dataset_config = self._train_datasets[dataset_name]
|
|
116
121
|
if isinstance(dataset_config, str):
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
122
|
+
if rank_zero_only.rank == 0:
|
|
123
|
+
log.info(
|
|
124
|
+
f"Loading train dataset using `datasets.load_dataset`: {dataset_config}"
|
|
125
|
+
)
|
|
120
126
|
dataset = load_dataset(dataset_config, split="train")
|
|
121
127
|
else:
|
|
122
128
|
dataset = super().load_train_dataset(dataset_name, *args, **kwargs)
|
|
@@ -125,9 +131,10 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
125
131
|
def load_val_dataset(self, dataset_name: str, *args, **kwargs):
|
|
126
132
|
dataset_config = self._val_datasets[dataset_name]
|
|
127
133
|
if isinstance(dataset_config, str):
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
134
|
+
if rank_zero_only.rank == 0:
|
|
135
|
+
log.info(
|
|
136
|
+
f"Loading validation dataset using `datasets.load_dataset`: {dataset_config}"
|
|
137
|
+
)
|
|
131
138
|
dataset = load_dataset(dataset_config, split="validation")
|
|
132
139
|
else:
|
|
133
140
|
dataset = super().load_val_dataset(dataset_name, *args, **kwargs)
|
|
@@ -136,9 +143,10 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
136
143
|
def load_test_dataset(self, dataset_name: str, *args, **kwargs):
|
|
137
144
|
dataset_config = self._test_datasets[dataset_name]
|
|
138
145
|
if isinstance(dataset_config, str):
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
146
|
+
if rank_zero_only.rank == 0:
|
|
147
|
+
log.info(
|
|
148
|
+
f"Loading test dataset using `datasets.load_dataset`: {dataset_config}"
|
|
149
|
+
)
|
|
142
150
|
dataset = load_dataset(dataset_config, split="test")
|
|
143
151
|
else:
|
|
144
152
|
dataset = super().load_test_dataset(dataset_name, *args, **kwargs)
|
fusion_bench/models/__init__.py
CHANGED
|
File without changes
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
R"""
|
|
2
|
+
Copy from https://github.com/Lucky-Lance/Expert_Sparsity/tree/main/model
|
|
3
|
+
|
|
4
|
+
Original repo: https://github.com/Lucky-Lance/Expert_Sparsity
|
|
5
|
+
|
|
6
|
+
Reference:
|
|
7
|
+
Not All Experts are Equal: Efficient Expert Pruning and Skipping for Mixture-of-Experts Large Language Models.
|
|
8
|
+
ACL 2024.
|
|
9
|
+
http://arxiv.org/abs/2402.14800
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from .wrapper import (
|
|
13
|
+
PrunableMixtralSparseMoeBlockWrapper,
|
|
14
|
+
DynamicSkippingMixtralSparseMoeBlockWrapper,
|
|
15
|
+
)
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class CacheDataset(torch.utils.data.Dataset):
|
|
5
|
+
def __init__(self):
|
|
6
|
+
self.alphas = [] # logits
|
|
7
|
+
self.Xs = [] # input hidden states
|
|
8
|
+
self.Zs = [] # output hidden states
|
|
9
|
+
self.prepared = False
|
|
10
|
+
|
|
11
|
+
def __len__(self):
|
|
12
|
+
if not self.prepared:
|
|
13
|
+
self.prepare_for_loader()
|
|
14
|
+
return len(self.alphas)
|
|
15
|
+
|
|
16
|
+
def __getitem__(self, index):
|
|
17
|
+
if not self.prepared:
|
|
18
|
+
self.prepare_for_loader()
|
|
19
|
+
if isinstance(index, list):
|
|
20
|
+
return [(self.alphas[idx], self.Xs[idx], self.Zs[idx]) for idx in index]
|
|
21
|
+
elif isinstance(index, int):
|
|
22
|
+
return self.alphas[index], self.Xs[index], self.Zs[index]
|
|
23
|
+
|
|
24
|
+
def append(self, alpha=None, X=None, Z=None):
|
|
25
|
+
if alpha is not None:
|
|
26
|
+
self.alphas.append(alpha.detach().to("cpu", non_blocking=True))
|
|
27
|
+
if X is not None:
|
|
28
|
+
self.Xs.append(X.detach().to("cpu", non_blocking=True))
|
|
29
|
+
if Z is not None:
|
|
30
|
+
self.Zs.append(Z.detach().to("cpu", non_blocking=True))
|
|
31
|
+
self.prepared = False
|
|
32
|
+
|
|
33
|
+
def prepare_for_loader(self):
|
|
34
|
+
if self.prepared:
|
|
35
|
+
return
|
|
36
|
+
self.prepared = True
|
|
37
|
+
self.alphas = torch.concat(self.alphas)
|
|
38
|
+
self.Xs = torch.concat(self.Xs)
|
|
39
|
+
self.Zs = torch.concat(self.Zs)
|
|
40
|
+
assert len(self.Xs) == len(self.Zs)
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from typing import Optional, Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
|
8
|
+
from transformers.models.mixtral.modeling_mixtral import (
|
|
9
|
+
MixtralBlockSparseTop2MLP,
|
|
10
|
+
MixtralConfig,
|
|
11
|
+
MixtralRMSNorm,
|
|
12
|
+
MixtralSparseMoeBlock,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DynamicSkippingMixtralSparseMoeBlock(nn.Module):
|
|
17
|
+
"""
|
|
18
|
+
This implementation is
|
|
19
|
+
strictly equivalent to standard MoE with full capacity (no
|
|
20
|
+
dropped tokens). It's faster since it formulates MoE operations
|
|
21
|
+
in terms of block-sparse operations to accomodate imbalanced
|
|
22
|
+
assignments of tokens to experts, whereas standard MoE either
|
|
23
|
+
(1) drop tokens at the cost of reduced performance or (2) set
|
|
24
|
+
capacity factor to number of experts and thus waste computation
|
|
25
|
+
and memory on padding.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, config, beta):
|
|
29
|
+
super().__init__()
|
|
30
|
+
self.hidden_dim = config.hidden_size
|
|
31
|
+
self.ffn_dim = config.intermediate_size
|
|
32
|
+
self.num_experts = config.num_local_experts
|
|
33
|
+
self.top_k = config.num_experts_per_tok
|
|
34
|
+
|
|
35
|
+
# gating
|
|
36
|
+
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
|
37
|
+
|
|
38
|
+
self.experts = nn.ModuleList(
|
|
39
|
+
[MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
self.beta = beta
|
|
43
|
+
|
|
44
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
45
|
+
""" """
|
|
46
|
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
|
47
|
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
48
|
+
# router_logits: (batch * sequence_length, n_experts)
|
|
49
|
+
router_logits = self.gate(hidden_states)
|
|
50
|
+
|
|
51
|
+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
|
52
|
+
routing_weights, selected_experts = torch.topk(
|
|
53
|
+
routing_weights, self.top_k, dim=-1
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
onlytop1_mask = (
|
|
57
|
+
routing_weights[:, 1] < self.beta * routing_weights[:, 0]
|
|
58
|
+
) # bz x seqlen
|
|
59
|
+
|
|
60
|
+
# routing_weights[tokens_top1, 1].fill_(0)
|
|
61
|
+
routing_weights[onlytop1_mask, 1] = 0
|
|
62
|
+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
|
63
|
+
# we cast back to the input dtype
|
|
64
|
+
routing_weights = routing_weights.to(hidden_states.dtype)
|
|
65
|
+
|
|
66
|
+
final_hidden_states = torch.zeros(
|
|
67
|
+
(batch_size * sequence_length, hidden_dim),
|
|
68
|
+
dtype=hidden_states.dtype,
|
|
69
|
+
device=hidden_states.device,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# One hot encode the selected experts to create an expert mask
|
|
73
|
+
# this will be used to easily index which expert is going to be sollicitated
|
|
74
|
+
expert_mask = torch.nn.functional.one_hot(
|
|
75
|
+
selected_experts, num_classes=self.num_experts
|
|
76
|
+
)
|
|
77
|
+
# ipdb.set_trace()
|
|
78
|
+
# expert_mask[tokens_top1, 1, :].fill_(0)
|
|
79
|
+
expert_mask[onlytop1_mask, 1, :] = 0
|
|
80
|
+
expert_mask = expert_mask.permute(2, 1, 0)
|
|
81
|
+
|
|
82
|
+
# Loop over all available experts in the model and perform the computation on each expert
|
|
83
|
+
for expert_idx in range(self.num_experts):
|
|
84
|
+
expert_layer = self.experts[expert_idx]
|
|
85
|
+
idx, top_x = torch.where(expert_mask[expert_idx])
|
|
86
|
+
|
|
87
|
+
if top_x.shape[0] == 0:
|
|
88
|
+
continue
|
|
89
|
+
|
|
90
|
+
# in torch it is faster to index using lists than torch tensors
|
|
91
|
+
top_x_list = top_x.tolist()
|
|
92
|
+
idx_list = idx.tolist()
|
|
93
|
+
|
|
94
|
+
# Index the correct hidden states and compute the expert hidden state for
|
|
95
|
+
# the current expert. We need to make sure to multiply the output hidden
|
|
96
|
+
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
|
97
|
+
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
|
|
98
|
+
current_hidden_states = expert_layer(
|
|
99
|
+
current_state, routing_weights[top_x_list, idx_list, None]
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# However `index_add_` only support torch tensors for indexing so we'll use
|
|
103
|
+
# the `top_x` tensor here.
|
|
104
|
+
final_hidden_states.index_add_(
|
|
105
|
+
0, top_x, current_hidden_states.to(hidden_states.dtype)
|
|
106
|
+
)
|
|
107
|
+
final_hidden_states = final_hidden_states.reshape(
|
|
108
|
+
batch_size, sequence_length, hidden_dim
|
|
109
|
+
)
|
|
110
|
+
return final_hidden_states, router_logits
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class MixtralDecoderLayer(nn.Module):
|
|
114
|
+
def __init__(self, config: MixtralConfig, layer_idx: int):
|
|
115
|
+
super().__init__()
|
|
116
|
+
self.hidden_size = config.hidden_size
|
|
117
|
+
|
|
118
|
+
self.self_attn = ALL_ATTENTION_FUNCTIONS[config._attn_implementation](
|
|
119
|
+
config, layer_idx
|
|
120
|
+
)
|
|
121
|
+
if hasattr(config, "betas"):
|
|
122
|
+
assert (
|
|
123
|
+
isinstance(config.betas, dict)
|
|
124
|
+
and len(config.betas) == config.num_hidden_layers
|
|
125
|
+
)
|
|
126
|
+
self.block_sparse_moe = DynamicSkippingMixtralSparseMoeBlock(
|
|
127
|
+
config, config.betas[str(layer_idx)]
|
|
128
|
+
)
|
|
129
|
+
warnings.warn(
|
|
130
|
+
f"Using online drop: {layer_idx}, {config.betas[str(layer_idx)]}, {type(self.block_sparse_moe)}"
|
|
131
|
+
)
|
|
132
|
+
else:
|
|
133
|
+
self.block_sparse_moe = MixtralSparseMoeBlock(config)
|
|
134
|
+
self.input_layernorm = MixtralRMSNorm(
|
|
135
|
+
config.hidden_size, eps=config.rms_norm_eps
|
|
136
|
+
)
|
|
137
|
+
self.post_attention_layernorm = MixtralRMSNorm(
|
|
138
|
+
config.hidden_size, eps=config.rms_norm_eps
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
def forward(
|
|
142
|
+
self,
|
|
143
|
+
hidden_states: torch.Tensor,
|
|
144
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
145
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
146
|
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
147
|
+
output_attentions: Optional[bool] = False,
|
|
148
|
+
output_router_logits: Optional[bool] = False,
|
|
149
|
+
use_cache: Optional[bool] = False,
|
|
150
|
+
**kwargs,
|
|
151
|
+
) -> Tuple[
|
|
152
|
+
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
|
153
|
+
]:
|
|
154
|
+
if "padding_mask" in kwargs:
|
|
155
|
+
warnings.warn(
|
|
156
|
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
|
157
|
+
)
|
|
158
|
+
"""
|
|
159
|
+
Args:
|
|
160
|
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
161
|
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
|
162
|
+
`(batch, sequence_length)` where padding elements are indicated by 0.
|
|
163
|
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
|
164
|
+
output_attentions (`bool`, *optional*):
|
|
165
|
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
166
|
+
returned tensors for more detail.
|
|
167
|
+
output_router_logits (`bool`, *optional*):
|
|
168
|
+
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
|
|
169
|
+
should not be returned during inference.
|
|
170
|
+
use_cache (`bool`, *optional*):
|
|
171
|
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
|
172
|
+
(see `past_key_values`).
|
|
173
|
+
"""
|
|
174
|
+
|
|
175
|
+
residual = hidden_states
|
|
176
|
+
|
|
177
|
+
hidden_states = self.input_layernorm(hidden_states)
|
|
178
|
+
|
|
179
|
+
# Self Attention
|
|
180
|
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
181
|
+
hidden_states=hidden_states,
|
|
182
|
+
attention_mask=attention_mask,
|
|
183
|
+
position_ids=position_ids,
|
|
184
|
+
past_key_value=past_key_value,
|
|
185
|
+
output_attentions=output_attentions,
|
|
186
|
+
use_cache=use_cache,
|
|
187
|
+
)
|
|
188
|
+
hidden_states = residual + hidden_states
|
|
189
|
+
|
|
190
|
+
# Fully Connected
|
|
191
|
+
residual = hidden_states
|
|
192
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
193
|
+
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
|
|
194
|
+
hidden_states = residual + hidden_states
|
|
195
|
+
|
|
196
|
+
outputs = (hidden_states,)
|
|
197
|
+
|
|
198
|
+
if output_attentions:
|
|
199
|
+
outputs += (self_attn_weights,)
|
|
200
|
+
|
|
201
|
+
if use_cache:
|
|
202
|
+
outputs += (present_key_value,)
|
|
203
|
+
|
|
204
|
+
if output_router_logits:
|
|
205
|
+
outputs += (router_logits,)
|
|
206
|
+
|
|
207
|
+
return outputs
|