sglang 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. sglang/__init__.py +3 -1
  2. sglang/api.py +7 -7
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +158 -11
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/bench_latency.py +299 -0
  8. sglang/global_config.py +12 -2
  9. sglang/lang/compiler.py +2 -2
  10. sglang/lang/interpreter.py +114 -67
  11. sglang/lang/ir.py +28 -3
  12. sglang/launch_server.py +4 -1
  13. sglang/launch_server_llavavid.py +2 -1
  14. sglang/srt/constrained/__init__.py +13 -6
  15. sglang/srt/constrained/fsm_cache.py +8 -2
  16. sglang/srt/constrained/jump_forward.py +113 -25
  17. sglang/srt/conversation.py +2 -0
  18. sglang/srt/flush_cache.py +3 -1
  19. sglang/srt/hf_transformers_utils.py +130 -1
  20. sglang/srt/layers/extend_attention.py +17 -0
  21. sglang/srt/layers/fused_moe.py +582 -0
  22. sglang/srt/layers/logits_processor.py +65 -32
  23. sglang/srt/layers/radix_attention.py +41 -7
  24. sglang/srt/layers/token_attention.py +16 -1
  25. sglang/srt/managers/controller/dp_worker.py +113 -0
  26. sglang/srt/managers/{router → controller}/infer_batch.py +242 -100
  27. sglang/srt/managers/controller/manager_multi.py +191 -0
  28. sglang/srt/managers/{router/manager.py → controller/manager_single.py} +34 -14
  29. sglang/srt/managers/{router → controller}/model_runner.py +262 -158
  30. sglang/srt/managers/{router → controller}/radix_cache.py +11 -1
  31. sglang/srt/managers/{router/scheduler.py → controller/schedule_heuristic.py} +9 -7
  32. sglang/srt/managers/{router/model_rpc.py → controller/tp_worker.py} +298 -267
  33. sglang/srt/managers/detokenizer_manager.py +42 -46
  34. sglang/srt/managers/io_struct.py +22 -12
  35. sglang/srt/managers/tokenizer_manager.py +151 -87
  36. sglang/srt/model_config.py +83 -5
  37. sglang/srt/models/chatglm.py +399 -0
  38. sglang/srt/models/commandr.py +10 -13
  39. sglang/srt/models/dbrx.py +9 -15
  40. sglang/srt/models/gemma.py +12 -15
  41. sglang/srt/models/grok.py +738 -0
  42. sglang/srt/models/llama2.py +26 -15
  43. sglang/srt/models/llama_classification.py +104 -0
  44. sglang/srt/models/llava.py +86 -19
  45. sglang/srt/models/llavavid.py +11 -20
  46. sglang/srt/models/mixtral.py +282 -103
  47. sglang/srt/models/mixtral_quant.py +372 -0
  48. sglang/srt/models/qwen.py +9 -13
  49. sglang/srt/models/qwen2.py +11 -13
  50. sglang/srt/models/stablelm.py +9 -15
  51. sglang/srt/models/yivl.py +17 -22
  52. sglang/srt/openai_api_adapter.py +150 -95
  53. sglang/srt/openai_protocol.py +11 -2
  54. sglang/srt/server.py +124 -48
  55. sglang/srt/server_args.py +128 -48
  56. sglang/srt/utils.py +234 -67
  57. sglang/test/test_programs.py +65 -3
  58. sglang/test/test_utils.py +32 -1
  59. sglang/utils.py +23 -4
  60. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/METADATA +40 -27
  61. sglang-0.1.18.dist-info/RECORD +78 -0
  62. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
  63. sglang/srt/backend_config.py +0 -13
  64. sglang/srt/models/dbrx_config.py +0 -281
  65. sglang/srt/weight_utils.py +0 -417
  66. sglang-0.1.16.dist-info/RECORD +0 -72
  67. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
  68. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,21 @@
1
1
  # Adapted from
2
- # https://github.com/vllm-project/vllm/blob/d0215a58e78572d91dadafe9d832a2db89b09a13/vllm/model_executor/models/mixtral.py#L1
2
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
3
3
  """Inference-only Mixtral model."""
4
- from typing import Optional
4
+ from typing import Iterable, Optional, Tuple
5
5
 
6
6
  import numpy as np
7
7
  import torch
8
8
  import torch.nn.functional as F
9
9
  from torch import nn
10
10
  from transformers import MixtralConfig
11
+ from vllm import _custom_ops as ops
12
+ from vllm.config import CacheConfig
11
13
  from vllm.distributed import (
12
14
  get_tensor_model_parallel_rank,
13
15
  get_tensor_model_parallel_world_size,
14
16
  tensor_model_parallel_all_reduce,
15
17
  )
18
+ from vllm.model_executor.layers.fused_moe import fused_moe
16
19
  from vllm.model_executor.layers.layernorm import RMSNorm
17
20
  from vllm.model_executor.layers.linear import (
18
21
  QKVParallelLinear,
@@ -20,118 +23,247 @@ from vllm.model_executor.layers.linear import (
20
23
  RowParallelLinear,
21
24
  )
22
25
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
26
+ from vllm.model_executor.layers.quantization.fp8 import Fp8Config
23
27
  from vllm.model_executor.layers.rotary_embedding import get_rope
24
28
  from vllm.model_executor.layers.vocab_parallel_embedding import (
25
29
  ParallelLMHead,
26
30
  VocabParallelEmbedding,
27
31
  )
32
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
33
+ from vllm.model_executor.utils import set_weight_attrs
34
+ from vllm.utils import print_warning_once
28
35
 
29
36
  from sglang.srt.layers.logits_processor import LogitsProcessor
30
37
  from sglang.srt.layers.radix_attention import RadixAttention
31
- from sglang.srt.managers.router.model_runner import InputMetadata
32
- from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
38
+ from sglang.srt.managers.controller.model_runner import InputMetadata
33
39
 
34
40
 
35
- class MixtralMLP(nn.Module):
41
+ class MixtralMoE(nn.Module):
42
+ """A tensor-parallel MoE implementation for Mixtral that shards each expert
43
+ across all ranks.
44
+
45
+ Each expert's weights are sharded across all ranks and a fused MoE
46
+ kernel is used for the forward pass, and finally we reduce the outputs
47
+ across ranks.
48
+ """
49
+
36
50
  def __init__(
37
51
  self,
38
52
  num_experts: int,
53
+ top_k: int,
39
54
  hidden_size: int,
40
55
  intermediate_size: int,
56
+ params_dtype: Optional[torch.dtype] = None,
57
+ tp_size: Optional[int] = None,
41
58
  quant_config: Optional[QuantizationConfig] = None,
42
- ) -> None:
59
+ ):
43
60
  super().__init__()
44
- self.num_experts = num_experts
45
- self.ffn_dim = intermediate_size
46
- self.hidden_dim = hidden_size
61
+ self.tp_size = tp_size or get_tensor_model_parallel_world_size()
62
+ self.num_total_experts = num_experts
63
+ self.top_k = top_k
64
+ self.hidden_size = hidden_size
65
+ self.intermediate_size = intermediate_size // self.tp_size
66
+ self.quant_config = quant_config
67
+
68
+ # FIXME(pcmoritz): Make this more general to support different
69
+ # quantization schemes
70
+ self.use_fp8 = isinstance(quant_config, Fp8Config)
47
71
 
48
- self.w1 = ReplicatedLinear(
49
- self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
72
+ if params_dtype is None:
73
+ params_dtype = torch.get_default_dtype()
74
+ self.params_dtype = params_dtype
75
+
76
+ # Gate always runs at half / full precision for now.
77
+ self.gate = ReplicatedLinear(
78
+ self.hidden_size,
79
+ self.num_total_experts,
80
+ bias=False,
81
+ params_dtype=self.params_dtype,
82
+ quant_config=None,
50
83
  )
51
- self.w2 = ReplicatedLinear(
52
- self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config
84
+
85
+ if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
86
+ params_dtype = torch.float8_e4m3fn
87
+
88
+ self.w13_weight = nn.Parameter(
89
+ torch.empty(
90
+ self.num_total_experts,
91
+ 2 * self.intermediate_size,
92
+ self.hidden_size,
93
+ dtype=params_dtype,
94
+ )
53
95
  )
54
- self.w3 = ReplicatedLinear(
55
- self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
96
+ self.w2_weight = nn.Parameter(
97
+ torch.empty(
98
+ self.num_total_experts,
99
+ self.hidden_size,
100
+ self.intermediate_size,
101
+ dtype=params_dtype,
102
+ )
56
103
  )
57
104
 
58
- # TODO: Use vllm's SiluAndMul
59
- self.act_fn = nn.SiLU()
105
+ set_weight_attrs(
106
+ self.w13_weight,
107
+ {
108
+ "weight_loader": self.weight_loader,
109
+ },
110
+ )
111
+ set_weight_attrs(
112
+ self.w2_weight,
113
+ {
114
+ "weight_loader": self.weight_loader,
115
+ },
116
+ )
60
117
 
61
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
62
- w1_out, _ = self.w1(hidden_states)
63
- w1_out = self.act_fn(w1_out)
64
- w3_out, _ = self.w3(hidden_states)
65
- current_hidden_states = w1_out * w3_out
66
- current_hidden_states, _ = self.w2(current_hidden_states)
67
- return current_hidden_states
118
+ # Used for fp8.
119
+ self.w13_scale = None
120
+ self.w2_scale = None
121
+ self.a13_scale = None
122
+ self.a2_scale = None
123
+
124
+ if self.use_fp8:
125
+ # WEIGHT_SCALE (for fp8)
126
+ self.w13_scale = nn.Parameter(
127
+ torch.ones(self.num_total_experts, dtype=torch.float32),
128
+ requires_grad=False,
129
+ )
130
+ self.w2_scale = nn.Parameter(
131
+ torch.ones(self.num_total_experts, dtype=torch.float32),
132
+ requires_grad=False,
133
+ )
68
134
 
135
+ # If loading fp8 checkpoint, pass the weight loaders.
136
+ # If loading an fp16 checkpoint, do not (we will quantize in
137
+ # process_weights_after_loading()
138
+ if quant_config.is_checkpoint_fp8_serialized:
139
+ set_weight_attrs(
140
+ self.w13_scale,
141
+ {
142
+ "weight_loader": self.weight_loader,
143
+ },
144
+ )
145
+ set_weight_attrs(
146
+ self.w2_scale,
147
+ {
148
+ "weight_loader": self.weight_loader,
149
+ },
150
+ )
69
151
 
70
- class MixtralMoE(nn.Module):
71
- def __init__(
152
+ # ACT_SCALE (for fp8)
153
+ if quant_config.activation_scheme == "static":
154
+ if not quant_config.is_checkpoint_fp8_serialized:
155
+ raise ValueError(
156
+ "Found static activation scheme for checkpoint that "
157
+ "was not serialized fp8."
158
+ )
159
+ self.a13_scale = nn.Parameter(
160
+ torch.zeros(self.num_total_experts, dtype=torch.float32),
161
+ requires_grad=False,
162
+ )
163
+ self.a2_scale = nn.Parameter(
164
+ torch.zeros(self.num_total_experts, dtype=torch.float32),
165
+ requires_grad=False,
166
+ )
167
+
168
+ set_weight_attrs(
169
+ self.a13_scale,
170
+ {
171
+ "weight_loader": self.weight_loader,
172
+ },
173
+ )
174
+ set_weight_attrs(
175
+ self.a2_scale,
176
+ {
177
+ "weight_loader": self.weight_loader,
178
+ },
179
+ )
180
+
181
+ def weight_loader(
72
182
  self,
73
- config: MixtralConfig,
74
- quant_config: Optional[QuantizationConfig] = None,
183
+ param: nn.Parameter,
184
+ loaded_weight: torch.Tensor,
185
+ weight_name: str,
186
+ expert_id: int,
75
187
  ):
76
- super().__init__()
77
- self.config = config
78
- self.rank = get_tensor_model_parallel_rank()
79
- self.tp_size = get_tensor_model_parallel_world_size()
80
- self.num_total_experts = config.num_local_experts
81
- self.top_k = config.num_experts_per_tok
82
- if self.tp_size > self.num_total_experts:
83
- raise ValueError(
84
- f"Tensor parallel size {self.tp_size} is greater than "
85
- f"the number of experts {self.num_total_experts}."
188
+ tp_rank = get_tensor_model_parallel_rank()
189
+ param_data = param.data
190
+ shard_size = self.intermediate_size
191
+ shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
192
+ if weight_name.endswith("w1.weight"):
193
+ param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
194
+ if weight_name.endswith("w3.weight"):
195
+ param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
196
+ shard, :
197
+ ]
198
+ if weight_name.endswith("w2.weight"):
199
+ param_data[expert_id, :, :] = loaded_weight[:, shard]
200
+ if "act_scale" in weight_name or "weight_scale" in weight_name:
201
+ param_data[expert_id] = loaded_weight
202
+
203
+ def process_weights_after_loading(self):
204
+ # Fp8 is the only case where we need to process after loading.
205
+ if not self.use_fp8:
206
+ return
207
+
208
+ # If checkpoint is fp16, quantize here.
209
+ if not self.quant_config.is_checkpoint_fp8_serialized:
210
+ w13_weight = torch.empty_like(
211
+ self.w13_weight.data, dtype=torch.float8_e4m3fn
86
212
  )
87
- # Split experts equally between ranks
88
- self.expert_indicies = np.array_split(
89
- range(self.num_total_experts), self.tp_size
90
- )[self.rank].tolist()
91
- if not self.expert_indicies:
92
- raise ValueError(f"Rank {self.rank} has no experts assigned to it.")
93
-
94
- self.experts = nn.ModuleList(
95
- [
96
- (
97
- MixtralMLP(
98
- self.num_total_experts,
99
- config.hidden_size,
100
- config.intermediate_size,
101
- quant_config=quant_config,
102
- )
103
- if idx in self.expert_indicies
104
- else None
213
+ w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
214
+ for expert in range(self.num_total_experts):
215
+ w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
216
+ self.w13_weight.data[expert, :, :]
217
+ )
218
+ w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
219
+ self.w2_weight.data[expert, :, :]
220
+ )
221
+ self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
222
+ self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
223
+
224
+ # If checkpoint is fp8 + static, cleanup act_scales.
225
+ # Since state_dict has an act_scale per expert but our kernels
226
+ # are passed one act_scale shared across all experts.
227
+ elif self.quant_config.activation_scheme == "static":
228
+ if self.a13_scale is None or self.a2_scale is None:
229
+ raise ValueError(
230
+ "QuantConfig has static quantization, but found "
231
+ "activation scales are None."
232
+ )
233
+
234
+ if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
235
+ print_warning_once(
236
+ "Found act_scales that are not equal for fp8 MoE layer. "
237
+ "Using the maximum across experts for each layer. "
105
238
  )
106
- for idx in range(self.num_total_experts)
107
- ]
108
- )
109
- self.gate = ReplicatedLinear(
110
- config.hidden_size, self.num_total_experts, bias=False, linear_method=None
111
- )
239
+
240
+ self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
241
+ self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
112
242
 
113
243
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
244
+ num_tokens, hidden_size = hidden_states.shape
245
+ hidden_states = hidden_states.view(-1, self.hidden_size)
246
+ # router_logits: (num_tokens, n_experts)
114
247
  router_logits, _ = self.gate(hidden_states)
115
-
116
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
117
- routing_weights, selected_experts = torch.topk(
118
- routing_weights, self.top_k, dim=-1
248
+ final_hidden_states = fused_moe(
249
+ hidden_states,
250
+ self.w13_weight,
251
+ self.w2_weight,
252
+ router_logits,
253
+ self.top_k,
254
+ renormalize=True,
255
+ inplace=True,
256
+ use_fp8=self.use_fp8,
257
+ w1_scale=self.w13_scale,
258
+ w2_scale=self.w2_scale,
259
+ a1_scale=self.a13_scale,
260
+ a2_scale=self.a2_scale,
119
261
  )
120
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
121
262
 
122
- final_hidden_states = None
123
- for expert_idx in self.expert_indicies:
124
- expert_layer = self.experts[expert_idx]
125
- expert_mask = selected_experts == expert_idx
126
- expert_weights = (routing_weights * expert_mask).sum(dim=-1, keepdim=True)
263
+ if self.tp_size > 1:
264
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
127
265
 
128
- current_hidden_states = expert_layer(hidden_states).mul_(expert_weights)
129
- if final_hidden_states is None:
130
- final_hidden_states = current_hidden_states
131
- else:
132
- final_hidden_states.add_(current_hidden_states)
133
-
134
- return tensor_model_parallel_all_reduce(final_hidden_states)
266
+ return final_hidden_states.view(num_tokens, hidden_size)
135
267
 
136
268
 
137
269
  class MixtralAttention(nn.Module):
@@ -233,7 +365,13 @@ class MixtralDecoderLayer(nn.Module):
233
365
  sliding_window=config.sliding_window,
234
366
  quant_config=quant_config,
235
367
  )
236
- self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config)
368
+ self.block_sparse_moe = MixtralMoE(
369
+ num_experts=config.num_local_experts,
370
+ top_k=config.num_experts_per_tok,
371
+ hidden_size=config.hidden_size,
372
+ intermediate_size=config.intermediate_size,
373
+ quant_config=quant_config,
374
+ )
237
375
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
238
376
  self.post_attention_layernorm = RMSNorm(
239
377
  config.hidden_size, eps=config.rms_norm_eps
@@ -313,6 +451,7 @@ class MixtralForCausalLM(nn.Module):
313
451
  self,
314
452
  config: MixtralConfig,
315
453
  quant_config: Optional[QuantizationConfig] = None,
454
+ cache_config: Optional[CacheConfig] = None,
316
455
  ) -> None:
317
456
  super().__init__()
318
457
  self.config = config
@@ -333,13 +472,7 @@ class MixtralForCausalLM(nn.Module):
333
472
  input_ids, hidden_states, self.lm_head.weight, input_metadata
334
473
  )
335
474
 
336
- def load_weights(
337
- self,
338
- model_name_or_path: str,
339
- cache_dir: Optional[str] = None,
340
- load_format: str = "auto",
341
- revision: Optional[str] = None,
342
- ):
475
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
343
476
  stacked_params_mapping = [
344
477
  # (param_name, shard_name, shard_id)
345
478
  ("qkv_proj", "q_proj", "q"),
@@ -347,16 +480,47 @@ class MixtralForCausalLM(nn.Module):
347
480
  ("qkv_proj", "v_proj", "v"),
348
481
  ]
349
482
 
483
+ expert_params_mapping = (
484
+ [
485
+ # These are the weight scales for the experts
486
+ # (param_name, weight_name, expert_id)
487
+ (
488
+ "w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
489
+ f"experts.{expert_id}.{weight_name}.weight_scale",
490
+ expert_id,
491
+ )
492
+ for expert_id in range(self.config.num_local_experts)
493
+ for weight_name in ["w1", "w2", "w3"]
494
+ ]
495
+ + [
496
+ # These are the weights for the experts
497
+ # (param_name, weight_name, expert_id)
498
+ (
499
+ "w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
500
+ f"experts.{expert_id}.{weight_name}.weight",
501
+ expert_id,
502
+ )
503
+ for expert_id in range(self.config.num_local_experts)
504
+ for weight_name in ["w1", "w2", "w3"]
505
+ ]
506
+ + [
507
+ # These are the activation scales for the experts
508
+ # (param_name, weight_name, expert_id)
509
+ (
510
+ "a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
511
+ f"experts.{expert_id}.{weight_name}.act_scale",
512
+ expert_id,
513
+ )
514
+ for expert_id in range(self.config.num_local_experts)
515
+ for weight_name in ["w1", "w2", "w3"]
516
+ ]
517
+ )
518
+
350
519
  params_dict = dict(self.named_parameters())
351
- for name, loaded_weight in hf_model_weights_iterator(
352
- model_name_or_path,
353
- cache_dir,
354
- load_format,
355
- revision,
356
- fall_back_to_pt=False,
357
- ):
520
+ for name, loaded_weight in weights:
358
521
  if "rotary_emb.inv_freq" in name:
359
522
  continue
523
+
360
524
  for param_name, weight_name, shard_id in stacked_params_mapping:
361
525
  if weight_name not in name:
362
526
  continue
@@ -369,15 +533,30 @@ class MixtralForCausalLM(nn.Module):
369
533
  weight_loader(param, loaded_weight, shard_id)
370
534
  break
371
535
  else:
372
- # Skip loading extra bias for GPTQ models.
373
- if name.endswith(".bias") and name not in params_dict:
374
- continue
375
- # Skip experts that are not assigned to this worker.
376
- if "block_sparse_moe.experts." in name and name not in params_dict:
377
- continue
378
- param = params_dict[name]
379
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
380
- weight_loader(param, loaded_weight)
536
+ for param_name, weight_name, expert_id in expert_params_mapping:
537
+ if weight_name not in name:
538
+ continue
539
+ name = name.replace(weight_name, param_name)
540
+ param = params_dict[name]
541
+ weight_loader = param.weight_loader
542
+ weight_loader(
543
+ param, loaded_weight, weight_name, expert_id=expert_id
544
+ )
545
+ break
546
+ else:
547
+ # Skip loading extra bias for GPTQ models.
548
+ if name.endswith(".bias") and name not in params_dict:
549
+ continue
550
+ param = params_dict[name]
551
+ weight_loader = getattr(
552
+ param, "weight_loader", default_weight_loader
553
+ )
554
+ weight_loader(param, loaded_weight)
555
+
556
+
557
+ def all_close_1d(x: torch.Tensor) -> bool:
558
+ assert len(x.shape) == 1
559
+ return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
381
560
 
382
561
 
383
562
  EntryClass = MixtralForCausalLM