sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +251 -26
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +63 -3
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +34 -19
- sglang/srt/entrypoints/openai/serving_completions.py +10 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +12 -0
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +250 -112
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +110 -49
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +43 -29
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -45
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +242 -278
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +13 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +160 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +90 -115
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +41 -477
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +24 -22
- sglang/srt/mem_cache/hiradix_cache.py +184 -101
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +324 -41
- sglang/srt/mem_cache/memory_pool_host.py +25 -18
- sglang/srt/mem_cache/radix_cache.py +5 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +189 -31
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +311 -50
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +90 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +297 -79
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/utils.py +37 -2
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,799 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import os
|
4
|
+
from typing import Any, Dict, List, Optional
|
5
|
+
|
6
|
+
import torch
|
7
|
+
import triton
|
8
|
+
import triton.language as tl
|
9
|
+
|
10
|
+
from sglang.srt.layers.quantization.fp8_kernel import (
|
11
|
+
per_token_group_quant_fp8,
|
12
|
+
scaled_fp8_quant,
|
13
|
+
sglang_per_token_group_quant_fp8,
|
14
|
+
)
|
15
|
+
from sglang.srt.layers.quantization.int8_kernel import (
|
16
|
+
per_token_group_quant_int8,
|
17
|
+
per_token_quant_int8,
|
18
|
+
sglang_per_token_group_quant_int8,
|
19
|
+
)
|
20
|
+
from sglang.srt.utils import (
|
21
|
+
cpu_has_amx_support,
|
22
|
+
get_bool_env_var,
|
23
|
+
is_cpu,
|
24
|
+
is_cuda,
|
25
|
+
is_hip,
|
26
|
+
)
|
27
|
+
|
28
|
+
_is_hip = is_hip()
|
29
|
+
_is_cuda = is_cuda()
|
30
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
31
|
+
_is_cpu = is_cpu()
|
32
|
+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
33
|
+
|
34
|
+
if _is_cuda:
|
35
|
+
pass
|
36
|
+
elif _is_cpu and _is_cpu_amx_available:
|
37
|
+
pass
|
38
|
+
elif _is_hip:
|
39
|
+
pass
|
40
|
+
|
41
|
+
padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
|
42
|
+
|
43
|
+
|
44
|
+
@triton.jit
|
45
|
+
def write_zeros_to_output(
|
46
|
+
c_ptr,
|
47
|
+
stride_cm,
|
48
|
+
stride_cn,
|
49
|
+
pid_n,
|
50
|
+
N,
|
51
|
+
offs_token,
|
52
|
+
token_mask,
|
53
|
+
BLOCK_SIZE_M,
|
54
|
+
BLOCK_SIZE_N,
|
55
|
+
compute_type,
|
56
|
+
):
|
57
|
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type)
|
58
|
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
59
|
+
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
60
|
+
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
61
|
+
tl.store(c_ptrs, accumulator, mask=c_mask)
|
62
|
+
|
63
|
+
|
64
|
+
@triton.jit
|
65
|
+
def fused_moe_kernel_gptq_awq(
|
66
|
+
# Pointers to matrices
|
67
|
+
a_ptr,
|
68
|
+
b_ptr,
|
69
|
+
c_ptr,
|
70
|
+
b_scale_ptr,
|
71
|
+
b_zp_ptr,
|
72
|
+
topk_weights_ptr,
|
73
|
+
sorted_token_ids_ptr,
|
74
|
+
expert_ids_ptr,
|
75
|
+
num_tokens_post_padded_ptr,
|
76
|
+
# Matrix dimensions
|
77
|
+
N: tl.constexpr,
|
78
|
+
K: tl.constexpr,
|
79
|
+
EM,
|
80
|
+
num_valid_tokens,
|
81
|
+
# The stride variables represent how much to increase the ptr by when
|
82
|
+
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
83
|
+
# how much to increase `a_ptr` by to get the element one row down
|
84
|
+
# (A has M rows).
|
85
|
+
stride_am,
|
86
|
+
stride_ak,
|
87
|
+
stride_be,
|
88
|
+
stride_bk,
|
89
|
+
stride_bn,
|
90
|
+
stride_cm,
|
91
|
+
stride_cn,
|
92
|
+
stride_bse,
|
93
|
+
stride_bsk,
|
94
|
+
stride_bsn,
|
95
|
+
stride_bze,
|
96
|
+
stride_bzk,
|
97
|
+
stride_bzn,
|
98
|
+
group_size: tl.constexpr,
|
99
|
+
# Meta-parameters
|
100
|
+
BLOCK_SIZE_M: tl.constexpr,
|
101
|
+
BLOCK_SIZE_N: tl.constexpr,
|
102
|
+
BLOCK_SIZE_K: tl.constexpr,
|
103
|
+
GROUP_SIZE_M: tl.constexpr,
|
104
|
+
MUL_ROUTED_WEIGHT: tl.constexpr,
|
105
|
+
top_k: tl.constexpr,
|
106
|
+
compute_type: tl.constexpr,
|
107
|
+
has_zp: tl.constexpr,
|
108
|
+
use_int4_w4a16: tl.constexpr,
|
109
|
+
use_int8_w8a16: tl.constexpr,
|
110
|
+
even_Ks: tl.constexpr,
|
111
|
+
):
|
112
|
+
"""
|
113
|
+
Implements the fused computation for a Mixture of Experts (MOE) using
|
114
|
+
token and expert matrices.
|
115
|
+
Key Parameters:
|
116
|
+
- A: The input tensor representing tokens with shape (*, K), where '*' can
|
117
|
+
be any shape representing batches and K is the feature dimension of
|
118
|
+
each token.
|
119
|
+
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
|
120
|
+
the number of experts, K is the input feature dimension, and N is
|
121
|
+
the output feature dimension.
|
122
|
+
- C: The output cache tensor with shape (M, topk, N), where M is the
|
123
|
+
total number of tokens post padding, topk is the number of times
|
124
|
+
each token is repeated, and N is the output feature dimension.
|
125
|
+
- sorted_token_ids: A tensor containing the sorted indices of tokens,
|
126
|
+
repeated topk times and arranged by the expert index they are
|
127
|
+
assigned to.
|
128
|
+
- expert_ids: A tensor containing the indices of the expert for each
|
129
|
+
block. It determines which expert matrix from B should be used for
|
130
|
+
each block in A.
|
131
|
+
This kernel performs the multiplication of a token by its corresponding
|
132
|
+
expert matrix as determined by `expert_ids`. The sorting of
|
133
|
+
`sorted_token_ids` by expert index and padding ensures divisibility by
|
134
|
+
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
|
135
|
+
multiplication across different blocks processed by the same expert.
|
136
|
+
"""
|
137
|
+
# -----------------------------------------------------------
|
138
|
+
# Map program ids `pid` to the block of C it should compute.
|
139
|
+
# This is done in a grouped ordering to promote L2 data reuse.
|
140
|
+
pid = tl.program_id(axis=0)
|
141
|
+
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
142
|
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
143
|
+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
144
|
+
group_id = pid // num_pid_in_group
|
145
|
+
first_pid_m = group_id * GROUP_SIZE_M
|
146
|
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
147
|
+
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
148
|
+
pid_n = (pid % num_pid_in_group) // group_size_m
|
149
|
+
|
150
|
+
# ----------------------------------------------------------
|
151
|
+
# Create pointers for the first blocks of A and B.
|
152
|
+
# We will advance this pointer as we move in the K direction
|
153
|
+
# and accumulate
|
154
|
+
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
155
|
+
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
|
156
|
+
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
157
|
+
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
158
|
+
return
|
159
|
+
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
160
|
+
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
161
|
+
token_mask = offs_token < num_valid_tokens
|
162
|
+
|
163
|
+
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
164
|
+
if off_experts == -1:
|
165
|
+
# -----------------------------------------------------------
|
166
|
+
# Write back zeros to the output when the expert is not
|
167
|
+
# in the current expert parallel rank.
|
168
|
+
write_zeros_to_output(
|
169
|
+
c_ptr,
|
170
|
+
stride_cm,
|
171
|
+
stride_cn,
|
172
|
+
pid_n,
|
173
|
+
N,
|
174
|
+
offs_token,
|
175
|
+
token_mask,
|
176
|
+
BLOCK_SIZE_M,
|
177
|
+
BLOCK_SIZE_N,
|
178
|
+
compute_type,
|
179
|
+
)
|
180
|
+
return
|
181
|
+
|
182
|
+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
183
|
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
184
|
+
a_ptrs = a_ptr + (
|
185
|
+
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
|
186
|
+
)
|
187
|
+
|
188
|
+
if use_int4_w4a16:
|
189
|
+
b_ptrs = (
|
190
|
+
b_ptr
|
191
|
+
+ off_experts * stride_be
|
192
|
+
+ (offs_k[:, None] // 2) * stride_bk
|
193
|
+
+ offs_bn[None, :] * stride_bn
|
194
|
+
)
|
195
|
+
b_shifter = (offs_k[:, None] % 2) * 4
|
196
|
+
elif use_int8_w8a16:
|
197
|
+
b_ptrs = (
|
198
|
+
b_ptr
|
199
|
+
+ off_experts * stride_be
|
200
|
+
+ offs_k[:, None] * stride_bk
|
201
|
+
+ offs_bn[None, :] * stride_bn
|
202
|
+
)
|
203
|
+
|
204
|
+
if not has_zp and use_int4_w4a16:
|
205
|
+
b_zp_num = 8
|
206
|
+
if not has_zp and use_int8_w8a16:
|
207
|
+
b_zp_num = 128
|
208
|
+
elif has_zp and use_int4_w4a16:
|
209
|
+
b_zp_shifter = (offs_bn[None, :] % 2) * 4
|
210
|
+
|
211
|
+
# -----------------------------------------------------------
|
212
|
+
# Iterate to compute a block of the C matrix.
|
213
|
+
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
214
|
+
# of fp32 values for higher accuracy.
|
215
|
+
# `accumulator` will be converted back to fp16 after the loop.
|
216
|
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
217
|
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
218
|
+
# Load the next block of A and B, generate a mask by checking the
|
219
|
+
# K dimension.
|
220
|
+
|
221
|
+
if not even_Ks:
|
222
|
+
k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
|
223
|
+
k_other = 0.0
|
224
|
+
else:
|
225
|
+
k_mask = None
|
226
|
+
k_other = None
|
227
|
+
|
228
|
+
a = tl.load(
|
229
|
+
a_ptrs,
|
230
|
+
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
231
|
+
other=0.0,
|
232
|
+
)
|
233
|
+
b = tl.load(b_ptrs)
|
234
|
+
if use_int4_w4a16:
|
235
|
+
b = (b >> b_shifter) & 0xF
|
236
|
+
|
237
|
+
b_scale_ptrs = (
|
238
|
+
b_scale_ptr
|
239
|
+
+ off_experts * stride_bse
|
240
|
+
+ offs_bn[None, :] * stride_bsn
|
241
|
+
+ ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
|
242
|
+
)
|
243
|
+
b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
|
244
|
+
b_scale = b_scale.to(tl.float32)
|
245
|
+
|
246
|
+
if has_zp and use_int4_w4a16:
|
247
|
+
offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
|
248
|
+
b_zp_ptrs = (
|
249
|
+
b_zp_ptr
|
250
|
+
+ off_experts * stride_bze
|
251
|
+
+ (offs_bn[None, :] // 2) * stride_bzn
|
252
|
+
+ offs_k_true * stride_bzk
|
253
|
+
)
|
254
|
+
b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
|
255
|
+
b_zp = (b_zp >> b_zp_shifter) & 0xF
|
256
|
+
b_zp = b_zp.to(tl.float32)
|
257
|
+
elif has_zp and use_int8_w8a16:
|
258
|
+
offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
|
259
|
+
b_zp_ptrs = (
|
260
|
+
b_zp_ptr
|
261
|
+
+ off_experts * stride_bze
|
262
|
+
+ offs_bn[None, :] * stride_bzn
|
263
|
+
+ offs_k_true * stride_bzk
|
264
|
+
)
|
265
|
+
b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
|
266
|
+
b_zp = b_zp.to(tl.float32)
|
267
|
+
|
268
|
+
# We accumulate along the K dimension.
|
269
|
+
if has_zp:
|
270
|
+
b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type)
|
271
|
+
else:
|
272
|
+
b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type)
|
273
|
+
accumulator = tl.dot(a, b, acc=accumulator)
|
274
|
+
|
275
|
+
# Advance the ptrs to the next K block.
|
276
|
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
277
|
+
if use_int4_w4a16:
|
278
|
+
b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
|
279
|
+
else:
|
280
|
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
281
|
+
|
282
|
+
if MUL_ROUTED_WEIGHT:
|
283
|
+
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
284
|
+
accumulator = accumulator * moe_weight[:, None]
|
285
|
+
|
286
|
+
accumulator = accumulator.to(compute_type)
|
287
|
+
# -----------------------------------------------------------
|
288
|
+
# Write back the block of the output
|
289
|
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
290
|
+
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
291
|
+
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
292
|
+
tl.store(c_ptrs, accumulator, mask=c_mask)
|
293
|
+
|
294
|
+
|
295
|
+
@triton.jit
|
296
|
+
def fused_moe_kernel(
|
297
|
+
# Pointers to matrices
|
298
|
+
a_ptr,
|
299
|
+
b_ptr,
|
300
|
+
bias_ptr,
|
301
|
+
c_ptr,
|
302
|
+
a_scale_ptr,
|
303
|
+
b_scale_ptr,
|
304
|
+
topk_weights_ptr,
|
305
|
+
sorted_token_ids_ptr,
|
306
|
+
expert_ids_ptr,
|
307
|
+
num_tokens_post_padded_ptr,
|
308
|
+
# Matrix dimensions
|
309
|
+
N,
|
310
|
+
K,
|
311
|
+
EM,
|
312
|
+
num_valid_tokens,
|
313
|
+
# The stride variables represent how much to increase the ptr by when
|
314
|
+
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
315
|
+
# how much to increase `a_ptr` by to get the element one row down
|
316
|
+
# (A has M rows).
|
317
|
+
stride_am,
|
318
|
+
stride_ak,
|
319
|
+
stride_be,
|
320
|
+
stride_bk,
|
321
|
+
stride_bn,
|
322
|
+
stride_bias_e,
|
323
|
+
stride_bias_n,
|
324
|
+
stride_cm,
|
325
|
+
stride_cn,
|
326
|
+
stride_asm,
|
327
|
+
stride_ask,
|
328
|
+
stride_bse,
|
329
|
+
stride_bsk,
|
330
|
+
stride_bsn,
|
331
|
+
# Block size for block-wise quantization
|
332
|
+
group_n: tl.constexpr,
|
333
|
+
group_k: tl.constexpr,
|
334
|
+
# Meta-parameters
|
335
|
+
BLOCK_SIZE_M: tl.constexpr,
|
336
|
+
BLOCK_SIZE_N: tl.constexpr,
|
337
|
+
BLOCK_SIZE_K: tl.constexpr,
|
338
|
+
GROUP_SIZE_M: tl.constexpr,
|
339
|
+
MUL_ROUTED_WEIGHT: tl.constexpr,
|
340
|
+
top_k: tl.constexpr,
|
341
|
+
compute_type: tl.constexpr,
|
342
|
+
use_fp8_w8a8: tl.constexpr,
|
343
|
+
use_int8_w8a8: tl.constexpr,
|
344
|
+
use_int8_w8a16: tl.constexpr,
|
345
|
+
per_channel_quant: tl.constexpr,
|
346
|
+
even_Ks: tl.constexpr,
|
347
|
+
):
|
348
|
+
"""
|
349
|
+
Implements the fused computation for a Mixture of Experts (MOE) using
|
350
|
+
token and expert matrices.
|
351
|
+
|
352
|
+
Key Parameters:
|
353
|
+
- A: The input tensor representing tokens with shape (*, K), where '*' can
|
354
|
+
be any shape representing batches and K is the feature dimension of
|
355
|
+
each token.
|
356
|
+
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
|
357
|
+
the number of experts, K is the input feature dimension, and N is
|
358
|
+
the output feature dimension.
|
359
|
+
- C: The output cache tensor with shape (M, topk, N), where M is the
|
360
|
+
total number of tokens post padding, topk is the number of times
|
361
|
+
each token is repeated, and N is the output feature dimension.
|
362
|
+
- sorted_token_ids: A tensor containing the sorted indices of tokens,
|
363
|
+
repeated topk times and arranged by the expert index they are
|
364
|
+
assigned to.
|
365
|
+
- expert_ids: A tensor containing the indices of the expert for each
|
366
|
+
block. It determines which expert matrix from B should be used for
|
367
|
+
each block in A.
|
368
|
+
|
369
|
+
This kernel performs the multiplication of a token by its corresponding
|
370
|
+
expert matrix as determined by `expert_ids`. The sorting of
|
371
|
+
`sorted_token_ids` by expert index and padding ensures divisibility by
|
372
|
+
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
|
373
|
+
multiplication across different blocks processed by the same expert.
|
374
|
+
"""
|
375
|
+
# -----------------------------------------------------------
|
376
|
+
# Map program ids `pid` to the block of C it should compute.
|
377
|
+
# This is done in a grouped ordering to promote L2 data reuse.
|
378
|
+
pid = tl.program_id(axis=0)
|
379
|
+
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
380
|
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
381
|
+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
382
|
+
group_id = pid // num_pid_in_group
|
383
|
+
first_pid_m = group_id * GROUP_SIZE_M
|
384
|
+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
385
|
+
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
386
|
+
pid_n = (pid % num_pid_in_group) // group_size_m
|
387
|
+
|
388
|
+
# ----------------------------------------------------------
|
389
|
+
# Create pointers for the first blocks of A and B.
|
390
|
+
# We will advance this pointer as we move in the K direction
|
391
|
+
# and accumulate
|
392
|
+
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
393
|
+
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
|
394
|
+
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
395
|
+
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
396
|
+
return
|
397
|
+
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
398
|
+
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
399
|
+
offs_token = offs_token.to(tl.int64)
|
400
|
+
token_mask = offs_token < num_valid_tokens
|
401
|
+
|
402
|
+
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
403
|
+
|
404
|
+
if off_experts == -1:
|
405
|
+
# -----------------------------------------------------------
|
406
|
+
# Write back zeros to the output when the expert is not
|
407
|
+
# in the current expert parallel rank.
|
408
|
+
write_zeros_to_output(
|
409
|
+
c_ptr,
|
410
|
+
stride_cm,
|
411
|
+
stride_cn,
|
412
|
+
pid_n,
|
413
|
+
N,
|
414
|
+
offs_token,
|
415
|
+
token_mask,
|
416
|
+
BLOCK_SIZE_M,
|
417
|
+
BLOCK_SIZE_N,
|
418
|
+
compute_type,
|
419
|
+
)
|
420
|
+
return
|
421
|
+
|
422
|
+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
423
|
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
424
|
+
a_ptrs = a_ptr + (
|
425
|
+
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
|
426
|
+
)
|
427
|
+
|
428
|
+
b_ptrs = (
|
429
|
+
b_ptr
|
430
|
+
+ off_experts * stride_be
|
431
|
+
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
432
|
+
)
|
433
|
+
if bias_ptr is not None:
|
434
|
+
bias = tl.load(
|
435
|
+
bias_ptr + off_experts * stride_bias_e + offs_bn[None, :] * stride_bias_n
|
436
|
+
)
|
437
|
+
if use_int8_w8a16:
|
438
|
+
b_scale_ptrs = (
|
439
|
+
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
|
440
|
+
)
|
441
|
+
b_scale = tl.load(b_scale_ptrs)
|
442
|
+
|
443
|
+
if use_fp8_w8a8 or use_int8_w8a8:
|
444
|
+
# block-wise
|
445
|
+
if group_k > 0 and group_n > 0:
|
446
|
+
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
447
|
+
offs_bsn = offs_bn // group_n
|
448
|
+
b_scale_ptrs = (
|
449
|
+
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
|
450
|
+
)
|
451
|
+
# channel-wise
|
452
|
+
elif per_channel_quant:
|
453
|
+
b_scale_ptrs = (
|
454
|
+
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
|
455
|
+
)
|
456
|
+
b_scale = tl.load(b_scale_ptrs)
|
457
|
+
# Load per-token scale for activations
|
458
|
+
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
459
|
+
a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
|
460
|
+
# tensor-wise
|
461
|
+
else:
|
462
|
+
a_scale = tl.load(a_scale_ptr)
|
463
|
+
b_scale = tl.load(b_scale_ptr + off_experts)
|
464
|
+
|
465
|
+
# -----------------------------------------------------------
|
466
|
+
# Iterate to compute a block of the C matrix.
|
467
|
+
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
468
|
+
# of fp32 values for higher accuracy.
|
469
|
+
# `accumulator` will be converted back to fp16 after the loop.
|
470
|
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
471
|
+
|
472
|
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
473
|
+
# Load the next block of A and B, generate a mask by checking the
|
474
|
+
# K dimension.
|
475
|
+
if even_Ks:
|
476
|
+
a = tl.load(
|
477
|
+
a_ptrs,
|
478
|
+
mask=token_mask[:, None],
|
479
|
+
other=0.0,
|
480
|
+
)
|
481
|
+
b = tl.load(b_ptrs)
|
482
|
+
else:
|
483
|
+
a = tl.load(
|
484
|
+
a_ptrs,
|
485
|
+
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
486
|
+
other=0.0,
|
487
|
+
)
|
488
|
+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
489
|
+
|
490
|
+
# We accumulate along the K dimension.
|
491
|
+
if use_int8_w8a16:
|
492
|
+
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
493
|
+
elif use_fp8_w8a8 or use_int8_w8a8:
|
494
|
+
if group_k > 0 and group_n > 0:
|
495
|
+
k_start = k * BLOCK_SIZE_K
|
496
|
+
offs_ks = k_start // group_k
|
497
|
+
a_scale = tl.load(
|
498
|
+
a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
|
499
|
+
)
|
500
|
+
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
|
501
|
+
|
502
|
+
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
|
503
|
+
else:
|
504
|
+
if use_fp8_w8a8:
|
505
|
+
accumulator = tl.dot(a, b, acc=accumulator)
|
506
|
+
else:
|
507
|
+
accumulator += tl.dot(a, b)
|
508
|
+
else:
|
509
|
+
accumulator += tl.dot(a, b)
|
510
|
+
# Advance the ptrs to the next K block.
|
511
|
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
512
|
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
513
|
+
|
514
|
+
if use_int8_w8a16:
|
515
|
+
accumulator *= b_scale
|
516
|
+
elif use_fp8_w8a8 or use_int8_w8a8:
|
517
|
+
if group_k == 0 or group_n == 0:
|
518
|
+
accumulator *= a_scale * b_scale
|
519
|
+
|
520
|
+
if bias_ptr is not None:
|
521
|
+
accumulator += bias
|
522
|
+
|
523
|
+
if MUL_ROUTED_WEIGHT:
|
524
|
+
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
525
|
+
accumulator *= moe_weight[:, None]
|
526
|
+
|
527
|
+
accumulator = accumulator.to(compute_type)
|
528
|
+
# -----------------------------------------------------------
|
529
|
+
# Write back the block of the output
|
530
|
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
531
|
+
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
532
|
+
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
533
|
+
tl.store(c_ptrs, accumulator, mask=c_mask)
|
534
|
+
|
535
|
+
|
536
|
+
def invoke_fused_moe_kernel(
|
537
|
+
A: torch.Tensor,
|
538
|
+
B: torch.Tensor,
|
539
|
+
bias: Optional[torch.Tensor],
|
540
|
+
C: torch.Tensor,
|
541
|
+
A_scale: Optional[torch.Tensor],
|
542
|
+
B_scale: Optional[torch.Tensor],
|
543
|
+
B_zp: Optional[torch.Tensor],
|
544
|
+
topk_weights: torch.Tensor,
|
545
|
+
topk_ids: torch.Tensor,
|
546
|
+
sorted_token_ids: torch.Tensor,
|
547
|
+
expert_ids: torch.Tensor,
|
548
|
+
num_tokens_post_padded: torch.Tensor,
|
549
|
+
mul_routed_weight: bool,
|
550
|
+
top_k: int,
|
551
|
+
config: Dict[str, Any],
|
552
|
+
compute_type: tl.dtype,
|
553
|
+
use_fp8_w8a8: bool,
|
554
|
+
use_int8_w8a8: bool,
|
555
|
+
use_int8_w8a16: bool,
|
556
|
+
use_int4_w4a16: bool,
|
557
|
+
per_channel_quant: bool,
|
558
|
+
block_shape: Optional[List[int]] = None,
|
559
|
+
no_combine: bool = False,
|
560
|
+
) -> None:
|
561
|
+
assert topk_weights.stride(1) == 1
|
562
|
+
assert sorted_token_ids.stride(0) == 1
|
563
|
+
|
564
|
+
padded_size = 0
|
565
|
+
if use_fp8_w8a8:
|
566
|
+
assert B_scale is not None
|
567
|
+
if block_shape is None:
|
568
|
+
# activation tensor-wise fp8 quantization, dynamic or static
|
569
|
+
padded_size = padding_size
|
570
|
+
# activations apply per-token quantization when weights apply per-channel quantization by default
|
571
|
+
A, A_scale = scaled_fp8_quant(
|
572
|
+
A, A_scale, use_per_token_if_dynamic=per_channel_quant
|
573
|
+
)
|
574
|
+
else:
|
575
|
+
# activation block-wise fp8 quantization
|
576
|
+
assert len(block_shape) == 2
|
577
|
+
block_n, block_k = block_shape[0], block_shape[1]
|
578
|
+
if _is_cuda:
|
579
|
+
A, A_scale = sglang_per_token_group_quant_fp8(A, block_k)
|
580
|
+
else:
|
581
|
+
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
582
|
+
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
583
|
+
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
584
|
+
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
585
|
+
elif use_int8_w8a8:
|
586
|
+
assert B_scale is not None
|
587
|
+
if block_shape is None:
|
588
|
+
# activation channel-wise int8 quantization
|
589
|
+
assert (
|
590
|
+
per_channel_quant
|
591
|
+
), "int8 quantization only supports channel-wise quantization except for block-wise quantization"
|
592
|
+
A, A_scale = per_token_quant_int8(A)
|
593
|
+
else:
|
594
|
+
# activation block-wise int8 quantization
|
595
|
+
assert len(block_shape) == 2
|
596
|
+
block_n, block_k = block_shape[0], block_shape[1]
|
597
|
+
if _is_cuda:
|
598
|
+
A, A_scale = sglang_per_token_group_quant_int8(A, block_k)
|
599
|
+
else:
|
600
|
+
A, A_scale = per_token_group_quant_int8(A, block_k)
|
601
|
+
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
602
|
+
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
603
|
+
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
604
|
+
elif use_int8_w8a16 or use_int4_w4a16:
|
605
|
+
assert B_scale is not None
|
606
|
+
assert block_shape is None or block_shape[0] == 0
|
607
|
+
else:
|
608
|
+
assert A_scale is None
|
609
|
+
assert B_scale is None
|
610
|
+
|
611
|
+
grid = lambda META: (
|
612
|
+
triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
|
613
|
+
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
|
614
|
+
)
|
615
|
+
|
616
|
+
K = B.shape[2] - padded_size
|
617
|
+
if K % config["BLOCK_SIZE_K"] == 0:
|
618
|
+
even_Ks = True
|
619
|
+
else:
|
620
|
+
even_Ks = False
|
621
|
+
|
622
|
+
if (
|
623
|
+
(use_int8_w8a16 or use_int4_w4a16)
|
624
|
+
and block_shape is not None
|
625
|
+
and block_shape[1] > 0
|
626
|
+
):
|
627
|
+
assert B_scale is not None and B_scale.ndim == 3
|
628
|
+
assert B_zp is None or B_zp.ndim == 3
|
629
|
+
assert bias is None
|
630
|
+
fused_moe_kernel_gptq_awq[grid](
|
631
|
+
A,
|
632
|
+
B,
|
633
|
+
C,
|
634
|
+
B_scale,
|
635
|
+
B_zp,
|
636
|
+
topk_weights,
|
637
|
+
sorted_token_ids,
|
638
|
+
expert_ids,
|
639
|
+
num_tokens_post_padded,
|
640
|
+
B.shape[1],
|
641
|
+
A.shape[1],
|
642
|
+
sorted_token_ids.shape[0],
|
643
|
+
topk_ids.numel(),
|
644
|
+
A.stride(0),
|
645
|
+
A.stride(1),
|
646
|
+
B.stride(0),
|
647
|
+
B.stride(2),
|
648
|
+
B.stride(1),
|
649
|
+
C.stride(1),
|
650
|
+
C.stride(2),
|
651
|
+
B_scale.stride(0),
|
652
|
+
B_scale.stride(2),
|
653
|
+
B_scale.stride(1),
|
654
|
+
B_zp.stride(0) if B_zp is not None else 0,
|
655
|
+
B_zp.stride(2) if B_zp is not None else 0,
|
656
|
+
B_zp.stride(1) if B_zp is not None else 0,
|
657
|
+
group_size=block_shape[1],
|
658
|
+
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
659
|
+
top_k=top_k,
|
660
|
+
compute_type=compute_type,
|
661
|
+
has_zp=B_zp is not None,
|
662
|
+
use_int4_w4a16=use_int4_w4a16,
|
663
|
+
use_int8_w8a16=use_int8_w8a16,
|
664
|
+
even_Ks=even_Ks,
|
665
|
+
**config,
|
666
|
+
)
|
667
|
+
|
668
|
+
else:
|
669
|
+
|
670
|
+
fused_moe_kernel[grid](
|
671
|
+
A,
|
672
|
+
B,
|
673
|
+
bias,
|
674
|
+
C,
|
675
|
+
A_scale,
|
676
|
+
B_scale,
|
677
|
+
topk_weights,
|
678
|
+
sorted_token_ids,
|
679
|
+
expert_ids,
|
680
|
+
num_tokens_post_padded,
|
681
|
+
B.shape[1],
|
682
|
+
B.shape[2] - padded_size,
|
683
|
+
sorted_token_ids.shape[0],
|
684
|
+
topk_ids.numel(),
|
685
|
+
A.stride(0),
|
686
|
+
A.stride(1),
|
687
|
+
B.stride(0),
|
688
|
+
B.stride(2),
|
689
|
+
B.stride(1),
|
690
|
+
bias.stride(0) if bias is not None else 0,
|
691
|
+
bias.stride(1) if bias is not None else 0,
|
692
|
+
C.stride(1),
|
693
|
+
C.stride(2),
|
694
|
+
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
|
695
|
+
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
|
696
|
+
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
697
|
+
B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
|
698
|
+
B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
699
|
+
0 if block_shape is None else block_shape[0],
|
700
|
+
0 if block_shape is None else block_shape[1],
|
701
|
+
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
702
|
+
top_k=top_k,
|
703
|
+
compute_type=compute_type,
|
704
|
+
use_fp8_w8a8=use_fp8_w8a8,
|
705
|
+
use_int8_w8a8=use_int8_w8a8,
|
706
|
+
use_int8_w8a16=use_int8_w8a16,
|
707
|
+
per_channel_quant=per_channel_quant,
|
708
|
+
even_Ks=even_Ks,
|
709
|
+
**config,
|
710
|
+
)
|
711
|
+
|
712
|
+
|
713
|
+
# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py
|
714
|
+
@triton.jit
|
715
|
+
def _moe_sum_reduce_kernel(
|
716
|
+
input_ptr,
|
717
|
+
input_stride_0,
|
718
|
+
input_stride_1,
|
719
|
+
input_stride_2,
|
720
|
+
output_ptr,
|
721
|
+
output_stride_0,
|
722
|
+
output_stride_1,
|
723
|
+
token_num: int,
|
724
|
+
topk_num: int,
|
725
|
+
hidden_dim: int,
|
726
|
+
routed_scaling_factor: tl.constexpr,
|
727
|
+
BLOCK_M: tl.constexpr,
|
728
|
+
BLOCK_DIM: tl.constexpr,
|
729
|
+
NUM_STAGE: tl.constexpr,
|
730
|
+
):
|
731
|
+
input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64)
|
732
|
+
input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64)
|
733
|
+
output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64)
|
734
|
+
|
735
|
+
token_block_id = tl.program_id(0)
|
736
|
+
dim_block_id = tl.program_id(1)
|
737
|
+
|
738
|
+
offs_token = token_block_id * BLOCK_M + tl.arange(0, BLOCK_M)
|
739
|
+
offs_dim = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM)
|
740
|
+
|
741
|
+
mask_token = offs_token < token_num
|
742
|
+
mask_dim = offs_dim < hidden_dim
|
743
|
+
|
744
|
+
base_ptrs = input_ptr + offs_token[:, None] * input_stride_0 + offs_dim[None, :]
|
745
|
+
|
746
|
+
accumulator = tl.zeros((BLOCK_M, BLOCK_DIM), dtype=tl.float32)
|
747
|
+
|
748
|
+
for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
|
749
|
+
tile = tl.load(
|
750
|
+
base_ptrs + i * input_stride_1,
|
751
|
+
mask=mask_token[:, None] & mask_dim[None, :],
|
752
|
+
other=0.0,
|
753
|
+
)
|
754
|
+
accumulator += tile.to(tl.float32)
|
755
|
+
accumulator *= routed_scaling_factor
|
756
|
+
|
757
|
+
# -------- Write back --------
|
758
|
+
store_ptrs = output_ptr + offs_token[:, None] * output_stride_0 + offs_dim[None, :]
|
759
|
+
tl.store(
|
760
|
+
store_ptrs,
|
761
|
+
accumulator.to(input_ptr.dtype.element_ty),
|
762
|
+
mask=mask_token[:, None] & mask_dim[None, :],
|
763
|
+
)
|
764
|
+
|
765
|
+
|
766
|
+
def moe_sum_reduce_triton(
|
767
|
+
input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float
|
768
|
+
):
|
769
|
+
assert input.is_contiguous()
|
770
|
+
assert output.is_contiguous()
|
771
|
+
|
772
|
+
token_num, topk_num, hidden_dim = input.shape
|
773
|
+
assert output.shape[0] == token_num and output.shape[1] == hidden_dim
|
774
|
+
|
775
|
+
BLOCK_M = 1
|
776
|
+
BLOCK_DIM = 2048
|
777
|
+
NUM_STAGE = 1
|
778
|
+
num_warps = 16
|
779
|
+
|
780
|
+
grid = (
|
781
|
+
triton.cdiv(token_num, BLOCK_M),
|
782
|
+
triton.cdiv(hidden_dim, BLOCK_DIM),
|
783
|
+
)
|
784
|
+
|
785
|
+
_moe_sum_reduce_kernel[grid](
|
786
|
+
input,
|
787
|
+
*input.stride(),
|
788
|
+
output,
|
789
|
+
*output.stride(),
|
790
|
+
token_num=token_num,
|
791
|
+
topk_num=topk_num,
|
792
|
+
hidden_dim=hidden_dim,
|
793
|
+
routed_scaling_factor=routed_scaling_factor,
|
794
|
+
BLOCK_M=BLOCK_M,
|
795
|
+
BLOCK_DIM=BLOCK_DIM,
|
796
|
+
NUM_STAGE=NUM_STAGE,
|
797
|
+
num_warps=num_warps,
|
798
|
+
)
|
799
|
+
return
|