sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +251 -26
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +63 -3
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +34 -19
- sglang/srt/entrypoints/openai/serving_completions.py +10 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +12 -0
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +250 -112
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +110 -49
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +43 -29
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -45
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +242 -278
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +13 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +160 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +90 -115
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +41 -477
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +24 -22
- sglang/srt/mem_cache/hiradix_cache.py +184 -101
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +324 -41
- sglang/srt/mem_cache/memory_pool_host.py +25 -18
- sglang/srt/mem_cache/radix_cache.py +5 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +189 -31
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +311 -50
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +90 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +297 -79
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/utils.py +37 -2
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -11,6 +11,8 @@ import torch
|
|
11
11
|
from compressed_tensors import CompressionFormat
|
12
12
|
from compressed_tensors.quantization import QuantizationStrategy
|
13
13
|
|
14
|
+
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
15
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
14
16
|
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
|
15
17
|
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
|
16
18
|
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
@@ -30,8 +32,10 @@ from sglang.srt.utils import (
|
|
30
32
|
|
31
33
|
if TYPE_CHECKING:
|
32
34
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
33
|
-
from sglang.srt.layers.moe.
|
34
|
-
|
35
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
36
|
+
CombineInput,
|
37
|
+
StandardDispatchOutput,
|
38
|
+
)
|
35
39
|
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
|
36
40
|
CompressedTensorsConfig,
|
37
41
|
)
|
@@ -293,14 +297,24 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
293
297
|
)
|
294
298
|
torch.cuda.empty_cache()
|
295
299
|
|
300
|
+
def create_moe_runner(
|
301
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
302
|
+
):
|
303
|
+
self.moe_runner_config = moe_runner_config
|
304
|
+
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
305
|
+
|
296
306
|
def apply(
|
297
307
|
self,
|
298
308
|
layer: torch.nn.Module,
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
309
|
+
dispatch_output: StandardDispatchOutput,
|
310
|
+
) -> CombineInput:
|
311
|
+
|
312
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
313
|
+
|
314
|
+
x = dispatch_output.hidden_states
|
315
|
+
topk_output = dispatch_output.topk_output
|
316
|
+
|
317
|
+
moe_runner_config = self.moe_runner_config
|
304
318
|
|
305
319
|
if (
|
306
320
|
_use_aiter
|
@@ -308,7 +322,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
308
322
|
and moe_runner_config.apply_router_weight_on_input
|
309
323
|
):
|
310
324
|
topk_weights, topk_ids, _ = topk_output
|
311
|
-
|
325
|
+
output = rocm_fused_experts_tkw1(
|
312
326
|
hidden_states=x,
|
313
327
|
w1=layer.w13_weight,
|
314
328
|
w2=layer.w2_weight,
|
@@ -324,21 +338,20 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
324
338
|
a1_scale=layer.w13_input_scale,
|
325
339
|
a2_scale=layer.w2_input_scale,
|
326
340
|
)
|
341
|
+
return StandardCombineInput(hidden_states=output)
|
327
342
|
else:
|
328
|
-
|
329
|
-
|
330
|
-
layer.
|
331
|
-
layer.w2_weight,
|
332
|
-
topk_output=topk_output,
|
333
|
-
moe_runner_config=moe_runner_config,
|
343
|
+
quant_info = TritonMoeQuantInfo(
|
344
|
+
w13_weight=layer.w13_weight,
|
345
|
+
w2_weight=layer.w2_weight,
|
334
346
|
use_fp8_w8a8=True,
|
335
347
|
per_channel_quant=self.weight_quant.strategy
|
336
348
|
== QuantizationStrategy.CHANNEL,
|
337
|
-
|
349
|
+
w13_scale=layer.w13_weight_scale,
|
338
350
|
w2_scale=layer.w2_weight_scale,
|
339
|
-
|
351
|
+
a13_scale=layer.w13_input_scale,
|
340
352
|
a2_scale=layer.w2_input_scale,
|
341
353
|
)
|
354
|
+
return self.runner.run(dispatch_output, quant_info)
|
342
355
|
|
343
356
|
|
344
357
|
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
@@ -380,8 +393,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|
380
393
|
params_dtype == torch.float16
|
381
394
|
), "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501
|
382
395
|
|
383
|
-
intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full")
|
384
|
-
|
385
396
|
# Will transpose the loaded weight along the
|
386
397
|
# intermediate and hidden dim sizes. Will
|
387
398
|
# shard for TP along the transposed dims
|
@@ -415,13 +426,13 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|
415
426
|
# In the case where we have actorder/g_idx,
|
416
427
|
# we do not partition the w2 scales
|
417
428
|
load_full_w2 = self.actorder and self.group_size != -1
|
418
|
-
w2_scales_size = (
|
419
|
-
intermediate_size_full if load_full_w2 else intermediate_size_per_partition
|
420
|
-
)
|
421
429
|
|
422
|
-
|
423
|
-
intermediate_size_per_partition
|
424
|
-
|
430
|
+
if load_full_w2:
|
431
|
+
w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size
|
432
|
+
else:
|
433
|
+
w2_scales_size = intermediate_size_per_partition
|
434
|
+
|
435
|
+
self.is_k_full = (not self.actorder) or layer.moe_tp_size == 1
|
425
436
|
|
426
437
|
if self.strategy == "channel":
|
427
438
|
num_groups_w2 = num_groups_w13 = 1
|
@@ -640,21 +651,29 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|
640
651
|
)
|
641
652
|
replace_tensor("w2_weight_scale", marlin_w2_scales)
|
642
653
|
|
654
|
+
def create_moe_runner(
|
655
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
656
|
+
):
|
657
|
+
self.moe_runner_config = moe_runner_config
|
658
|
+
|
643
659
|
def apply(
|
644
660
|
self,
|
645
661
|
layer: torch.nn.Module,
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
662
|
+
dispatch_output: StandardDispatchOutput,
|
663
|
+
) -> CombineInput:
|
664
|
+
|
665
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
650
666
|
|
651
667
|
assert (
|
652
|
-
moe_runner_config.activation == "silu"
|
668
|
+
self.moe_runner_config.activation == "silu"
|
653
669
|
), "Only SiLU activation is supported."
|
654
670
|
|
671
|
+
x = dispatch_output.hidden_states
|
672
|
+
topk_output = dispatch_output.topk_output
|
673
|
+
|
655
674
|
topk_weights, topk_ids, router_logits = topk_output
|
656
675
|
|
657
|
-
|
676
|
+
output = torch.ops.vllm.fused_marlin_moe(
|
658
677
|
x,
|
659
678
|
layer.w13_weight_packed,
|
660
679
|
layer.w2_weight_packed,
|
@@ -670,3 +689,4 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|
670
689
|
num_bits=self.num_bits,
|
671
690
|
is_k_full=self.is_k_full,
|
672
691
|
)
|
692
|
+
return StandardCombineInput(hidden_states=output)
|
@@ -21,9 +21,15 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|
21
21
|
normalize_e4m3fn_to_e4m3fnuz,
|
22
22
|
)
|
23
23
|
from sglang.srt.layers.quantization.utils import requantize_with_max_scale
|
24
|
+
from sglang.srt.utils import get_bool_env_var, is_hip
|
24
25
|
|
25
26
|
__all__ = ["CompressedTensorsW8A8Fp8"]
|
26
27
|
|
28
|
+
_is_hip = is_hip()
|
29
|
+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
30
|
+
if _use_aiter:
|
31
|
+
from aiter.ops.shuffle import shuffle_weight
|
32
|
+
|
27
33
|
|
28
34
|
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
29
35
|
|
@@ -76,7 +82,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|
76
82
|
else:
|
77
83
|
weight_scale = layer.weight_scale.data
|
78
84
|
|
79
|
-
|
85
|
+
if _use_aiter:
|
86
|
+
layer.weight = Parameter(
|
87
|
+
shuffle_weight(weight, (16, 16)), requires_grad=False
|
88
|
+
)
|
89
|
+
else:
|
90
|
+
layer.weight = Parameter(weight.t(), requires_grad=False)
|
91
|
+
|
80
92
|
# required by torch.compile to be torch.nn.Parameter
|
81
93
|
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
82
94
|
|
@@ -93,7 +93,7 @@ def _maybe_compile_deep_gemm_one_type_all(
|
|
93
93
|
if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
|
94
94
|
logger.warning(
|
95
95
|
"Entering DeepGEMM JIT Pre-Compile session. "
|
96
|
-
"It may
|
96
|
+
"It may take a long time (typically 10-20 mins) "
|
97
97
|
"if you have not run `sglang.compile_deep_gemm`. "
|
98
98
|
"It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
99
99
|
" for pre-compilation to reduce the overhead if you have not run it before. "
|
@@ -132,9 +132,17 @@ def _compile_deep_gemm_one_type_all(
|
|
132
132
|
kernel_type, max_m=max(m_list), n=n, k=k, num_groups=num_groups
|
133
133
|
)
|
134
134
|
|
135
|
+
old_compile_mode = deep_gemm.get_compile_mode()
|
136
|
+
deep_gemm.set_compile_mode(1)
|
135
137
|
# TODO can use multi thread
|
136
138
|
for m in tqdm(m_list, desc=f"DeepGEMM warmup"):
|
137
139
|
executor.execute(m=m)
|
140
|
+
deep_gemm.set_compile_mode(old_compile_mode)
|
141
|
+
|
142
|
+
# clean up input buffers
|
143
|
+
torch.cuda.current_stream().synchronize()
|
144
|
+
del executor
|
145
|
+
torch.cuda.empty_cache()
|
138
146
|
|
139
147
|
|
140
148
|
class _BaseWarmupExecutor:
|
@@ -11,6 +11,7 @@ from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
|
|
11
11
|
ENABLE_JIT_DEEPGEMM,
|
12
12
|
)
|
13
13
|
from sglang.srt.server_args import ServerArgs
|
14
|
+
from sglang.srt.utils import get_bool_env_var
|
14
15
|
|
15
16
|
logger = logging.getLogger(__name__)
|
16
17
|
|
@@ -18,6 +19,8 @@ if ENABLE_JIT_DEEPGEMM:
|
|
18
19
|
import deep_gemm
|
19
20
|
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
|
20
21
|
|
22
|
+
_SANITY_CHECK = get_bool_env_var("SGLANG_DEEPGEMM_SANITY_CHECK")
|
23
|
+
|
21
24
|
|
22
25
|
# TODO maybe rename these functions
|
23
26
|
def grouped_gemm_nt_f8f8bf16_masked(
|
@@ -31,6 +34,9 @@ def grouped_gemm_nt_f8f8bf16_masked(
|
|
31
34
|
_, n, _ = rhs[0].shape
|
32
35
|
kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED
|
33
36
|
|
37
|
+
_sanity_check_input(lhs)
|
38
|
+
_sanity_check_input(rhs)
|
39
|
+
|
34
40
|
with compile_utils.deep_gemm_execution_hook(
|
35
41
|
expected_m, n, k, num_groups, kernel_type
|
36
42
|
):
|
@@ -53,6 +59,9 @@ def grouped_gemm_nt_f8f8bf16_contig(
|
|
53
59
|
num_groups, n, _ = rhs[0].shape
|
54
60
|
kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
|
55
61
|
|
62
|
+
_sanity_check_input(lhs)
|
63
|
+
_sanity_check_input(rhs)
|
64
|
+
|
56
65
|
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
|
57
66
|
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(lhs, rhs, out, m_indices)
|
58
67
|
|
@@ -67,6 +76,9 @@ def gemm_nt_f8f8bf16(
|
|
67
76
|
num_groups = 1
|
68
77
|
kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16
|
69
78
|
|
79
|
+
_sanity_check_input(lhs)
|
80
|
+
_sanity_check_input(rhs)
|
81
|
+
|
70
82
|
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
|
71
83
|
deep_gemm.fp8_gemm_nt(
|
72
84
|
lhs,
|
@@ -90,3 +102,18 @@ def configure_deep_gemm_num_sms(num_sms):
|
|
90
102
|
yield
|
91
103
|
finally:
|
92
104
|
deep_gemm.set_num_sms(original_num_sms)
|
105
|
+
|
106
|
+
|
107
|
+
def _sanity_check_input(x_fp8: Tuple[torch.Tensor, torch.Tensor]):
|
108
|
+
if not _SANITY_CHECK:
|
109
|
+
return
|
110
|
+
|
111
|
+
x, x_scale = x_fp8
|
112
|
+
|
113
|
+
if x_scale.dtype == torch.int:
|
114
|
+
return
|
115
|
+
|
116
|
+
from sglang.srt.layers.quantization.fp8_utils import ceil_to_ue8m0
|
117
|
+
|
118
|
+
x_scale_ceil = ceil_to_ue8m0(x_scale)
|
119
|
+
assert torch.all(x_scale == x_scale_ceil), f"{x_scale=} {x_scale_ceil=}"
|
@@ -30,6 +30,9 @@ except ImportError:
|
|
30
30
|
|
31
31
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
32
32
|
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
33
|
+
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
34
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
35
|
+
from sglang.srt.layers.moe.token_dispatcher.base import DispatchOutputChecker
|
33
36
|
from sglang.srt.layers.parameter import (
|
34
37
|
BlockQuantScaleParameter,
|
35
38
|
ModelWeightParameter,
|
@@ -81,7 +84,11 @@ from sglang.srt.utils import (
|
|
81
84
|
)
|
82
85
|
|
83
86
|
if TYPE_CHECKING:
|
84
|
-
from sglang.srt.layers.moe.
|
87
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
88
|
+
CombineInput,
|
89
|
+
DispatchOutput,
|
90
|
+
StandardDispatchOutput,
|
91
|
+
)
|
85
92
|
from sglang.srt.layers.moe.topk import TopKOutput
|
86
93
|
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
87
94
|
|
@@ -345,6 +352,9 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
345
352
|
_is_cpu_amx_available
|
346
353
|
), "Fp8LinearMethod on CPU requires that CPU has AMX support"
|
347
354
|
_amx_process_weight_after_loading(layer, ["weight"])
|
355
|
+
layer.weight_scale_inv = torch.nn.Parameter(
|
356
|
+
layer.weight_scale_inv.data, requires_grad=False
|
357
|
+
)
|
348
358
|
return
|
349
359
|
else:
|
350
360
|
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
|
@@ -527,7 +537,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
527
537
|
layer: Module,
|
528
538
|
num_experts: int,
|
529
539
|
hidden_size: int,
|
530
|
-
|
540
|
+
intermediate_size_per_partition: int,
|
531
541
|
params_dtype: torch.dtype,
|
532
542
|
**extra_weight_attrs,
|
533
543
|
):
|
@@ -543,18 +553,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
543
553
|
)
|
544
554
|
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
545
555
|
# Required by column parallel or enabling merged weights
|
546
|
-
if
|
556
|
+
if intermediate_size_per_partition % block_n != 0:
|
547
557
|
raise ValueError(
|
548
558
|
f"The output_size of gate's and up's weight = "
|
549
|
-
f"{
|
559
|
+
f"{intermediate_size_per_partition} is not divisible by "
|
550
560
|
f"weight quantization block_n = {block_n}."
|
551
561
|
)
|
552
562
|
if tp_size > 1:
|
553
563
|
# Required by row parallel
|
554
|
-
if
|
564
|
+
if intermediate_size_per_partition % block_k != 0:
|
555
565
|
raise ValueError(
|
556
566
|
f"The input_size of down's weight = "
|
557
|
-
f"{
|
567
|
+
f"{intermediate_size_per_partition} is not divisible by "
|
558
568
|
f"weight quantization block_k = {block_k}."
|
559
569
|
)
|
560
570
|
|
@@ -564,7 +574,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
564
574
|
w13_weight = torch.nn.Parameter(
|
565
575
|
torch.empty(
|
566
576
|
num_experts,
|
567
|
-
2 *
|
577
|
+
2 * intermediate_size_per_partition,
|
568
578
|
hidden_size // 8,
|
569
579
|
dtype=params_dtype,
|
570
580
|
),
|
@@ -572,20 +582,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
572
582
|
)
|
573
583
|
w2_weight = torch.nn.Parameter(
|
574
584
|
torch.empty(
|
575
|
-
num_experts,
|
585
|
+
num_experts,
|
586
|
+
hidden_size,
|
587
|
+
intermediate_size_per_partition // 8,
|
588
|
+
dtype=params_dtype,
|
576
589
|
),
|
577
590
|
requires_grad=False,
|
578
591
|
)
|
579
592
|
else:
|
580
593
|
w13_weight = torch.nn.Parameter(
|
581
594
|
torch.empty(
|
582
|
-
num_experts,
|
595
|
+
num_experts,
|
596
|
+
2 * intermediate_size_per_partition,
|
597
|
+
hidden_size,
|
598
|
+
dtype=params_dtype,
|
583
599
|
),
|
584
600
|
requires_grad=False,
|
585
601
|
)
|
586
602
|
w2_weight = torch.nn.Parameter(
|
587
603
|
torch.empty(
|
588
|
-
num_experts,
|
604
|
+
num_experts,
|
605
|
+
hidden_size,
|
606
|
+
intermediate_size_per_partition,
|
607
|
+
dtype=params_dtype,
|
589
608
|
),
|
590
609
|
requires_grad=False,
|
591
610
|
)
|
@@ -601,7 +620,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
601
620
|
w13_weight_scale = torch.nn.Parameter(
|
602
621
|
torch.ones(
|
603
622
|
num_experts,
|
604
|
-
2 * ((
|
623
|
+
2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
|
605
624
|
(hidden_size + block_k - 1) // block_k,
|
606
625
|
dtype=torch.float32,
|
607
626
|
),
|
@@ -611,7 +630,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
611
630
|
torch.ones(
|
612
631
|
num_experts,
|
613
632
|
(hidden_size + block_n - 1) // block_n,
|
614
|
-
(
|
633
|
+
(intermediate_size_per_partition + block_k - 1) // block_k,
|
615
634
|
dtype=torch.float32,
|
616
635
|
),
|
617
636
|
requires_grad=False,
|
@@ -619,11 +638,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
619
638
|
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
620
639
|
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
621
640
|
assert self.quant_config.activation_scheme == "dynamic"
|
622
|
-
if
|
623
|
-
get_bool_env_var("SGLANG_CUTLASS_MOE")
|
624
|
-
and self.cutlass_fp8_supported
|
625
|
-
and (is_sm100_supported() or is_sm90_supported())
|
626
|
-
):
|
641
|
+
if self.use_cutlass_fused_experts_fp8:
|
627
642
|
self.ab_strides1 = torch.full(
|
628
643
|
(num_experts,),
|
629
644
|
hidden_size,
|
@@ -632,13 +647,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
632
647
|
)
|
633
648
|
self.c_strides1 = torch.full(
|
634
649
|
(num_experts,),
|
635
|
-
2 *
|
650
|
+
2 * intermediate_size_per_partition,
|
636
651
|
device=w13_weight.device,
|
637
652
|
dtype=torch.int64,
|
638
653
|
)
|
639
654
|
self.ab_strides2 = torch.full(
|
640
655
|
(num_experts,),
|
641
|
-
|
656
|
+
intermediate_size_per_partition,
|
642
657
|
device=w2_weight.device,
|
643
658
|
dtype=torch.int64,
|
644
659
|
)
|
@@ -691,7 +706,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
691
706
|
if _is_hip: # _use_aiter: TODO: add check back after triton kernel
|
692
707
|
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
|
693
708
|
w13_weight_scale1 = torch.nn.Parameter(
|
694
|
-
torch.ones(
|
709
|
+
torch.ones(
|
710
|
+
num_experts,
|
711
|
+
2 * intermediate_size_per_partition,
|
712
|
+
dtype=torch.float32,
|
713
|
+
),
|
695
714
|
requires_grad=False,
|
696
715
|
)
|
697
716
|
w2_weight_scale1 = torch.nn.Parameter(
|
@@ -984,14 +1003,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
984
1003
|
)
|
985
1004
|
torch.cuda.empty_cache()
|
986
1005
|
|
1006
|
+
def create_moe_runner(
|
1007
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
1008
|
+
):
|
1009
|
+
self.moe_runner_config = moe_runner_config
|
1010
|
+
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
1011
|
+
|
987
1012
|
def apply(
|
988
1013
|
self,
|
989
1014
|
layer: torch.nn.Module,
|
990
|
-
|
991
|
-
|
992
|
-
|
993
|
-
|
994
|
-
|
1015
|
+
dispatch_output: DispatchOutput,
|
1016
|
+
) -> CombineInput:
|
1017
|
+
|
1018
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
1019
|
+
|
1020
|
+
x = dispatch_output.hidden_states
|
1021
|
+
topk_output = dispatch_output.topk_output
|
1022
|
+
moe_runner_config = self.moe_runner_config
|
995
1023
|
|
996
1024
|
if use_intel_amx_backend(layer):
|
997
1025
|
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
@@ -1001,7 +1029,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1001
1029
|
moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
1002
1030
|
)
|
1003
1031
|
|
1004
|
-
|
1032
|
+
output = torch.ops.sgl_kernel.fused_experts_cpu(
|
1005
1033
|
x,
|
1006
1034
|
layer.w13_weight,
|
1007
1035
|
layer.w2_weight,
|
@@ -1017,6 +1045,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1017
1045
|
None, # a2_scale
|
1018
1046
|
True, # is_vnni
|
1019
1047
|
)
|
1048
|
+
return StandardCombineInput(hidden_states=output)
|
1020
1049
|
|
1021
1050
|
if _is_hip:
|
1022
1051
|
ret = self.maybe_apply_hip_fused_experts(
|
@@ -1027,7 +1056,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1027
1056
|
moe_runner_config.no_combine,
|
1028
1057
|
)
|
1029
1058
|
if ret is not None:
|
1030
|
-
return ret
|
1059
|
+
return StandardCombineInput(hidden_states=ret)
|
1031
1060
|
|
1032
1061
|
if self.use_cutlass_fused_experts_fp8:
|
1033
1062
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
|
@@ -1056,17 +1085,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1056
1085
|
self.problem_sizes2,
|
1057
1086
|
use_fp8_blockscale=True,
|
1058
1087
|
)
|
1059
|
-
|
1060
|
-
|
1061
|
-
|
1062
|
-
|
1063
|
-
|
1064
|
-
layer.w13_weight,
|
1065
|
-
layer.w2_weight,
|
1066
|
-
topk_output=topk_output,
|
1067
|
-
moe_runner_config=moe_runner_config,
|
1088
|
+
return StandardCombineInput(hidden_states=output)
|
1089
|
+
|
1090
|
+
quant_info = TritonMoeQuantInfo(
|
1091
|
+
w13_weight=layer.w13_weight,
|
1092
|
+
w2_weight=layer.w2_weight,
|
1068
1093
|
use_fp8_w8a8=True,
|
1069
|
-
|
1094
|
+
w13_scale=(
|
1070
1095
|
layer.w13_weight_scale_inv
|
1071
1096
|
if self.block_quant
|
1072
1097
|
else layer.w13_weight_scale
|
@@ -1074,20 +1099,22 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1074
1099
|
w2_scale=(
|
1075
1100
|
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
|
1076
1101
|
),
|
1077
|
-
|
1102
|
+
a13_scale=layer.w13_input_scale,
|
1078
1103
|
a2_scale=layer.w2_input_scale,
|
1079
1104
|
block_shape=self.quant_config.weight_block_size,
|
1080
1105
|
)
|
1106
|
+
return self.runner.run(dispatch_output, quant_info)
|
1081
1107
|
|
1082
1108
|
def apply_with_router_logits(
|
1083
1109
|
self,
|
1084
1110
|
layer: torch.nn.Module,
|
1085
|
-
|
1086
|
-
topk_output: TopKOutput,
|
1087
|
-
moe_runner_config: MoeRunnerConfig,
|
1111
|
+
dispatch_output: StandardDispatchOutput,
|
1088
1112
|
) -> torch.Tensor:
|
1089
|
-
|
1090
|
-
|
1113
|
+
x = dispatch_output.hidden_states
|
1114
|
+
topk_output = dispatch_output.topk_output
|
1115
|
+
|
1116
|
+
activation = self.moe_runner_config.activation
|
1117
|
+
routed_scaling_factor = self.moe_runner_config.routed_scaling_factor
|
1091
1118
|
|
1092
1119
|
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
|
1093
1120
|
|
@@ -1108,10 +1135,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1108
1135
|
and topk_config.topk_group is not None
|
1109
1136
|
), "Current trtllm_fp8_block_scale_moe kernel does not support these two arguments as None"
|
1110
1137
|
|
1111
|
-
|
1112
|
-
|
1113
|
-
|
1114
|
-
correction_bias
|
1138
|
+
correction_bias = (
|
1139
|
+
None
|
1140
|
+
if topk_config.correction_bias is None
|
1141
|
+
else topk_config.correction_bias.to(x.dtype)
|
1142
|
+
)
|
1143
|
+
|
1115
1144
|
return trtllm_fp8_block_scale_moe(
|
1116
1145
|
routing_logits=router_logits.to(torch.float32),
|
1117
1146
|
routing_bias=correction_bias,
|