sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +330 -156
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +8 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +134 -23
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +70 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +66 -66
- sglang/srt/entrypoints/grpc_server.py +431 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +120 -8
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +42 -4
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +18 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +4 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +12 -8
- sglang/srt/layers/attention/flashinfer_backend.py +248 -21
- sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +11 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +45 -15
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +147 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
- sglang/srt/layers/moe/ep_moe/layer.py +119 -397
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +17 -1
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +5 -30
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +673 -16
- sglang/srt/layers/sampler.py +36 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +66 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +399 -499
- sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +378 -90
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +435 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +28 -23
- sglang/srt/model_executor/model_runner.py +379 -139
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +273 -98
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +14 -37
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +5 -5
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3_moe.py +19 -35
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +6 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +577 -73
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +38 -28
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +192 -47
- sglang/srt/utils/hf_transformers_utils.py +40 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +232 -99
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -40,8 +40,9 @@ from sglang.srt.layers.moe import (
|
|
|
40
40
|
get_moe_a2a_backend,
|
|
41
41
|
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
|
42
42
|
)
|
|
43
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
44
43
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
44
|
+
from sglang.srt.server_args import get_global_server_args
|
|
45
|
+
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
|
45
46
|
from sglang.srt.utils import (
|
|
46
47
|
get_bool_env_var,
|
|
47
48
|
is_cuda,
|
|
@@ -168,7 +169,7 @@ class LayerScatterModes:
|
|
|
168
169
|
|
|
169
170
|
|
|
170
171
|
def enable_moe_dense_fully_dp():
|
|
171
|
-
return
|
|
172
|
+
return get_global_server_args().moe_dense_tp_size == 1
|
|
172
173
|
|
|
173
174
|
|
|
174
175
|
class LayerCommunicator:
|
|
@@ -211,6 +212,10 @@ class LayerCommunicator:
|
|
|
211
212
|
)
|
|
212
213
|
)
|
|
213
214
|
|
|
215
|
+
self._speculative_algo = SpeculativeAlgorithm.from_string(
|
|
216
|
+
get_global_server_args().speculative_algorithm
|
|
217
|
+
)
|
|
218
|
+
|
|
214
219
|
def prepare_attn(
|
|
215
220
|
self,
|
|
216
221
|
hidden_states: torch.Tensor,
|
|
@@ -314,11 +319,10 @@ class LayerCommunicator:
|
|
|
314
319
|
def should_fuse_mlp_allreduce_with_next_layer(
|
|
315
320
|
self, forward_batch: ForwardBatch
|
|
316
321
|
) -> bool:
|
|
317
|
-
speculative_algo = global_server_args_dict.get("speculative_algorithm", None)
|
|
318
322
|
if (
|
|
319
323
|
is_dp_attention_enabled()
|
|
320
|
-
and
|
|
321
|
-
and
|
|
324
|
+
and self._speculative_algo is not None
|
|
325
|
+
and self._speculative_algo.is_eagle()
|
|
322
326
|
):
|
|
323
327
|
return False
|
|
324
328
|
|
|
@@ -333,7 +337,7 @@ class LayerCommunicator:
|
|
|
333
337
|
static_conditions_met = (
|
|
334
338
|
(not self.is_last_layer)
|
|
335
339
|
and (self._context.tp_size > 1)
|
|
336
|
-
and
|
|
340
|
+
and get_global_server_args().enable_flashinfer_allreduce_fusion
|
|
337
341
|
and _is_flashinfer_available
|
|
338
342
|
)
|
|
339
343
|
|
|
@@ -531,7 +535,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
|
531
535
|
(_is_sm100_supported or _is_sm90_supported)
|
|
532
536
|
and _is_flashinfer_available
|
|
533
537
|
and hasattr(layernorm, "forward_with_allreduce_fusion")
|
|
534
|
-
and
|
|
538
|
+
and get_global_server_args().enable_flashinfer_allreduce_fusion
|
|
535
539
|
and hidden_states.shape[0] <= 4096
|
|
536
540
|
):
|
|
537
541
|
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
|
|
@@ -7,11 +7,10 @@ from typing import Dict, List, Tuple
|
|
|
7
7
|
import torch
|
|
8
8
|
from tqdm import tqdm
|
|
9
9
|
|
|
10
|
-
from sglang.srt.
|
|
11
|
-
|
|
12
|
-
)
|
|
10
|
+
from sglang.srt.environ import envs
|
|
11
|
+
from sglang.srt.layers.deep_gemm_wrapper.configurer import ENABLE_JIT_DEEPGEMM
|
|
13
12
|
from sglang.srt.server_args import ServerArgs
|
|
14
|
-
from sglang.srt.utils import ceil_div, get_bool_env_var
|
|
13
|
+
from sglang.srt.utils import ceil_div, get_bool_env_var
|
|
15
14
|
|
|
16
15
|
logger = logging.getLogger(__name__)
|
|
17
16
|
|
|
@@ -20,12 +19,9 @@ if ENABLE_JIT_DEEPGEMM:
|
|
|
20
19
|
|
|
21
20
|
|
|
22
21
|
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
|
|
23
|
-
_ENABLE_JIT_DEEPGEMM_PRECOMPILE =
|
|
24
|
-
"SGL_JIT_DEEPGEMM_PRECOMPILE", "true"
|
|
25
|
-
)
|
|
22
|
+
_ENABLE_JIT_DEEPGEMM_PRECOMPILE = envs.SGLANG_JIT_DEEPGEMM_PRECOMPILE.get()
|
|
26
23
|
_DO_COMPILE_ALL = True
|
|
27
24
|
_IS_FIRST_RANK_ON_NODE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true")
|
|
28
|
-
_COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
|
|
29
25
|
_IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false")
|
|
30
26
|
|
|
31
27
|
# Force redirect deep_gemm cache_dir
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
|
|
3
|
-
from sglang.srt.
|
|
3
|
+
from sglang.srt.environ import envs
|
|
4
|
+
from sglang.srt.utils import get_device_sm, is_blackwell
|
|
4
5
|
|
|
5
6
|
logger = logging.getLogger(__name__)
|
|
6
7
|
|
|
@@ -11,11 +12,11 @@ def _compute_enable_deep_gemm():
|
|
|
11
12
|
return False
|
|
12
13
|
|
|
13
14
|
try:
|
|
14
|
-
import deep_gemm
|
|
15
|
+
import deep_gemm # noqa: F401
|
|
15
16
|
except ImportError:
|
|
16
17
|
return False
|
|
17
18
|
|
|
18
|
-
return
|
|
19
|
+
return envs.SGLANG_ENABLE_JIT_DEEPGEMM.get()
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
|
|
@@ -4,8 +4,8 @@ from typing import Tuple
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from sglang.srt.layers.
|
|
8
|
-
from sglang.srt.layers.
|
|
7
|
+
from sglang.srt.layers.deep_gemm_wrapper import compile_utils
|
|
8
|
+
from sglang.srt.layers.deep_gemm_wrapper.configurer import ( # noqa: F401
|
|
9
9
|
DEEPGEMM_BLACKWELL,
|
|
10
10
|
DEEPGEMM_SCALE_UE8M0,
|
|
11
11
|
ENABLE_JIT_DEEPGEMM,
|
|
@@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
|
|
|
17
17
|
|
|
18
18
|
if ENABLE_JIT_DEEPGEMM:
|
|
19
19
|
import deep_gemm
|
|
20
|
-
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
|
|
20
|
+
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor # noqa: F401
|
|
21
21
|
|
|
22
22
|
_SANITY_CHECK = get_bool_env_var("SGLANG_DEEPGEMM_SANITY_CHECK")
|
|
23
23
|
|
|
@@ -87,6 +87,7 @@ class _DpGatheredBufferWrapper:
|
|
|
87
87
|
_global_dp_buffer_len: int
|
|
88
88
|
_local_dp_buffer_len: int
|
|
89
89
|
_global_num_tokens: Optional[List[int]]
|
|
90
|
+
_is_extend_in_batch: bool
|
|
90
91
|
|
|
91
92
|
@classmethod
|
|
92
93
|
def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device):
|
|
@@ -145,6 +146,14 @@ class _DpGatheredBufferWrapper:
|
|
|
145
146
|
def get_dp_device(cls) -> torch.device:
|
|
146
147
|
return cls._device
|
|
147
148
|
|
|
149
|
+
@classmethod
|
|
150
|
+
def set_is_extend_in_batch(cls, is_extend_in_batch: bool):
|
|
151
|
+
cls._is_extend_in_batch = is_extend_in_batch
|
|
152
|
+
|
|
153
|
+
@classmethod
|
|
154
|
+
def get_is_extend_in_batch(cls) -> bool:
|
|
155
|
+
return cls._is_extend_in_batch
|
|
156
|
+
|
|
148
157
|
|
|
149
158
|
def set_dp_buffer_len(
|
|
150
159
|
global_dp_buffer_len: int,
|
|
@@ -188,6 +197,14 @@ def get_dp_device() -> torch.device:
|
|
|
188
197
|
return _DpGatheredBufferWrapper.get_dp_device()
|
|
189
198
|
|
|
190
199
|
|
|
200
|
+
def set_is_extend_in_batch(is_extend_in_batch: bool):
|
|
201
|
+
_DpGatheredBufferWrapper.set_is_extend_in_batch(is_extend_in_batch)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def get_is_extend_in_batch() -> bool:
|
|
205
|
+
return _DpGatheredBufferWrapper.get_is_extend_in_batch()
|
|
206
|
+
|
|
207
|
+
|
|
191
208
|
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
|
|
192
209
|
if not enable_dp_attention:
|
|
193
210
|
return tp_rank, tp_size, 0
|
sglang/srt/layers/layernorm.py
CHANGED
|
@@ -42,13 +42,16 @@ _is_cpu_amx_available = cpu_has_amx_support()
|
|
|
42
42
|
_is_cpu = is_cpu()
|
|
43
43
|
_is_xpu = is_xpu()
|
|
44
44
|
|
|
45
|
-
if _is_cuda:
|
|
46
|
-
if _is_flashinfer_available:
|
|
47
|
-
|
|
48
|
-
else:
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
45
|
+
if _is_cuda or _is_xpu:
|
|
46
|
+
# if _is_flashinfer_available:
|
|
47
|
+
# from flashinfer.norm import fused_add_rmsnorm
|
|
48
|
+
# else:
|
|
49
|
+
from sgl_kernel import (
|
|
50
|
+
fused_add_rmsnorm,
|
|
51
|
+
gemma_fused_add_rmsnorm,
|
|
52
|
+
gemma_rmsnorm,
|
|
53
|
+
rmsnorm,
|
|
54
|
+
)
|
|
52
55
|
if _use_aiter:
|
|
53
56
|
from aiter import rmsnorm2d_fwd as rms_norm
|
|
54
57
|
from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm
|
|
@@ -211,6 +214,19 @@ class RMSNorm(CustomOp):
|
|
|
211
214
|
else:
|
|
212
215
|
return self.forward_native(x, residual)
|
|
213
216
|
|
|
217
|
+
def forward_xpu(
|
|
218
|
+
self,
|
|
219
|
+
x: torch.Tensor,
|
|
220
|
+
residual: Optional[torch.Tensor] = None,
|
|
221
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
222
|
+
if self.variance_size_override is not None:
|
|
223
|
+
return self.forward_native(x, residual)
|
|
224
|
+
if residual is not None:
|
|
225
|
+
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
|
|
226
|
+
return x, residual
|
|
227
|
+
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
|
|
228
|
+
return out
|
|
229
|
+
|
|
214
230
|
def forward_with_allreduce_fusion(
|
|
215
231
|
self,
|
|
216
232
|
x: torch.Tensor,
|
|
@@ -258,6 +274,19 @@ class GemmaRMSNorm(CustomOp):
|
|
|
258
274
|
if _is_hip:
|
|
259
275
|
self._forward_method = self.forward_native
|
|
260
276
|
|
|
277
|
+
def _forward_impl(
|
|
278
|
+
self,
|
|
279
|
+
x: torch.Tensor,
|
|
280
|
+
residual: Optional[torch.Tensor] = None,
|
|
281
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
282
|
+
if residual is not None:
|
|
283
|
+
gemma_fused_add_rmsnorm(
|
|
284
|
+
x, residual, self.weight.data, self.variance_epsilon
|
|
285
|
+
)
|
|
286
|
+
return x, residual
|
|
287
|
+
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
|
|
288
|
+
return out
|
|
289
|
+
|
|
261
290
|
def forward_native(
|
|
262
291
|
self,
|
|
263
292
|
x: torch.Tensor,
|
|
@@ -280,13 +309,7 @@ class GemmaRMSNorm(CustomOp):
|
|
|
280
309
|
x: torch.Tensor,
|
|
281
310
|
residual: Optional[torch.Tensor] = None,
|
|
282
311
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
283
|
-
|
|
284
|
-
gemma_fused_add_rmsnorm(
|
|
285
|
-
x, residual, self.weight.data, self.variance_epsilon
|
|
286
|
-
)
|
|
287
|
-
return x, residual
|
|
288
|
-
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
|
|
289
|
-
return out
|
|
312
|
+
return self._forward_impl(x, residual)
|
|
290
313
|
|
|
291
314
|
def forward_npu(
|
|
292
315
|
self,
|
|
@@ -300,6 +323,13 @@ class GemmaRMSNorm(CustomOp):
|
|
|
300
323
|
x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon)
|
|
301
324
|
return x if residual is None else (x, residual)
|
|
302
325
|
|
|
326
|
+
def forward_xpu(
|
|
327
|
+
self,
|
|
328
|
+
x: torch.Tensor,
|
|
329
|
+
residual: Optional[torch.Tensor] = None,
|
|
330
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
331
|
+
return self._forward_impl(x, residual)
|
|
332
|
+
|
|
303
333
|
|
|
304
334
|
class Gemma3RMSNorm(CustomOp):
|
|
305
335
|
def __init__(self, dim: int, eps: float = 1e-6):
|
|
@@ -335,4 +365,4 @@ if not (
|
|
|
335
365
|
logger.info(
|
|
336
366
|
"sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
|
|
337
367
|
)
|
|
338
|
-
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
|
368
|
+
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm # noqa: F401
|
sglang/srt/layers/linear.py
CHANGED
|
@@ -32,7 +32,7 @@ from sglang.srt.layers.parameter import (
|
|
|
32
32
|
)
|
|
33
33
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
|
34
34
|
from sglang.srt.layers.utils import pad_or_narrow_weight
|
|
35
|
-
from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs
|
|
35
|
+
from sglang.srt.utils import get_bool_env_var, is_cpu, is_hip, is_npu, set_weight_attrs
|
|
36
36
|
|
|
37
37
|
if TYPE_CHECKING:
|
|
38
38
|
from sglang.srt.layers.quantization.base_config import (
|
|
@@ -40,12 +40,18 @@ if TYPE_CHECKING:
|
|
|
40
40
|
QuantizeMethodBase,
|
|
41
41
|
)
|
|
42
42
|
|
|
43
|
+
_is_hip = is_hip()
|
|
44
|
+
_disable_hip_linear_quant = _is_hip and get_bool_env_var(
|
|
45
|
+
"SGLANG_ROCM_DISABLE_LINEARQUANT"
|
|
46
|
+
)
|
|
47
|
+
|
|
43
48
|
logger = logging.getLogger(__name__)
|
|
44
49
|
|
|
45
50
|
WEIGHT_LOADER_V2_SUPPORTED = [
|
|
46
51
|
"CompressedTensorsLinearMethod",
|
|
47
52
|
"AWQMarlinLinearMethod",
|
|
48
53
|
"AWQLinearMethod",
|
|
54
|
+
"AWQLinearAscendMethod",
|
|
49
55
|
"GPTQMarlinLinearMethod",
|
|
50
56
|
"Fp8LinearMethod",
|
|
51
57
|
"BlockInt8LinearMethod",
|
|
@@ -824,6 +830,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
|
824
830
|
self.num_kv_heads * self.head_size * tp_size, # v_proj
|
|
825
831
|
]
|
|
826
832
|
self.use_presharded_weights = load_presharded_attn
|
|
833
|
+
quant_config = None if _disable_hip_linear_quant else quant_config
|
|
827
834
|
|
|
828
835
|
super().__init__(
|
|
829
836
|
input_size=input_size,
|
|
@@ -1225,6 +1232,7 @@ class RowParallelLinear(LinearBase):
|
|
|
1225
1232
|
tp_size: Optional[int] = None,
|
|
1226
1233
|
use_presharded_weights: bool = False,
|
|
1227
1234
|
):
|
|
1235
|
+
quant_config = None if _disable_hip_linear_quant else quant_config
|
|
1228
1236
|
super().__init__(
|
|
1229
1237
|
input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
|
|
1230
1238
|
)
|
|
@@ -38,17 +38,15 @@ from sglang.srt.layers.dp_attention import (
|
|
|
38
38
|
get_dp_device,
|
|
39
39
|
get_dp_dtype,
|
|
40
40
|
get_dp_hidden_size,
|
|
41
|
-
get_global_dp_buffer,
|
|
42
41
|
get_local_attention_dp_size,
|
|
43
|
-
set_dp_buffer_len,
|
|
44
42
|
)
|
|
45
43
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
|
46
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
47
44
|
from sglang.srt.model_executor.forward_batch_info import (
|
|
48
45
|
CaptureHiddenMode,
|
|
49
46
|
ForwardBatch,
|
|
50
47
|
ForwardMode,
|
|
51
48
|
)
|
|
49
|
+
from sglang.srt.server_args import get_global_server_args
|
|
52
50
|
from sglang.srt.utils import dump_to_file, is_npu, use_intel_amx_backend
|
|
53
51
|
|
|
54
52
|
logger = logging.getLogger(__name__)
|
|
@@ -60,13 +58,14 @@ _is_npu = is_npu()
|
|
|
60
58
|
class LogitsProcessorOutput:
|
|
61
59
|
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
|
62
60
|
# The logits of the next tokens. shape: [#seq, vocab_size]
|
|
63
|
-
|
|
61
|
+
# Can be None for certain prefill-only requests (e.g., multi-item scoring) that don't need next token generation
|
|
62
|
+
next_token_logits: Optional[torch.Tensor]
|
|
64
63
|
# Used by speculative decoding (EAGLE)
|
|
65
64
|
# The last hidden layers
|
|
66
65
|
hidden_states: Optional[torch.Tensor] = None
|
|
67
66
|
|
|
68
67
|
## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler
|
|
69
|
-
# he log probs of output tokens, if
|
|
68
|
+
# he log probs of output tokens, if SGLANG_RETURN_ORIGINAL_LOGPROB = True, will get the log probs before applying temperature. If False, will get the log probs before applying temperature.
|
|
70
69
|
next_token_logprobs: Optional[torch.Tensor] = None
|
|
71
70
|
# The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
|
|
72
71
|
next_token_top_logprobs_val: Optional[List] = None
|
|
@@ -85,7 +84,10 @@ class LogitsProcessorOutput:
|
|
|
85
84
|
input_top_logprobs_val: List = None
|
|
86
85
|
input_top_logprobs_idx: List = None
|
|
87
86
|
# The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids)
|
|
88
|
-
|
|
87
|
+
# Can contain either lists or GPU tensors (for delayed GPU-to-CPU transfer optimization)
|
|
88
|
+
input_token_ids_logprobs_val: Optional[List[Union[List[float], torch.Tensor]]] = (
|
|
89
|
+
None
|
|
90
|
+
)
|
|
89
91
|
input_token_ids_logprobs_idx: Optional[List] = None
|
|
90
92
|
|
|
91
93
|
|
|
@@ -127,10 +129,16 @@ class LogitsMetadata:
|
|
|
127
129
|
# for padding
|
|
128
130
|
padded_static_len: int = -1
|
|
129
131
|
|
|
132
|
+
# Whether this batch is prefill-only (no token generation needed)
|
|
133
|
+
is_prefill_only: bool = False
|
|
134
|
+
|
|
130
135
|
@classmethod
|
|
131
136
|
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
|
132
137
|
if (
|
|
133
|
-
|
|
138
|
+
(
|
|
139
|
+
forward_batch.forward_mode.is_extend()
|
|
140
|
+
or forward_batch.forward_mode.is_split_prefill()
|
|
141
|
+
)
|
|
134
142
|
and forward_batch.return_logprob
|
|
135
143
|
and not forward_batch.forward_mode.is_target_verify()
|
|
136
144
|
):
|
|
@@ -169,6 +177,7 @@ class LogitsMetadata:
|
|
|
169
177
|
token_ids_logprobs=forward_batch.token_ids_logprobs,
|
|
170
178
|
extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
|
|
171
179
|
padded_static_len=forward_batch.padded_static_len,
|
|
180
|
+
is_prefill_only=forward_batch.is_prefill_only,
|
|
172
181
|
global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
|
|
173
182
|
dp_local_start_pos=forward_batch.dp_local_start_pos,
|
|
174
183
|
dp_local_num_tokens=forward_batch.dp_local_num_tokens,
|
|
@@ -219,8 +228,8 @@ class LogitsProcessor(nn.Module):
|
|
|
219
228
|
super().__init__()
|
|
220
229
|
self.config = config
|
|
221
230
|
self.logit_scale = logit_scale
|
|
222
|
-
self.use_attn_tp_group =
|
|
223
|
-
self.use_fp32_lm_head =
|
|
231
|
+
self.use_attn_tp_group = get_global_server_args().enable_dp_lm_head
|
|
232
|
+
self.use_fp32_lm_head = get_global_server_args().enable_fp32_lm_head
|
|
224
233
|
if self.use_attn_tp_group:
|
|
225
234
|
self.attn_tp_size = get_attention_tp_size()
|
|
226
235
|
self.do_tensor_parallel_all_gather = (
|
|
@@ -243,8 +252,110 @@ class LogitsProcessor(nn.Module):
|
|
|
243
252
|
):
|
|
244
253
|
self.final_logit_softcapping = None
|
|
245
254
|
|
|
246
|
-
self.debug_tensor_dump_output_folder =
|
|
247
|
-
|
|
255
|
+
self.debug_tensor_dump_output_folder = (
|
|
256
|
+
get_global_server_args().debug_tensor_dump_output_folder
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
def compute_logprobs_for_multi_item_scoring(
|
|
260
|
+
self,
|
|
261
|
+
input_ids,
|
|
262
|
+
hidden_states,
|
|
263
|
+
lm_head: VocabParallelEmbedding,
|
|
264
|
+
logits_metadata: Union[LogitsMetadata, ForwardBatch],
|
|
265
|
+
delimiter_token: int,
|
|
266
|
+
):
|
|
267
|
+
"""
|
|
268
|
+
Compute logprobs for multi-item scoring using delimiter-based token extraction.
|
|
269
|
+
|
|
270
|
+
This method is designed for scenarios where you want to score multiple items/candidates
|
|
271
|
+
against a single query by combining them into one sequence separated by delimiters.
|
|
272
|
+
|
|
273
|
+
Sequence format: Query<delimiter>Item1<delimiter>Item2<delimiter>...
|
|
274
|
+
Scoring positions: Extracts logprobs at positions before each <delimiter>
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
input_ids (torch.Tensor): Input token IDs containing query and items separated by delimiters.
|
|
278
|
+
Shape: [total_sequence_length] for single request or [batch_total_length] for batch.
|
|
279
|
+
hidden_states (torch.Tensor): Hidden states from the model.
|
|
280
|
+
Shape: [sequence_length, hidden_dim].
|
|
281
|
+
lm_head (VocabParallelEmbedding): Language model head for computing logits.
|
|
282
|
+
logits_metadata (Union[LogitsMetadata, ForwardBatch]): Metadata containing batch info
|
|
283
|
+
and token ID specifications for logprob extraction.
|
|
284
|
+
delimiter_token (int): Token ID used as delimiter between query and items.
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
LogitsProcessorOutput: Contains:
|
|
288
|
+
- next_token_logits: None (not needed for scoring-only requests)
|
|
289
|
+
- input_token_logprobs: Logprobs of delimiter tokens at scoring positions
|
|
290
|
+
- input_top_logprobs_val: Top-k logprobs at delimiter positions (if requested)
|
|
291
|
+
- input_top_logprobs_idx: Top-k token indices at delimiter positions (if requested)
|
|
292
|
+
- input_token_ids_logprobs_val: Logprobs for user-requested token IDs (if any)
|
|
293
|
+
- input_token_ids_logprobs_idx: Indices for user-requested token IDs (if any)
|
|
294
|
+
"""
|
|
295
|
+
multi_item_indices = (input_ids == delimiter_token).nonzero(as_tuple=True)[
|
|
296
|
+
0
|
|
297
|
+
] - 1
|
|
298
|
+
# Extract hidden states at delimiter positions for multi-item scoring
|
|
299
|
+
sliced_hidden = hidden_states[multi_item_indices]
|
|
300
|
+
|
|
301
|
+
sliced_logits = self._get_logits(sliced_hidden, lm_head, logits_metadata)
|
|
302
|
+
sliced_logprobs = torch.nn.functional.log_softmax(sliced_logits, dim=-1)
|
|
303
|
+
|
|
304
|
+
# Initialize return values
|
|
305
|
+
input_token_ids_logprobs_val = []
|
|
306
|
+
input_token_ids_logprobs_idx = []
|
|
307
|
+
input_top_logprobs_val = None
|
|
308
|
+
input_top_logprobs_idx = None
|
|
309
|
+
|
|
310
|
+
# Recalculate extend_logprob_pruned_lens_cpu to match delimiter counts per request
|
|
311
|
+
# Original contains sequence lengths, but we need delimiter counts for sliced_logprobs
|
|
312
|
+
if (
|
|
313
|
+
logits_metadata.token_ids_logprobs
|
|
314
|
+
or logits_metadata.extend_return_top_logprob
|
|
315
|
+
):
|
|
316
|
+
logits_metadata.extend_logprob_pruned_lens_cpu = []
|
|
317
|
+
|
|
318
|
+
if logits_metadata.extend_seq_lens_cpu is not None:
|
|
319
|
+
# Multi-request batch: count delimiters per request
|
|
320
|
+
input_pt = 0
|
|
321
|
+
for req_seq_len in logits_metadata.extend_seq_lens_cpu:
|
|
322
|
+
req_input_ids = input_ids[input_pt : input_pt + req_seq_len]
|
|
323
|
+
delimiter_count = (req_input_ids == delimiter_token).sum().item()
|
|
324
|
+
logits_metadata.extend_logprob_pruned_lens_cpu.append(
|
|
325
|
+
delimiter_count
|
|
326
|
+
)
|
|
327
|
+
input_pt += req_seq_len
|
|
328
|
+
else:
|
|
329
|
+
# Single request case: one request gets all delimiters
|
|
330
|
+
total_delimiters = (input_ids == delimiter_token).sum().item()
|
|
331
|
+
logits_metadata.extend_logprob_pruned_lens_cpu = [total_delimiters]
|
|
332
|
+
|
|
333
|
+
# Get the logprobs of specified token ids
|
|
334
|
+
if logits_metadata.extend_token_ids_logprob:
|
|
335
|
+
(
|
|
336
|
+
input_token_ids_logprobs_val,
|
|
337
|
+
input_token_ids_logprobs_idx,
|
|
338
|
+
) = self.get_token_ids_logprobs(
|
|
339
|
+
sliced_logprobs, logits_metadata, delay_cpu_copy=True
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
# Get the logprob of top-k tokens
|
|
343
|
+
if logits_metadata.extend_return_top_logprob:
|
|
344
|
+
(
|
|
345
|
+
input_top_logprobs_val,
|
|
346
|
+
input_top_logprobs_idx,
|
|
347
|
+
) = self.get_top_logprobs(sliced_logprobs, logits_metadata)
|
|
348
|
+
|
|
349
|
+
# For input_token_logprobs, use delimiter token logprobs
|
|
350
|
+
input_token_logprobs = sliced_logprobs[:, delimiter_token]
|
|
351
|
+
|
|
352
|
+
return LogitsProcessorOutput(
|
|
353
|
+
next_token_logits=None, # Multi-item scoring doesn't need next token logits
|
|
354
|
+
input_token_logprobs=input_token_logprobs,
|
|
355
|
+
input_top_logprobs_val=input_top_logprobs_val,
|
|
356
|
+
input_top_logprobs_idx=input_top_logprobs_idx,
|
|
357
|
+
input_token_ids_logprobs_val=input_token_ids_logprobs_val,
|
|
358
|
+
input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
|
|
248
359
|
)
|
|
249
360
|
|
|
250
361
|
def forward(
|
|
@@ -257,10 +368,19 @@ class LogitsProcessor(nn.Module):
|
|
|
257
368
|
) -> LogitsProcessorOutput:
|
|
258
369
|
if isinstance(logits_metadata, ForwardBatch):
|
|
259
370
|
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
|
|
371
|
+
|
|
372
|
+
# Check if multi-item scoring is enabled via server args (only for prefill-only requests)
|
|
373
|
+
multi_item_delimiter = get_global_server_args().multi_item_scoring_delimiter
|
|
374
|
+
if multi_item_delimiter is not None and logits_metadata.is_prefill_only:
|
|
375
|
+
return self.compute_logprobs_for_multi_item_scoring(
|
|
376
|
+
input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter
|
|
377
|
+
)
|
|
378
|
+
|
|
260
379
|
# Get the last hidden states and last logits for the next token prediction
|
|
261
380
|
if (
|
|
262
381
|
logits_metadata.forward_mode.is_decode_or_idle()
|
|
263
382
|
or logits_metadata.forward_mode.is_target_verify()
|
|
383
|
+
or logits_metadata.forward_mode.is_draft_extend_v2()
|
|
264
384
|
):
|
|
265
385
|
pruned_states = hidden_states
|
|
266
386
|
if aux_hidden_states is not None:
|
|
@@ -269,8 +389,8 @@ class LogitsProcessor(nn.Module):
|
|
|
269
389
|
input_logprob_indices = None
|
|
270
390
|
elif (
|
|
271
391
|
logits_metadata.forward_mode.is_extend()
|
|
272
|
-
|
|
273
|
-
):
|
|
392
|
+
or logits_metadata.forward_mode.is_split_prefill()
|
|
393
|
+
) and not logits_metadata.extend_return_logprob:
|
|
274
394
|
# Prefill without input logprobs.
|
|
275
395
|
if logits_metadata.padded_static_len < 0:
|
|
276
396
|
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
|
@@ -584,7 +704,9 @@ class LogitsProcessor(nn.Module):
|
|
|
584
704
|
|
|
585
705
|
@staticmethod
|
|
586
706
|
def get_token_ids_logprobs(
|
|
587
|
-
all_logprobs: torch.Tensor,
|
|
707
|
+
all_logprobs: torch.Tensor,
|
|
708
|
+
logits_metadata: LogitsMetadata,
|
|
709
|
+
delay_cpu_copy: bool = False,
|
|
588
710
|
):
|
|
589
711
|
input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], []
|
|
590
712
|
pt = 0
|
|
@@ -597,9 +719,17 @@ class LogitsProcessor(nn.Module):
|
|
|
597
719
|
input_token_ids_logprobs_idx.append([])
|
|
598
720
|
continue
|
|
599
721
|
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
722
|
+
position_logprobs = all_logprobs[
|
|
723
|
+
pt : pt + pruned_len, token_ids
|
|
724
|
+
] # Shape: [pruned_len, num_tokens]
|
|
725
|
+
|
|
726
|
+
if delay_cpu_copy:
|
|
727
|
+
# Keep as tensor to delay GPU-to-CPU transfer
|
|
728
|
+
input_token_ids_logprobs_val.append(position_logprobs)
|
|
729
|
+
else:
|
|
730
|
+
# Convert to list immediately (default behavior)
|
|
731
|
+
input_token_ids_logprobs_val.append(position_logprobs.tolist())
|
|
732
|
+
|
|
603
733
|
input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)])
|
|
604
734
|
pt += pruned_len
|
|
605
735
|
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ModelOpt related constants
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
QUANT_CFG_CHOICES = {
|
|
6
|
+
"fp8": "FP8_DEFAULT_CFG",
|
|
7
|
+
"int4_awq": "INT4_AWQ_CFG", # TODO: add support for int4_awq
|
|
8
|
+
"w4a8_awq": "W4A8_AWQ_BETA_CFG", # TODO: add support for w4a8_awq
|
|
9
|
+
"nvfp4": "NVFP4_DEFAULT_CFG",
|
|
10
|
+
"nvfp4_awq": "NVFP4_AWQ_LITE_CFG", # TODO: add support for nvfp4_awq
|
|
11
|
+
}
|