sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +89 -54
- sglang/bench_serving.py +437 -40
- sglang/lang/interpreter.py +1 -1
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +90 -27
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +82 -26
- sglang/srt/entrypoints/openai/serving_completions.py +25 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +28 -7
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +381 -136
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +11 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -8
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +111 -56
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
- sglang/srt/layers/quantization/fp8.py +78 -48
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +45 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +93 -68
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +396 -365
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +18 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +190 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +148 -122
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +77 -480
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +53 -40
- sglang/srt/mem_cache/hiradix_cache.py +196 -104
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +395 -53
- sglang/srt/mem_cache/memory_pool_host.py +27 -19
- sglang/srt/mem_cache/radix_cache.py +6 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +190 -32
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +323 -53
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +7 -19
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +91 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{conversation.py → parser/conversation.py} +38 -5
- sglang/srt/parser/harmony_parser.py +588 -0
- sglang/srt/parser/reasoning_parser.py +309 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +307 -80
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +96 -7
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- sglang/srt/reasoning_parser.py +0 -553
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -11,6 +11,8 @@ import torch
|
|
11
11
|
from compressed_tensors import CompressionFormat
|
12
12
|
from compressed_tensors.quantization import QuantizationStrategy
|
13
13
|
|
14
|
+
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
15
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
14
16
|
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
|
15
17
|
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
|
16
18
|
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
@@ -30,8 +32,10 @@ from sglang.srt.utils import (
|
|
30
32
|
|
31
33
|
if TYPE_CHECKING:
|
32
34
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
33
|
-
from sglang.srt.layers.moe.
|
34
|
-
|
35
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
36
|
+
CombineInput,
|
37
|
+
StandardDispatchOutput,
|
38
|
+
)
|
35
39
|
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
|
36
40
|
CompressedTensorsConfig,
|
37
41
|
)
|
@@ -293,14 +297,24 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
293
297
|
)
|
294
298
|
torch.cuda.empty_cache()
|
295
299
|
|
300
|
+
def create_moe_runner(
|
301
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
302
|
+
):
|
303
|
+
self.moe_runner_config = moe_runner_config
|
304
|
+
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
305
|
+
|
296
306
|
def apply(
|
297
307
|
self,
|
298
308
|
layer: torch.nn.Module,
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
309
|
+
dispatch_output: StandardDispatchOutput,
|
310
|
+
) -> CombineInput:
|
311
|
+
|
312
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
313
|
+
|
314
|
+
x = dispatch_output.hidden_states
|
315
|
+
topk_output = dispatch_output.topk_output
|
316
|
+
|
317
|
+
moe_runner_config = self.moe_runner_config
|
304
318
|
|
305
319
|
if (
|
306
320
|
_use_aiter
|
@@ -308,7 +322,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
308
322
|
and moe_runner_config.apply_router_weight_on_input
|
309
323
|
):
|
310
324
|
topk_weights, topk_ids, _ = topk_output
|
311
|
-
|
325
|
+
output = rocm_fused_experts_tkw1(
|
312
326
|
hidden_states=x,
|
313
327
|
w1=layer.w13_weight,
|
314
328
|
w2=layer.w2_weight,
|
@@ -324,21 +338,20 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
324
338
|
a1_scale=layer.w13_input_scale,
|
325
339
|
a2_scale=layer.w2_input_scale,
|
326
340
|
)
|
341
|
+
return StandardCombineInput(hidden_states=output)
|
327
342
|
else:
|
328
|
-
|
329
|
-
|
330
|
-
layer.
|
331
|
-
layer.w2_weight,
|
332
|
-
topk_output=topk_output,
|
333
|
-
moe_runner_config=moe_runner_config,
|
343
|
+
quant_info = TritonMoeQuantInfo(
|
344
|
+
w13_weight=layer.w13_weight,
|
345
|
+
w2_weight=layer.w2_weight,
|
334
346
|
use_fp8_w8a8=True,
|
335
347
|
per_channel_quant=self.weight_quant.strategy
|
336
348
|
== QuantizationStrategy.CHANNEL,
|
337
|
-
|
349
|
+
w13_scale=layer.w13_weight_scale,
|
338
350
|
w2_scale=layer.w2_weight_scale,
|
339
|
-
|
351
|
+
a13_scale=layer.w13_input_scale,
|
340
352
|
a2_scale=layer.w2_input_scale,
|
341
353
|
)
|
354
|
+
return self.runner.run(dispatch_output, quant_info)
|
342
355
|
|
343
356
|
|
344
357
|
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
@@ -380,8 +393,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|
380
393
|
params_dtype == torch.float16
|
381
394
|
), "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501
|
382
395
|
|
383
|
-
intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full")
|
384
|
-
|
385
396
|
# Will transpose the loaded weight along the
|
386
397
|
# intermediate and hidden dim sizes. Will
|
387
398
|
# shard for TP along the transposed dims
|
@@ -415,13 +426,13 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|
415
426
|
# In the case where we have actorder/g_idx,
|
416
427
|
# we do not partition the w2 scales
|
417
428
|
load_full_w2 = self.actorder and self.group_size != -1
|
418
|
-
w2_scales_size = (
|
419
|
-
intermediate_size_full if load_full_w2 else intermediate_size_per_partition
|
420
|
-
)
|
421
429
|
|
422
|
-
|
423
|
-
intermediate_size_per_partition
|
424
|
-
|
430
|
+
if load_full_w2:
|
431
|
+
w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size
|
432
|
+
else:
|
433
|
+
w2_scales_size = intermediate_size_per_partition
|
434
|
+
|
435
|
+
self.is_k_full = (not self.actorder) or layer.moe_tp_size == 1
|
425
436
|
|
426
437
|
if self.strategy == "channel":
|
427
438
|
num_groups_w2 = num_groups_w13 = 1
|
@@ -640,21 +651,29 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|
640
651
|
)
|
641
652
|
replace_tensor("w2_weight_scale", marlin_w2_scales)
|
642
653
|
|
654
|
+
def create_moe_runner(
|
655
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
656
|
+
):
|
657
|
+
self.moe_runner_config = moe_runner_config
|
658
|
+
|
643
659
|
def apply(
|
644
660
|
self,
|
645
661
|
layer: torch.nn.Module,
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
662
|
+
dispatch_output: StandardDispatchOutput,
|
663
|
+
) -> CombineInput:
|
664
|
+
|
665
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
650
666
|
|
651
667
|
assert (
|
652
|
-
moe_runner_config.activation == "silu"
|
668
|
+
self.moe_runner_config.activation == "silu"
|
653
669
|
), "Only SiLU activation is supported."
|
654
670
|
|
671
|
+
x = dispatch_output.hidden_states
|
672
|
+
topk_output = dispatch_output.topk_output
|
673
|
+
|
655
674
|
topk_weights, topk_ids, router_logits = topk_output
|
656
675
|
|
657
|
-
|
676
|
+
output = torch.ops.vllm.fused_marlin_moe(
|
658
677
|
x,
|
659
678
|
layer.w13_weight_packed,
|
660
679
|
layer.w2_weight_packed,
|
@@ -670,3 +689,4 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|
670
689
|
num_bits=self.num_bits,
|
671
690
|
is_k_full=self.is_k_full,
|
672
691
|
)
|
692
|
+
return StandardCombineInput(hidden_states=output)
|
@@ -21,9 +21,15 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|
21
21
|
normalize_e4m3fn_to_e4m3fnuz,
|
22
22
|
)
|
23
23
|
from sglang.srt.layers.quantization.utils import requantize_with_max_scale
|
24
|
+
from sglang.srt.utils import get_bool_env_var, is_hip
|
24
25
|
|
25
26
|
__all__ = ["CompressedTensorsW8A8Fp8"]
|
26
27
|
|
28
|
+
_is_hip = is_hip()
|
29
|
+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
30
|
+
if _use_aiter:
|
31
|
+
from aiter.ops.shuffle import shuffle_weight
|
32
|
+
|
27
33
|
|
28
34
|
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
29
35
|
|
@@ -76,7 +82,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|
76
82
|
else:
|
77
83
|
weight_scale = layer.weight_scale.data
|
78
84
|
|
79
|
-
|
85
|
+
if _use_aiter:
|
86
|
+
layer.weight = Parameter(
|
87
|
+
shuffle_weight(weight, (16, 16)), requires_grad=False
|
88
|
+
)
|
89
|
+
else:
|
90
|
+
layer.weight = Parameter(weight.t(), requires_grad=False)
|
91
|
+
|
80
92
|
# required by torch.compile to be torch.nn.Parameter
|
81
93
|
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
82
94
|
|
@@ -1,26 +1,22 @@
|
|
1
1
|
import logging
|
2
2
|
import os
|
3
3
|
from contextlib import contextmanager
|
4
|
-
from dataclasses import dataclass
|
5
4
|
from enum import IntEnum, auto
|
6
|
-
from typing import
|
5
|
+
from typing import Dict, List, Tuple
|
7
6
|
|
8
|
-
|
7
|
+
import torch
|
8
|
+
from tqdm import tqdm
|
9
9
|
|
10
10
|
from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
|
11
|
-
DEEPGEMM_BLACKWELL,
|
12
11
|
ENABLE_JIT_DEEPGEMM,
|
13
12
|
)
|
14
13
|
from sglang.srt.server_args import ServerArgs
|
15
|
-
from sglang.srt.utils import get_bool_env_var, get_int_env_var
|
14
|
+
from sglang.srt.utils import ceil_div, get_bool_env_var, get_int_env_var
|
16
15
|
|
17
16
|
logger = logging.getLogger(__name__)
|
18
17
|
|
19
|
-
if ENABLE_JIT_DEEPGEMM
|
20
|
-
|
21
|
-
from deep_gemm.jit import build
|
22
|
-
from deep_gemm.jit_kernels.gemm import get_best_configs
|
23
|
-
from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
|
18
|
+
if ENABLE_JIT_DEEPGEMM:
|
19
|
+
import deep_gemm
|
24
20
|
|
25
21
|
|
26
22
|
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
|
@@ -40,19 +36,7 @@ os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
|
|
40
36
|
# Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
|
41
37
|
# NVRTC may have performance loss with some cases.
|
42
38
|
# And NVCC JIT speed is also 9x faster in the ref commit
|
43
|
-
|
44
|
-
if ENABLE_JIT_DEEPGEMM:
|
45
|
-
try:
|
46
|
-
from deep_gemm.jit.compiler import get_nvcc_compiler
|
47
|
-
|
48
|
-
get_nvcc_compiler()
|
49
|
-
except:
|
50
|
-
logger.warning(
|
51
|
-
"NVCC Compiler not found, use NVRTC for DeepGEMM JIT "
|
52
|
-
"and may have performance loss with some cases."
|
53
|
-
)
|
54
|
-
_USE_NVRTC_DEFAULT = "1"
|
55
|
-
os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", _USE_NVRTC_DEFAULT)
|
39
|
+
os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", "0")
|
56
40
|
|
57
41
|
|
58
42
|
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
|
@@ -75,7 +59,7 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
|
|
75
59
|
# Default each rank will try compile all Ms to
|
76
60
|
# load all symbols at the launch stages.
|
77
61
|
# Avoid loading symbols at the serving stages.
|
78
|
-
_DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE
|
62
|
+
_DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE
|
79
63
|
|
80
64
|
|
81
65
|
class DeepGemmKernelType(IntEnum):
|
@@ -84,185 +68,15 @@ class DeepGemmKernelType(IntEnum):
|
|
84
68
|
GEMM_NT_F8F8BF16 = auto()
|
85
69
|
|
86
70
|
|
87
|
-
@dataclass
|
88
|
-
class DeepGemmKernelHelper:
|
89
|
-
name: str
|
90
|
-
compile_func: Callable[
|
91
|
-
[
|
92
|
-
int,
|
93
|
-
int,
|
94
|
-
int,
|
95
|
-
Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
96
|
-
],
|
97
|
-
None,
|
98
|
-
]
|
99
|
-
configure_func: Callable[
|
100
|
-
[int, int, int, int, int],
|
101
|
-
Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
102
|
-
]
|
103
|
-
|
104
|
-
|
105
71
|
_INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dict()
|
106
72
|
|
107
73
|
|
108
|
-
# TODO improve
|
109
|
-
def _compile_warning_1():
|
110
|
-
if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
|
111
|
-
logger.warning(
|
112
|
-
"Entering DeepGEMM JIT Pre-Compile session. "
|
113
|
-
"It may takes a long time (typically 10-20 mins) "
|
114
|
-
"if you have not run `sglang.compile_deep_gemm`. "
|
115
|
-
"It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
116
|
-
" for pre-compilation to reduce the overhead if you have not run it before. "
|
117
|
-
"For example: "
|
118
|
-
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
|
119
|
-
)
|
120
|
-
|
121
|
-
|
122
|
-
# TODO improve naming
|
123
|
-
def _compile_warning_2():
|
124
|
-
logger.warning(
|
125
|
-
"Entering DeepGEMM JIT Single Kernel Compile session. "
|
126
|
-
"And it will makes inference throughput becomes flaky. "
|
127
|
-
"Please run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
128
|
-
" for pre-compilation to solve this issue. "
|
129
|
-
"For example: "
|
130
|
-
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
|
131
|
-
)
|
132
|
-
|
133
|
-
|
134
|
-
def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
|
135
|
-
n: int,
|
136
|
-
k: int,
|
137
|
-
num_groups: int,
|
138
|
-
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
139
|
-
) -> None:
|
140
|
-
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
141
|
-
block_k = 128
|
142
|
-
num_tma_threads = 128
|
143
|
-
num_math_threads_per_group = 128
|
144
|
-
|
145
|
-
kwargs = {
|
146
|
-
"GEMM_TYPE": GemmType.GroupedMasked,
|
147
|
-
"NUM_TMA_THREADS": num_tma_threads,
|
148
|
-
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
149
|
-
"N": n,
|
150
|
-
"K": k,
|
151
|
-
"NUM_GROUPS": num_groups,
|
152
|
-
"BLOCK_M": block_m,
|
153
|
-
"BLOCK_N": block_n,
|
154
|
-
"BLOCK_K": block_k,
|
155
|
-
"SWIZZLE_D_MODE": smem_config[1],
|
156
|
-
"BLOCK_N_PADDING": smem_config[2],
|
157
|
-
"NUM_STAGES": num_stages,
|
158
|
-
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
159
|
-
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
160
|
-
"NUM_SMS": num_sms,
|
161
|
-
"SMEM_SIZE": smem_config[0],
|
162
|
-
}
|
163
|
-
|
164
|
-
code = FP8GemmRuntime.generate(kwargs)
|
165
|
-
_ = build("m_grouped_gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
|
166
|
-
|
167
|
-
|
168
|
-
def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
|
169
|
-
n: int,
|
170
|
-
k: int,
|
171
|
-
num_groups: int,
|
172
|
-
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
173
|
-
) -> None:
|
174
|
-
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
175
|
-
block_k = 128
|
176
|
-
num_tma_threads = 128
|
177
|
-
num_math_threads_per_group = 128
|
178
|
-
kwargs = {
|
179
|
-
"GEMM_TYPE": GemmType.GroupedContiguous,
|
180
|
-
"NUM_TMA_THREADS": num_tma_threads,
|
181
|
-
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
182
|
-
"N": n,
|
183
|
-
"K": k,
|
184
|
-
"NUM_GROUPS": 1,
|
185
|
-
"BLOCK_M": block_m,
|
186
|
-
"BLOCK_N": block_n,
|
187
|
-
"BLOCK_K": block_k,
|
188
|
-
"SWIZZLE_D_MODE": smem_config[1],
|
189
|
-
"BLOCK_N_PADDING": smem_config[2],
|
190
|
-
"NUM_STAGES": num_stages,
|
191
|
-
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
192
|
-
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
193
|
-
"NUM_SMS": num_sms,
|
194
|
-
"SMEM_SIZE": smem_config[0],
|
195
|
-
}
|
196
|
-
|
197
|
-
code = FP8GemmRuntime.generate(kwargs)
|
198
|
-
_ = build("m_grouped_gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
|
199
|
-
|
200
|
-
|
201
|
-
def _compile_gemm_nt_f8f8bf16_one(
|
202
|
-
n: int,
|
203
|
-
k: int,
|
204
|
-
_: int, # _ is a dummy parameter to align with other interfaces
|
205
|
-
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
206
|
-
) -> None:
|
207
|
-
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
208
|
-
block_k = 128
|
209
|
-
num_tma_threads = 128
|
210
|
-
num_math_threads_per_group = 128
|
211
|
-
kwargs = {
|
212
|
-
"GEMM_TYPE": GemmType.Normal,
|
213
|
-
"NUM_TMA_THREADS": num_tma_threads,
|
214
|
-
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
215
|
-
"N": n,
|
216
|
-
"K": k,
|
217
|
-
"NUM_GROUPS": 1,
|
218
|
-
"BLOCK_M": block_m,
|
219
|
-
"BLOCK_N": block_n,
|
220
|
-
"BLOCK_K": block_k,
|
221
|
-
"SWIZZLE_D_MODE": smem_config[1],
|
222
|
-
"BLOCK_N_PADDING": smem_config[2],
|
223
|
-
"NUM_STAGES": num_stages,
|
224
|
-
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
225
|
-
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
226
|
-
"NUM_SMS": num_sms,
|
227
|
-
"SMEM_SIZE": smem_config[0],
|
228
|
-
}
|
229
|
-
|
230
|
-
code = FP8GemmRuntime.generate(kwargs)
|
231
|
-
_ = build("gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
|
232
|
-
|
233
|
-
|
234
|
-
# TODO further refactor warmup-related
|
235
|
-
_KERNEL_HELPER_DICT: Dict[DeepGemmKernelType, DeepGemmKernelHelper] = {
|
236
|
-
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: DeepGemmKernelHelper(
|
237
|
-
name="m_grouped_gemm_fp8_fp8_bf16_nt_masked",
|
238
|
-
compile_func=_compile_grouped_gemm_nt_f8f8bf16_masked_one,
|
239
|
-
configure_func=lambda m, n, k, num_groups, num_sms: get_best_configs(
|
240
|
-
m, n, k, num_groups, num_sms, is_grouped_masked=True
|
241
|
-
),
|
242
|
-
),
|
243
|
-
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG: DeepGemmKernelHelper(
|
244
|
-
name="m_grouped_gemm_fp8_fp8_bf16_nt_contiguous",
|
245
|
-
compile_func=_compile_grouped_gemm_nt_f8f8bf16_contig_one,
|
246
|
-
configure_func=lambda m, n, k, _, num_sms: get_best_configs(
|
247
|
-
m, n, k, 1, num_sms, is_grouped_contiguous=True
|
248
|
-
),
|
249
|
-
),
|
250
|
-
DeepGemmKernelType.GEMM_NT_F8F8BF16: DeepGemmKernelHelper(
|
251
|
-
name="gemm_fp8_fp8_bf16_nt",
|
252
|
-
compile_func=_compile_gemm_nt_f8f8bf16_one,
|
253
|
-
configure_func=lambda m, n, k, _, num_sms: get_best_configs(
|
254
|
-
m, n, k, 1, num_sms
|
255
|
-
),
|
256
|
-
),
|
257
|
-
}
|
258
|
-
|
259
|
-
|
74
|
+
# TODO improve code
|
260
75
|
def _maybe_compile_deep_gemm_one_type_all(
|
261
76
|
kernel_type: DeepGemmKernelType,
|
262
77
|
n: int,
|
263
78
|
k: int,
|
264
79
|
num_groups: int,
|
265
|
-
m_list: Optional[List[int]] = None,
|
266
80
|
) -> None:
|
267
81
|
global _INITIALIZATION_DICT
|
268
82
|
global _BUILTIN_M_LIST
|
@@ -275,61 +89,153 @@ def _maybe_compile_deep_gemm_one_type_all(
|
|
275
89
|
):
|
276
90
|
_INITIALIZATION_DICT[query_key] = True
|
277
91
|
|
278
|
-
|
279
|
-
|
92
|
+
# TODO maybe improve logs
|
93
|
+
if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
|
94
|
+
logger.warning(
|
95
|
+
"Entering DeepGEMM JIT Pre-Compile session. "
|
96
|
+
"It may take a long time (typically 10-20 mins) "
|
97
|
+
"if you have not run `sglang.compile_deep_gemm`. "
|
98
|
+
"It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
99
|
+
" for pre-compilation to reduce the overhead if you have not run it before. "
|
100
|
+
"For example: "
|
101
|
+
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
|
102
|
+
)
|
103
|
+
|
280
104
|
logger.info(
|
281
105
|
f"Try DeepGEMM JIT Compiling for "
|
282
|
-
f"<{
|
106
|
+
f"<{kernel_type.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
|
283
107
|
f"{' It only takes a little time (typically 1 sec) if you have run `python3 -m sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}"
|
284
108
|
)
|
285
109
|
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
kernel_helper.configure_func(m, n, k, num_groups, num_sms)
|
293
|
-
)
|
294
|
-
compile_func = lambda config: kernel_helper.compile_func(
|
295
|
-
n, k, num_groups, config
|
110
|
+
_compile_deep_gemm_one_type_all(
|
111
|
+
kernel_type=kernel_type,
|
112
|
+
n=n,
|
113
|
+
k=k,
|
114
|
+
num_groups=num_groups,
|
115
|
+
m_list=_BUILTIN_M_LIST,
|
296
116
|
)
|
297
|
-
thread_map(compile_func, collected_configs, max_workers=_COMPILE_WORKERS)
|
298
117
|
|
299
118
|
|
300
|
-
|
301
|
-
def
|
302
|
-
|
303
|
-
|
304
|
-
|
119
|
+
# NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
|
120
|
+
def _compile_deep_gemm_one_type_all(
|
121
|
+
kernel_type: DeepGemmKernelType,
|
122
|
+
n: int,
|
123
|
+
k: int,
|
124
|
+
num_groups: int,
|
125
|
+
m_list: List[int],
|
126
|
+
) -> None:
|
127
|
+
if kernel_type == DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG:
|
128
|
+
m_alignment = deep_gemm.get_mk_alignment_for_contiguous_layout()
|
129
|
+
m_list = sorted(list(set(m for m in m_list if m % m_alignment == 0)))
|
130
|
+
|
131
|
+
executor = _BaseWarmupExecutor.create(
|
132
|
+
kernel_type, max_m=max(m_list), n=n, k=k, num_groups=num_groups
|
133
|
+
)
|
305
134
|
|
306
|
-
|
135
|
+
old_compile_mode = deep_gemm.get_compile_mode()
|
136
|
+
deep_gemm.set_compile_mode(1)
|
137
|
+
# TODO can use multi thread
|
138
|
+
for m in tqdm(m_list, desc=f"DeepGEMM warmup"):
|
139
|
+
executor.execute(m=m)
|
140
|
+
deep_gemm.set_compile_mode(old_compile_mode)
|
141
|
+
|
142
|
+
# clean up input buffers
|
143
|
+
torch.cuda.current_stream().synchronize()
|
144
|
+
del executor
|
145
|
+
torch.cuda.empty_cache()
|
146
|
+
|
147
|
+
|
148
|
+
class _BaseWarmupExecutor:
|
149
|
+
@staticmethod
|
150
|
+
def create(kernel_type: DeepGemmKernelType, **kwargs):
|
151
|
+
return {
|
152
|
+
DeepGemmKernelType.GEMM_NT_F8F8BF16: _NormalWarmupExecutor,
|
153
|
+
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG: _GroupedContWarmupExecutor,
|
154
|
+
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: _GroupedMaskedWarmupExecutor,
|
155
|
+
}[kernel_type](**kwargs)
|
156
|
+
|
157
|
+
def execute(self, m):
|
158
|
+
raise NotImplementedError
|
159
|
+
|
160
|
+
|
161
|
+
def _empty_token_fp8(size):
|
162
|
+
*dims, k = size
|
163
|
+
return (
|
164
|
+
torch.empty(size, device="cuda", dtype=torch.float8_e4m3fn),
|
165
|
+
torch.empty(
|
166
|
+
(*dims, ceil_div(k, _BLOCK_SIZE)), device="cuda", dtype=torch.float32
|
167
|
+
),
|
168
|
+
)
|
307
169
|
|
308
|
-
origin_func = RuntimeCache.get
|
309
170
|
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
171
|
+
def _empty_block_fp8(size):
|
172
|
+
*dims, n, k = size
|
173
|
+
return (
|
174
|
+
torch.empty(size, device="cuda", dtype=torch.float8_e4m3fn),
|
175
|
+
torch.empty(
|
176
|
+
(*dims, ceil_div(n, _BLOCK_SIZE), ceil_div(k, _BLOCK_SIZE)),
|
177
|
+
device="cuda",
|
178
|
+
dtype=torch.float32,
|
179
|
+
),
|
180
|
+
)
|
320
181
|
|
321
|
-
|
322
|
-
|
323
|
-
|
182
|
+
|
183
|
+
_BLOCK_SIZE = 128
|
184
|
+
|
185
|
+
|
186
|
+
class _NormalWarmupExecutor(_BaseWarmupExecutor):
|
187
|
+
def __init__(self, max_m: int, n: int, k: int, num_groups: int):
|
188
|
+
self.lhs_q, self.lhs_s = _empty_token_fp8((max_m, k))
|
189
|
+
self.rhs_q, self.rhs_s = _empty_block_fp8((n, k))
|
190
|
+
self.out = torch.empty((max_m, n), device="cuda", dtype=torch.bfloat16)
|
191
|
+
|
192
|
+
def execute(self, m):
|
193
|
+
deep_gemm.fp8_gemm_nt(
|
194
|
+
(self.lhs_q[:m], self.lhs_s[:m]),
|
195
|
+
(self.rhs_q, self.rhs_s),
|
196
|
+
self.out[:m],
|
197
|
+
)
|
198
|
+
|
199
|
+
|
200
|
+
class _GroupedContWarmupExecutor(_BaseWarmupExecutor):
|
201
|
+
def __init__(self, max_m: int, n: int, k: int, num_groups: int):
|
202
|
+
self.lhs_q, self.lhs_s = _empty_token_fp8((max_m, k))
|
203
|
+
self.rhs_q, self.rhs_s = _empty_block_fp8((num_groups, n, k))
|
204
|
+
self.m_indices = torch.zeros((max_m,), device="cuda", dtype=torch.int32)
|
205
|
+
self.out = torch.empty((max_m, n), device="cuda", dtype=torch.bfloat16)
|
206
|
+
|
207
|
+
def execute(self, m):
|
208
|
+
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
|
209
|
+
(self.lhs_q[:m], self.lhs_s[:m]),
|
210
|
+
(self.rhs_q, self.rhs_s),
|
211
|
+
self.out[:m],
|
212
|
+
m_indices=self.m_indices[:m],
|
213
|
+
)
|
214
|
+
|
215
|
+
|
216
|
+
class _GroupedMaskedWarmupExecutor(_BaseWarmupExecutor):
|
217
|
+
def __init__(self, max_m: int, n: int, k: int, num_groups: int):
|
218
|
+
self.lhs_q, self.lhs_s = _empty_token_fp8((num_groups, max_m, k))
|
219
|
+
self.rhs_q, self.rhs_s = _empty_block_fp8((num_groups, n, k))
|
220
|
+
self.masked_m = torch.zeros((num_groups,), device="cuda", dtype=torch.int32)
|
221
|
+
self.out = torch.empty(
|
222
|
+
(num_groups, max_m, n), device="cuda", dtype=torch.bfloat16
|
223
|
+
)
|
224
|
+
|
225
|
+
def execute(self, m):
|
226
|
+
deep_gemm.fp8_m_grouped_gemm_nt_masked(
|
227
|
+
(self.lhs_q, self.lhs_s),
|
228
|
+
(self.rhs_q, self.rhs_s),
|
229
|
+
self.out,
|
230
|
+
masked_m=self.masked_m,
|
231
|
+
# DeepGEMM uses `expect_m` instead of input shape for `get_best_config`
|
232
|
+
expected_m=m,
|
233
|
+
)
|
324
234
|
|
325
235
|
|
326
236
|
@contextmanager
|
327
237
|
def deep_gemm_execution_hook(
|
328
238
|
m: int, n: int, k: int, num_groups: int, kernel_type: DeepGemmKernelType
|
329
239
|
):
|
330
|
-
|
331
|
-
|
332
|
-
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
|
333
|
-
|
334
|
-
with _log_jit_build(m, n, k, kernel_type):
|
335
|
-
yield
|
240
|
+
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
|
241
|
+
yield
|
@@ -11,9 +11,6 @@ def _compute_enable_deep_gemm():
|
|
11
11
|
sm_version = get_device_sm()
|
12
12
|
if sm_version < 90:
|
13
13
|
return False
|
14
|
-
# TODO fix deepgemm cu129 fp8 issue
|
15
|
-
if torch.version.cuda == "12.9":
|
16
|
-
return False
|
17
14
|
|
18
15
|
try:
|
19
16
|
import deep_gemm
|
@@ -24,14 +21,12 @@ def _compute_enable_deep_gemm():
|
|
24
21
|
return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
|
25
22
|
|
26
23
|
|
27
|
-
|
24
|
+
def _is_blackwell_arch() -> bool:
|
25
|
+
major, minor = torch.cuda.get_device_capability(torch.cuda.current_device())
|
26
|
+
return major == 10
|
28
27
|
|
29
|
-
try:
|
30
|
-
from deep_gemm import fp8_gemm_nt
|
31
28
|
|
32
|
-
|
33
|
-
DEEPGEMM_BLACKWELL = True
|
34
|
-
except ImportError:
|
35
|
-
DEEPGEMM_BLACKWELL = False
|
29
|
+
ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
|
36
30
|
|
31
|
+
DEEPGEMM_BLACKWELL = ENABLE_JIT_DEEPGEMM and _is_blackwell_arch()
|
37
32
|
DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL
|