sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
sglang/srt/layers/sampler.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
import logging
|
2
|
-
from typing import List
|
2
|
+
from typing import List, Optional, Tuple
|
3
3
|
|
4
4
|
import torch
|
5
5
|
import torch.distributed as dist
|
@@ -39,6 +39,25 @@ class Sampler(nn.Module):
|
|
39
39
|
if is_dp_attention_enabled():
|
40
40
|
self.tp_sync_group = get_attention_tp_group().device_group
|
41
41
|
|
42
|
+
def _preprocess_logits(
|
43
|
+
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
|
44
|
+
) -> torch.Tensor:
|
45
|
+
"""Apply custom logit processors and handle NaN detection."""
|
46
|
+
# Apply the custom logit processors if registered in the sampling info
|
47
|
+
if sampling_info.has_custom_logit_processor:
|
48
|
+
apply_custom_logit_processor(logits, sampling_info)
|
49
|
+
|
50
|
+
# Detect and handle NaN values in logits
|
51
|
+
if self.use_nan_detection and torch.any(torch.isnan(logits)):
|
52
|
+
logger.warning("Detected errors during sampling! NaN in the logits.")
|
53
|
+
logits = torch.where(
|
54
|
+
torch.isnan(logits), torch.full_like(logits, -1e5), logits
|
55
|
+
)
|
56
|
+
if crash_on_warnings():
|
57
|
+
raise ValueError("Detected errors during sampling! NaN in the logits.")
|
58
|
+
|
59
|
+
return logits
|
60
|
+
|
42
61
|
def forward(
|
43
62
|
self,
|
44
63
|
logits_output: LogitsProcessorOutput,
|
@@ -46,6 +65,7 @@ class Sampler(nn.Module):
|
|
46
65
|
return_logprob: bool,
|
47
66
|
top_logprobs_nums: List[int],
|
48
67
|
token_ids_logprobs: List[List[int]],
|
68
|
+
positions: torch.Tensor,
|
49
69
|
):
|
50
70
|
"""Run a sampler & compute logprobs and update logits_output accordingly.
|
51
71
|
|
@@ -58,20 +78,13 @@ class Sampler(nn.Module):
|
|
58
78
|
batch_next_token_ids: next token IDs. If set, skip sampling and only
|
59
79
|
compute output logprobs It is used for speculative decoding which
|
60
80
|
performs sampling in draft workers.
|
81
|
+
positions: The positions of the tokens in the sequence. Used for deterministic sampling
|
82
|
+
to get the unique seed for each position.
|
61
83
|
"""
|
62
84
|
logits = logits_output.next_token_logits
|
63
85
|
|
64
|
-
#
|
65
|
-
|
66
|
-
apply_custom_logit_processor(logits, sampling_info)
|
67
|
-
|
68
|
-
if self.use_nan_detection and torch.any(torch.isnan(logits)):
|
69
|
-
logger.warning("Detected errors during sampling! NaN in the logits.")
|
70
|
-
logits = torch.where(
|
71
|
-
torch.isnan(logits), torch.full_like(logits, -1e5), logits
|
72
|
-
)
|
73
|
-
if crash_on_warnings():
|
74
|
-
raise ValueError("Detected errors during sampling! NaN in the logits.")
|
86
|
+
# Preprocess logits (custom processors and NaN handling)
|
87
|
+
logits = self._preprocess_logits(logits, sampling_info)
|
75
88
|
|
76
89
|
if sampling_info.is_all_greedy:
|
77
90
|
# Use torch.argmax if all requests use greedy sampling
|
@@ -80,9 +93,9 @@ class Sampler(nn.Module):
|
|
80
93
|
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
81
94
|
|
82
95
|
else:
|
83
|
-
#
|
96
|
+
# If requested, cache probabilities from original logits before temperature scaling.
|
84
97
|
if return_logprob and RETURN_ORIGINAL_LOGPROB:
|
85
|
-
|
98
|
+
probs_without_temp_scaling = torch.softmax(logits, dim=-1)
|
86
99
|
|
87
100
|
# Post process logits
|
88
101
|
logits.div_(sampling_info.temperatures)
|
@@ -114,6 +127,8 @@ class Sampler(nn.Module):
|
|
114
127
|
sampling_info.top_ps,
|
115
128
|
sampling_info.min_ps,
|
116
129
|
sampling_info.need_min_p_sampling,
|
130
|
+
sampling_info.sampling_seed,
|
131
|
+
positions,
|
117
132
|
)
|
118
133
|
else:
|
119
134
|
raise ValueError(
|
@@ -123,9 +138,10 @@ class Sampler(nn.Module):
|
|
123
138
|
if return_logprob:
|
124
139
|
# clamp to avoid -inf
|
125
140
|
if RETURN_ORIGINAL_LOGPROB:
|
126
|
-
logprobs = torch.log(
|
127
|
-
min=torch.finfo(
|
141
|
+
logprobs = torch.log(probs_without_temp_scaling).clamp(
|
142
|
+
min=torch.finfo(probs_without_temp_scaling.dtype).min
|
128
143
|
)
|
144
|
+
del probs_without_temp_scaling
|
129
145
|
else:
|
130
146
|
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
|
131
147
|
|
@@ -164,6 +180,55 @@ class Sampler(nn.Module):
|
|
164
180
|
|
165
181
|
return batch_next_token_ids
|
166
182
|
|
183
|
+
def compute_logprobs_only(
|
184
|
+
self,
|
185
|
+
logits_output: LogitsProcessorOutput,
|
186
|
+
sampling_info: SamplingBatchInfo,
|
187
|
+
return_logprob: bool,
|
188
|
+
top_logprobs_nums: List[int],
|
189
|
+
token_ids_logprobs: List[List[int]],
|
190
|
+
) -> None:
|
191
|
+
"""
|
192
|
+
Compute logprobs for requested token IDs without performing sampling.
|
193
|
+
|
194
|
+
Optimized for prefill-only scoring requests that need token probabilities
|
195
|
+
but don't require next token generation.
|
196
|
+
"""
|
197
|
+
|
198
|
+
if logits_output.next_token_logits is None:
|
199
|
+
logger.warning("No logits available for logprob computation")
|
200
|
+
return
|
201
|
+
|
202
|
+
# Check if any requests actually need logprobs computation
|
203
|
+
needs_token_ids_logprobs = any(
|
204
|
+
token_ids is not None and len(token_ids) > 0
|
205
|
+
for token_ids in token_ids_logprobs
|
206
|
+
)
|
207
|
+
needs_top_logprobs = any(x > 0 for x in top_logprobs_nums)
|
208
|
+
|
209
|
+
if not (needs_token_ids_logprobs or needs_top_logprobs):
|
210
|
+
return
|
211
|
+
|
212
|
+
# Preprocess logits (custom processors and NaN handling)
|
213
|
+
logits = self._preprocess_logits(logits_output.next_token_logits, sampling_info)
|
214
|
+
|
215
|
+
# Compute logprobs
|
216
|
+
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
217
|
+
|
218
|
+
# Handle top logprobs if requested
|
219
|
+
if needs_top_logprobs:
|
220
|
+
(
|
221
|
+
logits_output.next_token_top_logprobs_val,
|
222
|
+
logits_output.next_token_top_logprobs_idx,
|
223
|
+
) = get_top_logprobs(logprobs, top_logprobs_nums)
|
224
|
+
|
225
|
+
# Handle token_ids logprobs if requested
|
226
|
+
if needs_token_ids_logprobs:
|
227
|
+
(
|
228
|
+
logits_output.next_token_token_ids_logprobs_val,
|
229
|
+
logits_output.next_token_token_ids_logprobs_idx,
|
230
|
+
) = get_token_ids_logprobs_batch_optimized(logprobs, token_ids_logprobs)
|
231
|
+
|
167
232
|
|
168
233
|
def top_k_top_p_min_p_sampling_from_probs_torch(
|
169
234
|
probs: torch.Tensor,
|
@@ -171,8 +236,14 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
|
|
171
236
|
top_ps: torch.Tensor,
|
172
237
|
min_ps: torch.Tensor,
|
173
238
|
need_min_p_sampling: bool,
|
239
|
+
sampling_seed: Optional[torch.Tensor],
|
240
|
+
positions: torch.Tensor,
|
174
241
|
):
|
175
|
-
"""
|
242
|
+
"""
|
243
|
+
A top-k, top-p and min-p sampling implementation with native pytorch operations.
|
244
|
+
When sampling_seed is not None, deterministic inference will be enabled, it will sample
|
245
|
+
with the sampling_seed of each request.
|
246
|
+
"""
|
176
247
|
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
177
248
|
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
178
249
|
probs_sort[
|
@@ -184,14 +255,50 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
|
|
184
255
|
if need_min_p_sampling:
|
185
256
|
min_p_thresholds = probs_sort[:, 0] * min_ps
|
186
257
|
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
187
|
-
|
188
|
-
|
258
|
+
if sampling_seed is not None:
|
259
|
+
sampled_index = multinomial_with_seed(probs_sort, sampling_seed, positions)
|
260
|
+
else:
|
261
|
+
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
189
262
|
# int32 range is enough to represent the token ids
|
190
263
|
probs_idx = probs_idx.to(torch.int32)
|
191
264
|
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
192
265
|
return batch_next_token_ids
|
193
266
|
|
194
267
|
|
268
|
+
def multinomial_with_seed(
|
269
|
+
inputs: torch.Tensor, seed: torch.Tensor, positions: torch.Tensor
|
270
|
+
) -> torch.Tensor:
|
271
|
+
"""
|
272
|
+
Samples n elements from an input tensor `inputs` of shape (n, m) using
|
273
|
+
a unique random seed for each row. This is a deterministic batched alternative to
|
274
|
+
`torch.multinomial`.
|
275
|
+
|
276
|
+
Args:
|
277
|
+
inputs: A float tensor of shape (n, m) representing n categorical
|
278
|
+
distributions with m categories each. The values are treated
|
279
|
+
as weights and do not need to sum to 1.
|
280
|
+
seed: An integer tensor of shape (n,) containing the random seed
|
281
|
+
for each corresponding row in `inputs`.
|
282
|
+
positions: The positions of the tokens in the sequence. Used for deterministic sampling
|
283
|
+
to get the unique seed for each position.
|
284
|
+
|
285
|
+
Returns:
|
286
|
+
A tensor of shape (n,) where the i-th element is an index sampled
|
287
|
+
from the distribution in `inputs[i]` using `seed[i]`.
|
288
|
+
"""
|
289
|
+
n, m = inputs.shape
|
290
|
+
col_indices = torch.arange(m, device=inputs.device).unsqueeze(0)
|
291
|
+
step_seed = seed * 19349663 ^ positions * 73856093
|
292
|
+
seed_expanded = step_seed.unsqueeze(-1)
|
293
|
+
hashed = seed_expanded * 8589934591 ^ col_indices * 479001599
|
294
|
+
uniform_samples = (hashed % (2**24)).float() / (2**24)
|
295
|
+
epsilon = 1e-9
|
296
|
+
gumbel_noise = -torch.log(-torch.log(uniform_samples + epsilon) + epsilon)
|
297
|
+
log_probs = torch.log(inputs + epsilon)
|
298
|
+
perturbed_log_probs = log_probs + gumbel_noise
|
299
|
+
return torch.argmax(perturbed_log_probs, dim=1, keepdim=True)
|
300
|
+
|
301
|
+
|
195
302
|
def sampling_from_probs_torch(probs: torch.Tensor):
|
196
303
|
"""A sampling implementation with native pytorch operations, without
|
197
304
|
top-k, top-p, or min-p filtering."""
|
@@ -233,10 +340,95 @@ def get_top_logprobs(
|
|
233
340
|
)
|
234
341
|
|
235
342
|
|
236
|
-
def
|
343
|
+
def get_token_ids_logprobs_batch_optimized(
|
237
344
|
logprobs: torch.Tensor,
|
238
345
|
token_ids_logprobs: List[List[int]],
|
239
|
-
):
|
346
|
+
) -> Tuple[List, List]:
|
347
|
+
"""
|
348
|
+
Vectorized batch processing for token ID logprobs extraction.
|
349
|
+
|
350
|
+
Uses a single GPU kernel call for the entire batch instead of multiple
|
351
|
+
separate calls, significantly improving performance for large batches.
|
352
|
+
|
353
|
+
Args:
|
354
|
+
logprobs: Log probabilities tensor [batch_size, vocab_size]
|
355
|
+
token_ids_logprobs: List of token IDs to extract logprobs for
|
356
|
+
|
357
|
+
Example:
|
358
|
+
# Input: batch_size=3, vocab_size=5
|
359
|
+
logprobs = torch.tensor([
|
360
|
+
[-1.2, -2.1, -0.8, -3.0, -1.5], # batch 0
|
361
|
+
[-0.5, -1.8, -2.2, -1.1, -2.7], # batch 1
|
362
|
+
[-2.0, -0.9, -1.4, -2.8, -1.6], # batch 2
|
363
|
+
])
|
364
|
+
token_ids_logprobs = [[1, 3], [2], [0, 2, 4]]
|
365
|
+
|
366
|
+
# Output:
|
367
|
+
# values = [tensor([-2.1, -3.0]), tensor([-2.2]), tensor([-2.0, -1.4, -1.6])]
|
368
|
+
# indices = [[1, 3], [2], [0, 2, 4]]
|
369
|
+
"""
|
370
|
+
batch_size = len(token_ids_logprobs)
|
371
|
+
device = logprobs.device
|
372
|
+
|
373
|
+
# Step 1: Calculate lengths for each request, treating None as empty list
|
374
|
+
# Example: [[1, 3], [2], [0, 2, 4]] -> token_lengths = tensor([2, 1, 3])
|
375
|
+
token_lengths = torch.tensor(
|
376
|
+
[len(token_ids or []) for token_ids in token_ids_logprobs], device=device
|
377
|
+
)
|
378
|
+
total_tokens = int(token_lengths.sum().item()) # 2 + 1 + 3 = 6
|
379
|
+
|
380
|
+
# Handle edge case where no tokens are requested
|
381
|
+
if total_tokens == 0:
|
382
|
+
return [logprobs.new_empty(0) for _ in token_ids_logprobs], [
|
383
|
+
[] for _ in token_ids_logprobs
|
384
|
+
]
|
385
|
+
|
386
|
+
# Step 2: Build flattened indices using torch operations
|
387
|
+
# Example: row_indices = [0, 0, 1, 2, 2, 2] (batch indices repeated by their lengths)
|
388
|
+
row_indices = torch.repeat_interleave(
|
389
|
+
torch.arange(batch_size, device=device), token_lengths
|
390
|
+
)
|
391
|
+
# Example: col_indices = [1, 3, 2, 0, 2, 4] (flattened token IDs from all requests)
|
392
|
+
col_indices = torch.tensor(
|
393
|
+
[
|
394
|
+
token_id
|
395
|
+
for token_ids in token_ids_logprobs
|
396
|
+
for token_id in (token_ids or [])
|
397
|
+
],
|
398
|
+
device=device,
|
399
|
+
dtype=torch.long,
|
400
|
+
)
|
401
|
+
|
402
|
+
# Step 3: Single vectorized gather operation
|
403
|
+
# Example: logprobs[row_indices, col_indices] -> [-2.1, -3.0, -2.2, -2.0, -1.4, -1.6]
|
404
|
+
gathered_logprobs = logprobs[row_indices, col_indices]
|
405
|
+
|
406
|
+
# Step 4: Split results back per request using torch operations
|
407
|
+
# Example: split tensor [6] into chunks of sizes [2, 1, 3] -> [tensor(2), tensor(1), tensor(3)]
|
408
|
+
split_logprobs = torch.split_with_sizes(
|
409
|
+
gathered_logprobs, token_lengths.tolist(), dim=0
|
410
|
+
)
|
411
|
+
|
412
|
+
# Step 5: Format output to match expected return structure
|
413
|
+
# Example: Convert split tensors back to list format with proper empty handling
|
414
|
+
# i=0: [1,3] -> append split_logprobs[0] and [1,3]
|
415
|
+
# i=1: [2] -> append split_logprobs[1] and [2]
|
416
|
+
# i=2: [0,2,4] -> append split_logprobs[2] and [0,2,4]
|
417
|
+
output_token_ids_logprobs_val = []
|
418
|
+
output_token_ids_logprobs_idx = []
|
419
|
+
|
420
|
+
for i, token_ids in enumerate(token_ids_logprobs):
|
421
|
+
if token_ids is not None and len(token_ids) > 0:
|
422
|
+
output_token_ids_logprobs_val.append(split_logprobs[i])
|
423
|
+
output_token_ids_logprobs_idx.append(token_ids)
|
424
|
+
else:
|
425
|
+
output_token_ids_logprobs_val.append(logprobs.new_empty(0))
|
426
|
+
output_token_ids_logprobs_idx.append([])
|
427
|
+
|
428
|
+
return output_token_ids_logprobs_val, output_token_ids_logprobs_idx
|
429
|
+
|
430
|
+
|
431
|
+
def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List[int]]):
|
240
432
|
output_token_ids_logprobs_val = []
|
241
433
|
output_token_ids_logprobs_idx = []
|
242
434
|
for i, token_ids in enumerate(token_ids_logprobs):
|
sglang/srt/layers/utils.py
CHANGED
@@ -15,6 +15,29 @@ def get_layer_id(weight_name):
|
|
15
15
|
return None
|
16
16
|
|
17
17
|
|
18
|
+
def pad_or_narrow_weight(
|
19
|
+
loaded_weight: torch.Tensor, input_dim: int, start_idx: int, shard_size: int
|
20
|
+
) -> torch.Tensor:
|
21
|
+
# Padding with zeros for special case such as qwen2_5_VL's mlp which is not 8-aligned
|
22
|
+
valid_size = max(loaded_weight.shape[input_dim] - start_idx, 0)
|
23
|
+
|
24
|
+
if valid_size > 0:
|
25
|
+
loaded_slice = loaded_weight.narrow(input_dim, start_idx, valid_size)
|
26
|
+
pad_shape = list(loaded_weight.shape)
|
27
|
+
pad_shape[input_dim] = shard_size - valid_size
|
28
|
+
pad = torch.zeros(
|
29
|
+
pad_shape, dtype=loaded_weight.dtype, device=loaded_weight.device
|
30
|
+
)
|
31
|
+
return torch.cat([loaded_slice, pad], dim=input_dim)
|
32
|
+
|
33
|
+
# All padding
|
34
|
+
pad_shape = list(loaded_weight.shape)
|
35
|
+
pad_shape[input_dim] = shard_size
|
36
|
+
return torch.zeros(
|
37
|
+
pad_shape, dtype=loaded_weight.dtype, device=loaded_weight.device
|
38
|
+
)
|
39
|
+
|
40
|
+
|
18
41
|
class PPMissingLayer(torch.nn.Identity):
|
19
42
|
# Adapted from
|
20
43
|
# https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1
|
@@ -1,8 +1,9 @@
|
|
1
|
-
from typing import Tuple, Union
|
1
|
+
from typing import Optional, Tuple, Union
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
5
|
from sglang.srt.lora.utils import LoRABatchInfo
|
6
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
6
7
|
|
7
8
|
|
8
9
|
class BaseLoRABackend:
|
@@ -10,13 +11,14 @@ class BaseLoRABackend:
|
|
10
11
|
Each backend has its own implementation of Lora kernels.
|
11
12
|
|
12
13
|
Args:
|
13
|
-
|
14
|
-
|
14
|
+
max_loras_per_batch: maximum number of different lora weights
|
15
|
+
that can be applied in a single forward batch.
|
16
|
+
device: the device where the backend runs.
|
15
17
|
"""
|
16
18
|
|
17
|
-
def __init__(self,
|
18
|
-
self.
|
19
|
-
self.
|
19
|
+
def __init__(self, max_loras_per_batch: int, device: torch.device):
|
20
|
+
self.max_loras_per_batch = max_loras_per_batch
|
21
|
+
self.device = device
|
20
22
|
|
21
23
|
def run_lora_a_sgemm(
|
22
24
|
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
@@ -93,8 +95,44 @@ class BaseLoRABackend:
|
|
93
95
|
"""
|
94
96
|
pass
|
95
97
|
|
96
|
-
def
|
97
|
-
self
|
98
|
+
def init_cuda_graph_batch_info(
|
99
|
+
self,
|
100
|
+
cuda_graph_batch_info: LoRABatchInfo,
|
101
|
+
max_bs_in_cuda_graph: int,
|
102
|
+
):
|
103
|
+
"""Initialize the batch info for CUDA Graph mode.
|
104
|
+
|
105
|
+
This method provides a hook for each backend to conduct its own initialization
|
106
|
+
logic for CUDA Graph mode.
|
107
|
+
|
108
|
+
Args:
|
109
|
+
cuda_graph_batch_info: the LoRABatchInfo object created in LoraManager
|
110
|
+
max_bs_in_cuda_graph: maximum batch size for CUDA Graph mode
|
111
|
+
"""
|
112
|
+
pass
|
113
|
+
|
114
|
+
def prepare_lora_batch(
|
115
|
+
self,
|
116
|
+
forward_batch: ForwardBatch,
|
117
|
+
weight_indices: list[int],
|
118
|
+
lora_ranks: list[int],
|
119
|
+
scalings: list[float],
|
120
|
+
batch_info: Optional[LoRABatchInfo] = None,
|
121
|
+
):
|
122
|
+
"""Prepare the lora weights and batch info for current forward batch.
|
123
|
+
|
124
|
+
This method provides a hook for each backend to conduct its own preparation
|
125
|
+
logic for each forward batch.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
forward_batch: the ForwardBatch object for current forward pass
|
129
|
+
weight_indices: list of indices of lora weights to be applied for current batch
|
130
|
+
lora_ranks: list of lora ranks corresponding to weight_indices
|
131
|
+
scalings: list of scaling factors corresponding to weight_indices
|
132
|
+
batch_info: optional LoRABatchInfo object, if not provided, the backend should use its own
|
133
|
+
internal batch info (e.g., self.cuda_graph_batch_info for CUDA Graph mode)
|
134
|
+
"""
|
135
|
+
pass
|
98
136
|
|
99
137
|
|
100
138
|
def get_backend_from_name(name: str) -> BaseLoRABackend:
|
@@ -105,6 +143,10 @@ def get_backend_from_name(name: str) -> BaseLoRABackend:
|
|
105
143
|
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
|
106
144
|
|
107
145
|
return TritonLoRABackend
|
146
|
+
elif name == "csgmv":
|
147
|
+
from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
|
148
|
+
|
149
|
+
return ChunkedSgmvLoRABackend
|
108
150
|
elif name == "flashinfer":
|
109
151
|
raise ValueError(
|
110
152
|
"FlashInfer LoRA backend has been deprecated, please use `triton` instead."
|