sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +330 -156
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +8 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +134 -23
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +70 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +66 -66
- sglang/srt/entrypoints/grpc_server.py +431 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +120 -8
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +42 -4
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +18 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +4 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +12 -8
- sglang/srt/layers/attention/flashinfer_backend.py +248 -21
- sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +11 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +45 -15
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +147 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
- sglang/srt/layers/moe/ep_moe/layer.py +119 -397
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +17 -1
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +5 -30
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +673 -16
- sglang/srt/layers/sampler.py +36 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +66 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +399 -499
- sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +378 -90
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +435 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +28 -23
- sglang/srt/model_executor/model_runner.py +379 -139
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +273 -98
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +14 -37
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +5 -5
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3_moe.py +19 -35
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +6 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +577 -73
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +38 -28
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +192 -47
- sglang/srt/utils/hf_transformers_utils.py +40 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +232 -99
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -1,14 +1,12 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
|
-
from
|
|
5
|
-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
|
6
5
|
|
|
7
6
|
import torch
|
|
8
|
-
import triton
|
|
9
|
-
import triton.language as tl
|
|
10
7
|
|
|
11
|
-
from sglang.srt
|
|
8
|
+
from sglang.srt import single_batch_overlap
|
|
9
|
+
from sglang.srt.layers import deep_gemm_wrapper
|
|
12
10
|
from sglang.srt.layers.moe import (
|
|
13
11
|
get_deepep_mode,
|
|
14
12
|
get_moe_a2a_backend,
|
|
@@ -18,37 +16,21 @@ from sglang.srt.layers.moe import (
|
|
|
18
16
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
19
17
|
ep_gather,
|
|
20
18
|
ep_scatter,
|
|
21
|
-
moe_ep_deepgemm_preprocess,
|
|
22
|
-
post_reorder_triton_kernel,
|
|
23
19
|
silu_and_mul_masked_post_quant_fwd,
|
|
24
20
|
tma_align_input_scale,
|
|
25
21
|
)
|
|
26
22
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
|
|
27
23
|
from sglang.srt.layers.moe.topk import TopKOutput
|
|
28
|
-
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
|
29
24
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
30
25
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
|
31
26
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
|
32
27
|
is_fp8_fnuz,
|
|
33
28
|
sglang_per_token_group_quant_fp8,
|
|
34
29
|
)
|
|
35
|
-
from sglang.srt.layers.quantization.
|
|
36
|
-
CUTEDSL_MOE_NVFP4_DISPATCH,
|
|
37
|
-
ModelOptNvFp4FusedMoEMethod,
|
|
38
|
-
)
|
|
39
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
40
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
41
|
-
from sglang.srt.offloader import get_offloader
|
|
30
|
+
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
|
|
42
31
|
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
|
|
43
|
-
from sglang.srt.utils import
|
|
44
|
-
|
|
45
|
-
dispose_tensor,
|
|
46
|
-
get_bool_env_var,
|
|
47
|
-
get_int_env_var,
|
|
48
|
-
is_cuda,
|
|
49
|
-
is_hip,
|
|
50
|
-
is_npu,
|
|
51
|
-
)
|
|
32
|
+
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
|
|
33
|
+
from sglang.srt.utils.offloader import get_offloader
|
|
52
34
|
|
|
53
35
|
if TYPE_CHECKING:
|
|
54
36
|
from sglang.srt.layers.moe.token_dispatcher import (
|
|
@@ -72,29 +54,14 @@ if _use_aiter:
|
|
|
72
54
|
logger = logging.getLogger(__name__)
|
|
73
55
|
|
|
74
56
|
|
|
75
|
-
|
|
76
|
-
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
|
|
77
|
-
@torch.compile
|
|
78
|
-
def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
|
|
79
|
-
temp = x.to(torch.float32).view(torch.int32)
|
|
80
|
-
exp = torch.bitwise_right_shift(temp, 23)
|
|
81
|
-
mant = torch.bitwise_and(temp, 0x7FFFFF)
|
|
82
|
-
is_ru = torch.logical_and(
|
|
83
|
-
torch.logical_and((mant > 0), (exp != 0xFE)),
|
|
84
|
-
~torch.logical_and((exp == 0), (mant <= 0x400000)),
|
|
85
|
-
)
|
|
86
|
-
exp = torch.where(is_ru, exp + 1, exp)
|
|
87
|
-
new_x = exp.to(torch.uint8).view(torch.int)
|
|
88
|
-
return new_x.transpose(1, 2).contiguous().transpose(1, 2)
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
class EPMoE(FusedMoE):
|
|
57
|
+
class DeepEPMoE(FusedMoE):
|
|
92
58
|
"""
|
|
93
|
-
MoE Expert Parallel Impl
|
|
94
|
-
|
|
95
|
-
|
|
59
|
+
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
|
|
60
|
+
Mooncake EP shares the same class, as they expose the same interface.
|
|
96
61
|
"""
|
|
97
62
|
|
|
63
|
+
_has_printed = False
|
|
64
|
+
|
|
98
65
|
def __init__(
|
|
99
66
|
self,
|
|
100
67
|
num_experts: int,
|
|
@@ -108,291 +75,37 @@ class EPMoE(FusedMoE):
|
|
|
108
75
|
prefix: str = "",
|
|
109
76
|
activation: str = "silu",
|
|
110
77
|
routed_scaling_factor: Optional[float] = None,
|
|
111
|
-
gemm1_alpha: Optional[float] = None,
|
|
112
|
-
gemm1_clamp_limit: Optional[float] = None,
|
|
113
|
-
with_bias: bool = False,
|
|
114
78
|
):
|
|
115
79
|
super().__init__(
|
|
116
80
|
num_experts=num_experts,
|
|
81
|
+
top_k=top_k,
|
|
117
82
|
hidden_size=hidden_size,
|
|
118
83
|
intermediate_size=intermediate_size,
|
|
119
|
-
num_fused_shared_experts=num_fused_shared_experts,
|
|
120
84
|
layer_id=layer_id,
|
|
121
|
-
|
|
85
|
+
num_fused_shared_experts=num_fused_shared_experts,
|
|
122
86
|
params_dtype=params_dtype,
|
|
123
87
|
quant_config=quant_config,
|
|
124
88
|
prefix=prefix,
|
|
125
89
|
activation=activation,
|
|
126
|
-
# apply_router_weight_on_input=apply_router_weight_on_input,
|
|
127
90
|
routed_scaling_factor=routed_scaling_factor,
|
|
128
|
-
gemm1_alpha=gemm1_alpha,
|
|
129
|
-
gemm1_clamp_limit=gemm1_clamp_limit,
|
|
130
|
-
with_bias=with_bias,
|
|
131
91
|
)
|
|
132
92
|
|
|
133
|
-
self.intermediate_size = intermediate_size
|
|
134
|
-
|
|
135
93
|
if isinstance(quant_config, Fp8Config):
|
|
136
94
|
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
|
137
|
-
self.block_shape = (
|
|
138
|
-
self.quant_method.quant_config.weight_block_size
|
|
139
|
-
if self.use_block_quant
|
|
140
|
-
else None
|
|
141
|
-
)
|
|
142
95
|
self.use_fp8_w8a8 = True
|
|
143
96
|
self.fp8_dtype = torch.float8_e4m3fn
|
|
144
|
-
self.
|
|
145
|
-
|
|
97
|
+
self.use_w4afp8 = False
|
|
98
|
+
elif isinstance(quant_config, W4AFp8Config):
|
|
99
|
+
self.use_w4afp8 = True
|
|
146
100
|
self.use_fp8_w8a8 = False
|
|
147
101
|
self.use_block_quant = False
|
|
148
|
-
self.block_shape = None
|
|
149
|
-
self.activation_scheme = None
|
|
150
|
-
|
|
151
|
-
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
|
152
|
-
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
|
153
|
-
return self.forward_deepgemm(hidden_states, topk_output)
|
|
154
102
|
else:
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
self,
|
|
159
|
-
hidden_states: torch.Tensor,
|
|
160
|
-
topk_output: TopKOutput,
|
|
161
|
-
):
|
|
162
|
-
|
|
163
|
-
self.w13_weight_fp8 = (
|
|
164
|
-
self.w13_weight,
|
|
165
|
-
(
|
|
166
|
-
self.w13_weight_scale_inv
|
|
167
|
-
if self.use_block_quant
|
|
168
|
-
else self.w13_weight_scale
|
|
169
|
-
),
|
|
170
|
-
)
|
|
171
|
-
self.w2_weight_fp8 = (
|
|
172
|
-
self.w2_weight,
|
|
173
|
-
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
|
|
174
|
-
)
|
|
175
|
-
|
|
176
|
-
assert self.quant_method is not None
|
|
177
|
-
assert self.moe_runner_config.activation == "silu"
|
|
178
|
-
|
|
179
|
-
hidden_states_shape = hidden_states.shape
|
|
180
|
-
hidden_states_dtype = hidden_states.dtype
|
|
181
|
-
hidden_states_device = hidden_states.device
|
|
182
|
-
|
|
183
|
-
topk_weights, topk_ids, _ = topk_output
|
|
184
|
-
|
|
185
|
-
if not self.use_block_quant:
|
|
186
|
-
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
|
|
187
|
-
scale_block_size = 128
|
|
188
|
-
w13_weight_scale_n = 2 * (
|
|
189
|
-
(self.intermediate_size + scale_block_size - 1) // scale_block_size
|
|
190
|
-
)
|
|
191
|
-
w13_weight_scale_k = (
|
|
192
|
-
hidden_states_shape[-1] + scale_block_size - 1
|
|
193
|
-
) // scale_block_size
|
|
194
|
-
w13_weight_scale = (
|
|
195
|
-
self.w13_weight_scale.unsqueeze(1)
|
|
196
|
-
.repeat_interleave(w13_weight_scale_n, dim=1)
|
|
197
|
-
.unsqueeze(2)
|
|
198
|
-
.repeat_interleave(w13_weight_scale_k, dim=2)
|
|
199
|
-
)
|
|
200
|
-
self.w13_weight_fp8 = (
|
|
201
|
-
self.w13_weight,
|
|
202
|
-
w13_weight_scale,
|
|
203
|
-
)
|
|
204
|
-
w2_weight_scale_n = (
|
|
205
|
-
hidden_states_shape[-1] + scale_block_size - 1
|
|
206
|
-
) // scale_block_size
|
|
207
|
-
w2_weight_scale_k = (
|
|
208
|
-
self.intermediate_size + scale_block_size - 1
|
|
209
|
-
) // scale_block_size
|
|
210
|
-
w2_weight_scale = (
|
|
211
|
-
self.w2_weight_scale.unsqueeze(1)
|
|
212
|
-
.repeat_interleave(w2_weight_scale_n, dim=1)
|
|
213
|
-
.unsqueeze(2)
|
|
214
|
-
.repeat_interleave(w2_weight_scale_k, dim=2)
|
|
215
|
-
)
|
|
216
|
-
self.w2_weight_fp8 = (
|
|
217
|
-
self.w2_weight,
|
|
218
|
-
w2_weight_scale,
|
|
219
|
-
)
|
|
220
|
-
|
|
221
|
-
# PreReorder
|
|
222
|
-
m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
|
|
223
|
-
moe_ep_deepgemm_preprocess(
|
|
224
|
-
topk_ids,
|
|
225
|
-
self.num_experts,
|
|
226
|
-
hidden_states,
|
|
227
|
-
self.top_k,
|
|
228
|
-
self.start_expert_id,
|
|
229
|
-
self.end_expert_id,
|
|
230
|
-
self.block_shape,
|
|
231
|
-
)
|
|
232
|
-
)
|
|
233
|
-
|
|
234
|
-
dispose_tensor(hidden_states)
|
|
235
|
-
|
|
236
|
-
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
|
237
|
-
b, s_mn, s_k = gateup_input_scale.shape
|
|
238
|
-
assert (
|
|
239
|
-
s_mn % 4 == 0 and s_k % 4 == 0
|
|
240
|
-
), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
|
|
241
|
-
|
|
242
|
-
# GroupGemm-0
|
|
243
|
-
gateup_input_fp8 = (
|
|
244
|
-
gateup_input,
|
|
245
|
-
(
|
|
246
|
-
_cast_to_e8m0_with_rounding_up(gateup_input_scale)
|
|
247
|
-
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
|
248
|
-
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
|
|
249
|
-
gateup_input_scale
|
|
250
|
-
)
|
|
251
|
-
),
|
|
252
|
-
)
|
|
253
|
-
num_groups, m, k = gateup_input_fp8[0].size()
|
|
254
|
-
n = self.w13_weight.size(1)
|
|
255
|
-
gateup_output = torch.empty(
|
|
256
|
-
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
|
257
|
-
)
|
|
258
|
-
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
|
259
|
-
gateup_input_fp8,
|
|
260
|
-
self.w13_weight_fp8,
|
|
261
|
-
gateup_output,
|
|
262
|
-
masked_m,
|
|
263
|
-
expected_m,
|
|
264
|
-
)
|
|
265
|
-
del gateup_input
|
|
266
|
-
del gateup_input_fp8
|
|
267
|
-
|
|
268
|
-
# Act
|
|
269
|
-
down_input = torch.empty(
|
|
270
|
-
(
|
|
271
|
-
gateup_output.shape[0],
|
|
272
|
-
gateup_output.shape[1],
|
|
273
|
-
gateup_output.shape[2] // 2,
|
|
274
|
-
),
|
|
275
|
-
device=hidden_states_device,
|
|
276
|
-
dtype=self.fp8_dtype,
|
|
277
|
-
)
|
|
278
|
-
scale_block_size = 128
|
|
279
|
-
down_input_scale = torch.empty(
|
|
280
|
-
(
|
|
281
|
-
gateup_output.shape[0],
|
|
282
|
-
gateup_output.shape[1],
|
|
283
|
-
gateup_output.shape[2] // 2 // scale_block_size,
|
|
284
|
-
),
|
|
285
|
-
device=hidden_states_device,
|
|
286
|
-
dtype=torch.float32,
|
|
287
|
-
)
|
|
288
|
-
silu_and_mul_masked_post_quant_fwd(
|
|
289
|
-
gateup_output,
|
|
290
|
-
down_input,
|
|
291
|
-
down_input_scale,
|
|
292
|
-
scale_block_size,
|
|
293
|
-
masked_m,
|
|
294
|
-
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
|
295
|
-
)
|
|
296
|
-
del gateup_output
|
|
297
|
-
|
|
298
|
-
# GroupGemm-1
|
|
299
|
-
n = self.w2_weight.size(1)
|
|
300
|
-
down_input_fp8 = (
|
|
301
|
-
down_input,
|
|
302
|
-
(
|
|
303
|
-
down_input_scale
|
|
304
|
-
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
|
305
|
-
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
|
|
306
|
-
),
|
|
307
|
-
)
|
|
308
|
-
down_output = torch.empty(
|
|
309
|
-
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
|
310
|
-
)
|
|
311
|
-
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
|
312
|
-
down_input_fp8,
|
|
313
|
-
self.w2_weight_fp8,
|
|
314
|
-
down_output,
|
|
315
|
-
masked_m,
|
|
316
|
-
expected_m,
|
|
317
|
-
)
|
|
318
|
-
del down_input
|
|
319
|
-
del down_input_fp8
|
|
320
|
-
|
|
321
|
-
# PostReorder
|
|
322
|
-
output = torch.empty(
|
|
323
|
-
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
|
|
324
|
-
)
|
|
325
|
-
post_reorder_triton_kernel[(hidden_states_shape[0],)](
|
|
326
|
-
down_output,
|
|
327
|
-
output,
|
|
328
|
-
src2dst,
|
|
329
|
-
topk_ids,
|
|
330
|
-
topk_weights,
|
|
331
|
-
self.start_expert_id,
|
|
332
|
-
self.end_expert_id,
|
|
333
|
-
self.top_k,
|
|
334
|
-
hidden_states_shape[1],
|
|
335
|
-
m_max * self.start_expert_id,
|
|
336
|
-
BLOCK_SIZE=512,
|
|
337
|
-
)
|
|
338
|
-
if self.moe_runner_config.routed_scaling_factor is not None:
|
|
339
|
-
output *= self.moe_runner_config.routed_scaling_factor
|
|
340
|
-
return output
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
class DeepEPMoE(EPMoE):
|
|
344
|
-
"""
|
|
345
|
-
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
|
|
346
|
-
"""
|
|
347
|
-
|
|
348
|
-
_has_printed = False
|
|
103
|
+
self.use_fp8_w8a8 = False
|
|
104
|
+
self.use_block_quant = False
|
|
105
|
+
self.use_w4afp8 = False
|
|
349
106
|
|
|
350
|
-
def __init__(
|
|
351
|
-
self,
|
|
352
|
-
num_experts: int,
|
|
353
|
-
top_k: int,
|
|
354
|
-
hidden_size: int,
|
|
355
|
-
intermediate_size: int,
|
|
356
|
-
layer_id: int,
|
|
357
|
-
num_fused_shared_experts: int = 0,
|
|
358
|
-
params_dtype: Optional[torch.dtype] = None,
|
|
359
|
-
quant_config: Optional[QuantizationConfig] = None,
|
|
360
|
-
prefix: str = "",
|
|
361
|
-
activation: str = "silu",
|
|
362
|
-
routed_scaling_factor: Optional[float] = None,
|
|
363
|
-
):
|
|
364
|
-
super().__init__(
|
|
365
|
-
num_experts=num_experts,
|
|
366
|
-
top_k=top_k,
|
|
367
|
-
hidden_size=hidden_size,
|
|
368
|
-
intermediate_size=intermediate_size,
|
|
369
|
-
layer_id=layer_id,
|
|
370
|
-
num_fused_shared_experts=num_fused_shared_experts,
|
|
371
|
-
params_dtype=params_dtype,
|
|
372
|
-
quant_config=quant_config,
|
|
373
|
-
prefix=prefix,
|
|
374
|
-
activation=activation,
|
|
375
|
-
routed_scaling_factor=routed_scaling_factor,
|
|
376
|
-
)
|
|
377
107
|
self.deepep_mode = get_deepep_mode()
|
|
378
108
|
|
|
379
|
-
# TODO: move to the beginning of the file
|
|
380
|
-
from sglang.srt.distributed.parallel_state import get_tp_group
|
|
381
|
-
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
|
|
382
|
-
|
|
383
|
-
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
|
|
384
|
-
group=get_tp_group().device_group,
|
|
385
|
-
router_topk=self.top_k,
|
|
386
|
-
permute_fusion=True,
|
|
387
|
-
num_experts=self.num_experts,
|
|
388
|
-
num_local_experts=self.num_local_experts,
|
|
389
|
-
hidden_size=hidden_size,
|
|
390
|
-
params_dtype=params_dtype,
|
|
391
|
-
deepep_mode=self.deepep_mode,
|
|
392
|
-
async_finish=True, # TODO
|
|
393
|
-
return_recv_hook=True,
|
|
394
|
-
)
|
|
395
|
-
|
|
396
109
|
if self.deepep_mode.enable_low_latency() and not _is_npu:
|
|
397
110
|
# NPU supports low_latency deepep without deepgemm
|
|
398
111
|
assert (
|
|
@@ -416,7 +129,7 @@ class DeepEPMoE(EPMoE):
|
|
|
416
129
|
self.w13_weight,
|
|
417
130
|
(
|
|
418
131
|
self.w13_weight_scale_inv
|
|
419
|
-
if self.use_block_quant
|
|
132
|
+
if self.use_block_quant or self.use_w4afp8
|
|
420
133
|
else self.w13_weight_scale
|
|
421
134
|
),
|
|
422
135
|
)
|
|
@@ -424,7 +137,7 @@ class DeepEPMoE(EPMoE):
|
|
|
424
137
|
self.w2_weight,
|
|
425
138
|
(
|
|
426
139
|
self.w2_weight_scale_inv
|
|
427
|
-
if self.use_block_quant
|
|
140
|
+
if self.use_block_quant or self.use_w4afp8
|
|
428
141
|
else self.w2_weight_scale
|
|
429
142
|
),
|
|
430
143
|
)
|
|
@@ -432,44 +145,34 @@ class DeepEPMoE(EPMoE):
|
|
|
432
145
|
def forward(
|
|
433
146
|
self,
|
|
434
147
|
hidden_states: torch.Tensor,
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
148
|
+
topk_output: TopKOutput,
|
|
149
|
+
forward_shared_experts=None,
|
|
150
|
+
alt_stream=None,
|
|
151
|
+
disable_sbo=False,
|
|
438
152
|
):
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
153
|
+
|
|
154
|
+
# We have to call SBO inside MoE to be compatible with hooks used in offloading
|
|
155
|
+
return single_batch_overlap.execute_sbo(
|
|
156
|
+
hidden_states=hidden_states,
|
|
157
|
+
topk_output=topk_output,
|
|
158
|
+
# SBO args
|
|
159
|
+
experts=self,
|
|
160
|
+
forward_shared_experts=forward_shared_experts,
|
|
161
|
+
alt_stream=alt_stream,
|
|
162
|
+
disable_sbo=disable_sbo,
|
|
448
163
|
)
|
|
449
|
-
return hidden_states
|
|
450
164
|
|
|
451
165
|
def dispatch(
|
|
452
166
|
self,
|
|
453
167
|
hidden_states: torch.Tensor,
|
|
454
|
-
|
|
455
|
-
topk_weights: torch.Tensor,
|
|
456
|
-
forward_batch: ForwardBatch,
|
|
168
|
+
topk_output: TopKOutput,
|
|
457
169
|
):
|
|
458
|
-
return self.
|
|
170
|
+
return self.dispatcher.dispatch(
|
|
459
171
|
hidden_states=hidden_states,
|
|
460
|
-
|
|
461
|
-
topk_weights=topk_weights,
|
|
462
|
-
forward_batch=forward_batch,
|
|
463
|
-
input_global_scale=(
|
|
464
|
-
self.w13_input_scale_quant
|
|
465
|
-
if isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
|
|
466
|
-
and self.quant_method.enable_flashinfer_cutedsl_moe
|
|
467
|
-
and CUTEDSL_MOE_NVFP4_DISPATCH
|
|
468
|
-
else None
|
|
469
|
-
),
|
|
172
|
+
topk_output=topk_output,
|
|
470
173
|
)
|
|
471
174
|
|
|
472
|
-
def
|
|
175
|
+
def run_moe_core(
|
|
473
176
|
self,
|
|
474
177
|
dispatch_output: DispatchOutput,
|
|
475
178
|
down_gemm_overlap_args: Optional[DownGemmOverlapArgs] = None,
|
|
@@ -484,14 +187,20 @@ class DeepEPMoE(EPMoE):
|
|
|
484
187
|
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
|
|
485
188
|
return self.forward_npu(dispatch_output)
|
|
486
189
|
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
|
|
190
|
+
if self.use_w4afp8:
|
|
191
|
+
return self.forward_cutlass_w4afp8(dispatch_output)
|
|
487
192
|
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
|
488
193
|
return self.forward_deepgemm_contiguous(dispatch_output)
|
|
489
194
|
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
|
490
|
-
if
|
|
195
|
+
if (
|
|
196
|
+
get_moe_runner_backend().is_flashinfer_cutedsl()
|
|
197
|
+
and self.quant_config.get_name() == "modelopt_fp4"
|
|
198
|
+
):
|
|
491
199
|
return self.forward_flashinfer_cutedsl(
|
|
492
200
|
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
|
|
493
201
|
)
|
|
494
202
|
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
|
203
|
+
assert down_gemm_overlap_args is None
|
|
495
204
|
return self.forward_deepgemm_masked(dispatch_output)
|
|
496
205
|
else:
|
|
497
206
|
raise ValueError(
|
|
@@ -501,16 +210,14 @@ class DeepEPMoE(EPMoE):
|
|
|
501
210
|
def combine(
|
|
502
211
|
self,
|
|
503
212
|
hidden_states: torch.Tensor,
|
|
504
|
-
|
|
213
|
+
topk_ids: torch.Tensor,
|
|
505
214
|
topk_weights: torch.Tensor,
|
|
506
|
-
forward_batch: ForwardBatch,
|
|
507
215
|
overlap_args: Optional[Dict[str, Any]] = None,
|
|
508
216
|
):
|
|
509
|
-
return self.
|
|
217
|
+
return self.dispatcher.combine(
|
|
510
218
|
hidden_states=hidden_states,
|
|
511
|
-
|
|
219
|
+
topk_ids=topk_ids,
|
|
512
220
|
topk_weights=topk_weights,
|
|
513
|
-
forward_batch=forward_batch,
|
|
514
221
|
overlap_args=overlap_args,
|
|
515
222
|
)
|
|
516
223
|
|
|
@@ -518,9 +225,9 @@ class DeepEPMoE(EPMoE):
|
|
|
518
225
|
self,
|
|
519
226
|
dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
|
|
520
227
|
):
|
|
521
|
-
hidden_states,
|
|
228
|
+
hidden_states, topk_ids, topk_weights = (
|
|
522
229
|
dispatch_output.hidden_states,
|
|
523
|
-
dispatch_output.
|
|
230
|
+
dispatch_output.topk_ids,
|
|
524
231
|
dispatch_output.topk_weights,
|
|
525
232
|
)
|
|
526
233
|
if hidden_states.shape[0] == 0:
|
|
@@ -528,15 +235,15 @@ class DeepEPMoE(EPMoE):
|
|
|
528
235
|
# in original deepep, idx == -1 meaning invalid and will not be processed.
|
|
529
236
|
# aiter does not accept -1, we use a expert mask to make these idx invalid
|
|
530
237
|
# (idx == num_local_experts) meaning not used in aiter fused_moe
|
|
531
|
-
|
|
532
|
-
|
|
238
|
+
topk_ids_copy = topk_ids.to(torch.int32)
|
|
239
|
+
topk_ids_copy[topk_ids_copy == -1] = self.num_local_experts
|
|
533
240
|
|
|
534
241
|
return fused_moe(
|
|
535
242
|
hidden_states,
|
|
536
243
|
self.w13_weight,
|
|
537
244
|
self.w2_weight,
|
|
538
245
|
topk_weights,
|
|
539
|
-
|
|
246
|
+
topk_ids_copy,
|
|
540
247
|
w1_scale=self.w13_weight_scale_inv,
|
|
541
248
|
w2_scale=self.w2_weight_scale_inv,
|
|
542
249
|
quant_type=QuantType.per_128x128,
|
|
@@ -552,22 +259,24 @@ class DeepEPMoE(EPMoE):
|
|
|
552
259
|
self,
|
|
553
260
|
dispatch_output: DeepEPNormalOutput,
|
|
554
261
|
):
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
262
|
+
(
|
|
263
|
+
hidden_states,
|
|
264
|
+
hidden_states_scale,
|
|
265
|
+
topk_ids,
|
|
266
|
+
topk_weights,
|
|
267
|
+
num_recv_tokens_per_expert,
|
|
268
|
+
) = dispatch_output
|
|
559
269
|
assert self.quant_method is not None
|
|
560
270
|
assert self.moe_runner_config.activation == "silu"
|
|
561
271
|
if num_recv_tokens_per_expert is None:
|
|
562
|
-
return
|
|
272
|
+
return hidden_states.bfloat16()
|
|
563
273
|
all_tokens = sum(num_recv_tokens_per_expert)
|
|
564
274
|
if all_tokens <= 0:
|
|
565
|
-
return
|
|
566
|
-
M, K =
|
|
275
|
+
return hidden_states.bfloat16()
|
|
276
|
+
M, K = hidden_states.size()
|
|
567
277
|
N = self.w13_weight.size(1)
|
|
568
278
|
scale_block_size = 128
|
|
569
279
|
|
|
570
|
-
# TODO also unify other branches (e.g. `EPMoE.forward_deepgemm` sets the field on forward pass)
|
|
571
280
|
w13_weight_fp8 = (
|
|
572
281
|
self.w13_weight,
|
|
573
282
|
(
|
|
@@ -585,35 +294,35 @@ class DeepEPMoE(EPMoE):
|
|
|
585
294
|
),
|
|
586
295
|
)
|
|
587
296
|
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
297
|
+
hidden_states_shape = hidden_states.shape
|
|
298
|
+
hidden_states_device = hidden_states.device
|
|
299
|
+
hidden_states_dtype = hidden_states.dtype
|
|
591
300
|
|
|
592
301
|
input_tensor = [
|
|
593
302
|
torch.empty(
|
|
594
303
|
(all_tokens, K),
|
|
595
|
-
device=
|
|
596
|
-
dtype=
|
|
304
|
+
device=hidden_states.device,
|
|
305
|
+
dtype=hidden_states.dtype,
|
|
597
306
|
),
|
|
598
307
|
(
|
|
599
308
|
# TODO check whether need `zeros`
|
|
600
309
|
torch.zeros(
|
|
601
310
|
(ceil_div(K // 128, 4), all_tokens),
|
|
602
|
-
device=
|
|
311
|
+
device=hidden_states.device,
|
|
603
312
|
dtype=torch.int,
|
|
604
313
|
).transpose(0, 1)
|
|
605
314
|
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
|
606
315
|
else torch.empty(
|
|
607
316
|
(all_tokens, K // 128),
|
|
608
|
-
device=
|
|
317
|
+
device=hidden_states.device,
|
|
609
318
|
dtype=torch.float32,
|
|
610
319
|
)
|
|
611
320
|
),
|
|
612
321
|
]
|
|
613
322
|
m_indices = torch.empty(
|
|
614
|
-
all_tokens, device=
|
|
323
|
+
all_tokens, device=hidden_states.device, dtype=torch.int32
|
|
615
324
|
)
|
|
616
|
-
output_index = torch.empty_like(
|
|
325
|
+
output_index = torch.empty_like(topk_ids)
|
|
617
326
|
|
|
618
327
|
if get_offloader().forbid_copy_engine_usage:
|
|
619
328
|
num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce(
|
|
@@ -629,9 +338,9 @@ class DeepEPMoE(EPMoE):
|
|
|
629
338
|
expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
|
|
630
339
|
|
|
631
340
|
ep_scatter(
|
|
632
|
-
|
|
341
|
+
hidden_states,
|
|
633
342
|
hidden_states_scale,
|
|
634
|
-
|
|
343
|
+
topk_ids,
|
|
635
344
|
num_recv_tokens_per_expert_gpu,
|
|
636
345
|
expert_start_loc,
|
|
637
346
|
input_tensor[0],
|
|
@@ -640,11 +349,11 @@ class DeepEPMoE(EPMoE):
|
|
|
640
349
|
output_index,
|
|
641
350
|
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
|
642
351
|
)
|
|
643
|
-
dispose_tensor(
|
|
352
|
+
dispose_tensor(hidden_states)
|
|
644
353
|
|
|
645
354
|
gateup_output = torch.empty(
|
|
646
355
|
(all_tokens, N),
|
|
647
|
-
device=
|
|
356
|
+
device=hidden_states_device,
|
|
648
357
|
dtype=torch.bfloat16,
|
|
649
358
|
)
|
|
650
359
|
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
|
@@ -665,7 +374,7 @@ class DeepEPMoE(EPMoE):
|
|
|
665
374
|
del gateup_output
|
|
666
375
|
down_output = torch.empty(
|
|
667
376
|
(all_tokens, K),
|
|
668
|
-
device=
|
|
377
|
+
device=hidden_states_device,
|
|
669
378
|
dtype=torch.bfloat16,
|
|
670
379
|
)
|
|
671
380
|
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
|
|
@@ -687,11 +396,11 @@ class DeepEPMoE(EPMoE):
|
|
|
687
396
|
del down_input_fp8, down_input_scale
|
|
688
397
|
|
|
689
398
|
gather_out = torch.empty(
|
|
690
|
-
|
|
691
|
-
device=
|
|
399
|
+
hidden_states_shape,
|
|
400
|
+
device=hidden_states_device,
|
|
692
401
|
dtype=torch.bfloat16,
|
|
693
402
|
)
|
|
694
|
-
ep_gather(down_output,
|
|
403
|
+
ep_gather(down_output, topk_ids, topk_weights, output_index, gather_out)
|
|
695
404
|
|
|
696
405
|
return gather_out
|
|
697
406
|
|
|
@@ -700,42 +409,56 @@ class DeepEPMoE(EPMoE):
|
|
|
700
409
|
dispatch_output: DeepEPLLOutput,
|
|
701
410
|
down_gemm_overlap_args: Optional[DownGemmOverlapArgs],
|
|
702
411
|
):
|
|
703
|
-
hidden_states, _, _, masked_m, _ = dispatch_output
|
|
412
|
+
hidden_states, hidden_states_scale, _, _, masked_m, _ = dispatch_output
|
|
704
413
|
assert self.quant_method is not None
|
|
705
414
|
assert self.moe_runner_config.activation == "silu"
|
|
706
415
|
|
|
707
416
|
output = self.quant_method.apply_without_routing_weights(
|
|
708
417
|
layer=self,
|
|
709
|
-
x=hidden_states,
|
|
418
|
+
x=(hidden_states, hidden_states_scale),
|
|
710
419
|
masked_m=masked_m,
|
|
711
420
|
moe_runner_config=self.moe_runner_config,
|
|
712
421
|
down_gemm_overlap_args=down_gemm_overlap_args,
|
|
713
422
|
)
|
|
714
423
|
return output
|
|
715
424
|
|
|
425
|
+
def forward_cutlass_w4afp8(
|
|
426
|
+
self,
|
|
427
|
+
dispatch_output: DeepEPNormalOutput,
|
|
428
|
+
):
|
|
429
|
+
assert self.moe_runner_config.activation == "silu"
|
|
430
|
+
assert isinstance(self.quant_method, W4AFp8MoEMethod)
|
|
431
|
+
return self.quant_method.apply_deepep_normal(
|
|
432
|
+
layer=self,
|
|
433
|
+
dispatch_output=dispatch_output,
|
|
434
|
+
)
|
|
435
|
+
|
|
716
436
|
def forward_deepgemm_masked(
|
|
717
437
|
self,
|
|
718
438
|
dispatch_output: DeepEPLLOutput,
|
|
719
439
|
):
|
|
720
|
-
|
|
440
|
+
hidden_states, hidden_states_scale, _, _, masked_m, expected_m = dispatch_output
|
|
721
441
|
assert self.quant_method is not None
|
|
722
442
|
assert self.moe_runner_config.activation == "silu"
|
|
443
|
+
assert (
|
|
444
|
+
hidden_states_scale.dtype == torch.float32
|
|
445
|
+
), f"hidden_states_scale.dtype: {hidden_states_scale.dtype}"
|
|
723
446
|
|
|
724
447
|
# GroupGemm-0
|
|
725
|
-
num_groups, m, k =
|
|
448
|
+
num_groups, m, k = hidden_states.size()
|
|
726
449
|
n = self.w13_weight.size(1)
|
|
727
450
|
expected_m = min(expected_m, m)
|
|
728
451
|
gateup_output = torch.empty(
|
|
729
|
-
(num_groups, m, n), device=
|
|
452
|
+
(num_groups, m, n), device=hidden_states.device, dtype=torch.bfloat16
|
|
730
453
|
)
|
|
731
454
|
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
|
732
|
-
|
|
455
|
+
(hidden_states, hidden_states_scale),
|
|
733
456
|
self.w13_weight_fp8,
|
|
734
457
|
gateup_output,
|
|
735
458
|
masked_m,
|
|
736
459
|
expected_m,
|
|
737
460
|
)
|
|
738
|
-
dispose_tensor(
|
|
461
|
+
dispose_tensor(hidden_states)
|
|
739
462
|
|
|
740
463
|
# Act
|
|
741
464
|
down_input = torch.empty(
|
|
@@ -808,11 +531,9 @@ class DeepEPMoE(EPMoE):
|
|
|
808
531
|
def _forward_normal(dispatch_output: DeepEPNormalOutput):
|
|
809
532
|
if TYPE_CHECKING:
|
|
810
533
|
assert isinstance(dispatch_output, DeepEPNormalOutput)
|
|
811
|
-
hidden_states, _, _, num_recv_tokens_per_expert =
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
per_token_scale = hidden_states[1]
|
|
815
|
-
hidden_states = hidden_states[0]
|
|
534
|
+
hidden_states, hidden_states_scale, _, _, num_recv_tokens_per_expert = (
|
|
535
|
+
dispatch_output
|
|
536
|
+
)
|
|
816
537
|
|
|
817
538
|
group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
|
|
818
539
|
hidden_states.device
|
|
@@ -822,7 +543,7 @@ class DeepEPMoE(EPMoE):
|
|
|
822
543
|
hidden_states = torch_npu.npu_grouped_matmul(
|
|
823
544
|
x=[hidden_states],
|
|
824
545
|
weight=[self.w13_weight.permute(0, 2, 1)],
|
|
825
|
-
# per_token_scale=[
|
|
546
|
+
# per_token_scale=[hidden_states_scale],
|
|
826
547
|
split_item=2,
|
|
827
548
|
group_list_type=group_list_type,
|
|
828
549
|
group_type=0,
|
|
@@ -842,7 +563,7 @@ class DeepEPMoE(EPMoE):
|
|
|
842
563
|
)[0]
|
|
843
564
|
else:
|
|
844
565
|
if not get_bool_env_var("DEEP_NORMAL_MODE_USE_INT8_QUANT"):
|
|
845
|
-
hidden_states,
|
|
566
|
+
hidden_states, hidden_states_scale = torch_npu.npu_dynamic_quant(
|
|
846
567
|
hidden_states
|
|
847
568
|
)
|
|
848
569
|
# gmm1: gate_up_proj
|
|
@@ -850,7 +571,7 @@ class DeepEPMoE(EPMoE):
|
|
|
850
571
|
x=[hidden_states],
|
|
851
572
|
weight=[self.w13_weight],
|
|
852
573
|
scale=[self.w13_weight_scale.to(output_dtype)],
|
|
853
|
-
per_token_scale=[
|
|
574
|
+
per_token_scale=[hidden_states_scale],
|
|
854
575
|
split_item=2,
|
|
855
576
|
group_list_type=group_list_type,
|
|
856
577
|
group_type=0,
|
|
@@ -882,11 +603,14 @@ class DeepEPMoE(EPMoE):
|
|
|
882
603
|
def _forward_ll(dispatch_output: DeepEPLLOutput):
|
|
883
604
|
if TYPE_CHECKING:
|
|
884
605
|
assert isinstance(dispatch_output, DeepEPLLOutput)
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
606
|
+
(
|
|
607
|
+
hidden_states,
|
|
608
|
+
hidden_states_scale,
|
|
609
|
+
topk_ids,
|
|
610
|
+
topk_weights,
|
|
611
|
+
group_list,
|
|
612
|
+
_,
|
|
613
|
+
) = dispatch_output
|
|
890
614
|
|
|
891
615
|
group_list = group_list.to(torch.int64)
|
|
892
616
|
|
|
@@ -895,7 +619,7 @@ class DeepEPMoE(EPMoE):
|
|
|
895
619
|
hidden_states = torch_npu.npu_grouped_matmul(
|
|
896
620
|
x=[hidden_states],
|
|
897
621
|
weight=[self.w13_weight.permute(0, 2, 1)],
|
|
898
|
-
# per_token_scale=[
|
|
622
|
+
# per_token_scale=[hidden_states_scale],
|
|
899
623
|
split_item=2,
|
|
900
624
|
group_list_type=group_list_type,
|
|
901
625
|
group_type=0,
|
|
@@ -929,7 +653,7 @@ class DeepEPMoE(EPMoE):
|
|
|
929
653
|
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
|
930
654
|
x=hidden_states,
|
|
931
655
|
weight_scale=self.w13_weight_scale.to(torch.float32),
|
|
932
|
-
activation_scale=
|
|
656
|
+
activation_scale=hidden_states_scale,
|
|
933
657
|
bias=None,
|
|
934
658
|
quant_scale=None,
|
|
935
659
|
quant_offset=None,
|
|
@@ -962,7 +686,7 @@ class DeepEPMoE(EPMoE):
|
|
|
962
686
|
|
|
963
687
|
|
|
964
688
|
def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
|
|
965
|
-
if get_moe_a2a_backend().is_deepep():
|
|
689
|
+
if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake():
|
|
966
690
|
return DeepEPMoE
|
|
967
691
|
|
|
968
692
|
# NEW: Direct FP4 detection (bypasses EP requirements)
|
|
@@ -988,8 +712,6 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
|
|
|
988
712
|
return FlashInferFusedMoE
|
|
989
713
|
if get_moe_runner_backend().is_flashinfer_cutlass():
|
|
990
714
|
return FusedMoE
|
|
991
|
-
if get_moe_expert_parallel_world_size() > 1:
|
|
992
|
-
return EPMoE
|
|
993
715
|
return FusedMoE
|
|
994
716
|
|
|
995
717
|
|