sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +3 -13
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +158 -8
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +119 -75
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +5 -2
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/internvl.py +696 -0
  13. sglang/srt/configs/janus_pro.py +3 -0
  14. sglang/srt/configs/model_config.py +18 -0
  15. sglang/srt/constrained/base_grammar_backend.py +55 -72
  16. sglang/srt/constrained/llguidance_backend.py +25 -21
  17. sglang/srt/constrained/outlines_backend.py +27 -26
  18. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  19. sglang/srt/constrained/xgrammar_backend.py +71 -53
  20. sglang/srt/conversation.py +78 -46
  21. sglang/srt/disaggregation/base/conn.py +1 -0
  22. sglang/srt/disaggregation/decode.py +11 -3
  23. sglang/srt/disaggregation/fake/conn.py +1 -1
  24. sglang/srt/disaggregation/mini_lb.py +74 -23
  25. sglang/srt/disaggregation/mooncake/conn.py +236 -138
  26. sglang/srt/disaggregation/nixl/conn.py +242 -71
  27. sglang/srt/disaggregation/prefill.py +7 -4
  28. sglang/srt/disaggregation/utils.py +51 -2
  29. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  30. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  31. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  32. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  33. sglang/srt/distributed/parallel_state.py +22 -1
  34. sglang/srt/entrypoints/engine.py +31 -4
  35. sglang/srt/entrypoints/http_server.py +45 -3
  36. sglang/srt/entrypoints/verl_engine.py +3 -2
  37. sglang/srt/function_call_parser.py +2 -2
  38. sglang/srt/hf_transformers_utils.py +20 -1
  39. sglang/srt/layers/attention/flashattention_backend.py +147 -51
  40. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  41. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  42. sglang/srt/layers/attention/merge_state.py +46 -0
  43. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  44. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  45. sglang/srt/layers/attention/utils.py +4 -2
  46. sglang/srt/layers/attention/vision.py +290 -163
  47. sglang/srt/layers/dp_attention.py +71 -21
  48. sglang/srt/layers/layernorm.py +1 -1
  49. sglang/srt/layers/logits_processor.py +46 -11
  50. sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
  51. sglang/srt/layers/moe/ep_moe/layer.py +121 -2
  52. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  56. sglang/srt/layers/moe/topk.py +1 -1
  57. sglang/srt/layers/quantization/__init__.py +1 -1
  58. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  59. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  60. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  61. sglang/srt/layers/quantization/deep_gemm.py +77 -71
  62. sglang/srt/layers/quantization/fp8.py +110 -97
  63. sglang/srt/layers/quantization/fp8_kernel.py +81 -62
  64. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  65. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  66. sglang/srt/layers/quantization/kv_cache.py +3 -10
  67. sglang/srt/layers/quantization/utils.py +0 -5
  68. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  69. sglang/srt/layers/sampler.py +0 -4
  70. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  71. sglang/srt/lora/lora_manager.py +11 -14
  72. sglang/srt/lora/mem_pool.py +4 -4
  73. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  74. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  75. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  76. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  77. sglang/srt/lora/utils.py +1 -1
  78. sglang/srt/managers/cache_controller.py +115 -119
  79. sglang/srt/managers/data_parallel_controller.py +3 -3
  80. sglang/srt/managers/detokenizer_manager.py +21 -8
  81. sglang/srt/managers/io_struct.py +13 -1
  82. sglang/srt/managers/mm_utils.py +1 -1
  83. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  84. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  85. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  86. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  87. sglang/srt/managers/schedule_batch.py +93 -23
  88. sglang/srt/managers/schedule_policy.py +11 -8
  89. sglang/srt/managers/scheduler.py +140 -100
  90. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  91. sglang/srt/managers/tokenizer_manager.py +157 -47
  92. sglang/srt/managers/tp_worker.py +21 -21
  93. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  94. sglang/srt/mem_cache/chunk_cache.py +2 -0
  95. sglang/srt/mem_cache/memory_pool.py +4 -2
  96. sglang/srt/metrics/collector.py +312 -37
  97. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  98. sglang/srt/model_executor/forward_batch_info.py +1 -1
  99. sglang/srt/model_executor/model_runner.py +57 -41
  100. sglang/srt/model_loader/loader.py +18 -11
  101. sglang/srt/models/clip.py +4 -4
  102. sglang/srt/models/deepseek_janus_pro.py +3 -3
  103. sglang/srt/models/deepseek_nextn.py +1 -20
  104. sglang/srt/models/deepseek_v2.py +77 -39
  105. sglang/srt/models/gemma3_mm.py +1 -1
  106. sglang/srt/models/internlm2.py +3 -0
  107. sglang/srt/models/internvl.py +670 -0
  108. sglang/srt/models/llama.py +3 -1
  109. sglang/srt/models/llama4.py +58 -13
  110. sglang/srt/models/llava.py +248 -5
  111. sglang/srt/models/minicpmv.py +1 -1
  112. sglang/srt/models/mixtral.py +98 -34
  113. sglang/srt/models/mllama.py +1 -1
  114. sglang/srt/models/phi3_small.py +16 -2
  115. sglang/srt/models/pixtral.py +467 -0
  116. sglang/srt/models/qwen2_5_vl.py +8 -4
  117. sglang/srt/models/qwen2_vl.py +4 -4
  118. sglang/srt/models/roberta.py +1 -1
  119. sglang/srt/models/torch_native_llama.py +1 -1
  120. sglang/srt/models/xiaomi_mimo.py +171 -0
  121. sglang/srt/openai_api/adapter.py +52 -42
  122. sglang/srt/openai_api/protocol.py +20 -16
  123. sglang/srt/reasoning_parser.py +1 -1
  124. sglang/srt/sampling/custom_logit_processor.py +18 -3
  125. sglang/srt/sampling/sampling_batch_info.py +2 -2
  126. sglang/srt/sampling/sampling_params.py +2 -0
  127. sglang/srt/server_args.py +64 -10
  128. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  129. sglang/srt/speculative/eagle_utils.py +7 -7
  130. sglang/srt/speculative/eagle_worker.py +22 -19
  131. sglang/srt/utils.py +41 -6
  132. sglang/test/few_shot_gsm8k.py +2 -2
  133. sglang/test/few_shot_gsm8k_engine.py +2 -2
  134. sglang/test/run_eval.py +2 -2
  135. sglang/test/runners.py +8 -1
  136. sglang/test/send_one.py +13 -3
  137. sglang/test/simple_eval_common.py +1 -1
  138. sglang/test/simple_eval_humaneval.py +1 -1
  139. sglang/test/test_block_fp8.py +2 -2
  140. sglang/test/test_deepep_utils.py +219 -0
  141. sglang/test/test_programs.py +5 -5
  142. sglang/test/test_utils.py +92 -15
  143. sglang/utils.py +1 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
  146. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
  147. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
  148. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  149. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 16,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 64,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 32,
31
+ "num_warps": 4,
32
+ "num_stages": 5
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 16,
39
+ "num_warps": 4,
40
+ "num_stages": 4
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 16,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 64,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 64,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 64,
87
+ "num_warps": 4,
88
+ "num_stages": 4
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 16,
95
+ "num_warps": 4,
96
+ "num_stages": 4
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 64,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 32,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 32,
119
+ "num_warps": 4,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 64,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 4,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 4
145
+ }
146
+ }
@@ -29,6 +29,7 @@ from sglang.srt.utils import (
29
29
  get_device_name,
30
30
  is_cuda,
31
31
  is_hip,
32
+ log_info_on_rank0,
32
33
  )
33
34
 
34
35
  _is_hip = is_hip()
@@ -945,7 +946,9 @@ def get_moe_configs(
945
946
  # For example, updating the Triton version might cause all old configs to become suboptimal.
946
947
  # To achieve the best performance, consider re-tuning the Triton fused MOE kernel in your environment.
947
948
  # For the tuning method, refer to: https://github.com/sgl-project/sglang/tree/main/benchmark/kernels/fused_moe_triton
948
- logger.info("Using MoE kernel config from %s.", config_file_path)
949
+ log_info_on_rank0(
950
+ logger, f"Using MoE kernel config from {config_file_path}."
951
+ )
949
952
  # If a configuration has been found, return it
950
953
  return {int(key): val for key, val in json.load(f).items()}
951
954
 
@@ -991,7 +994,7 @@ def get_default_config(
991
994
  "num_stages": 2 if _is_hip else 4,
992
995
  }
993
996
  else:
994
- # Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
997
+ # Block-wise quant: BLOCK_SIZE_K must be divisible by block_shape[1]
995
998
  config = {
996
999
  "BLOCK_SIZE_M": 64,
997
1000
  "BLOCK_SIZE_N": block_shape[0],
@@ -270,7 +270,7 @@ def select_experts(
270
270
  routed_scaling_factor: Optional[float] = None,
271
271
  ):
272
272
  n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
273
- # DeekSeek V2/V3/R1 serices models uses grouped_top_k
273
+ # DeepSeek V2/V3/R1 series models use grouped_top_k
274
274
  if use_grouped_topk:
275
275
  assert topk_group is not None
276
276
  assert num_expert_group is not None
@@ -109,7 +109,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
109
109
  if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
110
110
  raise ValueError(
111
111
  f"{quantization} quantization requires some operators from vllm. "
112
- "Pleaes install vllm by `pip install vllm==0.8.4`"
112
+ "Please install vllm by `pip install vllm==0.8.4`"
113
113
  )
114
114
 
115
115
  return QUANTIZATION_METHODS[quantization]
@@ -152,7 +152,7 @@ class BlockInt8LinearMethod(LinearMethodBase):
152
152
  f"{input_size_per_partition} is not divisible by "
153
153
  f"weight quantization block_k = {block_k}."
154
154
  )
155
- # Required by collum parallel or enabling merged weights
155
+ # Required by column parallel or enabling merged weights
156
156
  if (tp_size > 1 and output_size // output_size_per_partition == tp_size) or len(
157
157
  output_partition_sizes
158
158
  ) > 1:
@@ -285,7 +285,7 @@ class BlockInt8MoEMethod:
285
285
  self.quant_config.weight_block_size[1],
286
286
  )
287
287
  # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
288
- # Required by collum parallel or enabling merged weights
288
+ # Required by column parallel or enabling merged weights
289
289
  if intermediate_size % block_n != 0:
290
290
  raise ValueError(
291
291
  f"The output_size of gate's and up's weight = "
@@ -10,16 +10,14 @@ import torch
10
10
  from compressed_tensors import CompressionFormat
11
11
  from compressed_tensors.quantization import QuantizationStrategy
12
12
 
13
- from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
13
+ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
14
14
  from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
15
15
  from sglang.srt.layers.quantization.utils import (
16
16
  all_close_1d,
17
- is_cuda,
18
- is_fp8_fnuz,
19
17
  per_tensor_dequantize,
20
18
  replace_parameter,
21
19
  )
22
- from sglang.srt.utils import set_weight_attrs
20
+ from sglang.srt.utils import is_cuda, set_weight_attrs
23
21
 
24
22
  _is_cuda = is_cuda()
25
23
 
@@ -15,11 +15,12 @@ from sglang.srt.layers.parameter import (
15
15
  from sglang.srt.layers.quantization.compressed_tensors.schemes import (
16
16
  CompressedTensorsScheme,
17
17
  )
18
+ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
18
19
  from sglang.srt.layers.quantization.fp8_utils import (
19
20
  apply_fp8_linear,
20
21
  normalize_e4m3fn_to_e4m3fnuz,
21
22
  )
22
- from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale
23
+ from sglang.srt.layers.quantization.utils import requantize_with_max_scale
23
24
 
24
25
  __all__ = ["CompressedTensorsW8A8Fp8"]
25
26
 
@@ -15,12 +15,9 @@ _ENABLE_JIT_DEEPGEMM = False
15
15
  if is_cuda():
16
16
  import deep_gemm
17
17
  from deep_gemm import get_num_sms
18
+ from deep_gemm.jit.compiler import get_nvcc_compiler
18
19
  from deep_gemm.jit_kernels.gemm import get_best_configs
19
- from deep_gemm.jit_kernels.gemm import includes as deep_gemm_includes
20
- from deep_gemm.jit_kernels.gemm import template as deep_gemm_gemm_template
21
- from deep_gemm.jit_kernels.m_grouped_gemm import (
22
- template as deep_gemm_grouped_gemm_template,
23
- )
20
+ from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
24
21
  from deep_gemm.jit_kernels.tuner import jit_tuner
25
22
 
26
23
  sm_version = get_device_sm()
@@ -28,6 +25,11 @@ if is_cuda():
28
25
  if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
29
26
  _ENABLE_JIT_DEEPGEMM = True
30
27
 
28
+
29
+ def get_enable_jit_deepgemm():
30
+ return _ENABLE_JIT_DEEPGEMM
31
+
32
+
31
33
  logger = logging.getLogger(__name__)
32
34
 
33
35
  _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
@@ -40,10 +42,25 @@ _COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
40
42
  _IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false")
41
43
 
42
44
  # Force redirect deep_gemm cache_dir
43
- os.environ["DG_CACHE_DIR"] = os.getenv(
44
- "SGL_DG_CACHE_DIR", os.path.expanduser("~") + "/.cache/deep_gemm"
45
+ os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
46
+ "SGL_DG_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "deep_gemm")
45
47
  )
46
48
 
49
+ # Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
50
+ # NVRTC may have performance loss with some cases.
51
+ # And NVCC JIT speed is also 9x faster in the ref commit
52
+ _USE_NVRTC_DEFAULT = "0"
53
+ if _ENABLE_JIT_DEEPGEMM:
54
+ try:
55
+ get_nvcc_compiler()
56
+ except:
57
+ logger.warning(
58
+ "NVCC Compiler not found, use NVRTC for DeepGEMM JIT "
59
+ "and may have performance loss with some cases."
60
+ )
61
+ _USE_NVRTC_DEFAULT = "1"
62
+ os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", _USE_NVRTC_DEFAULT)
63
+
47
64
 
48
65
  def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
49
66
  global _BUILTIN_M_LIST
@@ -98,10 +115,10 @@ _INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dic
98
115
  def _compile_warning_1():
99
116
  if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
100
117
  logger.warning(
101
- "Entering DeepGEMM JIT Pre-Complie session. "
118
+ "Entering DeepGEMM JIT Pre-Compile session. "
102
119
  "And it may takes a long time(Typically 10-20 mins) "
103
120
  "if you have not run `sglang.compile_deep_gemm`. "
104
- "Recommand to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
121
+ "It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
105
122
  " for pre-compilation to reduce the overhead if you have not run it before. "
106
123
  "For example: "
107
124
  "`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
@@ -110,7 +127,7 @@ def _compile_warning_1():
110
127
 
111
128
  def _compile_warning_2():
112
129
  logger.warning(
113
- "Entering DeepGEMM JIT Single Kernel Complie session. "
130
+ "Entering DeepGEMM JIT Single Kernel Compile session. "
114
131
  "And it will makes inference throughput becomes flaky. "
115
132
  "Please run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
116
133
  " for pre-compilation to solve this issue. "
@@ -125,10 +142,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
125
142
  num_groups: int,
126
143
  config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
127
144
  ) -> None:
128
- # Auto-tuning with compilation
129
- global deep_gemm_includes, deep_gemm_grouped_gemm_template
130
- _, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
131
- _ = jit_tuner.compile_and_tune(
145
+ num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
146
+ block_k = 128
147
+ num_tma_threads = 128
148
+ num_math_threads_per_group = 128
149
+ kwargs = {
150
+ "NUM_TMA_THREADS": num_tma_threads,
151
+ "NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
152
+ "BLOCK_K": block_k,
153
+ "NUM_SMS": num_sms,
154
+ "SMEM_SIZE": smem_config[0],
155
+ }
156
+ _, _ = jit_tuner.compile_and_tune(
132
157
  name="m_grouped_gemm_fp8_fp8_bf16_nt",
133
158
  keys={
134
159
  "N": n,
@@ -141,24 +166,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
141
166
  "NUM_STAGES": num_stages,
142
167
  "NUM_TMA_MULTICAST": tma_multicast_config[0],
143
168
  "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
144
- "GEMM_TYPE": "GroupedMasked",
169
+ "GEMM_TYPE": GemmType.GroupedMasked,
145
170
  },
146
171
  space=(),
147
- includes=deep_gemm_includes,
148
- arg_defs=(
149
- ("lhs", torch.float8_e4m3fn),
150
- ("lhs_scales", torch.float),
151
- ("rhs", torch.float8_e4m3fn),
152
- ("rhs_scales", torch.float),
153
- ("out", torch.bfloat16),
154
- ("grouped_layout", torch.int32),
155
- ("m", int),
156
- ("stream", torch.cuda.Stream),
157
- ("num_sms", int),
158
- ("smem_size", int),
159
- ),
160
- template=deep_gemm_grouped_gemm_template,
161
- args=[],
172
+ kwargs=kwargs,
173
+ runtime_cls=FP8GemmRuntime,
162
174
  )
163
175
 
164
176
 
@@ -168,9 +180,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
168
180
  num_groups: int,
169
181
  config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
170
182
  ) -> None:
171
- global deep_gemm_includes, deep_gemm_grouped_gemm_template
172
- _, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
173
- _ = jit_tuner.compile_and_tune(
183
+ num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
184
+ block_k = 128
185
+ num_tma_threads = 128
186
+ num_math_threads_per_group = 128
187
+ kwargs = {
188
+ "NUM_TMA_THREADS": num_tma_threads,
189
+ "NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
190
+ "BLOCK_K": block_k,
191
+ "NUM_SMS": num_sms,
192
+ "SMEM_SIZE": smem_config[0],
193
+ }
194
+ _, _ = jit_tuner.compile_and_tune(
174
195
  name="m_grouped_gemm_fp8_fp8_bf16_nt",
175
196
  keys={
176
197
  "N": n,
@@ -183,25 +204,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
183
204
  "NUM_STAGES": num_stages,
184
205
  "NUM_TMA_MULTICAST": tma_multicast_config[0],
185
206
  "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
186
- "GEMM_TYPE": "GroupedContiguous",
207
+ "GEMM_TYPE": GemmType.GroupedContiguous,
187
208
  },
188
209
  space=(),
189
- includes=deep_gemm_includes,
190
- arg_defs=(
191
- ("lhs", torch.float8_e4m3fn),
192
- ("lhs_scales", torch.float),
193
- ("rhs", torch.float8_e4m3fn),
194
- ("rhs_scales", torch.float),
195
- ("out", torch.bfloat16),
196
- ("grouped_layout", torch.int32),
197
- ("m", int),
198
- ("num_groups", int),
199
- ("stream", torch.cuda.Stream),
200
- ("num_sms", int),
201
- ("smem_size", int),
202
- ),
203
- template=deep_gemm_grouped_gemm_template,
204
- args=[],
210
+ kwargs=kwargs,
211
+ runtime_cls=FP8GemmRuntime,
205
212
  )
206
213
 
207
214
 
@@ -211,9 +218,20 @@ def _compile_gemm_nt_f8f8bf16_one(
211
218
  _: int, # _ is a dummy parameter to align with other interfaces
212
219
  config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
213
220
  ) -> None:
214
- global deep_gemm_includes, deep_gemm_gemm_template
215
- _, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
216
- _ = jit_tuner.compile_and_tune(
221
+ num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
222
+ block_k = 128
223
+ num_tma_threads = 128
224
+ num_math_threads_per_group = 128
225
+ kwargs = {
226
+ "GEMM_TYPE": GemmType.Normal,
227
+ "NUM_TMA_THREADS": num_tma_threads,
228
+ "NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
229
+ "NUM_GROUPS": 1,
230
+ "BLOCK_K": block_k,
231
+ "NUM_SMS": num_sms,
232
+ "SMEM_SIZE": smem_config[0],
233
+ }
234
+ _, _ = jit_tuner.compile_and_tune(
217
235
  name="gemm_fp8_fp8_bf16_nt",
218
236
  keys={
219
237
  "N": n,
@@ -227,20 +245,8 @@ def _compile_gemm_nt_f8f8bf16_one(
227
245
  "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
228
246
  },
229
247
  space=(),
230
- includes=deep_gemm_includes,
231
- arg_defs=(
232
- ("lhs", torch.float8_e4m3fn),
233
- ("lhs_scales", torch.float),
234
- ("rhs", torch.float8_e4m3fn),
235
- ("rhs_scales", torch.float),
236
- ("out", torch.bfloat16),
237
- ("m", int),
238
- ("stream", torch.cuda.Stream),
239
- ("num_sms", int),
240
- ("smem_size", int),
241
- ),
242
- template=deep_gemm_gemm_template,
243
- args=[],
248
+ kwargs=kwargs,
249
+ runtime_cls=FP8GemmRuntime,
244
250
  )
245
251
 
246
252
 
@@ -293,7 +299,7 @@ def _maybe_compile_deep_gemm_one_type_all(
293
299
  logger.info(
294
300
  f"Try DeepGEMM JIT Compiling for "
295
301
  f"<{kernel_helper.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
296
- f"{' It only takes a litte time (typically 1 sec) if you have run `python3 -m sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}"
302
+ f"{' It only takes a little time (typically 1 sec) if you have run `python3 -m sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}"
297
303
  )
298
304
 
299
305
  # NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
@@ -368,7 +374,7 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
368
374
 
369
375
  from deep_gemm.jit.runtime import RuntimeCache
370
376
 
371
- origin_func = RuntimeCache.__getitem__
377
+ origin_func = RuntimeCache.get
372
378
 
373
379
  def __patched_func(self, *args, **kwargs):
374
380
  ret = origin_func(self, *args, **kwargs)
@@ -380,6 +386,6 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
380
386
  )
381
387
  return ret
382
388
 
383
- RuntimeCache.__getitem__ = __patched_func
389
+ RuntimeCache.get = __patched_func
384
390
  yield
385
- RuntimeCache.__getitem__ = origin_func
391
+ RuntimeCache.get = origin_func