sglang 0.5.1.post2__py3-none-any.whl → 0.5.2rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +79 -53
- sglang/bench_serving.py +186 -14
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +12 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/conversation.py +38 -5
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/launch_lb.py +0 -13
- sglang/srt/disaggregation/mini_lb.py +33 -8
- sglang/srt/disaggregation/prefill.py +1 -1
- sglang/srt/distributed/parallel_state.py +24 -14
- sglang/srt/entrypoints/engine.py +19 -12
- sglang/srt/entrypoints/http_server.py +174 -34
- sglang/srt/entrypoints/openai/protocol.py +87 -24
- sglang/srt/entrypoints/openai/serving_chat.py +50 -9
- sglang/srt/entrypoints/openai/serving_completions.py +15 -0
- sglang/srt/eplb/eplb_manager.py +26 -2
- sglang/srt/eplb/expert_distribution.py +29 -2
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/harmony_parser.py +588 -0
- sglang/srt/hf_transformers_utils.py +26 -7
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention/ascend_backend.py +374 -136
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/flashinfer_mla_backend.py +5 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
- sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
- sglang/srt/layers/communicator.py +1 -2
- sglang/srt/layers/layernorm.py +28 -3
- sglang/srt/layers/linear.py +3 -2
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +13 -13
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/topk.py +35 -12
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +133 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +5 -23
- sglang/srt/layers/quantization/fp8.py +2 -1
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/modelopt_quant.py +7 -0
- sglang/srt/layers/quantization/mxfp4.py +25 -27
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -3
- sglang/srt/layers/rotary_embedding.py +28 -1
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/managers/cache_controller.py +237 -204
- sglang/srt/managers/detokenizer_manager.py +48 -2
- sglang/srt/managers/io_struct.py +57 -0
- sglang/srt/managers/mm_utils.py +5 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +591 -0
- sglang/srt/managers/scheduler.py +94 -9
- sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/tokenizer_manager.py +122 -42
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +51 -23
- sglang/srt/mem_cache/hiradix_cache.py +87 -71
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +77 -14
- sglang/srt/mem_cache/memory_pool_host.py +4 -5
- sglang/srt/mem_cache/radix_cache.py +6 -4
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +38 -20
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +87 -82
- sglang/srt/mem_cache/swa_radix_cache.py +1 -1
- sglang/srt/model_executor/model_runner.py +6 -5
- sglang/srt/model_loader/loader.py +15 -24
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/models/deepseek_v2.py +38 -13
- sglang/srt/models/gpt_oss.py +2 -15
- sglang/srt/models/llama_eagle3.py +4 -0
- sglang/srt/models/longcat_flash.py +1015 -0
- sglang/srt/models/longcat_flash_nextn.py +691 -0
- sglang/srt/models/qwen2.py +26 -3
- sglang/srt/models/qwen2_5_vl.py +66 -41
- sglang/srt/models/qwen2_moe.py +22 -2
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/reasoning_parser.py +56 -300
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/server_args.py +122 -56
- sglang/srt/speculative/eagle_worker.py +28 -8
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +73 -5
- sglang/test/attention/test_trtllm_mla_backend.py +12 -3
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/METADATA +7 -6
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/RECORD +107 -99
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/top_level.txt +0 -0
@@ -1,26 +1,22 @@
|
|
1
1
|
import logging
|
2
2
|
import os
|
3
3
|
from contextlib import contextmanager
|
4
|
-
from dataclasses import dataclass
|
5
4
|
from enum import IntEnum, auto
|
6
|
-
from typing import
|
5
|
+
from typing import Dict, List, Tuple
|
7
6
|
|
8
|
-
|
7
|
+
import torch
|
8
|
+
from tqdm import tqdm
|
9
9
|
|
10
10
|
from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
|
11
|
-
DEEPGEMM_BLACKWELL,
|
12
11
|
ENABLE_JIT_DEEPGEMM,
|
13
12
|
)
|
14
13
|
from sglang.srt.server_args import ServerArgs
|
15
|
-
from sglang.srt.utils import get_bool_env_var, get_int_env_var
|
14
|
+
from sglang.srt.utils import ceil_div, get_bool_env_var, get_int_env_var
|
16
15
|
|
17
16
|
logger = logging.getLogger(__name__)
|
18
17
|
|
19
|
-
if ENABLE_JIT_DEEPGEMM
|
20
|
-
|
21
|
-
from deep_gemm.jit import build
|
22
|
-
from deep_gemm.jit_kernels.gemm import get_best_configs
|
23
|
-
from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
|
18
|
+
if ENABLE_JIT_DEEPGEMM:
|
19
|
+
import deep_gemm
|
24
20
|
|
25
21
|
|
26
22
|
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
|
@@ -40,19 +36,7 @@ os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
|
|
40
36
|
# Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
|
41
37
|
# NVRTC may have performance loss with some cases.
|
42
38
|
# And NVCC JIT speed is also 9x faster in the ref commit
|
43
|
-
|
44
|
-
if ENABLE_JIT_DEEPGEMM:
|
45
|
-
try:
|
46
|
-
from deep_gemm.jit.compiler import get_nvcc_compiler
|
47
|
-
|
48
|
-
get_nvcc_compiler()
|
49
|
-
except:
|
50
|
-
logger.warning(
|
51
|
-
"NVCC Compiler not found, use NVRTC for DeepGEMM JIT "
|
52
|
-
"and may have performance loss with some cases."
|
53
|
-
)
|
54
|
-
_USE_NVRTC_DEFAULT = "1"
|
55
|
-
os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", _USE_NVRTC_DEFAULT)
|
39
|
+
os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", "0")
|
56
40
|
|
57
41
|
|
58
42
|
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
|
@@ -75,7 +59,7 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
|
|
75
59
|
# Default each rank will try compile all Ms to
|
76
60
|
# load all symbols at the launch stages.
|
77
61
|
# Avoid loading symbols at the serving stages.
|
78
|
-
_DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE
|
62
|
+
_DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE
|
79
63
|
|
80
64
|
|
81
65
|
class DeepGemmKernelType(IntEnum):
|
@@ -84,185 +68,15 @@ class DeepGemmKernelType(IntEnum):
|
|
84
68
|
GEMM_NT_F8F8BF16 = auto()
|
85
69
|
|
86
70
|
|
87
|
-
@dataclass
|
88
|
-
class DeepGemmKernelHelper:
|
89
|
-
name: str
|
90
|
-
compile_func: Callable[
|
91
|
-
[
|
92
|
-
int,
|
93
|
-
int,
|
94
|
-
int,
|
95
|
-
Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
96
|
-
],
|
97
|
-
None,
|
98
|
-
]
|
99
|
-
configure_func: Callable[
|
100
|
-
[int, int, int, int, int],
|
101
|
-
Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
102
|
-
]
|
103
|
-
|
104
|
-
|
105
71
|
_INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dict()
|
106
72
|
|
107
73
|
|
108
|
-
# TODO improve
|
109
|
-
def _compile_warning_1():
|
110
|
-
if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
|
111
|
-
logger.warning(
|
112
|
-
"Entering DeepGEMM JIT Pre-Compile session. "
|
113
|
-
"It may takes a long time (typically 10-20 mins) "
|
114
|
-
"if you have not run `sglang.compile_deep_gemm`. "
|
115
|
-
"It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
116
|
-
" for pre-compilation to reduce the overhead if you have not run it before. "
|
117
|
-
"For example: "
|
118
|
-
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
|
119
|
-
)
|
120
|
-
|
121
|
-
|
122
|
-
# TODO improve naming
|
123
|
-
def _compile_warning_2():
|
124
|
-
logger.warning(
|
125
|
-
"Entering DeepGEMM JIT Single Kernel Compile session. "
|
126
|
-
"And it will makes inference throughput becomes flaky. "
|
127
|
-
"Please run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
128
|
-
" for pre-compilation to solve this issue. "
|
129
|
-
"For example: "
|
130
|
-
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
|
131
|
-
)
|
132
|
-
|
133
|
-
|
134
|
-
def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
|
135
|
-
n: int,
|
136
|
-
k: int,
|
137
|
-
num_groups: int,
|
138
|
-
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
139
|
-
) -> None:
|
140
|
-
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
141
|
-
block_k = 128
|
142
|
-
num_tma_threads = 128
|
143
|
-
num_math_threads_per_group = 128
|
144
|
-
|
145
|
-
kwargs = {
|
146
|
-
"GEMM_TYPE": GemmType.GroupedMasked,
|
147
|
-
"NUM_TMA_THREADS": num_tma_threads,
|
148
|
-
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
149
|
-
"N": n,
|
150
|
-
"K": k,
|
151
|
-
"NUM_GROUPS": num_groups,
|
152
|
-
"BLOCK_M": block_m,
|
153
|
-
"BLOCK_N": block_n,
|
154
|
-
"BLOCK_K": block_k,
|
155
|
-
"SWIZZLE_D_MODE": smem_config[1],
|
156
|
-
"BLOCK_N_PADDING": smem_config[2],
|
157
|
-
"NUM_STAGES": num_stages,
|
158
|
-
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
159
|
-
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
160
|
-
"NUM_SMS": num_sms,
|
161
|
-
"SMEM_SIZE": smem_config[0],
|
162
|
-
}
|
163
|
-
|
164
|
-
code = FP8GemmRuntime.generate(kwargs)
|
165
|
-
_ = build("m_grouped_gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
|
166
|
-
|
167
|
-
|
168
|
-
def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
|
169
|
-
n: int,
|
170
|
-
k: int,
|
171
|
-
num_groups: int,
|
172
|
-
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
173
|
-
) -> None:
|
174
|
-
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
175
|
-
block_k = 128
|
176
|
-
num_tma_threads = 128
|
177
|
-
num_math_threads_per_group = 128
|
178
|
-
kwargs = {
|
179
|
-
"GEMM_TYPE": GemmType.GroupedContiguous,
|
180
|
-
"NUM_TMA_THREADS": num_tma_threads,
|
181
|
-
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
182
|
-
"N": n,
|
183
|
-
"K": k,
|
184
|
-
"NUM_GROUPS": 1,
|
185
|
-
"BLOCK_M": block_m,
|
186
|
-
"BLOCK_N": block_n,
|
187
|
-
"BLOCK_K": block_k,
|
188
|
-
"SWIZZLE_D_MODE": smem_config[1],
|
189
|
-
"BLOCK_N_PADDING": smem_config[2],
|
190
|
-
"NUM_STAGES": num_stages,
|
191
|
-
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
192
|
-
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
193
|
-
"NUM_SMS": num_sms,
|
194
|
-
"SMEM_SIZE": smem_config[0],
|
195
|
-
}
|
196
|
-
|
197
|
-
code = FP8GemmRuntime.generate(kwargs)
|
198
|
-
_ = build("m_grouped_gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
|
199
|
-
|
200
|
-
|
201
|
-
def _compile_gemm_nt_f8f8bf16_one(
|
202
|
-
n: int,
|
203
|
-
k: int,
|
204
|
-
_: int, # _ is a dummy parameter to align with other interfaces
|
205
|
-
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
206
|
-
) -> None:
|
207
|
-
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
208
|
-
block_k = 128
|
209
|
-
num_tma_threads = 128
|
210
|
-
num_math_threads_per_group = 128
|
211
|
-
kwargs = {
|
212
|
-
"GEMM_TYPE": GemmType.Normal,
|
213
|
-
"NUM_TMA_THREADS": num_tma_threads,
|
214
|
-
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
215
|
-
"N": n,
|
216
|
-
"K": k,
|
217
|
-
"NUM_GROUPS": 1,
|
218
|
-
"BLOCK_M": block_m,
|
219
|
-
"BLOCK_N": block_n,
|
220
|
-
"BLOCK_K": block_k,
|
221
|
-
"SWIZZLE_D_MODE": smem_config[1],
|
222
|
-
"BLOCK_N_PADDING": smem_config[2],
|
223
|
-
"NUM_STAGES": num_stages,
|
224
|
-
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
225
|
-
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
226
|
-
"NUM_SMS": num_sms,
|
227
|
-
"SMEM_SIZE": smem_config[0],
|
228
|
-
}
|
229
|
-
|
230
|
-
code = FP8GemmRuntime.generate(kwargs)
|
231
|
-
_ = build("gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
|
232
|
-
|
233
|
-
|
234
|
-
# TODO further refactor warmup-related
|
235
|
-
_KERNEL_HELPER_DICT: Dict[DeepGemmKernelType, DeepGemmKernelHelper] = {
|
236
|
-
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: DeepGemmKernelHelper(
|
237
|
-
name="m_grouped_gemm_fp8_fp8_bf16_nt_masked",
|
238
|
-
compile_func=_compile_grouped_gemm_nt_f8f8bf16_masked_one,
|
239
|
-
configure_func=lambda m, n, k, num_groups, num_sms: get_best_configs(
|
240
|
-
m, n, k, num_groups, num_sms, is_grouped_masked=True
|
241
|
-
),
|
242
|
-
),
|
243
|
-
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG: DeepGemmKernelHelper(
|
244
|
-
name="m_grouped_gemm_fp8_fp8_bf16_nt_contiguous",
|
245
|
-
compile_func=_compile_grouped_gemm_nt_f8f8bf16_contig_one,
|
246
|
-
configure_func=lambda m, n, k, _, num_sms: get_best_configs(
|
247
|
-
m, n, k, 1, num_sms, is_grouped_contiguous=True
|
248
|
-
),
|
249
|
-
),
|
250
|
-
DeepGemmKernelType.GEMM_NT_F8F8BF16: DeepGemmKernelHelper(
|
251
|
-
name="gemm_fp8_fp8_bf16_nt",
|
252
|
-
compile_func=_compile_gemm_nt_f8f8bf16_one,
|
253
|
-
configure_func=lambda m, n, k, _, num_sms: get_best_configs(
|
254
|
-
m, n, k, 1, num_sms
|
255
|
-
),
|
256
|
-
),
|
257
|
-
}
|
258
|
-
|
259
|
-
|
74
|
+
# TODO improve code
|
260
75
|
def _maybe_compile_deep_gemm_one_type_all(
|
261
76
|
kernel_type: DeepGemmKernelType,
|
262
77
|
n: int,
|
263
78
|
k: int,
|
264
79
|
num_groups: int,
|
265
|
-
m_list: Optional[List[int]] = None,
|
266
80
|
) -> None:
|
267
81
|
global _INITIALIZATION_DICT
|
268
82
|
global _BUILTIN_M_LIST
|
@@ -275,61 +89,145 @@ def _maybe_compile_deep_gemm_one_type_all(
|
|
275
89
|
):
|
276
90
|
_INITIALIZATION_DICT[query_key] = True
|
277
91
|
|
278
|
-
|
279
|
-
|
92
|
+
# TODO maybe improve logs
|
93
|
+
if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
|
94
|
+
logger.warning(
|
95
|
+
"Entering DeepGEMM JIT Pre-Compile session. "
|
96
|
+
"It may take a long time (typically 10-20 mins) "
|
97
|
+
"if you have not run `sglang.compile_deep_gemm`. "
|
98
|
+
"It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
99
|
+
" for pre-compilation to reduce the overhead if you have not run it before. "
|
100
|
+
"For example: "
|
101
|
+
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
|
102
|
+
)
|
103
|
+
|
280
104
|
logger.info(
|
281
105
|
f"Try DeepGEMM JIT Compiling for "
|
282
|
-
f"<{
|
106
|
+
f"<{kernel_type.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
|
283
107
|
f"{' It only takes a little time (typically 1 sec) if you have run `python3 -m sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}"
|
284
108
|
)
|
285
109
|
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
kernel_helper.configure_func(m, n, k, num_groups, num_sms)
|
293
|
-
)
|
294
|
-
compile_func = lambda config: kernel_helper.compile_func(
|
295
|
-
n, k, num_groups, config
|
110
|
+
_compile_deep_gemm_one_type_all(
|
111
|
+
kernel_type=kernel_type,
|
112
|
+
n=n,
|
113
|
+
k=k,
|
114
|
+
num_groups=num_groups,
|
115
|
+
m_list=_BUILTIN_M_LIST,
|
296
116
|
)
|
297
|
-
thread_map(compile_func, collected_configs, max_workers=_COMPILE_WORKERS)
|
298
117
|
|
299
118
|
|
300
|
-
|
301
|
-
def
|
302
|
-
|
303
|
-
|
304
|
-
|
119
|
+
# NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
|
120
|
+
def _compile_deep_gemm_one_type_all(
|
121
|
+
kernel_type: DeepGemmKernelType,
|
122
|
+
n: int,
|
123
|
+
k: int,
|
124
|
+
num_groups: int,
|
125
|
+
m_list: List[int],
|
126
|
+
) -> None:
|
127
|
+
if kernel_type == DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG:
|
128
|
+
m_alignment = deep_gemm.get_mk_alignment_for_contiguous_layout()
|
129
|
+
m_list = sorted(list(set(m for m in m_list if m % m_alignment == 0)))
|
305
130
|
|
306
|
-
|
131
|
+
executor = _BaseWarmupExecutor.create(
|
132
|
+
kernel_type, max_m=max(m_list), n=n, k=k, num_groups=num_groups
|
133
|
+
)
|
307
134
|
|
308
|
-
|
135
|
+
# TODO can use multi thread
|
136
|
+
for m in tqdm(m_list, desc=f"DeepGEMM warmup"):
|
137
|
+
executor.execute(m=m)
|
309
138
|
|
310
|
-
def __patched_func(self, *args, **kwargs):
|
311
|
-
ret = origin_func(self, *args, **kwargs)
|
312
|
-
if ret is None:
|
313
|
-
kernel_helper = _KERNEL_HELPER_DICT[kernel_type]
|
314
|
-
if not DEEPGEMM_BLACKWELL:
|
315
|
-
_compile_warning_2()
|
316
|
-
logger.warning(
|
317
|
-
f"DeepGEMM JIT Compiling for <{kernel_helper.name}> M={M}, N={N}, K={K}. Please wait."
|
318
|
-
)
|
319
|
-
return ret
|
320
139
|
|
321
|
-
|
322
|
-
|
323
|
-
|
140
|
+
class _BaseWarmupExecutor:
|
141
|
+
@staticmethod
|
142
|
+
def create(kernel_type: DeepGemmKernelType, **kwargs):
|
143
|
+
return {
|
144
|
+
DeepGemmKernelType.GEMM_NT_F8F8BF16: _NormalWarmupExecutor,
|
145
|
+
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG: _GroupedContWarmupExecutor,
|
146
|
+
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: _GroupedMaskedWarmupExecutor,
|
147
|
+
}[kernel_type](**kwargs)
|
148
|
+
|
149
|
+
def execute(self, m):
|
150
|
+
raise NotImplementedError
|
151
|
+
|
152
|
+
|
153
|
+
def _empty_token_fp8(size):
|
154
|
+
*dims, k = size
|
155
|
+
return (
|
156
|
+
torch.empty(size, device="cuda", dtype=torch.float8_e4m3fn),
|
157
|
+
torch.empty(
|
158
|
+
(*dims, ceil_div(k, _BLOCK_SIZE)), device="cuda", dtype=torch.float32
|
159
|
+
),
|
160
|
+
)
|
161
|
+
|
162
|
+
|
163
|
+
def _empty_block_fp8(size):
|
164
|
+
*dims, n, k = size
|
165
|
+
return (
|
166
|
+
torch.empty(size, device="cuda", dtype=torch.float8_e4m3fn),
|
167
|
+
torch.empty(
|
168
|
+
(*dims, ceil_div(n, _BLOCK_SIZE), ceil_div(k, _BLOCK_SIZE)),
|
169
|
+
device="cuda",
|
170
|
+
dtype=torch.float32,
|
171
|
+
),
|
172
|
+
)
|
173
|
+
|
174
|
+
|
175
|
+
_BLOCK_SIZE = 128
|
176
|
+
|
177
|
+
|
178
|
+
class _NormalWarmupExecutor(_BaseWarmupExecutor):
|
179
|
+
def __init__(self, max_m: int, n: int, k: int, num_groups: int):
|
180
|
+
self.lhs_q, self.lhs_s = _empty_token_fp8((max_m, k))
|
181
|
+
self.rhs_q, self.rhs_s = _empty_block_fp8((n, k))
|
182
|
+
self.out = torch.empty((max_m, n), device="cuda", dtype=torch.bfloat16)
|
183
|
+
|
184
|
+
def execute(self, m):
|
185
|
+
deep_gemm.fp8_gemm_nt(
|
186
|
+
(self.lhs_q[:m], self.lhs_s[:m]),
|
187
|
+
(self.rhs_q, self.rhs_s),
|
188
|
+
self.out[:m],
|
189
|
+
)
|
190
|
+
|
191
|
+
|
192
|
+
class _GroupedContWarmupExecutor(_BaseWarmupExecutor):
|
193
|
+
def __init__(self, max_m: int, n: int, k: int, num_groups: int):
|
194
|
+
self.lhs_q, self.lhs_s = _empty_token_fp8((max_m, k))
|
195
|
+
self.rhs_q, self.rhs_s = _empty_block_fp8((num_groups, n, k))
|
196
|
+
self.m_indices = torch.zeros((max_m,), device="cuda", dtype=torch.int32)
|
197
|
+
self.out = torch.empty((max_m, n), device="cuda", dtype=torch.bfloat16)
|
198
|
+
|
199
|
+
def execute(self, m):
|
200
|
+
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
|
201
|
+
(self.lhs_q[:m], self.lhs_s[:m]),
|
202
|
+
(self.rhs_q, self.rhs_s),
|
203
|
+
self.out[:m],
|
204
|
+
m_indices=self.m_indices[:m],
|
205
|
+
)
|
206
|
+
|
207
|
+
|
208
|
+
class _GroupedMaskedWarmupExecutor(_BaseWarmupExecutor):
|
209
|
+
def __init__(self, max_m: int, n: int, k: int, num_groups: int):
|
210
|
+
self.lhs_q, self.lhs_s = _empty_token_fp8((num_groups, max_m, k))
|
211
|
+
self.rhs_q, self.rhs_s = _empty_block_fp8((num_groups, n, k))
|
212
|
+
self.masked_m = torch.zeros((num_groups,), device="cuda", dtype=torch.int32)
|
213
|
+
self.out = torch.empty(
|
214
|
+
(num_groups, max_m, n), device="cuda", dtype=torch.bfloat16
|
215
|
+
)
|
216
|
+
|
217
|
+
def execute(self, m):
|
218
|
+
deep_gemm.fp8_m_grouped_gemm_nt_masked(
|
219
|
+
(self.lhs_q, self.lhs_s),
|
220
|
+
(self.rhs_q, self.rhs_s),
|
221
|
+
self.out,
|
222
|
+
masked_m=self.masked_m,
|
223
|
+
# DeepGEMM uses `expect_m` instead of input shape for `get_best_config`
|
224
|
+
expected_m=m,
|
225
|
+
)
|
324
226
|
|
325
227
|
|
326
228
|
@contextmanager
|
327
229
|
def deep_gemm_execution_hook(
|
328
230
|
m: int, n: int, k: int, num_groups: int, kernel_type: DeepGemmKernelType
|
329
231
|
):
|
330
|
-
|
331
|
-
|
332
|
-
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
|
333
|
-
|
334
|
-
with _log_jit_build(m, n, k, kernel_type):
|
335
|
-
yield
|
232
|
+
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
|
233
|
+
yield
|
@@ -11,9 +11,6 @@ def _compute_enable_deep_gemm():
|
|
11
11
|
sm_version = get_device_sm()
|
12
12
|
if sm_version < 90:
|
13
13
|
return False
|
14
|
-
# TODO fix deepgemm cu129 fp8 issue
|
15
|
-
if torch.version.cuda == "12.9":
|
16
|
-
return False
|
17
14
|
|
18
15
|
try:
|
19
16
|
import deep_gemm
|
@@ -24,14 +21,12 @@ def _compute_enable_deep_gemm():
|
|
24
21
|
return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
|
25
22
|
|
26
23
|
|
27
|
-
|
24
|
+
def _is_blackwell_arch() -> bool:
|
25
|
+
major, minor = torch.cuda.get_device_capability(torch.cuda.current_device())
|
26
|
+
return major == 10
|
28
27
|
|
29
|
-
try:
|
30
|
-
from deep_gemm import fp8_gemm_nt
|
31
28
|
|
32
|
-
|
33
|
-
DEEPGEMM_BLACKWELL = True
|
34
|
-
except ImportError:
|
35
|
-
DEEPGEMM_BLACKWELL = False
|
29
|
+
ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
|
36
30
|
|
31
|
+
DEEPGEMM_BLACKWELL = ENABLE_JIT_DEEPGEMM and _is_blackwell_arch()
|
37
32
|
DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL
|
@@ -16,33 +16,16 @@ logger = logging.getLogger(__name__)
|
|
16
16
|
|
17
17
|
if ENABLE_JIT_DEEPGEMM:
|
18
18
|
import deep_gemm
|
19
|
-
|
20
|
-
if DEEPGEMM_BLACKWELL:
|
21
|
-
from deep_gemm import fp8_gemm_nt as _gemm_nt_f8f8bf16_raw
|
22
|
-
from deep_gemm import (
|
23
|
-
fp8_m_grouped_gemm_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
|
24
|
-
)
|
25
|
-
from deep_gemm import (
|
26
|
-
m_grouped_fp8_gemm_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
|
27
|
-
)
|
28
|
-
else:
|
29
|
-
from deep_gemm import gemm_fp8_fp8_bf16_nt as _gemm_nt_f8f8bf16_raw
|
30
|
-
from deep_gemm import get_col_major_tma_aligned_tensor
|
31
|
-
from deep_gemm import (
|
32
|
-
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
|
33
|
-
)
|
34
|
-
from deep_gemm import (
|
35
|
-
m_grouped_gemm_fp8_fp8_bf16_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
|
36
|
-
)
|
19
|
+
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
|
37
20
|
|
38
21
|
|
22
|
+
# TODO maybe rename these functions
|
39
23
|
def grouped_gemm_nt_f8f8bf16_masked(
|
40
24
|
lhs: Tuple[torch.Tensor, torch.Tensor],
|
41
25
|
rhs: Tuple[torch.Tensor, torch.Tensor],
|
42
26
|
out: torch.Tensor,
|
43
27
|
masked_m: torch.Tensor,
|
44
28
|
expected_m: int,
|
45
|
-
recipe=None,
|
46
29
|
):
|
47
30
|
num_groups, _, k = lhs[0].shape
|
48
31
|
_, n, _ = rhs[0].shape
|
@@ -51,13 +34,12 @@ def grouped_gemm_nt_f8f8bf16_masked(
|
|
51
34
|
with compile_utils.deep_gemm_execution_hook(
|
52
35
|
expected_m, n, k, num_groups, kernel_type
|
53
36
|
):
|
54
|
-
|
37
|
+
deep_gemm.fp8_m_grouped_gemm_nt_masked(
|
55
38
|
lhs,
|
56
39
|
rhs,
|
57
40
|
out,
|
58
41
|
masked_m,
|
59
42
|
expected_m,
|
60
|
-
**({"recipe": recipe} if DEEPGEMM_BLACKWELL else {})
|
61
43
|
)
|
62
44
|
|
63
45
|
|
@@ -72,7 +54,7 @@ def grouped_gemm_nt_f8f8bf16_contig(
|
|
72
54
|
kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
|
73
55
|
|
74
56
|
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
|
75
|
-
|
57
|
+
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(lhs, rhs, out, m_indices)
|
76
58
|
|
77
59
|
|
78
60
|
def gemm_nt_f8f8bf16(
|
@@ -86,7 +68,7 @@ def gemm_nt_f8f8bf16(
|
|
86
68
|
kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16
|
87
69
|
|
88
70
|
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
|
89
|
-
|
71
|
+
deep_gemm.fp8_gemm_nt(
|
90
72
|
lhs,
|
91
73
|
rhs,
|
92
74
|
out,
|
@@ -64,7 +64,6 @@ from sglang.srt.layers.quantization.utils import (
|
|
64
64
|
per_tensor_dequantize,
|
65
65
|
requantize_with_max_scale,
|
66
66
|
)
|
67
|
-
from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
|
68
67
|
from sglang.srt.utils import (
|
69
68
|
cpu_has_amx_support,
|
70
69
|
get_bool_env_var,
|
@@ -72,6 +71,8 @@ from sglang.srt.utils import (
|
|
72
71
|
is_cuda,
|
73
72
|
is_hip,
|
74
73
|
is_npu,
|
74
|
+
is_sm90_supported,
|
75
|
+
is_sm100_supported,
|
75
76
|
log_info_on_rank0,
|
76
77
|
next_power_of_2,
|
77
78
|
print_warning_once,
|
@@ -298,7 +298,7 @@ def _per_token_group_quant_8bit_raw(
|
|
298
298
|
)
|
299
299
|
|
300
300
|
if scale_ue8m0:
|
301
|
-
from deep_gemm
|
301
|
+
from deep_gemm import transform_sf_into_required_layout
|
302
302
|
|
303
303
|
assert group_size == 128
|
304
304
|
x_s = transform_sf_into_required_layout(
|
@@ -338,7 +338,7 @@ def _per_token_group_quant_8bit_fuse_silu_and_mul(
|
|
338
338
|
# scale_ue8m0=scale_ue8m0,
|
339
339
|
# )
|
340
340
|
|
341
|
-
from deep_gemm
|
341
|
+
from deep_gemm import transform_sf_into_required_layout
|
342
342
|
|
343
343
|
from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd
|
344
344
|
|
@@ -5,7 +5,7 @@ import torch
|
|
5
5
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
6
6
|
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
7
7
|
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
|
8
|
-
from sglang.srt.
|
8
|
+
from sglang.srt.utils import is_sm100_supported
|
9
9
|
|
10
10
|
try:
|
11
11
|
from vllm import _custom_ops as ops
|
@@ -459,7 +459,7 @@ def _requant_weight_ue8m0(
|
|
459
459
|
import deep_gemm.utils.layout
|
460
460
|
|
461
461
|
sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128)
|
462
|
-
sf = deep_gemm.utils.layout.
|
462
|
+
sf = deep_gemm.utils.layout.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
|
463
463
|
return sf
|
464
464
|
|
465
465
|
out_s = _transform_scale(out_s, mn=out_w.shape[-2])
|
@@ -599,6 +599,13 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
599
599
|
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
|
600
600
|
if re.fullmatch(regex_str, prefix):
|
601
601
|
return True
|
602
|
+
|
603
|
+
# Check if the last part of the excluded pattern is contained in the last part of the prefix
|
604
|
+
# This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa
|
605
|
+
pattern_last_part = pattern.split(".")[-1]
|
606
|
+
prefix_last_part = prefix.split(".")[-1]
|
607
|
+
if pattern_last_part in prefix_last_part:
|
608
|
+
return True
|
602
609
|
return False
|
603
610
|
|
604
611
|
def get_quant_method(
|
@@ -29,14 +29,13 @@ from sglang.srt.layers.quantization.base_config import (
|
|
29
29
|
QuantizeMethodBase,
|
30
30
|
)
|
31
31
|
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
32
|
-
from sglang.srt.layers.utils import is_sm100_supported
|
33
32
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
34
33
|
from sglang.srt.utils import (
|
35
34
|
direct_register_custom_op,
|
36
|
-
get_bool_env_var,
|
37
35
|
is_cuda,
|
38
36
|
is_flashinfer_available,
|
39
37
|
is_hip,
|
38
|
+
is_sm100_supported,
|
40
39
|
is_triton_kernels_available,
|
41
40
|
log_info_on_rank0,
|
42
41
|
mxfp_supported,
|
@@ -67,10 +66,15 @@ _is_hip = is_hip()
|
|
67
66
|
|
68
67
|
if _is_hip:
|
69
68
|
# import aiter
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
69
|
+
try:
|
70
|
+
from aiter import ActivationType, QuantType, dtypes
|
71
|
+
from aiter.fused_moe import fused_moe
|
72
|
+
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
73
|
+
from aiter.utility.fp4_utils import e8m0_shuffle
|
74
|
+
except ImportError as err:
|
75
|
+
ActivationType = QuantType = dtypes = fused_moe = dynamic_mxfp4_quant = (
|
76
|
+
e8m0_shuffle
|
77
|
+
) = err
|
74
78
|
|
75
79
|
|
76
80
|
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
|
@@ -146,27 +150,21 @@ def _quant_dequant_mxfp4_fake(
|
|
146
150
|
return torch.empty_like(x)
|
147
151
|
|
148
152
|
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
mutates_args=[],
|
165
|
-
fake_impl=_quant_dequant_mxfp4_fake,
|
166
|
-
)
|
167
|
-
quant_dequant_mxfp4 = torch.ops.sglang.quant_dequant_mxfp4
|
168
|
-
except AttributeError as error:
|
169
|
-
raise error
|
153
|
+
direct_register_custom_op(
|
154
|
+
op_name="dequant_mxfp4",
|
155
|
+
op_func=_dequant_mxfp4,
|
156
|
+
mutates_args=[],
|
157
|
+
fake_impl=_dequant_mxfp4_fake,
|
158
|
+
)
|
159
|
+
dequant_mxfp4 = torch.ops.sglang.dequant_mxfp4
|
160
|
+
|
161
|
+
direct_register_custom_op(
|
162
|
+
op_name="quant_dequant_mxfp4",
|
163
|
+
op_func=_quant_dequant_mxfp4,
|
164
|
+
mutates_args=[],
|
165
|
+
fake_impl=_quant_dequant_mxfp4_fake,
|
166
|
+
)
|
167
|
+
quant_dequant_mxfp4 = torch.ops.sglang.quant_dequant_mxfp4
|
170
168
|
|
171
169
|
|
172
170
|
class Mxfp4Config(QuantizationConfig):
|