sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +54 -37
- sglang/bench_one_batch_server.py +340 -34
- sglang/bench_serving.py +340 -159
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +9 -2
- sglang/profiler.py +20 -3
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +309 -0
- sglang/srt/configs/load_config.py +33 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +284 -118
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +576 -0
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +6 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +26 -15
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +268 -98
- sglang/srt/disaggregation/decode.py +172 -39
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +203 -555
- sglang/srt/disaggregation/nixl/conn.py +217 -63
- sglang/srt/disaggregation/prefill.py +113 -270
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +203 -97
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +85 -65
- sglang/srt/entrypoints/grpc_server.py +632 -305
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +169 -17
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +327 -34
- sglang/srt/entrypoints/openai/serving_base.py +74 -8
- sglang/srt/entrypoints/openai/serving_chat.py +202 -118
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +20 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +47 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +323 -0
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +21 -16
- sglang/srt/function_call/glm4_moe_detector.py +4 -8
- sglang/srt/function_call/gpt_oss_detector.py +24 -1
- sglang/srt/function_call/json_array_parser.py +61 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +98 -7
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/grpc_request_manager.py +915 -0
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
- sglang/srt/layers/activation.py +11 -7
- sglang/srt/layers/attention/aiter_backend.py +17 -18
- sglang/srt/layers/attention/ascend_backend.py +125 -10
- sglang/srt/layers/attention/attention_registry.py +226 -0
- sglang/srt/layers/attention/base_attn_backend.py +32 -4
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +52 -15
- sglang/srt/layers/attention/flashinfer_backend.py +357 -212
- sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
- sglang/srt/layers/attention/flashmla_backend.py +9 -7
- sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
- sglang/srt/layers/attention/mamba/mamba.py +514 -1
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +23 -0
- sglang/srt/layers/attention/nsa_backend.py +1201 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +249 -42
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
- sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +61 -3
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +19 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +28 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +47 -15
- sglang/srt/layers/linear.py +30 -5
- sglang/srt/layers/logits_processor.py +161 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
- sglang/srt/layers/moe/ep_moe/layer.py +243 -448
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
- sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +27 -1
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +86 -20
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +43 -15
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +141 -81
- sglang/srt/layers/quantization/mxfp4.py +17 -34
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -24
- sglang/srt/layers/quantization/w8a8_int8.py +45 -27
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +750 -46
- sglang/srt/layers/sampler.py +84 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +23 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +9 -4
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +33 -7
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +41 -17
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +83 -152
- sglang/srt/managers/data_parallel_controller.py +156 -87
- sglang/srt/managers/detokenizer_manager.py +51 -24
- sglang/srt/managers/io_struct.py +223 -129
- sglang/srt/managers/mm_utils.py +49 -10
- sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +130 -0
- sglang/srt/managers/schedule_batch.py +340 -529
- sglang/srt/managers/schedule_policy.py +158 -18
- sglang/srt/managers/scheduler.py +665 -620
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
- sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
- sglang/srt/managers/tokenizer_manager.py +462 -226
- sglang/srt/managers/tp_worker.py +217 -156
- sglang/srt/managers/utils.py +79 -47
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +42 -28
- sglang/srt/mem_cache/base_prefix_cache.py +3 -3
- sglang/srt/mem_cache/chunk_cache.py +20 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +38 -0
- sglang/srt/mem_cache/hicache_storage.py +44 -2
- sglang/srt/mem_cache/hiradix_cache.py +134 -34
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +602 -208
- sglang/srt/mem_cache/memory_pool_host.py +134 -183
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +263 -78
- sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +115 -58
- sglang/srt/metrics/collector.py +113 -120
- sglang/srt/metrics/func_timer.py +3 -8
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +81 -36
- sglang/srt/model_executor/forward_batch_info.py +40 -50
- sglang/srt/model_executor/model_runner.py +507 -319
- sglang/srt/model_executor/npu_graph_runner.py +11 -5
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +438 -37
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +200 -27
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +40 -56
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +25 -4
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +793 -235
- sglang/srt/models/dots_ocr.py +171 -0
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +570 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -3
- sglang/srt/models/glm4_moe.py +17 -40
- sglang/srt/models/glm4_moe_nextn.py +4 -4
- sglang/srt/models/glm4v.py +3 -2
- sglang/srt/models/glm4v_moe.py +6 -6
- sglang/srt/models/gpt_oss.py +12 -35
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +4 -2
- sglang/srt/models/llama.py +6 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +6 -23
- sglang/srt/models/longcat_flash_nextn.py +4 -15
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +27 -6
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +5 -5
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +70 -4
- sglang/srt/models/qwen2_vl.py +6 -3
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +50 -38
- sglang/srt/models/qwen3_next.py +43 -21
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +791 -0
- sglang/srt/models/qwen3_vl_moe.py +343 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +268 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +61 -0
- sglang/srt/multimodal/processors/base_processor.py +21 -9
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +2 -4
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +20 -10
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +83 -17
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +36 -23
- sglang/srt/sampling/sampling_params.py +75 -0
- sglang/srt/server_args.py +1300 -338
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +161 -0
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
- sglang/srt/speculative/eagle_info.py +786 -0
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +113 -1270
- sglang/srt/speculative/eagle_worker.py +120 -285
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/ngram_info.py +433 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +49 -0
- sglang/srt/speculative/spec_utils.py +641 -0
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +35 -18
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/{utils.py → utils/common.py} +583 -113
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +120 -11
- sglang/test/runners.py +3 -1
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +8 -2
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +3 -4
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +430 -0
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +93 -1
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +432 -16
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
- sglang/srt/entrypoints/grpc_request_manager.py +0 -580
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -38,17 +38,15 @@ from sglang.srt.layers.dp_attention import (
|
|
|
38
38
|
get_dp_device,
|
|
39
39
|
get_dp_dtype,
|
|
40
40
|
get_dp_hidden_size,
|
|
41
|
-
get_global_dp_buffer,
|
|
42
41
|
get_local_attention_dp_size,
|
|
43
|
-
set_dp_buffer_len,
|
|
44
42
|
)
|
|
45
43
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
|
46
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
47
44
|
from sglang.srt.model_executor.forward_batch_info import (
|
|
48
45
|
CaptureHiddenMode,
|
|
49
46
|
ForwardBatch,
|
|
50
47
|
ForwardMode,
|
|
51
48
|
)
|
|
49
|
+
from sglang.srt.server_args import get_global_server_args
|
|
52
50
|
from sglang.srt.utils import dump_to_file, is_npu, use_intel_amx_backend
|
|
53
51
|
|
|
54
52
|
logger = logging.getLogger(__name__)
|
|
@@ -60,13 +58,14 @@ _is_npu = is_npu()
|
|
|
60
58
|
class LogitsProcessorOutput:
|
|
61
59
|
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
|
62
60
|
# The logits of the next tokens. shape: [#seq, vocab_size]
|
|
63
|
-
|
|
61
|
+
# Can be None for certain prefill-only requests (e.g., multi-item scoring) that don't need next token generation
|
|
62
|
+
next_token_logits: Optional[torch.Tensor]
|
|
64
63
|
# Used by speculative decoding (EAGLE)
|
|
65
64
|
# The last hidden layers
|
|
66
65
|
hidden_states: Optional[torch.Tensor] = None
|
|
67
66
|
|
|
68
67
|
## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler
|
|
69
|
-
# he log probs of output tokens, if
|
|
68
|
+
# he log probs of output tokens, if SGLANG_RETURN_ORIGINAL_LOGPROB = True, will get the log probs before applying temperature. If False, will get the log probs before applying temperature.
|
|
70
69
|
next_token_logprobs: Optional[torch.Tensor] = None
|
|
71
70
|
# The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
|
|
72
71
|
next_token_top_logprobs_val: Optional[List] = None
|
|
@@ -85,7 +84,10 @@ class LogitsProcessorOutput:
|
|
|
85
84
|
input_top_logprobs_val: List = None
|
|
86
85
|
input_top_logprobs_idx: List = None
|
|
87
86
|
# The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids)
|
|
88
|
-
|
|
87
|
+
# Can contain either lists or GPU tensors (for delayed GPU-to-CPU transfer optimization)
|
|
88
|
+
input_token_ids_logprobs_val: Optional[List[Union[List[float], torch.Tensor]]] = (
|
|
89
|
+
None
|
|
90
|
+
)
|
|
89
91
|
input_token_ids_logprobs_idx: Optional[List] = None
|
|
90
92
|
|
|
91
93
|
|
|
@@ -127,10 +129,16 @@ class LogitsMetadata:
|
|
|
127
129
|
# for padding
|
|
128
130
|
padded_static_len: int = -1
|
|
129
131
|
|
|
132
|
+
# Whether this batch is prefill-only (no token generation needed)
|
|
133
|
+
is_prefill_only: bool = False
|
|
134
|
+
|
|
130
135
|
@classmethod
|
|
131
136
|
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
|
132
137
|
if (
|
|
133
|
-
|
|
138
|
+
(
|
|
139
|
+
forward_batch.forward_mode.is_extend()
|
|
140
|
+
or forward_batch.forward_mode.is_split_prefill()
|
|
141
|
+
)
|
|
134
142
|
and forward_batch.return_logprob
|
|
135
143
|
and not forward_batch.forward_mode.is_target_verify()
|
|
136
144
|
):
|
|
@@ -169,6 +177,7 @@ class LogitsMetadata:
|
|
|
169
177
|
token_ids_logprobs=forward_batch.token_ids_logprobs,
|
|
170
178
|
extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
|
|
171
179
|
padded_static_len=forward_batch.padded_static_len,
|
|
180
|
+
is_prefill_only=forward_batch.is_prefill_only,
|
|
172
181
|
global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
|
|
173
182
|
dp_local_start_pos=forward_batch.dp_local_start_pos,
|
|
174
183
|
dp_local_num_tokens=forward_batch.dp_local_num_tokens,
|
|
@@ -219,7 +228,8 @@ class LogitsProcessor(nn.Module):
|
|
|
219
228
|
super().__init__()
|
|
220
229
|
self.config = config
|
|
221
230
|
self.logit_scale = logit_scale
|
|
222
|
-
self.use_attn_tp_group =
|
|
231
|
+
self.use_attn_tp_group = get_global_server_args().enable_dp_lm_head
|
|
232
|
+
self.use_fp32_lm_head = get_global_server_args().enable_fp32_lm_head
|
|
223
233
|
if self.use_attn_tp_group:
|
|
224
234
|
self.attn_tp_size = get_attention_tp_size()
|
|
225
235
|
self.do_tensor_parallel_all_gather = (
|
|
@@ -242,8 +252,110 @@ class LogitsProcessor(nn.Module):
|
|
|
242
252
|
):
|
|
243
253
|
self.final_logit_softcapping = None
|
|
244
254
|
|
|
245
|
-
self.debug_tensor_dump_output_folder =
|
|
246
|
-
|
|
255
|
+
self.debug_tensor_dump_output_folder = (
|
|
256
|
+
get_global_server_args().debug_tensor_dump_output_folder
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
def compute_logprobs_for_multi_item_scoring(
|
|
260
|
+
self,
|
|
261
|
+
input_ids,
|
|
262
|
+
hidden_states,
|
|
263
|
+
lm_head: VocabParallelEmbedding,
|
|
264
|
+
logits_metadata: Union[LogitsMetadata, ForwardBatch],
|
|
265
|
+
delimiter_token: int,
|
|
266
|
+
):
|
|
267
|
+
"""
|
|
268
|
+
Compute logprobs for multi-item scoring using delimiter-based token extraction.
|
|
269
|
+
|
|
270
|
+
This method is designed for scenarios where you want to score multiple items/candidates
|
|
271
|
+
against a single query by combining them into one sequence separated by delimiters.
|
|
272
|
+
|
|
273
|
+
Sequence format: Query<delimiter>Item1<delimiter>Item2<delimiter>...
|
|
274
|
+
Scoring positions: Extracts logprobs at positions before each <delimiter>
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
input_ids (torch.Tensor): Input token IDs containing query and items separated by delimiters.
|
|
278
|
+
Shape: [total_sequence_length] for single request or [batch_total_length] for batch.
|
|
279
|
+
hidden_states (torch.Tensor): Hidden states from the model.
|
|
280
|
+
Shape: [sequence_length, hidden_dim].
|
|
281
|
+
lm_head (VocabParallelEmbedding): Language model head for computing logits.
|
|
282
|
+
logits_metadata (Union[LogitsMetadata, ForwardBatch]): Metadata containing batch info
|
|
283
|
+
and token ID specifications for logprob extraction.
|
|
284
|
+
delimiter_token (int): Token ID used as delimiter between query and items.
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
LogitsProcessorOutput: Contains:
|
|
288
|
+
- next_token_logits: None (not needed for scoring-only requests)
|
|
289
|
+
- input_token_logprobs: Logprobs of delimiter tokens at scoring positions
|
|
290
|
+
- input_top_logprobs_val: Top-k logprobs at delimiter positions (if requested)
|
|
291
|
+
- input_top_logprobs_idx: Top-k token indices at delimiter positions (if requested)
|
|
292
|
+
- input_token_ids_logprobs_val: Logprobs for user-requested token IDs (if any)
|
|
293
|
+
- input_token_ids_logprobs_idx: Indices for user-requested token IDs (if any)
|
|
294
|
+
"""
|
|
295
|
+
multi_item_indices = (input_ids == delimiter_token).nonzero(as_tuple=True)[
|
|
296
|
+
0
|
|
297
|
+
] - 1
|
|
298
|
+
# Extract hidden states at delimiter positions for multi-item scoring
|
|
299
|
+
sliced_hidden = hidden_states[multi_item_indices]
|
|
300
|
+
|
|
301
|
+
sliced_logits = self._get_logits(sliced_hidden, lm_head, logits_metadata)
|
|
302
|
+
sliced_logprobs = torch.nn.functional.log_softmax(sliced_logits, dim=-1)
|
|
303
|
+
|
|
304
|
+
# Initialize return values
|
|
305
|
+
input_token_ids_logprobs_val = []
|
|
306
|
+
input_token_ids_logprobs_idx = []
|
|
307
|
+
input_top_logprobs_val = None
|
|
308
|
+
input_top_logprobs_idx = None
|
|
309
|
+
|
|
310
|
+
# Recalculate extend_logprob_pruned_lens_cpu to match delimiter counts per request
|
|
311
|
+
# Original contains sequence lengths, but we need delimiter counts for sliced_logprobs
|
|
312
|
+
if (
|
|
313
|
+
logits_metadata.token_ids_logprobs
|
|
314
|
+
or logits_metadata.extend_return_top_logprob
|
|
315
|
+
):
|
|
316
|
+
logits_metadata.extend_logprob_pruned_lens_cpu = []
|
|
317
|
+
|
|
318
|
+
if logits_metadata.extend_seq_lens_cpu is not None:
|
|
319
|
+
# Multi-request batch: count delimiters per request
|
|
320
|
+
input_pt = 0
|
|
321
|
+
for req_seq_len in logits_metadata.extend_seq_lens_cpu:
|
|
322
|
+
req_input_ids = input_ids[input_pt : input_pt + req_seq_len]
|
|
323
|
+
delimiter_count = (req_input_ids == delimiter_token).sum().item()
|
|
324
|
+
logits_metadata.extend_logprob_pruned_lens_cpu.append(
|
|
325
|
+
delimiter_count
|
|
326
|
+
)
|
|
327
|
+
input_pt += req_seq_len
|
|
328
|
+
else:
|
|
329
|
+
# Single request case: one request gets all delimiters
|
|
330
|
+
total_delimiters = (input_ids == delimiter_token).sum().item()
|
|
331
|
+
logits_metadata.extend_logprob_pruned_lens_cpu = [total_delimiters]
|
|
332
|
+
|
|
333
|
+
# Get the logprobs of specified token ids
|
|
334
|
+
if logits_metadata.extend_token_ids_logprob:
|
|
335
|
+
(
|
|
336
|
+
input_token_ids_logprobs_val,
|
|
337
|
+
input_token_ids_logprobs_idx,
|
|
338
|
+
) = self.get_token_ids_logprobs(
|
|
339
|
+
sliced_logprobs, logits_metadata, delay_cpu_copy=True
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
# Get the logprob of top-k tokens
|
|
343
|
+
if logits_metadata.extend_return_top_logprob:
|
|
344
|
+
(
|
|
345
|
+
input_top_logprobs_val,
|
|
346
|
+
input_top_logprobs_idx,
|
|
347
|
+
) = self.get_top_logprobs(sliced_logprobs, logits_metadata)
|
|
348
|
+
|
|
349
|
+
# For input_token_logprobs, use delimiter token logprobs
|
|
350
|
+
input_token_logprobs = sliced_logprobs[:, delimiter_token]
|
|
351
|
+
|
|
352
|
+
return LogitsProcessorOutput(
|
|
353
|
+
next_token_logits=None, # Multi-item scoring doesn't need next token logits
|
|
354
|
+
input_token_logprobs=input_token_logprobs,
|
|
355
|
+
input_top_logprobs_val=input_top_logprobs_val,
|
|
356
|
+
input_top_logprobs_idx=input_top_logprobs_idx,
|
|
357
|
+
input_token_ids_logprobs_val=input_token_ids_logprobs_val,
|
|
358
|
+
input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
|
|
247
359
|
)
|
|
248
360
|
|
|
249
361
|
def forward(
|
|
@@ -256,10 +368,19 @@ class LogitsProcessor(nn.Module):
|
|
|
256
368
|
) -> LogitsProcessorOutput:
|
|
257
369
|
if isinstance(logits_metadata, ForwardBatch):
|
|
258
370
|
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
|
|
371
|
+
|
|
372
|
+
# Check if multi-item scoring is enabled via server args (only for prefill-only requests)
|
|
373
|
+
multi_item_delimiter = get_global_server_args().multi_item_scoring_delimiter
|
|
374
|
+
if multi_item_delimiter is not None and logits_metadata.is_prefill_only:
|
|
375
|
+
return self.compute_logprobs_for_multi_item_scoring(
|
|
376
|
+
input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter
|
|
377
|
+
)
|
|
378
|
+
|
|
259
379
|
# Get the last hidden states and last logits for the next token prediction
|
|
260
380
|
if (
|
|
261
381
|
logits_metadata.forward_mode.is_decode_or_idle()
|
|
262
382
|
or logits_metadata.forward_mode.is_target_verify()
|
|
383
|
+
or logits_metadata.forward_mode.is_draft_extend_v2()
|
|
263
384
|
):
|
|
264
385
|
pruned_states = hidden_states
|
|
265
386
|
if aux_hidden_states is not None:
|
|
@@ -268,8 +389,8 @@ class LogitsProcessor(nn.Module):
|
|
|
268
389
|
input_logprob_indices = None
|
|
269
390
|
elif (
|
|
270
391
|
logits_metadata.forward_mode.is_extend()
|
|
271
|
-
|
|
272
|
-
):
|
|
392
|
+
or logits_metadata.forward_mode.is_split_prefill()
|
|
393
|
+
) and not logits_metadata.extend_return_logprob:
|
|
273
394
|
# Prefill without input logprobs.
|
|
274
395
|
if logits_metadata.padded_static_len < 0:
|
|
275
396
|
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
|
@@ -461,7 +582,11 @@ class LogitsProcessor(nn.Module):
|
|
|
461
582
|
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
|
|
462
583
|
|
|
463
584
|
if hasattr(lm_head, "weight"):
|
|
464
|
-
if
|
|
585
|
+
if self.use_fp32_lm_head:
|
|
586
|
+
logits = torch.matmul(
|
|
587
|
+
hidden_states.to(torch.float32), lm_head.weight.to(torch.float32).T
|
|
588
|
+
)
|
|
589
|
+
elif use_intel_amx_backend(lm_head):
|
|
465
590
|
logits = torch.ops.sgl_kernel.weight_packed_linear(
|
|
466
591
|
hidden_states.to(lm_head.weight.dtype),
|
|
467
592
|
lm_head.weight,
|
|
@@ -475,7 +600,15 @@ class LogitsProcessor(nn.Module):
|
|
|
475
600
|
else:
|
|
476
601
|
# GGUF models
|
|
477
602
|
# TODO: use weight_packed_linear for GGUF models
|
|
478
|
-
|
|
603
|
+
if self.use_fp32_lm_head:
|
|
604
|
+
with torch.cuda.amp.autocast(enabled=False):
|
|
605
|
+
logits = lm_head.quant_method.apply(
|
|
606
|
+
lm_head, hidden_states.to(torch.float32), embedding_bias
|
|
607
|
+
)
|
|
608
|
+
else:
|
|
609
|
+
logits = lm_head.quant_method.apply(
|
|
610
|
+
lm_head, hidden_states, embedding_bias
|
|
611
|
+
)
|
|
479
612
|
|
|
480
613
|
if self.logit_scale is not None:
|
|
481
614
|
logits.mul_(self.logit_scale)
|
|
@@ -571,7 +704,9 @@ class LogitsProcessor(nn.Module):
|
|
|
571
704
|
|
|
572
705
|
@staticmethod
|
|
573
706
|
def get_token_ids_logprobs(
|
|
574
|
-
all_logprobs: torch.Tensor,
|
|
707
|
+
all_logprobs: torch.Tensor,
|
|
708
|
+
logits_metadata: LogitsMetadata,
|
|
709
|
+
delay_cpu_copy: bool = False,
|
|
575
710
|
):
|
|
576
711
|
input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], []
|
|
577
712
|
pt = 0
|
|
@@ -584,9 +719,17 @@ class LogitsProcessor(nn.Module):
|
|
|
584
719
|
input_token_ids_logprobs_idx.append([])
|
|
585
720
|
continue
|
|
586
721
|
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
722
|
+
position_logprobs = all_logprobs[
|
|
723
|
+
pt : pt + pruned_len, token_ids
|
|
724
|
+
] # Shape: [pruned_len, num_tokens]
|
|
725
|
+
|
|
726
|
+
if delay_cpu_copy:
|
|
727
|
+
# Keep as tensor to delay GPU-to-CPU transfer
|
|
728
|
+
input_token_ids_logprobs_val.append(position_logprobs)
|
|
729
|
+
else:
|
|
730
|
+
# Convert to list immediately (default behavior)
|
|
731
|
+
input_token_ids_logprobs_val.append(position_logprobs.tolist())
|
|
732
|
+
|
|
590
733
|
input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)])
|
|
591
734
|
pt += pruned_len
|
|
592
735
|
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ModelOpt related constants
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
QUANT_CFG_CHOICES = {
|
|
6
|
+
"fp8": "FP8_DEFAULT_CFG",
|
|
7
|
+
"int4_awq": "INT4_AWQ_CFG", # TODO: add support for int4_awq
|
|
8
|
+
"w4a8_awq": "W4A8_AWQ_BETA_CFG", # TODO: add support for w4a8_awq
|
|
9
|
+
"nvfp4": "NVFP4_DEFAULT_CFG",
|
|
10
|
+
"nvfp4_awq": "NVFP4_AWQ_LITE_CFG", # TODO: add support for nvfp4_awq
|
|
11
|
+
}
|
|
@@ -11,24 +11,23 @@ from sgl_kernel import (
|
|
|
11
11
|
)
|
|
12
12
|
|
|
13
13
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
14
|
+
deepep_permute_triton_kernel,
|
|
15
|
+
deepep_post_reorder_triton_kernel,
|
|
16
|
+
deepep_run_moe_deep_preprocess,
|
|
14
17
|
post_reorder_triton_kernel_for_cutlass_moe,
|
|
15
18
|
pre_reorder_triton_kernel_for_cutlass_moe,
|
|
16
|
-
|
|
19
|
+
run_moe_ep_preproess,
|
|
17
20
|
)
|
|
18
21
|
|
|
19
22
|
|
|
20
23
|
def cutlass_w4a8_moe(
|
|
21
|
-
start_expert_id: int,
|
|
22
|
-
end_expert_id: int,
|
|
23
|
-
total_num_experts: int,
|
|
24
24
|
a: torch.Tensor,
|
|
25
25
|
w1_q: torch.Tensor,
|
|
26
26
|
w2_q: torch.Tensor,
|
|
27
27
|
w1_scale: torch.Tensor,
|
|
28
28
|
w2_scale: torch.Tensor,
|
|
29
29
|
topk_weights: torch.Tensor,
|
|
30
|
-
|
|
31
|
-
local_topk_ids: torch.Tensor,
|
|
30
|
+
topk_ids: torch.Tensor,
|
|
32
31
|
a_strides1: torch.Tensor,
|
|
33
32
|
b_strides1: torch.Tensor,
|
|
34
33
|
c_strides1: torch.Tensor,
|
|
@@ -64,6 +63,7 @@ def cutlass_w4a8_moe(
|
|
|
64
63
|
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
|
65
64
|
Shape: [num_experts, N // 512, K * 4]
|
|
66
65
|
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
|
66
|
+
- topk_ids (torch.Tensor): The ids of each token->expert mapping.
|
|
67
67
|
- a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
|
|
68
68
|
- b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
|
|
69
69
|
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
|
|
@@ -83,7 +83,7 @@ def cutlass_w4a8_moe(
|
|
|
83
83
|
Returns:
|
|
84
84
|
- torch.Tensor: The fp8 output tensor after applying the MoE layer.
|
|
85
85
|
"""
|
|
86
|
-
assert topk_weights.shape ==
|
|
86
|
+
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
|
87
87
|
assert w1_q.dtype == torch.int8
|
|
88
88
|
assert w2_q.dtype == torch.int8
|
|
89
89
|
assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
|
|
@@ -96,20 +96,21 @@ def cutlass_w4a8_moe(
|
|
|
96
96
|
assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
|
|
97
97
|
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
|
|
98
98
|
assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
|
|
99
|
-
|
|
99
|
+
num_local_experts = w1_q.size(0)
|
|
100
100
|
m = a.size(0)
|
|
101
101
|
k = w1_q.size(2) * 2 # w1_q is transposed and packed
|
|
102
102
|
n = w2_q.size(2) * 2 # w2_q is transposed and packed
|
|
103
|
-
topk =
|
|
103
|
+
topk = topk_ids.size(1)
|
|
104
104
|
|
|
105
105
|
if apply_router_weight_on_input:
|
|
106
106
|
assert topk == 1, "apply_router_weight_on_input is only implemented for topk=1"
|
|
107
107
|
|
|
108
108
|
device = a.device
|
|
109
|
+
topk_ids = torch.where(topk_ids == -1, num_local_experts, topk_ids)
|
|
109
110
|
|
|
110
|
-
_, src2dst, _ =
|
|
111
|
-
|
|
112
|
-
|
|
111
|
+
_, src2dst, _ = run_moe_ep_preproess(
|
|
112
|
+
topk_ids,
|
|
113
|
+
num_local_experts,
|
|
113
114
|
)
|
|
114
115
|
|
|
115
116
|
gateup_input = torch.empty(
|
|
@@ -122,9 +123,9 @@ def cutlass_w4a8_moe(
|
|
|
122
123
|
a,
|
|
123
124
|
gateup_input,
|
|
124
125
|
src2dst,
|
|
125
|
-
|
|
126
|
+
topk_ids,
|
|
126
127
|
a1_scale,
|
|
127
|
-
|
|
128
|
+
num_local_experts,
|
|
128
129
|
topk,
|
|
129
130
|
k,
|
|
130
131
|
BLOCK_SIZE=512,
|
|
@@ -133,16 +134,16 @@ def cutlass_w4a8_moe(
|
|
|
133
134
|
# NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel,
|
|
134
135
|
# they are kept to allow for a quick switch of the permutation logic
|
|
135
136
|
# from the current triton kernel implementation to the cutlass-based one if needed.
|
|
136
|
-
a_map = torch.empty((
|
|
137
|
-
c_map = torch.empty((
|
|
137
|
+
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
|
138
|
+
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
|
138
139
|
get_cutlass_w4a8_moe_mm_data(
|
|
139
|
-
|
|
140
|
+
topk_ids,
|
|
140
141
|
expert_offsets,
|
|
141
142
|
problem_sizes1,
|
|
142
143
|
problem_sizes2,
|
|
143
144
|
a_map,
|
|
144
145
|
c_map,
|
|
145
|
-
|
|
146
|
+
num_local_experts,
|
|
146
147
|
n,
|
|
147
148
|
k,
|
|
148
149
|
)
|
|
@@ -195,12 +196,203 @@ def cutlass_w4a8_moe(
|
|
|
195
196
|
c2,
|
|
196
197
|
output,
|
|
197
198
|
src2dst,
|
|
198
|
-
|
|
199
|
+
topk_ids,
|
|
199
200
|
topk_weights,
|
|
200
|
-
num_experts,
|
|
201
201
|
topk,
|
|
202
|
+
num_local_experts,
|
|
202
203
|
k,
|
|
203
|
-
0,
|
|
204
204
|
BLOCK_SIZE=512,
|
|
205
205
|
)
|
|
206
206
|
return output
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def cutlass_w4a8_moe_deepep_normal(
|
|
210
|
+
a: torch.Tensor,
|
|
211
|
+
w1_q: torch.Tensor,
|
|
212
|
+
w2_q: torch.Tensor,
|
|
213
|
+
w1_scale: torch.Tensor,
|
|
214
|
+
w2_scale: torch.Tensor,
|
|
215
|
+
topk_weights: torch.Tensor,
|
|
216
|
+
topk_ids_: torch.Tensor,
|
|
217
|
+
a_strides1: torch.Tensor,
|
|
218
|
+
b_strides1: torch.Tensor,
|
|
219
|
+
c_strides1: torch.Tensor,
|
|
220
|
+
a_strides2: torch.Tensor,
|
|
221
|
+
b_strides2: torch.Tensor,
|
|
222
|
+
c_strides2: torch.Tensor,
|
|
223
|
+
s_strides13: torch.Tensor,
|
|
224
|
+
s_strides2: torch.Tensor,
|
|
225
|
+
expert_offsets: torch.Tensor,
|
|
226
|
+
problem_sizes1: torch.Tensor,
|
|
227
|
+
problem_sizes2: torch.Tensor,
|
|
228
|
+
a1_scale: Optional[torch.Tensor] = None,
|
|
229
|
+
a2_scale: Optional[torch.Tensor] = None,
|
|
230
|
+
) -> torch.Tensor:
|
|
231
|
+
"""
|
|
232
|
+
This function computes a w4a8-quantized Mixture of Experts (MoE) layer
|
|
233
|
+
using two sets of quantized weights, w1_q and w2_q, and top-k gating
|
|
234
|
+
mechanism. The matrix multiplications are implemented with CUTLASS
|
|
235
|
+
grouped gemm.
|
|
236
|
+
|
|
237
|
+
Parameters:
|
|
238
|
+
- a (torch.Tensor): The input tensor to the MoE layer.
|
|
239
|
+
Shape: [M, K]
|
|
240
|
+
- w1_q (torch.Tensor): The first set of int4-quantized expert weights.
|
|
241
|
+
Shape: [num_experts, N * 2, K // 2]
|
|
242
|
+
(the weights are passed transposed and int4-packed)
|
|
243
|
+
- w2_q (torch.Tensor): The second set of int4-quantized expert weights.
|
|
244
|
+
Shape: [num_experts, K, N // 2]
|
|
245
|
+
(the weights are passed transposed and int4-packed)
|
|
246
|
+
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
|
|
247
|
+
Shape: [num_experts, K // 512, N * 8]
|
|
248
|
+
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
|
249
|
+
Shape: [num_experts, N // 512, K * 4]
|
|
250
|
+
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
|
251
|
+
- a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
|
|
252
|
+
- b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
|
|
253
|
+
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
|
|
254
|
+
- a_strides2 (torch.Tensor): The input strides of the second grouped gemm.
|
|
255
|
+
- b_strides2 (torch.Tensor): The weights strides of the second grouped gemm.
|
|
256
|
+
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
|
|
257
|
+
- s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm.
|
|
258
|
+
- s_strides2 (torch.Tensor): The scale strides of the second grouped gemm.
|
|
259
|
+
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
|
|
260
|
+
Shape: scalar or [1, K]
|
|
261
|
+
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
|
262
|
+
quantize the intermediate result between the gemms.
|
|
263
|
+
Shape: scalar or [1, N]
|
|
264
|
+
- apply_router_weight_on_input (bool): When true, the topk weights are
|
|
265
|
+
applied directly on the inputs. This is only applicable when topk is 1.
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
- torch.Tensor: The fp8 output tensor after applying the MoE layer.
|
|
269
|
+
"""
|
|
270
|
+
assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
|
|
271
|
+
assert w1_q.dtype == torch.int8
|
|
272
|
+
assert w2_q.dtype == torch.int8
|
|
273
|
+
assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
|
|
274
|
+
assert w1_q.shape[2] * 2 == w2_q.shape[1], "Hidden size mismatch w2"
|
|
275
|
+
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
|
|
276
|
+
assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
|
|
277
|
+
assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
|
|
278
|
+
|
|
279
|
+
assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
|
|
280
|
+
assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
|
|
281
|
+
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
|
|
282
|
+
assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
|
|
283
|
+
num_experts = w1_q.size(0)
|
|
284
|
+
m = a.size(0)
|
|
285
|
+
k = w1_q.size(2) * 2 # w1_q is transposed and packed
|
|
286
|
+
n = w2_q.size(2) * 2 # w2_q is transposed and packed
|
|
287
|
+
topk = topk_ids_.size(1)
|
|
288
|
+
|
|
289
|
+
num_experts = w1_q.size(0)
|
|
290
|
+
m = a.size(0)
|
|
291
|
+
k = w1_q.size(2) * 2
|
|
292
|
+
n = w2_q.size(2) * 2
|
|
293
|
+
topk = topk_ids_.size(1)
|
|
294
|
+
device = a.device
|
|
295
|
+
|
|
296
|
+
reorder_topk_ids, src2dst, _ = deepep_run_moe_deep_preprocess(
|
|
297
|
+
topk_ids_, num_experts
|
|
298
|
+
)
|
|
299
|
+
num_total_tokens = reorder_topk_ids.numel()
|
|
300
|
+
gateup_input_pre_reorder = torch.empty(
|
|
301
|
+
(int(num_total_tokens), a.shape[1]),
|
|
302
|
+
device=device,
|
|
303
|
+
dtype=a.dtype,
|
|
304
|
+
)
|
|
305
|
+
deepep_permute_triton_kernel[(a.shape[0],)](
|
|
306
|
+
a,
|
|
307
|
+
gateup_input_pre_reorder,
|
|
308
|
+
src2dst,
|
|
309
|
+
topk_ids_.to(torch.int64),
|
|
310
|
+
None,
|
|
311
|
+
topk,
|
|
312
|
+
a.shape[1],
|
|
313
|
+
BLOCK_SIZE=512,
|
|
314
|
+
)
|
|
315
|
+
gateup_input = torch.empty(
|
|
316
|
+
gateup_input_pre_reorder.shape, dtype=torch.float8_e4m3fn, device=device
|
|
317
|
+
)
|
|
318
|
+
sgl_per_tensor_quant_fp8(
|
|
319
|
+
gateup_input_pre_reorder, gateup_input, a1_scale.float(), True
|
|
320
|
+
)
|
|
321
|
+
del gateup_input_pre_reorder
|
|
322
|
+
local_topk_ids = topk_ids_
|
|
323
|
+
local_topk_ids = (
|
|
324
|
+
torch.where(local_topk_ids == -1, num_experts, topk_ids_).to(torch.int32)
|
|
325
|
+
).contiguous()
|
|
326
|
+
|
|
327
|
+
a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
|
|
328
|
+
c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
|
|
329
|
+
get_cutlass_w4a8_moe_mm_data(
|
|
330
|
+
local_topk_ids,
|
|
331
|
+
expert_offsets,
|
|
332
|
+
problem_sizes1,
|
|
333
|
+
problem_sizes2,
|
|
334
|
+
a_map,
|
|
335
|
+
c_map,
|
|
336
|
+
num_experts,
|
|
337
|
+
n,
|
|
338
|
+
k,
|
|
339
|
+
)
|
|
340
|
+
c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.bfloat16)
|
|
341
|
+
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.bfloat16)
|
|
342
|
+
|
|
343
|
+
cutlass_w4a8_moe_mm(
|
|
344
|
+
c1,
|
|
345
|
+
gateup_input,
|
|
346
|
+
w1_q,
|
|
347
|
+
a1_scale.float(),
|
|
348
|
+
w1_scale,
|
|
349
|
+
expert_offsets[:-1],
|
|
350
|
+
problem_sizes1,
|
|
351
|
+
a_strides1,
|
|
352
|
+
b_strides1,
|
|
353
|
+
c_strides1,
|
|
354
|
+
s_strides13,
|
|
355
|
+
128,
|
|
356
|
+
topk,
|
|
357
|
+
)
|
|
358
|
+
intermediate = torch.empty((m * topk, n), device=device, dtype=torch.bfloat16)
|
|
359
|
+
silu_and_mul(c1, intermediate)
|
|
360
|
+
|
|
361
|
+
intermediate_q = torch.empty(
|
|
362
|
+
intermediate.shape, dtype=torch.float8_e4m3fn, device=device
|
|
363
|
+
)
|
|
364
|
+
sgl_per_tensor_quant_fp8(intermediate, intermediate_q, a2_scale.float(), True)
|
|
365
|
+
|
|
366
|
+
cutlass_w4a8_moe_mm(
|
|
367
|
+
c2,
|
|
368
|
+
intermediate_q,
|
|
369
|
+
w2_q,
|
|
370
|
+
a2_scale.float(),
|
|
371
|
+
w2_scale,
|
|
372
|
+
expert_offsets[:-1],
|
|
373
|
+
problem_sizes2,
|
|
374
|
+
a_strides2,
|
|
375
|
+
b_strides2,
|
|
376
|
+
c_strides2,
|
|
377
|
+
s_strides2,
|
|
378
|
+
128,
|
|
379
|
+
topk,
|
|
380
|
+
)
|
|
381
|
+
num_tokens = src2dst.shape[0] // topk
|
|
382
|
+
output = torch.empty(
|
|
383
|
+
(num_tokens, c2.shape[1]),
|
|
384
|
+
device=c2.device,
|
|
385
|
+
dtype=torch.bfloat16,
|
|
386
|
+
)
|
|
387
|
+
deepep_post_reorder_triton_kernel[(num_tokens,)](
|
|
388
|
+
c2,
|
|
389
|
+
output,
|
|
390
|
+
src2dst,
|
|
391
|
+
topk_ids_,
|
|
392
|
+
topk_weights,
|
|
393
|
+
topk,
|
|
394
|
+
c2.shape[1],
|
|
395
|
+
BLOCK_SIZE=512,
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
return output
|