sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -11
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +474 -142
- sglang/compile_deep_gemm.py +3 -0
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +10 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +314 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +228 -92
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +294 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +78 -37
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +373 -68
- sglang/srt/disaggregation/prefill.py +53 -49
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +842 -0
- sglang/srt/entrypoints/grpc_server.py +950 -0
- sglang/srt/entrypoints/http_server.py +179 -60
- sglang/srt/entrypoints/openai/protocol.py +265 -29
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +213 -122
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +289 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +17 -8
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +215 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +40 -8
- sglang/srt/layers/attention/flashinfer_backend.py +341 -204
- sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
- sglang/srt/layers/attention/mamba/mamba.py +577 -0
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +180 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
- sglang/srt/layers/moe/ep_moe/layer.py +248 -333
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +83 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +29 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +155 -60
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +191 -56
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +28 -33
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +44 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +55 -0
- sglang/srt/managers/schedule_batch.py +343 -212
- sglang/srt/managers/schedule_policy.py +145 -18
- sglang/srt/managers/scheduler.py +653 -273
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +579 -674
- sglang/srt/managers/tp_worker.py +96 -26
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +9 -2
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +651 -80
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +227 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +93 -48
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +74 -46
- sglang/srt/model_executor/model_runner.py +455 -176
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +10 -4
- sglang/srt/model_loader/loader.py +319 -10
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +161 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +578 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +50 -4
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +55 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +49 -26
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1051 -285
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +98 -29
- sglang/srt/speculative/ngram_info.py +428 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +605 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +451 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +119 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +313 -0
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +140 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +407 -8
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,10 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import logging
|
4
|
-
from typing import TYPE_CHECKING, Optional, Union
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
5
5
|
|
6
6
|
import torch
|
7
7
|
|
8
|
-
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
9
8
|
from sglang.srt.layers.moe import (
|
10
9
|
get_deepep_mode,
|
11
10
|
get_moe_a2a_backend,
|
@@ -15,13 +14,10 @@ from sglang.srt.layers.moe import (
|
|
15
14
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
16
15
|
ep_gather,
|
17
16
|
ep_scatter,
|
18
|
-
moe_ep_deepgemm_preprocess,
|
19
|
-
post_reorder_triton_kernel,
|
20
17
|
silu_and_mul_masked_post_quant_fwd,
|
21
18
|
tma_align_input_scale,
|
22
19
|
)
|
23
20
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
|
24
|
-
from sglang.srt.layers.moe.topk import TopKOutput
|
25
21
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
26
22
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
27
23
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
@@ -29,13 +25,17 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|
29
25
|
is_fp8_fnuz,
|
30
26
|
sglang_per_token_group_quant_fp8,
|
31
27
|
)
|
32
|
-
from sglang.srt.
|
28
|
+
from sglang.srt.layers.quantization.modelopt_quant import (
|
29
|
+
CUTEDSL_MOE_NVFP4_DISPATCH,
|
30
|
+
ModelOptNvFp4FusedMoEMethod,
|
31
|
+
)
|
33
32
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
33
|
+
from sglang.srt.offloader import get_offloader
|
34
|
+
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
|
34
35
|
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
|
35
36
|
|
36
37
|
if TYPE_CHECKING:
|
37
38
|
from sglang.srt.layers.moe.token_dispatcher import (
|
38
|
-
AscendDeepEPLLOutput,
|
39
39
|
DeepEPLLOutput,
|
40
40
|
DeepEPNormalOutput,
|
41
41
|
DispatchOutput,
|
@@ -56,29 +56,13 @@ if _use_aiter:
|
|
56
56
|
logger = logging.getLogger(__name__)
|
57
57
|
|
58
58
|
|
59
|
-
|
60
|
-
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
|
61
|
-
@torch.compile
|
62
|
-
def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
|
63
|
-
temp = x.to(torch.float32).view(torch.int32)
|
64
|
-
exp = torch.bitwise_right_shift(temp, 23)
|
65
|
-
mant = torch.bitwise_and(temp, 0x7FFFFF)
|
66
|
-
is_ru = torch.logical_and(
|
67
|
-
torch.logical_and((mant > 0), (exp != 0xFE)),
|
68
|
-
~torch.logical_and((exp == 0), (mant <= 0x400000)),
|
69
|
-
)
|
70
|
-
exp = torch.where(is_ru, exp + 1, exp)
|
71
|
-
new_x = exp.to(torch.uint8).view(torch.int)
|
72
|
-
return new_x.transpose(1, 2).contiguous().transpose(1, 2)
|
73
|
-
|
74
|
-
|
75
|
-
class EPMoE(FusedMoE):
|
59
|
+
class DeepEPMoE(FusedMoE):
|
76
60
|
"""
|
77
|
-
MoE Expert Parallel Impl
|
78
|
-
|
79
|
-
|
61
|
+
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
|
80
62
|
"""
|
81
63
|
|
64
|
+
_has_printed = False
|
65
|
+
|
82
66
|
def __init__(
|
83
67
|
self,
|
84
68
|
num_experts: int,
|
@@ -92,272 +76,29 @@ class EPMoE(FusedMoE):
|
|
92
76
|
prefix: str = "",
|
93
77
|
activation: str = "silu",
|
94
78
|
routed_scaling_factor: Optional[float] = None,
|
95
|
-
gemm1_alpha: Optional[float] = None,
|
96
|
-
gemm1_clamp_limit: Optional[float] = None,
|
97
|
-
with_bias: bool = False,
|
98
79
|
):
|
99
80
|
super().__init__(
|
100
81
|
num_experts=num_experts,
|
82
|
+
top_k=top_k,
|
101
83
|
hidden_size=hidden_size,
|
102
84
|
intermediate_size=intermediate_size,
|
103
|
-
num_fused_shared_experts=num_fused_shared_experts,
|
104
85
|
layer_id=layer_id,
|
105
|
-
|
86
|
+
num_fused_shared_experts=num_fused_shared_experts,
|
106
87
|
params_dtype=params_dtype,
|
107
88
|
quant_config=quant_config,
|
108
89
|
prefix=prefix,
|
109
90
|
activation=activation,
|
110
|
-
# apply_router_weight_on_input=apply_router_weight_on_input,
|
111
91
|
routed_scaling_factor=routed_scaling_factor,
|
112
|
-
gemm1_alpha=gemm1_alpha,
|
113
|
-
gemm1_clamp_limit=gemm1_clamp_limit,
|
114
|
-
with_bias=with_bias,
|
115
92
|
)
|
116
93
|
|
117
|
-
self.intermediate_size = intermediate_size
|
118
|
-
|
119
94
|
if isinstance(quant_config, Fp8Config):
|
120
95
|
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
121
|
-
self.block_shape = (
|
122
|
-
self.quant_method.quant_config.weight_block_size
|
123
|
-
if self.use_block_quant
|
124
|
-
else None
|
125
|
-
)
|
126
96
|
self.use_fp8_w8a8 = True
|
127
97
|
self.fp8_dtype = torch.float8_e4m3fn
|
128
|
-
self.activation_scheme = quant_config.activation_scheme
|
129
98
|
else:
|
130
99
|
self.use_fp8_w8a8 = False
|
131
100
|
self.use_block_quant = False
|
132
|
-
self.block_shape = None
|
133
|
-
self.activation_scheme = None
|
134
|
-
|
135
|
-
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
136
|
-
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
137
|
-
return self.forward_deepgemm(hidden_states, topk_output)
|
138
|
-
else:
|
139
|
-
return super().forward(hidden_states, topk_output)
|
140
|
-
|
141
|
-
def forward_deepgemm(
|
142
|
-
self,
|
143
|
-
hidden_states: torch.Tensor,
|
144
|
-
topk_output: TopKOutput,
|
145
|
-
):
|
146
101
|
|
147
|
-
self.w13_weight_fp8 = (
|
148
|
-
self.w13_weight,
|
149
|
-
(
|
150
|
-
self.w13_weight_scale_inv
|
151
|
-
if self.use_block_quant
|
152
|
-
else self.w13_weight_scale
|
153
|
-
),
|
154
|
-
)
|
155
|
-
self.w2_weight_fp8 = (
|
156
|
-
self.w2_weight,
|
157
|
-
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
|
158
|
-
)
|
159
|
-
|
160
|
-
assert self.quant_method is not None
|
161
|
-
assert self.moe_runner_config.activation == "silu"
|
162
|
-
|
163
|
-
hidden_states_shape = hidden_states.shape
|
164
|
-
hidden_states_dtype = hidden_states.dtype
|
165
|
-
hidden_states_device = hidden_states.device
|
166
|
-
|
167
|
-
topk_weights, topk_ids, _ = topk_output
|
168
|
-
|
169
|
-
if not self.use_block_quant:
|
170
|
-
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
|
171
|
-
scale_block_size = 128
|
172
|
-
w13_weight_scale_n = 2 * (
|
173
|
-
(self.intermediate_size + scale_block_size - 1) // scale_block_size
|
174
|
-
)
|
175
|
-
w13_weight_scale_k = (
|
176
|
-
hidden_states_shape[-1] + scale_block_size - 1
|
177
|
-
) // scale_block_size
|
178
|
-
w13_weight_scale = (
|
179
|
-
self.w13_weight_scale.unsqueeze(1)
|
180
|
-
.repeat_interleave(w13_weight_scale_n, dim=1)
|
181
|
-
.unsqueeze(2)
|
182
|
-
.repeat_interleave(w13_weight_scale_k, dim=2)
|
183
|
-
)
|
184
|
-
self.w13_weight_fp8 = (
|
185
|
-
self.w13_weight,
|
186
|
-
w13_weight_scale,
|
187
|
-
)
|
188
|
-
w2_weight_scale_n = (
|
189
|
-
hidden_states_shape[-1] + scale_block_size - 1
|
190
|
-
) // scale_block_size
|
191
|
-
w2_weight_scale_k = (
|
192
|
-
self.intermediate_size + scale_block_size - 1
|
193
|
-
) // scale_block_size
|
194
|
-
w2_weight_scale = (
|
195
|
-
self.w2_weight_scale.unsqueeze(1)
|
196
|
-
.repeat_interleave(w2_weight_scale_n, dim=1)
|
197
|
-
.unsqueeze(2)
|
198
|
-
.repeat_interleave(w2_weight_scale_k, dim=2)
|
199
|
-
)
|
200
|
-
self.w2_weight_fp8 = (
|
201
|
-
self.w2_weight,
|
202
|
-
w2_weight_scale,
|
203
|
-
)
|
204
|
-
|
205
|
-
# PreReorder
|
206
|
-
m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
|
207
|
-
moe_ep_deepgemm_preprocess(
|
208
|
-
topk_ids,
|
209
|
-
self.num_experts,
|
210
|
-
hidden_states,
|
211
|
-
self.top_k,
|
212
|
-
self.start_expert_id,
|
213
|
-
self.end_expert_id,
|
214
|
-
self.block_shape,
|
215
|
-
)
|
216
|
-
)
|
217
|
-
|
218
|
-
dispose_tensor(hidden_states)
|
219
|
-
|
220
|
-
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
221
|
-
b, s_mn, s_k = gateup_input_scale.shape
|
222
|
-
assert (
|
223
|
-
s_mn % 4 == 0 and s_k % 4 == 0
|
224
|
-
), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
|
225
|
-
|
226
|
-
# GroupGemm-0
|
227
|
-
gateup_input_fp8 = (
|
228
|
-
gateup_input,
|
229
|
-
(
|
230
|
-
_cast_to_e8m0_with_rounding_up(gateup_input_scale)
|
231
|
-
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
232
|
-
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
|
233
|
-
gateup_input_scale
|
234
|
-
)
|
235
|
-
),
|
236
|
-
)
|
237
|
-
num_groups, m, k = gateup_input_fp8[0].size()
|
238
|
-
n = self.w13_weight.size(1)
|
239
|
-
gateup_output = torch.empty(
|
240
|
-
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
241
|
-
)
|
242
|
-
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
243
|
-
gateup_input_fp8,
|
244
|
-
self.w13_weight_fp8,
|
245
|
-
gateup_output,
|
246
|
-
masked_m,
|
247
|
-
expected_m,
|
248
|
-
)
|
249
|
-
del gateup_input
|
250
|
-
del gateup_input_fp8
|
251
|
-
|
252
|
-
# Act
|
253
|
-
down_input = torch.empty(
|
254
|
-
(
|
255
|
-
gateup_output.shape[0],
|
256
|
-
gateup_output.shape[1],
|
257
|
-
gateup_output.shape[2] // 2,
|
258
|
-
),
|
259
|
-
device=hidden_states_device,
|
260
|
-
dtype=self.fp8_dtype,
|
261
|
-
)
|
262
|
-
scale_block_size = 128
|
263
|
-
down_input_scale = torch.empty(
|
264
|
-
(
|
265
|
-
gateup_output.shape[0],
|
266
|
-
gateup_output.shape[1],
|
267
|
-
gateup_output.shape[2] // 2 // scale_block_size,
|
268
|
-
),
|
269
|
-
device=hidden_states_device,
|
270
|
-
dtype=torch.float32,
|
271
|
-
)
|
272
|
-
silu_and_mul_masked_post_quant_fwd(
|
273
|
-
gateup_output,
|
274
|
-
down_input,
|
275
|
-
down_input_scale,
|
276
|
-
scale_block_size,
|
277
|
-
masked_m,
|
278
|
-
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
279
|
-
)
|
280
|
-
del gateup_output
|
281
|
-
|
282
|
-
# GroupGemm-1
|
283
|
-
n = self.w2_weight.size(1)
|
284
|
-
down_input_fp8 = (
|
285
|
-
down_input,
|
286
|
-
(
|
287
|
-
down_input_scale
|
288
|
-
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
289
|
-
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
|
290
|
-
),
|
291
|
-
)
|
292
|
-
down_output = torch.empty(
|
293
|
-
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
294
|
-
)
|
295
|
-
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
296
|
-
down_input_fp8,
|
297
|
-
self.w2_weight_fp8,
|
298
|
-
down_output,
|
299
|
-
masked_m,
|
300
|
-
expected_m,
|
301
|
-
)
|
302
|
-
del down_input
|
303
|
-
del down_input_fp8
|
304
|
-
|
305
|
-
# PostReorder
|
306
|
-
output = torch.empty(
|
307
|
-
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
|
308
|
-
)
|
309
|
-
post_reorder_triton_kernel[(hidden_states_shape[0],)](
|
310
|
-
down_output,
|
311
|
-
output,
|
312
|
-
src2dst,
|
313
|
-
topk_ids,
|
314
|
-
topk_weights,
|
315
|
-
self.start_expert_id,
|
316
|
-
self.end_expert_id,
|
317
|
-
self.top_k,
|
318
|
-
hidden_states_shape[1],
|
319
|
-
m_max * self.start_expert_id,
|
320
|
-
BLOCK_SIZE=512,
|
321
|
-
)
|
322
|
-
if self.moe_runner_config.routed_scaling_factor is not None:
|
323
|
-
output *= self.moe_runner_config.routed_scaling_factor
|
324
|
-
return output
|
325
|
-
|
326
|
-
|
327
|
-
class DeepEPMoE(EPMoE):
|
328
|
-
"""
|
329
|
-
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
|
330
|
-
"""
|
331
|
-
|
332
|
-
_has_printed = False
|
333
|
-
|
334
|
-
def __init__(
|
335
|
-
self,
|
336
|
-
num_experts: int,
|
337
|
-
top_k: int,
|
338
|
-
hidden_size: int,
|
339
|
-
intermediate_size: int,
|
340
|
-
layer_id: int,
|
341
|
-
num_fused_shared_experts: int = 0,
|
342
|
-
params_dtype: Optional[torch.dtype] = None,
|
343
|
-
quant_config: Optional[QuantizationConfig] = None,
|
344
|
-
prefix: str = "",
|
345
|
-
activation: str = "silu",
|
346
|
-
routed_scaling_factor: Optional[float] = None,
|
347
|
-
):
|
348
|
-
super().__init__(
|
349
|
-
num_experts=num_experts,
|
350
|
-
top_k=top_k,
|
351
|
-
hidden_size=hidden_size,
|
352
|
-
intermediate_size=intermediate_size,
|
353
|
-
layer_id=layer_id,
|
354
|
-
num_fused_shared_experts=num_fused_shared_experts,
|
355
|
-
params_dtype=params_dtype,
|
356
|
-
quant_config=quant_config,
|
357
|
-
prefix=prefix,
|
358
|
-
activation=activation,
|
359
|
-
routed_scaling_factor=routed_scaling_factor,
|
360
|
-
)
|
361
102
|
self.deepep_mode = get_deepep_mode()
|
362
103
|
|
363
104
|
# TODO: move to the beginning of the file
|
@@ -444,9 +185,20 @@ class DeepEPMoE(EPMoE):
|
|
444
185
|
topk_idx=topk_idx,
|
445
186
|
topk_weights=topk_weights,
|
446
187
|
forward_batch=forward_batch,
|
188
|
+
input_global_scale=(
|
189
|
+
self.w13_input_scale_quant
|
190
|
+
if isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
|
191
|
+
and self.quant_method.enable_flashinfer_cutedsl_moe
|
192
|
+
and CUTEDSL_MOE_NVFP4_DISPATCH
|
193
|
+
else None
|
194
|
+
),
|
447
195
|
)
|
448
196
|
|
449
|
-
def moe_impl(
|
197
|
+
def moe_impl(
|
198
|
+
self,
|
199
|
+
dispatch_output: DispatchOutput,
|
200
|
+
down_gemm_overlap_args: Optional[DownGemmOverlapArgs] = None,
|
201
|
+
):
|
450
202
|
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
|
451
203
|
|
452
204
|
if _use_aiter:
|
@@ -454,12 +206,16 @@ class DeepEPMoE(EPMoE):
|
|
454
206
|
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
455
207
|
return self.forward_aiter(dispatch_output)
|
456
208
|
if _is_npu:
|
457
|
-
assert DispatchOutputChecker.
|
209
|
+
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
|
458
210
|
return self.forward_npu(dispatch_output)
|
459
211
|
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
|
460
212
|
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
461
213
|
return self.forward_deepgemm_contiguous(dispatch_output)
|
462
214
|
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
215
|
+
if get_moe_runner_backend().is_flashinfer_cutedsl():
|
216
|
+
return self.forward_flashinfer_cutedsl(
|
217
|
+
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
|
218
|
+
)
|
463
219
|
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
464
220
|
return self.forward_deepgemm_masked(dispatch_output)
|
465
221
|
else:
|
@@ -473,12 +229,14 @@ class DeepEPMoE(EPMoE):
|
|
473
229
|
topk_idx: torch.Tensor,
|
474
230
|
topk_weights: torch.Tensor,
|
475
231
|
forward_batch: ForwardBatch,
|
232
|
+
overlap_args: Optional[Dict[str, Any]] = None,
|
476
233
|
):
|
477
234
|
return self.deepep_dispatcher.combine(
|
478
235
|
hidden_states=hidden_states,
|
479
236
|
topk_idx=topk_idx,
|
480
237
|
topk_weights=topk_weights,
|
481
238
|
forward_batch=forward_batch,
|
239
|
+
overlap_args=overlap_args,
|
482
240
|
)
|
483
241
|
|
484
242
|
def forward_aiter(
|
@@ -534,6 +292,23 @@ class DeepEPMoE(EPMoE):
|
|
534
292
|
N = self.w13_weight.size(1)
|
535
293
|
scale_block_size = 128
|
536
294
|
|
295
|
+
w13_weight_fp8 = (
|
296
|
+
self.w13_weight,
|
297
|
+
(
|
298
|
+
self.w13_weight_scale_inv
|
299
|
+
if self.use_block_quant
|
300
|
+
else self.w13_weight_scale
|
301
|
+
),
|
302
|
+
)
|
303
|
+
w2_weight_fp8 = (
|
304
|
+
self.w2_weight,
|
305
|
+
(
|
306
|
+
self.w2_weight_scale_inv
|
307
|
+
if self.use_block_quant
|
308
|
+
else self.w2_weight_scale
|
309
|
+
),
|
310
|
+
)
|
311
|
+
|
537
312
|
hidden_states_fp8_shape = hidden_states_fp8.shape
|
538
313
|
hidden_states_fp8_device = hidden_states_fp8.device
|
539
314
|
hidden_states_fp8_dtype = hidden_states_fp8.dtype
|
@@ -564,12 +339,17 @@ class DeepEPMoE(EPMoE):
|
|
564
339
|
)
|
565
340
|
output_index = torch.empty_like(topk_idx)
|
566
341
|
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
342
|
+
if get_offloader().forbid_copy_engine_usage:
|
343
|
+
num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce(
|
344
|
+
num_recv_tokens_per_expert
|
345
|
+
)
|
346
|
+
else:
|
347
|
+
num_recv_tokens_per_expert_gpu = torch.tensor(
|
348
|
+
num_recv_tokens_per_expert,
|
349
|
+
dtype=torch.int32,
|
350
|
+
pin_memory=True,
|
351
|
+
device="cpu",
|
352
|
+
).cuda(non_blocking=True)
|
573
353
|
expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
|
574
354
|
|
575
355
|
ep_scatter(
|
@@ -594,7 +374,7 @@ class DeepEPMoE(EPMoE):
|
|
594
374
|
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
595
375
|
input_tensor[1] = tma_align_input_scale(input_tensor[1])
|
596
376
|
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
|
597
|
-
input_tensor,
|
377
|
+
input_tensor, w13_weight_fp8, gateup_output, m_indices
|
598
378
|
)
|
599
379
|
del input_tensor
|
600
380
|
down_input = torch.empty(
|
@@ -624,7 +404,7 @@ class DeepEPMoE(EPMoE):
|
|
624
404
|
down_input_scale = tma_align_input_scale(down_input_scale)
|
625
405
|
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
|
626
406
|
(down_input_fp8, down_input_scale),
|
627
|
-
|
407
|
+
w2_weight_fp8,
|
628
408
|
down_output,
|
629
409
|
m_indices,
|
630
410
|
)
|
@@ -639,6 +419,24 @@ class DeepEPMoE(EPMoE):
|
|
639
419
|
|
640
420
|
return gather_out
|
641
421
|
|
422
|
+
def forward_flashinfer_cutedsl(
|
423
|
+
self,
|
424
|
+
dispatch_output: DeepEPLLOutput,
|
425
|
+
down_gemm_overlap_args: Optional[DownGemmOverlapArgs],
|
426
|
+
):
|
427
|
+
hidden_states, _, _, masked_m, _ = dispatch_output
|
428
|
+
assert self.quant_method is not None
|
429
|
+
assert self.moe_runner_config.activation == "silu"
|
430
|
+
|
431
|
+
output = self.quant_method.apply_without_routing_weights(
|
432
|
+
layer=self,
|
433
|
+
x=hidden_states,
|
434
|
+
masked_m=masked_m,
|
435
|
+
moe_runner_config=self.moe_runner_config,
|
436
|
+
down_gemm_overlap_args=down_gemm_overlap_args,
|
437
|
+
)
|
438
|
+
return output
|
439
|
+
|
642
440
|
def forward_deepgemm_masked(
|
643
441
|
self,
|
644
442
|
dispatch_output: DeepEPLLOutput,
|
@@ -718,66 +516,176 @@ class DeepEPMoE(EPMoE):
|
|
718
516
|
|
719
517
|
def forward_npu(
|
720
518
|
self,
|
721
|
-
dispatch_output: DeepEPLLOutput,
|
519
|
+
dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
|
722
520
|
):
|
723
|
-
if TYPE_CHECKING:
|
724
|
-
assert isinstance(dispatch_output, AscendDeepEPLLOutput)
|
725
|
-
hidden_states, topk_idx, topk_weights, _, seg_indptr, _ = dispatch_output
|
726
521
|
assert self.quant_method is not None
|
727
522
|
assert self.moe_runner_config.activation == "silu"
|
728
523
|
|
524
|
+
import torch_npu
|
525
|
+
|
526
|
+
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
|
527
|
+
|
729
528
|
# NOTE: Ascend's Dispatch & Combine does not support FP16
|
730
529
|
output_dtype = torch.bfloat16
|
530
|
+
group_list_type = 1
|
731
531
|
|
732
|
-
|
733
|
-
|
532
|
+
def _forward_normal(dispatch_output: DeepEPNormalOutput):
|
533
|
+
if TYPE_CHECKING:
|
534
|
+
assert isinstance(dispatch_output, DeepEPNormalOutput)
|
535
|
+
hidden_states, _, _, num_recv_tokens_per_expert = dispatch_output
|
734
536
|
|
735
|
-
|
736
|
-
|
537
|
+
if isinstance(hidden_states, tuple):
|
538
|
+
per_token_scale = hidden_states[1]
|
539
|
+
hidden_states = hidden_states[0]
|
737
540
|
|
738
|
-
|
541
|
+
group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
|
542
|
+
hidden_states.device
|
543
|
+
)
|
544
|
+
if self.w13_weight.dtype != torch.int8:
|
545
|
+
# gmm1: gate_up_proj
|
546
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
547
|
+
x=[hidden_states],
|
548
|
+
weight=[self.w13_weight.permute(0, 2, 1)],
|
549
|
+
# per_token_scale=[per_token_scale],
|
550
|
+
split_item=2,
|
551
|
+
group_list_type=group_list_type,
|
552
|
+
group_type=0,
|
553
|
+
group_list=group_list,
|
554
|
+
output_dtype=output_dtype,
|
555
|
+
)[0]
|
556
|
+
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
557
|
+
# gmm2: down_proj
|
558
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
559
|
+
x=[hidden_states],
|
560
|
+
weight=[self.w2_weight.permute(0, 2, 1)],
|
561
|
+
split_item=2,
|
562
|
+
group_list_type=group_list_type,
|
563
|
+
group_type=0,
|
564
|
+
group_list=group_list,
|
565
|
+
output_dtype=output_dtype,
|
566
|
+
)[0]
|
567
|
+
else:
|
568
|
+
if not get_bool_env_var("DEEP_NORMAL_MODE_USE_INT8_QUANT"):
|
569
|
+
hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
|
570
|
+
hidden_states
|
571
|
+
)
|
572
|
+
# gmm1: gate_up_proj
|
573
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
574
|
+
x=[hidden_states],
|
575
|
+
weight=[self.w13_weight],
|
576
|
+
scale=[self.w13_weight_scale.to(output_dtype)],
|
577
|
+
per_token_scale=[per_token_scale],
|
578
|
+
split_item=2,
|
579
|
+
group_list_type=group_list_type,
|
580
|
+
group_type=0,
|
581
|
+
group_list=group_list,
|
582
|
+
output_dtype=output_dtype,
|
583
|
+
)[0]
|
584
|
+
|
585
|
+
# act_fn: swiglu
|
586
|
+
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
587
|
+
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
|
588
|
+
hidden_states
|
589
|
+
)
|
739
590
|
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
753
|
-
x=hidden_states,
|
754
|
-
weight_scale=self.w13_weight_scale.to(torch.float32),
|
755
|
-
activation_scale=pertoken_scale,
|
756
|
-
bias=None,
|
757
|
-
quant_scale=None,
|
758
|
-
quant_offset=None,
|
759
|
-
group_index=seg_indptr,
|
760
|
-
activate_left=True,
|
761
|
-
quant_mode=1,
|
762
|
-
)
|
591
|
+
# gmm2: down_proj
|
592
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
593
|
+
x=[hidden_states],
|
594
|
+
weight=[self.w2_weight],
|
595
|
+
scale=[self.w2_weight_scale.to(output_dtype)],
|
596
|
+
per_token_scale=[swiglu_out_scale],
|
597
|
+
split_item=2,
|
598
|
+
group_list_type=group_list_type,
|
599
|
+
group_type=0,
|
600
|
+
group_list=group_list,
|
601
|
+
output_dtype=output_dtype,
|
602
|
+
)[0]
|
763
603
|
|
764
|
-
|
765
|
-
hidden_states = torch_npu.npu_grouped_matmul(
|
766
|
-
x=[hidden_states],
|
767
|
-
weight=[self.w2_weight],
|
768
|
-
scale=[self.w2_weight_scale.to(output_dtype)],
|
769
|
-
per_token_scale=[swiglu_out_scale],
|
770
|
-
split_item=2,
|
771
|
-
group_list_type=group_list_type,
|
772
|
-
group_type=0,
|
773
|
-
group_list=seg_indptr,
|
774
|
-
output_dtype=output_dtype,
|
775
|
-
)[0]
|
604
|
+
return hidden_states
|
776
605
|
|
777
|
-
|
606
|
+
def _forward_ll(dispatch_output: DeepEPLLOutput):
|
607
|
+
if TYPE_CHECKING:
|
608
|
+
assert isinstance(dispatch_output, DeepEPLLOutput)
|
609
|
+
hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output
|
610
|
+
|
611
|
+
if isinstance(hidden_states, tuple):
|
612
|
+
per_token_scale = hidden_states[1]
|
613
|
+
hidden_states = hidden_states[0]
|
614
|
+
|
615
|
+
group_list = group_list.to(torch.int64)
|
616
|
+
|
617
|
+
if self.w13_weight.dtype != torch.int8:
|
618
|
+
# gmm1: gate_up_proj
|
619
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
620
|
+
x=[hidden_states],
|
621
|
+
weight=[self.w13_weight.permute(0, 2, 1)],
|
622
|
+
# per_token_scale=[per_token_scale],
|
623
|
+
split_item=2,
|
624
|
+
group_list_type=group_list_type,
|
625
|
+
group_type=0,
|
626
|
+
group_list=group_list,
|
627
|
+
output_dtype=output_dtype,
|
628
|
+
)[0]
|
629
|
+
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
630
|
+
# gmm2: down_proj
|
631
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
632
|
+
x=[hidden_states],
|
633
|
+
weight=[self.w2_weight.permute(0, 2, 1)],
|
634
|
+
split_item=2,
|
635
|
+
group_list_type=group_list_type,
|
636
|
+
group_type=0,
|
637
|
+
group_list=group_list,
|
638
|
+
output_dtype=output_dtype,
|
639
|
+
)[0]
|
640
|
+
else:
|
641
|
+
# gmm1: gate_up_proj
|
642
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
643
|
+
x=[hidden_states],
|
644
|
+
weight=[self.w13_weight],
|
645
|
+
split_item=2,
|
646
|
+
group_list_type=group_list_type,
|
647
|
+
group_type=0,
|
648
|
+
group_list=group_list,
|
649
|
+
output_dtype=torch.int32,
|
650
|
+
)[0]
|
651
|
+
|
652
|
+
# act_fn: swiglu
|
653
|
+
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
654
|
+
x=hidden_states,
|
655
|
+
weight_scale=self.w13_weight_scale.to(torch.float32),
|
656
|
+
activation_scale=per_token_scale,
|
657
|
+
bias=None,
|
658
|
+
quant_scale=None,
|
659
|
+
quant_offset=None,
|
660
|
+
group_index=group_list,
|
661
|
+
activate_left=True,
|
662
|
+
quant_mode=1,
|
663
|
+
)
|
778
664
|
|
665
|
+
# gmm2: down_proj
|
666
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
667
|
+
x=[hidden_states],
|
668
|
+
weight=[self.w2_weight],
|
669
|
+
scale=[self.w2_weight_scale.to(output_dtype)],
|
670
|
+
per_token_scale=[swiglu_out_scale],
|
671
|
+
split_item=2,
|
672
|
+
group_list_type=group_list_type,
|
673
|
+
group_type=0,
|
674
|
+
group_list=group_list,
|
675
|
+
output_dtype=output_dtype,
|
676
|
+
)[0]
|
779
677
|
|
780
|
-
|
678
|
+
return hidden_states
|
679
|
+
|
680
|
+
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
|
681
|
+
return _forward_normal(dispatch_output)
|
682
|
+
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
683
|
+
return _forward_ll(dispatch_output)
|
684
|
+
else:
|
685
|
+
raise ValueError(f"Not Supported DeepEP format {dispatch_output.format}")
|
686
|
+
|
687
|
+
|
688
|
+
def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
|
781
689
|
if get_moe_a2a_backend().is_deepep():
|
782
690
|
return DeepEPMoE
|
783
691
|
|
@@ -790,8 +698,7 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
|
|
790
698
|
return FusedMoE
|
791
699
|
try:
|
792
700
|
# Check the quantization argument directly
|
793
|
-
|
794
|
-
if quantization == "modelopt_fp4":
|
701
|
+
if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
|
795
702
|
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
796
703
|
FlashInferFP4MoE,
|
797
704
|
)
|
@@ -800,10 +707,18 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
|
|
800
707
|
except:
|
801
708
|
pass
|
802
709
|
|
803
|
-
if should_use_flashinfer_trtllm_moe():
|
710
|
+
if should_use_flashinfer_trtllm_moe() and quant_config is not None:
|
711
|
+
# FIXME: FlashInferFusedMoE only supports fp8 quant now
|
804
712
|
return FlashInferFusedMoE
|
805
713
|
if get_moe_runner_backend().is_flashinfer_cutlass():
|
806
714
|
return FusedMoE
|
807
|
-
if get_moe_expert_parallel_world_size() > 1:
|
808
|
-
return EPMoE
|
809
715
|
return FusedMoE
|
716
|
+
|
717
|
+
|
718
|
+
def copy_list_to_gpu_no_ce(arr: List[int]):
|
719
|
+
from sgl_kernel.elementwise import copy_to_gpu_no_ce
|
720
|
+
|
721
|
+
tensor_cpu = torch.tensor(arr, dtype=torch.int32, device="cpu")
|
722
|
+
tensor_gpu = torch.empty_like(tensor_cpu, device="cuda")
|
723
|
+
copy_to_gpu_no_ce(tensor_cpu, tensor_gpu)
|
724
|
+
return tensor_gpu
|