sglang 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/api.py +13 -1
  2. sglang/bench_latency.py +10 -5
  3. sglang/bench_serving.py +50 -26
  4. sglang/check_env.py +15 -0
  5. sglang/global_config.py +1 -1
  6. sglang/lang/backend/runtime_endpoint.py +60 -49
  7. sglang/lang/chat_template.py +10 -5
  8. sglang/lang/compiler.py +4 -0
  9. sglang/lang/interpreter.py +5 -2
  10. sglang/lang/ir.py +22 -4
  11. sglang/launch_server.py +8 -1
  12. sglang/srt/constrained/jump_forward.py +13 -2
  13. sglang/srt/conversation.py +50 -1
  14. sglang/srt/hf_transformers_utils.py +22 -23
  15. sglang/srt/layers/activation.py +24 -2
  16. sglang/srt/layers/decode_attention.py +338 -50
  17. sglang/srt/layers/extend_attention.py +3 -1
  18. sglang/srt/layers/fused_moe/__init__.py +1 -0
  19. sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
  20. sglang/srt/layers/fused_moe/layer.py +587 -0
  21. sglang/srt/layers/layernorm.py +3 -0
  22. sglang/srt/layers/logits_processor.py +64 -27
  23. sglang/srt/layers/radix_attention.py +41 -18
  24. sglang/srt/layers/sampler.py +154 -0
  25. sglang/srt/managers/controller_multi.py +2 -8
  26. sglang/srt/managers/controller_single.py +7 -10
  27. sglang/srt/managers/detokenizer_manager.py +20 -9
  28. sglang/srt/managers/io_struct.py +44 -11
  29. sglang/srt/managers/policy_scheduler.py +5 -2
  30. sglang/srt/managers/schedule_batch.py +59 -179
  31. sglang/srt/managers/tokenizer_manager.py +193 -84
  32. sglang/srt/managers/tp_worker.py +131 -50
  33. sglang/srt/mem_cache/memory_pool.py +82 -8
  34. sglang/srt/mm_utils.py +79 -7
  35. sglang/srt/model_executor/cuda_graph_runner.py +97 -28
  36. sglang/srt/model_executor/forward_batch_info.py +188 -82
  37. sglang/srt/model_executor/model_runner.py +269 -87
  38. sglang/srt/models/chatglm.py +6 -14
  39. sglang/srt/models/commandr.py +6 -2
  40. sglang/srt/models/dbrx.py +5 -1
  41. sglang/srt/models/deepseek.py +7 -3
  42. sglang/srt/models/deepseek_v2.py +12 -7
  43. sglang/srt/models/gemma.py +6 -2
  44. sglang/srt/models/gemma2.py +22 -8
  45. sglang/srt/models/gpt_bigcode.py +5 -1
  46. sglang/srt/models/grok.py +66 -398
  47. sglang/srt/models/internlm2.py +5 -1
  48. sglang/srt/models/llama2.py +7 -3
  49. sglang/srt/models/llama_classification.py +2 -2
  50. sglang/srt/models/llama_embedding.py +4 -0
  51. sglang/srt/models/llava.py +176 -59
  52. sglang/srt/models/minicpm.py +7 -3
  53. sglang/srt/models/mixtral.py +61 -255
  54. sglang/srt/models/mixtral_quant.py +6 -5
  55. sglang/srt/models/qwen.py +7 -4
  56. sglang/srt/models/qwen2.py +15 -5
  57. sglang/srt/models/qwen2_moe.py +7 -16
  58. sglang/srt/models/stablelm.py +6 -2
  59. sglang/srt/openai_api/adapter.py +149 -58
  60. sglang/srt/sampling/sampling_batch_info.py +209 -0
  61. sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -4
  62. sglang/srt/server.py +107 -71
  63. sglang/srt/server_args.py +49 -15
  64. sglang/srt/utils.py +27 -18
  65. sglang/test/runners.py +38 -38
  66. sglang/test/simple_eval_common.py +9 -10
  67. sglang/test/simple_eval_gpqa.py +2 -1
  68. sglang/test/simple_eval_humaneval.py +2 -2
  69. sglang/test/simple_eval_math.py +2 -1
  70. sglang/test/simple_eval_mmlu.py +2 -1
  71. sglang/test/test_activation.py +55 -0
  72. sglang/test/test_programs.py +32 -5
  73. sglang/test/test_utils.py +37 -50
  74. sglang/version.py +1 -1
  75. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/METADATA +102 -27
  76. sglang-0.2.14.dist-info/RECORD +114 -0
  77. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
  78. sglang/launch_server_llavavid.py +0 -29
  79. sglang/srt/model_loader/model_loader.py +0 -292
  80. sglang/srt/model_loader/utils.py +0 -275
  81. sglang-0.2.12.dist-info/RECORD +0 -112
  82. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
  83. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
@@ -18,38 +18,30 @@ limitations under the License.
18
18
  """Inference-only Mixtral model."""
19
19
  from typing import Iterable, Optional, Tuple
20
20
 
21
- import numpy as np
22
21
  import torch
23
- import torch.nn.functional as F
24
22
  from torch import nn
25
23
  from transformers import MixtralConfig
26
- from vllm import _custom_ops as ops
27
24
  from vllm.config import CacheConfig
28
- from vllm.distributed import (
29
- get_tensor_model_parallel_rank,
30
- get_tensor_model_parallel_world_size,
31
- tensor_model_parallel_all_reduce,
32
- )
33
- from vllm.model_executor.layers.fused_moe import fused_moe
34
- from vllm.model_executor.layers.layernorm import RMSNorm
25
+ from vllm.distributed import get_tensor_model_parallel_world_size
26
+ from vllm.model_executor.layers.fused_moe import FusedMoE
35
27
  from vllm.model_executor.layers.linear import (
36
28
  QKVParallelLinear,
37
29
  ReplicatedLinear,
38
30
  RowParallelLinear,
39
31
  )
40
32
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
41
- from vllm.model_executor.layers.quantization.fp8 import Fp8Config
42
33
  from vllm.model_executor.layers.rotary_embedding import get_rope
43
34
  from vllm.model_executor.layers.vocab_parallel_embedding import (
35
+ DEFAULT_VOCAB_PADDING_SIZE,
44
36
  ParallelLMHead,
45
37
  VocabParallelEmbedding,
46
38
  )
47
39
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
48
- from vllm.model_executor.utils import set_weight_attrs
49
- from vllm.utils import print_warning_once
50
40
 
41
+ from sglang.srt.layers.layernorm import RMSNorm
51
42
  from sglang.srt.layers.logits_processor import LogitsProcessor
52
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
+ from sglang.srt.layers.sampler import Sampler
53
45
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
54
46
 
55
47
 
@@ -69,216 +61,44 @@ class MixtralMoE(nn.Module):
69
61
  hidden_size: int,
70
62
  intermediate_size: int,
71
63
  params_dtype: Optional[torch.dtype] = None,
72
- tp_size: Optional[int] = None,
73
64
  quant_config: Optional[QuantizationConfig] = None,
65
+ tp_size: Optional[int] = None,
66
+ prefix: str = "",
74
67
  ):
75
68
  super().__init__()
76
- self.tp_size = tp_size or get_tensor_model_parallel_world_size()
77
- self.num_total_experts = num_experts
78
- self.top_k = top_k
79
69
  self.hidden_size = hidden_size
80
- self.intermediate_size = intermediate_size // self.tp_size
81
- self.quant_config = quant_config
82
-
83
- # FIXME(pcmoritz): Make this more general to support different
84
- # quantization schemes
85
- self.use_fp8 = isinstance(quant_config, Fp8Config)
86
-
87
- if params_dtype is None:
88
- params_dtype = torch.get_default_dtype()
89
- self.params_dtype = params_dtype
90
70
 
91
71
  # Gate always runs at half / full precision for now.
92
72
  self.gate = ReplicatedLinear(
93
- self.hidden_size,
94
- self.num_total_experts,
73
+ hidden_size,
74
+ num_experts,
95
75
  bias=False,
96
- params_dtype=self.params_dtype,
76
+ params_dtype=params_dtype,
97
77
  quant_config=None,
78
+ prefix=f"{prefix}.gate",
98
79
  )
99
80
 
100
- if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
101
- params_dtype = torch.float8_e4m3fn
102
-
103
- self.w13_weight = nn.Parameter(
104
- torch.empty(
105
- self.num_total_experts,
106
- 2 * self.intermediate_size,
107
- self.hidden_size,
108
- dtype=params_dtype,
109
- )
110
- )
111
- self.w2_weight = nn.Parameter(
112
- torch.empty(
113
- self.num_total_experts,
114
- self.hidden_size,
115
- self.intermediate_size,
116
- dtype=params_dtype,
117
- )
118
- )
119
-
120
- set_weight_attrs(
121
- self.w13_weight,
122
- {
123
- "weight_loader": self.weight_loader,
124
- },
125
- )
126
- set_weight_attrs(
127
- self.w2_weight,
128
- {
129
- "weight_loader": self.weight_loader,
130
- },
81
+ self.experts = FusedMoE(
82
+ num_experts=num_experts,
83
+ top_k=top_k,
84
+ hidden_size=hidden_size,
85
+ intermediate_size=intermediate_size,
86
+ params_dtype=params_dtype,
87
+ reduce_results=True,
88
+ renormalize=True,
89
+ quant_config=quant_config,
90
+ tp_size=tp_size,
91
+ prefix=f"{prefix}.experts",
131
92
  )
132
93
 
133
- # Used for fp8.
134
- self.w13_scale = None
135
- self.w2_scale = None
136
- self.a13_scale = None
137
- self.a2_scale = None
138
-
139
- if self.use_fp8:
140
- # WEIGHT_SCALE (for fp8)
141
- self.w13_scale = nn.Parameter(
142
- torch.ones(self.num_total_experts, dtype=torch.float32),
143
- requires_grad=False,
144
- )
145
- self.w2_scale = nn.Parameter(
146
- torch.ones(self.num_total_experts, dtype=torch.float32),
147
- requires_grad=False,
148
- )
149
-
150
- # If loading fp8 checkpoint, pass the weight loaders.
151
- # If loading an fp16 checkpoint, do not (we will quantize in
152
- # process_weights_after_loading()
153
- if quant_config.is_checkpoint_fp8_serialized:
154
- set_weight_attrs(
155
- self.w13_scale,
156
- {
157
- "weight_loader": self.weight_loader,
158
- },
159
- )
160
- set_weight_attrs(
161
- self.w2_scale,
162
- {
163
- "weight_loader": self.weight_loader,
164
- },
165
- )
166
-
167
- # ACT_SCALE (for fp8)
168
- if quant_config.activation_scheme == "static":
169
- if not quant_config.is_checkpoint_fp8_serialized:
170
- raise ValueError(
171
- "Found static activation scheme for checkpoint that "
172
- "was not serialized fp8."
173
- )
174
- self.a13_scale = nn.Parameter(
175
- torch.zeros(self.num_total_experts, dtype=torch.float32),
176
- requires_grad=False,
177
- )
178
- self.a2_scale = nn.Parameter(
179
- torch.zeros(self.num_total_experts, dtype=torch.float32),
180
- requires_grad=False,
181
- )
182
-
183
- set_weight_attrs(
184
- self.a13_scale,
185
- {
186
- "weight_loader": self.weight_loader,
187
- },
188
- )
189
- set_weight_attrs(
190
- self.a2_scale,
191
- {
192
- "weight_loader": self.weight_loader,
193
- },
194
- )
195
-
196
- def weight_loader(
197
- self,
198
- param: nn.Parameter,
199
- loaded_weight: torch.Tensor,
200
- weight_name: str,
201
- expert_id: int,
202
- ):
203
- tp_rank = get_tensor_model_parallel_rank()
204
- param_data = param.data
205
- shard_size = self.intermediate_size
206
- shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
207
- if weight_name.endswith("w1.weight"):
208
- param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
209
- if weight_name.endswith("w3.weight"):
210
- param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
211
- shard, :
212
- ]
213
- if weight_name.endswith("w2.weight"):
214
- param_data[expert_id, :, :] = loaded_weight[:, shard]
215
- if "act_scale" in weight_name or "weight_scale" in weight_name:
216
- param_data[expert_id] = loaded_weight
217
-
218
- def process_weights_after_loading(self):
219
- # Fp8 is the only case where we need to process after loading.
220
- if not self.use_fp8:
221
- return
222
-
223
- # If checkpoint is fp16, quantize here.
224
- if not self.quant_config.is_checkpoint_fp8_serialized:
225
- w13_weight = torch.empty_like(
226
- self.w13_weight.data, dtype=torch.float8_e4m3fn
227
- )
228
- w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
229
- for expert in range(self.num_total_experts):
230
- w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
231
- self.w13_weight.data[expert, :, :]
232
- )
233
- w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
234
- self.w2_weight.data[expert, :, :]
235
- )
236
- self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
237
- self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
238
-
239
- # If checkpoint is fp8 + static, cleanup act_scales.
240
- # Since state_dict has an act_scale per expert but our kernels
241
- # are passed one act_scale shared across all experts.
242
- elif self.quant_config.activation_scheme == "static":
243
- if self.a13_scale is None or self.a2_scale is None:
244
- raise ValueError(
245
- "QuantConfig has static quantization, but found "
246
- "activation scales are None."
247
- )
248
-
249
- if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
250
- print_warning_once(
251
- "Found act_scales that are not equal for fp8 MoE layer. "
252
- "Using the maximum across experts for each layer. "
253
- )
254
-
255
- self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
256
- self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
257
-
258
94
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
259
- num_tokens, hidden_size = hidden_states.shape
95
+ # NOTE: hidden_states can have either 1D or 2D shape.
96
+ orig_shape = hidden_states.shape
260
97
  hidden_states = hidden_states.view(-1, self.hidden_size)
261
98
  # router_logits: (num_tokens, n_experts)
262
99
  router_logits, _ = self.gate(hidden_states)
263
- final_hidden_states = fused_moe(
264
- hidden_states,
265
- self.w13_weight,
266
- self.w2_weight,
267
- router_logits,
268
- self.top_k,
269
- renormalize=True,
270
- inplace=True,
271
- use_fp8=self.use_fp8,
272
- w1_scale=self.w13_scale,
273
- w2_scale=self.w2_scale,
274
- a1_scale=self.a13_scale,
275
- a2_scale=self.a2_scale,
276
- )
277
-
278
- if self.tp_size > 1:
279
- final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
280
-
281
- return final_hidden_states.view(num_tokens, hidden_size)
100
+ final_hidden_states = self.experts(hidden_states, router_logits)
101
+ return final_hidden_states.view(orig_shape)
282
102
 
283
103
 
284
104
  class MixtralAttention(nn.Module):
@@ -291,7 +111,7 @@ class MixtralAttention(nn.Module):
291
111
  max_position: int = 4096 * 32,
292
112
  rope_theta: float = 10000,
293
113
  quant_config: Optional[QuantizationConfig] = None,
294
- sliding_window: Optional[int] = None,
114
+ prefix: str = "",
295
115
  ) -> None:
296
116
  super().__init__()
297
117
  self.hidden_size = hidden_size
@@ -314,7 +134,6 @@ class MixtralAttention(nn.Module):
314
134
  self.kv_size = self.num_kv_heads * self.head_dim
315
135
  self.scaling = self.head_dim**-0.5
316
136
  self.rope_theta = rope_theta
317
- self.sliding_window = sliding_window
318
137
 
319
138
  self.qkv_proj = QKVParallelLinear(
320
139
  hidden_size,
@@ -323,12 +142,14 @@ class MixtralAttention(nn.Module):
323
142
  self.total_num_kv_heads,
324
143
  bias=False,
325
144
  quant_config=quant_config,
145
+ prefix=f"{prefix}.qkv_proj",
326
146
  )
327
147
  self.o_proj = RowParallelLinear(
328
148
  self.total_num_heads * self.head_dim,
329
149
  hidden_size,
330
150
  bias=False,
331
151
  quant_config=quant_config,
152
+ prefix=f"{prefix}.o_proj",
332
153
  )
333
154
  self.rotary_emb = get_rope(
334
155
  self.head_dim,
@@ -365,6 +186,7 @@ class MixtralDecoderLayer(nn.Module):
365
186
  config: MixtralConfig,
366
187
  layer_id: int = 0,
367
188
  quant_config: Optional[QuantizationConfig] = None,
189
+ prefix: str = "",
368
190
  ) -> None:
369
191
  super().__init__()
370
192
  self.hidden_size = config.hidden_size
@@ -377,8 +199,8 @@ class MixtralDecoderLayer(nn.Module):
377
199
  num_kv_heads=config.num_key_value_heads,
378
200
  layer_id=layer_id,
379
201
  rope_theta=rope_theta,
380
- sliding_window=config.sliding_window,
381
202
  quant_config=quant_config,
203
+ prefix=f"{prefix}.self_attn",
382
204
  )
383
205
  self.block_sparse_moe = MixtralMoE(
384
206
  num_experts=config.num_local_experts,
@@ -386,6 +208,7 @@ class MixtralDecoderLayer(nn.Module):
386
208
  hidden_size=config.hidden_size,
387
209
  intermediate_size=config.intermediate_size,
388
210
  quant_config=quant_config,
211
+ prefix=f"{prefix}.block_sparse_moe",
389
212
  )
390
213
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
391
214
  self.post_attention_layernorm = RMSNorm(
@@ -422,6 +245,7 @@ class MixtralModel(nn.Module):
422
245
  self,
423
246
  config: MixtralConfig,
424
247
  quant_config: Optional[QuantizationConfig] = None,
248
+ prefix: str = "",
425
249
  ) -> None:
426
250
  super().__init__()
427
251
  self.padding_idx = config.pad_token_id
@@ -431,10 +255,11 @@ class MixtralModel(nn.Module):
431
255
  config.vocab_size,
432
256
  config.hidden_size,
433
257
  )
434
- # config.num_hidden_layers=16
435
258
  self.layers = nn.ModuleList(
436
259
  [
437
- MixtralDecoderLayer(config, i, quant_config=quant_config)
260
+ MixtralDecoderLayer(
261
+ config, i, quant_config=quant_config, prefix=f"{prefix}.layers"
262
+ )
438
263
  for i in range(config.num_hidden_layers)
439
264
  ]
440
265
  )
@@ -462,6 +287,7 @@ class MixtralModel(nn.Module):
462
287
 
463
288
 
464
289
  class MixtralForCausalLM(nn.Module):
290
+
465
291
  def __init__(
466
292
  self,
467
293
  config: MixtralConfig,
@@ -471,11 +297,11 @@ class MixtralForCausalLM(nn.Module):
471
297
  super().__init__()
472
298
  self.config = config
473
299
  self.quant_config = quant_config
474
- self.model = MixtralModel(config, quant_config=quant_config)
300
+ self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
475
301
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
476
302
  self.logits_processor = LogitsProcessor(config)
303
+ self.sampler = Sampler()
477
304
 
478
- @torch.no_grad()
479
305
  def forward(
480
306
  self,
481
307
  input_ids: torch.Tensor,
@@ -484,9 +310,11 @@ class MixtralForCausalLM(nn.Module):
484
310
  input_embeds: torch.Tensor = None,
485
311
  ) -> torch.Tensor:
486
312
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
487
- return self.logits_processor(
313
+ logits_output = self.logits_processor(
488
314
  input_ids, hidden_states, self.lm_head.weight, input_metadata
489
315
  )
316
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
317
+ return sample_output, logits_output
490
318
 
491
319
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
492
320
  stacked_params_mapping = [
@@ -496,40 +324,13 @@ class MixtralForCausalLM(nn.Module):
496
324
  ("qkv_proj", "v_proj", "v"),
497
325
  ]
498
326
 
499
- expert_params_mapping = (
500
- [
501
- # These are the weight scales for the experts
502
- # (param_name, weight_name, expert_id)
503
- (
504
- "w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
505
- f"experts.{expert_id}.{weight_name}.weight_scale",
506
- expert_id,
507
- )
508
- for expert_id in range(self.config.num_local_experts)
509
- for weight_name in ["w1", "w2", "w3"]
510
- ]
511
- + [
512
- # These are the weights for the experts
513
- # (param_name, weight_name, expert_id)
514
- (
515
- "w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
516
- f"experts.{expert_id}.{weight_name}.weight",
517
- expert_id,
518
- )
519
- for expert_id in range(self.config.num_local_experts)
520
- for weight_name in ["w1", "w2", "w3"]
521
- ]
522
- + [
523
- # These are the activation scales for the experts
524
- # (param_name, weight_name, expert_id)
525
- (
526
- "a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
527
- f"experts.{expert_id}.{weight_name}.act_scale",
528
- expert_id,
529
- )
530
- for expert_id in range(self.config.num_local_experts)
531
- for weight_name in ["w1", "w2", "w3"]
532
- ]
327
+ # Params for weights, fp8 weight scales, fp8 activation scales
328
+ # (param_name, weight_name, expert_id, shard_id)
329
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
330
+ ckpt_gate_proj_name="w1",
331
+ ckpt_down_proj_name="w2",
332
+ ckpt_up_proj_name="w3",
333
+ num_experts=self.config.num_local_experts,
533
334
  )
534
335
 
535
336
  params_dict = dict(self.named_parameters())
@@ -544,25 +345,35 @@ class MixtralForCausalLM(nn.Module):
544
345
  # Skip loading extra bias for GPTQ models.
545
346
  if name.endswith(".bias") and name not in params_dict:
546
347
  continue
348
+
547
349
  param = params_dict[name]
548
350
  weight_loader = param.weight_loader
549
351
  weight_loader(param, loaded_weight, shard_id)
550
352
  break
551
353
  else:
552
- for param_name, weight_name, expert_id in expert_params_mapping:
354
+ for mapping in expert_params_mapping:
355
+ param_name, weight_name, expert_id, shard_id = mapping
553
356
  if weight_name not in name:
554
357
  continue
555
358
  name = name.replace(weight_name, param_name)
359
+
556
360
  param = params_dict[name]
557
361
  weight_loader = param.weight_loader
558
362
  weight_loader(
559
- param, loaded_weight, weight_name, expert_id=expert_id
363
+ param,
364
+ loaded_weight,
365
+ weight_name,
366
+ shard_id=shard_id,
367
+ expert_id=expert_id,
560
368
  )
561
369
  break
562
370
  else:
563
371
  # Skip loading extra bias for GPTQ models.
564
372
  if name.endswith(".bias") and name not in params_dict:
565
373
  continue
374
+ if name is None:
375
+ continue
376
+
566
377
  param = params_dict[name]
567
378
  weight_loader = getattr(
568
379
  param, "weight_loader", default_weight_loader
@@ -570,9 +381,4 @@ class MixtralForCausalLM(nn.Module):
570
381
  weight_loader(param, loaded_weight)
571
382
 
572
383
 
573
- def all_close_1d(x: torch.Tensor) -> bool:
574
- assert len(x.shape) == 1
575
- return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
576
-
577
-
578
384
  EntryClass = MixtralForCausalLM
@@ -29,7 +29,6 @@ from vllm.distributed import (
29
29
  get_tensor_model_parallel_world_size,
30
30
  tensor_model_parallel_all_reduce,
31
31
  )
32
- from vllm.model_executor.layers.layernorm import RMSNorm
33
32
  from vllm.model_executor.layers.linear import (
34
33
  QKVParallelLinear,
35
34
  ReplicatedLinear,
@@ -43,8 +42,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
43
42
  )
44
43
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
45
44
 
45
+ from sglang.srt.layers.layernorm import RMSNorm
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
+ from sglang.srt.layers.sampler import Sampler
48
49
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
49
50
 
50
51
 
@@ -160,7 +161,6 @@ class MixtralAttention(nn.Module):
160
161
  max_position: int = 4096 * 32,
161
162
  rope_theta: float = 10000,
162
163
  quant_config: Optional[QuantizationConfig] = None,
163
- sliding_window: Optional[int] = None,
164
164
  ) -> None:
165
165
  super().__init__()
166
166
  self.hidden_size = hidden_size
@@ -183,7 +183,6 @@ class MixtralAttention(nn.Module):
183
183
  self.kv_size = self.num_kv_heads * self.head_dim
184
184
  self.scaling = self.head_dim**-0.5
185
185
  self.rope_theta = rope_theta
186
- self.sliding_window = sliding_window
187
186
 
188
187
  self.qkv_proj = QKVParallelLinear(
189
188
  hidden_size,
@@ -246,7 +245,6 @@ class MixtralDecoderLayer(nn.Module):
246
245
  num_kv_heads=config.num_key_value_heads,
247
246
  layer_id=layer_id,
248
247
  rope_theta=rope_theta,
249
- sliding_window=config.sliding_window,
250
248
  quant_config=quant_config,
251
249
  )
252
250
  self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config)
@@ -336,6 +334,7 @@ class QuantMixtralForCausalLM(nn.Module):
336
334
  self.model = MixtralModel(config, quant_config=quant_config)
337
335
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
338
336
  self.logits_processor = LogitsProcessor(config)
337
+ self.sampler = Sampler()
339
338
 
340
339
  @torch.no_grad()
341
340
  def forward(
@@ -346,9 +345,11 @@ class QuantMixtralForCausalLM(nn.Module):
346
345
  input_embeds: torch.Tensor = None,
347
346
  ) -> torch.Tensor:
348
347
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
349
- return self.logits_processor(
348
+ logits_output = self.logits_processor(
350
349
  input_ids, hidden_states, self.lm_head.weight, input_metadata
351
350
  )
351
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
352
+ return sample_output, logits_output
352
353
 
353
354
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
354
355
  stacked_params_mapping = [
sglang/srt/models/qwen.py CHANGED
@@ -22,8 +22,6 @@ from torch import nn
22
22
  from transformers import PretrainedConfig
23
23
  from vllm.config import CacheConfig
24
24
  from vllm.distributed import get_tensor_model_parallel_world_size
25
- from vllm.model_executor.layers.activation import SiluAndMul
26
- from vllm.model_executor.layers.layernorm import RMSNorm
27
25
  from vllm.model_executor.layers.linear import (
28
26
  MergedColumnParallelLinear,
29
27
  QKVParallelLinear,
@@ -37,8 +35,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
37
35
  )
38
36
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
37
 
38
+ from sglang.srt.layers.activation import SiluAndMul
39
+ from sglang.srt.layers.layernorm import RMSNorm
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
+ from sglang.srt.layers.sampler import Sampler
42
43
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
43
44
 
44
45
 
@@ -251,6 +252,7 @@ class QWenLMHeadModel(nn.Module):
251
252
  vocab_size = ((config.vocab_size + 63) // 64) * 64
252
253
  self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
253
254
  self.logits_processor = LogitsProcessor(config)
255
+ self.sampler = Sampler()
254
256
 
255
257
  @torch.no_grad()
256
258
  def forward(
@@ -260,10 +262,11 @@ class QWenLMHeadModel(nn.Module):
260
262
  input_metadata: InputMetadata,
261
263
  ):
262
264
  hidden_states = self.transformer(input_ids, positions, input_metadata)
263
- next_tokens = self.logits_processor(
265
+ logits_output = self.logits_processor(
264
266
  input_ids, hidden_states, self.lm_head.weight, input_metadata
265
267
  )
266
- return next_tokens
268
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
269
+ return sample_output, logits_output
267
270
 
268
271
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
269
272
  stacked_params_mapping = [
@@ -22,8 +22,6 @@ import torch
22
22
  from torch import nn
23
23
  from vllm.config import CacheConfig
24
24
  from vllm.distributed import get_tensor_model_parallel_world_size
25
- from vllm.model_executor.layers.activation import SiluAndMul
26
- from vllm.model_executor.layers.layernorm import RMSNorm
27
25
  from vllm.model_executor.layers.linear import (
28
26
  MergedColumnParallelLinear,
29
27
  QKVParallelLinear,
@@ -37,8 +35,12 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
37
35
  )
38
36
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
37
 
38
+ from sglang.srt.layers.activation import SiluAndMul
39
+ from sglang.srt.layers.layernorm import RMSNorm
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
+ from sglang.srt.layers.pooler import Pooler, PoolingType
41
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
+ from sglang.srt.layers.sampler import Sampler
42
44
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
43
45
 
44
46
  Qwen2Config = None
@@ -275,6 +277,8 @@ class Qwen2ForCausalLM(nn.Module):
275
277
  self.model = Qwen2Model(config, quant_config=quant_config)
276
278
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
277
279
  self.logits_processor = LogitsProcessor(config)
280
+ self.sampler = Sampler()
281
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
278
282
 
279
283
  @torch.no_grad()
280
284
  def forward(
@@ -283,11 +287,17 @@ class Qwen2ForCausalLM(nn.Module):
283
287
  positions: torch.Tensor,
284
288
  input_metadata: InputMetadata,
285
289
  input_embeds: torch.Tensor = None,
290
+ get_embedding: bool = False,
286
291
  ) -> torch.Tensor:
287
292
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
288
- return self.logits_processor(
289
- input_ids, hidden_states, self.lm_head.weight, input_metadata
290
- )
293
+ if not get_embedding:
294
+ logits_output = self.logits_processor(
295
+ input_ids, hidden_states, self.lm_head.weight, input_metadata
296
+ )
297
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
298
+ return sample_output, logits_output
299
+ else:
300
+ return self.pooler(hidden_states, input_metadata)
291
301
 
292
302
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
293
303
  stacked_params_mapping = [