sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,348 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
6
|
+
from sglang.srt.lora.triton_ops import (
|
7
|
+
chunked_sgmv_lora_expand_forward,
|
8
|
+
chunked_sgmv_lora_shrink_forward,
|
9
|
+
)
|
10
|
+
from sglang.srt.lora.utils import LoRABatchInfo
|
11
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
12
|
+
from sglang.srt.server_args import ServerArgs
|
13
|
+
|
14
|
+
MIN_CHUNK_SIZE = 16
|
15
|
+
|
16
|
+
|
17
|
+
class ChunkedSgmvLoRABackend(BaseLoRABackend):
|
18
|
+
"""
|
19
|
+
Chunked LoRA backend using segmented matrix-vector multiplication.
|
20
|
+
|
21
|
+
This backend is largely based on the SGMV (Segmented Gather Matrix-Vector multiplication) algorithm
|
22
|
+
introduced in the Punica paper (https://arxiv.org/pdf/2310.18547). One main variation made here is to
|
23
|
+
segment the input sequences into fixed-size chunks, which reduces excessive kernel launches especially
|
24
|
+
when the LoRA distribution is skewed.
|
25
|
+
"""
|
26
|
+
|
27
|
+
name = "csgmv"
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
max_loras_per_batch: int,
|
32
|
+
device: torch.device,
|
33
|
+
server_args: ServerArgs,
|
34
|
+
):
|
35
|
+
super().__init__(max_loras_per_batch, device)
|
36
|
+
self.max_chunk_size = server_args.max_lora_chunk_size
|
37
|
+
|
38
|
+
def run_lora_a_sgemm(
|
39
|
+
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
40
|
+
) -> torch.Tensor:
|
41
|
+
return chunked_sgmv_lora_shrink_forward(
|
42
|
+
x=x,
|
43
|
+
weights=weights,
|
44
|
+
batch_info=self.batch_info,
|
45
|
+
num_slices=1,
|
46
|
+
)
|
47
|
+
|
48
|
+
def run_lora_b_sgemm(
|
49
|
+
self,
|
50
|
+
x: torch.Tensor,
|
51
|
+
weights: torch.Tensor,
|
52
|
+
output_offset: torch.Tensor,
|
53
|
+
base_output: torch.Tensor = None,
|
54
|
+
*args,
|
55
|
+
**kwargs
|
56
|
+
) -> torch.Tensor:
|
57
|
+
# For simple lora B, we use slice offsets [0, output_dim]
|
58
|
+
output_dim = weights.shape[-2]
|
59
|
+
max_slice_size = output_dim
|
60
|
+
return chunked_sgmv_lora_expand_forward(
|
61
|
+
x=x,
|
62
|
+
weights=weights,
|
63
|
+
batch_info=self.batch_info,
|
64
|
+
slice_offsets=output_offset,
|
65
|
+
max_slice_size=max_slice_size,
|
66
|
+
base_output=base_output,
|
67
|
+
)
|
68
|
+
|
69
|
+
def run_qkv_lora(
|
70
|
+
self,
|
71
|
+
x: torch.Tensor,
|
72
|
+
qkv_lora_a: torch.Tensor,
|
73
|
+
qkv_lora_b: torch.Tensor,
|
74
|
+
output_offset: torch.Tensor,
|
75
|
+
max_qkv_out_dim: int,
|
76
|
+
base_output: torch.Tensor = None,
|
77
|
+
*args,
|
78
|
+
**kwargs
|
79
|
+
) -> torch.Tensor:
|
80
|
+
|
81
|
+
# x: (s, input_dim)
|
82
|
+
# qkv_lora_a: (num_lora, 3 * r, input_dim)
|
83
|
+
# qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
|
84
|
+
assert isinstance(qkv_lora_b, torch.Tensor)
|
85
|
+
|
86
|
+
lora_a_output = chunked_sgmv_lora_shrink_forward(
|
87
|
+
x=x,
|
88
|
+
weights=qkv_lora_a,
|
89
|
+
batch_info=self.batch_info,
|
90
|
+
num_slices=3,
|
91
|
+
)
|
92
|
+
lora_output = chunked_sgmv_lora_expand_forward(
|
93
|
+
x=lora_a_output,
|
94
|
+
weights=qkv_lora_b,
|
95
|
+
batch_info=self.batch_info,
|
96
|
+
slice_offsets=output_offset,
|
97
|
+
max_slice_size=max_qkv_out_dim,
|
98
|
+
base_output=base_output,
|
99
|
+
)
|
100
|
+
return lora_output
|
101
|
+
|
102
|
+
def run_gate_up_lora(
|
103
|
+
self,
|
104
|
+
x: torch.Tensor,
|
105
|
+
gate_up_lora_a: torch.Tensor,
|
106
|
+
gate_up_lora_b: torch.Tensor,
|
107
|
+
output_offset: torch.Tensor,
|
108
|
+
base_output: torch.Tensor = None,
|
109
|
+
*args,
|
110
|
+
**kwargs
|
111
|
+
) -> torch.Tensor:
|
112
|
+
|
113
|
+
# x: (s, input_dim)
|
114
|
+
# gate_up_lora_a: (num_lora, 2 * r, input_dim)
|
115
|
+
# gate_up_lora_b: (num_lora, 2 * output_dim, r)
|
116
|
+
assert isinstance(gate_up_lora_b, torch.Tensor)
|
117
|
+
output_dim = gate_up_lora_b.shape[-2] // 2
|
118
|
+
|
119
|
+
# lora_a_output: (s, 2 * r)
|
120
|
+
lora_a_output = chunked_sgmv_lora_shrink_forward(
|
121
|
+
x=x,
|
122
|
+
weights=gate_up_lora_a,
|
123
|
+
batch_info=self.batch_info,
|
124
|
+
num_slices=2,
|
125
|
+
)
|
126
|
+
lora_output = chunked_sgmv_lora_expand_forward(
|
127
|
+
x=lora_a_output,
|
128
|
+
weights=gate_up_lora_b,
|
129
|
+
batch_info=self.batch_info,
|
130
|
+
slice_offsets=output_offset,
|
131
|
+
max_slice_size=output_dim,
|
132
|
+
base_output=base_output,
|
133
|
+
)
|
134
|
+
return lora_output
|
135
|
+
|
136
|
+
def _determine_chunk_size(self, forward_batch: ForwardBatch) -> int:
|
137
|
+
"""
|
138
|
+
Heuristically determine the chunk size based on token token number in a batch.
|
139
|
+
|
140
|
+
Args:
|
141
|
+
forward_batch (ForwardBatch): The batch information containing sequence lengths.
|
142
|
+
|
143
|
+
Returns:
|
144
|
+
The determined chunk size
|
145
|
+
"""
|
146
|
+
|
147
|
+
if self.max_chunk_size <= MIN_CHUNK_SIZE:
|
148
|
+
return MIN_CHUNK_SIZE
|
149
|
+
|
150
|
+
num_tokens = (
|
151
|
+
forward_batch.extend_num_tokens
|
152
|
+
if forward_batch.forward_mode.is_extend()
|
153
|
+
else forward_batch.batch_size
|
154
|
+
)
|
155
|
+
if num_tokens >= 256:
|
156
|
+
chunk_size = 128
|
157
|
+
elif num_tokens >= 64:
|
158
|
+
chunk_size = 32
|
159
|
+
else: # num_tokens < 64
|
160
|
+
chunk_size = 16
|
161
|
+
return min(self.max_chunk_size, chunk_size)
|
162
|
+
|
163
|
+
def prepare_lora_batch(
|
164
|
+
self,
|
165
|
+
forward_batch: ForwardBatch,
|
166
|
+
weight_indices: list[int],
|
167
|
+
lora_ranks: list[int],
|
168
|
+
scalings: list[float],
|
169
|
+
batch_info: Optional[LoRABatchInfo] = None,
|
170
|
+
):
|
171
|
+
chunk_size = self._determine_chunk_size(forward_batch)
|
172
|
+
|
173
|
+
permutation, weight_indices_reordered = ChunkedSgmvLoRABackend._get_permutation(
|
174
|
+
seq_weight_indices=weight_indices,
|
175
|
+
forward_batch=forward_batch,
|
176
|
+
)
|
177
|
+
|
178
|
+
seg_weight_indices, seg_indptr = self._get_segments_info(
|
179
|
+
weights_reordered=weight_indices_reordered,
|
180
|
+
chunk_size=chunk_size,
|
181
|
+
)
|
182
|
+
num_segments = len(seg_weight_indices)
|
183
|
+
|
184
|
+
lora_ranks_tensor = torch.tensor(
|
185
|
+
lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
|
186
|
+
)
|
187
|
+
scalings_tensor = torch.tensor(
|
188
|
+
scalings, dtype=torch.float, pin_memory=True, device="cpu"
|
189
|
+
)
|
190
|
+
|
191
|
+
if batch_info is None:
|
192
|
+
batch_info = LoRABatchInfo(
|
193
|
+
bs=forward_batch.batch_size,
|
194
|
+
num_segments=num_segments,
|
195
|
+
max_len=chunk_size,
|
196
|
+
use_cuda_graph=False,
|
197
|
+
seg_indptr=torch.empty(
|
198
|
+
(num_segments + 1,), dtype=torch.int32, device=self.device
|
199
|
+
),
|
200
|
+
weight_indices=torch.empty(
|
201
|
+
(num_segments,), dtype=torch.int32, device=self.device
|
202
|
+
),
|
203
|
+
lora_ranks=torch.empty(
|
204
|
+
(self.max_loras_per_batch,), dtype=torch.int32, device=self.device
|
205
|
+
),
|
206
|
+
scalings=torch.empty(
|
207
|
+
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
|
208
|
+
),
|
209
|
+
permutation=torch.empty(
|
210
|
+
(len(permutation),), dtype=torch.int32, device=self.device
|
211
|
+
),
|
212
|
+
# Not used in chunked kernels
|
213
|
+
seg_lens=None,
|
214
|
+
)
|
215
|
+
else:
|
216
|
+
batch_info.bs = forward_batch.batch_size
|
217
|
+
batch_info.num_segments = num_segments
|
218
|
+
batch_info.max_len = chunk_size
|
219
|
+
|
220
|
+
# Copy to device asynchronously
|
221
|
+
batch_info.lora_ranks[: self.max_loras_per_batch].copy_(
|
222
|
+
lora_ranks_tensor, non_blocking=True
|
223
|
+
)
|
224
|
+
batch_info.scalings[: self.max_loras_per_batch].copy_(
|
225
|
+
scalings_tensor, non_blocking=True
|
226
|
+
)
|
227
|
+
batch_info.weight_indices[:num_segments].copy_(
|
228
|
+
seg_weight_indices, non_blocking=True
|
229
|
+
)
|
230
|
+
batch_info.seg_indptr[: num_segments + 1].copy_(seg_indptr, non_blocking=True)
|
231
|
+
batch_info.permutation[: len(permutation)].copy_(permutation, non_blocking=True)
|
232
|
+
|
233
|
+
self.batch_info = batch_info
|
234
|
+
|
235
|
+
@staticmethod
|
236
|
+
def _get_permutation(seq_weight_indices, forward_batch: ForwardBatch):
|
237
|
+
"""
|
238
|
+
Computes permutation indices for reordering tokens by their LoRA adapter assignments.
|
239
|
+
|
240
|
+
This function implements the "gather" step in Chunked Segmented Gather Matrix Vector
|
241
|
+
multiplication by creating a permutation that groups tokens by their LoRA adapter.
|
242
|
+
Tokens using the same LoRA adapter are placed together to enable efficient batched
|
243
|
+
computation.
|
244
|
+
|
245
|
+
Example:
|
246
|
+
seq_weight_indices = [0, 1, 0] # 3 sequences using adapters [0, 1, 0]
|
247
|
+
extend_seq_lens = [2, 1, 3] # sequence lengths [2, 1, 3 tokens]
|
248
|
+
|
249
|
+
# Creates row_weight_indices: [0, 0, 1, 0, 0, 0] (6 tokens total)
|
250
|
+
# Returns permutation: [0, 1, 3, 4, 5, 2] (groups adapter 0 tokens together)
|
251
|
+
# weights_reordered: [0, 0, 0, 0, 0, 1] (sorted by adapter)
|
252
|
+
|
253
|
+
Args:
|
254
|
+
seq_weight_indices: List of LoRA adapter indices for each sequence
|
255
|
+
forward_batch (ForwardBatch): Batch information containing sequence lengths
|
256
|
+
|
257
|
+
Returns:
|
258
|
+
tuple: (permutation, weights_reordered) where:
|
259
|
+
- permutation: Token reordering indices to group by adapter
|
260
|
+
- weights_reordered: Sorted adapter indices for each token
|
261
|
+
"""
|
262
|
+
with torch.device("cpu"):
|
263
|
+
seq_weight_indices = torch.tensor(seq_weight_indices, dtype=torch.int32)
|
264
|
+
|
265
|
+
seg_lens_cpu = (
|
266
|
+
torch.tensor(
|
267
|
+
forward_batch.extend_seq_lens_cpu,
|
268
|
+
dtype=torch.int32,
|
269
|
+
)
|
270
|
+
if forward_batch.forward_mode.is_extend()
|
271
|
+
else torch.ones(forward_batch.batch_size, dtype=torch.int32)
|
272
|
+
)
|
273
|
+
|
274
|
+
row_weight_indices = torch.repeat_interleave(
|
275
|
+
seq_weight_indices, seg_lens_cpu
|
276
|
+
)
|
277
|
+
permutation = torch.empty(
|
278
|
+
(len(row_weight_indices),), dtype=torch.long, pin_memory=True
|
279
|
+
)
|
280
|
+
torch.argsort(row_weight_indices, stable=True, out=permutation)
|
281
|
+
weights_reordered = row_weight_indices[permutation]
|
282
|
+
|
283
|
+
return permutation, weights_reordered
|
284
|
+
|
285
|
+
def _get_segments_info(self, weights_reordered: torch.Tensor, chunk_size: int):
|
286
|
+
"""
|
287
|
+
Computes segment information for chunked SGMV operations.
|
288
|
+
|
289
|
+
This function takes the reordered weight indices and creates segments of fixed size
|
290
|
+
(self.segment_size) for efficient kernel execution. Each segment contains tokens
|
291
|
+
that use the same LoRA adapter, enabling vectorized computation.
|
292
|
+
|
293
|
+
The segmentation is necessary because:
|
294
|
+
1. GPU kernels work efficiently on fixed-size blocks
|
295
|
+
2. Large groups of tokens using the same adapter are split into manageable chunks
|
296
|
+
3. Each segment can be processed independently in parallel
|
297
|
+
|
298
|
+
Example:
|
299
|
+
weights_reordered = [0, 0, 0, 0, 0, 1] # 5 tokens with adapter 0, 1 with adapter 1
|
300
|
+
segment_size = 3
|
301
|
+
|
302
|
+
# Creates segments:
|
303
|
+
# Segment 0: tokens 0-2 (adapter 0), length=3
|
304
|
+
# Segment 1: tokens 3-4 (adapter 0), length=2
|
305
|
+
# Segment 2: token 5 (adapter 1), length=1
|
306
|
+
|
307
|
+
# Returns:
|
308
|
+
# weight_indices_list: [0, 0, 1] (adapter for each segment)
|
309
|
+
# seg_indptr: [0, 3, 5, 6] (cumulative segment boundaries)
|
310
|
+
|
311
|
+
Args:
|
312
|
+
weights_reordered (torch.Tensor): Sorted adapter indices for each token
|
313
|
+
chunk_size (int): Fixed size for each segment
|
314
|
+
|
315
|
+
Returns:
|
316
|
+
tuple: (weight_indices_list, seg_indptr) where:
|
317
|
+
- weight_indices_list: LoRA adapter index for each segment
|
318
|
+
- seg_indptr: Cumulative segment boundaries (CSR-style indptr)
|
319
|
+
"""
|
320
|
+
with torch.device("cpu"):
|
321
|
+
unique_weights, counts = torch.unique_consecutive(
|
322
|
+
weights_reordered, return_counts=True
|
323
|
+
)
|
324
|
+
|
325
|
+
weight_indices_list = []
|
326
|
+
seg_lens_list = []
|
327
|
+
|
328
|
+
for weight_idx, group_len in zip(unique_weights, counts):
|
329
|
+
group_len = group_len.item()
|
330
|
+
num_segs = (group_len + chunk_size - 1) // chunk_size
|
331
|
+
|
332
|
+
weight_indices_list.extend([weight_idx.item()] * num_segs)
|
333
|
+
seg_lens_list.extend([chunk_size] * (num_segs - 1))
|
334
|
+
seg_lens_list.append(group_len - (num_segs - 1) * chunk_size)
|
335
|
+
|
336
|
+
seg_lens = torch.tensor(seg_lens_list, dtype=torch.int32)
|
337
|
+
|
338
|
+
weight_indices_list = torch.tensor(
|
339
|
+
weight_indices_list, dtype=torch.int32, pin_memory=True
|
340
|
+
)
|
341
|
+
|
342
|
+
seg_indptr = torch.empty(
|
343
|
+
(len(seg_lens) + 1,), dtype=torch.int32, pin_memory=True
|
344
|
+
)
|
345
|
+
seg_indptr[0] = 0
|
346
|
+
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
|
347
|
+
|
348
|
+
return weight_indices_list, seg_indptr
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
1
3
|
import torch
|
2
4
|
|
3
5
|
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
@@ -8,12 +10,20 @@ from sglang.srt.lora.triton_ops import (
|
|
8
10
|
sgemm_lora_b_fwd,
|
9
11
|
)
|
10
12
|
from sglang.srt.lora.utils import LoRABatchInfo
|
13
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
14
|
+
from sglang.srt.server_args import ServerArgs
|
11
15
|
|
12
16
|
|
13
17
|
class TritonLoRABackend(BaseLoRABackend):
|
18
|
+
name = "triton"
|
14
19
|
|
15
|
-
def __init__(
|
16
|
-
|
20
|
+
def __init__(
|
21
|
+
self,
|
22
|
+
max_loras_per_batch: int,
|
23
|
+
device: torch.device,
|
24
|
+
**kwargs,
|
25
|
+
):
|
26
|
+
super().__init__(max_loras_per_batch, device)
|
17
27
|
|
18
28
|
def run_lora_a_sgemm(
|
19
29
|
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
@@ -26,7 +36,7 @@ class TritonLoRABackend(BaseLoRABackend):
|
|
26
36
|
weights: torch.Tensor,
|
27
37
|
base_output: torch.Tensor = None,
|
28
38
|
*args,
|
29
|
-
**kwargs
|
39
|
+
**kwargs,
|
30
40
|
) -> torch.Tensor:
|
31
41
|
return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output)
|
32
42
|
|
@@ -39,7 +49,7 @@ class TritonLoRABackend(BaseLoRABackend):
|
|
39
49
|
max_qkv_out_dim: int,
|
40
50
|
base_output: torch.Tensor = None,
|
41
51
|
*args,
|
42
|
-
**kwargs
|
52
|
+
**kwargs,
|
43
53
|
) -> torch.Tensor:
|
44
54
|
|
45
55
|
# x: (s, input_dim)
|
@@ -65,7 +75,7 @@ class TritonLoRABackend(BaseLoRABackend):
|
|
65
75
|
gate_up_lora_b: torch.Tensor,
|
66
76
|
base_output: torch.Tensor = None,
|
67
77
|
*args,
|
68
|
-
**kwargs
|
78
|
+
**kwargs,
|
69
79
|
) -> torch.Tensor:
|
70
80
|
|
71
81
|
# x: (s, input_dim)
|
@@ -86,3 +96,87 @@ class TritonLoRABackend(BaseLoRABackend):
|
|
86
96
|
base_output,
|
87
97
|
)
|
88
98
|
return lora_output
|
99
|
+
|
100
|
+
def init_cuda_graph_batch_info(
|
101
|
+
self, cuda_graph_batch_info: LoRABatchInfo, max_bs_in_cuda_graph: int
|
102
|
+
):
|
103
|
+
# Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
|
104
|
+
# across batches.
|
105
|
+
cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph].fill_(1)
|
106
|
+
torch.cumsum(
|
107
|
+
cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph],
|
108
|
+
dim=0,
|
109
|
+
out=cuda_graph_batch_info.seg_indptr[1 : max_bs_in_cuda_graph + 1],
|
110
|
+
)
|
111
|
+
|
112
|
+
def prepare_lora_batch(
|
113
|
+
self,
|
114
|
+
forward_batch: ForwardBatch,
|
115
|
+
weight_indices: list[int],
|
116
|
+
lora_ranks: list[int],
|
117
|
+
scalings: list[float],
|
118
|
+
batch_info: Optional[LoRABatchInfo] = None,
|
119
|
+
):
|
120
|
+
# Use pinned memory to avoid synchronizations during host-to-device transfer
|
121
|
+
weight_indices_tensor = torch.tensor(
|
122
|
+
weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
|
123
|
+
)
|
124
|
+
lora_ranks_tensor = torch.tensor(
|
125
|
+
lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
|
126
|
+
)
|
127
|
+
scalings_tensor = torch.tensor(
|
128
|
+
scalings, dtype=torch.float, pin_memory=True, device="cpu"
|
129
|
+
)
|
130
|
+
|
131
|
+
bs = forward_batch.batch_size
|
132
|
+
|
133
|
+
if batch_info is not None:
|
134
|
+
assert (
|
135
|
+
batch_info.use_cuda_graph
|
136
|
+
), "batch_info.use_cuda_graph must be True when batch_info is provided"
|
137
|
+
batch_info.bs = forward_batch.batch_size
|
138
|
+
batch_info.num_segments = forward_batch.batch_size
|
139
|
+
else:
|
140
|
+
max_len = (
|
141
|
+
# Calculate max_len from the CPU copy to avoid D2H transfer.
|
142
|
+
max(forward_batch.extend_seq_lens_cpu)
|
143
|
+
if forward_batch.forward_mode.is_extend()
|
144
|
+
else 1
|
145
|
+
)
|
146
|
+
seg_lens = (
|
147
|
+
forward_batch.extend_seq_lens
|
148
|
+
if forward_batch.forward_mode.is_extend()
|
149
|
+
else torch.ones(bs, device=self.device)
|
150
|
+
)
|
151
|
+
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
|
152
|
+
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
|
153
|
+
|
154
|
+
batch_info = LoRABatchInfo(
|
155
|
+
bs=forward_batch.batch_size,
|
156
|
+
num_segments=forward_batch.batch_size,
|
157
|
+
max_len=max_len,
|
158
|
+
use_cuda_graph=False,
|
159
|
+
seg_lens=seg_lens,
|
160
|
+
seg_indptr=seg_indptr,
|
161
|
+
weight_indices=torch.empty(
|
162
|
+
(bs,), dtype=torch.int32, device=self.device
|
163
|
+
),
|
164
|
+
lora_ranks=torch.empty(
|
165
|
+
(self.max_loras_per_batch,), dtype=torch.int64, device=self.device
|
166
|
+
),
|
167
|
+
scalings=torch.empty(
|
168
|
+
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
|
169
|
+
),
|
170
|
+
permutation=None,
|
171
|
+
)
|
172
|
+
|
173
|
+
# Copy to device asynchronously
|
174
|
+
batch_info.lora_ranks[: self.max_loras_per_batch].copy_(
|
175
|
+
lora_ranks_tensor, non_blocking=True
|
176
|
+
)
|
177
|
+
batch_info.scalings[: self.max_loras_per_batch].copy_(
|
178
|
+
scalings_tensor, non_blocking=True
|
179
|
+
)
|
180
|
+
batch_info.weight_indices[:bs].copy_(weight_indices_tensor, non_blocking=True)
|
181
|
+
|
182
|
+
self.batch_info = batch_info
|
sglang/srt/lora/layers.py
CHANGED
@@ -66,6 +66,15 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
66
66
|
lora_backend: BaseLoRABackend,
|
67
67
|
) -> None:
|
68
68
|
super().__init__(base_layer, lora_backend)
|
69
|
+
shard_size = self.base_layer.output_partition_sizes[0]
|
70
|
+
self.output_offset = torch.tensor(
|
71
|
+
[
|
72
|
+
0,
|
73
|
+
shard_size,
|
74
|
+
],
|
75
|
+
dtype=torch.int32,
|
76
|
+
device=next(self.base_layer.parameters()).device,
|
77
|
+
)
|
69
78
|
|
70
79
|
def set_lora_info(
|
71
80
|
self,
|
@@ -81,6 +90,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
81
90
|
lora_output = self.lora_backend.run_lora_b_sgemm(
|
82
91
|
x=lora_a_output,
|
83
92
|
weights=self.B_buffer,
|
93
|
+
output_offset=self.output_offset,
|
84
94
|
base_output=base_output,
|
85
95
|
)
|
86
96
|
return lora_output
|
@@ -130,11 +140,23 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
130
140
|
self.A_buffer_gate_up = A_buffer
|
131
141
|
self.B_buffer_gate_up = B_buffer
|
132
142
|
|
143
|
+
shard_size = self.base_layer.output_partition_sizes[0]
|
144
|
+
self.output_offset = torch.tensor(
|
145
|
+
[
|
146
|
+
0,
|
147
|
+
shard_size,
|
148
|
+
2 * shard_size,
|
149
|
+
],
|
150
|
+
dtype=torch.int32,
|
151
|
+
device=next(self.base_layer.parameters()).device,
|
152
|
+
)
|
153
|
+
|
133
154
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
134
155
|
lora_output = self.lora_backend.run_gate_up_lora(
|
135
156
|
x=x,
|
136
157
|
gate_up_lora_a=self.A_buffer_gate_up,
|
137
158
|
gate_up_lora_b=self.B_buffer_gate_up,
|
159
|
+
output_offset=self.output_offset,
|
138
160
|
base_output=base_output,
|
139
161
|
)
|
140
162
|
return lora_output
|
@@ -243,12 +265,22 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
243
265
|
self.set_lora = True
|
244
266
|
self.A_buffer = A_buffer
|
245
267
|
self.B_buffer = B_buffer
|
268
|
+
output_size = self.base_layer.output_size
|
269
|
+
self.output_offset = torch.tensor(
|
270
|
+
[
|
271
|
+
0,
|
272
|
+
output_size,
|
273
|
+
],
|
274
|
+
dtype=torch.int32,
|
275
|
+
device=next(self.base_layer.parameters()).device,
|
276
|
+
)
|
246
277
|
|
247
278
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
248
279
|
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
249
280
|
lora_output = self.lora_backend.run_lora_b_sgemm(
|
250
281
|
x=lora_a_output,
|
251
282
|
weights=self.B_buffer,
|
283
|
+
output_offset=self.output_offset,
|
252
284
|
base_output=base_output,
|
253
285
|
)
|
254
286
|
return lora_output
|
sglang/srt/lora/lora.py
CHANGED
@@ -26,13 +26,17 @@ import torch
|
|
26
26
|
from torch import nn
|
27
27
|
|
28
28
|
from sglang.srt.configs.load_config import LoadConfig
|
29
|
-
from sglang.srt.hf_transformers_utils import AutoConfig
|
30
29
|
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
30
|
+
from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
|
31
|
+
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
|
31
32
|
from sglang.srt.lora.lora_config import LoRAConfig
|
32
33
|
from sglang.srt.model_loader.loader import DefaultModelLoader
|
34
|
+
from sglang.srt.utils.hf_transformers_utils import AutoConfig
|
33
35
|
|
34
36
|
logger = logging.getLogger(__name__)
|
35
37
|
|
38
|
+
SUPPORTED_BACKENDS = (TritonLoRABackend, ChunkedSgmvLoRABackend)
|
39
|
+
|
36
40
|
|
37
41
|
class LoRALayer(nn.Module):
|
38
42
|
def __init__(self, config: LoRAConfig, base_hf_config: AutoConfig):
|
@@ -45,6 +49,7 @@ class LoRALayer(nn.Module):
|
|
45
49
|
|
46
50
|
|
47
51
|
class LoRAAdapter(nn.Module):
|
52
|
+
|
48
53
|
def __init__(
|
49
54
|
self,
|
50
55
|
uid: str,
|
@@ -156,8 +161,8 @@ class LoRAAdapter(nn.Module):
|
|
156
161
|
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
|
157
162
|
if up_name not in weights:
|
158
163
|
weights[up_name] = torch.zeros_like(weights[weight_name])
|
159
|
-
assert self.lora_backend
|
160
|
-
f"LoRA weight initialization currently only supported for '
|
164
|
+
assert isinstance(self.lora_backend, SUPPORTED_BACKENDS), (
|
165
|
+
f"LoRA weight initialization currently only supported for LoRA backends: {', '.join(b.name for b in SUPPORTED_BACKENDS)}"
|
161
166
|
f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
|
162
167
|
f"or consider implementing custom initialization logic for other backends."
|
163
168
|
)
|