sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +1 -1
- sglang/lang/chat_template.py +29 -0
- sglang/srt/_custom_ops.py +19 -17
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/janus_pro.py +629 -0
- sglang/srt/configs/model_config.py +24 -14
- sglang/srt/conversation.py +80 -2
- sglang/srt/custom_op.py +64 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
- sglang/srt/distributed/parallel_state.py +10 -1
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/http_server.py +1 -1
- sglang/srt/function_call_parser.py +33 -2
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
- sglang/srt/layers/attention/triton_backend.py +1 -3
- sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
- sglang/srt/layers/attention/vision.py +43 -62
- sglang/srt/layers/dp_attention.py +30 -2
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/linear.py +1 -1
- sglang/srt/layers/logits_processor.py +1 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +25 -9
- sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
- sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/parameter.py +10 -0
- sglang/srt/layers/quantization/__init__.py +90 -68
- sglang/srt/layers/quantization/blockwise_int8.py +1 -2
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +174 -106
- sglang/srt/layers/quantization/fp8_kernel.py +210 -38
- sglang/srt/layers/quantization/fp8_utils.py +156 -15
- sglang/srt/layers/quantization/modelopt_quant.py +5 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
- sglang/srt/layers/quantization/w8a8_int8.py +152 -3
- sglang/srt/layers/rotary_embedding.py +5 -3
- sglang/srt/layers/sampler.py +29 -35
- sglang/srt/layers/vocab_parallel_embedding.py +0 -1
- sglang/srt/lora/backend/__init__.py +9 -12
- sglang/srt/managers/cache_controller.py +74 -8
- sglang/srt/managers/data_parallel_controller.py +1 -1
- sglang/srt/managers/image_processor.py +37 -631
- sglang/srt/managers/image_processors/base_image_processor.py +219 -0
- sglang/srt/managers/image_processors/janus_pro.py +79 -0
- sglang/srt/managers/image_processors/llava.py +152 -0
- sglang/srt/managers/image_processors/minicpmv.py +86 -0
- sglang/srt/managers/image_processors/mlama.py +60 -0
- sglang/srt/managers/image_processors/qwen_vl.py +161 -0
- sglang/srt/managers/io_struct.py +32 -15
- sglang/srt/managers/multi_modality_padding.py +134 -0
- sglang/srt/managers/schedule_batch.py +213 -118
- sglang/srt/managers/schedule_policy.py +40 -8
- sglang/srt/managers/scheduler.py +176 -683
- sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
- sglang/srt/managers/tokenizer_manager.py +6 -6
- sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
- sglang/srt/mem_cache/base_prefix_cache.py +6 -8
- sglang/srt/mem_cache/chunk_cache.py +12 -44
- sglang/srt/mem_cache/hiradix_cache.py +71 -34
- sglang/srt/mem_cache/memory_pool.py +81 -17
- sglang/srt/mem_cache/paged_allocator.py +283 -0
- sglang/srt/mem_cache/radix_cache.py +117 -36
- sglang/srt/model_executor/cuda_graph_runner.py +68 -20
- sglang/srt/model_executor/forward_batch_info.py +23 -10
- sglang/srt/model_executor/model_runner.py +63 -63
- sglang/srt/model_loader/loader.py +2 -1
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/deepseek_janus_pro.py +2127 -0
- sglang/srt/models/deepseek_nextn.py +23 -3
- sglang/srt/models/deepseek_v2.py +200 -191
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/minicpmv.py +28 -89
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/qwen2.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +25 -50
- sglang/srt/models/qwen2_vl.py +33 -49
- sglang/srt/openai_api/adapter.py +59 -35
- sglang/srt/openai_api/protocol.py +8 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
- sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
- sglang/srt/server_args.py +24 -16
- sglang/srt/speculative/eagle_worker.py +75 -39
- sglang/srt/utils.py +104 -9
- sglang/test/runners.py +104 -10
- sglang/test/test_block_fp8.py +106 -16
- sglang/test/test_custom_ops.py +88 -0
- sglang/test/test_utils.py +20 -4
- sglang/utils.py +0 -4
- sglang/version.py +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -16,9 +16,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
|
16
16
|
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
17
17
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
18
18
|
all_close_1d,
|
19
|
-
apply_fp8_linear,
|
20
19
|
convert_to_channelwise,
|
21
|
-
cutlass_fp8_supported,
|
22
20
|
per_tensor_dequantize,
|
23
21
|
requantize_with_max_scale,
|
24
22
|
)
|
@@ -29,14 +27,21 @@ from sglang.srt.layers.linear import (
|
|
29
27
|
LinearMethodBase,
|
30
28
|
UnquantizedLinearMethod,
|
31
29
|
)
|
32
|
-
from sglang.srt.layers.parameter import
|
30
|
+
from sglang.srt.layers.parameter import (
|
31
|
+
BlockQuantScaleParameter,
|
32
|
+
ModelWeightParameter,
|
33
|
+
PerTensorScaleParameter,
|
34
|
+
)
|
33
35
|
from sglang.srt.layers.quantization.base_config import (
|
34
36
|
QuantizationConfig,
|
35
37
|
QuantizeMethodBase,
|
36
38
|
)
|
39
|
+
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
37
40
|
from sglang.srt.layers.quantization.fp8_utils import (
|
38
|
-
|
41
|
+
apply_fp8_linear,
|
39
42
|
apply_w8a8_block_fp8_linear,
|
43
|
+
cutlass_fp8_supported,
|
44
|
+
input_to_float8,
|
40
45
|
normalize_e4m3fn_to_e4m3fnuz,
|
41
46
|
)
|
42
47
|
from sglang.srt.utils import (
|
@@ -49,9 +54,9 @@ from sglang.srt.utils import (
|
|
49
54
|
|
50
55
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
51
56
|
|
52
|
-
|
57
|
+
_is_hip = is_hip()
|
53
58
|
|
54
|
-
if
|
59
|
+
if _is_hip:
|
55
60
|
from aiter.fused_moe_bf16_asm import asm_moe
|
56
61
|
from aiter.ops.shuffle import shuffle_weight
|
57
62
|
|
@@ -170,7 +175,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
170
175
|
# kernel for fast weight-only FP8 quantization
|
171
176
|
self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
|
172
177
|
# Disable marlin for ROCm
|
173
|
-
if
|
178
|
+
if _is_hip:
|
174
179
|
self.use_marlin = False
|
175
180
|
|
176
181
|
self.block_quant = self.quant_config.weight_block_size is not None
|
@@ -282,7 +287,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
282
287
|
# Block quant doesn't need to process weights after loading
|
283
288
|
if self.block_quant:
|
284
289
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
285
|
-
if
|
290
|
+
if _is_hip:
|
286
291
|
# activation_scheme: dynamic
|
287
292
|
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
288
293
|
weight=layer.weight,
|
@@ -305,15 +310,15 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
305
310
|
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
306
311
|
# If checkpoint not serialized fp8, quantize the weights.
|
307
312
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
if self.use_marlin:
|
313
|
-
assert weight_scale.numel() == 1
|
314
|
-
weight_scale = convert_to_channelwise(
|
315
|
-
weight_scale.expand(len(layer.logical_widths)), layer.logical_widths
|
313
|
+
if self.cutlass_fp8_supported or self.use_marlin:
|
314
|
+
# apply per-channel quantization default, as cutlass sgl-kernel and marlin only support per-channel scale
|
315
|
+
qweight, weight_scale = per_token_group_quant_fp8(
|
316
|
+
layer.weight, layer.weight.shape[-1]
|
316
317
|
)
|
318
|
+
weight_scale = weight_scale.t().contiguous()
|
319
|
+
else:
|
320
|
+
# per-tensor quantization
|
321
|
+
qweight, weight_scale = input_to_float8(layer.weight)
|
317
322
|
|
318
323
|
# Update the layer with the new values.
|
319
324
|
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
@@ -330,23 +335,19 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
330
335
|
layer.input_scale = torch.nn.Parameter(
|
331
336
|
layer.input_scale.data, requires_grad=False
|
332
337
|
)
|
333
|
-
|
334
|
-
#
|
335
|
-
if self.use_marlin:
|
338
|
+
|
339
|
+
# cutlass sgl-kernel and marlin only support per-channel scale
|
340
|
+
if self.cutlass_fp8_supported or self.use_marlin:
|
336
341
|
weight = layer.weight
|
337
342
|
weight_scale = convert_to_channelwise(
|
338
343
|
layer.weight_scale, layer.logical_widths
|
339
344
|
)
|
340
|
-
|
341
|
-
# If using w8a8, torch._scaled_mm needs per tensor, so
|
342
|
-
# requantize the logical shards as a single weight.
|
343
345
|
else:
|
344
346
|
# Dequant -> Quant with max scale so we can run per tensor.
|
345
347
|
weight = layer.weight
|
346
348
|
weight_scale = layer.weight_scale
|
347
|
-
|
348
349
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
349
|
-
if
|
350
|
+
if _is_hip:
|
350
351
|
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
351
352
|
weight=weight,
|
352
353
|
weight_scale=weight_scale,
|
@@ -460,7 +461,11 @@ class Fp8MoEMethod:
|
|
460
461
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
461
462
|
|
462
463
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
463
|
-
params_dtype =
|
464
|
+
params_dtype = (
|
465
|
+
torch.int32
|
466
|
+
if get_bool_env_var("USE_INT4_WEIGHT")
|
467
|
+
else torch.float8_e4m3fn
|
468
|
+
)
|
464
469
|
tp_size = get_tensor_model_parallel_world_size()
|
465
470
|
if self.block_quant:
|
466
471
|
block_n, block_k = (
|
@@ -485,21 +490,40 @@ class Fp8MoEMethod:
|
|
485
490
|
)
|
486
491
|
|
487
492
|
# WEIGHTS
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
493
|
+
if get_bool_env_var("USE_INT4_WEIGHT"):
|
494
|
+
# INT4 MoE weight - INT32 packed
|
495
|
+
w13_weight = torch.nn.Parameter(
|
496
|
+
torch.empty(
|
497
|
+
num_experts,
|
498
|
+
2 * intermediate_size,
|
499
|
+
hidden_size // 8,
|
500
|
+
dtype=params_dtype,
|
501
|
+
),
|
502
|
+
requires_grad=False,
|
503
|
+
)
|
504
|
+
w2_weight = torch.nn.Parameter(
|
505
|
+
torch.empty(
|
506
|
+
num_experts, hidden_size, intermediate_size // 8, dtype=params_dtype
|
507
|
+
),
|
508
|
+
requires_grad=False,
|
509
|
+
)
|
510
|
+
else:
|
511
|
+
w13_weight = torch.nn.Parameter(
|
512
|
+
torch.empty(
|
513
|
+
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
|
514
|
+
),
|
515
|
+
requires_grad=False,
|
516
|
+
)
|
517
|
+
w2_weight = torch.nn.Parameter(
|
518
|
+
torch.empty(
|
519
|
+
num_experts, hidden_size, intermediate_size, dtype=params_dtype
|
520
|
+
),
|
521
|
+
requires_grad=False,
|
522
|
+
)
|
523
|
+
|
494
524
|
layer.register_parameter("w13_weight", w13_weight)
|
495
525
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
496
526
|
|
497
|
-
w2_weight = torch.nn.Parameter(
|
498
|
-
torch.empty(
|
499
|
-
num_experts, hidden_size, intermediate_size, dtype=params_dtype
|
500
|
-
),
|
501
|
-
requires_grad=False,
|
502
|
-
)
|
503
527
|
layer.register_parameter("w2_weight", w2_weight)
|
504
528
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
505
529
|
|
@@ -538,7 +562,9 @@ class Fp8MoEMethod:
|
|
538
562
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
539
563
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
540
564
|
|
541
|
-
if
|
565
|
+
if (
|
566
|
+
_is_hip
|
567
|
+
): # and get_bool_env_var("CK_MOE"): TODO: add check back after triton kernel
|
542
568
|
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
|
543
569
|
w13_weight_scale1 = torch.nn.Parameter(
|
544
570
|
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
|
@@ -565,6 +591,13 @@ class Fp8MoEMethod:
|
|
565
591
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
566
592
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
567
593
|
|
594
|
+
if get_bool_env_var("USE_INT4_WEIGHT"):
|
595
|
+
extra_weight_attrs.update(
|
596
|
+
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
597
|
+
)
|
598
|
+
set_weight_attrs(w13_weight_scale1, extra_weight_attrs)
|
599
|
+
set_weight_attrs(w2_weight_scale1, extra_weight_attrs)
|
600
|
+
|
568
601
|
# INPUT_SCALES
|
569
602
|
if self.quant_config.activation_scheme == "static":
|
570
603
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
@@ -590,14 +623,14 @@ class Fp8MoEMethod:
|
|
590
623
|
layer.w2_input_scale = None
|
591
624
|
|
592
625
|
def process_weights_after_loading(self, layer: Module) -> None:
|
593
|
-
|
594
|
-
|
595
|
-
|
626
|
+
if get_bool_env_var("USE_INT4_WEIGHT"):
|
627
|
+
self.process_weights_hip_int4(layer)
|
628
|
+
return
|
596
629
|
|
597
630
|
# Block quant doesn't need to process weights after loading
|
598
631
|
if self.block_quant:
|
599
632
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
600
|
-
if
|
633
|
+
if _is_hip:
|
601
634
|
# activation_scheme: dynamic
|
602
635
|
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
603
636
|
weight=layer.w13_weight,
|
@@ -630,10 +663,11 @@ class Fp8MoEMethod:
|
|
630
663
|
layer.w2_weight.contiguous(), (16, 16)
|
631
664
|
)
|
632
665
|
return
|
666
|
+
|
633
667
|
# If checkpoint is fp16 or bfloat16, quantize in place.
|
634
668
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
635
669
|
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
|
636
|
-
fp8_dtype = torch.float8_e4m3fnuz if
|
670
|
+
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
637
671
|
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
638
672
|
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
639
673
|
|
@@ -655,33 +689,8 @@ class Fp8MoEMethod:
|
|
655
689
|
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
656
690
|
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
657
691
|
|
658
|
-
if
|
659
|
-
|
660
|
-
layer.w13_weight = torch.nn.Parameter(
|
661
|
-
permute_weight(layer.w13_weight.data),
|
662
|
-
requires_grad=False,
|
663
|
-
)
|
664
|
-
torch.cuda.empty_cache()
|
665
|
-
layer.w2_weight = torch.nn.Parameter(
|
666
|
-
permute_weight(layer.w2_weight.data),
|
667
|
-
requires_grad=False,
|
668
|
-
)
|
669
|
-
torch.cuda.empty_cache()
|
670
|
-
# ROCm (CK_MOE): using column-wise scaling
|
671
|
-
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
|
672
|
-
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
|
673
|
-
elif get_bool_env_var("MOE_PADDING"):
|
674
|
-
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
675
|
-
layer.w13_weight = torch.nn.Parameter(
|
676
|
-
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
677
|
-
requires_grad=False,
|
678
|
-
)
|
679
|
-
torch.cuda.empty_cache()
|
680
|
-
layer.w2_weight = torch.nn.Parameter(
|
681
|
-
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
682
|
-
requires_grad=False,
|
683
|
-
)
|
684
|
-
torch.cuda.empty_cache()
|
692
|
+
if _is_hip:
|
693
|
+
self.process_weights_hip_scale_padding(layer)
|
685
694
|
return
|
686
695
|
|
687
696
|
# If checkpoint is fp8, we need to handle that the
|
@@ -712,7 +721,7 @@ class Fp8MoEMethod:
|
|
712
721
|
)
|
713
722
|
|
714
723
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
715
|
-
if
|
724
|
+
if _is_hip:
|
716
725
|
# Normalize the weights and scales
|
717
726
|
w13_weight, w13_weight_scale, w13_input_scale = (
|
718
727
|
normalize_e4m3fn_to_e4m3fnuz(
|
@@ -762,35 +771,85 @@ class Fp8MoEMethod:
|
|
762
771
|
max_w13_scales, requires_grad=False
|
763
772
|
)
|
764
773
|
|
765
|
-
if
|
766
|
-
|
767
|
-
layer.w13_weight = torch.nn.Parameter(
|
768
|
-
permute_weight(layer.w13_weight.data),
|
769
|
-
requires_grad=False,
|
770
|
-
)
|
771
|
-
torch.cuda.empty_cache()
|
772
|
-
layer.w2_weight = torch.nn.Parameter(
|
773
|
-
permute_weight(layer.w2_weight.data),
|
774
|
-
requires_grad=False,
|
775
|
-
)
|
776
|
-
torch.cuda.empty_cache()
|
777
|
-
# ROCm (CK_MOE): using column-wise scaling
|
778
|
-
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
|
779
|
-
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
|
780
|
-
elif get_bool_env_var("MOE_PADDING"):
|
781
|
-
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
782
|
-
layer.w13_weight = torch.nn.Parameter(
|
783
|
-
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
784
|
-
requires_grad=False,
|
785
|
-
)
|
786
|
-
torch.cuda.empty_cache()
|
787
|
-
layer.w2_weight = torch.nn.Parameter(
|
788
|
-
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
789
|
-
requires_grad=False,
|
790
|
-
)
|
791
|
-
torch.cuda.empty_cache()
|
774
|
+
if _is_hip:
|
775
|
+
self.process_weights_hip_scale_padding(layer)
|
792
776
|
return
|
793
777
|
|
778
|
+
def process_weights_hip_int4(self, layer: Module):
|
779
|
+
# TODO: and get_bool_env_var("CK_MOE"): add after triton kernel added
|
780
|
+
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
|
781
|
+
# Weight Permutation
|
782
|
+
layer.w13_weight = torch.nn.Parameter(
|
783
|
+
permute_weight(layer.w13_weight.data),
|
784
|
+
requires_grad=False,
|
785
|
+
)
|
786
|
+
torch.cuda.empty_cache()
|
787
|
+
layer.w2_weight = torch.nn.Parameter(
|
788
|
+
permute_weight(layer.w2_weight.data),
|
789
|
+
requires_grad=False,
|
790
|
+
)
|
791
|
+
torch.cuda.empty_cache()
|
792
|
+
|
793
|
+
# INT4-FP8 : offset INT4 w13_weight_scale1 to single w13_weight_scale
|
794
|
+
# Fp8 moe kernel needs single fp8 w13_weight_scale for w13 per expert.
|
795
|
+
# We won't do requant each expert's fp8 weight (not direct available),
|
796
|
+
# instead we adjust half of INT4 w13_weight_scale1 numbers
|
797
|
+
assert layer.w13_weight_scale is not None
|
798
|
+
shard_size = layer.intermediate_size_per_partition
|
799
|
+
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
800
|
+
for expert_id in range(layer.num_experts):
|
801
|
+
start = 0
|
802
|
+
max_w13_scale_fp8 = max_w13_scales[expert_id]
|
803
|
+
for shard_id in range(2):
|
804
|
+
if layer.w13_weight_scale[expert_id][shard_id] != max_w13_scale_fp8:
|
805
|
+
int4_rescale = (
|
806
|
+
layer.w13_weight_scale[expert_id][shard_id] / max_w13_scale_fp8
|
807
|
+
)
|
808
|
+
layer.w13_weight_scale1[expert_id][
|
809
|
+
start : start + shard_size
|
810
|
+
] *= int4_rescale
|
811
|
+
start += shard_size
|
812
|
+
|
813
|
+
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
|
814
|
+
|
815
|
+
# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling
|
816
|
+
# optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post
|
817
|
+
for expert_id in range(layer.num_experts):
|
818
|
+
layer.w13_weight_scale1[expert_id] *= max_w13_scales[expert_id]
|
819
|
+
layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id]
|
820
|
+
|
821
|
+
def process_weights_hip_scale_padding(self, layer: Module, padding_size: int):
|
822
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
823
|
+
padding_size, # Avoid circular import
|
824
|
+
)
|
825
|
+
|
826
|
+
if get_bool_env_var("CK_MOE"):
|
827
|
+
layer.w13_weight = torch.nn.Parameter(
|
828
|
+
permute_weight(layer.w13_weight.data),
|
829
|
+
requires_grad=False,
|
830
|
+
)
|
831
|
+
torch.cuda.empty_cache()
|
832
|
+
layer.w2_weight = torch.nn.Parameter(
|
833
|
+
permute_weight(layer.w2_weight.data),
|
834
|
+
requires_grad=False,
|
835
|
+
)
|
836
|
+
torch.cuda.empty_cache()
|
837
|
+
# ROCm (CK_MOE): using column-wise scaling
|
838
|
+
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
|
839
|
+
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
|
840
|
+
elif get_bool_env_var("MOE_PADDING"):
|
841
|
+
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
842
|
+
layer.w13_weight = torch.nn.Parameter(
|
843
|
+
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
844
|
+
requires_grad=False,
|
845
|
+
)
|
846
|
+
torch.cuda.empty_cache()
|
847
|
+
layer.w2_weight = torch.nn.Parameter(
|
848
|
+
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
849
|
+
requires_grad=False,
|
850
|
+
)
|
851
|
+
torch.cuda.empty_cache()
|
852
|
+
|
794
853
|
def apply(
|
795
854
|
self,
|
796
855
|
layer: torch.nn.Module,
|
@@ -823,8 +882,24 @@ class Fp8MoEMethod:
|
|
823
882
|
correction_bias=correction_bias,
|
824
883
|
)
|
825
884
|
|
826
|
-
if
|
885
|
+
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
886
|
+
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
|
887
|
+
assert not no_combine, f"{no_combine=} is not supported."
|
888
|
+
return asm_moe(
|
889
|
+
x,
|
890
|
+
layer.w13_weight,
|
891
|
+
layer.w2_weight,
|
892
|
+
topk_weights,
|
893
|
+
topk_ids,
|
894
|
+
layer.w13_weight_scale1,
|
895
|
+
layer.w2_weight_scale1,
|
896
|
+
activation=activation,
|
897
|
+
)
|
898
|
+
if _is_hip and get_bool_env_var("CK_MOE"):
|
827
899
|
# TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
|
900
|
+
assert (
|
901
|
+
activation == "silu"
|
902
|
+
), f"CK_MOE: FP8 and/or FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
|
828
903
|
assert not no_combine, f"{no_combine=} is not supported."
|
829
904
|
if self.block_quant:
|
830
905
|
return asm_moe(
|
@@ -835,10 +910,6 @@ class Fp8MoEMethod:
|
|
835
910
|
topk_ids,
|
836
911
|
layer.w13_weight_scale_inv,
|
837
912
|
layer.w2_weight_scale_inv,
|
838
|
-
None,
|
839
|
-
None,
|
840
|
-
False,
|
841
|
-
None,
|
842
913
|
block_shape=tuple(self.quant_config.weight_block_size),
|
843
914
|
expert_mask=None,
|
844
915
|
)
|
@@ -851,9 +922,6 @@ class Fp8MoEMethod:
|
|
851
922
|
topk_ids,
|
852
923
|
layer.w13_weight_scale1,
|
853
924
|
layer.w2_weight_scale1,
|
854
|
-
None,
|
855
|
-
None,
|
856
|
-
False,
|
857
925
|
)
|
858
926
|
else:
|
859
927
|
# Expert fusion with FP8 quantization
|