sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +10 -8
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +2 -1
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +93 -76
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +103 -15
- sglang/srt/entrypoints/engine.py +31 -33
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +48 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -2
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/qwen3_coder_detector.py +151 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +24 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +190 -23
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +34 -112
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +340 -9
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +162 -164
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +83 -35
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +288 -0
- sglang/srt/managers/io_struct.py +60 -30
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +163 -113
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +256 -86
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +38 -27
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +74 -23
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +168 -0
- sglang/srt/mem_cache/hiradix_cache.py +194 -5
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +44 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +66 -31
- sglang/srt/model_executor/forward_batch_info.py +210 -25
- sglang/srt/model_executor/model_runner.py +147 -42
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +192 -173
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +13 -6
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -9
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +57 -24
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +454 -270
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +10 -5
- sglang/srt/utils.py +44 -69
- sglang/test/runners.py +14 -3
- sglang/test/test_activation.py +50 -1
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -27,8 +27,10 @@ from sglang.srt.layers.quantization import deep_gemm_wrapper
|
|
27
27
|
from sglang.srt.utils import (
|
28
28
|
align,
|
29
29
|
direct_register_custom_op,
|
30
|
+
get_bool_env_var,
|
30
31
|
get_device_core_count,
|
31
32
|
get_device_name,
|
33
|
+
is_cpu,
|
32
34
|
is_cuda,
|
33
35
|
is_hip,
|
34
36
|
log_info_on_rank0,
|
@@ -37,6 +39,8 @@ from sglang.srt.utils import (
|
|
37
39
|
|
38
40
|
_is_hip = is_hip()
|
39
41
|
_is_cuda = is_cuda()
|
42
|
+
_is_cpu = is_cpu()
|
43
|
+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
40
44
|
|
41
45
|
if _is_cuda:
|
42
46
|
from sgl_kernel import (
|
@@ -45,6 +49,22 @@ if _is_cuda:
|
|
45
49
|
sgl_per_token_quant_fp8,
|
46
50
|
)
|
47
51
|
|
52
|
+
if _is_hip:
|
53
|
+
if _use_aiter:
|
54
|
+
try:
|
55
|
+
from aiter import ( # v0.1.3
|
56
|
+
dynamic_per_tensor_quant,
|
57
|
+
dynamic_per_token_scaled_quant,
|
58
|
+
static_per_tensor_quant,
|
59
|
+
)
|
60
|
+
except ImportError:
|
61
|
+
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
62
|
+
else:
|
63
|
+
try:
|
64
|
+
import vllm._C
|
65
|
+
except ImportError:
|
66
|
+
raise ImportError("vllm is required when SGLANG_USE_AITER is set to False")
|
67
|
+
|
48
68
|
logger = logging.getLogger(__name__)
|
49
69
|
|
50
70
|
|
@@ -1114,55 +1134,199 @@ def per_token_group_quant_mla_deep_gemm_masked_fp8(
|
|
1114
1134
|
return x_q, x_s.transpose(1, 2), masked_m, m, aligned_m
|
1115
1135
|
|
1116
1136
|
|
1117
|
-
|
1118
|
-
|
1119
|
-
|
1120
|
-
|
1121
|
-
|
1122
|
-
|
1123
|
-
|
1124
|
-
|
1137
|
+
"""
|
1138
|
+
Quantize input tensor to FP8 (8-bit floating point) format.
|
1139
|
+
|
1140
|
+
Args:
|
1141
|
+
input (torch.Tensor): Input tensor to be quantized
|
1142
|
+
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
|
1143
|
+
If None, scales will be computed dynamically.
|
1144
|
+
num_token_padding (Optional[int]): If specified, pad the first dimension
|
1145
|
+
of the output to at least this value.
|
1146
|
+
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
|
1147
|
+
determines the quantization granularity:
|
1148
|
+
- True: compute scale per token
|
1149
|
+
- False: compute single scale per tensor
|
1150
|
+
|
1151
|
+
Returns:
|
1152
|
+
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
1153
|
+
- quantized_tensor: The FP8 quantized version of input
|
1154
|
+
- scale_tensor: The scaling factors used for quantization
|
1155
|
+
|
1156
|
+
Raises:
|
1157
|
+
AssertionError: If input is not 2D or if static scale's numel != 1
|
1158
|
+
"""
|
1159
|
+
if _is_hip:
|
1125
1160
|
|
1126
|
-
|
1127
|
-
input
|
1128
|
-
scale
|
1129
|
-
|
1130
|
-
|
1131
|
-
|
1132
|
-
|
1133
|
-
|
1134
|
-
|
1135
|
-
|
1161
|
+
def scaled_fp8_quant(
|
1162
|
+
input: torch.Tensor,
|
1163
|
+
scale: Optional[torch.Tensor] = None,
|
1164
|
+
num_token_padding: Optional[int] = None,
|
1165
|
+
use_per_token_if_dynamic: bool = False,
|
1166
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
1167
|
+
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
|
1168
|
+
shape = input.shape
|
1169
|
+
if num_token_padding:
|
1170
|
+
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
1171
|
+
output = torch.empty(shape, device=input.device, dtype=fp8_dtype)
|
1172
|
+
|
1173
|
+
if scale is None:
|
1174
|
+
# Dynamic scaling
|
1175
|
+
if use_per_token_if_dynamic:
|
1176
|
+
scale = torch.empty(
|
1177
|
+
(shape[0], 1), device=input.device, dtype=torch.float32
|
1178
|
+
)
|
1179
|
+
if _use_aiter:
|
1180
|
+
dynamic_per_token_scaled_quant(output, input, scale)
|
1181
|
+
else:
|
1182
|
+
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
|
1183
|
+
output, input.contiguous(), scale, None
|
1184
|
+
)
|
1185
|
+
else:
|
1186
|
+
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
1187
|
+
if _use_aiter:
|
1188
|
+
dynamic_per_tensor_quant(output, input, scale)
|
1189
|
+
else:
|
1190
|
+
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
|
1191
|
+
else:
|
1192
|
+
# Static scaling
|
1193
|
+
assert (
|
1194
|
+
scale.numel() == 1
|
1195
|
+
), f"Expected scalar scale, got numel={scale.numel()}"
|
1196
|
+
if _use_aiter:
|
1197
|
+
static_per_tensor_quant(output, input, scale)
|
1198
|
+
else:
|
1199
|
+
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
|
1136
1200
|
|
1137
|
-
|
1138
|
-
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
1139
|
-
- quantized_tensor: The FP8 quantized version of input
|
1140
|
-
- scale_tensor: The scaling factors used for quantization
|
1201
|
+
return output, scale
|
1141
1202
|
|
1142
|
-
|
1143
|
-
|
1144
|
-
|
1145
|
-
|
1146
|
-
|
1147
|
-
|
1148
|
-
|
1149
|
-
|
1150
|
-
|
1151
|
-
|
1152
|
-
|
1153
|
-
if
|
1154
|
-
|
1155
|
-
|
1203
|
+
else:
|
1204
|
+
|
1205
|
+
def scaled_fp8_quant(
|
1206
|
+
input: torch.Tensor,
|
1207
|
+
scale: Optional[torch.Tensor] = None,
|
1208
|
+
num_token_padding: Optional[int] = None,
|
1209
|
+
use_per_token_if_dynamic: bool = False,
|
1210
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
1211
|
+
|
1212
|
+
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
|
1213
|
+
shape = input.shape
|
1214
|
+
if num_token_padding:
|
1215
|
+
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
1216
|
+
output = torch.empty(shape, device=input.device, dtype=fp8_dtype)
|
1217
|
+
|
1218
|
+
if scale is None:
|
1219
|
+
# Dynamic scaling
|
1220
|
+
if use_per_token_if_dynamic:
|
1221
|
+
scale = torch.empty(
|
1222
|
+
(shape[0], 1), device=input.device, dtype=torch.float32
|
1223
|
+
)
|
1224
|
+
sgl_per_token_quant_fp8(input, output, scale)
|
1225
|
+
else:
|
1226
|
+
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
1227
|
+
sgl_per_tensor_quant_fp8(
|
1228
|
+
input, output, scale, is_static=False
|
1229
|
+
) # False for dynamic
|
1156
1230
|
else:
|
1157
|
-
|
1231
|
+
# Static scaling
|
1232
|
+
assert (
|
1233
|
+
scale.numel() == 1
|
1234
|
+
), f"Expected scalar scale, got numel={scale.numel()}"
|
1158
1235
|
sgl_per_tensor_quant_fp8(
|
1159
|
-
input, output, scale, is_static=
|
1160
|
-
) #
|
1161
|
-
|
1162
|
-
|
1163
|
-
|
1164
|
-
|
1165
|
-
|
1166
|
-
|
1236
|
+
input, output, scale, is_static=True
|
1237
|
+
) # True for static
|
1238
|
+
|
1239
|
+
return output, scale
|
1240
|
+
|
1241
|
+
|
1242
|
+
fp8_autotune = triton.autotune(
|
1243
|
+
configs=[
|
1244
|
+
triton.Config({"BLOCK_M": block_m}, num_warps=num_warps)
|
1245
|
+
for block_m in [16, 32, 64, 128]
|
1246
|
+
for num_warps in [2, 4, 8]
|
1247
|
+
],
|
1248
|
+
key=["K", "BLOCK_K", "M_ALIGNMENT"],
|
1249
|
+
)
|
1250
|
+
|
1167
1251
|
|
1168
|
-
|
1252
|
+
@triton.jit
|
1253
|
+
def _per_token_group_quant_fp8_hopper_moe_mn_major(
|
1254
|
+
a, # (M, K):(K, 1)
|
1255
|
+
expert_offsets, # (num_experts,)
|
1256
|
+
problem_sizes, # (num_experts, 3)
|
1257
|
+
a_fp8, # (M, K):(K, 1)
|
1258
|
+
sfa, # (M, k)
|
1259
|
+
K: tl.constexpr,
|
1260
|
+
BLOCK_K: tl.constexpr,
|
1261
|
+
M_ALIGNMENT: tl.constexpr,
|
1262
|
+
BLOCK_M: tl.constexpr, # tune
|
1263
|
+
):
|
1264
|
+
k_offset = tl.program_id(0)
|
1265
|
+
expert_id = tl.program_id(1)
|
1266
|
+
|
1267
|
+
m = tl.load(problem_sizes + expert_id * 3)
|
1268
|
+
current_expert_offset = tl.load(expert_offsets + expert_id).to(tl.int64)
|
1269
|
+
tl.multiple_of(m, M_ALIGNMENT)
|
1270
|
+
tl.multiple_of(current_expert_offset, M_ALIGNMENT)
|
1271
|
+
|
1272
|
+
coord_k = k_offset * BLOCK_K + tl.arange(0, BLOCK_K)
|
1273
|
+
for i in tl.range(tl.cdiv(m, BLOCK_M)):
|
1274
|
+
coord_m = i * BLOCK_M + tl.arange(0, BLOCK_M)
|
1275
|
+
a_ptrs = a + current_expert_offset * K + coord_m[:, None] * K + coord_k[None, :]
|
1276
|
+
a_mask = (coord_m < m)[:, None] & (coord_k < K)[None, :]
|
1277
|
+
|
1278
|
+
inp = tl.load(a_ptrs, mask=a_mask).to(tl.float32) # [BLOCK_M, BLOCK_K]
|
1279
|
+
inp_amax = tl.max(tl.abs(inp), axis=1) # [BLOCK_M,]
|
1280
|
+
inp_amax = tl.clamp(inp_amax, min=1e-4, max=float("inf"))
|
1281
|
+
inp_fp8 = (inp * (448.0 / inp_amax[:, None])).to(tl.float8e4nv)
|
1282
|
+
|
1283
|
+
# Store fp8
|
1284
|
+
a_fp8_ptrs = (
|
1285
|
+
a_fp8 + current_expert_offset * K + coord_m[:, None] * K + coord_k[None, :]
|
1286
|
+
)
|
1287
|
+
tl.store(a_fp8_ptrs, inp_fp8, mask=a_mask)
|
1288
|
+
|
1289
|
+
# Store sfa
|
1290
|
+
k = tl.cdiv(K, BLOCK_K)
|
1291
|
+
sfa_ptrs = (
|
1292
|
+
sfa + current_expert_offset * k + k_offset * m + coord_m
|
1293
|
+
) # MN-Major with sfa
|
1294
|
+
tl.store(sfa_ptrs, inp_amax / 448.0, mask=coord_m < m)
|
1295
|
+
|
1296
|
+
|
1297
|
+
if not _is_cpu:
|
1298
|
+
_per_token_group_quant_fp8_hopper_moe_mn_major = fp8_autotune(
|
1299
|
+
_per_token_group_quant_fp8_hopper_moe_mn_major
|
1300
|
+
)
|
1301
|
+
|
1302
|
+
|
1303
|
+
def per_token_group_quant_fp8_hopper_moe_mn_major(
|
1304
|
+
A: torch.Tensor,
|
1305
|
+
expert_offsets: torch.Tensor,
|
1306
|
+
problem_sizes: torch.Tensor,
|
1307
|
+
group_size: int,
|
1308
|
+
expert_tokens_alignment: int = 1,
|
1309
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1310
|
+
assert A.dim() == 2
|
1311
|
+
assert A.is_contiguous(), "`A` is not contiguous"
|
1312
|
+
assert (
|
1313
|
+
A.shape[-1] % group_size == 0
|
1314
|
+
), "the last dimension of `A` cannot be divisible by `group_size`"
|
1315
|
+
|
1316
|
+
a_q = torch.empty_like(A, device=A.device, dtype=fp8_dtype)
|
1317
|
+
M, K = A.shape[0], A.shape[1]
|
1318
|
+
k = K // group_size
|
1319
|
+
sfa = torch.empty((M, k), device=A.device, dtype=torch.float32)
|
1320
|
+
num_experts = problem_sizes.shape[0]
|
1321
|
+
grid = (k, num_experts)
|
1322
|
+
_per_token_group_quant_fp8_hopper_moe_mn_major[grid](
|
1323
|
+
A,
|
1324
|
+
expert_offsets,
|
1325
|
+
problem_sizes,
|
1326
|
+
a_q,
|
1327
|
+
sfa,
|
1328
|
+
K,
|
1329
|
+
group_size,
|
1330
|
+
expert_tokens_alignment,
|
1331
|
+
)
|
1332
|
+
return a_q, sfa
|
@@ -42,7 +42,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
|
42
42
|
|
43
43
|
if _use_aiter:
|
44
44
|
import aiter
|
45
|
-
from aiter import
|
45
|
+
from aiter import gemm_a8w8_blockscale, get_hip_quant
|
46
46
|
|
47
47
|
aiter_per1x128_quant = get_hip_quant(aiter.QuantType.per_1x128)
|
48
48
|
|
@@ -274,7 +274,7 @@ def aiter_w8a8_block_fp8_linear(
|
|
274
274
|
output_shape = [*input.shape[:-1], weight.shape[0]]
|
275
275
|
|
276
276
|
q_input, x_scale = aiter_per1x128_quant(input_2d, quant_dtype=aiter.dtypes.fp8)
|
277
|
-
output =
|
277
|
+
output = gemm_a8w8_blockscale(
|
278
278
|
q_input, weight, x_scale, weight_scale, dtype=input.dtype
|
279
279
|
)
|
280
280
|
|