sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +18 -3
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  5. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
  6. sglang/srt/checkpoint_engine/__init__.py +9 -0
  7. sglang/srt/checkpoint_engine/update.py +317 -0
  8. sglang/srt/configs/__init__.py +2 -0
  9. sglang/srt/configs/deepseek_ocr.py +542 -10
  10. sglang/srt/configs/deepseekvl2.py +95 -194
  11. sglang/srt/configs/kimi_linear.py +160 -0
  12. sglang/srt/configs/mamba_utils.py +66 -0
  13. sglang/srt/configs/model_config.py +25 -2
  14. sglang/srt/constants.py +7 -0
  15. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  16. sglang/srt/disaggregation/decode.py +34 -6
  17. sglang/srt/disaggregation/nixl/conn.py +2 -2
  18. sglang/srt/disaggregation/prefill.py +25 -3
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  20. sglang/srt/distributed/parallel_state.py +9 -5
  21. sglang/srt/entrypoints/engine.py +13 -5
  22. sglang/srt/entrypoints/http_server.py +22 -3
  23. sglang/srt/entrypoints/openai/protocol.py +7 -1
  24. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  27. sglang/srt/environ.py +7 -0
  28. sglang/srt/eplb/expert_distribution.py +34 -1
  29. sglang/srt/eplb/expert_location.py +106 -36
  30. sglang/srt/grpc/compile_proto.py +3 -0
  31. sglang/srt/layers/attention/ascend_backend.py +233 -5
  32. sglang/srt/layers/attention/attention_registry.py +3 -0
  33. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  34. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  35. sglang/srt/layers/attention/fla/kda.py +1359 -0
  36. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  37. sglang/srt/layers/attention/flashattention_backend.py +7 -6
  38. sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
  39. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  40. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  41. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  42. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  43. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  44. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  45. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  46. sglang/srt/layers/attention/nsa_backend.py +157 -23
  47. sglang/srt/layers/attention/triton_backend.py +4 -1
  48. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  49. sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
  50. sglang/srt/layers/communicator.py +23 -1
  51. sglang/srt/layers/layernorm.py +16 -2
  52. sglang/srt/layers/logits_processor.py +4 -20
  53. sglang/srt/layers/moe/ep_moe/layer.py +0 -18
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  57. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  59. sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
  60. sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
  61. sglang/srt/layers/moe/topk.py +31 -6
  62. sglang/srt/layers/pooler.py +21 -2
  63. sglang/srt/layers/quantization/__init__.py +9 -78
  64. sglang/srt/layers/quantization/auto_round.py +394 -0
  65. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  66. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  67. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  68. sglang/srt/layers/rotary_embedding.py +117 -45
  69. sglang/srt/lora/lora_registry.py +9 -0
  70. sglang/srt/managers/async_mm_data_processor.py +122 -0
  71. sglang/srt/managers/data_parallel_controller.py +30 -3
  72. sglang/srt/managers/detokenizer_manager.py +3 -0
  73. sglang/srt/managers/io_struct.py +26 -4
  74. sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
  75. sglang/srt/managers/schedule_batch.py +74 -15
  76. sglang/srt/managers/scheduler.py +164 -129
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  78. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  79. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  80. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  81. sglang/srt/managers/session_controller.py +6 -5
  82. sglang/srt/managers/tokenizer_manager.py +154 -59
  83. sglang/srt/managers/tp_worker.py +24 -1
  84. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  85. sglang/srt/mem_cache/common.py +1 -0
  86. sglang/srt/mem_cache/memory_pool.py +171 -57
  87. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  88. sglang/srt/mem_cache/radix_cache.py +4 -0
  89. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  90. sglang/srt/metrics/collector.py +46 -3
  91. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  92. sglang/srt/model_executor/forward_batch_info.py +11 -11
  93. sglang/srt/model_executor/model_runner.py +76 -21
  94. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  95. sglang/srt/model_loader/weight_utils.py +1 -1
  96. sglang/srt/models/bailing_moe.py +9 -2
  97. sglang/srt/models/deepseek_nextn.py +11 -2
  98. sglang/srt/models/deepseek_v2.py +149 -34
  99. sglang/srt/models/glm4.py +391 -77
  100. sglang/srt/models/glm4v.py +196 -55
  101. sglang/srt/models/glm4v_moe.py +0 -1
  102. sglang/srt/models/gpt_oss.py +1 -10
  103. sglang/srt/models/kimi_linear.py +678 -0
  104. sglang/srt/models/llama4.py +1 -1
  105. sglang/srt/models/llama_eagle3.py +11 -1
  106. sglang/srt/models/longcat_flash.py +2 -2
  107. sglang/srt/models/minimax_m2.py +1 -1
  108. sglang/srt/models/qwen2.py +1 -1
  109. sglang/srt/models/qwen2_moe.py +30 -15
  110. sglang/srt/models/qwen3.py +1 -1
  111. sglang/srt/models/qwen3_moe.py +16 -8
  112. sglang/srt/models/qwen3_next.py +7 -0
  113. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  114. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  115. sglang/srt/multiplex/pdmux_context.py +164 -0
  116. sglang/srt/parser/conversation.py +7 -1
  117. sglang/srt/sampling/custom_logit_processor.py +67 -1
  118. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  119. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  120. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  121. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  122. sglang/srt/server_args.py +103 -22
  123. sglang/srt/single_batch_overlap.py +4 -1
  124. sglang/srt/speculative/draft_utils.py +16 -0
  125. sglang/srt/speculative/eagle_info.py +42 -36
  126. sglang/srt/speculative/eagle_info_v2.py +68 -25
  127. sglang/srt/speculative/eagle_utils.py +261 -16
  128. sglang/srt/speculative/eagle_worker.py +11 -3
  129. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  130. sglang/srt/speculative/spec_info.py +305 -31
  131. sglang/srt/speculative/spec_utils.py +44 -8
  132. sglang/srt/tracing/trace.py +121 -12
  133. sglang/srt/utils/common.py +55 -32
  134. sglang/srt/utils/hf_transformers_utils.py +38 -16
  135. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  136. sglang/test/kits/radix_cache_server_kit.py +50 -0
  137. sglang/test/runners.py +31 -7
  138. sglang/test/simple_eval_common.py +5 -3
  139. sglang/test/simple_eval_humaneval.py +1 -0
  140. sglang/test/simple_eval_math.py +1 -0
  141. sglang/test/simple_eval_mmlu.py +1 -0
  142. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  143. sglang/test/test_utils.py +7 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
  146. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
  147. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  148. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  149. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -19,6 +19,7 @@ from sglang.srt.utils import add_prefix
19
19
  # https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py
20
20
  """Inference-only LLaMA-EAGLE model compatible with HuggingFace weights."""
21
21
 
22
+ import copy
22
23
  from typing import Iterable, Optional, Tuple
23
24
 
24
25
  import torch
@@ -161,6 +162,10 @@ class LlamaModel(nn.Module):
161
162
  if hidden_states.shape[-1] != embeds.shape[-1]:
162
163
  hidden_states = self.fc(hidden_states)
163
164
 
165
+ # idle batch
166
+ if hidden_states.shape[0] == 0:
167
+ return hidden_states, [hidden_states]
168
+
164
169
  residual = None
165
170
  hidden_states, residual = self.midlayer(
166
171
  positions,
@@ -212,7 +217,12 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM):
212
217
  prefix=add_prefix("lm_head", prefix),
213
218
  )
214
219
 
215
- self.logits_processor = LogitsProcessor(config)
220
+ config_ = copy.deepcopy(config)
221
+ config_.vocab_size = (
222
+ config_.draft_vocab_size
223
+ ) # draft logits processor has it's own vocab size
224
+ self.logits_processor = LogitsProcessor(config_)
225
+
216
226
  self.capture_aux_hidden_states = True
217
227
  self.hot_token_id = None
218
228
 
@@ -821,8 +821,8 @@ class LongcatFlashForCausalLM(nn.Module):
821
821
  experts = layer.mlp.experts
822
822
  if isinstance(experts, DeepEPMoE):
823
823
  for w in [
824
- experts.w13_weight_fp8,
825
- experts.w2_weight_fp8,
824
+ (experts.w13_weight, experts.w13_weight_scale_inv),
825
+ (experts.w2_weight, experts.w2_weight_scale_inv),
826
826
  ]:
827
827
  requant_weight_ue8m0_inplace(w[0], w[1], weight_block_size)
828
828
 
@@ -122,7 +122,7 @@ class MiniMaxM2RMSNormTP(nn.Module):
122
122
 
123
123
  # Normalize and apply local weight shard
124
124
  x = x * torch.rsqrt(variance + self.variance_epsilon)
125
- x = x.to(orig_dtype) * self.weight
125
+ x = (x * self.weight).to(orig_dtype)
126
126
 
127
127
  return x
128
128
 
@@ -462,7 +462,7 @@ class Qwen2ForCausalLM(nn.Module):
462
462
  self.pp_group.send(
463
463
  self.model.embed_tokens.weight, dst=self.pp_group.last_rank
464
464
  )
465
- else:
465
+ elif self.pp_group.is_last_rank:
466
466
  emb_token_weight = self.pp_group.recv(
467
467
  size=(config.vocab_size, config.hidden_size),
468
468
  dtype=next(self.model.parameters()).dtype,
@@ -473,10 +473,16 @@ class Qwen2MoeDecoderLayer(nn.Module):
473
473
  hidden_states: torch.Tensor,
474
474
  forward_batch: ForwardBatch,
475
475
  residual: Optional[torch.Tensor],
476
+ captured_last_layer_outputs: Optional[List[torch.Tensor]] = None,
476
477
  ) -> Tuple[torch.Tensor, torch.Tensor]:
477
478
 
478
- hidden_states, residual = self.layer_communicator.prepare_attn(
479
- hidden_states, residual, forward_batch
479
+ hidden_states, residual = (
480
+ self.layer_communicator.prepare_attn_and_capture_last_layer_outputs(
481
+ hidden_states,
482
+ residual,
483
+ forward_batch,
484
+ captured_last_layer_outputs=captured_last_layer_outputs,
485
+ )
480
486
  )
481
487
 
482
488
  if hidden_states.shape[0] != 0:
@@ -553,6 +559,11 @@ class Qwen2MoeModel(nn.Module):
553
559
  # For EAGLE3 support
554
560
  self.layers_to_capture = []
555
561
 
562
+ def set_eagle3_layers_to_capture(self, layers_to_capture: List[int]):
563
+ self.layers_to_capture = layers_to_capture
564
+ for layer_id in self.layers_to_capture:
565
+ setattr(self.layers[layer_id], "_is_layer_to_capture", True)
566
+
556
567
  def forward(
557
568
  self,
558
569
  input_ids: torch.Tensor,
@@ -585,12 +596,6 @@ class Qwen2MoeModel(nn.Module):
585
596
  )
586
597
  else:
587
598
  for i in range(self.start_layer, self.end_layer):
588
- if i in self.layers_to_capture:
589
- aux_hidden_states.append(
590
- hidden_states + residual
591
- if residual is not None
592
- else hidden_states
593
- )
594
599
  ctx = (
595
600
  nullcontext()
596
601
  if get_global_server_args().enable_piecewise_cuda_graph
@@ -599,7 +604,15 @@ class Qwen2MoeModel(nn.Module):
599
604
  with ctx:
600
605
  layer = self.layers[i]
601
606
  hidden_states, residual = layer(
602
- positions, hidden_states, forward_batch, residual
607
+ positions,
608
+ hidden_states,
609
+ forward_batch,
610
+ residual,
611
+ captured_last_layer_outputs=(
612
+ aux_hidden_states
613
+ if getattr(layer, "_is_layer_to_capture", False)
614
+ else None
615
+ ),
603
616
  )
604
617
  if not self.pp_group.is_last_rank:
605
618
  return PPProxyTensors(
@@ -830,13 +843,15 @@ class Qwen2MoeForCausalLM(nn.Module):
830
843
  self.capture_aux_hidden_states = True
831
844
  if layer_ids is None:
832
845
  num_layers = self.config.num_hidden_layers
833
- self.model.layers_to_capture = [
834
- 2,
835
- num_layers // 2,
836
- num_layers - 3,
837
- ] # Specific layers for EAGLE3 support
846
+ self.model.set_eagle3_layers_to_capture(
847
+ [
848
+ 2,
849
+ num_layers // 2,
850
+ num_layers - 3,
851
+ ]
852
+ ) # Specific layers for EAGLE3 support
838
853
  else:
839
- self.model.layers_to_capture = [val + 1 for val in layer_ids]
854
+ self.model.set_eagle3_layers_to_capture([val + 1 for val in layer_ids])
840
855
 
841
856
 
842
857
  EntryClass = Qwen2MoeForCausalLM
@@ -361,7 +361,7 @@ class Qwen3ForCausalLM(nn.Module):
361
361
  self.pp_group.send(
362
362
  self.model.embed_tokens.weight, dst=self.pp_group.last_rank
363
363
  )
364
- else:
364
+ elif self.pp_group.is_last_rank:
365
365
  emb_token_weight = self.pp_group.recv(
366
366
  size=(config.vocab_size, config.hidden_size),
367
367
  dtype=next(self.model.parameters()).dtype,
@@ -537,10 +537,16 @@ class Qwen3MoeDecoderLayer(nn.Module):
537
537
  hidden_states: torch.Tensor,
538
538
  forward_batch: ForwardBatch,
539
539
  residual: Optional[torch.Tensor],
540
+ captured_last_layer_outputs: Optional[List[torch.Tensor]] = None,
540
541
  ) -> Tuple[torch.Tensor, torch.Tensor]:
541
542
 
542
- hidden_states, residual = self.layer_communicator.prepare_attn(
543
- hidden_states, residual, forward_batch
543
+ hidden_states, residual = (
544
+ self.layer_communicator.prepare_attn_and_capture_last_layer_outputs(
545
+ hidden_states,
546
+ residual,
547
+ forward_batch,
548
+ captured_last_layer_outputs=captured_last_layer_outputs,
549
+ )
544
550
  )
545
551
 
546
552
  if hidden_states.shape[0] != 0:
@@ -772,13 +778,15 @@ class Qwen3MoeForCausalLM(nn.Module):
772
778
  self.capture_aux_hidden_states = True
773
779
  if layer_ids is None:
774
780
  num_layers = self.config.num_hidden_layers
775
- self.model.layers_to_capture = [
776
- 2,
777
- num_layers // 2,
778
- num_layers - 3,
779
- ] # Specific layers for EAGLE3 support
781
+ self.model.set_eagle3_layers_to_capture(
782
+ [
783
+ 2,
784
+ num_layers // 2,
785
+ num_layers - 3,
786
+ ]
787
+ ) # Specific layers for EAGLE3 support
780
788
  else:
781
- self.model.layers_to_capture = [val + 1 for val in layer_ids]
789
+ self.model.set_eagle3_layers_to_capture([val + 1 for val in layer_ids])
782
790
 
783
791
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
784
792
  stacked_params_mapping = [
@@ -478,6 +478,13 @@ class Qwen3GatedDeltaNet(nn.Module):
478
478
  # reshape input data into 2D tensor
479
479
  core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
480
480
  z = z.reshape(-1, z.shape[-1])
481
+
482
+ # Add padding for DP-Attn
483
+ if is_dp_attention_enabled():
484
+ core_attn_out_pad = torch.zeros_like(z)
485
+ core_attn_out_pad[: core_attn_out.shape[0], :] = core_attn_out
486
+ core_attn_out = core_attn_out_pad
487
+
481
488
  core_attn_out = self.norm(core_attn_out, z)
482
489
  core_attn_out = core_attn_out.reshape(z_shape_og)
483
490
  core_attn_out = core_attn_out.reshape(*core_attn_out.shape[:-2], -1)
@@ -0,0 +1,35 @@
1
+ from typing import Dict, Type
2
+
3
+ from transformers import PretrainedConfig, ProcessorMixin
4
+
5
+ # Useful for registering a custom processor different from Hugging Face's default.
6
+ _CUSTOMIZED_MM_PROCESSOR: Dict[str, Type[ProcessorMixin]] = dict()
7
+
8
+
9
+ def register_customized_processor(
10
+ processor_class: Type[ProcessorMixin],
11
+ ):
12
+ """Class decorator that maps a config class's model_type field to a customized processor class.
13
+
14
+ Args:
15
+ processor_class: A processor class that inherits from ProcessorMixin
16
+
17
+ Example:
18
+ ```python
19
+ @register_customized_processor(MyCustomProcessor)
20
+ class MyModelConfig(PretrainedConfig):
21
+ model_type = "my_model"
22
+
23
+ ```
24
+ """
25
+
26
+ def decorator(config_class: PretrainedConfig):
27
+ if not hasattr(config_class, "model_type"):
28
+ raise ValueError(
29
+ f"Class {config_class.__name__} with register_customized_processor should "
30
+ f"have a 'model_type' class attribute."
31
+ )
32
+ _CUSTOMIZED_MM_PROCESSOR[config_class.model_type] = processor_class
33
+ return config_class
34
+
35
+ return decorator
@@ -0,0 +1,209 @@
1
+ """
2
+ Mixin class providing multiplexing scheduling logic
3
+ """
4
+
5
+ import logging
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ from torch.cuda.streams import ExternalStream
10
+
11
+ from sglang.srt.distributed.parallel_state import set_pdmux_status
12
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
13
+ from sglang.srt.multiplex.pdmux_context import (
14
+ get_current_stream_idx,
15
+ get_sm_counts,
16
+ get_stream_groups,
17
+ initialize_stream_groups,
18
+ load_pdmux_config,
19
+ set_current_stream_idx,
20
+ )
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class SchedulerMultiplexMixin:
26
+
27
+ def init_pdmux(self):
28
+ # for pd_multiplexing, Init stream_groups, exclude normal stream for prefill only and decode only
29
+ self.pdmux_config = load_pdmux_config(self.server_args.pdmux_config_path)
30
+ initialize_stream_groups(self.gpu_id, self.pdmux_config)
31
+ self.stream_groups = get_stream_groups()
32
+ self.sm_counts = get_sm_counts()
33
+ self.real_sm_group_num = len(self.stream_groups)
34
+ logger.info(
35
+ f"PD-Multiplexing enabled with {self.real_sm_group_num} stream groups, sm_counts (prefill_sm, decode_sm): {self.sm_counts}"
36
+ )
37
+
38
+ # TODO(jason-fxz): This is a temporary demo
39
+ def adjust_stream_groups(self) -> tuple[int, tuple[ExternalStream, ExternalStream]]:
40
+ if not self.running_batch.is_empty() and self.split_prefill_batch:
41
+ decode_bs = self.running_batch.batch_size()
42
+ manual_divisions = self.pdmux_config.manual_divisions
43
+ if manual_divisions:
44
+ for i in range(len(manual_divisions)):
45
+ _, _, threshold = manual_divisions[i]
46
+ if decode_bs >= threshold:
47
+ stream_idx = i + 1
48
+ else:
49
+ stream_idx = max(
50
+ 1,
51
+ min(
52
+ self.real_sm_group_num - 2,
53
+ decode_bs
54
+ * (self.real_sm_group_num - 2)
55
+ // self.pdmux_config.decode_bs_divisor,
56
+ ),
57
+ )
58
+ set_current_stream_idx(stream_idx)
59
+ elif not self.running_batch.is_empty():
60
+ set_current_stream_idx(self.real_sm_group_num - 1)
61
+ else:
62
+ set_current_stream_idx(0)
63
+
64
+ stream_idx = get_current_stream_idx()
65
+
66
+ self.tp_worker.model_runner.update_decode_attn_backend(stream_idx)
67
+ return stream_idx, self.stream_groups[stream_idx]
68
+
69
+ def update_split_prefill_batch(self, sm_count: int) -> bool:
70
+ if self.split_prefill_batch:
71
+ return False
72
+
73
+ # add new request
74
+ batch = self.get_new_batch_prefill()
75
+ if batch and not batch.is_empty():
76
+ batch.forward_mode = (
77
+ ForwardMode.SPLIT_PREFILL
78
+ ) # Set forward mode for split prefill
79
+ self.split_prefill_batch = batch
80
+ return True
81
+ return False
82
+
83
+ @torch.inference_mode()
84
+ def event_loop_pdmux(self):
85
+ """A scheduler loop for pd multiplexing."""
86
+ decode_done = False
87
+ prefill_done = False
88
+ wait_prefill_kernel_done = False
89
+ adjust_stream_group = False
90
+ stream_idx = get_current_stream_idx()
91
+ stream_group = self.stream_groups[stream_idx]
92
+ prefill_stream = stream_group[0]
93
+ decode_stream = stream_group[1]
94
+ torch.cuda.empty_cache()
95
+
96
+ logger.debug("Starting event loop for pd multiplexing...")
97
+
98
+ while True:
99
+ with torch.cuda.stream(decode_stream):
100
+ set_pdmux_status(False)
101
+ recv_reqs = self.recv_requests()
102
+ self.process_input_requests(recv_reqs)
103
+
104
+ with torch.cuda.stream(prefill_stream):
105
+ set_pdmux_status(True)
106
+ sm_count = self.sm_counts[stream_idx][0]
107
+ if not wait_prefill_kernel_done:
108
+ adjust_stream_group = (
109
+ self.update_split_prefill_batch(sm_count) or adjust_stream_group
110
+ )
111
+
112
+ with torch.cuda.stream(decode_stream):
113
+ set_pdmux_status(False)
114
+ self.running_batch = self.update_running_batch(self.running_batch)
115
+ adjust_stream_group = adjust_stream_group or (
116
+ stream_idx > 0 and self.running_batch.is_empty()
117
+ )
118
+ if self.running_batch.is_empty() and self.split_prefill_batch is None:
119
+ self.check_memory()
120
+ self.check_tree_cache()
121
+ self.new_token_ratio = self.init_new_token_ratio
122
+ self.maybe_sleep_on_idle()
123
+
124
+ if adjust_stream_group:
125
+ prefill_stream.synchronize()
126
+ decode_stream.synchronize()
127
+ stream_idx, stream_group = self.adjust_stream_groups()
128
+ prefill_stream = stream_group[0]
129
+ decode_stream = stream_group[1]
130
+ adjust_stream_group = False
131
+ logger.debug(
132
+ f"Adjusting stream groups: {stream_idx}, prefill sm: {self.sm_counts[stream_idx][0]}, decode sm: {self.sm_counts[stream_idx][1]}"
133
+ )
134
+
135
+ with torch.cuda.stream(decode_stream):
136
+ set_pdmux_status(False)
137
+ # process decode batch
138
+ if self.running_batch and not self.running_batch.is_empty():
139
+ decode_result = self.run_batch(self.running_batch)
140
+ decode_done = True
141
+ else:
142
+ decode_done = False
143
+ with torch.cuda.stream(prefill_stream):
144
+ set_pdmux_status(True)
145
+ if (
146
+ self.split_prefill_batch
147
+ and not self.split_prefill_batch.is_empty()
148
+ and not wait_prefill_kernel_done
149
+ ):
150
+ prefill_done = True
151
+ forward_count = (
152
+ max(
153
+ 1,
154
+ self.pdmux_config.split_forward_token_budget
155
+ // self.split_prefill_batch.extend_num_tokens,
156
+ )
157
+ if self.split_prefill_batch.extend_num_tokens > 0
158
+ else self.model_config.num_hidden_layers
159
+ )
160
+ next_split_index = min(
161
+ self.split_prefill_batch.split_index + forward_count,
162
+ self.model_config.num_hidden_layers,
163
+ )
164
+ forward_count = (
165
+ next_split_index - self.split_prefill_batch.split_index
166
+ )
167
+
168
+ self.split_prefill_batch.split_forward_count = forward_count
169
+ prefill_result = self.run_batch(self.split_prefill_batch)
170
+ if next_split_index == self.model_config.num_hidden_layers:
171
+ self.split_prefill_batch.split_prefill_finished = True
172
+ prefill_exe_done = prefill_stream.record_event()
173
+ self.split_prefill_batch.split_index = next_split_index
174
+
175
+ elif wait_prefill_kernel_done:
176
+ prefill_done = True
177
+ else:
178
+ prefill_done = False
179
+
180
+ with torch.cuda.stream(decode_stream):
181
+ set_pdmux_status(False)
182
+ decode_stream.synchronize()
183
+ if decode_done:
184
+ self.process_batch_result(self.running_batch, decode_result)
185
+
186
+ with torch.cuda.stream(prefill_stream):
187
+ set_pdmux_status(True)
188
+ if prefill_done and self.split_prefill_batch.split_prefill_finished:
189
+ wait_prefill_kernel_done = True
190
+ prefill_exe_done_flag = prefill_exe_done.query()
191
+ flags = (
192
+ torch.ones(1, device="cpu", dtype=torch.int32)
193
+ if prefill_exe_done_flag
194
+ else torch.zeros(1, device="cpu", dtype=torch.int32)
195
+ )
196
+
197
+ self.tp_cpu_group.allreduce(flags, dist.ReduceOp.SUM).wait()
198
+ if flags.item() == self.tp_size:
199
+ self.process_batch_result(
200
+ self.split_prefill_batch, prefill_result
201
+ )
202
+ if self.running_batch and not self.running_batch.is_empty():
203
+ self.running_batch.merge_batch(self.split_prefill_batch)
204
+ else:
205
+ self.running_batch = self.split_prefill_batch
206
+
207
+ self.split_prefill_batch = None
208
+ wait_prefill_kernel_done = False
209
+ adjust_stream_group = True
@@ -0,0 +1,164 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import List
3
+
4
+ import torch
5
+ import yaml
6
+
7
+ STREAM_GROUPS = []
8
+ SM_COUNTS = []
9
+ SM_GROUP_NUM = 8 # Default number of SM groups
10
+ CURRENT_STREAM_IDX = 0
11
+ CURRENT_STREAM_GROUP = None
12
+
13
+
14
+ @dataclass
15
+ class PDMuxConfig:
16
+ sm_group_num: int = 8
17
+ manual_divisions: List[List[int]] = field(
18
+ default_factory=list
19
+ ) # [prefill_sm, decode_sm, decode_bs_threshold]
20
+ split_forward_token_budget: int = 65536
21
+ decode_bs_divisor: int = 36
22
+
23
+
24
+ def load_pdmux_config(config_path: str) -> PDMuxConfig:
25
+ """Load pdmux configuration from YAML file into a dataclass."""
26
+ if not config_path:
27
+ return PDMuxConfig()
28
+
29
+ with open(config_path, "r") as f:
30
+ raw = yaml.safe_load(f)
31
+
32
+ if "sm_group_num" not in raw:
33
+ raise ValueError("Missing required field: sm_group_num")
34
+
35
+ if raw["sm_group_num"] < 3:
36
+ raise ValueError("sm_group_num must greater than 3")
37
+
38
+ manual_divisions = raw.get("manual_divisions", [])
39
+
40
+ expected = raw["sm_group_num"] - 2
41
+ if manual_divisions and len(manual_divisions) != expected:
42
+ raise ValueError(
43
+ f"manual_divisions must have {expected} entries, "
44
+ f"but got {len(manual_divisions)}"
45
+ )
46
+
47
+ return PDMuxConfig(
48
+ sm_group_num=raw["sm_group_num"],
49
+ manual_divisions=manual_divisions,
50
+ split_forward_token_budget=raw.get("split_forward_token_budget", 65536),
51
+ decode_bs_divisor=raw.get("decode_bs_divisor", 36),
52
+ )
53
+
54
+
55
+ def get_arch_constraints(compute_capability):
56
+ major, minor = compute_capability
57
+ # green context constraints for different architectures
58
+ if major == 6:
59
+ return 1, 1 # min_per_part, multiple
60
+ elif major == 7:
61
+ return 2, 2
62
+ elif major == 8:
63
+ return 4, 2
64
+ elif major == 9 and minor >= 0:
65
+ return 8, 8
66
+ else:
67
+ raise ValueError(f"Unsupported compute capability: {major}.{minor}")
68
+
69
+
70
+ def divide_sm(total_sms, compute_capability, groups):
71
+ """
72
+ :param total_sms: total sm count on a single GPU
73
+ :param compute_capability: (major, minor)
74
+ :return: SM partition group(prefill sm, decode sm)
75
+ """
76
+ min_per_part, multiple = get_arch_constraints(compute_capability)
77
+ possible_values = [
78
+ x
79
+ for x in range(min_per_part, total_sms - min_per_part + 1, multiple)
80
+ if x >= total_sms - x and total_sms - x >= 16
81
+ ]
82
+ if not possible_values:
83
+ raise ValueError(
84
+ f"No valid partitions found for total SMs {total_sms} "
85
+ f"with constraints (min per part: {min_per_part}, multiple: {multiple})"
86
+ )
87
+
88
+ if len(possible_values) >= groups:
89
+ step = max(1, len(possible_values) // groups)
90
+ selected_values = possible_values[::step][:groups]
91
+ else:
92
+ selected_values = possible_values
93
+
94
+ divisions = []
95
+ for part1 in selected_values:
96
+ part2 = total_sms - part1
97
+ divisions.append((part1, part2))
98
+
99
+ divisions.reverse() # Reverse to have larger prefill SM first
100
+
101
+ return divisions
102
+
103
+
104
+ def initialize_stream_groups(gpu_id: int, config: PDMuxConfig):
105
+ from sgl_kernel import spatial
106
+
107
+ global STREAM_GROUPS, SM_COUNTS, SM_GROUP_NUM, CURRENT_STREAM_IDX, CURRENT_STREAM_GROUP
108
+ # for pd_multiplexing, Init stream_groups
109
+ device = torch.cuda.current_device()
110
+ total_sm_count = spatial.get_sm_available(gpu_id)
111
+ # (prefill_sm_count, decode_sm_count)
112
+ if config.manual_divisions:
113
+ divisions = [
114
+ (prefill_sm, decode_sm)
115
+ for prefill_sm, decode_sm, _ in config.manual_divisions
116
+ ]
117
+ else:
118
+ divisions = divide_sm(
119
+ total_sm_count,
120
+ torch.cuda.get_device_capability(device),
121
+ config.sm_group_num - 2,
122
+ )
123
+
124
+ SM_COUNTS = []
125
+ SM_COUNTS.append((total_sm_count, 0)) # Normal stream for prefill
126
+ SM_COUNTS.extend(divisions) # Add the divided SM counts
127
+ SM_COUNTS.append((0, total_sm_count)) # Normal stream for decode
128
+ STREAM_GROUPS = []
129
+ STREAM_GROUPS.append(
130
+ (torch.cuda.Stream(gpu_id), torch.cuda.Stream(gpu_id))
131
+ ) # Normal stream for prefill
132
+ for prefill_sm, decode_sm in divisions:
133
+ STREAM_GROUPS.append(
134
+ (spatial.create_greenctx_stream_by_value(prefill_sm, decode_sm, gpu_id))
135
+ )
136
+ STREAM_GROUPS.append(
137
+ (torch.cuda.Stream(gpu_id), torch.cuda.Stream(gpu_id))
138
+ ) # Normal stream for decode
139
+
140
+ CURRENT_STREAM_IDX = 0
141
+ CURRENT_STREAM_GROUP = STREAM_GROUPS[CURRENT_STREAM_IDX]
142
+
143
+
144
+ def set_current_stream_idx(idx: int):
145
+ global CURRENT_STREAM_IDX, CURRENT_STREAM_GROUP
146
+ if idx < 0 or idx >= len(STREAM_GROUPS):
147
+ raise ValueError(f"Invalid stream index: {idx}")
148
+ CURRENT_STREAM_IDX = idx
149
+ CURRENT_STREAM_GROUP = STREAM_GROUPS[CURRENT_STREAM_IDX]
150
+
151
+
152
+ def get_stream_groups() -> list[tuple[torch.cuda.Stream, torch.cuda.Stream]]:
153
+ """Get the stream groups."""
154
+ return STREAM_GROUPS
155
+
156
+
157
+ def get_sm_counts() -> list[tuple[int, int]]:
158
+ """Get the SM counts."""
159
+ return SM_COUNTS
160
+
161
+
162
+ def get_current_stream_idx() -> int:
163
+ """Get the current stream index."""
164
+ return CURRENT_STREAM_IDX
@@ -101,6 +101,7 @@ class Conversation:
101
101
  stop_token_ids: Optional[int] = None
102
102
 
103
103
  audio_data: Optional[List[str]] = None
104
+ image_token_at_prefix: bool = False
104
105
 
105
106
  def get_prompt(self) -> str:
106
107
  """Get the prompt for generation."""
@@ -445,6 +446,7 @@ class Conversation:
445
446
  image_token=self.image_token,
446
447
  video_token=self.video_token,
447
448
  audio_token=self.audio_token,
449
+ image_token_at_prefix=self.image_token_at_prefix,
448
450
  )
449
451
 
450
452
  def dict(self):
@@ -512,6 +514,7 @@ def generate_embedding_convs(
512
514
  image_token=conv_template.image_token,
513
515
  video_token=conv_template.video_token,
514
516
  audio_token=conv_template.audio_token,
517
+ image_token_at_prefix=conv_template.image_token_at_prefix,
515
518
  )
516
519
  real_content = ""
517
520
 
@@ -578,6 +581,7 @@ def generate_chat_conv(
578
581
  image_token=conv.image_token,
579
582
  audio_token=conv.audio_token,
580
583
  video_token=conv.video_token,
584
+ image_token_at_prefix=conv.image_token_at_prefix,
581
585
  )
582
586
 
583
587
  if isinstance(request.messages, str):
@@ -627,7 +631,7 @@ def generate_chat_conv(
627
631
  real_content += content.text
628
632
  elif content.type == "image_url":
629
633
  # NOTE: works for llava and intervl2_5
630
- if conv.name in ["internvl-2-5"]:
634
+ if conv.image_token_at_prefix:
631
635
  real_content = image_token + real_content
632
636
  else:
633
637
  real_content += image_token
@@ -820,6 +824,7 @@ register_conv_template(
820
824
  sep="<|im_end|>\n",
821
825
  stop_str=["<|im_end|>", "<|action_end|>"],
822
826
  image_token="<IMG_CONTEXT>",
827
+ image_token_at_prefix=True,
823
828
  )
824
829
  )
825
830
 
@@ -848,6 +853,7 @@ register_conv_template(
848
853
  sep_style=SeparatorStyle.NO_COLON_SINGLE,
849
854
  stop_str=["<|end▁of▁sentence|>"],
850
855
  image_token="<image>",
856
+ image_token_at_prefix=True,
851
857
  )
852
858
  )
853
859