sglang 0.4.5.post3__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. sglang/bench_one_batch.py +19 -3
  2. sglang/bench_serving.py +8 -9
  3. sglang/compile_deep_gemm.py +45 -4
  4. sglang/srt/code_completion_parser.py +1 -1
  5. sglang/srt/configs/deepseekvl2.py +1 -1
  6. sglang/srt/configs/model_config.py +9 -3
  7. sglang/srt/constrained/llguidance_backend.py +78 -61
  8. sglang/srt/conversation.py +34 -1
  9. sglang/srt/disaggregation/decode.py +59 -11
  10. sglang/srt/disaggregation/mini_lb.py +45 -8
  11. sglang/srt/disaggregation/mooncake/conn.py +198 -31
  12. sglang/srt/disaggregation/prefill.py +24 -9
  13. sglang/srt/entrypoints/http_server.py +8 -2
  14. sglang/srt/function_call_parser.py +77 -5
  15. sglang/srt/layers/attention/base_attn_backend.py +3 -0
  16. sglang/srt/layers/attention/flashattention_backend.py +28 -10
  17. sglang/srt/layers/attention/flashmla_backend.py +8 -11
  18. sglang/srt/layers/attention/vision.py +2 -0
  19. sglang/srt/layers/layernorm.py +38 -16
  20. sglang/srt/layers/logits_processor.py +2 -2
  21. sglang/srt/layers/moe/fused_moe_native.py +2 -4
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
  25. sglang/srt/layers/pooler.py +6 -0
  26. sglang/srt/layers/quantization/awq.py +5 -1
  27. sglang/srt/layers/quantization/deep_gemm.py +17 -10
  28. sglang/srt/layers/quantization/int8_kernel.py +32 -1
  29. sglang/srt/layers/radix_attention.py +13 -3
  30. sglang/srt/layers/rotary_embedding.py +170 -126
  31. sglang/srt/managers/data_parallel_controller.py +10 -3
  32. sglang/srt/managers/io_struct.py +7 -0
  33. sglang/srt/managers/mm_utils.py +85 -28
  34. sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
  35. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
  36. sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
  37. sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
  38. sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
  39. sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
  40. sglang/srt/managers/schedule_batch.py +29 -12
  41. sglang/srt/managers/scheduler.py +31 -20
  42. sglang/srt/managers/tokenizer_manager.py +5 -1
  43. sglang/srt/mem_cache/memory_pool.py +87 -0
  44. sglang/srt/model_executor/cuda_graph_runner.py +4 -3
  45. sglang/srt/model_executor/forward_batch_info.py +51 -95
  46. sglang/srt/model_executor/model_runner.py +11 -24
  47. sglang/srt/models/deepseek.py +12 -2
  48. sglang/srt/models/deepseek_nextn.py +101 -6
  49. sglang/srt/models/deepseek_v2.py +144 -70
  50. sglang/srt/models/deepseek_vl2.py +9 -4
  51. sglang/srt/models/gemma3_causal.py +1 -1
  52. sglang/srt/models/llama4.py +0 -1
  53. sglang/srt/models/minicpmo.py +5 -1
  54. sglang/srt/models/mllama4.py +2 -2
  55. sglang/srt/models/qwen2_5_vl.py +3 -6
  56. sglang/srt/models/qwen2_vl.py +3 -7
  57. sglang/srt/models/roberta.py +178 -0
  58. sglang/srt/openai_api/adapter.py +18 -8
  59. sglang/srt/server_args.py +15 -22
  60. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  61. sglang/srt/torch_memory_saver_adapter.py +10 -1
  62. sglang/srt/utils.py +2 -1
  63. sglang/test/runners.py +6 -13
  64. sglang/test/test_utils.py +36 -18
  65. sglang/version.py +1 -1
  66. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/METADATA +4 -5
  67. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/RECORD +70 -68
  68. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
  69. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
  70. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
@@ -1,102 +1,102 @@
1
1
  {
2
2
  "1": {
3
- "BLOCK_SIZE_M": 64,
3
+ "BLOCK_SIZE_M": 16,
4
4
  "BLOCK_SIZE_N": 64,
5
5
  "BLOCK_SIZE_K": 128,
6
- "GROUP_SIZE_M": 16,
6
+ "GROUP_SIZE_M": 1,
7
7
  "num_warps": 4,
8
8
  "num_stages": 4
9
9
  },
10
10
  "2": {
11
- "BLOCK_SIZE_M": 64,
12
- "BLOCK_SIZE_N": 32,
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
13
  "BLOCK_SIZE_K": 128,
14
- "GROUP_SIZE_M": 1,
14
+ "GROUP_SIZE_M": 16,
15
15
  "num_warps": 4,
16
- "num_stages": 3
16
+ "num_stages": 4
17
17
  },
18
18
  "4": {
19
- "BLOCK_SIZE_M": 64,
20
- "BLOCK_SIZE_N": 64,
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
21
  "BLOCK_SIZE_K": 128,
22
- "GROUP_SIZE_M": 1,
22
+ "GROUP_SIZE_M": 16,
23
23
  "num_warps": 4,
24
24
  "num_stages": 4
25
25
  },
26
26
  "8": {
27
- "BLOCK_SIZE_M": 64,
27
+ "BLOCK_SIZE_M": 16,
28
28
  "BLOCK_SIZE_N": 128,
29
29
  "BLOCK_SIZE_K": 128,
30
30
  "GROUP_SIZE_M": 32,
31
31
  "num_warps": 4,
32
- "num_stages": 3
32
+ "num_stages": 4
33
33
  },
34
34
  "16": {
35
- "BLOCK_SIZE_M": 64,
35
+ "BLOCK_SIZE_M": 16,
36
36
  "BLOCK_SIZE_N": 128,
37
37
  "BLOCK_SIZE_K": 128,
38
- "GROUP_SIZE_M": 16,
38
+ "GROUP_SIZE_M": 1,
39
39
  "num_warps": 4,
40
40
  "num_stages": 3
41
41
  },
42
42
  "24": {
43
- "BLOCK_SIZE_M": 64,
43
+ "BLOCK_SIZE_M": 16,
44
44
  "BLOCK_SIZE_N": 128,
45
45
  "BLOCK_SIZE_K": 128,
46
- "GROUP_SIZE_M": 16,
46
+ "GROUP_SIZE_M": 1,
47
47
  "num_warps": 4,
48
- "num_stages": 3
48
+ "num_stages": 4
49
49
  },
50
50
  "32": {
51
- "BLOCK_SIZE_M": 64,
51
+ "BLOCK_SIZE_M": 16,
52
52
  "BLOCK_SIZE_N": 128,
53
53
  "BLOCK_SIZE_K": 128,
54
- "GROUP_SIZE_M": 32,
54
+ "GROUP_SIZE_M": 16,
55
55
  "num_warps": 4,
56
- "num_stages": 3
56
+ "num_stages": 5
57
57
  },
58
58
  "48": {
59
- "BLOCK_SIZE_M": 64,
59
+ "BLOCK_SIZE_M": 16,
60
60
  "BLOCK_SIZE_N": 128,
61
61
  "BLOCK_SIZE_K": 128,
62
- "GROUP_SIZE_M": 32,
62
+ "GROUP_SIZE_M": 64,
63
63
  "num_warps": 4,
64
- "num_stages": 3
64
+ "num_stages": 4
65
65
  },
66
66
  "64": {
67
- "BLOCK_SIZE_M": 64,
67
+ "BLOCK_SIZE_M": 16,
68
68
  "BLOCK_SIZE_N": 128,
69
69
  "BLOCK_SIZE_K": 128,
70
- "GROUP_SIZE_M": 64,
70
+ "GROUP_SIZE_M": 32,
71
71
  "num_warps": 4,
72
72
  "num_stages": 3
73
73
  },
74
74
  "96": {
75
- "BLOCK_SIZE_M": 64,
75
+ "BLOCK_SIZE_M": 16,
76
76
  "BLOCK_SIZE_N": 128,
77
77
  "BLOCK_SIZE_K": 128,
78
- "GROUP_SIZE_M": 64,
78
+ "GROUP_SIZE_M": 32,
79
79
  "num_warps": 4,
80
80
  "num_stages": 3
81
81
  },
82
82
  "128": {
83
- "BLOCK_SIZE_M": 64,
83
+ "BLOCK_SIZE_M": 16,
84
84
  "BLOCK_SIZE_N": 128,
85
85
  "BLOCK_SIZE_K": 128,
86
- "GROUP_SIZE_M": 16,
86
+ "GROUP_SIZE_M": 64,
87
87
  "num_warps": 4,
88
88
  "num_stages": 3
89
89
  },
90
90
  "256": {
91
- "BLOCK_SIZE_M": 64,
91
+ "BLOCK_SIZE_M": 16,
92
92
  "BLOCK_SIZE_N": 128,
93
93
  "BLOCK_SIZE_K": 128,
94
- "GROUP_SIZE_M": 16,
94
+ "GROUP_SIZE_M": 64,
95
95
  "num_warps": 4,
96
96
  "num_stages": 3
97
97
  },
98
98
  "512": {
99
- "BLOCK_SIZE_M": 64,
99
+ "BLOCK_SIZE_M": 16,
100
100
  "BLOCK_SIZE_N": 128,
101
101
  "BLOCK_SIZE_K": 128,
102
102
  "GROUP_SIZE_M": 16,
@@ -107,9 +107,9 @@
107
107
  "BLOCK_SIZE_M": 64,
108
108
  "BLOCK_SIZE_N": 128,
109
109
  "BLOCK_SIZE_K": 128,
110
- "GROUP_SIZE_M": 32,
110
+ "GROUP_SIZE_M": 16,
111
111
  "num_warps": 4,
112
- "num_stages": 3
112
+ "num_stages": 4
113
113
  },
114
114
  "1536": {
115
115
  "BLOCK_SIZE_M": 64,
@@ -117,21 +117,21 @@
117
117
  "BLOCK_SIZE_K": 128,
118
118
  "GROUP_SIZE_M": 32,
119
119
  "num_warps": 4,
120
- "num_stages": 3
120
+ "num_stages": 4
121
121
  },
122
122
  "2048": {
123
123
  "BLOCK_SIZE_M": 64,
124
124
  "BLOCK_SIZE_N": 128,
125
125
  "BLOCK_SIZE_K": 128,
126
- "GROUP_SIZE_M": 16,
126
+ "GROUP_SIZE_M": 32,
127
127
  "num_warps": 4,
128
- "num_stages": 3
128
+ "num_stages": 4
129
129
  },
130
130
  "3072": {
131
- "BLOCK_SIZE_M": 128,
132
- "BLOCK_SIZE_N": 64,
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
133
  "BLOCK_SIZE_K": 128,
134
- "GROUP_SIZE_M": 32,
134
+ "GROUP_SIZE_M": 16,
135
135
  "num_warps": 4,
136
136
  "num_stages": 3
137
137
  },
@@ -139,8 +139,8 @@
139
139
  "BLOCK_SIZE_M": 64,
140
140
  "BLOCK_SIZE_N": 128,
141
141
  "BLOCK_SIZE_K": 128,
142
- "GROUP_SIZE_M": 64,
142
+ "GROUP_SIZE_M": 16,
143
143
  "num_warps": 4,
144
- "num_stages": 3
144
+ "num_stages": 4
145
145
  }
146
146
  }
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 32,
15
+ "num_warps": 4,
16
+ "num_stages": 3
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 64,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 256,
37
+ "BLOCK_SIZE_K": 64,
38
+ "GROUP_SIZE_M": 64,
39
+ "num_warps": 4,
40
+ "num_stages": 4
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 16,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 32,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 16,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 16,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 16,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 32,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 256,
101
+ "BLOCK_SIZE_K": 64,
102
+ "GROUP_SIZE_M": 64,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 32,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 64,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }
@@ -13,7 +13,16 @@ import triton
13
13
  import triton.language as tl
14
14
 
15
15
  from sglang.srt.layers.moe.topk import select_experts
16
- from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
16
+ from sglang.srt.layers.quantization.fp8_kernel import (
17
+ per_token_group_quant_fp8,
18
+ scaled_fp8_quant,
19
+ sglang_per_token_group_quant_fp8,
20
+ )
21
+ from sglang.srt.layers.quantization.int8_kernel import (
22
+ per_token_group_quant_int8,
23
+ per_token_quant_int8,
24
+ sglang_per_token_group_quant_int8,
25
+ )
17
26
  from sglang.srt.utils import (
18
27
  direct_register_custom_op,
19
28
  get_bool_env_var,
@@ -746,18 +755,6 @@ def invoke_fused_moe_kernel(
746
755
  block_shape: Optional[List[int]] = None,
747
756
  no_combine: bool = False,
748
757
  ) -> None:
749
- from sglang.srt.layers.quantization.int8_kernel import (
750
- per_token_group_quant_int8,
751
- per_token_quant_int8,
752
- )
753
-
754
- if _is_cuda:
755
- from sglang.srt.layers.quantization.fp8_kernel import (
756
- sglang_per_token_group_quant_fp8,
757
- )
758
- else:
759
- from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
760
-
761
758
  assert topk_weights.stride(1) == 1
762
759
  assert sorted_token_ids.stride(0) == 1
763
760
 
@@ -794,7 +791,10 @@ def invoke_fused_moe_kernel(
794
791
  # activation block-wise int8 quantization
795
792
  assert len(block_shape) == 2
796
793
  block_n, block_k = block_shape[0], block_shape[1]
797
- A, A_scale = per_token_group_quant_int8(A, block_k)
794
+ if _is_cuda:
795
+ A, A_scale = sglang_per_token_group_quant_int8(A, block_k)
796
+ else:
797
+ A, A_scale = per_token_group_quant_int8(A, block_k)
798
798
  assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
799
799
  assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
800
800
  assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
@@ -940,7 +940,10 @@ def get_moe_configs(
940
940
  )
941
941
  if os.path.exists(config_file_path):
942
942
  with open(config_file_path) as f:
943
- logger.info("Using configuration from %s for MoE layer.", config_file_path)
943
+ logger.info(
944
+ "Using configuration from %s for MoE layer. Please note that due to the large number of configs under fused_moe_triton/configs potentially not being tuned with the corresponding Triton version in your current environment, using the current configs may result in performance degradation. To achieve best performance, you can consider re-tuning the Triton fused MOE kernel in your current environment. For the tuning method, please refer to: https://github.com/sgl-project/sglang/blob/main/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py. ",
945
+ config_file_path,
946
+ )
944
947
  # If a configuration has been found, return it
945
948
  return {int(key): val for key, val in json.load(f).items()}
946
949
 
@@ -12,6 +12,7 @@ from sglang.srt.model_executor.model_runner import ForwardBatch
12
12
 
13
13
  class PoolingType(IntEnum):
14
14
  LAST = 0
15
+ CLS = 1
15
16
 
16
17
 
17
18
  @dataclass
@@ -41,6 +42,11 @@ class Pooler(nn.Module):
41
42
  if self.pooling_type == PoolingType.LAST:
42
43
  last_token_indices = torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1
43
44
  pooled_data = hidden_states[last_token_indices]
45
+ elif self.pooling_type == PoolingType.CLS:
46
+ prompt_lens = forward_batch.extend_seq_lens
47
+ first_token_flat_indices = torch.zeros_like(prompt_lens)
48
+ first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
49
+ pooled_data = hidden_states[first_token_flat_indices]
44
50
  else:
45
51
  raise ValueError(f"Invalid pooling type: {self.pooling_type}")
46
52
 
@@ -3,7 +3,6 @@ import logging
3
3
  from typing import Any, Dict, List, Optional
4
4
 
5
5
  import torch
6
- from sgl_kernel import awq_dequantize
7
6
 
8
7
  from sglang.srt.layers.linear import (
9
8
  LinearBase,
@@ -12,6 +11,11 @@ from sglang.srt.layers.linear import (
12
11
  )
13
12
  from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
14
13
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
14
+ from sglang.srt.utils import is_cuda
15
+
16
+ _is_cuda = is_cuda()
17
+ if _is_cuda:
18
+ from sgl_kernel import awq_dequantize
15
19
 
16
20
  logger = logging.getLogger(__name__)
17
21
 
@@ -25,7 +25,7 @@ if is_cuda():
25
25
 
26
26
  sm_version = get_device_sm()
27
27
  if sm_version == 90:
28
- if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="false"):
28
+ if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
29
29
  _ENABLE_JIT_DEEPGEMM = True
30
30
 
31
31
  logger = logging.getLogger(__name__)
@@ -34,9 +34,10 @@ _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
34
34
  _ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var(
35
35
  "SGL_JIT_DEEPGEMM_PRECOMPILE", "true"
36
36
  )
37
- _DO_COMPILE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true")
37
+ _DO_COMPILE_ALL = True
38
+ _IS_FIRST_RANK_ON_NODE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true")
38
39
  _COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
39
- _IN_PRE_COMPILE_STAGE = get_bool_env_var("SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE", "false")
40
+ _IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false")
40
41
 
41
42
  # Force redirect deep_gemm cache_dir
42
43
  os.environ["DG_CACHE_DIR"] = os.getenv(
@@ -46,7 +47,8 @@ os.environ["DG_CACHE_DIR"] = os.getenv(
46
47
 
47
48
  def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
48
49
  global _BUILTIN_M_LIST
49
- global _DO_COMPILE
50
+ global _DO_COMPILE_ALL
51
+ global _IS_FIRST_RANK_ON_NODE
50
52
 
51
53
  # Generate m_max
52
54
  m_max = 1024 * 16
@@ -57,8 +59,13 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
57
59
  m_max = min(1024 * 128, m_max)
58
60
  _BUILTIN_M_LIST = list(range(1, m_max + 1))
59
61
 
60
- # Check if is the first rank on node
61
- _DO_COMPILE = ServerArgs.base_gpu_id == gpu_id
62
+ _IS_FIRST_RANK_ON_NODE = ServerArgs.base_gpu_id == gpu_id
63
+
64
+ # Check if is the first rank on node.
65
+ # Default each rank will try compile all Ms to
66
+ # load all symbols at the launch stages.
67
+ # Avoid loading symbols at the serving stages.
68
+ _DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE or not _IN_PRECOMPILE_STAGE
62
69
 
63
70
 
64
71
  class DeepGemmKernelType(IntEnum):
@@ -89,7 +96,7 @@ _INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dic
89
96
 
90
97
 
91
98
  def _compile_warning_1():
92
- if not _IN_PRE_COMPILE_STAGE:
99
+ if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
93
100
  logger.warning(
94
101
  "Entering DeepGEMM JIT Pre-Complie session. "
95
102
  "And it may takes a long time(Typically 10-20 mins) "
@@ -276,7 +283,7 @@ def _maybe_compile_deep_gemm_one_type_all(
276
283
  query_key = (kernel_type, n, k, num_groups)
277
284
  if (
278
285
  _ENABLE_JIT_DEEPGEMM_PRECOMPILE
279
- and _DO_COMPILE
286
+ and _DO_COMPILE_ALL
280
287
  and _INITIALIZATION_DICT.get(query_key) is None
281
288
  ):
282
289
  _INITIALIZATION_DICT[query_key] = True
@@ -286,7 +293,7 @@ def _maybe_compile_deep_gemm_one_type_all(
286
293
  logger.info(
287
294
  f"Try DeepGEMM JIT Compiling for "
288
295
  f"<{kernel_helper.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
289
- f"{' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. ' if not _IN_PRE_COMPILE_STAGE else ''}"
296
+ f"{' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}"
290
297
  )
291
298
 
292
299
  # NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
@@ -355,7 +362,7 @@ def gemm_nt_f8f8bf16(
355
362
 
356
363
  @contextmanager
357
364
  def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
358
- if _IN_PRE_COMPILE_STAGE:
365
+ if _IN_PRECOMPILE_STAGE:
359
366
  yield
360
367
  return
361
368
 
@@ -8,7 +8,11 @@ import torch
8
8
  import triton
9
9
  import triton.language as tl
10
10
 
11
- from sglang.srt.utils import get_device_name
11
+ from sglang.srt.utils import get_device_name, is_cuda
12
+
13
+ _is_cuda = is_cuda()
14
+ if _is_cuda:
15
+ from sgl_kernel import sgl_per_token_group_quant_int8
12
16
 
13
17
  logger = logging.getLogger(__name__)
14
18
 
@@ -165,6 +169,33 @@ def per_token_group_quant_int8(
165
169
  return x_q, x_s
166
170
 
167
171
 
172
+ def sglang_per_token_group_quant_int8(
173
+ x: torch.Tensor,
174
+ group_size: int,
175
+ eps: float = 1e-10,
176
+ dtype: torch.dtype = torch.int8,
177
+ ):
178
+ assert (
179
+ x.shape[-1] % group_size == 0
180
+ ), "the last dimension of `x` cannot be divisible by `group_size`"
181
+ assert x.is_contiguous(), "`x` is not contiguous"
182
+
183
+ iinfo = torch.iinfo(dtype)
184
+ int8_max = iinfo.max
185
+ int8_min = iinfo.min
186
+
187
+ x_q = torch.empty_like(x, device=x.device, dtype=dtype)
188
+ x_s = torch.empty(
189
+ x.shape[:-1] + (x.shape[-1] // group_size,),
190
+ device=x.device,
191
+ dtype=torch.float32,
192
+ )
193
+
194
+ sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
195
+
196
+ return x_q, x_s
197
+
198
+
168
199
  @triton.jit
169
200
  def _w8a8_block_int8_matmul(
170
201
  # Pointers to inputs and output
@@ -87,13 +87,23 @@ class RadixAttention(nn.Module):
87
87
  v,
88
88
  forward_batch: ForwardBatch,
89
89
  save_kv_cache: bool = True,
90
+ **kwargs,
90
91
  ):
91
92
  if k is not None:
92
93
  # For cross-layer sharing, kv can be None
93
94
  assert v is not None
94
- k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
95
- v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
95
+ if "k_rope" not in kwargs:
96
+ k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
97
+ v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
98
+ else:
99
+ k = k.view(-1, self.tp_k_head_num, self.v_head_dim)
96
100
 
97
101
  return forward_batch.attn_backend.forward(
98
- q, k, v, self, forward_batch, save_kv_cache
102
+ q,
103
+ k,
104
+ v,
105
+ self,
106
+ forward_batch,
107
+ save_kv_cache,
108
+ **kwargs,
99
109
  )