sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_one_batch.py +8 -6
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +19 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +211 -72
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +30 -29
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +144 -55
- sglang/srt/disaggregation/utils.py +155 -123
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +37 -29
- sglang/srt/entrypoints/http_server.py +153 -72
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +40 -3
- sglang/srt/layers/attention/aiter_backend.py +20 -4
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
- sglang/srt/layers/attention/flashattention_backend.py +71 -72
- sglang/srt/layers/attention/flashinfer_backend.py +10 -8
- sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -12
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +138 -130
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +28 -10
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +2 -14
- sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
- sglang/srt/layers/moe/ep_moe/layer.py +249 -33
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +107 -12
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +249 -105
- sglang/srt/lora/mem_pool.py +53 -50
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -14
- sglang/srt/managers/io_struct.py +31 -10
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +79 -37
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +220 -79
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +40 -10
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -15
- sglang/srt/mem_cache/hiradix_cache.py +38 -25
- sglang/srt/mem_cache/memory_pool.py +213 -505
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +56 -28
- sglang/srt/model_executor/cuda_graph_runner.py +198 -100
- sglang/srt/model_executor/forward_batch_info.py +32 -10
- sglang/srt/model_executor/model_runner.py +28 -12
- sglang/srt/model_loader/loader.py +16 -2
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +213 -173
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +351 -238
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
- sglang/srt/speculative/eagle_utils.py +468 -116
- sglang/srt/speculative/eagle_worker.py +258 -84
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +235 -11
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +2 -0
- sglang/test/test_utils.py +4 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -1990
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,32 @@
|
|
1
|
+
import logging
|
2
|
+
|
3
|
+
from sglang.srt.utils import get_bool_env_var, get_device_sm
|
4
|
+
|
5
|
+
logger = logging.getLogger(__name__)
|
6
|
+
|
7
|
+
|
8
|
+
def _compute_enable_deep_gemm():
|
9
|
+
sm_version = get_device_sm()
|
10
|
+
if sm_version < 90:
|
11
|
+
return False
|
12
|
+
|
13
|
+
try:
|
14
|
+
import deep_gemm
|
15
|
+
except ImportError:
|
16
|
+
logger.warning("Failed to import deep_gemm, disable ENABLE_JIT_DEEPGEMM.")
|
17
|
+
return False
|
18
|
+
|
19
|
+
return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
|
20
|
+
|
21
|
+
|
22
|
+
ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
|
23
|
+
|
24
|
+
try:
|
25
|
+
from deep_gemm import fp8_gemm_nt
|
26
|
+
|
27
|
+
# They have not given a name to this breaking change
|
28
|
+
DEEPGEMM_BLACKWELL = True
|
29
|
+
except ImportError:
|
30
|
+
DEEPGEMM_BLACKWELL = False
|
31
|
+
|
32
|
+
DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL
|
@@ -0,0 +1,110 @@
|
|
1
|
+
import logging
|
2
|
+
from contextlib import contextmanager
|
3
|
+
from typing import Tuple
|
4
|
+
|
5
|
+
import torch
|
6
|
+
|
7
|
+
from sglang.srt.layers.quantization.deep_gemm_wrapper import compile_utils
|
8
|
+
from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
|
9
|
+
DEEPGEMM_BLACKWELL,
|
10
|
+
DEEPGEMM_SCALE_UE8M0,
|
11
|
+
ENABLE_JIT_DEEPGEMM,
|
12
|
+
)
|
13
|
+
from sglang.srt.server_args import ServerArgs
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
if ENABLE_JIT_DEEPGEMM:
|
18
|
+
import deep_gemm
|
19
|
+
|
20
|
+
if DEEPGEMM_BLACKWELL:
|
21
|
+
from deep_gemm import fp8_gemm_nt as _gemm_nt_f8f8bf16_raw
|
22
|
+
from deep_gemm import (
|
23
|
+
fp8_m_grouped_gemm_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
|
24
|
+
)
|
25
|
+
from deep_gemm import (
|
26
|
+
m_grouped_fp8_gemm_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
|
27
|
+
)
|
28
|
+
else:
|
29
|
+
from deep_gemm import gemm_fp8_fp8_bf16_nt as _gemm_nt_f8f8bf16_raw
|
30
|
+
from deep_gemm import get_col_major_tma_aligned_tensor
|
31
|
+
from deep_gemm import (
|
32
|
+
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
|
33
|
+
)
|
34
|
+
from deep_gemm import (
|
35
|
+
m_grouped_gemm_fp8_fp8_bf16_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
|
36
|
+
)
|
37
|
+
|
38
|
+
|
39
|
+
def grouped_gemm_nt_f8f8bf16_masked(
|
40
|
+
lhs: Tuple[torch.Tensor, torch.Tensor],
|
41
|
+
rhs: Tuple[torch.Tensor, torch.Tensor],
|
42
|
+
out: torch.Tensor,
|
43
|
+
masked_m: torch.Tensor,
|
44
|
+
expected_m: int,
|
45
|
+
recipe=None,
|
46
|
+
):
|
47
|
+
num_groups, _, k = lhs[0].shape
|
48
|
+
_, n, _ = rhs[0].shape
|
49
|
+
kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED
|
50
|
+
|
51
|
+
with compile_utils.deep_gemm_execution_hook(
|
52
|
+
expected_m, n, k, num_groups, kernel_type
|
53
|
+
):
|
54
|
+
_grouped_gemm_nt_f8f8bf16_masked_raw(
|
55
|
+
lhs,
|
56
|
+
rhs,
|
57
|
+
out,
|
58
|
+
masked_m,
|
59
|
+
expected_m,
|
60
|
+
**({"recipe": recipe} if DEEPGEMM_BLACKWELL else {})
|
61
|
+
)
|
62
|
+
|
63
|
+
|
64
|
+
def grouped_gemm_nt_f8f8bf16_contig(
|
65
|
+
lhs: Tuple[torch.Tensor, torch.Tensor],
|
66
|
+
rhs: Tuple[torch.Tensor, torch.Tensor],
|
67
|
+
out: torch.Tensor,
|
68
|
+
m_indices: torch.Tensor,
|
69
|
+
):
|
70
|
+
m, k = lhs[0].shape
|
71
|
+
num_groups, n, _ = rhs[0].shape
|
72
|
+
kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
|
73
|
+
|
74
|
+
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
|
75
|
+
_grouped_gemm_nt_f8f8bf16_contig_raw(lhs, rhs, out, m_indices)
|
76
|
+
|
77
|
+
|
78
|
+
def gemm_nt_f8f8bf16(
|
79
|
+
lhs: Tuple[torch.Tensor, torch.Tensor],
|
80
|
+
rhs: Tuple[torch.Tensor, torch.Tensor],
|
81
|
+
out: torch.Tensor,
|
82
|
+
):
|
83
|
+
m, k = lhs[0].shape
|
84
|
+
n, _ = rhs[0].shape
|
85
|
+
num_groups = 1
|
86
|
+
kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16
|
87
|
+
|
88
|
+
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
|
89
|
+
_gemm_nt_f8f8bf16_raw(
|
90
|
+
lhs,
|
91
|
+
rhs,
|
92
|
+
out,
|
93
|
+
)
|
94
|
+
|
95
|
+
|
96
|
+
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
|
97
|
+
compile_utils.update_deep_gemm_config(gpu_id, server_args)
|
98
|
+
|
99
|
+
|
100
|
+
@contextmanager
|
101
|
+
def configure_deep_gemm_num_sms(num_sms):
|
102
|
+
if num_sms is None:
|
103
|
+
yield
|
104
|
+
else:
|
105
|
+
original_num_sms = deep_gemm.get_num_sms()
|
106
|
+
deep_gemm.set_num_sms(num_sms)
|
107
|
+
try:
|
108
|
+
yield
|
109
|
+
finally:
|
110
|
+
deep_gemm.set_num_sms(original_num_sms)
|
@@ -64,9 +64,12 @@ from sglang.srt.layers.quantization.utils import (
|
|
64
64
|
)
|
65
65
|
from sglang.srt.layers.utils import is_sm100_supported
|
66
66
|
from sglang.srt.utils import (
|
67
|
+
cpu_has_amx_support,
|
67
68
|
get_bool_env_var,
|
69
|
+
is_cpu,
|
68
70
|
is_cuda,
|
69
71
|
is_hip,
|
72
|
+
is_npu,
|
70
73
|
log_info_on_rank0,
|
71
74
|
print_warning_once,
|
72
75
|
set_weight_attrs,
|
@@ -74,6 +77,9 @@ from sglang.srt.utils import (
|
|
74
77
|
|
75
78
|
_is_hip = is_hip()
|
76
79
|
_is_cuda = is_cuda()
|
80
|
+
_is_npu = is_npu()
|
81
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
82
|
+
_is_cpu = is_cpu()
|
77
83
|
|
78
84
|
_is_fp8_fnuz = is_fp8_fnuz()
|
79
85
|
|
@@ -82,10 +88,11 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
|
82
88
|
|
83
89
|
if _is_hip:
|
84
90
|
from aiter import ActivationType, QuantType
|
91
|
+
from aiter.fused_moe import fused_moe
|
85
92
|
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
|
86
93
|
from aiter.ops.shuffle import shuffle_weight
|
87
94
|
|
88
|
-
if not _is_cuda:
|
95
|
+
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
89
96
|
from vllm._custom_ops import scaled_fp8_quant
|
90
97
|
|
91
98
|
|
@@ -1045,15 +1052,15 @@ class Fp8MoEMethod:
|
|
1045
1052
|
if _use_hip_int4:
|
1046
1053
|
# TODO: add triton kernel and add check _use_aiter
|
1047
1054
|
assert not no_combine, f"{no_combine=} is not supported."
|
1048
|
-
return
|
1055
|
+
return fused_moe(
|
1049
1056
|
x,
|
1050
1057
|
layer.w13_weight,
|
1051
1058
|
layer.w2_weight,
|
1052
1059
|
topk_weights,
|
1053
1060
|
topk_ids,
|
1054
|
-
QuantType.per_Token,
|
1055
|
-
layer.w13_weight_scale1,
|
1056
|
-
layer.w2_weight_scale1,
|
1061
|
+
quant_type=QuantType.per_Token,
|
1062
|
+
w1_scale=layer.w13_weight_scale1,
|
1063
|
+
w2_scale=layer.w2_weight_scale1,
|
1057
1064
|
activation=(
|
1058
1065
|
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
|
1059
1066
|
),
|
@@ -1062,31 +1069,32 @@ class Fp8MoEMethod:
|
|
1062
1069
|
if _use_aiter:
|
1063
1070
|
assert not no_combine, f"{no_combine=} is not supported."
|
1064
1071
|
if self.block_quant:
|
1065
|
-
|
1066
|
-
assert (
|
1067
|
-
activation == "silu"
|
1068
|
-
), f"_use_aiter: FP8 bloack_quant {activation=} will be supported later, unset _use_aiter"
|
1069
|
-
return asm_moe(
|
1072
|
+
return fused_moe(
|
1070
1073
|
x,
|
1071
1074
|
layer.w13_weight,
|
1072
1075
|
layer.w2_weight,
|
1073
1076
|
topk_weights,
|
1074
1077
|
topk_ids,
|
1075
|
-
layer.w13_weight_scale_inv,
|
1076
|
-
layer.w2_weight_scale_inv,
|
1077
|
-
|
1078
|
+
w1_scale=layer.w13_weight_scale_inv,
|
1079
|
+
w2_scale=layer.w2_weight_scale_inv,
|
1080
|
+
quant_type=QuantType.per_128x128,
|
1081
|
+
activation=(
|
1082
|
+
ActivationType.Silu
|
1083
|
+
if activation == "silu"
|
1084
|
+
else ActivationType.Gelu
|
1085
|
+
),
|
1078
1086
|
expert_mask=None,
|
1079
1087
|
)
|
1080
1088
|
else:
|
1081
|
-
return
|
1089
|
+
return fused_moe(
|
1082
1090
|
x,
|
1083
1091
|
layer.w13_weight,
|
1084
1092
|
layer.w2_weight,
|
1085
1093
|
topk_weights,
|
1086
1094
|
topk_ids,
|
1087
|
-
QuantType.per_Token,
|
1088
|
-
layer.w13_weight_scale1,
|
1089
|
-
layer.w2_weight_scale1,
|
1095
|
+
quant_type=QuantType.per_Token,
|
1096
|
+
w1_scale=layer.w13_weight_scale1,
|
1097
|
+
w2_scale=layer.w2_weight_scale1,
|
1090
1098
|
activation=(
|
1091
1099
|
ActivationType.Silu
|
1092
1100
|
if activation == "silu"
|
@@ -23,7 +23,8 @@ import torch
|
|
23
23
|
import triton
|
24
24
|
import triton.language as tl
|
25
25
|
|
26
|
-
from sglang.
|
26
|
+
from sglang.math_utils import align
|
27
|
+
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
27
28
|
from sglang.srt.utils import (
|
28
29
|
direct_register_custom_op,
|
29
30
|
get_device_core_count,
|
@@ -44,10 +45,6 @@ if _is_cuda:
|
|
44
45
|
sgl_per_token_quant_fp8,
|
45
46
|
)
|
46
47
|
|
47
|
-
from sglang.srt.layers.quantization.deep_gemm import (
|
48
|
-
gemm_nt_f8f8bf16 as deep_gemm_gemm_nt_f8f8bf16,
|
49
|
-
)
|
50
|
-
|
51
48
|
logger = logging.getLogger(__name__)
|
52
49
|
|
53
50
|
|
@@ -67,7 +64,6 @@ else:
|
|
67
64
|
fp8_max = torch.finfo(fp8_dtype).max
|
68
65
|
fp8_min = -fp8_max
|
69
66
|
|
70
|
-
|
71
67
|
if supports_custom_op():
|
72
68
|
|
73
69
|
def deep_gemm_fp8_fp8_bf16_nt(
|
@@ -77,7 +73,7 @@ if supports_custom_op():
|
|
77
73
|
Bs: torch.Tensor,
|
78
74
|
C: torch.Tensor,
|
79
75
|
) -> None:
|
80
|
-
|
76
|
+
deep_gemm_wrapper.gemm_nt_f8f8bf16((A, As), (B, Bs), C)
|
81
77
|
|
82
78
|
def deep_gemm_fp8_fp8_bf16_nt_fake(
|
83
79
|
A: torch.Tensor,
|
@@ -280,6 +276,7 @@ def sglang_per_token_group_quant_fp8(
|
|
280
276
|
eps: float = 1e-10,
|
281
277
|
column_major_scales: bool = False,
|
282
278
|
scale_tma_aligned: bool = False,
|
279
|
+
scale_ue8m0: bool = False,
|
283
280
|
):
|
284
281
|
assert (
|
285
282
|
x.shape[-1] % group_size == 0
|
@@ -287,8 +284,21 @@ def sglang_per_token_group_quant_fp8(
|
|
287
284
|
assert x.is_contiguous(), "`x` is not contiguous"
|
288
285
|
|
289
286
|
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
|
290
|
-
if
|
287
|
+
if scale_ue8m0:
|
288
|
+
assert column_major_scales and scale_tma_aligned
|
289
|
+
x_q_mn, x_q_k = x.shape
|
290
|
+
x_s_mn, x_s_k = x_q_mn, x_q_k // 128
|
291
|
+
aligned_mn = align(x_s_mn, 4)
|
292
|
+
aligned_k = align(x_s_k, 4)
|
293
|
+
# TODO(FIXME): Fix cuda kernel and recover here to empty.
|
294
|
+
x_s = torch.zeros(
|
295
|
+
(aligned_k // 4, aligned_mn),
|
296
|
+
device=x.device,
|
297
|
+
dtype=torch.int,
|
298
|
+
).transpose(0, 1)[:x_s_mn, :]
|
299
|
+
elif column_major_scales:
|
291
300
|
if scale_tma_aligned:
|
301
|
+
# TODO extract "align" function
|
292
302
|
# aligned to 4 * sizeof(float)
|
293
303
|
aligned_size = (x.shape[-2] + 3) // 4 * 4
|
294
304
|
x_s = torch.empty(
|
@@ -309,7 +319,9 @@ def sglang_per_token_group_quant_fp8(
|
|
309
319
|
dtype=torch.float32,
|
310
320
|
)
|
311
321
|
if x.shape[0] > 0:
|
312
|
-
sgl_per_token_group_quant_fp8(
|
322
|
+
sgl_per_token_group_quant_fp8(
|
323
|
+
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
|
324
|
+
)
|
313
325
|
|
314
326
|
return x_q, x_s
|
315
327
|
|
@@ -754,7 +766,15 @@ def prepare_block_fp8_matmul_inputs(
|
|
754
766
|
assert A.shape[-1] == B.shape[-1]
|
755
767
|
assert A.shape[:-1] == As.shape[:-1]
|
756
768
|
assert A.is_contiguous()
|
757
|
-
|
769
|
+
|
770
|
+
if As.dtype == torch.float:
|
771
|
+
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
772
|
+
elif As.dtype == torch.int:
|
773
|
+
assert (
|
774
|
+
triton.cdiv(triton.cdiv(A.shape[-1], block_k), 4) == As.shape[-1]
|
775
|
+
), f"{A.shape=} {As.shape=} {block_size=}"
|
776
|
+
else:
|
777
|
+
raise NotImplementedError
|
758
778
|
|
759
779
|
M = A.numel() // A.shape[-1]
|
760
780
|
|
@@ -762,8 +782,17 @@ def prepare_block_fp8_matmul_inputs(
|
|
762
782
|
assert B.is_contiguous()
|
763
783
|
assert Bs.ndim == 2
|
764
784
|
N, K = B.shape
|
765
|
-
|
766
|
-
|
785
|
+
|
786
|
+
if Bs.dtype == torch.float:
|
787
|
+
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
788
|
+
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
789
|
+
elif Bs.dtype == torch.int:
|
790
|
+
assert N == Bs.shape[0], f"{B.shape=} {Bs.shape=} {block_size=}"
|
791
|
+
assert (
|
792
|
+
triton.cdiv(triton.cdiv(K, block_k), 4) == Bs.shape[1]
|
793
|
+
), f"{B.shape=} {Bs.shape=} {block_size=}"
|
794
|
+
else:
|
795
|
+
raise NotImplementedError
|
767
796
|
|
768
797
|
C_shape = A.shape[:-1] + (N,)
|
769
798
|
C = A.new_empty(C_shape, dtype=output_dtype)
|
@@ -782,12 +811,12 @@ def w8a8_block_fp8_matmul_deepgemm(
|
|
782
811
|
M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size, output_dtype)
|
783
812
|
|
784
813
|
# Deepgemm only supports output tensor type as bfloat16
|
785
|
-
assert C.dtype == torch.bfloat16 and
|
814
|
+
assert C.dtype == torch.bfloat16 and deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
786
815
|
|
787
816
|
if supports_custom_op():
|
788
817
|
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
|
789
818
|
else:
|
790
|
-
|
819
|
+
deep_gemm_wrapper.gemm_nt_f8f8bf16((A, As), (B, Bs), C)
|
791
820
|
|
792
821
|
return C
|
793
822
|
|
@@ -881,7 +910,7 @@ def w8a8_block_fp8_matmul(
|
|
881
910
|
block_size: List[int],
|
882
911
|
output_dtype: torch.dtype = torch.float16,
|
883
912
|
) -> torch.Tensor:
|
884
|
-
if output_dtype == torch.bfloat16 and
|
913
|
+
if output_dtype == torch.bfloat16 and deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
885
914
|
return w8a8_block_fp8_matmul_deepgemm(
|
886
915
|
A, B, As, Bs, block_size, output_dtype=output_dtype
|
887
916
|
)
|
@@ -1,9 +1,10 @@
|
|
1
|
-
import os
|
2
|
-
from curses import flash
|
3
1
|
from typing import Callable, List, Optional, Tuple
|
4
2
|
|
3
|
+
import einops
|
5
4
|
import torch
|
6
5
|
|
6
|
+
from sglang.math_utils import align
|
7
|
+
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
7
8
|
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
8
9
|
from sglang.srt.layers.utils import is_sm100_supported
|
9
10
|
|
@@ -14,7 +15,6 @@ try:
|
|
14
15
|
except ImportError:
|
15
16
|
VLLM_AVAILABLE = False
|
16
17
|
|
17
|
-
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
18
18
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
19
19
|
fp8_dtype,
|
20
20
|
fp8_max,
|
@@ -137,7 +137,7 @@ def dispatch_w8a8_block_fp8_linear() -> Callable:
|
|
137
137
|
return cutlass_w8a8_block_fp8_linear_with_fallback
|
138
138
|
elif _use_aiter:
|
139
139
|
return aiter_w8a8_block_fp8_linear
|
140
|
-
elif
|
140
|
+
elif deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
141
141
|
return deepgemm_w8a8_block_fp8_linear_with_fallback
|
142
142
|
else:
|
143
143
|
return triton_w8a8_block_fp8_linear
|
@@ -238,7 +238,14 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
|
|
238
238
|
block_size[1],
|
239
239
|
column_major_scales=True,
|
240
240
|
scale_tma_aligned=True,
|
241
|
+
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
241
242
|
)
|
243
|
+
|
244
|
+
# NOTE(alcanderian): Useless when scale is packed to int32
|
245
|
+
# if get_bool_env_var("SGLANG_W8A8_DEEPGEMM_SANITY_CHECK_UE8M0"):
|
246
|
+
# _check_ue8m0("x_scale", x_scale)
|
247
|
+
# _check_ue8m0("weight_scale", ws)
|
248
|
+
|
242
249
|
output = w8a8_block_fp8_matmul_deepgemm(
|
243
250
|
q_input, weight, x_scale, weight_scale, block_size, output_dtype=output_dtype
|
244
251
|
)
|
@@ -247,6 +254,11 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
|
|
247
254
|
return output.to(dtype=output_dtype).view(*output_shape)
|
248
255
|
|
249
256
|
|
257
|
+
def _check_ue8m0(name, x):
|
258
|
+
x_ceil = ceil_to_ue8m0(x)
|
259
|
+
assert torch.all(x == x_ceil), f"{name=} {x=} {x_ceil=}"
|
260
|
+
|
261
|
+
|
250
262
|
def aiter_w8a8_block_fp8_linear(
|
251
263
|
input: torch.Tensor,
|
252
264
|
weight: torch.Tensor,
|
@@ -369,27 +381,80 @@ def block_quant_dequant(
|
|
369
381
|
The output is an unquantized tensor with dtype.
|
370
382
|
"""
|
371
383
|
block_n, block_k = block_size[0], block_size[1]
|
372
|
-
n, k = x_q_block.shape
|
373
|
-
n_tiles = (n + block_n - 1) // block_n
|
374
|
-
k_tiles = (k + block_k - 1) // block_k
|
375
|
-
assert n_tiles == x_s.shape[0]
|
376
|
-
assert k_tiles == x_s.shape[1]
|
384
|
+
*_, n, k = x_q_block.shape
|
377
385
|
|
378
|
-
|
386
|
+
# ... n_scale k_scale -> ... (n_scale block_n) (k_scale block_k)
|
387
|
+
x_scale_repeat = x_s.repeat_interleave(block_n, dim=-2).repeat_interleave(
|
388
|
+
block_k, dim=-1
|
389
|
+
)
|
390
|
+
x_scale_repeat = x_scale_repeat[..., :n, :k]
|
391
|
+
|
392
|
+
return (x_q_block.to(torch.float32) * x_scale_repeat).to(dtype)
|
393
|
+
|
394
|
+
|
395
|
+
def requant_weight_ue8m0_inplace(weight, weight_scale_inv, weight_block_size):
|
396
|
+
assert isinstance(weight, torch.nn.Parameter)
|
397
|
+
assert isinstance(weight_scale_inv, torch.nn.Parameter)
|
398
|
+
weight.data, weight_scale_inv.data = _requant_weight_ue8m0(
|
399
|
+
weight, weight_scale_inv, weight_block_size
|
400
|
+
)
|
401
|
+
|
402
|
+
|
403
|
+
def _requant_weight_ue8m0(
|
404
|
+
weight: torch.Tensor,
|
405
|
+
weight_scale_inv: torch.Tensor,
|
406
|
+
weight_block_size: List[int],
|
407
|
+
):
|
408
|
+
assert weight_block_size == [128, 128]
|
409
|
+
|
410
|
+
*_, n, k = weight.shape
|
411
|
+
|
412
|
+
weight_dequant = block_quant_dequant(
|
413
|
+
weight,
|
414
|
+
weight_scale_inv,
|
415
|
+
weight_block_size,
|
416
|
+
torch.bfloat16,
|
417
|
+
)
|
418
|
+
|
419
|
+
weight_dequant_flat = weight_dequant.view((-1, k))
|
420
|
+
out_w_flat, out_s_flat = per_block_cast_to_fp8(weight_dequant_flat)
|
421
|
+
|
422
|
+
out_w = out_w_flat.view(weight.shape)
|
423
|
+
out_s = out_s_flat.view(weight_scale_inv.shape)
|
424
|
+
|
425
|
+
# NOTE copy and modified from DeepGEMM
|
426
|
+
def _transform_scale(sf, mn: int):
|
427
|
+
import deep_gemm.utils.layout
|
428
|
+
|
429
|
+
sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128)
|
430
|
+
sf = deep_gemm.utils.layout.get_col_major_tma_aligned_packed_tensor(sf)
|
431
|
+
return sf
|
432
|
+
|
433
|
+
out_s = _transform_scale(out_s, mn=out_w.shape[-2])
|
434
|
+
|
435
|
+
return out_w, out_s
|
436
|
+
|
437
|
+
|
438
|
+
# COPIED FROM DeepGEMM
|
439
|
+
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
440
|
+
assert x.dim() == 2
|
441
|
+
m, n = x.shape
|
442
|
+
x_padded = torch.zeros(
|
443
|
+
(align(m, 128), align(n, 128)), dtype=x.dtype, device=x.device
|
444
|
+
)
|
445
|
+
x_padded[:m, :n] = x
|
446
|
+
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
447
|
+
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
448
|
+
sf = ceil_to_ue8m0(x_amax / 448.0)
|
449
|
+
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
|
450
|
+
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
|
451
|
+
x_view.size(0), x_view.size(2)
|
452
|
+
)
|
379
453
|
|
380
|
-
for j in range(n_tiles):
|
381
|
-
for i in range(k_tiles):
|
382
|
-
x_q_block_tile = x_q_block[
|
383
|
-
j * block_n : min((j + 1) * block_n, n),
|
384
|
-
i * block_k : min((i + 1) * block_k, k),
|
385
|
-
]
|
386
|
-
x_dq_block_tile = x_dq_block[
|
387
|
-
j * block_n : min((j + 1) * block_n, n),
|
388
|
-
i * block_k : min((i + 1) * block_k, k),
|
389
|
-
]
|
390
|
-
x_dq_block_tile[:, :] = x_q_block_tile.to(torch.float32) * x_s[j][i]
|
391
454
|
|
392
|
-
|
455
|
+
# COPIED FROM DeepGEMM
|
456
|
+
def ceil_to_ue8m0(x: torch.Tensor):
|
457
|
+
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
|
393
458
|
|
394
459
|
|
395
460
|
def channel_quant_to_tensor_quant(
|
@@ -29,11 +29,17 @@ from sglang.srt.layers.quantization.utils import (
|
|
29
29
|
requantize_with_max_scale,
|
30
30
|
)
|
31
31
|
from sglang.srt.layers.radix_attention import RadixAttention
|
32
|
-
from sglang.srt.utils import is_cuda
|
32
|
+
from sglang.srt.utils import is_cuda, next_power_of_2
|
33
33
|
|
34
34
|
if is_cuda():
|
35
35
|
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
36
36
|
|
37
|
+
try:
|
38
|
+
from flashinfer import fp4_quantize as fp4_quantize
|
39
|
+
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
|
40
|
+
except ImportError:
|
41
|
+
flashinfer_cutlass_fused_moe = None
|
42
|
+
|
37
43
|
# Initialize logger for the module
|
38
44
|
logger = logging.getLogger(__name__)
|
39
45
|
|
@@ -429,6 +435,9 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
|
|
429
435
|
layer.alpha = Parameter(
|
430
436
|
layer.input_scale * layer.weight_scale_2, requires_grad=False
|
431
437
|
)
|
438
|
+
layer.input_scale_inv = Parameter(
|
439
|
+
(1 / input_scale_2).to(torch.float32), requires_grad=False
|
440
|
+
)
|
432
441
|
|
433
442
|
# Pad and blockwise interleave weight_scale
|
434
443
|
scales = layer.weight_scale
|
@@ -467,7 +476,7 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
|
|
467
476
|
output_shape = [x_m, w_n]
|
468
477
|
|
469
478
|
# Quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
470
|
-
x_fp4, x_scale_interleaved = scaled_fp4_quant(x,
|
479
|
+
x_fp4, x_scale_interleaved = scaled_fp4_quant(x, layer.input_scale_inv)
|
471
480
|
|
472
481
|
assert x_fp4.dtype == torch.uint8
|
473
482
|
assert x_scale_interleaved.dtype == torch.float8_e4m3fn
|
@@ -521,6 +530,7 @@ class ModelOptNvFp4FusedMoEMethod:
|
|
521
530
|
" quantization. Please use Blackwell and"
|
522
531
|
" above."
|
523
532
|
)
|
533
|
+
self.enable_flashinfer_moe = False
|
524
534
|
|
525
535
|
def create_weights(
|
526
536
|
self,
|
@@ -674,7 +684,10 @@ class ModelOptNvFp4FusedMoEMethod:
|
|
674
684
|
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
|
675
685
|
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
|
676
686
|
|
677
|
-
|
687
|
+
if self.enable_flashinfer_moe:
|
688
|
+
w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
|
689
|
+
else:
|
690
|
+
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
|
678
691
|
layer.g1_alphas = Parameter(
|
679
692
|
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
|
680
693
|
requires_grad=False,
|
@@ -700,14 +713,19 @@ class ModelOptNvFp4FusedMoEMethod:
|
|
700
713
|
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
|
701
714
|
|
702
715
|
# GEMM 2
|
716
|
+
if self.enable_flashinfer_moe:
|
717
|
+
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
|
718
|
+
else:
|
719
|
+
w2_input_scale = layer.w2_input_scale
|
720
|
+
|
703
721
|
layer.g2_alphas = Parameter(
|
704
|
-
(
|
722
|
+
(w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
|
705
723
|
requires_grad=False,
|
706
724
|
)
|
707
725
|
|
708
726
|
# This is for quantization, so we need to invert it.
|
709
727
|
layer.w2_input_scale_quant = Parameter(
|
710
|
-
(1 /
|
728
|
+
(1 / w2_input_scale).to(torch.float32), requires_grad=False
|
711
729
|
)
|
712
730
|
|
713
731
|
assert (
|
@@ -727,11 +745,16 @@ class ModelOptNvFp4FusedMoEMethod:
|
|
727
745
|
layer.cutlass_moe_params = CutlassMoEParams(
|
728
746
|
CutlassMoEType.BlockscaledFP4,
|
729
747
|
device,
|
730
|
-
num_experts=layer.num_experts,
|
748
|
+
num_experts=layer.num_experts, # global num experts
|
731
749
|
intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n
|
732
750
|
hidden_size=layer.w13_weight.shape[2] * 2,
|
733
751
|
) # k
|
734
752
|
|
753
|
+
@property
|
754
|
+
def load_up_proj_weight_first(self) -> bool:
|
755
|
+
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
|
756
|
+
return self.enable_flashinfer_moe
|
757
|
+
|
735
758
|
def apply(
|
736
759
|
self,
|
737
760
|
layer: torch.nn.Module,
|
@@ -750,11 +773,13 @@ class ModelOptNvFp4FusedMoEMethod:
|
|
750
773
|
inplace: bool = True,
|
751
774
|
no_combine: bool = False,
|
752
775
|
routed_scaling_factor: Optional[float] = None,
|
776
|
+
ep_rank: Optional[int] = None,
|
777
|
+
ep_size: Optional[int] = None,
|
778
|
+
tp_rank: Optional[int] = None,
|
779
|
+
tp_size: Optional[int] = None,
|
753
780
|
) -> torch.Tensor:
|
754
781
|
|
755
782
|
assert activation == "silu", "Only SiLU activation is supported."
|
756
|
-
|
757
|
-
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
758
783
|
from sglang.srt.layers.moe.topk import select_experts
|
759
784
|
|
760
785
|
topk_weights, topk_ids = select_experts(
|
@@ -771,6 +796,35 @@ class ModelOptNvFp4FusedMoEMethod:
|
|
771
796
|
routed_scaling_factor=routed_scaling_factor,
|
772
797
|
)
|
773
798
|
|
799
|
+
if self.enable_flashinfer_moe:
|
800
|
+
assert (
|
801
|
+
not apply_router_weight_on_input
|
802
|
+
), "apply_router_weight_on_input is not supported for Flashinfer"
|
803
|
+
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
|
804
|
+
# and fp4 quantized weights loaded from the checkpoint
|
805
|
+
output = flashinfer_cutlass_fused_moe(
|
806
|
+
x,
|
807
|
+
topk_ids.to(torch.int),
|
808
|
+
topk_weights,
|
809
|
+
layer.w13_weight.view(torch.long),
|
810
|
+
layer.w2_weight.view(torch.long),
|
811
|
+
x.dtype,
|
812
|
+
quant_scales=[
|
813
|
+
layer.w13_input_scale_quant,
|
814
|
+
layer.w13_blockscale_swizzled.view(torch.int32),
|
815
|
+
layer.g1_alphas,
|
816
|
+
layer.w2_input_scale_quant,
|
817
|
+
layer.w2_blockscale_swizzled.view(torch.int32),
|
818
|
+
layer.g2_alphas,
|
819
|
+
],
|
820
|
+
ep_size=ep_size,
|
821
|
+
ep_rank=ep_rank,
|
822
|
+
tp_size=tp_size,
|
823
|
+
tp_rank=tp_rank,
|
824
|
+
tune_max_num_tokens=next_power_of_2(x.shape[0]),
|
825
|
+
)
|
826
|
+
return output[0]
|
827
|
+
|
774
828
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
775
829
|
|
776
830
|
return cutlass_moe_fp4(
|