sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -43,6 +43,37 @@ DEFAULT_CONFIG = {
43
43
  "layer_id": 0,
44
44
  }
45
45
 
46
+ ROPE_BASE = 10000
47
+ ROPE_SCALING_CONFIG = {
48
+ "beta_fast": 32,
49
+ "beta_slow": 1,
50
+ "factor": 40,
51
+ "mscale": 1.0,
52
+ "mscale_all_dim": 1.0,
53
+ "original_max_position_embeddings": 4096,
54
+ "type": "yarn",
55
+ "rope_type": "deepseek_yarn",
56
+ }
57
+
58
+
59
+ def build_rotary_emb(config, device=None):
60
+ from sglang.srt.layers.rotary_embedding import get_rope_wrapper
61
+
62
+ dev = device or config["device"]
63
+ rope_scaling = config.get("rope_scaling", ROPE_SCALING_CONFIG)
64
+ rotary = get_rope_wrapper(
65
+ head_size=config["qk_rope_head_dim"],
66
+ rotary_dim=config["qk_rope_head_dim"],
67
+ max_position=config["context_len"],
68
+ base=ROPE_BASE,
69
+ rope_scaling=rope_scaling,
70
+ is_neox_style=False,
71
+ device=dev,
72
+ )
73
+ rotary.cos_sin_cache = rotary.cos_sin_cache.to(dev)
74
+ return rotary
75
+
76
+
46
77
  # Centralized test cases for different test scenarios
47
78
  TEST_CASES = {
48
79
  "basic_functionality": [
@@ -63,18 +94,36 @@ TEST_CASES = {
63
94
  ],
64
95
  "decode_output_match": [
65
96
  {
66
- "name": "single",
97
+ "name": "single_fp16",
67
98
  "batch_size": 1,
68
99
  "max_seq_len": 64,
69
100
  "page_size": 32,
70
- "description": "Single vs reference",
101
+ "description": "Single FP16 vs reference",
71
102
  },
72
103
  {
73
- "name": "batch",
104
+ "name": "single_fp8",
105
+ "batch_size": 1,
106
+ "max_seq_len": 64,
107
+ "page_size": 64,
108
+ "tolerance": 1e-1,
109
+ "kv_cache_dtype": torch.float8_e4m3fn,
110
+ "description": "Single FP8 vs reference",
111
+ },
112
+ {
113
+ "name": "batch_fp16",
74
114
  "batch_size": 32,
75
115
  "max_seq_len": 64,
76
116
  "page_size": 32,
77
- "description": "Batch vs reference",
117
+ "description": "Batch FP16 vs reference",
118
+ },
119
+ {
120
+ "name": "batch_fp8",
121
+ "batch_size": 32,
122
+ "max_seq_len": 64,
123
+ "page_size": 64,
124
+ "tolerance": 1e-1,
125
+ "kv_cache_dtype": torch.float8_e4m3fn,
126
+ "description": "Batch FP8 vs reference",
78
127
  },
79
128
  ],
80
129
  "page_size_consistency": [
@@ -293,26 +342,52 @@ class TestTRTLLMMLA(CustomTestCase):
293
342
  layer,
294
343
  )
295
344
 
296
- def _create_qkv_tensors(self, batch_size, config):
297
- """Create Q, K, V tensors for testing."""
298
- head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"]
345
+ def _create_qkv_tensors(self, batch_size, config, dtype_override=None):
346
+ """Create Q, K, V random tensors for given batch size with separate MLA components.
347
+
348
+ Args:
349
+ batch_size: Batch size.
350
+ config: Configuration dict with model dims and device.
351
+ dtype_override: Optional torch dtype to override config["dtype"].
352
+
353
+ Returns:
354
+ Tuple of (q_nope, q_rope, k_nope, k_rope, v, cos_sin_cache)
355
+ """
299
356
  device = config["device"]
300
- dtype = config["dtype"]
357
+ target_dtype = dtype_override or config["dtype"]
301
358
 
302
- q = torch.randn(
303
- (batch_size, config["num_attention_heads"], head_dim),
304
- dtype=dtype,
359
+ # Create separate nope and rope components for Q
360
+ q_nope = torch.randn(
361
+ (batch_size, config["num_attention_heads"], config["kv_lora_rank"]),
362
+ dtype=config["dtype"],
305
363
  device=device,
306
364
  )
307
- k = torch.randn(
308
- (batch_size, config["num_kv_heads"], head_dim), dtype=dtype, device=device
365
+ q_rope = torch.randn(
366
+ (batch_size, config["num_attention_heads"], config["qk_rope_head_dim"]),
367
+ dtype=config["dtype"],
368
+ device=device,
369
+ )
370
+
371
+ # Create separate nope and rope components for K
372
+ k_nope = torch.randn(
373
+ (batch_size, config["num_kv_heads"], config["kv_lora_rank"]),
374
+ dtype=config["dtype"],
375
+ device=device,
376
+ )
377
+ k_rope = torch.randn(
378
+ (batch_size, config["num_kv_heads"], config["qk_rope_head_dim"]),
379
+ dtype=config["dtype"],
380
+ device=device,
309
381
  )
382
+
383
+ # V tensor (unchanged)
310
384
  v = torch.randn(
311
385
  (batch_size, config["num_kv_heads"], config["v_head_dim"]),
312
- dtype=dtype,
386
+ dtype=config["dtype"],
313
387
  device=device,
314
388
  )
315
- return q, k, v
389
+
390
+ return q_nope, q_rope, k_nope, k_rope, v
316
391
 
317
392
  def _create_forward_batch(
318
393
  self, batch_size, seq_lens, backend, model_runner, config
@@ -331,6 +406,10 @@ class TestTRTLLMMLA(CustomTestCase):
331
406
  )
332
407
  fb.req_to_token_pool = model_runner.req_to_token_pool
333
408
  fb.token_to_kv_pool = model_runner.token_to_kv_pool
409
+
410
+ # Add position information for RoPE
411
+ fb.positions = torch.arange(batch_size, device=config["device"])
412
+
334
413
  return fb
335
414
 
336
415
  def _populate_kv_cache(self, batch_size, seq_lens, model_runners, layer, config):
@@ -344,7 +423,7 @@ class TestTRTLLMMLA(CustomTestCase):
344
423
  for token_idx in range(seq_len - 1):
345
424
  # Create random K components for MLA
346
425
  cache_k_nope = torch.randn(
347
- (1, config["qk_nope_head_dim"]),
426
+ (1, config["kv_lora_rank"]),
348
427
  dtype=config["dtype"],
349
428
  device=config["device"],
350
429
  )
@@ -411,12 +490,16 @@ class TestTRTLLMMLA(CustomTestCase):
411
490
  batch_size, seq_lens, [model_runner_trtllm], layer, config
412
491
  )
413
492
 
414
- # Create Q, K, V tensors
493
+ # Create Q, K, V tensors with separate MLA components
415
494
  torch.manual_seed(config["seed_qkv"])
416
- q, k, v = self._create_qkv_tensors(batch_size, config)
495
+ q_nope, q_rope, k_nope, k_rope, v = self._create_qkv_tensors(
496
+ batch_size, config
497
+ )
417
498
 
418
- # Run forward decode
419
- output = trtllm_backend.forward_decode(q, k, v, layer, fb)
499
+ # Run forward decode with separate MLA components
500
+ output = trtllm_backend.forward_decode(
501
+ q_nope, k_nope, None, layer, fb, q_rope=q_rope, k_rope=k_rope
502
+ )
420
503
 
421
504
  # Basic checks
422
505
  expected_shape = (
@@ -439,6 +522,7 @@ class TestTRTLLMMLA(CustomTestCase):
439
522
  config = self._merge_config(test_case)
440
523
  batch_size = config["batch_size"]
441
524
  max_seq_len = config["max_seq_len"]
525
+ use_fp8 = config["kv_cache_dtype"] == torch.float8_e4m3fn
442
526
 
443
527
  # Create components
444
528
  (
@@ -487,19 +571,66 @@ class TestTRTLLMMLA(CustomTestCase):
487
571
 
488
572
  # Create Q, K, V tensors for current decode step
489
573
  torch.manual_seed(config["seed_qkv"])
490
- q, k, v = self._create_qkv_tensors(batch_size, config)
574
+
575
+ q_nope_ref, q_rope_ref, k_nope_ref, k_rope_ref, v_ref = (
576
+ self._create_qkv_tensors(batch_size, config)
577
+ )
578
+ q_nope_trt, q_rope_trt, k_nope_trt, k_rope_trt, v_trt = (
579
+ q_nope_ref.clone(),
580
+ q_rope_ref.clone(),
581
+ k_nope_ref.clone(),
582
+ k_rope_ref.clone(),
583
+ v_ref.clone(),
584
+ )
585
+ tolerance = config["tolerance"]
586
+
587
+ extra_args = {}
588
+ if use_fp8:
589
+ # TRT kernel applies RoPE + FP8 quantization internally
590
+ # pre-apply RoPE on the reference (FlashInfer) path here so
591
+ # both paths share the same rope params/cache while keeping
592
+ # the TRT path unrotated.
593
+ rotary_emb = build_rotary_emb(config)
594
+ q_rope_ref, k_rope_ref = rotary_emb(
595
+ fb_reference.positions, q_rope_ref, k_rope_ref
596
+ )
597
+ extra_args = {
598
+ "cos_sin_cache": rotary_emb.cos_sin_cache,
599
+ "is_neox": rotary_emb.is_neox_style,
600
+ }
601
+
602
+ dtype = q_rope_ref.dtype
603
+ q_rope_ref = q_rope_ref.to(torch.float8_e4m3fn).to(dtype)
604
+ q_nope_ref = q_nope_ref.to(torch.float8_e4m3fn).to(dtype)
605
+ k_rope_ref = k_rope_ref.to(torch.float8_e4m3fn).to(dtype)
606
+ k_nope_ref = k_nope_ref.to(torch.float8_e4m3fn).to(dtype)
491
607
 
492
608
  # Run forward decode on both backends
493
609
  out_trtllm = trtllm_backend.forward_decode(
494
- q.clone(), k.clone(), v.clone(), layer, fb_trtllm
610
+ q_nope_trt,
611
+ k_nope_trt,
612
+ None,
613
+ layer,
614
+ fb_trtllm,
615
+ q_rope=q_rope_trt,
616
+ k_rope=k_rope_trt,
617
+ **extra_args,
495
618
  )
619
+
620
+ # Reference backend should also take separate components, not concatenated
496
621
  out_reference = reference_backend.forward_decode(
497
- q.clone(), k.clone(), v.clone(), layer, fb_reference
622
+ q_nope_ref,
623
+ k_nope_ref,
624
+ v_ref,
625
+ layer,
626
+ fb_reference,
627
+ q_rope=q_rope_ref,
628
+ k_rope=k_rope_ref,
498
629
  )
499
630
 
500
631
  # Compare outputs
501
632
  comparison_passed = compare_outputs(
502
- out_trtllm, out_reference, tolerance=config["tolerance"]
633
+ out_trtllm, out_reference, tolerance=tolerance
503
634
  )
504
635
 
505
636
  self.assertTrue(
@@ -544,12 +675,16 @@ class TestTRTLLMMLA(CustomTestCase):
544
675
  batch_size, seq_lens, [model_runner], layer, config
545
676
  )
546
677
 
547
- # Create Q, K, V tensors
678
+ # Create Q, K, V tensors with separate MLA components
548
679
  torch.manual_seed(config["seed_qkv"])
549
- q, k, v = self._create_qkv_tensors(batch_size, config)
680
+ q_nope, q_rope, k_nope, k_rope, v = self._create_qkv_tensors(
681
+ batch_size, config
682
+ )
550
683
 
551
- # Run forward decode
552
- output = backend.forward_decode(q, k, v, layer, fb)
684
+ # Run forward decode with separate MLA components
685
+ output = backend.forward_decode(
686
+ q_nope, k_nope, None, layer, fb, q_rope=q_rope, k_rope=k_rope
687
+ )
553
688
 
554
689
  expected_shape = (
555
690
  batch_size,
@@ -591,23 +726,38 @@ class TestTRTLLMMLA(CustomTestCase):
591
726
  )
592
727
  backend.init_forward_metadata(fb)
593
728
 
594
- # Create Q, K, V tensors
729
+ # Create Q, K, V tensors with separate MLA components
595
730
  torch.manual_seed(config["seed_qkv"])
596
- head_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"]
597
- q = torch.randn(
598
- (batch_size, config["num_attention_heads"], head_dim),
731
+ q_nope = torch.randn(
732
+ (batch_size, config["num_attention_heads"], config["kv_lora_rank"]),
599
733
  dtype=config["dtype"],
600
734
  device=config["device"],
601
735
  )
602
- k = torch.randn(
603
- (batch_size, config["num_kv_heads"], head_dim),
736
+ k_nope = torch.randn(
737
+ (batch_size, config["num_kv_heads"], config["kv_lora_rank"]),
604
738
  dtype=config["dtype"],
605
739
  device=config["device"],
606
740
  )
607
- v = None
741
+ q_rope = torch.randn(
742
+ (
743
+ batch_size,
744
+ config["num_attention_heads"],
745
+ config["qk_rope_head_dim"],
746
+ ),
747
+ dtype=config["dtype"],
748
+ device=config["device"],
749
+ )
750
+ k_rope = torch.randn(
751
+ (batch_size, config["num_kv_heads"], config["qk_rope_head_dim"]),
752
+ dtype=config["dtype"],
753
+ device=config["device"],
754
+ )
755
+ v = None # Test with None v
608
756
 
609
757
  # Run forward decode
610
- output = backend.forward_decode(q, k, v, layer, fb)
758
+ output = backend.forward_decode(
759
+ q_nope, k_nope, v, layer, fb, q_rope=q_rope, k_rope=k_rope
760
+ )
611
761
 
612
762
  # Shape and sanity checks
613
763
  expected_shape = (
@@ -0,0 +1,59 @@
1
+ """
2
+ Do some monkey patch to make the documentation compilation faster and more reliable.
3
+
4
+ - Avoid port conflicts
5
+ - Reduce the server launch time
6
+ """
7
+
8
+ import weakref
9
+
10
+ import nest_asyncio
11
+
12
+ nest_asyncio.apply()
13
+
14
+ import sglang.srt.server_args as server_args_mod
15
+ from sglang.utils import execute_shell_command, reserve_port
16
+
17
+ DEFAULT_MAX_RUNNING_REQUESTS = 128
18
+ DEFAULT_MAX_TOTAL_TOKENS = 20480 # To allow multiple servers on the same machine
19
+
20
+ _original_post_init = server_args_mod.ServerArgs.__post_init__
21
+
22
+
23
+ def patched_post_init(self):
24
+ _original_post_init(self)
25
+ if self.max_running_requests is None:
26
+ self.max_running_requests = DEFAULT_MAX_RUNNING_REQUESTS
27
+ if self.max_total_tokens is None:
28
+ self.max_total_tokens = DEFAULT_MAX_TOTAL_TOKENS
29
+ self.cuda_graph_max_bs = 4
30
+
31
+
32
+ server_args_mod.ServerArgs.__post_init__ = patched_post_init
33
+
34
+ process_socket_map = weakref.WeakKeyDictionary()
35
+
36
+
37
+ def launch_server_cmd(command: str, host: str = "0.0.0.0", port: int = None):
38
+ """
39
+ Launch the server using the given command.
40
+ If no port is specified, a free port is reserved.
41
+ """
42
+ if port is None:
43
+ port, lock_socket = reserve_port(host)
44
+ else:
45
+ lock_socket = None
46
+
47
+ extra_flags = (
48
+ f"--max-running-requests {DEFAULT_MAX_RUNNING_REQUESTS} "
49
+ f"--max-total-tokens {DEFAULT_MAX_TOTAL_TOKENS} "
50
+ f"--cuda-graph-max-bs 4"
51
+ )
52
+
53
+ full_command = f"{command} --port {port} {extra_flags}"
54
+ process = execute_shell_command(full_command)
55
+
56
+ if lock_socket is not None:
57
+ process_socket_map[process] = lock_socket
58
+
59
+ return process, port
@@ -12,7 +12,7 @@ import time
12
12
 
13
13
  import numpy as np
14
14
 
15
- from sglang.api import set_default_backend
15
+ from sglang.lang.api import set_default_backend
16
16
  from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
17
17
  from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
18
18
 
@@ -8,7 +8,7 @@ import time
8
8
  import numpy as np
9
9
 
10
10
  import sglang as sgl
11
- from sglang.api import set_default_backend
11
+ from sglang.lang.api import set_default_backend
12
12
  from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
13
13
  from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
14
14
 
sglang/test/run_eval.py CHANGED
@@ -65,9 +65,10 @@ def run_eval(args):
65
65
 
66
66
  sampler = ChatCompletionSampler(
67
67
  model=args.model,
68
- max_tokens=2048,
68
+ max_tokens=getattr(args, "max_tokens", 2048),
69
69
  base_url=base_url,
70
70
  temperature=getattr(args, "temperature", 0.0),
71
+ reasoning_effort=getattr(args, "reasoning_effort", None),
71
72
  )
72
73
 
73
74
  # Run eval
@@ -120,7 +121,9 @@ if __name__ == "__main__":
120
121
  parser.add_argument("--eval-name", type=str, default="mmlu")
121
122
  parser.add_argument("--num-examples", type=int)
122
123
  parser.add_argument("--num-threads", type=int, default=512)
124
+ parser.add_argument("--max-tokens", type=int, default=2048)
123
125
  parser.add_argument("--temperature", type=float, default=0.0)
126
+ parser.add_argument("--reasoning-effort", type=str)
124
127
  args = parser.parse_args()
125
128
 
126
129
  run_eval(args)
sglang/test/runners.py CHANGED
@@ -568,8 +568,8 @@ class SRTRunner:
568
568
  else:
569
569
  self.tokenizer = None
570
570
 
571
- def load_lora_adapter(self, lora_name: str, lora_path: str):
572
- return self.engine.load_lora_adapter(lora_name, lora_path)
571
+ def load_lora_adapter(self, lora_name: str, lora_path: str, pinned: bool = False):
572
+ return self.engine.load_lora_adapter(lora_name, lora_path, pinned)
573
573
 
574
574
  def unload_lora_adapter(self, lora_name: str):
575
575
  return self.engine.unload_lora_adapter(lora_name)
@@ -91,6 +91,7 @@ class ChatCompletionSampler(SamplerBase):
91
91
  model: Optional[str] = None,
92
92
  system_message: Optional[str] = None,
93
93
  temperature: float = 0.0,
94
+ reasoning_effort: Optional[str] = None,
94
95
  max_tokens: int = 2048,
95
96
  ):
96
97
  self.client = OpenAI(base_url=base_url, http_client=LargerHttpxClient())
@@ -102,7 +103,11 @@ class ChatCompletionSampler(SamplerBase):
102
103
  self.system_message = system_message
103
104
  self.temperature = temperature
104
105
  self.max_tokens = max_tokens
106
+ self.reasoning_effort = reasoning_effort
105
107
  self.image_format = "url"
108
+ print(
109
+ f"ChatCompletionSampler initialized with {self.system_message=} {self.temperature=} {self.max_tokens=} {self.reasoning_effort=}"
110
+ )
106
111
 
107
112
  def _handle_image(
108
113
  self,
@@ -138,6 +143,7 @@ class ChatCompletionSampler(SamplerBase):
138
143
  messages=message_list,
139
144
  temperature=self.temperature,
140
145
  max_tokens=self.max_tokens,
146
+ reasoning_effort=self.reasoning_effort,
141
147
  )
142
148
  return response.choices[0].message.content
143
149
  # NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are rerunning MMMU
@@ -71,6 +71,8 @@ class GPQAEval(Eval):
71
71
  )
72
72
  ]
73
73
  response_text = sampler(prompt_messages)
74
+ if response_text is None:
75
+ response_text = ""
74
76
  match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
75
77
  extracted_answer = match.group(1) if match else None
76
78
  score = 1.0 if extracted_answer == correct_answer else 0.0
@@ -1,6 +1,9 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
+ from typing import Callable
3
+
2
4
  import pytest
3
5
  import torch
6
+ from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
4
7
  from sgl_kernel import scaled_fp4_quant
5
8
 
6
9
  from sglang.srt.layers.activation import SiluAndMul
@@ -111,15 +114,16 @@ def torch_moe(a, w1, w2, score, topk, expert_map):
111
114
  ).sum(dim=1)
112
115
 
113
116
 
114
- @pytest.mark.parametrize("m,n,k", MNK_FACTORS)
115
- @pytest.mark.parametrize("e", [40, 64, 256])
116
- @pytest.mark.parametrize("topk", [1, 6, 8])
117
- @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
118
- @torch.inference_mode()
119
- def test_cutlass_fp4_moe_no_graph(
120
- m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype
117
+ def check_moe(
118
+ m: int,
119
+ n: int,
120
+ k: int,
121
+ e: int,
122
+ topk: int,
123
+ dtype: torch.dtype,
124
+ moe_impl: Callable,
125
+ flip_w13: bool,
121
126
  ):
122
-
123
127
  torch.manual_seed(7)
124
128
  a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
125
129
  w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
@@ -167,38 +171,18 @@ def test_cutlass_fp4_moe_no_graph(
167
171
 
168
172
  a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
169
173
  a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
170
- # strides for the cutlass moe_fp4 kernel
171
- ab_strides_13 = torch.full(
172
- (e,), w1_q.shape[2] * 2, dtype=torch.int64, device=w1_q.device
173
- )
174
- c_strides_13 = torch.full(
175
- (e,), w1_q.shape[1], dtype=torch.int64, device=w1_q.device
176
- )
177
- ab_strides_2 = torch.full(
178
- (e,), w2_q.shape[2] * 2, dtype=torch.int64, device=w2_q.device
179
- )
180
- c_strides_2 = torch.full((e,), w2_q.shape[1], dtype=torch.int64, device=w2_q.device)
181
- params = CutlassMoEParams(
182
- CutlassMoEType.BlockscaledFP4,
183
- device=a.device,
184
- num_experts=e,
185
- intermediate_size_per_partition=n, # n
186
- hidden_size=k,
187
- ) # k
188
- cutlass_output = cutlass_moe_fp4(
174
+ test_output = moe_impl(
189
175
  a=a,
190
- a1_gscale=a1_gs,
191
- w1_fp4=w1_q,
176
+ topk_weights=topk_weights,
177
+ topk_ids=topk_ids,
178
+ w1_q=w1_q,
179
+ w2_q=w2_q,
180
+ a1_gs=a1_gs,
192
181
  w1_blockscale=w1_blockscale,
193
182
  w1_alphas=(1 / w1_gs),
194
- a2_gscale=a2_gs,
195
- w2_fp4=w2_q,
183
+ a2_gs=a2_gs,
196
184
  w2_blockscale=w2_blockscale,
197
185
  w2_alphas=(1 / w2_gs),
198
- topk_weights=topk_weights,
199
- topk_ids=topk_ids,
200
- params=params,
201
- apply_router_weight_on_input=False,
202
186
  )
203
187
 
204
188
  # Reference check:
@@ -237,10 +221,108 @@ def test_cutlass_fp4_moe_no_graph(
237
221
  block_size=quant_blocksize,
238
222
  )
239
223
 
224
+ if flip_w13:
225
+ dim = -2
226
+ size = w1_d.size(dim)
227
+ assert size % 2 == 0, f"Expected even size in dim {dim}, got {size}"
228
+ half = size // 2
229
+ # Reorder weight
230
+ w1, w3 = w1_d.split(half, dim=dim)
231
+ w1_d = torch.cat([w3, w1], dim=dim).contiguous()
232
+
240
233
  torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk, None)
241
234
 
242
- torch.testing.assert_close(torch_output, cutlass_output, atol=1e-1, rtol=1e-1)
235
+ torch.testing.assert_close(torch_output, test_output, atol=1e-1, rtol=1e-1)
236
+
237
+
238
+ @pytest.mark.parametrize("m,n,k", MNK_FACTORS)
239
+ @pytest.mark.parametrize("e", [40, 64, 256])
240
+ @pytest.mark.parametrize("topk", [1, 6, 8])
241
+ @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
242
+ @torch.inference_mode()
243
+ def test_cutlass_fp4_moe_no_graph(
244
+ m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype
245
+ ):
246
+ def cutlass_moe_impl(
247
+ a,
248
+ topk_weights,
249
+ topk_ids,
250
+ w1_q,
251
+ w2_q,
252
+ a1_gs,
253
+ w1_blockscale,
254
+ w1_alphas,
255
+ a2_gs,
256
+ w2_blockscale,
257
+ w2_alphas,
258
+ ):
259
+ params = CutlassMoEParams(
260
+ CutlassMoEType.BlockscaledFP4,
261
+ device=a.device,
262
+ num_experts=e,
263
+ intermediate_size_per_partition=n, # n
264
+ hidden_size=k,
265
+ ) # k
266
+ return cutlass_moe_fp4(
267
+ a=a,
268
+ a1_gscale=a1_gs,
269
+ w1_fp4=w1_q,
270
+ w1_blockscale=w1_blockscale,
271
+ w1_alphas=w1_alphas,
272
+ a2_gscale=a2_gs,
273
+ w2_fp4=w2_q,
274
+ w2_blockscale=w2_blockscale,
275
+ w2_alphas=w2_alphas,
276
+ topk_weights=topk_weights,
277
+ topk_ids=topk_ids,
278
+ params=params,
279
+ apply_router_weight_on_input=False,
280
+ )
281
+
282
+ check_moe(m, n, k, e, topk, dtype, cutlass_moe_impl, flip_w13=False)
283
+
284
+
285
+ @pytest.mark.parametrize("m,n,k", MNK_FACTORS)
286
+ @pytest.mark.parametrize("e", [40, 64, 256])
287
+ @pytest.mark.parametrize("topk", [1, 6, 8])
288
+ @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
289
+ @torch.inference_mode()
290
+ def test_flashinfer_fp4_moe_no_graph(
291
+ m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype
292
+ ):
293
+ def flashinfer_moe_impl(
294
+ a,
295
+ topk_weights,
296
+ topk_ids,
297
+ w1_q,
298
+ w2_q,
299
+ a1_gs,
300
+ w1_blockscale,
301
+ w1_alphas,
302
+ a2_gs,
303
+ w2_blockscale,
304
+ w2_alphas,
305
+ ):
306
+ return flashinfer_cutlass_fused_moe(
307
+ a,
308
+ topk_ids.to(torch.int),
309
+ topk_weights,
310
+ w1_q.view(torch.long),
311
+ w2_q.view(torch.long),
312
+ a.dtype,
313
+ quant_scales=[
314
+ a1_gs,
315
+ w1_blockscale.view(torch.int32),
316
+ w1_alphas,
317
+ a2_gs,
318
+ w2_blockscale.view(torch.int32),
319
+ w2_alphas,
320
+ ],
321
+ )[0]
322
+
323
+ check_moe(m, n, k, e, topk, dtype, flashinfer_moe_impl, flip_w13=True)
243
324
 
244
325
 
245
326
  if __name__ == "__main__":
246
327
  test_cutlass_fp4_moe_no_graph(224, 1024, 1024, 256, 8, torch.half)
328
+ test_flashinfer_fp4_moe_no_graph(224, 1024, 1024, 256, 8, torch.half)
sglang/test/test_utils.py CHANGED
@@ -83,7 +83,7 @@ DEFAULT_ENABLE_THINKING_MODEL_NAME_FOR_TEST = "Qwen/Qwen3-30B-A3B"
83
83
  DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1 = "meta-llama/Llama-3.1-8B-Instruct,mistralai/Mistral-7B-Instruct-v0.3,deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,google/gemma-2-27b-it"
84
84
  DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Llama-3.1-70B-Instruct,mistralai/Mixtral-8x7B-Instruct-v0.1,Qwen/Qwen2-57B-A14B-Instruct"
85
85
  DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8,neuralmagic/Mistral-7B-Instruct-v0.3-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,neuralmagic/gemma-2-2b-it-FP8"
86
- DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
86
+ DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,zai-org/GLM-4.5-Air-FP8"
87
87
  DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4,hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4,hugging-quants/Mixtral-8x7B-Instruct-v0.1-AWQ-INT4"
88
88
  DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct"
89
89
  DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST = "Qwen/Qwen2.5-VL-3B-Instruct"