sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -27,8 +27,10 @@ from sglang.srt.layers.quantization import deep_gemm_wrapper
27
27
  from sglang.srt.utils import (
28
28
  align,
29
29
  direct_register_custom_op,
30
+ get_bool_env_var,
30
31
  get_device_core_count,
31
32
  get_device_name,
33
+ is_cpu,
32
34
  is_cuda,
33
35
  is_hip,
34
36
  log_info_on_rank0,
@@ -37,6 +39,8 @@ from sglang.srt.utils import (
37
39
 
38
40
  _is_hip = is_hip()
39
41
  _is_cuda = is_cuda()
42
+ _is_cpu = is_cpu()
43
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
40
44
 
41
45
  if _is_cuda:
42
46
  from sgl_kernel import (
@@ -45,6 +49,22 @@ if _is_cuda:
45
49
  sgl_per_token_quant_fp8,
46
50
  )
47
51
 
52
+ if _is_hip:
53
+ if _use_aiter:
54
+ try:
55
+ from aiter import ( # v0.1.3
56
+ dynamic_per_tensor_quant,
57
+ dynamic_per_token_scaled_quant,
58
+ static_per_tensor_quant,
59
+ )
60
+ except ImportError:
61
+ raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
62
+ else:
63
+ try:
64
+ import vllm._C
65
+ except ImportError:
66
+ raise ImportError("vllm is required when SGLANG_USE_AITER is set to False")
67
+
48
68
  logger = logging.getLogger(__name__)
49
69
 
50
70
 
@@ -1114,55 +1134,199 @@ def per_token_group_quant_mla_deep_gemm_masked_fp8(
1114
1134
  return x_q, x_s.transpose(1, 2), masked_m, m, aligned_m
1115
1135
 
1116
1136
 
1117
- def scaled_fp8_quant(
1118
- input: torch.Tensor,
1119
- scale: Optional[torch.Tensor] = None,
1120
- num_token_padding: Optional[int] = None,
1121
- use_per_token_if_dynamic: bool = False,
1122
- ) -> tuple[torch.Tensor, torch.Tensor]:
1123
- """
1124
- Quantize input tensor to FP8 (8-bit floating point) format.
1137
+ """
1138
+ Quantize input tensor to FP8 (8-bit floating point) format.
1139
+
1140
+ Args:
1141
+ input (torch.Tensor): Input tensor to be quantized
1142
+ scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
1143
+ If None, scales will be computed dynamically.
1144
+ num_token_padding (Optional[int]): If specified, pad the first dimension
1145
+ of the output to at least this value.
1146
+ use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
1147
+ determines the quantization granularity:
1148
+ - True: compute scale per token
1149
+ - False: compute single scale per tensor
1150
+
1151
+ Returns:
1152
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
1153
+ - quantized_tensor: The FP8 quantized version of input
1154
+ - scale_tensor: The scaling factors used for quantization
1155
+
1156
+ Raises:
1157
+ AssertionError: If input is not 2D or if static scale's numel != 1
1158
+ """
1159
+ if _is_hip:
1125
1160
 
1126
- Args:
1127
- input (torch.Tensor): Input tensor to be quantized
1128
- scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
1129
- If None, scales will be computed dynamically.
1130
- num_token_padding (Optional[int]): If specified, pad the first dimension
1131
- of the output to at least this value.
1132
- use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
1133
- determines the quantization granularity:
1134
- - True: compute scale per token
1135
- - False: compute single scale per tensor
1161
+ def scaled_fp8_quant(
1162
+ input: torch.Tensor,
1163
+ scale: Optional[torch.Tensor] = None,
1164
+ num_token_padding: Optional[int] = None,
1165
+ use_per_token_if_dynamic: bool = False,
1166
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1167
+ assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
1168
+ shape = input.shape
1169
+ if num_token_padding:
1170
+ shape = (max(num_token_padding, input.shape[0]), shape[1])
1171
+ output = torch.empty(shape, device=input.device, dtype=fp8_dtype)
1172
+
1173
+ if scale is None:
1174
+ # Dynamic scaling
1175
+ if use_per_token_if_dynamic:
1176
+ scale = torch.empty(
1177
+ (shape[0], 1), device=input.device, dtype=torch.float32
1178
+ )
1179
+ if _use_aiter:
1180
+ dynamic_per_token_scaled_quant(output, input, scale)
1181
+ else:
1182
+ torch.ops._C.dynamic_per_token_scaled_fp8_quant(
1183
+ output, input.contiguous(), scale, None
1184
+ )
1185
+ else:
1186
+ scale = torch.zeros(1, device=input.device, dtype=torch.float32)
1187
+ if _use_aiter:
1188
+ dynamic_per_tensor_quant(output, input, scale)
1189
+ else:
1190
+ torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
1191
+ else:
1192
+ # Static scaling
1193
+ assert (
1194
+ scale.numel() == 1
1195
+ ), f"Expected scalar scale, got numel={scale.numel()}"
1196
+ if _use_aiter:
1197
+ static_per_tensor_quant(output, input, scale)
1198
+ else:
1199
+ torch.ops._C.static_scaled_fp8_quant(output, input, scale)
1136
1200
 
1137
- Returns:
1138
- Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
1139
- - quantized_tensor: The FP8 quantized version of input
1140
- - scale_tensor: The scaling factors used for quantization
1201
+ return output, scale
1141
1202
 
1142
- Raises:
1143
- AssertionError: If input is not 2D or if static scale's numel != 1
1144
- """
1145
- assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
1146
- shape = input.shape
1147
- if num_token_padding:
1148
- shape = (max(num_token_padding, input.shape[0]), shape[1])
1149
- output = torch.empty(shape, device=input.device, dtype=fp8_dtype)
1150
-
1151
- if scale is None:
1152
- # Dynamic scaling
1153
- if use_per_token_if_dynamic:
1154
- scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
1155
- sgl_per_token_quant_fp8(input, output, scale)
1203
+ else:
1204
+
1205
+ def scaled_fp8_quant(
1206
+ input: torch.Tensor,
1207
+ scale: Optional[torch.Tensor] = None,
1208
+ num_token_padding: Optional[int] = None,
1209
+ use_per_token_if_dynamic: bool = False,
1210
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1211
+
1212
+ assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
1213
+ shape = input.shape
1214
+ if num_token_padding:
1215
+ shape = (max(num_token_padding, input.shape[0]), shape[1])
1216
+ output = torch.empty(shape, device=input.device, dtype=fp8_dtype)
1217
+
1218
+ if scale is None:
1219
+ # Dynamic scaling
1220
+ if use_per_token_if_dynamic:
1221
+ scale = torch.empty(
1222
+ (shape[0], 1), device=input.device, dtype=torch.float32
1223
+ )
1224
+ sgl_per_token_quant_fp8(input, output, scale)
1225
+ else:
1226
+ scale = torch.zeros(1, device=input.device, dtype=torch.float32)
1227
+ sgl_per_tensor_quant_fp8(
1228
+ input, output, scale, is_static=False
1229
+ ) # False for dynamic
1156
1230
  else:
1157
- scale = torch.zeros(1, device=input.device, dtype=torch.float32)
1231
+ # Static scaling
1232
+ assert (
1233
+ scale.numel() == 1
1234
+ ), f"Expected scalar scale, got numel={scale.numel()}"
1158
1235
  sgl_per_tensor_quant_fp8(
1159
- input, output, scale, is_static=False
1160
- ) # False for dynamic
1161
- else:
1162
- # Static scaling
1163
- assert scale.numel() == 1, f"Expected scalar scale, got numel={scale.numel()}"
1164
- sgl_per_tensor_quant_fp8(
1165
- input, output, scale, is_static=True
1166
- ) # True for static
1236
+ input, output, scale, is_static=True
1237
+ ) # True for static
1238
+
1239
+ return output, scale
1240
+
1241
+
1242
+ fp8_autotune = triton.autotune(
1243
+ configs=[
1244
+ triton.Config({"BLOCK_M": block_m}, num_warps=num_warps)
1245
+ for block_m in [16, 32, 64, 128]
1246
+ for num_warps in [2, 4, 8]
1247
+ ],
1248
+ key=["K", "BLOCK_K", "M_ALIGNMENT"],
1249
+ )
1250
+
1167
1251
 
1168
- return output, scale
1252
+ @triton.jit
1253
+ def _per_token_group_quant_fp8_hopper_moe_mn_major(
1254
+ a, # (M, K):(K, 1)
1255
+ expert_offsets, # (num_experts,)
1256
+ problem_sizes, # (num_experts, 3)
1257
+ a_fp8, # (M, K):(K, 1)
1258
+ sfa, # (M, k)
1259
+ K: tl.constexpr,
1260
+ BLOCK_K: tl.constexpr,
1261
+ M_ALIGNMENT: tl.constexpr,
1262
+ BLOCK_M: tl.constexpr, # tune
1263
+ ):
1264
+ k_offset = tl.program_id(0)
1265
+ expert_id = tl.program_id(1)
1266
+
1267
+ m = tl.load(problem_sizes + expert_id * 3)
1268
+ current_expert_offset = tl.load(expert_offsets + expert_id).to(tl.int64)
1269
+ tl.multiple_of(m, M_ALIGNMENT)
1270
+ tl.multiple_of(current_expert_offset, M_ALIGNMENT)
1271
+
1272
+ coord_k = k_offset * BLOCK_K + tl.arange(0, BLOCK_K)
1273
+ for i in tl.range(tl.cdiv(m, BLOCK_M)):
1274
+ coord_m = i * BLOCK_M + tl.arange(0, BLOCK_M)
1275
+ a_ptrs = a + current_expert_offset * K + coord_m[:, None] * K + coord_k[None, :]
1276
+ a_mask = (coord_m < m)[:, None] & (coord_k < K)[None, :]
1277
+
1278
+ inp = tl.load(a_ptrs, mask=a_mask).to(tl.float32) # [BLOCK_M, BLOCK_K]
1279
+ inp_amax = tl.max(tl.abs(inp), axis=1) # [BLOCK_M,]
1280
+ inp_amax = tl.clamp(inp_amax, min=1e-4, max=float("inf"))
1281
+ inp_fp8 = (inp * (448.0 / inp_amax[:, None])).to(tl.float8e4nv)
1282
+
1283
+ # Store fp8
1284
+ a_fp8_ptrs = (
1285
+ a_fp8 + current_expert_offset * K + coord_m[:, None] * K + coord_k[None, :]
1286
+ )
1287
+ tl.store(a_fp8_ptrs, inp_fp8, mask=a_mask)
1288
+
1289
+ # Store sfa
1290
+ k = tl.cdiv(K, BLOCK_K)
1291
+ sfa_ptrs = (
1292
+ sfa + current_expert_offset * k + k_offset * m + coord_m
1293
+ ) # MN-Major with sfa
1294
+ tl.store(sfa_ptrs, inp_amax / 448.0, mask=coord_m < m)
1295
+
1296
+
1297
+ if not _is_cpu:
1298
+ _per_token_group_quant_fp8_hopper_moe_mn_major = fp8_autotune(
1299
+ _per_token_group_quant_fp8_hopper_moe_mn_major
1300
+ )
1301
+
1302
+
1303
+ def per_token_group_quant_fp8_hopper_moe_mn_major(
1304
+ A: torch.Tensor,
1305
+ expert_offsets: torch.Tensor,
1306
+ problem_sizes: torch.Tensor,
1307
+ group_size: int,
1308
+ expert_tokens_alignment: int = 1,
1309
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1310
+ assert A.dim() == 2
1311
+ assert A.is_contiguous(), "`A` is not contiguous"
1312
+ assert (
1313
+ A.shape[-1] % group_size == 0
1314
+ ), "the last dimension of `A` cannot be divisible by `group_size`"
1315
+
1316
+ a_q = torch.empty_like(A, device=A.device, dtype=fp8_dtype)
1317
+ M, K = A.shape[0], A.shape[1]
1318
+ k = K // group_size
1319
+ sfa = torch.empty((M, k), device=A.device, dtype=torch.float32)
1320
+ num_experts = problem_sizes.shape[0]
1321
+ grid = (k, num_experts)
1322
+ _per_token_group_quant_fp8_hopper_moe_mn_major[grid](
1323
+ A,
1324
+ expert_offsets,
1325
+ problem_sizes,
1326
+ a_q,
1327
+ sfa,
1328
+ K,
1329
+ group_size,
1330
+ expert_tokens_alignment,
1331
+ )
1332
+ return a_q, sfa
@@ -42,7 +42,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
42
42
 
43
43
  if _use_aiter:
44
44
  import aiter
45
- from aiter import gemm_a8w8_blockscale_CK, get_hip_quant
45
+ from aiter import gemm_a8w8_blockscale, get_hip_quant
46
46
 
47
47
  aiter_per1x128_quant = get_hip_quant(aiter.QuantType.per_1x128)
48
48
 
@@ -274,7 +274,7 @@ def aiter_w8a8_block_fp8_linear(
274
274
  output_shape = [*input.shape[:-1], weight.shape[0]]
275
275
 
276
276
  q_input, x_scale = aiter_per1x128_quant(input_2d, quant_dtype=aiter.dtypes.fp8)
277
- output = gemm_a8w8_blockscale_CK(
277
+ output = gemm_a8w8_blockscale(
278
278
  q_input, weight, x_scale, weight_scale, dtype=input.dtype
279
279
  )
280
280