sglang 0.4.6__py3-none-any.whl → 0.4.6.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -0
- sglang/check_env.py +3 -3
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/kimi_vl.py +38 -0
- sglang/srt/configs/kimi_vl_moonvit.py +32 -0
- sglang/srt/configs/model_config.py +15 -0
- sglang/srt/conversation.py +122 -1
- sglang/srt/disaggregation/decode.py +8 -2
- sglang/srt/disaggregation/fake/__init__.py +1 -0
- sglang/srt/disaggregation/fake/conn.py +88 -0
- sglang/srt/disaggregation/prefill.py +12 -3
- sglang/srt/disaggregation/utils.py +16 -2
- sglang/srt/entrypoints/engine.py +52 -21
- sglang/srt/entrypoints/http_server.py +27 -2
- sglang/srt/function_call_parser.py +97 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/layers/attention/cutlass_mla_backend.py +278 -0
- sglang/srt/layers/attention/flashinfer_backend.py +107 -82
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -16
- sglang/srt/layers/attention/flashmla_backend.py +3 -0
- sglang/srt/layers/attention/utils.py +1 -1
- sglang/srt/layers/dp_attention.py +5 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +1 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -8
- sglang/srt/layers/moe/fused_moe_triton/layer.py +15 -17
- sglang/srt/layers/quantization/__init__.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +1 -1
- sglang/srt/layers/quantization/fp8.py +20 -22
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/utils.py +35 -0
- sglang/srt/lora/layers.py +35 -9
- sglang/srt/lora/lora_manager.py +84 -35
- sglang/srt/managers/data_parallel_controller.py +52 -34
- sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
- sglang/srt/managers/schedule_batch.py +34 -15
- sglang/srt/managers/scheduler.py +273 -67
- sglang/srt/managers/scheduler_output_processor_mixin.py +26 -10
- sglang/srt/managers/tp_worker.py +52 -17
- sglang/srt/managers/tp_worker_overlap_thread.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +70 -36
- sglang/srt/model_executor/cuda_graph_runner.py +82 -19
- sglang/srt/model_executor/forward_batch_info.py +31 -1
- sglang/srt/model_executor/model_runner.py +123 -58
- sglang/srt/models/deepseek_nextn.py +1 -257
- sglang/srt/models/deepseek_v2.py +78 -18
- sglang/srt/models/kimi_vl.py +308 -0
- sglang/srt/models/kimi_vl_moonvit.py +639 -0
- sglang/srt/models/llama.py +92 -30
- sglang/srt/models/llama4.py +2 -1
- sglang/srt/models/llama_eagle.py +4 -1
- sglang/srt/models/llama_eagle3.py +4 -1
- sglang/srt/models/qwen2_moe.py +8 -3
- sglang/srt/models/qwen2_vl.py +0 -12
- sglang/srt/models/qwen3_moe.py +8 -3
- sglang/srt/openai_api/adapter.py +49 -8
- sglang/srt/openai_api/protocol.py +13 -1
- sglang/srt/reasoning_parser.py +25 -1
- sglang/srt/server_args.py +83 -24
- sglang/srt/speculative/eagle_worker.py +3 -2
- sglang/srt/utils.py +91 -9
- sglang/test/runners.py +4 -0
- sglang/test/send_one.py +84 -28
- sglang/test/test_utils.py +67 -0
- sglang/version.py +1 -1
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/METADATA +5 -4
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/RECORD +85 -60
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/top_level.txt +0 -0
@@ -72,8 +72,8 @@ _is_hip = is_hip()
|
|
72
72
|
_is_cuda = is_cuda()
|
73
73
|
|
74
74
|
if _is_hip:
|
75
|
-
from aiter import ActivationType
|
76
|
-
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
|
75
|
+
from aiter import ActivationType, QuantType
|
76
|
+
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
|
77
77
|
from aiter.ops.shuffle import shuffle_weight
|
78
78
|
|
79
79
|
if not _is_cuda:
|
@@ -484,7 +484,7 @@ class Fp8MoEMethod:
|
|
484
484
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
485
485
|
params_dtype = (
|
486
486
|
torch.uint32
|
487
|
-
if get_bool_env_var("
|
487
|
+
if get_bool_env_var("SGLANG_INT4_WEIGHT")
|
488
488
|
else torch.float8_e4m3fn
|
489
489
|
)
|
490
490
|
tp_size = get_tensor_model_parallel_world_size()
|
@@ -511,7 +511,7 @@ class Fp8MoEMethod:
|
|
511
511
|
)
|
512
512
|
|
513
513
|
# WEIGHTS
|
514
|
-
if _is_hip and get_bool_env_var("
|
514
|
+
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
|
515
515
|
# INT4 MoE weight - INT32 packed
|
516
516
|
w13_weight = torch.nn.Parameter(
|
517
517
|
torch.empty(
|
@@ -585,7 +585,7 @@ class Fp8MoEMethod:
|
|
585
585
|
|
586
586
|
if (
|
587
587
|
_is_hip
|
588
|
-
): # and get_bool_env_var("
|
588
|
+
): # and get_bool_env_var("SGLANG_AITER_MOE"): TODO: add check back after triton kernel
|
589
589
|
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
|
590
590
|
w13_weight_scale1 = torch.nn.Parameter(
|
591
591
|
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
|
@@ -612,7 +612,7 @@ class Fp8MoEMethod:
|
|
612
612
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
613
613
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
614
614
|
|
615
|
-
if _is_hip and get_bool_env_var("
|
615
|
+
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
|
616
616
|
extra_weight_attrs.update(
|
617
617
|
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
618
618
|
)
|
@@ -644,7 +644,7 @@ class Fp8MoEMethod:
|
|
644
644
|
layer.w2_input_scale = None
|
645
645
|
|
646
646
|
def process_weights_after_loading(self, layer: Module) -> None:
|
647
|
-
if _is_hip and get_bool_env_var("
|
647
|
+
if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
|
648
648
|
self.process_weights_hip_int4(layer)
|
649
649
|
return
|
650
650
|
|
@@ -675,7 +675,7 @@ class Fp8MoEMethod:
|
|
675
675
|
)
|
676
676
|
layer.w2_input_scale = None
|
677
677
|
|
678
|
-
if get_bool_env_var("
|
678
|
+
if get_bool_env_var("SGLANG_AITER_MOE"):
|
679
679
|
# Pre-shuffle weights
|
680
680
|
layer.w13_weight.data = shuffle_weight(
|
681
681
|
layer.w13_weight.contiguous(), (16, 16)
|
@@ -798,17 +798,15 @@ class Fp8MoEMethod:
|
|
798
798
|
return
|
799
799
|
|
800
800
|
def process_weights_hip_int4(self, layer: Module):
|
801
|
-
# TODO: and get_bool_env_var("
|
801
|
+
# TODO: and get_bool_env_var("SGLANG_AITER_MOE"): add after triton kernel added
|
802
802
|
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
|
803
803
|
# Weight Permutation
|
804
804
|
layer.w13_weight = torch.nn.Parameter(
|
805
|
-
# permute_weight(layer.w13_weight.data),
|
806
805
|
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
807
806
|
requires_grad=False,
|
808
807
|
)
|
809
808
|
torch.cuda.empty_cache()
|
810
809
|
layer.w2_weight = torch.nn.Parameter(
|
811
|
-
# permute_weight(layer.w2_weight.data),
|
812
810
|
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
813
811
|
requires_grad=False,
|
814
812
|
)
|
@@ -847,23 +845,21 @@ class Fp8MoEMethod:
|
|
847
845
|
padding_size, # Avoid circular import
|
848
846
|
)
|
849
847
|
|
850
|
-
if get_bool_env_var("
|
848
|
+
if get_bool_env_var("SGLANG_AITER_MOE"):
|
851
849
|
layer.w13_weight = torch.nn.Parameter(
|
852
|
-
# permute_weight(layer.w13_weight.data),
|
853
850
|
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
854
851
|
requires_grad=False,
|
855
852
|
)
|
856
853
|
torch.cuda.empty_cache()
|
857
854
|
layer.w2_weight = torch.nn.Parameter(
|
858
|
-
# permute_weight(layer.w2_weight.data),
|
859
855
|
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
860
856
|
requires_grad=False,
|
861
857
|
)
|
862
858
|
torch.cuda.empty_cache()
|
863
|
-
# ROCm (
|
859
|
+
# ROCm (SGLANG_AITER_MOE): using column-wise scaling
|
864
860
|
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
|
865
861
|
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
|
866
|
-
elif get_bool_env_var("
|
862
|
+
elif get_bool_env_var("SGLANG_MOE_PADDING"):
|
867
863
|
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
868
864
|
layer.w13_weight = torch.nn.Parameter(
|
869
865
|
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
@@ -912,15 +908,16 @@ class Fp8MoEMethod:
|
|
912
908
|
)
|
913
909
|
|
914
910
|
if _is_hip:
|
915
|
-
if get_bool_env_var("
|
916
|
-
# TODO: add triton kernel and add check get_bool_env_var("
|
911
|
+
if get_bool_env_var("SGLANG_INT4_WEIGHT"):
|
912
|
+
# TODO: add triton kernel and add check get_bool_env_var("SGLANG_AITER_MOE")
|
917
913
|
assert not no_combine, f"{no_combine=} is not supported."
|
918
|
-
return
|
914
|
+
return ck_moe_2stages(
|
919
915
|
x,
|
920
916
|
layer.w13_weight,
|
921
917
|
layer.w2_weight,
|
922
918
|
topk_weights,
|
923
919
|
topk_ids,
|
920
|
+
QuantType.per_Token,
|
924
921
|
layer.w13_weight_scale1,
|
925
922
|
layer.w2_weight_scale1,
|
926
923
|
activation=(
|
@@ -930,13 +927,13 @@ class Fp8MoEMethod:
|
|
930
927
|
),
|
931
928
|
)
|
932
929
|
|
933
|
-
if get_bool_env_var("
|
930
|
+
if get_bool_env_var("SGLANG_AITER_MOE"):
|
934
931
|
assert not no_combine, f"{no_combine=} is not supported."
|
935
932
|
if self.block_quant:
|
936
|
-
# TODO(
|
933
|
+
# TODO(SGLANG_AITER_MOE): FP8 block_quant only supports 'silu' for the time-being.
|
937
934
|
assert (
|
938
935
|
activation == "silu"
|
939
|
-
), f"
|
936
|
+
), f"SGLANG_AITER_MOE: FP8 bloack_quant {activation=} will be supported later, unset SGLANG_AITER_MOE"
|
940
937
|
return asm_moe(
|
941
938
|
x,
|
942
939
|
layer.w13_weight,
|
@@ -955,6 +952,7 @@ class Fp8MoEMethod:
|
|
955
952
|
layer.w2_weight,
|
956
953
|
topk_weights,
|
957
954
|
topk_ids,
|
955
|
+
QuantType.per_Token,
|
958
956
|
layer.w13_weight_scale1,
|
959
957
|
layer.w2_weight_scale1,
|
960
958
|
activation=(
|
@@ -31,7 +31,7 @@ from sglang.srt.utils import (
|
|
31
31
|
_is_hip = is_hip()
|
32
32
|
_is_cuda = is_cuda()
|
33
33
|
|
34
|
-
if _is_hip and get_bool_env_var("
|
34
|
+
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
|
35
35
|
from aiter import gemm_a8w8_blockscale
|
36
36
|
|
37
37
|
if _is_cuda:
|
@@ -132,7 +132,7 @@ def apply_w8a8_block_fp8_linear(
|
|
132
132
|
output = fp8_blockwise_scaled_mm(
|
133
133
|
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
|
134
134
|
)
|
135
|
-
elif _is_hip and get_bool_env_var("
|
135
|
+
elif _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
|
136
136
|
q_input, x_scale = per_token_group_quant_fp8(
|
137
137
|
input_2d, block_size[1], column_major_scales=False
|
138
138
|
)
|
@@ -0,0 +1,35 @@
|
|
1
|
+
import logging
|
2
|
+
import re
|
3
|
+
|
4
|
+
import torch
|
5
|
+
|
6
|
+
logger = logging.getLogger(__name__)
|
7
|
+
|
8
|
+
|
9
|
+
def get_layer_id(weight_name):
|
10
|
+
# example weight name: model.layers.10.self_attn.qkv_proj.weight
|
11
|
+
match = re.search(r"layers\.(\d+)\.", weight_name)
|
12
|
+
if match:
|
13
|
+
return int(match.group(1))
|
14
|
+
return None
|
15
|
+
|
16
|
+
|
17
|
+
class PPMissingLayer(torch.nn.Identity):
|
18
|
+
# Adapted from
|
19
|
+
# https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1
|
20
|
+
"""
|
21
|
+
A placeholder layer for missing layers in a pipeline parallel model.
|
22
|
+
"""
|
23
|
+
|
24
|
+
def __init__(self, *args, **kwargs):
|
25
|
+
super().__init__()
|
26
|
+
self.return_tuple = kwargs.get("return_tuple", False)
|
27
|
+
|
28
|
+
def forward(self, *args, **kwargs):
|
29
|
+
"""
|
30
|
+
Return the first arg from args or the first value from kwargs.
|
31
|
+
|
32
|
+
Wraps the input in a tuple if `self.return_tuple` is True.
|
33
|
+
"""
|
34
|
+
input = args[0] if args else next(iter(kwargs.values()))
|
35
|
+
return (input,) if self.return_tuple else input
|
sglang/srt/lora/layers.py
CHANGED
@@ -136,11 +136,19 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
136
136
|
self.set_lora = True
|
137
137
|
self.A_buffer_gate_up = A_buffer
|
138
138
|
if self.lora_backend.fuse_stacked_lora_b:
|
139
|
-
# TODO: avoid using contiguous() in GPU.
|
140
139
|
# B_buffer_gate_up: (num_lora, 2 * output_dim, r)
|
141
|
-
self.B_buffer_gate_up
|
142
|
-
|
143
|
-
|
140
|
+
if not hasattr(self, "B_buffer_gate_up") or self.B_buffer_gate_up is None:
|
141
|
+
self.B_buffer_gate_up = torch.empty(
|
142
|
+
(
|
143
|
+
B_buffer[0].shape[0],
|
144
|
+
2 * B_buffer[0].shape[1],
|
145
|
+
B_buffer[0].shape[2],
|
146
|
+
),
|
147
|
+
dtype=B_buffer[0].dtype,
|
148
|
+
device=B_buffer[0].device,
|
149
|
+
)
|
150
|
+
self.B_buffer_gate_up[:, : B_buffer[0].shape[1], :].copy_(B_buffer[0])
|
151
|
+
self.B_buffer_gate_up[:, B_buffer[0].shape[1] :, :].copy_(B_buffer[1])
|
144
152
|
else:
|
145
153
|
self.B_buffer_gate_up = (B_buffer[0], B_buffer[1])
|
146
154
|
|
@@ -171,7 +179,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
171
179
|
|
172
180
|
|
173
181
|
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
174
|
-
def
|
182
|
+
def __init__(
|
175
183
|
self,
|
176
184
|
base_layer: QKVParallelLinear,
|
177
185
|
lora_backend: BaseLoRABackend,
|
@@ -194,12 +202,30 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
194
202
|
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
|
195
203
|
|
196
204
|
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
|
197
|
-
self.B_buffer_qkv
|
198
|
-
|
199
|
-
|
205
|
+
if not hasattr(self, "B_buffer_qkv") or self.B_buffer_qkv is None:
|
206
|
+
self.B_buffer_qkv = torch.empty(
|
207
|
+
(
|
208
|
+
B_buffer_q[0].shape[0],
|
209
|
+
output_dim_q + 2 * output_dim_kv,
|
210
|
+
B_buffer_q[0].shape[2],
|
211
|
+
),
|
212
|
+
dtype=B_buffer_q[0].dtype,
|
213
|
+
device=B_buffer_q[0].device,
|
214
|
+
)
|
215
|
+
self.B_buffer_qkv[:, :output_dim_q, :].copy_(B_buffer_q[0])
|
216
|
+
self.B_buffer_qkv[:, output_dim_q : output_dim_q + output_dim_kv, :].copy_(
|
217
|
+
B_buffer_kv[0]
|
218
|
+
)
|
219
|
+
self.B_buffer_qkv[:, output_dim_q + output_dim_kv :, :].copy_(
|
220
|
+
B_buffer_kv[1]
|
221
|
+
)
|
200
222
|
|
201
223
|
# Offsets of q/k/v in output dimension
|
202
|
-
self.output_offset
|
224
|
+
if not hasattr(self, "output_offset") or self.output_offset is None:
|
225
|
+
self.output_offset = torch.empty(
|
226
|
+
4, dtype=torch.int32, device=B_buffer_q.device
|
227
|
+
)
|
228
|
+
self.output_offset[:4] = torch.tensor(
|
203
229
|
[
|
204
230
|
0,
|
205
231
|
output_dim_q,
|
sglang/srt/lora/lora_manager.py
CHANGED
@@ -72,6 +72,23 @@ class LoRAManager:
|
|
72
72
|
self.init_loras()
|
73
73
|
self.init_lora_memory_pool()
|
74
74
|
|
75
|
+
def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
|
76
|
+
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
|
77
|
+
with torch.device("cuda"):
|
78
|
+
self.cuda_graph_batch_info = LoRABatchInfo(
|
79
|
+
bs=self.max_bs_in_cuda_graph,
|
80
|
+
seg_lens=torch.zeros(self.max_bs_in_cuda_graph, dtype=torch.int32),
|
81
|
+
seg_indptr=torch.zeros(
|
82
|
+
self.max_bs_in_cuda_graph + 1, dtype=torch.int32
|
83
|
+
),
|
84
|
+
max_len=0,
|
85
|
+
weight_indices=torch.zeros(
|
86
|
+
self.max_bs_in_cuda_graph, dtype=torch.int32
|
87
|
+
),
|
88
|
+
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
|
89
|
+
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
|
90
|
+
)
|
91
|
+
|
75
92
|
def init_loras(self):
|
76
93
|
# Config of each LoRA adapter
|
77
94
|
self.configs: Dict[str, LoRAConfig] = {}
|
@@ -136,43 +153,75 @@ class LoRAManager:
|
|
136
153
|
assert len(cur_uids) <= self.max_loras_per_batch
|
137
154
|
self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
|
138
155
|
|
139
|
-
#
|
140
|
-
if cur_uids == set([None]):
|
141
|
-
return
|
142
|
-
|
143
|
-
# set up batch info shared by all lora moruldes
|
156
|
+
# set up batch info shared by all lora modules
|
144
157
|
bs = forward_batch.batch_size
|
145
|
-
seg_lens = (
|
146
|
-
forward_batch.extend_seq_lens
|
147
|
-
if forward_batch.forward_mode.is_extend()
|
148
|
-
else torch.ones(bs, device=self.device)
|
149
|
-
)
|
150
|
-
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
|
151
|
-
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
|
152
|
-
max_len = int(torch.max(seg_lens))
|
153
|
-
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
154
158
|
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
159
|
+
if hasattr(self, "max_bs_in_cuda_graph") and bs <= self.max_bs_in_cuda_graph:
|
160
|
+
# Do in-place updates when CUDA graph is enabled. Note that
|
161
|
+
# if CUDA graph is enabled, the batch whose bs <= max_bs_in_cuda_graph
|
162
|
+
# will also use these preallocated buffers, no matter whether
|
163
|
+
# the batch can use CUDA graph or not.
|
164
|
+
self.cuda_graph_batch_info.bs = bs
|
165
|
+
if forward_batch.forward_mode.is_extend():
|
166
|
+
self.cuda_graph_batch_info.seg_lens[:bs].copy_(
|
167
|
+
forward_batch.extend_seq_lens
|
168
|
+
)
|
169
|
+
else:
|
170
|
+
self.cuda_graph_batch_info.seg_lens[:bs].fill_(1)
|
171
|
+
torch.cumsum(
|
172
|
+
self.cuda_graph_batch_info.seg_lens[:bs],
|
173
|
+
dim=0,
|
174
|
+
out=self.cuda_graph_batch_info.seg_indptr[1 : bs + 1],
|
175
|
+
)
|
176
|
+
self.cuda_graph_batch_info.max_len = int(
|
177
|
+
torch.max(self.cuda_graph_batch_info.seg_lens[:bs])
|
178
|
+
)
|
179
|
+
|
180
|
+
for i, lora_path in enumerate(forward_batch.lora_paths):
|
181
|
+
self.cuda_graph_batch_info.weight_indices[i] = (
|
182
|
+
self.memory_pool.get_buffer_id(lora_path)
|
183
|
+
)
|
184
|
+
if lora_path is not None:
|
185
|
+
lora = self.loras[lora_path]
|
186
|
+
self.cuda_graph_batch_info.lora_ranks[
|
187
|
+
self.cuda_graph_batch_info.weight_indices[i]
|
188
|
+
] = lora.config.hf_config["r"]
|
189
|
+
self.cuda_graph_batch_info.scalings[
|
190
|
+
self.cuda_graph_batch_info.weight_indices[i]
|
191
|
+
] = lora.scaling
|
192
|
+
batch_info = self.cuda_graph_batch_info
|
193
|
+
else:
|
194
|
+
seg_lens = (
|
195
|
+
forward_batch.extend_seq_lens
|
196
|
+
if forward_batch.forward_mode.is_extend()
|
197
|
+
else torch.ones(bs, device=self.device)
|
198
|
+
)
|
199
|
+
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
|
200
|
+
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
|
201
|
+
max_len = int(torch.max(seg_lens))
|
202
|
+
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
203
|
+
|
204
|
+
lora_ranks = torch.empty(
|
205
|
+
(self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
|
206
|
+
)
|
207
|
+
scalings = torch.empty(
|
208
|
+
(self.max_loras_per_batch,), dtype=torch.float, device="cuda"
|
209
|
+
)
|
210
|
+
for i, lora_path in enumerate(forward_batch.lora_paths):
|
211
|
+
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
|
212
|
+
if lora_path is not None:
|
213
|
+
lora = self.loras[lora_path]
|
214
|
+
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
|
215
|
+
scalings[weight_indices[i]] = lora.scaling
|
216
|
+
batch_info = LoRABatchInfo(
|
217
|
+
bs=bs,
|
218
|
+
seg_lens=seg_lens,
|
219
|
+
seg_indptr=seg_indptr,
|
220
|
+
max_len=max_len,
|
221
|
+
weight_indices=weight_indices,
|
222
|
+
lora_ranks=lora_ranks,
|
223
|
+
scalings=scalings,
|
224
|
+
)
|
176
225
|
self.lora_backend.set_batch_info(batch_info)
|
177
226
|
|
178
227
|
# call set_lora_info for each lora modules
|
@@ -181,44 +181,62 @@ class DataParallelController:
|
|
181
181
|
enable=server_args.enable_memory_saver
|
182
182
|
)
|
183
183
|
|
184
|
-
# Launch tensor parallel scheduler processes
|
185
184
|
scheduler_pipe_readers = []
|
186
|
-
|
185
|
+
|
186
|
+
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
|
187
|
+
tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
|
187
188
|
tp_rank_range = range(
|
188
|
-
tp_size_per_node * server_args.node_rank,
|
189
|
-
tp_size_per_node * (server_args.node_rank + 1),
|
189
|
+
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
|
190
|
+
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
|
191
|
+
)
|
192
|
+
|
193
|
+
pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
|
194
|
+
pp_rank_range = range(
|
195
|
+
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
|
196
|
+
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
|
190
197
|
)
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
198
|
+
|
199
|
+
for pp_rank in pp_rank_range:
|
200
|
+
for tp_rank in tp_rank_range:
|
201
|
+
rank_port_args = port_args
|
202
|
+
|
203
|
+
if server_args.enable_dp_attention:
|
204
|
+
# dp attention has different sharding logic
|
205
|
+
_, _, dp_rank = compute_dp_attention_world_info(
|
206
|
+
server_args.enable_dp_attention,
|
207
|
+
tp_rank,
|
208
|
+
server_args.tp_size,
|
209
|
+
server_args.dp_size,
|
210
|
+
)
|
211
|
+
# compute zmq ports for this dp rank
|
212
|
+
rank_port_args = PortArgs.init_new(server_args, dp_rank)
|
213
|
+
# Data parallelism resues the tensor parallelism group,
|
214
|
+
# so all dp ranks should use the same nccl port.
|
215
|
+
rank_port_args.nccl_port = port_args.nccl_port
|
216
|
+
|
217
|
+
reader, writer = mp.Pipe(duplex=False)
|
218
|
+
gpu_id = (
|
219
|
+
server_args.base_gpu_id
|
220
|
+
+ base_gpu_id
|
221
|
+
+ ((pp_rank % pp_size_per_node) * tp_size_per_node)
|
222
|
+
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
|
201
223
|
)
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
with memory_saver_adapter.configure_subprocess():
|
219
|
-
proc.start()
|
220
|
-
self.scheduler_procs.append(proc)
|
221
|
-
scheduler_pipe_readers.append(reader)
|
224
|
+
proc = mp.Process(
|
225
|
+
target=run_scheduler_process,
|
226
|
+
args=(
|
227
|
+
server_args,
|
228
|
+
rank_port_args,
|
229
|
+
gpu_id,
|
230
|
+
tp_rank,
|
231
|
+
pp_rank,
|
232
|
+
dp_rank,
|
233
|
+
writer,
|
234
|
+
),
|
235
|
+
)
|
236
|
+
with memory_saver_adapter.configure_subprocess():
|
237
|
+
proc.start()
|
238
|
+
self.scheduler_procs.append(proc)
|
239
|
+
scheduler_pipe_readers.append(reader)
|
222
240
|
|
223
241
|
# Wait for model to finish loading
|
224
242
|
scheduler_info = []
|
@@ -0,0 +1,73 @@
|
|
1
|
+
import asyncio
|
2
|
+
import math
|
3
|
+
from typing import List, Union
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from PIL import Image
|
7
|
+
|
8
|
+
from sglang.srt.managers.multimodal_processors.base_processor import (
|
9
|
+
BaseMultimodalProcessor as SGLangBaseProcessor,
|
10
|
+
)
|
11
|
+
from sglang.srt.managers.multimodal_processors.base_processor import (
|
12
|
+
MultimodalSpecialTokens,
|
13
|
+
)
|
14
|
+
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
15
|
+
from sglang.srt.models.kimi_vl import KimiVLForConditionalGeneration
|
16
|
+
|
17
|
+
|
18
|
+
# Compatible with KimiVLForConditionalGeneration
|
19
|
+
class KimiVLImageProcessor(SGLangBaseProcessor):
|
20
|
+
models = [KimiVLForConditionalGeneration]
|
21
|
+
|
22
|
+
def __init__(self, hf_config, server_args, _processor):
|
23
|
+
super().__init__(hf_config, server_args, _processor)
|
24
|
+
self.IMAGE_TOKEN = "<|media_pad|>"
|
25
|
+
self.im_token_id = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
|
26
|
+
|
27
|
+
self.im_start = "<|media_start|>"
|
28
|
+
self.im_start_id = _processor.tokenizer.convert_tokens_to_ids(self.im_start)
|
29
|
+
|
30
|
+
self.im_end = "<|media_end|>"
|
31
|
+
self.im_end_id = _processor.tokenizer.convert_tokens_to_ids(self.im_end)
|
32
|
+
|
33
|
+
self.im_content = "<|media_content|>"
|
34
|
+
self.im_content_id = _processor.tokenizer.convert_tokens_to_ids(self.im_content)
|
35
|
+
|
36
|
+
async def process_mm_data_async(
|
37
|
+
self,
|
38
|
+
image_data: List[Union[str, bytes]],
|
39
|
+
input_text,
|
40
|
+
request_obj,
|
41
|
+
max_req_input_len,
|
42
|
+
*args,
|
43
|
+
**kwargs,
|
44
|
+
):
|
45
|
+
if not image_data:
|
46
|
+
return None
|
47
|
+
if isinstance(image_data, str):
|
48
|
+
image_data = [image_data]
|
49
|
+
|
50
|
+
base_output = self.load_mm_data(
|
51
|
+
prompt=input_text,
|
52
|
+
image_data=image_data,
|
53
|
+
multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMAGE_TOKEN),
|
54
|
+
max_req_input_len=max_req_input_len,
|
55
|
+
)
|
56
|
+
ret = self.process_mm_data(
|
57
|
+
input_text=base_output.input_text,
|
58
|
+
images=base_output.images,
|
59
|
+
)
|
60
|
+
return {
|
61
|
+
"input_ids": ret["input_ids"].flatten().tolist(),
|
62
|
+
"mm_items": [
|
63
|
+
MultimodalDataItem(
|
64
|
+
pixel_values=ret["pixel_values"],
|
65
|
+
image_grid_thws=ret["image_grid_hws"],
|
66
|
+
modality=Modality.IMAGE,
|
67
|
+
)
|
68
|
+
],
|
69
|
+
"im_token_id": self.im_token_id,
|
70
|
+
"im_start_id": self.im_start_id,
|
71
|
+
"im_end_id": self.im_end_id,
|
72
|
+
"im_content_id": self.im_content_id,
|
73
|
+
}
|