sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ from sglang.srt.distributed import (
|
|
17
17
|
get_tp_group,
|
18
18
|
tensor_model_parallel_all_reduce,
|
19
19
|
)
|
20
|
+
from sglang.srt.utils import get_bool_env_var, is_hip
|
20
21
|
|
21
22
|
if TYPE_CHECKING:
|
22
23
|
from sglang.srt.configs.model_config import ModelConfig
|
@@ -36,6 +37,9 @@ _LOCAL_ATTN_DP_SIZE: Optional[int] = None
|
|
36
37
|
_LOCAL_ATTN_DP_RANK: Optional[int] = None
|
37
38
|
_ENABLE_DP_ATTENTION_FLAG: bool = False
|
38
39
|
|
40
|
+
_is_hip = is_hip()
|
41
|
+
_USE_ROCM700A_WA = _is_hip and get_bool_env_var("SGLANG_USE_ROCM700A")
|
42
|
+
|
39
43
|
|
40
44
|
class DpPaddingMode(IntEnum):
|
41
45
|
|
@@ -51,7 +55,12 @@ class DpPaddingMode(IntEnum):
|
|
51
55
|
return self == DpPaddingMode.SUM_LEN
|
52
56
|
|
53
57
|
@classmethod
|
54
|
-
def get_dp_padding_mode(
|
58
|
+
def get_dp_padding_mode(
|
59
|
+
cls, is_extend_in_batch, global_num_tokens: List[int]
|
60
|
+
) -> DpPaddingMode:
|
61
|
+
if is_extend_in_batch:
|
62
|
+
return DpPaddingMode.SUM_LEN
|
63
|
+
|
55
64
|
# we choose the mode that minimizes the communication cost
|
56
65
|
max_len = max(global_num_tokens)
|
57
66
|
sum_len = sum(global_num_tokens)
|
@@ -62,7 +71,12 @@ class DpPaddingMode(IntEnum):
|
|
62
71
|
|
63
72
|
@classmethod
|
64
73
|
def get_default_mode_in_cuda_graph(cls) -> DpPaddingMode:
|
65
|
-
|
74
|
+
# TODO(kkhuang-amd): noqa, temporary work-around for rocm 7.0.0 alpha
|
75
|
+
# it can be safely removed later, once RCCL fixed
|
76
|
+
if _USE_ROCM700A_WA:
|
77
|
+
return cls.SUM_LEN
|
78
|
+
else:
|
79
|
+
return cls.MAX_LEN
|
66
80
|
|
67
81
|
|
68
82
|
class _DpGatheredBufferWrapper:
|
@@ -119,6 +133,18 @@ class _DpGatheredBufferWrapper:
|
|
119
133
|
def get_dp_global_num_tokens(cls) -> List[int]:
|
120
134
|
return cls._global_num_tokens
|
121
135
|
|
136
|
+
@classmethod
|
137
|
+
def get_dp_hidden_size(cls) -> int:
|
138
|
+
return cls._hidden_size
|
139
|
+
|
140
|
+
@classmethod
|
141
|
+
def get_dp_dtype(cls) -> torch.dtype:
|
142
|
+
return cls._dtype
|
143
|
+
|
144
|
+
@classmethod
|
145
|
+
def get_dp_device(cls) -> torch.device:
|
146
|
+
return cls._device
|
147
|
+
|
122
148
|
|
123
149
|
def set_dp_buffer_len(
|
124
150
|
global_dp_buffer_len: int,
|
@@ -150,6 +176,18 @@ def get_dp_global_num_tokens() -> List[int]:
|
|
150
176
|
return _DpGatheredBufferWrapper.get_dp_global_num_tokens()
|
151
177
|
|
152
178
|
|
179
|
+
def get_dp_hidden_size() -> int:
|
180
|
+
return _DpGatheredBufferWrapper.get_dp_hidden_size()
|
181
|
+
|
182
|
+
|
183
|
+
def get_dp_dtype() -> torch.dtype:
|
184
|
+
return _DpGatheredBufferWrapper.get_dp_dtype()
|
185
|
+
|
186
|
+
|
187
|
+
def get_dp_device() -> torch.device:
|
188
|
+
return _DpGatheredBufferWrapper.get_dp_device()
|
189
|
+
|
190
|
+
|
153
191
|
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
|
154
192
|
if not enable_dp_attention:
|
155
193
|
return tp_rank, tp_size, 0
|
@@ -225,6 +263,7 @@ def initialize_dp_attention(
|
|
225
263
|
use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
|
226
264
|
use_pymscclpp=False,
|
227
265
|
use_custom_allreduce=False,
|
266
|
+
use_torch_symm_mem=False,
|
228
267
|
use_hpu_communicator=False,
|
229
268
|
use_xpu_communicator=False,
|
230
269
|
use_npu_communicator=False,
|
sglang/srt/layers/elementwise.py
CHANGED
@@ -187,7 +187,9 @@ fused_dual_residual_rmsnorm_kernel_autotune = rmsnorm_autotune(
|
|
187
187
|
|
188
188
|
def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=False):
|
189
189
|
assert len(x.shape) == 2
|
190
|
-
assert
|
190
|
+
assert (
|
191
|
+
x.shape == residual.shape and x.dtype == residual.dtype
|
192
|
+
), f"{x.shape=} {residual.shape=} {x.dtype=} {residual.dtype=}"
|
191
193
|
output, mid = torch.empty_like(x), torch.empty_like(x)
|
192
194
|
bs, hidden_dim = x.shape
|
193
195
|
if autotune:
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -18,6 +18,7 @@ from typing import Optional, Tuple, Union
|
|
18
18
|
|
19
19
|
import torch
|
20
20
|
import torch.nn as nn
|
21
|
+
from packaging.version import Version
|
21
22
|
|
22
23
|
from sglang.srt.custom_op import CustomOp
|
23
24
|
from sglang.srt.utils import (
|
@@ -25,32 +26,38 @@ from sglang.srt.utils import (
|
|
25
26
|
get_bool_env_var,
|
26
27
|
is_cpu,
|
27
28
|
is_cuda,
|
29
|
+
is_flashinfer_available,
|
28
30
|
is_hip,
|
29
31
|
is_npu,
|
32
|
+
is_xpu,
|
30
33
|
supports_custom_op,
|
31
34
|
)
|
32
35
|
|
33
36
|
_is_cuda = is_cuda()
|
37
|
+
_is_flashinfer_available = is_flashinfer_available()
|
34
38
|
_is_hip = is_hip()
|
35
39
|
_is_npu = is_npu()
|
36
40
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
37
41
|
_is_cpu_amx_available = cpu_has_amx_support()
|
38
42
|
_is_cpu = is_cpu()
|
43
|
+
_is_xpu = is_xpu()
|
39
44
|
|
40
45
|
if _is_cuda:
|
41
|
-
|
42
|
-
fused_add_rmsnorm
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
)
|
46
|
+
if _is_flashinfer_available:
|
47
|
+
from flashinfer.norm import fused_add_rmsnorm
|
48
|
+
else:
|
49
|
+
from sgl_kernel import fused_add_rmsnorm
|
50
|
+
from sgl_kernel import gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm
|
47
51
|
|
48
52
|
if _use_aiter:
|
49
53
|
from aiter import rmsnorm2d_fwd as rms_norm
|
50
54
|
from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm
|
51
55
|
elif _is_hip:
|
56
|
+
import vllm
|
52
57
|
from vllm._custom_ops import fused_add_rms_norm, rms_norm
|
53
58
|
|
59
|
+
_vllm_version = Version(vllm.__version__)
|
60
|
+
|
54
61
|
logger = logging.getLogger(__name__)
|
55
62
|
|
56
63
|
if _is_npu:
|
@@ -73,6 +80,8 @@ class RMSNorm(CustomOp):
|
|
73
80
|
)
|
74
81
|
if _use_aiter:
|
75
82
|
self._forward_method = self.forward_aiter
|
83
|
+
if get_bool_env_var("SGLANG_ENABLE_DETERMINISTIC_INFERENCE"):
|
84
|
+
self._forward_method = self.forward_native
|
76
85
|
|
77
86
|
def forward_cuda(
|
78
87
|
self,
|
@@ -127,8 +136,21 @@ class RMSNorm(CustomOp):
|
|
127
136
|
# NOTE: Remove this if aiter kernel supports discontinuous input
|
128
137
|
x = x.contiguous()
|
129
138
|
if residual is not None:
|
130
|
-
|
131
|
-
|
139
|
+
if _vllm_version < Version("0.9"):
|
140
|
+
fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
|
141
|
+
return x, residual
|
142
|
+
else:
|
143
|
+
residual_out = torch.empty_like(x)
|
144
|
+
output = torch.empty_like(x)
|
145
|
+
fused_add_rms_norm(
|
146
|
+
output,
|
147
|
+
x,
|
148
|
+
residual_out,
|
149
|
+
residual,
|
150
|
+
self.weight.data,
|
151
|
+
self.variance_epsilon,
|
152
|
+
)
|
153
|
+
return output, residual_out
|
132
154
|
out = torch.empty_like(x)
|
133
155
|
rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
134
156
|
return out
|
@@ -271,16 +293,11 @@ class GemmaRMSNorm(CustomOp):
|
|
271
293
|
x: torch.Tensor,
|
272
294
|
residual: Optional[torch.Tensor] = None,
|
273
295
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
274
|
-
orig_dtype = x.dtype
|
275
296
|
if residual is not None:
|
276
297
|
x = x + residual
|
277
298
|
residual = x
|
278
299
|
|
279
|
-
x = x.
|
280
|
-
variance = torch_npu.mean(torch_npu.pow(x, 2), dim=-1, keepdim=True)
|
281
|
-
x = x * torch_npu.rsqrt(variance + self.variance_epsilon)
|
282
|
-
x = x * (1.0 + self.weight.float())
|
283
|
-
x = x.to(orig_dtype)
|
300
|
+
x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon)
|
284
301
|
return x if residual is None else (x, residual)
|
285
302
|
|
286
303
|
|
@@ -312,7 +329,9 @@ class Gemma3RMSNorm(CustomOp):
|
|
312
329
|
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
313
330
|
|
314
331
|
|
315
|
-
if not (
|
332
|
+
if not (
|
333
|
+
_is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_xpu
|
334
|
+
):
|
316
335
|
logger.info(
|
317
336
|
"sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
|
318
337
|
)
|
sglang/srt/layers/linear.py
CHANGED
@@ -31,6 +31,7 @@ from sglang.srt.layers.parameter import (
|
|
31
31
|
_ColumnvLLMParameter,
|
32
32
|
)
|
33
33
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
34
|
+
from sglang.srt.layers.utils import pad_or_narrow_weight
|
34
35
|
from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs
|
35
36
|
|
36
37
|
if TYPE_CHECKING:
|
@@ -235,9 +236,8 @@ class ReplicatedLinear(LinearBase):
|
|
235
236
|
loaded_weight = loaded_weight[:1]
|
236
237
|
else:
|
237
238
|
raise ValueError(f"{loaded_weight} are not all equal")
|
238
|
-
|
239
|
-
|
240
|
-
), f"Loading weight error: param: {param.size()}, loaded_weight: {loaded_weight.size()}"
|
239
|
+
|
240
|
+
assert param.size() == loaded_weight.size()
|
241
241
|
param.data.copy_(loaded_weight)
|
242
242
|
|
243
243
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
@@ -626,9 +626,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
626
626
|
# bitsandbytes loads the weights of the specific portion
|
627
627
|
# no need to narrow here
|
628
628
|
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
|
629
|
-
|
630
|
-
|
631
|
-
|
629
|
+
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
|
630
|
+
end_idx = start_idx + shard_size
|
631
|
+
if end_idx > loaded_weight.shape[output_dim]:
|
632
|
+
loaded_weight = pad_or_narrow_weight(
|
633
|
+
loaded_weight, output_dim, start_idx, shard_size
|
634
|
+
)
|
635
|
+
else:
|
636
|
+
loaded_weight = loaded_weight.narrow(
|
637
|
+
output_dim, start_idx, shard_size
|
638
|
+
)
|
632
639
|
|
633
640
|
# Special case for AQLM codebooks.
|
634
641
|
elif is_metadata:
|
@@ -894,6 +901,35 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
894
901
|
)
|
895
902
|
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
|
896
903
|
|
904
|
+
def _load_qkv_block_scale(
|
905
|
+
self, param: BasevLLMParameter, loaded_weight: torch.Tensor
|
906
|
+
):
|
907
|
+
block_n, _ = self.quant_method.quant_config.weight_block_size
|
908
|
+
q_size = self.total_num_heads * self.head_size // block_n
|
909
|
+
k_size = self.total_num_kv_heads * self.head_size // block_n
|
910
|
+
v_size = self.total_num_kv_heads * self.head_size // block_n
|
911
|
+
shard_offsets = [
|
912
|
+
# (shard_id, shard_offset, shard_size)
|
913
|
+
("q", 0, q_size),
|
914
|
+
("k", q_size, k_size),
|
915
|
+
("v", q_size + k_size, v_size),
|
916
|
+
]
|
917
|
+
for shard_id, shard_offset, shard_size in shard_offsets:
|
918
|
+
loaded_weight_shard = loaded_weight.narrow(
|
919
|
+
param.output_dim, shard_offset, shard_size
|
920
|
+
)
|
921
|
+
rank_shard_offset = self._get_shard_offset_mapping(shard_id) // block_n
|
922
|
+
rank_shard_size = self._get_shard_size_mapping(shard_id) // block_n
|
923
|
+
param.load_qkv_weight(
|
924
|
+
loaded_weight=loaded_weight_shard,
|
925
|
+
num_heads=self.num_kv_head_replicas,
|
926
|
+
shard_id=shard_id,
|
927
|
+
shard_offset=rank_shard_offset,
|
928
|
+
shard_size=rank_shard_size,
|
929
|
+
tp_rank=self.tp_rank,
|
930
|
+
use_presharded_weights=self.use_presharded_weights,
|
931
|
+
)
|
932
|
+
|
897
933
|
def weight_loader_v2(
|
898
934
|
self,
|
899
935
|
param: BasevLLMParameter,
|
@@ -907,6 +943,9 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
907
943
|
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
|
908
944
|
param.load_qkv_weight(loaded_weight=loaded_weight)
|
909
945
|
return
|
946
|
+
elif isinstance(param, BlockQuantScaleParameter):
|
947
|
+
self._load_qkv_block_scale(param, loaded_weight)
|
948
|
+
return
|
910
949
|
# TODO: @dsikka - move to parameter.py
|
911
950
|
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
912
951
|
return
|
@@ -1271,7 +1310,16 @@ class RowParallelLinear(LinearBase):
|
|
1271
1310
|
shard_size,
|
1272
1311
|
)
|
1273
1312
|
else:
|
1274
|
-
|
1313
|
+
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
|
1314
|
+
end_idx = start_idx + shard_size
|
1315
|
+
if end_idx > loaded_weight.shape[input_dim]:
|
1316
|
+
loaded_weight = pad_or_narrow_weight(
|
1317
|
+
loaded_weight, input_dim, start_idx, shard_size
|
1318
|
+
)
|
1319
|
+
else:
|
1320
|
+
loaded_weight = loaded_weight.narrow(
|
1321
|
+
input_dim, start_idx, shard_size
|
1322
|
+
)
|
1275
1323
|
|
1276
1324
|
# Special case for loading scales off disk, which often do not
|
1277
1325
|
# have a shape (such as in the case of AutoFP8).
|
@@ -35,6 +35,9 @@ from sglang.srt.layers.dp_attention import (
|
|
35
35
|
get_attention_dp_rank,
|
36
36
|
get_attention_dp_size,
|
37
37
|
get_attention_tp_size,
|
38
|
+
get_dp_device,
|
39
|
+
get_dp_dtype,
|
40
|
+
get_dp_hidden_size,
|
38
41
|
get_global_dp_buffer,
|
39
42
|
get_local_attention_dp_size,
|
40
43
|
set_dp_buffer_len,
|
@@ -46,10 +49,12 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
46
49
|
ForwardBatch,
|
47
50
|
ForwardMode,
|
48
51
|
)
|
49
|
-
from sglang.srt.utils import dump_to_file, use_intel_amx_backend
|
52
|
+
from sglang.srt.utils import dump_to_file, is_npu, use_intel_amx_backend
|
50
53
|
|
51
54
|
logger = logging.getLogger(__name__)
|
52
55
|
|
56
|
+
_is_npu = is_npu()
|
57
|
+
|
53
58
|
|
54
59
|
@dataclasses.dataclass
|
55
60
|
class LogitsProcessorOutput:
|
@@ -67,7 +72,10 @@ class LogitsProcessorOutput:
|
|
67
72
|
next_token_top_logprobs_val: Optional[List] = None
|
68
73
|
next_token_top_logprobs_idx: Optional[List] = None
|
69
74
|
# The logprobs and ids of the requested token ids in output positions. shape: [#seq, n] (n is the number of requested token ids)
|
70
|
-
|
75
|
+
# Can contain either lists or GPU tensors (for delayed copy optimization in prefill-only requests)
|
76
|
+
next_token_token_ids_logprobs_val: Optional[
|
77
|
+
List[Union[List[float], torch.Tensor]]
|
78
|
+
] = None
|
71
79
|
next_token_token_ids_logprobs_idx: Optional[List] = None
|
72
80
|
|
73
81
|
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
@@ -180,10 +188,13 @@ class LogitsMetadata:
|
|
180
188
|
)
|
181
189
|
else:
|
182
190
|
dp_local_start_pos = cumtokens[dp_rank - 1]
|
183
|
-
dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank]
|
184
191
|
|
185
192
|
self.dp_local_start_pos = dp_local_start_pos
|
186
|
-
self.dp_local_num_tokens =
|
193
|
+
self.dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank]
|
194
|
+
|
195
|
+
hidden_size = get_dp_hidden_size()
|
196
|
+
dtype = get_dp_dtype()
|
197
|
+
device = get_dp_device()
|
187
198
|
|
188
199
|
if self.global_num_tokens_for_logprob_cpu is not None:
|
189
200
|
# create a smaller buffer to reduce peak memory usage
|
@@ -191,10 +202,13 @@ class LogitsMetadata:
|
|
191
202
|
else:
|
192
203
|
self.global_dp_buffer_len = self.global_dp_buffer_len
|
193
204
|
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
205
|
+
self.gathered_buffer = torch.empty(
|
206
|
+
(
|
207
|
+
self.global_dp_buffer_len,
|
208
|
+
hidden_size,
|
209
|
+
),
|
210
|
+
dtype=dtype,
|
211
|
+
device=device,
|
198
212
|
)
|
199
213
|
|
200
214
|
|
@@ -206,6 +220,7 @@ class LogitsProcessor(nn.Module):
|
|
206
220
|
self.config = config
|
207
221
|
self.logit_scale = logit_scale
|
208
222
|
self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
|
223
|
+
self.use_fp32_lm_head = global_server_args_dict["enable_fp32_lm_head"]
|
209
224
|
if self.use_attn_tp_group:
|
210
225
|
self.attn_tp_size = get_attention_tp_size()
|
211
226
|
self.do_tensor_parallel_all_gather = (
|
@@ -441,13 +456,17 @@ class LogitsProcessor(nn.Module):
|
|
441
456
|
if self.do_tensor_parallel_all_gather_dp_attn:
|
442
457
|
logits_metadata.compute_dp_attention_metadata()
|
443
458
|
hidden_states, local_hidden_states = (
|
444
|
-
|
459
|
+
logits_metadata.gathered_buffer,
|
445
460
|
hidden_states,
|
446
461
|
)
|
447
462
|
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
|
448
463
|
|
449
464
|
if hasattr(lm_head, "weight"):
|
450
|
-
if
|
465
|
+
if self.use_fp32_lm_head:
|
466
|
+
logits = torch.matmul(
|
467
|
+
hidden_states.to(torch.float32), lm_head.weight.to(torch.float32).T
|
468
|
+
)
|
469
|
+
elif use_intel_amx_backend(lm_head):
|
451
470
|
logits = torch.ops.sgl_kernel.weight_packed_linear(
|
452
471
|
hidden_states.to(lm_head.weight.dtype),
|
453
472
|
lm_head.weight,
|
@@ -461,7 +480,15 @@ class LogitsProcessor(nn.Module):
|
|
461
480
|
else:
|
462
481
|
# GGUF models
|
463
482
|
# TODO: use weight_packed_linear for GGUF models
|
464
|
-
|
483
|
+
if self.use_fp32_lm_head:
|
484
|
+
with torch.cuda.amp.autocast(enabled=False):
|
485
|
+
logits = lm_head.quant_method.apply(
|
486
|
+
lm_head, hidden_states.to(torch.float32), embedding_bias
|
487
|
+
)
|
488
|
+
else:
|
489
|
+
logits = lm_head.quant_method.apply(
|
490
|
+
lm_head, hidden_states, embedding_bias
|
491
|
+
)
|
465
492
|
|
466
493
|
if self.logit_scale is not None:
|
467
494
|
logits.mul_(self.logit_scale)
|
@@ -517,7 +544,12 @@ class LogitsProcessor(nn.Module):
|
|
517
544
|
logits = logits[:, : self.config.vocab_size].float()
|
518
545
|
|
519
546
|
if self.final_logit_softcapping:
|
520
|
-
|
547
|
+
if not _is_npu:
|
548
|
+
fused_softcap(logits, self.final_logit_softcapping)
|
549
|
+
else:
|
550
|
+
logits = self.final_logit_softcapping * torch.tanh(
|
551
|
+
logits / self.final_logit_softcapping
|
552
|
+
)
|
521
553
|
|
522
554
|
return logits
|
523
555
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
1
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunner, MoeRunnerConfig
|
2
2
|
from sglang.srt.layers.moe.utils import (
|
3
3
|
DeepEPMode,
|
4
4
|
MoeA2ABackend,
|
@@ -17,6 +17,7 @@ from sglang.srt.layers.moe.utils import (
|
|
17
17
|
__all__ = [
|
18
18
|
"DeepEPMode",
|
19
19
|
"MoeA2ABackend",
|
20
|
+
"MoeRunner",
|
20
21
|
"MoeRunnerConfig",
|
21
22
|
"MoeRunnerBackend",
|
22
23
|
"initialize_moe_config",
|
@@ -147,8 +147,8 @@ def cutlass_w4a8_moe(
|
|
147
147
|
k,
|
148
148
|
)
|
149
149
|
|
150
|
-
c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.
|
151
|
-
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.
|
150
|
+
c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.bfloat16)
|
151
|
+
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.bfloat16)
|
152
152
|
|
153
153
|
cutlass_w4a8_moe_mm(
|
154
154
|
c1,
|
@@ -166,7 +166,7 @@ def cutlass_w4a8_moe(
|
|
166
166
|
topk,
|
167
167
|
)
|
168
168
|
|
169
|
-
intermediate = torch.empty((m * topk, n), device=device, dtype=torch.
|
169
|
+
intermediate = torch.empty((m * topk, n), device=device, dtype=torch.bfloat16)
|
170
170
|
silu_and_mul(c1, intermediate)
|
171
171
|
|
172
172
|
intermediate_q = torch.empty(
|
@@ -1104,10 +1104,10 @@ def ep_gather(
|
|
1104
1104
|
input_index: torch.Tensor,
|
1105
1105
|
output_tensor: torch.Tensor,
|
1106
1106
|
):
|
1107
|
-
BLOCK_D = 1024 if not is_in_ci() else 128 # block size of quantization
|
1108
1107
|
num_warps = 2
|
1109
1108
|
num_tokens = output_tensor.shape[0]
|
1110
1109
|
hidden_size = input_tensor.shape[1]
|
1110
|
+
BLOCK_D = 128 if hidden_size % 1024 != 0 else 1024 # block size of quantization
|
1111
1111
|
assert hidden_size % BLOCK_D == 0
|
1112
1112
|
grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024))
|
1113
1113
|
_fwd_kernel_ep_gather[grid](
|
@@ -1416,7 +1416,7 @@ def zero_experts_compute_triton(
|
|
1416
1416
|
zero_expert_scales[zero_expert_mask] = 0.0
|
1417
1417
|
|
1418
1418
|
normal_expert_mask = expert_indices >= num_experts
|
1419
|
-
expert_indices[normal_expert_mask] =
|
1419
|
+
expert_indices[normal_expert_mask] = -1
|
1420
1420
|
expert_scales[normal_expert_mask] = 0.0
|
1421
1421
|
|
1422
1422
|
output = torch.zeros_like(hidden_states).to(hidden_states.device)
|