sglang 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. sglang/api.py +7 -1
  2. sglang/bench_latency.py +9 -6
  3. sglang/bench_serving.py +46 -22
  4. sglang/global_config.py +1 -1
  5. sglang/lang/backend/runtime_endpoint.py +60 -49
  6. sglang/lang/compiler.py +2 -2
  7. sglang/lang/interpreter.py +4 -2
  8. sglang/lang/ir.py +16 -7
  9. sglang/srt/constrained/base_tool_cache.py +1 -1
  10. sglang/srt/constrained/fsm_cache.py +12 -2
  11. sglang/srt/constrained/jump_forward.py +13 -2
  12. sglang/srt/layers/activation.py +32 -0
  13. sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
  14. sglang/srt/layers/extend_attention.py +9 -2
  15. sglang/srt/layers/fused_moe/__init__.py +1 -0
  16. sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
  17. sglang/srt/layers/fused_moe/layer.py +587 -0
  18. sglang/srt/layers/layernorm.py +65 -0
  19. sglang/srt/layers/logits_processor.py +7 -2
  20. sglang/srt/layers/pooler.py +50 -0
  21. sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
  22. sglang/srt/layers/radix_attention.py +40 -16
  23. sglang/srt/managers/detokenizer_manager.py +31 -9
  24. sglang/srt/managers/io_struct.py +63 -0
  25. sglang/srt/managers/policy_scheduler.py +173 -25
  26. sglang/srt/managers/schedule_batch.py +115 -97
  27. sglang/srt/managers/tokenizer_manager.py +194 -112
  28. sglang/srt/managers/tp_worker.py +290 -359
  29. sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
  30. sglang/srt/mem_cache/chunk_cache.py +43 -20
  31. sglang/srt/mem_cache/memory_pool.py +2 -2
  32. sglang/srt/mem_cache/radix_cache.py +74 -40
  33. sglang/srt/model_executor/cuda_graph_runner.py +71 -25
  34. sglang/srt/model_executor/forward_batch_info.py +293 -156
  35. sglang/srt/model_executor/model_runner.py +77 -57
  36. sglang/srt/models/chatglm.py +2 -2
  37. sglang/srt/models/commandr.py +1 -1
  38. sglang/srt/models/deepseek.py +2 -2
  39. sglang/srt/models/deepseek_v2.py +7 -6
  40. sglang/srt/models/gemma.py +1 -1
  41. sglang/srt/models/gemma2.py +11 -6
  42. sglang/srt/models/grok.py +50 -396
  43. sglang/srt/models/internlm2.py +2 -7
  44. sglang/srt/models/llama2.py +4 -4
  45. sglang/srt/models/llama_embedding.py +88 -0
  46. sglang/srt/models/minicpm.py +2 -2
  47. sglang/srt/models/mixtral.py +56 -254
  48. sglang/srt/models/mixtral_quant.py +1 -4
  49. sglang/srt/models/qwen.py +2 -2
  50. sglang/srt/models/qwen2.py +2 -2
  51. sglang/srt/models/qwen2_moe.py +2 -13
  52. sglang/srt/models/stablelm.py +1 -1
  53. sglang/srt/openai_api/adapter.py +187 -48
  54. sglang/srt/openai_api/protocol.py +37 -1
  55. sglang/srt/sampling/penaltylib/__init__.py +13 -0
  56. sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
  57. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
  58. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
  59. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
  60. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
  61. sglang/srt/sampling_params.py +31 -8
  62. sglang/srt/server.py +91 -29
  63. sglang/srt/server_args.py +32 -19
  64. sglang/srt/utils.py +32 -15
  65. sglang/test/run_eval.py +10 -1
  66. sglang/test/runners.py +81 -73
  67. sglang/test/simple_eval_humaneval.py +2 -8
  68. sglang/test/simple_eval_mgsm.py +203 -0
  69. sglang/test/srt/sampling/penaltylib/utils.py +337 -0
  70. sglang/test/test_layernorm.py +60 -0
  71. sglang/test/test_programs.py +36 -7
  72. sglang/test/test_utils.py +24 -2
  73. sglang/utils.py +0 -1
  74. sglang/version.py +1 -1
  75. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/METADATA +33 -16
  76. sglang-0.2.13.dist-info/RECORD +112 -0
  77. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/WHEEL +1 -1
  78. sglang/srt/layers/linear.py +0 -884
  79. sglang/srt/layers/quantization/__init__.py +0 -64
  80. sglang/srt/layers/quantization/fp8.py +0 -677
  81. sglang/srt/model_loader/model_loader.py +0 -292
  82. sglang/srt/model_loader/utils.py +0 -275
  83. sglang-0.2.11.dist-info/RECORD +0 -102
  84. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/LICENSE +0 -0
  85. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,88 @@
1
+ from typing import Iterable, Optional, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from transformers import LlamaConfig
6
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
7
+
8
+ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
9
+ from sglang.srt.model_executor.model_runner import InputMetadata
10
+ from sglang.srt.models.llama2 import LlamaForCausalLM, LlamaModel
11
+
12
+
13
+ class LlamaEmbeddingModel(nn.Module):
14
+ def __init__(
15
+ self,
16
+ config: LlamaConfig,
17
+ quant_config=None,
18
+ cache_config=None,
19
+ efficient_weight_load=False,
20
+ ) -> None:
21
+ super().__init__()
22
+ self.model = LlamaModel(config, quant_config=quant_config)
23
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
24
+
25
+ @torch.no_grad()
26
+ def forward(
27
+ self,
28
+ input_ids: torch.Tensor,
29
+ positions: torch.Tensor,
30
+ input_metadata: InputMetadata,
31
+ input_embeds: torch.Tensor = None,
32
+ ) -> EmbeddingPoolerOutput:
33
+ hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
34
+ return self.pooler(hidden_states, input_metadata)
35
+
36
+ def load_weights(
37
+ self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
38
+ ):
39
+ stacked_params_mapping = [
40
+ # (param_name, shard_name, shard_id)
41
+ ("qkv_proj", "q_proj", "q"),
42
+ ("qkv_proj", "k_proj", "k"),
43
+ ("qkv_proj", "v_proj", "v"),
44
+ ("gate_up_proj", "gate_proj", 0),
45
+ ("gate_up_proj", "up_proj", 1),
46
+ ]
47
+ params_dict = dict(self.model.named_parameters())
48
+
49
+ def load_weights_per_param(name, loaded_weight):
50
+ if "rotary_emb.inv_freq" in name or "projector" in name:
51
+ return
52
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
53
+ # Models trained using ColossalAI may include these tensors in
54
+ # the checkpoint. Skip them.
55
+ return
56
+ for param_name, weight_name, shard_id in stacked_params_mapping:
57
+ if weight_name not in name:
58
+ continue
59
+ name = name.replace(weight_name, param_name)
60
+ # Skip loading extra bias for GPTQ models.
61
+ if name.endswith(".bias") and name not in params_dict:
62
+ continue
63
+ if name.startswith("model.vision_tower") and name not in params_dict:
64
+ continue
65
+ param = params_dict[name]
66
+ weight_loader = param.weight_loader
67
+ weight_loader(param, loaded_weight, shard_id)
68
+ break
69
+ else:
70
+ # Skip loading extra bias for GPTQ models.
71
+ if name.endswith(".bias") and name not in params_dict:
72
+ return
73
+ if name.startswith("model.vision_tower") and name not in params_dict:
74
+ return
75
+ param = params_dict[name]
76
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
77
+ weight_loader(param, loaded_weight)
78
+
79
+ if name is None or loaded_weight is None:
80
+ for name, loaded_weight in weights:
81
+ load_weights_per_param(name, loaded_weight)
82
+ else:
83
+ load_weights_per_param(name, loaded_weight)
84
+
85
+
86
+ EntryClass = LlamaEmbeddingModel
87
+ # compat: e5-mistral model.config class == MistralModel
88
+ EntryClassRemapping = [("MistralModel", LlamaEmbeddingModel)]
@@ -22,8 +22,6 @@ import torch
22
22
  from torch import nn
23
23
  from vllm.config import CacheConfig
24
24
  from vllm.distributed import get_tensor_model_parallel_world_size
25
- from vllm.model_executor.layers.activation import SiluAndMul
26
- from vllm.model_executor.layers.layernorm import RMSNorm
27
25
  from vllm.model_executor.layers.linear import (
28
26
  MergedColumnParallelLinear,
29
27
  QKVParallelLinear,
@@ -37,6 +35,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
37
35
  )
38
36
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
37
 
38
+ from sglang.srt.layers.activation import SiluAndMul
39
+ from sglang.srt.layers.layernorm import RMSNorm
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
42
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -18,36 +18,27 @@ limitations under the License.
18
18
  """Inference-only Mixtral model."""
19
19
  from typing import Iterable, Optional, Tuple
20
20
 
21
- import numpy as np
22
21
  import torch
23
- import torch.nn.functional as F
24
22
  from torch import nn
25
23
  from transformers import MixtralConfig
26
- from vllm import _custom_ops as ops
27
24
  from vllm.config import CacheConfig
28
- from vllm.distributed import (
29
- get_tensor_model_parallel_rank,
30
- get_tensor_model_parallel_world_size,
31
- tensor_model_parallel_all_reduce,
32
- )
33
- from vllm.model_executor.layers.fused_moe import fused_moe
34
- from vllm.model_executor.layers.layernorm import RMSNorm
25
+ from vllm.distributed import get_tensor_model_parallel_world_size
26
+ from vllm.model_executor.layers.fused_moe import FusedMoE
35
27
  from vllm.model_executor.layers.linear import (
36
28
  QKVParallelLinear,
37
29
  ReplicatedLinear,
38
30
  RowParallelLinear,
39
31
  )
40
32
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
41
- from vllm.model_executor.layers.quantization.fp8 import Fp8Config
42
33
  from vllm.model_executor.layers.rotary_embedding import get_rope
43
34
  from vllm.model_executor.layers.vocab_parallel_embedding import (
35
+ DEFAULT_VOCAB_PADDING_SIZE,
44
36
  ParallelLMHead,
45
37
  VocabParallelEmbedding,
46
38
  )
47
39
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
48
- from vllm.model_executor.utils import set_weight_attrs
49
- from vllm.utils import print_warning_once
50
40
 
41
+ from sglang.srt.layers.layernorm import RMSNorm
51
42
  from sglang.srt.layers.logits_processor import LogitsProcessor
52
43
  from sglang.srt.layers.radix_attention import RadixAttention
53
44
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -69,216 +60,44 @@ class MixtralMoE(nn.Module):
69
60
  hidden_size: int,
70
61
  intermediate_size: int,
71
62
  params_dtype: Optional[torch.dtype] = None,
72
- tp_size: Optional[int] = None,
73
63
  quant_config: Optional[QuantizationConfig] = None,
64
+ tp_size: Optional[int] = None,
65
+ prefix: str = "",
74
66
  ):
75
67
  super().__init__()
76
- self.tp_size = tp_size or get_tensor_model_parallel_world_size()
77
- self.num_total_experts = num_experts
78
- self.top_k = top_k
79
68
  self.hidden_size = hidden_size
80
- self.intermediate_size = intermediate_size // self.tp_size
81
- self.quant_config = quant_config
82
-
83
- # FIXME(pcmoritz): Make this more general to support different
84
- # quantization schemes
85
- self.use_fp8 = isinstance(quant_config, Fp8Config)
86
-
87
- if params_dtype is None:
88
- params_dtype = torch.get_default_dtype()
89
- self.params_dtype = params_dtype
90
69
 
91
70
  # Gate always runs at half / full precision for now.
92
71
  self.gate = ReplicatedLinear(
93
- self.hidden_size,
94
- self.num_total_experts,
72
+ hidden_size,
73
+ num_experts,
95
74
  bias=False,
96
- params_dtype=self.params_dtype,
75
+ params_dtype=params_dtype,
97
76
  quant_config=None,
77
+ prefix=f"{prefix}.gate",
98
78
  )
99
79
 
100
- if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
101
- params_dtype = torch.float8_e4m3fn
102
-
103
- self.w13_weight = nn.Parameter(
104
- torch.empty(
105
- self.num_total_experts,
106
- 2 * self.intermediate_size,
107
- self.hidden_size,
108
- dtype=params_dtype,
109
- )
110
- )
111
- self.w2_weight = nn.Parameter(
112
- torch.empty(
113
- self.num_total_experts,
114
- self.hidden_size,
115
- self.intermediate_size,
116
- dtype=params_dtype,
117
- )
118
- )
119
-
120
- set_weight_attrs(
121
- self.w13_weight,
122
- {
123
- "weight_loader": self.weight_loader,
124
- },
125
- )
126
- set_weight_attrs(
127
- self.w2_weight,
128
- {
129
- "weight_loader": self.weight_loader,
130
- },
80
+ self.experts = FusedMoE(
81
+ num_experts=num_experts,
82
+ top_k=top_k,
83
+ hidden_size=hidden_size,
84
+ intermediate_size=intermediate_size,
85
+ params_dtype=params_dtype,
86
+ reduce_results=True,
87
+ renormalize=True,
88
+ quant_config=quant_config,
89
+ tp_size=tp_size,
90
+ prefix=f"{prefix}.experts",
131
91
  )
132
92
 
133
- # Used for fp8.
134
- self.w13_scale = None
135
- self.w2_scale = None
136
- self.a13_scale = None
137
- self.a2_scale = None
138
-
139
- if self.use_fp8:
140
- # WEIGHT_SCALE (for fp8)
141
- self.w13_scale = nn.Parameter(
142
- torch.ones(self.num_total_experts, dtype=torch.float32),
143
- requires_grad=False,
144
- )
145
- self.w2_scale = nn.Parameter(
146
- torch.ones(self.num_total_experts, dtype=torch.float32),
147
- requires_grad=False,
148
- )
149
-
150
- # If loading fp8 checkpoint, pass the weight loaders.
151
- # If loading an fp16 checkpoint, do not (we will quantize in
152
- # process_weights_after_loading()
153
- if quant_config.is_checkpoint_fp8_serialized:
154
- set_weight_attrs(
155
- self.w13_scale,
156
- {
157
- "weight_loader": self.weight_loader,
158
- },
159
- )
160
- set_weight_attrs(
161
- self.w2_scale,
162
- {
163
- "weight_loader": self.weight_loader,
164
- },
165
- )
166
-
167
- # ACT_SCALE (for fp8)
168
- if quant_config.activation_scheme == "static":
169
- if not quant_config.is_checkpoint_fp8_serialized:
170
- raise ValueError(
171
- "Found static activation scheme for checkpoint that "
172
- "was not serialized fp8."
173
- )
174
- self.a13_scale = nn.Parameter(
175
- torch.zeros(self.num_total_experts, dtype=torch.float32),
176
- requires_grad=False,
177
- )
178
- self.a2_scale = nn.Parameter(
179
- torch.zeros(self.num_total_experts, dtype=torch.float32),
180
- requires_grad=False,
181
- )
182
-
183
- set_weight_attrs(
184
- self.a13_scale,
185
- {
186
- "weight_loader": self.weight_loader,
187
- },
188
- )
189
- set_weight_attrs(
190
- self.a2_scale,
191
- {
192
- "weight_loader": self.weight_loader,
193
- },
194
- )
195
-
196
- def weight_loader(
197
- self,
198
- param: nn.Parameter,
199
- loaded_weight: torch.Tensor,
200
- weight_name: str,
201
- expert_id: int,
202
- ):
203
- tp_rank = get_tensor_model_parallel_rank()
204
- param_data = param.data
205
- shard_size = self.intermediate_size
206
- shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
207
- if weight_name.endswith("w1.weight"):
208
- param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
209
- if weight_name.endswith("w3.weight"):
210
- param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
211
- shard, :
212
- ]
213
- if weight_name.endswith("w2.weight"):
214
- param_data[expert_id, :, :] = loaded_weight[:, shard]
215
- if "act_scale" in weight_name or "weight_scale" in weight_name:
216
- param_data[expert_id] = loaded_weight
217
-
218
- def process_weights_after_loading(self):
219
- # Fp8 is the only case where we need to process after loading.
220
- if not self.use_fp8:
221
- return
222
-
223
- # If checkpoint is fp16, quantize here.
224
- if not self.quant_config.is_checkpoint_fp8_serialized:
225
- w13_weight = torch.empty_like(
226
- self.w13_weight.data, dtype=torch.float8_e4m3fn
227
- )
228
- w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
229
- for expert in range(self.num_total_experts):
230
- w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
231
- self.w13_weight.data[expert, :, :]
232
- )
233
- w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
234
- self.w2_weight.data[expert, :, :]
235
- )
236
- self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
237
- self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
238
-
239
- # If checkpoint is fp8 + static, cleanup act_scales.
240
- # Since state_dict has an act_scale per expert but our kernels
241
- # are passed one act_scale shared across all experts.
242
- elif self.quant_config.activation_scheme == "static":
243
- if self.a13_scale is None or self.a2_scale is None:
244
- raise ValueError(
245
- "QuantConfig has static quantization, but found "
246
- "activation scales are None."
247
- )
248
-
249
- if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
250
- print_warning_once(
251
- "Found act_scales that are not equal for fp8 MoE layer. "
252
- "Using the maximum across experts for each layer. "
253
- )
254
-
255
- self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
256
- self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
257
-
258
93
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
259
- num_tokens, hidden_size = hidden_states.shape
94
+ # NOTE: hidden_states can have either 1D or 2D shape.
95
+ orig_shape = hidden_states.shape
260
96
  hidden_states = hidden_states.view(-1, self.hidden_size)
261
97
  # router_logits: (num_tokens, n_experts)
262
98
  router_logits, _ = self.gate(hidden_states)
263
- final_hidden_states = fused_moe(
264
- hidden_states,
265
- self.w13_weight,
266
- self.w2_weight,
267
- router_logits,
268
- self.top_k,
269
- renormalize=True,
270
- inplace=True,
271
- use_fp8=self.use_fp8,
272
- w1_scale=self.w13_scale,
273
- w2_scale=self.w2_scale,
274
- a1_scale=self.a13_scale,
275
- a2_scale=self.a2_scale,
276
- )
277
-
278
- if self.tp_size > 1:
279
- final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
280
-
281
- return final_hidden_states.view(num_tokens, hidden_size)
99
+ final_hidden_states = self.experts(hidden_states, router_logits)
100
+ return final_hidden_states.view(orig_shape)
282
101
 
283
102
 
284
103
  class MixtralAttention(nn.Module):
@@ -291,7 +110,7 @@ class MixtralAttention(nn.Module):
291
110
  max_position: int = 4096 * 32,
292
111
  rope_theta: float = 10000,
293
112
  quant_config: Optional[QuantizationConfig] = None,
294
- sliding_window: Optional[int] = None,
113
+ prefix: str = "",
295
114
  ) -> None:
296
115
  super().__init__()
297
116
  self.hidden_size = hidden_size
@@ -314,7 +133,6 @@ class MixtralAttention(nn.Module):
314
133
  self.kv_size = self.num_kv_heads * self.head_dim
315
134
  self.scaling = self.head_dim**-0.5
316
135
  self.rope_theta = rope_theta
317
- self.sliding_window = sliding_window
318
136
 
319
137
  self.qkv_proj = QKVParallelLinear(
320
138
  hidden_size,
@@ -323,12 +141,14 @@ class MixtralAttention(nn.Module):
323
141
  self.total_num_kv_heads,
324
142
  bias=False,
325
143
  quant_config=quant_config,
144
+ prefix=f"{prefix}.qkv_proj",
326
145
  )
327
146
  self.o_proj = RowParallelLinear(
328
147
  self.total_num_heads * self.head_dim,
329
148
  hidden_size,
330
149
  bias=False,
331
150
  quant_config=quant_config,
151
+ prefix=f"{prefix}.o_proj",
332
152
  )
333
153
  self.rotary_emb = get_rope(
334
154
  self.head_dim,
@@ -365,6 +185,7 @@ class MixtralDecoderLayer(nn.Module):
365
185
  config: MixtralConfig,
366
186
  layer_id: int = 0,
367
187
  quant_config: Optional[QuantizationConfig] = None,
188
+ prefix: str = "",
368
189
  ) -> None:
369
190
  super().__init__()
370
191
  self.hidden_size = config.hidden_size
@@ -377,8 +198,8 @@ class MixtralDecoderLayer(nn.Module):
377
198
  num_kv_heads=config.num_key_value_heads,
378
199
  layer_id=layer_id,
379
200
  rope_theta=rope_theta,
380
- sliding_window=config.sliding_window,
381
201
  quant_config=quant_config,
202
+ prefix=f"{prefix}.self_attn",
382
203
  )
383
204
  self.block_sparse_moe = MixtralMoE(
384
205
  num_experts=config.num_local_experts,
@@ -386,6 +207,7 @@ class MixtralDecoderLayer(nn.Module):
386
207
  hidden_size=config.hidden_size,
387
208
  intermediate_size=config.intermediate_size,
388
209
  quant_config=quant_config,
210
+ prefix=f"{prefix}.block_sparse_moe",
389
211
  )
390
212
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
391
213
  self.post_attention_layernorm = RMSNorm(
@@ -422,6 +244,7 @@ class MixtralModel(nn.Module):
422
244
  self,
423
245
  config: MixtralConfig,
424
246
  quant_config: Optional[QuantizationConfig] = None,
247
+ prefix: str = "",
425
248
  ) -> None:
426
249
  super().__init__()
427
250
  self.padding_idx = config.pad_token_id
@@ -431,10 +254,11 @@ class MixtralModel(nn.Module):
431
254
  config.vocab_size,
432
255
  config.hidden_size,
433
256
  )
434
- # config.num_hidden_layers=16
435
257
  self.layers = nn.ModuleList(
436
258
  [
437
- MixtralDecoderLayer(config, i, quant_config=quant_config)
259
+ MixtralDecoderLayer(
260
+ config, i, quant_config=quant_config, prefix=f"{prefix}.layers"
261
+ )
438
262
  for i in range(config.num_hidden_layers)
439
263
  ]
440
264
  )
@@ -462,6 +286,7 @@ class MixtralModel(nn.Module):
462
286
 
463
287
 
464
288
  class MixtralForCausalLM(nn.Module):
289
+
465
290
  def __init__(
466
291
  self,
467
292
  config: MixtralConfig,
@@ -471,11 +296,10 @@ class MixtralForCausalLM(nn.Module):
471
296
  super().__init__()
472
297
  self.config = config
473
298
  self.quant_config = quant_config
474
- self.model = MixtralModel(config, quant_config=quant_config)
299
+ self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
475
300
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
476
301
  self.logits_processor = LogitsProcessor(config)
477
302
 
478
- @torch.no_grad()
479
303
  def forward(
480
304
  self,
481
305
  input_ids: torch.Tensor,
@@ -496,40 +320,13 @@ class MixtralForCausalLM(nn.Module):
496
320
  ("qkv_proj", "v_proj", "v"),
497
321
  ]
498
322
 
499
- expert_params_mapping = (
500
- [
501
- # These are the weight scales for the experts
502
- # (param_name, weight_name, expert_id)
503
- (
504
- "w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
505
- f"experts.{expert_id}.{weight_name}.weight_scale",
506
- expert_id,
507
- )
508
- for expert_id in range(self.config.num_local_experts)
509
- for weight_name in ["w1", "w2", "w3"]
510
- ]
511
- + [
512
- # These are the weights for the experts
513
- # (param_name, weight_name, expert_id)
514
- (
515
- "w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
516
- f"experts.{expert_id}.{weight_name}.weight",
517
- expert_id,
518
- )
519
- for expert_id in range(self.config.num_local_experts)
520
- for weight_name in ["w1", "w2", "w3"]
521
- ]
522
- + [
523
- # These are the activation scales for the experts
524
- # (param_name, weight_name, expert_id)
525
- (
526
- "a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
527
- f"experts.{expert_id}.{weight_name}.act_scale",
528
- expert_id,
529
- )
530
- for expert_id in range(self.config.num_local_experts)
531
- for weight_name in ["w1", "w2", "w3"]
532
- ]
323
+ # Params for weights, fp8 weight scales, fp8 activation scales
324
+ # (param_name, weight_name, expert_id, shard_id)
325
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
326
+ ckpt_gate_proj_name="w1",
327
+ ckpt_down_proj_name="w2",
328
+ ckpt_up_proj_name="w3",
329
+ num_experts=self.config.num_local_experts,
533
330
  )
534
331
 
535
332
  params_dict = dict(self.named_parameters())
@@ -544,25 +341,35 @@ class MixtralForCausalLM(nn.Module):
544
341
  # Skip loading extra bias for GPTQ models.
545
342
  if name.endswith(".bias") and name not in params_dict:
546
343
  continue
344
+
547
345
  param = params_dict[name]
548
346
  weight_loader = param.weight_loader
549
347
  weight_loader(param, loaded_weight, shard_id)
550
348
  break
551
349
  else:
552
- for param_name, weight_name, expert_id in expert_params_mapping:
350
+ for mapping in expert_params_mapping:
351
+ param_name, weight_name, expert_id, shard_id = mapping
553
352
  if weight_name not in name:
554
353
  continue
555
354
  name = name.replace(weight_name, param_name)
355
+
556
356
  param = params_dict[name]
557
357
  weight_loader = param.weight_loader
558
358
  weight_loader(
559
- param, loaded_weight, weight_name, expert_id=expert_id
359
+ param,
360
+ loaded_weight,
361
+ weight_name,
362
+ shard_id=shard_id,
363
+ expert_id=expert_id,
560
364
  )
561
365
  break
562
366
  else:
563
367
  # Skip loading extra bias for GPTQ models.
564
368
  if name.endswith(".bias") and name not in params_dict:
565
369
  continue
370
+ if name is None:
371
+ continue
372
+
566
373
  param = params_dict[name]
567
374
  weight_loader = getattr(
568
375
  param, "weight_loader", default_weight_loader
@@ -570,9 +377,4 @@ class MixtralForCausalLM(nn.Module):
570
377
  weight_loader(param, loaded_weight)
571
378
 
572
379
 
573
- def all_close_1d(x: torch.Tensor) -> bool:
574
- assert len(x.shape) == 1
575
- return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
576
-
577
-
578
380
  EntryClass = MixtralForCausalLM
@@ -29,7 +29,6 @@ from vllm.distributed import (
29
29
  get_tensor_model_parallel_world_size,
30
30
  tensor_model_parallel_all_reduce,
31
31
  )
32
- from vllm.model_executor.layers.layernorm import RMSNorm
33
32
  from vllm.model_executor.layers.linear import (
34
33
  QKVParallelLinear,
35
34
  ReplicatedLinear,
@@ -43,6 +42,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
43
42
  )
44
43
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
45
44
 
45
+ from sglang.srt.layers.layernorm import RMSNorm
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
48
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -160,7 +160,6 @@ class MixtralAttention(nn.Module):
160
160
  max_position: int = 4096 * 32,
161
161
  rope_theta: float = 10000,
162
162
  quant_config: Optional[QuantizationConfig] = None,
163
- sliding_window: Optional[int] = None,
164
163
  ) -> None:
165
164
  super().__init__()
166
165
  self.hidden_size = hidden_size
@@ -183,7 +182,6 @@ class MixtralAttention(nn.Module):
183
182
  self.kv_size = self.num_kv_heads * self.head_dim
184
183
  self.scaling = self.head_dim**-0.5
185
184
  self.rope_theta = rope_theta
186
- self.sliding_window = sliding_window
187
185
 
188
186
  self.qkv_proj = QKVParallelLinear(
189
187
  hidden_size,
@@ -246,7 +244,6 @@ class MixtralDecoderLayer(nn.Module):
246
244
  num_kv_heads=config.num_key_value_heads,
247
245
  layer_id=layer_id,
248
246
  rope_theta=rope_theta,
249
- sliding_window=config.sliding_window,
250
247
  quant_config=quant_config,
251
248
  )
252
249
  self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config)
sglang/srt/models/qwen.py CHANGED
@@ -22,8 +22,6 @@ from torch import nn
22
22
  from transformers import PretrainedConfig
23
23
  from vllm.config import CacheConfig
24
24
  from vllm.distributed import get_tensor_model_parallel_world_size
25
- from vllm.model_executor.layers.activation import SiluAndMul
26
- from vllm.model_executor.layers.layernorm import RMSNorm
27
25
  from vllm.model_executor.layers.linear import (
28
26
  MergedColumnParallelLinear,
29
27
  QKVParallelLinear,
@@ -37,6 +35,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
37
35
  )
38
36
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
37
 
38
+ from sglang.srt.layers.activation import SiluAndMul
39
+ from sglang.srt.layers.layernorm import RMSNorm
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
42
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -22,8 +22,6 @@ import torch
22
22
  from torch import nn
23
23
  from vllm.config import CacheConfig
24
24
  from vllm.distributed import get_tensor_model_parallel_world_size
25
- from vllm.model_executor.layers.activation import SiluAndMul
26
- from vllm.model_executor.layers.layernorm import RMSNorm
27
25
  from vllm.model_executor.layers.linear import (
28
26
  MergedColumnParallelLinear,
29
27
  QKVParallelLinear,
@@ -37,6 +35,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
37
35
  )
38
36
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
37
 
38
+ from sglang.srt.layers.activation import SiluAndMul
39
+ from sglang.srt.layers.layernorm import RMSNorm
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
42
  from sglang.srt.model_executor.forward_batch_info import InputMetadata