sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +6 -1
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +8 -7
- sglang/srt/disaggregation/decode.py +8 -4
- sglang/srt/disaggregation/mooncake/conn.py +43 -25
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/distributed/parallel_state.py +4 -2
- sglang/srt/entrypoints/context.py +3 -20
- sglang/srt/entrypoints/engine.py +13 -8
- sglang/srt/entrypoints/harmony_utils.py +2 -0
- sglang/srt/entrypoints/http_server.py +68 -5
- sglang/srt/entrypoints/openai/protocol.py +2 -9
- sglang/srt/entrypoints/openai/serving_chat.py +60 -265
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/tool_server.py +4 -3
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/jinja_template_utils.py +6 -0
- sglang/srt/layers/attention/aiter_backend.py +370 -107
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +55 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +24 -27
- sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
- sglang/srt/layers/attention/vision.py +9 -1
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +11 -13
- sglang/srt/layers/dp_attention.py +118 -27
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +12 -18
- sglang/srt/layers/moe/cutlass_moe.py +11 -16
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +60 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +4 -1
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +10 -35
- sglang/srt/layers/quantization/awq.py +15 -16
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +22 -10
- sglang/srt/layers/quantization/gptq.py +12 -17
- sglang/srt/layers/quantization/marlin_utils.py +15 -5
- sglang/srt/layers/quantization/modelopt_quant.py +58 -41
- sglang/srt/layers/quantization/mxfp4.py +20 -3
- sglang/srt/layers/quantization/utils.py +52 -2
- sglang/srt/layers/quantization/w4afp8.py +20 -11
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +281 -2
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +66 -116
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +12 -48
- sglang/srt/lora/lora_registry.py +20 -9
- sglang/srt/lora/mem_pool.py +20 -63
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +24 -29
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -6
- sglang/srt/managers/mm_utils.py +1 -2
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +43 -49
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +18 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/tokenizer_manager.py +53 -44
- sglang/srt/mem_cache/allocator.py +39 -214
- sglang/srt/mem_cache/allocator_ascend.py +158 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +34 -24
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +33 -35
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -23
- sglang/srt/model_executor/forward_batch_info.py +33 -14
- sglang/srt/model_executor/model_runner.py +179 -81
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/models/deepseek_nextn.py +2 -1
- sglang/srt/models/deepseek_v2.py +79 -38
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +8 -9
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +11 -11
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +142 -20
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +10 -27
- sglang/srt/models/llama4.py +19 -6
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +20 -5
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_classification.py +78 -0
- sglang/srt/models/qwen3_moe.py +18 -5
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +6 -2
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/operations.py +17 -2
- sglang/srt/reasoning_parser.py +316 -0
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +142 -140
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +16 -12
- sglang/srt/utils.py +3 -3
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
- sglang/lang/backend/__init__.py +0 -0
- sglang/srt/function_call/harmony_tool_parser.py +0 -130
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,8 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
|
2
2
|
from __future__ import annotations
|
3
3
|
|
4
|
-
import importlib.util
|
5
4
|
import logging
|
6
|
-
from typing import TYPE_CHECKING, Any,
|
5
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
7
6
|
|
8
7
|
import torch
|
9
8
|
from torch.nn.parameter import Parameter
|
@@ -42,11 +41,7 @@ if is_cuda():
|
|
42
41
|
|
43
42
|
try:
|
44
43
|
from flashinfer import mm_fp4 as fp4_gemm
|
45
|
-
from flashinfer import
|
46
|
-
reorder_rows_for_gated_act_gemm,
|
47
|
-
shuffle_matrix_a,
|
48
|
-
shuffle_matrix_sf_a,
|
49
|
-
)
|
44
|
+
from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_sf_a
|
50
45
|
|
51
46
|
enable_flashinfer_fp4_gemm = True
|
52
47
|
except ImportError:
|
@@ -682,9 +677,9 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
|
|
682
677
|
padded_scales = padded_scales.permute((0, 1, 4, 3, 2, 5))
|
683
678
|
padded_scales = padded_scales.contiguous().cuda()
|
684
679
|
padded_scales = (
|
685
|
-
padded_scales.reshape(
|
680
|
+
padded_scales.reshape(M_padded, K_padded)
|
686
681
|
if scale_ndim == 2
|
687
|
-
else padded_scales.reshape(B,
|
682
|
+
else padded_scales.reshape(B, M_padded, K_padded)
|
688
683
|
)
|
689
684
|
layer.weight_scale_interleaved = Parameter(padded_scales, requires_grad=False)
|
690
685
|
|
@@ -742,6 +737,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
742
737
|
" above."
|
743
738
|
)
|
744
739
|
self.enable_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
|
740
|
+
self._cache_permute_indices = {}
|
745
741
|
|
746
742
|
@property
|
747
743
|
def enable_flashinfer_cutlass_moe(self) -> bool:
|
@@ -883,9 +879,9 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
883
879
|
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
|
884
880
|
swizzled_scale = swizzled_scale.contiguous().cuda()
|
885
881
|
return (
|
886
|
-
swizzled_scale.reshape(
|
882
|
+
swizzled_scale.reshape(M_padded, K_padded)
|
887
883
|
if scale_ndim == 2
|
888
|
-
else swizzled_scale.reshape(B,
|
884
|
+
else swizzled_scale.reshape(B, M_padded, K_padded)
|
889
885
|
)
|
890
886
|
|
891
887
|
def prepare_static_weights_for_kernel(
|
@@ -905,10 +901,15 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
905
901
|
e2m1_and_ufp8sf_scale_to_float,
|
906
902
|
fp4_quantize,
|
907
903
|
next_positive_power_of_2,
|
904
|
+
nvfp4_block_scale_interleave,
|
908
905
|
reorder_rows_for_gated_act_gemm,
|
909
906
|
shuffle_matrix_a,
|
910
907
|
shuffle_matrix_sf_a,
|
911
908
|
)
|
909
|
+
from flashinfer.fused_moe.core import (
|
910
|
+
_maybe_get_cached_w2_permute_indices,
|
911
|
+
_maybe_get_cached_w3_w1_permute_indices,
|
912
|
+
)
|
912
913
|
|
913
914
|
"""Prepare quantized weights for kernel (done offline with weights)."""
|
914
915
|
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
@@ -932,50 +933,66 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
932
933
|
num_experts, hidden_size, intermediate_size // 16
|
933
934
|
) # fp8 scaling factors
|
934
935
|
|
935
|
-
# Reorder rows of W1 and scales for fused gated activation
|
936
|
-
gemm1_weights_fp4_interleaved = []
|
937
|
-
gemm1_scales_fp4_interleaved = []
|
938
|
-
for i in range(num_experts):
|
939
|
-
gemm1_weights_fp4_interleaved.append(
|
940
|
-
reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone())
|
941
|
-
)
|
942
|
-
gemm1_scales_fp4_interleaved.append(
|
943
|
-
reorder_rows_for_gated_act_gemm(gemm1_scales_linear_fp4[i].clone())
|
944
|
-
)
|
945
|
-
|
946
|
-
# Stack weights and scales for all experts
|
947
|
-
gemm1_weights_fp4_interleaved = torch.stack(
|
948
|
-
gemm1_weights_fp4_interleaved
|
949
|
-
).reshape(num_experts, 2 * intermediate_size, hidden_size // 2)
|
950
|
-
gemm1_scales_fp4_interleaved = torch.stack(
|
951
|
-
gemm1_scales_fp4_interleaved
|
952
|
-
).reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
|
953
|
-
|
954
|
-
# Shuffle weights and scaling factors for transposed mma output
|
955
936
|
gemm1_weights_fp4_shuffled = []
|
956
937
|
gemm1_scales_fp4_shuffled = []
|
957
938
|
gemm2_weights_fp4_shuffled = []
|
958
939
|
gemm2_scales_fp4_shuffled = []
|
959
940
|
for i in range(num_experts):
|
941
|
+
# Calculate the permute indices for the following:
|
942
|
+
# 1. Reorder rows of W1 and scales for fused gated activation
|
943
|
+
# 2. Shuffle weights and scaling factors for transposed mma output
|
944
|
+
# for both w3_w1 and w2 weights and scale factors
|
945
|
+
permute_indices = _maybe_get_cached_w3_w1_permute_indices(
|
946
|
+
self._cache_permute_indices,
|
947
|
+
gemm1_weights_fp4[i].view(torch.uint8),
|
948
|
+
epilogue_tile_m,
|
949
|
+
)
|
960
950
|
gemm1_weights_fp4_shuffled.append(
|
961
|
-
|
962
|
-
|
963
|
-
)
|
951
|
+
gemm1_weights_fp4[i]
|
952
|
+
.view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
|
953
|
+
.contiguous()
|
954
|
+
)
|
955
|
+
|
956
|
+
permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
|
957
|
+
self._cache_permute_indices,
|
958
|
+
gemm1_scales_linear_fp4[i].view(torch.uint8),
|
959
|
+
epilogue_tile_m,
|
960
|
+
num_elts_per_sf=16,
|
964
961
|
)
|
965
962
|
gemm1_scales_fp4_shuffled.append(
|
966
|
-
|
967
|
-
|
963
|
+
nvfp4_block_scale_interleave(
|
964
|
+
gemm1_scales_linear_fp4[i]
|
965
|
+
.view(torch.uint8)[
|
966
|
+
permute_sf_indices.to(gemm1_scales_linear_fp4.device)
|
967
|
+
]
|
968
|
+
.contiguous()
|
968
969
|
)
|
969
970
|
)
|
970
971
|
|
972
|
+
permute_indices = _maybe_get_cached_w2_permute_indices(
|
973
|
+
self._cache_permute_indices,
|
974
|
+
gemm2_weights_fp4[i].view(torch.uint8),
|
975
|
+
epilogue_tile_m,
|
976
|
+
)
|
971
977
|
gemm2_weights_fp4_shuffled.append(
|
972
|
-
|
973
|
-
|
974
|
-
)
|
978
|
+
gemm2_weights_fp4[i]
|
979
|
+
.view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)]
|
980
|
+
.contiguous()
|
981
|
+
)
|
982
|
+
|
983
|
+
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
|
984
|
+
self._cache_permute_indices,
|
985
|
+
gemm2_scales_linear_fp4[i].view(torch.uint8),
|
986
|
+
epilogue_tile_m,
|
987
|
+
num_elts_per_sf=16,
|
975
988
|
)
|
976
989
|
gemm2_scales_fp4_shuffled.append(
|
977
|
-
|
978
|
-
gemm2_scales_linear_fp4[i]
|
990
|
+
nvfp4_block_scale_interleave(
|
991
|
+
gemm2_scales_linear_fp4[i]
|
992
|
+
.view(torch.uint8)[
|
993
|
+
permute_sf_indices.to(gemm2_scales_linear_fp4.device)
|
994
|
+
]
|
995
|
+
.contiguous()
|
979
996
|
)
|
980
997
|
)
|
981
998
|
|
@@ -1,5 +1,18 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
1
2
|
# SPDX-License-Identifier: Apache-2.0
|
2
|
-
#
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/mxfp4.py
|
3
16
|
|
4
17
|
from __future__ import annotations
|
5
18
|
|
@@ -209,6 +222,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
209
222
|
|
210
223
|
super().__init__()
|
211
224
|
|
225
|
+
self.prefix = prefix
|
212
226
|
self.topk_indices_dtype = None
|
213
227
|
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
|
214
228
|
self.with_bias = False
|
@@ -332,7 +346,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
332
346
|
if self.use_flashinfer:
|
333
347
|
log_info_on_rank0(
|
334
348
|
logger,
|
335
|
-
"Shuffling MoE weights for FlashInfer MXFP4 moe kernel, it might take a while...",
|
349
|
+
f"Shuffling MoE weights for FlashInfer MXFP4 moe kernel (layer: {self.prefix}), it might take a while...",
|
336
350
|
)
|
337
351
|
layer.gemm1_alpha = Parameter(
|
338
352
|
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
|
@@ -570,8 +584,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
570
584
|
) -> torch.Tensor:
|
571
585
|
if self.use_flashinfer:
|
572
586
|
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
|
573
|
-
x_quant, x_scale = mxfp8_quantize(
|
587
|
+
x_quant, x_scale = mxfp8_quantize(
|
588
|
+
x, False, alignment=self.hidden_size
|
589
|
+
) # to mxfp8
|
574
590
|
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
|
591
|
+
assert x_quant.shape[-1] == self.hidden_size
|
575
592
|
|
576
593
|
top_k, router_logits = topk_output
|
577
594
|
|
@@ -11,13 +11,39 @@ import numpy
|
|
11
11
|
import torch
|
12
12
|
|
13
13
|
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
14
|
-
from sglang.srt.
|
15
|
-
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
|
14
|
+
from sglang.srt.utils import is_cuda
|
16
15
|
|
17
16
|
if TYPE_CHECKING:
|
18
17
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
19
18
|
|
20
19
|
|
20
|
+
def get_scalar_types():
|
21
|
+
"""
|
22
|
+
Returns:
|
23
|
+
tuple: (ScalarType, scalar_types)
|
24
|
+
"""
|
25
|
+
try:
|
26
|
+
from sgl_kernel.scalar_type import ScalarType, scalar_types
|
27
|
+
|
28
|
+
return ScalarType, scalar_types
|
29
|
+
except ImportError:
|
30
|
+
|
31
|
+
class MockScalarType:
|
32
|
+
pass
|
33
|
+
|
34
|
+
class MockScalarTypes:
|
35
|
+
uint4b8 = "uint4b8"
|
36
|
+
uint8b128 = "uint8b128"
|
37
|
+
|
38
|
+
def __getattr__(self, name):
|
39
|
+
return f"mock_{name}"
|
40
|
+
|
41
|
+
return MockScalarType, MockScalarTypes()
|
42
|
+
|
43
|
+
|
44
|
+
ScalarType, scalar_types = get_scalar_types()
|
45
|
+
|
46
|
+
|
21
47
|
def is_layer_skipped(
|
22
48
|
prefix: str,
|
23
49
|
ignored_layers: List[str],
|
@@ -295,6 +321,30 @@ def pack_cols(
|
|
295
321
|
return q_res
|
296
322
|
|
297
323
|
|
324
|
+
def pack_rows(
|
325
|
+
q_w: torch.Tensor,
|
326
|
+
num_bits: int,
|
327
|
+
size_k: int,
|
328
|
+
size_n: int,
|
329
|
+
):
|
330
|
+
assert q_w.shape == (size_k, size_n)
|
331
|
+
|
332
|
+
pack_factor = get_pack_factor(num_bits)
|
333
|
+
assert size_k % pack_factor == 0
|
334
|
+
|
335
|
+
orig_device = q_w.device
|
336
|
+
|
337
|
+
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
338
|
+
|
339
|
+
q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)
|
340
|
+
|
341
|
+
for i in range(pack_factor):
|
342
|
+
q_res |= q_w[i::pack_factor, :] << num_bits * i
|
343
|
+
|
344
|
+
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
345
|
+
return q_res
|
346
|
+
|
347
|
+
|
298
348
|
def unpack_cols(
|
299
349
|
packed_q_w: torch.Tensor,
|
300
350
|
num_bits: int,
|
@@ -116,6 +116,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
116
116
|
params_dtype: torch.dtype,
|
117
117
|
**extra_weight_attrs,
|
118
118
|
):
|
119
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
120
|
+
|
119
121
|
assert "weight_loader" in extra_weight_attrs
|
120
122
|
|
121
123
|
# Fused gate_up_proj (column parallel)
|
@@ -144,6 +146,9 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
144
146
|
layer.register_parameter("w2_weight", w2_weight)
|
145
147
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
146
148
|
|
149
|
+
extra_weight_attrs.update(
|
150
|
+
{"quant_method": FusedMoeWeightScaleSupported.GROUP.value}
|
151
|
+
)
|
147
152
|
w13_weight_scale = torch.nn.Parameter(
|
148
153
|
torch.zeros(
|
149
154
|
num_experts,
|
@@ -274,8 +279,11 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
274
279
|
def apply(
|
275
280
|
self,
|
276
281
|
layer: EPMoE,
|
277
|
-
|
282
|
+
x: torch.Tensor,
|
278
283
|
topk_output: TopKOutput,
|
284
|
+
activation: str = "silu",
|
285
|
+
apply_router_weight_on_input: bool = False,
|
286
|
+
routed_scaling_factor: Optional[float] = None,
|
279
287
|
**kwargs,
|
280
288
|
) -> torch.Tensor:
|
281
289
|
|
@@ -284,19 +292,17 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
284
292
|
|
285
293
|
topk_weights, topk_ids, _ = topk_output
|
286
294
|
local_topk_ids = topk_ids
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
return cutlass_w4a8_moe(
|
295
|
+
local_topk_ids = torch.where(
|
296
|
+
topk_ids == -1,
|
297
|
+
layer.num_experts,
|
298
|
+
topk_ids,
|
299
|
+
)
|
300
|
+
|
301
|
+
output = cutlass_w4a8_moe(
|
296
302
|
layer.start_expert_id,
|
297
303
|
layer.end_expert_id,
|
298
304
|
layer.num_experts,
|
299
|
-
|
305
|
+
x,
|
300
306
|
layer.w13_weight,
|
301
307
|
layer.w2_weight,
|
302
308
|
layer.w13_weight_scale_inv,
|
@@ -318,3 +324,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
318
324
|
layer.w13_input_scale,
|
319
325
|
layer.w2_input_scale,
|
320
326
|
)
|
327
|
+
if routed_scaling_factor is not None:
|
328
|
+
output *= routed_scaling_factor
|
329
|
+
return output
|
@@ -3,7 +3,18 @@ from __future__ import annotations
|
|
3
3
|
import importlib
|
4
4
|
import sys
|
5
5
|
from types import MappingProxyType
|
6
|
-
from typing import
|
6
|
+
from typing import (
|
7
|
+
TYPE_CHECKING,
|
8
|
+
Any,
|
9
|
+
Callable,
|
10
|
+
Dict,
|
11
|
+
List,
|
12
|
+
Mapping,
|
13
|
+
Optional,
|
14
|
+
Tuple,
|
15
|
+
Union,
|
16
|
+
cast,
|
17
|
+
)
|
7
18
|
|
8
19
|
import torch
|
9
20
|
from torch.nn.parameter import Parameter
|
@@ -79,22 +90,16 @@ def npu_wrapper_rmsnorm_forward(func):
|
|
79
90
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
80
91
|
if not x.is_contiguous():
|
81
92
|
x = x.contiguous()
|
82
|
-
original_dtype = x.dtype
|
83
|
-
x = x.to(torch.float32)
|
84
93
|
if residual is not None:
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
x, self.weight.to(torch.float32), self.variance_epsilon
|
91
|
-
)[0]
|
92
|
-
+ self.bias
|
93
|
-
)
|
94
|
+
out, _, residual_out = torch_npu.npu_add_rms_norm(
|
95
|
+
residual, x, self.weight.data, self.variance_epsilon
|
96
|
+
)
|
97
|
+
out = out + self.bias
|
98
|
+
return out.to(x.dtype), residual_out
|
94
99
|
|
95
|
-
|
96
|
-
|
97
|
-
return
|
100
|
+
out = torch_npu.npu_rms_norm(x, self.weight.data, self.variance_epsilon)[0]
|
101
|
+
out = out + self.bias
|
102
|
+
return out.to(x.dtype)
|
98
103
|
|
99
104
|
return _rmsnorm_forward_oot
|
100
105
|
|
@@ -250,17 +255,23 @@ class W8A8Int8Config(QuantizationConfig):
|
|
250
255
|
|
251
256
|
if _is_npu:
|
252
257
|
if isinstance(layer, LinearBase):
|
258
|
+
key = "model"
|
259
|
+
if "vision_model" in prefix:
|
260
|
+
key = "vision_model"
|
261
|
+
elif "visual" in prefix:
|
262
|
+
key = "visual"
|
263
|
+
packed_modules_mapping_subset = self.packed_modules_mapping.get(key, {})
|
253
264
|
prefix_in_quant_config = prefix
|
254
265
|
proj_name = prefix.split(".")[-1]
|
255
|
-
if proj_name in
|
266
|
+
if proj_name in packed_modules_mapping_subset:
|
256
267
|
prefix_in_quant_config = prefix.replace(
|
257
|
-
proj_name,
|
268
|
+
proj_name, packed_modules_mapping_subset[proj_name][0]
|
258
269
|
)
|
259
270
|
self.is_dynamic = (
|
260
271
|
self.quant_description[prefix_in_quant_config + ".weight"]
|
261
272
|
== "W8A8_DYNAMIC"
|
262
273
|
)
|
263
|
-
if self.is_layer_skipped(prefix,
|
274
|
+
if self.is_layer_skipped(prefix, packed_modules_mapping_subset):
|
264
275
|
return UnquantizedLinearMethod()
|
265
276
|
return (
|
266
277
|
NPU_W8A8DynamicLinearMethod(self)
|
@@ -571,8 +582,10 @@ class NPU_W8A8LinearMethodImpl:
|
|
571
582
|
layer: torch.nn.Module,
|
572
583
|
x: torch.Tensor,
|
573
584
|
bias: Optional[torch.Tensor] = None,
|
574
|
-
tp_rank: Optional[int] = 0,
|
575
585
|
) -> torch.Tensor:
|
586
|
+
# To prevent import loops
|
587
|
+
from sglang.srt.layers.linear import RowParallelLinear
|
588
|
+
|
576
589
|
original_dtype = x.dtype
|
577
590
|
if original_dtype != torch.int8:
|
578
591
|
x = torch_npu.npu_quantize(
|
@@ -583,8 +596,12 @@ class NPU_W8A8LinearMethodImpl:
|
|
583
596
|
-1,
|
584
597
|
True,
|
585
598
|
)
|
586
|
-
|
587
|
-
|
599
|
+
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
600
|
+
# bias will not get added more than once in Attention TP>1 case)
|
601
|
+
if isinstance(layer, RowParallelLinear) and layer.tp_rank > 0:
|
602
|
+
quant_bias = None
|
603
|
+
else:
|
604
|
+
quant_bias = layer.quant_bias
|
588
605
|
return torch_npu.npu_quant_matmul(
|
589
606
|
x,
|
590
607
|
layer.weight,
|
@@ -651,13 +668,21 @@ class NPU_W8A8LinearMethodMTImpl:
|
|
651
668
|
layer: torch.nn.Module,
|
652
669
|
x: torch.Tensor,
|
653
670
|
bias: Optional[torch.Tensor] = None,
|
654
|
-
tp_rank: Optional[int] = 0,
|
655
671
|
) -> torch.Tensor:
|
672
|
+
# To prevent import loops
|
673
|
+
from sglang.srt.layers.linear import RowParallelLinear
|
674
|
+
|
656
675
|
original_dtype = x.dtype
|
657
676
|
if original_dtype != torch.int8:
|
658
677
|
x = quant_per_tensor(x, layer.input_scale, layer.input_offset)
|
659
678
|
|
660
|
-
|
679
|
+
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
680
|
+
# bias will not get added more than once in Attention TP>1 case)
|
681
|
+
if isinstance(layer, RowParallelLinear) and layer.tp_rank > 0:
|
682
|
+
quant_bias = None
|
683
|
+
else:
|
684
|
+
quant_bias = layer.quant_bias
|
685
|
+
|
661
686
|
return ops.quant_matmul(
|
662
687
|
x=x, weight=layer.weight, deq_scale=layer.deq_scale, deq_bias=quant_bias
|
663
688
|
)
|
@@ -737,11 +762,6 @@ class NPU_W8A8LinearMethod(LinearMethodBase):
|
|
737
762
|
x: torch.Tensor,
|
738
763
|
bias: Optional[torch.Tensor] = None,
|
739
764
|
) -> torch.Tensor:
|
740
|
-
from sglang.srt.layers.linear import RowParallelLinear
|
741
|
-
|
742
|
-
if isinstance(layer, RowParallelLinear):
|
743
|
-
tp_rank = get_tensor_model_parallel_rank()
|
744
|
-
return self.quant_method.apply(layer, x, bias, tp_rank)
|
745
765
|
return self.quant_method.apply(layer, x, bias)
|
746
766
|
|
747
767
|
|
@@ -780,7 +800,6 @@ class NPU_W8A8DynamicLinearMethodImpl:
|
|
780
800
|
tp_rank: Optional[int] = 0,
|
781
801
|
) -> torch.Tensor:
|
782
802
|
original_dtype = x.dtype
|
783
|
-
# use ATB quantize
|
784
803
|
quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
785
804
|
return torch_npu.npu_quant_matmul(
|
786
805
|
quant_out,
|
@@ -863,11 +882,6 @@ class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
|
|
863
882
|
x: torch.Tensor,
|
864
883
|
bias: Optional[torch.Tensor] = None,
|
865
884
|
) -> torch.Tensor:
|
866
|
-
from sglang.srt.layers.linear import RowParallelLinear
|
867
|
-
|
868
|
-
if isinstance(layer, RowParallelLinear):
|
869
|
-
tp_rank = get_tensor_model_parallel_rank()
|
870
|
-
return self.quant_method.apply(layer, x, bias, tp_rank)
|
871
885
|
return self.quant_method.apply(layer, x, bias)
|
872
886
|
|
873
887
|
|