sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +330 -156
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +8 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +134 -23
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +70 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +66 -66
- sglang/srt/entrypoints/grpc_server.py +431 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +120 -8
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +42 -4
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +18 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +4 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +12 -8
- sglang/srt/layers/attention/flashinfer_backend.py +248 -21
- sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +11 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +45 -15
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +147 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
- sglang/srt/layers/moe/ep_moe/layer.py +119 -397
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +17 -1
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +5 -30
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +673 -16
- sglang/srt/layers/sampler.py +36 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +66 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +399 -499
- sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +378 -90
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +435 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +28 -23
- sglang/srt/model_executor/model_runner.py +379 -139
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +273 -98
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +14 -37
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +5 -5
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3_moe.py +19 -35
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +6 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +577 -73
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +38 -28
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +192 -47
- sglang/srt/utils/hf_transformers_utils.py +40 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +232 -99
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
|
-
from typing import TYPE_CHECKING, Optional
|
|
4
|
+
from typing import TYPE_CHECKING, Optional
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
@@ -55,6 +55,25 @@ class AttentionBackend(ABC):
|
|
|
55
55
|
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
|
|
56
56
|
raise NotImplementedError()
|
|
57
57
|
|
|
58
|
+
def get_verify_buffers_to_fill_after_draft(self):
|
|
59
|
+
"""
|
|
60
|
+
Return buffers of verify attention kernels that needs to be filled after draft.
|
|
61
|
+
|
|
62
|
+
Typically, these are tree mask and position buffers.
|
|
63
|
+
"""
|
|
64
|
+
return [None, None]
|
|
65
|
+
|
|
66
|
+
def update_verify_buffers_to_fill_after_draft(
|
|
67
|
+
self, spec_info: SpecInput, cuda_graph_bs: Optional[int]
|
|
68
|
+
):
|
|
69
|
+
"""
|
|
70
|
+
Update the buffers returned by get_verify_fill_after_draft_buffers if needed.
|
|
71
|
+
|
|
72
|
+
Here, we need to redo the computation of all metadata of the attention backend
|
|
73
|
+
that depends on tree mask and position buffers.
|
|
74
|
+
"""
|
|
75
|
+
raise NotImplementedError()
|
|
76
|
+
|
|
58
77
|
def forward(
|
|
59
78
|
self,
|
|
60
79
|
q: torch.Tensor,
|
|
@@ -5,8 +5,8 @@ from typing import TYPE_CHECKING
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
7
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
|
8
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
9
8
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
9
|
+
from sglang.srt.server_args import get_global_server_args
|
|
10
10
|
|
|
11
11
|
if TYPE_CHECKING:
|
|
12
12
|
from sglang.srt.layers.radix_attention import RadixAttention
|
|
@@ -42,7 +42,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|
|
42
42
|
# TODO: Change the hard-coded block_seq_num
|
|
43
43
|
self.BLOCK_SEQ = 128
|
|
44
44
|
|
|
45
|
-
if
|
|
45
|
+
if get_global_server_args().triton_attention_reduce_in_fp32:
|
|
46
46
|
self.reduce_dtype = torch.float32
|
|
47
47
|
else:
|
|
48
48
|
self.reduce_dtype = torch.float16
|
|
@@ -5,7 +5,6 @@
|
|
|
5
5
|
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
|
6
6
|
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
|
7
7
|
|
|
8
|
-
import math
|
|
9
8
|
|
|
10
9
|
import torch
|
|
11
10
|
import torch.nn.functional as F
|
|
@@ -13,6 +12,8 @@ import triton
|
|
|
13
12
|
import triton.language as tl
|
|
14
13
|
from einops import rearrange
|
|
15
14
|
|
|
15
|
+
from sglang.srt.utils import device_context
|
|
16
|
+
|
|
16
17
|
|
|
17
18
|
def rms_norm_ref(
|
|
18
19
|
x,
|
|
@@ -158,7 +159,7 @@ def _layer_norm_fwd(
|
|
|
158
159
|
# heuristics for number of warps
|
|
159
160
|
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
|
160
161
|
grid = (M, ngroups)
|
|
161
|
-
with
|
|
162
|
+
with device_context(x.device):
|
|
162
163
|
_layer_norm_fwd_1pass_kernel[grid](
|
|
163
164
|
x,
|
|
164
165
|
out,
|
|
@@ -181,6 +182,45 @@ def _layer_norm_fwd(
|
|
|
181
182
|
return out, mean, rstd
|
|
182
183
|
|
|
183
184
|
|
|
185
|
+
def rms_norm_gated(
|
|
186
|
+
*,
|
|
187
|
+
x,
|
|
188
|
+
weight,
|
|
189
|
+
bias,
|
|
190
|
+
z=None,
|
|
191
|
+
eps=1e-6,
|
|
192
|
+
group_size=None,
|
|
193
|
+
norm_before_gate=True,
|
|
194
|
+
is_rms_norm=False,
|
|
195
|
+
):
|
|
196
|
+
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
|
197
|
+
|
|
198
|
+
x_shape_og = x.shape
|
|
199
|
+
# reshape input data into 2D tensor
|
|
200
|
+
x = x.reshape(-1, x.shape[-1])
|
|
201
|
+
if x.stride(-1) != 1:
|
|
202
|
+
x = x.contiguous()
|
|
203
|
+
if z is not None:
|
|
204
|
+
assert z.shape == x_shape_og
|
|
205
|
+
z = z.reshape(-1, z.shape[-1])
|
|
206
|
+
if z.stride(-1) != 1:
|
|
207
|
+
z = z.contiguous()
|
|
208
|
+
weight = weight.contiguous()
|
|
209
|
+
if bias is not None:
|
|
210
|
+
bias = bias.contiguous()
|
|
211
|
+
y, mean, rstd = _layer_norm_fwd(
|
|
212
|
+
x,
|
|
213
|
+
weight,
|
|
214
|
+
bias,
|
|
215
|
+
eps,
|
|
216
|
+
z=z,
|
|
217
|
+
group_size=group_size,
|
|
218
|
+
norm_before_gate=norm_before_gate,
|
|
219
|
+
is_rms_norm=is_rms_norm,
|
|
220
|
+
)
|
|
221
|
+
return y.reshape(x_shape_og)
|
|
222
|
+
|
|
223
|
+
|
|
184
224
|
class LayerNormFn(torch.autograd.Function):
|
|
185
225
|
|
|
186
226
|
@staticmethod
|
|
@@ -195,32 +235,16 @@ class LayerNormFn(torch.autograd.Function):
|
|
|
195
235
|
norm_before_gate=True,
|
|
196
236
|
is_rms_norm=False,
|
|
197
237
|
):
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
if x.stride(-1) != 1:
|
|
204
|
-
x = x.contiguous()
|
|
205
|
-
if z is not None:
|
|
206
|
-
assert z.shape == x_shape_og
|
|
207
|
-
z = z.reshape(-1, z.shape[-1])
|
|
208
|
-
if z.stride(-1) != 1:
|
|
209
|
-
z = z.contiguous()
|
|
210
|
-
weight = weight.contiguous()
|
|
211
|
-
if bias is not None:
|
|
212
|
-
bias = bias.contiguous()
|
|
213
|
-
y, mean, rstd = _layer_norm_fwd(
|
|
214
|
-
x,
|
|
215
|
-
weight,
|
|
216
|
-
bias,
|
|
217
|
-
eps,
|
|
238
|
+
return rms_norm_gated(
|
|
239
|
+
x=x,
|
|
240
|
+
weight=weight,
|
|
241
|
+
bias=bias,
|
|
242
|
+
eps=eps,
|
|
218
243
|
z=z,
|
|
219
244
|
group_size=group_size,
|
|
220
245
|
norm_before_gate=norm_before_gate,
|
|
221
246
|
is_rms_norm=is_rms_norm,
|
|
222
247
|
)
|
|
223
|
-
return y.reshape(x_shape_og)
|
|
224
248
|
|
|
225
249
|
|
|
226
250
|
def layernorm_fn(
|
|
@@ -238,14 +262,6 @@ def layernorm_fn(
|
|
|
238
262
|
)
|
|
239
263
|
|
|
240
264
|
|
|
241
|
-
def rmsnorm_fn(
|
|
242
|
-
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
|
|
243
|
-
):
|
|
244
|
-
return LayerNormFn.apply(
|
|
245
|
-
x, weight, bias, z, eps, group_size, norm_before_gate, True
|
|
246
|
-
)
|
|
247
|
-
|
|
248
|
-
|
|
249
265
|
class LayerNorm(torch.nn.Module):
|
|
250
266
|
|
|
251
267
|
def __init__(
|
|
@@ -284,6 +300,7 @@ class LayerNorm(torch.nn.Module):
|
|
|
284
300
|
group_size=self.group_size,
|
|
285
301
|
eps=self.eps,
|
|
286
302
|
norm_before_gate=self.norm_before_gate,
|
|
303
|
+
is_rms_norm=False,
|
|
287
304
|
)
|
|
288
305
|
|
|
289
306
|
|
|
@@ -315,7 +332,7 @@ class RMSNorm(torch.nn.Module):
|
|
|
315
332
|
|
|
316
333
|
def forward(self, x, z=None):
|
|
317
334
|
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
|
318
|
-
return
|
|
335
|
+
return layernorm_fn(
|
|
319
336
|
x,
|
|
320
337
|
self.weight,
|
|
321
338
|
self.bias,
|
|
@@ -323,4 +340,5 @@ class RMSNorm(torch.nn.Module):
|
|
|
323
340
|
eps=self.eps,
|
|
324
341
|
group_size=self.group_size,
|
|
325
342
|
norm_before_gate=self.norm_before_gate,
|
|
343
|
+
is_rms_norm=True,
|
|
326
344
|
)
|
|
@@ -9,8 +9,6 @@ import triton
|
|
|
9
9
|
import triton.language as tl
|
|
10
10
|
|
|
11
11
|
from sglang.srt.layers.attention.fla.index import prepare_chunk_indices
|
|
12
|
-
from sglang.srt.layers.attention.fla.op import safe_exp
|
|
13
|
-
from sglang.srt.layers.attention.fla.utils import check_shared_mem
|
|
14
12
|
|
|
15
13
|
|
|
16
14
|
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from typing import TYPE_CHECKING, Optional
|
|
4
|
+
from typing import TYPE_CHECKING, Optional
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
import torch
|
|
@@ -10,8 +10,9 @@ import triton.language as tl
|
|
|
10
10
|
|
|
11
11
|
from sglang.srt.configs.model_config import AttentionArch
|
|
12
12
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
|
13
|
-
from sglang.srt.
|
|
13
|
+
from sglang.srt.layers.radix_attention import AttentionType
|
|
14
14
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
|
15
|
+
from sglang.srt.server_args import get_global_server_args
|
|
15
16
|
from sglang.srt.speculative.spec_info import SpecInput
|
|
16
17
|
|
|
17
18
|
if TYPE_CHECKING:
|
|
@@ -705,7 +706,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
|
705
706
|
q = q.to(self.kv_cache_dtype)
|
|
706
707
|
q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None
|
|
707
708
|
k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
|
|
708
|
-
causal =
|
|
709
|
+
causal = True
|
|
710
|
+
if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
|
|
711
|
+
causal = False
|
|
709
712
|
|
|
710
713
|
# Check if we should use local attention
|
|
711
714
|
use_local_attn = (
|
|
@@ -754,7 +757,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|
|
754
757
|
|
|
755
758
|
# Use Flash Attention for prefill
|
|
756
759
|
if not self.use_mla:
|
|
757
|
-
assert self.fa_impl_ver in [3], "Only FA3 support here"
|
|
758
760
|
# Do multi-head attention
|
|
759
761
|
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
|
|
760
762
|
layer.layer_id
|
|
@@ -828,7 +830,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
|
828
830
|
):
|
|
829
831
|
# Do multi-head attention with chunked prefix cache
|
|
830
832
|
if forward_batch.attn_attend_prefix_cache:
|
|
831
|
-
assert not
|
|
833
|
+
assert not get_global_server_args().disable_chunked_prefix_cache
|
|
832
834
|
# MHA for chunked prefix kv cache when running model with MLA
|
|
833
835
|
assert forward_batch.prefix_chunk_idx is not None
|
|
834
836
|
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
|
@@ -1006,7 +1008,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
|
1006
1008
|
if layer.sliding_window_size is not None and layer.sliding_window_size > -1
|
|
1007
1009
|
else (-1, -1)
|
|
1008
1010
|
)
|
|
1009
|
-
causal =
|
|
1011
|
+
causal = True
|
|
1012
|
+
if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
|
|
1013
|
+
causal = False
|
|
1010
1014
|
|
|
1011
1015
|
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
|
1012
1016
|
kwargs = {}
|
|
@@ -2316,7 +2320,7 @@ class FlashAttentionMultiStepBackend:
|
|
|
2316
2320
|
self.topk = topk
|
|
2317
2321
|
self.speculative_num_steps = speculative_num_steps
|
|
2318
2322
|
self.attn_backends = []
|
|
2319
|
-
for i in range(self.speculative_num_steps):
|
|
2323
|
+
for i in range(self.speculative_num_steps - 1):
|
|
2320
2324
|
self.attn_backends.append(
|
|
2321
2325
|
FlashAttentionBackend(
|
|
2322
2326
|
model_runner,
|
|
@@ -2331,7 +2335,7 @@ class FlashAttentionMultiStepBackend:
|
|
|
2331
2335
|
self.attn_backends[i].init_forward_metadata(forward_batch)
|
|
2332
2336
|
|
|
2333
2337
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
|
2334
|
-
for i in range(self.speculative_num_steps):
|
|
2338
|
+
for i in range(self.speculative_num_steps - 1):
|
|
2335
2339
|
self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
|
|
2336
2340
|
|
|
2337
2341
|
def init_forward_metadata_capture_cuda_graph(
|