sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,348 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
6
|
+
from sglang.srt.lora.triton_ops import (
|
7
|
+
chunked_sgmv_lora_expand_forward,
|
8
|
+
chunked_sgmv_lora_shrink_forward,
|
9
|
+
)
|
10
|
+
from sglang.srt.lora.utils import LoRABatchInfo
|
11
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
12
|
+
from sglang.srt.server_args import ServerArgs
|
13
|
+
|
14
|
+
MIN_CHUNK_SIZE = 16
|
15
|
+
|
16
|
+
|
17
|
+
class ChunkedSgmvLoRABackend(BaseLoRABackend):
|
18
|
+
"""
|
19
|
+
Chunked LoRA backend using segmented matrix-vector multiplication.
|
20
|
+
|
21
|
+
This backend is largely based on the SGMV (Segmented Gather Matrix-Vector multiplication) algorithm
|
22
|
+
introduced in the Punica paper (https://arxiv.org/pdf/2310.18547). One main variation made here is to
|
23
|
+
segment the input sequences into fixed-size chunks, which reduces excessive kernel launches especially
|
24
|
+
when the LoRA distribution is skewed.
|
25
|
+
"""
|
26
|
+
|
27
|
+
name = "csgmv"
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
max_loras_per_batch: int,
|
32
|
+
device: torch.device,
|
33
|
+
server_args: ServerArgs,
|
34
|
+
):
|
35
|
+
super().__init__(max_loras_per_batch, device)
|
36
|
+
self.max_chunk_size = server_args.max_lora_chunk_size
|
37
|
+
|
38
|
+
def run_lora_a_sgemm(
|
39
|
+
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
40
|
+
) -> torch.Tensor:
|
41
|
+
return chunked_sgmv_lora_shrink_forward(
|
42
|
+
x=x,
|
43
|
+
weights=weights,
|
44
|
+
batch_info=self.batch_info,
|
45
|
+
num_slices=1,
|
46
|
+
)
|
47
|
+
|
48
|
+
def run_lora_b_sgemm(
|
49
|
+
self,
|
50
|
+
x: torch.Tensor,
|
51
|
+
weights: torch.Tensor,
|
52
|
+
output_offset: torch.Tensor,
|
53
|
+
base_output: torch.Tensor = None,
|
54
|
+
*args,
|
55
|
+
**kwargs
|
56
|
+
) -> torch.Tensor:
|
57
|
+
# For simple lora B, we use slice offsets [0, output_dim]
|
58
|
+
output_dim = weights.shape[-2]
|
59
|
+
max_slice_size = output_dim
|
60
|
+
return chunked_sgmv_lora_expand_forward(
|
61
|
+
x=x,
|
62
|
+
weights=weights,
|
63
|
+
batch_info=self.batch_info,
|
64
|
+
slice_offsets=output_offset,
|
65
|
+
max_slice_size=max_slice_size,
|
66
|
+
base_output=base_output,
|
67
|
+
)
|
68
|
+
|
69
|
+
def run_qkv_lora(
|
70
|
+
self,
|
71
|
+
x: torch.Tensor,
|
72
|
+
qkv_lora_a: torch.Tensor,
|
73
|
+
qkv_lora_b: torch.Tensor,
|
74
|
+
output_offset: torch.Tensor,
|
75
|
+
max_qkv_out_dim: int,
|
76
|
+
base_output: torch.Tensor = None,
|
77
|
+
*args,
|
78
|
+
**kwargs
|
79
|
+
) -> torch.Tensor:
|
80
|
+
|
81
|
+
# x: (s, input_dim)
|
82
|
+
# qkv_lora_a: (num_lora, 3 * r, input_dim)
|
83
|
+
# qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
|
84
|
+
assert isinstance(qkv_lora_b, torch.Tensor)
|
85
|
+
|
86
|
+
lora_a_output = chunked_sgmv_lora_shrink_forward(
|
87
|
+
x=x,
|
88
|
+
weights=qkv_lora_a,
|
89
|
+
batch_info=self.batch_info,
|
90
|
+
num_slices=3,
|
91
|
+
)
|
92
|
+
lora_output = chunked_sgmv_lora_expand_forward(
|
93
|
+
x=lora_a_output,
|
94
|
+
weights=qkv_lora_b,
|
95
|
+
batch_info=self.batch_info,
|
96
|
+
slice_offsets=output_offset,
|
97
|
+
max_slice_size=max_qkv_out_dim,
|
98
|
+
base_output=base_output,
|
99
|
+
)
|
100
|
+
return lora_output
|
101
|
+
|
102
|
+
def run_gate_up_lora(
|
103
|
+
self,
|
104
|
+
x: torch.Tensor,
|
105
|
+
gate_up_lora_a: torch.Tensor,
|
106
|
+
gate_up_lora_b: torch.Tensor,
|
107
|
+
output_offset: torch.Tensor,
|
108
|
+
base_output: torch.Tensor = None,
|
109
|
+
*args,
|
110
|
+
**kwargs
|
111
|
+
) -> torch.Tensor:
|
112
|
+
|
113
|
+
# x: (s, input_dim)
|
114
|
+
# gate_up_lora_a: (num_lora, 2 * r, input_dim)
|
115
|
+
# gate_up_lora_b: (num_lora, 2 * output_dim, r)
|
116
|
+
assert isinstance(gate_up_lora_b, torch.Tensor)
|
117
|
+
output_dim = gate_up_lora_b.shape[-2] // 2
|
118
|
+
|
119
|
+
# lora_a_output: (s, 2 * r)
|
120
|
+
lora_a_output = chunked_sgmv_lora_shrink_forward(
|
121
|
+
x=x,
|
122
|
+
weights=gate_up_lora_a,
|
123
|
+
batch_info=self.batch_info,
|
124
|
+
num_slices=2,
|
125
|
+
)
|
126
|
+
lora_output = chunked_sgmv_lora_expand_forward(
|
127
|
+
x=lora_a_output,
|
128
|
+
weights=gate_up_lora_b,
|
129
|
+
batch_info=self.batch_info,
|
130
|
+
slice_offsets=output_offset,
|
131
|
+
max_slice_size=output_dim,
|
132
|
+
base_output=base_output,
|
133
|
+
)
|
134
|
+
return lora_output
|
135
|
+
|
136
|
+
def _determine_chunk_size(self, forward_batch: ForwardBatch) -> int:
|
137
|
+
"""
|
138
|
+
Heuristically determine the chunk size based on token token number in a batch.
|
139
|
+
|
140
|
+
Args:
|
141
|
+
forward_batch (ForwardBatch): The batch information containing sequence lengths.
|
142
|
+
|
143
|
+
Returns:
|
144
|
+
The determined chunk size
|
145
|
+
"""
|
146
|
+
|
147
|
+
if self.max_chunk_size <= MIN_CHUNK_SIZE:
|
148
|
+
return MIN_CHUNK_SIZE
|
149
|
+
|
150
|
+
num_tokens = (
|
151
|
+
forward_batch.extend_num_tokens
|
152
|
+
if forward_batch.forward_mode.is_extend()
|
153
|
+
else forward_batch.batch_size
|
154
|
+
)
|
155
|
+
if num_tokens >= 256:
|
156
|
+
chunk_size = 128
|
157
|
+
elif num_tokens >= 64:
|
158
|
+
chunk_size = 32
|
159
|
+
else: # num_tokens < 64
|
160
|
+
chunk_size = 16
|
161
|
+
return min(self.max_chunk_size, chunk_size)
|
162
|
+
|
163
|
+
def prepare_lora_batch(
|
164
|
+
self,
|
165
|
+
forward_batch: ForwardBatch,
|
166
|
+
weight_indices: list[int],
|
167
|
+
lora_ranks: list[int],
|
168
|
+
scalings: list[float],
|
169
|
+
batch_info: Optional[LoRABatchInfo] = None,
|
170
|
+
):
|
171
|
+
chunk_size = self._determine_chunk_size(forward_batch)
|
172
|
+
|
173
|
+
permutation, weight_indices_reordered = ChunkedSgmvLoRABackend._get_permutation(
|
174
|
+
seq_weight_indices=weight_indices,
|
175
|
+
forward_batch=forward_batch,
|
176
|
+
)
|
177
|
+
|
178
|
+
seg_weight_indices, seg_indptr = self._get_segments_info(
|
179
|
+
weights_reordered=weight_indices_reordered,
|
180
|
+
chunk_size=chunk_size,
|
181
|
+
)
|
182
|
+
num_segments = len(seg_weight_indices)
|
183
|
+
|
184
|
+
lora_ranks_tensor = torch.tensor(
|
185
|
+
lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
|
186
|
+
)
|
187
|
+
scalings_tensor = torch.tensor(
|
188
|
+
scalings, dtype=torch.float, pin_memory=True, device="cpu"
|
189
|
+
)
|
190
|
+
|
191
|
+
if batch_info is None:
|
192
|
+
batch_info = LoRABatchInfo(
|
193
|
+
bs=forward_batch.batch_size,
|
194
|
+
num_segments=num_segments,
|
195
|
+
max_len=chunk_size,
|
196
|
+
use_cuda_graph=False,
|
197
|
+
seg_indptr=torch.empty(
|
198
|
+
(num_segments + 1,), dtype=torch.int32, device=self.device
|
199
|
+
),
|
200
|
+
weight_indices=torch.empty(
|
201
|
+
(num_segments,), dtype=torch.int32, device=self.device
|
202
|
+
),
|
203
|
+
lora_ranks=torch.empty(
|
204
|
+
(self.max_loras_per_batch,), dtype=torch.int32, device=self.device
|
205
|
+
),
|
206
|
+
scalings=torch.empty(
|
207
|
+
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
|
208
|
+
),
|
209
|
+
permutation=torch.empty(
|
210
|
+
(len(permutation),), dtype=torch.int32, device=self.device
|
211
|
+
),
|
212
|
+
# Not used in chunked kernels
|
213
|
+
seg_lens=None,
|
214
|
+
)
|
215
|
+
else:
|
216
|
+
batch_info.bs = forward_batch.batch_size
|
217
|
+
batch_info.num_segments = num_segments
|
218
|
+
batch_info.max_len = chunk_size
|
219
|
+
|
220
|
+
# Copy to device asynchronously
|
221
|
+
batch_info.lora_ranks[: self.max_loras_per_batch].copy_(
|
222
|
+
lora_ranks_tensor, non_blocking=True
|
223
|
+
)
|
224
|
+
batch_info.scalings[: self.max_loras_per_batch].copy_(
|
225
|
+
scalings_tensor, non_blocking=True
|
226
|
+
)
|
227
|
+
batch_info.weight_indices[:num_segments].copy_(
|
228
|
+
seg_weight_indices, non_blocking=True
|
229
|
+
)
|
230
|
+
batch_info.seg_indptr[: num_segments + 1].copy_(seg_indptr, non_blocking=True)
|
231
|
+
batch_info.permutation[: len(permutation)].copy_(permutation, non_blocking=True)
|
232
|
+
|
233
|
+
self.batch_info = batch_info
|
234
|
+
|
235
|
+
@staticmethod
|
236
|
+
def _get_permutation(seq_weight_indices, forward_batch: ForwardBatch):
|
237
|
+
"""
|
238
|
+
Computes permutation indices for reordering tokens by their LoRA adapter assignments.
|
239
|
+
|
240
|
+
This function implements the "gather" step in Chunked Segmented Gather Matrix Vector
|
241
|
+
multiplication by creating a permutation that groups tokens by their LoRA adapter.
|
242
|
+
Tokens using the same LoRA adapter are placed together to enable efficient batched
|
243
|
+
computation.
|
244
|
+
|
245
|
+
Example:
|
246
|
+
seq_weight_indices = [0, 1, 0] # 3 sequences using adapters [0, 1, 0]
|
247
|
+
extend_seq_lens = [2, 1, 3] # sequence lengths [2, 1, 3 tokens]
|
248
|
+
|
249
|
+
# Creates row_weight_indices: [0, 0, 1, 0, 0, 0] (6 tokens total)
|
250
|
+
# Returns permutation: [0, 1, 3, 4, 5, 2] (groups adapter 0 tokens together)
|
251
|
+
# weights_reordered: [0, 0, 0, 0, 0, 1] (sorted by adapter)
|
252
|
+
|
253
|
+
Args:
|
254
|
+
seq_weight_indices: List of LoRA adapter indices for each sequence
|
255
|
+
forward_batch (ForwardBatch): Batch information containing sequence lengths
|
256
|
+
|
257
|
+
Returns:
|
258
|
+
tuple: (permutation, weights_reordered) where:
|
259
|
+
- permutation: Token reordering indices to group by adapter
|
260
|
+
- weights_reordered: Sorted adapter indices for each token
|
261
|
+
"""
|
262
|
+
with torch.device("cpu"):
|
263
|
+
seq_weight_indices = torch.tensor(seq_weight_indices, dtype=torch.int32)
|
264
|
+
|
265
|
+
seg_lens_cpu = (
|
266
|
+
torch.tensor(
|
267
|
+
forward_batch.extend_seq_lens_cpu,
|
268
|
+
dtype=torch.int32,
|
269
|
+
)
|
270
|
+
if forward_batch.forward_mode.is_extend()
|
271
|
+
else torch.ones(forward_batch.batch_size, dtype=torch.int32)
|
272
|
+
)
|
273
|
+
|
274
|
+
row_weight_indices = torch.repeat_interleave(
|
275
|
+
seq_weight_indices, seg_lens_cpu
|
276
|
+
)
|
277
|
+
permutation = torch.empty(
|
278
|
+
(len(row_weight_indices),), dtype=torch.long, pin_memory=True
|
279
|
+
)
|
280
|
+
torch.argsort(row_weight_indices, stable=True, out=permutation)
|
281
|
+
weights_reordered = row_weight_indices[permutation]
|
282
|
+
|
283
|
+
return permutation, weights_reordered
|
284
|
+
|
285
|
+
def _get_segments_info(self, weights_reordered: torch.Tensor, chunk_size: int):
|
286
|
+
"""
|
287
|
+
Computes segment information for chunked SGMV operations.
|
288
|
+
|
289
|
+
This function takes the reordered weight indices and creates segments of fixed size
|
290
|
+
(self.segment_size) for efficient kernel execution. Each segment contains tokens
|
291
|
+
that use the same LoRA adapter, enabling vectorized computation.
|
292
|
+
|
293
|
+
The segmentation is necessary because:
|
294
|
+
1. GPU kernels work efficiently on fixed-size blocks
|
295
|
+
2. Large groups of tokens using the same adapter are split into manageable chunks
|
296
|
+
3. Each segment can be processed independently in parallel
|
297
|
+
|
298
|
+
Example:
|
299
|
+
weights_reordered = [0, 0, 0, 0, 0, 1] # 5 tokens with adapter 0, 1 with adapter 1
|
300
|
+
segment_size = 3
|
301
|
+
|
302
|
+
# Creates segments:
|
303
|
+
# Segment 0: tokens 0-2 (adapter 0), length=3
|
304
|
+
# Segment 1: tokens 3-4 (adapter 0), length=2
|
305
|
+
# Segment 2: token 5 (adapter 1), length=1
|
306
|
+
|
307
|
+
# Returns:
|
308
|
+
# weight_indices_list: [0, 0, 1] (adapter for each segment)
|
309
|
+
# seg_indptr: [0, 3, 5, 6] (cumulative segment boundaries)
|
310
|
+
|
311
|
+
Args:
|
312
|
+
weights_reordered (torch.Tensor): Sorted adapter indices for each token
|
313
|
+
chunk_size (int): Fixed size for each segment
|
314
|
+
|
315
|
+
Returns:
|
316
|
+
tuple: (weight_indices_list, seg_indptr) where:
|
317
|
+
- weight_indices_list: LoRA adapter index for each segment
|
318
|
+
- seg_indptr: Cumulative segment boundaries (CSR-style indptr)
|
319
|
+
"""
|
320
|
+
with torch.device("cpu"):
|
321
|
+
unique_weights, counts = torch.unique_consecutive(
|
322
|
+
weights_reordered, return_counts=True
|
323
|
+
)
|
324
|
+
|
325
|
+
weight_indices_list = []
|
326
|
+
seg_lens_list = []
|
327
|
+
|
328
|
+
for weight_idx, group_len in zip(unique_weights, counts):
|
329
|
+
group_len = group_len.item()
|
330
|
+
num_segs = (group_len + chunk_size - 1) // chunk_size
|
331
|
+
|
332
|
+
weight_indices_list.extend([weight_idx.item()] * num_segs)
|
333
|
+
seg_lens_list.extend([chunk_size] * (num_segs - 1))
|
334
|
+
seg_lens_list.append(group_len - (num_segs - 1) * chunk_size)
|
335
|
+
|
336
|
+
seg_lens = torch.tensor(seg_lens_list, dtype=torch.int32)
|
337
|
+
|
338
|
+
weight_indices_list = torch.tensor(
|
339
|
+
weight_indices_list, dtype=torch.int32, pin_memory=True
|
340
|
+
)
|
341
|
+
|
342
|
+
seg_indptr = torch.empty(
|
343
|
+
(len(seg_lens) + 1,), dtype=torch.int32, pin_memory=True
|
344
|
+
)
|
345
|
+
seg_indptr[0] = 0
|
346
|
+
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
|
347
|
+
|
348
|
+
return weight_indices_list, seg_indptr
|
@@ -11,12 +11,18 @@ from sglang.srt.lora.triton_ops import (
|
|
11
11
|
)
|
12
12
|
from sglang.srt.lora.utils import LoRABatchInfo
|
13
13
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
14
|
+
from sglang.srt.server_args import ServerArgs
|
14
15
|
|
15
16
|
|
16
17
|
class TritonLoRABackend(BaseLoRABackend):
|
17
18
|
name = "triton"
|
18
19
|
|
19
|
-
def __init__(
|
20
|
+
def __init__(
|
21
|
+
self,
|
22
|
+
max_loras_per_batch: int,
|
23
|
+
device: torch.device,
|
24
|
+
**kwargs,
|
25
|
+
):
|
20
26
|
super().__init__(max_loras_per_batch, device)
|
21
27
|
|
22
28
|
def run_lora_a_sgemm(
|
@@ -30,7 +36,7 @@ class TritonLoRABackend(BaseLoRABackend):
|
|
30
36
|
weights: torch.Tensor,
|
31
37
|
base_output: torch.Tensor = None,
|
32
38
|
*args,
|
33
|
-
**kwargs
|
39
|
+
**kwargs,
|
34
40
|
) -> torch.Tensor:
|
35
41
|
return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output)
|
36
42
|
|
@@ -43,7 +49,7 @@ class TritonLoRABackend(BaseLoRABackend):
|
|
43
49
|
max_qkv_out_dim: int,
|
44
50
|
base_output: torch.Tensor = None,
|
45
51
|
*args,
|
46
|
-
**kwargs
|
52
|
+
**kwargs,
|
47
53
|
) -> torch.Tensor:
|
48
54
|
|
49
55
|
# x: (s, input_dim)
|
@@ -69,7 +75,7 @@ class TritonLoRABackend(BaseLoRABackend):
|
|
69
75
|
gate_up_lora_b: torch.Tensor,
|
70
76
|
base_output: torch.Tensor = None,
|
71
77
|
*args,
|
72
|
-
**kwargs
|
78
|
+
**kwargs,
|
73
79
|
) -> torch.Tensor:
|
74
80
|
|
75
81
|
# x: (s, input_dim)
|
sglang/srt/lora/lora.py
CHANGED
@@ -26,16 +26,17 @@ import torch
|
|
26
26
|
from torch import nn
|
27
27
|
|
28
28
|
from sglang.srt.configs.load_config import LoadConfig
|
29
|
-
from sglang.srt.hf_transformers_utils import AutoConfig
|
30
29
|
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
31
|
-
|
32
|
-
# from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
|
30
|
+
from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
|
33
31
|
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
|
34
32
|
from sglang.srt.lora.lora_config import LoRAConfig
|
35
33
|
from sglang.srt.model_loader.loader import DefaultModelLoader
|
34
|
+
from sglang.srt.utils.hf_transformers_utils import AutoConfig
|
36
35
|
|
37
36
|
logger = logging.getLogger(__name__)
|
38
37
|
|
38
|
+
SUPPORTED_BACKENDS = (TritonLoRABackend, ChunkedSgmvLoRABackend)
|
39
|
+
|
39
40
|
|
40
41
|
class LoRALayer(nn.Module):
|
41
42
|
def __init__(self, config: LoRAConfig, base_hf_config: AutoConfig):
|
@@ -48,6 +49,7 @@ class LoRALayer(nn.Module):
|
|
48
49
|
|
49
50
|
|
50
51
|
class LoRAAdapter(nn.Module):
|
52
|
+
|
51
53
|
def __init__(
|
52
54
|
self,
|
53
55
|
uid: str,
|
@@ -159,8 +161,8 @@ class LoRAAdapter(nn.Module):
|
|
159
161
|
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
|
160
162
|
if up_name not in weights:
|
161
163
|
weights[up_name] = torch.zeros_like(weights[weight_name])
|
162
|
-
assert isinstance(self.lora_backend,
|
163
|
-
f"LoRA weight initialization currently only supported for '
|
164
|
+
assert isinstance(self.lora_backend, SUPPORTED_BACKENDS), (
|
165
|
+
f"LoRA weight initialization currently only supported for LoRA backends: {', '.join(b.name for b in SUPPORTED_BACKENDS)}"
|
164
166
|
f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
|
165
167
|
f"or consider implementing custom initialization logic for other backends."
|
166
168
|
)
|
sglang/srt/lora/lora_manager.py
CHANGED
@@ -21,7 +21,6 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple
|
|
21
21
|
import torch
|
22
22
|
|
23
23
|
from sglang.srt.configs.load_config import LoadConfig
|
24
|
-
from sglang.srt.hf_transformers_utils import AutoConfig
|
25
24
|
from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_from_name
|
26
25
|
from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
|
27
26
|
from sglang.srt.lora.lora import LoRAAdapter
|
@@ -35,9 +34,11 @@ from sglang.srt.lora.utils import (
|
|
35
34
|
get_normalized_target_modules,
|
36
35
|
get_target_module_name,
|
37
36
|
)
|
38
|
-
from sglang.srt.managers.io_struct import
|
37
|
+
from sglang.srt.managers.io_struct import LoRAUpdateOutput
|
39
38
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
39
|
+
from sglang.srt.server_args import ServerArgs
|
40
40
|
from sglang.srt.utils import replace_submodule
|
41
|
+
from sglang.srt.utils.hf_transformers_utils import AutoConfig
|
41
42
|
|
42
43
|
logger = logging.getLogger(__name__)
|
43
44
|
|
@@ -56,6 +57,7 @@ class LoRAManager:
|
|
56
57
|
max_lora_rank: Optional[int] = None,
|
57
58
|
target_modules: Optional[Iterable[str]] = None,
|
58
59
|
lora_paths: Optional[List[LoRARef]] = None,
|
60
|
+
server_args: Optional[ServerArgs] = None,
|
59
61
|
):
|
60
62
|
self.base_model: torch.nn.Module = base_model
|
61
63
|
self.base_hf_config: AutoConfig = base_hf_config
|
@@ -72,6 +74,7 @@ class LoRAManager:
|
|
72
74
|
self.lora_backend: BaseLoRABackend = backend_type(
|
73
75
|
max_loras_per_batch=max_loras_per_batch,
|
74
76
|
device=self.device,
|
77
|
+
server_args=server_args,
|
75
78
|
)
|
76
79
|
|
77
80
|
# Initialize mutable internal state of the LoRAManager.
|
@@ -104,8 +107,8 @@ class LoRAManager:
|
|
104
107
|
|
105
108
|
def create_lora_update_result(
|
106
109
|
self, success: bool, error_message: str = ""
|
107
|
-
) ->
|
108
|
-
return
|
110
|
+
) -> LoRAUpdateOutput:
|
111
|
+
return LoRAUpdateOutput(
|
109
112
|
success=success,
|
110
113
|
error_message=error_message,
|
111
114
|
loaded_adapters={
|
@@ -114,7 +117,7 @@ class LoRAManager:
|
|
114
117
|
},
|
115
118
|
)
|
116
119
|
|
117
|
-
def load_lora_adapter(self, lora_ref: LoRARef) ->
|
120
|
+
def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput:
|
118
121
|
"""
|
119
122
|
Load a single LoRA adapter from the specified path.
|
120
123
|
|
@@ -171,7 +174,7 @@ class LoRAManager:
|
|
171
174
|
"`--max-loras-per-batch` or load it as unpinned LoRA adapters."
|
172
175
|
)
|
173
176
|
|
174
|
-
def unload_lora_adapter(self, lora_ref: LoRARef) ->
|
177
|
+
def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput:
|
175
178
|
"""
|
176
179
|
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
|
177
180
|
delete the corresponding LoRA modules.
|
@@ -415,6 +418,10 @@ class LoRAManager:
|
|
415
418
|
replace_submodule(self.base_model, module_name, lora_module)
|
416
419
|
return lora_module
|
417
420
|
|
421
|
+
def should_skip_lora_for_vision_model(self, module_name):
|
422
|
+
# TODO: support different vision models
|
423
|
+
return module_name.find("vision_model.model") != -1
|
424
|
+
|
418
425
|
def init_lora_modules(self):
|
419
426
|
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
|
420
427
|
self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [
|
@@ -432,6 +439,10 @@ class LoRAManager:
|
|
432
439
|
) and not self.base_model.should_apply_lora(module_name):
|
433
440
|
continue
|
434
441
|
|
442
|
+
# Skip vision model
|
443
|
+
if self.should_skip_lora_for_vision_model(module_name):
|
444
|
+
continue
|
445
|
+
|
435
446
|
# The module should be converted if it is included in target_names
|
436
447
|
if module_name.split(".")[-1] in self.target_modules:
|
437
448
|
layer_id = get_layer_id(module_name)
|
sglang/srt/lora/mem_pool.py
CHANGED
@@ -4,7 +4,6 @@ from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
|
|
4
4
|
import torch
|
5
5
|
|
6
6
|
from sglang.srt.distributed import divide
|
7
|
-
from sglang.srt.hf_transformers_utils import AutoConfig
|
8
7
|
from sglang.srt.lora.layers import BaseLayerWithLoRA
|
9
8
|
from sglang.srt.lora.lora import LoRAAdapter
|
10
9
|
from sglang.srt.lora.lora_config import LoRAConfig
|
@@ -17,6 +16,7 @@ from sglang.srt.lora.utils import (
|
|
17
16
|
get_stacked_multiply,
|
18
17
|
get_target_module_name,
|
19
18
|
)
|
19
|
+
from sglang.srt.utils.hf_transformers_utils import AutoConfig
|
20
20
|
|
21
21
|
logger = logging.getLogger(__name__)
|
22
22
|
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from .chunked_sgmv_expand import chunked_sgmv_lora_expand_forward
|
2
|
+
from .chunked_sgmv_shrink import chunked_sgmv_lora_shrink_forward
|
1
3
|
from .gate_up_lora_b import gate_up_lora_b_fwd
|
2
4
|
from .qkv_lora_b import qkv_lora_b_fwd
|
3
5
|
from .sgemm_lora_a import sgemm_lora_a_fwd
|
@@ -8,4 +10,6 @@ __all__ = [
|
|
8
10
|
"qkv_lora_b_fwd",
|
9
11
|
"sgemm_lora_a_fwd",
|
10
12
|
"sgemm_lora_b_fwd",
|
13
|
+
"chunked_sgmv_lora_shrink_forward",
|
14
|
+
"chunked_sgmv_lora_expand_forward",
|
11
15
|
]
|