sglang 0.5.1.post1__py3-none-any.whl → 0.5.1.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +79 -53
- sglang/bench_serving.py +186 -14
- sglang/profiler.py +0 -1
- sglang/srt/conversation.py +38 -5
- sglang/srt/disaggregation/decode.py +4 -0
- sglang/srt/disaggregation/prefill.py +4 -0
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/entrypoints/openai/protocol.py +27 -24
- sglang/srt/entrypoints/openai/serving_chat.py +50 -9
- sglang/srt/entrypoints/openai/serving_completions.py +15 -0
- sglang/srt/entrypoints/tool.py +7 -7
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/harmony_parser.py +588 -0
- sglang/srt/hf_transformers_utils.py +16 -7
- sglang/srt/layers/attention/ascend_backend.py +218 -111
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/flashinfer_mla_backend.py +76 -91
- sglang/srt/layers/attention/utils.py +15 -94
- sglang/srt/layers/communicator.py +1 -2
- sglang/srt/layers/moe/cutlass_moe.py +0 -15
- sglang/srt/layers/moe/ep_moe/layer.py +1 -7
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +133 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -7
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +5 -23
- sglang/srt/layers/quantization/fp8.py +2 -1
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/mxfp4.py +16 -23
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/lora/lora_manager.py +29 -12
- sglang/srt/managers/cache_controller.py +223 -156
- sglang/srt/managers/detokenizer_manager.py +5 -0
- sglang/srt/managers/io_struct.py +30 -0
- sglang/srt/managers/scheduler.py +58 -7
- sglang/srt/managers/scheduler_metrics_mixin.py +15 -0
- sglang/srt/managers/tokenizer_manager.py +36 -3
- sglang/srt/mem_cache/hicache_storage.py +31 -20
- sglang/srt/mem_cache/hiradix_cache.py +12 -3
- sglang/srt/mem_cache/memory_pool.py +73 -14
- sglang/srt/mem_cache/memory_pool_host.py +3 -2
- sglang/srt/mem_cache/radix_cache.py +1 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +5 -13
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +85 -81
- sglang/srt/metrics/collector.py +5 -5
- sglang/srt/model_executor/cuda_graph_runner.py +2 -2
- sglang/srt/model_executor/model_runner.py +1 -1
- sglang/srt/models/deepseek_v2.py +12 -3
- sglang/srt/models/gpt_oss.py +2 -1
- sglang/srt/models/qwen2_5_vl.py +1 -0
- sglang/srt/offloader.py +115 -0
- sglang/srt/reasoning_parser.py +56 -300
- sglang/srt/server_args.py +10 -5
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +59 -12
- sglang/test/test_cutlass_moe.py +33 -28
- sglang/version.py +1 -1
- {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/METADATA +6 -5
- {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/RECORD +69 -65
- {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/top_level.txt +0 -0
@@ -1,26 +1,22 @@
|
|
1
1
|
import logging
|
2
2
|
import os
|
3
3
|
from contextlib import contextmanager
|
4
|
-
from dataclasses import dataclass
|
5
4
|
from enum import IntEnum, auto
|
6
|
-
from typing import
|
5
|
+
from typing import Dict, List, Tuple
|
7
6
|
|
8
|
-
|
7
|
+
import torch
|
8
|
+
from tqdm import tqdm
|
9
9
|
|
10
10
|
from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
|
11
|
-
DEEPGEMM_BLACKWELL,
|
12
11
|
ENABLE_JIT_DEEPGEMM,
|
13
12
|
)
|
14
13
|
from sglang.srt.server_args import ServerArgs
|
15
|
-
from sglang.srt.utils import get_bool_env_var, get_int_env_var
|
14
|
+
from sglang.srt.utils import ceil_div, get_bool_env_var, get_int_env_var
|
16
15
|
|
17
16
|
logger = logging.getLogger(__name__)
|
18
17
|
|
19
|
-
if ENABLE_JIT_DEEPGEMM
|
20
|
-
|
21
|
-
from deep_gemm.jit import build
|
22
|
-
from deep_gemm.jit_kernels.gemm import get_best_configs
|
23
|
-
from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
|
18
|
+
if ENABLE_JIT_DEEPGEMM:
|
19
|
+
import deep_gemm
|
24
20
|
|
25
21
|
|
26
22
|
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
|
@@ -40,19 +36,7 @@ os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
|
|
40
36
|
# Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
|
41
37
|
# NVRTC may have performance loss with some cases.
|
42
38
|
# And NVCC JIT speed is also 9x faster in the ref commit
|
43
|
-
|
44
|
-
if ENABLE_JIT_DEEPGEMM:
|
45
|
-
try:
|
46
|
-
from deep_gemm.jit.compiler import get_nvcc_compiler
|
47
|
-
|
48
|
-
get_nvcc_compiler()
|
49
|
-
except:
|
50
|
-
logger.warning(
|
51
|
-
"NVCC Compiler not found, use NVRTC for DeepGEMM JIT "
|
52
|
-
"and may have performance loss with some cases."
|
53
|
-
)
|
54
|
-
_USE_NVRTC_DEFAULT = "1"
|
55
|
-
os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", _USE_NVRTC_DEFAULT)
|
39
|
+
os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", "0")
|
56
40
|
|
57
41
|
|
58
42
|
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
|
@@ -75,7 +59,7 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
|
|
75
59
|
# Default each rank will try compile all Ms to
|
76
60
|
# load all symbols at the launch stages.
|
77
61
|
# Avoid loading symbols at the serving stages.
|
78
|
-
_DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE
|
62
|
+
_DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE
|
79
63
|
|
80
64
|
|
81
65
|
class DeepGemmKernelType(IntEnum):
|
@@ -84,185 +68,15 @@ class DeepGemmKernelType(IntEnum):
|
|
84
68
|
GEMM_NT_F8F8BF16 = auto()
|
85
69
|
|
86
70
|
|
87
|
-
@dataclass
|
88
|
-
class DeepGemmKernelHelper:
|
89
|
-
name: str
|
90
|
-
compile_func: Callable[
|
91
|
-
[
|
92
|
-
int,
|
93
|
-
int,
|
94
|
-
int,
|
95
|
-
Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
96
|
-
],
|
97
|
-
None,
|
98
|
-
]
|
99
|
-
configure_func: Callable[
|
100
|
-
[int, int, int, int, int],
|
101
|
-
Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
102
|
-
]
|
103
|
-
|
104
|
-
|
105
71
|
_INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dict()
|
106
72
|
|
107
73
|
|
108
|
-
# TODO improve
|
109
|
-
def _compile_warning_1():
|
110
|
-
if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
|
111
|
-
logger.warning(
|
112
|
-
"Entering DeepGEMM JIT Pre-Compile session. "
|
113
|
-
"It may takes a long time (typically 10-20 mins) "
|
114
|
-
"if you have not run `sglang.compile_deep_gemm`. "
|
115
|
-
"It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
116
|
-
" for pre-compilation to reduce the overhead if you have not run it before. "
|
117
|
-
"For example: "
|
118
|
-
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
|
119
|
-
)
|
120
|
-
|
121
|
-
|
122
|
-
# TODO improve naming
|
123
|
-
def _compile_warning_2():
|
124
|
-
logger.warning(
|
125
|
-
"Entering DeepGEMM JIT Single Kernel Compile session. "
|
126
|
-
"And it will makes inference throughput becomes flaky. "
|
127
|
-
"Please run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
128
|
-
" for pre-compilation to solve this issue. "
|
129
|
-
"For example: "
|
130
|
-
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
|
131
|
-
)
|
132
|
-
|
133
|
-
|
134
|
-
def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
|
135
|
-
n: int,
|
136
|
-
k: int,
|
137
|
-
num_groups: int,
|
138
|
-
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
139
|
-
) -> None:
|
140
|
-
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
141
|
-
block_k = 128
|
142
|
-
num_tma_threads = 128
|
143
|
-
num_math_threads_per_group = 128
|
144
|
-
|
145
|
-
kwargs = {
|
146
|
-
"GEMM_TYPE": GemmType.GroupedMasked,
|
147
|
-
"NUM_TMA_THREADS": num_tma_threads,
|
148
|
-
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
149
|
-
"N": n,
|
150
|
-
"K": k,
|
151
|
-
"NUM_GROUPS": num_groups,
|
152
|
-
"BLOCK_M": block_m,
|
153
|
-
"BLOCK_N": block_n,
|
154
|
-
"BLOCK_K": block_k,
|
155
|
-
"SWIZZLE_D_MODE": smem_config[1],
|
156
|
-
"BLOCK_N_PADDING": smem_config[2],
|
157
|
-
"NUM_STAGES": num_stages,
|
158
|
-
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
159
|
-
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
160
|
-
"NUM_SMS": num_sms,
|
161
|
-
"SMEM_SIZE": smem_config[0],
|
162
|
-
}
|
163
|
-
|
164
|
-
code = FP8GemmRuntime.generate(kwargs)
|
165
|
-
_ = build("m_grouped_gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
|
166
|
-
|
167
|
-
|
168
|
-
def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
|
169
|
-
n: int,
|
170
|
-
k: int,
|
171
|
-
num_groups: int,
|
172
|
-
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
173
|
-
) -> None:
|
174
|
-
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
175
|
-
block_k = 128
|
176
|
-
num_tma_threads = 128
|
177
|
-
num_math_threads_per_group = 128
|
178
|
-
kwargs = {
|
179
|
-
"GEMM_TYPE": GemmType.GroupedContiguous,
|
180
|
-
"NUM_TMA_THREADS": num_tma_threads,
|
181
|
-
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
182
|
-
"N": n,
|
183
|
-
"K": k,
|
184
|
-
"NUM_GROUPS": 1,
|
185
|
-
"BLOCK_M": block_m,
|
186
|
-
"BLOCK_N": block_n,
|
187
|
-
"BLOCK_K": block_k,
|
188
|
-
"SWIZZLE_D_MODE": smem_config[1],
|
189
|
-
"BLOCK_N_PADDING": smem_config[2],
|
190
|
-
"NUM_STAGES": num_stages,
|
191
|
-
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
192
|
-
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
193
|
-
"NUM_SMS": num_sms,
|
194
|
-
"SMEM_SIZE": smem_config[0],
|
195
|
-
}
|
196
|
-
|
197
|
-
code = FP8GemmRuntime.generate(kwargs)
|
198
|
-
_ = build("m_grouped_gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
|
199
|
-
|
200
|
-
|
201
|
-
def _compile_gemm_nt_f8f8bf16_one(
|
202
|
-
n: int,
|
203
|
-
k: int,
|
204
|
-
_: int, # _ is a dummy parameter to align with other interfaces
|
205
|
-
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
206
|
-
) -> None:
|
207
|
-
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
208
|
-
block_k = 128
|
209
|
-
num_tma_threads = 128
|
210
|
-
num_math_threads_per_group = 128
|
211
|
-
kwargs = {
|
212
|
-
"GEMM_TYPE": GemmType.Normal,
|
213
|
-
"NUM_TMA_THREADS": num_tma_threads,
|
214
|
-
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
215
|
-
"N": n,
|
216
|
-
"K": k,
|
217
|
-
"NUM_GROUPS": 1,
|
218
|
-
"BLOCK_M": block_m,
|
219
|
-
"BLOCK_N": block_n,
|
220
|
-
"BLOCK_K": block_k,
|
221
|
-
"SWIZZLE_D_MODE": smem_config[1],
|
222
|
-
"BLOCK_N_PADDING": smem_config[2],
|
223
|
-
"NUM_STAGES": num_stages,
|
224
|
-
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
225
|
-
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
226
|
-
"NUM_SMS": num_sms,
|
227
|
-
"SMEM_SIZE": smem_config[0],
|
228
|
-
}
|
229
|
-
|
230
|
-
code = FP8GemmRuntime.generate(kwargs)
|
231
|
-
_ = build("gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
|
232
|
-
|
233
|
-
|
234
|
-
# TODO further refactor warmup-related
|
235
|
-
_KERNEL_HELPER_DICT: Dict[DeepGemmKernelType, DeepGemmKernelHelper] = {
|
236
|
-
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: DeepGemmKernelHelper(
|
237
|
-
name="m_grouped_gemm_fp8_fp8_bf16_nt_masked",
|
238
|
-
compile_func=_compile_grouped_gemm_nt_f8f8bf16_masked_one,
|
239
|
-
configure_func=lambda m, n, k, num_groups, num_sms: get_best_configs(
|
240
|
-
m, n, k, num_groups, num_sms, is_grouped_masked=True
|
241
|
-
),
|
242
|
-
),
|
243
|
-
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG: DeepGemmKernelHelper(
|
244
|
-
name="m_grouped_gemm_fp8_fp8_bf16_nt_contiguous",
|
245
|
-
compile_func=_compile_grouped_gemm_nt_f8f8bf16_contig_one,
|
246
|
-
configure_func=lambda m, n, k, _, num_sms: get_best_configs(
|
247
|
-
m, n, k, 1, num_sms, is_grouped_contiguous=True
|
248
|
-
),
|
249
|
-
),
|
250
|
-
DeepGemmKernelType.GEMM_NT_F8F8BF16: DeepGemmKernelHelper(
|
251
|
-
name="gemm_fp8_fp8_bf16_nt",
|
252
|
-
compile_func=_compile_gemm_nt_f8f8bf16_one,
|
253
|
-
configure_func=lambda m, n, k, _, num_sms: get_best_configs(
|
254
|
-
m, n, k, 1, num_sms
|
255
|
-
),
|
256
|
-
),
|
257
|
-
}
|
258
|
-
|
259
|
-
|
74
|
+
# TODO improve code
|
260
75
|
def _maybe_compile_deep_gemm_one_type_all(
|
261
76
|
kernel_type: DeepGemmKernelType,
|
262
77
|
n: int,
|
263
78
|
k: int,
|
264
79
|
num_groups: int,
|
265
|
-
m_list: Optional[List[int]] = None,
|
266
80
|
) -> None:
|
267
81
|
global _INITIALIZATION_DICT
|
268
82
|
global _BUILTIN_M_LIST
|
@@ -275,61 +89,145 @@ def _maybe_compile_deep_gemm_one_type_all(
|
|
275
89
|
):
|
276
90
|
_INITIALIZATION_DICT[query_key] = True
|
277
91
|
|
278
|
-
|
279
|
-
|
92
|
+
# TODO maybe improve logs
|
93
|
+
if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
|
94
|
+
logger.warning(
|
95
|
+
"Entering DeepGEMM JIT Pre-Compile session. "
|
96
|
+
"It may takes a long time (typically 10-20 mins) "
|
97
|
+
"if you have not run `sglang.compile_deep_gemm`. "
|
98
|
+
"It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
99
|
+
" for pre-compilation to reduce the overhead if you have not run it before. "
|
100
|
+
"For example: "
|
101
|
+
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
|
102
|
+
)
|
103
|
+
|
280
104
|
logger.info(
|
281
105
|
f"Try DeepGEMM JIT Compiling for "
|
282
|
-
f"<{
|
106
|
+
f"<{kernel_type.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
|
283
107
|
f"{' It only takes a little time (typically 1 sec) if you have run `python3 -m sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}"
|
284
108
|
)
|
285
109
|
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
kernel_helper.configure_func(m, n, k, num_groups, num_sms)
|
293
|
-
)
|
294
|
-
compile_func = lambda config: kernel_helper.compile_func(
|
295
|
-
n, k, num_groups, config
|
110
|
+
_compile_deep_gemm_one_type_all(
|
111
|
+
kernel_type=kernel_type,
|
112
|
+
n=n,
|
113
|
+
k=k,
|
114
|
+
num_groups=num_groups,
|
115
|
+
m_list=_BUILTIN_M_LIST,
|
296
116
|
)
|
297
|
-
thread_map(compile_func, collected_configs, max_workers=_COMPILE_WORKERS)
|
298
117
|
|
299
118
|
|
300
|
-
|
301
|
-
def
|
302
|
-
|
303
|
-
|
304
|
-
|
119
|
+
# NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
|
120
|
+
def _compile_deep_gemm_one_type_all(
|
121
|
+
kernel_type: DeepGemmKernelType,
|
122
|
+
n: int,
|
123
|
+
k: int,
|
124
|
+
num_groups: int,
|
125
|
+
m_list: List[int],
|
126
|
+
) -> None:
|
127
|
+
if kernel_type == DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG:
|
128
|
+
m_alignment = deep_gemm.get_mk_alignment_for_contiguous_layout()
|
129
|
+
m_list = sorted(list(set(m for m in m_list if m % m_alignment == 0)))
|
305
130
|
|
306
|
-
|
131
|
+
executor = _BaseWarmupExecutor.create(
|
132
|
+
kernel_type, max_m=max(m_list), n=n, k=k, num_groups=num_groups
|
133
|
+
)
|
307
134
|
|
308
|
-
|
135
|
+
# TODO can use multi thread
|
136
|
+
for m in tqdm(m_list, desc=f"DeepGEMM warmup"):
|
137
|
+
executor.execute(m=m)
|
309
138
|
|
310
|
-
def __patched_func(self, *args, **kwargs):
|
311
|
-
ret = origin_func(self, *args, **kwargs)
|
312
|
-
if ret is None:
|
313
|
-
kernel_helper = _KERNEL_HELPER_DICT[kernel_type]
|
314
|
-
if not DEEPGEMM_BLACKWELL:
|
315
|
-
_compile_warning_2()
|
316
|
-
logger.warning(
|
317
|
-
f"DeepGEMM JIT Compiling for <{kernel_helper.name}> M={M}, N={N}, K={K}. Please wait."
|
318
|
-
)
|
319
|
-
return ret
|
320
139
|
|
321
|
-
|
322
|
-
|
323
|
-
|
140
|
+
class _BaseWarmupExecutor:
|
141
|
+
@staticmethod
|
142
|
+
def create(kernel_type: DeepGemmKernelType, **kwargs):
|
143
|
+
return {
|
144
|
+
DeepGemmKernelType.GEMM_NT_F8F8BF16: _NormalWarmupExecutor,
|
145
|
+
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG: _GroupedContWarmupExecutor,
|
146
|
+
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: _GroupedMaskedWarmupExecutor,
|
147
|
+
}[kernel_type](**kwargs)
|
148
|
+
|
149
|
+
def execute(self, m):
|
150
|
+
raise NotImplementedError
|
151
|
+
|
152
|
+
|
153
|
+
def _empty_token_fp8(size):
|
154
|
+
*dims, k = size
|
155
|
+
return (
|
156
|
+
torch.empty(size, device="cuda", dtype=torch.float8_e4m3fn),
|
157
|
+
torch.empty(
|
158
|
+
(*dims, ceil_div(k, _BLOCK_SIZE)), device="cuda", dtype=torch.float32
|
159
|
+
),
|
160
|
+
)
|
161
|
+
|
162
|
+
|
163
|
+
def _empty_block_fp8(size):
|
164
|
+
*dims, n, k = size
|
165
|
+
return (
|
166
|
+
torch.empty(size, device="cuda", dtype=torch.float8_e4m3fn),
|
167
|
+
torch.empty(
|
168
|
+
(*dims, ceil_div(n, _BLOCK_SIZE), ceil_div(k, _BLOCK_SIZE)),
|
169
|
+
device="cuda",
|
170
|
+
dtype=torch.float32,
|
171
|
+
),
|
172
|
+
)
|
173
|
+
|
174
|
+
|
175
|
+
_BLOCK_SIZE = 128
|
176
|
+
|
177
|
+
|
178
|
+
class _NormalWarmupExecutor(_BaseWarmupExecutor):
|
179
|
+
def __init__(self, max_m: int, n: int, k: int, num_groups: int):
|
180
|
+
self.lhs_q, self.lhs_s = _empty_token_fp8((max_m, k))
|
181
|
+
self.rhs_q, self.rhs_s = _empty_block_fp8((n, k))
|
182
|
+
self.out = torch.empty((max_m, n), device="cuda", dtype=torch.bfloat16)
|
183
|
+
|
184
|
+
def execute(self, m):
|
185
|
+
deep_gemm.fp8_gemm_nt(
|
186
|
+
(self.lhs_q[:m], self.lhs_s[:m]),
|
187
|
+
(self.rhs_q, self.rhs_s),
|
188
|
+
self.out[:m],
|
189
|
+
)
|
190
|
+
|
191
|
+
|
192
|
+
class _GroupedContWarmupExecutor(_BaseWarmupExecutor):
|
193
|
+
def __init__(self, max_m: int, n: int, k: int, num_groups: int):
|
194
|
+
self.lhs_q, self.lhs_s = _empty_token_fp8((max_m, k))
|
195
|
+
self.rhs_q, self.rhs_s = _empty_block_fp8((num_groups, n, k))
|
196
|
+
self.m_indices = torch.zeros((max_m,), device="cuda", dtype=torch.int32)
|
197
|
+
self.out = torch.empty((max_m, n), device="cuda", dtype=torch.bfloat16)
|
198
|
+
|
199
|
+
def execute(self, m):
|
200
|
+
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
|
201
|
+
(self.lhs_q[:m], self.lhs_s[:m]),
|
202
|
+
(self.rhs_q, self.rhs_s),
|
203
|
+
self.out[:m],
|
204
|
+
m_indices=self.m_indices[:m],
|
205
|
+
)
|
206
|
+
|
207
|
+
|
208
|
+
class _GroupedMaskedWarmupExecutor(_BaseWarmupExecutor):
|
209
|
+
def __init__(self, max_m: int, n: int, k: int, num_groups: int):
|
210
|
+
self.lhs_q, self.lhs_s = _empty_token_fp8((num_groups, max_m, k))
|
211
|
+
self.rhs_q, self.rhs_s = _empty_block_fp8((num_groups, n, k))
|
212
|
+
self.masked_m = torch.zeros((num_groups,), device="cuda", dtype=torch.int32)
|
213
|
+
self.out = torch.empty(
|
214
|
+
(num_groups, max_m, n), device="cuda", dtype=torch.bfloat16
|
215
|
+
)
|
216
|
+
|
217
|
+
def execute(self, m):
|
218
|
+
deep_gemm.fp8_m_grouped_gemm_nt_masked(
|
219
|
+
(self.lhs_q, self.lhs_s),
|
220
|
+
(self.rhs_q, self.rhs_s),
|
221
|
+
self.out,
|
222
|
+
masked_m=self.masked_m,
|
223
|
+
# DeepGEMM uses `expect_m` instead of input shape for `get_best_config`
|
224
|
+
expected_m=m,
|
225
|
+
)
|
324
226
|
|
325
227
|
|
326
228
|
@contextmanager
|
327
229
|
def deep_gemm_execution_hook(
|
328
230
|
m: int, n: int, k: int, num_groups: int, kernel_type: DeepGemmKernelType
|
329
231
|
):
|
330
|
-
|
331
|
-
|
332
|
-
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
|
333
|
-
|
334
|
-
with _log_jit_build(m, n, k, kernel_type):
|
335
|
-
yield
|
232
|
+
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
|
233
|
+
yield
|
@@ -24,14 +24,12 @@ def _compute_enable_deep_gemm():
|
|
24
24
|
return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
|
25
25
|
|
26
26
|
|
27
|
-
|
27
|
+
def _is_blackwell_arch() -> bool:
|
28
|
+
major, minor = torch.cuda.get_device_capability(torch.cuda.current_device())
|
29
|
+
return major == 10
|
28
30
|
|
29
|
-
try:
|
30
|
-
from deep_gemm import fp8_gemm_nt
|
31
31
|
|
32
|
-
|
33
|
-
DEEPGEMM_BLACKWELL = True
|
34
|
-
except ImportError:
|
35
|
-
DEEPGEMM_BLACKWELL = False
|
32
|
+
ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
|
36
33
|
|
34
|
+
DEEPGEMM_BLACKWELL = ENABLE_JIT_DEEPGEMM and _is_blackwell_arch()
|
37
35
|
DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL
|
@@ -16,33 +16,16 @@ logger = logging.getLogger(__name__)
|
|
16
16
|
|
17
17
|
if ENABLE_JIT_DEEPGEMM:
|
18
18
|
import deep_gemm
|
19
|
-
|
20
|
-
if DEEPGEMM_BLACKWELL:
|
21
|
-
from deep_gemm import fp8_gemm_nt as _gemm_nt_f8f8bf16_raw
|
22
|
-
from deep_gemm import (
|
23
|
-
fp8_m_grouped_gemm_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
|
24
|
-
)
|
25
|
-
from deep_gemm import (
|
26
|
-
m_grouped_fp8_gemm_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
|
27
|
-
)
|
28
|
-
else:
|
29
|
-
from deep_gemm import gemm_fp8_fp8_bf16_nt as _gemm_nt_f8f8bf16_raw
|
30
|
-
from deep_gemm import get_col_major_tma_aligned_tensor
|
31
|
-
from deep_gemm import (
|
32
|
-
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
|
33
|
-
)
|
34
|
-
from deep_gemm import (
|
35
|
-
m_grouped_gemm_fp8_fp8_bf16_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
|
36
|
-
)
|
19
|
+
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
|
37
20
|
|
38
21
|
|
22
|
+
# TODO maybe rename these functions
|
39
23
|
def grouped_gemm_nt_f8f8bf16_masked(
|
40
24
|
lhs: Tuple[torch.Tensor, torch.Tensor],
|
41
25
|
rhs: Tuple[torch.Tensor, torch.Tensor],
|
42
26
|
out: torch.Tensor,
|
43
27
|
masked_m: torch.Tensor,
|
44
28
|
expected_m: int,
|
45
|
-
recipe=None,
|
46
29
|
):
|
47
30
|
num_groups, _, k = lhs[0].shape
|
48
31
|
_, n, _ = rhs[0].shape
|
@@ -51,13 +34,12 @@ def grouped_gemm_nt_f8f8bf16_masked(
|
|
51
34
|
with compile_utils.deep_gemm_execution_hook(
|
52
35
|
expected_m, n, k, num_groups, kernel_type
|
53
36
|
):
|
54
|
-
|
37
|
+
deep_gemm.fp8_m_grouped_gemm_nt_masked(
|
55
38
|
lhs,
|
56
39
|
rhs,
|
57
40
|
out,
|
58
41
|
masked_m,
|
59
42
|
expected_m,
|
60
|
-
**({"recipe": recipe} if DEEPGEMM_BLACKWELL else {})
|
61
43
|
)
|
62
44
|
|
63
45
|
|
@@ -72,7 +54,7 @@ def grouped_gemm_nt_f8f8bf16_contig(
|
|
72
54
|
kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
|
73
55
|
|
74
56
|
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
|
75
|
-
|
57
|
+
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(lhs, rhs, out, m_indices)
|
76
58
|
|
77
59
|
|
78
60
|
def gemm_nt_f8f8bf16(
|
@@ -86,7 +68,7 @@ def gemm_nt_f8f8bf16(
|
|
86
68
|
kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16
|
87
69
|
|
88
70
|
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
|
89
|
-
|
71
|
+
deep_gemm.fp8_gemm_nt(
|
90
72
|
lhs,
|
91
73
|
rhs,
|
92
74
|
out,
|
@@ -64,7 +64,6 @@ from sglang.srt.layers.quantization.utils import (
|
|
64
64
|
per_tensor_dequantize,
|
65
65
|
requantize_with_max_scale,
|
66
66
|
)
|
67
|
-
from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
|
68
67
|
from sglang.srt.utils import (
|
69
68
|
cpu_has_amx_support,
|
70
69
|
get_bool_env_var,
|
@@ -72,6 +71,8 @@ from sglang.srt.utils import (
|
|
72
71
|
is_cuda,
|
73
72
|
is_hip,
|
74
73
|
is_npu,
|
74
|
+
is_sm90_supported,
|
75
|
+
is_sm100_supported,
|
75
76
|
log_info_on_rank0,
|
76
77
|
next_power_of_2,
|
77
78
|
print_warning_once,
|
@@ -298,7 +298,7 @@ def _per_token_group_quant_8bit_raw(
|
|
298
298
|
)
|
299
299
|
|
300
300
|
if scale_ue8m0:
|
301
|
-
from deep_gemm
|
301
|
+
from deep_gemm import transform_sf_into_required_layout
|
302
302
|
|
303
303
|
assert group_size == 128
|
304
304
|
x_s = transform_sf_into_required_layout(
|
@@ -338,7 +338,7 @@ def _per_token_group_quant_8bit_fuse_silu_and_mul(
|
|
338
338
|
# scale_ue8m0=scale_ue8m0,
|
339
339
|
# )
|
340
340
|
|
341
|
-
from deep_gemm
|
341
|
+
from deep_gemm import transform_sf_into_required_layout
|
342
342
|
|
343
343
|
from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd
|
344
344
|
|
@@ -5,7 +5,7 @@ import torch
|
|
5
5
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
6
6
|
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
7
7
|
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
|
8
|
-
from sglang.srt.
|
8
|
+
from sglang.srt.utils import is_sm100_supported
|
9
9
|
|
10
10
|
try:
|
11
11
|
from vllm import _custom_ops as ops
|
@@ -459,7 +459,7 @@ def _requant_weight_ue8m0(
|
|
459
459
|
import deep_gemm.utils.layout
|
460
460
|
|
461
461
|
sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128)
|
462
|
-
sf = deep_gemm.utils.layout.
|
462
|
+
sf = deep_gemm.utils.layout.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
|
463
463
|
return sf
|
464
464
|
|
465
465
|
out_s = _transform_scale(out_s, mn=out_w.shape[-2])
|
@@ -876,7 +876,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
876
876
|
data=torch.empty(
|
877
877
|
layer.num_local_experts,
|
878
878
|
2 * intermediate_size_per_partition,
|
879
|
-
# 2 fp4 items are packed in the input dimension
|
880
879
|
hidden_size // self.quant_config.group_size,
|
881
880
|
dtype=weight_scale_dtype,
|
882
881
|
),
|
@@ -895,7 +894,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
895
894
|
data=torch.empty(
|
896
895
|
layer.num_local_experts,
|
897
896
|
hidden_size,
|
898
|
-
# 2 fp4 items are packed in the input dimension
|
899
897
|
intermediate_size_per_partition // self.quant_config.group_size,
|
900
898
|
dtype=weight_scale_dtype,
|
901
899
|
),
|
@@ -1212,11 +1210,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1212
1210
|
|
1213
1211
|
# Process w13 weights
|
1214
1212
|
w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale)
|
1213
|
+
del layer.w13_weight_scale
|
1215
1214
|
layer.w13_blockscale_swizzled.data.copy_(w13_blockscale_swizzled)
|
1216
1215
|
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
|
1217
1216
|
|
1218
1217
|
# Process w2 weights
|
1219
1218
|
w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
|
1219
|
+
del layer.w2_weight_scale
|
1220
1220
|
layer.w2_blockscale_swizzled.data.copy_(w2_blockscale_swizzled)
|
1221
1221
|
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
1222
1222
|
|
@@ -29,14 +29,13 @@ from sglang.srt.layers.quantization.base_config import (
|
|
29
29
|
QuantizeMethodBase,
|
30
30
|
)
|
31
31
|
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
32
|
-
from sglang.srt.layers.utils import is_sm100_supported
|
33
32
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
34
33
|
from sglang.srt.utils import (
|
35
34
|
direct_register_custom_op,
|
36
|
-
get_bool_env_var,
|
37
35
|
is_cuda,
|
38
36
|
is_flashinfer_available,
|
39
37
|
is_hip,
|
38
|
+
is_sm100_supported,
|
40
39
|
is_triton_kernels_available,
|
41
40
|
log_info_on_rank0,
|
42
41
|
mxfp_supported,
|
@@ -146,27 +145,21 @@ def _quant_dequant_mxfp4_fake(
|
|
146
145
|
return torch.empty_like(x)
|
147
146
|
|
148
147
|
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
mutates_args=[],
|
165
|
-
fake_impl=_quant_dequant_mxfp4_fake,
|
166
|
-
)
|
167
|
-
quant_dequant_mxfp4 = torch.ops.sglang.quant_dequant_mxfp4
|
168
|
-
except AttributeError as error:
|
169
|
-
raise error
|
148
|
+
direct_register_custom_op(
|
149
|
+
op_name="dequant_mxfp4",
|
150
|
+
op_func=_dequant_mxfp4,
|
151
|
+
mutates_args=[],
|
152
|
+
fake_impl=_dequant_mxfp4_fake,
|
153
|
+
)
|
154
|
+
dequant_mxfp4 = torch.ops.sglang.dequant_mxfp4
|
155
|
+
|
156
|
+
direct_register_custom_op(
|
157
|
+
op_name="quant_dequant_mxfp4",
|
158
|
+
op_func=_quant_dequant_mxfp4,
|
159
|
+
mutates_args=[],
|
160
|
+
fake_impl=_quant_dequant_mxfp4_fake,
|
161
|
+
)
|
162
|
+
quant_dequant_mxfp4 = torch.ops.sglang.quant_dequant_mxfp4
|
170
163
|
|
171
164
|
|
172
165
|
class Mxfp4Config(QuantizationConfig):
|
@@ -13,6 +13,8 @@
|
|
13
13
|
# See the License for the specific language governing permissions and
|
14
14
|
# limitations under the License.
|
15
15
|
|
16
|
+
from typing import Optional
|
17
|
+
|
16
18
|
import torch
|
17
19
|
|
18
20
|
|
@@ -24,7 +26,7 @@ class MXFP4QuantizeUtil:
|
|
24
26
|
E2M1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5])
|
25
27
|
|
26
28
|
@classmethod
|
27
|
-
def quantize(cls, input: torch.Tensor, block_size: int
|
29
|
+
def quantize(cls, input: torch.Tensor, block_size: Optional[int]) -> tuple:
|
28
30
|
"""Converting a tensor to a quantized format based on MXFP4 quantization. Only E4M3 is supported.
|
29
31
|
Args:
|
30
32
|
input (torch.Tensor): The input tensor to be quantized.
|