sglang 0.4.5.post2__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. sglang/bench_serving.py +3 -2
  2. sglang/compile_deep_gemm.py +136 -0
  3. sglang/lang/backend/openai.py +5 -1
  4. sglang/lang/backend/runtime_endpoint.py +5 -1
  5. sglang/srt/configs/model_config.py +4 -1
  6. sglang/srt/constrained/xgrammar_backend.py +1 -0
  7. sglang/srt/disaggregation/decode.py +43 -0
  8. sglang/srt/disaggregation/mini_lb.py +69 -8
  9. sglang/srt/disaggregation/mooncake/conn.py +1 -1
  10. sglang/srt/disaggregation/nixl/__init__.py +1 -0
  11. sglang/srt/disaggregation/nixl/conn.py +622 -0
  12. sglang/srt/disaggregation/prefill.py +100 -16
  13. sglang/srt/disaggregation/utils.py +17 -0
  14. sglang/srt/entrypoints/engine.py +4 -0
  15. sglang/srt/entrypoints/http_server.py +3 -7
  16. sglang/srt/function_call_parser.py +60 -0
  17. sglang/srt/layers/activation.py +2 -2
  18. sglang/srt/layers/attention/flashattention_backend.py +781 -150
  19. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
  20. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -5
  21. sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
  22. sglang/srt/layers/dp_attention.py +1 -1
  23. sglang/srt/layers/layernorm.py +19 -4
  24. sglang/srt/layers/moe/ep_moe/layer.py +2 -0
  25. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  26. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  27. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
  28. sglang/srt/layers/quantization/deep_gemm.py +378 -0
  29. sglang/srt/layers/quantization/fp8_kernel.py +7 -38
  30. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  31. sglang/srt/layers/quantization/gptq.py +13 -7
  32. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  33. sglang/srt/layers/quantization/w8a8_int8.py +3 -3
  34. sglang/srt/layers/rotary_embedding.py +6 -6
  35. sglang/srt/layers/sampler.py +2 -2
  36. sglang/srt/managers/data_parallel_controller.py +7 -1
  37. sglang/srt/managers/io_struct.py +14 -3
  38. sglang/srt/managers/schedule_batch.py +13 -0
  39. sglang/srt/managers/scheduler.py +16 -6
  40. sglang/srt/managers/tokenizer_manager.py +115 -29
  41. sglang/srt/managers/tp_worker.py +1 -0
  42. sglang/srt/mem_cache/hiradix_cache.py +40 -32
  43. sglang/srt/mem_cache/memory_pool.py +31 -13
  44. sglang/srt/model_executor/cuda_graph_runner.py +13 -8
  45. sglang/srt/model_executor/model_runner.py +19 -4
  46. sglang/srt/models/deepseek_v2.py +9 -6
  47. sglang/srt/models/minicpm3.py +2 -2
  48. sglang/srt/models/minicpmo.py +17 -6
  49. sglang/srt/openai_api/adapter.py +71 -4
  50. sglang/srt/openai_api/protocol.py +6 -1
  51. sglang/srt/server_args.py +52 -40
  52. sglang/srt/speculative/build_eagle_tree.py +2 -2
  53. sglang/srt/speculative/eagle_utils.py +2 -2
  54. sglang/srt/speculative/eagle_worker.py +2 -7
  55. sglang/srt/utils.py +46 -5
  56. sglang/test/test_utils.py +3 -1
  57. sglang/version.py +1 -1
  58. {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +3 -3
  59. {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +62 -57
  60. {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +0 -0
  61. {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
  62. {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,378 @@
1
+ import logging
2
+ import os
3
+ from contextlib import contextmanager
4
+ from dataclasses import dataclass
5
+ from enum import IntEnum, auto
6
+ from typing import Callable, Dict, List, Optional, Tuple
7
+
8
+ import torch
9
+ from tqdm.contrib.concurrent import thread_map
10
+
11
+ from sglang.srt.server_args import ServerArgs
12
+ from sglang.srt.utils import get_bool_env_var, get_device_sm, get_int_env_var, is_cuda
13
+
14
+ _ENABLE_JIT_DEEPGEMM = False
15
+ if is_cuda():
16
+ import deep_gemm
17
+ from deep_gemm import get_num_sms
18
+ from deep_gemm.jit_kernels.gemm import get_best_configs
19
+ from deep_gemm.jit_kernels.gemm import includes as deep_gemm_includes
20
+ from deep_gemm.jit_kernels.gemm import template as deep_gemm_gemm_template
21
+ from deep_gemm.jit_kernels.m_grouped_gemm import (
22
+ template as deep_gemm_grouped_gemm_template,
23
+ )
24
+ from deep_gemm.jit_kernels.tuner import jit_tuner
25
+
26
+ sm_version = get_device_sm()
27
+ if sm_version == 90:
28
+ if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="false"):
29
+ _ENABLE_JIT_DEEPGEMM = True
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+ _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
34
+ _ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var(
35
+ "SGL_JIT_DEEPGEMM_PRECOMPILE", "true"
36
+ )
37
+ _DO_COMPILE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true")
38
+ _COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
39
+ _IN_PRE_COMPILE_STAGE = get_bool_env_var("SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE", "false")
40
+
41
+ # Force redirect deep_gemm cache_dir
42
+ os.environ["DG_CACHE_DIR"] = os.getenv(
43
+ "SGL_DG_CACHE_DIR", os.path.expanduser("~") + "/.cache/deep_gemm"
44
+ )
45
+
46
+
47
+ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
48
+ global _BUILTIN_M_LIST
49
+ global _DO_COMPILE
50
+
51
+ # Generate m_max
52
+ m_max = 1024 * 16
53
+ if server_args.chunked_prefill_size < 1:
54
+ m_max = 1024 * 64
55
+ elif server_args.chunked_prefill_size > 8192:
56
+ m_max = server_args.chunked_prefill_size * 2
57
+ m_max = min(1024 * 128, m_max)
58
+ _BUILTIN_M_LIST = list(range(1, m_max + 1))
59
+
60
+ # Check if is the first rank on node
61
+ _DO_COMPILE = ServerArgs.base_gpu_id == gpu_id
62
+
63
+
64
+ class DeepGemmKernelType(IntEnum):
65
+ GROUPED_GEMM_NT_F8F8BF16_MASKED = auto()
66
+ GROUPED_GEMM_NT_F8F8BF16_CONTIG = auto()
67
+ GEMM_NT_F8F8BF16 = auto()
68
+
69
+
70
+ @dataclass
71
+ class DeepGemmKernelHelper:
72
+ name: str
73
+ compile_func: Callable[
74
+ [
75
+ int,
76
+ int,
77
+ int,
78
+ Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
79
+ ],
80
+ None,
81
+ ]
82
+ configure_func: Callable[
83
+ [int, int, int, int, int],
84
+ Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
85
+ ]
86
+
87
+
88
+ _INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dict()
89
+
90
+
91
+ def _compile_warning_1():
92
+ if not _IN_PRE_COMPILE_STAGE:
93
+ logger.warning(
94
+ "Entering DeepGEMM JIT Pre-Complie session. "
95
+ "And it may takes a long time(Typically 10-20 mins) "
96
+ "if you have not run `sglang.compile_deep_gemm`. "
97
+ "Recommand to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
98
+ " for pre-compilation to reduce the overhead if you have not run it before. "
99
+ "For example: "
100
+ "`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
101
+ )
102
+
103
+
104
+ def _compile_warning_2():
105
+ logger.warning(
106
+ "Entering DeepGEMM JIT Single Kernel Complie session. "
107
+ "And it will makes inference throughput becomes flaky. "
108
+ "Please run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
109
+ " for pre-compilation to solve this issue. "
110
+ "For example: "
111
+ "`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
112
+ )
113
+
114
+
115
+ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
116
+ n: int,
117
+ k: int,
118
+ num_groups: int,
119
+ config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
120
+ ) -> None:
121
+ # Auto-tuning with compilation
122
+ global deep_gemm_includes, deep_gemm_grouped_gemm_template
123
+ _, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
124
+ _ = jit_tuner.compile_and_tune(
125
+ name="m_grouped_gemm_fp8_fp8_bf16_nt",
126
+ keys={
127
+ "N": n,
128
+ "K": k,
129
+ "BLOCK_M": block_m,
130
+ "BLOCK_N": block_n,
131
+ "SWIZZLE_D_MODE": smem_config[1],
132
+ "BLOCK_N_PADDING": smem_config[2],
133
+ "NUM_GROUPS": num_groups,
134
+ "NUM_STAGES": num_stages,
135
+ "NUM_TMA_MULTICAST": tma_multicast_config[0],
136
+ "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
137
+ "GEMM_TYPE": "GroupedMasked",
138
+ },
139
+ space=(),
140
+ includes=deep_gemm_includes,
141
+ arg_defs=(
142
+ ("lhs", torch.float8_e4m3fn),
143
+ ("lhs_scales", torch.float),
144
+ ("rhs", torch.float8_e4m3fn),
145
+ ("rhs_scales", torch.float),
146
+ ("out", torch.bfloat16),
147
+ ("grouped_layout", torch.int32),
148
+ ("m", int),
149
+ ("stream", torch.cuda.Stream),
150
+ ("num_sms", int),
151
+ ("smem_size", int),
152
+ ),
153
+ template=deep_gemm_grouped_gemm_template,
154
+ args=[],
155
+ )
156
+
157
+
158
+ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
159
+ n: int,
160
+ k: int,
161
+ num_groups: int,
162
+ config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
163
+ ) -> None:
164
+ global deep_gemm_includes, deep_gemm_grouped_gemm_template
165
+ _, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
166
+ _ = jit_tuner.compile_and_tune(
167
+ name="m_grouped_gemm_fp8_fp8_bf16_nt",
168
+ keys={
169
+ "N": n,
170
+ "K": k,
171
+ "BLOCK_M": block_m,
172
+ "BLOCK_N": block_n,
173
+ "SWIZZLE_D_MODE": smem_config[1],
174
+ "BLOCK_N_PADDING": smem_config[2],
175
+ "NUM_GROUPS": num_groups,
176
+ "NUM_STAGES": num_stages,
177
+ "NUM_TMA_MULTICAST": tma_multicast_config[0],
178
+ "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
179
+ "GEMM_TYPE": "GroupedContiguous",
180
+ },
181
+ space=(),
182
+ includes=deep_gemm_includes,
183
+ arg_defs=(
184
+ ("lhs", torch.float8_e4m3fn),
185
+ ("lhs_scales", torch.float),
186
+ ("rhs", torch.float8_e4m3fn),
187
+ ("rhs_scales", torch.float),
188
+ ("out", torch.bfloat16),
189
+ ("grouped_layout", torch.int32),
190
+ ("m", int),
191
+ ("num_groups", int),
192
+ ("stream", torch.cuda.Stream),
193
+ ("num_sms", int),
194
+ ("smem_size", int),
195
+ ),
196
+ template=deep_gemm_grouped_gemm_template,
197
+ args=[],
198
+ )
199
+
200
+
201
+ def _compile_gemm_nt_f8f8bf16_one(
202
+ n: int,
203
+ k: int,
204
+ _: int, # _ is a dummy parameter to align with other interfaces
205
+ config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
206
+ ) -> None:
207
+ global deep_gemm_includes, deep_gemm_gemm_template
208
+ _, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
209
+ _ = jit_tuner.compile_and_tune(
210
+ name="gemm_fp8_fp8_bf16_nt",
211
+ keys={
212
+ "N": n,
213
+ "K": k,
214
+ "BLOCK_M": block_m,
215
+ "BLOCK_N": block_n,
216
+ "SWIZZLE_D_MODE": smem_config[1],
217
+ "BLOCK_N_PADDING": smem_config[2],
218
+ "NUM_STAGES": num_stages,
219
+ "NUM_TMA_MULTICAST": tma_multicast_config[0],
220
+ "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
221
+ },
222
+ space=(),
223
+ includes=deep_gemm_includes,
224
+ arg_defs=(
225
+ ("lhs", torch.float8_e4m3fn),
226
+ ("lhs_scales", torch.float),
227
+ ("rhs", torch.float8_e4m3fn),
228
+ ("rhs_scales", torch.float),
229
+ ("out", torch.bfloat16),
230
+ ("m", int),
231
+ ("stream", torch.cuda.Stream),
232
+ ("num_sms", int),
233
+ ("smem_size", int),
234
+ ),
235
+ template=deep_gemm_gemm_template,
236
+ args=[],
237
+ )
238
+
239
+
240
+ _KERNEL_HELPER_DICT: Dict[DeepGemmKernelType, DeepGemmKernelHelper] = {
241
+ DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: DeepGemmKernelHelper(
242
+ name="m_grouped_gemm_fp8_fp8_bf16_nt_masked",
243
+ compile_func=_compile_grouped_gemm_nt_f8f8bf16_masked_one,
244
+ configure_func=lambda m, n, k, num_groups, num_sms: get_best_configs(
245
+ m, n, k, num_groups, num_sms, is_grouped_masked=True
246
+ ),
247
+ ),
248
+ DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG: DeepGemmKernelHelper(
249
+ name="m_grouped_gemm_fp8_fp8_bf16_nt_contiguous",
250
+ compile_func=_compile_grouped_gemm_nt_f8f8bf16_contig_one,
251
+ configure_func=lambda m, n, k, _, num_sms: get_best_configs(
252
+ m, n, k, 1, num_sms, is_grouped_contiguous=True
253
+ ),
254
+ ),
255
+ DeepGemmKernelType.GEMM_NT_F8F8BF16: DeepGemmKernelHelper(
256
+ name="gemm_fp8_fp8_bf16_nt",
257
+ compile_func=_compile_gemm_nt_f8f8bf16_one,
258
+ configure_func=lambda m, n, k, _, num_sms: get_best_configs(
259
+ m, n, k, 1, num_sms
260
+ ),
261
+ ),
262
+ }
263
+
264
+
265
+ def _maybe_compile_deep_gemm_one_type_all(
266
+ kernel_type: DeepGemmKernelType,
267
+ n: int,
268
+ k: int,
269
+ num_groups: int,
270
+ m_list: Optional[List[int]] = None,
271
+ ) -> None:
272
+
273
+ global _INITIALIZATION_DICT
274
+ global _BUILTIN_M_LIST
275
+
276
+ query_key = (kernel_type, n, k, num_groups)
277
+ if (
278
+ _ENABLE_JIT_DEEPGEMM_PRECOMPILE
279
+ and _DO_COMPILE
280
+ and _INITIALIZATION_DICT.get(query_key) is None
281
+ ):
282
+ _INITIALIZATION_DICT[query_key] = True
283
+
284
+ kernel_helper = _KERNEL_HELPER_DICT[kernel_type]
285
+ _compile_warning_1()
286
+ logger.info(
287
+ f"Try DeepGEMM JIT Compiling for "
288
+ f"<{kernel_helper.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
289
+ f"{' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. ' if not _IN_PRE_COMPILE_STAGE else ''}"
290
+ )
291
+
292
+ # NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
293
+ num_sms = get_num_sms()
294
+ collected_configs = set()
295
+ for m in m_list if m_list is not None else _BUILTIN_M_LIST:
296
+ # Put config into set to get unique configs and reduce cases to be compiled
297
+ collected_configs.add(
298
+ kernel_helper.configure_func(m, n, k, num_groups, num_sms)
299
+ )
300
+ compile_func = lambda config: kernel_helper.compile_func(
301
+ n, k, num_groups, config
302
+ )
303
+ thread_map(compile_func, collected_configs, max_workers=_COMPILE_WORKERS)
304
+
305
+
306
+ def grouped_gemm_nt_f8f8bf16_masked(
307
+ lhs: Tuple[torch.Tensor, torch.Tensor],
308
+ rhs: Tuple[torch.Tensor, torch.Tensor],
309
+ out: torch.Tensor,
310
+ masked_m: torch.Tensor,
311
+ expected_m: int,
312
+ ):
313
+ num_groups, _, k = lhs[0].shape
314
+ _, n, _ = rhs[0].shape
315
+
316
+ kernel_type = DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED
317
+ _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
318
+
319
+ with _log_jit_build(expected_m, n, k, kernel_type):
320
+ deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
321
+ lhs, rhs, out, masked_m, expected_m
322
+ )
323
+
324
+
325
+ def grouped_gemm_nt_f8f8bf16_contig(
326
+ lhs: Tuple[torch.Tensor, torch.Tensor],
327
+ rhs: Tuple[torch.Tensor, torch.Tensor],
328
+ out: torch.Tensor,
329
+ m_indices: torch.Tensor,
330
+ ):
331
+ m, k = lhs[0].shape
332
+ num_groups, n, _ = rhs[0].shape
333
+
334
+ kernel_type = DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
335
+ _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
336
+
337
+ with _log_jit_build(m, n, k, kernel_type):
338
+ deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs, rhs, out, m_indices)
339
+
340
+
341
+ def gemm_nt_f8f8bf16(
342
+ lhs: Tuple[torch.Tensor, torch.Tensor],
343
+ rhs: Tuple[torch.Tensor, torch.Tensor],
344
+ out: torch.Tensor,
345
+ ):
346
+ m, k = lhs[0].shape
347
+ n, _ = rhs[0].shape
348
+
349
+ kernel_type = DeepGemmKernelType.GEMM_NT_F8F8BF16
350
+ _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, 1)
351
+
352
+ with _log_jit_build(m, n, k, kernel_type):
353
+ deep_gemm.gemm_fp8_fp8_bf16_nt(lhs, rhs, out)
354
+
355
+
356
+ @contextmanager
357
+ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
358
+ if _IN_PRE_COMPILE_STAGE:
359
+ yield
360
+ return
361
+
362
+ from deep_gemm.jit.runtime import RuntimeCache
363
+
364
+ origin_func = RuntimeCache.__getitem__
365
+
366
+ def __patched_func(self, *args, **kwargs):
367
+ ret = origin_func(self, *args, **kwargs)
368
+ if ret is None:
369
+ kernel_helper = _KERNEL_HELPER_DICT[kernel_type]
370
+ _compile_warning_2()
371
+ logger.warning(
372
+ f"DeepGEMM JIT Compiling for <{kernel_helper.name}> M={M}, N={N}, K={K}. Please wait."
373
+ )
374
+ return ret
375
+
376
+ RuntimeCache.__getitem__ = __patched_func
377
+ yield
378
+ RuntimeCache.__getitem__ = origin_func
@@ -16,19 +16,17 @@ import functools
16
16
  import json
17
17
  import logging
18
18
  import os
19
- from contextlib import contextmanager
20
19
  from typing import Any, Dict, List, Optional, Tuple
21
20
 
22
21
  import torch
23
22
  import triton
24
23
  import triton.language as tl
25
24
 
25
+ from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
26
26
  from sglang.srt.utils import (
27
27
  direct_register_custom_op,
28
- get_bool_env_var,
29
28
  get_device_core_count,
30
29
  get_device_name,
31
- get_device_sm,
32
30
  is_cuda,
33
31
  is_hip,
34
32
  supports_custom_op,
@@ -43,22 +41,16 @@ else:
43
41
  fp8_max = torch.finfo(_fp8_type).max
44
42
  fp8_min = -fp8_max
45
43
 
46
- _enable_jit_deepgemm = False
47
- _enable_jit_deepgemm_bmm = False
48
44
  if _is_cuda:
49
- import deep_gemm
50
45
  from sgl_kernel import (
51
46
  sgl_per_tensor_quant_fp8,
52
47
  sgl_per_token_group_quant_fp8,
53
48
  sgl_per_token_quant_fp8,
54
49
  )
55
50
 
56
- sm_version = get_device_sm()
57
- if sm_version == 90:
58
- if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="false"):
59
- _enable_jit_deepgemm = True
60
- if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM_BMM", default="false"):
61
- _enable_jit_deepgemm_bmm = True
51
+ from sglang.srt.layers.quantization.deep_gemm import (
52
+ gemm_nt_f8f8bf16 as deep_gemm_gemm_nt_f8f8bf16,
53
+ )
62
54
 
63
55
  logger = logging.getLogger(__name__)
64
56
 
@@ -71,10 +63,7 @@ if supports_custom_op():
71
63
  Bs: torch.Tensor,
72
64
  C: torch.Tensor,
73
65
  ) -> None:
74
- M, K = A.shape
75
- N, _ = B.shape
76
- with _log_jit_build(M, N, K):
77
- deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
66
+ deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
78
67
 
79
68
  def deep_gemm_fp8_fp8_bf16_nt_fake(
80
69
  A: torch.Tensor,
@@ -715,25 +704,6 @@ def get_w8a8_block_fp8_configs(
715
704
  return None
716
705
 
717
706
 
718
- @contextmanager
719
- def _log_jit_build(M: int, N: int, K: int):
720
- from deep_gemm.jit.runtime import RuntimeCache
721
-
722
- origin_func = RuntimeCache.__getitem__
723
-
724
- def __patched_func(self, *args, **kwargs):
725
- ret = origin_func(self, *args, **kwargs)
726
- if ret is None:
727
- logger.warning(
728
- f"DeepGEMM JIT code generation <gemm_fp8_fp8_bf16_nt>: M={M}, N={N}, K={K}. Please wait."
729
- )
730
- return ret
731
-
732
- RuntimeCache.__getitem__ = __patched_func
733
- yield
734
- RuntimeCache.__getitem__ = origin_func
735
-
736
-
737
707
  def w8a8_block_fp8_matmul(
738
708
  A: torch.Tensor,
739
709
  B: torch.Tensor,
@@ -804,12 +774,11 @@ def w8a8_block_fp8_matmul(
804
774
  )
805
775
 
806
776
  # deepgemm only support bf16
807
- if C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
777
+ if C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM:
808
778
  if supports_custom_op():
809
779
  torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
810
780
  else:
811
- with _log_jit_build(M, N, K):
812
- deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
781
+ deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
813
782
  else:
814
783
  kernel = (
815
784
  _w8a8_block_fp8_matmul_unrolledx4
@@ -12,8 +12,8 @@ try:
12
12
  except ImportError:
13
13
  VLLM_AVAILABLE = False
14
14
 
15
+ from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
15
16
  from sglang.srt.layers.quantization.fp8_kernel import (
16
- _enable_jit_deepgemm,
17
17
  per_token_group_quant_fp8,
18
18
  scaled_fp8_quant,
19
19
  sglang_per_token_quant_fp8,
@@ -143,7 +143,7 @@ def apply_w8a8_block_fp8_linear(
143
143
  )
144
144
  gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
145
145
  else:
146
- if _enable_jit_deepgemm:
146
+ if _ENABLE_JIT_DEEPGEMM:
147
147
  q_input, x_scale = sglang_per_token_group_quant_fp8(
148
148
  input_2d,
149
149
  block_size[1],
@@ -37,6 +37,14 @@ except ImportError:
37
37
  logger = logging.getLogger(__name__)
38
38
 
39
39
 
40
+ def check_marlin_format(hf_quant_cfg: Dict[str, Any]) -> bool:
41
+ # compat: gptqmodel and autogptq (eol) main use checkpoint_format: str
42
+ # compat: autogptq <=0.7.1 is_marlin_format: bool
43
+ return hf_quant_cfg.get("checkpoint_format") == "marlin" or hf_quant_cfg.get(
44
+ "is_marlin_format", False
45
+ )
46
+
47
+
40
48
  class GPTQConfig(QuantizationConfig):
41
49
  """Config class for GPTQ.
42
50
 
@@ -262,13 +270,15 @@ class GPTQMarlinConfig(QuantizationConfig):
262
270
 
263
271
  @classmethod
264
272
  def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
273
+ is_marlin_format = check_marlin_format(hf_quant_cfg)
274
+
265
275
  can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
266
276
 
267
277
  is_valid_user_quant = (
268
278
  user_quant is None or user_quant == "marlin" or user_quant == "gptq_marlin"
269
279
  )
270
280
 
271
- if can_convert and is_valid_user_quant:
281
+ if not is_marlin_format and can_convert and is_valid_user_quant:
272
282
  msg = (
273
283
  "The model is convertible to {} during runtime."
274
284
  " Using {} kernel.".format(cls.get_name(), cls.get_name())
@@ -276,7 +286,7 @@ class GPTQMarlinConfig(QuantizationConfig):
276
286
  logger.info(msg)
277
287
  return cls.get_name()
278
288
 
279
- if can_convert and user_quant == "gptq":
289
+ if not is_marlin_format and can_convert and user_quant == "gptq":
280
290
  logger.info(
281
291
  "Detected that the model can run with gptq_marlin"
282
292
  ", however you specified quantization=gptq explicitly,"
@@ -401,11 +411,7 @@ class MarlinConfig(QuantizationConfig):
401
411
 
402
412
  @classmethod
403
413
  def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
404
- # compat: autogptq >=0.8.0 use checkpoint_format: str
405
- # compat: autogptq <=0.7.1 is_marlin_format: bool
406
- is_marlin_format = hf_quant_cfg.get(
407
- "checkpoint_format"
408
- ) == "marlin" or hf_quant_cfg.get("is_marlin_format", False)
414
+ is_marlin_format = check_marlin_format(hf_quant_cfg)
409
415
 
410
416
  is_valid_user_quant = (
411
417
  user_quant is None or user_quant == "gptq" or user_quant == "marlin"
@@ -22,9 +22,9 @@ from sglang.srt.layers.quantization.utils import (
22
22
  requantize_with_max_scale,
23
23
  )
24
24
  from sglang.srt.layers.radix_attention import RadixAttention
25
- from sglang.srt.utils import is_cuda_available
25
+ from sglang.srt.utils import is_cuda
26
26
 
27
- if is_cuda_available():
27
+ if is_cuda():
28
28
  from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
29
29
 
30
30
  # Initialize logger for the module
@@ -11,10 +11,10 @@ from sglang.srt.layers.quantization.base_config import (
11
11
  QuantizeMethodBase,
12
12
  )
13
13
  from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
14
- from sglang.srt.utils import is_cuda_available, set_weight_attrs
14
+ from sglang.srt.utils import is_cuda, set_weight_attrs
15
15
 
16
- is_cuda = is_cuda_available()
17
- if is_cuda:
16
+ _is_cuda = is_cuda()
17
+ if _is_cuda:
18
18
  from sgl_kernel import int8_scaled_mm
19
19
 
20
20
 
@@ -8,11 +8,11 @@ import torch
8
8
  import torch.nn as nn
9
9
 
10
10
  from sglang.srt.custom_op import CustomOp
11
- from sglang.srt.utils import is_cuda_available
11
+ from sglang.srt.utils import is_cuda
12
12
 
13
- _is_cuda_available = is_cuda_available()
13
+ _is_cuda = is_cuda()
14
14
 
15
- if _is_cuda_available:
15
+ if _is_cuda:
16
16
  from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
17
17
  else:
18
18
  from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding
@@ -82,7 +82,7 @@ class RotaryEmbedding(CustomOp):
82
82
 
83
83
  cache = self._compute_cos_sin_cache()
84
84
  # NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
85
- if not _is_cuda_available:
85
+ if not _is_cuda:
86
86
  cache = cache.to(dtype)
87
87
  self.cos_sin_cache: torch.Tensor
88
88
  self.register_buffer("cos_sin_cache", cache, persistent=False)
@@ -149,7 +149,7 @@ class RotaryEmbedding(CustomOp):
149
149
  key: torch.Tensor,
150
150
  offsets: Optional[torch.Tensor] = None,
151
151
  ) -> Tuple[torch.Tensor, torch.Tensor]:
152
- if _is_cuda_available and (self.head_size in [64, 128, 256, 512]):
152
+ if _is_cuda and (self.head_size in [64, 128, 256, 512]):
153
153
  apply_rope_with_cos_sin_cache_inplace(
154
154
  positions=positions,
155
155
  query=query,
@@ -652,7 +652,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
652
652
  def forward(self, *args, **kwargs):
653
653
  if torch.compiler.is_compiling():
654
654
  return self.forward_native(*args, **kwargs)
655
- if _is_cuda_available:
655
+ if _is_cuda:
656
656
  return self.forward_cuda(*args, **kwargs)
657
657
  else:
658
658
  return self.forward_native(*args, **kwargs)
@@ -10,9 +10,9 @@ from sglang.srt.layers.dp_attention import get_attention_tp_group
10
10
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
11
11
  from sglang.srt.managers.schedule_batch import global_server_args_dict
12
12
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
13
- from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda_available
13
+ from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda
14
14
 
15
- if is_cuda_available():
15
+ if is_cuda():
16
16
  from sgl_kernel import (
17
17
  min_p_sampling_from_probs,
18
18
  top_k_renorm_prob,
@@ -30,6 +30,7 @@ from sglang.srt.managers.io_struct import (
30
30
  )
31
31
  from sglang.srt.managers.scheduler import run_scheduler_process
32
32
  from sglang.srt.server_args import PortArgs, ServerArgs
33
+ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
33
34
  from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket
34
35
  from sglang.utils import get_exception_traceback
35
36
 
@@ -174,6 +175,10 @@ class DataParallelController:
174
175
  if not server_args.enable_dp_attention:
175
176
  logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
176
177
 
178
+ memory_saver_adapter = TorchMemorySaverAdapter.create(
179
+ enable=server_args.enable_memory_saver
180
+ )
181
+
177
182
  # Launch tensor parallel scheduler processes
178
183
  scheduler_pipe_readers = []
179
184
  tp_size_per_node = server_args.tp_size // server_args.nnodes
@@ -208,7 +213,8 @@ class DataParallelController:
208
213
  target=run_scheduler_process,
209
214
  args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),
210
215
  )
211
- proc.start()
216
+ with memory_saver_adapter.configure_subprocess():
217
+ proc.start()
212
218
  self.scheduler_procs.append(proc)
213
219
  scheduler_pipe_readers.append(reader)
214
220