sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +251 -26
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +63 -3
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +34 -19
- sglang/srt/entrypoints/openai/serving_completions.py +10 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +12 -0
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +250 -112
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +110 -49
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +43 -29
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -45
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +242 -278
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +13 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +160 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +90 -115
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +41 -477
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +24 -22
- sglang/srt/mem_cache/hiradix_cache.py +184 -101
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +324 -41
- sglang/srt/mem_cache/memory_pool_host.py +25 -18
- sglang/srt/mem_cache/radix_cache.py +5 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +189 -31
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +311 -50
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +90 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +297 -79
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/utils.py +37 -2
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
sglang/srt/lora/layers.py
CHANGED
@@ -66,6 +66,15 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
66
66
|
lora_backend: BaseLoRABackend,
|
67
67
|
) -> None:
|
68
68
|
super().__init__(base_layer, lora_backend)
|
69
|
+
shard_size = self.base_layer.output_partition_sizes[0]
|
70
|
+
self.output_offset = torch.tensor(
|
71
|
+
[
|
72
|
+
0,
|
73
|
+
shard_size,
|
74
|
+
],
|
75
|
+
dtype=torch.int32,
|
76
|
+
device=next(self.base_layer.parameters()).device,
|
77
|
+
)
|
69
78
|
|
70
79
|
def set_lora_info(
|
71
80
|
self,
|
@@ -81,6 +90,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
81
90
|
lora_output = self.lora_backend.run_lora_b_sgemm(
|
82
91
|
x=lora_a_output,
|
83
92
|
weights=self.B_buffer,
|
93
|
+
output_offset=self.output_offset,
|
84
94
|
base_output=base_output,
|
85
95
|
)
|
86
96
|
return lora_output
|
@@ -130,11 +140,23 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
130
140
|
self.A_buffer_gate_up = A_buffer
|
131
141
|
self.B_buffer_gate_up = B_buffer
|
132
142
|
|
143
|
+
shard_size = self.base_layer.output_partition_sizes[0]
|
144
|
+
self.output_offset = torch.tensor(
|
145
|
+
[
|
146
|
+
0,
|
147
|
+
shard_size,
|
148
|
+
2 * shard_size,
|
149
|
+
],
|
150
|
+
dtype=torch.int32,
|
151
|
+
device=next(self.base_layer.parameters()).device,
|
152
|
+
)
|
153
|
+
|
133
154
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
134
155
|
lora_output = self.lora_backend.run_gate_up_lora(
|
135
156
|
x=x,
|
136
157
|
gate_up_lora_a=self.A_buffer_gate_up,
|
137
158
|
gate_up_lora_b=self.B_buffer_gate_up,
|
159
|
+
output_offset=self.output_offset,
|
138
160
|
base_output=base_output,
|
139
161
|
)
|
140
162
|
return lora_output
|
@@ -243,12 +265,22 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
243
265
|
self.set_lora = True
|
244
266
|
self.A_buffer = A_buffer
|
245
267
|
self.B_buffer = B_buffer
|
268
|
+
output_size = self.base_layer.output_size
|
269
|
+
self.output_offset = torch.tensor(
|
270
|
+
[
|
271
|
+
0,
|
272
|
+
output_size,
|
273
|
+
],
|
274
|
+
dtype=torch.int32,
|
275
|
+
device=next(self.base_layer.parameters()).device,
|
276
|
+
)
|
246
277
|
|
247
278
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
248
279
|
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
249
280
|
lora_output = self.lora_backend.run_lora_b_sgemm(
|
250
281
|
x=lora_a_output,
|
251
282
|
weights=self.B_buffer,
|
283
|
+
output_offset=self.output_offset,
|
252
284
|
base_output=base_output,
|
253
285
|
)
|
254
286
|
return lora_output
|
sglang/srt/lora/lora.py
CHANGED
@@ -28,6 +28,9 @@ from torch import nn
|
|
28
28
|
from sglang.srt.configs.load_config import LoadConfig
|
29
29
|
from sglang.srt.hf_transformers_utils import AutoConfig
|
30
30
|
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
31
|
+
|
32
|
+
# from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
|
33
|
+
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
|
31
34
|
from sglang.srt.lora.lora_config import LoRAConfig
|
32
35
|
from sglang.srt.model_loader.loader import DefaultModelLoader
|
33
36
|
|
@@ -156,7 +159,7 @@ class LoRAAdapter(nn.Module):
|
|
156
159
|
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
|
157
160
|
if up_name not in weights:
|
158
161
|
weights[up_name] = torch.zeros_like(weights[weight_name])
|
159
|
-
assert self.lora_backend
|
162
|
+
assert isinstance(self.lora_backend, TritonLoRABackend), (
|
160
163
|
f"LoRA weight initialization currently only supported for 'triton' backend. "
|
161
164
|
f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
|
162
165
|
f"or consider implementing custom initialization logic for other backends."
|
sglang/srt/lora/lora_manager.py
CHANGED
@@ -69,7 +69,10 @@ class LoRAManager:
|
|
69
69
|
# LoRA backend for running sgemm kernels
|
70
70
|
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
|
71
71
|
backend_type = get_backend_from_name(lora_backend)
|
72
|
-
self.lora_backend: BaseLoRABackend = backend_type(
|
72
|
+
self.lora_backend: BaseLoRABackend = backend_type(
|
73
|
+
max_loras_per_batch=max_loras_per_batch,
|
74
|
+
device=self.device,
|
75
|
+
)
|
73
76
|
|
74
77
|
# Initialize mutable internal state of the LoRAManager.
|
75
78
|
self.init_state(
|
@@ -82,29 +85,22 @@ class LoRAManager:
|
|
82
85
|
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
|
83
86
|
with torch.device("cuda"):
|
84
87
|
self.cuda_graph_batch_info = LoRABatchInfo(
|
85
|
-
bs=
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
),
|
88
|
+
bs=max_bs_in_cuda_graph,
|
89
|
+
use_cuda_graph=True,
|
90
|
+
num_segments=None,
|
91
|
+
seg_lens=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
|
92
|
+
seg_indptr=torch.zeros(max_bs_in_cuda_graph + 1, dtype=torch.int32),
|
90
93
|
max_len=1,
|
91
|
-
weight_indices=torch.zeros(
|
92
|
-
|
93
|
-
),
|
94
|
+
weight_indices=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
|
95
|
+
permutation=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
|
94
96
|
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
|
95
97
|
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
|
96
98
|
)
|
97
99
|
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph],
|
103
|
-
dim=0,
|
104
|
-
out=self.cuda_graph_batch_info.seg_indptr[
|
105
|
-
1 : self.max_bs_in_cuda_graph + 1
|
106
|
-
],
|
107
|
-
)
|
100
|
+
self.lora_backend.init_cuda_graph_batch_info(
|
101
|
+
cuda_graph_batch_info=self.cuda_graph_batch_info,
|
102
|
+
max_bs_in_cuda_graph=max_bs_in_cuda_graph,
|
103
|
+
)
|
108
104
|
|
109
105
|
def create_lora_update_result(
|
110
106
|
self, success: bool, error_message: str = ""
|
@@ -232,7 +228,6 @@ class LoRAManager:
|
|
232
228
|
return required_slots <= mem_pool_vacancy
|
233
229
|
|
234
230
|
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
235
|
-
|
236
231
|
# Load active loras into lora memory pool
|
237
232
|
cur_uids = set(forward_batch.lora_ids)
|
238
233
|
|
@@ -247,102 +242,30 @@ class LoRAManager:
|
|
247
242
|
# set up batch info shared by all lora modules
|
248
243
|
bs = forward_batch.batch_size
|
249
244
|
|
250
|
-
|
251
|
-
weight_indices_out: torch.Tensor,
|
252
|
-
lora_ranks_out: torch.Tensor,
|
253
|
-
scalings_out: torch.Tensor,
|
254
|
-
):
|
255
|
-
"""
|
256
|
-
Transfer adapter metadata (weight indices, LoRA rank, scalings) from host
|
257
|
-
to device (CUDA) asynchronously.
|
258
|
-
"""
|
259
|
-
weight_indices = [0] * len(forward_batch.lora_ids)
|
260
|
-
lora_ranks = [0] * self.max_loras_per_batch
|
261
|
-
scalings = [0] * self.max_loras_per_batch
|
262
|
-
for i, uid in enumerate(forward_batch.lora_ids):
|
263
|
-
weight_indices[i] = self.memory_pool.get_buffer_id(uid)
|
264
|
-
if uid is not None:
|
265
|
-
lora = self.loras[uid]
|
266
|
-
lora_ranks[weight_indices[i]] = lora.config.r
|
267
|
-
scalings[weight_indices[i]] = lora.scaling
|
268
|
-
|
269
|
-
# Use pinned memory to avoid synchronizations during host-to-device transfer
|
270
|
-
weight_indices_tensor = torch.tensor(
|
271
|
-
weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
|
272
|
-
)
|
273
|
-
lora_ranks_tensor = torch.tensor(
|
274
|
-
lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
|
275
|
-
)
|
276
|
-
scalings_tensor = torch.tensor(
|
277
|
-
scalings, dtype=torch.float, pin_memory=True, device="cpu"
|
278
|
-
)
|
279
|
-
|
280
|
-
# Copy to device tensors asynchronously
|
281
|
-
weight_indices_out[:bs].copy_(weight_indices_tensor, non_blocking=True)
|
282
|
-
lora_ranks_out[: self.max_loras_per_batch].copy_(
|
283
|
-
lora_ranks_tensor, non_blocking=True
|
284
|
-
)
|
285
|
-
scalings_out[: self.max_loras_per_batch].copy_(
|
286
|
-
scalings_tensor, non_blocking=True
|
287
|
-
)
|
288
|
-
|
289
|
-
if (
|
245
|
+
use_cuda_graph = (
|
290
246
|
hasattr(self, "max_bs_in_cuda_graph")
|
291
247
|
and bs <= self.max_bs_in_cuda_graph
|
292
248
|
and forward_batch.forward_mode.is_cuda_graph()
|
293
|
-
)
|
294
|
-
# Do in-place updates when CUDA graph is enabled and the batch forward mode
|
295
|
-
# could use CUDA graph.
|
296
|
-
|
297
|
-
transfer_adapter_info(
|
298
|
-
self.cuda_graph_batch_info.weight_indices,
|
299
|
-
self.cuda_graph_batch_info.lora_ranks,
|
300
|
-
self.cuda_graph_batch_info.scalings,
|
301
|
-
)
|
302
|
-
|
303
|
-
self.cuda_graph_batch_info.bs = bs
|
304
|
-
self.cuda_graph_batch_info.max_len = 1
|
305
|
-
batch_info = self.cuda_graph_batch_info
|
306
|
-
else:
|
307
|
-
weight_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
|
308
|
-
lora_ranks = torch.zeros(
|
309
|
-
(self.max_loras_per_batch,), dtype=torch.int64, device=self.device
|
310
|
-
)
|
311
|
-
scalings = torch.zeros(
|
312
|
-
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
|
313
|
-
)
|
314
|
-
transfer_adapter_info(
|
315
|
-
weight_indices,
|
316
|
-
lora_ranks,
|
317
|
-
scalings,
|
318
|
-
)
|
319
|
-
|
320
|
-
seg_lens = (
|
321
|
-
forward_batch.extend_seq_lens
|
322
|
-
if forward_batch.forward_mode.is_extend()
|
323
|
-
else torch.ones(bs, device=self.device)
|
324
|
-
)
|
325
|
-
|
326
|
-
max_len = (
|
327
|
-
# Calculate max_len from the CPU copy to avoid D2H transfer.
|
328
|
-
max(forward_batch.extend_seq_lens_cpu)
|
329
|
-
if forward_batch.forward_mode.is_extend()
|
330
|
-
else 1
|
331
|
-
)
|
249
|
+
)
|
332
250
|
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
weight_indices=
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
251
|
+
weight_indices = [0] * len(forward_batch.lora_ids)
|
252
|
+
lora_ranks = [0] * self.max_loras_per_batch
|
253
|
+
scalings = [0] * self.max_loras_per_batch
|
254
|
+
for i, uid in enumerate(forward_batch.lora_ids):
|
255
|
+
weight_indices[i] = self.memory_pool.get_buffer_id(uid)
|
256
|
+
if uid is not None:
|
257
|
+
lora = self.loras[uid]
|
258
|
+
lora_ranks[weight_indices[i]] = lora.config.r
|
259
|
+
scalings[weight_indices[i]] = lora.scaling
|
260
|
+
# Do in-place updates when CUDA graph is enabled and the batch forward mode
|
261
|
+
# could use CUDA graph.
|
262
|
+
self.lora_backend.prepare_lora_batch(
|
263
|
+
forward_batch=forward_batch,
|
264
|
+
weight_indices=weight_indices,
|
265
|
+
lora_ranks=lora_ranks,
|
266
|
+
scalings=scalings,
|
267
|
+
batch_info=self.cuda_graph_batch_info if use_cuda_graph else None,
|
268
|
+
)
|
346
269
|
|
347
270
|
def update_lora_info(self):
|
348
271
|
"""
|
sglang/srt/lora/mem_pool.py
CHANGED
@@ -104,12 +104,18 @@ class LoRAMemoryPool:
|
|
104
104
|
return all(_can_support(x) for x in config)
|
105
105
|
|
106
106
|
def get_lora_A_shape(
|
107
|
-
self,
|
107
|
+
self,
|
108
|
+
module_name: str,
|
109
|
+
base_model: torch.nn.Module,
|
110
|
+
max_lora_dim: int,
|
111
|
+
layer_idx: int,
|
108
112
|
) -> Tuple[int]:
|
109
113
|
"""
|
110
114
|
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
111
115
|
"""
|
112
|
-
input_dim, _ = get_hidden_dim(
|
116
|
+
input_dim, _ = get_hidden_dim(
|
117
|
+
module_name, self.base_hf_config, base_model, layer_idx
|
118
|
+
)
|
113
119
|
c = get_stacked_multiply(module_name)
|
114
120
|
if self.tp_size > 1 and module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
115
121
|
input_dim = divide(input_dim, self.tp_size)
|
@@ -120,12 +126,18 @@ class LoRAMemoryPool:
|
|
120
126
|
)
|
121
127
|
|
122
128
|
def get_lora_B_shape(
|
123
|
-
self,
|
129
|
+
self,
|
130
|
+
module_name: str,
|
131
|
+
base_model: torch.nn.Module,
|
132
|
+
max_lora_dim: int,
|
133
|
+
layer_idx: int,
|
124
134
|
) -> Tuple[int]:
|
125
135
|
"""
|
126
136
|
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
127
137
|
"""
|
128
|
-
_, output_dim = get_hidden_dim(
|
138
|
+
_, output_dim = get_hidden_dim(
|
139
|
+
module_name, self.base_hf_config, base_model, layer_idx
|
140
|
+
)
|
129
141
|
if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
130
142
|
output_dim = divide(output_dim, self.tp_size)
|
131
143
|
return (
|
@@ -140,19 +152,21 @@ class LoRAMemoryPool:
|
|
140
152
|
def init_buffer(
|
141
153
|
buffer: Dict[str, List[torch.Tensor]],
|
142
154
|
target_modules: Set[str],
|
143
|
-
get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
|
155
|
+
get_lora_shape_fn: Callable[[str, torch.nn.Module, int, int], Tuple[int]],
|
144
156
|
):
|
145
157
|
for module_name in target_modules:
|
146
|
-
lora_shape = get_lora_shape_fn(
|
147
|
-
module_name, base_model, self.max_lora_rank
|
148
|
-
)
|
149
158
|
buffer[module_name] = [
|
150
159
|
torch.empty(
|
151
|
-
|
160
|
+
get_lora_shape_fn(
|
161
|
+
module_name,
|
162
|
+
base_model,
|
163
|
+
self.max_lora_rank,
|
164
|
+
idx,
|
165
|
+
),
|
152
166
|
dtype=self.dtype,
|
153
167
|
device=device,
|
154
168
|
)
|
155
|
-
for
|
169
|
+
for idx in range(self.num_layer)
|
156
170
|
]
|
157
171
|
|
158
172
|
init_buffer(
|
sglang/srt/lora/utils.py
CHANGED
@@ -10,19 +10,19 @@ from sglang.srt.hf_transformers_utils import AutoConfig
|
|
10
10
|
|
11
11
|
@dataclass
|
12
12
|
class LoRABatchInfo:
|
13
|
+
# The forward mode is using CUDA Graph.
|
14
|
+
use_cuda_graph: bool
|
15
|
+
|
13
16
|
# Batch size
|
14
17
|
bs: int
|
15
18
|
|
16
|
-
#
|
17
|
-
|
19
|
+
# Number of segments. For triton backend, it is equal to batch size.
|
20
|
+
num_segments: int
|
18
21
|
|
19
|
-
# Indice pointers of each
|
22
|
+
# Indice pointers of each segment in shape (num_segments + 1, )
|
20
23
|
seg_indptr: torch.Tensor
|
21
24
|
|
22
|
-
#
|
23
|
-
max_len: int
|
24
|
-
|
25
|
-
# The index of lora adapter used by each sequence, in shape (bs,)
|
25
|
+
# The index of lora adapter used by each segment, in shape (num_segments,)
|
26
26
|
weight_indices: torch.Tensor
|
27
27
|
|
28
28
|
# ranks of each lora adapter, in shape (lora_num,)
|
@@ -31,6 +31,15 @@ class LoRABatchInfo:
|
|
31
31
|
# scaling of each lora adapter, in shape (lora_num,)
|
32
32
|
scalings: torch.Tensor
|
33
33
|
|
34
|
+
# Lengths of each segments in shape (num_segments,)
|
35
|
+
seg_lens: Optional[torch.Tensor]
|
36
|
+
|
37
|
+
# Maximum segment length of current batch
|
38
|
+
max_len: Optional[int]
|
39
|
+
|
40
|
+
# The logical (re)ordering of input rows (tokens), in shape (num_tokens,)
|
41
|
+
permutation: Optional[torch.Tensor]
|
42
|
+
|
34
43
|
|
35
44
|
class LoRAType(Enum):
|
36
45
|
LORA_A = 0
|
@@ -48,14 +57,14 @@ def get_layer_id(name: str) -> int:
|
|
48
57
|
|
49
58
|
|
50
59
|
def get_hidden_dim(
|
51
|
-
module_name: str, config: AutoConfig, base_model: torch.nn.Module
|
60
|
+
module_name: str, config: AutoConfig, base_model: torch.nn.Module, layer_idx: int
|
52
61
|
) -> Tuple[int]:
|
53
62
|
"""
|
54
63
|
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
55
64
|
"""
|
56
65
|
|
57
66
|
if hasattr(base_model, "get_hidden_dim"):
|
58
|
-
return base_model.get_hidden_dim(module_name)
|
67
|
+
return base_model.get_hidden_dim(module_name, layer_idx)
|
59
68
|
else:
|
60
69
|
"""
|
61
70
|
WARNING: get_hidden_dim() is not defined,
|