sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,214 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import triton
|
5
|
+
import triton.language as tl
|
6
|
+
|
7
|
+
from sglang.srt.lora.utils import LoRABatchInfo
|
8
|
+
from sglang.srt.utils import cached_triton_kernel
|
9
|
+
|
10
|
+
|
11
|
+
@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
|
12
|
+
@triton.jit
|
13
|
+
def _chunked_lora_expand_kernel(
|
14
|
+
# Pointers to matrices
|
15
|
+
x,
|
16
|
+
weights,
|
17
|
+
output,
|
18
|
+
# Information on sequence lengths and weight id
|
19
|
+
seg_indptr,
|
20
|
+
weight_indices,
|
21
|
+
lora_ranks,
|
22
|
+
permutation,
|
23
|
+
num_segs,
|
24
|
+
# For fused output scaling
|
25
|
+
scalings,
|
26
|
+
# Offsets of q/k/v slice on output dimension
|
27
|
+
slice_offsets,
|
28
|
+
# Meta parameters
|
29
|
+
NUM_SLICES: tl.constexpr,
|
30
|
+
OUTPUT_DIM: tl.constexpr,
|
31
|
+
MAX_RANK: tl.constexpr, # K = R
|
32
|
+
BLOCK_M: tl.constexpr,
|
33
|
+
BLOCK_N: tl.constexpr,
|
34
|
+
BLOCK_K: tl.constexpr,
|
35
|
+
):
|
36
|
+
"""
|
37
|
+
Computes a chunked SGMV for LoRA expand operations.
|
38
|
+
|
39
|
+
When a sequence's rank is 0, the kernel is essentially a no-op, following
|
40
|
+
the convention in pytorch where the product of two matrices of shape (m, 0)
|
41
|
+
and (0, n) is an all-zero matrix of shape (m, n).
|
42
|
+
|
43
|
+
Args:
|
44
|
+
x (Tensor): The input tensor, which is the result of the LoRA A projection.
|
45
|
+
Shape: (s, num_slices * K), where s is the sum of all sequence lengths in the
|
46
|
+
batch and K is the maximum LoRA rank.
|
47
|
+
weights (Tensor): The LoRA B weights for all adapters.
|
48
|
+
Shape: (num_lora, output_dim, K).
|
49
|
+
output (Tensor): The output tensor where the result is stored.
|
50
|
+
Shape: (s, output_dim).
|
51
|
+
"""
|
52
|
+
tl.static_assert(NUM_SLICES <= 3)
|
53
|
+
|
54
|
+
x_stride_0: tl.constexpr = NUM_SLICES * MAX_RANK
|
55
|
+
x_stride_1: tl.constexpr = 1
|
56
|
+
|
57
|
+
w_stride_0: tl.constexpr = OUTPUT_DIM * MAX_RANK
|
58
|
+
w_stride_1: tl.constexpr = MAX_RANK
|
59
|
+
w_stride_2: tl.constexpr = 1
|
60
|
+
|
61
|
+
output_stride_0: tl.constexpr = OUTPUT_DIM
|
62
|
+
output_stride_1: tl.constexpr = 1
|
63
|
+
|
64
|
+
pid_s = tl.program_id(axis=2)
|
65
|
+
if pid_s >= num_segs:
|
66
|
+
return
|
67
|
+
|
68
|
+
# Current block computes sequence with batch_id,
|
69
|
+
# which starts from row seg_start of x with length seg_len.
|
70
|
+
# qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v)
|
71
|
+
w_index = tl.load(weight_indices + pid_s)
|
72
|
+
cur_rank = tl.load(lora_ranks + w_index)
|
73
|
+
|
74
|
+
# If rank is 0, this kernel is a no-op.
|
75
|
+
if cur_rank == 0:
|
76
|
+
return
|
77
|
+
|
78
|
+
seg_start = tl.load(seg_indptr + pid_s)
|
79
|
+
seg_end = tl.load(seg_indptr + pid_s + 1)
|
80
|
+
|
81
|
+
slice_id = tl.program_id(axis=1)
|
82
|
+
slice_start = tl.load(slice_offsets + slice_id)
|
83
|
+
slice_end = tl.load(slice_offsets + slice_id + 1)
|
84
|
+
|
85
|
+
scaling = tl.load(scalings + w_index)
|
86
|
+
# Adjust K (rank) according to the specific LoRA adapter
|
87
|
+
cur_rank = tl.minimum(MAX_RANK, cur_rank)
|
88
|
+
|
89
|
+
# Map logical sequence index to physical index
|
90
|
+
s_offset_logical = tl.arange(0, BLOCK_M) + seg_start
|
91
|
+
s_offset_physical = tl.load(
|
92
|
+
permutation + s_offset_logical, mask=s_offset_logical < seg_end
|
93
|
+
)
|
94
|
+
|
95
|
+
# Create pointers for the first block of x and weights[batch_id][n_start: n_end][:]
|
96
|
+
# The pointers will be advanced as we move in the K direction
|
97
|
+
# and accumulate
|
98
|
+
pid_n = tl.program_id(axis=0)
|
99
|
+
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_start
|
100
|
+
k_offset = tl.arange(0, BLOCK_K)
|
101
|
+
|
102
|
+
x_ptrs = (
|
103
|
+
x
|
104
|
+
+ slice_id * cur_rank * x_stride_1
|
105
|
+
+ (s_offset_physical[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1)
|
106
|
+
)
|
107
|
+
w_ptrs = (weights + w_index * w_stride_0) + (
|
108
|
+
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
109
|
+
)
|
110
|
+
|
111
|
+
# Iterate to compute the block in output matrix
|
112
|
+
partial_sum = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
113
|
+
for k in range(0, tl.cdiv(cur_rank, BLOCK_K)):
|
114
|
+
x_tile = tl.load(
|
115
|
+
x_ptrs,
|
116
|
+
mask=(s_offset_logical[:, None] < seg_end)
|
117
|
+
& (k_offset[None, :] < cur_rank - k * BLOCK_K),
|
118
|
+
other=0.0,
|
119
|
+
)
|
120
|
+
w_tile = tl.load(
|
121
|
+
w_ptrs,
|
122
|
+
mask=(k_offset[:, None] < cur_rank - k * BLOCK_K)
|
123
|
+
& (n_offset[None, :] < slice_end),
|
124
|
+
other=0.0,
|
125
|
+
)
|
126
|
+
partial_sum += tl.dot(x_tile, w_tile)
|
127
|
+
|
128
|
+
x_ptrs += BLOCK_K * x_stride_1
|
129
|
+
w_ptrs += BLOCK_K * w_stride_2
|
130
|
+
|
131
|
+
# Store result to output matrix
|
132
|
+
partial_sum *= scaling
|
133
|
+
partial_sum = partial_sum.to(x.dtype.element_ty)
|
134
|
+
output_ptr = output + (
|
135
|
+
s_offset_physical[:, None] * output_stride_0
|
136
|
+
+ n_offset[None, :] * output_stride_1
|
137
|
+
)
|
138
|
+
output_mask = (s_offset_logical[:, None] < seg_end) & (
|
139
|
+
n_offset[None, :] < slice_end
|
140
|
+
)
|
141
|
+
partial_sum += tl.load(output_ptr, mask=output_mask, other=0.0)
|
142
|
+
tl.store(output_ptr, partial_sum, mask=output_mask)
|
143
|
+
|
144
|
+
|
145
|
+
def chunked_sgmv_lora_expand_forward(
|
146
|
+
x: torch.Tensor,
|
147
|
+
weights: torch.Tensor,
|
148
|
+
batch_info: LoRABatchInfo,
|
149
|
+
slice_offsets: torch.Tensor,
|
150
|
+
max_slice_size: int,
|
151
|
+
base_output: Optional[torch.Tensor],
|
152
|
+
) -> torch.Tensor:
|
153
|
+
|
154
|
+
# x: (s, slice_num * r)
|
155
|
+
# weights: (num_lora, output_dim, r)
|
156
|
+
# slice_offsets: boundaries for different slices in the output dimension
|
157
|
+
# output: (s, output_dim)
|
158
|
+
|
159
|
+
# Compute lora_output with shape (s, output_dim) as follows:
|
160
|
+
# For each slice i, accumulates:
|
161
|
+
# lora_output[:, slice_offsets[i]:slice_offsets[i+1]] += scaling * sgemm(x[:, i*cur_rank:(i+1)*cur_rank], weights[:, slice_offsets[i]:slice_offsets[i+1], :])
|
162
|
+
|
163
|
+
assert x.is_contiguous()
|
164
|
+
assert weights.is_contiguous()
|
165
|
+
assert len(x.shape) == 2
|
166
|
+
assert len(weights.shape) == 3
|
167
|
+
|
168
|
+
# Get dims
|
169
|
+
M = x.shape[0]
|
170
|
+
input_dim = x.shape[1]
|
171
|
+
OUTPUT_DIM = weights.shape[1]
|
172
|
+
MAX_RANK = weights.shape[2]
|
173
|
+
num_slices = len(slice_offsets) - 1
|
174
|
+
assert input_dim == num_slices * MAX_RANK
|
175
|
+
|
176
|
+
# TODO (lifuhuang): fine-tune per operation
|
177
|
+
BLOCK_M = batch_info.max_len
|
178
|
+
BLOCK_K = 16
|
179
|
+
BLOCK_N = 64
|
180
|
+
|
181
|
+
num_segments = batch_info.num_segments
|
182
|
+
|
183
|
+
grid = (
|
184
|
+
triton.cdiv(max_slice_size, BLOCK_N),
|
185
|
+
num_slices, # number of slices in the input/output
|
186
|
+
batch_info.bs if batch_info.use_cuda_graph else num_segments,
|
187
|
+
)
|
188
|
+
|
189
|
+
if base_output is None:
|
190
|
+
output = torch.zeros((M, OUTPUT_DIM), device=x.device, dtype=x.dtype)
|
191
|
+
else:
|
192
|
+
output = base_output
|
193
|
+
|
194
|
+
_chunked_lora_expand_kernel[grid](
|
195
|
+
x=x,
|
196
|
+
weights=weights,
|
197
|
+
output=output,
|
198
|
+
seg_indptr=batch_info.seg_indptr,
|
199
|
+
weight_indices=batch_info.weight_indices,
|
200
|
+
lora_ranks=batch_info.lora_ranks,
|
201
|
+
permutation=batch_info.permutation,
|
202
|
+
num_segs=num_segments,
|
203
|
+
scalings=batch_info.scalings,
|
204
|
+
slice_offsets=slice_offsets,
|
205
|
+
# constants
|
206
|
+
NUM_SLICES=num_slices,
|
207
|
+
OUTPUT_DIM=OUTPUT_DIM,
|
208
|
+
MAX_RANK=MAX_RANK,
|
209
|
+
BLOCK_M=BLOCK_M,
|
210
|
+
BLOCK_N=BLOCK_N,
|
211
|
+
BLOCK_K=BLOCK_K,
|
212
|
+
)
|
213
|
+
|
214
|
+
return output
|
@@ -0,0 +1,174 @@
|
|
1
|
+
import torch
|
2
|
+
import triton
|
3
|
+
import triton.language as tl
|
4
|
+
|
5
|
+
from sglang.srt.lora.utils import LoRABatchInfo
|
6
|
+
from sglang.srt.utils import cached_triton_kernel
|
7
|
+
|
8
|
+
|
9
|
+
@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
|
10
|
+
@triton.jit
|
11
|
+
def _chunked_lora_shrink_kernel(
|
12
|
+
# Pointers to matrices
|
13
|
+
x,
|
14
|
+
weights,
|
15
|
+
output,
|
16
|
+
# Information on sequence lengths,ranks and weight id
|
17
|
+
seg_indptr,
|
18
|
+
weight_indices,
|
19
|
+
lora_ranks,
|
20
|
+
permutation,
|
21
|
+
num_segs,
|
22
|
+
# Meta parameters
|
23
|
+
N: tl.constexpr, # num_slices * r
|
24
|
+
K: tl.constexpr, # input_dim
|
25
|
+
NUM_SLICES: tl.constexpr,
|
26
|
+
BLOCK_M: tl.constexpr,
|
27
|
+
BLOCK_N: tl.constexpr,
|
28
|
+
BLOCK_K: tl.constexpr,
|
29
|
+
):
|
30
|
+
"""
|
31
|
+
Computes a chunked SGMV for LoRA shrink operations.
|
32
|
+
|
33
|
+
The kernel ensures that output[seg_start:seg_start + seg_len, :rank * num_slices]
|
34
|
+
stores the product of the input `x` and the LoRA weights for the corresponding
|
35
|
+
sequence. This implies that when rank is 0, the kernel is essentially a no-op,
|
36
|
+
as output[seg_start:seg_start + seg_len, :0] is trivially correct (empty).
|
37
|
+
|
38
|
+
Args:
|
39
|
+
x (torch.Tensor): The input activations tensor of shape `(s, K)`, where `s`
|
40
|
+
is the sum of all sequence lengths in the batch.
|
41
|
+
weights (torch.Tensor): The LoRA A weights for all available adapters,
|
42
|
+
with shape `(num_lora, N, K)` where N = num_slices * r.
|
43
|
+
output (torch.Tensor): The output tensor of shape `(s, N)`.
|
44
|
+
"""
|
45
|
+
x_stride_1: tl.constexpr = 1
|
46
|
+
x_stride_0: tl.constexpr = K
|
47
|
+
|
48
|
+
w_stride_0: tl.constexpr = N * K
|
49
|
+
w_stride_1: tl.constexpr = K
|
50
|
+
w_stride_2: tl.constexpr = 1
|
51
|
+
|
52
|
+
output_stride_0: tl.constexpr = N
|
53
|
+
output_stride_1: tl.constexpr = 1
|
54
|
+
|
55
|
+
pid_s = tl.program_id(1)
|
56
|
+
if pid_s >= num_segs:
|
57
|
+
return
|
58
|
+
|
59
|
+
pid_n = tl.program_id(0)
|
60
|
+
|
61
|
+
# Current block computes sequence with batch_id,
|
62
|
+
# which starts from row seg_start of x with length seg_len
|
63
|
+
w_index = tl.load(weight_indices + pid_s)
|
64
|
+
rank = tl.load(lora_ranks + w_index)
|
65
|
+
|
66
|
+
# If rank is 0, this kernel becomes a no-op as the output is always trivially correct.
|
67
|
+
if rank == 0:
|
68
|
+
return
|
69
|
+
|
70
|
+
seg_start = tl.load(seg_indptr + pid_s)
|
71
|
+
seg_end = tl.load(seg_indptr + pid_s + 1)
|
72
|
+
|
73
|
+
# Adjust N dim according to the specific LoRA adapter
|
74
|
+
cur_n = tl.minimum(N, rank * NUM_SLICES)
|
75
|
+
|
76
|
+
# Map logical sequence index to physical index
|
77
|
+
s_offset_logical = tl.arange(0, BLOCK_M) + seg_start
|
78
|
+
s_offset_physical = tl.load(
|
79
|
+
permutation + s_offset_logical, mask=s_offset_logical < seg_end
|
80
|
+
)
|
81
|
+
|
82
|
+
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
83
|
+
k_offset = tl.arange(0, BLOCK_K)
|
84
|
+
x_ptrs = x + (
|
85
|
+
s_offset_physical[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
|
86
|
+
)
|
87
|
+
w_ptrs = (weights + w_index * w_stride_0) + (
|
88
|
+
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
89
|
+
)
|
90
|
+
|
91
|
+
# Iterate to compute the block in output matrix
|
92
|
+
partial_sum = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
93
|
+
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
94
|
+
x_tile = tl.load(
|
95
|
+
x_ptrs,
|
96
|
+
mask=(s_offset_logical[:, None] < seg_end)
|
97
|
+
& (k_offset[None, :] < K - k * BLOCK_K),
|
98
|
+
other=0.0,
|
99
|
+
)
|
100
|
+
w_tile = tl.load(
|
101
|
+
w_ptrs,
|
102
|
+
mask=(k_offset[:, None] < K - k * BLOCK_K) & (n_offset[None, :] < cur_n),
|
103
|
+
other=0.0,
|
104
|
+
)
|
105
|
+
partial_sum += tl.dot(x_tile, w_tile)
|
106
|
+
|
107
|
+
x_ptrs += BLOCK_K * x_stride_1
|
108
|
+
w_ptrs += BLOCK_K * w_stride_2
|
109
|
+
|
110
|
+
# Store result to output matrix
|
111
|
+
partial_sum = partial_sum.to(x.dtype.element_ty)
|
112
|
+
output_ptr = output + (
|
113
|
+
s_offset_physical[:, None] * output_stride_0
|
114
|
+
+ n_offset[None, :] * output_stride_1
|
115
|
+
)
|
116
|
+
output_mask = (s_offset_logical[:, None] < seg_end) & (n_offset[None, :] < cur_n)
|
117
|
+
tl.store(output_ptr, partial_sum, mask=output_mask)
|
118
|
+
|
119
|
+
|
120
|
+
def chunked_sgmv_lora_shrink_forward(
|
121
|
+
x: torch.Tensor,
|
122
|
+
weights: torch.Tensor,
|
123
|
+
batch_info: LoRABatchInfo,
|
124
|
+
num_slices: int,
|
125
|
+
) -> torch.Tensor:
|
126
|
+
# x: (s, input_dim)
|
127
|
+
# weights: (num_lora, num_slices * r, input_dim)
|
128
|
+
# output: (s, num_slices * r)
|
129
|
+
# num_slices: qkv=3, gate_up=2, others=1
|
130
|
+
# when called with multiple slices, the weights.shape[-2] will be num_slices * r
|
131
|
+
# input_dim is much larger than r
|
132
|
+
|
133
|
+
assert x.is_contiguous()
|
134
|
+
assert weights.is_contiguous()
|
135
|
+
assert len(x.shape) == 2
|
136
|
+
assert len(weights.shape) == 3
|
137
|
+
|
138
|
+
# Block shapes
|
139
|
+
# TODO (lifuhuang): experiment with split-k
|
140
|
+
BLOCK_M = batch_info.max_len
|
141
|
+
BLOCK_N = 16
|
142
|
+
BLOCK_K = 256
|
143
|
+
|
144
|
+
S = x.shape[0]
|
145
|
+
N = weights.shape[1]
|
146
|
+
K = weights.shape[2]
|
147
|
+
assert x.shape[-1] == K
|
148
|
+
|
149
|
+
num_segments = batch_info.num_segments
|
150
|
+
grid = (
|
151
|
+
triton.cdiv(N, BLOCK_N),
|
152
|
+
batch_info.bs if batch_info.use_cuda_graph else num_segments,
|
153
|
+
)
|
154
|
+
|
155
|
+
output = torch.empty((S, N), device=x.device, dtype=x.dtype)
|
156
|
+
_chunked_lora_shrink_kernel[grid](
|
157
|
+
x=x,
|
158
|
+
weights=weights,
|
159
|
+
output=output,
|
160
|
+
seg_indptr=batch_info.seg_indptr,
|
161
|
+
weight_indices=batch_info.weight_indices,
|
162
|
+
lora_ranks=batch_info.lora_ranks,
|
163
|
+
permutation=batch_info.permutation,
|
164
|
+
num_segs=num_segments,
|
165
|
+
# constants
|
166
|
+
N=N,
|
167
|
+
K=K,
|
168
|
+
NUM_SLICES=num_slices,
|
169
|
+
BLOCK_M=BLOCK_M,
|
170
|
+
BLOCK_N=BLOCK_N,
|
171
|
+
BLOCK_K=BLOCK_K,
|
172
|
+
)
|
173
|
+
|
174
|
+
return output
|
sglang/srt/lora/utils.py
CHANGED
@@ -5,7 +5,7 @@ from typing import Iterable, Optional, Set, Tuple
|
|
5
5
|
|
6
6
|
import torch
|
7
7
|
|
8
|
-
from sglang.srt.hf_transformers_utils import AutoConfig
|
8
|
+
from sglang.srt.utils.hf_transformers_utils import AutoConfig
|
9
9
|
|
10
10
|
|
11
11
|
@dataclass
|
@@ -19,6 +19,9 @@ class LoRABatchInfo:
|
|
19
19
|
# Number of segments. For triton backend, it is equal to batch size.
|
20
20
|
num_segments: int
|
21
21
|
|
22
|
+
# Maximum segment length of current batch
|
23
|
+
max_len: int
|
24
|
+
|
22
25
|
# Indice pointers of each segment in shape (num_segments + 1, )
|
23
26
|
seg_indptr: torch.Tensor
|
24
27
|
|
@@ -34,9 +37,6 @@ class LoRABatchInfo:
|
|
34
37
|
# Lengths of each segments in shape (num_segments,)
|
35
38
|
seg_lens: Optional[torch.Tensor]
|
36
39
|
|
37
|
-
# Maximum segment length of current batch
|
38
|
-
max_len: Optional[int]
|
39
|
-
|
40
40
|
# The logical (re)ordering of input rows (tokens), in shape (num_tokens,)
|
41
41
|
permutation: Optional[torch.Tensor]
|
42
42
|
|
@@ -98,6 +98,7 @@ def get_normalized_target_modules(
|
|
98
98
|
) -> set[str]:
|
99
99
|
"""
|
100
100
|
Mapping a list of target module name to names of the normalized LoRA weights.
|
101
|
+
Handles both base module names (e.g., "gate_proj") and prefixed module names (e.g., "feed_forward.gate_proj").
|
101
102
|
"""
|
102
103
|
params_mapping = {
|
103
104
|
"q_proj": "qkv_proj",
|
@@ -109,7 +110,8 @@ def get_normalized_target_modules(
|
|
109
110
|
|
110
111
|
result = set()
|
111
112
|
for name in target_modules:
|
112
|
-
|
113
|
+
base_name = name.split(".")[-1]
|
114
|
+
normalized_name = params_mapping.get(base_name, base_name)
|
113
115
|
result.add(normalized_name)
|
114
116
|
return result
|
115
117
|
|