sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/configs/deepseekvl2.py +11 -2
  4. sglang/srt/configs/internvl.py +3 -0
  5. sglang/srt/configs/janus_pro.py +3 -0
  6. sglang/srt/configs/model_config.py +9 -7
  7. sglang/srt/configs/update_config.py +3 -1
  8. sglang/srt/conversation.py +1 -0
  9. sglang/srt/custom_op.py +5 -2
  10. sglang/srt/disaggregation/decode.py +9 -1
  11. sglang/srt/disaggregation/mooncake/conn.py +44 -56
  12. sglang/srt/distributed/parallel_state.py +33 -0
  13. sglang/srt/entrypoints/engine.py +30 -26
  14. sglang/srt/entrypoints/openai/serving_chat.py +21 -2
  15. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  16. sglang/srt/function_call/function_call_parser.py +2 -0
  17. sglang/srt/function_call/qwen3_detector.py +150 -0
  18. sglang/srt/hf_transformers_utils.py +0 -1
  19. sglang/srt/layers/activation.py +13 -0
  20. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  21. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  22. sglang/srt/layers/linear.py +13 -102
  23. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  24. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  25. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  26. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  34. sglang/srt/layers/moe/topk.py +187 -12
  35. sglang/srt/layers/quantization/__init__.py +20 -134
  36. sglang/srt/layers/quantization/awq.py +578 -11
  37. sglang/srt/layers/quantization/awq_triton.py +339 -0
  38. sglang/srt/layers/quantization/base_config.py +85 -10
  39. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  40. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
  42. sglang/srt/layers/quantization/fp8.py +273 -62
  43. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  44. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  45. sglang/srt/layers/quantization/gptq.py +501 -143
  46. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +26 -108
  48. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  49. sglang/srt/layers/quantization/petit.py +252 -0
  50. sglang/srt/layers/quantization/petit_utils.py +104 -0
  51. sglang/srt/layers/quantization/qoq.py +7 -6
  52. sglang/srt/layers/quantization/scalar_type.py +352 -0
  53. sglang/srt/layers/quantization/unquant.py +422 -0
  54. sglang/srt/layers/quantization/utils.py +343 -3
  55. sglang/srt/layers/quantization/w4afp8.py +8 -4
  56. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  57. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  58. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  59. sglang/srt/lora/lora.py +0 -4
  60. sglang/srt/lora/lora_manager.py +87 -53
  61. sglang/srt/lora/mem_pool.py +81 -33
  62. sglang/srt/lora/utils.py +12 -5
  63. sglang/srt/managers/cache_controller.py +241 -0
  64. sglang/srt/managers/io_struct.py +41 -29
  65. sglang/srt/managers/mm_utils.py +7 -8
  66. sglang/srt/managers/schedule_batch.py +150 -110
  67. sglang/srt/managers/schedule_policy.py +68 -27
  68. sglang/srt/managers/scheduler.py +243 -61
  69. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  70. sglang/srt/managers/tokenizer_manager.py +11 -3
  71. sglang/srt/managers/tp_worker.py +14 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  73. sglang/srt/mem_cache/allocator.py +7 -16
  74. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  75. sglang/srt/mem_cache/chunk_cache.py +5 -2
  76. sglang/srt/mem_cache/hicache_storage.py +152 -0
  77. sglang/srt/mem_cache/hiradix_cache.py +179 -4
  78. sglang/srt/mem_cache/memory_pool.py +16 -1
  79. sglang/srt/mem_cache/memory_pool_host.py +41 -2
  80. sglang/srt/mem_cache/radix_cache.py +26 -0
  81. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  82. sglang/srt/metrics/collector.py +9 -0
  83. sglang/srt/model_executor/cuda_graph_runner.py +5 -6
  84. sglang/srt/model_executor/forward_batch_info.py +14 -1
  85. sglang/srt/model_executor/model_runner.py +109 -22
  86. sglang/srt/model_loader/loader.py +7 -1
  87. sglang/srt/model_loader/utils.py +4 -4
  88. sglang/srt/models/clip.py +1 -1
  89. sglang/srt/models/deepseek.py +9 -6
  90. sglang/srt/models/deepseek_janus_pro.py +1 -1
  91. sglang/srt/models/deepseek_v2.py +191 -171
  92. sglang/srt/models/deepseek_vl2.py +5 -5
  93. sglang/srt/models/gemma.py +48 -0
  94. sglang/srt/models/gemma2.py +52 -0
  95. sglang/srt/models/gemma3_causal.py +63 -0
  96. sglang/srt/models/gemma3_mm.py +1 -1
  97. sglang/srt/models/gemma3n_mm.py +2 -4
  98. sglang/srt/models/granitemoe.py +385 -0
  99. sglang/srt/models/grok.py +9 -3
  100. sglang/srt/models/hunyuan.py +63 -16
  101. sglang/srt/models/internvl.py +1 -1
  102. sglang/srt/models/kimi_vl.py +1 -1
  103. sglang/srt/models/llama.py +41 -0
  104. sglang/srt/models/llama4.py +11 -11
  105. sglang/srt/models/llava.py +2 -2
  106. sglang/srt/models/llavavid.py +1 -1
  107. sglang/srt/models/minicpm.py +0 -2
  108. sglang/srt/models/minicpmo.py +3 -7
  109. sglang/srt/models/minicpmv.py +1 -1
  110. sglang/srt/models/mistral.py +1 -1
  111. sglang/srt/models/mixtral.py +9 -2
  112. sglang/srt/models/mllama.py +3 -5
  113. sglang/srt/models/mllama4.py +3 -3
  114. sglang/srt/models/olmoe.py +8 -5
  115. sglang/srt/models/persimmon.py +330 -0
  116. sglang/srt/models/phi.py +321 -0
  117. sglang/srt/models/phi4mm.py +44 -4
  118. sglang/srt/models/phi4mm_audio.py +1260 -0
  119. sglang/srt/models/phi4mm_utils.py +1917 -0
  120. sglang/srt/models/phimoe.py +9 -3
  121. sglang/srt/models/qwen.py +37 -0
  122. sglang/srt/models/qwen2.py +41 -0
  123. sglang/srt/models/qwen2_5_vl.py +4 -4
  124. sglang/srt/models/qwen2_audio.py +1 -1
  125. sglang/srt/models/qwen2_moe.py +53 -5
  126. sglang/srt/models/qwen2_vl.py +4 -4
  127. sglang/srt/models/qwen3.py +65 -1
  128. sglang/srt/models/qwen3_moe.py +56 -18
  129. sglang/srt/models/vila.py +1 -1
  130. sglang/srt/multimodal/processors/base_processor.py +91 -97
  131. sglang/srt/multimodal/processors/clip.py +21 -19
  132. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  133. sglang/srt/multimodal/processors/gemma3.py +13 -17
  134. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  135. sglang/srt/multimodal/processors/internvl.py +9 -10
  136. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  137. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  138. sglang/srt/multimodal/processors/llava.py +4 -2
  139. sglang/srt/multimodal/processors/minicpm.py +35 -44
  140. sglang/srt/multimodal/processors/mlama.py +21 -18
  141. sglang/srt/multimodal/processors/mllama4.py +4 -5
  142. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  143. sglang/srt/multimodal/processors/pixtral.py +14 -35
  144. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  145. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  146. sglang/srt/multimodal/processors/vila.py +14 -14
  147. sglang/srt/sampling/sampling_params.py +8 -1
  148. sglang/srt/server_args.py +393 -230
  149. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
  150. sglang/srt/two_batch_overlap.py +1 -0
  151. sglang/srt/utils.py +27 -1
  152. sglang/test/runners.py +14 -3
  153. sglang/test/test_block_fp8.py +8 -3
  154. sglang/test/test_block_fp8_ep.py +1 -1
  155. sglang/test/test_custom_ops.py +12 -7
  156. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  157. sglang/test/test_fp4_moe.py +1 -3
  158. sglang/test/test_marlin_moe.py +286 -0
  159. sglang/test/test_marlin_utils.py +171 -0
  160. sglang/test/test_utils.py +35 -0
  161. sglang/version.py +1 -1
  162. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
  163. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
  164. sglang/srt/layers/quantization/quant_utils.py +0 -166
  165. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  166. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
  167. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
  168. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -84,7 +84,15 @@ class EAGLEDraftExtendCudaGraphRunner:
84
84
  self.hidden_states = torch.zeros(
85
85
  (
86
86
  self.max_num_token,
87
- self.model_runner.model_config.hidden_size * 3,
87
+ (
88
+ self.model_runner.model_config.hf_config.target_hidden_size
89
+ * 3
90
+ if hasattr(
91
+ self.model_runner.model_config.hf_config,
92
+ "target_hidden_size",
93
+ )
94
+ else self.model_runner.model_config.hidden_size * 3
95
+ ),
88
96
  ),
89
97
  dtype=self.model_runner.dtype,
90
98
  )
@@ -500,6 +500,7 @@ class TboForwardBatchPreparer:
500
500
  "capture_hidden_mode",
501
501
  "padded_static_len",
502
502
  "mrope_positions", # only used by qwen2-vl, thus not care
503
+ "split_index", # for split prefill
503
504
  ]:
504
505
  output_dict[key] = getattr(batch, key)
505
506
  if not batch.forward_mode.is_target_verify():
sglang/srt/utils.py CHANGED
@@ -691,12 +691,17 @@ def decode_video_base64(video_base64):
691
691
  ) # Return an empty array and size tuple if no frames were found
692
692
 
693
693
 
694
- def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarray:
694
+ def load_audio(
695
+ audio_file: str, sr: Optional[int] = None, mono: bool = True
696
+ ) -> np.ndarray:
695
697
  # Use soundfile here, since librosa use it under the hood,
696
698
  # and librosa will not support audio loading in the future
697
699
  import soundfile as sf
698
700
  from scipy.signal import resample
699
701
 
702
+ if sr is None:
703
+ sr = 16000
704
+
700
705
  # Load audio data
701
706
  if isinstance(audio_file, bytes):
702
707
  audio, original_sr = sf.read(BytesIO(audio_file))
@@ -1417,6 +1422,13 @@ def get_nvgpu_memory_capacity():
1417
1422
  ]
1418
1423
 
1419
1424
  if not memory_values:
1425
+ # Fallback to torch.cuda.mem_get_info() when failed to get memory capacity from nvidia-smi,
1426
+ # typically in NVIDIA MIG mode.
1427
+ if torch.cuda.is_available():
1428
+ logger.warning(
1429
+ "Failed to get GPU memory capacity from nvidia-smi, falling back to torch.cuda.mem_get_info()."
1430
+ )
1431
+ return torch.cuda.mem_get_info()[1] // 1024 // 1024 # unit: MB
1420
1432
  raise ValueError("No GPU memory values found.")
1421
1433
 
1422
1434
  # Return the minimum memory value
@@ -2880,3 +2892,17 @@ def parse_module_path(module_path, function_name, create_dummy):
2880
2892
  return final_module, getattr(final_module, function_name)
2881
2893
 
2882
2894
  return final_module, None
2895
+
2896
+
2897
+ # LoRA-related constants and utilities
2898
+ SUPPORTED_LORA_TARGET_MODULES = [
2899
+ "q_proj",
2900
+ "k_proj",
2901
+ "v_proj",
2902
+ "o_proj",
2903
+ "gate_proj",
2904
+ "up_proj",
2905
+ "down_proj",
2906
+ ]
2907
+
2908
+ LORA_TARGET_ALL_MODULES = "all"
sglang/test/runners.py CHANGED
@@ -134,10 +134,12 @@ class HFRunner:
134
134
  model_type: str = "generation",
135
135
  output_str_only: bool = False,
136
136
  trust_remote_code: bool = False,
137
+ patch_model_do_sample_false: bool = False,
137
138
  ):
138
139
  self.model_type = model_type
139
140
  self.output_str_only = output_str_only
140
141
  self.trust_remote_code = trust_remote_code
142
+ self.patch_model_do_sample_false = patch_model_do_sample_false
141
143
 
142
144
  self.in_queue = mp.Queue()
143
145
  self.out_queue = mp.Queue()
@@ -292,6 +294,7 @@ class HFRunner:
292
294
  torch_dtype=torch_dtype,
293
295
  output_str_only=self.output_str_only,
294
296
  token_ids_logprob=token_ids_logprob,
297
+ patch_model_do_sample_false=self.patch_model_do_sample_false,
295
298
  )
296
299
  )
297
300
  elif self.model_type == "embedding":
@@ -380,6 +383,7 @@ class HFRunner:
380
383
  lora_paths: Optional[List[str]] = None,
381
384
  output_str_only: bool = False,
382
385
  token_ids_logprob: Optional[int] = None,
386
+ patch_model_do_sample_false: Optional[bool] = False,
383
387
  ) -> ModelOutput:
384
388
  output_strs = []
385
389
  top_input_logprobs = []
@@ -407,7 +411,8 @@ class HFRunner:
407
411
  )
408
412
  else:
409
413
  model = base_model
410
-
414
+ if patch_model_do_sample_false:
415
+ model.generation_config.do_sample = False
411
416
  outputs = model.generate(
412
417
  input_ids=input_ids,
413
418
  generation_config=GenerationConfig(
@@ -481,7 +486,7 @@ class SRTRunner:
481
486
  torch_dtype: torch.dtype,
482
487
  model_type: str,
483
488
  tp_size: int = 1,
484
- impl: str = "auto",
489
+ model_impl: str = "auto",
485
490
  port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
486
491
  lora_paths: List[str] = None,
487
492
  max_loras_per_batch: int = 4,
@@ -505,6 +510,9 @@ class SRTRunner:
505
510
  torchao_config: Optional[str] = None,
506
511
  cuda_graph_max_bs: int = 4,
507
512
  sleep_on_idle=False,
513
+ max_lora_rank: Optional[int] = None,
514
+ lora_target_modules: Optional[List[str]] = None,
515
+ enable_lora: Optional[bool] = None,
508
516
  ):
509
517
  self.model_type = model_type
510
518
  self.is_generation = model_type == "generation"
@@ -523,7 +531,7 @@ class SRTRunner:
523
531
  tp_size=tp_size,
524
532
  dtype=get_dtype_str(torch_dtype),
525
533
  port=port,
526
- impl=impl,
534
+ model_impl=model_impl,
527
535
  torchao_config=torchao_config,
528
536
  mem_fraction_static=mem_fraction_static,
529
537
  trust_remote_code=trust_remote_code,
@@ -543,6 +551,9 @@ class SRTRunner:
543
551
  cuda_graph_max_bs=cuda_graph_max_bs,
544
552
  disable_custom_all_reduce=disable_custom_all_reduce,
545
553
  sleep_on_idle=sleep_on_idle,
554
+ max_lora_rank=max_lora_rank,
555
+ lora_target_modules=lora_target_modules,
556
+ enable_lora=enable_lora,
546
557
  **spec_kwargs,
547
558
  )
548
559
 
@@ -6,6 +6,7 @@ import torch
6
6
 
7
7
  from sglang.srt.layers.activation import SiluAndMul
8
8
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
9
+ from sglang.srt.layers.moe.topk import select_experts
9
10
  from sglang.srt.layers.quantization.fp8_kernel import (
10
11
  per_tensor_quant_mla_fp8,
11
12
  per_token_group_quant_fp8,
@@ -497,13 +498,17 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase):
497
498
  score = torch.randn((M, E), dtype=dtype)
498
499
 
499
500
  with torch.inference_mode():
501
+ topk_output = select_experts(
502
+ hidden_states=a,
503
+ router_logits=score,
504
+ top_k=topk,
505
+ renormalize=False,
506
+ )
500
507
  out = fused_moe(
501
508
  a,
502
509
  w1,
503
510
  w2,
504
- score,
505
- topk,
506
- renormalize=False,
511
+ topk_output,
507
512
  use_fp8_w8a8=True,
508
513
  w1_scale=w1_s,
509
514
  w2_scale=w2_s,
@@ -40,7 +40,7 @@ def ep_moe(
40
40
  block_shape: Optional[List[int]] = None,
41
41
  ):
42
42
  use_blockwise_fp8 = block_shape is not None
43
- topk_weights, topk_ids = select_experts(
43
+ topk_weights, topk_ids, _ = select_experts(
44
44
  hidden_states=hidden_states,
45
45
  router_logits=router_logits,
46
46
  top_k=top_k,
@@ -3,8 +3,13 @@
3
3
  import pytest
4
4
  import torch
5
5
 
6
- from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
7
- from sglang.srt.utils import is_cuda
6
+ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
7
+ from sglang.srt.utils import is_cuda, is_hip
8
+
9
+ _is_cuda = is_cuda()
10
+ _is_hip = is_hip()
11
+ _is_fp8_fnuz = is_fp8_fnuz()
12
+ fp8_dtype = torch.float8_e4m3fnuz if _is_fp8_fnuz else torch.float8_e4m3fn
8
13
 
9
14
 
10
15
  @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@@ -13,10 +18,10 @@ def test_scaled_fp8_quant_per_tensor(dtype) -> None:
13
18
  def quantize_ref_per_tensor(tensor, inv_scale):
14
19
  # The reference implementation that fully aligns to
15
20
  # the kernel being tested.
16
- finfo = torch.finfo(torch.float8_e4m3fn)
21
+ finfo = torch.finfo(fp8_dtype)
17
22
  scale = inv_scale.reciprocal()
18
23
  qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
19
- qweight = qweight.to(torch.float8_e4m3fn)
24
+ qweight = qweight.to(fp8_dtype)
20
25
  return qweight
21
26
 
22
27
  def dequantize_per_tensor(tensor, inv_scale, dtype):
@@ -48,19 +53,19 @@ def test_scaled_fp8_quant_per_tensor(dtype) -> None:
48
53
  )
49
54
 
50
55
 
51
- if is_cuda:
56
+ if _is_cuda or _is_hip:
52
57
 
53
58
  @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
54
59
  def test_scaled_fp8_quant_per_token_dynamic(dtype) -> None:
55
60
  def quantize_ref_per_token(tensor, inv_scale):
56
61
  # The reference implementation that fully aligns to
57
62
  # the kernel being tested.
58
- finfo = torch.finfo(torch.float8_e4m3fn)
63
+ finfo = torch.finfo(fp8_dtype)
59
64
  scale = inv_scale.reciprocal()
60
65
  qweight = (tensor.to(torch.float32) * scale).clamp(
61
66
  min=finfo.min, max=finfo.max
62
67
  )
63
- qweight = qweight.to(torch.float8_e4m3fn)
68
+ qweight = qweight.to(fp8_dtype)
64
69
  return qweight
65
70
 
66
71
  def dequantize_per_token(tensor, inv_scale, dtype):
@@ -100,12 +100,10 @@ def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype):
100
100
  s_strides2 = c_strides2
101
101
 
102
102
  score = torch.randn((M, E), dtype=dtype, device=device)
103
- topk_weights, topk_ids = select_experts(
103
+ topk_weights, topk_ids, _ = select_experts(
104
104
  hidden_states=a,
105
105
  router_logits=score,
106
106
  top_k=topk,
107
- use_grouped_topk=False,
108
- renormalize=False,
109
107
  )
110
108
  expert_map = torch.arange(E, dtype=torch.int32, device=device)
111
109
  expert_map[local_e:] = E
@@ -159,12 +159,10 @@ def test_cutlass_fp4_moe_no_graph(
159
159
 
160
160
  score = torch.randn((m, e), device="cuda", dtype=dtype)
161
161
 
162
- topk_weights, topk_ids = select_experts(
162
+ topk_weights, topk_ids, _ = select_experts(
163
163
  hidden_states=a,
164
164
  router_logits=score,
165
165
  top_k=topk,
166
- use_grouped_topk=False,
167
- renormalize=False,
168
166
  )
169
167
 
170
168
  a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
@@ -0,0 +1,286 @@
1
+ import types
2
+ from typing import Optional
3
+
4
+ import pytest
5
+ import torch
6
+ from sgl_kernel import fused_marlin_moe
7
+
8
+ from sglang.srt.layers.activation import SiluAndMul
9
+ from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
10
+ from sglang.test.test_marlin_utils import awq_marlin_quantize, marlin_quantize
11
+
12
+
13
+ def stack_and_dev(tensors: list[torch.Tensor]):
14
+ dev = tensors[0].device
15
+ return torch.stack(tensors, dim=0).to(dev)
16
+
17
+
18
+ def torch_experts(
19
+ a: torch.Tensor,
20
+ w1: torch.Tensor,
21
+ w2: torch.Tensor,
22
+ topk_weight: torch.Tensor,
23
+ topk_ids: torch.Tensor,
24
+ global_num_experts: int = -1,
25
+ expert_map: Optional[torch.Tensor] = None,
26
+ quant_dtype: Optional[torch.dtype] = None,
27
+ apply_router_weights_on_input: bool = False,
28
+ ) -> torch.Tensor:
29
+ assert (
30
+ global_num_experts == -1
31
+ or (global_num_experts == w1.shape[0] and expert_map is None)
32
+ or (expert_map is not None and global_num_experts == expert_map.shape[0])
33
+ )
34
+
35
+ M, K = a.shape
36
+ topk = topk_ids.shape[1]
37
+ print("quant_dtype", quant_dtype)
38
+ # exit(0)
39
+ if apply_router_weights_on_input:
40
+ assert topk == 1
41
+ a = a * topk_weight.to(a.dtype)
42
+
43
+ a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
44
+
45
+ out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
46
+
47
+ num_experts = w1.shape[0]
48
+
49
+ topk_ids = topk_ids.view(-1)
50
+ if expert_map is not None:
51
+ topk_ids = expert_map[topk_ids]
52
+
53
+ f32 = torch.float32
54
+
55
+ for i in range(num_experts):
56
+ mask = topk_ids == i
57
+ if mask.sum():
58
+ if quant_dtype is None:
59
+ tmp1 = a[mask] @ w1[i].transpose(0, 1)
60
+ tmp2 = SiluAndMul()(tmp1)
61
+ out[mask] = tmp2 @ w2[i].transpose(0, 1)
62
+
63
+ if apply_router_weights_on_input:
64
+ return out
65
+ else:
66
+ return (
67
+ (out.view(M, -1, w2.shape[1]).to(f32) * topk_weight.view(M, -1, 1))
68
+ .sum(dim=1)
69
+ .to(out.dtype)
70
+ )
71
+
72
+
73
+ def torch_moe(
74
+ a: torch.Tensor,
75
+ w1: torch.Tensor,
76
+ w2: torch.Tensor,
77
+ score: torch.Tensor,
78
+ topk: int,
79
+ global_num_experts: int = -1,
80
+ expert_map: Optional[torch.Tensor] = None,
81
+ ) -> torch.Tensor:
82
+ score = torch.softmax(score, dim=-1, dtype=torch.float32)
83
+ topk_weight, topk_ids = torch.topk(score, topk)
84
+ return torch_experts(
85
+ a, w1, w2, topk_weight, topk_ids, global_num_experts, expert_map
86
+ )
87
+
88
+
89
+ def marlin_moe_generate_valid_test_cases():
90
+ import itertools
91
+
92
+ m_list = [1, 123, 666]
93
+ n_list = [128, 1024]
94
+ k_list = [256, 2048]
95
+ e_list = [4, 12]
96
+ topk_list = [2, 3]
97
+ dtype_list = [torch.half, torch.bfloat16]
98
+ group_size_list = [128]
99
+ act_order_list = [True, False]
100
+ quant_type_list = [
101
+ scalar_types.uint4,
102
+ scalar_types.uint4b8,
103
+ ]
104
+ is_k_full_list = [True, False]
105
+
106
+ all_combinations = itertools.product(
107
+ m_list,
108
+ n_list,
109
+ k_list,
110
+ e_list,
111
+ topk_list,
112
+ dtype_list,
113
+ group_size_list,
114
+ act_order_list,
115
+ quant_type_list,
116
+ is_k_full_list,
117
+ )
118
+
119
+ def is_invalid(
120
+ m, n, k, e, topk, dtype, group_size, act_order, quant_type, is_k_full
121
+ ):
122
+
123
+ # Filter act_order
124
+ if act_order:
125
+ if group_size in (-1, k, n):
126
+ return False
127
+ if quant_type not in [scalar_types.uint4b8]:
128
+ return False
129
+ elif not is_k_full:
130
+ return False
131
+
132
+ return True
133
+
134
+ cases = []
135
+ for case in all_combinations:
136
+ if is_invalid(*case):
137
+ cases.append(case)
138
+ return cases
139
+
140
+
141
+ @pytest.mark.flaky(reruns=2)
142
+ @pytest.mark.parametrize(
143
+ ("m, n, k, e, topk, dtype, group_size," "act_order, quant_type, is_k_full"),
144
+ marlin_moe_generate_valid_test_cases(),
145
+ )
146
+ def test_fused_marlin_moe(
147
+ m: int,
148
+ n: int,
149
+ k: int,
150
+ e: int,
151
+ topk: int,
152
+ dtype: torch.dtype,
153
+ group_size: int,
154
+ act_order: bool,
155
+ quant_type: ScalarType,
156
+ is_k_full: bool,
157
+ ):
158
+ if not torch.cuda.is_available():
159
+ pytest.skip("CUDA device not available")
160
+
161
+ torch.manual_seed(0)
162
+
163
+ has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
164
+
165
+ # Filter act_order
166
+ if act_order:
167
+ if group_size == -1:
168
+ return
169
+ if group_size in (k, n):
170
+ return
171
+ if has_zp:
172
+ return
173
+ else:
174
+ if not is_k_full:
175
+ return
176
+
177
+ a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
178
+ w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
179
+ w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
180
+
181
+ e_map = None
182
+
183
+ w_ref1_l = []
184
+ qweight1_l = []
185
+ scales1_l = []
186
+ zeros1_l = []
187
+ g_idx1_l = []
188
+ sort_indices1_l = []
189
+
190
+ for i in range(w1.shape[0]):
191
+ if has_zp:
192
+ w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize(
193
+ w1[i].transpose(1, 0), quant_type, group_size
194
+ )
195
+
196
+ w_ref1_l.append(w_ref1.T)
197
+ qweight1_l.append(qweight1)
198
+ scales1_l.append(scales1)
199
+ zeros1_l.append(zeros1)
200
+ else:
201
+ test_perm = torch.randperm(k)
202
+ w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
203
+ w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
204
+ )
205
+
206
+ w_ref1_l.append(w_ref1.T)
207
+ qweight1_l.append(qweight1)
208
+ scales1_l.append(scales1)
209
+ g_idx1_l.append(g_idx1)
210
+ sort_indices1_l.append(sort_indices1)
211
+
212
+ w_ref1 = stack_and_dev(w_ref1_l)
213
+ qweight1 = stack_and_dev(qweight1_l).contiguous()
214
+ scales1 = stack_and_dev(scales1_l)
215
+ g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
216
+ zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None
217
+ sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
218
+
219
+ w_ref2_l = []
220
+ qweight2_l = []
221
+ scales2_l = []
222
+ zeros2_l = []
223
+ g_idx2_l = []
224
+ sort_indices2_l = []
225
+
226
+ for i in range(w2.shape[0]):
227
+ if has_zp:
228
+ w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize(
229
+ w2[i].transpose(1, 0), quant_type, group_size
230
+ )
231
+
232
+ w_ref2_l.append(w_ref2.T)
233
+ qweight2_l.append(qweight2)
234
+ scales2_l.append(scales2)
235
+ zeros2_l.append(zeros2)
236
+ else:
237
+ test_perm = torch.randperm(n)
238
+ w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
239
+ w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
240
+ )
241
+
242
+ w_ref2_l.append(w_ref2.T)
243
+ qweight2_l.append(qweight2)
244
+ scales2_l.append(scales2)
245
+ g_idx2_l.append(g_idx2)
246
+ sort_indices2_l.append(sort_indices2)
247
+
248
+ w_ref2 = stack_and_dev(w_ref2_l)
249
+ qweight2 = stack_and_dev(qweight2_l).contiguous()
250
+ scales2 = stack_and_dev(scales2_l)
251
+ g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
252
+ zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None
253
+ sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
254
+
255
+ score = torch.randn((m, e), device="cuda", dtype=dtype)
256
+ from sglang.srt.layers.moe.topk import fused_topk_torch_native
257
+
258
+ topk_weights, topk_ids = fused_topk_torch_native(a, score, topk, False)
259
+
260
+ torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map)
261
+
262
+ marlin_output = fused_marlin_moe(
263
+ a,
264
+ qweight1,
265
+ qweight2,
266
+ scales1,
267
+ scales2,
268
+ score,
269
+ topk_weights,
270
+ topk_ids,
271
+ g_idx1=g_idx1,
272
+ g_idx2=g_idx2,
273
+ sort_indices1=sort_indices1,
274
+ sort_indices2=sort_indices2,
275
+ w1_zeros=zeros1,
276
+ w2_zeros=zeros2,
277
+ num_bits=4,
278
+ is_k_full=is_k_full,
279
+ )
280
+
281
+ torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
282
+
283
+
284
+ if __name__ == "__main__":
285
+ # Run the specific test function directly
286
+ pytest.main([__file__])