sglang 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. sglang/__init__.py +5 -1
  2. sglang/api.py +8 -3
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +148 -12
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/global_config.py +11 -1
  8. sglang/lang/chat_template.py +9 -2
  9. sglang/lang/interpreter.py +161 -81
  10. sglang/lang/ir.py +29 -11
  11. sglang/lang/tracer.py +1 -1
  12. sglang/launch_server.py +1 -2
  13. sglang/launch_server_llavavid.py +31 -0
  14. sglang/srt/constrained/fsm_cache.py +3 -0
  15. sglang/srt/flush_cache.py +16 -0
  16. sglang/srt/hf_transformers_utils.py +83 -2
  17. sglang/srt/layers/extend_attention.py +17 -0
  18. sglang/srt/layers/fused_moe.py +485 -0
  19. sglang/srt/layers/logits_processor.py +12 -7
  20. sglang/srt/layers/radix_attention.py +10 -3
  21. sglang/srt/layers/token_attention.py +16 -1
  22. sglang/srt/managers/controller/dp_worker.py +110 -0
  23. sglang/srt/managers/controller/infer_batch.py +619 -0
  24. sglang/srt/managers/controller/manager_multi.py +191 -0
  25. sglang/srt/managers/controller/manager_single.py +97 -0
  26. sglang/srt/managers/controller/model_runner.py +462 -0
  27. sglang/srt/managers/controller/radix_cache.py +267 -0
  28. sglang/srt/managers/controller/schedule_heuristic.py +59 -0
  29. sglang/srt/managers/controller/tp_worker.py +791 -0
  30. sglang/srt/managers/detokenizer_manager.py +45 -45
  31. sglang/srt/managers/io_struct.py +26 -10
  32. sglang/srt/managers/router/infer_batch.py +130 -74
  33. sglang/srt/managers/router/manager.py +7 -9
  34. sglang/srt/managers/router/model_rpc.py +224 -135
  35. sglang/srt/managers/router/model_runner.py +94 -107
  36. sglang/srt/managers/router/radix_cache.py +54 -18
  37. sglang/srt/managers/router/scheduler.py +23 -34
  38. sglang/srt/managers/tokenizer_manager.py +183 -88
  39. sglang/srt/model_config.py +5 -2
  40. sglang/srt/models/commandr.py +15 -22
  41. sglang/srt/models/dbrx.py +22 -29
  42. sglang/srt/models/gemma.py +14 -24
  43. sglang/srt/models/grok.py +671 -0
  44. sglang/srt/models/llama2.py +24 -23
  45. sglang/srt/models/llava.py +85 -25
  46. sglang/srt/models/llavavid.py +298 -0
  47. sglang/srt/models/mixtral.py +254 -130
  48. sglang/srt/models/mixtral_quant.py +373 -0
  49. sglang/srt/models/qwen.py +28 -25
  50. sglang/srt/models/qwen2.py +17 -22
  51. sglang/srt/models/stablelm.py +21 -26
  52. sglang/srt/models/yivl.py +17 -25
  53. sglang/srt/openai_api_adapter.py +140 -95
  54. sglang/srt/openai_protocol.py +10 -1
  55. sglang/srt/server.py +101 -52
  56. sglang/srt/server_args.py +59 -11
  57. sglang/srt/utils.py +242 -75
  58. sglang/test/test_programs.py +44 -0
  59. sglang/test/test_utils.py +32 -1
  60. sglang/utils.py +95 -26
  61. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/METADATA +23 -13
  62. sglang-0.1.17.dist-info/RECORD +81 -0
  63. sglang/srt/backend_config.py +0 -13
  64. sglang/srt/models/dbrx_config.py +0 -281
  65. sglang/srt/weight_utils.py +0 -402
  66. sglang-0.1.15.dist-info/RECORD +0 -69
  67. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
  68. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
  69. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -1,143 +1,234 @@
1
1
  # Adapted from
2
- # https://github.com/vllm-project/vllm/blob/d0215a58e78572d91dadafe9d832a2db89b09a13/vllm/model_executor/models/mixtral.py#L1
2
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
3
3
  """Inference-only Mixtral model."""
4
- from typing import Optional
4
+ from typing import Iterable, Optional, Tuple
5
5
 
6
6
  import numpy as np
7
7
  import torch
8
8
  import torch.nn.functional as F
9
9
  from torch import nn
10
10
  from transformers import MixtralConfig
11
+ from vllm import _custom_ops as ops
12
+ from vllm.config import CacheConfig
13
+ from vllm.distributed import (
14
+ get_tensor_model_parallel_rank,
15
+ get_tensor_model_parallel_world_size,
16
+ tensor_model_parallel_all_reduce,
17
+ )
18
+ from vllm.model_executor.layers.fused_moe import fused_moe
11
19
  from vllm.model_executor.layers.layernorm import RMSNorm
12
20
  from vllm.model_executor.layers.linear import (
13
21
  QKVParallelLinear,
14
22
  ReplicatedLinear,
15
23
  RowParallelLinear,
16
24
  )
17
- from vllm.model_executor.layers.quantization.base_config import (
18
- QuantizationConfig)
25
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
26
+ from vllm.model_executor.layers.quantization.fp8 import Fp8Config
19
27
  from vllm.model_executor.layers.rotary_embedding import get_rope
20
28
  from vllm.model_executor.layers.vocab_parallel_embedding import (
21
29
  ParallelLMHead,
22
30
  VocabParallelEmbedding,
23
31
  )
24
- from vllm.distributed import (
25
- tensor_model_parallel_all_reduce,
26
- )
27
- from vllm.distributed import (
28
- get_tensor_model_parallel_rank,
29
- get_tensor_model_parallel_world_size,
30
- )
31
- from sglang.srt.weight_utils import (
32
- default_weight_loader,
33
- hf_model_weights_iterator,
34
- )
32
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
33
+ from vllm.model_executor.utils import set_weight_attrs
34
+ from vllm.utils import print_warning_once
35
+
35
36
 
36
37
  from sglang.srt.layers.logits_processor import LogitsProcessor
37
38
  from sglang.srt.layers.radix_attention import RadixAttention
38
- from sglang.srt.managers.router.model_runner import InputMetadata
39
-
39
+ from sglang.srt.managers.controller.model_runner import InputMetadata
40
40
 
41
- class MixtralMLP(nn.Module):
42
- def __init__(
43
- self,
44
- num_experts: int,
45
- hidden_size: int,
46
- intermediate_size: int,
47
- quant_config: Optional[QuantizationConfig] = None,
48
- ) -> None:
49
- super().__init__()
50
- self.num_experts = num_experts
51
- self.ffn_dim = intermediate_size
52
- self.hidden_dim = hidden_size
53
41
 
54
- self.w1 = ReplicatedLinear(
55
- self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
56
- )
57
- self.w2 = ReplicatedLinear(
58
- self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config
59
- )
60
- self.w3 = ReplicatedLinear(
61
- self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
62
- )
63
-
64
- # TODO: Use vllm's SiluAndMul
65
- self.act_fn = nn.SiLU()
66
42
 
67
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
68
- w1_out, _ = self.w1(hidden_states)
69
- w1_out = self.act_fn(w1_out)
70
- w3_out, _ = self.w3(hidden_states)
71
- current_hidden_states = w1_out * w3_out
72
- current_hidden_states, _ = self.w2(current_hidden_states)
73
- return current_hidden_states
43
+ class MixtralMoE(nn.Module):
44
+ """A tensor-parallel MoE implementation for Mixtral that shards each expert
45
+ across all ranks.
74
46
 
47
+ Each expert's weights are sharded across all ranks and a fused MoE
48
+ kernel is used for the forward pass, and finally we reduce the outputs
49
+ across ranks.
50
+ """
75
51
 
76
- class MixtralMoE(nn.Module):
77
52
  def __init__(
78
53
  self,
79
- config: MixtralConfig,
54
+ num_experts: int,
55
+ top_k: int,
56
+ hidden_size: int,
57
+ intermediate_size: int,
58
+ params_dtype: Optional[torch.dtype] = None,
59
+ tp_size: Optional[int] = None,
80
60
  quant_config: Optional[QuantizationConfig] = None,
81
61
  ):
82
62
  super().__init__()
83
- self.config = config
84
- self.rank = get_tensor_model_parallel_rank()
85
- self.tp_size = get_tensor_model_parallel_world_size()
86
- self.num_total_experts = config.num_local_experts
87
- self.top_k = config.num_experts_per_tok
88
- if self.tp_size > self.num_total_experts:
89
- raise ValueError(
90
- f"Tensor parallel size {self.tp_size} is greater than "
91
- f"the number of experts {self.num_total_experts}."
92
- )
93
- # Split experts equally between ranks
94
- self.expert_indicies = np.array_split(
95
- range(self.num_total_experts), self.tp_size
96
- )[self.rank].tolist()
97
- if not self.expert_indicies:
98
- raise ValueError(f"Rank {self.rank} has no experts assigned to it.")
99
-
100
- self.experts = nn.ModuleList(
101
- [
102
- (
103
- MixtralMLP(
104
- self.num_total_experts,
105
- config.hidden_size,
106
- config.intermediate_size,
107
- quant_config=quant_config,
108
- )
109
- if idx in self.expert_indicies
110
- else None
111
- )
112
- for idx in range(self.num_total_experts)
113
- ]
114
- )
115
- self.gate = ReplicatedLinear(
116
- config.hidden_size, self.num_total_experts, bias=False, linear_method=None
117
- )
63
+ self.tp_size = tp_size or get_tensor_model_parallel_world_size()
64
+ self.num_total_experts = num_experts
65
+ self.top_k = top_k
66
+ self.hidden_size = hidden_size
67
+ self.intermediate_size = intermediate_size // self.tp_size
68
+ self.quant_config = quant_config
69
+
70
+ # FIXME(pcmoritz): Make this more general to support different
71
+ # quantization schemes
72
+ self.use_fp8 = isinstance(quant_config, Fp8Config)
73
+
74
+ if params_dtype is None:
75
+ params_dtype = torch.get_default_dtype()
76
+ self.params_dtype = params_dtype
77
+
78
+ # Gate always runs at half / full precision for now.
79
+ self.gate = ReplicatedLinear(self.hidden_size,
80
+ self.num_total_experts,
81
+ bias=False,
82
+ params_dtype=self.params_dtype,
83
+ quant_config=None)
84
+
85
+ if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
86
+ params_dtype = torch.float8_e4m3fn
87
+
88
+ self.w13_weight = nn.Parameter(
89
+ torch.empty(self.num_total_experts,
90
+ 2 * self.intermediate_size,
91
+ self.hidden_size,
92
+ dtype=params_dtype))
93
+ self.w2_weight = nn.Parameter(
94
+ torch.empty(self.num_total_experts,
95
+ self.hidden_size,
96
+ self.intermediate_size,
97
+ dtype=params_dtype))
98
+
99
+ set_weight_attrs(self.w13_weight, {
100
+ "weight_loader": self.weight_loader,
101
+ })
102
+ set_weight_attrs(self.w2_weight, {
103
+ "weight_loader": self.weight_loader,
104
+ })
105
+
106
+ # Used for fp8.
107
+ self.w13_scale = None
108
+ self.w2_scale = None
109
+ self.a13_scale = None
110
+ self.a2_scale = None
111
+
112
+ if self.use_fp8:
113
+ # WEIGHT_SCALE (for fp8)
114
+ self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
115
+ dtype=torch.float32),
116
+ requires_grad=False)
117
+ self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
118
+ dtype=torch.float32),
119
+ requires_grad=False)
120
+
121
+ # If loading fp8 checkpoint, pass the weight loaders.
122
+ # If loading an fp16 checkpoint, do not (we will quantize in
123
+ # process_weights_after_loading()
124
+ if quant_config.is_checkpoint_fp8_serialized:
125
+ set_weight_attrs(self.w13_scale, {
126
+ "weight_loader": self.weight_loader,
127
+ })
128
+ set_weight_attrs(self.w2_scale, {
129
+ "weight_loader": self.weight_loader,
130
+ })
131
+
132
+ # ACT_SCALE (for fp8)
133
+ if quant_config.activation_scheme == "static":
134
+ if not quant_config.is_checkpoint_fp8_serialized:
135
+ raise ValueError(
136
+ "Found static activation scheme for checkpoint that "
137
+ "was not serialized fp8.")
138
+ self.a13_scale = nn.Parameter(torch.zeros(
139
+ self.num_total_experts, dtype=torch.float32),
140
+ requires_grad=False)
141
+ self.a2_scale = nn.Parameter(torch.zeros(
142
+ self.num_total_experts, dtype=torch.float32),
143
+ requires_grad=False)
144
+
145
+ set_weight_attrs(self.a13_scale, {
146
+ "weight_loader": self.weight_loader,
147
+ })
148
+ set_weight_attrs(self.a2_scale, {
149
+ "weight_loader": self.weight_loader,
150
+ })
151
+
152
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
153
+ weight_name: str, expert_id: int):
154
+ tp_rank = get_tensor_model_parallel_rank()
155
+ param_data = param.data
156
+ shard_size = self.intermediate_size
157
+ shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
158
+ if weight_name.endswith("w1.weight"):
159
+ param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
160
+ if weight_name.endswith("w3.weight"):
161
+ param_data[expert_id,
162
+ shard_size:2 * shard_size, :] = loaded_weight[shard, :]
163
+ if weight_name.endswith("w2.weight"):
164
+ param_data[expert_id, :, :] = loaded_weight[:, shard]
165
+ if "act_scale" in weight_name or "weight_scale" in weight_name:
166
+ param_data[expert_id] = loaded_weight
167
+
168
+ def process_weights_after_loading(self):
169
+ # Fp8 is the only case where we need to process after loading.
170
+ if not self.use_fp8:
171
+ return
172
+
173
+ # If checkpoint is fp16, quantize here.
174
+ if not self.quant_config.is_checkpoint_fp8_serialized:
175
+ w13_weight = torch.empty_like(self.w13_weight.data,
176
+ dtype=torch.float8_e4m3fn)
177
+ w2_weight = torch.empty_like(self.w2_weight.data,
178
+ dtype=torch.float8_e4m3fn)
179
+ for expert in range(self.num_total_experts):
180
+ w13_weight[expert, :, :], self.w13_scale[
181
+ expert] = ops.scaled_fp8_quant(
182
+ self.w13_weight.data[expert, :, :])
183
+ w2_weight[expert, :, :], self.w2_scale[
184
+ expert] = ops.scaled_fp8_quant(
185
+ self.w2_weight.data[expert, :, :])
186
+ self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
187
+ self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
188
+
189
+ # If checkpoint is fp8 + static, cleanup act_scales.
190
+ # Since state_dict has an act_scale per expert but our kernels
191
+ # are passed one act_scale shared across all experts.
192
+ elif self.quant_config.activation_scheme == "static":
193
+ if self.a13_scale is None or self.a2_scale is None:
194
+ raise ValueError(
195
+ "QuantConfig has static quantization, but found "
196
+ "activation scales are None.")
197
+
198
+ if (not all_close_1d(self.a13_scale)
199
+ or not all_close_1d(self.a2_scale)):
200
+ print_warning_once(
201
+ "Found act_scales that are not equal for fp8 MoE layer. "
202
+ "Using the maximum across experts for each layer. ")
203
+
204
+ self.a13_scale = nn.Parameter(self.a13_scale.max(),
205
+ requires_grad=False)
206
+ self.a2_scale = nn.Parameter(self.a2_scale.max(),
207
+ requires_grad=False)
118
208
 
119
209
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
210
+ num_tokens, hidden_size = hidden_states.shape
211
+ hidden_states = hidden_states.view(-1, self.hidden_size)
212
+ # router_logits: (num_tokens, n_experts)
120
213
  router_logits, _ = self.gate(hidden_states)
121
-
122
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
123
- routing_weights, selected_experts = torch.topk(
124
- routing_weights, self.top_k, dim=-1
125
- )
126
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
127
-
128
- final_hidden_states = None
129
- for expert_idx in self.expert_indicies:
130
- expert_layer = self.experts[expert_idx]
131
- expert_mask = selected_experts == expert_idx
132
- expert_weights = (routing_weights * expert_mask).sum(dim=-1, keepdim=True)
133
-
134
- current_hidden_states = expert_layer(hidden_states).mul_(expert_weights)
135
- if final_hidden_states is None:
136
- final_hidden_states = current_hidden_states
137
- else:
138
- final_hidden_states.add_(current_hidden_states)
139
-
140
- return tensor_model_parallel_all_reduce(final_hidden_states)
214
+ final_hidden_states = fused_moe(hidden_states,
215
+ self.w13_weight,
216
+ self.w2_weight,
217
+ router_logits,
218
+ self.top_k,
219
+ renormalize=True,
220
+ inplace=True,
221
+ use_fp8=self.use_fp8,
222
+ w1_scale=self.w13_scale,
223
+ w2_scale=self.w2_scale,
224
+ a1_scale=self.a13_scale,
225
+ a2_scale=self.a2_scale)
226
+
227
+ if self.tp_size > 1:
228
+ final_hidden_states = tensor_model_parallel_all_reduce(
229
+ final_hidden_states)
230
+
231
+ return final_hidden_states.view(num_tokens, hidden_size)
141
232
 
142
233
 
143
234
  class MixtralAttention(nn.Module):
@@ -239,7 +330,12 @@ class MixtralDecoderLayer(nn.Module):
239
330
  sliding_window=config.sliding_window,
240
331
  quant_config=quant_config,
241
332
  )
242
- self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config)
333
+ self.block_sparse_moe = MixtralMoE(
334
+ num_experts=config.num_local_experts,
335
+ top_k=config.num_experts_per_tok,
336
+ hidden_size=config.hidden_size,
337
+ intermediate_size=config.intermediate_size,
338
+ quant_config=quant_config)
243
339
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
244
340
  self.post_attention_layernorm = RMSNorm(
245
341
  config.hidden_size, eps=config.rms_norm_eps
@@ -319,6 +415,7 @@ class MixtralForCausalLM(nn.Module):
319
415
  self,
320
416
  config: MixtralConfig,
321
417
  quant_config: Optional[QuantizationConfig] = None,
418
+ cache_config: Optional[CacheConfig] = None,
322
419
  ) -> None:
323
420
  super().__init__()
324
421
  self.config = config
@@ -339,13 +436,7 @@ class MixtralForCausalLM(nn.Module):
339
436
  input_ids, hidden_states, self.lm_head.weight, input_metadata
340
437
  )
341
438
 
342
- def load_weights(
343
- self,
344
- model_name_or_path: str,
345
- cache_dir: Optional[str] = None,
346
- load_format: str = "auto",
347
- revision: Optional[str] = None,
348
- ):
439
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
349
440
  stacked_params_mapping = [
350
441
  # (param_name, shard_name, shard_id)
351
442
  ("qkv_proj", "q_proj", "q"),
@@ -353,17 +444,35 @@ class MixtralForCausalLM(nn.Module):
353
444
  ("qkv_proj", "v_proj", "v"),
354
445
  ]
355
446
 
447
+ expert_params_mapping = [
448
+ # These are the weight scales for the experts
449
+ # (param_name, weight_name, expert_id)
450
+ ("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
451
+ f"experts.{expert_id}.{weight_name}.weight_scale", expert_id)
452
+ for expert_id in range(self.config.num_local_experts)
453
+ for weight_name in ["w1", "w2", "w3"]
454
+ ] + [
455
+ # These are the weights for the experts
456
+ # (param_name, weight_name, expert_id)
457
+ ("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
458
+ f"experts.{expert_id}.{weight_name}.weight", expert_id)
459
+ for expert_id in range(self.config.num_local_experts)
460
+ for weight_name in ["w1", "w2", "w3"]
461
+ ] + [
462
+ # These are the activation scales for the experts
463
+ # (param_name, weight_name, expert_id)
464
+ ("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
465
+ f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
466
+ for expert_id in range(self.config.num_local_experts)
467
+ for weight_name in ["w1", "w2", "w3"]
468
+ ]
469
+
356
470
  params_dict = dict(self.named_parameters())
357
- for name, loaded_weight in hf_model_weights_iterator(
358
- model_name_or_path,
359
- cache_dir,
360
- load_format,
361
- revision,
362
- fall_back_to_pt=False,
363
- ):
471
+ for name, loaded_weight in weights:
364
472
  if "rotary_emb.inv_freq" in name:
365
473
  continue
366
- for param_name, weight_name, shard_id in stacked_params_mapping:
474
+
475
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
367
476
  if weight_name not in name:
368
477
  continue
369
478
  name = name.replace(weight_name, param_name)
@@ -375,15 +484,30 @@ class MixtralForCausalLM(nn.Module):
375
484
  weight_loader(param, loaded_weight, shard_id)
376
485
  break
377
486
  else:
378
- # Skip loading extra bias for GPTQ models.
379
- if name.endswith(".bias") and name not in params_dict:
380
- continue
381
- # Skip experts that are not assigned to this worker.
382
- if "block_sparse_moe.experts." in name and name not in params_dict:
383
- continue
384
- param = params_dict[name]
385
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
386
- weight_loader(param, loaded_weight)
487
+ for param_name, weight_name, expert_id in expert_params_mapping:
488
+ if weight_name not in name:
489
+ continue
490
+ name = name.replace(weight_name, param_name)
491
+ param = params_dict[name]
492
+ weight_loader = param.weight_loader
493
+ weight_loader(param,
494
+ loaded_weight,
495
+ weight_name,
496
+ expert_id=expert_id)
497
+ break
498
+ else:
499
+ # Skip loading extra bias for GPTQ models.
500
+ if name.endswith(".bias") and name not in params_dict:
501
+ continue
502
+ param = params_dict[name]
503
+ weight_loader = getattr(param, "weight_loader",
504
+ default_weight_loader)
505
+ weight_loader(param, loaded_weight)
506
+
507
+
508
+ def all_close_1d(x: torch.Tensor) -> bool:
509
+ assert len(x.shape) == 1
510
+ return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
387
511
 
388
512
 
389
513
  EntryClass = MixtralForCausalLM