sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
sglang/srt/lora/lora_manager.py
CHANGED
@@ -21,7 +21,6 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple
|
|
21
21
|
import torch
|
22
22
|
|
23
23
|
from sglang.srt.configs.load_config import LoadConfig
|
24
|
-
from sglang.srt.hf_transformers_utils import AutoConfig
|
25
24
|
from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_from_name
|
26
25
|
from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
|
27
26
|
from sglang.srt.lora.lora import LoRAAdapter
|
@@ -35,9 +34,11 @@ from sglang.srt.lora.utils import (
|
|
35
34
|
get_normalized_target_modules,
|
36
35
|
get_target_module_name,
|
37
36
|
)
|
38
|
-
from sglang.srt.managers.io_struct import
|
37
|
+
from sglang.srt.managers.io_struct import LoRAUpdateOutput
|
39
38
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
39
|
+
from sglang.srt.server_args import ServerArgs
|
40
40
|
from sglang.srt.utils import replace_submodule
|
41
|
+
from sglang.srt.utils.hf_transformers_utils import AutoConfig
|
41
42
|
|
42
43
|
logger = logging.getLogger(__name__)
|
43
44
|
|
@@ -56,6 +57,7 @@ class LoRAManager:
|
|
56
57
|
max_lora_rank: Optional[int] = None,
|
57
58
|
target_modules: Optional[Iterable[str]] = None,
|
58
59
|
lora_paths: Optional[List[LoRARef]] = None,
|
60
|
+
server_args: Optional[ServerArgs] = None,
|
59
61
|
):
|
60
62
|
self.base_model: torch.nn.Module = base_model
|
61
63
|
self.base_hf_config: AutoConfig = base_hf_config
|
@@ -69,7 +71,11 @@ class LoRAManager:
|
|
69
71
|
# LoRA backend for running sgemm kernels
|
70
72
|
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
|
71
73
|
backend_type = get_backend_from_name(lora_backend)
|
72
|
-
self.lora_backend: BaseLoRABackend = backend_type(
|
74
|
+
self.lora_backend: BaseLoRABackend = backend_type(
|
75
|
+
max_loras_per_batch=max_loras_per_batch,
|
76
|
+
device=self.device,
|
77
|
+
server_args=server_args,
|
78
|
+
)
|
73
79
|
|
74
80
|
# Initialize mutable internal state of the LoRAManager.
|
75
81
|
self.init_state(
|
@@ -82,34 +88,27 @@ class LoRAManager:
|
|
82
88
|
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
|
83
89
|
with torch.device("cuda"):
|
84
90
|
self.cuda_graph_batch_info = LoRABatchInfo(
|
85
|
-
bs=
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
),
|
91
|
+
bs=max_bs_in_cuda_graph,
|
92
|
+
use_cuda_graph=True,
|
93
|
+
num_segments=None,
|
94
|
+
seg_lens=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
|
95
|
+
seg_indptr=torch.zeros(max_bs_in_cuda_graph + 1, dtype=torch.int32),
|
90
96
|
max_len=1,
|
91
|
-
weight_indices=torch.zeros(
|
92
|
-
|
93
|
-
),
|
97
|
+
weight_indices=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
|
98
|
+
permutation=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
|
94
99
|
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
|
95
100
|
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
|
96
101
|
)
|
97
102
|
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph],
|
103
|
-
dim=0,
|
104
|
-
out=self.cuda_graph_batch_info.seg_indptr[
|
105
|
-
1 : self.max_bs_in_cuda_graph + 1
|
106
|
-
],
|
107
|
-
)
|
103
|
+
self.lora_backend.init_cuda_graph_batch_info(
|
104
|
+
cuda_graph_batch_info=self.cuda_graph_batch_info,
|
105
|
+
max_bs_in_cuda_graph=max_bs_in_cuda_graph,
|
106
|
+
)
|
108
107
|
|
109
108
|
def create_lora_update_result(
|
110
109
|
self, success: bool, error_message: str = ""
|
111
|
-
) ->
|
112
|
-
return
|
110
|
+
) -> LoRAUpdateOutput:
|
111
|
+
return LoRAUpdateOutput(
|
113
112
|
success=success,
|
114
113
|
error_message=error_message,
|
115
114
|
loaded_adapters={
|
@@ -118,7 +117,7 @@ class LoRAManager:
|
|
118
117
|
},
|
119
118
|
)
|
120
119
|
|
121
|
-
def load_lora_adapter(self, lora_ref: LoRARef) ->
|
120
|
+
def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput:
|
122
121
|
"""
|
123
122
|
Load a single LoRA adapter from the specified path.
|
124
123
|
|
@@ -175,7 +174,7 @@ class LoRAManager:
|
|
175
174
|
"`--max-loras-per-batch` or load it as unpinned LoRA adapters."
|
176
175
|
)
|
177
176
|
|
178
|
-
def unload_lora_adapter(self, lora_ref: LoRARef) ->
|
177
|
+
def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput:
|
179
178
|
"""
|
180
179
|
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
|
181
180
|
delete the corresponding LoRA modules.
|
@@ -232,7 +231,6 @@ class LoRAManager:
|
|
232
231
|
return required_slots <= mem_pool_vacancy
|
233
232
|
|
234
233
|
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
235
|
-
|
236
234
|
# Load active loras into lora memory pool
|
237
235
|
cur_uids = set(forward_batch.lora_ids)
|
238
236
|
|
@@ -247,102 +245,30 @@ class LoRAManager:
|
|
247
245
|
# set up batch info shared by all lora modules
|
248
246
|
bs = forward_batch.batch_size
|
249
247
|
|
250
|
-
|
251
|
-
weight_indices_out: torch.Tensor,
|
252
|
-
lora_ranks_out: torch.Tensor,
|
253
|
-
scalings_out: torch.Tensor,
|
254
|
-
):
|
255
|
-
"""
|
256
|
-
Transfer adapter metadata (weight indices, LoRA rank, scalings) from host
|
257
|
-
to device (CUDA) asynchronously.
|
258
|
-
"""
|
259
|
-
weight_indices = [0] * len(forward_batch.lora_ids)
|
260
|
-
lora_ranks = [0] * self.max_loras_per_batch
|
261
|
-
scalings = [0] * self.max_loras_per_batch
|
262
|
-
for i, uid in enumerate(forward_batch.lora_ids):
|
263
|
-
weight_indices[i] = self.memory_pool.get_buffer_id(uid)
|
264
|
-
if uid is not None:
|
265
|
-
lora = self.loras[uid]
|
266
|
-
lora_ranks[weight_indices[i]] = lora.config.r
|
267
|
-
scalings[weight_indices[i]] = lora.scaling
|
268
|
-
|
269
|
-
# Use pinned memory to avoid synchronizations during host-to-device transfer
|
270
|
-
weight_indices_tensor = torch.tensor(
|
271
|
-
weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
|
272
|
-
)
|
273
|
-
lora_ranks_tensor = torch.tensor(
|
274
|
-
lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
|
275
|
-
)
|
276
|
-
scalings_tensor = torch.tensor(
|
277
|
-
scalings, dtype=torch.float, pin_memory=True, device="cpu"
|
278
|
-
)
|
279
|
-
|
280
|
-
# Copy to device tensors asynchronously
|
281
|
-
weight_indices_out[:bs].copy_(weight_indices_tensor, non_blocking=True)
|
282
|
-
lora_ranks_out[: self.max_loras_per_batch].copy_(
|
283
|
-
lora_ranks_tensor, non_blocking=True
|
284
|
-
)
|
285
|
-
scalings_out[: self.max_loras_per_batch].copy_(
|
286
|
-
scalings_tensor, non_blocking=True
|
287
|
-
)
|
288
|
-
|
289
|
-
if (
|
248
|
+
use_cuda_graph = (
|
290
249
|
hasattr(self, "max_bs_in_cuda_graph")
|
291
250
|
and bs <= self.max_bs_in_cuda_graph
|
292
251
|
and forward_batch.forward_mode.is_cuda_graph()
|
293
|
-
)
|
294
|
-
# Do in-place updates when CUDA graph is enabled and the batch forward mode
|
295
|
-
# could use CUDA graph.
|
296
|
-
|
297
|
-
transfer_adapter_info(
|
298
|
-
self.cuda_graph_batch_info.weight_indices,
|
299
|
-
self.cuda_graph_batch_info.lora_ranks,
|
300
|
-
self.cuda_graph_batch_info.scalings,
|
301
|
-
)
|
302
|
-
|
303
|
-
self.cuda_graph_batch_info.bs = bs
|
304
|
-
self.cuda_graph_batch_info.max_len = 1
|
305
|
-
batch_info = self.cuda_graph_batch_info
|
306
|
-
else:
|
307
|
-
weight_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
|
308
|
-
lora_ranks = torch.zeros(
|
309
|
-
(self.max_loras_per_batch,), dtype=torch.int64, device=self.device
|
310
|
-
)
|
311
|
-
scalings = torch.zeros(
|
312
|
-
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
|
313
|
-
)
|
314
|
-
transfer_adapter_info(
|
315
|
-
weight_indices,
|
316
|
-
lora_ranks,
|
317
|
-
scalings,
|
318
|
-
)
|
319
|
-
|
320
|
-
seg_lens = (
|
321
|
-
forward_batch.extend_seq_lens
|
322
|
-
if forward_batch.forward_mode.is_extend()
|
323
|
-
else torch.ones(bs, device=self.device)
|
324
|
-
)
|
325
|
-
|
326
|
-
max_len = (
|
327
|
-
# Calculate max_len from the CPU copy to avoid D2H transfer.
|
328
|
-
max(forward_batch.extend_seq_lens_cpu)
|
329
|
-
if forward_batch.forward_mode.is_extend()
|
330
|
-
else 1
|
331
|
-
)
|
252
|
+
)
|
332
253
|
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
weight_indices=
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
254
|
+
weight_indices = [0] * len(forward_batch.lora_ids)
|
255
|
+
lora_ranks = [0] * self.max_loras_per_batch
|
256
|
+
scalings = [0] * self.max_loras_per_batch
|
257
|
+
for i, uid in enumerate(forward_batch.lora_ids):
|
258
|
+
weight_indices[i] = self.memory_pool.get_buffer_id(uid)
|
259
|
+
if uid is not None:
|
260
|
+
lora = self.loras[uid]
|
261
|
+
lora_ranks[weight_indices[i]] = lora.config.r
|
262
|
+
scalings[weight_indices[i]] = lora.scaling
|
263
|
+
# Do in-place updates when CUDA graph is enabled and the batch forward mode
|
264
|
+
# could use CUDA graph.
|
265
|
+
self.lora_backend.prepare_lora_batch(
|
266
|
+
forward_batch=forward_batch,
|
267
|
+
weight_indices=weight_indices,
|
268
|
+
lora_ranks=lora_ranks,
|
269
|
+
scalings=scalings,
|
270
|
+
batch_info=self.cuda_graph_batch_info if use_cuda_graph else None,
|
271
|
+
)
|
346
272
|
|
347
273
|
def update_lora_info(self):
|
348
274
|
"""
|
@@ -492,6 +418,10 @@ class LoRAManager:
|
|
492
418
|
replace_submodule(self.base_model, module_name, lora_module)
|
493
419
|
return lora_module
|
494
420
|
|
421
|
+
def should_skip_lora_for_vision_model(self, module_name):
|
422
|
+
# TODO: support different vision models
|
423
|
+
return module_name.find("vision_model.model") != -1
|
424
|
+
|
495
425
|
def init_lora_modules(self):
|
496
426
|
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
|
497
427
|
self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [
|
@@ -509,6 +439,10 @@ class LoRAManager:
|
|
509
439
|
) and not self.base_model.should_apply_lora(module_name):
|
510
440
|
continue
|
511
441
|
|
442
|
+
# Skip vision model
|
443
|
+
if self.should_skip_lora_for_vision_model(module_name):
|
444
|
+
continue
|
445
|
+
|
512
446
|
# The module should be converted if it is included in target_names
|
513
447
|
if module_name.split(".")[-1] in self.target_modules:
|
514
448
|
layer_id = get_layer_id(module_name)
|
sglang/srt/lora/mem_pool.py
CHANGED
@@ -4,7 +4,6 @@ from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
|
|
4
4
|
import torch
|
5
5
|
|
6
6
|
from sglang.srt.distributed import divide
|
7
|
-
from sglang.srt.hf_transformers_utils import AutoConfig
|
8
7
|
from sglang.srt.lora.layers import BaseLayerWithLoRA
|
9
8
|
from sglang.srt.lora.lora import LoRAAdapter
|
10
9
|
from sglang.srt.lora.lora_config import LoRAConfig
|
@@ -17,6 +16,7 @@ from sglang.srt.lora.utils import (
|
|
17
16
|
get_stacked_multiply,
|
18
17
|
get_target_module_name,
|
19
18
|
)
|
19
|
+
from sglang.srt.utils.hf_transformers_utils import AutoConfig
|
20
20
|
|
21
21
|
logger = logging.getLogger(__name__)
|
22
22
|
|
@@ -104,12 +104,18 @@ class LoRAMemoryPool:
|
|
104
104
|
return all(_can_support(x) for x in config)
|
105
105
|
|
106
106
|
def get_lora_A_shape(
|
107
|
-
self,
|
107
|
+
self,
|
108
|
+
module_name: str,
|
109
|
+
base_model: torch.nn.Module,
|
110
|
+
max_lora_dim: int,
|
111
|
+
layer_idx: int,
|
108
112
|
) -> Tuple[int]:
|
109
113
|
"""
|
110
114
|
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
111
115
|
"""
|
112
|
-
input_dim, _ = get_hidden_dim(
|
116
|
+
input_dim, _ = get_hidden_dim(
|
117
|
+
module_name, self.base_hf_config, base_model, layer_idx
|
118
|
+
)
|
113
119
|
c = get_stacked_multiply(module_name)
|
114
120
|
if self.tp_size > 1 and module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
115
121
|
input_dim = divide(input_dim, self.tp_size)
|
@@ -120,12 +126,18 @@ class LoRAMemoryPool:
|
|
120
126
|
)
|
121
127
|
|
122
128
|
def get_lora_B_shape(
|
123
|
-
self,
|
129
|
+
self,
|
130
|
+
module_name: str,
|
131
|
+
base_model: torch.nn.Module,
|
132
|
+
max_lora_dim: int,
|
133
|
+
layer_idx: int,
|
124
134
|
) -> Tuple[int]:
|
125
135
|
"""
|
126
136
|
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
127
137
|
"""
|
128
|
-
_, output_dim = get_hidden_dim(
|
138
|
+
_, output_dim = get_hidden_dim(
|
139
|
+
module_name, self.base_hf_config, base_model, layer_idx
|
140
|
+
)
|
129
141
|
if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
130
142
|
output_dim = divide(output_dim, self.tp_size)
|
131
143
|
return (
|
@@ -140,19 +152,21 @@ class LoRAMemoryPool:
|
|
140
152
|
def init_buffer(
|
141
153
|
buffer: Dict[str, List[torch.Tensor]],
|
142
154
|
target_modules: Set[str],
|
143
|
-
get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
|
155
|
+
get_lora_shape_fn: Callable[[str, torch.nn.Module, int, int], Tuple[int]],
|
144
156
|
):
|
145
157
|
for module_name in target_modules:
|
146
|
-
lora_shape = get_lora_shape_fn(
|
147
|
-
module_name, base_model, self.max_lora_rank
|
148
|
-
)
|
149
158
|
buffer[module_name] = [
|
150
159
|
torch.empty(
|
151
|
-
|
160
|
+
get_lora_shape_fn(
|
161
|
+
module_name,
|
162
|
+
base_model,
|
163
|
+
self.max_lora_rank,
|
164
|
+
idx,
|
165
|
+
),
|
152
166
|
dtype=self.dtype,
|
153
167
|
device=device,
|
154
168
|
)
|
155
|
-
for
|
169
|
+
for idx in range(self.num_layer)
|
156
170
|
]
|
157
171
|
|
158
172
|
init_buffer(
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from .chunked_sgmv_expand import chunked_sgmv_lora_expand_forward
|
2
|
+
from .chunked_sgmv_shrink import chunked_sgmv_lora_shrink_forward
|
1
3
|
from .gate_up_lora_b import gate_up_lora_b_fwd
|
2
4
|
from .qkv_lora_b import qkv_lora_b_fwd
|
3
5
|
from .sgemm_lora_a import sgemm_lora_a_fwd
|
@@ -8,4 +10,6 @@ __all__ = [
|
|
8
10
|
"qkv_lora_b_fwd",
|
9
11
|
"sgemm_lora_a_fwd",
|
10
12
|
"sgemm_lora_b_fwd",
|
13
|
+
"chunked_sgmv_lora_shrink_forward",
|
14
|
+
"chunked_sgmv_lora_expand_forward",
|
11
15
|
]
|
@@ -0,0 +1,214 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import triton
|
5
|
+
import triton.language as tl
|
6
|
+
|
7
|
+
from sglang.srt.lora.utils import LoRABatchInfo
|
8
|
+
from sglang.srt.utils import cached_triton_kernel
|
9
|
+
|
10
|
+
|
11
|
+
@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
|
12
|
+
@triton.jit
|
13
|
+
def _chunked_lora_expand_kernel(
|
14
|
+
# Pointers to matrices
|
15
|
+
x,
|
16
|
+
weights,
|
17
|
+
output,
|
18
|
+
# Information on sequence lengths and weight id
|
19
|
+
seg_indptr,
|
20
|
+
weight_indices,
|
21
|
+
lora_ranks,
|
22
|
+
permutation,
|
23
|
+
num_segs,
|
24
|
+
# For fused output scaling
|
25
|
+
scalings,
|
26
|
+
# Offsets of q/k/v slice on output dimension
|
27
|
+
slice_offsets,
|
28
|
+
# Meta parameters
|
29
|
+
NUM_SLICES: tl.constexpr,
|
30
|
+
OUTPUT_DIM: tl.constexpr,
|
31
|
+
MAX_RANK: tl.constexpr, # K = R
|
32
|
+
BLOCK_M: tl.constexpr,
|
33
|
+
BLOCK_N: tl.constexpr,
|
34
|
+
BLOCK_K: tl.constexpr,
|
35
|
+
):
|
36
|
+
"""
|
37
|
+
Computes a chunked SGMV for LoRA expand operations.
|
38
|
+
|
39
|
+
When a sequence's rank is 0, the kernel is essentially a no-op, following
|
40
|
+
the convention in pytorch where the product of two matrices of shape (m, 0)
|
41
|
+
and (0, n) is an all-zero matrix of shape (m, n).
|
42
|
+
|
43
|
+
Args:
|
44
|
+
x (Tensor): The input tensor, which is the result of the LoRA A projection.
|
45
|
+
Shape: (s, num_slices * K), where s is the sum of all sequence lengths in the
|
46
|
+
batch and K is the maximum LoRA rank.
|
47
|
+
weights (Tensor): The LoRA B weights for all adapters.
|
48
|
+
Shape: (num_lora, output_dim, K).
|
49
|
+
output (Tensor): The output tensor where the result is stored.
|
50
|
+
Shape: (s, output_dim).
|
51
|
+
"""
|
52
|
+
tl.static_assert(NUM_SLICES <= 3)
|
53
|
+
|
54
|
+
x_stride_0: tl.constexpr = NUM_SLICES * MAX_RANK
|
55
|
+
x_stride_1: tl.constexpr = 1
|
56
|
+
|
57
|
+
w_stride_0: tl.constexpr = OUTPUT_DIM * MAX_RANK
|
58
|
+
w_stride_1: tl.constexpr = MAX_RANK
|
59
|
+
w_stride_2: tl.constexpr = 1
|
60
|
+
|
61
|
+
output_stride_0: tl.constexpr = OUTPUT_DIM
|
62
|
+
output_stride_1: tl.constexpr = 1
|
63
|
+
|
64
|
+
pid_s = tl.program_id(axis=2)
|
65
|
+
if pid_s >= num_segs:
|
66
|
+
return
|
67
|
+
|
68
|
+
# Current block computes sequence with batch_id,
|
69
|
+
# which starts from row seg_start of x with length seg_len.
|
70
|
+
# qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v)
|
71
|
+
w_index = tl.load(weight_indices + pid_s)
|
72
|
+
cur_rank = tl.load(lora_ranks + w_index)
|
73
|
+
|
74
|
+
# If rank is 0, this kernel is a no-op.
|
75
|
+
if cur_rank == 0:
|
76
|
+
return
|
77
|
+
|
78
|
+
seg_start = tl.load(seg_indptr + pid_s)
|
79
|
+
seg_end = tl.load(seg_indptr + pid_s + 1)
|
80
|
+
|
81
|
+
slice_id = tl.program_id(axis=1)
|
82
|
+
slice_start = tl.load(slice_offsets + slice_id)
|
83
|
+
slice_end = tl.load(slice_offsets + slice_id + 1)
|
84
|
+
|
85
|
+
scaling = tl.load(scalings + w_index)
|
86
|
+
# Adjust K (rank) according to the specific LoRA adapter
|
87
|
+
cur_rank = tl.minimum(MAX_RANK, cur_rank)
|
88
|
+
|
89
|
+
# Map logical sequence index to physical index
|
90
|
+
s_offset_logical = tl.arange(0, BLOCK_M) + seg_start
|
91
|
+
s_offset_physical = tl.load(
|
92
|
+
permutation + s_offset_logical, mask=s_offset_logical < seg_end
|
93
|
+
)
|
94
|
+
|
95
|
+
# Create pointers for the first block of x and weights[batch_id][n_start: n_end][:]
|
96
|
+
# The pointers will be advanced as we move in the K direction
|
97
|
+
# and accumulate
|
98
|
+
pid_n = tl.program_id(axis=0)
|
99
|
+
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_start
|
100
|
+
k_offset = tl.arange(0, BLOCK_K)
|
101
|
+
|
102
|
+
x_ptrs = (
|
103
|
+
x
|
104
|
+
+ slice_id * cur_rank * x_stride_1
|
105
|
+
+ (s_offset_physical[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1)
|
106
|
+
)
|
107
|
+
w_ptrs = (weights + w_index * w_stride_0) + (
|
108
|
+
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
109
|
+
)
|
110
|
+
|
111
|
+
# Iterate to compute the block in output matrix
|
112
|
+
partial_sum = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
113
|
+
for k in range(0, tl.cdiv(cur_rank, BLOCK_K)):
|
114
|
+
x_tile = tl.load(
|
115
|
+
x_ptrs,
|
116
|
+
mask=(s_offset_logical[:, None] < seg_end)
|
117
|
+
& (k_offset[None, :] < cur_rank - k * BLOCK_K),
|
118
|
+
other=0.0,
|
119
|
+
)
|
120
|
+
w_tile = tl.load(
|
121
|
+
w_ptrs,
|
122
|
+
mask=(k_offset[:, None] < cur_rank - k * BLOCK_K)
|
123
|
+
& (n_offset[None, :] < slice_end),
|
124
|
+
other=0.0,
|
125
|
+
)
|
126
|
+
partial_sum += tl.dot(x_tile, w_tile)
|
127
|
+
|
128
|
+
x_ptrs += BLOCK_K * x_stride_1
|
129
|
+
w_ptrs += BLOCK_K * w_stride_2
|
130
|
+
|
131
|
+
# Store result to output matrix
|
132
|
+
partial_sum *= scaling
|
133
|
+
partial_sum = partial_sum.to(x.dtype.element_ty)
|
134
|
+
output_ptr = output + (
|
135
|
+
s_offset_physical[:, None] * output_stride_0
|
136
|
+
+ n_offset[None, :] * output_stride_1
|
137
|
+
)
|
138
|
+
output_mask = (s_offset_logical[:, None] < seg_end) & (
|
139
|
+
n_offset[None, :] < slice_end
|
140
|
+
)
|
141
|
+
partial_sum += tl.load(output_ptr, mask=output_mask, other=0.0)
|
142
|
+
tl.store(output_ptr, partial_sum, mask=output_mask)
|
143
|
+
|
144
|
+
|
145
|
+
def chunked_sgmv_lora_expand_forward(
|
146
|
+
x: torch.Tensor,
|
147
|
+
weights: torch.Tensor,
|
148
|
+
batch_info: LoRABatchInfo,
|
149
|
+
slice_offsets: torch.Tensor,
|
150
|
+
max_slice_size: int,
|
151
|
+
base_output: Optional[torch.Tensor],
|
152
|
+
) -> torch.Tensor:
|
153
|
+
|
154
|
+
# x: (s, slice_num * r)
|
155
|
+
# weights: (num_lora, output_dim, r)
|
156
|
+
# slice_offsets: boundaries for different slices in the output dimension
|
157
|
+
# output: (s, output_dim)
|
158
|
+
|
159
|
+
# Compute lora_output with shape (s, output_dim) as follows:
|
160
|
+
# For each slice i, accumulates:
|
161
|
+
# lora_output[:, slice_offsets[i]:slice_offsets[i+1]] += scaling * sgemm(x[:, i*cur_rank:(i+1)*cur_rank], weights[:, slice_offsets[i]:slice_offsets[i+1], :])
|
162
|
+
|
163
|
+
assert x.is_contiguous()
|
164
|
+
assert weights.is_contiguous()
|
165
|
+
assert len(x.shape) == 2
|
166
|
+
assert len(weights.shape) == 3
|
167
|
+
|
168
|
+
# Get dims
|
169
|
+
M = x.shape[0]
|
170
|
+
input_dim = x.shape[1]
|
171
|
+
OUTPUT_DIM = weights.shape[1]
|
172
|
+
MAX_RANK = weights.shape[2]
|
173
|
+
num_slices = len(slice_offsets) - 1
|
174
|
+
assert input_dim == num_slices * MAX_RANK
|
175
|
+
|
176
|
+
# TODO (lifuhuang): fine-tune per operation
|
177
|
+
BLOCK_M = batch_info.max_len
|
178
|
+
BLOCK_K = 16
|
179
|
+
BLOCK_N = 64
|
180
|
+
|
181
|
+
num_segments = batch_info.num_segments
|
182
|
+
|
183
|
+
grid = (
|
184
|
+
triton.cdiv(max_slice_size, BLOCK_N),
|
185
|
+
num_slices, # number of slices in the input/output
|
186
|
+
batch_info.bs if batch_info.use_cuda_graph else num_segments,
|
187
|
+
)
|
188
|
+
|
189
|
+
if base_output is None:
|
190
|
+
output = torch.zeros((M, OUTPUT_DIM), device=x.device, dtype=x.dtype)
|
191
|
+
else:
|
192
|
+
output = base_output
|
193
|
+
|
194
|
+
_chunked_lora_expand_kernel[grid](
|
195
|
+
x=x,
|
196
|
+
weights=weights,
|
197
|
+
output=output,
|
198
|
+
seg_indptr=batch_info.seg_indptr,
|
199
|
+
weight_indices=batch_info.weight_indices,
|
200
|
+
lora_ranks=batch_info.lora_ranks,
|
201
|
+
permutation=batch_info.permutation,
|
202
|
+
num_segs=num_segments,
|
203
|
+
scalings=batch_info.scalings,
|
204
|
+
slice_offsets=slice_offsets,
|
205
|
+
# constants
|
206
|
+
NUM_SLICES=num_slices,
|
207
|
+
OUTPUT_DIM=OUTPUT_DIM,
|
208
|
+
MAX_RANK=MAX_RANK,
|
209
|
+
BLOCK_M=BLOCK_M,
|
210
|
+
BLOCK_N=BLOCK_N,
|
211
|
+
BLOCK_K=BLOCK_K,
|
212
|
+
)
|
213
|
+
|
214
|
+
return output
|