sglang 0.4.4.post4__py3-none-any.whl → 0.4.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +24 -0
- sglang/srt/configs/model_config.py +4 -0
- sglang/srt/conversation.py +29 -4
- sglang/srt/layers/attention/flashattention_backend.py +286 -9
- sglang/srt/layers/moe/fused_moe_native.py +5 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/quantization/__init__.py +1 -0
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/fp8.py +3 -1
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/w8a8_int8.py +2 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/rotary_embedding.py +63 -0
- sglang/srt/managers/multimodal_processors/mllama4.py +161 -0
- sglang/srt/model_executor/model_runner.py +1 -0
- sglang/srt/models/llama.py +12 -4
- sglang/srt/models/llama4.py +420 -0
- sglang/srt/models/mllama4.py +154 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.dist-info}/METADATA +1 -1
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.dist-info}/RECORD +32 -22
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.dist-info}/top_level.txt +0 -0
sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json
ADDED
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 5
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 256,
|
14
|
+
"GROUP_SIZE_M": 16,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 3
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 256,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 8,
|
24
|
+
"num_stages": 3
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 256,
|
30
|
+
"GROUP_SIZE_M": 32,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 64,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 16,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 4
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 32,
|
45
|
+
"BLOCK_SIZE_K": 256,
|
46
|
+
"GROUP_SIZE_M": 32,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 32,
|
52
|
+
"BLOCK_SIZE_N": 128,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 4
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 32,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 4
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 256,
|
70
|
+
"GROUP_SIZE_M": 1,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 64,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 64,
|
78
|
+
"GROUP_SIZE_M": 1,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 5
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 64,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 1,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 5
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 32,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 1,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 4
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 64,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 1,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 4
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 1,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 4
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 128,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 64,
|
118
|
+
"GROUP_SIZE_M": 1,
|
119
|
+
"num_warps": 8,
|
120
|
+
"num_stages": 5
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 128,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 64,
|
126
|
+
"GROUP_SIZE_M": 1,
|
127
|
+
"num_warps": 8,
|
128
|
+
"num_stages": 5
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 128,
|
132
|
+
"BLOCK_SIZE_N": 256,
|
133
|
+
"BLOCK_SIZE_K": 64,
|
134
|
+
"GROUP_SIZE_M": 16,
|
135
|
+
"num_warps": 8,
|
136
|
+
"num_stages": 4
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 256,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 64,
|
142
|
+
"GROUP_SIZE_M": 1,
|
143
|
+
"num_warps": 8,
|
144
|
+
"num_stages": 4
|
145
|
+
}
|
146
|
+
}
|
sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json
ADDED
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 32,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 5
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 64,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 1,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 64,
|
21
|
+
"BLOCK_SIZE_K": 256,
|
22
|
+
"GROUP_SIZE_M": 16,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 3
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 64,
|
29
|
+
"BLOCK_SIZE_K": 256,
|
30
|
+
"GROUP_SIZE_M": 16,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 256,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 16,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 4
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 32,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 16,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 64,
|
53
|
+
"BLOCK_SIZE_K": 256,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 64,
|
61
|
+
"BLOCK_SIZE_K": 256,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 64,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 1,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 5
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 256,
|
78
|
+
"GROUP_SIZE_M": 1,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 64,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 32,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 2
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 32,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 1,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 4
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 32,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 1,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 4
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 64,
|
109
|
+
"BLOCK_SIZE_K": 64,
|
110
|
+
"GROUP_SIZE_M": 1,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 5
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 128,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 64,
|
118
|
+
"GROUP_SIZE_M": 1,
|
119
|
+
"num_warps": 8,
|
120
|
+
"num_stages": 5
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 128,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 64,
|
126
|
+
"GROUP_SIZE_M": 1,
|
127
|
+
"num_warps": 8,
|
128
|
+
"num_stages": 5
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 128,
|
132
|
+
"BLOCK_SIZE_N": 256,
|
133
|
+
"BLOCK_SIZE_K": 64,
|
134
|
+
"GROUP_SIZE_M": 1,
|
135
|
+
"num_warps": 8,
|
136
|
+
"num_stages": 4
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 128,
|
140
|
+
"BLOCK_SIZE_N": 256,
|
141
|
+
"BLOCK_SIZE_K": 64,
|
142
|
+
"GROUP_SIZE_M": 16,
|
143
|
+
"num_warps": 8,
|
144
|
+
"num_stages": 4
|
145
|
+
}
|
146
|
+
}
|
@@ -1079,6 +1079,7 @@ def inplace_fused_experts(
|
|
1079
1079
|
topk_weights: torch.Tensor,
|
1080
1080
|
topk_ids: torch.Tensor,
|
1081
1081
|
activation: str = "silu",
|
1082
|
+
apply_router_weight_on_input: bool = False,
|
1082
1083
|
use_fp8_w8a8: bool = False,
|
1083
1084
|
use_int8_w8a8: bool = False,
|
1084
1085
|
use_int8_w8a16: bool = False,
|
@@ -1099,6 +1100,7 @@ def inplace_fused_experts(
|
|
1099
1100
|
topk_ids,
|
1100
1101
|
True,
|
1101
1102
|
activation,
|
1103
|
+
apply_router_weight_on_input,
|
1102
1104
|
use_fp8_w8a8,
|
1103
1105
|
use_int8_w8a8,
|
1104
1106
|
use_int8_w8a16,
|
@@ -1120,6 +1122,7 @@ def inplace_fused_experts_fake(
|
|
1120
1122
|
topk_weights: torch.Tensor,
|
1121
1123
|
topk_ids: torch.Tensor,
|
1122
1124
|
activation: str = "silu",
|
1125
|
+
apply_router_weight_on_input: bool = False,
|
1123
1126
|
use_fp8_w8a8: bool = False,
|
1124
1127
|
use_int8_w8a8: bool = False,
|
1125
1128
|
use_int8_w8a16: bool = False,
|
@@ -1150,6 +1153,7 @@ def outplace_fused_experts(
|
|
1150
1153
|
topk_weights: torch.Tensor,
|
1151
1154
|
topk_ids: torch.Tensor,
|
1152
1155
|
activation: str = "silu",
|
1156
|
+
apply_router_weight_on_input: bool = False,
|
1153
1157
|
use_fp8_w8a8: bool = False,
|
1154
1158
|
use_int8_w8a8: bool = False,
|
1155
1159
|
use_int8_w8a16: bool = False,
|
@@ -1171,6 +1175,7 @@ def outplace_fused_experts(
|
|
1171
1175
|
topk_ids,
|
1172
1176
|
False,
|
1173
1177
|
activation,
|
1178
|
+
apply_router_weight_on_input,
|
1174
1179
|
use_fp8_w8a8,
|
1175
1180
|
use_int8_w8a8,
|
1176
1181
|
use_int8_w8a16,
|
@@ -1193,6 +1198,7 @@ def outplace_fused_experts_fake(
|
|
1193
1198
|
topk_weights: torch.Tensor,
|
1194
1199
|
topk_ids: torch.Tensor,
|
1195
1200
|
activation: str = "silu",
|
1201
|
+
apply_router_weight_on_input: bool = False,
|
1196
1202
|
use_fp8_w8a8: bool = False,
|
1197
1203
|
use_int8_w8a8: bool = False,
|
1198
1204
|
use_int8_w8a16: bool = False,
|
@@ -1225,6 +1231,7 @@ def fused_experts(
|
|
1225
1231
|
topk_ids: torch.Tensor,
|
1226
1232
|
inplace: bool = False,
|
1227
1233
|
activation: str = "silu",
|
1234
|
+
apply_router_weight_on_input: bool = False,
|
1228
1235
|
use_fp8_w8a8: bool = False,
|
1229
1236
|
use_int8_w8a8: bool = False,
|
1230
1237
|
use_int8_w8a16: bool = False,
|
@@ -1247,6 +1254,7 @@ def fused_experts(
|
|
1247
1254
|
topk_weights,
|
1248
1255
|
topk_ids,
|
1249
1256
|
activation,
|
1257
|
+
apply_router_weight_on_input,
|
1250
1258
|
use_fp8_w8a8,
|
1251
1259
|
use_int8_w8a8,
|
1252
1260
|
use_int8_w8a16,
|
@@ -1268,6 +1276,7 @@ def fused_experts(
|
|
1268
1276
|
topk_weights,
|
1269
1277
|
topk_ids,
|
1270
1278
|
activation,
|
1279
|
+
apply_router_weight_on_input,
|
1271
1280
|
use_fp8_w8a8,
|
1272
1281
|
use_int8_w8a8,
|
1273
1282
|
use_int8_w8a16,
|
@@ -1291,6 +1300,7 @@ def fused_experts_impl(
|
|
1291
1300
|
topk_ids: torch.Tensor,
|
1292
1301
|
inplace: bool = False,
|
1293
1302
|
activation: str = "silu",
|
1303
|
+
apply_router_weight_on_input: bool = False,
|
1294
1304
|
use_fp8_w8a8: bool = False,
|
1295
1305
|
use_int8_w8a8: bool = False,
|
1296
1306
|
use_int8_w8a16: bool = False,
|
@@ -1423,7 +1433,7 @@ def fused_experts_impl(
|
|
1423
1433
|
sorted_token_ids,
|
1424
1434
|
expert_ids,
|
1425
1435
|
num_tokens_post_padded,
|
1426
|
-
|
1436
|
+
apply_router_weight_on_input,
|
1427
1437
|
topk_ids.shape[1],
|
1428
1438
|
config,
|
1429
1439
|
compute_type=compute_type,
|
@@ -1456,7 +1466,7 @@ def fused_experts_impl(
|
|
1456
1466
|
(
|
1457
1467
|
intermediate_cache3
|
1458
1468
|
if not no_combine and topk_ids.shape[1] != 1
|
1459
|
-
else out_hidden_states[begin_chunk_idx:end_chunk_idx]
|
1469
|
+
else out_hidden_states[begin_chunk_idx:end_chunk_idx].unsqueeze(0)
|
1460
1470
|
),
|
1461
1471
|
a2_scale,
|
1462
1472
|
w2_scale,
|
@@ -1466,7 +1476,7 @@ def fused_experts_impl(
|
|
1466
1476
|
sorted_token_ids,
|
1467
1477
|
expert_ids,
|
1468
1478
|
num_tokens_post_padded,
|
1469
|
-
|
1479
|
+
not apply_router_weight_on_input,
|
1470
1480
|
1,
|
1471
1481
|
config,
|
1472
1482
|
compute_type=compute_type,
|
@@ -128,6 +128,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
128
128
|
custom_routing_function: Optional[Callable] = None,
|
129
129
|
correction_bias: Optional[torch.Tensor] = None,
|
130
130
|
activation: str = "silu",
|
131
|
+
apply_router_weight_on_input: bool = False,
|
131
132
|
inplace: bool = True,
|
132
133
|
no_combine: bool = False,
|
133
134
|
) -> torch.Tensor:
|
@@ -143,6 +144,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
143
144
|
custom_routing_function=custom_routing_function,
|
144
145
|
correction_bias=correction_bias,
|
145
146
|
activation=activation,
|
147
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
146
148
|
inplace=inplace,
|
147
149
|
no_combine=no_combine,
|
148
150
|
)
|
@@ -160,6 +162,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
160
162
|
custom_routing_function: Optional[Callable] = None,
|
161
163
|
correction_bias: Optional[torch.Tensor] = None,
|
162
164
|
activation: str = "silu",
|
165
|
+
apply_router_weight_on_input: bool = False,
|
163
166
|
inplace: bool = True,
|
164
167
|
no_combine: bool = False,
|
165
168
|
) -> torch.Tensor:
|
@@ -200,6 +203,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
200
203
|
topk_ids=topk_ids,
|
201
204
|
inplace=inplace and not no_combine,
|
202
205
|
activation=activation,
|
206
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
203
207
|
no_combine=no_combine,
|
204
208
|
)
|
205
209
|
|
@@ -276,6 +280,7 @@ class FusedMoE(torch.nn.Module):
|
|
276
280
|
custom_routing_function: Optional[Callable] = None,
|
277
281
|
correction_bias: Optional[torch.Tensor] = None,
|
278
282
|
activation: str = "silu",
|
283
|
+
apply_router_weight_on_input: bool = False,
|
279
284
|
use_presharded_weights: bool = False,
|
280
285
|
inplace: bool = True,
|
281
286
|
no_combine: bool = False,
|
@@ -302,6 +307,7 @@ class FusedMoE(torch.nn.Module):
|
|
302
307
|
self.custom_routing_function = custom_routing_function
|
303
308
|
self.correction_bias = correction_bias
|
304
309
|
self.activation = activation
|
310
|
+
self.apply_router_weight_on_input = apply_router_weight_on_input
|
305
311
|
self.use_presharded_weights = use_presharded_weights
|
306
312
|
self.inplace = inplace
|
307
313
|
self.no_combine = no_combine
|
@@ -630,6 +636,7 @@ class FusedMoE(torch.nn.Module):
|
|
630
636
|
custom_routing_function=self.custom_routing_function,
|
631
637
|
correction_bias=self.correction_bias,
|
632
638
|
activation=self.activation,
|
639
|
+
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
633
640
|
)
|
634
641
|
|
635
642
|
if self.reduce_results and self.tp_size > 1:
|
@@ -280,6 +280,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
|
280
280
|
custom_routing_function: Optional[Callable] = None,
|
281
281
|
correction_bias: Optional[torch.Tensor] = None,
|
282
282
|
activation: str = "silu",
|
283
|
+
apply_router_weight_on_input: bool = False,
|
283
284
|
inplace: bool = True,
|
284
285
|
no_combine: bool = False,
|
285
286
|
):
|
@@ -370,6 +370,7 @@ class BlockInt8MoEMethod:
|
|
370
370
|
custom_routing_function: Optional[Callable] = None,
|
371
371
|
correction_bias: Optional[torch.Tensor] = None,
|
372
372
|
activation: str = "silu",
|
373
|
+
apply_router_weight_on_input: bool = False,
|
373
374
|
inplace: bool = True,
|
374
375
|
no_combine: bool = False,
|
375
376
|
) -> torch.Tensor:
|
@@ -398,6 +399,7 @@ class BlockInt8MoEMethod:
|
|
398
399
|
topk_ids=topk_ids,
|
399
400
|
inplace=inplace,
|
400
401
|
activation=activation,
|
402
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
401
403
|
use_int8_w8a8=True,
|
402
404
|
w1_scale=(layer.w13_weight_scale_inv),
|
403
405
|
w2_scale=(layer.w2_weight_scale_inv),
|
@@ -860,7 +860,7 @@ class Fp8MoEMethod:
|
|
860
860
|
layer.w13_weight_scale1[expert_id] *= max_w13_scales[expert_id]
|
861
861
|
layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id]
|
862
862
|
|
863
|
-
def process_weights_hip_scale_padding(self, layer: Module
|
863
|
+
def process_weights_hip_scale_padding(self, layer: Module):
|
864
864
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
865
865
|
padding_size, # Avoid circular import
|
866
866
|
)
|
@@ -905,6 +905,7 @@ class Fp8MoEMethod:
|
|
905
905
|
custom_routing_function: Optional[Callable] = None,
|
906
906
|
correction_bias: Optional[torch.Tensor] = None,
|
907
907
|
activation: str = "silu",
|
908
|
+
apply_router_weight_on_input: bool = False,
|
908
909
|
inplace: bool = True,
|
909
910
|
no_combine: bool = False,
|
910
911
|
) -> torch.Tensor:
|
@@ -975,6 +976,7 @@ class Fp8MoEMethod:
|
|
975
976
|
topk_ids=topk_ids,
|
976
977
|
inplace=inplace and not no_combine,
|
977
978
|
activation=activation,
|
979
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
978
980
|
use_fp8_w8a8=True,
|
979
981
|
w1_scale=(
|
980
982
|
layer.w13_weight_scale_inv
|
@@ -344,6 +344,7 @@ class MoeWNA16Method:
|
|
344
344
|
custom_routing_function: Optional[Callable] = None,
|
345
345
|
correction_bias: Optional[torch.Tensor] = None,
|
346
346
|
activation: str = "silu",
|
347
|
+
apply_router_weight_on_input: bool = False,
|
347
348
|
inplace: bool = True,
|
348
349
|
no_combine: bool = False,
|
349
350
|
) -> torch.Tensor:
|
@@ -374,6 +375,7 @@ class MoeWNA16Method:
|
|
374
375
|
topk_weights=topk_weights,
|
375
376
|
topk_ids=topk_ids,
|
376
377
|
inplace=inplace,
|
378
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
377
379
|
use_int4_w4a16=weight_bits == 4,
|
378
380
|
use_int8_w8a16=weight_bits == 8,
|
379
381
|
w1_scale=layer.w13_scales,
|
@@ -230,6 +230,7 @@ class W8A8Int8MoEMethod:
|
|
230
230
|
custom_routing_function: Optional[Callable] = None,
|
231
231
|
correction_bias: Optional[torch.Tensor] = None,
|
232
232
|
activation: str = "silu",
|
233
|
+
apply_router_weight_on_input: bool = False,
|
233
234
|
inplace: bool = True,
|
234
235
|
no_combine: bool = False,
|
235
236
|
) -> torch.Tensor:
|
@@ -257,6 +258,7 @@ class W8A8Int8MoEMethod:
|
|
257
258
|
topk_ids=topk_ids,
|
258
259
|
inplace=inplace,
|
259
260
|
activation=activation,
|
261
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
260
262
|
use_int8_w8a8=True,
|
261
263
|
w1_scale=(layer.w13_weight_scale),
|
262
264
|
w2_scale=(layer.w2_weight_scale),
|
@@ -35,6 +35,7 @@ class RadixAttention(nn.Module):
|
|
35
35
|
sliding_window_size: int = -1,
|
36
36
|
is_cross_attention: bool = False,
|
37
37
|
prefix: str = "",
|
38
|
+
use_irope: bool = False,
|
38
39
|
):
|
39
40
|
super().__init__()
|
40
41
|
self.tp_q_head_num = num_heads
|
@@ -50,6 +51,7 @@ class RadixAttention(nn.Module):
|
|
50
51
|
self.is_cross_attention = is_cross_attention
|
51
52
|
self.k_scale = None
|
52
53
|
self.v_scale = None
|
54
|
+
self.use_irope = use_irope
|
53
55
|
|
54
56
|
def forward(
|
55
57
|
self,
|
@@ -733,6 +733,69 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
|
|
733
733
|
return new_freqs
|
734
734
|
|
735
735
|
|
736
|
+
class Llama4VisionRotaryEmbedding(RotaryEmbedding):
|
737
|
+
|
738
|
+
def __init__(
|
739
|
+
self,
|
740
|
+
head_size: int,
|
741
|
+
rotary_dim: int,
|
742
|
+
max_position_embeddings: int,
|
743
|
+
base: int,
|
744
|
+
is_neox_style: bool,
|
745
|
+
dtype: torch.dtype,
|
746
|
+
):
|
747
|
+
super().__init__(
|
748
|
+
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
749
|
+
)
|
750
|
+
|
751
|
+
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
752
|
+
inv_freqs = super()._compute_inv_freq(base)
|
753
|
+
inv_freqs = inv_freqs[: (self.rotary_dim // 2)]
|
754
|
+
return inv_freqs
|
755
|
+
|
756
|
+
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
757
|
+
inv_freq = self._compute_inv_freq(self.base)
|
758
|
+
|
759
|
+
# self.max_position_embeddings here is number of image patches
|
760
|
+
# i.e. (image_size // patch_size) ** 2
|
761
|
+
num_patches = self.max_position_embeddings
|
762
|
+
img_idx = torch.arange(num_patches, dtype=torch.int32).reshape(num_patches, 1)
|
763
|
+
img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
|
764
|
+
img_idx[-1, -1] = -2 # set to ID_CLS_TOKEN
|
765
|
+
num_patches_single_dim = int(math.sqrt(num_patches))
|
766
|
+
frequencies_x = img_idx % num_patches_single_dim
|
767
|
+
frequencies_y = img_idx // num_patches_single_dim
|
768
|
+
freqs_x = (
|
769
|
+
(frequencies_x + 1)[..., None] * inv_freq[None, None, :]
|
770
|
+
).repeat_interleave(2, dim=-1)
|
771
|
+
freqs_y = (
|
772
|
+
(frequencies_y + 1)[..., None] * inv_freq[None, None, :]
|
773
|
+
).repeat_interleave(2, dim=-1)
|
774
|
+
freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
|
775
|
+
freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
|
776
|
+
cache = torch.view_as_complex(
|
777
|
+
torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
|
778
|
+
)
|
779
|
+
return cache
|
780
|
+
|
781
|
+
def forward(
|
782
|
+
self,
|
783
|
+
query: torch.Tensor,
|
784
|
+
key: torch.Tensor,
|
785
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
786
|
+
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
|
787
|
+
query_ = torch.view_as_complex(query.float().reshape(*query.shape[:-1], -1, 2))
|
788
|
+
key_ = torch.view_as_complex(key.float().reshape(*key.shape[:-1], -1, 2))
|
789
|
+
broadcast_shape = [
|
790
|
+
d if i == 1 or i == (query_.ndim - 1) else 1
|
791
|
+
for i, d in enumerate(query_.shape)
|
792
|
+
]
|
793
|
+
freqs_ci = self.cos_sin_cache.view(*broadcast_shape)
|
794
|
+
query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
|
795
|
+
key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
|
796
|
+
return query_out.type_as(query), key_out.type_as(key)
|
797
|
+
|
798
|
+
|
736
799
|
class MRotaryEmbedding(RotaryEmbedding):
|
737
800
|
"""Rotary Embedding with Multimodal Sections."""
|
738
801
|
|