sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +302 -414
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +13 -8
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +144 -6
- sglang/srt/managers/schedule_batch.py +237 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +773 -334
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +225 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +68 -37
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +102 -36
- sglang/srt/model_executor/cuda_graph_runner.py +56 -31
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +280 -81
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -32
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +135 -60
- sglang/srt/speculative/build_eagle_tree.py +8 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
- sglang/srt/speculative/eagle_utils.py +92 -57
- sglang/srt/speculative/eagle_worker.py +238 -111
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
@@ -26,12 +26,20 @@ from sglang.srt.distributed import (
|
|
26
26
|
get_tensor_model_parallel_world_size,
|
27
27
|
tensor_model_parallel_all_gather,
|
28
28
|
)
|
29
|
+
from sglang.srt.layers.dp_attention import (
|
30
|
+
dp_gather,
|
31
|
+
dp_scatter,
|
32
|
+
get_attention_dp_rank,
|
33
|
+
get_attention_dp_size,
|
34
|
+
)
|
29
35
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
36
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
30
37
|
from sglang.srt.model_executor.forward_batch_info import (
|
31
38
|
CaptureHiddenMode,
|
32
39
|
ForwardBatch,
|
33
40
|
ForwardMode,
|
34
41
|
)
|
42
|
+
from sglang.srt.utils import dump_to_file
|
35
43
|
|
36
44
|
logger = logging.getLogger(__name__)
|
37
45
|
|
@@ -51,13 +59,19 @@ class LogitsProcessorOutput:
|
|
51
59
|
# The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
|
52
60
|
next_token_top_logprobs_val: Optional[List] = None
|
53
61
|
next_token_top_logprobs_idx: Optional[List] = None
|
62
|
+
# The logprobs and ids of the requested token ids in output positions. shape: [#seq, n] (n is the number of requested token ids)
|
63
|
+
next_token_token_ids_logprobs_val: Optional[List] = None
|
64
|
+
next_token_token_ids_logprobs_idx: Optional[List] = None
|
54
65
|
|
55
66
|
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
56
67
|
# The logprobs of input tokens. shape: [#token]
|
57
|
-
input_token_logprobs: torch.Tensor = None
|
68
|
+
input_token_logprobs: Optional[torch.Tensor] = None
|
58
69
|
# The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
|
59
70
|
input_top_logprobs_val: List = None
|
60
71
|
input_top_logprobs_idx: List = None
|
72
|
+
# The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids)
|
73
|
+
input_token_ids_logprobs_val: Optional[List] = None
|
74
|
+
input_token_ids_logprobs_idx: Optional[List] = None
|
61
75
|
|
62
76
|
|
63
77
|
@dataclasses.dataclass
|
@@ -67,43 +81,114 @@ class LogitsMetadata:
|
|
67
81
|
|
68
82
|
extend_return_logprob: bool = False
|
69
83
|
extend_return_top_logprob: bool = False
|
84
|
+
extend_token_ids_logprob: bool = False
|
70
85
|
extend_seq_lens: Optional[torch.Tensor] = None
|
71
86
|
extend_seq_lens_cpu: Optional[List[int]] = None
|
72
87
|
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
73
88
|
extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
|
74
89
|
top_logprobs_nums: Optional[List[int]] = None
|
90
|
+
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
|
91
|
+
token_ids_logprobs: Optional[List[List[int]]] = None
|
92
|
+
|
93
|
+
# logits and logprobs post processing
|
94
|
+
temp_scaled_logprobs: bool = False
|
95
|
+
temperature: torch.Tensor = None
|
96
|
+
top_p_normalized_logprobs: bool = False
|
97
|
+
top_p: torch.Tensor = None
|
98
|
+
|
99
|
+
# DP attention metadata. Not needed when DP attention is not used.
|
100
|
+
# Number of tokens in the request.
|
101
|
+
global_num_tokens_gpu: Optional[torch.Tensor] = None
|
102
|
+
# The start position of local hidden states.
|
103
|
+
dp_local_start_pos: Optional[torch.Tensor] = None
|
104
|
+
dp_local_num_tokens: Optional[torch.Tensor] = None
|
105
|
+
gathered_buffer: Optional[torch.Tensor] = None
|
106
|
+
# Buffer to gather logits from all ranks.
|
107
|
+
forward_batch_gathered_buffer: Optional[torch.Tensor] = None
|
108
|
+
# Number of tokens to sample per DP rank
|
109
|
+
global_num_tokens_for_logprob_cpu: Optional[torch.Tensor] = None
|
110
|
+
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
|
111
|
+
|
112
|
+
# for padding
|
113
|
+
padded_static_len: int = -1
|
75
114
|
|
76
115
|
@classmethod
|
77
116
|
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
78
|
-
if
|
79
|
-
|
117
|
+
if (
|
118
|
+
forward_batch.forward_mode.is_extend()
|
119
|
+
and forward_batch.return_logprob
|
120
|
+
and not forward_batch.forward_mode.is_target_verify()
|
121
|
+
):
|
80
122
|
extend_return_top_logprob = any(
|
81
123
|
x > 0 for x in forward_batch.top_logprobs_nums
|
82
124
|
)
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
125
|
+
extend_token_ids_logprob = any(
|
126
|
+
x is not None for x in forward_batch.token_ids_logprobs
|
127
|
+
)
|
128
|
+
extend_return_logprob = False
|
129
|
+
extend_logprob_pruned_lens_cpu = []
|
130
|
+
for extend_len, start_len in zip(
|
131
|
+
forward_batch.extend_seq_lens_cpu,
|
132
|
+
forward_batch.extend_logprob_start_lens_cpu,
|
133
|
+
):
|
134
|
+
if extend_len - start_len > 0:
|
135
|
+
extend_return_logprob = True
|
136
|
+
extend_logprob_pruned_lens_cpu.append(extend_len - start_len)
|
90
137
|
else:
|
91
138
|
extend_return_logprob = extend_return_top_logprob = (
|
92
|
-
|
93
|
-
) = False
|
139
|
+
extend_token_ids_logprob
|
140
|
+
) = extend_logprob_pruned_lens_cpu = False
|
94
141
|
|
95
142
|
return cls(
|
96
143
|
forward_mode=forward_batch.forward_mode,
|
97
144
|
capture_hidden_mode=forward_batch.capture_hidden_mode,
|
98
145
|
extend_return_logprob=extend_return_logprob,
|
99
146
|
extend_return_top_logprob=extend_return_top_logprob,
|
147
|
+
extend_token_ids_logprob=extend_token_ids_logprob,
|
100
148
|
extend_seq_lens=forward_batch.extend_seq_lens,
|
101
149
|
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
|
102
150
|
extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
|
103
151
|
extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
|
104
152
|
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
153
|
+
token_ids_logprobs=forward_batch.token_ids_logprobs,
|
154
|
+
extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
|
155
|
+
padded_static_len=forward_batch.padded_static_len,
|
156
|
+
global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
|
157
|
+
dp_local_start_pos=forward_batch.dp_local_start_pos,
|
158
|
+
dp_local_num_tokens=forward_batch.dp_local_num_tokens,
|
159
|
+
gathered_buffer=forward_batch.gathered_buffer,
|
160
|
+
forward_batch_gathered_buffer=forward_batch.gathered_buffer,
|
161
|
+
global_num_tokens_for_logprob_cpu=forward_batch.global_num_tokens_for_logprob_cpu,
|
162
|
+
global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu,
|
163
|
+
)
|
164
|
+
|
165
|
+
def compute_dp_attention_metadata(self, hidden_states: torch.Tensor):
|
166
|
+
if self.global_num_tokens_for_logprob_cpu is None:
|
167
|
+
# we are capturing cuda graph
|
168
|
+
return
|
169
|
+
|
170
|
+
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
|
171
|
+
dp_rank = get_attention_dp_rank()
|
172
|
+
if dp_rank == 0:
|
173
|
+
dp_local_start_pos = torch.zeros_like(
|
174
|
+
self.global_num_tokens_for_logprob_gpu[0]
|
175
|
+
)
|
176
|
+
else:
|
177
|
+
dp_local_start_pos = cumtokens[dp_rank - 1]
|
178
|
+
dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank]
|
179
|
+
gathered_buffer = torch.zeros(
|
180
|
+
(
|
181
|
+
sum(self.global_num_tokens_for_logprob_cpu),
|
182
|
+
hidden_states.shape[1],
|
183
|
+
),
|
184
|
+
dtype=hidden_states.dtype,
|
185
|
+
device=hidden_states.device,
|
105
186
|
)
|
106
187
|
|
188
|
+
self.dp_local_start_pos = dp_local_start_pos
|
189
|
+
self.dp_local_num_tokens = dp_local_num_tokens
|
190
|
+
self.gathered_buffer = gathered_buffer
|
191
|
+
|
107
192
|
|
108
193
|
class LogitsProcessor(nn.Module):
|
109
194
|
def __init__(
|
@@ -115,6 +200,9 @@ class LogitsProcessor(nn.Module):
|
|
115
200
|
self.do_tensor_parallel_all_gather = (
|
116
201
|
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
|
117
202
|
)
|
203
|
+
self.do_tensor_parallel_all_gather_dp_attn = (
|
204
|
+
self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1
|
205
|
+
)
|
118
206
|
self.final_logit_softcapping = getattr(
|
119
207
|
self.config, "final_logit_softcapping", None
|
120
208
|
)
|
@@ -124,6 +212,10 @@ class LogitsProcessor(nn.Module):
|
|
124
212
|
):
|
125
213
|
self.final_logit_softcapping = None
|
126
214
|
|
215
|
+
self.debug_tensor_dump_output_folder = global_server_args_dict.get(
|
216
|
+
"debug_tensor_dump_output_folder", None
|
217
|
+
)
|
218
|
+
|
127
219
|
def forward(
|
128
220
|
self,
|
129
221
|
input_ids,
|
@@ -141,30 +233,74 @@ class LogitsProcessor(nn.Module):
|
|
141
233
|
):
|
142
234
|
pruned_states = hidden_states
|
143
235
|
sample_indices = None
|
236
|
+
input_logprob_indices = None
|
144
237
|
elif (
|
145
238
|
logits_metadata.forward_mode.is_extend()
|
146
239
|
and not logits_metadata.extend_return_logprob
|
147
240
|
):
|
148
241
|
# Prefill without input logprobs.
|
149
|
-
|
242
|
+
if logits_metadata.padded_static_len < 0:
|
243
|
+
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
244
|
+
else:
|
245
|
+
# If padding_static length is 5 and extended_seq_lens is [2, 3],
|
246
|
+
# then our batch looks like [t00, t01, p, p, p, t10, t11, t12, p, p]
|
247
|
+
# and this retrieves t01 and t12, which are the valid last tokens
|
248
|
+
idx = torch.arange(
|
249
|
+
len(logits_metadata.extend_seq_lens),
|
250
|
+
device=logits_metadata.extend_seq_lens.device,
|
251
|
+
)
|
252
|
+
last_index = (
|
253
|
+
idx * logits_metadata.padded_static_len
|
254
|
+
+ logits_metadata.extend_seq_lens
|
255
|
+
- 1
|
256
|
+
)
|
150
257
|
pruned_states = hidden_states[last_index]
|
151
258
|
sample_indices = None
|
259
|
+
input_logprob_indices = None
|
152
260
|
else:
|
153
|
-
#
|
261
|
+
# Input logprobs are required.
|
262
|
+
# Find 3 different indices.
|
263
|
+
# 1. pruned_states: hidden states that we want logprobs from.
|
264
|
+
# 2. sample_indices: Indices that have sampled tokens.
|
265
|
+
# 3. input_logprob_indices: Indices that have input logprob tokens.
|
154
266
|
sample_index_pt = -1
|
155
267
|
sample_indices = []
|
156
|
-
|
157
|
-
|
268
|
+
input_logprob_indices_pt = 0
|
269
|
+
input_logprob_indices = []
|
270
|
+
pt, pruned_states = 0, []
|
271
|
+
for extend_logprob_start_len, extend_len in zip(
|
158
272
|
logits_metadata.extend_logprob_start_lens_cpu,
|
159
273
|
logits_metadata.extend_seq_lens_cpu,
|
160
274
|
):
|
275
|
+
# It can happen in chunked prefill. We still need to sample 1 token,
|
276
|
+
# But we don't want to include it in input logprob.
|
277
|
+
if extend_len == extend_logprob_start_len:
|
278
|
+
start_len = extend_logprob_start_len - 1
|
279
|
+
else:
|
280
|
+
start_len = extend_logprob_start_len
|
281
|
+
|
282
|
+
# We always need at least 1 token to sample because that's required
|
283
|
+
# by a caller.
|
284
|
+
assert extend_len > start_len
|
161
285
|
pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
|
286
|
+
pt += extend_len
|
162
287
|
sample_index_pt += extend_len - start_len
|
163
288
|
sample_indices.append(sample_index_pt)
|
164
|
-
|
165
|
-
|
289
|
+
input_logprob_indices.extend(
|
290
|
+
[
|
291
|
+
input_logprob_indices_pt + i
|
292
|
+
for i in range(extend_len - extend_logprob_start_len)
|
293
|
+
]
|
294
|
+
)
|
295
|
+
input_logprob_indices_pt += extend_len - start_len
|
166
296
|
|
167
297
|
pruned_states = torch.cat(pruned_states)
|
298
|
+
sample_indices = torch.tensor(
|
299
|
+
sample_indices, device=pruned_states.device, dtype=torch.int64
|
300
|
+
)
|
301
|
+
input_logprob_indices = torch.tensor(
|
302
|
+
input_logprob_indices, device=pruned_states.device, dtype=torch.int64
|
303
|
+
)
|
168
304
|
|
169
305
|
# Compute logits for both input and sampled tokens.
|
170
306
|
logits = self._get_logits(pruned_states, lm_head, logits_metadata)
|
@@ -172,28 +308,51 @@ class LogitsProcessor(nn.Module):
|
|
172
308
|
logits[sample_indices] if sample_indices is not None else logits
|
173
309
|
)
|
174
310
|
|
175
|
-
if
|
176
|
-
|
177
|
-
|
178
|
-
|
311
|
+
if self.debug_tensor_dump_output_folder:
|
312
|
+
assert (
|
313
|
+
not self.do_tensor_parallel_all_gather or get_attention_dp_size() == 1
|
314
|
+
), "dp attention + sharded lm_head doesn't support full logits"
|
315
|
+
full_logits = self._get_logits(hidden_states, lm_head, logits_metadata)
|
316
|
+
dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits)
|
317
|
+
|
318
|
+
hidden_states_to_store: Optional[torch.Tensor] = None
|
319
|
+
if logits_metadata.capture_hidden_mode.need_capture():
|
320
|
+
if logits_metadata.capture_hidden_mode.is_full():
|
321
|
+
hidden_states_to_store = hidden_states
|
322
|
+
elif logits_metadata.capture_hidden_mode.is_last():
|
323
|
+
# Get the last token hidden states. If sample_indices is None,
|
324
|
+
# pruned states only contain the last tokens already.
|
325
|
+
hidden_states_to_store = (
|
326
|
+
pruned_states[sample_indices] if sample_indices else pruned_states
|
327
|
+
)
|
328
|
+
else:
|
329
|
+
assert False, "Should never reach"
|
330
|
+
|
331
|
+
if not logits_metadata.extend_return_logprob:
|
179
332
|
# Decode mode or extend mode without return_logprob.
|
180
333
|
return LogitsProcessorOutput(
|
181
334
|
next_token_logits=sampled_logits,
|
182
|
-
hidden_states=
|
183
|
-
hidden_states
|
184
|
-
if logits_metadata.capture_hidden_mode.is_full()
|
185
|
-
else (
|
186
|
-
pruned_states
|
187
|
-
if logits_metadata.capture_hidden_mode.is_last()
|
188
|
-
else None
|
189
|
-
)
|
190
|
-
),
|
335
|
+
hidden_states=hidden_states_to_store,
|
191
336
|
)
|
192
337
|
else:
|
193
|
-
input_logprobs = logits
|
338
|
+
input_logprobs = logits[input_logprob_indices]
|
194
339
|
del hidden_states, logits
|
195
340
|
|
196
341
|
# Normalize the logprob w/o temperature, top-p
|
342
|
+
pruned_lens = torch.tensor(
|
343
|
+
logits_metadata.extend_logprob_pruned_lens_cpu,
|
344
|
+
device=input_logprobs.device,
|
345
|
+
)
|
346
|
+
if logits_metadata.temp_scaled_logprobs:
|
347
|
+
logits_metadata.temperature = torch.repeat_interleave(
|
348
|
+
logits_metadata.temperature.view(-1),
|
349
|
+
pruned_lens,
|
350
|
+
).view(-1, 1)
|
351
|
+
if logits_metadata.top_p_normalized_logprobs:
|
352
|
+
logits_metadata.top_p = torch.repeat_interleave(
|
353
|
+
logits_metadata.top_p,
|
354
|
+
pruned_lens,
|
355
|
+
)
|
197
356
|
input_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
198
357
|
input_logprobs, logits_metadata
|
199
358
|
)
|
@@ -207,14 +366,18 @@ class LogitsProcessor(nn.Module):
|
|
207
366
|
else:
|
208
367
|
input_top_logprobs_val = input_top_logprobs_idx = None
|
209
368
|
|
369
|
+
# Get the logprob of given token id
|
370
|
+
if logits_metadata.extend_token_ids_logprob:
|
371
|
+
(
|
372
|
+
input_token_ids_logprobs_val,
|
373
|
+
input_token_ids_logprobs_idx,
|
374
|
+
) = self.get_token_ids_logprobs(input_logprobs, logits_metadata)
|
375
|
+
else:
|
376
|
+
input_token_ids_logprobs_val = input_token_ids_logprobs_idx = None
|
377
|
+
|
210
378
|
input_token_logprobs = input_logprobs[
|
211
379
|
torch.arange(input_logprobs.shape[0], device=input_logprobs.device),
|
212
|
-
|
213
|
-
[
|
214
|
-
torch.cat(pruned_input_ids)[1:],
|
215
|
-
torch.tensor([0], device=input_logprobs.device),
|
216
|
-
]
|
217
|
-
),
|
380
|
+
logits_metadata.extend_input_logprob_token_ids_gpu,
|
218
381
|
]
|
219
382
|
|
220
383
|
return LogitsProcessorOutput(
|
@@ -222,6 +385,9 @@ class LogitsProcessor(nn.Module):
|
|
222
385
|
input_token_logprobs=input_token_logprobs,
|
223
386
|
input_top_logprobs_val=input_top_logprobs_val,
|
224
387
|
input_top_logprobs_idx=input_top_logprobs_idx,
|
388
|
+
hidden_states=hidden_states_to_store,
|
389
|
+
input_token_ids_logprobs_val=input_token_ids_logprobs_val,
|
390
|
+
input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
|
225
391
|
)
|
226
392
|
|
227
393
|
def _get_logits(
|
@@ -231,13 +397,27 @@ class LogitsProcessor(nn.Module):
|
|
231
397
|
logits_metadata: LogitsMetadata,
|
232
398
|
embedding_bias: Optional[torch.Tensor] = None,
|
233
399
|
) -> torch.Tensor:
|
234
|
-
"""Get logits from hidden_states.
|
400
|
+
"""Get logits from hidden_states.
|
401
|
+
|
402
|
+
If sampled_logits_only is True, it means hidden_states only contain the
|
403
|
+
last position (e.g., extend without input logprobs). The caller should
|
404
|
+
guarantee the given hidden_states follow this constraint.
|
405
|
+
"""
|
406
|
+
if self.do_tensor_parallel_all_gather_dp_attn:
|
407
|
+
logits_metadata.compute_dp_attention_metadata(hidden_states)
|
408
|
+
hidden_states, local_hidden_states = (
|
409
|
+
logits_metadata.gathered_buffer,
|
410
|
+
hidden_states.clone(),
|
411
|
+
)
|
412
|
+
dp_gather(hidden_states, local_hidden_states, logits_metadata, "embedding")
|
235
413
|
|
236
414
|
if hasattr(lm_head, "weight"):
|
237
|
-
logits = torch.matmul(
|
415
|
+
logits = torch.matmul(
|
416
|
+
hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
|
417
|
+
)
|
238
418
|
else:
|
239
419
|
# GGUF models
|
240
|
-
logits = lm_head.
|
420
|
+
logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias)
|
241
421
|
|
242
422
|
if self.logit_scale is not None:
|
243
423
|
logits.mul_(self.logit_scale)
|
@@ -245,6 +425,17 @@ class LogitsProcessor(nn.Module):
|
|
245
425
|
if self.do_tensor_parallel_all_gather:
|
246
426
|
logits = tensor_model_parallel_all_gather(logits)
|
247
427
|
|
428
|
+
if self.do_tensor_parallel_all_gather_dp_attn:
|
429
|
+
logits, global_logits = (
|
430
|
+
torch.empty(
|
431
|
+
(local_hidden_states.shape[0], logits.shape[1]),
|
432
|
+
device=logits.device,
|
433
|
+
dtype=logits.dtype,
|
434
|
+
),
|
435
|
+
logits,
|
436
|
+
)
|
437
|
+
dp_scatter(logits, global_logits, logits_metadata)
|
438
|
+
|
248
439
|
logits = logits[:, : self.config.vocab_size].float()
|
249
440
|
|
250
441
|
if self.final_logit_softcapping:
|
@@ -272,21 +463,66 @@ class LogitsProcessor(nn.Module):
|
|
272
463
|
continue
|
273
464
|
|
274
465
|
input_top_logprobs_val.append(
|
275
|
-
[values[pt + j][:k] for j in range(pruned_len
|
466
|
+
[values[pt + j][:k] for j in range(pruned_len)]
|
276
467
|
)
|
277
468
|
input_top_logprobs_idx.append(
|
278
|
-
[indices[pt + j][:k] for j in range(pruned_len
|
469
|
+
[indices[pt + j][:k] for j in range(pruned_len)]
|
279
470
|
)
|
280
471
|
pt += pruned_len
|
281
472
|
|
282
473
|
return input_top_logprobs_val, input_top_logprobs_idx
|
283
474
|
|
475
|
+
@staticmethod
|
476
|
+
def get_token_ids_logprobs(
|
477
|
+
all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata
|
478
|
+
):
|
479
|
+
input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], []
|
480
|
+
pt = 0
|
481
|
+
for token_ids, pruned_len in zip(
|
482
|
+
logits_metadata.token_ids_logprobs,
|
483
|
+
logits_metadata.extend_logprob_pruned_lens_cpu,
|
484
|
+
):
|
485
|
+
if pruned_len <= 0:
|
486
|
+
input_token_ids_logprobs_val.append([])
|
487
|
+
input_token_ids_logprobs_idx.append([])
|
488
|
+
continue
|
489
|
+
|
490
|
+
input_token_ids_logprobs_val.append(
|
491
|
+
[all_logprobs[pt + j, token_ids].tolist() for j in range(pruned_len)]
|
492
|
+
)
|
493
|
+
input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)])
|
494
|
+
pt += pruned_len
|
495
|
+
|
496
|
+
return input_token_ids_logprobs_val, input_token_ids_logprobs_idx
|
497
|
+
|
284
498
|
@staticmethod
|
285
499
|
def compute_temp_top_p_normalized_logprobs(
|
286
500
|
last_logits: torch.Tensor, logits_metadata: LogitsMetadata
|
287
501
|
) -> torch.Tensor:
|
288
|
-
|
289
|
-
|
502
|
+
"""
|
503
|
+
compute logprobs for the output token from the given logits.
|
504
|
+
|
505
|
+
Returns:
|
506
|
+
torch.Tensor: logprobs from logits
|
507
|
+
"""
|
508
|
+
# Scale logits if temperature scaling is enabled
|
509
|
+
if logits_metadata.temp_scaled_logprobs:
|
510
|
+
last_logits = last_logits / logits_metadata.temperature
|
511
|
+
|
512
|
+
# Normalize logprobs if top_p normalization is enabled
|
513
|
+
# NOTE: only normalize logprobs when top_p is set and not equal to 1.0
|
514
|
+
if (
|
515
|
+
logits_metadata.top_p_normalized_logprobs
|
516
|
+
and (logits_metadata.top_p != 1.0).any()
|
517
|
+
):
|
518
|
+
from sglang.srt.layers.sampler import top_p_normalize_probs_torch
|
519
|
+
|
520
|
+
probs = torch.softmax(last_logits, dim=-1)
|
521
|
+
del last_logits
|
522
|
+
probs = top_p_normalize_probs_torch(probs, logits_metadata.top_p)
|
523
|
+
return torch.log(probs)
|
524
|
+
else:
|
525
|
+
return torch.nn.functional.log_softmax(last_logits, dim=-1)
|
290
526
|
|
291
527
|
|
292
528
|
@triton.jit
|
@@ -1,10 +1,17 @@
|
|
1
1
|
import logging
|
2
|
-
from typing import Optional
|
2
|
+
from typing import List, Optional
|
3
3
|
|
4
4
|
import torch
|
5
5
|
import triton
|
6
6
|
import triton.language as tl
|
7
7
|
|
8
|
+
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
9
|
+
|
10
|
+
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
11
|
+
if _is_cuda:
|
12
|
+
from sglang.srt.layers.quantization.fp8_kernel import (
|
13
|
+
sglang_per_token_group_quant_fp8,
|
14
|
+
)
|
8
15
|
logger = logging.getLogger(__name__)
|
9
16
|
|
10
17
|
|
@@ -137,6 +144,73 @@ def silu_and_mul_triton_kernel(
|
|
137
144
|
tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
|
138
145
|
|
139
146
|
|
147
|
+
@triton.jit
|
148
|
+
def tanh(x):
|
149
|
+
return 2 * tl.sigmoid(2 * x) - 1
|
150
|
+
|
151
|
+
|
152
|
+
@triton.jit
|
153
|
+
def gelu_and_mul_triton_kernel(
|
154
|
+
gateup_output,
|
155
|
+
down_input,
|
156
|
+
hidden_size,
|
157
|
+
reorder_topk_ids,
|
158
|
+
scales,
|
159
|
+
start_expert_id,
|
160
|
+
end_expert_id,
|
161
|
+
BLOCK_SIZE: tl.constexpr,
|
162
|
+
):
|
163
|
+
InDtype = gateup_output.dtype.element_ty
|
164
|
+
OutDtype = down_input.dtype.element_ty
|
165
|
+
|
166
|
+
half_hidden_size = hidden_size // 2
|
167
|
+
|
168
|
+
pid = tl.program_id(0)
|
169
|
+
expert_id = tl.load(reorder_topk_ids + pid)
|
170
|
+
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
171
|
+
gateup_output_ptr = gateup_output + pid * hidden_size
|
172
|
+
gate_output_ptr = gateup_output_ptr
|
173
|
+
up_output_ptr = gateup_output_ptr + half_hidden_size
|
174
|
+
down_input_ptr = down_input + pid * half_hidden_size
|
175
|
+
|
176
|
+
if scales is not None:
|
177
|
+
scale = tl.load(scales + expert_id - start_expert_id)
|
178
|
+
scale = (1 / scale).to(InDtype)
|
179
|
+
else:
|
180
|
+
scale = 1
|
181
|
+
|
182
|
+
for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
|
183
|
+
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
184
|
+
mask = offset < half_hidden_size
|
185
|
+
|
186
|
+
gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
|
187
|
+
up_output = tl.load(up_output_ptr + offset, mask=mask)
|
188
|
+
|
189
|
+
# gelu & mul & quantize
|
190
|
+
# https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
|
191
|
+
# sqrt(2/pi)
|
192
|
+
kAlpha = 0.7978845608028654
|
193
|
+
gate_output = (
|
194
|
+
0.5
|
195
|
+
* gate_output
|
196
|
+
* (
|
197
|
+
1
|
198
|
+
+ tanh(
|
199
|
+
kAlpha
|
200
|
+
* (
|
201
|
+
gate_output
|
202
|
+
+ 0.044715 * gate_output * gate_output * gate_output
|
203
|
+
)
|
204
|
+
)
|
205
|
+
)
|
206
|
+
)
|
207
|
+
gate_output = gate_output.to(InDtype)
|
208
|
+
|
209
|
+
gelu_mul_output = gate_output * up_output * scale
|
210
|
+
gelu_mul_output = gelu_mul_output.to(OutDtype)
|
211
|
+
tl.store(down_input_ptr + offset, gelu_mul_output, mask=mask)
|
212
|
+
|
213
|
+
|
140
214
|
@triton.jit
|
141
215
|
def post_reorder_triton_kernel(
|
142
216
|
down_output_ptr,
|
@@ -218,12 +292,19 @@ def grouped_gemm_triton_kernel(
|
|
218
292
|
seg_indptr,
|
219
293
|
weight_indices,
|
220
294
|
m_num_tiles_indptr,
|
221
|
-
use_fp8_w8a8,
|
222
295
|
scale_a,
|
223
296
|
scale_b,
|
297
|
+
use_fp8_w8a8: tl.constexpr,
|
298
|
+
group_n: tl.constexpr,
|
299
|
+
group_k: tl.constexpr,
|
224
300
|
a_stride_0: tl.constexpr,
|
225
301
|
b_stride_0: tl.constexpr,
|
226
302
|
b_stride_1: tl.constexpr,
|
303
|
+
as_stride_0: tl.constexpr,
|
304
|
+
as_stride_1: tl.constexpr,
|
305
|
+
bs_stride_0: tl.constexpr,
|
306
|
+
bs_stride_2: tl.constexpr,
|
307
|
+
bs_stride_1: tl.constexpr,
|
227
308
|
BLOCK_SIZE_M: tl.constexpr,
|
228
309
|
BLOCK_SIZE_N: tl.constexpr,
|
229
310
|
BLOCK_SIZE_K: tl.constexpr,
|
@@ -260,6 +341,12 @@ def grouped_gemm_triton_kernel(
|
|
260
341
|
+ (n_range_start + offs_bn[:, None]) * b_stride_1
|
261
342
|
+ offs_k[None, :]
|
262
343
|
)
|
344
|
+
|
345
|
+
if group_k > 0 and group_n > 0:
|
346
|
+
a_scale_ptrs = scale_a + (m_range_start + offs_am[:, None]) * as_stride_0
|
347
|
+
offs_bsn = (n_range_start + offs_bn) // group_n
|
348
|
+
b_scale_ptrs = scale_b + (expert_id * bs_stride_0) + offs_bsn * bs_stride_1
|
349
|
+
|
263
350
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
264
351
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
265
352
|
a_tile = tl.load(
|
@@ -268,14 +355,23 @@ def grouped_gemm_triton_kernel(
|
|
268
355
|
b_tile = tl.load(
|
269
356
|
b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
|
270
357
|
)
|
271
|
-
|
358
|
+
|
359
|
+
if group_k > 0 and group_n > 0:
|
360
|
+
k_start = k * BLOCK_SIZE_K
|
361
|
+
offs_ks = k_start // group_k
|
362
|
+
a_scale = tl.load(a_scale_ptrs + offs_ks * as_stride_1)
|
363
|
+
b_scale = tl.load(b_scale_ptrs + offs_ks * bs_stride_2)
|
364
|
+
accumulator += tl.dot(a_tile, b_tile.T) * a_scale * b_scale[None, :]
|
365
|
+
else:
|
366
|
+
accumulator = tl.dot(a_tile, b_tile.T, accumulator)
|
272
367
|
a_ptr += BLOCK_SIZE_K
|
273
368
|
b_ptr += BLOCK_SIZE_K
|
274
369
|
|
275
|
-
if use_fp8_w8a8:
|
370
|
+
if use_fp8_w8a8 and not (group_k > 0 and group_n > 0):
|
276
371
|
scale_a_value = tl.load(scale_a + expert_id)
|
277
372
|
scale_b_value = tl.load(scale_b + expert_id)
|
278
373
|
accumulator *= scale_a_value * scale_b_value
|
374
|
+
|
279
375
|
c_tile = accumulator.to(c_dtype)
|
280
376
|
|
281
377
|
offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M)
|
@@ -307,14 +403,29 @@ def grouped_gemm_triton(
|
|
307
403
|
use_fp8_w8a8: bool = False,
|
308
404
|
scale_a: torch.Tensor = None,
|
309
405
|
scale_b: torch.Tensor = None,
|
406
|
+
block_shape: Optional[List[int]] = None,
|
310
407
|
):
|
311
408
|
assert weight_column_major == True # TODO: more
|
312
|
-
if use_fp8_w8a8:
|
409
|
+
if use_fp8_w8a8 and block_shape is None:
|
313
410
|
assert scale_a is not None and scale_b is not None
|
314
411
|
|
412
|
+
if block_shape is not None:
|
413
|
+
assert len(block_shape) == 2
|
414
|
+
block_n, block_k = block_shape[0], block_shape[1]
|
415
|
+
if _is_cuda:
|
416
|
+
a, scale_a = sglang_per_token_group_quant_fp8(a, block_k)
|
417
|
+
else:
|
418
|
+
a, scale_a = per_token_group_quant_fp8(a, block_k)
|
419
|
+
|
420
|
+
assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1]
|
421
|
+
assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2]
|
422
|
+
assert triton.cdiv(b.shape[-1], block_k) == scale_b.shape[-1]
|
423
|
+
|
424
|
+
# TODO: adjust config or tune kernel
|
425
|
+
# Reduce block size to prevent L40 shared memory overflow.
|
315
426
|
config = {
|
316
|
-
"BLOCK_SIZE_M":
|
317
|
-
"BLOCK_SIZE_N":
|
427
|
+
"BLOCK_SIZE_M": 64,
|
428
|
+
"BLOCK_SIZE_N": 32,
|
318
429
|
"BLOCK_SIZE_K": 128,
|
319
430
|
}
|
320
431
|
|
@@ -338,12 +449,19 @@ def grouped_gemm_triton(
|
|
338
449
|
seg_indptr,
|
339
450
|
weight_indices,
|
340
451
|
m_num_tiles_indptr,
|
341
|
-
use_fp8_w8a8,
|
342
452
|
scale_a,
|
343
453
|
scale_b,
|
454
|
+
use_fp8_w8a8,
|
455
|
+
0 if block_shape is None else block_shape[0],
|
456
|
+
0 if block_shape is None else block_shape[1],
|
344
457
|
a.stride(0),
|
345
458
|
b.stride(0),
|
346
459
|
b.stride(1),
|
460
|
+
scale_a.stride(0) if scale_a is not None and scale_a.ndim == 2 else 0,
|
461
|
+
scale_a.stride(1) if scale_a is not None and scale_a.ndim == 2 else 0,
|
462
|
+
scale_b.stride(0) if scale_b is not None and scale_b.ndim >= 2 else 0,
|
463
|
+
scale_b.stride(2) if scale_b is not None and scale_b.ndim == 3 else 0,
|
464
|
+
scale_b.stride(1) if scale_b is not None and scale_b.ndim >= 2 else 0,
|
347
465
|
**config,
|
348
466
|
)
|
349
467
|
return c
|