sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/configs/deepseekvl2.py +11 -2
  4. sglang/srt/configs/internvl.py +3 -0
  5. sglang/srt/configs/janus_pro.py +3 -0
  6. sglang/srt/configs/model_config.py +9 -7
  7. sglang/srt/configs/update_config.py +3 -1
  8. sglang/srt/conversation.py +1 -0
  9. sglang/srt/custom_op.py +5 -2
  10. sglang/srt/disaggregation/decode.py +9 -1
  11. sglang/srt/disaggregation/mooncake/conn.py +44 -56
  12. sglang/srt/distributed/parallel_state.py +33 -0
  13. sglang/srt/entrypoints/engine.py +30 -26
  14. sglang/srt/entrypoints/openai/serving_chat.py +21 -2
  15. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  16. sglang/srt/function_call/function_call_parser.py +2 -0
  17. sglang/srt/function_call/qwen3_detector.py +150 -0
  18. sglang/srt/hf_transformers_utils.py +0 -1
  19. sglang/srt/layers/activation.py +13 -0
  20. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  21. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  22. sglang/srt/layers/linear.py +13 -102
  23. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  24. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  25. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  26. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  34. sglang/srt/layers/moe/topk.py +187 -12
  35. sglang/srt/layers/quantization/__init__.py +20 -134
  36. sglang/srt/layers/quantization/awq.py +578 -11
  37. sglang/srt/layers/quantization/awq_triton.py +339 -0
  38. sglang/srt/layers/quantization/base_config.py +85 -10
  39. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  40. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
  42. sglang/srt/layers/quantization/fp8.py +273 -62
  43. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  44. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  45. sglang/srt/layers/quantization/gptq.py +501 -143
  46. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +26 -108
  48. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  49. sglang/srt/layers/quantization/petit.py +252 -0
  50. sglang/srt/layers/quantization/petit_utils.py +104 -0
  51. sglang/srt/layers/quantization/qoq.py +7 -6
  52. sglang/srt/layers/quantization/scalar_type.py +352 -0
  53. sglang/srt/layers/quantization/unquant.py +422 -0
  54. sglang/srt/layers/quantization/utils.py +343 -3
  55. sglang/srt/layers/quantization/w4afp8.py +8 -4
  56. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  57. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  58. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  59. sglang/srt/lora/lora.py +0 -4
  60. sglang/srt/lora/lora_manager.py +87 -53
  61. sglang/srt/lora/mem_pool.py +81 -33
  62. sglang/srt/lora/utils.py +12 -5
  63. sglang/srt/managers/cache_controller.py +241 -0
  64. sglang/srt/managers/io_struct.py +41 -29
  65. sglang/srt/managers/mm_utils.py +7 -8
  66. sglang/srt/managers/schedule_batch.py +150 -110
  67. sglang/srt/managers/schedule_policy.py +68 -27
  68. sglang/srt/managers/scheduler.py +243 -61
  69. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  70. sglang/srt/managers/tokenizer_manager.py +11 -3
  71. sglang/srt/managers/tp_worker.py +14 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  73. sglang/srt/mem_cache/allocator.py +7 -16
  74. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  75. sglang/srt/mem_cache/chunk_cache.py +5 -2
  76. sglang/srt/mem_cache/hicache_storage.py +152 -0
  77. sglang/srt/mem_cache/hiradix_cache.py +179 -4
  78. sglang/srt/mem_cache/memory_pool.py +16 -1
  79. sglang/srt/mem_cache/memory_pool_host.py +41 -2
  80. sglang/srt/mem_cache/radix_cache.py +26 -0
  81. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  82. sglang/srt/metrics/collector.py +9 -0
  83. sglang/srt/model_executor/cuda_graph_runner.py +5 -6
  84. sglang/srt/model_executor/forward_batch_info.py +14 -1
  85. sglang/srt/model_executor/model_runner.py +109 -22
  86. sglang/srt/model_loader/loader.py +7 -1
  87. sglang/srt/model_loader/utils.py +4 -4
  88. sglang/srt/models/clip.py +1 -1
  89. sglang/srt/models/deepseek.py +9 -6
  90. sglang/srt/models/deepseek_janus_pro.py +1 -1
  91. sglang/srt/models/deepseek_v2.py +191 -171
  92. sglang/srt/models/deepseek_vl2.py +5 -5
  93. sglang/srt/models/gemma.py +48 -0
  94. sglang/srt/models/gemma2.py +52 -0
  95. sglang/srt/models/gemma3_causal.py +63 -0
  96. sglang/srt/models/gemma3_mm.py +1 -1
  97. sglang/srt/models/gemma3n_mm.py +2 -4
  98. sglang/srt/models/granitemoe.py +385 -0
  99. sglang/srt/models/grok.py +9 -3
  100. sglang/srt/models/hunyuan.py +63 -16
  101. sglang/srt/models/internvl.py +1 -1
  102. sglang/srt/models/kimi_vl.py +1 -1
  103. sglang/srt/models/llama.py +41 -0
  104. sglang/srt/models/llama4.py +11 -11
  105. sglang/srt/models/llava.py +2 -2
  106. sglang/srt/models/llavavid.py +1 -1
  107. sglang/srt/models/minicpm.py +0 -2
  108. sglang/srt/models/minicpmo.py +3 -7
  109. sglang/srt/models/minicpmv.py +1 -1
  110. sglang/srt/models/mistral.py +1 -1
  111. sglang/srt/models/mixtral.py +9 -2
  112. sglang/srt/models/mllama.py +3 -5
  113. sglang/srt/models/mllama4.py +3 -3
  114. sglang/srt/models/olmoe.py +8 -5
  115. sglang/srt/models/persimmon.py +330 -0
  116. sglang/srt/models/phi.py +321 -0
  117. sglang/srt/models/phi4mm.py +44 -4
  118. sglang/srt/models/phi4mm_audio.py +1260 -0
  119. sglang/srt/models/phi4mm_utils.py +1917 -0
  120. sglang/srt/models/phimoe.py +9 -3
  121. sglang/srt/models/qwen.py +37 -0
  122. sglang/srt/models/qwen2.py +41 -0
  123. sglang/srt/models/qwen2_5_vl.py +4 -4
  124. sglang/srt/models/qwen2_audio.py +1 -1
  125. sglang/srt/models/qwen2_moe.py +53 -5
  126. sglang/srt/models/qwen2_vl.py +4 -4
  127. sglang/srt/models/qwen3.py +65 -1
  128. sglang/srt/models/qwen3_moe.py +56 -18
  129. sglang/srt/models/vila.py +1 -1
  130. sglang/srt/multimodal/processors/base_processor.py +91 -97
  131. sglang/srt/multimodal/processors/clip.py +21 -19
  132. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  133. sglang/srt/multimodal/processors/gemma3.py +13 -17
  134. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  135. sglang/srt/multimodal/processors/internvl.py +9 -10
  136. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  137. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  138. sglang/srt/multimodal/processors/llava.py +4 -2
  139. sglang/srt/multimodal/processors/minicpm.py +35 -44
  140. sglang/srt/multimodal/processors/mlama.py +21 -18
  141. sglang/srt/multimodal/processors/mllama4.py +4 -5
  142. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  143. sglang/srt/multimodal/processors/pixtral.py +14 -35
  144. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  145. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  146. sglang/srt/multimodal/processors/vila.py +14 -14
  147. sglang/srt/sampling/sampling_params.py +8 -1
  148. sglang/srt/server_args.py +393 -230
  149. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
  150. sglang/srt/two_batch_overlap.py +1 -0
  151. sglang/srt/utils.py +27 -1
  152. sglang/test/runners.py +14 -3
  153. sglang/test/test_block_fp8.py +8 -3
  154. sglang/test/test_block_fp8_ep.py +1 -1
  155. sglang/test/test_custom_ops.py +12 -7
  156. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  157. sglang/test/test_fp4_moe.py +1 -3
  158. sglang/test/test_marlin_moe.py +286 -0
  159. sglang/test/test_marlin_utils.py +171 -0
  160. sglang/test/test_utils.py +35 -0
  161. sglang/version.py +1 -1
  162. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
  163. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
  164. sglang/srt/layers/quantization/quant_utils.py +0 -166
  165. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  166. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
  167. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
  168. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py CHANGED
@@ -271,12 +271,13 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
271
271
  batch,
272
272
  dp_size=model_runner.server_args.dp_size,
273
273
  attn_tp_size=1,
274
- tp_cpu_group=model_runner.tp_group.cpu_group,
274
+ tp_group=model_runner.tp_group,
275
275
  get_idle_batch=None,
276
276
  disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
277
277
  spec_algorithm=SpeculativeAlgorithm.NONE,
278
278
  speculative_num_draft_tokens=None,
279
279
  require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
280
+ disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule,
280
281
  )
281
282
 
282
283
 
@@ -73,6 +73,8 @@ async def benchmark(args):
73
73
 
74
74
  tasks: List[asyncio.Task] = []
75
75
  for idx, ex in enumerate(dataset):
76
+ if idx >= args.num_prompts:
77
+ break
76
78
  tasks.append(
77
79
  asyncio.create_task(
78
80
  fetch_response(
@@ -103,6 +105,8 @@ def analyse(args):
103
105
  hyps: List[str] = []
104
106
  refs: List[str] = []
105
107
  for idx, ex in enumerate(tqdm(dataset, desc="Loading responses")):
108
+ if idx >= args.num_prompts:
109
+ break
106
110
  pkl_file = output_dir / f"response_{idx}.pkl"
107
111
  if not pkl_file.exists():
108
112
  raise FileNotFoundError(pkl_file)
@@ -150,6 +154,9 @@ if __name__ == "__main__":
150
154
  parser.add_argument(
151
155
  "--output-dir", default="tmp-output-dir", help="Directory for cached responses"
152
156
  )
157
+ parser.add_argument(
158
+ "--num-prompts", type=int, default=10000, help="Number of prompts to run"
159
+ )
153
160
  args = parser.parse_args()
154
161
 
155
162
  asyncio.run(benchmark(args))
@@ -42,6 +42,9 @@ def select_best_resolution(image_size, candidate_resolutions):
42
42
 
43
43
 
44
44
  class DictOutput(object):
45
+ def items(self):
46
+ return self.__dict__.items()
47
+
45
48
  def keys(self):
46
49
  return self.__dict__.keys()
47
50
 
@@ -59,7 +62,9 @@ class DictOutput(object):
59
62
  class VLChatProcessorOutput(DictOutput):
60
63
  input_ids: torch.LongTensor
61
64
  target_ids: torch.LongTensor
62
- images: torch.Tensor
65
+ pixel_values: (
66
+ torch.Tensor
67
+ ) # rename from "images" to "pixel_values" for compatibility
63
68
  images_seq_mask: torch.BoolTensor
64
69
  images_spatial_crop: torch.LongTensor
65
70
 
@@ -312,10 +317,14 @@ class DeepseekVLV2Processor(ProcessorMixin):
312
317
  images = torch.stack(images_list, dim=0)
313
318
  images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
314
319
 
320
+ images_spatial_crop = torch.stack(
321
+ [images_spatial_crop], dim=0
322
+ ) # stack the tensor to make it a batch of 1
323
+
315
324
  prepare = VLChatProcessorOutput(
316
325
  input_ids=input_ids,
317
326
  target_ids=target_ids,
318
- images=images,
327
+ pixel_values=images,
319
328
  images_seq_mask=images_seq_mask,
320
329
  images_spatial_crop=images_spatial_crop,
321
330
  )
@@ -9,6 +9,7 @@ from transformers import (
9
9
  LlamaConfig,
10
10
  PretrainedConfig,
11
11
  PreTrainedTokenizer,
12
+ Qwen2Config,
12
13
  )
13
14
 
14
15
  from sglang.utils import logger
@@ -311,6 +312,8 @@ class InternVLChatConfig(PretrainedConfig):
311
312
  self.llm_config = LlamaConfig(**llm_config)
312
313
  elif llm_config.get("architectures")[0] == "InternLM2ForCausalLM":
313
314
  self.llm_config = InternLM2Config(**llm_config)
315
+ elif llm_config.get("architectures")[0] == "Qwen2ForCausalLM":
316
+ self.llm_config = Qwen2Config(**llm_config)
314
317
  else:
315
318
  raise ValueError(
316
319
  "Unsupported architecture: {}".format(
@@ -284,6 +284,9 @@ class VLMImageProcessor(BaseImageProcessor):
284
284
 
285
285
 
286
286
  class DictOutput(object):
287
+ def items(self):
288
+ return self.__dict__.items()
289
+
287
290
  def keys(self):
288
291
  return self.__dict__.keys()
289
292
 
@@ -53,7 +53,7 @@ class ModelConfig:
53
53
  trust_remote_code: bool = True,
54
54
  revision: Optional[str] = None,
55
55
  context_length: Optional[int] = None,
56
- model_override_args: Optional[str] = None,
56
+ model_override_args: str = "{}",
57
57
  is_embedding: Optional[bool] = None,
58
58
  enable_multimodal: Optional[bool] = None,
59
59
  dtype: str = "auto",
@@ -61,13 +61,13 @@ class ModelConfig:
61
61
  override_config_file: Optional[str] = None,
62
62
  is_draft_model: bool = False,
63
63
  hybrid_kvcache_ratio: Optional[float] = None,
64
- impl: Union[str, ModelImpl] = ModelImpl.AUTO,
64
+ model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
65
65
  ) -> None:
66
66
 
67
67
  self.model_path = model_path
68
68
  self.revision = revision
69
69
  self.quantization = quantization
70
- self.impl = impl
70
+ self.model_impl = model_impl
71
71
 
72
72
  # Parse args
73
73
  self.maybe_pull_model_tokenizer_from_remote()
@@ -286,7 +286,7 @@ class ModelConfig:
286
286
  dtype=server_args.dtype,
287
287
  quantization=server_args.quantization,
288
288
  hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
289
- impl=server_args.impl,
289
+ model_impl=server_args.model_impl,
290
290
  **kwargs,
291
291
  )
292
292
 
@@ -391,6 +391,7 @@ class ModelConfig:
391
391
  "compressed-tensors",
392
392
  "fbgemm_fp8",
393
393
  "w8a8_fp8",
394
+ "petit_nvfp4",
394
395
  ]
395
396
  optimized_quantization_methods = [
396
397
  "fp8",
@@ -408,9 +409,11 @@ class ModelConfig:
408
409
  "moe_wna16",
409
410
  "qoq",
410
411
  "w4afp8",
412
+ "petit_nvfp4",
411
413
  ]
412
414
  compatible_quantization_methods = {
413
415
  "modelopt_fp4": ["modelopt"],
416
+ "petit_nvfp4": ["modelopt"],
414
417
  "w8a8_int8": ["compressed-tensors", "compressed_tensors"],
415
418
  "w8a8_fp8": ["compressed-tensors", "compressed_tensors"],
416
419
  }
@@ -711,7 +714,6 @@ def get_hybrid_layer_ids(model_architectures: List[str], num_hidden_layers: int)
711
714
  i for i in range(num_hidden_layers) if (i + 1) % 4 == 0
712
715
  ]
713
716
  else:
714
- raise ValueError(
715
- "get_hybrid_layer_ids is only implemented for Llama4ForConditionalGeneration"
716
- )
717
+ swa_attention_layer_ids = None
718
+ full_attention_layer_ids = None
717
719
  return swa_attention_layer_ids, full_attention_layer_ids
@@ -115,5 +115,7 @@ def adjust_config_with_unaligned_cpu_tp(
115
115
  model_config = update_intermediate_size(
116
116
  model_config, "intermediate_size", intermediate_padding_size
117
117
  )
118
-
118
+ model_config = update_intermediate_size(
119
+ model_config, "intermediate_size_mlp", intermediate_padding_size
120
+ )
119
121
  return model_config
@@ -729,6 +729,7 @@ register_conv_template(
729
729
  sep="<|end|>",
730
730
  stop_str="<|end|>",
731
731
  image_token="<|endoftext10|>",
732
+ audio_token="<|endoftext11|>",
732
733
  )
733
734
  )
734
735
 
sglang/srt/custom_op.py CHANGED
@@ -29,15 +29,18 @@ class CustomOp(nn.Module):
29
29
 
30
30
  self._original_forward_method = self._forward_method
31
31
  # NOTE: Temporarily workaround MoE
32
+ # The performance of torch.compile on this layer is not always good when bs > 1,
33
+ # so we decide to only use torch.compile when bs=1
32
34
  if "FusedMoE" in self.__class__.__name__:
33
35
  if num_tokens == 1:
34
36
  from sglang.srt.layers.moe.fused_moe_native import (
35
37
  fused_moe_forward_native,
36
38
  )
37
39
 
38
- # The performance of torch.compile on this layer is not always good when bs > 1,
39
- # so we decide to only use torch.compile when bs =1
40
40
  self._forward_method = fused_moe_forward_native
41
+ elif "TopK" in self.__class__.__name__:
42
+ if num_tokens == 1:
43
+ self._forward_method = self.forward_native
41
44
  else:
42
45
  self._forward_method = self.forward_native
43
46
  self.is_torch_compile = True
@@ -439,7 +439,15 @@ class DecodePreallocQueue:
439
439
  else 0
440
440
  )
441
441
 
442
- allocatable_tokens = self.token_to_kv_pool_allocator.available_size() - max(
442
+ if self.scheduler.model_config.is_hybrid:
443
+ available_size = min(
444
+ self.token_to_kv_pool_allocator.full_available_size(),
445
+ self.token_to_kv_pool_allocator.swa_available_size(),
446
+ )
447
+ else:
448
+ available_size = self.token_to_kv_pool_allocator.available_size()
449
+
450
+ allocatable_tokens = available_size - max(
443
451
  # preserve some space for future decode
444
452
  self.num_reserved_decode_tokens
445
453
  * (
@@ -321,67 +321,60 @@ class MooncakeKVManager(BaseKVManager):
321
321
  This may introduce performance overhead (increased TTFT) for long sequences.
322
322
  """
323
323
  # Extract configuration
324
- local_tp_rank = self.kv_args.engine_rank
325
324
  local_tp_size = self.tp_size // self.dp_size
325
+ local_tp_rank_in_group = self.kv_args.engine_rank % local_tp_size
326
+ src_kv_item_len = self.kv_args.kv_item_lens[0]
327
+ dst_tp_rank_in_group = dst_tp_rank % dst_tp_size
326
328
  num_kv_heads = self.kv_args.kv_head_num
327
329
  num_layers = len(self.kv_args.kv_data_ptrs)
328
330
  page_size = self.kv_args.page_size
329
331
 
330
332
  # Calculate head distribution
331
- heads_per_decode_rank = num_kv_heads * local_tp_size // dst_tp_size
332
- heads_per_prefill_rank = num_kv_heads
333
- decode_global_head_start = dst_tp_rank * heads_per_decode_rank
334
- prefill_global_head_start = local_tp_rank * heads_per_prefill_rank
335
- bytes_per_head = dst_kv_item_len // heads_per_decode_rank // page_size
336
-
337
- decode_rank_item_lens = [dst_kv_item_len for _ in range(num_layers)]
333
+ src_heads_per_rank = num_kv_heads
334
+ dst_heads_per_rank = num_kv_heads * local_tp_size // dst_tp_size
335
+ bytes_per_head_slice_to_send = (
336
+ dst_kv_item_len // page_size // dst_heads_per_rank
337
+ )
338
338
 
339
339
  # Determine slicing parameters based on TP configuration
340
340
  if local_tp_size > dst_tp_size:
341
- src_head_offset = 0
342
- num_heads_to_send = heads_per_prefill_rank
343
- dst_head_offset = prefill_global_head_start - decode_global_head_start
341
+ # Send KVCache from multiple prefill instances to 1 decode instance
342
+ src_head_start_offset = 0
343
+ num_heads_to_send = src_heads_per_rank
344
+ dst_head_start_offset = local_tp_rank_in_group * src_heads_per_rank
344
345
  else:
345
- src_head_offset = decode_global_head_start - prefill_global_head_start
346
- num_heads_to_send = heads_per_decode_rank
347
- dst_head_offset = 0
346
+ # Send KVCache from 1 prefill instance to multiple decode instances
347
+ src_head_start_offset = dst_tp_rank_in_group * dst_heads_per_rank
348
+ num_heads_to_send = dst_heads_per_rank
349
+ dst_head_start_offset = 0
348
350
 
349
- layer_transfer_params = []
351
+ layers_params = []
350
352
  for layer_id in range(num_layers):
351
- item_len_of_prefill_rank_page = self.kv_args.kv_item_lens[layer_id]
352
-
353
- # Page stride on the target dst decode rank for its slice pages
354
- item_len_of_decode_rank_page = decode_rank_item_lens[layer_id]
355
-
356
- if item_len_of_prefill_rank_page == 0 or num_kv_heads == 0:
357
- logger.error(
358
- f"Invalid item_len_of_prefill_rank_page or num_kv_heads for layer {layer_id}"
359
- )
360
- return -1
361
-
362
- # Calculate precise byte offset and length for the sub-slice within the prefill page data
363
- src_slice_offset = src_head_offset * bytes_per_head
364
- dst_slice_offset = dst_head_offset * bytes_per_head
365
- slice_lens_per_page = num_heads_to_send * bytes_per_head
353
+ # Calculate precise byte offset and length for the sub-slice within the token
354
+ src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
355
+ dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
356
+ heads_bytes_per_token_to_send = (
357
+ num_heads_to_send * bytes_per_head_slice_to_send
358
+ )
366
359
 
367
- # Sanity check: The data sub-slice to be sent should fit into the decode instance's page.
368
- # This means slice_lens_per_page <= item_len_of_decode_rank_page
369
- if slice_lens_per_page > item_len_of_decode_rank_page:
360
+ # Sanity check: The data sub-slice to be sent should fit into the dst buffer.
361
+ # This means heads_bytes_per_token_to_send <= (dst_kv_item_len // page_size)
362
+ if heads_bytes_per_token_to_send > (dst_kv_item_len // page_size):
370
363
  logger.error(
371
364
  f"[{mooncake_session_id}] Layer {layer_id}: "
372
- f"slice size ({slice_lens_per_page}) exceeds "
373
- f"target page size ({item_len_of_decode_rank_page})"
365
+ f"slice size ({heads_bytes_per_token_to_send}) exceeds "
366
+ f"target token slot size ({dst_kv_item_len // page_size})"
374
367
  )
375
368
  return -1
376
- layer_transfer_params.append(
369
+ layers_params.append(
377
370
  (
378
371
  self.kv_args.kv_data_ptrs[layer_id],
379
372
  dst_kv_ptrs[layer_id],
380
- item_len_of_prefill_rank_page,
381
- item_len_of_decode_rank_page,
382
- src_slice_offset,
383
- dst_slice_offset,
384
- slice_lens_per_page,
373
+ src_kv_item_len,
374
+ dst_kv_item_len,
375
+ src_head_slice_offset,
376
+ dst_head_slice_offset,
377
+ heads_bytes_per_token_to_send,
385
378
  )
386
379
  )
387
380
 
@@ -391,9 +384,9 @@ class MooncakeKVManager(BaseKVManager):
391
384
  dst_ptr,
392
385
  src_item_len,
393
386
  dst_item_len,
394
- src_offset,
395
- dst_offset,
396
- slice_lens_per_page,
387
+ src_head_slice_offset,
388
+ dst_head_slice_offset,
389
+ heads_bytes_per_token_to_send,
397
390
  ) = layer_params
398
391
  src_addr_list = []
399
392
  dst_addr_list = []
@@ -424,17 +417,12 @@ class MooncakeKVManager(BaseKVManager):
424
417
  )
425
418
 
426
419
  # Calculate final src and dst addresses by applying head-slice offsets
427
- src_slice_addr = src_token_slot_start_addr + src_offset
428
- dst_slice_addr = dst_token_slot_start_addr + dst_offset
420
+ src_slice_addr = src_token_slot_start_addr + src_head_slice_offset
421
+ dst_slice_addr = dst_token_slot_start_addr + dst_head_slice_offset
429
422
 
430
423
  src_addr_list.append(src_slice_addr)
431
424
  dst_addr_list.append(dst_slice_addr)
432
- length_list.append(slice_lens_per_page)
433
-
434
- logger.debug(
435
- f"SYNC: sid={mooncake_session_id}, "
436
- f"src={src_slice_addr}, dst={dst_slice_addr}, len={slice_lens_per_page}"
437
- )
425
+ length_list.append(heads_bytes_per_token_to_send)
438
426
 
439
427
  return self.engine.batch_transfer_sync(
440
428
  mooncake_session_id, src_addr_list, dst_addr_list, length_list
@@ -445,7 +433,7 @@ class MooncakeKVManager(BaseKVManager):
445
433
  process_layer_tp_aware,
446
434
  layer_params,
447
435
  )
448
- for layer_params in layer_transfer_params
436
+ for layer_params in layers_params
449
437
  ]
450
438
 
451
439
  for future in concurrent.futures.as_completed(futures):
@@ -533,12 +521,12 @@ class MooncakeKVManager(BaseKVManager):
533
521
  if len(chunked_dst_kv_indice) < len(
534
522
  kv_chunk.prefill_kv_indices
535
523
  ):
536
- kv_chunk.prefill_kv_indices = kv_chunk.prefill_kv_indices[
537
- : len(chunked_dst_kv_indice)
538
- ]
539
524
  logger.warning(
540
525
  f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
541
526
  )
527
+ kv_chunk.prefill_kv_indices = kv_chunk.prefill_kv_indices[
528
+ : len(chunked_dst_kv_indice)
529
+ ]
542
530
 
543
531
  target_rank_registration_info: KVArgsRegisterInfo = (
544
532
  self.decode_kv_args_table[req.mooncake_session_id]
@@ -1065,8 +1065,23 @@ def init_model_parallel_group(
1065
1065
 
1066
1066
  _TP: Optional[GroupCoordinator] = None
1067
1067
 
1068
+ # duplicate GroupCoordinator for prefill in PD-Multiplexing
1069
+ _PDMUX_PREFILL_TP_GROUP: Optional[GroupCoordinator] = None
1070
+
1071
+ _ENABLE_PDMUX_P_TP: bool = False
1072
+
1073
+
1074
+ def set_pdmux_status(enable_prefill_multiplexing: bool):
1075
+ global _ENABLE_PDMUX_P_TP
1076
+ _ENABLE_PDMUX_P_TP = enable_prefill_multiplexing
1077
+
1068
1078
 
1069
1079
  def get_tp_group() -> GroupCoordinator:
1080
+ if _ENABLE_PDMUX_P_TP:
1081
+ assert (
1082
+ _PDMUX_PREFILL_TP_GROUP is not None
1083
+ ), "tensor model parallel group for PD-Multiplexing Prefill is not initialized"
1084
+ return _PDMUX_PREFILL_TP_GROUP
1070
1085
  assert _TP is not None, "tensor model parallel group is not initialized"
1071
1086
  return _TP
1072
1087
 
@@ -1182,6 +1197,7 @@ def initialize_model_parallel(
1182
1197
  tensor_model_parallel_size: int = 1,
1183
1198
  pipeline_model_parallel_size: int = 1,
1184
1199
  backend: Optional[str] = None,
1200
+ duplicate_tp_group: bool = False,
1185
1201
  ) -> None:
1186
1202
  """
1187
1203
  Initialize model parallel groups.
@@ -1239,6 +1255,23 @@ def initialize_model_parallel(
1239
1255
  group_name="tp",
1240
1256
  )
1241
1257
 
1258
+ if duplicate_tp_group:
1259
+ global _PDMUX_PREFILL_TP_GROUP
1260
+ assert (
1261
+ _PDMUX_PREFILL_TP_GROUP is None
1262
+ ), "tensor model parallel group for PD-Multiplexing Prefill is already initialized"
1263
+ _PDMUX_PREFILL_TP_GROUP = init_model_parallel_group(
1264
+ group_ranks,
1265
+ get_world_group().local_rank,
1266
+ backend,
1267
+ use_message_queue_broadcaster=get_bool_env_var(
1268
+ "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
1269
+ ),
1270
+ group_name="pdmux_prefill_tp",
1271
+ )
1272
+ _TP.pynccl_comm.disabled = False
1273
+ _PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
1274
+
1242
1275
  # Build the pipeline model-parallel groups.
1243
1276
  num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
1244
1277
  global _PP
@@ -46,9 +46,9 @@ from sglang.srt.managers.io_struct import (
46
46
  EmbeddingReqInput,
47
47
  GenerateReqInput,
48
48
  GetWeightsByNameReqInput,
49
- ImageDataItem,
50
49
  InitWeightsUpdateGroupReqInput,
51
50
  LoadLoRAAdapterReqInput,
51
+ MultimodalDataInputFormat,
52
52
  ReleaseMemoryOccupationReqInput,
53
53
  ResumeMemoryOccupationReqInput,
54
54
  RpcReqInput,
@@ -148,13 +148,9 @@ class Engine(EngineBase):
148
148
  # - List of images (one per request in a batch)
149
149
  # - List of lists of images (multiple images per request)
150
150
  # See also python/sglang/srt/utils.py:load_image for more details.
151
- image_data: Optional[
152
- Union[
153
- List[List[ImageDataItem]],
154
- List[ImageDataItem],
155
- ImageDataItem,
156
- ]
157
- ] = None,
151
+ image_data: Optional[MultimodalDataInputFormat] = None,
152
+ audio_data: Optional[MultimodalDataInputFormat] = None,
153
+ video_data: Optional[MultimodalDataInputFormat] = None,
158
154
  return_logprob: Optional[Union[List[bool], bool]] = False,
159
155
  logprob_start_len: Optional[Union[List[int], int]] = None,
160
156
  top_logprobs_num: Optional[Union[List[int], int]] = None,
@@ -187,6 +183,8 @@ class Engine(EngineBase):
187
183
  input_ids=input_ids,
188
184
  sampling_params=sampling_params,
189
185
  image_data=image_data,
186
+ audio_data=audio_data,
187
+ video_data=video_data,
190
188
  return_logprob=return_logprob,
191
189
  logprob_start_len=logprob_start_len,
192
190
  top_logprobs_num=top_logprobs_num,
@@ -231,13 +229,9 @@ class Engine(EngineBase):
231
229
  # - List of images (one per request in a batch)
232
230
  # - List of lists of images (multiple images per request)
233
231
  # See also python/sglang/srt/utils.py:load_image for more details.
234
- image_data: Optional[
235
- Union[
236
- List[List[ImageDataItem]],
237
- List[ImageDataItem],
238
- ImageDataItem,
239
- ]
240
- ] = None,
232
+ image_data: Optional[MultimodalDataInputFormat] = None,
233
+ audio_data: Optional[MultimodalDataInputFormat] = None,
234
+ video_data: Optional[MultimodalDataInputFormat] = None,
241
235
  return_logprob: Optional[Union[List[bool], bool]] = False,
242
236
  logprob_start_len: Optional[Union[List[int], int]] = None,
243
237
  top_logprobs_num: Optional[Union[List[int], int]] = None,
@@ -272,6 +266,8 @@ class Engine(EngineBase):
272
266
  input_ids=input_ids,
273
267
  sampling_params=sampling_params,
274
268
  image_data=image_data,
269
+ audio_data=audio_data,
270
+ video_data=video_data,
275
271
  return_logprob=return_logprob,
276
272
  logprob_start_len=logprob_start_len,
277
273
  top_logprobs_num=top_logprobs_num,
@@ -295,19 +291,20 @@ class Engine(EngineBase):
295
291
  def encode(
296
292
  self,
297
293
  prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
298
- image_data: Optional[
299
- Union[
300
- List[List[Union[Image, str]]],
301
- List[Union[Image, str]],
302
- Union[Image, str],
303
- ]
304
- ] = None,
294
+ image_data: Optional[MultimodalDataInputFormat] = None,
295
+ audio_data: Optional[MultimodalDataInputFormat] = None,
296
+ video_data: Optional[MultimodalDataInputFormat] = None,
305
297
  ) -> Dict:
306
298
  """
307
299
  The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
308
300
  Please refer to `EmbeddingReqInput` for the documentation.
309
301
  """
310
- obj = EmbeddingReqInput(text=prompt, image_data=image_data)
302
+ obj = EmbeddingReqInput(
303
+ text=prompt,
304
+ image_data=image_data,
305
+ audio_data=audio_data,
306
+ video_data=video_data,
307
+ )
311
308
  loop = asyncio.get_event_loop()
312
309
  generator = self.tokenizer_manager.generate_request(obj, None)
313
310
  ret = loop.run_until_complete(generator.__anext__())
@@ -316,7 +313,9 @@ class Engine(EngineBase):
316
313
  async def async_encode(
317
314
  self,
318
315
  prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
319
- image_data: Optional[Union[List[str], str]] = None,
316
+ image_data: Optional[MultimodalDataInputFormat] = None,
317
+ audio_data: Optional[MultimodalDataInputFormat] = None,
318
+ video_data: Optional[MultimodalDataInputFormat] = None,
320
319
  ) -> Dict:
321
320
  """
322
321
  Asynchronous version of encode method.
@@ -324,7 +323,12 @@ class Engine(EngineBase):
324
323
  The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
325
324
  Please refer to `EmbeddingReqInput` for the documentation.
326
325
  """
327
- obj = EmbeddingReqInput(text=prompt, image_data=image_data)
326
+ obj = EmbeddingReqInput(
327
+ text=prompt,
328
+ image_data=image_data,
329
+ audio_data=audio_data,
330
+ video_data=video_data,
331
+ )
328
332
  generator = self.tokenizer_manager.generate_request(obj, None)
329
333
  return await generator.__anext__()
330
334
 
@@ -650,7 +654,7 @@ def _set_envs_and_config(server_args: ServerArgs):
650
654
  if _is_cuda:
651
655
  assert_pkg_version(
652
656
  "sgl-kernel",
653
- "0.2.5",
657
+ "0.2.6.post1",
654
658
  "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
655
659
  )
656
660
 
@@ -113,12 +113,12 @@ class OpenAIServingChat(OpenAIServingBase):
113
113
  request.skip_special_tokens = False
114
114
  if not isinstance(request.tool_choice, str):
115
115
  tools = [
116
- item.function.model_dump()
116
+ item.model_dump()
117
117
  for item in request.tools
118
118
  if item.function.name == request.tool_choice.function.name
119
119
  ]
120
120
  else:
121
- tools = [item.function.model_dump() for item in request.tools]
121
+ tools = [item.model_dump() for item in request.tools]
122
122
 
123
123
  tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
124
124
  parser = FunctionCallParser(request.tools, tool_call_parser)
@@ -164,6 +164,25 @@ class OpenAIServingChat(OpenAIServingBase):
164
164
  audio_data,
165
165
  modalities,
166
166
  )
167
+
168
+ if "tool_calls" in processed_msg and isinstance(
169
+ processed_msg.get("tool_calls"), list
170
+ ):
171
+ for call in processed_msg["tool_calls"]:
172
+ try:
173
+ if "arguments" in call["function"] and isinstance(
174
+ call["function"]["arguments"], str
175
+ ):
176
+ call["function"]["arguments"] = json.loads(
177
+ call["function"]["arguments"]
178
+ )
179
+ except json.JSONDecodeError as e:
180
+ # Log a warning or error if JSON parsing fails for arguments
181
+ logger.warning(
182
+ f"Failed to parse tool call arguments as JSON: {e}"
183
+ )
184
+ # Decide whether to continue or raise the exception based on desired behavior
185
+ continue # Or raise e if strict parsing is required
167
186
  openai_compatible_messages.append(processed_msg)
168
187
 
169
188
  # Handle assistant prefix for continue_final_message
@@ -66,7 +66,7 @@ def transform_select_experts_inputs(
66
66
  info: Optional[ExpertLocationDispatchInfo],
67
67
  ):
68
68
  if (info is not None) and (info.ep_dispatch_algorithm == "fake"):
69
- router_logits = torch.randn_like(router_logits)
69
+ router_logits.uniform_(5, 10)
70
70
  if correction_bias is not None:
71
71
  correction_bias = torch.zeros_like(correction_bias)
72
72
  return router_logits, correction_bias
@@ -14,6 +14,7 @@ from sglang.srt.function_call.kimik2_detector import KimiK2Detector
14
14
  from sglang.srt.function_call.llama32_detector import Llama32Detector
15
15
  from sglang.srt.function_call.mistral_detector import MistralDetector
16
16
  from sglang.srt.function_call.pythonic_detector import PythonicDetector
17
+ from sglang.srt.function_call.qwen3_detector import Qwen3XMLDetector
17
18
  from sglang.srt.function_call.qwen25_detector import Qwen25Detector
18
19
 
19
20
  logger = logging.getLogger(__name__)
@@ -35,6 +36,7 @@ class FunctionCallParser:
35
36
  "deepseekv3": DeepSeekV3Detector,
36
37
  "pythonic": PythonicDetector,
37
38
  "kimi_k2": KimiK2Detector,
39
+ "qwen3": Qwen3XMLDetector,
38
40
  }
39
41
 
40
42
  def __init__(self, tools: List[Tool], tool_call_parser: str):