sglang 0.4.5.post2__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +3 -2
- sglang/compile_deep_gemm.py +136 -0
- sglang/lang/backend/openai.py +5 -1
- sglang/lang/backend/runtime_endpoint.py +5 -1
- sglang/srt/configs/model_config.py +4 -1
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/disaggregation/decode.py +43 -0
- sglang/srt/disaggregation/mini_lb.py +69 -8
- sglang/srt/disaggregation/mooncake/conn.py +1 -1
- sglang/srt/disaggregation/nixl/__init__.py +1 -0
- sglang/srt/disaggregation/nixl/conn.py +622 -0
- sglang/srt/disaggregation/prefill.py +100 -16
- sglang/srt/disaggregation/utils.py +17 -0
- sglang/srt/entrypoints/engine.py +4 -0
- sglang/srt/entrypoints/http_server.py +3 -7
- sglang/srt/function_call_parser.py +60 -0
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +781 -150
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
- sglang/srt/layers/dp_attention.py +1 -1
- sglang/srt/layers/layernorm.py +19 -4
- sglang/srt/layers/moe/ep_moe/layer.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
- sglang/srt/layers/quantization/deep_gemm.py +378 -0
- sglang/srt/layers/quantization/fp8_kernel.py +7 -38
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +13 -7
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/w8a8_int8.py +3 -3
- sglang/srt/layers/rotary_embedding.py +6 -6
- sglang/srt/layers/sampler.py +2 -2
- sglang/srt/managers/data_parallel_controller.py +7 -1
- sglang/srt/managers/io_struct.py +14 -3
- sglang/srt/managers/schedule_batch.py +13 -0
- sglang/srt/managers/scheduler.py +16 -6
- sglang/srt/managers/tokenizer_manager.py +115 -29
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +40 -32
- sglang/srt/mem_cache/memory_pool.py +31 -13
- sglang/srt/model_executor/cuda_graph_runner.py +13 -8
- sglang/srt/model_executor/model_runner.py +19 -4
- sglang/srt/models/deepseek_v2.py +9 -6
- sglang/srt/models/minicpm3.py +2 -2
- sglang/srt/models/minicpmo.py +17 -6
- sglang/srt/openai_api/adapter.py +71 -4
- sglang/srt/openai_api/protocol.py +6 -1
- sglang/srt/server_args.py +52 -40
- sglang/srt/speculative/build_eagle_tree.py +2 -2
- sglang/srt/speculative/eagle_utils.py +2 -2
- sglang/srt/speculative/eagle_worker.py +2 -7
- sglang/srt/utils.py +46 -5
- sglang/test/test_utils.py +3 -1
- sglang/version.py +1 -1
- {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +3 -3
- {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +62 -57
- {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,378 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
from contextlib import contextmanager
|
4
|
+
from dataclasses import dataclass
|
5
|
+
from enum import IntEnum, auto
|
6
|
+
from typing import Callable, Dict, List, Optional, Tuple
|
7
|
+
|
8
|
+
import torch
|
9
|
+
from tqdm.contrib.concurrent import thread_map
|
10
|
+
|
11
|
+
from sglang.srt.server_args import ServerArgs
|
12
|
+
from sglang.srt.utils import get_bool_env_var, get_device_sm, get_int_env_var, is_cuda
|
13
|
+
|
14
|
+
_ENABLE_JIT_DEEPGEMM = False
|
15
|
+
if is_cuda():
|
16
|
+
import deep_gemm
|
17
|
+
from deep_gemm import get_num_sms
|
18
|
+
from deep_gemm.jit_kernels.gemm import get_best_configs
|
19
|
+
from deep_gemm.jit_kernels.gemm import includes as deep_gemm_includes
|
20
|
+
from deep_gemm.jit_kernels.gemm import template as deep_gemm_gemm_template
|
21
|
+
from deep_gemm.jit_kernels.m_grouped_gemm import (
|
22
|
+
template as deep_gemm_grouped_gemm_template,
|
23
|
+
)
|
24
|
+
from deep_gemm.jit_kernels.tuner import jit_tuner
|
25
|
+
|
26
|
+
sm_version = get_device_sm()
|
27
|
+
if sm_version == 90:
|
28
|
+
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="false"):
|
29
|
+
_ENABLE_JIT_DEEPGEMM = True
|
30
|
+
|
31
|
+
logger = logging.getLogger(__name__)
|
32
|
+
|
33
|
+
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
|
34
|
+
_ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var(
|
35
|
+
"SGL_JIT_DEEPGEMM_PRECOMPILE", "true"
|
36
|
+
)
|
37
|
+
_DO_COMPILE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true")
|
38
|
+
_COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
|
39
|
+
_IN_PRE_COMPILE_STAGE = get_bool_env_var("SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE", "false")
|
40
|
+
|
41
|
+
# Force redirect deep_gemm cache_dir
|
42
|
+
os.environ["DG_CACHE_DIR"] = os.getenv(
|
43
|
+
"SGL_DG_CACHE_DIR", os.path.expanduser("~") + "/.cache/deep_gemm"
|
44
|
+
)
|
45
|
+
|
46
|
+
|
47
|
+
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
|
48
|
+
global _BUILTIN_M_LIST
|
49
|
+
global _DO_COMPILE
|
50
|
+
|
51
|
+
# Generate m_max
|
52
|
+
m_max = 1024 * 16
|
53
|
+
if server_args.chunked_prefill_size < 1:
|
54
|
+
m_max = 1024 * 64
|
55
|
+
elif server_args.chunked_prefill_size > 8192:
|
56
|
+
m_max = server_args.chunked_prefill_size * 2
|
57
|
+
m_max = min(1024 * 128, m_max)
|
58
|
+
_BUILTIN_M_LIST = list(range(1, m_max + 1))
|
59
|
+
|
60
|
+
# Check if is the first rank on node
|
61
|
+
_DO_COMPILE = ServerArgs.base_gpu_id == gpu_id
|
62
|
+
|
63
|
+
|
64
|
+
class DeepGemmKernelType(IntEnum):
|
65
|
+
GROUPED_GEMM_NT_F8F8BF16_MASKED = auto()
|
66
|
+
GROUPED_GEMM_NT_F8F8BF16_CONTIG = auto()
|
67
|
+
GEMM_NT_F8F8BF16 = auto()
|
68
|
+
|
69
|
+
|
70
|
+
@dataclass
|
71
|
+
class DeepGemmKernelHelper:
|
72
|
+
name: str
|
73
|
+
compile_func: Callable[
|
74
|
+
[
|
75
|
+
int,
|
76
|
+
int,
|
77
|
+
int,
|
78
|
+
Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
79
|
+
],
|
80
|
+
None,
|
81
|
+
]
|
82
|
+
configure_func: Callable[
|
83
|
+
[int, int, int, int, int],
|
84
|
+
Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
85
|
+
]
|
86
|
+
|
87
|
+
|
88
|
+
_INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dict()
|
89
|
+
|
90
|
+
|
91
|
+
def _compile_warning_1():
|
92
|
+
if not _IN_PRE_COMPILE_STAGE:
|
93
|
+
logger.warning(
|
94
|
+
"Entering DeepGEMM JIT Pre-Complie session. "
|
95
|
+
"And it may takes a long time(Typically 10-20 mins) "
|
96
|
+
"if you have not run `sglang.compile_deep_gemm`. "
|
97
|
+
"Recommand to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
98
|
+
" for pre-compilation to reduce the overhead if you have not run it before. "
|
99
|
+
"For example: "
|
100
|
+
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
|
101
|
+
)
|
102
|
+
|
103
|
+
|
104
|
+
def _compile_warning_2():
|
105
|
+
logger.warning(
|
106
|
+
"Entering DeepGEMM JIT Single Kernel Complie session. "
|
107
|
+
"And it will makes inference throughput becomes flaky. "
|
108
|
+
"Please run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
109
|
+
" for pre-compilation to solve this issue. "
|
110
|
+
"For example: "
|
111
|
+
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
|
112
|
+
)
|
113
|
+
|
114
|
+
|
115
|
+
def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
|
116
|
+
n: int,
|
117
|
+
k: int,
|
118
|
+
num_groups: int,
|
119
|
+
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
120
|
+
) -> None:
|
121
|
+
# Auto-tuning with compilation
|
122
|
+
global deep_gemm_includes, deep_gemm_grouped_gemm_template
|
123
|
+
_, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
124
|
+
_ = jit_tuner.compile_and_tune(
|
125
|
+
name="m_grouped_gemm_fp8_fp8_bf16_nt",
|
126
|
+
keys={
|
127
|
+
"N": n,
|
128
|
+
"K": k,
|
129
|
+
"BLOCK_M": block_m,
|
130
|
+
"BLOCK_N": block_n,
|
131
|
+
"SWIZZLE_D_MODE": smem_config[1],
|
132
|
+
"BLOCK_N_PADDING": smem_config[2],
|
133
|
+
"NUM_GROUPS": num_groups,
|
134
|
+
"NUM_STAGES": num_stages,
|
135
|
+
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
136
|
+
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
137
|
+
"GEMM_TYPE": "GroupedMasked",
|
138
|
+
},
|
139
|
+
space=(),
|
140
|
+
includes=deep_gemm_includes,
|
141
|
+
arg_defs=(
|
142
|
+
("lhs", torch.float8_e4m3fn),
|
143
|
+
("lhs_scales", torch.float),
|
144
|
+
("rhs", torch.float8_e4m3fn),
|
145
|
+
("rhs_scales", torch.float),
|
146
|
+
("out", torch.bfloat16),
|
147
|
+
("grouped_layout", torch.int32),
|
148
|
+
("m", int),
|
149
|
+
("stream", torch.cuda.Stream),
|
150
|
+
("num_sms", int),
|
151
|
+
("smem_size", int),
|
152
|
+
),
|
153
|
+
template=deep_gemm_grouped_gemm_template,
|
154
|
+
args=[],
|
155
|
+
)
|
156
|
+
|
157
|
+
|
158
|
+
def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
|
159
|
+
n: int,
|
160
|
+
k: int,
|
161
|
+
num_groups: int,
|
162
|
+
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
163
|
+
) -> None:
|
164
|
+
global deep_gemm_includes, deep_gemm_grouped_gemm_template
|
165
|
+
_, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
166
|
+
_ = jit_tuner.compile_and_tune(
|
167
|
+
name="m_grouped_gemm_fp8_fp8_bf16_nt",
|
168
|
+
keys={
|
169
|
+
"N": n,
|
170
|
+
"K": k,
|
171
|
+
"BLOCK_M": block_m,
|
172
|
+
"BLOCK_N": block_n,
|
173
|
+
"SWIZZLE_D_MODE": smem_config[1],
|
174
|
+
"BLOCK_N_PADDING": smem_config[2],
|
175
|
+
"NUM_GROUPS": num_groups,
|
176
|
+
"NUM_STAGES": num_stages,
|
177
|
+
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
178
|
+
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
179
|
+
"GEMM_TYPE": "GroupedContiguous",
|
180
|
+
},
|
181
|
+
space=(),
|
182
|
+
includes=deep_gemm_includes,
|
183
|
+
arg_defs=(
|
184
|
+
("lhs", torch.float8_e4m3fn),
|
185
|
+
("lhs_scales", torch.float),
|
186
|
+
("rhs", torch.float8_e4m3fn),
|
187
|
+
("rhs_scales", torch.float),
|
188
|
+
("out", torch.bfloat16),
|
189
|
+
("grouped_layout", torch.int32),
|
190
|
+
("m", int),
|
191
|
+
("num_groups", int),
|
192
|
+
("stream", torch.cuda.Stream),
|
193
|
+
("num_sms", int),
|
194
|
+
("smem_size", int),
|
195
|
+
),
|
196
|
+
template=deep_gemm_grouped_gemm_template,
|
197
|
+
args=[],
|
198
|
+
)
|
199
|
+
|
200
|
+
|
201
|
+
def _compile_gemm_nt_f8f8bf16_one(
|
202
|
+
n: int,
|
203
|
+
k: int,
|
204
|
+
_: int, # _ is a dummy parameter to align with other interfaces
|
205
|
+
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
206
|
+
) -> None:
|
207
|
+
global deep_gemm_includes, deep_gemm_gemm_template
|
208
|
+
_, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
209
|
+
_ = jit_tuner.compile_and_tune(
|
210
|
+
name="gemm_fp8_fp8_bf16_nt",
|
211
|
+
keys={
|
212
|
+
"N": n,
|
213
|
+
"K": k,
|
214
|
+
"BLOCK_M": block_m,
|
215
|
+
"BLOCK_N": block_n,
|
216
|
+
"SWIZZLE_D_MODE": smem_config[1],
|
217
|
+
"BLOCK_N_PADDING": smem_config[2],
|
218
|
+
"NUM_STAGES": num_stages,
|
219
|
+
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
220
|
+
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
221
|
+
},
|
222
|
+
space=(),
|
223
|
+
includes=deep_gemm_includes,
|
224
|
+
arg_defs=(
|
225
|
+
("lhs", torch.float8_e4m3fn),
|
226
|
+
("lhs_scales", torch.float),
|
227
|
+
("rhs", torch.float8_e4m3fn),
|
228
|
+
("rhs_scales", torch.float),
|
229
|
+
("out", torch.bfloat16),
|
230
|
+
("m", int),
|
231
|
+
("stream", torch.cuda.Stream),
|
232
|
+
("num_sms", int),
|
233
|
+
("smem_size", int),
|
234
|
+
),
|
235
|
+
template=deep_gemm_gemm_template,
|
236
|
+
args=[],
|
237
|
+
)
|
238
|
+
|
239
|
+
|
240
|
+
_KERNEL_HELPER_DICT: Dict[DeepGemmKernelType, DeepGemmKernelHelper] = {
|
241
|
+
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: DeepGemmKernelHelper(
|
242
|
+
name="m_grouped_gemm_fp8_fp8_bf16_nt_masked",
|
243
|
+
compile_func=_compile_grouped_gemm_nt_f8f8bf16_masked_one,
|
244
|
+
configure_func=lambda m, n, k, num_groups, num_sms: get_best_configs(
|
245
|
+
m, n, k, num_groups, num_sms, is_grouped_masked=True
|
246
|
+
),
|
247
|
+
),
|
248
|
+
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG: DeepGemmKernelHelper(
|
249
|
+
name="m_grouped_gemm_fp8_fp8_bf16_nt_contiguous",
|
250
|
+
compile_func=_compile_grouped_gemm_nt_f8f8bf16_contig_one,
|
251
|
+
configure_func=lambda m, n, k, _, num_sms: get_best_configs(
|
252
|
+
m, n, k, 1, num_sms, is_grouped_contiguous=True
|
253
|
+
),
|
254
|
+
),
|
255
|
+
DeepGemmKernelType.GEMM_NT_F8F8BF16: DeepGemmKernelHelper(
|
256
|
+
name="gemm_fp8_fp8_bf16_nt",
|
257
|
+
compile_func=_compile_gemm_nt_f8f8bf16_one,
|
258
|
+
configure_func=lambda m, n, k, _, num_sms: get_best_configs(
|
259
|
+
m, n, k, 1, num_sms
|
260
|
+
),
|
261
|
+
),
|
262
|
+
}
|
263
|
+
|
264
|
+
|
265
|
+
def _maybe_compile_deep_gemm_one_type_all(
|
266
|
+
kernel_type: DeepGemmKernelType,
|
267
|
+
n: int,
|
268
|
+
k: int,
|
269
|
+
num_groups: int,
|
270
|
+
m_list: Optional[List[int]] = None,
|
271
|
+
) -> None:
|
272
|
+
|
273
|
+
global _INITIALIZATION_DICT
|
274
|
+
global _BUILTIN_M_LIST
|
275
|
+
|
276
|
+
query_key = (kernel_type, n, k, num_groups)
|
277
|
+
if (
|
278
|
+
_ENABLE_JIT_DEEPGEMM_PRECOMPILE
|
279
|
+
and _DO_COMPILE
|
280
|
+
and _INITIALIZATION_DICT.get(query_key) is None
|
281
|
+
):
|
282
|
+
_INITIALIZATION_DICT[query_key] = True
|
283
|
+
|
284
|
+
kernel_helper = _KERNEL_HELPER_DICT[kernel_type]
|
285
|
+
_compile_warning_1()
|
286
|
+
logger.info(
|
287
|
+
f"Try DeepGEMM JIT Compiling for "
|
288
|
+
f"<{kernel_helper.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
|
289
|
+
f"{' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. ' if not _IN_PRE_COMPILE_STAGE else ''}"
|
290
|
+
)
|
291
|
+
|
292
|
+
# NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
|
293
|
+
num_sms = get_num_sms()
|
294
|
+
collected_configs = set()
|
295
|
+
for m in m_list if m_list is not None else _BUILTIN_M_LIST:
|
296
|
+
# Put config into set to get unique configs and reduce cases to be compiled
|
297
|
+
collected_configs.add(
|
298
|
+
kernel_helper.configure_func(m, n, k, num_groups, num_sms)
|
299
|
+
)
|
300
|
+
compile_func = lambda config: kernel_helper.compile_func(
|
301
|
+
n, k, num_groups, config
|
302
|
+
)
|
303
|
+
thread_map(compile_func, collected_configs, max_workers=_COMPILE_WORKERS)
|
304
|
+
|
305
|
+
|
306
|
+
def grouped_gemm_nt_f8f8bf16_masked(
|
307
|
+
lhs: Tuple[torch.Tensor, torch.Tensor],
|
308
|
+
rhs: Tuple[torch.Tensor, torch.Tensor],
|
309
|
+
out: torch.Tensor,
|
310
|
+
masked_m: torch.Tensor,
|
311
|
+
expected_m: int,
|
312
|
+
):
|
313
|
+
num_groups, _, k = lhs[0].shape
|
314
|
+
_, n, _ = rhs[0].shape
|
315
|
+
|
316
|
+
kernel_type = DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED
|
317
|
+
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
|
318
|
+
|
319
|
+
with _log_jit_build(expected_m, n, k, kernel_type):
|
320
|
+
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
321
|
+
lhs, rhs, out, masked_m, expected_m
|
322
|
+
)
|
323
|
+
|
324
|
+
|
325
|
+
def grouped_gemm_nt_f8f8bf16_contig(
|
326
|
+
lhs: Tuple[torch.Tensor, torch.Tensor],
|
327
|
+
rhs: Tuple[torch.Tensor, torch.Tensor],
|
328
|
+
out: torch.Tensor,
|
329
|
+
m_indices: torch.Tensor,
|
330
|
+
):
|
331
|
+
m, k = lhs[0].shape
|
332
|
+
num_groups, n, _ = rhs[0].shape
|
333
|
+
|
334
|
+
kernel_type = DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
|
335
|
+
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
|
336
|
+
|
337
|
+
with _log_jit_build(m, n, k, kernel_type):
|
338
|
+
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs, rhs, out, m_indices)
|
339
|
+
|
340
|
+
|
341
|
+
def gemm_nt_f8f8bf16(
|
342
|
+
lhs: Tuple[torch.Tensor, torch.Tensor],
|
343
|
+
rhs: Tuple[torch.Tensor, torch.Tensor],
|
344
|
+
out: torch.Tensor,
|
345
|
+
):
|
346
|
+
m, k = lhs[0].shape
|
347
|
+
n, _ = rhs[0].shape
|
348
|
+
|
349
|
+
kernel_type = DeepGemmKernelType.GEMM_NT_F8F8BF16
|
350
|
+
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, 1)
|
351
|
+
|
352
|
+
with _log_jit_build(m, n, k, kernel_type):
|
353
|
+
deep_gemm.gemm_fp8_fp8_bf16_nt(lhs, rhs, out)
|
354
|
+
|
355
|
+
|
356
|
+
@contextmanager
|
357
|
+
def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
|
358
|
+
if _IN_PRE_COMPILE_STAGE:
|
359
|
+
yield
|
360
|
+
return
|
361
|
+
|
362
|
+
from deep_gemm.jit.runtime import RuntimeCache
|
363
|
+
|
364
|
+
origin_func = RuntimeCache.__getitem__
|
365
|
+
|
366
|
+
def __patched_func(self, *args, **kwargs):
|
367
|
+
ret = origin_func(self, *args, **kwargs)
|
368
|
+
if ret is None:
|
369
|
+
kernel_helper = _KERNEL_HELPER_DICT[kernel_type]
|
370
|
+
_compile_warning_2()
|
371
|
+
logger.warning(
|
372
|
+
f"DeepGEMM JIT Compiling for <{kernel_helper.name}> M={M}, N={N}, K={K}. Please wait."
|
373
|
+
)
|
374
|
+
return ret
|
375
|
+
|
376
|
+
RuntimeCache.__getitem__ = __patched_func
|
377
|
+
yield
|
378
|
+
RuntimeCache.__getitem__ = origin_func
|
@@ -16,19 +16,17 @@ import functools
|
|
16
16
|
import json
|
17
17
|
import logging
|
18
18
|
import os
|
19
|
-
from contextlib import contextmanager
|
20
19
|
from typing import Any, Dict, List, Optional, Tuple
|
21
20
|
|
22
21
|
import torch
|
23
22
|
import triton
|
24
23
|
import triton.language as tl
|
25
24
|
|
25
|
+
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
26
26
|
from sglang.srt.utils import (
|
27
27
|
direct_register_custom_op,
|
28
|
-
get_bool_env_var,
|
29
28
|
get_device_core_count,
|
30
29
|
get_device_name,
|
31
|
-
get_device_sm,
|
32
30
|
is_cuda,
|
33
31
|
is_hip,
|
34
32
|
supports_custom_op,
|
@@ -43,22 +41,16 @@ else:
|
|
43
41
|
fp8_max = torch.finfo(_fp8_type).max
|
44
42
|
fp8_min = -fp8_max
|
45
43
|
|
46
|
-
_enable_jit_deepgemm = False
|
47
|
-
_enable_jit_deepgemm_bmm = False
|
48
44
|
if _is_cuda:
|
49
|
-
import deep_gemm
|
50
45
|
from sgl_kernel import (
|
51
46
|
sgl_per_tensor_quant_fp8,
|
52
47
|
sgl_per_token_group_quant_fp8,
|
53
48
|
sgl_per_token_quant_fp8,
|
54
49
|
)
|
55
50
|
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
_enable_jit_deepgemm = True
|
60
|
-
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM_BMM", default="false"):
|
61
|
-
_enable_jit_deepgemm_bmm = True
|
51
|
+
from sglang.srt.layers.quantization.deep_gemm import (
|
52
|
+
gemm_nt_f8f8bf16 as deep_gemm_gemm_nt_f8f8bf16,
|
53
|
+
)
|
62
54
|
|
63
55
|
logger = logging.getLogger(__name__)
|
64
56
|
|
@@ -71,10 +63,7 @@ if supports_custom_op():
|
|
71
63
|
Bs: torch.Tensor,
|
72
64
|
C: torch.Tensor,
|
73
65
|
) -> None:
|
74
|
-
|
75
|
-
N, _ = B.shape
|
76
|
-
with _log_jit_build(M, N, K):
|
77
|
-
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
66
|
+
deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
|
78
67
|
|
79
68
|
def deep_gemm_fp8_fp8_bf16_nt_fake(
|
80
69
|
A: torch.Tensor,
|
@@ -715,25 +704,6 @@ def get_w8a8_block_fp8_configs(
|
|
715
704
|
return None
|
716
705
|
|
717
706
|
|
718
|
-
@contextmanager
|
719
|
-
def _log_jit_build(M: int, N: int, K: int):
|
720
|
-
from deep_gemm.jit.runtime import RuntimeCache
|
721
|
-
|
722
|
-
origin_func = RuntimeCache.__getitem__
|
723
|
-
|
724
|
-
def __patched_func(self, *args, **kwargs):
|
725
|
-
ret = origin_func(self, *args, **kwargs)
|
726
|
-
if ret is None:
|
727
|
-
logger.warning(
|
728
|
-
f"DeepGEMM JIT code generation <gemm_fp8_fp8_bf16_nt>: M={M}, N={N}, K={K}. Please wait."
|
729
|
-
)
|
730
|
-
return ret
|
731
|
-
|
732
|
-
RuntimeCache.__getitem__ = __patched_func
|
733
|
-
yield
|
734
|
-
RuntimeCache.__getitem__ = origin_func
|
735
|
-
|
736
|
-
|
737
707
|
def w8a8_block_fp8_matmul(
|
738
708
|
A: torch.Tensor,
|
739
709
|
B: torch.Tensor,
|
@@ -804,12 +774,11 @@ def w8a8_block_fp8_matmul(
|
|
804
774
|
)
|
805
775
|
|
806
776
|
# deepgemm only support bf16
|
807
|
-
if C.dtype == torch.bfloat16 and
|
777
|
+
if C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM:
|
808
778
|
if supports_custom_op():
|
809
779
|
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
|
810
780
|
else:
|
811
|
-
|
812
|
-
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
781
|
+
deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
|
813
782
|
else:
|
814
783
|
kernel = (
|
815
784
|
_w8a8_block_fp8_matmul_unrolledx4
|
@@ -12,8 +12,8 @@ try:
|
|
12
12
|
except ImportError:
|
13
13
|
VLLM_AVAILABLE = False
|
14
14
|
|
15
|
+
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
15
16
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
16
|
-
_enable_jit_deepgemm,
|
17
17
|
per_token_group_quant_fp8,
|
18
18
|
scaled_fp8_quant,
|
19
19
|
sglang_per_token_quant_fp8,
|
@@ -143,7 +143,7 @@ def apply_w8a8_block_fp8_linear(
|
|
143
143
|
)
|
144
144
|
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
|
145
145
|
else:
|
146
|
-
if
|
146
|
+
if _ENABLE_JIT_DEEPGEMM:
|
147
147
|
q_input, x_scale = sglang_per_token_group_quant_fp8(
|
148
148
|
input_2d,
|
149
149
|
block_size[1],
|
@@ -37,6 +37,14 @@ except ImportError:
|
|
37
37
|
logger = logging.getLogger(__name__)
|
38
38
|
|
39
39
|
|
40
|
+
def check_marlin_format(hf_quant_cfg: Dict[str, Any]) -> bool:
|
41
|
+
# compat: gptqmodel and autogptq (eol) main use checkpoint_format: str
|
42
|
+
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
43
|
+
return hf_quant_cfg.get("checkpoint_format") == "marlin" or hf_quant_cfg.get(
|
44
|
+
"is_marlin_format", False
|
45
|
+
)
|
46
|
+
|
47
|
+
|
40
48
|
class GPTQConfig(QuantizationConfig):
|
41
49
|
"""Config class for GPTQ.
|
42
50
|
|
@@ -262,13 +270,15 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|
262
270
|
|
263
271
|
@classmethod
|
264
272
|
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
273
|
+
is_marlin_format = check_marlin_format(hf_quant_cfg)
|
274
|
+
|
265
275
|
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
|
266
276
|
|
267
277
|
is_valid_user_quant = (
|
268
278
|
user_quant is None or user_quant == "marlin" or user_quant == "gptq_marlin"
|
269
279
|
)
|
270
280
|
|
271
|
-
if can_convert and is_valid_user_quant:
|
281
|
+
if not is_marlin_format and can_convert and is_valid_user_quant:
|
272
282
|
msg = (
|
273
283
|
"The model is convertible to {} during runtime."
|
274
284
|
" Using {} kernel.".format(cls.get_name(), cls.get_name())
|
@@ -276,7 +286,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|
276
286
|
logger.info(msg)
|
277
287
|
return cls.get_name()
|
278
288
|
|
279
|
-
if can_convert and user_quant == "gptq":
|
289
|
+
if not is_marlin_format and can_convert and user_quant == "gptq":
|
280
290
|
logger.info(
|
281
291
|
"Detected that the model can run with gptq_marlin"
|
282
292
|
", however you specified quantization=gptq explicitly,"
|
@@ -401,11 +411,7 @@ class MarlinConfig(QuantizationConfig):
|
|
401
411
|
|
402
412
|
@classmethod
|
403
413
|
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
404
|
-
|
405
|
-
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
406
|
-
is_marlin_format = hf_quant_cfg.get(
|
407
|
-
"checkpoint_format"
|
408
|
-
) == "marlin" or hf_quant_cfg.get("is_marlin_format", False)
|
414
|
+
is_marlin_format = check_marlin_format(hf_quant_cfg)
|
409
415
|
|
410
416
|
is_valid_user_quant = (
|
411
417
|
user_quant is None or user_quant == "gptq" or user_quant == "marlin"
|
@@ -22,9 +22,9 @@ from sglang.srt.layers.quantization.utils import (
|
|
22
22
|
requantize_with_max_scale,
|
23
23
|
)
|
24
24
|
from sglang.srt.layers.radix_attention import RadixAttention
|
25
|
-
from sglang.srt.utils import
|
25
|
+
from sglang.srt.utils import is_cuda
|
26
26
|
|
27
|
-
if
|
27
|
+
if is_cuda():
|
28
28
|
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
29
29
|
|
30
30
|
# Initialize logger for the module
|
@@ -11,10 +11,10 @@ from sglang.srt.layers.quantization.base_config import (
|
|
11
11
|
QuantizeMethodBase,
|
12
12
|
)
|
13
13
|
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
14
|
-
from sglang.srt.utils import
|
14
|
+
from sglang.srt.utils import is_cuda, set_weight_attrs
|
15
15
|
|
16
|
-
|
17
|
-
if
|
16
|
+
_is_cuda = is_cuda()
|
17
|
+
if _is_cuda:
|
18
18
|
from sgl_kernel import int8_scaled_mm
|
19
19
|
|
20
20
|
|
@@ -8,11 +8,11 @@ import torch
|
|
8
8
|
import torch.nn as nn
|
9
9
|
|
10
10
|
from sglang.srt.custom_op import CustomOp
|
11
|
-
from sglang.srt.utils import
|
11
|
+
from sglang.srt.utils import is_cuda
|
12
12
|
|
13
|
-
|
13
|
+
_is_cuda = is_cuda()
|
14
14
|
|
15
|
-
if
|
15
|
+
if _is_cuda:
|
16
16
|
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
17
17
|
else:
|
18
18
|
from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding
|
@@ -82,7 +82,7 @@ class RotaryEmbedding(CustomOp):
|
|
82
82
|
|
83
83
|
cache = self._compute_cos_sin_cache()
|
84
84
|
# NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
|
85
|
-
if not
|
85
|
+
if not _is_cuda:
|
86
86
|
cache = cache.to(dtype)
|
87
87
|
self.cos_sin_cache: torch.Tensor
|
88
88
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
@@ -149,7 +149,7 @@ class RotaryEmbedding(CustomOp):
|
|
149
149
|
key: torch.Tensor,
|
150
150
|
offsets: Optional[torch.Tensor] = None,
|
151
151
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
152
|
-
if
|
152
|
+
if _is_cuda and (self.head_size in [64, 128, 256, 512]):
|
153
153
|
apply_rope_with_cos_sin_cache_inplace(
|
154
154
|
positions=positions,
|
155
155
|
query=query,
|
@@ -652,7 +652,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
652
652
|
def forward(self, *args, **kwargs):
|
653
653
|
if torch.compiler.is_compiling():
|
654
654
|
return self.forward_native(*args, **kwargs)
|
655
|
-
if
|
655
|
+
if _is_cuda:
|
656
656
|
return self.forward_cuda(*args, **kwargs)
|
657
657
|
else:
|
658
658
|
return self.forward_native(*args, **kwargs)
|
sglang/srt/layers/sampler.py
CHANGED
@@ -10,9 +10,9 @@ from sglang.srt.layers.dp_attention import get_attention_tp_group
|
|
10
10
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
11
11
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
12
12
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
13
|
-
from sglang.srt.utils import crash_on_warnings, get_bool_env_var,
|
13
|
+
from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda
|
14
14
|
|
15
|
-
if
|
15
|
+
if is_cuda():
|
16
16
|
from sgl_kernel import (
|
17
17
|
min_p_sampling_from_probs,
|
18
18
|
top_k_renorm_prob,
|
@@ -30,6 +30,7 @@ from sglang.srt.managers.io_struct import (
|
|
30
30
|
)
|
31
31
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
32
32
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
33
|
+
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
33
34
|
from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket
|
34
35
|
from sglang.utils import get_exception_traceback
|
35
36
|
|
@@ -174,6 +175,10 @@ class DataParallelController:
|
|
174
175
|
if not server_args.enable_dp_attention:
|
175
176
|
logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
|
176
177
|
|
178
|
+
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
179
|
+
enable=server_args.enable_memory_saver
|
180
|
+
)
|
181
|
+
|
177
182
|
# Launch tensor parallel scheduler processes
|
178
183
|
scheduler_pipe_readers = []
|
179
184
|
tp_size_per_node = server_args.tp_size // server_args.nnodes
|
@@ -208,7 +213,8 @@ class DataParallelController:
|
|
208
213
|
target=run_scheduler_process,
|
209
214
|
args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),
|
210
215
|
)
|
211
|
-
|
216
|
+
with memory_saver_adapter.configure_subprocess():
|
217
|
+
proc.start()
|
212
218
|
self.scheduler_procs.append(proc)
|
213
219
|
scheduler_pipe_readers.append(reader)
|
214
220
|
|