fusion-bench 0.2.13__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. fusion_bench/method/ada_svd/clip_vision.py +4 -1
  2. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +46 -145
  3. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +229 -0
  4. fusion_bench/method/smile_upscaling/smile_upscaling.py +6 -336
  5. fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +2 -203
  6. fusion_bench/models/modeling_smile_qwen2/__init__.py +8 -0
  7. fusion_bench/models/modeling_smile_qwen2/configuration_smile_qwen2.py +21 -0
  8. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +922 -0
  9. fusion_bench/models/modeling_smile_qwen2/register.py +11 -0
  10. fusion_bench/models/rankone_moe.py +2 -88
  11. fusion_bench/models/smile_moe/linear_from_hf_config.py +373 -0
  12. fusion_bench/models/smile_moe/{linear.py → linear_from_module.py} +103 -33
  13. fusion_bench/models/smile_moe/utils/__init__.py +24 -0
  14. fusion_bench/models/smile_moe/utils/svd_utils.py +46 -0
  15. fusion_bench/taskpool/__init__.py +2 -0
  16. fusion_bench/taskpool/lm_eval_harness/__init__.py +3 -0
  17. fusion_bench/taskpool/lm_eval_harness/taskpool.py +87 -0
  18. {fusion_bench-0.2.13.dist-info → fusion_bench-0.2.14.dist-info}/METADATA +22 -2
  19. {fusion_bench-0.2.13.dist-info → fusion_bench-0.2.14.dist-info}/RECORD +27 -14
  20. {fusion_bench-0.2.13.dist-info → fusion_bench-0.2.14.dist-info}/WHEEL +1 -1
  21. fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +5 -2
  22. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +13 -0
  23. fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +17 -0
  24. fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml +12 -0
  25. {fusion_bench-0.2.13.dist-info → fusion_bench-0.2.14.dist-info}/entry_points.txt +0 -0
  26. {fusion_bench-0.2.13.dist-info → fusion_bench-0.2.14.dist-info}/licenses/LICENSE +0 -0
  27. {fusion_bench-0.2.13.dist-info → fusion_bench-0.2.14.dist-info}/top_level.txt +0 -0
@@ -13,348 +13,18 @@ from fusion_bench.method import BaseAlgorithm
13
13
  from fusion_bench.method.simple_average import simple_average
14
14
  from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
15
15
  from fusion_bench.modelpool import BaseModelPool
16
+ from fusion_bench.models.smile_moe.linear_from_module import (
17
+ ExpertNotTrainedError,
18
+ SmileCompressedLinear,
19
+ SmileGate,
20
+ SmileMoELinear,
21
+ )
16
22
  from fusion_bench.models.utils import get_attr, set_attr
17
23
  from fusion_bench.utils.parameters import print_parameters
18
24
 
19
25
  log = logging.getLogger(__name__)
20
26
 
21
27
 
22
- class ExpertNotTrainedError(Exception):
23
- pass
24
-
25
-
26
- def _is_all_zeros(tensor: Tensor | List[Tensor]) -> bool:
27
- """
28
- Check if a tensor or a list of tensors are all zeros.
29
-
30
- Args:
31
- tensor (Tensor | List[Tensor]): A tensor or a list of tensors.
32
-
33
- Returns:
34
- bool: True if all elements are zeros, False otherwise.
35
- """
36
- if isinstance(tensor, Tensor):
37
- return torch.allclose(tensor, torch.zeros_like(tensor))
38
- else:
39
- return all(_is_all_zeros(t) for t in tensor)
40
-
41
-
42
- def _svd(w: Tensor, full_matrices=True) -> Tuple[Tensor, Tensor, Tensor]:
43
- """
44
- Perform Singular Value Decomposition (SVD) on a tensor.
45
-
46
- Args:
47
- w (Tensor): The input tensor.
48
- full_matrices (bool): Whether to compute the full-sized U and V matrices.
49
-
50
- Returns:
51
- Tuple[Tensor, Tensor, Tensor]: The U, S, and V matrices from SVD.
52
- """
53
- u, s, vh = torch.linalg.svd(
54
- w, full_matrices=full_matrices, driver="gesvd" if w.is_cuda else None
55
- )
56
- v = vh.T
57
- return u, s, v
58
-
59
-
60
- def svd(
61
- w: Tensor, full_matrices=True, accelerator=None
62
- ) -> Tuple[Tensor, Tensor, Tensor]:
63
- """
64
- Perform SVD on a tensor, optionally using a specified accelerator.
65
-
66
- Args:
67
- w (Tensor): The input tensor.
68
- full_matrices (bool): Whether to compute the full-sized U and V matrices.
69
- accelerator (str): The device to perform the computation on.
70
-
71
- Returns:
72
- Tuple[Tensor, Tensor, Tensor]: The U, S, and V matrices from SVD.
73
- """
74
- if accelerator is None:
75
- return _svd(w, full_matrices=full_matrices)
76
- original_device = w.device
77
- w = w.to(accelerator)
78
- u, s, v = _svd(w)
79
- return u.to(original_device), s.to(original_device), v.to(original_device)
80
-
81
-
82
- class SmileGate(nn.Module):
83
- def __init__(
84
- self,
85
- input_features: int,
86
- w_diff_list: List[Tensor],
87
- k: int,
88
- svd_list=None, # cached `svd_list`, pass it to avoid recomputing
89
- upscaling_accelerator=None,
90
- ):
91
- """
92
- Initialize the SmileGate module.
93
-
94
- Args:
95
- input_features (int): The number of input features.
96
- w_diff_list (List[Tensor]): A list of weight difference tensors.
97
- k (int): The number of singular values to keep.
98
- svd_list (List[Tuple[Tensor, Tensor, Tensor]]): Cached SVD results.
99
- upscaling_accelerator (str): The device to perform the computation on.
100
- """
101
- super().__init__()
102
- self.input_features = input_features
103
- self.num_experts = len(w_diff_list)
104
- weights = []
105
- for i, w_diff in enumerate(w_diff_list):
106
- if svd_list is None:
107
- u, s, v = svd(w_diff, accelerator=upscaling_accelerator)
108
- else:
109
- u, s, v = svd_list[i]
110
- u = u[:, :k]
111
- s = s[:k]
112
- v = v[:, :k]
113
-
114
- # weights.append((s * v).T)
115
- weights.append(v.T)
116
- self.k = s.size(0) # k is the actual k after truncation
117
-
118
- weights = (
119
- torch.stack(weights, dim=0)
120
- .reshape(self.num_experts * self.k, -1)
121
- .contiguous()
122
- )
123
- self.weights = nn.Parameter(
124
- weights
125
- ) # weights should be a tensor of shape (num_experts * k, n)
126
-
127
- def forward(self, x: Tensor):
128
- """
129
- Forward pass of the SmileGate module.
130
-
131
- Args:
132
- x (Tensor): The input tensor.
133
-
134
- Returns:
135
- Tensor: The routing weights.
136
- """
137
- batch_size = x.size(0)
138
- if self.num_experts == 1:
139
- return torch.ones(batch_size, 1, device=x.device, dtype=x.dtype)
140
-
141
- routing_weights = F.linear(x, self.weights).view(
142
- batch_size, self.num_experts, self.k
143
- )
144
- routing_weights = routing_weights.norm(p=2, dim=2)
145
- return routing_weights
146
-
147
-
148
- class SmileCompressedLinear(nn.Module):
149
- def __init__(self, model: nn.Linear, k: int, svd_cache=None):
150
- """
151
- Initialize the SmileCompressedLinear module.
152
-
153
- Args:
154
- model (nn.Linear): The linear model to compress.
155
- k (int): The number of singular values to keep.
156
- svd_cache (Tuple[Tensor, Tensor, Tensor]): Cached SVD results.
157
- """
158
- super().__init__()
159
- if svd_cache is None:
160
- u, s, v = svd(model.weight)
161
- else:
162
- u, s, v = svd_cache
163
- if k > 0:
164
- u = u[:, :k]
165
- s = s[:k]
166
- v = v[:, :k]
167
-
168
- self.u = nn.Parameter(u)
169
- self.svh = nn.Parameter((s * v).T)
170
-
171
- if model.bias is not None:
172
- self.bias = nn.Parameter(model.bias.data, requires_grad=True)
173
- else:
174
- self.register_parameter("bias", None)
175
-
176
- def forward(self, x):
177
- """
178
- Forward pass of the SmileCompressedLinear module.
179
-
180
- Args:
181
- x (Tensor): The input tensor.
182
-
183
- Returns:
184
- Tensor: The output tensor.
185
- """
186
- x = F.linear(x, self.svh)
187
- x = F.linear(x, self.u, self.bias)
188
- return x
189
-
190
-
191
- class SmileMoELinear(nn.Module):
192
- @torch.no_grad()
193
- def __init__(
194
- self,
195
- pretrained_model: nn.Linear,
196
- finetuned_models: List[nn.Linear],
197
- gate_k: int,
198
- k: int,
199
- top_k: int = 1,
200
- full_matrices=True,
201
- upscaling_accelerator=None,
202
- routing_use_diff=True,
203
- ):
204
- """
205
- Initialize the SmileMoELinear module.
206
-
207
- Args:
208
- pretrained_model (nn.Linear): The pretrained linear model.
209
- finetuned_models (List[nn.Linear]): A list of fine-tuned linear models.
210
- gate_k (int): The number of singular values to keep for the gate.
211
- k (int): The number of singular values to keep for the experts.
212
- top_k (int): The number of top experts to select.
213
- full_matrices (bool): Whether to compute the full-sized U and V matrices.
214
- upscaling_accelerator (str): The device to perform the computation on.
215
- routing_use_diff (bool): Whether to use weight differences for routing.
216
- """
217
- super().__init__()
218
- self.num_experts = len(finetuned_models)
219
- self.top_k = top_k
220
- self.k = k
221
- self.gate_k = gate_k
222
- self.in_features = pretrained_model.in_features
223
- self.out_features = pretrained_model.out_features
224
-
225
- w_diff_list = [m.weight - pretrained_model.weight for m in finetuned_models]
226
- if _is_all_zeros(w_diff_list):
227
- # All fine-tuned models are identical to the pretrained model
228
- raise ExpertNotTrainedError()
229
-
230
- if routing_use_diff or k > 0:
231
- svd_cache_list = [
232
- svd(w, full_matrices=full_matrices, accelerator=upscaling_accelerator)
233
- for w in w_diff_list
234
- ] # the svd cache list to avoid recomputing
235
-
236
- # construct the gate network
237
- if routing_use_diff:
238
- self.gate = SmileGate(
239
- input_features=self.in_features,
240
- w_diff_list=w_diff_list,
241
- k=gate_k,
242
- svd_list=svd_cache_list,
243
- upscaling_accelerator=upscaling_accelerator,
244
- )
245
- else:
246
- self.gate = SmileGate(
247
- input_features=self.in_features,
248
- w_diff_list=[m.weight for m in finetuned_models],
249
- k=gate_k,
250
- svd_list=None,
251
- upscaling_accelerator=upscaling_accelerator,
252
- )
253
-
254
- # construct experts
255
- for m, w_diff in zip(finetuned_models, w_diff_list):
256
- m.weight.data = w_diff
257
- if k > 0:
258
- experts = [
259
- SmileCompressedLinear(m, k, svd_cache=svd_cache)
260
- for m, svd_cache in zip(finetuned_models, svd_cache_list)
261
- ]
262
- else:
263
- # if k is not set (<0), we use the full fine-tuned model
264
- experts = finetuned_models
265
- self.experts = nn.ModuleList(experts)
266
-
267
- if pretrained_model.bias is not None:
268
- for m in experts:
269
- m.bias.data = m.bias.data - pretrained_model.bias
270
- # assign the pretrained model (the shared part)
271
- self.pretrained_model = pretrained_model
272
-
273
- def forward(self, hidden_states: Tensor):
274
- """
275
- Forward pass of the SmileMoELinear module.
276
-
277
- Args:
278
- hidden_states (Tensor): The input tensor.
279
-
280
- Returns:
281
- Tensor: The output tensor.
282
- """
283
- pretrained_out = self.pretrained_model(hidden_states)
284
-
285
- input_shape = hidden_states.size()
286
- hidden_states = hidden_states.view(-1, self.in_features)
287
-
288
- router_logits = self.gate(hidden_states)
289
- routing_weights = F.softmax(router_logits, dim=1)
290
- # sample the expert according to the routing weights
291
- routing_weights, selected_experts = torch.topk(
292
- routing_weights, self.top_k, dim=-1
293
- )
294
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
295
-
296
- final_hidden_states = torch.zeros(
297
- (hidden_states.size(0), self.out_features),
298
- dtype=hidden_states.dtype,
299
- device=hidden_states.device,
300
- )
301
-
302
- # One hot encode the selected experts to create an expert mask
303
- # this will be used to easily index which expert is going to be sollicitated
304
- expert_mask = torch.nn.functional.one_hot(
305
- selected_experts, num_classes=self.num_experts
306
- ).permute(2, 1, 0)
307
-
308
- # Loop over all available experts in the model and perform the computation on each expert
309
- for expert_idx in range(self.num_experts):
310
- expert_layer = self.experts[expert_idx]
311
- idx, top_x = torch.where(expert_mask[expert_idx])
312
-
313
- # Index the correct hidden states and compute the expert hidden state for
314
- # the current expert. We need to make sure to multiply the output hidden
315
- # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
316
- current_state = hidden_states[None, top_x].reshape(-1, self.in_features)
317
- if current_state.numel() == 0:
318
- continue
319
- current_hidden_states = (
320
- expert_layer(current_state) * routing_weights[top_x, idx, None]
321
- )
322
-
323
- # However `index_add_` only support torch tensors for indexing so we'll use
324
- # the `top_x` tensor here.
325
- final_hidden_states.index_add_(
326
- 0, top_x, current_hidden_states.to(hidden_states.dtype)
327
- )
328
- final_hidden_states = final_hidden_states.reshape(
329
- *input_shape[:-1], self.out_features
330
- )
331
- final_hidden_states = pretrained_out + final_hidden_states
332
- return final_hidden_states
333
-
334
- @property
335
- def weight(self):
336
- """
337
- Mimic linear layer. Bacause in some cases, user might indicate the device (or dtype of parameters) of the linear layer using `linear_layer.weight.device`
338
- """
339
- return self.pretrained_model.weight
340
-
341
- @property
342
- def bias(self):
343
- return self.pretrained_model.bias
344
-
345
- def __repr__(self):
346
- return (
347
- f"SingularMoELinear("
348
- f"in_features={self.pretrained_model.in_features}, "
349
- f"out_features={self.pretrained_model.out_features}, "
350
- f"num_experts={self.num_experts}, "
351
- f"top_k={self.top_k}, "
352
- f"gate_k={self.gate_k}, "
353
- f"k={self.k}"
354
- f")"
355
- )
356
-
357
-
358
28
  class SmileUpscalingAlgorithm(
359
29
  SimpleProfilerMixin,
360
30
  BaseAlgorithm,
@@ -27,6 +27,8 @@ from transformers.models.mistral.modeling_mistral import (
27
27
  MistralRotaryEmbedding,
28
28
  )
29
29
 
30
+ from fusion_bench.models.smile_moe.linear_from_hf_config import SmileLinear
31
+
30
32
  from .configuration_smile_mistral import SmileMistralConfig
31
33
 
32
34
  logger = logging.getLogger(__name__)
@@ -80,209 +82,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
80
82
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
81
83
 
82
84
 
83
- class SmileGate(nn.Module):
84
- __constants__ = ["in_features", "num_experts", "k"]
85
- in_features: int
86
- num_experts: int
87
- k: int
88
- weight: Tensor
89
-
90
- def __init__(
91
- self,
92
- in_features: int,
93
- num_experts: int,
94
- k: int,
95
- device=None,
96
- dtype=None,
97
- ):
98
- factory_kwargs = {"device": device, "dtype": dtype}
99
- super().__init__()
100
- self.input_features = in_features
101
- self.num_experts = num_experts
102
- self.k = k
103
-
104
- self.weight = nn.Parameter(
105
- torch.empty(num_experts * k, in_features, **factory_kwargs)
106
- )
107
-
108
- def forward(self, x: Tensor):
109
- batch_size = x.size(0)
110
- if self.num_experts == 1:
111
- return torch.ones(batch_size, 1, device=x.device, dtype=x.dtype)
112
-
113
- routing_weights = F.linear(x, self.weight).view(
114
- batch_size, self.num_experts, self.k
115
- )
116
- routing_weights = routing_weights.norm(p=2, dim=2)
117
- return routing_weights
118
-
119
-
120
- class SmileLinearExpert(nn.Module):
121
- __constants__ = ["in_features", "out_features", "k"]
122
- in_features: int
123
- out_features: int
124
- k: int
125
-
126
- def __init__(
127
- self,
128
- in_features,
129
- out_features,
130
- k: int,
131
- bias: bool,
132
- device=None,
133
- dtype=None,
134
- ):
135
- factory_kwargs = {"device": device, "dtype": dtype}
136
- super().__init__()
137
- self.in_features = in_features
138
- self.out_features = out_features
139
- self.k = k
140
-
141
- self.u = nn.Parameter(torch.empty(out_features, k, **factory_kwargs))
142
- self.svh = nn.Parameter(torch.empty(k, in_features, **factory_kwargs))
143
-
144
- if bias:
145
- self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
146
- else:
147
- self.register_parameter("bias", None)
148
-
149
- def forward(self, x):
150
- x = F.linear(x, self.svh)
151
- x = F.linear(x, self.u, self.bias)
152
- return x
153
-
154
-
155
- class SmileLinear(nn.Module):
156
- @torch.no_grad()
157
- def __init__(
158
- self,
159
- config: SmileMistralConfig,
160
- in_features,
161
- out_features,
162
- bias: bool,
163
- device=None,
164
- dtype=None,
165
- ):
166
- factory_kwargs = {"device": device, "dtype": dtype}
167
- super().__init__()
168
- self.num_local_experts = config.num_local_experts
169
- self.num_experts_per_tok = config.num_experts_per_tok
170
- self.rank_of_expert = config.rank_of_expert
171
- self.rank_of_router = config.rank_of_router
172
- self.in_features = in_features
173
- self.out_features = out_features
174
-
175
- # construct the gate network
176
- self.gate = SmileGate(
177
- in_features=in_features,
178
- num_experts=self.num_local_experts,
179
- k=self.rank_of_router,
180
- **factory_kwargs,
181
- )
182
-
183
- # the shared linear
184
- self.shared_linear = nn.Linear(
185
- in_features, out_features, bias=bias, **factory_kwargs
186
- )
187
-
188
- # construct experts
189
- if self.rank_of_expert > 0:
190
- self.experts = nn.ModuleList(
191
- [
192
- SmileLinearExpert(
193
- in_features=in_features,
194
- out_features=out_features,
195
- bias=bias,
196
- k=self.rank_of_expert,
197
- **factory_kwargs,
198
- )
199
- for _ in range(self.num_local_experts)
200
- ]
201
- )
202
- else:
203
- self.experts = nn.ModuleList(
204
- [
205
- nn.Linear(in_features, out_features, bias=bias, **factory_kwargs)
206
- for _ in range(self.num_local_experts)
207
- ]
208
- )
209
-
210
- def forward(self, hidden_states: Tensor):
211
- pretrained_out = self.shared_linear(hidden_states)
212
-
213
- input_shape = hidden_states.size()
214
- hidden_states = hidden_states.view(-1, self.in_features)
215
-
216
- router_logits = self.gate(hidden_states)
217
- routing_weights = F.softmax(router_logits, dim=1)
218
- # sample the expert according to the routing weights
219
- routing_weights, selected_experts = torch.topk(
220
- routing_weights, self.num_experts_per_tok, dim=-1
221
- )
222
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
223
-
224
- final_hidden_states = torch.zeros(
225
- (hidden_states.size(0), self.out_features),
226
- dtype=hidden_states.dtype,
227
- device=hidden_states.device,
228
- )
229
-
230
- # One hot encode the selected experts to create an expert mask
231
- # this will be used to easily index which expert is going to be sollicitated
232
- expert_mask = torch.nn.functional.one_hot(
233
- selected_experts, num_classes=self.num_local_experts
234
- ).permute(2, 1, 0)
235
-
236
- # Loop over all available experts in the model and perform the computation on each expert
237
- for expert_idx in range(self.num_local_experts):
238
- expert_layer = self.experts[expert_idx]
239
- idx, top_x = torch.where(expert_mask[expert_idx])
240
-
241
- # Index the correct hidden states and compute the expert hidden state for
242
- # the current expert. We need to make sure to multiply the output hidden
243
- # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
244
- current_state = hidden_states[None, top_x].reshape(-1, self.in_features)
245
- if current_state.numel() == 0:
246
- continue
247
- current_hidden_states = (
248
- expert_layer(current_state) * routing_weights[top_x, idx, None]
249
- )
250
-
251
- # However `index_add_` only support torch tensors for indexing so we'll use
252
- # the `top_x` tensor here.
253
- final_hidden_states.index_add_(
254
- 0, top_x, current_hidden_states.to(hidden_states.dtype)
255
- )
256
- final_hidden_states = final_hidden_states.reshape(
257
- *input_shape[:-1], self.out_features
258
- )
259
- final_hidden_states = pretrained_out + final_hidden_states
260
- return final_hidden_states
261
-
262
- @property
263
- def weight(self):
264
- """
265
- Mimic linear layer. Bacause in some cases, user might indicate the device (or dtype of parameters) of the linear layer using `linear_layer.weight.device`
266
- """
267
- return self.shared_linear.weight
268
-
269
- @property
270
- def bias(self):
271
- return self.shared_linear.bias
272
-
273
- def __repr__(self):
274
- return (
275
- f"SingularMoELinear("
276
- f"in_features={self.shared_linear.in_features}, "
277
- f"out_features={self.shared_linear.out_features}, "
278
- f"num_local_experts={self.num_local_experts}, "
279
- f"num_experts_per_tok={self.num_experts_per_tok}, "
280
- f"rank_of_router={self.rank_of_router}, "
281
- f"rank_of_expert={self.rank_of_expert}"
282
- f")"
283
- )
284
-
285
-
286
85
  class SmileMistralAttention(nn.Module):
287
86
  """
288
87
  Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
@@ -0,0 +1,8 @@
1
+ from . import register
2
+ from .configuration_smile_qwen2 import SmileQwen2Config
3
+ from .modeling_smile_qwen2 import (
4
+ SmileQwen2ForCausalLM,
5
+ SmileQwen2ForQuestionAnswering,
6
+ SmileQwen2ForSequenceClassification,
7
+ SmileQwen2Model,
8
+ )
@@ -0,0 +1,21 @@
1
+ from transformers import PretrainedConfig
2
+ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
3
+
4
+
5
+ class SmileQwen2Config(Qwen2Config):
6
+ model_type = "smile_qwen2"
7
+
8
+ def __init__(
9
+ self,
10
+ num_experts_per_tok: int = 1,
11
+ rank_of_router: int = None,
12
+ rank_of_expert: int = None,
13
+ num_local_experts: int = None,
14
+ **kwargs,
15
+ ):
16
+ self.num_experts_per_tok = num_experts_per_tok
17
+ self.rank_of_router = rank_of_router
18
+ self.rank_of_expert = rank_of_expert
19
+ self.num_local_experts = num_local_experts
20
+
21
+ super().__init__(**kwargs)