sglang 0.4.3.post1__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +208 -295
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +9 -6
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +143 -6
- sglang/srt/managers/schedule_batch.py +238 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +681 -259
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +224 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +44 -18
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +94 -36
- sglang/srt/model_executor/cuda_graph_runner.py +55 -24
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +209 -28
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -29
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +136 -52
- sglang/srt/speculative/build_eagle_tree.py +2 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
- sglang/srt/speculative/eagle_utils.py +92 -58
- sglang/srt/speculative/eagle_worker.py +186 -94
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +214 -166
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -11,12 +11,14 @@ from sglang.srt.distributed import (
|
|
11
11
|
get_tensor_model_parallel_world_size,
|
12
12
|
)
|
13
13
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
14
|
+
gelu_and_mul_triton_kernel,
|
14
15
|
grouped_gemm_triton,
|
15
16
|
post_reorder_triton_kernel,
|
16
17
|
pre_reorder_triton_kernel,
|
17
18
|
run_moe_ep_preproess,
|
18
19
|
silu_and_mul_triton_kernel,
|
19
20
|
)
|
21
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
20
22
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase
|
21
23
|
from sglang.srt.layers.moe.topk import select_experts
|
22
24
|
from sglang.srt.layers.quantization.base_config import (
|
@@ -61,6 +63,7 @@ class GroupedGemmRunner(torch.nn.Module):
|
|
61
63
|
use_fp8_w8a8: bool = False,
|
62
64
|
scale_a: torch.Tensor = None,
|
63
65
|
scale_b: torch.Tensor = None,
|
66
|
+
block_shape: Optional[List[int]] = None,
|
64
67
|
):
|
65
68
|
if self.use_flashinfer:
|
66
69
|
# TODO: flashinfer
|
@@ -87,6 +90,7 @@ class GroupedGemmRunner(torch.nn.Module):
|
|
87
90
|
use_fp8_w8a8,
|
88
91
|
scale_a,
|
89
92
|
scale_b,
|
93
|
+
block_shape=block_shape,
|
90
94
|
)
|
91
95
|
return c
|
92
96
|
|
@@ -147,12 +151,20 @@ class EPMoE(torch.nn.Module):
|
|
147
151
|
if quant_config is None:
|
148
152
|
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
|
149
153
|
self.use_fp8_w8a8 = False
|
154
|
+
self.use_block_quant = False
|
155
|
+
self.block_shape = None
|
150
156
|
self.activation_scheme = None
|
151
157
|
else:
|
152
158
|
self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
|
153
159
|
quant_config
|
154
160
|
)
|
155
161
|
self.use_fp8_w8a8 = True
|
162
|
+
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
163
|
+
self.block_shape = (
|
164
|
+
self.quant_method.quant_config.weight_block_size
|
165
|
+
if self.use_block_quant
|
166
|
+
else None
|
167
|
+
)
|
156
168
|
self.fp8_dtype = torch.float8_e4m3fn
|
157
169
|
self.activation_scheme = quant_config.activation_scheme
|
158
170
|
|
@@ -169,11 +181,11 @@ class EPMoE(torch.nn.Module):
|
|
169
181
|
|
170
182
|
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
171
183
|
assert self.quant_method is not None
|
172
|
-
assert self.activation == "silu"
|
173
184
|
|
174
185
|
if self.grouped_gemm_runner is None:
|
175
186
|
self.grouped_gemm_runner = GroupedGemmRunner(
|
176
|
-
hidden_states.device,
|
187
|
+
hidden_states.device,
|
188
|
+
use_flashinfer=False, # TODO: use flashinfer
|
177
189
|
)
|
178
190
|
|
179
191
|
topk_weights, topk_ids = select_experts(
|
@@ -195,9 +207,13 @@ class EPMoE(torch.nn.Module):
|
|
195
207
|
gateup_input = torch.empty(
|
196
208
|
(int(hidden_states.shape[0] * self.top_k), hidden_states.shape[1]),
|
197
209
|
device=hidden_states.device,
|
198
|
-
dtype=
|
210
|
+
dtype=(
|
211
|
+
self.fp8_dtype
|
212
|
+
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
213
|
+
else hidden_states.dtype
|
214
|
+
),
|
199
215
|
)
|
200
|
-
if self.activation_scheme == "dynamic":
|
216
|
+
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
201
217
|
max_value = (
|
202
218
|
torch.max(hidden_states)
|
203
219
|
.repeat(self.num_experts_per_partition)
|
@@ -243,7 +259,12 @@ class EPMoE(torch.nn.Module):
|
|
243
259
|
weight_indices=weight_indices_cur_rank,
|
244
260
|
use_fp8_w8a8=self.use_fp8_w8a8,
|
245
261
|
scale_a=self.w13_input_scale,
|
246
|
-
scale_b=
|
262
|
+
scale_b=(
|
263
|
+
self.w13_weight_scale_inv
|
264
|
+
if self.use_block_quant
|
265
|
+
else self.w13_weight_scale
|
266
|
+
),
|
267
|
+
block_shape=self.block_shape,
|
247
268
|
)
|
248
269
|
|
249
270
|
# Act
|
@@ -251,9 +272,13 @@ class EPMoE(torch.nn.Module):
|
|
251
272
|
gateup_output.shape[0],
|
252
273
|
gateup_output.shape[1] // 2,
|
253
274
|
device=gateup_output.device,
|
254
|
-
dtype=
|
275
|
+
dtype=(
|
276
|
+
self.fp8_dtype
|
277
|
+
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
278
|
+
else hidden_states.dtype
|
279
|
+
),
|
255
280
|
)
|
256
|
-
if self.w2_input_scale is None:
|
281
|
+
if self.w2_input_scale is None and not self.use_block_quant:
|
257
282
|
self.w2_input_scale = torch.ones(
|
258
283
|
self.num_experts_per_partition,
|
259
284
|
dtype=torch.float32,
|
@@ -271,6 +296,17 @@ class EPMoE(torch.nn.Module):
|
|
271
296
|
self.end_expert_id,
|
272
297
|
BLOCK_SIZE=512,
|
273
298
|
)
|
299
|
+
elif self.activation == "gelu":
|
300
|
+
gelu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
301
|
+
gateup_output,
|
302
|
+
down_input,
|
303
|
+
gateup_output.shape[1],
|
304
|
+
reorder_topk_ids,
|
305
|
+
self.w2_input_scale,
|
306
|
+
self.start_expert_id,
|
307
|
+
self.end_expert_id,
|
308
|
+
BLOCK_SIZE=512,
|
309
|
+
)
|
274
310
|
else:
|
275
311
|
raise ValueError(f"Unsupported activation: {self.activation=}")
|
276
312
|
|
@@ -291,7 +327,12 @@ class EPMoE(torch.nn.Module):
|
|
291
327
|
weight_indices=weight_indices_cur_rank,
|
292
328
|
use_fp8_w8a8=self.use_fp8_w8a8,
|
293
329
|
scale_a=self.w2_input_scale,
|
294
|
-
scale_b=
|
330
|
+
scale_b=(
|
331
|
+
self.w2_weight_scale_inv
|
332
|
+
if self.use_block_quant
|
333
|
+
else self.w2_weight_scale
|
334
|
+
),
|
335
|
+
block_shape=self.block_shape,
|
295
336
|
)
|
296
337
|
|
297
338
|
# PostReorder
|
@@ -358,7 +399,11 @@ class EPMoE(torch.nn.Module):
|
|
358
399
|
# Special case for fp8 scales.
|
359
400
|
if "scale" in weight_name:
|
360
401
|
self._load_fp8_scale(
|
361
|
-
param.data,
|
402
|
+
param.data,
|
403
|
+
loaded_weight,
|
404
|
+
weight_name,
|
405
|
+
shard_id,
|
406
|
+
expert_id,
|
362
407
|
)
|
363
408
|
return
|
364
409
|
|
@@ -395,18 +440,33 @@ class EPMoE(torch.nn.Module):
|
|
395
440
|
param_data[expert_id] = loaded_weight
|
396
441
|
# Weight scales
|
397
442
|
elif "weight_scale" in weight_name:
|
443
|
+
if self.use_block_quant:
|
444
|
+
block_n, block_k = self.block_shape[0], self.block_shape[1]
|
445
|
+
if shard_id == "w1":
|
446
|
+
param_data[expert_id][
|
447
|
+
: (self.intermediate_size + block_n - 1) // block_n, :
|
448
|
+
] = loaded_weight
|
449
|
+
elif shard_id == "w3":
|
450
|
+
param_data[expert_id][
|
451
|
+
(self.intermediate_size + block_n - 1) // block_n :, :
|
452
|
+
] = loaded_weight
|
453
|
+
else: # w2
|
454
|
+
param_data[expert_id] = loaded_weight
|
398
455
|
# If we are in merged column case (gate_up_proj)
|
399
|
-
if shard_id in ("w1", "w3"):
|
400
|
-
# We have to keep the weight scales of w1 and w3 because
|
401
|
-
# we need to re-quantize w1/w3 weights after weight loading.
|
402
|
-
idx = 0 if shard_id == "w1" else 1
|
403
|
-
param_data[expert_id][idx] = loaded_weight
|
404
|
-
# If we are in the row parallel case (down_proj)
|
405
456
|
else:
|
406
|
-
|
457
|
+
if shard_id in ("w1", "w3"):
|
458
|
+
# We have to keep the weight scales of w1 and w3 because
|
459
|
+
# we need to re-quantize w1/w3 weights after weight loading.
|
460
|
+
idx = 0 if shard_id == "w1" else 1
|
461
|
+
param_data[expert_id][idx] = loaded_weight
|
462
|
+
|
463
|
+
# If we are in the row parallel case (down_proj)
|
464
|
+
else:
|
465
|
+
param_data[expert_id] = loaded_weight
|
407
466
|
|
408
467
|
|
409
468
|
class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
|
469
|
+
|
410
470
|
def create_weights(
|
411
471
|
self,
|
412
472
|
layer: torch.nn.Module,
|
@@ -498,6 +558,7 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
|
498
558
|
|
499
559
|
def __init__(self, quant_config: Fp8Config):
|
500
560
|
self.quant_config = quant_config
|
561
|
+
self.block_quant = self.quant_config.weight_block_size is not None
|
501
562
|
|
502
563
|
def create_weights(
|
503
564
|
self,
|
@@ -512,6 +573,29 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
|
512
573
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
513
574
|
params_dtype = torch.float8_e4m3fn
|
514
575
|
|
576
|
+
tp_size = get_tensor_model_parallel_world_size()
|
577
|
+
if self.block_quant:
|
578
|
+
block_n, block_k = (
|
579
|
+
self.quant_config.weight_block_size[0],
|
580
|
+
self.quant_config.weight_block_size[1],
|
581
|
+
)
|
582
|
+
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
583
|
+
# Required by collum parallel or enabling merged weights
|
584
|
+
if intermediate_size % block_n != 0:
|
585
|
+
raise ValueError(
|
586
|
+
f"The output_size of gate's and up's weight = "
|
587
|
+
f"{intermediate_size} is not divisible by "
|
588
|
+
f"weight quantization block_n = {block_n}."
|
589
|
+
)
|
590
|
+
if tp_size > 1:
|
591
|
+
# Required by row parallel
|
592
|
+
if intermediate_size % block_k != 0:
|
593
|
+
raise ValueError(
|
594
|
+
f"The input_size of down's weight = "
|
595
|
+
f"{intermediate_size} is not divisible by "
|
596
|
+
f"weight quantization block_k = {block_k}."
|
597
|
+
)
|
598
|
+
|
515
599
|
# WEIGHTS
|
516
600
|
w13_weight = torch.nn.Parameter(
|
517
601
|
torch.empty(
|
@@ -538,21 +622,49 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
|
538
622
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
539
623
|
|
540
624
|
# WEIGHT_SCALES
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
625
|
+
if self.block_quant:
|
626
|
+
w13_weight_scale = torch.nn.Parameter(
|
627
|
+
torch.ones(
|
628
|
+
num_experts_per_partition,
|
629
|
+
2 * ((intermediate_size + block_n - 1) // block_n),
|
630
|
+
(hidden_size + block_k - 1) // block_k,
|
631
|
+
dtype=torch.float32,
|
632
|
+
),
|
633
|
+
requires_grad=False,
|
634
|
+
)
|
635
|
+
w2_weight_scale = torch.nn.Parameter(
|
636
|
+
torch.ones(
|
637
|
+
num_experts_per_partition,
|
638
|
+
(hidden_size + block_n - 1) // block_n,
|
639
|
+
(intermediate_size + block_k - 1) // block_k,
|
640
|
+
dtype=torch.float32,
|
641
|
+
),
|
642
|
+
requires_grad=False,
|
643
|
+
)
|
644
|
+
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
645
|
+
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
646
|
+
assert self.quant_config.activation_scheme == "dynamic"
|
647
|
+
else:
|
648
|
+
# WEIGHT_SCALES
|
649
|
+
# Allocate 2 scales for w1 and w3 respectively.
|
650
|
+
w13_weight_scale = torch.nn.Parameter(
|
651
|
+
torch.ones(num_experts_per_partition, 2, dtype=torch.float32),
|
652
|
+
requires_grad=False,
|
653
|
+
)
|
654
|
+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
547
655
|
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
656
|
+
w2_weight_scale = torch.nn.Parameter(
|
657
|
+
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
658
|
+
requires_grad=False,
|
659
|
+
)
|
660
|
+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
553
661
|
# Add the quantization method used (per tensor/grouped/channel)
|
554
662
|
# to ensure the weight scales are loaded in properly
|
555
|
-
extra_weight_attrs.update(
|
663
|
+
extra_weight_attrs.update(
|
664
|
+
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
|
665
|
+
if self.block_quant
|
666
|
+
else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
667
|
+
)
|
556
668
|
# If loading fp8 checkpoint, pass the weight loaders.
|
557
669
|
# If loading an fp16 checkpoint, do not (we will quantize in
|
558
670
|
# process_weights_after_loading()
|
@@ -24,6 +24,8 @@ def fused_moe_forward_native(
|
|
24
24
|
custom_routing_function: Optional[Callable] = None,
|
25
25
|
correction_bias: Optional[torch.Tensor] = None,
|
26
26
|
activation: str = "silu",
|
27
|
+
inplace: bool = True,
|
28
|
+
no_combine: bool = False,
|
27
29
|
) -> torch.Tensor:
|
28
30
|
topk_weights, topk_ids = select_experts(
|
29
31
|
hidden_states=x,
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 32,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 64,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 64,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 64,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 3
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 64,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 16,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 4
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 128,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 16,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 2
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 2
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 64,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 2
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 64,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 4
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 64,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 64,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 5
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 64,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 4
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 32,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 1,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 32,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 64,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 32,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 16,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 32,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 16,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 32,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 16,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 32,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 16,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 3
|
145
|
+
}
|
146
|
+
}
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 64,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 16,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 3
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 64,
|
12
|
+
"BLOCK_SIZE_N": 64,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 16,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 3
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 16,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 3
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 64,
|
28
|
+
"BLOCK_SIZE_N": 64,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 32,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 64,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 16,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 64,
|
44
|
+
"BLOCK_SIZE_N": 64,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 64,
|
52
|
+
"BLOCK_SIZE_N": 64,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 64,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 64,
|
60
|
+
"BLOCK_SIZE_N": 64,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 64,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 16,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 64,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 16,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 64,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 16,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 64,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 32,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 64,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 16,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 64,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 32,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 64,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 128,
|
132
|
+
"BLOCK_SIZE_N": 64,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 16,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 64,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 3
|
145
|
+
}
|
146
|
+
}
|