sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +330 -156
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +8 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +134 -23
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +70 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +66 -66
- sglang/srt/entrypoints/grpc_server.py +431 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +120 -8
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +42 -4
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +18 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +4 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +12 -8
- sglang/srt/layers/attention/flashinfer_backend.py +248 -21
- sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +11 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +45 -15
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +147 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
- sglang/srt/layers/moe/ep_moe/layer.py +119 -397
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +17 -1
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +5 -30
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +673 -16
- sglang/srt/layers/sampler.py +36 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +66 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +399 -499
- sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +378 -90
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +435 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +28 -23
- sglang/srt/model_executor/model_runner.py +379 -139
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +273 -98
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +14 -37
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +5 -5
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3_moe.py +19 -35
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +6 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +577 -73
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +38 -28
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +192 -47
- sglang/srt/utils/hf_transformers_utils.py +40 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +232 -99
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -11,24 +11,23 @@ from sgl_kernel import (
|
|
|
11
11
|
)
|
|
12
12
|
|
|
13
13
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
14
|
+
deepep_permute_triton_kernel,
|
|
15
|
+
deepep_post_reorder_triton_kernel,
|
|
16
|
+
deepep_run_moe_deep_preprocess,
|
|
14
17
|
post_reorder_triton_kernel_for_cutlass_moe,
|
|
15
18
|
pre_reorder_triton_kernel_for_cutlass_moe,
|
|
16
|
-
|
|
19
|
+
run_moe_ep_preproess,
|
|
17
20
|
)
|
|
18
21
|
|
|
19
22
|
|
|
20
23
|
def cutlass_w4a8_moe(
|
|
21
|
-
start_expert_id: int,
|
|
22
|
-
end_expert_id: int,
|
|
23
|
-
total_num_experts: int,
|
|
24
24
|
a: torch.Tensor,
|
|
25
25
|
w1_q: torch.Tensor,
|
|
26
26
|
w2_q: torch.Tensor,
|
|
27
27
|
w1_scale: torch.Tensor,
|
|
28
28
|
w2_scale: torch.Tensor,
|
|
29
29
|
topk_weights: torch.Tensor,
|
|
30
|
-
|
|
31
|
-
local_topk_ids: torch.Tensor,
|
|
30
|
+
topk_ids: torch.Tensor,
|
|
32
31
|
a_strides1: torch.Tensor,
|
|
33
32
|
b_strides1: torch.Tensor,
|
|
34
33
|
c_strides1: torch.Tensor,
|
|
@@ -64,6 +63,7 @@ def cutlass_w4a8_moe(
|
|
|
64
63
|
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
|
65
64
|
Shape: [num_experts, N // 512, K * 4]
|
|
66
65
|
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
|
66
|
+
- topk_ids (torch.Tensor): The ids of each token->expert mapping.
|
|
67
67
|
- a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
|
|
68
68
|
- b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
|
|
69
69
|
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
|
|
@@ -83,7 +83,7 @@ def cutlass_w4a8_moe(
|
|
|
83
83
|
Returns:
|
|
84
84
|
- torch.Tensor: The fp8 output tensor after applying the MoE layer.
|
|
85
85
|
"""
|
|
86
|
-
assert topk_weights.shape ==
|
|
86
|
+
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
|
87
87
|
assert w1_q.dtype == torch.int8
|
|
88
88
|
assert w2_q.dtype == torch.int8
|
|
89
89
|
assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
|
|
@@ -96,20 +96,21 @@ def cutlass_w4a8_moe(
|
|
|
96
96
|
assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
|
|
97
97
|
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
|
|
98
98
|
assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
|
|
99
|
-
|
|
99
|
+
num_local_experts = w1_q.size(0)
|
|
100
100
|
m = a.size(0)
|
|
101
101
|
k = w1_q.size(2) * 2 # w1_q is transposed and packed
|
|
102
102
|
n = w2_q.size(2) * 2 # w2_q is transposed and packed
|
|
103
|
-
topk =
|
|
103
|
+
topk = topk_ids.size(1)
|
|
104
104
|
|
|
105
105
|
if apply_router_weight_on_input:
|
|
106
106
|
assert topk == 1, "apply_router_weight_on_input is only implemented for topk=1"
|
|
107
107
|
|
|
108
108
|
device = a.device
|
|
109
|
+
topk_ids = torch.where(topk_ids == -1, num_local_experts, topk_ids)
|
|
109
110
|
|
|
110
|
-
_, src2dst, _ =
|
|
111
|
-
|
|
112
|
-
|
|
111
|
+
_, src2dst, _ = run_moe_ep_preproess(
|
|
112
|
+
topk_ids,
|
|
113
|
+
num_local_experts,
|
|
113
114
|
)
|
|
114
115
|
|
|
115
116
|
gateup_input = torch.empty(
|
|
@@ -122,9 +123,9 @@ def cutlass_w4a8_moe(
|
|
|
122
123
|
a,
|
|
123
124
|
gateup_input,
|
|
124
125
|
src2dst,
|
|
125
|
-
|
|
126
|
+
topk_ids,
|
|
126
127
|
a1_scale,
|
|
127
|
-
|
|
128
|
+
num_local_experts,
|
|
128
129
|
topk,
|
|
129
130
|
k,
|
|
130
131
|
BLOCK_SIZE=512,
|
|
@@ -133,16 +134,16 @@ def cutlass_w4a8_moe(
|
|
|
133
134
|
# NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel,
|
|
134
135
|
# they are kept to allow for a quick switch of the permutation logic
|
|
135
136
|
# from the current triton kernel implementation to the cutlass-based one if needed.
|
|
136
|
-
a_map = torch.empty((
|
|
137
|
-
c_map = torch.empty((
|
|
137
|
+
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
|
138
|
+
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
|
138
139
|
get_cutlass_w4a8_moe_mm_data(
|
|
139
|
-
|
|
140
|
+
topk_ids,
|
|
140
141
|
expert_offsets,
|
|
141
142
|
problem_sizes1,
|
|
142
143
|
problem_sizes2,
|
|
143
144
|
a_map,
|
|
144
145
|
c_map,
|
|
145
|
-
|
|
146
|
+
num_local_experts,
|
|
146
147
|
n,
|
|
147
148
|
k,
|
|
148
149
|
)
|
|
@@ -195,12 +196,203 @@ def cutlass_w4a8_moe(
|
|
|
195
196
|
c2,
|
|
196
197
|
output,
|
|
197
198
|
src2dst,
|
|
198
|
-
|
|
199
|
+
topk_ids,
|
|
199
200
|
topk_weights,
|
|
200
|
-
num_experts,
|
|
201
201
|
topk,
|
|
202
|
+
num_local_experts,
|
|
202
203
|
k,
|
|
203
|
-
0,
|
|
204
204
|
BLOCK_SIZE=512,
|
|
205
205
|
)
|
|
206
206
|
return output
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def cutlass_w4a8_moe_deepep_normal(
|
|
210
|
+
a: torch.Tensor,
|
|
211
|
+
w1_q: torch.Tensor,
|
|
212
|
+
w2_q: torch.Tensor,
|
|
213
|
+
w1_scale: torch.Tensor,
|
|
214
|
+
w2_scale: torch.Tensor,
|
|
215
|
+
topk_weights: torch.Tensor,
|
|
216
|
+
topk_ids_: torch.Tensor,
|
|
217
|
+
a_strides1: torch.Tensor,
|
|
218
|
+
b_strides1: torch.Tensor,
|
|
219
|
+
c_strides1: torch.Tensor,
|
|
220
|
+
a_strides2: torch.Tensor,
|
|
221
|
+
b_strides2: torch.Tensor,
|
|
222
|
+
c_strides2: torch.Tensor,
|
|
223
|
+
s_strides13: torch.Tensor,
|
|
224
|
+
s_strides2: torch.Tensor,
|
|
225
|
+
expert_offsets: torch.Tensor,
|
|
226
|
+
problem_sizes1: torch.Tensor,
|
|
227
|
+
problem_sizes2: torch.Tensor,
|
|
228
|
+
a1_scale: Optional[torch.Tensor] = None,
|
|
229
|
+
a2_scale: Optional[torch.Tensor] = None,
|
|
230
|
+
) -> torch.Tensor:
|
|
231
|
+
"""
|
|
232
|
+
This function computes a w4a8-quantized Mixture of Experts (MoE) layer
|
|
233
|
+
using two sets of quantized weights, w1_q and w2_q, and top-k gating
|
|
234
|
+
mechanism. The matrix multiplications are implemented with CUTLASS
|
|
235
|
+
grouped gemm.
|
|
236
|
+
|
|
237
|
+
Parameters:
|
|
238
|
+
- a (torch.Tensor): The input tensor to the MoE layer.
|
|
239
|
+
Shape: [M, K]
|
|
240
|
+
- w1_q (torch.Tensor): The first set of int4-quantized expert weights.
|
|
241
|
+
Shape: [num_experts, N * 2, K // 2]
|
|
242
|
+
(the weights are passed transposed and int4-packed)
|
|
243
|
+
- w2_q (torch.Tensor): The second set of int4-quantized expert weights.
|
|
244
|
+
Shape: [num_experts, K, N // 2]
|
|
245
|
+
(the weights are passed transposed and int4-packed)
|
|
246
|
+
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
|
|
247
|
+
Shape: [num_experts, K // 512, N * 8]
|
|
248
|
+
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
|
249
|
+
Shape: [num_experts, N // 512, K * 4]
|
|
250
|
+
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
|
251
|
+
- a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
|
|
252
|
+
- b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
|
|
253
|
+
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
|
|
254
|
+
- a_strides2 (torch.Tensor): The input strides of the second grouped gemm.
|
|
255
|
+
- b_strides2 (torch.Tensor): The weights strides of the second grouped gemm.
|
|
256
|
+
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
|
|
257
|
+
- s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm.
|
|
258
|
+
- s_strides2 (torch.Tensor): The scale strides of the second grouped gemm.
|
|
259
|
+
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
|
|
260
|
+
Shape: scalar or [1, K]
|
|
261
|
+
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
|
262
|
+
quantize the intermediate result between the gemms.
|
|
263
|
+
Shape: scalar or [1, N]
|
|
264
|
+
- apply_router_weight_on_input (bool): When true, the topk weights are
|
|
265
|
+
applied directly on the inputs. This is only applicable when topk is 1.
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
- torch.Tensor: The fp8 output tensor after applying the MoE layer.
|
|
269
|
+
"""
|
|
270
|
+
assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
|
|
271
|
+
assert w1_q.dtype == torch.int8
|
|
272
|
+
assert w2_q.dtype == torch.int8
|
|
273
|
+
assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
|
|
274
|
+
assert w1_q.shape[2] * 2 == w2_q.shape[1], "Hidden size mismatch w2"
|
|
275
|
+
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
|
|
276
|
+
assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
|
|
277
|
+
assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
|
|
278
|
+
|
|
279
|
+
assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
|
|
280
|
+
assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
|
|
281
|
+
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
|
|
282
|
+
assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
|
|
283
|
+
num_experts = w1_q.size(0)
|
|
284
|
+
m = a.size(0)
|
|
285
|
+
k = w1_q.size(2) * 2 # w1_q is transposed and packed
|
|
286
|
+
n = w2_q.size(2) * 2 # w2_q is transposed and packed
|
|
287
|
+
topk = topk_ids_.size(1)
|
|
288
|
+
|
|
289
|
+
num_experts = w1_q.size(0)
|
|
290
|
+
m = a.size(0)
|
|
291
|
+
k = w1_q.size(2) * 2
|
|
292
|
+
n = w2_q.size(2) * 2
|
|
293
|
+
topk = topk_ids_.size(1)
|
|
294
|
+
device = a.device
|
|
295
|
+
|
|
296
|
+
reorder_topk_ids, src2dst, _ = deepep_run_moe_deep_preprocess(
|
|
297
|
+
topk_ids_, num_experts
|
|
298
|
+
)
|
|
299
|
+
num_total_tokens = reorder_topk_ids.numel()
|
|
300
|
+
gateup_input_pre_reorder = torch.empty(
|
|
301
|
+
(int(num_total_tokens), a.shape[1]),
|
|
302
|
+
device=device,
|
|
303
|
+
dtype=a.dtype,
|
|
304
|
+
)
|
|
305
|
+
deepep_permute_triton_kernel[(a.shape[0],)](
|
|
306
|
+
a,
|
|
307
|
+
gateup_input_pre_reorder,
|
|
308
|
+
src2dst,
|
|
309
|
+
topk_ids_.to(torch.int64),
|
|
310
|
+
None,
|
|
311
|
+
topk,
|
|
312
|
+
a.shape[1],
|
|
313
|
+
BLOCK_SIZE=512,
|
|
314
|
+
)
|
|
315
|
+
gateup_input = torch.empty(
|
|
316
|
+
gateup_input_pre_reorder.shape, dtype=torch.float8_e4m3fn, device=device
|
|
317
|
+
)
|
|
318
|
+
sgl_per_tensor_quant_fp8(
|
|
319
|
+
gateup_input_pre_reorder, gateup_input, a1_scale.float(), True
|
|
320
|
+
)
|
|
321
|
+
del gateup_input_pre_reorder
|
|
322
|
+
local_topk_ids = topk_ids_
|
|
323
|
+
local_topk_ids = (
|
|
324
|
+
torch.where(local_topk_ids == -1, num_experts, topk_ids_).to(torch.int32)
|
|
325
|
+
).contiguous()
|
|
326
|
+
|
|
327
|
+
a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
|
|
328
|
+
c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
|
|
329
|
+
get_cutlass_w4a8_moe_mm_data(
|
|
330
|
+
local_topk_ids,
|
|
331
|
+
expert_offsets,
|
|
332
|
+
problem_sizes1,
|
|
333
|
+
problem_sizes2,
|
|
334
|
+
a_map,
|
|
335
|
+
c_map,
|
|
336
|
+
num_experts,
|
|
337
|
+
n,
|
|
338
|
+
k,
|
|
339
|
+
)
|
|
340
|
+
c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.bfloat16)
|
|
341
|
+
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.bfloat16)
|
|
342
|
+
|
|
343
|
+
cutlass_w4a8_moe_mm(
|
|
344
|
+
c1,
|
|
345
|
+
gateup_input,
|
|
346
|
+
w1_q,
|
|
347
|
+
a1_scale.float(),
|
|
348
|
+
w1_scale,
|
|
349
|
+
expert_offsets[:-1],
|
|
350
|
+
problem_sizes1,
|
|
351
|
+
a_strides1,
|
|
352
|
+
b_strides1,
|
|
353
|
+
c_strides1,
|
|
354
|
+
s_strides13,
|
|
355
|
+
128,
|
|
356
|
+
topk,
|
|
357
|
+
)
|
|
358
|
+
intermediate = torch.empty((m * topk, n), device=device, dtype=torch.bfloat16)
|
|
359
|
+
silu_and_mul(c1, intermediate)
|
|
360
|
+
|
|
361
|
+
intermediate_q = torch.empty(
|
|
362
|
+
intermediate.shape, dtype=torch.float8_e4m3fn, device=device
|
|
363
|
+
)
|
|
364
|
+
sgl_per_tensor_quant_fp8(intermediate, intermediate_q, a2_scale.float(), True)
|
|
365
|
+
|
|
366
|
+
cutlass_w4a8_moe_mm(
|
|
367
|
+
c2,
|
|
368
|
+
intermediate_q,
|
|
369
|
+
w2_q,
|
|
370
|
+
a2_scale.float(),
|
|
371
|
+
w2_scale,
|
|
372
|
+
expert_offsets[:-1],
|
|
373
|
+
problem_sizes2,
|
|
374
|
+
a_strides2,
|
|
375
|
+
b_strides2,
|
|
376
|
+
c_strides2,
|
|
377
|
+
s_strides2,
|
|
378
|
+
128,
|
|
379
|
+
topk,
|
|
380
|
+
)
|
|
381
|
+
num_tokens = src2dst.shape[0] // topk
|
|
382
|
+
output = torch.empty(
|
|
383
|
+
(num_tokens, c2.shape[1]),
|
|
384
|
+
device=c2.device,
|
|
385
|
+
dtype=torch.bfloat16,
|
|
386
|
+
)
|
|
387
|
+
deepep_post_reorder_triton_kernel[(num_tokens,)](
|
|
388
|
+
c2,
|
|
389
|
+
output,
|
|
390
|
+
src2dst,
|
|
391
|
+
topk_ids_,
|
|
392
|
+
topk_weights,
|
|
393
|
+
topk,
|
|
394
|
+
c2.shape[1],
|
|
395
|
+
BLOCK_SIZE=512,
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
return output
|