sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +330 -156
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +8 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +134 -23
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +70 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +66 -66
- sglang/srt/entrypoints/grpc_server.py +431 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +120 -8
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +42 -4
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +18 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +4 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +12 -8
- sglang/srt/layers/attention/flashinfer_backend.py +248 -21
- sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +11 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +45 -15
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +147 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
- sglang/srt/layers/moe/ep_moe/layer.py +119 -397
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +17 -1
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +5 -30
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +673 -16
- sglang/srt/layers/sampler.py +36 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +66 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +399 -499
- sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +378 -90
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +435 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +28 -23
- sglang/srt/model_executor/model_runner.py +379 -139
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +273 -98
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +14 -37
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +5 -5
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3_moe.py +19 -35
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +6 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +577 -73
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +38 -28
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +192 -47
- sglang/srt/utils/hf_transformers_utils.py +40 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +232 -99
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -25,30 +25,6 @@ if TYPE_CHECKING:
|
|
|
25
25
|
def quantize(w, dtype, dev, **opt):
|
|
26
26
|
if dtype == "bf16":
|
|
27
27
|
return w.to(torch.bfloat16), InFlexData()
|
|
28
|
-
elif dtype == "fp8":
|
|
29
|
-
wq = w.to(torch.float8_e4m3fn).transpose(-1, -2).contiguous().transpose(-1, -2)
|
|
30
|
-
return (
|
|
31
|
-
wq,
|
|
32
|
-
InFlexData(dtype=wq.dtype, scale=w.abs().max().unsqueeze(0)),
|
|
33
|
-
MicroscalingCtx(),
|
|
34
|
-
)
|
|
35
|
-
else:
|
|
36
|
-
assert dtype == "mx4", f"{dtype=}"
|
|
37
|
-
swizzle_mx_scale = opt["swizzle_mx_scale"]
|
|
38
|
-
swizzle_axis = 2 if swizzle_mx_scale else None
|
|
39
|
-
w = w.to(torch.bfloat16)
|
|
40
|
-
w, mx_scales, weight_scale_shape = downcast_to_mxfp(
|
|
41
|
-
w, torch.uint8, axis=1, swizzle_axis=swizzle_axis
|
|
42
|
-
)
|
|
43
|
-
return (
|
|
44
|
-
w,
|
|
45
|
-
InFlexData(),
|
|
46
|
-
MicroscalingCtx(
|
|
47
|
-
weight_scale=mx_scales,
|
|
48
|
-
swizzle_mx=swizzle_mx_scale,
|
|
49
|
-
actual_weight_scale_shape=weight_scale_shape,
|
|
50
|
-
),
|
|
51
|
-
)
|
|
52
28
|
|
|
53
29
|
|
|
54
30
|
def triton_kernel_moe_forward(
|
|
@@ -119,14 +95,14 @@ def triton_kernel_fused_experts(
|
|
|
119
95
|
block_shape: Optional[list[int]] = None,
|
|
120
96
|
) -> torch.Tensor:
|
|
121
97
|
|
|
122
|
-
assert use_fp8_w8a8
|
|
123
|
-
assert per_channel_quant
|
|
124
|
-
assert expert_map
|
|
125
|
-
assert w1_scale
|
|
126
|
-
assert w2_scale
|
|
127
|
-
assert a1_scale
|
|
128
|
-
assert a2_scale
|
|
129
|
-
assert block_shape
|
|
98
|
+
assert use_fp8_w8a8 is False, "use_fp8_w8a8 is not supported"
|
|
99
|
+
assert per_channel_quant is False, "per_channel_quant is not supported"
|
|
100
|
+
assert expert_map is None, "expert_map is not supported"
|
|
101
|
+
assert w1_scale is None, "w1_scale is not supported"
|
|
102
|
+
assert w2_scale is None, "w2_scale is not supported"
|
|
103
|
+
assert a1_scale is None, "a1_scale is not supported"
|
|
104
|
+
assert a2_scale is None, "a2_scale is not supported"
|
|
105
|
+
assert block_shape is None, "block_shape is not supported"
|
|
130
106
|
|
|
131
107
|
# type check
|
|
132
108
|
assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
|
|
@@ -143,7 +119,7 @@ def triton_kernel_fused_experts(
|
|
|
143
119
|
), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}"
|
|
144
120
|
|
|
145
121
|
# feature check
|
|
146
|
-
assert inplace
|
|
122
|
+
assert inplace is False, "Inplace is not supported in new triton MoE kernel"
|
|
147
123
|
|
|
148
124
|
M, K = hidden_states.shape
|
|
149
125
|
E, _, N = w1.shape
|
|
@@ -264,14 +240,14 @@ def triton_kernel_fused_experts_with_bias(
|
|
|
264
240
|
gemm1_alpha: Optional[float] = None,
|
|
265
241
|
gemm1_clamp_limit: Optional[float] = None,
|
|
266
242
|
) -> torch.Tensor:
|
|
267
|
-
assert use_fp8_w8a8
|
|
268
|
-
assert per_channel_quant
|
|
269
|
-
assert expert_map
|
|
270
|
-
assert w1_scale
|
|
271
|
-
assert w2_scale
|
|
272
|
-
assert a1_scale
|
|
273
|
-
assert a2_scale
|
|
274
|
-
assert block_shape
|
|
243
|
+
assert use_fp8_w8a8 is False, "use_fp8_w8a8 is not supported"
|
|
244
|
+
assert per_channel_quant is False, "per_channel_quant is not supported"
|
|
245
|
+
assert expert_map is None, "expert_map is not supported"
|
|
246
|
+
assert w1_scale is None, "w1_scale is not supported"
|
|
247
|
+
assert w2_scale is None, "w2_scale is not supported"
|
|
248
|
+
assert a1_scale is None, "a1_scale is not supported"
|
|
249
|
+
assert a2_scale is None, "a2_scale is not supported"
|
|
250
|
+
assert block_shape is None, "block_shape is not supported"
|
|
275
251
|
|
|
276
252
|
# type check
|
|
277
253
|
assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
|
|
@@ -290,7 +266,7 @@ def triton_kernel_fused_experts_with_bias(
|
|
|
290
266
|
), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}"
|
|
291
267
|
|
|
292
268
|
# feature check
|
|
293
|
-
assert inplace
|
|
269
|
+
assert inplace is False, "Inplace is not supported in new triton MoE kernel"
|
|
294
270
|
|
|
295
271
|
E, _, _ = w1.shape
|
|
296
272
|
|
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import TYPE_CHECKING, List, Optional
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from sglang.srt.layers.moe.moe_runner.base import (
|
|
9
|
+
MoeQuantInfo,
|
|
10
|
+
MoeRunnerConfig,
|
|
11
|
+
MoeRunnerCore,
|
|
12
|
+
RunnerInput,
|
|
13
|
+
RunnerOutput,
|
|
14
|
+
register_post_permute,
|
|
15
|
+
register_pre_permute,
|
|
16
|
+
)
|
|
17
|
+
from sglang.srt.layers.moe.utils import MoeRunnerBackend
|
|
18
|
+
from sglang.srt.utils import dispose_tensor
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from sglang.srt.layers.moe.token_dispatcher.standard import (
|
|
22
|
+
StandardCombineInput,
|
|
23
|
+
StandardDispatchOutput,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# TODO(kaixih@nvidia): ideally we should merge this logic into
|
|
28
|
+
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
|
|
29
|
+
@torch.compile
|
|
30
|
+
def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
|
|
31
|
+
temp = x.to(torch.float32).view(torch.int32)
|
|
32
|
+
exp = torch.bitwise_right_shift(temp, 23)
|
|
33
|
+
mant = torch.bitwise_and(temp, 0x7FFFFF)
|
|
34
|
+
is_ru = torch.logical_and(
|
|
35
|
+
torch.logical_and((mant > 0), (exp != 0xFE)),
|
|
36
|
+
~torch.logical_and((exp == 0), (mant <= 0x400000)),
|
|
37
|
+
)
|
|
38
|
+
exp = torch.where(is_ru, exp + 1, exp)
|
|
39
|
+
new_x = exp.to(torch.uint8).view(torch.int)
|
|
40
|
+
return new_x.transpose(1, 2).contiguous().transpose(1, 2)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class DeepGemmRunnerInput(RunnerInput):
|
|
45
|
+
hidden_states: torch.Tensor
|
|
46
|
+
hidden_states_scale: torch.Tensor
|
|
47
|
+
masked_m: torch.Tensor
|
|
48
|
+
expected_m: int
|
|
49
|
+
use_masked_gemm: bool
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def runner_backend(self) -> MoeRunnerBackend:
|
|
53
|
+
return MoeRunnerBackend.DEEP_GEMM
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dataclass
|
|
57
|
+
class DeepGemmRunnerOutput(RunnerOutput):
|
|
58
|
+
hidden_states: torch.Tensor
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def runner_backend(self) -> MoeRunnerBackend:
|
|
62
|
+
return MoeRunnerBackend.DEEP_GEMM
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@dataclass
|
|
66
|
+
class DeepGemmMoeQuantInfo(MoeQuantInfo):
|
|
67
|
+
w13_weight: torch.Tensor
|
|
68
|
+
w2_weight: torch.Tensor
|
|
69
|
+
use_fp8: bool
|
|
70
|
+
w13_scale: Optional[torch.Tensor] = None
|
|
71
|
+
w2_scale: Optional[torch.Tensor] = None
|
|
72
|
+
block_shape: Optional[List[int]] = None
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class DeepGemmRunnerCore(MoeRunnerCore):
|
|
76
|
+
def __init__(self, config: MoeRunnerConfig):
|
|
77
|
+
super().__init__(config)
|
|
78
|
+
assert self.config.activation == "silu"
|
|
79
|
+
|
|
80
|
+
def run(
|
|
81
|
+
self,
|
|
82
|
+
runner_input: DeepGemmRunnerInput,
|
|
83
|
+
quant_info: DeepGemmMoeQuantInfo,
|
|
84
|
+
running_state: dict,
|
|
85
|
+
) -> DeepGemmRunnerOutput:
|
|
86
|
+
|
|
87
|
+
if runner_input.use_masked_gemm:
|
|
88
|
+
hidden_states = self._run_masked_gemm(
|
|
89
|
+
runner_input,
|
|
90
|
+
quant_info,
|
|
91
|
+
running_state,
|
|
92
|
+
)
|
|
93
|
+
else:
|
|
94
|
+
hidden_states = self._run_contiguous_gemm(
|
|
95
|
+
runner_input,
|
|
96
|
+
quant_info,
|
|
97
|
+
running_state,
|
|
98
|
+
)
|
|
99
|
+
return DeepGemmRunnerOutput(hidden_states=hidden_states)
|
|
100
|
+
|
|
101
|
+
def _run_masked_gemm(
|
|
102
|
+
self,
|
|
103
|
+
runner_input: DeepGemmRunnerInput,
|
|
104
|
+
quant_info: DeepGemmMoeQuantInfo,
|
|
105
|
+
running_state: dict,
|
|
106
|
+
) -> torch.Tensor:
|
|
107
|
+
|
|
108
|
+
from sglang.srt.layers import deep_gemm_wrapper
|
|
109
|
+
from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
110
|
+
silu_and_mul_masked_post_quant_fwd,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
hidden_states = runner_input.hidden_states
|
|
114
|
+
hidden_states_scale = runner_input.hidden_states_scale
|
|
115
|
+
masked_m = runner_input.masked_m
|
|
116
|
+
expected_m = runner_input.expected_m
|
|
117
|
+
|
|
118
|
+
w13_weight = quant_info.w13_weight
|
|
119
|
+
w2_weight = quant_info.w2_weight
|
|
120
|
+
w13_scale = quant_info.w13_scale
|
|
121
|
+
w2_scale = quant_info.w2_scale
|
|
122
|
+
|
|
123
|
+
hidden_states_device = running_state["hidden_states_device"]
|
|
124
|
+
|
|
125
|
+
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
|
126
|
+
b, s_mn, s_k = hidden_states_scale.shape
|
|
127
|
+
assert (
|
|
128
|
+
s_mn % 4 == 0 and s_k % 4 == 0
|
|
129
|
+
), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
|
|
130
|
+
|
|
131
|
+
# GroupGemm-0
|
|
132
|
+
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
|
133
|
+
hidden_states_scale = _cast_to_e8m0_with_rounding_up(hidden_states_scale)
|
|
134
|
+
else:
|
|
135
|
+
hidden_states_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
|
|
136
|
+
hidden_states_scale
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
num_groups, m, k = hidden_states.shape
|
|
140
|
+
n = w13_weight.size(1)
|
|
141
|
+
gateup_output = torch.empty(
|
|
142
|
+
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
|
143
|
+
)
|
|
144
|
+
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
|
145
|
+
(hidden_states, hidden_states_scale),
|
|
146
|
+
(w13_weight, w13_scale),
|
|
147
|
+
gateup_output,
|
|
148
|
+
masked_m,
|
|
149
|
+
expected_m,
|
|
150
|
+
)
|
|
151
|
+
dispose_tensor(hidden_states)
|
|
152
|
+
|
|
153
|
+
# Act
|
|
154
|
+
down_input = torch.empty(
|
|
155
|
+
(
|
|
156
|
+
gateup_output.shape[0],
|
|
157
|
+
gateup_output.shape[1],
|
|
158
|
+
gateup_output.shape[2] // 2,
|
|
159
|
+
),
|
|
160
|
+
device=hidden_states_device,
|
|
161
|
+
dtype=torch.float8_e4m3fn,
|
|
162
|
+
)
|
|
163
|
+
scale_block_size = 128
|
|
164
|
+
down_input_scale = torch.empty(
|
|
165
|
+
(
|
|
166
|
+
gateup_output.shape[0],
|
|
167
|
+
gateup_output.shape[1],
|
|
168
|
+
gateup_output.shape[2] // 2 // scale_block_size,
|
|
169
|
+
),
|
|
170
|
+
device=hidden_states_device,
|
|
171
|
+
dtype=torch.float32,
|
|
172
|
+
)
|
|
173
|
+
silu_and_mul_masked_post_quant_fwd(
|
|
174
|
+
gateup_output,
|
|
175
|
+
down_input,
|
|
176
|
+
down_input_scale,
|
|
177
|
+
scale_block_size,
|
|
178
|
+
masked_m,
|
|
179
|
+
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
|
180
|
+
)
|
|
181
|
+
del gateup_output
|
|
182
|
+
|
|
183
|
+
# GroupGemm-1
|
|
184
|
+
n = w2_weight.shape[1]
|
|
185
|
+
|
|
186
|
+
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
|
187
|
+
down_input_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
|
|
188
|
+
down_input_scale
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
down_output = torch.empty(
|
|
192
|
+
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
|
193
|
+
)
|
|
194
|
+
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
|
195
|
+
(down_input, down_input_scale),
|
|
196
|
+
(w2_weight, w2_scale),
|
|
197
|
+
down_output,
|
|
198
|
+
masked_m,
|
|
199
|
+
expected_m,
|
|
200
|
+
)
|
|
201
|
+
del down_input
|
|
202
|
+
|
|
203
|
+
return down_output
|
|
204
|
+
|
|
205
|
+
def _run_contiguous_gemm(
|
|
206
|
+
self,
|
|
207
|
+
runner_input: DeepGemmRunnerInput,
|
|
208
|
+
quant_info: DeepGemmMoeQuantInfo,
|
|
209
|
+
running_state: dict,
|
|
210
|
+
) -> torch.Tensor:
|
|
211
|
+
pass
|
|
212
|
+
|
|
213
|
+
@property
|
|
214
|
+
def runner_backend(self) -> MoeRunnerBackend:
|
|
215
|
+
return MoeRunnerBackend.DEEP_GEMM
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
@register_pre_permute("standard", "deep_gemm")
|
|
219
|
+
def pre_permute_standard_to_deep_gemm(
|
|
220
|
+
dispatch_output: StandardDispatchOutput,
|
|
221
|
+
quant_info: DeepGemmMoeQuantInfo,
|
|
222
|
+
runner_config: MoeRunnerConfig,
|
|
223
|
+
running_state: dict,
|
|
224
|
+
) -> DeepGemmRunnerInput:
|
|
225
|
+
from sglang.srt.layers.moe.ep_moe.kernels import moe_ep_deepgemm_preprocess
|
|
226
|
+
|
|
227
|
+
hidden_states, topk_output = dispatch_output
|
|
228
|
+
topk_weights, topk_ids, _ = topk_output
|
|
229
|
+
|
|
230
|
+
hidden_states_shape = hidden_states.shape
|
|
231
|
+
hidden_states_dtype = hidden_states.dtype
|
|
232
|
+
hidden_states_device = hidden_states.device
|
|
233
|
+
hidden_states_ref = hidden_states
|
|
234
|
+
|
|
235
|
+
topk_weights, topk_ids = topk_weights, topk_ids
|
|
236
|
+
|
|
237
|
+
# PreReorder
|
|
238
|
+
masked_m, expected_m, src2dst, hidden_states, hidden_states_scale = (
|
|
239
|
+
moe_ep_deepgemm_preprocess(
|
|
240
|
+
topk_ids,
|
|
241
|
+
runner_config.num_local_experts,
|
|
242
|
+
hidden_states,
|
|
243
|
+
runner_config.top_k,
|
|
244
|
+
quant_info.block_shape,
|
|
245
|
+
)
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
dispose_tensor(hidden_states_ref)
|
|
249
|
+
|
|
250
|
+
running_state["topk_ids"] = topk_ids
|
|
251
|
+
running_state["topk_weights"] = topk_weights
|
|
252
|
+
running_state["hidden_states_shape"] = hidden_states_shape
|
|
253
|
+
running_state["hidden_states_dtype"] = hidden_states_dtype
|
|
254
|
+
running_state["hidden_states_device"] = hidden_states_device
|
|
255
|
+
running_state["src2dst"] = src2dst
|
|
256
|
+
|
|
257
|
+
return DeepGemmRunnerInput(
|
|
258
|
+
hidden_states=hidden_states,
|
|
259
|
+
hidden_states_scale=hidden_states_scale,
|
|
260
|
+
masked_m=masked_m,
|
|
261
|
+
expected_m=expected_m,
|
|
262
|
+
use_masked_gemm=True,
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
@register_post_permute("deep_gemm", "standard")
|
|
267
|
+
def post_permute_deep_gemm_to_standard(
|
|
268
|
+
runner_output: DeepGemmRunnerOutput,
|
|
269
|
+
quant_info: DeepGemmMoeQuantInfo,
|
|
270
|
+
runner_config: MoeRunnerConfig,
|
|
271
|
+
running_state: dict,
|
|
272
|
+
) -> StandardCombineInput:
|
|
273
|
+
from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel
|
|
274
|
+
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
|
|
275
|
+
|
|
276
|
+
hidden_states_shape = running_state["hidden_states_shape"]
|
|
277
|
+
hidden_states_dtype = running_state["hidden_states_dtype"]
|
|
278
|
+
hidden_states_device = running_state["hidden_states_device"]
|
|
279
|
+
src2dst = running_state["src2dst"]
|
|
280
|
+
topk_ids = running_state["topk_ids"]
|
|
281
|
+
topk_weights = running_state["topk_weights"]
|
|
282
|
+
|
|
283
|
+
output = torch.empty(
|
|
284
|
+
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
|
|
285
|
+
)
|
|
286
|
+
post_reorder_triton_kernel[(hidden_states_shape[0],)](
|
|
287
|
+
runner_output.hidden_states,
|
|
288
|
+
output,
|
|
289
|
+
src2dst,
|
|
290
|
+
topk_ids,
|
|
291
|
+
topk_weights,
|
|
292
|
+
runner_config.top_k,
|
|
293
|
+
hidden_states_shape[1],
|
|
294
|
+
BLOCK_SIZE=512,
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
dispose_tensor(runner_output.hidden_states)
|
|
298
|
+
|
|
299
|
+
if runner_config.routed_scaling_factor is not None:
|
|
300
|
+
output *= runner_config.routed_scaling_factor
|
|
301
|
+
|
|
302
|
+
return StandardCombineInput(
|
|
303
|
+
hidden_states=output,
|
|
304
|
+
)
|
|
@@ -9,6 +9,7 @@ from sglang.srt.layers.moe.moe_runner.base import (
|
|
|
9
9
|
MoeRunnerConfig,
|
|
10
10
|
PermuteMethodPool,
|
|
11
11
|
)
|
|
12
|
+
from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmRunnerCore
|
|
12
13
|
from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore
|
|
13
14
|
from sglang.srt.layers.moe.utils import get_moe_a2a_backend
|
|
14
15
|
|
|
@@ -30,6 +31,8 @@ class MoeRunner:
|
|
|
30
31
|
|
|
31
32
|
if runner_backend.is_triton():
|
|
32
33
|
self.runner_core = TritonRunnerCore(config)
|
|
34
|
+
elif runner_backend.is_deep_gemm():
|
|
35
|
+
self.runner_core = DeepGemmRunnerCore(config)
|
|
33
36
|
else:
|
|
34
37
|
raise NotImplementedError(f"Unsupported runner backend: {runner_backend}")
|
|
35
38
|
|
sglang/srt/layers/moe/router.py
CHANGED
|
@@ -11,7 +11,7 @@ _is_hip = is_hip()
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
@triton.jit
|
|
14
|
-
def
|
|
14
|
+
def fused_moe_router_cudacore_kernel(
|
|
15
15
|
input_ptr, # input (bs, hidden_dim)
|
|
16
16
|
moe_router_weight_ptr, # input (num_experts, hidden_dim)
|
|
17
17
|
topk_weights_ptr, # output (bs, topk)
|
|
@@ -114,7 +114,7 @@ def fused_moe_router_kernel(
|
|
|
114
114
|
# assert not moe_renormalize, "moe weight renormalization not implemented"
|
|
115
115
|
|
|
116
116
|
|
|
117
|
-
def
|
|
117
|
+
def fused_moe_router_cudacore(
|
|
118
118
|
x: torch.Tensor,
|
|
119
119
|
router_weight: torch.Tensor,
|
|
120
120
|
topk: int,
|
|
@@ -138,7 +138,7 @@ def fused_moe_router_impl(
|
|
|
138
138
|
),
|
|
139
139
|
}
|
|
140
140
|
|
|
141
|
-
|
|
141
|
+
fused_moe_router_cudacore_kernel[(bs,)](
|
|
142
142
|
x,
|
|
143
143
|
router_weight,
|
|
144
144
|
topk_weights,
|
|
@@ -157,7 +157,7 @@ def fused_moe_router_impl(
|
|
|
157
157
|
|
|
158
158
|
|
|
159
159
|
@triton.jit
|
|
160
|
-
def
|
|
160
|
+
def fused_moe_router_tensorcore_kernel(
|
|
161
161
|
a_ptr, # input (bs, hidden_dim)
|
|
162
162
|
b_ptr, # input (num_experts, hidden_dim)
|
|
163
163
|
topk_weights_ptr, # output (bs, topk)
|
|
@@ -167,12 +167,15 @@ def fused_moe_router_large_bs_kernel(
|
|
|
167
167
|
topk: tl.constexpr, # only support topk <= 2
|
|
168
168
|
moe_softcapping: tl.constexpr,
|
|
169
169
|
moe_renormalize: tl.constexpr, # not supported
|
|
170
|
+
correction_bias_ptr,
|
|
171
|
+
is_correction_bias: tl.constexpr,
|
|
170
172
|
K: tl.constexpr,
|
|
171
173
|
BLOCK_SIZE_M: tl.constexpr,
|
|
172
174
|
BLOCK_SIZE_N: tl.constexpr,
|
|
173
175
|
BLOCK_SIZE_K: tl.constexpr,
|
|
174
176
|
stride_am: tl.constexpr,
|
|
175
177
|
stride_bn: tl.constexpr,
|
|
178
|
+
dp_attn_workaround_flag: tl.constexpr,
|
|
176
179
|
):
|
|
177
180
|
|
|
178
181
|
# 1. get block id
|
|
@@ -217,6 +220,20 @@ def fused_moe_router_large_bs_kernel(
|
|
|
217
220
|
exped = tl.exp(2 * logits_scaled)
|
|
218
221
|
logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
|
|
219
222
|
|
|
223
|
+
# Add bias after softcapping
|
|
224
|
+
if is_correction_bias:
|
|
225
|
+
bias = tl.load(
|
|
226
|
+
correction_bias_ptr + tl.arange(0, BLOCK_SIZE_N)[None, :],
|
|
227
|
+
mask=expert_mask.T,
|
|
228
|
+
other=0.0,
|
|
229
|
+
)
|
|
230
|
+
logits_softcapped = logits_softcapped + bias
|
|
231
|
+
|
|
232
|
+
if dp_attn_workaround_flag:
|
|
233
|
+
logits_softcapped = tl.where(
|
|
234
|
+
logits_softcapped != logits_softcapped, -1e9, logits_softcapped
|
|
235
|
+
)
|
|
236
|
+
|
|
220
237
|
# 5. top1
|
|
221
238
|
arange_block_size_n = tl.arange(0, BLOCK_SIZE_N)[None, :]
|
|
222
239
|
cond_top1 = arange_block_size_n < num_experts
|
|
@@ -266,7 +283,7 @@ def fused_moe_router_large_bs_kernel(
|
|
|
266
283
|
)
|
|
267
284
|
|
|
268
285
|
|
|
269
|
-
def
|
|
286
|
+
def fused_moe_router_tensorcore(
|
|
270
287
|
x: torch.Tensor,
|
|
271
288
|
router_weight: torch.Tensor,
|
|
272
289
|
topk: int,
|
|
@@ -274,6 +291,7 @@ def fused_moe_router_large_bs_impl(
|
|
|
274
291
|
BLOCK_SIZE_M: int,
|
|
275
292
|
BLOCK_SIZE_N: int,
|
|
276
293
|
BLOCK_SIZE_K: int,
|
|
294
|
+
correction_bias: Optional[torch.Tensor] = None,
|
|
277
295
|
):
|
|
278
296
|
assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
|
|
279
297
|
bs, hidden_dim = x.shape
|
|
@@ -285,10 +303,17 @@ def fused_moe_router_large_bs_impl(
|
|
|
285
303
|
|
|
286
304
|
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
|
|
287
305
|
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
|
|
306
|
+
is_correction_bias = correction_bias is not None
|
|
288
307
|
|
|
289
308
|
grid = (triton.cdiv(bs, BLOCK_SIZE_M) * triton.cdiv(num_experts, BLOCK_SIZE_N),)
|
|
290
309
|
|
|
291
|
-
|
|
310
|
+
# TODO(ch-wan): temporary workaround for dp attention. We should support masked
|
|
311
|
+
# router to skip padded tokens.
|
|
312
|
+
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
|
|
313
|
+
|
|
314
|
+
dp_attn_workaround_flag = is_dp_attention_enabled()
|
|
315
|
+
|
|
316
|
+
fused_moe_router_tensorcore_kernel[grid](
|
|
292
317
|
a_ptr=x,
|
|
293
318
|
b_ptr=router_weight,
|
|
294
319
|
topk_weights_ptr=topk_weights,
|
|
@@ -299,11 +324,14 @@ def fused_moe_router_large_bs_impl(
|
|
|
299
324
|
moe_softcapping=moe_softcapping,
|
|
300
325
|
moe_renormalize=False,
|
|
301
326
|
K=hidden_dim,
|
|
327
|
+
correction_bias_ptr=correction_bias,
|
|
328
|
+
is_correction_bias=is_correction_bias,
|
|
302
329
|
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
|
303
330
|
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
|
304
331
|
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
|
305
332
|
stride_am=hidden_dim,
|
|
306
333
|
stride_bn=hidden_dim,
|
|
334
|
+
dp_attn_workaround_flag=dp_attn_workaround_flag,
|
|
307
335
|
)
|
|
308
336
|
|
|
309
337
|
return topk_weights, topk_ids
|
|
@@ -316,6 +344,7 @@ def fused_moe_router_shim(
|
|
|
316
344
|
topk,
|
|
317
345
|
renormalize,
|
|
318
346
|
correction_bias: Optional[torch.Tensor] = None,
|
|
347
|
+
enable_deterministic_inference: bool = False,
|
|
319
348
|
):
|
|
320
349
|
assert not renormalize
|
|
321
350
|
assert (
|
|
@@ -324,16 +353,22 @@ def fused_moe_router_shim(
|
|
|
324
353
|
)
|
|
325
354
|
bs, hidden_dim = hidden_states.shape
|
|
326
355
|
num_experts = gating_output.shape[0]
|
|
356
|
+
|
|
327
357
|
BLOCK_SIZE_M = 32
|
|
328
|
-
|
|
329
|
-
|
|
358
|
+
|
|
359
|
+
BLOCK_SIZE_N = max(num_experts, 16)
|
|
360
|
+
BLOCK_SIZE_K = (
|
|
361
|
+
256 if num_experts < 256 else 64
|
|
362
|
+
) # if experts are large, need to use smaller k block or shared memory OOM
|
|
363
|
+
|
|
330
364
|
if (
|
|
331
|
-
bs >= 512
|
|
332
|
-
and topk <= 2
|
|
333
|
-
and num_experts <= BLOCK_SIZE_N
|
|
365
|
+
(bs >= 512 or num_experts > 8)
|
|
334
366
|
and hidden_dim % BLOCK_SIZE_K == 0
|
|
367
|
+
# we keep using single kernel to avoid non-deterministic behavior
|
|
368
|
+
and not enable_deterministic_inference
|
|
335
369
|
):
|
|
336
|
-
|
|
370
|
+
# if large batch size or large expert, use kernel that uses tensorcore in matmul
|
|
371
|
+
return fused_moe_router_tensorcore(
|
|
337
372
|
x=hidden_states,
|
|
338
373
|
router_weight=gating_output,
|
|
339
374
|
topk=topk,
|
|
@@ -341,9 +376,11 @@ def fused_moe_router_shim(
|
|
|
341
376
|
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
|
342
377
|
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
|
343
378
|
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
|
379
|
+
correction_bias=correction_bias,
|
|
344
380
|
)
|
|
345
381
|
else:
|
|
346
|
-
|
|
382
|
+
# if smaller, use kernel that does not use tensorcore in matmul
|
|
383
|
+
return fused_moe_router_cudacore(
|
|
347
384
|
x=hidden_states,
|
|
348
385
|
router_weight=gating_output,
|
|
349
386
|
topk=topk,
|
|
@@ -380,11 +417,10 @@ class FusedMoeRouter:
|
|
|
380
417
|
renormalize=False,
|
|
381
418
|
)
|
|
382
419
|
|
|
383
|
-
def
|
|
420
|
+
def forward_torch(
|
|
384
421
|
self,
|
|
385
422
|
x: torch.Tensor,
|
|
386
423
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
387
|
-
# g, _ = self.router_linear.forward(x)
|
|
388
424
|
g = x.float() @ self.router_linear.weight.T.float()
|
|
389
425
|
|
|
390
426
|
g = torch.tanh(g.float() / self.moe_softcapping) * self.moe_softcapping
|
|
@@ -16,8 +16,14 @@ from sglang.srt.layers.moe.token_dispatcher.deepep import (
|
|
|
16
16
|
DeepEPNormalCombineInput,
|
|
17
17
|
DeepEPNormalOutput,
|
|
18
18
|
)
|
|
19
|
+
from sglang.srt.layers.moe.token_dispatcher.mooncake import (
|
|
20
|
+
MooncakeCombineInput,
|
|
21
|
+
MooncakeDispatchOutput,
|
|
22
|
+
MooncakeEPDispatcher,
|
|
23
|
+
)
|
|
19
24
|
from sglang.srt.layers.moe.token_dispatcher.standard import (
|
|
20
25
|
StandardCombineInput,
|
|
26
|
+
StandardDispatcher,
|
|
21
27
|
StandardDispatchOutput,
|
|
22
28
|
)
|
|
23
29
|
|
|
@@ -30,6 +36,10 @@ __all__ = [
|
|
|
30
36
|
"DispatchOutput",
|
|
31
37
|
"DispatchOutputFormat",
|
|
32
38
|
"DispatchOutputChecker",
|
|
39
|
+
"MooncakeCombineInput",
|
|
40
|
+
"MooncakeDispatchOutput",
|
|
41
|
+
"MooncakeEPDispatcher",
|
|
42
|
+
"StandardDispatcher",
|
|
33
43
|
"StandardDispatchOutput",
|
|
34
44
|
"StandardCombineInput",
|
|
35
45
|
"DeepEPConfig",
|
|
@@ -73,7 +73,7 @@ class DispatchOutputFormat(Enum):
|
|
|
73
73
|
class DispatchOutput(Protocol):
|
|
74
74
|
"""Protocol for dispatch outputs in different formats."""
|
|
75
75
|
|
|
76
|
-
|
|
76
|
+
hidden_states: torch.Tensor
|
|
77
77
|
|
|
78
78
|
@property
|
|
79
79
|
def format(self) -> DispatchOutputFormat: ...
|