sglang 0.4.4.post3__py3-none-any.whl → 0.4.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +49 -7
- sglang/lang/chat_template.py +24 -0
- sglang/srt/_custom_ops.py +59 -92
- sglang/srt/configs/model_config.py +5 -0
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/conversation.py +29 -4
- sglang/srt/custom_op.py +5 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/entrypoints/engine.py +0 -5
- sglang/srt/layers/attention/flashattention_backend.py +678 -83
- sglang/srt/layers/attention/flashinfer_backend.py +5 -7
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
- sglang/srt/layers/moe/ep_moe/layer.py +79 -80
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
- sglang/srt/layers/moe/fused_moe_native.py +5 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +416 -50
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/topk.py +49 -3
- sglang/srt/layers/quantization/__init__.py +5 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
- sglang/srt/layers/quantization/fp8.py +3 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -4
- sglang/srt/layers/quantization/moe_wna16.py +503 -0
- sglang/srt/layers/quantization/utils.py +1 -1
- sglang/srt/layers/quantization/w8a8_int8.py +2 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/rotary_embedding.py +63 -12
- sglang/srt/managers/cache_controller.py +34 -11
- sglang/srt/managers/mm_utils.py +202 -156
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
- sglang/srt/managers/multimodal_processors/clip.py +7 -26
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
- sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
- sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
- sglang/srt/managers/multimodal_processors/llava.py +34 -14
- sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
- sglang/srt/managers/multimodal_processors/mlama.py +10 -23
- sglang/srt/managers/multimodal_processors/mllama4.py +161 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
- sglang/srt/managers/schedule_batch.py +185 -128
- sglang/srt/managers/scheduler.py +4 -4
- sglang/srt/managers/tokenizer_manager.py +1 -1
- sglang/srt/managers/utils.py +1 -6
- sglang/srt/mem_cache/hiradix_cache.py +62 -52
- sglang/srt/mem_cache/memory_pool.py +72 -6
- sglang/srt/mem_cache/paged_allocator.py +39 -0
- sglang/srt/metrics/collector.py +23 -53
- sglang/srt/model_executor/cuda_graph_runner.py +8 -6
- sglang/srt/model_executor/forward_batch_info.py +10 -10
- sglang/srt/model_executor/model_runner.py +60 -57
- sglang/srt/model_loader/loader.py +8 -0
- sglang/srt/models/clip.py +12 -7
- sglang/srt/models/deepseek_janus_pro.py +10 -15
- sglang/srt/models/deepseek_v2.py +212 -121
- sglang/srt/models/deepseek_vl2.py +105 -104
- sglang/srt/models/gemma3_mm.py +14 -80
- sglang/srt/models/llama.py +16 -5
- sglang/srt/models/llama4.py +420 -0
- sglang/srt/models/llava.py +31 -19
- sglang/srt/models/llavavid.py +16 -7
- sglang/srt/models/minicpmo.py +63 -147
- sglang/srt/models/minicpmv.py +17 -27
- sglang/srt/models/mllama.py +29 -14
- sglang/srt/models/mllama4.py +154 -0
- sglang/srt/models/qwen2.py +9 -6
- sglang/srt/models/qwen2_5_vl.py +21 -31
- sglang/srt/models/qwen2_vl.py +20 -21
- sglang/srt/openai_api/adapter.py +18 -6
- sglang/srt/platforms/interface.py +371 -0
- sglang/srt/server_args.py +99 -14
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
- sglang/srt/speculative/eagle_utils.py +140 -28
- sglang/srt/speculative/eagle_worker.py +93 -24
- sglang/srt/utils.py +104 -51
- sglang/test/test_custom_ops.py +55 -0
- sglang/test/test_utils.py +13 -26
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/METADATA +4 -3
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/RECORD +99 -84
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/top_level.txt +0 -0
@@ -14,7 +14,6 @@ from functools import partial
|
|
14
14
|
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
15
15
|
|
16
16
|
import torch
|
17
|
-
import triton
|
18
17
|
|
19
18
|
from sglang.global_config import global_config
|
20
19
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
@@ -22,7 +21,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
|
|
22
21
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
23
22
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
24
23
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
25
|
-
from sglang.srt.utils import
|
24
|
+
from sglang.srt.utils import is_flashinfer_available, next_power_of_2
|
26
25
|
|
27
26
|
if TYPE_CHECKING:
|
28
27
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -932,6 +931,7 @@ class FlashInferMultiStepDraftBackend:
|
|
932
931
|
self.topk = topk
|
933
932
|
self.speculative_num_steps = speculative_num_steps
|
934
933
|
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
934
|
+
self.page_size = model_runner.page_size
|
935
935
|
|
936
936
|
max_bs = model_runner.req_to_token_pool.size * self.topk
|
937
937
|
self.kv_indptr = torch.zeros(
|
@@ -985,9 +985,9 @@ class FlashInferMultiStepDraftBackend:
|
|
985
985
|
self.pool_len,
|
986
986
|
kv_indices_buffer.shape[1],
|
987
987
|
self.kv_indptr.shape[1],
|
988
|
-
|
989
|
-
|
990
|
-
|
988
|
+
next_power_of_2(num_seqs),
|
989
|
+
next_power_of_2(self.speculative_num_steps),
|
990
|
+
next_power_of_2(bs),
|
991
991
|
)
|
992
992
|
|
993
993
|
assert forward_batch.spec_info is not None
|
@@ -1018,8 +1018,6 @@ class FlashInferMultiStepDraftBackend:
|
|
1018
1018
|
)
|
1019
1019
|
|
1020
1020
|
def call_fn(i, forward_batch):
|
1021
|
-
assert forward_batch.spec_info is not None
|
1022
|
-
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
1023
1021
|
forward_batch.spec_info.kv_indptr = (
|
1024
1022
|
forward_batch.spec_info.kv_indptr.clone()
|
1025
1023
|
)
|
@@ -71,8 +71,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
71
71
|
self.device = model_runner.device
|
72
72
|
self.skip_prefill = skip_prefill
|
73
73
|
|
74
|
-
global_config.enable_flashinfer_mla = True
|
75
|
-
|
76
74
|
# Allocate buffers
|
77
75
|
global global_workspace_buffer
|
78
76
|
if global_workspace_buffer is None:
|
@@ -797,7 +795,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
797
795
|
encoder_lens=None,
|
798
796
|
forward_mode=ForwardMode.DECODE,
|
799
797
|
spec_info=forward_batch.spec_info,
|
800
|
-
seq_lens_cpu=forward_batch.
|
798
|
+
seq_lens_cpu=forward_batch.seq_lens_cpu,
|
801
799
|
)
|
802
800
|
|
803
801
|
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
@@ -92,7 +92,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|
92
92
|
if forward_batch.forward_mode.is_decode_or_idle():
|
93
93
|
if spec_info is None:
|
94
94
|
max_seqlen_pad = triton.cdiv(
|
95
|
-
forward_batch.
|
95
|
+
forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE
|
96
96
|
)
|
97
97
|
block_kv_indices = torch.full(
|
98
98
|
(bs, max_seqlen_pad),
|
@@ -244,6 +244,148 @@ def silu_and_mul_triton_kernel(
|
|
244
244
|
tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
|
245
245
|
|
246
246
|
|
247
|
+
# copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py
|
248
|
+
@triton.jit
|
249
|
+
def _silu_and_mul_post_quant_kernel(
|
250
|
+
input_ptr,
|
251
|
+
stride_input_0,
|
252
|
+
stride_input_1,
|
253
|
+
stride_input_2,
|
254
|
+
output_ptr,
|
255
|
+
stride_output_0,
|
256
|
+
stride_output_1,
|
257
|
+
stride_output_2,
|
258
|
+
output_scale_ptr,
|
259
|
+
stride_output_scale_0,
|
260
|
+
stride_output_scale_1,
|
261
|
+
stride_output_scale_2,
|
262
|
+
masked_m_ptr,
|
263
|
+
size_n,
|
264
|
+
fp8_max,
|
265
|
+
fp8_min,
|
266
|
+
BLOCK_N: tl.constexpr,
|
267
|
+
NUM_STAGE: tl.constexpr,
|
268
|
+
):
|
269
|
+
expert_id = tl.program_id(2)
|
270
|
+
token_id = tl.program_id(1)
|
271
|
+
hidden_dim_block_index = tl.program_id(0)
|
272
|
+
|
273
|
+
block_num_per_expert = tl.num_programs(1)
|
274
|
+
|
275
|
+
token_num_cur_expert = tl.load(masked_m_ptr + expert_id)
|
276
|
+
|
277
|
+
stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64)
|
278
|
+
stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64)
|
279
|
+
stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64)
|
280
|
+
stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64)
|
281
|
+
|
282
|
+
offs_in_d = hidden_dim_block_index * BLOCK_N + tl.arange(0, BLOCK_N)
|
283
|
+
input_ptr_offs = input_ptr + expert_id * stride_input_0 + offs_in_d
|
284
|
+
output_ptr_offs = output_ptr + expert_id * stride_output_0 + offs_in_d
|
285
|
+
output_scale_offs = (
|
286
|
+
output_scale_ptr
|
287
|
+
+ expert_id * stride_output_scale_0
|
288
|
+
+ hidden_dim_block_index * stride_output_scale_2
|
289
|
+
)
|
290
|
+
|
291
|
+
for token_index in tl.range(
|
292
|
+
token_id, token_num_cur_expert, block_num_per_expert, num_stages=NUM_STAGE
|
293
|
+
):
|
294
|
+
gate = tl.load(
|
295
|
+
input_ptr_offs + token_index * stride_input_1,
|
296
|
+
mask=offs_in_d < size_n,
|
297
|
+
other=0.0,
|
298
|
+
).to(tl.float32)
|
299
|
+
up = tl.load(
|
300
|
+
input_ptr_offs + token_index * stride_input_1 + size_n,
|
301
|
+
mask=offs_in_d < size_n,
|
302
|
+
other=0.0,
|
303
|
+
)
|
304
|
+
gate = gate / (1 + tl.exp(-gate))
|
305
|
+
gate = gate.to(input_ptr.dtype.element_ty)
|
306
|
+
gate_up = up * gate
|
307
|
+
_absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10)
|
308
|
+
output_s = _absmax / fp8_max
|
309
|
+
output_q = tl.clamp(gate_up / output_s, fp8_min, fp8_max).to(
|
310
|
+
output_ptr.dtype.element_ty
|
311
|
+
)
|
312
|
+
tl.store(
|
313
|
+
output_ptr_offs + token_index * stride_output_1,
|
314
|
+
output_q,
|
315
|
+
mask=offs_in_d < size_n,
|
316
|
+
)
|
317
|
+
tl.store(
|
318
|
+
output_scale_offs + token_index * stride_output_scale_1,
|
319
|
+
output_s,
|
320
|
+
)
|
321
|
+
|
322
|
+
|
323
|
+
def silu_and_mul_masked_post_quant_fwd(
|
324
|
+
input: torch.Tensor,
|
325
|
+
output: torch.Tensor,
|
326
|
+
output_scale: torch.Tensor,
|
327
|
+
quant_group_size: int,
|
328
|
+
masked_m: torch.Tensor,
|
329
|
+
):
|
330
|
+
"""
|
331
|
+
input shape [expert_num, token_num_padded, hidden_dim]
|
332
|
+
output shape [expert_num, token_num_padded, hidden_dim // 2], dtype fp8
|
333
|
+
output_scale [expert_num token_num_paddded, hidden_dim // 2 // 128] dtype float32
|
334
|
+
quant_group_size int,
|
335
|
+
masked_m shape [expert_num],
|
336
|
+
"""
|
337
|
+
|
338
|
+
assert input.is_contiguous()
|
339
|
+
assert output.dtype == torch.float8_e4m3fn
|
340
|
+
assert output.is_contiguous()
|
341
|
+
assert len(input.shape) == 3
|
342
|
+
assert input.shape[0] == masked_m.shape[0]
|
343
|
+
assert input.shape[-1] % 2 == 0
|
344
|
+
|
345
|
+
size_n = input.shape[-1] // 2
|
346
|
+
assert size_n % quant_group_size == 0
|
347
|
+
|
348
|
+
expert_num = len(masked_m)
|
349
|
+
|
350
|
+
if expert_num < 4:
|
351
|
+
BLOCK_NUM_PER_EXPERT = 64
|
352
|
+
else:
|
353
|
+
BLOCK_NUM_PER_EXPERT = 32
|
354
|
+
|
355
|
+
BLOCK_N = quant_group_size
|
356
|
+
num_warps = 1
|
357
|
+
NUM_STAGES = 6
|
358
|
+
hidden_dim_split_block_num = triton.cdiv(size_n, BLOCK_N)
|
359
|
+
assert BLOCK_N % quant_group_size == 0
|
360
|
+
|
361
|
+
grid = (
|
362
|
+
hidden_dim_split_block_num,
|
363
|
+
BLOCK_NUM_PER_EXPERT,
|
364
|
+
expert_num,
|
365
|
+
)
|
366
|
+
|
367
|
+
finfo = torch.finfo(torch.float8_e4m3fn)
|
368
|
+
fp8_max = finfo.max
|
369
|
+
fp8_min = -fp8_max
|
370
|
+
|
371
|
+
_silu_and_mul_post_quant_kernel[grid](
|
372
|
+
input,
|
373
|
+
*input.stride(),
|
374
|
+
output,
|
375
|
+
*output.stride(),
|
376
|
+
output_scale,
|
377
|
+
*output_scale.stride(),
|
378
|
+
masked_m,
|
379
|
+
size_n,
|
380
|
+
fp8_max,
|
381
|
+
fp8_min,
|
382
|
+
BLOCK_N=BLOCK_N,
|
383
|
+
NUM_STAGE=NUM_STAGES,
|
384
|
+
num_warps=num_warps,
|
385
|
+
)
|
386
|
+
return
|
387
|
+
|
388
|
+
|
247
389
|
@triton.jit
|
248
390
|
def tanh(x):
|
249
391
|
return 2 * tl.sigmoid(2 * x) - 1
|
@@ -3,12 +3,16 @@ from typing import Callable, List, Optional, Tuple
|
|
3
3
|
|
4
4
|
import torch
|
5
5
|
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
6
|
+
try:
|
7
|
+
from deep_gemm import (
|
8
|
+
get_col_major_tma_aligned_tensor,
|
9
|
+
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
10
|
+
)
|
11
|
+
|
12
|
+
use_deep_gemm = True
|
13
|
+
except ImportError:
|
14
|
+
use_deep_gemm = False
|
15
|
+
|
12
16
|
from torch.nn import Module
|
13
17
|
|
14
18
|
from sglang.srt.custom_op import CustomOp
|
@@ -22,6 +26,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
22
26
|
post_reorder_triton_kernel,
|
23
27
|
pre_reorder_triton_kernel,
|
24
28
|
run_moe_ep_preproess,
|
29
|
+
silu_and_mul_masked_post_quant_fwd,
|
25
30
|
silu_and_mul_triton_kernel,
|
26
31
|
)
|
27
32
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
@@ -33,7 +38,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|
33
38
|
)
|
34
39
|
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
35
40
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
36
|
-
from sglang.srt.utils import is_cuda, is_hip, set_weight_attrs
|
41
|
+
from sglang.srt.utils import DeepEPMode, is_cuda, is_hip, set_weight_attrs
|
37
42
|
|
38
43
|
_is_cuda = is_cuda()
|
39
44
|
|
@@ -42,7 +47,6 @@ if _is_cuda:
|
|
42
47
|
else:
|
43
48
|
from vllm import _custom_ops as vllm_ops
|
44
49
|
|
45
|
-
|
46
50
|
logger = logging.getLogger(__name__)
|
47
51
|
|
48
52
|
_is_hip = is_hip()
|
@@ -809,6 +813,7 @@ class DeepEPMoE(EPMoE):
|
|
809
813
|
correction_bias: Optional[torch.Tensor] = None,
|
810
814
|
custom_routing_function: Optional[Callable] = None,
|
811
815
|
activation: str = "silu",
|
816
|
+
deepep_mode: DeepEPMode = DeepEPMode.auto,
|
812
817
|
):
|
813
818
|
super().__init__(
|
814
819
|
num_experts,
|
@@ -827,21 +832,38 @@ class DeepEPMoE(EPMoE):
|
|
827
832
|
custom_routing_function,
|
828
833
|
activation,
|
829
834
|
)
|
835
|
+
self.deepep_mode = deepep_mode
|
836
|
+
if self.deepep_mode.enable_low_latency():
|
837
|
+
assert use_deep_gemm, f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
838
|
+
self.w13_weight_fp8 = (
|
839
|
+
self.w13_weight,
|
840
|
+
(
|
841
|
+
self.w13_weight_scale_inv
|
842
|
+
if self.use_block_quant
|
843
|
+
else self.w13_weight_scale
|
844
|
+
),
|
845
|
+
)
|
846
|
+
self.w2_weight_fp8 = (
|
847
|
+
self.w2_weight,
|
848
|
+
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
|
849
|
+
)
|
830
850
|
|
831
851
|
def forward(
|
832
852
|
self,
|
833
853
|
hidden_states: torch.Tensor,
|
834
854
|
reorder_topk_ids: torch.Tensor,
|
835
855
|
seg_indptr: torch.Tensor,
|
856
|
+
masked_m: torch.Tensor,
|
857
|
+
expected_m: int,
|
836
858
|
forward_mode: ForwardMode,
|
837
859
|
):
|
838
|
-
|
839
|
-
if
|
860
|
+
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
861
|
+
if resolved_deepep_mode == DeepEPMode.normal:
|
840
862
|
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
|
863
|
+
elif resolved_deepep_mode == DeepEPMode.low_latency:
|
864
|
+
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
|
841
865
|
else:
|
842
|
-
|
843
|
-
hidden_states, reorder_topk_ids, seg_indptr
|
844
|
-
)
|
866
|
+
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
845
867
|
|
846
868
|
def forward_normal(
|
847
869
|
self,
|
@@ -958,89 +980,66 @@ class DeepEPMoE(EPMoE):
|
|
958
980
|
|
959
981
|
def forward_deepgemm_masked(
|
960
982
|
self,
|
961
|
-
|
962
|
-
|
963
|
-
|
983
|
+
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
|
984
|
+
masked_m: torch.Tensor,
|
985
|
+
expected_m: int,
|
964
986
|
):
|
965
987
|
assert self.quant_method is not None
|
966
988
|
assert self.activation == "silu"
|
967
|
-
|
968
|
-
|
969
|
-
|
970
|
-
torch.max(hidden_states)
|
971
|
-
.repeat(self.num_experts_per_partition)
|
972
|
-
.to(torch.float32)
|
973
|
-
)
|
974
|
-
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
989
|
+
assert (
|
990
|
+
hidden_states_fp8[0].size(0) % 4 == 0
|
991
|
+
), f"TMA alignment error: {hidden_states_fp8[0].size(0)}"
|
975
992
|
|
976
993
|
# GroupGemm-0
|
994
|
+
num_groups, m, k = hidden_states_fp8[0].size()
|
995
|
+
n = self.w13_weight.size(1)
|
996
|
+
expected_m = min(expected_m, m)
|
977
997
|
gateup_output = torch.empty(
|
978
|
-
|
979
|
-
|
980
|
-
|
981
|
-
|
998
|
+
(num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16
|
999
|
+
)
|
1000
|
+
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
1001
|
+
hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
|
982
1002
|
)
|
983
|
-
if hidden_states.shape[0] > 0:
|
984
|
-
# Transpose earlier so that the testing will not trigger transposing kernels
|
985
|
-
hidden_states = (
|
986
|
-
hidden_states[0],
|
987
|
-
get_col_major_tma_aligned_tensor(hidden_states[1]),
|
988
|
-
)
|
989
|
-
"""
|
990
|
-
gateup_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
991
|
-
hidden_states, self.w13_weight, out, masked_m, expected_m
|
992
|
-
)
|
993
|
-
"""
|
994
1003
|
|
995
1004
|
# Act
|
996
1005
|
down_input = torch.empty(
|
997
|
-
|
998
|
-
|
999
|
-
|
1000
|
-
|
1001
|
-
self.fp8_dtype
|
1002
|
-
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
1003
|
-
else hidden_states.dtype
|
1006
|
+
(
|
1007
|
+
gateup_output.shape[0],
|
1008
|
+
gateup_output.shape[1],
|
1009
|
+
gateup_output.shape[2] // 2,
|
1004
1010
|
),
|
1011
|
+
device=gateup_output.device,
|
1012
|
+
dtype=self.fp8_dtype,
|
1005
1013
|
)
|
1006
|
-
|
1007
|
-
|
1008
|
-
|
1009
|
-
|
1010
|
-
device=hidden_states.device,
|
1011
|
-
)
|
1012
|
-
|
1013
|
-
if self.activation == "silu":
|
1014
|
-
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
1015
|
-
gateup_output,
|
1016
|
-
down_input,
|
1014
|
+
scale_block_size = 128
|
1015
|
+
down_input_scale = torch.empty(
|
1016
|
+
(
|
1017
|
+
gateup_output.shape[0],
|
1017
1018
|
gateup_output.shape[1],
|
1018
|
-
|
1019
|
-
|
1020
|
-
|
1021
|
-
|
1022
|
-
|
1023
|
-
|
1024
|
-
|
1025
|
-
|
1019
|
+
gateup_output.shape[2] // 2 // scale_block_size,
|
1020
|
+
),
|
1021
|
+
device=gateup_output.device,
|
1022
|
+
dtype=torch.float32,
|
1023
|
+
)
|
1024
|
+
silu_and_mul_masked_post_quant_fwd(
|
1025
|
+
gateup_output,
|
1026
|
+
down_input,
|
1027
|
+
down_input_scale,
|
1028
|
+
scale_block_size,
|
1029
|
+
masked_m,
|
1030
|
+
)
|
1026
1031
|
|
1027
1032
|
# GroupGemm-1
|
1033
|
+
n = self.w2_weight.size(1)
|
1034
|
+
down_input_fp8 = (
|
1035
|
+
down_input,
|
1036
|
+
get_col_major_tma_aligned_tensor(down_input_scale),
|
1037
|
+
)
|
1028
1038
|
down_output = torch.empty(
|
1029
|
-
down_input.
|
1030
|
-
|
1031
|
-
|
1032
|
-
|
1039
|
+
(num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
|
1040
|
+
)
|
1041
|
+
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
1042
|
+
down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
|
1033
1043
|
)
|
1034
|
-
if down_input.shape[0] > 0:
|
1035
|
-
# Transpose earlier so that the testing will not trigger transposing kernels
|
1036
|
-
down_input = (
|
1037
|
-
down_input[0],
|
1038
|
-
get_col_major_tma_aligned_tensor(down_input[1]),
|
1039
|
-
)
|
1040
|
-
"""
|
1041
|
-
down_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
1042
|
-
down_input, self.w2_weight, out, masked_m, expected_m
|
1043
|
-
)
|
1044
|
-
"""
|
1045
1044
|
|
1046
1045
|
return down_output
|