sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +10 -8
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +157 -21
- sglang/bench_serving.py +137 -59
- sglang/compile_deep_gemm.py +5 -5
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +40 -28
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +49 -44
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +129 -135
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +238 -122
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +10 -19
- sglang/srt/disaggregation/prefill.py +132 -47
- sglang/srt/disaggregation/utils.py +123 -6
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +44 -9
- sglang/srt/entrypoints/http_server.py +23 -6
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +64 -18
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +6 -4
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +61 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
- sglang/srt/layers/moe/ep_moe/layer.py +105 -51
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +67 -10
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +8 -3
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +77 -74
- sglang/srt/layers/quantization/fp8.py +92 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +20 -7
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +2 -4
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +19 -4
- sglang/srt/managers/mm_utils.py +294 -140
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +122 -42
- sglang/srt/managers/schedule_policy.py +1 -5
- sglang/srt/managers/scheduler.py +205 -138
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +232 -58
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +76 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +314 -39
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +29 -19
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +5 -1
- sglang/srt/model_executor/model_runner.py +163 -68
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +308 -351
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +15 -8
- sglang/srt/models/llava.py +258 -7
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +58 -20
- sglang/srt/openai_api/protocol.py +6 -8
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +162 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +138 -7
- sglang/srt/speculative/eagle_worker.py +69 -21
- sglang/srt/utils.py +74 -17
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +55 -14
- sglang/utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
sglang/srt/layers/moe/topk.py
CHANGED
@@ -18,7 +18,14 @@ from typing import Callable, Optional
|
|
18
18
|
import torch
|
19
19
|
import torch.nn.functional as F
|
20
20
|
|
21
|
-
from sglang.srt.managers.expert_distribution import
|
21
|
+
from sglang.srt.managers.expert_distribution import (
|
22
|
+
ExpertDistributionRecorder,
|
23
|
+
get_global_expert_distribution_recorder,
|
24
|
+
)
|
25
|
+
from sglang.srt.managers.expert_location_dispatch import (
|
26
|
+
ExpertLocationDispatchInfo,
|
27
|
+
topk_ids_logical_to_physical,
|
28
|
+
)
|
22
29
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
23
30
|
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
|
24
31
|
|
@@ -32,9 +39,6 @@ if _is_cuda or _is_hip:
|
|
32
39
|
from sgl_kernel import topk_softmax
|
33
40
|
|
34
41
|
|
35
|
-
expert_distribution_recorder = ExpertDistributionRecorder()
|
36
|
-
|
37
|
-
|
38
42
|
def fused_topk_native(
|
39
43
|
hidden_states: torch.Tensor,
|
40
44
|
gating_output: torch.Tensor,
|
@@ -61,6 +65,7 @@ def fused_topk(
|
|
61
65
|
gating_output: torch.Tensor,
|
62
66
|
topk: int,
|
63
67
|
renormalize: bool,
|
68
|
+
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
64
69
|
):
|
65
70
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
66
71
|
|
@@ -84,7 +89,7 @@ def fused_topk(
|
|
84
89
|
|
85
90
|
if renormalize:
|
86
91
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
87
|
-
|
92
|
+
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
88
93
|
return topk_weights, topk_ids
|
89
94
|
|
90
95
|
|
@@ -99,6 +104,8 @@ def grouped_topk(
|
|
99
104
|
topk_group: int = 0,
|
100
105
|
n_share_experts_fusion: int = 0,
|
101
106
|
routed_scaling_factor: Optional[float] = None,
|
107
|
+
num_token_non_padded: Optional[torch.Tensor] = None,
|
108
|
+
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
102
109
|
):
|
103
110
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
104
111
|
|
@@ -138,7 +145,10 @@ def grouped_topk(
|
|
138
145
|
)
|
139
146
|
topk_weights = topk_weights / topk_weights_sum
|
140
147
|
|
141
|
-
|
148
|
+
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
149
|
+
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
150
|
+
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
151
|
+
return topk_weights, topk_ids
|
142
152
|
|
143
153
|
|
144
154
|
def biased_grouped_topk_impl(
|
@@ -151,6 +161,8 @@ def biased_grouped_topk_impl(
|
|
151
161
|
topk_group: int = 0,
|
152
162
|
n_share_experts_fusion: int = 0,
|
153
163
|
routed_scaling_factor: Optional[float] = None,
|
164
|
+
num_token_non_padded: Optional[torch.Tensor] = None,
|
165
|
+
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
154
166
|
):
|
155
167
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
156
168
|
|
@@ -197,13 +209,26 @@ def biased_grouped_topk_impl(
|
|
197
209
|
)
|
198
210
|
topk_weights = topk_weights / topk_weights_sum
|
199
211
|
|
200
|
-
|
212
|
+
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
213
|
+
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
214
|
+
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
215
|
+
return topk_weights, topk_ids
|
201
216
|
|
202
217
|
|
203
218
|
def is_power_of_two(n):
|
204
219
|
return n > 0 and math.log2(n).is_integer()
|
205
220
|
|
206
221
|
|
222
|
+
def _mask_topk_ids_padded_region(
|
223
|
+
topk_ids: torch.Tensor,
|
224
|
+
num_token_non_padded: Optional[torch.Tensor] = None,
|
225
|
+
):
|
226
|
+
if num_token_non_padded is None:
|
227
|
+
return
|
228
|
+
indices = torch.arange(0, topk_ids.shape[0], device=topk_ids.device)
|
229
|
+
topk_ids[indices >= num_token_non_padded, :] = -1
|
230
|
+
|
231
|
+
|
207
232
|
def biased_grouped_topk(
|
208
233
|
hidden_states: torch.Tensor,
|
209
234
|
gating_output: torch.Tensor,
|
@@ -215,6 +240,8 @@ def biased_grouped_topk(
|
|
215
240
|
compiled: bool = True,
|
216
241
|
n_share_experts_fusion: int = 0,
|
217
242
|
routed_scaling_factor: Optional[float] = None,
|
243
|
+
num_token_non_padded: Optional[torch.Tensor] = None,
|
244
|
+
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
218
245
|
):
|
219
246
|
assert (
|
220
247
|
routed_scaling_factor is not None
|
@@ -226,7 +253,7 @@ def biased_grouped_topk(
|
|
226
253
|
<= 32 # moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
|
227
254
|
and is_power_of_two(correction_bias.shape[0])
|
228
255
|
):
|
229
|
-
|
256
|
+
topk_weights, topk_ids = moe_fused_gate(
|
230
257
|
gating_output,
|
231
258
|
correction_bias,
|
232
259
|
num_expert_group,
|
@@ -235,6 +262,15 @@ def biased_grouped_topk(
|
|
235
262
|
n_share_experts_fusion,
|
236
263
|
routed_scaling_factor,
|
237
264
|
)
|
265
|
+
# TODO merge into kernel for this branch
|
266
|
+
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
267
|
+
# TODO will fuse this into kernel, thus use slow manual operation now
|
268
|
+
if num_token_non_padded is None:
|
269
|
+
return topk_weights, topk_ids
|
270
|
+
torch.compile(
|
271
|
+
_mask_topk_ids_padded_region, dynamic=True, backend=get_compiler_backend()
|
272
|
+
)(topk_ids, num_token_non_padded)
|
273
|
+
return topk_weights, topk_ids
|
238
274
|
else:
|
239
275
|
biased_grouped_topk_fn = (
|
240
276
|
torch.compile(
|
@@ -253,6 +289,8 @@ def biased_grouped_topk(
|
|
253
289
|
topk_group,
|
254
290
|
n_share_experts_fusion=n_share_experts_fusion,
|
255
291
|
routed_scaling_factor=routed_scaling_factor,
|
292
|
+
num_token_non_padded=num_token_non_padded,
|
293
|
+
expert_location_dispatch_info=expert_location_dispatch_info,
|
256
294
|
)
|
257
295
|
|
258
296
|
|
@@ -268,9 +306,11 @@ def select_experts(
|
|
268
306
|
correction_bias: Optional[torch.Tensor] = None,
|
269
307
|
torch_native: bool = False,
|
270
308
|
routed_scaling_factor: Optional[float] = None,
|
309
|
+
num_token_non_padded: Optional[torch.Tensor] = None,
|
310
|
+
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
271
311
|
):
|
272
312
|
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
273
|
-
#
|
313
|
+
# DeepSeek V2/V3/R1 series models use grouped_top_k
|
274
314
|
if use_grouped_topk:
|
275
315
|
assert topk_group is not None
|
276
316
|
assert num_expert_group is not None
|
@@ -284,6 +324,8 @@ def select_experts(
|
|
284
324
|
topk_group=topk_group,
|
285
325
|
n_share_experts_fusion=n_share_experts_fusion,
|
286
326
|
routed_scaling_factor=routed_scaling_factor,
|
327
|
+
num_token_non_padded=num_token_non_padded,
|
328
|
+
expert_location_dispatch_info=expert_location_dispatch_info,
|
287
329
|
)
|
288
330
|
else:
|
289
331
|
topk_weights, topk_ids = biased_grouped_topk(
|
@@ -296,8 +338,14 @@ def select_experts(
|
|
296
338
|
topk_group=topk_group,
|
297
339
|
n_share_experts_fusion=n_share_experts_fusion,
|
298
340
|
routed_scaling_factor=routed_scaling_factor,
|
341
|
+
num_token_non_padded=num_token_non_padded,
|
342
|
+
expert_location_dispatch_info=expert_location_dispatch_info,
|
299
343
|
)
|
300
344
|
elif torch_native and custom_routing_function is None:
|
345
|
+
assert (
|
346
|
+
num_token_non_padded is None
|
347
|
+
), "num_token_non_padded is not yet supported in fused_topk_native"
|
348
|
+
assert expert_location_dispatch_info is None
|
301
349
|
topk_weights, topk_ids = fused_topk_native(
|
302
350
|
hidden_states=hidden_states,
|
303
351
|
gating_output=router_logits,
|
@@ -305,13 +353,22 @@ def select_experts(
|
|
305
353
|
renormalize=renormalize,
|
306
354
|
)
|
307
355
|
elif custom_routing_function is None:
|
356
|
+
assert (
|
357
|
+
num_token_non_padded is None
|
358
|
+
), "num_token_non_padded is not yet supported in fused_topk"
|
359
|
+
# Qwen3MOE uses fused_topk
|
308
360
|
topk_weights, topk_ids = fused_topk(
|
309
361
|
hidden_states=hidden_states,
|
310
362
|
gating_output=router_logits,
|
311
363
|
topk=top_k,
|
312
364
|
renormalize=renormalize,
|
365
|
+
expert_location_dispatch_info=expert_location_dispatch_info,
|
313
366
|
)
|
314
367
|
else:
|
368
|
+
assert (
|
369
|
+
num_token_non_padded is None
|
370
|
+
), "num_token_non_padded is not yet supported in custom_routing_function"
|
371
|
+
assert expert_location_dispatch_info is None
|
315
372
|
topk_weights, topk_ids = custom_routing_function(
|
316
373
|
hidden_states=hidden_states,
|
317
374
|
gating_output=router_logits,
|
@@ -319,6 +376,6 @@ def select_experts(
|
|
319
376
|
renormalize=renormalize,
|
320
377
|
)
|
321
378
|
|
322
|
-
|
379
|
+
get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids)
|
323
380
|
|
324
381
|
return topk_weights, topk_ids
|
@@ -0,0 +1,70 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
"""Logits processing."""
|
15
|
+
|
16
|
+
import torch
|
17
|
+
import triton
|
18
|
+
import triton.language as tl
|
19
|
+
|
20
|
+
|
21
|
+
@triton.jit
|
22
|
+
def hash_kernel(
|
23
|
+
input_ptr,
|
24
|
+
output_ptr,
|
25
|
+
n_elements,
|
26
|
+
BLOCK_SIZE: tl.constexpr,
|
27
|
+
PRIME: tl.constexpr,
|
28
|
+
XCONST: tl.constexpr,
|
29
|
+
):
|
30
|
+
pid = tl.program_id(axis=0)
|
31
|
+
block_start = pid * BLOCK_SIZE
|
32
|
+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
33
|
+
mask = offsets < n_elements
|
34
|
+
|
35
|
+
data = tl.load(input_ptr + offsets, mask=mask, other=0)
|
36
|
+
mixed = data ^ (offsets + XCONST)
|
37
|
+
hash_val = mixed * PRIME
|
38
|
+
hash_val = hash_val ^ (hash_val >> 16)
|
39
|
+
hash_val = hash_val * (PRIME ^ XCONST)
|
40
|
+
hash_val = hash_val ^ (hash_val >> 13)
|
41
|
+
|
42
|
+
tl.store(output_ptr + offsets, hash_val, mask=mask)
|
43
|
+
|
44
|
+
|
45
|
+
PRIME_1 = -(11400714785074694791 ^ 0xFFFFFFFFFFFFFFFF) - 1
|
46
|
+
PRIME_2 = -(14029467366897019727 ^ 0xFFFFFFFFFFFFFFFF) - 1
|
47
|
+
|
48
|
+
|
49
|
+
def gpu_tensor_hash(tensor: torch.Tensor) -> int:
|
50
|
+
assert tensor.is_cuda
|
51
|
+
tensor = tensor.contiguous().view(torch.int32)
|
52
|
+
n = tensor.numel()
|
53
|
+
BLOCK_SIZE = 1024
|
54
|
+
grid = (triton.cdiv(n, BLOCK_SIZE),)
|
55
|
+
|
56
|
+
intermediate_hashes = torch.empty(n, dtype=torch.int32, device=tensor.device)
|
57
|
+
|
58
|
+
hash_kernel[grid](
|
59
|
+
tensor,
|
60
|
+
intermediate_hashes,
|
61
|
+
n,
|
62
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
63
|
+
PRIME=PRIME_1,
|
64
|
+
XCONST=PRIME_2,
|
65
|
+
)
|
66
|
+
|
67
|
+
# TODO: threads can't be synced on triton kernel
|
68
|
+
final_hash = intermediate_hashes.sum().item()
|
69
|
+
|
70
|
+
return final_hash
|
@@ -25,7 +25,6 @@ try:
|
|
25
25
|
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
26
26
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
27
27
|
GPTQMarlinLinearMethod,
|
28
|
-
GPTQMarlinMoEMethod,
|
29
28
|
)
|
30
29
|
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
31
30
|
GPTQMarlin24Config,
|
@@ -58,12 +57,17 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
|
|
58
57
|
CompressedTensorsConfig,
|
59
58
|
)
|
60
59
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
61
|
-
from sglang.srt.layers.quantization.gptq import
|
60
|
+
from sglang.srt.layers.quantization.gptq import (
|
61
|
+
GPTQConfig,
|
62
|
+
GPTQMarlinConfig,
|
63
|
+
GPTQMarlinMoEMethod,
|
64
|
+
)
|
62
65
|
from sglang.srt.layers.quantization.modelopt_quant import (
|
63
66
|
ModelOptFp4Config,
|
64
67
|
ModelOptFp8Config,
|
65
68
|
)
|
66
69
|
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
70
|
+
from sglang.srt.layers.quantization.qoq import QoQConfig
|
67
71
|
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
68
72
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
69
73
|
|
@@ -77,6 +81,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|
77
81
|
"w8a8_fp8": W8A8Fp8Config,
|
78
82
|
"moe_wna16": MoeWNA16Config,
|
79
83
|
"compressed-tensors": CompressedTensorsConfig,
|
84
|
+
"qoq": QoQConfig,
|
80
85
|
}
|
81
86
|
|
82
87
|
# VLLM-dependent quantization methods
|
@@ -109,7 +114,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|
109
114
|
if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
|
110
115
|
raise ValueError(
|
111
116
|
f"{quantization} quantization requires some operators from vllm. "
|
112
|
-
"
|
117
|
+
"Please install vllm by `pip install vllm==0.8.4`"
|
113
118
|
)
|
114
119
|
|
115
120
|
return QUANTIZATION_METHODS[quantization]
|
@@ -152,7 +152,7 @@ class BlockInt8LinearMethod(LinearMethodBase):
|
|
152
152
|
f"{input_size_per_partition} is not divisible by "
|
153
153
|
f"weight quantization block_k = {block_k}."
|
154
154
|
)
|
155
|
-
# Required by
|
155
|
+
# Required by column parallel or enabling merged weights
|
156
156
|
if (tp_size > 1 and output_size // output_size_per_partition == tp_size) or len(
|
157
157
|
output_partition_sizes
|
158
158
|
) > 1:
|
@@ -285,7 +285,7 @@ class BlockInt8MoEMethod:
|
|
285
285
|
self.quant_config.weight_block_size[1],
|
286
286
|
)
|
287
287
|
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
288
|
-
# Required by
|
288
|
+
# Required by column parallel or enabling merged weights
|
289
289
|
if intermediate_size % block_n != 0:
|
290
290
|
raise ValueError(
|
291
291
|
f"The output_size of gate's and up's weight = "
|
@@ -11,30 +11,29 @@ from tqdm.contrib.concurrent import thread_map
|
|
11
11
|
from sglang.srt.server_args import ServerArgs
|
12
12
|
from sglang.srt.utils import get_bool_env_var, get_device_sm, get_int_env_var, is_cuda
|
13
13
|
|
14
|
+
logger = logging.getLogger(__name__)
|
14
15
|
_ENABLE_JIT_DEEPGEMM = False
|
15
|
-
|
16
|
+
|
17
|
+
try:
|
16
18
|
import deep_gemm
|
17
19
|
from deep_gemm import get_num_sms
|
20
|
+
from deep_gemm.jit.compiler import get_nvcc_compiler
|
18
21
|
from deep_gemm.jit_kernels.gemm import get_best_configs
|
19
|
-
from deep_gemm.jit_kernels.
|
20
|
-
from deep_gemm.jit_kernels.gemm import template as deep_gemm_gemm_template
|
21
|
-
from deep_gemm.jit_kernels.m_grouped_gemm import (
|
22
|
-
template as deep_gemm_grouped_gemm_template,
|
23
|
-
)
|
22
|
+
from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
|
24
23
|
from deep_gemm.jit_kernels.tuner import jit_tuner
|
25
24
|
|
26
25
|
sm_version = get_device_sm()
|
27
26
|
if sm_version == 90:
|
28
27
|
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
|
29
28
|
_ENABLE_JIT_DEEPGEMM = True
|
29
|
+
except ImportError:
|
30
|
+
logger.warning("Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.")
|
30
31
|
|
31
32
|
|
32
33
|
def get_enable_jit_deepgemm():
|
33
34
|
return _ENABLE_JIT_DEEPGEMM
|
34
35
|
|
35
36
|
|
36
|
-
logger = logging.getLogger(__name__)
|
37
|
-
|
38
37
|
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
|
39
38
|
_ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var(
|
40
39
|
"SGL_JIT_DEEPGEMM_PRECOMPILE", "true"
|
@@ -45,10 +44,25 @@ _COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
|
|
45
44
|
_IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false")
|
46
45
|
|
47
46
|
# Force redirect deep_gemm cache_dir
|
48
|
-
os.environ["
|
49
|
-
"SGL_DG_CACHE_DIR", os.path.expanduser("~")
|
47
|
+
os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
|
48
|
+
"SGL_DG_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "deep_gemm")
|
50
49
|
)
|
51
50
|
|
51
|
+
# Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
|
52
|
+
# NVRTC may have performance loss with some cases.
|
53
|
+
# And NVCC JIT speed is also 9x faster in the ref commit
|
54
|
+
_USE_NVRTC_DEFAULT = "0"
|
55
|
+
if _ENABLE_JIT_DEEPGEMM:
|
56
|
+
try:
|
57
|
+
get_nvcc_compiler()
|
58
|
+
except:
|
59
|
+
logger.warning(
|
60
|
+
"NVCC Compiler not found, use NVRTC for DeepGEMM JIT "
|
61
|
+
"and may have performance loss with some cases."
|
62
|
+
)
|
63
|
+
_USE_NVRTC_DEFAULT = "1"
|
64
|
+
os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", _USE_NVRTC_DEFAULT)
|
65
|
+
|
52
66
|
|
53
67
|
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
|
54
68
|
global _BUILTIN_M_LIST
|
@@ -103,10 +117,10 @@ _INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dic
|
|
103
117
|
def _compile_warning_1():
|
104
118
|
if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
|
105
119
|
logger.warning(
|
106
|
-
"Entering DeepGEMM JIT Pre-
|
120
|
+
"Entering DeepGEMM JIT Pre-Compile session. "
|
107
121
|
"And it may takes a long time(Typically 10-20 mins) "
|
108
122
|
"if you have not run `sglang.compile_deep_gemm`. "
|
109
|
-
"
|
123
|
+
"It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
110
124
|
" for pre-compilation to reduce the overhead if you have not run it before. "
|
111
125
|
"For example: "
|
112
126
|
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
|
@@ -115,7 +129,7 @@ def _compile_warning_1():
|
|
115
129
|
|
116
130
|
def _compile_warning_2():
|
117
131
|
logger.warning(
|
118
|
-
"Entering DeepGEMM JIT Single Kernel
|
132
|
+
"Entering DeepGEMM JIT Single Kernel Compile session. "
|
119
133
|
"And it will makes inference throughput becomes flaky. "
|
120
134
|
"Please run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
121
135
|
" for pre-compilation to solve this issue. "
|
@@ -130,10 +144,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
|
|
130
144
|
num_groups: int,
|
131
145
|
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
132
146
|
) -> None:
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
147
|
+
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
148
|
+
block_k = 128
|
149
|
+
num_tma_threads = 128
|
150
|
+
num_math_threads_per_group = 128
|
151
|
+
kwargs = {
|
152
|
+
"NUM_TMA_THREADS": num_tma_threads,
|
153
|
+
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
154
|
+
"BLOCK_K": block_k,
|
155
|
+
"NUM_SMS": num_sms,
|
156
|
+
"SMEM_SIZE": smem_config[0],
|
157
|
+
}
|
158
|
+
_, _ = jit_tuner.compile_and_tune(
|
137
159
|
name="m_grouped_gemm_fp8_fp8_bf16_nt",
|
138
160
|
keys={
|
139
161
|
"N": n,
|
@@ -146,24 +168,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
|
|
146
168
|
"NUM_STAGES": num_stages,
|
147
169
|
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
148
170
|
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
149
|
-
"GEMM_TYPE":
|
171
|
+
"GEMM_TYPE": GemmType.GroupedMasked,
|
150
172
|
},
|
151
173
|
space=(),
|
152
|
-
|
153
|
-
|
154
|
-
("lhs", torch.float8_e4m3fn),
|
155
|
-
("lhs_scales", torch.float),
|
156
|
-
("rhs", torch.float8_e4m3fn),
|
157
|
-
("rhs_scales", torch.float),
|
158
|
-
("out", torch.bfloat16),
|
159
|
-
("grouped_layout", torch.int32),
|
160
|
-
("m", int),
|
161
|
-
("stream", torch.cuda.Stream),
|
162
|
-
("num_sms", int),
|
163
|
-
("smem_size", int),
|
164
|
-
),
|
165
|
-
template=deep_gemm_grouped_gemm_template,
|
166
|
-
args=[],
|
174
|
+
kwargs=kwargs,
|
175
|
+
runtime_cls=FP8GemmRuntime,
|
167
176
|
)
|
168
177
|
|
169
178
|
|
@@ -173,9 +182,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
|
|
173
182
|
num_groups: int,
|
174
183
|
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
175
184
|
) -> None:
|
176
|
-
|
177
|
-
|
178
|
-
|
185
|
+
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
186
|
+
block_k = 128
|
187
|
+
num_tma_threads = 128
|
188
|
+
num_math_threads_per_group = 128
|
189
|
+
kwargs = {
|
190
|
+
"NUM_TMA_THREADS": num_tma_threads,
|
191
|
+
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
192
|
+
"BLOCK_K": block_k,
|
193
|
+
"NUM_SMS": num_sms,
|
194
|
+
"SMEM_SIZE": smem_config[0],
|
195
|
+
}
|
196
|
+
_, _ = jit_tuner.compile_and_tune(
|
179
197
|
name="m_grouped_gemm_fp8_fp8_bf16_nt",
|
180
198
|
keys={
|
181
199
|
"N": n,
|
@@ -188,25 +206,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
|
|
188
206
|
"NUM_STAGES": num_stages,
|
189
207
|
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
190
208
|
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
191
|
-
"GEMM_TYPE":
|
209
|
+
"GEMM_TYPE": GemmType.GroupedContiguous,
|
192
210
|
},
|
193
211
|
space=(),
|
194
|
-
|
195
|
-
|
196
|
-
("lhs", torch.float8_e4m3fn),
|
197
|
-
("lhs_scales", torch.float),
|
198
|
-
("rhs", torch.float8_e4m3fn),
|
199
|
-
("rhs_scales", torch.float),
|
200
|
-
("out", torch.bfloat16),
|
201
|
-
("grouped_layout", torch.int32),
|
202
|
-
("m", int),
|
203
|
-
("num_groups", int),
|
204
|
-
("stream", torch.cuda.Stream),
|
205
|
-
("num_sms", int),
|
206
|
-
("smem_size", int),
|
207
|
-
),
|
208
|
-
template=deep_gemm_grouped_gemm_template,
|
209
|
-
args=[],
|
212
|
+
kwargs=kwargs,
|
213
|
+
runtime_cls=FP8GemmRuntime,
|
210
214
|
)
|
211
215
|
|
212
216
|
|
@@ -216,9 +220,20 @@ def _compile_gemm_nt_f8f8bf16_one(
|
|
216
220
|
_: int, # _ is a dummy parameter to align with other interfaces
|
217
221
|
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
218
222
|
) -> None:
|
219
|
-
|
220
|
-
|
221
|
-
|
223
|
+
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
224
|
+
block_k = 128
|
225
|
+
num_tma_threads = 128
|
226
|
+
num_math_threads_per_group = 128
|
227
|
+
kwargs = {
|
228
|
+
"GEMM_TYPE": GemmType.Normal,
|
229
|
+
"NUM_TMA_THREADS": num_tma_threads,
|
230
|
+
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
231
|
+
"NUM_GROUPS": 1,
|
232
|
+
"BLOCK_K": block_k,
|
233
|
+
"NUM_SMS": num_sms,
|
234
|
+
"SMEM_SIZE": smem_config[0],
|
235
|
+
}
|
236
|
+
_, _ = jit_tuner.compile_and_tune(
|
222
237
|
name="gemm_fp8_fp8_bf16_nt",
|
223
238
|
keys={
|
224
239
|
"N": n,
|
@@ -232,20 +247,8 @@ def _compile_gemm_nt_f8f8bf16_one(
|
|
232
247
|
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
233
248
|
},
|
234
249
|
space=(),
|
235
|
-
|
236
|
-
|
237
|
-
("lhs", torch.float8_e4m3fn),
|
238
|
-
("lhs_scales", torch.float),
|
239
|
-
("rhs", torch.float8_e4m3fn),
|
240
|
-
("rhs_scales", torch.float),
|
241
|
-
("out", torch.bfloat16),
|
242
|
-
("m", int),
|
243
|
-
("stream", torch.cuda.Stream),
|
244
|
-
("num_sms", int),
|
245
|
-
("smem_size", int),
|
246
|
-
),
|
247
|
-
template=deep_gemm_gemm_template,
|
248
|
-
args=[],
|
250
|
+
kwargs=kwargs,
|
251
|
+
runtime_cls=FP8GemmRuntime,
|
249
252
|
)
|
250
253
|
|
251
254
|
|
@@ -298,7 +301,7 @@ def _maybe_compile_deep_gemm_one_type_all(
|
|
298
301
|
logger.info(
|
299
302
|
f"Try DeepGEMM JIT Compiling for "
|
300
303
|
f"<{kernel_helper.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
|
301
|
-
f"{' It only takes a
|
304
|
+
f"{' It only takes a little time (typically 1 sec) if you have run `python3 -m sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}"
|
302
305
|
)
|
303
306
|
|
304
307
|
# NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
|
@@ -373,7 +376,7 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
|
|
373
376
|
|
374
377
|
from deep_gemm.jit.runtime import RuntimeCache
|
375
378
|
|
376
|
-
origin_func = RuntimeCache.
|
379
|
+
origin_func = RuntimeCache.get
|
377
380
|
|
378
381
|
def __patched_func(self, *args, **kwargs):
|
379
382
|
ret = origin_func(self, *args, **kwargs)
|
@@ -385,6 +388,6 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
|
|
385
388
|
)
|
386
389
|
return ret
|
387
390
|
|
388
|
-
RuntimeCache.
|
391
|
+
RuntimeCache.get = __patched_func
|
389
392
|
yield
|
390
|
-
RuntimeCache.
|
393
|
+
RuntimeCache.get = origin_func
|