sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +251 -26
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +63 -3
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +34 -19
- sglang/srt/entrypoints/openai/serving_completions.py +10 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +12 -0
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +250 -112
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +110 -49
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +43 -29
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -45
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +242 -278
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +13 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +160 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +90 -115
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +41 -477
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +24 -22
- sglang/srt/mem_cache/hiradix_cache.py +184 -101
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +324 -41
- sglang/srt/mem_cache/memory_pool_host.py +25 -18
- sglang/srt/mem_cache/radix_cache.py +5 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +189 -31
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +311 -50
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +90 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +297 -79
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/utils.py +37 -2
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -45,7 +45,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
|
45
45
|
|
46
46
|
if _use_aiter:
|
47
47
|
import aiter
|
48
|
-
from aiter import gemm_a8w8_blockscale, get_hip_quant
|
48
|
+
from aiter import gemm_a8w8_blockscale, gemm_a8w8_bpreshuffle, get_hip_quant
|
49
49
|
|
50
50
|
aiter_per1x128_quant = get_hip_quant(aiter.QuantType.per_1x128)
|
51
51
|
|
@@ -248,11 +248,6 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
|
|
248
248
|
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
249
249
|
)
|
250
250
|
|
251
|
-
# NOTE(alcanderian): Useless when scale is packed to int32
|
252
|
-
# if get_bool_env_var("SGLANG_W8A8_DEEPGEMM_SANITY_CHECK_UE8M0"):
|
253
|
-
# _check_ue8m0("x_scale", x_scale)
|
254
|
-
# _check_ue8m0("weight_scale", ws)
|
255
|
-
|
256
251
|
output = w8a8_block_fp8_matmul_deepgemm(
|
257
252
|
q_input, weight, x_scale, weight_scale, block_size, output_dtype=output_dtype
|
258
253
|
)
|
@@ -261,11 +256,6 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
|
|
261
256
|
return output.to(dtype=output_dtype).view(*output_shape)
|
262
257
|
|
263
258
|
|
264
|
-
def _check_ue8m0(name, x):
|
265
|
-
x_ceil = ceil_to_ue8m0(x)
|
266
|
-
assert torch.all(x == x_ceil), f"{name=} {x=} {x_ceil=}"
|
267
|
-
|
268
|
-
|
269
259
|
def aiter_w8a8_block_fp8_linear(
|
270
260
|
input: torch.Tensor,
|
271
261
|
weight: torch.Tensor,
|
@@ -652,25 +642,49 @@ def apply_fp8_linear(
|
|
652
642
|
use_per_token_if_dynamic
|
653
643
|
and not per_tensor_weights
|
654
644
|
and not per_tensor_activations
|
655
|
-
and USE_ROWWISE_TORCH_SCALED_MM
|
645
|
+
and (USE_ROWWISE_TORCH_SCALED_MM or _use_aiter)
|
656
646
|
):
|
657
|
-
#
|
658
|
-
#
|
659
|
-
#
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
647
|
+
# into this sector means use dynamic per-token-per-channel quant
|
648
|
+
# per-token scale quant for input matrix, every row(one token) have one scale factor
|
649
|
+
# per-channel scale quant for weight matrix, every col(one channel) have one scale factor
|
650
|
+
if _use_aiter:
|
651
|
+
# gemm_a8w8_bpreshuffle(XQ, WQ, x_scale, w_scale, dtype)
|
652
|
+
# XQ -> input tensor, shape = (m, k)
|
653
|
+
# WQ -> weight tensor, shape = (n, k), with preshuffe get better perf
|
654
|
+
# x_scale -> input scale tensor, shape = (m, 1)
|
655
|
+
# w_scale -> weight scale tensor, shape = (n ,1)
|
656
|
+
# dtype -> output dtype
|
657
|
+
output = gemm_a8w8_bpreshuffle(
|
658
|
+
XQ=qinput,
|
659
|
+
WQ=weight,
|
660
|
+
x_scale=x_scale,
|
661
|
+
w_scale=weight_scale,
|
662
|
+
dtype=input.dtype,
|
663
|
+
)
|
664
|
+
if bias is not None:
|
665
|
+
output += bias
|
666
|
+
return _process_scaled_mm_output(
|
667
|
+
output, input_2d.shape, [*input.shape[:-1], weight.shape[0]]
|
668
|
+
)
|
669
|
+
else:
|
670
|
+
# For now validated on ROCm platform
|
671
|
+
# fp8 rowwise scaling in torch._scaled_mm is introduced in
|
672
|
+
# https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
|
673
|
+
# and ROCm 6.3, which only exists in torch 2.7 and above.
|
674
|
+
# For CUDA platform please validate if the
|
675
|
+
# torch._scaled_mm support rowwise scaled GEMM
|
676
|
+
# Fused GEMM_DQ Rowwise GEMM
|
677
|
+
output = torch._scaled_mm(
|
678
|
+
qinput,
|
679
|
+
weight,
|
680
|
+
out_dtype=input.dtype,
|
681
|
+
scale_a=x_scale,
|
682
|
+
scale_b=weight_scale.t(),
|
683
|
+
bias=bias,
|
684
|
+
)
|
685
|
+
return _process_scaled_mm_output(
|
686
|
+
output, input_2d.shape, output_shape
|
687
|
+
)
|
674
688
|
else:
|
675
689
|
# Fallback for channelwise case, where we use unfused DQ
|
676
690
|
# due to limitations with scaled_mm
|
@@ -45,7 +45,10 @@ from sglang.srt.layers.quantization.utils import (
|
|
45
45
|
|
46
46
|
if TYPE_CHECKING:
|
47
47
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
48
|
-
from sglang.srt.layers.moe.
|
48
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
49
|
+
StandardDispatchOutput,
|
50
|
+
CombineInput,
|
51
|
+
)
|
49
52
|
|
50
53
|
from sglang.srt.utils import is_cuda
|
51
54
|
|
@@ -838,19 +841,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|
838
841
|
from sglang.srt.layers.linear import set_weight_attrs
|
839
842
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
840
843
|
|
841
|
-
|
842
|
-
|
843
|
-
self.is_k_full = (not self.quant_config.desc_act) or (
|
844
|
-
intermediate_size_per_partition == intermediate_size
|
845
|
-
)
|
844
|
+
self.is_k_full = (not self.quant_config.desc_act) or layer.moe_tp_size == 1
|
846
845
|
|
847
846
|
if self.quant_config.group_size != -1:
|
848
847
|
scales_size13 = hidden_size // self.quant_config.group_size
|
849
|
-
|
850
|
-
|
851
|
-
|
852
|
-
|
853
|
-
)
|
848
|
+
if self.quant_config.desc_act:
|
849
|
+
w2_scales_size = intermediate_size_per_partition
|
850
|
+
else:
|
851
|
+
w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size
|
854
852
|
scales_size2 = w2_scales_size // self.quant_config.group_size
|
855
853
|
strategy = FusedMoeWeightScaleSupported.GROUP.value
|
856
854
|
else:
|
@@ -1052,17 +1050,26 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|
1052
1050
|
)
|
1053
1051
|
replace_parameter(layer, "w2_scales", marlin_w2_scales)
|
1054
1052
|
|
1053
|
+
def create_moe_runner(
|
1054
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
1055
|
+
):
|
1056
|
+
self.moe_runner_config = moe_runner_config
|
1057
|
+
|
1055
1058
|
def apply(
|
1056
1059
|
self,
|
1057
1060
|
layer: torch.nn.Module,
|
1058
|
-
|
1059
|
-
|
1060
|
-
|
1061
|
-
|
1061
|
+
dispatch_output: StandardDispatchOutput,
|
1062
|
+
) -> CombineInput:
|
1063
|
+
|
1064
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
1065
|
+
|
1066
|
+
x = dispatch_output.hidden_states
|
1067
|
+
topk_output = dispatch_output.topk_output
|
1068
|
+
|
1062
1069
|
# Delay the import to avoid circular dependency
|
1063
1070
|
|
1064
1071
|
assert (
|
1065
|
-
moe_runner_config.activation == "silu"
|
1072
|
+
self.moe_runner_config.activation == "silu"
|
1066
1073
|
), "Only SiLU activation is supported."
|
1067
1074
|
|
1068
1075
|
# The input must currently be float16
|
@@ -1071,7 +1078,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|
1071
1078
|
|
1072
1079
|
topk_weights, topk_ids, router_logits = topk_output
|
1073
1080
|
|
1074
|
-
|
1081
|
+
output = fused_marlin_moe(
|
1075
1082
|
x,
|
1076
1083
|
layer.w13_qweight,
|
1077
1084
|
layer.w2_qweight,
|
@@ -1087,3 +1094,4 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|
1087
1094
|
num_bits=self.quant_config.weight_bits,
|
1088
1095
|
is_k_full=self.is_k_full,
|
1089
1096
|
).to(orig_dtype)
|
1097
|
+
return StandardCombineInput(hidden_states=output)
|
@@ -10,10 +10,14 @@ from torch.nn.parameter import Parameter
|
|
10
10
|
from sglang.srt.distributed import get_tp_group
|
11
11
|
from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer
|
12
12
|
from sglang.srt.layers.moe import (
|
13
|
+
MoeRunner,
|
14
|
+
MoeRunnerBackend,
|
15
|
+
MoeRunnerConfig,
|
13
16
|
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
14
17
|
should_use_flashinfer_trtllm_moe,
|
15
18
|
)
|
16
19
|
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
20
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
17
21
|
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
18
22
|
from sglang.srt.layers.quantization.base_config import (
|
19
23
|
FusedMoEMethodBase,
|
@@ -39,8 +43,10 @@ from sglang.srt.utils import is_cuda, next_power_of_2
|
|
39
43
|
|
40
44
|
if TYPE_CHECKING:
|
41
45
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
42
|
-
from sglang.srt.layers.moe.
|
43
|
-
|
46
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
47
|
+
CombineInput,
|
48
|
+
StandardDispatchOutput,
|
49
|
+
)
|
44
50
|
|
45
51
|
if is_cuda():
|
46
52
|
from sgl_kernel import scaled_fp4_quant
|
@@ -322,7 +328,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|
322
328
|
layer: torch.nn.Module,
|
323
329
|
num_experts: int,
|
324
330
|
hidden_size: int,
|
325
|
-
|
331
|
+
intermediate_size_per_partition: int,
|
326
332
|
params_dtype: torch.dtype,
|
327
333
|
**extra_weight_attrs,
|
328
334
|
):
|
@@ -338,7 +344,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|
338
344
|
|
339
345
|
w13_weight = ModelWeightParameter(
|
340
346
|
data=torch.empty(
|
341
|
-
num_experts,
|
347
|
+
num_experts,
|
348
|
+
2 * intermediate_size_per_partition,
|
349
|
+
hidden_size,
|
350
|
+
dtype=weight_dtype,
|
342
351
|
),
|
343
352
|
input_dim=2,
|
344
353
|
output_dim=1,
|
@@ -348,7 +357,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|
348
357
|
|
349
358
|
w2_weight = ModelWeightParameter(
|
350
359
|
data=torch.empty(
|
351
|
-
num_experts,
|
360
|
+
num_experts,
|
361
|
+
hidden_size,
|
362
|
+
intermediate_size_per_partition,
|
363
|
+
dtype=weight_dtype,
|
352
364
|
),
|
353
365
|
input_dim=2,
|
354
366
|
output_dim=1,
|
@@ -414,28 +426,28 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|
414
426
|
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
415
427
|
|
416
428
|
# Requantize each expert's weights using the combined scale
|
417
|
-
# w13_weight has shape (num_experts, 2 *
|
418
|
-
# where the first
|
419
|
-
|
429
|
+
# w13_weight has shape (num_experts, 2 * intermediate_size_per_partition, hidden_size)
|
430
|
+
# where the first intermediate_size_per_partition rows are w1, the next are w3
|
431
|
+
intermediate_size_per_partition = layer.w13_weight.shape[1] // 2
|
420
432
|
for expert_id in range(layer.w13_weight.shape[0]):
|
421
433
|
start = 0
|
422
434
|
for shard_id in range(2): # w1 and w3
|
423
435
|
# Dequantize using the original scale for this shard
|
424
436
|
dq_weight = per_tensor_dequantize(
|
425
437
|
layer.w13_weight[expert_id][
|
426
|
-
start : start +
|
438
|
+
start : start + intermediate_size_per_partition, :
|
427
439
|
],
|
428
440
|
layer.w13_weight_scale[expert_id][shard_id],
|
429
441
|
)
|
430
442
|
# Requantize using the combined max scale
|
431
443
|
(
|
432
444
|
layer.w13_weight[expert_id][
|
433
|
-
start : start +
|
445
|
+
start : start + intermediate_size_per_partition, :
|
434
446
|
],
|
435
447
|
_,
|
436
448
|
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
437
449
|
|
438
|
-
start +=
|
450
|
+
start += intermediate_size_per_partition
|
439
451
|
|
440
452
|
# Update the scale parameter to be per-expert instead of per-shard
|
441
453
|
layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
|
@@ -457,29 +469,31 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|
457
469
|
layer.w2_input_scale.max(), requires_grad=False
|
458
470
|
)
|
459
471
|
|
472
|
+
def create_moe_runner(
|
473
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
474
|
+
):
|
475
|
+
self.moe_runner_config = moe_runner_config
|
476
|
+
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
477
|
+
|
460
478
|
def apply(
|
461
479
|
self,
|
462
480
|
layer: torch.nn.Module,
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
return fused_experts(
|
470
|
-
x,
|
471
|
-
layer.w13_weight,
|
472
|
-
layer.w2_weight,
|
473
|
-
topk_output=topk_output,
|
474
|
-
moe_runner_config=moe_runner_config,
|
481
|
+
dispatch_output: StandardDispatchOutput,
|
482
|
+
) -> CombineInput:
|
483
|
+
|
484
|
+
quant_info = TritonMoeQuantInfo(
|
485
|
+
w13_weight=layer.w13_weight,
|
486
|
+
w2_weight=layer.w2_weight,
|
475
487
|
use_fp8_w8a8=True,
|
476
|
-
per_channel_quant=False,
|
477
|
-
|
488
|
+
per_channel_quant=False,
|
489
|
+
w13_scale=layer.w13_weight_scale,
|
478
490
|
w2_scale=layer.w2_weight_scale,
|
479
|
-
|
491
|
+
a13_scale=layer.w13_input_scale,
|
480
492
|
a2_scale=layer.w2_input_scale,
|
481
493
|
)
|
482
494
|
|
495
|
+
return self.runner.run(dispatch_output, quant_info)
|
496
|
+
|
483
497
|
|
484
498
|
class ModelOptFp4Config(QuantizationConfig):
|
485
499
|
"""Config class for FP4."""
|
@@ -517,6 +531,39 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
517
531
|
def get_config_filenames(cls) -> List[str]:
|
518
532
|
return ["hf_quant_config.json"]
|
519
533
|
|
534
|
+
@staticmethod
|
535
|
+
def common_group_size(cfg: dict) -> int:
|
536
|
+
"""Return the unique group_size across the config; raise if missing/mismatched."""
|
537
|
+
sizes = set()
|
538
|
+
|
539
|
+
# Top-level and 'quantization' block
|
540
|
+
v = cfg.get("group_size")
|
541
|
+
if isinstance(v, int):
|
542
|
+
sizes.add(v)
|
543
|
+
q = cfg.get("quantization")
|
544
|
+
if isinstance(q, dict):
|
545
|
+
v = q.get("group_size")
|
546
|
+
if isinstance(v, int):
|
547
|
+
sizes.add(v)
|
548
|
+
|
549
|
+
# config_groups: accept group-level or nested dicts (e.g., weights/input_activations)
|
550
|
+
for g in (cfg.get("config_groups") or {}).values():
|
551
|
+
if isinstance(g, dict):
|
552
|
+
v = g.get("group_size")
|
553
|
+
if isinstance(v, int):
|
554
|
+
sizes.add(v)
|
555
|
+
for sub in g.values():
|
556
|
+
if isinstance(sub, dict):
|
557
|
+
v = sub.get("group_size")
|
558
|
+
if isinstance(v, int):
|
559
|
+
sizes.add(v)
|
560
|
+
|
561
|
+
if not sizes:
|
562
|
+
raise ValueError("No group_size found in config.")
|
563
|
+
if len(sizes) > 1:
|
564
|
+
raise ValueError(f"Inconsistent group_size values: {sorted(sizes)}")
|
565
|
+
return next(iter(sizes))
|
566
|
+
|
520
567
|
@classmethod
|
521
568
|
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
|
522
569
|
# Handle two different config formats:
|
@@ -549,7 +596,7 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
549
596
|
else:
|
550
597
|
kv_cache_quant_algo = "auto"
|
551
598
|
|
552
|
-
group_size =
|
599
|
+
group_size = ModelOptFp4Config.common_group_size(config)
|
553
600
|
exclude_modules = config.get("ignore", [])
|
554
601
|
else:
|
555
602
|
# Fall back to nested format (hf_quant_config.json - legacy format)
|
@@ -559,7 +606,7 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
559
606
|
kv_cache_quant_algo = quant_config.get("kv_cache_quant_algo")
|
560
607
|
if not kv_cache_quant_algo:
|
561
608
|
kv_cache_quant_algo = "auto"
|
562
|
-
group_size =
|
609
|
+
group_size = ModelOptFp4Config.common_group_size(config)
|
563
610
|
exclude_modules = quant_config.get("exclude_modules", [])
|
564
611
|
except (ValueError, KeyError):
|
565
612
|
raise ValueError(
|
@@ -595,10 +642,22 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
595
642
|
def is_layer_excluded(self, prefix: str, exclude_modules: list):
|
596
643
|
import regex as re
|
597
644
|
|
645
|
+
fused_patterns = ["q_a_proj", "q_b_proj", "kv_a_proj_with_mqa", "kv_b_proj"]
|
646
|
+
prefix_split = prefix.split(".")
|
598
647
|
for pattern in exclude_modules:
|
599
648
|
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
|
649
|
+
pattern_split = pattern.split(".")
|
600
650
|
if re.fullmatch(regex_str, prefix):
|
601
651
|
return True
|
652
|
+
elif (
|
653
|
+
pattern_split[-1] in fused_patterns
|
654
|
+
and pattern_split[-1] in prefix_split[-1]
|
655
|
+
):
|
656
|
+
# Check if the last part of the excluded pattern is contained in the last part of the prefix
|
657
|
+
# This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa
|
658
|
+
# e.g., model.layers.{i}.self_attn.{fused_weight_name}
|
659
|
+
assert len(prefix_split) == 5 and len(pattern_split) == 5
|
660
|
+
return True
|
602
661
|
return False
|
603
662
|
|
604
663
|
def get_quant_method(
|
@@ -1203,8 +1262,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1203
1262
|
layer.w13_weight_scale,
|
1204
1263
|
)
|
1205
1264
|
|
1206
|
-
logger.info_once("Applied flashinfer weight processing for both w13 and w2")
|
1207
|
-
|
1208
1265
|
else:
|
1209
1266
|
# CUTLASS processing - handle w13 and w2 separately
|
1210
1267
|
|
@@ -1221,7 +1278,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1221
1278
|
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
1222
1279
|
|
1223
1280
|
# Both flashinfer cutlass and regular cutlass use same processing for w2
|
1224
|
-
logger.info_once("Applied weight processing for both w13 and w2")
|
1225
1281
|
|
1226
1282
|
# Set up CUTLASS MoE parameters
|
1227
1283
|
device = layer.w13_weight.device
|
@@ -1238,21 +1294,32 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1238
1294
|
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
|
1239
1295
|
return self.enable_flashinfer_cutlass_moe
|
1240
1296
|
|
1297
|
+
def create_moe_runner(
|
1298
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
1299
|
+
):
|
1300
|
+
self.moe_runner_config = moe_runner_config
|
1301
|
+
|
1241
1302
|
def apply(
|
1242
1303
|
self,
|
1243
1304
|
layer: FusedMoE,
|
1244
|
-
|
1245
|
-
|
1246
|
-
|
1247
|
-
|
1305
|
+
dispatch_output: StandardDispatchOutput,
|
1306
|
+
) -> CombineInput:
|
1307
|
+
|
1308
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
1309
|
+
|
1310
|
+
x = dispatch_output.hidden_states
|
1311
|
+
topk_output = dispatch_output.topk_output
|
1312
|
+
|
1248
1313
|
assert (
|
1249
|
-
moe_runner_config.activation == "silu"
|
1314
|
+
self.moe_runner_config.activation == "silu"
|
1250
1315
|
), "Only SiLU activation is supported."
|
1251
1316
|
|
1317
|
+
moe_runner_config = self.moe_runner_config
|
1318
|
+
|
1252
1319
|
# Check if this is a FlashInferFP4MoE layer that should handle its own forward
|
1253
1320
|
if hasattr(layer, "gemm1_weights_fp4_shuffled"):
|
1254
1321
|
# This layer was processed with flashinfer TRTLLM - delegate to its own forward
|
1255
|
-
return layer.forward(x, topk_output)
|
1322
|
+
return StandardCombineInput(hidden_states=layer.forward(x, topk_output))
|
1256
1323
|
|
1257
1324
|
if self.enable_flashinfer_cutlass_moe:
|
1258
1325
|
assert (
|
@@ -1305,13 +1372,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1305
1372
|
tp_rank=layer.moe_tp_rank,
|
1306
1373
|
tune_max_num_tokens=next_power_of_2(x.shape[0]),
|
1307
1374
|
)[0]
|
1308
|
-
# Scale by routed_scaling_factor is fused into select_experts.
|
1309
1375
|
if should_use_flashinfer_cutlass_moe_fp4_allgather():
|
1310
1376
|
output, global_output = get_local_dp_buffer(), output
|
1311
1377
|
get_tp_group().reduce_scatterv(
|
1312
1378
|
global_output, output=output, sizes=get_dp_global_num_tokens()
|
1313
1379
|
)
|
1314
|
-
return output
|
1380
|
+
return StandardCombineInput(hidden_states=output)
|
1315
1381
|
|
1316
1382
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
1317
1383
|
|
@@ -1332,4 +1398,5 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1332
1398
|
apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
|
1333
1399
|
).to(x.dtype)
|
1334
1400
|
# Scale by routed_scaling_factor is fused into select_experts.
|
1335
|
-
|
1401
|
+
|
1402
|
+
return StandardCombineInput(hidden_states=output)
|
@@ -9,6 +9,8 @@ import torch
|
|
9
9
|
|
10
10
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
11
11
|
from sglang.srt.distributed.parallel_state import get_tp_group
|
12
|
+
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
13
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
12
14
|
from sglang.srt.layers.quantization.awq import AWQConfig
|
13
15
|
from sglang.srt.layers.quantization.base_config import (
|
14
16
|
FusedMoEMethodBase,
|
@@ -22,8 +24,10 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
|
|
22
24
|
logger = logging.getLogger(__name__)
|
23
25
|
|
24
26
|
if TYPE_CHECKING:
|
25
|
-
from sglang.srt.layers.moe.
|
26
|
-
|
27
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
28
|
+
CombineInput,
|
29
|
+
StandardDispatchOutput,
|
30
|
+
)
|
27
31
|
|
28
32
|
|
29
33
|
def get_weight_perm(num_bits: int):
|
@@ -349,37 +353,36 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|
349
353
|
layer.register_parameter(key, param)
|
350
354
|
set_weight_attrs(param, extra_weight_attrs)
|
351
355
|
|
356
|
+
def create_moe_runner(
|
357
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
358
|
+
):
|
359
|
+
self.moe_runner_config = moe_runner_config
|
360
|
+
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
361
|
+
|
352
362
|
def apply(
|
353
363
|
self,
|
354
364
|
layer: torch.nn.Module,
|
355
|
-
|
356
|
-
|
357
|
-
moe_runner_config: MoeRunnerConfig,
|
358
|
-
) -> torch.Tensor:
|
359
|
-
# avoid circular import
|
360
|
-
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
361
|
-
|
365
|
+
dispatch_output: StandardDispatchOutput,
|
366
|
+
) -> CombineInput:
|
362
367
|
assert (
|
363
|
-
moe_runner_config.activation == "silu"
|
368
|
+
self.moe_runner_config.activation == "silu"
|
364
369
|
), "Only SiLU activation is supported."
|
365
370
|
|
366
371
|
weight_bits = self.quant_config.weight_bits
|
367
372
|
has_zp = self.quant_config.has_zp
|
368
373
|
|
369
|
-
|
370
|
-
|
371
|
-
layer.
|
372
|
-
layer.w2_qweight,
|
373
|
-
topk_output=topk_output,
|
374
|
-
moe_runner_config=moe_runner_config,
|
374
|
+
quant_info = TritonMoeQuantInfo(
|
375
|
+
w13_weight=layer.w13_qweight,
|
376
|
+
w2_weight=layer.w2_qweight,
|
375
377
|
use_int4_w4a16=weight_bits == 4,
|
376
378
|
use_int8_w8a16=weight_bits == 8,
|
377
|
-
|
379
|
+
w13_scale=layer.w13_scales,
|
378
380
|
w2_scale=layer.w2_scales,
|
379
|
-
|
381
|
+
w13_zp=layer.w13_qzeros if has_zp else None,
|
380
382
|
w2_zp=layer.w2_qzeros if has_zp else None,
|
381
383
|
block_shape=[0, layer.group_size],
|
382
384
|
)
|
385
|
+
return self.runner.run(dispatch_output, quant_info)
|
383
386
|
|
384
387
|
@staticmethod
|
385
388
|
def get_weight_loader(layer, weight_loader):
|