sglang 0.4.7__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_serving.py +1 -1
  4. sglang/lang/interpreter.py +40 -1
  5. sglang/lang/ir.py +27 -0
  6. sglang/math_utils.py +8 -0
  7. sglang/srt/configs/model_config.py +6 -0
  8. sglang/srt/conversation.py +6 -0
  9. sglang/srt/disaggregation/base/__init__.py +1 -1
  10. sglang/srt/disaggregation/base/conn.py +25 -11
  11. sglang/srt/disaggregation/common/__init__.py +5 -1
  12. sglang/srt/disaggregation/common/utils.py +42 -0
  13. sglang/srt/disaggregation/decode.py +196 -51
  14. sglang/srt/disaggregation/fake/__init__.py +1 -1
  15. sglang/srt/disaggregation/fake/conn.py +15 -9
  16. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  17. sglang/srt/disaggregation/mooncake/conn.py +18 -13
  18. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  19. sglang/srt/disaggregation/nixl/conn.py +17 -12
  20. sglang/srt/disaggregation/prefill.py +128 -43
  21. sglang/srt/disaggregation/utils.py +127 -123
  22. sglang/srt/entrypoints/engine.py +15 -1
  23. sglang/srt/entrypoints/http_server.py +13 -2
  24. sglang/srt/eplb_simulator/__init__.py +1 -0
  25. sglang/srt/eplb_simulator/reader.py +51 -0
  26. sglang/srt/layers/activation.py +19 -0
  27. sglang/srt/layers/attention/aiter_backend.py +15 -2
  28. sglang/srt/layers/attention/cutlass_mla_backend.py +38 -15
  29. sglang/srt/layers/attention/flashattention_backend.py +53 -64
  30. sglang/srt/layers/attention/flashinfer_backend.py +1 -2
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +22 -24
  32. sglang/srt/layers/attention/flashmla_backend.py +2 -10
  33. sglang/srt/layers/attention/triton_backend.py +119 -119
  34. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  35. sglang/srt/layers/attention/vision.py +51 -24
  36. sglang/srt/layers/communicator.py +23 -5
  37. sglang/srt/layers/linear.py +0 -4
  38. sglang/srt/layers/logits_processor.py +0 -12
  39. sglang/srt/layers/moe/ep_moe/kernels.py +6 -5
  40. sglang/srt/layers/moe/ep_moe/layer.py +42 -32
  41. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  42. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -4
  43. sglang/srt/layers/moe/topk.py +16 -8
  44. sglang/srt/layers/pooler.py +56 -0
  45. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  46. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  47. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  48. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  49. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  50. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  51. sglang/srt/layers/radix_attention.py +2 -3
  52. sglang/srt/lora/lora_manager.py +79 -34
  53. sglang/srt/lora/mem_pool.py +4 -5
  54. sglang/srt/managers/cache_controller.py +2 -1
  55. sglang/srt/managers/io_struct.py +28 -4
  56. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  57. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  58. sglang/srt/managers/schedule_batch.py +39 -6
  59. sglang/srt/managers/scheduler.py +73 -17
  60. sglang/srt/managers/tokenizer_manager.py +29 -2
  61. sglang/srt/mem_cache/chunk_cache.py +1 -0
  62. sglang/srt/mem_cache/hiradix_cache.py +4 -2
  63. sglang/srt/mem_cache/memory_pool.py +111 -407
  64. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  65. sglang/srt/mem_cache/radix_cache.py +36 -12
  66. sglang/srt/model_executor/cuda_graph_runner.py +122 -55
  67. sglang/srt/model_executor/forward_batch_info.py +14 -5
  68. sglang/srt/model_executor/model_runner.py +6 -6
  69. sglang/srt/model_loader/loader.py +8 -1
  70. sglang/srt/models/bert.py +113 -13
  71. sglang/srt/models/deepseek_v2.py +113 -155
  72. sglang/srt/models/internvl.py +46 -102
  73. sglang/srt/models/roberta.py +117 -9
  74. sglang/srt/models/vila.py +305 -0
  75. sglang/srt/openai_api/adapter.py +162 -4
  76. sglang/srt/openai_api/protocol.py +37 -1
  77. sglang/srt/sampling/sampling_batch_info.py +24 -0
  78. sglang/srt/sampling/sampling_params.py +2 -0
  79. sglang/srt/server_args.py +318 -233
  80. sglang/srt/speculative/build_eagle_tree.py +1 -1
  81. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -3
  82. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +5 -2
  83. sglang/srt/speculative/eagle_utils.py +389 -109
  84. sglang/srt/speculative/eagle_worker.py +134 -43
  85. sglang/srt/two_batch_overlap.py +4 -2
  86. sglang/srt/utils.py +58 -0
  87. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  88. sglang/test/runners.py +38 -3
  89. sglang/test/test_block_fp8.py +1 -0
  90. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  91. sglang/test/test_block_fp8_ep.py +1 -0
  92. sglang/test/test_utils.py +3 -1
  93. sglang/utils.py +9 -0
  94. sglang/version.py +1 -1
  95. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +5 -5
  96. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +99 -88
  97. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
@@ -3,10 +3,13 @@
3
3
 
4
4
  from dataclasses import dataclass
5
5
  from enum import IntEnum
6
+ from typing import Optional
6
7
 
7
8
  import torch
8
9
  import torch.nn as nn
10
+ from transformers import PretrainedConfig
9
11
 
12
+ from sglang.srt.layers.activation import get_cross_encoder_activation_function
10
13
  from sglang.srt.model_executor.model_runner import ForwardBatch
11
14
 
12
15
 
@@ -54,3 +57,56 @@ class Pooler(nn.Module):
54
57
  pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
55
58
 
56
59
  return EmbeddingPoolerOutput(embeddings=pooled_data)
60
+
61
+
62
+ class CrossEncodingPooler(nn.Module):
63
+ """A layer that pools specific information from hidden states.
64
+
65
+ This layer does the following:
66
+ 1. Extracts specific tokens or aggregates data based on pooling method.
67
+ 2. Normalizes output if specified.
68
+ 3. Returns structured results as `EmbeddingPoolerOutput`.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ config: PretrainedConfig,
74
+ classifier: nn.Module,
75
+ pooler: Optional[nn.Module] = None,
76
+ ):
77
+ super().__init__()
78
+ self.classifier = classifier
79
+ self.pooler = pooler
80
+ self.default_activation_function = get_cross_encoder_activation_function(config)
81
+
82
+ def forward(
83
+ self,
84
+ hidden_states: torch.Tensor,
85
+ forward_batch: ForwardBatch,
86
+ ) -> EmbeddingPoolerOutput:
87
+ """Pools sentence pair scores from the hidden_states."""
88
+
89
+ prompt_lens = forward_batch.extend_seq_lens
90
+
91
+ offset = 0
92
+ pooled_data_lst = []
93
+ for prompt_len in prompt_lens:
94
+ pooled_data_i = hidden_states[offset : offset + prompt_len]
95
+
96
+ if self.pooler is not None:
97
+ final_shape_tensor = self.pooler(pooled_data_i, forward_batch)
98
+ else:
99
+ final_shape_tensor = self.classifier(pooled_data_i)
100
+
101
+ pooled_data_lst.append(final_shape_tensor)
102
+ offset += prompt_len
103
+
104
+ pooled_output = torch.stack(pooled_data_lst)
105
+
106
+ if self.pooler is not None:
107
+ # apply classifier once on the full batch if possible
108
+ pooled_output = self.classifier(pooled_output)
109
+
110
+ scores = self.default_activation_function(pooled_output).squeeze(-1)
111
+
112
+ return EmbeddingPoolerOutput(embeddings=scores)
@@ -0,0 +1 @@
1
+ from .entrypoint import *
@@ -5,34 +5,23 @@ from dataclasses import dataclass
5
5
  from enum import IntEnum, auto
6
6
  from typing import Callable, Dict, List, Optional, Tuple
7
7
 
8
- import torch
9
8
  from tqdm.contrib.concurrent import thread_map
10
9
 
10
+ from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
11
+ DEEPGEMM_BLACKWELL,
12
+ ENABLE_JIT_DEEPGEMM,
13
+ )
11
14
  from sglang.srt.server_args import ServerArgs
12
- from sglang.srt.utils import get_bool_env_var, get_device_sm, get_int_env_var, is_cuda
15
+ from sglang.srt.utils import get_bool_env_var, get_int_env_var
13
16
 
14
17
  logger = logging.getLogger(__name__)
15
- _ENABLE_JIT_DEEPGEMM = False
16
18
 
17
- try:
18
- import deep_gemm
19
+ if ENABLE_JIT_DEEPGEMM and not DEEPGEMM_BLACKWELL:
19
20
  from deep_gemm import get_num_sms
20
21
  from deep_gemm.jit import build
21
- from deep_gemm.jit.compiler import get_nvcc_compiler
22
22
  from deep_gemm.jit_kernels.gemm import get_best_configs
23
23
  from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
24
24
 
25
- sm_version = get_device_sm()
26
- if sm_version == 90:
27
- if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
28
- _ENABLE_JIT_DEEPGEMM = True
29
- except ImportError:
30
- logger.warning("Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.")
31
-
32
-
33
- def get_enable_jit_deepgemm():
34
- return _ENABLE_JIT_DEEPGEMM
35
-
36
25
 
37
26
  _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
38
27
  _ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var(
@@ -52,8 +41,10 @@ os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
52
41
  # NVRTC may have performance loss with some cases.
53
42
  # And NVCC JIT speed is also 9x faster in the ref commit
54
43
  _USE_NVRTC_DEFAULT = "0"
55
- if _ENABLE_JIT_DEEPGEMM:
44
+ if ENABLE_JIT_DEEPGEMM:
56
45
  try:
46
+ from deep_gemm.jit.compiler import get_nvcc_compiler
47
+
57
48
  get_nvcc_compiler()
58
49
  except:
59
50
  logger.warning(
@@ -114,11 +105,12 @@ class DeepGemmKernelHelper:
114
105
  _INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dict()
115
106
 
116
107
 
108
+ # TODO improve naming
117
109
  def _compile_warning_1():
118
110
  if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
119
111
  logger.warning(
120
112
  "Entering DeepGEMM JIT Pre-Compile session. "
121
- "And it may takes a long time(Typically 10-20 mins) "
113
+ "It may takes a long time (typically 10-20 mins) "
122
114
  "if you have not run `sglang.compile_deep_gemm`. "
123
115
  "It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
124
116
  " for pre-compilation to reduce the overhead if you have not run it before. "
@@ -127,6 +119,7 @@ def _compile_warning_1():
127
119
  )
128
120
 
129
121
 
122
+ # TODO improve naming
130
123
  def _compile_warning_2():
131
124
  logger.warning(
132
125
  "Entering DeepGEMM JIT Single Kernel Compile session. "
@@ -238,6 +231,7 @@ def _compile_gemm_nt_f8f8bf16_one(
238
231
  _ = build("gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
239
232
 
240
233
 
234
+ # TODO further refactor warmup-related
241
235
  _KERNEL_HELPER_DICT: Dict[DeepGemmKernelType, DeepGemmKernelHelper] = {
242
236
  DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: DeepGemmKernelHelper(
243
237
  name="m_grouped_gemm_fp8_fp8_bf16_nt_masked",
@@ -270,7 +264,6 @@ def _maybe_compile_deep_gemm_one_type_all(
270
264
  num_groups: int,
271
265
  m_list: Optional[List[int]] = None,
272
266
  ) -> None:
273
-
274
267
  global _INITIALIZATION_DICT
275
268
  global _BUILTIN_M_LIST
276
269
 
@@ -304,56 +297,6 @@ def _maybe_compile_deep_gemm_one_type_all(
304
297
  thread_map(compile_func, collected_configs, max_workers=_COMPILE_WORKERS)
305
298
 
306
299
 
307
- def grouped_gemm_nt_f8f8bf16_masked(
308
- lhs: Tuple[torch.Tensor, torch.Tensor],
309
- rhs: Tuple[torch.Tensor, torch.Tensor],
310
- out: torch.Tensor,
311
- masked_m: torch.Tensor,
312
- expected_m: int,
313
- ):
314
- num_groups, _, k = lhs[0].shape
315
- _, n, _ = rhs[0].shape
316
-
317
- kernel_type = DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED
318
- _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
319
-
320
- with _log_jit_build(expected_m, n, k, kernel_type):
321
- deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
322
- lhs, rhs, out, masked_m, expected_m
323
- )
324
-
325
-
326
- def grouped_gemm_nt_f8f8bf16_contig(
327
- lhs: Tuple[torch.Tensor, torch.Tensor],
328
- rhs: Tuple[torch.Tensor, torch.Tensor],
329
- out: torch.Tensor,
330
- m_indices: torch.Tensor,
331
- ):
332
- m, k = lhs[0].shape
333
- num_groups, n, _ = rhs[0].shape
334
-
335
- kernel_type = DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
336
- _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
337
-
338
- with _log_jit_build(m, n, k, kernel_type):
339
- deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs, rhs, out, m_indices)
340
-
341
-
342
- def gemm_nt_f8f8bf16(
343
- lhs: Tuple[torch.Tensor, torch.Tensor],
344
- rhs: Tuple[torch.Tensor, torch.Tensor],
345
- out: torch.Tensor,
346
- ):
347
- m, k = lhs[0].shape
348
- n, _ = rhs[0].shape
349
-
350
- kernel_type = DeepGemmKernelType.GEMM_NT_F8F8BF16
351
- _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, 1)
352
-
353
- with _log_jit_build(m, n, k, kernel_type):
354
- deep_gemm.gemm_fp8_fp8_bf16_nt(lhs, rhs, out)
355
-
356
-
357
300
  @contextmanager
358
301
  def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
359
302
  if _IN_PRECOMPILE_STAGE:
@@ -368,7 +311,8 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
368
311
  ret = origin_func(self, *args, **kwargs)
369
312
  if ret is None:
370
313
  kernel_helper = _KERNEL_HELPER_DICT[kernel_type]
371
- _compile_warning_2()
314
+ if not DEEPGEMM_BLACKWELL:
315
+ _compile_warning_2()
372
316
  logger.warning(
373
317
  f"DeepGEMM JIT Compiling for <{kernel_helper.name}> M={M}, N={N}, K={K}. Please wait."
374
318
  )
@@ -380,13 +324,12 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
380
324
 
381
325
 
382
326
  @contextmanager
383
- def configure_deep_gemm_num_sms(num_sms):
384
- if num_sms is None:
327
+ def deep_gemm_execution_hook(
328
+ m: int, n: int, k: int, num_groups: int, kernel_type: DeepGemmKernelType
329
+ ):
330
+ # not supported yet
331
+ if not DEEPGEMM_BLACKWELL:
332
+ _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
333
+
334
+ with _log_jit_build(m, n, k, kernel_type):
385
335
  yield
386
- else:
387
- original_num_sms = deep_gemm.get_num_sms()
388
- deep_gemm.set_num_sms(num_sms)
389
- try:
390
- yield
391
- finally:
392
- deep_gemm.set_num_sms(original_num_sms)
@@ -0,0 +1,32 @@
1
+ import logging
2
+
3
+ from sglang.srt.utils import get_bool_env_var, get_device_sm
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ def _compute_enable_deep_gemm():
9
+ sm_version = get_device_sm()
10
+ if sm_version < 90:
11
+ return False
12
+
13
+ try:
14
+ import deep_gemm
15
+ except ImportError:
16
+ logger.warning("Failed to import deep_gemm, disable ENABLE_JIT_DEEPGEMM.")
17
+ return False
18
+
19
+ return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
20
+
21
+
22
+ ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
23
+
24
+ try:
25
+ from deep_gemm import fp8_gemm_nt
26
+
27
+ # They have not given a name to this breaking change
28
+ DEEPGEMM_BLACKWELL = True
29
+ except ImportError:
30
+ DEEPGEMM_BLACKWELL = False
31
+
32
+ DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL
@@ -0,0 +1,110 @@
1
+ import logging
2
+ from contextlib import contextmanager
3
+ from typing import Tuple
4
+
5
+ import torch
6
+
7
+ from sglang.srt.layers.quantization.deep_gemm_wrapper import compile_utils
8
+ from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
9
+ DEEPGEMM_BLACKWELL,
10
+ DEEPGEMM_SCALE_UE8M0,
11
+ ENABLE_JIT_DEEPGEMM,
12
+ )
13
+ from sglang.srt.server_args import ServerArgs
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ if ENABLE_JIT_DEEPGEMM:
18
+ import deep_gemm
19
+
20
+ if DEEPGEMM_BLACKWELL:
21
+ from deep_gemm import fp8_gemm_nt as _gemm_nt_f8f8bf16_raw
22
+ from deep_gemm import (
23
+ fp8_m_grouped_gemm_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
24
+ )
25
+ from deep_gemm import (
26
+ m_grouped_fp8_gemm_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
27
+ )
28
+ else:
29
+ from deep_gemm import gemm_fp8_fp8_bf16_nt as _gemm_nt_f8f8bf16_raw
30
+ from deep_gemm import get_col_major_tma_aligned_tensor
31
+ from deep_gemm import (
32
+ m_grouped_gemm_fp8_fp8_bf16_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
33
+ )
34
+ from deep_gemm import (
35
+ m_grouped_gemm_fp8_fp8_bf16_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
36
+ )
37
+
38
+
39
+ def grouped_gemm_nt_f8f8bf16_masked(
40
+ lhs: Tuple[torch.Tensor, torch.Tensor],
41
+ rhs: Tuple[torch.Tensor, torch.Tensor],
42
+ out: torch.Tensor,
43
+ masked_m: torch.Tensor,
44
+ expected_m: int,
45
+ recipe=None,
46
+ ):
47
+ num_groups, _, k = lhs[0].shape
48
+ _, n, _ = rhs[0].shape
49
+ kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED
50
+
51
+ with compile_utils.deep_gemm_execution_hook(
52
+ expected_m, n, k, num_groups, kernel_type
53
+ ):
54
+ _grouped_gemm_nt_f8f8bf16_masked_raw(
55
+ lhs,
56
+ rhs,
57
+ out,
58
+ masked_m,
59
+ expected_m,
60
+ **({"recipe": recipe} if DEEPGEMM_BLACKWELL else {})
61
+ )
62
+
63
+
64
+ def grouped_gemm_nt_f8f8bf16_contig(
65
+ lhs: Tuple[torch.Tensor, torch.Tensor],
66
+ rhs: Tuple[torch.Tensor, torch.Tensor],
67
+ out: torch.Tensor,
68
+ m_indices: torch.Tensor,
69
+ ):
70
+ m, k = lhs[0].shape
71
+ num_groups, n, _ = rhs[0].shape
72
+ kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
73
+
74
+ with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
75
+ _grouped_gemm_nt_f8f8bf16_contig_raw(lhs, rhs, out, m_indices)
76
+
77
+
78
+ def gemm_nt_f8f8bf16(
79
+ lhs: Tuple[torch.Tensor, torch.Tensor],
80
+ rhs: Tuple[torch.Tensor, torch.Tensor],
81
+ out: torch.Tensor,
82
+ ):
83
+ m, k = lhs[0].shape
84
+ n, _ = rhs[0].shape
85
+ num_groups = 1
86
+ kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16
87
+
88
+ with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
89
+ _gemm_nt_f8f8bf16_raw(
90
+ lhs,
91
+ rhs,
92
+ out,
93
+ )
94
+
95
+
96
+ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
97
+ compile_utils.update_deep_gemm_config(gpu_id, server_args)
98
+
99
+
100
+ @contextmanager
101
+ def configure_deep_gemm_num_sms(num_sms):
102
+ if num_sms is None:
103
+ yield
104
+ else:
105
+ original_num_sms = deep_gemm.get_num_sms()
106
+ deep_gemm.set_num_sms(num_sms)
107
+ try:
108
+ yield
109
+ finally:
110
+ deep_gemm.set_num_sms(original_num_sms)
@@ -23,7 +23,8 @@ import torch
23
23
  import triton
24
24
  import triton.language as tl
25
25
 
26
- from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
26
+ from sglang.math_utils import align
27
+ from sglang.srt.layers.quantization import deep_gemm_wrapper
27
28
  from sglang.srt.utils import (
28
29
  direct_register_custom_op,
29
30
  get_device_core_count,
@@ -44,10 +45,6 @@ if _is_cuda:
44
45
  sgl_per_token_quant_fp8,
45
46
  )
46
47
 
47
- from sglang.srt.layers.quantization.deep_gemm import (
48
- gemm_nt_f8f8bf16 as deep_gemm_gemm_nt_f8f8bf16,
49
- )
50
-
51
48
  logger = logging.getLogger(__name__)
52
49
 
53
50
 
@@ -67,7 +64,6 @@ else:
67
64
  fp8_max = torch.finfo(fp8_dtype).max
68
65
  fp8_min = -fp8_max
69
66
 
70
-
71
67
  if supports_custom_op():
72
68
 
73
69
  def deep_gemm_fp8_fp8_bf16_nt(
@@ -77,7 +73,7 @@ if supports_custom_op():
77
73
  Bs: torch.Tensor,
78
74
  C: torch.Tensor,
79
75
  ) -> None:
80
- deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
76
+ deep_gemm_wrapper.gemm_nt_f8f8bf16((A, As), (B, Bs), C)
81
77
 
82
78
  def deep_gemm_fp8_fp8_bf16_nt_fake(
83
79
  A: torch.Tensor,
@@ -280,6 +276,7 @@ def sglang_per_token_group_quant_fp8(
280
276
  eps: float = 1e-10,
281
277
  column_major_scales: bool = False,
282
278
  scale_tma_aligned: bool = False,
279
+ scale_ue8m0: bool = False,
283
280
  ):
284
281
  assert (
285
282
  x.shape[-1] % group_size == 0
@@ -287,8 +284,21 @@ def sglang_per_token_group_quant_fp8(
287
284
  assert x.is_contiguous(), "`x` is not contiguous"
288
285
 
289
286
  x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
290
- if column_major_scales:
287
+ if scale_ue8m0:
288
+ assert column_major_scales and scale_tma_aligned
289
+ x_q_mn, x_q_k = x.shape
290
+ x_s_mn, x_s_k = x_q_mn, x_q_k // 128
291
+ aligned_mn = align(x_s_mn, 4)
292
+ aligned_k = align(x_s_k, 4)
293
+ # TODO(FIXME): Fix cuda kernel and recover here to empty.
294
+ x_s = torch.zeros(
295
+ (aligned_k // 4, aligned_mn),
296
+ device=x.device,
297
+ dtype=torch.int,
298
+ ).transpose(0, 1)[:x_s_mn, :]
299
+ elif column_major_scales:
291
300
  if scale_tma_aligned:
301
+ # TODO extract "align" function
292
302
  # aligned to 4 * sizeof(float)
293
303
  aligned_size = (x.shape[-2] + 3) // 4 * 4
294
304
  x_s = torch.empty(
@@ -309,7 +319,9 @@ def sglang_per_token_group_quant_fp8(
309
319
  dtype=torch.float32,
310
320
  )
311
321
  if x.shape[0] > 0:
312
- sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
322
+ sgl_per_token_group_quant_fp8(
323
+ x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
324
+ )
313
325
 
314
326
  return x_q, x_s
315
327
 
@@ -754,7 +766,15 @@ def prepare_block_fp8_matmul_inputs(
754
766
  assert A.shape[-1] == B.shape[-1]
755
767
  assert A.shape[:-1] == As.shape[:-1]
756
768
  assert A.is_contiguous()
757
- assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
769
+
770
+ if As.dtype == torch.float:
771
+ assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
772
+ elif As.dtype == torch.int:
773
+ assert (
774
+ triton.cdiv(triton.cdiv(A.shape[-1], block_k), 4) == As.shape[-1]
775
+ ), f"{A.shape=} {As.shape=} {block_size=}"
776
+ else:
777
+ raise NotImplementedError
758
778
 
759
779
  M = A.numel() // A.shape[-1]
760
780
 
@@ -762,8 +782,17 @@ def prepare_block_fp8_matmul_inputs(
762
782
  assert B.is_contiguous()
763
783
  assert Bs.ndim == 2
764
784
  N, K = B.shape
765
- assert triton.cdiv(N, block_n) == Bs.shape[0]
766
- assert triton.cdiv(K, block_k) == Bs.shape[1]
785
+
786
+ if Bs.dtype == torch.float:
787
+ assert triton.cdiv(N, block_n) == Bs.shape[0]
788
+ assert triton.cdiv(K, block_k) == Bs.shape[1]
789
+ elif Bs.dtype == torch.int:
790
+ assert N == Bs.shape[0], f"{B.shape=} {Bs.shape=} {block_size=}"
791
+ assert (
792
+ triton.cdiv(triton.cdiv(K, block_k), 4) == Bs.shape[1]
793
+ ), f"{B.shape=} {Bs.shape=} {block_size=}"
794
+ else:
795
+ raise NotImplementedError
767
796
 
768
797
  C_shape = A.shape[:-1] + (N,)
769
798
  C = A.new_empty(C_shape, dtype=output_dtype)
@@ -782,12 +811,12 @@ def w8a8_block_fp8_matmul_deepgemm(
782
811
  M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size, output_dtype)
783
812
 
784
813
  # Deepgemm only supports output tensor type as bfloat16
785
- assert C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM
814
+ assert C.dtype == torch.bfloat16 and deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
786
815
 
787
816
  if supports_custom_op():
788
817
  torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
789
818
  else:
790
- deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
819
+ deep_gemm_wrapper.gemm_nt_f8f8bf16((A, As), (B, Bs), C)
791
820
 
792
821
  return C
793
822
 
@@ -881,7 +910,7 @@ def w8a8_block_fp8_matmul(
881
910
  block_size: List[int],
882
911
  output_dtype: torch.dtype = torch.float16,
883
912
  ) -> torch.Tensor:
884
- if output_dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM:
913
+ if output_dtype == torch.bfloat16 and deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
885
914
  return w8a8_block_fp8_matmul_deepgemm(
886
915
  A, B, As, Bs, block_size, output_dtype=output_dtype
887
916
  )
@@ -1,9 +1,10 @@
1
- import os
2
- from curses import flash
3
1
  from typing import Callable, List, Optional, Tuple
4
2
 
3
+ import einops
5
4
  import torch
6
5
 
6
+ from sglang.math_utils import align
7
+ from sglang.srt.layers.quantization import deep_gemm_wrapper
7
8
  from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
8
9
  from sglang.srt.layers.utils import is_sm100_supported
9
10
 
@@ -14,7 +15,6 @@ try:
14
15
  except ImportError:
15
16
  VLLM_AVAILABLE = False
16
17
 
17
- from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
18
18
  from sglang.srt.layers.quantization.fp8_kernel import (
19
19
  fp8_dtype,
20
20
  fp8_max,
@@ -137,7 +137,7 @@ def dispatch_w8a8_block_fp8_linear() -> Callable:
137
137
  return cutlass_w8a8_block_fp8_linear_with_fallback
138
138
  elif _use_aiter:
139
139
  return aiter_w8a8_block_fp8_linear
140
- elif _ENABLE_JIT_DEEPGEMM:
140
+ elif deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
141
141
  return deepgemm_w8a8_block_fp8_linear_with_fallback
142
142
  else:
143
143
  return triton_w8a8_block_fp8_linear
@@ -238,7 +238,14 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
238
238
  block_size[1],
239
239
  column_major_scales=True,
240
240
  scale_tma_aligned=True,
241
+ scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
241
242
  )
243
+
244
+ # NOTE(alcanderian): Useless when scale is packed to int32
245
+ # if get_bool_env_var("SGLANG_W8A8_DEEPGEMM_SANITY_CHECK_UE8M0"):
246
+ # _check_ue8m0("x_scale", x_scale)
247
+ # _check_ue8m0("weight_scale", ws)
248
+
242
249
  output = w8a8_block_fp8_matmul_deepgemm(
243
250
  q_input, weight, x_scale, weight_scale, block_size, output_dtype=output_dtype
244
251
  )
@@ -247,6 +254,11 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
247
254
  return output.to(dtype=output_dtype).view(*output_shape)
248
255
 
249
256
 
257
+ def _check_ue8m0(name, x):
258
+ x_ceil = ceil_to_ue8m0(x)
259
+ assert torch.all(x == x_ceil), f"{name=} {x=} {x_ceil=}"
260
+
261
+
250
262
  def aiter_w8a8_block_fp8_linear(
251
263
  input: torch.Tensor,
252
264
  weight: torch.Tensor,
@@ -369,27 +381,80 @@ def block_quant_dequant(
369
381
  The output is an unquantized tensor with dtype.
370
382
  """
371
383
  block_n, block_k = block_size[0], block_size[1]
372
- n, k = x_q_block.shape
373
- n_tiles = (n + block_n - 1) // block_n
374
- k_tiles = (k + block_k - 1) // block_k
375
- assert n_tiles == x_s.shape[0]
376
- assert k_tiles == x_s.shape[1]
384
+ *_, n, k = x_q_block.shape
377
385
 
378
- x_dq_block = torch.empty_like(x_q_block, dtype=dtype)
386
+ # ... n_scale k_scale -> ... (n_scale block_n) (k_scale block_k)
387
+ x_scale_repeat = x_s.repeat_interleave(block_n, dim=-2).repeat_interleave(
388
+ block_k, dim=-1
389
+ )
390
+ x_scale_repeat = x_scale_repeat[..., :n, :k]
391
+
392
+ return (x_q_block.to(torch.float32) * x_scale_repeat).to(dtype)
393
+
394
+
395
+ def requant_weight_ue8m0_inplace(weight, weight_scale_inv, weight_block_size):
396
+ assert isinstance(weight, torch.nn.Parameter)
397
+ assert isinstance(weight_scale_inv, torch.nn.Parameter)
398
+ weight.data, weight_scale_inv.data = _requant_weight_ue8m0(
399
+ weight, weight_scale_inv, weight_block_size
400
+ )
401
+
402
+
403
+ def _requant_weight_ue8m0(
404
+ weight: torch.Tensor,
405
+ weight_scale_inv: torch.Tensor,
406
+ weight_block_size: List[int],
407
+ ):
408
+ assert weight_block_size == [128, 128]
409
+
410
+ *_, n, k = weight.shape
411
+
412
+ weight_dequant = block_quant_dequant(
413
+ weight,
414
+ weight_scale_inv,
415
+ weight_block_size,
416
+ torch.bfloat16,
417
+ )
418
+
419
+ weight_dequant_flat = weight_dequant.view((-1, k))
420
+ out_w_flat, out_s_flat = per_block_cast_to_fp8(weight_dequant_flat)
421
+
422
+ out_w = out_w_flat.view(weight.shape)
423
+ out_s = out_s_flat.view(weight_scale_inv.shape)
424
+
425
+ # NOTE copy and modified from DeepGEMM
426
+ def _transform_scale(sf, mn: int):
427
+ import deep_gemm.utils.layout
428
+
429
+ sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128)
430
+ sf = deep_gemm.utils.layout.get_col_major_tma_aligned_packed_tensor(sf)
431
+ return sf
432
+
433
+ out_s = _transform_scale(out_s, mn=out_w.shape[-2])
434
+
435
+ return out_w, out_s
436
+
437
+
438
+ # COPIED FROM DeepGEMM
439
+ def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
440
+ assert x.dim() == 2
441
+ m, n = x.shape
442
+ x_padded = torch.zeros(
443
+ (align(m, 128), align(n, 128)), dtype=x.dtype, device=x.device
444
+ )
445
+ x_padded[:m, :n] = x
446
+ x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
447
+ x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
448
+ sf = ceil_to_ue8m0(x_amax / 448.0)
449
+ x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
450
+ return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
451
+ x_view.size(0), x_view.size(2)
452
+ )
379
453
 
380
- for j in range(n_tiles):
381
- for i in range(k_tiles):
382
- x_q_block_tile = x_q_block[
383
- j * block_n : min((j + 1) * block_n, n),
384
- i * block_k : min((i + 1) * block_k, k),
385
- ]
386
- x_dq_block_tile = x_dq_block[
387
- j * block_n : min((j + 1) * block_n, n),
388
- i * block_k : min((i + 1) * block_k, k),
389
- ]
390
- x_dq_block_tile[:, :] = x_q_block_tile.to(torch.float32) * x_s[j][i]
391
454
 
392
- return x_dq_block
455
+ # COPIED FROM DeepGEMM
456
+ def ceil_to_ue8m0(x: torch.Tensor):
457
+ return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
393
458
 
394
459
 
395
460
  def channel_quant_to_tensor_quant(