sglang 0.3.1__py3-none-any.whl → 0.3.1.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. sglang/bench_latency.py +10 -3
  2. sglang/bench_server_latency.py +187 -0
  3. sglang/bench_serving.py +1 -1
  4. sglang/global_config.py +5 -13
  5. sglang/lang/interpreter.py +0 -3
  6. sglang/srt/constrained/fsm_cache.py +5 -1
  7. sglang/srt/layers/activation.py +16 -1
  8. sglang/srt/layers/attention_backend.py +12 -12
  9. sglang/srt/layers/fused_moe/layer.py +27 -7
  10. sglang/srt/layers/layernorm.py +21 -6
  11. sglang/srt/layers/sampler.py +40 -98
  12. sglang/srt/lora/lora_manager.py +11 -8
  13. sglang/srt/managers/io_struct.py +3 -0
  14. sglang/srt/managers/policy_scheduler.py +49 -93
  15. sglang/srt/managers/schedule_batch.py +2 -1
  16. sglang/srt/managers/tp_worker.py +19 -13
  17. sglang/srt/model_executor/cuda_graph_runner.py +25 -13
  18. sglang/srt/model_executor/model_runner.py +37 -46
  19. sglang/srt/models/deepseek_v2.py +8 -3
  20. sglang/srt/models/llama.py +1 -3
  21. sglang/srt/models/llama_classification.py +2 -3
  22. sglang/srt/models/minicpm3.py +7 -3
  23. sglang/srt/models/olmoe.py +415 -0
  24. sglang/srt/models/xverse.py +1 -3
  25. sglang/srt/models/xverse_moe.py +1 -4
  26. sglang/srt/sampling/sampling_batch_info.py +3 -50
  27. sglang/srt/server.py +6 -1
  28. sglang/srt/server_args.py +39 -10
  29. sglang/srt/utils.py +7 -51
  30. sglang/test/few_shot_gsm8k.py +8 -2
  31. sglang/test/test_utils.py +1 -1
  32. sglang/version.py +1 -1
  33. {sglang-0.3.1.dist-info → sglang-0.3.1.post2.dist-info}/METADATA +4 -5
  34. {sglang-0.3.1.dist-info → sglang-0.3.1.post2.dist-info}/RECORD +37 -35
  35. {sglang-0.3.1.dist-info → sglang-0.3.1.post2.dist-info}/WHEEL +1 -1
  36. {sglang-0.3.1.dist-info → sglang-0.3.1.post2.dist-info}/LICENSE +0 -0
  37. {sglang-0.3.1.dist-info → sglang-0.3.1.post2.dist-info}/top_level.txt +0 -0
@@ -19,7 +19,6 @@ limitations under the License.
19
19
  from typing import Any, Dict, Iterable, Optional, Tuple
20
20
 
21
21
  import torch
22
- from flashinfer import bmm_fp8
23
22
  from torch import nn
24
23
  from transformers import PretrainedConfig
25
24
  from vllm.config import CacheConfig
@@ -48,6 +47,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
48
47
  from sglang.srt.layers.radix_attention import RadixAttention
49
48
  from sglang.srt.managers.schedule_batch import global_server_args_dict
50
49
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
+ from sglang.srt.utils import is_hip
51
+
52
+ # ROCm: flashinfer available later
53
+ if not is_hip():
54
+ from flashinfer import bmm_fp8
51
55
 
52
56
 
53
57
  class DeepseekV2MLP(nn.Module):
@@ -503,7 +507,7 @@ class DeepseekV2DecoderLayer(nn.Module):
503
507
  rope_theta = getattr(config, "rope_theta", 10000)
504
508
  rope_scaling = getattr(config, "rope_scaling", None)
505
509
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
506
- if global_server_args_dict["enable_mla"]:
510
+ if not global_server_args_dict["disable_mla"]:
507
511
  self.self_attn = DeepseekV2AttentionMLA(
508
512
  config=config,
509
513
  hidden_size=self.hidden_size,
@@ -649,6 +653,7 @@ class DeepseekV2ForCausalLM(nn.Module):
649
653
  )
650
654
  self.logits_processor = LogitsProcessor(config)
651
655
 
656
+ @torch.no_grad()
652
657
  def forward(
653
658
  self,
654
659
  input_ids: torch.Tensor,
@@ -727,7 +732,7 @@ class DeepseekV2ForCausalLM(nn.Module):
727
732
  )
728
733
  weight_loader(param, loaded_weight)
729
734
 
730
- if global_server_args_dict["enable_mla"]:
735
+ if not global_server_args_dict["disable_mla"]:
731
736
  for layer_id in range(self.config.num_hidden_layers):
732
737
  self_attn = self.model.layers[layer_id].self_attn
733
738
  w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
@@ -305,8 +305,6 @@ class LlamaForCausalLM(nn.Module):
305
305
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
306
306
  self.logits_processor = LogitsProcessor(config)
307
307
 
308
- self.param_dict = dict(self.named_parameters())
309
-
310
308
  @torch.no_grad()
311
309
  def forward(
312
310
  self,
@@ -374,7 +372,7 @@ class LlamaForCausalLM(nn.Module):
374
372
  (".gate_up_proj", ".gate_proj", 0),
375
373
  (".gate_up_proj", ".up_proj", 1),
376
374
  ]
377
- params_dict = self.param_dict
375
+ params_dict = dict(self.named_parameters())
378
376
 
379
377
  for name, loaded_weight in weights:
380
378
  if "rotary_emb.inv_freq" in name or "projector" in name:
@@ -36,6 +36,7 @@ class LlamaForClassification(nn.Module):
36
36
  ) -> None:
37
37
  super().__init__()
38
38
  self.config = config
39
+ self.torchao_config = None
39
40
  self.quant_config = quant_config
40
41
  self.model = LlamaModel(config, quant_config=quant_config)
41
42
 
@@ -44,8 +45,6 @@ class LlamaForClassification(nn.Module):
44
45
  )
45
46
  self.eos_token_id = config.eos_token_id
46
47
 
47
- self.param_dict = dict(self.named_parameters())
48
-
49
48
  @torch.no_grad()
50
49
  def forward(
51
50
  self,
@@ -77,7 +76,7 @@ class LlamaForClassification(nn.Module):
77
76
  return logits_output
78
77
 
79
78
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
80
- params_dict = self.param_dict
79
+ params_dict = dict(self.named_parameters())
81
80
 
82
81
  for name, loaded_weight in weights:
83
82
  if "classification_head" in name:
@@ -19,7 +19,6 @@ import math
19
19
  from typing import Any, Dict, Iterable, Optional, Tuple
20
20
 
21
21
  import torch
22
- from flashinfer import bmm_fp8
23
22
  from torch import nn
24
23
  from transformers import PretrainedConfig
25
24
  from vllm.config import CacheConfig
@@ -44,6 +43,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
44
43
  from sglang.srt.layers.radix_attention import RadixAttention
45
44
  from sglang.srt.managers.schedule_batch import global_server_args_dict
46
45
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
46
+ from sglang.srt.utils import is_hip
47
+
48
+ # ROCm: flashinfer available later
49
+ if not is_hip():
50
+ from flashinfer import bmm_fp8
47
51
 
48
52
 
49
53
  class MiniCPM3MLP(nn.Module):
@@ -415,7 +419,7 @@ class MiniCPM3DecoderLayer(nn.Module):
415
419
  rope_theta = getattr(config, "rope_theta", 10000)
416
420
  rope_scaling = getattr(config, "rope_scaling", None)
417
421
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
418
- if global_server_args_dict["enable_mla"]:
422
+ if not global_server_args_dict["disable_mla"]:
419
423
  self.self_attn = MiniCPM3AttentionMLA(
420
424
  config=config,
421
425
  hidden_size=self.hidden_size,
@@ -649,7 +653,7 @@ class MiniCPM3ForCausalLM(nn.Module):
649
653
  )
650
654
  weight_loader(param, loaded_weight)
651
655
 
652
- if global_server_args_dict["enable_mla"]:
656
+ if not global_server_args_dict["disable_mla"]:
653
657
  for layer_id in range(self.config.num_hidden_layers):
654
658
  self_attn = self.model.layers[layer_id].self_attn
655
659
  w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
@@ -0,0 +1,415 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ # Adapted from:
17
+ # https://github.com/vllm-project/vllm/pull/7922
18
+
19
+ """Inference-only OLMoE model compatible with HuggingFace weights."""
20
+ from typing import Any, Dict, Iterable, List, Optional, Tuple
21
+
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from torch import nn
25
+ from transformers import PretrainedConfig
26
+ from vllm.config import CacheConfig
27
+ from vllm.distributed import (
28
+ get_tensor_model_parallel_world_size,
29
+ tensor_model_parallel_all_reduce,
30
+ )
31
+ from vllm.model_executor.layers.fused_moe import FusedMoE
32
+ from vllm.model_executor.layers.linear import (
33
+ MergedColumnParallelLinear,
34
+ QKVParallelLinear,
35
+ ReplicatedLinear,
36
+ RowParallelLinear,
37
+ )
38
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
39
+ from vllm.model_executor.layers.rotary_embedding import get_rope
40
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
41
+ ParallelLMHead,
42
+ VocabParallelEmbedding,
43
+ )
44
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
45
+ from vllm.utils import print_warning_once
46
+
47
+ from sglang.srt.layers.activation import SiluAndMul
48
+ from sglang.srt.layers.layernorm import RMSNorm
49
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
50
+ from sglang.srt.layers.radix_attention import RadixAttention
51
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
52
+
53
+
54
+ class OlmoeMoE(nn.Module):
55
+ """A tensor-parallel MoE implementation for Olmoe that shards each expert
56
+ across all ranks.
57
+
58
+ Each expert's weights are sharded across all ranks and a fused MoE
59
+ kernel is used for the forward pass, and finally we reduce the outputs
60
+ across ranks.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ num_experts: int,
66
+ top_k: int,
67
+ hidden_size: int,
68
+ intermediate_size: int,
69
+ params_dtype: Optional[torch.dtype] = None,
70
+ quant_config: Optional[QuantizationConfig] = None,
71
+ tp_size: Optional[int] = None,
72
+ prefix: str = "",
73
+ ):
74
+ super().__init__()
75
+ self.hidden_size = hidden_size
76
+
77
+ # Gate always runs at half / full precision for now.
78
+ self.gate = ReplicatedLinear(
79
+ hidden_size, num_experts, bias=False, quant_config=None
80
+ )
81
+
82
+ self.experts = FusedMoE(
83
+ num_experts=num_experts,
84
+ top_k=top_k,
85
+ hidden_size=hidden_size,
86
+ intermediate_size=intermediate_size,
87
+ reduce_results=True,
88
+ renormalize=False,
89
+ quant_config=quant_config,
90
+ tp_size=tp_size,
91
+ )
92
+
93
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
94
+ # NOTE: hidden_states can have either 1D or 2D shape.
95
+ orig_shape = hidden_states.shape
96
+ hidden_states = hidden_states.view(-1, self.hidden_size)
97
+ # router_logits: (num_tokens, n_experts)
98
+ router_logits, _ = self.gate(hidden_states)
99
+ final_hidden_states = self.experts(
100
+ hidden_states=hidden_states, router_logits=router_logits
101
+ )
102
+ return final_hidden_states.view(orig_shape)
103
+
104
+
105
+ class OlmoeAttention(nn.Module):
106
+
107
+ def __init__(
108
+ self,
109
+ layer_id: int,
110
+ hidden_size: int,
111
+ num_heads: int,
112
+ num_kv_heads: int,
113
+ rope_theta: float = 10000,
114
+ rope_scaling: Optional[Dict[str, Any]] = None,
115
+ max_position_embeddings: int = 4096,
116
+ quant_config: Optional[QuantizationConfig] = None,
117
+ ) -> None:
118
+ super().__init__()
119
+ self.hidden_size = hidden_size
120
+ tp_size = get_tensor_model_parallel_world_size()
121
+ self.total_num_heads = num_heads
122
+ assert self.total_num_heads % tp_size == 0
123
+ self.num_heads = self.total_num_heads // tp_size
124
+ self.total_num_kv_heads = num_kv_heads
125
+ if self.total_num_kv_heads >= tp_size:
126
+ # Number of KV heads is greater than TP size, so we partition
127
+ # the KV heads across multiple tensor parallel GPUs.
128
+ assert self.total_num_kv_heads % tp_size == 0
129
+ else:
130
+ # Number of KV heads is less than TP size, so we replicate
131
+ # the KV heads across multiple tensor parallel GPUs.
132
+ assert tp_size % self.total_num_kv_heads == 0
133
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
134
+ self.head_dim = hidden_size // self.total_num_heads
135
+ self.q_size = self.num_heads * self.head_dim
136
+ self.kv_size = self.num_kv_heads * self.head_dim
137
+ self.scaling = self.head_dim**-0.5
138
+ self.rope_theta = rope_theta
139
+ self.max_position_embeddings = max_position_embeddings
140
+
141
+ self.qkv_proj = QKVParallelLinear(
142
+ hidden_size,
143
+ self.head_dim,
144
+ self.total_num_heads,
145
+ self.total_num_kv_heads,
146
+ bias=False,
147
+ quant_config=quant_config,
148
+ )
149
+ self.q_norm = RMSNorm(hidden_size, eps=1e-5)
150
+ self.k_norm = RMSNorm(hidden_size, eps=1e-5)
151
+ self.o_proj = RowParallelLinear(
152
+ self.total_num_heads * self.head_dim,
153
+ hidden_size,
154
+ bias=False,
155
+ quant_config=quant_config,
156
+ )
157
+
158
+ self.rotary_emb = get_rope(
159
+ self.head_dim,
160
+ rotary_dim=self.head_dim,
161
+ max_position=max_position_embeddings,
162
+ base=rope_theta,
163
+ rope_scaling=rope_scaling,
164
+ is_neox_style=True,
165
+ )
166
+ self.attn = RadixAttention(
167
+ self.num_heads,
168
+ self.head_dim,
169
+ self.scaling,
170
+ layer_id=layer_id,
171
+ num_kv_heads=self.num_kv_heads,
172
+ )
173
+
174
+ def forward(
175
+ self,
176
+ positions: torch.Tensor,
177
+ hidden_states: torch.Tensor,
178
+ input_metadata: InputMetadata,
179
+ ) -> torch.Tensor:
180
+ qkv, _ = self.qkv_proj(hidden_states)
181
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
182
+ q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous())
183
+ q, k = self.rotary_emb(positions, q, k)
184
+ attn_output = self.attn(q, k, v, input_metadata)
185
+ output, _ = self.o_proj(attn_output)
186
+ return output
187
+
188
+
189
+ class OlmoeDecoderLayer(nn.Module):
190
+
191
+ def __init__(
192
+ self,
193
+ config: PretrainedConfig,
194
+ layer_id: int = 0,
195
+ quant_config: Optional[QuantizationConfig] = None,
196
+ ) -> None:
197
+ super().__init__()
198
+ self.hidden_size = config.hidden_size
199
+ rope_theta = getattr(config, "rope_theta", 10000)
200
+ rope_scaling = getattr(config, "rope_scaling", None)
201
+ max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
202
+
203
+ self.self_attn = OlmoeAttention(
204
+ layer_id,
205
+ hidden_size=self.hidden_size,
206
+ num_heads=config.num_attention_heads,
207
+ num_kv_heads=config.num_key_value_heads,
208
+ rope_theta=rope_theta,
209
+ rope_scaling=rope_scaling,
210
+ max_position_embeddings=max_position_embeddings,
211
+ quant_config=quant_config,
212
+ )
213
+
214
+ self.mlp = OlmoeMoE(
215
+ num_experts=config.num_experts,
216
+ top_k=config.num_experts_per_tok,
217
+ hidden_size=config.hidden_size,
218
+ intermediate_size=config.intermediate_size,
219
+ quant_config=quant_config,
220
+ )
221
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
222
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
223
+
224
+ def forward(
225
+ self,
226
+ positions: torch.Tensor,
227
+ hidden_states: torch.Tensor,
228
+ input_metadata: InputMetadata,
229
+ residual: Optional[torch.Tensor],
230
+ ) -> torch.Tensor:
231
+ # Self Attention
232
+ if residual is None:
233
+ residual = hidden_states
234
+ hidden_states = self.input_layernorm(hidden_states)
235
+ else:
236
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
237
+
238
+ hidden_states = self.self_attn(
239
+ positions=positions,
240
+ hidden_states=hidden_states,
241
+ input_metadata=input_metadata,
242
+ )
243
+
244
+ # Fully Connected
245
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
246
+ hidden_states = self.mlp(hidden_states)
247
+ return hidden_states, residual
248
+
249
+
250
+ class OlmoeModel(nn.Module):
251
+
252
+ def __init__(
253
+ self,
254
+ config: PretrainedConfig,
255
+ quant_config: Optional[QuantizationConfig] = None,
256
+ ) -> None:
257
+ super().__init__()
258
+ self.padding_idx = config.pad_token_id
259
+ self.vocab_size = config.vocab_size
260
+
261
+ self.embed_tokens = VocabParallelEmbedding(
262
+ config.vocab_size,
263
+ config.hidden_size,
264
+ )
265
+ self.layers = nn.ModuleList(
266
+ [
267
+ OlmoeDecoderLayer(config, layer_id, quant_config=quant_config)
268
+ for layer_id in range(config.num_hidden_layers)
269
+ ]
270
+ )
271
+ self.norm = RMSNorm(config.hidden_size, eps=1e-5)
272
+
273
+ def forward(
274
+ self,
275
+ input_ids: torch.Tensor,
276
+ positions: torch.Tensor,
277
+ input_metadata: InputMetadata,
278
+ input_embeds: torch.Tensor = None,
279
+ ) -> torch.Tensor:
280
+ if input_embeds is None:
281
+ hidden_states = self.embed_tokens(input_ids)
282
+ else:
283
+ hidden_states = input_embeds
284
+ residual = None
285
+ for i in range(len(self.layers)):
286
+ layer = self.layers[i]
287
+ hidden_states, residual = layer(
288
+ positions, hidden_states, input_metadata, residual
289
+ )
290
+ hidden_states, _ = self.norm(hidden_states, residual)
291
+ return hidden_states
292
+
293
+
294
+ class OlmoeForCausalLM(nn.Module):
295
+
296
+ fall_back_to_pt_during_load = False
297
+
298
+ def __init__(
299
+ self,
300
+ config: PretrainedConfig,
301
+ cache_config: Optional[CacheConfig] = None,
302
+ quant_config: Optional[QuantizationConfig] = None,
303
+ ) -> None:
304
+ super().__init__()
305
+ self.config = config
306
+ self.quant_config = quant_config
307
+ self.model = OlmoeModel(config, quant_config)
308
+ self.lm_head = ParallelLMHead(
309
+ config.vocab_size, config.hidden_size, quant_config=quant_config
310
+ )
311
+ self.logits_processor = LogitsProcessor(config)
312
+
313
+ def forward(
314
+ self,
315
+ input_ids: torch.Tensor,
316
+ positions: torch.Tensor,
317
+ input_metadata: InputMetadata,
318
+ input_embeds: torch.Tensor = None,
319
+ ) -> torch.Tensor:
320
+ hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
321
+ return self.logits_processor(
322
+ input_ids, hidden_states, self.lm_head.weight, input_metadata
323
+ )
324
+
325
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
326
+ stacked_params_mapping = [
327
+ # (param_name, shard_name, shard_id)
328
+ ("qkv_proj", "q_proj", "q"),
329
+ ("qkv_proj", "k_proj", "k"),
330
+ ("qkv_proj", "v_proj", "v"),
331
+ ("gate_up_proj", "gate_proj", 0),
332
+ ("gate_up_proj", "up_proj", 1),
333
+ ]
334
+
335
+ # Params for weights, fp8 weight scales, fp8 activation scales
336
+ # (param_name, weight_name, expert_id, shard_id)
337
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
338
+ ckpt_gate_proj_name="gate_proj",
339
+ ckpt_down_proj_name="down_proj",
340
+ ckpt_up_proj_name="up_proj",
341
+ num_experts=self.config.num_experts,
342
+ )
343
+
344
+ params_dict = dict(self.named_parameters())
345
+ for name, loaded_weight in weights:
346
+ if "rotary_emb.inv_freq" in name:
347
+ continue
348
+ for param_name, weight_name, shard_id in stacked_params_mapping:
349
+ # Skip non-stacked layers and experts (experts handled below).
350
+ if weight_name not in name:
351
+ continue
352
+ # We have mlp.experts[0].gate_proj in the checkpoint.
353
+ # Since we handle the experts below in expert_params_mapping,
354
+ # we need to skip here BEFORE we update the name, otherwise
355
+ # name will be updated to mlp.experts[0].gate_up_proj, which
356
+ # will then be updated below in expert_params_mapping
357
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
358
+ if "mlp.experts" in name:
359
+ continue
360
+ name = name.replace(weight_name, param_name)
361
+ # Skip loading extra bias for GPTQ models.
362
+ if name.endswith(".bias") and name not in params_dict:
363
+ continue
364
+ if name not in params_dict:
365
+ continue
366
+
367
+ param = params_dict[name]
368
+ weight_loader = param.weight_loader
369
+ weight_loader(param, loaded_weight, shard_id)
370
+ break
371
+ else:
372
+ for mapping in expert_params_mapping:
373
+ param_name, weight_name, expert_id, shard_id = mapping
374
+ if weight_name not in name:
375
+ continue
376
+ name = name.replace(weight_name, param_name)
377
+ param = params_dict[name]
378
+ weight_loader = param.weight_loader
379
+ weight_loader(
380
+ param,
381
+ loaded_weight,
382
+ name,
383
+ shard_id=shard_id,
384
+ expert_id=expert_id,
385
+ )
386
+ break
387
+ else:
388
+ # Skip loading extra bias for GPTQ models.
389
+ if name.endswith(".bias") and name not in params_dict:
390
+ continue
391
+ # Remapping the name of FP8 kv-scale.
392
+ if name.endswith("kv_scale"):
393
+ remapped_kv_scale_name = name.replace(
394
+ ".kv_scale", ".attn.kv_scale"
395
+ )
396
+ if remapped_kv_scale_name not in params_dict:
397
+ print_warning_once(
398
+ "Found kv scale in the checkpoint "
399
+ f"(e.g. {name}), but not found the expected "
400
+ f"name in the model "
401
+ f"(e.g. {remapped_kv_scale_name}). "
402
+ "kv-scale is not loaded."
403
+ )
404
+ continue
405
+ else:
406
+ name = remapped_kv_scale_name
407
+
408
+ param = params_dict[name]
409
+ weight_loader = getattr(
410
+ param, "weight_loader", default_weight_loader
411
+ )
412
+ weight_loader(param, loaded_weight)
413
+
414
+
415
+ EntryClass = OlmoeForCausalLM
@@ -307,8 +307,6 @@ class XverseForCausalLM(nn.Module):
307
307
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
308
308
  self.logits_processor = LogitsProcessor(config)
309
309
 
310
- self.param_dict = dict(self.named_parameters())
311
-
312
310
  @torch.no_grad()
313
311
  def forward(
314
312
  self,
@@ -333,7 +331,7 @@ class XverseForCausalLM(nn.Module):
333
331
  ("gate_up_proj", "gate_proj", 0),
334
332
  ("gate_up_proj", "up_proj", 1),
335
333
  ]
336
- params_dict = self.param_dict
334
+ params_dict = dict(self.named_parameters())
337
335
 
338
336
  def load_weights_per_param(name, loaded_weight):
339
337
  if "rotary_emb.inv_freq" in name or "projector" in name:
@@ -383,8 +383,6 @@ class XverseMoeForCausalLM(nn.Module):
383
383
  )
384
384
  self.logits_processor = LogitsProcessor(config)
385
385
 
386
- self.param_dict = dict(self.named_parameters())
387
-
388
386
  @torch.no_grad()
389
387
  def forward(
390
388
  self,
@@ -406,8 +404,7 @@ class XverseMoeForCausalLM(nn.Module):
406
404
  ("gate_up_proj", "gate_proj", 0),
407
405
  ("gate_up_proj", "up_proj", 1),
408
406
  ]
409
-
410
- params_dict = self.param_dict
407
+ params_dict = dict(self.named_parameters())
411
408
 
412
409
  for name, loaded_weight in weights:
413
410
  if "rotary_emb.inv_freq" in name:
@@ -34,56 +34,6 @@ class SamplingBatchInfo:
34
34
  linear_penalties: torch.Tensor = None
35
35
  scaling_penalties: torch.Tensor = None
36
36
 
37
- def __len__(self):
38
- return len(self.temperatures)
39
-
40
- def can_run_in_cuda_graph(self):
41
- # Vocab bias and min_ps are not supported in CUDA graph
42
- return (
43
- self.logit_bias is None
44
- and self.linear_penalties is None
45
- and self.scaling_penalties is None
46
- and not self.need_min_p_sampling
47
- )
48
-
49
- @classmethod
50
- def dummy_one(cls, max_bs: int, vocab_size: int):
51
- ret = cls(vocab_size=vocab_size)
52
- with torch.device("cuda"):
53
- ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float)
54
- ret.top_ps = torch.ones((max_bs,), dtype=torch.float)
55
- ret.top_ks = torch.ones((max_bs,), dtype=torch.int)
56
- ret.vocab_mask = torch.zeros((max_bs, vocab_size), dtype=torch.bool)
57
- return ret
58
-
59
- def __getitem__(self, key):
60
- if isinstance(key, slice):
61
- # NOTE:This method is only used in CUDA graph
62
- assert self.can_run_in_cuda_graph()
63
- return SamplingBatchInfo(
64
- vocab_size=self.vocab_size,
65
- temperatures=self.temperatures[key],
66
- top_ps=self.top_ps[key],
67
- top_ks=self.top_ks[key],
68
- vocab_mask=self.vocab_mask[key],
69
- )
70
- else:
71
- raise NotImplementedError
72
-
73
- def inplace_assign(self, bs: int, other: SamplingBatchInfo):
74
- # NOTE:This method is only used in CUDA graph
75
- assert self.can_run_in_cuda_graph()
76
-
77
- self.vocab_size = other.vocab_size
78
- self.temperatures[:bs] = other.temperatures
79
- self.top_ps[:bs] = other.top_ps
80
- self.top_ks[:bs] = other.top_ks
81
-
82
- if other.vocab_mask is None:
83
- self.vocab_mask[:bs].fill_(False)
84
- else:
85
- self.vocab_mask[:bs] = other.vocab_mask
86
-
87
37
  @classmethod
88
38
  def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
89
39
  reqs = batch.reqs
@@ -130,6 +80,9 @@ class SamplingBatchInfo:
130
80
 
131
81
  return ret
132
82
 
83
+ def __len__(self):
84
+ return len(self.temperatures)
85
+
133
86
  def update_penalties(self):
134
87
  self.scaling_penalties = None
135
88
  self.linear_penalties = None
sglang/srt/server.py CHANGED
@@ -78,6 +78,7 @@ from sglang.srt.utils import (
78
78
  assert_pkg_version,
79
79
  configure_logger,
80
80
  enable_show_time_cost,
81
+ is_hip,
81
82
  kill_child_process,
82
83
  maybe_set_triton_cache_manager,
83
84
  prepare_model,
@@ -152,7 +153,7 @@ async def flush_cache():
152
153
  async def update_weights(obj: UpdateWeightReqInput, request: Request):
153
154
 
154
155
  success, message = await tokenizer_manager.update_weights(obj, request)
155
- content = {"message": message, "success": str(success)}
156
+ content = {"success": success, "message": message}
156
157
  if success:
157
158
  return JSONResponse(
158
159
  content,
@@ -434,6 +435,10 @@ def _set_envs_and_config(server_args: ServerArgs):
434
435
  "at https://docs.flashinfer.ai/installation.html.",
435
436
  )
436
437
 
438
+ if is_hip():
439
+ # to figure out a better method of not using fork later
440
+ mp.set_start_method("spawn", force=True)
441
+
437
442
 
438
443
  def _wait_and_warmup(server_args, pipe_finish_writer, pid):
439
444
  headers = {}