sglang 0.4.5.post3__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. sglang/bench_one_batch.py +19 -3
  2. sglang/bench_serving.py +8 -9
  3. sglang/compile_deep_gemm.py +45 -4
  4. sglang/srt/code_completion_parser.py +1 -1
  5. sglang/srt/configs/deepseekvl2.py +1 -1
  6. sglang/srt/configs/model_config.py +9 -3
  7. sglang/srt/constrained/llguidance_backend.py +78 -61
  8. sglang/srt/conversation.py +34 -1
  9. sglang/srt/disaggregation/decode.py +59 -11
  10. sglang/srt/disaggregation/mini_lb.py +45 -8
  11. sglang/srt/disaggregation/mooncake/conn.py +198 -31
  12. sglang/srt/disaggregation/prefill.py +24 -9
  13. sglang/srt/entrypoints/http_server.py +8 -2
  14. sglang/srt/function_call_parser.py +77 -5
  15. sglang/srt/layers/attention/base_attn_backend.py +3 -0
  16. sglang/srt/layers/attention/flashattention_backend.py +28 -10
  17. sglang/srt/layers/attention/flashmla_backend.py +8 -11
  18. sglang/srt/layers/attention/vision.py +2 -0
  19. sglang/srt/layers/layernorm.py +38 -16
  20. sglang/srt/layers/logits_processor.py +2 -2
  21. sglang/srt/layers/moe/fused_moe_native.py +2 -4
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
  25. sglang/srt/layers/pooler.py +6 -0
  26. sglang/srt/layers/quantization/awq.py +5 -1
  27. sglang/srt/layers/quantization/deep_gemm.py +17 -10
  28. sglang/srt/layers/quantization/int8_kernel.py +32 -1
  29. sglang/srt/layers/radix_attention.py +13 -3
  30. sglang/srt/layers/rotary_embedding.py +170 -126
  31. sglang/srt/managers/data_parallel_controller.py +10 -3
  32. sglang/srt/managers/io_struct.py +7 -0
  33. sglang/srt/managers/mm_utils.py +85 -28
  34. sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
  35. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
  36. sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
  37. sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
  38. sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
  39. sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
  40. sglang/srt/managers/schedule_batch.py +29 -12
  41. sglang/srt/managers/scheduler.py +31 -20
  42. sglang/srt/managers/tokenizer_manager.py +5 -1
  43. sglang/srt/mem_cache/memory_pool.py +87 -0
  44. sglang/srt/model_executor/cuda_graph_runner.py +4 -3
  45. sglang/srt/model_executor/forward_batch_info.py +51 -95
  46. sglang/srt/model_executor/model_runner.py +11 -24
  47. sglang/srt/models/deepseek.py +12 -2
  48. sglang/srt/models/deepseek_nextn.py +101 -6
  49. sglang/srt/models/deepseek_v2.py +144 -70
  50. sglang/srt/models/deepseek_vl2.py +9 -4
  51. sglang/srt/models/gemma3_causal.py +1 -1
  52. sglang/srt/models/llama4.py +0 -1
  53. sglang/srt/models/minicpmo.py +5 -1
  54. sglang/srt/models/mllama4.py +2 -2
  55. sglang/srt/models/qwen2_5_vl.py +3 -6
  56. sglang/srt/models/qwen2_vl.py +3 -7
  57. sglang/srt/models/roberta.py +178 -0
  58. sglang/srt/openai_api/adapter.py +18 -8
  59. sglang/srt/server_args.py +15 -22
  60. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  61. sglang/srt/torch_memory_saver_adapter.py +10 -1
  62. sglang/srt/utils.py +2 -1
  63. sglang/test/runners.py +6 -13
  64. sglang/test/test_utils.py +36 -18
  65. sglang/version.py +1 -1
  66. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/METADATA +4 -5
  67. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/RECORD +70 -68
  68. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
  69. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
  70. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@
16
16
  # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
17
  # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
18
  # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+ from typing import List, Union
19
20
 
20
21
  import torch
21
22
 
@@ -35,7 +36,13 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
35
36
  self.IMAGE_TOKEN = "<image>"
36
37
 
37
38
  async def process_mm_data_async(
38
- self, image_data, input_ids, request_obj, max_req_input_len, *args, **kwargs
39
+ self,
40
+ image_data: List[Union[str, bytes]],
41
+ input_text,
42
+ request_obj,
43
+ max_req_input_len,
44
+ *args,
45
+ **kwargs
39
46
  ):
40
47
  if not image_data:
41
48
  return None
@@ -45,7 +52,7 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
45
52
 
46
53
  image_token = self.IMAGE_TOKEN
47
54
  base_output = self.load_mm_data(
48
- input_ids,
55
+ input_text,
49
56
  image_data=image_data,
50
57
  multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
51
58
  max_req_input_len=max_req_input_len,
@@ -1,7 +1,5 @@
1
1
  from typing import List, Union
2
2
 
3
- from transformers.utils import logging
4
-
5
3
  from sglang.srt.managers.multimodal_processor import (
6
4
  BaseMultimodalProcessor as SGLangBaseProcessor,
7
5
  )
@@ -13,7 +11,6 @@ from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
13
11
 
14
12
  # Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py
15
13
  # will be removed in the future
16
- logger = logging.get_logger(__name__)
17
14
 
18
15
 
19
16
  class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
@@ -28,7 +25,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
28
25
  async def process_mm_data_async(
29
26
  self,
30
27
  image_data: List[Union[str, bytes]],
31
- input_ids,
28
+ input_text,
32
29
  request_obj,
33
30
  max_req_input_len,
34
31
  *args,
@@ -41,7 +38,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
41
38
 
42
39
  image_token = self.IMAGE_TOKEN
43
40
  base_output = self.load_mm_data(
44
- prompt=input_ids,
41
+ prompt=input_text,
45
42
  image_data=image_data,
46
43
  multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
47
44
  max_req_input_len=max_req_input_len,
@@ -17,7 +17,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
17
17
  async def process_mm_data_async(
18
18
  self,
19
19
  image_data: List[Union[str, bytes]],
20
- input_ids,
20
+ input_text,
21
21
  request_obj,
22
22
  max_req_input_len,
23
23
  **kwargs,
@@ -31,7 +31,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
31
31
  processor = self._processor
32
32
 
33
33
  base_out = self.load_mm_data(
34
- prompt=input_ids,
34
+ prompt=input_text,
35
35
  image_data=image_data,
36
36
  multimodal_tokens=MultimodalSpecialTokens(
37
37
  image_token=processor.image_token
@@ -51,9 +51,10 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
51
51
  async def process_mm_data_async(
52
52
  self,
53
53
  image_data: List[Union[str, bytes]],
54
- input_ids,
54
+ input_text,
55
55
  request_obj,
56
56
  max_req_input_len,
57
+ **kwargs,
57
58
  ):
58
59
  audio_data = request_obj.audio_data
59
60
  if not image_data and not audio_data:
@@ -64,7 +65,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
64
65
  audio_data = [audio_data]
65
66
 
66
67
  base_output = self.load_mm_data(
67
- prompt=input_ids,
68
+ prompt=input_text,
68
69
  max_req_input_len=max_req_input_len,
69
70
  audio_data=audio_data,
70
71
  image_data=image_data,
@@ -96,7 +97,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
96
97
  audio_start_id = tokenizer.audio_start_id
97
98
  audio_end_id = tokenizer.audio_end_id
98
99
 
99
- im_token_id = tokenizer.unk_token_id
100
+ im_token_id = tokenizer.unk_id
100
101
  pixel_values = res["pixel_values"]
101
102
  tgt_sizes = res["tgt_sizes"]
102
103
 
@@ -5,6 +5,7 @@ from typing import List, Union
5
5
  import torch
6
6
  from PIL import Image
7
7
 
8
+ from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
8
9
  from sglang.srt.managers.multimodal_processors.base_processor import (
9
10
  BaseMultimodalProcessor as SGLangBaseProcessor,
10
11
  )
@@ -27,6 +28,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
27
28
  self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
28
29
  self.image_token_id = hf_config.image_token_id
29
30
  self.video_token_id = hf_config.video_token_id
31
+ self.vision_start_token_id = hf_config.vision_start_token_id
32
+ self.vision_end_token_id = hf_config.vision_end_token_id
30
33
  self.NUM_TOKEN_PER_FRAME = 770
31
34
  self.IMAGE_FACTOR = 28
32
35
  self.MIN_PIXELS = 4 * 28 * 28
@@ -36,20 +39,18 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
36
39
  async def process_mm_data_async(
37
40
  self,
38
41
  image_data: List[Union[str, bytes]],
39
- prompt,
42
+ input_text,
40
43
  request_obj,
41
44
  max_req_input_len,
42
45
  *args,
43
46
  **kwargs,
44
47
  ):
45
- if not image_data:
46
- return None
47
48
  if isinstance(image_data, str):
48
49
  image_data = [image_data]
49
50
 
50
51
  image_token = self.IMAGE_TOKEN
51
52
  base_output = self.load_mm_data(
52
- prompt=prompt,
53
+ prompt=input_text,
53
54
  image_data=image_data,
54
55
  multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
55
56
  max_req_input_len=max_req_input_len,
@@ -116,29 +117,53 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
116
117
  async def resize_image_async(image):
117
118
  return resize_image(image)
118
119
 
119
- resize_tasks = [resize_image_async(image) for image in base_output.images]
120
- resized_images = await asyncio.gather(*resize_tasks)
120
+ if base_output.images:
121
+ resize_tasks = [resize_image_async(image) for image in base_output.images]
122
+ base_output.images = await asyncio.gather(*resize_tasks)
121
123
 
122
124
  ret = self.process_mm_data(
123
125
  input_text=base_output.input_text,
124
- images=resized_images,
126
+ images=base_output.images,
125
127
  )
126
128
 
127
- image_grid_thws = torch.concat([ret["image_grid_thw"]])
128
- return {
129
- "input_ids": ret["input_ids"].flatten().tolist(),
130
- "mm_items": [
129
+ items = []
130
+
131
+ input_ids = ret["input_ids"].flatten().tolist()
132
+ if "pixel_values" in ret:
133
+ items += [
131
134
  MultimodalDataItem(
132
135
  pixel_values=ret["pixel_values"],
133
- image_grid_thws=image_grid_thws,
136
+ image_grid_thws=torch.concat([ret["image_grid_thw"]]),
134
137
  # TODO
135
138
  video_grid_thws=None,
136
139
  second_per_grid_ts=ret.get("second_per_grid_ts", None),
137
140
  modality=Modality.IMAGE,
138
141
  )
139
- ],
142
+ ]
143
+
144
+ mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index(
145
+ spatial_merge_size=self.hf_config.vision_config.spatial_merge_size,
146
+ image_token_id=self.image_token_id,
147
+ video_token_id=self.video_token_id,
148
+ vision_start_token_id=self.vision_start_token_id,
149
+ model_type=self.hf_config.model_type,
150
+ tokens_per_second=getattr(
151
+ self.hf_config.vision_config, "tokens_per_second", None
152
+ ),
153
+ input_ids=torch.tensor(input_ids).unsqueeze(0),
154
+ image_grid_thw=ret.get("image_grid_thw", None),
155
+ video_grid_thw=ret.get("video_grid_thw", None),
156
+ second_per_grid_ts=ret.get("second_per_grid_ts", None),
157
+ )
158
+ mrope_positions = mrope_positions.squeeze(1)
159
+
160
+ return {
161
+ "input_ids": input_ids,
162
+ "mm_items": items,
140
163
  "im_start_id": self.IM_START_TOKEN_ID,
141
164
  "im_end_id": self.IM_END_TOKEN_ID,
142
165
  "im_token_id": self.image_token_id,
143
166
  "video_token_id": self.video_token_id,
167
+ "mrope_positions": mrope_positions,
168
+ "mrope_position_delta": mrope_position_delta,
144
169
  }
@@ -285,6 +285,7 @@ class MultimodalInputs:
285
285
  num_image_tokens: Optional[int] = None
286
286
 
287
287
  # QWen2-VL related
288
+ mrope_positions: Optional[torch.Tensor] = None
288
289
  mrope_position_delta: Optional[torch.Tensor] = None
289
290
 
290
291
  # image
@@ -310,16 +311,12 @@ class MultimodalInputs:
310
311
  assert isinstance(ret.mm_items, list)
311
312
  ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
312
313
 
313
- assert len(ret.mm_items) != 0
314
-
315
- # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
316
- # Please note that if the `input_ids` is later used in the model forward,
317
- # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
318
- # errors in cuda kernels. See also llava.py for example.
319
314
  for item in ret.mm_items:
320
315
  item.set_pad_value()
321
316
 
322
317
  optional_args = [
318
+ "mrope_positions",
319
+ "mrope_position_delta",
323
320
  "im_token_id",
324
321
  "im_start_id",
325
322
  "im_end_id",
@@ -350,11 +347,6 @@ class MultimodalInputs:
350
347
  merge image inputs when requests are being merged
351
348
  """
352
349
 
353
- # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
354
- # Please note that if the `input_ids` is later used in the model forward,
355
- # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
356
- # errors in cuda kernels. See also llava.py for example.
357
-
358
350
  # args needed to be merged
359
351
  optional_args = [
360
352
  "mm_items",
@@ -364,6 +356,30 @@ class MultimodalInputs:
364
356
  self_arg = getattr(self, arg, None)
365
357
  if self_arg is not None:
366
358
  setattr(self, arg, self_arg + getattr(other, arg))
359
+
360
+ mrope_positions = self.mrope_positions
361
+ if mrope_positions is not None:
362
+ if other.mrope_positions is None:
363
+ self.mrope_positions = mrope_positions
364
+ else:
365
+ self.mrope_positions = torch.cat(
366
+ [self.mrope_positions, other.mrope_positions], dim=1
367
+ )
368
+
369
+ mrope_position_delta = self.mrope_position_delta
370
+ if mrope_position_delta is not None:
371
+ if other.mrope_position_delta is None:
372
+ self.mrope_position_delta = mrope_position_delta
373
+ else:
374
+ self.mrope_position_delta = torch.cat(
375
+ [self.mrope_position_delta, other.mrope_position_delta], dim=0
376
+ )
377
+
378
+ for key, val in other.__dict__.items():
379
+ if "_id" in key:
380
+ # set token_ids
381
+ if getattr(self, key, None) is None:
382
+ setattr(self, key, getattr(other, key, None))
367
383
  # other args would be kept intact
368
384
 
369
385
 
@@ -388,6 +404,7 @@ class Req:
388
404
  return_hidden_states: bool = False,
389
405
  eos_token_ids: Optional[Set[int]] = None,
390
406
  bootstrap_host: Optional[str] = None,
407
+ bootstrap_port: Optional[int] = None,
391
408
  bootstrap_room: Optional[int] = None,
392
409
  ):
393
410
  # Input and output info
@@ -523,6 +540,7 @@ class Req:
523
540
 
524
541
  # For disaggregation
525
542
  self.bootstrap_host: str = bootstrap_host
543
+ self.bootstrap_port: Optional[int] = bootstrap_port
526
544
  self.bootstrap_room: Optional[int] = bootstrap_room
527
545
  self.disagg_kv_sender: Optional[BaseKVSender] = None
528
546
 
@@ -1450,7 +1468,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1450
1468
  if self.model_config.is_encoder_decoder:
1451
1469
  self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
1452
1470
  self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
1453
-
1454
1471
  self.req_pool_indices = torch.cat(
1455
1472
  [self.req_pool_indices, other.req_pool_indices]
1456
1473
  )
@@ -578,6 +578,10 @@ class Scheduler(
578
578
  bootstrap_port=self.server_args.disaggregation_bootstrap_port,
579
579
  transfer_backend=self.transfer_backend,
580
580
  )
581
+
582
+ # Metric for pre-allocation
583
+ self.num_tokens_pre_allocated = 0
584
+
581
585
  elif self.disaggregation_mode == DisaggregationMode.PREFILL:
582
586
  # *2 for the headroom.
583
587
  buffer_size = self.max_running_requests * 2
@@ -593,7 +597,7 @@ class Scheduler(
593
597
  )
594
598
  metadata_buffers = [output_id_buffer]
595
599
 
596
- self.disagg_prefill_pending_queue = PrefillBootstrapQueue(
600
+ self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
597
601
  token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
598
602
  req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
599
603
  metadata_buffers=metadata_buffers,
@@ -787,6 +791,7 @@ class Scheduler(
787
791
  return_hidden_states=recv_req.return_hidden_states,
788
792
  eos_token_ids=self.model_config.hf_eos_token_id,
789
793
  bootstrap_host=recv_req.bootstrap_host,
794
+ bootstrap_port=recv_req.bootstrap_port,
790
795
  bootstrap_room=recv_req.bootstrap_room,
791
796
  )
792
797
  req.tokenizer = self.tokenizer
@@ -901,7 +906,7 @@ class Scheduler(
901
906
  def _add_request_to_queue(self, req: Req):
902
907
  req.queue_time_start = time.time()
903
908
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
904
- self.disagg_prefill_pending_queue.add(req)
909
+ self.disagg_prefill_bootstrap_queue.add(req)
905
910
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
906
911
  self.disagg_decode_prealloc_queue.add(req)
907
912
  else:
@@ -991,8 +996,15 @@ class Scheduler(
991
996
  f"#cached-token: {adder.log_hit_tokens}, "
992
997
  f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
993
998
  f"#running-req: {running_bs}, "
994
- f"#queue-req: {len(self.waiting_queue)}, "
995
999
  )
1000
+
1001
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
1002
+ f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
1003
+ f += f"#queue-req: {len(self.waiting_queue)}, "
1004
+ f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)} "
1005
+ else:
1006
+ f += f"#queue-req: {len(self.waiting_queue)}"
1007
+
996
1008
  logger.info(f)
997
1009
 
998
1010
  if self.enable_metrics:
@@ -1028,15 +1040,14 @@ class Scheduler(
1028
1040
  gap_latency / self.server_args.decode_log_interval
1029
1041
  )
1030
1042
 
1043
+ msg = (
1044
+ f"Decode batch. "
1045
+ f"#running-req: {num_running_reqs}, "
1046
+ f"#token: {num_used}, "
1047
+ f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
1048
+ )
1049
+
1031
1050
  if self.spec_algorithm.is_none():
1032
- msg = (
1033
- f"Decode batch. "
1034
- f"#running-req: {num_running_reqs}, "
1035
- f"#token: {num_used}, "
1036
- f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
1037
- f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
1038
- f"#queue-req: {len(self.waiting_queue)}, "
1039
- )
1040
1051
  spec_accept_length = 0
1041
1052
  else:
1042
1053
  spec_accept_length = (
@@ -1045,15 +1056,15 @@ class Scheduler(
1045
1056
  self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
1046
1057
  self.cum_spec_accept_count += self.spec_num_total_forward_ct
1047
1058
  self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
1048
- msg = (
1049
- f"Decode batch. "
1050
- f"#running-req: {num_running_reqs}, "
1051
- f"#token: {num_used}, "
1052
- f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
1053
- f"accept len: {spec_accept_length:.2f}, "
1054
- f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
1055
- f"#queue-req: {len(self.waiting_queue)}, "
1056
- )
1059
+ msg += f"accept len: {spec_accept_length:.2f}, "
1060
+
1061
+ if self.disaggregation_mode == DisaggregationMode.DECODE:
1062
+ msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
1063
+
1064
+ msg += (
1065
+ f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
1066
+ f"#queue-req: {len(self.waiting_queue)}"
1067
+ )
1057
1068
 
1058
1069
  logger.info(msg)
1059
1070
  if self.enable_metrics:
@@ -419,7 +419,10 @@ class TokenizerManager:
419
419
  input_ids = self.tokenizer.encode(input_text)
420
420
 
421
421
  image_inputs: Dict = await self.mm_processor.process_mm_data_async(
422
- obj.image_data, input_text or input_ids, obj, self.max_req_input_len
422
+ image_data=obj.image_data,
423
+ input_text=input_text or input_ids,
424
+ request_obj=obj,
425
+ max_req_input_len=self.max_req_input_len,
423
426
  )
424
427
  if image_inputs and "input_ids" in image_inputs:
425
428
  input_ids = image_inputs["input_ids"]
@@ -495,6 +498,7 @@ class TokenizerManager:
495
498
  token_ids_logprob,
496
499
  obj.stream,
497
500
  bootstrap_host=obj.bootstrap_host,
501
+ bootstrap_port=obj.bootstrap_port,
498
502
  bootstrap_room=obj.bootstrap_room,
499
503
  lora_path=obj.lora_path,
500
504
  input_embeds=input_embeds,
@@ -34,6 +34,8 @@ from typing import List, Optional, Tuple, Union
34
34
  import numpy as np
35
35
  import psutil
36
36
  import torch
37
+ import triton
38
+ import triton.language as tl
37
39
 
38
40
  from sglang.srt.layers.radix_attention import RadixAttention
39
41
  from sglang.srt.utils import debug_timing, get_compiler_backend
@@ -405,6 +407,72 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
405
407
  dst_2[loc] = src_2.to(dtype).view(store_dtype)
406
408
 
407
409
 
410
+ @triton.jit
411
+ def set_mla_kv_buffer_kernel(
412
+ kv_buffer_ptr,
413
+ cache_k_nope_ptr,
414
+ cache_k_rope_ptr,
415
+ loc_ptr,
416
+ buffer_stride: tl.constexpr,
417
+ nope_stride: tl.constexpr,
418
+ rope_stride: tl.constexpr,
419
+ nope_dim: tl.constexpr,
420
+ rope_dim: tl.constexpr,
421
+ BLOCK: tl.constexpr,
422
+ ):
423
+ pid_loc = tl.program_id(0)
424
+ pid_blk = tl.program_id(1)
425
+
426
+ base = pid_blk * BLOCK
427
+ offs = base + tl.arange(0, BLOCK)
428
+ total_dim = nope_dim + rope_dim
429
+ mask = offs < total_dim
430
+
431
+ loc = tl.load(loc_ptr + pid_loc)
432
+ dst_ptr = kv_buffer_ptr + loc * buffer_stride + offs
433
+
434
+ if base + BLOCK <= nope_dim:
435
+ src = tl.load(
436
+ cache_k_nope_ptr + pid_loc * nope_stride + offs,
437
+ mask=mask,
438
+ )
439
+ else:
440
+ offs_rope = offs - nope_dim
441
+ src = tl.load(
442
+ cache_k_rope_ptr + pid_loc * rope_stride + offs_rope,
443
+ mask=mask,
444
+ )
445
+
446
+ tl.store(dst_ptr, src, mask=mask)
447
+
448
+
449
+ def set_mla_kv_buffer_triton(
450
+ kv_buffer: torch.Tensor,
451
+ loc: torch.Tensor,
452
+ cache_k_nope: torch.Tensor,
453
+ cache_k_rope: torch.Tensor,
454
+ ):
455
+ nope_dim = cache_k_nope.shape[-1]
456
+ rope_dim = cache_k_rope.shape[-1]
457
+ total_dim = nope_dim + rope_dim
458
+ BLOCK = 128
459
+ n_loc = loc.numel()
460
+ grid = (n_loc, triton.cdiv(total_dim, BLOCK))
461
+
462
+ set_mla_kv_buffer_kernel[grid](
463
+ kv_buffer,
464
+ cache_k_nope,
465
+ cache_k_rope,
466
+ loc,
467
+ kv_buffer.stride(0),
468
+ cache_k_nope.stride(0),
469
+ cache_k_rope.stride(0),
470
+ nope_dim,
471
+ rope_dim,
472
+ BLOCK=BLOCK,
473
+ )
474
+
475
+
408
476
  class MLATokenToKVPool(KVCache):
409
477
  def __init__(
410
478
  self,
@@ -504,6 +572,25 @@ class MLATokenToKVPool(KVCache):
504
572
  else:
505
573
  self.kv_buffer[layer_id][loc] = cache_k
506
574
 
575
+ def set_mla_kv_buffer(
576
+ self,
577
+ layer: RadixAttention,
578
+ loc: torch.Tensor,
579
+ cache_k_nope: torch.Tensor,
580
+ cache_k_rope: torch.Tensor,
581
+ ):
582
+ layer_id = layer.layer_id
583
+ if cache_k_nope.dtype != self.dtype:
584
+ cache_k_nope = cache_k_nope.to(self.dtype)
585
+ cache_k_rope = cache_k_rope.to(self.dtype)
586
+ if self.store_dtype != self.dtype:
587
+ cache_k_nope = cache_k_nope.view(self.store_dtype)
588
+ cache_k_rope = cache_k_rope.view(self.store_dtype)
589
+
590
+ set_mla_kv_buffer_triton(
591
+ self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
592
+ )
593
+
507
594
  def get_flat_data(self, indices):
508
595
  # prepare a large chunk of contiguous data for efficient transfer
509
596
  return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)])
@@ -134,7 +134,8 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
134
134
  )
135
135
 
136
136
  gpu_mem = get_device_memory_capacity()
137
- if gpu_mem is not None and gpu_mem > 81920:
137
+ # Batch size of each rank will not become so large when DP is on
138
+ if gpu_mem is not None and gpu_mem > 81920 and server_args.dp_size == 1:
138
139
  capture_bs += list(range(160, 257, 8))
139
140
 
140
141
  if max(capture_bs) > model_runner.req_to_token_pool.size:
@@ -278,9 +279,9 @@ class CudaGraphRunner:
278
279
  f"Capture cuda graph failed: {e}\n"
279
280
  "Possible solutions:\n"
280
281
  "1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
281
- "2. set --cuda-graph-max-bs to a smaller value (e.g., 32)\n"
282
+ "2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n"
282
283
  "3. disable torch compile by not using --enable-torch-compile\n"
283
- "4. disable cuda graph by --disable-cuda-graph\n"
284
+ "4. disable cuda graph by --disable-cuda-graph. (Not recommonded. Huge perf loss)\n"
284
285
  "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
285
286
  )
286
287