fusion-bench 0.2.16__py3-none-any.whl → 0.2.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. fusion_bench/method/__init__.py +11 -0
  2. fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py +1 -1
  3. fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py +1 -1
  4. fusion_bench/method/base_algorithm.py +1 -0
  5. fusion_bench/method/dawe/dawe_for_clip.py +1 -1
  6. fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py +3 -2
  7. fusion_bench/method/expert_sparsity/__init__.py +10 -0
  8. fusion_bench/method/expert_sparsity/mixtral/__init__.py +23 -0
  9. fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +175 -0
  10. fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +159 -0
  11. fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +173 -0
  12. fusion_bench/method/expert_sparsity/utils/calibration_data.py +153 -0
  13. fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +1 -1
  14. fusion_bench/method/knots/__init__.py +0 -0
  15. fusion_bench/method/knots/knots_utils.py +23 -0
  16. fusion_bench/method/pwe_moe/module.py +2 -7
  17. fusion_bench/method/simple_average.py +3 -2
  18. fusion_bench/method/task_singular_vector/TSVM.py +238 -25
  19. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +52 -20
  20. fusion_bench/method/task_singular_vector/utils/__init__.py +1 -0
  21. fusion_bench/method/task_singular_vector/utils/task_singular_interference.py +41 -0
  22. fusion_bench/mixins/hydra_config.py +1 -1
  23. fusion_bench/mixins/lightning_fabric.py +25 -1
  24. fusion_bench/mixins/serialization.py +18 -2
  25. fusion_bench/modelpool/base_pool.py +1 -0
  26. fusion_bench/modelpool/causal_lm/causal_lm.py +8 -5
  27. fusion_bench/modelpool/clip_vision/modelpool.py +21 -13
  28. fusion_bench/models/__init__.py +1 -0
  29. fusion_bench/models/expert_sparsity/__init__.py +0 -0
  30. fusion_bench/models/expert_sparsity/mixtral/__init__.py +15 -0
  31. fusion_bench/models/expert_sparsity/mixtral/dataset.py +40 -0
  32. fusion_bench/models/expert_sparsity/mixtral/modeling_mixtral.py +207 -0
  33. fusion_bench/models/expert_sparsity/mixtral/wrapper.py +268 -0
  34. fusion_bench/models/parameter_dict.py +6 -1
  35. fusion_bench/programs/fabric_fusion_program.py +21 -13
  36. fusion_bench/taskpool/base_pool.py +1 -0
  37. fusion_bench/taskpool/dummy.py +6 -4
  38. fusion_bench/utils/__init__.py +4 -3
  39. fusion_bench/utils/dtype.py +2 -1
  40. fusion_bench/utils/fabric.py +11 -4
  41. fusion_bench/utils/{instantiate.py → instantiate_utils.py} +3 -0
  42. fusion_bench/utils/lazy_state_dict.py +80 -10
  43. fusion_bench/utils/pylogger.py +30 -0
  44. {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/METADATA +3 -1
  45. {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/RECORD +59 -38
  46. {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/WHEEL +1 -1
  47. fusion_bench_config/fabric/loggers/mlflow_logger.yaml +2 -0
  48. fusion_bench_config/fabric_model_fusion.yaml +2 -2
  49. fusion_bench_config/method/expert_sparsity/README.md +6 -0
  50. fusion_bench_config/method/expert_sparsity/mixtral.yaml +17 -0
  51. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -1
  52. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_cars_and_dtd.yaml +16 -0
  53. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +16 -0
  54. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +16 -0
  55. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +19 -0
  56. fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml +0 -1
  57. {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/entry_points.txt +0 -0
  58. {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/licenses/LICENSE +0 -0
  59. {fusion_bench-0.2.16.dist-info → fusion_bench-0.2.18.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,41 @@
1
+ from typing import List
2
+
3
+ import torch
4
+
5
+
6
+ def compute_task_singular_interference(weight_differences: List[torch.Tensor]) -> float:
7
+ R"""
8
+ Compute the singular interference of a list of weight differences $\{W_i - W_0\}_{i=1}^T$,
9
+ where $W_0$ is the pre-trained model weight, $W_i$ is the weight of the i-th fine-tuned model
10
+ and $T$ is the number of fine-tuned models.
11
+
12
+ Args:
13
+ weight_differences (List[torch.Tensor]): A list of weight differences $\{W_i - W_0\}_{i=1}^T$.
14
+
15
+ Returns:
16
+ float: The singular interference of the list of weight differences.
17
+ """
18
+ device = weight_differences[0].device
19
+ dtype = weight_differences[0].dtype
20
+
21
+ U = []
22
+ S = []
23
+ V = []
24
+ for delta_w in weight_differences:
25
+ u, s, vh = torch.linalg.svd(delta_w, full_matrices=False)
26
+ U.append(u)
27
+ S.append(s)
28
+ V.append(vh.t())
29
+ U = torch.cat(U, dim=0)
30
+ S = torch.cat(S, dim=0)
31
+ V = torch.cat(V, dim=0)
32
+
33
+ singular_task_interference = torch.linalg.multi_dot(
34
+ (
35
+ U.t() @ U - torch.eye(U.shape[1], device=device, dtype=dtype),
36
+ torch.diag(S),
37
+ V.t() @ V - torch.eye(V.shape[1], device=device, dtype=dtype),
38
+ )
39
+ )
40
+ singular_task_interference = torch.linalg.norm(singular_task_interference, ord="1")
41
+ return singular_task_interference
@@ -9,7 +9,7 @@ from hydra import compose, initialize
9
9
  from omegaconf import DictConfig, OmegaConf
10
10
 
11
11
  from fusion_bench.utils import import_object, instantiate
12
- from fusion_bench.utils.instantiate import set_print_function_call
12
+ from fusion_bench.utils.instantiate_utils import set_print_function_call
13
13
 
14
14
  log = logging.getLogger(__name__)
15
15
 
@@ -11,7 +11,7 @@ from lightning.fabric.utilities.rank_zero import rank_zero_only
11
11
  from omegaconf import DictConfig, OmegaConf
12
12
 
13
13
  from fusion_bench.utils import import_object
14
- from fusion_bench.utils.instantiate import instantiate
14
+ from fusion_bench.utils.instantiate_utils import instantiate
15
15
 
16
16
  if TYPE_CHECKING:
17
17
  import lightning.fabric.loggers.tensorboard
@@ -172,3 +172,27 @@ class LightningFabricMixin:
172
172
  return True
173
173
  else:
174
174
  return False
175
+
176
+ def log(self, name: str, value: Any, step: Optional[int] = None):
177
+ """
178
+ Logs the metric to the fabric's logger.
179
+ """
180
+ self.fabric.log(name, value, step=step)
181
+
182
+ def log_dict(self, metrics: dict, step: Optional[int] = None):
183
+ """
184
+ Logs the metrics to the fabric's logger.
185
+ """
186
+ self.fabric.log_dict(metrics, step=step)
187
+
188
+ def log_optimizer_lr(
189
+ self,
190
+ optimizer: torch.optim.Optimizer,
191
+ step: Optional[int] = None,
192
+ name_template: str = "train/lr_group_{0}",
193
+ ):
194
+ """
195
+ Logs the learning rate of the optimizer to the fabric's logger.
196
+ """
197
+ for i, param_group in enumerate(optimizer.param_groups):
198
+ self.fabric.log(name_template.format(i), param_group["lr"], step=step)
@@ -4,13 +4,14 @@ from typing import Dict, Optional, Union
4
4
 
5
5
  from omegaconf import OmegaConf
6
6
 
7
- from fusion_bench.utils import instantiate
7
+ from fusion_bench.utils import import_object, instantiate
8
8
 
9
9
  log = logging.getLogger(__name__)
10
10
 
11
11
 
12
12
  class YAMLSerializationMixin:
13
13
  _recursive_: bool = False
14
+ _config_key: Optional[str] = None
14
15
  _config_mapping: Dict[str, str] = {
15
16
  "_recursive_": "_recursive_",
16
17
  }
@@ -99,7 +100,22 @@ class YAMLSerializationMixin:
99
100
  BaseModelPool: The loaded model pool.
100
101
  """
101
102
  config = OmegaConf.load(path)
102
- return instantiate(config, _recursive_=cls._recursive_)
103
+ if cls._config_key is not None and cls._config_key in config:
104
+ config = config[cls._config_key]
105
+ target_cls = import_object(config["_target_"])
106
+ if target_cls != cls:
107
+ log.warning(
108
+ f"The class {target_cls.__name__} is not the same as the class {cls.__name__}. "
109
+ f"Instantiating the class {target_cls.__name__} instead."
110
+ )
111
+ return instantiate(
112
+ config,
113
+ _recursive_=(
114
+ cls._recursive_
115
+ if config.get("_recursive_") is None
116
+ else config.get("_recursive_")
117
+ ),
118
+ )
103
119
 
104
120
  def to_config(self):
105
121
  """
@@ -29,6 +29,7 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
29
29
  """
30
30
 
31
31
  _program = None
32
+ _config_key = "modelpool"
32
33
  _models: Union[DictConfig, Dict[str, nn.Module]]
33
34
  _config_mapping = BaseYAMLSerializableModel._config_mapping | {
34
35
  "_models": "models",
@@ -141,6 +141,7 @@ class CausalLMPool(BaseModelPool):
141
141
  model_dtype: Optional[str] = None,
142
142
  save_tokenizer: bool = False,
143
143
  tokenizer_kwargs=None,
144
+ tokenizer: Optional[PreTrainedTokenizer] = None,
144
145
  **kwargs,
145
146
  ):
146
147
  """
@@ -154,11 +155,13 @@ class CausalLMPool(BaseModelPool):
154
155
  **kwargs: Additional keyword arguments passed to the `save_pretrained` method.
155
156
  """
156
157
  path = os.path.expanduser(path)
157
- if save_tokenizer:
158
- if tokenizer_kwargs is None:
159
- tokenizer_kwargs = {}
160
- # load the tokenizer
161
- tokenizer = self.load_tokenizer(**tokenizer_kwargs)
158
+ # NOTE: if tokenizer is provided, it will be saved regardless of `save_tokenizer`
159
+ if save_tokenizer or tokenizer is not None:
160
+ if tokenizer is None:
161
+ if tokenizer_kwargs is None:
162
+ tokenizer_kwargs = {}
163
+ # load the tokenizer
164
+ tokenizer = self.load_tokenizer(**tokenizer_kwargs)
162
165
  tokenizer.save_pretrained(
163
166
  path,
164
167
  push_to_hub=push_to_hub,
@@ -3,6 +3,7 @@ from copy import deepcopy
3
3
  from typing import Optional, Union
4
4
 
5
5
  from datasets import load_dataset
6
+ from lightning.fabric.utilities import rank_zero_only
6
7
  from omegaconf import DictConfig, open_dict
7
8
  from torch import nn
8
9
  from torch.utils.data import Dataset
@@ -40,7 +41,8 @@ class CLIPVisionModelPool(BaseModelPool):
40
41
  def load_processor(self, *args, **kwargs) -> CLIPProcessor:
41
42
  assert self._processor is not None, "Processor is not defined in the config"
42
43
  if isinstance(self._processor, str):
43
- log.info(f"Loading `transformers.CLIPProcessor`: {self._processor}")
44
+ if rank_zero_only.rank == 0:
45
+ log.info(f"Loading `transformers.CLIPProcessor`: {self._processor}")
44
46
  processor = CLIPProcessor.from_pretrained(self._processor)
45
47
  else:
46
48
  processor = instantiate(self._processor, *args, **kwargs)
@@ -50,7 +52,8 @@ class CLIPVisionModelPool(BaseModelPool):
50
52
  model_config = self._models[model_name]
51
53
 
52
54
  if isinstance(model_config, str):
53
- log.info(f"Loading `transformers.CLIPModel`: {model_config}")
55
+ if rank_zero_only.rank == 0:
56
+ log.info(f"Loading `transformers.CLIPModel`: {model_config}")
54
57
  clip_model = CLIPModel.from_pretrained(model_config, *args, **kwargs)
55
58
  return clip_model
56
59
  else:
@@ -102,10 +105,12 @@ class CLIPVisionModelPool(BaseModelPool):
102
105
  ):
103
106
  model = self._models[model_name_or_config]
104
107
  if isinstance(model, str):
105
- log.info(f"Loading `transformers.CLIPVisionModel`: {model}")
108
+ if rank_zero_only.rank == 0:
109
+ log.info(f"Loading `transformers.CLIPVisionModel`: {model}")
106
110
  return CLIPVisionModel.from_pretrained(model, *args, **kwargs)
107
111
  if isinstance(model, nn.Module):
108
- log.info(f"Returning existing model: {model}")
112
+ if rank_zero_only.rank == 0:
113
+ log.info(f"Returning existing model: {model}")
109
114
  return model
110
115
 
111
116
  # If the model is not a string, we use the default load_model method
@@ -114,9 +119,10 @@ class CLIPVisionModelPool(BaseModelPool):
114
119
  def load_train_dataset(self, dataset_name: str, *args, **kwargs):
115
120
  dataset_config = self._train_datasets[dataset_name]
116
121
  if isinstance(dataset_config, str):
117
- log.info(
118
- f"Loading train dataset using `datasets.load_dataset`: {dataset_config}"
119
- )
122
+ if rank_zero_only.rank == 0:
123
+ log.info(
124
+ f"Loading train dataset using `datasets.load_dataset`: {dataset_config}"
125
+ )
120
126
  dataset = load_dataset(dataset_config, split="train")
121
127
  else:
122
128
  dataset = super().load_train_dataset(dataset_name, *args, **kwargs)
@@ -125,9 +131,10 @@ class CLIPVisionModelPool(BaseModelPool):
125
131
  def load_val_dataset(self, dataset_name: str, *args, **kwargs):
126
132
  dataset_config = self._val_datasets[dataset_name]
127
133
  if isinstance(dataset_config, str):
128
- log.info(
129
- f"Loading validation dataset using `datasets.load_dataset`: {dataset_config}"
130
- )
134
+ if rank_zero_only.rank == 0:
135
+ log.info(
136
+ f"Loading validation dataset using `datasets.load_dataset`: {dataset_config}"
137
+ )
131
138
  dataset = load_dataset(dataset_config, split="validation")
132
139
  else:
133
140
  dataset = super().load_val_dataset(dataset_name, *args, **kwargs)
@@ -136,9 +143,10 @@ class CLIPVisionModelPool(BaseModelPool):
136
143
  def load_test_dataset(self, dataset_name: str, *args, **kwargs):
137
144
  dataset_config = self._test_datasets[dataset_name]
138
145
  if isinstance(dataset_config, str):
139
- log.info(
140
- f"Loading test dataset using `datasets.load_dataset`: {dataset_config}"
141
- )
146
+ if rank_zero_only.rank == 0:
147
+ log.info(
148
+ f"Loading test dataset using `datasets.load_dataset`: {dataset_config}"
149
+ )
142
150
  dataset = load_dataset(dataset_config, split="test")
143
151
  else:
144
152
  dataset = super().load_test_dataset(dataset_name, *args, **kwargs)
@@ -1,3 +1,4 @@
1
1
  # flake8: noqa F401
2
2
  from . import separate_io, utils
3
3
  from .parameter_dict import ParameterDictModel
4
+ from fusion_bench.utils import LazyStateDict
File without changes
@@ -0,0 +1,15 @@
1
+ R"""
2
+ Copy from https://github.com/Lucky-Lance/Expert_Sparsity/tree/main/model
3
+
4
+ Original repo: https://github.com/Lucky-Lance/Expert_Sparsity
5
+
6
+ Reference:
7
+ Not All Experts are Equal: Efficient Expert Pruning and Skipping for Mixture-of-Experts Large Language Models.
8
+ ACL 2024.
9
+ http://arxiv.org/abs/2402.14800
10
+ """
11
+
12
+ from .wrapper import (
13
+ PrunableMixtralSparseMoeBlockWrapper,
14
+ DynamicSkippingMixtralSparseMoeBlockWrapper,
15
+ )
@@ -0,0 +1,40 @@
1
+ import torch
2
+
3
+
4
+ class CacheDataset(torch.utils.data.Dataset):
5
+ def __init__(self):
6
+ self.alphas = [] # logits
7
+ self.Xs = [] # input hidden states
8
+ self.Zs = [] # output hidden states
9
+ self.prepared = False
10
+
11
+ def __len__(self):
12
+ if not self.prepared:
13
+ self.prepare_for_loader()
14
+ return len(self.alphas)
15
+
16
+ def __getitem__(self, index):
17
+ if not self.prepared:
18
+ self.prepare_for_loader()
19
+ if isinstance(index, list):
20
+ return [(self.alphas[idx], self.Xs[idx], self.Zs[idx]) for idx in index]
21
+ elif isinstance(index, int):
22
+ return self.alphas[index], self.Xs[index], self.Zs[index]
23
+
24
+ def append(self, alpha=None, X=None, Z=None):
25
+ if alpha is not None:
26
+ self.alphas.append(alpha.detach().to("cpu", non_blocking=True))
27
+ if X is not None:
28
+ self.Xs.append(X.detach().to("cpu", non_blocking=True))
29
+ if Z is not None:
30
+ self.Zs.append(Z.detach().to("cpu", non_blocking=True))
31
+ self.prepared = False
32
+
33
+ def prepare_for_loader(self):
34
+ if self.prepared:
35
+ return
36
+ self.prepared = True
37
+ self.alphas = torch.concat(self.alphas)
38
+ self.Xs = torch.concat(self.Xs)
39
+ self.Zs = torch.concat(self.Zs)
40
+ assert len(self.Xs) == len(self.Zs)
@@ -0,0 +1,207 @@
1
+ import warnings
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
8
+ from transformers.models.mixtral.modeling_mixtral import (
9
+ MixtralBlockSparseTop2MLP,
10
+ MixtralConfig,
11
+ MixtralRMSNorm,
12
+ MixtralSparseMoeBlock,
13
+ )
14
+
15
+
16
+ class DynamicSkippingMixtralSparseMoeBlock(nn.Module):
17
+ """
18
+ This implementation is
19
+ strictly equivalent to standard MoE with full capacity (no
20
+ dropped tokens). It's faster since it formulates MoE operations
21
+ in terms of block-sparse operations to accomodate imbalanced
22
+ assignments of tokens to experts, whereas standard MoE either
23
+ (1) drop tokens at the cost of reduced performance or (2) set
24
+ capacity factor to number of experts and thus waste computation
25
+ and memory on padding.
26
+ """
27
+
28
+ def __init__(self, config, beta):
29
+ super().__init__()
30
+ self.hidden_dim = config.hidden_size
31
+ self.ffn_dim = config.intermediate_size
32
+ self.num_experts = config.num_local_experts
33
+ self.top_k = config.num_experts_per_tok
34
+
35
+ # gating
36
+ self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
37
+
38
+ self.experts = nn.ModuleList(
39
+ [MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]
40
+ )
41
+
42
+ self.beta = beta
43
+
44
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
45
+ """ """
46
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
47
+ hidden_states = hidden_states.view(-1, hidden_dim)
48
+ # router_logits: (batch * sequence_length, n_experts)
49
+ router_logits = self.gate(hidden_states)
50
+
51
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
52
+ routing_weights, selected_experts = torch.topk(
53
+ routing_weights, self.top_k, dim=-1
54
+ )
55
+
56
+ onlytop1_mask = (
57
+ routing_weights[:, 1] < self.beta * routing_weights[:, 0]
58
+ ) # bz x seqlen
59
+
60
+ # routing_weights[tokens_top1, 1].fill_(0)
61
+ routing_weights[onlytop1_mask, 1] = 0
62
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
63
+ # we cast back to the input dtype
64
+ routing_weights = routing_weights.to(hidden_states.dtype)
65
+
66
+ final_hidden_states = torch.zeros(
67
+ (batch_size * sequence_length, hidden_dim),
68
+ dtype=hidden_states.dtype,
69
+ device=hidden_states.device,
70
+ )
71
+
72
+ # One hot encode the selected experts to create an expert mask
73
+ # this will be used to easily index which expert is going to be sollicitated
74
+ expert_mask = torch.nn.functional.one_hot(
75
+ selected_experts, num_classes=self.num_experts
76
+ )
77
+ # ipdb.set_trace()
78
+ # expert_mask[tokens_top1, 1, :].fill_(0)
79
+ expert_mask[onlytop1_mask, 1, :] = 0
80
+ expert_mask = expert_mask.permute(2, 1, 0)
81
+
82
+ # Loop over all available experts in the model and perform the computation on each expert
83
+ for expert_idx in range(self.num_experts):
84
+ expert_layer = self.experts[expert_idx]
85
+ idx, top_x = torch.where(expert_mask[expert_idx])
86
+
87
+ if top_x.shape[0] == 0:
88
+ continue
89
+
90
+ # in torch it is faster to index using lists than torch tensors
91
+ top_x_list = top_x.tolist()
92
+ idx_list = idx.tolist()
93
+
94
+ # Index the correct hidden states and compute the expert hidden state for
95
+ # the current expert. We need to make sure to multiply the output hidden
96
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
97
+ current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
98
+ current_hidden_states = expert_layer(
99
+ current_state, routing_weights[top_x_list, idx_list, None]
100
+ )
101
+
102
+ # However `index_add_` only support torch tensors for indexing so we'll use
103
+ # the `top_x` tensor here.
104
+ final_hidden_states.index_add_(
105
+ 0, top_x, current_hidden_states.to(hidden_states.dtype)
106
+ )
107
+ final_hidden_states = final_hidden_states.reshape(
108
+ batch_size, sequence_length, hidden_dim
109
+ )
110
+ return final_hidden_states, router_logits
111
+
112
+
113
+ class MixtralDecoderLayer(nn.Module):
114
+ def __init__(self, config: MixtralConfig, layer_idx: int):
115
+ super().__init__()
116
+ self.hidden_size = config.hidden_size
117
+
118
+ self.self_attn = ALL_ATTENTION_FUNCTIONS[config._attn_implementation](
119
+ config, layer_idx
120
+ )
121
+ if hasattr(config, "betas"):
122
+ assert (
123
+ isinstance(config.betas, dict)
124
+ and len(config.betas) == config.num_hidden_layers
125
+ )
126
+ self.block_sparse_moe = DynamicSkippingMixtralSparseMoeBlock(
127
+ config, config.betas[str(layer_idx)]
128
+ )
129
+ warnings.warn(
130
+ f"Using online drop: {layer_idx}, {config.betas[str(layer_idx)]}, {type(self.block_sparse_moe)}"
131
+ )
132
+ else:
133
+ self.block_sparse_moe = MixtralSparseMoeBlock(config)
134
+ self.input_layernorm = MixtralRMSNorm(
135
+ config.hidden_size, eps=config.rms_norm_eps
136
+ )
137
+ self.post_attention_layernorm = MixtralRMSNorm(
138
+ config.hidden_size, eps=config.rms_norm_eps
139
+ )
140
+
141
+ def forward(
142
+ self,
143
+ hidden_states: torch.Tensor,
144
+ attention_mask: Optional[torch.Tensor] = None,
145
+ position_ids: Optional[torch.LongTensor] = None,
146
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
147
+ output_attentions: Optional[bool] = False,
148
+ output_router_logits: Optional[bool] = False,
149
+ use_cache: Optional[bool] = False,
150
+ **kwargs,
151
+ ) -> Tuple[
152
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
153
+ ]:
154
+ if "padding_mask" in kwargs:
155
+ warnings.warn(
156
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
157
+ )
158
+ """
159
+ Args:
160
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
161
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
162
+ `(batch, sequence_length)` where padding elements are indicated by 0.
163
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
164
+ output_attentions (`bool`, *optional*):
165
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
166
+ returned tensors for more detail.
167
+ output_router_logits (`bool`, *optional*):
168
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
169
+ should not be returned during inference.
170
+ use_cache (`bool`, *optional*):
171
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
172
+ (see `past_key_values`).
173
+ """
174
+
175
+ residual = hidden_states
176
+
177
+ hidden_states = self.input_layernorm(hidden_states)
178
+
179
+ # Self Attention
180
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
181
+ hidden_states=hidden_states,
182
+ attention_mask=attention_mask,
183
+ position_ids=position_ids,
184
+ past_key_value=past_key_value,
185
+ output_attentions=output_attentions,
186
+ use_cache=use_cache,
187
+ )
188
+ hidden_states = residual + hidden_states
189
+
190
+ # Fully Connected
191
+ residual = hidden_states
192
+ hidden_states = self.post_attention_layernorm(hidden_states)
193
+ hidden_states, router_logits = self.block_sparse_moe(hidden_states)
194
+ hidden_states = residual + hidden_states
195
+
196
+ outputs = (hidden_states,)
197
+
198
+ if output_attentions:
199
+ outputs += (self_attn_weights,)
200
+
201
+ if use_cache:
202
+ outputs += (present_key_value,)
203
+
204
+ if output_router_logits:
205
+ outputs += (router_logits,)
206
+
207
+ return outputs