sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -11
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +474 -142
- sglang/compile_deep_gemm.py +3 -0
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +10 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +314 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +228 -92
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +294 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +78 -37
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +373 -68
- sglang/srt/disaggregation/prefill.py +53 -49
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +842 -0
- sglang/srt/entrypoints/grpc_server.py +950 -0
- sglang/srt/entrypoints/http_server.py +179 -60
- sglang/srt/entrypoints/openai/protocol.py +265 -29
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +213 -122
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +289 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +17 -8
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +215 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +40 -8
- sglang/srt/layers/attention/flashinfer_backend.py +341 -204
- sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
- sglang/srt/layers/attention/mamba/mamba.py +577 -0
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +180 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
- sglang/srt/layers/moe/ep_moe/layer.py +248 -333
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +83 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +29 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +155 -60
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +191 -56
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +28 -33
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +44 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +55 -0
- sglang/srt/managers/schedule_batch.py +343 -212
- sglang/srt/managers/schedule_policy.py +145 -18
- sglang/srt/managers/scheduler.py +653 -273
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +579 -674
- sglang/srt/managers/tp_worker.py +96 -26
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +9 -2
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +651 -80
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +227 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +93 -48
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +74 -46
- sglang/srt/model_executor/model_runner.py +455 -176
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +10 -4
- sglang/srt/model_loader/loader.py +319 -10
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +161 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +578 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +50 -4
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +55 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +49 -26
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1051 -285
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +98 -29
- sglang/srt/speculative/ngram_info.py +428 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +605 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +451 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +119 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +313 -0
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +140 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +407 -8
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ from sglang.srt.distributed import (
|
|
17
17
|
get_tp_group,
|
18
18
|
tensor_model_parallel_all_reduce,
|
19
19
|
)
|
20
|
+
from sglang.srt.utils import get_bool_env_var, is_hip
|
20
21
|
|
21
22
|
if TYPE_CHECKING:
|
22
23
|
from sglang.srt.configs.model_config import ModelConfig
|
@@ -36,6 +37,9 @@ _LOCAL_ATTN_DP_SIZE: Optional[int] = None
|
|
36
37
|
_LOCAL_ATTN_DP_RANK: Optional[int] = None
|
37
38
|
_ENABLE_DP_ATTENTION_FLAG: bool = False
|
38
39
|
|
40
|
+
_is_hip = is_hip()
|
41
|
+
_USE_ROCM700A_WA = _is_hip and get_bool_env_var("SGLANG_USE_ROCM700A")
|
42
|
+
|
39
43
|
|
40
44
|
class DpPaddingMode(IntEnum):
|
41
45
|
|
@@ -51,7 +55,12 @@ class DpPaddingMode(IntEnum):
|
|
51
55
|
return self == DpPaddingMode.SUM_LEN
|
52
56
|
|
53
57
|
@classmethod
|
54
|
-
def get_dp_padding_mode(
|
58
|
+
def get_dp_padding_mode(
|
59
|
+
cls, is_extend_in_batch, global_num_tokens: List[int]
|
60
|
+
) -> DpPaddingMode:
|
61
|
+
if is_extend_in_batch:
|
62
|
+
return DpPaddingMode.SUM_LEN
|
63
|
+
|
55
64
|
# we choose the mode that minimizes the communication cost
|
56
65
|
max_len = max(global_num_tokens)
|
57
66
|
sum_len = sum(global_num_tokens)
|
@@ -62,7 +71,12 @@ class DpPaddingMode(IntEnum):
|
|
62
71
|
|
63
72
|
@classmethod
|
64
73
|
def get_default_mode_in_cuda_graph(cls) -> DpPaddingMode:
|
65
|
-
|
74
|
+
# TODO(kkhuang-amd): noqa, temporary work-around for rocm 7.0.0 alpha
|
75
|
+
# it can be safely removed later, once RCCL fixed
|
76
|
+
if _USE_ROCM700A_WA:
|
77
|
+
return cls.SUM_LEN
|
78
|
+
else:
|
79
|
+
return cls.MAX_LEN
|
66
80
|
|
67
81
|
|
68
82
|
class _DpGatheredBufferWrapper:
|
@@ -119,6 +133,18 @@ class _DpGatheredBufferWrapper:
|
|
119
133
|
def get_dp_global_num_tokens(cls) -> List[int]:
|
120
134
|
return cls._global_num_tokens
|
121
135
|
|
136
|
+
@classmethod
|
137
|
+
def get_dp_hidden_size(cls) -> int:
|
138
|
+
return cls._hidden_size
|
139
|
+
|
140
|
+
@classmethod
|
141
|
+
def get_dp_dtype(cls) -> torch.dtype:
|
142
|
+
return cls._dtype
|
143
|
+
|
144
|
+
@classmethod
|
145
|
+
def get_dp_device(cls) -> torch.device:
|
146
|
+
return cls._device
|
147
|
+
|
122
148
|
|
123
149
|
def set_dp_buffer_len(
|
124
150
|
global_dp_buffer_len: int,
|
@@ -150,6 +176,18 @@ def get_dp_global_num_tokens() -> List[int]:
|
|
150
176
|
return _DpGatheredBufferWrapper.get_dp_global_num_tokens()
|
151
177
|
|
152
178
|
|
179
|
+
def get_dp_hidden_size() -> int:
|
180
|
+
return _DpGatheredBufferWrapper.get_dp_hidden_size()
|
181
|
+
|
182
|
+
|
183
|
+
def get_dp_dtype() -> torch.dtype:
|
184
|
+
return _DpGatheredBufferWrapper.get_dp_dtype()
|
185
|
+
|
186
|
+
|
187
|
+
def get_dp_device() -> torch.device:
|
188
|
+
return _DpGatheredBufferWrapper.get_dp_device()
|
189
|
+
|
190
|
+
|
153
191
|
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
|
154
192
|
if not enable_dp_attention:
|
155
193
|
return tp_rank, tp_size, 0
|
@@ -225,6 +263,7 @@ def initialize_dp_attention(
|
|
225
263
|
use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
|
226
264
|
use_pymscclpp=False,
|
227
265
|
use_custom_allreduce=False,
|
266
|
+
use_torch_symm_mem=False,
|
228
267
|
use_hpu_communicator=False,
|
229
268
|
use_xpu_communicator=False,
|
230
269
|
use_npu_communicator=False,
|
sglang/srt/layers/elementwise.py
CHANGED
@@ -187,7 +187,9 @@ fused_dual_residual_rmsnorm_kernel_autotune = rmsnorm_autotune(
|
|
187
187
|
|
188
188
|
def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=False):
|
189
189
|
assert len(x.shape) == 2
|
190
|
-
assert
|
190
|
+
assert (
|
191
|
+
x.shape == residual.shape and x.dtype == residual.dtype
|
192
|
+
), f"{x.shape=} {residual.shape=} {x.dtype=} {residual.dtype=}"
|
191
193
|
output, mid = torch.empty_like(x), torch.empty_like(x)
|
192
194
|
bs, hidden_dim = x.shape
|
193
195
|
if autotune:
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -18,6 +18,7 @@ from typing import Optional, Tuple, Union
|
|
18
18
|
|
19
19
|
import torch
|
20
20
|
import torch.nn as nn
|
21
|
+
from packaging.version import Version
|
21
22
|
|
22
23
|
from sglang.srt.custom_op import CustomOp
|
23
24
|
from sglang.srt.utils import (
|
@@ -25,32 +26,38 @@ from sglang.srt.utils import (
|
|
25
26
|
get_bool_env_var,
|
26
27
|
is_cpu,
|
27
28
|
is_cuda,
|
29
|
+
is_flashinfer_available,
|
28
30
|
is_hip,
|
29
31
|
is_npu,
|
32
|
+
is_xpu,
|
30
33
|
supports_custom_op,
|
31
34
|
)
|
32
35
|
|
33
36
|
_is_cuda = is_cuda()
|
37
|
+
_is_flashinfer_available = is_flashinfer_available()
|
34
38
|
_is_hip = is_hip()
|
35
39
|
_is_npu = is_npu()
|
36
40
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
37
41
|
_is_cpu_amx_available = cpu_has_amx_support()
|
38
42
|
_is_cpu = is_cpu()
|
43
|
+
_is_xpu = is_xpu()
|
39
44
|
|
40
45
|
if _is_cuda:
|
41
|
-
|
42
|
-
fused_add_rmsnorm
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
)
|
46
|
+
if _is_flashinfer_available:
|
47
|
+
from flashinfer.norm import fused_add_rmsnorm
|
48
|
+
else:
|
49
|
+
from sgl_kernel import fused_add_rmsnorm
|
50
|
+
from sgl_kernel import gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm
|
47
51
|
|
48
52
|
if _use_aiter:
|
49
53
|
from aiter import rmsnorm2d_fwd as rms_norm
|
50
54
|
from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm
|
51
55
|
elif _is_hip:
|
56
|
+
import vllm
|
52
57
|
from vllm._custom_ops import fused_add_rms_norm, rms_norm
|
53
58
|
|
59
|
+
_vllm_version = Version(vllm.__version__)
|
60
|
+
|
54
61
|
logger = logging.getLogger(__name__)
|
55
62
|
|
56
63
|
if _is_npu:
|
@@ -73,6 +80,8 @@ class RMSNorm(CustomOp):
|
|
73
80
|
)
|
74
81
|
if _use_aiter:
|
75
82
|
self._forward_method = self.forward_aiter
|
83
|
+
if get_bool_env_var("SGLANG_ENABLE_DETERMINISTIC_INFERENCE"):
|
84
|
+
self._forward_method = self.forward_native
|
76
85
|
|
77
86
|
def forward_cuda(
|
78
87
|
self,
|
@@ -127,8 +136,21 @@ class RMSNorm(CustomOp):
|
|
127
136
|
# NOTE: Remove this if aiter kernel supports discontinuous input
|
128
137
|
x = x.contiguous()
|
129
138
|
if residual is not None:
|
130
|
-
|
131
|
-
|
139
|
+
if _vllm_version < Version("0.9"):
|
140
|
+
fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
|
141
|
+
return x, residual
|
142
|
+
else:
|
143
|
+
residual_out = torch.empty_like(x)
|
144
|
+
output = torch.empty_like(x)
|
145
|
+
fused_add_rms_norm(
|
146
|
+
output,
|
147
|
+
x,
|
148
|
+
residual_out,
|
149
|
+
residual,
|
150
|
+
self.weight.data,
|
151
|
+
self.variance_epsilon,
|
152
|
+
)
|
153
|
+
return output, residual_out
|
132
154
|
out = torch.empty_like(x)
|
133
155
|
rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
134
156
|
return out
|
@@ -271,16 +293,11 @@ class GemmaRMSNorm(CustomOp):
|
|
271
293
|
x: torch.Tensor,
|
272
294
|
residual: Optional[torch.Tensor] = None,
|
273
295
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
274
|
-
orig_dtype = x.dtype
|
275
296
|
if residual is not None:
|
276
297
|
x = x + residual
|
277
298
|
residual = x
|
278
299
|
|
279
|
-
x = x.
|
280
|
-
variance = torch_npu.mean(torch_npu.pow(x, 2), dim=-1, keepdim=True)
|
281
|
-
x = x * torch_npu.rsqrt(variance + self.variance_epsilon)
|
282
|
-
x = x * (1.0 + self.weight.float())
|
283
|
-
x = x.to(orig_dtype)
|
300
|
+
x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon)
|
284
301
|
return x if residual is None else (x, residual)
|
285
302
|
|
286
303
|
|
@@ -312,7 +329,9 @@ class Gemma3RMSNorm(CustomOp):
|
|
312
329
|
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
313
330
|
|
314
331
|
|
315
|
-
if not (
|
332
|
+
if not (
|
333
|
+
_is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_xpu
|
334
|
+
):
|
316
335
|
logger.info(
|
317
336
|
"sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
|
318
337
|
)
|
sglang/srt/layers/linear.py
CHANGED
@@ -31,6 +31,7 @@ from sglang.srt.layers.parameter import (
|
|
31
31
|
_ColumnvLLMParameter,
|
32
32
|
)
|
33
33
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
34
|
+
from sglang.srt.layers.utils import pad_or_narrow_weight
|
34
35
|
from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs
|
35
36
|
|
36
37
|
if TYPE_CHECKING:
|
@@ -235,9 +236,8 @@ class ReplicatedLinear(LinearBase):
|
|
235
236
|
loaded_weight = loaded_weight[:1]
|
236
237
|
else:
|
237
238
|
raise ValueError(f"{loaded_weight} are not all equal")
|
238
|
-
|
239
|
-
|
240
|
-
), f"Loading weight error: param: {param.size()}, loaded_weight: {loaded_weight.size()}"
|
239
|
+
|
240
|
+
assert param.size() == loaded_weight.size()
|
241
241
|
param.data.copy_(loaded_weight)
|
242
242
|
|
243
243
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
@@ -626,9 +626,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
626
626
|
# bitsandbytes loads the weights of the specific portion
|
627
627
|
# no need to narrow here
|
628
628
|
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
|
629
|
-
|
630
|
-
|
631
|
-
|
629
|
+
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
|
630
|
+
end_idx = start_idx + shard_size
|
631
|
+
if end_idx > loaded_weight.shape[output_dim]:
|
632
|
+
loaded_weight = pad_or_narrow_weight(
|
633
|
+
loaded_weight, output_dim, start_idx, shard_size
|
634
|
+
)
|
635
|
+
else:
|
636
|
+
loaded_weight = loaded_weight.narrow(
|
637
|
+
output_dim, start_idx, shard_size
|
638
|
+
)
|
632
639
|
|
633
640
|
# Special case for AQLM codebooks.
|
634
641
|
elif is_metadata:
|
@@ -894,6 +901,35 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
894
901
|
)
|
895
902
|
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
|
896
903
|
|
904
|
+
def _load_qkv_block_scale(
|
905
|
+
self, param: BasevLLMParameter, loaded_weight: torch.Tensor
|
906
|
+
):
|
907
|
+
block_n, _ = self.quant_method.quant_config.weight_block_size
|
908
|
+
q_size = self.total_num_heads * self.head_size // block_n
|
909
|
+
k_size = self.total_num_kv_heads * self.head_size // block_n
|
910
|
+
v_size = self.total_num_kv_heads * self.head_size // block_n
|
911
|
+
shard_offsets = [
|
912
|
+
# (shard_id, shard_offset, shard_size)
|
913
|
+
("q", 0, q_size),
|
914
|
+
("k", q_size, k_size),
|
915
|
+
("v", q_size + k_size, v_size),
|
916
|
+
]
|
917
|
+
for shard_id, shard_offset, shard_size in shard_offsets:
|
918
|
+
loaded_weight_shard = loaded_weight.narrow(
|
919
|
+
param.output_dim, shard_offset, shard_size
|
920
|
+
)
|
921
|
+
rank_shard_offset = self._get_shard_offset_mapping(shard_id) // block_n
|
922
|
+
rank_shard_size = self._get_shard_size_mapping(shard_id) // block_n
|
923
|
+
param.load_qkv_weight(
|
924
|
+
loaded_weight=loaded_weight_shard,
|
925
|
+
num_heads=self.num_kv_head_replicas,
|
926
|
+
shard_id=shard_id,
|
927
|
+
shard_offset=rank_shard_offset,
|
928
|
+
shard_size=rank_shard_size,
|
929
|
+
tp_rank=self.tp_rank,
|
930
|
+
use_presharded_weights=self.use_presharded_weights,
|
931
|
+
)
|
932
|
+
|
897
933
|
def weight_loader_v2(
|
898
934
|
self,
|
899
935
|
param: BasevLLMParameter,
|
@@ -907,6 +943,9 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
907
943
|
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
|
908
944
|
param.load_qkv_weight(loaded_weight=loaded_weight)
|
909
945
|
return
|
946
|
+
elif isinstance(param, BlockQuantScaleParameter):
|
947
|
+
self._load_qkv_block_scale(param, loaded_weight)
|
948
|
+
return
|
910
949
|
# TODO: @dsikka - move to parameter.py
|
911
950
|
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
912
951
|
return
|
@@ -1271,7 +1310,16 @@ class RowParallelLinear(LinearBase):
|
|
1271
1310
|
shard_size,
|
1272
1311
|
)
|
1273
1312
|
else:
|
1274
|
-
|
1313
|
+
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
|
1314
|
+
end_idx = start_idx + shard_size
|
1315
|
+
if end_idx > loaded_weight.shape[input_dim]:
|
1316
|
+
loaded_weight = pad_or_narrow_weight(
|
1317
|
+
loaded_weight, input_dim, start_idx, shard_size
|
1318
|
+
)
|
1319
|
+
else:
|
1320
|
+
loaded_weight = loaded_weight.narrow(
|
1321
|
+
input_dim, start_idx, shard_size
|
1322
|
+
)
|
1275
1323
|
|
1276
1324
|
# Special case for loading scales off disk, which often do not
|
1277
1325
|
# have a shape (such as in the case of AutoFP8).
|
@@ -35,6 +35,9 @@ from sglang.srt.layers.dp_attention import (
|
|
35
35
|
get_attention_dp_rank,
|
36
36
|
get_attention_dp_size,
|
37
37
|
get_attention_tp_size,
|
38
|
+
get_dp_device,
|
39
|
+
get_dp_dtype,
|
40
|
+
get_dp_hidden_size,
|
38
41
|
get_global_dp_buffer,
|
39
42
|
get_local_attention_dp_size,
|
40
43
|
set_dp_buffer_len,
|
@@ -46,16 +49,19 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
46
49
|
ForwardBatch,
|
47
50
|
ForwardMode,
|
48
51
|
)
|
49
|
-
from sglang.srt.utils import dump_to_file, use_intel_amx_backend
|
52
|
+
from sglang.srt.utils import dump_to_file, is_npu, use_intel_amx_backend
|
50
53
|
|
51
54
|
logger = logging.getLogger(__name__)
|
52
55
|
|
56
|
+
_is_npu = is_npu()
|
57
|
+
|
53
58
|
|
54
59
|
@dataclasses.dataclass
|
55
60
|
class LogitsProcessorOutput:
|
56
61
|
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
57
62
|
# The logits of the next tokens. shape: [#seq, vocab_size]
|
58
|
-
|
63
|
+
# Can be None for certain prefill-only requests (e.g., multi-item scoring) that don't need next token generation
|
64
|
+
next_token_logits: Optional[torch.Tensor]
|
59
65
|
# Used by speculative decoding (EAGLE)
|
60
66
|
# The last hidden layers
|
61
67
|
hidden_states: Optional[torch.Tensor] = None
|
@@ -67,7 +73,10 @@ class LogitsProcessorOutput:
|
|
67
73
|
next_token_top_logprobs_val: Optional[List] = None
|
68
74
|
next_token_top_logprobs_idx: Optional[List] = None
|
69
75
|
# The logprobs and ids of the requested token ids in output positions. shape: [#seq, n] (n is the number of requested token ids)
|
70
|
-
|
76
|
+
# Can contain either lists or GPU tensors (for delayed copy optimization in prefill-only requests)
|
77
|
+
next_token_token_ids_logprobs_val: Optional[
|
78
|
+
List[Union[List[float], torch.Tensor]]
|
79
|
+
] = None
|
71
80
|
next_token_token_ids_logprobs_idx: Optional[List] = None
|
72
81
|
|
73
82
|
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
@@ -77,7 +86,10 @@ class LogitsProcessorOutput:
|
|
77
86
|
input_top_logprobs_val: List = None
|
78
87
|
input_top_logprobs_idx: List = None
|
79
88
|
# The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids)
|
80
|
-
|
89
|
+
# Can contain either lists or GPU tensors (for delayed GPU-to-CPU transfer optimization)
|
90
|
+
input_token_ids_logprobs_val: Optional[List[Union[List[float], torch.Tensor]]] = (
|
91
|
+
None
|
92
|
+
)
|
81
93
|
input_token_ids_logprobs_idx: Optional[List] = None
|
82
94
|
|
83
95
|
|
@@ -119,6 +131,9 @@ class LogitsMetadata:
|
|
119
131
|
# for padding
|
120
132
|
padded_static_len: int = -1
|
121
133
|
|
134
|
+
# Whether this batch is prefill-only (no token generation needed)
|
135
|
+
is_prefill_only: bool = False
|
136
|
+
|
122
137
|
@classmethod
|
123
138
|
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
124
139
|
if (
|
@@ -161,6 +176,7 @@ class LogitsMetadata:
|
|
161
176
|
token_ids_logprobs=forward_batch.token_ids_logprobs,
|
162
177
|
extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
|
163
178
|
padded_static_len=forward_batch.padded_static_len,
|
179
|
+
is_prefill_only=forward_batch.is_prefill_only,
|
164
180
|
global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
|
165
181
|
dp_local_start_pos=forward_batch.dp_local_start_pos,
|
166
182
|
dp_local_num_tokens=forward_batch.dp_local_num_tokens,
|
@@ -180,10 +196,13 @@ class LogitsMetadata:
|
|
180
196
|
)
|
181
197
|
else:
|
182
198
|
dp_local_start_pos = cumtokens[dp_rank - 1]
|
183
|
-
dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank]
|
184
199
|
|
185
200
|
self.dp_local_start_pos = dp_local_start_pos
|
186
|
-
self.dp_local_num_tokens =
|
201
|
+
self.dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank]
|
202
|
+
|
203
|
+
hidden_size = get_dp_hidden_size()
|
204
|
+
dtype = get_dp_dtype()
|
205
|
+
device = get_dp_device()
|
187
206
|
|
188
207
|
if self.global_num_tokens_for_logprob_cpu is not None:
|
189
208
|
# create a smaller buffer to reduce peak memory usage
|
@@ -191,10 +210,13 @@ class LogitsMetadata:
|
|
191
210
|
else:
|
192
211
|
self.global_dp_buffer_len = self.global_dp_buffer_len
|
193
212
|
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
213
|
+
self.gathered_buffer = torch.empty(
|
214
|
+
(
|
215
|
+
self.global_dp_buffer_len,
|
216
|
+
hidden_size,
|
217
|
+
),
|
218
|
+
dtype=dtype,
|
219
|
+
device=device,
|
198
220
|
)
|
199
221
|
|
200
222
|
|
@@ -206,6 +228,7 @@ class LogitsProcessor(nn.Module):
|
|
206
228
|
self.config = config
|
207
229
|
self.logit_scale = logit_scale
|
208
230
|
self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
|
231
|
+
self.use_fp32_lm_head = global_server_args_dict["enable_fp32_lm_head"]
|
209
232
|
if self.use_attn_tp_group:
|
210
233
|
self.attn_tp_size = get_attention_tp_size()
|
211
234
|
self.do_tensor_parallel_all_gather = (
|
@@ -232,6 +255,108 @@ class LogitsProcessor(nn.Module):
|
|
232
255
|
"debug_tensor_dump_output_folder", None
|
233
256
|
)
|
234
257
|
|
258
|
+
def compute_logprobs_for_multi_item_scoring(
|
259
|
+
self,
|
260
|
+
input_ids,
|
261
|
+
hidden_states,
|
262
|
+
lm_head: VocabParallelEmbedding,
|
263
|
+
logits_metadata: Union[LogitsMetadata, ForwardBatch],
|
264
|
+
delimiter_token: int,
|
265
|
+
):
|
266
|
+
"""
|
267
|
+
Compute logprobs for multi-item scoring using delimiter-based token extraction.
|
268
|
+
|
269
|
+
This method is designed for scenarios where you want to score multiple items/candidates
|
270
|
+
against a single query by combining them into one sequence separated by delimiters.
|
271
|
+
|
272
|
+
Sequence format: Query<delimiter>Item1<delimiter>Item2<delimiter>...
|
273
|
+
Scoring positions: Extracts logprobs at positions before each <delimiter>
|
274
|
+
|
275
|
+
Args:
|
276
|
+
input_ids (torch.Tensor): Input token IDs containing query and items separated by delimiters.
|
277
|
+
Shape: [total_sequence_length] for single request or [batch_total_length] for batch.
|
278
|
+
hidden_states (torch.Tensor): Hidden states from the model.
|
279
|
+
Shape: [sequence_length, hidden_dim].
|
280
|
+
lm_head (VocabParallelEmbedding): Language model head for computing logits.
|
281
|
+
logits_metadata (Union[LogitsMetadata, ForwardBatch]): Metadata containing batch info
|
282
|
+
and token ID specifications for logprob extraction.
|
283
|
+
delimiter_token (int): Token ID used as delimiter between query and items.
|
284
|
+
|
285
|
+
Returns:
|
286
|
+
LogitsProcessorOutput: Contains:
|
287
|
+
- next_token_logits: None (not needed for scoring-only requests)
|
288
|
+
- input_token_logprobs: Logprobs of delimiter tokens at scoring positions
|
289
|
+
- input_top_logprobs_val: Top-k logprobs at delimiter positions (if requested)
|
290
|
+
- input_top_logprobs_idx: Top-k token indices at delimiter positions (if requested)
|
291
|
+
- input_token_ids_logprobs_val: Logprobs for user-requested token IDs (if any)
|
292
|
+
- input_token_ids_logprobs_idx: Indices for user-requested token IDs (if any)
|
293
|
+
"""
|
294
|
+
multi_item_indices = (input_ids == delimiter_token).nonzero(as_tuple=True)[
|
295
|
+
0
|
296
|
+
] - 1
|
297
|
+
# Extract hidden states at delimiter positions for multi-item scoring
|
298
|
+
sliced_hidden = hidden_states[multi_item_indices]
|
299
|
+
|
300
|
+
sliced_logits = self._get_logits(sliced_hidden, lm_head, logits_metadata)
|
301
|
+
sliced_logprobs = torch.nn.functional.log_softmax(sliced_logits, dim=-1)
|
302
|
+
|
303
|
+
# Initialize return values
|
304
|
+
input_token_ids_logprobs_val = []
|
305
|
+
input_token_ids_logprobs_idx = []
|
306
|
+
input_top_logprobs_val = None
|
307
|
+
input_top_logprobs_idx = None
|
308
|
+
|
309
|
+
# Recalculate extend_logprob_pruned_lens_cpu to match delimiter counts per request
|
310
|
+
# Original contains sequence lengths, but we need delimiter counts for sliced_logprobs
|
311
|
+
if (
|
312
|
+
logits_metadata.token_ids_logprobs
|
313
|
+
or logits_metadata.extend_return_top_logprob
|
314
|
+
):
|
315
|
+
logits_metadata.extend_logprob_pruned_lens_cpu = []
|
316
|
+
|
317
|
+
if logits_metadata.extend_seq_lens_cpu is not None:
|
318
|
+
# Multi-request batch: count delimiters per request
|
319
|
+
input_pt = 0
|
320
|
+
for req_seq_len in logits_metadata.extend_seq_lens_cpu:
|
321
|
+
req_input_ids = input_ids[input_pt : input_pt + req_seq_len]
|
322
|
+
delimiter_count = (req_input_ids == delimiter_token).sum().item()
|
323
|
+
logits_metadata.extend_logprob_pruned_lens_cpu.append(
|
324
|
+
delimiter_count
|
325
|
+
)
|
326
|
+
input_pt += req_seq_len
|
327
|
+
else:
|
328
|
+
# Single request case: one request gets all delimiters
|
329
|
+
total_delimiters = (input_ids == delimiter_token).sum().item()
|
330
|
+
logits_metadata.extend_logprob_pruned_lens_cpu = [total_delimiters]
|
331
|
+
|
332
|
+
# Get the logprobs of specified token ids
|
333
|
+
if logits_metadata.extend_token_ids_logprob:
|
334
|
+
(
|
335
|
+
input_token_ids_logprobs_val,
|
336
|
+
input_token_ids_logprobs_idx,
|
337
|
+
) = self.get_token_ids_logprobs(
|
338
|
+
sliced_logprobs, logits_metadata, delay_cpu_copy=True
|
339
|
+
)
|
340
|
+
|
341
|
+
# Get the logprob of top-k tokens
|
342
|
+
if logits_metadata.extend_return_top_logprob:
|
343
|
+
(
|
344
|
+
input_top_logprobs_val,
|
345
|
+
input_top_logprobs_idx,
|
346
|
+
) = self.get_top_logprobs(sliced_logprobs, logits_metadata)
|
347
|
+
|
348
|
+
# For input_token_logprobs, use delimiter token logprobs
|
349
|
+
input_token_logprobs = sliced_logprobs[:, delimiter_token]
|
350
|
+
|
351
|
+
return LogitsProcessorOutput(
|
352
|
+
next_token_logits=None, # Multi-item scoring doesn't need next token logits
|
353
|
+
input_token_logprobs=input_token_logprobs,
|
354
|
+
input_top_logprobs_val=input_top_logprobs_val,
|
355
|
+
input_top_logprobs_idx=input_top_logprobs_idx,
|
356
|
+
input_token_ids_logprobs_val=input_token_ids_logprobs_val,
|
357
|
+
input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
|
358
|
+
)
|
359
|
+
|
235
360
|
def forward(
|
236
361
|
self,
|
237
362
|
input_ids,
|
@@ -242,6 +367,16 @@ class LogitsProcessor(nn.Module):
|
|
242
367
|
) -> LogitsProcessorOutput:
|
243
368
|
if isinstance(logits_metadata, ForwardBatch):
|
244
369
|
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
|
370
|
+
|
371
|
+
# Check if multi-item scoring is enabled via server args (only for prefill-only requests)
|
372
|
+
multi_item_delimiter = global_server_args_dict.get(
|
373
|
+
"multi_item_scoring_delimiter"
|
374
|
+
)
|
375
|
+
if multi_item_delimiter is not None and logits_metadata.is_prefill_only:
|
376
|
+
return self.compute_logprobs_for_multi_item_scoring(
|
377
|
+
input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter
|
378
|
+
)
|
379
|
+
|
245
380
|
# Get the last hidden states and last logits for the next token prediction
|
246
381
|
if (
|
247
382
|
logits_metadata.forward_mode.is_decode_or_idle()
|
@@ -441,13 +576,17 @@ class LogitsProcessor(nn.Module):
|
|
441
576
|
if self.do_tensor_parallel_all_gather_dp_attn:
|
442
577
|
logits_metadata.compute_dp_attention_metadata()
|
443
578
|
hidden_states, local_hidden_states = (
|
444
|
-
|
579
|
+
logits_metadata.gathered_buffer,
|
445
580
|
hidden_states,
|
446
581
|
)
|
447
582
|
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
|
448
583
|
|
449
584
|
if hasattr(lm_head, "weight"):
|
450
|
-
if
|
585
|
+
if self.use_fp32_lm_head:
|
586
|
+
logits = torch.matmul(
|
587
|
+
hidden_states.to(torch.float32), lm_head.weight.to(torch.float32).T
|
588
|
+
)
|
589
|
+
elif use_intel_amx_backend(lm_head):
|
451
590
|
logits = torch.ops.sgl_kernel.weight_packed_linear(
|
452
591
|
hidden_states.to(lm_head.weight.dtype),
|
453
592
|
lm_head.weight,
|
@@ -461,7 +600,15 @@ class LogitsProcessor(nn.Module):
|
|
461
600
|
else:
|
462
601
|
# GGUF models
|
463
602
|
# TODO: use weight_packed_linear for GGUF models
|
464
|
-
|
603
|
+
if self.use_fp32_lm_head:
|
604
|
+
with torch.cuda.amp.autocast(enabled=False):
|
605
|
+
logits = lm_head.quant_method.apply(
|
606
|
+
lm_head, hidden_states.to(torch.float32), embedding_bias
|
607
|
+
)
|
608
|
+
else:
|
609
|
+
logits = lm_head.quant_method.apply(
|
610
|
+
lm_head, hidden_states, embedding_bias
|
611
|
+
)
|
465
612
|
|
466
613
|
if self.logit_scale is not None:
|
467
614
|
logits.mul_(self.logit_scale)
|
@@ -517,7 +664,12 @@ class LogitsProcessor(nn.Module):
|
|
517
664
|
logits = logits[:, : self.config.vocab_size].float()
|
518
665
|
|
519
666
|
if self.final_logit_softcapping:
|
520
|
-
|
667
|
+
if not _is_npu:
|
668
|
+
fused_softcap(logits, self.final_logit_softcapping)
|
669
|
+
else:
|
670
|
+
logits = self.final_logit_softcapping * torch.tanh(
|
671
|
+
logits / self.final_logit_softcapping
|
672
|
+
)
|
521
673
|
|
522
674
|
return logits
|
523
675
|
|
@@ -552,7 +704,9 @@ class LogitsProcessor(nn.Module):
|
|
552
704
|
|
553
705
|
@staticmethod
|
554
706
|
def get_token_ids_logprobs(
|
555
|
-
all_logprobs: torch.Tensor,
|
707
|
+
all_logprobs: torch.Tensor,
|
708
|
+
logits_metadata: LogitsMetadata,
|
709
|
+
delay_cpu_copy: bool = False,
|
556
710
|
):
|
557
711
|
input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], []
|
558
712
|
pt = 0
|
@@ -565,9 +719,17 @@ class LogitsProcessor(nn.Module):
|
|
565
719
|
input_token_ids_logprobs_idx.append([])
|
566
720
|
continue
|
567
721
|
|
568
|
-
|
569
|
-
|
570
|
-
|
722
|
+
position_logprobs = all_logprobs[
|
723
|
+
pt : pt + pruned_len, token_ids
|
724
|
+
] # Shape: [pruned_len, num_tokens]
|
725
|
+
|
726
|
+
if delay_cpu_copy:
|
727
|
+
# Keep as tensor to delay GPU-to-CPU transfer
|
728
|
+
input_token_ids_logprobs_val.append(position_logprobs)
|
729
|
+
else:
|
730
|
+
# Convert to list immediately (default behavior)
|
731
|
+
input_token_ids_logprobs_val.append(position_logprobs.tolist())
|
732
|
+
|
571
733
|
input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)])
|
572
734
|
pt += pruned_len
|
573
735
|
|
@@ -0,0 +1,11 @@
|
|
1
|
+
"""
|
2
|
+
ModelOpt related constants
|
3
|
+
"""
|
4
|
+
|
5
|
+
QUANT_CFG_CHOICES = {
|
6
|
+
"fp8": "FP8_DEFAULT_CFG",
|
7
|
+
"int4_awq": "INT4_AWQ_CFG", # TODO: add support for int4_awq
|
8
|
+
"w4a8_awq": "W4A8_AWQ_BETA_CFG", # TODO: add support for w4a8_awq
|
9
|
+
"nvfp4": "NVFP4_DEFAULT_CFG",
|
10
|
+
"nvfp4_awq": "NVFP4_AWQ_LITE_CFG", # TODO: add support for nvfp4_awq
|
11
|
+
}
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
1
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunner, MoeRunnerConfig
|
2
2
|
from sglang.srt.layers.moe.utils import (
|
3
3
|
DeepEPMode,
|
4
4
|
MoeA2ABackend,
|
@@ -17,6 +17,7 @@ from sglang.srt.layers.moe.utils import (
|
|
17
17
|
__all__ = [
|
18
18
|
"DeepEPMode",
|
19
19
|
"MoeA2ABackend",
|
20
|
+
"MoeRunner",
|
20
21
|
"MoeRunnerConfig",
|
21
22
|
"MoeRunnerBackend",
|
22
23
|
"initialize_moe_config",
|