sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +330 -156
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +8 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +134 -23
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +70 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +66 -66
- sglang/srt/entrypoints/grpc_server.py +431 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +120 -8
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +42 -4
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +18 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +4 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +12 -8
- sglang/srt/layers/attention/flashinfer_backend.py +248 -21
- sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +11 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +45 -15
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +147 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
- sglang/srt/layers/moe/ep_moe/layer.py +119 -397
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +17 -1
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +5 -30
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +673 -16
- sglang/srt/layers/sampler.py +36 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +66 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +399 -499
- sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +378 -90
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +435 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +28 -23
- sglang/srt/model_executor/model_runner.py +379 -139
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +273 -98
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +14 -37
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +5 -5
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3_moe.py +19 -35
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +6 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +577 -73
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +38 -28
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +192 -47
- sglang/srt/utils/hf_transformers_utils.py +40 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +232 -99
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
|
@@ -25,17 +25,16 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
|
|
25
25
|
|
|
26
26
|
import torch
|
|
27
27
|
import torch.nn.functional as F
|
|
28
|
+
import tqdm
|
|
28
29
|
from torch import nn
|
|
29
30
|
from transformers import PretrainedConfig
|
|
30
31
|
|
|
31
|
-
from sglang.srt import single_batch_overlap
|
|
32
32
|
from sglang.srt.configs.model_config import (
|
|
33
33
|
get_nsa_index_head_dim,
|
|
34
34
|
get_nsa_index_n_heads,
|
|
35
35
|
get_nsa_index_topk,
|
|
36
36
|
is_deepseek_nsa,
|
|
37
37
|
)
|
|
38
|
-
from sglang.srt.debug_utils.dumper import dumper
|
|
39
38
|
from sglang.srt.distributed import (
|
|
40
39
|
get_moe_expert_parallel_world_size,
|
|
41
40
|
get_pp_group,
|
|
@@ -46,9 +45,11 @@ from sglang.srt.distributed import (
|
|
|
46
45
|
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
|
47
46
|
use_symmetric_memory,
|
|
48
47
|
)
|
|
48
|
+
from sglang.srt.environ import envs
|
|
49
49
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
|
50
50
|
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
|
51
51
|
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
|
52
|
+
from sglang.srt.layers import deep_gemm_wrapper
|
|
52
53
|
from sglang.srt.layers.activation import SiluAndMul
|
|
53
54
|
from sglang.srt.layers.amx_utils import PackWeightMethod
|
|
54
55
|
from sglang.srt.layers.attention.npu_ops.mla_preprocess import (
|
|
@@ -75,7 +76,6 @@ from sglang.srt.layers.linear import (
|
|
|
75
76
|
)
|
|
76
77
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
77
78
|
from sglang.srt.layers.moe import (
|
|
78
|
-
get_deepep_mode,
|
|
79
79
|
get_moe_a2a_backend,
|
|
80
80
|
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
|
81
81
|
should_use_flashinfer_trtllm_moe,
|
|
@@ -83,8 +83,12 @@ from sglang.srt.layers.moe import (
|
|
|
83
83
|
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
|
84
84
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
|
85
85
|
from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
|
|
86
|
-
from sglang.srt.layers.quantization import
|
|
86
|
+
from sglang.srt.layers.quantization import CompressedTensorsConfig
|
|
87
87
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
88
|
+
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
|
89
|
+
CompressedTensorsWNA16AMXEPMoEMethod,
|
|
90
|
+
)
|
|
91
|
+
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
|
88
92
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
|
89
93
|
is_fp8_fnuz,
|
|
90
94
|
per_tensor_quant_mla_fp8,
|
|
@@ -95,7 +99,9 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|
|
95
99
|
block_quant_to_tensor_quant,
|
|
96
100
|
channel_quant_to_tensor_quant,
|
|
97
101
|
normalize_e4m3fn_to_e4m3fnuz,
|
|
102
|
+
quant_weight_ue8m0,
|
|
98
103
|
requant_weight_ue8m0_inplace,
|
|
104
|
+
transform_scale_ue8m0_inplace,
|
|
99
105
|
)
|
|
100
106
|
from sglang.srt.layers.quantization.int8_utils import (
|
|
101
107
|
block_dequant as int8_block_dequant,
|
|
@@ -107,14 +113,12 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
|
107
113
|
ParallelLMHead,
|
|
108
114
|
VocabParallelEmbedding,
|
|
109
115
|
)
|
|
110
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
111
116
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
|
112
117
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
118
|
+
from sglang.srt.server_args import get_global_server_args
|
|
113
119
|
from sglang.srt.single_batch_overlap import SboFlags
|
|
114
|
-
from sglang.srt.
|
|
115
|
-
|
|
116
|
-
model_forward_maybe_tbo,
|
|
117
|
-
)
|
|
120
|
+
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
|
121
|
+
from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
|
|
118
122
|
from sglang.srt.utils import (
|
|
119
123
|
BumpAllocator,
|
|
120
124
|
LazyValue,
|
|
@@ -131,6 +135,7 @@ from sglang.srt.utils import (
|
|
|
131
135
|
is_hip,
|
|
132
136
|
is_non_idle_and_non_empty,
|
|
133
137
|
is_npu,
|
|
138
|
+
is_nvidia_cublas_cu12_version_ge_12_9,
|
|
134
139
|
is_sm100_supported,
|
|
135
140
|
log_info_on_rank0,
|
|
136
141
|
make_layers,
|
|
@@ -181,18 +186,31 @@ elif _is_hip:
|
|
|
181
186
|
awq_dequantize_triton as awq_dequantize,
|
|
182
187
|
)
|
|
183
188
|
elif _is_npu:
|
|
184
|
-
import custom_ops
|
|
185
|
-
import sgl_kernel_npu
|
|
186
|
-
import torch_npu
|
|
189
|
+
import custom_ops # noqa: F401
|
|
190
|
+
import sgl_kernel_npu # noqa: F401
|
|
191
|
+
import torch_npu # noqa: F401
|
|
192
|
+
|
|
193
|
+
from sglang.srt.layers.quantization.awq_triton import (
|
|
194
|
+
awq_dequantize_decomposition as awq_dequantize,
|
|
195
|
+
)
|
|
187
196
|
else:
|
|
188
197
|
pass
|
|
189
198
|
|
|
190
199
|
_is_flashinfer_available = is_flashinfer_available()
|
|
191
200
|
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
|
192
|
-
|
|
201
|
+
_is_cublas_ge_129 = is_nvidia_cublas_cu12_version_ge_12_9()
|
|
193
202
|
|
|
194
203
|
logger = logging.getLogger(__name__)
|
|
195
204
|
|
|
205
|
+
|
|
206
|
+
def enable_nextn_moe_bf16_cast_to_fp8(quant_config):
|
|
207
|
+
return (
|
|
208
|
+
quant_config is not None
|
|
209
|
+
and quant_config.get_name() == "modelopt_fp4"
|
|
210
|
+
and get_moe_a2a_backend().is_deepep()
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
|
|
196
214
|
FORWARD_ABSORB_CORE_ATTENTION_BACKENDS = [
|
|
197
215
|
"fa3",
|
|
198
216
|
"nsa",
|
|
@@ -517,12 +535,13 @@ class DeepseekV2MoE(nn.Module):
|
|
|
517
535
|
self.n_shared_experts = config.n_shared_experts
|
|
518
536
|
self.num_fused_shared_experts = (
|
|
519
537
|
0
|
|
520
|
-
if
|
|
538
|
+
if get_global_server_args().disable_shared_experts_fusion
|
|
521
539
|
else config.n_shared_experts
|
|
522
540
|
)
|
|
523
541
|
self.config = config
|
|
524
542
|
self.layer_id = layer_id
|
|
525
543
|
self.alt_stream = alt_stream
|
|
544
|
+
self.is_nextn = is_nextn
|
|
526
545
|
|
|
527
546
|
if self.tp_size > config.n_routed_experts:
|
|
528
547
|
raise ValueError(
|
|
@@ -546,7 +565,7 @@ class DeepseekV2MoE(nn.Module):
|
|
|
546
565
|
self.experts = get_moe_impl_class(quant_config)(
|
|
547
566
|
num_experts=config.n_routed_experts
|
|
548
567
|
+ self.num_fused_shared_experts
|
|
549
|
-
+
|
|
568
|
+
+ get_global_server_args().ep_num_redundant_experts,
|
|
550
569
|
num_fused_shared_experts=self.num_fused_shared_experts,
|
|
551
570
|
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
|
552
571
|
hidden_size=config.hidden_size,
|
|
@@ -589,6 +608,7 @@ class DeepseekV2MoE(nn.Module):
|
|
|
589
608
|
**(
|
|
590
609
|
dict(tp_rank=0, tp_size=1)
|
|
591
610
|
if get_moe_a2a_backend().is_deepep()
|
|
611
|
+
or get_moe_a2a_backend().is_mooncake()
|
|
592
612
|
or should_use_flashinfer_cutlass_moe_fp4_allgather()
|
|
593
613
|
else {}
|
|
594
614
|
),
|
|
@@ -619,12 +639,12 @@ class DeepseekV2MoE(nn.Module):
|
|
|
619
639
|
|
|
620
640
|
self.top_k = config.num_experts_per_tok
|
|
621
641
|
|
|
622
|
-
if get_moe_a2a_backend().is_deepep():
|
|
642
|
+
if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake():
|
|
623
643
|
# TODO: we will support tp < ep in the future
|
|
624
644
|
self.ep_size = get_moe_expert_parallel_world_size()
|
|
625
645
|
self.num_experts = (
|
|
626
646
|
config.n_routed_experts
|
|
627
|
-
+
|
|
647
|
+
+ get_global_server_args().ep_num_redundant_experts
|
|
628
648
|
)
|
|
629
649
|
self.renormalize = config.norm_topk_prob
|
|
630
650
|
self.topk_group = config.topk_group
|
|
@@ -635,20 +655,10 @@ class DeepseekV2MoE(nn.Module):
|
|
|
635
655
|
else None
|
|
636
656
|
)
|
|
637
657
|
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
num_experts=self.num_experts,
|
|
643
|
-
num_local_experts=config.n_routed_experts // self.tp_size,
|
|
644
|
-
hidden_size=config.hidden_size,
|
|
645
|
-
params_dtype=config.torch_dtype,
|
|
646
|
-
deepep_mode=get_deepep_mode(),
|
|
647
|
-
async_finish=True,
|
|
648
|
-
return_recv_hook=True,
|
|
649
|
-
)
|
|
650
|
-
|
|
651
|
-
self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
|
|
658
|
+
self._enable_a2a_moe = (
|
|
659
|
+
get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake()
|
|
660
|
+
)
|
|
661
|
+
self._fuse_shared_experts_inside_sbo = SboFlags.fuse_shared_experts_inside_sbo()
|
|
652
662
|
|
|
653
663
|
def get_moe_weights(self):
|
|
654
664
|
return [
|
|
@@ -665,7 +675,7 @@ class DeepseekV2MoE(nn.Module):
|
|
|
665
675
|
use_reduce_scatter: bool = False,
|
|
666
676
|
gemm_output_zero_allocator: BumpAllocator = None,
|
|
667
677
|
) -> torch.Tensor:
|
|
668
|
-
if not self.
|
|
678
|
+
if not self._enable_a2a_moe:
|
|
669
679
|
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
|
670
680
|
if (
|
|
671
681
|
self.alt_stream is not None
|
|
@@ -707,6 +717,10 @@ class DeepseekV2MoE(nn.Module):
|
|
|
707
717
|
# router_logits: (num_tokens, n_experts)
|
|
708
718
|
router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
|
|
709
719
|
topk_output = self.topk(hidden_states, router_logits)
|
|
720
|
+
if isinstance(
|
|
721
|
+
self.experts.quant_method, CompressedTensorsWNA16AMXEPMoEMethod
|
|
722
|
+
):
|
|
723
|
+
topk_output.topk_weights.mul_(self.routed_scaling_factor)
|
|
710
724
|
final_hidden_states = self.experts(hidden_states, topk_output)
|
|
711
725
|
if not _is_cuda:
|
|
712
726
|
final_hidden_states *= self.routed_scaling_factor
|
|
@@ -740,9 +754,10 @@ class DeepseekV2MoE(nn.Module):
|
|
|
740
754
|
return self.forward_cpu(hidden_states, should_allreduce_fusion)
|
|
741
755
|
|
|
742
756
|
if hidden_states.shape[0] > 0:
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
757
|
+
if not self._fuse_shared_experts_inside_sbo:
|
|
758
|
+
shared_output = self._forward_shared_experts(
|
|
759
|
+
hidden_states, gemm_output_zero_allocator
|
|
760
|
+
)
|
|
746
761
|
# router_logits: (num_tokens, n_experts)
|
|
747
762
|
router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
|
|
748
763
|
topk_output = self.topk(hidden_states, router_logits)
|
|
@@ -750,7 +765,27 @@ class DeepseekV2MoE(nn.Module):
|
|
|
750
765
|
shared_output = None
|
|
751
766
|
topk_output = self.topk.empty_topk_output(hidden_states.device)
|
|
752
767
|
|
|
753
|
-
|
|
768
|
+
if self._fuse_shared_experts_inside_sbo:
|
|
769
|
+
shared_output = None
|
|
770
|
+
|
|
771
|
+
def _forward_shared_experts_and_put_results():
|
|
772
|
+
nonlocal shared_output
|
|
773
|
+
shared_output = self._forward_shared_experts(
|
|
774
|
+
hidden_states, gemm_output_zero_allocator
|
|
775
|
+
)
|
|
776
|
+
|
|
777
|
+
final_hidden_states = self.experts(
|
|
778
|
+
hidden_states,
|
|
779
|
+
topk_output,
|
|
780
|
+
**(
|
|
781
|
+
dict(
|
|
782
|
+
forward_shared_experts=_forward_shared_experts_and_put_results,
|
|
783
|
+
alt_stream=self.alt_stream,
|
|
784
|
+
)
|
|
785
|
+
if self._fuse_shared_experts_inside_sbo
|
|
786
|
+
else {}
|
|
787
|
+
),
|
|
788
|
+
)
|
|
754
789
|
if not _is_cuda and not _use_aiter:
|
|
755
790
|
# fused in biased_grouped_topk so we can skip here
|
|
756
791
|
final_hidden_states *= self.routed_scaling_factor
|
|
@@ -834,9 +869,9 @@ class DeepseekV2MoE(nn.Module):
|
|
|
834
869
|
if hidden_states.shape[0] > 0:
|
|
835
870
|
# router_logits: (num_tokens, n_experts)
|
|
836
871
|
router_logits = self.gate(hidden_states)
|
|
837
|
-
if not
|
|
872
|
+
if not self._fuse_shared_experts_inside_sbo:
|
|
838
873
|
shared_output = self._forward_shared_experts(hidden_states)
|
|
839
|
-
|
|
874
|
+
topk_output = self.topk(
|
|
840
875
|
hidden_states,
|
|
841
876
|
router_logits,
|
|
842
877
|
num_token_non_padded=forward_batch.num_token_non_padded,
|
|
@@ -845,22 +880,29 @@ class DeepseekV2MoE(nn.Module):
|
|
|
845
880
|
),
|
|
846
881
|
)
|
|
847
882
|
else:
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
883
|
+
topk_output = self.topk.empty_topk_output(hidden_states.device)
|
|
884
|
+
|
|
885
|
+
if self._fuse_shared_experts_inside_sbo:
|
|
886
|
+
shared_output = None
|
|
887
|
+
|
|
888
|
+
def _forward_shared_experts_and_put_results():
|
|
889
|
+
nonlocal shared_output
|
|
890
|
+
shared_output = self._forward_shared_experts(hidden_states)
|
|
851
891
|
|
|
852
|
-
final_hidden_states
|
|
892
|
+
final_hidden_states = self.experts(
|
|
853
893
|
hidden_states=hidden_states,
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
894
|
+
topk_output=topk_output,
|
|
895
|
+
**(
|
|
896
|
+
dict(
|
|
897
|
+
forward_shared_experts=_forward_shared_experts_and_put_results,
|
|
898
|
+
alt_stream=self.alt_stream,
|
|
899
|
+
# SBO is not yet implemented for NextN
|
|
900
|
+
disable_sbo=self.is_nextn,
|
|
901
|
+
)
|
|
902
|
+
if self._fuse_shared_experts_inside_sbo
|
|
903
|
+
else {}
|
|
904
|
+
),
|
|
861
905
|
)
|
|
862
|
-
if sbo_shared_output is not None:
|
|
863
|
-
shared_output = sbo_shared_output
|
|
864
906
|
|
|
865
907
|
if shared_output is not None:
|
|
866
908
|
x = shared_output
|
|
@@ -911,7 +953,7 @@ class DeepseekV2MoE(nn.Module):
|
|
|
911
953
|
with get_global_expert_distribution_recorder().with_current_layer(
|
|
912
954
|
self.layer_id
|
|
913
955
|
):
|
|
914
|
-
state.
|
|
956
|
+
state.topk_output = self.topk(
|
|
915
957
|
hidden_states=hidden_states,
|
|
916
958
|
router_logits=router_logits,
|
|
917
959
|
num_token_non_padded=state.forward_batch.num_token_non_padded,
|
|
@@ -920,21 +962,13 @@ class DeepseekV2MoE(nn.Module):
|
|
|
920
962
|
),
|
|
921
963
|
)
|
|
922
964
|
else:
|
|
923
|
-
state.
|
|
924
|
-
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
|
925
|
-
)
|
|
926
|
-
state.topk_weights_local = torch.empty(
|
|
927
|
-
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
|
928
|
-
)
|
|
965
|
+
state.topk_output = self.topk.empty_topk_output(hidden_states.device)
|
|
929
966
|
|
|
930
967
|
def op_dispatch_a(self, state):
|
|
931
968
|
if self.ep_size > 1:
|
|
932
|
-
self.experts.
|
|
969
|
+
self.experts.dispatcher.dispatch_a(
|
|
933
970
|
hidden_states=state.hidden_states_mlp_input,
|
|
934
|
-
|
|
935
|
-
topk_idx=state.pop("topk_idx_local"),
|
|
936
|
-
topk_weights=state.pop("topk_weights_local"),
|
|
937
|
-
forward_batch=state.forward_batch,
|
|
971
|
+
topk_output=state.pop("topk_output"),
|
|
938
972
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
|
939
973
|
)
|
|
940
974
|
|
|
@@ -943,32 +977,29 @@ class DeepseekV2MoE(nn.Module):
|
|
|
943
977
|
with get_global_expert_distribution_recorder().with_current_layer(
|
|
944
978
|
self.layer_id
|
|
945
979
|
):
|
|
946
|
-
state.dispatch_output = self.experts.
|
|
980
|
+
state.dispatch_output = self.experts.dispatcher.dispatch_b(
|
|
947
981
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
|
948
982
|
)
|
|
949
983
|
|
|
950
984
|
def op_experts(self, state):
|
|
951
|
-
state.hidden_states_experts_output = self.experts.
|
|
985
|
+
state.hidden_states_experts_output = self.experts.run_moe_core(
|
|
952
986
|
dispatch_output=state.dispatch_output,
|
|
953
987
|
)
|
|
954
988
|
|
|
955
989
|
def op_combine_a(self, state):
|
|
956
990
|
if self.ep_size > 1:
|
|
957
|
-
self.experts.
|
|
991
|
+
self.experts.dispatcher.combine_a(
|
|
958
992
|
hidden_states=state.pop("hidden_states_experts_output"),
|
|
959
|
-
|
|
993
|
+
topk_ids=state.dispatch_output.topk_ids,
|
|
960
994
|
topk_weights=state.dispatch_output.topk_weights,
|
|
961
|
-
forward_batch=state.forward_batch,
|
|
962
995
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
|
963
996
|
)
|
|
964
997
|
state.pop("dispatch_output")
|
|
965
998
|
|
|
966
999
|
def op_combine_b(self, state):
|
|
967
1000
|
if self.ep_size > 1:
|
|
968
|
-
state.hidden_states_after_combine = (
|
|
969
|
-
|
|
970
|
-
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
|
971
|
-
)
|
|
1001
|
+
state.hidden_states_after_combine = self.experts.dispatcher.combine_b(
|
|
1002
|
+
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
|
972
1003
|
)
|
|
973
1004
|
|
|
974
1005
|
def op_output(self, state):
|
|
@@ -1050,7 +1081,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
|
1050
1081
|
q_lora_rank,
|
|
1051
1082
|
self.num_heads * self.qk_head_dim,
|
|
1052
1083
|
bias=False,
|
|
1053
|
-
quant_config=quant_config,
|
|
1084
|
+
quant_config=self._get_q_b_proj_quant_config(quant_config),
|
|
1054
1085
|
prefix=add_prefix("q_b_proj", prefix),
|
|
1055
1086
|
tp_rank=attn_tp_rank,
|
|
1056
1087
|
tp_size=attn_tp_size,
|
|
@@ -1122,7 +1153,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
|
1122
1153
|
base=rope_theta,
|
|
1123
1154
|
rope_scaling=rope_scaling,
|
|
1124
1155
|
is_neox_style=False,
|
|
1125
|
-
device=
|
|
1156
|
+
device=get_global_server_args().device,
|
|
1126
1157
|
)
|
|
1127
1158
|
|
|
1128
1159
|
if rope_scaling:
|
|
@@ -1166,12 +1197,12 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
|
1166
1197
|
self.w_scale_v = None
|
|
1167
1198
|
self.use_deep_gemm_bmm = False
|
|
1168
1199
|
|
|
1169
|
-
self.flashinfer_mla_disable_ragged =
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
self.disable_chunked_prefix_cache =
|
|
1173
|
-
|
|
1174
|
-
|
|
1200
|
+
self.flashinfer_mla_disable_ragged = (
|
|
1201
|
+
get_global_server_args().flashinfer_mla_disable_ragged
|
|
1202
|
+
)
|
|
1203
|
+
self.disable_chunked_prefix_cache = (
|
|
1204
|
+
get_global_server_args().disable_chunked_prefix_cache
|
|
1205
|
+
)
|
|
1175
1206
|
|
|
1176
1207
|
self.current_attention_backend = (
|
|
1177
1208
|
None # Attention backend used by current forward batch
|
|
@@ -1250,18 +1281,18 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
|
1250
1281
|
) -> AttnForwardMethod:
|
|
1251
1282
|
# Determine attention backend used by current forward batch
|
|
1252
1283
|
if forward_batch.forward_mode.is_decode_or_idle():
|
|
1253
|
-
attention_backend =
|
|
1284
|
+
attention_backend = get_global_server_args().decode_attention_backend
|
|
1254
1285
|
elif (
|
|
1255
1286
|
forward_batch.forward_mode.is_target_verify()
|
|
1256
1287
|
or forward_batch.forward_mode.is_draft_extend()
|
|
1257
1288
|
):
|
|
1258
1289
|
# Use the specified backend for speculative operations (both verify and draft extend)
|
|
1259
|
-
if
|
|
1260
|
-
attention_backend =
|
|
1290
|
+
if get_global_server_args().speculative_attention_mode == "decode":
|
|
1291
|
+
attention_backend = get_global_server_args().decode_attention_backend
|
|
1261
1292
|
else: # default to prefill
|
|
1262
|
-
attention_backend =
|
|
1293
|
+
attention_backend = get_global_server_args().prefill_attention_backend
|
|
1263
1294
|
else:
|
|
1264
|
-
attention_backend =
|
|
1295
|
+
attention_backend = get_global_server_args().prefill_attention_backend
|
|
1265
1296
|
self.current_attention_backend = attention_backend
|
|
1266
1297
|
|
|
1267
1298
|
handler = AttentionBackendRegistry.get_handler(attention_backend)
|
|
@@ -1351,6 +1382,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
|
1351
1382
|
inner_state = self.mla_preprocess.forward(
|
|
1352
1383
|
positions, hidden_states, forward_batch, zero_allocator
|
|
1353
1384
|
)
|
|
1385
|
+
inner_state = (*inner_state, None) # add a position for topk_indices
|
|
1354
1386
|
elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
|
|
1355
1387
|
inner_state = self.forward_npu_sparse_prepare(
|
|
1356
1388
|
positions, hidden_states, forward_batch, zero_allocator
|
|
@@ -1572,9 +1604,14 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
|
1572
1604
|
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
|
1573
1605
|
)
|
|
1574
1606
|
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
|
1607
|
+
# fix bmm_fp8 error under cublas12.9 caused by bumpallocator, detail in pr#11612
|
|
1575
1608
|
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
|
1576
1609
|
q_nope.transpose(0, 1),
|
|
1577
|
-
|
|
1610
|
+
(
|
|
1611
|
+
torch.zeros((1,), dtype=torch.float32, device=q_nope.device)
|
|
1612
|
+
if _is_cublas_ge_129
|
|
1613
|
+
else zero_allocator.allocate(1)
|
|
1614
|
+
),
|
|
1578
1615
|
)
|
|
1579
1616
|
q_nope_out = bmm_fp8(
|
|
1580
1617
|
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
|
@@ -1718,7 +1755,11 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
|
1718
1755
|
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
|
1719
1756
|
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
|
1720
1757
|
attn_output.transpose(0, 1),
|
|
1721
|
-
|
|
1758
|
+
(
|
|
1759
|
+
torch.zeros((1,), dtype=torch.float32, device=attn_output.device)
|
|
1760
|
+
if _is_cublas_ge_129
|
|
1761
|
+
else zero_allocator.allocate(1)
|
|
1762
|
+
),
|
|
1722
1763
|
)
|
|
1723
1764
|
attn_bmm_output = bmm_fp8(
|
|
1724
1765
|
attn_output_val,
|
|
@@ -2335,6 +2376,17 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
|
2335
2376
|
output, _ = self.o_proj(attn_output)
|
|
2336
2377
|
return output
|
|
2337
2378
|
|
|
2379
|
+
@staticmethod
|
|
2380
|
+
def _get_q_b_proj_quant_config(quant_config):
|
|
2381
|
+
if get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"):
|
|
2382
|
+
# refer to real DeepSeek V3 quant config
|
|
2383
|
+
return Fp8Config(
|
|
2384
|
+
is_checkpoint_fp8_serialized=True,
|
|
2385
|
+
weight_block_size=[128, 128],
|
|
2386
|
+
)
|
|
2387
|
+
else:
|
|
2388
|
+
return quant_config
|
|
2389
|
+
|
|
2338
2390
|
|
|
2339
2391
|
class DeepseekV2DecoderLayer(nn.Module):
|
|
2340
2392
|
|
|
@@ -2343,6 +2395,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
|
2343
2395
|
config: PretrainedConfig,
|
|
2344
2396
|
layer_id: int,
|
|
2345
2397
|
quant_config: Optional[QuantizationConfig] = None,
|
|
2398
|
+
moe_quant_config: Optional[QuantizationConfig] = None,
|
|
2346
2399
|
is_nextn: bool = False,
|
|
2347
2400
|
prefix: str = "",
|
|
2348
2401
|
alt_stream: Optional[torch.cuda.Stream] = None,
|
|
@@ -2353,7 +2406,9 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
|
2353
2406
|
rope_theta = getattr(config, "rope_theta", 10000)
|
|
2354
2407
|
rope_scaling = getattr(config, "rope_scaling", None)
|
|
2355
2408
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
|
2356
|
-
self.speculative_algorithm =
|
|
2409
|
+
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
|
2410
|
+
get_global_server_args().speculative_algorithm
|
|
2411
|
+
)
|
|
2357
2412
|
self.layer_id = layer_id
|
|
2358
2413
|
self.is_nextn = is_nextn
|
|
2359
2414
|
self.self_attn = DeepseekV2AttentionMLA(
|
|
@@ -2390,7 +2445,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
|
2390
2445
|
if self.is_layer_sparse:
|
|
2391
2446
|
self.mlp = DeepseekV2MoE(
|
|
2392
2447
|
config=config,
|
|
2393
|
-
quant_config=quant_config,
|
|
2448
|
+
quant_config=moe_quant_config or quant_config,
|
|
2394
2449
|
prefix=add_prefix("mlp", prefix),
|
|
2395
2450
|
layer_id=self.layer_id,
|
|
2396
2451
|
alt_stream=alt_stream,
|
|
@@ -2796,6 +2851,10 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
|
2796
2851
|
self.config = config
|
|
2797
2852
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
2798
2853
|
self.quant_config = quant_config
|
|
2854
|
+
if envs.SGLANG_KT_MOE_AMX_WEIGHT_PATH.is_set():
|
|
2855
|
+
CompressedTensorsConfig.DeepSeekFP8Config = Fp8Config(
|
|
2856
|
+
True, "dynamic", None, [128, 128]
|
|
2857
|
+
)
|
|
2799
2858
|
self.determine_num_fused_shared_experts()
|
|
2800
2859
|
self.model = DeepseekV2Model(
|
|
2801
2860
|
config, quant_config, prefix=add_prefix("model", prefix)
|
|
@@ -2805,7 +2864,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
|
2805
2864
|
config.hidden_size,
|
|
2806
2865
|
quant_config=quant_config,
|
|
2807
2866
|
prefix=add_prefix("lm_head", prefix),
|
|
2808
|
-
use_attn_tp_group=
|
|
2867
|
+
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
|
|
2809
2868
|
)
|
|
2810
2869
|
self.logits_processor = LogitsProcessor(config)
|
|
2811
2870
|
|
|
@@ -2825,7 +2884,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
|
2825
2884
|
self, architecture: str = "DeepseekV3ForCausalLM"
|
|
2826
2885
|
):
|
|
2827
2886
|
self.num_fused_shared_experts = 0
|
|
2828
|
-
if
|
|
2887
|
+
if get_global_server_args().disable_shared_experts_fusion:
|
|
2829
2888
|
return
|
|
2830
2889
|
|
|
2831
2890
|
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
|
@@ -2844,7 +2903,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
|
2844
2903
|
disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts."
|
|
2845
2904
|
|
|
2846
2905
|
if disable_reason is not None:
|
|
2847
|
-
|
|
2906
|
+
get_global_server_args().disable_shared_experts_fusion = True
|
|
2848
2907
|
self.num_fused_shared_experts = 0
|
|
2849
2908
|
log_info_on_rank0(
|
|
2850
2909
|
logger,
|
|
@@ -2909,7 +2968,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
|
2909
2968
|
)
|
|
2910
2969
|
if hasattr(self_attn.kv_b_proj, "qweight"):
|
|
2911
2970
|
# AWQ compatible
|
|
2912
|
-
if _is_cuda or _is_hip:
|
|
2971
|
+
if _is_cuda or _is_hip or _is_npu:
|
|
2913
2972
|
w = awq_dequantize(
|
|
2914
2973
|
self_attn.kv_b_proj.qweight,
|
|
2915
2974
|
self_attn.kv_b_proj.scales,
|
|
@@ -2935,11 +2994,13 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
|
2935
2994
|
torch.float8_e4m3fn,
|
|
2936
2995
|
torch.float8_e4m3fnuz,
|
|
2937
2996
|
):
|
|
2938
|
-
|
|
2939
|
-
|
|
2940
|
-
|
|
2941
|
-
|
|
2942
|
-
weight_block_size
|
|
2997
|
+
selected_quant_config = getattr(
|
|
2998
|
+
self.quant_config, "DeepSeekFP8Config", self.quant_config
|
|
2999
|
+
)
|
|
3000
|
+
weight_block_size = getattr(
|
|
3001
|
+
selected_quant_config, "weight_block_size", None
|
|
3002
|
+
)
|
|
3003
|
+
if weight_block_size is not None:
|
|
2943
3004
|
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
|
2944
3005
|
if _is_fp8_fnuz:
|
|
2945
3006
|
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
|
@@ -3069,6 +3130,16 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
|
3069
3130
|
):
|
|
3070
3131
|
self._weight_requant_ue8m0(is_nextn)
|
|
3071
3132
|
|
|
3133
|
+
# TODO can move weight_requant_ue8m0 and transform_scale_ue8m0 into Fp8LinearMethod.process_weights_after_loading
|
|
3134
|
+
if (
|
|
3135
|
+
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
|
3136
|
+
and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
|
3137
|
+
and get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN")
|
|
3138
|
+
):
|
|
3139
|
+
self._transform_scale_ue8m0(is_nextn)
|
|
3140
|
+
if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config):
|
|
3141
|
+
self._transform_scale_nextn_moe_ue8m0()
|
|
3142
|
+
|
|
3072
3143
|
def _weight_requant_ue8m0(self, is_nextn=False):
|
|
3073
3144
|
weight_block_size = self.quant_config.weight_block_size
|
|
3074
3145
|
|
|
@@ -3134,6 +3205,47 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
|
3134
3205
|
module.weight, module.weight_scale_inv, weight_block_size
|
|
3135
3206
|
)
|
|
3136
3207
|
|
|
3208
|
+
# TODO can move weight_requant_ue8m0 and transform_scale_ue8m0 into Fp8LinearMethod.process_weights_after_loading
|
|
3209
|
+
def _transform_scale_ue8m0(self, is_nextn=False):
|
|
3210
|
+
num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
|
|
3211
|
+
|
|
3212
|
+
for layer_id in range(num_hidden_layers):
|
|
3213
|
+
if is_nextn:
|
|
3214
|
+
layer = self.model.decoder
|
|
3215
|
+
else:
|
|
3216
|
+
layer = self.model.layers[layer_id]
|
|
3217
|
+
|
|
3218
|
+
module_list = []
|
|
3219
|
+
if self.config.q_lora_rank is not None:
|
|
3220
|
+
module_list.append(layer.self_attn.q_b_proj)
|
|
3221
|
+
|
|
3222
|
+
for module in module_list:
|
|
3223
|
+
transform_scale_ue8m0_inplace(
|
|
3224
|
+
module.weight_scale_inv, mn=module.weight.shape[-2]
|
|
3225
|
+
)
|
|
3226
|
+
|
|
3227
|
+
# TODO avoid code dup (currently combine from weight_requant_ue8m0 and transform_scale_ue8m0)
|
|
3228
|
+
def _transform_scale_nextn_moe_ue8m0(self):
|
|
3229
|
+
layer = self.model.decoder
|
|
3230
|
+
|
|
3231
|
+
shared_experts = getattr(layer.mlp, "shared_experts", None)
|
|
3232
|
+
if shared_experts is not None:
|
|
3233
|
+
for module in [
|
|
3234
|
+
shared_experts.gate_up_proj,
|
|
3235
|
+
shared_experts.down_proj,
|
|
3236
|
+
]:
|
|
3237
|
+
transform_scale_ue8m0_inplace(
|
|
3238
|
+
module.weight_scale_inv, mn=module.weight.shape[-2]
|
|
3239
|
+
)
|
|
3240
|
+
|
|
3241
|
+
experts = layer.mlp.experts
|
|
3242
|
+
if isinstance(experts, DeepEPMoE):
|
|
3243
|
+
for w in [
|
|
3244
|
+
experts.w13_weight_fp8,
|
|
3245
|
+
experts.w2_weight_fp8,
|
|
3246
|
+
]:
|
|
3247
|
+
transform_scale_ue8m0_inplace(w[1], mn=w[0].shape[-2])
|
|
3248
|
+
|
|
3137
3249
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
|
|
3138
3250
|
|
|
3139
3251
|
if is_nextn:
|
|
@@ -3149,6 +3261,13 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
|
3149
3261
|
else:
|
|
3150
3262
|
raise ValueError("num_nextn_predict_layers is not in the config")
|
|
3151
3263
|
|
|
3264
|
+
if get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"):
|
|
3265
|
+
weights = self._quant_attn_to_fp8_ue8m0(weights, is_nextn=is_nextn)
|
|
3266
|
+
if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config):
|
|
3267
|
+
weights = self._quant_nextn_moe_to_fp8_ue8m0(
|
|
3268
|
+
weights, nextn_layer_id=nextn_layer_id
|
|
3269
|
+
)
|
|
3270
|
+
|
|
3152
3271
|
stacked_params_mapping = [
|
|
3153
3272
|
# (param_name, shard_name, shard_id)
|
|
3154
3273
|
("gate_up_proj", "gate_proj", 0),
|
|
@@ -3378,6 +3497,62 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
|
3378
3497
|
|
|
3379
3498
|
self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
|
|
3380
3499
|
|
|
3500
|
+
def _quant_attn_to_fp8_ue8m0(self, weights, is_nextn):
|
|
3501
|
+
weights_dict = dict(weights)
|
|
3502
|
+
|
|
3503
|
+
# temporarily only support DeepSeek V3/R1
|
|
3504
|
+
weight_block_size = [128, 128]
|
|
3505
|
+
|
|
3506
|
+
for layer_id in tqdm.trange(
|
|
3507
|
+
self.config.num_hidden_layers + int(is_nextn),
|
|
3508
|
+
desc="quant attn to fp8 ue8m0",
|
|
3509
|
+
):
|
|
3510
|
+
for stem in [
|
|
3511
|
+
# may put tensors like `o_proj` here for DeepSeek FP4 ckpt v1
|
|
3512
|
+
"q_b_proj",
|
|
3513
|
+
]:
|
|
3514
|
+
partial_name = f"model.layers.{layer_id}.self_attn.{stem}"
|
|
3515
|
+
original_weight = weights_dict[f"{partial_name}.weight"]
|
|
3516
|
+
out_w, out_s = quant_weight_ue8m0(
|
|
3517
|
+
original_weight, weight_block_size=weight_block_size
|
|
3518
|
+
)
|
|
3519
|
+
weights_dict[f"{partial_name}.weight"] = out_w
|
|
3520
|
+
weights_dict[f"{partial_name}.weight_scale_inv"] = out_s
|
|
3521
|
+
|
|
3522
|
+
return list(weights_dict.items())
|
|
3523
|
+
|
|
3524
|
+
# TODO avoid code dup
|
|
3525
|
+
def _quant_nextn_moe_to_fp8_ue8m0(self, weights, nextn_layer_id: int):
|
|
3526
|
+
weights_dict = dict(weights)
|
|
3527
|
+
|
|
3528
|
+
# temporarily only support DeepSeek V3/R1
|
|
3529
|
+
weight_block_size = [128, 128]
|
|
3530
|
+
|
|
3531
|
+
for layer_id in [nextn_layer_id]:
|
|
3532
|
+
for expert_sub_name in [
|
|
3533
|
+
"shared_experts",
|
|
3534
|
+
*[
|
|
3535
|
+
f"experts.{expert_id}"
|
|
3536
|
+
for expert_id in range(self.config.n_routed_experts)
|
|
3537
|
+
],
|
|
3538
|
+
]:
|
|
3539
|
+
for stem in [
|
|
3540
|
+
"gate_proj",
|
|
3541
|
+
"up_proj",
|
|
3542
|
+
"down_proj",
|
|
3543
|
+
]:
|
|
3544
|
+
partial_name = (
|
|
3545
|
+
f"model.layers.{layer_id}.mlp.{expert_sub_name}.{stem}"
|
|
3546
|
+
)
|
|
3547
|
+
original_weight = weights_dict[f"{partial_name}.weight"]
|
|
3548
|
+
out_w, out_s = quant_weight_ue8m0(
|
|
3549
|
+
original_weight, weight_block_size=weight_block_size
|
|
3550
|
+
)
|
|
3551
|
+
weights_dict[f"{partial_name}.weight"] = out_w
|
|
3552
|
+
weights_dict[f"{partial_name}.weight_scale_inv"] = out_s
|
|
3553
|
+
|
|
3554
|
+
return list(weights_dict.items())
|
|
3555
|
+
|
|
3381
3556
|
def get_embed_and_head(self):
|
|
3382
3557
|
return self.model.embed_tokens.weight, self.lm_head.weight
|
|
3383
3558
|
|