sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +330 -156
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +8 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +134 -23
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +70 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +66 -66
- sglang/srt/entrypoints/grpc_server.py +431 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +120 -8
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +42 -4
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +18 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +4 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +12 -8
- sglang/srt/layers/attention/flashinfer_backend.py +248 -21
- sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +11 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +45 -15
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +147 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
- sglang/srt/layers/moe/ep_moe/layer.py +119 -397
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +17 -1
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +5 -30
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +673 -16
- sglang/srt/layers/sampler.py +36 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +66 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +399 -499
- sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +378 -90
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +435 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +28 -23
- sglang/srt/model_executor/model_runner.py +379 -139
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +273 -98
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +14 -37
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +5 -5
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3_moe.py +19 -35
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +6 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +577 -73
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +38 -28
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +192 -47
- sglang/srt/utils/hf_transformers_utils.py +40 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +232 -99
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -199,7 +199,6 @@ class GPTQConfig(QuantizationConfig):
|
|
|
199
199
|
self, layer: torch.nn.Module, prefix: str
|
|
200
200
|
) -> Optional[LinearMethodBase]:
|
|
201
201
|
# Delay the import to avoid circular dependency
|
|
202
|
-
from sglang.srt.layers.linear import LinearBase
|
|
203
202
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
|
204
203
|
|
|
205
204
|
if isinstance(layer, FusedMoE):
|
|
@@ -12,7 +12,15 @@ from sglang.srt.utils import get_device_name, is_cuda
|
|
|
12
12
|
|
|
13
13
|
_is_cuda = is_cuda()
|
|
14
14
|
if _is_cuda:
|
|
15
|
-
|
|
15
|
+
# Temporary
|
|
16
|
+
try:
|
|
17
|
+
from sgl_kernel import sgl_per_token_group_quant_8bit
|
|
18
|
+
|
|
19
|
+
enable_sgl_per_token_group_quant_8bit = True
|
|
20
|
+
except ImportError:
|
|
21
|
+
from sgl_kernel import sgl_per_token_group_quant_int8
|
|
22
|
+
|
|
23
|
+
enable_sgl_per_token_group_quant_8bit = False
|
|
16
24
|
|
|
17
25
|
logger = logging.getLogger(__name__)
|
|
18
26
|
|
|
@@ -187,6 +195,7 @@ def sglang_per_token_group_quant_int8(
|
|
|
187
195
|
group_size: int,
|
|
188
196
|
eps: float = 1e-10,
|
|
189
197
|
dtype: torch.dtype = torch.int8,
|
|
198
|
+
enable_v2: Optional[bool] = None,
|
|
190
199
|
):
|
|
191
200
|
assert (
|
|
192
201
|
x.shape[-1] % group_size == 0
|
|
@@ -204,7 +213,14 @@ def sglang_per_token_group_quant_int8(
|
|
|
204
213
|
dtype=torch.float32,
|
|
205
214
|
)
|
|
206
215
|
|
|
207
|
-
|
|
216
|
+
# Temporary
|
|
217
|
+
if enable_sgl_per_token_group_quant_8bit:
|
|
218
|
+
sgl_per_token_group_quant_8bit(
|
|
219
|
+
x, x_q, x_s, group_size, eps, int8_min, int8_max, enable_v2=enable_v2
|
|
220
|
+
)
|
|
221
|
+
else:
|
|
222
|
+
assert not enable_v2
|
|
223
|
+
sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
|
|
208
224
|
|
|
209
225
|
return x_q, x_s
|
|
210
226
|
|
|
@@ -4,6 +4,7 @@
|
|
|
4
4
|
from __future__ import annotations
|
|
5
5
|
|
|
6
6
|
import logging
|
|
7
|
+
from dataclasses import dataclass
|
|
7
8
|
from typing import TYPE_CHECKING, Any, Optional
|
|
8
9
|
|
|
9
10
|
import numpy
|
|
@@ -57,6 +58,17 @@ MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
|
|
57
58
|
USE_FP32_REDUCE_DEFAULT = True
|
|
58
59
|
|
|
59
60
|
|
|
61
|
+
@dataclass
|
|
62
|
+
class MarlinLinearLayerConfig:
|
|
63
|
+
full_weight_shape: tuple[int, int] # [in, out]
|
|
64
|
+
partition_weight_shape: tuple[int, int]
|
|
65
|
+
weight_type: ScalarType
|
|
66
|
+
act_type: torch.dtype
|
|
67
|
+
group_size: int
|
|
68
|
+
zero_points: bool
|
|
69
|
+
has_g_idx: bool
|
|
70
|
+
|
|
71
|
+
|
|
60
72
|
# For binary size and compile time, we don't support the same types for with and
|
|
61
73
|
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
|
|
62
74
|
# TODO: we may want to move this into the C++ so its closer to the actual impl
|
|
@@ -79,7 +79,7 @@ CUTEDSL_MOE_SCALAR_INPUT_SCALE = get_bool_env_var(
|
|
|
79
79
|
"SGLANG_CUTEDSL_MOE_SCALAR_INPUT_SCALE", "true"
|
|
80
80
|
)
|
|
81
81
|
USE_CUTLASS_BACKEND_FOR_FP4_GEMM = get_bool_env_var(
|
|
82
|
-
"SGLANG_USE_CUTLASS_BACKEND_FOR_FP4_GEMM"
|
|
82
|
+
"SGLANG_USE_CUTLASS_BACKEND_FOR_FP4_GEMM", "true"
|
|
83
83
|
)
|
|
84
84
|
# TODO make it true by default when the DeepEP PR is merged
|
|
85
85
|
CUTEDSL_MOE_NVFP4_DISPATCH = get_bool_env_var(
|
|
@@ -90,7 +90,50 @@ CUTEDSL_MOE_NVFP4_DISPATCH = get_bool_env_var(
|
|
|
90
90
|
ACTIVATION_SCHEMES = ["static"]
|
|
91
91
|
|
|
92
92
|
|
|
93
|
-
class
|
|
93
|
+
class ModelOptQuantConfig(QuantizationConfig):
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
kv_cache_quant_algo: Optional[str],
|
|
97
|
+
exclude_modules: Optional[List[str]],
|
|
98
|
+
packed_modules_mapping: Optional[Dict[str, List[str]]],
|
|
99
|
+
):
|
|
100
|
+
super().__init__()
|
|
101
|
+
self.packed_modules_mapping = packed_modules_mapping
|
|
102
|
+
self.exclude_modules = exclude_modules or []
|
|
103
|
+
self.kv_cache_quant_algo = kv_cache_quant_algo
|
|
104
|
+
|
|
105
|
+
def _get_quant_method(
|
|
106
|
+
self,
|
|
107
|
+
layer: torch.nn.Module,
|
|
108
|
+
prefix: str,
|
|
109
|
+
*,
|
|
110
|
+
Linear: type[LinearMethodBase],
|
|
111
|
+
Moe: type[FusedMoEMethodBase],
|
|
112
|
+
) -> Optional[QuantizeMethodBase]:
|
|
113
|
+
from sglang.srt.layers.linear import LinearBase
|
|
114
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
|
115
|
+
|
|
116
|
+
if isinstance(layer, LinearBase):
|
|
117
|
+
if is_layer_skipped(
|
|
118
|
+
prefix, self.exclude_modules, self.packed_modules_mapping
|
|
119
|
+
) or self.is_layer_excluded(prefix):
|
|
120
|
+
return UnquantizedLinearMethod()
|
|
121
|
+
return Linear(self)
|
|
122
|
+
elif self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
|
|
123
|
+
return ModelOptFp8KVCacheMethod(self)
|
|
124
|
+
elif isinstance(layer, FusedMoE):
|
|
125
|
+
return Moe(self)
|
|
126
|
+
return None
|
|
127
|
+
|
|
128
|
+
@classmethod
|
|
129
|
+
def get_config_filenames(cls) -> List[str]:
|
|
130
|
+
return ["hf_quant_config.json"]
|
|
131
|
+
|
|
132
|
+
def get_scaled_act_names(self) -> List[str]:
|
|
133
|
+
return []
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class ModelOptFp8Config(ModelOptQuantConfig):
|
|
94
137
|
"""Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks."""
|
|
95
138
|
|
|
96
139
|
def __init__(
|
|
@@ -98,22 +141,27 @@ class ModelOptFp8Config(QuantizationConfig):
|
|
|
98
141
|
is_checkpoint_fp8_serialized: bool = False,
|
|
99
142
|
kv_cache_quant_method: Optional[str] = None,
|
|
100
143
|
exclude_modules: Optional[List[str]] = None,
|
|
144
|
+
packed_modules_mapping: Optional[Dict[str, List[str]]] = None,
|
|
101
145
|
) -> None:
|
|
102
146
|
"""
|
|
103
147
|
Args:
|
|
104
148
|
is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format.
|
|
105
149
|
"""
|
|
150
|
+
super().__init__(kv_cache_quant_method, exclude_modules, packed_modules_mapping)
|
|
106
151
|
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
|
107
|
-
self.kv_cache_quant_method = kv_cache_quant_method
|
|
108
|
-
self.exclude_modules = exclude_modules
|
|
109
152
|
if is_checkpoint_fp8_serialized:
|
|
110
153
|
logger.warning(
|
|
111
154
|
"Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
|
|
112
155
|
)
|
|
113
156
|
|
|
157
|
+
@classmethod
|
|
158
|
+
def override_quantization_method(cls, hf_quant_config, user_quant):
|
|
159
|
+
"""Override quantization method based on the model's config."""
|
|
160
|
+
return cls._modelopt_override_quantization_method(hf_quant_config, user_quant)
|
|
161
|
+
|
|
114
162
|
@classmethod
|
|
115
163
|
def get_name(cls) -> str:
|
|
116
|
-
return "
|
|
164
|
+
return "modelopt_fp8"
|
|
117
165
|
|
|
118
166
|
@classmethod
|
|
119
167
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
|
@@ -123,10 +171,6 @@ class ModelOptFp8Config(QuantizationConfig):
|
|
|
123
171
|
def get_min_capability(cls) -> int:
|
|
124
172
|
return 89 # Minimum hardware capability (e.g., Hopper GPUs).
|
|
125
173
|
|
|
126
|
-
@classmethod
|
|
127
|
-
def get_config_filenames(cls) -> List[str]:
|
|
128
|
-
return ["hf_quant_config.json"]
|
|
129
|
-
|
|
130
174
|
@classmethod
|
|
131
175
|
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp8Config:
|
|
132
176
|
# Handle two different config formats:
|
|
@@ -181,37 +225,27 @@ class ModelOptFp8Config(QuantizationConfig):
|
|
|
181
225
|
is_checkpoint_fp8_serialized=True,
|
|
182
226
|
kv_cache_quant_method=kv_cache_quant_method,
|
|
183
227
|
exclude_modules=exclude_modules,
|
|
228
|
+
packed_modules_mapping=config.get("packed_modules_mapping"),
|
|
184
229
|
)
|
|
185
230
|
|
|
186
|
-
def
|
|
187
|
-
self
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
from sglang.srt.layers.linear import LinearBase
|
|
191
|
-
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
|
192
|
-
|
|
193
|
-
if self.exclude_modules and any(
|
|
231
|
+
def is_layer_excluded(self, prefix: str) -> bool:
|
|
232
|
+
if len(self.exclude_modules) == 0:
|
|
233
|
+
return False
|
|
234
|
+
return any(
|
|
194
235
|
module in prefix
|
|
195
236
|
or (
|
|
196
237
|
prefix.startswith("language_model.")
|
|
197
238
|
and module in prefix.removeprefix("language_model.")
|
|
198
239
|
)
|
|
199
240
|
for module in self.exclude_modules
|
|
200
|
-
)
|
|
201
|
-
return None
|
|
202
|
-
|
|
203
|
-
if isinstance(layer, LinearBase):
|
|
204
|
-
return ModelOptFp8LinearMethod(self)
|
|
205
|
-
if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
|
|
206
|
-
return ModelOptFp8KVCacheMethod(self)
|
|
207
|
-
|
|
208
|
-
if isinstance(layer, FusedMoE):
|
|
209
|
-
return ModelOptFp8MoEMethod(self)
|
|
210
|
-
|
|
211
|
-
return None
|
|
241
|
+
)
|
|
212
242
|
|
|
213
|
-
def
|
|
214
|
-
|
|
243
|
+
def get_quant_method(
|
|
244
|
+
self, layer: torch.nn.Module, prefix: str
|
|
245
|
+
) -> Optional[QuantizeMethodBase]:
|
|
246
|
+
return self._get_quant_method(
|
|
247
|
+
layer, prefix, Linear=ModelOptFp8LinearMethod, Moe=ModelOptFp8MoEMethod
|
|
248
|
+
)
|
|
215
249
|
|
|
216
250
|
|
|
217
251
|
class ModelOptFp8LinearMethod(LinearMethodBase):
|
|
@@ -507,7 +541,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|
|
507
541
|
return self.runner.run(dispatch_output, quant_info)
|
|
508
542
|
|
|
509
543
|
|
|
510
|
-
class ModelOptFp4Config(
|
|
544
|
+
class ModelOptFp4Config(ModelOptQuantConfig):
|
|
511
545
|
"""Config class for FP4."""
|
|
512
546
|
|
|
513
547
|
def __init__(
|
|
@@ -516,7 +550,9 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
|
516
550
|
kv_cache_quant_algo: str = None,
|
|
517
551
|
group_size: int = None,
|
|
518
552
|
exclude_modules: List[str] = None,
|
|
553
|
+
packed_modules_mapping: Optional[Dict[str, List[str]]] = None,
|
|
519
554
|
) -> None:
|
|
555
|
+
super().__init__(kv_cache_quant_algo, exclude_modules, packed_modules_mapping)
|
|
520
556
|
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
|
|
521
557
|
if is_checkpoint_nvfp4_serialized:
|
|
522
558
|
logger.warning(
|
|
@@ -524,8 +560,11 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
|
524
560
|
"format is experimental and subject to change."
|
|
525
561
|
)
|
|
526
562
|
self.group_size = group_size
|
|
527
|
-
|
|
528
|
-
|
|
563
|
+
|
|
564
|
+
@classmethod
|
|
565
|
+
def override_quantization_method(cls, hf_quant_config, user_quant):
|
|
566
|
+
"""Override quantization method based on the model's config."""
|
|
567
|
+
return cls._modelopt_override_quantization_method(hf_quant_config, user_quant)
|
|
529
568
|
|
|
530
569
|
@classmethod
|
|
531
570
|
def get_name(cls) -> str:
|
|
@@ -539,10 +578,6 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
|
539
578
|
def get_min_capability(cls) -> int:
|
|
540
579
|
return 100
|
|
541
580
|
|
|
542
|
-
@classmethod
|
|
543
|
-
def get_config_filenames(cls) -> List[str]:
|
|
544
|
-
return ["hf_quant_config.json"]
|
|
545
|
-
|
|
546
581
|
@staticmethod
|
|
547
582
|
def common_group_size(cfg: dict) -> int:
|
|
548
583
|
"""Return the unique group_size across the config; raise if missing/mismatched."""
|
|
@@ -608,7 +643,16 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
|
608
643
|
else:
|
|
609
644
|
kv_cache_quant_algo = "auto"
|
|
610
645
|
|
|
611
|
-
group_size =
|
|
646
|
+
group_size = config.get("group_size")
|
|
647
|
+
# If group_size is not at top level, try to extract from config_groups
|
|
648
|
+
if group_size is None:
|
|
649
|
+
config_groups = config.get("config_groups", {})
|
|
650
|
+
if config_groups:
|
|
651
|
+
# Get group_size from the first group's weights config
|
|
652
|
+
first_group = next(iter(config_groups.values()), {})
|
|
653
|
+
weights_config = first_group.get("weights", {})
|
|
654
|
+
group_size = weights_config.get("group_size")
|
|
655
|
+
|
|
612
656
|
exclude_modules = config.get("ignore", [])
|
|
613
657
|
else:
|
|
614
658
|
# Fall back to nested format (hf_quant_config.json - legacy format)
|
|
@@ -634,29 +678,30 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
|
634
678
|
)
|
|
635
679
|
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
|
|
636
680
|
|
|
637
|
-
if
|
|
681
|
+
if group_size is None or exclude_modules is None:
|
|
638
682
|
logger.warning(
|
|
639
683
|
f"group_size: {group_size},"
|
|
640
684
|
f"kv_cache_quant_algo: {kv_cache_quant_algo},"
|
|
641
685
|
f"exclude_modules: {exclude_modules}"
|
|
642
686
|
)
|
|
643
687
|
raise ValueError(
|
|
644
|
-
"NVFP4 quantization requires
|
|
645
|
-
"
|
|
688
|
+
"NVFP4 quantization requires group_size and exclude_modules "
|
|
689
|
+
"specified in the quantization config"
|
|
646
690
|
)
|
|
647
691
|
return cls(
|
|
648
692
|
is_checkpoint_nvfp4_serialized,
|
|
649
693
|
kv_cache_quant_algo,
|
|
650
694
|
group_size,
|
|
651
695
|
exclude_modules,
|
|
696
|
+
config.get("packed_modules_mapping"),
|
|
652
697
|
)
|
|
653
698
|
|
|
654
|
-
def is_layer_excluded(self, prefix: str
|
|
699
|
+
def is_layer_excluded(self, prefix: str):
|
|
655
700
|
import regex as re
|
|
656
701
|
|
|
657
702
|
fused_patterns = ["q_a_proj", "q_b_proj", "kv_a_proj_with_mqa", "kv_b_proj"]
|
|
658
703
|
prefix_split = prefix.split(".")
|
|
659
|
-
for pattern in exclude_modules:
|
|
704
|
+
for pattern in self.exclude_modules:
|
|
660
705
|
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
|
|
661
706
|
pattern_split = pattern.split(".")
|
|
662
707
|
if re.fullmatch(regex_str, prefix):
|
|
@@ -672,30 +717,13 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
|
672
717
|
return True
|
|
673
718
|
return False
|
|
674
719
|
|
|
675
|
-
def get_quant_method(
|
|
676
|
-
self
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
if isinstance(layer, LinearBase):
|
|
683
|
-
if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded(
|
|
684
|
-
prefix, self.exclude_modules
|
|
685
|
-
):
|
|
686
|
-
return UnquantizedLinearMethod()
|
|
687
|
-
return ModelOptFp4LinearMethod(self)
|
|
688
|
-
if self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
|
|
689
|
-
return ModelOptFp8KVCacheMethod(self)
|
|
690
|
-
elif isinstance(layer, FlashInferFP4MoE):
|
|
691
|
-
# FlashInferFP4MoE needs the same quantization method but with compatible attribute handling
|
|
692
|
-
return ModelOptNvFp4FusedMoEMethod(self)
|
|
693
|
-
elif isinstance(layer, FusedMoE):
|
|
694
|
-
return ModelOptNvFp4FusedMoEMethod(self)
|
|
695
|
-
return None
|
|
696
|
-
|
|
697
|
-
def get_scaled_act_names(self) -> List[str]:
|
|
698
|
-
return []
|
|
720
|
+
def get_quant_method(self, layer: torch.nn.Module, prefix: str):
|
|
721
|
+
return self._get_quant_method(
|
|
722
|
+
layer,
|
|
723
|
+
prefix,
|
|
724
|
+
Linear=ModelOptFp4LinearMethod,
|
|
725
|
+
Moe=ModelOptNvFp4FusedMoEMethod, # FlashInferFP4MoE needs the same quantization method but with compatible attribute handling
|
|
726
|
+
)
|
|
699
727
|
|
|
700
728
|
|
|
701
729
|
class ModelOptFp4LinearMethod(LinearMethodBase):
|
|
@@ -852,25 +880,15 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
|
|
|
852
880
|
if enable_flashinfer_fp4_gemm:
|
|
853
881
|
w = layer.weight.T
|
|
854
882
|
w_scale_interleaved = layer.weight_scale_interleaved.T
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
)
|
|
865
|
-
else:
|
|
866
|
-
out = fp4_gemm(
|
|
867
|
-
x_fp4,
|
|
868
|
-
w,
|
|
869
|
-
x_scale_interleaved,
|
|
870
|
-
w_scale_interleaved,
|
|
871
|
-
layer.alpha,
|
|
872
|
-
output_dtype,
|
|
873
|
-
)
|
|
883
|
+
out = fp4_gemm(
|
|
884
|
+
x_fp4,
|
|
885
|
+
w,
|
|
886
|
+
x_scale_interleaved,
|
|
887
|
+
w_scale_interleaved,
|
|
888
|
+
layer.alpha,
|
|
889
|
+
output_dtype,
|
|
890
|
+
**(dict(backend="cutlass") if USE_CUTLASS_BACKEND_FOR_FP4_GEMM else dict()),
|
|
891
|
+
)
|
|
874
892
|
if bias is not None:
|
|
875
893
|
out = out + bias
|
|
876
894
|
return out.view(*output_shape)
|
|
@@ -1069,19 +1087,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
|
1069
1087
|
intermediate_size,
|
|
1070
1088
|
num_experts,
|
|
1071
1089
|
):
|
|
1072
|
-
from flashinfer import
|
|
1073
|
-
RoutingMethodType,
|
|
1074
|
-
e2m1_and_ufp8sf_scale_to_float,
|
|
1075
|
-
fp4_quantize,
|
|
1076
|
-
next_positive_power_of_2,
|
|
1077
|
-
nvfp4_block_scale_interleave,
|
|
1078
|
-
reorder_rows_for_gated_act_gemm,
|
|
1079
|
-
shuffle_matrix_a,
|
|
1080
|
-
shuffle_matrix_sf_a,
|
|
1081
|
-
)
|
|
1090
|
+
from flashinfer import nvfp4_block_scale_interleave
|
|
1082
1091
|
from flashinfer.fused_moe.core import (
|
|
1083
|
-
_maybe_get_cached_w2_permute_indices,
|
|
1084
1092
|
_maybe_get_cached_w3_w1_permute_indices,
|
|
1093
|
+
get_w2_permute_indices_with_cache,
|
|
1085
1094
|
)
|
|
1086
1095
|
|
|
1087
1096
|
"""Prepare quantized weights for kernel (done offline with weights)."""
|
|
@@ -1142,7 +1151,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
|
1142
1151
|
)
|
|
1143
1152
|
)
|
|
1144
1153
|
|
|
1145
|
-
permute_indices =
|
|
1154
|
+
permute_indices = get_w2_permute_indices_with_cache(
|
|
1146
1155
|
self._cache_permute_indices,
|
|
1147
1156
|
gemm2_weights_fp4[i].view(torch.uint8),
|
|
1148
1157
|
epilogue_tile_m,
|
|
@@ -1153,7 +1162,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
|
1153
1162
|
.contiguous()
|
|
1154
1163
|
)
|
|
1155
1164
|
|
|
1156
|
-
permute_sf_indices =
|
|
1165
|
+
permute_sf_indices = get_w2_permute_indices_with_cache(
|
|
1157
1166
|
self._cache_permute_indices,
|
|
1158
1167
|
gemm2_scales_linear_fp4[i].view(torch.uint8),
|
|
1159
1168
|
epilogue_tile_m,
|
|
@@ -1263,6 +1272,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
|
1263
1272
|
(1 / w2_input_scale).to(torch.float32), requires_grad=False
|
|
1264
1273
|
)
|
|
1265
1274
|
|
|
1275
|
+
layer.dispatcher.set_quant_config(
|
|
1276
|
+
{"input_global_scale": layer.w13_input_scale_quant}
|
|
1277
|
+
)
|
|
1278
|
+
|
|
1266
1279
|
# Validate weight scales
|
|
1267
1280
|
for name, weight_scale in [
|
|
1268
1281
|
("w13", layer.w13_weight_scale),
|
|
@@ -1366,6 +1379,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
|
1366
1379
|
self,
|
|
1367
1380
|
layer: FusedMoE,
|
|
1368
1381
|
dispatch_output: StandardDispatchOutput,
|
|
1382
|
+
forward_shared_experts=None,
|
|
1383
|
+
alt_stream=None,
|
|
1369
1384
|
) -> CombineInput:
|
|
1370
1385
|
|
|
1371
1386
|
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
|
@@ -1437,9 +1452,19 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
|
1437
1452
|
)[0]
|
|
1438
1453
|
if should_use_flashinfer_cutlass_moe_fp4_allgather():
|
|
1439
1454
|
output, global_output = get_local_dp_buffer(), output
|
|
1455
|
+
|
|
1456
|
+
if forward_shared_experts is not None:
|
|
1457
|
+
alt_stream.wait_stream(torch.cuda.current_stream())
|
|
1458
|
+
with torch.cuda.stream(alt_stream):
|
|
1459
|
+
forward_shared_experts()
|
|
1460
|
+
|
|
1440
1461
|
get_tp_group().reduce_scatterv(
|
|
1441
1462
|
global_output, output=output, sizes=get_dp_global_num_tokens()
|
|
1442
1463
|
)
|
|
1464
|
+
|
|
1465
|
+
if forward_shared_experts is not None:
|
|
1466
|
+
torch.cuda.current_stream().wait_stream(alt_stream)
|
|
1467
|
+
|
|
1443
1468
|
return StandardCombineInput(hidden_states=output)
|
|
1444
1469
|
|
|
1445
1470
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
|
@@ -31,7 +31,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|
|
31
31
|
QuantizeMethodBase,
|
|
32
32
|
)
|
|
33
33
|
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
|
34
|
-
from sglang.srt.
|
|
34
|
+
from sglang.srt.server_args import get_global_server_args
|
|
35
35
|
from sglang.srt.utils import (
|
|
36
36
|
direct_register_custom_op,
|
|
37
37
|
is_cuda,
|
|
@@ -41,7 +41,6 @@ from sglang.srt.utils import (
|
|
|
41
41
|
is_triton_kernels_available,
|
|
42
42
|
log_info_on_rank0,
|
|
43
43
|
mxfp_supported,
|
|
44
|
-
next_power_of_2,
|
|
45
44
|
round_up,
|
|
46
45
|
set_weight_attrs,
|
|
47
46
|
)
|
|
@@ -265,9 +264,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
|
265
264
|
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
|
|
266
265
|
self.with_bias = False
|
|
267
266
|
self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4()
|
|
268
|
-
self.flashinfer_mxfp4_moe_precision =
|
|
269
|
-
|
|
270
|
-
|
|
267
|
+
self.flashinfer_mxfp4_moe_precision = (
|
|
268
|
+
get_global_server_args().flashinfer_mxfp4_moe_precision
|
|
269
|
+
)
|
|
271
270
|
|
|
272
271
|
self.triton_kernel_moe_forward = None
|
|
273
272
|
self.triton_kernel_moe_with_bias_forward = None
|
|
@@ -597,30 +596,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
|
597
596
|
layer.w2_weight = Parameter(w2_weight.data, requires_grad=False)
|
|
598
597
|
torch.cuda.empty_cache()
|
|
599
598
|
|
|
600
|
-
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
|
|
601
|
-
# Number of tokens in the input tensor.
|
|
602
|
-
num_tokens = x.shape[0]
|
|
603
|
-
# Factor to account for the imbalance of the experts.
|
|
604
|
-
# factor equals to the
|
|
605
|
-
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
|
|
606
|
-
# - 1.0 means perfect expert distribution.
|
|
607
|
-
# - > 1.0 means some experts have more
|
|
608
|
-
# tokens than the perfect distribution.
|
|
609
|
-
# - < 1.0 does not make sense.
|
|
610
|
-
imbalance_factor = 1.3
|
|
611
|
-
# Calculate the number of tokens per expert
|
|
612
|
-
# assuming perfect distribution.
|
|
613
|
-
num_tokens_per_expert = (num_tokens * top_k) // self.num_experts
|
|
614
|
-
# Apply the imbalance factor.
|
|
615
|
-
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
|
|
616
|
-
# And pad the number to the next power of 2.
|
|
617
|
-
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
|
618
|
-
# Cap to 8-64 tokens per CTA tile
|
|
619
|
-
# as it's the range supported by the kernel.
|
|
620
|
-
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
|
621
|
-
|
|
622
|
-
return tile_tokens_dim
|
|
623
|
-
|
|
624
599
|
def create_moe_runner(
|
|
625
600
|
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
|
626
601
|
):
|
|
@@ -696,7 +671,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
|
696
671
|
layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
|
|
697
672
|
layer.num_local_experts, # local num experts
|
|
698
673
|
None,
|
|
699
|
-
|
|
674
|
+
None, # tile_tokens_dim
|
|
700
675
|
1, # routing_method_type, renormalize
|
|
701
676
|
True, # do finalize
|
|
702
677
|
)[0]
|
|
@@ -65,7 +65,9 @@ class QuarkConfig(QuantizationConfig):
|
|
|
65
65
|
if should_ignore_layer(
|
|
66
66
|
prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping
|
|
67
67
|
):
|
|
68
|
-
|
|
68
|
+
if isinstance(layer, LinearBase):
|
|
69
|
+
return UnquantizedLinearMethod()
|
|
70
|
+
return None
|
|
69
71
|
|
|
70
72
|
if isinstance(layer, LinearBase):
|
|
71
73
|
scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
|
@@ -3,16 +3,16 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import logging
|
|
6
|
-
from typing import TYPE_CHECKING, Any
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
|
-
from aiter import ActivationType, QuantType
|
|
9
|
+
from aiter import ActivationType, QuantType
|
|
10
10
|
from aiter.fused_moe import fused_moe
|
|
11
11
|
from aiter.utility.fp4_utils import e8m0_shuffle
|
|
12
12
|
|
|
13
13
|
from sglang.srt.layers.moe import MoeRunnerConfig
|
|
14
14
|
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
|
|
15
|
-
from sglang.srt.utils import
|
|
15
|
+
from sglang.srt.utils import is_hip, set_weight_attrs
|
|
16
16
|
|
|
17
17
|
if TYPE_CHECKING:
|
|
18
18
|
from sglang.srt.layers.moe.token_dispatcher import (
|
|
@@ -2,20 +2,13 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Any, Callable, Optional
|
|
4
4
|
|
|
5
|
-
import aiter
|
|
6
5
|
import torch
|
|
7
|
-
import torch.nn.functional as F
|
|
8
|
-
from aiter.ops.gemm_op_a4w4 import gemm_a4w4
|
|
9
|
-
from aiter.ops.shuffle import shuffle_weight
|
|
10
6
|
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
|
|
11
7
|
from aiter.ops.triton.gemm_afp4wfp4_pre_quant_atomic import gemm_afp4wfp4_pre_quant
|
|
12
8
|
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
|
13
|
-
from aiter.utility import dtypes
|
|
14
|
-
from aiter.utility.fp4_utils import e8m0_shuffle
|
|
15
9
|
|
|
16
10
|
from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
|
|
17
11
|
from sglang.srt.layers.quantization.quark.schemes import QuarkScheme
|
|
18
|
-
from sglang.srt.utils import get_bool_env_var
|
|
19
12
|
|
|
20
13
|
__all__ = ["QuarkW4A4MXFP4"]
|
|
21
14
|
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import importlib.util
|
|
4
3
|
from typing import TYPE_CHECKING, List, Optional
|
|
5
4
|
|
|
6
5
|
import torch
|
|
@@ -31,8 +30,6 @@ if TYPE_CHECKING:
|
|
|
31
30
|
StandardDispatchOutput,
|
|
32
31
|
)
|
|
33
32
|
|
|
34
|
-
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
|
35
|
-
|
|
36
33
|
|
|
37
34
|
_is_cpu_amx_available = cpu_has_amx_support()
|
|
38
35
|
_is_hip = is_hip()
|
|
@@ -143,7 +140,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
|
143
140
|
|
|
144
141
|
self.triton_kernel_moe_forward = None
|
|
145
142
|
self.triton_kernel_moe_with_bias_forward = None
|
|
146
|
-
if torch.cuda.is_available() and
|
|
143
|
+
if torch.cuda.is_available() and use_triton_kernels:
|
|
147
144
|
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
|
148
145
|
triton_kernel_moe_forward as _tk_forward,
|
|
149
146
|
)
|