sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/device_config.py +3 -1
  5. sglang/srt/configs/dots_vlm.py +139 -0
  6. sglang/srt/configs/load_config.py +1 -0
  7. sglang/srt/configs/model_config.py +50 -6
  8. sglang/srt/configs/qwen3_next.py +326 -0
  9. sglang/srt/connector/__init__.py +8 -1
  10. sglang/srt/connector/remote_instance.py +82 -0
  11. sglang/srt/constrained/base_grammar_backend.py +48 -12
  12. sglang/srt/constrained/llguidance_backend.py +0 -1
  13. sglang/srt/constrained/outlines_backend.py +0 -1
  14. sglang/srt/constrained/xgrammar_backend.py +28 -9
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/base/conn.py +1 -1
  21. sglang/srt/disaggregation/common/conn.py +15 -12
  22. sglang/srt/disaggregation/decode.py +21 -10
  23. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -445
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +5 -3
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +24 -3
  31. sglang/srt/entrypoints/engine.py +38 -17
  32. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  33. sglang/srt/entrypoints/grpc_server.py +680 -0
  34. sglang/srt/entrypoints/http_server.py +85 -54
  35. sglang/srt/entrypoints/openai/protocol.py +4 -1
  36. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  37. sglang/srt/entrypoints/openai/serving_chat.py +36 -16
  38. sglang/srt/entrypoints/openai/serving_completions.py +12 -3
  39. sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
  40. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  41. sglang/srt/entrypoints/openai/serving_responses.py +6 -3
  42. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  43. sglang/srt/eplb/eplb_manager.py +2 -2
  44. sglang/srt/eplb/expert_distribution.py +26 -13
  45. sglang/srt/eplb/expert_location.py +8 -3
  46. sglang/srt/eplb/expert_location_updater.py +1 -1
  47. sglang/srt/function_call/base_format_detector.py +3 -6
  48. sglang/srt/function_call/ebnf_composer.py +11 -9
  49. sglang/srt/function_call/function_call_parser.py +6 -0
  50. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  51. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  52. sglang/srt/grpc/__init__.py +1 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  56. sglang/srt/hf_transformers_utils.py +4 -0
  57. sglang/srt/layers/activation.py +142 -9
  58. sglang/srt/layers/attention/ascend_backend.py +11 -4
  59. sglang/srt/layers/attention/fla/chunk.py +242 -0
  60. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  61. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  62. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  63. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  66. sglang/srt/layers/attention/fla/index.py +37 -0
  67. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  68. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  69. sglang/srt/layers/attention/fla/op.py +66 -0
  70. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  71. sglang/srt/layers/attention/fla/utils.py +331 -0
  72. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  73. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  74. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  75. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  76. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  77. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  78. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  79. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  80. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  81. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  82. sglang/srt/layers/attention/triton_backend.py +18 -1
  83. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  84. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  85. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  86. sglang/srt/layers/dp_attention.py +30 -1
  87. sglang/srt/layers/layernorm.py +32 -15
  88. sglang/srt/layers/linear.py +34 -3
  89. sglang/srt/layers/logits_processor.py +29 -10
  90. sglang/srt/layers/moe/__init__.py +2 -1
  91. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  92. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  93. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  94. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  95. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  96. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  107. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  108. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  109. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  110. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  111. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  112. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  113. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  114. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  115. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  116. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  117. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  118. sglang/srt/layers/moe/topk.py +30 -9
  119. sglang/srt/layers/moe/utils.py +12 -6
  120. sglang/srt/layers/quantization/awq.py +19 -7
  121. sglang/srt/layers/quantization/base_config.py +11 -6
  122. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  123. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  124. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  125. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  126. sglang/srt/layers/quantization/fp8.py +76 -47
  127. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  128. sglang/srt/layers/quantization/gptq.py +25 -17
  129. sglang/srt/layers/quantization/modelopt_quant.py +147 -47
  130. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  131. sglang/srt/layers/quantization/mxfp4.py +64 -40
  132. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  133. sglang/srt/layers/quantization/unquant.py +135 -47
  134. sglang/srt/layers/quantization/w4afp8.py +30 -17
  135. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  136. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  137. sglang/srt/layers/sampler.py +162 -18
  138. sglang/srt/lora/backend/base_backend.py +50 -8
  139. sglang/srt/lora/backend/triton_backend.py +90 -2
  140. sglang/srt/lora/layers.py +32 -0
  141. sglang/srt/lora/lora.py +4 -1
  142. sglang/srt/lora/lora_manager.py +35 -112
  143. sglang/srt/lora/mem_pool.py +24 -10
  144. sglang/srt/lora/utils.py +18 -9
  145. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  146. sglang/srt/managers/cache_controller.py +158 -160
  147. sglang/srt/managers/data_parallel_controller.py +105 -35
  148. sglang/srt/managers/detokenizer_manager.py +8 -4
  149. sglang/srt/managers/disagg_service.py +46 -0
  150. sglang/srt/managers/io_struct.py +199 -12
  151. sglang/srt/managers/mm_utils.py +1 -0
  152. sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
  153. sglang/srt/managers/schedule_batch.py +77 -56
  154. sglang/srt/managers/schedule_policy.py +1 -1
  155. sglang/srt/managers/scheduler.py +187 -39
  156. sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
  157. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  158. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  159. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  160. sglang/srt/managers/tokenizer_manager.py +259 -519
  161. sglang/srt/managers/tp_worker.py +53 -4
  162. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  163. sglang/srt/mem_cache/hicache_storage.py +3 -23
  164. sglang/srt/mem_cache/hiradix_cache.py +103 -43
  165. sglang/srt/mem_cache/memory_pool.py +347 -48
  166. sglang/srt/mem_cache/memory_pool_host.py +105 -46
  167. sglang/srt/mem_cache/radix_cache.py +0 -2
  168. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  169. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  170. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
  171. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  172. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  173. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
  174. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  175. sglang/srt/metrics/collector.py +493 -76
  176. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  177. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  178. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  179. sglang/srt/model_executor/forward_batch_info.py +59 -2
  180. sglang/srt/model_executor/model_runner.py +356 -29
  181. sglang/srt/model_loader/__init__.py +9 -3
  182. sglang/srt/model_loader/loader.py +128 -4
  183. sglang/srt/model_loader/weight_utils.py +2 -1
  184. sglang/srt/models/apertus.py +686 -0
  185. sglang/srt/models/bailing_moe.py +798 -218
  186. sglang/srt/models/bailing_moe_nextn.py +168 -0
  187. sglang/srt/models/deepseek_v2.py +109 -15
  188. sglang/srt/models/dots_vlm.py +174 -0
  189. sglang/srt/models/dots_vlm_vit.py +337 -0
  190. sglang/srt/models/ernie4.py +1 -1
  191. sglang/srt/models/gemma3n_mm.py +1 -1
  192. sglang/srt/models/glm4_moe.py +1 -1
  193. sglang/srt/models/glm4v.py +4 -2
  194. sglang/srt/models/glm4v_moe.py +3 -0
  195. sglang/srt/models/gpt_oss.py +1 -1
  196. sglang/srt/models/llama4.py +9 -0
  197. sglang/srt/models/llama_eagle3.py +13 -0
  198. sglang/srt/models/longcat_flash.py +2 -2
  199. sglang/srt/models/mllama4.py +25 -0
  200. sglang/srt/models/opt.py +637 -0
  201. sglang/srt/models/qwen2.py +7 -0
  202. sglang/srt/models/qwen2_5_vl.py +27 -3
  203. sglang/srt/models/qwen2_moe.py +56 -12
  204. sglang/srt/models/qwen3_moe.py +1 -1
  205. sglang/srt/models/qwen3_next.py +1042 -0
  206. sglang/srt/models/qwen3_next_mtp.py +112 -0
  207. sglang/srt/models/step3_vl.py +1 -1
  208. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  209. sglang/srt/multimodal/processors/glm4v.py +9 -9
  210. sglang/srt/multimodal/processors/internvl.py +141 -129
  211. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  212. sglang/srt/offloader.py +27 -3
  213. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  214. sglang/srt/sampling/sampling_batch_info.py +18 -15
  215. sglang/srt/server_args.py +276 -35
  216. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  217. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  218. sglang/srt/speculative/eagle_utils.py +0 -2
  219. sglang/srt/speculative/eagle_worker.py +43 -4
  220. sglang/srt/speculative/spec_info.py +5 -0
  221. sglang/srt/speculative/standalone_worker.py +109 -0
  222. sglang/srt/tracing/trace.py +552 -0
  223. sglang/srt/utils.py +34 -3
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  226. sglang/test/runners.py +4 -0
  227. sglang/test/test_cutlass_moe.py +24 -6
  228. sglang/test/test_disaggregation_utils.py +66 -0
  229. sglang/test/test_fp4_moe.py +370 -1
  230. sglang/test/test_utils.py +28 -1
  231. sglang/utils.py +11 -0
  232. sglang/version.py +1 -1
  233. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  234. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
  235. sglang/srt/disaggregation/launch_lb.py +0 -118
  236. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  237. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  238. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -41,6 +41,10 @@ DEFAULT_CONFIG = {
41
41
  "v_head_dim": 512,
42
42
  "num_kv_heads": 1,
43
43
  "layer_id": 0,
44
+ "tp_q_head_num": 128,
45
+ "tp_k_head_num": 128,
46
+ "prefill_head_dim": 192,
47
+ "prefill_v_head_dim": 128,
44
48
  }
45
49
 
46
50
  ROPE_BASE = 10000
@@ -92,7 +96,7 @@ TEST_CASES = {
92
96
  "description": "Medium-scale batch",
93
97
  },
94
98
  ],
95
- "decode_output_match": [
99
+ "output_match": [
96
100
  {
97
101
  "name": "single_fp16",
98
102
  "batch_size": 1,
@@ -322,7 +326,7 @@ class TestTRTLLMMLA(CustomTestCase):
322
326
  config.update(test_case)
323
327
  return config
324
328
 
325
- def _create_model_components(self, config):
329
+ def _create_model_components(self, config, is_prefill=False):
326
330
  """Create model runners, backends, and layer for testing."""
327
331
  # Create model runners
328
332
  model_runner_trtllm = MockModelRunner(config)
@@ -332,14 +336,23 @@ class TestTRTLLMMLA(CustomTestCase):
332
336
  trtllm_backend = TRTLLMMLABackend(model_runner_trtllm)
333
337
  reference_backend = FlashInferMLAAttnBackend(model_runner_reference)
334
338
 
339
+ head_dim = (
340
+ config["kv_lora_rank"] + config["qk_rope_head_dim"]
341
+ if not is_prefill
342
+ else config["prefill_head_dim"]
343
+ )
344
+ v_head_dim = (
345
+ config["v_head_dim"] if not is_prefill else config["prefill_v_head_dim"]
346
+ )
347
+
335
348
  # Create RadixAttention layer
336
349
  layer = RadixAttention(
337
350
  num_heads=config["num_attention_heads"],
338
- head_dim=config["kv_lora_rank"] + config["qk_rope_head_dim"],
351
+ head_dim=head_dim,
339
352
  scaling=model_runner_trtllm.model_config.scaling,
340
353
  num_kv_heads=config["num_kv_heads"],
341
354
  layer_id=config["layer_id"],
342
- v_head_dim=config["v_head_dim"],
355
+ v_head_dim=v_head_dim,
343
356
  prefix="attn_mqa",
344
357
  )
345
358
 
@@ -524,7 +537,7 @@ class TestTRTLLMMLA(CustomTestCase):
524
537
  """Test that TRTLLM and FlashInfer MLA backends produce matching outputs."""
525
538
  print(f"\nRunning decode output matching tests...")
526
539
 
527
- for test_case in TEST_CASES["decode_output_match"]:
540
+ for test_case in TEST_CASES["output_match"]:
528
541
  with self.subTest(test_case=test_case["name"]):
529
542
  print(f" Testing {test_case['name']}: {test_case['description']}")
530
543
 
@@ -1099,6 +1112,157 @@ class TestTRTLLMMLA(CustomTestCase):
1099
1112
  self.assertIsNotNone(metadata_3.block_kv_indices)
1100
1113
  self.assertEqual(metadata_3.block_kv_indices.shape[0], config["batch_size"])
1101
1114
 
1115
+ def test_prefill_output_match_self_attention(self):
1116
+ """Test prefill (forward) behavior of TRTLLM MLA backend vs reference."""
1117
+ print(f"\nRunning prefill output tests...")
1118
+
1119
+ for test_case in TEST_CASES["output_match"][:2]: # Just a subset for speed
1120
+ with self.subTest(test_case=test_case["name"]):
1121
+ print(
1122
+ f"Prefill Testing {test_case['name']}: {test_case['description']}"
1123
+ )
1124
+
1125
+ config = self._merge_config(test_case)
1126
+ batch_size = config["batch_size"]
1127
+ max_seq_len = config["max_seq_len"]
1128
+
1129
+ # Create components
1130
+ (
1131
+ model_runner_trtllm,
1132
+ model_runner_reference,
1133
+ trtllm_backend,
1134
+ reference_backend,
1135
+ layer,
1136
+ ) = self._create_model_components(config, is_prefill=True)
1137
+
1138
+ # Prefill uses full sequences
1139
+ seq_lens = torch.full(
1140
+ (batch_size,), max_seq_len, device=config["device"]
1141
+ )
1142
+
1143
+ def _create_forward_batch_prefill(
1144
+ batch_size,
1145
+ seq_lens,
1146
+ extend_prefix_lens,
1147
+ backend,
1148
+ model_runner,
1149
+ config,
1150
+ ):
1151
+ """Create a forward batch for the given backend."""
1152
+
1153
+ fb = ForwardBatch(
1154
+ batch_size=batch_size,
1155
+ input_ids=torch.randint(
1156
+ 0, 100, (batch_size, 1), device=config["device"]
1157
+ ),
1158
+ out_cache_loc=torch.arange(batch_size, device=config["device"]),
1159
+ seq_lens_sum=int(seq_lens.sum().item()),
1160
+ extend_prefix_lens=extend_prefix_lens,
1161
+ extend_prefix_lens_cpu=extend_prefix_lens.cpu().int().tolist(),
1162
+ extend_seq_lens_cpu=(seq_lens - extend_prefix_lens)
1163
+ .cpu()
1164
+ .int()
1165
+ .tolist(),
1166
+ forward_mode=ForwardMode.EXTEND,
1167
+ req_pool_indices=torch.arange(
1168
+ batch_size, device=config["device"]
1169
+ ),
1170
+ seq_lens=seq_lens,
1171
+ seq_lens_cpu=seq_lens.cpu(),
1172
+ attn_attend_prefix_cache=False,
1173
+ mha_return_lse=False,
1174
+ attn_backend=backend,
1175
+ )
1176
+ fb.req_to_token_pool = model_runner.req_to_token_pool
1177
+ fb.token_to_kv_pool = model_runner.token_to_kv_pool
1178
+
1179
+ # Add position information for RoPE
1180
+ fb.positions = torch.arange(batch_size, device=config["device"])
1181
+
1182
+ return fb
1183
+
1184
+ # Create forward batches
1185
+ fb_trtllm = _create_forward_batch_prefill(
1186
+ batch_size,
1187
+ seq_lens.clone(),
1188
+ torch.zeros(batch_size, device=config["device"], dtype=torch.int32),
1189
+ trtllm_backend,
1190
+ model_runner_trtllm,
1191
+ config,
1192
+ )
1193
+ fb_reference = _create_forward_batch_prefill(
1194
+ batch_size,
1195
+ seq_lens.clone(),
1196
+ torch.zeros(batch_size, device=config["device"], dtype=torch.int32),
1197
+ reference_backend,
1198
+ model_runner_reference,
1199
+ config,
1200
+ )
1201
+
1202
+ # Initialize metadata for both backends
1203
+ trtllm_backend.init_forward_metadata(fb_trtllm)
1204
+ reference_backend.init_forward_metadata(fb_reference)
1205
+
1206
+ # Create Q, K, V tensors for prefill
1207
+ torch.manual_seed(config["seed_qkv"])
1208
+
1209
+ def _create_qkv_tensors_prefill(
1210
+ batch_size, seq_len, config, dtype_override=None
1211
+ ):
1212
+ """Create Q, K, V tensors for prefill, using config for head_num and head_dim."""
1213
+ device = config["device"]
1214
+ dtype = dtype_override or config["dtype"]
1215
+
1216
+ total_tokens = batch_size * seq_len
1217
+
1218
+ tp_q_head_num = config["tp_q_head_num"]
1219
+ tp_k_head_num = config["tp_k_head_num"]
1220
+ head_dim = config["prefill_head_dim"]
1221
+ v_head_dim = config["prefill_v_head_dim"]
1222
+
1223
+ q = torch.randn(
1224
+ (total_tokens, tp_q_head_num * head_dim),
1225
+ dtype=dtype,
1226
+ device=device,
1227
+ )
1228
+ k = torch.randn(
1229
+ (total_tokens, tp_k_head_num * head_dim),
1230
+ dtype=dtype,
1231
+ device=device,
1232
+ )
1233
+ v = torch.randn(
1234
+ (total_tokens, tp_k_head_num * v_head_dim),
1235
+ dtype=dtype,
1236
+ device=device,
1237
+ )
1238
+
1239
+ # Reshape as requested
1240
+ q = q.view(-1, tp_q_head_num, head_dim)
1241
+ k = k.view(-1, tp_k_head_num, head_dim)
1242
+ v = v.view(-1, tp_k_head_num, v_head_dim)
1243
+
1244
+ return q, k, v
1245
+
1246
+ q, k, v = _create_qkv_tensors_prefill(batch_size, max_seq_len, config)
1247
+ # Run prefill on both backends
1248
+ out_trtllm = trtllm_backend.forward_extend(
1249
+ q, k, v, layer, fb_trtllm, False
1250
+ ).view(-1, layer.tp_q_head_num * layer.v_head_dim)
1251
+ out_reference = reference_backend.forward_extend(
1252
+ q, k, v, layer, fb_reference, False
1253
+ )
1254
+
1255
+ tolerance = config.get("tolerance", 1e-2)
1256
+ comparison_passed = compare_outputs(
1257
+ out_trtllm, out_reference, tolerance=tolerance
1258
+ )
1259
+ self.assertTrue(
1260
+ comparison_passed,
1261
+ f"TRTLLM and Reference prefill outputs differ beyond tolerance. "
1262
+ f"Config: {test_case['name']}, "
1263
+ f"Max diff: {(out_trtllm - out_reference).abs().max().item()}",
1264
+ )
1265
+
1102
1266
 
1103
1267
  if __name__ == "__main__":
1104
1268
  unittest.main()
sglang/test/runners.py CHANGED
@@ -505,6 +505,7 @@ class SRTRunner:
505
505
  mem_fraction_static: float = 0.65,
506
506
  trust_remote_code: bool = False,
507
507
  speculative_draft_model_path: Optional[str] = None,
508
+ speculative_draft_model_revision: Optional[str] = None,
508
509
  speculative_algorithm: Optional[str] = None,
509
510
  speculative_num_steps: Optional[int] = None,
510
511
  speculative_eagle_topk: Optional[int] = None,
@@ -526,6 +527,9 @@ class SRTRunner:
526
527
  spec_kwargs = {}
527
528
  if speculative_draft_model_path:
528
529
  spec_kwargs["speculative_draft_model_path"] = speculative_draft_model_path
530
+ spec_kwargs["speculative_draft_model_revision"] = (
531
+ speculative_draft_model_revision
532
+ )
529
533
  spec_kwargs["speculative_algorithm"] = speculative_algorithm
530
534
  spec_kwargs["speculative_num_steps"] = speculative_num_steps
531
535
  spec_kwargs["speculative_eagle_topk"] = speculative_eagle_topk
@@ -9,6 +9,7 @@ from transformers import AutoConfig
9
9
  from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
10
10
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
11
11
  from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
12
+ from sglang.srt.layers.moe.topk import StandardTopKOutput
12
13
 
13
14
 
14
15
  # Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py
@@ -21,7 +22,7 @@ def calc_diff(x, y):
21
22
 
22
23
  def get_model_config(tp_size: int):
23
24
  config = AutoConfig.from_pretrained(
24
- "deepseek-ai/deepseek-R1", trust_remote_code=True
25
+ "deepseek-ai/Deepseek-R1", trust_remote_code=True
25
26
  )
26
27
  E = config.n_routed_experts
27
28
  topk = config.num_experts_per_tok
@@ -152,14 +153,31 @@ def run_test(tp_size, batch_size, model_config, check=False):
152
153
  problem_sizes2,
153
154
  )
154
155
 
156
+ topk_output = StandardTopKOutput(
157
+ topk_weights=topk_weights,
158
+ topk_ids=topk_ids,
159
+ router_logits=torch.randn(
160
+ (batch_size, topk), device=topk_weights.device, dtype=dtype
161
+ ),
162
+ )
163
+
164
+ moe_runner_config = MoeRunnerConfig(
165
+ num_experts=E,
166
+ top_k=topk,
167
+ hidden_size=H,
168
+ intermediate_size_per_partition=I,
169
+ params_dtype=dtype,
170
+ activation="silu",
171
+ inplace=False,
172
+ )
173
+
155
174
  # Note: Triton expects non-transposed weights
156
- moe_config = MoeRunnerConfig(inplace=False)
157
175
  triton_lambda = lambda: fused_experts(
158
176
  x,
159
177
  w1,
160
178
  w2,
161
- (topk_weights, topk_ids, "dummy"),
162
- moe_config,
179
+ topk_output,
180
+ moe_runner_config,
163
181
  use_fp8_w8a8=True,
164
182
  w1_scale=w1_scale,
165
183
  w2_scale=w2_scale,
@@ -224,8 +242,8 @@ def run_test(tp_size, batch_size, model_config, check=False):
224
242
  x,
225
243
  w1, # Original shape
226
244
  w2, # Original shape
227
- (topk_weights, topk_ids, "dummy"),
228
- moe_config,
245
+ topk_output,
246
+ moe_runner_config,
229
247
  use_fp8_w8a8=True,
230
248
  w1_scale=w1_scale,
231
249
  w2_scale=w2_scale,
@@ -0,0 +1,66 @@
1
+ import time
2
+
3
+ import requests
4
+
5
+ from sglang.srt.utils import kill_process_tree
6
+ from sglang.test.test_utils import (
7
+ DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
8
+ CustomTestCase,
9
+ popen_with_error_check,
10
+ )
11
+
12
+
13
+ class TestDisaggregationBase(CustomTestCase):
14
+ @classmethod
15
+ def setUpClass(cls):
16
+ cls.process_lb, cls.process_decode, cls.process_prefill = None, None, None
17
+ pass
18
+
19
+ @classmethod
20
+ def launch_lb(cls):
21
+ lb_command = [
22
+ "python3",
23
+ "-m",
24
+ "sglang_router.launch_router",
25
+ "--pd-disaggregation",
26
+ "--mini-lb", # FIXME: remove this
27
+ "--prefill",
28
+ cls.prefill_url,
29
+ "--decode",
30
+ cls.decode_url,
31
+ "--host",
32
+ cls.base_host,
33
+ "--port",
34
+ cls.lb_port,
35
+ ]
36
+ print("Starting load balancer:", " ".join(lb_command))
37
+ cls.process_lb = popen_with_error_check(lb_command)
38
+ cls.wait_server_ready(cls.lb_url + "/health")
39
+
40
+ @classmethod
41
+ def wait_server_ready(cls, url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH):
42
+ start_time = time.perf_counter()
43
+ while True:
44
+ try:
45
+ response = requests.get(url)
46
+ if response.status_code == 200:
47
+ print(f"Server {url} is ready")
48
+ return
49
+ except Exception:
50
+ pass
51
+
52
+ if time.perf_counter() - start_time > timeout:
53
+ raise RuntimeError(f"Server {url} failed to start in {timeout}s")
54
+ time.sleep(1)
55
+
56
+ @classmethod
57
+ def tearDownClass(cls):
58
+ for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
59
+ if process:
60
+ try:
61
+ kill_process_tree(process.pid)
62
+ except Exception as e:
63
+ print(f"Error killing process {process.pid}: {e}")
64
+
65
+ # wait for 5 seconds
66
+ time.sleep(5)