sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -132,6 +132,9 @@ class ForwardMode(IntEnum):
132
132
  or self == ForwardMode.IDLE
133
133
  )
134
134
 
135
+ def is_cpu_graph(self):
136
+ return self == ForwardMode.DECODE
137
+
135
138
  def is_dummy_first(self):
136
139
  return self == ForwardMode.DUMMY_FIRST
137
140
 
@@ -441,7 +444,13 @@ class ForwardBatch:
441
444
  ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
442
445
 
443
446
  if model_runner.model_is_mrope:
444
- ret._compute_mrope_positions(model_runner, batch)
447
+ if (
448
+ ret.spec_info is not None
449
+ and getattr(ret.spec_info, "positions", None) is not None
450
+ ):
451
+ ret._compute_spec_mrope_positions(model_runner, batch)
452
+ else:
453
+ ret._compute_mrope_positions(model_runner, batch)
445
454
 
446
455
  # Init lora information
447
456
  if model_runner.server_args.enable_lora:
@@ -507,6 +516,52 @@ class ForwardBatch:
507
516
  or self.contains_image_inputs()
508
517
  )
509
518
 
519
+ def _compute_spec_mrope_positions(
520
+ self, model_runner: ModelRunner, batch: ModelWorkerBatch
521
+ ):
522
+ # TODO support batched deltas
523
+ batch_size = self.seq_lens.shape[0]
524
+ device = model_runner.device
525
+ mm_inputs = batch.multimodal_inputs
526
+
527
+ if batch.forward_mode.is_draft_extend(): # draft_extend_after_decode
528
+ mrope_deltas = []
529
+ extend_lens = []
530
+ for batch_idx in range(batch_size):
531
+ extend_seq_len = batch.extend_seq_lens[batch_idx]
532
+ extend_lens.append(extend_seq_len)
533
+ mrope_delta = (
534
+ torch.zeros(1, dtype=torch.int64)
535
+ if mm_inputs[batch_idx] is None
536
+ else mm_inputs[batch_idx].mrope_position_delta.squeeze(0)
537
+ )
538
+ mrope_deltas.append(mrope_delta.to(device=device))
539
+ position_chunks = torch.split(batch.spec_info.positions, extend_lens)
540
+ mrope_positions_list = [
541
+ pos_chunk + delta
542
+ for pos_chunk, delta in zip(position_chunks, mrope_deltas)
543
+ ]
544
+ next_input_positions = (
545
+ torch.cat(mrope_positions_list, dim=0).unsqueeze(0).repeat(3, 1)
546
+ )
547
+
548
+ else: # target_verify or draft_decode
549
+ seq_positions = batch.spec_info.positions.view(batch_size, -1)
550
+ mrope_deltas = [
551
+ (
552
+ torch.tensor([0], dtype=torch.int64)
553
+ if mm_inputs[i] is None
554
+ else mm_inputs[i].mrope_position_delta.squeeze(0)
555
+ )
556
+ for i in range(batch_size)
557
+ ]
558
+ mrope_delta_tensor = torch.stack(mrope_deltas, dim=0).to(device=device)
559
+ next_input_positions = (
560
+ (seq_positions + mrope_delta_tensor).flatten().unsqueeze(0).repeat(3, 1)
561
+ )
562
+
563
+ self.mrope_positions = next_input_positions
564
+
510
565
  def _compute_mrope_positions(
511
566
  self, model_runner: ModelRunner, batch: ModelWorkerBatch
512
567
  ):
@@ -516,24 +571,23 @@ class ForwardBatch:
516
571
  for batch_idx in range(batch_size):
517
572
  mm_input = batch.multimodal_inputs[batch_idx]
518
573
  if self.forward_mode.is_decode():
519
- mrope_position_deltas = (
520
- [0]
521
- if mm_input is None
522
- else flatten_nested_list(mm_input.mrope_position_delta.tolist())
523
- )
524
- next_input_positions = []
525
- for mrope_position_delta in mrope_position_deltas:
526
- # batched deltas needs to be processed separately
527
- # Convert list of lists to tensor with shape [3, seq_len]
528
- next_input_positions += [
529
- MRotaryEmbedding.get_next_input_positions(
530
- mrope_position_delta,
531
- int(self.seq_lens[batch_idx]) - 1,
532
- int(self.seq_lens[batch_idx]),
533
- )
534
- ]
535
574
  # 3 * N
536
- mrope_positions_list[batch_idx] = torch.cat(next_input_positions, dim=1)
575
+ if mm_input is None:
576
+ mrope_positions_list[batch_idx] = torch.full(
577
+ (3, 1),
578
+ self.seq_lens[batch_idx] - 1,
579
+ dtype=torch.int64,
580
+ device=model_runner.device,
581
+ )
582
+ else:
583
+ mrope_position_deltas = mm_input.mrope_position_delta.flatten().to(
584
+ model_runner.device, non_blocking=True
585
+ )
586
+ mrope_positions_list[batch_idx] = (
587
+ (mrope_position_deltas + self.seq_lens[batch_idx] - 1)
588
+ .unsqueeze(0)
589
+ .repeat(3, 1)
590
+ )
537
591
  elif self.forward_mode.is_extend():
538
592
  extend_seq_len, extend_prefix_len = (
539
593
  batch.extend_seq_lens[batch_idx],
@@ -20,6 +20,7 @@ import json
20
20
  import logging
21
21
  import os
22
22
  import time
23
+ from collections import defaultdict
23
24
  from dataclasses import dataclass
24
25
  from typing import List, Optional, Tuple, Union
25
26
 
@@ -32,6 +33,7 @@ from sglang.srt.configs.model_config import AttentionArch, ModelConfig
32
33
  from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
33
34
  from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
34
35
  from sglang.srt.distributed import (
36
+ get_pp_group,
35
37
  get_tp_group,
36
38
  get_world_group,
37
39
  init_distributed_environment,
@@ -83,11 +85,14 @@ from sglang.srt.mem_cache.memory_pool import (
83
85
  AscendMLAPagedTokenToKVPool,
84
86
  AscendTokenToKVPool,
85
87
  DoubleSparseTokenToKVPool,
88
+ HybridLinearKVPool,
89
+ HybridReqToTokenPool,
86
90
  MHATokenToKVPool,
87
91
  MLATokenToKVPool,
88
92
  ReqToTokenPool,
89
93
  SWAKVPool,
90
94
  )
95
+ from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner
91
96
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
92
97
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
93
98
  from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
@@ -300,6 +305,26 @@ class ModelRunner:
300
305
  if architectures and not any("Llama4" in arch for arch in architectures):
301
306
  self.is_hybrid = self.model_config.is_hybrid = True
302
307
 
308
+ if self.is_hybrid_gdn:
309
+ logger.warning("Hybrid GDN model detected, disable radix cache")
310
+ self.server_args.disable_radix_cache = True
311
+ self.server_args.attention_backend = "hybrid_linear_attn"
312
+ if self.server_args.max_mamba_cache_size is None:
313
+ if self.server_args.max_running_requests is not None:
314
+ self.server_args.max_mamba_cache_size = (
315
+ self.server_args.max_running_requests
316
+ )
317
+ else:
318
+ self.server_args.max_mamba_cache_size = 512
319
+ self.server_args.max_mamba_cache_size = (
320
+ self.server_args.max_mamba_cache_size
321
+ // (
322
+ self.server_args.dp_size
323
+ if self.server_args.enable_dp_attention
324
+ else 1
325
+ )
326
+ )
327
+
303
328
  # For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft
304
329
  # models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to
305
330
  # determine the number of layers.
@@ -307,7 +332,10 @@ class ModelRunner:
307
332
  model_num_layers = (
308
333
  self.model_config.num_nextn_predict_layers
309
334
  if self.is_draft_worker and model_has_mtp_layers
310
- else self.model_config.num_hidden_layers
335
+ else max(
336
+ self.model_config.num_hidden_layers,
337
+ self.model_config.num_attention_layers,
338
+ )
311
339
  )
312
340
  self.start_layer = getattr(self.model, "start_layer", 0)
313
341
  self.end_layer = getattr(self.model, "end_layer", model_num_layers)
@@ -338,6 +366,14 @@ class ModelRunner:
338
366
  if server_args.enable_lora:
339
367
  self.init_lora_manager()
340
368
 
369
+ # Init Double Sparsity
370
+ if server_args.enable_double_sparsity:
371
+ if server_args.ds_heavy_channel_type is None:
372
+ raise ValueError(
373
+ "Please specify the heavy channel type for double sparsity optimization."
374
+ )
375
+ self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
376
+
341
377
  # Init memory pool and attention backends
342
378
  self.init_memory_pool(
343
379
  min_per_gpu_memory,
@@ -348,12 +384,12 @@ class ModelRunner:
348
384
  self.init_cublas()
349
385
  self.init_attention_backend()
350
386
  self.init_device_graphs()
351
- elif self.device == "npu":
387
+ elif self.device in ["npu", "cpu"]:
352
388
  self.init_attention_backend()
353
389
  self.init_device_graphs()
354
390
  else:
355
391
  self.graph_runner = None
356
- self.cuda_graph_mem_usage = 0
392
+ self.graph_mem_usage = 0
357
393
  self.init_attention_backend()
358
394
 
359
395
  # auxiliary hidden capture mode. TODO: expose this to server args?
@@ -503,11 +539,6 @@ class ModelRunner:
503
539
  )
504
540
  server_args.attention_backend = "triton"
505
541
  server_args.disable_cuda_graph = True
506
- if server_args.ds_heavy_channel_type is None:
507
- raise ValueError(
508
- "Please specify the heavy channel type for double sparsity optimization."
509
- )
510
- self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
511
542
 
512
543
  if self.is_multimodal:
513
544
  if not self.is_multimodal_chunked_prefill_supported:
@@ -519,6 +550,17 @@ class ModelRunner:
519
550
 
520
551
  if not self.use_mla_backend:
521
552
  server_args.disable_chunked_prefix_cache = True
553
+ # TODO(kaixih@nvidia): remove this once we have a better solution for DP attention.
554
+ # For more details, see: https://github.com/sgl-project/sglang/issues/8616
555
+ elif (
556
+ self.dp_size > 1
557
+ and is_sm100_supported()
558
+ and server_args.attention_backend != "triton"
559
+ ):
560
+ logger.info(
561
+ "Disable chunked prefix cache when dp size > 1 and attention backend is not triton."
562
+ )
563
+ server_args.disable_chunked_prefix_cache = True
522
564
 
523
565
  if not server_args.disable_chunked_prefix_cache:
524
566
  logger.info("Chunked prefix cache is turned on.")
@@ -590,6 +632,11 @@ class ModelRunner:
590
632
  # Set local size to hint SGLang to use shared memory based AllReduce
591
633
  os.environ["LOCAL_SIZE"] = str(self.tp_size)
592
634
  torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank)
635
+
636
+ @torch.library.register_fake("sgl_kernel::shm_allgather")
637
+ def _(data, dim):
638
+ return torch.cat([data] * self.tp_size, dim=dim)
639
+
593
640
  else:
594
641
  logger.warning(
595
642
  "init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available"
@@ -622,6 +669,7 @@ class ModelRunner:
622
669
  cpu_group=get_world_group().cpu_group,
623
670
  )
624
671
  self.tp_group = get_tp_group()
672
+ self.pp_group = get_pp_group()
625
673
  self.attention_tp_group = get_attention_tp_group()
626
674
 
627
675
  # Check memory for tensor parallelism
@@ -1054,6 +1102,8 @@ class ModelRunner:
1054
1102
  "num_nextn_predict_layers",
1055
1103
  self.num_effective_layers,
1056
1104
  )
1105
+ elif self.is_hybrid_gdn:
1106
+ num_layers = len(self.model_config.hf_config.full_attention_layer_ids)
1057
1107
  else:
1058
1108
  num_layers = self.num_effective_layers
1059
1109
  if self.use_mla_backend:
@@ -1073,9 +1123,22 @@ class ModelRunner:
1073
1123
  rest_memory = available_gpu_memory - total_gpu_memory * (
1074
1124
  1 - self.mem_fraction_static
1075
1125
  )
1126
+ if self.is_hybrid_gdn:
1127
+ rest_memory -= (
1128
+ self.server_args.max_mamba_cache_size
1129
+ * self.model_config.hf_config.mamba_cache_per_req
1130
+ / (1 << 30)
1131
+ )
1076
1132
  max_num_token = int(rest_memory * (1 << 30) // cell_size)
1077
1133
  return max_num_token
1078
1134
 
1135
+ @property
1136
+ def is_hybrid_gdn(self):
1137
+ return self.model_config.hf_config.architectures[0] in [
1138
+ "Qwen3NextForCausalLM",
1139
+ "Qwen3NextForCausalLMMTP",
1140
+ ]
1141
+
1079
1142
  def set_num_token_hybrid(self):
1080
1143
  if (
1081
1144
  "Llama4ForConditionalGeneration"
@@ -1196,6 +1259,8 @@ class ModelRunner:
1196
1259
  ),
1197
1260
  4096,
1198
1261
  )
1262
+ if self.is_hybrid_gdn:
1263
+ max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
1199
1264
 
1200
1265
  if not self.spec_algorithm.is_none():
1201
1266
  if self.is_draft_worker:
@@ -1234,6 +1299,16 @@ class ModelRunner:
1234
1299
  // self.server_args.page_size
1235
1300
  * self.server_args.page_size
1236
1301
  )
1302
+ # different pp rank may have different num of layers, so we need to reduce the max_total_num_tokens
1303
+ if self.pp_size > 1:
1304
+ tensor = torch.tensor(self.max_total_num_tokens, dtype=torch.int64)
1305
+ torch.distributed.all_reduce(
1306
+ tensor,
1307
+ op=torch.distributed.ReduceOp.MIN,
1308
+ group=get_world_group().cpu_group,
1309
+ )
1310
+ self.max_total_num_tokens = tensor.item()
1311
+
1237
1312
  # create token size for hybrid cache
1238
1313
  if self.is_hybrid:
1239
1314
  self.set_num_token_hybrid()
@@ -1264,6 +1339,28 @@ class ModelRunner:
1264
1339
  enable_memory_saver=self.server_args.enable_memory_saver,
1265
1340
  pre_alloc_size=pre_alloc_size,
1266
1341
  )
1342
+ elif self.is_hybrid_gdn:
1343
+ config = self.model_config.hf_config
1344
+ (
1345
+ conv_state_shape,
1346
+ temporal_state_shape,
1347
+ conv_dtype,
1348
+ ssm_dtype,
1349
+ mamba_layers,
1350
+ ) = config.hybrid_gdn_params
1351
+ self.req_to_token_pool = HybridReqToTokenPool(
1352
+ size=max_num_reqs,
1353
+ max_context_len=self.model_config.context_len
1354
+ + extra_max_context_len,
1355
+ device=self.device,
1356
+ enable_memory_saver=self.server_args.enable_memory_saver,
1357
+ conv_state_shape=conv_state_shape,
1358
+ temporal_state_shape=temporal_state_shape,
1359
+ conv_dtype=conv_dtype,
1360
+ ssm_dtype=ssm_dtype,
1361
+ mamba_layers=mamba_layers,
1362
+ speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
1363
+ )
1267
1364
  else:
1268
1365
  self.req_to_token_pool = ReqToTokenPool(
1269
1366
  size=max_num_reqs,
@@ -1346,6 +1443,23 @@ class ModelRunner:
1346
1443
  enable_kvcache_transpose=False,
1347
1444
  device=self.device,
1348
1445
  )
1446
+ elif self.is_hybrid_gdn:
1447
+ self.token_to_kv_pool = HybridLinearKVPool(
1448
+ size=self.max_total_num_tokens,
1449
+ dtype=self.kv_cache_dtype,
1450
+ head_num=self.model_config.get_num_kv_heads(
1451
+ get_attention_tp_size()
1452
+ ),
1453
+ head_dim=self.model_config.head_dim,
1454
+ # if draft worker, we only need 1 attention layer's kv pool
1455
+ full_attention_layer_ids=(
1456
+ [0]
1457
+ if self.is_draft_worker
1458
+ else self.model_config.hf_config.full_attention_layer_ids
1459
+ ),
1460
+ enable_kvcache_transpose=False,
1461
+ device=self.device,
1462
+ )
1349
1463
  else:
1350
1464
  self.token_to_kv_pool = MHATokenToKVPool(
1351
1465
  self.max_total_num_tokens,
@@ -1440,14 +1554,12 @@ class ModelRunner:
1440
1554
  else self.server_args.attention_backend
1441
1555
  )
1442
1556
  if self.decode_attention_backend_str != self.prefill_attention_backend_str:
1443
- assert (
1444
- self.server_args.speculative_algorithm is None
1445
- ), "Currently HybridAttentionBackend does not support speculative decoding."
1446
1557
  from sglang.srt.layers.attention.hybrid_attn_backend import (
1447
1558
  HybridAttnBackend,
1448
1559
  )
1449
1560
 
1450
1561
  attn_backend = HybridAttnBackend(
1562
+ self,
1451
1563
  decode_backend=self._get_attention_backend_from_str(
1452
1564
  self.decode_attention_backend_str
1453
1565
  ),
@@ -1581,6 +1693,24 @@ class ModelRunner:
1581
1693
  )
1582
1694
 
1583
1695
  return DualChunkFlashAttentionBackend(self)
1696
+ elif backend_str == "hybrid_linear_attn":
1697
+ assert (
1698
+ self.is_hybrid_gdn
1699
+ ), "hybrid_linear_attn backend can only be used with hybrid GDN models."
1700
+ from sglang.srt.layers.attention.flashattention_backend import (
1701
+ FlashAttentionBackend,
1702
+ )
1703
+ from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
1704
+ HybridLinearAttnBackend,
1705
+ MambaAttnBackend,
1706
+ )
1707
+
1708
+ full_attn_backend = FlashAttentionBackend(self)
1709
+ linear_attn_backend = MambaAttnBackend(self)
1710
+ full_attn_layers = self.model_config.hf_config.full_attention_layer_ids
1711
+ return HybridLinearAttnBackend(
1712
+ full_attn_backend, linear_attn_backend, full_attn_layers
1713
+ )
1584
1714
  else:
1585
1715
  raise ValueError(f"Invalid attention backend: {backend_str}")
1586
1716
 
@@ -1602,38 +1732,46 @@ class ModelRunner:
1602
1732
  )
1603
1733
 
1604
1734
  def init_device_graphs(self):
1605
- """Capture cuda graphs."""
1735
+ """Capture device graphs."""
1606
1736
  self.graph_runner = None
1607
- self.cuda_graph_mem_usage = 0
1737
+ self.graph_mem_usage = 0
1608
1738
 
1609
1739
  if not self.is_generation:
1610
1740
  # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
1611
1741
  return
1612
1742
 
1613
- if self.server_args.disable_cuda_graph:
1743
+ if self.device != "cpu" and self.server_args.disable_cuda_graph:
1744
+ return
1745
+
1746
+ if self.device == "cpu" and not self.server_args.enable_torch_compile:
1614
1747
  return
1615
1748
 
1616
1749
  tic = time.perf_counter()
1617
1750
  before_mem = get_available_gpu_memory(self.device, self.gpu_id)
1618
1751
  logger.info(
1619
- f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
1752
+ f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
1620
1753
  )
1621
- self.graph_runner = (
1622
- CudaGraphRunner(self) if not _is_npu else NPUGraphRunner(self)
1754
+ graph_runners = defaultdict(
1755
+ lambda: CudaGraphRunner,
1756
+ {
1757
+ "cpu": CPUGraphRunner,
1758
+ "npu": NPUGraphRunner,
1759
+ },
1623
1760
  )
1761
+ self.graph_runner = graph_runners[self.device](self)
1762
+
1624
1763
  after_mem = get_available_gpu_memory(self.device, self.gpu_id)
1625
- self.cuda_graph_mem_usage = before_mem - after_mem
1764
+ self.graph_mem_usage = before_mem - after_mem
1626
1765
  logger.info(
1627
- f"Capture cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
1628
- f"mem usage={self.cuda_graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB."
1766
+ f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
1767
+ f"mem usage={self.graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB."
1629
1768
  )
1630
1769
 
1631
1770
  def init_threads_binding(self):
1632
1771
  omp_cpuids = os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", "all")
1772
+ cpu_ids_by_node = get_cpu_ids_by_node()
1773
+ n_numa_node = len(cpu_ids_by_node)
1633
1774
  if omp_cpuids == "all":
1634
- cpu_ids_by_node = get_cpu_ids_by_node()
1635
- n_numa_node = len(cpu_ids_by_node)
1636
-
1637
1775
  assert self.tp_size <= n_numa_node, (
1638
1776
  f"SGLANG_CPU_OMP_THREADS_BIND is not set, in this case, "
1639
1777
  f"tp_size {self.tp_size} should be smaller than or equal to number of numa node on the machine {n_numa_node}. "
@@ -1650,11 +1788,22 @@ class ModelRunner:
1650
1788
  )
1651
1789
  self.local_omp_cpuid = cpu_ids_by_node[self.tp_rank]
1652
1790
  else:
1653
- self.local_omp_cpuid = omp_cpuids.split("|")[self.tp_rank]
1791
+ threads_bind_list = omp_cpuids.split("|")
1792
+ assert self.tp_size == len(threads_bind_list), (
1793
+ f"SGLANG_CPU_OMP_THREADS_BIND setting must be aligned with TP size parameter ({self.tp_size}). "
1794
+ f"Please double check your settings."
1795
+ )
1796
+ self.local_omp_cpuid = threads_bind_list[self.tp_rank]
1797
+ if self.tp_size > n_numa_node:
1798
+ logger.warning(
1799
+ f"TP size ({self.tp_size})is larger than numa node number ({n_numa_node}), "
1800
+ f"in this case the available memory amount of each rank cannot be determined in prior. "
1801
+ f"Please set proper `--max-total-tokens` to avoid the out-of-memory error."
1802
+ )
1654
1803
 
1655
1804
  def apply_torch_tp(self):
1656
1805
  logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
1657
- from sglang.srt.model_parallel import tensor_parallel
1806
+ from sglang.srt.layers.model_parallel import tensor_parallel
1658
1807
 
1659
1808
  device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
1660
1809
  tensor_parallel(self.model, device_mesh)
@@ -1770,18 +1919,24 @@ class ModelRunner:
1770
1919
  reinit_attn_backend: bool = False,
1771
1920
  split_forward_count: int = 1,
1772
1921
  ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
1773
- can_run_cuda_graph = bool(
1774
- forward_batch.forward_mode.is_cuda_graph()
1922
+ mode_check = (
1923
+ forward_batch.forward_mode.is_cpu_graph
1924
+ if self.device == "cpu"
1925
+ else forward_batch.forward_mode.is_cuda_graph
1926
+ )
1927
+ can_run_graph = bool(
1928
+ mode_check()
1775
1929
  and self.graph_runner
1776
1930
  and self.graph_runner.can_run(forward_batch)
1777
1931
  )
1778
- if can_run_cuda_graph:
1932
+
1933
+ if can_run_graph:
1779
1934
  ret = self.graph_runner.replay(
1780
1935
  forward_batch,
1781
1936
  skip_attn_backend_init=skip_attn_backend_init,
1782
1937
  pp_proxy_tensors=pp_proxy_tensors,
1783
1938
  )
1784
- return ret, can_run_cuda_graph
1939
+ return ret, can_run_graph
1785
1940
 
1786
1941
  # For MLP sync
1787
1942
  if forward_batch.global_num_tokens_cpu is not None:
@@ -1810,10 +1965,13 @@ class ModelRunner:
1810
1965
  else:
1811
1966
  raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
1812
1967
 
1813
- if forward_batch.global_num_tokens_cpu is not None:
1968
+ if (
1969
+ forward_batch.global_num_tokens_cpu is not None
1970
+ and self.pp_group.is_last_rank
1971
+ ):
1814
1972
  forward_batch.post_forward_mlp_sync_batch(ret)
1815
1973
 
1816
- return ret, can_run_cuda_graph
1974
+ return ret, can_run_graph
1817
1975
 
1818
1976
  def _preprocess_logits(
1819
1977
  self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
@@ -1,16 +1,22 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/__init__.py
2
2
 
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
3
7
  from torch import nn
4
8
 
5
- from sglang.srt.configs.device_config import DeviceConfig
6
- from sglang.srt.configs.load_config import LoadConfig
7
- from sglang.srt.configs.model_config import ModelConfig
8
9
  from sglang.srt.model_loader.loader import BaseModelLoader, get_model_loader
9
10
  from sglang.srt.model_loader.utils import (
10
11
  get_architecture_class_name,
11
12
  get_model_architecture,
12
13
  )
13
14
 
15
+ if TYPE_CHECKING:
16
+ from sglang.srt.configs.device_config import DeviceConfig
17
+ from sglang.srt.configs.load_config import LoadConfig
18
+ from sglang.srt.configs.model_config import ModelConfig
19
+
14
20
 
15
21
  def get_model(
16
22
  *,