sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +10 -8
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +157 -21
- sglang/bench_serving.py +137 -59
- sglang/compile_deep_gemm.py +5 -5
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +40 -28
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +49 -44
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +129 -135
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +238 -122
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +10 -19
- sglang/srt/disaggregation/prefill.py +132 -47
- sglang/srt/disaggregation/utils.py +123 -6
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +44 -9
- sglang/srt/entrypoints/http_server.py +23 -6
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +64 -18
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +6 -4
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +61 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
- sglang/srt/layers/moe/ep_moe/layer.py +105 -51
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +67 -10
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +8 -3
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +77 -74
- sglang/srt/layers/quantization/fp8.py +92 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +20 -7
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +2 -4
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +19 -4
- sglang/srt/managers/mm_utils.py +294 -140
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +122 -42
- sglang/srt/managers/schedule_policy.py +1 -5
- sglang/srt/managers/scheduler.py +205 -138
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +232 -58
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +76 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +314 -39
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +29 -19
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +5 -1
- sglang/srt/model_executor/model_runner.py +163 -68
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +308 -351
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +15 -8
- sglang/srt/models/llava.py +258 -7
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +58 -20
- sglang/srt/openai_api/protocol.py +6 -8
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +162 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +138 -7
- sglang/srt/speculative/eagle_worker.py +69 -21
- sglang/srt/utils.py +74 -17
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +55 -14
- sglang/utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -23,15 +23,17 @@ import triton.language as tl
|
|
23
23
|
from torch import nn
|
24
24
|
|
25
25
|
from sglang.srt.distributed import (
|
26
|
-
get_tensor_model_parallel_rank,
|
27
26
|
get_tensor_model_parallel_world_size,
|
28
27
|
tensor_model_parallel_all_gather,
|
29
28
|
)
|
30
29
|
from sglang.srt.layers.dp_attention import (
|
30
|
+
attn_tp_all_gather,
|
31
31
|
dp_gather_replicate,
|
32
32
|
dp_scatter,
|
33
|
-
get_attention_dp_rank,
|
34
33
|
get_attention_dp_size,
|
34
|
+
get_attention_tp_size,
|
35
|
+
get_local_attention_dp_rank,
|
36
|
+
get_local_attention_dp_size,
|
35
37
|
)
|
36
38
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
37
39
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
@@ -45,6 +47,18 @@ from sglang.srt.utils import dump_to_file
|
|
45
47
|
logger = logging.getLogger(__name__)
|
46
48
|
|
47
49
|
|
50
|
+
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
51
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
52
|
+
from sglang.srt.model_executor.forward_batch_info import (
|
53
|
+
CaptureHiddenMode,
|
54
|
+
ForwardBatch,
|
55
|
+
ForwardMode,
|
56
|
+
)
|
57
|
+
from sglang.srt.utils import dump_to_file
|
58
|
+
|
59
|
+
logger = logging.getLogger(__name__)
|
60
|
+
|
61
|
+
|
48
62
|
@dataclasses.dataclass
|
49
63
|
class LogitsProcessorOutput:
|
50
64
|
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
@@ -169,7 +183,7 @@ class LogitsMetadata:
|
|
169
183
|
return
|
170
184
|
|
171
185
|
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
|
172
|
-
dp_rank =
|
186
|
+
dp_rank = get_local_attention_dp_rank()
|
173
187
|
if dp_rank == 0:
|
174
188
|
dp_local_start_pos = torch.zeros_like(
|
175
189
|
self.global_num_tokens_for_logprob_gpu[0]
|
@@ -198,12 +212,20 @@ class LogitsProcessor(nn.Module):
|
|
198
212
|
super().__init__()
|
199
213
|
self.config = config
|
200
214
|
self.logit_scale = logit_scale
|
201
|
-
self.
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
215
|
+
self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
|
216
|
+
if self.use_attn_tp_group:
|
217
|
+
self.attn_tp_size = get_attention_tp_size()
|
218
|
+
self.do_tensor_parallel_all_gather = (
|
219
|
+
not skip_all_gather and self.attn_tp_size > 1
|
220
|
+
)
|
221
|
+
self.do_tensor_parallel_all_gather_dp_attn = False
|
222
|
+
else:
|
223
|
+
self.do_tensor_parallel_all_gather = (
|
224
|
+
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
|
225
|
+
)
|
226
|
+
self.do_tensor_parallel_all_gather_dp_attn = (
|
227
|
+
self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1
|
228
|
+
)
|
207
229
|
self.final_logit_softcapping = getattr(
|
208
230
|
self.config, "final_logit_softcapping", None
|
209
231
|
)
|
@@ -315,7 +337,8 @@ class LogitsProcessor(nn.Module):
|
|
315
337
|
|
316
338
|
if self.debug_tensor_dump_output_folder:
|
317
339
|
assert (
|
318
|
-
not self.do_tensor_parallel_all_gather
|
340
|
+
not self.do_tensor_parallel_all_gather
|
341
|
+
or get_local_attention_dp_size() == 1
|
319
342
|
), "dp attention + sharded lm_head doesn't support full logits"
|
320
343
|
full_logits = self._get_logits(hidden_states, lm_head, logits_metadata)
|
321
344
|
dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits)
|
@@ -442,7 +465,19 @@ class LogitsProcessor(nn.Module):
|
|
442
465
|
logits.mul_(self.logit_scale)
|
443
466
|
|
444
467
|
if self.do_tensor_parallel_all_gather:
|
445
|
-
|
468
|
+
if self.use_attn_tp_group:
|
469
|
+
global_logits = torch.empty(
|
470
|
+
(self.config.vocab_size, logits.shape[0]),
|
471
|
+
device=logits.device,
|
472
|
+
dtype=logits.dtype,
|
473
|
+
)
|
474
|
+
global_logits = global_logits.T
|
475
|
+
attn_tp_all_gather(
|
476
|
+
list(global_logits.tensor_split(self.attn_tp_size, dim=-1)), logits
|
477
|
+
)
|
478
|
+
logits = global_logits
|
479
|
+
else:
|
480
|
+
logits = tensor_model_parallel_all_gather(logits)
|
446
481
|
|
447
482
|
if self.do_tensor_parallel_all_gather_dp_attn:
|
448
483
|
logits, global_logits = (
|
@@ -0,0 +1,207 @@
|
|
1
|
+
"""Cutlass MoE kernel."""
|
2
|
+
|
3
|
+
import functools
|
4
|
+
import json
|
5
|
+
import logging
|
6
|
+
import os
|
7
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
8
|
+
|
9
|
+
import torch
|
10
|
+
|
11
|
+
from sglang.srt.utils import is_cuda
|
12
|
+
|
13
|
+
_is_cuda = is_cuda()
|
14
|
+
if _is_cuda:
|
15
|
+
import sgl_kernel
|
16
|
+
from sgl_kernel import (
|
17
|
+
fp8_blockwise_scaled_grouped_mm,
|
18
|
+
prepare_moe_input,
|
19
|
+
silu_and_mul,
|
20
|
+
)
|
21
|
+
|
22
|
+
|
23
|
+
def cutlass_fused_experts(
|
24
|
+
a: torch.Tensor,
|
25
|
+
w1_q: torch.Tensor,
|
26
|
+
w2_q: torch.Tensor,
|
27
|
+
w1_scale: torch.Tensor,
|
28
|
+
w2_scale: torch.Tensor,
|
29
|
+
topk_weights: torch.Tensor,
|
30
|
+
topk_ids: torch.Tensor,
|
31
|
+
a1_strides: torch.Tensor,
|
32
|
+
c1_strides: torch.Tensor,
|
33
|
+
a2_strides: torch.Tensor,
|
34
|
+
c2_strides: torch.Tensor,
|
35
|
+
workspace: torch.Tensor,
|
36
|
+
a_ptrs: torch.Tensor,
|
37
|
+
b_ptrs: torch.Tensor,
|
38
|
+
out_ptrs: torch.Tensor,
|
39
|
+
a_scales_ptrs: torch.Tensor,
|
40
|
+
b_scales_ptrs: torch.Tensor,
|
41
|
+
expert_offsets: torch.Tensor,
|
42
|
+
problem_sizes1: torch.Tensor,
|
43
|
+
problem_sizes2: torch.Tensor,
|
44
|
+
use_fp8_blockscale: bool = True,
|
45
|
+
) -> torch.Tensor:
|
46
|
+
"""Performs Fused MoE computation using CUTLASS-like kernels with FP8 weights and activations.
|
47
|
+
|
48
|
+
This function implements a Mixture of Experts (MoE) layer with a SwiGLU/SiLU
|
49
|
+
activation, leveraging custom kernels likely derived from CUTLASS principles
|
50
|
+
for grouped matrix multiplication (`fp8_blockwise_scaled_grouped_mm`) and
|
51
|
+
data preparation (`prepare_moe_input`, `silu_and_mul`).
|
52
|
+
|
53
|
+
It handles per-token routing, quantizes input activations to FP8 with
|
54
|
+
per-token scales, performs the expert computations using FP8 GEMMs with
|
55
|
+
pre-quantized FP8 weights (per-block scales), applies the SiLU activation,
|
56
|
+
and combines the results weighted by the router scores.
|
57
|
+
|
58
|
+
Args:
|
59
|
+
a (torch.Tensor): Input activations. Shape: `(m, k)`, where `m` is the total
|
60
|
+
number of tokens and `k` is the hidden size. Expected dtype: `torch.half`
|
61
|
+
or `torch.bfloat16`.
|
62
|
+
w1_q (torch.Tensor): Pre-quantized FP8 weight tensor for the first GEMM
|
63
|
+
(up-projection part of SwiGLU). Expected shape: `(E, k, n*2)`, where
|
64
|
+
`E` is the number of experts, `k` is the hidden size, and `n*2` is the
|
65
|
+
intermediate size (`I`). Expected dtype: `torch.float8_e4m3fn`.
|
66
|
+
Note: This shape implies weights are stored as (num_experts, hidden_size, intermediate_size).
|
67
|
+
w2_q (torch.Tensor): Pre-quantized FP8 weight tensor for the second GEMM
|
68
|
+
(down-projection). Expected shape: `(E, n, k)`, where `n` is half the
|
69
|
+
intermediate size (`I // 2`). Expected dtype: `torch.float8_e4m3fn`.
|
70
|
+
Note: This shape implies weights are stored as (num_experts, intermediate_size // 2, hidden_size).
|
71
|
+
w1_scale (torch.Tensor): Scales corresponding to `w1_q` (per-block scales).
|
72
|
+
Shape: `(E, num_blocks_n, num_blocks_k)`. Dtype: `torch.float32`.
|
73
|
+
w2_scale (torch.Tensor): Scales corresponding to `w2_q` (per-block scales).
|
74
|
+
Shape: `(E, num_blocks_k, num_blocks_n)`. Dtype: `torch.float32`.
|
75
|
+
topk_weights (torch.Tensor): Router weights for the selected top-k experts
|
76
|
+
for each token. Shape: `(m, topk)`. Dtype should ideally match `a`.
|
77
|
+
topk_ids (torch.Tensor): Indices of the selected top-k experts for each token.
|
78
|
+
Shape: `(m, topk)`. Dtype: `torch.int32`.
|
79
|
+
a1_strides (torch.Tensor): Stride information for the first GEMM's 'a' input.
|
80
|
+
Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`.
|
81
|
+
Note: Its exact usage within `fp8_blockwise_scaled_grouped_mm` needs clarification
|
82
|
+
as it's passed as both a_stride and b_stride in the first call.
|
83
|
+
c1_strides (torch.Tensor): Stride information for the first GEMM's 'c' output.
|
84
|
+
Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`.
|
85
|
+
a2_strides (torch.Tensor): Stride information for the second GEMM's 'a' input.
|
86
|
+
Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`.
|
87
|
+
Note: Its exact usage within `fp8_blockwise_scaled_grouped_mm` needs clarification
|
88
|
+
as it's passed as both a_stride and b_stride in the second call.
|
89
|
+
c2_strides (torch.Tensor): Stride information for the second GEMM's 'c' output.
|
90
|
+
Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`.
|
91
|
+
workspace (torch.Tensor): Reusable workspace for the underlying kernel.
|
92
|
+
a_ptrs (torch.Tensor): Pointers container for calculating offsets of the input activations for each expert.
|
93
|
+
b_ptrs (torch.Tensor): Pointers container for calculating offsets of the input weights for each expert.
|
94
|
+
out_ptrs (torch.Tensor): Pointers container for calculating offsets of the output activations for each expert.
|
95
|
+
a_scales_ptrs (torch.Tensor): Pointers container for calculating offsets of the input scales for each expert.
|
96
|
+
b_scales_ptrs (torch.Tensor): Pointers container for calculating offsets of the input scales for each expert.
|
97
|
+
use_fp8_blockscale (bool, optional): Flag indicating usage of FP8 with
|
98
|
+
block scaling. Currently, only `True` is supported. Defaults to `True`.
|
99
|
+
|
100
|
+
Returns:
|
101
|
+
torch.Tensor: The computed MoE layer output. Shape: `(m, k)`, dtype matches `a`.
|
102
|
+
|
103
|
+
Raises:
|
104
|
+
AssertionError: If input shapes, dtypes, or flags are inconsistent or unsupported.
|
105
|
+
NotImplementedError: If CUDA is not available or `sgl_kernel` is not properly installed.
|
106
|
+
"""
|
107
|
+
assert use_fp8_blockscale, "Only support fp8 blockscale for now"
|
108
|
+
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
109
|
+
assert w1_q.dtype == torch.float8_e4m3fn
|
110
|
+
assert w2_q.dtype == torch.float8_e4m3fn
|
111
|
+
assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1"
|
112
|
+
assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2"
|
113
|
+
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
|
114
|
+
assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch"
|
115
|
+
assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
|
116
|
+
assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
|
117
|
+
assert a.dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
|
118
|
+
|
119
|
+
if is_cuda:
|
120
|
+
from sglang.srt.layers.quantization.fp8_kernel import (
|
121
|
+
sglang_per_token_group_quant_fp8,
|
122
|
+
)
|
123
|
+
|
124
|
+
out_dtype = a.dtype
|
125
|
+
num_experts = w1_q.size(0)
|
126
|
+
m = a.size(0)
|
127
|
+
k = w1_q.size(1)
|
128
|
+
n = w2_q.size(1)
|
129
|
+
|
130
|
+
topk = topk_ids.size(1)
|
131
|
+
|
132
|
+
a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
|
133
|
+
device = a_q.device
|
134
|
+
|
135
|
+
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
136
|
+
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
137
|
+
|
138
|
+
prepare_moe_input(
|
139
|
+
topk_ids,
|
140
|
+
expert_offsets,
|
141
|
+
problem_sizes1,
|
142
|
+
problem_sizes2,
|
143
|
+
a_map,
|
144
|
+
c_map,
|
145
|
+
num_experts,
|
146
|
+
n,
|
147
|
+
k,
|
148
|
+
)
|
149
|
+
|
150
|
+
rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype)
|
151
|
+
rep_a1_scales = a1_scale[a_map]
|
152
|
+
|
153
|
+
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
|
154
|
+
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
|
155
|
+
|
156
|
+
a_sf_layout = torch.empty((num_experts, 5), device=device, dtype=torch.int)
|
157
|
+
w_sf_layout = torch.empty((num_experts, 5), device=device, dtype=torch.int)
|
158
|
+
|
159
|
+
fp8_blockwise_scaled_grouped_mm(
|
160
|
+
c1,
|
161
|
+
a_ptrs,
|
162
|
+
b_ptrs,
|
163
|
+
out_ptrs,
|
164
|
+
a_scales_ptrs,
|
165
|
+
b_scales_ptrs,
|
166
|
+
rep_a_q,
|
167
|
+
w1_q,
|
168
|
+
rep_a1_scales,
|
169
|
+
w1_scale,
|
170
|
+
a1_strides,
|
171
|
+
a1_strides,
|
172
|
+
c1_strides,
|
173
|
+
a_sf_layout,
|
174
|
+
w_sf_layout,
|
175
|
+
problem_sizes1,
|
176
|
+
expert_offsets[:-1],
|
177
|
+
workspace,
|
178
|
+
)
|
179
|
+
|
180
|
+
intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
|
181
|
+
silu_and_mul(c1, intermediate)
|
182
|
+
|
183
|
+
intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
|
184
|
+
|
185
|
+
fp8_blockwise_scaled_grouped_mm(
|
186
|
+
c2,
|
187
|
+
a_ptrs,
|
188
|
+
b_ptrs,
|
189
|
+
out_ptrs,
|
190
|
+
a_scales_ptrs,
|
191
|
+
b_scales_ptrs,
|
192
|
+
intemediate_q,
|
193
|
+
w2_q,
|
194
|
+
a2_scale,
|
195
|
+
w2_scale,
|
196
|
+
a2_strides,
|
197
|
+
a2_strides,
|
198
|
+
c2_strides,
|
199
|
+
a_sf_layout,
|
200
|
+
w_sf_layout,
|
201
|
+
problem_sizes2,
|
202
|
+
expert_offsets[:-1],
|
203
|
+
workspace,
|
204
|
+
)
|
205
|
+
return (
|
206
|
+
c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)
|
207
|
+
).sum(dim=1)
|
@@ -3,10 +3,9 @@ from typing import List, Optional
|
|
3
3
|
|
4
4
|
import torch
|
5
5
|
import triton
|
6
|
-
import triton.language as tl
|
7
6
|
|
8
7
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
9
|
-
from sglang.srt.utils import is_cuda
|
8
|
+
from sglang.srt.utils import dispose_tensor, is_cuda
|
10
9
|
|
11
10
|
logger = logging.getLogger(__name__)
|
12
11
|
|
@@ -116,7 +115,7 @@ def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int):
|
|
116
115
|
seg_indptr = torch.empty(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
|
117
116
|
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int64)
|
118
117
|
|
119
|
-
# Find
|
118
|
+
# Find offset
|
120
119
|
expert_ids = torch.arange(
|
121
120
|
num_experts + 1, device=topk_ids.device, dtype=reorder_topk_ids.dtype
|
122
121
|
)
|
@@ -653,12 +652,15 @@ def grouped_gemm_triton(
|
|
653
652
|
scale_a: torch.Tensor = None,
|
654
653
|
scale_b: torch.Tensor = None,
|
655
654
|
block_shape: Optional[List[int]] = None,
|
655
|
+
c_dtype=None,
|
656
656
|
):
|
657
657
|
assert weight_column_major == True # TODO: more
|
658
658
|
if use_fp8_w8a8 and block_shape is None:
|
659
659
|
assert scale_a is not None and scale_b is not None
|
660
660
|
|
661
661
|
if block_shape is not None:
|
662
|
+
a_original = a
|
663
|
+
|
662
664
|
assert len(block_shape) == 2
|
663
665
|
block_n, block_k = block_shape[0], block_shape[1]
|
664
666
|
a, scale_a = per_token_group_quant_fp8(a, block_k)
|
@@ -667,6 +669,8 @@ def grouped_gemm_triton(
|
|
667
669
|
assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2]
|
668
670
|
assert triton.cdiv(b.shape[-1], block_k) == scale_b.shape[-1]
|
669
671
|
|
672
|
+
dispose_tensor(a_original)
|
673
|
+
|
670
674
|
# TODO: adjust config or tune kernel
|
671
675
|
# Reduce block size to prevent L40 shared memory overflow.
|
672
676
|
config = {
|
@@ -680,6 +684,10 @@ def grouped_gemm_triton(
|
|
680
684
|
m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"]
|
681
685
|
)
|
682
686
|
|
687
|
+
if c is None:
|
688
|
+
assert c_dtype is not None
|
689
|
+
c = torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=c_dtype)
|
690
|
+
|
683
691
|
grid = lambda META: (
|
684
692
|
triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size,
|
685
693
|
triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]),
|
@@ -783,19 +791,23 @@ def _fwd_kernel_ep_scatter_2(
|
|
783
791
|
offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
|
784
792
|
mask_s = offset_in_s < SCALE_HIDDEN_SIZE
|
785
793
|
|
786
|
-
for
|
794
|
+
for token_id_int32 in range(start_token_id, total_token_num, grid_num):
|
795
|
+
token_id = token_id_int32.to(tl.int64)
|
787
796
|
to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
|
788
797
|
to_copy_s = tl.load(
|
789
798
|
recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s
|
790
799
|
)
|
791
800
|
|
792
|
-
for
|
801
|
+
for topk_idx_int32 in tl.range(0, topk_num, 1, num_stages=4):
|
802
|
+
topk_index = topk_idx_int32.to(tl.int64)
|
793
803
|
expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)
|
794
804
|
if expert_id >= 0:
|
795
|
-
|
805
|
+
dest_token_index_int32 = tl.atomic_add(expert_start_loc + expert_id, 1)
|
806
|
+
dest_token_index = dest_token_index_int32.to(tl.int64)
|
807
|
+
|
796
808
|
tl.store(
|
797
809
|
output_index + token_id * output_index_stride0 + topk_index,
|
798
|
-
|
810
|
+
dest_token_index_int32,
|
799
811
|
)
|
800
812
|
output_tensor_ptr = (
|
801
813
|
output_tensor + dest_token_index * output_tensor_stride0
|
@@ -894,21 +906,31 @@ def _fwd_kernel_ep_gather(
|
|
894
906
|
topk_num: tl.constexpr,
|
895
907
|
BLOCK_D: tl.constexpr,
|
896
908
|
):
|
897
|
-
|
898
|
-
|
909
|
+
cur_block_int32 = tl.program_id(0)
|
910
|
+
cur_block = cur_block_int32.to(tl.int64)
|
911
|
+
|
912
|
+
start_cur_token_int32 = tl.program_id(1)
|
913
|
+
|
899
914
|
grid_num = tl.num_programs(1)
|
900
915
|
|
901
|
-
for
|
916
|
+
for cur_token_int32 in range(start_cur_token_int32, total_token_num, grid_num):
|
917
|
+
cur_token = cur_token_int32.to(tl.int64)
|
918
|
+
|
902
919
|
off_d = tl.arange(0, BLOCK_D)
|
903
920
|
accumulator = tl.zeros([BLOCK_D], dtype=tl.float32)
|
904
|
-
|
921
|
+
|
922
|
+
for topk_index_int32 in range(0, topk_num):
|
923
|
+
topk_index = topk_index_int32.to(tl.int64)
|
924
|
+
|
905
925
|
expert_id = tl.load(
|
906
926
|
recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index
|
907
927
|
)
|
908
928
|
if expert_id >= 0:
|
909
|
-
|
929
|
+
source_token_index_int32 = tl.load(
|
910
930
|
input_index + cur_token * input_index_stride0 + topk_index
|
911
931
|
)
|
932
|
+
source_token_index = source_token_index_int32.to(tl.int64)
|
933
|
+
|
912
934
|
acc_weight = tl.load(
|
913
935
|
recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index
|
914
936
|
)
|