sglang 0.5.1.post1__py3-none-any.whl → 0.5.1.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. sglang/bench_one_batch_server.py +79 -53
  2. sglang/bench_serving.py +186 -14
  3. sglang/profiler.py +0 -1
  4. sglang/srt/conversation.py +38 -5
  5. sglang/srt/disaggregation/decode.py +4 -0
  6. sglang/srt/disaggregation/prefill.py +4 -0
  7. sglang/srt/entrypoints/engine.py +2 -2
  8. sglang/srt/entrypoints/openai/protocol.py +27 -24
  9. sglang/srt/entrypoints/openai/serving_chat.py +50 -9
  10. sglang/srt/entrypoints/openai/serving_completions.py +15 -0
  11. sglang/srt/entrypoints/tool.py +7 -7
  12. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  13. sglang/srt/function_call/function_call_parser.py +2 -0
  14. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  15. sglang/srt/harmony_parser.py +588 -0
  16. sglang/srt/hf_transformers_utils.py +16 -7
  17. sglang/srt/layers/attention/ascend_backend.py +218 -111
  18. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  19. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  20. sglang/srt/layers/attention/flashinfer_mla_backend.py +76 -91
  21. sglang/srt/layers/attention/utils.py +15 -94
  22. sglang/srt/layers/communicator.py +1 -2
  23. sglang/srt/layers/moe/cutlass_moe.py +0 -15
  24. sglang/srt/layers/moe/ep_moe/layer.py +1 -7
  25. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  27. sglang/srt/layers/moe/topk.py +1 -1
  28. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +133 -235
  29. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -7
  30. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +5 -23
  31. sglang/srt/layers/quantization/fp8.py +2 -1
  32. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  33. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  34. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  35. sglang/srt/layers/quantization/mxfp4.py +16 -23
  36. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  37. sglang/srt/layers/utils.py +0 -14
  38. sglang/srt/lora/lora_manager.py +29 -12
  39. sglang/srt/managers/cache_controller.py +223 -156
  40. sglang/srt/managers/detokenizer_manager.py +5 -0
  41. sglang/srt/managers/io_struct.py +30 -0
  42. sglang/srt/managers/scheduler.py +58 -7
  43. sglang/srt/managers/scheduler_metrics_mixin.py +15 -0
  44. sglang/srt/managers/tokenizer_manager.py +36 -3
  45. sglang/srt/mem_cache/hicache_storage.py +31 -20
  46. sglang/srt/mem_cache/hiradix_cache.py +12 -3
  47. sglang/srt/mem_cache/memory_pool.py +73 -14
  48. sglang/srt/mem_cache/memory_pool_host.py +3 -2
  49. sglang/srt/mem_cache/radix_cache.py +1 -0
  50. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +5 -13
  51. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +85 -81
  52. sglang/srt/metrics/collector.py +5 -5
  53. sglang/srt/model_executor/cuda_graph_runner.py +2 -2
  54. sglang/srt/model_executor/model_runner.py +1 -1
  55. sglang/srt/models/deepseek_v2.py +12 -3
  56. sglang/srt/models/gpt_oss.py +2 -1
  57. sglang/srt/models/qwen2_5_vl.py +1 -0
  58. sglang/srt/offloader.py +115 -0
  59. sglang/srt/reasoning_parser.py +56 -300
  60. sglang/srt/server_args.py +10 -5
  61. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  62. sglang/srt/utils.py +59 -12
  63. sglang/test/test_cutlass_moe.py +33 -28
  64. sglang/version.py +1 -1
  65. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/METADATA +6 -5
  66. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/RECORD +69 -65
  67. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/WHEEL +0 -0
  68. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/licenses/LICENSE +0 -0
  69. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/top_level.txt +0 -0
@@ -1,26 +1,22 @@
1
1
  import logging
2
2
  import os
3
3
  from contextlib import contextmanager
4
- from dataclasses import dataclass
5
4
  from enum import IntEnum, auto
6
- from typing import Callable, Dict, List, Optional, Tuple
5
+ from typing import Dict, List, Tuple
7
6
 
8
- from tqdm.contrib.concurrent import thread_map
7
+ import torch
8
+ from tqdm import tqdm
9
9
 
10
10
  from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
11
- DEEPGEMM_BLACKWELL,
12
11
  ENABLE_JIT_DEEPGEMM,
13
12
  )
14
13
  from sglang.srt.server_args import ServerArgs
15
- from sglang.srt.utils import get_bool_env_var, get_int_env_var
14
+ from sglang.srt.utils import ceil_div, get_bool_env_var, get_int_env_var
16
15
 
17
16
  logger = logging.getLogger(__name__)
18
17
 
19
- if ENABLE_JIT_DEEPGEMM and not DEEPGEMM_BLACKWELL:
20
- from deep_gemm import get_num_sms
21
- from deep_gemm.jit import build
22
- from deep_gemm.jit_kernels.gemm import get_best_configs
23
- from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
18
+ if ENABLE_JIT_DEEPGEMM:
19
+ import deep_gemm
24
20
 
25
21
 
26
22
  _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
@@ -40,19 +36,7 @@ os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
40
36
  # Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
41
37
  # NVRTC may have performance loss with some cases.
42
38
  # And NVCC JIT speed is also 9x faster in the ref commit
43
- _USE_NVRTC_DEFAULT = "0"
44
- if ENABLE_JIT_DEEPGEMM:
45
- try:
46
- from deep_gemm.jit.compiler import get_nvcc_compiler
47
-
48
- get_nvcc_compiler()
49
- except:
50
- logger.warning(
51
- "NVCC Compiler not found, use NVRTC for DeepGEMM JIT "
52
- "and may have performance loss with some cases."
53
- )
54
- _USE_NVRTC_DEFAULT = "1"
55
- os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", _USE_NVRTC_DEFAULT)
39
+ os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", "0")
56
40
 
57
41
 
58
42
  def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
@@ -75,7 +59,7 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
75
59
  # Default each rank will try compile all Ms to
76
60
  # load all symbols at the launch stages.
77
61
  # Avoid loading symbols at the serving stages.
78
- _DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE or not _IN_PRECOMPILE_STAGE
62
+ _DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE
79
63
 
80
64
 
81
65
  class DeepGemmKernelType(IntEnum):
@@ -84,185 +68,15 @@ class DeepGemmKernelType(IntEnum):
84
68
  GEMM_NT_F8F8BF16 = auto()
85
69
 
86
70
 
87
- @dataclass
88
- class DeepGemmKernelHelper:
89
- name: str
90
- compile_func: Callable[
91
- [
92
- int,
93
- int,
94
- int,
95
- Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
96
- ],
97
- None,
98
- ]
99
- configure_func: Callable[
100
- [int, int, int, int, int],
101
- Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
102
- ]
103
-
104
-
105
71
  _INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dict()
106
72
 
107
73
 
108
- # TODO improve naming
109
- def _compile_warning_1():
110
- if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
111
- logger.warning(
112
- "Entering DeepGEMM JIT Pre-Compile session. "
113
- "It may takes a long time (typically 10-20 mins) "
114
- "if you have not run `sglang.compile_deep_gemm`. "
115
- "It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
116
- " for pre-compilation to reduce the overhead if you have not run it before. "
117
- "For example: "
118
- "`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
119
- )
120
-
121
-
122
- # TODO improve naming
123
- def _compile_warning_2():
124
- logger.warning(
125
- "Entering DeepGEMM JIT Single Kernel Compile session. "
126
- "And it will makes inference throughput becomes flaky. "
127
- "Please run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
128
- " for pre-compilation to solve this issue. "
129
- "For example: "
130
- "`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
131
- )
132
-
133
-
134
- def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
135
- n: int,
136
- k: int,
137
- num_groups: int,
138
- config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
139
- ) -> None:
140
- num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
141
- block_k = 128
142
- num_tma_threads = 128
143
- num_math_threads_per_group = 128
144
-
145
- kwargs = {
146
- "GEMM_TYPE": GemmType.GroupedMasked,
147
- "NUM_TMA_THREADS": num_tma_threads,
148
- "NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
149
- "N": n,
150
- "K": k,
151
- "NUM_GROUPS": num_groups,
152
- "BLOCK_M": block_m,
153
- "BLOCK_N": block_n,
154
- "BLOCK_K": block_k,
155
- "SWIZZLE_D_MODE": smem_config[1],
156
- "BLOCK_N_PADDING": smem_config[2],
157
- "NUM_STAGES": num_stages,
158
- "NUM_TMA_MULTICAST": tma_multicast_config[0],
159
- "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
160
- "NUM_SMS": num_sms,
161
- "SMEM_SIZE": smem_config[0],
162
- }
163
-
164
- code = FP8GemmRuntime.generate(kwargs)
165
- _ = build("m_grouped_gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
166
-
167
-
168
- def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
169
- n: int,
170
- k: int,
171
- num_groups: int,
172
- config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
173
- ) -> None:
174
- num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
175
- block_k = 128
176
- num_tma_threads = 128
177
- num_math_threads_per_group = 128
178
- kwargs = {
179
- "GEMM_TYPE": GemmType.GroupedContiguous,
180
- "NUM_TMA_THREADS": num_tma_threads,
181
- "NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
182
- "N": n,
183
- "K": k,
184
- "NUM_GROUPS": 1,
185
- "BLOCK_M": block_m,
186
- "BLOCK_N": block_n,
187
- "BLOCK_K": block_k,
188
- "SWIZZLE_D_MODE": smem_config[1],
189
- "BLOCK_N_PADDING": smem_config[2],
190
- "NUM_STAGES": num_stages,
191
- "NUM_TMA_MULTICAST": tma_multicast_config[0],
192
- "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
193
- "NUM_SMS": num_sms,
194
- "SMEM_SIZE": smem_config[0],
195
- }
196
-
197
- code = FP8GemmRuntime.generate(kwargs)
198
- _ = build("m_grouped_gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
199
-
200
-
201
- def _compile_gemm_nt_f8f8bf16_one(
202
- n: int,
203
- k: int,
204
- _: int, # _ is a dummy parameter to align with other interfaces
205
- config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
206
- ) -> None:
207
- num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
208
- block_k = 128
209
- num_tma_threads = 128
210
- num_math_threads_per_group = 128
211
- kwargs = {
212
- "GEMM_TYPE": GemmType.Normal,
213
- "NUM_TMA_THREADS": num_tma_threads,
214
- "NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
215
- "N": n,
216
- "K": k,
217
- "NUM_GROUPS": 1,
218
- "BLOCK_M": block_m,
219
- "BLOCK_N": block_n,
220
- "BLOCK_K": block_k,
221
- "SWIZZLE_D_MODE": smem_config[1],
222
- "BLOCK_N_PADDING": smem_config[2],
223
- "NUM_STAGES": num_stages,
224
- "NUM_TMA_MULTICAST": tma_multicast_config[0],
225
- "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
226
- "NUM_SMS": num_sms,
227
- "SMEM_SIZE": smem_config[0],
228
- }
229
-
230
- code = FP8GemmRuntime.generate(kwargs)
231
- _ = build("gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
232
-
233
-
234
- # TODO further refactor warmup-related
235
- _KERNEL_HELPER_DICT: Dict[DeepGemmKernelType, DeepGemmKernelHelper] = {
236
- DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: DeepGemmKernelHelper(
237
- name="m_grouped_gemm_fp8_fp8_bf16_nt_masked",
238
- compile_func=_compile_grouped_gemm_nt_f8f8bf16_masked_one,
239
- configure_func=lambda m, n, k, num_groups, num_sms: get_best_configs(
240
- m, n, k, num_groups, num_sms, is_grouped_masked=True
241
- ),
242
- ),
243
- DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG: DeepGemmKernelHelper(
244
- name="m_grouped_gemm_fp8_fp8_bf16_nt_contiguous",
245
- compile_func=_compile_grouped_gemm_nt_f8f8bf16_contig_one,
246
- configure_func=lambda m, n, k, _, num_sms: get_best_configs(
247
- m, n, k, 1, num_sms, is_grouped_contiguous=True
248
- ),
249
- ),
250
- DeepGemmKernelType.GEMM_NT_F8F8BF16: DeepGemmKernelHelper(
251
- name="gemm_fp8_fp8_bf16_nt",
252
- compile_func=_compile_gemm_nt_f8f8bf16_one,
253
- configure_func=lambda m, n, k, _, num_sms: get_best_configs(
254
- m, n, k, 1, num_sms
255
- ),
256
- ),
257
- }
258
-
259
-
74
+ # TODO improve code
260
75
  def _maybe_compile_deep_gemm_one_type_all(
261
76
  kernel_type: DeepGemmKernelType,
262
77
  n: int,
263
78
  k: int,
264
79
  num_groups: int,
265
- m_list: Optional[List[int]] = None,
266
80
  ) -> None:
267
81
  global _INITIALIZATION_DICT
268
82
  global _BUILTIN_M_LIST
@@ -275,61 +89,145 @@ def _maybe_compile_deep_gemm_one_type_all(
275
89
  ):
276
90
  _INITIALIZATION_DICT[query_key] = True
277
91
 
278
- kernel_helper = _KERNEL_HELPER_DICT[kernel_type]
279
- _compile_warning_1()
92
+ # TODO maybe improve logs
93
+ if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
94
+ logger.warning(
95
+ "Entering DeepGEMM JIT Pre-Compile session. "
96
+ "It may takes a long time (typically 10-20 mins) "
97
+ "if you have not run `sglang.compile_deep_gemm`. "
98
+ "It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
99
+ " for pre-compilation to reduce the overhead if you have not run it before. "
100
+ "For example: "
101
+ "`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
102
+ )
103
+
280
104
  logger.info(
281
105
  f"Try DeepGEMM JIT Compiling for "
282
- f"<{kernel_helper.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
106
+ f"<{kernel_type.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
283
107
  f"{' It only takes a little time (typically 1 sec) if you have run `python3 -m sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}"
284
108
  )
285
109
 
286
- # NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
287
- num_sms = get_num_sms()
288
- collected_configs = set()
289
- for m in m_list if m_list is not None else _BUILTIN_M_LIST:
290
- # Put config into set to get unique configs and reduce cases to be compiled
291
- collected_configs.add(
292
- kernel_helper.configure_func(m, n, k, num_groups, num_sms)
293
- )
294
- compile_func = lambda config: kernel_helper.compile_func(
295
- n, k, num_groups, config
110
+ _compile_deep_gemm_one_type_all(
111
+ kernel_type=kernel_type,
112
+ n=n,
113
+ k=k,
114
+ num_groups=num_groups,
115
+ m_list=_BUILTIN_M_LIST,
296
116
  )
297
- thread_map(compile_func, collected_configs, max_workers=_COMPILE_WORKERS)
298
117
 
299
118
 
300
- @contextmanager
301
- def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
302
- if _IN_PRECOMPILE_STAGE:
303
- yield
304
- return
119
+ # NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
120
+ def _compile_deep_gemm_one_type_all(
121
+ kernel_type: DeepGemmKernelType,
122
+ n: int,
123
+ k: int,
124
+ num_groups: int,
125
+ m_list: List[int],
126
+ ) -> None:
127
+ if kernel_type == DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG:
128
+ m_alignment = deep_gemm.get_mk_alignment_for_contiguous_layout()
129
+ m_list = sorted(list(set(m for m in m_list if m % m_alignment == 0)))
305
130
 
306
- from deep_gemm.jit.runtime import RuntimeCache
131
+ executor = _BaseWarmupExecutor.create(
132
+ kernel_type, max_m=max(m_list), n=n, k=k, num_groups=num_groups
133
+ )
307
134
 
308
- origin_func = RuntimeCache.get
135
+ # TODO can use multi thread
136
+ for m in tqdm(m_list, desc=f"DeepGEMM warmup"):
137
+ executor.execute(m=m)
309
138
 
310
- def __patched_func(self, *args, **kwargs):
311
- ret = origin_func(self, *args, **kwargs)
312
- if ret is None:
313
- kernel_helper = _KERNEL_HELPER_DICT[kernel_type]
314
- if not DEEPGEMM_BLACKWELL:
315
- _compile_warning_2()
316
- logger.warning(
317
- f"DeepGEMM JIT Compiling for <{kernel_helper.name}> M={M}, N={N}, K={K}. Please wait."
318
- )
319
- return ret
320
139
 
321
- RuntimeCache.get = __patched_func
322
- yield
323
- RuntimeCache.get = origin_func
140
+ class _BaseWarmupExecutor:
141
+ @staticmethod
142
+ def create(kernel_type: DeepGemmKernelType, **kwargs):
143
+ return {
144
+ DeepGemmKernelType.GEMM_NT_F8F8BF16: _NormalWarmupExecutor,
145
+ DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG: _GroupedContWarmupExecutor,
146
+ DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: _GroupedMaskedWarmupExecutor,
147
+ }[kernel_type](**kwargs)
148
+
149
+ def execute(self, m):
150
+ raise NotImplementedError
151
+
152
+
153
+ def _empty_token_fp8(size):
154
+ *dims, k = size
155
+ return (
156
+ torch.empty(size, device="cuda", dtype=torch.float8_e4m3fn),
157
+ torch.empty(
158
+ (*dims, ceil_div(k, _BLOCK_SIZE)), device="cuda", dtype=torch.float32
159
+ ),
160
+ )
161
+
162
+
163
+ def _empty_block_fp8(size):
164
+ *dims, n, k = size
165
+ return (
166
+ torch.empty(size, device="cuda", dtype=torch.float8_e4m3fn),
167
+ torch.empty(
168
+ (*dims, ceil_div(n, _BLOCK_SIZE), ceil_div(k, _BLOCK_SIZE)),
169
+ device="cuda",
170
+ dtype=torch.float32,
171
+ ),
172
+ )
173
+
174
+
175
+ _BLOCK_SIZE = 128
176
+
177
+
178
+ class _NormalWarmupExecutor(_BaseWarmupExecutor):
179
+ def __init__(self, max_m: int, n: int, k: int, num_groups: int):
180
+ self.lhs_q, self.lhs_s = _empty_token_fp8((max_m, k))
181
+ self.rhs_q, self.rhs_s = _empty_block_fp8((n, k))
182
+ self.out = torch.empty((max_m, n), device="cuda", dtype=torch.bfloat16)
183
+
184
+ def execute(self, m):
185
+ deep_gemm.fp8_gemm_nt(
186
+ (self.lhs_q[:m], self.lhs_s[:m]),
187
+ (self.rhs_q, self.rhs_s),
188
+ self.out[:m],
189
+ )
190
+
191
+
192
+ class _GroupedContWarmupExecutor(_BaseWarmupExecutor):
193
+ def __init__(self, max_m: int, n: int, k: int, num_groups: int):
194
+ self.lhs_q, self.lhs_s = _empty_token_fp8((max_m, k))
195
+ self.rhs_q, self.rhs_s = _empty_block_fp8((num_groups, n, k))
196
+ self.m_indices = torch.zeros((max_m,), device="cuda", dtype=torch.int32)
197
+ self.out = torch.empty((max_m, n), device="cuda", dtype=torch.bfloat16)
198
+
199
+ def execute(self, m):
200
+ deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
201
+ (self.lhs_q[:m], self.lhs_s[:m]),
202
+ (self.rhs_q, self.rhs_s),
203
+ self.out[:m],
204
+ m_indices=self.m_indices[:m],
205
+ )
206
+
207
+
208
+ class _GroupedMaskedWarmupExecutor(_BaseWarmupExecutor):
209
+ def __init__(self, max_m: int, n: int, k: int, num_groups: int):
210
+ self.lhs_q, self.lhs_s = _empty_token_fp8((num_groups, max_m, k))
211
+ self.rhs_q, self.rhs_s = _empty_block_fp8((num_groups, n, k))
212
+ self.masked_m = torch.zeros((num_groups,), device="cuda", dtype=torch.int32)
213
+ self.out = torch.empty(
214
+ (num_groups, max_m, n), device="cuda", dtype=torch.bfloat16
215
+ )
216
+
217
+ def execute(self, m):
218
+ deep_gemm.fp8_m_grouped_gemm_nt_masked(
219
+ (self.lhs_q, self.lhs_s),
220
+ (self.rhs_q, self.rhs_s),
221
+ self.out,
222
+ masked_m=self.masked_m,
223
+ # DeepGEMM uses `expect_m` instead of input shape for `get_best_config`
224
+ expected_m=m,
225
+ )
324
226
 
325
227
 
326
228
  @contextmanager
327
229
  def deep_gemm_execution_hook(
328
230
  m: int, n: int, k: int, num_groups: int, kernel_type: DeepGemmKernelType
329
231
  ):
330
- # not supported yet
331
- if not DEEPGEMM_BLACKWELL:
332
- _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
333
-
334
- with _log_jit_build(m, n, k, kernel_type):
335
- yield
232
+ _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
233
+ yield
@@ -24,14 +24,12 @@ def _compute_enable_deep_gemm():
24
24
  return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
25
25
 
26
26
 
27
- ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
27
+ def _is_blackwell_arch() -> bool:
28
+ major, minor = torch.cuda.get_device_capability(torch.cuda.current_device())
29
+ return major == 10
28
30
 
29
- try:
30
- from deep_gemm import fp8_gemm_nt
31
31
 
32
- # They have not given a name to this breaking change
33
- DEEPGEMM_BLACKWELL = True
34
- except ImportError:
35
- DEEPGEMM_BLACKWELL = False
32
+ ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
36
33
 
34
+ DEEPGEMM_BLACKWELL = ENABLE_JIT_DEEPGEMM and _is_blackwell_arch()
37
35
  DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL
@@ -16,33 +16,16 @@ logger = logging.getLogger(__name__)
16
16
 
17
17
  if ENABLE_JIT_DEEPGEMM:
18
18
  import deep_gemm
19
-
20
- if DEEPGEMM_BLACKWELL:
21
- from deep_gemm import fp8_gemm_nt as _gemm_nt_f8f8bf16_raw
22
- from deep_gemm import (
23
- fp8_m_grouped_gemm_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
24
- )
25
- from deep_gemm import (
26
- m_grouped_fp8_gemm_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
27
- )
28
- else:
29
- from deep_gemm import gemm_fp8_fp8_bf16_nt as _gemm_nt_f8f8bf16_raw
30
- from deep_gemm import get_col_major_tma_aligned_tensor
31
- from deep_gemm import (
32
- m_grouped_gemm_fp8_fp8_bf16_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
33
- )
34
- from deep_gemm import (
35
- m_grouped_gemm_fp8_fp8_bf16_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
36
- )
19
+ from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
37
20
 
38
21
 
22
+ # TODO maybe rename these functions
39
23
  def grouped_gemm_nt_f8f8bf16_masked(
40
24
  lhs: Tuple[torch.Tensor, torch.Tensor],
41
25
  rhs: Tuple[torch.Tensor, torch.Tensor],
42
26
  out: torch.Tensor,
43
27
  masked_m: torch.Tensor,
44
28
  expected_m: int,
45
- recipe=None,
46
29
  ):
47
30
  num_groups, _, k = lhs[0].shape
48
31
  _, n, _ = rhs[0].shape
@@ -51,13 +34,12 @@ def grouped_gemm_nt_f8f8bf16_masked(
51
34
  with compile_utils.deep_gemm_execution_hook(
52
35
  expected_m, n, k, num_groups, kernel_type
53
36
  ):
54
- _grouped_gemm_nt_f8f8bf16_masked_raw(
37
+ deep_gemm.fp8_m_grouped_gemm_nt_masked(
55
38
  lhs,
56
39
  rhs,
57
40
  out,
58
41
  masked_m,
59
42
  expected_m,
60
- **({"recipe": recipe} if DEEPGEMM_BLACKWELL else {})
61
43
  )
62
44
 
63
45
 
@@ -72,7 +54,7 @@ def grouped_gemm_nt_f8f8bf16_contig(
72
54
  kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
73
55
 
74
56
  with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
75
- _grouped_gemm_nt_f8f8bf16_contig_raw(lhs, rhs, out, m_indices)
57
+ deep_gemm.m_grouped_fp8_gemm_nt_contiguous(lhs, rhs, out, m_indices)
76
58
 
77
59
 
78
60
  def gemm_nt_f8f8bf16(
@@ -86,7 +68,7 @@ def gemm_nt_f8f8bf16(
86
68
  kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16
87
69
 
88
70
  with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
89
- _gemm_nt_f8f8bf16_raw(
71
+ deep_gemm.fp8_gemm_nt(
90
72
  lhs,
91
73
  rhs,
92
74
  out,
@@ -64,7 +64,6 @@ from sglang.srt.layers.quantization.utils import (
64
64
  per_tensor_dequantize,
65
65
  requantize_with_max_scale,
66
66
  )
67
- from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
68
67
  from sglang.srt.utils import (
69
68
  cpu_has_amx_support,
70
69
  get_bool_env_var,
@@ -72,6 +71,8 @@ from sglang.srt.utils import (
72
71
  is_cuda,
73
72
  is_hip,
74
73
  is_npu,
74
+ is_sm90_supported,
75
+ is_sm100_supported,
75
76
  log_info_on_rank0,
76
77
  next_power_of_2,
77
78
  print_warning_once,
@@ -298,7 +298,7 @@ def _per_token_group_quant_8bit_raw(
298
298
  )
299
299
 
300
300
  if scale_ue8m0:
301
- from deep_gemm.utils.layout import transform_sf_into_required_layout
301
+ from deep_gemm import transform_sf_into_required_layout
302
302
 
303
303
  assert group_size == 128
304
304
  x_s = transform_sf_into_required_layout(
@@ -338,7 +338,7 @@ def _per_token_group_quant_8bit_fuse_silu_and_mul(
338
338
  # scale_ue8m0=scale_ue8m0,
339
339
  # )
340
340
 
341
- from deep_gemm.utils.layout import transform_sf_into_required_layout
341
+ from deep_gemm import transform_sf_into_required_layout
342
342
 
343
343
  from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd
344
344
 
@@ -5,7 +5,7 @@ import torch
5
5
  from sglang.srt.layers.quantization import deep_gemm_wrapper
6
6
  from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
7
7
  from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
8
- from sglang.srt.layers.utils import is_sm100_supported
8
+ from sglang.srt.utils import is_sm100_supported
9
9
 
10
10
  try:
11
11
  from vllm import _custom_ops as ops
@@ -459,7 +459,7 @@ def _requant_weight_ue8m0(
459
459
  import deep_gemm.utils.layout
460
460
 
461
461
  sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128)
462
- sf = deep_gemm.utils.layout.get_col_major_tma_aligned_packed_tensor(sf)
462
+ sf = deep_gemm.utils.layout.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
463
463
  return sf
464
464
 
465
465
  out_s = _transform_scale(out_s, mn=out_w.shape[-2])
@@ -876,7 +876,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
876
876
  data=torch.empty(
877
877
  layer.num_local_experts,
878
878
  2 * intermediate_size_per_partition,
879
- # 2 fp4 items are packed in the input dimension
880
879
  hidden_size // self.quant_config.group_size,
881
880
  dtype=weight_scale_dtype,
882
881
  ),
@@ -895,7 +894,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
895
894
  data=torch.empty(
896
895
  layer.num_local_experts,
897
896
  hidden_size,
898
- # 2 fp4 items are packed in the input dimension
899
897
  intermediate_size_per_partition // self.quant_config.group_size,
900
898
  dtype=weight_scale_dtype,
901
899
  ),
@@ -1212,11 +1210,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1212
1210
 
1213
1211
  # Process w13 weights
1214
1212
  w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale)
1213
+ del layer.w13_weight_scale
1215
1214
  layer.w13_blockscale_swizzled.data.copy_(w13_blockscale_swizzled)
1216
1215
  layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
1217
1216
 
1218
1217
  # Process w2 weights
1219
1218
  w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
1219
+ del layer.w2_weight_scale
1220
1220
  layer.w2_blockscale_swizzled.data.copy_(w2_blockscale_swizzled)
1221
1221
  layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
1222
1222
 
@@ -29,14 +29,13 @@ from sglang.srt.layers.quantization.base_config import (
29
29
  QuantizeMethodBase,
30
30
  )
31
31
  from sglang.srt.layers.quantization.utils import is_layer_skipped
32
- from sglang.srt.layers.utils import is_sm100_supported
33
32
  from sglang.srt.managers.schedule_batch import global_server_args_dict
34
33
  from sglang.srt.utils import (
35
34
  direct_register_custom_op,
36
- get_bool_env_var,
37
35
  is_cuda,
38
36
  is_flashinfer_available,
39
37
  is_hip,
38
+ is_sm100_supported,
40
39
  is_triton_kernels_available,
41
40
  log_info_on_rank0,
42
41
  mxfp_supported,
@@ -146,27 +145,21 @@ def _quant_dequant_mxfp4_fake(
146
145
  return torch.empty_like(x)
147
146
 
148
147
 
149
- try:
150
- direct_register_custom_op(
151
- op_name="dequant_mxfp4",
152
- op_func=_dequant_mxfp4,
153
- mutates_args=[],
154
- fake_impl=_dequant_mxfp4_fake,
155
- )
156
- dequant_mxfp4 = torch.ops.sglang.dequant_mxfp4
157
- except AttributeError as error:
158
- raise error
159
-
160
- try:
161
- direct_register_custom_op(
162
- op_name="quant_dequant_mxfp4",
163
- op_func=_quant_dequant_mxfp4,
164
- mutates_args=[],
165
- fake_impl=_quant_dequant_mxfp4_fake,
166
- )
167
- quant_dequant_mxfp4 = torch.ops.sglang.quant_dequant_mxfp4
168
- except AttributeError as error:
169
- raise error
148
+ direct_register_custom_op(
149
+ op_name="dequant_mxfp4",
150
+ op_func=_dequant_mxfp4,
151
+ mutates_args=[],
152
+ fake_impl=_dequant_mxfp4_fake,
153
+ )
154
+ dequant_mxfp4 = torch.ops.sglang.dequant_mxfp4
155
+
156
+ direct_register_custom_op(
157
+ op_name="quant_dequant_mxfp4",
158
+ op_func=_quant_dequant_mxfp4,
159
+ mutates_args=[],
160
+ fake_impl=_quant_dequant_mxfp4_fake,
161
+ )
162
+ quant_dequant_mxfp4 = torch.ops.sglang.quant_dequant_mxfp4
170
163
 
171
164
 
172
165
  class Mxfp4Config(QuantizationConfig):
@@ -13,6 +13,8 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ from typing import Optional
17
+
16
18
  import torch
17
19
 
18
20
 
@@ -24,7 +26,7 @@ class MXFP4QuantizeUtil:
24
26
  E2M1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5])
25
27
 
26
28
  @classmethod
27
- def quantize(cls, input: torch.Tensor, block_size: int | None) -> tuple:
29
+ def quantize(cls, input: torch.Tensor, block_size: Optional[int]) -> tuple:
28
30
  """Converting a tensor to a quantized format based on MXFP4 quantization. Only E4M3 is supported.
29
31
  Args:
30
32
  input (torch.Tensor): The input tensor to be quantized.