sglang 0.5.0rc1__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. sglang/bench_one_batch.py +0 -1
  2. sglang/srt/configs/model_config.py +1 -0
  3. sglang/srt/disaggregation/decode.py +0 -1
  4. sglang/srt/entrypoints/engine.py +2 -2
  5. sglang/srt/entrypoints/http_server.py +64 -0
  6. sglang/srt/entrypoints/openai/protocol.py +2 -0
  7. sglang/srt/entrypoints/openai/serving_chat.py +1 -0
  8. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  9. sglang/srt/layers/attention/flashinfer_backend.py +3 -0
  10. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
  11. sglang/srt/layers/attention/triton_backend.py +24 -27
  12. sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
  13. sglang/srt/layers/attention/trtllm_mla_backend.py +10 -3
  14. sglang/srt/layers/communicator.py +7 -7
  15. sglang/srt/layers/dp_attention.py +118 -27
  16. sglang/srt/layers/logits_processor.py +12 -18
  17. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/multimodal.py +156 -40
  29. sglang/srt/layers/quantization/__init__.py +5 -32
  30. sglang/srt/layers/quantization/awq.py +15 -16
  31. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
  32. sglang/srt/layers/quantization/gptq.py +12 -17
  33. sglang/srt/layers/quantization/marlin_utils.py +15 -5
  34. sglang/srt/layers/quantization/modelopt_quant.py +52 -30
  35. sglang/srt/layers/quantization/mxfp4.py +16 -2
  36. sglang/srt/layers/quantization/utils.py +52 -2
  37. sglang/srt/layers/sampler.py +5 -2
  38. sglang/srt/lora/layers.py +6 -2
  39. sglang/srt/managers/cache_controller.py +4 -1
  40. sglang/srt/managers/io_struct.py +14 -0
  41. sglang/srt/managers/schedule_batch.py +18 -39
  42. sglang/srt/managers/scheduler.py +3 -4
  43. sglang/srt/managers/tokenizer_manager.py +28 -18
  44. sglang/srt/mem_cache/allocator.py +8 -157
  45. sglang/srt/mem_cache/allocator_ascend.py +158 -0
  46. sglang/srt/mem_cache/chunk_cache.py +1 -1
  47. sglang/srt/model_executor/cuda_graph_runner.py +8 -21
  48. sglang/srt/model_executor/forward_batch_info.py +8 -10
  49. sglang/srt/model_executor/model_runner.py +57 -53
  50. sglang/srt/models/deepseek_nextn.py +2 -1
  51. sglang/srt/models/deepseek_v2.py +5 -3
  52. sglang/srt/models/glm4_moe.py +2 -2
  53. sglang/srt/models/glm4_moe_nextn.py +2 -1
  54. sglang/srt/models/gpt_oss.py +7 -2
  55. sglang/srt/models/llama.py +10 -2
  56. sglang/srt/models/llama4.py +18 -5
  57. sglang/srt/models/qwen2.py +2 -2
  58. sglang/srt/models/qwen2_moe.py +20 -5
  59. sglang/srt/models/qwen3_classification.py +78 -0
  60. sglang/srt/models/qwen3_moe.py +18 -5
  61. sglang/srt/models/step3_vl.py +6 -2
  62. sglang/srt/operations.py +17 -2
  63. sglang/srt/sampling/sampling_batch_info.py +7 -4
  64. sglang/srt/server_args.py +33 -7
  65. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
  66. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  67. sglang/srt/two_batch_overlap.py +4 -8
  68. sglang/test/test_marlin_moe.py +1 -1
  69. sglang/test/test_marlin_utils.py +1 -1
  70. sglang/version.py +1 -1
  71. {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +5 -5
  72. {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +75 -63
  73. sglang/srt/layers/quantization/scalar_type.py +0 -352
  74. {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
  75. {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
  76. {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -55,13 +55,7 @@ if is_mxfp_supported:
55
55
  from sglang.srt.layers.quantization.fp4 import MxFp4Config
56
56
 
57
57
  from sglang.srt.layers.quantization.fp8 import Fp8Config
58
- from sglang.srt.layers.quantization.gptq import (
59
- GPTQConfig,
60
- GPTQLinearMethod,
61
- GPTQMarlinConfig,
62
- GPTQMarlinLinearMethod,
63
- GPTQMarlinMoEMethod,
64
- )
58
+ from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
65
59
  from sglang.srt.layers.quantization.modelopt_quant import (
66
60
  ModelOptFp4Config,
67
61
  ModelOptFp8Config,
@@ -70,7 +64,6 @@ from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
70
64
  from sglang.srt.layers.quantization.mxfp4 import Mxfp4Config
71
65
  from sglang.srt.layers.quantization.petit import PetitNvFp4Config
72
66
  from sglang.srt.layers.quantization.qoq import QoQConfig
73
- from sglang.srt.layers.quantization.utils import get_linear_quant_method
74
67
  from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
75
68
  from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
76
69
  from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
@@ -86,6 +79,10 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
86
79
  "modelopt_fp4": ModelOptFp4Config,
87
80
  "w8a8_int8": W8A8Int8Config,
88
81
  "w8a8_fp8": W8A8Fp8Config,
82
+ "awq": AWQConfig,
83
+ "awq_marlin": AWQMarlinConfig,
84
+ "gptq": GPTQConfig,
85
+ "gptq_marlin": GPTQMarlinConfig,
89
86
  "moe_wna16": MoeWNA16Config,
90
87
  "compressed-tensors": CompressedTensorsConfig,
91
88
  "qoq": QoQConfig,
@@ -111,19 +108,15 @@ elif is_mxfp_supported and is_hip():
111
108
  # VLLM-dependent quantization methods
112
109
  VLLM_QUANTIZATION_METHODS = {
113
110
  "aqlm": AQLMConfig,
114
- "awq": AWQConfig,
115
111
  "deepspeedfp": DeepSpeedFPConfig,
116
112
  "tpu_int8": Int8TpuConfig,
117
113
  "fbgemm_fp8": FBGEMMFp8Config,
118
114
  "marlin": MarlinConfig,
119
115
  "gguf": GGUFConfig,
120
116
  "gptq_marlin_24": GPTQMarlin24Config,
121
- "awq_marlin": AWQMarlinConfig,
122
117
  "bitsandbytes": BitsAndBytesConfig,
123
118
  "qqq": QQQConfig,
124
119
  "experts_int8": ExpertsInt8Config,
125
- "gptq_marlin": GPTQMarlinConfig,
126
- "gptq": GPTQConfig,
127
120
  }
128
121
 
129
122
  QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS}
@@ -145,23 +138,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
145
138
  return QUANTIZATION_METHODS[quantization]
146
139
 
147
140
 
148
- def gptq_get_quant_method(self, layer, prefix):
149
- from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
150
-
151
- if isinstance(layer, FusedMoE):
152
- return GPTQMarlinMoEMethod(self)
153
-
154
- if isinstance(self, GPTQConfig):
155
- return get_linear_quant_method(
156
- self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
157
- )
158
- elif isinstance(self, GPTQMarlinConfig):
159
- return get_linear_quant_method(
160
- self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
161
- )
162
- return None
163
-
164
-
165
141
  original_isinstance = builtins.isinstance
166
142
 
167
143
 
@@ -239,10 +215,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
239
215
 
240
216
  def monkey_patch_quant_configs():
241
217
  """Apply all monkey patches in one place."""
242
- setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
243
- setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
244
218
 
245
- monkey_patch_moe_apply(GPTQMarlinMoEMethod)
246
219
  monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
247
220
  monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
248
221
 
@@ -29,29 +29,25 @@ from sglang.srt.layers.quantization.marlin_utils import (
29
29
  verify_marlin_supported,
30
30
  verify_marlin_supports_shape,
31
31
  )
32
- from sglang.srt.layers.quantization.scalar_type import scalar_types
33
32
  from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
34
- from sglang.srt.layers.quantization.utils import replace_parameter
33
+ from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter
35
34
 
36
35
  if TYPE_CHECKING:
37
36
  from sglang.srt.layers.moe.topk import TopKOutput
38
37
 
39
- try:
40
- from vllm import _custom_ops as ops
41
-
42
- warnings.warn(
43
- f"Using kernels directly from vllm. This might lead to performance degradation or "
44
- f"missing functionalities as certain kernels may not be optimized. "
45
- )
46
- except ImportError:
47
- ops = None
48
-
49
38
  from sglang.srt.utils import is_cuda, is_hip
50
39
 
51
40
  _is_cuda = is_cuda()
52
41
  _is_hip = is_hip()
53
42
  if _is_cuda:
54
- from sgl_kernel import awq_dequantize, fused_marlin_moe
43
+ from sgl_kernel import (
44
+ awq_dequantize,
45
+ awq_marlin_moe_repack,
46
+ awq_marlin_repack,
47
+ fused_marlin_moe,
48
+ )
49
+
50
+
55
51
  elif _is_hip:
56
52
  from sglang.srt.layers.quantization.awq_triton import (
57
53
  awq_dequantize_triton as awq_dequantize,
@@ -64,6 +60,9 @@ else:
64
60
  logger = logging.getLogger(__name__)
65
61
 
66
62
 
63
+ ScalarType, scalar_types = get_scalar_types()
64
+
65
+
67
66
  def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
68
67
  return any(module_name in prefix for module_name in modules_to_not_convert)
69
68
 
@@ -516,7 +515,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
516
515
  layer.workspace = marlin_make_workspace(device)
517
516
 
518
517
  # Repack weights from AWQ format to marlin format.
519
- marlin_qweight = ops.awq_marlin_repack(
518
+ marlin_qweight = awq_marlin_repack(
520
519
  layer.qweight,
521
520
  size_k=layer.input_size_per_partition,
522
521
  size_n=layer.output_size_per_partition,
@@ -684,7 +683,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
684
683
  requires_grad=False,
685
684
  )
686
685
 
687
- marlin_w13_qweight = ops.awq_marlin_moe_repack(
686
+ marlin_w13_qweight = awq_marlin_moe_repack(
688
687
  layer.w13_qweight,
689
688
  layer.w13_g_idx_sort_indices,
690
689
  size_k=layer.w13_qweight.shape[1],
@@ -693,7 +692,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
693
692
  )
694
693
  replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
695
694
 
696
- marlin_w2_qweight = ops.awq_marlin_moe_repack(
695
+ marlin_w2_qweight = awq_marlin_moe_repack(
697
696
  layer.w2_qweight,
698
697
  layer.w2_g_idx_sort_indices,
699
698
  size_k=layer.w2_qweight.shape[1],
@@ -16,7 +16,6 @@ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_qu
16
16
  from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
17
17
  from sglang.srt.layers.quantization.utils import (
18
18
  all_close_1d,
19
- cpu_has_amx_support,
20
19
  per_tensor_dequantize,
21
20
  replace_parameter,
22
21
  )
@@ -36,9 +36,9 @@ from sglang.srt.layers.quantization.marlin_utils import (
36
36
  marlin_zero_points,
37
37
  verify_marlin_supported,
38
38
  )
39
- from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
40
39
  from sglang.srt.layers.quantization.utils import (
41
40
  get_linear_quant_method,
41
+ get_scalar_types,
42
42
  replace_parameter,
43
43
  unpack_cols,
44
44
  )
@@ -46,20 +46,16 @@ from sglang.srt.layers.quantization.utils import (
46
46
  if TYPE_CHECKING:
47
47
  from sglang.srt.layers.moe.topk import TopKOutput
48
48
 
49
- try:
50
- from vllm import _custom_ops as ops
51
- except ImportError:
52
- ops = None
53
-
54
49
  from sglang.srt.utils import is_cuda
55
50
 
56
51
  _is_cuda = is_cuda()
57
52
 
58
53
  if _is_cuda:
59
- from sgl_kernel import fused_marlin_moe
54
+ from sgl_kernel import fused_marlin_moe, gptq_gemm, gptq_marlin_repack, gptq_shuffle
60
55
 
61
56
 
62
57
  logger = logging.getLogger(__name__)
58
+ ScalarType, scalar_types = get_scalar_types()
63
59
 
64
60
 
65
61
  def check_marlin_format(hf_quant_cfg: Dict[str, Any]) -> bool:
@@ -85,9 +81,7 @@ def gptq_marlin_moe_repack(
85
81
  dtype=b_q_weight.dtype,
86
82
  )
87
83
  for e in range(num_experts):
88
- output[e] = torch.ops.sgl_kernel.gptq_marlin_repack(
89
- b_q_weight[e], perm[e], size_k, size_n, num_bits
90
- )
84
+ output[e] = gptq_marlin_repack(b_q_weight[e], perm[e], size_k, size_n, num_bits)
91
85
  return output
92
86
 
93
87
 
@@ -204,11 +198,12 @@ class GPTQConfig(QuantizationConfig):
204
198
  from sglang.srt.layers.linear import LinearBase
205
199
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
206
200
 
207
- if isinstance(layer, LinearBase):
208
- return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
209
- elif isinstance(layer, FusedMoE):
201
+ if isinstance(layer, FusedMoE):
210
202
  raise TypeError("GPTQ Method does not support MoE, please use gptq_marlin")
211
- return None
203
+ else:
204
+ return get_linear_quant_method(
205
+ self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
206
+ )
212
207
 
213
208
 
214
209
  class GPTQMarlinConfig(QuantizationConfig):
@@ -530,7 +525,7 @@ class GPTQLinearMethod(LinearMethodBase):
530
525
  layer.g_idx.data = torch.empty(
531
526
  (0,), dtype=torch.int, device=layer.g_idx.device
532
527
  )
533
- ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits)
528
+ gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits)
534
529
 
535
530
  def apply(
536
531
  self,
@@ -541,7 +536,7 @@ class GPTQLinearMethod(LinearMethodBase):
541
536
  out_shape = x.shape[:-1] + (layer.qweight.shape[-1],)
542
537
  reshaped_x = x.reshape(-1, x.shape[-1])
543
538
 
544
- output = ops.gptq_gemm(
539
+ output = gptq_gemm(
545
540
  reshaped_x,
546
541
  layer.qweight,
547
542
  layer.qzeros,
@@ -726,7 +721,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
726
721
  def transform_w_q(x):
727
722
  assert isinstance(x, BasevLLMParameter)
728
723
  permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
729
- x.data = torch.ops.sgl_kernel.gptq_marlin_repack(
724
+ x.data = gptq_marlin_repack(
730
725
  x.data.contiguous(),
731
726
  perm=layer.g_idx_sort_indices,
732
727
  size_k=c.partition_weight_shape[0],
@@ -19,9 +19,12 @@ from sglang.srt.layers.quantization.base_config import (
19
19
  LinearMethodBase,
20
20
  QuantizationConfig,
21
21
  )
22
- from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
23
- from sglang.srt.layers.quantization.utils import pack_cols, unpack_cols
24
- from sglang.srt.utils import get_device_capability
22
+ from sglang.srt.layers.quantization.utils import (
23
+ get_scalar_types,
24
+ pack_cols,
25
+ unpack_cols,
26
+ )
27
+ from sglang.srt.utils import get_device_capability, is_cuda
25
28
 
26
29
  if TYPE_CHECKING:
27
30
  from sglang.srt.layers.linear import LinearBase
@@ -31,8 +34,15 @@ try:
31
34
  except ImportError:
32
35
  ops = None
33
36
 
37
+ _is_cuda = is_cuda()
38
+
39
+ if _is_cuda:
40
+ from sgl_kernel import gptq_marlin_gemm
41
+
34
42
  logger = logging.getLogger(__name__)
35
43
 
44
+ ScalarType, scalar_types = get_scalar_types()
45
+
36
46
  GPTQ_MARLIN_TILE = 16
37
47
  GPTQ_MARLIN_MIN_THREAD_N = 64
38
48
  GPTQ_MARLIN_MIN_THREAD_K = 128
@@ -453,7 +463,7 @@ def apply_gptq_marlin_linear(
453
463
  dtype=input.dtype,
454
464
  )
455
465
 
456
- output = ops.gptq_marlin_gemm(
466
+ output = gptq_marlin_gemm(
457
467
  reshaped_x,
458
468
  None,
459
469
  weight,
@@ -504,7 +514,7 @@ def apply_awq_marlin_linear(
504
514
  dtype=input.dtype,
505
515
  )
506
516
 
507
- output = ops.gptq_marlin_gemm(
517
+ output = gptq_marlin_gemm(
508
518
  reshaped_x,
509
519
  None,
510
520
  weight,
@@ -737,6 +737,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
737
737
  " above."
738
738
  )
739
739
  self.enable_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
740
+ self._cache_permute_indices = {}
740
741
 
741
742
  @property
742
743
  def enable_flashinfer_cutlass_moe(self) -> bool:
@@ -900,10 +901,15 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
900
901
  e2m1_and_ufp8sf_scale_to_float,
901
902
  fp4_quantize,
902
903
  next_positive_power_of_2,
904
+ nvfp4_block_scale_interleave,
903
905
  reorder_rows_for_gated_act_gemm,
904
906
  shuffle_matrix_a,
905
907
  shuffle_matrix_sf_a,
906
908
  )
909
+ from flashinfer.fused_moe.core import (
910
+ _maybe_get_cached_w2_permute_indices,
911
+ _maybe_get_cached_w3_w1_permute_indices,
912
+ )
907
913
 
908
914
  """Prepare quantized weights for kernel (done offline with weights)."""
909
915
  epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
@@ -927,50 +933,66 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
927
933
  num_experts, hidden_size, intermediate_size // 16
928
934
  ) # fp8 scaling factors
929
935
 
930
- # Reorder rows of W1 and scales for fused gated activation
931
- gemm1_weights_fp4_interleaved = []
932
- gemm1_scales_fp4_interleaved = []
933
- for i in range(num_experts):
934
- gemm1_weights_fp4_interleaved.append(
935
- reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone())
936
- )
937
- gemm1_scales_fp4_interleaved.append(
938
- reorder_rows_for_gated_act_gemm(gemm1_scales_linear_fp4[i].clone())
939
- )
940
-
941
- # Stack weights and scales for all experts
942
- gemm1_weights_fp4_interleaved = torch.stack(
943
- gemm1_weights_fp4_interleaved
944
- ).reshape(num_experts, 2 * intermediate_size, hidden_size // 2)
945
- gemm1_scales_fp4_interleaved = torch.stack(
946
- gemm1_scales_fp4_interleaved
947
- ).reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
948
-
949
- # Shuffle weights and scaling factors for transposed mma output
950
936
  gemm1_weights_fp4_shuffled = []
951
937
  gemm1_scales_fp4_shuffled = []
952
938
  gemm2_weights_fp4_shuffled = []
953
939
  gemm2_scales_fp4_shuffled = []
954
940
  for i in range(num_experts):
941
+ # Calculate the permute indices for the following:
942
+ # 1. Reorder rows of W1 and scales for fused gated activation
943
+ # 2. Shuffle weights and scaling factors for transposed mma output
944
+ # for both w3_w1 and w2 weights and scale factors
945
+ permute_indices = _maybe_get_cached_w3_w1_permute_indices(
946
+ self._cache_permute_indices,
947
+ gemm1_weights_fp4[i].view(torch.uint8),
948
+ epilogue_tile_m,
949
+ )
955
950
  gemm1_weights_fp4_shuffled.append(
956
- shuffle_matrix_a(
957
- gemm1_weights_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m
958
- )
951
+ gemm1_weights_fp4[i]
952
+ .view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
953
+ .contiguous()
954
+ )
955
+
956
+ permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
957
+ self._cache_permute_indices,
958
+ gemm1_scales_linear_fp4[i].view(torch.uint8),
959
+ epilogue_tile_m,
960
+ num_elts_per_sf=16,
959
961
  )
960
962
  gemm1_scales_fp4_shuffled.append(
961
- shuffle_matrix_sf_a(
962
- gemm1_scales_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m
963
+ nvfp4_block_scale_interleave(
964
+ gemm1_scales_linear_fp4[i]
965
+ .view(torch.uint8)[
966
+ permute_sf_indices.to(gemm1_scales_linear_fp4.device)
967
+ ]
968
+ .contiguous()
963
969
  )
964
970
  )
965
971
 
972
+ permute_indices = _maybe_get_cached_w2_permute_indices(
973
+ self._cache_permute_indices,
974
+ gemm2_weights_fp4[i].view(torch.uint8),
975
+ epilogue_tile_m,
976
+ )
966
977
  gemm2_weights_fp4_shuffled.append(
967
- shuffle_matrix_a(
968
- gemm2_weights_fp4[i].view(torch.uint8), epilogue_tile_m
969
- )
978
+ gemm2_weights_fp4[i]
979
+ .view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)]
980
+ .contiguous()
981
+ )
982
+
983
+ permute_sf_indices = _maybe_get_cached_w2_permute_indices(
984
+ self._cache_permute_indices,
985
+ gemm2_scales_linear_fp4[i].view(torch.uint8),
986
+ epilogue_tile_m,
987
+ num_elts_per_sf=16,
970
988
  )
971
989
  gemm2_scales_fp4_shuffled.append(
972
- shuffle_matrix_sf_a(
973
- gemm2_scales_linear_fp4[i].view(torch.uint8), epilogue_tile_m
990
+ nvfp4_block_scale_interleave(
991
+ gemm2_scales_linear_fp4[i]
992
+ .view(torch.uint8)[
993
+ permute_sf_indices.to(gemm2_scales_linear_fp4.device)
994
+ ]
995
+ .contiguous()
974
996
  )
975
997
  )
976
998
 
@@ -1,5 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
2
  # SPDX-License-Identifier: Apache-2.0
2
- # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/mxfp4.py
3
16
 
4
17
  from __future__ import annotations
5
18
 
@@ -209,6 +222,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
209
222
 
210
223
  super().__init__()
211
224
 
225
+ self.prefix = prefix
212
226
  self.topk_indices_dtype = None
213
227
  self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
214
228
  self.with_bias = False
@@ -332,7 +346,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
332
346
  if self.use_flashinfer:
333
347
  log_info_on_rank0(
334
348
  logger,
335
- "Shuffling MoE weights for FlashInfer MXFP4 moe kernel, it might take a while...",
349
+ f"Shuffling MoE weights for FlashInfer MXFP4 moe kernel (layer: {self.prefix}), it might take a while...",
336
350
  )
337
351
  layer.gemm1_alpha = Parameter(
338
352
  torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
@@ -11,13 +11,39 @@ import numpy
11
11
  import torch
12
12
 
13
13
  from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
14
- from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
15
- from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
14
+ from sglang.srt.utils import is_cuda
16
15
 
17
16
  if TYPE_CHECKING:
18
17
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
19
18
 
20
19
 
20
+ def get_scalar_types():
21
+ """
22
+ Returns:
23
+ tuple: (ScalarType, scalar_types)
24
+ """
25
+ try:
26
+ from sgl_kernel.scalar_type import ScalarType, scalar_types
27
+
28
+ return ScalarType, scalar_types
29
+ except ImportError:
30
+
31
+ class MockScalarType:
32
+ pass
33
+
34
+ class MockScalarTypes:
35
+ uint4b8 = "uint4b8"
36
+ uint8b128 = "uint8b128"
37
+
38
+ def __getattr__(self, name):
39
+ return f"mock_{name}"
40
+
41
+ return MockScalarType, MockScalarTypes()
42
+
43
+
44
+ ScalarType, scalar_types = get_scalar_types()
45
+
46
+
21
47
  def is_layer_skipped(
22
48
  prefix: str,
23
49
  ignored_layers: List[str],
@@ -295,6 +321,30 @@ def pack_cols(
295
321
  return q_res
296
322
 
297
323
 
324
+ def pack_rows(
325
+ q_w: torch.Tensor,
326
+ num_bits: int,
327
+ size_k: int,
328
+ size_n: int,
329
+ ):
330
+ assert q_w.shape == (size_k, size_n)
331
+
332
+ pack_factor = get_pack_factor(num_bits)
333
+ assert size_k % pack_factor == 0
334
+
335
+ orig_device = q_w.device
336
+
337
+ q_w = q_w.cpu().numpy().astype(numpy.uint32)
338
+
339
+ q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)
340
+
341
+ for i in range(pack_factor):
342
+ q_res |= q_w[i::pack_factor, :] << num_bits * i
343
+
344
+ q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
345
+ return q_res
346
+
347
+
298
348
  def unpack_cols(
299
349
  packed_q_w: torch.Tensor,
300
350
  num_bits: int,
@@ -6,7 +6,10 @@ import torch.distributed as dist
6
6
  from torch import nn
7
7
 
8
8
  from sglang.srt.distributed import get_tp_group
9
- from sglang.srt.layers.dp_attention import get_attention_tp_group
9
+ from sglang.srt.layers.dp_attention import (
10
+ get_attention_tp_group,
11
+ is_dp_attention_enabled,
12
+ )
10
13
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
11
14
  from sglang.srt.managers.schedule_batch import global_server_args_dict
12
15
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
@@ -32,7 +35,7 @@ class Sampler(nn.Module):
32
35
  self.use_nan_detection = global_server_args_dict["enable_nan_detection"]
33
36
  self.tp_sync_group = get_tp_group().device_group
34
37
 
35
- if global_server_args_dict["enable_dp_attention"]:
38
+ if is_dp_attention_enabled():
36
39
  self.tp_sync_group = get_attention_tp_group().device_group
37
40
 
38
41
  def forward(
sglang/srt/lora/layers.py CHANGED
@@ -253,7 +253,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
253
253
  )
254
254
  return lora_output
255
255
 
256
- def forward(self, input_: torch.Tensor):
256
+ def forward(self, input_: torch.Tensor, skip_all_reduce=False):
257
257
  # duplicate the logic in RowParallelLinear
258
258
  if self.base_layer.input_is_parallel:
259
259
  input_parallel = input_
@@ -270,7 +270,11 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
270
270
  if self.set_lora:
271
271
  output_parallel = self.apply_lora(output_parallel, input_parallel)
272
272
 
273
- if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
273
+ if (
274
+ self.base_layer.reduce_results
275
+ and self.base_layer.tp_size > 1
276
+ and not skip_all_reduce
277
+ ):
274
278
  output_ = tensor_model_parallel_all_reduce(output_parallel)
275
279
  else:
276
280
  output_ = output_parallel
@@ -296,6 +296,9 @@ class HiCacheController:
296
296
  self.prefetch_tp_group = torch.distributed.new_group(
297
297
  group_ranks, backend="gloo"
298
298
  )
299
+ self.prefetch_io_tp_group = torch.distributed.new_group(
300
+ group_ranks, backend="gloo"
301
+ )
299
302
  self.backup_tp_group = torch.distributed.new_group(
300
303
  group_ranks, backend="gloo"
301
304
  )
@@ -602,7 +605,7 @@ class HiCacheController:
602
605
 
603
606
  if self.tp_world_size > 1:
604
607
  # to ensure all TP workers release the host memory at the same time
605
- torch.distributed.barrier(group=self.prefetch_tp_group)
608
+ torch.distributed.barrier(group=self.prefetch_io_tp_group)
606
609
  # operation terminated by controller, release pre-allocated memory
607
610
  self.mem_pool_host.free(
608
611
  operation.host_indices[operation.completed_tokens :]
@@ -798,6 +798,8 @@ class UpdateWeightFromDiskReqInput:
798
798
  load_format: Optional[str] = None
799
799
  # Whether to abort all requests before updating weights
800
800
  abort_all_requests: bool = False
801
+ # Optional: Update weight version along with weights
802
+ weight_version: Optional[str] = None
801
803
 
802
804
 
803
805
  @dataclass
@@ -819,6 +821,8 @@ class UpdateWeightsFromDistributedReqInput:
819
821
  flush_cache: bool = True
820
822
  # Whether to abort all requests before updating weights
821
823
  abort_all_requests: bool = False
824
+ # Optional: Update weight version along with weights
825
+ weight_version: Optional[str] = None
822
826
 
823
827
 
824
828
  @dataclass
@@ -842,6 +846,8 @@ class UpdateWeightsFromTensorReqInput:
842
846
  flush_cache: bool = True
843
847
  # Whether to abort all requests before updating weights
844
848
  abort_all_requests: bool = False
849
+ # Optional: Update weight version along with weights
850
+ weight_version: Optional[str] = None
845
851
 
846
852
 
847
853
  @dataclass
@@ -872,6 +878,14 @@ class InitWeightsUpdateGroupReqOutput:
872
878
  message: str
873
879
 
874
880
 
881
+ @dataclass
882
+ class UpdateWeightVersionReqInput:
883
+ # The new weight version
884
+ new_version: str
885
+ # Whether to abort all running requests before updating
886
+ abort_all_requests: bool = True
887
+
888
+
875
889
  @dataclass
876
890
  class GetWeightsByNameReqInput:
877
891
  name: str