sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +4 -2
- sglang/bench_one_batch.py +3 -13
- sglang/bench_one_batch_server.py +143 -15
- sglang/bench_serving.py +158 -8
- sglang/compile_deep_gemm.py +1 -1
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +119 -75
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +5 -2
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +18 -0
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +71 -53
- sglang/srt/conversation.py +78 -46
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +11 -3
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +236 -138
- sglang/srt/disaggregation/nixl/conn.py +242 -71
- sglang/srt/disaggregation/prefill.py +7 -4
- sglang/srt/disaggregation/utils.py +51 -2
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +31 -4
- sglang/srt/entrypoints/http_server.py +45 -3
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/function_call_parser.py +2 -2
- sglang/srt/hf_transformers_utils.py +20 -1
- sglang/srt/layers/attention/flashattention_backend.py +147 -51
- sglang/srt/layers/attention/flashinfer_backend.py +23 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/utils.py +4 -2
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/dp_attention.py +71 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
- sglang/srt/layers/moe/ep_moe/layer.py +121 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +77 -71
- sglang/srt/layers/quantization/fp8.py +110 -97
- sglang/srt/layers/quantization/fp8_kernel.py +81 -62
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/int8_kernel.py +2 -2
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +11 -14
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/io_struct.py +13 -1
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/schedule_batch.py +93 -23
- sglang/srt/managers/schedule_policy.py +11 -8
- sglang/srt/managers/scheduler.py +140 -100
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/tokenizer_manager.py +157 -47
- sglang/srt/managers/tp_worker.py +21 -21
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +4 -2
- sglang/srt/metrics/collector.py +312 -37
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +1 -1
- sglang/srt/model_executor/model_runner.py +57 -41
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +3 -3
- sglang/srt/models/deepseek_nextn.py +1 -20
- sglang/srt/models/deepseek_v2.py +77 -39
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/llama.py +3 -1
- sglang/srt/models/llama4.py +58 -13
- sglang/srt/models/llava.py +248 -5
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +52 -42
- sglang/srt/openai_api/protocol.py +20 -16
- sglang/srt/reasoning_parser.py +1 -1
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +2 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +64 -10
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +7 -7
- sglang/srt/speculative/eagle_worker.py +22 -19
- sglang/srt/utils.py +41 -6
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +92 -15
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 128,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 16,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 64,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 32,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 5
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 16,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 4
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 128,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 16,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 64,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 64,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 64,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 4
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 16,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 4
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 16,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 64,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 32,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 32,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 4
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 64,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 16,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 4
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 16,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 4
|
145
|
+
}
|
146
|
+
}
|
@@ -29,6 +29,7 @@ from sglang.srt.utils import (
|
|
29
29
|
get_device_name,
|
30
30
|
is_cuda,
|
31
31
|
is_hip,
|
32
|
+
log_info_on_rank0,
|
32
33
|
)
|
33
34
|
|
34
35
|
_is_hip = is_hip()
|
@@ -945,7 +946,9 @@ def get_moe_configs(
|
|
945
946
|
# For example, updating the Triton version might cause all old configs to become suboptimal.
|
946
947
|
# To achieve the best performance, consider re-tuning the Triton fused MOE kernel in your environment.
|
947
948
|
# For the tuning method, refer to: https://github.com/sgl-project/sglang/tree/main/benchmark/kernels/fused_moe_triton
|
948
|
-
|
949
|
+
log_info_on_rank0(
|
950
|
+
logger, f"Using MoE kernel config from {config_file_path}."
|
951
|
+
)
|
949
952
|
# If a configuration has been found, return it
|
950
953
|
return {int(key): val for key, val in json.load(f).items()}
|
951
954
|
|
@@ -991,7 +994,7 @@ def get_default_config(
|
|
991
994
|
"num_stages": 2 if _is_hip else 4,
|
992
995
|
}
|
993
996
|
else:
|
994
|
-
# Block-wise quant: BLOCK_SIZE_K must be
|
997
|
+
# Block-wise quant: BLOCK_SIZE_K must be divisible by block_shape[1]
|
995
998
|
config = {
|
996
999
|
"BLOCK_SIZE_M": 64,
|
997
1000
|
"BLOCK_SIZE_N": block_shape[0],
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -270,7 +270,7 @@ def select_experts(
|
|
270
270
|
routed_scaling_factor: Optional[float] = None,
|
271
271
|
):
|
272
272
|
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
273
|
-
#
|
273
|
+
# DeepSeek V2/V3/R1 series models use grouped_top_k
|
274
274
|
if use_grouped_topk:
|
275
275
|
assert topk_group is not None
|
276
276
|
assert num_expert_group is not None
|
@@ -109,7 +109,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|
109
109
|
if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
|
110
110
|
raise ValueError(
|
111
111
|
f"{quantization} quantization requires some operators from vllm. "
|
112
|
-
"
|
112
|
+
"Please install vllm by `pip install vllm==0.8.4`"
|
113
113
|
)
|
114
114
|
|
115
115
|
return QUANTIZATION_METHODS[quantization]
|
@@ -152,7 +152,7 @@ class BlockInt8LinearMethod(LinearMethodBase):
|
|
152
152
|
f"{input_size_per_partition} is not divisible by "
|
153
153
|
f"weight quantization block_k = {block_k}."
|
154
154
|
)
|
155
|
-
# Required by
|
155
|
+
# Required by column parallel or enabling merged weights
|
156
156
|
if (tp_size > 1 and output_size // output_size_per_partition == tp_size) or len(
|
157
157
|
output_partition_sizes
|
158
158
|
) > 1:
|
@@ -285,7 +285,7 @@ class BlockInt8MoEMethod:
|
|
285
285
|
self.quant_config.weight_block_size[1],
|
286
286
|
)
|
287
287
|
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
288
|
-
# Required by
|
288
|
+
# Required by column parallel or enabling merged weights
|
289
289
|
if intermediate_size % block_n != 0:
|
290
290
|
raise ValueError(
|
291
291
|
f"The output_size of gate's and up's weight = "
|
@@ -10,16 +10,14 @@ import torch
|
|
10
10
|
from compressed_tensors import CompressionFormat
|
11
11
|
from compressed_tensors.quantization import QuantizationStrategy
|
12
12
|
|
13
|
-
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
13
|
+
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
|
14
14
|
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
15
15
|
from sglang.srt.layers.quantization.utils import (
|
16
16
|
all_close_1d,
|
17
|
-
is_cuda,
|
18
|
-
is_fp8_fnuz,
|
19
17
|
per_tensor_dequantize,
|
20
18
|
replace_parameter,
|
21
19
|
)
|
22
|
-
from sglang.srt.utils import set_weight_attrs
|
20
|
+
from sglang.srt.utils import is_cuda, set_weight_attrs
|
23
21
|
|
24
22
|
_is_cuda = is_cuda()
|
25
23
|
|
@@ -15,11 +15,12 @@ from sglang.srt.layers.parameter import (
|
|
15
15
|
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
16
16
|
CompressedTensorsScheme,
|
17
17
|
)
|
18
|
+
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
|
18
19
|
from sglang.srt.layers.quantization.fp8_utils import (
|
19
20
|
apply_fp8_linear,
|
20
21
|
normalize_e4m3fn_to_e4m3fnuz,
|
21
22
|
)
|
22
|
-
from sglang.srt.layers.quantization.utils import
|
23
|
+
from sglang.srt.layers.quantization.utils import requantize_with_max_scale
|
23
24
|
|
24
25
|
__all__ = ["CompressedTensorsW8A8Fp8"]
|
25
26
|
|
@@ -15,12 +15,9 @@ _ENABLE_JIT_DEEPGEMM = False
|
|
15
15
|
if is_cuda():
|
16
16
|
import deep_gemm
|
17
17
|
from deep_gemm import get_num_sms
|
18
|
+
from deep_gemm.jit.compiler import get_nvcc_compiler
|
18
19
|
from deep_gemm.jit_kernels.gemm import get_best_configs
|
19
|
-
from deep_gemm.jit_kernels.
|
20
|
-
from deep_gemm.jit_kernels.gemm import template as deep_gemm_gemm_template
|
21
|
-
from deep_gemm.jit_kernels.m_grouped_gemm import (
|
22
|
-
template as deep_gemm_grouped_gemm_template,
|
23
|
-
)
|
20
|
+
from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
|
24
21
|
from deep_gemm.jit_kernels.tuner import jit_tuner
|
25
22
|
|
26
23
|
sm_version = get_device_sm()
|
@@ -28,6 +25,11 @@ if is_cuda():
|
|
28
25
|
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
|
29
26
|
_ENABLE_JIT_DEEPGEMM = True
|
30
27
|
|
28
|
+
|
29
|
+
def get_enable_jit_deepgemm():
|
30
|
+
return _ENABLE_JIT_DEEPGEMM
|
31
|
+
|
32
|
+
|
31
33
|
logger = logging.getLogger(__name__)
|
32
34
|
|
33
35
|
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
|
@@ -40,10 +42,25 @@ _COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
|
|
40
42
|
_IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false")
|
41
43
|
|
42
44
|
# Force redirect deep_gemm cache_dir
|
43
|
-
os.environ["
|
44
|
-
"SGL_DG_CACHE_DIR", os.path.expanduser("~")
|
45
|
+
os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
|
46
|
+
"SGL_DG_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "deep_gemm")
|
45
47
|
)
|
46
48
|
|
49
|
+
# Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
|
50
|
+
# NVRTC may have performance loss with some cases.
|
51
|
+
# And NVCC JIT speed is also 9x faster in the ref commit
|
52
|
+
_USE_NVRTC_DEFAULT = "0"
|
53
|
+
if _ENABLE_JIT_DEEPGEMM:
|
54
|
+
try:
|
55
|
+
get_nvcc_compiler()
|
56
|
+
except:
|
57
|
+
logger.warning(
|
58
|
+
"NVCC Compiler not found, use NVRTC for DeepGEMM JIT "
|
59
|
+
"and may have performance loss with some cases."
|
60
|
+
)
|
61
|
+
_USE_NVRTC_DEFAULT = "1"
|
62
|
+
os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", _USE_NVRTC_DEFAULT)
|
63
|
+
|
47
64
|
|
48
65
|
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
|
49
66
|
global _BUILTIN_M_LIST
|
@@ -98,10 +115,10 @@ _INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dic
|
|
98
115
|
def _compile_warning_1():
|
99
116
|
if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
|
100
117
|
logger.warning(
|
101
|
-
"Entering DeepGEMM JIT Pre-
|
118
|
+
"Entering DeepGEMM JIT Pre-Compile session. "
|
102
119
|
"And it may takes a long time(Typically 10-20 mins) "
|
103
120
|
"if you have not run `sglang.compile_deep_gemm`. "
|
104
|
-
"
|
121
|
+
"It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
105
122
|
" for pre-compilation to reduce the overhead if you have not run it before. "
|
106
123
|
"For example: "
|
107
124
|
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
|
@@ -110,7 +127,7 @@ def _compile_warning_1():
|
|
110
127
|
|
111
128
|
def _compile_warning_2():
|
112
129
|
logger.warning(
|
113
|
-
"Entering DeepGEMM JIT Single Kernel
|
130
|
+
"Entering DeepGEMM JIT Single Kernel Compile session. "
|
114
131
|
"And it will makes inference throughput becomes flaky. "
|
115
132
|
"Please run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
116
133
|
" for pre-compilation to solve this issue. "
|
@@ -125,10 +142,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
|
|
125
142
|
num_groups: int,
|
126
143
|
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
127
144
|
) -> None:
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
145
|
+
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
146
|
+
block_k = 128
|
147
|
+
num_tma_threads = 128
|
148
|
+
num_math_threads_per_group = 128
|
149
|
+
kwargs = {
|
150
|
+
"NUM_TMA_THREADS": num_tma_threads,
|
151
|
+
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
152
|
+
"BLOCK_K": block_k,
|
153
|
+
"NUM_SMS": num_sms,
|
154
|
+
"SMEM_SIZE": smem_config[0],
|
155
|
+
}
|
156
|
+
_, _ = jit_tuner.compile_and_tune(
|
132
157
|
name="m_grouped_gemm_fp8_fp8_bf16_nt",
|
133
158
|
keys={
|
134
159
|
"N": n,
|
@@ -141,24 +166,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
|
|
141
166
|
"NUM_STAGES": num_stages,
|
142
167
|
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
143
168
|
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
144
|
-
"GEMM_TYPE":
|
169
|
+
"GEMM_TYPE": GemmType.GroupedMasked,
|
145
170
|
},
|
146
171
|
space=(),
|
147
|
-
|
148
|
-
|
149
|
-
("lhs", torch.float8_e4m3fn),
|
150
|
-
("lhs_scales", torch.float),
|
151
|
-
("rhs", torch.float8_e4m3fn),
|
152
|
-
("rhs_scales", torch.float),
|
153
|
-
("out", torch.bfloat16),
|
154
|
-
("grouped_layout", torch.int32),
|
155
|
-
("m", int),
|
156
|
-
("stream", torch.cuda.Stream),
|
157
|
-
("num_sms", int),
|
158
|
-
("smem_size", int),
|
159
|
-
),
|
160
|
-
template=deep_gemm_grouped_gemm_template,
|
161
|
-
args=[],
|
172
|
+
kwargs=kwargs,
|
173
|
+
runtime_cls=FP8GemmRuntime,
|
162
174
|
)
|
163
175
|
|
164
176
|
|
@@ -168,9 +180,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
|
|
168
180
|
num_groups: int,
|
169
181
|
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
170
182
|
) -> None:
|
171
|
-
|
172
|
-
|
173
|
-
|
183
|
+
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
184
|
+
block_k = 128
|
185
|
+
num_tma_threads = 128
|
186
|
+
num_math_threads_per_group = 128
|
187
|
+
kwargs = {
|
188
|
+
"NUM_TMA_THREADS": num_tma_threads,
|
189
|
+
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
190
|
+
"BLOCK_K": block_k,
|
191
|
+
"NUM_SMS": num_sms,
|
192
|
+
"SMEM_SIZE": smem_config[0],
|
193
|
+
}
|
194
|
+
_, _ = jit_tuner.compile_and_tune(
|
174
195
|
name="m_grouped_gemm_fp8_fp8_bf16_nt",
|
175
196
|
keys={
|
176
197
|
"N": n,
|
@@ -183,25 +204,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
|
|
183
204
|
"NUM_STAGES": num_stages,
|
184
205
|
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
185
206
|
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
186
|
-
"GEMM_TYPE":
|
207
|
+
"GEMM_TYPE": GemmType.GroupedContiguous,
|
187
208
|
},
|
188
209
|
space=(),
|
189
|
-
|
190
|
-
|
191
|
-
("lhs", torch.float8_e4m3fn),
|
192
|
-
("lhs_scales", torch.float),
|
193
|
-
("rhs", torch.float8_e4m3fn),
|
194
|
-
("rhs_scales", torch.float),
|
195
|
-
("out", torch.bfloat16),
|
196
|
-
("grouped_layout", torch.int32),
|
197
|
-
("m", int),
|
198
|
-
("num_groups", int),
|
199
|
-
("stream", torch.cuda.Stream),
|
200
|
-
("num_sms", int),
|
201
|
-
("smem_size", int),
|
202
|
-
),
|
203
|
-
template=deep_gemm_grouped_gemm_template,
|
204
|
-
args=[],
|
210
|
+
kwargs=kwargs,
|
211
|
+
runtime_cls=FP8GemmRuntime,
|
205
212
|
)
|
206
213
|
|
207
214
|
|
@@ -211,9 +218,20 @@ def _compile_gemm_nt_f8f8bf16_one(
|
|
211
218
|
_: int, # _ is a dummy parameter to align with other interfaces
|
212
219
|
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
213
220
|
) -> None:
|
214
|
-
|
215
|
-
|
216
|
-
|
221
|
+
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
222
|
+
block_k = 128
|
223
|
+
num_tma_threads = 128
|
224
|
+
num_math_threads_per_group = 128
|
225
|
+
kwargs = {
|
226
|
+
"GEMM_TYPE": GemmType.Normal,
|
227
|
+
"NUM_TMA_THREADS": num_tma_threads,
|
228
|
+
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
229
|
+
"NUM_GROUPS": 1,
|
230
|
+
"BLOCK_K": block_k,
|
231
|
+
"NUM_SMS": num_sms,
|
232
|
+
"SMEM_SIZE": smem_config[0],
|
233
|
+
}
|
234
|
+
_, _ = jit_tuner.compile_and_tune(
|
217
235
|
name="gemm_fp8_fp8_bf16_nt",
|
218
236
|
keys={
|
219
237
|
"N": n,
|
@@ -227,20 +245,8 @@ def _compile_gemm_nt_f8f8bf16_one(
|
|
227
245
|
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
228
246
|
},
|
229
247
|
space=(),
|
230
|
-
|
231
|
-
|
232
|
-
("lhs", torch.float8_e4m3fn),
|
233
|
-
("lhs_scales", torch.float),
|
234
|
-
("rhs", torch.float8_e4m3fn),
|
235
|
-
("rhs_scales", torch.float),
|
236
|
-
("out", torch.bfloat16),
|
237
|
-
("m", int),
|
238
|
-
("stream", torch.cuda.Stream),
|
239
|
-
("num_sms", int),
|
240
|
-
("smem_size", int),
|
241
|
-
),
|
242
|
-
template=deep_gemm_gemm_template,
|
243
|
-
args=[],
|
248
|
+
kwargs=kwargs,
|
249
|
+
runtime_cls=FP8GemmRuntime,
|
244
250
|
)
|
245
251
|
|
246
252
|
|
@@ -293,7 +299,7 @@ def _maybe_compile_deep_gemm_one_type_all(
|
|
293
299
|
logger.info(
|
294
300
|
f"Try DeepGEMM JIT Compiling for "
|
295
301
|
f"<{kernel_helper.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
|
296
|
-
f"{' It only takes a
|
302
|
+
f"{' It only takes a little time (typically 1 sec) if you have run `python3 -m sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}"
|
297
303
|
)
|
298
304
|
|
299
305
|
# NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
|
@@ -368,7 +374,7 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
|
|
368
374
|
|
369
375
|
from deep_gemm.jit.runtime import RuntimeCache
|
370
376
|
|
371
|
-
origin_func = RuntimeCache.
|
377
|
+
origin_func = RuntimeCache.get
|
372
378
|
|
373
379
|
def __patched_func(self, *args, **kwargs):
|
374
380
|
ret = origin_func(self, *args, **kwargs)
|
@@ -380,6 +386,6 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
|
|
380
386
|
)
|
381
387
|
return ret
|
382
388
|
|
383
|
-
RuntimeCache.
|
389
|
+
RuntimeCache.get = __patched_func
|
384
390
|
yield
|
385
|
-
RuntimeCache.
|
391
|
+
RuntimeCache.get = origin_func
|