fusion-bench 0.2.17__py3-none-any.whl → 0.2.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. fusion_bench/__init__.py +6 -0
  2. fusion_bench/constants/banner.py +12 -0
  3. fusion_bench/method/__init__.py +11 -0
  4. fusion_bench/method/expert_sparsity/__init__.py +10 -0
  5. fusion_bench/method/expert_sparsity/mixtral/__init__.py +23 -0
  6. fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +175 -0
  7. fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +159 -0
  8. fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +173 -0
  9. fusion_bench/method/expert_sparsity/utils/calibration_data.py +153 -0
  10. fusion_bench/method/knots/__init__.py +0 -0
  11. fusion_bench/method/knots/knots_utils.py +23 -0
  12. fusion_bench/method/linear/simple_average_for_llama.py +17 -3
  13. fusion_bench/method/simple_average.py +10 -0
  14. fusion_bench/method/task_singular_vector/utils/__init__.py +1 -0
  15. fusion_bench/method/task_singular_vector/utils/task_singular_interference.py +41 -0
  16. fusion_bench/modelpool/causal_lm/causal_lm.py +45 -11
  17. fusion_bench/models/__init__.py +1 -0
  18. fusion_bench/models/expert_sparsity/__init__.py +0 -0
  19. fusion_bench/models/expert_sparsity/mixtral/__init__.py +15 -0
  20. fusion_bench/models/expert_sparsity/mixtral/dataset.py +40 -0
  21. fusion_bench/models/expert_sparsity/mixtral/modeling_mixtral.py +207 -0
  22. fusion_bench/models/expert_sparsity/mixtral/wrapper.py +268 -0
  23. fusion_bench/programs/fabric_fusion_program.py +12 -8
  24. fusion_bench/tasks/clip_classification/imagenet.py +1008 -2004
  25. fusion_bench/utils/__init__.py +3 -2
  26. fusion_bench/utils/dtype.py +2 -1
  27. fusion_bench/utils/fabric.py +11 -4
  28. fusion_bench/utils/lazy_state_dict.py +155 -13
  29. fusion_bench/utils/misc.py +19 -1
  30. fusion_bench/utils/pylogger.py +2 -0
  31. {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.19.dist-info}/METADATA +1 -1
  32. {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.19.dist-info}/RECORD +40 -21
  33. fusion_bench_config/fabric/loggers/mlflow_logger.yaml +2 -0
  34. fusion_bench_config/method/expert_sparsity/README.md +6 -0
  35. fusion_bench_config/method/expert_sparsity/mixtral.yaml +17 -0
  36. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +11 -0
  37. {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.19.dist-info}/WHEEL +0 -0
  38. {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.19.dist-info}/entry_points.txt +0 -0
  39. {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.19.dist-info}/licenses/LICENSE +0 -0
  40. {fusion_bench-0.2.17.dist-info → fusion_bench-0.2.19.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,153 @@
1
+ """
2
+ This module contains the code for loading the calibration data.
3
+
4
+ - C4
5
+ - Math
6
+ """
7
+
8
+ import itertools
9
+ import logging
10
+ import os
11
+
12
+ import torch
13
+ import transformers
14
+ from datasets import load_dataset
15
+ from transformers import PreTrainedTokenizer, default_data_collator
16
+ from transformers.testing_utils import CaptureLogger
17
+ from huggingface_hub import hf_hub_download
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ DATASETS = {
23
+ # C4: Please download first part of the C4 training data `c4-train.00000-of-01024.json` from [allenai/c4](https://huggingface.co/datasets/allenai/c4/blob/main/en/c4-train.00000-of-01024.json.gz).
24
+ "c4": lambda: load_dataset(
25
+ "json",
26
+ data_files={
27
+ "train": hf_hub_download(
28
+ "allenai/c4",
29
+ filename="en/c4-train.00000-of-01024.json.gz",
30
+ repo_type="dataset",
31
+ )
32
+ },
33
+ ),
34
+ # MATH: You can use our pre-built calibration set in `./data/math_pretrain_style.json`. To reproduce our construction, please download the training set of [MATH](https://github.com/hendrycks/math) and use our [script](data/math_calib_construction.py).
35
+ # NOTE: I have uploaded the math_pretrain_style.json to my huggingface repo:
36
+ # https://huggingface.co/datasets/tanganke/math_pretrain_style/tree/main.
37
+ "math": lambda: load_dataset(
38
+ "json",
39
+ data_files={
40
+ "train": hf_hub_download(
41
+ "tanganke/math_pretrain_style",
42
+ filename="math_pretrain_style.json",
43
+ repo_type="dataset",
44
+ )
45
+ },
46
+ ),
47
+ }
48
+
49
+
50
+ def build_calib_loader(
51
+ dataset: str,
52
+ tokenizer: PreTrainedTokenizer,
53
+ max_block_size: int,
54
+ n_blocks_for_stat: int,
55
+ batch_size: int,
56
+ num_workers: int,
57
+ seed: int = 42,
58
+ ):
59
+ # dataset can be a string or a dataset object.
60
+ # If it is a string, it can be the name of the dataset in DATASETS or the path to the dataset (a json file).
61
+ if isinstance(dataset, str):
62
+ if dataset in DATASETS:
63
+ all_set = DATASETS[dataset]()
64
+ else:
65
+ assert os.path.exists(dataset), f"Dataset {dataset} not found."
66
+ all_set = load_dataset("json", data_files={"train": dataset})
67
+ else:
68
+ assert dataset is not None, "Dataset is not provided."
69
+ all_set = dataset
70
+
71
+ block_size = tokenizer.model_max_length
72
+ if block_size > max_block_size:
73
+ logger.info(
74
+ "The chosen tokenizer supports a `model_max_length` that is longer than the default `max_block_size` value"
75
+ f" of {max_block_size}. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can"
76
+ " override this default with `--max_block_size xxx`."
77
+ )
78
+ block_size = max_block_size
79
+
80
+ if n_blocks_for_stat > 0: # Random choose `n_blocks_for_stat` blocks
81
+ calib_set = (
82
+ all_set["train"]
83
+ .shuffle(seed=seed)
84
+ .select(range(min(n_blocks_for_stat * 16, len(all_set["train"]))))
85
+ )
86
+ else: # Use the whole set
87
+ logger.warning("n_blocks_for_stat <= 0, using the whole dataset.")
88
+ calib_set = all_set["train"].shuffle(seed=seed)
89
+
90
+ logger.info(f"Calibration dataset: {calib_set}")
91
+ text_column_name = (
92
+ "text" if "text" in calib_set.features else list(calib_set.features)[0]
93
+ )
94
+
95
+ tok_logger = transformers.utils.logging.get_logger(
96
+ "transformers.tokenization_utils_base"
97
+ )
98
+
99
+ def tokenize_function(examples):
100
+ with CaptureLogger(tok_logger) as cl:
101
+ output = tokenizer(examples[text_column_name])
102
+ # clm input could be much much longer than block_size
103
+ if "Token indices sequence length is longer than the" in cl.out:
104
+ tok_logger.warning(
105
+ "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
106
+ " before being passed to the model."
107
+ )
108
+ return output
109
+
110
+ tokenized_calib_set = calib_set.map(
111
+ tokenize_function,
112
+ batched=True,
113
+ remove_columns=list(calib_set.features),
114
+ )
115
+
116
+ def group_texts(examples):
117
+ # Concatenate all texts.
118
+ concatenated_examples = {
119
+ k: list(itertools.chain(*examples[k])) for k in examples.keys()
120
+ }
121
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
122
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
123
+ # customize this part to your needs.
124
+ if total_length >= block_size:
125
+ total_length = (total_length // block_size) * block_size
126
+ # Split by chunks of max_len.
127
+ result = {
128
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
129
+ for k, t in concatenated_examples.items()
130
+ }
131
+ result["labels"] = result["input_ids"].copy()
132
+ return result
133
+
134
+ lm_calib_set = tokenized_calib_set.map(
135
+ group_texts,
136
+ batched=True,
137
+ )
138
+
139
+ if n_blocks_for_stat > 0:
140
+ assert len(lm_calib_set) > n_blocks_for_stat
141
+ lm_calib_set = lm_calib_set.select(range(n_blocks_for_stat))
142
+
143
+ calib_loader = torch.utils.data.DataLoader(
144
+ lm_calib_set,
145
+ batch_size=batch_size,
146
+ num_workers=num_workers,
147
+ pin_memory=True,
148
+ drop_last=False,
149
+ shuffle=False,
150
+ collate_fn=default_data_collator,
151
+ )
152
+
153
+ return calib_loader
File without changes
@@ -0,0 +1,23 @@
1
+ import torch
2
+
3
+
4
+ def subspace_alignment(
5
+ delta_weights: list[torch.Tensor],
6
+ svd_dtype: torch.dtype | None = torch.float64,
7
+ eps: float = 1e-4,
8
+ ):
9
+ """
10
+ Reference: Model merging with SVD to tie the Knots. http://arxiv.org/abs/2410.19735
11
+ """
12
+ if svd_dtype is None:
13
+ svd_dtype = delta_weights[0].dtype
14
+ original_dtype = delta_weights[0].dtype
15
+ output_dim, input_dim = delta_weights[0].size()
16
+ concat_task_vector = torch.cat(delta_weights, dim=1)
17
+ U, S, Vh = torch.linalg.svd(concat_task_vector.to(svd_dtype), full_matrices=False)
18
+ # Keep only supported basis components
19
+ U = U[:, S > eps].to(original_dtype)
20
+ Vh = Vh[S > eps].to(original_dtype)
21
+ S = S[S > eps].to(original_dtype)
22
+ Vhs = torch.split(Vh, input_dim, dim=1)
23
+ return U, S, Vhs
@@ -1,4 +1,5 @@
1
- from typing import Optional
1
+ from copy import deepcopy
2
+ from typing import TYPE_CHECKING, Optional
2
3
 
3
4
  from typing_extensions import override
4
5
 
@@ -6,6 +7,11 @@ from fusion_bench import timeit_context
6
7
  from fusion_bench.method.base_algorithm import BaseAlgorithm
7
8
  from fusion_bench.method.simple_average import SimpleAverageAlgorithm
8
9
  from fusion_bench.modelpool import CausalLMBackbonePool, CausalLMPool
10
+ from fusion_bench.utils.pylogger import getRankZeroLogger
11
+ from omegaconf import flag_override
12
+ from fusion_bench.utils import instantiate
13
+
14
+ log = getRankZeroLogger(__name__)
9
15
 
10
16
 
11
17
  class SimpleAverageForLlama(BaseAlgorithm):
@@ -40,12 +46,20 @@ class SimpleAverageForLlama(BaseAlgorithm):
40
46
 
41
47
  if self.merge_backbone:
42
48
  assert modelpool.has_pretrained
43
- backbone_modelpool = CausalLMBackbonePool(**modelpool.config)
49
+ log.info(
50
+ "Merging backbone of the model pool, use CausalLMBackbonePool instead of CausalLMPool."
51
+ )
52
+ modelpool_config = deepcopy(modelpool.config)
53
+ with flag_override(modelpool_config, "allow_objects", True):
54
+ modelpool_config._target_ = (
55
+ "fusion_bench.modelpool.causal_lm.CausalLMBackbonePool"
56
+ )
57
+ backbone_modelpool = instantiate(modelpool_config)
44
58
  model = modelpool.load_model("_pretrained_")
45
59
  backbone_model = SimpleAverageAlgorithm().run(backbone_modelpool)
46
60
  model.model.layers = backbone_model
47
61
  else:
48
- model = SimpleAverageAlgorithm().run()
62
+ model = SimpleAverageAlgorithm().run(modelpool=modelpool)
49
63
 
50
64
  if self.model_save_path is not None:
51
65
  with timeit_context(f"Saving the model to {self.model_save_path}"):
@@ -8,6 +8,7 @@ from torch import nn
8
8
  from fusion_bench.method.base_algorithm import BaseAlgorithm
9
9
  from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
10
10
  from fusion_bench.modelpool import BaseModelPool
11
+ from fusion_bench.utils import LazyStateDict
11
12
  from fusion_bench.utils.state_dict_arithmetic import (
12
13
  state_dict_add,
13
14
  state_dict_avg,
@@ -104,6 +105,15 @@ class SimpleAverageAlgorithm(
104
105
  # Divide the accumulated state dictionary by the number of models to get the average
105
106
  sd = state_dict_div(sd, len(modelpool.model_names))
106
107
 
108
+ if isinstance(forward_model, LazyStateDict):
109
+ # if the model is a LazyStateDict, convert it to an empty module
110
+ forward_model = forward_model.meta_module.to_empty(
111
+ device=(
112
+ "cpu"
113
+ if forward_model._torch_dtype is None
114
+ else forward_model._torch_dtype
115
+ )
116
+ )
107
117
  forward_model.load_state_dict(sd)
108
118
  # print profile report and log the merged models
109
119
  self.print_profile_summary()
@@ -5,3 +5,4 @@ from fusion_bench.method.ties_merging.ties_merging_utils import (
5
5
  from fusion_bench.utils import state_dict_to_vector, vector_to_state_dict
6
6
 
7
7
  from . import TSVC_utils, TSVM_utils
8
+ from .task_singular_interference import compute_task_singular_interference
@@ -0,0 +1,41 @@
1
+ from typing import List
2
+
3
+ import torch
4
+
5
+
6
+ def compute_task_singular_interference(weight_differences: List[torch.Tensor]) -> float:
7
+ R"""
8
+ Compute the singular interference of a list of weight differences $\{W_i - W_0\}_{i=1}^T$,
9
+ where $W_0$ is the pre-trained model weight, $W_i$ is the weight of the i-th fine-tuned model
10
+ and $T$ is the number of fine-tuned models.
11
+
12
+ Args:
13
+ weight_differences (List[torch.Tensor]): A list of weight differences $\{W_i - W_0\}_{i=1}^T$.
14
+
15
+ Returns:
16
+ float: The singular interference of the list of weight differences.
17
+ """
18
+ device = weight_differences[0].device
19
+ dtype = weight_differences[0].dtype
20
+
21
+ U = []
22
+ S = []
23
+ V = []
24
+ for delta_w in weight_differences:
25
+ u, s, vh = torch.linalg.svd(delta_w, full_matrices=False)
26
+ U.append(u)
27
+ S.append(s)
28
+ V.append(vh.t())
29
+ U = torch.cat(U, dim=0)
30
+ S = torch.cat(S, dim=0)
31
+ V = torch.cat(V, dim=0)
32
+
33
+ singular_task_interference = torch.linalg.multi_dot(
34
+ (
35
+ U.t() @ U - torch.eye(U.shape[1], device=device, dtype=dtype),
36
+ torch.diag(S),
37
+ V.t() @ V - torch.eye(V.shape[1], device=device, dtype=dtype),
38
+ )
39
+ )
40
+ singular_task_interference = torch.linalg.norm(singular_task_interference, ord="1")
41
+ return singular_task_interference
@@ -22,6 +22,8 @@ from typing_extensions import override
22
22
  from fusion_bench.modelpool import BaseModelPool
23
23
  from fusion_bench.utils import instantiate
24
24
  from fusion_bench.utils.dtype import parse_dtype
25
+ from fusion_bench.utils.lazy_state_dict import LazyStateDict
26
+ from fusion_bench.utils.packages import import_object
25
27
 
26
28
  log = logging.getLogger(__name__)
27
29
 
@@ -30,6 +32,7 @@ class CausalLMPool(BaseModelPool):
30
32
  _config_mapping = BaseModelPool._config_mapping | {
31
33
  "_tokenizer": "tokenizer",
32
34
  "_model_kwargs": "model_kwargs",
35
+ "load_lazy": "load_lazy",
33
36
  }
34
37
 
35
38
  def __init__(
@@ -38,6 +41,7 @@ class CausalLMPool(BaseModelPool):
38
41
  *,
39
42
  tokenizer: Optional[DictConfig],
40
43
  model_kwargs: Optional[DictConfig] = None,
44
+ load_lazy: bool = False,
41
45
  **kwargs,
42
46
  ):
43
47
  super().__init__(models, **kwargs)
@@ -51,6 +55,7 @@ class CausalLMPool(BaseModelPool):
51
55
  self._model_kwargs.torch_dtype = parse_dtype(
52
56
  self._model_kwargs.torch_dtype
53
57
  )
58
+ self.load_lazy = load_lazy
54
59
 
55
60
  @override
56
61
  def load_model(
@@ -88,21 +93,41 @@ class CausalLMPool(BaseModelPool):
88
93
  model_kwargs.update(kwargs)
89
94
 
90
95
  if isinstance(model_name_or_config, str):
96
+ # If model_name_or_config is a string, it is the name or the path of the model
91
97
  log.info(f"Loading model: {model_name_or_config}", stacklevel=2)
92
98
  if model_name_or_config in self._models.keys():
93
99
  model_config = self._models[model_name_or_config]
94
100
  if isinstance(model_config, str):
95
101
  # model_config is a string
96
- model = AutoModelForCausalLM.from_pretrained(
97
- model_config,
98
- *args,
99
- **model_kwargs,
100
- )
102
+ if not self.load_lazy:
103
+ model = AutoModelForCausalLM.from_pretrained(
104
+ model_config,
105
+ *args,
106
+ **model_kwargs,
107
+ )
108
+ else:
109
+ # model_config is a string, but we want to use LazyStateDict
110
+ model = LazyStateDict(
111
+ checkpoint=model_config,
112
+ meta_module_class=AutoModelForCausalLM,
113
+ *args,
114
+ **model_kwargs,
115
+ )
101
116
  return model
102
117
  elif isinstance(model_name_or_config, (DictConfig, Dict)):
103
118
  model_config = model_name_or_config
104
119
 
105
- model = instantiate(model_config, *args, **model_kwargs)
120
+ if not self.load_lazy:
121
+ model = instantiate(model_config, *args, **model_kwargs)
122
+ else:
123
+ meta_module_class = model_config.pop("_target_")
124
+ checkpoint = model_config.pop("pretrained_model_name_or_path")
125
+ model = LazyStateDict(
126
+ checkpoint=checkpoint,
127
+ meta_module_class=meta_module_class,
128
+ *args,
129
+ **model_kwargs,
130
+ )
106
131
  return model
107
132
 
108
133
  def load_tokenizer(self, *args, **kwargs) -> PreTrainedTokenizer:
@@ -141,6 +166,7 @@ class CausalLMPool(BaseModelPool):
141
166
  model_dtype: Optional[str] = None,
142
167
  save_tokenizer: bool = False,
143
168
  tokenizer_kwargs=None,
169
+ tokenizer: Optional[PreTrainedTokenizer] = None,
144
170
  **kwargs,
145
171
  ):
146
172
  """
@@ -154,11 +180,13 @@ class CausalLMPool(BaseModelPool):
154
180
  **kwargs: Additional keyword arguments passed to the `save_pretrained` method.
155
181
  """
156
182
  path = os.path.expanduser(path)
157
- if save_tokenizer:
158
- if tokenizer_kwargs is None:
159
- tokenizer_kwargs = {}
160
- # load the tokenizer
161
- tokenizer = self.load_tokenizer(**tokenizer_kwargs)
183
+ # NOTE: if tokenizer is provided, it will be saved regardless of `save_tokenizer`
184
+ if save_tokenizer or tokenizer is not None:
185
+ if tokenizer is None:
186
+ if tokenizer_kwargs is None:
187
+ tokenizer_kwargs = {}
188
+ # load the tokenizer
189
+ tokenizer = self.load_tokenizer(**tokenizer_kwargs)
162
190
  tokenizer.save_pretrained(
163
191
  path,
164
192
  push_to_hub=push_to_hub,
@@ -176,6 +204,12 @@ class CausalLMBackbonePool(CausalLMPool):
176
204
  def load_model(
177
205
  self, model_name_or_config: str | DictConfig, *args, **kwargs
178
206
  ) -> Module:
207
+ if self.load_lazy:
208
+ log.warning(
209
+ "CausalLMBackbonePool does not support lazy loading. "
210
+ "Falling back to normal loading."
211
+ )
212
+ self.load_lazy = False
179
213
  model: AutoModelForCausalLM = super().load_model(
180
214
  model_name_or_config, *args, **kwargs
181
215
  )
@@ -1,3 +1,4 @@
1
1
  # flake8: noqa F401
2
2
  from . import separate_io, utils
3
3
  from .parameter_dict import ParameterDictModel
4
+ from fusion_bench.utils import LazyStateDict
File without changes
@@ -0,0 +1,15 @@
1
+ R"""
2
+ Copy from https://github.com/Lucky-Lance/Expert_Sparsity/tree/main/model
3
+
4
+ Original repo: https://github.com/Lucky-Lance/Expert_Sparsity
5
+
6
+ Reference:
7
+ Not All Experts are Equal: Efficient Expert Pruning and Skipping for Mixture-of-Experts Large Language Models.
8
+ ACL 2024.
9
+ http://arxiv.org/abs/2402.14800
10
+ """
11
+
12
+ from .wrapper import (
13
+ PrunableMixtralSparseMoeBlockWrapper,
14
+ DynamicSkippingMixtralSparseMoeBlockWrapper,
15
+ )
@@ -0,0 +1,40 @@
1
+ import torch
2
+
3
+
4
+ class CacheDataset(torch.utils.data.Dataset):
5
+ def __init__(self):
6
+ self.alphas = [] # logits
7
+ self.Xs = [] # input hidden states
8
+ self.Zs = [] # output hidden states
9
+ self.prepared = False
10
+
11
+ def __len__(self):
12
+ if not self.prepared:
13
+ self.prepare_for_loader()
14
+ return len(self.alphas)
15
+
16
+ def __getitem__(self, index):
17
+ if not self.prepared:
18
+ self.prepare_for_loader()
19
+ if isinstance(index, list):
20
+ return [(self.alphas[idx], self.Xs[idx], self.Zs[idx]) for idx in index]
21
+ elif isinstance(index, int):
22
+ return self.alphas[index], self.Xs[index], self.Zs[index]
23
+
24
+ def append(self, alpha=None, X=None, Z=None):
25
+ if alpha is not None:
26
+ self.alphas.append(alpha.detach().to("cpu", non_blocking=True))
27
+ if X is not None:
28
+ self.Xs.append(X.detach().to("cpu", non_blocking=True))
29
+ if Z is not None:
30
+ self.Zs.append(Z.detach().to("cpu", non_blocking=True))
31
+ self.prepared = False
32
+
33
+ def prepare_for_loader(self):
34
+ if self.prepared:
35
+ return
36
+ self.prepared = True
37
+ self.alphas = torch.concat(self.alphas)
38
+ self.Xs = torch.concat(self.Xs)
39
+ self.Zs = torch.concat(self.Zs)
40
+ assert len(self.Xs) == len(self.Zs)
@@ -0,0 +1,207 @@
1
+ import warnings
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
8
+ from transformers.models.mixtral.modeling_mixtral import (
9
+ MixtralBlockSparseTop2MLP,
10
+ MixtralConfig,
11
+ MixtralRMSNorm,
12
+ MixtralSparseMoeBlock,
13
+ )
14
+
15
+
16
+ class DynamicSkippingMixtralSparseMoeBlock(nn.Module):
17
+ """
18
+ This implementation is
19
+ strictly equivalent to standard MoE with full capacity (no
20
+ dropped tokens). It's faster since it formulates MoE operations
21
+ in terms of block-sparse operations to accomodate imbalanced
22
+ assignments of tokens to experts, whereas standard MoE either
23
+ (1) drop tokens at the cost of reduced performance or (2) set
24
+ capacity factor to number of experts and thus waste computation
25
+ and memory on padding.
26
+ """
27
+
28
+ def __init__(self, config, beta):
29
+ super().__init__()
30
+ self.hidden_dim = config.hidden_size
31
+ self.ffn_dim = config.intermediate_size
32
+ self.num_experts = config.num_local_experts
33
+ self.top_k = config.num_experts_per_tok
34
+
35
+ # gating
36
+ self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
37
+
38
+ self.experts = nn.ModuleList(
39
+ [MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]
40
+ )
41
+
42
+ self.beta = beta
43
+
44
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
45
+ """ """
46
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
47
+ hidden_states = hidden_states.view(-1, hidden_dim)
48
+ # router_logits: (batch * sequence_length, n_experts)
49
+ router_logits = self.gate(hidden_states)
50
+
51
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
52
+ routing_weights, selected_experts = torch.topk(
53
+ routing_weights, self.top_k, dim=-1
54
+ )
55
+
56
+ onlytop1_mask = (
57
+ routing_weights[:, 1] < self.beta * routing_weights[:, 0]
58
+ ) # bz x seqlen
59
+
60
+ # routing_weights[tokens_top1, 1].fill_(0)
61
+ routing_weights[onlytop1_mask, 1] = 0
62
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
63
+ # we cast back to the input dtype
64
+ routing_weights = routing_weights.to(hidden_states.dtype)
65
+
66
+ final_hidden_states = torch.zeros(
67
+ (batch_size * sequence_length, hidden_dim),
68
+ dtype=hidden_states.dtype,
69
+ device=hidden_states.device,
70
+ )
71
+
72
+ # One hot encode the selected experts to create an expert mask
73
+ # this will be used to easily index which expert is going to be sollicitated
74
+ expert_mask = torch.nn.functional.one_hot(
75
+ selected_experts, num_classes=self.num_experts
76
+ )
77
+ # ipdb.set_trace()
78
+ # expert_mask[tokens_top1, 1, :].fill_(0)
79
+ expert_mask[onlytop1_mask, 1, :] = 0
80
+ expert_mask = expert_mask.permute(2, 1, 0)
81
+
82
+ # Loop over all available experts in the model and perform the computation on each expert
83
+ for expert_idx in range(self.num_experts):
84
+ expert_layer = self.experts[expert_idx]
85
+ idx, top_x = torch.where(expert_mask[expert_idx])
86
+
87
+ if top_x.shape[0] == 0:
88
+ continue
89
+
90
+ # in torch it is faster to index using lists than torch tensors
91
+ top_x_list = top_x.tolist()
92
+ idx_list = idx.tolist()
93
+
94
+ # Index the correct hidden states and compute the expert hidden state for
95
+ # the current expert. We need to make sure to multiply the output hidden
96
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
97
+ current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
98
+ current_hidden_states = expert_layer(
99
+ current_state, routing_weights[top_x_list, idx_list, None]
100
+ )
101
+
102
+ # However `index_add_` only support torch tensors for indexing so we'll use
103
+ # the `top_x` tensor here.
104
+ final_hidden_states.index_add_(
105
+ 0, top_x, current_hidden_states.to(hidden_states.dtype)
106
+ )
107
+ final_hidden_states = final_hidden_states.reshape(
108
+ batch_size, sequence_length, hidden_dim
109
+ )
110
+ return final_hidden_states, router_logits
111
+
112
+
113
+ class MixtralDecoderLayer(nn.Module):
114
+ def __init__(self, config: MixtralConfig, layer_idx: int):
115
+ super().__init__()
116
+ self.hidden_size = config.hidden_size
117
+
118
+ self.self_attn = ALL_ATTENTION_FUNCTIONS[config._attn_implementation](
119
+ config, layer_idx
120
+ )
121
+ if hasattr(config, "betas"):
122
+ assert (
123
+ isinstance(config.betas, dict)
124
+ and len(config.betas) == config.num_hidden_layers
125
+ )
126
+ self.block_sparse_moe = DynamicSkippingMixtralSparseMoeBlock(
127
+ config, config.betas[str(layer_idx)]
128
+ )
129
+ warnings.warn(
130
+ f"Using online drop: {layer_idx}, {config.betas[str(layer_idx)]}, {type(self.block_sparse_moe)}"
131
+ )
132
+ else:
133
+ self.block_sparse_moe = MixtralSparseMoeBlock(config)
134
+ self.input_layernorm = MixtralRMSNorm(
135
+ config.hidden_size, eps=config.rms_norm_eps
136
+ )
137
+ self.post_attention_layernorm = MixtralRMSNorm(
138
+ config.hidden_size, eps=config.rms_norm_eps
139
+ )
140
+
141
+ def forward(
142
+ self,
143
+ hidden_states: torch.Tensor,
144
+ attention_mask: Optional[torch.Tensor] = None,
145
+ position_ids: Optional[torch.LongTensor] = None,
146
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
147
+ output_attentions: Optional[bool] = False,
148
+ output_router_logits: Optional[bool] = False,
149
+ use_cache: Optional[bool] = False,
150
+ **kwargs,
151
+ ) -> Tuple[
152
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
153
+ ]:
154
+ if "padding_mask" in kwargs:
155
+ warnings.warn(
156
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
157
+ )
158
+ """
159
+ Args:
160
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
161
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
162
+ `(batch, sequence_length)` where padding elements are indicated by 0.
163
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
164
+ output_attentions (`bool`, *optional*):
165
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
166
+ returned tensors for more detail.
167
+ output_router_logits (`bool`, *optional*):
168
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
169
+ should not be returned during inference.
170
+ use_cache (`bool`, *optional*):
171
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
172
+ (see `past_key_values`).
173
+ """
174
+
175
+ residual = hidden_states
176
+
177
+ hidden_states = self.input_layernorm(hidden_states)
178
+
179
+ # Self Attention
180
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
181
+ hidden_states=hidden_states,
182
+ attention_mask=attention_mask,
183
+ position_ids=position_ids,
184
+ past_key_value=past_key_value,
185
+ output_attentions=output_attentions,
186
+ use_cache=use_cache,
187
+ )
188
+ hidden_states = residual + hidden_states
189
+
190
+ # Fully Connected
191
+ residual = hidden_states
192
+ hidden_states = self.post_attention_layernorm(hidden_states)
193
+ hidden_states, router_logits = self.block_sparse_moe(hidden_states)
194
+ hidden_states = residual + hidden_states
195
+
196
+ outputs = (hidden_states,)
197
+
198
+ if output_attentions:
199
+ outputs += (self_attn_weights,)
200
+
201
+ if use_cache:
202
+ outputs += (present_key_value,)
203
+
204
+ if output_router_logits:
205
+ outputs += (router_logits,)
206
+
207
+ return outputs