sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +3 -6
- sglang/compile_deep_gemm.py +136 -0
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +6 -2
- sglang/lang/backend/runtime_endpoint.py +5 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +4 -1
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/decode.py +105 -6
- sglang/srt/disaggregation/mini_lb.py +74 -9
- sglang/srt/disaggregation/mooncake/conn.py +33 -63
- sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
- sglang/srt/disaggregation/nixl/__init__.py +1 -0
- sglang/srt/disaggregation/nixl/conn.py +622 -0
- sglang/srt/disaggregation/prefill.py +137 -17
- sglang/srt/disaggregation/utils.py +32 -0
- sglang/srt/entrypoints/engine.py +4 -0
- sglang/srt/entrypoints/http_server.py +3 -7
- sglang/srt/entrypoints/verl_engine.py +7 -5
- sglang/srt/function_call_parser.py +60 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +883 -209
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +18 -7
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
- sglang/srt/layers/dp_attention.py +1 -1
- sglang/srt/layers/layernorm.py +20 -5
- sglang/srt/layers/linear.py +17 -3
- sglang/srt/layers/moe/ep_moe/layer.py +17 -29
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/topk.py +27 -30
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +1 -0
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +9 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/deep_gemm.py +378 -0
- sglang/srt/layers/quantization/fp8.py +115 -132
- sglang/srt/layers/quantization/fp8_kernel.py +213 -88
- sglang/srt/layers/quantization/fp8_utils.py +189 -264
- sglang/srt/layers/quantization/gptq.py +13 -7
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -7
- sglang/srt/layers/radix_attention.py +15 -0
- sglang/srt/layers/rotary_embedding.py +9 -8
- sglang/srt/layers/sampler.py +7 -12
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +7 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/mm_utils.py +4 -3
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
- sglang/srt/managers/schedule_batch.py +15 -4
- sglang/srt/managers/scheduler.py +28 -77
- sglang/srt/managers/tokenizer_manager.py +116 -29
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +41 -29
- sglang/srt/mem_cache/memory_pool.py +38 -15
- sglang/srt/model_executor/cuda_graph_runner.py +15 -10
- sglang/srt/model_executor/model_runner.py +39 -31
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +292 -348
- sglang/srt/models/llama.py +5 -5
- sglang/srt/models/minicpm3.py +31 -203
- sglang/srt/models/minicpmo.py +17 -6
- sglang/srt/models/qwen2.py +4 -1
- sglang/srt/models/qwen2_moe.py +14 -13
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/openai_api/adapter.py +71 -4
- sglang/srt/openai_api/protocol.py +6 -1
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +86 -72
- sglang/srt/speculative/build_eagle_tree.py +2 -2
- sglang/srt/speculative/eagle_utils.py +2 -2
- sglang/srt/speculative/eagle_worker.py +6 -14
- sglang/srt/utils.py +62 -6
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +167 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/version.py +1 -1
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +5 -5
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +116 -110
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,378 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
from contextlib import contextmanager
|
4
|
+
from dataclasses import dataclass
|
5
|
+
from enum import IntEnum, auto
|
6
|
+
from typing import Callable, Dict, List, Optional, Tuple
|
7
|
+
|
8
|
+
import torch
|
9
|
+
from tqdm.contrib.concurrent import thread_map
|
10
|
+
|
11
|
+
from sglang.srt.server_args import ServerArgs
|
12
|
+
from sglang.srt.utils import get_bool_env_var, get_device_sm, get_int_env_var, is_cuda
|
13
|
+
|
14
|
+
_ENABLE_JIT_DEEPGEMM = False
|
15
|
+
if is_cuda():
|
16
|
+
import deep_gemm
|
17
|
+
from deep_gemm import get_num_sms
|
18
|
+
from deep_gemm.jit_kernels.gemm import get_best_configs
|
19
|
+
from deep_gemm.jit_kernels.gemm import includes as deep_gemm_includes
|
20
|
+
from deep_gemm.jit_kernels.gemm import template as deep_gemm_gemm_template
|
21
|
+
from deep_gemm.jit_kernels.m_grouped_gemm import (
|
22
|
+
template as deep_gemm_grouped_gemm_template,
|
23
|
+
)
|
24
|
+
from deep_gemm.jit_kernels.tuner import jit_tuner
|
25
|
+
|
26
|
+
sm_version = get_device_sm()
|
27
|
+
if sm_version == 90:
|
28
|
+
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="false"):
|
29
|
+
_ENABLE_JIT_DEEPGEMM = True
|
30
|
+
|
31
|
+
logger = logging.getLogger(__name__)
|
32
|
+
|
33
|
+
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
|
34
|
+
_ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var(
|
35
|
+
"SGL_JIT_DEEPGEMM_PRECOMPILE", "true"
|
36
|
+
)
|
37
|
+
_DO_COMPILE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true")
|
38
|
+
_COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
|
39
|
+
_IN_PRE_COMPILE_STAGE = get_bool_env_var("SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE", "false")
|
40
|
+
|
41
|
+
# Force redirect deep_gemm cache_dir
|
42
|
+
os.environ["DG_CACHE_DIR"] = os.getenv(
|
43
|
+
"SGL_DG_CACHE_DIR", os.path.expanduser("~") + "/.cache/deep_gemm"
|
44
|
+
)
|
45
|
+
|
46
|
+
|
47
|
+
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
|
48
|
+
global _BUILTIN_M_LIST
|
49
|
+
global _DO_COMPILE
|
50
|
+
|
51
|
+
# Generate m_max
|
52
|
+
m_max = 1024 * 16
|
53
|
+
if server_args.chunked_prefill_size < 1:
|
54
|
+
m_max = 1024 * 64
|
55
|
+
elif server_args.chunked_prefill_size > 8192:
|
56
|
+
m_max = server_args.chunked_prefill_size * 2
|
57
|
+
m_max = min(1024 * 128, m_max)
|
58
|
+
_BUILTIN_M_LIST = list(range(1, m_max + 1))
|
59
|
+
|
60
|
+
# Check if is the first rank on node
|
61
|
+
_DO_COMPILE = ServerArgs.base_gpu_id == gpu_id
|
62
|
+
|
63
|
+
|
64
|
+
class DeepGemmKernelType(IntEnum):
|
65
|
+
GROUPED_GEMM_NT_F8F8BF16_MASKED = auto()
|
66
|
+
GROUPED_GEMM_NT_F8F8BF16_CONTIG = auto()
|
67
|
+
GEMM_NT_F8F8BF16 = auto()
|
68
|
+
|
69
|
+
|
70
|
+
@dataclass
|
71
|
+
class DeepGemmKernelHelper:
|
72
|
+
name: str
|
73
|
+
compile_func: Callable[
|
74
|
+
[
|
75
|
+
int,
|
76
|
+
int,
|
77
|
+
int,
|
78
|
+
Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
79
|
+
],
|
80
|
+
None,
|
81
|
+
]
|
82
|
+
configure_func: Callable[
|
83
|
+
[int, int, int, int, int],
|
84
|
+
Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
85
|
+
]
|
86
|
+
|
87
|
+
|
88
|
+
_INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dict()
|
89
|
+
|
90
|
+
|
91
|
+
def _compile_warning_1():
|
92
|
+
if not _IN_PRE_COMPILE_STAGE:
|
93
|
+
logger.warning(
|
94
|
+
"Entering DeepGEMM JIT Pre-Complie session. "
|
95
|
+
"And it may takes a long time(Typically 10-20 mins) "
|
96
|
+
"if you have not run `sglang.compile_deep_gemm`. "
|
97
|
+
"Recommand to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
98
|
+
" for pre-compilation to reduce the overhead if you have not run it before. "
|
99
|
+
"For example: "
|
100
|
+
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
|
101
|
+
)
|
102
|
+
|
103
|
+
|
104
|
+
def _compile_warning_2():
|
105
|
+
logger.warning(
|
106
|
+
"Entering DeepGEMM JIT Single Kernel Complie session. "
|
107
|
+
"And it will makes inference throughput becomes flaky. "
|
108
|
+
"Please run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
109
|
+
" for pre-compilation to solve this issue. "
|
110
|
+
"For example: "
|
111
|
+
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
|
112
|
+
)
|
113
|
+
|
114
|
+
|
115
|
+
def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
|
116
|
+
n: int,
|
117
|
+
k: int,
|
118
|
+
num_groups: int,
|
119
|
+
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
120
|
+
) -> None:
|
121
|
+
# Auto-tuning with compilation
|
122
|
+
global deep_gemm_includes, deep_gemm_grouped_gemm_template
|
123
|
+
_, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
124
|
+
_ = jit_tuner.compile_and_tune(
|
125
|
+
name="m_grouped_gemm_fp8_fp8_bf16_nt",
|
126
|
+
keys={
|
127
|
+
"N": n,
|
128
|
+
"K": k,
|
129
|
+
"BLOCK_M": block_m,
|
130
|
+
"BLOCK_N": block_n,
|
131
|
+
"SWIZZLE_D_MODE": smem_config[1],
|
132
|
+
"BLOCK_N_PADDING": smem_config[2],
|
133
|
+
"NUM_GROUPS": num_groups,
|
134
|
+
"NUM_STAGES": num_stages,
|
135
|
+
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
136
|
+
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
137
|
+
"GEMM_TYPE": "GroupedMasked",
|
138
|
+
},
|
139
|
+
space=(),
|
140
|
+
includes=deep_gemm_includes,
|
141
|
+
arg_defs=(
|
142
|
+
("lhs", torch.float8_e4m3fn),
|
143
|
+
("lhs_scales", torch.float),
|
144
|
+
("rhs", torch.float8_e4m3fn),
|
145
|
+
("rhs_scales", torch.float),
|
146
|
+
("out", torch.bfloat16),
|
147
|
+
("grouped_layout", torch.int32),
|
148
|
+
("m", int),
|
149
|
+
("stream", torch.cuda.Stream),
|
150
|
+
("num_sms", int),
|
151
|
+
("smem_size", int),
|
152
|
+
),
|
153
|
+
template=deep_gemm_grouped_gemm_template,
|
154
|
+
args=[],
|
155
|
+
)
|
156
|
+
|
157
|
+
|
158
|
+
def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
|
159
|
+
n: int,
|
160
|
+
k: int,
|
161
|
+
num_groups: int,
|
162
|
+
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
163
|
+
) -> None:
|
164
|
+
global deep_gemm_includes, deep_gemm_grouped_gemm_template
|
165
|
+
_, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
166
|
+
_ = jit_tuner.compile_and_tune(
|
167
|
+
name="m_grouped_gemm_fp8_fp8_bf16_nt",
|
168
|
+
keys={
|
169
|
+
"N": n,
|
170
|
+
"K": k,
|
171
|
+
"BLOCK_M": block_m,
|
172
|
+
"BLOCK_N": block_n,
|
173
|
+
"SWIZZLE_D_MODE": smem_config[1],
|
174
|
+
"BLOCK_N_PADDING": smem_config[2],
|
175
|
+
"NUM_GROUPS": num_groups,
|
176
|
+
"NUM_STAGES": num_stages,
|
177
|
+
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
178
|
+
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
179
|
+
"GEMM_TYPE": "GroupedContiguous",
|
180
|
+
},
|
181
|
+
space=(),
|
182
|
+
includes=deep_gemm_includes,
|
183
|
+
arg_defs=(
|
184
|
+
("lhs", torch.float8_e4m3fn),
|
185
|
+
("lhs_scales", torch.float),
|
186
|
+
("rhs", torch.float8_e4m3fn),
|
187
|
+
("rhs_scales", torch.float),
|
188
|
+
("out", torch.bfloat16),
|
189
|
+
("grouped_layout", torch.int32),
|
190
|
+
("m", int),
|
191
|
+
("num_groups", int),
|
192
|
+
("stream", torch.cuda.Stream),
|
193
|
+
("num_sms", int),
|
194
|
+
("smem_size", int),
|
195
|
+
),
|
196
|
+
template=deep_gemm_grouped_gemm_template,
|
197
|
+
args=[],
|
198
|
+
)
|
199
|
+
|
200
|
+
|
201
|
+
def _compile_gemm_nt_f8f8bf16_one(
|
202
|
+
n: int,
|
203
|
+
k: int,
|
204
|
+
_: int, # _ is a dummy parameter to align with other interfaces
|
205
|
+
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
206
|
+
) -> None:
|
207
|
+
global deep_gemm_includes, deep_gemm_gemm_template
|
208
|
+
_, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
209
|
+
_ = jit_tuner.compile_and_tune(
|
210
|
+
name="gemm_fp8_fp8_bf16_nt",
|
211
|
+
keys={
|
212
|
+
"N": n,
|
213
|
+
"K": k,
|
214
|
+
"BLOCK_M": block_m,
|
215
|
+
"BLOCK_N": block_n,
|
216
|
+
"SWIZZLE_D_MODE": smem_config[1],
|
217
|
+
"BLOCK_N_PADDING": smem_config[2],
|
218
|
+
"NUM_STAGES": num_stages,
|
219
|
+
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
220
|
+
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
221
|
+
},
|
222
|
+
space=(),
|
223
|
+
includes=deep_gemm_includes,
|
224
|
+
arg_defs=(
|
225
|
+
("lhs", torch.float8_e4m3fn),
|
226
|
+
("lhs_scales", torch.float),
|
227
|
+
("rhs", torch.float8_e4m3fn),
|
228
|
+
("rhs_scales", torch.float),
|
229
|
+
("out", torch.bfloat16),
|
230
|
+
("m", int),
|
231
|
+
("stream", torch.cuda.Stream),
|
232
|
+
("num_sms", int),
|
233
|
+
("smem_size", int),
|
234
|
+
),
|
235
|
+
template=deep_gemm_gemm_template,
|
236
|
+
args=[],
|
237
|
+
)
|
238
|
+
|
239
|
+
|
240
|
+
_KERNEL_HELPER_DICT: Dict[DeepGemmKernelType, DeepGemmKernelHelper] = {
|
241
|
+
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: DeepGemmKernelHelper(
|
242
|
+
name="m_grouped_gemm_fp8_fp8_bf16_nt_masked",
|
243
|
+
compile_func=_compile_grouped_gemm_nt_f8f8bf16_masked_one,
|
244
|
+
configure_func=lambda m, n, k, num_groups, num_sms: get_best_configs(
|
245
|
+
m, n, k, num_groups, num_sms, is_grouped_masked=True
|
246
|
+
),
|
247
|
+
),
|
248
|
+
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG: DeepGemmKernelHelper(
|
249
|
+
name="m_grouped_gemm_fp8_fp8_bf16_nt_contiguous",
|
250
|
+
compile_func=_compile_grouped_gemm_nt_f8f8bf16_contig_one,
|
251
|
+
configure_func=lambda m, n, k, _, num_sms: get_best_configs(
|
252
|
+
m, n, k, 1, num_sms, is_grouped_contiguous=True
|
253
|
+
),
|
254
|
+
),
|
255
|
+
DeepGemmKernelType.GEMM_NT_F8F8BF16: DeepGemmKernelHelper(
|
256
|
+
name="gemm_fp8_fp8_bf16_nt",
|
257
|
+
compile_func=_compile_gemm_nt_f8f8bf16_one,
|
258
|
+
configure_func=lambda m, n, k, _, num_sms: get_best_configs(
|
259
|
+
m, n, k, 1, num_sms
|
260
|
+
),
|
261
|
+
),
|
262
|
+
}
|
263
|
+
|
264
|
+
|
265
|
+
def _maybe_compile_deep_gemm_one_type_all(
|
266
|
+
kernel_type: DeepGemmKernelType,
|
267
|
+
n: int,
|
268
|
+
k: int,
|
269
|
+
num_groups: int,
|
270
|
+
m_list: Optional[List[int]] = None,
|
271
|
+
) -> None:
|
272
|
+
|
273
|
+
global _INITIALIZATION_DICT
|
274
|
+
global _BUILTIN_M_LIST
|
275
|
+
|
276
|
+
query_key = (kernel_type, n, k, num_groups)
|
277
|
+
if (
|
278
|
+
_ENABLE_JIT_DEEPGEMM_PRECOMPILE
|
279
|
+
and _DO_COMPILE
|
280
|
+
and _INITIALIZATION_DICT.get(query_key) is None
|
281
|
+
):
|
282
|
+
_INITIALIZATION_DICT[query_key] = True
|
283
|
+
|
284
|
+
kernel_helper = _KERNEL_HELPER_DICT[kernel_type]
|
285
|
+
_compile_warning_1()
|
286
|
+
logger.info(
|
287
|
+
f"Try DeepGEMM JIT Compiling for "
|
288
|
+
f"<{kernel_helper.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
|
289
|
+
f"{' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. ' if not _IN_PRE_COMPILE_STAGE else ''}"
|
290
|
+
)
|
291
|
+
|
292
|
+
# NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
|
293
|
+
num_sms = get_num_sms()
|
294
|
+
collected_configs = set()
|
295
|
+
for m in m_list if m_list is not None else _BUILTIN_M_LIST:
|
296
|
+
# Put config into set to get unique configs and reduce cases to be compiled
|
297
|
+
collected_configs.add(
|
298
|
+
kernel_helper.configure_func(m, n, k, num_groups, num_sms)
|
299
|
+
)
|
300
|
+
compile_func = lambda config: kernel_helper.compile_func(
|
301
|
+
n, k, num_groups, config
|
302
|
+
)
|
303
|
+
thread_map(compile_func, collected_configs, max_workers=_COMPILE_WORKERS)
|
304
|
+
|
305
|
+
|
306
|
+
def grouped_gemm_nt_f8f8bf16_masked(
|
307
|
+
lhs: Tuple[torch.Tensor, torch.Tensor],
|
308
|
+
rhs: Tuple[torch.Tensor, torch.Tensor],
|
309
|
+
out: torch.Tensor,
|
310
|
+
masked_m: torch.Tensor,
|
311
|
+
expected_m: int,
|
312
|
+
):
|
313
|
+
num_groups, _, k = lhs[0].shape
|
314
|
+
_, n, _ = rhs[0].shape
|
315
|
+
|
316
|
+
kernel_type = DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED
|
317
|
+
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
|
318
|
+
|
319
|
+
with _log_jit_build(expected_m, n, k, kernel_type):
|
320
|
+
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
321
|
+
lhs, rhs, out, masked_m, expected_m
|
322
|
+
)
|
323
|
+
|
324
|
+
|
325
|
+
def grouped_gemm_nt_f8f8bf16_contig(
|
326
|
+
lhs: Tuple[torch.Tensor, torch.Tensor],
|
327
|
+
rhs: Tuple[torch.Tensor, torch.Tensor],
|
328
|
+
out: torch.Tensor,
|
329
|
+
m_indices: torch.Tensor,
|
330
|
+
):
|
331
|
+
m, k = lhs[0].shape
|
332
|
+
num_groups, n, _ = rhs[0].shape
|
333
|
+
|
334
|
+
kernel_type = DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
|
335
|
+
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
|
336
|
+
|
337
|
+
with _log_jit_build(m, n, k, kernel_type):
|
338
|
+
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs, rhs, out, m_indices)
|
339
|
+
|
340
|
+
|
341
|
+
def gemm_nt_f8f8bf16(
|
342
|
+
lhs: Tuple[torch.Tensor, torch.Tensor],
|
343
|
+
rhs: Tuple[torch.Tensor, torch.Tensor],
|
344
|
+
out: torch.Tensor,
|
345
|
+
):
|
346
|
+
m, k = lhs[0].shape
|
347
|
+
n, _ = rhs[0].shape
|
348
|
+
|
349
|
+
kernel_type = DeepGemmKernelType.GEMM_NT_F8F8BF16
|
350
|
+
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, 1)
|
351
|
+
|
352
|
+
with _log_jit_build(m, n, k, kernel_type):
|
353
|
+
deep_gemm.gemm_fp8_fp8_bf16_nt(lhs, rhs, out)
|
354
|
+
|
355
|
+
|
356
|
+
@contextmanager
|
357
|
+
def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
|
358
|
+
if _IN_PRE_COMPILE_STAGE:
|
359
|
+
yield
|
360
|
+
return
|
361
|
+
|
362
|
+
from deep_gemm.jit.runtime import RuntimeCache
|
363
|
+
|
364
|
+
origin_func = RuntimeCache.__getitem__
|
365
|
+
|
366
|
+
def __patched_func(self, *args, **kwargs):
|
367
|
+
ret = origin_func(self, *args, **kwargs)
|
368
|
+
if ret is None:
|
369
|
+
kernel_helper = _KERNEL_HELPER_DICT[kernel_type]
|
370
|
+
_compile_warning_2()
|
371
|
+
logger.warning(
|
372
|
+
f"DeepGEMM JIT Compiling for <{kernel_helper.name}> M={M}, N={N}, K={K}. Please wait."
|
373
|
+
)
|
374
|
+
return ret
|
375
|
+
|
376
|
+
RuntimeCache.__getitem__ = __patched_func
|
377
|
+
yield
|
378
|
+
RuntimeCache.__getitem__ = origin_func
|