sglang 0.4.5.post3__py3-none-any.whl → 0.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +19 -3
- sglang/bench_serving.py +8 -9
- sglang/compile_deep_gemm.py +45 -4
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +1 -1
- sglang/srt/configs/model_config.py +9 -3
- sglang/srt/constrained/llguidance_backend.py +78 -61
- sglang/srt/conversation.py +34 -1
- sglang/srt/disaggregation/decode.py +59 -11
- sglang/srt/disaggregation/mini_lb.py +45 -8
- sglang/srt/disaggregation/mooncake/conn.py +198 -31
- sglang/srt/disaggregation/prefill.py +24 -9
- sglang/srt/entrypoints/http_server.py +8 -2
- sglang/srt/function_call_parser.py +77 -5
- sglang/srt/layers/attention/base_attn_backend.py +3 -0
- sglang/srt/layers/attention/flashattention_backend.py +28 -10
- sglang/srt/layers/attention/flashmla_backend.py +8 -11
- sglang/srt/layers/attention/vision.py +2 -0
- sglang/srt/layers/layernorm.py +38 -16
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/fused_moe_native.py +2 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
- sglang/srt/layers/pooler.py +6 -0
- sglang/srt/layers/quantization/awq.py +5 -1
- sglang/srt/layers/quantization/deep_gemm.py +17 -10
- sglang/srt/layers/quantization/int8_kernel.py +32 -1
- sglang/srt/layers/radix_attention.py +13 -3
- sglang/srt/layers/rotary_embedding.py +170 -126
- sglang/srt/managers/data_parallel_controller.py +10 -3
- sglang/srt/managers/io_struct.py +7 -0
- sglang/srt/managers/mm_utils.py +85 -28
- sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
- sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
- sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
- sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
- sglang/srt/managers/schedule_batch.py +29 -12
- sglang/srt/managers/scheduler.py +31 -20
- sglang/srt/managers/tokenizer_manager.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +87 -0
- sglang/srt/model_executor/cuda_graph_runner.py +4 -3
- sglang/srt/model_executor/forward_batch_info.py +51 -95
- sglang/srt/model_executor/model_runner.py +11 -24
- sglang/srt/models/deepseek.py +12 -2
- sglang/srt/models/deepseek_nextn.py +101 -6
- sglang/srt/models/deepseek_v2.py +144 -70
- sglang/srt/models/deepseek_vl2.py +9 -4
- sglang/srt/models/gemma3_causal.py +1 -1
- sglang/srt/models/llama4.py +0 -1
- sglang/srt/models/minicpmo.py +5 -1
- sglang/srt/models/mllama4.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +3 -6
- sglang/srt/models/qwen2_vl.py +3 -7
- sglang/srt/models/roberta.py +178 -0
- sglang/srt/openai_api/adapter.py +18 -8
- sglang/srt/server_args.py +15 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/torch_memory_saver_adapter.py +10 -1
- sglang/srt/utils.py +2 -1
- sglang/test/runners.py +6 -13
- sglang/test/test_utils.py +36 -18
- sglang/version.py +1 -1
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/METADATA +4 -5
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/RECORD +70 -68
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
@@ -1,102 +1,102 @@
|
|
1
1
|
{
|
2
2
|
"1": {
|
3
|
-
"BLOCK_SIZE_M":
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
4
|
"BLOCK_SIZE_N": 64,
|
5
5
|
"BLOCK_SIZE_K": 128,
|
6
|
-
"GROUP_SIZE_M":
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
7
|
"num_warps": 4,
|
8
8
|
"num_stages": 4
|
9
9
|
},
|
10
10
|
"2": {
|
11
|
-
"BLOCK_SIZE_M":
|
12
|
-
"BLOCK_SIZE_N":
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
13
|
"BLOCK_SIZE_K": 128,
|
14
|
-
"GROUP_SIZE_M":
|
14
|
+
"GROUP_SIZE_M": 16,
|
15
15
|
"num_warps": 4,
|
16
|
-
"num_stages":
|
16
|
+
"num_stages": 4
|
17
17
|
},
|
18
18
|
"4": {
|
19
|
-
"BLOCK_SIZE_M":
|
20
|
-
"BLOCK_SIZE_N":
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
21
|
"BLOCK_SIZE_K": 128,
|
22
|
-
"GROUP_SIZE_M":
|
22
|
+
"GROUP_SIZE_M": 16,
|
23
23
|
"num_warps": 4,
|
24
24
|
"num_stages": 4
|
25
25
|
},
|
26
26
|
"8": {
|
27
|
-
"BLOCK_SIZE_M":
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
28
|
"BLOCK_SIZE_N": 128,
|
29
29
|
"BLOCK_SIZE_K": 128,
|
30
30
|
"GROUP_SIZE_M": 32,
|
31
31
|
"num_warps": 4,
|
32
|
-
"num_stages":
|
32
|
+
"num_stages": 4
|
33
33
|
},
|
34
34
|
"16": {
|
35
|
-
"BLOCK_SIZE_M":
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
36
|
"BLOCK_SIZE_N": 128,
|
37
37
|
"BLOCK_SIZE_K": 128,
|
38
|
-
"GROUP_SIZE_M":
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
39
|
"num_warps": 4,
|
40
40
|
"num_stages": 3
|
41
41
|
},
|
42
42
|
"24": {
|
43
|
-
"BLOCK_SIZE_M":
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
44
|
"BLOCK_SIZE_N": 128,
|
45
45
|
"BLOCK_SIZE_K": 128,
|
46
|
-
"GROUP_SIZE_M":
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
47
|
"num_warps": 4,
|
48
|
-
"num_stages":
|
48
|
+
"num_stages": 4
|
49
49
|
},
|
50
50
|
"32": {
|
51
|
-
"BLOCK_SIZE_M":
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
52
|
"BLOCK_SIZE_N": 128,
|
53
53
|
"BLOCK_SIZE_K": 128,
|
54
|
-
"GROUP_SIZE_M":
|
54
|
+
"GROUP_SIZE_M": 16,
|
55
55
|
"num_warps": 4,
|
56
|
-
"num_stages":
|
56
|
+
"num_stages": 5
|
57
57
|
},
|
58
58
|
"48": {
|
59
|
-
"BLOCK_SIZE_M":
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
60
|
"BLOCK_SIZE_N": 128,
|
61
61
|
"BLOCK_SIZE_K": 128,
|
62
|
-
"GROUP_SIZE_M":
|
62
|
+
"GROUP_SIZE_M": 64,
|
63
63
|
"num_warps": 4,
|
64
|
-
"num_stages":
|
64
|
+
"num_stages": 4
|
65
65
|
},
|
66
66
|
"64": {
|
67
|
-
"BLOCK_SIZE_M":
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
68
|
"BLOCK_SIZE_N": 128,
|
69
69
|
"BLOCK_SIZE_K": 128,
|
70
|
-
"GROUP_SIZE_M":
|
70
|
+
"GROUP_SIZE_M": 32,
|
71
71
|
"num_warps": 4,
|
72
72
|
"num_stages": 3
|
73
73
|
},
|
74
74
|
"96": {
|
75
|
-
"BLOCK_SIZE_M":
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
76
|
"BLOCK_SIZE_N": 128,
|
77
77
|
"BLOCK_SIZE_K": 128,
|
78
|
-
"GROUP_SIZE_M":
|
78
|
+
"GROUP_SIZE_M": 32,
|
79
79
|
"num_warps": 4,
|
80
80
|
"num_stages": 3
|
81
81
|
},
|
82
82
|
"128": {
|
83
|
-
"BLOCK_SIZE_M":
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
84
|
"BLOCK_SIZE_N": 128,
|
85
85
|
"BLOCK_SIZE_K": 128,
|
86
|
-
"GROUP_SIZE_M":
|
86
|
+
"GROUP_SIZE_M": 64,
|
87
87
|
"num_warps": 4,
|
88
88
|
"num_stages": 3
|
89
89
|
},
|
90
90
|
"256": {
|
91
|
-
"BLOCK_SIZE_M":
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
92
|
"BLOCK_SIZE_N": 128,
|
93
93
|
"BLOCK_SIZE_K": 128,
|
94
|
-
"GROUP_SIZE_M":
|
94
|
+
"GROUP_SIZE_M": 64,
|
95
95
|
"num_warps": 4,
|
96
96
|
"num_stages": 3
|
97
97
|
},
|
98
98
|
"512": {
|
99
|
-
"BLOCK_SIZE_M":
|
99
|
+
"BLOCK_SIZE_M": 16,
|
100
100
|
"BLOCK_SIZE_N": 128,
|
101
101
|
"BLOCK_SIZE_K": 128,
|
102
102
|
"GROUP_SIZE_M": 16,
|
@@ -107,9 +107,9 @@
|
|
107
107
|
"BLOCK_SIZE_M": 64,
|
108
108
|
"BLOCK_SIZE_N": 128,
|
109
109
|
"BLOCK_SIZE_K": 128,
|
110
|
-
"GROUP_SIZE_M":
|
110
|
+
"GROUP_SIZE_M": 16,
|
111
111
|
"num_warps": 4,
|
112
|
-
"num_stages":
|
112
|
+
"num_stages": 4
|
113
113
|
},
|
114
114
|
"1536": {
|
115
115
|
"BLOCK_SIZE_M": 64,
|
@@ -117,21 +117,21 @@
|
|
117
117
|
"BLOCK_SIZE_K": 128,
|
118
118
|
"GROUP_SIZE_M": 32,
|
119
119
|
"num_warps": 4,
|
120
|
-
"num_stages":
|
120
|
+
"num_stages": 4
|
121
121
|
},
|
122
122
|
"2048": {
|
123
123
|
"BLOCK_SIZE_M": 64,
|
124
124
|
"BLOCK_SIZE_N": 128,
|
125
125
|
"BLOCK_SIZE_K": 128,
|
126
|
-
"GROUP_SIZE_M":
|
126
|
+
"GROUP_SIZE_M": 32,
|
127
127
|
"num_warps": 4,
|
128
|
-
"num_stages":
|
128
|
+
"num_stages": 4
|
129
129
|
},
|
130
130
|
"3072": {
|
131
|
-
"BLOCK_SIZE_M":
|
132
|
-
"BLOCK_SIZE_N":
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
133
|
"BLOCK_SIZE_K": 128,
|
134
|
-
"GROUP_SIZE_M":
|
134
|
+
"GROUP_SIZE_M": 16,
|
135
135
|
"num_warps": 4,
|
136
136
|
"num_stages": 3
|
137
137
|
},
|
@@ -139,8 +139,8 @@
|
|
139
139
|
"BLOCK_SIZE_M": 64,
|
140
140
|
"BLOCK_SIZE_N": 128,
|
141
141
|
"BLOCK_SIZE_K": 128,
|
142
|
-
"GROUP_SIZE_M":
|
142
|
+
"GROUP_SIZE_M": 16,
|
143
143
|
"num_warps": 4,
|
144
|
-
"num_stages":
|
144
|
+
"num_stages": 4
|
145
145
|
}
|
146
146
|
}
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 32,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 3
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 64,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 256,
|
37
|
+
"BLOCK_SIZE_K": 64,
|
38
|
+
"GROUP_SIZE_M": 64,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 4
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 16,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 128,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 32,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 16,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 16,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 1,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 16,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 32,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 16,
|
100
|
+
"BLOCK_SIZE_N": 256,
|
101
|
+
"BLOCK_SIZE_K": 64,
|
102
|
+
"GROUP_SIZE_M": 64,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 16,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 32,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 64,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 16,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 16,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 3
|
145
|
+
}
|
146
|
+
}
|
@@ -13,7 +13,16 @@ import triton
|
|
13
13
|
import triton.language as tl
|
14
14
|
|
15
15
|
from sglang.srt.layers.moe.topk import select_experts
|
16
|
-
from sglang.srt.layers.quantization.fp8_kernel import
|
16
|
+
from sglang.srt.layers.quantization.fp8_kernel import (
|
17
|
+
per_token_group_quant_fp8,
|
18
|
+
scaled_fp8_quant,
|
19
|
+
sglang_per_token_group_quant_fp8,
|
20
|
+
)
|
21
|
+
from sglang.srt.layers.quantization.int8_kernel import (
|
22
|
+
per_token_group_quant_int8,
|
23
|
+
per_token_quant_int8,
|
24
|
+
sglang_per_token_group_quant_int8,
|
25
|
+
)
|
17
26
|
from sglang.srt.utils import (
|
18
27
|
direct_register_custom_op,
|
19
28
|
get_bool_env_var,
|
@@ -746,18 +755,6 @@ def invoke_fused_moe_kernel(
|
|
746
755
|
block_shape: Optional[List[int]] = None,
|
747
756
|
no_combine: bool = False,
|
748
757
|
) -> None:
|
749
|
-
from sglang.srt.layers.quantization.int8_kernel import (
|
750
|
-
per_token_group_quant_int8,
|
751
|
-
per_token_quant_int8,
|
752
|
-
)
|
753
|
-
|
754
|
-
if _is_cuda:
|
755
|
-
from sglang.srt.layers.quantization.fp8_kernel import (
|
756
|
-
sglang_per_token_group_quant_fp8,
|
757
|
-
)
|
758
|
-
else:
|
759
|
-
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
760
|
-
|
761
758
|
assert topk_weights.stride(1) == 1
|
762
759
|
assert sorted_token_ids.stride(0) == 1
|
763
760
|
|
@@ -794,7 +791,10 @@ def invoke_fused_moe_kernel(
|
|
794
791
|
# activation block-wise int8 quantization
|
795
792
|
assert len(block_shape) == 2
|
796
793
|
block_n, block_k = block_shape[0], block_shape[1]
|
797
|
-
|
794
|
+
if _is_cuda:
|
795
|
+
A, A_scale = sglang_per_token_group_quant_int8(A, block_k)
|
796
|
+
else:
|
797
|
+
A, A_scale = per_token_group_quant_int8(A, block_k)
|
798
798
|
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
799
799
|
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
800
800
|
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
@@ -940,7 +940,10 @@ def get_moe_configs(
|
|
940
940
|
)
|
941
941
|
if os.path.exists(config_file_path):
|
942
942
|
with open(config_file_path) as f:
|
943
|
-
logger.info(
|
943
|
+
logger.info(
|
944
|
+
"Using configuration from %s for MoE layer. Please note that due to the large number of configs under fused_moe_triton/configs potentially not being tuned with the corresponding Triton version in your current environment, using the current configs may result in performance degradation. To achieve best performance, you can consider re-tuning the Triton fused MOE kernel in your current environment. For the tuning method, please refer to: https://github.com/sgl-project/sglang/blob/main/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py. ",
|
945
|
+
config_file_path,
|
946
|
+
)
|
944
947
|
# If a configuration has been found, return it
|
945
948
|
return {int(key): val for key, val in json.load(f).items()}
|
946
949
|
|
sglang/srt/layers/pooler.py
CHANGED
@@ -12,6 +12,7 @@ from sglang.srt.model_executor.model_runner import ForwardBatch
|
|
12
12
|
|
13
13
|
class PoolingType(IntEnum):
|
14
14
|
LAST = 0
|
15
|
+
CLS = 1
|
15
16
|
|
16
17
|
|
17
18
|
@dataclass
|
@@ -41,6 +42,11 @@ class Pooler(nn.Module):
|
|
41
42
|
if self.pooling_type == PoolingType.LAST:
|
42
43
|
last_token_indices = torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1
|
43
44
|
pooled_data = hidden_states[last_token_indices]
|
45
|
+
elif self.pooling_type == PoolingType.CLS:
|
46
|
+
prompt_lens = forward_batch.extend_seq_lens
|
47
|
+
first_token_flat_indices = torch.zeros_like(prompt_lens)
|
48
|
+
first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
|
49
|
+
pooled_data = hidden_states[first_token_flat_indices]
|
44
50
|
else:
|
45
51
|
raise ValueError(f"Invalid pooling type: {self.pooling_type}")
|
46
52
|
|
@@ -3,7 +3,6 @@ import logging
|
|
3
3
|
from typing import Any, Dict, List, Optional
|
4
4
|
|
5
5
|
import torch
|
6
|
-
from sgl_kernel import awq_dequantize
|
7
6
|
|
8
7
|
from sglang.srt.layers.linear import (
|
9
8
|
LinearBase,
|
@@ -12,6 +11,11 @@ from sglang.srt.layers.linear import (
|
|
12
11
|
)
|
13
12
|
from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
|
14
13
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
14
|
+
from sglang.srt.utils import is_cuda
|
15
|
+
|
16
|
+
_is_cuda = is_cuda()
|
17
|
+
if _is_cuda:
|
18
|
+
from sgl_kernel import awq_dequantize
|
15
19
|
|
16
20
|
logger = logging.getLogger(__name__)
|
17
21
|
|
@@ -25,7 +25,7 @@ if is_cuda():
|
|
25
25
|
|
26
26
|
sm_version = get_device_sm()
|
27
27
|
if sm_version == 90:
|
28
|
-
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="
|
28
|
+
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
|
29
29
|
_ENABLE_JIT_DEEPGEMM = True
|
30
30
|
|
31
31
|
logger = logging.getLogger(__name__)
|
@@ -34,9 +34,10 @@ _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
|
|
34
34
|
_ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var(
|
35
35
|
"SGL_JIT_DEEPGEMM_PRECOMPILE", "true"
|
36
36
|
)
|
37
|
-
|
37
|
+
_DO_COMPILE_ALL = True
|
38
|
+
_IS_FIRST_RANK_ON_NODE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true")
|
38
39
|
_COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
|
39
|
-
|
40
|
+
_IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false")
|
40
41
|
|
41
42
|
# Force redirect deep_gemm cache_dir
|
42
43
|
os.environ["DG_CACHE_DIR"] = os.getenv(
|
@@ -46,7 +47,8 @@ os.environ["DG_CACHE_DIR"] = os.getenv(
|
|
46
47
|
|
47
48
|
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
|
48
49
|
global _BUILTIN_M_LIST
|
49
|
-
global
|
50
|
+
global _DO_COMPILE_ALL
|
51
|
+
global _IS_FIRST_RANK_ON_NODE
|
50
52
|
|
51
53
|
# Generate m_max
|
52
54
|
m_max = 1024 * 16
|
@@ -57,8 +59,13 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
|
|
57
59
|
m_max = min(1024 * 128, m_max)
|
58
60
|
_BUILTIN_M_LIST = list(range(1, m_max + 1))
|
59
61
|
|
60
|
-
|
61
|
-
|
62
|
+
_IS_FIRST_RANK_ON_NODE = ServerArgs.base_gpu_id == gpu_id
|
63
|
+
|
64
|
+
# Check if is the first rank on node.
|
65
|
+
# Default each rank will try compile all Ms to
|
66
|
+
# load all symbols at the launch stages.
|
67
|
+
# Avoid loading symbols at the serving stages.
|
68
|
+
_DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE or not _IN_PRECOMPILE_STAGE
|
62
69
|
|
63
70
|
|
64
71
|
class DeepGemmKernelType(IntEnum):
|
@@ -89,7 +96,7 @@ _INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dic
|
|
89
96
|
|
90
97
|
|
91
98
|
def _compile_warning_1():
|
92
|
-
if not
|
99
|
+
if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
|
93
100
|
logger.warning(
|
94
101
|
"Entering DeepGEMM JIT Pre-Complie session. "
|
95
102
|
"And it may takes a long time(Typically 10-20 mins) "
|
@@ -276,7 +283,7 @@ def _maybe_compile_deep_gemm_one_type_all(
|
|
276
283
|
query_key = (kernel_type, n, k, num_groups)
|
277
284
|
if (
|
278
285
|
_ENABLE_JIT_DEEPGEMM_PRECOMPILE
|
279
|
-
and
|
286
|
+
and _DO_COMPILE_ALL
|
280
287
|
and _INITIALIZATION_DICT.get(query_key) is None
|
281
288
|
):
|
282
289
|
_INITIALIZATION_DICT[query_key] = True
|
@@ -286,7 +293,7 @@ def _maybe_compile_deep_gemm_one_type_all(
|
|
286
293
|
logger.info(
|
287
294
|
f"Try DeepGEMM JIT Compiling for "
|
288
295
|
f"<{kernel_helper.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
|
289
|
-
f"{' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. ' if not
|
296
|
+
f"{' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}"
|
290
297
|
)
|
291
298
|
|
292
299
|
# NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
|
@@ -355,7 +362,7 @@ def gemm_nt_f8f8bf16(
|
|
355
362
|
|
356
363
|
@contextmanager
|
357
364
|
def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
|
358
|
-
if
|
365
|
+
if _IN_PRECOMPILE_STAGE:
|
359
366
|
yield
|
360
367
|
return
|
361
368
|
|
@@ -8,7 +8,11 @@ import torch
|
|
8
8
|
import triton
|
9
9
|
import triton.language as tl
|
10
10
|
|
11
|
-
from sglang.srt.utils import get_device_name
|
11
|
+
from sglang.srt.utils import get_device_name, is_cuda
|
12
|
+
|
13
|
+
_is_cuda = is_cuda()
|
14
|
+
if _is_cuda:
|
15
|
+
from sgl_kernel import sgl_per_token_group_quant_int8
|
12
16
|
|
13
17
|
logger = logging.getLogger(__name__)
|
14
18
|
|
@@ -165,6 +169,33 @@ def per_token_group_quant_int8(
|
|
165
169
|
return x_q, x_s
|
166
170
|
|
167
171
|
|
172
|
+
def sglang_per_token_group_quant_int8(
|
173
|
+
x: torch.Tensor,
|
174
|
+
group_size: int,
|
175
|
+
eps: float = 1e-10,
|
176
|
+
dtype: torch.dtype = torch.int8,
|
177
|
+
):
|
178
|
+
assert (
|
179
|
+
x.shape[-1] % group_size == 0
|
180
|
+
), "the last dimension of `x` cannot be divisible by `group_size`"
|
181
|
+
assert x.is_contiguous(), "`x` is not contiguous"
|
182
|
+
|
183
|
+
iinfo = torch.iinfo(dtype)
|
184
|
+
int8_max = iinfo.max
|
185
|
+
int8_min = iinfo.min
|
186
|
+
|
187
|
+
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
188
|
+
x_s = torch.empty(
|
189
|
+
x.shape[:-1] + (x.shape[-1] // group_size,),
|
190
|
+
device=x.device,
|
191
|
+
dtype=torch.float32,
|
192
|
+
)
|
193
|
+
|
194
|
+
sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
|
195
|
+
|
196
|
+
return x_q, x_s
|
197
|
+
|
198
|
+
|
168
199
|
@triton.jit
|
169
200
|
def _w8a8_block_int8_matmul(
|
170
201
|
# Pointers to inputs and output
|
@@ -87,13 +87,23 @@ class RadixAttention(nn.Module):
|
|
87
87
|
v,
|
88
88
|
forward_batch: ForwardBatch,
|
89
89
|
save_kv_cache: bool = True,
|
90
|
+
**kwargs,
|
90
91
|
):
|
91
92
|
if k is not None:
|
92
93
|
# For cross-layer sharing, kv can be None
|
93
94
|
assert v is not None
|
94
|
-
|
95
|
-
|
95
|
+
if "k_rope" not in kwargs:
|
96
|
+
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
97
|
+
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
98
|
+
else:
|
99
|
+
k = k.view(-1, self.tp_k_head_num, self.v_head_dim)
|
96
100
|
|
97
101
|
return forward_batch.attn_backend.forward(
|
98
|
-
q,
|
102
|
+
q,
|
103
|
+
k,
|
104
|
+
v,
|
105
|
+
self,
|
106
|
+
forward_batch,
|
107
|
+
save_kv_cache,
|
108
|
+
**kwargs,
|
99
109
|
)
|