sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_serving.py +0 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  12. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  13. sglang/srt/constrained/xgrammar_backend.py +26 -4
  14. sglang/srt/custom_op.py +0 -62
  15. sglang/srt/disaggregation/decode.py +62 -6
  16. sglang/srt/disaggregation/mini_lb.py +5 -1
  17. sglang/srt/disaggregation/mooncake/conn.py +32 -62
  18. sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
  19. sglang/srt/disaggregation/prefill.py +40 -4
  20. sglang/srt/disaggregation/utils.py +15 -0
  21. sglang/srt/entrypoints/verl_engine.py +7 -5
  22. sglang/srt/layers/activation.py +6 -8
  23. sglang/srt/layers/attention/flashattention_backend.py +114 -71
  24. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  25. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  26. sglang/srt/layers/attention/triton_backend.py +6 -0
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  28. sglang/srt/layers/layernorm.py +1 -1
  29. sglang/srt/layers/linear.py +17 -3
  30. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  31. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  34. sglang/srt/layers/moe/topk.py +27 -30
  35. sglang/srt/layers/parameter.py +0 -2
  36. sglang/srt/layers/quantization/__init__.py +1 -0
  37. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  38. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +8 -2
  39. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
  40. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  41. sglang/srt/layers/quantization/fp8.py +115 -132
  42. sglang/srt/layers/quantization/fp8_kernel.py +213 -57
  43. sglang/srt/layers/quantization/fp8_utils.py +187 -262
  44. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  45. sglang/srt/layers/quantization/utils.py +5 -11
  46. sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
  47. sglang/srt/layers/quantization/w8a8_int8.py +7 -7
  48. sglang/srt/layers/radix_attention.py +15 -0
  49. sglang/srt/layers/rotary_embedding.py +3 -2
  50. sglang/srt/layers/sampler.py +5 -10
  51. sglang/srt/lora/backend/base_backend.py +18 -2
  52. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  53. sglang/srt/lora/backend/triton_backend.py +1 -1
  54. sglang/srt/lora/layers.py +1 -1
  55. sglang/srt/lora/lora.py +1 -1
  56. sglang/srt/lora/lora_manager.py +1 -1
  57. sglang/srt/managers/detokenizer_manager.py +0 -1
  58. sglang/srt/managers/io_struct.py +1 -0
  59. sglang/srt/managers/mm_utils.py +4 -3
  60. sglang/srt/managers/multimodal_processor.py +0 -2
  61. sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
  62. sglang/srt/managers/schedule_batch.py +2 -4
  63. sglang/srt/managers/scheduler.py +12 -71
  64. sglang/srt/managers/tokenizer_manager.py +1 -0
  65. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  66. sglang/srt/mem_cache/memory_pool.py +7 -2
  67. sglang/srt/model_executor/cuda_graph_runner.py +2 -2
  68. sglang/srt/model_executor/model_runner.py +20 -27
  69. sglang/srt/models/bert.py +398 -0
  70. sglang/srt/models/deepseek.py +1 -1
  71. sglang/srt/models/deepseek_nextn.py +74 -70
  72. sglang/srt/models/deepseek_v2.py +289 -348
  73. sglang/srt/models/llama.py +5 -5
  74. sglang/srt/models/minicpm3.py +29 -201
  75. sglang/srt/models/qwen2.py +4 -1
  76. sglang/srt/models/qwen2_moe.py +14 -13
  77. sglang/srt/models/qwen3.py +335 -0
  78. sglang/srt/models/qwen3_moe.py +423 -0
  79. sglang/srt/reasoning_parser.py +0 -1
  80. sglang/srt/sampling/sampling_batch_info.py +2 -3
  81. sglang/srt/server_args.py +34 -32
  82. sglang/srt/speculative/eagle_worker.py +4 -7
  83. sglang/srt/utils.py +16 -1
  84. sglang/test/runners.py +5 -1
  85. sglang/test/test_block_fp8.py +167 -0
  86. sglang/test/test_custom_ops.py +1 -1
  87. sglang/version.py +1 -1
  88. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +3 -3
  89. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +92 -91
  90. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  91. sglang/lang/__init__.py +0 -0
  92. sglang/srt/lora/backend/__init__.py +0 -25
  93. sglang/srt/server.py +0 -18
  94. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  95. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,423 @@
1
+ # Adapted from qwen2_moe.py
2
+
3
+ # Copyright 2023-2024 SGLang Team
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+
18
+ """Inference-only Qwen3MoE model compatible with HuggingFace weights."""
19
+
20
+ from functools import partial
21
+ from typing import Any, Dict, Iterable, Optional, Tuple
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ from torch import nn
26
+
27
+ from sglang.srt.distributed import (
28
+ get_tensor_model_parallel_rank,
29
+ get_tensor_model_parallel_world_size,
30
+ split_tensor_along_last_dim,
31
+ tensor_model_parallel_all_gather,
32
+ tensor_model_parallel_all_reduce,
33
+ )
34
+ from sglang.srt.layers.activation import SiluAndMul
35
+ from sglang.srt.layers.layernorm import RMSNorm
36
+ from sglang.srt.layers.linear import (
37
+ MergedColumnParallelLinear,
38
+ QKVParallelLinear,
39
+ ReplicatedLinear,
40
+ RowParallelLinear,
41
+ )
42
+ from sglang.srt.layers.logits_processor import LogitsProcessor
43
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
44
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
45
+ from sglang.srt.layers.radix_attention import RadixAttention
46
+ from sglang.srt.layers.rotary_embedding import get_rope
47
+ from sglang.srt.layers.vocab_parallel_embedding import (
48
+ ParallelLMHead,
49
+ VocabParallelEmbedding,
50
+ )
51
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
52
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
53
+ from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
54
+ from sglang.srt.models.qwen2_moe import Qwen2MoeModel
55
+ from sglang.srt.utils import add_prefix
56
+
57
+ Qwen3MoeConfig = None
58
+
59
+
60
+ class Qwen3MoeSparseMoeBlock(nn.Module):
61
+ def __init__(
62
+ self,
63
+ config: Qwen3MoeConfig,
64
+ quant_config: Optional[QuantizationConfig] = None,
65
+ prefix: str = "",
66
+ ):
67
+ super().__init__()
68
+ self.tp_size = get_tensor_model_parallel_world_size()
69
+
70
+ if self.tp_size > config.num_experts:
71
+ raise ValueError(
72
+ f"Tensor parallel size {self.tp_size} is greater than "
73
+ f"the number of experts {config.num_experts}."
74
+ )
75
+
76
+ self.experts = FusedMoE(
77
+ num_experts=config.num_experts,
78
+ top_k=config.num_experts_per_tok,
79
+ hidden_size=config.hidden_size,
80
+ intermediate_size=config.moe_intermediate_size,
81
+ reduce_results=False,
82
+ renormalize=config.norm_topk_prob,
83
+ quant_config=quant_config,
84
+ prefix=add_prefix("experts", prefix),
85
+ )
86
+
87
+ self.gate = ReplicatedLinear(
88
+ config.hidden_size,
89
+ config.num_experts,
90
+ bias=False,
91
+ quant_config=None,
92
+ prefix=add_prefix("gate", prefix),
93
+ )
94
+
95
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
96
+ num_tokens, hidden_dim = hidden_states.shape
97
+ hidden_states = hidden_states.view(-1, hidden_dim)
98
+
99
+ # router_logits: (num_tokens, n_experts)
100
+ router_logits, _ = self.gate(hidden_states)
101
+ final_hidden_states = self.experts(
102
+ hidden_states=hidden_states, router_logits=router_logits
103
+ )
104
+ if self.tp_size > 1:
105
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
106
+
107
+ return final_hidden_states.view(num_tokens, hidden_dim)
108
+
109
+
110
+ class Qwen3MoeAttention(nn.Module):
111
+ def __init__(
112
+ self,
113
+ hidden_size: int,
114
+ num_heads: int,
115
+ num_kv_heads: int,
116
+ layer_id: int = 0,
117
+ rope_theta: float = 10000,
118
+ rope_scaling: Optional[Dict[str, Any]] = None,
119
+ max_position_embeddings: int = 8192,
120
+ head_dim: Optional[int] = None,
121
+ rms_norm_eps: float = 1e-06,
122
+ attention_bias: bool = False,
123
+ quant_config: Optional[QuantizationConfig] = None,
124
+ prefix: str = "",
125
+ ) -> None:
126
+ super().__init__()
127
+ self.hidden_size = hidden_size
128
+ self.tp_size = get_tensor_model_parallel_world_size()
129
+ self.total_num_heads = num_heads
130
+ assert self.total_num_heads % self.tp_size == 0
131
+ self.num_heads = self.total_num_heads // self.tp_size
132
+ self.total_num_kv_heads = num_kv_heads
133
+ if self.total_num_kv_heads >= self.tp_size:
134
+ # Number of KV heads is greater than TP size, so we partition
135
+ # the KV heads across multiple tensor parallel GPUs.
136
+ assert self.total_num_kv_heads % self.tp_size == 0
137
+ else:
138
+ # Number of KV heads is less than TP size, so we replicate
139
+ # the KV heads across multiple tensor parallel GPUs.
140
+ assert self.tp_size % self.total_num_kv_heads == 0
141
+ self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
142
+ self.head_dim = head_dim or hidden_size // self.total_num_heads
143
+ self.q_size = self.num_heads * self.head_dim
144
+ self.kv_size = self.num_kv_heads * self.head_dim
145
+ self.scaling = self.head_dim**-0.5
146
+ self.rope_theta = rope_theta
147
+ self.max_position_embeddings = max_position_embeddings
148
+ self.tp_rank = get_tensor_model_parallel_rank()
149
+
150
+ self.qkv_proj = QKVParallelLinear(
151
+ hidden_size,
152
+ self.head_dim,
153
+ self.total_num_heads,
154
+ self.total_num_kv_heads,
155
+ bias=attention_bias,
156
+ quant_config=quant_config,
157
+ prefix=add_prefix("qkv_proj", prefix),
158
+ )
159
+
160
+ self.o_proj = RowParallelLinear(
161
+ self.total_num_heads * self.head_dim,
162
+ hidden_size,
163
+ bias=attention_bias,
164
+ quant_config=quant_config,
165
+ prefix=add_prefix("o_proj", prefix),
166
+ )
167
+
168
+ self.rotary_emb = get_rope(
169
+ self.head_dim,
170
+ rotary_dim=self.head_dim,
171
+ max_position=max_position_embeddings,
172
+ base=rope_theta,
173
+ rope_scaling=rope_scaling,
174
+ )
175
+ self.attn = RadixAttention(
176
+ self.num_heads,
177
+ self.head_dim,
178
+ self.scaling,
179
+ num_kv_heads=self.num_kv_heads,
180
+ layer_id=layer_id,
181
+ prefix=add_prefix("attn", prefix),
182
+ )
183
+
184
+ self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
185
+ self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
186
+
187
+ def _apply_qk_norm(
188
+ self, q: torch.Tensor, k: torch.Tensor
189
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
190
+ q_by_head = q.reshape(-1, self.head_dim)
191
+ q_by_head = self.q_norm(q_by_head)
192
+ q = q_by_head.view(q.shape)
193
+ k_by_head = k.reshape(-1, self.head_dim)
194
+ k_by_head = self.k_norm(k_by_head)
195
+ k = k_by_head.view(k.shape)
196
+ return q, k
197
+
198
+ def forward(
199
+ self,
200
+ positions: torch.Tensor,
201
+ hidden_states: torch.Tensor,
202
+ forward_batch: ForwardBatch,
203
+ ) -> torch.Tensor:
204
+ qkv, _ = self.qkv_proj(hidden_states)
205
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
206
+ q, k = self._apply_qk_norm(q, k)
207
+ q, k = self.rotary_emb(positions, q, k)
208
+ attn_output = self.attn(q, k, v, forward_batch)
209
+ output, _ = self.o_proj(attn_output)
210
+ return output
211
+
212
+
213
+ class Qwen3MoeDecoderLayer(nn.Module):
214
+ def __init__(
215
+ self,
216
+ config: Qwen3MoeConfig,
217
+ layer_id: int,
218
+ quant_config: Optional[QuantizationConfig] = None,
219
+ prefix: str = "",
220
+ ) -> None:
221
+ super().__init__()
222
+ self.hidden_size = config.hidden_size
223
+ rope_theta = getattr(config, "rope_theta", 10000)
224
+ rope_scaling = getattr(config, "rope_scaling", None)
225
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
226
+ head_dim = getattr(
227
+ config, "head_dim", config.hidden_size // config.num_attention_heads
228
+ )
229
+ rms_norm_eps = config.rms_norm_eps
230
+ attention_bias = config.attention_bias
231
+ self.self_attn = Qwen3MoeAttention(
232
+ hidden_size=self.hidden_size,
233
+ num_heads=config.num_attention_heads,
234
+ num_kv_heads=config.num_key_value_heads,
235
+ layer_id=layer_id,
236
+ rope_theta=rope_theta,
237
+ rope_scaling=rope_scaling,
238
+ max_position_embeddings=max_position_embeddings,
239
+ head_dim=head_dim,
240
+ rms_norm_eps=rms_norm_eps,
241
+ attention_bias=attention_bias,
242
+ quant_config=quant_config,
243
+ prefix=add_prefix("self_attn", prefix),
244
+ )
245
+
246
+ # Note: Qwen/Qwen2-57B-A14B-Instruct does not have
247
+ # `mlp_only_layers` in the config.
248
+ mlp_only_layers = (
249
+ [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
250
+ )
251
+ if (layer_id not in mlp_only_layers) and (
252
+ config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0
253
+ ):
254
+ self.mlp = Qwen3MoeSparseMoeBlock(
255
+ config=config,
256
+ quant_config=quant_config,
257
+ prefix=add_prefix("mlp", prefix),
258
+ )
259
+ else:
260
+ self.mlp = Qwen3MoeMLP(
261
+ hidden_size=config.hidden_size,
262
+ intermediate_size=config.intermediate_size,
263
+ hidden_act=config.hidden_act,
264
+ quant_config=quant_config,
265
+ prefix=add_prefix("mlp", prefix),
266
+ )
267
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
268
+ self.post_attention_layernorm = RMSNorm(
269
+ config.hidden_size, eps=config.rms_norm_eps
270
+ )
271
+
272
+ def forward(
273
+ self,
274
+ positions: torch.Tensor,
275
+ hidden_states: torch.Tensor,
276
+ forward_batch: ForwardBatch,
277
+ residual: Optional[torch.Tensor],
278
+ ) -> torch.Tensor:
279
+ # Self Attention
280
+ if residual is None:
281
+ residual = hidden_states
282
+ hidden_states = self.input_layernorm(hidden_states)
283
+ else:
284
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
285
+ hidden_states = self.self_attn(
286
+ positions=positions,
287
+ hidden_states=hidden_states,
288
+ forward_batch=forward_batch,
289
+ )
290
+
291
+ # Fully Connected
292
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
293
+ hidden_states = self.mlp(hidden_states)
294
+ return hidden_states, residual
295
+
296
+
297
+ class Qwen3MoeModel(Qwen2MoeModel):
298
+ def __init__(
299
+ self,
300
+ config: Qwen3MoeConfig,
301
+ quant_config: Optional[QuantizationConfig] = None,
302
+ prefix: str = "",
303
+ ) -> None:
304
+ super().__init__(
305
+ config=config,
306
+ quant_config=quant_config,
307
+ prefix=prefix,
308
+ decoder_layer_type=Qwen3MoeDecoderLayer,
309
+ )
310
+
311
+
312
+ class Qwen3MoeForCausalLM(nn.Module):
313
+
314
+ fall_back_to_pt_during_load = False
315
+
316
+ def __init__(
317
+ self,
318
+ config: Qwen3MoeConfig,
319
+ quant_config: Optional[QuantizationConfig] = None,
320
+ prefix: str = "",
321
+ ) -> None:
322
+ super().__init__()
323
+ self.config = config
324
+ self.quant_config = quant_config
325
+ self.model = Qwen3MoeModel(
326
+ config, quant_config, prefix=add_prefix("model", prefix)
327
+ )
328
+ self.lm_head = ParallelLMHead(
329
+ config.vocab_size,
330
+ config.hidden_size,
331
+ quant_config=quant_config,
332
+ prefix=add_prefix("lm_head", prefix),
333
+ )
334
+ self.logits_processor = LogitsProcessor(config)
335
+
336
+ @torch.no_grad()
337
+ def forward(
338
+ self,
339
+ input_ids: torch.Tensor,
340
+ positions: torch.Tensor,
341
+ forward_batch: ForwardBatch,
342
+ input_embeds: torch.Tensor = None,
343
+ ) -> torch.Tensor:
344
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
345
+ return self.logits_processor(
346
+ input_ids, hidden_states, self.lm_head, forward_batch
347
+ )
348
+
349
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
350
+ stacked_params_mapping = [
351
+ # (param_name, shard_name, shard_id)
352
+ ("qkv_proj", "q_proj", "q"),
353
+ ("qkv_proj", "k_proj", "k"),
354
+ ("qkv_proj", "v_proj", "v"),
355
+ ("gate_up_proj", "gate_proj", 0),
356
+ ("gate_up_proj", "up_proj", 1),
357
+ ]
358
+
359
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
360
+ ckpt_gate_proj_name="gate_proj",
361
+ ckpt_down_proj_name="down_proj",
362
+ ckpt_up_proj_name="up_proj",
363
+ num_experts=self.config.num_experts,
364
+ )
365
+
366
+ params_dict = dict(self.named_parameters())
367
+ for name, loaded_weight in weights:
368
+ if "rotary_emb.inv_freq" in name:
369
+ continue
370
+ for param_name, weight_name, shard_id in stacked_params_mapping:
371
+ # Skip non-stacked layers and experts (experts handled below).
372
+ if weight_name not in name:
373
+ continue
374
+ # We have mlp.experts[0].gate_proj in the checkpoint.
375
+ # Since we handle the experts below in expert_params_mapping,
376
+ # we need to skip here BEFORE we update the name, otherwise
377
+ # name will be updated to mlp.experts[0].gate_up_proj, which
378
+ # will then be updated below in expert_params_mapping
379
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
380
+ if "mlp.experts" in name:
381
+ continue
382
+ name = name.replace(weight_name, param_name)
383
+ # Skip loading extra bias for GPTQ models.
384
+ if name.endswith(".bias") and name not in params_dict:
385
+ continue
386
+ if name not in params_dict:
387
+ continue
388
+
389
+ param = params_dict[name]
390
+ weight_loader = param.weight_loader
391
+ weight_loader(param, loaded_weight, shard_id)
392
+ break
393
+ else:
394
+ for mapping in expert_params_mapping:
395
+ param_name, weight_name, expert_id, shard_id = mapping
396
+ if weight_name not in name:
397
+ continue
398
+ name = name.replace(weight_name, param_name)
399
+ param = params_dict[name]
400
+ weight_loader = param.weight_loader
401
+ weight_loader(
402
+ param,
403
+ loaded_weight,
404
+ name,
405
+ shard_id=shard_id,
406
+ expert_id=expert_id,
407
+ )
408
+ break
409
+ else:
410
+ # Skip loading extra bias for GPTQ models.
411
+ if name.endswith(".bias") and name not in params_dict:
412
+ continue
413
+ if name not in params_dict:
414
+ continue
415
+
416
+ param = params_dict[name]
417
+ weight_loader = getattr(
418
+ param, "weight_loader", default_weight_loader
419
+ )
420
+ weight_loader(param, loaded_weight)
421
+
422
+
423
+ EntryClass = Qwen3MoeForCausalLM
@@ -1,4 +1,3 @@
1
- import re
2
1
  from typing import Dict, Tuple
3
2
 
4
3
 
@@ -10,12 +10,11 @@ import torch
10
10
  import sglang.srt.sampling.penaltylib as penaltylib
11
11
  from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
12
12
 
13
- logger = logging.getLogger(__name__)
14
-
15
-
16
13
  if TYPE_CHECKING:
17
14
  from sglang.srt.managers.schedule_batch import ScheduleBatch
18
15
 
16
+ logger = logging.getLogger(__name__)
17
+
19
18
 
20
19
  @dataclasses.dataclass
21
20
  class SamplingBatchInfo:
sglang/srt/server_args.py CHANGED
@@ -155,7 +155,6 @@ class ServerArgs:
155
155
  enable_nccl_nvls: bool = False
156
156
  disable_outlines_disk_cache: bool = False
157
157
  disable_custom_all_reduce: bool = False
158
- disable_mla: bool = False
159
158
  enable_llama4_multimodal: Optional[bool] = None
160
159
  disable_overlap_schedule: bool = False
161
160
  enable_mixed_chunk: bool = False
@@ -180,13 +179,12 @@ class ServerArgs:
180
179
  tool_call_parser: Optional[str] = None
181
180
  enable_hierarchical_cache: bool = False
182
181
  hicache_ratio: float = 2.0
183
- enable_flashinfer_mla: bool = False # TODO: remove this argument
184
- enable_flashmla: bool = False
185
182
  flashinfer_mla_disable_ragged: bool = False
186
183
  warmups: Optional[str] = None
184
+ moe_dense_tp_size: Optional[int] = None
187
185
  n_share_experts_fusion: int = 0
188
- disable_shared_experts_fusion: bool = False
189
186
  disable_chunked_prefix_cache: bool = False
187
+ disable_fast_image_processor: bool = False
190
188
 
191
189
  # Debug tensor dumps
192
190
  debug_tensor_dump_output_folder: Optional[str] = None
@@ -197,9 +195,7 @@ class ServerArgs:
197
195
  disaggregation_mode: str = "null"
198
196
  disaggregation_bootstrap_port: int = 8998
199
197
  disaggregation_transfer_backend: str = "mooncake"
200
-
201
- # multimodal
202
- disable_fast_image_processor: bool = False
198
+ disaggregation_ib_device: Optional[str] = None
203
199
 
204
200
  def __post_init__(self):
205
201
  # Expert parallelism
@@ -232,9 +228,6 @@ class ServerArgs:
232
228
  # GPU memory is not known yet or no GPU is available.
233
229
  gpu_mem = None
234
230
 
235
- if is_hip():
236
- self.disable_shared_experts_fusion = True
237
-
238
231
  # Set mem fraction static, which depends on the tensor parallelism size
239
232
  if self.mem_fraction_static is None:
240
233
  if self.tp_size >= 16:
@@ -257,7 +250,12 @@ class ServerArgs:
257
250
 
258
251
  assert self.chunked_prefill_size % self.page_size == 0
259
252
 
260
- if self.enable_flashmla is True:
253
+ assert self.moe_dense_tp_size in {
254
+ 1,
255
+ None,
256
+ }, f"moe_dense_tp_size only support 1 and None currently"
257
+
258
+ if self.attention_backend == "flashmla":
261
259
  logger.warning(
262
260
  "FlashMLA only supports a page_size of 64, change page_size to 64."
263
261
  )
@@ -394,6 +392,10 @@ class ServerArgs:
394
392
  os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
395
393
  "1" if self.enable_torch_compile else "0"
396
394
  )
395
+ # Set env var before grammar backends init
396
+ os.environ["SGLANG_DISABLE_OUTLINES_DISK_CACHE"] = (
397
+ "1" if self.disable_outlines_disk_cache else "0"
398
+ )
397
399
 
398
400
  @staticmethod
399
401
  def add_cli_args(parser: argparse.ArgumentParser):
@@ -826,7 +828,7 @@ class ServerArgs:
826
828
  parser.add_argument(
827
829
  "--attention-backend",
828
830
  type=str,
829
- choices=["flashinfer", "triton", "torch_native", "fa3"],
831
+ choices=["flashinfer", "triton", "torch_native", "fa3", "flashmla"],
830
832
  default=ServerArgs.attention_backend,
831
833
  help="Choose the kernels for attention layers.",
832
834
  )
@@ -846,13 +848,13 @@ class ServerArgs:
846
848
  )
847
849
  parser.add_argument(
848
850
  "--enable-flashinfer-mla",
849
- action="store_true",
850
- help="Enable FlashInfer MLA optimization. This argument will be deprecated soon! Please use '--attention-backend flashinfer' instead for switching on flashfiner mla!",
851
+ action=DeprecatedAction,
852
+ help="--enable-flashinfer-mla is deprecated. Please use '--attention-backend flashinfer' instead.",
851
853
  )
852
854
  parser.add_argument(
853
855
  "--enable-flashmla",
854
- action="store_true",
855
- help="Enable FlashMLA decode optimization",
856
+ action=DeprecatedAction,
857
+ help="--enable-flashmla is deprecated. Please use '--attention-backend flashmla' instead.",
856
858
  )
857
859
  parser.add_argument(
858
860
  "--flashinfer-mla-disable-ragged",
@@ -977,11 +979,6 @@ class ServerArgs:
977
979
  action="store_true",
978
980
  help="Disable the custom all-reduce kernel and fall back to NCCL.",
979
981
  )
980
- parser.add_argument(
981
- "--disable-mla",
982
- action="store_true",
983
- help="Disable Multi-head Latent Attention (MLA) for DeepSeek V2/V3/R1 series models.",
984
- )
985
982
  parser.add_argument(
986
983
  "--enable-llama4-multimodal",
987
984
  default=ServerArgs.enable_llama4_multimodal,
@@ -1111,6 +1108,12 @@ class ServerArgs:
1111
1108
  action="store_true",
1112
1109
  help="Enabling DeepEP MoE implementation for EP MoE.",
1113
1110
  )
1111
+ parser.add_argument(
1112
+ "--moe-dense-tp-size",
1113
+ type=int,
1114
+ default=ServerArgs.moe_dense_tp_size,
1115
+ help="TP size for MoE dense MLP layers. This flag is useful when, with large TP size, there are errors caused by weights in MLP layers having dimension smaller than the min dimension GEMM supports.",
1116
+ )
1114
1117
  parser.add_argument(
1115
1118
  "--deepep-mode",
1116
1119
  type=str,
@@ -1123,18 +1126,18 @@ class ServerArgs:
1123
1126
  "--n-share-experts-fusion",
1124
1127
  type=int,
1125
1128
  default=0,
1126
- help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1 "
1127
- "we use tp_size by default.",
1129
+ help="The number of shared_experts need to be replicated to fuse with normal experts in deepseek v3/r1, "
1130
+ "set it to tp_size can get best optimized performace.",
1128
1131
  )
1129
1132
  parser.add_argument(
1130
- "--disable-shared-experts-fusion",
1133
+ "--disable-chunked-prefix-cache",
1131
1134
  action="store_true",
1132
- help="Disable shared experts fusion by setting n_share_experts_fusion to 0.",
1135
+ help="Disable chunked prefix cache feature for deepseek, which should save overhead for short sequences.",
1133
1136
  )
1134
1137
  parser.add_argument(
1135
- "--disable-chunked-prefix-cache",
1138
+ "--disable-fast-image-processor",
1136
1139
  action="store_true",
1137
- help="Disable chunked prefix cache feature for deepseek, which should save overhead for short sequences.",
1140
+ help="Adopt base image processor instead of fast image processor.",
1138
1141
  )
1139
1142
 
1140
1143
  # Server warmups
@@ -1186,12 +1189,11 @@ class ServerArgs:
1186
1189
  default=ServerArgs.disaggregation_transfer_backend,
1187
1190
  help="The backend for disaggregation transfer. Default is mooncake.",
1188
1191
  )
1189
-
1190
- # Multimodal
1191
1192
  parser.add_argument(
1192
- "--disable-fast-image-processor",
1193
- action="store_true",
1194
- help="Adopt base image processor instead of fast image processor.",
1193
+ "--disaggregation-ib-device",
1194
+ type=str,
1195
+ default=ServerArgs.disaggregation_ib_device,
1196
+ help="The ib device for disaggregation transfer. Default is None, it will be detected automatically if using the mooncake backend.",
1195
1197
  )
1196
1198
 
1197
1199
  @classmethod
@@ -271,14 +271,11 @@ class EAGLEWorker(TpModelWorker):
271
271
  )
272
272
  elif batch.forward_mode.is_idle():
273
273
  model_worker_batch = batch.get_model_worker_batch()
274
- logits_output, next_token_ids, _ = (
275
- self.target_worker.forward_batch_generation(
276
- ForwardBatch.init_new(
277
- model_worker_batch, self.target_worker.model_runner
278
- )
279
- )
274
+ logits_output, next_token_ids = self.target_worker.forward_batch_generation(
275
+ model_worker_batch
280
276
  )
281
- return logits_output, next_token_ids, model_worker_batch.bid, 0, False
277
+
278
+ return logits_output, next_token_ids, model_worker_batch.bid, 0
282
279
  else:
283
280
  logits_output, next_token_ids, bid = self.forward_target_extend(batch)
284
281
  with self.draft_tp_context(self.draft_model_runner.tp_group):
sglang/srt/utils.py CHANGED
@@ -55,7 +55,6 @@ import torch.distributed
55
55
  import torch.distributed as dist
56
56
  import triton
57
57
  import zmq
58
- from decord import VideoReader, cpu
59
58
  from fastapi.responses import ORJSONResponse
60
59
  from packaging import version as pkg_version
61
60
  from PIL import Image
@@ -545,6 +544,9 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
545
544
 
546
545
 
547
546
  def encode_video(video_path, frame_count_limit=None):
547
+ # Lazy import because decord is not available on some arm platforms.
548
+ from decord import VideoReader, cpu
549
+
548
550
  if not os.path.exists(video_path):
549
551
  logger.error(f"Video {video_path} does not exist")
550
552
  return []
@@ -1930,3 +1932,16 @@ def is_fa3_default_architecture(hf_config):
1930
1932
  "MistralForCausalLM",
1931
1933
  }
1932
1934
  return architectures[0] in default_archs
1935
+
1936
+
1937
+ # Can be more general if it is used in multiple places (keep it simple and thus not general now)
1938
+ class BumpAllocator:
1939
+ def __init__(self, buffer_size: int, dtype, device):
1940
+ self._buffer = torch.zeros((buffer_size,), dtype=dtype, device=device)
1941
+ self._pointer = 0
1942
+
1943
+ def allocate(self, size: int):
1944
+ assert self._pointer + size <= len(self._buffer)
1945
+ output = self._buffer[self._pointer : self._pointer + size]
1946
+ self._pointer += size
1947
+ return output
sglang/test/runners.py CHANGED
@@ -26,8 +26,8 @@ from transformers import (
26
26
  AutoProcessor,
27
27
  )
28
28
 
29
+ from sglang.srt.entrypoints.engine import Engine
29
30
  from sglang.srt.hf_transformers_utils import get_tokenizer
30
- from sglang.srt.server import Engine
31
31
  from sglang.srt.utils import load_image
32
32
  from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l
33
33
 
@@ -51,6 +51,8 @@ NUM_TOP_LOGPROBS = 5
51
51
  def get_dtype_str(torch_dtype):
52
52
  if torch_dtype is torch.float16:
53
53
  return "float16"
54
+ if torch_dtype is torch.float32:
55
+ return "float32"
54
56
  else:
55
57
  raise NotImplementedError()
56
58
 
@@ -447,6 +449,7 @@ class SRTRunner:
447
449
  port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
448
450
  lora_paths: List[str] = None,
449
451
  max_loras_per_batch: int = 4,
452
+ attention_backend: Optional[str] = None,
450
453
  lora_backend: str = "triton",
451
454
  disable_cuda_graph: bool = False,
452
455
  disable_radix_cache: bool = False,
@@ -487,6 +490,7 @@ class SRTRunner:
487
490
  lora_paths=lora_paths,
488
491
  max_loras_per_batch=max_loras_per_batch,
489
492
  lora_backend=lora_backend,
493
+ attention_backend=attention_backend,
490
494
  disable_cuda_graph=disable_cuda_graph,
491
495
  disable_radix_cache=disable_radix_cache,
492
496
  chunked_prefill_size=chunked_prefill_size,