sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +330 -156
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +8 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +134 -23
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +70 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +66 -66
- sglang/srt/entrypoints/grpc_server.py +431 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +120 -8
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +42 -4
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +18 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +4 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +12 -8
- sglang/srt/layers/attention/flashinfer_backend.py +248 -21
- sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +11 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +45 -15
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +147 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
- sglang/srt/layers/moe/ep_moe/layer.py +119 -397
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +17 -1
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +5 -30
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +673 -16
- sglang/srt/layers/sampler.py +36 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +66 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +399 -499
- sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +378 -90
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +435 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +28 -23
- sglang/srt/model_executor/model_runner.py +379 -139
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +273 -98
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +14 -37
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +5 -5
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3_moe.py +19 -35
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +6 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +577 -73
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +38 -28
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +192 -47
- sglang/srt/utils/hf_transformers_utils.py +40 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +232 -99
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Callable, Optional
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from compressed_tensors.quantization import ActivationOrdering
|
|
9
|
+
|
|
10
|
+
# yapf conflicts with isort for this block
|
|
11
|
+
# yapf: disable
|
|
12
|
+
from sglang.srt.layers.parameter import (
|
|
13
|
+
BasevLLMParameter,
|
|
14
|
+
ChannelQuantScaleParameter,
|
|
15
|
+
GroupQuantScaleParameter,
|
|
16
|
+
PackedColumnParameter,
|
|
17
|
+
PackedvLLMParameter,
|
|
18
|
+
RowvLLMParameter,
|
|
19
|
+
permute_param_layout_,
|
|
20
|
+
)
|
|
21
|
+
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
|
22
|
+
CompressedTensorsScheme,
|
|
23
|
+
)
|
|
24
|
+
from sglang.srt.layers.quantization.marlin_utils import (
|
|
25
|
+
MarlinLinearLayerConfig,
|
|
26
|
+
apply_gptq_marlin_linear,
|
|
27
|
+
check_marlin_supports_shape,
|
|
28
|
+
marlin_is_k_full,
|
|
29
|
+
marlin_make_empty_g_idx,
|
|
30
|
+
marlin_make_workspace,
|
|
31
|
+
marlin_permute_scales,
|
|
32
|
+
marlin_repeat_scales_on_all_ranks,
|
|
33
|
+
marlin_sort_g_idx,
|
|
34
|
+
marlin_zero_points,
|
|
35
|
+
)
|
|
36
|
+
from sglang.srt.layers.quantization.utils import (
|
|
37
|
+
get_scalar_types,
|
|
38
|
+
replace_parameter,
|
|
39
|
+
unpack_cols,
|
|
40
|
+
)
|
|
41
|
+
from sglang.srt.utils import is_cuda
|
|
42
|
+
|
|
43
|
+
_is_cuda = is_cuda()
|
|
44
|
+
|
|
45
|
+
if _is_cuda:
|
|
46
|
+
from sgl_kernel import gptq_marlin_repack
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
ScalarType, scalar_types = get_scalar_types()
|
|
50
|
+
|
|
51
|
+
logger = logging.getLogger(__name__)
|
|
52
|
+
|
|
53
|
+
__all__ = ["CompressedTensorsWNA16"]
|
|
54
|
+
WNA16_SUPPORTED_TYPES_MAP = {
|
|
55
|
+
4: scalar_types.uint4b8,
|
|
56
|
+
8: scalar_types.uint8b128
|
|
57
|
+
}
|
|
58
|
+
WNA16_ZP_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4, 8: scalar_types.uint8}
|
|
59
|
+
WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class CompressedTensorsWNA16(CompressedTensorsScheme):
|
|
63
|
+
_kernel_backends_being_used: set[str] = set()
|
|
64
|
+
|
|
65
|
+
def __init__(self,
|
|
66
|
+
strategy: str,
|
|
67
|
+
num_bits: int,
|
|
68
|
+
group_size: Optional[int] = None,
|
|
69
|
+
symmetric: Optional[bool] = True,
|
|
70
|
+
actorder: Optional[ActivationOrdering] = None):
|
|
71
|
+
|
|
72
|
+
self.pack_factor = 32 // num_bits
|
|
73
|
+
self.strategy = strategy
|
|
74
|
+
self.symmetric = symmetric
|
|
75
|
+
self.group_size = -1 if group_size is None else group_size
|
|
76
|
+
self.has_g_idx = actorder == ActivationOrdering.GROUP
|
|
77
|
+
|
|
78
|
+
if self.group_size == -1 and self.strategy != "channel":
|
|
79
|
+
raise ValueError("Marlin kernels require group quantization or "
|
|
80
|
+
"channelwise quantization, but found no group "
|
|
81
|
+
"size and strategy is not channelwise.")
|
|
82
|
+
|
|
83
|
+
if num_bits not in WNA16_SUPPORTED_TYPES_MAP:
|
|
84
|
+
raise ValueError(
|
|
85
|
+
f"Unsupported num_bits = {num_bits}. "
|
|
86
|
+
f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}")
|
|
87
|
+
|
|
88
|
+
self.quant_type = (WNA16_ZP_SUPPORTED_TYPES_MAP[num_bits]
|
|
89
|
+
if not self.symmetric else
|
|
90
|
+
WNA16_SUPPORTED_TYPES_MAP[num_bits])
|
|
91
|
+
|
|
92
|
+
@classmethod
|
|
93
|
+
def get_min_capability(cls) -> int:
|
|
94
|
+
# ampere and up
|
|
95
|
+
return 80
|
|
96
|
+
|
|
97
|
+
def create_weights(self, layer: torch.nn.Module, output_size: int,
|
|
98
|
+
input_size: int, output_partition_sizes: list[int],
|
|
99
|
+
input_size_per_partition: int,
|
|
100
|
+
params_dtype: torch.dtype, weight_loader: Callable,
|
|
101
|
+
**kwargs):
|
|
102
|
+
|
|
103
|
+
output_size_per_partition = sum(output_partition_sizes)
|
|
104
|
+
|
|
105
|
+
self.kernel_config = MarlinLinearLayerConfig(
|
|
106
|
+
full_weight_shape=(input_size, output_size),
|
|
107
|
+
partition_weight_shape=(
|
|
108
|
+
input_size_per_partition,
|
|
109
|
+
output_size_per_partition,
|
|
110
|
+
),
|
|
111
|
+
weight_type=self.quant_type,
|
|
112
|
+
act_type=params_dtype,
|
|
113
|
+
group_size=self.group_size,
|
|
114
|
+
zero_points=not self.symmetric,
|
|
115
|
+
has_g_idx=self.has_g_idx
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# If group_size is -1, we are in channelwise case.
|
|
119
|
+
group_size = self.group_size if self.group_size != -1 else input_size
|
|
120
|
+
row_parallel = (input_size != input_size_per_partition)
|
|
121
|
+
partition_scales = not marlin_repeat_scales_on_all_ranks(
|
|
122
|
+
self.has_g_idx, self.group_size, row_parallel)
|
|
123
|
+
|
|
124
|
+
scales_and_zp_size = input_size // group_size
|
|
125
|
+
|
|
126
|
+
if partition_scales:
|
|
127
|
+
assert input_size_per_partition % group_size == 0
|
|
128
|
+
scales_and_zp_size = input_size_per_partition // group_size
|
|
129
|
+
|
|
130
|
+
weight = PackedvLLMParameter(input_dim=1,
|
|
131
|
+
output_dim=0,
|
|
132
|
+
weight_loader=weight_loader,
|
|
133
|
+
packed_factor=self.pack_factor,
|
|
134
|
+
packed_dim=1,
|
|
135
|
+
data=torch.empty(
|
|
136
|
+
output_size_per_partition,
|
|
137
|
+
input_size_per_partition //
|
|
138
|
+
self.pack_factor,
|
|
139
|
+
dtype=torch.int32,
|
|
140
|
+
))
|
|
141
|
+
|
|
142
|
+
weight_scale_args = {
|
|
143
|
+
"weight_loader":
|
|
144
|
+
weight_loader,
|
|
145
|
+
"data":
|
|
146
|
+
torch.empty(
|
|
147
|
+
output_size_per_partition,
|
|
148
|
+
scales_and_zp_size,
|
|
149
|
+
dtype=params_dtype,
|
|
150
|
+
)
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
zeros_args = {
|
|
154
|
+
"weight_loader":
|
|
155
|
+
weight_loader,
|
|
156
|
+
"data":
|
|
157
|
+
torch.zeros(
|
|
158
|
+
output_size_per_partition // self.pack_factor,
|
|
159
|
+
scales_and_zp_size,
|
|
160
|
+
dtype=torch.int32,
|
|
161
|
+
)
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
if not partition_scales:
|
|
165
|
+
weight_scale = ChannelQuantScaleParameter(output_dim=0,
|
|
166
|
+
**weight_scale_args)
|
|
167
|
+
|
|
168
|
+
if not self.symmetric:
|
|
169
|
+
qzeros = PackedColumnParameter(output_dim=0,
|
|
170
|
+
packed_dim=0,
|
|
171
|
+
packed_factor=self.pack_factor,
|
|
172
|
+
**zeros_args)
|
|
173
|
+
else:
|
|
174
|
+
weight_scale = GroupQuantScaleParameter(output_dim=0,
|
|
175
|
+
input_dim=1,
|
|
176
|
+
**weight_scale_args)
|
|
177
|
+
if not self.symmetric:
|
|
178
|
+
qzeros = PackedvLLMParameter(input_dim=1,
|
|
179
|
+
output_dim=0,
|
|
180
|
+
packed_dim=0,
|
|
181
|
+
packed_factor=self.pack_factor,
|
|
182
|
+
**zeros_args)
|
|
183
|
+
|
|
184
|
+
# A 2D array defining the original shape of the weights
|
|
185
|
+
# before packing
|
|
186
|
+
weight_shape = BasevLLMParameter(data=torch.empty(2,
|
|
187
|
+
dtype=torch.int64),
|
|
188
|
+
weight_loader=weight_loader)
|
|
189
|
+
|
|
190
|
+
layer.register_parameter("weight_packed", weight)
|
|
191
|
+
layer.register_parameter("weight_scale", weight_scale)
|
|
192
|
+
layer.register_parameter("weight_shape", weight_shape)
|
|
193
|
+
|
|
194
|
+
if not self.symmetric:
|
|
195
|
+
layer.register_parameter("weight_zero_point", qzeros)
|
|
196
|
+
|
|
197
|
+
# group index (for activation reordering)
|
|
198
|
+
if self.has_g_idx:
|
|
199
|
+
weight_g_idx = RowvLLMParameter(data=torch.empty(
|
|
200
|
+
input_size_per_partition,
|
|
201
|
+
dtype=torch.int32,
|
|
202
|
+
),
|
|
203
|
+
input_dim=0,
|
|
204
|
+
weight_loader=weight_loader)
|
|
205
|
+
layer.register_parameter("weight_g_idx", weight_g_idx)
|
|
206
|
+
|
|
207
|
+
# Checkpoints are serialized in compressed-tensors format, which is
|
|
208
|
+
# different from the format the kernel may want. Handle repacking here.
|
|
209
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
210
|
+
# Default names since marlin requires empty parameters for these,
|
|
211
|
+
# TODO: remove this requirement from marlin (allow optional tensors)
|
|
212
|
+
self.w_q_name = "weight_packed"
|
|
213
|
+
self.w_s_name = "weight_scale"
|
|
214
|
+
self.w_zp_name = "weight_zero_point"
|
|
215
|
+
self.w_gidx_name = "weight_g_idx"
|
|
216
|
+
|
|
217
|
+
device = getattr(layer, self.w_q_name).device
|
|
218
|
+
c = self.kernel_config
|
|
219
|
+
|
|
220
|
+
check_marlin_supports_shape(
|
|
221
|
+
c.partition_weight_shape[1], # out_features
|
|
222
|
+
c.partition_weight_shape[0], # in_features
|
|
223
|
+
c.full_weight_shape[0], # in_features
|
|
224
|
+
c.group_size,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0]
|
|
228
|
+
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
|
|
229
|
+
|
|
230
|
+
# Allocate marlin workspace.
|
|
231
|
+
self.workspace = marlin_make_workspace(device)
|
|
232
|
+
|
|
233
|
+
def _transform_param(
|
|
234
|
+
layer: torch.nn.Module, name: Optional[str], fn: Callable
|
|
235
|
+
) -> None:
|
|
236
|
+
if name is not None and getattr(layer, name, None) is not None:
|
|
237
|
+
|
|
238
|
+
old_param = getattr(layer, name)
|
|
239
|
+
new_param = fn(old_param)
|
|
240
|
+
# replace the parameter with torch.nn.Parameter for TorchDynamo
|
|
241
|
+
# compatibility
|
|
242
|
+
replace_parameter(
|
|
243
|
+
layer, name, torch.nn.Parameter(new_param.data, requires_grad=False)
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
def transform_w_q(x):
|
|
247
|
+
assert isinstance(x, BasevLLMParameter)
|
|
248
|
+
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
|
249
|
+
x.data = gptq_marlin_repack(
|
|
250
|
+
x.data.contiguous(),
|
|
251
|
+
perm=layer.g_idx_sort_indices,
|
|
252
|
+
size_k=c.partition_weight_shape[0],
|
|
253
|
+
size_n=c.partition_weight_shape[1],
|
|
254
|
+
num_bits=c.weight_type.size_bits,
|
|
255
|
+
)
|
|
256
|
+
return x
|
|
257
|
+
|
|
258
|
+
def transform_w_s(x):
|
|
259
|
+
assert isinstance(x, BasevLLMParameter)
|
|
260
|
+
permute_param_layout_(x, input_dim=0, output_dim=1)
|
|
261
|
+
x.data = marlin_permute_scales(
|
|
262
|
+
x.data.contiguous(),
|
|
263
|
+
size_k=c.partition_weight_shape[0],
|
|
264
|
+
size_n=c.partition_weight_shape[1],
|
|
265
|
+
group_size=c.group_size,
|
|
266
|
+
)
|
|
267
|
+
return x
|
|
268
|
+
|
|
269
|
+
if c.has_g_idx:
|
|
270
|
+
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
|
|
271
|
+
getattr(layer, self.w_gidx_name)
|
|
272
|
+
)
|
|
273
|
+
_transform_param(layer, self.w_gidx_name, lambda _: g_idx)
|
|
274
|
+
layer.g_idx_sort_indices = g_idx_sort_indices
|
|
275
|
+
else:
|
|
276
|
+
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
|
|
277
|
+
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
|
278
|
+
|
|
279
|
+
if c.zero_points:
|
|
280
|
+
grouped_k = (
|
|
281
|
+
c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1
|
|
282
|
+
)
|
|
283
|
+
_transform_param(
|
|
284
|
+
layer,
|
|
285
|
+
self.w_zp_name,
|
|
286
|
+
lambda x: marlin_zero_points(
|
|
287
|
+
unpack_cols(
|
|
288
|
+
x.t(),
|
|
289
|
+
c.weight_type.size_bits,
|
|
290
|
+
grouped_k,
|
|
291
|
+
c.partition_weight_shape[1],
|
|
292
|
+
),
|
|
293
|
+
size_k=grouped_k,
|
|
294
|
+
size_n=c.partition_weight_shape[1],
|
|
295
|
+
num_bits=c.weight_type.size_bits,
|
|
296
|
+
),
|
|
297
|
+
)
|
|
298
|
+
else:
|
|
299
|
+
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
|
|
300
|
+
_transform_param(layer, self.w_q_name, transform_w_q)
|
|
301
|
+
_transform_param(layer, self.w_s_name, transform_w_s)
|
|
302
|
+
|
|
303
|
+
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
|
304
|
+
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
|
305
|
+
c = self.kernel_config
|
|
306
|
+
|
|
307
|
+
def _get_weight_params(
|
|
308
|
+
layer: torch.nn.Module,
|
|
309
|
+
) -> tuple[
|
|
310
|
+
torch.Tensor, # w_q
|
|
311
|
+
torch.Tensor, # w_s
|
|
312
|
+
Optional[torch.Tensor], # w_zp,
|
|
313
|
+
Optional[torch.Tensor], # w_gidx
|
|
314
|
+
]:
|
|
315
|
+
return (
|
|
316
|
+
getattr(layer, self.w_q_name),
|
|
317
|
+
getattr(layer, self.w_s_name),
|
|
318
|
+
getattr(layer, self.w_zp_name or "", None),
|
|
319
|
+
getattr(layer, self.w_gidx_name or "", None),
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
w_q, w_s, w_zp, w_gidx = _get_weight_params(layer)
|
|
323
|
+
|
|
324
|
+
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
|
|
325
|
+
# None for marlin
|
|
326
|
+
return apply_gptq_marlin_linear(
|
|
327
|
+
input=x,
|
|
328
|
+
weight=w_q,
|
|
329
|
+
weight_scale=w_s,
|
|
330
|
+
weight_zp=w_zp, # type: ignore
|
|
331
|
+
g_idx=w_gidx, # type: ignore
|
|
332
|
+
g_idx_sort_indices=layer.g_idx_sort_indices,
|
|
333
|
+
workspace=self.workspace,
|
|
334
|
+
wtype=c.weight_type,
|
|
335
|
+
input_size_per_partition=c.partition_weight_shape[0],
|
|
336
|
+
output_size_per_partition=c.partition_weight_shape[1],
|
|
337
|
+
is_k_full=self.is_k_full,
|
|
338
|
+
bias=bias,
|
|
339
|
+
)
|
|
@@ -31,8 +31,8 @@ except ImportError:
|
|
|
31
31
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
|
32
32
|
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
|
33
33
|
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
|
34
|
+
from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmMoeQuantInfo
|
|
34
35
|
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
|
35
|
-
from sglang.srt.layers.moe.token_dispatcher.base import DispatchOutputChecker
|
|
36
36
|
from sglang.srt.layers.parameter import (
|
|
37
37
|
BlockQuantScaleParameter,
|
|
38
38
|
ModelWeightParameter,
|
|
@@ -1006,8 +1006,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
|
1006
1006
|
def create_moe_runner(
|
|
1007
1007
|
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
|
1008
1008
|
):
|
|
1009
|
+
|
|
1010
|
+
from sglang.srt.layers import deep_gemm_wrapper
|
|
1011
|
+
from sglang.srt.layers.moe.utils import (
|
|
1012
|
+
get_moe_a2a_backend,
|
|
1013
|
+
get_moe_runner_backend,
|
|
1014
|
+
)
|
|
1015
|
+
|
|
1009
1016
|
self.moe_runner_config = moe_runner_config
|
|
1010
|
-
|
|
1017
|
+
moe_runner_backend = get_moe_runner_backend()
|
|
1018
|
+
|
|
1019
|
+
if moe_runner_backend.is_auto():
|
|
1020
|
+
if (
|
|
1021
|
+
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
|
1022
|
+
and get_moe_a2a_backend().is_deepep()
|
|
1023
|
+
):
|
|
1024
|
+
moe_runner_backend = MoeRunnerBackend.DEEP_GEMM
|
|
1025
|
+
else:
|
|
1026
|
+
moe_runner_backend = MoeRunnerBackend.TRITON
|
|
1027
|
+
if moe_runner_backend.is_deep_gemm() or moe_runner_backend.is_triton():
|
|
1028
|
+
self.runner = MoeRunner(moe_runner_backend, moe_runner_config)
|
|
1029
|
+
else:
|
|
1030
|
+
# TODO(cwan): refactor other backends
|
|
1031
|
+
pass
|
|
1011
1032
|
|
|
1012
1033
|
def apply(
|
|
1013
1034
|
self,
|
|
@@ -1087,22 +1108,67 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
|
1087
1108
|
)
|
|
1088
1109
|
return StandardCombineInput(hidden_states=output)
|
|
1089
1110
|
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
|
|
1101
|
-
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
|
|
1105
|
-
|
|
1111
|
+
if self.runner.runner_backend.is_deep_gemm():
|
|
1112
|
+
|
|
1113
|
+
w13_weight = layer.w13_weight
|
|
1114
|
+
w2_weight = layer.w2_weight
|
|
1115
|
+
|
|
1116
|
+
if self.block_quant:
|
|
1117
|
+
block_shape = self.quant_config.weight_block_size
|
|
1118
|
+
w13_scale = layer.w13_weight_scale_inv
|
|
1119
|
+
w2_scale = layer.w2_weight_scale_inv
|
|
1120
|
+
else:
|
|
1121
|
+
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
|
|
1122
|
+
scale_block_size = 128
|
|
1123
|
+
block_shape = [scale_block_size, scale_block_size]
|
|
1124
|
+
w13_scale_n = (w13_weight.shape[1] - 1) // scale_block_size + 1
|
|
1125
|
+
w13_scale_k = (w13_weight.shape[2] - 1) // scale_block_size + 1
|
|
1126
|
+
w13_scale = (
|
|
1127
|
+
layer.w13_weight_scale.unsqueeze(1)
|
|
1128
|
+
.repeat_interleave(w13_scale_n, dim=1)
|
|
1129
|
+
.unsqueeze(2)
|
|
1130
|
+
.repeat_interleave(w13_scale_k, dim=2)
|
|
1131
|
+
)
|
|
1132
|
+
w2_scale_n = (w2_weight.shape[1] - 1) // scale_block_size + 1
|
|
1133
|
+
w2_scale_k = (w2_weight.shape[2] - 1) // scale_block_size + 1
|
|
1134
|
+
w2_scale = (
|
|
1135
|
+
layer.w2_weight_scale.unsqueeze(1)
|
|
1136
|
+
.repeat_interleave(w2_scale_n, dim=1)
|
|
1137
|
+
.unsqueeze(2)
|
|
1138
|
+
.repeat_interleave(w2_scale_k, dim=2)
|
|
1139
|
+
)
|
|
1140
|
+
quant_info = DeepGemmMoeQuantInfo(
|
|
1141
|
+
w13_weight=w13_weight,
|
|
1142
|
+
w2_weight=w2_weight,
|
|
1143
|
+
use_fp8=True,
|
|
1144
|
+
w13_scale=w13_scale,
|
|
1145
|
+
w2_scale=w2_scale,
|
|
1146
|
+
block_shape=block_shape,
|
|
1147
|
+
)
|
|
1148
|
+
elif self.runner.runner_backend.is_triton():
|
|
1149
|
+
quant_info = TritonMoeQuantInfo(
|
|
1150
|
+
w13_weight=layer.w13_weight,
|
|
1151
|
+
w2_weight=layer.w2_weight,
|
|
1152
|
+
use_fp8_w8a8=True,
|
|
1153
|
+
w13_scale=(
|
|
1154
|
+
layer.w13_weight_scale_inv
|
|
1155
|
+
if self.block_quant
|
|
1156
|
+
else layer.w13_weight_scale
|
|
1157
|
+
),
|
|
1158
|
+
w2_scale=(
|
|
1159
|
+
layer.w2_weight_scale_inv
|
|
1160
|
+
if self.block_quant
|
|
1161
|
+
else layer.w2_weight_scale
|
|
1162
|
+
),
|
|
1163
|
+
a13_scale=layer.w13_input_scale,
|
|
1164
|
+
a2_scale=layer.w2_input_scale,
|
|
1165
|
+
block_shape=self.quant_config.weight_block_size,
|
|
1166
|
+
)
|
|
1167
|
+
else:
|
|
1168
|
+
raise NotImplementedError(
|
|
1169
|
+
"Unsupported runner backend: %s" % self.runner.runner_backend
|
|
1170
|
+
)
|
|
1171
|
+
|
|
1106
1172
|
return self.runner.run(dispatch_output, quant_info)
|
|
1107
1173
|
|
|
1108
1174
|
def apply_with_router_logits(
|
|
@@ -23,7 +23,7 @@ import torch
|
|
|
23
23
|
import triton
|
|
24
24
|
import triton.language as tl
|
|
25
25
|
|
|
26
|
-
from sglang.srt.layers
|
|
26
|
+
from sglang.srt.layers import deep_gemm_wrapper
|
|
27
27
|
from sglang.srt.utils import (
|
|
28
28
|
align,
|
|
29
29
|
direct_register_custom_op,
|
|
@@ -43,11 +43,17 @@ _is_cpu = is_cpu()
|
|
|
43
43
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
|
44
44
|
|
|
45
45
|
if _is_cuda:
|
|
46
|
-
from sgl_kernel import
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
46
|
+
from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8
|
|
47
|
+
|
|
48
|
+
# Temporary
|
|
49
|
+
try:
|
|
50
|
+
from sgl_kernel import sgl_per_token_group_quant_8bit
|
|
51
|
+
|
|
52
|
+
enable_sgl_per_token_group_quant_8bit = True
|
|
53
|
+
except ImportError:
|
|
54
|
+
from sgl_kernel import sgl_per_token_group_quant_fp8
|
|
55
|
+
|
|
56
|
+
enable_sgl_per_token_group_quant_8bit = False
|
|
51
57
|
|
|
52
58
|
if _is_hip:
|
|
53
59
|
if _use_aiter:
|
|
@@ -61,7 +67,7 @@ if _is_hip:
|
|
|
61
67
|
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
|
62
68
|
else:
|
|
63
69
|
try:
|
|
64
|
-
import vllm._C
|
|
70
|
+
import vllm._C # noqa: F401
|
|
65
71
|
except ImportError:
|
|
66
72
|
raise ImportError("vllm is required when SGLANG_USE_AITER is set to False")
|
|
67
73
|
|
|
@@ -477,6 +483,7 @@ def sglang_per_token_group_quant_fp8(
|
|
|
477
483
|
scale_ue8m0: bool = False,
|
|
478
484
|
fuse_silu_and_mul: bool = False,
|
|
479
485
|
masked_m: Optional[torch.Tensor] = None,
|
|
486
|
+
enable_v2: Optional[bool] = None,
|
|
480
487
|
):
|
|
481
488
|
assert (
|
|
482
489
|
x.shape[-1] % group_size == 0
|
|
@@ -496,9 +503,26 @@ def sglang_per_token_group_quant_fp8(
|
|
|
496
503
|
)
|
|
497
504
|
|
|
498
505
|
if x.shape[0] > 0:
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
506
|
+
# Temporary
|
|
507
|
+
if enable_sgl_per_token_group_quant_8bit:
|
|
508
|
+
sgl_per_token_group_quant_8bit(
|
|
509
|
+
x,
|
|
510
|
+
x_q,
|
|
511
|
+
x_s,
|
|
512
|
+
group_size,
|
|
513
|
+
eps,
|
|
514
|
+
fp8_min,
|
|
515
|
+
fp8_max,
|
|
516
|
+
scale_ue8m0,
|
|
517
|
+
fuse_silu_and_mul,
|
|
518
|
+
masked_m,
|
|
519
|
+
enable_v2=enable_v2,
|
|
520
|
+
)
|
|
521
|
+
else:
|
|
522
|
+
assert not enable_v2
|
|
523
|
+
sgl_per_token_group_quant_fp8(
|
|
524
|
+
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
|
|
525
|
+
)
|
|
502
526
|
|
|
503
527
|
return x_q, x_s
|
|
504
528
|
|
|
@@ -514,6 +538,7 @@ def sglang_per_token_group_quant_8bit(
|
|
|
514
538
|
scale_ue8m0: bool = False,
|
|
515
539
|
fuse_silu_and_mul: bool = False,
|
|
516
540
|
masked_m: Optional[torch.Tensor] = None,
|
|
541
|
+
enable_v2: Optional[bool] = None,
|
|
517
542
|
):
|
|
518
543
|
from sglang.srt.layers.quantization.int8_kernel import (
|
|
519
544
|
sglang_per_token_group_quant_int8,
|
|
@@ -529,6 +554,7 @@ def sglang_per_token_group_quant_8bit(
|
|
|
529
554
|
group_size=group_size,
|
|
530
555
|
eps=eps,
|
|
531
556
|
dtype=dst_dtype,
|
|
557
|
+
enable_v2=enable_v2,
|
|
532
558
|
)
|
|
533
559
|
|
|
534
560
|
return sglang_per_token_group_quant_fp8(
|
|
@@ -540,6 +566,7 @@ def sglang_per_token_group_quant_8bit(
|
|
|
540
566
|
scale_ue8m0=scale_ue8m0,
|
|
541
567
|
fuse_silu_and_mul=fuse_silu_and_mul,
|
|
542
568
|
masked_m=masked_m,
|
|
569
|
+
enable_v2=enable_v2,
|
|
543
570
|
)
|
|
544
571
|
|
|
545
572
|
|
|
@@ -1804,3 +1831,21 @@ def triton_scaled_mm(
|
|
|
1804
1831
|
)
|
|
1805
1832
|
|
|
1806
1833
|
return result.to(out_dtype)
|
|
1834
|
+
|
|
1835
|
+
|
|
1836
|
+
if _is_cuda:
|
|
1837
|
+
if enable_sgl_per_token_group_quant_8bit:
|
|
1838
|
+
|
|
1839
|
+
@torch.library.register_fake("sgl_kernel::sgl_per_token_group_quant_8bit")
|
|
1840
|
+
def _(
|
|
1841
|
+
input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
|
|
1842
|
+
):
|
|
1843
|
+
return
|
|
1844
|
+
|
|
1845
|
+
else:
|
|
1846
|
+
|
|
1847
|
+
@torch.library.register_fake("sgl_kernel::sgl_per_token_group_quant_fp8")
|
|
1848
|
+
def _(
|
|
1849
|
+
input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
|
|
1850
|
+
):
|
|
1851
|
+
return
|
|
@@ -2,11 +2,10 @@ from typing import Callable, List, Optional, Tuple
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from sglang.srt import
|
|
6
|
-
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
|
5
|
+
from sglang.srt.layers import deep_gemm_wrapper
|
|
7
6
|
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
|
8
7
|
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
|
|
9
|
-
from sglang.srt.utils import is_sm100_supported
|
|
8
|
+
from sglang.srt.utils import ceil_div, is_sm100_supported, offloader
|
|
10
9
|
|
|
11
10
|
try:
|
|
12
11
|
from vllm import _custom_ops as ops
|
|
@@ -29,7 +28,6 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|
|
29
28
|
)
|
|
30
29
|
from sglang.srt.utils import (
|
|
31
30
|
align,
|
|
32
|
-
ceil_div,
|
|
33
31
|
get_bool_env_var,
|
|
34
32
|
get_cuda_version,
|
|
35
33
|
get_device_capability,
|
|
@@ -443,23 +441,53 @@ def _requant_weight_ue8m0(
|
|
|
443
441
|
torch.bfloat16,
|
|
444
442
|
)
|
|
445
443
|
|
|
444
|
+
out_w, out_s = quant_weight_ue8m0(
|
|
445
|
+
weight_dequant=weight_dequant,
|
|
446
|
+
weight_block_size=weight_block_size,
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
out_s = _transform_scale_ue8m0(out_s, mn=out_w.shape[-2])
|
|
450
|
+
|
|
451
|
+
return out_w, out_s
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
def quant_weight_ue8m0(
|
|
455
|
+
weight_dequant: torch.Tensor,
|
|
456
|
+
weight_block_size: List[int],
|
|
457
|
+
):
|
|
458
|
+
assert weight_block_size == [128, 128]
|
|
459
|
+
assert (
|
|
460
|
+
weight_dequant.dtype == torch.bfloat16
|
|
461
|
+
), f"{weight_dequant.dtype=} {weight_dequant.shape=}"
|
|
462
|
+
|
|
463
|
+
*batch_dims, n, k = weight_dequant.shape
|
|
464
|
+
|
|
446
465
|
weight_dequant_flat = weight_dequant.view((-1, k))
|
|
447
466
|
out_w_flat, out_s_flat = per_block_cast_to_fp8(weight_dequant_flat)
|
|
448
467
|
|
|
449
|
-
out_w = out_w_flat.view(
|
|
450
|
-
out_s = out_s_flat.view(
|
|
468
|
+
out_w = out_w_flat.view((*batch_dims, n, k))
|
|
469
|
+
out_s = out_s_flat.view(
|
|
470
|
+
(
|
|
471
|
+
*batch_dims,
|
|
472
|
+
ceil_div(n, weight_block_size[0]),
|
|
473
|
+
ceil_div(k, weight_block_size[1]),
|
|
474
|
+
)
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
return out_w, out_s
|
|
478
|
+
|
|
451
479
|
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
import deep_gemm.utils.layout
|
|
480
|
+
def transform_scale_ue8m0_inplace(param, mn):
|
|
481
|
+
param.data = _transform_scale_ue8m0(param.data, mn=mn)
|
|
455
482
|
|
|
456
|
-
sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128)
|
|
457
|
-
sf = deep_gemm.utils.layout.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
|
|
458
|
-
return sf
|
|
459
483
|
|
|
460
|
-
|
|
484
|
+
# NOTE copy and modified from DeepGEMM
|
|
485
|
+
def _transform_scale_ue8m0(sf, mn):
|
|
486
|
+
import deep_gemm.utils.layout
|
|
461
487
|
|
|
462
|
-
|
|
488
|
+
sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128)
|
|
489
|
+
sf = deep_gemm.utils.layout.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
|
|
490
|
+
return sf
|
|
463
491
|
|
|
464
492
|
|
|
465
493
|
# COPIED FROM DeepGEMM
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
from __future__ import annotations
|
|
3
3
|
|
|
4
4
|
import logging
|
|
5
|
-
from typing import Any, Optional
|
|
5
|
+
from typing import Any, List, Optional
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
8
|
from torch.nn import Module
|
|
@@ -11,7 +11,6 @@ from torch.nn.parameter import Parameter
|
|
|
11
11
|
from sglang.srt.layers.linear import LinearBase
|
|
12
12
|
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
|
|
13
13
|
from sglang.srt.layers.quantization.base_config import (
|
|
14
|
-
FusedMoEMethodBase,
|
|
15
14
|
LinearMethodBase,
|
|
16
15
|
QuantizationConfig,
|
|
17
16
|
QuantizeMethodBase,
|
|
@@ -28,7 +27,7 @@ from sglang.srt.layers.quantization.marlin_utils_fp8 import (
|
|
|
28
27
|
prepare_fp8_layer_for_marlin,
|
|
29
28
|
)
|
|
30
29
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
|
31
|
-
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
|
30
|
+
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
|
32
31
|
from sglang.srt.utils import get_bool_env_var, is_cuda
|
|
33
32
|
|
|
34
33
|
_is_cuda = is_cuda()
|