sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +1 -11
- sglang/bench_serving.py +149 -1
- sglang/lang/chat_template.py +44 -0
- sglang/srt/configs/deepseekvl2.py +3 -0
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +17 -0
- sglang/srt/constrained/xgrammar_backend.py +11 -19
- sglang/srt/conversation.py +30 -3
- sglang/srt/disaggregation/decode.py +4 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +9 -18
- sglang/srt/disaggregation/nixl/conn.py +241 -71
- sglang/srt/disaggregation/utils.py +44 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +14 -2
- sglang/srt/entrypoints/http_server.py +28 -1
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/hf_transformers_utils.py +20 -1
- sglang/srt/layers/attention/flashattention_backend.py +146 -50
- sglang/srt/layers/attention/flashinfer_backend.py +23 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
- sglang/srt/layers/moe/ep_moe/layer.py +120 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +4 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +5 -0
- sglang/srt/layers/quantization/fp8.py +108 -95
- sglang/srt/layers/quantization/fp8_kernel.py +79 -60
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/lora/lora_manager.py +10 -13
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/io_struct.py +10 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/schedule_batch.py +19 -1
- sglang/srt/managers/schedule_policy.py +11 -5
- sglang/srt/managers/scheduler.py +28 -13
- sglang/srt/managers/tokenizer_manager.py +24 -13
- sglang/srt/managers/tp_worker.py +9 -12
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +2 -2
- sglang/srt/model_executor/model_runner.py +44 -33
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_nextn.py +1 -20
- sglang/srt/models/deepseek_v2.py +55 -20
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/llama.py +1 -1
- sglang/srt/models/llama4.py +53 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +24 -40
- sglang/srt/openai_api/protocol.py +28 -16
- sglang/srt/reasoning_parser.py +2 -2
- sglang/srt/sampling/sampling_batch_info.py +54 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +30 -6
- sglang/srt/utils.py +35 -1
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_utils.py +3 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +14 -6
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +90 -80
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,19 @@
|
|
1
|
+
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
1
2
|
from sglang.srt.utils import DeepEPMode
|
2
3
|
|
3
4
|
try:
|
4
5
|
from deep_ep import Buffer
|
5
6
|
|
7
|
+
from sglang.srt.layers.quantization.fp8_kernel import (
|
8
|
+
sglang_per_token_group_quant_fp8,
|
9
|
+
)
|
10
|
+
|
6
11
|
use_deepep = True
|
7
12
|
except ImportError:
|
8
13
|
use_deepep = False
|
9
14
|
|
10
15
|
from enum import IntEnum, auto
|
11
|
-
from typing import Optional, Tuple
|
16
|
+
from typing import Optional, Tuple, Union
|
12
17
|
|
13
18
|
import torch
|
14
19
|
import torch.distributed as dist
|
@@ -78,7 +83,6 @@ class DeepEPBuffer:
|
|
78
83
|
),
|
79
84
|
num_rdma_bytes,
|
80
85
|
)
|
81
|
-
|
82
86
|
cls._buffer = Buffer(
|
83
87
|
group,
|
84
88
|
num_nvl_bytes,
|
@@ -181,44 +185,74 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
181
185
|
topk_weights: torch.Tensor,
|
182
186
|
):
|
183
187
|
topk_idx = topk_idx.to(torch.int64)
|
188
|
+
if _ENABLE_JIT_DEEPGEMM:
|
189
|
+
# TODO hard code 128 block quant,use fp8 communication
|
190
|
+
hidden_states = sglang_per_token_group_quant_fp8(hidden_states, 128)
|
184
191
|
previous_event = Buffer.capture() if self.async_finish else None
|
185
192
|
return hidden_states, topk_idx, topk_weights, previous_event
|
186
193
|
|
187
194
|
def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
|
195
|
+
if _ENABLE_JIT_DEEPGEMM:
|
196
|
+
(
|
197
|
+
hidden_states,
|
198
|
+
topk_idx,
|
199
|
+
topk_weights,
|
200
|
+
num_recv_tokens_per_expert_list,
|
201
|
+
event,
|
202
|
+
) = self._dispatch_core(
|
203
|
+
hidden_states, topk_idx, topk_weights, previous_event
|
198
204
|
)
|
199
|
-
|
200
|
-
|
201
|
-
|
205
|
+
event.current_stream_wait() if self.async_finish else ()
|
206
|
+
return (
|
207
|
+
hidden_states,
|
208
|
+
topk_idx,
|
209
|
+
topk_weights,
|
210
|
+
None,
|
211
|
+
num_recv_tokens_per_expert_list,
|
212
|
+
None,
|
213
|
+
None,
|
214
|
+
None,
|
202
215
|
)
|
203
|
-
|
204
|
-
|
216
|
+
else:
|
217
|
+
(
|
218
|
+
hidden_states,
|
219
|
+
topk_idx,
|
220
|
+
topk_weights,
|
221
|
+
num_recv_tokens_per_expert_list,
|
222
|
+
event,
|
223
|
+
) = self._dispatch_core(
|
224
|
+
hidden_states, topk_idx, topk_weights, previous_event
|
205
225
|
)
|
226
|
+
event.current_stream_wait() if self.async_finish else ()
|
227
|
+
if hidden_states.shape[0] > 0:
|
228
|
+
reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
|
229
|
+
hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
|
230
|
+
)
|
231
|
+
else:
|
232
|
+
reorder_topk_ids = torch.empty(
|
233
|
+
(0,), device=hidden_states.device, dtype=torch.int64
|
234
|
+
)
|
235
|
+
seg_indptr = torch.zeros(
|
236
|
+
(self.num_experts + 1,),
|
237
|
+
device=hidden_states.device,
|
238
|
+
dtype=torch.int64,
|
239
|
+
)
|
206
240
|
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
241
|
+
masked_m = expected_m = None
|
242
|
+
return (
|
243
|
+
hidden_states,
|
244
|
+
topk_idx,
|
245
|
+
topk_weights,
|
246
|
+
reorder_topk_ids,
|
247
|
+
None,
|
248
|
+
seg_indptr,
|
249
|
+
masked_m,
|
250
|
+
expected_m,
|
251
|
+
)
|
218
252
|
|
219
253
|
def _dispatch_core(
|
220
254
|
self,
|
221
|
-
x: torch.Tensor,
|
255
|
+
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
222
256
|
topk_idx: torch.Tensor,
|
223
257
|
topk_weights: torch.Tensor,
|
224
258
|
previous_event,
|
@@ -246,7 +280,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
246
280
|
recv_x,
|
247
281
|
recv_topk_idx,
|
248
282
|
recv_topk_weights,
|
249
|
-
|
283
|
+
num_recv_tokens_per_expert_list,
|
250
284
|
self.handle,
|
251
285
|
event,
|
252
286
|
) = buffer.dispatch(
|
@@ -260,12 +294,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
260
294
|
previous_event=previous_event,
|
261
295
|
async_finish=self.async_finish,
|
262
296
|
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
|
297
|
+
expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1,
|
263
298
|
)
|
264
299
|
|
265
300
|
return (
|
266
301
|
recv_x,
|
267
302
|
recv_topk_idx,
|
268
303
|
recv_topk_weights,
|
304
|
+
num_recv_tokens_per_expert_list,
|
269
305
|
event,
|
270
306
|
)
|
271
307
|
|
@@ -314,29 +350,32 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
314
350
|
topk_idx: torch.Tensor,
|
315
351
|
topk_weights: torch.Tensor,
|
316
352
|
):
|
317
|
-
if
|
318
|
-
|
319
|
-
output = torch.empty(
|
320
|
-
(num_tokens, hidden_states.shape[1]),
|
321
|
-
device=hidden_states.device,
|
322
|
-
dtype=hidden_states.dtype,
|
323
|
-
)
|
324
|
-
deepep_post_reorder_triton_kernel[(num_tokens,)](
|
325
|
-
hidden_states,
|
326
|
-
output,
|
327
|
-
self.src2dst,
|
328
|
-
topk_idx,
|
329
|
-
topk_weights,
|
330
|
-
self.router_topk,
|
331
|
-
hidden_states.shape[1],
|
332
|
-
BLOCK_SIZE=512,
|
333
|
-
)
|
353
|
+
if _ENABLE_JIT_DEEPGEMM:
|
354
|
+
output = hidden_states
|
334
355
|
else:
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
356
|
+
if hidden_states.shape[0] > 0:
|
357
|
+
num_tokens = self.src2dst.shape[0] // self.router_topk
|
358
|
+
output = torch.empty(
|
359
|
+
(num_tokens, hidden_states.shape[1]),
|
360
|
+
device=hidden_states.device,
|
361
|
+
dtype=hidden_states.dtype,
|
362
|
+
)
|
363
|
+
deepep_post_reorder_triton_kernel[(num_tokens,)](
|
364
|
+
hidden_states,
|
365
|
+
output,
|
366
|
+
self.src2dst,
|
367
|
+
topk_idx,
|
368
|
+
topk_weights,
|
369
|
+
self.router_topk,
|
370
|
+
hidden_states.shape[1],
|
371
|
+
BLOCK_SIZE=512,
|
372
|
+
)
|
373
|
+
else:
|
374
|
+
output = torch.zeros(
|
375
|
+
(0, hidden_states.shape[1]),
|
376
|
+
device=hidden_states.device,
|
377
|
+
dtype=hidden_states.dtype,
|
378
|
+
)
|
340
379
|
previous_event = Buffer.capture() if self.async_finish else None
|
341
380
|
return output, previous_event
|
342
381
|
|
@@ -360,6 +399,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
360
399
|
|
361
400
|
def _get_buffer(self):
|
362
401
|
DeepEPBuffer.set_dispatch_mode_as_normal()
|
402
|
+
|
363
403
|
return DeepEPBuffer.get_deepep_buffer(
|
364
404
|
self.group,
|
365
405
|
self.hidden_size,
|
@@ -426,6 +466,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
426
466
|
topk_idx,
|
427
467
|
topk_weights,
|
428
468
|
reorder_topk_ids,
|
469
|
+
None,
|
429
470
|
seg_indptr,
|
430
471
|
masked_m,
|
431
472
|
expected_m,
|
@@ -570,7 +611,8 @@ class DeepEPDispatcher:
|
|
570
611
|
|
571
612
|
def dispatch(self, *args, **kwargs) -> Tuple:
|
572
613
|
self.dispatch_a(*args, **kwargs)
|
573
|
-
|
614
|
+
ret = self.dispatch_b()
|
615
|
+
return ret
|
574
616
|
|
575
617
|
def dispatch_a(
|
576
618
|
self,
|
@@ -593,7 +635,8 @@ class DeepEPDispatcher:
|
|
593
635
|
|
594
636
|
def combine(self, *args, **kwargs) -> Tuple:
|
595
637
|
self.combine_a(*args, **kwargs)
|
596
|
-
|
638
|
+
ret = self.combine_b()
|
639
|
+
return ret
|
597
640
|
|
598
641
|
def combine_a(
|
599
642
|
self,
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 32,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 1,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 64,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 1,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 4
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 16,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 64,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 128,
|
53
|
+
"BLOCK_SIZE_K": 64,
|
54
|
+
"GROUP_SIZE_M": 64,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 256,
|
61
|
+
"BLOCK_SIZE_K": 64,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 16,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 16,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 16,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 16,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 16,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 16,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 32,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 16,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 256,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 32,
|
119
|
+
"num_warps": 8,
|
120
|
+
"num_stages": 4
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 256,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 32,
|
127
|
+
"num_warps": 8,
|
128
|
+
"num_stages": 4
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 256,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 32,
|
135
|
+
"num_warps": 8,
|
136
|
+
"num_stages": 4
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 256,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 32,
|
143
|
+
"num_warps": 8,
|
144
|
+
"num_stages": 4
|
145
|
+
}
|
146
|
+
}
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 128,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 16,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 64,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 32,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 5
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 16,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 4
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 128,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 16,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 64,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 64,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 64,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 4
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 16,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 4
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 16,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 64,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 32,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 32,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 4
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 64,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 16,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 4
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 16,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 4
|
145
|
+
}
|
146
|
+
}
|
@@ -29,6 +29,7 @@ from sglang.srt.utils import (
|
|
29
29
|
get_device_name,
|
30
30
|
is_cuda,
|
31
31
|
is_hip,
|
32
|
+
log_info_on_rank0,
|
32
33
|
)
|
33
34
|
|
34
35
|
_is_hip = is_hip()
|
@@ -945,7 +946,9 @@ def get_moe_configs(
|
|
945
946
|
# For example, updating the Triton version might cause all old configs to become suboptimal.
|
946
947
|
# To achieve the best performance, consider re-tuning the Triton fused MOE kernel in your environment.
|
947
948
|
# For the tuning method, refer to: https://github.com/sgl-project/sglang/tree/main/benchmark/kernels/fused_moe_triton
|
948
|
-
|
949
|
+
log_info_on_rank0(
|
950
|
+
logger, f"Using MoE kernel config from {config_file_path}."
|
951
|
+
)
|
949
952
|
# If a configuration has been found, return it
|
950
953
|
return {int(key): val for key, val in json.load(f).items()}
|
951
954
|
|
@@ -10,16 +10,14 @@ import torch
|
|
10
10
|
from compressed_tensors import CompressionFormat
|
11
11
|
from compressed_tensors.quantization import QuantizationStrategy
|
12
12
|
|
13
|
-
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
13
|
+
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
|
14
14
|
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
15
15
|
from sglang.srt.layers.quantization.utils import (
|
16
16
|
all_close_1d,
|
17
|
-
is_cuda,
|
18
|
-
is_fp8_fnuz,
|
19
17
|
per_tensor_dequantize,
|
20
18
|
replace_parameter,
|
21
19
|
)
|
22
|
-
from sglang.srt.utils import set_weight_attrs
|
20
|
+
from sglang.srt.utils import is_cuda, set_weight_attrs
|
23
21
|
|
24
22
|
_is_cuda = is_cuda()
|
25
23
|
|
@@ -15,11 +15,12 @@ from sglang.srt.layers.parameter import (
|
|
15
15
|
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
16
16
|
CompressedTensorsScheme,
|
17
17
|
)
|
18
|
+
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
|
18
19
|
from sglang.srt.layers.quantization.fp8_utils import (
|
19
20
|
apply_fp8_linear,
|
20
21
|
normalize_e4m3fn_to_e4m3fnuz,
|
21
22
|
)
|
22
|
-
from sglang.srt.layers.quantization.utils import
|
23
|
+
from sglang.srt.layers.quantization.utils import requantize_with_max_scale
|
23
24
|
|
24
25
|
__all__ = ["CompressedTensorsW8A8Fp8"]
|
25
26
|
|
@@ -28,6 +28,11 @@ if is_cuda():
|
|
28
28
|
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
|
29
29
|
_ENABLE_JIT_DEEPGEMM = True
|
30
30
|
|
31
|
+
|
32
|
+
def get_enable_jit_deepgemm():
|
33
|
+
return _ENABLE_JIT_DEEPGEMM
|
34
|
+
|
35
|
+
|
31
36
|
logger = logging.getLogger(__name__)
|
32
37
|
|
33
38
|
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
|