sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_serving.py +3 -6
  4. sglang/compile_deep_gemm.py +136 -0
  5. sglang/lang/backend/anthropic.py +0 -4
  6. sglang/lang/backend/base_backend.py +1 -1
  7. sglang/lang/backend/openai.py +6 -2
  8. sglang/lang/backend/runtime_endpoint.py +5 -1
  9. sglang/lang/backend/vertexai.py +0 -1
  10. sglang/lang/compiler.py +1 -7
  11. sglang/lang/tracer.py +3 -7
  12. sglang/srt/_custom_ops.py +0 -2
  13. sglang/srt/configs/model_config.py +4 -1
  14. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  15. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  16. sglang/srt/constrained/xgrammar_backend.py +27 -4
  17. sglang/srt/custom_op.py +0 -62
  18. sglang/srt/disaggregation/decode.py +105 -6
  19. sglang/srt/disaggregation/mini_lb.py +74 -9
  20. sglang/srt/disaggregation/mooncake/conn.py +33 -63
  21. sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
  22. sglang/srt/disaggregation/nixl/__init__.py +1 -0
  23. sglang/srt/disaggregation/nixl/conn.py +622 -0
  24. sglang/srt/disaggregation/prefill.py +137 -17
  25. sglang/srt/disaggregation/utils.py +32 -0
  26. sglang/srt/entrypoints/engine.py +4 -0
  27. sglang/srt/entrypoints/http_server.py +3 -7
  28. sglang/srt/entrypoints/verl_engine.py +7 -5
  29. sglang/srt/function_call_parser.py +60 -0
  30. sglang/srt/layers/activation.py +6 -8
  31. sglang/srt/layers/attention/flashattention_backend.py +883 -209
  32. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  33. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  34. sglang/srt/layers/attention/triton_backend.py +6 -0
  35. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
  36. sglang/srt/layers/attention/triton_ops/extend_attention.py +18 -7
  37. sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
  38. sglang/srt/layers/dp_attention.py +1 -1
  39. sglang/srt/layers/layernorm.py +20 -5
  40. sglang/srt/layers/linear.py +17 -3
  41. sglang/srt/layers/moe/ep_moe/layer.py +17 -29
  42. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  43. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
  44. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  45. sglang/srt/layers/moe/topk.py +27 -30
  46. sglang/srt/layers/parameter.py +0 -2
  47. sglang/srt/layers/quantization/__init__.py +1 -0
  48. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  49. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +9 -2
  50. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  52. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
  53. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  54. sglang/srt/layers/quantization/deep_gemm.py +378 -0
  55. sglang/srt/layers/quantization/fp8.py +115 -132
  56. sglang/srt/layers/quantization/fp8_kernel.py +213 -88
  57. sglang/srt/layers/quantization/fp8_utils.py +189 -264
  58. sglang/srt/layers/quantization/gptq.py +13 -7
  59. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  60. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  61. sglang/srt/layers/quantization/utils.py +5 -11
  62. sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
  63. sglang/srt/layers/quantization/w8a8_int8.py +7 -7
  64. sglang/srt/layers/radix_attention.py +15 -0
  65. sglang/srt/layers/rotary_embedding.py +9 -8
  66. sglang/srt/layers/sampler.py +7 -12
  67. sglang/srt/lora/backend/base_backend.py +18 -2
  68. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  69. sglang/srt/lora/backend/triton_backend.py +1 -1
  70. sglang/srt/lora/layers.py +1 -1
  71. sglang/srt/lora/lora.py +1 -1
  72. sglang/srt/lora/lora_manager.py +1 -1
  73. sglang/srt/managers/data_parallel_controller.py +7 -1
  74. sglang/srt/managers/detokenizer_manager.py +0 -1
  75. sglang/srt/managers/io_struct.py +15 -3
  76. sglang/srt/managers/mm_utils.py +4 -3
  77. sglang/srt/managers/multimodal_processor.py +0 -2
  78. sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
  79. sglang/srt/managers/schedule_batch.py +15 -4
  80. sglang/srt/managers/scheduler.py +28 -77
  81. sglang/srt/managers/tokenizer_manager.py +116 -29
  82. sglang/srt/managers/tp_worker.py +1 -0
  83. sglang/srt/mem_cache/hiradix_cache.py +41 -29
  84. sglang/srt/mem_cache/memory_pool.py +38 -15
  85. sglang/srt/model_executor/cuda_graph_runner.py +15 -10
  86. sglang/srt/model_executor/model_runner.py +39 -31
  87. sglang/srt/models/bert.py +398 -0
  88. sglang/srt/models/deepseek.py +1 -1
  89. sglang/srt/models/deepseek_nextn.py +74 -70
  90. sglang/srt/models/deepseek_v2.py +292 -348
  91. sglang/srt/models/llama.py +5 -5
  92. sglang/srt/models/minicpm3.py +31 -203
  93. sglang/srt/models/minicpmo.py +17 -6
  94. sglang/srt/models/qwen2.py +4 -1
  95. sglang/srt/models/qwen2_moe.py +14 -13
  96. sglang/srt/models/qwen3.py +335 -0
  97. sglang/srt/models/qwen3_moe.py +423 -0
  98. sglang/srt/openai_api/adapter.py +71 -4
  99. sglang/srt/openai_api/protocol.py +6 -1
  100. sglang/srt/reasoning_parser.py +0 -1
  101. sglang/srt/sampling/sampling_batch_info.py +2 -3
  102. sglang/srt/server_args.py +86 -72
  103. sglang/srt/speculative/build_eagle_tree.py +2 -2
  104. sglang/srt/speculative/eagle_utils.py +2 -2
  105. sglang/srt/speculative/eagle_worker.py +6 -14
  106. sglang/srt/utils.py +62 -6
  107. sglang/test/runners.py +5 -1
  108. sglang/test/test_block_fp8.py +167 -0
  109. sglang/test/test_custom_ops.py +1 -1
  110. sglang/test/test_utils.py +3 -1
  111. sglang/version.py +1 -1
  112. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +5 -5
  113. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +116 -110
  114. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +1 -1
  115. sglang/lang/__init__.py +0 -0
  116. sglang/srt/lora/backend/__init__.py +0 -25
  117. sglang/srt/server.py +0 -18
  118. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
  119. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,378 @@
1
+ import logging
2
+ import os
3
+ from contextlib import contextmanager
4
+ from dataclasses import dataclass
5
+ from enum import IntEnum, auto
6
+ from typing import Callable, Dict, List, Optional, Tuple
7
+
8
+ import torch
9
+ from tqdm.contrib.concurrent import thread_map
10
+
11
+ from sglang.srt.server_args import ServerArgs
12
+ from sglang.srt.utils import get_bool_env_var, get_device_sm, get_int_env_var, is_cuda
13
+
14
+ _ENABLE_JIT_DEEPGEMM = False
15
+ if is_cuda():
16
+ import deep_gemm
17
+ from deep_gemm import get_num_sms
18
+ from deep_gemm.jit_kernels.gemm import get_best_configs
19
+ from deep_gemm.jit_kernels.gemm import includes as deep_gemm_includes
20
+ from deep_gemm.jit_kernels.gemm import template as deep_gemm_gemm_template
21
+ from deep_gemm.jit_kernels.m_grouped_gemm import (
22
+ template as deep_gemm_grouped_gemm_template,
23
+ )
24
+ from deep_gemm.jit_kernels.tuner import jit_tuner
25
+
26
+ sm_version = get_device_sm()
27
+ if sm_version == 90:
28
+ if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="false"):
29
+ _ENABLE_JIT_DEEPGEMM = True
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+ _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
34
+ _ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var(
35
+ "SGL_JIT_DEEPGEMM_PRECOMPILE", "true"
36
+ )
37
+ _DO_COMPILE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true")
38
+ _COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
39
+ _IN_PRE_COMPILE_STAGE = get_bool_env_var("SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE", "false")
40
+
41
+ # Force redirect deep_gemm cache_dir
42
+ os.environ["DG_CACHE_DIR"] = os.getenv(
43
+ "SGL_DG_CACHE_DIR", os.path.expanduser("~") + "/.cache/deep_gemm"
44
+ )
45
+
46
+
47
+ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
48
+ global _BUILTIN_M_LIST
49
+ global _DO_COMPILE
50
+
51
+ # Generate m_max
52
+ m_max = 1024 * 16
53
+ if server_args.chunked_prefill_size < 1:
54
+ m_max = 1024 * 64
55
+ elif server_args.chunked_prefill_size > 8192:
56
+ m_max = server_args.chunked_prefill_size * 2
57
+ m_max = min(1024 * 128, m_max)
58
+ _BUILTIN_M_LIST = list(range(1, m_max + 1))
59
+
60
+ # Check if is the first rank on node
61
+ _DO_COMPILE = ServerArgs.base_gpu_id == gpu_id
62
+
63
+
64
+ class DeepGemmKernelType(IntEnum):
65
+ GROUPED_GEMM_NT_F8F8BF16_MASKED = auto()
66
+ GROUPED_GEMM_NT_F8F8BF16_CONTIG = auto()
67
+ GEMM_NT_F8F8BF16 = auto()
68
+
69
+
70
+ @dataclass
71
+ class DeepGemmKernelHelper:
72
+ name: str
73
+ compile_func: Callable[
74
+ [
75
+ int,
76
+ int,
77
+ int,
78
+ Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
79
+ ],
80
+ None,
81
+ ]
82
+ configure_func: Callable[
83
+ [int, int, int, int, int],
84
+ Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
85
+ ]
86
+
87
+
88
+ _INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dict()
89
+
90
+
91
+ def _compile_warning_1():
92
+ if not _IN_PRE_COMPILE_STAGE:
93
+ logger.warning(
94
+ "Entering DeepGEMM JIT Pre-Complie session. "
95
+ "And it may takes a long time(Typically 10-20 mins) "
96
+ "if you have not run `sglang.compile_deep_gemm`. "
97
+ "Recommand to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
98
+ " for pre-compilation to reduce the overhead if you have not run it before. "
99
+ "For example: "
100
+ "`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
101
+ )
102
+
103
+
104
+ def _compile_warning_2():
105
+ logger.warning(
106
+ "Entering DeepGEMM JIT Single Kernel Complie session. "
107
+ "And it will makes inference throughput becomes flaky. "
108
+ "Please run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
109
+ " for pre-compilation to solve this issue. "
110
+ "For example: "
111
+ "`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
112
+ )
113
+
114
+
115
+ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
116
+ n: int,
117
+ k: int,
118
+ num_groups: int,
119
+ config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
120
+ ) -> None:
121
+ # Auto-tuning with compilation
122
+ global deep_gemm_includes, deep_gemm_grouped_gemm_template
123
+ _, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
124
+ _ = jit_tuner.compile_and_tune(
125
+ name="m_grouped_gemm_fp8_fp8_bf16_nt",
126
+ keys={
127
+ "N": n,
128
+ "K": k,
129
+ "BLOCK_M": block_m,
130
+ "BLOCK_N": block_n,
131
+ "SWIZZLE_D_MODE": smem_config[1],
132
+ "BLOCK_N_PADDING": smem_config[2],
133
+ "NUM_GROUPS": num_groups,
134
+ "NUM_STAGES": num_stages,
135
+ "NUM_TMA_MULTICAST": tma_multicast_config[0],
136
+ "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
137
+ "GEMM_TYPE": "GroupedMasked",
138
+ },
139
+ space=(),
140
+ includes=deep_gemm_includes,
141
+ arg_defs=(
142
+ ("lhs", torch.float8_e4m3fn),
143
+ ("lhs_scales", torch.float),
144
+ ("rhs", torch.float8_e4m3fn),
145
+ ("rhs_scales", torch.float),
146
+ ("out", torch.bfloat16),
147
+ ("grouped_layout", torch.int32),
148
+ ("m", int),
149
+ ("stream", torch.cuda.Stream),
150
+ ("num_sms", int),
151
+ ("smem_size", int),
152
+ ),
153
+ template=deep_gemm_grouped_gemm_template,
154
+ args=[],
155
+ )
156
+
157
+
158
+ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
159
+ n: int,
160
+ k: int,
161
+ num_groups: int,
162
+ config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
163
+ ) -> None:
164
+ global deep_gemm_includes, deep_gemm_grouped_gemm_template
165
+ _, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
166
+ _ = jit_tuner.compile_and_tune(
167
+ name="m_grouped_gemm_fp8_fp8_bf16_nt",
168
+ keys={
169
+ "N": n,
170
+ "K": k,
171
+ "BLOCK_M": block_m,
172
+ "BLOCK_N": block_n,
173
+ "SWIZZLE_D_MODE": smem_config[1],
174
+ "BLOCK_N_PADDING": smem_config[2],
175
+ "NUM_GROUPS": num_groups,
176
+ "NUM_STAGES": num_stages,
177
+ "NUM_TMA_MULTICAST": tma_multicast_config[0],
178
+ "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
179
+ "GEMM_TYPE": "GroupedContiguous",
180
+ },
181
+ space=(),
182
+ includes=deep_gemm_includes,
183
+ arg_defs=(
184
+ ("lhs", torch.float8_e4m3fn),
185
+ ("lhs_scales", torch.float),
186
+ ("rhs", torch.float8_e4m3fn),
187
+ ("rhs_scales", torch.float),
188
+ ("out", torch.bfloat16),
189
+ ("grouped_layout", torch.int32),
190
+ ("m", int),
191
+ ("num_groups", int),
192
+ ("stream", torch.cuda.Stream),
193
+ ("num_sms", int),
194
+ ("smem_size", int),
195
+ ),
196
+ template=deep_gemm_grouped_gemm_template,
197
+ args=[],
198
+ )
199
+
200
+
201
+ def _compile_gemm_nt_f8f8bf16_one(
202
+ n: int,
203
+ k: int,
204
+ _: int, # _ is a dummy parameter to align with other interfaces
205
+ config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
206
+ ) -> None:
207
+ global deep_gemm_includes, deep_gemm_gemm_template
208
+ _, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
209
+ _ = jit_tuner.compile_and_tune(
210
+ name="gemm_fp8_fp8_bf16_nt",
211
+ keys={
212
+ "N": n,
213
+ "K": k,
214
+ "BLOCK_M": block_m,
215
+ "BLOCK_N": block_n,
216
+ "SWIZZLE_D_MODE": smem_config[1],
217
+ "BLOCK_N_PADDING": smem_config[2],
218
+ "NUM_STAGES": num_stages,
219
+ "NUM_TMA_MULTICAST": tma_multicast_config[0],
220
+ "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
221
+ },
222
+ space=(),
223
+ includes=deep_gemm_includes,
224
+ arg_defs=(
225
+ ("lhs", torch.float8_e4m3fn),
226
+ ("lhs_scales", torch.float),
227
+ ("rhs", torch.float8_e4m3fn),
228
+ ("rhs_scales", torch.float),
229
+ ("out", torch.bfloat16),
230
+ ("m", int),
231
+ ("stream", torch.cuda.Stream),
232
+ ("num_sms", int),
233
+ ("smem_size", int),
234
+ ),
235
+ template=deep_gemm_gemm_template,
236
+ args=[],
237
+ )
238
+
239
+
240
+ _KERNEL_HELPER_DICT: Dict[DeepGemmKernelType, DeepGemmKernelHelper] = {
241
+ DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: DeepGemmKernelHelper(
242
+ name="m_grouped_gemm_fp8_fp8_bf16_nt_masked",
243
+ compile_func=_compile_grouped_gemm_nt_f8f8bf16_masked_one,
244
+ configure_func=lambda m, n, k, num_groups, num_sms: get_best_configs(
245
+ m, n, k, num_groups, num_sms, is_grouped_masked=True
246
+ ),
247
+ ),
248
+ DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG: DeepGemmKernelHelper(
249
+ name="m_grouped_gemm_fp8_fp8_bf16_nt_contiguous",
250
+ compile_func=_compile_grouped_gemm_nt_f8f8bf16_contig_one,
251
+ configure_func=lambda m, n, k, _, num_sms: get_best_configs(
252
+ m, n, k, 1, num_sms, is_grouped_contiguous=True
253
+ ),
254
+ ),
255
+ DeepGemmKernelType.GEMM_NT_F8F8BF16: DeepGemmKernelHelper(
256
+ name="gemm_fp8_fp8_bf16_nt",
257
+ compile_func=_compile_gemm_nt_f8f8bf16_one,
258
+ configure_func=lambda m, n, k, _, num_sms: get_best_configs(
259
+ m, n, k, 1, num_sms
260
+ ),
261
+ ),
262
+ }
263
+
264
+
265
+ def _maybe_compile_deep_gemm_one_type_all(
266
+ kernel_type: DeepGemmKernelType,
267
+ n: int,
268
+ k: int,
269
+ num_groups: int,
270
+ m_list: Optional[List[int]] = None,
271
+ ) -> None:
272
+
273
+ global _INITIALIZATION_DICT
274
+ global _BUILTIN_M_LIST
275
+
276
+ query_key = (kernel_type, n, k, num_groups)
277
+ if (
278
+ _ENABLE_JIT_DEEPGEMM_PRECOMPILE
279
+ and _DO_COMPILE
280
+ and _INITIALIZATION_DICT.get(query_key) is None
281
+ ):
282
+ _INITIALIZATION_DICT[query_key] = True
283
+
284
+ kernel_helper = _KERNEL_HELPER_DICT[kernel_type]
285
+ _compile_warning_1()
286
+ logger.info(
287
+ f"Try DeepGEMM JIT Compiling for "
288
+ f"<{kernel_helper.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
289
+ f"{' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. ' if not _IN_PRE_COMPILE_STAGE else ''}"
290
+ )
291
+
292
+ # NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
293
+ num_sms = get_num_sms()
294
+ collected_configs = set()
295
+ for m in m_list if m_list is not None else _BUILTIN_M_LIST:
296
+ # Put config into set to get unique configs and reduce cases to be compiled
297
+ collected_configs.add(
298
+ kernel_helper.configure_func(m, n, k, num_groups, num_sms)
299
+ )
300
+ compile_func = lambda config: kernel_helper.compile_func(
301
+ n, k, num_groups, config
302
+ )
303
+ thread_map(compile_func, collected_configs, max_workers=_COMPILE_WORKERS)
304
+
305
+
306
+ def grouped_gemm_nt_f8f8bf16_masked(
307
+ lhs: Tuple[torch.Tensor, torch.Tensor],
308
+ rhs: Tuple[torch.Tensor, torch.Tensor],
309
+ out: torch.Tensor,
310
+ masked_m: torch.Tensor,
311
+ expected_m: int,
312
+ ):
313
+ num_groups, _, k = lhs[0].shape
314
+ _, n, _ = rhs[0].shape
315
+
316
+ kernel_type = DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED
317
+ _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
318
+
319
+ with _log_jit_build(expected_m, n, k, kernel_type):
320
+ deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
321
+ lhs, rhs, out, masked_m, expected_m
322
+ )
323
+
324
+
325
+ def grouped_gemm_nt_f8f8bf16_contig(
326
+ lhs: Tuple[torch.Tensor, torch.Tensor],
327
+ rhs: Tuple[torch.Tensor, torch.Tensor],
328
+ out: torch.Tensor,
329
+ m_indices: torch.Tensor,
330
+ ):
331
+ m, k = lhs[0].shape
332
+ num_groups, n, _ = rhs[0].shape
333
+
334
+ kernel_type = DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
335
+ _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
336
+
337
+ with _log_jit_build(m, n, k, kernel_type):
338
+ deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs, rhs, out, m_indices)
339
+
340
+
341
+ def gemm_nt_f8f8bf16(
342
+ lhs: Tuple[torch.Tensor, torch.Tensor],
343
+ rhs: Tuple[torch.Tensor, torch.Tensor],
344
+ out: torch.Tensor,
345
+ ):
346
+ m, k = lhs[0].shape
347
+ n, _ = rhs[0].shape
348
+
349
+ kernel_type = DeepGemmKernelType.GEMM_NT_F8F8BF16
350
+ _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, 1)
351
+
352
+ with _log_jit_build(m, n, k, kernel_type):
353
+ deep_gemm.gemm_fp8_fp8_bf16_nt(lhs, rhs, out)
354
+
355
+
356
+ @contextmanager
357
+ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
358
+ if _IN_PRE_COMPILE_STAGE:
359
+ yield
360
+ return
361
+
362
+ from deep_gemm.jit.runtime import RuntimeCache
363
+
364
+ origin_func = RuntimeCache.__getitem__
365
+
366
+ def __patched_func(self, *args, **kwargs):
367
+ ret = origin_func(self, *args, **kwargs)
368
+ if ret is None:
369
+ kernel_helper = _KERNEL_HELPER_DICT[kernel_type]
370
+ _compile_warning_2()
371
+ logger.warning(
372
+ f"DeepGEMM JIT Compiling for <{kernel_helper.name}> M={M}, N={N}, K={K}. Please wait."
373
+ )
374
+ return ret
375
+
376
+ RuntimeCache.__getitem__ = __patched_func
377
+ yield
378
+ RuntimeCache.__getitem__ = origin_func