sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -1
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +8 -7
  6. sglang/srt/disaggregation/decode.py +8 -4
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +68 -5
  14. sglang/srt/entrypoints/openai/protocol.py +2 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +60 -265
  16. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  18. sglang/srt/function_call/ebnf_composer.py +1 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  21. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  22. sglang/srt/function_call/kimik2_detector.py +3 -3
  23. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  24. sglang/srt/jinja_template_utils.py +6 -0
  25. sglang/srt/layers/attention/aiter_backend.py +370 -107
  26. sglang/srt/layers/attention/ascend_backend.py +3 -0
  27. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  28. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  29. sglang/srt/layers/attention/flashinfer_backend.py +55 -13
  30. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
  31. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  32. sglang/srt/layers/attention/triton_backend.py +24 -27
  33. sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
  34. sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
  35. sglang/srt/layers/attention/vision.py +9 -1
  36. sglang/srt/layers/attention/wave_backend.py +627 -0
  37. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  38. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  39. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  40. sglang/srt/layers/communicator.py +11 -13
  41. sglang/srt/layers/dp_attention.py +118 -27
  42. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  43. sglang/srt/layers/linear.py +1 -0
  44. sglang/srt/layers/logits_processor.py +12 -18
  45. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  46. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  47. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  48. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  63. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  64. sglang/srt/layers/moe/topk.py +4 -1
  65. sglang/srt/layers/multimodal.py +156 -40
  66. sglang/srt/layers/quantization/__init__.py +10 -35
  67. sglang/srt/layers/quantization/awq.py +15 -16
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
  69. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  70. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  71. sglang/srt/layers/quantization/gptq.py +12 -17
  72. sglang/srt/layers/quantization/marlin_utils.py +15 -5
  73. sglang/srt/layers/quantization/modelopt_quant.py +58 -41
  74. sglang/srt/layers/quantization/mxfp4.py +20 -3
  75. sglang/srt/layers/quantization/utils.py +52 -2
  76. sglang/srt/layers/quantization/w4afp8.py +20 -11
  77. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  78. sglang/srt/layers/rotary_embedding.py +281 -2
  79. sglang/srt/layers/sampler.py +5 -2
  80. sglang/srt/lora/backend/base_backend.py +3 -23
  81. sglang/srt/lora/layers.py +66 -116
  82. sglang/srt/lora/lora.py +17 -62
  83. sglang/srt/lora/lora_manager.py +12 -48
  84. sglang/srt/lora/lora_registry.py +20 -9
  85. sglang/srt/lora/mem_pool.py +20 -63
  86. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  87. sglang/srt/lora/utils.py +25 -58
  88. sglang/srt/managers/cache_controller.py +24 -29
  89. sglang/srt/managers/detokenizer_manager.py +1 -1
  90. sglang/srt/managers/io_struct.py +20 -6
  91. sglang/srt/managers/mm_utils.py +1 -2
  92. sglang/srt/managers/multimodal_processor.py +1 -1
  93. sglang/srt/managers/schedule_batch.py +43 -49
  94. sglang/srt/managers/schedule_policy.py +6 -6
  95. sglang/srt/managers/scheduler.py +18 -11
  96. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  97. sglang/srt/managers/tokenizer_manager.py +53 -44
  98. sglang/srt/mem_cache/allocator.py +39 -214
  99. sglang/srt/mem_cache/allocator_ascend.py +158 -0
  100. sglang/srt/mem_cache/chunk_cache.py +1 -1
  101. sglang/srt/mem_cache/hicache_storage.py +1 -1
  102. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  103. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  104. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  105. sglang/srt/mem_cache/radix_cache.py +2 -5
  106. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  107. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  108. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  109. sglang/srt/model_executor/cuda_graph_runner.py +29 -23
  110. sglang/srt/model_executor/forward_batch_info.py +33 -14
  111. sglang/srt/model_executor/model_runner.py +179 -81
  112. sglang/srt/model_loader/loader.py +18 -6
  113. sglang/srt/models/deepseek_nextn.py +2 -1
  114. sglang/srt/models/deepseek_v2.py +79 -38
  115. sglang/srt/models/gemma2.py +0 -34
  116. sglang/srt/models/gemma3n_mm.py +8 -9
  117. sglang/srt/models/glm4.py +6 -0
  118. sglang/srt/models/glm4_moe.py +11 -11
  119. sglang/srt/models/glm4_moe_nextn.py +2 -1
  120. sglang/srt/models/glm4v.py +589 -0
  121. sglang/srt/models/glm4v_moe.py +400 -0
  122. sglang/srt/models/gpt_oss.py +142 -20
  123. sglang/srt/models/granite.py +0 -25
  124. sglang/srt/models/llama.py +10 -27
  125. sglang/srt/models/llama4.py +19 -6
  126. sglang/srt/models/qwen2.py +2 -2
  127. sglang/srt/models/qwen2_5_vl.py +7 -3
  128. sglang/srt/models/qwen2_audio.py +10 -9
  129. sglang/srt/models/qwen2_moe.py +20 -5
  130. sglang/srt/models/qwen3.py +0 -24
  131. sglang/srt/models/qwen3_classification.py +78 -0
  132. sglang/srt/models/qwen3_moe.py +18 -5
  133. sglang/srt/models/registry.py +1 -1
  134. sglang/srt/models/step3_vl.py +6 -2
  135. sglang/srt/models/torch_native_llama.py +0 -24
  136. sglang/srt/multimodal/processors/base_processor.py +23 -13
  137. sglang/srt/multimodal/processors/glm4v.py +132 -0
  138. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  139. sglang/srt/operations.py +17 -2
  140. sglang/srt/reasoning_parser.py +316 -0
  141. sglang/srt/sampling/sampling_batch_info.py +7 -4
  142. sglang/srt/server_args.py +142 -140
  143. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
  144. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  145. sglang/srt/speculative/eagle_worker.py +16 -0
  146. sglang/srt/two_batch_overlap.py +16 -12
  147. sglang/srt/utils.py +3 -3
  148. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  149. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  150. sglang/test/doc_patch.py +59 -0
  151. sglang/test/few_shot_gsm8k.py +1 -1
  152. sglang/test/few_shot_gsm8k_engine.py +1 -1
  153. sglang/test/run_eval.py +4 -1
  154. sglang/test/simple_eval_common.py +6 -0
  155. sglang/test/simple_eval_gpqa.py +2 -0
  156. sglang/test/test_fp4_moe.py +118 -36
  157. sglang/test/test_marlin_moe.py +1 -1
  158. sglang/test/test_marlin_utils.py +1 -1
  159. sglang/utils.py +1 -1
  160. sglang/version.py +1 -1
  161. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
  162. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
  163. sglang/lang/backend/__init__.py +0 -0
  164. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  165. sglang/srt/layers/quantization/scalar_type.py +0 -352
  166. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  167. /sglang/{api.py → lang/api.py} +0 -0
  168. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
  169. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
  170. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,8 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
2
2
  from __future__ import annotations
3
3
 
4
- import importlib.util
5
4
  import logging
6
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
5
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional
7
6
 
8
7
  import torch
9
8
  from torch.nn.parameter import Parameter
@@ -42,11 +41,7 @@ if is_cuda():
42
41
 
43
42
  try:
44
43
  from flashinfer import mm_fp4 as fp4_gemm
45
- from flashinfer import (
46
- reorder_rows_for_gated_act_gemm,
47
- shuffle_matrix_a,
48
- shuffle_matrix_sf_a,
49
- )
44
+ from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_sf_a
50
45
 
51
46
  enable_flashinfer_fp4_gemm = True
52
47
  except ImportError:
@@ -682,9 +677,9 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
682
677
  padded_scales = padded_scales.permute((0, 1, 4, 3, 2, 5))
683
678
  padded_scales = padded_scales.contiguous().cuda()
684
679
  padded_scales = (
685
- padded_scales.reshape(M, K)
680
+ padded_scales.reshape(M_padded, K_padded)
686
681
  if scale_ndim == 2
687
- else padded_scales.reshape(B, M, K)
682
+ else padded_scales.reshape(B, M_padded, K_padded)
688
683
  )
689
684
  layer.weight_scale_interleaved = Parameter(padded_scales, requires_grad=False)
690
685
 
@@ -742,6 +737,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
742
737
  " above."
743
738
  )
744
739
  self.enable_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
740
+ self._cache_permute_indices = {}
745
741
 
746
742
  @property
747
743
  def enable_flashinfer_cutlass_moe(self) -> bool:
@@ -883,9 +879,9 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
883
879
  swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
884
880
  swizzled_scale = swizzled_scale.contiguous().cuda()
885
881
  return (
886
- swizzled_scale.reshape(M, K)
882
+ swizzled_scale.reshape(M_padded, K_padded)
887
883
  if scale_ndim == 2
888
- else swizzled_scale.reshape(B, M, K)
884
+ else swizzled_scale.reshape(B, M_padded, K_padded)
889
885
  )
890
886
 
891
887
  def prepare_static_weights_for_kernel(
@@ -905,10 +901,15 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
905
901
  e2m1_and_ufp8sf_scale_to_float,
906
902
  fp4_quantize,
907
903
  next_positive_power_of_2,
904
+ nvfp4_block_scale_interleave,
908
905
  reorder_rows_for_gated_act_gemm,
909
906
  shuffle_matrix_a,
910
907
  shuffle_matrix_sf_a,
911
908
  )
909
+ from flashinfer.fused_moe.core import (
910
+ _maybe_get_cached_w2_permute_indices,
911
+ _maybe_get_cached_w3_w1_permute_indices,
912
+ )
912
913
 
913
914
  """Prepare quantized weights for kernel (done offline with weights)."""
914
915
  epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
@@ -932,50 +933,66 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
932
933
  num_experts, hidden_size, intermediate_size // 16
933
934
  ) # fp8 scaling factors
934
935
 
935
- # Reorder rows of W1 and scales for fused gated activation
936
- gemm1_weights_fp4_interleaved = []
937
- gemm1_scales_fp4_interleaved = []
938
- for i in range(num_experts):
939
- gemm1_weights_fp4_interleaved.append(
940
- reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone())
941
- )
942
- gemm1_scales_fp4_interleaved.append(
943
- reorder_rows_for_gated_act_gemm(gemm1_scales_linear_fp4[i].clone())
944
- )
945
-
946
- # Stack weights and scales for all experts
947
- gemm1_weights_fp4_interleaved = torch.stack(
948
- gemm1_weights_fp4_interleaved
949
- ).reshape(num_experts, 2 * intermediate_size, hidden_size // 2)
950
- gemm1_scales_fp4_interleaved = torch.stack(
951
- gemm1_scales_fp4_interleaved
952
- ).reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
953
-
954
- # Shuffle weights and scaling factors for transposed mma output
955
936
  gemm1_weights_fp4_shuffled = []
956
937
  gemm1_scales_fp4_shuffled = []
957
938
  gemm2_weights_fp4_shuffled = []
958
939
  gemm2_scales_fp4_shuffled = []
959
940
  for i in range(num_experts):
941
+ # Calculate the permute indices for the following:
942
+ # 1. Reorder rows of W1 and scales for fused gated activation
943
+ # 2. Shuffle weights and scaling factors for transposed mma output
944
+ # for both w3_w1 and w2 weights and scale factors
945
+ permute_indices = _maybe_get_cached_w3_w1_permute_indices(
946
+ self._cache_permute_indices,
947
+ gemm1_weights_fp4[i].view(torch.uint8),
948
+ epilogue_tile_m,
949
+ )
960
950
  gemm1_weights_fp4_shuffled.append(
961
- shuffle_matrix_a(
962
- gemm1_weights_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m
963
- )
951
+ gemm1_weights_fp4[i]
952
+ .view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
953
+ .contiguous()
954
+ )
955
+
956
+ permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
957
+ self._cache_permute_indices,
958
+ gemm1_scales_linear_fp4[i].view(torch.uint8),
959
+ epilogue_tile_m,
960
+ num_elts_per_sf=16,
964
961
  )
965
962
  gemm1_scales_fp4_shuffled.append(
966
- shuffle_matrix_sf_a(
967
- gemm1_scales_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m
963
+ nvfp4_block_scale_interleave(
964
+ gemm1_scales_linear_fp4[i]
965
+ .view(torch.uint8)[
966
+ permute_sf_indices.to(gemm1_scales_linear_fp4.device)
967
+ ]
968
+ .contiguous()
968
969
  )
969
970
  )
970
971
 
972
+ permute_indices = _maybe_get_cached_w2_permute_indices(
973
+ self._cache_permute_indices,
974
+ gemm2_weights_fp4[i].view(torch.uint8),
975
+ epilogue_tile_m,
976
+ )
971
977
  gemm2_weights_fp4_shuffled.append(
972
- shuffle_matrix_a(
973
- gemm2_weights_fp4[i].view(torch.uint8), epilogue_tile_m
974
- )
978
+ gemm2_weights_fp4[i]
979
+ .view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)]
980
+ .contiguous()
981
+ )
982
+
983
+ permute_sf_indices = _maybe_get_cached_w2_permute_indices(
984
+ self._cache_permute_indices,
985
+ gemm2_scales_linear_fp4[i].view(torch.uint8),
986
+ epilogue_tile_m,
987
+ num_elts_per_sf=16,
975
988
  )
976
989
  gemm2_scales_fp4_shuffled.append(
977
- shuffle_matrix_sf_a(
978
- gemm2_scales_linear_fp4[i].view(torch.uint8), epilogue_tile_m
990
+ nvfp4_block_scale_interleave(
991
+ gemm2_scales_linear_fp4[i]
992
+ .view(torch.uint8)[
993
+ permute_sf_indices.to(gemm2_scales_linear_fp4.device)
994
+ ]
995
+ .contiguous()
979
996
  )
980
997
  )
981
998
 
@@ -1,5 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
2
  # SPDX-License-Identifier: Apache-2.0
2
- # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/mxfp4.py
3
16
 
4
17
  from __future__ import annotations
5
18
 
@@ -209,6 +222,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
209
222
 
210
223
  super().__init__()
211
224
 
225
+ self.prefix = prefix
212
226
  self.topk_indices_dtype = None
213
227
  self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
214
228
  self.with_bias = False
@@ -332,7 +346,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
332
346
  if self.use_flashinfer:
333
347
  log_info_on_rank0(
334
348
  logger,
335
- "Shuffling MoE weights for FlashInfer MXFP4 moe kernel, it might take a while...",
349
+ f"Shuffling MoE weights for FlashInfer MXFP4 moe kernel (layer: {self.prefix}), it might take a while...",
336
350
  )
337
351
  layer.gemm1_alpha = Parameter(
338
352
  torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
@@ -570,8 +584,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
570
584
  ) -> torch.Tensor:
571
585
  if self.use_flashinfer:
572
586
  # Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
573
- x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
587
+ x_quant, x_scale = mxfp8_quantize(
588
+ x, False, alignment=self.hidden_size
589
+ ) # to mxfp8
574
590
  x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
591
+ assert x_quant.shape[-1] == self.hidden_size
575
592
 
576
593
  top_k, router_logits = topk_output
577
594
 
@@ -11,13 +11,39 @@ import numpy
11
11
  import torch
12
12
 
13
13
  from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
14
- from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
15
- from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
14
+ from sglang.srt.utils import is_cuda
16
15
 
17
16
  if TYPE_CHECKING:
18
17
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
19
18
 
20
19
 
20
+ def get_scalar_types():
21
+ """
22
+ Returns:
23
+ tuple: (ScalarType, scalar_types)
24
+ """
25
+ try:
26
+ from sgl_kernel.scalar_type import ScalarType, scalar_types
27
+
28
+ return ScalarType, scalar_types
29
+ except ImportError:
30
+
31
+ class MockScalarType:
32
+ pass
33
+
34
+ class MockScalarTypes:
35
+ uint4b8 = "uint4b8"
36
+ uint8b128 = "uint8b128"
37
+
38
+ def __getattr__(self, name):
39
+ return f"mock_{name}"
40
+
41
+ return MockScalarType, MockScalarTypes()
42
+
43
+
44
+ ScalarType, scalar_types = get_scalar_types()
45
+
46
+
21
47
  def is_layer_skipped(
22
48
  prefix: str,
23
49
  ignored_layers: List[str],
@@ -295,6 +321,30 @@ def pack_cols(
295
321
  return q_res
296
322
 
297
323
 
324
+ def pack_rows(
325
+ q_w: torch.Tensor,
326
+ num_bits: int,
327
+ size_k: int,
328
+ size_n: int,
329
+ ):
330
+ assert q_w.shape == (size_k, size_n)
331
+
332
+ pack_factor = get_pack_factor(num_bits)
333
+ assert size_k % pack_factor == 0
334
+
335
+ orig_device = q_w.device
336
+
337
+ q_w = q_w.cpu().numpy().astype(numpy.uint32)
338
+
339
+ q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)
340
+
341
+ for i in range(pack_factor):
342
+ q_res |= q_w[i::pack_factor, :] << num_bits * i
343
+
344
+ q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
345
+ return q_res
346
+
347
+
298
348
  def unpack_cols(
299
349
  packed_q_w: torch.Tensor,
300
350
  num_bits: int,
@@ -116,6 +116,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
116
116
  params_dtype: torch.dtype,
117
117
  **extra_weight_attrs,
118
118
  ):
119
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
120
+
119
121
  assert "weight_loader" in extra_weight_attrs
120
122
 
121
123
  # Fused gate_up_proj (column parallel)
@@ -144,6 +146,9 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
144
146
  layer.register_parameter("w2_weight", w2_weight)
145
147
  set_weight_attrs(w2_weight, extra_weight_attrs)
146
148
 
149
+ extra_weight_attrs.update(
150
+ {"quant_method": FusedMoeWeightScaleSupported.GROUP.value}
151
+ )
147
152
  w13_weight_scale = torch.nn.Parameter(
148
153
  torch.zeros(
149
154
  num_experts,
@@ -274,8 +279,11 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
274
279
  def apply(
275
280
  self,
276
281
  layer: EPMoE,
277
- hidden_states: torch.Tensor,
282
+ x: torch.Tensor,
278
283
  topk_output: TopKOutput,
284
+ activation: str = "silu",
285
+ apply_router_weight_on_input: bool = False,
286
+ routed_scaling_factor: Optional[float] = None,
279
287
  **kwargs,
280
288
  ) -> torch.Tensor:
281
289
 
@@ -284,19 +292,17 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
284
292
 
285
293
  topk_weights, topk_ids, _ = topk_output
286
294
  local_topk_ids = topk_ids
287
- if layer.expert_map is not None:
288
- "Translate info from expert_map to topk_ids"
289
- local_topk_ids = torch.where(
290
- layer.expert_map[topk_ids] != layer.num_experts,
291
- layer.expert_map[topk_ids],
292
- layer.num_experts,
293
- )
294
-
295
- return cutlass_w4a8_moe(
295
+ local_topk_ids = torch.where(
296
+ topk_ids == -1,
297
+ layer.num_experts,
298
+ topk_ids,
299
+ )
300
+
301
+ output = cutlass_w4a8_moe(
296
302
  layer.start_expert_id,
297
303
  layer.end_expert_id,
298
304
  layer.num_experts,
299
- hidden_states,
305
+ x,
300
306
  layer.w13_weight,
301
307
  layer.w2_weight,
302
308
  layer.w13_weight_scale_inv,
@@ -318,3 +324,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
318
324
  layer.w13_input_scale,
319
325
  layer.w2_input_scale,
320
326
  )
327
+ if routed_scaling_factor is not None:
328
+ output *= routed_scaling_factor
329
+ return output
@@ -3,7 +3,18 @@ from __future__ import annotations
3
3
  import importlib
4
4
  import sys
5
5
  from types import MappingProxyType
6
- from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
6
+ from typing import (
7
+ TYPE_CHECKING,
8
+ Any,
9
+ Callable,
10
+ Dict,
11
+ List,
12
+ Mapping,
13
+ Optional,
14
+ Tuple,
15
+ Union,
16
+ cast,
17
+ )
7
18
 
8
19
  import torch
9
20
  from torch.nn.parameter import Parameter
@@ -79,22 +90,16 @@ def npu_wrapper_rmsnorm_forward(func):
79
90
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
80
91
  if not x.is_contiguous():
81
92
  x = x.contiguous()
82
- original_dtype = x.dtype
83
- x = x.to(torch.float32)
84
93
  if residual is not None:
85
- x = x + residual.to(torch.float32)
86
- residual = x.to(original_dtype)
87
-
88
- x = (
89
- torch_npu.npu_rms_norm(
90
- x, self.weight.to(torch.float32), self.variance_epsilon
91
- )[0]
92
- + self.bias
93
- )
94
+ out, _, residual_out = torch_npu.npu_add_rms_norm(
95
+ residual, x, self.weight.data, self.variance_epsilon
96
+ )
97
+ out = out + self.bias
98
+ return out.to(x.dtype), residual_out
94
99
 
95
- if residual is None:
96
- return x.to(original_dtype)
97
- return x.to(original_dtype), residual
100
+ out = torch_npu.npu_rms_norm(x, self.weight.data, self.variance_epsilon)[0]
101
+ out = out + self.bias
102
+ return out.to(x.dtype)
98
103
 
99
104
  return _rmsnorm_forward_oot
100
105
 
@@ -250,17 +255,23 @@ class W8A8Int8Config(QuantizationConfig):
250
255
 
251
256
  if _is_npu:
252
257
  if isinstance(layer, LinearBase):
258
+ key = "model"
259
+ if "vision_model" in prefix:
260
+ key = "vision_model"
261
+ elif "visual" in prefix:
262
+ key = "visual"
263
+ packed_modules_mapping_subset = self.packed_modules_mapping.get(key, {})
253
264
  prefix_in_quant_config = prefix
254
265
  proj_name = prefix.split(".")[-1]
255
- if proj_name in self.packed_modules_mapping:
266
+ if proj_name in packed_modules_mapping_subset:
256
267
  prefix_in_quant_config = prefix.replace(
257
- proj_name, self.packed_modules_mapping[proj_name][0]
268
+ proj_name, packed_modules_mapping_subset[proj_name][0]
258
269
  )
259
270
  self.is_dynamic = (
260
271
  self.quant_description[prefix_in_quant_config + ".weight"]
261
272
  == "W8A8_DYNAMIC"
262
273
  )
263
- if self.is_layer_skipped(prefix, self.packed_modules_mapping):
274
+ if self.is_layer_skipped(prefix, packed_modules_mapping_subset):
264
275
  return UnquantizedLinearMethod()
265
276
  return (
266
277
  NPU_W8A8DynamicLinearMethod(self)
@@ -571,8 +582,10 @@ class NPU_W8A8LinearMethodImpl:
571
582
  layer: torch.nn.Module,
572
583
  x: torch.Tensor,
573
584
  bias: Optional[torch.Tensor] = None,
574
- tp_rank: Optional[int] = 0,
575
585
  ) -> torch.Tensor:
586
+ # To prevent import loops
587
+ from sglang.srt.layers.linear import RowParallelLinear
588
+
576
589
  original_dtype = x.dtype
577
590
  if original_dtype != torch.int8:
578
591
  x = torch_npu.npu_quantize(
@@ -583,8 +596,12 @@ class NPU_W8A8LinearMethodImpl:
583
596
  -1,
584
597
  True,
585
598
  )
586
-
587
- quant_bias = layer.quant_bias if tp_rank == 0 else None
599
+ # Only fuse bias add into GEMM for rank 0 (this ensures that
600
+ # bias will not get added more than once in Attention TP>1 case)
601
+ if isinstance(layer, RowParallelLinear) and layer.tp_rank > 0:
602
+ quant_bias = None
603
+ else:
604
+ quant_bias = layer.quant_bias
588
605
  return torch_npu.npu_quant_matmul(
589
606
  x,
590
607
  layer.weight,
@@ -651,13 +668,21 @@ class NPU_W8A8LinearMethodMTImpl:
651
668
  layer: torch.nn.Module,
652
669
  x: torch.Tensor,
653
670
  bias: Optional[torch.Tensor] = None,
654
- tp_rank: Optional[int] = 0,
655
671
  ) -> torch.Tensor:
672
+ # To prevent import loops
673
+ from sglang.srt.layers.linear import RowParallelLinear
674
+
656
675
  original_dtype = x.dtype
657
676
  if original_dtype != torch.int8:
658
677
  x = quant_per_tensor(x, layer.input_scale, layer.input_offset)
659
678
 
660
- quant_bias = layer.quant_bias if tp_rank == 0 else None
679
+ # Only fuse bias add into GEMM for rank 0 (this ensures that
680
+ # bias will not get added more than once in Attention TP>1 case)
681
+ if isinstance(layer, RowParallelLinear) and layer.tp_rank > 0:
682
+ quant_bias = None
683
+ else:
684
+ quant_bias = layer.quant_bias
685
+
661
686
  return ops.quant_matmul(
662
687
  x=x, weight=layer.weight, deq_scale=layer.deq_scale, deq_bias=quant_bias
663
688
  )
@@ -737,11 +762,6 @@ class NPU_W8A8LinearMethod(LinearMethodBase):
737
762
  x: torch.Tensor,
738
763
  bias: Optional[torch.Tensor] = None,
739
764
  ) -> torch.Tensor:
740
- from sglang.srt.layers.linear import RowParallelLinear
741
-
742
- if isinstance(layer, RowParallelLinear):
743
- tp_rank = get_tensor_model_parallel_rank()
744
- return self.quant_method.apply(layer, x, bias, tp_rank)
745
765
  return self.quant_method.apply(layer, x, bias)
746
766
 
747
767
 
@@ -780,7 +800,6 @@ class NPU_W8A8DynamicLinearMethodImpl:
780
800
  tp_rank: Optional[int] = 0,
781
801
  ) -> torch.Tensor:
782
802
  original_dtype = x.dtype
783
- # use ATB quantize
784
803
  quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)
785
804
  return torch_npu.npu_quant_matmul(
786
805
  quant_out,
@@ -863,11 +882,6 @@ class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
863
882
  x: torch.Tensor,
864
883
  bias: Optional[torch.Tensor] = None,
865
884
  ) -> torch.Tensor:
866
- from sglang.srt.layers.linear import RowParallelLinear
867
-
868
- if isinstance(layer, RowParallelLinear):
869
- tp_rank = get_tensor_model_parallel_rank()
870
- return self.quant_method.apply(layer, x, bias, tp_rank)
871
885
  return self.quant_method.apply(layer, x, bias)
872
886
 
873
887