sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,12 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import logging
|
4
|
-
from
|
4
|
+
from contextlib import nullcontext
|
5
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
5
6
|
|
6
7
|
import torch
|
8
|
+
import triton
|
9
|
+
import triton.language as tl
|
7
10
|
|
8
11
|
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
9
12
|
from sglang.srt.layers.moe import (
|
@@ -29,13 +32,26 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|
29
32
|
is_fp8_fnuz,
|
30
33
|
sglang_per_token_group_quant_fp8,
|
31
34
|
)
|
35
|
+
from sglang.srt.layers.quantization.modelopt_quant import (
|
36
|
+
CUTEDSL_MOE_NVFP4_DISPATCH,
|
37
|
+
ModelOptNvFp4FusedMoEMethod,
|
38
|
+
)
|
32
39
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
33
40
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
34
|
-
from sglang.srt.
|
41
|
+
from sglang.srt.offloader import get_offloader
|
42
|
+
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
|
43
|
+
from sglang.srt.utils import (
|
44
|
+
ceil_div,
|
45
|
+
dispose_tensor,
|
46
|
+
get_bool_env_var,
|
47
|
+
get_int_env_var,
|
48
|
+
is_cuda,
|
49
|
+
is_hip,
|
50
|
+
is_npu,
|
51
|
+
)
|
35
52
|
|
36
53
|
if TYPE_CHECKING:
|
37
54
|
from sglang.srt.layers.moe.token_dispatcher import (
|
38
|
-
AscendDeepEPLLOutput,
|
39
55
|
DeepEPLLOutput,
|
40
56
|
DeepEPNormalOutput,
|
41
57
|
DispatchOutput,
|
@@ -444,9 +460,20 @@ class DeepEPMoE(EPMoE):
|
|
444
460
|
topk_idx=topk_idx,
|
445
461
|
topk_weights=topk_weights,
|
446
462
|
forward_batch=forward_batch,
|
463
|
+
input_global_scale=(
|
464
|
+
self.w13_input_scale_quant
|
465
|
+
if isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
|
466
|
+
and self.quant_method.enable_flashinfer_cutedsl_moe
|
467
|
+
and CUTEDSL_MOE_NVFP4_DISPATCH
|
468
|
+
else None
|
469
|
+
),
|
447
470
|
)
|
448
471
|
|
449
|
-
def moe_impl(
|
472
|
+
def moe_impl(
|
473
|
+
self,
|
474
|
+
dispatch_output: DispatchOutput,
|
475
|
+
down_gemm_overlap_args: Optional[DownGemmOverlapArgs] = None,
|
476
|
+
):
|
450
477
|
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
|
451
478
|
|
452
479
|
if _use_aiter:
|
@@ -454,12 +481,16 @@ class DeepEPMoE(EPMoE):
|
|
454
481
|
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
455
482
|
return self.forward_aiter(dispatch_output)
|
456
483
|
if _is_npu:
|
457
|
-
assert DispatchOutputChecker.
|
484
|
+
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
|
458
485
|
return self.forward_npu(dispatch_output)
|
459
486
|
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
|
460
487
|
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
461
488
|
return self.forward_deepgemm_contiguous(dispatch_output)
|
462
489
|
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
490
|
+
if get_moe_runner_backend().is_flashinfer_cutedsl():
|
491
|
+
return self.forward_flashinfer_cutedsl(
|
492
|
+
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
|
493
|
+
)
|
463
494
|
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
464
495
|
return self.forward_deepgemm_masked(dispatch_output)
|
465
496
|
else:
|
@@ -473,12 +504,14 @@ class DeepEPMoE(EPMoE):
|
|
473
504
|
topk_idx: torch.Tensor,
|
474
505
|
topk_weights: torch.Tensor,
|
475
506
|
forward_batch: ForwardBatch,
|
507
|
+
overlap_args: Optional[Dict[str, Any]] = None,
|
476
508
|
):
|
477
509
|
return self.deepep_dispatcher.combine(
|
478
510
|
hidden_states=hidden_states,
|
479
511
|
topk_idx=topk_idx,
|
480
512
|
topk_weights=topk_weights,
|
481
513
|
forward_batch=forward_batch,
|
514
|
+
overlap_args=overlap_args,
|
482
515
|
)
|
483
516
|
|
484
517
|
def forward_aiter(
|
@@ -534,6 +567,24 @@ class DeepEPMoE(EPMoE):
|
|
534
567
|
N = self.w13_weight.size(1)
|
535
568
|
scale_block_size = 128
|
536
569
|
|
570
|
+
# TODO also unify other branches (e.g. `EPMoE.forward_deepgemm` sets the field on forward pass)
|
571
|
+
w13_weight_fp8 = (
|
572
|
+
self.w13_weight,
|
573
|
+
(
|
574
|
+
self.w13_weight_scale_inv
|
575
|
+
if self.use_block_quant
|
576
|
+
else self.w13_weight_scale
|
577
|
+
),
|
578
|
+
)
|
579
|
+
w2_weight_fp8 = (
|
580
|
+
self.w2_weight,
|
581
|
+
(
|
582
|
+
self.w2_weight_scale_inv
|
583
|
+
if self.use_block_quant
|
584
|
+
else self.w2_weight_scale
|
585
|
+
),
|
586
|
+
)
|
587
|
+
|
537
588
|
hidden_states_fp8_shape = hidden_states_fp8.shape
|
538
589
|
hidden_states_fp8_device = hidden_states_fp8.device
|
539
590
|
hidden_states_fp8_dtype = hidden_states_fp8.dtype
|
@@ -564,12 +615,17 @@ class DeepEPMoE(EPMoE):
|
|
564
615
|
)
|
565
616
|
output_index = torch.empty_like(topk_idx)
|
566
617
|
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
618
|
+
if get_offloader().forbid_copy_engine_usage:
|
619
|
+
num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce(
|
620
|
+
num_recv_tokens_per_expert
|
621
|
+
)
|
622
|
+
else:
|
623
|
+
num_recv_tokens_per_expert_gpu = torch.tensor(
|
624
|
+
num_recv_tokens_per_expert,
|
625
|
+
dtype=torch.int32,
|
626
|
+
pin_memory=True,
|
627
|
+
device="cpu",
|
628
|
+
).cuda(non_blocking=True)
|
573
629
|
expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
|
574
630
|
|
575
631
|
ep_scatter(
|
@@ -594,7 +650,7 @@ class DeepEPMoE(EPMoE):
|
|
594
650
|
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
595
651
|
input_tensor[1] = tma_align_input_scale(input_tensor[1])
|
596
652
|
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
|
597
|
-
input_tensor,
|
653
|
+
input_tensor, w13_weight_fp8, gateup_output, m_indices
|
598
654
|
)
|
599
655
|
del input_tensor
|
600
656
|
down_input = torch.empty(
|
@@ -624,7 +680,7 @@ class DeepEPMoE(EPMoE):
|
|
624
680
|
down_input_scale = tma_align_input_scale(down_input_scale)
|
625
681
|
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
|
626
682
|
(down_input_fp8, down_input_scale),
|
627
|
-
|
683
|
+
w2_weight_fp8,
|
628
684
|
down_output,
|
629
685
|
m_indices,
|
630
686
|
)
|
@@ -639,6 +695,24 @@ class DeepEPMoE(EPMoE):
|
|
639
695
|
|
640
696
|
return gather_out
|
641
697
|
|
698
|
+
def forward_flashinfer_cutedsl(
|
699
|
+
self,
|
700
|
+
dispatch_output: DeepEPLLOutput,
|
701
|
+
down_gemm_overlap_args: Optional[DownGemmOverlapArgs],
|
702
|
+
):
|
703
|
+
hidden_states, _, _, masked_m, _ = dispatch_output
|
704
|
+
assert self.quant_method is not None
|
705
|
+
assert self.moe_runner_config.activation == "silu"
|
706
|
+
|
707
|
+
output = self.quant_method.apply_without_routing_weights(
|
708
|
+
layer=self,
|
709
|
+
x=hidden_states,
|
710
|
+
masked_m=masked_m,
|
711
|
+
moe_runner_config=self.moe_runner_config,
|
712
|
+
down_gemm_overlap_args=down_gemm_overlap_args,
|
713
|
+
)
|
714
|
+
return output
|
715
|
+
|
642
716
|
def forward_deepgemm_masked(
|
643
717
|
self,
|
644
718
|
dispatch_output: DeepEPLLOutput,
|
@@ -718,66 +792,176 @@ class DeepEPMoE(EPMoE):
|
|
718
792
|
|
719
793
|
def forward_npu(
|
720
794
|
self,
|
721
|
-
dispatch_output: DeepEPLLOutput,
|
795
|
+
dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
|
722
796
|
):
|
723
|
-
if TYPE_CHECKING:
|
724
|
-
assert isinstance(dispatch_output, AscendDeepEPLLOutput)
|
725
|
-
hidden_states, topk_idx, topk_weights, _, seg_indptr, _ = dispatch_output
|
726
797
|
assert self.quant_method is not None
|
727
798
|
assert self.moe_runner_config.activation == "silu"
|
728
799
|
|
800
|
+
import torch_npu
|
801
|
+
|
802
|
+
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
|
803
|
+
|
729
804
|
# NOTE: Ascend's Dispatch & Combine does not support FP16
|
730
805
|
output_dtype = torch.bfloat16
|
806
|
+
group_list_type = 1
|
731
807
|
|
732
|
-
|
733
|
-
|
808
|
+
def _forward_normal(dispatch_output: DeepEPNormalOutput):
|
809
|
+
if TYPE_CHECKING:
|
810
|
+
assert isinstance(dispatch_output, DeepEPNormalOutput)
|
811
|
+
hidden_states, _, _, num_recv_tokens_per_expert = dispatch_output
|
734
812
|
|
735
|
-
|
736
|
-
|
813
|
+
if isinstance(hidden_states, tuple):
|
814
|
+
per_token_scale = hidden_states[1]
|
815
|
+
hidden_states = hidden_states[0]
|
737
816
|
|
738
|
-
|
817
|
+
group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
|
818
|
+
hidden_states.device
|
819
|
+
)
|
820
|
+
if self.w13_weight.dtype != torch.int8:
|
821
|
+
# gmm1: gate_up_proj
|
822
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
823
|
+
x=[hidden_states],
|
824
|
+
weight=[self.w13_weight.permute(0, 2, 1)],
|
825
|
+
# per_token_scale=[per_token_scale],
|
826
|
+
split_item=2,
|
827
|
+
group_list_type=group_list_type,
|
828
|
+
group_type=0,
|
829
|
+
group_list=group_list,
|
830
|
+
output_dtype=output_dtype,
|
831
|
+
)[0]
|
832
|
+
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
833
|
+
# gmm2: down_proj
|
834
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
835
|
+
x=[hidden_states],
|
836
|
+
weight=[self.w2_weight.permute(0, 2, 1)],
|
837
|
+
split_item=2,
|
838
|
+
group_list_type=group_list_type,
|
839
|
+
group_type=0,
|
840
|
+
group_list=group_list,
|
841
|
+
output_dtype=output_dtype,
|
842
|
+
)[0]
|
843
|
+
else:
|
844
|
+
if not get_bool_env_var("DEEP_NORMAL_MODE_USE_INT8_QUANT"):
|
845
|
+
hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
|
846
|
+
hidden_states
|
847
|
+
)
|
848
|
+
# gmm1: gate_up_proj
|
849
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
850
|
+
x=[hidden_states],
|
851
|
+
weight=[self.w13_weight],
|
852
|
+
scale=[self.w13_weight_scale.to(output_dtype)],
|
853
|
+
per_token_scale=[per_token_scale],
|
854
|
+
split_item=2,
|
855
|
+
group_list_type=group_list_type,
|
856
|
+
group_type=0,
|
857
|
+
group_list=group_list,
|
858
|
+
output_dtype=output_dtype,
|
859
|
+
)[0]
|
860
|
+
|
861
|
+
# act_fn: swiglu
|
862
|
+
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
863
|
+
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
|
864
|
+
hidden_states
|
865
|
+
)
|
739
866
|
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
753
|
-
x=hidden_states,
|
754
|
-
weight_scale=self.w13_weight_scale.to(torch.float32),
|
755
|
-
activation_scale=pertoken_scale,
|
756
|
-
bias=None,
|
757
|
-
quant_scale=None,
|
758
|
-
quant_offset=None,
|
759
|
-
group_index=seg_indptr,
|
760
|
-
activate_left=True,
|
761
|
-
quant_mode=1,
|
762
|
-
)
|
763
|
-
|
764
|
-
# gmm2: down_proj
|
765
|
-
hidden_states = torch_npu.npu_grouped_matmul(
|
766
|
-
x=[hidden_states],
|
767
|
-
weight=[self.w2_weight],
|
768
|
-
scale=[self.w2_weight_scale.to(output_dtype)],
|
769
|
-
per_token_scale=[swiglu_out_scale],
|
770
|
-
split_item=2,
|
771
|
-
group_list_type=group_list_type,
|
772
|
-
group_type=0,
|
773
|
-
group_list=seg_indptr,
|
774
|
-
output_dtype=output_dtype,
|
775
|
-
)[0]
|
867
|
+
# gmm2: down_proj
|
868
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
869
|
+
x=[hidden_states],
|
870
|
+
weight=[self.w2_weight],
|
871
|
+
scale=[self.w2_weight_scale.to(output_dtype)],
|
872
|
+
per_token_scale=[swiglu_out_scale],
|
873
|
+
split_item=2,
|
874
|
+
group_list_type=group_list_type,
|
875
|
+
group_type=0,
|
876
|
+
group_list=group_list,
|
877
|
+
output_dtype=output_dtype,
|
878
|
+
)[0]
|
776
879
|
|
777
|
-
|
880
|
+
return hidden_states
|
881
|
+
|
882
|
+
def _forward_ll(dispatch_output: DeepEPLLOutput):
|
883
|
+
if TYPE_CHECKING:
|
884
|
+
assert isinstance(dispatch_output, DeepEPLLOutput)
|
885
|
+
hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output
|
886
|
+
|
887
|
+
if isinstance(hidden_states, tuple):
|
888
|
+
per_token_scale = hidden_states[1]
|
889
|
+
hidden_states = hidden_states[0]
|
890
|
+
|
891
|
+
group_list = group_list.to(torch.int64)
|
892
|
+
|
893
|
+
if self.w13_weight.dtype != torch.int8:
|
894
|
+
# gmm1: gate_up_proj
|
895
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
896
|
+
x=[hidden_states],
|
897
|
+
weight=[self.w13_weight.permute(0, 2, 1)],
|
898
|
+
# per_token_scale=[per_token_scale],
|
899
|
+
split_item=2,
|
900
|
+
group_list_type=group_list_type,
|
901
|
+
group_type=0,
|
902
|
+
group_list=group_list,
|
903
|
+
output_dtype=output_dtype,
|
904
|
+
)[0]
|
905
|
+
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
906
|
+
# gmm2: down_proj
|
907
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
908
|
+
x=[hidden_states],
|
909
|
+
weight=[self.w2_weight.permute(0, 2, 1)],
|
910
|
+
split_item=2,
|
911
|
+
group_list_type=group_list_type,
|
912
|
+
group_type=0,
|
913
|
+
group_list=group_list,
|
914
|
+
output_dtype=output_dtype,
|
915
|
+
)[0]
|
916
|
+
else:
|
917
|
+
# gmm1: gate_up_proj
|
918
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
919
|
+
x=[hidden_states],
|
920
|
+
weight=[self.w13_weight],
|
921
|
+
split_item=2,
|
922
|
+
group_list_type=group_list_type,
|
923
|
+
group_type=0,
|
924
|
+
group_list=group_list,
|
925
|
+
output_dtype=torch.int32,
|
926
|
+
)[0]
|
927
|
+
|
928
|
+
# act_fn: swiglu
|
929
|
+
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
930
|
+
x=hidden_states,
|
931
|
+
weight_scale=self.w13_weight_scale.to(torch.float32),
|
932
|
+
activation_scale=per_token_scale,
|
933
|
+
bias=None,
|
934
|
+
quant_scale=None,
|
935
|
+
quant_offset=None,
|
936
|
+
group_index=group_list,
|
937
|
+
activate_left=True,
|
938
|
+
quant_mode=1,
|
939
|
+
)
|
940
|
+
|
941
|
+
# gmm2: down_proj
|
942
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
943
|
+
x=[hidden_states],
|
944
|
+
weight=[self.w2_weight],
|
945
|
+
scale=[self.w2_weight_scale.to(output_dtype)],
|
946
|
+
per_token_scale=[swiglu_out_scale],
|
947
|
+
split_item=2,
|
948
|
+
group_list_type=group_list_type,
|
949
|
+
group_type=0,
|
950
|
+
group_list=group_list,
|
951
|
+
output_dtype=output_dtype,
|
952
|
+
)[0]
|
778
953
|
|
954
|
+
return hidden_states
|
779
955
|
|
780
|
-
|
956
|
+
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
|
957
|
+
return _forward_normal(dispatch_output)
|
958
|
+
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
959
|
+
return _forward_ll(dispatch_output)
|
960
|
+
else:
|
961
|
+
raise ValueError(f"Not Supported DeepEP format {dispatch_output.format}")
|
962
|
+
|
963
|
+
|
964
|
+
def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
|
781
965
|
if get_moe_a2a_backend().is_deepep():
|
782
966
|
return DeepEPMoE
|
783
967
|
|
@@ -790,8 +974,7 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
|
|
790
974
|
return FusedMoE
|
791
975
|
try:
|
792
976
|
# Check the quantization argument directly
|
793
|
-
|
794
|
-
if quantization == "modelopt_fp4":
|
977
|
+
if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
|
795
978
|
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
796
979
|
FlashInferFP4MoE,
|
797
980
|
)
|
@@ -800,10 +983,20 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
|
|
800
983
|
except:
|
801
984
|
pass
|
802
985
|
|
803
|
-
if should_use_flashinfer_trtllm_moe():
|
986
|
+
if should_use_flashinfer_trtllm_moe() and quant_config is not None:
|
987
|
+
# FIXME: FlashInferFusedMoE only supports fp8 quant now
|
804
988
|
return FlashInferFusedMoE
|
805
989
|
if get_moe_runner_backend().is_flashinfer_cutlass():
|
806
990
|
return FusedMoE
|
807
991
|
if get_moe_expert_parallel_world_size() > 1:
|
808
992
|
return EPMoE
|
809
993
|
return FusedMoE
|
994
|
+
|
995
|
+
|
996
|
+
def copy_list_to_gpu_no_ce(arr: List[int]):
|
997
|
+
from sgl_kernel.elementwise import copy_to_gpu_no_ce
|
998
|
+
|
999
|
+
tensor_cpu = torch.tensor(arr, dtype=torch.int32, device="cpu")
|
1000
|
+
tensor_gpu = torch.empty_like(tensor_cpu, device="cuda")
|
1001
|
+
copy_to_gpu_no_ce(tensor_cpu, tensor_gpu)
|
1002
|
+
return tensor_gpu
|
@@ -0,0 +1,183 @@
|
|
1
|
+
from typing import Any, Dict, Optional, Union
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked
|
5
|
+
from sgl_kernel.gemm import (
|
6
|
+
scaled_fp4_grouped_quant,
|
7
|
+
silu_and_mul_scaled_fp4_grouped_quant,
|
8
|
+
)
|
9
|
+
|
10
|
+
|
11
|
+
def get_cute_dtype(input: torch.Tensor) -> str:
|
12
|
+
if input.dtype == torch.bfloat16:
|
13
|
+
return "bfloat16"
|
14
|
+
elif input.dtype == torch.float16:
|
15
|
+
return "float16"
|
16
|
+
elif input.dtype == torch.float32:
|
17
|
+
return "float32"
|
18
|
+
else:
|
19
|
+
raise ValueError(f"Unsupported cute dtype {input.dtype}")
|
20
|
+
|
21
|
+
|
22
|
+
def flashinfer_cutedsl_moe_masked(
|
23
|
+
hidden_states: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
24
|
+
input_global_scale: torch.Tensor,
|
25
|
+
w1: torch.Tensor,
|
26
|
+
w1_blockscale: torch.Tensor,
|
27
|
+
w1_alpha,
|
28
|
+
w2: torch.Tensor,
|
29
|
+
a2_global_scale: torch.Tensor,
|
30
|
+
w2_blockscale: torch.Tensor,
|
31
|
+
w2_alpha,
|
32
|
+
masked_m: torch.Tensor,
|
33
|
+
down_sm_count: Optional[int] = None,
|
34
|
+
down_signals: Optional[torch.Tensor] = None,
|
35
|
+
down_start_event: Optional[torch.cuda.Event] = None,
|
36
|
+
):
|
37
|
+
"""
|
38
|
+
Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL
|
39
|
+
kernels.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
hidden_states: Either of the following case
|
43
|
+
* torch.Tensor: [num_experts, m, k], bf16
|
44
|
+
* tuple[torch.Tensor, torch.Tensor]: [num_experts, m, k // 2], uint8, [num_experts, m, k // 16], float8_e4m3fn
|
45
|
+
input_global_scale (torch.Tensor): (l,)
|
46
|
+
w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
|
47
|
+
w1_blockscale (torch.Tensor): blockscale factors, e4m3,
|
48
|
+
w1_alpha (torch.Tensor): (l,)
|
49
|
+
w2 (torch.Tensor): fp4 weights, [l, k, n // 2], uint8
|
50
|
+
a2_global_scale (torch.Tensor): (l,)
|
51
|
+
w2_blockscale (torch.Tensor): blockscale factors, e4m3,
|
52
|
+
w2_alpha (torch.Tensor): (l,)
|
53
|
+
masked_m (torch.Tensor): Masked dimension indices
|
54
|
+
|
55
|
+
Notes:
|
56
|
+
- Assumes max(masked_m) == m.
|
57
|
+
"""
|
58
|
+
|
59
|
+
# === Assertions on dtypes ===
|
60
|
+
assert w1.dtype == torch.uint8, f"w1 must be uint8 (fp4 packed), got {w1.dtype}"
|
61
|
+
assert (
|
62
|
+
w1_blockscale.dtype == torch.float8_e4m3fn
|
63
|
+
), f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}"
|
64
|
+
assert (
|
65
|
+
w1_alpha.dtype == torch.float32
|
66
|
+
), f"w1_alpha must be float32, got {w1_alpha.dtype}"
|
67
|
+
assert w2.dtype == torch.uint8, f"w2 must be uint8 (fp4 packed), got {w2.dtype}"
|
68
|
+
assert (
|
69
|
+
a2_global_scale.dtype == torch.float32
|
70
|
+
), f"a2_global_scale must be float32, got {a2_global_scale.dtype}"
|
71
|
+
assert (
|
72
|
+
w2_blockscale.dtype == torch.float8_e4m3fn
|
73
|
+
), f"w2_blockscale must be float8_e4m3fn, got {w2_blockscale.dtype}"
|
74
|
+
assert (
|
75
|
+
w2_alpha.dtype == torch.float32
|
76
|
+
), f"w2_alpha must be float32, got {w2_alpha.dtype}"
|
77
|
+
|
78
|
+
# === Assertions on shapes ===
|
79
|
+
n = w2.shape[-1] * 2 # intermediate dimension
|
80
|
+
|
81
|
+
if isinstance(hidden_states, tuple):
|
82
|
+
assert (
|
83
|
+
input_global_scale is None
|
84
|
+
), "input_global_scale is needed when input needs quant"
|
85
|
+
|
86
|
+
a_q = hidden_states[0].view(torch.uint8)
|
87
|
+
a_q_sf = hidden_states[1].view(torch.float8_e4m3fn)
|
88
|
+
m, k_by_2, num_experts = a_q.shape
|
89
|
+
k = k_by_2 * 2
|
90
|
+
else:
|
91
|
+
num_experts, m, k = hidden_states.shape
|
92
|
+
|
93
|
+
assert (
|
94
|
+
input_global_scale.dtype == torch.float32
|
95
|
+
), f"input_global_scale must be float32, got {input_global_scale.dtype}"
|
96
|
+
assert input_global_scale.shape == (
|
97
|
+
num_experts,
|
98
|
+
), f"input_global_scale must be (l,), got {input_global_scale.shape}"
|
99
|
+
|
100
|
+
a_q, a_q_sf = scaled_fp4_grouped_quant(
|
101
|
+
hidden_states,
|
102
|
+
input_global_scale,
|
103
|
+
masked_m,
|
104
|
+
)
|
105
|
+
|
106
|
+
assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}"
|
107
|
+
assert (
|
108
|
+
w1.shape[-1] * 2 == k
|
109
|
+
), f"w1 last dim * 2 must equal k, got {w1.shape[-1]} vs k={k}"
|
110
|
+
assert w2.shape[-2:] == (
|
111
|
+
k,
|
112
|
+
n // 2,
|
113
|
+
), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n//2)}"
|
114
|
+
assert w1_alpha.shape == (
|
115
|
+
num_experts,
|
116
|
+
), f"w1_alpha must be (l,), got {w1_alpha.shape}"
|
117
|
+
assert a2_global_scale.shape == (
|
118
|
+
num_experts,
|
119
|
+
), f"a2_global_scale must be (l,), got {a2_global_scale.shape}"
|
120
|
+
assert w2_alpha.shape == (
|
121
|
+
num_experts,
|
122
|
+
), f"w2_alpha must be (l,), got {w2_alpha.shape}"
|
123
|
+
|
124
|
+
# TODO(kaixih@nvidia): dtype should be based on inputs.
|
125
|
+
gateup_output = torch.empty(
|
126
|
+
(num_experts, m, n * 2), dtype=torch.bfloat16, device=a_q.device
|
127
|
+
)
|
128
|
+
gateup_output = gateup_output.permute(1, 2, 0) # requirement of kernel
|
129
|
+
sf_vec_size = 16
|
130
|
+
assert a_q_sf.dtype == torch.float8_e4m3fn
|
131
|
+
assert a_q.dtype == torch.uint8
|
132
|
+
ab_dtype = "float4_e2m1fn"
|
133
|
+
sf_dtype = "float8_e4m3fn"
|
134
|
+
c_dtype = "bfloat16"
|
135
|
+
|
136
|
+
# Gemm1
|
137
|
+
grouped_gemm_nt_masked(
|
138
|
+
(a_q, a_q_sf),
|
139
|
+
(w1.permute(1, 2, 0), w1_blockscale),
|
140
|
+
gateup_output,
|
141
|
+
masked_m,
|
142
|
+
ab_dtype=ab_dtype,
|
143
|
+
sf_dtype=sf_dtype,
|
144
|
+
c_dtype=c_dtype,
|
145
|
+
sf_vec_size=sf_vec_size,
|
146
|
+
alpha=w1_alpha.view(1, 1, num_experts),
|
147
|
+
alpha_dtype=get_cute_dtype(w1_alpha),
|
148
|
+
) # in logical [m, n, l]
|
149
|
+
|
150
|
+
# SILU and quantization
|
151
|
+
diq, diq_sf = silu_and_mul_scaled_fp4_grouped_quant(
|
152
|
+
gateup_output.permute(2, 0, 1),
|
153
|
+
a2_global_scale,
|
154
|
+
masked_m,
|
155
|
+
)
|
156
|
+
|
157
|
+
if down_start_event is not None:
|
158
|
+
down_start_event.record()
|
159
|
+
|
160
|
+
# Gemm2
|
161
|
+
out = torch.empty((num_experts, m, k), dtype=torch.bfloat16, device=a_q.device)
|
162
|
+
out = out.permute(1, 2, 0) # requirement of kernel
|
163
|
+
grouped_gemm_nt_masked(
|
164
|
+
(diq, diq_sf),
|
165
|
+
(w2.permute(1, 2, 0), w2_blockscale),
|
166
|
+
out,
|
167
|
+
masked_m,
|
168
|
+
ab_dtype=ab_dtype,
|
169
|
+
sf_dtype=sf_dtype,
|
170
|
+
c_dtype=c_dtype,
|
171
|
+
sf_vec_size=sf_vec_size,
|
172
|
+
alpha=w2_alpha.view(1, 1, num_experts),
|
173
|
+
alpha_dtype=get_cute_dtype(w2_alpha),
|
174
|
+
**(
|
175
|
+
dict(
|
176
|
+
sm_count=down_sm_count,
|
177
|
+
dst_signals=down_signals,
|
178
|
+
)
|
179
|
+
if down_sm_count is not None or down_signals is not None
|
180
|
+
else {}
|
181
|
+
),
|
182
|
+
) # in logical [m, k, l]
|
183
|
+
return out.permute(2, 0, 1)
|
@@ -8,16 +8,18 @@ from torch.nn import functional as F
|
|
8
8
|
|
9
9
|
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
|
10
10
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
11
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardDispatchOutput
|
11
12
|
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
12
13
|
|
13
14
|
|
14
15
|
def fused_moe_forward_native(
|
15
16
|
layer: torch.nn.Module,
|
16
|
-
|
17
|
-
topk_output: StandardTopKOutput,
|
18
|
-
moe_runner_config: MoeRunnerConfig,
|
17
|
+
dispatch_output: StandardDispatchOutput,
|
19
18
|
) -> torch.Tensor:
|
20
19
|
|
20
|
+
x, topk_output = dispatch_output
|
21
|
+
moe_runner_config = layer.moe_runner_config
|
22
|
+
|
21
23
|
if moe_runner_config.apply_router_weight_on_input:
|
22
24
|
raise NotImplementedError()
|
23
25
|
|