sglang 0.3.0__py3-none-any.whl → 0.3.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/bench_latency.py +17 -8
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +5 -17
  4. sglang/lang/backend/runtime_endpoint.py +5 -2
  5. sglang/lang/interpreter.py +1 -4
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +33 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +1 -3
  14. sglang/srt/layers/activation.py +12 -0
  15. sglang/srt/layers/attention_backend.py +480 -0
  16. sglang/srt/layers/flashinfer_utils.py +235 -0
  17. sglang/srt/layers/fused_moe/layer.py +27 -7
  18. sglang/srt/layers/layernorm.py +12 -0
  19. sglang/srt/layers/logits_processor.py +64 -77
  20. sglang/srt/layers/radix_attention.py +11 -161
  21. sglang/srt/layers/sampler.py +38 -122
  22. sglang/srt/layers/torchao_utils.py +75 -0
  23. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  24. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  25. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  26. sglang/srt/lora/lora.py +403 -0
  27. sglang/srt/lora/lora_config.py +43 -0
  28. sglang/srt/lora/lora_manager.py +259 -0
  29. sglang/srt/managers/controller_multi.py +1 -5
  30. sglang/srt/managers/controller_single.py +0 -5
  31. sglang/srt/managers/io_struct.py +16 -1
  32. sglang/srt/managers/policy_scheduler.py +122 -5
  33. sglang/srt/managers/schedule_batch.py +105 -71
  34. sglang/srt/managers/tokenizer_manager.py +17 -8
  35. sglang/srt/managers/tp_worker.py +188 -121
  36. sglang/srt/model_executor/cuda_graph_runner.py +69 -133
  37. sglang/srt/model_executor/forward_batch_info.py +35 -312
  38. sglang/srt/model_executor/model_runner.py +123 -154
  39. sglang/srt/models/baichuan.py +416 -0
  40. sglang/srt/models/chatglm.py +1 -5
  41. sglang/srt/models/commandr.py +1 -5
  42. sglang/srt/models/dbrx.py +1 -5
  43. sglang/srt/models/deepseek.py +1 -5
  44. sglang/srt/models/deepseek_v2.py +7 -6
  45. sglang/srt/models/exaone.py +1 -5
  46. sglang/srt/models/gemma.py +1 -5
  47. sglang/srt/models/gemma2.py +1 -5
  48. sglang/srt/models/gpt_bigcode.py +1 -5
  49. sglang/srt/models/grok.py +1 -5
  50. sglang/srt/models/internlm2.py +1 -5
  51. sglang/srt/models/llama.py +51 -5
  52. sglang/srt/models/llama_classification.py +1 -20
  53. sglang/srt/models/llava.py +30 -5
  54. sglang/srt/models/llavavid.py +2 -2
  55. sglang/srt/models/minicpm.py +1 -5
  56. sglang/srt/models/minicpm3.py +669 -0
  57. sglang/srt/models/mixtral.py +6 -5
  58. sglang/srt/models/mixtral_quant.py +1 -5
  59. sglang/srt/models/olmoe.py +415 -0
  60. sglang/srt/models/qwen.py +1 -5
  61. sglang/srt/models/qwen2.py +1 -5
  62. sglang/srt/models/qwen2_moe.py +6 -5
  63. sglang/srt/models/stablelm.py +1 -5
  64. sglang/srt/models/xverse.py +375 -0
  65. sglang/srt/models/xverse_moe.py +445 -0
  66. sglang/srt/openai_api/adapter.py +65 -46
  67. sglang/srt/openai_api/protocol.py +11 -3
  68. sglang/srt/sampling/sampling_batch_info.py +46 -80
  69. sglang/srt/server.py +30 -15
  70. sglang/srt/server_args.py +163 -28
  71. sglang/srt/utils.py +19 -51
  72. sglang/test/few_shot_gsm8k.py +132 -0
  73. sglang/test/runners.py +114 -22
  74. sglang/test/test_programs.py +7 -5
  75. sglang/test/test_utils.py +85 -2
  76. sglang/utils.py +32 -37
  77. sglang/version.py +1 -1
  78. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/METADATA +30 -18
  79. sglang-0.3.1.post1.dist-info/RECORD +130 -0
  80. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/WHEEL +1 -1
  81. sglang-0.3.0.dist-info/RECORD +0 -118
  82. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/LICENSE +0 -0
  83. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/top_level.txt +0 -0
@@ -41,7 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
41
  from sglang.srt.layers.layernorm import RMSNorm
42
42
  from sglang.srt.layers.logits_processor import LogitsProcessor
43
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
- from sglang.srt.layers.sampler import Sampler
44
+ from sglang.srt.layers.torchao_utils import apply_torchao_config_
45
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
45
46
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
46
47
 
47
48
 
@@ -297,10 +298,10 @@ class MixtralForCausalLM(nn.Module):
297
298
  super().__init__()
298
299
  self.config = config
299
300
  self.quant_config = quant_config
301
+ self.torchao_config = global_server_args_dict["torchao_config"]
300
302
  self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
301
303
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
302
304
  self.logits_processor = LogitsProcessor(config)
303
- self.sampler = Sampler()
304
305
 
305
306
  def forward(
306
307
  self,
@@ -310,11 +311,9 @@ class MixtralForCausalLM(nn.Module):
310
311
  input_embeds: torch.Tensor = None,
311
312
  ) -> torch.Tensor:
312
313
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
313
- logits_output = self.logits_processor(
314
+ return self.logits_processor(
314
315
  input_ids, hidden_states, self.lm_head.weight, input_metadata
315
316
  )
316
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
317
- return sample_output, logits_output
318
317
 
319
318
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
320
319
  stacked_params_mapping = [
@@ -380,5 +379,7 @@ class MixtralForCausalLM(nn.Module):
380
379
  )
381
380
  weight_loader(param, loaded_weight)
382
381
 
382
+ apply_torchao_config_(self, params_dict, set(["proj.weight"]))
383
+
383
384
 
384
385
  EntryClass = MixtralForCausalLM
@@ -45,7 +45,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
45
45
  from sglang.srt.layers.layernorm import RMSNorm
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
- from sglang.srt.layers.sampler import Sampler
49
48
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
49
 
51
50
 
@@ -334,7 +333,6 @@ class QuantMixtralForCausalLM(nn.Module):
334
333
  self.model = MixtralModel(config, quant_config=quant_config)
335
334
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
336
335
  self.logits_processor = LogitsProcessor(config)
337
- self.sampler = Sampler()
338
336
 
339
337
  @torch.no_grad()
340
338
  def forward(
@@ -345,11 +343,9 @@ class QuantMixtralForCausalLM(nn.Module):
345
343
  input_embeds: torch.Tensor = None,
346
344
  ) -> torch.Tensor:
347
345
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
348
- logits_output = self.logits_processor(
346
+ return self.logits_processor(
349
347
  input_ids, hidden_states, self.lm_head.weight, input_metadata
350
348
  )
351
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
352
- return sample_output, logits_output
353
349
 
354
350
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
355
351
  stacked_params_mapping = [
@@ -0,0 +1,415 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ # Adapted from:
17
+ # https://github.com/vllm-project/vllm/pull/7922
18
+
19
+ """Inference-only OLMoE model compatible with HuggingFace weights."""
20
+ from typing import Any, Dict, Iterable, List, Optional, Tuple
21
+
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from torch import nn
25
+ from transformers import PretrainedConfig
26
+ from vllm.config import CacheConfig
27
+ from vllm.distributed import (
28
+ get_tensor_model_parallel_world_size,
29
+ tensor_model_parallel_all_reduce,
30
+ )
31
+ from vllm.model_executor.layers.fused_moe import FusedMoE
32
+ from vllm.model_executor.layers.linear import (
33
+ MergedColumnParallelLinear,
34
+ QKVParallelLinear,
35
+ ReplicatedLinear,
36
+ RowParallelLinear,
37
+ )
38
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
39
+ from vllm.model_executor.layers.rotary_embedding import get_rope
40
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
41
+ ParallelLMHead,
42
+ VocabParallelEmbedding,
43
+ )
44
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
45
+ from vllm.utils import print_warning_once
46
+
47
+ from sglang.srt.layers.activation import SiluAndMul
48
+ from sglang.srt.layers.layernorm import RMSNorm
49
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
50
+ from sglang.srt.layers.radix_attention import RadixAttention
51
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
52
+
53
+
54
+ class OlmoeMoE(nn.Module):
55
+ """A tensor-parallel MoE implementation for Olmoe that shards each expert
56
+ across all ranks.
57
+
58
+ Each expert's weights are sharded across all ranks and a fused MoE
59
+ kernel is used for the forward pass, and finally we reduce the outputs
60
+ across ranks.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ num_experts: int,
66
+ top_k: int,
67
+ hidden_size: int,
68
+ intermediate_size: int,
69
+ params_dtype: Optional[torch.dtype] = None,
70
+ quant_config: Optional[QuantizationConfig] = None,
71
+ tp_size: Optional[int] = None,
72
+ prefix: str = "",
73
+ ):
74
+ super().__init__()
75
+ self.hidden_size = hidden_size
76
+
77
+ # Gate always runs at half / full precision for now.
78
+ self.gate = ReplicatedLinear(
79
+ hidden_size, num_experts, bias=False, quant_config=None
80
+ )
81
+
82
+ self.experts = FusedMoE(
83
+ num_experts=num_experts,
84
+ top_k=top_k,
85
+ hidden_size=hidden_size,
86
+ intermediate_size=intermediate_size,
87
+ reduce_results=True,
88
+ renormalize=False,
89
+ quant_config=quant_config,
90
+ tp_size=tp_size,
91
+ )
92
+
93
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
94
+ # NOTE: hidden_states can have either 1D or 2D shape.
95
+ orig_shape = hidden_states.shape
96
+ hidden_states = hidden_states.view(-1, self.hidden_size)
97
+ # router_logits: (num_tokens, n_experts)
98
+ router_logits, _ = self.gate(hidden_states)
99
+ final_hidden_states = self.experts(
100
+ hidden_states=hidden_states, router_logits=router_logits
101
+ )
102
+ return final_hidden_states.view(orig_shape)
103
+
104
+
105
+ class OlmoeAttention(nn.Module):
106
+
107
+ def __init__(
108
+ self,
109
+ layer_id: int,
110
+ hidden_size: int,
111
+ num_heads: int,
112
+ num_kv_heads: int,
113
+ rope_theta: float = 10000,
114
+ rope_scaling: Optional[Dict[str, Any]] = None,
115
+ max_position_embeddings: int = 4096,
116
+ quant_config: Optional[QuantizationConfig] = None,
117
+ ) -> None:
118
+ super().__init__()
119
+ self.hidden_size = hidden_size
120
+ tp_size = get_tensor_model_parallel_world_size()
121
+ self.total_num_heads = num_heads
122
+ assert self.total_num_heads % tp_size == 0
123
+ self.num_heads = self.total_num_heads // tp_size
124
+ self.total_num_kv_heads = num_kv_heads
125
+ if self.total_num_kv_heads >= tp_size:
126
+ # Number of KV heads is greater than TP size, so we partition
127
+ # the KV heads across multiple tensor parallel GPUs.
128
+ assert self.total_num_kv_heads % tp_size == 0
129
+ else:
130
+ # Number of KV heads is less than TP size, so we replicate
131
+ # the KV heads across multiple tensor parallel GPUs.
132
+ assert tp_size % self.total_num_kv_heads == 0
133
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
134
+ self.head_dim = hidden_size // self.total_num_heads
135
+ self.q_size = self.num_heads * self.head_dim
136
+ self.kv_size = self.num_kv_heads * self.head_dim
137
+ self.scaling = self.head_dim**-0.5
138
+ self.rope_theta = rope_theta
139
+ self.max_position_embeddings = max_position_embeddings
140
+
141
+ self.qkv_proj = QKVParallelLinear(
142
+ hidden_size,
143
+ self.head_dim,
144
+ self.total_num_heads,
145
+ self.total_num_kv_heads,
146
+ bias=False,
147
+ quant_config=quant_config,
148
+ )
149
+ self.q_norm = RMSNorm(hidden_size, eps=1e-5)
150
+ self.k_norm = RMSNorm(hidden_size, eps=1e-5)
151
+ self.o_proj = RowParallelLinear(
152
+ self.total_num_heads * self.head_dim,
153
+ hidden_size,
154
+ bias=False,
155
+ quant_config=quant_config,
156
+ )
157
+
158
+ self.rotary_emb = get_rope(
159
+ self.head_dim,
160
+ rotary_dim=self.head_dim,
161
+ max_position=max_position_embeddings,
162
+ base=rope_theta,
163
+ rope_scaling=rope_scaling,
164
+ is_neox_style=True,
165
+ )
166
+ self.attn = RadixAttention(
167
+ self.num_heads,
168
+ self.head_dim,
169
+ self.scaling,
170
+ layer_id=layer_id,
171
+ num_kv_heads=self.num_kv_heads,
172
+ )
173
+
174
+ def forward(
175
+ self,
176
+ positions: torch.Tensor,
177
+ hidden_states: torch.Tensor,
178
+ input_metadata: InputMetadata,
179
+ ) -> torch.Tensor:
180
+ qkv, _ = self.qkv_proj(hidden_states)
181
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
182
+ q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous())
183
+ q, k = self.rotary_emb(positions, q, k)
184
+ attn_output = self.attn(q, k, v, input_metadata)
185
+ output, _ = self.o_proj(attn_output)
186
+ return output
187
+
188
+
189
+ class OlmoeDecoderLayer(nn.Module):
190
+
191
+ def __init__(
192
+ self,
193
+ config: PretrainedConfig,
194
+ layer_id: int = 0,
195
+ quant_config: Optional[QuantizationConfig] = None,
196
+ ) -> None:
197
+ super().__init__()
198
+ self.hidden_size = config.hidden_size
199
+ rope_theta = getattr(config, "rope_theta", 10000)
200
+ rope_scaling = getattr(config, "rope_scaling", None)
201
+ max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
202
+
203
+ self.self_attn = OlmoeAttention(
204
+ layer_id,
205
+ hidden_size=self.hidden_size,
206
+ num_heads=config.num_attention_heads,
207
+ num_kv_heads=config.num_key_value_heads,
208
+ rope_theta=rope_theta,
209
+ rope_scaling=rope_scaling,
210
+ max_position_embeddings=max_position_embeddings,
211
+ quant_config=quant_config,
212
+ )
213
+
214
+ self.mlp = OlmoeMoE(
215
+ num_experts=config.num_experts,
216
+ top_k=config.num_experts_per_tok,
217
+ hidden_size=config.hidden_size,
218
+ intermediate_size=config.intermediate_size,
219
+ quant_config=quant_config,
220
+ )
221
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
222
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
223
+
224
+ def forward(
225
+ self,
226
+ positions: torch.Tensor,
227
+ hidden_states: torch.Tensor,
228
+ input_metadata: InputMetadata,
229
+ residual: Optional[torch.Tensor],
230
+ ) -> torch.Tensor:
231
+ # Self Attention
232
+ if residual is None:
233
+ residual = hidden_states
234
+ hidden_states = self.input_layernorm(hidden_states)
235
+ else:
236
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
237
+
238
+ hidden_states = self.self_attn(
239
+ positions=positions,
240
+ hidden_states=hidden_states,
241
+ input_metadata=input_metadata,
242
+ )
243
+
244
+ # Fully Connected
245
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
246
+ hidden_states = self.mlp(hidden_states)
247
+ return hidden_states, residual
248
+
249
+
250
+ class OlmoeModel(nn.Module):
251
+
252
+ def __init__(
253
+ self,
254
+ config: PretrainedConfig,
255
+ quant_config: Optional[QuantizationConfig] = None,
256
+ ) -> None:
257
+ super().__init__()
258
+ self.padding_idx = config.pad_token_id
259
+ self.vocab_size = config.vocab_size
260
+
261
+ self.embed_tokens = VocabParallelEmbedding(
262
+ config.vocab_size,
263
+ config.hidden_size,
264
+ )
265
+ self.layers = nn.ModuleList(
266
+ [
267
+ OlmoeDecoderLayer(config, layer_id, quant_config=quant_config)
268
+ for layer_id in range(config.num_hidden_layers)
269
+ ]
270
+ )
271
+ self.norm = RMSNorm(config.hidden_size, eps=1e-5)
272
+
273
+ def forward(
274
+ self,
275
+ input_ids: torch.Tensor,
276
+ positions: torch.Tensor,
277
+ input_metadata: InputMetadata,
278
+ input_embeds: torch.Tensor = None,
279
+ ) -> torch.Tensor:
280
+ if input_embeds is None:
281
+ hidden_states = self.embed_tokens(input_ids)
282
+ else:
283
+ hidden_states = input_embeds
284
+ residual = None
285
+ for i in range(len(self.layers)):
286
+ layer = self.layers[i]
287
+ hidden_states, residual = layer(
288
+ positions, hidden_states, input_metadata, residual
289
+ )
290
+ hidden_states, _ = self.norm(hidden_states, residual)
291
+ return hidden_states
292
+
293
+
294
+ class OlmoeForCausalLM(nn.Module):
295
+
296
+ fall_back_to_pt_during_load = False
297
+
298
+ def __init__(
299
+ self,
300
+ config: PretrainedConfig,
301
+ cache_config: Optional[CacheConfig] = None,
302
+ quant_config: Optional[QuantizationConfig] = None,
303
+ ) -> None:
304
+ super().__init__()
305
+ self.config = config
306
+ self.quant_config = quant_config
307
+ self.model = OlmoeModel(config, quant_config)
308
+ self.lm_head = ParallelLMHead(
309
+ config.vocab_size, config.hidden_size, quant_config=quant_config
310
+ )
311
+ self.logits_processor = LogitsProcessor(config)
312
+
313
+ def forward(
314
+ self,
315
+ input_ids: torch.Tensor,
316
+ positions: torch.Tensor,
317
+ input_metadata: InputMetadata,
318
+ input_embeds: torch.Tensor = None,
319
+ ) -> torch.Tensor:
320
+ hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
321
+ return self.logits_processor(
322
+ input_ids, hidden_states, self.lm_head.weight, input_metadata
323
+ )
324
+
325
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
326
+ stacked_params_mapping = [
327
+ # (param_name, shard_name, shard_id)
328
+ ("qkv_proj", "q_proj", "q"),
329
+ ("qkv_proj", "k_proj", "k"),
330
+ ("qkv_proj", "v_proj", "v"),
331
+ ("gate_up_proj", "gate_proj", 0),
332
+ ("gate_up_proj", "up_proj", 1),
333
+ ]
334
+
335
+ # Params for weights, fp8 weight scales, fp8 activation scales
336
+ # (param_name, weight_name, expert_id, shard_id)
337
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
338
+ ckpt_gate_proj_name="gate_proj",
339
+ ckpt_down_proj_name="down_proj",
340
+ ckpt_up_proj_name="up_proj",
341
+ num_experts=self.config.num_experts,
342
+ )
343
+
344
+ params_dict = dict(self.named_parameters())
345
+ for name, loaded_weight in weights:
346
+ if "rotary_emb.inv_freq" in name:
347
+ continue
348
+ for param_name, weight_name, shard_id in stacked_params_mapping:
349
+ # Skip non-stacked layers and experts (experts handled below).
350
+ if weight_name not in name:
351
+ continue
352
+ # We have mlp.experts[0].gate_proj in the checkpoint.
353
+ # Since we handle the experts below in expert_params_mapping,
354
+ # we need to skip here BEFORE we update the name, otherwise
355
+ # name will be updated to mlp.experts[0].gate_up_proj, which
356
+ # will then be updated below in expert_params_mapping
357
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
358
+ if "mlp.experts" in name:
359
+ continue
360
+ name = name.replace(weight_name, param_name)
361
+ # Skip loading extra bias for GPTQ models.
362
+ if name.endswith(".bias") and name not in params_dict:
363
+ continue
364
+ if name not in params_dict:
365
+ continue
366
+
367
+ param = params_dict[name]
368
+ weight_loader = param.weight_loader
369
+ weight_loader(param, loaded_weight, shard_id)
370
+ break
371
+ else:
372
+ for mapping in expert_params_mapping:
373
+ param_name, weight_name, expert_id, shard_id = mapping
374
+ if weight_name not in name:
375
+ continue
376
+ name = name.replace(weight_name, param_name)
377
+ param = params_dict[name]
378
+ weight_loader = param.weight_loader
379
+ weight_loader(
380
+ param,
381
+ loaded_weight,
382
+ name,
383
+ shard_id=shard_id,
384
+ expert_id=expert_id,
385
+ )
386
+ break
387
+ else:
388
+ # Skip loading extra bias for GPTQ models.
389
+ if name.endswith(".bias") and name not in params_dict:
390
+ continue
391
+ # Remapping the name of FP8 kv-scale.
392
+ if name.endswith("kv_scale"):
393
+ remapped_kv_scale_name = name.replace(
394
+ ".kv_scale", ".attn.kv_scale"
395
+ )
396
+ if remapped_kv_scale_name not in params_dict:
397
+ print_warning_once(
398
+ "Found kv scale in the checkpoint "
399
+ f"(e.g. {name}), but not found the expected "
400
+ f"name in the model "
401
+ f"(e.g. {remapped_kv_scale_name}). "
402
+ "kv-scale is not loaded."
403
+ )
404
+ continue
405
+ else:
406
+ name = remapped_kv_scale_name
407
+
408
+ param = params_dict[name]
409
+ weight_loader = getattr(
410
+ param, "weight_loader", default_weight_loader
411
+ )
412
+ weight_loader(param, loaded_weight)
413
+
414
+
415
+ EntryClass = OlmoeForCausalLM
sglang/srt/models/qwen.py CHANGED
@@ -39,7 +39,6 @@ from sglang.srt.layers.activation import SiluAndMul
39
39
  from sglang.srt.layers.layernorm import RMSNorm
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
- from sglang.srt.layers.sampler import Sampler
43
42
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
44
43
 
45
44
 
@@ -252,7 +251,6 @@ class QWenLMHeadModel(nn.Module):
252
251
  vocab_size = ((config.vocab_size + 63) // 64) * 64
253
252
  self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
254
253
  self.logits_processor = LogitsProcessor(config)
255
- self.sampler = Sampler()
256
254
 
257
255
  @torch.no_grad()
258
256
  def forward(
@@ -262,11 +260,9 @@ class QWenLMHeadModel(nn.Module):
262
260
  input_metadata: InputMetadata,
263
261
  ):
264
262
  hidden_states = self.transformer(input_ids, positions, input_metadata)
265
- logits_output = self.logits_processor(
263
+ return self.logits_processor(
266
264
  input_ids, hidden_states, self.lm_head.weight, input_metadata
267
265
  )
268
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
269
- return sample_output, logits_output
270
266
 
271
267
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
272
268
  stacked_params_mapping = [
@@ -40,7 +40,6 @@ from sglang.srt.layers.layernorm import RMSNorm
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.pooler import Pooler, PoolingType
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
- from sglang.srt.layers.sampler import Sampler
44
43
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
45
44
 
46
45
  Qwen2Config = None
@@ -277,7 +276,6 @@ class Qwen2ForCausalLM(nn.Module):
277
276
  self.model = Qwen2Model(config, quant_config=quant_config)
278
277
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
279
278
  self.logits_processor = LogitsProcessor(config)
280
- self.sampler = Sampler()
281
279
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
282
280
 
283
281
  @torch.no_grad()
@@ -291,11 +289,9 @@ class Qwen2ForCausalLM(nn.Module):
291
289
  ) -> torch.Tensor:
292
290
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
293
291
  if not get_embedding:
294
- logits_output = self.logits_processor(
292
+ return self.logits_processor(
295
293
  input_ids, hidden_states, self.lm_head.weight, input_metadata
296
294
  )
297
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
298
- return sample_output, logits_output
299
295
  else:
300
296
  return self.pooler(hidden_states, input_metadata)
301
297
 
@@ -47,7 +47,8 @@ from sglang.srt.layers.activation import SiluAndMul
47
47
  from sglang.srt.layers.layernorm import RMSNorm
48
48
  from sglang.srt.layers.logits_processor import LogitsProcessor
49
49
  from sglang.srt.layers.radix_attention import RadixAttention
50
- from sglang.srt.layers.sampler import Sampler
50
+ from sglang.srt.layers.torchao_utils import apply_torchao_config_
51
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
51
52
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
52
53
 
53
54
 
@@ -360,12 +361,12 @@ class Qwen2MoeForCausalLM(nn.Module):
360
361
  super().__init__()
361
362
  self.config = config
362
363
  self.quant_config = quant_config
364
+ self.torchao_config = global_server_args_dict["torchao_config"]
363
365
  self.model = Qwen2MoeModel(config, cache_config, quant_config)
364
366
  self.lm_head = ParallelLMHead(
365
367
  config.vocab_size, config.hidden_size, quant_config=quant_config
366
368
  )
367
369
  self.logits_processor = LogitsProcessor(config)
368
- self.sampler = Sampler()
369
370
 
370
371
  @torch.no_grad()
371
372
  def forward(
@@ -376,11 +377,9 @@ class Qwen2MoeForCausalLM(nn.Module):
376
377
  input_embeds: torch.Tensor = None,
377
378
  ) -> torch.Tensor:
378
379
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
379
- logits_output = self.logits_processor(
380
+ return self.logits_processor(
380
381
  input_ids, hidden_states, self.lm_head.weight, input_metadata
381
382
  )
382
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
383
- return sample_output, logits_output
384
383
 
385
384
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
386
385
  stacked_params_mapping = [
@@ -455,5 +454,7 @@ class Qwen2MoeForCausalLM(nn.Module):
455
454
  )
456
455
  weight_loader(param, loaded_weight)
457
456
 
457
+ apply_torchao_config_(self, params_dict, set(["proj.weight"]))
458
+
458
459
 
459
460
  EntryClass = Qwen2MoeForCausalLM
@@ -40,7 +40,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
40
  from sglang.srt.layers.activation import SiluAndMul
41
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
- from sglang.srt.layers.sampler import Sampler
44
43
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
45
44
 
46
45
 
@@ -250,7 +249,6 @@ class StableLmForCausalLM(nn.Module):
250
249
  self.model = StableLMEpochModel(config, quant_config=quant_config)
251
250
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
252
251
  self.logits_processor = LogitsProcessor(config)
253
- self.sampler = Sampler()
254
252
 
255
253
  @torch.no_grad()
256
254
  def forward(
@@ -261,11 +259,9 @@ class StableLmForCausalLM(nn.Module):
261
259
  input_embeds: torch.Tensor = None,
262
260
  ) -> torch.Tensor:
263
261
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
264
- logits_output = self.logits_processor(
262
+ return self.logits_processor(
265
263
  input_ids, hidden_states, self.lm_head.weight, input_metadata
266
264
  )
267
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
268
- return sample_output, logits_output
269
265
 
270
266
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
271
267
  stacked_params_mapping = [