sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_one_batch.py +8 -6
  4. sglang/bench_serving.py +1 -1
  5. sglang/lang/interpreter.py +40 -1
  6. sglang/lang/ir.py +27 -0
  7. sglang/math_utils.py +8 -0
  8. sglang/srt/_custom_ops.py +2 -2
  9. sglang/srt/code_completion_parser.py +2 -44
  10. sglang/srt/configs/model_config.py +6 -0
  11. sglang/srt/constants.py +3 -0
  12. sglang/srt/conversation.py +19 -3
  13. sglang/srt/custom_op.py +5 -1
  14. sglang/srt/disaggregation/base/__init__.py +1 -1
  15. sglang/srt/disaggregation/base/conn.py +25 -11
  16. sglang/srt/disaggregation/common/__init__.py +5 -1
  17. sglang/srt/disaggregation/common/utils.py +42 -0
  18. sglang/srt/disaggregation/decode.py +211 -72
  19. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  20. sglang/srt/disaggregation/fake/__init__.py +1 -1
  21. sglang/srt/disaggregation/fake/conn.py +15 -9
  22. sglang/srt/disaggregation/mini_lb.py +34 -4
  23. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  24. sglang/srt/disaggregation/mooncake/conn.py +30 -29
  25. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  26. sglang/srt/disaggregation/nixl/conn.py +17 -12
  27. sglang/srt/disaggregation/prefill.py +144 -55
  28. sglang/srt/disaggregation/utils.py +155 -123
  29. sglang/srt/distributed/parallel_state.py +12 -4
  30. sglang/srt/entrypoints/engine.py +37 -29
  31. sglang/srt/entrypoints/http_server.py +153 -72
  32. sglang/srt/entrypoints/http_server_engine.py +0 -3
  33. sglang/srt/entrypoints/openai/__init__.py +0 -0
  34. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
  35. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  36. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  37. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  38. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  39. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  40. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  41. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  42. sglang/srt/entrypoints/openai/utils.py +72 -0
  43. sglang/srt/eplb_simulator/__init__.py +1 -0
  44. sglang/srt/eplb_simulator/reader.py +51 -0
  45. sglang/srt/function_call/base_format_detector.py +7 -4
  46. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  47. sglang/srt/function_call/ebnf_composer.py +64 -10
  48. sglang/srt/function_call/function_call_parser.py +6 -6
  49. sglang/srt/function_call/llama32_detector.py +1 -1
  50. sglang/srt/function_call/mistral_detector.py +1 -1
  51. sglang/srt/function_call/pythonic_detector.py +1 -1
  52. sglang/srt/function_call/qwen25_detector.py +1 -1
  53. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  54. sglang/srt/layers/activation.py +40 -3
  55. sglang/srt/layers/attention/aiter_backend.py +20 -4
  56. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  57. sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
  58. sglang/srt/layers/attention/flashattention_backend.py +71 -72
  59. sglang/srt/layers/attention/flashinfer_backend.py +10 -8
  60. sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
  61. sglang/srt/layers/attention/flashmla_backend.py +7 -12
  62. sglang/srt/layers/attention/tbo_backend.py +3 -3
  63. sglang/srt/layers/attention/triton_backend.py +138 -130
  64. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  65. sglang/srt/layers/attention/vision.py +51 -24
  66. sglang/srt/layers/communicator.py +28 -10
  67. sglang/srt/layers/dp_attention.py +11 -2
  68. sglang/srt/layers/layernorm.py +29 -2
  69. sglang/srt/layers/linear.py +0 -4
  70. sglang/srt/layers/logits_processor.py +2 -14
  71. sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
  72. sglang/srt/layers/moe/ep_moe/layer.py +249 -33
  73. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
  76. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  77. sglang/srt/layers/moe/topk.py +107 -12
  78. sglang/srt/layers/pooler.py +56 -0
  79. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  80. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  81. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  82. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  83. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  84. sglang/srt/layers/quantization/fp8.py +25 -17
  85. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  86. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  87. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  88. sglang/srt/layers/quantization/utils.py +5 -2
  89. sglang/srt/layers/radix_attention.py +2 -3
  90. sglang/srt/layers/rotary_embedding.py +42 -2
  91. sglang/srt/layers/sampler.py +1 -1
  92. sglang/srt/lora/lora_manager.py +249 -105
  93. sglang/srt/lora/mem_pool.py +53 -50
  94. sglang/srt/lora/utils.py +1 -1
  95. sglang/srt/managers/cache_controller.py +33 -14
  96. sglang/srt/managers/io_struct.py +31 -10
  97. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  98. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  99. sglang/srt/managers/schedule_batch.py +79 -37
  100. sglang/srt/managers/schedule_policy.py +70 -56
  101. sglang/srt/managers/scheduler.py +220 -79
  102. sglang/srt/managers/template_manager.py +226 -0
  103. sglang/srt/managers/tokenizer_manager.py +40 -10
  104. sglang/srt/managers/tp_worker.py +12 -2
  105. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  106. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  107. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  108. sglang/srt/mem_cache/chunk_cache.py +11 -15
  109. sglang/srt/mem_cache/hiradix_cache.py +38 -25
  110. sglang/srt/mem_cache/memory_pool.py +213 -505
  111. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  112. sglang/srt/mem_cache/radix_cache.py +56 -28
  113. sglang/srt/model_executor/cuda_graph_runner.py +198 -100
  114. sglang/srt/model_executor/forward_batch_info.py +32 -10
  115. sglang/srt/model_executor/model_runner.py +28 -12
  116. sglang/srt/model_loader/loader.py +16 -2
  117. sglang/srt/model_loader/weight_utils.py +11 -2
  118. sglang/srt/models/bert.py +113 -13
  119. sglang/srt/models/deepseek_nextn.py +29 -27
  120. sglang/srt/models/deepseek_v2.py +213 -173
  121. sglang/srt/models/glm4.py +312 -0
  122. sglang/srt/models/internvl.py +46 -102
  123. sglang/srt/models/mimo_mtp.py +2 -18
  124. sglang/srt/models/roberta.py +117 -9
  125. sglang/srt/models/vila.py +305 -0
  126. sglang/srt/reasoning_parser.py +21 -11
  127. sglang/srt/sampling/sampling_batch_info.py +24 -0
  128. sglang/srt/sampling/sampling_params.py +2 -0
  129. sglang/srt/server_args.py +351 -238
  130. sglang/srt/speculative/build_eagle_tree.py +1 -1
  131. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
  132. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
  133. sglang/srt/speculative/eagle_utils.py +468 -116
  134. sglang/srt/speculative/eagle_worker.py +258 -84
  135. sglang/srt/torch_memory_saver_adapter.py +19 -15
  136. sglang/srt/two_batch_overlap.py +4 -2
  137. sglang/srt/utils.py +235 -11
  138. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  139. sglang/test/runners.py +38 -3
  140. sglang/test/test_block_fp8.py +1 -0
  141. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  142. sglang/test/test_block_fp8_ep.py +2 -0
  143. sglang/test/test_utils.py +4 -1
  144. sglang/utils.py +9 -0
  145. sglang/version.py +1 -1
  146. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
  147. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
  148. sglang/srt/entrypoints/verl_engine.py +0 -179
  149. sglang/srt/openai_api/adapter.py +0 -1990
  150. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  151. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  152. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,32 @@
1
+ import logging
2
+
3
+ from sglang.srt.utils import get_bool_env_var, get_device_sm
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ def _compute_enable_deep_gemm():
9
+ sm_version = get_device_sm()
10
+ if sm_version < 90:
11
+ return False
12
+
13
+ try:
14
+ import deep_gemm
15
+ except ImportError:
16
+ logger.warning("Failed to import deep_gemm, disable ENABLE_JIT_DEEPGEMM.")
17
+ return False
18
+
19
+ return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
20
+
21
+
22
+ ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
23
+
24
+ try:
25
+ from deep_gemm import fp8_gemm_nt
26
+
27
+ # They have not given a name to this breaking change
28
+ DEEPGEMM_BLACKWELL = True
29
+ except ImportError:
30
+ DEEPGEMM_BLACKWELL = False
31
+
32
+ DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL
@@ -0,0 +1,110 @@
1
+ import logging
2
+ from contextlib import contextmanager
3
+ from typing import Tuple
4
+
5
+ import torch
6
+
7
+ from sglang.srt.layers.quantization.deep_gemm_wrapper import compile_utils
8
+ from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
9
+ DEEPGEMM_BLACKWELL,
10
+ DEEPGEMM_SCALE_UE8M0,
11
+ ENABLE_JIT_DEEPGEMM,
12
+ )
13
+ from sglang.srt.server_args import ServerArgs
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ if ENABLE_JIT_DEEPGEMM:
18
+ import deep_gemm
19
+
20
+ if DEEPGEMM_BLACKWELL:
21
+ from deep_gemm import fp8_gemm_nt as _gemm_nt_f8f8bf16_raw
22
+ from deep_gemm import (
23
+ fp8_m_grouped_gemm_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
24
+ )
25
+ from deep_gemm import (
26
+ m_grouped_fp8_gemm_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
27
+ )
28
+ else:
29
+ from deep_gemm import gemm_fp8_fp8_bf16_nt as _gemm_nt_f8f8bf16_raw
30
+ from deep_gemm import get_col_major_tma_aligned_tensor
31
+ from deep_gemm import (
32
+ m_grouped_gemm_fp8_fp8_bf16_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
33
+ )
34
+ from deep_gemm import (
35
+ m_grouped_gemm_fp8_fp8_bf16_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
36
+ )
37
+
38
+
39
+ def grouped_gemm_nt_f8f8bf16_masked(
40
+ lhs: Tuple[torch.Tensor, torch.Tensor],
41
+ rhs: Tuple[torch.Tensor, torch.Tensor],
42
+ out: torch.Tensor,
43
+ masked_m: torch.Tensor,
44
+ expected_m: int,
45
+ recipe=None,
46
+ ):
47
+ num_groups, _, k = lhs[0].shape
48
+ _, n, _ = rhs[0].shape
49
+ kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED
50
+
51
+ with compile_utils.deep_gemm_execution_hook(
52
+ expected_m, n, k, num_groups, kernel_type
53
+ ):
54
+ _grouped_gemm_nt_f8f8bf16_masked_raw(
55
+ lhs,
56
+ rhs,
57
+ out,
58
+ masked_m,
59
+ expected_m,
60
+ **({"recipe": recipe} if DEEPGEMM_BLACKWELL else {})
61
+ )
62
+
63
+
64
+ def grouped_gemm_nt_f8f8bf16_contig(
65
+ lhs: Tuple[torch.Tensor, torch.Tensor],
66
+ rhs: Tuple[torch.Tensor, torch.Tensor],
67
+ out: torch.Tensor,
68
+ m_indices: torch.Tensor,
69
+ ):
70
+ m, k = lhs[0].shape
71
+ num_groups, n, _ = rhs[0].shape
72
+ kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
73
+
74
+ with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
75
+ _grouped_gemm_nt_f8f8bf16_contig_raw(lhs, rhs, out, m_indices)
76
+
77
+
78
+ def gemm_nt_f8f8bf16(
79
+ lhs: Tuple[torch.Tensor, torch.Tensor],
80
+ rhs: Tuple[torch.Tensor, torch.Tensor],
81
+ out: torch.Tensor,
82
+ ):
83
+ m, k = lhs[0].shape
84
+ n, _ = rhs[0].shape
85
+ num_groups = 1
86
+ kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16
87
+
88
+ with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
89
+ _gemm_nt_f8f8bf16_raw(
90
+ lhs,
91
+ rhs,
92
+ out,
93
+ )
94
+
95
+
96
+ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
97
+ compile_utils.update_deep_gemm_config(gpu_id, server_args)
98
+
99
+
100
+ @contextmanager
101
+ def configure_deep_gemm_num_sms(num_sms):
102
+ if num_sms is None:
103
+ yield
104
+ else:
105
+ original_num_sms = deep_gemm.get_num_sms()
106
+ deep_gemm.set_num_sms(num_sms)
107
+ try:
108
+ yield
109
+ finally:
110
+ deep_gemm.set_num_sms(original_num_sms)
@@ -64,9 +64,12 @@ from sglang.srt.layers.quantization.utils import (
64
64
  )
65
65
  from sglang.srt.layers.utils import is_sm100_supported
66
66
  from sglang.srt.utils import (
67
+ cpu_has_amx_support,
67
68
  get_bool_env_var,
69
+ is_cpu,
68
70
  is_cuda,
69
71
  is_hip,
72
+ is_npu,
70
73
  log_info_on_rank0,
71
74
  print_warning_once,
72
75
  set_weight_attrs,
@@ -74,6 +77,9 @@ from sglang.srt.utils import (
74
77
 
75
78
  _is_hip = is_hip()
76
79
  _is_cuda = is_cuda()
80
+ _is_npu = is_npu()
81
+ _is_cpu_amx_available = cpu_has_amx_support()
82
+ _is_cpu = is_cpu()
77
83
 
78
84
  _is_fp8_fnuz = is_fp8_fnuz()
79
85
 
@@ -82,10 +88,11 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
82
88
 
83
89
  if _is_hip:
84
90
  from aiter import ActivationType, QuantType
91
+ from aiter.fused_moe import fused_moe
85
92
  from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
86
93
  from aiter.ops.shuffle import shuffle_weight
87
94
 
88
- if not _is_cuda:
95
+ if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
89
96
  from vllm._custom_ops import scaled_fp8_quant
90
97
 
91
98
 
@@ -1045,15 +1052,15 @@ class Fp8MoEMethod:
1045
1052
  if _use_hip_int4:
1046
1053
  # TODO: add triton kernel and add check _use_aiter
1047
1054
  assert not no_combine, f"{no_combine=} is not supported."
1048
- return ck_moe_2stages(
1055
+ return fused_moe(
1049
1056
  x,
1050
1057
  layer.w13_weight,
1051
1058
  layer.w2_weight,
1052
1059
  topk_weights,
1053
1060
  topk_ids,
1054
- QuantType.per_Token,
1055
- layer.w13_weight_scale1,
1056
- layer.w2_weight_scale1,
1061
+ quant_type=QuantType.per_Token,
1062
+ w1_scale=layer.w13_weight_scale1,
1063
+ w2_scale=layer.w2_weight_scale1,
1057
1064
  activation=(
1058
1065
  ActivationType.Silu if activation == "silu" else ActivationType.Gelu
1059
1066
  ),
@@ -1062,31 +1069,32 @@ class Fp8MoEMethod:
1062
1069
  if _use_aiter:
1063
1070
  assert not no_combine, f"{no_combine=} is not supported."
1064
1071
  if self.block_quant:
1065
- # TODO(_use_aiter): FP8 block_quant only supports 'silu' for the time-being.
1066
- assert (
1067
- activation == "silu"
1068
- ), f"_use_aiter: FP8 bloack_quant {activation=} will be supported later, unset _use_aiter"
1069
- return asm_moe(
1072
+ return fused_moe(
1070
1073
  x,
1071
1074
  layer.w13_weight,
1072
1075
  layer.w2_weight,
1073
1076
  topk_weights,
1074
1077
  topk_ids,
1075
- layer.w13_weight_scale_inv,
1076
- layer.w2_weight_scale_inv,
1077
- block_shape=tuple(self.quant_config.weight_block_size),
1078
+ w1_scale=layer.w13_weight_scale_inv,
1079
+ w2_scale=layer.w2_weight_scale_inv,
1080
+ quant_type=QuantType.per_128x128,
1081
+ activation=(
1082
+ ActivationType.Silu
1083
+ if activation == "silu"
1084
+ else ActivationType.Gelu
1085
+ ),
1078
1086
  expert_mask=None,
1079
1087
  )
1080
1088
  else:
1081
- return ck_moe_2stages(
1089
+ return fused_moe(
1082
1090
  x,
1083
1091
  layer.w13_weight,
1084
1092
  layer.w2_weight,
1085
1093
  topk_weights,
1086
1094
  topk_ids,
1087
- QuantType.per_Token,
1088
- layer.w13_weight_scale1,
1089
- layer.w2_weight_scale1,
1095
+ quant_type=QuantType.per_Token,
1096
+ w1_scale=layer.w13_weight_scale1,
1097
+ w2_scale=layer.w2_weight_scale1,
1090
1098
  activation=(
1091
1099
  ActivationType.Silu
1092
1100
  if activation == "silu"
@@ -23,7 +23,8 @@ import torch
23
23
  import triton
24
24
  import triton.language as tl
25
25
 
26
- from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
26
+ from sglang.math_utils import align
27
+ from sglang.srt.layers.quantization import deep_gemm_wrapper
27
28
  from sglang.srt.utils import (
28
29
  direct_register_custom_op,
29
30
  get_device_core_count,
@@ -44,10 +45,6 @@ if _is_cuda:
44
45
  sgl_per_token_quant_fp8,
45
46
  )
46
47
 
47
- from sglang.srt.layers.quantization.deep_gemm import (
48
- gemm_nt_f8f8bf16 as deep_gemm_gemm_nt_f8f8bf16,
49
- )
50
-
51
48
  logger = logging.getLogger(__name__)
52
49
 
53
50
 
@@ -67,7 +64,6 @@ else:
67
64
  fp8_max = torch.finfo(fp8_dtype).max
68
65
  fp8_min = -fp8_max
69
66
 
70
-
71
67
  if supports_custom_op():
72
68
 
73
69
  def deep_gemm_fp8_fp8_bf16_nt(
@@ -77,7 +73,7 @@ if supports_custom_op():
77
73
  Bs: torch.Tensor,
78
74
  C: torch.Tensor,
79
75
  ) -> None:
80
- deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
76
+ deep_gemm_wrapper.gemm_nt_f8f8bf16((A, As), (B, Bs), C)
81
77
 
82
78
  def deep_gemm_fp8_fp8_bf16_nt_fake(
83
79
  A: torch.Tensor,
@@ -280,6 +276,7 @@ def sglang_per_token_group_quant_fp8(
280
276
  eps: float = 1e-10,
281
277
  column_major_scales: bool = False,
282
278
  scale_tma_aligned: bool = False,
279
+ scale_ue8m0: bool = False,
283
280
  ):
284
281
  assert (
285
282
  x.shape[-1] % group_size == 0
@@ -287,8 +284,21 @@ def sglang_per_token_group_quant_fp8(
287
284
  assert x.is_contiguous(), "`x` is not contiguous"
288
285
 
289
286
  x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
290
- if column_major_scales:
287
+ if scale_ue8m0:
288
+ assert column_major_scales and scale_tma_aligned
289
+ x_q_mn, x_q_k = x.shape
290
+ x_s_mn, x_s_k = x_q_mn, x_q_k // 128
291
+ aligned_mn = align(x_s_mn, 4)
292
+ aligned_k = align(x_s_k, 4)
293
+ # TODO(FIXME): Fix cuda kernel and recover here to empty.
294
+ x_s = torch.zeros(
295
+ (aligned_k // 4, aligned_mn),
296
+ device=x.device,
297
+ dtype=torch.int,
298
+ ).transpose(0, 1)[:x_s_mn, :]
299
+ elif column_major_scales:
291
300
  if scale_tma_aligned:
301
+ # TODO extract "align" function
292
302
  # aligned to 4 * sizeof(float)
293
303
  aligned_size = (x.shape[-2] + 3) // 4 * 4
294
304
  x_s = torch.empty(
@@ -309,7 +319,9 @@ def sglang_per_token_group_quant_fp8(
309
319
  dtype=torch.float32,
310
320
  )
311
321
  if x.shape[0] > 0:
312
- sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
322
+ sgl_per_token_group_quant_fp8(
323
+ x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
324
+ )
313
325
 
314
326
  return x_q, x_s
315
327
 
@@ -754,7 +766,15 @@ def prepare_block_fp8_matmul_inputs(
754
766
  assert A.shape[-1] == B.shape[-1]
755
767
  assert A.shape[:-1] == As.shape[:-1]
756
768
  assert A.is_contiguous()
757
- assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
769
+
770
+ if As.dtype == torch.float:
771
+ assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
772
+ elif As.dtype == torch.int:
773
+ assert (
774
+ triton.cdiv(triton.cdiv(A.shape[-1], block_k), 4) == As.shape[-1]
775
+ ), f"{A.shape=} {As.shape=} {block_size=}"
776
+ else:
777
+ raise NotImplementedError
758
778
 
759
779
  M = A.numel() // A.shape[-1]
760
780
 
@@ -762,8 +782,17 @@ def prepare_block_fp8_matmul_inputs(
762
782
  assert B.is_contiguous()
763
783
  assert Bs.ndim == 2
764
784
  N, K = B.shape
765
- assert triton.cdiv(N, block_n) == Bs.shape[0]
766
- assert triton.cdiv(K, block_k) == Bs.shape[1]
785
+
786
+ if Bs.dtype == torch.float:
787
+ assert triton.cdiv(N, block_n) == Bs.shape[0]
788
+ assert triton.cdiv(K, block_k) == Bs.shape[1]
789
+ elif Bs.dtype == torch.int:
790
+ assert N == Bs.shape[0], f"{B.shape=} {Bs.shape=} {block_size=}"
791
+ assert (
792
+ triton.cdiv(triton.cdiv(K, block_k), 4) == Bs.shape[1]
793
+ ), f"{B.shape=} {Bs.shape=} {block_size=}"
794
+ else:
795
+ raise NotImplementedError
767
796
 
768
797
  C_shape = A.shape[:-1] + (N,)
769
798
  C = A.new_empty(C_shape, dtype=output_dtype)
@@ -782,12 +811,12 @@ def w8a8_block_fp8_matmul_deepgemm(
782
811
  M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size, output_dtype)
783
812
 
784
813
  # Deepgemm only supports output tensor type as bfloat16
785
- assert C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM
814
+ assert C.dtype == torch.bfloat16 and deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
786
815
 
787
816
  if supports_custom_op():
788
817
  torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
789
818
  else:
790
- deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
819
+ deep_gemm_wrapper.gemm_nt_f8f8bf16((A, As), (B, Bs), C)
791
820
 
792
821
  return C
793
822
 
@@ -881,7 +910,7 @@ def w8a8_block_fp8_matmul(
881
910
  block_size: List[int],
882
911
  output_dtype: torch.dtype = torch.float16,
883
912
  ) -> torch.Tensor:
884
- if output_dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM:
913
+ if output_dtype == torch.bfloat16 and deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
885
914
  return w8a8_block_fp8_matmul_deepgemm(
886
915
  A, B, As, Bs, block_size, output_dtype=output_dtype
887
916
  )
@@ -1,9 +1,10 @@
1
- import os
2
- from curses import flash
3
1
  from typing import Callable, List, Optional, Tuple
4
2
 
3
+ import einops
5
4
  import torch
6
5
 
6
+ from sglang.math_utils import align
7
+ from sglang.srt.layers.quantization import deep_gemm_wrapper
7
8
  from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
8
9
  from sglang.srt.layers.utils import is_sm100_supported
9
10
 
@@ -14,7 +15,6 @@ try:
14
15
  except ImportError:
15
16
  VLLM_AVAILABLE = False
16
17
 
17
- from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
18
18
  from sglang.srt.layers.quantization.fp8_kernel import (
19
19
  fp8_dtype,
20
20
  fp8_max,
@@ -137,7 +137,7 @@ def dispatch_w8a8_block_fp8_linear() -> Callable:
137
137
  return cutlass_w8a8_block_fp8_linear_with_fallback
138
138
  elif _use_aiter:
139
139
  return aiter_w8a8_block_fp8_linear
140
- elif _ENABLE_JIT_DEEPGEMM:
140
+ elif deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
141
141
  return deepgemm_w8a8_block_fp8_linear_with_fallback
142
142
  else:
143
143
  return triton_w8a8_block_fp8_linear
@@ -238,7 +238,14 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
238
238
  block_size[1],
239
239
  column_major_scales=True,
240
240
  scale_tma_aligned=True,
241
+ scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
241
242
  )
243
+
244
+ # NOTE(alcanderian): Useless when scale is packed to int32
245
+ # if get_bool_env_var("SGLANG_W8A8_DEEPGEMM_SANITY_CHECK_UE8M0"):
246
+ # _check_ue8m0("x_scale", x_scale)
247
+ # _check_ue8m0("weight_scale", ws)
248
+
242
249
  output = w8a8_block_fp8_matmul_deepgemm(
243
250
  q_input, weight, x_scale, weight_scale, block_size, output_dtype=output_dtype
244
251
  )
@@ -247,6 +254,11 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
247
254
  return output.to(dtype=output_dtype).view(*output_shape)
248
255
 
249
256
 
257
+ def _check_ue8m0(name, x):
258
+ x_ceil = ceil_to_ue8m0(x)
259
+ assert torch.all(x == x_ceil), f"{name=} {x=} {x_ceil=}"
260
+
261
+
250
262
  def aiter_w8a8_block_fp8_linear(
251
263
  input: torch.Tensor,
252
264
  weight: torch.Tensor,
@@ -369,27 +381,80 @@ def block_quant_dequant(
369
381
  The output is an unquantized tensor with dtype.
370
382
  """
371
383
  block_n, block_k = block_size[0], block_size[1]
372
- n, k = x_q_block.shape
373
- n_tiles = (n + block_n - 1) // block_n
374
- k_tiles = (k + block_k - 1) // block_k
375
- assert n_tiles == x_s.shape[0]
376
- assert k_tiles == x_s.shape[1]
384
+ *_, n, k = x_q_block.shape
377
385
 
378
- x_dq_block = torch.empty_like(x_q_block, dtype=dtype)
386
+ # ... n_scale k_scale -> ... (n_scale block_n) (k_scale block_k)
387
+ x_scale_repeat = x_s.repeat_interleave(block_n, dim=-2).repeat_interleave(
388
+ block_k, dim=-1
389
+ )
390
+ x_scale_repeat = x_scale_repeat[..., :n, :k]
391
+
392
+ return (x_q_block.to(torch.float32) * x_scale_repeat).to(dtype)
393
+
394
+
395
+ def requant_weight_ue8m0_inplace(weight, weight_scale_inv, weight_block_size):
396
+ assert isinstance(weight, torch.nn.Parameter)
397
+ assert isinstance(weight_scale_inv, torch.nn.Parameter)
398
+ weight.data, weight_scale_inv.data = _requant_weight_ue8m0(
399
+ weight, weight_scale_inv, weight_block_size
400
+ )
401
+
402
+
403
+ def _requant_weight_ue8m0(
404
+ weight: torch.Tensor,
405
+ weight_scale_inv: torch.Tensor,
406
+ weight_block_size: List[int],
407
+ ):
408
+ assert weight_block_size == [128, 128]
409
+
410
+ *_, n, k = weight.shape
411
+
412
+ weight_dequant = block_quant_dequant(
413
+ weight,
414
+ weight_scale_inv,
415
+ weight_block_size,
416
+ torch.bfloat16,
417
+ )
418
+
419
+ weight_dequant_flat = weight_dequant.view((-1, k))
420
+ out_w_flat, out_s_flat = per_block_cast_to_fp8(weight_dequant_flat)
421
+
422
+ out_w = out_w_flat.view(weight.shape)
423
+ out_s = out_s_flat.view(weight_scale_inv.shape)
424
+
425
+ # NOTE copy and modified from DeepGEMM
426
+ def _transform_scale(sf, mn: int):
427
+ import deep_gemm.utils.layout
428
+
429
+ sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128)
430
+ sf = deep_gemm.utils.layout.get_col_major_tma_aligned_packed_tensor(sf)
431
+ return sf
432
+
433
+ out_s = _transform_scale(out_s, mn=out_w.shape[-2])
434
+
435
+ return out_w, out_s
436
+
437
+
438
+ # COPIED FROM DeepGEMM
439
+ def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
440
+ assert x.dim() == 2
441
+ m, n = x.shape
442
+ x_padded = torch.zeros(
443
+ (align(m, 128), align(n, 128)), dtype=x.dtype, device=x.device
444
+ )
445
+ x_padded[:m, :n] = x
446
+ x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
447
+ x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
448
+ sf = ceil_to_ue8m0(x_amax / 448.0)
449
+ x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
450
+ return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
451
+ x_view.size(0), x_view.size(2)
452
+ )
379
453
 
380
- for j in range(n_tiles):
381
- for i in range(k_tiles):
382
- x_q_block_tile = x_q_block[
383
- j * block_n : min((j + 1) * block_n, n),
384
- i * block_k : min((i + 1) * block_k, k),
385
- ]
386
- x_dq_block_tile = x_dq_block[
387
- j * block_n : min((j + 1) * block_n, n),
388
- i * block_k : min((i + 1) * block_k, k),
389
- ]
390
- x_dq_block_tile[:, :] = x_q_block_tile.to(torch.float32) * x_s[j][i]
391
454
 
392
- return x_dq_block
455
+ # COPIED FROM DeepGEMM
456
+ def ceil_to_ue8m0(x: torch.Tensor):
457
+ return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
393
458
 
394
459
 
395
460
  def channel_quant_to_tensor_quant(
@@ -29,11 +29,17 @@ from sglang.srt.layers.quantization.utils import (
29
29
  requantize_with_max_scale,
30
30
  )
31
31
  from sglang.srt.layers.radix_attention import RadixAttention
32
- from sglang.srt.utils import is_cuda
32
+ from sglang.srt.utils import is_cuda, next_power_of_2
33
33
 
34
34
  if is_cuda():
35
35
  from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
36
36
 
37
+ try:
38
+ from flashinfer import fp4_quantize as fp4_quantize
39
+ from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
40
+ except ImportError:
41
+ flashinfer_cutlass_fused_moe = None
42
+
37
43
  # Initialize logger for the module
38
44
  logger = logging.getLogger(__name__)
39
45
 
@@ -429,6 +435,9 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
429
435
  layer.alpha = Parameter(
430
436
  layer.input_scale * layer.weight_scale_2, requires_grad=False
431
437
  )
438
+ layer.input_scale_inv = Parameter(
439
+ (1 / input_scale_2).to(torch.float32), requires_grad=False
440
+ )
432
441
 
433
442
  # Pad and blockwise interleave weight_scale
434
443
  scales = layer.weight_scale
@@ -467,7 +476,7 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
467
476
  output_shape = [x_m, w_n]
468
477
 
469
478
  # Quantize BF16 or FP16 to (FP4 and interleaved block scale)
470
- x_fp4, x_scale_interleaved = scaled_fp4_quant(x, 1 / layer.input_scale)
479
+ x_fp4, x_scale_interleaved = scaled_fp4_quant(x, layer.input_scale_inv)
471
480
 
472
481
  assert x_fp4.dtype == torch.uint8
473
482
  assert x_scale_interleaved.dtype == torch.float8_e4m3fn
@@ -521,6 +530,7 @@ class ModelOptNvFp4FusedMoEMethod:
521
530
  " quantization. Please use Blackwell and"
522
531
  " above."
523
532
  )
533
+ self.enable_flashinfer_moe = False
524
534
 
525
535
  def create_weights(
526
536
  self,
@@ -674,7 +684,10 @@ class ModelOptNvFp4FusedMoEMethod:
674
684
  w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
675
685
  layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
676
686
 
677
- w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
687
+ if self.enable_flashinfer_moe:
688
+ w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
689
+ else:
690
+ w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
678
691
  layer.g1_alphas = Parameter(
679
692
  (w13_input_scale * w13_weight_scale_2).to(torch.float32),
680
693
  requires_grad=False,
@@ -700,14 +713,19 @@ class ModelOptNvFp4FusedMoEMethod:
700
713
  layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
701
714
 
702
715
  # GEMM 2
716
+ if self.enable_flashinfer_moe:
717
+ w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
718
+ else:
719
+ w2_input_scale = layer.w2_input_scale
720
+
703
721
  layer.g2_alphas = Parameter(
704
- (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
722
+ (w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
705
723
  requires_grad=False,
706
724
  )
707
725
 
708
726
  # This is for quantization, so we need to invert it.
709
727
  layer.w2_input_scale_quant = Parameter(
710
- (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False
728
+ (1 / w2_input_scale).to(torch.float32), requires_grad=False
711
729
  )
712
730
 
713
731
  assert (
@@ -727,11 +745,16 @@ class ModelOptNvFp4FusedMoEMethod:
727
745
  layer.cutlass_moe_params = CutlassMoEParams(
728
746
  CutlassMoEType.BlockscaledFP4,
729
747
  device,
730
- num_experts=layer.num_experts,
748
+ num_experts=layer.num_experts, # global num experts
731
749
  intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n
732
750
  hidden_size=layer.w13_weight.shape[2] * 2,
733
751
  ) # k
734
752
 
753
+ @property
754
+ def load_up_proj_weight_first(self) -> bool:
755
+ # FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
756
+ return self.enable_flashinfer_moe
757
+
735
758
  def apply(
736
759
  self,
737
760
  layer: torch.nn.Module,
@@ -750,11 +773,13 @@ class ModelOptNvFp4FusedMoEMethod:
750
773
  inplace: bool = True,
751
774
  no_combine: bool = False,
752
775
  routed_scaling_factor: Optional[float] = None,
776
+ ep_rank: Optional[int] = None,
777
+ ep_size: Optional[int] = None,
778
+ tp_rank: Optional[int] = None,
779
+ tp_size: Optional[int] = None,
753
780
  ) -> torch.Tensor:
754
781
 
755
782
  assert activation == "silu", "Only SiLU activation is supported."
756
-
757
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
758
783
  from sglang.srt.layers.moe.topk import select_experts
759
784
 
760
785
  topk_weights, topk_ids = select_experts(
@@ -771,6 +796,35 @@ class ModelOptNvFp4FusedMoEMethod:
771
796
  routed_scaling_factor=routed_scaling_factor,
772
797
  )
773
798
 
799
+ if self.enable_flashinfer_moe:
800
+ assert (
801
+ not apply_router_weight_on_input
802
+ ), "apply_router_weight_on_input is not supported for Flashinfer"
803
+ # TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
804
+ # and fp4 quantized weights loaded from the checkpoint
805
+ output = flashinfer_cutlass_fused_moe(
806
+ x,
807
+ topk_ids.to(torch.int),
808
+ topk_weights,
809
+ layer.w13_weight.view(torch.long),
810
+ layer.w2_weight.view(torch.long),
811
+ x.dtype,
812
+ quant_scales=[
813
+ layer.w13_input_scale_quant,
814
+ layer.w13_blockscale_swizzled.view(torch.int32),
815
+ layer.g1_alphas,
816
+ layer.w2_input_scale_quant,
817
+ layer.w2_blockscale_swizzled.view(torch.int32),
818
+ layer.g2_alphas,
819
+ ],
820
+ ep_size=ep_size,
821
+ ep_rank=ep_rank,
822
+ tp_size=tp_size,
823
+ tp_rank=tp_rank,
824
+ tune_max_num_tokens=next_power_of_2(x.shape[0]),
825
+ )
826
+ return output[0]
827
+
774
828
  from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
775
829
 
776
830
  return cutlass_moe_fp4(