sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. sglang/bench_one_batch.py +113 -17
  2. sglang/srt/configs/model_config.py +35 -0
  3. sglang/srt/conversation.py +9 -5
  4. sglang/srt/disaggregation/base/conn.py +5 -2
  5. sglang/srt/disaggregation/decode.py +6 -1
  6. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  7. sglang/srt/disaggregation/mooncake/conn.py +243 -135
  8. sglang/srt/disaggregation/prefill.py +2 -0
  9. sglang/srt/distributed/parallel_state.py +11 -9
  10. sglang/srt/entrypoints/context.py +244 -0
  11. sglang/srt/entrypoints/engine.py +4 -3
  12. sglang/srt/entrypoints/harmony_utils.py +370 -0
  13. sglang/srt/entrypoints/http_server.py +71 -0
  14. sglang/srt/entrypoints/openai/protocol.py +227 -1
  15. sglang/srt/entrypoints/openai/serving_chat.py +278 -42
  16. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +174 -0
  18. sglang/srt/entrypoints/tool.py +87 -0
  19. sglang/srt/eplb/expert_location.py +5 -1
  20. sglang/srt/function_call/harmony_tool_parser.py +130 -0
  21. sglang/srt/hf_transformers_utils.py +30 -3
  22. sglang/srt/jinja_template_utils.py +8 -1
  23. sglang/srt/layers/attention/aiter_backend.py +5 -8
  24. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  25. sglang/srt/layers/attention/triton_backend.py +85 -14
  26. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  28. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  29. sglang/srt/layers/attention/vision.py +13 -5
  30. sglang/srt/layers/communicator.py +21 -4
  31. sglang/srt/layers/dp_attention.py +12 -0
  32. sglang/srt/layers/linear.py +2 -7
  33. sglang/srt/layers/moe/cutlass_moe.py +20 -6
  34. sglang/srt/layers/moe/ep_moe/layer.py +77 -73
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +416 -35
  37. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  38. sglang/srt/layers/moe/topk.py +12 -3
  39. sglang/srt/layers/moe/utils.py +16 -0
  40. sglang/srt/layers/quantization/__init__.py +22 -0
  41. sglang/srt/layers/quantization/fp4.py +557 -0
  42. sglang/srt/layers/quantization/fp8.py +3 -6
  43. sglang/srt/layers/quantization/fp8_utils.py +29 -0
  44. sglang/srt/layers/quantization/modelopt_quant.py +259 -64
  45. sglang/srt/layers/quantization/mxfp4.py +651 -0
  46. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  47. sglang/srt/layers/quantization/quark/__init__.py +0 -0
  48. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  49. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  50. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  51. sglang/srt/layers/quantization/quark/utils.py +107 -0
  52. sglang/srt/layers/quantization/unquant.py +60 -6
  53. sglang/srt/layers/quantization/w4afp8.py +1 -1
  54. sglang/srt/layers/rotary_embedding.py +225 -1
  55. sglang/srt/layers/utils.py +9 -0
  56. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  57. sglang/srt/lora/lora_manager.py +70 -14
  58. sglang/srt/lora/lora_registry.py +3 -2
  59. sglang/srt/lora/mem_pool.py +43 -5
  60. sglang/srt/managers/cache_controller.py +55 -30
  61. sglang/srt/managers/detokenizer_manager.py +1 -1
  62. sglang/srt/managers/io_struct.py +15 -3
  63. sglang/srt/managers/mm_utils.py +5 -11
  64. sglang/srt/managers/schedule_batch.py +28 -7
  65. sglang/srt/managers/scheduler.py +26 -12
  66. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  67. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  68. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  69. sglang/srt/managers/template_manager.py +35 -1
  70. sglang/srt/managers/tokenizer_manager.py +24 -6
  71. sglang/srt/managers/tp_worker.py +3 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  73. sglang/srt/mem_cache/hiradix_cache.py +53 -5
  74. sglang/srt/mem_cache/memory_pool_host.py +1 -1
  75. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  76. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  77. sglang/srt/model_executor/cuda_graph_runner.py +7 -6
  78. sglang/srt/model_executor/forward_batch_info.py +35 -14
  79. sglang/srt/model_executor/model_runner.py +19 -2
  80. sglang/srt/model_loader/weight_utils.py +10 -0
  81. sglang/srt/models/bailing_moe.py +425 -0
  82. sglang/srt/models/deepseek_v2.py +72 -33
  83. sglang/srt/models/ernie4.py +426 -0
  84. sglang/srt/models/ernie4_eagle.py +203 -0
  85. sglang/srt/models/gemma3n_mm.py +39 -0
  86. sglang/srt/models/glm4_moe.py +24 -12
  87. sglang/srt/models/gpt_oss.py +1134 -0
  88. sglang/srt/models/qwen2.py +6 -0
  89. sglang/srt/models/qwen2_moe.py +6 -0
  90. sglang/srt/models/qwen3_moe.py +32 -6
  91. sglang/srt/models/step3_vl.py +9 -0
  92. sglang/srt/models/transformers.py +2 -5
  93. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  94. sglang/srt/reasoning_parser.py +18 -39
  95. sglang/srt/server_args.py +142 -7
  96. sglang/srt/two_batch_overlap.py +157 -5
  97. sglang/srt/utils.py +38 -2
  98. sglang/test/runners.py +2 -2
  99. sglang/test/test_utils.py +1 -1
  100. sglang/version.py +1 -1
  101. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +16 -14
  102. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +105 -84
  103. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
  104. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
  105. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -6,15 +6,50 @@ from typing import TYPE_CHECKING, Optional
6
6
 
7
7
  import torch
8
8
  from sgl_kernel import gelu_and_mul, silu_and_mul
9
- from triton_kernels.matmul_ogs import matmul_ogs
9
+ from triton_kernels.matmul_ogs import (
10
+ FlexCtx,
11
+ FnSpecs,
12
+ FusedActivation,
13
+ PrecisionConfig,
14
+ matmul_ogs,
15
+ )
16
+ from triton_kernels.numerics import InFlexData
10
17
  from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
11
-
12
- from sglang.srt.utils import direct_register_custom_op
18
+ from triton_kernels.swiglu import swiglu_fn
13
19
 
14
20
  if TYPE_CHECKING:
15
21
  from sglang.srt.layers.moe.topk import TopKOutput
16
22
 
17
23
 
24
+ def quantize(w, dtype, dev, **opt):
25
+ if dtype == "bf16":
26
+ return w.to(torch.bfloat16), InFlexData()
27
+ elif dtype == "fp8":
28
+ wq = w.to(torch.float8_e4m3fn).transpose(-1, -2).contiguous().transpose(-1, -2)
29
+ return (
30
+ wq,
31
+ InFlexData(dtype=wq.dtype, scale=w.abs().max().unsqueeze(0)),
32
+ MicroscalingCtx(),
33
+ )
34
+ else:
35
+ assert dtype == "mx4", f"{dtype=}"
36
+ swizzle_mx_scale = opt["swizzle_mx_scale"]
37
+ swizzle_axis = 2 if swizzle_mx_scale else None
38
+ w = w.to(torch.bfloat16)
39
+ w, mx_scales, weight_scale_shape = downcast_to_mxfp(
40
+ w, torch.uint8, axis=1, swizzle_axis=swizzle_axis
41
+ )
42
+ return (
43
+ w,
44
+ InFlexData(),
45
+ MicroscalingCtx(
46
+ weight_scale=mx_scales,
47
+ swizzle_mx=swizzle_mx_scale,
48
+ actual_weight_scale_shape=weight_scale_shape,
49
+ ),
50
+ )
51
+
52
+
18
53
  def triton_kernel_moe_forward(
19
54
  hidden_states: torch.Tensor,
20
55
  w1: torch.Tensor,
@@ -146,3 +181,153 @@ def triton_kernel_fused_experts(
146
181
  )
147
182
 
148
183
  return intermediate_cache3
184
+
185
+
186
+ def triton_kernel_moe_with_bias_forward(
187
+ hidden_states: torch.Tensor,
188
+ w1: torch.Tensor,
189
+ w1_pcg,
190
+ b1: torch.Tensor,
191
+ w2: torch.Tensor,
192
+ w2_pcg,
193
+ b2: torch.Tensor,
194
+ topk_output: TopKOutput,
195
+ inplace: bool = False,
196
+ activation: str = "silu",
197
+ use_fp8_w8a8: bool = False,
198
+ per_channel_quant: bool = False,
199
+ global_num_experts: int = -1,
200
+ expert_map: Optional[torch.Tensor] = None,
201
+ w1_scale: Optional[torch.Tensor] = None,
202
+ w2_scale: Optional[torch.Tensor] = None,
203
+ a1_scale: Optional[torch.Tensor] = None,
204
+ a2_scale: Optional[torch.Tensor] = None,
205
+ block_shape: Optional[list[int]] = None,
206
+ activation_alpha: Optional[float] = None,
207
+ swiglu_limit: Optional[int] = None,
208
+ ) -> torch.Tensor:
209
+ assert topk_output.format.is_triton_kernel()
210
+ routing_data, gather_idx, scatter_idx = topk_output
211
+
212
+ return triton_kernel_fused_experts_with_bias(
213
+ hidden_states,
214
+ w1=w1,
215
+ w1_pcg=w1_pcg,
216
+ b1=b1,
217
+ w2=w2,
218
+ w2_pcg=w2_pcg,
219
+ b2=b2,
220
+ routing_data=routing_data,
221
+ gather_indx=gather_idx,
222
+ scatter_indx=scatter_idx,
223
+ inplace=inplace,
224
+ activation=activation,
225
+ use_fp8_w8a8=use_fp8_w8a8,
226
+ per_channel_quant=per_channel_quant,
227
+ global_num_experts=global_num_experts,
228
+ expert_map=expert_map,
229
+ w1_scale=w1_scale,
230
+ w2_scale=w2_scale,
231
+ a1_scale=a1_scale,
232
+ a2_scale=a2_scale,
233
+ block_shape=block_shape,
234
+ activation_alpha=activation_alpha,
235
+ swiglu_limit=swiglu_limit,
236
+ )
237
+
238
+
239
+ def triton_kernel_fused_experts_with_bias(
240
+ hidden_states: torch.Tensor,
241
+ w1: torch.Tensor,
242
+ w1_pcg,
243
+ b1: torch.Tensor,
244
+ w2: torch.Tensor,
245
+ w2_pcg,
246
+ b2: torch.Tensor,
247
+ routing_data: RoutingData,
248
+ gather_indx: GatherIndx,
249
+ scatter_indx: ScatterIndx,
250
+ inplace: bool = False,
251
+ activation: str = "silu",
252
+ use_fp8_w8a8: bool = False,
253
+ per_channel_quant: bool = False,
254
+ global_num_experts: int = -1,
255
+ expert_map: Optional[torch.Tensor] = None,
256
+ w1_scale: Optional[torch.Tensor] = None,
257
+ w2_scale: Optional[torch.Tensor] = None,
258
+ a1_scale: Optional[torch.Tensor] = None,
259
+ a2_scale: Optional[torch.Tensor] = None,
260
+ block_shape: Optional[list[int]] = None,
261
+ activation_alpha: Optional[float] = None,
262
+ swiglu_limit: Optional[int] = None,
263
+ ) -> torch.Tensor:
264
+ # print(f"here in triton moe with bias", b1.shape, b1.dtype, b2.shape, b2.dtype)
265
+ assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
266
+ assert per_channel_quant == False, "per_channel_quant is not supported"
267
+ assert expert_map == None, "expert_map is not supported"
268
+ assert w1_scale == None, "w1_scale is not supported"
269
+ assert w2_scale == None, "w2_scale is not supported"
270
+ assert a1_scale == None, "a1_scale is not supported"
271
+ assert a2_scale == None, "a2_scale is not supported"
272
+ assert block_shape == None, "block_shape is not supported"
273
+
274
+ # type check
275
+ assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
276
+ for w in (w1, w2):
277
+ # TODO assert bf16 or mxfp4
278
+ # assert (w.dtype == torch.bfloat16) or check-is-mxfp4, f"w must be bfloat16 or mxfp4 {w1.dtype=}"
279
+ pass
280
+
281
+ # Shape check
282
+ assert hidden_states.ndim == 2, "hidden_states must be 2D"
283
+ assert (
284
+ hidden_states.shape[-1] == w1.shape[-2]
285
+ ), f"hidden_states shape[-1] {hidden_states.shape} must be equal to w1 shape[-2] {w1.shape}"
286
+ assert (
287
+ w2.shape[-1] == w1.shape[1]
288
+ ), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}"
289
+
290
+ # feature check
291
+ assert inplace == False, "Inplace is not supported in new triton MoE kernel"
292
+
293
+ E, _, _ = w1.shape
294
+
295
+ if global_num_experts == -1:
296
+ global_num_experts = E
297
+
298
+ # TODO maybe completely remove this branch
299
+ if w1.dtype == torch.bfloat16:
300
+ device = "cuda"
301
+ optg = dict()
302
+ w1, w1_flex = quantize(w1, "bf16", device, **optg)
303
+ w1_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex))
304
+
305
+ w2, w2_flex = quantize(w2, "bf16", device, **optg)
306
+ w2_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex))
307
+
308
+ act = FusedActivation(
309
+ FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),
310
+ (activation_alpha, swiglu_limit),
311
+ 2,
312
+ )
313
+
314
+ intermediate_cache = matmul_ogs(
315
+ hidden_states,
316
+ w1,
317
+ b1,
318
+ routing_data,
319
+ gather_indx=gather_indx,
320
+ precision_config=w1_pcg,
321
+ gammas=None,
322
+ fused_activation=act,
323
+ )
324
+
325
+ return matmul_ogs(
326
+ intermediate_cache,
327
+ w2,
328
+ b2,
329
+ routing_data,
330
+ scatter_indx=scatter_indx,
331
+ precision_config=w2_pcg,
332
+ gammas=routing_data.gate_scal,
333
+ )
@@ -185,8 +185,9 @@ class TopK(CustomOp):
185
185
  expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
186
186
  ) -> TopKOutput:
187
187
  if self.use_triton_kernels:
188
+ # renormalize=True is equivalent to sm_first=False
188
189
  routing_data, gather_idx, scatter_idx = routing(
189
- router_logits, self.top_k, self.renormalize
190
+ router_logits, self.top_k, sm_first=not self.renormalize
190
191
  )
191
192
  return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
192
193
  else:
@@ -397,8 +398,12 @@ def grouped_topk_gpu(
397
398
  .reshape(num_token, -1)
398
399
  ) # [n, e]
399
400
  tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
401
+ # TODO: NPU can't support directly evaluating a comparison for now
400
402
  topk_weights, topk_ids = torch.topk(
401
- tmp_scores, k=topk, dim=-1, sorted=num_fused_shared_experts > 0
403
+ tmp_scores,
404
+ k=topk,
405
+ dim=-1,
406
+ sorted=(True if num_fused_shared_experts > 0 else False),
402
407
  )
403
408
  if num_fused_shared_experts:
404
409
  topk_ids[:, -1] = torch.randint(
@@ -488,8 +493,12 @@ def biased_grouped_topk_impl(
488
493
  tmp_scores = scores_for_choice.masked_fill(
489
494
  ~score_mask.bool(), float("-inf")
490
495
  ) # [n, e]
496
+ # TODO: NPU can't support directly evaluating a comparison for now
491
497
  _, topk_ids = torch.topk(
492
- tmp_scores, k=topk, dim=-1, sorted=num_fused_shared_experts > 0
498
+ tmp_scores,
499
+ k=topk,
500
+ dim=-1,
501
+ sorted=(True if num_fused_shared_experts > 0 else False),
493
502
  )
494
503
  topk_weights = scores.gather(1, topk_ids)
495
504
 
@@ -1,4 +1,20 @@
1
+ import importlib.util
1
2
  from enum import Enum
3
+ from functools import lru_cache
4
+
5
+ from packaging import version as pkg_version
6
+
7
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
8
+
9
+
10
+ @lru_cache(maxsize=1)
11
+ def should_use_flashinfer_trtllm_moe():
12
+ result = global_server_args_dict["enable_flashinfer_trtllm_moe"] and (
13
+ not importlib.util.find_spec("flashinfer")
14
+ or pkg_version.parse(__import__("flashinfer").__version__)
15
+ >= pkg_version.parse("0.2.9rc1")
16
+ )
17
+ return result
2
18
 
3
19
 
4
20
  class MoeA2ABackend(Enum):
@@ -47,6 +47,12 @@ from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
47
47
  from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
48
48
  CompressedTensorsConfig,
49
49
  )
50
+ from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
51
+
52
+ is_mxfp_supported = mxfp_supported()
53
+ if is_mxfp_supported:
54
+ from sglang.srt.layers.quantization.fp4 import MxFp4Config
55
+
50
56
  from sglang.srt.layers.quantization.fp8 import Fp8Config
51
57
  from sglang.srt.layers.quantization.gptq import (
52
58
  GPTQConfig,
@@ -60,6 +66,7 @@ from sglang.srt.layers.quantization.modelopt_quant import (
60
66
  ModelOptFp8Config,
61
67
  )
62
68
  from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
69
+ from sglang.srt.layers.quantization.mxfp4 import Mxfp4Config
63
70
  from sglang.srt.layers.quantization.petit import PetitNvFp4Config
64
71
  from sglang.srt.layers.quantization.qoq import QoQConfig
65
72
  from sglang.srt.layers.quantization.utils import get_linear_quant_method
@@ -85,6 +92,21 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
85
92
  "petit_nvfp4": PetitNvFp4Config,
86
93
  }
87
94
 
95
+
96
+ if is_cuda():
97
+ BASE_QUANTIZATION_METHODS.update(
98
+ {
99
+ "quark": Mxfp4Config,
100
+ "mxfp4": Mxfp4Config,
101
+ }
102
+ )
103
+ elif is_mxfp_supported and is_hip():
104
+ BASE_QUANTIZATION_METHODS.update(
105
+ {
106
+ "quark": MxFp4Config,
107
+ "mxfp4": MxFp4Config,
108
+ }
109
+ )
88
110
  # VLLM-dependent quantization methods
89
111
  VLLM_QUANTIZATION_METHODS = {
90
112
  "aqlm": AQLMConfig,