sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +220 -378
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +9 -6
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +143 -6
- sglang/srt/managers/schedule_batch.py +237 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +681 -259
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +224 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +44 -18
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +94 -36
- sglang/srt/model_executor/cuda_graph_runner.py +55 -24
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +208 -28
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -32
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +136 -52
- sglang/srt/speculative/build_eagle_tree.py +2 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
- sglang/srt/speculative/eagle_utils.py +92 -58
- sglang/srt/speculative/eagle_worker.py +186 -94
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +200 -166
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,409 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from typing import Any, Callable, Dict, List, Optional
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from torch.nn import Module
|
8
|
+
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
9
|
+
|
10
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
11
|
+
from sglang.srt.layers.linear import (
|
12
|
+
LinearBase,
|
13
|
+
LinearMethodBase,
|
14
|
+
UnquantizedLinearMethod,
|
15
|
+
)
|
16
|
+
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
17
|
+
from sglang.srt.layers.quantization.base_config import (
|
18
|
+
QuantizationConfig,
|
19
|
+
QuantizeMethodBase,
|
20
|
+
)
|
21
|
+
from sglang.srt.layers.quantization.fp8_utils import BlockQuantScaleParameter
|
22
|
+
from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
|
23
|
+
from sglang.srt.utils import set_weight_attrs
|
24
|
+
|
25
|
+
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
26
|
+
|
27
|
+
logger = logging.getLogger(__name__)
|
28
|
+
|
29
|
+
|
30
|
+
class BlockInt8Config(QuantizationConfig):
|
31
|
+
"""Config class for INT8."""
|
32
|
+
|
33
|
+
def __init__(
|
34
|
+
self,
|
35
|
+
is_checkpoint_int8_serialized: bool = False,
|
36
|
+
activation_scheme: str = "dynamic",
|
37
|
+
ignored_layers: Optional[List[str]] = None,
|
38
|
+
weight_block_size: List[int] = None,
|
39
|
+
) -> None:
|
40
|
+
self.is_checkpoint_int8_serialized = is_checkpoint_int8_serialized
|
41
|
+
if is_checkpoint_int8_serialized:
|
42
|
+
logger.warning(
|
43
|
+
"Detected int8 checkpoint. Please note that the "
|
44
|
+
"format is experimental and subject to change."
|
45
|
+
)
|
46
|
+
if activation_scheme not in ACTIVATION_SCHEMES:
|
47
|
+
raise ValueError(f"Unsupported activation scheme {activation_scheme}")
|
48
|
+
self.activation_scheme = activation_scheme
|
49
|
+
self.ignored_layers = ignored_layers or []
|
50
|
+
if weight_block_size is not None:
|
51
|
+
if not is_checkpoint_int8_serialized:
|
52
|
+
raise ValueError(
|
53
|
+
f"The block-wise quantization only supports int8-serialized checkpoint for now."
|
54
|
+
)
|
55
|
+
if len(weight_block_size) != 2:
|
56
|
+
raise ValueError(
|
57
|
+
f"The quantization block size of weight must have 2 dimensions, but got {len(weight_block_size)} dimensions."
|
58
|
+
)
|
59
|
+
if activation_scheme != "dynamic":
|
60
|
+
raise ValueError(
|
61
|
+
f"The block-wise quantization only supports dynamic activation scheme for now, but got {activation_scheme} activation scheme."
|
62
|
+
)
|
63
|
+
self.weight_block_size = weight_block_size
|
64
|
+
|
65
|
+
@classmethod
|
66
|
+
def get_name(cls) -> str:
|
67
|
+
return "blockwise_int8"
|
68
|
+
|
69
|
+
@classmethod
|
70
|
+
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
71
|
+
return [torch.bfloat16, torch.half]
|
72
|
+
|
73
|
+
@classmethod
|
74
|
+
def get_min_capability(cls) -> int:
|
75
|
+
return 80
|
76
|
+
|
77
|
+
@classmethod
|
78
|
+
def get_config_filenames(cls) -> List[str]:
|
79
|
+
return []
|
80
|
+
|
81
|
+
@classmethod
|
82
|
+
def from_config(cls, config: Dict[str, Any]) -> "BlockInt8Config":
|
83
|
+
quant_method = cls.get_from_keys(config, ["quant_method"])
|
84
|
+
is_checkpoint_int8_serialized = "int8" in quant_method
|
85
|
+
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
86
|
+
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
|
87
|
+
weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
|
88
|
+
return cls(
|
89
|
+
is_checkpoint_int8_serialized=is_checkpoint_int8_serialized,
|
90
|
+
activation_scheme=activation_scheme,
|
91
|
+
ignored_layers=ignored_layers,
|
92
|
+
weight_block_size=weight_block_size,
|
93
|
+
)
|
94
|
+
|
95
|
+
def get_quant_method(
|
96
|
+
self, layer: torch.nn.Module, prefix: str
|
97
|
+
) -> Optional["QuantizeMethodBase"]:
|
98
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
99
|
+
|
100
|
+
if isinstance(layer, LinearBase):
|
101
|
+
if is_layer_skipped(prefix, self.ignored_layers):
|
102
|
+
return UnquantizedLinearMethod()
|
103
|
+
return BlockInt8LinearMethod(self)
|
104
|
+
elif isinstance(layer, FusedMoE):
|
105
|
+
return BlockInt8MoEMethod(self)
|
106
|
+
return None
|
107
|
+
|
108
|
+
def get_scaled_act_names(self) -> List[str]:
|
109
|
+
return []
|
110
|
+
|
111
|
+
|
112
|
+
class BlockInt8LinearMethod(LinearMethodBase):
|
113
|
+
"""Linear method for INT8.
|
114
|
+
Supports loading INT8 checkpoints with static weight scale and
|
115
|
+
dynamic activation scale.
|
116
|
+
|
117
|
+
Limitations:
|
118
|
+
Only support block-wise int8 quantization and int8 checkpoint
|
119
|
+
|
120
|
+
Args:
|
121
|
+
quant_config: The quantization config.
|
122
|
+
"""
|
123
|
+
|
124
|
+
def __init__(self, quant_config: BlockInt8Config):
|
125
|
+
self.quant_config = quant_config
|
126
|
+
assert self.quant_config.weight_block_size is not None
|
127
|
+
assert self.quant_config.is_checkpoint_int8_serialized
|
128
|
+
|
129
|
+
def create_weights(
|
130
|
+
self,
|
131
|
+
layer: torch.nn.Module,
|
132
|
+
input_size_per_partition: int,
|
133
|
+
output_partition_sizes: List[int],
|
134
|
+
input_size: int,
|
135
|
+
output_size: int,
|
136
|
+
params_dtype: torch.dtype,
|
137
|
+
**extra_weight_attrs,
|
138
|
+
):
|
139
|
+
output_size_per_partition = sum(output_partition_sizes)
|
140
|
+
weight_loader = extra_weight_attrs.get("weight_loader")
|
141
|
+
|
142
|
+
tp_size = get_tensor_model_parallel_world_size()
|
143
|
+
|
144
|
+
block_n, block_k = (
|
145
|
+
self.quant_config.weight_block_size[0],
|
146
|
+
self.quant_config.weight_block_size[1],
|
147
|
+
)
|
148
|
+
# Required by row parallel
|
149
|
+
if tp_size > 1 and input_size // input_size_per_partition == tp_size:
|
150
|
+
if input_size_per_partition % block_k != 0:
|
151
|
+
raise ValueError(
|
152
|
+
f"Weight input_size_per_partition = "
|
153
|
+
f"{input_size_per_partition} is not divisible by "
|
154
|
+
f"weight quantization block_k = {block_k}."
|
155
|
+
)
|
156
|
+
# Required by collum parallel or enabling merged weights
|
157
|
+
if (tp_size > 1 and output_size // output_size_per_partition == tp_size) or len(
|
158
|
+
output_partition_sizes
|
159
|
+
) > 1:
|
160
|
+
for output_partition_size in output_partition_sizes:
|
161
|
+
if output_partition_size % block_n != 0:
|
162
|
+
raise ValueError(
|
163
|
+
f"Weight output_partition_size = "
|
164
|
+
f"{output_partition_size} is not divisible by "
|
165
|
+
f"weight quantization block_n = {block_n}."
|
166
|
+
)
|
167
|
+
|
168
|
+
layer.logical_widths = output_partition_sizes
|
169
|
+
|
170
|
+
layer.input_size_per_partition = input_size_per_partition
|
171
|
+
layer.output_size_per_partition = output_size_per_partition
|
172
|
+
layer.orig_dtype = params_dtype
|
173
|
+
|
174
|
+
# WEIGHT
|
175
|
+
weight_dtype = (
|
176
|
+
torch.int8
|
177
|
+
if self.quant_config.is_checkpoint_int8_serialized
|
178
|
+
else params_dtype
|
179
|
+
)
|
180
|
+
|
181
|
+
weight = ModelWeightParameter(
|
182
|
+
data=torch.empty(
|
183
|
+
output_size_per_partition, input_size_per_partition, dtype=weight_dtype
|
184
|
+
),
|
185
|
+
input_dim=1,
|
186
|
+
output_dim=0,
|
187
|
+
weight_loader=weight_loader,
|
188
|
+
)
|
189
|
+
layer.register_parameter("weight", weight)
|
190
|
+
|
191
|
+
# WEIGHT SCALE
|
192
|
+
|
193
|
+
scale = BlockQuantScaleParameter(
|
194
|
+
data=torch.empty(
|
195
|
+
(output_size_per_partition + block_n - 1) // block_n,
|
196
|
+
(input_size_per_partition + block_k - 1) // block_k,
|
197
|
+
dtype=torch.float32,
|
198
|
+
),
|
199
|
+
input_dim=1,
|
200
|
+
output_dim=0,
|
201
|
+
weight_loader=weight_loader,
|
202
|
+
)
|
203
|
+
scale[:] = torch.finfo(torch.float32).min
|
204
|
+
layer.register_parameter("weight_scale_inv", scale)
|
205
|
+
|
206
|
+
# INPUT ACTIVATION SCALE
|
207
|
+
assert self.quant_config.activation_scheme == "dynamic"
|
208
|
+
layer.register_parameter("input_scale", None)
|
209
|
+
|
210
|
+
def process_weights_after_loading(self, layer: Module) -> None:
|
211
|
+
# Block quant doesn't need to process weights after loading
|
212
|
+
# Use torch Parameter to avoid cuda graph capturing issue
|
213
|
+
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
214
|
+
layer.weight_scale_inv = torch.nn.Parameter(
|
215
|
+
layer.weight_scale_inv.data, requires_grad=False
|
216
|
+
)
|
217
|
+
|
218
|
+
def apply(
|
219
|
+
self,
|
220
|
+
layer: torch.nn.Module,
|
221
|
+
x: torch.Tensor,
|
222
|
+
bias: Optional[torch.Tensor] = None,
|
223
|
+
) -> torch.Tensor:
|
224
|
+
return apply_w8a8_block_int8_linear(
|
225
|
+
input=x,
|
226
|
+
weight=layer.weight,
|
227
|
+
block_size=self.quant_config.weight_block_size,
|
228
|
+
weight_scale=layer.weight_scale_inv,
|
229
|
+
input_scale=None,
|
230
|
+
bias=bias,
|
231
|
+
)
|
232
|
+
|
233
|
+
|
234
|
+
class BlockInt8MoEMethod:
|
235
|
+
"""MoE method for INT8.
|
236
|
+
Supports loading INT8 checkpoints with static weight scale and
|
237
|
+
dynamic activation scale.
|
238
|
+
|
239
|
+
Limitations:
|
240
|
+
Only support block-wise int8 quantization and int8 checkpoint
|
241
|
+
|
242
|
+
Args:
|
243
|
+
quant_config: The quantization config.
|
244
|
+
"""
|
245
|
+
|
246
|
+
def __new__(cls, *args, **kwargs):
|
247
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
248
|
+
|
249
|
+
if not hasattr(cls, "_initialized"):
|
250
|
+
original_init = cls.__init__
|
251
|
+
new_cls = type(
|
252
|
+
cls.__name__,
|
253
|
+
(FusedMoEMethodBase,),
|
254
|
+
{
|
255
|
+
"__init__": original_init,
|
256
|
+
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
257
|
+
},
|
258
|
+
)
|
259
|
+
obj = super(new_cls, new_cls).__new__(new_cls)
|
260
|
+
obj.__init__(*args, **kwargs)
|
261
|
+
return obj
|
262
|
+
return super().__new__(cls)
|
263
|
+
|
264
|
+
def __init__(self, quant_config):
|
265
|
+
self.quant_config = quant_config
|
266
|
+
assert self.quant_config.weight_block_size is not None
|
267
|
+
assert self.quant_config.is_checkpoint_int8_serialized
|
268
|
+
|
269
|
+
def create_weights(
|
270
|
+
self,
|
271
|
+
layer: Module,
|
272
|
+
num_experts: int,
|
273
|
+
hidden_size: int,
|
274
|
+
intermediate_size: int,
|
275
|
+
params_dtype: torch.dtype,
|
276
|
+
**extra_weight_attrs,
|
277
|
+
):
|
278
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
279
|
+
|
280
|
+
if self.quant_config.is_checkpoint_int8_serialized:
|
281
|
+
params_dtype = torch.int8
|
282
|
+
tp_size = get_tensor_model_parallel_world_size()
|
283
|
+
|
284
|
+
block_n, block_k = (
|
285
|
+
self.quant_config.weight_block_size[0],
|
286
|
+
self.quant_config.weight_block_size[1],
|
287
|
+
)
|
288
|
+
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
289
|
+
# Required by collum parallel or enabling merged weights
|
290
|
+
if intermediate_size % block_n != 0:
|
291
|
+
raise ValueError(
|
292
|
+
f"The output_size of gate's and up's weight = "
|
293
|
+
f"{intermediate_size} is not divisible by "
|
294
|
+
f"weight quantization block_n = {block_n}."
|
295
|
+
)
|
296
|
+
if tp_size > 1:
|
297
|
+
# Required by row parallel
|
298
|
+
if intermediate_size % block_k != 0:
|
299
|
+
raise ValueError(
|
300
|
+
f"The input_size of down's weight = "
|
301
|
+
f"{intermediate_size} is not divisible by "
|
302
|
+
f"weight quantization block_k = {block_k}."
|
303
|
+
)
|
304
|
+
|
305
|
+
# WEIGHTS
|
306
|
+
w13_weight = torch.nn.Parameter(
|
307
|
+
torch.empty(
|
308
|
+
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
|
309
|
+
),
|
310
|
+
requires_grad=False,
|
311
|
+
)
|
312
|
+
layer.register_parameter("w13_weight", w13_weight)
|
313
|
+
set_weight_attrs(w13_weight, extra_weight_attrs)
|
314
|
+
|
315
|
+
w2_weight = torch.nn.Parameter(
|
316
|
+
torch.empty(
|
317
|
+
num_experts, hidden_size, intermediate_size, dtype=params_dtype
|
318
|
+
),
|
319
|
+
requires_grad=False,
|
320
|
+
)
|
321
|
+
layer.register_parameter("w2_weight", w2_weight)
|
322
|
+
set_weight_attrs(w2_weight, extra_weight_attrs)
|
323
|
+
|
324
|
+
# WEIGHT_SCALES
|
325
|
+
w13_weight_scale = torch.nn.Parameter(
|
326
|
+
torch.ones(
|
327
|
+
num_experts,
|
328
|
+
2 * ((intermediate_size + block_n - 1) // block_n),
|
329
|
+
(hidden_size + block_k - 1) // block_k,
|
330
|
+
dtype=torch.float32,
|
331
|
+
),
|
332
|
+
requires_grad=False,
|
333
|
+
)
|
334
|
+
w2_weight_scale = torch.nn.Parameter(
|
335
|
+
torch.ones(
|
336
|
+
num_experts,
|
337
|
+
(hidden_size + block_n - 1) // block_n,
|
338
|
+
(intermediate_size + block_k - 1) // block_k,
|
339
|
+
dtype=torch.float32,
|
340
|
+
),
|
341
|
+
requires_grad=False,
|
342
|
+
)
|
343
|
+
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
344
|
+
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
345
|
+
|
346
|
+
extra_weight_attrs.update(
|
347
|
+
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
|
348
|
+
)
|
349
|
+
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
350
|
+
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
351
|
+
|
352
|
+
# INPUT_SCALES
|
353
|
+
assert self.quant_config.activation_scheme == "dynamic"
|
354
|
+
layer.w13_input_scale = None
|
355
|
+
layer.w2_input_scale = None
|
356
|
+
|
357
|
+
def process_weights_after_loading(self, layer: Module) -> None:
|
358
|
+
# Block quant doesn't need to process weights after loading
|
359
|
+
return
|
360
|
+
|
361
|
+
def apply(
|
362
|
+
self,
|
363
|
+
layer: torch.nn.Module,
|
364
|
+
x: torch.Tensor,
|
365
|
+
router_logits: torch.Tensor,
|
366
|
+
top_k: int,
|
367
|
+
renormalize: bool,
|
368
|
+
use_grouped_topk: bool,
|
369
|
+
topk_group: Optional[int] = None,
|
370
|
+
num_expert_group: Optional[int] = None,
|
371
|
+
custom_routing_function: Optional[Callable] = None,
|
372
|
+
correction_bias: Optional[torch.Tensor] = None,
|
373
|
+
activation: str = "silu",
|
374
|
+
inplace: bool = True,
|
375
|
+
no_combine: bool = False,
|
376
|
+
) -> torch.Tensor:
|
377
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
378
|
+
from sglang.srt.layers.moe.topk import select_experts
|
379
|
+
|
380
|
+
# Expert selection
|
381
|
+
topk_weights, topk_ids = select_experts(
|
382
|
+
hidden_states=x,
|
383
|
+
router_logits=router_logits,
|
384
|
+
use_grouped_topk=use_grouped_topk,
|
385
|
+
top_k=top_k,
|
386
|
+
renormalize=renormalize,
|
387
|
+
topk_group=topk_group,
|
388
|
+
num_expert_group=num_expert_group,
|
389
|
+
custom_routing_function=custom_routing_function,
|
390
|
+
correction_bias=correction_bias,
|
391
|
+
)
|
392
|
+
|
393
|
+
# Expert fusion with INT8 quantization
|
394
|
+
return fused_experts(
|
395
|
+
x,
|
396
|
+
layer.w13_weight,
|
397
|
+
layer.w2_weight,
|
398
|
+
topk_weights=topk_weights,
|
399
|
+
topk_ids=topk_ids,
|
400
|
+
inplace=inplace,
|
401
|
+
activation=activation,
|
402
|
+
use_int8_w8a8=True,
|
403
|
+
w1_scale=(layer.w13_weight_scale_inv),
|
404
|
+
w2_scale=(layer.w2_weight_scale_inv),
|
405
|
+
a1_scale=layer.w13_input_scale,
|
406
|
+
a2_scale=layer.w2_input_scale,
|
407
|
+
block_shape=self.quant_config.weight_block_size,
|
408
|
+
no_combine=no_combine,
|
409
|
+
)
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 64,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 5
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 64,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 64,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 3
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 32,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 64,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 5
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 32,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 16,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 5
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 32,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 32,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 5
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 32,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 64,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 5
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 32,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 4
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 64,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 4
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 64,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 64,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 4
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 64,
|
76
|
+
"BLOCK_SIZE_N": 64,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 32,
|
79
|
+
"num_warps": 8,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 64,
|
84
|
+
"BLOCK_SIZE_N": 32,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 1,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 64,
|
92
|
+
"BLOCK_SIZE_N": 64,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 64,
|
95
|
+
"num_warps": 8,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 128,
|
100
|
+
"BLOCK_SIZE_N": 64,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 64,
|
103
|
+
"num_warps": 8,
|
104
|
+
"num_stages": 5
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 32,
|
108
|
+
"BLOCK_SIZE_N": 64,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 32,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 128,
|
116
|
+
"BLOCK_SIZE_N": 64,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 64,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 256,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 1,
|
127
|
+
"num_warps": 8,
|
128
|
+
"num_stages": 4
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 128,
|
132
|
+
"BLOCK_SIZE_N": 32,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 16,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 4
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 64,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 1,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 3
|
145
|
+
}
|
146
|
+
}
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 32,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 64,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 5
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 32,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 64,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 5
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 32,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 64,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 5
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 32,
|
28
|
+
"BLOCK_SIZE_N": 32,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 16,
|
31
|
+
"num_warps": 8,
|
32
|
+
"num_stages": 5
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 32,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 16,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 5
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 32,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 32,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 5
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 32,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 16,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 5
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 64,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 5
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 64,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 32,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 5
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 32,
|
76
|
+
"BLOCK_SIZE_N": 64,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 1,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 4
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 32,
|
84
|
+
"BLOCK_SIZE_N": 64,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 16,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 4
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 32,
|
92
|
+
"BLOCK_SIZE_N": 64,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 1,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 5
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 64,
|
100
|
+
"BLOCK_SIZE_N": 64,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 1,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 5
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 1,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 128,
|
116
|
+
"BLOCK_SIZE_N": 64,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 1,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 64,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 64,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 128,
|
140
|
+
"BLOCK_SIZE_N": 64,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 1,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 3
|
145
|
+
}
|
146
|
+
}
|