sglang 0.4.1.post3__py3-none-any.whl → 0.4.1.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -0
- sglang/srt/layers/attention/__init__.py +14 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
- sglang/srt/layers/attention/flashinfer_backend.py +211 -81
- sglang/srt/layers/attention/torch_native_backend.py +1 -38
- sglang/srt/layers/attention/triton_backend.py +20 -11
- sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
- sglang/srt/layers/logits_processor.py +167 -212
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +187 -29
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -6
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/sampler.py +57 -21
- sglang/srt/layers/torchao_utils.py +17 -3
- sglang/srt/managers/io_struct.py +1 -2
- sglang/srt/managers/schedule_batch.py +26 -2
- sglang/srt/managers/schedule_policy.py +159 -90
- sglang/srt/managers/scheduler.py +62 -26
- sglang/srt/managers/tokenizer_manager.py +22 -20
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/model_executor/cuda_graph_runner.py +118 -73
- sglang/srt/model_executor/forward_batch_info.py +33 -8
- sglang/srt/model_executor/model_runner.py +63 -61
- sglang/srt/models/deepseek_v2.py +34 -7
- sglang/srt/models/grok.py +97 -26
- sglang/srt/openai_api/adapter.py +0 -17
- sglang/srt/openai_api/protocol.py +3 -3
- sglang/srt/sampling/sampling_batch_info.py +21 -0
- sglang/srt/sampling/sampling_params.py +9 -1
- sglang/srt/server.py +9 -5
- sglang/srt/server_args.py +108 -57
- sglang/srt/speculative/build_eagle_tree.py +347 -0
- sglang/srt/speculative/eagle_utils.py +618 -0
- sglang/srt/speculative/eagle_worker.py +170 -0
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/utils.py +15 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/METADATA +9 -8
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/RECORD +63 -39
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/WHEEL +1 -1
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 64,
|
4
|
+
"BLOCK_SIZE_N": 128,
|
5
|
+
"BLOCK_SIZE_K": 256,
|
6
|
+
"GROUP_SIZE_M": 32,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 64,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 64,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 3
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 64,
|
20
|
+
"BLOCK_SIZE_N": 64,
|
21
|
+
"BLOCK_SIZE_K": 256,
|
22
|
+
"GROUP_SIZE_M": 32,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 3
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 64,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 256,
|
30
|
+
"GROUP_SIZE_M": 64,
|
31
|
+
"num_warps": 8,
|
32
|
+
"num_stages": 4
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 64,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 64,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 64,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 32,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 64,
|
52
|
+
"BLOCK_SIZE_N": 128,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 64,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 32,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 64,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 64,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 64,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 16,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 64,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 32,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 128,
|
92
|
+
"BLOCK_SIZE_N": 256,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 1,
|
95
|
+
"num_warps": 8,
|
96
|
+
"num_stages": 4
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 128,
|
100
|
+
"BLOCK_SIZE_N": 256,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 16,
|
103
|
+
"num_warps": 8,
|
104
|
+
"num_stages": 4
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 128,
|
108
|
+
"BLOCK_SIZE_N": 256,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 16,
|
111
|
+
"num_warps": 8,
|
112
|
+
"num_stages": 4
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 128,
|
116
|
+
"BLOCK_SIZE_N": 256,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 64,
|
119
|
+
"num_warps": 8,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 128,
|
124
|
+
"BLOCK_SIZE_N": 256,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 32,
|
127
|
+
"num_warps": 8,
|
128
|
+
"num_stages": 4
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 128,
|
132
|
+
"BLOCK_SIZE_N": 256,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 32,
|
135
|
+
"num_warps": 8,
|
136
|
+
"num_stages": 4
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 128,
|
140
|
+
"BLOCK_SIZE_N": 256,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 16,
|
143
|
+
"num_warps": 8,
|
144
|
+
"num_stages": 4
|
145
|
+
}
|
146
|
+
}
|
@@ -17,15 +17,21 @@ from sglang.srt.layers.moe.topk import select_experts
|
|
17
17
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
18
18
|
from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
|
19
19
|
|
20
|
-
|
20
|
+
is_hip_flag = False
|
21
21
|
if not is_hip():
|
22
22
|
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
23
23
|
|
24
|
-
|
24
|
+
is_hip_flag = False
|
25
|
+
else:
|
26
|
+
is_hip_flag = True
|
25
27
|
|
26
28
|
logger = logging.getLogger(__name__)
|
27
29
|
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
28
30
|
|
31
|
+
enable_moe_align_block_size_triton = bool(
|
32
|
+
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
33
|
+
)
|
34
|
+
|
29
35
|
|
30
36
|
@triton.jit
|
31
37
|
def fused_moe_kernel(
|
@@ -222,6 +228,139 @@ def fused_moe_kernel(
|
|
222
228
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
223
229
|
|
224
230
|
|
231
|
+
def ceil_div(a, b):
|
232
|
+
return (a + b - 1) // b
|
233
|
+
|
234
|
+
|
235
|
+
@triton.jit
|
236
|
+
def moe_align_block_size_stage1(
|
237
|
+
topk_ids_ptr,
|
238
|
+
tokens_cnts_ptr,
|
239
|
+
num_experts: tl.constexpr,
|
240
|
+
numel: tl.constexpr,
|
241
|
+
tokens_per_thread: tl.constexpr,
|
242
|
+
):
|
243
|
+
pid = tl.program_id(0)
|
244
|
+
|
245
|
+
start_idx = pid * tokens_per_thread
|
246
|
+
|
247
|
+
off_c = (pid + 1) * num_experts
|
248
|
+
|
249
|
+
for i in range(tokens_per_thread):
|
250
|
+
if start_idx + i < numel:
|
251
|
+
idx = tl.load(topk_ids_ptr + start_idx + i)
|
252
|
+
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
|
253
|
+
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
|
254
|
+
|
255
|
+
|
256
|
+
@triton.jit
|
257
|
+
def moe_align_block_size_stage2(
|
258
|
+
tokens_cnts_ptr,
|
259
|
+
num_experts: tl.constexpr,
|
260
|
+
):
|
261
|
+
pid = tl.program_id(0)
|
262
|
+
|
263
|
+
last_cnt = 0
|
264
|
+
for i in range(1, num_experts + 1):
|
265
|
+
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
|
266
|
+
last_cnt = last_cnt + token_cnt
|
267
|
+
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
|
268
|
+
|
269
|
+
|
270
|
+
@triton.jit
|
271
|
+
def moe_align_block_size_stage3(
|
272
|
+
total_tokens_post_pad_ptr,
|
273
|
+
tokens_cnts_ptr,
|
274
|
+
cumsum_ptr,
|
275
|
+
num_experts: tl.constexpr,
|
276
|
+
block_size: tl.constexpr,
|
277
|
+
):
|
278
|
+
last_cumsum = 0
|
279
|
+
off_cnt = num_experts * num_experts
|
280
|
+
for i in range(1, num_experts + 1):
|
281
|
+
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
|
282
|
+
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
|
283
|
+
tl.store(cumsum_ptr + i, last_cumsum)
|
284
|
+
tl.store(total_tokens_post_pad_ptr, last_cumsum)
|
285
|
+
|
286
|
+
|
287
|
+
@triton.jit
|
288
|
+
def moe_align_block_size_stage4(
|
289
|
+
topk_ids_ptr,
|
290
|
+
sorted_token_ids_ptr,
|
291
|
+
expert_ids_ptr,
|
292
|
+
tokens_cnts_ptr,
|
293
|
+
cumsum_ptr,
|
294
|
+
num_experts: tl.constexpr,
|
295
|
+
block_size: tl.constexpr,
|
296
|
+
numel: tl.constexpr,
|
297
|
+
tokens_per_thread: tl.constexpr,
|
298
|
+
):
|
299
|
+
pid = tl.program_id(0)
|
300
|
+
start_idx = tl.load(cumsum_ptr + pid)
|
301
|
+
end_idx = tl.load(cumsum_ptr + pid + 1)
|
302
|
+
|
303
|
+
for i in range(start_idx, end_idx, block_size):
|
304
|
+
tl.store(expert_ids_ptr + i // block_size, pid)
|
305
|
+
|
306
|
+
start_idx = pid * tokens_per_thread
|
307
|
+
off_t = pid * num_experts
|
308
|
+
|
309
|
+
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):
|
310
|
+
expert_id = tl.load(topk_ids_ptr + i)
|
311
|
+
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
|
312
|
+
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
|
313
|
+
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
|
314
|
+
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
|
315
|
+
|
316
|
+
|
317
|
+
def moe_align_block_size_triton(
|
318
|
+
topk_ids: torch.Tensor,
|
319
|
+
num_experts: int,
|
320
|
+
block_size: int,
|
321
|
+
sorted_token_ids: torch.Tensor,
|
322
|
+
expert_ids: torch.Tensor,
|
323
|
+
num_tokens_post_pad: torch.Tensor,
|
324
|
+
) -> None:
|
325
|
+
numel = topk_ids.numel()
|
326
|
+
grid = (num_experts,)
|
327
|
+
tokens_cnts = torch.zeros(
|
328
|
+
(num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
|
329
|
+
)
|
330
|
+
cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
|
331
|
+
tokens_per_thread = ceil_div(numel, num_experts)
|
332
|
+
|
333
|
+
moe_align_block_size_stage1[grid](
|
334
|
+
topk_ids,
|
335
|
+
tokens_cnts,
|
336
|
+
num_experts,
|
337
|
+
numel,
|
338
|
+
tokens_per_thread,
|
339
|
+
)
|
340
|
+
moe_align_block_size_stage2[grid](
|
341
|
+
tokens_cnts,
|
342
|
+
num_experts,
|
343
|
+
)
|
344
|
+
moe_align_block_size_stage3[(1,)](
|
345
|
+
num_tokens_post_pad,
|
346
|
+
tokens_cnts,
|
347
|
+
cumsum,
|
348
|
+
num_experts,
|
349
|
+
block_size,
|
350
|
+
)
|
351
|
+
moe_align_block_size_stage4[grid](
|
352
|
+
topk_ids,
|
353
|
+
sorted_token_ids,
|
354
|
+
expert_ids,
|
355
|
+
tokens_cnts,
|
356
|
+
cumsum,
|
357
|
+
num_experts,
|
358
|
+
block_size,
|
359
|
+
numel,
|
360
|
+
tokens_per_thread,
|
361
|
+
)
|
362
|
+
|
363
|
+
|
225
364
|
def moe_align_block_size(
|
226
365
|
topk_ids: torch.Tensor, block_size: int, num_experts: int
|
227
366
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
@@ -272,24 +411,36 @@ def moe_align_block_size(
|
|
272
411
|
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
273
412
|
)
|
274
413
|
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
275
|
-
if
|
276
|
-
|
277
|
-
(
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
414
|
+
if num_experts >= 224:
|
415
|
+
if enable_moe_align_block_size_triton or is_hip_flag:
|
416
|
+
moe_align_block_size_triton(
|
417
|
+
topk_ids,
|
418
|
+
num_experts,
|
419
|
+
block_size,
|
420
|
+
sorted_ids,
|
421
|
+
expert_ids,
|
422
|
+
num_tokens_post_pad,
|
423
|
+
)
|
424
|
+
else:
|
425
|
+
token_cnts_buffer = torch.empty(
|
426
|
+
(num_experts + 1) * num_experts,
|
427
|
+
dtype=torch.int32,
|
428
|
+
device=topk_ids.device,
|
429
|
+
)
|
430
|
+
cumsum_buffer = torch.empty(
|
431
|
+
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
432
|
+
)
|
282
433
|
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
434
|
+
sgl_moe_align_block_size(
|
435
|
+
topk_ids,
|
436
|
+
num_experts,
|
437
|
+
block_size,
|
438
|
+
sorted_ids,
|
439
|
+
expert_ids,
|
440
|
+
num_tokens_post_pad,
|
441
|
+
token_cnts_buffer,
|
442
|
+
cumsum_buffer,
|
443
|
+
)
|
293
444
|
else:
|
294
445
|
ops.moe_align_block_size(
|
295
446
|
topk_ids,
|
@@ -326,9 +477,9 @@ def invoke_fused_moe_kernel(
|
|
326
477
|
|
327
478
|
padded_size = 0
|
328
479
|
if use_fp8_w8a8:
|
329
|
-
padded_size = padding_size
|
330
480
|
assert B_scale is not None
|
331
481
|
if block_shape is None:
|
482
|
+
padded_size = padding_size
|
332
483
|
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
333
484
|
else:
|
334
485
|
assert len(block_shape) == 2
|
@@ -463,7 +614,7 @@ def get_default_config(
|
|
463
614
|
"BLOCK_SIZE_K": 128,
|
464
615
|
"GROUP_SIZE_M": 32,
|
465
616
|
"num_warps": 8,
|
466
|
-
"num_stages": 4,
|
617
|
+
"num_stages": 2 if is_hip_flag else 4,
|
467
618
|
}
|
468
619
|
if M <= E:
|
469
620
|
config = {
|
@@ -472,7 +623,7 @@ def get_default_config(
|
|
472
623
|
"BLOCK_SIZE_K": 128,
|
473
624
|
"GROUP_SIZE_M": 1,
|
474
625
|
"num_warps": 4,
|
475
|
-
"num_stages": 4,
|
626
|
+
"num_stages": 2 if is_hip_flag else 4,
|
476
627
|
}
|
477
628
|
else:
|
478
629
|
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
|
@@ -482,7 +633,7 @@ def get_default_config(
|
|
482
633
|
"BLOCK_SIZE_K": block_shape[1],
|
483
634
|
"GROUP_SIZE_M": 32,
|
484
635
|
"num_warps": 4,
|
485
|
-
"num_stages": 3,
|
636
|
+
"num_stages": 2 if is_hip_flag else 3,
|
486
637
|
}
|
487
638
|
else:
|
488
639
|
config = {
|
@@ -727,7 +878,7 @@ def fused_experts_impl(
|
|
727
878
|
block_shape: Optional[List[int]] = None,
|
728
879
|
):
|
729
880
|
padded_size = padding_size
|
730
|
-
if not use_fp8_w8a8:
|
881
|
+
if not use_fp8_w8a8 or block_shape is not None:
|
731
882
|
padded_size = 0
|
732
883
|
|
733
884
|
# Check constraints.
|
@@ -854,11 +1005,18 @@ def fused_experts_impl(
|
|
854
1005
|
block_shape=block_shape,
|
855
1006
|
)
|
856
1007
|
|
857
|
-
|
858
|
-
|
859
|
-
|
860
|
-
|
861
|
-
|
1008
|
+
if is_hip_flag:
|
1009
|
+
ops.moe_sum(
|
1010
|
+
intermediate_cache3.view(*intermediate_cache3.shape),
|
1011
|
+
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
1012
|
+
)
|
1013
|
+
else:
|
1014
|
+
torch.sum(
|
1015
|
+
intermediate_cache3.view(*intermediate_cache3.shape),
|
1016
|
+
dim=1,
|
1017
|
+
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
1018
|
+
)
|
1019
|
+
|
862
1020
|
return out_hidden_states
|
863
1021
|
|
864
1022
|
|
@@ -321,9 +321,12 @@ class FusedMoE(torch.nn.Module):
|
|
321
321
|
# Index the loaded weight for tp sharding.
|
322
322
|
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
323
323
|
shard_size = expert_data.shape[shard_dim] // 2
|
324
|
-
|
325
|
-
|
326
|
-
|
324
|
+
|
325
|
+
if not self.use_presharded_weights:
|
326
|
+
loaded_weight = loaded_weight.narrow(
|
327
|
+
shard_dim, shard_size * tp_rank, shard_size
|
328
|
+
)
|
329
|
+
|
327
330
|
# Narrow parameter and load.
|
328
331
|
# w1, gate_proj: Load into first logical weight of w13.
|
329
332
|
if shard_id == "w1":
|
@@ -347,9 +350,12 @@ class FusedMoE(torch.nn.Module):
|
|
347
350
|
# down_proj: "RowParallel" so tp sharding on input_dim
|
348
351
|
# Narrow parameter and load.
|
349
352
|
shard_size = expert_data.shape[shard_dim]
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
+
|
354
|
+
if not self.use_presharded_weights:
|
355
|
+
loaded_weight = loaded_weight.narrow(
|
356
|
+
shard_dim, shard_size * tp_rank, shard_size
|
357
|
+
)
|
358
|
+
|
353
359
|
# w2, down_proj: Load into only logical weight of w2.
|
354
360
|
expert_data.copy_(loaded_weight)
|
355
361
|
|
@@ -389,7 +395,9 @@ class FusedMoE(torch.nn.Module):
|
|
389
395
|
weight_name: str,
|
390
396
|
shard_id: str,
|
391
397
|
expert_id: int,
|
398
|
+
use_presharded_weights: bool = False,
|
392
399
|
) -> None:
|
400
|
+
self.use_presharded_weights = use_presharded_weights
|
393
401
|
|
394
402
|
# compressed-tensors checkpoints with packed weights are stored flipped
|
395
403
|
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
@@ -280,9 +280,9 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
280
280
|
weight_scale=layer.weight_scale_inv,
|
281
281
|
input_scale=None,
|
282
282
|
)
|
283
|
-
layer.weight = torch.nn.Parameter(weight,
|
283
|
+
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
|
284
284
|
layer.weight_scale_inv = torch.nn.Parameter(
|
285
|
-
weight_scale,
|
285
|
+
weight_scale, requires_grad=False
|
286
286
|
)
|
287
287
|
layer.input_scale = None
|
288
288
|
return
|
sglang/srt/layers/sampler.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
import logging
|
2
|
-
from typing import
|
2
|
+
from typing import List
|
3
3
|
|
4
4
|
import torch
|
5
5
|
from torch import nn
|
@@ -28,13 +28,12 @@ class Sampler(nn.Module):
|
|
28
28
|
|
29
29
|
def forward(
|
30
30
|
self,
|
31
|
-
|
31
|
+
logits_output: LogitsProcessorOutput,
|
32
32
|
sampling_info: SamplingBatchInfo,
|
33
|
+
return_logprob: bool,
|
34
|
+
top_logprobs_nums: List[int],
|
33
35
|
):
|
34
|
-
|
35
|
-
logits = logits.next_token_logits
|
36
|
-
|
37
|
-
logits = logits.contiguous()
|
36
|
+
logits = logits_output.next_token_logits
|
38
37
|
|
39
38
|
if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
|
40
39
|
logger.warning("Detected errors during sampling! NaN in the logits.")
|
@@ -47,6 +46,8 @@ class Sampler(nn.Module):
|
|
47
46
|
if sampling_info.is_all_greedy:
|
48
47
|
# Use torch.argmax if all requests use greedy sampling
|
49
48
|
batch_next_token_ids = torch.argmax(logits, -1)
|
49
|
+
if return_logprob:
|
50
|
+
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
50
51
|
else:
|
51
52
|
# Post process logits
|
52
53
|
logits.div_(sampling_info.temperatures)
|
@@ -54,6 +55,14 @@ class Sampler(nn.Module):
|
|
54
55
|
del logits
|
55
56
|
|
56
57
|
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
58
|
+
if return_logprob:
|
59
|
+
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems,
|
60
|
+
# https://github.com/flashinfer-ai/flashinfer/issues/708
|
61
|
+
# so we use the torch implementation.
|
62
|
+
logprobs = torch.log(
|
63
|
+
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
|
64
|
+
)
|
65
|
+
|
57
66
|
max_top_k_round, batch_size = 32, probs.shape[0]
|
58
67
|
uniform_samples = torch.rand(
|
59
68
|
(max_top_k_round, batch_size), device=probs.device
|
@@ -76,6 +85,7 @@ class Sampler(nn.Module):
|
|
76
85
|
if self.use_nan_detectioin and not torch.all(success):
|
77
86
|
logger.warning("Detected errors during sampling!")
|
78
87
|
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
88
|
+
|
79
89
|
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
80
90
|
# A slower fallback implementation with torch native operations.
|
81
91
|
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
@@ -85,12 +95,31 @@ class Sampler(nn.Module):
|
|
85
95
|
sampling_info.min_ps,
|
86
96
|
sampling_info.need_min_p_sampling,
|
87
97
|
)
|
98
|
+
if return_logprob:
|
99
|
+
logprobs = torch.log(
|
100
|
+
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
|
101
|
+
)
|
88
102
|
else:
|
89
103
|
raise ValueError(
|
90
104
|
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
91
105
|
)
|
92
106
|
|
93
|
-
|
107
|
+
batch_next_token_ids = batch_next_token_ids.to(torch.int32)
|
108
|
+
|
109
|
+
# Attach logprobs to logits_output (in-place modification)
|
110
|
+
if return_logprob:
|
111
|
+
if any(x > 0 for x in top_logprobs_nums):
|
112
|
+
(
|
113
|
+
logits_output.next_token_top_logprobs_val,
|
114
|
+
logits_output.next_token_top_logprobs_idx,
|
115
|
+
) = get_top_logprobs(logprobs, top_logprobs_nums)
|
116
|
+
|
117
|
+
logits_output.next_token_logprobs = logprobs[
|
118
|
+
torch.arange(len(batch_next_token_ids), device=sampling_info.device),
|
119
|
+
batch_next_token_ids,
|
120
|
+
]
|
121
|
+
|
122
|
+
return batch_next_token_ids
|
94
123
|
|
95
124
|
|
96
125
|
def top_k_top_p_min_p_sampling_from_probs_torch(
|
@@ -120,20 +149,27 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
|
|
120
149
|
return batch_next_token_ids
|
121
150
|
|
122
151
|
|
123
|
-
def
|
152
|
+
def top_p_normalize_probs_torch(
|
124
153
|
probs: torch.Tensor,
|
125
154
|
top_ps: torch.Tensor,
|
126
155
|
):
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
156
|
+
# See also top_k_top_p_min_p_sampling_from_probs_torch
|
157
|
+
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
158
|
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
159
|
+
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
160
|
+
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
161
|
+
return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)
|
162
|
+
|
163
|
+
|
164
|
+
def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
|
165
|
+
max_k = max(top_logprobs_nums)
|
166
|
+
ret = logprobs.topk(max_k, dim=1)
|
167
|
+
values = ret.values.tolist()
|
168
|
+
indices = ret.indices.tolist()
|
169
|
+
|
170
|
+
output_top_logprobs_val = []
|
171
|
+
output_top_logprobs_idx = []
|
172
|
+
for i, k in enumerate(top_logprobs_nums):
|
173
|
+
output_top_logprobs_val.append(values[i][:k])
|
174
|
+
output_top_logprobs_idx.append(indices[i][:k])
|
175
|
+
return output_top_logprobs_val, output_top_logprobs_idx
|
@@ -11,6 +11,22 @@ import torch
|
|
11
11
|
logger = logging.getLogger(__name__)
|
12
12
|
|
13
13
|
|
14
|
+
def get_gemlite_cache_path() -> str:
|
15
|
+
return f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
|
16
|
+
|
17
|
+
|
18
|
+
def save_gemlite_cache(print_error: bool = False) -> bool:
|
19
|
+
try:
|
20
|
+
from gemlite.core import GemLiteLinearTriton
|
21
|
+
|
22
|
+
GemLiteLinearTriton.cache_config(get_gemlite_cache_path())
|
23
|
+
except Exception:
|
24
|
+
if print_error:
|
25
|
+
logger.error("Failed to save the GemLite cache.")
|
26
|
+
return False
|
27
|
+
return True
|
28
|
+
|
29
|
+
|
14
30
|
def apply_torchao_config_to_model(
|
15
31
|
model: torch.nn.Module, torchao_config: str, filter_fn=None
|
16
32
|
):
|
@@ -74,9 +90,7 @@ def apply_torchao_config_to_model(
|
|
74
90
|
)
|
75
91
|
|
76
92
|
# try to load gemlite kernel config
|
77
|
-
GemLiteLinearTriton.load_config(
|
78
|
-
f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
|
79
|
-
)
|
93
|
+
GemLiteLinearTriton.load_config(get_gemlite_cache_path())
|
80
94
|
|
81
95
|
elif "fp8wo" in torchao_config:
|
82
96
|
# this requires newer hardware
|
sglang/srt/managers/io_struct.py
CHANGED