sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -39,6 +39,7 @@ import triton
39
39
  import triton.language as tl
40
40
 
41
41
  from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
42
+ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
42
43
  from sglang.srt.layers.dp_attention import (
43
44
  DpPaddingMode,
44
45
  get_attention_dp_rank,
@@ -89,12 +90,9 @@ class ForwardMode(IntEnum):
89
90
  self == ForwardMode.EXTEND
90
91
  or self == ForwardMode.MIXED
91
92
  or self == ForwardMode.DRAFT_EXTEND
92
- or (
93
- self == ForwardMode.DRAFT_EXTEND_V2
94
- if include_draft_extend_v2
95
- else False
96
- )
93
+ or (include_draft_extend_v2 and self == ForwardMode.DRAFT_EXTEND_V2)
97
94
  or self == ForwardMode.TARGET_VERIFY
95
+ or self == ForwardMode.SPLIT_PREFILL
98
96
  )
99
97
 
100
98
  def is_decode(self):
@@ -113,22 +111,21 @@ class ForwardMode(IntEnum):
113
111
  return self == ForwardMode.TARGET_VERIFY
114
112
 
115
113
  def is_draft_extend(self, include_v2: bool = False):
116
- if include_v2:
117
- return (
118
- self == ForwardMode.DRAFT_EXTEND_V2 or self == ForwardMode.DRAFT_EXTEND
119
- )
120
- return self == ForwardMode.DRAFT_EXTEND
114
+ return self == ForwardMode.DRAFT_EXTEND or (
115
+ include_v2 and self == ForwardMode.DRAFT_EXTEND_V2
116
+ )
121
117
 
122
118
  def is_draft_extend_v2(self):
123
119
  # For fixed shape logits output in v2 eagle worker
124
120
  return self == ForwardMode.DRAFT_EXTEND_V2
125
121
 
126
- def is_extend_or_draft_extend_or_mixed(self):
122
+ def is_extend_or_draft_extend_or_mixed(self, include_draft_extend_v2: bool = False):
127
123
  return (
128
124
  self == ForwardMode.EXTEND
129
125
  or self == ForwardMode.DRAFT_EXTEND
130
126
  or self == ForwardMode.MIXED
131
127
  or self == ForwardMode.SPLIT_PREFILL
128
+ or (include_draft_extend_v2 and self == ForwardMode.DRAFT_EXTEND_V2)
132
129
  )
133
130
 
134
131
  def is_cuda_graph(self):
@@ -250,6 +247,8 @@ class ForwardBatch:
250
247
  # For MLA chunked prefix cache used in chunked prefill
251
248
  # Tell attention backend whether lse needs to be returned
252
249
  mha_return_lse: Optional[bool] = None
250
+ mha_one_shot_kv_indices: Optional[torch.Tensor] = None
251
+ mha_one_shot: Optional[bool] = None
253
252
 
254
253
  # For multimodal
255
254
  mm_inputs: Optional[List[MultimodalInputs]] = None
@@ -316,6 +315,9 @@ class ForwardBatch:
316
315
  tbo_parent_token_range: Optional[Tuple[int, int]] = None
317
316
  tbo_children: Optional[List[ForwardBatch]] = None
318
317
 
318
+ # For matryoshka embeddings
319
+ dimensions: Optional[list[int]] = None
320
+
319
321
  @classmethod
320
322
  def init_new(
321
323
  cls,
@@ -357,6 +359,7 @@ class ForwardBatch:
357
359
  input_embeds=batch.input_embeds,
358
360
  token_type_ids=batch.token_type_ids,
359
361
  tbo_split_seq_index=batch.tbo_split_seq_index,
362
+ dimensions=batch.dimensions,
360
363
  )
361
364
  device = model_runner.device
362
365
 
@@ -572,9 +575,15 @@ class ForwardBatch:
572
575
  device=model_runner.device,
573
576
  )
574
577
  else:
575
- mrope_position_deltas = mm_input.mrope_position_delta.flatten().to(
576
- model_runner.device, non_blocking=True
577
- )
578
+ if mm_input.mrope_position_delta.device.type != model_runner.device:
579
+ # transfer mrope_position_delta to device when the first running,
580
+ # avoiding successvie host-to-device data transfer
581
+ mm_input.mrope_position_delta = (
582
+ mm_input.mrope_position_delta.to(
583
+ model_runner.device, non_blocking=True
584
+ )
585
+ )
586
+ mrope_position_deltas = mm_input.mrope_position_delta.flatten()
578
587
  mrope_positions_list[batch_idx] = (
579
588
  (mrope_position_deltas + self.seq_lens[batch_idx] - 1)
580
589
  .unsqueeze(0)
@@ -863,6 +872,10 @@ class ForwardBatch:
863
872
  self.token_to_kv_pool, MLATokenToKVPool
864
873
  ), "Currently chunked prefix cache can only be used by Deepseek models"
865
874
 
875
+ if not any(self.extend_prefix_lens_cpu):
876
+ self.num_prefix_chunks = 0
877
+ return
878
+
866
879
  if self.prefix_chunk_len is not None:
867
880
  # Chunked kv cache info already prepared by prior modules
868
881
  return
@@ -917,6 +930,34 @@ class ForwardBatch:
917
930
  def can_run_tbo(self):
918
931
  return self.tbo_split_seq_index is not None
919
932
 
933
+ def fetch_mha_one_shot_kv_indices(self):
934
+ if self.mha_one_shot_kv_indices is not None:
935
+ return self.mha_one_shot_kv_indices
936
+ batch_size = self.batch_size
937
+ paged_kernel_lens_sum = sum(self.seq_lens_cpu)
938
+ kv_indices = torch.empty(
939
+ paged_kernel_lens_sum,
940
+ dtype=torch.int32,
941
+ device=self.req_pool_indices.device,
942
+ )
943
+ kv_indptr = torch.zeros(
944
+ batch_size + 1,
945
+ dtype=torch.int32,
946
+ device=self.req_pool_indices.device,
947
+ )
948
+ kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
949
+ create_flashinfer_kv_indices_triton[(self.batch_size,)](
950
+ self.req_to_token_pool.req_to_token,
951
+ self.req_pool_indices,
952
+ self.seq_lens,
953
+ kv_indptr,
954
+ None,
955
+ kv_indices,
956
+ self.req_to_token_pool.req_to_token.shape[1],
957
+ )
958
+ self.mha_one_shot_kv_indices = kv_indices
959
+ return kv_indices
960
+
920
961
 
921
962
  def enable_num_token_non_padded(server_args):
922
963
  return get_moe_expert_parallel_world_size() > 1
@@ -29,7 +29,12 @@ from typing import Callable, List, Optional, Tuple, Union
29
29
  import torch
30
30
  import torch.distributed as dist
31
31
 
32
- from sglang.srt.configs import FalconH1Config, NemotronHConfig, Qwen3NextConfig
32
+ from sglang.srt.configs import (
33
+ FalconH1Config,
34
+ KimiLinearConfig,
35
+ NemotronHConfig,
36
+ Qwen3NextConfig,
37
+ )
33
38
  from sglang.srt.configs.device_config import DeviceConfig
34
39
  from sglang.srt.configs.load_config import LoadConfig, LoadFormat
35
40
  from sglang.srt.configs.model_config import (
@@ -40,6 +45,9 @@ from sglang.srt.configs.model_config import (
40
45
  )
41
46
  from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
42
47
  from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
48
+ from sglang.srt.debug_utils.tensor_dump_forward_hook import (
49
+ register_forward_hook_for_model,
50
+ )
43
51
  from sglang.srt.distributed import (
44
52
  get_pp_group,
45
53
  get_tp_group,
@@ -77,7 +85,6 @@ from sglang.srt.layers.dp_attention import (
77
85
  initialize_dp_attention,
78
86
  )
79
87
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
80
- from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
81
88
  from sglang.srt.layers.sampler import Sampler
82
89
  from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
83
90
  from sglang.srt.lora.lora_manager import LoRAManager
@@ -131,16 +138,10 @@ from sglang.srt.utils import (
131
138
  get_bool_env_var,
132
139
  get_cpu_ids_by_node,
133
140
  init_custom_process_group,
134
- is_fa3_default_architecture,
135
- is_flashinfer_available,
136
141
  is_hip,
137
- is_hopper_with_cuda_12_3,
138
- is_no_spec_infer_or_topk_one,
139
142
  is_npu,
140
- is_sm100_supported,
141
143
  log_info_on_rank0,
142
144
  monkey_patch_p2p_access_check,
143
- monkey_patch_vllm_gguf_config,
144
145
  set_cuda_arch,
145
146
  slow_rank_detector,
146
147
  xpu_has_xmx_support,
@@ -355,7 +356,11 @@ class ModelRunner:
355
356
 
356
357
  if not self.is_draft_worker:
357
358
  set_global_expert_location_metadata(
358
- compute_initial_expert_location_metadata(server_args, self.model_config)
359
+ compute_initial_expert_location_metadata(
360
+ server_args=server_args,
361
+ model_config=self.model_config,
362
+ moe_ep_rank=self.moe_ep_rank,
363
+ )
359
364
  )
360
365
  if self.tp_rank == 0 and get_bool_env_var(
361
366
  "SGLANG_LOG_EXPERT_LOCATION_METADATA"
@@ -503,121 +508,6 @@ class ModelRunner:
503
508
  def model_specific_adjustment(self):
504
509
  server_args = self.server_args
505
510
 
506
- if (
507
- server_args.attention_backend == "intel_amx"
508
- and server_args.device == "cpu"
509
- and not _is_cpu_amx_available
510
- ):
511
- logger.info(
512
- "The current platform does not support Intel AMX, will fallback to torch_native backend."
513
- )
514
- server_args.attention_backend = "torch_native"
515
-
516
- if (
517
- server_args.attention_backend == "intel_xpu"
518
- and server_args.device == "xpu"
519
- and not _is_xpu_xmx_available
520
- ):
521
- logger.info(
522
- "The current platform does not support Intel XMX, will fallback to triton backend."
523
- )
524
- server_args.attention_backend = "triton"
525
-
526
- if server_args.prefill_attention_backend is not None and (
527
- server_args.prefill_attention_backend
528
- == server_args.decode_attention_backend
529
- ): # override the default attention backend
530
- server_args.attention_backend = server_args.prefill_attention_backend
531
-
532
- if (
533
- getattr(self.model_config.hf_config, "dual_chunk_attention_config", None)
534
- is not None
535
- ):
536
- if server_args.attention_backend is None:
537
- server_args.attention_backend = "dual_chunk_flash_attn"
538
- logger.info("Dual chunk attention is turned on by default.")
539
- elif server_args.attention_backend != "dual_chunk_flash_attn":
540
- raise ValueError(
541
- "Dual chunk attention is enabled, but attention backend is set to "
542
- f"{server_args.attention_backend}. Please set it to 'dual_chunk_flash_attn'."
543
- )
544
-
545
- if server_args.attention_backend is None:
546
- """
547
- Auto select the fastest attention backend.
548
-
549
- 1. Models with MHA Architecture (e.g: Llama, QWen)
550
- 1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
551
- 1.2 In other cases, we will use flashinfer if available, otherwise use triton.
552
- 2. Models with MLA Architecture and using FA3
553
- 2.1 We will use FA3 backend on hopper.
554
- 2.2 We will use Flashinfer backend on blackwell.
555
- 2.3 Otherwise, we will use triton backend.
556
- """
557
-
558
- if not self.use_mla_backend:
559
- # MHA architecture
560
- if (
561
- is_hopper_with_cuda_12_3()
562
- and is_no_spec_infer_or_topk_one(server_args)
563
- and is_fa3_default_architecture(self.model_config.hf_config)
564
- ):
565
- server_args.attention_backend = "fa3"
566
- elif _is_hip:
567
- server_args.attention_backend = "aiter"
568
- elif _is_npu:
569
- server_args.attention_backend = "ascend"
570
- else:
571
- server_args.attention_backend = (
572
- "flashinfer" if is_flashinfer_available() else "triton"
573
- )
574
- else:
575
- # MLA architecture
576
- if is_hopper_with_cuda_12_3():
577
- server_args.attention_backend = "fa3"
578
- elif is_sm100_supported():
579
- server_args.attention_backend = "flashinfer"
580
- elif _is_hip:
581
- head_num = self.model_config.get_num_kv_heads(self.tp_size)
582
- # TODO current aiter only support head number 16 or 128 head number
583
- if head_num == 128 or head_num == 16:
584
- server_args.attention_backend = "aiter"
585
- else:
586
- server_args.attention_backend = "triton"
587
- elif _is_npu:
588
- server_args.attention_backend = "ascend"
589
- else:
590
- server_args.attention_backend = "triton"
591
- log_info_on_rank0(
592
- logger,
593
- f"Attention backend not explicitly specified. Use {server_args.attention_backend} backend by default.",
594
- )
595
- elif self.use_mla_backend:
596
- if server_args.device != "cpu":
597
- if server_args.attention_backend in MLA_ATTENTION_BACKENDS:
598
- logger.info(
599
- f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
600
- )
601
- else:
602
- raise ValueError(
603
- f"Invalid attention backend for MLA: {server_args.attention_backend}"
604
- )
605
- else:
606
- if server_args.attention_backend != "intel_amx":
607
- raise ValueError(
608
- "MLA optimization not supported on CPU except for intel_amx backend."
609
- )
610
-
611
- if (
612
- server_args.attention_backend == "fa3"
613
- and server_args.kv_cache_dtype == "fp8_e5m2"
614
- ):
615
- logger.warning(
616
- "FlashAttention3 only supports fp8_e4m3 if using FP8; "
617
- "Setting attention backend to triton."
618
- )
619
- server_args.attention_backend = "triton"
620
-
621
511
  if server_args.enable_double_sparsity:
622
512
  logger.info(
623
513
  "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
@@ -643,37 +533,12 @@ class ModelRunner:
643
533
  if not server_args.disable_chunked_prefix_cache:
644
534
  log_info_on_rank0(logger, "Chunked prefix cache is turned on.")
645
535
 
646
- if server_args.attention_backend == "aiter":
647
- if self.model_config.context_len > 8192:
648
- self.mem_fraction_static *= 0.85
649
-
650
- if (
651
- server_args.enable_hierarchical_cache
652
- and server_args.hicache_io_backend == "kernel"
653
- ):
654
- # fix for the compatibility issue with FlashAttention3 decoding and HiCache kernel backend
655
- if server_args.decode_attention_backend is None:
656
- if not self.use_mla_backend:
657
- server_args.decode_attention_backend = (
658
- "flashinfer" if is_flashinfer_available() else "triton"
659
- )
660
- else:
661
- server_args.decode_attention_backend = (
662
- "flashinfer" if is_sm100_supported() else "triton"
663
- )
664
- elif server_args.decode_attention_backend == "fa3":
665
- server_args.hicache_io_backend = "direct"
666
- logger.warning(
667
- "FlashAttention3 decode backend is not compatible with hierarchical cache. "
668
- "Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
669
- )
670
-
671
536
  if self.model_config.hf_config.model_type == "qwen3_vl_moe":
672
537
  if (
673
538
  quantization_config := getattr(
674
539
  self.model_config.hf_config, "quantization_config", None
675
540
  )
676
- ) is not None:
541
+ ) is not None and "weight_block_size" in quantization_config:
677
542
  weight_block_size_n = quantization_config["weight_block_size"][0]
678
543
 
679
544
  if self.tp_size % self.moe_ep_size != 0:
@@ -858,8 +723,6 @@ class ModelRunner:
858
723
  self.model_config = adjust_config_with_unaligned_cpu_tp(
859
724
  self.model_config, self.load_config, self.tp_size
860
725
  )
861
- if self.server_args.load_format == "gguf":
862
- monkey_patch_vllm_gguf_config()
863
726
 
864
727
  if self.server_args.load_format == LoadFormat.REMOTE_INSTANCE:
865
728
  if self.tp_rank == 0:
@@ -878,7 +741,6 @@ class ModelRunner:
878
741
  # Load the model
879
742
  # Remove monkey_patch when linear.py quant remove dependencies with vllm
880
743
  monkey_patch_vllm_parallel_state()
881
- monkey_patch_isinstance_for_vllm_base_layer()
882
744
 
883
745
  with self.memory_saver_adapter.region(
884
746
  GPU_MEMORY_TYPE_WEIGHTS,
@@ -890,7 +752,6 @@ class ModelRunner:
890
752
  device_config=DeviceConfig(self.device, self.gpu_id),
891
753
  )
892
754
  monkey_patch_vllm_parallel_state(reverse=True)
893
- monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
894
755
 
895
756
  get_offloader().post_init()
896
757
 
@@ -938,6 +799,15 @@ class ModelRunner:
938
799
  f"avail mem={after_avail_memory:.2f} GB, "
939
800
  f"mem usage={self.weight_load_mem_usage:.2f} GB."
940
801
  )
802
+ if self.server_args.debug_tensor_dump_output_folder is not None:
803
+ register_forward_hook_for_model(
804
+ self.model,
805
+ self.server_args.debug_tensor_dump_output_folder,
806
+ self.server_args.debug_tensor_dump_layers,
807
+ self.tp_size,
808
+ self.tp_rank,
809
+ self.pp_rank,
810
+ )
941
811
 
942
812
  if self.server_args.elastic_ep_backend == "mooncake":
943
813
  # Mooncake does not support `monitored_barrier`
@@ -1493,9 +1363,16 @@ class ModelRunner:
1493
1363
  return config
1494
1364
  return None
1495
1365
 
1366
+ @property
1367
+ def kimi_linear_config(self):
1368
+ config = self.model_config.hf_config
1369
+ if isinstance(config, KimiLinearConfig):
1370
+ return config
1371
+ return None
1372
+
1496
1373
  @property
1497
1374
  def mambaish_config(self):
1498
- return self.mamba2_config or self.hybrid_gdn_config
1375
+ return self.mamba2_config or self.hybrid_gdn_config or self.kimi_linear_config
1499
1376
 
1500
1377
  def set_num_token_hybrid(self):
1501
1378
  if (
@@ -1806,9 +1683,11 @@ class ModelRunner:
1806
1683
  get_attention_tp_size()
1807
1684
  ),
1808
1685
  head_dim=self.model_config.head_dim,
1809
- layer_num=self.model_config.num_hidden_layers,
1686
+ layer_num=self.num_effective_layers,
1810
1687
  device=self.device,
1811
1688
  enable_memory_saver=self.server_args.enable_memory_saver,
1689
+ start_layer=self.start_layer,
1690
+ end_layer=self.end_layer,
1812
1691
  )
1813
1692
  elif self.use_mla_backend and is_nsa_model:
1814
1693
  self.token_to_kv_pool = NSATokenToKVPool(
@@ -1824,7 +1703,7 @@ class ModelRunner:
1824
1703
  end_layer=self.end_layer,
1825
1704
  index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config),
1826
1705
  )
1827
- elif self.use_mla_backend:
1706
+ elif self.use_mla_backend and not self.mambaish_config:
1828
1707
  assert not is_nsa_model
1829
1708
  self.token_to_kv_pool = MLATokenToKVPool(
1830
1709
  self.max_total_num_tokens,
@@ -1868,6 +1747,12 @@ class ModelRunner:
1868
1747
  device=self.device,
1869
1748
  )
1870
1749
  elif config := self.mambaish_config:
1750
+ extra_args = {}
1751
+ if self.use_mla_backend:
1752
+ extra_args = {
1753
+ "kv_lora_rank": self.model_config.kv_lora_rank,
1754
+ "qk_rope_head_dim": self.model_config.qk_rope_head_dim,
1755
+ }
1871
1756
  self.token_to_kv_pool = HybridLinearKVPool(
1872
1757
  page_size=self.page_size,
1873
1758
  size=self.max_total_num_tokens,
@@ -1883,6 +1768,8 @@ class ModelRunner:
1883
1768
  enable_kvcache_transpose=False,
1884
1769
  device=self.device,
1885
1770
  mamba_pool=self.req_to_token_pool.mamba_pool,
1771
+ use_mla=self.use_mla_backend,
1772
+ **extra_args,
1886
1773
  )
1887
1774
  else:
1888
1775
  self.token_to_kv_pool = MHATokenToKVPool(
@@ -1898,6 +1785,7 @@ class ModelRunner:
1898
1785
  enable_memory_saver=self.server_args.enable_memory_saver,
1899
1786
  start_layer=self.start_layer,
1900
1787
  end_layer=self.end_layer,
1788
+ enable_alt_stream=not self.server_args.enable_pdmux,
1901
1789
  enable_kv_cache_copy=(
1902
1790
  self.server_args.speculative_algorithm is not None
1903
1791
  ),
@@ -1966,12 +1854,18 @@ class ModelRunner:
1966
1854
 
1967
1855
  def init_attention_backend(self):
1968
1856
  """Init attention kernel backend."""
1969
- if self.server_args.enable_two_batch_overlap and not self.is_draft_worker:
1857
+ if self.server_args.enable_pdmux:
1858
+ self.attn_backend = self._get_attention_backend(init_new_workspace=True)
1859
+ self.decode_attn_backend_group = []
1860
+ for _ in range(self.server_args.sm_group_num):
1861
+ self.decode_attn_backend_group.append(self._get_attention_backend())
1862
+ self.decode_attn_backend = self.decode_attn_backend_group[0]
1863
+ elif self.server_args.enable_two_batch_overlap and not self.is_draft_worker:
1970
1864
  self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend)
1971
1865
  else:
1972
1866
  self.attn_backend = self._get_attention_backend()
1973
1867
 
1974
- def _get_attention_backend(self):
1868
+ def _get_attention_backend(self, init_new_workspace: bool = False):
1975
1869
  """Init attention kernel backend."""
1976
1870
  self.prefill_attention_backend_str, self.decode_attention_backend_str = (
1977
1871
  self.server_args.get_attention_backends()
@@ -1985,10 +1879,12 @@ class ModelRunner:
1985
1879
  attn_backend = HybridAttnBackend(
1986
1880
  self,
1987
1881
  decode_backend=self._get_attention_backend_from_str(
1988
- self.decode_attention_backend_str
1882
+ self.decode_attention_backend_str,
1883
+ init_new_workspace=init_new_workspace,
1989
1884
  ),
1990
1885
  prefill_backend=self._get_attention_backend_from_str(
1991
- self.prefill_attention_backend_str
1886
+ self.prefill_attention_backend_str,
1887
+ init_new_workspace=init_new_workspace,
1992
1888
  ),
1993
1889
  )
1994
1890
  logger.info(
@@ -2002,7 +1898,8 @@ class ModelRunner:
2002
1898
  )
2003
1899
  else:
2004
1900
  attn_backend = self._get_attention_backend_from_str(
2005
- self.server_args.attention_backend
1901
+ self.server_args.attention_backend,
1902
+ init_new_workspace=init_new_workspace,
2006
1903
  )
2007
1904
 
2008
1905
  (
@@ -2011,9 +1908,12 @@ class ModelRunner:
2011
1908
  ) = (self.prefill_attention_backend_str, self.decode_attention_backend_str)
2012
1909
  return attn_backend
2013
1910
 
2014
- def _get_attention_backend_from_str(self, backend_str: str):
1911
+ def _get_attention_backend_from_str(
1912
+ self, backend_str: str, init_new_workspace: bool = False
1913
+ ):
2015
1914
  if backend_str not in ATTENTION_BACKENDS:
2016
1915
  raise ValueError(f"Invalid attention backend: {backend_str}")
1916
+ self.init_new_workspace = init_new_workspace
2017
1917
  full_attention_backend = ATTENTION_BACKENDS[backend_str](self)
2018
1918
  return attn_backend_wrapper(self, full_attention_backend)
2019
1919
 
@@ -2111,6 +2011,9 @@ class ModelRunner:
2111
2011
  device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
2112
2012
  tensor_parallel(self.model, device_mesh)
2113
2013
 
2014
+ def update_decode_attn_backend(self, stream_idx: int):
2015
+ self.decode_attn_backend = self.decode_attn_backend_group[stream_idx]
2016
+
2114
2017
  def forward_decode(
2115
2018
  self,
2116
2019
  forward_batch: ForwardBatch,
@@ -2118,7 +2021,11 @@ class ModelRunner:
2118
2021
  pp_proxy_tensors=None,
2119
2022
  ) -> LogitsProcessorOutput:
2120
2023
  if not skip_attn_backend_init:
2121
- self.attn_backend.init_forward_metadata(forward_batch)
2024
+ if self.server_args.enable_pdmux:
2025
+ self.decode_attn_backend.init_forward_metadata(forward_batch)
2026
+ forward_batch.attn_backend = self.decode_attn_backend
2027
+ else:
2028
+ self.attn_backend.init_forward_metadata(forward_batch)
2122
2029
  # FIXME: add pp_proxy_tensors arg to all models
2123
2030
  kwargs = {}
2124
2031
  if self.support_pp:
@@ -2256,18 +2163,18 @@ class ModelRunner:
2256
2163
  skip_attn_backend_init=skip_attn_backend_init,
2257
2164
  pp_proxy_tensors=pp_proxy_tensors,
2258
2165
  )
2259
- elif forward_batch.forward_mode.is_extend():
2260
- ret = self.forward_extend(
2261
- forward_batch,
2262
- skip_attn_backend_init=skip_attn_backend_init,
2263
- pp_proxy_tensors=pp_proxy_tensors,
2264
- )
2265
2166
  elif forward_batch.forward_mode.is_split_prefill():
2266
2167
  ret = self.forward_split_prefill(
2267
2168
  forward_batch,
2268
2169
  reinit_attn_backend=reinit_attn_backend,
2269
2170
  forward_count=split_forward_count,
2270
2171
  )
2172
+ elif forward_batch.forward_mode.is_extend():
2173
+ ret = self.forward_extend(
2174
+ forward_batch,
2175
+ skip_attn_backend_init=skip_attn_backend_init,
2176
+ pp_proxy_tensors=pp_proxy_tensors,
2177
+ )
2271
2178
  elif forward_batch.forward_mode.is_idle():
2272
2179
  ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
2273
2180
  else:
@@ -75,9 +75,13 @@ class NPUGraphRunner(CudaGraphRunner):
75
75
 
76
76
  # Replay
77
77
  if not is_deepseek_nsa(self.model_runner.model_config.hf_config):
78
- seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (
79
- self.bs - self.raw_bs
80
- )
78
+ if forward_batch.forward_mode.is_target_verify():
79
+ seq_lens_cpu = forward_batch.seq_lens.cpu() + self.num_tokens_per_bs
80
+ seq_lens = seq_lens_cpu.tolist() + [0] * (self.bs - self.raw_bs)
81
+ else:
82
+ seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (
83
+ self.bs - self.raw_bs
84
+ )
81
85
  thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
82
86
  thread.start()
83
87
  self.graphs[self.bs].replay()
@@ -32,7 +32,6 @@ from sglang.srt.distributed import get_tensor_model_parallel_rank
32
32
  from sglang.srt.distributed.device_communicators.pynccl_allocator import (
33
33
  set_graph_pool_id,
34
34
  )
35
- from sglang.srt.distributed.parallel_state import graph_capture
36
35
  from sglang.srt.layers.dp_attention import (
37
36
  DpPaddingMode,
38
37
  get_attention_tp_rank,
@@ -250,6 +249,9 @@ class PiecewiseCudaGraphRunner:
250
249
  lora_ids=None,
251
250
  )
252
251
 
252
+ # Attention backend
253
+ self.model_runner.attn_backend.init_forward_metadata(forward_batch)
254
+
253
255
  with set_forward_context(forward_batch, self.attention_layers):
254
256
  _ = self.model_runner.model.forward(
255
257
  forward_batch.input_ids,
@@ -262,9 +264,14 @@ class PiecewiseCudaGraphRunner:
262
264
 
263
265
  def can_run(self, forward_batch: ForwardBatch):
264
266
  num_tokens = len(forward_batch.input_ids)
265
- # TODO(yuwei): support return logprob
267
+ # TODO(yuwei): support return input_ids' logprob
266
268
  if forward_batch.return_logprob:
267
- return False
269
+ for start_len, seq_len in zip(
270
+ forward_batch.extend_logprob_start_lens_cpu,
271
+ forward_batch.extend_seq_lens_cpu,
272
+ ):
273
+ if start_len is not None and start_len < seq_len:
274
+ return False
268
275
  if num_tokens <= self.max_num_tokens:
269
276
  return True
270
277
  return False
@@ -273,10 +280,10 @@ class PiecewiseCudaGraphRunner:
273
280
  # Trigger CUDA graph capture for specific shapes.
274
281
  # Capture the large shapes first so that the smaller shapes
275
282
  # can reuse the memory pool allocated for the large shapes.
276
- with freeze_gc(
277
- self.model_runner.server_args.enable_cudagraph_gc
278
- ), graph_capture() as graph_capture_context:
279
- self.stream = graph_capture_context.stream
283
+ with freeze_gc(self.model_runner.server_args.enable_cudagraph_gc):
284
+ if self.model_runner.tp_group.ca_comm is not None:
285
+ old_ca_disable = self.model_runner.tp_group.ca_comm.disabled
286
+ self.model_runner.tp_group.ca_comm.disabled = True
280
287
  avail_mem = get_available_gpu_memory(
281
288
  self.model_runner.device,
282
289
  self.model_runner.gpu_id,
@@ -304,9 +311,10 @@ class PiecewiseCudaGraphRunner:
304
311
 
305
312
  # Save gemlite cache after each capture
306
313
  save_gemlite_cache()
314
+ if self.model_runner.tp_group.ca_comm is not None:
315
+ self.model_runner.tp_group.ca_comm.disabled = old_ca_disable
307
316
 
308
317
  def capture_one_batch_size(self, num_tokens: int):
309
- stream = self.stream
310
318
  bs = 1
311
319
 
312
320
  # Graph inputs
@@ -370,9 +378,6 @@ class PiecewiseCudaGraphRunner:
370
378
  if lora_ids is not None:
371
379
  self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
372
380
 
373
- # # Attention backend
374
- self.model_runner.attn_backend.init_forward_metadata(forward_batch)
375
-
376
381
  # Run and capture
377
382
  def run_once():
378
383
  # Clean intermediate result cache for DP attention
@@ -438,7 +443,7 @@ class PiecewiseCudaGraphRunner:
438
443
  out_cache_loc=out_cache_loc,
439
444
  seq_lens_sum=forward_batch.seq_lens_sum,
440
445
  encoder_lens=forward_batch.encoder_lens,
441
- return_logprob=forward_batch.return_logprob,
446
+ return_logprob=False,
442
447
  extend_seq_lens=forward_batch.extend_seq_lens,
443
448
  extend_prefix_lens=forward_batch.extend_prefix_lens,
444
449
  extend_start_loc=forward_batch.extend_start_loc,
@@ -474,6 +479,9 @@ class PiecewiseCudaGraphRunner:
474
479
  forward_batch: ForwardBatch,
475
480
  **kwargs,
476
481
  ) -> Union[LogitsProcessorOutput, PPProxyTensors]:
482
+ if self.model_runner.tp_group.ca_comm is not None:
483
+ old_ca_disable = self.model_runner.tp_group.ca_comm.disabled
484
+ self.model_runner.tp_group.ca_comm.disabled = True
477
485
  static_forward_batch = self.replay_prepare(forward_batch, **kwargs)
478
486
  # Replay
479
487
  with set_forward_context(static_forward_batch, self.attention_layers):
@@ -499,6 +507,8 @@ class PiecewiseCudaGraphRunner:
499
507
  raise NotImplementedError(
500
508
  "PPProxyTensors is not supported in PiecewiseCudaGraphRunner yet."
501
509
  )
510
+ if self.model_runner.tp_group.ca_comm is not None:
511
+ self.model_runner.tp_group.ca_comm.disabled = old_ca_disable
502
512
 
503
513
  def get_spec_info(self, num_tokens: int):
504
514
  spec_info = None