sglang 0.4.4.post4__py3-none-any.whl → 0.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. sglang/lang/chat_template.py +24 -0
  2. sglang/srt/configs/model_config.py +4 -0
  3. sglang/srt/conversation.py +29 -4
  4. sglang/srt/layers/attention/flashattention_backend.py +286 -9
  5. sglang/srt/layers/moe/fused_moe_native.py +5 -0
  6. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  7. sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  8. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  9. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
  10. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  11. sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  12. sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  13. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -3
  14. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  15. sglang/srt/layers/quantization/__init__.py +1 -0
  16. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  17. sglang/srt/layers/quantization/fp8.py +3 -1
  18. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  19. sglang/srt/layers/quantization/w8a8_int8.py +2 -0
  20. sglang/srt/layers/radix_attention.py +2 -0
  21. sglang/srt/layers/rotary_embedding.py +63 -0
  22. sglang/srt/managers/multimodal_processors/mllama4.py +161 -0
  23. sglang/srt/model_executor/model_runner.py +1 -0
  24. sglang/srt/models/llama.py +12 -4
  25. sglang/srt/models/llama4.py +420 -0
  26. sglang/srt/models/mllama4.py +154 -0
  27. sglang/version.py +1 -1
  28. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.dist-info}/METADATA +1 -1
  29. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.dist-info}/RECORD +32 -22
  30. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.dist-info}/WHEEL +0 -0
  31. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.dist-info}/licenses/LICENSE +0 -0
  32. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 5
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 256,
14
+ "GROUP_SIZE_M": 16,
15
+ "num_warps": 4,
16
+ "num_stages": 3
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 256,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 8,
24
+ "num_stages": 3
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 256,
30
+ "GROUP_SIZE_M": 32,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 64,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 16,
39
+ "num_warps": 4,
40
+ "num_stages": 4
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 32,
45
+ "BLOCK_SIZE_K": 256,
46
+ "GROUP_SIZE_M": 32,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 32,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 4
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 32,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 4
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 256,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 64,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 64,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 5
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 64,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 5
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 32,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 4
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 4,
104
+ "num_stages": 4
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 4,
112
+ "num_stages": 4
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 8,
120
+ "num_stages": 5
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 8,
128
+ "num_stages": 5
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 8,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 256,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 1,
143
+ "num_warps": 8,
144
+ "num_stages": 4
145
+ }
146
+ }
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 32,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 5
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 256,
22
+ "GROUP_SIZE_M": 16,
23
+ "num_warps": 4,
24
+ "num_stages": 3
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 256,
30
+ "GROUP_SIZE_M": 16,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 256,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 16,
39
+ "num_warps": 4,
40
+ "num_stages": 4
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 32,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 16,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 256,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 64,
61
+ "BLOCK_SIZE_K": 256,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 64,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 5
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 256,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 64,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 32,
87
+ "num_warps": 4,
88
+ "num_stages": 2
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 32,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 4
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 32,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 4,
104
+ "num_stages": 4
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 64,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 4,
112
+ "num_stages": 5
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 8,
120
+ "num_stages": 5
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 8,
128
+ "num_stages": 5
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 8,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 8,
144
+ "num_stages": 4
145
+ }
146
+ }
@@ -1079,6 +1079,7 @@ def inplace_fused_experts(
1079
1079
  topk_weights: torch.Tensor,
1080
1080
  topk_ids: torch.Tensor,
1081
1081
  activation: str = "silu",
1082
+ apply_router_weight_on_input: bool = False,
1082
1083
  use_fp8_w8a8: bool = False,
1083
1084
  use_int8_w8a8: bool = False,
1084
1085
  use_int8_w8a16: bool = False,
@@ -1099,6 +1100,7 @@ def inplace_fused_experts(
1099
1100
  topk_ids,
1100
1101
  True,
1101
1102
  activation,
1103
+ apply_router_weight_on_input,
1102
1104
  use_fp8_w8a8,
1103
1105
  use_int8_w8a8,
1104
1106
  use_int8_w8a16,
@@ -1120,6 +1122,7 @@ def inplace_fused_experts_fake(
1120
1122
  topk_weights: torch.Tensor,
1121
1123
  topk_ids: torch.Tensor,
1122
1124
  activation: str = "silu",
1125
+ apply_router_weight_on_input: bool = False,
1123
1126
  use_fp8_w8a8: bool = False,
1124
1127
  use_int8_w8a8: bool = False,
1125
1128
  use_int8_w8a16: bool = False,
@@ -1150,6 +1153,7 @@ def outplace_fused_experts(
1150
1153
  topk_weights: torch.Tensor,
1151
1154
  topk_ids: torch.Tensor,
1152
1155
  activation: str = "silu",
1156
+ apply_router_weight_on_input: bool = False,
1153
1157
  use_fp8_w8a8: bool = False,
1154
1158
  use_int8_w8a8: bool = False,
1155
1159
  use_int8_w8a16: bool = False,
@@ -1171,6 +1175,7 @@ def outplace_fused_experts(
1171
1175
  topk_ids,
1172
1176
  False,
1173
1177
  activation,
1178
+ apply_router_weight_on_input,
1174
1179
  use_fp8_w8a8,
1175
1180
  use_int8_w8a8,
1176
1181
  use_int8_w8a16,
@@ -1193,6 +1198,7 @@ def outplace_fused_experts_fake(
1193
1198
  topk_weights: torch.Tensor,
1194
1199
  topk_ids: torch.Tensor,
1195
1200
  activation: str = "silu",
1201
+ apply_router_weight_on_input: bool = False,
1196
1202
  use_fp8_w8a8: bool = False,
1197
1203
  use_int8_w8a8: bool = False,
1198
1204
  use_int8_w8a16: bool = False,
@@ -1225,6 +1231,7 @@ def fused_experts(
1225
1231
  topk_ids: torch.Tensor,
1226
1232
  inplace: bool = False,
1227
1233
  activation: str = "silu",
1234
+ apply_router_weight_on_input: bool = False,
1228
1235
  use_fp8_w8a8: bool = False,
1229
1236
  use_int8_w8a8: bool = False,
1230
1237
  use_int8_w8a16: bool = False,
@@ -1247,6 +1254,7 @@ def fused_experts(
1247
1254
  topk_weights,
1248
1255
  topk_ids,
1249
1256
  activation,
1257
+ apply_router_weight_on_input,
1250
1258
  use_fp8_w8a8,
1251
1259
  use_int8_w8a8,
1252
1260
  use_int8_w8a16,
@@ -1268,6 +1276,7 @@ def fused_experts(
1268
1276
  topk_weights,
1269
1277
  topk_ids,
1270
1278
  activation,
1279
+ apply_router_weight_on_input,
1271
1280
  use_fp8_w8a8,
1272
1281
  use_int8_w8a8,
1273
1282
  use_int8_w8a16,
@@ -1291,6 +1300,7 @@ def fused_experts_impl(
1291
1300
  topk_ids: torch.Tensor,
1292
1301
  inplace: bool = False,
1293
1302
  activation: str = "silu",
1303
+ apply_router_weight_on_input: bool = False,
1294
1304
  use_fp8_w8a8: bool = False,
1295
1305
  use_int8_w8a8: bool = False,
1296
1306
  use_int8_w8a16: bool = False,
@@ -1423,7 +1433,7 @@ def fused_experts_impl(
1423
1433
  sorted_token_ids,
1424
1434
  expert_ids,
1425
1435
  num_tokens_post_padded,
1426
- False,
1436
+ apply_router_weight_on_input,
1427
1437
  topk_ids.shape[1],
1428
1438
  config,
1429
1439
  compute_type=compute_type,
@@ -1456,7 +1466,7 @@ def fused_experts_impl(
1456
1466
  (
1457
1467
  intermediate_cache3
1458
1468
  if not no_combine and topk_ids.shape[1] != 1
1459
- else out_hidden_states[begin_chunk_idx:end_chunk_idx]
1469
+ else out_hidden_states[begin_chunk_idx:end_chunk_idx].unsqueeze(0)
1460
1470
  ),
1461
1471
  a2_scale,
1462
1472
  w2_scale,
@@ -1466,7 +1476,7 @@ def fused_experts_impl(
1466
1476
  sorted_token_ids,
1467
1477
  expert_ids,
1468
1478
  num_tokens_post_padded,
1469
- True,
1479
+ not apply_router_weight_on_input,
1470
1480
  1,
1471
1481
  config,
1472
1482
  compute_type=compute_type,
@@ -128,6 +128,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
128
128
  custom_routing_function: Optional[Callable] = None,
129
129
  correction_bias: Optional[torch.Tensor] = None,
130
130
  activation: str = "silu",
131
+ apply_router_weight_on_input: bool = False,
131
132
  inplace: bool = True,
132
133
  no_combine: bool = False,
133
134
  ) -> torch.Tensor:
@@ -143,6 +144,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
143
144
  custom_routing_function=custom_routing_function,
144
145
  correction_bias=correction_bias,
145
146
  activation=activation,
147
+ apply_router_weight_on_input=apply_router_weight_on_input,
146
148
  inplace=inplace,
147
149
  no_combine=no_combine,
148
150
  )
@@ -160,6 +162,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
160
162
  custom_routing_function: Optional[Callable] = None,
161
163
  correction_bias: Optional[torch.Tensor] = None,
162
164
  activation: str = "silu",
165
+ apply_router_weight_on_input: bool = False,
163
166
  inplace: bool = True,
164
167
  no_combine: bool = False,
165
168
  ) -> torch.Tensor:
@@ -200,6 +203,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
200
203
  topk_ids=topk_ids,
201
204
  inplace=inplace and not no_combine,
202
205
  activation=activation,
206
+ apply_router_weight_on_input=apply_router_weight_on_input,
203
207
  no_combine=no_combine,
204
208
  )
205
209
 
@@ -276,6 +280,7 @@ class FusedMoE(torch.nn.Module):
276
280
  custom_routing_function: Optional[Callable] = None,
277
281
  correction_bias: Optional[torch.Tensor] = None,
278
282
  activation: str = "silu",
283
+ apply_router_weight_on_input: bool = False,
279
284
  use_presharded_weights: bool = False,
280
285
  inplace: bool = True,
281
286
  no_combine: bool = False,
@@ -302,6 +307,7 @@ class FusedMoE(torch.nn.Module):
302
307
  self.custom_routing_function = custom_routing_function
303
308
  self.correction_bias = correction_bias
304
309
  self.activation = activation
310
+ self.apply_router_weight_on_input = apply_router_weight_on_input
305
311
  self.use_presharded_weights = use_presharded_weights
306
312
  self.inplace = inplace
307
313
  self.no_combine = no_combine
@@ -630,6 +636,7 @@ class FusedMoE(torch.nn.Module):
630
636
  custom_routing_function=self.custom_routing_function,
631
637
  correction_bias=self.correction_bias,
632
638
  activation=self.activation,
639
+ apply_router_weight_on_input=self.apply_router_weight_on_input,
633
640
  )
634
641
 
635
642
  if self.reduce_results and self.tp_size > 1:
@@ -280,6 +280,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
280
280
  custom_routing_function: Optional[Callable] = None,
281
281
  correction_bias: Optional[torch.Tensor] = None,
282
282
  activation: str = "silu",
283
+ apply_router_weight_on_input: bool = False,
283
284
  inplace: bool = True,
284
285
  no_combine: bool = False,
285
286
  ):
@@ -370,6 +370,7 @@ class BlockInt8MoEMethod:
370
370
  custom_routing_function: Optional[Callable] = None,
371
371
  correction_bias: Optional[torch.Tensor] = None,
372
372
  activation: str = "silu",
373
+ apply_router_weight_on_input: bool = False,
373
374
  inplace: bool = True,
374
375
  no_combine: bool = False,
375
376
  ) -> torch.Tensor:
@@ -398,6 +399,7 @@ class BlockInt8MoEMethod:
398
399
  topk_ids=topk_ids,
399
400
  inplace=inplace,
400
401
  activation=activation,
402
+ apply_router_weight_on_input=apply_router_weight_on_input,
401
403
  use_int8_w8a8=True,
402
404
  w1_scale=(layer.w13_weight_scale_inv),
403
405
  w2_scale=(layer.w2_weight_scale_inv),
@@ -860,7 +860,7 @@ class Fp8MoEMethod:
860
860
  layer.w13_weight_scale1[expert_id] *= max_w13_scales[expert_id]
861
861
  layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id]
862
862
 
863
- def process_weights_hip_scale_padding(self, layer: Module, padding_size: int):
863
+ def process_weights_hip_scale_padding(self, layer: Module):
864
864
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
865
865
  padding_size, # Avoid circular import
866
866
  )
@@ -905,6 +905,7 @@ class Fp8MoEMethod:
905
905
  custom_routing_function: Optional[Callable] = None,
906
906
  correction_bias: Optional[torch.Tensor] = None,
907
907
  activation: str = "silu",
908
+ apply_router_weight_on_input: bool = False,
908
909
  inplace: bool = True,
909
910
  no_combine: bool = False,
910
911
  ) -> torch.Tensor:
@@ -975,6 +976,7 @@ class Fp8MoEMethod:
975
976
  topk_ids=topk_ids,
976
977
  inplace=inplace and not no_combine,
977
978
  activation=activation,
979
+ apply_router_weight_on_input=apply_router_weight_on_input,
978
980
  use_fp8_w8a8=True,
979
981
  w1_scale=(
980
982
  layer.w13_weight_scale_inv
@@ -344,6 +344,7 @@ class MoeWNA16Method:
344
344
  custom_routing_function: Optional[Callable] = None,
345
345
  correction_bias: Optional[torch.Tensor] = None,
346
346
  activation: str = "silu",
347
+ apply_router_weight_on_input: bool = False,
347
348
  inplace: bool = True,
348
349
  no_combine: bool = False,
349
350
  ) -> torch.Tensor:
@@ -374,6 +375,7 @@ class MoeWNA16Method:
374
375
  topk_weights=topk_weights,
375
376
  topk_ids=topk_ids,
376
377
  inplace=inplace,
378
+ apply_router_weight_on_input=apply_router_weight_on_input,
377
379
  use_int4_w4a16=weight_bits == 4,
378
380
  use_int8_w8a16=weight_bits == 8,
379
381
  w1_scale=layer.w13_scales,
@@ -230,6 +230,7 @@ class W8A8Int8MoEMethod:
230
230
  custom_routing_function: Optional[Callable] = None,
231
231
  correction_bias: Optional[torch.Tensor] = None,
232
232
  activation: str = "silu",
233
+ apply_router_weight_on_input: bool = False,
233
234
  inplace: bool = True,
234
235
  no_combine: bool = False,
235
236
  ) -> torch.Tensor:
@@ -257,6 +258,7 @@ class W8A8Int8MoEMethod:
257
258
  topk_ids=topk_ids,
258
259
  inplace=inplace,
259
260
  activation=activation,
261
+ apply_router_weight_on_input=apply_router_weight_on_input,
260
262
  use_int8_w8a8=True,
261
263
  w1_scale=(layer.w13_weight_scale),
262
264
  w2_scale=(layer.w2_weight_scale),
@@ -35,6 +35,7 @@ class RadixAttention(nn.Module):
35
35
  sliding_window_size: int = -1,
36
36
  is_cross_attention: bool = False,
37
37
  prefix: str = "",
38
+ use_irope: bool = False,
38
39
  ):
39
40
  super().__init__()
40
41
  self.tp_q_head_num = num_heads
@@ -50,6 +51,7 @@ class RadixAttention(nn.Module):
50
51
  self.is_cross_attention = is_cross_attention
51
52
  self.k_scale = None
52
53
  self.v_scale = None
54
+ self.use_irope = use_irope
53
55
 
54
56
  def forward(
55
57
  self,
@@ -733,6 +733,69 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
733
733
  return new_freqs
734
734
 
735
735
 
736
+ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
737
+
738
+ def __init__(
739
+ self,
740
+ head_size: int,
741
+ rotary_dim: int,
742
+ max_position_embeddings: int,
743
+ base: int,
744
+ is_neox_style: bool,
745
+ dtype: torch.dtype,
746
+ ):
747
+ super().__init__(
748
+ head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
749
+ )
750
+
751
+ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
752
+ inv_freqs = super()._compute_inv_freq(base)
753
+ inv_freqs = inv_freqs[: (self.rotary_dim // 2)]
754
+ return inv_freqs
755
+
756
+ def _compute_cos_sin_cache(self) -> torch.Tensor:
757
+ inv_freq = self._compute_inv_freq(self.base)
758
+
759
+ # self.max_position_embeddings here is number of image patches
760
+ # i.e. (image_size // patch_size) ** 2
761
+ num_patches = self.max_position_embeddings
762
+ img_idx = torch.arange(num_patches, dtype=torch.int32).reshape(num_patches, 1)
763
+ img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
764
+ img_idx[-1, -1] = -2 # set to ID_CLS_TOKEN
765
+ num_patches_single_dim = int(math.sqrt(num_patches))
766
+ frequencies_x = img_idx % num_patches_single_dim
767
+ frequencies_y = img_idx // num_patches_single_dim
768
+ freqs_x = (
769
+ (frequencies_x + 1)[..., None] * inv_freq[None, None, :]
770
+ ).repeat_interleave(2, dim=-1)
771
+ freqs_y = (
772
+ (frequencies_y + 1)[..., None] * inv_freq[None, None, :]
773
+ ).repeat_interleave(2, dim=-1)
774
+ freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
775
+ freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
776
+ cache = torch.view_as_complex(
777
+ torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
778
+ )
779
+ return cache
780
+
781
+ def forward(
782
+ self,
783
+ query: torch.Tensor,
784
+ key: torch.Tensor,
785
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
786
+ self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
787
+ query_ = torch.view_as_complex(query.float().reshape(*query.shape[:-1], -1, 2))
788
+ key_ = torch.view_as_complex(key.float().reshape(*key.shape[:-1], -1, 2))
789
+ broadcast_shape = [
790
+ d if i == 1 or i == (query_.ndim - 1) else 1
791
+ for i, d in enumerate(query_.shape)
792
+ ]
793
+ freqs_ci = self.cos_sin_cache.view(*broadcast_shape)
794
+ query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
795
+ key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
796
+ return query_out.type_as(query), key_out.type_as(key)
797
+
798
+
736
799
  class MRotaryEmbedding(RotaryEmbedding):
737
800
  """Rotary Embedding with Multimodal Sections."""
738
801