sglang 0.4.6__py3-none-any.whl → 0.4.6.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/check_env.py +3 -3
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/kimi_vl.py +38 -0
  5. sglang/srt/configs/kimi_vl_moonvit.py +32 -0
  6. sglang/srt/configs/model_config.py +15 -0
  7. sglang/srt/conversation.py +122 -1
  8. sglang/srt/disaggregation/decode.py +8 -2
  9. sglang/srt/disaggregation/fake/__init__.py +1 -0
  10. sglang/srt/disaggregation/fake/conn.py +88 -0
  11. sglang/srt/disaggregation/prefill.py +12 -3
  12. sglang/srt/disaggregation/utils.py +16 -2
  13. sglang/srt/entrypoints/engine.py +52 -21
  14. sglang/srt/entrypoints/http_server.py +27 -2
  15. sglang/srt/function_call_parser.py +97 -0
  16. sglang/srt/hf_transformers_utils.py +2 -0
  17. sglang/srt/layers/attention/cutlass_mla_backend.py +278 -0
  18. sglang/srt/layers/attention/flashinfer_backend.py +107 -82
  19. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -16
  20. sglang/srt/layers/attention/flashmla_backend.py +3 -0
  21. sglang/srt/layers/attention/utils.py +1 -1
  22. sglang/srt/layers/dp_attention.py +5 -2
  23. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +1 -3
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
  40. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -8
  41. sglang/srt/layers/moe/fused_moe_triton/layer.py +15 -17
  42. sglang/srt/layers/quantization/__init__.py +2 -2
  43. sglang/srt/layers/quantization/deep_gemm.py +1 -1
  44. sglang/srt/layers/quantization/fp8.py +20 -22
  45. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  46. sglang/srt/layers/utils.py +35 -0
  47. sglang/srt/lora/layers.py +35 -9
  48. sglang/srt/lora/lora_manager.py +84 -35
  49. sglang/srt/managers/data_parallel_controller.py +52 -34
  50. sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
  51. sglang/srt/managers/schedule_batch.py +34 -15
  52. sglang/srt/managers/scheduler.py +273 -67
  53. sglang/srt/managers/scheduler_output_processor_mixin.py +26 -10
  54. sglang/srt/managers/tp_worker.py +52 -17
  55. sglang/srt/managers/tp_worker_overlap_thread.py +18 -7
  56. sglang/srt/mem_cache/memory_pool.py +70 -36
  57. sglang/srt/model_executor/cuda_graph_runner.py +82 -19
  58. sglang/srt/model_executor/forward_batch_info.py +31 -1
  59. sglang/srt/model_executor/model_runner.py +123 -58
  60. sglang/srt/models/deepseek_nextn.py +1 -257
  61. sglang/srt/models/deepseek_v2.py +78 -18
  62. sglang/srt/models/kimi_vl.py +308 -0
  63. sglang/srt/models/kimi_vl_moonvit.py +639 -0
  64. sglang/srt/models/llama.py +92 -30
  65. sglang/srt/models/llama4.py +2 -1
  66. sglang/srt/models/llama_eagle.py +4 -1
  67. sglang/srt/models/llama_eagle3.py +4 -1
  68. sglang/srt/models/qwen2_moe.py +8 -3
  69. sglang/srt/models/qwen2_vl.py +0 -12
  70. sglang/srt/models/qwen3_moe.py +8 -3
  71. sglang/srt/openai_api/adapter.py +49 -8
  72. sglang/srt/openai_api/protocol.py +13 -1
  73. sglang/srt/reasoning_parser.py +25 -1
  74. sglang/srt/server_args.py +83 -24
  75. sglang/srt/speculative/eagle_worker.py +3 -2
  76. sglang/srt/utils.py +91 -9
  77. sglang/test/runners.py +4 -0
  78. sglang/test/send_one.py +84 -28
  79. sglang/test/test_utils.py +67 -0
  80. sglang/version.py +1 -1
  81. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/METADATA +5 -4
  82. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/RECORD +85 -60
  83. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/WHEEL +1 -1
  84. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/licenses/LICENSE +0 -0
  85. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/top_level.txt +0 -0
@@ -72,8 +72,8 @@ _is_hip = is_hip()
72
72
  _is_cuda = is_cuda()
73
73
 
74
74
  if _is_hip:
75
- from aiter import ActivationType
76
- from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4
75
+ from aiter import ActivationType, QuantType
76
+ from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
77
77
  from aiter.ops.shuffle import shuffle_weight
78
78
 
79
79
  if not _is_cuda:
@@ -484,7 +484,7 @@ class Fp8MoEMethod:
484
484
  if self.quant_config.is_checkpoint_fp8_serialized:
485
485
  params_dtype = (
486
486
  torch.uint32
487
- if get_bool_env_var("USE_INT4_WEIGHT")
487
+ if get_bool_env_var("SGLANG_INT4_WEIGHT")
488
488
  else torch.float8_e4m3fn
489
489
  )
490
490
  tp_size = get_tensor_model_parallel_world_size()
@@ -511,7 +511,7 @@ class Fp8MoEMethod:
511
511
  )
512
512
 
513
513
  # WEIGHTS
514
- if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
514
+ if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
515
515
  # INT4 MoE weight - INT32 packed
516
516
  w13_weight = torch.nn.Parameter(
517
517
  torch.empty(
@@ -585,7 +585,7 @@ class Fp8MoEMethod:
585
585
 
586
586
  if (
587
587
  _is_hip
588
- ): # and get_bool_env_var("CK_MOE"): TODO: add check back after triton kernel
588
+ ): # and get_bool_env_var("SGLANG_AITER_MOE"): TODO: add check back after triton kernel
589
589
  # ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
590
590
  w13_weight_scale1 = torch.nn.Parameter(
591
591
  torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
@@ -612,7 +612,7 @@ class Fp8MoEMethod:
612
612
  set_weight_attrs(w13_weight_scale, extra_weight_attrs)
613
613
  set_weight_attrs(w2_weight_scale, extra_weight_attrs)
614
614
 
615
- if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
615
+ if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
616
616
  extra_weight_attrs.update(
617
617
  {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
618
618
  )
@@ -644,7 +644,7 @@ class Fp8MoEMethod:
644
644
  layer.w2_input_scale = None
645
645
 
646
646
  def process_weights_after_loading(self, layer: Module) -> None:
647
- if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
647
+ if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
648
648
  self.process_weights_hip_int4(layer)
649
649
  return
650
650
 
@@ -675,7 +675,7 @@ class Fp8MoEMethod:
675
675
  )
676
676
  layer.w2_input_scale = None
677
677
 
678
- if get_bool_env_var("CK_MOE"):
678
+ if get_bool_env_var("SGLANG_AITER_MOE"):
679
679
  # Pre-shuffle weights
680
680
  layer.w13_weight.data = shuffle_weight(
681
681
  layer.w13_weight.contiguous(), (16, 16)
@@ -798,17 +798,15 @@ class Fp8MoEMethod:
798
798
  return
799
799
 
800
800
  def process_weights_hip_int4(self, layer: Module):
801
- # TODO: and get_bool_env_var("CK_MOE"): add after triton kernel added
801
+ # TODO: and get_bool_env_var("SGLANG_AITER_MOE"): add after triton kernel added
802
802
  # INT4-FP8 (INT4 MoE Weight, FP8 Compute)
803
803
  # Weight Permutation
804
804
  layer.w13_weight = torch.nn.Parameter(
805
- # permute_weight(layer.w13_weight.data),
806
805
  shuffle_weight(layer.w13_weight.data, (16, 16)),
807
806
  requires_grad=False,
808
807
  )
809
808
  torch.cuda.empty_cache()
810
809
  layer.w2_weight = torch.nn.Parameter(
811
- # permute_weight(layer.w2_weight.data),
812
810
  shuffle_weight(layer.w2_weight.data, (16, 16)),
813
811
  requires_grad=False,
814
812
  )
@@ -847,23 +845,21 @@ class Fp8MoEMethod:
847
845
  padding_size, # Avoid circular import
848
846
  )
849
847
 
850
- if get_bool_env_var("CK_MOE"):
848
+ if get_bool_env_var("SGLANG_AITER_MOE"):
851
849
  layer.w13_weight = torch.nn.Parameter(
852
- # permute_weight(layer.w13_weight.data),
853
850
  shuffle_weight(layer.w13_weight.data, (16, 16)),
854
851
  requires_grad=False,
855
852
  )
856
853
  torch.cuda.empty_cache()
857
854
  layer.w2_weight = torch.nn.Parameter(
858
- # permute_weight(layer.w2_weight.data),
859
855
  shuffle_weight(layer.w2_weight.data, (16, 16)),
860
856
  requires_grad=False,
861
857
  )
862
858
  torch.cuda.empty_cache()
863
- # ROCm (CK_MOE): using column-wise scaling
859
+ # ROCm (SGLANG_AITER_MOE): using column-wise scaling
864
860
  layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
865
861
  layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
866
- elif get_bool_env_var("MOE_PADDING"):
862
+ elif get_bool_env_var("SGLANG_MOE_PADDING"):
867
863
  # If ROCm, apply weight padding (min. Mem channel contention) only if set
868
864
  layer.w13_weight = torch.nn.Parameter(
869
865
  F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
@@ -912,15 +908,16 @@ class Fp8MoEMethod:
912
908
  )
913
909
 
914
910
  if _is_hip:
915
- if get_bool_env_var("USE_INT4_WEIGHT"):
916
- # TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
911
+ if get_bool_env_var("SGLANG_INT4_WEIGHT"):
912
+ # TODO: add triton kernel and add check get_bool_env_var("SGLANG_AITER_MOE")
917
913
  assert not no_combine, f"{no_combine=} is not supported."
918
- return ck_moe_2stages_win4(
914
+ return ck_moe_2stages(
919
915
  x,
920
916
  layer.w13_weight,
921
917
  layer.w2_weight,
922
918
  topk_weights,
923
919
  topk_ids,
920
+ QuantType.per_Token,
924
921
  layer.w13_weight_scale1,
925
922
  layer.w2_weight_scale1,
926
923
  activation=(
@@ -930,13 +927,13 @@ class Fp8MoEMethod:
930
927
  ),
931
928
  )
932
929
 
933
- if get_bool_env_var("CK_MOE"):
930
+ if get_bool_env_var("SGLANG_AITER_MOE"):
934
931
  assert not no_combine, f"{no_combine=} is not supported."
935
932
  if self.block_quant:
936
- # TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
933
+ # TODO(SGLANG_AITER_MOE): FP8 block_quant only supports 'silu' for the time-being.
937
934
  assert (
938
935
  activation == "silu"
939
- ), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
936
+ ), f"SGLANG_AITER_MOE: FP8 bloack_quant {activation=} will be supported later, unset SGLANG_AITER_MOE"
940
937
  return asm_moe(
941
938
  x,
942
939
  layer.w13_weight,
@@ -955,6 +952,7 @@ class Fp8MoEMethod:
955
952
  layer.w2_weight,
956
953
  topk_weights,
957
954
  topk_ids,
955
+ QuantType.per_Token,
958
956
  layer.w13_weight_scale1,
959
957
  layer.w2_weight_scale1,
960
958
  activation=(
@@ -31,7 +31,7 @@ from sglang.srt.utils import (
31
31
  _is_hip = is_hip()
32
32
  _is_cuda = is_cuda()
33
33
 
34
- if _is_hip and get_bool_env_var("CK_MOE"):
34
+ if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
35
35
  from aiter import gemm_a8w8_blockscale
36
36
 
37
37
  if _is_cuda:
@@ -132,7 +132,7 @@ def apply_w8a8_block_fp8_linear(
132
132
  output = fp8_blockwise_scaled_mm(
133
133
  q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
134
134
  )
135
- elif _is_hip and get_bool_env_var("CK_MOE"):
135
+ elif _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
136
136
  q_input, x_scale = per_token_group_quant_fp8(
137
137
  input_2d, block_size[1], column_major_scales=False
138
138
  )
@@ -0,0 +1,35 @@
1
+ import logging
2
+ import re
3
+
4
+ import torch
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ def get_layer_id(weight_name):
10
+ # example weight name: model.layers.10.self_attn.qkv_proj.weight
11
+ match = re.search(r"layers\.(\d+)\.", weight_name)
12
+ if match:
13
+ return int(match.group(1))
14
+ return None
15
+
16
+
17
+ class PPMissingLayer(torch.nn.Identity):
18
+ # Adapted from
19
+ # https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1
20
+ """
21
+ A placeholder layer for missing layers in a pipeline parallel model.
22
+ """
23
+
24
+ def __init__(self, *args, **kwargs):
25
+ super().__init__()
26
+ self.return_tuple = kwargs.get("return_tuple", False)
27
+
28
+ def forward(self, *args, **kwargs):
29
+ """
30
+ Return the first arg from args or the first value from kwargs.
31
+
32
+ Wraps the input in a tuple if `self.return_tuple` is True.
33
+ """
34
+ input = args[0] if args else next(iter(kwargs.values()))
35
+ return (input,) if self.return_tuple else input
sglang/srt/lora/layers.py CHANGED
@@ -136,11 +136,19 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
136
136
  self.set_lora = True
137
137
  self.A_buffer_gate_up = A_buffer
138
138
  if self.lora_backend.fuse_stacked_lora_b:
139
- # TODO: avoid using contiguous() in GPU.
140
139
  # B_buffer_gate_up: (num_lora, 2 * output_dim, r)
141
- self.B_buffer_gate_up = torch.cat(
142
- (B_buffer[0], B_buffer[1]), dim=-2
143
- ).contiguous()
140
+ if not hasattr(self, "B_buffer_gate_up") or self.B_buffer_gate_up is None:
141
+ self.B_buffer_gate_up = torch.empty(
142
+ (
143
+ B_buffer[0].shape[0],
144
+ 2 * B_buffer[0].shape[1],
145
+ B_buffer[0].shape[2],
146
+ ),
147
+ dtype=B_buffer[0].dtype,
148
+ device=B_buffer[0].device,
149
+ )
150
+ self.B_buffer_gate_up[:, : B_buffer[0].shape[1], :].copy_(B_buffer[0])
151
+ self.B_buffer_gate_up[:, B_buffer[0].shape[1] :, :].copy_(B_buffer[1])
144
152
  else:
145
153
  self.B_buffer_gate_up = (B_buffer[0], B_buffer[1])
146
154
 
@@ -171,7 +179,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
171
179
 
172
180
 
173
181
  class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
174
- def init__(
182
+ def __init__(
175
183
  self,
176
184
  base_layer: QKVParallelLinear,
177
185
  lora_backend: BaseLoRABackend,
@@ -194,12 +202,30 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
194
202
  output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
195
203
 
196
204
  # B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
197
- self.B_buffer_qkv = torch.cat(
198
- (B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2
199
- ).contiguous()
205
+ if not hasattr(self, "B_buffer_qkv") or self.B_buffer_qkv is None:
206
+ self.B_buffer_qkv = torch.empty(
207
+ (
208
+ B_buffer_q[0].shape[0],
209
+ output_dim_q + 2 * output_dim_kv,
210
+ B_buffer_q[0].shape[2],
211
+ ),
212
+ dtype=B_buffer_q[0].dtype,
213
+ device=B_buffer_q[0].device,
214
+ )
215
+ self.B_buffer_qkv[:, :output_dim_q, :].copy_(B_buffer_q[0])
216
+ self.B_buffer_qkv[:, output_dim_q : output_dim_q + output_dim_kv, :].copy_(
217
+ B_buffer_kv[0]
218
+ )
219
+ self.B_buffer_qkv[:, output_dim_q + output_dim_kv :, :].copy_(
220
+ B_buffer_kv[1]
221
+ )
200
222
 
201
223
  # Offsets of q/k/v in output dimension
202
- self.output_offset = torch.tensor(
224
+ if not hasattr(self, "output_offset") or self.output_offset is None:
225
+ self.output_offset = torch.empty(
226
+ 4, dtype=torch.int32, device=B_buffer_q.device
227
+ )
228
+ self.output_offset[:4] = torch.tensor(
203
229
  [
204
230
  0,
205
231
  output_dim_q,
@@ -72,6 +72,23 @@ class LoRAManager:
72
72
  self.init_loras()
73
73
  self.init_lora_memory_pool()
74
74
 
75
+ def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
76
+ self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
77
+ with torch.device("cuda"):
78
+ self.cuda_graph_batch_info = LoRABatchInfo(
79
+ bs=self.max_bs_in_cuda_graph,
80
+ seg_lens=torch.zeros(self.max_bs_in_cuda_graph, dtype=torch.int32),
81
+ seg_indptr=torch.zeros(
82
+ self.max_bs_in_cuda_graph + 1, dtype=torch.int32
83
+ ),
84
+ max_len=0,
85
+ weight_indices=torch.zeros(
86
+ self.max_bs_in_cuda_graph, dtype=torch.int32
87
+ ),
88
+ lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
89
+ scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
90
+ )
91
+
75
92
  def init_loras(self):
76
93
  # Config of each LoRA adapter
77
94
  self.configs: Dict[str, LoRAConfig] = {}
@@ -136,43 +153,75 @@ class LoRAManager:
136
153
  assert len(cur_uids) <= self.max_loras_per_batch
137
154
  self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
138
155
 
139
- # FIXME: Handle lora uid with None more safely
140
- if cur_uids == set([None]):
141
- return
142
-
143
- # set up batch info shared by all lora moruldes
156
+ # set up batch info shared by all lora modules
144
157
  bs = forward_batch.batch_size
145
- seg_lens = (
146
- forward_batch.extend_seq_lens
147
- if forward_batch.forward_mode.is_extend()
148
- else torch.ones(bs, device=self.device)
149
- )
150
- seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
151
- seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
152
- max_len = int(torch.max(seg_lens))
153
- weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
154
158
 
155
- lora_ranks = torch.empty(
156
- (self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
157
- )
158
- scalings = torch.empty(
159
- (self.max_loras_per_batch,), dtype=torch.float, device="cuda"
160
- )
161
- for i, lora_path in enumerate(forward_batch.lora_paths):
162
- weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
163
- lora = self.loras[lora_path]
164
- lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
165
- scalings[weight_indices[i]] = lora.scaling
166
-
167
- batch_info = LoRABatchInfo(
168
- bs=bs,
169
- seg_lens=seg_lens,
170
- seg_indptr=seg_indptr,
171
- max_len=max_len,
172
- weight_indices=weight_indices,
173
- lora_ranks=lora_ranks,
174
- scalings=scalings,
175
- )
159
+ if hasattr(self, "max_bs_in_cuda_graph") and bs <= self.max_bs_in_cuda_graph:
160
+ # Do in-place updates when CUDA graph is enabled. Note that
161
+ # if CUDA graph is enabled, the batch whose bs <= max_bs_in_cuda_graph
162
+ # will also use these preallocated buffers, no matter whether
163
+ # the batch can use CUDA graph or not.
164
+ self.cuda_graph_batch_info.bs = bs
165
+ if forward_batch.forward_mode.is_extend():
166
+ self.cuda_graph_batch_info.seg_lens[:bs].copy_(
167
+ forward_batch.extend_seq_lens
168
+ )
169
+ else:
170
+ self.cuda_graph_batch_info.seg_lens[:bs].fill_(1)
171
+ torch.cumsum(
172
+ self.cuda_graph_batch_info.seg_lens[:bs],
173
+ dim=0,
174
+ out=self.cuda_graph_batch_info.seg_indptr[1 : bs + 1],
175
+ )
176
+ self.cuda_graph_batch_info.max_len = int(
177
+ torch.max(self.cuda_graph_batch_info.seg_lens[:bs])
178
+ )
179
+
180
+ for i, lora_path in enumerate(forward_batch.lora_paths):
181
+ self.cuda_graph_batch_info.weight_indices[i] = (
182
+ self.memory_pool.get_buffer_id(lora_path)
183
+ )
184
+ if lora_path is not None:
185
+ lora = self.loras[lora_path]
186
+ self.cuda_graph_batch_info.lora_ranks[
187
+ self.cuda_graph_batch_info.weight_indices[i]
188
+ ] = lora.config.hf_config["r"]
189
+ self.cuda_graph_batch_info.scalings[
190
+ self.cuda_graph_batch_info.weight_indices[i]
191
+ ] = lora.scaling
192
+ batch_info = self.cuda_graph_batch_info
193
+ else:
194
+ seg_lens = (
195
+ forward_batch.extend_seq_lens
196
+ if forward_batch.forward_mode.is_extend()
197
+ else torch.ones(bs, device=self.device)
198
+ )
199
+ seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
200
+ seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
201
+ max_len = int(torch.max(seg_lens))
202
+ weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
203
+
204
+ lora_ranks = torch.empty(
205
+ (self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
206
+ )
207
+ scalings = torch.empty(
208
+ (self.max_loras_per_batch,), dtype=torch.float, device="cuda"
209
+ )
210
+ for i, lora_path in enumerate(forward_batch.lora_paths):
211
+ weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
212
+ if lora_path is not None:
213
+ lora = self.loras[lora_path]
214
+ lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
215
+ scalings[weight_indices[i]] = lora.scaling
216
+ batch_info = LoRABatchInfo(
217
+ bs=bs,
218
+ seg_lens=seg_lens,
219
+ seg_indptr=seg_indptr,
220
+ max_len=max_len,
221
+ weight_indices=weight_indices,
222
+ lora_ranks=lora_ranks,
223
+ scalings=scalings,
224
+ )
176
225
  self.lora_backend.set_batch_info(batch_info)
177
226
 
178
227
  # call set_lora_info for each lora modules
@@ -181,44 +181,62 @@ class DataParallelController:
181
181
  enable=server_args.enable_memory_saver
182
182
  )
183
183
 
184
- # Launch tensor parallel scheduler processes
185
184
  scheduler_pipe_readers = []
186
- tp_size_per_node = server_args.tp_size // server_args.nnodes
185
+
186
+ nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
187
+ tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
187
188
  tp_rank_range = range(
188
- tp_size_per_node * server_args.node_rank,
189
- tp_size_per_node * (server_args.node_rank + 1),
189
+ tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
190
+ tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
191
+ )
192
+
193
+ pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
194
+ pp_rank_range = range(
195
+ pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
196
+ pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
190
197
  )
191
- for tp_rank in tp_rank_range:
192
- rank_port_args = port_args
193
-
194
- if server_args.enable_dp_attention:
195
- # dp attention has different sharding logic
196
- _, _, dp_rank = compute_dp_attention_world_info(
197
- server_args.enable_dp_attention,
198
- tp_rank,
199
- server_args.tp_size,
200
- server_args.dp_size,
198
+
199
+ for pp_rank in pp_rank_range:
200
+ for tp_rank in tp_rank_range:
201
+ rank_port_args = port_args
202
+
203
+ if server_args.enable_dp_attention:
204
+ # dp attention has different sharding logic
205
+ _, _, dp_rank = compute_dp_attention_world_info(
206
+ server_args.enable_dp_attention,
207
+ tp_rank,
208
+ server_args.tp_size,
209
+ server_args.dp_size,
210
+ )
211
+ # compute zmq ports for this dp rank
212
+ rank_port_args = PortArgs.init_new(server_args, dp_rank)
213
+ # Data parallelism resues the tensor parallelism group,
214
+ # so all dp ranks should use the same nccl port.
215
+ rank_port_args.nccl_port = port_args.nccl_port
216
+
217
+ reader, writer = mp.Pipe(duplex=False)
218
+ gpu_id = (
219
+ server_args.base_gpu_id
220
+ + base_gpu_id
221
+ + ((pp_rank % pp_size_per_node) * tp_size_per_node)
222
+ + (tp_rank % tp_size_per_node) * server_args.gpu_id_step
201
223
  )
202
- # compute zmq ports for this dp rank
203
- rank_port_args = PortArgs.init_new(server_args, dp_rank)
204
- # Data parallelism resues the tensor parallelism group,
205
- # so all dp ranks should use the same nccl port.
206
- rank_port_args.nccl_port = port_args.nccl_port
207
-
208
- reader, writer = mp.Pipe(duplex=False)
209
- gpu_id = (
210
- server_args.base_gpu_id
211
- + base_gpu_id
212
- + (tp_rank % tp_size_per_node) * server_args.gpu_id_step
213
- )
214
- proc = mp.Process(
215
- target=run_scheduler_process,
216
- args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),
217
- )
218
- with memory_saver_adapter.configure_subprocess():
219
- proc.start()
220
- self.scheduler_procs.append(proc)
221
- scheduler_pipe_readers.append(reader)
224
+ proc = mp.Process(
225
+ target=run_scheduler_process,
226
+ args=(
227
+ server_args,
228
+ rank_port_args,
229
+ gpu_id,
230
+ tp_rank,
231
+ pp_rank,
232
+ dp_rank,
233
+ writer,
234
+ ),
235
+ )
236
+ with memory_saver_adapter.configure_subprocess():
237
+ proc.start()
238
+ self.scheduler_procs.append(proc)
239
+ scheduler_pipe_readers.append(reader)
222
240
 
223
241
  # Wait for model to finish loading
224
242
  scheduler_info = []
@@ -0,0 +1,73 @@
1
+ import asyncio
2
+ import math
3
+ from typing import List, Union
4
+
5
+ import torch
6
+ from PIL import Image
7
+
8
+ from sglang.srt.managers.multimodal_processors.base_processor import (
9
+ BaseMultimodalProcessor as SGLangBaseProcessor,
10
+ )
11
+ from sglang.srt.managers.multimodal_processors.base_processor import (
12
+ MultimodalSpecialTokens,
13
+ )
14
+ from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
15
+ from sglang.srt.models.kimi_vl import KimiVLForConditionalGeneration
16
+
17
+
18
+ # Compatible with KimiVLForConditionalGeneration
19
+ class KimiVLImageProcessor(SGLangBaseProcessor):
20
+ models = [KimiVLForConditionalGeneration]
21
+
22
+ def __init__(self, hf_config, server_args, _processor):
23
+ super().__init__(hf_config, server_args, _processor)
24
+ self.IMAGE_TOKEN = "<|media_pad|>"
25
+ self.im_token_id = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
26
+
27
+ self.im_start = "<|media_start|>"
28
+ self.im_start_id = _processor.tokenizer.convert_tokens_to_ids(self.im_start)
29
+
30
+ self.im_end = "<|media_end|>"
31
+ self.im_end_id = _processor.tokenizer.convert_tokens_to_ids(self.im_end)
32
+
33
+ self.im_content = "<|media_content|>"
34
+ self.im_content_id = _processor.tokenizer.convert_tokens_to_ids(self.im_content)
35
+
36
+ async def process_mm_data_async(
37
+ self,
38
+ image_data: List[Union[str, bytes]],
39
+ input_text,
40
+ request_obj,
41
+ max_req_input_len,
42
+ *args,
43
+ **kwargs,
44
+ ):
45
+ if not image_data:
46
+ return None
47
+ if isinstance(image_data, str):
48
+ image_data = [image_data]
49
+
50
+ base_output = self.load_mm_data(
51
+ prompt=input_text,
52
+ image_data=image_data,
53
+ multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMAGE_TOKEN),
54
+ max_req_input_len=max_req_input_len,
55
+ )
56
+ ret = self.process_mm_data(
57
+ input_text=base_output.input_text,
58
+ images=base_output.images,
59
+ )
60
+ return {
61
+ "input_ids": ret["input_ids"].flatten().tolist(),
62
+ "mm_items": [
63
+ MultimodalDataItem(
64
+ pixel_values=ret["pixel_values"],
65
+ image_grid_thws=ret["image_grid_hws"],
66
+ modality=Modality.IMAGE,
67
+ )
68
+ ],
69
+ "im_token_id": self.im_token_id,
70
+ "im_start_id": self.im_start_id,
71
+ "im_end_id": self.im_end_id,
72
+ "im_content_id": self.im_content_id,
73
+ }